format using black

This commit is contained in:
Balakrishnan Balasubramanian 2023-06-26 17:20:50 -04:00
parent 7cb1b69744
commit 59abf24ce5
8 changed files with 38 additions and 50 deletions

View File

@ -28,3 +28,6 @@ build: clean requirements.txt
clean: clean:
rm -rf build rm -rf build
rm -rf mail4one.pyz rm -rf mail4one.pyz
format:
black mail4one/*py

View File

@ -56,13 +56,13 @@ class PopCfg(ServerCfg):
class SmtpStartTLSCfg(ServerCfg): class SmtpStartTLSCfg(ServerCfg):
server_type = "smtp_starttls" server_type = "smtp_starttls"
smtputf8 = True # Not used yet smtputf8 = True # Not used yet
port = 25 port = 25
class SmtpCfg(ServerCfg): class SmtpCfg(ServerCfg):
server_type = "smtp_starttls" server_type = "smtp_starttls"
smtputf8 = True # Not used yet smtputf8 = True # Not used yet
port = 465 port = 465
@ -89,7 +89,6 @@ Checker = tuple[str, CheckerFn, bool]
def parse_checkers(cfg: Config) -> list[Checker]: def parse_checkers(cfg: Config) -> list[Checker]:
def make_match_fn(m: Match): def make_match_fn(m: Match):
if m.addrs and m.addr_rexs: if m.addrs and m.addr_rexs:
raise Exception("Both addrs and addr_rexs is set") raise Exception("Both addrs and addr_rexs is set")
@ -97,8 +96,7 @@ def parse_checkers(cfg: Config) -> list[Checker]:
return lambda malias: malias in m.addrs return lambda malias: malias in m.addrs
elif m.addr_rexs: elif m.addr_rexs:
compiled_res = [re.compile(reg) for reg in m.addr_rexs] compiled_res = [re.compile(reg) for reg in m.addr_rexs]
return lambda malias: any( return lambda malias: any(reg.match(malias) for reg in compiled_res)
reg.match(malias) for reg in compiled_res)
else: else:
raise Exception("Neither addrs nor addr_rexs is set") raise Exception("Neither addrs nor addr_rexs is set")
@ -114,13 +112,13 @@ def parse_checkers(cfg: Config) -> list[Checker]:
return mbox_name, match_fn, rule.stop_check return mbox_name, match_fn, rule.stop_check
return [ return [
make_checker(mbox.name, Rule(rule)) for mbox in cfg.boxes or [] make_checker(mbox.name, Rule(rule))
for mbox in cfg.boxes or []
for rule in mbox.rules for rule in mbox.rules
] ]
def get_mboxes(addr: str, checks: list[Checker]) -> list[str]: def get_mboxes(addr: str, checks: list[Checker]) -> list[str]:
def inner(): def inner():
for mbox, match_fn, stop_check in checks: for mbox, match_fn, stop_check in checks:
if match_fn(addr): if match_fn(addr):

View File

@ -44,7 +44,6 @@ class State:
class SharedState: class SharedState:
def __init__(self, mails_path: Path, users: dict[str, tuple[PWInfo, str]]): def __init__(self, mails_path: Path, users: dict[str, tuple[PWInfo, str]]):
self.mails_path = mails_path self.mails_path = mails_path
self.users = users self.users = users
@ -56,8 +55,7 @@ class SharedState:
return self.counter return self.counter
c_shared_state: contextvars.ContextVar = contextvars.ContextVar( c_shared_state: contextvars.ContextVar = contextvars.ContextVar("pop_shared_state")
"pop_shared_state")
def scfg() -> SharedState: def scfg() -> SharedState:
@ -72,7 +70,6 @@ def state() -> State:
class PopLogger(logging.LoggerAdapter): class PopLogger(logging.LoggerAdapter):
def __init__(self): def __init__(self):
super().__init__(logging.getLogger("pop3"), None) super().__init__(logging.getLogger("pop3"), None)
@ -280,8 +277,7 @@ def get_deleted_items(deleted_items_path: Path) -> set[str]:
return set() return set()
def save_deleted_items(deleted_items_path: Path, def save_deleted_items(deleted_items_path: Path, deleted_items: set[str]) -> None:
deleted_items: set[str]) -> None:
with deleted_items_path.open(mode="w") as f: with deleted_items_path.open(mode="w") as f:
f.writelines(f"{did}\n" for did in deleted_items) f.writelines(f"{did}\n" for did in deleted_items)
@ -298,8 +294,9 @@ async def transaction_stage() -> None:
new_deleted_items: set[str] = await process_transactions(mails_list) new_deleted_items: set[str] = await process_transactions(mails_list)
logger.info(f"completed transactions. Deleted:{len(new_deleted_items)}") logger.info(f"completed transactions. Deleted:{len(new_deleted_items)}")
if new_deleted_items: if new_deleted_items:
save_deleted_items(deleted_items_path, save_deleted_items(
existing_deleted_items.union(new_deleted_items)) deleted_items_path, existing_deleted_items.union(new_deleted_items)
)
logger.info(f"Saved deleted items") logger.info(f"Saved deleted items")
@ -330,7 +327,6 @@ async def start_session() -> None:
def parse_users(users: list[User]) -> dict[str, tuple[PWInfo, str]]: def parse_users(users: list[User]) -> dict[str, tuple[PWInfo, str]]:
def inner(): def inner():
for user in users: for user in users:
user = User(user) user = User(user)
@ -340,15 +336,13 @@ def parse_users(users: list[User]) -> dict[str, tuple[PWInfo, str]]:
return dict(inner()) return dict(inner())
def make_pop_server_callback(mails_path: Path, users: list[User], def make_pop_server_callback(mails_path: Path, users: list[User], timeout_seconds: int):
timeout_seconds: int):
scfg = SharedState(mails_path=mails_path, users=parse_users(users)) scfg = SharedState(mails_path=mails_path, users=parse_users(users))
async def session_cb(reader: StreamReader, writer: StreamWriter): async def session_cb(reader: StreamReader, writer: StreamWriter):
c_shared_state.set(scfg) c_shared_state.set(scfg)
ip, _ = writer.get_extra_info("peername") ip, _ = writer.get_extra_info("peername")
c_state.set( c_state.set(State(reader=reader, writer=writer, ip=ip, req_id=scfg.next_id()))
State(reader=reader, writer=writer, ip=ip, req_id=scfg.next_id()))
logger.info(f"Got pop server callback") logger.info(f"Got pop server callback")
try: try:
try: try:

View File

@ -86,7 +86,7 @@ def parse_command(bline: bytes) -> Request:
if parts: if parts:
request.arg2, *parts = parts request.arg2, *parts = parts
if parts: if parts:
(request.rest, ) = parts (request.rest,) = parts
return request return request
@ -130,7 +130,6 @@ def get_mail(entry: MailEntry) -> bytes:
class MailList: class MailList:
def __init__(self, entries: list[MailEntry]): def __init__(self, entries: list[MailEntry]):
self.entries = entries self.entries = entries
set_nid(self.entries) set_nid(self.entries)

View File

@ -19,17 +19,13 @@ KEY_LEN = 64 # This is python default
def gen_pwhash(password: str) -> str: def gen_pwhash(password: str) -> str:
salt = os.urandom(SALT_LEN) salt = os.urandom(SALT_LEN)
sh = scrypt(password.encode(), sh = scrypt(
salt=salt, password.encode(), salt=salt, n=SCRYPT_N, r=SCRYPT_R, p=SCRYPT_P, dklen=KEY_LEN
n=SCRYPT_N, )
r=SCRYPT_R,
p=SCRYPT_P,
dklen=KEY_LEN)
return b32encode(VERSION + salt + sh).decode() return b32encode(VERSION + salt + sh).decode()
class PWInfo: class PWInfo:
def __init__(self, salt: bytes, sh: bytes): def __init__(self, salt: bytes, sh: bytes):
self.salt = salt self.salt = salt
self.scrypt_hash = sh self.scrypt_hash = sh
@ -40,12 +36,13 @@ def parse_hash(pwhash_str: str) -> PWInfo:
if not len(pwhash) == 1 + SALT_LEN + KEY_LEN: if not len(pwhash) == 1 + SALT_LEN + KEY_LEN:
raise Exception( raise Exception(
f"Invalid hash size, {len(pwhash)} != {1 + SALT_LEN + KEY_LEN}") f"Invalid hash size, {len(pwhash)} != {1 + SALT_LEN + KEY_LEN}"
)
if (ver := pwhash[0:1]) != VERSION: if (ver := pwhash[0:1]) != VERSION:
raise Exception(f"Invalid hash version, {ver!r} != {VERSION!r}") raise Exception(f"Invalid hash version, {ver!r} != {VERSION!r}")
salt, sh = pwhash[1:SALT_LEN + 1], pwhash[-KEY_LEN:] salt, sh = pwhash[1 : SALT_LEN + 1], pwhash[-KEY_LEN:]
return PWInfo(salt, sh) return PWInfo(salt, sh)

View File

@ -22,21 +22,22 @@ def create_tls_context(certfile, keyfile) -> ssl.SSLContext:
def setup_logging(cfg: config.LogCfg): def setup_logging(cfg: config.LogCfg):
logging_format = "%(asctime)s %(name)s %(levelname)s %(message)s @ %(filename)s:%(lineno)d" logging_format = (
"%(asctime)s %(name)s %(levelname)s %(message)s @ %(filename)s:%(lineno)d"
)
if cfg.logfile == "CONSOLE": if cfg.logfile == "CONSOLE":
logging.basicConfig(level=cfg.level, format=logging_format) logging.basicConfig(level=cfg.level, format=logging_format)
else: else:
logging.basicConfig(filename=cfg.logfile, logging.basicConfig(
level=cfg.level, filename=cfg.logfile, level=cfg.level, format=logging_format
format=logging_format) )
async def a_main(cfg: config.Config) -> None: async def a_main(cfg: config.Config) -> None:
default_tls_context: ssl.SSLContext | None = None default_tls_context: ssl.SSLContext | None = None
if tls := cfg.default_tls: if tls := cfg.default_tls:
logging.info( logging.info(f"Initializing default tls {tls.certfile=}, {tls.keyfile=}")
f"Initializing default tls {tls.certfile=}, {tls.keyfile=}")
default_tls_context = create_tls_context(tls.certfile, tls.keyfile) default_tls_context = create_tls_context(tls.certfile, tls.keyfile)
def get_tls_context(tls: config.TLSCfg | str): def get_tls_context(tls: config.TLSCfg | str):

View File

@ -25,15 +25,14 @@ logger = logging.getLogger("smtp")
class MyHandler(AsyncMessage): class MyHandler(AsyncMessage):
def __init__(self, mails_path: Path, mbox_finder: Callable[[str], list[str]]):
def __init__(self, mails_path: Path, mbox_finder: Callable[[str],
list[str]]):
super().__init__() super().__init__()
self.mails_path = mails_path self.mails_path = mails_path
self.mbox_finder = mbox_finder self.mbox_finder = mbox_finder
async def handle_DATA(self, server: SMTPServer, session: SMTPSession, async def handle_DATA(
envelope: SMTPEnvelope) -> str: self, server: SMTPServer, session: SMTPSession, envelope: SMTPEnvelope
) -> str:
self.rcpt_tos = envelope.rcpt_tos self.rcpt_tos = envelope.rcpt_tos
self.peer = session.peer self.peer = session.peer
return await super().handle_DATA(server, session, envelope) return await super().handle_DATA(server, session, envelope)
@ -63,9 +62,9 @@ class MyHandler(AsyncMessage):
) )
def protocol_factory_starttls(mails_path: Path, def protocol_factory_starttls(
mbox_finder: Callable[[str], list[str]], mails_path: Path, mbox_finder: Callable[[str], list[str]], context: ssl.SSLContext
context: ssl.SSLContext): ):
logger.info("Got smtp client cb starttls") logger.info("Got smtp client cb starttls")
try: try:
handler = MyHandler(mails_path, mbox_finder) handler = MyHandler(mails_path, mbox_finder)
@ -81,8 +80,7 @@ def protocol_factory_starttls(mails_path: Path,
return smtp return smtp
def protocol_factory(mails_path: Path, mbox_finder: Callable[[str], def protocol_factory(mails_path: Path, mbox_finder: Callable[[str], list[str]]):
list[str]]):
logger.info("Got smtp client cb") logger.info("Got smtp client cb")
try: try:
handler = MyHandler(mails_path, mbox_finder) handler = MyHandler(mails_path, mbox_finder)
@ -105,8 +103,7 @@ async def create_smtp_server_starttls(
) )
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
return await loop.create_server( return await loop.create_server(
partial(protocol_factory_starttls, mails_path, mbox_finder, partial(protocol_factory_starttls, mails_path, mbox_finder, ssl_context),
ssl_context),
host=host, host=host,
port=port, port=port,
start_serving=False, start_serving=False,

View File

@ -1,2 +1 @@
VERSION = "DEVELOMENT" VERSION = "DEVELOMENT"