Format using black and then lsp format

This commit is contained in:
Balakrishnan Balasubramanian 2023-06-16 21:55:57 -04:00
parent 0ed2341d68
commit 2fa748c444
6 changed files with 138 additions and 89 deletions

View File

@ -63,7 +63,7 @@ class SmtpCfg(ServerCfg):
class Config(Jata): class Config(Jata):
default_tls: TLSCfg | None default_tls: TLSCfg | None
default_host: str = '0.0.0.0' default_host: str = "0.0.0.0"
debug: bool = False debug: bool = False
mails_path: str mails_path: str

View File

@ -11,9 +11,24 @@ from .config import User
from .pwhash import parse_hash, check_pass, PWInfo from .pwhash import parse_hash, check_pass, PWInfo
from asyncio import StreamReader, StreamWriter from asyncio import StreamReader, StreamWriter
from .poputils import InvalidCommand, parse_command, err, Command, \ from .poputils import (
ClientQuit, ClientDisconnected, ClientError, AuthError, ok, \ InvalidCommand,
msg, end, Request, MailEntry, get_mail, get_mails_list, MailList parse_command,
err,
Command,
ClientQuit,
ClientDisconnected,
ClientError,
AuthError,
ok,
msg,
end,
Request,
MailEntry,
get_mail,
get_mails_list,
MailList,
)
async def next_req() -> Request: async def next_req() -> Request:
@ -85,7 +100,8 @@ async def auth_stage() -> None:
await handle_user_pass_auth(req) await handle_user_pass_auth(req)
if state().username in config().loggedin_users: if state().username in config().loggedin_users:
logging.warning( logging.warning(
f"User: {state().username} already has an active session") f"User: {state().username} already has an active session"
)
raise AuthError("Already logged in") raise AuthError("Already logged in")
else: else:
config().loggedin_users.add(state().username) config().loggedin_users.add(state().username)
@ -217,7 +233,7 @@ async def transaction_stage() -> None:
existing_deleted_items: set[str] = get_deleted_items(deleted_items_path) existing_deleted_items: set[str] = get_deleted_items(deleted_items_path)
mails_list = [ mails_list = [
entry entry
for entry in get_mails_list(config().mails_path / state().mbox / 'new') for entry in get_mails_list(config().mails_path / state().mbox / "new")
if entry.uid not in existing_deleted_items if entry.uid not in existing_deleted_items
] ]
@ -282,14 +298,14 @@ class Config:
self.loggedin_users: set[str] = set() self.loggedin_users: set[str] = set()
c_config: contextvars.ContextVar = contextvars.ContextVar('config') c_config: contextvars.ContextVar = contextvars.ContextVar("config")
def config() -> Config: def config() -> Config:
return c_config.get() return c_config.get()
c_state: contextvars.ContextVar = contextvars.ContextVar('state') c_state: contextvars.ContextVar = contextvars.ContextVar("state")
def state() -> State: def state() -> State:
@ -312,20 +328,23 @@ def make_pop_server_callback(mails_path: Path, users: list[User],
return session_cb return session_cb
async def create_pop_server(host: str, async def create_pop_server(
host: str,
port: int, port: int,
mails_path: Path, mails_path: Path,
users: list[User], users: list[User],
ssl_context: ssl.SSLContext | None = None, ssl_context: ssl.SSLContext | None = None,
timeout_seconds: int = 60) -> asyncio.Server: timeout_seconds: int = 60,
) -> asyncio.Server:
logging.info( logging.info(
f"Starting POP3 server {host=}, {port=}, {mails_path=}, {len(users)=}, {ssl_context != None=}, {timeout_seconds=}" f"Starting POP3 server {host=}, {port=}, {mails_path=}, {len(users)=}, {ssl_context != None=}, {timeout_seconds=}"
) )
return await asyncio.start_server(make_pop_server_callback( return await asyncio.start_server(
mails_path, users, timeout_seconds), make_pop_server_callback(mails_path, users, timeout_seconds),
host=host, host=host,
port=port, port=port,
ssl=ssl_context) ssl=ssl_context,
)
async def a_main(*args, **kwargs) -> None: async def a_main(*args, **kwargs) -> None:

View File

@ -11,6 +11,7 @@ class ClientError(Exception):
class ClientQuit(ClientError): class ClientQuit(ClientError):
pass pass
class ClientDisconnected(ClientError): class ClientDisconnected(ClientError):
pass pass
@ -85,7 +86,7 @@ def parse_command(bline: bytes) -> Request:
if parts: if parts:
request.arg2, *parts = parts request.arg2, *parts = parts
if parts: if parts:
request.rest, = parts (request.rest, ) = parts
return request return request
@ -124,7 +125,7 @@ def set_nid(entries: list[MailEntry]):
def get_mail(entry: MailEntry) -> bytes: def get_mail(entry: MailEntry) -> bytes:
with open(entry.path, mode='rb') as fp: with open(entry.path, mode="rb") as fp:
return fp.read() return fp.read()

View File

@ -12,7 +12,7 @@ SCRYPT_R = 8
SCRYPT_P = 1 SCRYPT_P = 1
# If any of above parameters change, version will be incremented # If any of above parameters change, version will be incremented
VERSION = b'\x01' VERSION = b"\x01"
SALT_LEN = 30 SALT_LEN = 30
KEY_LEN = 64 # This is python default KEY_LEN = 64 # This is python default
@ -51,21 +51,26 @@ def parse_hash(pwhash_str: str) -> PWInfo:
def check_pass(password: str, pwinfo: PWInfo) -> bool: def check_pass(password: str, pwinfo: PWInfo) -> bool:
# No need for constant time compare for hashes. See https://security.stackexchange.com/a/46215 # No need for constant time compare for hashes. See https://security.stackexchange.com/a/46215
return pwinfo.scrypt_hash == scrypt(password.encode(), return pwinfo.scrypt_hash == scrypt(
password.encode(),
salt=pwinfo.salt, salt=pwinfo.salt,
n=SCRYPT_N, n=SCRYPT_N,
r=SCRYPT_R, r=SCRYPT_R,
p=SCRYPT_P, p=SCRYPT_P,
dklen=KEY_LEN) dklen=KEY_LEN,
)
if __name__ == '__main__': if __name__ == "__main__":
import sys import sys
if len(sys.argv) == 2: if len(sys.argv) == 2:
print(gen_pwhash(sys.argv[1])) print(gen_pwhash(sys.argv[1]))
elif len(sys.argv) == 3: elif len(sys.argv) == 3:
ok = check_pass(sys.argv[1], parse_hash(sys.argv[2])) ok = check_pass(sys.argv[1], parse_hash(sys.argv[2]))
print("OK" if ok else "NOT OK") print("OK" if ok else "NOT OK")
else: else:
print("Usage: python3 -m mail4one.pwhash <password> [password_hash]", print(
file=sys.stderr) "Usage: python3 -m mail4one.pwhash <password> [password_hash]",
file=sys.stderr,
)

View File

@ -28,7 +28,6 @@ def setup_logging(args):
async def a_main(cfg: config.Config) -> None: async def a_main(cfg: config.Config) -> None:
default_tls_context: ssl.SSLContext | None = None default_tls_context: ssl.SSLContext | None = None
if tls := cfg.default_tls: if tls := cfg.default_tls:
@ -60,7 +59,8 @@ async def a_main(cfg: config.Config) -> None:
mails_path=Path(cfg.mails_path), mails_path=Path(cfg.mails_path),
users=cfg.users, users=cfg.users,
ssl_context=get_tls_context(pop.tls), ssl_context=get_tls_context(pop.tls),
timeout_seconds=pop.timeout_seconds) timeout_seconds=pop.timeout_seconds,
)
servers.append(pop_server) servers.append(pop_server)
if cfg.smtp_starttls: if cfg.smtp_starttls:
@ -73,17 +73,19 @@ async def a_main(cfg: config.Config) -> None:
port=stls.port, port=stls.port,
mails_path=Path(cfg.mails_path), mails_path=Path(cfg.mails_path),
mbox_finder=mbox_finder, mbox_finder=mbox_finder,
ssl_context=stls_context) ssl_context=stls_context,
)
servers.append(smtp_server_starttls) servers.append(smtp_server_starttls)
if cfg.smtp: if cfg.smtp:
smtp = config.SmtpCfg(cfg.smtp) smtp = config.SmtpCfg(cfg.smtp)
smtp_server = await create_smtp_server(host=get_host(smtp.host), smtp_server = await create_smtp_server(
host=get_host(smtp.host),
port=smtp.port, port=smtp.port,
mails_path=Path(cfg.mails_path), mails_path=Path(cfg.mails_path),
mbox_finder=mbox_finder, mbox_finder=mbox_finder,
ssl_context=get_tls_context( ssl_context=get_tls_context(smtp.tls),
smtp.tls)) )
servers.append(smtp_server) servers.append(smtp_server)
if servers: if servers:
@ -93,31 +95,41 @@ async def a_main(cfg: config.Config) -> None:
def main() -> None: def main() -> None:
parser = ArgumentParser(description="Personal Mail Server", epilog="See https://gitea.balki.me/balki/mail4one for more info") parser = ArgumentParser(
description="Personal Mail Server",
epilog="See https://gitea.balki.me/balki/mail4one for more info",
)
parser.add_argument( parser.add_argument(
"-e", "-e",
"--echo_password", "--echo_password",
action="store_true", action="store_true",
help="Show password in command line if -g without password is used") help="Show password in command line if -g without password is used",
)
group = parser.add_mutually_exclusive_group(required=True) group = parser.add_mutually_exclusive_group(required=True)
group.add_argument("-c", group.add_argument(
"-c",
"--config", "--config",
metavar="CONFIG_PATH", metavar="CONFIG_PATH",
type=Path, type=Path,
help="Run mail server with passed config") help="Run mail server with passed config",
group.add_argument("-g", )
group.add_argument(
"-g",
"--genpwhash", "--genpwhash",
nargs="?", nargs="?",
dest="password", dest="password",
const="FROM_TERMINAL", const="FROM_TERMINAL",
metavar="PASSWORD", metavar="PASSWORD",
help="Generate password hash to add in config") help="Generate password hash to add in config",
group.add_argument("-r", )
group.add_argument(
"-r",
"--pwverify", "--pwverify",
dest="password_pwhash", dest="password_pwhash",
nargs=2, nargs=2,
metavar=("PASSWORD", "PWHASH"), metavar=("PASSWORD", "PWHASH"),
help="Check if password matches password hash") help="Check if password matches password hash",
)
args = parser.parse_args() args = parser.parse_args()
if password := args.password: if password := args.password:
if password == "FROM_TERMINAL": if password == "FROM_TERMINAL":
@ -138,5 +150,5 @@ def main() -> None:
asyncio.run(a_main(cfg), debug=cfg.debug) asyncio.run(a_main(cfg), debug=cfg.debug)
if __name__ == '__main__': if __name__ == "__main__":
main() main()

View File

@ -23,7 +23,8 @@ from aiosmtpd.smtp import Session as SMTPSession
class MyHandler(AsyncMessage): class MyHandler(AsyncMessage):
def __init__(self, mails_path: Path, mbox_finder: Callable[[str], list[str]]): def __init__(self, mails_path: Path, mbox_finder: Callable[[str],
list[str]]):
super().__init__() super().__init__()
self.mails_path = mails_path self.mails_path = mails_path
self.mbox_finder = mbox_finder self.mbox_finder = mbox_finder
@ -41,7 +42,7 @@ class MyHandler(AsyncMessage):
if not all_mboxes: if not all_mboxes:
return return
for mbox in all_mboxes: for mbox in all_mboxes:
for sub in ('new', 'tmp', 'cur'): for sub in ("new", "tmp", "cur"):
sub_path = self.mails_path / mbox / sub sub_path = self.mails_path / mbox / sub
sub_path.mkdir(mode=0o755, exist_ok=True, parents=True) sub_path.mkdir(mode=0o755, exist_ok=True, parents=True)
with tempfile.TemporaryDirectory() as tmpdir: with tempfile.TemporaryDirectory() as tmpdir:
@ -50,7 +51,7 @@ class MyHandler(AsyncMessage):
gen = BytesGenerator(fp, policy=email.policy.SMTP) gen = BytesGenerator(fp, policy=email.policy.SMTP)
gen.flatten(m) gen.flatten(m)
for mbox in all_mboxes: for mbox in all_mboxes:
shutil.copy(temp_email_path, self.mails_path / mbox / 'new') shutil.copy(temp_email_path, self.mails_path / mbox / "new")
def protocol_factory_starttls(mails_path: Path, def protocol_factory_starttls(mails_path: Path,
@ -59,17 +60,20 @@ def protocol_factory_starttls(mails_path: Path,
logging.info("Got smtp client cb starttls") logging.info("Got smtp client cb starttls")
try: try:
handler = MyHandler(mails_path, mbox_finder) handler = MyHandler(mails_path, mbox_finder)
smtp = SMTP(handler=handler, smtp = SMTP(
handler=handler,
require_starttls=True, require_starttls=True,
tls_context=context, tls_context=context,
enable_SMTPUTF8=True) enable_SMTPUTF8=True,
)
except Exception as e: except Exception as e:
logging.error("Something went wrong", e) logging.error("Something went wrong", e)
raise raise
return smtp return smtp
def protocol_factory(mails_path: Path, mbox_finder: Callable[[str], list[str]]): def protocol_factory(mails_path: Path, mbox_finder: Callable[[str],
list[str]]):
logging.info("Got smtp client cb") logging.info("Got smtp client cb")
try: try:
handler = MyHandler(mails_path, mbox_finder) handler = MyHandler(mails_path, mbox_finder)
@ -80,30 +84,38 @@ def protocol_factory(mails_path: Path, mbox_finder: Callable[[str], list[str]]):
return smtp return smtp
async def create_smtp_server_starttls(host: str, async def create_smtp_server_starttls(
host: str,
port: int, port: int,
mails_path: Path, mails_path: Path,
mbox_finder: Callable[[str], list[str]], mbox_finder: Callable[[str], list[str]],
ssl_context: ssl.SSLContext) -> asyncio.Server: ssl_context: ssl.SSLContext,
) -> asyncio.Server:
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
return await loop.create_server(partial(protocol_factory_starttls, return await loop.create_server(
mails_path, mbox_finder, ssl_context), partial(protocol_factory_starttls, mails_path, mbox_finder,
ssl_context),
host=host, host=host,
port=port, port=port,
start_serving=False) start_serving=False,
)
async def create_smtp_server(host: str, async def create_smtp_server(
host: str,
port: int, port: int,
mails_path: Path, mails_path: Path,
mbox_finder: Callable[[str], list[str]], mbox_finder: Callable[[str], list[str]],
ssl_context: ssl.SSLContext | None = None) -> asyncio.Server: ssl_context: ssl.SSLContext | None = None,
) -> asyncio.Server:
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
return await loop.create_server(partial(protocol_factory, mails_path, mbox_finder), return await loop.create_server(
partial(protocol_factory, mails_path, mbox_finder),
host=host, host=host,
port=port, port=port,
ssl=ssl_context, ssl=ssl_context,
start_serving=False) start_serving=False,
)
async def a_main(*args, **kwargs): async def a_main(*args, **kwargs):