diff --git a/bitcode_derive/src/decode.rs b/bitcode_derive/src/decode.rs index c95f813..0421537 100644 --- a/bitcode_derive/src/decode.rs +++ b/bitcode_derive/src/decode.rs @@ -1,6 +1,6 @@ use crate::attribute::BitcodeAttrs; use crate::private; -use crate::shared::{remove_lifetimes, replace_lifetimes, variant_index}; +use crate::shared::{remove_lifetimes, replace_lifetimes, VariantIndexType}; use proc_macro2::{Ident, Span, TokenStream}; use quote::{quote, ToTokens}; use syn::{ @@ -111,6 +111,7 @@ impl crate::shared::Item for Item { self, crate_name: &Path, variant_count: usize, + variant_index_type: VariantIndexType, pattern: impl Fn(usize) -> TokenStream, inner: impl Fn(Self, usize) -> TokenStream, ) -> TokenStream { @@ -126,7 +127,12 @@ impl crate::shared::Item for Item { .then(|| { let private = private(crate_name); let c_style = inners.is_empty(); - quote! { variants: #private::VariantDecoder<#de, #variant_count, #c_style>, } + let histogram = if c_style { + 0 + } else { + variant_count + }; + quote! { variants: #private::VariantDecoder<#de, #variant_index_type, #variant_count, #histogram>, } }) .unwrap_or_default(); quote! { @@ -165,7 +171,7 @@ impl crate::shared::Item for Item { if inner.is_empty() { quote! {} } else { - let i = variant_index(i); + let i = variant_index_type.instance_to_tokens(i); let length = decode_variants .then(|| { quote! { @@ -209,7 +215,7 @@ impl crate::shared::Item for Item { .map(|i| { let inner = inner(i); let pattern = pattern(i); - let i = variant_index(i); + let i = variant_index_type.instance_to_tokens(i); quote! { #i => { #inner @@ -221,7 +227,7 @@ impl crate::shared::Item for Item { quote! { match self.variants.decode() { #variants - // Safety: VariantDecoder::decode outputs numbers less than N. + // Safety: VariantDecoder<_, N, _>::decode outputs numbers less than N. _ => unsafe { ::core::hint::unreachable_unchecked() } } } diff --git a/bitcode_derive/src/encode.rs b/bitcode_derive/src/encode.rs index 0ddbdbc..d134a37 100644 --- a/bitcode_derive/src/encode.rs +++ b/bitcode_derive/src/encode.rs @@ -1,6 +1,6 @@ use crate::attribute::BitcodeAttrs; use crate::private; -use crate::shared::{remove_lifetimes, replace_lifetimes, variant_index}; +use crate::shared::{remove_lifetimes, replace_lifetimes, VariantIndexType}; use proc_macro2::{Ident, Span, TokenStream}; use quote::{quote, ToTokens}; use syn::{parse_quote, Generics, Path, Type}; @@ -114,6 +114,7 @@ impl crate::shared::Item for Item { self, crate_name: &Path, variant_count: usize, + variant_index_type: VariantIndexType, pattern: impl Fn(usize) -> TokenStream, inner: impl Fn(Self, usize) -> TokenStream, ) -> TokenStream { @@ -124,7 +125,7 @@ impl crate::shared::Item for Item { let variants = encode_variants .then(|| { let private = private(crate_name); - quote! { variants: #private::VariantEncoder<#variant_count>, } + quote! { variants: #private::VariantEncoder<#variant_index_type, #variant_count>, } }) .unwrap_or_default(); let inners: TokenStream = (0..variant_count).map(|i| inner(self, i)).collect(); @@ -149,7 +150,7 @@ impl crate::shared::Item for Item { let variants: TokenStream = (0..variant_count) .map(|i| { let pattern = pattern(i); - let i = variant_index(i); + let i = variant_index_type.instance_to_tokens(i); quote! { #pattern => #i, } diff --git a/bitcode_derive/src/shared.rs b/bitcode_derive/src/shared.rs index 3d4f67e..879e114 100644 --- a/bitcode_derive/src/shared.rs +++ b/bitcode_derive/src/shared.rs @@ -9,9 +9,60 @@ use syn::{ Result, Type, WherePredicate, }; -type VariantIndex = u8; -pub fn variant_index(i: usize) -> VariantIndex { - i.try_into().unwrap() +#[derive(Copy, Clone, Debug, PartialEq, Eq)] +pub enum VariantIndexType { + U8, + U16, +} + +impl VariantIndexType { + pub fn new(variant_count: usize, ident: &Ident) -> Result { + for candidate in [Self::U8, Self::U16] { + if variant_count <= candidate.max_variants() { + return Ok(candidate); + } + } + err( + &ident, + &format!( + "enums with more than {} variants are not supported", + Self::U16.max_variants() + ), + ) + } + + fn max_variants(self) -> usize { + (match self { + Self::U8 => u8::MAX as usize, + Self::U16 => u16::MAX as usize, + }) + 1 + } + + pub fn instance_to_tokens(self, index: usize) -> TokenStream { + match self { + Self::U8 => { + let n: u8 = index.try_into().unwrap(); + quote! {#n} + } + Self::U16 => { + let n: u16 = index.try_into().unwrap(); + quote! {#n} + } + } + } +} + +impl ToTokens for VariantIndexType { + fn to_tokens(&self, tokens: &mut TokenStream) { + use quote::TokenStreamExt; + tokens.append(Ident::new( + match self { + Self::U8 => "u8", + Self::U16 => "u16", + }, + Span::call_site(), + )); + } } pub trait Item: Copy + Sized { @@ -36,6 +87,7 @@ pub trait Item: Copy + Sized { self, crate_name: &Path, variant_count: usize, + variant_index_type: VariantIndexType, pattern: impl Fn(usize) -> TokenStream, inner: impl Fn(Self, usize) -> TokenStream, ) -> TokenStream; @@ -132,12 +184,20 @@ pub trait Derive { }) } Data::Enum(data_enum) => { - let max_variants = VariantIndex::MAX as usize + 1; - if data_enum.variants.len() > max_variants { - return err( - &ident, - &format!("enums with more than {max_variants} variants are not supported"), - ); + let variant_index_type = VariantIndexType::new(data_enum.variants.len(), &ident)?; + + if variant_index_type != VariantIndexType::U8 { + for variant in &data_enum.variants { + if !variant.fields.is_empty() { + return err( + &ident, + &format!( + "enums with more than {} variants must not have any variants with fields", + VariantIndexType::U8.max_variants() + ), + ); + } + } } // Used for adding `bounds` and skipping fields. Would be used by `#[bitcode(with_serde)]`. @@ -154,6 +214,7 @@ pub trait Derive { item.enum_impl( &attrs.crate_name, data_enum.variants.len(), + variant_index_type, |i| { let variant = &data_enum.variants[i]; let variant_name = &variant.ident; diff --git a/fuzz/fuzz_targets/fuzz.rs b/fuzz/fuzz_targets/fuzz.rs index a9beda4..aed0aad 100644 --- a/fuzz/fuzz_targets/fuzz.rs +++ b/fuzz/fuzz_targets/fuzz.rs @@ -3,14 +3,14 @@ use libfuzzer_sys::fuzz_target; extern crate bitcode; use arrayvec::{ArrayString, ArrayVec}; use bitcode::{Decode, DecodeOwned, Encode}; +use rust_decimal::Decimal; use serde::de::DeserializeOwned; use serde::{Deserialize, Serialize}; use std::collections::{BTreeMap, HashMap}; use std::fmt::Debug; +use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6}; use std::num::NonZeroU32; use std::time::Duration; -use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6}; -use rust_decimal::Decimal; #[inline(never)] fn test_derive(data: &[u8]) { @@ -140,6 +140,39 @@ fuzz_target!(|data: &[u8]| { pub enum Enum16 { A, B, C, D, E, F, G, H, I, J, K, L, M, N, O, P } #[derive(Serialize, Deserialize, Encode, Decode, Debug, PartialEq)] pub enum Enum17 { A, B, C, D, E, F, G, H, I, J, K, L, M, N, O, P, Q } + #[derive(Serialize, Deserialize, Encode, Decode, Debug, PartialEq)] + pub enum Enum300 { + V1, V2, V3, V4, V5, V6, V7, V8, V9, V10, + V11, V12, V13, V14, V15, V16, V17, V18, V19, V20, + V21, V22, V23, V24, V25, V26, V27, V28, V29, V30, + V31, V32, V33, V34, V35, V36, V37, V38, V39, V40, + V41, V42, V43, V44, V45, V46, V47, V48, V49, V50, + V51, V52, V53, V54, V55, V56, V57, V58, V59, V60, + V61, V62, V63, V64, V65, V66, V67, V68, V69, V70, + V71, V72, V73, V74, V75, V76, V77, V78, V79, V80, + V81, V82, V83, V84, V85, V86, V87, V88, V89, V90, + V91, V92, V93, V94, V95, V96, V97, V98, V99, V100, + V101, V102, V103, V104, V105, V106, V107, V108, V109, V110, + V111, V112, V113, V114, V115, V116, V117, V118, V119, V120, + V121, V122, V123, V124, V125, V126, V127, V128, V129, V130, + V131, V132, V133, V134, V135, V136, V137, V138, V139, V140, + V141, V142, V143, V144, V145, V146, V147, V148, V149, V150, + V151, V152, V153, V154, V155, V156, V157, V158, V159, V160, + V161, V162, V163, V164, V165, V166, V167, V168, V169, V170, + V171, V172, V173, V174, V175, V176, V177, V178, V179, V180, + V181, V182, V183, V184, V185, V186, V187, V188, V189, V190, + V191, V192, V193, V194, V195, V196, V197, V198, V199, V200, + V201, V202, V203, V204, V205, V206, V207, V208, V209, V210, + V211, V212, V213, V214, V215, V216, V217, V218, V219, V220, + V221, V222, V223, V224, V225, V226, V227, V228, V229, V230, + V231, V232, V233, V234, V235, V236, V237, V238, V239, V240, + V241, V242, V243, V244, V245, V246, V247, V248, V249, V250, + V251, V252, V253, V254, V255, V256, V257, V258, V259, V260, + V261, V262, V263, V264, V265, V266, V267, V268, V269, V270, + V271, V272, V273, V274, V275, V276, V277, V278, V279, V280, + V281, V282, V283, V284, V285, V286, V287, V288, V289, V290, + V291, V292, V293, V294, V295, V296, V297, V298, V299, V300, + } } use enums::*; @@ -148,10 +181,20 @@ fuzz_target!(|data: &[u8]| { A, B, C(u16), - D { a: u8, b: u8, #[serde(skip)] #[bitcode(skip)] c: u8 }, + D { + a: u8, + b: u8, + #[serde(skip)] + #[bitcode(skip)] + c: u8, + }, E(String), F, - G(#[bitcode(skip)] #[serde(skip)] i16), + G( + #[bitcode(skip)] + #[serde(skip)] + i16, + ), P(BTreeMap), } @@ -219,6 +262,7 @@ fuzz_target!(|data: &[u8]| { Enum15, Enum16, Enum17, + Enum300, Enum, ArrayString<5>, ArrayString<70>, diff --git a/src/derive/option.rs b/src/derive/option.rs index b192bae..967aec6 100644 --- a/src/derive/option.rs +++ b/src/derive/option.rs @@ -7,7 +7,7 @@ use core::mem::MaybeUninit; use core::num::NonZeroUsize; pub struct OptionEncoder { - variants: VariantEncoder<2>, + variants: VariantEncoder, some: T::Encoder, } @@ -86,7 +86,7 @@ impl Buffer for OptionEncoder { } pub struct OptionDecoder<'a, T: Decode<'a>> { - variants: VariantDecoder<'a, 2, false>, + variants: VariantDecoder<'a, u8, 2, 2>, some: T::Decoder, } diff --git a/src/derive/result.rs b/src/derive/result.rs index 9ec6971..fb7dede 100644 --- a/src/derive/result.rs +++ b/src/derive/result.rs @@ -7,7 +7,7 @@ use core::mem::MaybeUninit; use core::num::NonZeroUsize; pub struct ResultEncoder { - variants: VariantEncoder<2>, + variants: VariantEncoder, ok: T::Encoder, err: E::Encoder, } @@ -55,7 +55,7 @@ impl Buffer for ResultEncoder { } pub struct ResultDecoder<'a, T: Decode<'a>, E: Decode<'a>> { - variants: VariantDecoder<'a, 2, false>, + variants: VariantDecoder<'a, u8, 2, 2>, ok: T::Decoder, err: E::Decoder, } diff --git a/src/derive/variant.rs b/src/derive/variant.rs index 67463c3..2c4b37d 100644 --- a/src/derive/variant.rs +++ b/src/derive/variant.rs @@ -1,23 +1,30 @@ use crate::coder::{Buffer, Decoder, Encoder, Result, View}; +use crate::error::err; use crate::fast::{CowSlice, NextUnchecked, PushUnchecked, VecImpl}; -use crate::pack::{pack_bytes_less_than, unpack_bytes_less_than}; +use crate::pack::{check_less_than, pack_bytes_less_than, unpack_bytes_less_than}; +use crate::pack_ints::{pack_ints, unpack_ints, Int}; use alloc::vec::Vec; +use core::any::TypeId; use core::num::NonZeroUsize; #[derive(Default)] -pub struct VariantEncoder(VecImpl); +pub struct VariantEncoder(VecImpl); -impl Encoder for VariantEncoder { +impl Encoder for VariantEncoder { #[inline(always)] - fn encode(&mut self, v: &u8) { + fn encode(&mut self, v: &T) { unsafe { self.0.push_unchecked(*v) }; } } -impl Buffer for VariantEncoder { +impl Buffer for VariantEncoder { fn collect_into(&mut self, out: &mut Vec) { assert!(N >= 2); - pack_bytes_less_than::(self.0.as_slice(), out); + if TypeId::of::() != TypeId::of::() { + pack_ints(self.0.as_mut_slice(), out); + } else { + pack_bytes_less_than::(bytemuck::must_cast_slice::(self.0.as_slice()), out); + }; self.0.clear(); } @@ -26,13 +33,16 @@ impl Buffer for VariantEncoder { } } -pub struct VariantDecoder<'a, const N: usize, const C_STYLE: bool> { - variants: CowSlice<'a, u8>, - histogram: [usize; N], // Not required if C_STYLE. TODO don't reserve space for it. +pub struct VariantDecoder<'a, T: Int, const N: usize, const HISTOGRAM: usize> { + variants: CowSlice<'a, T::Une>, + // `HISTOGRAM` is 0 for C style (fieldless) enums. + histogram: [usize; HISTOGRAM], } // [(); N] doesn't implement Default. -impl Default for VariantDecoder<'_, N, C_STYLE> { +impl Default + for VariantDecoder<'_, T, N, HISTOGRAM> +{ fn default() -> Self { Self { variants: Default::default(), @@ -41,30 +51,41 @@ impl Default for VariantDecoder<'_, N, C_ST } } -// C style enums don't require length, so we can skip making a histogram for them. -impl<'a, const N: usize> VariantDecoder<'a, N, false> { +// C style enums (`HISTOGRAM` = 0) don't require length, so we +// can skip making a histogram for them. +impl<'a, T: Int, const N: usize> VariantDecoder<'a, T, N, N> { pub fn length(&self, variant_index: u8) -> usize { self.histogram[variant_index as usize] } } -impl<'a, const N: usize, const C_STYLE: bool> View<'a> for VariantDecoder<'a, N, C_STYLE> { +impl<'a, T: Int + Into, const N: usize, const HISTOGRAM: usize> View<'a> + for VariantDecoder<'a, T, N, HISTOGRAM> +{ fn populate(&mut self, input: &mut &'a [u8], length: usize) -> Result<()> { assert!(N >= 2); - if C_STYLE { - unpack_bytes_less_than::(input, length, &mut self.variants)?; + if TypeId::of::() != TypeId::of::() { + unpack_ints::(input, length, &mut self.variants)?; + + check_less_than::(unsafe { + self.variants.as_slice(length) + })?; } else { - self.histogram = unpack_bytes_less_than::(input, length, &mut self.variants)?; + assert!(HISTOGRAM == 0 || HISTOGRAM == N); + let out = self.variants.cast_mut::(); + self.histogram = unpack_bytes_less_than::(input, length, out)?; } Ok(()) } } -impl<'a, const N: usize, const C_STYLE: bool> Decoder<'a, u8> for VariantDecoder<'a, N, C_STYLE> { +impl<'a, T: Int + Into, const N: usize, const HISTOGRAM: usize> Decoder<'a, T> + for VariantDecoder<'a, T, N, HISTOGRAM> +{ // Guaranteed to output numbers less than N. #[inline(always)] - fn decode(&mut self) -> u8 { - unsafe { self.variants.mut_slice().next_unchecked() } + fn decode(&mut self) -> T { + T::from_unaligned(unsafe { self.variants.mut_slice().next_unchecked() }) } } @@ -138,3 +159,59 @@ mod tests { } crate::bench_encode_decode!(bool_enum_vec: Vec<_>); } + +#[cfg(test)] +mod test2 { + use crate::{decode, encode, Decode, Encode}; + use alloc::vec::Vec; + + #[cfg_attr(not(test), rustfmt::skip)] + #[derive(Encode, Decode, Debug, PartialEq)] + pub enum Enum300 { + V1, V2, V3, V4, V5, V6, V7, V8, V9, V10, + V11, V12, V13, V14, V15, V16, V17, V18, V19, V20, + V21, V22, V23, V24, V25, V26, V27, V28, V29, V30, + V31, V32, V33, V34, V35, V36, V37, V38, V39, V40, + V41, V42, V43, V44, V45, V46, V47, V48, V49, V50, + V51, V52, V53, V54, V55, V56, V57, V58, V59, V60, + V61, V62, V63, V64, V65, V66, V67, V68, V69, V70, + V71, V72, V73, V74, V75, V76, V77, V78, V79, V80, + V81, V82, V83, V84, V85, V86, V87, V88, V89, V90, + V91, V92, V93, V94, V95, V96, V97, V98, V99, V100, + V101, V102, V103, V104, V105, V106, V107, V108, V109, V110, + V111, V112, V113, V114, V115, V116, V117, V118, V119, V120, + V121, V122, V123, V124, V125, V126, V127, V128, V129, V130, + V131, V132, V133, V134, V135, V136, V137, V138, V139, V140, + V141, V142, V143, V144, V145, V146, V147, V148, V149, V150, + V151, V152, V153, V154, V155, V156, V157, V158, V159, V160, + V161, V162, V163, V164, V165, V166, V167, V168, V169, V170, + V171, V172, V173, V174, V175, V176, V177, V178, V179, V180, + V181, V182, V183, V184, V185, V186, V187, V188, V189, V190, + V191, V192, V193, V194, V195, V196, V197, V198, V199, V200, + V201, V202, V203, V204, V205, V206, V207, V208, V209, V210, + V211, V212, V213, V214, V215, V216, V217, V218, V219, V220, + V221, V222, V223, V224, V225, V226, V227, V228, V229, V230, + V231, V232, V233, V234, V235, V236, V237, V238, V239, V240, + V241, V242, V243, V244, V245, V246, V247, V248, V249, V250, + V251, V252, V253, V254, V255, V256, V257, V258, V259, V260, + V261, V262, V263, V264, V265, V266, V267, V268, V269, V270, + V271, V272, V273, V274, V275, V276, V277, V278, V279, V280, + V281, V282, V283, V284, V285, V286, V287, V288, V289, V290, + V291, V292, V293, V294, V295, V296, V297, V298, V299, V300, + } + + #[allow(unused)] + #[test] + fn test_large_c_style_enum() { + assert!(matches!(decode(&encode(&Enum300::V42)), Ok(Enum300::V42))); + assert!(matches!(decode(&encode(&Enum300::V300)), Ok(Enum300::V300))); + } + + fn bench_data() -> Vec { + crate::random_data(1000) + .into_iter() + .map(|v: u16| unsafe { core::mem::transmute_copy::<_, Enum300>(&(v % 300)) }) + .collect() + } + crate::bench_encode_decode!(enum_300_variants_vec: Vec<_>); +} diff --git a/src/pack.rs b/src/pack.rs index 6aee87e..6c0684b 100644 --- a/src/pack.rs +++ b/src/pack.rs @@ -2,7 +2,7 @@ use crate::coder::Result; use crate::consume::{consume_byte, consume_byte_arrays, consume_bytes}; use crate::error::err; use crate::fast::CowSlice; -use crate::pack_ints::SizedInt; +use crate::pack_ints::{Int, SizedInt}; use alloc::vec::Vec; /// Possible states per byte in descending order. Each packed byte will use `log2(states)` bits. @@ -201,6 +201,39 @@ pub fn pack_bytes_less_than(bytes: &[u8], out: &mut Vec) { } } +fn check_less_than_u8( + unpacked: &[u8], +) -> Result<[usize; HISTOGRAM]> { + check_less_than::(bytemuck::must_cast_slice(unpacked)) +} + +pub fn check_less_than< + T: Int + Into, + const N: usize, + const HISTOGRAM: usize, + const FACTOR: usize, +>( + unpacked: &[T::Une], +) -> Result<[usize; HISTOGRAM]> { + assert!(FACTOR >= N); + debug_assert!(unpacked + .iter() + .all(|&v| T::from_unaligned(v).into() < FACTOR)); + if FACTOR > N + && unpacked + .iter() + .copied() + .map(T::from_unaligned) + .max() + .map(Into::into) + .unwrap_or(0) + >= N + { + return invalid_packing(); + } + Ok(core::array::from_fn(|_| unreachable!("HISTOGRAM not 0"))) +} + /// Like `unpack_bytes` but all values are less than `N` so it can avoid encoding the packing. /// Bytes returned by this function are guaranteed less than `N`. /// @@ -215,19 +248,6 @@ pub fn unpack_bytes_less_than<'a, const N: usize, const HISTOGRAM: usize>( ) -> Result<[usize; HISTOGRAM]> { assert!(HISTOGRAM == N || HISTOGRAM == 0); - /// Checks that `unpacked` bytes are less than `N`. All of `unpacked` is assumed to be < `FACTOR`. - /// `HISTOGRAM` must be 0. - fn check_less_than( - unpacked: &[u8], - ) -> Result<[usize; HISTOGRAM]> { - assert!(FACTOR >= N); - debug_assert!(unpacked.iter().all(|&v| (v as usize) < FACTOR)); - if FACTOR > N && unpacked.iter().copied().max().unwrap_or(0) as usize >= N { - return invalid_packing(); - } - Ok(core::array::from_fn(|_| unreachable!("HISTOGRAM not 0"))) - } - /// Returns `Ok(histogram)` if buckets after `OUT` are 0. fn check_histogram( histogram: [usize; IN], @@ -244,7 +264,7 @@ pub fn unpack_bytes_less_than<'a, const N: usize, const HISTOGRAM: usize>( let bytes = consume_bytes(input, length)?; out.set_borrowed(bytes); return if HISTOGRAM == 0 { - check_less_than::(bytes) + check_less_than_u8::(bytes) } else { check_histogram(crate::histogram::histogram(bytes)) }; @@ -269,7 +289,7 @@ pub fn unpack_bytes_less_than<'a, const N: usize, const HISTOGRAM: usize>( let original_input = *input; unpack_arithmetic::(input, length, out)?; if HISTOGRAM == 0 { - check_less_than::(out) + check_less_than_u8::(out) } else { let floor = length / divisor; let ceil = crate::nightly::div_ceil_usize(length, divisor); diff --git a/src/serde/ser.rs b/src/serde/ser.rs index 63057a6..331c467 100644 --- a/src/serde/ser.rs +++ b/src/serde/ser.rs @@ -240,6 +240,11 @@ impl<'a> EncoderWrapper<'a> { #[inline(always)] fn variant_index_u8(variant_index: u32) -> Result { if variant_index > u8::MAX as u32 { + // Properly optimizing the size of large enums would + // require `serde` to specify the variant count. + // + // Good news: the `derive` version of `bitcode` supports + // arbitrary-sized fieldless enums! err("enums with more than 256 variants are unsupported") } else { Ok(variant_index as u8) diff --git a/src/str.rs b/src/str.rs index 80d727d..8d88766 100644 --- a/src/str.rs +++ b/src/str.rs @@ -254,6 +254,7 @@ mod tests { #[doc(hidden)] pub fn _cant_decode_static_from_non_static_buffer() {} +/// TODO: specify error. /// ```compile_fail /// use bitcode::{encode, decode, Encode, Decode}; ///