From 7abd6e9e13165b732e571b38bacb808247759371 Mon Sep 17 00:00:00 2001 From: balki <3070606-balki@users.noreply.gitlab.com> Date: Wed, 19 Dec 2018 02:01:43 -0500 Subject: [PATCH] mix --- mail4one/pop3.py | 284 +++++++++++++++++++++++++++-------------------- 1 file changed, 161 insertions(+), 123 deletions(-) diff --git a/mail4one/pop3.py b/mail4one/pop3.py index 7bd4517..75888c5 100644 --- a/mail4one/pop3.py +++ b/mail4one/pop3.py @@ -14,93 +14,97 @@ from .poputils import InvalidCommand, parse_command, err, Command, ClientQuit, C class Session: reader: asyncio.StreamReader writer: asyncio.StreamWriter - username: str = "" - read_items: Path = None + # common state all_sessions: ClassVar[Set] = set() mails_path: ClassVar[Path] = Path("") - pending_request: Request = None + current_session: ClassVar = ContextVar("session") - def pop_request(self): - request = self.pending_request - self.pending_request = None - return request + @classmethod + def get(cls): + return cls.current_session.get() - async def next_req(self): - if self.pending_request: - return self.pop_request() + @classmethod + def reader(cls): + return cls.get().reader - for _ in range(InvalidCommand.RETRIES): - line = await self.reader.readline() - logging.debug(f"Client: {line}") - if not line: - continue - try: - request: Request = parse_command(line) - except InvalidCommand: - write(err("Bad command")) - else: - if request.cmd == Command.QUIT: - raise ClientQuit - return request + @classmethod + def writer(cls): + return cls.get().writer + + +async def next_req(): + for _ in range(InvalidCommand.RETRIES): + line = await Session.reader().readline() + logging.debug(f"Client: {line}") + if not line: + continue + try: + request: Request = parse_command(line) + except InvalidCommand: + write(err("Bad command")) else: - raise ClientError(f"Bad command {InvalidCommand.RETRIES} times") - - async def expect_cmd(self, *commands: Command, optional=False): - req = await self.next_req() - if req.cmd not in commands: - if not optional: - logging.error(f"{req.cmd} is not in {commands}") - raise ClientError - else: - self.pending_request = req - return - return req + if request.cmd == Command.QUIT: + raise ClientQuit + return request + else: + raise ClientError(f"Bad command {InvalidCommand.RETRIES} times") -current_session: ContextVar[Session] = ContextVar("session") +async def expect_cmd(*commands: Command): + req = await next_req() + if req.cmd not in commands: + logging.error(f"Unexpected command: {req.cmd} is not in {commands}") + raise ClientError + return req def write(data): logging.debug(f"Server: {data}") - session: Session = current_session.get() + session: Session = Session.current_session.get() session.writer.write(data) +async def drain(): + session: Session = Session.current_session.get() + await session.writer.drain() + + def validate_user_and_pass(username, password): if username != password: raise AuthError("Invalid user pass") async def handle_user_pass_auth(user_cmd): - session: Session = current_session.get() username = user_cmd.arg1 if not username: raise AuthError("Invalid USER command. username empty") write(ok("Welcome")) - cmd = await session.expect_cmd(Command.PASS) + cmd = await expect_cmd(Command.PASS) password = cmd.arg1 validate_user_and_pass(username, password) - write(ok("Good")) logging.info(f"User: {username} has logged in successfully") - session.username = username - Session.all_sessions.add(username) + return username async def auth_stage(): - session: Session = current_session.get() write(ok("Server Ready")) for _ in range(AuthError.RETRIES): try: - req = await session.expect_cmd(Command.USER, Command.CAPA) + req = await expect_cmd(Command.USER, Command.CAPA) if req.cmd is Command.CAPA: write(ok("Following are supported")) write(msg("USER")) write(end()) else: - return await handle_user_pass_auth(req) - except AuthError: - write(err("Wrong auth")) + username = await handle_user_pass_auth(req) + if username in Session.all_sessions: + logging.warning(f"User: {username} already has an active session") + raise AuthError("Already logged in") + else: + write(ok("Login successful")) + except AuthError as ae: + write(err(f"Auth Failed: {ae}")) except ClientQuit as c: write(ok("Bye")) logging.warning("Client has QUIT before auth succeeded") @@ -109,101 +113,135 @@ async def auth_stage(): raise ClientError("Failed to authenticate") -async def process_transactions(mails_list: List[MailEntry]): - session: Session = current_session.get() +def trans_command_capa(_, __): + write(ok("CAPA follows")) + write(msg("UIDL")) + write(end()) + +def trans_command_stat(mails, _): + num, size = mails.compute_stat() + write(ok(f"{num} {size}")) + + +def trans_command_list(mails, req): + if req.arg1: + entry = mails.get(req.arg1) + if entry: + write(ok(f"{req.arg1} {entry.size}")) + else: + write(err("Not found")) + else: + write(ok("Mails follow")) + for entry in mails.get_all(): + write(msg(f"{entry.nid} {entry.size}")) + write(end()) + await drain() + + +def trans_command_uidl(mails, req): + if req.arg1: + entry = mails.get(req.arg1) + if entry: + write(ok(f"{req.arg1} {entry.uid}")) + else: + write(err("Not found")) + else: + write(ok("Mails follow")) + for entry in mails.get_all(): + write(msg(f"{entry.nid} {entry.uid}")) + write(end()) + await drain() + + +def trans_command_retr(mails, req): + entry = mails.get(req.arg1) + if entry: + write(ok("Contents follow")) + write(get_mail(entry)) + write(end()) + drain() + else: + write(err("Not found")) + + +def trans_command_dele(mails, req): + entry = mails.get(req.arg1) + if entry: + mails.delete(req.arg1) + else: + write(err("Not found")) + + +def trans_command_noop(_, __): + write(ok("Hmm")) + + +async def process_transactions(mails_list: List[MailEntry]): mails = MailList(mails_list) + def reset(_, __): + nonlocal mails + mails = MailList(mails_list) + + handle_map = { + Command.CAPA: trans_command_capa, + Command.STAT: trans_command_stat, + Command.LIST: trans_command_list, + Command.UIDL: trans_command_uidl, + Command.RETR: trans_command_retr, + Command.DELE: trans_command_dele, + Command.RSET: reset, + Command.NOOP: trans_command_noop, + } + while True: try: - req = await session.next_req() - logging.debug(f"Request: {req}") - if req.cmd is Command.CAPA: - write(ok("CAPA follows")) - write(msg("UIDL")) - write(end()) - elif req.cmd is Command.STAT: - num, size = mails.compute_stat() - write(ok(f"{num} {size}")) - elif req.cmd is Command.LIST: - if req.arg1: - entry = mails.get(req.arg1) - if entry: - write(ok(f"{req.arg1} {entry.size}")) - else: - write(err("Not found")) - else: - write(ok("Mails follow")) - for entry in mails.get_all(): - write(msg(f"{entry.nid} {entry.size}")) - write(end()) - elif req.cmd is Command.UIDL: - if req.arg1: - entry = mails.get(req.arg1) - if entry: - write(ok(f"{req.arg1} {entry.uid}")) - else: - write(err("Not found")) - else: - write(ok("Mails follow")) - for entry in mails.get_all(): - write(msg(f"{entry.nid} {entry.uid}")) - write(end()) - await session.writer.drain() - elif req.cmd is Command.RETR: - entry = mails.get(req.arg1) - if entry: - write(ok("Contents follow")) - write(get_mail(entry)) - write(end()) - await session.writer.drain() - else: - write(err("Not found")) - elif req.cmd is Command.DELE: - entry = mails.get(req.arg1) - if entry: - mails.delete(req.arg1) - else: - write(err("Not found")) - elif req.cmd is Command.RSET: - mails = MailList(mails_list) - elif req.cmd is Command.NOOP: - pass - else: - write(err("Not implemented")) - raise ClientError("We shouldn't reach here") + req = await next_req() except ClientQuit: write(ok("Bye")) return mails.deleted_uids + logging.debug(f"Request: {req}") + try: + func = handle_map[req.cmd] + except KeyError: + write(err("Not implemented")) + raise ClientError("We shouldn't reach here") + else: + func(mails, req) -async def transaction_stage(): - session: Session = current_session.get() - logging.debug(f"Entering transaction stage for {session.username}") - session.read_items = Session.mails_path / session.username +async def transaction_stage(deleted_items_path: Path): + with deleted_items_path.open() as f: + deleted_items = set(f.read().splitlines()) - with session.read_items.open() as f: - read_items = set(f.read().splitlines()) - - mails_list = [entry for entry in get_mails_list(Session.mails_path / 'new') if entry.uid not in read_items] + mails_list = [entry for entry in get_mails_list(Session.mails_path / 'new') if entry.uid not in deleted_items] return await process_transactions(mails_list) -def delete_messages(delete_ids): - session: Session = current_session.get() - with session.read_items.open(mode="w") as f: +def delete_messages(delete_ids, deleted_items_path: Path): + with deleted_items_path.open(mode="w") as f: f.writelines(delete_ids) - logging.info(f"Client deleted these ids {delete_ids}") async def new_session(stream_reader: asyncio.StreamReader, stream_writer: asyncio.StreamWriter): session = Session(stream_reader, stream_writer) - current_session.set(session) + Session.current_session.set(session) logging.info(f"New session started with {stream_reader} and {stream_writer}") + username = None try: - await auth_stage() - delete_ids = await transaction_stage() - delete_messages(delete_ids) + username = await auth_stage() + assert username is not None + Session.all_sessions.add(username) + deleted_items_path = Session.mails_path / username + logging.info(f"User:{username} logged in successfully") + + delete_ids = await transaction_stage(deleted_items_path) + logging.info(f"User:{username} completed transactions. Deleted:{delete_ids}") + + delete_messages(delete_ids, deleted_items_path) + logging.info(f"User:{username} Saved deleted items") + except ClientError as c: write(err("Something went wrong")) logging.error(f"Unexpected client error", c) @@ -211,8 +249,8 @@ async def new_session(stream_reader: asyncio.StreamReader, stream_writer: asynci logging.error(f"Serious client error", e) raise finally: - if session.username: - Session.all_sessions.remove(session.username) + if username: + Session.all_sessions.remove(username) stream_writer.close()