temp merge

This commit is contained in:
Balakrishnan Balasubramanian 2018-12-18 23:17:03 -05:00
commit bd85de78e0
3 changed files with 92 additions and 53 deletions

9
TODO.md Normal file
View File

@ -0,0 +1,9 @@
#TODO
1. User timeout for POP
1. unittests
1. Web interface
1. Custom email processing
1. Refactor smtp controller
1. pip installable package
1. Listen on port 465 for smtp too

View File

@ -1,44 +1,72 @@
import asyncio import asyncio
import logging
import ssl import ssl
from _contextvars import ContextVar from _contextvars import ContextVar
from dataclasses import dataclass
from pathlib import Path from pathlib import Path
import logging from typing import ClassVar, List, Coroutine
from collections import deque
from .poputils import * from .poputils import InvalidCommand, parse_command, err, Command, ClientQuit, ClientError, AuthError, ok, msg, end, \
MailStorage, Request
reader: ContextVar[asyncio.StreamReader] = ContextVar("reader")
writer: ContextVar[asyncio.StreamWriter] = ContextVar("writer") @dataclass
class Session:
reader: asyncio.StreamReader
writer: asyncio.StreamWriter
username: str = ""
read_items: Path = None
# common state
all_sessions: ClassVar[List] = []
mails_path: ClassVar[Path] = Path("")
wait_for_privileges_to_drop: ClassVar[Coroutine] = None
pending_request: Request = None
def pop_request(self):
request = self.pending_request
self.pending_request = None
return request
async def next_req(self):
if self.pending_request:
return self.pop_request()
for _ in range(InvalidCommand.RETRIES):
line = await self.reader.readline()
logging.debug(f"Client: {line}")
if not line:
continue
try:
request: Request = parse_command(line)
except InvalidCommand:
write(err("Bad command"))
else:
if request.cmd == Command.QUIT:
raise ClientQuit
return request
else:
raise ClientError(f"Bad command {InvalidCommand.RETRIES} times")
async def expect_cmd(self, *commands: Command, optional=False):
req = await self.next_req()
if req.cmd not in commands:
if not optional:
logging.error(f"{req.cmd} is not in {commands}")
raise ClientError
else:
self.pending_request = req
return
return req
current_session: ContextVar[Session] = ContextVar("session")
def write(data): def write(data):
logging.debug(f"Server: {data}") logging.debug(f"Server: {data}")
writer.get().write(data) session: Session = current_session.get()
session.writer.write(data)
async def next_req():
for _ in range(InvalidCommand.RETRIES):
line = await reader.get().readline()
logging.debug(f"Client: {line}")
if not line:
continue
try:
request = parse_command(line)
except InvalidCommand:
write(err("Bad command"))
else:
if request.cmd == Command.QUIT:
raise ClientQuit
return request
else:
raise ClientError(f"Bad command {InvalidCommand.RETRIES} times")
async def expect_cmd(*commands: Command):
cmd = await next_req()
if cmd.cmd not in commands:
logging.error(f"{cmd.cmd} is not in {commands}")
raise ClientError
return cmd
def validate_user_and_pass(username, password): def validate_user_and_pass(username, password):
@ -47,11 +75,12 @@ def validate_user_and_pass(username, password):
async def handle_user_pass_auth(user_cmd): async def handle_user_pass_auth(user_cmd):
session: Session = current_session.get()
username = user_cmd.arg1 username = user_cmd.arg1
if not username: if not username:
raise AuthError("Invalid USER command. username empty") raise AuthError("Invalid USER command. username empty")
write(ok("Welcome")) write(ok("Welcome"))
cmd = await expect_cmd(Command.PASS) cmd = await session.expect_cmd(Command.PASS)
password = cmd.arg1 password = cmd.arg1
validate_user_and_pass(username, password) validate_user_and_pass(username, password)
write(ok("Good")) write(ok("Good"))
@ -59,10 +88,11 @@ async def handle_user_pass_auth(user_cmd):
async def auth_stage(): async def auth_stage():
session: Session = current_session.get()
write(ok("Server Ready")) write(ok("Server Ready"))
for _ in range(AuthError.RETRIES): for _ in range(AuthError.RETRIES):
try: try:
req = await expect_cmd(Command.USER, Command.CAPA) req = await session.expect_cmd(Command.USER, Command.CAPA)
if req.cmd is Command.CAPA: if req.cmd is Command.CAPA:
write(ok("Following are supported")) write(ok("Following are supported"))
write(msg("USER")) write(msg("USER"))
@ -73,21 +103,23 @@ async def auth_stage():
return username return username
except AuthError: except AuthError:
write(err("Wrong auth")) write(err("Wrong auth"))
except ClientQuit: except ClientQuit as c:
write(ok("Bye")) write(ok("Bye"))
logging.info("Client has QUIT") logging.warning("Client has QUIT before auth succeeded")
raise raise ClientError from c
else: else:
raise ClientError("Failed to authenticate") raise ClientError("Failed to authenticate")
MAILS_PATH = "" async def transaction_stage():
session: Session = current_session.get()
logging.debug(f"Entering transaction stage for {session.username}")
async def transaction_stage(user: User):
logging.debug(f"Entering transaction stage for {user}")
deleted_message_ids = [] deleted_message_ids = []
mailbox = MailStorage(MAILS_PATH) mailbox = MailStorage(Session.mails_path / 'new')
with session.read_items.open() as f:
read_items = set(f.read().splitlines())
mails_list = mailbox.get_mails_list() mails_list = mailbox.get_mails_list()
mails_map = {str(entry.nid): entry for entry in mails_list} mails_map = {str(entry.nid): entry for entry in mails_list}
while True: while True:
@ -116,7 +148,7 @@ async def transaction_stage(user: User):
for entry in mails_list: for entry in mails_list:
write(msg(f"{entry.nid} {entry.uid}")) write(msg(f"{entry.nid} {entry.uid}"))
write(end()) write(end())
await writer.get().drain() await session.writer.drain()
elif req.cmd is Command.RETR: elif req.cmd is Command.RETR:
if req.arg1 not in mails_map: if req.arg1 not in mails_map:
write(err("Not found")) write(err("Not found"))
@ -124,7 +156,7 @@ async def transaction_stage(user: User):
write(ok("Contents follow")) write(ok("Contents follow"))
write(mailbox.get_mail(mails_map[req.arg1])) write(mailbox.get_mail(mails_map[req.arg1]))
write(end()) write(end())
await writer.get().drain() await session.writer.drain()
else: else:
write(err("Not implemented")) write(err("Not implemented"))
except ClientQuit: except ClientQuit:
@ -137,15 +169,12 @@ def delete_messages(delete_ids):
async def new_session(stream_reader: asyncio.StreamReader, stream_writer: asyncio.StreamWriter): async def new_session(stream_reader: asyncio.StreamReader, stream_writer: asyncio.StreamWriter):
reader.set(stream_reader) current_session.set(Session(stream_reader, stream_writer))
writer.set(stream_writer)
logging.info(f"New session started with {stream_reader} and {stream_writer}") logging.info(f"New session started with {stream_reader} and {stream_writer}")
try: try:
username: User = await auth_stage() await auth_stage()
delete_ids = await transaction_stage(username) delete_ids = await transaction_stage()
delete_messages(delete_ids) delete_messages(delete_ids)
except ClientQuit:
pass
except ClientError as c: except ClientError as c:
write(err("Something went wrong")) write(err("Something went wrong"))
logging.error(f"Unexpected client error", c) logging.error(f"Unexpected client error", c)
@ -157,10 +186,9 @@ async def new_session(stream_reader: asyncio.StreamReader, stream_writer: asynci
async def create_pop_server(dirpath: Path, port: int, host="", context: ssl.SSLContext = None): async def create_pop_server(dirpath: Path, port: int, host="", context: ssl.SSLContext = None):
Session.mails_path = dirpath
logging.info( logging.info(
f"Starting POP3 server Maildir={dirpath}, host={host}, port={port}, context={context}") f"Starting POP3 server Maildir={dirpath}, host={host}, port={port}, context={context}")
global MAILS_PATH
MAILS_PATH = dirpath / 'new'
return await asyncio.start_server(new_session, host=host, port=port, ssl=context) return await asyncio.start_server(new_session, host=host, port=port, ssl=context)

View File

@ -1,6 +1,7 @@
import os import os
from dataclasses import dataclass from dataclasses import dataclass
from enum import Enum, auto from enum import Enum, auto
from pathlib import Path
from typing import NewType, List from typing import NewType, List
@ -109,7 +110,7 @@ class MailEntry:
class MailStorage: class MailStorage:
def __init__(self, dirpath: str): def __init__(self, dirpath: Path):
self.dirpath = dirpath self.dirpath = dirpath
self.files = files_in_path(self.dirpath) self.files = files_in_path(self.dirpath)
self.entries = [MailEntry(filename, path) for filename, path in self.files] self.entries = [MailEntry(filename, path) for filename, path in self.files]
@ -127,3 +128,4 @@ class MailStorage:
def get_mail(entry: MailEntry) -> bytes: def get_mail(entry: MailEntry) -> bytes:
with open(entry.path, mode='rb') as fp: with open(entry.path, mode='rb') as fp:
return fp.read() return fp.read()