diff --git a/src/lib.rs b/src/lib.rs index a88ea23..779971d 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,8 +1,9 @@ #![warn(clippy::all, clippy::pedantic)] -pub mod fingerprint; -pub mod host; +mod fingerprint; +mod host; +pub mod prelude; pub mod receiver; -pub mod request; -pub mod response; -pub mod sender; -pub mod status; +mod request; +mod response; +mod sender; +mod status; diff --git a/src/prelude.rs b/src/prelude.rs new file mode 100644 index 0000000..a42fb1e --- /dev/null +++ b/src/prelude.rs @@ -0,0 +1,9 @@ +pub use super::{ + fingerprint::{Fingerprint, Error as FingerprintError}, + host::{Host, ParseHostError}, + receiver, + response::{Response, Error as ParseResponseError}, + request::{Request, ParseRequestError}, + sender::{CertificateStore, Error as SenderError, Sender, Verifier}, + status::*, +}; diff --git a/src/response.rs b/src/response.rs index 77821f4..a1c377f 100644 --- a/src/response.rs +++ b/src/response.rs @@ -1,6 +1,8 @@ -use std::{fmt, num::ParseIntError, str::FromStr}; +use std::{fmt, str::FromStr}; +use crate::prelude::Status; -use crate::status::{ParseStatusError, Status}; +mod error; +pub use error::Error; #[derive(Clone, Debug, PartialEq)] pub struct Response { @@ -14,58 +16,18 @@ impl fmt::Display for Response { } } -#[derive(Clone, Debug, PartialEq)] -pub enum ParseResponseError { - TooLong, - ParseInt(ParseIntError), - StatusError, - Malformed, -} - -impl fmt::Display for ParseResponseError { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - match self { - Self::TooLong => write!(f, "ParseResponseError: too long"), - Self::ParseInt(e) => write!(f, "ParseResponseError: {e}"), - Self::StatusError => write!(f, "ParseResponseError: Invalid Status"), - Self::Malformed => write!(f, "ParseResponseError: Malformed"), - } - } -} - -impl std::error::Error for ParseResponseError { - fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { - match self { - Self::ParseInt(e) => Some(e), - _ => None, - } - } -} - -impl From for ParseResponseError { - fn from(value: ParseIntError) -> Self { - Self::ParseInt(value) - } -} - -impl From for ParseResponseError { - fn from(_value: ParseStatusError) -> Self { - Self::StatusError - } -} - impl FromStr for Response { - type Err = ParseResponseError; + type Err = Error; fn from_str(s: &str) -> Result { if s.len() > 2048 { - return Err(ParseResponseError::TooLong); + return Err(Error::TooLong); } if !s.ends_with("\r\n") { - return Err(ParseResponseError::Malformed); + return Err(Error::Malformed); } let Some((status, meta)) = s.split_once(' ') else { - return Err(ParseResponseError::Malformed); + return Err(Error::Malformed); }; let status: u8 = status.parse()?; let status: Status = status.try_into()?; @@ -90,14 +52,14 @@ mod tests { #[test] fn parse_badend() { let response = "20 message delivered\n".parse::(); - assert_eq!(response, Err(ParseResponseError::Malformed)); + assert_eq!(response, Err(Error::Malformed)); } #[test] fn parse_badint() { let response = "twenty message deliverred\r\n".parse::(); match response { - Err(ParseResponseError::ParseInt(_)) => {} + Err(Error::ParseInt(_)) => {} _ => panic!(), } } diff --git a/src/response/error.rs b/src/response/error.rs new file mode 100644 index 0000000..7783943 --- /dev/null +++ b/src/response/error.rs @@ -0,0 +1,45 @@ +use { + crate::prelude::ParseStatusError, + std::{fmt, num::ParseIntError}, +}; + +#[derive(Clone, Debug, PartialEq)] +pub enum Error { + TooLong, + ParseInt(ParseIntError), + StatusError, + Malformed, +} + +impl fmt::Display for Error { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Self::TooLong => write!(f, "ParseResponseError: too long"), + Self::ParseInt(e) => write!(f, "ParseResponseError: {e}"), + Self::StatusError => write!(f, "ParseResponseError: Invalid Status"), + Self::Malformed => write!(f, "ParseResponseError: Malformed"), + } + } +} + +impl std::error::Error for Error { + fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { + match self { + Self::ParseInt(e) => Some(e), + _ => None, + } + } +} + +impl From for Error { + fn from(value: ParseIntError) -> Self { + Self::ParseInt(value) + } +} + +impl From for Error { + fn from(_value: ParseStatusError) -> Self { + Self::StatusError + } +} + diff --git a/src/sender.rs b/src/sender.rs index 520e35f..979c175 100644 --- a/src/sender.rs +++ b/src/sender.rs @@ -1,84 +1,32 @@ -use self::{store::CertificateStore, verifier::Verifier}; use crate::{ - request::{ParseRequestError, Request}, - response::{ParseResponseError, Response}, -}; -use std::{ - fmt, - io::{self, Read, Write}, + request::Request, + response::Response, }; +use std::io::{Read, Write}; +pub use self::{error::Error, verifier::{CertificateStore, Verifier}}; -pub mod store; -pub mod verifier; +mod error; +mod verifier; #[derive(Debug)] -pub struct Sender<'a, S: CertificateStore, C: Sized, T: Read + Write + Sized> { - pub request: Request, - pub verifier: Verifier<'a, S>, - pub stream: rustls::StreamOwned, -} - -#[derive(Debug)] -pub enum Error { - TlsError(rustls::Error), - RequestError(ParseRequestError), - ResponseError(ParseResponseError), - IoError(io::Error), -} - -impl fmt::Display for Error { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - match self { - Self::TlsError(e) => write!(f, "{e}"), - Self::RequestError(e) => write!(f, "{e}"), - Self::ResponseError(e) => write!(f, "{e}"), - Self::IoError(e) => write!(f, "{e}"), - } - } -} - -impl std::error::Error for Error { - fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { - match self { - Self::RequestError(e) => Some(e), - Self::ResponseError(e) => Some(e), - Self::TlsError(e) => Some(e), - Self::IoError(e) => Some(e), - } - } -} - -impl From for Error { - fn from(value: rustls::Error) -> Self { - Self::TlsError(value) - } -} - -impl From for Error { - fn from(value: ParseRequestError) -> Self { - Self::RequestError(value) - } -} - -impl From for Error { - fn from(value: ParseResponseError) -> Self { - Self::ResponseError(value) - } -} - -impl From for Error { - fn from(value: io::Error) -> Self { - Self::IoError(value) - } -} - -impl<'a, S, C, T> Sender<'a, S, C, T> +pub struct Sender where - S: CertificateStore + Sync, + S: CertificateStore, C: Sized, T: Read + Write + Sized, { - pub fn new(request_str: &str, store: &'a S) -> Result { + pub request: Request, + pub verifier: Verifier, + pub stream: rustls::StreamOwned, +} + +impl Sender +where + S: CertificateStore, + C: Sized, + T: Read + Write + Sized, +{ + pub fn new(request_str: &str, store: S) -> Result { let request: Request = request_str.parse()?; let verifier = Verifier::new(store); unimplemented!(); diff --git a/src/sender/error.rs b/src/sender/error.rs new file mode 100644 index 0000000..ae7e573 --- /dev/null +++ b/src/sender/error.rs @@ -0,0 +1,59 @@ +use { + crate::prelude::{ParseRequestError, ParseResponseError}, + std::{fmt, io}, +}; + +#[derive(Debug)] +pub enum Error { + TlsError(rustls::Error), + RequestError(ParseRequestError), + ResponseError(ParseResponseError), + IoError(io::Error), +} + +impl fmt::Display for Error { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Self::TlsError(e) => write!(f, "{e}"), + Self::RequestError(e) => write!(f, "{e}"), + Self::ResponseError(e) => write!(f, "{e}"), + Self::IoError(e) => write!(f, "{e}"), + } + } +} + +impl std::error::Error for Error { + fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { + match self { + Self::RequestError(e) => Some(e), + Self::ResponseError(e) => Some(e), + Self::TlsError(e) => Some(e), + Self::IoError(e) => Some(e), + } + } +} + +impl From for Error { + fn from(value: rustls::Error) -> Self { + Self::TlsError(value) + } +} + +impl From for Error { + fn from(value: ParseRequestError) -> Self { + Self::RequestError(value) + } +} + +impl From for Error { + fn from(value: ParseResponseError) -> Self { + Self::ResponseError(value) + } +} + +impl From for Error { + fn from(value: io::Error) -> Self { + Self::IoError(value) + } +} + diff --git a/src/sender/store.rs b/src/sender/store.rs deleted file mode 100644 index 03cf53c..0000000 --- a/src/sender/store.rs +++ /dev/null @@ -1,4 +0,0 @@ -pub trait CertificateStore { - fn get(&self, host: &str) -> Option; - fn insert(&mut self, host: &str, fingerprint: &str); -} diff --git a/src/sender/verifier.rs b/src/sender/verifier.rs index 7506505..29e1a08 100644 --- a/src/sender/verifier.rs +++ b/src/sender/verifier.rs @@ -1,16 +1,21 @@ -use super::store::CertificateStore; use crate::fingerprint::Fingerprint; use rustls::{ client::{ServerCertVerified, ServerCertVerifier}, Certificate, }; +use std::sync::{Arc, Mutex}; -#[derive(Debug)] -pub struct Verifier<'a, T: CertificateStore> { - store: &'a T, +pub trait CertificateStore: Send + Sync { + fn get(&self, host: &str) -> Option; + fn insert(&mut self, host: &str, fingerprint: &str); } -impl<'a, T: CertificateStore + Sync> ServerCertVerifier for Verifier<'a, T> { +#[derive(Debug)] +pub struct Verifier { + store: Arc>, +} + +impl ServerCertVerifier for Verifier { fn verify_server_cert( &self, end_entity: &Certificate, @@ -29,8 +34,8 @@ impl<'a, T: CertificateStore + Sync> ServerCertVerifier for Verifier<'a, T> { _ => todo!(), }; if let Some(fingerprint) = match server_name { - rustls::ServerName::DnsName(n) => self.store.get(n.as_ref()), - rustls::ServerName::IpAddress(ip) => self.store.get(&ip.to_string()), + rustls::ServerName::DnsName(n) => self.store.lock().unwrap().get(n.as_ref()), + rustls::ServerName::IpAddress(ip) => self.store.lock().unwrap().get(&ip.to_string()), _ => todo!(), } { if fingerprint == fp.1 && name == fp.0 { @@ -38,16 +43,19 @@ impl<'a, T: CertificateStore + Sync> ServerCertVerifier for Verifier<'a, T> { } } else { // todo: need a way to update `self.store`. Probably will require - // an Arc> for interior mutability + // an Arc> for interior mutability. + // UPDATE: Now wrapped in Arc> } - return Err(rustls::Error::General( + Err(rustls::Error::General( "Unrecognized certificate".to_string(), - )); + )) } } -impl<'a, T: CertificateStore + Sync> Verifier<'a, T> { - pub fn new(store: &'a T) -> Self { - Self { store } +impl Verifier { + pub fn new(store: T) -> Self { + Self { + store: Arc::new(Mutex::new(store)), + } } }