get auth stage working

This commit is contained in:
Balakrishnan Balasubramanian 2023-06-10 22:20:36 -04:00
parent 53bb412721
commit 02f3068626
3 changed files with 57 additions and 22 deletions

View File

@ -3,6 +3,7 @@ import logging
import os
import ssl
import contextvars
import contextlib
from dataclasses import dataclass
from hashlib import sha256
from pathlib import Path
@ -10,15 +11,18 @@ from .config import User
from .pwhash import parse_hash, check_pass, PWInfo
from asyncio import StreamReader, StreamWriter
from .poputils import InvalidCommand, parse_command, err, Command, ClientQuit, ClientError, AuthError, ok, msg, end, \
Request, MailEntry, get_mail, get_mails_list, MailList
from .poputils import InvalidCommand, parse_command, err, Command, \
ClientQuit, ClientDisconnected, ClientError, AuthError, ok, \
msg, end, Request, MailEntry, get_mail, get_mails_list, MailList
async def next_req() -> Request:
for _ in range(InvalidCommand.RETRIES):
line = await state().reader.readline()
# logging.debug(f"Client: {line}")
logging.debug(f"Client: {line!r}")
if not line:
if state().reader.at_eof():
raise ClientDisconnected
continue
try:
request: Request = parse_command(line)
@ -79,19 +83,20 @@ async def auth_stage() -> None:
write(end())
else:
await handle_user_pass_auth(req)
if (username:=state().username) in config().loggedin_users:
if state().username in config().loggedin_users:
logging.warning(
f"User: {username} already has an active session")
f"User: {state().username} already has an active session")
raise AuthError("Already logged in")
else:
config().loggedin_users.add(username)
config().loggedin_users.add(state().username)
write(ok("Login successful"))
return
except AuthError as ae:
write(err(f"Auth Failed: {ae}"))
except ClientQuit as c:
write(ok("Bye"))
logging.warning("Client has QUIT before auth succeeded")
raise ClientError from c
raise
else:
raise ClientError("Failed to authenticate")
@ -201,7 +206,8 @@ def get_deleted_items(deleted_items_path: Path) -> set[str]:
return set()
def save_deleted_items(deleted_items_path: Path, deleted_items: set[str]) -> None:
def save_deleted_items(deleted_items_path: Path,
deleted_items: set[str]) -> None:
with deleted_items_path.open(mode="w") as f:
f.writelines(f"{did}\n" for did in deleted_items)
@ -232,6 +238,12 @@ async def start_session() -> None:
assert state().mbox
await transaction_stage()
logging.info(f"User:{state().username} done")
except ClientDisconnected as c:
logging.info("Client disconnected")
pass
except ClientQuit:
logging.info("Client QUIT")
pass
except ClientError as c:
write(err("Something went wrong"))
logging.error(f"Unexpected client error: {c}")
@ -239,7 +251,7 @@ async def start_session() -> None:
logging.error(f"Serious client error: {e}")
raise
finally:
if state().username:
with contextlib.suppress(KeyError):
config().loggedin_users.remove(state().username)
@ -295,6 +307,7 @@ def make_pop_server_callback(mails_path: Path, users: list[User],
return await asyncio.wait_for(start_session(), timeout_seconds)
finally:
writer.close()
await writer.wait_closed()
return session_cb

View File

@ -3,14 +3,16 @@ import asyncio
import logging
from .pop3 import create_pop_server
from .config import User
from pathlib import Path
class TestPop3(unittest.IsolatedAsyncioTestCase):
def setUp(self):
logging.basicConfig(level=logging.CRITICAL)
def setUp(self) -> None:
logging.basicConfig(level=logging.DEBUG)
async def asyncSetUp(self):
async def asyncSetUp(self) -> None:
logging.debug("at asyncSetUp")
test_hash = "".join((l.strip() for l in """
AFTY5EVN7AX47ZL7UMH3BETYWFBTAV3XHR73CEFAJBPN2NIHPWD
ZHV2UQSMSPHSQQ2A2BFQBNC77VL7F2UKATQNJZGYLCSU6C43UQD
@ -21,12 +23,13 @@ class TestPop3(unittest.IsolatedAsyncioTestCase):
]
pop_server = await create_pop_server(host='127.0.0.1',
port=7995,
mails_path='w.tmp',
mails_path=Path('w.tmp'),
users=users)
self.task = asyncio.create_task(pop_server.serve_forever())
self.reader, self.writer = await asyncio.open_connection('127.0.0.1', 7995)
self.reader, self.writer = await asyncio.open_connection(
'127.0.0.1', 7995)
async def test_QUIT(self):
async def test_QUIT(self) -> None:
dialog = """
S: +OK Server Ready
C: QUIT
@ -34,7 +37,7 @@ class TestPop3(unittest.IsolatedAsyncioTestCase):
"""
await self.dialog_checker(dialog)
async def test_BAD(self):
async def test_BAD(self) -> None:
dialog = """
S: +OK Server Ready
C: HELO
@ -49,19 +52,25 @@ class TestPop3(unittest.IsolatedAsyncioTestCase):
# TODO fix
# self.assertTrue(reader.at_eof(), "server should close the connection")
async def test_AUTH(self):
async def do_login(self) -> None:
dialog = """
S: +OK Server Ready
C: USER foobar
S: +OK Welcome
C: PASS helloworld
S: +OK Login successful
"""
await self.dialog_checker(dialog)
async def test_AUTH(self) -> None:
await self.do_login()
dialog = """
C: QUIT
S: +OK Bye
"""
await self.dialog_checker(dialog)
async def test_dupe_AUTH(self):
async def test_dupe_AUTH(self) -> None:
r1, w1 = await asyncio.open_connection('127.0.0.1', 7995)
r2, w2 = await asyncio.open_connection('127.0.0.1', 7995)
dialog = """
@ -83,7 +92,15 @@ class TestPop3(unittest.IsolatedAsyncioTestCase):
await self.dialog_checker_impl(r1, w1, end_dialog)
await self.dialog_checker_impl(r2, w2, end_dialog)
async def test_CAPA(self):
async def test_STAT(self) -> None:
await self.do_login()
dialog = """
C: STAT
S: +OK Bye
"""
await self.dialog_checker(dialog)
async def test_CAPA(self) -> None:
dialog = """
S: +OK Server Ready
C: CAPA
@ -95,16 +112,18 @@ class TestPop3(unittest.IsolatedAsyncioTestCase):
"""
await self.dialog_checker(dialog)
async def asyncTearDown(self):
async def asyncTearDown(self) -> None:
logging.debug("at teardown")
self.writer.close()
await self.writer.wait_closed()
self.task.cancel("test done")
async def dialog_checker(self, dialog: str):
async def dialog_checker(self, dialog: str) -> None:
await self.dialog_checker_impl(self.reader, self.writer, dialog)
async def dialog_checker_impl(self, reader: asyncio.StreamReader,
writer: asyncio.StreamWriter, dialog: str):
writer: asyncio.StreamWriter,
dialog: str) -> None:
for line in dialog.splitlines():
line = line.strip()
if not line:

View File

@ -11,6 +11,9 @@ class ClientError(Exception):
class ClientQuit(ClientError):
pass
class ClientDisconnected(ClientError):
pass
class InvalidCommand(ClientError):
RETRIES = 3