Improve logging

This commit is contained in:
Balakrishnan Balasubramanian 2023-06-16 23:42:30 -04:00
parent 2fa748c444
commit f407c8b395
5 changed files with 132 additions and 76 deletions

View File

@ -1,5 +1,6 @@
import json import json
import re import re
import logging
from typing import Callable from typing import Callable
from jata import Jata, MutableDefault from jata import Jata, MutableDefault
@ -61,10 +62,15 @@ class SmtpCfg(ServerCfg):
port = 465 port = 465
class LogCfg(Jata):
logfile = "STDOUT"
level = "INFO"
class Config(Jata): class Config(Jata):
default_tls: TLSCfg | None default_tls: TLSCfg | None
default_host: str = "0.0.0.0" default_host: str = "0.0.0.0"
debug: bool = False logging: LogCfg | None = None
mails_path: str mails_path: str
matches: list[Match] matches: list[Match]
@ -127,4 +133,5 @@ def get_mboxes(addr: str, checks: list[Checker]) -> list[str]:
def gen_addr_to_mboxes(cfg: Config) -> Callable[[str], list[str]]: def gen_addr_to_mboxes(cfg: Config) -> Callable[[str], list[str]]:
checks = parse_checkers(cfg) checks = parse_checkers(cfg)
logging.info(f"Parsed checkers from config, {len(checks)=}")
return lambda addr: get_mboxes(addr, checks) return lambda addr: get_mboxes(addr, checks)

View File

@ -1,15 +1,17 @@
import asyncio import asyncio
import contextlib
import contextvars
import logging import logging
import os import os
import ssl import ssl
import contextvars import uuid
import contextlib
from dataclasses import dataclass from dataclasses import dataclass
from hashlib import sha256 from hashlib import sha256
from pathlib import Path from pathlib import Path
from .config import User from .config import User
from .pwhash import parse_hash, check_pass, PWInfo from .pwhash import parse_hash, check_pass, PWInfo
from asyncio import StreamReader, StreamWriter from asyncio import StreamReader, StreamWriter
import random
from .poputils import ( from .poputils import (
InvalidCommand, InvalidCommand,
@ -31,10 +33,66 @@ from .poputils import (
) )
@dataclass
class State:
reader: StreamReader
writer: StreamWriter
ip: str
req_id: int
username: str = ""
mbox: str = ""
class SharedState:
def __init__(self, mails_path: Path, users: dict[str, tuple[PWInfo, str]]):
self.mails_path = mails_path
self.users = users
self.loggedin_users: set[str] = set()
self.counter = random.randint(10000, 99999) * 100000
def next_id(self) -> int:
self.counter = self.counter + 1
return self.counter
c_shared_state: contextvars.ContextVar = contextvars.ContextVar(
"pop_shared_state")
def scfg() -> SharedState:
return c_shared_state.get()
c_state: contextvars.ContextVar = contextvars.ContextVar("state")
def state() -> State:
return c_state.get()
class PopLogger(logging.LoggerAdapter):
def __init__(self):
super().__init__(logging.getLogger("pop3"), None)
def process(self, msg, kwargs):
state: State = c_state.get(None)
if not state:
return super().process(msg, kwargs)
user = "NA"
if state.username:
user = state.username
return super().process(f"{state.ip} {state.req_id} {user} {msg}", kwargs)
logger = PopLogger()
async def next_req() -> Request: async def next_req() -> Request:
for _ in range(InvalidCommand.RETRIES): for _ in range(InvalidCommand.RETRIES):
line = await state().reader.readline() line = await state().reader.readline()
logging.debug(f"Client: {line!r}") logger.debug(f"Client: {line!r}")
if not line: if not line:
if state().reader.at_eof(): if state().reader.at_eof():
raise ClientDisconnected raise ClientDisconnected
@ -54,19 +112,19 @@ async def next_req() -> Request:
async def expect_cmd(*commands: Command) -> Request: async def expect_cmd(*commands: Command) -> Request:
req = await next_req() req = await next_req()
if req.cmd not in commands: if req.cmd not in commands:
logging.error(f"Unexpected command: {req.cmd} is not in {commands}") logger.error(f"Unexpected command: {req.cmd} is not in {commands}")
raise ClientError raise ClientError
return req return req
def write(data) -> None: def write(data) -> None:
logging.debug(f"Server: {data}") logger.debug(f"Server: {data}")
state().writer.write(data) state().writer.write(data)
def validate_password(username, password) -> None: def validate_password(username, password) -> None:
try: try:
pwinfo, mbox = config().users[username] pwinfo, mbox = scfg().users[username]
except: except:
raise AuthError("Invalid user pass") raise AuthError("Invalid user pass")
@ -84,7 +142,7 @@ async def handle_user_pass_auth(user_cmd) -> None:
cmd = await expect_cmd(Command.PASS) cmd = await expect_cmd(Command.PASS)
password = cmd.arg1 password = cmd.arg1
validate_password(username, password) validate_password(username, password)
logging.info(f"{username=} has logged in successfully") logger.info(f"{username=} has logged in successfully")
async def auth_stage() -> None: async def auth_stage() -> None:
@ -98,20 +156,20 @@ async def auth_stage() -> None:
write(end()) write(end())
else: else:
await handle_user_pass_auth(req) await handle_user_pass_auth(req)
if state().username in config().loggedin_users: if state().username in scfg().loggedin_users:
logging.warning( logger.warning(
f"User: {state().username} already has an active session" f"User: {state().username} already has an active session"
) )
raise AuthError("Already logged in") raise AuthError("Already logged in")
else: else:
config().loggedin_users.add(state().username) scfg().loggedin_users.add(state().username)
write(ok("Login successful")) write(ok("Login successful"))
return return
except AuthError as ae: except AuthError as ae:
write(err(f"Auth Failed: {ae}")) write(err(f"Auth Failed: {ae}"))
except ClientQuit as c: except ClientQuit as c:
write(ok("Bye")) write(ok("Bye"))
logging.warning("Client has QUIT before auth succeeded") logger.warning("Client has QUIT before auth succeeded")
raise raise
else: else:
raise ClientError("Failed to authenticate") raise ClientError("Failed to authenticate")
@ -204,7 +262,7 @@ async def process_transactions(mails_list: list[MailEntry]) -> set[str]:
except ClientQuit: except ClientQuit:
write(ok("Bye")) write(ok("Bye"))
return mails.deleted_uids return mails.deleted_uids
logging.debug(f"Request: {req}") logger.debug(f"Request: {req}")
try: try:
func = handle_map[req.cmd] func = handle_map[req.cmd]
except KeyError: except KeyError:
@ -229,46 +287,46 @@ def save_deleted_items(deleted_items_path: Path,
async def transaction_stage() -> None: async def transaction_stage() -> None:
deleted_items_path = config().mails_path / state().mbox / state().username deleted_items_path = scfg().mails_path / state().mbox / state().username
existing_deleted_items: set[str] = get_deleted_items(deleted_items_path) existing_deleted_items: set[str] = get_deleted_items(deleted_items_path)
mails_list = [ mails_list = [
entry entry
for entry in get_mails_list(config().mails_path / state().mbox / "new") for entry in get_mails_list(scfg().mails_path / state().mbox / "new")
if entry.uid not in existing_deleted_items if entry.uid not in existing_deleted_items
] ]
new_deleted_items: set[str] = await process_transactions(mails_list) new_deleted_items: set[str] = await process_transactions(mails_list)
logging.info(f"completed transactions. Deleted:{len(new_deleted_items)}") logger.info(f"completed transactions. Deleted:{len(new_deleted_items)}")
if new_deleted_items: if new_deleted_items:
save_deleted_items(deleted_items_path, save_deleted_items(deleted_items_path,
existing_deleted_items.union(new_deleted_items)) existing_deleted_items.union(new_deleted_items))
logging.info(f"Saved deleted items") logger.info(f"Saved deleted items")
async def start_session() -> None: async def start_session() -> None:
logging.info("New session started") logger.info("New session started")
try: try:
await auth_stage() await auth_stage()
assert state().username assert state().username
assert state().mbox assert state().mbox
await transaction_stage() await transaction_stage()
logging.info(f"User:{state().username} done") logger.info(f"User:{state().username} done")
except ClientDisconnected as c: except ClientDisconnected as c:
logging.info("Client disconnected") logger.info("Client disconnected")
pass pass
except ClientQuit: except ClientQuit:
logging.info("Client QUIT") logger.info("Client QUIT")
pass pass
except ClientError as c: except ClientError as c:
write(err("Something went wrong")) write(err("Something went wrong"))
logging.error(f"Unexpected client error: {c}") logger.error(f"Unexpected client error: {c}")
except Exception as e: except Exception as e:
logging.error(f"Serious client error: {e}") logger.error(f"Serious client error: {e}")
raise raise
finally: finally:
with contextlib.suppress(KeyError): with contextlib.suppress(KeyError):
config().loggedin_users.remove(state().username) scfg().loggedin_users.remove(state().username)
def parse_users(users: list[User]) -> dict[str, tuple[PWInfo, str]]: def parse_users(users: list[User]) -> dict[str, tuple[PWInfo, str]]:
@ -282,43 +340,16 @@ def parse_users(users: list[User]) -> dict[str, tuple[PWInfo, str]]:
return dict(inner()) return dict(inner())
@dataclass
class State:
reader: StreamReader
writer: StreamWriter
username: str = ""
mbox: str = ""
class Config:
def __init__(self, mails_path: Path, users: dict[str, tuple[PWInfo, str]]):
self.mails_path = mails_path
self.users = users
self.loggedin_users: set[str] = set()
c_config: contextvars.ContextVar = contextvars.ContextVar("config")
def config() -> Config:
return c_config.get()
c_state: contextvars.ContextVar = contextvars.ContextVar("state")
def state() -> State:
return c_state.get()
def make_pop_server_callback(mails_path: Path, users: list[User], def make_pop_server_callback(mails_path: Path, users: list[User],
timeout_seconds: int): timeout_seconds: int):
config = Config(mails_path=mails_path, users=parse_users(users)) scfg = SharedState(mails_path=mails_path, users=parse_users(users))
async def session_cb(reader: StreamReader, writer: StreamWriter): async def session_cb(reader: StreamReader, writer: StreamWriter):
c_config.set(config) c_shared_state.set(scfg)
c_state.set(State(reader=reader, writer=writer)) ip, _ = writer.get_extra_info("peername")
c_state.set(
State(reader=reader, writer=writer, ip=ip, req_id=scfg.next_id()))
logger.info(f"Got pop server callback")
try: try:
return await asyncio.wait_for(start_session(), timeout_seconds) return await asyncio.wait_for(start_session(), timeout_seconds)
finally: finally:

View File

@ -20,17 +20,20 @@ def create_tls_context(certfile, keyfile) -> ssl.SSLContext:
return context return context
def setup_logging(args): def setup_logging(cfg: config.LogCfg):
if args.debug: logging_format = "%(asctime)s %(name)s %(levelname)s %(message)s @ %(filename)s:%(lineno)d"
logging.basicConfig(level=logging.DEBUG) if cfg.logfile == "STDOUT":
logging.basicConfig(level=cfg.level, format=logging_format)
else: else:
logging.basicConfig(level=logging.INFO) logging.basicConfig(filename=cfg.logfile, level=cfg.level, format=logging_format)
async def a_main(cfg: config.Config) -> None: async def a_main(cfg: config.Config) -> None:
default_tls_context: ssl.SSLContext | None = None default_tls_context: ssl.SSLContext | None = None
if tls := cfg.default_tls: if tls := cfg.default_tls:
logging.info(f"Initializing default tls {tls.certfile=}, {tls.keyfile=}")
default_tls_context = create_tls_context(tls.certfile, tls.keyfile) default_tls_context = create_tls_context(tls.certfile, tls.keyfile)
def get_tls_context(tls: config.TLSCfg | str): def get_tls_context(tls: config.TLSCfg | str):
@ -146,8 +149,9 @@ def main() -> None:
print("✗ password and hash do not match") print("✗ password and hash do not match")
else: else:
cfg = config.Config(args.config.read_text()) cfg = config.Config(args.config.read_text())
setup_logging(cfg) setup_logging(config.LogCfg(cfg.logging))
asyncio.run(a_main(cfg), debug=cfg.debug) logging.info(f"Starting mail4one {args.config=}")
asyncio.run(a_main(cfg))
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -13,6 +13,7 @@ from email.message import Message
import email.policy import email.policy
from email.generator import BytesGenerator from email.generator import BytesGenerator
import tempfile import tempfile
import random
from aiosmtpd.handlers import Mailbox, AsyncMessage from aiosmtpd.handlers import Mailbox, AsyncMessage
from aiosmtpd.smtp import SMTP, DATA_SIZE_DEFAULT from aiosmtpd.smtp import SMTP, DATA_SIZE_DEFAULT
@ -20,6 +21,8 @@ from aiosmtpd.smtp import SMTP as SMTPServer
from aiosmtpd.smtp import Envelope as SMTPEnvelope from aiosmtpd.smtp import Envelope as SMTPEnvelope
from aiosmtpd.smtp import Session as SMTPSession from aiosmtpd.smtp import Session as SMTPSession
logger = logging.getLogger("smtp")
class MyHandler(AsyncMessage): class MyHandler(AsyncMessage):
@ -32,6 +35,7 @@ class MyHandler(AsyncMessage):
async def handle_DATA(self, server: SMTPServer, session: SMTPSession, async def handle_DATA(self, server: SMTPServer, session: SMTPSession,
envelope: SMTPEnvelope) -> str: envelope: SMTPEnvelope) -> str:
self.rcpt_tos = envelope.rcpt_tos self.rcpt_tos = envelope.rcpt_tos
self.peer = session.peer
return await super().handle_DATA(server, session, envelope) return await super().handle_DATA(server, session, envelope)
async def handle_message(self, m: Message): # type: ignore[override] async def handle_message(self, m: Message): # type: ignore[override]
@ -40,24 +44,29 @@ class MyHandler(AsyncMessage):
for mbox in self.mbox_finder(addr): for mbox in self.mbox_finder(addr):
all_mboxes.add(mbox) all_mboxes.add(mbox)
if not all_mboxes: if not all_mboxes:
logger.warning(f"dropping message from: {self.peer}")
return return
for mbox in all_mboxes: for mbox in all_mboxes:
for sub in ("new", "tmp", "cur"): for sub in ("new", "tmp", "cur"):
sub_path = self.mails_path / mbox / sub sub_path = self.mails_path / mbox / sub
sub_path.mkdir(mode=0o755, exist_ok=True, parents=True) sub_path.mkdir(mode=0o755, exist_ok=True, parents=True)
with tempfile.TemporaryDirectory() as tmpdir: with tempfile.TemporaryDirectory() as tmpdir:
temp_email_path = Path(tmpdir) / f"{uuid.uuid4()}.eml" filename = f"{uuid.uuid4()}.eml"
temp_email_path = Path(tmpdir) / filename
with open(temp_email_path, "wb") as fp: with open(temp_email_path, "wb") as fp:
gen = BytesGenerator(fp, policy=email.policy.SMTP) gen = BytesGenerator(fp, policy=email.policy.SMTP)
gen.flatten(m) gen.flatten(m)
for mbox in all_mboxes: for mbox in all_mboxes:
shutil.copy(temp_email_path, self.mails_path / mbox / "new") shutil.copy(temp_email_path, self.mails_path / mbox / "new")
logger.info(
f"Saved mail at {filename} addrs: {','.join(self.rcpt_tos)}, mboxes: {','.join(all_mboxes)} peer: {self.peer}"
)
def protocol_factory_starttls(mails_path: Path, def protocol_factory_starttls(mails_path: Path,
mbox_finder: Callable[[str], list[str]], mbox_finder: Callable[[str], list[str]],
context: ssl.SSLContext): context: ssl.SSLContext):
logging.info("Got smtp client cb starttls") logger.info("Got smtp client cb starttls")
try: try:
handler = MyHandler(mails_path, mbox_finder) handler = MyHandler(mails_path, mbox_finder)
smtp = SMTP( smtp = SMTP(
@ -67,19 +76,19 @@ def protocol_factory_starttls(mails_path: Path,
enable_SMTPUTF8=True, enable_SMTPUTF8=True,
) )
except Exception as e: except Exception as e:
logging.error("Something went wrong", e) logger.error("Something went wrong", e)
raise raise
return smtp return smtp
def protocol_factory(mails_path: Path, mbox_finder: Callable[[str], def protocol_factory(mails_path: Path, mbox_finder: Callable[[str],
list[str]]): list[str]]):
logging.info("Got smtp client cb") logger.info("Got smtp client cb")
try: try:
handler = MyHandler(mails_path, mbox_finder) handler = MyHandler(mails_path, mbox_finder)
smtp = SMTP(handler=handler, enable_SMTPUTF8=True) smtp = SMTP(handler=handler, enable_SMTPUTF8=True)
except Exception as e: except Exception as e:
logging.error("Something went wrong", e) logger.error("Something went wrong", e)
raise raise
return smtp return smtp
@ -91,6 +100,9 @@ async def create_smtp_server_starttls(
mbox_finder: Callable[[str], list[str]], mbox_finder: Callable[[str], list[str]],
ssl_context: ssl.SSLContext, ssl_context: ssl.SSLContext,
) -> asyncio.Server: ) -> asyncio.Server:
logging.info(
f"Starting SMTP STARTTLS server {host=}, {port=}, {mails_path=}, {ssl_context != None=}"
)
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
return await loop.create_server( return await loop.create_server(
partial(protocol_factory_starttls, mails_path, mbox_finder, partial(protocol_factory_starttls, mails_path, mbox_finder,
@ -108,6 +120,9 @@ async def create_smtp_server(
mbox_finder: Callable[[str], list[str]], mbox_finder: Callable[[str], list[str]],
ssl_context: ssl.SSLContext | None = None, ssl_context: ssl.SSLContext | None = None,
) -> asyncio.Server: ) -> asyncio.Server:
logging.info(
f"Starting SMTP server {host=}, {port=}, {mails_path=}, {ssl_context != None=}"
)
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
return await loop.create_server( return await loop.create_server(
partial(protocol_factory, mails_path, mbox_finder), partial(protocol_factory, mails_path, mbox_finder),

View File

@ -3,6 +3,7 @@ import logging
import unittest import unittest
import smtplib import smtplib
import tempfile import tempfile
import contextlib
import os import os
from pathlib import Path from pathlib import Path
@ -27,7 +28,6 @@ def setUpModule() -> None:
class TestSMTP(unittest.IsolatedAsyncioTestCase): class TestSMTP(unittest.IsolatedAsyncioTestCase):
async def asyncSetUp(self) -> None: async def asyncSetUp(self) -> None:
logging.basicConfig(level=logging.DEBUG)
smtp_server = await create_smtp_server( smtp_server = await create_smtp_server(
host="127.0.0.1", host="127.0.0.1",
port=7996, port=7996,
@ -43,16 +43,15 @@ class TestSMTP(unittest.IsolatedAsyncioTestCase):
Byee Byee
""" """
msg = b"".join(l.strip() + b"\r\n" for l in msg.splitlines()) msg = b"".join(l.strip() + b"\r\n" for l in msg.splitlines())
local_port: str
def send_mail(): def send_mail():
nonlocal local_port with contextlib.closing(smtplib.SMTP(host="127.0.0.1",
server = smtplib.SMTP(host="127.0.0.1", port=7996) port=7996)) as client:
server.sendmail("foo@sender.com", "foo@bar.com", msg) client.sendmail("foo@sender.com", "foo@bar.com", msg)
_, local_port = server.sock.getsockname() _, local_port = client.sock.getsockname()
server.close() return local_port
await asyncio.to_thread(send_mail) local_port = await asyncio.to_thread(send_mail)
expected = f"""From: foo@sender.com expected = f"""From: foo@sender.com
To: "foo@bar.com" To: "foo@bar.com"
X-Peer: ('127.0.0.1', {local_port}) X-Peer: ('127.0.0.1', {local_port})