refactor server
This commit is contained in:
parent
622bb0f473
commit
93bfb3d41c
@ -34,21 +34,46 @@ class User(Jata):
|
|||||||
mbox: str
|
mbox: str
|
||||||
|
|
||||||
|
|
||||||
class Config(Jata):
|
class TLSCfg(Jata):
|
||||||
certfile: str
|
certfile: str
|
||||||
keyfile: str
|
keyfile: str
|
||||||
debug: bool = False
|
|
||||||
mails_path: str
|
|
||||||
host = '0.0.0.0'
|
class ServerCfg(Jata):
|
||||||
smtp_port = 25
|
host: str = "default"
|
||||||
smtp_port_tls = 465
|
port: int
|
||||||
smtp_port_submission = 587
|
tls: TLSCfg | str = "default"
|
||||||
pop_port = 995
|
|
||||||
pop_timeout_seconds = 60
|
|
||||||
|
class PopCfg(ServerCfg):
|
||||||
|
port = 995
|
||||||
|
timeout_seconds = 60
|
||||||
|
|
||||||
|
|
||||||
|
class SmtpStartTLSCfg(ServerCfg):
|
||||||
smtputf8 = True
|
smtputf8 = True
|
||||||
|
port = 25
|
||||||
|
|
||||||
|
|
||||||
|
class SmtpCfg(ServerCfg):
|
||||||
|
smtputf8 = True
|
||||||
|
port = 465
|
||||||
|
|
||||||
|
|
||||||
|
class Config(Jata):
|
||||||
|
default_tls: TLSCfg | None
|
||||||
|
default_host: str = '0.0.0.0'
|
||||||
|
|
||||||
|
mails_path: str
|
||||||
users: list[User]
|
users: list[User]
|
||||||
boxes: list[Mbox]
|
boxes: list[Mbox]
|
||||||
matches: list[Match]
|
matches: list[Match]
|
||||||
|
debug: bool = False
|
||||||
|
|
||||||
|
pop: PopCfg | None
|
||||||
|
smtp_starttls: SmtpStartTLSCfg | None
|
||||||
|
smtp: SmtpCfg | None
|
||||||
|
# smtp_port_submission = 587
|
||||||
|
|
||||||
|
|
||||||
CheckerFn = Callable[[str], bool]
|
CheckerFn = Callable[[str], bool]
|
||||||
@ -69,10 +94,7 @@ def parse_checkers(cfg: Config) -> list[Checker]:
|
|||||||
else:
|
else:
|
||||||
raise Exception("Neither addrs nor addr_rexs is set")
|
raise Exception("Neither addrs nor addr_rexs is set")
|
||||||
|
|
||||||
matches = {
|
matches = {m.name: make_match_fn(Match(m)) for m in cfg.matches or []}
|
||||||
m.name: make_match_fn(m)
|
|
||||||
for match in cfg.matches if (m := Match(match)) is not None
|
|
||||||
}
|
|
||||||
matches[DEFAULT_MATCH_ALL] = lambda _: True
|
matches[DEFAULT_MATCH_ALL] = lambda _: True
|
||||||
|
|
||||||
def make_checker(mbox_name: str, rule: Rule) -> Checker:
|
def make_checker(mbox_name: str, rule: Rule) -> Checker:
|
||||||
@ -84,7 +106,7 @@ def parse_checkers(cfg: Config) -> list[Checker]:
|
|||||||
return mbox_name, match_fn, rule.stop_check
|
return mbox_name, match_fn, rule.stop_check
|
||||||
|
|
||||||
return [
|
return [
|
||||||
make_checker(mbox.name, Rule(rule)) for mbox in cfg.boxes
|
make_checker(mbox.name, Rule(rule)) for mbox in cfg.boxes or []
|
||||||
for rule in mbox.rules
|
for rule in mbox.rules
|
||||||
]
|
]
|
||||||
|
|
||||||
|
@ -317,7 +317,7 @@ async def create_pop_server(host: str,
|
|||||||
mails_path: Path,
|
mails_path: Path,
|
||||||
users: list[User],
|
users: list[User],
|
||||||
ssl_context: ssl.SSLContext | None = None,
|
ssl_context: ssl.SSLContext | None = None,
|
||||||
timeout_seconds: int = 60):
|
timeout_seconds: int = 60) -> asyncio.Server:
|
||||||
logging.info(
|
logging.info(
|
||||||
f"Starting POP3 server {host=}, {port=}, {mails_path=}, {len(users)=}, {ssl_context != None=}, {timeout_seconds=}"
|
f"Starting POP3 server {host=}, {port=}, {mails_path=}, {len(users)=}, {ssl_context != None=}, {timeout_seconds=}"
|
||||||
)
|
)
|
||||||
|
@ -6,13 +6,13 @@ import sys
|
|||||||
from argparse import ArgumentParser
|
from argparse import ArgumentParser
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
from .smtp import create_smtp_server_starttls, create_smtp_server_tls
|
from .smtp import create_smtp_server_starttls, create_smtp_server
|
||||||
from .pop3 import create_pop_server
|
from .pop3 import create_pop_server
|
||||||
|
|
||||||
from .config import Config
|
from . import config
|
||||||
|
|
||||||
|
|
||||||
def create_tls_context(certfile, keyfile):
|
def create_tls_context(certfile, keyfile) -> ssl.SSLContext:
|
||||||
context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH)
|
context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH)
|
||||||
context.load_cert_chain(certfile=certfile, keyfile=keyfile)
|
context.load_cert_chain(certfile=certfile, keyfile=keyfile)
|
||||||
return context
|
return context
|
||||||
@ -25,44 +25,82 @@ def setup_logging(args):
|
|||||||
logging.basicConfig(level=logging.INFO)
|
logging.basicConfig(level=logging.INFO)
|
||||||
|
|
||||||
|
|
||||||
async def a_main(config, tls_context):
|
async def a_main(cfg: config.Config) -> None:
|
||||||
|
|
||||||
|
default_tls_context: ssl.SSLContext | None = None
|
||||||
|
|
||||||
|
if tls := cfg.default_tls:
|
||||||
|
default_tls_context = create_tls_context(tls.certfile, tls.keyfile)
|
||||||
|
|
||||||
|
def get_tls_context(tls: config.TLSCfg | str):
|
||||||
|
if tls == "default":
|
||||||
|
return default_tls_context
|
||||||
|
elif tls == "disable":
|
||||||
|
return None
|
||||||
|
else:
|
||||||
|
tls_cfg = config.TLSCfg(pop.tls)
|
||||||
|
return create_tls_context(tls_cfg.certfile, tls_cfg.keyfile)
|
||||||
|
|
||||||
|
def get_host(host):
|
||||||
|
if host == "default":
|
||||||
|
return cfg.default_host
|
||||||
|
else:
|
||||||
|
return host
|
||||||
|
|
||||||
|
mbox_finder = config.gen_addr_to_mboxes(cfg)
|
||||||
|
servers: list[asyncio.Server] = []
|
||||||
|
|
||||||
|
if cfg.pop:
|
||||||
|
pop = config.PopCfg(cfg.pop)
|
||||||
pop_server = await create_pop_server(
|
pop_server = await create_pop_server(
|
||||||
host=config.host,
|
host=get_host(pop.host),
|
||||||
port=config.pop_port,
|
port=pop.port,
|
||||||
mails_path=config.mails_path,
|
mails_path=Path(cfg.mails_path),
|
||||||
users=config.users,
|
users=cfg.users,
|
||||||
ssl_context=tls_context,
|
ssl_context=get_tls_context(pop.tls),
|
||||||
timeout_seconds=config.pop_timeout_seconds)
|
timeout_seconds=pop.timeout_seconds)
|
||||||
|
servers.append(pop_server)
|
||||||
|
|
||||||
|
if cfg.smtp_starttls:
|
||||||
|
stls = config.SmtpStartTLSCfg(cfg.smtp_starttls)
|
||||||
|
stls_context = get_tls_context(stls.tls)
|
||||||
|
if not stls_context:
|
||||||
|
raise Exception("starttls requires ssl_context")
|
||||||
smtp_server_starttls = await create_smtp_server_starttls(
|
smtp_server_starttls = await create_smtp_server_starttls(
|
||||||
config.mail_dir_path,
|
host=get_host(stls.host),
|
||||||
port=config.smtp_port,
|
port=stls.port,
|
||||||
host=config.host,
|
mails_path=Path(cfg.mails_path),
|
||||||
context=tls_context)
|
mbox_finder=mbox_finder,
|
||||||
|
ssl_context=stls_context)
|
||||||
|
servers.append(smtp_server_starttls)
|
||||||
|
|
||||||
smtp_server_tls = await create_smtp_server_tls(config.mail_dir_path,
|
if cfg.smtp:
|
||||||
port=config.smtp_port_tls,
|
smtp = config.SmtpCfg(cfg.smtp)
|
||||||
host=config.host,
|
smtp_server = await create_smtp_server(host=get_host(smtp.host),
|
||||||
context=tls_context)
|
port=smtp.port,
|
||||||
|
mails_path=Path(cfg.mails_path),
|
||||||
|
mbox_finder=mbox_finder,
|
||||||
|
ssl_context=get_tls_context(
|
||||||
|
smtp.tls))
|
||||||
|
servers.append(smtp_server)
|
||||||
|
|
||||||
await asyncio.gather(pop_server.serve_forever(),
|
if servers:
|
||||||
smtp_server_starttls.serve_forever(),
|
await asyncio.gather(server.serve_forever() for server in servers)
|
||||||
smtp_server_tls.serve_forever())
|
else:
|
||||||
|
logging.warn("Nothing to do!")
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
parser = ArgumentParser()
|
parser = ArgumentParser()
|
||||||
parser.add_argument("config_path")
|
parser.add_argument("config_path", type=Path)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
config = Config(open(args.config_path).read())
|
config = Config(args.config_path.read_text())
|
||||||
|
|
||||||
setup_logging(args)
|
setup_logging(args)
|
||||||
loop = asyncio.get_event_loop()
|
loop = asyncio.get_event_loop()
|
||||||
loop.set_debug(config.debug)
|
loop.set_debug(config.debug)
|
||||||
|
|
||||||
tls_context = create_tls_context(config.certfile, config.keyfile)
|
asyncio.run(a_main(config))
|
||||||
|
|
||||||
asyncio.run(a_main(config, tls_context))
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
@ -55,7 +55,7 @@ class MyHandler(AsyncMessage):
|
|||||||
|
|
||||||
def protocol_factory_starttls(mails_path: Path,
|
def protocol_factory_starttls(mails_path: Path,
|
||||||
mbox_finder: Callable[[str], [str]],
|
mbox_finder: Callable[[str], [str]],
|
||||||
context: ssl.SSLContext | None = None):
|
context: ssl.SSLContext):
|
||||||
logging.info("Got smtp client cb starttls")
|
logging.info("Got smtp client cb starttls")
|
||||||
try:
|
try:
|
||||||
handler = MyHandler(mails_path, mbox_finder)
|
handler = MyHandler(mails_path, mbox_finder)
|
||||||
@ -84,25 +84,25 @@ async def create_smtp_server_starttls(host: str,
|
|||||||
port: int,
|
port: int,
|
||||||
mails_path: Path,
|
mails_path: Path,
|
||||||
mbox_finder: Callable[[str], [str]],
|
mbox_finder: Callable[[str], [str]],
|
||||||
context: ssl.SSLContext | None = None):
|
ssl_context: ssl.SSLContext) -> asyncio.Server:
|
||||||
loop = asyncio.get_event_loop()
|
loop = asyncio.get_event_loop()
|
||||||
return await loop.create_server(partial(protocol_factory_starttls,
|
return await loop.create_server(partial(protocol_factory_starttls,
|
||||||
mails_path, mbox_finder, context),
|
mails_path, mbox_finder, ssl_context),
|
||||||
host=host,
|
host=host,
|
||||||
port=port,
|
port=port,
|
||||||
start_serving=False)
|
start_serving=False)
|
||||||
|
|
||||||
|
|
||||||
async def create_smtp_server_tls(host: str,
|
async def create_smtp_server(host: str,
|
||||||
port: int,
|
port: int,
|
||||||
mails_path: Path,
|
mails_path: Path,
|
||||||
mbox_finder: Callable[[str], [str]],
|
mbox_finder: Callable[[str], [str]],
|
||||||
context: ssl.SSLContext | None = None):
|
ssl_context: ssl.SSLContext | None = None) -> asyncio.Server:
|
||||||
loop = asyncio.get_event_loop()
|
loop = asyncio.get_event_loop()
|
||||||
return await loop.create_server(partial(protocol_factory, mails_path, mbox_finder),
|
return await loop.create_server(partial(protocol_factory, mails_path, mbox_finder),
|
||||||
host=host,
|
host=host,
|
||||||
port=port,
|
port=port,
|
||||||
ssl=context,
|
ssl=ssl_context,
|
||||||
start_serving=False)
|
start_serving=False)
|
||||||
|
|
||||||
|
|
||||||
|
@ -7,7 +7,7 @@ import os
|
|||||||
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
from .smtp import create_smtp_server_tls
|
from .smtp import create_smtp_server
|
||||||
|
|
||||||
TEST_MBOX = 'foobar_mails'
|
TEST_MBOX = 'foobar_mails'
|
||||||
MAILS_PATH: Path
|
MAILS_PATH: Path
|
||||||
@ -28,7 +28,7 @@ class TestSMTP(unittest.IsolatedAsyncioTestCase):
|
|||||||
|
|
||||||
async def asyncSetUp(self) -> None:
|
async def asyncSetUp(self) -> None:
|
||||||
logging.basicConfig(level=logging.DEBUG)
|
logging.basicConfig(level=logging.DEBUG)
|
||||||
smtp_server = await create_smtp_server_tls(
|
smtp_server = await create_smtp_server(
|
||||||
host="127.0.0.1",
|
host="127.0.0.1",
|
||||||
port=7996,
|
port=7996,
|
||||||
mails_path=MAILS_PATH,
|
mails_path=MAILS_PATH,
|
||||||
|
Loading…
Reference in New Issue
Block a user