Optional decoding of plus character in percent encoded strings

main
pantonshire 3 years ago
parent f0cc229a65
commit 8f263e330a

@ -5,12 +5,12 @@
use core::str;
#[cfg(all(feature = "alloc", not(feature = "std")))]
use alloc::{borrow::Cow, string::String, vec::Vec};
use alloc::{borrow::Cow, boxed::Box, string::String, vec::Vec};
#[cfg(feature = "std")]
use std::borrow::Cow;
use crate::{strings::FixedString, sink::StrSink};
use crate::{sink::StrSink, strings::FixedString};
use super::hex;
@ -35,12 +35,12 @@ where
// SAFETY:
// We have already checked that `i < xs.len()`, so `..i` is in bounds for `xs`.
let prefix = unsafe { xs.get_unchecked(..i) };
// SAFETY:
// We have already checked that `i < xs.len()`, so `i + 1 <= xs.len()` must hold.
// Therefore, `(i + 1)..` is in bounds for `xs`.
let suffix = unsafe { xs.get_unchecked((i + 1)..) };
return (prefix, Some((x, suffix)));
}
@ -87,9 +87,10 @@ impl<'a> Iterator for PercentEncoder<'a> {
// a `prefix` consisting entirely of characters which do not need to be percent-encoded,
// followed by a `suffix` which is either `None` or starts which a character which needs
// to be percent-encoded.
let (prefix, suffix) = split_at(self.remaining, |b| {
!matches!(b, b'0'..=b'9' | b'A'..=b'Z' | b'a'..=b'z' | b'-' | b'.' | b'_' | b'~')
});
let (prefix, suffix) = split_at(
self.remaining,
|b| !matches!(b, b'0'..=b'9' | b'A'..=b'Z' | b'a'..=b'z' | b'-' | b'.' | b'_' | b'~'),
);
// SAFETY:
// `prefix` only contains characters in the unreserved set, which are all valid ASCII
@ -103,7 +104,7 @@ impl<'a> Iterator for PercentEncoder<'a> {
Some((byte, suffix)) => {
self.remaining = suffix;
Some((prefix, Some(Self::percent_encode_byte(byte))))
},
}
// If there's no suffix, then we've reached the end of the input string. Therefore, we
// set the length of the iterator's slice to 0 to indicate that we are done, and then
@ -111,7 +112,7 @@ impl<'a> Iterator for PercentEncoder<'a> {
None => {
self.remaining = &self.remaining[self.remaining.len()..];
Some((prefix, None))
},
}
}
}
}
@ -130,7 +131,7 @@ where
buf.push_str(prefix);
buf.push_str(&encoded_byte);
for (prefix, encoded_byte) in encoder {
for (prefix, encoded_byte) in encoder {
buf.push_str(prefix);
if let Some(encoded_byte) = encoded_byte {
buf.push_str(&encoded_byte);
@ -138,7 +139,7 @@ where
}
Cow::Owned(buf)
},
}
Some((prefix, None)) => Cow::Borrowed(prefix),
@ -174,22 +175,27 @@ where
Ok(())
}
pub struct PercentDecoder<'a> {
pub struct PercentDecoder<'a, M> {
remaining: &'a [u8],
mode: M,
}
impl<'a> PercentDecoder<'a> {
pub fn new<B>(bytes: &'a B) -> Self
impl<'a, M> PercentDecoder<'a, M> {
pub fn new<B>(bytes: &'a B, mode: M) -> Self
where
B: AsRef<[u8]> + ?Sized,
{
Self {
remaining: bytes.as_ref(),
mode,
}
}
}
impl<'a> Iterator for PercentDecoder<'a> {
impl<'a, M> Iterator for PercentDecoder<'a, M>
where
M: PercentDecodeMode,
{
type Item = (&'a [u8], Option<u8>);
fn next(&mut self) -> Option<Self::Item> {
@ -200,6 +206,23 @@ impl<'a> Iterator for PercentDecoder<'a> {
let mut i = 0;
while i < self.remaining.len() {
// The '+' character being decoded to a space does not appear in the URL standard
// section on percent-encoding, but it does appear in the section on
// application/x-www-form-urlencoded. We implement it here as an optional feature to
// simplify things.
if self.mode.plus_space() && self.remaining[i] == b'+' {
// SAFETY:
// `i < self.remaining.len()`, so `..i` is a valid range over the slice.
let prefix = unsafe { self.remaining.get_unchecked(..i) };
// SAFETY:
// `i < self.remaining.len()`, so `i + 1 <= self.remaining.len()`. Therefore,
// `(i + 1)..` is a valid range over the slice.
self.remaining = unsafe { self.remaining.get_unchecked((i + 1)..) };
return Some((prefix, Some(b' ')));
}
// According to the URL standard, the only special case we need to handle is when the
// percent character '%' is followed immediately by two hex digits. We check that there
// are at least two characters after the percent with `self.remaining.len() - i > 2`,
@ -214,9 +237,12 @@ impl<'a> Iterator for PercentDecoder<'a> {
// gives `i + 2 < self.remaining.len()`. Therefore, `i + 1` and `i + 2` are valid
// indexes into the slice.
let (msb, lsb) = unsafe {
(*self.remaining.get_unchecked(i + 1), *self.remaining.get_unchecked(i + 2))
(
*self.remaining.get_unchecked(i + 1),
*self.remaining.get_unchecked(i + 2),
)
};
// If the two bytes are valid hex digits, decode the hex number.
if let Ok(decoded) = hex::hex_to_byte(msb, lsb) {
// SAFETY:
@ -243,12 +269,66 @@ impl<'a> Iterator for PercentDecoder<'a> {
}
}
pub trait PercentDecodeMode: percent_decode_mode::PercentDecodeModeSealed {}
impl<T: PercentDecodeMode> PercentDecodeMode for &T {}
#[cfg(feature = "alloc")]
impl<T: PercentDecodeMode> PercentDecodeMode for Box<T> {}
pub struct StandardDecode;
impl PercentDecodeMode for StandardDecode {}
pub struct FormDecode;
impl PercentDecodeMode for FormDecode {}
mod percent_decode_mode {
#[cfg(all(feature = "alloc", not(feature = "std")))]
use alloc::boxed::Box;
pub trait PercentDecodeModeSealed {
fn plus_space(&self) -> bool;
}
impl<T: PercentDecodeModeSealed> PercentDecodeModeSealed for &T {
#[inline]
fn plus_space(&self) -> bool {
T::plus_space(self)
}
}
#[cfg(feature = "alloc")]
impl<T: PercentDecodeModeSealed> PercentDecodeModeSealed for Box<T> {
#[inline]
fn plus_space(&self) -> bool {
T::plus_space(&**self)
}
}
impl PercentDecodeModeSealed for super::StandardDecode {
#[inline]
fn plus_space(&self) -> bool {
false
}
}
impl PercentDecodeModeSealed for super::FormDecode {
#[inline]
fn plus_space(&self) -> bool {
true
}
}
}
#[cfg(feature = "alloc")]
pub fn percent_decode<B>(bytes: &B) -> Cow<[u8]>
pub fn percent_decode<B, M>(bytes: &B, mode: M) -> Cow<[u8]>
where
B: AsRef<[u8]> + ?Sized,
M: PercentDecodeMode,
{
let mut decoder = PercentDecoder::new(bytes);
let mut decoder = PercentDecoder::new(bytes, mode);
match decoder.next() {
Some((prefix, Some(byte))) => {
@ -264,7 +344,7 @@ where
}
Cow::Owned(buf)
},
}
Some((prefix, None)) => Cow::Borrowed(prefix),
@ -273,23 +353,27 @@ where
}
#[cfg(feature = "alloc")]
pub fn percent_decode_utf8<B>(bytes: &B) -> Cow<str>
pub fn percent_decode_utf8<B, M>(bytes: &B, mode: M) -> Cow<str>
where
B: AsRef<[u8]> + ?Sized,
M: PercentDecodeMode,
{
match percent_decode(bytes) {
match percent_decode(bytes, mode) {
Cow::Borrowed(decoded) => String::from_utf8_lossy(decoded),
Cow::Owned(decoded) => match String::from_utf8_lossy(&decoded) {
Cow::Borrowed(decoded_str) => {
debug_assert_eq!(decoded_str.len(), decoded.len());
debug_assert_eq!(decoded_str.as_bytes().as_ptr() as *const u8, decoded.as_ptr());
debug_assert_eq!(
decoded_str.as_bytes().as_ptr() as *const u8,
decoded.as_ptr()
);
// SAFETY:
// `String::from_utf8_lossy` returned a `Cow::Borrowed`, which means that
// `decoded` is valid UTF-8.
let decoded = unsafe { String::from_utf8_unchecked(decoded) };
Cow::Owned(decoded)
},
}
Cow::Owned(decoded) => Cow::Owned(decoded),
},
}
@ -311,9 +395,18 @@ mod tests {
assert!(matches!(percent_encode(""), Cow::Borrowed("")));
assert!(matches!(percent_encode("foobar"), Cow::Borrowed("foobar")));
assert_eq!(&*percent_encode("Ladies + Gentlemen"), "Ladies%20%2B%20Gentlemen");
assert_eq!(&*percent_encode("An encoded string!"), "An%20encoded%20string%21");
assert_eq!(&*percent_encode("Dogs, Cats & Mice"), "Dogs%2C%20Cats%20%26%20Mice");
assert_eq!(
&*percent_encode("Ladies + Gentlemen"),
"Ladies%20%2B%20Gentlemen"
);
assert_eq!(
&*percent_encode("An encoded string!"),
"An%20encoded%20string%21"
);
assert_eq!(
&*percent_encode("Dogs, Cats & Mice"),
"Dogs%2C%20Cats%20%26%20Mice"
);
assert_eq!(&*percent_encode("☃"), "%E2%98%83");
}
@ -326,26 +419,73 @@ mod tests {
#[cfg(feature = "std")]
use std::borrow::Cow;
use super::{percent_decode_utf8};
assert!(matches!(percent_decode_utf8(""), Cow::Borrowed("")));
assert!(matches!(percent_decode_utf8("foobar"), Cow::Borrowed("foobar")));
assert_eq!(&*percent_decode_utf8("Ladies%20%2B%20Gentlemen"), "Ladies + Gentlemen");
assert_eq!(&*percent_decode_utf8("An%20encoded%20string%21"), "An encoded string!");
assert_eq!(&*percent_decode_utf8("Dogs%2C%20Cats%20%26%20Mice"), "Dogs, Cats & Mice");
assert_eq!(&*percent_decode_utf8("%E2%98%83"), "☃");
assert_eq!(&*percent_decode_utf8("%e2%98%83"), "☃");
assert_eq!(&*percent_decode_utf8("%41%6E%20%65%6E%63%6F%64%65%64%20%73%74%72%69%6E%67%21"), "An encoded string!");
assert_eq!(&*percent_decode_utf8("hello!"), "hello!");
assert_eq!(&*percent_decode_utf8("hello%"), "hello%");
assert_eq!(&*percent_decode_utf8("%a"), "%a");
assert_eq!(&*percent_decode_utf8("%za"), "%za");
assert_eq!(&*percent_decode_utf8("%az"), "%az");
assert_eq!(&*percent_decode_utf8("hello%FFworld"), "hello<6C>world");
use super::{percent_decode_utf8, FormDecode, StandardDecode};
assert!(matches!(
percent_decode_utf8("", StandardDecode),
Cow::Borrowed("")
));
assert!(matches!(
percent_decode_utf8("foobar", StandardDecode),
Cow::Borrowed("foobar")
));
assert_eq!(
&*percent_decode_utf8("Ladies%20%2B%20Gentlemen", StandardDecode),
"Ladies + Gentlemen"
);
assert_eq!(
&*percent_decode_utf8("An%20encoded%20string%21", StandardDecode),
"An encoded string!"
);
assert_eq!(
&*percent_decode_utf8("Dogs%2C%20Cats%20%26%20Mice", StandardDecode),
"Dogs, Cats & Mice"
);
assert_eq!(
&*percent_decode_utf8("%E2%98%83", StandardDecode),
"☃"
);
assert_eq!(
&*percent_decode_utf8("%e2%98%83", StandardDecode),
"☃"
);
assert_eq!(
&*percent_decode_utf8(
"%41%6E%20%65%6E%63%6F%64%65%64%20%73%74%72%69%6E%67%21",
StandardDecode
),
"An encoded string!"
);
assert_eq!(&*percent_decode_utf8("hello!", StandardDecode), "hello!");
assert_eq!(&*percent_decode_utf8("hello%", StandardDecode), "hello%");
assert_eq!(&*percent_decode_utf8("%a", StandardDecode), "%a");
assert_eq!(&*percent_decode_utf8("%za", StandardDecode), "%za");
assert_eq!(&*percent_decode_utf8("%az", StandardDecode), "%az");
assert_eq!(
&*percent_decode_utf8("hello%FFworld", StandardDecode),
"hello<6C>world"
);
assert_eq!(
&*percent_decode_utf8("hello+world", StandardDecode),
"hello+world"
);
assert_eq!(
&*percent_decode_utf8("hello+world", FormDecode),
"hello world"
);
assert_eq!(
&*percent_decode_utf8("hello++world", FormDecode),
"hello world"
);
assert_eq!(
&*percent_decode_utf8("+hello+world+", FormDecode),
" hello world "
);
}
}

Loading…
Cancel
Save