mail4one/mail4one/pop3.py

401 lines
11 KiB
Python

import asyncio
import contextlib
import contextvars
import logging
import ssl
import random
from typing import Optional
from asyncio import StreamReader, StreamWriter
from dataclasses import dataclass
from pathlib import Path
from .config import User
from .pwhash import parse_hash, check_pass, PWInfo
from .poputils import (
InvalidCommand,
parse_command,
err,
Command,
ClientQuit,
ClientDisconnected,
ClientError,
AuthError,
ok,
msg,
end,
Request,
MailEntry,
get_mail_fp,
get_mails_list,
MailList,
)
@dataclass
class State:
reader: StreamReader
writer: StreamWriter
ip: str
req_id: int
username: str = ""
mbox: str = ""
class SharedState:
def __init__(self, mails_path: Path, users: dict[str, tuple[PWInfo, str]]):
self.mails_path = mails_path
self.users = users
self.loggedin_users: set[str] = set()
self.counter = random.randint(10000, 99999) * 100000
def next_id(self) -> int:
self.counter = self.counter + 1
return self.counter
c_shared_state: contextvars.ContextVar = contextvars.ContextVar("pop_shared_state")
def scfg() -> SharedState:
return c_shared_state.get()
c_state: contextvars.ContextVar = contextvars.ContextVar("state")
def state() -> State:
return c_state.get()
class PopLogger(logging.LoggerAdapter):
def __init__(self):
super().__init__(logging.getLogger("pop3"), None)
def process(self, msg, kwargs):
state: State = c_state.get(None)
if not state:
return super().process(msg, kwargs)
user = "NA"
if state.username:
user = state.username
return super().process(f"{state.ip} {state.req_id} {user} {msg}", kwargs)
logger = PopLogger()
async def next_req() -> Request:
for _ in range(InvalidCommand.RETRIES):
line = await state().reader.readline()
logger.debug(f"Client: {line!r}")
if not line:
if state().reader.at_eof():
raise ClientDisconnected
continue
try:
request: Request = parse_command(line)
except InvalidCommand:
write(err("Bad command"))
else:
if request.cmd == Command.QUIT:
raise ClientQuit
return request
else:
raise ClientError(f"Bad command {InvalidCommand.RETRIES} times")
async def expect_cmd(*commands: Command) -> Request:
req = await next_req()
if req.cmd not in commands:
logger.error(f"Unexpected command: {req.cmd} is not in {commands}")
raise ClientError
return req
def write(data: bytes) -> None:
logger.debug(f"Server: {data!r}")
state().writer.write(data)
def validate_password(username, password) -> None:
try:
pwinfo, mbox = scfg().users[username]
except:
raise AuthError("Invalid user pass")
if not check_pass(password, pwinfo):
raise AuthError("Invalid user pass")
state().username = username
state().mbox = mbox
async def handle_user_pass_auth(user_cmd) -> None:
username = user_cmd.arg1
if not username:
raise AuthError("Invalid USER command. username empty")
write(ok("Welcome"))
cmd = await expect_cmd(Command.PASS)
password = cmd.arg1
validate_password(username, password)
logger.info(f"{username=} has logged in successfully")
async def auth_stage() -> None:
write(ok("Server Ready"))
for _ in range(AuthError.RETRIES):
try:
req = await expect_cmd(Command.USER, Command.CAPA)
if req.cmd is Command.CAPA:
write(ok("Following are supported"))
write(msg("USER"))
write(end())
else:
await handle_user_pass_auth(req)
if state().username in scfg().loggedin_users:
logger.warning(
f"User: {state().username} already has an active session"
)
raise AuthError("Already logged in")
else:
scfg().loggedin_users.add(state().username)
write(ok("Login successful"))
return
except AuthError as ae:
write(err(f"Auth Failed: {ae}"))
except ClientQuit as c:
write(ok("Bye"))
logger.warning("Client has QUIT before auth succeeded")
raise
else:
raise ClientError("Failed to authenticate")
def trans_command_capa(_, __) -> None:
write(ok("CAPA follows"))
write(msg("UIDL"))
write(end())
def trans_command_stat(mails: MailList, _) -> None:
num, size = mails.compute_stat()
write(ok(f"{num} {size}"))
def trans_command_list(mails: MailList, req: Request) -> None:
if req.arg1:
entry = mails.get(req.arg1)
if entry:
write(ok(f"{req.arg1} {entry.size}"))
else:
write(err("Not found"))
else:
write(ok("Mails follow"))
for entry in mails.get_all():
write(msg(f"{entry.nid} {entry.size}"))
write(end())
def trans_command_uidl(mails: MailList, req: Request) -> None:
if req.arg1:
entry = mails.get(req.arg1)
if entry:
write(ok(f"{req.arg1} {entry.uid}"))
else:
write(err("Not found"))
else:
write(ok("Mails follow"))
for entry in mails.get_all():
write(msg(f"{entry.nid} {entry.uid}"))
write(end())
def trans_command_retr(mails: MailList, req: Request) -> None:
entry = mails.get(req.arg1)
if entry:
write(ok("Contents follow"))
with get_mail_fp(entry) as fp:
for line in fp:
if line.startswith(b"."):
write(b".") # prepend dot
write(line)
# write(get_mail(entry)) # no prepend dot
write(end())
mails.delete(req.arg1)
else:
write(err("Not found"))
def trans_command_dele(mails: MailList, req: Request) -> None:
entry = mails.get(req.arg1)
if entry:
mails.delete(req.arg1)
write(ok("Deleted"))
else:
write(err("Not found"))
def trans_command_noop(_, __) -> None:
write(ok("Hmm"))
async def process_transactions(mails_list: list[MailEntry]) -> set[str]:
mails = MailList(mails_list)
def reset(_, __):
nonlocal mails
mails = MailList(mails_list)
handle_map = {
Command.CAPA: trans_command_capa,
Command.STAT: trans_command_stat,
Command.LIST: trans_command_list,
Command.UIDL: trans_command_uidl,
Command.RETR: trans_command_retr,
Command.DELE: trans_command_dele,
Command.RSET: reset,
Command.NOOP: trans_command_noop,
}
while True:
try:
req = await next_req()
except ClientQuit:
write(ok("Bye"))
return mails.deleted_uids
logger.debug(f"Request: {req}")
try:
func = handle_map[req.cmd]
except KeyError:
write(err("Not implemented"))
raise ClientError("We shouldn't reach here")
else:
func(mails, req)
await state().writer.drain()
def get_deleted_items(deleted_items_path: Path) -> set[str]:
if deleted_items_path.exists():
with deleted_items_path.open() as f:
return set(f.read().splitlines())
return set()
def save_deleted_items(deleted_items_path: Path, deleted_items: set[str]) -> None:
with deleted_items_path.open(mode="w") as f:
f.writelines(f"{did}\n" for did in deleted_items)
async def transaction_stage() -> None:
deleted_items_path = scfg().mails_path / state().mbox / state().username
existing_deleted_items: set[str] = get_deleted_items(deleted_items_path)
mails_list = [
entry
for entry in get_mails_list(scfg().mails_path / state().mbox / "new")
if entry.uid not in existing_deleted_items
]
new_deleted_items: set[str] = await process_transactions(mails_list)
logger.info(f"completed transactions. Deleted:{len(new_deleted_items)}")
if new_deleted_items:
save_deleted_items(
deleted_items_path, existing_deleted_items.union(new_deleted_items)
)
logger.info(f"Saved deleted items")
async def start_session() -> None:
logger.info("New session started")
try:
await auth_stage()
assert state().username
assert state().mbox
await transaction_stage()
logger.info(f"User:{state().username} done")
except ClientDisconnected:
logger.info("Client disconnected")
except ClientQuit:
logger.info("Client QUIT")
except ClientError as c:
write(err("Something went wrong"))
logger.error(f"Unexpected client error: {c}")
except:
logger.exception("Serious client error")
raise
finally:
with contextlib.suppress(KeyError):
scfg().loggedin_users.remove(state().username)
def parse_users(users: list[User]) -> dict[str, tuple[PWInfo, str]]:
def inner():
for user in users:
user = User(user)
pwinfo = parse_hash(user.password_hash)
yield user.username, (pwinfo, user.mbox)
return dict(inner())
def make_pop_server_callback(mails_path: Path, users: list[User], timeout_seconds: int):
scfg = SharedState(mails_path=mails_path, users=parse_users(users))
async def session_cb(reader: StreamReader, writer: StreamWriter):
c_shared_state.set(scfg)
ip, _ = writer.get_extra_info("peername")
c_state.set(State(reader=reader, writer=writer, ip=ip, req_id=scfg.next_id()))
logger.info(f"Got pop server callback")
try:
try:
return await asyncio.wait_for(start_session(), timeout_seconds)
finally:
writer.close()
await writer.wait_closed()
except:
logger.exception("unexpected exception")
return session_cb
async def create_pop_server(
host: str,
port: int,
mails_path: Path,
users: list[User],
ssl_context: Optional[ssl.SSLContext] = None,
timeout_seconds: int = 60,
) -> asyncio.Server:
logging.info(
f"Starting POP3 server {host=}, {port=}, {mails_path=!s}, {len(users)=}, {ssl_context != None=}, {timeout_seconds=}"
)
return await asyncio.start_server(
make_pop_server_callback(mails_path, users, timeout_seconds),
host=host,
port=port,
ssl=ssl_context,
)
async def a_main(*args, **kwargs) -> None:
server = await create_pop_server(*args, **kwargs)
await server.serve_forever()
def debug_main():
logging.basicConfig(level=logging.DEBUG)
import sys
from .pwhash import gen_pwhash
_, mails_path, mbox = sys.argv
mails_path = Path(mails_path)
users = [User(username="dummy", password_hash=gen_pwhash("dummy"), mbox=mbox)]
asyncio.run(a_main("127.0.0.1", 1101, mails_path, users=users))
if __name__ == "__main__":
debug_main()