Improve logging
This commit is contained in:
parent
2fa748c444
commit
f407c8b395
@ -1,5 +1,6 @@
|
|||||||
import json
|
import json
|
||||||
import re
|
import re
|
||||||
|
import logging
|
||||||
from typing import Callable
|
from typing import Callable
|
||||||
from jata import Jata, MutableDefault
|
from jata import Jata, MutableDefault
|
||||||
|
|
||||||
@ -61,10 +62,15 @@ class SmtpCfg(ServerCfg):
|
|||||||
port = 465
|
port = 465
|
||||||
|
|
||||||
|
|
||||||
|
class LogCfg(Jata):
|
||||||
|
logfile = "STDOUT"
|
||||||
|
level = "INFO"
|
||||||
|
|
||||||
|
|
||||||
class Config(Jata):
|
class Config(Jata):
|
||||||
default_tls: TLSCfg | None
|
default_tls: TLSCfg | None
|
||||||
default_host: str = "0.0.0.0"
|
default_host: str = "0.0.0.0"
|
||||||
debug: bool = False
|
logging: LogCfg | None = None
|
||||||
|
|
||||||
mails_path: str
|
mails_path: str
|
||||||
matches: list[Match]
|
matches: list[Match]
|
||||||
@ -127,4 +133,5 @@ def get_mboxes(addr: str, checks: list[Checker]) -> list[str]:
|
|||||||
|
|
||||||
def gen_addr_to_mboxes(cfg: Config) -> Callable[[str], list[str]]:
|
def gen_addr_to_mboxes(cfg: Config) -> Callable[[str], list[str]]:
|
||||||
checks = parse_checkers(cfg)
|
checks = parse_checkers(cfg)
|
||||||
|
logging.info(f"Parsed checkers from config, {len(checks)=}")
|
||||||
return lambda addr: get_mboxes(addr, checks)
|
return lambda addr: get_mboxes(addr, checks)
|
||||||
|
143
mail4one/pop3.py
143
mail4one/pop3.py
@ -1,15 +1,17 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
|
import contextlib
|
||||||
|
import contextvars
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import ssl
|
import ssl
|
||||||
import contextvars
|
import uuid
|
||||||
import contextlib
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from hashlib import sha256
|
from hashlib import sha256
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from .config import User
|
from .config import User
|
||||||
from .pwhash import parse_hash, check_pass, PWInfo
|
from .pwhash import parse_hash, check_pass, PWInfo
|
||||||
from asyncio import StreamReader, StreamWriter
|
from asyncio import StreamReader, StreamWriter
|
||||||
|
import random
|
||||||
|
|
||||||
from .poputils import (
|
from .poputils import (
|
||||||
InvalidCommand,
|
InvalidCommand,
|
||||||
@ -31,10 +33,66 @@ from .poputils import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class State:
|
||||||
|
reader: StreamReader
|
||||||
|
writer: StreamWriter
|
||||||
|
ip: str
|
||||||
|
req_id: int
|
||||||
|
username: str = ""
|
||||||
|
mbox: str = ""
|
||||||
|
|
||||||
|
|
||||||
|
class SharedState:
|
||||||
|
|
||||||
|
def __init__(self, mails_path: Path, users: dict[str, tuple[PWInfo, str]]):
|
||||||
|
self.mails_path = mails_path
|
||||||
|
self.users = users
|
||||||
|
self.loggedin_users: set[str] = set()
|
||||||
|
self.counter = random.randint(10000, 99999) * 100000
|
||||||
|
|
||||||
|
def next_id(self) -> int:
|
||||||
|
self.counter = self.counter + 1
|
||||||
|
return self.counter
|
||||||
|
|
||||||
|
|
||||||
|
c_shared_state: contextvars.ContextVar = contextvars.ContextVar(
|
||||||
|
"pop_shared_state")
|
||||||
|
|
||||||
|
|
||||||
|
def scfg() -> SharedState:
|
||||||
|
return c_shared_state.get()
|
||||||
|
|
||||||
|
|
||||||
|
c_state: contextvars.ContextVar = contextvars.ContextVar("state")
|
||||||
|
|
||||||
|
|
||||||
|
def state() -> State:
|
||||||
|
return c_state.get()
|
||||||
|
|
||||||
|
|
||||||
|
class PopLogger(logging.LoggerAdapter):
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(logging.getLogger("pop3"), None)
|
||||||
|
|
||||||
|
def process(self, msg, kwargs):
|
||||||
|
state: State = c_state.get(None)
|
||||||
|
if not state:
|
||||||
|
return super().process(msg, kwargs)
|
||||||
|
user = "NA"
|
||||||
|
if state.username:
|
||||||
|
user = state.username
|
||||||
|
return super().process(f"{state.ip} {state.req_id} {user} {msg}", kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
logger = PopLogger()
|
||||||
|
|
||||||
|
|
||||||
async def next_req() -> Request:
|
async def next_req() -> Request:
|
||||||
for _ in range(InvalidCommand.RETRIES):
|
for _ in range(InvalidCommand.RETRIES):
|
||||||
line = await state().reader.readline()
|
line = await state().reader.readline()
|
||||||
logging.debug(f"Client: {line!r}")
|
logger.debug(f"Client: {line!r}")
|
||||||
if not line:
|
if not line:
|
||||||
if state().reader.at_eof():
|
if state().reader.at_eof():
|
||||||
raise ClientDisconnected
|
raise ClientDisconnected
|
||||||
@ -54,19 +112,19 @@ async def next_req() -> Request:
|
|||||||
async def expect_cmd(*commands: Command) -> Request:
|
async def expect_cmd(*commands: Command) -> Request:
|
||||||
req = await next_req()
|
req = await next_req()
|
||||||
if req.cmd not in commands:
|
if req.cmd not in commands:
|
||||||
logging.error(f"Unexpected command: {req.cmd} is not in {commands}")
|
logger.error(f"Unexpected command: {req.cmd} is not in {commands}")
|
||||||
raise ClientError
|
raise ClientError
|
||||||
return req
|
return req
|
||||||
|
|
||||||
|
|
||||||
def write(data) -> None:
|
def write(data) -> None:
|
||||||
logging.debug(f"Server: {data}")
|
logger.debug(f"Server: {data}")
|
||||||
state().writer.write(data)
|
state().writer.write(data)
|
||||||
|
|
||||||
|
|
||||||
def validate_password(username, password) -> None:
|
def validate_password(username, password) -> None:
|
||||||
try:
|
try:
|
||||||
pwinfo, mbox = config().users[username]
|
pwinfo, mbox = scfg().users[username]
|
||||||
except:
|
except:
|
||||||
raise AuthError("Invalid user pass")
|
raise AuthError("Invalid user pass")
|
||||||
|
|
||||||
@ -84,7 +142,7 @@ async def handle_user_pass_auth(user_cmd) -> None:
|
|||||||
cmd = await expect_cmd(Command.PASS)
|
cmd = await expect_cmd(Command.PASS)
|
||||||
password = cmd.arg1
|
password = cmd.arg1
|
||||||
validate_password(username, password)
|
validate_password(username, password)
|
||||||
logging.info(f"{username=} has logged in successfully")
|
logger.info(f"{username=} has logged in successfully")
|
||||||
|
|
||||||
|
|
||||||
async def auth_stage() -> None:
|
async def auth_stage() -> None:
|
||||||
@ -98,20 +156,20 @@ async def auth_stage() -> None:
|
|||||||
write(end())
|
write(end())
|
||||||
else:
|
else:
|
||||||
await handle_user_pass_auth(req)
|
await handle_user_pass_auth(req)
|
||||||
if state().username in config().loggedin_users:
|
if state().username in scfg().loggedin_users:
|
||||||
logging.warning(
|
logger.warning(
|
||||||
f"User: {state().username} already has an active session"
|
f"User: {state().username} already has an active session"
|
||||||
)
|
)
|
||||||
raise AuthError("Already logged in")
|
raise AuthError("Already logged in")
|
||||||
else:
|
else:
|
||||||
config().loggedin_users.add(state().username)
|
scfg().loggedin_users.add(state().username)
|
||||||
write(ok("Login successful"))
|
write(ok("Login successful"))
|
||||||
return
|
return
|
||||||
except AuthError as ae:
|
except AuthError as ae:
|
||||||
write(err(f"Auth Failed: {ae}"))
|
write(err(f"Auth Failed: {ae}"))
|
||||||
except ClientQuit as c:
|
except ClientQuit as c:
|
||||||
write(ok("Bye"))
|
write(ok("Bye"))
|
||||||
logging.warning("Client has QUIT before auth succeeded")
|
logger.warning("Client has QUIT before auth succeeded")
|
||||||
raise
|
raise
|
||||||
else:
|
else:
|
||||||
raise ClientError("Failed to authenticate")
|
raise ClientError("Failed to authenticate")
|
||||||
@ -204,7 +262,7 @@ async def process_transactions(mails_list: list[MailEntry]) -> set[str]:
|
|||||||
except ClientQuit:
|
except ClientQuit:
|
||||||
write(ok("Bye"))
|
write(ok("Bye"))
|
||||||
return mails.deleted_uids
|
return mails.deleted_uids
|
||||||
logging.debug(f"Request: {req}")
|
logger.debug(f"Request: {req}")
|
||||||
try:
|
try:
|
||||||
func = handle_map[req.cmd]
|
func = handle_map[req.cmd]
|
||||||
except KeyError:
|
except KeyError:
|
||||||
@ -229,46 +287,46 @@ def save_deleted_items(deleted_items_path: Path,
|
|||||||
|
|
||||||
|
|
||||||
async def transaction_stage() -> None:
|
async def transaction_stage() -> None:
|
||||||
deleted_items_path = config().mails_path / state().mbox / state().username
|
deleted_items_path = scfg().mails_path / state().mbox / state().username
|
||||||
existing_deleted_items: set[str] = get_deleted_items(deleted_items_path)
|
existing_deleted_items: set[str] = get_deleted_items(deleted_items_path)
|
||||||
mails_list = [
|
mails_list = [
|
||||||
entry
|
entry
|
||||||
for entry in get_mails_list(config().mails_path / state().mbox / "new")
|
for entry in get_mails_list(scfg().mails_path / state().mbox / "new")
|
||||||
if entry.uid not in existing_deleted_items
|
if entry.uid not in existing_deleted_items
|
||||||
]
|
]
|
||||||
|
|
||||||
new_deleted_items: set[str] = await process_transactions(mails_list)
|
new_deleted_items: set[str] = await process_transactions(mails_list)
|
||||||
logging.info(f"completed transactions. Deleted:{len(new_deleted_items)}")
|
logger.info(f"completed transactions. Deleted:{len(new_deleted_items)}")
|
||||||
if new_deleted_items:
|
if new_deleted_items:
|
||||||
save_deleted_items(deleted_items_path,
|
save_deleted_items(deleted_items_path,
|
||||||
existing_deleted_items.union(new_deleted_items))
|
existing_deleted_items.union(new_deleted_items))
|
||||||
|
|
||||||
logging.info(f"Saved deleted items")
|
logger.info(f"Saved deleted items")
|
||||||
|
|
||||||
|
|
||||||
async def start_session() -> None:
|
async def start_session() -> None:
|
||||||
logging.info("New session started")
|
logger.info("New session started")
|
||||||
try:
|
try:
|
||||||
await auth_stage()
|
await auth_stage()
|
||||||
assert state().username
|
assert state().username
|
||||||
assert state().mbox
|
assert state().mbox
|
||||||
await transaction_stage()
|
await transaction_stage()
|
||||||
logging.info(f"User:{state().username} done")
|
logger.info(f"User:{state().username} done")
|
||||||
except ClientDisconnected as c:
|
except ClientDisconnected as c:
|
||||||
logging.info("Client disconnected")
|
logger.info("Client disconnected")
|
||||||
pass
|
pass
|
||||||
except ClientQuit:
|
except ClientQuit:
|
||||||
logging.info("Client QUIT")
|
logger.info("Client QUIT")
|
||||||
pass
|
pass
|
||||||
except ClientError as c:
|
except ClientError as c:
|
||||||
write(err("Something went wrong"))
|
write(err("Something went wrong"))
|
||||||
logging.error(f"Unexpected client error: {c}")
|
logger.error(f"Unexpected client error: {c}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.error(f"Serious client error: {e}")
|
logger.error(f"Serious client error: {e}")
|
||||||
raise
|
raise
|
||||||
finally:
|
finally:
|
||||||
with contextlib.suppress(KeyError):
|
with contextlib.suppress(KeyError):
|
||||||
config().loggedin_users.remove(state().username)
|
scfg().loggedin_users.remove(state().username)
|
||||||
|
|
||||||
|
|
||||||
def parse_users(users: list[User]) -> dict[str, tuple[PWInfo, str]]:
|
def parse_users(users: list[User]) -> dict[str, tuple[PWInfo, str]]:
|
||||||
@ -282,43 +340,16 @@ def parse_users(users: list[User]) -> dict[str, tuple[PWInfo, str]]:
|
|||||||
return dict(inner())
|
return dict(inner())
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class State:
|
|
||||||
reader: StreamReader
|
|
||||||
writer: StreamWriter
|
|
||||||
username: str = ""
|
|
||||||
mbox: str = ""
|
|
||||||
|
|
||||||
|
|
||||||
class Config:
|
|
||||||
|
|
||||||
def __init__(self, mails_path: Path, users: dict[str, tuple[PWInfo, str]]):
|
|
||||||
self.mails_path = mails_path
|
|
||||||
self.users = users
|
|
||||||
self.loggedin_users: set[str] = set()
|
|
||||||
|
|
||||||
|
|
||||||
c_config: contextvars.ContextVar = contextvars.ContextVar("config")
|
|
||||||
|
|
||||||
|
|
||||||
def config() -> Config:
|
|
||||||
return c_config.get()
|
|
||||||
|
|
||||||
|
|
||||||
c_state: contextvars.ContextVar = contextvars.ContextVar("state")
|
|
||||||
|
|
||||||
|
|
||||||
def state() -> State:
|
|
||||||
return c_state.get()
|
|
||||||
|
|
||||||
|
|
||||||
def make_pop_server_callback(mails_path: Path, users: list[User],
|
def make_pop_server_callback(mails_path: Path, users: list[User],
|
||||||
timeout_seconds: int):
|
timeout_seconds: int):
|
||||||
config = Config(mails_path=mails_path, users=parse_users(users))
|
scfg = SharedState(mails_path=mails_path, users=parse_users(users))
|
||||||
|
|
||||||
async def session_cb(reader: StreamReader, writer: StreamWriter):
|
async def session_cb(reader: StreamReader, writer: StreamWriter):
|
||||||
c_config.set(config)
|
c_shared_state.set(scfg)
|
||||||
c_state.set(State(reader=reader, writer=writer))
|
ip, _ = writer.get_extra_info("peername")
|
||||||
|
c_state.set(
|
||||||
|
State(reader=reader, writer=writer, ip=ip, req_id=scfg.next_id()))
|
||||||
|
logger.info(f"Got pop server callback")
|
||||||
try:
|
try:
|
||||||
return await asyncio.wait_for(start_session(), timeout_seconds)
|
return await asyncio.wait_for(start_session(), timeout_seconds)
|
||||||
finally:
|
finally:
|
||||||
|
@ -20,17 +20,20 @@ def create_tls_context(certfile, keyfile) -> ssl.SSLContext:
|
|||||||
return context
|
return context
|
||||||
|
|
||||||
|
|
||||||
def setup_logging(args):
|
def setup_logging(cfg: config.LogCfg):
|
||||||
if args.debug:
|
logging_format = "%(asctime)s %(name)s %(levelname)s %(message)s @ %(filename)s:%(lineno)d"
|
||||||
logging.basicConfig(level=logging.DEBUG)
|
if cfg.logfile == "STDOUT":
|
||||||
|
logging.basicConfig(level=cfg.level, format=logging_format)
|
||||||
else:
|
else:
|
||||||
logging.basicConfig(level=logging.INFO)
|
logging.basicConfig(filename=cfg.logfile, level=cfg.level, format=logging_format)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
async def a_main(cfg: config.Config) -> None:
|
async def a_main(cfg: config.Config) -> None:
|
||||||
default_tls_context: ssl.SSLContext | None = None
|
default_tls_context: ssl.SSLContext | None = None
|
||||||
|
|
||||||
if tls := cfg.default_tls:
|
if tls := cfg.default_tls:
|
||||||
|
logging.info(f"Initializing default tls {tls.certfile=}, {tls.keyfile=}")
|
||||||
default_tls_context = create_tls_context(tls.certfile, tls.keyfile)
|
default_tls_context = create_tls_context(tls.certfile, tls.keyfile)
|
||||||
|
|
||||||
def get_tls_context(tls: config.TLSCfg | str):
|
def get_tls_context(tls: config.TLSCfg | str):
|
||||||
@ -146,8 +149,9 @@ def main() -> None:
|
|||||||
print("✗ password and hash do not match")
|
print("✗ password and hash do not match")
|
||||||
else:
|
else:
|
||||||
cfg = config.Config(args.config.read_text())
|
cfg = config.Config(args.config.read_text())
|
||||||
setup_logging(cfg)
|
setup_logging(config.LogCfg(cfg.logging))
|
||||||
asyncio.run(a_main(cfg), debug=cfg.debug)
|
logging.info(f"Starting mail4one {args.config=}")
|
||||||
|
asyncio.run(a_main(cfg))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
@ -13,6 +13,7 @@ from email.message import Message
|
|||||||
import email.policy
|
import email.policy
|
||||||
from email.generator import BytesGenerator
|
from email.generator import BytesGenerator
|
||||||
import tempfile
|
import tempfile
|
||||||
|
import random
|
||||||
|
|
||||||
from aiosmtpd.handlers import Mailbox, AsyncMessage
|
from aiosmtpd.handlers import Mailbox, AsyncMessage
|
||||||
from aiosmtpd.smtp import SMTP, DATA_SIZE_DEFAULT
|
from aiosmtpd.smtp import SMTP, DATA_SIZE_DEFAULT
|
||||||
@ -20,6 +21,8 @@ from aiosmtpd.smtp import SMTP as SMTPServer
|
|||||||
from aiosmtpd.smtp import Envelope as SMTPEnvelope
|
from aiosmtpd.smtp import Envelope as SMTPEnvelope
|
||||||
from aiosmtpd.smtp import Session as SMTPSession
|
from aiosmtpd.smtp import Session as SMTPSession
|
||||||
|
|
||||||
|
logger = logging.getLogger("smtp")
|
||||||
|
|
||||||
|
|
||||||
class MyHandler(AsyncMessage):
|
class MyHandler(AsyncMessage):
|
||||||
|
|
||||||
@ -32,6 +35,7 @@ class MyHandler(AsyncMessage):
|
|||||||
async def handle_DATA(self, server: SMTPServer, session: SMTPSession,
|
async def handle_DATA(self, server: SMTPServer, session: SMTPSession,
|
||||||
envelope: SMTPEnvelope) -> str:
|
envelope: SMTPEnvelope) -> str:
|
||||||
self.rcpt_tos = envelope.rcpt_tos
|
self.rcpt_tos = envelope.rcpt_tos
|
||||||
|
self.peer = session.peer
|
||||||
return await super().handle_DATA(server, session, envelope)
|
return await super().handle_DATA(server, session, envelope)
|
||||||
|
|
||||||
async def handle_message(self, m: Message): # type: ignore[override]
|
async def handle_message(self, m: Message): # type: ignore[override]
|
||||||
@ -40,24 +44,29 @@ class MyHandler(AsyncMessage):
|
|||||||
for mbox in self.mbox_finder(addr):
|
for mbox in self.mbox_finder(addr):
|
||||||
all_mboxes.add(mbox)
|
all_mboxes.add(mbox)
|
||||||
if not all_mboxes:
|
if not all_mboxes:
|
||||||
|
logger.warning(f"dropping message from: {self.peer}")
|
||||||
return
|
return
|
||||||
for mbox in all_mboxes:
|
for mbox in all_mboxes:
|
||||||
for sub in ("new", "tmp", "cur"):
|
for sub in ("new", "tmp", "cur"):
|
||||||
sub_path = self.mails_path / mbox / sub
|
sub_path = self.mails_path / mbox / sub
|
||||||
sub_path.mkdir(mode=0o755, exist_ok=True, parents=True)
|
sub_path.mkdir(mode=0o755, exist_ok=True, parents=True)
|
||||||
with tempfile.TemporaryDirectory() as tmpdir:
|
with tempfile.TemporaryDirectory() as tmpdir:
|
||||||
temp_email_path = Path(tmpdir) / f"{uuid.uuid4()}.eml"
|
filename = f"{uuid.uuid4()}.eml"
|
||||||
|
temp_email_path = Path(tmpdir) / filename
|
||||||
with open(temp_email_path, "wb") as fp:
|
with open(temp_email_path, "wb") as fp:
|
||||||
gen = BytesGenerator(fp, policy=email.policy.SMTP)
|
gen = BytesGenerator(fp, policy=email.policy.SMTP)
|
||||||
gen.flatten(m)
|
gen.flatten(m)
|
||||||
for mbox in all_mboxes:
|
for mbox in all_mboxes:
|
||||||
shutil.copy(temp_email_path, self.mails_path / mbox / "new")
|
shutil.copy(temp_email_path, self.mails_path / mbox / "new")
|
||||||
|
logger.info(
|
||||||
|
f"Saved mail at {filename} addrs: {','.join(self.rcpt_tos)}, mboxes: {','.join(all_mboxes)} peer: {self.peer}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def protocol_factory_starttls(mails_path: Path,
|
def protocol_factory_starttls(mails_path: Path,
|
||||||
mbox_finder: Callable[[str], list[str]],
|
mbox_finder: Callable[[str], list[str]],
|
||||||
context: ssl.SSLContext):
|
context: ssl.SSLContext):
|
||||||
logging.info("Got smtp client cb starttls")
|
logger.info("Got smtp client cb starttls")
|
||||||
try:
|
try:
|
||||||
handler = MyHandler(mails_path, mbox_finder)
|
handler = MyHandler(mails_path, mbox_finder)
|
||||||
smtp = SMTP(
|
smtp = SMTP(
|
||||||
@ -67,19 +76,19 @@ def protocol_factory_starttls(mails_path: Path,
|
|||||||
enable_SMTPUTF8=True,
|
enable_SMTPUTF8=True,
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.error("Something went wrong", e)
|
logger.error("Something went wrong", e)
|
||||||
raise
|
raise
|
||||||
return smtp
|
return smtp
|
||||||
|
|
||||||
|
|
||||||
def protocol_factory(mails_path: Path, mbox_finder: Callable[[str],
|
def protocol_factory(mails_path: Path, mbox_finder: Callable[[str],
|
||||||
list[str]]):
|
list[str]]):
|
||||||
logging.info("Got smtp client cb")
|
logger.info("Got smtp client cb")
|
||||||
try:
|
try:
|
||||||
handler = MyHandler(mails_path, mbox_finder)
|
handler = MyHandler(mails_path, mbox_finder)
|
||||||
smtp = SMTP(handler=handler, enable_SMTPUTF8=True)
|
smtp = SMTP(handler=handler, enable_SMTPUTF8=True)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.error("Something went wrong", e)
|
logger.error("Something went wrong", e)
|
||||||
raise
|
raise
|
||||||
return smtp
|
return smtp
|
||||||
|
|
||||||
@ -91,6 +100,9 @@ async def create_smtp_server_starttls(
|
|||||||
mbox_finder: Callable[[str], list[str]],
|
mbox_finder: Callable[[str], list[str]],
|
||||||
ssl_context: ssl.SSLContext,
|
ssl_context: ssl.SSLContext,
|
||||||
) -> asyncio.Server:
|
) -> asyncio.Server:
|
||||||
|
logging.info(
|
||||||
|
f"Starting SMTP STARTTLS server {host=}, {port=}, {mails_path=}, {ssl_context != None=}"
|
||||||
|
)
|
||||||
loop = asyncio.get_event_loop()
|
loop = asyncio.get_event_loop()
|
||||||
return await loop.create_server(
|
return await loop.create_server(
|
||||||
partial(protocol_factory_starttls, mails_path, mbox_finder,
|
partial(protocol_factory_starttls, mails_path, mbox_finder,
|
||||||
@ -108,6 +120,9 @@ async def create_smtp_server(
|
|||||||
mbox_finder: Callable[[str], list[str]],
|
mbox_finder: Callable[[str], list[str]],
|
||||||
ssl_context: ssl.SSLContext | None = None,
|
ssl_context: ssl.SSLContext | None = None,
|
||||||
) -> asyncio.Server:
|
) -> asyncio.Server:
|
||||||
|
logging.info(
|
||||||
|
f"Starting SMTP server {host=}, {port=}, {mails_path=}, {ssl_context != None=}"
|
||||||
|
)
|
||||||
loop = asyncio.get_event_loop()
|
loop = asyncio.get_event_loop()
|
||||||
return await loop.create_server(
|
return await loop.create_server(
|
||||||
partial(protocol_factory, mails_path, mbox_finder),
|
partial(protocol_factory, mails_path, mbox_finder),
|
||||||
|
@ -3,6 +3,7 @@ import logging
|
|||||||
import unittest
|
import unittest
|
||||||
import smtplib
|
import smtplib
|
||||||
import tempfile
|
import tempfile
|
||||||
|
import contextlib
|
||||||
import os
|
import os
|
||||||
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
@ -27,7 +28,6 @@ def setUpModule() -> None:
|
|||||||
class TestSMTP(unittest.IsolatedAsyncioTestCase):
|
class TestSMTP(unittest.IsolatedAsyncioTestCase):
|
||||||
|
|
||||||
async def asyncSetUp(self) -> None:
|
async def asyncSetUp(self) -> None:
|
||||||
logging.basicConfig(level=logging.DEBUG)
|
|
||||||
smtp_server = await create_smtp_server(
|
smtp_server = await create_smtp_server(
|
||||||
host="127.0.0.1",
|
host="127.0.0.1",
|
||||||
port=7996,
|
port=7996,
|
||||||
@ -43,16 +43,15 @@ class TestSMTP(unittest.IsolatedAsyncioTestCase):
|
|||||||
Byee
|
Byee
|
||||||
"""
|
"""
|
||||||
msg = b"".join(l.strip() + b"\r\n" for l in msg.splitlines())
|
msg = b"".join(l.strip() + b"\r\n" for l in msg.splitlines())
|
||||||
local_port: str
|
|
||||||
|
|
||||||
def send_mail():
|
def send_mail():
|
||||||
nonlocal local_port
|
with contextlib.closing(smtplib.SMTP(host="127.0.0.1",
|
||||||
server = smtplib.SMTP(host="127.0.0.1", port=7996)
|
port=7996)) as client:
|
||||||
server.sendmail("foo@sender.com", "foo@bar.com", msg)
|
client.sendmail("foo@sender.com", "foo@bar.com", msg)
|
||||||
_, local_port = server.sock.getsockname()
|
_, local_port = client.sock.getsockname()
|
||||||
server.close()
|
return local_port
|
||||||
|
|
||||||
await asyncio.to_thread(send_mail)
|
local_port = await asyncio.to_thread(send_mail)
|
||||||
expected = f"""From: foo@sender.com
|
expected = f"""From: foo@sender.com
|
||||||
To: "foo@bar.com"
|
To: "foo@bar.com"
|
||||||
X-Peer: ('127.0.0.1', {local_port})
|
X-Peer: ('127.0.0.1', {local_port})
|
||||||
|
Loading…
Reference in New Issue
Block a user