Skip to content
Snippets Groups Projects
Commit 522f86be authored by chrg's avatar chrg
Browse files

Add --batch option

parent 8b64ce1f
No related branches found
No related tags found
No related merge requests found
from dataclasses import dataclass, field from dataclasses import dataclass, field
from pathlib import Path from pathlib import Path
from typing import Callable from typing import Callable, cast
import subprocess
import csv import csv
import logging import logging
import tempfile import tempfile
from contextlib import contextmanager
import click import click
import git import git
...@@ -99,6 +101,7 @@ class BlobHandler: ...@@ -99,6 +101,7 @@ class BlobHandler:
help="The file to output the csv mapping to", help="The file to output the csv mapping to",
type=click.File("w", lazy=False), type=click.File("w", lazy=False),
) )
@click.option("--batch/--no-batch", default=False)
@click.option( @click.option(
"-v", "-v",
"--verbose", "--verbose",
...@@ -115,6 +118,7 @@ def regit( ...@@ -115,6 +118,7 @@ def regit(
program: Path, program: Path,
args: tuple[str], args: tuple[str],
verbose: int, verbose: int,
batch: bool,
): ):
"""A simple program that runs a command on every commit on a repo.""" """A simple program that runs a command on every commit on a repo."""
...@@ -132,22 +136,6 @@ def regit( ...@@ -132,22 +136,6 @@ def regit(
repo = git.Repo.clone_from(url=repo, to_path=output, no_local=True) repo = git.Repo.clone_from(url=repo, to_path=output, no_local=True)
log.debug("Cloned repo to %s", output) log.debug("Cloned repo to %s", output)
with tempfile.TemporaryDirectory() as folder:
folder = Path(folder)
def transformer(file: Path, content: bytes) -> bytes:
pargs = list(args)
try:
ix = pargs.index("{}")
except ValueError:
return utils.run_stdout([program] + pargs, input=content)
else:
with utils.tfile(folder, file.name, content) as tmp_file:
pargs[ix] = str(tmp_file)
utils.run([program] + pargs)
with open(tmp_file, "rb") as f:
return f.read()
def is_relevant(file: Path): def is_relevant(file: Path):
if pattern is None: if pattern is None:
return True return True
...@@ -155,6 +143,7 @@ def regit( ...@@ -155,6 +143,7 @@ def regit(
log.debug("Check if %s matched pattern %s", file, match) log.debug("Check if %s matched pattern %s", file, match)
return match return match
with mktransformer(program, args, batch) as transformer:
handler = BlobHandler(repo, is_relevant=is_relevant, transform=transformer) handler = BlobHandler(repo, is_relevant=is_relevant, transform=transformer)
log.debug("Starting handler") log.debug("Starting handler")
with handler: with handler:
...@@ -176,5 +165,67 @@ def regit( ...@@ -176,5 +165,67 @@ def regit(
print(repo.working_dir) print(repo.working_dir)
@contextmanager
def mktransformer(program, args, batch):
with tempfile.TemporaryDirectory() as folder:
folder = Path(folder)
pargs = list(args)
if batch:
formatproc = utils.popen(
[program] + pargs,
stdin=subprocess.PIPE,
universal_newlines=False,
)
if formatproc.stdin is None:
raise RuntimeError("Could not start the formatting program")
fin = formatproc.stdin
def transformer(file: Path, content: bytes) -> bytes:
with utils.tfile(folder, file.name, content) as tmp_file:
try:
fin.write(str(tmp_file).encode("utf-8") + b"\n") # type: ignore
fin.flush()
except subprocess.CalledProcessError:
file = Path(file.name).with_suffix("input").absolute()
log.error("Writing argument to %s", file)
with open(file, "wb") as f:
f.write(content)
raise
with open(tmp_file, "rb") as f:
return f.read()
yield transformer
formatproc.stdin.close()
formatproc.wait()
elif "{}" in pargs:
ix = pargs.index("{}")
def transformer(file: Path, content: bytes) -> bytes:
with utils.tfile(folder, file.name, content) as tmp_file:
pargs[ix] = str(tmp_file)
try:
utils.run([program] + pargs)
except subprocess.CalledProcessError:
file = Path(file.name).with_suffix("input").absolute()
log.error("Writing argument to %s", file)
with open(file, "wb") as f:
f.write(content)
raise
with open(tmp_file, "rb") as f:
return f.read()
yield transformer
else:
def transformer(file: Path, content: bytes) -> bytes:
return utils.run_stdout([program] + pargs, input=content)
yield transformer
if __name__ == "__main__": if __name__ == "__main__":
regit() regit()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment