Enum variant parsing
parent
dee1f49d1a
commit
c7cc6507ad
@ -1,36 +1,15 @@
|
|||||||
#[macro_use]
|
#[macro_use]
|
||||||
extern crate enumscribe;
|
extern crate enumscribe;
|
||||||
|
|
||||||
use std::collections::HashMap;
|
#[derive(EnumToString)]
|
||||||
|
enum Foo {
|
||||||
#[derive(EnumStrDeserialize, PartialEq, Eq, Debug)]
|
Baa,
|
||||||
#[case_insensitive]
|
#[enumscribe(ignore)]
|
||||||
enum Airport {
|
Baz(),
|
||||||
#[str_name("LHR")]
|
#[enumscribe(other)]
|
||||||
Heathrow,
|
Lorem { inner: String }
|
||||||
#[str_name("LGW")]
|
|
||||||
Gatwick,
|
|
||||||
#[str_name("LTN")]
|
|
||||||
Luton,
|
|
||||||
#[str_name("BHX")]
|
|
||||||
BirminghamInternational,
|
|
||||||
#[other]
|
|
||||||
Other(Box<String>),
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fn main() {
|
fn main() {
|
||||||
let json_str = r#"
|
println!("Hello world!");
|
||||||
{
|
|
||||||
"airport_1": "LTN",
|
|
||||||
"airport_2": "bhx",
|
|
||||||
"airport_3": "lHr",
|
|
||||||
"airport_4": "MAN"
|
|
||||||
}"#;
|
|
||||||
|
|
||||||
let json: HashMap<String, Airport> = serde_json::from_str(json_str).unwrap();
|
|
||||||
|
|
||||||
println!("{:?}", json.get("airport_1").unwrap());
|
|
||||||
println!("{:?}", json.get("airport_2").unwrap());
|
|
||||||
println!("{:?}", json.get("airport_3").unwrap());
|
|
||||||
println!("{:?}", json.get("airport_4").unwrap());
|
|
||||||
}
|
}
|
||||||
|
|||||||
@ -1,88 +1,169 @@
|
|||||||
use proc_macro::TokenStream;
|
use proc_macro::TokenStream;
|
||||||
|
use std::collections::HashSet;
|
||||||
|
|
||||||
use quote::quote;
|
use proc_macro2::Span;
|
||||||
use syn::{Attribute, Data, DeriveInput, LitStr};
|
use quote::{quote, quote_spanned};
|
||||||
|
use syn::{Attribute, Data, DataEnum, DeriveInput, Fields, LitStr};
|
||||||
/// Derives `serde::Deserialize` for an enum with variants associated with strings.
|
use syn::parse::{ParseBuffer, ParseStream};
|
||||||
/// The `#[str_name("...")]` attribute is used to specify the string associated with each variant.
|
use syn::spanned::Spanned;
|
||||||
///
|
|
||||||
/// An "other" variant can be specified with `#[other]`. This variant should have a parameter
|
|
||||||
/// which implements `From<String>` to store the string that could not be deserialized to any
|
|
||||||
/// of the other variants.
|
|
||||||
///
|
|
||||||
/// If no "other" variant is specified, strings which are not associated with any of the variants
|
|
||||||
/// will produce a deserialization error.
|
|
||||||
///
|
|
||||||
/// The enum may have the attribute `#[case_insensitive]`, in which case string comparisons will
|
|
||||||
/// be done case-insensitively.
|
|
||||||
#[proc_macro_derive(EnumStrDeserialize, attributes(case_insensitive, str_name, other))]
|
|
||||||
pub fn derive_enum_str_de(ast: TokenStream) -> TokenStream {
|
|
||||||
const ATTR_CASE_INSENSITIVE: &'static str = "case_insensitive";
|
|
||||||
const ATTR_STR_NAME: &'static str = "str_name";
|
|
||||||
const ATTR_OTHER: &'static str = "other";
|
|
||||||
|
|
||||||
let ast: DeriveInput = syn::parse(ast).unwrap();
|
|
||||||
|
|
||||||
let enum_name = &ast.ident;
|
|
||||||
let enum_names = std::iter::repeat(enum_name);
|
|
||||||
|
|
||||||
let case_insensitive = find_attribute(ATTR_CASE_INSENSITIVE, &ast.attrs).is_some();
|
|
||||||
|
|
||||||
let enum_data = match ast.data {
|
|
||||||
Data::Enum(e) => e,
|
|
||||||
_ => panic!("cannot derive EnumStrDeserialize for anything other than an enum"),
|
|
||||||
};
|
|
||||||
|
|
||||||
let (variant_names, variant_strings): (Vec<_>, Vec<_>) = enum_data.variants.iter()
|
use attribute::*;
|
||||||
.map(|variant| (&variant.ident, find_attribute(ATTR_STR_NAME, &variant.attrs)))
|
use error::{MacroError, MacroResult};
|
||||||
.filter(|(_, attribute)| attribute.is_some())
|
|
||||||
.map(|(variant_ident, attribute)| (variant_ident, attribute
|
|
||||||
.unwrap()
|
|
||||||
.parse_args::<LitStr>()
|
|
||||||
.unwrap()
|
|
||||||
.value()))
|
|
||||||
.map(|(variant_ident, attribute)| (variant_ident, if case_insensitive {
|
|
||||||
attribute.to_lowercase()
|
|
||||||
} else {
|
|
||||||
attribute
|
|
||||||
}))
|
|
||||||
.unzip();
|
|
||||||
|
|
||||||
let other_variant = enum_data.variants.iter()
|
mod error;
|
||||||
.find(|variant| find_attribute(ATTR_OTHER, &variant.attrs).is_some());
|
mod attribute;
|
||||||
|
|
||||||
let matching_string = if case_insensitive {
|
const CRATE_ATTR: &'static str = "enumscribe";
|
||||||
quote! { deserialized_string.to_lowercase() }
|
|
||||||
} else {
|
#[derive(Clone, Debug)]
|
||||||
quote! { deserialized_string }
|
struct Enum {
|
||||||
|
variants: Vec<Variant>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Debug)]
|
||||||
|
struct Variant {
|
||||||
|
ident: String,
|
||||||
|
v_type: VariantType,
|
||||||
|
span: Span,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Debug)]
|
||||||
|
enum VariantType {
|
||||||
|
Ignore,
|
||||||
|
Named { name: String, constructor: VariantConstructor },
|
||||||
|
Other { field_name: Option<String> }, //use {} for constructor if Some, use () if None
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Copy, Debug)]
|
||||||
|
enum VariantConstructor {
|
||||||
|
None,
|
||||||
|
Paren,
|
||||||
|
Brace,
|
||||||
|
}
|
||||||
|
|
||||||
|
fn parse_enum(data: DataEnum) -> MacroResult<Enum> {
|
||||||
|
const NAME: &'static str = "str";
|
||||||
|
const OTHER: &'static str = "other";
|
||||||
|
const IGNORE: &'static str = "ignore";
|
||||||
|
|
||||||
|
let mut variants = Vec::with_capacity(data.variants.len());
|
||||||
|
let mut taken_names = HashSet::new();
|
||||||
|
|
||||||
|
for variant in data.variants {
|
||||||
|
let variant_ident = variant.ident.to_string();
|
||||||
|
let variant_span = variant.span();
|
||||||
|
|
||||||
|
let mut dict = Dict::from_attrs(CRATE_ATTR, &variant.attrs)?;
|
||||||
|
|
||||||
|
let name_opt = dict.remove_typed_value(NAME, Value::value_string)?;
|
||||||
|
|
||||||
|
let other = match dict.remove_typed_value(OTHER, Value::value_bool)? {
|
||||||
|
Some((other, _)) => other,
|
||||||
|
None => false
|
||||||
|
};
|
||||||
|
|
||||||
|
let ignore = match dict.remove_typed_value(IGNORE, Value::value_bool)? {
|
||||||
|
Some((ignore, _)) => ignore,
|
||||||
|
None => false
|
||||||
};
|
};
|
||||||
|
|
||||||
let (base_case_pattern, base_case_value) = if let Some(other_variant) = other_variant {
|
dict.assert_empty()?;
|
||||||
let other_variant_name = &other_variant.ident;
|
|
||||||
(quote! { _ }, quote! { ::core::result::Result::Ok(#enum_name::#other_variant_name(deserialized_string.into())) })
|
let scribe_variant = if ignore {
|
||||||
|
Variant {
|
||||||
|
ident: variant_ident,
|
||||||
|
v_type: VariantType::Ignore,
|
||||||
|
span: variant_span,
|
||||||
|
}
|
||||||
|
} else if other {
|
||||||
|
if let Some((_, name_span)) = name_opt {
|
||||||
|
return Err(MacroError::new(
|
||||||
|
format!(
|
||||||
|
"cannot use {} for variant {} because it is marked as {}",
|
||||||
|
NAME, variant_ident, OTHER
|
||||||
|
),
|
||||||
|
name_span,
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
if variant.fields.len() != 1 {
|
||||||
|
return Err(MacroError::new(
|
||||||
|
format!(
|
||||||
|
"the variant {} must have exactly one field because it is marked as {}",
|
||||||
|
variant_ident, OTHER
|
||||||
|
),
|
||||||
|
variant_span,
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
let field_name = variant.fields.iter().next()
|
||||||
|
.and_then(|field| field.ident.as_ref().map(|ident| ident.to_string()));
|
||||||
|
|
||||||
|
Variant {
|
||||||
|
ident: variant_ident,
|
||||||
|
v_type: VariantType::Other { field_name },
|
||||||
|
span: variant_span,
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
(quote! { s }, quote! { ::core::result::Result::Err(::serde::de::Error::unknown_variant(s, &[#(#variant_strings),*])) })
|
let (name, name_span) = match name_opt {
|
||||||
|
Some((name, name_span)) => (name, name_span),
|
||||||
|
None => (variant.ident.to_string(), variant.ident.span())
|
||||||
};
|
};
|
||||||
|
|
||||||
(quote! {
|
if taken_names.contains(&name) {
|
||||||
impl<'de> ::serde::Deserialize<'de> for #enum_name {
|
return Err(MacroError::new(
|
||||||
fn deserialize<D>(deserializer: D) -> ::core::result::Result<Self, D::Error>
|
format!("duplicate name \"{}\"", name),
|
||||||
where
|
name_span,
|
||||||
D: ::serde::Deserializer<'de>,
|
));
|
||||||
{
|
|
||||||
let deserialized_string = ::std::string::String::deserialize(deserializer)?;
|
|
||||||
match #matching_string.as_str() {
|
|
||||||
#(#variant_strings => ::core::result::Result::Ok(#enum_names::#variant_names),)*
|
|
||||||
#base_case_pattern => #base_case_value,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if variant.fields.len() != 0 {
|
||||||
|
return Err(MacroError::new(
|
||||||
|
format!(
|
||||||
|
"the variant {} must not have any fields",
|
||||||
|
variant_ident
|
||||||
|
),
|
||||||
|
variant_span,
|
||||||
|
));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
let constructor = match variant.fields {
|
||||||
|
Fields::Named(_) => VariantConstructor::Brace,
|
||||||
|
Fields::Unnamed(_) => VariantConstructor::Paren,
|
||||||
|
Fields::Unit => VariantConstructor::None,
|
||||||
|
};
|
||||||
|
|
||||||
|
taken_names.insert(name.clone());
|
||||||
|
|
||||||
|
Variant {
|
||||||
|
ident: variant_ident,
|
||||||
|
v_type: VariantType::Named { name, constructor },
|
||||||
|
span: variant_span,
|
||||||
}
|
}
|
||||||
}).into()
|
};
|
||||||
|
|
||||||
|
variants.push(scribe_variant);
|
||||||
}
|
}
|
||||||
|
|
||||||
fn find_attribute<'a>(name: &str, attributes: &'a [Attribute]) -> Option<&'a Attribute> {
|
Ok(Enum { variants })
|
||||||
attributes
|
|
||||||
.iter()
|
|
||||||
.find(|attribute| attribute.path.is_ident(name))
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[proc_macro_derive(EnumToString, attributes(enumscribe))]
|
||||||
|
pub fn derive_enum_to_string(input: TokenStream) -> TokenStream {
|
||||||
|
let input: DeriveInput = syn::parse(input).unwrap();
|
||||||
|
|
||||||
|
let enum_data = match input.data {
|
||||||
|
Data::Enum(data) => data,
|
||||||
|
Data::Struct(_) => return MacroError::new("cannot use enumscribe for structs", input.ident.span()).to_token_stream(),
|
||||||
|
Data::Union(_) => return MacroError::new("cannot use enumscribe for unions", input.ident.span()).to_token_stream()
|
||||||
|
};
|
||||||
|
|
||||||
|
let variants = match parse_enum(enum_data) {
|
||||||
|
Ok(variants) => variants,
|
||||||
|
Err(err) => return err.to_token_stream()
|
||||||
|
};
|
||||||
|
|
||||||
|
println!("{:?}", variants);
|
||||||
|
|
||||||
|
TokenStream::new()
|
||||||
|
}
|
||||||
|
|||||||
Loading…
Reference in New Issue