add rules parsing and checking

This commit is contained in:
Balakrishnan Balasubramanian 2023-06-11 19:49:41 -04:00
parent 9639e3b828
commit 12ac18a03e
3 changed files with 145 additions and 12 deletions

View File

@ -1,20 +1,31 @@
import json
import re
from typing import Callable
from jata import Jata, MutableDefault
class Match(Jata):
name: str
alias: list[str] = MutableDefault(lambda: [])
alias_regex: list[str] = MutableDefault(lambda: [])
addrs: list[str] = MutableDefault(lambda: [])
addr_rexs: list[str] = MutableDefault(lambda: [])
DEFAULT_MATCH_ALL = "default_match_all"
class Rule(Jata):
match_name: str
negate: bool = False
# Do not process further rules
stop_check: bool = False
class Mbox(Jata):
name: str
rules: list[str]
rules: list[Rule]
DEFAULT_NULL_MBOX = "default_null_mbox"
class User(Jata):
@ -35,8 +46,54 @@ class Config(Jata):
pop_port = 995
pop_timeout_seconds = 60
smtputf8 = True
rules: list[Rule]
boxes: list[Mbox]
users: list[User]
boxes: list[Mbox]
matches: list[Match]
def parse_rules(cfg: Config) -> list[tuple[str, Callable[[str], bool], bool]]:
def make_match_fn(m: Match):
if m.addrs and m.addr_rexs:
raise Exception("Both addrs and addr_rexs is set")
if m.addrs:
return lambda malias: malias in m.addrs
elif m.addr_rexs:
compiled_res = [re.compile(reg) for reg in m.addr_rexs]
return lambda malias: any(
reg.match(malias) for reg in compiled_res)
else:
raise Exception("Neither addrs nor addr_rexs is set")
matches = {
m.name: make_match_fn(m)
for match in cfg.matches if (m := Match(match)) is not None
}
matches[DEFAULT_MATCH_ALL] = lambda _: True
def flat_rules():
for mbox in cfg.boxes:
for rule in mbox.rules:
rule = Rule(rule)
fn = matches[rule.match_name]
if rule.negate:
match_fn = lambda malias, fn=fn: not fn(malias)
else:
match_fn = fn
yield (mbox.name, match_fn, rule.stop_check)
return list(flat_rules())
def get_mboxes(
addr: str, rules: list[tuple[str, Callable[[str], bool],
bool]]) -> list[str]:
def inner():
for mbox, match_fn, stop_check in rules:
if match_fn(addr):
yield mbox
if stop_check:
return
return list(inner())

77
mail4one/config_test.py Normal file
View File

@ -0,0 +1,77 @@
import unittest
from . import config
TEST_CONFIG = """
{
"mails_path": "/var/tmp/mails",
"matches": [
{
"name": "mydomain",
"addr_rexs": [
".*@mydomain.com",
".*@m.mydomain.com"
]
},
{
"name": "personal",
"addrs": [
"first.last@mydomain.com",
"secret.name@mydomain.com"
]
}
],
"boxes": [
{
"name": "spam",
"rules": [
{
"match_name": "mydomain",
"negate": true,
"stop_check": true
}
]
},
{
"name": "important",
"rules": [
{
"match_name": "personal"
}
]
},
{
"name": "all",
"rules": [
{
"match_name": "default_match_all"
}
]
}
]
}
"""
class TestConfig(unittest.TestCase):
def test_config(self) -> None:
cfg = config.Config(TEST_CONFIG)
self.assertEqual(cfg.mails_path, "/var/tmp/mails")
def test_parse_rules(self) -> None:
cfg = config.Config(TEST_CONFIG)
op = config.parse_rules(cfg)
self.assertEqual(len(op), 3)
def test_get_mboxes(self) -> None:
cfg = config.Config(TEST_CONFIG)
rules = config.parse_rules(cfg)
self.assertEqual(config.get_mboxes("foo@bar.com", rules), ['spam'])
self.assertEqual(config.get_mboxes("foo@mydomain.com", rules), ['all'])
self.assertEqual(config.get_mboxes("first.last@mydomain.com", rules),
['important', 'all'])
if __name__ == "__main__":
unittest.main()

View File

@ -5,6 +5,7 @@ import mailbox
import ssl
from functools import partial
from pathlib import Path
from . import config
from aiosmtpd.handlers import Mailbox
from aiosmtpd.smtp import SMTP, DATA_SIZE_DEFAULT
@ -32,14 +33,14 @@ class MailboxCRLF(Mailbox):
self.mailbox = MaildirCRLF(mail_dir)
def protocol_factory_starttls(dirpath: Path, context: ssl.SSLContext | None = None):
def protocol_factory_starttls(dirpath: Path,
context: ssl.SSLContext | None = None):
logging.info("Got smtp client cb")
try:
handler = MailboxCRLF(dirpath)
smtp = SMTP(handler=handler,
require_starttls=True,
tls_context=context,
data_size_limit=DATA_SIZE_DEFAULT,
enable_SMTPUTF8=True)
except Exception as e:
logging.error("Something went wrong", e)
@ -51,9 +52,7 @@ def protocol_factory(dirpath: Path):
logging.info("Got smtp client cb")
try:
handler = MailboxCRLF(dirpath)
smtp = SMTP(handler=handler,
data_size_limit=DATA_SIZE_DEFAULT,
enable_SMTPUTF8=True)
smtp = SMTP(handler=handler, enable_SMTPUTF8=True)
except Exception as e:
logging.error("Something went wrong", e)
raise
@ -63,7 +62,7 @@ def protocol_factory(dirpath: Path):
async def create_smtp_server_starttls(dirpath: Path,
port: int,
host="",
context: ssl.SSLContext | None= None):
context: ssl.SSLContext | None = None):
loop = asyncio.get_event_loop()
return await loop.create_server(partial(protocol_factory_starttls, dirpath,
context),
@ -75,7 +74,7 @@ async def create_smtp_server_starttls(dirpath: Path,
async def create_smtp_server_tls(dirpath: Path,
port: int,
host="",
context: ssl.SSLContext | None= None):
context: ssl.SSLContext | None = None):
loop = asyncio.get_event_loop()
return await loop.create_server(partial(protocol_factory, dirpath),
host=host,