Verifier: wrap self.store in Arc<Mutex<T>>;

Refactoring:
* Add prelude
* Move some error types into modules
* make most modules private and re-export their items
This commit is contained in:
Nathan Fisher 2023-05-24 12:07:18 -04:00
parent d2802ced83
commit 02de655640
8 changed files with 171 additions and 143 deletions

View file

@ -1,8 +1,9 @@
#![warn(clippy::all, clippy::pedantic)]
pub mod fingerprint;
pub mod host;
mod fingerprint;
mod host;
pub mod prelude;
pub mod receiver;
pub mod request;
pub mod response;
pub mod sender;
pub mod status;
mod request;
mod response;
mod sender;
mod status;

9
src/prelude.rs Normal file
View file

@ -0,0 +1,9 @@
pub use super::{
fingerprint::{Fingerprint, Error as FingerprintError},
host::{Host, ParseHostError},
receiver,
response::{Response, Error as ParseResponseError},
request::{Request, ParseRequestError},
sender::{CertificateStore, Error as SenderError, Sender, Verifier},
status::*,
};

View file

@ -1,6 +1,8 @@
use std::{fmt, num::ParseIntError, str::FromStr};
use std::{fmt, str::FromStr};
use crate::prelude::Status;
use crate::status::{ParseStatusError, Status};
mod error;
pub use error::Error;
#[derive(Clone, Debug, PartialEq)]
pub struct Response {
@ -14,58 +16,18 @@ impl fmt::Display for Response {
}
}
#[derive(Clone, Debug, PartialEq)]
pub enum ParseResponseError {
TooLong,
ParseInt(ParseIntError),
StatusError,
Malformed,
}
impl fmt::Display for ParseResponseError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::TooLong => write!(f, "ParseResponseError: too long"),
Self::ParseInt(e) => write!(f, "ParseResponseError: {e}"),
Self::StatusError => write!(f, "ParseResponseError: Invalid Status"),
Self::Malformed => write!(f, "ParseResponseError: Malformed"),
}
}
}
impl std::error::Error for ParseResponseError {
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
match self {
Self::ParseInt(e) => Some(e),
_ => None,
}
}
}
impl From<ParseIntError> for ParseResponseError {
fn from(value: ParseIntError) -> Self {
Self::ParseInt(value)
}
}
impl From<ParseStatusError> for ParseResponseError {
fn from(_value: ParseStatusError) -> Self {
Self::StatusError
}
}
impl FromStr for Response {
type Err = ParseResponseError;
type Err = Error;
fn from_str(s: &str) -> Result<Self, Self::Err> {
if s.len() > 2048 {
return Err(ParseResponseError::TooLong);
return Err(Error::TooLong);
}
if !s.ends_with("\r\n") {
return Err(ParseResponseError::Malformed);
return Err(Error::Malformed);
}
let Some((status, meta)) = s.split_once(' ') else {
return Err(ParseResponseError::Malformed);
return Err(Error::Malformed);
};
let status: u8 = status.parse()?;
let status: Status = status.try_into()?;
@ -90,14 +52,14 @@ mod tests {
#[test]
fn parse_badend() {
let response = "20 message delivered\n".parse::<Response>();
assert_eq!(response, Err(ParseResponseError::Malformed));
assert_eq!(response, Err(Error::Malformed));
}
#[test]
fn parse_badint() {
let response = "twenty message deliverred\r\n".parse::<Response>();
match response {
Err(ParseResponseError::ParseInt(_)) => {}
Err(Error::ParseInt(_)) => {}
_ => panic!(),
}
}

45
src/response/error.rs Normal file
View file

@ -0,0 +1,45 @@
use {
crate::prelude::ParseStatusError,
std::{fmt, num::ParseIntError},
};
#[derive(Clone, Debug, PartialEq)]
pub enum Error {
TooLong,
ParseInt(ParseIntError),
StatusError,
Malformed,
}
impl fmt::Display for Error {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::TooLong => write!(f, "ParseResponseError: too long"),
Self::ParseInt(e) => write!(f, "ParseResponseError: {e}"),
Self::StatusError => write!(f, "ParseResponseError: Invalid Status"),
Self::Malformed => write!(f, "ParseResponseError: Malformed"),
}
}
}
impl std::error::Error for Error {
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
match self {
Self::ParseInt(e) => Some(e),
_ => None,
}
}
}
impl From<ParseIntError> for Error {
fn from(value: ParseIntError) -> Self {
Self::ParseInt(value)
}
}
impl From<ParseStatusError> for Error {
fn from(_value: ParseStatusError) -> Self {
Self::StatusError
}
}

View file

@ -1,84 +1,32 @@
use self::{store::CertificateStore, verifier::Verifier};
use crate::{
request::{ParseRequestError, Request},
response::{ParseResponseError, Response},
};
use std::{
fmt,
io::{self, Read, Write},
request::Request,
response::Response,
};
use std::io::{Read, Write};
pub use self::{error::Error, verifier::{CertificateStore, Verifier}};
pub mod store;
pub mod verifier;
mod error;
mod verifier;
#[derive(Debug)]
pub struct Sender<'a, S: CertificateStore, C: Sized, T: Read + Write + Sized> {
pub request: Request,
pub verifier: Verifier<'a, S>,
pub stream: rustls::StreamOwned<C, T>,
}
#[derive(Debug)]
pub enum Error {
TlsError(rustls::Error),
RequestError(ParseRequestError),
ResponseError(ParseResponseError),
IoError(io::Error),
}
impl fmt::Display for Error {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::TlsError(e) => write!(f, "{e}"),
Self::RequestError(e) => write!(f, "{e}"),
Self::ResponseError(e) => write!(f, "{e}"),
Self::IoError(e) => write!(f, "{e}"),
}
}
}
impl std::error::Error for Error {
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
match self {
Self::RequestError(e) => Some(e),
Self::ResponseError(e) => Some(e),
Self::TlsError(e) => Some(e),
Self::IoError(e) => Some(e),
}
}
}
impl From<rustls::Error> for Error {
fn from(value: rustls::Error) -> Self {
Self::TlsError(value)
}
}
impl From<ParseRequestError> for Error {
fn from(value: ParseRequestError) -> Self {
Self::RequestError(value)
}
}
impl From<ParseResponseError> for Error {
fn from(value: ParseResponseError) -> Self {
Self::ResponseError(value)
}
}
impl From<io::Error> for Error {
fn from(value: io::Error) -> Self {
Self::IoError(value)
}
}
impl<'a, S, C, T> Sender<'a, S, C, T>
pub struct Sender<S, C, T>
where
S: CertificateStore + Sync,
S: CertificateStore,
C: Sized,
T: Read + Write + Sized,
{
pub fn new(request_str: &str, store: &'a S) -> Result<Self, Error> {
pub request: Request,
pub verifier: Verifier<S>,
pub stream: rustls::StreamOwned<C, T>,
}
impl<S, C, T> Sender<S, C, T>
where
S: CertificateStore,
C: Sized,
T: Read + Write + Sized,
{
pub fn new(request_str: &str, store: S) -> Result<Self, Error> {
let request: Request = request_str.parse()?;
let verifier = Verifier::new(store);
unimplemented!();

59
src/sender/error.rs Normal file
View file

@ -0,0 +1,59 @@
use {
crate::prelude::{ParseRequestError, ParseResponseError},
std::{fmt, io},
};
#[derive(Debug)]
pub enum Error {
TlsError(rustls::Error),
RequestError(ParseRequestError),
ResponseError(ParseResponseError),
IoError(io::Error),
}
impl fmt::Display for Error {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::TlsError(e) => write!(f, "{e}"),
Self::RequestError(e) => write!(f, "{e}"),
Self::ResponseError(e) => write!(f, "{e}"),
Self::IoError(e) => write!(f, "{e}"),
}
}
}
impl std::error::Error for Error {
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
match self {
Self::RequestError(e) => Some(e),
Self::ResponseError(e) => Some(e),
Self::TlsError(e) => Some(e),
Self::IoError(e) => Some(e),
}
}
}
impl From<rustls::Error> for Error {
fn from(value: rustls::Error) -> Self {
Self::TlsError(value)
}
}
impl From<ParseRequestError> for Error {
fn from(value: ParseRequestError) -> Self {
Self::RequestError(value)
}
}
impl From<ParseResponseError> for Error {
fn from(value: ParseResponseError) -> Self {
Self::ResponseError(value)
}
}
impl From<io::Error> for Error {
fn from(value: io::Error) -> Self {
Self::IoError(value)
}
}

View file

@ -1,4 +0,0 @@
pub trait CertificateStore {
fn get(&self, host: &str) -> Option<String>;
fn insert(&mut self, host: &str, fingerprint: &str);
}

View file

@ -1,16 +1,21 @@
use super::store::CertificateStore;
use crate::fingerprint::Fingerprint;
use rustls::{
client::{ServerCertVerified, ServerCertVerifier},
Certificate,
};
use std::sync::{Arc, Mutex};
#[derive(Debug)]
pub struct Verifier<'a, T: CertificateStore> {
store: &'a T,
pub trait CertificateStore: Send + Sync {
fn get(&self, host: &str) -> Option<String>;
fn insert(&mut self, host: &str, fingerprint: &str);
}
impl<'a, T: CertificateStore + Sync> ServerCertVerifier for Verifier<'a, T> {
#[derive(Debug)]
pub struct Verifier<S: CertificateStore> {
store: Arc<Mutex<S>>,
}
impl<S: CertificateStore> ServerCertVerifier for Verifier<S> {
fn verify_server_cert(
&self,
end_entity: &Certificate,
@ -29,8 +34,8 @@ impl<'a, T: CertificateStore + Sync> ServerCertVerifier for Verifier<'a, T> {
_ => todo!(),
};
if let Some(fingerprint) = match server_name {
rustls::ServerName::DnsName(n) => self.store.get(n.as_ref()),
rustls::ServerName::IpAddress(ip) => self.store.get(&ip.to_string()),
rustls::ServerName::DnsName(n) => self.store.lock().unwrap().get(n.as_ref()),
rustls::ServerName::IpAddress(ip) => self.store.lock().unwrap().get(&ip.to_string()),
_ => todo!(),
} {
if fingerprint == fp.1 && name == fp.0 {
@ -38,16 +43,19 @@ impl<'a, T: CertificateStore + Sync> ServerCertVerifier for Verifier<'a, T> {
}
} else {
// todo: need a way to update `self.store`. Probably will require
// an Arc<Mutex<T>> for interior mutability
// an Arc<Mutex<T>> for interior mutability.
// UPDATE: Now wrapped in Arc<Mutex<T>>
}
return Err(rustls::Error::General(
Err(rustls::Error::General(
"Unrecognized certificate".to_string(),
));
))
}
}
impl<'a, T: CertificateStore + Sync> Verifier<'a, T> {
pub fn new(store: &'a T) -> Self {
Self { store }
impl<T: CertificateStore> Verifier<T> {
pub fn new(store: T) -> Self {
Self {
store: Arc::new(Mutex::new(store)),
}
}
}