diff --git a/.gitignore b/.gitignore index ea8c4bf..bc4f9f5 100644 --- a/.gitignore +++ b/.gitignore @@ -1 +1,4 @@ -/target +target + +.DS_Store +*~* diff --git a/Cargo.lock b/Cargo.lock index 14cd509..90affd8 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -5,6 +5,16 @@ version = 3 [[package]] name = "bitfield-struct" version = "0.3.2" +dependencies = [ + "bitfield-struct-derive", + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "bitfield-struct-derive" +version = "0.3.2" dependencies = [ "proc-macro2", "quote", @@ -13,9 +23,9 @@ dependencies = [ [[package]] name = "proc-macro2" -version = "1.0.49" +version = "1.0.51" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "57a8eca9f9c4ffde41714334dee777596264c7825420f521abc92b5b5deb63a5" +checksum = "5d727cae5b39d21da60fa540906919ad737832fe0b1c165da3a34d6548c849d6" dependencies = [ "unicode-ident", ] @@ -31,9 +41,9 @@ dependencies = [ [[package]] name = "syn" -version = "1.0.107" +version = "1.0.109" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1f4064b5b16e03ae50984a5a8ed5d4f8803e6bc1fd170a3cda91a1be4b18e3f5" +checksum = "72b64191b275b66ffe2469e8af2c1cfe3bafa67b529ead792a6d0160888b4237" dependencies = [ "proc-macro2", "quote", diff --git a/Cargo.toml b/Cargo.toml index 69e42c3..17cd255 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -10,10 +10,8 @@ repository = "https://github.com/wrenger/bitfield-struct-rs.git" readme = "README.md" license = "MIT" -[lib] -proc-macro = true - [dependencies] +bitfield-struct-derive = { version = "0.3.2", path = "derive" } quote = "1.0" syn = { version = "1.0", features = ["full"] } proc-macro2 = "1.0" diff --git a/derive/Cargo.toml b/derive/Cargo.toml new file mode 100644 index 0000000..b587114 --- /dev/null +++ b/derive/Cargo.toml @@ -0,0 +1,19 @@ +[package] +name = "bitfield-struct-derive" +version = "0.3.2" +edition = "2021" +authors = ["Lars Wrenger "] +description = "Procedural macro for bitfields." +keywords = ["bitfield", "bits", "bitfields", "proc-macro"] +categories = ["data-structures", "no-std"] +repository = "https://github.com/wrenger/bitfield-struct-rs.git" +readme = "README.md" +license = "MIT" + +[lib] +proc-macro = true + +[dependencies] +quote = "1.0" +syn = { version = "1.0", features = ["full"] } +proc-macro2 = "1.0" diff --git a/derive/src/lib.rs b/derive/src/lib.rs new file mode 100644 index 0000000..27566d8 --- /dev/null +++ b/derive/src/lib.rs @@ -0,0 +1,532 @@ +use proc_macro as pc; +use proc_macro2::{Ident, Span, TokenStream}; +use quote::{format_ident, quote, ToTokens}; +use std::stringify; +use syn::parse::{Parse, ParseStream}; +use syn::punctuated::Punctuated; +use syn::spanned::Spanned; +use syn::{Error, Token}; + +/// Creates a bitfield for this struct. +/// +/// There are two ways to specify the size and alignment of the bitfield: +/// +/// - Integer type: `#[bitfield(u32)]` +/// - All explicitly sized integers are supported +/// - The alignment defaults to the alignment of the integer if not otherwise specified +/// - Bytes argument: `#[bitfield(bytes = 3)]` +/// - The default alignment is 1 +/// +/// The macro two optional arguments +/// - `align`: Specifies the alignment of the bitfield +/// - `debug`: Whether or not the fmt::Debug trait should be generated (default: true) +/// +/// For example: `#[bitfield(bytes = 6, align = 2, debug = false)]` +#[proc_macro_attribute] +pub fn bitfield(args: pc::TokenStream, input: pc::TokenStream) -> pc::TokenStream { + match bitfield_inner(args.into(), input.into()) { + Ok(result) => result.into(), + Err(e) => e.into_compile_error().into(), + } +} + +fn bitfield_inner(args: TokenStream, input: TokenStream) -> syn::Result { + let input = syn::parse2::(input)?; + let Params { + bytes, + align, + debug, + ty, + } = syn::parse2::(args)?; + + let bits = bytes * 8; + let span = input.fields.span(); + let name = input.ident; + let name_str = name.to_string(); + let vis = input.vis; + let attrs: TokenStream = input.attrs.iter().map(ToTokens::to_token_stream).collect(); + + let syn::Fields::Named(fields) = input.fields else { + return Err(Error::new(span, "only named fields are supported")); + }; + + let mut offset = 0; + let mut members = Vec::with_capacity(fields.named.len()); + for field in fields.named { + let f = Member::new(field, offset)?; + offset += f.bits; + members.push(f); + } + + if offset < bits { + return Err(Error::new( + span, + format_args!( + "The bitfiled size ({bits} bits) has to be equal to the sum of its members ({offset} bits)!. \ + You might have to add padding (a {} bits large member prefixed with \"_\").", + bits - offset + ), + )); + } + if offset > bits { + return Err(Error::new( + span, + format_args!( + "The size of the members ({offset} bits) is larger than the type ({bits} bits)!." + ), + )); + } + + let debug_impl = if debug { + let debug_fields = members.iter().map(|m| m.debug()); + quote! { + impl core::fmt::Debug for #name { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + f.debug_struct(#name_str) + #( #debug_fields )* + .finish() + } + } + } + } else { + Default::default() + }; + + // The size of isize and usize is architecture dependent and not known for proc_macros, + // thus we have to check it with const asserts. + let const_asserts = members.iter().filter_map(|m| { + if m.class == TypeClass::SizeInt { + let bits = m.bits; + let msg = format!("overflowing field type of '{}'", m.ident); + Some(quote!( + const _: () = assert!(#bits <= 8 * std::mem::size_of::(), #msg); + )) + } else { + None + } + }); + let type_conversion = if let Some(ty) = ty { + Some(quote! { + impl From<#ty> for #name { + fn from(v: #ty) -> Self { + Self(v.to_be_bytes()) + } + } + impl From<#name> for #ty { + fn from(v: #name) -> #ty { + #ty::from_be_bytes(v.0) + } + } + }) + } else { + None + }; + + let align = syn::LitInt::new(&format!("{align}"), Span::mixed_site()); + Ok(quote! { + #attrs + #[derive(Copy, Clone)] + #[repr(align(#align))] + #vis struct #name([u8; #bytes]); + + impl #name { + #vis const fn new() -> Self { + Self([0; #bytes]) + } + + #( #members )* + } + + impl From<[u8; #bytes]> for #name { + fn from(v: [u8; #bytes]) -> Self { + Self(v) + } + } + impl From<#name> for [u8; #bytes] { + fn from(v: #name) -> [u8; #bytes] { + v.0 + } + } + #type_conversion + + #( #const_asserts )* + + #debug_impl + }) +} + +/// Distinguish between different types for code generation. +/// +/// We need this to make accessor functions for bool and ints const. +/// As soon as we have const conversion traits, we can simply switch to `TryFrom` and don't have to generate different code. +#[derive(Debug, PartialEq, Eq, Clone, Copy)] +enum TypeClass { + /// Booleans with 1 bit size + Bool, + /// Ints with fixes sizes: u8, u64, ... + Int, + /// Ints with architecture dependend sizes: usize, isize + SizeInt, + /// Custom types + Other, +} + +struct Member { + attrs: Vec, + ty: syn::Type, + class: TypeClass, + bits: usize, + ident: syn::Ident, + vis: syn::Visibility, + offset: usize, +} + +impl Member { + fn new(f: syn::Field, offset: usize) -> syn::Result { + let span = f.span(); + + let syn::Field { + mut attrs, + vis, + ident, + ty, + .. + } = f; + + let ident = ident.ok_or_else(|| Error::new(span, "Not supported"))?; + + let (class, bits) = bits(&attrs, &ty)?; + // remove our attribute + attrs.retain(|a| !a.path.is_ident("bits")); + + Ok(Self { + attrs, + ty, + class, + bits, + ident, + vis, + offset, + }) + } + + fn debug(&self) -> TokenStream { + let ident_str = self.ident.to_string(); + if self.bits > 0 && !ident_str.starts_with('_') { + let ident = &self.ident; + quote!(.field(#ident_str, &self.#ident())) + } else { + Default::default() + } + } +} + +impl ToTokens for Member { + fn to_tokens(&self, tokens: &mut TokenStream) { + let Self { + attrs, + ty, + class, + bits, + ident, + vis, + offset, + } = self; + let ident_str = ident.to_string(); + + // Skip zero sized and padding members + if self.bits == 0 || ident_str.starts_with('_') { + return Default::default(); + } + + let with_ident = format_ident!("with_{ident}"); + let set_ident = format_ident!("set_{ident}"); + let bits_ident = format_ident!("{}_BITS", ident_str.to_uppercase()); + let offset_ident = format_ident!("{}_OFFSET", ident_str.to_uppercase()); + + let location = format!("\n\nBits: {offset}..{}", offset + bits); + + let doc: TokenStream = attrs + .iter() + .filter(|a| !a.path.is_ident("bits")) + .map(ToTokens::to_token_stream) + .collect(); + + let general = quote! { + const #bits_ident: usize = #bits; + const #offset_ident: usize = #offset; + + #doc + #[doc = #location] + #vis fn #set_ident(&mut self, value: #ty) { + *self = self.#with_ident(value); + } + }; + + let bytes = (bits + 7) / 8; + + let code = match class { + TypeClass::Bool => quote! { + #general + + #doc + #[doc = #location] + #vis const fn #with_ident(self, value: #ty) -> Self { + let src = [value as u8]; + Self(bitfield_struct::bit_copy(self.0, #offset, &src, 0, 1)) + } + #doc + #[doc = #location] + #vis const fn #ident(&self) -> #ty { + bitfield_struct::is_bit_set(&self.0, #offset) + } + }, + TypeClass::Int | TypeClass::SizeInt => quote! { + #general + + #doc + #[doc = #location] + #vis const fn #with_ident(self, value: #ty) -> Self { + let src = value.to_ne_bytes(); + Self(bitfield_struct::bit_copy(self.0, #offset, &src, 0, #bits)) + } + #doc + #[doc = #location] + #vis const fn #ident(&self) -> #ty { + // copy to the upper half + let out = bitfield_struct::bit_copy( + [0; #ty::BITS as usize / 8], #ty::BITS as usize - #bits, &self.0, #offset, #bits); + // shift down to potentially perform a sign extend + #ty::from_ne_bytes(out) >> (#ty::BITS as usize - #bits) + } + }, + TypeClass::Other => quote! { + #general + + #doc + #[doc = #location] + #vis fn #with_ident(self, value: #ty) -> Self { + let src: [u8; #bytes] = value.into(); + Self(bitfield_struct::bit_copy(self.0, #offset, &src, 0, #bits)) + } + #doc + #[doc = #location] + #vis fn #ident(&self) -> #ty { + let out = bitfield_struct::bit_copy([0; #bytes], 0, &self.0, #offset, #bits); + out.into() + } + }, + }; + tokens.extend(code); + } +} + +/// Parses the `bits` attribute that allows specifying a custom number of bits. +fn bits(attrs: &[syn::Attribute], ty: &syn::Type) -> syn::Result<(TypeClass, usize)> { + let size_int = matches!(ty, syn::Type::Path(syn::TypePath{ path, .. }) + if path.is_ident("usize") || path.is_ident("isize")); + + for attr in attrs { + if let syn::Attribute { + style: syn::AttrStyle::Outer, + path, + tokens, + .. + } = attr + { + if !path.is_ident("bits") { + continue; + } + + let bits = attr + .parse_args::() + .map_err(|e| e.with(Error::new(attr.span(), "malformed #[bits] attribute")))? + .base10_parse() + .map_err(|e| e.with(Error::new(attr.span(), "malformed #[bits] attribute")))?; + + return if bits == 0 { + Ok((TypeClass::Other, 0)) + } else if let Ok((class, size)) = type_bits(ty) { + if bits <= size { + Ok((class, bits)) + } else { + Err(Error::new(tokens.span(), "overflowing field type")) + } + } else if size_int { + // isize and usize are supported but their bit size is not known at this point! + // Meaning that they must have a bits attribute explicitly defining their size + Ok((TypeClass::SizeInt, bits)) + } else { + Ok((TypeClass::Other, bits)) + }; + } + } + + if size_int { + return Err(Error::new( + ty.span(), + "isize and usize fields require the #[bits($1)] attribute", + )); + } + + // Fallback to type size + type_bits(ty) +} + +/// Returns the number of bits for a given type +fn type_bits(ty: &syn::Type) -> syn::Result<(TypeClass, usize)> { + let err_unsupported = Error::new(ty.span(), "unsupported type"); + + let syn::Type::Path(syn::TypePath{ path, .. }) = ty else { + return Err(err_unsupported) + }; + let Some(ident) = path.get_ident() else { + return Err(err_unsupported) + }; + if ident == "bool" { + return Ok((TypeClass::Bool, 1)); + } + macro_rules! integer { + ($ident:ident => $($ty:ident),*) => { + match ident { + $(_ if ident == stringify!($ty) => Ok((TypeClass::Int, $ty::BITS as _)),)* + _ => Err(err_unsupported) + } + }; + } + integer!(ident => u8, i8, u16, i16, u32, i32, u64, i64, u128, i128) +} + +struct Params { + ty: Option, + bytes: usize, + align: usize, + debug: bool, +} + +impl Parse for Params { + fn parse(input: ParseStream) -> syn::Result { + let mut align = 1; + let mut debug = true; + let mut ty = None; + + let bytes = if input.peek2(Token![=]) && input.peek(syn::Ident) { + Bytes::parse(input)?.bytes + } else { + let t = syn::Type::parse(input)?; + let (class, bits) = type_bits(&t)?; + if class != TypeClass::Int { + return Err(Error::new( + input.span(), + "Invalid argument or type, expecting `bytes` or an integer type", + )); + } + ty = Some(t); + align = bits / 8; + bits / 8 + }; + + if let Ok(_) = ::parse(input) { + let params = Punctuated::::parse_terminated(input)?; + for param in params { + match param { + Param::Align(a) => align = a, + Param::Debug(d) => debug = d, + } + } + } + + Ok(Params { + ty, + bytes, + align, + debug, + }) + } +} + +struct Bytes { + bytes: usize, +} + +impl Parse for Bytes { + fn parse(input: ParseStream) -> syn::Result { + let ident = Ident::parse(input)?; + if ident != "bytes" { + return Err(Error::new( + ident.span(), + "Invalid argument, expecting `bytes` or an integer type", + )); + } + ::parse(input)?; + Ok(Self { + bytes: syn::LitInt::parse(input)?.base10_parse()?, + }) + } +} + +enum Param { + Align(usize), + Debug(bool), +} + +impl Parse for Param { + fn parse(input: ParseStream) -> syn::Result { + let ident = Ident::parse(input)?; + + ::parse(input)?; + + if ident == "align" { + Ok(Self::Align(syn::LitInt::parse(input)?.base10_parse()?)) + } else if ident == "debug" { + Ok(Self::Debug(syn::LitBool::parse(input)?.value)) + } else { + Err(Error::new(ident.span(), "unknown argument")) + } + } +} + +trait ErrorExt { + fn with(self, other: Self) -> Self; +} + +impl ErrorExt for Error { + fn with(mut self, other: Self) -> Self { + self.combine(other); + self + } +} + +#[cfg(test)] +mod test { + use quote::quote; + + use crate::Params; + + #[test] + fn parse_args() { + let args = quote! { + bytes = 3 + }; + let params = syn::parse2::(args).unwrap(); + assert!(params.bytes == 3 && params.debug == true); + + let args = quote! { + bytes = 3, align = 2, debug = false + }; + let params = syn::parse2::(args).unwrap(); + assert!(params.bytes == 3 && params.align == 2 && params.debug == false); + + let args = quote! { + u64 + }; + let params = syn::parse2::(args).unwrap(); + assert!(params.bytes == 8 && params.debug == true); + + let args = quote! { + u32, debug = false + }; + let params = syn::parse2::(args).unwrap(); + assert!(params.bytes == 4 && params.debug == false); + } +} diff --git a/src/lib.rs b/src/lib.rs index 3e9b085..9f0488e 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -73,7 +73,7 @@ //! ``` //! # use bitfield_struct::bitfield; //! /// A test bitfield with documentation -//! #[bitfield(u64)] +//! #[bitfield(u64, align = 4)] // <- Set a specific alignment (defaults to the integers alignment) //! #[derive(PartialEq, Eq)] // <- Attributes after `bitfield` are carried over //! struct MyBitfield { //! /// defaults to 16 bits for u16 @@ -102,25 +102,25 @@ //! //! /// A custom enum //! #[derive(Debug, PartialEq, Eq)] -//! #[repr(u64)] +//! #[repr(u8)] //! enum CustomEnum { //! A = 0, //! B = 1, //! C = 2, //! } -//! // implement `From` and `Into` for `CustomEnum`! -//! # impl From for CustomEnum { -//! # fn from(value: u64) -> Self { -//! # match value { +//! // implement `From<[u8; 2]>` and `Into<[u8; 2]>` for `CustomEnum`! +//! # impl From<[u8; 2]> for CustomEnum { +//! # fn from(value: [u8; 2]) -> Self { +//! # match value[0] { //! # 0 => Self::A, //! # 1 => Self::B, //! # _ => Self::C, //! # } //! # } //! # } -//! # impl From for u64 { +//! # impl From for [u8; 2] { //! # fn from(value: CustomEnum) -> Self { -//! # value as _ +//! # [value as _, 0] //! # } //! # } //! @@ -159,9 +159,9 @@ //! //! ```ignore //! // generated struct -//! struct MyBitfield(u64); +//! struct MyBitfield([u8; 8]); //! impl MyBitfield { -//! const fn new() -> Self { Self(0) } +//! const fn new() -> Self { Self([0; 8]) } //! //! const INT_BITS: usize = 16; //! const INT_OFFSET: usize = 0; @@ -173,6 +173,9 @@ //! // other field ... //! } //! // generated trait implementations +//! impl From<[u8; 8]> for MyBitfield { /* ... */ } +//! impl From for [u8; 8] { /* ... */ } +//! // from the `ty` parameter //! impl From for MyBitfield { /* ... */ } //! impl From for u64 { /* ... */ } //! impl Debug for MyBitfield { /* ... */ } @@ -180,6 +183,22 @@ //! //! > Hint: You can use the rust-analyzer "Expand macro recursively" action to view the generated code. //! +//! ## No-type bitfields +//! +//! Instead of specifying a base type, you can manually define the `size` and `align` of the bitfield. +//! +//! ``` +//! # use bitfield_struct::bitfield; +//! # use std::mem::{size_of, align_of}; +//! #[bitfield(bytes = 9)] +//! struct NoTy { +//! data: u64, +//! extra: u8, +//! } +//! assert_eq!(size_of::(), 9); +//! assert_eq!(align_of::(), 1); // align defaults to 1 +//! ``` +//! //! ## `fmt::Debug` //! //! This macro automatically creates a suitable `fmt::Debug` implementation @@ -206,439 +225,321 @@ //! ``` //! -use proc_macro as pc; -use proc_macro2::{Ident, Span, TokenStream}; -use quote::{format_ident, quote, ToTokens}; -use std::stringify; -use syn::parse::{Parse, ParseStream}; -use syn::spanned::Spanned; -use syn::Token; +pub use bitfield_struct_derive::bitfield; -/// Creates a bitfield for this struct. +/// The heart of the bitfield macro. +/// It copies bits (with different offsets) from `src` to `dst`. /// -/// The arguments first, have to begin with the underlying type of the bitfield: -/// For example: `#[bitfield(u64)]`. +/// This function is used both for the getters and setters of the bitfield struct. /// -/// It can contain an extra `debug` argument for disabling the `Debug` trait -/// generation (`#[bitfield(u64, debug = false)]`). -#[proc_macro_attribute] -pub fn bitfield(args: pc::TokenStream, input: pc::TokenStream) -> pc::TokenStream { - match bitfield_inner(args.into(), input.into()) { - Ok(result) => result.into(), - Err(e) => e.into_compile_error().into(), - } -} - -fn bitfield_inner(args: TokenStream, input: TokenStream) -> syn::Result { - let input = syn::parse2::(input)?; - let Params { ty, bits, debug } = - syn::parse2::(args).map_err(|e| unsupported_param(e, input.span()))?; - - let span = input.fields.span(); - let name = input.ident; - let name_str = name.to_string(); - let vis = input.vis; - let attrs: TokenStream = input.attrs.iter().map(ToTokens::to_token_stream).collect(); - - let syn::Fields::Named(fields) = input.fields else { - return Err(syn::Error::new(span, "only named fields are supported")); - }; - - let mut offset = 0; - let mut members = Vec::with_capacity(fields.named.len()); - for field in fields.named { - let f = Member::new(ty.clone(), field, offset)?; - offset += f.bits; - members.push(f); - } - - if offset < bits { - return Err(syn::Error::new( - span, - format!( - "The bitfiled size ({bits} bits) has to be equal to the sum of its members ({offset} bits)!. \ - You might have to add padding (a {} bits large member prefixed with \"_\").", - bits - offset - ), - )); - } - if offset > bits { - return Err(syn::Error::new( - span, - format!( - "The size of the members ({offset} bits) is larger than the type ({bits} bits)!." - ), - )); - } - - let debug_impl = if debug { - let debug_fields = members.iter().map(|m| m.debug()); - quote! { - impl core::fmt::Debug for #name { - fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { - f.debug_struct(#name_str) - #( #debug_fields )* - .finish() - } - } - } - } else { - Default::default() - }; - - // The size of isize and usize is architecture dependent and not known for proc_macros, - // thus we have to check it with const asserts. - let const_asserts = members.iter().filter_map(|m| { - if m.class == TypeClass::SizeInt { - let bits = m.bits; - let msg = format!("overflowing field type of '{}'", m.ident); - Some(quote!( - const _: () = assert!(#bits <= 8 * std::mem::size_of::(), #msg); - )) - } else { - None - } - }); - - Ok(quote! { - #attrs - #[derive(Copy, Clone)] - #[repr(transparent)] - #vis struct #name(#ty); - - impl #name { - #vis const fn new() -> Self { - Self(0) - } - - #( #members )* - } - - impl From<#ty> for #name { - fn from(v: #ty) -> Self { - Self(v) - } - } - impl From<#name> for #ty { - fn from(v: #name) -> #ty { - v.0 - } - } - - #( #const_asserts )* - - #debug_impl - }) -} - -/// Distinguish between different types for code generation. +/// General idea: +/// - Copy prefix bits +/// - Copy aligned u8 +/// - Copy suffix bits /// -/// We need this to make accessor functions for bool and ints const. -/// As soon as we have const conversion traits, we can simply switch to `TryFrom` and don't have to generate different code. -#[derive(Debug, PartialEq, Eq, Clone, Copy)] -enum TypeClass { - /// Booleans with 1 bit size - Bool, - /// Ints with fixes sizes: u8, u64, ... - Int, - /// Ints with architecture dependend sizes: usize, isize - SizeInt, - /// Custom types - Other, +/// Possible future optimization: +/// - Copy and shift with larger instructions (u16/u32/u64) if the buffers are large enough +/// +/// FIXME: Use mutable reference as soon as `const_mut_refs` is stable +#[inline(always)] +pub const fn bit_copy( + mut dst: [u8; D], + dst_off: usize, + src: &[u8], + src_off: usize, + len: usize, +) -> [u8; D] { + debug_assert!(len > 0); + debug_assert!(dst.len() * 8 >= dst_off + len); + debug_assert!(src.len() * 8 >= src_off + len); + + if len == 1 { + let dst_i = dst_off / 8; + dst[dst_i] = single_bit(dst[dst_i], dst_off % 8, src, src_off); + dst + } else if len < (8 - (dst_off % 8)) { + // edge case if there are less then one byte to be copied + let dst_i = dst_off / 8; + dst[dst_i] = single_byte(dst[dst_i], dst_off % 8, src, src_off, len); + dst + } else if dst_off % 8 == src_off % 8 { + copy_aligned(dst, dst_off / 8, src, src_off / 8, dst_off % 8, len) + } else { + copy_unaligned(dst, dst_off, src, src_off, len) + } } -struct Member { - base_ty: syn::Type, - attrs: Vec, - ty: syn::Type, - class: TypeClass, - bits: usize, - ident: syn::Ident, - vis: syn::Visibility, - offset: usize, +/// Test if this bit is set +#[inline(always)] +pub const fn is_bit_set(src: &[u8], i: usize) -> bool { + debug_assert!(i < src.len() * 8); + (src[i / 8] >> (i % 8)) & 1 != 0 } -impl Member { - fn new(base_ty: syn::Type, f: syn::Field, offset: usize) -> syn::Result { - let span = f.span(); - - let syn::Field { - mut attrs, - vis, - ident, - ty, - .. - } = f; - - let ident = ident.ok_or_else(|| syn::Error::new(span, "Not supported"))?; - - let (class, bits) = bits(&attrs, &ty)?; - // remove our attribute - attrs.retain(|a| !a.path.is_ident("bits")); - - Ok(Self { - base_ty, - attrs, - ty, - class, - bits, - ident, - vis, - offset, - }) - } - - fn debug(&self) -> TokenStream { - let ident_str = self.ident.to_string(); - if self.bits > 0 && !ident_str.starts_with('_') { - let ident = &self.ident; - quote!(.field(#ident_str, &self.#ident())) - } else { - Default::default() - } +/// Only a single bit is copied +#[inline(always)] +const fn single_bit(dst: u8, dst_off: usize, src: &[u8], src_off: usize) -> u8 { + debug_assert!(dst_off < 8); + if is_bit_set(src, src_off) { + dst | (1 << dst_off) + } else { + dst & !(1 << dst_off) } } -impl ToTokens for Member { - fn to_tokens(&self, tokens: &mut TokenStream) { - let Self { - base_ty, - attrs, - ty, - class, - bits, - ident, - vis, - offset, - } = self; - let ident_str = ident.to_string(); - - // Skip zero sized and padding members - if self.bits == 0 || ident_str.starts_with('_') { - return Default::default(); - } - - let with_ident = format_ident!("with_{ident}"); - let set_ident = format_ident!("set_{ident}"); - let bits_ident = format_ident!("{}_BITS", ident_str.to_uppercase()); - let offset_ident = format_ident!("{}_OFFSET", ident_str.to_uppercase()); +/// We have only one destination byte. +#[inline(always)] +const fn single_byte(dst: u8, dst_off: usize, src: &[u8], src_off: usize, len: usize) -> u8 { + const MAX: u8 = u8::MAX; + const BITS: usize = u8::BITS as _; - let location = format!("\n\nBits: {offset}..{}", offset + bits); + debug_assert!(dst_off < BITS); - let doc: TokenStream = attrs - .iter() - .filter(|a| !a.path.is_ident("bits")) - .map(ToTokens::to_token_stream) - .collect(); + let src_i = src_off / BITS; + let src_off = src_off % BITS; - let general = quote! { - const #bits_ident: usize = #bits; - const #offset_ident: usize = #offset; + let mask = (MAX >> (BITS - len)) << dst_off; + let mut dst = dst & !mask; + dst |= ((src[src_i] >> src_off) << dst_off) & mask; - #doc - #[doc = #location] - #vis fn #set_ident(&mut self, value: #ty) { - *self = self.#with_ident(value); - } - }; - - let mask: u128 = !0 >> (u128::BITS as usize - bits); - let mask = syn::LitInt::new(&format!("0x{mask:x}"), Span::mixed_site()); - - let code = match class { - TypeClass::Bool => quote! { - #general - - #doc - #[doc = #location] - #vis const fn #with_ident(self, value: #ty) -> Self { - Self(self.0 & !(1 << #offset) | (value as #base_ty) << #offset) - } - #doc - #[doc = #location] - #vis const fn #ident(&self) -> #ty { - ((self.0 >> #offset) & 1) != 0 - } - }, - TypeClass::Int | TypeClass::SizeInt => quote! { - #general - - #doc - #[doc = #location] - #vis const fn #with_ident(self, value: #ty) -> Self { - debug_assert!(value <= #mask); - Self(self.0 & !(#mask << #offset) | (value as #base_ty & #mask) << #offset) - } - #doc - #[doc = #location] - #vis const fn #ident(&self) -> #ty { - let shift = #ty::BITS as usize - #bits; - (((self.0 >> #offset) as #ty) << shift) >> shift - } - }, - TypeClass::Other => quote! { - #general - - #doc - #[doc = #location] - #vis fn #with_ident(self, value: #ty) -> Self { - let value: #base_ty = value.into(); - debug_assert!(value <= #mask); - Self(self.0 & !(#mask << #offset) | (value & #mask) << #offset) - } - #doc - #[doc = #location] - #vis fn #ident(&self) -> #ty { - let shift = #base_ty::BITS as usize - #bits; - (((self.0 >> #offset) << shift) >> shift).into() - } - }, - }; - tokens.extend(code); + // exceeding a single byte of the src buffer + if len + src_off > BITS { + dst |= (src[src_i + 1] << (BITS - src_off + dst_off)) & mask; } + dst } -/// Parses the `bits` attribute that allows specifying a custom number of bits. -fn bits(attrs: &[syn::Attribute], ty: &syn::Type) -> syn::Result<(TypeClass, usize)> { - fn malformed(mut e: syn::Error, attr: &syn::Attribute) -> syn::Error { - e.combine(syn::Error::new(attr.span(), "malformed #[bits] attribute")); - e - } - - for attr in attrs { - match attr { - syn::Attribute { - style: syn::AttrStyle::Outer, - path, - tokens, - .. - } if path.is_ident("bits") => { - let bits = attr - .parse_args::() - .map_err(|e| malformed(e, attr))? - .base10_parse() - .map_err(|e| malformed(e, attr))?; - - return if bits == 0 { - Ok((TypeClass::Other, 0)) - } else if let Ok((class, size)) = type_bits(ty) { - if bits <= size { - Ok((class, bits)) - } else { - Err(syn::Error::new(tokens.span(), "overflowing field type")) - } - } else if matches!(ty, syn::Type::Path(syn::TypePath{ path, .. }) - if path.is_ident("usize") || path.is_ident("isize")) - { - // isize and usize are supported but types size is not known at this point! - // Meaning that they must have a bits attribute explicitly defining their size - Ok((TypeClass::SizeInt, bits)) - } else { - Ok((TypeClass::Other, bits)) - }; +/// The buffers have different bit offsets +#[inline(always)] +const fn copy_unaligned( + mut dst: [u8; D], + mut dst_off: usize, + src: &[u8], + mut src_off: usize, + mut len: usize, +) -> [u8; D] { + const MAX: u8 = u8::MAX; + const BITS: usize = u8::BITS as _; + + debug_assert!(src_off % BITS != 0 || dst_off % BITS != 0); + debug_assert!(dst.len() * BITS >= dst_off + len); + debug_assert!(src.len() * BITS >= src_off + len); + + let mut dst_i = dst_off / BITS; + dst_off %= BITS; + let mut src_i = src_off / BITS; + src_off %= BITS; + + // copy dst prefix -> byte-align dst + if dst_off > 0 { + let prefix = BITS - dst_off; + let mask = MAX << dst_off; + dst[dst_i] &= !mask; + dst[dst_i] |= (src[src_i] >> src_off) << dst_off; + + // exceeding a single byte of the src buffer + dst_off += BITS - src_off; + src_off += prefix; + if let Some(reminder) = src_off.checked_sub(BITS) { + src_i += 1; + if reminder > 0 { + dst[dst_i] |= src[src_i] << dst_off } - _ => {} + src_off = reminder } + dst_i += 1; + len -= prefix; } - if let syn::Type::Path(syn::TypePath { path, .. }) = ty { - if path.is_ident("usize") || path.is_ident("isize") { - return Err(syn::Error::new( - ty.span(), - "isize and usize fields require the #[bits($1)] attribute", - )); - } + // copy middle + let mut i = 0; + while i < len / BITS { + dst[dst_i + i] = (src[src_i + i] >> src_off) | (src[src_i + i + 1] << (BITS - src_off)); + i += 1; } - // Fallback to type size - type_bits(ty) + // suffix + let suffix = len % BITS; + if suffix > 0 { + let last = len / BITS; + let mask = MAX >> (BITS - suffix); + dst[dst_i + last] &= !mask; + dst[dst_i + last] |= src[src_i + last] >> src_off; + + // exceeding a single byte of the src buffer + if suffix + src_off > BITS { + dst[dst_i + last] |= (src[src_i + last + 1] << (BITS - src_off)) & mask; + } + } + dst } -/// Returns the number of bits for a given type -fn type_bits(ty: &syn::Type) -> syn::Result<(TypeClass, usize)> { - let syn::Type::Path(syn::TypePath{ path, .. }) = ty else { - return Err(syn::Error::new(ty.span(), "unsupported type")) - }; - let Some(ident) = path.get_ident() else { - return Err(syn::Error::new(ty.span(), "unsupported type")) - }; - if ident == "bool" { - return Ok((TypeClass::Bool, 1)); +/// The buffers have the same bit offsets +#[inline(always)] +const fn copy_aligned( + mut dst: [u8; D], + mut dst_i: usize, + src: &[u8], + mut src_i: usize, + off: usize, + mut len: usize, +) -> [u8; D] { + const MAX: u8 = u8::MAX; + const BITS: usize = u8::BITS as _; + + debug_assert!(off < BITS); + debug_assert!(dst.len() * BITS >= dst_i * BITS + len); + debug_assert!(src.len() * BITS >= src_i * BITS + len); + + // copy prefix -> byte-align dst + if off > 0 { + let prefix = BITS - (off % BITS); + let mask = MAX << (off % BITS); + dst[dst_i] &= !mask; + dst[dst_i] |= src[src_i] & mask; + + src_i += 1; + dst_i += 1; + len -= prefix; } - macro_rules! integer { - ($ident:ident => $($ty:ident),*) => { - match ident { - $(_ if ident == stringify!($ty) => Ok((TypeClass::Int, $ty::BITS as _)),)* - _ => Err(syn::Error::new(ty.span(), "unsupported type")) - } - }; + + // copy middle + let mut i = 0; + while i < len / BITS { + dst[dst_i + i] = src[src_i + i]; + i += 1; } - integer!(ident => u8, i8, u16, i16, u32, i32, u64, i64, u128, i128) -} -struct Params { - ty: syn::Type, - bits: usize, - debug: bool, + // copy suffix + let suffix = len % BITS; + if suffix > 0 { + let last = len / BITS; + let mask = MAX >> (BITS - suffix); + dst[dst_i + last] &= !mask; + dst[dst_i + last] |= src[src_i + last]; + } + dst } -impl Parse for Params { - fn parse(input: ParseStream) -> syn::Result { - let Ok(ty) = syn::Type::parse(input) else { - return Err(syn::Error::new(input.span(), "unknown type")); - }; - let (class, bits) = type_bits(&ty)?; - if class != TypeClass::Int { - return Err(syn::Error::new(input.span(), "unsupported type")); - } - - // try parse additional debug arg - let debug = if ::parse(input).is_ok() { - let ident = Ident::parse(input)?; - - if ident != "debug" { - return Err(syn::Error::new(ident.span(), "unknown argument")); - } - ::parse(input)?; - - syn::LitBool::parse(input)?.value - } else { - true - }; +#[cfg(test)] +mod test { - Ok(Params { bits, ty, debug }) + #[allow(unused)] + fn b_print(b: &[u8]) { + for v in b.iter().rev() { + print!("{v:08b} "); + } + println!() } -} -fn unsupported_param(mut e: syn::Error, arg: T) -> syn::Error -where - T: syn::spanned::Spanned, -{ - e.combine(syn::Error::new( - arg.span(), - "unsupported #[bitfield] argument", - )); - e -} + #[test] + fn copy_bits_single_bit() { + // single byte + let src = [0b00100000]; + let dst = [0b10111111]; + let dst = super::bit_copy(dst, 6, &src, 5, 1); + assert_eq!(dst, [0b11111111]); + // reversed + let src = [!0b00100000]; + let dst = [!0b10111111]; + let dst = super::bit_copy(dst, 6, &src, 5, 1); + assert_eq!(dst, [!0b11111111]); + } -#[cfg(test)] -mod test { - use quote::quote; + #[test] + fn copy_bits_single_byte() { + // single byte + let src = [0b00111000]; + let dst = [0b10001111]; + let dst = super::bit_copy(dst, 4, &src, 3, 3); + assert_eq!(dst, [0b11111111]); + // reversed + let src = [!0b00111000]; + let dst = [!0b10001111]; + let dst = super::bit_copy(dst, 4, &src, 3, 3); + assert_eq!(dst, [!0b11111111]); + } - use crate::Params; + #[test] + fn copy_bits_unaligned() { + // two to single byte + let src = [0b00000000, 0b11000000, 0b00000111, 0b00000000]; + let dst = [0b00000000, 0b00000000, 0b00000000, 0b00000000]; + let dst = super::bit_copy(dst, 17, &src, 14, 5); + assert_eq!(dst, [0b00000000, 0b00000000, 0b00111110, 0b0000000]); + // reversed + let src = [!0b00000000, !0b11000000, !0b00000111, !0b00000000]; + let dst = [!0b00000000, !0b00000000, !0b00000000, !0b00000000]; + let dst = super::bit_copy(dst, 17, &src, 14, 5); + assert_eq!(dst, [!0b00000000, !0b00000000, !0b00111110, !0b0000000]); + + // over two bytes + let src = [0b00000000, 0b11000000, 0b00000111, 0b00000000]; + let dst = [0b00000000, 0b00000000, 0b00000000, 0b00000000]; + let dst = super::bit_copy(dst, 23, &src, 14, 5); + assert_eq!(dst, [0b00000000, 0b00000000, 0b10000000, 0b00001111]); + // reversed + let src = [!0b00000000, !0b11000000, !0b00000111, !0b00000000]; + let dst = [!0b00000000, !0b00000000, !0b00000000, !0b00000000]; + let dst = super::bit_copy(dst, 23, &src, 14, 5); + assert_eq!(dst, [!0b00000000, !0b00000000, !0b10000000, !0b00001111]); + + // over three bytes + let src = [0b11000000, 0b11111111, 0b00000111, 0b00000000]; + let dst = [0b00000000, 0b00000000, 0b00000000, 0b00000000]; + let dst = super::bit_copy(dst, 15, &src, 6, 13); + assert_eq!(dst, [0b00000000, 0b10000000, 0b11111111, 0b00001111]); + // reversed + let src = [!0b11000000, !0b11111111, !0b00000111, !0b00000000]; + let dst = [!0b00000000, !0b00000000, !0b00000000, !0b00000000]; + let dst = super::bit_copy(dst, 15, &src, 6, 13); + assert_eq!(dst, [!0b00000000, !0b10000000, !0b11111111, !0b00001111]); + + // prefix exceeds a single byte + let src = [0b00000000, 0b10000000, 0b11111111, 0b00000111]; + let dst = [0b00000000, 0b00000000, 0b00000000, 0b00000000]; + let dst = super::bit_copy(dst, 20, &src, 15, 12); + assert_eq!(dst, [0b00000000, 0b00000000, 0b11110000, 0b11111111]); + // reversed + let src = [!0b00000000, !0b10000000, !0b11111111, !0b00000111]; + let dst = [!0b00000000, !0b00000000, !0b00000000, !0b00000000]; + let dst = super::bit_copy(dst, 20, &src, 15, 12); + assert_eq!(dst, [!0b00000000, !0b00000000, !0b11110000, !0b11111111]); + } #[test] - fn parse_args() { - let args = quote! { - u64 - }; - let params = syn::parse2::(args).unwrap(); - assert!(params.bits == u64::BITS as usize && params.debug == true); - - let args = quote! { - u32, debug = false - }; - let params = syn::parse2::(args).unwrap(); - assert!(params.bits == u32::BITS as usize && params.debug == false); + fn copy_bits_aligned() { + // over two bytes + let src = [0b00000000, 0b11000000, 0b00000111, 0b00000000]; + let dst = [0b00000000, 0b00000000, 0b00000000, 0b00000000]; + let dst = super::bit_copy(dst, 14, &src, 14, 5); + assert_eq!(dst, src); + // reversed + let src = [!0b00000000, !0b11000000, !0b00000111, !0b00000000]; + let dst = [!0b00000000, !0b00000000, !0b00000000, !0b00000000]; + let dst = super::bit_copy(dst, 14, &src, 14, 5); + assert_eq!(dst, src); + + // over three bytes + let src = [0b11000000, 0b11100111, 0b00000111, 0b00000000]; + let dst = [0b00000000, 0b00000000, 0b00000000, 0b00000000]; + let dst = super::bit_copy(dst, 14, &src, 6, 13); + assert_eq!(dst, [0b00000000, 0b11000000, 0b11100111, 0b00000111]); + // reversed + let src = [!0b11000000, !0b11100111, !0b00000111, !0b00000000]; + let dst = [!0b00000000, !0b00000000, !0b00000000, !0b00000000]; + let dst = super::bit_copy(dst, 14, &src, 6, 13); + assert_eq!(dst, [!0b00000000, !0b11000000, !0b11100111, !0b00000111]); + + // all bits + let src = [0xff, 0xff, 0xff, 0xff]; + let dst = [0, 0, 0, 0]; + let dst = super::bit_copy(dst, 0, &src, 0, 4 * 8); + assert_eq!(dst, [0xff, 0xff, 0xff, 0xff]); + // reversed + let src = [0, 0, 0, 0]; + let dst = [0xff, 0xff, 0xff, 0xff]; + let dst = super::bit_copy(dst, 0, &src, 0, 4 * 8); + assert_eq!(dst, [0, 0, 0, 0]); } } diff --git a/tests/test.rs b/tests/test.rs index f1f725b..b439106 100644 --- a/tests/test.rs +++ b/tests/test.rs @@ -1,4 +1,5 @@ use std::fmt; +use std::mem::{align_of, size_of}; use bitfield_struct::bitfield; @@ -33,46 +34,55 @@ fn members() { /// A custom enum #[derive(Debug, PartialEq, Eq)] - #[repr(u64)] + #[repr(u8)] enum CustomEnum { A = 0, B = 1, C = 2, } - impl From for CustomEnum { - fn from(value: u64) -> Self { - match value { + impl From<[u8; 2]> for CustomEnum { + fn from(value: [u8; 2]) -> Self { + match value[0] { 0 => Self::A, 1 => Self::B, _ => Self::C, } } } - impl From for u64 { + impl From for [u8; 2] { fn from(value: CustomEnum) -> Self { - value as _ + [value as _, 0] } } + assert_eq!(align_of::(), 8); + assert_eq!(size_of::(), 8); + let mut val = MyBitfield::new() .with_int(3 << 15) .with_flag(true) .with_tiny(1) .with_negative(-3) .with_custom(CustomEnum::B) - .with_public(2); + .with_public((1 << MyBitfield::PUBLIC_BITS) - 1); println!("{val:?}"); let raw: u64 = val.into(); println!("{raw:b}"); + let raw: [u8; 8] = val.into(); + for v in raw { + print!("{v:08b} "); + } + println!(); + assert_eq!(val.int(), 3 << 15); assert_eq!(val.flag(), true); assert_eq!(val.negative(), -3); assert_eq!(val.tiny(), 1); assert_eq!(val.custom(), CustomEnum::B); - assert_eq!(val.public(), 2); + assert_eq!(val.public(), (1 << MyBitfield::PUBLIC_BITS) - 1); // const members assert_eq!(MyBitfield::FLAG_BITS, 1); @@ -119,3 +129,19 @@ fn debug() { let full = Full::new().with_data(123); println!("{full:?}"); } + +#[test] +fn custom_size() { + #[bitfield(bytes = 9)] + struct NoTy { + data: u64, + extra: u8, + } + + let full = NoTy::new().with_data(123).with_extra(255); + assert_eq!(full.data(), 123); + assert_eq!(full.extra(), 255); + assert_eq!(align_of::(), 1); + assert_eq!(size_of::(), 9); + println!("{full:?}"); +}