Format using black and then lsp format

This commit is contained in:
Balakrishnan Balasubramanian 2023-06-16 21:55:57 -04:00
parent 0ed2341d68
commit 2fa748c444
6 changed files with 138 additions and 89 deletions

View File

@ -63,7 +63,7 @@ class SmtpCfg(ServerCfg):
class Config(Jata):
default_tls: TLSCfg | None
default_host: str = '0.0.0.0'
default_host: str = "0.0.0.0"
debug: bool = False
mails_path: str

View File

@ -11,9 +11,24 @@ from .config import User
from .pwhash import parse_hash, check_pass, PWInfo
from asyncio import StreamReader, StreamWriter
from .poputils import InvalidCommand, parse_command, err, Command, \
ClientQuit, ClientDisconnected, ClientError, AuthError, ok, \
msg, end, Request, MailEntry, get_mail, get_mails_list, MailList
from .poputils import (
InvalidCommand,
parse_command,
err,
Command,
ClientQuit,
ClientDisconnected,
ClientError,
AuthError,
ok,
msg,
end,
Request,
MailEntry,
get_mail,
get_mails_list,
MailList,
)
async def next_req() -> Request:
@ -85,7 +100,8 @@ async def auth_stage() -> None:
await handle_user_pass_auth(req)
if state().username in config().loggedin_users:
logging.warning(
f"User: {state().username} already has an active session")
f"User: {state().username} already has an active session"
)
raise AuthError("Already logged in")
else:
config().loggedin_users.add(state().username)
@ -217,7 +233,7 @@ async def transaction_stage() -> None:
existing_deleted_items: set[str] = get_deleted_items(deleted_items_path)
mails_list = [
entry
for entry in get_mails_list(config().mails_path / state().mbox / 'new')
for entry in get_mails_list(config().mails_path / state().mbox / "new")
if entry.uid not in existing_deleted_items
]
@ -282,14 +298,14 @@ class Config:
self.loggedin_users: set[str] = set()
c_config: contextvars.ContextVar = contextvars.ContextVar('config')
c_config: contextvars.ContextVar = contextvars.ContextVar("config")
def config() -> Config:
return c_config.get()
c_state: contextvars.ContextVar = contextvars.ContextVar('state')
c_state: contextvars.ContextVar = contextvars.ContextVar("state")
def state() -> State:
@ -312,20 +328,23 @@ def make_pop_server_callback(mails_path: Path, users: list[User],
return session_cb
async def create_pop_server(host: str,
async def create_pop_server(
host: str,
port: int,
mails_path: Path,
users: list[User],
ssl_context: ssl.SSLContext | None = None,
timeout_seconds: int = 60) -> asyncio.Server:
timeout_seconds: int = 60,
) -> asyncio.Server:
logging.info(
f"Starting POP3 server {host=}, {port=}, {mails_path=}, {len(users)=}, {ssl_context != None=}, {timeout_seconds=}"
)
return await asyncio.start_server(make_pop_server_callback(
mails_path, users, timeout_seconds),
return await asyncio.start_server(
make_pop_server_callback(mails_path, users, timeout_seconds),
host=host,
port=port,
ssl=ssl_context)
ssl=ssl_context,
)
async def a_main(*args, **kwargs) -> None:

View File

@ -11,6 +11,7 @@ class ClientError(Exception):
class ClientQuit(ClientError):
pass
class ClientDisconnected(ClientError):
pass
@ -85,7 +86,7 @@ def parse_command(bline: bytes) -> Request:
if parts:
request.arg2, *parts = parts
if parts:
request.rest, = parts
(request.rest, ) = parts
return request
@ -124,7 +125,7 @@ def set_nid(entries: list[MailEntry]):
def get_mail(entry: MailEntry) -> bytes:
with open(entry.path, mode='rb') as fp:
with open(entry.path, mode="rb") as fp:
return fp.read()

View File

@ -12,7 +12,7 @@ SCRYPT_R = 8
SCRYPT_P = 1
# If any of above parameters change, version will be incremented
VERSION = b'\x01'
VERSION = b"\x01"
SALT_LEN = 30
KEY_LEN = 64 # This is python default
@ -51,21 +51,26 @@ def parse_hash(pwhash_str: str) -> PWInfo:
def check_pass(password: str, pwinfo: PWInfo) -> bool:
# No need for constant time compare for hashes. See https://security.stackexchange.com/a/46215
return pwinfo.scrypt_hash == scrypt(password.encode(),
return pwinfo.scrypt_hash == scrypt(
password.encode(),
salt=pwinfo.salt,
n=SCRYPT_N,
r=SCRYPT_R,
p=SCRYPT_P,
dklen=KEY_LEN)
dklen=KEY_LEN,
)
if __name__ == '__main__':
if __name__ == "__main__":
import sys
if len(sys.argv) == 2:
print(gen_pwhash(sys.argv[1]))
elif len(sys.argv) == 3:
ok = check_pass(sys.argv[1], parse_hash(sys.argv[2]))
print("OK" if ok else "NOT OK")
else:
print("Usage: python3 -m mail4one.pwhash <password> [password_hash]",
file=sys.stderr)
print(
"Usage: python3 -m mail4one.pwhash <password> [password_hash]",
file=sys.stderr,
)

View File

@ -28,7 +28,6 @@ def setup_logging(args):
async def a_main(cfg: config.Config) -> None:
default_tls_context: ssl.SSLContext | None = None
if tls := cfg.default_tls:
@ -60,7 +59,8 @@ async def a_main(cfg: config.Config) -> None:
mails_path=Path(cfg.mails_path),
users=cfg.users,
ssl_context=get_tls_context(pop.tls),
timeout_seconds=pop.timeout_seconds)
timeout_seconds=pop.timeout_seconds,
)
servers.append(pop_server)
if cfg.smtp_starttls:
@ -73,17 +73,19 @@ async def a_main(cfg: config.Config) -> None:
port=stls.port,
mails_path=Path(cfg.mails_path),
mbox_finder=mbox_finder,
ssl_context=stls_context)
ssl_context=stls_context,
)
servers.append(smtp_server_starttls)
if cfg.smtp:
smtp = config.SmtpCfg(cfg.smtp)
smtp_server = await create_smtp_server(host=get_host(smtp.host),
smtp_server = await create_smtp_server(
host=get_host(smtp.host),
port=smtp.port,
mails_path=Path(cfg.mails_path),
mbox_finder=mbox_finder,
ssl_context=get_tls_context(
smtp.tls))
ssl_context=get_tls_context(smtp.tls),
)
servers.append(smtp_server)
if servers:
@ -93,31 +95,41 @@ async def a_main(cfg: config.Config) -> None:
def main() -> None:
parser = ArgumentParser(description="Personal Mail Server", epilog="See https://gitea.balki.me/balki/mail4one for more info")
parser = ArgumentParser(
description="Personal Mail Server",
epilog="See https://gitea.balki.me/balki/mail4one for more info",
)
parser.add_argument(
"-e",
"--echo_password",
action="store_true",
help="Show password in command line if -g without password is used")
help="Show password in command line if -g without password is used",
)
group = parser.add_mutually_exclusive_group(required=True)
group.add_argument("-c",
group.add_argument(
"-c",
"--config",
metavar="CONFIG_PATH",
type=Path,
help="Run mail server with passed config")
group.add_argument("-g",
help="Run mail server with passed config",
)
group.add_argument(
"-g",
"--genpwhash",
nargs="?",
dest="password",
const="FROM_TERMINAL",
metavar="PASSWORD",
help="Generate password hash to add in config")
group.add_argument("-r",
help="Generate password hash to add in config",
)
group.add_argument(
"-r",
"--pwverify",
dest="password_pwhash",
nargs=2,
metavar=("PASSWORD", "PWHASH"),
help="Check if password matches password hash")
help="Check if password matches password hash",
)
args = parser.parse_args()
if password := args.password:
if password == "FROM_TERMINAL":
@ -138,5 +150,5 @@ def main() -> None:
asyncio.run(a_main(cfg), debug=cfg.debug)
if __name__ == '__main__':
if __name__ == "__main__":
main()

View File

@ -23,7 +23,8 @@ from aiosmtpd.smtp import Session as SMTPSession
class MyHandler(AsyncMessage):
def __init__(self, mails_path: Path, mbox_finder: Callable[[str], list[str]]):
def __init__(self, mails_path: Path, mbox_finder: Callable[[str],
list[str]]):
super().__init__()
self.mails_path = mails_path
self.mbox_finder = mbox_finder
@ -41,7 +42,7 @@ class MyHandler(AsyncMessage):
if not all_mboxes:
return
for mbox in all_mboxes:
for sub in ('new', 'tmp', 'cur'):
for sub in ("new", "tmp", "cur"):
sub_path = self.mails_path / mbox / sub
sub_path.mkdir(mode=0o755, exist_ok=True, parents=True)
with tempfile.TemporaryDirectory() as tmpdir:
@ -50,7 +51,7 @@ class MyHandler(AsyncMessage):
gen = BytesGenerator(fp, policy=email.policy.SMTP)
gen.flatten(m)
for mbox in all_mboxes:
shutil.copy(temp_email_path, self.mails_path / mbox / 'new')
shutil.copy(temp_email_path, self.mails_path / mbox / "new")
def protocol_factory_starttls(mails_path: Path,
@ -59,17 +60,20 @@ def protocol_factory_starttls(mails_path: Path,
logging.info("Got smtp client cb starttls")
try:
handler = MyHandler(mails_path, mbox_finder)
smtp = SMTP(handler=handler,
smtp = SMTP(
handler=handler,
require_starttls=True,
tls_context=context,
enable_SMTPUTF8=True)
enable_SMTPUTF8=True,
)
except Exception as e:
logging.error("Something went wrong", e)
raise
return smtp
def protocol_factory(mails_path: Path, mbox_finder: Callable[[str], list[str]]):
def protocol_factory(mails_path: Path, mbox_finder: Callable[[str],
list[str]]):
logging.info("Got smtp client cb")
try:
handler = MyHandler(mails_path, mbox_finder)
@ -80,30 +84,38 @@ def protocol_factory(mails_path: Path, mbox_finder: Callable[[str], list[str]]):
return smtp
async def create_smtp_server_starttls(host: str,
async def create_smtp_server_starttls(
host: str,
port: int,
mails_path: Path,
mbox_finder: Callable[[str], list[str]],
ssl_context: ssl.SSLContext) -> asyncio.Server:
ssl_context: ssl.SSLContext,
) -> asyncio.Server:
loop = asyncio.get_event_loop()
return await loop.create_server(partial(protocol_factory_starttls,
mails_path, mbox_finder, ssl_context),
return await loop.create_server(
partial(protocol_factory_starttls, mails_path, mbox_finder,
ssl_context),
host=host,
port=port,
start_serving=False)
start_serving=False,
)
async def create_smtp_server(host: str,
async def create_smtp_server(
host: str,
port: int,
mails_path: Path,
mbox_finder: Callable[[str], list[str]],
ssl_context: ssl.SSLContext | None = None) -> asyncio.Server:
ssl_context: ssl.SSLContext | None = None,
) -> asyncio.Server:
loop = asyncio.get_event_loop()
return await loop.create_server(partial(protocol_factory, mails_path, mbox_finder),
return await loop.create_server(
partial(protocol_factory, mails_path, mbox_finder),
host=host,
port=port,
ssl=ssl_context,
start_serving=False)
start_serving=False,
)
async def a_main(*args, **kwargs):