Verifier: wrap self.store
in Arc<Mutex<T>>;
Refactoring: * Add prelude * Move some error types into modules * make most modules private and re-export their items
This commit is contained in:
parent
d2802ced83
commit
02de655640
8 changed files with 171 additions and 143 deletions
13
src/lib.rs
13
src/lib.rs
|
@ -1,8 +1,9 @@
|
|||
#![warn(clippy::all, clippy::pedantic)]
|
||||
pub mod fingerprint;
|
||||
pub mod host;
|
||||
mod fingerprint;
|
||||
mod host;
|
||||
pub mod prelude;
|
||||
pub mod receiver;
|
||||
pub mod request;
|
||||
pub mod response;
|
||||
pub mod sender;
|
||||
pub mod status;
|
||||
mod request;
|
||||
mod response;
|
||||
mod sender;
|
||||
mod status;
|
||||
|
|
9
src/prelude.rs
Normal file
9
src/prelude.rs
Normal file
|
@ -0,0 +1,9 @@
|
|||
pub use super::{
|
||||
fingerprint::{Fingerprint, Error as FingerprintError},
|
||||
host::{Host, ParseHostError},
|
||||
receiver,
|
||||
response::{Response, Error as ParseResponseError},
|
||||
request::{Request, ParseRequestError},
|
||||
sender::{CertificateStore, Error as SenderError, Sender, Verifier},
|
||||
status::*,
|
||||
};
|
|
@ -1,6 +1,8 @@
|
|||
use std::{fmt, num::ParseIntError, str::FromStr};
|
||||
use std::{fmt, str::FromStr};
|
||||
use crate::prelude::Status;
|
||||
|
||||
use crate::status::{ParseStatusError, Status};
|
||||
mod error;
|
||||
pub use error::Error;
|
||||
|
||||
#[derive(Clone, Debug, PartialEq)]
|
||||
pub struct Response {
|
||||
|
@ -14,58 +16,18 @@ impl fmt::Display for Response {
|
|||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, PartialEq)]
|
||||
pub enum ParseResponseError {
|
||||
TooLong,
|
||||
ParseInt(ParseIntError),
|
||||
StatusError,
|
||||
Malformed,
|
||||
}
|
||||
|
||||
impl fmt::Display for ParseResponseError {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
match self {
|
||||
Self::TooLong => write!(f, "ParseResponseError: too long"),
|
||||
Self::ParseInt(e) => write!(f, "ParseResponseError: {e}"),
|
||||
Self::StatusError => write!(f, "ParseResponseError: Invalid Status"),
|
||||
Self::Malformed => write!(f, "ParseResponseError: Malformed"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl std::error::Error for ParseResponseError {
|
||||
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
|
||||
match self {
|
||||
Self::ParseInt(e) => Some(e),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<ParseIntError> for ParseResponseError {
|
||||
fn from(value: ParseIntError) -> Self {
|
||||
Self::ParseInt(value)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<ParseStatusError> for ParseResponseError {
|
||||
fn from(_value: ParseStatusError) -> Self {
|
||||
Self::StatusError
|
||||
}
|
||||
}
|
||||
|
||||
impl FromStr for Response {
|
||||
type Err = ParseResponseError;
|
||||
type Err = Error;
|
||||
|
||||
fn from_str(s: &str) -> Result<Self, Self::Err> {
|
||||
if s.len() > 2048 {
|
||||
return Err(ParseResponseError::TooLong);
|
||||
return Err(Error::TooLong);
|
||||
}
|
||||
if !s.ends_with("\r\n") {
|
||||
return Err(ParseResponseError::Malformed);
|
||||
return Err(Error::Malformed);
|
||||
}
|
||||
let Some((status, meta)) = s.split_once(' ') else {
|
||||
return Err(ParseResponseError::Malformed);
|
||||
return Err(Error::Malformed);
|
||||
};
|
||||
let status: u8 = status.parse()?;
|
||||
let status: Status = status.try_into()?;
|
||||
|
@ -90,14 +52,14 @@ mod tests {
|
|||
#[test]
|
||||
fn parse_badend() {
|
||||
let response = "20 message delivered\n".parse::<Response>();
|
||||
assert_eq!(response, Err(ParseResponseError::Malformed));
|
||||
assert_eq!(response, Err(Error::Malformed));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_badint() {
|
||||
let response = "twenty message deliverred\r\n".parse::<Response>();
|
||||
match response {
|
||||
Err(ParseResponseError::ParseInt(_)) => {}
|
||||
Err(Error::ParseInt(_)) => {}
|
||||
_ => panic!(),
|
||||
}
|
||||
}
|
||||
|
|
45
src/response/error.rs
Normal file
45
src/response/error.rs
Normal file
|
@ -0,0 +1,45 @@
|
|||
use {
|
||||
crate::prelude::ParseStatusError,
|
||||
std::{fmt, num::ParseIntError},
|
||||
};
|
||||
|
||||
#[derive(Clone, Debug, PartialEq)]
|
||||
pub enum Error {
|
||||
TooLong,
|
||||
ParseInt(ParseIntError),
|
||||
StatusError,
|
||||
Malformed,
|
||||
}
|
||||
|
||||
impl fmt::Display for Error {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
match self {
|
||||
Self::TooLong => write!(f, "ParseResponseError: too long"),
|
||||
Self::ParseInt(e) => write!(f, "ParseResponseError: {e}"),
|
||||
Self::StatusError => write!(f, "ParseResponseError: Invalid Status"),
|
||||
Self::Malformed => write!(f, "ParseResponseError: Malformed"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl std::error::Error for Error {
|
||||
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
|
||||
match self {
|
||||
Self::ParseInt(e) => Some(e),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<ParseIntError> for Error {
|
||||
fn from(value: ParseIntError) -> Self {
|
||||
Self::ParseInt(value)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<ParseStatusError> for Error {
|
||||
fn from(_value: ParseStatusError) -> Self {
|
||||
Self::StatusError
|
||||
}
|
||||
}
|
||||
|
|
@ -1,84 +1,32 @@
|
|||
use self::{store::CertificateStore, verifier::Verifier};
|
||||
use crate::{
|
||||
request::{ParseRequestError, Request},
|
||||
response::{ParseResponseError, Response},
|
||||
};
|
||||
use std::{
|
||||
fmt,
|
||||
io::{self, Read, Write},
|
||||
request::Request,
|
||||
response::Response,
|
||||
};
|
||||
use std::io::{Read, Write};
|
||||
pub use self::{error::Error, verifier::{CertificateStore, Verifier}};
|
||||
|
||||
pub mod store;
|
||||
pub mod verifier;
|
||||
mod error;
|
||||
mod verifier;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct Sender<'a, S: CertificateStore, C: Sized, T: Read + Write + Sized> {
|
||||
pub request: Request,
|
||||
pub verifier: Verifier<'a, S>,
|
||||
pub stream: rustls::StreamOwned<C, T>,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub enum Error {
|
||||
TlsError(rustls::Error),
|
||||
RequestError(ParseRequestError),
|
||||
ResponseError(ParseResponseError),
|
||||
IoError(io::Error),
|
||||
}
|
||||
|
||||
impl fmt::Display for Error {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
match self {
|
||||
Self::TlsError(e) => write!(f, "{e}"),
|
||||
Self::RequestError(e) => write!(f, "{e}"),
|
||||
Self::ResponseError(e) => write!(f, "{e}"),
|
||||
Self::IoError(e) => write!(f, "{e}"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl std::error::Error for Error {
|
||||
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
|
||||
match self {
|
||||
Self::RequestError(e) => Some(e),
|
||||
Self::ResponseError(e) => Some(e),
|
||||
Self::TlsError(e) => Some(e),
|
||||
Self::IoError(e) => Some(e),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<rustls::Error> for Error {
|
||||
fn from(value: rustls::Error) -> Self {
|
||||
Self::TlsError(value)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<ParseRequestError> for Error {
|
||||
fn from(value: ParseRequestError) -> Self {
|
||||
Self::RequestError(value)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<ParseResponseError> for Error {
|
||||
fn from(value: ParseResponseError) -> Self {
|
||||
Self::ResponseError(value)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<io::Error> for Error {
|
||||
fn from(value: io::Error) -> Self {
|
||||
Self::IoError(value)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, S, C, T> Sender<'a, S, C, T>
|
||||
pub struct Sender<S, C, T>
|
||||
where
|
||||
S: CertificateStore + Sync,
|
||||
S: CertificateStore,
|
||||
C: Sized,
|
||||
T: Read + Write + Sized,
|
||||
{
|
||||
pub fn new(request_str: &str, store: &'a S) -> Result<Self, Error> {
|
||||
pub request: Request,
|
||||
pub verifier: Verifier<S>,
|
||||
pub stream: rustls::StreamOwned<C, T>,
|
||||
}
|
||||
|
||||
impl<S, C, T> Sender<S, C, T>
|
||||
where
|
||||
S: CertificateStore,
|
||||
C: Sized,
|
||||
T: Read + Write + Sized,
|
||||
{
|
||||
pub fn new(request_str: &str, store: S) -> Result<Self, Error> {
|
||||
let request: Request = request_str.parse()?;
|
||||
let verifier = Verifier::new(store);
|
||||
unimplemented!();
|
||||
|
|
59
src/sender/error.rs
Normal file
59
src/sender/error.rs
Normal file
|
@ -0,0 +1,59 @@
|
|||
use {
|
||||
crate::prelude::{ParseRequestError, ParseResponseError},
|
||||
std::{fmt, io},
|
||||
};
|
||||
|
||||
#[derive(Debug)]
|
||||
pub enum Error {
|
||||
TlsError(rustls::Error),
|
||||
RequestError(ParseRequestError),
|
||||
ResponseError(ParseResponseError),
|
||||
IoError(io::Error),
|
||||
}
|
||||
|
||||
impl fmt::Display for Error {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
match self {
|
||||
Self::TlsError(e) => write!(f, "{e}"),
|
||||
Self::RequestError(e) => write!(f, "{e}"),
|
||||
Self::ResponseError(e) => write!(f, "{e}"),
|
||||
Self::IoError(e) => write!(f, "{e}"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl std::error::Error for Error {
|
||||
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
|
||||
match self {
|
||||
Self::RequestError(e) => Some(e),
|
||||
Self::ResponseError(e) => Some(e),
|
||||
Self::TlsError(e) => Some(e),
|
||||
Self::IoError(e) => Some(e),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<rustls::Error> for Error {
|
||||
fn from(value: rustls::Error) -> Self {
|
||||
Self::TlsError(value)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<ParseRequestError> for Error {
|
||||
fn from(value: ParseRequestError) -> Self {
|
||||
Self::RequestError(value)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<ParseResponseError> for Error {
|
||||
fn from(value: ParseResponseError) -> Self {
|
||||
Self::ResponseError(value)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<io::Error> for Error {
|
||||
fn from(value: io::Error) -> Self {
|
||||
Self::IoError(value)
|
||||
}
|
||||
}
|
||||
|
|
@ -1,4 +0,0 @@
|
|||
pub trait CertificateStore {
|
||||
fn get(&self, host: &str) -> Option<String>;
|
||||
fn insert(&mut self, host: &str, fingerprint: &str);
|
||||
}
|
|
@ -1,16 +1,21 @@
|
|||
use super::store::CertificateStore;
|
||||
use crate::fingerprint::Fingerprint;
|
||||
use rustls::{
|
||||
client::{ServerCertVerified, ServerCertVerifier},
|
||||
Certificate,
|
||||
};
|
||||
use std::sync::{Arc, Mutex};
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct Verifier<'a, T: CertificateStore> {
|
||||
store: &'a T,
|
||||
pub trait CertificateStore: Send + Sync {
|
||||
fn get(&self, host: &str) -> Option<String>;
|
||||
fn insert(&mut self, host: &str, fingerprint: &str);
|
||||
}
|
||||
|
||||
impl<'a, T: CertificateStore + Sync> ServerCertVerifier for Verifier<'a, T> {
|
||||
#[derive(Debug)]
|
||||
pub struct Verifier<S: CertificateStore> {
|
||||
store: Arc<Mutex<S>>,
|
||||
}
|
||||
|
||||
impl<S: CertificateStore> ServerCertVerifier for Verifier<S> {
|
||||
fn verify_server_cert(
|
||||
&self,
|
||||
end_entity: &Certificate,
|
||||
|
@ -29,8 +34,8 @@ impl<'a, T: CertificateStore + Sync> ServerCertVerifier for Verifier<'a, T> {
|
|||
_ => todo!(),
|
||||
};
|
||||
if let Some(fingerprint) = match server_name {
|
||||
rustls::ServerName::DnsName(n) => self.store.get(n.as_ref()),
|
||||
rustls::ServerName::IpAddress(ip) => self.store.get(&ip.to_string()),
|
||||
rustls::ServerName::DnsName(n) => self.store.lock().unwrap().get(n.as_ref()),
|
||||
rustls::ServerName::IpAddress(ip) => self.store.lock().unwrap().get(&ip.to_string()),
|
||||
_ => todo!(),
|
||||
} {
|
||||
if fingerprint == fp.1 && name == fp.0 {
|
||||
|
@ -38,16 +43,19 @@ impl<'a, T: CertificateStore + Sync> ServerCertVerifier for Verifier<'a, T> {
|
|||
}
|
||||
} else {
|
||||
// todo: need a way to update `self.store`. Probably will require
|
||||
// an Arc<Mutex<T>> for interior mutability
|
||||
// an Arc<Mutex<T>> for interior mutability.
|
||||
// UPDATE: Now wrapped in Arc<Mutex<T>>
|
||||
}
|
||||
return Err(rustls::Error::General(
|
||||
Err(rustls::Error::General(
|
||||
"Unrecognized certificate".to_string(),
|
||||
));
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, T: CertificateStore + Sync> Verifier<'a, T> {
|
||||
pub fn new(store: &'a T) -> Self {
|
||||
Self { store }
|
||||
impl<T: CertificateStore> Verifier<T> {
|
||||
pub fn new(store: T) -> Self {
|
||||
Self {
|
||||
store: Arc::new(Mutex::new(store)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Add table
Reference in a new issue