beat
This commit is contained in:
parent
bd85de78e0
commit
06acf78b82
@ -4,11 +4,10 @@ import ssl
|
|||||||
from _contextvars import ContextVar
|
from _contextvars import ContextVar
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import ClassVar, List, Coroutine
|
from typing import ClassVar, List, Set
|
||||||
from collections import deque
|
|
||||||
|
|
||||||
from .poputils import InvalidCommand, parse_command, err, Command, ClientQuit, ClientError, AuthError, ok, msg, end, \
|
from .poputils import InvalidCommand, parse_command, err, Command, ClientQuit, ClientError, AuthError, ok, msg, end, \
|
||||||
MailStorage, Request
|
Request, MailEntry, get_mail, get_mails_list, MailList
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ -18,9 +17,8 @@ class Session:
|
|||||||
username: str = ""
|
username: str = ""
|
||||||
read_items: Path = None
|
read_items: Path = None
|
||||||
# common state
|
# common state
|
||||||
all_sessions: ClassVar[List] = []
|
all_sessions: ClassVar[Set] = set()
|
||||||
mails_path: ClassVar[Path] = Path("")
|
mails_path: ClassVar[Path] = Path("")
|
||||||
wait_for_privileges_to_drop: ClassVar[Coroutine] = None
|
|
||||||
pending_request: Request = None
|
pending_request: Request = None
|
||||||
|
|
||||||
def pop_request(self):
|
def pop_request(self):
|
||||||
@ -84,7 +82,9 @@ async def handle_user_pass_auth(user_cmd):
|
|||||||
password = cmd.arg1
|
password = cmd.arg1
|
||||||
validate_user_and_pass(username, password)
|
validate_user_and_pass(username, password)
|
||||||
write(ok("Good"))
|
write(ok("Good"))
|
||||||
return username, password
|
logging.info(f"User: {username} has logged in successfully")
|
||||||
|
session.username = username
|
||||||
|
Session.all_sessions.add(username)
|
||||||
|
|
||||||
|
|
||||||
async def auth_stage():
|
async def auth_stage():
|
||||||
@ -98,9 +98,7 @@ async def auth_stage():
|
|||||||
write(msg("USER"))
|
write(msg("USER"))
|
||||||
write(end())
|
write(end())
|
||||||
else:
|
else:
|
||||||
username, password = await handle_user_pass_auth(req)
|
return await handle_user_pass_auth(req)
|
||||||
logging.info(f"User: {username} has logged in successfully")
|
|
||||||
return username
|
|
||||||
except AuthError:
|
except AuthError:
|
||||||
write(err("Wrong auth"))
|
write(err("Wrong auth"))
|
||||||
except ClientQuit as c:
|
except ClientQuit as c:
|
||||||
@ -111,65 +109,96 @@ async def auth_stage():
|
|||||||
raise ClientError("Failed to authenticate")
|
raise ClientError("Failed to authenticate")
|
||||||
|
|
||||||
|
|
||||||
async def transaction_stage():
|
async def process_transactions(mails_list: List[MailEntry]):
|
||||||
session: Session = current_session.get()
|
session: Session = current_session.get()
|
||||||
logging.debug(f"Entering transaction stage for {session.username}")
|
|
||||||
deleted_message_ids = []
|
|
||||||
mailbox = MailStorage(Session.mails_path / 'new')
|
|
||||||
|
|
||||||
with session.read_items.open() as f:
|
mails = MailList(mails_list)
|
||||||
read_items = set(f.read().splitlines())
|
|
||||||
|
|
||||||
mails_list = mailbox.get_mails_list()
|
|
||||||
mails_map = {str(entry.nid): entry for entry in mails_list}
|
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
req = await next_req()
|
req = await session.next_req()
|
||||||
logging.debug(f"Request: {req}")
|
logging.debug(f"Request: {req}")
|
||||||
if req.cmd is Command.CAPA:
|
if req.cmd is Command.CAPA:
|
||||||
write(ok("No CAPA"))
|
write(ok("CAPA follows"))
|
||||||
|
write(msg("UIDL"))
|
||||||
write(end())
|
write(end())
|
||||||
elif req.cmd is Command.STAT:
|
elif req.cmd is Command.STAT:
|
||||||
num, size = mailbox.get_mailbox_size()
|
num, size = mails.compute_stat()
|
||||||
write(ok(f"{num} {size}"))
|
write(ok(f"{num} {size}"))
|
||||||
elif req.cmd is Command.LIST:
|
elif req.cmd is Command.LIST:
|
||||||
if req.arg1:
|
if req.arg1:
|
||||||
write(ok(f"{req.arg1} {mails_map[req.arg1].size}"))
|
entry = mails.get(req.arg1)
|
||||||
|
if entry:
|
||||||
|
write(ok(f"{req.arg1} {entry.size}"))
|
||||||
|
else:
|
||||||
|
write(err("Not found"))
|
||||||
else:
|
else:
|
||||||
write(ok("Mails follow"))
|
write(ok("Mails follow"))
|
||||||
for entry in mails_list:
|
for entry in mails.get_all():
|
||||||
write(msg(f"{entry.nid} {entry.size}"))
|
write(msg(f"{entry.nid} {entry.size}"))
|
||||||
write(end())
|
write(end())
|
||||||
elif req.cmd is Command.UIDL:
|
elif req.cmd is Command.UIDL:
|
||||||
if req.arg1:
|
if req.arg1:
|
||||||
write(ok(f"{req.arg1} {mails_map[req.arg1].uid}"))
|
entry = mails.get(req.arg1)
|
||||||
|
if entry:
|
||||||
|
write(ok(f"{req.arg1} {entry.uid}"))
|
||||||
|
else:
|
||||||
|
write(err("Not found"))
|
||||||
else:
|
else:
|
||||||
write(ok("Mails follow"))
|
write(ok("Mails follow"))
|
||||||
for entry in mails_list:
|
for entry in mails.get_all():
|
||||||
write(msg(f"{entry.nid} {entry.uid}"))
|
write(msg(f"{entry.nid} {entry.uid}"))
|
||||||
write(end())
|
write(end())
|
||||||
await session.writer.drain()
|
await session.writer.drain()
|
||||||
elif req.cmd is Command.RETR:
|
elif req.cmd is Command.RETR:
|
||||||
if req.arg1 not in mails_map:
|
entry = mails.get(req.arg1)
|
||||||
write(err("Not found"))
|
if entry:
|
||||||
else:
|
|
||||||
write(ok("Contents follow"))
|
write(ok("Contents follow"))
|
||||||
write(mailbox.get_mail(mails_map[req.arg1]))
|
write(get_mail(entry))
|
||||||
write(end())
|
write(end())
|
||||||
await session.writer.drain()
|
await session.writer.drain()
|
||||||
|
else:
|
||||||
|
write(err("Not found"))
|
||||||
|
elif req.cmd is Command.DELE:
|
||||||
|
entry = mails.get(req.arg1)
|
||||||
|
if entry:
|
||||||
|
mails.delete(req.arg1)
|
||||||
|
else:
|
||||||
|
write(err("Not found"))
|
||||||
|
elif req.cmd is Command.RSET:
|
||||||
|
mails = MailList(mails_list)
|
||||||
|
elif req.cmd is Command.NOOP:
|
||||||
|
pass
|
||||||
else:
|
else:
|
||||||
write(err("Not implemented"))
|
write(err("Not implemented"))
|
||||||
|
raise ClientError("We shouldn't reach here")
|
||||||
except ClientQuit:
|
except ClientQuit:
|
||||||
write(ok("Bye"))
|
write(ok("Bye"))
|
||||||
return deleted_message_ids
|
return mails.deleted_uids
|
||||||
|
|
||||||
|
|
||||||
|
async def transaction_stage():
|
||||||
|
session: Session = current_session.get()
|
||||||
|
logging.debug(f"Entering transaction stage for {session.username}")
|
||||||
|
session.read_items = Session.mails_path / session.username
|
||||||
|
|
||||||
|
with session.read_items.open() as f:
|
||||||
|
read_items = set(f.read().splitlines())
|
||||||
|
|
||||||
|
mails_list = [entry for entry in get_mails_list(Session.mails_path / 'new') if entry.uid not in read_items]
|
||||||
|
return await process_transactions(mails_list)
|
||||||
|
|
||||||
|
|
||||||
def delete_messages(delete_ids):
|
def delete_messages(delete_ids):
|
||||||
|
session: Session = current_session.get()
|
||||||
|
with session.read_items.open(mode="w") as f:
|
||||||
|
f.writelines(delete_ids)
|
||||||
logging.info(f"Client deleted these ids {delete_ids}")
|
logging.info(f"Client deleted these ids {delete_ids}")
|
||||||
|
|
||||||
|
|
||||||
async def new_session(stream_reader: asyncio.StreamReader, stream_writer: asyncio.StreamWriter):
|
async def new_session(stream_reader: asyncio.StreamReader, stream_writer: asyncio.StreamWriter):
|
||||||
current_session.set(Session(stream_reader, stream_writer))
|
session = Session(stream_reader, stream_writer)
|
||||||
|
current_session.set(session)
|
||||||
logging.info(f"New session started with {stream_reader} and {stream_writer}")
|
logging.info(f"New session started with {stream_reader} and {stream_writer}")
|
||||||
try:
|
try:
|
||||||
await auth_stage()
|
await auth_stage()
|
||||||
@ -182,6 +211,8 @@ async def new_session(stream_reader: asyncio.StreamReader, stream_writer: asynci
|
|||||||
logging.error(f"Serious client error", e)
|
logging.error(f"Serious client error", e)
|
||||||
raise
|
raise
|
||||||
finally:
|
finally:
|
||||||
|
if session.username:
|
||||||
|
Session.all_sessions.remove(session.username)
|
||||||
stream_writer.close()
|
stream_writer.close()
|
||||||
|
|
||||||
|
|
||||||
|
@ -28,15 +28,17 @@ User = NewType('User', str)
|
|||||||
|
|
||||||
|
|
||||||
class Command(Enum):
|
class Command(Enum):
|
||||||
|
CAPA = auto()
|
||||||
USER = auto()
|
USER = auto()
|
||||||
PASS = auto()
|
PASS = auto()
|
||||||
CAPA = auto()
|
|
||||||
QUIT = auto()
|
QUIT = auto()
|
||||||
|
STAT = auto()
|
||||||
LIST = auto()
|
LIST = auto()
|
||||||
UIDL = auto()
|
UIDL = auto()
|
||||||
RETR = auto()
|
RETR = auto()
|
||||||
DELE = auto()
|
DELE = auto()
|
||||||
STAT = auto()
|
RSET = auto()
|
||||||
|
NOOP = auto()
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ -88,11 +90,6 @@ def parse_command(line: bytes) -> Request:
|
|||||||
return request
|
return request
|
||||||
|
|
||||||
|
|
||||||
def files_in_path(path):
|
|
||||||
for _, _, files in os.walk(path):
|
|
||||||
return [(f, os.path.join(path, f)) for f in files]
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class MailEntry:
|
class MailEntry:
|
||||||
uid: str
|
uid: str
|
||||||
@ -109,23 +106,46 @@ class MailEntry:
|
|||||||
self.path = path
|
self.path = path
|
||||||
|
|
||||||
|
|
||||||
class MailStorage:
|
def files_in_path(path):
|
||||||
def __init__(self, dirpath: Path):
|
for _, _, files in os.walk(path):
|
||||||
self.dirpath = dirpath
|
return [(f, os.path.join(path, f)) for f in files]
|
||||||
self.files = files_in_path(self.dirpath)
|
|
||||||
self.entries = [MailEntry(filename, path) for filename, path in self.files]
|
|
||||||
self.entries = sorted(self.entries, reverse=True, key=lambda e: e.c_time)
|
|
||||||
for i, entry in enumerate(self.entries, start=1):
|
|
||||||
entry.nid = i
|
|
||||||
|
|
||||||
def get_mailbox_size(self) -> (int, int):
|
|
||||||
return len(self.entries), sum(entry.size for entry in self.entries)
|
|
||||||
|
|
||||||
def get_mails_list(self) -> List[MailEntry]:
|
def get_mails_list(dirpath: Path) -> List[MailEntry]:
|
||||||
return self.entries
|
files = files_in_path(dirpath)
|
||||||
|
entries = [MailEntry(filename, path) for filename, path in files]
|
||||||
|
return entries
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def get_mail(entry: MailEntry) -> bytes:
|
def set_nid(entries: List[MailEntry]):
|
||||||
with open(entry.path, mode='rb') as fp:
|
entries.sort(reverse=True, key=lambda e: e.c_time)
|
||||||
return fp.read()
|
entries = sorted(entries, reverse=True, key=lambda e: e.c_time)
|
||||||
|
for i, entry in enumerate(entries, start=1):
|
||||||
|
entry.nid = i
|
||||||
|
|
||||||
|
|
||||||
|
def get_mail(entry: MailEntry) -> bytes:
|
||||||
|
with open(entry.path, mode='rb') as fp:
|
||||||
|
return fp.read()
|
||||||
|
|
||||||
|
|
||||||
|
class MailList:
|
||||||
|
def __init__(self, entries: List[MailEntry]):
|
||||||
|
self.entries = entries
|
||||||
|
set_nid(self.entries)
|
||||||
|
self.mails_map = {str(e.nid): e for e in entries}
|
||||||
|
self.deleted_uids = set()
|
||||||
|
|
||||||
|
def delete(self, nid: str):
|
||||||
|
self.deleted_uids.add(self.mails_map.pop(nid).uid)
|
||||||
|
|
||||||
|
def get(self, nid: str):
|
||||||
|
self.mails_map.get(nid)
|
||||||
|
|
||||||
|
def get_all(self):
|
||||||
|
return [e for e in self.entries if str(e.nid) in self.mails_map]
|
||||||
|
|
||||||
|
def compute_stat(self):
|
||||||
|
entries = self.get_all()
|
||||||
|
return len(entries), sum(entry.size for entry in entries)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user