temp merge

This commit is contained in:
Balakrishnan Balasubramanian 2018-12-18 23:17:03 -05:00
commit bd85de78e0
3 changed files with 92 additions and 53 deletions

9
TODO.md Normal file
View File

@ -0,0 +1,9 @@
#TODO
1. User timeout for POP
1. unittests
1. Web interface
1. Custom email processing
1. Refactor smtp controller
1. pip installable package
1. Listen on port 465 for smtp too

View File

@ -1,44 +1,72 @@
import asyncio
import logging
import ssl
from _contextvars import ContextVar
from dataclasses import dataclass
from pathlib import Path
import logging
from typing import ClassVar, List, Coroutine
from collections import deque
from .poputils import *
from .poputils import InvalidCommand, parse_command, err, Command, ClientQuit, ClientError, AuthError, ok, msg, end, \
MailStorage, Request
reader: ContextVar[asyncio.StreamReader] = ContextVar("reader")
writer: ContextVar[asyncio.StreamWriter] = ContextVar("writer")
@dataclass
class Session:
reader: asyncio.StreamReader
writer: asyncio.StreamWriter
username: str = ""
read_items: Path = None
# common state
all_sessions: ClassVar[List] = []
mails_path: ClassVar[Path] = Path("")
wait_for_privileges_to_drop: ClassVar[Coroutine] = None
pending_request: Request = None
def pop_request(self):
request = self.pending_request
self.pending_request = None
return request
async def next_req(self):
if self.pending_request:
return self.pop_request()
for _ in range(InvalidCommand.RETRIES):
line = await self.reader.readline()
logging.debug(f"Client: {line}")
if not line:
continue
try:
request: Request = parse_command(line)
except InvalidCommand:
write(err("Bad command"))
else:
if request.cmd == Command.QUIT:
raise ClientQuit
return request
else:
raise ClientError(f"Bad command {InvalidCommand.RETRIES} times")
async def expect_cmd(self, *commands: Command, optional=False):
req = await self.next_req()
if req.cmd not in commands:
if not optional:
logging.error(f"{req.cmd} is not in {commands}")
raise ClientError
else:
self.pending_request = req
return
return req
current_session: ContextVar[Session] = ContextVar("session")
def write(data):
logging.debug(f"Server: {data}")
writer.get().write(data)
async def next_req():
for _ in range(InvalidCommand.RETRIES):
line = await reader.get().readline()
logging.debug(f"Client: {line}")
if not line:
continue
try:
request = parse_command(line)
except InvalidCommand:
write(err("Bad command"))
else:
if request.cmd == Command.QUIT:
raise ClientQuit
return request
else:
raise ClientError(f"Bad command {InvalidCommand.RETRIES} times")
async def expect_cmd(*commands: Command):
cmd = await next_req()
if cmd.cmd not in commands:
logging.error(f"{cmd.cmd} is not in {commands}")
raise ClientError
return cmd
session: Session = current_session.get()
session.writer.write(data)
def validate_user_and_pass(username, password):
@ -47,11 +75,12 @@ def validate_user_and_pass(username, password):
async def handle_user_pass_auth(user_cmd):
session: Session = current_session.get()
username = user_cmd.arg1
if not username:
raise AuthError("Invalid USER command. username empty")
write(ok("Welcome"))
cmd = await expect_cmd(Command.PASS)
cmd = await session.expect_cmd(Command.PASS)
password = cmd.arg1
validate_user_and_pass(username, password)
write(ok("Good"))
@ -59,10 +88,11 @@ async def handle_user_pass_auth(user_cmd):
async def auth_stage():
session: Session = current_session.get()
write(ok("Server Ready"))
for _ in range(AuthError.RETRIES):
try:
req = await expect_cmd(Command.USER, Command.CAPA)
req = await session.expect_cmd(Command.USER, Command.CAPA)
if req.cmd is Command.CAPA:
write(ok("Following are supported"))
write(msg("USER"))
@ -73,21 +103,23 @@ async def auth_stage():
return username
except AuthError:
write(err("Wrong auth"))
except ClientQuit:
except ClientQuit as c:
write(ok("Bye"))
logging.info("Client has QUIT")
raise
logging.warning("Client has QUIT before auth succeeded")
raise ClientError from c
else:
raise ClientError("Failed to authenticate")
MAILS_PATH = ""
async def transaction_stage(user: User):
logging.debug(f"Entering transaction stage for {user}")
async def transaction_stage():
session: Session = current_session.get()
logging.debug(f"Entering transaction stage for {session.username}")
deleted_message_ids = []
mailbox = MailStorage(MAILS_PATH)
mailbox = MailStorage(Session.mails_path / 'new')
with session.read_items.open() as f:
read_items = set(f.read().splitlines())
mails_list = mailbox.get_mails_list()
mails_map = {str(entry.nid): entry for entry in mails_list}
while True:
@ -116,7 +148,7 @@ async def transaction_stage(user: User):
for entry in mails_list:
write(msg(f"{entry.nid} {entry.uid}"))
write(end())
await writer.get().drain()
await session.writer.drain()
elif req.cmd is Command.RETR:
if req.arg1 not in mails_map:
write(err("Not found"))
@ -124,7 +156,7 @@ async def transaction_stage(user: User):
write(ok("Contents follow"))
write(mailbox.get_mail(mails_map[req.arg1]))
write(end())
await writer.get().drain()
await session.writer.drain()
else:
write(err("Not implemented"))
except ClientQuit:
@ -137,15 +169,12 @@ def delete_messages(delete_ids):
async def new_session(stream_reader: asyncio.StreamReader, stream_writer: asyncio.StreamWriter):
reader.set(stream_reader)
writer.set(stream_writer)
current_session.set(Session(stream_reader, stream_writer))
logging.info(f"New session started with {stream_reader} and {stream_writer}")
try:
username: User = await auth_stage()
delete_ids = await transaction_stage(username)
await auth_stage()
delete_ids = await transaction_stage()
delete_messages(delete_ids)
except ClientQuit:
pass
except ClientError as c:
write(err("Something went wrong"))
logging.error(f"Unexpected client error", c)
@ -157,10 +186,9 @@ async def new_session(stream_reader: asyncio.StreamReader, stream_writer: asynci
async def create_pop_server(dirpath: Path, port: int, host="", context: ssl.SSLContext = None):
Session.mails_path = dirpath
logging.info(
f"Starting POP3 server Maildir={dirpath}, host={host}, port={port}, context={context}")
global MAILS_PATH
MAILS_PATH = dirpath / 'new'
return await asyncio.start_server(new_session, host=host, port=port, ssl=context)

View File

@ -1,6 +1,7 @@
import os
from dataclasses import dataclass
from enum import Enum, auto
from pathlib import Path
from typing import NewType, List
@ -109,7 +110,7 @@ class MailEntry:
class MailStorage:
def __init__(self, dirpath: str):
def __init__(self, dirpath: Path):
self.dirpath = dirpath
self.files = files_in_path(self.dirpath)
self.entries = [MailEntry(filename, path) for filename, path in self.files]
@ -127,3 +128,4 @@ class MailStorage:
def get_mail(entry: MailEntry) -> bytes:
with open(entry.path, mode='rb') as fp:
return fp.read()