This commit is contained in:
Balakrishnan Balasubramanian 2018-12-19 02:01:43 -05:00
parent 06acf78b82
commit 7abd6e9e13

View File

@ -14,93 +14,97 @@ from .poputils import InvalidCommand, parse_command, err, Command, ClientQuit, C
class Session: class Session:
reader: asyncio.StreamReader reader: asyncio.StreamReader
writer: asyncio.StreamWriter writer: asyncio.StreamWriter
username: str = ""
read_items: Path = None
# common state # common state
all_sessions: ClassVar[Set] = set() all_sessions: ClassVar[Set] = set()
mails_path: ClassVar[Path] = Path("") mails_path: ClassVar[Path] = Path("")
pending_request: Request = None current_session: ClassVar = ContextVar("session")
def pop_request(self): @classmethod
request = self.pending_request def get(cls):
self.pending_request = None return cls.current_session.get()
return request
async def next_req(self): @classmethod
if self.pending_request: def reader(cls):
return self.pop_request() return cls.get().reader
for _ in range(InvalidCommand.RETRIES): @classmethod
line = await self.reader.readline() def writer(cls):
logging.debug(f"Client: {line}") return cls.get().writer
if not line:
continue
try: async def next_req():
request: Request = parse_command(line) for _ in range(InvalidCommand.RETRIES):
except InvalidCommand: line = await Session.reader().readline()
write(err("Bad command")) logging.debug(f"Client: {line}")
else: if not line:
if request.cmd == Command.QUIT: continue
raise ClientQuit try:
return request request: Request = parse_command(line)
except InvalidCommand:
write(err("Bad command"))
else: else:
raise ClientError(f"Bad command {InvalidCommand.RETRIES} times") if request.cmd == Command.QUIT:
raise ClientQuit
async def expect_cmd(self, *commands: Command, optional=False): return request
req = await self.next_req() else:
if req.cmd not in commands: raise ClientError(f"Bad command {InvalidCommand.RETRIES} times")
if not optional:
logging.error(f"{req.cmd} is not in {commands}")
raise ClientError
else:
self.pending_request = req
return
return req
current_session: ContextVar[Session] = ContextVar("session") async def expect_cmd(*commands: Command):
req = await next_req()
if req.cmd not in commands:
logging.error(f"Unexpected command: {req.cmd} is not in {commands}")
raise ClientError
return req
def write(data): def write(data):
logging.debug(f"Server: {data}") logging.debug(f"Server: {data}")
session: Session = current_session.get() session: Session = Session.current_session.get()
session.writer.write(data) session.writer.write(data)
async def drain():
session: Session = Session.current_session.get()
await session.writer.drain()
def validate_user_and_pass(username, password): def validate_user_and_pass(username, password):
if username != password: if username != password:
raise AuthError("Invalid user pass") raise AuthError("Invalid user pass")
async def handle_user_pass_auth(user_cmd): async def handle_user_pass_auth(user_cmd):
session: Session = current_session.get()
username = user_cmd.arg1 username = user_cmd.arg1
if not username: if not username:
raise AuthError("Invalid USER command. username empty") raise AuthError("Invalid USER command. username empty")
write(ok("Welcome")) write(ok("Welcome"))
cmd = await session.expect_cmd(Command.PASS) cmd = await expect_cmd(Command.PASS)
password = cmd.arg1 password = cmd.arg1
validate_user_and_pass(username, password) validate_user_and_pass(username, password)
write(ok("Good"))
logging.info(f"User: {username} has logged in successfully") logging.info(f"User: {username} has logged in successfully")
session.username = username return username
Session.all_sessions.add(username)
async def auth_stage(): async def auth_stage():
session: Session = current_session.get()
write(ok("Server Ready")) write(ok("Server Ready"))
for _ in range(AuthError.RETRIES): for _ in range(AuthError.RETRIES):
try: try:
req = await session.expect_cmd(Command.USER, Command.CAPA) req = await expect_cmd(Command.USER, Command.CAPA)
if req.cmd is Command.CAPA: if req.cmd is Command.CAPA:
write(ok("Following are supported")) write(ok("Following are supported"))
write(msg("USER")) write(msg("USER"))
write(end()) write(end())
else: else:
return await handle_user_pass_auth(req) username = await handle_user_pass_auth(req)
except AuthError: if username in Session.all_sessions:
write(err("Wrong auth")) logging.warning(f"User: {username} already has an active session")
raise AuthError("Already logged in")
else:
write(ok("Login successful"))
except AuthError as ae:
write(err(f"Auth Failed: {ae}"))
except ClientQuit as c: except ClientQuit as c:
write(ok("Bye")) write(ok("Bye"))
logging.warning("Client has QUIT before auth succeeded") logging.warning("Client has QUIT before auth succeeded")
@ -109,101 +113,135 @@ async def auth_stage():
raise ClientError("Failed to authenticate") raise ClientError("Failed to authenticate")
async def process_transactions(mails_list: List[MailEntry]): def trans_command_capa(_, __):
session: Session = current_session.get() write(ok("CAPA follows"))
write(msg("UIDL"))
write(end())
def trans_command_stat(mails, _):
num, size = mails.compute_stat()
write(ok(f"{num} {size}"))
def trans_command_list(mails, req):
if req.arg1:
entry = mails.get(req.arg1)
if entry:
write(ok(f"{req.arg1} {entry.size}"))
else:
write(err("Not found"))
else:
write(ok("Mails follow"))
for entry in mails.get_all():
write(msg(f"{entry.nid} {entry.size}"))
write(end())
await drain()
def trans_command_uidl(mails, req):
if req.arg1:
entry = mails.get(req.arg1)
if entry:
write(ok(f"{req.arg1} {entry.uid}"))
else:
write(err("Not found"))
else:
write(ok("Mails follow"))
for entry in mails.get_all():
write(msg(f"{entry.nid} {entry.uid}"))
write(end())
await drain()
def trans_command_retr(mails, req):
entry = mails.get(req.arg1)
if entry:
write(ok("Contents follow"))
write(get_mail(entry))
write(end())
drain()
else:
write(err("Not found"))
def trans_command_dele(mails, req):
entry = mails.get(req.arg1)
if entry:
mails.delete(req.arg1)
else:
write(err("Not found"))
def trans_command_noop(_, __):
write(ok("Hmm"))
async def process_transactions(mails_list: List[MailEntry]):
mails = MailList(mails_list) mails = MailList(mails_list)
def reset(_, __):
nonlocal mails
mails = MailList(mails_list)
handle_map = {
Command.CAPA: trans_command_capa,
Command.STAT: trans_command_stat,
Command.LIST: trans_command_list,
Command.UIDL: trans_command_uidl,
Command.RETR: trans_command_retr,
Command.DELE: trans_command_dele,
Command.RSET: reset,
Command.NOOP: trans_command_noop,
}
while True: while True:
try: try:
req = await session.next_req() req = await next_req()
logging.debug(f"Request: {req}")
if req.cmd is Command.CAPA:
write(ok("CAPA follows"))
write(msg("UIDL"))
write(end())
elif req.cmd is Command.STAT:
num, size = mails.compute_stat()
write(ok(f"{num} {size}"))
elif req.cmd is Command.LIST:
if req.arg1:
entry = mails.get(req.arg1)
if entry:
write(ok(f"{req.arg1} {entry.size}"))
else:
write(err("Not found"))
else:
write(ok("Mails follow"))
for entry in mails.get_all():
write(msg(f"{entry.nid} {entry.size}"))
write(end())
elif req.cmd is Command.UIDL:
if req.arg1:
entry = mails.get(req.arg1)
if entry:
write(ok(f"{req.arg1} {entry.uid}"))
else:
write(err("Not found"))
else:
write(ok("Mails follow"))
for entry in mails.get_all():
write(msg(f"{entry.nid} {entry.uid}"))
write(end())
await session.writer.drain()
elif req.cmd is Command.RETR:
entry = mails.get(req.arg1)
if entry:
write(ok("Contents follow"))
write(get_mail(entry))
write(end())
await session.writer.drain()
else:
write(err("Not found"))
elif req.cmd is Command.DELE:
entry = mails.get(req.arg1)
if entry:
mails.delete(req.arg1)
else:
write(err("Not found"))
elif req.cmd is Command.RSET:
mails = MailList(mails_list)
elif req.cmd is Command.NOOP:
pass
else:
write(err("Not implemented"))
raise ClientError("We shouldn't reach here")
except ClientQuit: except ClientQuit:
write(ok("Bye")) write(ok("Bye"))
return mails.deleted_uids return mails.deleted_uids
logging.debug(f"Request: {req}")
try:
func = handle_map[req.cmd]
except KeyError:
write(err("Not implemented"))
raise ClientError("We shouldn't reach here")
else:
func(mails, req)
async def transaction_stage(): async def transaction_stage(deleted_items_path: Path):
session: Session = current_session.get() with deleted_items_path.open() as f:
logging.debug(f"Entering transaction stage for {session.username}") deleted_items = set(f.read().splitlines())
session.read_items = Session.mails_path / session.username
with session.read_items.open() as f: mails_list = [entry for entry in get_mails_list(Session.mails_path / 'new') if entry.uid not in deleted_items]
read_items = set(f.read().splitlines())
mails_list = [entry for entry in get_mails_list(Session.mails_path / 'new') if entry.uid not in read_items]
return await process_transactions(mails_list) return await process_transactions(mails_list)
def delete_messages(delete_ids): def delete_messages(delete_ids, deleted_items_path: Path):
session: Session = current_session.get() with deleted_items_path.open(mode="w") as f:
with session.read_items.open(mode="w") as f:
f.writelines(delete_ids) f.writelines(delete_ids)
logging.info(f"Client deleted these ids {delete_ids}")
async def new_session(stream_reader: asyncio.StreamReader, stream_writer: asyncio.StreamWriter): async def new_session(stream_reader: asyncio.StreamReader, stream_writer: asyncio.StreamWriter):
session = Session(stream_reader, stream_writer) session = Session(stream_reader, stream_writer)
current_session.set(session) Session.current_session.set(session)
logging.info(f"New session started with {stream_reader} and {stream_writer}") logging.info(f"New session started with {stream_reader} and {stream_writer}")
username = None
try: try:
await auth_stage() username = await auth_stage()
delete_ids = await transaction_stage() assert username is not None
delete_messages(delete_ids) Session.all_sessions.add(username)
deleted_items_path = Session.mails_path / username
logging.info(f"User:{username} logged in successfully")
delete_ids = await transaction_stage(deleted_items_path)
logging.info(f"User:{username} completed transactions. Deleted:{delete_ids}")
delete_messages(delete_ids, deleted_items_path)
logging.info(f"User:{username} Saved deleted items")
except ClientError as c: except ClientError as c:
write(err("Something went wrong")) write(err("Something went wrong"))
logging.error(f"Unexpected client error", c) logging.error(f"Unexpected client error", c)
@ -211,8 +249,8 @@ async def new_session(stream_reader: asyncio.StreamReader, stream_writer: asynci
logging.error(f"Serious client error", e) logging.error(f"Serious client error", e)
raise raise
finally: finally:
if session.username: if username:
Session.all_sessions.remove(session.username) Session.all_sessions.remove(username)
stream_writer.close() stream_writer.close()