Improve logging
This commit is contained in:
parent
2fa748c444
commit
f407c8b395
@ -1,5 +1,6 @@
|
||||
import json
|
||||
import re
|
||||
import logging
|
||||
from typing import Callable
|
||||
from jata import Jata, MutableDefault
|
||||
|
||||
@ -61,10 +62,15 @@ class SmtpCfg(ServerCfg):
|
||||
port = 465
|
||||
|
||||
|
||||
class LogCfg(Jata):
|
||||
logfile = "STDOUT"
|
||||
level = "INFO"
|
||||
|
||||
|
||||
class Config(Jata):
|
||||
default_tls: TLSCfg | None
|
||||
default_host: str = "0.0.0.0"
|
||||
debug: bool = False
|
||||
logging: LogCfg | None = None
|
||||
|
||||
mails_path: str
|
||||
matches: list[Match]
|
||||
@ -127,4 +133,5 @@ def get_mboxes(addr: str, checks: list[Checker]) -> list[str]:
|
||||
|
||||
def gen_addr_to_mboxes(cfg: Config) -> Callable[[str], list[str]]:
|
||||
checks = parse_checkers(cfg)
|
||||
logging.info(f"Parsed checkers from config, {len(checks)=}")
|
||||
return lambda addr: get_mboxes(addr, checks)
|
||||
|
143
mail4one/pop3.py
143
mail4one/pop3.py
@ -1,15 +1,17 @@
|
||||
import asyncio
|
||||
import contextlib
|
||||
import contextvars
|
||||
import logging
|
||||
import os
|
||||
import ssl
|
||||
import contextvars
|
||||
import contextlib
|
||||
import uuid
|
||||
from dataclasses import dataclass
|
||||
from hashlib import sha256
|
||||
from pathlib import Path
|
||||
from .config import User
|
||||
from .pwhash import parse_hash, check_pass, PWInfo
|
||||
from asyncio import StreamReader, StreamWriter
|
||||
import random
|
||||
|
||||
from .poputils import (
|
||||
InvalidCommand,
|
||||
@ -31,10 +33,66 @@ from .poputils import (
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class State:
|
||||
reader: StreamReader
|
||||
writer: StreamWriter
|
||||
ip: str
|
||||
req_id: int
|
||||
username: str = ""
|
||||
mbox: str = ""
|
||||
|
||||
|
||||
class SharedState:
|
||||
|
||||
def __init__(self, mails_path: Path, users: dict[str, tuple[PWInfo, str]]):
|
||||
self.mails_path = mails_path
|
||||
self.users = users
|
||||
self.loggedin_users: set[str] = set()
|
||||
self.counter = random.randint(10000, 99999) * 100000
|
||||
|
||||
def next_id(self) -> int:
|
||||
self.counter = self.counter + 1
|
||||
return self.counter
|
||||
|
||||
|
||||
c_shared_state: contextvars.ContextVar = contextvars.ContextVar(
|
||||
"pop_shared_state")
|
||||
|
||||
|
||||
def scfg() -> SharedState:
|
||||
return c_shared_state.get()
|
||||
|
||||
|
||||
c_state: contextvars.ContextVar = contextvars.ContextVar("state")
|
||||
|
||||
|
||||
def state() -> State:
|
||||
return c_state.get()
|
||||
|
||||
|
||||
class PopLogger(logging.LoggerAdapter):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(logging.getLogger("pop3"), None)
|
||||
|
||||
def process(self, msg, kwargs):
|
||||
state: State = c_state.get(None)
|
||||
if not state:
|
||||
return super().process(msg, kwargs)
|
||||
user = "NA"
|
||||
if state.username:
|
||||
user = state.username
|
||||
return super().process(f"{state.ip} {state.req_id} {user} {msg}", kwargs)
|
||||
|
||||
|
||||
logger = PopLogger()
|
||||
|
||||
|
||||
async def next_req() -> Request:
|
||||
for _ in range(InvalidCommand.RETRIES):
|
||||
line = await state().reader.readline()
|
||||
logging.debug(f"Client: {line!r}")
|
||||
logger.debug(f"Client: {line!r}")
|
||||
if not line:
|
||||
if state().reader.at_eof():
|
||||
raise ClientDisconnected
|
||||
@ -54,19 +112,19 @@ async def next_req() -> Request:
|
||||
async def expect_cmd(*commands: Command) -> Request:
|
||||
req = await next_req()
|
||||
if req.cmd not in commands:
|
||||
logging.error(f"Unexpected command: {req.cmd} is not in {commands}")
|
||||
logger.error(f"Unexpected command: {req.cmd} is not in {commands}")
|
||||
raise ClientError
|
||||
return req
|
||||
|
||||
|
||||
def write(data) -> None:
|
||||
logging.debug(f"Server: {data}")
|
||||
logger.debug(f"Server: {data}")
|
||||
state().writer.write(data)
|
||||
|
||||
|
||||
def validate_password(username, password) -> None:
|
||||
try:
|
||||
pwinfo, mbox = config().users[username]
|
||||
pwinfo, mbox = scfg().users[username]
|
||||
except:
|
||||
raise AuthError("Invalid user pass")
|
||||
|
||||
@ -84,7 +142,7 @@ async def handle_user_pass_auth(user_cmd) -> None:
|
||||
cmd = await expect_cmd(Command.PASS)
|
||||
password = cmd.arg1
|
||||
validate_password(username, password)
|
||||
logging.info(f"{username=} has logged in successfully")
|
||||
logger.info(f"{username=} has logged in successfully")
|
||||
|
||||
|
||||
async def auth_stage() -> None:
|
||||
@ -98,20 +156,20 @@ async def auth_stage() -> None:
|
||||
write(end())
|
||||
else:
|
||||
await handle_user_pass_auth(req)
|
||||
if state().username in config().loggedin_users:
|
||||
logging.warning(
|
||||
if state().username in scfg().loggedin_users:
|
||||
logger.warning(
|
||||
f"User: {state().username} already has an active session"
|
||||
)
|
||||
raise AuthError("Already logged in")
|
||||
else:
|
||||
config().loggedin_users.add(state().username)
|
||||
scfg().loggedin_users.add(state().username)
|
||||
write(ok("Login successful"))
|
||||
return
|
||||
except AuthError as ae:
|
||||
write(err(f"Auth Failed: {ae}"))
|
||||
except ClientQuit as c:
|
||||
write(ok("Bye"))
|
||||
logging.warning("Client has QUIT before auth succeeded")
|
||||
logger.warning("Client has QUIT before auth succeeded")
|
||||
raise
|
||||
else:
|
||||
raise ClientError("Failed to authenticate")
|
||||
@ -204,7 +262,7 @@ async def process_transactions(mails_list: list[MailEntry]) -> set[str]:
|
||||
except ClientQuit:
|
||||
write(ok("Bye"))
|
||||
return mails.deleted_uids
|
||||
logging.debug(f"Request: {req}")
|
||||
logger.debug(f"Request: {req}")
|
||||
try:
|
||||
func = handle_map[req.cmd]
|
||||
except KeyError:
|
||||
@ -229,46 +287,46 @@ def save_deleted_items(deleted_items_path: Path,
|
||||
|
||||
|
||||
async def transaction_stage() -> None:
|
||||
deleted_items_path = config().mails_path / state().mbox / state().username
|
||||
deleted_items_path = scfg().mails_path / state().mbox / state().username
|
||||
existing_deleted_items: set[str] = get_deleted_items(deleted_items_path)
|
||||
mails_list = [
|
||||
entry
|
||||
for entry in get_mails_list(config().mails_path / state().mbox / "new")
|
||||
for entry in get_mails_list(scfg().mails_path / state().mbox / "new")
|
||||
if entry.uid not in existing_deleted_items
|
||||
]
|
||||
|
||||
new_deleted_items: set[str] = await process_transactions(mails_list)
|
||||
logging.info(f"completed transactions. Deleted:{len(new_deleted_items)}")
|
||||
logger.info(f"completed transactions. Deleted:{len(new_deleted_items)}")
|
||||
if new_deleted_items:
|
||||
save_deleted_items(deleted_items_path,
|
||||
existing_deleted_items.union(new_deleted_items))
|
||||
|
||||
logging.info(f"Saved deleted items")
|
||||
logger.info(f"Saved deleted items")
|
||||
|
||||
|
||||
async def start_session() -> None:
|
||||
logging.info("New session started")
|
||||
logger.info("New session started")
|
||||
try:
|
||||
await auth_stage()
|
||||
assert state().username
|
||||
assert state().mbox
|
||||
await transaction_stage()
|
||||
logging.info(f"User:{state().username} done")
|
||||
logger.info(f"User:{state().username} done")
|
||||
except ClientDisconnected as c:
|
||||
logging.info("Client disconnected")
|
||||
logger.info("Client disconnected")
|
||||
pass
|
||||
except ClientQuit:
|
||||
logging.info("Client QUIT")
|
||||
logger.info("Client QUIT")
|
||||
pass
|
||||
except ClientError as c:
|
||||
write(err("Something went wrong"))
|
||||
logging.error(f"Unexpected client error: {c}")
|
||||
logger.error(f"Unexpected client error: {c}")
|
||||
except Exception as e:
|
||||
logging.error(f"Serious client error: {e}")
|
||||
logger.error(f"Serious client error: {e}")
|
||||
raise
|
||||
finally:
|
||||
with contextlib.suppress(KeyError):
|
||||
config().loggedin_users.remove(state().username)
|
||||
scfg().loggedin_users.remove(state().username)
|
||||
|
||||
|
||||
def parse_users(users: list[User]) -> dict[str, tuple[PWInfo, str]]:
|
||||
@ -282,43 +340,16 @@ def parse_users(users: list[User]) -> dict[str, tuple[PWInfo, str]]:
|
||||
return dict(inner())
|
||||
|
||||
|
||||
@dataclass
|
||||
class State:
|
||||
reader: StreamReader
|
||||
writer: StreamWriter
|
||||
username: str = ""
|
||||
mbox: str = ""
|
||||
|
||||
|
||||
class Config:
|
||||
|
||||
def __init__(self, mails_path: Path, users: dict[str, tuple[PWInfo, str]]):
|
||||
self.mails_path = mails_path
|
||||
self.users = users
|
||||
self.loggedin_users: set[str] = set()
|
||||
|
||||
|
||||
c_config: contextvars.ContextVar = contextvars.ContextVar("config")
|
||||
|
||||
|
||||
def config() -> Config:
|
||||
return c_config.get()
|
||||
|
||||
|
||||
c_state: contextvars.ContextVar = contextvars.ContextVar("state")
|
||||
|
||||
|
||||
def state() -> State:
|
||||
return c_state.get()
|
||||
|
||||
|
||||
def make_pop_server_callback(mails_path: Path, users: list[User],
|
||||
timeout_seconds: int):
|
||||
config = Config(mails_path=mails_path, users=parse_users(users))
|
||||
scfg = SharedState(mails_path=mails_path, users=parse_users(users))
|
||||
|
||||
async def session_cb(reader: StreamReader, writer: StreamWriter):
|
||||
c_config.set(config)
|
||||
c_state.set(State(reader=reader, writer=writer))
|
||||
c_shared_state.set(scfg)
|
||||
ip, _ = writer.get_extra_info("peername")
|
||||
c_state.set(
|
||||
State(reader=reader, writer=writer, ip=ip, req_id=scfg.next_id()))
|
||||
logger.info(f"Got pop server callback")
|
||||
try:
|
||||
return await asyncio.wait_for(start_session(), timeout_seconds)
|
||||
finally:
|
||||
|
@ -20,17 +20,20 @@ def create_tls_context(certfile, keyfile) -> ssl.SSLContext:
|
||||
return context
|
||||
|
||||
|
||||
def setup_logging(args):
|
||||
if args.debug:
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
def setup_logging(cfg: config.LogCfg):
|
||||
logging_format = "%(asctime)s %(name)s %(levelname)s %(message)s @ %(filename)s:%(lineno)d"
|
||||
if cfg.logfile == "STDOUT":
|
||||
logging.basicConfig(level=cfg.level, format=logging_format)
|
||||
else:
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logging.basicConfig(filename=cfg.logfile, level=cfg.level, format=logging_format)
|
||||
|
||||
|
||||
|
||||
async def a_main(cfg: config.Config) -> None:
|
||||
default_tls_context: ssl.SSLContext | None = None
|
||||
|
||||
if tls := cfg.default_tls:
|
||||
logging.info(f"Initializing default tls {tls.certfile=}, {tls.keyfile=}")
|
||||
default_tls_context = create_tls_context(tls.certfile, tls.keyfile)
|
||||
|
||||
def get_tls_context(tls: config.TLSCfg | str):
|
||||
@ -146,8 +149,9 @@ def main() -> None:
|
||||
print("✗ password and hash do not match")
|
||||
else:
|
||||
cfg = config.Config(args.config.read_text())
|
||||
setup_logging(cfg)
|
||||
asyncio.run(a_main(cfg), debug=cfg.debug)
|
||||
setup_logging(config.LogCfg(cfg.logging))
|
||||
logging.info(f"Starting mail4one {args.config=}")
|
||||
asyncio.run(a_main(cfg))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -13,6 +13,7 @@ from email.message import Message
|
||||
import email.policy
|
||||
from email.generator import BytesGenerator
|
||||
import tempfile
|
||||
import random
|
||||
|
||||
from aiosmtpd.handlers import Mailbox, AsyncMessage
|
||||
from aiosmtpd.smtp import SMTP, DATA_SIZE_DEFAULT
|
||||
@ -20,6 +21,8 @@ from aiosmtpd.smtp import SMTP as SMTPServer
|
||||
from aiosmtpd.smtp import Envelope as SMTPEnvelope
|
||||
from aiosmtpd.smtp import Session as SMTPSession
|
||||
|
||||
logger = logging.getLogger("smtp")
|
||||
|
||||
|
||||
class MyHandler(AsyncMessage):
|
||||
|
||||
@ -32,6 +35,7 @@ class MyHandler(AsyncMessage):
|
||||
async def handle_DATA(self, server: SMTPServer, session: SMTPSession,
|
||||
envelope: SMTPEnvelope) -> str:
|
||||
self.rcpt_tos = envelope.rcpt_tos
|
||||
self.peer = session.peer
|
||||
return await super().handle_DATA(server, session, envelope)
|
||||
|
||||
async def handle_message(self, m: Message): # type: ignore[override]
|
||||
@ -40,24 +44,29 @@ class MyHandler(AsyncMessage):
|
||||
for mbox in self.mbox_finder(addr):
|
||||
all_mboxes.add(mbox)
|
||||
if not all_mboxes:
|
||||
logger.warning(f"dropping message from: {self.peer}")
|
||||
return
|
||||
for mbox in all_mboxes:
|
||||
for sub in ("new", "tmp", "cur"):
|
||||
sub_path = self.mails_path / mbox / sub
|
||||
sub_path.mkdir(mode=0o755, exist_ok=True, parents=True)
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
temp_email_path = Path(tmpdir) / f"{uuid.uuid4()}.eml"
|
||||
filename = f"{uuid.uuid4()}.eml"
|
||||
temp_email_path = Path(tmpdir) / filename
|
||||
with open(temp_email_path, "wb") as fp:
|
||||
gen = BytesGenerator(fp, policy=email.policy.SMTP)
|
||||
gen.flatten(m)
|
||||
for mbox in all_mboxes:
|
||||
shutil.copy(temp_email_path, self.mails_path / mbox / "new")
|
||||
logger.info(
|
||||
f"Saved mail at {filename} addrs: {','.join(self.rcpt_tos)}, mboxes: {','.join(all_mboxes)} peer: {self.peer}"
|
||||
)
|
||||
|
||||
|
||||
def protocol_factory_starttls(mails_path: Path,
|
||||
mbox_finder: Callable[[str], list[str]],
|
||||
context: ssl.SSLContext):
|
||||
logging.info("Got smtp client cb starttls")
|
||||
logger.info("Got smtp client cb starttls")
|
||||
try:
|
||||
handler = MyHandler(mails_path, mbox_finder)
|
||||
smtp = SMTP(
|
||||
@ -67,19 +76,19 @@ def protocol_factory_starttls(mails_path: Path,
|
||||
enable_SMTPUTF8=True,
|
||||
)
|
||||
except Exception as e:
|
||||
logging.error("Something went wrong", e)
|
||||
logger.error("Something went wrong", e)
|
||||
raise
|
||||
return smtp
|
||||
|
||||
|
||||
def protocol_factory(mails_path: Path, mbox_finder: Callable[[str],
|
||||
list[str]]):
|
||||
logging.info("Got smtp client cb")
|
||||
logger.info("Got smtp client cb")
|
||||
try:
|
||||
handler = MyHandler(mails_path, mbox_finder)
|
||||
smtp = SMTP(handler=handler, enable_SMTPUTF8=True)
|
||||
except Exception as e:
|
||||
logging.error("Something went wrong", e)
|
||||
logger.error("Something went wrong", e)
|
||||
raise
|
||||
return smtp
|
||||
|
||||
@ -91,6 +100,9 @@ async def create_smtp_server_starttls(
|
||||
mbox_finder: Callable[[str], list[str]],
|
||||
ssl_context: ssl.SSLContext,
|
||||
) -> asyncio.Server:
|
||||
logging.info(
|
||||
f"Starting SMTP STARTTLS server {host=}, {port=}, {mails_path=}, {ssl_context != None=}"
|
||||
)
|
||||
loop = asyncio.get_event_loop()
|
||||
return await loop.create_server(
|
||||
partial(protocol_factory_starttls, mails_path, mbox_finder,
|
||||
@ -108,6 +120,9 @@ async def create_smtp_server(
|
||||
mbox_finder: Callable[[str], list[str]],
|
||||
ssl_context: ssl.SSLContext | None = None,
|
||||
) -> asyncio.Server:
|
||||
logging.info(
|
||||
f"Starting SMTP server {host=}, {port=}, {mails_path=}, {ssl_context != None=}"
|
||||
)
|
||||
loop = asyncio.get_event_loop()
|
||||
return await loop.create_server(
|
||||
partial(protocol_factory, mails_path, mbox_finder),
|
||||
|
@ -3,6 +3,7 @@ import logging
|
||||
import unittest
|
||||
import smtplib
|
||||
import tempfile
|
||||
import contextlib
|
||||
import os
|
||||
|
||||
from pathlib import Path
|
||||
@ -27,7 +28,6 @@ def setUpModule() -> None:
|
||||
class TestSMTP(unittest.IsolatedAsyncioTestCase):
|
||||
|
||||
async def asyncSetUp(self) -> None:
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
smtp_server = await create_smtp_server(
|
||||
host="127.0.0.1",
|
||||
port=7996,
|
||||
@ -43,16 +43,15 @@ class TestSMTP(unittest.IsolatedAsyncioTestCase):
|
||||
Byee
|
||||
"""
|
||||
msg = b"".join(l.strip() + b"\r\n" for l in msg.splitlines())
|
||||
local_port: str
|
||||
|
||||
def send_mail():
|
||||
nonlocal local_port
|
||||
server = smtplib.SMTP(host="127.0.0.1", port=7996)
|
||||
server.sendmail("foo@sender.com", "foo@bar.com", msg)
|
||||
_, local_port = server.sock.getsockname()
|
||||
server.close()
|
||||
with contextlib.closing(smtplib.SMTP(host="127.0.0.1",
|
||||
port=7996)) as client:
|
||||
client.sendmail("foo@sender.com", "foo@bar.com", msg)
|
||||
_, local_port = client.sock.getsockname()
|
||||
return local_port
|
||||
|
||||
await asyncio.to_thread(send_mail)
|
||||
local_port = await asyncio.to_thread(send_mail)
|
||||
expected = f"""From: foo@sender.com
|
||||
To: "foo@bar.com"
|
||||
X-Peer: ('127.0.0.1', {local_port})
|
||||
|
Loading…
Reference in New Issue
Block a user