Compare commits

...

8 Commits
v1.1 ... main

9 changed files with 150 additions and 86 deletions

38
DEVNOTES.md Normal file
View File

@ -0,0 +1,38 @@
Notes for developers
## Running just one test
```
python -m unittest tests.test_pop.TestPop3.test_CAPA
```
## Patch for enable logging in test
Patch generated using below
```
git diff --patch -U1 tests >> ./DEVNOTES.md
```
Apply with below
```bash
git apply - <<PATCH
diff --git a/tests/test_pop.py b/tests/test_pop.py
index 55c1a91..a825665 100644
--- a/tests/test_pop.py
+++ b/tests/test_pop.py
@@ -55,3 +55,3 @@ def setUpModule() -> None:
global MAILS_PATH
- logging.basicConfig(level=logging.CRITICAL)
+ logging.basicConfig(level=logging.DEBUG)
td = tempfile.TemporaryDirectory(prefix="m41.pop.")
PATCH
```
## pylint
```
pylint mail4one/*py > /tmp/errs
vim +"cfile /tmp/errs"
```

View File

@ -1,10 +1,10 @@
# Needs python3 >= 3.9, sed, git for build, docker for tests # Needs python3 >= 3.9, sed, git for build
build: clean mail4one.pyz: requirements.txt mail4one/*py
python3 -m pip install -r requirements.txt --no-compile --target build python3 -m pip install -r requirements.txt --no-compile --target build
cp -r mail4one/ build/ cp -r mail4one/ build/
sed -i "s/DEVELOMENT/$(shell scripts/get_version.sh)/" build/mail4one/version.py sed -i "s/DEVELOMENT/$(shell scripts/get_version.sh)/" build/mail4one/version.py
find build -name "*.pyi" -o -name "py.typed" | xargs -I typefile rm typefile find build -name "*.pyi" -o -name "py.typed" | xargs -I typefile rm typefile
rm -rf build/bin rm -rf build/bin build/aiosmtpd/{docs,tests,qa}
rm -rf build/mail4one/__pycache__ rm -rf build/mail4one/__pycache__
rm -rf build/*.dist-info rm -rf build/*.dist-info
python3 -m zipapp \ python3 -m zipapp \
@ -13,10 +13,19 @@ build: clean
--main mail4one.server:main \ --main mail4one.server:main \
--compress build --compress build
.PHONY: build
build: clean mail4one.pyz
.PHONY: test
test: mail4one.pyz
PYTHONPATH=mail4one.pyz python3 -m unittest discover
.PHONY: clean
clean: clean:
rm -rf build rm -rf build
rm -rf mail4one.pyz rm -rf mail4one.pyz
.PHONY: docker-tests
docker-tests: docker-tests:
docker run --pull=always -v `pwd`:/app -w /app --rm python:3.11-alpine sh scripts/runtests.sh docker run --pull=always -v `pwd`:/app -w /app --rm python:3.11-alpine sh scripts/runtests.sh
docker run --pull=always -v `pwd`:/app -w /app --rm python:3.10-alpine sh scripts/runtests.sh docker run --pull=always -v `pwd`:/app -w /app --rm python:3.10-alpine sh scripts/runtests.sh
@ -31,24 +40,31 @@ docker-tests:
requirements.txt: Pipfile.lock requirements.txt: Pipfile.lock
pipenv requirements > requirements.txt pipenv requirements > requirements.txt
.PHONY: format
format: format:
black mail4one/*py tests/*py black mail4one/*py tests/*py
.PHONY: build-dev
build-dev: requirements.txt build build-dev: requirements.txt build
.PHONY: setup
setup: setup:
pipenv install pipenv install
.PHONY: cleanup
cleanup: cleanup:
pipenv --rm pipenv --rm
.PHONY: update
update: update:
rm requirements.txt Pipfile.lock rm requirements.txt Pipfile.lock
pipenv update pipenv update
pipenv requirements > requirements.txt pipenv requirements > requirements.txt
.PHONY: shell
shell: shell:
MYPYPATH=`pipenv --venv`/lib/python3.11/site-packages pipenv shell MYPYPATH=$(shell ls -d `pipenv --venv`/lib/python3*/site-packages) pipenv shell
test: .PHONY: dev-test
dev-test:
pipenv run python -m unittest discover pipenv run python -m unittest discover

View File

@ -1,4 +1,5 @@
import json """Module for parsing mail4one config.json"""
import re import re
import logging import logging
from typing import Callable, Union, Optional from typing import Callable, Union, Optional
@ -56,13 +57,14 @@ class PopCfg(ServerCfg):
class SmtpStartTLSCfg(ServerCfg): class SmtpStartTLSCfg(ServerCfg):
server_type = "smtp_starttls" server_type = "smtp_starttls"
smtputf8 = True # Not used yet require_starttls = True
smtputf8 = True
port = 25 port = 25
class SmtpCfg(ServerCfg): class SmtpCfg(ServerCfg):
server_type = "smtp_starttls" server_type = "smtp"
smtputf8 = True # Not used yet smtputf8 = True
port = 465 port = 465
@ -94,11 +96,10 @@ def parse_checkers(cfg: Config) -> list[Checker]:
raise Exception("Both addrs and addr_rexs is set") raise Exception("Both addrs and addr_rexs is set")
if m.addrs: if m.addrs:
return lambda malias: malias in m.addrs return lambda malias: malias in m.addrs
elif m.addr_rexs: if m.addr_rexs:
compiled_res = [re.compile(reg) for reg in m.addr_rexs] compiled_res = [re.compile(reg) for reg in m.addr_rexs]
return lambda malias: any(reg.match(malias) for reg in compiled_res) return lambda malias: any(reg.match(malias) for reg in compiled_res)
else: raise Exception("Neither addrs nor addr_rexs is set")
raise Exception("Neither addrs nor addr_rexs is set")
matches = {m.name: make_match_fn(Match(m)) for m in cfg.matches or []} matches = {m.name: make_match_fn(Match(m)) for m in cfg.matches or []}
matches[DEFAULT_MATCH_ALL] = lambda _: True matches[DEFAULT_MATCH_ALL] = lambda _: True

View File

@ -2,18 +2,15 @@ import asyncio
import contextlib import contextlib
import contextvars import contextvars
import logging import logging
import os
import ssl import ssl
import uuid import random
from typing import Optional
from asyncio import StreamReader, StreamWriter
from dataclasses import dataclass from dataclasses import dataclass
from hashlib import sha256
from pathlib import Path from pathlib import Path
from .config import User from .config import User
from .pwhash import parse_hash, check_pass, PWInfo from .pwhash import parse_hash, check_pass, PWInfo
from asyncio import StreamReader, StreamWriter
import random
from typing import Optional
from .poputils import ( from .poputils import (
InvalidCommand, InvalidCommand,
@ -75,14 +72,14 @@ class PopLogger(logging.LoggerAdapter):
def __init__(self): def __init__(self):
super().__init__(logging.getLogger("pop3"), None) super().__init__(logging.getLogger("pop3"), None)
def process(self, msg, kwargs): def process(self, log_msg, kwargs):
state: State = c_state.get(None) st: State = c_state.get(None)
if not state: if not st:
return super().process(msg, kwargs) return super().process(log_msg, kwargs)
user = "NA" user = "NA"
if state.username: if st.username:
user = state.username user = st.username
return super().process(f"{state.ip} {state.req_id} {user} {msg}", kwargs) return super().process(f"{st.ip} {st.req_id} {user} {log_msg}", kwargs)
logger = PopLogger() logger = PopLogger()
@ -104,8 +101,7 @@ async def next_req() -> Request:
if request.cmd == Command.QUIT: if request.cmd == Command.QUIT:
raise ClientQuit raise ClientQuit
return request return request
else: raise ClientError(f"Bad command {InvalidCommand.RETRIES} times")
raise ClientError(f"Bad command {InvalidCommand.RETRIES} times")
async def expect_cmd(*commands: Command) -> Request: async def expect_cmd(*commands: Command) -> Request:
@ -153,25 +149,23 @@ async def auth_stage() -> None:
write(ok("Following are supported")) write(ok("Following are supported"))
write(msg("USER")) write(msg("USER"))
write(end()) write(end())
else: continue
await handle_user_pass_auth(req) await handle_user_pass_auth(req)
if state().username in scfg().loggedin_users: if state().username in scfg().loggedin_users:
logger.warning( logger.warning(
f"User: {state().username} already has an active session" f"User: {state().username} already has an active session"
) )
raise AuthError("Already logged in") raise AuthError("Already logged in")
else: scfg().loggedin_users.add(state().username)
scfg().loggedin_users.add(state().username) write(ok("Login successful"))
write(ok("Login successful")) return
return
except AuthError as ae: except AuthError as ae:
write(err(f"Auth Failed: {ae}")) write(err(f"Auth Failed: {ae}"))
except ClientQuit as c: except ClientQuit:
write(ok("Bye")) write(ok("Bye"))
logger.warning("Client has QUIT before auth succeeded") logger.warning("Client has QUIT before auth succeeded")
raise raise
else: raise ClientError("Failed to authenticate")
raise ClientError("Failed to authenticate")
def trans_command_capa(_, __) -> None: def trans_command_capa(_, __) -> None:
@ -272,9 +266,8 @@ async def process_transactions(mails_list: list[MailEntry]) -> set[str]:
except KeyError: except KeyError:
write(err("Not implemented")) write(err("Not implemented"))
raise ClientError("We shouldn't reach here") raise ClientError("We shouldn't reach here")
else: func(mails, req)
func(mails, req) await state().writer.drain()
await state().writer.drain()
def get_deleted_items(deleted_items_path: Path) -> set[str]: def get_deleted_items(deleted_items_path: Path) -> set[str]:
@ -305,7 +298,7 @@ async def transaction_stage() -> None:
deleted_items_path, existing_deleted_items.union(new_deleted_items) deleted_items_path, existing_deleted_items.union(new_deleted_items)
) )
logger.info(f"Saved deleted items") logger.info("Saved deleted items")
async def start_session() -> None: async def start_session() -> None:
@ -316,12 +309,10 @@ async def start_session() -> None:
assert state().mbox assert state().mbox
await transaction_stage() await transaction_stage()
logger.info(f"User:{state().username} done") logger.info(f"User:{state().username} done")
except ClientDisconnected as c: except ClientDisconnected:
logger.info("Client disconnected") logger.info("Client disconnected")
pass
except ClientQuit: except ClientQuit:
logger.info("Client QUIT") logger.info("Client QUIT")
pass
except ClientError as c: except ClientError as c:
write(err("Something went wrong")) write(err("Something went wrong"))
logger.error(f"Unexpected client error: {c}") logger.error(f"Unexpected client error: {c}")
@ -344,13 +335,13 @@ def parse_users(users: list[User]) -> dict[str, tuple[PWInfo, str]]:
def make_pop_server_callback(mails_path: Path, users: list[User], timeout_seconds: int): def make_pop_server_callback(mails_path: Path, users: list[User], timeout_seconds: int):
scfg = SharedState(mails_path=mails_path, users=parse_users(users)) s_state = SharedState(mails_path=mails_path, users=parse_users(users))
async def session_cb(reader: StreamReader, writer: StreamWriter): async def session_cb(reader: StreamReader, writer: StreamWriter):
c_shared_state.set(scfg) c_shared_state.set(s_state)
ip, _ = writer.get_extra_info("peername") ip, _ = writer.get_extra_info("peername")
c_state.set(State(reader=reader, writer=writer, ip=ip, req_id=scfg.next_id())) c_state.set(State(reader=reader, writer=writer, ip=ip, req_id=s_state.next_id()))
logger.info(f"Got pop server callback") logger.info("Got pop server callback")
try: try:
try: try:
return await asyncio.wait_for(start_session(), timeout_seconds) return await asyncio.wait_for(start_session(), timeout_seconds)
@ -372,7 +363,7 @@ async def create_pop_server(
timeout_seconds: int = 60, timeout_seconds: int = 60,
) -> asyncio.Server: ) -> asyncio.Server:
logging.info( logging.info(
f"Starting POP3 server {host=}, {port=}, {mails_path=!s}, {len(users)=}, {ssl_context != None=}, {timeout_seconds=}" f"Starting POP3 server {host=}, {port=}, {mails_path=!s}, {len(users)=}, {bool(ssl_context)=}, {timeout_seconds=}"
) )
return await asyncio.start_server( return await asyncio.start_server(
make_pop_server_callback(mails_path, users, timeout_seconds), make_pop_server_callback(mails_path, users, timeout_seconds),

View File

@ -20,12 +20,10 @@ class ClientDisconnected(ClientError):
class InvalidCommand(ClientError): class InvalidCommand(ClientError):
RETRIES = 3 RETRIES = 3
"""WIll allow NUM_BAD_COMMANDS times""" """WIll allow NUM_BAD_COMMANDS times"""
pass
class AuthError(ClientError): class AuthError(ClientError):
RETRIES = 3 RETRIES = 3
pass
class Command(Enum): class Command(Enum):

View File

@ -1,11 +1,10 @@
import asyncio import asyncio
import logging import logging
import os
import ssl import ssl
import sys
from argparse import ArgumentParser from argparse import ArgumentParser
from pathlib import Path from pathlib import Path
from getpass import getpass from getpass import getpass
from typing import Optional, Union
from .smtp import create_smtp_server_starttls, create_smtp_server from .smtp import create_smtp_server_starttls, create_smtp_server
from .pop3 import create_pop_server from .pop3 import create_pop_server
@ -13,7 +12,6 @@ from .version import VERSION
from . import config from . import config
from . import pwhash from . import pwhash
from typing import Optional, Union
def create_tls_context(certfile, keyfile) -> ssl.SSLContext: def create_tls_context(certfile, keyfile) -> ssl.SSLContext:
@ -44,17 +42,15 @@ async def a_main(cfg: config.Config) -> None:
def get_tls_context(tls: Union[config.TLSCfg, str]): def get_tls_context(tls: Union[config.TLSCfg, str]):
if tls == "default": if tls == "default":
return default_tls_context return default_tls_context
elif tls == "disable": if tls == "disable":
return None return None
else: tls_cfg = config.TLSCfg(tls)
tls_cfg = config.TLSCfg(tls) return create_tls_context(tls_cfg.certfile, tls_cfg.keyfile)
return create_tls_context(tls_cfg.certfile, tls_cfg.keyfile)
def get_host(host): def get_host(host):
if host == "default": if host == "default":
return cfg.default_host return cfg.default_host
else: return host
return host
mbox_finder = config.gen_addr_to_mboxes(cfg) mbox_finder = config.gen_addr_to_mboxes(cfg)
servers: list[asyncio.Server] = [] servers: list[asyncio.Server] = []
@ -86,6 +82,8 @@ async def a_main(cfg: config.Config) -> None:
mails_path=Path(cfg.mails_path), mails_path=Path(cfg.mails_path),
mbox_finder=mbox_finder, mbox_finder=mbox_finder,
ssl_context=stls_context, ssl_context=stls_context,
require_starttls=stls.require_starttls,
smtputf8=stls.smtputf8,
) )
servers.append(smtp_server_starttls) servers.append(smtp_server_starttls)
elif scfg.server_type == "smtp": elif scfg.server_type == "smtp":
@ -96,6 +94,7 @@ async def a_main(cfg: config.Config) -> None:
mails_path=Path(cfg.mails_path), mails_path=Path(cfg.mails_path),
mbox_finder=mbox_finder, mbox_finder=mbox_finder,
ssl_context=get_tls_context(smtp.tls), ssl_context=get_tls_context(smtp.tls),
smtputf8=smtp.smtputf8,
) )
servers.append(smtp_server) servers.append(smtp_server)
else: else:

View File

@ -1,23 +1,18 @@
import asyncio import asyncio
import io
import logging import logging
import mailbox
import ssl import ssl
import uuid import uuid
import shutil import shutil
from functools import partial from functools import partial
from pathlib import Path from pathlib import Path
from typing import Callable, Optional from typing import Callable, Optional
from . import config
from email.message import Message from email.message import Message
import email.policy import email.policy
from email.generator import BytesGenerator from email.generator import BytesGenerator
import tempfile import tempfile
import random
from aiosmtpd.handlers import Mailbox, AsyncMessage from aiosmtpd.handlers import AsyncMessage
from aiosmtpd.smtp import SMTP, DATA_SIZE_DEFAULT from aiosmtpd.smtp import SMTP
from aiosmtpd.smtp import SMTP as SMTPServer
from aiosmtpd.smtp import Envelope as SMTPEnvelope from aiosmtpd.smtp import Envelope as SMTPEnvelope
from aiosmtpd.smtp import Session as SMTPSession from aiosmtpd.smtp import Session as SMTPSession
@ -29,15 +24,17 @@ class MyHandler(AsyncMessage):
super().__init__() super().__init__()
self.mails_path = mails_path self.mails_path = mails_path
self.mbox_finder = mbox_finder self.mbox_finder = mbox_finder
self.rcpt_tos = []
self.peer = None
async def handle_DATA( async def handle_DATA(
self, server: SMTPServer, session: SMTPSession, envelope: SMTPEnvelope self, server: SMTP, session: SMTPSession, envelope: SMTPEnvelope
) -> str: ) -> str:
self.rcpt_tos = envelope.rcpt_tos self.rcpt_tos = envelope.rcpt_tos
self.peer = session.peer self.peer = session.peer
return await super().handle_DATA(server, session, envelope) return await super().handle_DATA(server, session, envelope)
async def handle_message(self, m: Message): # type: ignore[override] async def handle_message(self, message: Message): # type: ignore[override]
all_mboxes: set[str] = set() all_mboxes: set[str] = set()
for addr in self.rcpt_tos: for addr in self.rcpt_tos:
for mbox in self.mbox_finder(addr.lower()): for mbox in self.mbox_finder(addr.lower()):
@ -54,7 +51,7 @@ class MyHandler(AsyncMessage):
temp_email_path = Path(tmpdir) / filename temp_email_path = Path(tmpdir) / filename
with open(temp_email_path, "wb") as fp: with open(temp_email_path, "wb") as fp:
gen = BytesGenerator(fp, policy=email.policy.SMTP) gen = BytesGenerator(fp, policy=email.policy.SMTP)
gen.flatten(m) gen.flatten(message)
for mbox in all_mboxes: for mbox in all_mboxes:
shutil.copy(temp_email_path, self.mails_path / mbox / "new") shutil.copy(temp_email_path, self.mails_path / mbox / "new")
logger.info( logger.info(
@ -63,16 +60,20 @@ class MyHandler(AsyncMessage):
def protocol_factory_starttls( def protocol_factory_starttls(
mails_path: Path, mbox_finder: Callable[[str], list[str]], context: ssl.SSLContext mails_path: Path,
mbox_finder: Callable[[str], list[str]],
context: ssl.SSLContext,
require_starttls: bool,
smtputf8: bool,
): ):
logger.info("Got smtp client cb starttls") logger.info("Got smtp client cb starttls")
try: try:
handler = MyHandler(mails_path, mbox_finder) handler = MyHandler(mails_path, mbox_finder)
smtp = SMTP( smtp = SMTP(
handler=handler, handler=handler,
require_starttls=True, require_starttls=require_starttls,
tls_context=context, tls_context=context,
enable_SMTPUTF8=True, enable_SMTPUTF8=smtputf8,
) )
except: except:
logger.exception("Something went wrong") logger.exception("Something went wrong")
@ -80,11 +81,13 @@ def protocol_factory_starttls(
return smtp return smtp
def protocol_factory(mails_path: Path, mbox_finder: Callable[[str], list[str]]): def protocol_factory(
mails_path: Path, mbox_finder: Callable[[str], list[str]], smtputf8: bool
):
logger.info("Got smtp client cb") logger.info("Got smtp client cb")
try: try:
handler = MyHandler(mails_path, mbox_finder) handler = MyHandler(mails_path, mbox_finder)
smtp = SMTP(handler=handler, enable_SMTPUTF8=True) smtp = SMTP(handler=handler, enable_SMTPUTF8=smtputf8)
except: except:
logger.exception("Something went wrong") logger.exception("Something went wrong")
raise raise
@ -97,13 +100,22 @@ async def create_smtp_server_starttls(
mails_path: Path, mails_path: Path,
mbox_finder: Callable[[str], list[str]], mbox_finder: Callable[[str], list[str]],
ssl_context: ssl.SSLContext, ssl_context: ssl.SSLContext,
require_starttls: bool,
smtputf8: bool,
) -> asyncio.Server: ) -> asyncio.Server:
logging.info( logging.info(
f"Starting SMTP STARTTLS server {host=}, {port=}, {mails_path=!s}, {ssl_context != None=}" f"Starting SMTP STARTTLS server {host=}, {port=}, {mails_path=!s}, {bool(ssl_context)=}"
) )
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
return await loop.create_server( return await loop.create_server(
partial(protocol_factory_starttls, mails_path, mbox_finder, ssl_context), partial(
protocol_factory_starttls,
mails_path,
mbox_finder,
ssl_context,
require_starttls,
smtputf8,
),
host=host, host=host,
port=port, port=port,
start_serving=False, start_serving=False,
@ -115,14 +127,15 @@ async def create_smtp_server(
port: int, port: int,
mails_path: Path, mails_path: Path,
mbox_finder: Callable[[str], list[str]], mbox_finder: Callable[[str], list[str]],
ssl_context: Optional[ssl.SSLContext] = None, ssl_context: Optional[ssl.SSLContext],
smtputf8: bool,
) -> asyncio.Server: ) -> asyncio.Server:
logging.info( logging.info(
f"Starting SMTP server {host=}, {port=}, {mails_path=!s}, {ssl_context != None=}" f"Starting SMTP server {host=}, {port=}, {mails_path=!s}, {bool(ssl_context)=}"
) )
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
return await loop.create_server( return await loop.create_server(
partial(protocol_factory, mails_path, mbox_finder), partial(protocol_factory, mails_path, mbox_finder, smtputf8),
host=host, host=host,
port=port, port=port,
ssl=ssl_context, ssl=ssl_context,

View File

@ -89,6 +89,9 @@ class TestPop3(unittest.IsolatedAsyncioTestCase):
self.task = asyncio.create_task(pop_server.serve_forever()) self.task = asyncio.create_task(pop_server.serve_forever())
self.reader, self.writer = await asyncio.open_connection("127.0.0.1", 7995) self.reader, self.writer = await asyncio.open_connection("127.0.0.1", 7995)
# Additional writers to close
self.ws: list[asyncio.StreamWriter] = []
async def test_QUIT(self) -> None: async def test_QUIT(self) -> None:
dialog = """ dialog = """
S: +OK Server Ready S: +OK Server Ready
@ -133,6 +136,7 @@ class TestPop3(unittest.IsolatedAsyncioTestCase):
async def test_dupe_AUTH(self) -> None: async def test_dupe_AUTH(self) -> None:
r1, w1 = await asyncio.open_connection("127.0.0.1", 7995) r1, w1 = await asyncio.open_connection("127.0.0.1", 7995)
r2, w2 = await asyncio.open_connection("127.0.0.1", 7995) r2, w2 = await asyncio.open_connection("127.0.0.1", 7995)
self.ws += w1, w2
dialog = """ dialog = """
S: +OK Server Ready S: +OK Server Ready
C: USER foobar C: USER foobar
@ -231,8 +235,10 @@ class TestPop3(unittest.IsolatedAsyncioTestCase):
async def asyncTearDown(self) -> None: async def asyncTearDown(self) -> None:
logging.debug("at teardown") logging.debug("at teardown")
self.writer.close() for w in self.ws + [self.writer]:
await self.writer.wait_closed() w.close()
await w.wait_closed()
self.ws.clear()
self.task.cancel("test done") self.task.cancel("test done")
async def dialog_checker(self, dialog: str) -> None: async def dialog_checker(self, dialog: str) -> None:

View File

@ -33,6 +33,8 @@ class TestSMTP(unittest.IsolatedAsyncioTestCase):
port=7996, port=7996,
mails_path=MAILS_PATH, mails_path=MAILS_PATH,
mbox_finder=lambda addr: [TEST_MBOX], mbox_finder=lambda addr: [TEST_MBOX],
ssl_context=None,
smtputf8=True,
) )
self.task = asyncio.create_task(smtp_server.serve_forever()) self.task = asyncio.create_task(smtp_server.serve_forever())