Add decode function and tests

TODO: Fix test failure in encoding with incorrect padding calculation
This commit is contained in:
Nathan Fisher 2024-05-03 00:31:08 -04:00
parent e12cc66fba
commit 2236a572f7
3 changed files with 82 additions and 4 deletions

View File

@ -7,6 +7,7 @@ pub use {
pub enum Error { pub enum Error {
Io(io::Error), Io(io::Error),
IllegalChar(char), IllegalChar(char),
MissingPadding,
} }
impl From<io::Error> for Error { impl From<io::Error> for Error {
@ -27,7 +28,78 @@ pub struct Decoder<R: Read, W: Write> {
alphabet: B64Alphabet, alphabet: B64Alphabet,
} }
impl<R: Read, W: Write> Decoder<R, W> {
pub fn new(reader: R, writer: W, alphabet: Option<B64Alphabet>) -> Self {
Self {
reader,
writer,
alphabet: alphabet.unwrap_or_default(),
}
}
pub fn decode(&mut self) -> Result<(), Error> {
loop {
let mut ibuf = [0_u8; 4];
let mut obuf = [0_u8; 3];
let mut num = 0_u32;
let mut n_bytes = 0;
loop {
n_bytes += match self.reader.read(&mut ibuf) {
Ok(n) => n,
Err(e) if e.kind() == ErrorKind::Interrupted => continue,
Err(e) => return Err(e.into()),
};
break;
}
match n_bytes {
0 => break,
4 => {}
_ => return Err(Error::MissingPadding),
}
let mut bytes = 0;
for (i, &c) in ibuf.iter().enumerate() {
let c = c.into();
num <<= 6;
if c == self.alphabet.pad() {
continue;
}
if i != bytes {
return Err(Error::IllegalChar(c));
}
let Some(idx) = self.alphabet.get(c) else {
return Err(Error::IllegalChar(c));
};
num |= idx as u32;
bytes += 1;
}
let olen = bytes * 6 / 8;
for i in (0..3).rev() {
obuf[i] = (num & 0xff) as u8;
num >>= 8;
}
self.writer.write_all(&mut obuf[0..olen])?;
}
Ok(())
}
}
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
#[test]
fn decode() {
let mut decoder = Decoder::new("SGVsbG8sIFdvcmxk".as_bytes(), vec![], None);
decoder.decode().unwrap();
assert_eq!(String::from_utf8(decoder.writer).unwrap(), "Hello, World");
decoder = Decoder::new("SGVsbG8sIFdvcmxkIQ==".as_bytes(), vec![], None);
decoder.decode().unwrap();
assert_eq!(String::from_utf8(decoder.writer).unwrap(), "Hello, World!");
decoder = Decoder::new("SGVsbG8sIFdvcmxkIQo=".as_bytes(), vec![], None);
decoder.decode().unwrap();
assert_eq!(
String::from_utf8(decoder.writer).unwrap(),
"Hello, World!\n"
);
}
} }

View File

@ -49,11 +49,11 @@ pub struct Encoder<R: Read, W: Write> {
} }
impl<R: Read, W: Write> Encoder<R, W> { impl<R: Read, W: Write> Encoder<R, W> {
pub fn new(reader: R, writer: W, alphabet: B64Alphabet) -> Self { pub fn new(reader: R, writer: W, alphabet: Option<B64Alphabet>) -> Self {
Self { Self {
reader, reader,
writer, writer,
alphabet, alphabet: alphabet.unwrap_or_default(),
} }
} }
@ -116,12 +116,18 @@ mod tests {
#[test] #[test]
fn encode() { fn encode() {
let mut encoder = Encoder { let mut encoder = Encoder::new("Hello, World".as_bytes(), String::new(), None);
encoder.encode().unwrap();
assert_eq!(encoder.output(), "SGVsbG8sIFdvcmxk");
encoder = Encoder {
reader: "Hello, World!".as_bytes(), reader: "Hello, World!".as_bytes(),
writer: String::new(), writer: String::new(),
alphabet: B64Alphabet::default(), alphabet: B64Alphabet::default(),
}; };
encoder.encode().unwrap(); encoder.encode().unwrap();
assert_eq!(encoder.output(), "SGVsbG8sIFdvcmxkIQ=="); assert_eq!(encoder.output(), "SGVsbG8sIFdvcmxkIQ==");
encoder = Encoder::new("Hello, World!\n".as_bytes(), String::new(), None);
encoder.encode().unwrap();
assert_eq!(encoder.output(), "SGVsbG8sIFdvcmxkIQo=");
} }
} }

View File

@ -29,7 +29,7 @@ impl Default for B64Alphabet {
} }
impl B64Alphabet { impl B64Alphabet {
pub fn idx(&self, c: char) -> Option<usize> { pub fn get(&self, c: char) -> Option<usize> {
for (idx, x) in self.items.iter().enumerate() { for (idx, x) in self.items.iter().enumerate() {
if *x == c { if *x == c {
return Some(idx); return Some(idx);