From c7cc6507adc9cd474ed4588d731c314b26d9d6ad Mon Sep 17 00:00:00 2001 From: Pantonshire Date: Sat, 15 May 2021 15:21:16 +0100 Subject: [PATCH] Enum variant parsing --- examples/airports.rs | 37 ++----- src/attribute.rs | 154 ++++++++++++++++++++++------ src/error.rs | 43 +++++++- src/lib.rs | 233 +++++++++++++++++++++++++++++-------------- 4 files changed, 328 insertions(+), 139 deletions(-) diff --git a/examples/airports.rs b/examples/airports.rs index 46f70c0..2d1482d 100644 --- a/examples/airports.rs +++ b/examples/airports.rs @@ -1,36 +1,15 @@ #[macro_use] extern crate enumscribe; -use std::collections::HashMap; - -#[derive(EnumStrDeserialize, PartialEq, Eq, Debug)] -#[case_insensitive] -enum Airport { - #[str_name("LHR")] - Heathrow, - #[str_name("LGW")] - Gatwick, - #[str_name("LTN")] - Luton, - #[str_name("BHX")] - BirminghamInternational, - #[other] - Other(Box), +#[derive(EnumToString)] +enum Foo { + Baa, + #[enumscribe(ignore)] + Baz(), + #[enumscribe(other)] + Lorem { inner: String } } fn main() { - let json_str = r#" - { - "airport_1": "LTN", - "airport_2": "bhx", - "airport_3": "lHr", - "airport_4": "MAN" - }"#; - - let json: HashMap = serde_json::from_str(json_str).unwrap(); - - println!("{:?}", json.get("airport_1").unwrap()); - println!("{:?}", json.get("airport_2").unwrap()); - println!("{:?}", json.get("airport_3").unwrap()); - println!("{:?}", json.get("airport_4").unwrap()); + println!("Hello world!"); } diff --git a/src/attribute.rs b/src/attribute.rs index 71ae7df..8716b08 100644 --- a/src/attribute.rs +++ b/src/attribute.rs @@ -1,4 +1,5 @@ use std::collections::HashMap; +use std::fmt; use proc_macro2::Span; use syn::{Attribute, Ident, Lit, Token}; @@ -6,6 +7,8 @@ use syn::parse::{Parse, ParseBuffer, ParseStream}; use syn::parse::discouraged::Speculative; use syn::token::Token; +use crate::error::{MacroError, MacroResult, ValueTypeError, ValueTypeResult}; + #[derive(Clone)] pub(crate) enum Value { None, @@ -13,9 +16,84 @@ pub(crate) enum Value { Ident(Ident), } -#[derive(Clone)] +impl Value { + pub(crate) fn type_name(&self) -> &'static str { + match self { + Value::None => "nothing", + Value::Lit(lit) => match lit { + Lit::Str(_) => "string", + Lit::ByteStr(_) => "byte string", + Lit::Byte(_) => "byte", + Lit::Char(_) => "character", + Lit::Int(_) => "integer", + Lit::Float(_) => "float", + Lit::Bool(_) => "boolean", + Lit::Verbatim(_) => "verbatim literal", + }, + Value::Ident(_) => "identifier", + } + } + + /// Gets the boolean value associated with this Value. `Value::None` value is considered to + /// be true. If this value cannot represent a boolean, a `ValueTypeError` will be returned. + pub(crate) fn value_bool(&self) -> ValueTypeResult { + match self { + Value::None => Ok(true), + Value::Lit(Lit::Bool(lit_bool)) => Ok(lit_bool.value), + val => Err(ValueTypeError { + message: format!("expected boolean but found {}", val.type_name()).into() + }) + } + } + + /// Gets the string value associated with this Value. If this value cannot represent a string, + /// a `ValueTypeError` will be returned. + pub(crate) fn value_string(&self) -> ValueTypeResult { + match self { + Value::Lit(Lit::Str(lit_str)) => Ok(lit_str.value()), + val => Err(ValueTypeError { + message: format!("expected string but found {}", val.type_name()).into() + }) + } + } +} + +impl fmt::Debug for Value { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Value::None => write!(f, "ε"), + Value::Lit(lit) => match lit { + Lit::Str(lit_str) => lit_str.value().fmt(f), + Lit::ByteStr(lit_byte_str) => lit_byte_str.value().fmt(f), + Lit::Byte(lit_byte) => lit_byte.value().fmt(f), + Lit::Char(lit_char) => lit_char.value().fmt(f), + Lit::Int(lit_int) => write!(f, "{}", lit_int.base10_digits()), + Lit::Float(lit_float) => write!(f, "{}", lit_float.base10_digits()), + Lit::Bool(lit_bool) => lit_bool.value.fmt(f), + Lit::Verbatim(lit_verbatim) => lit_verbatim.fmt(f), + }, + Value::Ident(ident) => ident.fmt(f), + } + } +} + +#[derive(Clone, Debug)] pub(crate) struct Dict { - pub(crate) inner: HashMap + pub(crate) inner: HashMap, +} + +/// Represents the contents of a single `#[tag(...)]`. +/// The contents are parsed from `key = value` pairs, separated by commas. +#[derive(Clone, Debug)] +struct AttributeTag { + inner: Vec<(String, Value, Span)>, +} + +#[derive(Clone, Debug)] +struct KeyValPair { + key: String, + val: Value, + span: Span, } impl Dict { @@ -23,53 +101,71 @@ impl Dict { Dict { inner: HashMap::new() } } - pub(crate) fn from_attrs(name: &str, attrs: &[Attribute]) -> syn::Result { + pub(crate) fn from_attrs(name: &str, attrs: &[Attribute]) -> MacroResult { let mut dict = Dict::new(); - let sub_dicts = attrs.iter() + let attribute_tags = attrs.iter() .filter(|attr| attr.path.is_ident(name)) - .map(|attr| attr.parse_args::()); + .map(|attr| attr.parse_args::()); + + for tag in attribute_tags { + let tag = tag.map_err(MacroError::from)?; - for sub_dict in sub_dicts { - dict.inner.extend(sub_dict?.inner.into_iter()); + for (key, val, span) in tag.inner { + if dict.inner.contains_key(&key) { + return Err(MacroError::new(format!( + "key appears more than once: {}", key + ), span)); + } + + dict.inner.insert(key, (val, span)); + } } Ok(dict) } - pub(crate) fn require_keys(&self, keys: &[&str]) -> Result<(), String> { - match keys.iter().find(|&&key| !self.inner.contains_key(key)) { - Some(absent_key) => Err(absent_key.to_string()), - None => Ok(()) + pub(crate) fn remove_typed_value(&mut self, key: &str, converter: F) -> MacroResult> + where + F: Fn(&Value) -> ValueTypeResult + { + match self.inner.remove(key) { + None => Ok(None), + Some((val, span)) => match converter(&val) { + Ok(converted) => Ok(Some((converted, span))), + Err(ValueTypeError { message }) => Err(MacroError::new( + format!("{} for key: {}", message, key), + span, + )) + } } } - pub(crate) fn allow_keys(&self, keys: &[&str]) -> Result<(), String> { - match self.inner.keys().find(|key| !keys.contains(&key.as_str())) { - Some(disallowed_key) => Err(disallowed_key.clone()), - None => Ok(()) + pub(crate) fn assert_empty(&self) -> MacroResult<()> { + if self.inner.is_empty() { + Ok(()) + } else { + let (unexpected_key, (_, unexpected_span)) = self.inner.iter().next().unwrap(); + Err(MacroError::new( + format!("unexpected key: {}", unexpected_key), + *unexpected_span, + )) } } } -impl Parse for Dict { +impl Parse for AttributeTag { fn parse(input: ParseStream) -> syn::Result { - Ok(Dict { + Ok(AttributeTag { inner: input .parse_terminated::(KeyValPair::parse)? .into_iter() - .map(|pair| (pair.key, (pair.val, pair.span))) - .collect::>() + .map(|pair| (pair.key, pair.val, pair.span)) + .collect::>() }) } } -struct KeyValPair { - key: String, - val: Value, - span: Span, -} - impl Parse for KeyValPair { fn parse(input: ParseStream) -> syn::Result { let key = input @@ -77,15 +173,15 @@ impl Parse for KeyValPair { let val = if input.peek(Token![=]) { input.parse::()?; + if let Ok(lit) = speculative_parse::(input) { Value::Lit(lit) } else if let Ok(ident) = speculative_parse::(input) { Value::Ident(ident) } else { - let err_msg = format!( - "expected either a literal or identifier as the value corresponding to the \ - key \"{}\", but found neither", key); - return Err(input.error(err_msg)); + return Err(input.error(format!( + "could not parse value corresponding to key: {}", key + ))); } } else { Value::None diff --git a/src/error.rs b/src/error.rs index 4f33e0e..b506ab9 100644 --- a/src/error.rs +++ b/src/error.rs @@ -1,18 +1,24 @@ use std::borrow::Cow; +use std::fmt; +use std::error; +use std::result; use proc_macro2::Span; use quote::quote_spanned; +use syn::Error; #[derive(Clone, Debug)] pub(crate) struct MacroError { - message: Cow<'static, str>, - span: Span, + pub(crate) message: Cow<'static, str>, + pub(crate) span: Span, } +pub(crate) type MacroResult = result::Result; + impl MacroError { - pub(crate) fn new(message: Cow<'static, str>, span: Span) -> Self { + pub(crate) fn new(message: T, span: Span) -> Self where T : Into> { MacroError { - message, + message: message.into(), span, } } @@ -29,4 +35,31 @@ impl MacroError { } } -pub(crate) type Result = std::result::Result; +impl From for MacroError { + fn from(err: Error) -> Self { + MacroError::new(err.to_string(), err.span()) + } +} + +impl fmt::Display for MacroError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{}", self.message) + } +} + +impl error::Error for MacroError {} + +#[derive(Clone, Debug)] +pub(crate) struct ValueTypeError { + pub(crate) message: Cow<'static, str> +} + +pub(crate) type ValueTypeResult = result::Result; + +impl fmt::Display for ValueTypeError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{}", self.message) + } +} + +impl error::Error for ValueTypeError {} diff --git a/src/lib.rs b/src/lib.rs index 8e83f98..8da1aef 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,88 +1,169 @@ use proc_macro::TokenStream; +use std::collections::HashSet; -use quote::quote; -use syn::{Attribute, Data, DeriveInput, LitStr}; - -/// Derives `serde::Deserialize` for an enum with variants associated with strings. -/// The `#[str_name("...")]` attribute is used to specify the string associated with each variant. -/// -/// An "other" variant can be specified with `#[other]`. This variant should have a parameter -/// which implements `From` to store the string that could not be deserialized to any -/// of the other variants. -/// -/// If no "other" variant is specified, strings which are not associated with any of the variants -/// will produce a deserialization error. -/// -/// The enum may have the attribute `#[case_insensitive]`, in which case string comparisons will -/// be done case-insensitively. -#[proc_macro_derive(EnumStrDeserialize, attributes(case_insensitive, str_name, other))] -pub fn derive_enum_str_de(ast: TokenStream) -> TokenStream { - const ATTR_CASE_INSENSITIVE: &'static str = "case_insensitive"; - const ATTR_STR_NAME: &'static str = "str_name"; - const ATTR_OTHER: &'static str = "other"; - - let ast: DeriveInput = syn::parse(ast).unwrap(); - - let enum_name = &ast.ident; - let enum_names = std::iter::repeat(enum_name); - - let case_insensitive = find_attribute(ATTR_CASE_INSENSITIVE, &ast.attrs).is_some(); - - let enum_data = match ast.data { - Data::Enum(e) => e, - _ => panic!("cannot derive EnumStrDeserialize for anything other than an enum"), - }; +use proc_macro2::Span; +use quote::{quote, quote_spanned}; +use syn::{Attribute, Data, DataEnum, DeriveInput, Fields, LitStr}; +use syn::parse::{ParseBuffer, ParseStream}; +use syn::spanned::Spanned; + +use attribute::*; +use error::{MacroError, MacroResult}; + +mod error; +mod attribute; + +const CRATE_ATTR: &'static str = "enumscribe"; + +#[derive(Clone, Debug)] +struct Enum { + variants: Vec, +} - let (variant_names, variant_strings): (Vec<_>, Vec<_>) = enum_data.variants.iter() - .map(|variant| (&variant.ident, find_attribute(ATTR_STR_NAME, &variant.attrs))) - .filter(|(_, attribute)| attribute.is_some()) - .map(|(variant_ident, attribute)| (variant_ident, attribute - .unwrap() - .parse_args::() - .unwrap() - .value())) - .map(|(variant_ident, attribute)| (variant_ident, if case_insensitive { - attribute.to_lowercase() +#[derive(Clone, Debug)] +struct Variant { + ident: String, + v_type: VariantType, + span: Span, +} + +#[derive(Clone, Debug)] +enum VariantType { + Ignore, + Named { name: String, constructor: VariantConstructor }, + Other { field_name: Option }, //use {} for constructor if Some, use () if None +} + +#[derive(Clone, Copy, Debug)] +enum VariantConstructor { + None, + Paren, + Brace, +} + +fn parse_enum(data: DataEnum) -> MacroResult { + const NAME: &'static str = "str"; + const OTHER: &'static str = "other"; + const IGNORE: &'static str = "ignore"; + + let mut variants = Vec::with_capacity(data.variants.len()); + let mut taken_names = HashSet::new(); + + for variant in data.variants { + let variant_ident = variant.ident.to_string(); + let variant_span = variant.span(); + + let mut dict = Dict::from_attrs(CRATE_ATTR, &variant.attrs)?; + + let name_opt = dict.remove_typed_value(NAME, Value::value_string)?; + + let other = match dict.remove_typed_value(OTHER, Value::value_bool)? { + Some((other, _)) => other, + None => false + }; + + let ignore = match dict.remove_typed_value(IGNORE, Value::value_bool)? { + Some((ignore, _)) => ignore, + None => false + }; + + dict.assert_empty()?; + + let scribe_variant = if ignore { + Variant { + ident: variant_ident, + v_type: VariantType::Ignore, + span: variant_span, + } + } else if other { + if let Some((_, name_span)) = name_opt { + return Err(MacroError::new( + format!( + "cannot use {} for variant {} because it is marked as {}", + NAME, variant_ident, OTHER + ), + name_span, + )); + } + + if variant.fields.len() != 1 { + return Err(MacroError::new( + format!( + "the variant {} must have exactly one field because it is marked as {}", + variant_ident, OTHER + ), + variant_span, + )); + } + + let field_name = variant.fields.iter().next() + .and_then(|field| field.ident.as_ref().map(|ident| ident.to_string())); + + Variant { + ident: variant_ident, + v_type: VariantType::Other { field_name }, + span: variant_span, + } } else { - attribute - })) - .unzip(); + let (name, name_span) = match name_opt { + Some((name, name_span)) => (name, name_span), + None => (variant.ident.to_string(), variant.ident.span()) + }; - let other_variant = enum_data.variants.iter() - .find(|variant| find_attribute(ATTR_OTHER, &variant.attrs).is_some()); + if taken_names.contains(&name) { + return Err(MacroError::new( + format!("duplicate name \"{}\"", name), + name_span, + )); + } - let matching_string = if case_insensitive { - quote! { deserialized_string.to_lowercase() } - } else { - quote! { deserialized_string } - }; + if variant.fields.len() != 0 { + return Err(MacroError::new( + format!( + "the variant {} must not have any fields", + variant_ident + ), + variant_span, + )); + } - let (base_case_pattern, base_case_value) = if let Some(other_variant) = other_variant { - let other_variant_name = &other_variant.ident; - (quote! { _ }, quote! { ::core::result::Result::Ok(#enum_name::#other_variant_name(deserialized_string.into())) }) - } else { - (quote! { s }, quote! { ::core::result::Result::Err(::serde::de::Error::unknown_variant(s, &[#(#variant_strings),*])) }) - }; + let constructor = match variant.fields { + Fields::Named(_) => VariantConstructor::Brace, + Fields::Unnamed(_) => VariantConstructor::Paren, + Fields::Unit => VariantConstructor::None, + }; - (quote! { - impl<'de> ::serde::Deserialize<'de> for #enum_name { - fn deserialize(deserializer: D) -> ::core::result::Result - where - D: ::serde::Deserializer<'de>, - { - let deserialized_string = ::std::string::String::deserialize(deserializer)?; - match #matching_string.as_str() { - #(#variant_strings => ::core::result::Result::Ok(#enum_names::#variant_names),)* - #base_case_pattern => #base_case_value, - } + taken_names.insert(name.clone()); + + Variant { + ident: variant_ident, + v_type: VariantType::Named { name, constructor }, + span: variant_span, } - } - }).into() -} + }; + + variants.push(scribe_variant); + } -fn find_attribute<'a>(name: &str, attributes: &'a [Attribute]) -> Option<&'a Attribute> { - attributes - .iter() - .find(|attribute| attribute.path.is_ident(name)) + Ok(Enum { variants }) } +#[proc_macro_derive(EnumToString, attributes(enumscribe))] +pub fn derive_enum_to_string(input: TokenStream) -> TokenStream { + let input: DeriveInput = syn::parse(input).unwrap(); + + let enum_data = match input.data { + Data::Enum(data) => data, + Data::Struct(_) => return MacroError::new("cannot use enumscribe for structs", input.ident.span()).to_token_stream(), + Data::Union(_) => return MacroError::new("cannot use enumscribe for unions", input.ident.span()).to_token_stream() + }; + + let variants = match parse_enum(enum_data) { + Ok(variants) => variants, + Err(err) => return err.to_token_stream() + }; + + println!("{:?}", variants); + + TokenStream::new() +}