pop3 refactor wip
This commit is contained in:
parent
c441863af7
commit
5600d30f54
@ -26,7 +26,14 @@ class User(Jata):
|
||||
class Config(Jata):
|
||||
certfile: str
|
||||
keyfile: str
|
||||
debug: bool = False
|
||||
mails_path: str
|
||||
host = '0.0.0.0'
|
||||
smtp_port = 25
|
||||
smtp_port_tls = 465
|
||||
smtp_port_submission = 587
|
||||
pop_port = 995
|
||||
smtputf8 = True
|
||||
rules: list[Rule]
|
||||
boxes: list[Mbox]
|
||||
users: list[User]
|
||||
|
153
mail4one/pop3.py
153
mail4one/pop3.py
@ -2,11 +2,14 @@ import asyncio
|
||||
import logging
|
||||
import os
|
||||
import ssl
|
||||
from _contextvars import ContextVar
|
||||
import contextvars
|
||||
from dataclasses import dataclass
|
||||
from hashlib import sha256
|
||||
from pathlib import Path
|
||||
from typing import ClassVar, List, Set
|
||||
from .config import User
|
||||
from .pwhash import parse_hash, check_pass
|
||||
from asyncio import StreamReader, StreamWriter
|
||||
|
||||
from .poputils import InvalidCommand, parse_command, err, Command, ClientQuit, ClientError, AuthError, ok, msg, end, \
|
||||
Request, MailEntry, get_mail, get_mails_list, MailList
|
||||
@ -19,21 +22,16 @@ def add_season(content: bytes, season: bytes):
|
||||
# noinspection PyProtectedMember
|
||||
@dataclass
|
||||
class Session:
|
||||
_reader: asyncio.StreamReader
|
||||
_reader: StreamReader
|
||||
_writer: asyncio.StreamWriter
|
||||
username: str
|
||||
mbox: str
|
||||
|
||||
# common state
|
||||
all_sessions: ClassVar[Set] = set()
|
||||
mails_path: ClassVar[Path] = Path("")
|
||||
users: ClassVar[list[User]] = list()
|
||||
current_session: ClassVar = ContextVar("session")
|
||||
password_hash: ClassVar[str] = ""
|
||||
SALT: ClassVar[bytes] = b"balki is awesome+"
|
||||
pepper: ClassVar[bytes]
|
||||
|
||||
@classmethod
|
||||
def init_password(cls, salted_hash: str):
|
||||
cls.pepper = os.urandom(32)
|
||||
cls.password_hash = add_season(bytes.fromhex(salted_hash), cls.pepper)
|
||||
|
||||
@classmethod
|
||||
def get(cls):
|
||||
@ -50,7 +48,7 @@ class Session:
|
||||
|
||||
async def next_req():
|
||||
for _ in range(InvalidCommand.RETRIES):
|
||||
line = await Session.reader().readline()
|
||||
line = await state().reader.readline()
|
||||
logging.debug(f"Client: {line}")
|
||||
if not line:
|
||||
continue
|
||||
@ -76,13 +74,20 @@ async def expect_cmd(*commands: Command):
|
||||
|
||||
def write(data):
|
||||
logging.debug(f"Server: {data}")
|
||||
Session.writer().write(data)
|
||||
state().writer.write(data)
|
||||
|
||||
|
||||
def validate_password(password):
|
||||
if Session.password_hash != add_season(add_season(password.encode(), Session.SALT), Session.pepper):
|
||||
def validate_password(username, password):
|
||||
try:
|
||||
pwinfo, mbox = config().users[username]
|
||||
except:
|
||||
raise AuthError("Invalid user pass")
|
||||
|
||||
if not check_pass(password, pwinfo):
|
||||
raise AuthError("Invalid user pass")
|
||||
state().username = username
|
||||
state().mbox = mbox
|
||||
|
||||
|
||||
async def handle_user_pass_auth(user_cmd):
|
||||
username = user_cmd.arg1
|
||||
@ -91,9 +96,8 @@ async def handle_user_pass_auth(user_cmd):
|
||||
write(ok("Welcome"))
|
||||
cmd = await expect_cmd(Command.PASS)
|
||||
password = cmd.arg1
|
||||
validate_password(password)
|
||||
logging.info(f"User: {username} has logged in successfully")
|
||||
return username
|
||||
validate_password(username, password)
|
||||
logging.info(f"{username=} has logged in successfully")
|
||||
|
||||
|
||||
async def auth_stage():
|
||||
@ -107,12 +111,12 @@ async def auth_stage():
|
||||
write(end())
|
||||
else:
|
||||
username = await handle_user_pass_auth(req)
|
||||
if username in Session.all_sessions:
|
||||
logging.warning(f"User: {username} already has an active session")
|
||||
if username in config().loggedin_users:
|
||||
logging.warning(
|
||||
f"User: {username} already has an active session")
|
||||
raise AuthError("Already logged in")
|
||||
else:
|
||||
write(ok("Login successful"))
|
||||
return username
|
||||
except AuthError as ae:
|
||||
write(err(f"Auth Failed: {ae}"))
|
||||
except ClientQuit as c:
|
||||
@ -218,7 +222,7 @@ async def process_transactions(mails_list: List[MailEntry]):
|
||||
raise ClientError("We shouldn't reach here")
|
||||
else:
|
||||
func(mails, req)
|
||||
await Session.writer().drain()
|
||||
await state().writer.drain()
|
||||
|
||||
|
||||
def get_deleted_items(deleted_items_path: Path):
|
||||
@ -234,32 +238,35 @@ def save_deleted_items(deleted_items_path: Path, deleted_items: Set):
|
||||
|
||||
|
||||
async def transaction_stage(existing_deleted_items: Set):
|
||||
mails_list = [entry for entry in get_mails_list(Session.mails_path / 'new') if
|
||||
entry.uid not in existing_deleted_items]
|
||||
mails_list = [
|
||||
entry for entry in get_mails_list(config().mails_path / 'new')
|
||||
if entry.uid not in existing_deleted_items
|
||||
]
|
||||
|
||||
new_deleted_items: Set = await process_transactions(mails_list)
|
||||
return new_deleted_items
|
||||
|
||||
|
||||
async def new_session(stream_reader: asyncio.StreamReader, stream_writer: asyncio.StreamWriter):
|
||||
session = Session(stream_reader, stream_writer)
|
||||
Session.current_session.set(session)
|
||||
logging.info(f"New session started with {stream_reader} and {stream_writer}")
|
||||
username = None
|
||||
async def start_session():
|
||||
logging.info("New session started")
|
||||
try:
|
||||
username = await auth_stage()
|
||||
await auth_stage()
|
||||
assert username is not None
|
||||
Session.all_sessions.add(username)
|
||||
deleted_items_path = Session.mails_path / username
|
||||
config().loggedin_users.add(username)
|
||||
_, mbox = config().users[username]
|
||||
deleted_items_path = config().mails_path/ mbox / username
|
||||
logging.info(f"User:{username} logged in successfully")
|
||||
|
||||
existing_deleted_items: Set = get_deleted_items(deleted_items_path)
|
||||
|
||||
new_deleted_items: Set = await transaction_stage(existing_deleted_items)
|
||||
logging.info(f"User:{username} completed transactions. Deleted:{new_deleted_items}")
|
||||
logging.info(
|
||||
f"{username=} completed transactions. Deleted:{len(new_deleted_items)}"
|
||||
)
|
||||
|
||||
if new_deleted_items:
|
||||
save_deleted_items(deleted_items_path, existing_deleted_items.union(new_deleted_items))
|
||||
save_deleted_items(deleted_items_path,
|
||||
existing_deleted_items.union(new_deleted_items))
|
||||
|
||||
logging.info(f"User:{username} Saved deleted items")
|
||||
|
||||
@ -271,23 +278,78 @@ async def new_session(stream_reader: asyncio.StreamReader, stream_writer: asynci
|
||||
raise
|
||||
finally:
|
||||
if username:
|
||||
Session.all_sessions.remove(username)
|
||||
stream_writer.close()
|
||||
config().loggedin_users.remove(username)
|
||||
|
||||
|
||||
async def timed_cb(stream_reader: asyncio.StreamReader, stream_writer: asyncio.StreamWriter):
|
||||
try:
|
||||
return await asyncio.wait_for(new_session(stream_reader, stream_writer), 60)
|
||||
finally:
|
||||
stream_writer.close()
|
||||
def parse_users(users: list[User]):
|
||||
|
||||
def inner():
|
||||
for user in Users:
|
||||
user = User(user)
|
||||
pwinfo = parse_hash(user.password_hash)
|
||||
yield user.username, (pwinfo, user.mbox)
|
||||
|
||||
return dict(inner())
|
||||
|
||||
|
||||
async def create_pop_server(dirpath: Path, port: int, password_hash: str, host="", context: ssl.SSLContext = None):
|
||||
Session.mails_path = dirpath
|
||||
Session.init_password(password_hash)
|
||||
@dataclass
|
||||
class State:
|
||||
reader: StreamReader
|
||||
writer: StreamWriter
|
||||
username: str = ""
|
||||
mbox: str = ""
|
||||
|
||||
|
||||
@dataclass
|
||||
class Config:
|
||||
mails_path: Path
|
||||
users: dict[str, tuple[pwhash.PWInfo, str]]
|
||||
loggedin_users: set[str] = set()
|
||||
|
||||
|
||||
c_config = contextvars.ContextVar('config')
|
||||
|
||||
|
||||
def config() -> Config:
|
||||
return c_config.get()
|
||||
|
||||
|
||||
c_state = contextvars.ContextVar('state')
|
||||
|
||||
|
||||
def state() -> State:
|
||||
return c_state.get()
|
||||
|
||||
|
||||
def make_pop_server_callback(dirpath: Path, users: list[User],
|
||||
timeout_seconds: int):
|
||||
config = Config(mails_path=dirpath, users=parse_users(users))
|
||||
|
||||
async def session_cb(reader: StreamReader, writer: StreamWriter):
|
||||
c_config.set(config)
|
||||
c_state.set(State(reader=reader, writer=writer))
|
||||
try:
|
||||
return await asyncio.wait_for(start_session(), timeout_seconds)
|
||||
finally:
|
||||
stream_writer.close()
|
||||
|
||||
return session_cb
|
||||
|
||||
|
||||
async def create_pop_server(dirpath: Path,
|
||||
port: int,
|
||||
users: list[User],
|
||||
host="",
|
||||
context: ssl.SSLContext = None,
|
||||
timeout_seconds: int = 60):
|
||||
logging.info(
|
||||
f"Starting POP3 server Maildir={dirpath}, host={host}, port={port}, context={context}")
|
||||
return await asyncio.start_server(timed_cb, host=host, port=port, ssl=context)
|
||||
f"Starting POP3 server {dirpath=}, {host=}, {port=}, {timeout_seconds=}, ssl={context != None}"
|
||||
)
|
||||
return await asyncio.start_server(make_pop_server_callback(
|
||||
dirpath, users, timeout_seconds),
|
||||
host=host,
|
||||
port=port,
|
||||
ssl=context)
|
||||
|
||||
|
||||
async def a_main(*args, **kwargs):
|
||||
@ -304,7 +366,6 @@ def debug_main():
|
||||
|
||||
mails_path = Path(mails_path)
|
||||
port = int(port)
|
||||
password_hash = add_season(password.encode(), Session.SALT).hex()
|
||||
|
||||
asyncio.run(a_main(mails_path, port, password_hash=password_hash))
|
||||
|
||||
|
@ -13,6 +13,8 @@ from pathlib import Path
|
||||
from .smtp import create_smtp_server_starttls, create_smtp_server_tls
|
||||
from .pop3 import create_pop_server
|
||||
|
||||
from .config import Config
|
||||
|
||||
|
||||
def create_tls_context(certfile, keyfile):
|
||||
context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH)
|
||||
@ -48,44 +50,43 @@ def setup_logging(args):
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
||||
|
||||
def drop_privileges():
|
||||
try:
|
||||
import pwd
|
||||
except ImportError:
|
||||
logging.error("Cannot import pwd; run as root")
|
||||
sys.exit(1)
|
||||
nobody = pwd.getpwnam('nobody')
|
||||
try:
|
||||
os.setgid(nobody.pw_gid)
|
||||
os.setuid(nobody.pw_uid)
|
||||
except PermissionError:
|
||||
logging.error("Cannot setuid nobody; run as root")
|
||||
sys.exit(1)
|
||||
logging.info("Dropped privileges")
|
||||
logging.debug("Signalled! Clients can come in")
|
||||
async def a_main(config, tls_context):
|
||||
pop_server = await create_pop_server(config.mails_path,
|
||||
port=config.pop_port,
|
||||
host=config.host,
|
||||
context=tls_context,
|
||||
users=config.users)
|
||||
|
||||
|
||||
async def a_main(args, tls_context):
|
||||
pop_server = await create_pop_server(
|
||||
args.mail_dir_path, port=args.pop_port, host=args.host, context=tls_context, password_hash=args.password_hash)
|
||||
smtp_server_starttls = await create_smtp_server_starttls(
|
||||
args.mail_dir_path, port=args.smtp_port, host=args.host, context=tls_context)
|
||||
smtp_server_tls = await create_smtp_server_tls(
|
||||
args.mail_dir_path, port=args.smtp_port_tls, host=args.host, context=tls_context)
|
||||
drop_privileges()
|
||||
await asyncio.gather(
|
||||
pop_server.serve_forever(),
|
||||
smtp_server_starttls.serve_forever(),
|
||||
smtp_server_tls.serve_forever())
|
||||
config.mail_dir_path,
|
||||
port=config.smtp_port,
|
||||
host=config.host,
|
||||
context=tls_context)
|
||||
|
||||
smtp_server_tls = await create_smtp_server_tls(config.mail_dir_path,
|
||||
port=config.smtp_port_tls,
|
||||
host=config.host,
|
||||
context=tls_context)
|
||||
|
||||
await asyncio.gather(pop_server.serve_forever(),
|
||||
smtp_server_starttls.serve_forever(),
|
||||
smtp_server_tls.serve_forever())
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
tls_context = create_tls_context(args.certfile, args.keyfile)
|
||||
config_path = sys.argv[1]
|
||||
parser = ArgumentParser()
|
||||
parser.add_argument("config_path")
|
||||
args = parser.parse_args()
|
||||
config = Config(open(args.config_path).read())
|
||||
|
||||
setup_logging(args)
|
||||
loop = asyncio.get_event_loop()
|
||||
loop.set_debug(args.debug)
|
||||
asyncio.run(a_main(args, tls_context))
|
||||
loop.set_debug(config.debug)
|
||||
|
||||
tls_context = create_tls_context(config.certfile, config.keyfile)
|
||||
|
||||
asyncio.run(a_main(config, tls_context))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
Loading…
Reference in New Issue
Block a user