mix
This commit is contained in:
		
							
								
								
									
										284
									
								
								mail4one/pop3.py
									
									
									
									
									
								
							
							
						
						
									
										284
									
								
								mail4one/pop3.py
									
									
									
									
									
								
							@@ -14,93 +14,97 @@ from .poputils import InvalidCommand, parse_command, err, Command, ClientQuit, C
 | 
			
		||||
class Session:
 | 
			
		||||
    reader: asyncio.StreamReader
 | 
			
		||||
    writer: asyncio.StreamWriter
 | 
			
		||||
    username: str = ""
 | 
			
		||||
    read_items: Path = None
 | 
			
		||||
 | 
			
		||||
    # common state
 | 
			
		||||
    all_sessions: ClassVar[Set] = set()
 | 
			
		||||
    mails_path: ClassVar[Path] = Path("")
 | 
			
		||||
    pending_request: Request = None
 | 
			
		||||
    current_session: ClassVar = ContextVar("session")
 | 
			
		||||
 | 
			
		||||
    def pop_request(self):
 | 
			
		||||
        request = self.pending_request
 | 
			
		||||
        self.pending_request = None
 | 
			
		||||
        return request
 | 
			
		||||
    @classmethod
 | 
			
		||||
    def get(cls):
 | 
			
		||||
        return cls.current_session.get()
 | 
			
		||||
 | 
			
		||||
    async def next_req(self):
 | 
			
		||||
        if self.pending_request:
 | 
			
		||||
            return self.pop_request()
 | 
			
		||||
    @classmethod
 | 
			
		||||
    def reader(cls):
 | 
			
		||||
        return cls.get().reader
 | 
			
		||||
 | 
			
		||||
        for _ in range(InvalidCommand.RETRIES):
 | 
			
		||||
            line = await self.reader.readline()
 | 
			
		||||
            logging.debug(f"Client: {line}")
 | 
			
		||||
            if not line:
 | 
			
		||||
                continue
 | 
			
		||||
            try:
 | 
			
		||||
                request: Request = parse_command(line)
 | 
			
		||||
            except InvalidCommand:
 | 
			
		||||
                write(err("Bad command"))
 | 
			
		||||
            else:
 | 
			
		||||
                if request.cmd == Command.QUIT:
 | 
			
		||||
                    raise ClientQuit
 | 
			
		||||
                return request
 | 
			
		||||
    @classmethod
 | 
			
		||||
    def writer(cls):
 | 
			
		||||
        return cls.get().writer
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
async def next_req():
 | 
			
		||||
    for _ in range(InvalidCommand.RETRIES):
 | 
			
		||||
        line = await Session.reader().readline()
 | 
			
		||||
        logging.debug(f"Client: {line}")
 | 
			
		||||
        if not line:
 | 
			
		||||
            continue
 | 
			
		||||
        try:
 | 
			
		||||
            request: Request = parse_command(line)
 | 
			
		||||
        except InvalidCommand:
 | 
			
		||||
            write(err("Bad command"))
 | 
			
		||||
        else:
 | 
			
		||||
            raise ClientError(f"Bad command {InvalidCommand.RETRIES} times")
 | 
			
		||||
 | 
			
		||||
    async def expect_cmd(self, *commands: Command, optional=False):
 | 
			
		||||
        req = await self.next_req()
 | 
			
		||||
        if req.cmd not in commands:
 | 
			
		||||
            if not optional:
 | 
			
		||||
                logging.error(f"{req.cmd} is not in {commands}")
 | 
			
		||||
                raise ClientError
 | 
			
		||||
            else:
 | 
			
		||||
                self.pending_request = req
 | 
			
		||||
                return
 | 
			
		||||
        return req
 | 
			
		||||
            if request.cmd == Command.QUIT:
 | 
			
		||||
                raise ClientQuit
 | 
			
		||||
            return request
 | 
			
		||||
    else:
 | 
			
		||||
        raise ClientError(f"Bad command {InvalidCommand.RETRIES} times")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
current_session: ContextVar[Session] = ContextVar("session")
 | 
			
		||||
async def expect_cmd(*commands: Command):
 | 
			
		||||
    req = await next_req()
 | 
			
		||||
    if req.cmd not in commands:
 | 
			
		||||
        logging.error(f"Unexpected command: {req.cmd} is not in {commands}")
 | 
			
		||||
        raise ClientError
 | 
			
		||||
    return req
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def write(data):
 | 
			
		||||
    logging.debug(f"Server: {data}")
 | 
			
		||||
    session: Session = current_session.get()
 | 
			
		||||
    session: Session = Session.current_session.get()
 | 
			
		||||
    session.writer.write(data)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
async def drain():
 | 
			
		||||
    session: Session = Session.current_session.get()
 | 
			
		||||
    await session.writer.drain()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def validate_user_and_pass(username, password):
 | 
			
		||||
    if username != password:
 | 
			
		||||
        raise AuthError("Invalid user pass")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
async def handle_user_pass_auth(user_cmd):
 | 
			
		||||
    session: Session = current_session.get()
 | 
			
		||||
    username = user_cmd.arg1
 | 
			
		||||
    if not username:
 | 
			
		||||
        raise AuthError("Invalid USER command. username empty")
 | 
			
		||||
    write(ok("Welcome"))
 | 
			
		||||
    cmd = await session.expect_cmd(Command.PASS)
 | 
			
		||||
    cmd = await expect_cmd(Command.PASS)
 | 
			
		||||
    password = cmd.arg1
 | 
			
		||||
    validate_user_and_pass(username, password)
 | 
			
		||||
    write(ok("Good"))
 | 
			
		||||
    logging.info(f"User: {username} has logged in successfully")
 | 
			
		||||
    session.username = username
 | 
			
		||||
    Session.all_sessions.add(username)
 | 
			
		||||
    return username
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
async def auth_stage():
 | 
			
		||||
    session: Session = current_session.get()
 | 
			
		||||
    write(ok("Server Ready"))
 | 
			
		||||
    for _ in range(AuthError.RETRIES):
 | 
			
		||||
        try:
 | 
			
		||||
            req = await session.expect_cmd(Command.USER, Command.CAPA)
 | 
			
		||||
            req = await expect_cmd(Command.USER, Command.CAPA)
 | 
			
		||||
            if req.cmd is Command.CAPA:
 | 
			
		||||
                write(ok("Following are supported"))
 | 
			
		||||
                write(msg("USER"))
 | 
			
		||||
                write(end())
 | 
			
		||||
            else:
 | 
			
		||||
                return await handle_user_pass_auth(req)
 | 
			
		||||
        except AuthError:
 | 
			
		||||
            write(err("Wrong auth"))
 | 
			
		||||
                username = await handle_user_pass_auth(req)
 | 
			
		||||
                if username in Session.all_sessions:
 | 
			
		||||
                    logging.warning(f"User: {username} already has an active session")
 | 
			
		||||
                    raise AuthError("Already logged in")
 | 
			
		||||
                else:
 | 
			
		||||
                    write(ok("Login successful"))
 | 
			
		||||
        except AuthError as ae:
 | 
			
		||||
            write(err(f"Auth Failed: {ae}"))
 | 
			
		||||
        except ClientQuit as c:
 | 
			
		||||
            write(ok("Bye"))
 | 
			
		||||
            logging.warning("Client has QUIT before auth succeeded")
 | 
			
		||||
@@ -109,101 +113,135 @@ async def auth_stage():
 | 
			
		||||
        raise ClientError("Failed to authenticate")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
async def process_transactions(mails_list: List[MailEntry]):
 | 
			
		||||
    session: Session = current_session.get()
 | 
			
		||||
def trans_command_capa(_, __):
 | 
			
		||||
    write(ok("CAPA follows"))
 | 
			
		||||
    write(msg("UIDL"))
 | 
			
		||||
    write(end())
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def trans_command_stat(mails, _):
 | 
			
		||||
    num, size = mails.compute_stat()
 | 
			
		||||
    write(ok(f"{num} {size}"))
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def trans_command_list(mails, req):
 | 
			
		||||
    if req.arg1:
 | 
			
		||||
        entry = mails.get(req.arg1)
 | 
			
		||||
        if entry:
 | 
			
		||||
            write(ok(f"{req.arg1} {entry.size}"))
 | 
			
		||||
        else:
 | 
			
		||||
            write(err("Not found"))
 | 
			
		||||
    else:
 | 
			
		||||
        write(ok("Mails follow"))
 | 
			
		||||
        for entry in mails.get_all():
 | 
			
		||||
            write(msg(f"{entry.nid} {entry.size}"))
 | 
			
		||||
        write(end())
 | 
			
		||||
        await drain()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def trans_command_uidl(mails, req):
 | 
			
		||||
    if req.arg1:
 | 
			
		||||
        entry = mails.get(req.arg1)
 | 
			
		||||
        if entry:
 | 
			
		||||
            write(ok(f"{req.arg1} {entry.uid}"))
 | 
			
		||||
        else:
 | 
			
		||||
            write(err("Not found"))
 | 
			
		||||
    else:
 | 
			
		||||
        write(ok("Mails follow"))
 | 
			
		||||
        for entry in mails.get_all():
 | 
			
		||||
            write(msg(f"{entry.nid} {entry.uid}"))
 | 
			
		||||
        write(end())
 | 
			
		||||
        await drain()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def trans_command_retr(mails, req):
 | 
			
		||||
    entry = mails.get(req.arg1)
 | 
			
		||||
    if entry:
 | 
			
		||||
        write(ok("Contents follow"))
 | 
			
		||||
        write(get_mail(entry))
 | 
			
		||||
        write(end())
 | 
			
		||||
        drain()
 | 
			
		||||
    else:
 | 
			
		||||
        write(err("Not found"))
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def trans_command_dele(mails, req):
 | 
			
		||||
    entry = mails.get(req.arg1)
 | 
			
		||||
    if entry:
 | 
			
		||||
        mails.delete(req.arg1)
 | 
			
		||||
    else:
 | 
			
		||||
        write(err("Not found"))
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def trans_command_noop(_, __):
 | 
			
		||||
    write(ok("Hmm"))
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
async def process_transactions(mails_list: List[MailEntry]):
 | 
			
		||||
    mails = MailList(mails_list)
 | 
			
		||||
 | 
			
		||||
    def reset(_, __):
 | 
			
		||||
        nonlocal mails
 | 
			
		||||
        mails = MailList(mails_list)
 | 
			
		||||
 | 
			
		||||
    handle_map = {
 | 
			
		||||
        Command.CAPA: trans_command_capa,
 | 
			
		||||
        Command.STAT: trans_command_stat,
 | 
			
		||||
        Command.LIST: trans_command_list,
 | 
			
		||||
        Command.UIDL: trans_command_uidl,
 | 
			
		||||
        Command.RETR: trans_command_retr,
 | 
			
		||||
        Command.DELE: trans_command_dele,
 | 
			
		||||
        Command.RSET: reset,
 | 
			
		||||
        Command.NOOP: trans_command_noop,
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    while True:
 | 
			
		||||
        try:
 | 
			
		||||
            req = await session.next_req()
 | 
			
		||||
            logging.debug(f"Request: {req}")
 | 
			
		||||
            if req.cmd is Command.CAPA:
 | 
			
		||||
                write(ok("CAPA follows"))
 | 
			
		||||
                write(msg("UIDL"))
 | 
			
		||||
                write(end())
 | 
			
		||||
            elif req.cmd is Command.STAT:
 | 
			
		||||
                num, size = mails.compute_stat()
 | 
			
		||||
                write(ok(f"{num} {size}"))
 | 
			
		||||
            elif req.cmd is Command.LIST:
 | 
			
		||||
                if req.arg1:
 | 
			
		||||
                    entry = mails.get(req.arg1)
 | 
			
		||||
                    if entry:
 | 
			
		||||
                        write(ok(f"{req.arg1} {entry.size}"))
 | 
			
		||||
                    else:
 | 
			
		||||
                        write(err("Not found"))
 | 
			
		||||
                else:
 | 
			
		||||
                    write(ok("Mails follow"))
 | 
			
		||||
                    for entry in mails.get_all():
 | 
			
		||||
                        write(msg(f"{entry.nid} {entry.size}"))
 | 
			
		||||
                    write(end())
 | 
			
		||||
            elif req.cmd is Command.UIDL:
 | 
			
		||||
                if req.arg1:
 | 
			
		||||
                    entry = mails.get(req.arg1)
 | 
			
		||||
                    if entry:
 | 
			
		||||
                        write(ok(f"{req.arg1} {entry.uid}"))
 | 
			
		||||
                    else:
 | 
			
		||||
                        write(err("Not found"))
 | 
			
		||||
                else:
 | 
			
		||||
                    write(ok("Mails follow"))
 | 
			
		||||
                    for entry in mails.get_all():
 | 
			
		||||
                        write(msg(f"{entry.nid} {entry.uid}"))
 | 
			
		||||
                    write(end())
 | 
			
		||||
                    await session.writer.drain()
 | 
			
		||||
            elif req.cmd is Command.RETR:
 | 
			
		||||
                entry = mails.get(req.arg1)
 | 
			
		||||
                if entry:
 | 
			
		||||
                    write(ok("Contents follow"))
 | 
			
		||||
                    write(get_mail(entry))
 | 
			
		||||
                    write(end())
 | 
			
		||||
                    await session.writer.drain()
 | 
			
		||||
                else:
 | 
			
		||||
                    write(err("Not found"))
 | 
			
		||||
            elif req.cmd is Command.DELE:
 | 
			
		||||
                entry = mails.get(req.arg1)
 | 
			
		||||
                if entry:
 | 
			
		||||
                    mails.delete(req.arg1)
 | 
			
		||||
                else:
 | 
			
		||||
                    write(err("Not found"))
 | 
			
		||||
            elif req.cmd is Command.RSET:
 | 
			
		||||
                mails = MailList(mails_list)
 | 
			
		||||
            elif req.cmd is Command.NOOP:
 | 
			
		||||
                pass
 | 
			
		||||
            else:
 | 
			
		||||
                write(err("Not implemented"))
 | 
			
		||||
                raise ClientError("We shouldn't reach here")
 | 
			
		||||
            req = await next_req()
 | 
			
		||||
        except ClientQuit:
 | 
			
		||||
            write(ok("Bye"))
 | 
			
		||||
            return mails.deleted_uids
 | 
			
		||||
        logging.debug(f"Request: {req}")
 | 
			
		||||
        try:
 | 
			
		||||
            func = handle_map[req.cmd]
 | 
			
		||||
        except KeyError:
 | 
			
		||||
            write(err("Not implemented"))
 | 
			
		||||
            raise ClientError("We shouldn't reach here")
 | 
			
		||||
        else:
 | 
			
		||||
            func(mails, req)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
async def transaction_stage():
 | 
			
		||||
    session: Session = current_session.get()
 | 
			
		||||
    logging.debug(f"Entering transaction stage for {session.username}")
 | 
			
		||||
    session.read_items = Session.mails_path / session.username
 | 
			
		||||
async def transaction_stage(deleted_items_path: Path):
 | 
			
		||||
    with deleted_items_path.open() as f:
 | 
			
		||||
        deleted_items = set(f.read().splitlines())
 | 
			
		||||
 | 
			
		||||
    with session.read_items.open() as f:
 | 
			
		||||
        read_items = set(f.read().splitlines())
 | 
			
		||||
 | 
			
		||||
    mails_list = [entry for entry in get_mails_list(Session.mails_path / 'new') if entry.uid not in read_items]
 | 
			
		||||
    mails_list = [entry for entry in get_mails_list(Session.mails_path / 'new') if entry.uid not in deleted_items]
 | 
			
		||||
    return await process_transactions(mails_list)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def delete_messages(delete_ids):
 | 
			
		||||
    session: Session = current_session.get()
 | 
			
		||||
    with session.read_items.open(mode="w") as f:
 | 
			
		||||
def delete_messages(delete_ids, deleted_items_path: Path):
 | 
			
		||||
    with deleted_items_path.open(mode="w") as f:
 | 
			
		||||
        f.writelines(delete_ids)
 | 
			
		||||
    logging.info(f"Client deleted these ids {delete_ids}")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
async def new_session(stream_reader: asyncio.StreamReader, stream_writer: asyncio.StreamWriter):
 | 
			
		||||
    session = Session(stream_reader, stream_writer)
 | 
			
		||||
    current_session.set(session)
 | 
			
		||||
    Session.current_session.set(session)
 | 
			
		||||
    logging.info(f"New session started with {stream_reader} and {stream_writer}")
 | 
			
		||||
    username = None
 | 
			
		||||
    try:
 | 
			
		||||
        await auth_stage()
 | 
			
		||||
        delete_ids = await transaction_stage()
 | 
			
		||||
        delete_messages(delete_ids)
 | 
			
		||||
        username = await auth_stage()
 | 
			
		||||
        assert username is not None
 | 
			
		||||
        Session.all_sessions.add(username)
 | 
			
		||||
        deleted_items_path = Session.mails_path / username
 | 
			
		||||
        logging.info(f"User:{username} logged in successfully")
 | 
			
		||||
 | 
			
		||||
        delete_ids = await transaction_stage(deleted_items_path)
 | 
			
		||||
        logging.info(f"User:{username} completed transactions. Deleted:{delete_ids}")
 | 
			
		||||
 | 
			
		||||
        delete_messages(delete_ids, deleted_items_path)
 | 
			
		||||
        logging.info(f"User:{username} Saved deleted items")
 | 
			
		||||
 | 
			
		||||
    except ClientError as c:
 | 
			
		||||
        write(err("Something went wrong"))
 | 
			
		||||
        logging.error(f"Unexpected client error", c)
 | 
			
		||||
@@ -211,8 +249,8 @@ async def new_session(stream_reader: asyncio.StreamReader, stream_writer: asynci
 | 
			
		||||
        logging.error(f"Serious client error", e)
 | 
			
		||||
        raise
 | 
			
		||||
    finally:
 | 
			
		||||
        if session.username:
 | 
			
		||||
            Session.all_sessions.remove(session.username)
 | 
			
		||||
        if username:
 | 
			
		||||
            Session.all_sessions.remove(username)
 | 
			
		||||
        stream_writer.close()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user