Enum variant parsing
parent
dee1f49d1a
commit
c7cc6507ad
@ -1,36 +1,15 @@
|
||||
#[macro_use]
|
||||
extern crate enumscribe;
|
||||
|
||||
use std::collections::HashMap;
|
||||
|
||||
#[derive(EnumStrDeserialize, PartialEq, Eq, Debug)]
|
||||
#[case_insensitive]
|
||||
enum Airport {
|
||||
#[str_name("LHR")]
|
||||
Heathrow,
|
||||
#[str_name("LGW")]
|
||||
Gatwick,
|
||||
#[str_name("LTN")]
|
||||
Luton,
|
||||
#[str_name("BHX")]
|
||||
BirminghamInternational,
|
||||
#[other]
|
||||
Other(Box<String>),
|
||||
#[derive(EnumToString)]
|
||||
enum Foo {
|
||||
Baa,
|
||||
#[enumscribe(ignore)]
|
||||
Baz(),
|
||||
#[enumscribe(other)]
|
||||
Lorem { inner: String }
|
||||
}
|
||||
|
||||
fn main() {
|
||||
let json_str = r#"
|
||||
{
|
||||
"airport_1": "LTN",
|
||||
"airport_2": "bhx",
|
||||
"airport_3": "lHr",
|
||||
"airport_4": "MAN"
|
||||
}"#;
|
||||
|
||||
let json: HashMap<String, Airport> = serde_json::from_str(json_str).unwrap();
|
||||
|
||||
println!("{:?}", json.get("airport_1").unwrap());
|
||||
println!("{:?}", json.get("airport_2").unwrap());
|
||||
println!("{:?}", json.get("airport_3").unwrap());
|
||||
println!("{:?}", json.get("airport_4").unwrap());
|
||||
println!("Hello world!");
|
||||
}
|
||||
|
||||
@ -1,88 +1,169 @@
|
||||
use proc_macro::TokenStream;
|
||||
use std::collections::HashSet;
|
||||
|
||||
use quote::quote;
|
||||
use syn::{Attribute, Data, DeriveInput, LitStr};
|
||||
|
||||
/// Derives `serde::Deserialize` for an enum with variants associated with strings.
|
||||
/// The `#[str_name("...")]` attribute is used to specify the string associated with each variant.
|
||||
///
|
||||
/// An "other" variant can be specified with `#[other]`. This variant should have a parameter
|
||||
/// which implements `From<String>` to store the string that could not be deserialized to any
|
||||
/// of the other variants.
|
||||
///
|
||||
/// If no "other" variant is specified, strings which are not associated with any of the variants
|
||||
/// will produce a deserialization error.
|
||||
///
|
||||
/// The enum may have the attribute `#[case_insensitive]`, in which case string comparisons will
|
||||
/// be done case-insensitively.
|
||||
#[proc_macro_derive(EnumStrDeserialize, attributes(case_insensitive, str_name, other))]
|
||||
pub fn derive_enum_str_de(ast: TokenStream) -> TokenStream {
|
||||
const ATTR_CASE_INSENSITIVE: &'static str = "case_insensitive";
|
||||
const ATTR_STR_NAME: &'static str = "str_name";
|
||||
const ATTR_OTHER: &'static str = "other";
|
||||
|
||||
let ast: DeriveInput = syn::parse(ast).unwrap();
|
||||
|
||||
let enum_name = &ast.ident;
|
||||
let enum_names = std::iter::repeat(enum_name);
|
||||
|
||||
let case_insensitive = find_attribute(ATTR_CASE_INSENSITIVE, &ast.attrs).is_some();
|
||||
|
||||
let enum_data = match ast.data {
|
||||
Data::Enum(e) => e,
|
||||
_ => panic!("cannot derive EnumStrDeserialize for anything other than an enum"),
|
||||
};
|
||||
use proc_macro2::Span;
|
||||
use quote::{quote, quote_spanned};
|
||||
use syn::{Attribute, Data, DataEnum, DeriveInput, Fields, LitStr};
|
||||
use syn::parse::{ParseBuffer, ParseStream};
|
||||
use syn::spanned::Spanned;
|
||||
|
||||
use attribute::*;
|
||||
use error::{MacroError, MacroResult};
|
||||
|
||||
mod error;
|
||||
mod attribute;
|
||||
|
||||
const CRATE_ATTR: &'static str = "enumscribe";
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
struct Enum {
|
||||
variants: Vec<Variant>,
|
||||
}
|
||||
|
||||
let (variant_names, variant_strings): (Vec<_>, Vec<_>) = enum_data.variants.iter()
|
||||
.map(|variant| (&variant.ident, find_attribute(ATTR_STR_NAME, &variant.attrs)))
|
||||
.filter(|(_, attribute)| attribute.is_some())
|
||||
.map(|(variant_ident, attribute)| (variant_ident, attribute
|
||||
.unwrap()
|
||||
.parse_args::<LitStr>()
|
||||
.unwrap()
|
||||
.value()))
|
||||
.map(|(variant_ident, attribute)| (variant_ident, if case_insensitive {
|
||||
attribute.to_lowercase()
|
||||
#[derive(Clone, Debug)]
|
||||
struct Variant {
|
||||
ident: String,
|
||||
v_type: VariantType,
|
||||
span: Span,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
enum VariantType {
|
||||
Ignore,
|
||||
Named { name: String, constructor: VariantConstructor },
|
||||
Other { field_name: Option<String> }, //use {} for constructor if Some, use () if None
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy, Debug)]
|
||||
enum VariantConstructor {
|
||||
None,
|
||||
Paren,
|
||||
Brace,
|
||||
}
|
||||
|
||||
fn parse_enum(data: DataEnum) -> MacroResult<Enum> {
|
||||
const NAME: &'static str = "str";
|
||||
const OTHER: &'static str = "other";
|
||||
const IGNORE: &'static str = "ignore";
|
||||
|
||||
let mut variants = Vec::with_capacity(data.variants.len());
|
||||
let mut taken_names = HashSet::new();
|
||||
|
||||
for variant in data.variants {
|
||||
let variant_ident = variant.ident.to_string();
|
||||
let variant_span = variant.span();
|
||||
|
||||
let mut dict = Dict::from_attrs(CRATE_ATTR, &variant.attrs)?;
|
||||
|
||||
let name_opt = dict.remove_typed_value(NAME, Value::value_string)?;
|
||||
|
||||
let other = match dict.remove_typed_value(OTHER, Value::value_bool)? {
|
||||
Some((other, _)) => other,
|
||||
None => false
|
||||
};
|
||||
|
||||
let ignore = match dict.remove_typed_value(IGNORE, Value::value_bool)? {
|
||||
Some((ignore, _)) => ignore,
|
||||
None => false
|
||||
};
|
||||
|
||||
dict.assert_empty()?;
|
||||
|
||||
let scribe_variant = if ignore {
|
||||
Variant {
|
||||
ident: variant_ident,
|
||||
v_type: VariantType::Ignore,
|
||||
span: variant_span,
|
||||
}
|
||||
} else if other {
|
||||
if let Some((_, name_span)) = name_opt {
|
||||
return Err(MacroError::new(
|
||||
format!(
|
||||
"cannot use {} for variant {} because it is marked as {}",
|
||||
NAME, variant_ident, OTHER
|
||||
),
|
||||
name_span,
|
||||
));
|
||||
}
|
||||
|
||||
if variant.fields.len() != 1 {
|
||||
return Err(MacroError::new(
|
||||
format!(
|
||||
"the variant {} must have exactly one field because it is marked as {}",
|
||||
variant_ident, OTHER
|
||||
),
|
||||
variant_span,
|
||||
));
|
||||
}
|
||||
|
||||
let field_name = variant.fields.iter().next()
|
||||
.and_then(|field| field.ident.as_ref().map(|ident| ident.to_string()));
|
||||
|
||||
Variant {
|
||||
ident: variant_ident,
|
||||
v_type: VariantType::Other { field_name },
|
||||
span: variant_span,
|
||||
}
|
||||
} else {
|
||||
attribute
|
||||
}))
|
||||
.unzip();
|
||||
let (name, name_span) = match name_opt {
|
||||
Some((name, name_span)) => (name, name_span),
|
||||
None => (variant.ident.to_string(), variant.ident.span())
|
||||
};
|
||||
|
||||
let other_variant = enum_data.variants.iter()
|
||||
.find(|variant| find_attribute(ATTR_OTHER, &variant.attrs).is_some());
|
||||
if taken_names.contains(&name) {
|
||||
return Err(MacroError::new(
|
||||
format!("duplicate name \"{}\"", name),
|
||||
name_span,
|
||||
));
|
||||
}
|
||||
|
||||
let matching_string = if case_insensitive {
|
||||
quote! { deserialized_string.to_lowercase() }
|
||||
} else {
|
||||
quote! { deserialized_string }
|
||||
};
|
||||
if variant.fields.len() != 0 {
|
||||
return Err(MacroError::new(
|
||||
format!(
|
||||
"the variant {} must not have any fields",
|
||||
variant_ident
|
||||
),
|
||||
variant_span,
|
||||
));
|
||||
}
|
||||
|
||||
let (base_case_pattern, base_case_value) = if let Some(other_variant) = other_variant {
|
||||
let other_variant_name = &other_variant.ident;
|
||||
(quote! { _ }, quote! { ::core::result::Result::Ok(#enum_name::#other_variant_name(deserialized_string.into())) })
|
||||
} else {
|
||||
(quote! { s }, quote! { ::core::result::Result::Err(::serde::de::Error::unknown_variant(s, &[#(#variant_strings),*])) })
|
||||
};
|
||||
let constructor = match variant.fields {
|
||||
Fields::Named(_) => VariantConstructor::Brace,
|
||||
Fields::Unnamed(_) => VariantConstructor::Paren,
|
||||
Fields::Unit => VariantConstructor::None,
|
||||
};
|
||||
|
||||
(quote! {
|
||||
impl<'de> ::serde::Deserialize<'de> for #enum_name {
|
||||
fn deserialize<D>(deserializer: D) -> ::core::result::Result<Self, D::Error>
|
||||
where
|
||||
D: ::serde::Deserializer<'de>,
|
||||
{
|
||||
let deserialized_string = ::std::string::String::deserialize(deserializer)?;
|
||||
match #matching_string.as_str() {
|
||||
#(#variant_strings => ::core::result::Result::Ok(#enum_names::#variant_names),)*
|
||||
#base_case_pattern => #base_case_value,
|
||||
}
|
||||
taken_names.insert(name.clone());
|
||||
|
||||
Variant {
|
||||
ident: variant_ident,
|
||||
v_type: VariantType::Named { name, constructor },
|
||||
span: variant_span,
|
||||
}
|
||||
}
|
||||
}).into()
|
||||
}
|
||||
};
|
||||
|
||||
variants.push(scribe_variant);
|
||||
}
|
||||
|
||||
fn find_attribute<'a>(name: &str, attributes: &'a [Attribute]) -> Option<&'a Attribute> {
|
||||
attributes
|
||||
.iter()
|
||||
.find(|attribute| attribute.path.is_ident(name))
|
||||
Ok(Enum { variants })
|
||||
}
|
||||
|
||||
#[proc_macro_derive(EnumToString, attributes(enumscribe))]
|
||||
pub fn derive_enum_to_string(input: TokenStream) -> TokenStream {
|
||||
let input: DeriveInput = syn::parse(input).unwrap();
|
||||
|
||||
let enum_data = match input.data {
|
||||
Data::Enum(data) => data,
|
||||
Data::Struct(_) => return MacroError::new("cannot use enumscribe for structs", input.ident.span()).to_token_stream(),
|
||||
Data::Union(_) => return MacroError::new("cannot use enumscribe for unions", input.ident.span()).to_token_stream()
|
||||
};
|
||||
|
||||
let variants = match parse_enum(enum_data) {
|
||||
Ok(variants) => variants,
|
||||
Err(err) => return err.to_token_stream()
|
||||
};
|
||||
|
||||
println!("{:?}", variants);
|
||||
|
||||
TokenStream::new()
|
||||
}
|
||||
|
||||
Loading…
Reference in New Issue