Verifier: wrap self.store in Arc<Mutex<T>>;

Refactoring:
* Add prelude
* Move some error types into modules
* make most modules private and re-export their items
This commit is contained in:
Nathan Fisher 2023-05-24 12:07:18 -04:00
parent d2802ced83
commit 02de655640
8 changed files with 171 additions and 143 deletions

View file

@ -1,8 +1,9 @@
#![warn(clippy::all, clippy::pedantic)] #![warn(clippy::all, clippy::pedantic)]
pub mod fingerprint; mod fingerprint;
pub mod host; mod host;
pub mod prelude;
pub mod receiver; pub mod receiver;
pub mod request; mod request;
pub mod response; mod response;
pub mod sender; mod sender;
pub mod status; mod status;

9
src/prelude.rs Normal file
View file

@ -0,0 +1,9 @@
pub use super::{
fingerprint::{Fingerprint, Error as FingerprintError},
host::{Host, ParseHostError},
receiver,
response::{Response, Error as ParseResponseError},
request::{Request, ParseRequestError},
sender::{CertificateStore, Error as SenderError, Sender, Verifier},
status::*,
};

View file

@ -1,6 +1,8 @@
use std::{fmt, num::ParseIntError, str::FromStr}; use std::{fmt, str::FromStr};
use crate::prelude::Status;
use crate::status::{ParseStatusError, Status}; mod error;
pub use error::Error;
#[derive(Clone, Debug, PartialEq)] #[derive(Clone, Debug, PartialEq)]
pub struct Response { pub struct Response {
@ -14,58 +16,18 @@ impl fmt::Display for Response {
} }
} }
#[derive(Clone, Debug, PartialEq)]
pub enum ParseResponseError {
TooLong,
ParseInt(ParseIntError),
StatusError,
Malformed,
}
impl fmt::Display for ParseResponseError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::TooLong => write!(f, "ParseResponseError: too long"),
Self::ParseInt(e) => write!(f, "ParseResponseError: {e}"),
Self::StatusError => write!(f, "ParseResponseError: Invalid Status"),
Self::Malformed => write!(f, "ParseResponseError: Malformed"),
}
}
}
impl std::error::Error for ParseResponseError {
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
match self {
Self::ParseInt(e) => Some(e),
_ => None,
}
}
}
impl From<ParseIntError> for ParseResponseError {
fn from(value: ParseIntError) -> Self {
Self::ParseInt(value)
}
}
impl From<ParseStatusError> for ParseResponseError {
fn from(_value: ParseStatusError) -> Self {
Self::StatusError
}
}
impl FromStr for Response { impl FromStr for Response {
type Err = ParseResponseError; type Err = Error;
fn from_str(s: &str) -> Result<Self, Self::Err> { fn from_str(s: &str) -> Result<Self, Self::Err> {
if s.len() > 2048 { if s.len() > 2048 {
return Err(ParseResponseError::TooLong); return Err(Error::TooLong);
} }
if !s.ends_with("\r\n") { if !s.ends_with("\r\n") {
return Err(ParseResponseError::Malformed); return Err(Error::Malformed);
} }
let Some((status, meta)) = s.split_once(' ') else { let Some((status, meta)) = s.split_once(' ') else {
return Err(ParseResponseError::Malformed); return Err(Error::Malformed);
}; };
let status: u8 = status.parse()?; let status: u8 = status.parse()?;
let status: Status = status.try_into()?; let status: Status = status.try_into()?;
@ -90,14 +52,14 @@ mod tests {
#[test] #[test]
fn parse_badend() { fn parse_badend() {
let response = "20 message delivered\n".parse::<Response>(); let response = "20 message delivered\n".parse::<Response>();
assert_eq!(response, Err(ParseResponseError::Malformed)); assert_eq!(response, Err(Error::Malformed));
} }
#[test] #[test]
fn parse_badint() { fn parse_badint() {
let response = "twenty message deliverred\r\n".parse::<Response>(); let response = "twenty message deliverred\r\n".parse::<Response>();
match response { match response {
Err(ParseResponseError::ParseInt(_)) => {} Err(Error::ParseInt(_)) => {}
_ => panic!(), _ => panic!(),
} }
} }

45
src/response/error.rs Normal file
View file

@ -0,0 +1,45 @@
use {
crate::prelude::ParseStatusError,
std::{fmt, num::ParseIntError},
};
#[derive(Clone, Debug, PartialEq)]
pub enum Error {
TooLong,
ParseInt(ParseIntError),
StatusError,
Malformed,
}
impl fmt::Display for Error {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::TooLong => write!(f, "ParseResponseError: too long"),
Self::ParseInt(e) => write!(f, "ParseResponseError: {e}"),
Self::StatusError => write!(f, "ParseResponseError: Invalid Status"),
Self::Malformed => write!(f, "ParseResponseError: Malformed"),
}
}
}
impl std::error::Error for Error {
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
match self {
Self::ParseInt(e) => Some(e),
_ => None,
}
}
}
impl From<ParseIntError> for Error {
fn from(value: ParseIntError) -> Self {
Self::ParseInt(value)
}
}
impl From<ParseStatusError> for Error {
fn from(_value: ParseStatusError) -> Self {
Self::StatusError
}
}

View file

@ -1,84 +1,32 @@
use self::{store::CertificateStore, verifier::Verifier};
use crate::{ use crate::{
request::{ParseRequestError, Request}, request::Request,
response::{ParseResponseError, Response}, response::Response,
};
use std::{
fmt,
io::{self, Read, Write},
}; };
use std::io::{Read, Write};
pub use self::{error::Error, verifier::{CertificateStore, Verifier}};
pub mod store; mod error;
pub mod verifier; mod verifier;
#[derive(Debug)] #[derive(Debug)]
pub struct Sender<'a, S: CertificateStore, C: Sized, T: Read + Write + Sized> { pub struct Sender<S, C, T>
pub request: Request,
pub verifier: Verifier<'a, S>,
pub stream: rustls::StreamOwned<C, T>,
}
#[derive(Debug)]
pub enum Error {
TlsError(rustls::Error),
RequestError(ParseRequestError),
ResponseError(ParseResponseError),
IoError(io::Error),
}
impl fmt::Display for Error {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::TlsError(e) => write!(f, "{e}"),
Self::RequestError(e) => write!(f, "{e}"),
Self::ResponseError(e) => write!(f, "{e}"),
Self::IoError(e) => write!(f, "{e}"),
}
}
}
impl std::error::Error for Error {
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
match self {
Self::RequestError(e) => Some(e),
Self::ResponseError(e) => Some(e),
Self::TlsError(e) => Some(e),
Self::IoError(e) => Some(e),
}
}
}
impl From<rustls::Error> for Error {
fn from(value: rustls::Error) -> Self {
Self::TlsError(value)
}
}
impl From<ParseRequestError> for Error {
fn from(value: ParseRequestError) -> Self {
Self::RequestError(value)
}
}
impl From<ParseResponseError> for Error {
fn from(value: ParseResponseError) -> Self {
Self::ResponseError(value)
}
}
impl From<io::Error> for Error {
fn from(value: io::Error) -> Self {
Self::IoError(value)
}
}
impl<'a, S, C, T> Sender<'a, S, C, T>
where where
S: CertificateStore + Sync, S: CertificateStore,
C: Sized, C: Sized,
T: Read + Write + Sized, T: Read + Write + Sized,
{ {
pub fn new(request_str: &str, store: &'a S) -> Result<Self, Error> { pub request: Request,
pub verifier: Verifier<S>,
pub stream: rustls::StreamOwned<C, T>,
}
impl<S, C, T> Sender<S, C, T>
where
S: CertificateStore,
C: Sized,
T: Read + Write + Sized,
{
pub fn new(request_str: &str, store: S) -> Result<Self, Error> {
let request: Request = request_str.parse()?; let request: Request = request_str.parse()?;
let verifier = Verifier::new(store); let verifier = Verifier::new(store);
unimplemented!(); unimplemented!();

59
src/sender/error.rs Normal file
View file

@ -0,0 +1,59 @@
use {
crate::prelude::{ParseRequestError, ParseResponseError},
std::{fmt, io},
};
#[derive(Debug)]
pub enum Error {
TlsError(rustls::Error),
RequestError(ParseRequestError),
ResponseError(ParseResponseError),
IoError(io::Error),
}
impl fmt::Display for Error {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::TlsError(e) => write!(f, "{e}"),
Self::RequestError(e) => write!(f, "{e}"),
Self::ResponseError(e) => write!(f, "{e}"),
Self::IoError(e) => write!(f, "{e}"),
}
}
}
impl std::error::Error for Error {
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
match self {
Self::RequestError(e) => Some(e),
Self::ResponseError(e) => Some(e),
Self::TlsError(e) => Some(e),
Self::IoError(e) => Some(e),
}
}
}
impl From<rustls::Error> for Error {
fn from(value: rustls::Error) -> Self {
Self::TlsError(value)
}
}
impl From<ParseRequestError> for Error {
fn from(value: ParseRequestError) -> Self {
Self::RequestError(value)
}
}
impl From<ParseResponseError> for Error {
fn from(value: ParseResponseError) -> Self {
Self::ResponseError(value)
}
}
impl From<io::Error> for Error {
fn from(value: io::Error) -> Self {
Self::IoError(value)
}
}

View file

@ -1,4 +0,0 @@
pub trait CertificateStore {
fn get(&self, host: &str) -> Option<String>;
fn insert(&mut self, host: &str, fingerprint: &str);
}

View file

@ -1,16 +1,21 @@
use super::store::CertificateStore;
use crate::fingerprint::Fingerprint; use crate::fingerprint::Fingerprint;
use rustls::{ use rustls::{
client::{ServerCertVerified, ServerCertVerifier}, client::{ServerCertVerified, ServerCertVerifier},
Certificate, Certificate,
}; };
use std::sync::{Arc, Mutex};
#[derive(Debug)] pub trait CertificateStore: Send + Sync {
pub struct Verifier<'a, T: CertificateStore> { fn get(&self, host: &str) -> Option<String>;
store: &'a T, fn insert(&mut self, host: &str, fingerprint: &str);
} }
impl<'a, T: CertificateStore + Sync> ServerCertVerifier for Verifier<'a, T> { #[derive(Debug)]
pub struct Verifier<S: CertificateStore> {
store: Arc<Mutex<S>>,
}
impl<S: CertificateStore> ServerCertVerifier for Verifier<S> {
fn verify_server_cert( fn verify_server_cert(
&self, &self,
end_entity: &Certificate, end_entity: &Certificate,
@ -29,8 +34,8 @@ impl<'a, T: CertificateStore + Sync> ServerCertVerifier for Verifier<'a, T> {
_ => todo!(), _ => todo!(),
}; };
if let Some(fingerprint) = match server_name { if let Some(fingerprint) = match server_name {
rustls::ServerName::DnsName(n) => self.store.get(n.as_ref()), rustls::ServerName::DnsName(n) => self.store.lock().unwrap().get(n.as_ref()),
rustls::ServerName::IpAddress(ip) => self.store.get(&ip.to_string()), rustls::ServerName::IpAddress(ip) => self.store.lock().unwrap().get(&ip.to_string()),
_ => todo!(), _ => todo!(),
} { } {
if fingerprint == fp.1 && name == fp.0 { if fingerprint == fp.1 && name == fp.0 {
@ -38,16 +43,19 @@ impl<'a, T: CertificateStore + Sync> ServerCertVerifier for Verifier<'a, T> {
} }
} else { } else {
// todo: need a way to update `self.store`. Probably will require // todo: need a way to update `self.store`. Probably will require
// an Arc<Mutex<T>> for interior mutability // an Arc<Mutex<T>> for interior mutability.
// UPDATE: Now wrapped in Arc<Mutex<T>>
} }
return Err(rustls::Error::General( Err(rustls::Error::General(
"Unrecognized certificate".to_string(), "Unrecognized certificate".to_string(),
)); ))
} }
} }
impl<'a, T: CertificateStore + Sync> Verifier<'a, T> { impl<T: CertificateStore> Verifier<T> {
pub fn new(store: &'a T) -> Self { pub fn new(store: T) -> Self {
Self { store } Self {
store: Arc::new(Mutex::new(store)),
}
} }
} }