diff --git a/enumscribe/Cargo.toml b/enumscribe/Cargo.toml index 9c2378f..fcccae0 100644 --- a/enumscribe/Cargo.toml +++ b/enumscribe/Cargo.toml @@ -12,6 +12,7 @@ keywords = ["enum", "derive", "serde"] [dependencies] enumscribe_derive = { version = "0.3.0", path = "../enumscribe_derive", default-features = false, optional = true } +serde = { version = "1.0", default-features = false, optional = true } [dev-dependencies] serde = { version = "1.0", features = ["derive"] } @@ -21,4 +22,5 @@ serde_json = "1.0" default = ["std", "derive", "derive_serde"] std = ["enumscribe_derive/std"] derive = ["enumscribe_derive"] -derive_serde = ["derive", "enumscribe_derive/serde"] +derive_serde = ["derive", "serde", "enumscribe_derive/serde"] +serde = ["derive_serde", "dep:serde"] diff --git a/enumscribe/src/internal/capped_string.rs b/enumscribe/src/internal/capped_string.rs new file mode 100644 index 0000000..d6d5b7f --- /dev/null +++ b/enumscribe/src/internal/capped_string.rs @@ -0,0 +1,442 @@ +//! Module for the [`CappedString`](CappedString) type, which is a string type which always stores +//! its data inline. + +use core::{str, ops::Deref, borrow::Borrow, fmt}; + +/// TODO: documentation +pub enum CowCappedString<'a, const N: usize> { + /// TODO: documentation + Borrowed(&'a str), + /// TODO: documentation + Owned(CappedString), +} + +impl<'a, const N: usize> CowCappedString<'a, N> { + /// Returns the string data contained by this `CowCappedString`. + #[inline] + #[must_use] + pub fn as_str(&self) -> &str { + match self { + CowCappedString::Borrowed(s) => s, + CowCappedString::Owned(s) => s, + } + } + + /// Returns a new `CappedString` with capacity `M` containing the string converted to + /// uppercase. Returns `None` if the uppercase-converted string is longer than `M` bytes. + #[inline] + #[must_use] + pub fn to_uppercase(&self) -> Option> { + CappedString::::uppercase_from_str(self) + } +} + +impl<'a, const N: usize> Deref for CowCappedString<'a, N> { + type Target = str; + + #[inline] + fn deref(&self) -> &Self::Target { + self.as_str() + } +} + +impl<'a, const N: usize> AsRef for CowCappedString<'a, N> { + #[inline] + fn as_ref(&self) -> &str { + self + } +} + +impl<'a, const N: usize> Borrow for CowCappedString<'a, N> { + #[inline] + fn borrow(&self) -> &str { + self + } +} + +#[cfg(feature = "serde")] +impl<'de, const N: usize> serde::Deserialize<'de> for CowCappedString<'de, N> { + fn deserialize(deserializer: D) -> Result + where + D: serde::Deserializer<'de> + { + deserializer.deserialize_str(CowCappedStringVisitor::) + } +} + +#[cfg(feature = "serde")] +struct CowCappedStringVisitor; + +#[cfg(feature = "serde")] +impl<'de, const N: usize> serde::de::Visitor<'de> for CowCappedStringVisitor { + type Value = CowCappedString<'de, N>; + + fn expecting(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "a borrowed string or a string up to {} bytes long", N) + } + + fn visit_str(self, v: &str) -> Result + where + E: serde::de::Error, + { + CappedStringVisitor::.visit_str(v) + .map(CowCappedString::Owned) + } + + fn visit_bytes(self, v: &[u8]) -> Result + where + E: serde::de::Error, + { + CappedStringVisitor::.visit_bytes(v) + .map(CowCappedString::Owned) + } + + fn visit_borrowed_str(self, v: &'de str) -> Result + where + E: serde::de::Error, + { + Ok(CowCappedString::Borrowed(v)) + } + + fn visit_borrowed_bytes(self, v: &'de [u8]) -> Result + where + E: serde::de::Error, + { + str::from_utf8(v) + .map_err(|_| E::invalid_value(serde::de::Unexpected::Bytes(v), &self)) + .and_then(|v| self.visit_borrowed_str(v)) + } +} + +/// A string type which stores up to `N` bytes of string data inline. +pub struct CappedString { + /// The string data. It is an invariant that the first `len` bytes must be valid UTF-8. + buf: [u8; N], + // The length of the string data in the buffer. It is an invariant that `len <= N`. + len: usize, +} + +impl CappedString { + /// Returns a new `CappedString` containing a copy of the given string data. Returns `None` if + /// the string data is larger than `N` bytes. + #[inline] + #[must_use] + pub fn from_str(s: &str) -> Option { + unsafe { Self::from_utf8_unchecked(s.as_bytes()) } + } + + /// Returns a new `CappedString` containing an uppercase conversion of the given string data. + /// Returns `None` if the converted string is larger than `N` bytes. + #[inline] + #[must_use] + pub fn uppercase_from_str(s: &str) -> Option { + let mut buf = [0u8; N]; + let mut cursor = 0usize; + + for c_orig in s.chars() { + for c_upper in c_orig.to_uppercase() { + let encode_buf = cursor + .checked_add(c_upper.len_utf8()) + .and_then(|encode_buf_end| buf.get_mut(cursor..encode_buf_end))?; + + // FIXME: avoid the panic asm that gets generated for this encode (can never panic, + // as we always have at least `c_upper.len_utf8()` buffer space). + let encoded = c_upper.encode_utf8(encode_buf); + cursor = cursor.checked_add(encoded.len())?; + } + } + + let filled_buf = buf.get(..cursor)?; + + // SAFETY: + // `filled_buf` has been filled with a sequence of bytes obtained from `char::encode_utf8`, + // so it is valid UTF-8. + unsafe { Self::from_utf8_unchecked(filled_buf) } + } + + /// Returns a new `CappedString` containing a copy of the given UTF-8 encoded string data. + /// Returns `None` if more than `N` bytes of data are given. + /// + /// # Safety + /// - `bs` must be valid UTF-8. + #[inline] + #[must_use] + pub unsafe fn from_utf8_unchecked(bs: &[u8]) -> Option { + let mut buf = [0u8; N]; + buf.get_mut(..bs.len())?.copy_from_slice(bs); + + // SAFETY: + // - `bs.len() <= N` has already been checked by the `get_mut` call, which will return + // `None` and cause us to return early if the condition does not hold. + // + unsafe { Some(Self::from_raw_parts(buf, bs.len())) } + } + + /// Returns a new `CappedString` from a given buffer and length. + /// + /// # Safety + /// - `len <= N` must hold. + /// - The first `len` bytes of `buf` must be valid UTF-8. + #[inline] + #[must_use] + pub unsafe fn from_raw_parts(buf: [u8; N], len: usize) -> Self { + Self { buf, len } + } + + + /// Consumes the `CappedString` and returns its buffer and length. + #[inline] + #[must_use] + pub fn into_raw_parts(self) -> ([u8; N], usize) { + (self.buf, self.len) + } + + /// Returns the string data contained by this `CappedString`. + #[inline] + #[must_use] + pub fn as_str(&self) -> &str { + // SAFETY: + // - It is an invariant of `CappedString` that `len <= N`. + // - It is an invariant of `CappedString` that the first `len` bytes of `buf` are valid + // UTF-8. + unsafe { + let buf_occupied_prefix = self.buf.get_unchecked(..self.len); + str::from_utf8_unchecked(buf_occupied_prefix) + } + } + + /// Returns a new `CappedString` with capacity `M` containing the string converted to + /// uppercase. Returns `None` if the uppercase-converted string is longer than `M` bytes. + #[inline] + #[must_use] + pub fn to_uppercase(&self) -> Option> { + CappedString::::uppercase_from_str(self) + } +} + +impl Deref for CappedString { + type Target = str; + + #[inline] + fn deref(&self) -> &Self::Target { + self.as_str() + } +} + +impl AsRef for CappedString { + #[inline] + fn as_ref(&self) -> &str { + self + } +} + +impl Borrow for CappedString { + #[inline] + fn borrow(&self) -> &str { + self + } +} + +impl PartialEq for CappedString { + fn eq(&self, other: &Self) -> bool { + self.as_str() == other.as_str() + } +} + +impl Eq for CappedString {} + +impl PartialEq for CappedString { + fn eq(&self, other: &str) -> bool { + self.as_str() == other + } +} + +#[cfg(feature = "serde")] +impl<'de, const N: usize> serde::Deserialize<'de> for CappedString { + fn deserialize(deserializer: D) -> Result + where + D: serde::Deserializer<'de> + { + deserializer.deserialize_str(CappedStringVisitor::) + } +} + +#[cfg(feature = "serde")] +struct CappedStringVisitor; + +#[cfg(feature = "serde")] +impl<'de, const N: usize> serde::de::Visitor<'de> for CappedStringVisitor { + type Value = CappedString; + + fn expecting(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "a string up to {} bytes long", N) + } + + fn visit_str(self, v: &str) -> Result + where + E: serde::de::Error, + { + CappedString::from_str(v) + .ok_or_else(|| E::invalid_length(v.len(), &self)) + } + + fn visit_bytes(self, v: &[u8]) -> Result + where + E: serde::de::Error, + { + str::from_utf8(v) + .map_err(|_| E::invalid_value(serde::de::Unexpected::Bytes(v), &self)) + .and_then(|v| self.visit_str(v)) + } +} + +#[cfg(test)] +mod tests { + use super::{CappedString, CowCappedString}; + + #[cfg(feature = "serde")] + #[test] + fn test_cow_capped_string_deserialize() { + struct DeBorrowedOnly(String); + + impl<'de, const N: usize> serde::Deserialize<'de> for DeBorrowedOnly { + fn deserialize(deserializer: D) -> Result + where + D: serde::Deserializer<'de> + { + match CowCappedString::<'de, N>::deserialize(deserializer)? { + CowCappedString::Borrowed(s) => Ok(Self(s.to_owned())), + CowCappedString::Owned(_) => { + Err(serde::de::Error::custom("expected borrowed CowCappedString")) + }, + } + } + } + + struct DeOwnedOnly(String); + + impl<'de, const N: usize> serde::Deserialize<'de> for DeOwnedOnly { + fn deserialize(deserializer: D) -> Result + where + D: serde::Deserializer<'de> + { + match CowCappedString::<'de, N>::deserialize(deserializer)? { + CowCappedString::Borrowed(_) => { + Err(serde::de::Error::custom("expected owned CowCappedString")) + }, + CowCappedString::Owned(s) => Ok(Self(s.to_owned())), + } + } + } + + { + let DeBorrowedOnly(s) = serde_json::from_str::>( + r#""hello""# + ).unwrap(); + assert_eq!(s, "hello"); + } + { + let DeBorrowedOnly(s) = serde_json::from_str::>( + r#""hello""# + ).unwrap(); + assert_eq!(s, "hello"); + } + { + let s = serde_json::from_str::>( + r#""hello""# + ); + assert!(s.is_err()); + } + { + let DeOwnedOnly(s) = serde_json::from_str::>( + r#""\u87f9""# + ).unwrap(); + assert_eq!(s, "蟹"); + } + { + let s = serde_json::from_str::>( + r#""\u87f9""# + ); + assert!(s.is_err()); + } + } + + #[cfg(feature = "serde")] + #[test] + fn test_capped_string_deserialize() { + { + let s = serde_json::from_str::>( + r#""hello""# + ).unwrap(); + assert_eq!(s.as_str(), "hello"); + } + { + let s = serde_json::from_str::>( + r#""hello""# + ); + assert!(s.is_err()); + } + { + let s = serde_json::from_str::>( + r#""hello""# + ).unwrap(); + assert_eq!(s.as_str(), "hello"); + } + { + let s = serde_json::from_str::>( + r#""hello\tworld\n""# + ).unwrap(); + assert_eq!(s.as_str(), "hello\tworld\n"); + } + { + let s = serde_json::from_str::>( + r#""\u87f9""# + ).unwrap(); + assert_eq!(s.as_str(), "蟹"); + } + { + let s = serde_json::from_str::>( + r#""\u87f9""# + ); + assert!(s.is_err()); + } + } + + #[test] + fn test_capped_string_uppercase() { + { + let s1 = CappedString::<5>::from_str("hello").unwrap(); + let s2 = s1.to_uppercase::<5>().unwrap(); + assert_eq!(s2.as_str(), "HELLO"); + } + { + let s1 = CappedString::<20>::from_str("hello").unwrap(); + let s2 = s1.to_uppercase::<20>().unwrap(); + assert_eq!(s2.as_str(), "HELLO"); + } + { + let s1 = CappedString::<5>::from_str("hElLo").unwrap(); + let s2 = s1.to_uppercase::<5>().unwrap(); + assert_eq!(s2.as_str(), "HELLO"); + } + { + let s = CappedString::<5>::from_str("hello").unwrap(); + assert!(s.to_uppercase::<4>().is_none()); + } + { + let s1 = CappedString::<5>::from_str("groß").unwrap(); + let s2 = s1.to_uppercase::<5>().unwrap(); + assert_eq!(s2.as_str(), "GROSS"); + } + { + let s1 = CappedString::<1>::from_str("").unwrap(); + let s2 = s1.to_uppercase::<1>().unwrap(); + assert_eq!(s2.as_str(), ""); + } + { + let s1 = CappedString::<0>::from_str("").unwrap(); + let s2 = s1.to_uppercase::<0>().unwrap(); + assert_eq!(s2.as_str(), ""); + } + } +} diff --git a/enumscribe/src/internal/mod.rs b/enumscribe/src/internal/mod.rs new file mode 100644 index 0000000..6a46570 --- /dev/null +++ b/enumscribe/src/internal/mod.rs @@ -0,0 +1,3 @@ +//! Utilities for use by code generated by `enumscribe_derive`. + +pub mod capped_string; diff --git a/enumscribe/src/lib.rs b/enumscribe/src/lib.rs index 438b39b..04e41de 100644 --- a/enumscribe/src/lib.rs +++ b/enumscribe/src/lib.rs @@ -181,9 +181,11 @@ //! you *really* don't want to use a `Cow` for whatever reason. #![deny(missing_docs)] +#![deny(unsafe_op_in_unsafe_fn)] #![cfg_attr(not(feature = "std"), no_std)] -#[macro_use] +pub mod internal; + extern crate enumscribe_derive; pub use enumscribe_derive::*; diff --git a/enumscribe_derive/src/enums.rs b/enumscribe_derive/src/enums.rs index 4363b98..7e9da6e 100644 --- a/enumscribe_derive/src/enums.rs +++ b/enumscribe_derive/src/enums.rs @@ -12,7 +12,45 @@ use crate::{CASE_INSENSITIVE, CRATE_ATTR, IGNORE, NAME, OTHER}; #[derive(Clone)] pub(crate) struct Enum<'a> { - pub(crate) variants: Vec>, + variants: Box<[Variant<'a>]>, + name_capacity: usize, + name_upper_capacity: usize, +} + +impl<'a> Enum<'a> { + pub(crate) fn new(variants: Box<[Variant<'a>]>) -> Self { + let name_capacity = variants + .iter() + .filter_map(|v| v.v_type.as_named()) + .map(|named| named.name().len()) + .max() + .unwrap_or(0); + + let name_upper_capacity = variants + .iter() + .filter_map(|v| v.v_type.as_named()) + .map(|named| named.name_upper().len()) + .max() + .unwrap_or(0); + + Self { + variants, + name_capacity, + name_upper_capacity, + } + } + + pub(crate) fn variants(&self) -> &[Variant<'a>] { + &self.variants + } + + pub(crate) fn name_capacity(&self) -> usize { + self.name_capacity + } + + pub(crate) fn name_upper_capacity(&self) -> usize { + self.name_upper_capacity + } } #[derive(Clone)] @@ -25,14 +63,69 @@ pub(crate) struct Variant<'a> { #[derive(Clone)] pub(crate) enum VariantType<'a> { Ignore, - Named { - name: String, + Named(NamedVariant), + Other(OtherVariant<'a>), +} + +impl<'a> VariantType<'a> { + pub(crate) fn as_named(&self) -> Option<&NamedVariant> { + match self { + Self::Named(named) => Some(named), + _ => None, + } + } +} + +#[derive(Clone)] +pub(crate) struct NamedVariant { + name: Box, + name_upper: Box, + constructor: VariantConstructor, + case_insensitive: bool, +} + +impl NamedVariant { + pub(crate) fn new( + name: Box, constructor: VariantConstructor, - case_insensitive: bool, - }, - Other { - field_name: Option<&'a Ident>, - }, + case_insensitive: bool + ) -> Self + { + let name_upper = char_wise_uppercase(&name); + Self { + name, + name_upper, + constructor, + case_insensitive, + } + } + + pub(crate) fn name(&self) -> &str { + &self.name + } + + pub(crate) fn name_upper(&self) -> &str { + &self.name_upper + } + + pub(crate) fn constructor(&self) -> VariantConstructor { + self.constructor + } + + pub(crate) fn case_insensitive(&self) -> bool { + self.case_insensitive + } +} + +#[derive(Clone)] +pub(crate) struct OtherVariant<'a> { + field_name: Option<&'a Ident>, +} + +impl<'a> OtherVariant<'a> { + pub(crate) fn field_name(&self) -> Option<&'a Ident> { + self.field_name + } } #[derive(Clone, Copy, Debug)] @@ -58,20 +151,18 @@ impl<'a> Variant<'a> { match &self.v_type { VariantType::Ignore => Ok(None), - VariantType::Named { - name, constructor, .. - } => { - let constructor_tokens = constructor.empty(); + VariantType::Named(named) => { + let constructor_tokens = named.constructor().empty_toks(); let pattern = quote! { #enum_ident::#variant_ident #constructor_tokens }; - Ok(Some((pattern, named_fn(self, enum_ident, name)?))) + Ok(Some((pattern, named_fn(self, enum_ident, named.name())?))) } - VariantType::Other { field_name } => { - let field_name_tokens = match field_name { + VariantType::Other(other) => { + let field_name_tokens = match other.field_name() { Some(field_name) => field_name.to_token_stream(), None => quote! { __enumscribe_other_inner }, }; - let pattern = match field_name { + let pattern = match other.field_name() { Some(_) => quote! { #enum_ident::#variant_ident{#field_name_tokens} }, None => quote! { #enum_ident::#variant_ident(#field_name_tokens) }, }; @@ -85,7 +176,7 @@ impl<'a> Variant<'a> { } impl VariantConstructor { - pub(crate) fn empty(&self) -> TokenStream2 { + pub(crate) fn empty_toks(&self) -> TokenStream2 { match self { VariantConstructor::None => quote! {}, VariantConstructor::Paren => quote! { () }, @@ -220,7 +311,7 @@ pub(crate) fn parse_enum<'a>(data: &'a DataEnum, attrs: &'a [Attribute]) -> Macr Variant { data: variant, - v_type: VariantType::Other { field_name }, + v_type: VariantType::Other(OtherVariant { field_name }), span: variant_span, } } else { @@ -279,13 +370,12 @@ pub(crate) fn parse_enum<'a>(data: &'a DataEnum, attrs: &'a [Attribute]) -> Macr Fields::Unit => VariantConstructor::None, }; + let named = NamedVariant::new(name.into_boxed_str(), constructor, case_insensitive); + let v_type = VariantType::Named(named); + Variant { data: variant, - v_type: VariantType::Named { - name, - constructor, - case_insensitive, - }, + v_type, span: variant_span, } }; @@ -293,5 +383,13 @@ pub(crate) fn parse_enum<'a>(data: &'a DataEnum, attrs: &'a [Attribute]) -> Macr variants.push(scribe_variant); } - Ok(Enum { variants }) -} \ No newline at end of file + Ok(Enum::new(variants.into_boxed_slice())) +} + +fn char_wise_uppercase(s: &str) -> Box { + // Use the same uppercase algorithm as `enumscribe::internal::capped_string`. + s.chars() + .flat_map(char::to_uppercase) + .collect::() + .into_boxed_str() +} diff --git a/enumscribe_derive/src/lib.rs b/enumscribe_derive/src/lib.rs index 2c9b78c..74f96f3 100644 --- a/enumscribe_derive/src/lib.rs +++ b/enumscribe_derive/src/lib.rs @@ -59,9 +59,9 @@ where let enum_ident = &input.ident; - let mut match_arms = Vec::with_capacity(parsed_enum.variants.len()); + let mut match_arms = Vec::with_capacity(parsed_enum.variants().len()); - for variant in parsed_enum.variants.iter() { + for variant in parsed_enum.variants().iter() { match variant.match_variant(enum_ident, &named_fn, &other_fn) { Ok(Some((pattern, result))) => match_arms.push(quote! { #pattern => #result }), Ok(None) => return ignore_err_fn(variant, enum_ident).into(), @@ -102,9 +102,9 @@ where let enum_ident = &input.ident; let mut ignore_variant = false; - let mut match_arms = Vec::with_capacity(parsed_enum.variants.len()); + let mut match_arms = Vec::with_capacity(parsed_enum.variants().len()); - for variant in parsed_enum.variants.iter() { + for variant in parsed_enum.variants().iter() { match variant.match_variant(enum_ident, &named_fn, &other_fn) { Ok(Some((pattern, result))) => match_arms.push(quote! { #pattern => #result }), Ok(None) => ignore_variant = true, @@ -192,30 +192,28 @@ where let mut case_sensitive_arms = Vec::new(); let mut case_insensitive_arms = Vec::new(); - for variant in parsed_enum.variants.iter() { + for variant in parsed_enum.variants().iter() { let variant_ident = &variant.data.ident; match &variant.v_type { VariantType::Ignore => (), - VariantType::Named { - name, - constructor, - case_insensitive, - } => { - let match_pattern = if *case_insensitive { - let lowercase_name = name.to_lowercase(); - quote! { #lowercase_name } + VariantType::Named(named) => { + let match_pattern = if named.case_insensitive() { + let uppercase_name = named.name_upper(); + quote! { #uppercase_name } } else { + let name = named.name(); quote! { #name } }; - let constructor_tokens = constructor.empty(); - let constructed_variant = - quote! { #enum_ident::#variant_ident #constructor_tokens }; + let constructor_tokens = named.constructor().empty_toks(); + let constructed_variant = quote! { + #enum_ident::#variant_ident #constructor_tokens + }; let match_result = named_fn(constructed_variant); - if *case_insensitive { + if named.case_insensitive() { &mut case_insensitive_arms } else { &mut case_sensitive_arms @@ -223,11 +221,11 @@ where .push(quote! { #match_pattern => #match_result }); } - VariantType::Other { field_name } => { + VariantType::Other(other) => { let unscribe_value = quote! { <_ as ::std::convert::Into<_>>::into(#match_against) }; - let constructed_variant = match field_name { + let constructed_variant = match other.field_name() { None => quote! { #enum_ident::#variant_ident(#unscribe_value) }, @@ -251,10 +249,23 @@ where let case_insensitive_match = if case_insensitive_arms.is_empty() { None } else { + let match_against_upper_ident = quote! { __enumscribe_unscribe_uppercase }; + let name_upper_cap = parsed_enum.name_upper_capacity(); + Some(quote! { - let __enumscribe_unscribe_lowercase = #match_against.to_lowercase(); - match __enumscribe_unscribe_lowercase.as_str() { - #(#case_insensitive_arms,)* + match ::enumscribe + ::internal + ::capped_string + ::CappedString + ::<#name_upper_cap> + ::uppercase_from_str(#match_against) + { + Some(#match_against_upper_ident) => { + match &*#match_against_upper_ident { + #(#case_insensitive_arms,)* + #other_arm, + } + }, #other_arm, } }) @@ -659,23 +670,22 @@ pub fn derive_enum_serialize(input: TokenStream) -> TokenStream { let mut match_arms = Vec::new(); let mut ignore_variant = false; - for variant in parsed_enum.variants.iter() { + for variant in parsed_enum.variants().iter() { let variant_ident = &variant.data.ident; match &variant.v_type { VariantType::Ignore => ignore_variant = true, - VariantType::Named { - name, constructor, .. - } => { - let constructor_tokens = constructor.empty(); + VariantType::Named(named) => { + let constructor_tokens = named.constructor().empty_toks(); + let name = named.name(); match_arms.push(quote! { #enum_ident::#variant_ident #constructor_tokens => #serializer_ident.serialize_str(#name) }) } - VariantType::Other { field_name } => match field_name { + VariantType::Other(other) => match other.field_name() { Some(field_name) => match_arms.push(quote! { #enum_ident::#variant_ident { #field_name } => #serializer_ident.serialize_str(&#field_name) @@ -748,13 +758,14 @@ pub fn derive_enum_deserialize(input: TokenStream) -> TokenStream { let enum_ident = &input.ident; let deserializer_ident = quote! { __enumscribe_deserializer }; + let deserialized_cow_capped_str_ident = quote! { __enumscribe_deserialized_cow_capped_str }; let deserialized_str_ident = quote! { __enumscribe_deserialized_str }; let variant_strings = parsed_enum - .variants + .variants() .iter() .map(|variant| match &variant.v_type { - VariantType::Named { name, .. } => Some(name.as_str()), + VariantType::Named(named) => Some(named.name()), _ => None, }) .filter_map(|name| name) @@ -771,22 +782,32 @@ pub fn derive_enum_deserialize(input: TokenStream) -> TokenStream { ::core::result::Result::Ok(#constructed_other_variant) }, |_| Ok(quote! { - __enumscribe_deserialize_base_case => ::core::result::Result::Err( + _ => ::core::result::Result::Err( ::serde::de::Error::unknown_variant( - __enumscribe_deserialize_base_case, + #deserialized_str_ident, &[#(#variant_strings),*] ) ) }), )); + let name_cap = parsed_enum.name_capacity(); + (quote! { #[automatically_derived] impl<'de> ::serde::Deserialize<'de> for #enum_ident { fn deserialize(#deserializer_ident: D) -> ::core::result::Result where D: ::serde::Deserializer<'de> { - let #deserialized_str_ident = <&str as ::serde::Deserialize<'_>>::deserialize(#deserializer_ident)?; + let #deserialized_cow_capped_str_ident = < + ::enumscribe + ::internal + ::capped_string + ::CowCappedString<'de, #name_cap> + as ::serde::Deserialize<'_> + >::deserialize(#deserializer_ident)?; + + let #deserialized_str_ident = &*#deserialized_cow_capped_str_ident; #main_match } } diff --git a/enumscribe_examples/Cargo.toml b/enumscribe_examples/Cargo.toml index cc708a3..94112bb 100644 --- a/enumscribe_examples/Cargo.toml +++ b/enumscribe_examples/Cargo.toml @@ -6,6 +6,6 @@ edition = "2018" license = "MIT" [dev-dependencies] -enumscribe = { path = "../enumscribe" } +enumscribe = { path = "../enumscribe", features = ["std", "derive", "serde"] } serde = { version = "1.0", features = ["derive"] } serde_json = "1.0" diff --git a/enumscribe_tests/tests/test_serde.rs b/enumscribe_tests/tests/test_serde.rs new file mode 100644 index 0000000..a3aed39 --- /dev/null +++ b/enumscribe_tests/tests/test_serde.rs @@ -0,0 +1,25 @@ +use enumscribe::EnumDeserialize; + +#[test] +fn test_deserialize() { + #[derive(EnumDeserialize, Eq, PartialEq, Debug)] + enum E0 { + V0, + #[enumscribe(str = "baa", case_insensitive)] + V1, + #[enumscribe(str = "bAz\n", case_insensitive)] + V2, + #[enumscribe(str = "蟹")] + V3, + } + + assert_eq!(serde_json::from_str::(r#""V0""#).unwrap(), E0::V0); + assert!(serde_json::from_str::(r#""v0""#).is_err()); + assert_eq!(serde_json::from_str::(r#""baa""#).unwrap(), E0::V1); + assert_eq!(serde_json::from_str::(r#""BAA""#).unwrap(), E0::V1); + assert_eq!(serde_json::from_str::(r#""BaA""#).unwrap(), E0::V1); + assert_eq!(serde_json::from_str::(r#""baz\n""#).unwrap(), E0::V2); + assert_eq!(serde_json::from_str::(r#""BAZ\n""#).unwrap(), E0::V2); + assert_eq!(serde_json::from_str::(r#""BaZ\n""#).unwrap(), E0::V2); + assert_eq!(serde_json::from_str::(r#""\u87f9""#).unwrap(), E0::V3); +} diff --git a/justfile b/justfile new file mode 100644 index 0000000..62a7b45 --- /dev/null +++ b/justfile @@ -0,0 +1,6 @@ +nightly := 'cargo +nightly' +rustc_nightly_flags := '-Z randomize-layout -Z macro-backtrace' + +test: + RUSTFLAGS='{{rustc_nightly_flags}}' {{nightly}} miri test +