4 Commits

6 changed files with 44 additions and 25 deletions

View File

@ -56,13 +56,14 @@ class PopCfg(ServerCfg):
class SmtpStartTLSCfg(ServerCfg): class SmtpStartTLSCfg(ServerCfg):
server_type = "smtp_starttls" server_type = "smtp_starttls"
smtputf8 = True # Not used yet require_starttls = True
smtputf8 = True
port = 25 port = 25
class SmtpCfg(ServerCfg): class SmtpCfg(ServerCfg):
server_type = "smtp_starttls" server_type = "smtp"
smtputf8 = True # Not used yet smtputf8 = True
port = 465 port = 465

View File

@ -2,11 +2,8 @@ import asyncio
import contextlib import contextlib
import contextvars import contextvars
import logging import logging
import os
import ssl import ssl
import uuid
from dataclasses import dataclass from dataclasses import dataclass
from hashlib import sha256
from pathlib import Path from pathlib import Path
from .config import User from .config import User
from .pwhash import parse_hash, check_pass, PWInfo from .pwhash import parse_hash, check_pass, PWInfo

View File

@ -1,8 +1,6 @@
import asyncio import asyncio
import logging import logging
import os
import ssl import ssl
import sys
from argparse import ArgumentParser from argparse import ArgumentParser
from pathlib import Path from pathlib import Path
from getpass import getpass from getpass import getpass
@ -86,6 +84,8 @@ async def a_main(cfg: config.Config) -> None:
mails_path=Path(cfg.mails_path), mails_path=Path(cfg.mails_path),
mbox_finder=mbox_finder, mbox_finder=mbox_finder,
ssl_context=stls_context, ssl_context=stls_context,
require_starttls=stls.require_starttls,
smtputf8=stls.smtputf8,
) )
servers.append(smtp_server_starttls) servers.append(smtp_server_starttls)
elif scfg.server_type == "smtp": elif scfg.server_type == "smtp":
@ -96,6 +96,7 @@ async def a_main(cfg: config.Config) -> None:
mails_path=Path(cfg.mails_path), mails_path=Path(cfg.mails_path),
mbox_finder=mbox_finder, mbox_finder=mbox_finder,
ssl_context=get_tls_context(smtp.tls), ssl_context=get_tls_context(smtp.tls),
smtputf8=smtp.smtputf8,
) )
servers.append(smtp_server) servers.append(smtp_server)
else: else:

View File

@ -1,7 +1,5 @@
import asyncio import asyncio
import io
import logging import logging
import mailbox
import ssl import ssl
import uuid import uuid
import shutil import shutil
@ -13,11 +11,9 @@ from email.message import Message
import email.policy import email.policy
from email.generator import BytesGenerator from email.generator import BytesGenerator
import tempfile import tempfile
import random
from aiosmtpd.handlers import Mailbox, AsyncMessage from aiosmtpd.handlers import AsyncMessage
from aiosmtpd.smtp import SMTP, DATA_SIZE_DEFAULT from aiosmtpd.smtp import SMTP
from aiosmtpd.smtp import SMTP as SMTPServer
from aiosmtpd.smtp import Envelope as SMTPEnvelope from aiosmtpd.smtp import Envelope as SMTPEnvelope
from aiosmtpd.smtp import Session as SMTPSession from aiosmtpd.smtp import Session as SMTPSession
@ -31,7 +27,7 @@ class MyHandler(AsyncMessage):
self.mbox_finder = mbox_finder self.mbox_finder = mbox_finder
async def handle_DATA( async def handle_DATA(
self, server: SMTPServer, session: SMTPSession, envelope: SMTPEnvelope self, server: SMTP, session: SMTPSession, envelope: SMTPEnvelope
) -> str: ) -> str:
self.rcpt_tos = envelope.rcpt_tos self.rcpt_tos = envelope.rcpt_tos
self.peer = session.peer self.peer = session.peer
@ -63,16 +59,20 @@ class MyHandler(AsyncMessage):
def protocol_factory_starttls( def protocol_factory_starttls(
mails_path: Path, mbox_finder: Callable[[str], list[str]], context: ssl.SSLContext mails_path: Path,
mbox_finder: Callable[[str], list[str]],
context: ssl.SSLContext,
require_starttls: bool,
smtputf8: bool,
): ):
logger.info("Got smtp client cb starttls") logger.info("Got smtp client cb starttls")
try: try:
handler = MyHandler(mails_path, mbox_finder) handler = MyHandler(mails_path, mbox_finder)
smtp = SMTP( smtp = SMTP(
handler=handler, handler=handler,
require_starttls=True, require_starttls=require_starttls,
tls_context=context, tls_context=context,
enable_SMTPUTF8=True, enable_SMTPUTF8=smtputf8,
) )
except: except:
logger.exception("Something went wrong") logger.exception("Something went wrong")
@ -80,11 +80,13 @@ def protocol_factory_starttls(
return smtp return smtp
def protocol_factory(mails_path: Path, mbox_finder: Callable[[str], list[str]]): def protocol_factory(
mails_path: Path, mbox_finder: Callable[[str], list[str]], smtputf8: bool
):
logger.info("Got smtp client cb") logger.info("Got smtp client cb")
try: try:
handler = MyHandler(mails_path, mbox_finder) handler = MyHandler(mails_path, mbox_finder)
smtp = SMTP(handler=handler, enable_SMTPUTF8=True) smtp = SMTP(handler=handler, enable_SMTPUTF8=smtputf8)
except: except:
logger.exception("Something went wrong") logger.exception("Something went wrong")
raise raise
@ -97,13 +99,22 @@ async def create_smtp_server_starttls(
mails_path: Path, mails_path: Path,
mbox_finder: Callable[[str], list[str]], mbox_finder: Callable[[str], list[str]],
ssl_context: ssl.SSLContext, ssl_context: ssl.SSLContext,
require_starttls: bool,
smtputf8: bool,
) -> asyncio.Server: ) -> asyncio.Server:
logging.info( logging.info(
f"Starting SMTP STARTTLS server {host=}, {port=}, {mails_path=!s}, {ssl_context != None=}" f"Starting SMTP STARTTLS server {host=}, {port=}, {mails_path=!s}, {ssl_context != None=}"
) )
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
return await loop.create_server( return await loop.create_server(
partial(protocol_factory_starttls, mails_path, mbox_finder, ssl_context), partial(
protocol_factory_starttls,
mails_path,
mbox_finder,
ssl_context,
require_starttls,
smtputf8,
),
host=host, host=host,
port=port, port=port,
start_serving=False, start_serving=False,
@ -115,14 +126,15 @@ async def create_smtp_server(
port: int, port: int,
mails_path: Path, mails_path: Path,
mbox_finder: Callable[[str], list[str]], mbox_finder: Callable[[str], list[str]],
ssl_context: Optional[ssl.SSLContext] = None, ssl_context: Optional[ssl.SSLContext],
smtputf8: bool,
) -> asyncio.Server: ) -> asyncio.Server:
logging.info( logging.info(
f"Starting SMTP server {host=}, {port=}, {mails_path=!s}, {ssl_context != None=}" f"Starting SMTP server {host=}, {port=}, {mails_path=!s}, {ssl_context != None=}"
) )
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
return await loop.create_server( return await loop.create_server(
partial(protocol_factory, mails_path, mbox_finder), partial(protocol_factory, mails_path, mbox_finder, smtputf8),
host=host, host=host,
port=port, port=port,
ssl=ssl_context, ssl=ssl_context,

View File

@ -89,6 +89,9 @@ class TestPop3(unittest.IsolatedAsyncioTestCase):
self.task = asyncio.create_task(pop_server.serve_forever()) self.task = asyncio.create_task(pop_server.serve_forever())
self.reader, self.writer = await asyncio.open_connection("127.0.0.1", 7995) self.reader, self.writer = await asyncio.open_connection("127.0.0.1", 7995)
# Additional writers to close
self.ws: list[asyncio.StreamWriter] = []
async def test_QUIT(self) -> None: async def test_QUIT(self) -> None:
dialog = """ dialog = """
S: +OK Server Ready S: +OK Server Ready
@ -133,6 +136,7 @@ class TestPop3(unittest.IsolatedAsyncioTestCase):
async def test_dupe_AUTH(self) -> None: async def test_dupe_AUTH(self) -> None:
r1, w1 = await asyncio.open_connection("127.0.0.1", 7995) r1, w1 = await asyncio.open_connection("127.0.0.1", 7995)
r2, w2 = await asyncio.open_connection("127.0.0.1", 7995) r2, w2 = await asyncio.open_connection("127.0.0.1", 7995)
self.ws += w1, w2
dialog = """ dialog = """
S: +OK Server Ready S: +OK Server Ready
C: USER foobar C: USER foobar
@ -231,8 +235,10 @@ class TestPop3(unittest.IsolatedAsyncioTestCase):
async def asyncTearDown(self) -> None: async def asyncTearDown(self) -> None:
logging.debug("at teardown") logging.debug("at teardown")
self.writer.close() for w in self.ws + [self.writer]:
await self.writer.wait_closed() w.close()
await w.wait_closed()
self.ws.clear()
self.task.cancel("test done") self.task.cancel("test done")
async def dialog_checker(self, dialog: str) -> None: async def dialog_checker(self, dialog: str) -> None:

View File

@ -33,6 +33,8 @@ class TestSMTP(unittest.IsolatedAsyncioTestCase):
port=7996, port=7996,
mails_path=MAILS_PATH, mails_path=MAILS_PATH,
mbox_finder=lambda addr: [TEST_MBOX], mbox_finder=lambda addr: [TEST_MBOX],
ssl_context=None,
smtputf8=True,
) )
self.task = asyncio.create_task(smtp_server.serve_forever()) self.task = asyncio.create_task(smtp_server.serve_forever())