This commit is contained in:
Balakrishnan Balasubramanian 2018-12-19 00:57:50 -05:00
parent bd85de78e0
commit 06acf78b82
2 changed files with 104 additions and 53 deletions

View File

@ -4,11 +4,10 @@ import ssl
from _contextvars import ContextVar from _contextvars import ContextVar
from dataclasses import dataclass from dataclasses import dataclass
from pathlib import Path from pathlib import Path
from typing import ClassVar, List, Coroutine from typing import ClassVar, List, Set
from collections import deque
from .poputils import InvalidCommand, parse_command, err, Command, ClientQuit, ClientError, AuthError, ok, msg, end, \ from .poputils import InvalidCommand, parse_command, err, Command, ClientQuit, ClientError, AuthError, ok, msg, end, \
MailStorage, Request Request, MailEntry, get_mail, get_mails_list, MailList
@dataclass @dataclass
@ -18,9 +17,8 @@ class Session:
username: str = "" username: str = ""
read_items: Path = None read_items: Path = None
# common state # common state
all_sessions: ClassVar[List] = [] all_sessions: ClassVar[Set] = set()
mails_path: ClassVar[Path] = Path("") mails_path: ClassVar[Path] = Path("")
wait_for_privileges_to_drop: ClassVar[Coroutine] = None
pending_request: Request = None pending_request: Request = None
def pop_request(self): def pop_request(self):
@ -84,7 +82,9 @@ async def handle_user_pass_auth(user_cmd):
password = cmd.arg1 password = cmd.arg1
validate_user_and_pass(username, password) validate_user_and_pass(username, password)
write(ok("Good")) write(ok("Good"))
return username, password logging.info(f"User: {username} has logged in successfully")
session.username = username
Session.all_sessions.add(username)
async def auth_stage(): async def auth_stage():
@ -98,9 +98,7 @@ async def auth_stage():
write(msg("USER")) write(msg("USER"))
write(end()) write(end())
else: else:
username, password = await handle_user_pass_auth(req) return await handle_user_pass_auth(req)
logging.info(f"User: {username} has logged in successfully")
return username
except AuthError: except AuthError:
write(err("Wrong auth")) write(err("Wrong auth"))
except ClientQuit as c: except ClientQuit as c:
@ -111,65 +109,96 @@ async def auth_stage():
raise ClientError("Failed to authenticate") raise ClientError("Failed to authenticate")
async def transaction_stage(): async def process_transactions(mails_list: List[MailEntry]):
session: Session = current_session.get() session: Session = current_session.get()
logging.debug(f"Entering transaction stage for {session.username}")
deleted_message_ids = []
mailbox = MailStorage(Session.mails_path / 'new')
with session.read_items.open() as f: mails = MailList(mails_list)
read_items = set(f.read().splitlines())
mails_list = mailbox.get_mails_list()
mails_map = {str(entry.nid): entry for entry in mails_list}
while True: while True:
try: try:
req = await next_req() req = await session.next_req()
logging.debug(f"Request: {req}") logging.debug(f"Request: {req}")
if req.cmd is Command.CAPA: if req.cmd is Command.CAPA:
write(ok("No CAPA")) write(ok("CAPA follows"))
write(msg("UIDL"))
write(end()) write(end())
elif req.cmd is Command.STAT: elif req.cmd is Command.STAT:
num, size = mailbox.get_mailbox_size() num, size = mails.compute_stat()
write(ok(f"{num} {size}")) write(ok(f"{num} {size}"))
elif req.cmd is Command.LIST: elif req.cmd is Command.LIST:
if req.arg1: if req.arg1:
write(ok(f"{req.arg1} {mails_map[req.arg1].size}")) entry = mails.get(req.arg1)
if entry:
write(ok(f"{req.arg1} {entry.size}"))
else:
write(err("Not found"))
else: else:
write(ok("Mails follow")) write(ok("Mails follow"))
for entry in mails_list: for entry in mails.get_all():
write(msg(f"{entry.nid} {entry.size}")) write(msg(f"{entry.nid} {entry.size}"))
write(end()) write(end())
elif req.cmd is Command.UIDL: elif req.cmd is Command.UIDL:
if req.arg1: if req.arg1:
write(ok(f"{req.arg1} {mails_map[req.arg1].uid}")) entry = mails.get(req.arg1)
if entry:
write(ok(f"{req.arg1} {entry.uid}"))
else:
write(err("Not found"))
else: else:
write(ok("Mails follow")) write(ok("Mails follow"))
for entry in mails_list: for entry in mails.get_all():
write(msg(f"{entry.nid} {entry.uid}")) write(msg(f"{entry.nid} {entry.uid}"))
write(end()) write(end())
await session.writer.drain() await session.writer.drain()
elif req.cmd is Command.RETR: elif req.cmd is Command.RETR:
if req.arg1 not in mails_map: entry = mails.get(req.arg1)
write(err("Not found")) if entry:
else:
write(ok("Contents follow")) write(ok("Contents follow"))
write(mailbox.get_mail(mails_map[req.arg1])) write(get_mail(entry))
write(end()) write(end())
await session.writer.drain() await session.writer.drain()
else:
write(err("Not found"))
elif req.cmd is Command.DELE:
entry = mails.get(req.arg1)
if entry:
mails.delete(req.arg1)
else:
write(err("Not found"))
elif req.cmd is Command.RSET:
mails = MailList(mails_list)
elif req.cmd is Command.NOOP:
pass
else: else:
write(err("Not implemented")) write(err("Not implemented"))
raise ClientError("We shouldn't reach here")
except ClientQuit: except ClientQuit:
write(ok("Bye")) write(ok("Bye"))
return deleted_message_ids return mails.deleted_uids
async def transaction_stage():
session: Session = current_session.get()
logging.debug(f"Entering transaction stage for {session.username}")
session.read_items = Session.mails_path / session.username
with session.read_items.open() as f:
read_items = set(f.read().splitlines())
mails_list = [entry for entry in get_mails_list(Session.mails_path / 'new') if entry.uid not in read_items]
return await process_transactions(mails_list)
def delete_messages(delete_ids): def delete_messages(delete_ids):
session: Session = current_session.get()
with session.read_items.open(mode="w") as f:
f.writelines(delete_ids)
logging.info(f"Client deleted these ids {delete_ids}") logging.info(f"Client deleted these ids {delete_ids}")
async def new_session(stream_reader: asyncio.StreamReader, stream_writer: asyncio.StreamWriter): async def new_session(stream_reader: asyncio.StreamReader, stream_writer: asyncio.StreamWriter):
current_session.set(Session(stream_reader, stream_writer)) session = Session(stream_reader, stream_writer)
current_session.set(session)
logging.info(f"New session started with {stream_reader} and {stream_writer}") logging.info(f"New session started with {stream_reader} and {stream_writer}")
try: try:
await auth_stage() await auth_stage()
@ -182,6 +211,8 @@ async def new_session(stream_reader: asyncio.StreamReader, stream_writer: asynci
logging.error(f"Serious client error", e) logging.error(f"Serious client error", e)
raise raise
finally: finally:
if session.username:
Session.all_sessions.remove(session.username)
stream_writer.close() stream_writer.close()

View File

@ -28,15 +28,17 @@ User = NewType('User', str)
class Command(Enum): class Command(Enum):
CAPA = auto()
USER = auto() USER = auto()
PASS = auto() PASS = auto()
CAPA = auto()
QUIT = auto() QUIT = auto()
STAT = auto()
LIST = auto() LIST = auto()
UIDL = auto() UIDL = auto()
RETR = auto() RETR = auto()
DELE = auto() DELE = auto()
STAT = auto() RSET = auto()
NOOP = auto()
@dataclass @dataclass
@ -88,11 +90,6 @@ def parse_command(line: bytes) -> Request:
return request return request
def files_in_path(path):
for _, _, files in os.walk(path):
return [(f, os.path.join(path, f)) for f in files]
@dataclass @dataclass
class MailEntry: class MailEntry:
uid: str uid: str
@ -109,23 +106,46 @@ class MailEntry:
self.path = path self.path = path
class MailStorage: def files_in_path(path):
def __init__(self, dirpath: Path): for _, _, files in os.walk(path):
self.dirpath = dirpath return [(f, os.path.join(path, f)) for f in files]
self.files = files_in_path(self.dirpath)
self.entries = [MailEntry(filename, path) for filename, path in self.files]
self.entries = sorted(self.entries, reverse=True, key=lambda e: e.c_time)
for i, entry in enumerate(self.entries, start=1):
entry.nid = i
def get_mailbox_size(self) -> (int, int):
return len(self.entries), sum(entry.size for entry in self.entries)
def get_mails_list(self) -> List[MailEntry]: def get_mails_list(dirpath: Path) -> List[MailEntry]:
return self.entries files = files_in_path(dirpath)
entries = [MailEntry(filename, path) for filename, path in files]
return entries
@staticmethod
def get_mail(entry: MailEntry) -> bytes: def set_nid(entries: List[MailEntry]):
with open(entry.path, mode='rb') as fp: entries.sort(reverse=True, key=lambda e: e.c_time)
return fp.read() entries = sorted(entries, reverse=True, key=lambda e: e.c_time)
for i, entry in enumerate(entries, start=1):
entry.nid = i
def get_mail(entry: MailEntry) -> bytes:
with open(entry.path, mode='rb') as fp:
return fp.read()
class MailList:
def __init__(self, entries: List[MailEntry]):
self.entries = entries
set_nid(self.entries)
self.mails_map = {str(e.nid): e for e in entries}
self.deleted_uids = set()
def delete(self, nid: str):
self.deleted_uids.add(self.mails_map.pop(nid).uid)
def get(self, nid: str):
self.mails_map.get(nid)
def get_all(self):
return [e for e in self.entries if str(e.nid) in self.mails_map]
def compute_stat(self):
entries = self.get_all()
return len(entries), sum(entry.size for entry in entries)