get auth stage working

This commit is contained in:
Balakrishnan Balasubramanian 2023-06-10 22:20:36 -04:00
parent 53bb412721
commit 02f3068626
3 changed files with 57 additions and 22 deletions

View File

@ -3,6 +3,7 @@ import logging
import os import os
import ssl import ssl
import contextvars import contextvars
import contextlib
from dataclasses import dataclass from dataclasses import dataclass
from hashlib import sha256 from hashlib import sha256
from pathlib import Path from pathlib import Path
@ -10,15 +11,18 @@ from .config import User
from .pwhash import parse_hash, check_pass, PWInfo from .pwhash import parse_hash, check_pass, PWInfo
from asyncio import StreamReader, StreamWriter from asyncio import StreamReader, StreamWriter
from .poputils import InvalidCommand, parse_command, err, Command, ClientQuit, ClientError, AuthError, ok, msg, end, \ from .poputils import InvalidCommand, parse_command, err, Command, \
Request, MailEntry, get_mail, get_mails_list, MailList ClientQuit, ClientDisconnected, ClientError, AuthError, ok, \
msg, end, Request, MailEntry, get_mail, get_mails_list, MailList
async def next_req() -> Request: async def next_req() -> Request:
for _ in range(InvalidCommand.RETRIES): for _ in range(InvalidCommand.RETRIES):
line = await state().reader.readline() line = await state().reader.readline()
# logging.debug(f"Client: {line}") logging.debug(f"Client: {line!r}")
if not line: if not line:
if state().reader.at_eof():
raise ClientDisconnected
continue continue
try: try:
request: Request = parse_command(line) request: Request = parse_command(line)
@ -79,19 +83,20 @@ async def auth_stage() -> None:
write(end()) write(end())
else: else:
await handle_user_pass_auth(req) await handle_user_pass_auth(req)
if (username:=state().username) in config().loggedin_users: if state().username in config().loggedin_users:
logging.warning( logging.warning(
f"User: {username} already has an active session") f"User: {state().username} already has an active session")
raise AuthError("Already logged in") raise AuthError("Already logged in")
else: else:
config().loggedin_users.add(username) config().loggedin_users.add(state().username)
write(ok("Login successful")) write(ok("Login successful"))
return
except AuthError as ae: except AuthError as ae:
write(err(f"Auth Failed: {ae}")) write(err(f"Auth Failed: {ae}"))
except ClientQuit as c: except ClientQuit as c:
write(ok("Bye")) write(ok("Bye"))
logging.warning("Client has QUIT before auth succeeded") logging.warning("Client has QUIT before auth succeeded")
raise ClientError from c raise
else: else:
raise ClientError("Failed to authenticate") raise ClientError("Failed to authenticate")
@ -201,7 +206,8 @@ def get_deleted_items(deleted_items_path: Path) -> set[str]:
return set() return set()
def save_deleted_items(deleted_items_path: Path, deleted_items: set[str]) -> None: def save_deleted_items(deleted_items_path: Path,
deleted_items: set[str]) -> None:
with deleted_items_path.open(mode="w") as f: with deleted_items_path.open(mode="w") as f:
f.writelines(f"{did}\n" for did in deleted_items) f.writelines(f"{did}\n" for did in deleted_items)
@ -232,6 +238,12 @@ async def start_session() -> None:
assert state().mbox assert state().mbox
await transaction_stage() await transaction_stage()
logging.info(f"User:{state().username} done") logging.info(f"User:{state().username} done")
except ClientDisconnected as c:
logging.info("Client disconnected")
pass
except ClientQuit:
logging.info("Client QUIT")
pass
except ClientError as c: except ClientError as c:
write(err("Something went wrong")) write(err("Something went wrong"))
logging.error(f"Unexpected client error: {c}") logging.error(f"Unexpected client error: {c}")
@ -239,7 +251,7 @@ async def start_session() -> None:
logging.error(f"Serious client error: {e}") logging.error(f"Serious client error: {e}")
raise raise
finally: finally:
if state().username: with contextlib.suppress(KeyError):
config().loggedin_users.remove(state().username) config().loggedin_users.remove(state().username)
@ -295,6 +307,7 @@ def make_pop_server_callback(mails_path: Path, users: list[User],
return await asyncio.wait_for(start_session(), timeout_seconds) return await asyncio.wait_for(start_session(), timeout_seconds)
finally: finally:
writer.close() writer.close()
await writer.wait_closed()
return session_cb return session_cb

View File

@ -3,14 +3,16 @@ import asyncio
import logging import logging
from .pop3 import create_pop_server from .pop3 import create_pop_server
from .config import User from .config import User
from pathlib import Path
class TestPop3(unittest.IsolatedAsyncioTestCase): class TestPop3(unittest.IsolatedAsyncioTestCase):
def setUp(self): def setUp(self) -> None:
logging.basicConfig(level=logging.CRITICAL) logging.basicConfig(level=logging.DEBUG)
async def asyncSetUp(self): async def asyncSetUp(self) -> None:
logging.debug("at asyncSetUp")
test_hash = "".join((l.strip() for l in """ test_hash = "".join((l.strip() for l in """
AFTY5EVN7AX47ZL7UMH3BETYWFBTAV3XHR73CEFAJBPN2NIHPWD AFTY5EVN7AX47ZL7UMH3BETYWFBTAV3XHR73CEFAJBPN2NIHPWD
ZHV2UQSMSPHSQQ2A2BFQBNC77VL7F2UKATQNJZGYLCSU6C43UQD ZHV2UQSMSPHSQQ2A2BFQBNC77VL7F2UKATQNJZGYLCSU6C43UQD
@ -21,12 +23,13 @@ class TestPop3(unittest.IsolatedAsyncioTestCase):
] ]
pop_server = await create_pop_server(host='127.0.0.1', pop_server = await create_pop_server(host='127.0.0.1',
port=7995, port=7995,
mails_path='w.tmp', mails_path=Path('w.tmp'),
users=users) users=users)
self.task = asyncio.create_task(pop_server.serve_forever()) self.task = asyncio.create_task(pop_server.serve_forever())
self.reader, self.writer = await asyncio.open_connection('127.0.0.1', 7995) self.reader, self.writer = await asyncio.open_connection(
'127.0.0.1', 7995)
async def test_QUIT(self): async def test_QUIT(self) -> None:
dialog = """ dialog = """
S: +OK Server Ready S: +OK Server Ready
C: QUIT C: QUIT
@ -34,7 +37,7 @@ class TestPop3(unittest.IsolatedAsyncioTestCase):
""" """
await self.dialog_checker(dialog) await self.dialog_checker(dialog)
async def test_BAD(self): async def test_BAD(self) -> None:
dialog = """ dialog = """
S: +OK Server Ready S: +OK Server Ready
C: HELO C: HELO
@ -49,19 +52,25 @@ class TestPop3(unittest.IsolatedAsyncioTestCase):
# TODO fix # TODO fix
# self.assertTrue(reader.at_eof(), "server should close the connection") # self.assertTrue(reader.at_eof(), "server should close the connection")
async def test_AUTH(self): async def do_login(self) -> None:
dialog = """ dialog = """
S: +OK Server Ready S: +OK Server Ready
C: USER foobar C: USER foobar
S: +OK Welcome S: +OK Welcome
C: PASS helloworld C: PASS helloworld
S: +OK Login successful S: +OK Login successful
"""
await self.dialog_checker(dialog)
async def test_AUTH(self) -> None:
await self.do_login()
dialog = """
C: QUIT C: QUIT
S: +OK Bye S: +OK Bye
""" """
await self.dialog_checker(dialog) await self.dialog_checker(dialog)
async def test_dupe_AUTH(self): async def test_dupe_AUTH(self) -> None:
r1, w1 = await asyncio.open_connection('127.0.0.1', 7995) r1, w1 = await asyncio.open_connection('127.0.0.1', 7995)
r2, w2 = await asyncio.open_connection('127.0.0.1', 7995) r2, w2 = await asyncio.open_connection('127.0.0.1', 7995)
dialog = """ dialog = """
@ -83,7 +92,15 @@ class TestPop3(unittest.IsolatedAsyncioTestCase):
await self.dialog_checker_impl(r1, w1, end_dialog) await self.dialog_checker_impl(r1, w1, end_dialog)
await self.dialog_checker_impl(r2, w2, end_dialog) await self.dialog_checker_impl(r2, w2, end_dialog)
async def test_CAPA(self): async def test_STAT(self) -> None:
await self.do_login()
dialog = """
C: STAT
S: +OK Bye
"""
await self.dialog_checker(dialog)
async def test_CAPA(self) -> None:
dialog = """ dialog = """
S: +OK Server Ready S: +OK Server Ready
C: CAPA C: CAPA
@ -95,16 +112,18 @@ class TestPop3(unittest.IsolatedAsyncioTestCase):
""" """
await self.dialog_checker(dialog) await self.dialog_checker(dialog)
async def asyncTearDown(self): async def asyncTearDown(self) -> None:
logging.debug("at teardown")
self.writer.close() self.writer.close()
await self.writer.wait_closed() await self.writer.wait_closed()
self.task.cancel("test done") self.task.cancel("test done")
async def dialog_checker(self, dialog: str): async def dialog_checker(self, dialog: str) -> None:
await self.dialog_checker_impl(self.reader, self.writer, dialog) await self.dialog_checker_impl(self.reader, self.writer, dialog)
async def dialog_checker_impl(self, reader: asyncio.StreamReader, async def dialog_checker_impl(self, reader: asyncio.StreamReader,
writer: asyncio.StreamWriter, dialog: str): writer: asyncio.StreamWriter,
dialog: str) -> None:
for line in dialog.splitlines(): for line in dialog.splitlines():
line = line.strip() line = line.strip()
if not line: if not line:

View File

@ -11,6 +11,9 @@ class ClientError(Exception):
class ClientQuit(ClientError): class ClientQuit(ClientError):
pass pass
class ClientDisconnected(ClientError):
pass
class InvalidCommand(ClientError): class InvalidCommand(ClientError):
RETRIES = 3 RETRIES = 3