refactor server
This commit is contained in:
parent
622bb0f473
commit
93bfb3d41c
@ -6,8 +6,8 @@ from jata import Jata, MutableDefault
|
||||
|
||||
class Match(Jata):
|
||||
name: str
|
||||
addrs: list[str] = MutableDefault(lambda: []) # type: ignore
|
||||
addr_rexs: list[str] = MutableDefault(lambda: []) # type: ignore
|
||||
addrs: list[str] = MutableDefault(lambda: []) # type: ignore
|
||||
addr_rexs: list[str] = MutableDefault(lambda: []) # type: ignore
|
||||
|
||||
|
||||
DEFAULT_MATCH_ALL = "default_match_all"
|
||||
@ -34,21 +34,46 @@ class User(Jata):
|
||||
mbox: str
|
||||
|
||||
|
||||
class Config(Jata):
|
||||
class TLSCfg(Jata):
|
||||
certfile: str
|
||||
keyfile: str
|
||||
debug: bool = False
|
||||
mails_path: str
|
||||
host = '0.0.0.0'
|
||||
smtp_port = 25
|
||||
smtp_port_tls = 465
|
||||
smtp_port_submission = 587
|
||||
pop_port = 995
|
||||
pop_timeout_seconds = 60
|
||||
|
||||
|
||||
class ServerCfg(Jata):
|
||||
host: str = "default"
|
||||
port: int
|
||||
tls: TLSCfg | str = "default"
|
||||
|
||||
|
||||
class PopCfg(ServerCfg):
|
||||
port = 995
|
||||
timeout_seconds = 60
|
||||
|
||||
|
||||
class SmtpStartTLSCfg(ServerCfg):
|
||||
smtputf8 = True
|
||||
port = 25
|
||||
|
||||
|
||||
class SmtpCfg(ServerCfg):
|
||||
smtputf8 = True
|
||||
port = 465
|
||||
|
||||
|
||||
class Config(Jata):
|
||||
default_tls: TLSCfg | None
|
||||
default_host: str = '0.0.0.0'
|
||||
|
||||
mails_path: str
|
||||
users: list[User]
|
||||
boxes: list[Mbox]
|
||||
matches: list[Match]
|
||||
debug: bool = False
|
||||
|
||||
pop: PopCfg | None
|
||||
smtp_starttls: SmtpStartTLSCfg | None
|
||||
smtp: SmtpCfg | None
|
||||
# smtp_port_submission = 587
|
||||
|
||||
|
||||
CheckerFn = Callable[[str], bool]
|
||||
@ -69,10 +94,7 @@ def parse_checkers(cfg: Config) -> list[Checker]:
|
||||
else:
|
||||
raise Exception("Neither addrs nor addr_rexs is set")
|
||||
|
||||
matches = {
|
||||
m.name: make_match_fn(m)
|
||||
for match in cfg.matches if (m := Match(match)) is not None
|
||||
}
|
||||
matches = {m.name: make_match_fn(Match(m)) for m in cfg.matches or []}
|
||||
matches[DEFAULT_MATCH_ALL] = lambda _: True
|
||||
|
||||
def make_checker(mbox_name: str, rule: Rule) -> Checker:
|
||||
@ -84,7 +106,7 @@ def parse_checkers(cfg: Config) -> list[Checker]:
|
||||
return mbox_name, match_fn, rule.stop_check
|
||||
|
||||
return [
|
||||
make_checker(mbox.name, Rule(rule)) for mbox in cfg.boxes
|
||||
make_checker(mbox.name, Rule(rule)) for mbox in cfg.boxes or []
|
||||
for rule in mbox.rules
|
||||
]
|
||||
|
||||
|
@ -317,7 +317,7 @@ async def create_pop_server(host: str,
|
||||
mails_path: Path,
|
||||
users: list[User],
|
||||
ssl_context: ssl.SSLContext | None = None,
|
||||
timeout_seconds: int = 60):
|
||||
timeout_seconds: int = 60) -> asyncio.Server:
|
||||
logging.info(
|
||||
f"Starting POP3 server {host=}, {port=}, {mails_path=}, {len(users)=}, {ssl_context != None=}, {timeout_seconds=}"
|
||||
)
|
||||
|
@ -6,13 +6,13 @@ import sys
|
||||
from argparse import ArgumentParser
|
||||
from pathlib import Path
|
||||
|
||||
from .smtp import create_smtp_server_starttls, create_smtp_server_tls
|
||||
from .smtp import create_smtp_server_starttls, create_smtp_server
|
||||
from .pop3 import create_pop_server
|
||||
|
||||
from .config import Config
|
||||
from . import config
|
||||
|
||||
|
||||
def create_tls_context(certfile, keyfile):
|
||||
def create_tls_context(certfile, keyfile) -> ssl.SSLContext:
|
||||
context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH)
|
||||
context.load_cert_chain(certfile=certfile, keyfile=keyfile)
|
||||
return context
|
||||
@ -25,44 +25,82 @@ def setup_logging(args):
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
||||
|
||||
async def a_main(config, tls_context):
|
||||
pop_server = await create_pop_server(
|
||||
host=config.host,
|
||||
port=config.pop_port,
|
||||
mails_path=config.mails_path,
|
||||
users=config.users,
|
||||
ssl_context=tls_context,
|
||||
timeout_seconds=config.pop_timeout_seconds)
|
||||
async def a_main(cfg: config.Config) -> None:
|
||||
|
||||
smtp_server_starttls = await create_smtp_server_starttls(
|
||||
config.mail_dir_path,
|
||||
port=config.smtp_port,
|
||||
host=config.host,
|
||||
context=tls_context)
|
||||
default_tls_context: ssl.SSLContext | None = None
|
||||
|
||||
smtp_server_tls = await create_smtp_server_tls(config.mail_dir_path,
|
||||
port=config.smtp_port_tls,
|
||||
host=config.host,
|
||||
context=tls_context)
|
||||
if tls := cfg.default_tls:
|
||||
default_tls_context = create_tls_context(tls.certfile, tls.keyfile)
|
||||
|
||||
await asyncio.gather(pop_server.serve_forever(),
|
||||
smtp_server_starttls.serve_forever(),
|
||||
smtp_server_tls.serve_forever())
|
||||
def get_tls_context(tls: config.TLSCfg | str):
|
||||
if tls == "default":
|
||||
return default_tls_context
|
||||
elif tls == "disable":
|
||||
return None
|
||||
else:
|
||||
tls_cfg = config.TLSCfg(pop.tls)
|
||||
return create_tls_context(tls_cfg.certfile, tls_cfg.keyfile)
|
||||
|
||||
def get_host(host):
|
||||
if host == "default":
|
||||
return cfg.default_host
|
||||
else:
|
||||
return host
|
||||
|
||||
mbox_finder = config.gen_addr_to_mboxes(cfg)
|
||||
servers: list[asyncio.Server] = []
|
||||
|
||||
if cfg.pop:
|
||||
pop = config.PopCfg(cfg.pop)
|
||||
pop_server = await create_pop_server(
|
||||
host=get_host(pop.host),
|
||||
port=pop.port,
|
||||
mails_path=Path(cfg.mails_path),
|
||||
users=cfg.users,
|
||||
ssl_context=get_tls_context(pop.tls),
|
||||
timeout_seconds=pop.timeout_seconds)
|
||||
servers.append(pop_server)
|
||||
|
||||
if cfg.smtp_starttls:
|
||||
stls = config.SmtpStartTLSCfg(cfg.smtp_starttls)
|
||||
stls_context = get_tls_context(stls.tls)
|
||||
if not stls_context:
|
||||
raise Exception("starttls requires ssl_context")
|
||||
smtp_server_starttls = await create_smtp_server_starttls(
|
||||
host=get_host(stls.host),
|
||||
port=stls.port,
|
||||
mails_path=Path(cfg.mails_path),
|
||||
mbox_finder=mbox_finder,
|
||||
ssl_context=stls_context)
|
||||
servers.append(smtp_server_starttls)
|
||||
|
||||
if cfg.smtp:
|
||||
smtp = config.SmtpCfg(cfg.smtp)
|
||||
smtp_server = await create_smtp_server(host=get_host(smtp.host),
|
||||
port=smtp.port,
|
||||
mails_path=Path(cfg.mails_path),
|
||||
mbox_finder=mbox_finder,
|
||||
ssl_context=get_tls_context(
|
||||
smtp.tls))
|
||||
servers.append(smtp_server)
|
||||
|
||||
if servers:
|
||||
await asyncio.gather(server.serve_forever() for server in servers)
|
||||
else:
|
||||
logging.warn("Nothing to do!")
|
||||
|
||||
|
||||
def main():
|
||||
parser = ArgumentParser()
|
||||
parser.add_argument("config_path")
|
||||
parser.add_argument("config_path", type=Path)
|
||||
args = parser.parse_args()
|
||||
config = Config(open(args.config_path).read())
|
||||
config = Config(args.config_path.read_text())
|
||||
|
||||
setup_logging(args)
|
||||
loop = asyncio.get_event_loop()
|
||||
loop.set_debug(config.debug)
|
||||
|
||||
tls_context = create_tls_context(config.certfile, config.keyfile)
|
||||
|
||||
asyncio.run(a_main(config, tls_context))
|
||||
asyncio.run(a_main(config))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
@ -55,7 +55,7 @@ class MyHandler(AsyncMessage):
|
||||
|
||||
def protocol_factory_starttls(mails_path: Path,
|
||||
mbox_finder: Callable[[str], [str]],
|
||||
context: ssl.SSLContext | None = None):
|
||||
context: ssl.SSLContext):
|
||||
logging.info("Got smtp client cb starttls")
|
||||
try:
|
||||
handler = MyHandler(mails_path, mbox_finder)
|
||||
@ -84,25 +84,25 @@ async def create_smtp_server_starttls(host: str,
|
||||
port: int,
|
||||
mails_path: Path,
|
||||
mbox_finder: Callable[[str], [str]],
|
||||
context: ssl.SSLContext | None = None):
|
||||
ssl_context: ssl.SSLContext) -> asyncio.Server:
|
||||
loop = asyncio.get_event_loop()
|
||||
return await loop.create_server(partial(protocol_factory_starttls,
|
||||
mails_path, mbox_finder, context),
|
||||
mails_path, mbox_finder, ssl_context),
|
||||
host=host,
|
||||
port=port,
|
||||
start_serving=False)
|
||||
|
||||
|
||||
async def create_smtp_server_tls(host: str,
|
||||
async def create_smtp_server(host: str,
|
||||
port: int,
|
||||
mails_path: Path,
|
||||
mbox_finder: Callable[[str], [str]],
|
||||
context: ssl.SSLContext | None = None):
|
||||
ssl_context: ssl.SSLContext | None = None) -> asyncio.Server:
|
||||
loop = asyncio.get_event_loop()
|
||||
return await loop.create_server(partial(protocol_factory, mails_path, mbox_finder),
|
||||
host=host,
|
||||
port=port,
|
||||
ssl=context,
|
||||
ssl=ssl_context,
|
||||
start_serving=False)
|
||||
|
||||
|
||||
|
@ -7,7 +7,7 @@ import os
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
from .smtp import create_smtp_server_tls
|
||||
from .smtp import create_smtp_server
|
||||
|
||||
TEST_MBOX = 'foobar_mails'
|
||||
MAILS_PATH: Path
|
||||
@ -28,7 +28,7 @@ class TestSMTP(unittest.IsolatedAsyncioTestCase):
|
||||
|
||||
async def asyncSetUp(self) -> None:
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
smtp_server = await create_smtp_server_tls(
|
||||
smtp_server = await create_smtp_server(
|
||||
host="127.0.0.1",
|
||||
port=7996,
|
||||
mails_path=MAILS_PATH,
|
||||
|
Loading…
Reference in New Issue
Block a user