From 406b42f0ca21d152b3722226a04a8c136fa71b11 Mon Sep 17 00:00:00 2001 From: pantonshire Date: Sun, 1 Oct 2023 13:12:30 +0100 Subject: [PATCH] use CappedString and CowCappedString in generated unscribe fn --- enumscribe/src/internal/capped_string.rs | 91 ++++++++++---- enumscribe_derive/src/enums.rs | 148 +++++++++++++++++++---- enumscribe_derive/src/lib.rs | 83 ++++++++----- 3 files changed, 240 insertions(+), 82 deletions(-) diff --git a/enumscribe/src/internal/capped_string.rs b/enumscribe/src/internal/capped_string.rs index ba4a145..d6d5b7f 100644 --- a/enumscribe/src/internal/capped_string.rs +++ b/enumscribe/src/internal/capped_string.rs @@ -21,6 +21,37 @@ impl<'a, const N: usize> CowCappedString<'a, N> { CowCappedString::Owned(s) => s, } } + + /// Returns a new `CappedString` with capacity `M` containing the string converted to + /// uppercase. Returns `None` if the uppercase-converted string is longer than `M` bytes. + #[inline] + #[must_use] + pub fn to_uppercase(&self) -> Option> { + CappedString::::uppercase_from_str(self) + } +} + +impl<'a, const N: usize> Deref for CowCappedString<'a, N> { + type Target = str; + + #[inline] + fn deref(&self) -> &Self::Target { + self.as_str() + } +} + +impl<'a, const N: usize> AsRef for CowCappedString<'a, N> { + #[inline] + fn as_ref(&self) -> &str { + self + } +} + +impl<'a, const N: usize> Borrow for CowCappedString<'a, N> { + #[inline] + fn borrow(&self) -> &str { + self + } } #[cfg(feature = "serde")] @@ -86,16 +117,45 @@ pub struct CappedString { } impl CappedString { - /// Returns a new `CappedString` containing a copy of the given string data. Returns an error - /// if the string data is larger than `N` bytes. + /// Returns a new `CappedString` containing a copy of the given string data. Returns `None` if + /// the string data is larger than `N` bytes. #[inline] #[must_use] pub fn from_str(s: &str) -> Option { unsafe { Self::from_utf8_unchecked(s.as_bytes()) } } + /// Returns a new `CappedString` containing an uppercase conversion of the given string data. + /// Returns `None` if the converted string is larger than `N` bytes. + #[inline] + #[must_use] + pub fn uppercase_from_str(s: &str) -> Option { + let mut buf = [0u8; N]; + let mut cursor = 0usize; + + for c_orig in s.chars() { + for c_upper in c_orig.to_uppercase() { + let encode_buf = cursor + .checked_add(c_upper.len_utf8()) + .and_then(|encode_buf_end| buf.get_mut(cursor..encode_buf_end))?; + + // FIXME: avoid the panic asm that gets generated for this encode (can never panic, + // as we always have at least `c_upper.len_utf8()` buffer space). + let encoded = c_upper.encode_utf8(encode_buf); + cursor = cursor.checked_add(encoded.len())?; + } + } + + let filled_buf = buf.get(..cursor)?; + + // SAFETY: + // `filled_buf` has been filled with a sequence of bytes obtained from `char::encode_utf8`, + // so it is valid UTF-8. + unsafe { Self::from_utf8_unchecked(filled_buf) } + } + /// Returns a new `CappedString` containing a copy of the given UTF-8 encoded string data. - /// Returns an error if more than `N` bytes of data are given. + /// Returns `None` if more than `N` bytes of data are given. /// /// # Safety /// - `bs` must be valid UTF-8. @@ -146,32 +206,11 @@ impl CappedString { } /// Returns a new `CappedString` with capacity `M` containing the string converted to - /// uppercase. Returns an error if the uppercase-converted string is longer than `M` bytes. + /// uppercase. Returns `None` if the uppercase-converted string is longer than `M` bytes. #[inline] #[must_use] pub fn to_uppercase(&self) -> Option> { - let mut buf = [0u8; M]; - let mut cursor = 0usize; - - for c_orig in self.as_str().chars() { - for c_upper in c_orig.to_uppercase() { - let encode_buf = cursor - .checked_add(c_upper.len_utf8()) - .and_then(|encode_buf_end| buf.get_mut(cursor..encode_buf_end))?; - - // FIXME: avoid the panic asm that gets generated for this encode (can never panic, - // as we always have at least `c_upper.len_utf8()` buffer space). - let encoded = c_upper.encode_utf8(encode_buf); - cursor = cursor.checked_add(encoded.len())?; - } - } - - let filled_buf = buf.get(..cursor)?; - - // SAFETY: - // `filled_buf` has been filled with a sequence of bytes obtained from `char::encode_utf8`, - // so it is valid UTF-8. - unsafe { CappedString::from_utf8_unchecked(filled_buf) } + CappedString::::uppercase_from_str(self) } } diff --git a/enumscribe_derive/src/enums.rs b/enumscribe_derive/src/enums.rs index 4363b98..7e9da6e 100644 --- a/enumscribe_derive/src/enums.rs +++ b/enumscribe_derive/src/enums.rs @@ -12,7 +12,45 @@ use crate::{CASE_INSENSITIVE, CRATE_ATTR, IGNORE, NAME, OTHER}; #[derive(Clone)] pub(crate) struct Enum<'a> { - pub(crate) variants: Vec>, + variants: Box<[Variant<'a>]>, + name_capacity: usize, + name_upper_capacity: usize, +} + +impl<'a> Enum<'a> { + pub(crate) fn new(variants: Box<[Variant<'a>]>) -> Self { + let name_capacity = variants + .iter() + .filter_map(|v| v.v_type.as_named()) + .map(|named| named.name().len()) + .max() + .unwrap_or(0); + + let name_upper_capacity = variants + .iter() + .filter_map(|v| v.v_type.as_named()) + .map(|named| named.name_upper().len()) + .max() + .unwrap_or(0); + + Self { + variants, + name_capacity, + name_upper_capacity, + } + } + + pub(crate) fn variants(&self) -> &[Variant<'a>] { + &self.variants + } + + pub(crate) fn name_capacity(&self) -> usize { + self.name_capacity + } + + pub(crate) fn name_upper_capacity(&self) -> usize { + self.name_upper_capacity + } } #[derive(Clone)] @@ -25,14 +63,69 @@ pub(crate) struct Variant<'a> { #[derive(Clone)] pub(crate) enum VariantType<'a> { Ignore, - Named { - name: String, + Named(NamedVariant), + Other(OtherVariant<'a>), +} + +impl<'a> VariantType<'a> { + pub(crate) fn as_named(&self) -> Option<&NamedVariant> { + match self { + Self::Named(named) => Some(named), + _ => None, + } + } +} + +#[derive(Clone)] +pub(crate) struct NamedVariant { + name: Box, + name_upper: Box, + constructor: VariantConstructor, + case_insensitive: bool, +} + +impl NamedVariant { + pub(crate) fn new( + name: Box, constructor: VariantConstructor, - case_insensitive: bool, - }, - Other { - field_name: Option<&'a Ident>, - }, + case_insensitive: bool + ) -> Self + { + let name_upper = char_wise_uppercase(&name); + Self { + name, + name_upper, + constructor, + case_insensitive, + } + } + + pub(crate) fn name(&self) -> &str { + &self.name + } + + pub(crate) fn name_upper(&self) -> &str { + &self.name_upper + } + + pub(crate) fn constructor(&self) -> VariantConstructor { + self.constructor + } + + pub(crate) fn case_insensitive(&self) -> bool { + self.case_insensitive + } +} + +#[derive(Clone)] +pub(crate) struct OtherVariant<'a> { + field_name: Option<&'a Ident>, +} + +impl<'a> OtherVariant<'a> { + pub(crate) fn field_name(&self) -> Option<&'a Ident> { + self.field_name + } } #[derive(Clone, Copy, Debug)] @@ -58,20 +151,18 @@ impl<'a> Variant<'a> { match &self.v_type { VariantType::Ignore => Ok(None), - VariantType::Named { - name, constructor, .. - } => { - let constructor_tokens = constructor.empty(); + VariantType::Named(named) => { + let constructor_tokens = named.constructor().empty_toks(); let pattern = quote! { #enum_ident::#variant_ident #constructor_tokens }; - Ok(Some((pattern, named_fn(self, enum_ident, name)?))) + Ok(Some((pattern, named_fn(self, enum_ident, named.name())?))) } - VariantType::Other { field_name } => { - let field_name_tokens = match field_name { + VariantType::Other(other) => { + let field_name_tokens = match other.field_name() { Some(field_name) => field_name.to_token_stream(), None => quote! { __enumscribe_other_inner }, }; - let pattern = match field_name { + let pattern = match other.field_name() { Some(_) => quote! { #enum_ident::#variant_ident{#field_name_tokens} }, None => quote! { #enum_ident::#variant_ident(#field_name_tokens) }, }; @@ -85,7 +176,7 @@ impl<'a> Variant<'a> { } impl VariantConstructor { - pub(crate) fn empty(&self) -> TokenStream2 { + pub(crate) fn empty_toks(&self) -> TokenStream2 { match self { VariantConstructor::None => quote! {}, VariantConstructor::Paren => quote! { () }, @@ -220,7 +311,7 @@ pub(crate) fn parse_enum<'a>(data: &'a DataEnum, attrs: &'a [Attribute]) -> Macr Variant { data: variant, - v_type: VariantType::Other { field_name }, + v_type: VariantType::Other(OtherVariant { field_name }), span: variant_span, } } else { @@ -279,13 +370,12 @@ pub(crate) fn parse_enum<'a>(data: &'a DataEnum, attrs: &'a [Attribute]) -> Macr Fields::Unit => VariantConstructor::None, }; + let named = NamedVariant::new(name.into_boxed_str(), constructor, case_insensitive); + let v_type = VariantType::Named(named); + Variant { data: variant, - v_type: VariantType::Named { - name, - constructor, - case_insensitive, - }, + v_type, span: variant_span, } }; @@ -293,5 +383,13 @@ pub(crate) fn parse_enum<'a>(data: &'a DataEnum, attrs: &'a [Attribute]) -> Macr variants.push(scribe_variant); } - Ok(Enum { variants }) -} \ No newline at end of file + Ok(Enum::new(variants.into_boxed_slice())) +} + +fn char_wise_uppercase(s: &str) -> Box { + // Use the same uppercase algorithm as `enumscribe::internal::capped_string`. + s.chars() + .flat_map(char::to_uppercase) + .collect::() + .into_boxed_str() +} diff --git a/enumscribe_derive/src/lib.rs b/enumscribe_derive/src/lib.rs index 2c9b78c..c21228c 100644 --- a/enumscribe_derive/src/lib.rs +++ b/enumscribe_derive/src/lib.rs @@ -59,9 +59,9 @@ where let enum_ident = &input.ident; - let mut match_arms = Vec::with_capacity(parsed_enum.variants.len()); + let mut match_arms = Vec::with_capacity(parsed_enum.variants().len()); - for variant in parsed_enum.variants.iter() { + for variant in parsed_enum.variants().iter() { match variant.match_variant(enum_ident, &named_fn, &other_fn) { Ok(Some((pattern, result))) => match_arms.push(quote! { #pattern => #result }), Ok(None) => return ignore_err_fn(variant, enum_ident).into(), @@ -102,9 +102,9 @@ where let enum_ident = &input.ident; let mut ignore_variant = false; - let mut match_arms = Vec::with_capacity(parsed_enum.variants.len()); + let mut match_arms = Vec::with_capacity(parsed_enum.variants().len()); - for variant in parsed_enum.variants.iter() { + for variant in parsed_enum.variants().iter() { match variant.match_variant(enum_ident, &named_fn, &other_fn) { Ok(Some((pattern, result))) => match_arms.push(quote! { #pattern => #result }), Ok(None) => ignore_variant = true, @@ -192,30 +192,28 @@ where let mut case_sensitive_arms = Vec::new(); let mut case_insensitive_arms = Vec::new(); - for variant in parsed_enum.variants.iter() { + for variant in parsed_enum.variants().iter() { let variant_ident = &variant.data.ident; match &variant.v_type { VariantType::Ignore => (), - VariantType::Named { - name, - constructor, - case_insensitive, - } => { - let match_pattern = if *case_insensitive { - let lowercase_name = name.to_lowercase(); - quote! { #lowercase_name } + VariantType::Named(named) => { + let match_pattern = if named.case_insensitive() { + let uppercase_name = named.name_upper(); + quote! { #uppercase_name } } else { + let name = named.name(); quote! { #name } }; - let constructor_tokens = constructor.empty(); - let constructed_variant = - quote! { #enum_ident::#variant_ident #constructor_tokens }; + let constructor_tokens = named.constructor().empty_toks(); + let constructed_variant = quote! { + #enum_ident::#variant_ident #constructor_tokens + }; let match_result = named_fn(constructed_variant); - if *case_insensitive { + if named.case_insensitive() { &mut case_insensitive_arms } else { &mut case_sensitive_arms @@ -223,11 +221,11 @@ where .push(quote! { #match_pattern => #match_result }); } - VariantType::Other { field_name } => { + VariantType::Other(other) => { let unscribe_value = quote! { <_ as ::std::convert::Into<_>>::into(#match_against) }; - let constructed_variant = match field_name { + let constructed_variant = match other.field_name() { None => quote! { #enum_ident::#variant_ident(#unscribe_value) }, @@ -251,10 +249,23 @@ where let case_insensitive_match = if case_insensitive_arms.is_empty() { None } else { + let match_against_upper_ident = quote! { __enumscribe_unscribe_uppercase }; + let name_upper_cap = parsed_enum.name_upper_capacity(); + Some(quote! { - let __enumscribe_unscribe_lowercase = #match_against.to_lowercase(); - match __enumscribe_unscribe_lowercase.as_str() { - #(#case_insensitive_arms,)* + match ::enumscribe + ::internal + ::capped_string + ::CappedString + ::<#name_upper_cap> + ::uppercase_from_str(#match_against) + { + Some(#match_against_upper_ident) => { + match &*#match_against_upper_ident { + #(#case_insensitive_arms,)* + #other_arm, + } + }, #other_arm, } }) @@ -659,23 +670,22 @@ pub fn derive_enum_serialize(input: TokenStream) -> TokenStream { let mut match_arms = Vec::new(); let mut ignore_variant = false; - for variant in parsed_enum.variants.iter() { + for variant in parsed_enum.variants().iter() { let variant_ident = &variant.data.ident; match &variant.v_type { VariantType::Ignore => ignore_variant = true, - VariantType::Named { - name, constructor, .. - } => { - let constructor_tokens = constructor.empty(); + VariantType::Named(named) => { + let constructor_tokens = named.constructor().empty_toks(); + let name = named.name(); match_arms.push(quote! { #enum_ident::#variant_ident #constructor_tokens => #serializer_ident.serialize_str(#name) }) } - VariantType::Other { field_name } => match field_name { + VariantType::Other(other) => match other.field_name() { Some(field_name) => match_arms.push(quote! { #enum_ident::#variant_ident { #field_name } => #serializer_ident.serialize_str(&#field_name) @@ -748,13 +758,14 @@ pub fn derive_enum_deserialize(input: TokenStream) -> TokenStream { let enum_ident = &input.ident; let deserializer_ident = quote! { __enumscribe_deserializer }; + let deserialized_cow_capped_str_ident = quote! { __enumscribe_deserialized_cow_capped_str }; let deserialized_str_ident = quote! { __enumscribe_deserialized_str }; let variant_strings = parsed_enum - .variants + .variants() .iter() .map(|variant| match &variant.v_type { - VariantType::Named { name, .. } => Some(name.as_str()), + VariantType::Named(named) => Some(named.name()), _ => None, }) .filter_map(|name| name) @@ -780,13 +791,23 @@ pub fn derive_enum_deserialize(input: TokenStream) -> TokenStream { }), )); + let name_cap = parsed_enum.name_capacity(); + (quote! { #[automatically_derived] impl<'de> ::serde::Deserialize<'de> for #enum_ident { fn deserialize(#deserializer_ident: D) -> ::core::result::Result where D: ::serde::Deserializer<'de> { - let #deserialized_str_ident = <&str as ::serde::Deserialize<'_>>::deserialize(#deserializer_ident)?; + let #deserialized_cow_capped_str_ident = < + ::enumscribe + ::internal + ::capped_string + ::CowCappedString<'de, #name_cap> + as ::serde::Deserialize<'_> + >::deserialize(#deserializer_ident)?; + + let #deserialized_str_ident = &*#deserialized_cow_capped_str_ident; #main_match } }