smtp refactor

This commit is contained in:
Balakrishnan Balasubramanian 2023-06-12 00:49:48 -04:00
parent 9e8a5c0c2e
commit 0b8372c68c
2 changed files with 58 additions and 31 deletions

View File

@ -6,8 +6,8 @@ from jata import Jata, MutableDefault
class Match(Jata): class Match(Jata):
name: str name: str
addrs: list[str] = MutableDefault(lambda: []) addrs: list[str] = MutableDefault(lambda: []) # type: ignore
addr_rexs: list[str] = MutableDefault(lambda: []) addr_rexs: list[str] = MutableDefault(lambda: []) # type: ignore
DEFAULT_MATCH_ALL = "default_match_all" DEFAULT_MATCH_ALL = "default_match_all"
@ -100,3 +100,8 @@ def get_mboxes(addr: str, checks: list[Checker]) -> list[str]:
return return
return list(inner()) return list(inner())
def gen_addr_to_mboxes(cfg: Config) -> Callable[[str], [str]]:
checks = parse_checkers(cfg)
return lambda addr: get_mboxes(addr, checks)

View File

@ -3,41 +3,61 @@ import io
import logging import logging
import mailbox import mailbox
import ssl import ssl
import uuid
import shutil
from functools import partial from functools import partial
from pathlib import Path from pathlib import Path
from typing import Callable
from . import config from . import config
from email.message import Message
import email.policy
from email.generator import BytesGenerator
import tempfile
from aiosmtpd.handlers import Mailbox from aiosmtpd.handlers import Mailbox, AsyncMessage
from aiosmtpd.smtp import SMTP, DATA_SIZE_DEFAULT from aiosmtpd.smtp import SMTP, DATA_SIZE_DEFAULT
from aiosmtpd.smtp import SMTP as SMTPServer
from aiosmtpd.smtp import Envelope as SMTPEnvelope
from aiosmtpd.smtp import Session as SMTPSession
class MaildirCRLF(mailbox.Maildir): class MyHandler(AsyncMessage):
_append_newline = True
def _dump_message(self, message, target, mangle_from_=False): def __init__(self, mails_path: Path, mbox_finder: Callable[[str], [str]]):
temp_buffer = io.BytesIO() super().__init__()
super()._dump_message(message, temp_buffer, mangle_from_=mangle_from_) self.mails_path = mails_path
temp_buffer.seek(0) self.mbox_finder = mbox_finder
data = temp_buffer.read()
data = data.replace(b'\n', b'\r\n') async def handle_DATA(self, server: SMTPServer, session: SMTPSession,
target.write(data) envelope: SMTPEnvelope) -> str:
self.rcpt_tos = envelope.rcpt_tos
return await super().handle_DATA(server, session, envelope)
async def handle_message(self, m: Message): # type: ignore[override]
all_mboxes: set[str] = set()
for addr in self.rcpt_tos:
all_mboxes.union(self.mbox_finder(addr))
if not all_mboxes:
return
for mbox in all_mboxes:
for sub in ('new', 'tmp', 'cur'):
sub_path = self.mails_path / mbox / sub
sub_path.mkdir(mode=0o755, exist_ok=True, parents=True)
with tempfile.TemporaryDirectory() as tmpdir:
temp_email_path = Path(tmpdir) / f"{uuid.uuid4()}.eml"
with open(temp_email_path, "wb") as fp:
gen = BytesGenerator(fp, policy=email.policy.SMTP)
gen.flatten(m)
for mbox in all_mboxes:
shutil.copy(temp_email_path, self.mails_path / mbox / 'new')
class MailboxCRLF(Mailbox): def protocol_factory_starttls(mails_path: Path,
mbox_finder: Callable[[str], [str]],
def __init__(self, mail_dir: Path):
super().__init__(mail_dir)
for sub in ('new', 'tmp', 'cur'):
sub_path = mail_dir / sub
sub_path.mkdir(mode=0o755, exist_ok=True, parents=True)
self.mailbox = MaildirCRLF(mail_dir)
def protocol_factory_starttls(dirpath: Path,
context: ssl.SSLContext | None = None): context: ssl.SSLContext | None = None):
logging.info("Got smtp client cb") logging.info("Got smtp client cb starttls")
try: try:
handler = MailboxCRLF(dirpath) handler = MyHandler(mails_path, mbox_finder)
smtp = SMTP(handler=handler, smtp = SMTP(handler=handler,
require_starttls=True, require_starttls=True,
tls_context=context, tls_context=context,
@ -48,10 +68,10 @@ def protocol_factory_starttls(dirpath: Path,
return smtp return smtp
def protocol_factory(dirpath: Path): def protocol_factory(mails_path: Path, mbox_finder: Callable[[str], [str]]):
logging.info("Got smtp client cb") logging.info("Got smtp client cb")
try: try:
handler = MailboxCRLF(dirpath) handler = MyHandler(mails_path, mbox_finder)
smtp = SMTP(handler=handler, enable_SMTPUTF8=True) smtp = SMTP(handler=handler, enable_SMTPUTF8=True)
except Exception as e: except Exception as e:
logging.error("Something went wrong", e) logging.error("Something went wrong", e)
@ -59,19 +79,21 @@ def protocol_factory(dirpath: Path):
return smtp return smtp
async def create_smtp_server_starttls(dirpath: Path, async def create_smtp_server_starttls(host: str,
port: int, port: int,
host="", mails_path: Path,
mbox_finder: Callable[[str], [str]],
context: ssl.SSLContext | None = None): context: ssl.SSLContext | None = None):
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
return await loop.create_server(partial(protocol_factory_starttls, dirpath, return await loop.create_server(partial(protocol_factory_starttls,
context), mails_path, mbox_finder, context),
host=host, host=host,
port=port, port=port,
start_serving=False) start_serving=False)
async def create_smtp_server_tls(dirpath: Path, async def create_smtp_server_tls(dirpath: Path,
mbox_finder: Callable[[str], [str]],
port: int, port: int,
host="", host="",
context: ssl.SSLContext | None = None): context: ssl.SSLContext | None = None):