diff --git a/mail4one/pop3.py b/mail4one/pop3.py index 4136bac..7bd4517 100644 --- a/mail4one/pop3.py +++ b/mail4one/pop3.py @@ -4,11 +4,10 @@ import ssl from _contextvars import ContextVar from dataclasses import dataclass from pathlib import Path -from typing import ClassVar, List, Coroutine -from collections import deque +from typing import ClassVar, List, Set from .poputils import InvalidCommand, parse_command, err, Command, ClientQuit, ClientError, AuthError, ok, msg, end, \ - MailStorage, Request + Request, MailEntry, get_mail, get_mails_list, MailList @dataclass @@ -18,9 +17,8 @@ class Session: username: str = "" read_items: Path = None # common state - all_sessions: ClassVar[List] = [] + all_sessions: ClassVar[Set] = set() mails_path: ClassVar[Path] = Path("") - wait_for_privileges_to_drop: ClassVar[Coroutine] = None pending_request: Request = None def pop_request(self): @@ -84,7 +82,9 @@ async def handle_user_pass_auth(user_cmd): password = cmd.arg1 validate_user_and_pass(username, password) write(ok("Good")) - return username, password + logging.info(f"User: {username} has logged in successfully") + session.username = username + Session.all_sessions.add(username) async def auth_stage(): @@ -98,9 +98,7 @@ async def auth_stage(): write(msg("USER")) write(end()) else: - username, password = await handle_user_pass_auth(req) - logging.info(f"User: {username} has logged in successfully") - return username + return await handle_user_pass_auth(req) except AuthError: write(err("Wrong auth")) except ClientQuit as c: @@ -111,65 +109,96 @@ async def auth_stage(): raise ClientError("Failed to authenticate") -async def transaction_stage(): +async def process_transactions(mails_list: List[MailEntry]): session: Session = current_session.get() - logging.debug(f"Entering transaction stage for {session.username}") - deleted_message_ids = [] - mailbox = MailStorage(Session.mails_path / 'new') - with session.read_items.open() as f: - read_items = set(f.read().splitlines()) + mails = MailList(mails_list) - mails_list = mailbox.get_mails_list() - mails_map = {str(entry.nid): entry for entry in mails_list} while True: try: - req = await next_req() + req = await session.next_req() logging.debug(f"Request: {req}") if req.cmd is Command.CAPA: - write(ok("No CAPA")) + write(ok("CAPA follows")) + write(msg("UIDL")) write(end()) elif req.cmd is Command.STAT: - num, size = mailbox.get_mailbox_size() + num, size = mails.compute_stat() write(ok(f"{num} {size}")) elif req.cmd is Command.LIST: if req.arg1: - write(ok(f"{req.arg1} {mails_map[req.arg1].size}")) + entry = mails.get(req.arg1) + if entry: + write(ok(f"{req.arg1} {entry.size}")) + else: + write(err("Not found")) else: write(ok("Mails follow")) - for entry in mails_list: + for entry in mails.get_all(): write(msg(f"{entry.nid} {entry.size}")) write(end()) elif req.cmd is Command.UIDL: if req.arg1: - write(ok(f"{req.arg1} {mails_map[req.arg1].uid}")) + entry = mails.get(req.arg1) + if entry: + write(ok(f"{req.arg1} {entry.uid}")) + else: + write(err("Not found")) else: write(ok("Mails follow")) - for entry in mails_list: + for entry in mails.get_all(): write(msg(f"{entry.nid} {entry.uid}")) write(end()) await session.writer.drain() elif req.cmd is Command.RETR: - if req.arg1 not in mails_map: - write(err("Not found")) - else: + entry = mails.get(req.arg1) + if entry: write(ok("Contents follow")) - write(mailbox.get_mail(mails_map[req.arg1])) + write(get_mail(entry)) write(end()) await session.writer.drain() + else: + write(err("Not found")) + elif req.cmd is Command.DELE: + entry = mails.get(req.arg1) + if entry: + mails.delete(req.arg1) + else: + write(err("Not found")) + elif req.cmd is Command.RSET: + mails = MailList(mails_list) + elif req.cmd is Command.NOOP: + pass else: write(err("Not implemented")) + raise ClientError("We shouldn't reach here") except ClientQuit: write(ok("Bye")) - return deleted_message_ids + return mails.deleted_uids + + +async def transaction_stage(): + session: Session = current_session.get() + logging.debug(f"Entering transaction stage for {session.username}") + session.read_items = Session.mails_path / session.username + + with session.read_items.open() as f: + read_items = set(f.read().splitlines()) + + mails_list = [entry for entry in get_mails_list(Session.mails_path / 'new') if entry.uid not in read_items] + return await process_transactions(mails_list) def delete_messages(delete_ids): + session: Session = current_session.get() + with session.read_items.open(mode="w") as f: + f.writelines(delete_ids) logging.info(f"Client deleted these ids {delete_ids}") async def new_session(stream_reader: asyncio.StreamReader, stream_writer: asyncio.StreamWriter): - current_session.set(Session(stream_reader, stream_writer)) + session = Session(stream_reader, stream_writer) + current_session.set(session) logging.info(f"New session started with {stream_reader} and {stream_writer}") try: await auth_stage() @@ -182,6 +211,8 @@ async def new_session(stream_reader: asyncio.StreamReader, stream_writer: asynci logging.error(f"Serious client error", e) raise finally: + if session.username: + Session.all_sessions.remove(session.username) stream_writer.close() diff --git a/mail4one/poputils.py b/mail4one/poputils.py index 46f2ce3..643035e 100644 --- a/mail4one/poputils.py +++ b/mail4one/poputils.py @@ -28,15 +28,17 @@ User = NewType('User', str) class Command(Enum): + CAPA = auto() USER = auto() PASS = auto() - CAPA = auto() QUIT = auto() + STAT = auto() LIST = auto() UIDL = auto() RETR = auto() DELE = auto() - STAT = auto() + RSET = auto() + NOOP = auto() @dataclass @@ -88,11 +90,6 @@ def parse_command(line: bytes) -> Request: return request -def files_in_path(path): - for _, _, files in os.walk(path): - return [(f, os.path.join(path, f)) for f in files] - - @dataclass class MailEntry: uid: str @@ -109,23 +106,46 @@ class MailEntry: self.path = path -class MailStorage: - def __init__(self, dirpath: Path): - self.dirpath = dirpath - self.files = files_in_path(self.dirpath) - self.entries = [MailEntry(filename, path) for filename, path in self.files] - self.entries = sorted(self.entries, reverse=True, key=lambda e: e.c_time) - for i, entry in enumerate(self.entries, start=1): - entry.nid = i +def files_in_path(path): + for _, _, files in os.walk(path): + return [(f, os.path.join(path, f)) for f in files] - def get_mailbox_size(self) -> (int, int): - return len(self.entries), sum(entry.size for entry in self.entries) - def get_mails_list(self) -> List[MailEntry]: - return self.entries +def get_mails_list(dirpath: Path) -> List[MailEntry]: + files = files_in_path(dirpath) + entries = [MailEntry(filename, path) for filename, path in files] + return entries - @staticmethod - def get_mail(entry: MailEntry) -> bytes: - with open(entry.path, mode='rb') as fp: - return fp.read() + +def set_nid(entries: List[MailEntry]): + entries.sort(reverse=True, key=lambda e: e.c_time) + entries = sorted(entries, reverse=True, key=lambda e: e.c_time) + for i, entry in enumerate(entries, start=1): + entry.nid = i + + +def get_mail(entry: MailEntry) -> bytes: + with open(entry.path, mode='rb') as fp: + return fp.read() + + +class MailList: + def __init__(self, entries: List[MailEntry]): + self.entries = entries + set_nid(self.entries) + self.mails_map = {str(e.nid): e for e in entries} + self.deleted_uids = set() + + def delete(self, nid: str): + self.deleted_uids.add(self.mails_map.pop(nid).uid) + + def get(self, nid: str): + self.mails_map.get(nid) + + def get_all(self): + return [e for e in self.entries if str(e.nid) in self.mails_map] + + def compute_stat(self): + entries = self.get_all() + return len(entries), sum(entry.size for entry in entries)