This commit is contained in:
Balakrishnan Balasubramanian 2018-12-19 00:57:50 -05:00
parent bd85de78e0
commit 06acf78b82
2 changed files with 104 additions and 53 deletions

View File

@ -4,11 +4,10 @@ import ssl
from _contextvars import ContextVar
from dataclasses import dataclass
from pathlib import Path
from typing import ClassVar, List, Coroutine
from collections import deque
from typing import ClassVar, List, Set
from .poputils import InvalidCommand, parse_command, err, Command, ClientQuit, ClientError, AuthError, ok, msg, end, \
MailStorage, Request
Request, MailEntry, get_mail, get_mails_list, MailList
@dataclass
@ -18,9 +17,8 @@ class Session:
username: str = ""
read_items: Path = None
# common state
all_sessions: ClassVar[List] = []
all_sessions: ClassVar[Set] = set()
mails_path: ClassVar[Path] = Path("")
wait_for_privileges_to_drop: ClassVar[Coroutine] = None
pending_request: Request = None
def pop_request(self):
@ -84,7 +82,9 @@ async def handle_user_pass_auth(user_cmd):
password = cmd.arg1
validate_user_and_pass(username, password)
write(ok("Good"))
return username, password
logging.info(f"User: {username} has logged in successfully")
session.username = username
Session.all_sessions.add(username)
async def auth_stage():
@ -98,9 +98,7 @@ async def auth_stage():
write(msg("USER"))
write(end())
else:
username, password = await handle_user_pass_auth(req)
logging.info(f"User: {username} has logged in successfully")
return username
return await handle_user_pass_auth(req)
except AuthError:
write(err("Wrong auth"))
except ClientQuit as c:
@ -111,65 +109,96 @@ async def auth_stage():
raise ClientError("Failed to authenticate")
async def transaction_stage():
async def process_transactions(mails_list: List[MailEntry]):
session: Session = current_session.get()
logging.debug(f"Entering transaction stage for {session.username}")
deleted_message_ids = []
mailbox = MailStorage(Session.mails_path / 'new')
with session.read_items.open() as f:
read_items = set(f.read().splitlines())
mails = MailList(mails_list)
mails_list = mailbox.get_mails_list()
mails_map = {str(entry.nid): entry for entry in mails_list}
while True:
try:
req = await next_req()
req = await session.next_req()
logging.debug(f"Request: {req}")
if req.cmd is Command.CAPA:
write(ok("No CAPA"))
write(ok("CAPA follows"))
write(msg("UIDL"))
write(end())
elif req.cmd is Command.STAT:
num, size = mailbox.get_mailbox_size()
num, size = mails.compute_stat()
write(ok(f"{num} {size}"))
elif req.cmd is Command.LIST:
if req.arg1:
write(ok(f"{req.arg1} {mails_map[req.arg1].size}"))
entry = mails.get(req.arg1)
if entry:
write(ok(f"{req.arg1} {entry.size}"))
else:
write(err("Not found"))
else:
write(ok("Mails follow"))
for entry in mails_list:
for entry in mails.get_all():
write(msg(f"{entry.nid} {entry.size}"))
write(end())
elif req.cmd is Command.UIDL:
if req.arg1:
write(ok(f"{req.arg1} {mails_map[req.arg1].uid}"))
entry = mails.get(req.arg1)
if entry:
write(ok(f"{req.arg1} {entry.uid}"))
else:
write(err("Not found"))
else:
write(ok("Mails follow"))
for entry in mails_list:
for entry in mails.get_all():
write(msg(f"{entry.nid} {entry.uid}"))
write(end())
await session.writer.drain()
elif req.cmd is Command.RETR:
if req.arg1 not in mails_map:
write(err("Not found"))
else:
entry = mails.get(req.arg1)
if entry:
write(ok("Contents follow"))
write(mailbox.get_mail(mails_map[req.arg1]))
write(get_mail(entry))
write(end())
await session.writer.drain()
else:
write(err("Not found"))
elif req.cmd is Command.DELE:
entry = mails.get(req.arg1)
if entry:
mails.delete(req.arg1)
else:
write(err("Not found"))
elif req.cmd is Command.RSET:
mails = MailList(mails_list)
elif req.cmd is Command.NOOP:
pass
else:
write(err("Not implemented"))
raise ClientError("We shouldn't reach here")
except ClientQuit:
write(ok("Bye"))
return deleted_message_ids
return mails.deleted_uids
async def transaction_stage():
session: Session = current_session.get()
logging.debug(f"Entering transaction stage for {session.username}")
session.read_items = Session.mails_path / session.username
with session.read_items.open() as f:
read_items = set(f.read().splitlines())
mails_list = [entry for entry in get_mails_list(Session.mails_path / 'new') if entry.uid not in read_items]
return await process_transactions(mails_list)
def delete_messages(delete_ids):
session: Session = current_session.get()
with session.read_items.open(mode="w") as f:
f.writelines(delete_ids)
logging.info(f"Client deleted these ids {delete_ids}")
async def new_session(stream_reader: asyncio.StreamReader, stream_writer: asyncio.StreamWriter):
current_session.set(Session(stream_reader, stream_writer))
session = Session(stream_reader, stream_writer)
current_session.set(session)
logging.info(f"New session started with {stream_reader} and {stream_writer}")
try:
await auth_stage()
@ -182,6 +211,8 @@ async def new_session(stream_reader: asyncio.StreamReader, stream_writer: asynci
logging.error(f"Serious client error", e)
raise
finally:
if session.username:
Session.all_sessions.remove(session.username)
stream_writer.close()

View File

@ -28,15 +28,17 @@ User = NewType('User', str)
class Command(Enum):
CAPA = auto()
USER = auto()
PASS = auto()
CAPA = auto()
QUIT = auto()
STAT = auto()
LIST = auto()
UIDL = auto()
RETR = auto()
DELE = auto()
STAT = auto()
RSET = auto()
NOOP = auto()
@dataclass
@ -88,11 +90,6 @@ def parse_command(line: bytes) -> Request:
return request
def files_in_path(path):
for _, _, files in os.walk(path):
return [(f, os.path.join(path, f)) for f in files]
@dataclass
class MailEntry:
uid: str
@ -109,23 +106,46 @@ class MailEntry:
self.path = path
class MailStorage:
def __init__(self, dirpath: Path):
self.dirpath = dirpath
self.files = files_in_path(self.dirpath)
self.entries = [MailEntry(filename, path) for filename, path in self.files]
self.entries = sorted(self.entries, reverse=True, key=lambda e: e.c_time)
for i, entry in enumerate(self.entries, start=1):
entry.nid = i
def files_in_path(path):
for _, _, files in os.walk(path):
return [(f, os.path.join(path, f)) for f in files]
def get_mailbox_size(self) -> (int, int):
return len(self.entries), sum(entry.size for entry in self.entries)
def get_mails_list(self) -> List[MailEntry]:
return self.entries
def get_mails_list(dirpath: Path) -> List[MailEntry]:
files = files_in_path(dirpath)
entries = [MailEntry(filename, path) for filename, path in files]
return entries
@staticmethod
def get_mail(entry: MailEntry) -> bytes:
with open(entry.path, mode='rb') as fp:
return fp.read()
def set_nid(entries: List[MailEntry]):
entries.sort(reverse=True, key=lambda e: e.c_time)
entries = sorted(entries, reverse=True, key=lambda e: e.c_time)
for i, entry in enumerate(entries, start=1):
entry.nid = i
def get_mail(entry: MailEntry) -> bytes:
with open(entry.path, mode='rb') as fp:
return fp.read()
class MailList:
def __init__(self, entries: List[MailEntry]):
self.entries = entries
set_nid(self.entries)
self.mails_map = {str(e.nid): e for e in entries}
self.deleted_uids = set()
def delete(self, nid: str):
self.deleted_uids.add(self.mails_map.pop(nid).uid)
def get(self, nid: str):
self.mails_map.get(nid)
def get_all(self):
return [e for e in self.entries if str(e.nid) in self.mails_map]
def compute_stat(self):
entries = self.get_all()
return len(entries), sum(entry.size for entry in entries)