From f407c8b395ba0ab89b249bfca0f036e0a174be1c Mon Sep 17 00:00:00 2001 From: Balakrishnan Balasubramanian Date: Fri, 16 Jun 2023 23:42:30 -0400 Subject: [PATCH] Improve logging --- mail4one/config.py | 9 ++- mail4one/pop3.py | 143 +++++++++++++++++++++++++++------------------ mail4one/server.py | 16 +++-- mail4one/smtp.py | 25 ++++++-- tests/test_smtp.py | 15 +++-- 5 files changed, 132 insertions(+), 76 deletions(-) diff --git a/mail4one/config.py b/mail4one/config.py index aa3f7ad..9372b06 100644 --- a/mail4one/config.py +++ b/mail4one/config.py @@ -1,5 +1,6 @@ import json import re +import logging from typing import Callable from jata import Jata, MutableDefault @@ -61,10 +62,15 @@ class SmtpCfg(ServerCfg): port = 465 +class LogCfg(Jata): + logfile = "STDOUT" + level = "INFO" + + class Config(Jata): default_tls: TLSCfg | None default_host: str = "0.0.0.0" - debug: bool = False + logging: LogCfg | None = None mails_path: str matches: list[Match] @@ -127,4 +133,5 @@ def get_mboxes(addr: str, checks: list[Checker]) -> list[str]: def gen_addr_to_mboxes(cfg: Config) -> Callable[[str], list[str]]: checks = parse_checkers(cfg) + logging.info(f"Parsed checkers from config, {len(checks)=}") return lambda addr: get_mboxes(addr, checks) diff --git a/mail4one/pop3.py b/mail4one/pop3.py index 5b78ea0..47e6f8e 100644 --- a/mail4one/pop3.py +++ b/mail4one/pop3.py @@ -1,15 +1,17 @@ import asyncio +import contextlib +import contextvars import logging import os import ssl -import contextvars -import contextlib +import uuid from dataclasses import dataclass from hashlib import sha256 from pathlib import Path from .config import User from .pwhash import parse_hash, check_pass, PWInfo from asyncio import StreamReader, StreamWriter +import random from .poputils import ( InvalidCommand, @@ -31,10 +33,66 @@ from .poputils import ( ) +@dataclass +class State: + reader: StreamReader + writer: StreamWriter + ip: str + req_id: int + username: str = "" + mbox: str = "" + + +class SharedState: + + def __init__(self, mails_path: Path, users: dict[str, tuple[PWInfo, str]]): + self.mails_path = mails_path + self.users = users + self.loggedin_users: set[str] = set() + self.counter = random.randint(10000, 99999) * 100000 + + def next_id(self) -> int: + self.counter = self.counter + 1 + return self.counter + + +c_shared_state: contextvars.ContextVar = contextvars.ContextVar( + "pop_shared_state") + + +def scfg() -> SharedState: + return c_shared_state.get() + + +c_state: contextvars.ContextVar = contextvars.ContextVar("state") + + +def state() -> State: + return c_state.get() + + +class PopLogger(logging.LoggerAdapter): + + def __init__(self): + super().__init__(logging.getLogger("pop3"), None) + + def process(self, msg, kwargs): + state: State = c_state.get(None) + if not state: + return super().process(msg, kwargs) + user = "NA" + if state.username: + user = state.username + return super().process(f"{state.ip} {state.req_id} {user} {msg}", kwargs) + + +logger = PopLogger() + + async def next_req() -> Request: for _ in range(InvalidCommand.RETRIES): line = await state().reader.readline() - logging.debug(f"Client: {line!r}") + logger.debug(f"Client: {line!r}") if not line: if state().reader.at_eof(): raise ClientDisconnected @@ -54,19 +112,19 @@ async def next_req() -> Request: async def expect_cmd(*commands: Command) -> Request: req = await next_req() if req.cmd not in commands: - logging.error(f"Unexpected command: {req.cmd} is not in {commands}") + logger.error(f"Unexpected command: {req.cmd} is not in {commands}") raise ClientError return req def write(data) -> None: - logging.debug(f"Server: {data}") + logger.debug(f"Server: {data}") state().writer.write(data) def validate_password(username, password) -> None: try: - pwinfo, mbox = config().users[username] + pwinfo, mbox = scfg().users[username] except: raise AuthError("Invalid user pass") @@ -84,7 +142,7 @@ async def handle_user_pass_auth(user_cmd) -> None: cmd = await expect_cmd(Command.PASS) password = cmd.arg1 validate_password(username, password) - logging.info(f"{username=} has logged in successfully") + logger.info(f"{username=} has logged in successfully") async def auth_stage() -> None: @@ -98,20 +156,20 @@ async def auth_stage() -> None: write(end()) else: await handle_user_pass_auth(req) - if state().username in config().loggedin_users: - logging.warning( + if state().username in scfg().loggedin_users: + logger.warning( f"User: {state().username} already has an active session" ) raise AuthError("Already logged in") else: - config().loggedin_users.add(state().username) + scfg().loggedin_users.add(state().username) write(ok("Login successful")) return except AuthError as ae: write(err(f"Auth Failed: {ae}")) except ClientQuit as c: write(ok("Bye")) - logging.warning("Client has QUIT before auth succeeded") + logger.warning("Client has QUIT before auth succeeded") raise else: raise ClientError("Failed to authenticate") @@ -204,7 +262,7 @@ async def process_transactions(mails_list: list[MailEntry]) -> set[str]: except ClientQuit: write(ok("Bye")) return mails.deleted_uids - logging.debug(f"Request: {req}") + logger.debug(f"Request: {req}") try: func = handle_map[req.cmd] except KeyError: @@ -229,46 +287,46 @@ def save_deleted_items(deleted_items_path: Path, async def transaction_stage() -> None: - deleted_items_path = config().mails_path / state().mbox / state().username + deleted_items_path = scfg().mails_path / state().mbox / state().username existing_deleted_items: set[str] = get_deleted_items(deleted_items_path) mails_list = [ entry - for entry in get_mails_list(config().mails_path / state().mbox / "new") + for entry in get_mails_list(scfg().mails_path / state().mbox / "new") if entry.uid not in existing_deleted_items ] new_deleted_items: set[str] = await process_transactions(mails_list) - logging.info(f"completed transactions. Deleted:{len(new_deleted_items)}") + logger.info(f"completed transactions. Deleted:{len(new_deleted_items)}") if new_deleted_items: save_deleted_items(deleted_items_path, existing_deleted_items.union(new_deleted_items)) - logging.info(f"Saved deleted items") + logger.info(f"Saved deleted items") async def start_session() -> None: - logging.info("New session started") + logger.info("New session started") try: await auth_stage() assert state().username assert state().mbox await transaction_stage() - logging.info(f"User:{state().username} done") + logger.info(f"User:{state().username} done") except ClientDisconnected as c: - logging.info("Client disconnected") + logger.info("Client disconnected") pass except ClientQuit: - logging.info("Client QUIT") + logger.info("Client QUIT") pass except ClientError as c: write(err("Something went wrong")) - logging.error(f"Unexpected client error: {c}") + logger.error(f"Unexpected client error: {c}") except Exception as e: - logging.error(f"Serious client error: {e}") + logger.error(f"Serious client error: {e}") raise finally: with contextlib.suppress(KeyError): - config().loggedin_users.remove(state().username) + scfg().loggedin_users.remove(state().username) def parse_users(users: list[User]) -> dict[str, tuple[PWInfo, str]]: @@ -282,43 +340,16 @@ def parse_users(users: list[User]) -> dict[str, tuple[PWInfo, str]]: return dict(inner()) -@dataclass -class State: - reader: StreamReader - writer: StreamWriter - username: str = "" - mbox: str = "" - - -class Config: - - def __init__(self, mails_path: Path, users: dict[str, tuple[PWInfo, str]]): - self.mails_path = mails_path - self.users = users - self.loggedin_users: set[str] = set() - - -c_config: contextvars.ContextVar = contextvars.ContextVar("config") - - -def config() -> Config: - return c_config.get() - - -c_state: contextvars.ContextVar = contextvars.ContextVar("state") - - -def state() -> State: - return c_state.get() - - def make_pop_server_callback(mails_path: Path, users: list[User], timeout_seconds: int): - config = Config(mails_path=mails_path, users=parse_users(users)) + scfg = SharedState(mails_path=mails_path, users=parse_users(users)) async def session_cb(reader: StreamReader, writer: StreamWriter): - c_config.set(config) - c_state.set(State(reader=reader, writer=writer)) + c_shared_state.set(scfg) + ip, _ = writer.get_extra_info("peername") + c_state.set( + State(reader=reader, writer=writer, ip=ip, req_id=scfg.next_id())) + logger.info(f"Got pop server callback") try: return await asyncio.wait_for(start_session(), timeout_seconds) finally: diff --git a/mail4one/server.py b/mail4one/server.py index eac8cfe..3a400b5 100644 --- a/mail4one/server.py +++ b/mail4one/server.py @@ -20,17 +20,20 @@ def create_tls_context(certfile, keyfile) -> ssl.SSLContext: return context -def setup_logging(args): - if args.debug: - logging.basicConfig(level=logging.DEBUG) +def setup_logging(cfg: config.LogCfg): + logging_format = "%(asctime)s %(name)s %(levelname)s %(message)s @ %(filename)s:%(lineno)d" + if cfg.logfile == "STDOUT": + logging.basicConfig(level=cfg.level, format=logging_format) else: - logging.basicConfig(level=logging.INFO) + logging.basicConfig(filename=cfg.logfile, level=cfg.level, format=logging_format) + async def a_main(cfg: config.Config) -> None: default_tls_context: ssl.SSLContext | None = None if tls := cfg.default_tls: + logging.info(f"Initializing default tls {tls.certfile=}, {tls.keyfile=}") default_tls_context = create_tls_context(tls.certfile, tls.keyfile) def get_tls_context(tls: config.TLSCfg | str): @@ -146,8 +149,9 @@ def main() -> None: print("✗ password and hash do not match") else: cfg = config.Config(args.config.read_text()) - setup_logging(cfg) - asyncio.run(a_main(cfg), debug=cfg.debug) + setup_logging(config.LogCfg(cfg.logging)) + logging.info(f"Starting mail4one {args.config=}") + asyncio.run(a_main(cfg)) if __name__ == "__main__": diff --git a/mail4one/smtp.py b/mail4one/smtp.py index 5f084b2..5c04864 100644 --- a/mail4one/smtp.py +++ b/mail4one/smtp.py @@ -13,6 +13,7 @@ from email.message import Message import email.policy from email.generator import BytesGenerator import tempfile +import random from aiosmtpd.handlers import Mailbox, AsyncMessage from aiosmtpd.smtp import SMTP, DATA_SIZE_DEFAULT @@ -20,6 +21,8 @@ from aiosmtpd.smtp import SMTP as SMTPServer from aiosmtpd.smtp import Envelope as SMTPEnvelope from aiosmtpd.smtp import Session as SMTPSession +logger = logging.getLogger("smtp") + class MyHandler(AsyncMessage): @@ -32,6 +35,7 @@ class MyHandler(AsyncMessage): async def handle_DATA(self, server: SMTPServer, session: SMTPSession, envelope: SMTPEnvelope) -> str: self.rcpt_tos = envelope.rcpt_tos + self.peer = session.peer return await super().handle_DATA(server, session, envelope) async def handle_message(self, m: Message): # type: ignore[override] @@ -40,24 +44,29 @@ class MyHandler(AsyncMessage): for mbox in self.mbox_finder(addr): all_mboxes.add(mbox) if not all_mboxes: + logger.warning(f"dropping message from: {self.peer}") return for mbox in all_mboxes: for sub in ("new", "tmp", "cur"): sub_path = self.mails_path / mbox / sub sub_path.mkdir(mode=0o755, exist_ok=True, parents=True) with tempfile.TemporaryDirectory() as tmpdir: - temp_email_path = Path(tmpdir) / f"{uuid.uuid4()}.eml" + filename = f"{uuid.uuid4()}.eml" + temp_email_path = Path(tmpdir) / filename with open(temp_email_path, "wb") as fp: gen = BytesGenerator(fp, policy=email.policy.SMTP) gen.flatten(m) for mbox in all_mboxes: shutil.copy(temp_email_path, self.mails_path / mbox / "new") + logger.info( + f"Saved mail at {filename} addrs: {','.join(self.rcpt_tos)}, mboxes: {','.join(all_mboxes)} peer: {self.peer}" + ) def protocol_factory_starttls(mails_path: Path, mbox_finder: Callable[[str], list[str]], context: ssl.SSLContext): - logging.info("Got smtp client cb starttls") + logger.info("Got smtp client cb starttls") try: handler = MyHandler(mails_path, mbox_finder) smtp = SMTP( @@ -67,19 +76,19 @@ def protocol_factory_starttls(mails_path: Path, enable_SMTPUTF8=True, ) except Exception as e: - logging.error("Something went wrong", e) + logger.error("Something went wrong", e) raise return smtp def protocol_factory(mails_path: Path, mbox_finder: Callable[[str], list[str]]): - logging.info("Got smtp client cb") + logger.info("Got smtp client cb") try: handler = MyHandler(mails_path, mbox_finder) smtp = SMTP(handler=handler, enable_SMTPUTF8=True) except Exception as e: - logging.error("Something went wrong", e) + logger.error("Something went wrong", e) raise return smtp @@ -91,6 +100,9 @@ async def create_smtp_server_starttls( mbox_finder: Callable[[str], list[str]], ssl_context: ssl.SSLContext, ) -> asyncio.Server: + logging.info( + f"Starting SMTP STARTTLS server {host=}, {port=}, {mails_path=}, {ssl_context != None=}" + ) loop = asyncio.get_event_loop() return await loop.create_server( partial(protocol_factory_starttls, mails_path, mbox_finder, @@ -108,6 +120,9 @@ async def create_smtp_server( mbox_finder: Callable[[str], list[str]], ssl_context: ssl.SSLContext | None = None, ) -> asyncio.Server: + logging.info( + f"Starting SMTP server {host=}, {port=}, {mails_path=}, {ssl_context != None=}" + ) loop = asyncio.get_event_loop() return await loop.create_server( partial(protocol_factory, mails_path, mbox_finder), diff --git a/tests/test_smtp.py b/tests/test_smtp.py index 1b5e087..855f637 100644 --- a/tests/test_smtp.py +++ b/tests/test_smtp.py @@ -3,6 +3,7 @@ import logging import unittest import smtplib import tempfile +import contextlib import os from pathlib import Path @@ -27,7 +28,6 @@ def setUpModule() -> None: class TestSMTP(unittest.IsolatedAsyncioTestCase): async def asyncSetUp(self) -> None: - logging.basicConfig(level=logging.DEBUG) smtp_server = await create_smtp_server( host="127.0.0.1", port=7996, @@ -43,16 +43,15 @@ class TestSMTP(unittest.IsolatedAsyncioTestCase): Byee """ msg = b"".join(l.strip() + b"\r\n" for l in msg.splitlines()) - local_port: str def send_mail(): - nonlocal local_port - server = smtplib.SMTP(host="127.0.0.1", port=7996) - server.sendmail("foo@sender.com", "foo@bar.com", msg) - _, local_port = server.sock.getsockname() - server.close() + with contextlib.closing(smtplib.SMTP(host="127.0.0.1", + port=7996)) as client: + client.sendmail("foo@sender.com", "foo@bar.com", msg) + _, local_port = client.sock.getsockname() + return local_port - await asyncio.to_thread(send_mail) + local_port = await asyncio.to_thread(send_mail) expected = f"""From: foo@sender.com To: "foo@bar.com" X-Peer: ('127.0.0.1', {local_port})