fix mypy errors and remove typing

This commit is contained in:
Balakrishnan Balasubramanian 2023-06-06 22:16:26 -04:00
parent aa6e102b07
commit 39d8036563
3 changed files with 16 additions and 21 deletions

View File

@ -6,7 +6,6 @@ import contextvars
from dataclasses import dataclass from dataclasses import dataclass
from hashlib import sha256 from hashlib import sha256
from pathlib import Path from pathlib import Path
from typing import ClassVar, List, Set
from .config import User from .config import User
from .pwhash import parse_hash, check_pass, PWInfo from .pwhash import parse_hash, check_pass, PWInfo
from asyncio import StreamReader, StreamWriter from asyncio import StreamReader, StreamWriter
@ -160,7 +159,7 @@ def trans_command_noop(_, __):
write(ok("Hmm")) write(ok("Hmm"))
async def process_transactions(mails_list: List[MailEntry]): async def process_transactions(mails_list: list[MailEntry]):
mails = MailList(mails_list) mails = MailList(mails_list)
def reset(_, __): def reset(_, __):
@ -202,7 +201,7 @@ def get_deleted_items(deleted_items_path: Path):
return set() return set()
def save_deleted_items(deleted_items_path: Path, deleted_items: Set): def save_deleted_items(deleted_items_path: Path, deleted_items: set[str]):
with deleted_items_path.open(mode="w") as f: with deleted_items_path.open(mode="w") as f:
f.writelines(f"{did}\n" for did in deleted_items) f.writelines(f"{did}\n" for did in deleted_items)
@ -271,14 +270,14 @@ class Config:
self.loggedin_users: set[str] = set() self.loggedin_users: set[str] = set()
c_config = contextvars.ContextVar('config') c_config: contextvars.ContextVar = contextvars.ContextVar('config')
def config() -> Config: def config() -> Config:
return c_config.get() return c_config.get()
c_state = contextvars.ContextVar('state') c_state: contextvars.ContextVar = contextvars.ContextVar('state')
def state() -> State: def state() -> State:
@ -295,7 +294,7 @@ def make_pop_server_callback(mails_path: Path, users: list[User],
try: try:
return await asyncio.wait_for(start_session(), timeout_seconds) return await asyncio.wait_for(start_session(), timeout_seconds)
finally: finally:
stream_writer.close() writer.close()
return session_cb return session_cb
@ -304,7 +303,7 @@ async def create_pop_server(host: str,
port: int, port: int,
mails_path: Path, mails_path: Path,
users: list[User], users: list[User],
ssl_context: ssl.SSLContext = None, ssl_context: ssl.SSLContext | None = None,
timeout_seconds: int = 60): timeout_seconds: int = 60):
logging.info( logging.info(
f"Starting POP3 server {host=}, {port=}, {mails_path=}, {len(users)=}, {ssl_context != None=}, {timeout_seconds=}" f"Starting POP3 server {host=}, {port=}, {mails_path=}, {len(users)=}, {ssl_context != None=}, {timeout_seconds=}"

View File

@ -2,7 +2,6 @@ import os
from dataclasses import dataclass from dataclasses import dataclass
from enum import Enum, auto from enum import Enum, auto
from pathlib import Path from pathlib import Path
from typing import NewType, List
class ClientError(Exception): class ClientError(Exception):
@ -24,9 +23,6 @@ class AuthError(ClientError):
pass pass
User = NewType('User', str)
class Command(Enum): class Command(Enum):
CAPA = auto() CAPA = auto()
USER = auto() USER = auto()
@ -65,8 +61,8 @@ def err(arg):
return f"-ERR {arg}\r\n".encode() return f"-ERR {arg}\r\n".encode()
def parse_command(line: bytes) -> Request: def parse_command(bline: bytes) -> Request:
line = line.decode() line = bline.decode()
if not line.endswith("\r\n"): if not line.endswith("\r\n"):
raise ClientError("Invalid line ending") raise ClientError("Invalid line ending")
@ -112,13 +108,13 @@ def files_in_path(path):
return [] return []
def get_mails_list(dirpath: Path) -> List[MailEntry]: def get_mails_list(dirpath: Path) -> list[MailEntry]:
files = files_in_path(dirpath) files = files_in_path(dirpath)
entries = [MailEntry(filename, path) for filename, path in files] entries = [MailEntry(filename, path) for filename, path in files]
return entries return entries
def set_nid(entries: List[MailEntry]): def set_nid(entries: list[MailEntry]):
entries.sort(reverse=True, key=lambda e: e.c_time) entries.sort(reverse=True, key=lambda e: e.c_time)
entries = sorted(entries, reverse=True, key=lambda e: e.c_time) entries = sorted(entries, reverse=True, key=lambda e: e.c_time)
for i, entry in enumerate(entries, start=1): for i, entry in enumerate(entries, start=1):
@ -131,11 +127,12 @@ def get_mail(entry: MailEntry) -> bytes:
class MailList: class MailList:
def __init__(self, entries: List[MailEntry]):
def __init__(self, entries: list[MailEntry]):
self.entries = entries self.entries = entries
set_nid(self.entries) set_nid(self.entries)
self.mails_map = {str(e.nid): e for e in entries} self.mails_map = {str(e.nid): e for e in entries}
self.deleted_uids = set() self.deleted_uids: set[str] = set()
def delete(self, nid: str): def delete(self, nid: str):
self.deleted_uids.add(self.mails_map.pop(nid).uid) self.deleted_uids.add(self.mails_map.pop(nid).uid)
@ -149,4 +146,3 @@ class MailList:
def compute_stat(self): def compute_stat(self):
entries = self.get_all() entries = self.get_all()
return len(entries), sum(entry.size for entry in entries) return len(entries), sum(entry.size for entry in entries)

View File

@ -32,7 +32,7 @@ class MailboxCRLF(Mailbox):
self.mailbox = MaildirCRLF(mail_dir) self.mailbox = MaildirCRLF(mail_dir)
def protocol_factory_starttls(dirpath: Path, context: ssl.SSLContext = None): def protocol_factory_starttls(dirpath: Path, context: ssl.SSLContext | None = None):
logging.info("Got smtp client cb") logging.info("Got smtp client cb")
try: try:
handler = MailboxCRLF(dirpath) handler = MailboxCRLF(dirpath)
@ -63,7 +63,7 @@ def protocol_factory(dirpath: Path):
async def create_smtp_server_starttls(dirpath: Path, async def create_smtp_server_starttls(dirpath: Path,
port: int, port: int,
host="", host="",
context: ssl.SSLContext = None): context: ssl.SSLContext | None= None):
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
return await loop.create_server(partial(protocol_factory_starttls, dirpath, return await loop.create_server(partial(protocol_factory_starttls, dirpath,
context), context),
@ -75,7 +75,7 @@ async def create_smtp_server_starttls(dirpath: Path,
async def create_smtp_server_tls(dirpath: Path, async def create_smtp_server_tls(dirpath: Path,
port: int, port: int,
host="", host="",
context: ssl.SSLContext = None): context: ssl.SSLContext | None= None):
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
return await loop.create_server(partial(protocol_factory, dirpath), return await loop.create_server(partial(protocol_factory, dirpath),
host=host, host=host,