Verifier: wrap self.store
in Arc<Mutex<T>>;
Refactoring: * Add prelude * Move some error types into modules * make most modules private and re-export their items
This commit is contained in:
parent
d2802ced83
commit
02de655640
8 changed files with 171 additions and 143 deletions
13
src/lib.rs
13
src/lib.rs
|
@ -1,8 +1,9 @@
|
||||||
#![warn(clippy::all, clippy::pedantic)]
|
#![warn(clippy::all, clippy::pedantic)]
|
||||||
pub mod fingerprint;
|
mod fingerprint;
|
||||||
pub mod host;
|
mod host;
|
||||||
|
pub mod prelude;
|
||||||
pub mod receiver;
|
pub mod receiver;
|
||||||
pub mod request;
|
mod request;
|
||||||
pub mod response;
|
mod response;
|
||||||
pub mod sender;
|
mod sender;
|
||||||
pub mod status;
|
mod status;
|
||||||
|
|
9
src/prelude.rs
Normal file
9
src/prelude.rs
Normal file
|
@ -0,0 +1,9 @@
|
||||||
|
pub use super::{
|
||||||
|
fingerprint::{Fingerprint, Error as FingerprintError},
|
||||||
|
host::{Host, ParseHostError},
|
||||||
|
receiver,
|
||||||
|
response::{Response, Error as ParseResponseError},
|
||||||
|
request::{Request, ParseRequestError},
|
||||||
|
sender::{CertificateStore, Error as SenderError, Sender, Verifier},
|
||||||
|
status::*,
|
||||||
|
};
|
|
@ -1,6 +1,8 @@
|
||||||
use std::{fmt, num::ParseIntError, str::FromStr};
|
use std::{fmt, str::FromStr};
|
||||||
|
use crate::prelude::Status;
|
||||||
|
|
||||||
use crate::status::{ParseStatusError, Status};
|
mod error;
|
||||||
|
pub use error::Error;
|
||||||
|
|
||||||
#[derive(Clone, Debug, PartialEq)]
|
#[derive(Clone, Debug, PartialEq)]
|
||||||
pub struct Response {
|
pub struct Response {
|
||||||
|
@ -14,58 +16,18 @@ impl fmt::Display for Response {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone, Debug, PartialEq)]
|
|
||||||
pub enum ParseResponseError {
|
|
||||||
TooLong,
|
|
||||||
ParseInt(ParseIntError),
|
|
||||||
StatusError,
|
|
||||||
Malformed,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl fmt::Display for ParseResponseError {
|
|
||||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
|
||||||
match self {
|
|
||||||
Self::TooLong => write!(f, "ParseResponseError: too long"),
|
|
||||||
Self::ParseInt(e) => write!(f, "ParseResponseError: {e}"),
|
|
||||||
Self::StatusError => write!(f, "ParseResponseError: Invalid Status"),
|
|
||||||
Self::Malformed => write!(f, "ParseResponseError: Malformed"),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl std::error::Error for ParseResponseError {
|
|
||||||
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
|
|
||||||
match self {
|
|
||||||
Self::ParseInt(e) => Some(e),
|
|
||||||
_ => None,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl From<ParseIntError> for ParseResponseError {
|
|
||||||
fn from(value: ParseIntError) -> Self {
|
|
||||||
Self::ParseInt(value)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl From<ParseStatusError> for ParseResponseError {
|
|
||||||
fn from(_value: ParseStatusError) -> Self {
|
|
||||||
Self::StatusError
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl FromStr for Response {
|
impl FromStr for Response {
|
||||||
type Err = ParseResponseError;
|
type Err = Error;
|
||||||
|
|
||||||
fn from_str(s: &str) -> Result<Self, Self::Err> {
|
fn from_str(s: &str) -> Result<Self, Self::Err> {
|
||||||
if s.len() > 2048 {
|
if s.len() > 2048 {
|
||||||
return Err(ParseResponseError::TooLong);
|
return Err(Error::TooLong);
|
||||||
}
|
}
|
||||||
if !s.ends_with("\r\n") {
|
if !s.ends_with("\r\n") {
|
||||||
return Err(ParseResponseError::Malformed);
|
return Err(Error::Malformed);
|
||||||
}
|
}
|
||||||
let Some((status, meta)) = s.split_once(' ') else {
|
let Some((status, meta)) = s.split_once(' ') else {
|
||||||
return Err(ParseResponseError::Malformed);
|
return Err(Error::Malformed);
|
||||||
};
|
};
|
||||||
let status: u8 = status.parse()?;
|
let status: u8 = status.parse()?;
|
||||||
let status: Status = status.try_into()?;
|
let status: Status = status.try_into()?;
|
||||||
|
@ -90,14 +52,14 @@ mod tests {
|
||||||
#[test]
|
#[test]
|
||||||
fn parse_badend() {
|
fn parse_badend() {
|
||||||
let response = "20 message delivered\n".parse::<Response>();
|
let response = "20 message delivered\n".parse::<Response>();
|
||||||
assert_eq!(response, Err(ParseResponseError::Malformed));
|
assert_eq!(response, Err(Error::Malformed));
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn parse_badint() {
|
fn parse_badint() {
|
||||||
let response = "twenty message deliverred\r\n".parse::<Response>();
|
let response = "twenty message deliverred\r\n".parse::<Response>();
|
||||||
match response {
|
match response {
|
||||||
Err(ParseResponseError::ParseInt(_)) => {}
|
Err(Error::ParseInt(_)) => {}
|
||||||
_ => panic!(),
|
_ => panic!(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
45
src/response/error.rs
Normal file
45
src/response/error.rs
Normal file
|
@ -0,0 +1,45 @@
|
||||||
|
use {
|
||||||
|
crate::prelude::ParseStatusError,
|
||||||
|
std::{fmt, num::ParseIntError},
|
||||||
|
};
|
||||||
|
|
||||||
|
#[derive(Clone, Debug, PartialEq)]
|
||||||
|
pub enum Error {
|
||||||
|
TooLong,
|
||||||
|
ParseInt(ParseIntError),
|
||||||
|
StatusError,
|
||||||
|
Malformed,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl fmt::Display for Error {
|
||||||
|
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||||
|
match self {
|
||||||
|
Self::TooLong => write!(f, "ParseResponseError: too long"),
|
||||||
|
Self::ParseInt(e) => write!(f, "ParseResponseError: {e}"),
|
||||||
|
Self::StatusError => write!(f, "ParseResponseError: Invalid Status"),
|
||||||
|
Self::Malformed => write!(f, "ParseResponseError: Malformed"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl std::error::Error for Error {
|
||||||
|
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
|
||||||
|
match self {
|
||||||
|
Self::ParseInt(e) => Some(e),
|
||||||
|
_ => None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<ParseIntError> for Error {
|
||||||
|
fn from(value: ParseIntError) -> Self {
|
||||||
|
Self::ParseInt(value)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<ParseStatusError> for Error {
|
||||||
|
fn from(_value: ParseStatusError) -> Self {
|
||||||
|
Self::StatusError
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
@ -1,84 +1,32 @@
|
||||||
use self::{store::CertificateStore, verifier::Verifier};
|
|
||||||
use crate::{
|
use crate::{
|
||||||
request::{ParseRequestError, Request},
|
request::Request,
|
||||||
response::{ParseResponseError, Response},
|
response::Response,
|
||||||
};
|
|
||||||
use std::{
|
|
||||||
fmt,
|
|
||||||
io::{self, Read, Write},
|
|
||||||
};
|
};
|
||||||
|
use std::io::{Read, Write};
|
||||||
|
pub use self::{error::Error, verifier::{CertificateStore, Verifier}};
|
||||||
|
|
||||||
pub mod store;
|
mod error;
|
||||||
pub mod verifier;
|
mod verifier;
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
pub struct Sender<'a, S: CertificateStore, C: Sized, T: Read + Write + Sized> {
|
pub struct Sender<S, C, T>
|
||||||
pub request: Request,
|
|
||||||
pub verifier: Verifier<'a, S>,
|
|
||||||
pub stream: rustls::StreamOwned<C, T>,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug)]
|
|
||||||
pub enum Error {
|
|
||||||
TlsError(rustls::Error),
|
|
||||||
RequestError(ParseRequestError),
|
|
||||||
ResponseError(ParseResponseError),
|
|
||||||
IoError(io::Error),
|
|
||||||
}
|
|
||||||
|
|
||||||
impl fmt::Display for Error {
|
|
||||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
|
||||||
match self {
|
|
||||||
Self::TlsError(e) => write!(f, "{e}"),
|
|
||||||
Self::RequestError(e) => write!(f, "{e}"),
|
|
||||||
Self::ResponseError(e) => write!(f, "{e}"),
|
|
||||||
Self::IoError(e) => write!(f, "{e}"),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl std::error::Error for Error {
|
|
||||||
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
|
|
||||||
match self {
|
|
||||||
Self::RequestError(e) => Some(e),
|
|
||||||
Self::ResponseError(e) => Some(e),
|
|
||||||
Self::TlsError(e) => Some(e),
|
|
||||||
Self::IoError(e) => Some(e),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl From<rustls::Error> for Error {
|
|
||||||
fn from(value: rustls::Error) -> Self {
|
|
||||||
Self::TlsError(value)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl From<ParseRequestError> for Error {
|
|
||||||
fn from(value: ParseRequestError) -> Self {
|
|
||||||
Self::RequestError(value)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl From<ParseResponseError> for Error {
|
|
||||||
fn from(value: ParseResponseError) -> Self {
|
|
||||||
Self::ResponseError(value)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl From<io::Error> for Error {
|
|
||||||
fn from(value: io::Error) -> Self {
|
|
||||||
Self::IoError(value)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<'a, S, C, T> Sender<'a, S, C, T>
|
|
||||||
where
|
where
|
||||||
S: CertificateStore + Sync,
|
S: CertificateStore,
|
||||||
C: Sized,
|
C: Sized,
|
||||||
T: Read + Write + Sized,
|
T: Read + Write + Sized,
|
||||||
{
|
{
|
||||||
pub fn new(request_str: &str, store: &'a S) -> Result<Self, Error> {
|
pub request: Request,
|
||||||
|
pub verifier: Verifier<S>,
|
||||||
|
pub stream: rustls::StreamOwned<C, T>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<S, C, T> Sender<S, C, T>
|
||||||
|
where
|
||||||
|
S: CertificateStore,
|
||||||
|
C: Sized,
|
||||||
|
T: Read + Write + Sized,
|
||||||
|
{
|
||||||
|
pub fn new(request_str: &str, store: S) -> Result<Self, Error> {
|
||||||
let request: Request = request_str.parse()?;
|
let request: Request = request_str.parse()?;
|
||||||
let verifier = Verifier::new(store);
|
let verifier = Verifier::new(store);
|
||||||
unimplemented!();
|
unimplemented!();
|
||||||
|
|
59
src/sender/error.rs
Normal file
59
src/sender/error.rs
Normal file
|
@ -0,0 +1,59 @@
|
||||||
|
use {
|
||||||
|
crate::prelude::{ParseRequestError, ParseResponseError},
|
||||||
|
std::{fmt, io},
|
||||||
|
};
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub enum Error {
|
||||||
|
TlsError(rustls::Error),
|
||||||
|
RequestError(ParseRequestError),
|
||||||
|
ResponseError(ParseResponseError),
|
||||||
|
IoError(io::Error),
|
||||||
|
}
|
||||||
|
|
||||||
|
impl fmt::Display for Error {
|
||||||
|
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||||
|
match self {
|
||||||
|
Self::TlsError(e) => write!(f, "{e}"),
|
||||||
|
Self::RequestError(e) => write!(f, "{e}"),
|
||||||
|
Self::ResponseError(e) => write!(f, "{e}"),
|
||||||
|
Self::IoError(e) => write!(f, "{e}"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl std::error::Error for Error {
|
||||||
|
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
|
||||||
|
match self {
|
||||||
|
Self::RequestError(e) => Some(e),
|
||||||
|
Self::ResponseError(e) => Some(e),
|
||||||
|
Self::TlsError(e) => Some(e),
|
||||||
|
Self::IoError(e) => Some(e),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<rustls::Error> for Error {
|
||||||
|
fn from(value: rustls::Error) -> Self {
|
||||||
|
Self::TlsError(value)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<ParseRequestError> for Error {
|
||||||
|
fn from(value: ParseRequestError) -> Self {
|
||||||
|
Self::RequestError(value)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<ParseResponseError> for Error {
|
||||||
|
fn from(value: ParseResponseError) -> Self {
|
||||||
|
Self::ResponseError(value)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<io::Error> for Error {
|
||||||
|
fn from(value: io::Error) -> Self {
|
||||||
|
Self::IoError(value)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
@ -1,4 +0,0 @@
|
||||||
pub trait CertificateStore {
|
|
||||||
fn get(&self, host: &str) -> Option<String>;
|
|
||||||
fn insert(&mut self, host: &str, fingerprint: &str);
|
|
||||||
}
|
|
|
@ -1,16 +1,21 @@
|
||||||
use super::store::CertificateStore;
|
|
||||||
use crate::fingerprint::Fingerprint;
|
use crate::fingerprint::Fingerprint;
|
||||||
use rustls::{
|
use rustls::{
|
||||||
client::{ServerCertVerified, ServerCertVerifier},
|
client::{ServerCertVerified, ServerCertVerifier},
|
||||||
Certificate,
|
Certificate,
|
||||||
};
|
};
|
||||||
|
use std::sync::{Arc, Mutex};
|
||||||
|
|
||||||
#[derive(Debug)]
|
pub trait CertificateStore: Send + Sync {
|
||||||
pub struct Verifier<'a, T: CertificateStore> {
|
fn get(&self, host: &str) -> Option<String>;
|
||||||
store: &'a T,
|
fn insert(&mut self, host: &str, fingerprint: &str);
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<'a, T: CertificateStore + Sync> ServerCertVerifier for Verifier<'a, T> {
|
#[derive(Debug)]
|
||||||
|
pub struct Verifier<S: CertificateStore> {
|
||||||
|
store: Arc<Mutex<S>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<S: CertificateStore> ServerCertVerifier for Verifier<S> {
|
||||||
fn verify_server_cert(
|
fn verify_server_cert(
|
||||||
&self,
|
&self,
|
||||||
end_entity: &Certificate,
|
end_entity: &Certificate,
|
||||||
|
@ -29,8 +34,8 @@ impl<'a, T: CertificateStore + Sync> ServerCertVerifier for Verifier<'a, T> {
|
||||||
_ => todo!(),
|
_ => todo!(),
|
||||||
};
|
};
|
||||||
if let Some(fingerprint) = match server_name {
|
if let Some(fingerprint) = match server_name {
|
||||||
rustls::ServerName::DnsName(n) => self.store.get(n.as_ref()),
|
rustls::ServerName::DnsName(n) => self.store.lock().unwrap().get(n.as_ref()),
|
||||||
rustls::ServerName::IpAddress(ip) => self.store.get(&ip.to_string()),
|
rustls::ServerName::IpAddress(ip) => self.store.lock().unwrap().get(&ip.to_string()),
|
||||||
_ => todo!(),
|
_ => todo!(),
|
||||||
} {
|
} {
|
||||||
if fingerprint == fp.1 && name == fp.0 {
|
if fingerprint == fp.1 && name == fp.0 {
|
||||||
|
@ -38,16 +43,19 @@ impl<'a, T: CertificateStore + Sync> ServerCertVerifier for Verifier<'a, T> {
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
// todo: need a way to update `self.store`. Probably will require
|
// todo: need a way to update `self.store`. Probably will require
|
||||||
// an Arc<Mutex<T>> for interior mutability
|
// an Arc<Mutex<T>> for interior mutability.
|
||||||
|
// UPDATE: Now wrapped in Arc<Mutex<T>>
|
||||||
}
|
}
|
||||||
return Err(rustls::Error::General(
|
Err(rustls::Error::General(
|
||||||
"Unrecognized certificate".to_string(),
|
"Unrecognized certificate".to_string(),
|
||||||
));
|
))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<'a, T: CertificateStore + Sync> Verifier<'a, T> {
|
impl<T: CertificateStore> Verifier<T> {
|
||||||
pub fn new(store: &'a T) -> Self {
|
pub fn new(store: T) -> Self {
|
||||||
Self { store }
|
Self {
|
||||||
|
store: Arc::new(Mutex::new(store)),
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Add table
Reference in a new issue