pop3 refactor wip

This commit is contained in:
Balakrishnan Balasubramanian 2023-06-02 11:36:18 -04:00
parent c441863af7
commit 5600d30f54
3 changed files with 146 additions and 77 deletions

View File

@ -26,7 +26,14 @@ class User(Jata):
class Config(Jata):
certfile: str
keyfile: str
debug: bool = False
mails_path: str
host = '0.0.0.0'
smtp_port = 25
smtp_port_tls = 465
smtp_port_submission = 587
pop_port = 995
smtputf8 = True
rules: list[Rule]
boxes: list[Mbox]
users: list[User]

View File

@ -2,11 +2,14 @@ import asyncio
import logging
import os
import ssl
from _contextvars import ContextVar
import contextvars
from dataclasses import dataclass
from hashlib import sha256
from pathlib import Path
from typing import ClassVar, List, Set
from .config import User
from .pwhash import parse_hash, check_pass
from asyncio import StreamReader, StreamWriter
from .poputils import InvalidCommand, parse_command, err, Command, ClientQuit, ClientError, AuthError, ok, msg, end, \
Request, MailEntry, get_mail, get_mails_list, MailList
@ -19,21 +22,16 @@ def add_season(content: bytes, season: bytes):
# noinspection PyProtectedMember
@dataclass
class Session:
_reader: asyncio.StreamReader
_reader: StreamReader
_writer: asyncio.StreamWriter
username: str
mbox: str
# common state
all_sessions: ClassVar[Set] = set()
mails_path: ClassVar[Path] = Path("")
users: ClassVar[list[User]] = list()
current_session: ClassVar = ContextVar("session")
password_hash: ClassVar[str] = ""
SALT: ClassVar[bytes] = b"balki is awesome+"
pepper: ClassVar[bytes]
@classmethod
def init_password(cls, salted_hash: str):
cls.pepper = os.urandom(32)
cls.password_hash = add_season(bytes.fromhex(salted_hash), cls.pepper)
@classmethod
def get(cls):
@ -50,7 +48,7 @@ class Session:
async def next_req():
for _ in range(InvalidCommand.RETRIES):
line = await Session.reader().readline()
line = await state().reader.readline()
logging.debug(f"Client: {line}")
if not line:
continue
@ -76,13 +74,20 @@ async def expect_cmd(*commands: Command):
def write(data):
logging.debug(f"Server: {data}")
Session.writer().write(data)
state().writer.write(data)
def validate_password(password):
if Session.password_hash != add_season(add_season(password.encode(), Session.SALT), Session.pepper):
def validate_password(username, password):
try:
pwinfo, mbox = config().users[username]
except:
raise AuthError("Invalid user pass")
if not check_pass(password, pwinfo):
raise AuthError("Invalid user pass")
state().username = username
state().mbox = mbox
async def handle_user_pass_auth(user_cmd):
username = user_cmd.arg1
@ -91,9 +96,8 @@ async def handle_user_pass_auth(user_cmd):
write(ok("Welcome"))
cmd = await expect_cmd(Command.PASS)
password = cmd.arg1
validate_password(password)
logging.info(f"User: {username} has logged in successfully")
return username
validate_password(username, password)
logging.info(f"{username=} has logged in successfully")
async def auth_stage():
@ -107,12 +111,12 @@ async def auth_stage():
write(end())
else:
username = await handle_user_pass_auth(req)
if username in Session.all_sessions:
logging.warning(f"User: {username} already has an active session")
if username in config().loggedin_users:
logging.warning(
f"User: {username} already has an active session")
raise AuthError("Already logged in")
else:
write(ok("Login successful"))
return username
except AuthError as ae:
write(err(f"Auth Failed: {ae}"))
except ClientQuit as c:
@ -218,7 +222,7 @@ async def process_transactions(mails_list: List[MailEntry]):
raise ClientError("We shouldn't reach here")
else:
func(mails, req)
await Session.writer().drain()
await state().writer.drain()
def get_deleted_items(deleted_items_path: Path):
@ -234,32 +238,35 @@ def save_deleted_items(deleted_items_path: Path, deleted_items: Set):
async def transaction_stage(existing_deleted_items: Set):
mails_list = [entry for entry in get_mails_list(Session.mails_path / 'new') if
entry.uid not in existing_deleted_items]
mails_list = [
entry for entry in get_mails_list(config().mails_path / 'new')
if entry.uid not in existing_deleted_items
]
new_deleted_items: Set = await process_transactions(mails_list)
return new_deleted_items
async def new_session(stream_reader: asyncio.StreamReader, stream_writer: asyncio.StreamWriter):
session = Session(stream_reader, stream_writer)
Session.current_session.set(session)
logging.info(f"New session started with {stream_reader} and {stream_writer}")
username = None
async def start_session():
logging.info("New session started")
try:
username = await auth_stage()
await auth_stage()
assert username is not None
Session.all_sessions.add(username)
deleted_items_path = Session.mails_path / username
config().loggedin_users.add(username)
_, mbox = config().users[username]
deleted_items_path = config().mails_path/ mbox / username
logging.info(f"User:{username} logged in successfully")
existing_deleted_items: Set = get_deleted_items(deleted_items_path)
new_deleted_items: Set = await transaction_stage(existing_deleted_items)
logging.info(f"User:{username} completed transactions. Deleted:{new_deleted_items}")
logging.info(
f"{username=} completed transactions. Deleted:{len(new_deleted_items)}"
)
if new_deleted_items:
save_deleted_items(deleted_items_path, existing_deleted_items.union(new_deleted_items))
save_deleted_items(deleted_items_path,
existing_deleted_items.union(new_deleted_items))
logging.info(f"User:{username} Saved deleted items")
@ -271,23 +278,78 @@ async def new_session(stream_reader: asyncio.StreamReader, stream_writer: asynci
raise
finally:
if username:
Session.all_sessions.remove(username)
stream_writer.close()
config().loggedin_users.remove(username)
async def timed_cb(stream_reader: asyncio.StreamReader, stream_writer: asyncio.StreamWriter):
try:
return await asyncio.wait_for(new_session(stream_reader, stream_writer), 60)
finally:
stream_writer.close()
def parse_users(users: list[User]):
def inner():
for user in Users:
user = User(user)
pwinfo = parse_hash(user.password_hash)
yield user.username, (pwinfo, user.mbox)
return dict(inner())
async def create_pop_server(dirpath: Path, port: int, password_hash: str, host="", context: ssl.SSLContext = None):
Session.mails_path = dirpath
Session.init_password(password_hash)
@dataclass
class State:
reader: StreamReader
writer: StreamWriter
username: str = ""
mbox: str = ""
@dataclass
class Config:
mails_path: Path
users: dict[str, tuple[pwhash.PWInfo, str]]
loggedin_users: set[str] = set()
c_config = contextvars.ContextVar('config')
def config() -> Config:
return c_config.get()
c_state = contextvars.ContextVar('state')
def state() -> State:
return c_state.get()
def make_pop_server_callback(dirpath: Path, users: list[User],
timeout_seconds: int):
config = Config(mails_path=dirpath, users=parse_users(users))
async def session_cb(reader: StreamReader, writer: StreamWriter):
c_config.set(config)
c_state.set(State(reader=reader, writer=writer))
try:
return await asyncio.wait_for(start_session(), timeout_seconds)
finally:
stream_writer.close()
return session_cb
async def create_pop_server(dirpath: Path,
port: int,
users: list[User],
host="",
context: ssl.SSLContext = None,
timeout_seconds: int = 60):
logging.info(
f"Starting POP3 server Maildir={dirpath}, host={host}, port={port}, context={context}")
return await asyncio.start_server(timed_cb, host=host, port=port, ssl=context)
f"Starting POP3 server {dirpath=}, {host=}, {port=}, {timeout_seconds=}, ssl={context != None}"
)
return await asyncio.start_server(make_pop_server_callback(
dirpath, users, timeout_seconds),
host=host,
port=port,
ssl=context)
async def a_main(*args, **kwargs):
@ -304,7 +366,6 @@ def debug_main():
mails_path = Path(mails_path)
port = int(port)
password_hash = add_season(password.encode(), Session.SALT).hex()
asyncio.run(a_main(mails_path, port, password_hash=password_hash))

View File

@ -13,6 +13,8 @@ from pathlib import Path
from .smtp import create_smtp_server_starttls, create_smtp_server_tls
from .pop3 import create_pop_server
from .config import Config
def create_tls_context(certfile, keyfile):
context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH)
@ -48,44 +50,43 @@ def setup_logging(args):
logging.basicConfig(level=logging.INFO)
def drop_privileges():
try:
import pwd
except ImportError:
logging.error("Cannot import pwd; run as root")
sys.exit(1)
nobody = pwd.getpwnam('nobody')
try:
os.setgid(nobody.pw_gid)
os.setuid(nobody.pw_uid)
except PermissionError:
logging.error("Cannot setuid nobody; run as root")
sys.exit(1)
logging.info("Dropped privileges")
logging.debug("Signalled! Clients can come in")
async def a_main(config, tls_context):
pop_server = await create_pop_server(config.mails_path,
port=config.pop_port,
host=config.host,
context=tls_context,
users=config.users)
async def a_main(args, tls_context):
pop_server = await create_pop_server(
args.mail_dir_path, port=args.pop_port, host=args.host, context=tls_context, password_hash=args.password_hash)
smtp_server_starttls = await create_smtp_server_starttls(
args.mail_dir_path, port=args.smtp_port, host=args.host, context=tls_context)
smtp_server_tls = await create_smtp_server_tls(
args.mail_dir_path, port=args.smtp_port_tls, host=args.host, context=tls_context)
drop_privileges()
await asyncio.gather(
pop_server.serve_forever(),
smtp_server_starttls.serve_forever(),
smtp_server_tls.serve_forever())
config.mail_dir_path,
port=config.smtp_port,
host=config.host,
context=tls_context)
smtp_server_tls = await create_smtp_server_tls(config.mail_dir_path,
port=config.smtp_port_tls,
host=config.host,
context=tls_context)
await asyncio.gather(pop_server.serve_forever(),
smtp_server_starttls.serve_forever(),
smtp_server_tls.serve_forever())
def main():
args = parse_args()
tls_context = create_tls_context(args.certfile, args.keyfile)
config_path = sys.argv[1]
parser = ArgumentParser()
parser.add_argument("config_path")
args = parser.parse_args()
config = Config(open(args.config_path).read())
setup_logging(args)
loop = asyncio.get_event_loop()
loop.set_debug(args.debug)
asyncio.run(a_main(args, tls_context))
loop.set_debug(config.debug)
tls_context = create_tls_context(config.certfile, config.keyfile)
asyncio.run(a_main(config, tls_context))
if __name__ == '__main__':