smtp refactor
This commit is contained in:
		@@ -6,8 +6,8 @@ from jata import Jata, MutableDefault
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
class Match(Jata):
 | 
					class Match(Jata):
 | 
				
			||||||
    name: str
 | 
					    name: str
 | 
				
			||||||
    addrs: list[str] = MutableDefault(lambda: [])
 | 
					    addrs: list[str] = MutableDefault(lambda: []) # type: ignore
 | 
				
			||||||
    addr_rexs: list[str] = MutableDefault(lambda: [])
 | 
					    addr_rexs: list[str] = MutableDefault(lambda: []) # type: ignore
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
DEFAULT_MATCH_ALL = "default_match_all"
 | 
					DEFAULT_MATCH_ALL = "default_match_all"
 | 
				
			||||||
@@ -100,3 +100,8 @@ def get_mboxes(addr: str, checks: list[Checker]) -> list[str]:
 | 
				
			|||||||
                    return
 | 
					                    return
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    return list(inner())
 | 
					    return list(inner())
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def gen_addr_to_mboxes(cfg: Config) -> Callable[[str], [str]]:
 | 
				
			||||||
 | 
					    checks = parse_checkers(cfg)
 | 
				
			||||||
 | 
					    return lambda addr: get_mboxes(addr, checks)
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -3,41 +3,61 @@ import io
 | 
				
			|||||||
import logging
 | 
					import logging
 | 
				
			||||||
import mailbox
 | 
					import mailbox
 | 
				
			||||||
import ssl
 | 
					import ssl
 | 
				
			||||||
 | 
					import uuid
 | 
				
			||||||
 | 
					import shutil
 | 
				
			||||||
from functools import partial
 | 
					from functools import partial
 | 
				
			||||||
from pathlib import Path
 | 
					from pathlib import Path
 | 
				
			||||||
 | 
					from typing import Callable
 | 
				
			||||||
from . import config
 | 
					from . import config
 | 
				
			||||||
 | 
					from email.message import Message
 | 
				
			||||||
 | 
					import email.policy
 | 
				
			||||||
 | 
					from email.generator import BytesGenerator
 | 
				
			||||||
 | 
					import tempfile
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from aiosmtpd.handlers import Mailbox
 | 
					from aiosmtpd.handlers import Mailbox, AsyncMessage
 | 
				
			||||||
from aiosmtpd.smtp import SMTP, DATA_SIZE_DEFAULT
 | 
					from aiosmtpd.smtp import SMTP, DATA_SIZE_DEFAULT
 | 
				
			||||||
 | 
					from aiosmtpd.smtp import SMTP as SMTPServer
 | 
				
			||||||
 | 
					from aiosmtpd.smtp import Envelope as SMTPEnvelope
 | 
				
			||||||
 | 
					from aiosmtpd.smtp import Session as SMTPSession
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class MaildirCRLF(mailbox.Maildir):
 | 
					class MyHandler(AsyncMessage):
 | 
				
			||||||
    _append_newline = True
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def _dump_message(self, message, target, mangle_from_=False):
 | 
					    def __init__(self, mails_path: Path, mbox_finder: Callable[[str], [str]]):
 | 
				
			||||||
        temp_buffer = io.BytesIO()
 | 
					        super().__init__()
 | 
				
			||||||
        super()._dump_message(message, temp_buffer, mangle_from_=mangle_from_)
 | 
					        self.mails_path = mails_path
 | 
				
			||||||
        temp_buffer.seek(0)
 | 
					        self.mbox_finder = mbox_finder
 | 
				
			||||||
        data = temp_buffer.read()
 | 
					 | 
				
			||||||
        data = data.replace(b'\n', b'\r\n')
 | 
					 | 
				
			||||||
        target.write(data)
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    async def handle_DATA(self, server: SMTPServer, session: SMTPSession,
 | 
				
			||||||
 | 
					                          envelope: SMTPEnvelope) -> str:
 | 
				
			||||||
 | 
					        self.rcpt_tos = envelope.rcpt_tos
 | 
				
			||||||
 | 
					        return await super().handle_DATA(server, session, envelope)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class MailboxCRLF(Mailbox):
 | 
					    async def handle_message(self, m: Message):  # type: ignore[override]
 | 
				
			||||||
 | 
					        all_mboxes: set[str] = set()
 | 
				
			||||||
    def __init__(self, mail_dir: Path):
 | 
					        for addr in self.rcpt_tos:
 | 
				
			||||||
        super().__init__(mail_dir)
 | 
					            all_mboxes.union(self.mbox_finder(addr))
 | 
				
			||||||
 | 
					        if not all_mboxes:
 | 
				
			||||||
 | 
					            return
 | 
				
			||||||
 | 
					        for mbox in all_mboxes:
 | 
				
			||||||
            for sub in ('new', 'tmp', 'cur'):
 | 
					            for sub in ('new', 'tmp', 'cur'):
 | 
				
			||||||
            sub_path = mail_dir / sub
 | 
					                sub_path = self.mails_path / mbox / sub
 | 
				
			||||||
                sub_path.mkdir(mode=0o755, exist_ok=True, parents=True)
 | 
					                sub_path.mkdir(mode=0o755, exist_ok=True, parents=True)
 | 
				
			||||||
        self.mailbox = MaildirCRLF(mail_dir)
 | 
					        with tempfile.TemporaryDirectory() as tmpdir:
 | 
				
			||||||
 | 
					            temp_email_path = Path(tmpdir) / f"{uuid.uuid4()}.eml"
 | 
				
			||||||
 | 
					            with open(temp_email_path, "wb") as fp:
 | 
				
			||||||
 | 
					                gen = BytesGenerator(fp, policy=email.policy.SMTP)
 | 
				
			||||||
 | 
					                gen.flatten(m)
 | 
				
			||||||
 | 
					            for mbox in all_mboxes:
 | 
				
			||||||
 | 
					                shutil.copy(temp_email_path, self.mails_path / mbox / 'new')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def protocol_factory_starttls(dirpath: Path,
 | 
					def protocol_factory_starttls(mails_path: Path,
 | 
				
			||||||
 | 
					                              mbox_finder: Callable[[str], [str]],
 | 
				
			||||||
                              context: ssl.SSLContext | None = None):
 | 
					                              context: ssl.SSLContext | None = None):
 | 
				
			||||||
    logging.info("Got smtp client cb")
 | 
					    logging.info("Got smtp client cb starttls")
 | 
				
			||||||
    try:
 | 
					    try:
 | 
				
			||||||
        handler = MailboxCRLF(dirpath)
 | 
					        handler = MyHandler(mails_path, mbox_finder)
 | 
				
			||||||
        smtp = SMTP(handler=handler,
 | 
					        smtp = SMTP(handler=handler,
 | 
				
			||||||
                    require_starttls=True,
 | 
					                    require_starttls=True,
 | 
				
			||||||
                    tls_context=context,
 | 
					                    tls_context=context,
 | 
				
			||||||
@@ -48,10 +68,10 @@ def protocol_factory_starttls(dirpath: Path,
 | 
				
			|||||||
    return smtp
 | 
					    return smtp
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def protocol_factory(dirpath: Path):
 | 
					def protocol_factory(mails_path: Path, mbox_finder: Callable[[str], [str]]):
 | 
				
			||||||
    logging.info("Got smtp client cb")
 | 
					    logging.info("Got smtp client cb")
 | 
				
			||||||
    try:
 | 
					    try:
 | 
				
			||||||
        handler = MailboxCRLF(dirpath)
 | 
					        handler = MyHandler(mails_path, mbox_finder)
 | 
				
			||||||
        smtp = SMTP(handler=handler, enable_SMTPUTF8=True)
 | 
					        smtp = SMTP(handler=handler, enable_SMTPUTF8=True)
 | 
				
			||||||
    except Exception as e:
 | 
					    except Exception as e:
 | 
				
			||||||
        logging.error("Something went wrong", e)
 | 
					        logging.error("Something went wrong", e)
 | 
				
			||||||
@@ -59,19 +79,21 @@ def protocol_factory(dirpath: Path):
 | 
				
			|||||||
    return smtp
 | 
					    return smtp
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
async def create_smtp_server_starttls(dirpath: Path,
 | 
					async def create_smtp_server_starttls(host: str,
 | 
				
			||||||
                                      port: int,
 | 
					                                      port: int,
 | 
				
			||||||
                                      host="",
 | 
					                                      mails_path: Path,
 | 
				
			||||||
 | 
					                                      mbox_finder: Callable[[str], [str]],
 | 
				
			||||||
                                      context: ssl.SSLContext | None = None):
 | 
					                                      context: ssl.SSLContext | None = None):
 | 
				
			||||||
    loop = asyncio.get_event_loop()
 | 
					    loop = asyncio.get_event_loop()
 | 
				
			||||||
    return await loop.create_server(partial(protocol_factory_starttls, dirpath,
 | 
					    return await loop.create_server(partial(protocol_factory_starttls,
 | 
				
			||||||
                                            context),
 | 
					                                            mails_path, mbox_finder, context),
 | 
				
			||||||
                                    host=host,
 | 
					                                    host=host,
 | 
				
			||||||
                                    port=port,
 | 
					                                    port=port,
 | 
				
			||||||
                                    start_serving=False)
 | 
					                                    start_serving=False)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
async def create_smtp_server_tls(dirpath: Path,
 | 
					async def create_smtp_server_tls(dirpath: Path,
 | 
				
			||||||
 | 
					                                 mbox_finder: Callable[[str], [str]],
 | 
				
			||||||
                                 port: int,
 | 
					                                 port: int,
 | 
				
			||||||
                                 host="",
 | 
					                                 host="",
 | 
				
			||||||
                                 context: ssl.SSLContext | None = None):
 | 
					                                 context: ssl.SSLContext | None = None):
 | 
				
			||||||
 
 | 
				
			|||||||
		Reference in New Issue
	
	Block a user