Improve logging
This commit is contained in:
		| @@ -1,5 +1,6 @@ | ||||
| import json | ||||
| import re | ||||
| import logging | ||||
| from typing import Callable | ||||
| from jata import Jata, MutableDefault | ||||
|  | ||||
| @@ -61,10 +62,15 @@ class SmtpCfg(ServerCfg): | ||||
|     port = 465 | ||||
|  | ||||
|  | ||||
| class LogCfg(Jata): | ||||
|     logfile = "STDOUT" | ||||
|     level = "INFO" | ||||
|  | ||||
|  | ||||
| class Config(Jata): | ||||
|     default_tls: TLSCfg | None | ||||
|     default_host: str = "0.0.0.0" | ||||
|     debug: bool = False | ||||
|     logging: LogCfg | None = None | ||||
|  | ||||
|     mails_path: str | ||||
|     matches: list[Match] | ||||
| @@ -127,4 +133,5 @@ def get_mboxes(addr: str, checks: list[Checker]) -> list[str]: | ||||
|  | ||||
| def gen_addr_to_mboxes(cfg: Config) -> Callable[[str], list[str]]: | ||||
|     checks = parse_checkers(cfg) | ||||
|     logging.info(f"Parsed checkers from config, {len(checks)=}") | ||||
|     return lambda addr: get_mboxes(addr, checks) | ||||
|   | ||||
							
								
								
									
										143
									
								
								mail4one/pop3.py
									
									
									
									
									
								
							
							
						
						
									
										143
									
								
								mail4one/pop3.py
									
									
									
									
									
								
							| @@ -1,15 +1,17 @@ | ||||
| import asyncio | ||||
| import contextlib | ||||
| import contextvars | ||||
| import logging | ||||
| import os | ||||
| import ssl | ||||
| import contextvars | ||||
| import contextlib | ||||
| import uuid | ||||
| from dataclasses import dataclass | ||||
| from hashlib import sha256 | ||||
| from pathlib import Path | ||||
| from .config import User | ||||
| from .pwhash import parse_hash, check_pass, PWInfo | ||||
| from asyncio import StreamReader, StreamWriter | ||||
| import random | ||||
|  | ||||
| from .poputils import ( | ||||
|     InvalidCommand, | ||||
| @@ -31,10 +33,66 @@ from .poputils import ( | ||||
| ) | ||||
|  | ||||
|  | ||||
| @dataclass | ||||
| class State: | ||||
|     reader: StreamReader | ||||
|     writer: StreamWriter | ||||
|     ip: str | ||||
|     req_id: int | ||||
|     username: str = "" | ||||
|     mbox: str = "" | ||||
|  | ||||
|  | ||||
| class SharedState: | ||||
|  | ||||
|     def __init__(self, mails_path: Path, users: dict[str, tuple[PWInfo, str]]): | ||||
|         self.mails_path = mails_path | ||||
|         self.users = users | ||||
|         self.loggedin_users: set[str] = set() | ||||
|         self.counter = random.randint(10000, 99999) * 100000 | ||||
|  | ||||
|     def next_id(self) -> int: | ||||
|         self.counter = self.counter + 1 | ||||
|         return self.counter | ||||
|  | ||||
|  | ||||
| c_shared_state: contextvars.ContextVar = contextvars.ContextVar( | ||||
|     "pop_shared_state") | ||||
|  | ||||
|  | ||||
| def scfg() -> SharedState: | ||||
|     return c_shared_state.get() | ||||
|  | ||||
|  | ||||
| c_state: contextvars.ContextVar = contextvars.ContextVar("state") | ||||
|  | ||||
|  | ||||
| def state() -> State: | ||||
|     return c_state.get() | ||||
|  | ||||
|  | ||||
| class PopLogger(logging.LoggerAdapter): | ||||
|  | ||||
|     def __init__(self): | ||||
|         super().__init__(logging.getLogger("pop3"), None) | ||||
|  | ||||
|     def process(self, msg, kwargs): | ||||
|         state: State = c_state.get(None) | ||||
|         if not state: | ||||
|             return super().process(msg, kwargs) | ||||
|         user = "NA" | ||||
|         if state.username: | ||||
|             user = state.username | ||||
|         return super().process(f"{state.ip} {state.req_id} {user} {msg}", kwargs) | ||||
|  | ||||
|  | ||||
| logger = PopLogger() | ||||
|  | ||||
|  | ||||
| async def next_req() -> Request: | ||||
|     for _ in range(InvalidCommand.RETRIES): | ||||
|         line = await state().reader.readline() | ||||
|         logging.debug(f"Client: {line!r}") | ||||
|         logger.debug(f"Client: {line!r}") | ||||
|         if not line: | ||||
|             if state().reader.at_eof(): | ||||
|                 raise ClientDisconnected | ||||
| @@ -54,19 +112,19 @@ async def next_req() -> Request: | ||||
| async def expect_cmd(*commands: Command) -> Request: | ||||
|     req = await next_req() | ||||
|     if req.cmd not in commands: | ||||
|         logging.error(f"Unexpected command: {req.cmd} is not in {commands}") | ||||
|         logger.error(f"Unexpected command: {req.cmd} is not in {commands}") | ||||
|         raise ClientError | ||||
|     return req | ||||
|  | ||||
|  | ||||
| def write(data) -> None: | ||||
|     logging.debug(f"Server: {data}") | ||||
|     logger.debug(f"Server: {data}") | ||||
|     state().writer.write(data) | ||||
|  | ||||
|  | ||||
| def validate_password(username, password) -> None: | ||||
|     try: | ||||
|         pwinfo, mbox = config().users[username] | ||||
|         pwinfo, mbox = scfg().users[username] | ||||
|     except: | ||||
|         raise AuthError("Invalid user pass") | ||||
|  | ||||
| @@ -84,7 +142,7 @@ async def handle_user_pass_auth(user_cmd) -> None: | ||||
|     cmd = await expect_cmd(Command.PASS) | ||||
|     password = cmd.arg1 | ||||
|     validate_password(username, password) | ||||
|     logging.info(f"{username=} has logged in successfully") | ||||
|     logger.info(f"{username=} has logged in successfully") | ||||
|  | ||||
|  | ||||
| async def auth_stage() -> None: | ||||
| @@ -98,20 +156,20 @@ async def auth_stage() -> None: | ||||
|                 write(end()) | ||||
|             else: | ||||
|                 await handle_user_pass_auth(req) | ||||
|                 if state().username in config().loggedin_users: | ||||
|                     logging.warning( | ||||
|                 if state().username in scfg().loggedin_users: | ||||
|                     logger.warning( | ||||
|                         f"User: {state().username} already has an active session" | ||||
|                     ) | ||||
|                     raise AuthError("Already logged in") | ||||
|                 else: | ||||
|                     config().loggedin_users.add(state().username) | ||||
|                     scfg().loggedin_users.add(state().username) | ||||
|                     write(ok("Login successful")) | ||||
|                     return | ||||
|         except AuthError as ae: | ||||
|             write(err(f"Auth Failed: {ae}")) | ||||
|         except ClientQuit as c: | ||||
|             write(ok("Bye")) | ||||
|             logging.warning("Client has QUIT before auth succeeded") | ||||
|             logger.warning("Client has QUIT before auth succeeded") | ||||
|             raise | ||||
|     else: | ||||
|         raise ClientError("Failed to authenticate") | ||||
| @@ -204,7 +262,7 @@ async def process_transactions(mails_list: list[MailEntry]) -> set[str]: | ||||
|         except ClientQuit: | ||||
|             write(ok("Bye")) | ||||
|             return mails.deleted_uids | ||||
|         logging.debug(f"Request: {req}") | ||||
|         logger.debug(f"Request: {req}") | ||||
|         try: | ||||
|             func = handle_map[req.cmd] | ||||
|         except KeyError: | ||||
| @@ -229,46 +287,46 @@ def save_deleted_items(deleted_items_path: Path, | ||||
|  | ||||
|  | ||||
| async def transaction_stage() -> None: | ||||
|     deleted_items_path = config().mails_path / state().mbox / state().username | ||||
|     deleted_items_path = scfg().mails_path / state().mbox / state().username | ||||
|     existing_deleted_items: set[str] = get_deleted_items(deleted_items_path) | ||||
|     mails_list = [ | ||||
|         entry | ||||
|         for entry in get_mails_list(config().mails_path / state().mbox / "new") | ||||
|         for entry in get_mails_list(scfg().mails_path / state().mbox / "new") | ||||
|         if entry.uid not in existing_deleted_items | ||||
|     ] | ||||
|  | ||||
|     new_deleted_items: set[str] = await process_transactions(mails_list) | ||||
|     logging.info(f"completed transactions. Deleted:{len(new_deleted_items)}") | ||||
|     logger.info(f"completed transactions. Deleted:{len(new_deleted_items)}") | ||||
|     if new_deleted_items: | ||||
|         save_deleted_items(deleted_items_path, | ||||
|                            existing_deleted_items.union(new_deleted_items)) | ||||
|  | ||||
|     logging.info(f"Saved deleted items") | ||||
|     logger.info(f"Saved deleted items") | ||||
|  | ||||
|  | ||||
| async def start_session() -> None: | ||||
|     logging.info("New session started") | ||||
|     logger.info("New session started") | ||||
|     try: | ||||
|         await auth_stage() | ||||
|         assert state().username | ||||
|         assert state().mbox | ||||
|         await transaction_stage() | ||||
|         logging.info(f"User:{state().username} done") | ||||
|         logger.info(f"User:{state().username} done") | ||||
|     except ClientDisconnected as c: | ||||
|         logging.info("Client disconnected") | ||||
|         logger.info("Client disconnected") | ||||
|         pass | ||||
|     except ClientQuit: | ||||
|         logging.info("Client QUIT") | ||||
|         logger.info("Client QUIT") | ||||
|         pass | ||||
|     except ClientError as c: | ||||
|         write(err("Something went wrong")) | ||||
|         logging.error(f"Unexpected client error: {c}") | ||||
|         logger.error(f"Unexpected client error: {c}") | ||||
|     except Exception as e: | ||||
|         logging.error(f"Serious client error: {e}") | ||||
|         logger.error(f"Serious client error: {e}") | ||||
|         raise | ||||
|     finally: | ||||
|         with contextlib.suppress(KeyError): | ||||
|             config().loggedin_users.remove(state().username) | ||||
|             scfg().loggedin_users.remove(state().username) | ||||
|  | ||||
|  | ||||
| def parse_users(users: list[User]) -> dict[str, tuple[PWInfo, str]]: | ||||
| @@ -282,43 +340,16 @@ def parse_users(users: list[User]) -> dict[str, tuple[PWInfo, str]]: | ||||
|     return dict(inner()) | ||||
|  | ||||
|  | ||||
| @dataclass | ||||
| class State: | ||||
|     reader: StreamReader | ||||
|     writer: StreamWriter | ||||
|     username: str = "" | ||||
|     mbox: str = "" | ||||
|  | ||||
|  | ||||
| class Config: | ||||
|  | ||||
|     def __init__(self, mails_path: Path, users: dict[str, tuple[PWInfo, str]]): | ||||
|         self.mails_path = mails_path | ||||
|         self.users = users | ||||
|         self.loggedin_users: set[str] = set() | ||||
|  | ||||
|  | ||||
| c_config: contextvars.ContextVar = contextvars.ContextVar("config") | ||||
|  | ||||
|  | ||||
| def config() -> Config: | ||||
|     return c_config.get() | ||||
|  | ||||
|  | ||||
| c_state: contextvars.ContextVar = contextvars.ContextVar("state") | ||||
|  | ||||
|  | ||||
| def state() -> State: | ||||
|     return c_state.get() | ||||
|  | ||||
|  | ||||
| def make_pop_server_callback(mails_path: Path, users: list[User], | ||||
|                              timeout_seconds: int): | ||||
|     config = Config(mails_path=mails_path, users=parse_users(users)) | ||||
|     scfg = SharedState(mails_path=mails_path, users=parse_users(users)) | ||||
|  | ||||
|     async def session_cb(reader: StreamReader, writer: StreamWriter): | ||||
|         c_config.set(config) | ||||
|         c_state.set(State(reader=reader, writer=writer)) | ||||
|         c_shared_state.set(scfg) | ||||
|         ip, _ = writer.get_extra_info("peername") | ||||
|         c_state.set( | ||||
|             State(reader=reader, writer=writer, ip=ip, req_id=scfg.next_id())) | ||||
|         logger.info(f"Got pop server callback") | ||||
|         try: | ||||
|             return await asyncio.wait_for(start_session(), timeout_seconds) | ||||
|         finally: | ||||
|   | ||||
| @@ -20,17 +20,20 @@ def create_tls_context(certfile, keyfile) -> ssl.SSLContext: | ||||
|     return context | ||||
|  | ||||
|  | ||||
| def setup_logging(args): | ||||
|     if args.debug: | ||||
|         logging.basicConfig(level=logging.DEBUG) | ||||
| def setup_logging(cfg: config.LogCfg): | ||||
|     logging_format = "%(asctime)s %(name)s %(levelname)s %(message)s @ %(filename)s:%(lineno)d" | ||||
|     if cfg.logfile == "STDOUT": | ||||
|         logging.basicConfig(level=cfg.level, format=logging_format) | ||||
|     else: | ||||
|         logging.basicConfig(level=logging.INFO) | ||||
|         logging.basicConfig(filename=cfg.logfile, level=cfg.level, format=logging_format) | ||||
|  | ||||
|  | ||||
|  | ||||
| async def a_main(cfg: config.Config) -> None: | ||||
|     default_tls_context: ssl.SSLContext | None = None | ||||
|  | ||||
|     if tls := cfg.default_tls: | ||||
|         logging.info(f"Initializing default tls {tls.certfile=}, {tls.keyfile=}") | ||||
|         default_tls_context = create_tls_context(tls.certfile, tls.keyfile) | ||||
|  | ||||
|     def get_tls_context(tls: config.TLSCfg | str): | ||||
| @@ -146,8 +149,9 @@ def main() -> None: | ||||
|             print("✗ password and hash do not match") | ||||
|     else: | ||||
|         cfg = config.Config(args.config.read_text()) | ||||
|         setup_logging(cfg) | ||||
|         asyncio.run(a_main(cfg), debug=cfg.debug) | ||||
|         setup_logging(config.LogCfg(cfg.logging)) | ||||
|         logging.info(f"Starting mail4one {args.config=}") | ||||
|         asyncio.run(a_main(cfg)) | ||||
|  | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|   | ||||
| @@ -13,6 +13,7 @@ from email.message import Message | ||||
| import email.policy | ||||
| from email.generator import BytesGenerator | ||||
| import tempfile | ||||
| import random | ||||
|  | ||||
| from aiosmtpd.handlers import Mailbox, AsyncMessage | ||||
| from aiosmtpd.smtp import SMTP, DATA_SIZE_DEFAULT | ||||
| @@ -20,6 +21,8 @@ from aiosmtpd.smtp import SMTP as SMTPServer | ||||
| from aiosmtpd.smtp import Envelope as SMTPEnvelope | ||||
| from aiosmtpd.smtp import Session as SMTPSession | ||||
|  | ||||
| logger = logging.getLogger("smtp") | ||||
|  | ||||
|  | ||||
| class MyHandler(AsyncMessage): | ||||
|  | ||||
| @@ -32,6 +35,7 @@ class MyHandler(AsyncMessage): | ||||
|     async def handle_DATA(self, server: SMTPServer, session: SMTPSession, | ||||
|                           envelope: SMTPEnvelope) -> str: | ||||
|         self.rcpt_tos = envelope.rcpt_tos | ||||
|         self.peer = session.peer | ||||
|         return await super().handle_DATA(server, session, envelope) | ||||
|  | ||||
|     async def handle_message(self, m: Message):  # type: ignore[override] | ||||
| @@ -40,24 +44,29 @@ class MyHandler(AsyncMessage): | ||||
|             for mbox in self.mbox_finder(addr): | ||||
|                 all_mboxes.add(mbox) | ||||
|         if not all_mboxes: | ||||
|             logger.warning(f"dropping message from: {self.peer}") | ||||
|             return | ||||
|         for mbox in all_mboxes: | ||||
|             for sub in ("new", "tmp", "cur"): | ||||
|                 sub_path = self.mails_path / mbox / sub | ||||
|                 sub_path.mkdir(mode=0o755, exist_ok=True, parents=True) | ||||
|         with tempfile.TemporaryDirectory() as tmpdir: | ||||
|             temp_email_path = Path(tmpdir) / f"{uuid.uuid4()}.eml" | ||||
|             filename = f"{uuid.uuid4()}.eml" | ||||
|             temp_email_path = Path(tmpdir) / filename | ||||
|             with open(temp_email_path, "wb") as fp: | ||||
|                 gen = BytesGenerator(fp, policy=email.policy.SMTP) | ||||
|                 gen.flatten(m) | ||||
|             for mbox in all_mboxes: | ||||
|                 shutil.copy(temp_email_path, self.mails_path / mbox / "new") | ||||
|             logger.info( | ||||
|                 f"Saved mail at {filename} addrs: {','.join(self.rcpt_tos)}, mboxes: {','.join(all_mboxes)} peer: {self.peer}" | ||||
|             ) | ||||
|  | ||||
|  | ||||
| def protocol_factory_starttls(mails_path: Path, | ||||
|                               mbox_finder: Callable[[str], list[str]], | ||||
|                               context: ssl.SSLContext): | ||||
|     logging.info("Got smtp client cb starttls") | ||||
|     logger.info("Got smtp client cb starttls") | ||||
|     try: | ||||
|         handler = MyHandler(mails_path, mbox_finder) | ||||
|         smtp = SMTP( | ||||
| @@ -67,19 +76,19 @@ def protocol_factory_starttls(mails_path: Path, | ||||
|             enable_SMTPUTF8=True, | ||||
|         ) | ||||
|     except Exception as e: | ||||
|         logging.error("Something went wrong", e) | ||||
|         logger.error("Something went wrong", e) | ||||
|         raise | ||||
|     return smtp | ||||
|  | ||||
|  | ||||
| def protocol_factory(mails_path: Path, mbox_finder: Callable[[str], | ||||
|                                                              list[str]]): | ||||
|     logging.info("Got smtp client cb") | ||||
|     logger.info("Got smtp client cb") | ||||
|     try: | ||||
|         handler = MyHandler(mails_path, mbox_finder) | ||||
|         smtp = SMTP(handler=handler, enable_SMTPUTF8=True) | ||||
|     except Exception as e: | ||||
|         logging.error("Something went wrong", e) | ||||
|         logger.error("Something went wrong", e) | ||||
|         raise | ||||
|     return smtp | ||||
|  | ||||
| @@ -91,6 +100,9 @@ async def create_smtp_server_starttls( | ||||
|     mbox_finder: Callable[[str], list[str]], | ||||
|     ssl_context: ssl.SSLContext, | ||||
| ) -> asyncio.Server: | ||||
|     logging.info( | ||||
|         f"Starting SMTP STARTTLS server {host=}, {port=}, {mails_path=}, {ssl_context != None=}" | ||||
|     ) | ||||
|     loop = asyncio.get_event_loop() | ||||
|     return await loop.create_server( | ||||
|         partial(protocol_factory_starttls, mails_path, mbox_finder, | ||||
| @@ -108,6 +120,9 @@ async def create_smtp_server( | ||||
|     mbox_finder: Callable[[str], list[str]], | ||||
|     ssl_context: ssl.SSLContext | None = None, | ||||
| ) -> asyncio.Server: | ||||
|     logging.info( | ||||
|         f"Starting SMTP server {host=}, {port=}, {mails_path=}, {ssl_context != None=}" | ||||
|     ) | ||||
|     loop = asyncio.get_event_loop() | ||||
|     return await loop.create_server( | ||||
|         partial(protocol_factory, mails_path, mbox_finder), | ||||
|   | ||||
| @@ -3,6 +3,7 @@ import logging | ||||
| import unittest | ||||
| import smtplib | ||||
| import tempfile | ||||
| import contextlib | ||||
| import os | ||||
|  | ||||
| from pathlib import Path | ||||
| @@ -27,7 +28,6 @@ def setUpModule() -> None: | ||||
| class TestSMTP(unittest.IsolatedAsyncioTestCase): | ||||
|  | ||||
|     async def asyncSetUp(self) -> None: | ||||
|         logging.basicConfig(level=logging.DEBUG) | ||||
|         smtp_server = await create_smtp_server( | ||||
|             host="127.0.0.1", | ||||
|             port=7996, | ||||
| @@ -43,16 +43,15 @@ class TestSMTP(unittest.IsolatedAsyncioTestCase): | ||||
|         Byee | ||||
|         """ | ||||
|         msg = b"".join(l.strip() + b"\r\n" for l in msg.splitlines()) | ||||
|         local_port: str | ||||
|  | ||||
|         def send_mail(): | ||||
|             nonlocal local_port | ||||
|             server = smtplib.SMTP(host="127.0.0.1", port=7996) | ||||
|             server.sendmail("foo@sender.com", "foo@bar.com", msg) | ||||
|             _, local_port = server.sock.getsockname() | ||||
|             server.close() | ||||
|             with contextlib.closing(smtplib.SMTP(host="127.0.0.1", | ||||
|                                                  port=7996)) as client: | ||||
|                 client.sendmail("foo@sender.com", "foo@bar.com", msg) | ||||
|                 _, local_port = client.sock.getsockname() | ||||
|                 return local_port | ||||
|  | ||||
|         await asyncio.to_thread(send_mail) | ||||
|         local_port = await asyncio.to_thread(send_mail) | ||||
|         expected = f"""From: foo@sender.com | ||||
|         To: "foo@bar.com" | ||||
|         X-Peer: ('127.0.0.1', {local_port}) | ||||
|   | ||||
		Reference in New Issue
	
	Block a user