From 2236a572f77bd287672e033c55fd89935801bbef Mon Sep 17 00:00:00 2001 From: Nathan Fisher Date: Fri, 3 May 2024 00:31:08 -0400 Subject: [PATCH] Add decode function and tests TODO: Fix test failure in encoding with incorrect padding calculation --- src/decode.rs | 72 +++++++++++++++++++++++++++++++++++++++++++++++++++ src/encode.rs | 12 ++++++--- src/lib.rs | 2 +- 3 files changed, 82 insertions(+), 4 deletions(-) diff --git a/src/decode.rs b/src/decode.rs index 8c6c82a..21af1f3 100644 --- a/src/decode.rs +++ b/src/decode.rs @@ -7,6 +7,7 @@ pub use { pub enum Error { Io(io::Error), IllegalChar(char), + MissingPadding, } impl From for Error { @@ -27,7 +28,78 @@ pub struct Decoder { alphabet: B64Alphabet, } +impl Decoder { + pub fn new(reader: R, writer: W, alphabet: Option) -> Self { + Self { + reader, + writer, + alphabet: alphabet.unwrap_or_default(), + } + } + + pub fn decode(&mut self) -> Result<(), Error> { + loop { + let mut ibuf = [0_u8; 4]; + let mut obuf = [0_u8; 3]; + let mut num = 0_u32; + let mut n_bytes = 0; + loop { + n_bytes += match self.reader.read(&mut ibuf) { + Ok(n) => n, + Err(e) if e.kind() == ErrorKind::Interrupted => continue, + Err(e) => return Err(e.into()), + }; + break; + } + match n_bytes { + 0 => break, + 4 => {} + _ => return Err(Error::MissingPadding), + } + let mut bytes = 0; + for (i, &c) in ibuf.iter().enumerate() { + let c = c.into(); + num <<= 6; + if c == self.alphabet.pad() { + continue; + } + if i != bytes { + return Err(Error::IllegalChar(c)); + } + let Some(idx) = self.alphabet.get(c) else { + return Err(Error::IllegalChar(c)); + }; + num |= idx as u32; + bytes += 1; + } + let olen = bytes * 6 / 8; + for i in (0..3).rev() { + obuf[i] = (num & 0xff) as u8; + num >>= 8; + } + self.writer.write_all(&mut obuf[0..olen])?; + } + Ok(()) + } +} + #[cfg(test)] mod tests { use super::*; + + #[test] + fn decode() { + let mut decoder = Decoder::new("SGVsbG8sIFdvcmxk".as_bytes(), vec![], None); + decoder.decode().unwrap(); + assert_eq!(String::from_utf8(decoder.writer).unwrap(), "Hello, World"); + decoder = Decoder::new("SGVsbG8sIFdvcmxkIQ==".as_bytes(), vec![], None); + decoder.decode().unwrap(); + assert_eq!(String::from_utf8(decoder.writer).unwrap(), "Hello, World!"); + decoder = Decoder::new("SGVsbG8sIFdvcmxkIQo=".as_bytes(), vec![], None); + decoder.decode().unwrap(); + assert_eq!( + String::from_utf8(decoder.writer).unwrap(), + "Hello, World!\n" + ); + } } diff --git a/src/encode.rs b/src/encode.rs index dd08d12..fe59294 100644 --- a/src/encode.rs +++ b/src/encode.rs @@ -49,11 +49,11 @@ pub struct Encoder { } impl Encoder { - pub fn new(reader: R, writer: W, alphabet: B64Alphabet) -> Self { + pub fn new(reader: R, writer: W, alphabet: Option) -> Self { Self { reader, writer, - alphabet, + alphabet: alphabet.unwrap_or_default(), } } @@ -116,12 +116,18 @@ mod tests { #[test] fn encode() { - let mut encoder = Encoder { + let mut encoder = Encoder::new("Hello, World".as_bytes(), String::new(), None); + encoder.encode().unwrap(); + assert_eq!(encoder.output(), "SGVsbG8sIFdvcmxk"); + encoder = Encoder { reader: "Hello, World!".as_bytes(), writer: String::new(), alphabet: B64Alphabet::default(), }; encoder.encode().unwrap(); assert_eq!(encoder.output(), "SGVsbG8sIFdvcmxkIQ=="); + encoder = Encoder::new("Hello, World!\n".as_bytes(), String::new(), None); + encoder.encode().unwrap(); + assert_eq!(encoder.output(), "SGVsbG8sIFdvcmxkIQo="); } } diff --git a/src/lib.rs b/src/lib.rs index d349fca..0a90c62 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -29,7 +29,7 @@ impl Default for B64Alphabet { } impl B64Alphabet { - pub fn idx(&self, c: char) -> Option { + pub fn get(&self, c: char) -> Option { for (idx, x) in self.items.iter().enumerate() { if *x == c { return Some(idx);