refactor server

This commit is contained in:
Balakrishnan Balasubramanian 2023-06-13 13:44:10 -04:00
parent 622bb0f473
commit 93bfb3d41c
5 changed files with 113 additions and 53 deletions

View File

@ -34,21 +34,46 @@ class User(Jata):
mbox: str mbox: str
class Config(Jata): class TLSCfg(Jata):
certfile: str certfile: str
keyfile: str keyfile: str
debug: bool = False
mails_path: str
host = '0.0.0.0' class ServerCfg(Jata):
smtp_port = 25 host: str = "default"
smtp_port_tls = 465 port: int
smtp_port_submission = 587 tls: TLSCfg | str = "default"
pop_port = 995
pop_timeout_seconds = 60
class PopCfg(ServerCfg):
port = 995
timeout_seconds = 60
class SmtpStartTLSCfg(ServerCfg):
smtputf8 = True smtputf8 = True
port = 25
class SmtpCfg(ServerCfg):
smtputf8 = True
port = 465
class Config(Jata):
default_tls: TLSCfg | None
default_host: str = '0.0.0.0'
mails_path: str
users: list[User] users: list[User]
boxes: list[Mbox] boxes: list[Mbox]
matches: list[Match] matches: list[Match]
debug: bool = False
pop: PopCfg | None
smtp_starttls: SmtpStartTLSCfg | None
smtp: SmtpCfg | None
# smtp_port_submission = 587
CheckerFn = Callable[[str], bool] CheckerFn = Callable[[str], bool]
@ -69,10 +94,7 @@ def parse_checkers(cfg: Config) -> list[Checker]:
else: else:
raise Exception("Neither addrs nor addr_rexs is set") raise Exception("Neither addrs nor addr_rexs is set")
matches = { matches = {m.name: make_match_fn(Match(m)) for m in cfg.matches or []}
m.name: make_match_fn(m)
for match in cfg.matches if (m := Match(match)) is not None
}
matches[DEFAULT_MATCH_ALL] = lambda _: True matches[DEFAULT_MATCH_ALL] = lambda _: True
def make_checker(mbox_name: str, rule: Rule) -> Checker: def make_checker(mbox_name: str, rule: Rule) -> Checker:
@ -84,7 +106,7 @@ def parse_checkers(cfg: Config) -> list[Checker]:
return mbox_name, match_fn, rule.stop_check return mbox_name, match_fn, rule.stop_check
return [ return [
make_checker(mbox.name, Rule(rule)) for mbox in cfg.boxes make_checker(mbox.name, Rule(rule)) for mbox in cfg.boxes or []
for rule in mbox.rules for rule in mbox.rules
] ]

View File

@ -317,7 +317,7 @@ async def create_pop_server(host: str,
mails_path: Path, mails_path: Path,
users: list[User], users: list[User],
ssl_context: ssl.SSLContext | None = None, ssl_context: ssl.SSLContext | None = None,
timeout_seconds: int = 60): timeout_seconds: int = 60) -> asyncio.Server:
logging.info( logging.info(
f"Starting POP3 server {host=}, {port=}, {mails_path=}, {len(users)=}, {ssl_context != None=}, {timeout_seconds=}" f"Starting POP3 server {host=}, {port=}, {mails_path=}, {len(users)=}, {ssl_context != None=}, {timeout_seconds=}"
) )

View File

@ -6,13 +6,13 @@ import sys
from argparse import ArgumentParser from argparse import ArgumentParser
from pathlib import Path from pathlib import Path
from .smtp import create_smtp_server_starttls, create_smtp_server_tls from .smtp import create_smtp_server_starttls, create_smtp_server
from .pop3 import create_pop_server from .pop3 import create_pop_server
from .config import Config from . import config
def create_tls_context(certfile, keyfile): def create_tls_context(certfile, keyfile) -> ssl.SSLContext:
context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH) context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH)
context.load_cert_chain(certfile=certfile, keyfile=keyfile) context.load_cert_chain(certfile=certfile, keyfile=keyfile)
return context return context
@ -25,44 +25,82 @@ def setup_logging(args):
logging.basicConfig(level=logging.INFO) logging.basicConfig(level=logging.INFO)
async def a_main(config, tls_context): async def a_main(cfg: config.Config) -> None:
default_tls_context: ssl.SSLContext | None = None
if tls := cfg.default_tls:
default_tls_context = create_tls_context(tls.certfile, tls.keyfile)
def get_tls_context(tls: config.TLSCfg | str):
if tls == "default":
return default_tls_context
elif tls == "disable":
return None
else:
tls_cfg = config.TLSCfg(pop.tls)
return create_tls_context(tls_cfg.certfile, tls_cfg.keyfile)
def get_host(host):
if host == "default":
return cfg.default_host
else:
return host
mbox_finder = config.gen_addr_to_mboxes(cfg)
servers: list[asyncio.Server] = []
if cfg.pop:
pop = config.PopCfg(cfg.pop)
pop_server = await create_pop_server( pop_server = await create_pop_server(
host=config.host, host=get_host(pop.host),
port=config.pop_port, port=pop.port,
mails_path=config.mails_path, mails_path=Path(cfg.mails_path),
users=config.users, users=cfg.users,
ssl_context=tls_context, ssl_context=get_tls_context(pop.tls),
timeout_seconds=config.pop_timeout_seconds) timeout_seconds=pop.timeout_seconds)
servers.append(pop_server)
if cfg.smtp_starttls:
stls = config.SmtpStartTLSCfg(cfg.smtp_starttls)
stls_context = get_tls_context(stls.tls)
if not stls_context:
raise Exception("starttls requires ssl_context")
smtp_server_starttls = await create_smtp_server_starttls( smtp_server_starttls = await create_smtp_server_starttls(
config.mail_dir_path, host=get_host(stls.host),
port=config.smtp_port, port=stls.port,
host=config.host, mails_path=Path(cfg.mails_path),
context=tls_context) mbox_finder=mbox_finder,
ssl_context=stls_context)
servers.append(smtp_server_starttls)
smtp_server_tls = await create_smtp_server_tls(config.mail_dir_path, if cfg.smtp:
port=config.smtp_port_tls, smtp = config.SmtpCfg(cfg.smtp)
host=config.host, smtp_server = await create_smtp_server(host=get_host(smtp.host),
context=tls_context) port=smtp.port,
mails_path=Path(cfg.mails_path),
mbox_finder=mbox_finder,
ssl_context=get_tls_context(
smtp.tls))
servers.append(smtp_server)
await asyncio.gather(pop_server.serve_forever(), if servers:
smtp_server_starttls.serve_forever(), await asyncio.gather(server.serve_forever() for server in servers)
smtp_server_tls.serve_forever()) else:
logging.warn("Nothing to do!")
def main(): def main():
parser = ArgumentParser() parser = ArgumentParser()
parser.add_argument("config_path") parser.add_argument("config_path", type=Path)
args = parser.parse_args() args = parser.parse_args()
config = Config(open(args.config_path).read()) config = Config(args.config_path.read_text())
setup_logging(args) setup_logging(args)
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
loop.set_debug(config.debug) loop.set_debug(config.debug)
tls_context = create_tls_context(config.certfile, config.keyfile) asyncio.run(a_main(config))
asyncio.run(a_main(config, tls_context))
if __name__ == '__main__': if __name__ == '__main__':

View File

@ -55,7 +55,7 @@ class MyHandler(AsyncMessage):
def protocol_factory_starttls(mails_path: Path, def protocol_factory_starttls(mails_path: Path,
mbox_finder: Callable[[str], [str]], mbox_finder: Callable[[str], [str]],
context: ssl.SSLContext | None = None): context: ssl.SSLContext):
logging.info("Got smtp client cb starttls") logging.info("Got smtp client cb starttls")
try: try:
handler = MyHandler(mails_path, mbox_finder) handler = MyHandler(mails_path, mbox_finder)
@ -84,25 +84,25 @@ async def create_smtp_server_starttls(host: str,
port: int, port: int,
mails_path: Path, mails_path: Path,
mbox_finder: Callable[[str], [str]], mbox_finder: Callable[[str], [str]],
context: ssl.SSLContext | None = None): ssl_context: ssl.SSLContext) -> asyncio.Server:
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
return await loop.create_server(partial(protocol_factory_starttls, return await loop.create_server(partial(protocol_factory_starttls,
mails_path, mbox_finder, context), mails_path, mbox_finder, ssl_context),
host=host, host=host,
port=port, port=port,
start_serving=False) start_serving=False)
async def create_smtp_server_tls(host: str, async def create_smtp_server(host: str,
port: int, port: int,
mails_path: Path, mails_path: Path,
mbox_finder: Callable[[str], [str]], mbox_finder: Callable[[str], [str]],
context: ssl.SSLContext | None = None): ssl_context: ssl.SSLContext | None = None) -> asyncio.Server:
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
return await loop.create_server(partial(protocol_factory, mails_path, mbox_finder), return await loop.create_server(partial(protocol_factory, mails_path, mbox_finder),
host=host, host=host,
port=port, port=port,
ssl=context, ssl=ssl_context,
start_serving=False) start_serving=False)

View File

@ -7,7 +7,7 @@ import os
from pathlib import Path from pathlib import Path
from .smtp import create_smtp_server_tls from .smtp import create_smtp_server
TEST_MBOX = 'foobar_mails' TEST_MBOX = 'foobar_mails'
MAILS_PATH: Path MAILS_PATH: Path
@ -28,7 +28,7 @@ class TestSMTP(unittest.IsolatedAsyncioTestCase):
async def asyncSetUp(self) -> None: async def asyncSetUp(self) -> None:
logging.basicConfig(level=logging.DEBUG) logging.basicConfig(level=logging.DEBUG)
smtp_server = await create_smtp_server_tls( smtp_server = await create_smtp_server(
host="127.0.0.1", host="127.0.0.1",
port=7996, port=7996,
mails_path=MAILS_PATH, mails_path=MAILS_PATH,