fix syntax errors and format

This commit is contained in:
Balakrishnan Balasubramanian 2023-06-03 20:34:13 -04:00
parent 5600d30f54
commit 5cdf36fe32
4 changed files with 41 additions and 78 deletions

View File

@ -33,6 +33,7 @@ class Config(Jata):
smtp_port_tls = 465 smtp_port_tls = 465
smtp_port_submission = 587 smtp_port_submission = 587
pop_port = 995 pop_port = 995
pop_timeout_seconds = 60
smtputf8 = True smtputf8 = True
rules: list[Rule] rules: list[Rule]
boxes: list[Mbox] boxes: list[Mbox]

View File

@ -8,44 +8,13 @@ from hashlib import sha256
from pathlib import Path from pathlib import Path
from typing import ClassVar, List, Set from typing import ClassVar, List, Set
from .config import User from .config import User
from .pwhash import parse_hash, check_pass from .pwhash import parse_hash, check_pass, PWInfo
from asyncio import StreamReader, StreamWriter from asyncio import StreamReader, StreamWriter
from .poputils import InvalidCommand, parse_command, err, Command, ClientQuit, ClientError, AuthError, ok, msg, end, \ from .poputils import InvalidCommand, parse_command, err, Command, ClientQuit, ClientError, AuthError, ok, msg, end, \
Request, MailEntry, get_mail, get_mails_list, MailList Request, MailEntry, get_mail, get_mails_list, MailList
def add_season(content: bytes, season: bytes):
return sha256(season + content).digest()
# noinspection PyProtectedMember
@dataclass
class Session:
_reader: StreamReader
_writer: asyncio.StreamWriter
username: str
mbox: str
# common state
all_sessions: ClassVar[Set] = set()
mails_path: ClassVar[Path] = Path("")
users: ClassVar[list[User]] = list()
current_session: ClassVar = ContextVar("session")
@classmethod
def get(cls):
return cls.current_session.get()
@classmethod
def reader(cls):
return cls.get()._reader
@classmethod
def writer(cls):
return cls.get()._writer
async def next_req(): async def next_req():
for _ in range(InvalidCommand.RETRIES): for _ in range(InvalidCommand.RETRIES):
line = await state().reader.readline() line = await state().reader.readline()
@ -259,7 +228,8 @@ async def start_session():
existing_deleted_items: Set = get_deleted_items(deleted_items_path) existing_deleted_items: Set = get_deleted_items(deleted_items_path)
new_deleted_items: Set = await transaction_stage(existing_deleted_items) new_deleted_items: Set = await transaction_stage(existing_deleted_items
)
logging.info( logging.info(
f"{username=} completed transactions. Deleted:{len(new_deleted_items)}" f"{username=} completed transactions. Deleted:{len(new_deleted_items)}"
) )
@ -300,11 +270,12 @@ class State:
mbox: str = "" mbox: str = ""
@dataclass
class Config: class Config:
mails_path: Path
users: dict[str, tuple[pwhash.PWInfo, str]] def __init__(self, mails_path: Path, users: dict[str, tuple[PWInfo, str]]):
loggedin_users: set[str] = set() self.mails_path = mails_path
self.users = users
self.loggedin_users: set[str] = set()
c_config = contextvars.ContextVar('config') c_config = contextvars.ContextVar('config')
@ -336,11 +307,11 @@ def make_pop_server_callback(dirpath: Path, users: list[User],
return session_cb return session_cb
async def create_pop_server(dirpath: Path, async def create_pop_server(host: str,
port: int, port: int,
mails_path: Path,
users: list[User], users: list[User],
host="", ssl_context: ssl.SSLContext = None,
context: ssl.SSLContext = None,
timeout_seconds: int = 60): timeout_seconds: int = 60):
logging.info( logging.info(
f"Starting POP3 server {dirpath=}, {host=}, {port=}, {timeout_seconds=}, ssl={context != None}" f"Starting POP3 server {dirpath=}, {host=}, {port=}, {timeout_seconds=}, ssl={context != None}"

View File

@ -1,8 +1,4 @@
import asyncio import asyncio
# Though we don't use requests, without the below import, we crash https://stackoverflow.com/a/13057751
# When running on privilege port after dropping privileges.
# noinspection PyUnresolvedReferences
import encodings.idna
import logging import logging
import os import os
import ssl import ssl
@ -22,27 +18,6 @@ def create_tls_context(certfile, keyfile):
return context return context
def parse_args():
parser = ArgumentParser()
parser.add_argument('--certfile')
parser.add_argument('--keyfile')
parser.add_argument('--password_hash')
parser.add_argument("mail_dir_path")
args = parser.parse_args()
args.mail_dir_path = Path(args.mail_dir_path)
# Hardcoded args
args.host = '0.0.0.0'
args.smtp_port = 25
args.smtp_port_tls = 465
args.smtp_port_submission = 587
args.pop_port = 995
args.smtputf8 = True
args.debug = True
return args
def setup_logging(args): def setup_logging(args):
if args.debug: if args.debug:
logging.basicConfig(level=logging.DEBUG) logging.basicConfig(level=logging.DEBUG)
@ -51,11 +26,13 @@ def setup_logging(args):
async def a_main(config, tls_context): async def a_main(config, tls_context):
pop_server = await create_pop_server(config.mails_path, pop_server = await create_pop_server(
port=config.pop_port,
host=config.host, host=config.host,
context=tls_context, port=config.pop_port,
users=config.users) mails_path=config.mails_path,
users=config.users,
ssl_context=tls_context,
timeout_seconds=config.pop_timeout_seconds)
smtp_server_starttls = await create_smtp_server_starttls( smtp_server_starttls = await create_smtp_server_starttls(
config.mail_dir_path, config.mail_dir_path,
@ -74,7 +51,6 @@ async def a_main(config, tls_context):
def main(): def main():
config_path = sys.argv[1]
parser = ArgumentParser() parser = ArgumentParser()
parser.add_argument("config_path") parser.add_argument("config_path")
args = parser.parse_args() args = parser.parse_args()

View File

@ -23,6 +23,7 @@ class MaildirCRLF(mailbox.Maildir):
class MailboxCRLF(Mailbox): class MailboxCRLF(Mailbox):
def __init__(self, mail_dir: Path): def __init__(self, mail_dir: Path):
super().__init__(mail_dir) super().__init__(mail_dir)
for sub in ('new', 'tmp', 'cur'): for sub in ('new', 'tmp', 'cur'):
@ -50,23 +51,37 @@ def protocol_factory(dirpath: Path):
logging.info("Got smtp client cb") logging.info("Got smtp client cb")
try: try:
handler = MailboxCRLF(dirpath) handler = MailboxCRLF(dirpath)
smtp = SMTP(handler=handler, data_size_limit=DATA_SIZE_DEFAULT, enable_SMTPUTF8=True) smtp = SMTP(handler=handler,
data_size_limit=DATA_SIZE_DEFAULT,
enable_SMTPUTF8=True)
except Exception as e: except Exception as e:
logging.error("Something went wrong", e) logging.error("Something went wrong", e)
raise raise
return smtp return smtp
async def create_smtp_server_starttls(dirpath: Path, port: int, host="", context: ssl.SSLContext = None): async def create_smtp_server_starttls(dirpath: Path,
port: int,
host="",
context: ssl.SSLContext = None):
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
return await loop.create_server(partial(protocol_factory_starttls, dirpath, context), return await loop.create_server(partial(protocol_factory_starttls, dirpath,
host=host, port=port, start_serving=False) context),
host=host,
port=port,
start_serving=False)
async def create_smtp_server_tls(dirpath: Path, port: int, host="", context: ssl.SSLContext = None): async def create_smtp_server_tls(dirpath: Path,
port: int,
host="",
context: ssl.SSLContext = None):
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
return await loop.create_server(partial(protocol_factory, dirpath), return await loop.create_server(partial(protocol_factory, dirpath),
host=host, port=port, ssl=context, start_serving=False) host=host,
port=port,
ssl=context,
start_serving=False)
async def a_main(*args, **kwargs): async def a_main(*args, **kwargs):