Add type annotations and use match statement

Now works only in python3.10 and above
This commit is contained in:
Balakrishnan Balasubramanian 2022-12-09 21:57:25 -05:00
parent 6b37b6e50f
commit 7c11878345
2 changed files with 28 additions and 26 deletions

View File

@ -40,7 +40,7 @@
Git config file (`~/.gitconfig`) should look like this Git config file (`~/.gitconfig`) should look like this
```toml ```TOML
[alias] [alias]
... ...
dt = difftool --tool vimtabdiff --dir-diff dt = difftool --tool vimtabdiff --dir-diff
@ -51,12 +51,12 @@ Git config file (`~/.gitconfig`) should look like this
# Known issues # Known issues
1. If your path to custom vim has space, it does not work. i.e. Following does *not* work 1. If your path to custom vim has space, it does not work. i.e. Following does **not** work
```bash ```bash
git config --global difftool.vimtabdiff.cmd 'vimtabdiff.py --vim "/home/foo/my files/bin/vim" $LOCAL $REMOTE' git config --global difftool.vimtabdiff.cmd 'vimtabdiff.py --vim "/home/foo/my files/bin/vim" $LOCAL $REMOTE'
``` ```
2. Not tested in non-linux OS. Pull requests welcome if found any issues. 2. Not tested in non-linux OS. Pull requests welcome if found any issues but hopefully should work fine.
# Similar # Similar

View File

@ -2,32 +2,36 @@
import os import os
import argparse import argparse
import pathlib
import itertools import itertools
import tempfile import tempfile
import subprocess import subprocess
from pathlib import Path
from typing import Callable, TypeVar
from collections.abc import Iterator, Sequence
T = TypeVar('T')
def star(f): def star(f: Callable[..., T]) -> Callable[[Sequence], T]:
""" see https://stackoverflow.com/q/21892989 """ """ see https://stackoverflow.com/q/21892989 """
return lambda args: f(*args) return lambda args: f(*args)
def parse_args(): def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
description="Show diff of files from two directories in vim tabs") description="Show diff of files from two directories in vim tabs",
parser.add_argument("pathA") epilog="See https://github.com/balki/vimtabdiff for more info")
parser.add_argument("pathB") parser.add_argument("pathA", type=Path)
parser.add_argument("pathB", type=Path)
parser.add_argument("--vim", help="vim command to run", default="vim") parser.add_argument("--vim", help="vim command to run", default="vim")
return parser.parse_args() return parser.parse_args()
def get_dir_info(dirname): def get_dir_info(dirpath: Path | None) -> tuple[list[Path], list[Path]]:
if not dirname: if not dirpath:
return [], [] return [], []
dirs, files = [], [] dirs, files = [], []
dirp = pathlib.Path(dirname) for p in dirpath.iterdir():
for p in dirp.iterdir():
if p.is_dir(): if p.is_dir():
dirs.append(p) dirs.append(p)
else: else:
@ -35,27 +39,25 @@ def get_dir_info(dirname):
return dirs, files return dirs, files
def get_pairs(aItems, bItems): def get_pairs(aItems: list[Path],
bItems: list[Path]) -> Iterator[tuple[Path | None, Path | None]]:
aItems = [(item, 'A') for item in aItems] aItems = [(item, 'A') for item in aItems]
bItems = [(item, 'B') for item in bItems] bItems = [(item, 'B') for item in bItems]
abItems = aItems + bItems abItems = aItems + bItems
abItems.sort(key=star(lambda item, tag: (item.name, tag))) abItems.sort(key=star(lambda item, tag: (item.name, tag)))
for _, items in itertools.groupby(abItems, for _, items in itertools.groupby(abItems,
key=star(lambda item, _: item.name)): key=star(lambda item, _: item.name)):
items = list(items) match list(items):
# NOTE: python 3.10's match expression can make this better case [(aItem, _), (bItem, _)]:
if len(items) == 2:
(aItem, _), (bItem, _) = items
yield aItem, bItem yield aItem, bItem
else: case [(item, 'A'),]:
(item, tag), = items
if tag == 'A':
yield item, None yield item, None
else: case [(item, 'B'),]:
yield None, item yield None, item
def get_file_pairs(a, b): def get_file_pairs(a: Path,
b: Path) -> Iterator[tuple[Path | None, Path | None]]:
aDirs, aFiles = get_dir_info(a) aDirs, aFiles = get_dir_info(a)
bDirs, bFiles = get_dir_info(b) bDirs, bFiles = get_dir_info(b)
yield from get_pairs(aFiles, bFiles) yield from get_pairs(aFiles, bFiles)
@ -63,7 +65,7 @@ def get_file_pairs(a, b):
yield from get_file_pairs(aDir, bDir) yield from get_file_pairs(aDir, bDir)
def main(): def main() -> None:
args = parse_args() args = parse_args()
vimCmdFile = tempfile.NamedTemporaryFile(mode='w', delete=False) vimCmdFile = tempfile.NamedTemporaryFile(mode='w', delete=False)
with vimCmdFile: with vimCmdFile: