Enum variant parsing

rename
Pantonshire 5 years ago
parent dee1f49d1a
commit c7cc6507ad

@ -1,36 +1,15 @@
#[macro_use] #[macro_use]
extern crate enumscribe; extern crate enumscribe;
use std::collections::HashMap; #[derive(EnumToString)]
enum Foo {
#[derive(EnumStrDeserialize, PartialEq, Eq, Debug)] Baa,
#[case_insensitive] #[enumscribe(ignore)]
enum Airport { Baz(),
#[str_name("LHR")] #[enumscribe(other)]
Heathrow, Lorem { inner: String }
#[str_name("LGW")]
Gatwick,
#[str_name("LTN")]
Luton,
#[str_name("BHX")]
BirminghamInternational,
#[other]
Other(Box<String>),
} }
fn main() { fn main() {
let json_str = r#" println!("Hello world!");
{
"airport_1": "LTN",
"airport_2": "bhx",
"airport_3": "lHr",
"airport_4": "MAN"
}"#;
let json: HashMap<String, Airport> = serde_json::from_str(json_str).unwrap();
println!("{:?}", json.get("airport_1").unwrap());
println!("{:?}", json.get("airport_2").unwrap());
println!("{:?}", json.get("airport_3").unwrap());
println!("{:?}", json.get("airport_4").unwrap());
} }

@ -1,4 +1,5 @@
use std::collections::HashMap; use std::collections::HashMap;
use std::fmt;
use proc_macro2::Span; use proc_macro2::Span;
use syn::{Attribute, Ident, Lit, Token}; use syn::{Attribute, Ident, Lit, Token};
@ -6,6 +7,8 @@ use syn::parse::{Parse, ParseBuffer, ParseStream};
use syn::parse::discouraged::Speculative; use syn::parse::discouraged::Speculative;
use syn::token::Token; use syn::token::Token;
use crate::error::{MacroError, MacroResult, ValueTypeError, ValueTypeResult};
#[derive(Clone)] #[derive(Clone)]
pub(crate) enum Value { pub(crate) enum Value {
None, None,
@ -13,9 +16,84 @@ pub(crate) enum Value {
Ident(Ident), Ident(Ident),
} }
#[derive(Clone)] impl Value {
pub(crate) fn type_name(&self) -> &'static str {
match self {
Value::None => "nothing",
Value::Lit(lit) => match lit {
Lit::Str(_) => "string",
Lit::ByteStr(_) => "byte string",
Lit::Byte(_) => "byte",
Lit::Char(_) => "character",
Lit::Int(_) => "integer",
Lit::Float(_) => "float",
Lit::Bool(_) => "boolean",
Lit::Verbatim(_) => "verbatim literal",
},
Value::Ident(_) => "identifier",
}
}
/// Gets the boolean value associated with this Value. `Value::None` value is considered to
/// be true. If this value cannot represent a boolean, a `ValueTypeError` will be returned.
pub(crate) fn value_bool(&self) -> ValueTypeResult<bool> {
match self {
Value::None => Ok(true),
Value::Lit(Lit::Bool(lit_bool)) => Ok(lit_bool.value),
val => Err(ValueTypeError {
message: format!("expected boolean but found {}", val.type_name()).into()
})
}
}
/// Gets the string value associated with this Value. If this value cannot represent a string,
/// a `ValueTypeError` will be returned.
pub(crate) fn value_string(&self) -> ValueTypeResult<String> {
match self {
Value::Lit(Lit::Str(lit_str)) => Ok(lit_str.value()),
val => Err(ValueTypeError {
message: format!("expected string but found {}", val.type_name()).into()
})
}
}
}
impl fmt::Debug for Value {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Value::None => write!(f, "ε"),
Value::Lit(lit) => match lit {
Lit::Str(lit_str) => lit_str.value().fmt(f),
Lit::ByteStr(lit_byte_str) => lit_byte_str.value().fmt(f),
Lit::Byte(lit_byte) => lit_byte.value().fmt(f),
Lit::Char(lit_char) => lit_char.value().fmt(f),
Lit::Int(lit_int) => write!(f, "{}", lit_int.base10_digits()),
Lit::Float(lit_float) => write!(f, "{}", lit_float.base10_digits()),
Lit::Bool(lit_bool) => lit_bool.value.fmt(f),
Lit::Verbatim(lit_verbatim) => lit_verbatim.fmt(f),
},
Value::Ident(ident) => ident.fmt(f),
}
}
}
#[derive(Clone, Debug)]
pub(crate) struct Dict { pub(crate) struct Dict {
pub(crate) inner: HashMap<String, (Value, Span)> pub(crate) inner: HashMap<String, (Value, Span)>,
}
/// Represents the contents of a single `#[tag(...)]`.
/// The contents are parsed from `key = value` pairs, separated by commas.
#[derive(Clone, Debug)]
struct AttributeTag {
inner: Vec<(String, Value, Span)>,
}
#[derive(Clone, Debug)]
struct KeyValPair {
key: String,
val: Value,
span: Span,
} }
impl Dict { impl Dict {
@ -23,53 +101,71 @@ impl Dict {
Dict { inner: HashMap::new() } Dict { inner: HashMap::new() }
} }
pub(crate) fn from_attrs(name: &str, attrs: &[Attribute]) -> syn::Result<Self> { pub(crate) fn from_attrs(name: &str, attrs: &[Attribute]) -> MacroResult<Self> {
let mut dict = Dict::new(); let mut dict = Dict::new();
let sub_dicts = attrs.iter() let attribute_tags = attrs.iter()
.filter(|attr| attr.path.is_ident(name)) .filter(|attr| attr.path.is_ident(name))
.map(|attr| attr.parse_args::<Dict>()); .map(|attr| attr.parse_args::<AttributeTag>());
for tag in attribute_tags {
let tag = tag.map_err(MacroError::from)?;
for sub_dict in sub_dicts { for (key, val, span) in tag.inner {
dict.inner.extend(sub_dict?.inner.into_iter()); if dict.inner.contains_key(&key) {
return Err(MacroError::new(format!(
"key appears more than once: {}", key
), span));
}
dict.inner.insert(key, (val, span));
}
} }
Ok(dict) Ok(dict)
} }
pub(crate) fn require_keys(&self, keys: &[&str]) -> Result<(), String> { pub(crate) fn remove_typed_value<T, F>(&mut self, key: &str, converter: F) -> MacroResult<Option<(T, Span)>>
match keys.iter().find(|&&key| !self.inner.contains_key(key)) { where
Some(absent_key) => Err(absent_key.to_string()), F: Fn(&Value) -> ValueTypeResult<T>
None => Ok(()) {
match self.inner.remove(key) {
None => Ok(None),
Some((val, span)) => match converter(&val) {
Ok(converted) => Ok(Some((converted, span))),
Err(ValueTypeError { message }) => Err(MacroError::new(
format!("{} for key: {}", message, key),
span,
))
}
} }
} }
pub(crate) fn allow_keys(&self, keys: &[&str]) -> Result<(), String> { pub(crate) fn assert_empty(&self) -> MacroResult<()> {
match self.inner.keys().find(|key| !keys.contains(&key.as_str())) { if self.inner.is_empty() {
Some(disallowed_key) => Err(disallowed_key.clone()), Ok(())
None => Ok(()) } else {
let (unexpected_key, (_, unexpected_span)) = self.inner.iter().next().unwrap();
Err(MacroError::new(
format!("unexpected key: {}", unexpected_key),
*unexpected_span,
))
} }
} }
} }
impl Parse for Dict { impl Parse for AttributeTag {
fn parse(input: ParseStream) -> syn::Result<Self> { fn parse(input: ParseStream) -> syn::Result<Self> {
Ok(Dict { Ok(AttributeTag {
inner: input inner: input
.parse_terminated::<KeyValPair, Token![,]>(KeyValPair::parse)? .parse_terminated::<KeyValPair, Token![,]>(KeyValPair::parse)?
.into_iter() .into_iter()
.map(|pair| (pair.key, (pair.val, pair.span))) .map(|pair| (pair.key, pair.val, pair.span))
.collect::<HashMap<_, _>>() .collect::<Vec<_>>()
}) })
} }
} }
struct KeyValPair {
key: String,
val: Value,
span: Span,
}
impl Parse for KeyValPair { impl Parse for KeyValPair {
fn parse(input: ParseStream) -> syn::Result<Self> { fn parse(input: ParseStream) -> syn::Result<Self> {
let key = input let key = input
@ -77,15 +173,15 @@ impl Parse for KeyValPair {
let val = if input.peek(Token![=]) { let val = if input.peek(Token![=]) {
input.parse::<Token![=]>()?; input.parse::<Token![=]>()?;
if let Ok(lit) = speculative_parse::<Lit>(input) { if let Ok(lit) = speculative_parse::<Lit>(input) {
Value::Lit(lit) Value::Lit(lit)
} else if let Ok(ident) = speculative_parse::<Ident>(input) { } else if let Ok(ident) = speculative_parse::<Ident>(input) {
Value::Ident(ident) Value::Ident(ident)
} else { } else {
let err_msg = format!( return Err(input.error(format!(
"expected either a literal or identifier as the value corresponding to the \ "could not parse value corresponding to key: {}", key
key \"{}\", but found neither", key); )));
return Err(input.error(err_msg));
} }
} else { } else {
Value::None Value::None

@ -1,18 +1,24 @@
use std::borrow::Cow; use std::borrow::Cow;
use std::fmt;
use std::error;
use std::result;
use proc_macro2::Span; use proc_macro2::Span;
use quote::quote_spanned; use quote::quote_spanned;
use syn::Error;
#[derive(Clone, Debug)] #[derive(Clone, Debug)]
pub(crate) struct MacroError { pub(crate) struct MacroError {
message: Cow<'static, str>, pub(crate) message: Cow<'static, str>,
span: Span, pub(crate) span: Span,
} }
pub(crate) type MacroResult<T> = result::Result<T, MacroError>;
impl MacroError { impl MacroError {
pub(crate) fn new(message: Cow<'static, str>, span: Span) -> Self { pub(crate) fn new<T>(message: T, span: Span) -> Self where T : Into<Cow<'static, str>> {
MacroError { MacroError {
message, message: message.into(),
span, span,
} }
} }
@ -29,4 +35,31 @@ impl MacroError {
} }
} }
pub(crate) type Result<T> = std::result::Result<T, MacroError>; impl From<syn::Error> for MacroError {
fn from(err: Error) -> Self {
MacroError::new(err.to_string(), err.span())
}
}
impl fmt::Display for MacroError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}", self.message)
}
}
impl error::Error for MacroError {}
#[derive(Clone, Debug)]
pub(crate) struct ValueTypeError {
pub(crate) message: Cow<'static, str>
}
pub(crate) type ValueTypeResult<T> = result::Result<T, ValueTypeError>;
impl fmt::Display for ValueTypeError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}", self.message)
}
}
impl error::Error for ValueTypeError {}

@ -1,88 +1,169 @@
use proc_macro::TokenStream; use proc_macro::TokenStream;
use std::collections::HashSet;
use quote::quote; use proc_macro2::Span;
use syn::{Attribute, Data, DeriveInput, LitStr}; use quote::{quote, quote_spanned};
use syn::{Attribute, Data, DataEnum, DeriveInput, Fields, LitStr};
/// Derives `serde::Deserialize` for an enum with variants associated with strings. use syn::parse::{ParseBuffer, ParseStream};
/// The `#[str_name("...")]` attribute is used to specify the string associated with each variant. use syn::spanned::Spanned;
///
/// An "other" variant can be specified with `#[other]`. This variant should have a parameter use attribute::*;
/// which implements `From<String>` to store the string that could not be deserialized to any use error::{MacroError, MacroResult};
/// of the other variants.
/// mod error;
/// If no "other" variant is specified, strings which are not associated with any of the variants mod attribute;
/// will produce a deserialization error.
/// const CRATE_ATTR: &'static str = "enumscribe";
/// The enum may have the attribute `#[case_insensitive]`, in which case string comparisons will
/// be done case-insensitively. #[derive(Clone, Debug)]
#[proc_macro_derive(EnumStrDeserialize, attributes(case_insensitive, str_name, other))] struct Enum {
pub fn derive_enum_str_de(ast: TokenStream) -> TokenStream { variants: Vec<Variant>,
const ATTR_CASE_INSENSITIVE: &'static str = "case_insensitive"; }
const ATTR_STR_NAME: &'static str = "str_name";
const ATTR_OTHER: &'static str = "other";
let ast: DeriveInput = syn::parse(ast).unwrap();
let enum_name = &ast.ident;
let enum_names = std::iter::repeat(enum_name);
let case_insensitive = find_attribute(ATTR_CASE_INSENSITIVE, &ast.attrs).is_some();
let enum_data = match ast.data {
Data::Enum(e) => e,
_ => panic!("cannot derive EnumStrDeserialize for anything other than an enum"),
};
let (variant_names, variant_strings): (Vec<_>, Vec<_>) = enum_data.variants.iter() #[derive(Clone, Debug)]
.map(|variant| (&variant.ident, find_attribute(ATTR_STR_NAME, &variant.attrs))) struct Variant {
.filter(|(_, attribute)| attribute.is_some()) ident: String,
.map(|(variant_ident, attribute)| (variant_ident, attribute v_type: VariantType,
.unwrap() span: Span,
.parse_args::<LitStr>() }
.unwrap()
.value())) #[derive(Clone, Debug)]
.map(|(variant_ident, attribute)| (variant_ident, if case_insensitive { enum VariantType {
attribute.to_lowercase() Ignore,
Named { name: String, constructor: VariantConstructor },
Other { field_name: Option<String> }, //use {} for constructor if Some, use () if None
}
#[derive(Clone, Copy, Debug)]
enum VariantConstructor {
None,
Paren,
Brace,
}
fn parse_enum(data: DataEnum) -> MacroResult<Enum> {
const NAME: &'static str = "str";
const OTHER: &'static str = "other";
const IGNORE: &'static str = "ignore";
let mut variants = Vec::with_capacity(data.variants.len());
let mut taken_names = HashSet::new();
for variant in data.variants {
let variant_ident = variant.ident.to_string();
let variant_span = variant.span();
let mut dict = Dict::from_attrs(CRATE_ATTR, &variant.attrs)?;
let name_opt = dict.remove_typed_value(NAME, Value::value_string)?;
let other = match dict.remove_typed_value(OTHER, Value::value_bool)? {
Some((other, _)) => other,
None => false
};
let ignore = match dict.remove_typed_value(IGNORE, Value::value_bool)? {
Some((ignore, _)) => ignore,
None => false
};
dict.assert_empty()?;
let scribe_variant = if ignore {
Variant {
ident: variant_ident,
v_type: VariantType::Ignore,
span: variant_span,
}
} else if other {
if let Some((_, name_span)) = name_opt {
return Err(MacroError::new(
format!(
"cannot use {} for variant {} because it is marked as {}",
NAME, variant_ident, OTHER
),
name_span,
));
}
if variant.fields.len() != 1 {
return Err(MacroError::new(
format!(
"the variant {} must have exactly one field because it is marked as {}",
variant_ident, OTHER
),
variant_span,
));
}
let field_name = variant.fields.iter().next()
.and_then(|field| field.ident.as_ref().map(|ident| ident.to_string()));
Variant {
ident: variant_ident,
v_type: VariantType::Other { field_name },
span: variant_span,
}
} else { } else {
attribute let (name, name_span) = match name_opt {
})) Some((name, name_span)) => (name, name_span),
.unzip(); None => (variant.ident.to_string(), variant.ident.span())
};
let other_variant = enum_data.variants.iter() if taken_names.contains(&name) {
.find(|variant| find_attribute(ATTR_OTHER, &variant.attrs).is_some()); return Err(MacroError::new(
format!("duplicate name \"{}\"", name),
name_span,
));
}
let matching_string = if case_insensitive { if variant.fields.len() != 0 {
quote! { deserialized_string.to_lowercase() } return Err(MacroError::new(
} else { format!(
quote! { deserialized_string } "the variant {} must not have any fields",
}; variant_ident
),
variant_span,
));
}
let (base_case_pattern, base_case_value) = if let Some(other_variant) = other_variant { let constructor = match variant.fields {
let other_variant_name = &other_variant.ident; Fields::Named(_) => VariantConstructor::Brace,
(quote! { _ }, quote! { ::core::result::Result::Ok(#enum_name::#other_variant_name(deserialized_string.into())) }) Fields::Unnamed(_) => VariantConstructor::Paren,
} else { Fields::Unit => VariantConstructor::None,
(quote! { s }, quote! { ::core::result::Result::Err(::serde::de::Error::unknown_variant(s, &[#(#variant_strings),*])) }) };
};
(quote! { taken_names.insert(name.clone());
impl<'de> ::serde::Deserialize<'de> for #enum_name {
fn deserialize<D>(deserializer: D) -> ::core::result::Result<Self, D::Error> Variant {
where ident: variant_ident,
D: ::serde::Deserializer<'de>, v_type: VariantType::Named { name, constructor },
{ span: variant_span,
let deserialized_string = ::std::string::String::deserialize(deserializer)?;
match #matching_string.as_str() {
#(#variant_strings => ::core::result::Result::Ok(#enum_names::#variant_names),)*
#base_case_pattern => #base_case_value,
}
} }
} };
}).into()
} variants.push(scribe_variant);
}
fn find_attribute<'a>(name: &str, attributes: &'a [Attribute]) -> Option<&'a Attribute> { Ok(Enum { variants })
attributes
.iter()
.find(|attribute| attribute.path.is_ident(name))
} }
#[proc_macro_derive(EnumToString, attributes(enumscribe))]
pub fn derive_enum_to_string(input: TokenStream) -> TokenStream {
let input: DeriveInput = syn::parse(input).unwrap();
let enum_data = match input.data {
Data::Enum(data) => data,
Data::Struct(_) => return MacroError::new("cannot use enumscribe for structs", input.ident.span()).to_token_stream(),
Data::Union(_) => return MacroError::new("cannot use enumscribe for unions", input.ident.span()).to_token_stream()
};
let variants = match parse_enum(enum_data) {
Ok(variants) => variants,
Err(err) => return err.to_token_stream()
};
println!("{:?}", variants);
TokenStream::new()
}

Loading…
Cancel
Save