format using black
This commit is contained in:
		
							
								
								
									
										3
									
								
								Makefile
									
									
									
									
									
								
							
							
						
						
									
										3
									
								
								Makefile
									
									
									
									
									
								
							@@ -28,3 +28,6 @@ build: clean requirements.txt
 | 
				
			|||||||
clean:
 | 
					clean:
 | 
				
			||||||
	rm -rf build
 | 
						rm -rf build
 | 
				
			||||||
	rm -rf mail4one.pyz
 | 
						rm -rf mail4one.pyz
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					format:
 | 
				
			||||||
 | 
						black mail4one/*py
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -56,13 +56,13 @@ class PopCfg(ServerCfg):
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
class SmtpStartTLSCfg(ServerCfg):
 | 
					class SmtpStartTLSCfg(ServerCfg):
 | 
				
			||||||
    server_type = "smtp_starttls"
 | 
					    server_type = "smtp_starttls"
 | 
				
			||||||
    smtputf8 = True # Not used yet
 | 
					    smtputf8 = True  # Not used yet
 | 
				
			||||||
    port = 25
 | 
					    port = 25
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class SmtpCfg(ServerCfg):
 | 
					class SmtpCfg(ServerCfg):
 | 
				
			||||||
    server_type = "smtp_starttls"
 | 
					    server_type = "smtp_starttls"
 | 
				
			||||||
    smtputf8 = True # Not used yet
 | 
					    smtputf8 = True  # Not used yet
 | 
				
			||||||
    port = 465
 | 
					    port = 465
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -89,7 +89,6 @@ Checker = tuple[str, CheckerFn, bool]
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def parse_checkers(cfg: Config) -> list[Checker]:
 | 
					def parse_checkers(cfg: Config) -> list[Checker]:
 | 
				
			||||||
 | 
					 | 
				
			||||||
    def make_match_fn(m: Match):
 | 
					    def make_match_fn(m: Match):
 | 
				
			||||||
        if m.addrs and m.addr_rexs:
 | 
					        if m.addrs and m.addr_rexs:
 | 
				
			||||||
            raise Exception("Both addrs and addr_rexs is set")
 | 
					            raise Exception("Both addrs and addr_rexs is set")
 | 
				
			||||||
@@ -97,8 +96,7 @@ def parse_checkers(cfg: Config) -> list[Checker]:
 | 
				
			|||||||
            return lambda malias: malias in m.addrs
 | 
					            return lambda malias: malias in m.addrs
 | 
				
			||||||
        elif m.addr_rexs:
 | 
					        elif m.addr_rexs:
 | 
				
			||||||
            compiled_res = [re.compile(reg) for reg in m.addr_rexs]
 | 
					            compiled_res = [re.compile(reg) for reg in m.addr_rexs]
 | 
				
			||||||
            return lambda malias: any(
 | 
					            return lambda malias: any(reg.match(malias) for reg in compiled_res)
 | 
				
			||||||
                reg.match(malias) for reg in compiled_res)
 | 
					 | 
				
			||||||
        else:
 | 
					        else:
 | 
				
			||||||
            raise Exception("Neither addrs nor addr_rexs is set")
 | 
					            raise Exception("Neither addrs nor addr_rexs is set")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -114,13 +112,13 @@ def parse_checkers(cfg: Config) -> list[Checker]:
 | 
				
			|||||||
        return mbox_name, match_fn, rule.stop_check
 | 
					        return mbox_name, match_fn, rule.stop_check
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    return [
 | 
					    return [
 | 
				
			||||||
        make_checker(mbox.name, Rule(rule)) for mbox in cfg.boxes or []
 | 
					        make_checker(mbox.name, Rule(rule))
 | 
				
			||||||
 | 
					        for mbox in cfg.boxes or []
 | 
				
			||||||
        for rule in mbox.rules
 | 
					        for rule in mbox.rules
 | 
				
			||||||
    ]
 | 
					    ]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def get_mboxes(addr: str, checks: list[Checker]) -> list[str]:
 | 
					def get_mboxes(addr: str, checks: list[Checker]) -> list[str]:
 | 
				
			||||||
 | 
					 | 
				
			||||||
    def inner():
 | 
					    def inner():
 | 
				
			||||||
        for mbox, match_fn, stop_check in checks:
 | 
					        for mbox, match_fn, stop_check in checks:
 | 
				
			||||||
            if match_fn(addr):
 | 
					            if match_fn(addr):
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -44,7 +44,6 @@ class State:
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class SharedState:
 | 
					class SharedState:
 | 
				
			||||||
 | 
					 | 
				
			||||||
    def __init__(self, mails_path: Path, users: dict[str, tuple[PWInfo, str]]):
 | 
					    def __init__(self, mails_path: Path, users: dict[str, tuple[PWInfo, str]]):
 | 
				
			||||||
        self.mails_path = mails_path
 | 
					        self.mails_path = mails_path
 | 
				
			||||||
        self.users = users
 | 
					        self.users = users
 | 
				
			||||||
@@ -56,8 +55,7 @@ class SharedState:
 | 
				
			|||||||
        return self.counter
 | 
					        return self.counter
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
c_shared_state: contextvars.ContextVar = contextvars.ContextVar(
 | 
					c_shared_state: contextvars.ContextVar = contextvars.ContextVar("pop_shared_state")
 | 
				
			||||||
    "pop_shared_state")
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def scfg() -> SharedState:
 | 
					def scfg() -> SharedState:
 | 
				
			||||||
@@ -72,7 +70,6 @@ def state() -> State:
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class PopLogger(logging.LoggerAdapter):
 | 
					class PopLogger(logging.LoggerAdapter):
 | 
				
			||||||
 | 
					 | 
				
			||||||
    def __init__(self):
 | 
					    def __init__(self):
 | 
				
			||||||
        super().__init__(logging.getLogger("pop3"), None)
 | 
					        super().__init__(logging.getLogger("pop3"), None)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -280,8 +277,7 @@ def get_deleted_items(deleted_items_path: Path) -> set[str]:
 | 
				
			|||||||
    return set()
 | 
					    return set()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def save_deleted_items(deleted_items_path: Path,
 | 
					def save_deleted_items(deleted_items_path: Path, deleted_items: set[str]) -> None:
 | 
				
			||||||
                       deleted_items: set[str]) -> None:
 | 
					 | 
				
			||||||
    with deleted_items_path.open(mode="w") as f:
 | 
					    with deleted_items_path.open(mode="w") as f:
 | 
				
			||||||
        f.writelines(f"{did}\n" for did in deleted_items)
 | 
					        f.writelines(f"{did}\n" for did in deleted_items)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -298,8 +294,9 @@ async def transaction_stage() -> None:
 | 
				
			|||||||
    new_deleted_items: set[str] = await process_transactions(mails_list)
 | 
					    new_deleted_items: set[str] = await process_transactions(mails_list)
 | 
				
			||||||
    logger.info(f"completed transactions. Deleted:{len(new_deleted_items)}")
 | 
					    logger.info(f"completed transactions. Deleted:{len(new_deleted_items)}")
 | 
				
			||||||
    if new_deleted_items:
 | 
					    if new_deleted_items:
 | 
				
			||||||
        save_deleted_items(deleted_items_path,
 | 
					        save_deleted_items(
 | 
				
			||||||
                           existing_deleted_items.union(new_deleted_items))
 | 
					            deleted_items_path, existing_deleted_items.union(new_deleted_items)
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    logger.info(f"Saved deleted items")
 | 
					    logger.info(f"Saved deleted items")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -330,7 +327,6 @@ async def start_session() -> None:
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def parse_users(users: list[User]) -> dict[str, tuple[PWInfo, str]]:
 | 
					def parse_users(users: list[User]) -> dict[str, tuple[PWInfo, str]]:
 | 
				
			||||||
 | 
					 | 
				
			||||||
    def inner():
 | 
					    def inner():
 | 
				
			||||||
        for user in users:
 | 
					        for user in users:
 | 
				
			||||||
            user = User(user)
 | 
					            user = User(user)
 | 
				
			||||||
@@ -340,15 +336,13 @@ def parse_users(users: list[User]) -> dict[str, tuple[PWInfo, str]]:
 | 
				
			|||||||
    return dict(inner())
 | 
					    return dict(inner())
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def make_pop_server_callback(mails_path: Path, users: list[User],
 | 
					def make_pop_server_callback(mails_path: Path, users: list[User], timeout_seconds: int):
 | 
				
			||||||
                             timeout_seconds: int):
 | 
					 | 
				
			||||||
    scfg = SharedState(mails_path=mails_path, users=parse_users(users))
 | 
					    scfg = SharedState(mails_path=mails_path, users=parse_users(users))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    async def session_cb(reader: StreamReader, writer: StreamWriter):
 | 
					    async def session_cb(reader: StreamReader, writer: StreamWriter):
 | 
				
			||||||
        c_shared_state.set(scfg)
 | 
					        c_shared_state.set(scfg)
 | 
				
			||||||
        ip, _ = writer.get_extra_info("peername")
 | 
					        ip, _ = writer.get_extra_info("peername")
 | 
				
			||||||
        c_state.set(
 | 
					        c_state.set(State(reader=reader, writer=writer, ip=ip, req_id=scfg.next_id()))
 | 
				
			||||||
            State(reader=reader, writer=writer, ip=ip, req_id=scfg.next_id()))
 | 
					 | 
				
			||||||
        logger.info(f"Got pop server callback")
 | 
					        logger.info(f"Got pop server callback")
 | 
				
			||||||
        try:
 | 
					        try:
 | 
				
			||||||
            try:
 | 
					            try:
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -86,7 +86,7 @@ def parse_command(bline: bytes) -> Request:
 | 
				
			|||||||
    if parts:
 | 
					    if parts:
 | 
				
			||||||
        request.arg2, *parts = parts
 | 
					        request.arg2, *parts = parts
 | 
				
			||||||
    if parts:
 | 
					    if parts:
 | 
				
			||||||
        (request.rest, ) = parts
 | 
					        (request.rest,) = parts
 | 
				
			||||||
    return request
 | 
					    return request
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -130,7 +130,6 @@ def get_mail(entry: MailEntry) -> bytes:
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class MailList:
 | 
					class MailList:
 | 
				
			||||||
 | 
					 | 
				
			||||||
    def __init__(self, entries: list[MailEntry]):
 | 
					    def __init__(self, entries: list[MailEntry]):
 | 
				
			||||||
        self.entries = entries
 | 
					        self.entries = entries
 | 
				
			||||||
        set_nid(self.entries)
 | 
					        set_nid(self.entries)
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -19,17 +19,13 @@ KEY_LEN = 64  # This is python default
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
def gen_pwhash(password: str) -> str:
 | 
					def gen_pwhash(password: str) -> str:
 | 
				
			||||||
    salt = os.urandom(SALT_LEN)
 | 
					    salt = os.urandom(SALT_LEN)
 | 
				
			||||||
    sh = scrypt(password.encode(),
 | 
					    sh = scrypt(
 | 
				
			||||||
                salt=salt,
 | 
					        password.encode(), salt=salt, n=SCRYPT_N, r=SCRYPT_R, p=SCRYPT_P, dklen=KEY_LEN
 | 
				
			||||||
                n=SCRYPT_N,
 | 
					    )
 | 
				
			||||||
                r=SCRYPT_R,
 | 
					 | 
				
			||||||
                p=SCRYPT_P,
 | 
					 | 
				
			||||||
                dklen=KEY_LEN)
 | 
					 | 
				
			||||||
    return b32encode(VERSION + salt + sh).decode()
 | 
					    return b32encode(VERSION + salt + sh).decode()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class PWInfo:
 | 
					class PWInfo:
 | 
				
			||||||
 | 
					 | 
				
			||||||
    def __init__(self, salt: bytes, sh: bytes):
 | 
					    def __init__(self, salt: bytes, sh: bytes):
 | 
				
			||||||
        self.salt = salt
 | 
					        self.salt = salt
 | 
				
			||||||
        self.scrypt_hash = sh
 | 
					        self.scrypt_hash = sh
 | 
				
			||||||
@@ -40,12 +36,13 @@ def parse_hash(pwhash_str: str) -> PWInfo:
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
    if not len(pwhash) == 1 + SALT_LEN + KEY_LEN:
 | 
					    if not len(pwhash) == 1 + SALT_LEN + KEY_LEN:
 | 
				
			||||||
        raise Exception(
 | 
					        raise Exception(
 | 
				
			||||||
            f"Invalid hash size, {len(pwhash)} !=  {1 + SALT_LEN + KEY_LEN}")
 | 
					            f"Invalid hash size, {len(pwhash)} !=  {1 + SALT_LEN + KEY_LEN}"
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    if (ver := pwhash[0:1]) != VERSION:
 | 
					    if (ver := pwhash[0:1]) != VERSION:
 | 
				
			||||||
        raise Exception(f"Invalid hash version, {ver!r} !=  {VERSION!r}")
 | 
					        raise Exception(f"Invalid hash version, {ver!r} !=  {VERSION!r}")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    salt, sh = pwhash[1:SALT_LEN + 1], pwhash[-KEY_LEN:]
 | 
					    salt, sh = pwhash[1 : SALT_LEN + 1], pwhash[-KEY_LEN:]
 | 
				
			||||||
    return PWInfo(salt, sh)
 | 
					    return PWInfo(salt, sh)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -22,21 +22,22 @@ def create_tls_context(certfile, keyfile) -> ssl.SSLContext:
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def setup_logging(cfg: config.LogCfg):
 | 
					def setup_logging(cfg: config.LogCfg):
 | 
				
			||||||
    logging_format = "%(asctime)s %(name)s %(levelname)s %(message)s @ %(filename)s:%(lineno)d"
 | 
					    logging_format = (
 | 
				
			||||||
 | 
					        "%(asctime)s %(name)s %(levelname)s %(message)s @ %(filename)s:%(lineno)d"
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
    if cfg.logfile == "CONSOLE":
 | 
					    if cfg.logfile == "CONSOLE":
 | 
				
			||||||
        logging.basicConfig(level=cfg.level, format=logging_format)
 | 
					        logging.basicConfig(level=cfg.level, format=logging_format)
 | 
				
			||||||
    else:
 | 
					    else:
 | 
				
			||||||
        logging.basicConfig(filename=cfg.logfile,
 | 
					        logging.basicConfig(
 | 
				
			||||||
                            level=cfg.level,
 | 
					            filename=cfg.logfile, level=cfg.level, format=logging_format
 | 
				
			||||||
                            format=logging_format)
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
async def a_main(cfg: config.Config) -> None:
 | 
					async def a_main(cfg: config.Config) -> None:
 | 
				
			||||||
    default_tls_context: ssl.SSLContext | None = None
 | 
					    default_tls_context: ssl.SSLContext | None = None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    if tls := cfg.default_tls:
 | 
					    if tls := cfg.default_tls:
 | 
				
			||||||
        logging.info(
 | 
					        logging.info(f"Initializing default tls {tls.certfile=}, {tls.keyfile=}")
 | 
				
			||||||
            f"Initializing default tls {tls.certfile=}, {tls.keyfile=}")
 | 
					 | 
				
			||||||
        default_tls_context = create_tls_context(tls.certfile, tls.keyfile)
 | 
					        default_tls_context = create_tls_context(tls.certfile, tls.keyfile)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def get_tls_context(tls: config.TLSCfg | str):
 | 
					    def get_tls_context(tls: config.TLSCfg | str):
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -25,15 +25,14 @@ logger = logging.getLogger("smtp")
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class MyHandler(AsyncMessage):
 | 
					class MyHandler(AsyncMessage):
 | 
				
			||||||
 | 
					    def __init__(self, mails_path: Path, mbox_finder: Callable[[str], list[str]]):
 | 
				
			||||||
    def __init__(self, mails_path: Path, mbox_finder: Callable[[str],
 | 
					 | 
				
			||||||
                                                               list[str]]):
 | 
					 | 
				
			||||||
        super().__init__()
 | 
					        super().__init__()
 | 
				
			||||||
        self.mails_path = mails_path
 | 
					        self.mails_path = mails_path
 | 
				
			||||||
        self.mbox_finder = mbox_finder
 | 
					        self.mbox_finder = mbox_finder
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    async def handle_DATA(self, server: SMTPServer, session: SMTPSession,
 | 
					    async def handle_DATA(
 | 
				
			||||||
                          envelope: SMTPEnvelope) -> str:
 | 
					        self, server: SMTPServer, session: SMTPSession, envelope: SMTPEnvelope
 | 
				
			||||||
 | 
					    ) -> str:
 | 
				
			||||||
        self.rcpt_tos = envelope.rcpt_tos
 | 
					        self.rcpt_tos = envelope.rcpt_tos
 | 
				
			||||||
        self.peer = session.peer
 | 
					        self.peer = session.peer
 | 
				
			||||||
        return await super().handle_DATA(server, session, envelope)
 | 
					        return await super().handle_DATA(server, session, envelope)
 | 
				
			||||||
@@ -63,9 +62,9 @@ class MyHandler(AsyncMessage):
 | 
				
			|||||||
            )
 | 
					            )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def protocol_factory_starttls(mails_path: Path,
 | 
					def protocol_factory_starttls(
 | 
				
			||||||
                              mbox_finder: Callable[[str], list[str]],
 | 
					    mails_path: Path, mbox_finder: Callable[[str], list[str]], context: ssl.SSLContext
 | 
				
			||||||
                              context: ssl.SSLContext):
 | 
					):
 | 
				
			||||||
    logger.info("Got smtp client cb starttls")
 | 
					    logger.info("Got smtp client cb starttls")
 | 
				
			||||||
    try:
 | 
					    try:
 | 
				
			||||||
        handler = MyHandler(mails_path, mbox_finder)
 | 
					        handler = MyHandler(mails_path, mbox_finder)
 | 
				
			||||||
@@ -81,8 +80,7 @@ def protocol_factory_starttls(mails_path: Path,
 | 
				
			|||||||
    return smtp
 | 
					    return smtp
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def protocol_factory(mails_path: Path, mbox_finder: Callable[[str],
 | 
					def protocol_factory(mails_path: Path, mbox_finder: Callable[[str], list[str]]):
 | 
				
			||||||
                                                             list[str]]):
 | 
					 | 
				
			||||||
    logger.info("Got smtp client cb")
 | 
					    logger.info("Got smtp client cb")
 | 
				
			||||||
    try:
 | 
					    try:
 | 
				
			||||||
        handler = MyHandler(mails_path, mbox_finder)
 | 
					        handler = MyHandler(mails_path, mbox_finder)
 | 
				
			||||||
@@ -105,8 +103,7 @@ async def create_smtp_server_starttls(
 | 
				
			|||||||
    )
 | 
					    )
 | 
				
			||||||
    loop = asyncio.get_event_loop()
 | 
					    loop = asyncio.get_event_loop()
 | 
				
			||||||
    return await loop.create_server(
 | 
					    return await loop.create_server(
 | 
				
			||||||
        partial(protocol_factory_starttls, mails_path, mbox_finder,
 | 
					        partial(protocol_factory_starttls, mails_path, mbox_finder, ssl_context),
 | 
				
			||||||
                ssl_context),
 | 
					 | 
				
			||||||
        host=host,
 | 
					        host=host,
 | 
				
			||||||
        port=port,
 | 
					        port=port,
 | 
				
			||||||
        start_serving=False,
 | 
					        start_serving=False,
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -1,2 +1 @@
 | 
				
			|||||||
 | 
					 | 
				
			||||||
VERSION = "DEVELOMENT"
 | 
					VERSION = "DEVELOMENT"
 | 
				
			||||||
 
 | 
				
			|||||||
		Reference in New Issue
	
	Block a user