diff --git a/enumscribe/src/internal/capped_string.rs b/enumscribe/src/internal/capped_string.rs index 8ee4aca..43ab7b3 100644 --- a/enumscribe/src/internal/capped_string.rs +++ b/enumscribe/src/internal/capped_string.rs @@ -3,6 +3,68 @@ use core::{str, ops::Deref, borrow::Borrow, fmt}; +/// TODO: documentation +pub enum CowCappedString<'a, const N: usize> { + /// TODO: documentation + Borrowed(&'a str), + /// TODO: documentation + Owned(CappedString), +} + +#[cfg(feature = "serde")] +impl<'de, const N: usize> serde::Deserialize<'de> for CowCappedString<'de, N> { + fn deserialize(deserializer: D) -> Result + where + D: serde::Deserializer<'de> + { + deserializer.deserialize_str(CowCappedStringVisitor::) + } +} + +#[cfg(feature = "serde")] +struct CowCappedStringVisitor; + +#[cfg(feature = "serde")] +impl<'de, const N: usize> serde::de::Visitor<'de> for CowCappedStringVisitor { + type Value = CowCappedString<'de, N>; + + fn expecting(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "a borrowed string or a string up to {} bytes long", N) + } + + fn visit_str(self, v: &str) -> Result + where + E: serde::de::Error, + { + CappedStringVisitor::.visit_str(v) + .map(CowCappedString::Owned) + } + + fn visit_bytes(self, v: &[u8]) -> Result + where + E: serde::de::Error, + { + CappedStringVisitor::.visit_bytes(v) + .map(CowCappedString::Owned) + } + + fn visit_borrowed_str(self, v: &'de str) -> Result + where + E: serde::de::Error, + { + Ok(CowCappedString::Borrowed(v)) + } + + fn visit_borrowed_bytes(self, v: &'de [u8]) -> Result + where + E: serde::de::Error, + { + str::from_utf8(v) + .map_err(|_| E::invalid_value(serde::de::Unexpected::Bytes(v), &self)) + .and_then(|v| self.visit_borrowed_str(v)) + } +} + /// A string type which stores up to `N` bytes of string data inline. pub struct CappedString { /// The string data. It is an invariant that the first `len` bytes must be valid UTF-8. @@ -134,8 +196,10 @@ impl<'de, const N: usize> serde::Deserialize<'de> for CappedString { } } +#[cfg(feature = "serde")] struct CappedStringVisitor; +#[cfg(feature = "serde")] impl<'de, const N: usize> serde::de::Visitor<'de> for CappedStringVisitor { type Value = CappedString; @@ -157,8 +221,7 @@ impl<'de, const N: usize> serde::de::Visitor<'de> for CappedStringVisitor { { str::from_utf8(v) .map_err(|_| E::invalid_value(serde::de::Unexpected::Bytes(v), &self)) - .and_then(|v| CappedString::from_str(v) - .ok_or_else(|| E::invalid_length(v.len(), &self))) + .and_then(|v| self.visit_str(v)) } }