diff --git a/evaluate.py b/evaluate.py index 7047daca48f528370513c27121e51f3de3046f54..7c303ac5f38c8e4aea9ea66c7b6c0a3198c0ada9 100755 --- a/evaluate.py +++ b/evaluate.py @@ -2,23 +2,112 @@ """ The jpamb evaluator """ +from io import StringIO +from typing import TextIO, TypeVar +from collections import defaultdict import click import subprocess +import re import sys from dataclasses import dataclass prim = bool | int -@dataclass +W = TypeVar("W", bound=TextIO) + + +def gettargets(): + return tuple(sorted(("divide by zero", "*", "assertion error", "ok"))) + + +def print_prim(i: prim, file: W = sys.stdout) -> W: + if isinstance(i, bool): + if i: + file.write("true") + else: + file.write("false") + else: + print(i, file=file, end="") + return file + + +@dataclass(frozen=True) +class Input: + val: tuple[prim, ...] + + @staticmethod + def parse(string: str) -> "Input": + if not (m := re.match(r"\(([^)]*)\)", string)): + raise ValueError(f"Invalid inputs: {string!r}") + parsed_args = [] + for i in m.group(1).split(","): + i = i.strip() + if not i: + continue + if i == "true": + parsed_args.append(True) + elif i == "false": + parsed_args.append(False) + else: + parsed_args.append(int(i)) + return Input(tuple(parsed_args)) + + def __str__(self) -> str: + return self.print(StringIO()).getvalue() + + def print(self, file: W = sys.stdout) -> W: + open, close = "()" + file.write(open) + if self.val: + print_prim(self.val[0], file=file) + for i in self.val[1:]: + file.write(", ") + print_prim(i, file=file) + file.write(close) + return file + + +@dataclass(frozen=True) class Case: methodid: str - input: str + input: Input result: str -def rebuild(): - subprocess.call(["mvn", "compile"]) +@dataclass(frozen=True) +class Prediction: + wager: float + + @staticmethod + def parse(string: str) -> "Prediction": + if m := re.match(r"([^%]*)\%", string): + p = float(m.group(1)) / 100 + return Prediction.from_probability(p) + else: + return Prediction(float(string)) + + @staticmethod + def from_probability(p: float) -> "Prediction": + negate = False + if p < 0.5: + p = 1 - p + negate = True + if p == 1: + x = float("inf") + else: + x = ((1 - p) / p) * 0.5 + return Prediction(-x if negate else x) + + def score(self, happens: bool): + wager = (-1 if not happens else 1) * self.wager + if wager > 0: + if wager == float("inf"): + return 1 + else: + return 1 - 1 / (wager + 1) + else: + return wager def runtime(args, enable_assertions=False, **kwargs): @@ -33,13 +122,56 @@ def runtime(args, enable_assertions=False, **kwargs): return subprocess.check_output(pargs, text=True, **kwargs) +def run_cmd(cmd, /, timeout, verbose=True, **kwargs): + import time + + if verbose: + stderr = sys.stdout + sys.stdout.flush() + else: + stderr = subprocess.DEVNULL + try: + start = time.time() + cp = subprocess.run( + cmd, + text=True, + stderr=stderr, + stdout=subprocess.PIPE, + timeout=timeout, + check=True, + **kwargs, + ) + stop = time.time() + result = cp.stdout.strip() + if verbose: + print() + return (result, stop - start) + except subprocess.CalledProcessError: + if verbose: + print() + raise + except subprocess.TimeoutExpired: + if verbose: + print() + raise + + def getcases(): import csv - for r in sorted( - csv.reader(runtime([]).splitlines(), delimiter=" ", skipinitialspace=True) - ): - yield Case(r[0], *r[1].split(" -> ")) + cases = csv.reader(runtime([]).splitlines(), delimiter=" ", skipinitialspace=True) + for r in sorted(cases): + args, res = r[1].split(" -> ") + yield Case(r[0], Input.parse(args), res) + + +def cases_by_methodid() -> dict[str, list[Case]]: + cases_by_id = defaultdict(list) + + for c in getcases(): + cases_by_id[c.methodid].append(c) + + return cases_by_id @click.group @@ -47,11 +179,20 @@ def cli(): """The jpamb evaluator""" +@cli.command +def rebuild(): + """Rebuild the test-suite.""" + subprocess.call(["mvn", "compile"]) + + @cli.command def cases(): """Get a list of cases to test""" + import json + for c in getcases(): - print(c) + json.dump(c.__dict__, sys.stdout) + print() @cli.command @@ -65,36 +206,23 @@ def test(cmd, timeout): if not cmd: cmd = ["java", "-cp", "target/classes", "-ea", "jpamb.Runtime"] - rebuild() - cases = list(getcases()) failed = [] for c in cases: print(f"=" * 80) - print(f"{c.methodid} with {c.input}") + pretty_inputs = str(c.input) + print(f"{c.methodid} with {pretty_inputs}") print() - sys.stdout.flush() try: - cp = subprocess.run( - cmd + [c.methodid, c.input], - text=True, - stderr=sys.stdout, - stdout=subprocess.PIPE, - timeout=timeout, - check=True, - ) - result = cp.stdout.strip() + result, time = run_cmd(cmd + [c.methodid, pretty_inputs], timeout=timeout) success = result == c.result - print() - print(f"Got {result} which is {success}") + print(f"Got {result!r} in {time} which is {success}") except subprocess.CalledProcessError: - print() - print(f"Process failed.") success = False + print(f"Process failed.") except subprocess.TimeoutExpired: success = "*" == c.result - print() print(f"Timed out after {timeout}s which is {success}") if not success: failed += [c] @@ -108,18 +236,93 @@ def test(cmd, timeout): @cli.command -def evaluate(): - """Check that all cases are valid""" +@click.option("--timeout", default=0.5) +@click.option("-v", "verbosity", is_flag=True) +@click.option("--target", "targets", multiple=True) +@click.argument( + "CMD", + nargs=-1, +) +def evaluate(cmd, timeout, targets, verbosity): + """Given an command check if it can predict the results.""" + + if not cmd: + click.UsageError("Expected a command to evaluate") + cmd = list(cmd) + + if not targets: + targets = gettargets() + resulting_targets = [] + for target in targets: + resulting_targets.extend(target.split(",")) + targets = resulting_targets + + cases_by_method = defaultdict(list) for c in getcases(): - try: - result = runtime( - [c.methodid, c.input], - enable_assertions=True, - timeout=0.5, - ).strip() - except subprocess.TimeoutExpired: - result = "*" - print(c, result == c.result) + cases_by_method[c.methodid].append(c) + + results = [] + + for m, cases in cases_by_method.items(): + for t in targets: + try: + prediction, time = run_cmd( + cmd + [m, t], timeout=timeout, verbose=verbosity + ) + print(prediction) + prediction = Prediction.parse(prediction) + + sometimes = any(c.result == t for c in cases) + score = prediction.score(sometimes) + + result = [m, t, prediction.wager, score, time] + except subprocess.CalledProcessError: + result = [m, t, 0, 0, 0] + except subprocess.TimeoutExpired: + result = [m, t, 0, 0, timeout] + results.append(result) + + import csv + + w = csv.writer(sys.stdout) + w.writerow(["methodid", "target", "wager", "score", "time"]) + w.writerows( + [r[:2] + [f"{r[2]:0.2f}", f"{r[3]:0.2f}", f"{r[4]:0.3f}"] for r in results] + ) + w.writerow( + [ + "-", + "-", + f"{sum(abs(r[2]) for r in results):0.2f}", + f"{sum(r[3] for r in results):0.2f}", + f"{sum(r[4] for r in results):0.2f}", + ] + ) + + +@cli.command +def stats(): + methods = [] + + targets = list(gettargets()) + + for mid, cases in sorted(cases_by_methodid().items()): + methods.append( + [mid] + list(1 if any(c.result == t for c in cases) else 0 for t in targets) + ) + + import csv + + w = csv.writer(sys.stdout) + w.writerow(["methodid"] + targets) + w.writerows(methods) + w.writerow( + ["-"] + + [ + f"{sum(m[i + 1] for m in methods) / (len(methods) * len(targets)):0.1%}" + for i, _ in enumerate(targets) + ] + ) if __name__ == "__main__": diff --git a/solutions/apriori.py b/solutions/apriori.py new file mode 100755 index 0000000000000000000000000000000000000000..ff7f76e8b685f26d8239580db66a6a47b62436b4 --- /dev/null +++ b/solutions/apriori.py @@ -0,0 +1,14 @@ +#!/usr/bin/env python3 +""" The cheating solution. + +This solution uses apriori knowledge about the distribution of the test-cases +to gain an advantage. +""" + +import sys, csv + +with open("stats/distribution.csv") as f: + distribution = list(csv.DictReader(f))[-1] + +print(f"Got {sys.argv[1:]}", file=sys.stderr) +print(distribution[sys.argv[2]]) diff --git a/solutions/conservative.py b/solutions/conservative.py new file mode 100755 index 0000000000000000000000000000000000000000..8078c240917a911615085aed3845dc059df8a531 --- /dev/null +++ b/solutions/conservative.py @@ -0,0 +1,10 @@ +#!/usr/bin/env python3 +""" The conservative solution. + +Simply answer don't know (50%) to all questions. + +""" +import sys + +print(f"Got {sys.argv[1:]}", file=sys.stderr) +print("50%") diff --git a/src/main/java/jpamb/cases/Loops.java b/src/main/java/jpamb/cases/Loops.java index 423d026c5c267ed80867db954e80057bb48e0c7c..235ff253ea5dcef7bb12fa6e308bd948148fc2b7 100644 --- a/src/main/java/jpamb/cases/Loops.java +++ b/src/main/java/jpamb/cases/Loops.java @@ -31,7 +31,7 @@ public class Loops { } @Case("() -> assertion error") - @Tag({ INTEGER_OVERFLOW }) + @Tag({ LOOP, INTEGER_OVERFLOW }) public static void terminates() { short i = 0; while (i++ != 0) { diff --git a/src/main/java/jpamb/cases/Simple.java b/src/main/java/jpamb/cases/Simple.java index 0caf225310bd9bb679bc8a3a8508df27da445af8..d744556309a2e98fa69058a3e062aad01de07fbb 100644 --- a/src/main/java/jpamb/cases/Simple.java +++ b/src/main/java/jpamb/cases/Simple.java @@ -10,16 +10,19 @@ public class Simple { } @Case("(false) -> assertion error") + @Case("(true) -> ok") public static void assertBoolean(boolean shouldFail) { assert shouldFail; } @Case("(0) -> assertion error") + @Case("(1) -> ok") public static void assertInteger(int n) { assert n != 0; } @Case("(-1) -> assertion error") + @Case("(1) -> ok") public static void assertPositive(int num) { assert num > 0; } @@ -30,11 +33,13 @@ public class Simple { } @Case("(0) -> divide by zero") + @Case("(1) -> ok") public static int divideByN(int n) { return 1 / n; } @Case("(0, 0) -> divide by zero") + @Case("(0, 1) -> ok") public static int divideZeroByZero(int a, int b) { return a / b; } diff --git a/src/main/java/jpamb/cases/Tricky.java b/src/main/java/jpamb/cases/Tricky.java new file mode 100644 index 0000000000000000000000000000000000000000..661625e5df9643ff834a1ff2aa8860bac19c43ae --- /dev/null +++ b/src/main/java/jpamb/cases/Tricky.java @@ -0,0 +1,20 @@ +package jpamb.cases; + +import jpamb.utils.*; +import static jpamb.utils.Tag.TagType.*; + +public class Tricky { + + @Case("(24) -> ok") + @Tag({ LOOP }) + public static void collatz(int n) { + while (n != 1) { + if (n % 2 == 0) { + n = n / 2; + } else { + n = n * 3 + 1; + } + } + } + +} diff --git a/src/main/java/jpamb/utils/CaseContent.java b/src/main/java/jpamb/utils/CaseContent.java index 9e0ffe1198762b595e2220837f9dd8644aa96dcb..b66c7cebecddb09f3742d0ff684c423621b5f8de 100644 --- a/src/main/java/jpamb/utils/CaseContent.java +++ b/src/main/java/jpamb/utils/CaseContent.java @@ -73,6 +73,8 @@ public record CaseContent( return ASSERTION_ERROR; } else if (string.equals("divide by zero")) { return DIVIDE_BY_ZERO; + } else if (string.equals("ok")) { + return SUCCESS; } else { throw new RuntimeException("Invalid result type: " + string); }