Improve logging

This commit is contained in:
Balakrishnan Balasubramanian 2023-06-16 23:42:30 -04:00
parent 2fa748c444
commit f407c8b395
5 changed files with 132 additions and 76 deletions

View File

@ -1,5 +1,6 @@
import json
import re
import logging
from typing import Callable
from jata import Jata, MutableDefault
@ -61,10 +62,15 @@ class SmtpCfg(ServerCfg):
port = 465
class LogCfg(Jata):
logfile = "STDOUT"
level = "INFO"
class Config(Jata):
default_tls: TLSCfg | None
default_host: str = "0.0.0.0"
debug: bool = False
logging: LogCfg | None = None
mails_path: str
matches: list[Match]
@ -127,4 +133,5 @@ def get_mboxes(addr: str, checks: list[Checker]) -> list[str]:
def gen_addr_to_mboxes(cfg: Config) -> Callable[[str], list[str]]:
checks = parse_checkers(cfg)
logging.info(f"Parsed checkers from config, {len(checks)=}")
return lambda addr: get_mboxes(addr, checks)

View File

@ -1,15 +1,17 @@
import asyncio
import contextlib
import contextvars
import logging
import os
import ssl
import contextvars
import contextlib
import uuid
from dataclasses import dataclass
from hashlib import sha256
from pathlib import Path
from .config import User
from .pwhash import parse_hash, check_pass, PWInfo
from asyncio import StreamReader, StreamWriter
import random
from .poputils import (
InvalidCommand,
@ -31,10 +33,66 @@ from .poputils import (
)
@dataclass
class State:
reader: StreamReader
writer: StreamWriter
ip: str
req_id: int
username: str = ""
mbox: str = ""
class SharedState:
def __init__(self, mails_path: Path, users: dict[str, tuple[PWInfo, str]]):
self.mails_path = mails_path
self.users = users
self.loggedin_users: set[str] = set()
self.counter = random.randint(10000, 99999) * 100000
def next_id(self) -> int:
self.counter = self.counter + 1
return self.counter
c_shared_state: contextvars.ContextVar = contextvars.ContextVar(
"pop_shared_state")
def scfg() -> SharedState:
return c_shared_state.get()
c_state: contextvars.ContextVar = contextvars.ContextVar("state")
def state() -> State:
return c_state.get()
class PopLogger(logging.LoggerAdapter):
def __init__(self):
super().__init__(logging.getLogger("pop3"), None)
def process(self, msg, kwargs):
state: State = c_state.get(None)
if not state:
return super().process(msg, kwargs)
user = "NA"
if state.username:
user = state.username
return super().process(f"{state.ip} {state.req_id} {user} {msg}", kwargs)
logger = PopLogger()
async def next_req() -> Request:
for _ in range(InvalidCommand.RETRIES):
line = await state().reader.readline()
logging.debug(f"Client: {line!r}")
logger.debug(f"Client: {line!r}")
if not line:
if state().reader.at_eof():
raise ClientDisconnected
@ -54,19 +112,19 @@ async def next_req() -> Request:
async def expect_cmd(*commands: Command) -> Request:
req = await next_req()
if req.cmd not in commands:
logging.error(f"Unexpected command: {req.cmd} is not in {commands}")
logger.error(f"Unexpected command: {req.cmd} is not in {commands}")
raise ClientError
return req
def write(data) -> None:
logging.debug(f"Server: {data}")
logger.debug(f"Server: {data}")
state().writer.write(data)
def validate_password(username, password) -> None:
try:
pwinfo, mbox = config().users[username]
pwinfo, mbox = scfg().users[username]
except:
raise AuthError("Invalid user pass")
@ -84,7 +142,7 @@ async def handle_user_pass_auth(user_cmd) -> None:
cmd = await expect_cmd(Command.PASS)
password = cmd.arg1
validate_password(username, password)
logging.info(f"{username=} has logged in successfully")
logger.info(f"{username=} has logged in successfully")
async def auth_stage() -> None:
@ -98,20 +156,20 @@ async def auth_stage() -> None:
write(end())
else:
await handle_user_pass_auth(req)
if state().username in config().loggedin_users:
logging.warning(
if state().username in scfg().loggedin_users:
logger.warning(
f"User: {state().username} already has an active session"
)
raise AuthError("Already logged in")
else:
config().loggedin_users.add(state().username)
scfg().loggedin_users.add(state().username)
write(ok("Login successful"))
return
except AuthError as ae:
write(err(f"Auth Failed: {ae}"))
except ClientQuit as c:
write(ok("Bye"))
logging.warning("Client has QUIT before auth succeeded")
logger.warning("Client has QUIT before auth succeeded")
raise
else:
raise ClientError("Failed to authenticate")
@ -204,7 +262,7 @@ async def process_transactions(mails_list: list[MailEntry]) -> set[str]:
except ClientQuit:
write(ok("Bye"))
return mails.deleted_uids
logging.debug(f"Request: {req}")
logger.debug(f"Request: {req}")
try:
func = handle_map[req.cmd]
except KeyError:
@ -229,46 +287,46 @@ def save_deleted_items(deleted_items_path: Path,
async def transaction_stage() -> None:
deleted_items_path = config().mails_path / state().mbox / state().username
deleted_items_path = scfg().mails_path / state().mbox / state().username
existing_deleted_items: set[str] = get_deleted_items(deleted_items_path)
mails_list = [
entry
for entry in get_mails_list(config().mails_path / state().mbox / "new")
for entry in get_mails_list(scfg().mails_path / state().mbox / "new")
if entry.uid not in existing_deleted_items
]
new_deleted_items: set[str] = await process_transactions(mails_list)
logging.info(f"completed transactions. Deleted:{len(new_deleted_items)}")
logger.info(f"completed transactions. Deleted:{len(new_deleted_items)}")
if new_deleted_items:
save_deleted_items(deleted_items_path,
existing_deleted_items.union(new_deleted_items))
logging.info(f"Saved deleted items")
logger.info(f"Saved deleted items")
async def start_session() -> None:
logging.info("New session started")
logger.info("New session started")
try:
await auth_stage()
assert state().username
assert state().mbox
await transaction_stage()
logging.info(f"User:{state().username} done")
logger.info(f"User:{state().username} done")
except ClientDisconnected as c:
logging.info("Client disconnected")
logger.info("Client disconnected")
pass
except ClientQuit:
logging.info("Client QUIT")
logger.info("Client QUIT")
pass
except ClientError as c:
write(err("Something went wrong"))
logging.error(f"Unexpected client error: {c}")
logger.error(f"Unexpected client error: {c}")
except Exception as e:
logging.error(f"Serious client error: {e}")
logger.error(f"Serious client error: {e}")
raise
finally:
with contextlib.suppress(KeyError):
config().loggedin_users.remove(state().username)
scfg().loggedin_users.remove(state().username)
def parse_users(users: list[User]) -> dict[str, tuple[PWInfo, str]]:
@ -282,43 +340,16 @@ def parse_users(users: list[User]) -> dict[str, tuple[PWInfo, str]]:
return dict(inner())
@dataclass
class State:
reader: StreamReader
writer: StreamWriter
username: str = ""
mbox: str = ""
class Config:
def __init__(self, mails_path: Path, users: dict[str, tuple[PWInfo, str]]):
self.mails_path = mails_path
self.users = users
self.loggedin_users: set[str] = set()
c_config: contextvars.ContextVar = contextvars.ContextVar("config")
def config() -> Config:
return c_config.get()
c_state: contextvars.ContextVar = contextvars.ContextVar("state")
def state() -> State:
return c_state.get()
def make_pop_server_callback(mails_path: Path, users: list[User],
timeout_seconds: int):
config = Config(mails_path=mails_path, users=parse_users(users))
scfg = SharedState(mails_path=mails_path, users=parse_users(users))
async def session_cb(reader: StreamReader, writer: StreamWriter):
c_config.set(config)
c_state.set(State(reader=reader, writer=writer))
c_shared_state.set(scfg)
ip, _ = writer.get_extra_info("peername")
c_state.set(
State(reader=reader, writer=writer, ip=ip, req_id=scfg.next_id()))
logger.info(f"Got pop server callback")
try:
return await asyncio.wait_for(start_session(), timeout_seconds)
finally:

View File

@ -20,17 +20,20 @@ def create_tls_context(certfile, keyfile) -> ssl.SSLContext:
return context
def setup_logging(args):
if args.debug:
logging.basicConfig(level=logging.DEBUG)
def setup_logging(cfg: config.LogCfg):
logging_format = "%(asctime)s %(name)s %(levelname)s %(message)s @ %(filename)s:%(lineno)d"
if cfg.logfile == "STDOUT":
logging.basicConfig(level=cfg.level, format=logging_format)
else:
logging.basicConfig(level=logging.INFO)
logging.basicConfig(filename=cfg.logfile, level=cfg.level, format=logging_format)
async def a_main(cfg: config.Config) -> None:
default_tls_context: ssl.SSLContext | None = None
if tls := cfg.default_tls:
logging.info(f"Initializing default tls {tls.certfile=}, {tls.keyfile=}")
default_tls_context = create_tls_context(tls.certfile, tls.keyfile)
def get_tls_context(tls: config.TLSCfg | str):
@ -146,8 +149,9 @@ def main() -> None:
print("✗ password and hash do not match")
else:
cfg = config.Config(args.config.read_text())
setup_logging(cfg)
asyncio.run(a_main(cfg), debug=cfg.debug)
setup_logging(config.LogCfg(cfg.logging))
logging.info(f"Starting mail4one {args.config=}")
asyncio.run(a_main(cfg))
if __name__ == "__main__":

View File

@ -13,6 +13,7 @@ from email.message import Message
import email.policy
from email.generator import BytesGenerator
import tempfile
import random
from aiosmtpd.handlers import Mailbox, AsyncMessage
from aiosmtpd.smtp import SMTP, DATA_SIZE_DEFAULT
@ -20,6 +21,8 @@ from aiosmtpd.smtp import SMTP as SMTPServer
from aiosmtpd.smtp import Envelope as SMTPEnvelope
from aiosmtpd.smtp import Session as SMTPSession
logger = logging.getLogger("smtp")
class MyHandler(AsyncMessage):
@ -32,6 +35,7 @@ class MyHandler(AsyncMessage):
async def handle_DATA(self, server: SMTPServer, session: SMTPSession,
envelope: SMTPEnvelope) -> str:
self.rcpt_tos = envelope.rcpt_tos
self.peer = session.peer
return await super().handle_DATA(server, session, envelope)
async def handle_message(self, m: Message): # type: ignore[override]
@ -40,24 +44,29 @@ class MyHandler(AsyncMessage):
for mbox in self.mbox_finder(addr):
all_mboxes.add(mbox)
if not all_mboxes:
logger.warning(f"dropping message from: {self.peer}")
return
for mbox in all_mboxes:
for sub in ("new", "tmp", "cur"):
sub_path = self.mails_path / mbox / sub
sub_path.mkdir(mode=0o755, exist_ok=True, parents=True)
with tempfile.TemporaryDirectory() as tmpdir:
temp_email_path = Path(tmpdir) / f"{uuid.uuid4()}.eml"
filename = f"{uuid.uuid4()}.eml"
temp_email_path = Path(tmpdir) / filename
with open(temp_email_path, "wb") as fp:
gen = BytesGenerator(fp, policy=email.policy.SMTP)
gen.flatten(m)
for mbox in all_mboxes:
shutil.copy(temp_email_path, self.mails_path / mbox / "new")
logger.info(
f"Saved mail at {filename} addrs: {','.join(self.rcpt_tos)}, mboxes: {','.join(all_mboxes)} peer: {self.peer}"
)
def protocol_factory_starttls(mails_path: Path,
mbox_finder: Callable[[str], list[str]],
context: ssl.SSLContext):
logging.info("Got smtp client cb starttls")
logger.info("Got smtp client cb starttls")
try:
handler = MyHandler(mails_path, mbox_finder)
smtp = SMTP(
@ -67,19 +76,19 @@ def protocol_factory_starttls(mails_path: Path,
enable_SMTPUTF8=True,
)
except Exception as e:
logging.error("Something went wrong", e)
logger.error("Something went wrong", e)
raise
return smtp
def protocol_factory(mails_path: Path, mbox_finder: Callable[[str],
list[str]]):
logging.info("Got smtp client cb")
logger.info("Got smtp client cb")
try:
handler = MyHandler(mails_path, mbox_finder)
smtp = SMTP(handler=handler, enable_SMTPUTF8=True)
except Exception as e:
logging.error("Something went wrong", e)
logger.error("Something went wrong", e)
raise
return smtp
@ -91,6 +100,9 @@ async def create_smtp_server_starttls(
mbox_finder: Callable[[str], list[str]],
ssl_context: ssl.SSLContext,
) -> asyncio.Server:
logging.info(
f"Starting SMTP STARTTLS server {host=}, {port=}, {mails_path=}, {ssl_context != None=}"
)
loop = asyncio.get_event_loop()
return await loop.create_server(
partial(protocol_factory_starttls, mails_path, mbox_finder,
@ -108,6 +120,9 @@ async def create_smtp_server(
mbox_finder: Callable[[str], list[str]],
ssl_context: ssl.SSLContext | None = None,
) -> asyncio.Server:
logging.info(
f"Starting SMTP server {host=}, {port=}, {mails_path=}, {ssl_context != None=}"
)
loop = asyncio.get_event_loop()
return await loop.create_server(
partial(protocol_factory, mails_path, mbox_finder),

View File

@ -3,6 +3,7 @@ import logging
import unittest
import smtplib
import tempfile
import contextlib
import os
from pathlib import Path
@ -27,7 +28,6 @@ def setUpModule() -> None:
class TestSMTP(unittest.IsolatedAsyncioTestCase):
async def asyncSetUp(self) -> None:
logging.basicConfig(level=logging.DEBUG)
smtp_server = await create_smtp_server(
host="127.0.0.1",
port=7996,
@ -43,16 +43,15 @@ class TestSMTP(unittest.IsolatedAsyncioTestCase):
Byee
"""
msg = b"".join(l.strip() + b"\r\n" for l in msg.splitlines())
local_port: str
def send_mail():
nonlocal local_port
server = smtplib.SMTP(host="127.0.0.1", port=7996)
server.sendmail("foo@sender.com", "foo@bar.com", msg)
_, local_port = server.sock.getsockname()
server.close()
with contextlib.closing(smtplib.SMTP(host="127.0.0.1",
port=7996)) as client:
client.sendmail("foo@sender.com", "foo@bar.com", msg)
_, local_port = client.sock.getsockname()
return local_port
await asyncio.to_thread(send_mail)
local_port = await asyncio.to_thread(send_mail)
expected = f"""From: foo@sender.com
To: "foo@bar.com"
X-Peer: ('127.0.0.1', {local_port})