use bytes::{Buf, BytesMut}; use mailparse::MailHeaderMap; use std::fmt::{Display, Write as _}; use std::io::Cursor; use tokio::{ io::{AsyncReadExt, AsyncWriteExt, BufWriter}, net::{TcpListener, TcpStream}, }; use crate::mail::{Mail, MailPart, Mailbox}; #[derive(PartialEq, Eq)] pub enum ConnectionState { Commands, Data, } #[derive(Debug, Clone)] pub enum Frame { Header, Raw(String), Ehlo(String), From(String), To(String), Ok(String), DataStart, DataEnd, Quit, StartMailInput, Close, } impl Display for Frame { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { Frame::Ok(v) => write!(f, "250 {}\r\n", v), Frame::Raw(v) => write!(f, "{}\r\n", v), Frame::Close => write!(f, "221 Closing connection\r\n"), Frame::Header => write!(f, "220 Mailspy Test Server\r\n"), Frame::StartMailInput => write!(f, "354 Start mail input\r\n"), _ => Ok(()), } } } impl Frame { pub fn check(buf: &mut Cursor<&[u8]>) -> anyhow::Result<()> { get_line(buf)?; Ok(()) } pub fn parse(buf: &mut Cursor<&[u8]>) -> anyhow::Result { let line = get_line(buf)?.to_vec(); let string = String::from_utf8(line)?; if string.to_lowercase().starts_with("ehlo") { return Ok(Frame::Ehlo(string[3..].to_string())); } if string.to_lowercase().starts_with("mail from:") { return Ok(Frame::From(string[9..].to_string())); } if string.to_lowercase().starts_with("rcpt to:") { return Ok(Frame::To(string[8..].to_string())); } if string.to_lowercase().starts_with("data") { return Ok(Frame::DataStart); } if string.to_lowercase().starts_with("quit") { return Ok(Frame::Quit); } if string.to_lowercase().trim().starts_with('.') { return Ok(Frame::DataEnd); } Ok(Frame::Raw(string)) } } #[async_trait::async_trait] pub trait Transmitter { async fn read_frame(&mut self) -> anyhow::Result>; async fn write_frame(&mut self, frame: &Frame) -> anyhow::Result<()>; } pub struct Connection { stream: BufWriter, buffer: BytesMut, } #[async_trait::async_trait] impl Transmitter for Connection { async fn read_frame(&mut self) -> anyhow::Result> { loop { if let Some(frame) = self.parse_frame()? { return Ok(Some(frame)); } if 0 == self.stream.read_buf(&mut self.buffer).await? { if self.buffer.is_empty() { return Ok(None); } else { return Err(anyhow::anyhow!("Connection reset by peer")); } } } } async fn write_frame(&mut self, frame: &Frame) -> anyhow::Result<()> { self.stream.write_all(frame.to_string().as_bytes()).await?; self.stream.flush().await?; Ok(()) } } impl Connection { pub fn new(socket: TcpStream) -> Connection { Connection { stream: BufWriter::new(socket), buffer: BytesMut::with_capacity(4 * 1024), } } fn parse_frame(&mut self) -> anyhow::Result> { let mut buf = Cursor::new(&self.buffer[..]); match Frame::check(&mut buf) { Ok(_) => { let len = buf.position() as usize; buf.set_position(0); let frame = Frame::parse(&mut buf)?; self.buffer.advance(len); Ok(Some(frame)) } Err(_) => Ok(None), } } } pub async fn server(mailbox: Mailbox, port: u16) -> anyhow::Result<()> { let addr = format!("0.0.0.0:{}", port); let listener = TcpListener::bind(&addr).await?; tracing::info!(port =? port, "SMTP Server running"); tokio::spawn(async move { loop { if let Ok((socket, _)) = listener.accept().await { let mb = mailbox.clone(); tokio::spawn(async move { if let Err(e) = process(Connection::new(socket), mb).await { tracing::error!("Mail processing error: {}", e); } }); } } }); Ok(()) } async fn process(mut connection: impl Transmitter, mailbox: Mailbox) -> anyhow::Result<()> { connection.write_frame(&Frame::Header).await.unwrap(); let mut mail_from = String::new(); let mut rcpt_to = String::new(); let mut state = ConnectionState::Commands; let mut data = String::new(); loop { if let Some(frame) = connection.read_frame().await.unwrap() { match frame { Frame::Ehlo(_) => { connection .write_frame(&Frame::Ok("EHLO".to_string())) .await?; } Frame::From(val) => { mail_from = val; connection .write_frame(&Frame::Ok("2.1.0 Sender OK".to_string())) .await?; } Frame::To(val) => { rcpt_to = val; connection .write_frame(&Frame::Ok("2.1.5 Recipient OK".to_string())) .await?; } Frame::DataStart => { connection.write_frame(&Frame::StartMailInput).await?; state = ConnectionState::Data; } Frame::DataEnd => { connection .write_frame(&Frame::Ok("Mail sent".to_string())) .await?; let mut mail = Mail::default(); let data = data.clone(); let msg = mailparse::parse_mail(data.as_bytes())?; if msg.subparts.is_empty() { mail.body = vec![MailPart { content_type: msg .headers .get_first_value("Content-Type") .unwrap_or_else(|| "text/plain".to_string()), data: msg.get_body()?, }]; } else { mail.body = msg .subparts .iter() .map(|part| { part.get_body().map(|body| MailPart { content_type: part .headers .get_first_value("Content-Type") .unwrap_or_else(|| "text/plain".to_string()), data: body, }) }) .collect::, _>>()?; } mail.subject = msg.headers.get_first_value("Subject").unwrap_or_default(); mail.from = msg .headers .get_first_value("From") .unwrap_or_else(|| mail_from.clone()); mail.to = msg .headers .get_first_value("To") .unwrap_or_else(|| rcpt_to.clone()); mail.date = msg.headers.get_first_value("Date").unwrap_or_default(); mailbox.store(mail).await; connection.write_frame(&Frame::Close).await?; break Ok(()); } Frame::Raw(s) if state == ConnectionState::Data => { writeln!(data, "{}", s)?; } _ => { connection.write_frame(&Frame::Ok("OK".to_string())).await?; } } } } } /// Find a line fn get_line<'a>(src: &mut Cursor<&'a [u8]>) -> anyhow::Result<&'a [u8]> { if !src.has_remaining() { return Err(anyhow::anyhow!("Incomplete")); } // Scan the bytes directly let start = src.position() as usize; // Scan to the second to last byte let end = src.get_ref().len() - 1; for i in start..end { if src.get_ref()[i + 1] == b'\n' { // We found a line, update the position to be *after* the \n src.set_position((i + 2) as u64); // Return the line return Ok(&src.get_ref()[start..i + 1]); } } Err(anyhow::anyhow!("Incomplete")) } #[cfg(test)] mod tests { use tokio::sync::mpsc::{UnboundedReceiver, UnboundedSender}; use super::*; struct TestConnection { pub in_rx: UnboundedReceiver, pub out_tx: UnboundedSender, } #[async_trait::async_trait] impl Transmitter for TestConnection { async fn read_frame(&mut self) -> anyhow::Result> { Ok(self.in_rx.recv().await) } async fn write_frame(&mut self, frame: &Frame) -> anyhow::Result<()> { self.out_tx.send(frame.clone())?; Ok(()) } } impl TestConnection { pub fn setup() -> (UnboundedSender, UnboundedReceiver, Self) { let (in_tx, in_rx) = tokio::sync::mpsc::unbounded_channel(); let (out_tx, out_rx) = tokio::sync::mpsc::unbounded_channel(); let s = Self { in_rx, out_tx }; (in_tx, out_rx, s) } } #[tokio::test] async fn test_process() { let mb = Mailbox::new(); let (tx, mut rx, conn) = TestConnection::setup(); tokio::spawn(process(conn, mb.clone())); assert!(matches!(rx.recv().await, Some(Frame::Header))); tx.send(Frame::Ehlo("Test mail client".to_string())) .unwrap(); assert!(matches!(rx.recv().await, Some(Frame::Ok(_)))); tx.send(Frame::From("alice@alison.com".to_string())) .unwrap(); assert!(matches!(rx.recv().await, Some(Frame::Ok(_)))); tx.send(Frame::To("alice@alison.com".to_string())).unwrap(); assert!(matches!(rx.recv().await, Some(Frame::Ok(_)))); tx.send(Frame::To("alice@alison.com".to_string())).unwrap(); tx.send(Frame::DataStart).unwrap(); tx.send(Frame::Raw("body".to_string())).unwrap(); tx.send(Frame::DataEnd).unwrap(); assert!(matches!(rx.recv().await, Some(Frame::Ok(_)))); } #[test] fn test_get_line() { let mut c = Cursor::new("First line\nSecond line\nThird line\n".as_bytes()); assert_eq!(get_line(&mut c).unwrap(), "First line".as_bytes()); assert_eq!(get_line(&mut c).unwrap(), "Second line".as_bytes()); assert_eq!(get_line(&mut c).unwrap(), "Third line".as_bytes()); } #[test] fn test_frame_parse() { assert!(matches!( Frame::parse(&mut Cursor::new("EHLO example.com\n".as_bytes())).unwrap(), Frame::Ehlo(_), )); assert!(matches!( Frame::parse(&mut Cursor::new("RCPT TO: alice@example.com\n".as_bytes())).unwrap(), Frame::To(_), )); assert!(matches!( Frame::parse(&mut Cursor::new( "MAIL FROM: alice@example.com\n".as_bytes() )) .unwrap(), Frame::From(_), )); assert!(matches!( Frame::parse(&mut Cursor::new("DATA\n".as_bytes())).unwrap(), Frame::DataStart, )); assert!(matches!( Frame::parse(&mut Cursor::new(".\n".as_bytes())).unwrap(), Frame::DataEnd, )); assert!(matches!( Frame::parse(&mut Cursor::new("QUIT\n".as_bytes())).unwrap(), Frame::Quit, )); } }