pop3 refactor wip

This commit is contained in:
Balakrishnan Balasubramanian 2023-06-02 11:36:18 -04:00
parent c441863af7
commit 5600d30f54
3 changed files with 146 additions and 77 deletions

View File

@ -26,7 +26,14 @@ class User(Jata):
class Config(Jata): class Config(Jata):
certfile: str certfile: str
keyfile: str keyfile: str
debug: bool = False
mails_path: str mails_path: str
host = '0.0.0.0'
smtp_port = 25
smtp_port_tls = 465
smtp_port_submission = 587
pop_port = 995
smtputf8 = True
rules: list[Rule] rules: list[Rule]
boxes: list[Mbox] boxes: list[Mbox]
users: list[User] users: list[User]

View File

@ -2,11 +2,14 @@ import asyncio
import logging import logging
import os import os
import ssl import ssl
from _contextvars import ContextVar import contextvars
from dataclasses import dataclass from dataclasses import dataclass
from hashlib import sha256 from hashlib import sha256
from pathlib import Path from pathlib import Path
from typing import ClassVar, List, Set from typing import ClassVar, List, Set
from .config import User
from .pwhash import parse_hash, check_pass
from asyncio import StreamReader, StreamWriter
from .poputils import InvalidCommand, parse_command, err, Command, ClientQuit, ClientError, AuthError, ok, msg, end, \ from .poputils import InvalidCommand, parse_command, err, Command, ClientQuit, ClientError, AuthError, ok, msg, end, \
Request, MailEntry, get_mail, get_mails_list, MailList Request, MailEntry, get_mail, get_mails_list, MailList
@ -19,21 +22,16 @@ def add_season(content: bytes, season: bytes):
# noinspection PyProtectedMember # noinspection PyProtectedMember
@dataclass @dataclass
class Session: class Session:
_reader: asyncio.StreamReader _reader: StreamReader
_writer: asyncio.StreamWriter _writer: asyncio.StreamWriter
username: str
mbox: str
# common state # common state
all_sessions: ClassVar[Set] = set() all_sessions: ClassVar[Set] = set()
mails_path: ClassVar[Path] = Path("") mails_path: ClassVar[Path] = Path("")
users: ClassVar[list[User]] = list()
current_session: ClassVar = ContextVar("session") current_session: ClassVar = ContextVar("session")
password_hash: ClassVar[str] = ""
SALT: ClassVar[bytes] = b"balki is awesome+"
pepper: ClassVar[bytes]
@classmethod
def init_password(cls, salted_hash: str):
cls.pepper = os.urandom(32)
cls.password_hash = add_season(bytes.fromhex(salted_hash), cls.pepper)
@classmethod @classmethod
def get(cls): def get(cls):
@ -50,7 +48,7 @@ class Session:
async def next_req(): async def next_req():
for _ in range(InvalidCommand.RETRIES): for _ in range(InvalidCommand.RETRIES):
line = await Session.reader().readline() line = await state().reader.readline()
logging.debug(f"Client: {line}") logging.debug(f"Client: {line}")
if not line: if not line:
continue continue
@ -76,13 +74,20 @@ async def expect_cmd(*commands: Command):
def write(data): def write(data):
logging.debug(f"Server: {data}") logging.debug(f"Server: {data}")
Session.writer().write(data) state().writer.write(data)
def validate_password(password): def validate_password(username, password):
if Session.password_hash != add_season(add_season(password.encode(), Session.SALT), Session.pepper): try:
pwinfo, mbox = config().users[username]
except:
raise AuthError("Invalid user pass") raise AuthError("Invalid user pass")
if not check_pass(password, pwinfo):
raise AuthError("Invalid user pass")
state().username = username
state().mbox = mbox
async def handle_user_pass_auth(user_cmd): async def handle_user_pass_auth(user_cmd):
username = user_cmd.arg1 username = user_cmd.arg1
@ -91,9 +96,8 @@ async def handle_user_pass_auth(user_cmd):
write(ok("Welcome")) write(ok("Welcome"))
cmd = await expect_cmd(Command.PASS) cmd = await expect_cmd(Command.PASS)
password = cmd.arg1 password = cmd.arg1
validate_password(password) validate_password(username, password)
logging.info(f"User: {username} has logged in successfully") logging.info(f"{username=} has logged in successfully")
return username
async def auth_stage(): async def auth_stage():
@ -107,12 +111,12 @@ async def auth_stage():
write(end()) write(end())
else: else:
username = await handle_user_pass_auth(req) username = await handle_user_pass_auth(req)
if username in Session.all_sessions: if username in config().loggedin_users:
logging.warning(f"User: {username} already has an active session") logging.warning(
f"User: {username} already has an active session")
raise AuthError("Already logged in") raise AuthError("Already logged in")
else: else:
write(ok("Login successful")) write(ok("Login successful"))
return username
except AuthError as ae: except AuthError as ae:
write(err(f"Auth Failed: {ae}")) write(err(f"Auth Failed: {ae}"))
except ClientQuit as c: except ClientQuit as c:
@ -218,7 +222,7 @@ async def process_transactions(mails_list: List[MailEntry]):
raise ClientError("We shouldn't reach here") raise ClientError("We shouldn't reach here")
else: else:
func(mails, req) func(mails, req)
await Session.writer().drain() await state().writer.drain()
def get_deleted_items(deleted_items_path: Path): def get_deleted_items(deleted_items_path: Path):
@ -234,32 +238,35 @@ def save_deleted_items(deleted_items_path: Path, deleted_items: Set):
async def transaction_stage(existing_deleted_items: Set): async def transaction_stage(existing_deleted_items: Set):
mails_list = [entry for entry in get_mails_list(Session.mails_path / 'new') if mails_list = [
entry.uid not in existing_deleted_items] entry for entry in get_mails_list(config().mails_path / 'new')
if entry.uid not in existing_deleted_items
]
new_deleted_items: Set = await process_transactions(mails_list) new_deleted_items: Set = await process_transactions(mails_list)
return new_deleted_items return new_deleted_items
async def new_session(stream_reader: asyncio.StreamReader, stream_writer: asyncio.StreamWriter): async def start_session():
session = Session(stream_reader, stream_writer) logging.info("New session started")
Session.current_session.set(session)
logging.info(f"New session started with {stream_reader} and {stream_writer}")
username = None
try: try:
username = await auth_stage() await auth_stage()
assert username is not None assert username is not None
Session.all_sessions.add(username) config().loggedin_users.add(username)
deleted_items_path = Session.mails_path / username _, mbox = config().users[username]
deleted_items_path = config().mails_path/ mbox / username
logging.info(f"User:{username} logged in successfully") logging.info(f"User:{username} logged in successfully")
existing_deleted_items: Set = get_deleted_items(deleted_items_path) existing_deleted_items: Set = get_deleted_items(deleted_items_path)
new_deleted_items: Set = await transaction_stage(existing_deleted_items) new_deleted_items: Set = await transaction_stage(existing_deleted_items)
logging.info(f"User:{username} completed transactions. Deleted:{new_deleted_items}") logging.info(
f"{username=} completed transactions. Deleted:{len(new_deleted_items)}"
)
if new_deleted_items: if new_deleted_items:
save_deleted_items(deleted_items_path, existing_deleted_items.union(new_deleted_items)) save_deleted_items(deleted_items_path,
existing_deleted_items.union(new_deleted_items))
logging.info(f"User:{username} Saved deleted items") logging.info(f"User:{username} Saved deleted items")
@ -271,23 +278,78 @@ async def new_session(stream_reader: asyncio.StreamReader, stream_writer: asynci
raise raise
finally: finally:
if username: if username:
Session.all_sessions.remove(username) config().loggedin_users.remove(username)
stream_writer.close()
async def timed_cb(stream_reader: asyncio.StreamReader, stream_writer: asyncio.StreamWriter): def parse_users(users: list[User]):
try:
return await asyncio.wait_for(new_session(stream_reader, stream_writer), 60) def inner():
finally: for user in Users:
stream_writer.close() user = User(user)
pwinfo = parse_hash(user.password_hash)
yield user.username, (pwinfo, user.mbox)
return dict(inner())
async def create_pop_server(dirpath: Path, port: int, password_hash: str, host="", context: ssl.SSLContext = None): @dataclass
Session.mails_path = dirpath class State:
Session.init_password(password_hash) reader: StreamReader
writer: StreamWriter
username: str = ""
mbox: str = ""
@dataclass
class Config:
mails_path: Path
users: dict[str, tuple[pwhash.PWInfo, str]]
loggedin_users: set[str] = set()
c_config = contextvars.ContextVar('config')
def config() -> Config:
return c_config.get()
c_state = contextvars.ContextVar('state')
def state() -> State:
return c_state.get()
def make_pop_server_callback(dirpath: Path, users: list[User],
timeout_seconds: int):
config = Config(mails_path=dirpath, users=parse_users(users))
async def session_cb(reader: StreamReader, writer: StreamWriter):
c_config.set(config)
c_state.set(State(reader=reader, writer=writer))
try:
return await asyncio.wait_for(start_session(), timeout_seconds)
finally:
stream_writer.close()
return session_cb
async def create_pop_server(dirpath: Path,
port: int,
users: list[User],
host="",
context: ssl.SSLContext = None,
timeout_seconds: int = 60):
logging.info( logging.info(
f"Starting POP3 server Maildir={dirpath}, host={host}, port={port}, context={context}") f"Starting POP3 server {dirpath=}, {host=}, {port=}, {timeout_seconds=}, ssl={context != None}"
return await asyncio.start_server(timed_cb, host=host, port=port, ssl=context) )
return await asyncio.start_server(make_pop_server_callback(
dirpath, users, timeout_seconds),
host=host,
port=port,
ssl=context)
async def a_main(*args, **kwargs): async def a_main(*args, **kwargs):
@ -304,7 +366,6 @@ def debug_main():
mails_path = Path(mails_path) mails_path = Path(mails_path)
port = int(port) port = int(port)
password_hash = add_season(password.encode(), Session.SALT).hex()
asyncio.run(a_main(mails_path, port, password_hash=password_hash)) asyncio.run(a_main(mails_path, port, password_hash=password_hash))

View File

@ -13,6 +13,8 @@ from pathlib import Path
from .smtp import create_smtp_server_starttls, create_smtp_server_tls from .smtp import create_smtp_server_starttls, create_smtp_server_tls
from .pop3 import create_pop_server from .pop3 import create_pop_server
from .config import Config
def create_tls_context(certfile, keyfile): def create_tls_context(certfile, keyfile):
context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH) context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH)
@ -48,44 +50,43 @@ def setup_logging(args):
logging.basicConfig(level=logging.INFO) logging.basicConfig(level=logging.INFO)
def drop_privileges(): async def a_main(config, tls_context):
try: pop_server = await create_pop_server(config.mails_path,
import pwd port=config.pop_port,
except ImportError: host=config.host,
logging.error("Cannot import pwd; run as root") context=tls_context,
sys.exit(1) users=config.users)
nobody = pwd.getpwnam('nobody')
try:
os.setgid(nobody.pw_gid)
os.setuid(nobody.pw_uid)
except PermissionError:
logging.error("Cannot setuid nobody; run as root")
sys.exit(1)
logging.info("Dropped privileges")
logging.debug("Signalled! Clients can come in")
async def a_main(args, tls_context):
pop_server = await create_pop_server(
args.mail_dir_path, port=args.pop_port, host=args.host, context=tls_context, password_hash=args.password_hash)
smtp_server_starttls = await create_smtp_server_starttls( smtp_server_starttls = await create_smtp_server_starttls(
args.mail_dir_path, port=args.smtp_port, host=args.host, context=tls_context) config.mail_dir_path,
smtp_server_tls = await create_smtp_server_tls( port=config.smtp_port,
args.mail_dir_path, port=args.smtp_port_tls, host=args.host, context=tls_context) host=config.host,
drop_privileges() context=tls_context)
await asyncio.gather(
pop_server.serve_forever(), smtp_server_tls = await create_smtp_server_tls(config.mail_dir_path,
smtp_server_starttls.serve_forever(), port=config.smtp_port_tls,
smtp_server_tls.serve_forever()) host=config.host,
context=tls_context)
await asyncio.gather(pop_server.serve_forever(),
smtp_server_starttls.serve_forever(),
smtp_server_tls.serve_forever())
def main(): def main():
args = parse_args() config_path = sys.argv[1]
tls_context = create_tls_context(args.certfile, args.keyfile) parser = ArgumentParser()
parser.add_argument("config_path")
args = parser.parse_args()
config = Config(open(args.config_path).read())
setup_logging(args) setup_logging(args)
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
loop.set_debug(args.debug) loop.set_debug(config.debug)
asyncio.run(a_main(args, tls_context))
tls_context = create_tls_context(config.certfile, config.keyfile)
asyncio.run(a_main(config, tls_context))
if __name__ == '__main__': if __name__ == '__main__':