Skip to content
Snippets Groups Projects
Commit 488e7ca7 authored by chrg's avatar chrg
Browse files

Add tricky case

parent 66f014d0
Branches
No related tags found
No related merge requests found
...@@ -2,23 +2,112 @@ ...@@ -2,23 +2,112 @@
""" The jpamb evaluator """ The jpamb evaluator
""" """
from io import StringIO
from typing import TextIO, TypeVar
from collections import defaultdict
import click import click
import subprocess import subprocess
import re
import sys import sys
from dataclasses import dataclass from dataclasses import dataclass
prim = bool | int prim = bool | int
@dataclass W = TypeVar("W", bound=TextIO)
def gettargets():
return tuple(sorted(("divide by zero", "*", "assertion error", "ok")))
def print_prim(i: prim, file: W = sys.stdout) -> W:
if isinstance(i, bool):
if i:
file.write("true")
else:
file.write("false")
else:
print(i, file=file, end="")
return file
@dataclass(frozen=True)
class Input:
val: tuple[prim, ...]
@staticmethod
def parse(string: str) -> "Input":
if not (m := re.match(r"\(([^)]*)\)", string)):
raise ValueError(f"Invalid inputs: {string!r}")
parsed_args = []
for i in m.group(1).split(","):
i = i.strip()
if not i:
continue
if i == "true":
parsed_args.append(True)
elif i == "false":
parsed_args.append(False)
else:
parsed_args.append(int(i))
return Input(tuple(parsed_args))
def __str__(self) -> str:
return self.print(StringIO()).getvalue()
def print(self, file: W = sys.stdout) -> W:
open, close = "()"
file.write(open)
if self.val:
print_prim(self.val[0], file=file)
for i in self.val[1:]:
file.write(", ")
print_prim(i, file=file)
file.write(close)
return file
@dataclass(frozen=True)
class Case: class Case:
methodid: str methodid: str
input: str input: Input
result: str result: str
def rebuild(): @dataclass(frozen=True)
subprocess.call(["mvn", "compile"]) class Prediction:
wager: float
@staticmethod
def parse(string: str) -> "Prediction":
if m := re.match(r"([^%]*)\%", string):
p = float(m.group(1)) / 100
return Prediction.from_probability(p)
else:
return Prediction(float(string))
@staticmethod
def from_probability(p: float) -> "Prediction":
negate = False
if p < 0.5:
p = 1 - p
negate = True
if p == 1:
x = float("inf")
else:
x = ((1 - p) / p) * 0.5
return Prediction(-x if negate else x)
def score(self, happens: bool):
wager = (-1 if not happens else 1) * self.wager
if wager > 0:
if wager == float("inf"):
return 1
else:
return 1 - 1 / (wager + 1)
else:
return wager
def runtime(args, enable_assertions=False, **kwargs): def runtime(args, enable_assertions=False, **kwargs):
...@@ -33,13 +122,56 @@ def runtime(args, enable_assertions=False, **kwargs): ...@@ -33,13 +122,56 @@ def runtime(args, enable_assertions=False, **kwargs):
return subprocess.check_output(pargs, text=True, **kwargs) return subprocess.check_output(pargs, text=True, **kwargs)
def run_cmd(cmd, /, timeout, verbose=True, **kwargs):
import time
if verbose:
stderr = sys.stdout
sys.stdout.flush()
else:
stderr = subprocess.DEVNULL
try:
start = time.time()
cp = subprocess.run(
cmd,
text=True,
stderr=stderr,
stdout=subprocess.PIPE,
timeout=timeout,
check=True,
**kwargs,
)
stop = time.time()
result = cp.stdout.strip()
if verbose:
print()
return (result, stop - start)
except subprocess.CalledProcessError:
if verbose:
print()
raise
except subprocess.TimeoutExpired:
if verbose:
print()
raise
def getcases(): def getcases():
import csv import csv
for r in sorted( cases = csv.reader(runtime([]).splitlines(), delimiter=" ", skipinitialspace=True)
csv.reader(runtime([]).splitlines(), delimiter=" ", skipinitialspace=True) for r in sorted(cases):
): args, res = r[1].split(" -> ")
yield Case(r[0], *r[1].split(" -> ")) yield Case(r[0], Input.parse(args), res)
def cases_by_methodid() -> dict[str, list[Case]]:
cases_by_id = defaultdict(list)
for c in getcases():
cases_by_id[c.methodid].append(c)
return cases_by_id
@click.group @click.group
...@@ -47,11 +179,20 @@ def cli(): ...@@ -47,11 +179,20 @@ def cli():
"""The jpamb evaluator""" """The jpamb evaluator"""
@cli.command
def rebuild():
"""Rebuild the test-suite."""
subprocess.call(["mvn", "compile"])
@cli.command @cli.command
def cases(): def cases():
"""Get a list of cases to test""" """Get a list of cases to test"""
import json
for c in getcases(): for c in getcases():
print(c) json.dump(c.__dict__, sys.stdout)
print()
@cli.command @cli.command
...@@ -65,36 +206,23 @@ def test(cmd, timeout): ...@@ -65,36 +206,23 @@ def test(cmd, timeout):
if not cmd: if not cmd:
cmd = ["java", "-cp", "target/classes", "-ea", "jpamb.Runtime"] cmd = ["java", "-cp", "target/classes", "-ea", "jpamb.Runtime"]
rebuild()
cases = list(getcases()) cases = list(getcases())
failed = [] failed = []
for c in cases: for c in cases:
print(f"=" * 80) print(f"=" * 80)
print(f"{c.methodid} with {c.input}") pretty_inputs = str(c.input)
print(f"{c.methodid} with {pretty_inputs}")
print() print()
sys.stdout.flush()
try: try:
cp = subprocess.run( result, time = run_cmd(cmd + [c.methodid, pretty_inputs], timeout=timeout)
cmd + [c.methodid, c.input],
text=True,
stderr=sys.stdout,
stdout=subprocess.PIPE,
timeout=timeout,
check=True,
)
result = cp.stdout.strip()
success = result == c.result success = result == c.result
print() print(f"Got {result!r} in {time} which is {success}")
print(f"Got {result} which is {success}")
except subprocess.CalledProcessError: except subprocess.CalledProcessError:
print()
print(f"Process failed.")
success = False success = False
print(f"Process failed.")
except subprocess.TimeoutExpired: except subprocess.TimeoutExpired:
success = "*" == c.result success = "*" == c.result
print()
print(f"Timed out after {timeout}s which is {success}") print(f"Timed out after {timeout}s which is {success}")
if not success: if not success:
failed += [c] failed += [c]
...@@ -108,18 +236,93 @@ def test(cmd, timeout): ...@@ -108,18 +236,93 @@ def test(cmd, timeout):
@cli.command @cli.command
def evaluate(): @click.option("--timeout", default=0.5)
"""Check that all cases are valid""" @click.option("-v", "verbosity", is_flag=True)
@click.option("--target", "targets", multiple=True)
@click.argument(
"CMD",
nargs=-1,
)
def evaluate(cmd, timeout, targets, verbosity):
"""Given an command check if it can predict the results."""
if not cmd:
click.UsageError("Expected a command to evaluate")
cmd = list(cmd)
if not targets:
targets = gettargets()
resulting_targets = []
for target in targets:
resulting_targets.extend(target.split(","))
targets = resulting_targets
cases_by_method = defaultdict(list)
for c in getcases(): for c in getcases():
cases_by_method[c.methodid].append(c)
results = []
for m, cases in cases_by_method.items():
for t in targets:
try: try:
result = runtime( prediction, time = run_cmd(
[c.methodid, c.input], cmd + [m, t], timeout=timeout, verbose=verbosity
enable_assertions=True, )
timeout=0.5, print(prediction)
).strip() prediction = Prediction.parse(prediction)
sometimes = any(c.result == t for c in cases)
score = prediction.score(sometimes)
result = [m, t, prediction.wager, score, time]
except subprocess.CalledProcessError:
result = [m, t, 0, 0, 0]
except subprocess.TimeoutExpired: except subprocess.TimeoutExpired:
result = "*" result = [m, t, 0, 0, timeout]
print(c, result == c.result) results.append(result)
import csv
w = csv.writer(sys.stdout)
w.writerow(["methodid", "target", "wager", "score", "time"])
w.writerows(
[r[:2] + [f"{r[2]:0.2f}", f"{r[3]:0.2f}", f"{r[4]:0.3f}"] for r in results]
)
w.writerow(
[
"-",
"-",
f"{sum(abs(r[2]) for r in results):0.2f}",
f"{sum(r[3] for r in results):0.2f}",
f"{sum(r[4] for r in results):0.2f}",
]
)
@cli.command
def stats():
methods = []
targets = list(gettargets())
for mid, cases in sorted(cases_by_methodid().items()):
methods.append(
[mid] + list(1 if any(c.result == t for c in cases) else 0 for t in targets)
)
import csv
w = csv.writer(sys.stdout)
w.writerow(["methodid"] + targets)
w.writerows(methods)
w.writerow(
["-"]
+ [
f"{sum(m[i + 1] for m in methods) / (len(methods) * len(targets)):0.1%}"
for i, _ in enumerate(targets)
]
)
if __name__ == "__main__": if __name__ == "__main__":
......
#!/usr/bin/env python3
""" The cheating solution.
This solution uses apriori knowledge about the distribution of the test-cases
to gain an advantage.
"""
import sys, csv
with open("stats/distribution.csv") as f:
distribution = list(csv.DictReader(f))[-1]
print(f"Got {sys.argv[1:]}", file=sys.stderr)
print(distribution[sys.argv[2]])
#!/usr/bin/env python3
""" The conservative solution.
Simply answer don't know (50%) to all questions.
"""
import sys
print(f"Got {sys.argv[1:]}", file=sys.stderr)
print("50%")
...@@ -31,7 +31,7 @@ public class Loops { ...@@ -31,7 +31,7 @@ public class Loops {
} }
@Case("() -> assertion error") @Case("() -> assertion error")
@Tag({ INTEGER_OVERFLOW }) @Tag({ LOOP, INTEGER_OVERFLOW })
public static void terminates() { public static void terminates() {
short i = 0; short i = 0;
while (i++ != 0) { while (i++ != 0) {
......
...@@ -10,16 +10,19 @@ public class Simple { ...@@ -10,16 +10,19 @@ public class Simple {
} }
@Case("(false) -> assertion error") @Case("(false) -> assertion error")
@Case("(true) -> ok")
public static void assertBoolean(boolean shouldFail) { public static void assertBoolean(boolean shouldFail) {
assert shouldFail; assert shouldFail;
} }
@Case("(0) -> assertion error") @Case("(0) -> assertion error")
@Case("(1) -> ok")
public static void assertInteger(int n) { public static void assertInteger(int n) {
assert n != 0; assert n != 0;
} }
@Case("(-1) -> assertion error") @Case("(-1) -> assertion error")
@Case("(1) -> ok")
public static void assertPositive(int num) { public static void assertPositive(int num) {
assert num > 0; assert num > 0;
} }
...@@ -30,11 +33,13 @@ public class Simple { ...@@ -30,11 +33,13 @@ public class Simple {
} }
@Case("(0) -> divide by zero") @Case("(0) -> divide by zero")
@Case("(1) -> ok")
public static int divideByN(int n) { public static int divideByN(int n) {
return 1 / n; return 1 / n;
} }
@Case("(0, 0) -> divide by zero") @Case("(0, 0) -> divide by zero")
@Case("(0, 1) -> ok")
public static int divideZeroByZero(int a, int b) { public static int divideZeroByZero(int a, int b) {
return a / b; return a / b;
} }
......
package jpamb.cases;
import jpamb.utils.*;
import static jpamb.utils.Tag.TagType.*;
public class Tricky {
@Case("(24) -> ok")
@Tag({ LOOP })
public static void collatz(int n) {
while (n != 1) {
if (n % 2 == 0) {
n = n / 2;
} else {
n = n * 3 + 1;
}
}
}
}
...@@ -73,6 +73,8 @@ public record CaseContent( ...@@ -73,6 +73,8 @@ public record CaseContent(
return ASSERTION_ERROR; return ASSERTION_ERROR;
} else if (string.equals("divide by zero")) { } else if (string.equals("divide by zero")) {
return DIVIDE_BY_ZERO; return DIVIDE_BY_ZERO;
} else if (string.equals("ok")) {
return SUCCESS;
} else { } else {
throw new RuntimeException("Invalid result type: " + string); throw new RuntimeException("Invalid result type: " + string);
} }
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment