refactor server

This commit is contained in:
Balakrishnan Balasubramanian 2023-06-13 13:44:10 -04:00
parent 622bb0f473
commit 93bfb3d41c
5 changed files with 113 additions and 53 deletions

View File

@ -34,21 +34,46 @@ class User(Jata):
mbox: str
class Config(Jata):
class TLSCfg(Jata):
certfile: str
keyfile: str
debug: bool = False
mails_path: str
host = '0.0.0.0'
smtp_port = 25
smtp_port_tls = 465
smtp_port_submission = 587
pop_port = 995
pop_timeout_seconds = 60
class ServerCfg(Jata):
host: str = "default"
port: int
tls: TLSCfg | str = "default"
class PopCfg(ServerCfg):
port = 995
timeout_seconds = 60
class SmtpStartTLSCfg(ServerCfg):
smtputf8 = True
port = 25
class SmtpCfg(ServerCfg):
smtputf8 = True
port = 465
class Config(Jata):
default_tls: TLSCfg | None
default_host: str = '0.0.0.0'
mails_path: str
users: list[User]
boxes: list[Mbox]
matches: list[Match]
debug: bool = False
pop: PopCfg | None
smtp_starttls: SmtpStartTLSCfg | None
smtp: SmtpCfg | None
# smtp_port_submission = 587
CheckerFn = Callable[[str], bool]
@ -69,10 +94,7 @@ def parse_checkers(cfg: Config) -> list[Checker]:
else:
raise Exception("Neither addrs nor addr_rexs is set")
matches = {
m.name: make_match_fn(m)
for match in cfg.matches if (m := Match(match)) is not None
}
matches = {m.name: make_match_fn(Match(m)) for m in cfg.matches or []}
matches[DEFAULT_MATCH_ALL] = lambda _: True
def make_checker(mbox_name: str, rule: Rule) -> Checker:
@ -84,7 +106,7 @@ def parse_checkers(cfg: Config) -> list[Checker]:
return mbox_name, match_fn, rule.stop_check
return [
make_checker(mbox.name, Rule(rule)) for mbox in cfg.boxes
make_checker(mbox.name, Rule(rule)) for mbox in cfg.boxes or []
for rule in mbox.rules
]

View File

@ -317,7 +317,7 @@ async def create_pop_server(host: str,
mails_path: Path,
users: list[User],
ssl_context: ssl.SSLContext | None = None,
timeout_seconds: int = 60):
timeout_seconds: int = 60) -> asyncio.Server:
logging.info(
f"Starting POP3 server {host=}, {port=}, {mails_path=}, {len(users)=}, {ssl_context != None=}, {timeout_seconds=}"
)

View File

@ -6,13 +6,13 @@ import sys
from argparse import ArgumentParser
from pathlib import Path
from .smtp import create_smtp_server_starttls, create_smtp_server_tls
from .smtp import create_smtp_server_starttls, create_smtp_server
from .pop3 import create_pop_server
from .config import Config
from . import config
def create_tls_context(certfile, keyfile):
def create_tls_context(certfile, keyfile) -> ssl.SSLContext:
context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH)
context.load_cert_chain(certfile=certfile, keyfile=keyfile)
return context
@ -25,44 +25,82 @@ def setup_logging(args):
logging.basicConfig(level=logging.INFO)
async def a_main(config, tls_context):
async def a_main(cfg: config.Config) -> None:
default_tls_context: ssl.SSLContext | None = None
if tls := cfg.default_tls:
default_tls_context = create_tls_context(tls.certfile, tls.keyfile)
def get_tls_context(tls: config.TLSCfg | str):
if tls == "default":
return default_tls_context
elif tls == "disable":
return None
else:
tls_cfg = config.TLSCfg(pop.tls)
return create_tls_context(tls_cfg.certfile, tls_cfg.keyfile)
def get_host(host):
if host == "default":
return cfg.default_host
else:
return host
mbox_finder = config.gen_addr_to_mboxes(cfg)
servers: list[asyncio.Server] = []
if cfg.pop:
pop = config.PopCfg(cfg.pop)
pop_server = await create_pop_server(
host=config.host,
port=config.pop_port,
mails_path=config.mails_path,
users=config.users,
ssl_context=tls_context,
timeout_seconds=config.pop_timeout_seconds)
host=get_host(pop.host),
port=pop.port,
mails_path=Path(cfg.mails_path),
users=cfg.users,
ssl_context=get_tls_context(pop.tls),
timeout_seconds=pop.timeout_seconds)
servers.append(pop_server)
if cfg.smtp_starttls:
stls = config.SmtpStartTLSCfg(cfg.smtp_starttls)
stls_context = get_tls_context(stls.tls)
if not stls_context:
raise Exception("starttls requires ssl_context")
smtp_server_starttls = await create_smtp_server_starttls(
config.mail_dir_path,
port=config.smtp_port,
host=config.host,
context=tls_context)
host=get_host(stls.host),
port=stls.port,
mails_path=Path(cfg.mails_path),
mbox_finder=mbox_finder,
ssl_context=stls_context)
servers.append(smtp_server_starttls)
smtp_server_tls = await create_smtp_server_tls(config.mail_dir_path,
port=config.smtp_port_tls,
host=config.host,
context=tls_context)
if cfg.smtp:
smtp = config.SmtpCfg(cfg.smtp)
smtp_server = await create_smtp_server(host=get_host(smtp.host),
port=smtp.port,
mails_path=Path(cfg.mails_path),
mbox_finder=mbox_finder,
ssl_context=get_tls_context(
smtp.tls))
servers.append(smtp_server)
await asyncio.gather(pop_server.serve_forever(),
smtp_server_starttls.serve_forever(),
smtp_server_tls.serve_forever())
if servers:
await asyncio.gather(server.serve_forever() for server in servers)
else:
logging.warn("Nothing to do!")
def main():
parser = ArgumentParser()
parser.add_argument("config_path")
parser.add_argument("config_path", type=Path)
args = parser.parse_args()
config = Config(open(args.config_path).read())
config = Config(args.config_path.read_text())
setup_logging(args)
loop = asyncio.get_event_loop()
loop.set_debug(config.debug)
tls_context = create_tls_context(config.certfile, config.keyfile)
asyncio.run(a_main(config, tls_context))
asyncio.run(a_main(config))
if __name__ == '__main__':

View File

@ -55,7 +55,7 @@ class MyHandler(AsyncMessage):
def protocol_factory_starttls(mails_path: Path,
mbox_finder: Callable[[str], [str]],
context: ssl.SSLContext | None = None):
context: ssl.SSLContext):
logging.info("Got smtp client cb starttls")
try:
handler = MyHandler(mails_path, mbox_finder)
@ -84,25 +84,25 @@ async def create_smtp_server_starttls(host: str,
port: int,
mails_path: Path,
mbox_finder: Callable[[str], [str]],
context: ssl.SSLContext | None = None):
ssl_context: ssl.SSLContext) -> asyncio.Server:
loop = asyncio.get_event_loop()
return await loop.create_server(partial(protocol_factory_starttls,
mails_path, mbox_finder, context),
mails_path, mbox_finder, ssl_context),
host=host,
port=port,
start_serving=False)
async def create_smtp_server_tls(host: str,
async def create_smtp_server(host: str,
port: int,
mails_path: Path,
mbox_finder: Callable[[str], [str]],
context: ssl.SSLContext | None = None):
ssl_context: ssl.SSLContext | None = None) -> asyncio.Server:
loop = asyncio.get_event_loop()
return await loop.create_server(partial(protocol_factory, mails_path, mbox_finder),
host=host,
port=port,
ssl=context,
ssl=ssl_context,
start_serving=False)

View File

@ -7,7 +7,7 @@ import os
from pathlib import Path
from .smtp import create_smtp_server_tls
from .smtp import create_smtp_server
TEST_MBOX = 'foobar_mails'
MAILS_PATH: Path
@ -28,7 +28,7 @@ class TestSMTP(unittest.IsolatedAsyncioTestCase):
async def asyncSetUp(self) -> None:
logging.basicConfig(level=logging.DEBUG)
smtp_server = await create_smtp_server_tls(
smtp_server = await create_smtp_server(
host="127.0.0.1",
port=7996,
mails_path=MAILS_PATH,