add return type hinting for pop3

This commit is contained in:
Balakrishnan Balasubramanian 2023-06-07 22:07:38 -04:00
parent 9b0400583c
commit fc7ff2a2b6

View File

@ -14,10 +14,10 @@ from .poputils import InvalidCommand, parse_command, err, Command, ClientQuit, C
Request, MailEntry, get_mail, get_mails_list, MailList Request, MailEntry, get_mail, get_mails_list, MailList
async def next_req(): async def next_req() -> Request:
for _ in range(InvalidCommand.RETRIES): for _ in range(InvalidCommand.RETRIES):
line = await state().reader.readline() line = await state().reader.readline()
logging.debug(f"Client: {line}") # logging.debug(f"Client: {line}")
if not line: if not line:
continue continue
try: try:
@ -32,7 +32,7 @@ async def next_req():
raise ClientError(f"Bad command {InvalidCommand.RETRIES} times") raise ClientError(f"Bad command {InvalidCommand.RETRIES} times")
async def expect_cmd(*commands: Command): async def expect_cmd(*commands: Command) -> Request:
req = await next_req() req = await next_req()
if req.cmd not in commands: if req.cmd not in commands:
logging.error(f"Unexpected command: {req.cmd} is not in {commands}") logging.error(f"Unexpected command: {req.cmd} is not in {commands}")
@ -40,12 +40,12 @@ async def expect_cmd(*commands: Command):
return req return req
def write(data): def write(data) -> None:
logging.debug(f"Server: {data}") logging.debug(f"Server: {data}")
state().writer.write(data) state().writer.write(data)
def validate_password(username, password): def validate_password(username, password) -> None:
try: try:
pwinfo, mbox = config().users[username] pwinfo, mbox = config().users[username]
except: except:
@ -57,7 +57,7 @@ def validate_password(username, password):
state().mbox = mbox state().mbox = mbox
async def handle_user_pass_auth(user_cmd): async def handle_user_pass_auth(user_cmd) -> None:
username = user_cmd.arg1 username = user_cmd.arg1
if not username: if not username:
raise AuthError("Invalid USER command. username empty") raise AuthError("Invalid USER command. username empty")
@ -68,7 +68,7 @@ async def handle_user_pass_auth(user_cmd):
logging.info(f"{username=} has logged in successfully") logging.info(f"{username=} has logged in successfully")
async def auth_stage(): async def auth_stage() -> None:
write(ok("Server Ready")) write(ok("Server Ready"))
for _ in range(AuthError.RETRIES): for _ in range(AuthError.RETRIES):
try: try:
@ -78,8 +78,8 @@ async def auth_stage():
write(msg("USER")) write(msg("USER"))
write(end()) write(end())
else: else:
username = await handle_user_pass_auth(req) await handle_user_pass_auth(req)
if username in config().loggedin_users: if (username:=state().username) in config().loggedin_users:
logging.warning( logging.warning(
f"User: {username} already has an active session") f"User: {username} already has an active session")
raise AuthError("Already logged in") raise AuthError("Already logged in")
@ -96,18 +96,18 @@ async def auth_stage():
raise ClientError("Failed to authenticate") raise ClientError("Failed to authenticate")
def trans_command_capa(_, __): def trans_command_capa(_, __) -> None:
write(ok("CAPA follows")) write(ok("CAPA follows"))
write(msg("UIDL")) write(msg("UIDL"))
write(end()) write(end())
def trans_command_stat(mails: MailList, _): def trans_command_stat(mails: MailList, _) -> None:
num, size = mails.compute_stat() num, size = mails.compute_stat()
write(ok(f"{num} {size}")) write(ok(f"{num} {size}"))
def trans_command_list(mails: MailList, req: Request): def trans_command_list(mails: MailList, req: Request) -> None:
if req.arg1: if req.arg1:
entry = mails.get(req.arg1) entry = mails.get(req.arg1)
if entry: if entry:
@ -121,7 +121,7 @@ def trans_command_list(mails: MailList, req: Request):
write(end()) write(end())
def trans_command_uidl(mails: MailList, req: Request): def trans_command_uidl(mails: MailList, req: Request) -> None:
if req.arg1: if req.arg1:
entry = mails.get(req.arg1) entry = mails.get(req.arg1)
if entry: if entry:
@ -135,7 +135,7 @@ def trans_command_uidl(mails: MailList, req: Request):
write(end()) write(end())
def trans_command_retr(mails: MailList, req: Request): def trans_command_retr(mails: MailList, req: Request) -> None:
entry = mails.get(req.arg1) entry = mails.get(req.arg1)
if entry: if entry:
write(ok("Contents follow")) write(ok("Contents follow"))
@ -146,7 +146,7 @@ def trans_command_retr(mails: MailList, req: Request):
write(err("Not found")) write(err("Not found"))
def trans_command_dele(mails: MailList, req: Request): def trans_command_dele(mails: MailList, req: Request) -> None:
entry = mails.get(req.arg1) entry = mails.get(req.arg1)
if entry: if entry:
mails.delete(req.arg1) mails.delete(req.arg1)
@ -155,11 +155,11 @@ def trans_command_dele(mails: MailList, req: Request):
write(err("Not found")) write(err("Not found"))
def trans_command_noop(_, __): def trans_command_noop(_, __) -> None:
write(ok("Hmm")) write(ok("Hmm"))
async def process_transactions(mails_list: list[MailEntry]): async def process_transactions(mails_list: list[MailEntry]) -> set[str]:
mails = MailList(mails_list) mails = MailList(mails_list)
def reset(_, __): def reset(_, __):
@ -194,19 +194,19 @@ async def process_transactions(mails_list: list[MailEntry]):
await state().writer.drain() await state().writer.drain()
def get_deleted_items(deleted_items_path: Path): def get_deleted_items(deleted_items_path: Path) -> set[str]:
if deleted_items_path.exists(): if deleted_items_path.exists():
with deleted_items_path.open() as f: with deleted_items_path.open() as f:
return set(f.read().splitlines()) return set(f.read().splitlines())
return set() return set()
def save_deleted_items(deleted_items_path: Path, deleted_items: set[str]): def save_deleted_items(deleted_items_path: Path, deleted_items: set[str]) -> None:
with deleted_items_path.open(mode="w") as f: with deleted_items_path.open(mode="w") as f:
f.writelines(f"{did}\n" for did in deleted_items) f.writelines(f"{did}\n" for did in deleted_items)
async def transaction_stage(): async def transaction_stage() -> None:
deleted_items_path = config().mails_path / state().mbox / state().username deleted_items_path = config().mails_path / state().mbox / state().username
existing_deleted_items: set[str] = get_deleted_items(deleted_items_path) existing_deleted_items: set[str] = get_deleted_items(deleted_items_path)
mails_list = [ mails_list = [
@ -215,7 +215,7 @@ async def transaction_stage():
if entry.uid not in existing_deleted_items if entry.uid not in existing_deleted_items
] ]
new_deleted_items: Set = await process_transactions(mails_list) new_deleted_items: set[str] = await process_transactions(mails_list)
logging.info(f"completed transactions. Deleted:{len(new_deleted_items)}") logging.info(f"completed transactions. Deleted:{len(new_deleted_items)}")
if new_deleted_items: if new_deleted_items:
save_deleted_items(deleted_items_path, save_deleted_items(deleted_items_path,
@ -224,7 +224,7 @@ async def transaction_stage():
logging.info(f"Saved deleted items") logging.info(f"Saved deleted items")
async def start_session(): async def start_session() -> None:
logging.info("New session started") logging.info("New session started")
try: try:
await auth_stage() await auth_stage()
@ -243,7 +243,7 @@ async def start_session():
config().loggedin_users.remove(state().username) config().loggedin_users.remove(state().username)
def parse_users(users: list[User]): def parse_users(users: list[User]) -> dict[str, tuple[PWInfo, str]]:
def inner(): def inner():
for user in users: for user in users:
@ -315,7 +315,7 @@ async def create_pop_server(host: str,
ssl=ssl_context) ssl=ssl_context)
async def a_main(*args, **kwargs): async def a_main(*args, **kwargs) -> None:
server = await create_pop_server(*args, **kwargs) server = await create_pop_server(*args, **kwargs)
await server.serve_forever() await server.serve_forever()