smtp refactor

This commit is contained in:
Balakrishnan Balasubramanian 2023-06-12 00:49:48 -04:00
parent 9e8a5c0c2e
commit 0b8372c68c
2 changed files with 58 additions and 31 deletions

View File

@ -6,8 +6,8 @@ from jata import Jata, MutableDefault
class Match(Jata):
name: str
addrs: list[str] = MutableDefault(lambda: [])
addr_rexs: list[str] = MutableDefault(lambda: [])
addrs: list[str] = MutableDefault(lambda: []) # type: ignore
addr_rexs: list[str] = MutableDefault(lambda: []) # type: ignore
DEFAULT_MATCH_ALL = "default_match_all"
@ -100,3 +100,8 @@ def get_mboxes(addr: str, checks: list[Checker]) -> list[str]:
return
return list(inner())
def gen_addr_to_mboxes(cfg: Config) -> Callable[[str], [str]]:
checks = parse_checkers(cfg)
return lambda addr: get_mboxes(addr, checks)

View File

@ -3,41 +3,61 @@ import io
import logging
import mailbox
import ssl
import uuid
import shutil
from functools import partial
from pathlib import Path
from typing import Callable
from . import config
from email.message import Message
import email.policy
from email.generator import BytesGenerator
import tempfile
from aiosmtpd.handlers import Mailbox
from aiosmtpd.handlers import Mailbox, AsyncMessage
from aiosmtpd.smtp import SMTP, DATA_SIZE_DEFAULT
from aiosmtpd.smtp import SMTP as SMTPServer
from aiosmtpd.smtp import Envelope as SMTPEnvelope
from aiosmtpd.smtp import Session as SMTPSession
class MaildirCRLF(mailbox.Maildir):
_append_newline = True
class MyHandler(AsyncMessage):
def _dump_message(self, message, target, mangle_from_=False):
temp_buffer = io.BytesIO()
super()._dump_message(message, temp_buffer, mangle_from_=mangle_from_)
temp_buffer.seek(0)
data = temp_buffer.read()
data = data.replace(b'\n', b'\r\n')
target.write(data)
def __init__(self, mails_path: Path, mbox_finder: Callable[[str], [str]]):
super().__init__()
self.mails_path = mails_path
self.mbox_finder = mbox_finder
async def handle_DATA(self, server: SMTPServer, session: SMTPSession,
envelope: SMTPEnvelope) -> str:
self.rcpt_tos = envelope.rcpt_tos
return await super().handle_DATA(server, session, envelope)
async def handle_message(self, m: Message): # type: ignore[override]
all_mboxes: set[str] = set()
for addr in self.rcpt_tos:
all_mboxes.union(self.mbox_finder(addr))
if not all_mboxes:
return
for mbox in all_mboxes:
for sub in ('new', 'tmp', 'cur'):
sub_path = self.mails_path / mbox / sub
sub_path.mkdir(mode=0o755, exist_ok=True, parents=True)
with tempfile.TemporaryDirectory() as tmpdir:
temp_email_path = Path(tmpdir) / f"{uuid.uuid4()}.eml"
with open(temp_email_path, "wb") as fp:
gen = BytesGenerator(fp, policy=email.policy.SMTP)
gen.flatten(m)
for mbox in all_mboxes:
shutil.copy(temp_email_path, self.mails_path / mbox / 'new')
class MailboxCRLF(Mailbox):
def __init__(self, mail_dir: Path):
super().__init__(mail_dir)
for sub in ('new', 'tmp', 'cur'):
sub_path = mail_dir / sub
sub_path.mkdir(mode=0o755, exist_ok=True, parents=True)
self.mailbox = MaildirCRLF(mail_dir)
def protocol_factory_starttls(dirpath: Path,
def protocol_factory_starttls(mails_path: Path,
mbox_finder: Callable[[str], [str]],
context: ssl.SSLContext | None = None):
logging.info("Got smtp client cb")
logging.info("Got smtp client cb starttls")
try:
handler = MailboxCRLF(dirpath)
handler = MyHandler(mails_path, mbox_finder)
smtp = SMTP(handler=handler,
require_starttls=True,
tls_context=context,
@ -48,10 +68,10 @@ def protocol_factory_starttls(dirpath: Path,
return smtp
def protocol_factory(dirpath: Path):
def protocol_factory(mails_path: Path, mbox_finder: Callable[[str], [str]]):
logging.info("Got smtp client cb")
try:
handler = MailboxCRLF(dirpath)
handler = MyHandler(mails_path, mbox_finder)
smtp = SMTP(handler=handler, enable_SMTPUTF8=True)
except Exception as e:
logging.error("Something went wrong", e)
@ -59,19 +79,21 @@ def protocol_factory(dirpath: Path):
return smtp
async def create_smtp_server_starttls(dirpath: Path,
async def create_smtp_server_starttls(host: str,
port: int,
host="",
mails_path: Path,
mbox_finder: Callable[[str], [str]],
context: ssl.SSLContext | None = None):
loop = asyncio.get_event_loop()
return await loop.create_server(partial(protocol_factory_starttls, dirpath,
context),
return await loop.create_server(partial(protocol_factory_starttls,
mails_path, mbox_finder, context),
host=host,
port=port,
start_serving=False)
async def create_smtp_server_tls(dirpath: Path,
mbox_finder: Callable[[str], [str]],
port: int,
host="",
context: ssl.SSLContext | None = None):