Send client certs if they exist in ClientCertificateStore; Impl

TryFrom<Vec<u8>> for Response; Adjust error types to fit all new cases;
This commit is contained in:
Nathan Fisher 2023-05-27 10:58:14 -04:00
parent 54a099bb44
commit 46d04405ad
4 changed files with 53 additions and 23 deletions

View file

@ -1,6 +1,6 @@
use { use {
crate::prelude::ParseStatusError, crate::prelude::ParseStatusError,
std::{fmt, num::ParseIntError}, std::{fmt, num::ParseIntError, string::FromUtf8Error},
}; };
#[derive(Debug, PartialEq)] #[derive(Debug, PartialEq)]
@ -16,6 +16,8 @@ pub enum Error {
StatusError, StatusError,
/// The response was malformed /// The response was malformed
Malformed, Malformed,
/// The response is not valid utf8
Utf8Error,
} }
impl fmt::Display for Error { impl fmt::Display for Error {
@ -25,6 +27,7 @@ impl fmt::Display for Error {
Self::ParseInt(e) => write!(f, "ParseResponseError: {e}"), Self::ParseInt(e) => write!(f, "ParseResponseError: {e}"),
Self::StatusError => write!(f, "ParseResponseError: Invalid Status"), Self::StatusError => write!(f, "ParseResponseError: Invalid Status"),
Self::Malformed => write!(f, "ParseResponseError: Malformed"), Self::Malformed => write!(f, "ParseResponseError: Malformed"),
Self::Utf8Error => write!(f, "ParseResponseError: Not Utf8"),
} }
} }
} }
@ -49,3 +52,9 @@ impl From<ParseStatusError> for Error {
Self::StatusError Self::StatusError
} }
} }
impl From<FromUtf8Error> for Error {
fn from(_value: FromUtf8Error) -> Self {
Self::Utf8Error
}
}

View file

@ -43,6 +43,14 @@ impl FromStr for Response {
} }
} }
impl TryFrom<Vec<u8>> for Response {
type Error = Error;
fn try_from(value: Vec<u8>) -> Result<Self, Self::Error> {
String::from_utf8(value)?.as_str().parse()
}
}
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;

View file

@ -1,11 +1,13 @@
use { use {
crate::prelude::{ParseRequestError, ParseResponseError}, crate::prelude::{ParseRequestError, ParseResponseError},
rustls::InvalidMessage,
std::{fmt, io}, std::{fmt, io},
}; };
#[derive(Debug)] #[derive(Debug)]
/// Errors which might occur when sending a message /// Errors which might occur when sending a message
pub enum Error { pub enum Error {
CertificateError(InvalidMessage),
DnsError, DnsError,
TlsError(rustls::Error), TlsError(rustls::Error),
RequestError(ParseRequestError), RequestError(ParseRequestError),
@ -16,6 +18,7 @@ pub enum Error {
impl fmt::Display for Error { impl fmt::Display for Error {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self { match self {
Self::CertificateError(e) => write!(f, "{e:?}"),
Self::DnsError => write!(f, "Dns Error"), Self::DnsError => write!(f, "Dns Error"),
Self::TlsError(e) => write!(f, "{e}"), Self::TlsError(e) => write!(f, "{e}"),
Self::RequestError(e) => write!(f, "{e}"), Self::RequestError(e) => write!(f, "{e}"),
@ -28,15 +31,21 @@ impl fmt::Display for Error {
impl std::error::Error for Error { impl std::error::Error for Error {
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
match self { match self {
Self::DnsError => None,
Self::RequestError(e) => Some(e), Self::RequestError(e) => Some(e),
Self::ResponseError(e) => Some(e), Self::ResponseError(e) => Some(e),
Self::TlsError(e) => Some(e), Self::TlsError(e) => Some(e),
Self::IoError(e) => Some(e), Self::IoError(e) => Some(e),
_ => None,
} }
} }
} }
impl From<InvalidMessage> for Error {
fn from(value: InvalidMessage) -> Self {
Self::CertificateError(value)
}
}
impl From<rustls::Error> for Error { impl From<rustls::Error> for Error {
fn from(value: rustls::Error) -> Self { fn from(value: rustls::Error) -> Self {
Self::TlsError(value) Self::TlsError(value)

View file

@ -1,40 +1,40 @@
use std::{ use {
io, crate::prelude::{CertificateStore, ClientCertificateStore, Request, Response},
net::{TcpStream, ToSocketAddrs}, rustls::{internal::msgs::codec::Codec, ClientConfig, ClientConnection, StreamOwned},
sync::Arc, std::{
time::Duration, io::{self, Read, Write},
net::{TcpStream, ToSocketAddrs},
sync::Arc,
time::Duration,
},
}; };
use rustls::{ClientConfig, ClientConnection, StreamOwned};
pub use self::{error::Error, verifier::Verifier}; pub use self::{error::Error, verifier::Verifier};
use {
crate::prelude::{CertificateStore, Request},
std::io::{Read, Write},
};
mod error; mod error;
mod verifier; mod verifier;
#[derive(Debug)] #[derive(Debug)]
/// Sends a piece of mail from the sending server to the receiving server /// Sends a piece of mail from the sending server to the receiving server
pub struct Sender<S> pub struct Sender<S, C>
where where
S: CertificateStore, S: CertificateStore,
C: ClientCertificateStore,
{ {
/// The full message text to be sent /// The full message text to be sent
pub request: Request, pub request: Request,
/// A [CertificateStore] for servers known to us /// A [CertificateStore] for servers known to us
pub store: S, pub store: S,
/// A [CertificateStore] for mailboxes which exist on this system /// A [CertificateStore] for mailboxes which exist on this system
pub client_store: S, pub client_store: C,
} }
impl<S> Sender<S> impl<S, C> Sender<S, C>
where where
S: CertificateStore + 'static, S: CertificateStore + 'static,
C: ClientCertificateStore,
{ {
pub fn new(request_str: &str, store: S, client_store: S) -> Result<Self, Error> { pub fn new(request_str: &str, store: S, client_store: C) -> Result<Self, Error> {
let request: Request = request_str.parse()?; let request: Request = request_str.parse()?;
Ok(Self { Ok(Self {
request, request,
@ -47,16 +47,14 @@ where
self.request.sender.host.to_string() self.request.sender.host.to_string()
} }
pub fn send(self) -> Result<Vec<u8>, Error> { pub fn send(self) -> Result<Response, Error> {
let dnsname = self let dnsname = self
.host_string() .host_string()
.as_str() .as_str()
.try_into() .try_into()
.map_err(|_| Error::DnsError)?; .map_err(|_| Error::DnsError)?;
let mut it = self.request.sender.host.to_socket_addrs()?; let mut it = self.request.sender.host.to_socket_addrs()?;
let client_cert = self let client_cert = self.client_store.get_certificate(&self.request.sender);
.client_store
.get_certificate(&self.request.sender.to_string());
let verifier = Arc::new(Verifier::new(self.store)); let verifier = Arc::new(Verifier::new(self.store));
let Some(socket_addrs) = it.next() else { let Some(socket_addrs) = it.next() else {
return Err(io::Error::new(io::ErrorKind::Other, "no data retrieved").into()); return Err(io::Error::new(io::ErrorKind::Other, "no data retrieved").into());
@ -67,7 +65,12 @@ where
.with_custom_certificate_verifier(verifier); .with_custom_certificate_verifier(verifier);
let cfg = match client_cert { let cfg = match client_cert {
None => cfg.with_no_client_auth(), None => cfg.with_no_client_auth(),
Some(_) => cfg.with_no_client_auth(), // todo: cfg.with_single_cert(cert_chain, key_der) Some(c) => {
let rustls_cert = rustls::Certificate::read_bytes(&c.der)?;
let cert_chain = vec![rustls_cert];
let key_der = rustls::PrivateKey(c.key);
cfg.with_single_cert(cert_chain, key_der)?
}
}; };
let client = ClientConnection::new(Arc::new(cfg), dnsname)?; let client = ClientConnection::new(Arc::new(cfg), dnsname)?;
let mut stream = StreamOwned::new(client, tcp_stream); let mut stream = StreamOwned::new(client, tcp_stream);
@ -76,6 +79,7 @@ where
stream.read_to_end(&mut buf)?; stream.read_to_end(&mut buf)?;
stream.conn.send_close_notify(); stream.conn.send_close_notify();
drop(stream); drop(stream);
Ok(buf) let res = buf.try_into()?;
Ok(res)
} }
} }