This commit is contained in:
2018-12-18 21:14:22 -05:00
parent dd321ab8b7
commit 35596a1aaa
3 changed files with 95 additions and 57 deletions

View File

@ -1,44 +1,72 @@
import asyncio
import logging
import ssl
from _contextvars import ContextVar
from dataclasses import dataclass
from pathlib import Path
import logging
from typing import ClassVar, List, Coroutine
from collections import deque
from .poputils import *
from .poputils import InvalidCommand, parse_command, err, Command, ClientQuit, ClientError, AuthError, ok, msg, end, \
MailStorage, Request
reader: ContextVar[asyncio.StreamReader] = ContextVar("reader")
writer: ContextVar[asyncio.StreamWriter] = ContextVar("writer")
@dataclass
class Session:
reader: asyncio.StreamReader
writer: asyncio.StreamWriter
username: str = ""
read_items: Path = None
# common state
all_sessions: ClassVar[List] = []
mails_path: ClassVar[Path] = Path("")
wait_for_privileges_to_drop: ClassVar[Coroutine] = None
pending_request: Request = None
def pop_request(self):
request = self.pending_request
self.pending_request = None
return request
async def next_req(self):
if self.pending_request:
return self.pop_request()
for _ in range(InvalidCommand.RETRIES):
line = await self.reader.readline()
logging.debug(f"Client: {line}")
if not line:
continue
try:
request: Request = parse_command(line)
except InvalidCommand:
write(err("Bad command"))
else:
if request.cmd == Command.QUIT:
raise ClientQuit
return request
else:
raise ClientError(f"Bad command {InvalidCommand.RETRIES} times")
async def expect_cmd(self, *commands: Command, optional=False):
req = await self.next_req()
if req.cmd not in commands:
if not optional:
logging.error(f"{req.cmd} is not in {commands}")
raise ClientError
else:
self.pending_request = req
return
return req
current_session: ContextVar[Session] = ContextVar("session")
def write(data):
logging.debug(f"Server: {data}")
writer.get().write(data)
async def next_req():
for _ in range(InvalidCommand.RETRIES):
line = await reader.get().readline()
logging.debug(f"Client: {line}")
if not line:
continue
try:
request = parse_command(line)
except InvalidCommand:
write(err("Bad command"))
else:
if request.cmd == Command.QUIT:
raise ClientQuit
return request
else:
raise ClientError(f"Bad command {InvalidCommand.RETRIES} times")
async def expect_cmd(*commands: Command):
cmd = await next_req()
if cmd.cmd not in commands:
logging.error(f"{cmd.cmd} is not in {commands}")
raise ClientError
return cmd
session: Session = current_session.get()
session.writer.write(data)
def validate_user_and_pass(username, password):
@ -47,11 +75,12 @@ def validate_user_and_pass(username, password):
async def handle_user_pass_auth(user_cmd):
session: Session = current_session.get()
username = user_cmd.arg1
if not username:
raise AuthError("Invalid USER command. username empty")
write(ok("Welcome"))
cmd = await expect_cmd(Command.PASS)
cmd = await session.expect_cmd(Command.PASS)
password = cmd.arg1
validate_user_and_pass(username, password)
write(ok("Good"))
@ -59,10 +88,11 @@ async def handle_user_pass_auth(user_cmd):
async def auth_stage():
session: Session = current_session.get()
write(ok("Server Ready"))
for _ in range(AuthError.RETRIES):
try:
req = await expect_cmd(Command.USER, Command.CAPA)
req = await session.expect_cmd(Command.USER, Command.CAPA)
if req.cmd is Command.CAPA:
write(ok("Following are supported"))
write(msg("USER"))
@ -73,22 +103,23 @@ async def auth_stage():
return username
except AuthError:
write(err("Wrong auth"))
except ClientQuit:
except ClientQuit as c:
write(ok("Bye"))
logging.info("Client has QUIT")
raise
logging.warning("Client has QUIT before auth succeeded")
raise ClientError from c
else:
raise ClientError("Failed to authenticate")
MAILS_PATH = ""
WAIT_FOR_PRIVILEGES_TO_DROP = None
async def transaction_stage(user: User):
logging.debug(f"Entering transaction stage for {user}")
async def transaction_stage():
session: Session = current_session.get()
logging.debug(f"Entering transaction stage for {session.username}")
deleted_message_ids = []
mailbox = MailStorage(MAILS_PATH)
mailbox = MailStorage(Session.mails_path / 'new')
with session.read_items.open() as f:
read_items = set(f.read().splitlines())
mails_list = mailbox.get_mails_list()
mails_map = {str(entry.nid): entry for entry in mails_list}
while True:
@ -117,7 +148,7 @@ async def transaction_stage(user: User):
for entry in mails_list:
write(msg(f"{entry.nid} {entry.uid}"))
write(end())
await writer.get().drain()
await session.writer.drain()
elif req.cmd is Command.RETR:
if req.arg1 not in mails_map:
write(err("Not found"))
@ -125,7 +156,7 @@ async def transaction_stage(user: User):
write(ok("Contents follow"))
write(mailbox.get_mail(mails_map[req.arg1]))
write(end())
await writer.get().drain()
await session.writer.drain()
else:
write(err("Not implemented"))
except ClientQuit:
@ -138,18 +169,15 @@ def delete_messages(delete_ids):
async def new_session(stream_reader: asyncio.StreamReader, stream_writer: asyncio.StreamWriter):
if WAIT_FOR_PRIVILEGES_TO_DROP:
if Session.wait_for_privileges_to_drop:
logging.warning("Waiting for privileges to drop")
await WAIT_FOR_PRIVILEGES_TO_DROP
reader.set(stream_reader)
writer.set(stream_writer)
await Session.wait_for_privileges_to_drop
current_session.set(Session(stream_reader, stream_writer))
logging.info(f"New session started with {stream_reader} and {stream_writer}")
try:
username: User = await auth_stage()
delete_ids = await transaction_stage(username)
await auth_stage()
delete_ids = await transaction_stage()
delete_messages(delete_ids)
except ClientQuit:
pass
except ClientError as c:
write(err("Something went wrong"))
logging.error(f"Unexpected client error", c)
@ -163,9 +191,8 @@ async def new_session(stream_reader: asyncio.StreamReader, stream_writer: asynci
async def a_main(dirpath: Path, port: int, host="", context: ssl.SSLContext = None, waiter=None):
logging.info(
f"Starting POP3 server Maildir={dirpath}, host={host}, port={port}, context={context}, waiter={waiter}")
global MAILS_PATH, WAIT_FOR_PRIVILEGES_TO_DROP
MAILS_PATH = dirpath / 'new'
WAIT_FOR_PRIVILEGES_TO_DROP = waiter
Session.mails_path = dirpath
Session.wait_for_privileges_to_drop = waiter
server = await asyncio.start_server(new_session, host=host, port=port, ssl=context)
await server.serve_forever()

View File

@ -1,6 +1,7 @@
import os
from dataclasses import dataclass
from enum import Enum, auto
from pathlib import Path
from typing import NewType, List
@ -109,7 +110,7 @@ class MailEntry:
class MailStorage:
def __init__(self, dirpath: str):
def __init__(self, dirpath: Path):
self.dirpath = dirpath
self.files = files_in_path(self.dirpath)
self.entries = [MailEntry(filename, path) for filename, path in self.files]
@ -127,3 +128,4 @@ class MailStorage:
def get_mail(entry: MailEntry) -> bytes:
with open(entry.path, mode='rb') as fp:
return fp.read()