From d0ea531c05031cf8496d1275478f24fcb1e8492a Mon Sep 17 00:00:00 2001 From: pantonshire Date: Fri, 26 Aug 2022 15:00:11 +0100 Subject: [PATCH] RFC 3986 percent encoding, std feature now depends on alloc feature --- Cargo.toml | 2 +- src/encoding/hex.rs | 89 ++++++------ src/encoding/rfc3986.rs | 299 ++++++++++++++++++++++++++++++++++++++++ src/lib.rs | 2 +- src/strings/capped.rs | 8 +- src/strings/mod.rs | 4 +- src/uuid.rs | 12 +- 7 files changed, 359 insertions(+), 57 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 387dc5b..f6feafa 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -6,7 +6,7 @@ edition = "2021" [features] default = ["std"] alloc = ["serde?/alloc"] -std = ["serde?/std"] +std = ["serde?/std", "alloc"] [dependencies] serde = { version = "1", default-features = false, optional = true } diff --git a/src/encoding/hex.rs b/src/encoding/hex.rs index 24d1ae2..9e719f5 100644 --- a/src/encoding/hex.rs +++ b/src/encoding/hex.rs @@ -179,23 +179,26 @@ fn nybble_to_hex_upper(nybble: u8) -> u8 { } } -pub fn hex_to_be_byte_array(hex: &str) -> Result<[u8; N], ArrayParseError> { - let mut iter = hex.chars().rev(); +pub fn hex_to_be_byte_array(hex: &B) -> Result<[u8; N], ArrayParseError> +where + B: AsRef<[u8]> + ?Sized, +{ + let mut iter = hex.as_ref().iter().copied().rev(); let mut buf = [0u8; N]; let mut bytes = 0usize; - while let Some(ch0) = iter.next() { + while let Some(lsb) = iter.next() { bytes += 1; if bytes > N { return Err(ArrayParseError::TooLong(N)) } match iter.next() { - Some(ch1) => { - buf[N - bytes] = hex_to_byte(ch1, ch0)?; + Some(msb) => { + buf[N - bytes] = hex_to_byte(msb, lsb)?; }, None => { - buf[N - bytes] = hex_to_nybble(ch0)?; + buf[N - bytes] = hex_to_nybble(lsb)?; return Ok(buf); }, } @@ -205,23 +208,23 @@ pub fn hex_to_be_byte_array(hex: &str) -> Result<[u8; N], ArrayP } #[inline] -pub fn hex_to_byte(ch0: char, ch1: char) -> Result { - Ok((hex_to_nybble(ch0)? << 4) | hex_to_nybble(ch1)?) +pub fn hex_to_byte(msb: u8, lsb: u8) -> Result { + Ok((hex_to_nybble(msb)? << 4) | hex_to_nybble(lsb)?) } #[inline] -fn hex_to_nybble(ch: char) -> Result { - match ch { - '0'..='9' => Ok((ch as u8) - 0x30), - 'A'..='F' => Ok((ch as u8) - 0x37), - 'a'..='f' => Ok((ch as u8) - 0x57), - ch => Err(ParseError::BadChar(ch)), +fn hex_to_nybble(byte: u8) -> Result { + match byte { + b'0'..=b'9' => Ok(byte - 0x30), + b'A'..=b'F' => Ok(byte - 0x37), + b'a'..=b'f' => Ok(byte - 0x57), + byte => Err(ParseError::BadChar(byte)), } } #[derive(Debug)] pub enum ParseError { - BadChar(char), + BadChar(u8), } impl fmt::Display for ParseError { @@ -263,7 +266,7 @@ impl From for ArrayParseError { mod tests { use super::*; - #[cfg(any(feature = "alloc", feature = "std"))] + #[cfg(feature = "alloc")] #[test] fn test_hex_bytes_debug() { #[cfg(not(feature = "std"))] @@ -280,7 +283,7 @@ mod tests { ); } - #[cfg(any(feature = "alloc", feature = "std"))] + #[cfg(feature = "alloc")] #[test] fn test_hex_bytes_display() { #[cfg(not(feature = "std"))] @@ -365,31 +368,31 @@ mod tests { #[test] fn test_hex_to_nybble() { - assert_eq!(hex_to_nybble('0').unwrap(), 0x0); - assert_eq!(hex_to_nybble('1').unwrap(), 0x1); - assert_eq!(hex_to_nybble('2').unwrap(), 0x2); - assert_eq!(hex_to_nybble('3').unwrap(), 0x3); - assert_eq!(hex_to_nybble('4').unwrap(), 0x4); - assert_eq!(hex_to_nybble('5').unwrap(), 0x5); - assert_eq!(hex_to_nybble('6').unwrap(), 0x6); - assert_eq!(hex_to_nybble('7').unwrap(), 0x7); - assert_eq!(hex_to_nybble('8').unwrap(), 0x8); - assert_eq!(hex_to_nybble('9').unwrap(), 0x9); - assert_eq!(hex_to_nybble('a').unwrap(), 0xa); - assert_eq!(hex_to_nybble('b').unwrap(), 0xb); - assert_eq!(hex_to_nybble('c').unwrap(), 0xc); - assert_eq!(hex_to_nybble('d').unwrap(), 0xd); - assert_eq!(hex_to_nybble('e').unwrap(), 0xe); - assert_eq!(hex_to_nybble('f').unwrap(), 0xf); - assert_eq!(hex_to_nybble('A').unwrap(), 0xa); - assert_eq!(hex_to_nybble('B').unwrap(), 0xb); - assert_eq!(hex_to_nybble('C').unwrap(), 0xc); - assert_eq!(hex_to_nybble('D').unwrap(), 0xd); - assert_eq!(hex_to_nybble('E').unwrap(), 0xe); - assert_eq!(hex_to_nybble('F').unwrap(), 0xf); - - assert!(matches!(hex_to_nybble('g'), Err(ParseError::BadChar('g')))); - assert!(matches!(hex_to_nybble('G'), Err(ParseError::BadChar('G')))); + assert_eq!(hex_to_nybble(b'0').unwrap(), 0x0); + assert_eq!(hex_to_nybble(b'1').unwrap(), 0x1); + assert_eq!(hex_to_nybble(b'2').unwrap(), 0x2); + assert_eq!(hex_to_nybble(b'3').unwrap(), 0x3); + assert_eq!(hex_to_nybble(b'4').unwrap(), 0x4); + assert_eq!(hex_to_nybble(b'5').unwrap(), 0x5); + assert_eq!(hex_to_nybble(b'6').unwrap(), 0x6); + assert_eq!(hex_to_nybble(b'7').unwrap(), 0x7); + assert_eq!(hex_to_nybble(b'8').unwrap(), 0x8); + assert_eq!(hex_to_nybble(b'9').unwrap(), 0x9); + assert_eq!(hex_to_nybble(b'a').unwrap(), 0xa); + assert_eq!(hex_to_nybble(b'b').unwrap(), 0xb); + assert_eq!(hex_to_nybble(b'c').unwrap(), 0xc); + assert_eq!(hex_to_nybble(b'd').unwrap(), 0xd); + assert_eq!(hex_to_nybble(b'e').unwrap(), 0xe); + assert_eq!(hex_to_nybble(b'f').unwrap(), 0xf); + assert_eq!(hex_to_nybble(b'A').unwrap(), 0xa); + assert_eq!(hex_to_nybble(b'B').unwrap(), 0xb); + assert_eq!(hex_to_nybble(b'C').unwrap(), 0xc); + assert_eq!(hex_to_nybble(b'D').unwrap(), 0xd); + assert_eq!(hex_to_nybble(b'E').unwrap(), 0xe); + assert_eq!(hex_to_nybble(b'F').unwrap(), 0xf); + + assert!(matches!(hex_to_nybble(b'g'), Err(ParseError::BadChar(b'g')))); + assert!(matches!(hex_to_nybble(b'G'), Err(ParseError::BadChar(b'G')))); } #[test] @@ -414,7 +417,7 @@ mod tests { ); assert!(matches!( - hex_to_be_byte_array::<5>("d90058decebf"), + hex_to_be_byte_array::<5, _>("d90058decebf"), Err(ArrayParseError::TooLong(5)) )); } diff --git a/src/encoding/rfc3986.rs b/src/encoding/rfc3986.rs index e69de29..0c63cd0 100644 --- a/src/encoding/rfc3986.rs +++ b/src/encoding/rfc3986.rs @@ -0,0 +1,299 @@ +// Following RFC3986 (https://www.rfc-editor.org/rfc/rfc3986#section-2.1) + +use core::{ + fmt::{self, Write}, + str, +}; + +#[cfg(all(feature = "alloc", not(feature = "std")))] +use alloc::{borrow::Cow, string::String, vec::Vec}; + +#[cfg(feature = "std")] +use std::borrow::Cow; + +use crate::{either::{Either, Inl, Inr}, strings::FixedString}; + +use super::hex; + +/// Finds the first element of the slice which does not match the given predicate and returns the +/// sub-slice preceding that element, the element itself, and the sub-slice following the element. +#[inline] +fn split_at_non_matching(xs: &[T], predicate: P) -> (&[T], Option<(T, &[T])>) +where + T: Copy, + P: Fn(T) -> bool, +{ + let mut i = 0; + while i < xs.len() { + let x = xs[i]; + if !predicate(x) { + // `get_unchecked` is used here because the compiler currently seems to struggle to + // reason about the correctness of the start and end indexes here, and can end up + // leaving in unnecessary bound checks. + // SAFETY: + // We have already checked that `i < xs.len()`, so `..i` is in bounds for `xs`. + let prefix = unsafe { xs.get_unchecked(..i) }; + // SAFETY: + // We have already checked that `i < xs.len()`, so `i + 1 <= xs.len()` must hold. + // Therefore, `(i + 1)..` is in bounds for `xs`. + let suffix = unsafe { xs.get_unchecked((i + 1)..) }; + return (prefix, Some((x, suffix))); + } + i += 1; + } + (xs, None) +} + +fn byte_unreserved(byte: u8) -> bool { + matches!(byte, b'0'..=b'9' | b'A'..=b'Z' | b'a'..=b'z' | b'-' | b'.' | b'_' | b'~') +} + +struct PercentEncoder<'a>(&'a [u8]); + +impl<'a> PercentEncoder<'a> { + pub fn partial_encode(&mut self) -> Option<(&'a str, Option>)> { + if self.0.is_empty() { + return None; + } + + let (prefix, suffix) = split_at_non_matching(self.0, byte_unreserved); + + // SAFETY: + // `prefix` only contains bytes which satisfy `byte_unreserved`, which are all valid ASCII + // characters. Therefore, it is valid UTF-8. + let prefix = unsafe { str::from_utf8_unchecked(prefix) }; + + match suffix { + Some((byte, suffix)) => { + self.0 = suffix; + Some((prefix, Some(Self::percent_encode_byte(byte)))) + }, + + None => { + self.0 = &self.0[self.0.len()..]; + Some((prefix, None)) + }, + } + } + + fn percent_encode_byte(byte: u8) -> FixedString<3> { + let [msb, lsb] = hex::byte_to_hex_upper(byte).into_raw(); + // SAFETY: + // The bytes obtained from `hex::byte_to_hex_upper` are valid UTF-8, and `b'%'` is a valid + // UTF-8 codepoint, so the byte array is valid UTF-8. + unsafe { FixedString::from_raw_array([b'%', msb, lsb]) } + } +} + +#[cfg(feature = "alloc")] +pub fn percent_encode(bytes: &B) -> Cow +where + B: AsRef<[u8]> + ?Sized, +{ + let mut encoder = PercentEncoder(bytes.as_ref()); + + match encoder.partial_encode().unwrap_or(("", None)) { + (prefix, Some(encoded_byte)) => { + let mut buf = String::new(); + buf.push_str(prefix); + buf.push_str(&encoded_byte); + + while let Some((prefix, encoded_byte)) = encoder.partial_encode() { + buf.push_str(prefix); + if let Some(encoded_byte) = encoded_byte { + buf.push_str(&encoded_byte); + } + } + + Cow::Owned(buf) + }, + + (prefix, None) => Cow::Borrowed(prefix), + } +} + +pub fn percent_encode_to_fmt_writer(writer: &mut W, bytes: &B) -> fmt::Result +where + W: Write + ?Sized, + B: AsRef<[u8]> + ?Sized, +{ + let mut encoder = PercentEncoder(bytes.as_ref()); + + while let Some((prefix, encoded_byte)) = encoder.partial_encode() { + if !prefix.is_empty() { + writer.write_str(prefix)?; + } + if let Some(encoded_byte) = encoded_byte { + writer.write_str(&encoded_byte)?; + } + } + + Ok(()) +} + +struct PercentDecoder<'a>(&'a [u8]); + +impl<'a> PercentDecoder<'a> { + fn partial_decode(&mut self) -> Result)>, PercentDecodeError> { + if self.0.is_empty() { + return Ok(None); + } + + let (prefix, suffix) = split_at_non_matching(self.0, byte_unreserved); + + // SAFETY: + // `prefix` only contains bytes which satisfy `byte_unreserved`, which are all valid ASCII + // characters. Therefore, it is valid UTF-8. + let prefix = unsafe { str::from_utf8_unchecked(prefix) }; + + match suffix { + Some((byte, suffix)) => { + if byte != b'%' { + return Err(PercentDecodeError); + } + + let [hex_msb, hex_lsb]: [u8; 2] = suffix + .get(..2) + .and_then(|hex_bytes| hex_bytes.try_into().ok()) + .ok_or(PercentDecodeError)?; + + let hex_byte = hex::hex_to_byte(hex_msb, hex_lsb) + .map_err(|_| PercentDecodeError)?; + + self.0 = &suffix[2..]; + + Ok(Some((prefix, Some(hex_byte)))) + }, + + None => { + self.0 = &self.0[self.0.len()..]; + Ok(Some((prefix, None))) + }, + } + } +} + +#[cfg(feature = "alloc")] +fn percent_decode_internal(bytes: &B) -> Result>, PercentDecodeError> +where + B: AsRef<[u8]> + ?Sized, +{ + let mut decoder = PercentDecoder(bytes.as_ref()); + + match decoder.partial_decode()?.unwrap_or(("", None)) { + (prefix, Some(byte)) => { + let mut buf = Vec::new(); + buf.extend(prefix.bytes()); + buf.push(byte); + + while let Some((prefix, byte)) = decoder.partial_decode()? { + buf.extend(prefix.bytes()); + if let Some(byte) = byte { + buf.push(byte); + } + } + + Ok(Inr(buf)) + }, + + (prefix, None) => Ok(Inl(prefix)) + } +} + +#[cfg(feature = "alloc")] +pub fn percent_decode_to_utf8(bytes: &B) -> Result, PercentDecodeError> +where + B: AsRef<[u8]> + ?Sized, +{ + percent_decode_internal(bytes).and_then(|decoded| match decoded { + Inl(decoded_str) => Ok(Cow::Borrowed(decoded_str)), + Inr(decoded_bytes) => String::from_utf8(decoded_bytes) + .map(Cow::Owned) + .map_err(|_| PercentDecodeError), + }) +} + +#[cfg(feature = "alloc")] +pub fn percent_decode_to_bytes(bytes: &B) -> Result, PercentDecodeError> +where + B: AsRef<[u8]> + ?Sized, +{ + percent_decode_internal(bytes).map(|decoded| match decoded { + Inl(decoded_str) => Cow::Borrowed(decoded_str.as_bytes()), + Inr(decoded_bytes) => Cow::Owned(decoded_bytes), + }) +} + +#[derive(Debug)] +pub struct PercentDecodeError; + +impl fmt::Display for PercentDecodeError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "invalid rfc 3986 percent-encoded string") + } +} + +#[cfg(feature = "std")] +impl std::error::Error for PercentDecodeError {} + +#[cfg(test)] +mod tests { + #[cfg(feature = "alloc")] + #[test] + fn test_percent_encode() { + #[cfg(all(feature = "alloc", not(feature = "std")))] + use alloc::borrow::Cow; + + #[cfg(feature = "std")] + use std::borrow::Cow; + + use super::percent_encode; + + assert!(matches!(percent_encode(""), Cow::Borrowed(""))); + assert!(matches!(percent_encode("foobar"), Cow::Borrowed("foobar"))); + + assert_eq!(&*percent_encode("Ladies + Gentlemen"), "Ladies%20%2B%20Gentlemen"); + assert_eq!(&*percent_encode("An encoded string!"), "An%20encoded%20string%21"); + assert_eq!(&*percent_encode("Dogs, Cats & Mice"), "Dogs%2C%20Cats%20%26%20Mice"); + assert_eq!(&*percent_encode("☃"), "%E2%98%83"); + } + + #[cfg(feature = "alloc")] + #[test] + fn test_percent_decode() { + #[cfg(all(feature = "alloc", not(feature = "std")))] + use alloc::borrow::Cow; + + #[cfg(feature = "std")] + use std::borrow::Cow; + + use super::{percent_decode_to_utf8, percent_decode_to_bytes}; + + assert!(matches!(percent_decode_to_utf8(""), Ok(Cow::Borrowed("")))); + assert!(matches!(percent_decode_to_bytes(""), Ok(Cow::Borrowed(b"")))); + assert!(matches!(percent_decode_to_utf8("foobar"), Ok(Cow::Borrowed("foobar")))); + assert!(matches!(percent_decode_to_bytes("foobar"), Ok(Cow::Borrowed(b"foobar")))); + + assert!(matches!(percent_decode_to_utf8("Ladies%20%2B%20Gentlemen").as_deref(), Ok("Ladies + Gentlemen"))); + assert!(matches!(percent_decode_to_bytes("Ladies%20%2B%20Gentlemen").as_deref(), Ok(b"Ladies + Gentlemen"))); + assert!(matches!(percent_decode_to_utf8("An%20encoded%20string%21").as_deref(), Ok("An encoded string!"))); + assert!(matches!(percent_decode_to_bytes("An%20encoded%20string%21").as_deref(), Ok(b"An encoded string!"))); + assert!(matches!(percent_decode_to_utf8("Dogs%2C%20Cats%20%26%20Mice").as_deref(), Ok("Dogs, Cats & Mice"))); + assert!(matches!(percent_decode_to_bytes("Dogs%2C%20Cats%20%26%20Mice").as_deref(), Ok(b"Dogs, Cats & Mice"))); + assert!(matches!(percent_decode_to_utf8("%E2%98%83").as_deref(), Ok("☃"))); + + assert!(matches!(percent_decode_to_utf8("%e2%98%83").as_deref(), Ok("☃"))); + + assert!(matches!(percent_decode_to_utf8("%41%6E%20%65%6E%63%6F%64%65%64%20%73%74%72%69%6E%67%21").as_deref(), Ok("An encoded string!"))); + + assert!(matches!(percent_decode_to_utf8("hello!"), Err(_))); + assert!(matches!(percent_decode_to_bytes("hello!"), Err(_))); + assert!(matches!(percent_decode_to_utf8("%2"), Err(_))); + assert!(matches!(percent_decode_to_bytes("%2"), Err(_))); + assert!(matches!(percent_decode_to_utf8("%2!"), Err(_))); + assert!(matches!(percent_decode_to_bytes("%2!"), Err(_))); + + assert!(matches!(percent_decode_to_utf8("%FF"), Err(_))); + assert!(matches!(percent_decode_to_bytes("%FF").as_deref(), Ok(&[0xff]))); + } +} diff --git a/src/lib.rs b/src/lib.rs index 9220636..064f823 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,6 +1,6 @@ #![cfg_attr(not(feature = "std"), no_std)] -#[cfg(feature = "alloc")] +#[cfg(all(feature = "alloc", not(feature = "std")))] extern crate alloc; pub mod convert; diff --git a/src/strings/capped.rs b/src/strings/capped.rs index ae41e5b..25bf6cc 100644 --- a/src/strings/capped.rs +++ b/src/strings/capped.rs @@ -118,7 +118,7 @@ impl CappedString { } } -#[cfg(any(feature = "alloc", feature = "std"))] +#[cfg(feature = "alloc")] impl CappedString { pub fn into_boxed_str(self) -> Box { self.as_str().into() @@ -189,7 +189,7 @@ impl<'a, const N: usize> TryFrom<&'a str> for CappedString { } } -#[cfg(any(feature = "alloc", feature = "std"))] +#[cfg(feature = "alloc")] impl TryFrom for CappedString { type Error = Error; @@ -199,7 +199,7 @@ impl TryFrom for CappedString { } } -#[cfg(any(feature = "alloc", feature = "std"))] +#[cfg(feature = "alloc")] impl<'a, const N: usize> TryFrom> for CappedString { type Error = Error; @@ -209,7 +209,7 @@ impl<'a, const N: usize> TryFrom> for CappedString { } } -#[cfg(any(feature = "alloc", feature = "std"))] +#[cfg(feature = "alloc")] impl From> for String { #[inline] fn from(s: CappedString) -> Self { diff --git a/src/strings/mod.rs b/src/strings/mod.rs index 30aacbc..6facc8d 100644 --- a/src/strings/mod.rs +++ b/src/strings/mod.rs @@ -1,9 +1,9 @@ pub mod fixed; pub mod capped; -#[cfg(any(feature = "alloc", feature = "std"))] +#[cfg(feature = "alloc")] pub mod inlining; pub use fixed::{FixedString, Error as FixedStringError}; pub use capped::{CappedString, Error as CappedStringError}; -#[cfg(any(feature = "alloc", feature = "std"))] +#[cfg(feature = "alloc")] pub use inlining::{InliningString, InliningString23}; diff --git a/src/uuid.rs b/src/uuid.rs index 85a73f9..c8cb716 100644 --- a/src/uuid.rs +++ b/src/uuid.rs @@ -133,19 +133,19 @@ impl str::FromStr for Uuid { let mut buf = [0u8; 16]; buf[..4].copy_from_slice( - &hex::hex_to_be_byte_array::<4>(groups[0]).map_err(ParseError::BadTimeLow)?, + &hex::hex_to_be_byte_array::<4, _>(groups[0]).map_err(ParseError::BadTimeLow)?, ); buf[4..6].copy_from_slice( - &hex::hex_to_be_byte_array::<2>(groups[1]).map_err(ParseError::BadTimeMid)?, + &hex::hex_to_be_byte_array::<2, _>(groups[1]).map_err(ParseError::BadTimeMid)?, ); buf[6..8].copy_from_slice( - &hex::hex_to_be_byte_array::<2>(groups[2]).map_err(ParseError::BadTimeHi)?, + &hex::hex_to_be_byte_array::<2, _>(groups[2]).map_err(ParseError::BadTimeHi)?, ); buf[8..10].copy_from_slice( - &hex::hex_to_be_byte_array::<2>(groups[3]).map_err(ParseError::BadClockSeq)?, + &hex::hex_to_be_byte_array::<2, _>(groups[3]).map_err(ParseError::BadClockSeq)?, ); buf[10..].copy_from_slice( - &hex::hex_to_be_byte_array::<6>(groups[4]).map_err(ParseError::BadNode)?, + &hex::hex_to_be_byte_array::<6, _>(groups[4]).map_err(ParseError::BadNode)?, ); Ok(Self::from_bytes(buf)) @@ -412,7 +412,7 @@ impl<'a> Iterator for Sha1ChunkIter<'a> { mod tests { use super::Uuid; - #[cfg(any(feature = "alloc", feature = "std"))] + #[cfg(feature = "alloc")] #[test] fn test_uuid_display() { #[cfg(not(feature = "std"))]