beat
This commit is contained in:
		@@ -4,11 +4,10 @@ import ssl
 | 
				
			|||||||
from _contextvars import ContextVar
 | 
					from _contextvars import ContextVar
 | 
				
			||||||
from dataclasses import dataclass
 | 
					from dataclasses import dataclass
 | 
				
			||||||
from pathlib import Path
 | 
					from pathlib import Path
 | 
				
			||||||
from typing import ClassVar, List, Coroutine
 | 
					from typing import ClassVar, List, Set
 | 
				
			||||||
from collections import deque
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
from .poputils import InvalidCommand, parse_command, err, Command, ClientQuit, ClientError, AuthError, ok, msg, end, \
 | 
					from .poputils import InvalidCommand, parse_command, err, Command, ClientQuit, ClientError, AuthError, ok, msg, end, \
 | 
				
			||||||
    MailStorage, Request
 | 
					    Request, MailEntry, get_mail, get_mails_list, MailList
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@dataclass
 | 
					@dataclass
 | 
				
			||||||
@@ -18,9 +17,8 @@ class Session:
 | 
				
			|||||||
    username: str = ""
 | 
					    username: str = ""
 | 
				
			||||||
    read_items: Path = None
 | 
					    read_items: Path = None
 | 
				
			||||||
    # common state
 | 
					    # common state
 | 
				
			||||||
    all_sessions: ClassVar[List] = []
 | 
					    all_sessions: ClassVar[Set] = set()
 | 
				
			||||||
    mails_path: ClassVar[Path] = Path("")
 | 
					    mails_path: ClassVar[Path] = Path("")
 | 
				
			||||||
    wait_for_privileges_to_drop: ClassVar[Coroutine] = None
 | 
					 | 
				
			||||||
    pending_request: Request = None
 | 
					    pending_request: Request = None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def pop_request(self):
 | 
					    def pop_request(self):
 | 
				
			||||||
@@ -84,7 +82,9 @@ async def handle_user_pass_auth(user_cmd):
 | 
				
			|||||||
    password = cmd.arg1
 | 
					    password = cmd.arg1
 | 
				
			||||||
    validate_user_and_pass(username, password)
 | 
					    validate_user_and_pass(username, password)
 | 
				
			||||||
    write(ok("Good"))
 | 
					    write(ok("Good"))
 | 
				
			||||||
    return username, password
 | 
					    logging.info(f"User: {username} has logged in successfully")
 | 
				
			||||||
 | 
					    session.username = username
 | 
				
			||||||
 | 
					    Session.all_sessions.add(username)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
async def auth_stage():
 | 
					async def auth_stage():
 | 
				
			||||||
@@ -98,9 +98,7 @@ async def auth_stage():
 | 
				
			|||||||
                write(msg("USER"))
 | 
					                write(msg("USER"))
 | 
				
			||||||
                write(end())
 | 
					                write(end())
 | 
				
			||||||
            else:
 | 
					            else:
 | 
				
			||||||
                username, password = await handle_user_pass_auth(req)
 | 
					                return await handle_user_pass_auth(req)
 | 
				
			||||||
                logging.info(f"User: {username} has logged in successfully")
 | 
					 | 
				
			||||||
                return username
 | 
					 | 
				
			||||||
        except AuthError:
 | 
					        except AuthError:
 | 
				
			||||||
            write(err("Wrong auth"))
 | 
					            write(err("Wrong auth"))
 | 
				
			||||||
        except ClientQuit as c:
 | 
					        except ClientQuit as c:
 | 
				
			||||||
@@ -111,65 +109,96 @@ async def auth_stage():
 | 
				
			|||||||
        raise ClientError("Failed to authenticate")
 | 
					        raise ClientError("Failed to authenticate")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
async def transaction_stage():
 | 
					async def process_transactions(mails_list: List[MailEntry]):
 | 
				
			||||||
    session: Session = current_session.get()
 | 
					    session: Session = current_session.get()
 | 
				
			||||||
    logging.debug(f"Entering transaction stage for {session.username}")
 | 
					 | 
				
			||||||
    deleted_message_ids = []
 | 
					 | 
				
			||||||
    mailbox = MailStorage(Session.mails_path / 'new')
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
    with session.read_items.open() as f:
 | 
					    mails = MailList(mails_list)
 | 
				
			||||||
        read_items = set(f.read().splitlines())
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
    mails_list = mailbox.get_mails_list()
 | 
					 | 
				
			||||||
    mails_map = {str(entry.nid): entry for entry in mails_list}
 | 
					 | 
				
			||||||
    while True:
 | 
					    while True:
 | 
				
			||||||
        try:
 | 
					        try:
 | 
				
			||||||
            req = await next_req()
 | 
					            req = await session.next_req()
 | 
				
			||||||
            logging.debug(f"Request: {req}")
 | 
					            logging.debug(f"Request: {req}")
 | 
				
			||||||
            if req.cmd is Command.CAPA:
 | 
					            if req.cmd is Command.CAPA:
 | 
				
			||||||
                write(ok("No CAPA"))
 | 
					                write(ok("CAPA follows"))
 | 
				
			||||||
 | 
					                write(msg("UIDL"))
 | 
				
			||||||
                write(end())
 | 
					                write(end())
 | 
				
			||||||
            elif req.cmd is Command.STAT:
 | 
					            elif req.cmd is Command.STAT:
 | 
				
			||||||
                num, size = mailbox.get_mailbox_size()
 | 
					                num, size = mails.compute_stat()
 | 
				
			||||||
                write(ok(f"{num} {size}"))
 | 
					                write(ok(f"{num} {size}"))
 | 
				
			||||||
            elif req.cmd is Command.LIST:
 | 
					            elif req.cmd is Command.LIST:
 | 
				
			||||||
                if req.arg1:
 | 
					                if req.arg1:
 | 
				
			||||||
                    write(ok(f"{req.arg1} {mails_map[req.arg1].size}"))
 | 
					                    entry = mails.get(req.arg1)
 | 
				
			||||||
 | 
					                    if entry:
 | 
				
			||||||
 | 
					                        write(ok(f"{req.arg1} {entry.size}"))
 | 
				
			||||||
 | 
					                    else:
 | 
				
			||||||
 | 
					                        write(err("Not found"))
 | 
				
			||||||
                else:
 | 
					                else:
 | 
				
			||||||
                    write(ok("Mails follow"))
 | 
					                    write(ok("Mails follow"))
 | 
				
			||||||
                    for entry in mails_list:
 | 
					                    for entry in mails.get_all():
 | 
				
			||||||
                        write(msg(f"{entry.nid} {entry.size}"))
 | 
					                        write(msg(f"{entry.nid} {entry.size}"))
 | 
				
			||||||
                    write(end())
 | 
					                    write(end())
 | 
				
			||||||
            elif req.cmd is Command.UIDL:
 | 
					            elif req.cmd is Command.UIDL:
 | 
				
			||||||
                if req.arg1:
 | 
					                if req.arg1:
 | 
				
			||||||
                    write(ok(f"{req.arg1} {mails_map[req.arg1].uid}"))
 | 
					                    entry = mails.get(req.arg1)
 | 
				
			||||||
 | 
					                    if entry:
 | 
				
			||||||
 | 
					                        write(ok(f"{req.arg1} {entry.uid}"))
 | 
				
			||||||
 | 
					                    else:
 | 
				
			||||||
 | 
					                        write(err("Not found"))
 | 
				
			||||||
                else:
 | 
					                else:
 | 
				
			||||||
                    write(ok("Mails follow"))
 | 
					                    write(ok("Mails follow"))
 | 
				
			||||||
                    for entry in mails_list:
 | 
					                    for entry in mails.get_all():
 | 
				
			||||||
                        write(msg(f"{entry.nid} {entry.uid}"))
 | 
					                        write(msg(f"{entry.nid} {entry.uid}"))
 | 
				
			||||||
                    write(end())
 | 
					                    write(end())
 | 
				
			||||||
                    await session.writer.drain()
 | 
					                    await session.writer.drain()
 | 
				
			||||||
            elif req.cmd is Command.RETR:
 | 
					            elif req.cmd is Command.RETR:
 | 
				
			||||||
                if req.arg1 not in mails_map:
 | 
					                entry = mails.get(req.arg1)
 | 
				
			||||||
                    write(err("Not found"))
 | 
					                if entry:
 | 
				
			||||||
                else:
 | 
					 | 
				
			||||||
                    write(ok("Contents follow"))
 | 
					                    write(ok("Contents follow"))
 | 
				
			||||||
                    write(mailbox.get_mail(mails_map[req.arg1]))
 | 
					                    write(get_mail(entry))
 | 
				
			||||||
                    write(end())
 | 
					                    write(end())
 | 
				
			||||||
                    await session.writer.drain()
 | 
					                    await session.writer.drain()
 | 
				
			||||||
 | 
					                else:
 | 
				
			||||||
 | 
					                    write(err("Not found"))
 | 
				
			||||||
 | 
					            elif req.cmd is Command.DELE:
 | 
				
			||||||
 | 
					                entry = mails.get(req.arg1)
 | 
				
			||||||
 | 
					                if entry:
 | 
				
			||||||
 | 
					                    mails.delete(req.arg1)
 | 
				
			||||||
 | 
					                else:
 | 
				
			||||||
 | 
					                    write(err("Not found"))
 | 
				
			||||||
 | 
					            elif req.cmd is Command.RSET:
 | 
				
			||||||
 | 
					                mails = MailList(mails_list)
 | 
				
			||||||
 | 
					            elif req.cmd is Command.NOOP:
 | 
				
			||||||
 | 
					                pass
 | 
				
			||||||
            else:
 | 
					            else:
 | 
				
			||||||
                write(err("Not implemented"))
 | 
					                write(err("Not implemented"))
 | 
				
			||||||
 | 
					                raise ClientError("We shouldn't reach here")
 | 
				
			||||||
        except ClientQuit:
 | 
					        except ClientQuit:
 | 
				
			||||||
            write(ok("Bye"))
 | 
					            write(ok("Bye"))
 | 
				
			||||||
            return deleted_message_ids
 | 
					            return mails.deleted_uids
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					async def transaction_stage():
 | 
				
			||||||
 | 
					    session: Session = current_session.get()
 | 
				
			||||||
 | 
					    logging.debug(f"Entering transaction stage for {session.username}")
 | 
				
			||||||
 | 
					    session.read_items = Session.mails_path / session.username
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    with session.read_items.open() as f:
 | 
				
			||||||
 | 
					        read_items = set(f.read().splitlines())
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    mails_list = [entry for entry in get_mails_list(Session.mails_path / 'new') if entry.uid not in read_items]
 | 
				
			||||||
 | 
					    return await process_transactions(mails_list)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def delete_messages(delete_ids):
 | 
					def delete_messages(delete_ids):
 | 
				
			||||||
 | 
					    session: Session = current_session.get()
 | 
				
			||||||
 | 
					    with session.read_items.open(mode="w") as f:
 | 
				
			||||||
 | 
					        f.writelines(delete_ids)
 | 
				
			||||||
    logging.info(f"Client deleted these ids {delete_ids}")
 | 
					    logging.info(f"Client deleted these ids {delete_ids}")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
async def new_session(stream_reader: asyncio.StreamReader, stream_writer: asyncio.StreamWriter):
 | 
					async def new_session(stream_reader: asyncio.StreamReader, stream_writer: asyncio.StreamWriter):
 | 
				
			||||||
    current_session.set(Session(stream_reader, stream_writer))
 | 
					    session = Session(stream_reader, stream_writer)
 | 
				
			||||||
 | 
					    current_session.set(session)
 | 
				
			||||||
    logging.info(f"New session started with {stream_reader} and {stream_writer}")
 | 
					    logging.info(f"New session started with {stream_reader} and {stream_writer}")
 | 
				
			||||||
    try:
 | 
					    try:
 | 
				
			||||||
        await auth_stage()
 | 
					        await auth_stage()
 | 
				
			||||||
@@ -182,6 +211,8 @@ async def new_session(stream_reader: asyncio.StreamReader, stream_writer: asynci
 | 
				
			|||||||
        logging.error(f"Serious client error", e)
 | 
					        logging.error(f"Serious client error", e)
 | 
				
			||||||
        raise
 | 
					        raise
 | 
				
			||||||
    finally:
 | 
					    finally:
 | 
				
			||||||
 | 
					        if session.username:
 | 
				
			||||||
 | 
					            Session.all_sessions.remove(session.username)
 | 
				
			||||||
        stream_writer.close()
 | 
					        stream_writer.close()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -28,15 +28,17 @@ User = NewType('User', str)
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class Command(Enum):
 | 
					class Command(Enum):
 | 
				
			||||||
 | 
					    CAPA = auto()
 | 
				
			||||||
    USER = auto()
 | 
					    USER = auto()
 | 
				
			||||||
    PASS = auto()
 | 
					    PASS = auto()
 | 
				
			||||||
    CAPA = auto()
 | 
					 | 
				
			||||||
    QUIT = auto()
 | 
					    QUIT = auto()
 | 
				
			||||||
 | 
					    STAT = auto()
 | 
				
			||||||
    LIST = auto()
 | 
					    LIST = auto()
 | 
				
			||||||
    UIDL = auto()
 | 
					    UIDL = auto()
 | 
				
			||||||
    RETR = auto()
 | 
					    RETR = auto()
 | 
				
			||||||
    DELE = auto()
 | 
					    DELE = auto()
 | 
				
			||||||
    STAT = auto()
 | 
					    RSET = auto()
 | 
				
			||||||
 | 
					    NOOP = auto()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@dataclass
 | 
					@dataclass
 | 
				
			||||||
@@ -88,11 +90,6 @@ def parse_command(line: bytes) -> Request:
 | 
				
			|||||||
    return request
 | 
					    return request
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def files_in_path(path):
 | 
					 | 
				
			||||||
    for _, _, files in os.walk(path):
 | 
					 | 
				
			||||||
        return [(f, os.path.join(path, f)) for f in files]
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
@dataclass
 | 
					@dataclass
 | 
				
			||||||
class MailEntry:
 | 
					class MailEntry:
 | 
				
			||||||
    uid: str
 | 
					    uid: str
 | 
				
			||||||
@@ -109,23 +106,46 @@ class MailEntry:
 | 
				
			|||||||
        self.path = path
 | 
					        self.path = path
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class MailStorage:
 | 
					def files_in_path(path):
 | 
				
			||||||
    def __init__(self, dirpath: Path):
 | 
					    for _, _, files in os.walk(path):
 | 
				
			||||||
        self.dirpath = dirpath
 | 
					        return [(f, os.path.join(path, f)) for f in files]
 | 
				
			||||||
        self.files = files_in_path(self.dirpath)
 | 
					
 | 
				
			||||||
        self.entries = [MailEntry(filename, path) for filename, path in self.files]
 | 
					
 | 
				
			||||||
        self.entries = sorted(self.entries, reverse=True, key=lambda e: e.c_time)
 | 
					def get_mails_list(dirpath: Path) -> List[MailEntry]:
 | 
				
			||||||
        for i, entry in enumerate(self.entries, start=1):
 | 
					    files = files_in_path(dirpath)
 | 
				
			||||||
 | 
					    entries = [MailEntry(filename, path) for filename, path in files]
 | 
				
			||||||
 | 
					    return entries
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def set_nid(entries: List[MailEntry]):
 | 
				
			||||||
 | 
					    entries.sort(reverse=True, key=lambda e: e.c_time)
 | 
				
			||||||
 | 
					    entries = sorted(entries, reverse=True, key=lambda e: e.c_time)
 | 
				
			||||||
 | 
					    for i, entry in enumerate(entries, start=1):
 | 
				
			||||||
        entry.nid = i
 | 
					        entry.nid = i
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def get_mailbox_size(self) -> (int, int):
 | 
					 | 
				
			||||||
        return len(self.entries), sum(entry.size for entry in self.entries)
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def get_mails_list(self) -> List[MailEntry]:
 | 
					 | 
				
			||||||
        return self.entries
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    @staticmethod
 | 
					 | 
				
			||||||
def get_mail(entry: MailEntry) -> bytes:
 | 
					def get_mail(entry: MailEntry) -> bytes:
 | 
				
			||||||
    with open(entry.path, mode='rb') as fp:
 | 
					    with open(entry.path, mode='rb') as fp:
 | 
				
			||||||
        return fp.read()
 | 
					        return fp.read()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class MailList:
 | 
				
			||||||
 | 
					    def __init__(self, entries: List[MailEntry]):
 | 
				
			||||||
 | 
					        self.entries = entries
 | 
				
			||||||
 | 
					        set_nid(self.entries)
 | 
				
			||||||
 | 
					        self.mails_map = {str(e.nid): e for e in entries}
 | 
				
			||||||
 | 
					        self.deleted_uids = set()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def delete(self, nid: str):
 | 
				
			||||||
 | 
					        self.deleted_uids.add(self.mails_map.pop(nid).uid)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def get(self, nid: str):
 | 
				
			||||||
 | 
					        self.mails_map.get(nid)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def get_all(self):
 | 
				
			||||||
 | 
					        return [e for e in self.entries if str(e.nid) in self.mails_map]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def compute_stat(self):
 | 
				
			||||||
 | 
					        entries = self.get_all()
 | 
				
			||||||
 | 
					        return len(entries), sum(entry.size for entry in entries)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 
 | 
				
			|||||||
		Reference in New Issue
	
	Block a user