ape-framework 0.1.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
File without changes
@@ -0,0 +1,50 @@
1
+ Metadata-Version: 2.4
2
+ Name: ape-framework
3
+ Version: 0.1.0
4
+ Summary: Package for evaluating algebra problems using AI systems.
5
+ Classifier: Programming Language :: Python :: 3
6
+ Classifier: Operating System :: OS Independent
7
+ Requires-Python: >=3.12
8
+ Description-Content-Type: text/markdown
9
+ License-File: LICENSE
10
+ Dynamic: license-file
11
+
12
+ # Algebra Problems Evaluator (APE)
13
+
14
+ ## What is APE?
15
+ **APE is a framework that simplifies building and running mathematical benchmarks for AI systems.**
16
+
17
+ It was initially meant as a way to evaluate mathematical problems from the field of algebra on whether they are suited to be a target goal for LLM reasoning research (hence the name).
18
+ However it might be more useful to think about it the other way around - the subjects of the evaluation are LLMs or more broadly - AI systems
19
+ and they are being evaluated on their ability to solve specific algebra problems. Problems, which solutions are hard to generate, but relatively easy to automatically check for correctness.
20
+
21
+ APE was created as a part of a Bachelor's project at the University of Warsaw.
22
+
23
+ ## User's Guide
24
+ [**TODO**]
25
+
26
+ ## Development setup
27
+
28
+ Install development dependencies:
29
+
30
+ ```bash
31
+ pip install -r requirements-dev.txt
32
+ ```
33
+
34
+ Install git hooks:
35
+
36
+ ```bash
37
+ pre-commit install
38
+ ```
39
+
40
+ Run all hooks manually:
41
+
42
+ ```bash
43
+ pre-commit run --all-files
44
+ ```
45
+
46
+ <!--
47
+ 1. What the user has to implement to have a complete benchmark?
48
+ 2. How to run a test using an APE benchmark?
49
+ 3. Where to find the documentation?
50
+ -->
@@ -0,0 +1,39 @@
1
+ # Algebra Problems Evaluator (APE)
2
+
3
+ ## What is APE?
4
+ **APE is a framework that simplifies building and running mathematical benchmarks for AI systems.**
5
+
6
+ It was initially meant as a way to evaluate mathematical problems from the field of algebra on whether they are suited to be a target goal for LLM reasoning research (hence the name).
7
+ However it might be more useful to think about it the other way around - the subjects of the evaluation are LLMs or more broadly - AI systems
8
+ and they are being evaluated on their ability to solve specific algebra problems. Problems, which solutions are hard to generate, but relatively easy to automatically check for correctness.
9
+
10
+ APE was created as a part of a Bachelor's project at the University of Warsaw.
11
+
12
+ ## User's Guide
13
+ [**TODO**]
14
+
15
+ ## Development setup
16
+
17
+ Install development dependencies:
18
+
19
+ ```bash
20
+ pip install -r requirements-dev.txt
21
+ ```
22
+
23
+ Install git hooks:
24
+
25
+ ```bash
26
+ pre-commit install
27
+ ```
28
+
29
+ Run all hooks manually:
30
+
31
+ ```bash
32
+ pre-commit run --all-files
33
+ ```
34
+
35
+ <!--
36
+ 1. What the user has to implement to have a complete benchmark?
37
+ 2. How to run a test using an APE benchmark?
38
+ 3. Where to find the documentation?
39
+ -->
@@ -0,0 +1,5 @@
1
+ import logging
2
+
3
+ # Prevent logging if the user did not configure it
4
+ logger = logging.getLogger(__name__)
5
+ logger.addHandler(logging.NullHandler())
@@ -0,0 +1,11 @@
1
+ from abc import ABC, abstractmethod
2
+ from typing import Generic, TypeVar
3
+
4
+ RawT = TypeVar("RawT")
5
+ SolutionT = TypeVar("SolutionT")
6
+
7
+
8
+ class SolutionAdapter(ABC, Generic[RawT, SolutionT]):
9
+ @classmethod
10
+ @abstractmethod
11
+ def adapt(cls, raw: RawT) -> list[SolutionT]: ...
@@ -0,0 +1,39 @@
1
+ from abc import ABC, abstractmethod
2
+ from dataclasses import dataclass
3
+ from datetime import timedelta
4
+ from enum import Enum
5
+ from typing import Any, Generic, Optional, Type, TypeVar
6
+
7
+ from .adapter import SolutionT
8
+ from .data import DataT
9
+
10
+ ResultT = TypeVar("ResultT", bound=Enum)
11
+
12
+
13
+ @dataclass(frozen=True)
14
+ class CheckerOutput(Generic[ResultT]):
15
+ result: ResultT
16
+ runtime: Optional[timedelta]
17
+ # Additional metadata can be added
18
+
19
+
20
+ class Checker(ABC, Generic[DataT, SolutionT, ResultT]):
21
+ @abstractmethod
22
+ def check(
23
+ self, data: DataT, solution_batch: list[SolutionT]
24
+ ) -> list[CheckerOutput[ResultT]]: ...
25
+
26
+
27
+ CheckerT = TypeVar("CheckerT", bound=Checker)
28
+
29
+
30
+ class CheckerFactory(Generic[CheckerT]):
31
+ def __init__(self, checker_cls: Type[CheckerT], **kwargs: object) -> None:
32
+ self._checker_cls = checker_cls
33
+ self._kwargs = kwargs
34
+
35
+ def create(self) -> CheckerT:
36
+ return self._checker_cls(**self._kwargs)
37
+
38
+ def get_params(self) -> dict[str, Any]:
39
+ return self._kwargs
@@ -0,0 +1,298 @@
1
+ from __future__ import annotations
2
+
3
+ import logging
4
+ import logging.config
5
+ from collections import deque
6
+ from contextlib import AbstractContextManager
7
+ from datetime import datetime
8
+ from pathlib import Path
9
+ from types import TracebackType
10
+ from typing import Any, Callable, Literal, Optional
11
+
12
+ from rich import box
13
+ from rich.console import Console
14
+ from rich.layout import Layout
15
+ from rich.live import Live
16
+ from rich.panel import Panel
17
+ from rich.progress import (
18
+ BarColumn,
19
+ Progress,
20
+ SpinnerColumn,
21
+ TaskID,
22
+ TaskProgressColumn,
23
+ TextColumn,
24
+ TimeElapsedColumn,
25
+ )
26
+ from rich.table import Table
27
+
28
+ from .runner import Runner
29
+
30
+
31
+ class LogBufferHandler(logging.Handler):
32
+ def __init__(self, add_entry: Callable[[dict[str, str]], None]) -> None:
33
+ super().__init__()
34
+ self._add_entry = add_entry
35
+
36
+ def emit(self, record: logging.LogRecord) -> None:
37
+ try:
38
+ entry = {
39
+ "time": datetime.fromtimestamp(record.created).strftime("%H:%M:%S"),
40
+ "level": record.levelname,
41
+ "logger": record.name,
42
+ "message": record.getMessage(),
43
+ }
44
+ self._add_entry(entry)
45
+ except Exception:
46
+ self.handleError(record)
47
+
48
+
49
+ class CLI(AbstractContextManager):
50
+ def __init__(
51
+ self,
52
+ runner: Runner,
53
+ *,
54
+ log_file: Path | str = Path("output/run.log"),
55
+ log_level: int | str = logging.INFO,
56
+ ) -> None:
57
+ self._runner = runner
58
+ self._log_file = Path(log_file)
59
+ self._log_level = self._resolve_log_level(log_level)
60
+ self._progress_task_id: Optional[TaskID] = None
61
+ self._log_buffer: deque[dict[str, str]] = deque(maxlen=50)
62
+ self._solver_statuses: list[dict[str, Any]] = []
63
+ self._checker_statuses: list[dict[str, Any]] = []
64
+
65
+ def __enter__(self) -> CLI:
66
+ self._setup_logging()
67
+ self._logger = logging.getLogger(__name__)
68
+ self._console = Console()
69
+ self._progress = Progress(
70
+ SpinnerColumn(),
71
+ TextColumn("{task.description}", justify="left"),
72
+ BarColumn(bar_width=None),
73
+ TaskProgressColumn(),
74
+ TextColumn("{task.completed}/{task.total}"),
75
+ TimeElapsedColumn(),
76
+ console=self._console,
77
+ transient=False,
78
+ expand=True,
79
+ )
80
+ self._log_handler = LogBufferHandler(self._add_log_entry)
81
+ self._log_handler.setLevel(self._log_level)
82
+ logging.getLogger().addHandler(self._log_handler)
83
+ self._live = Live(
84
+ self._render_dashboard(),
85
+ console=self._console,
86
+ refresh_per_second=8,
87
+ transient=False,
88
+ )
89
+ self._live.start()
90
+ self._progress.start()
91
+ self._refresh_dashboard()
92
+ return self
93
+
94
+ def __exit__(
95
+ self,
96
+ exc_type: type[BaseException] | None,
97
+ exc: BaseException | None,
98
+ tb: TracebackType | None,
99
+ ) -> Literal[False]:
100
+ if self._live:
101
+ self._live.stop()
102
+ if self._progress:
103
+ self._progress.stop()
104
+ if self._log_handler:
105
+ logging.getLogger().removeHandler(self._log_handler)
106
+ self._log_handler.close()
107
+ return False
108
+
109
+ def run(self, *args: Any, **kwargs: Any) -> Any:
110
+ self._logger.info("CLI runner starting")
111
+
112
+ forwarded = dict(kwargs)
113
+ forwarded.setdefault("progress_hook", self._progress_hook)
114
+ forwarded.setdefault("worker_status_hook", self._status_hook)
115
+
116
+ result = self._runner.run(*args, **forwarded)
117
+
118
+ self._logger.info("CLI runner finished")
119
+ return result
120
+
121
+ def _progress_hook(self, done: int, total: int) -> None:
122
+ # Initialize task on first call
123
+ if self._progress_task_id is None:
124
+ description = "Processing data"
125
+ self._progress_task_id = self._progress.add_task(description, total=total)
126
+
127
+ if self._progress_task_id is not None:
128
+ self._progress.update(self._progress_task_id, completed=done, total=total)
129
+ self._refresh_dashboard()
130
+
131
+ def _status_hook(self, snapshot: dict[str, list[dict[str, Any]]]) -> None:
132
+ self._solver_statuses = snapshot.get("solvers", [])
133
+ self._checker_statuses = snapshot.get("checkers", [])
134
+ self._refresh_dashboard()
135
+
136
+ def _setup_logging(self) -> None:
137
+ self._log_file.parent.mkdir(parents=True, exist_ok=True)
138
+
139
+ logging.config.dictConfig(
140
+ {
141
+ "version": 1,
142
+ "disable_existing_loggers": False,
143
+ "formatters": {
144
+ "default": {
145
+ "format": "%(asctime)s %(name)s %(levelname)s %(message)s",
146
+ }
147
+ },
148
+ "handlers": {
149
+ "file": {
150
+ "class": "logging.FileHandler",
151
+ "level": self._log_level,
152
+ "formatter": "default",
153
+ "filename": str(self._log_file),
154
+ "encoding": "utf-8",
155
+ }
156
+ },
157
+ "root": {
158
+ "level": self._log_level,
159
+ "handlers": ["file"],
160
+ },
161
+ }
162
+ )
163
+
164
+ @staticmethod
165
+ def _resolve_log_level(level: int | str) -> int:
166
+ if isinstance(level, int):
167
+ return level
168
+
169
+ normalized = level.strip().upper()
170
+ level_map = logging.getLevelNamesMapping()
171
+ if normalized in level_map:
172
+ return level_map[normalized]
173
+
174
+ if normalized.isdigit():
175
+ return int(normalized)
176
+
177
+ raise ValueError(
178
+ f"Invalid log level: {level!r}. Use a valid logging level name or integer."
179
+ )
180
+
181
+ def _add_log_entry(self, entry: dict[str, str]) -> None:
182
+ self._log_buffer.append(entry)
183
+ self._refresh_dashboard()
184
+
185
+ def _refresh_dashboard(self) -> None:
186
+ try:
187
+ self._live.update(self._render_dashboard(), refresh=True)
188
+ except Exception:
189
+ # Avoid crashing the run due to rendering issues.
190
+ return
191
+
192
+ def _render_dashboard(self) -> Layout:
193
+ layout = Layout(name="root")
194
+ layout.split_column(
195
+ Layout(self._render_progress_panel(), name="progress", size=4),
196
+ Layout(self._render_workers_panel(), name="workers", ratio=2),
197
+ Layout(self._render_logs_panel(), name="logs", ratio=3),
198
+ )
199
+ return layout
200
+
201
+ def _render_progress_panel(self) -> Panel:
202
+ if self._progress is None or self._progress_task_id is None:
203
+ body: Any = "Waiting for progress updates..."
204
+ else:
205
+ body = self._progress
206
+
207
+ return Panel(body, title="Run Progress", box=box.SIMPLE)
208
+
209
+ def _render_workers_panel(self) -> Panel:
210
+ solvers = Table(
211
+ show_header=True,
212
+ header_style="bold",
213
+ expand=True,
214
+ box=box.SIMPLE,
215
+ padding=(0, 0),
216
+ title="Solvers",
217
+ title_style="bold",
218
+ )
219
+ solvers.add_column("Solver", justify="center")
220
+ solvers.add_column("State", justify="center")
221
+ solvers.add_column("Data", justify="center")
222
+ solvers.add_column("Runs completed", justify="center")
223
+ solvers.add_column("Session length", justify="center")
224
+
225
+ if not self._solver_statuses:
226
+ solvers.add_row("-", "-", "-", "-", "-")
227
+ else:
228
+ for st in self._solver_statuses:
229
+ solvers.add_row(
230
+ str(st.get("id", "-")),
231
+ str(st.get("state", "-")),
232
+ str(st.get("data_id", "-")),
233
+ f"{st.get('curr_run', '-')}/{st.get('total_runs', '-')}",
234
+ str(st.get("session_length", "-")),
235
+ )
236
+
237
+ checkers = Table(
238
+ show_header=True,
239
+ header_style="bold",
240
+ expand=True,
241
+ box=box.SIMPLE,
242
+ padding=(0, 0),
243
+ title="Checkers",
244
+ title_style="bold",
245
+ )
246
+ checkers.add_column("Checker", justify="center")
247
+ checkers.add_column("State", justify="center")
248
+ checkers.add_column("Data", justify="center")
249
+ checkers.add_column("From solver", justify="center")
250
+
251
+ if not self._checker_statuses:
252
+ checkers.add_row("-", "-", "-", "-")
253
+ else:
254
+ for st in self._checker_statuses:
255
+ checkers.add_row(
256
+ str(st.get("id", "-")),
257
+ str(st.get("state", "-")),
258
+ str(st.get("data_id", "-")),
259
+ str(st.get("solver_id", "-")),
260
+ )
261
+
262
+ grid = Table.grid(expand=True)
263
+ grid.add_column(ratio=1)
264
+ grid.add_column(ratio=1)
265
+ grid.add_row(solvers, checkers)
266
+
267
+ return Panel(grid, box=box.SIMPLE)
268
+
269
+ def _render_logs_panel(self) -> Panel:
270
+ table = Table(
271
+ show_header=True,
272
+ header_style="bold",
273
+ expand=True,
274
+ box=box.SIMPLE,
275
+ padding=(0, 0),
276
+ )
277
+ table.add_column("Time", style="dim", width=8, no_wrap=True)
278
+ table.add_column("Level", width=8, no_wrap=True)
279
+ table.add_column("Logger", style="dim", width=20, no_wrap=True)
280
+ table.add_column("Message", overflow="fold")
281
+
282
+ if not self._log_buffer:
283
+ table.add_row("-", "-", "-", "Waiting for logs...")
284
+ else:
285
+ max_rows = 6
286
+ if self._console is not None:
287
+ max_rows = max(6, min(30, self._console.size.height - 12))
288
+
289
+ for entry in reversed(list(self._log_buffer)[-max_rows:]):
290
+ table.add_row(
291
+ entry.get("time", ""),
292
+ entry.get("level", ""),
293
+ entry.get("logger", ""),
294
+ entry.get("message", ""),
295
+ )
296
+ table.add_section()
297
+
298
+ return Panel(table, title="Latest Logs", box=box.SIMPLE, padding=(0, 0))
@@ -0,0 +1,93 @@
1
+ import argparse
2
+ import json
3
+ import sys
4
+ from pathlib import Path
5
+
6
+ # We use raw jsons to avoid issues with unpickling types etc.
7
+ # This file needs to be updated for each API breaking change, but it works nicely
8
+
9
+
10
+ def _merge_batches(dst: dict, src: dict) -> None:
11
+ dst["solutions"].extend(src.get("solutions", []))
12
+ dst["solution_runtimes"].extend(src.get("solution_runtimes", []))
13
+ dst["solution_total_times"].extend(src.get("solution_total_times", []))
14
+ dst["checker_outputs"].extend(src.get("checker_outputs", []))
15
+
16
+
17
+ def main() -> None:
18
+ parser = argparse.ArgumentParser(description="Combine APE run results.")
19
+ parser.add_argument("folder", help="Folder containing .jsonl run files")
20
+ args = parser.parse_args()
21
+
22
+ folder = Path(args.folder)
23
+ if not folder.exists() or not folder.is_dir():
24
+ print(f"Error: {folder} does not exist or is not a directory.", file=sys.stderr)
25
+ sys.exit(1)
26
+
27
+ output_path = folder / "combined.jsonl"
28
+
29
+ if output_path.exists():
30
+ print(
31
+ f"Error: Output file {output_path} already exists. Aborting to prevent overwriting.",
32
+ file=sys.stderr,
33
+ )
34
+ sys.exit(1)
35
+
36
+ input_files = []
37
+ for f in folder.glob("*.jsonl"):
38
+ input_files.append(f)
39
+
40
+ if not input_files:
41
+ print("No .jsonl files found to combine.")
42
+ return
43
+
44
+ # Sort files for deterministic order
45
+ input_files.sort()
46
+
47
+ print(f"Found {len(input_files)} files to combine into {output_path}")
48
+
49
+ merged_data: dict[str, dict] = {}
50
+
51
+ def process_file(fpath: Path) -> None:
52
+ print(f"Reading {fpath.name}...")
53
+ count = 0
54
+ try:
55
+ with open(fpath, "r") as infile:
56
+ for line in infile:
57
+ line = line.strip()
58
+ if not line:
59
+ continue
60
+ try:
61
+ output = json.loads(line)
62
+ data_id = output["input_data"]["_Data__id"]
63
+ if data_id in merged_data:
64
+ _merge_batches(merged_data[data_id], output)
65
+ else:
66
+ merged_data[data_id] = output
67
+ count += 1
68
+ except Exception as e:
69
+ print(
70
+ f" Error decoding line in {fpath.name}: {e}",
71
+ file=sys.stderr,
72
+ )
73
+ print(f" Processed {count} entries.")
74
+ except Exception as e:
75
+ print(f" Error reading {fpath.name}: {e}", file=sys.stderr)
76
+
77
+ for fpath in input_files:
78
+ process_file(fpath)
79
+
80
+ print(f"Writing merged results ({len(merged_data)} items) to {output_path}...")
81
+ try:
82
+ with open(output_path, "w") as outfile:
83
+ for output in merged_data.values():
84
+ outfile.write(json.dumps(output) + "\n")
85
+ except Exception as e:
86
+ print(f"Error writing to {output_path}: {e}", file=sys.stderr)
87
+ sys.exit(1)
88
+
89
+ print("Done.")
90
+
91
+
92
+ if __name__ == "__main__":
93
+ main()
@@ -0,0 +1,23 @@
1
+ from abc import ABC, abstractmethod
2
+ from pathlib import Path
3
+ from typing import Any, Generic
4
+
5
+ from .adapter import SolutionT
6
+ from .checker import ResultT
7
+ from .data import DataT
8
+ from .runner import Output
9
+
10
+
11
+ class Report(ABC, Generic[DataT, SolutionT, ResultT]):
12
+
13
+ # JSON-like dictionary with data in a Typst templates compatible format.
14
+ _summarized_data: dict[str, Any]
15
+
16
+ @abstractmethod
17
+ def __init__(self, output: Output[DataT, SolutionT, ResultT]) -> None: ...
18
+
19
+ @abstractmethod
20
+ def save(self, path: Path) -> None: ...
21
+
22
+ @abstractmethod
23
+ def dump_data(self, path: Path) -> None: ...
@@ -0,0 +1,122 @@
1
+ import logging
2
+ import os
3
+ import sys
4
+ from datetime import datetime
5
+ from pathlib import Path
6
+ from typing import Any, Generic, Optional, cast
7
+
8
+ import jsonpickle
9
+
10
+ from .adapter import SolutionT
11
+ from .checker import ResultT
12
+ from .data import Data, SchemaT
13
+ from .runner.output import OutputBatch
14
+ from .runner.run_params import RunParams
15
+
16
+ logger = logging.getLogger(__name__)
17
+
18
+
19
+ class SaveHandler(Generic[SchemaT, SolutionT, ResultT]):
20
+ """Responsible for handling the result file; so loading the save if it is present,
21
+ and appending new output to it."""
22
+
23
+ def __init__(
24
+ self,
25
+ run_params: RunParams,
26
+ solver_params: dict[str, Any],
27
+ checker_params: dict[str, Any],
28
+ ):
29
+ """Create a new save handler. If the output file exists, look for a save
30
+ in it for cross-run persistence. If a save is present, the last saved data id is
31
+ found and can be retrieved by calling `last_from_save`. The save format is JSONL for the output
32
+ batches and JSON for run metadata.
33
+ """
34
+ self._save_file: Path = run_params.save_file
35
+ self._metadata_file: Path = run_params.metadata_file
36
+ self._last: Optional[str] = None
37
+
38
+ self._save_file.parent.mkdir(parents=True, exist_ok=True)
39
+ self._metadata_file.parent.mkdir(parents=True, exist_ok=True)
40
+
41
+ if run_params.load_save:
42
+ try:
43
+ metadata = cast(
44
+ dict[str, Any], jsonpickle.decode(self._metadata_file.read_text())
45
+ )
46
+ _ = self._save_file.read_text() # Run to ensure the file exists
47
+
48
+ last_id = metadata.get("last_id")
49
+ self._last = str(last_id) if last_id is not None else None
50
+
51
+ logger.info(f"Save found and loaded. Last label: {self._last}.")
52
+ return
53
+ except FileNotFoundError:
54
+ logger.info("No save found.")
55
+ except Exception as e:
56
+ logger.warning(
57
+ f"Tried to read save from file {self._save_file}, but the save is corrupted. "
58
+ f"Error encountered: {e}."
59
+ )
60
+
61
+ logger.info("Creating new save files.")
62
+ try:
63
+ self._save_file.touch(exist_ok=False)
64
+ except FileExistsError:
65
+ logger.critical(
66
+ "Found save files when creating new files. Stopping to prevent data overwriting."
67
+ )
68
+ raise RuntimeError
69
+
70
+ metadata = {
71
+ "run_params": run_params,
72
+ "solver_params": solver_params,
73
+ "checker_params": checker_params,
74
+ "start_date": datetime.now(),
75
+ }
76
+ metadata_payload = cast(
77
+ str,
78
+ jsonpickle.encode(metadata, unpicklable=True, make_refs=False),
79
+ )
80
+ self._metadata_file.write_text(metadata_payload)
81
+
82
+ def last_from_save(self) -> str | None:
83
+ """See which Data id was the last calculated in the previous run.
84
+
85
+ Returns:
86
+ id (str | None): the last id (or None if no save loaded).
87
+ """
88
+ return self._last
89
+
90
+ def save(
91
+ self,
92
+ output_batch: OutputBatch[Data[SchemaT], SolutionT, ResultT],
93
+ ) -> None:
94
+ """Add the new results to the save file.
95
+
96
+ Params:
97
+ data (Data[SchemaT]): the data item the given batch is for
98
+ solution_result_pairs (list[tuple[SolutionT, ResultT]]): a list of
99
+ corresponding pairs (solution, result)
100
+ """
101
+ metadata = cast(
102
+ dict[str, Any], jsonpickle.decode(self._metadata_file.read_text())
103
+ )
104
+ metadata["last_id"] = output_batch.input_data.id
105
+ metadata["last_time"] = datetime.now()
106
+ total = metadata.get("total", 0)
107
+ metadata["total"] = int(total) + 1
108
+
109
+ metadata_payload = cast(
110
+ str,
111
+ jsonpickle.encode(metadata, unpicklable=True, make_refs=False),
112
+ )
113
+ self._metadata_file.write_text(metadata_payload)
114
+
115
+ with open(self._save_file, "a") as f:
116
+ output_payload = cast(
117
+ str,
118
+ jsonpickle.encode(output_batch, unpicklable=True, make_refs=False),
119
+ )
120
+ f.write(output_payload + "\n")
121
+ f.flush()
122
+ os.fsync(f.fileno())
@@ -0,0 +1,46 @@
1
+ from abc import ABC, abstractmethod
2
+ from typing import Any, Generic, Optional, Type, TypeVar
3
+
4
+ from .adapter import RawT
5
+ from .checker import CheckerOutput, ResultT
6
+ from .data import DataT
7
+
8
+
9
+ class Solver(ABC, Generic[DataT, RawT, ResultT]):
10
+ @abstractmethod
11
+ def begin_session(self, data: DataT) -> RawT:
12
+ """Starts the solving session and returns a first raw solution.
13
+ Should only be called right after construction, or after advance_session returned None.
14
+ """
15
+ ...
16
+
17
+ @abstractmethod
18
+ def advance_session(self, outputs: list[CheckerOutput[ResultT]]) -> Optional[RawT]:
19
+ """Receives the results of the previous solution and tries to make another one,
20
+ or decides to stop here and return None.
21
+ Should not be called again after returning None, or not until begin_session was called again.
22
+ """
23
+ ...
24
+
25
+ @abstractmethod
26
+ def retry_last(self) -> RawT:
27
+ """Tries to generate the solution again, in an event that the previous raw solution
28
+ could not be parsed.
29
+ Should only be called after a call to begin_session or advance_session, if it returned something.
30
+ """
31
+ ...
32
+
33
+
34
+ SolverT = TypeVar("SolverT", bound=Solver)
35
+
36
+
37
+ class SolverFactory(Generic[SolverT]):
38
+ def __init__(self, solver_cls: Type[SolverT], **kwargs: object) -> None:
39
+ self._solver_cls = solver_cls
40
+ self._kwargs = kwargs
41
+
42
+ def create(self) -> SolverT:
43
+ return self._solver_cls(**self._kwargs)
44
+
45
+ def get_params(self) -> dict[str, Any]:
46
+ return self._kwargs
@@ -0,0 +1,67 @@
1
+ import sys
2
+ from typing import Any, cast
3
+
4
+ import jsonpickle
5
+
6
+ from .runner.output import OutputBatch
7
+ from .runner.run_params import RunParams
8
+
9
+
10
+ # A very simple visualizer for now
11
+ class Visualizer:
12
+ def visualize(self, run_params: RunParams) -> None:
13
+ save_f = run_params.save_file
14
+ if not save_f.exists() or not save_f.is_file():
15
+ print(f"Error: {save_f} does not exist or is not a file.", file=sys.stderr)
16
+ return
17
+
18
+ metadata_f = run_params.metadata_file
19
+ if not metadata_f.exists() or not metadata_f.is_file():
20
+ print(
21
+ f"Error: {metadata_f} does not exist or is not a file.", file=sys.stderr
22
+ )
23
+ return
24
+
25
+ print(f"Reading metadata from {metadata_f}...\n")
26
+ try:
27
+ metadata = cast(dict[str, Any], jsonpickle.decode(metadata_f.read_text()))
28
+ for k, v in metadata.items():
29
+ print(f" {k}: {v}")
30
+ except Exception as e:
31
+ print(f"Error decoding metadata: {e}", file=sys.stderr)
32
+
33
+ print(f"\n\nVisualizing results from {save_f}...\n")
34
+ try:
35
+ with open(save_f, "r") as f:
36
+ for i, line in enumerate(f):
37
+ line = line.strip()
38
+ if not line:
39
+ continue
40
+ try:
41
+ output = cast(
42
+ OutputBatch[Any, Any, Any], jsonpickle.decode(line)
43
+ )
44
+ print(f"--- Entry {i + 1} ---")
45
+ print(f"Data:\n{output.input_data}")
46
+ print("\nResults:")
47
+ if len(output.solutions) > 0:
48
+ for j, (
49
+ sol,
50
+ gen_time,
51
+ total_gen_time,
52
+ checker_output,
53
+ ) in enumerate(output):
54
+ print(f" Result {j + 1}:")
55
+ print(f" Solution: {sol}")
56
+ print(f" Status: {checker_output.result}")
57
+ print(f" Checking time: {checker_output.runtime}")
58
+ print(f" Generation time: {gen_time}")
59
+ print(f" Total generation time: {total_gen_time}")
60
+ else:
61
+ print(" No results found.")
62
+ print("\n" + "=" * 40 + "\n")
63
+
64
+ except Exception as e:
65
+ print(f"Error decoding line {i + 1}: {e}", file=sys.stderr)
66
+ except Exception as e:
67
+ print(f"Error reading file {save_f}: {e}", file=sys.stderr)
@@ -0,0 +1,50 @@
1
+ Metadata-Version: 2.4
2
+ Name: ape-framework
3
+ Version: 0.1.0
4
+ Summary: Package for evaluating algebra problems using AI systems.
5
+ Classifier: Programming Language :: Python :: 3
6
+ Classifier: Operating System :: OS Independent
7
+ Requires-Python: >=3.12
8
+ Description-Content-Type: text/markdown
9
+ License-File: LICENSE
10
+ Dynamic: license-file
11
+
12
+ # Algebra Problems Evaluator (APE)
13
+
14
+ ## What is APE?
15
+ **APE is a framework that simplifies building and running mathematical benchmarks for AI systems.**
16
+
17
+ It was initially meant as a way to evaluate mathematical problems from the field of algebra on whether they are suited to be a target goal for LLM reasoning research (hence the name).
18
+ However it might be more useful to think about it the other way around - the subjects of the evaluation are LLMs or more broadly - AI systems
19
+ and they are being evaluated on their ability to solve specific algebra problems. Problems, which solutions are hard to generate, but relatively easy to automatically check for correctness.
20
+
21
+ APE was created as a part of a Bachelor's project at the University of Warsaw.
22
+
23
+ ## User's Guide
24
+ [**TODO**]
25
+
26
+ ## Development setup
27
+
28
+ Install development dependencies:
29
+
30
+ ```bash
31
+ pip install -r requirements-dev.txt
32
+ ```
33
+
34
+ Install git hooks:
35
+
36
+ ```bash
37
+ pre-commit install
38
+ ```
39
+
40
+ Run all hooks manually:
41
+
42
+ ```bash
43
+ pre-commit run --all-files
44
+ ```
45
+
46
+ <!--
47
+ 1. What the user has to implement to have a complete benchmark?
48
+ 2. How to run a test using an APE benchmark?
49
+ 3. Where to find the documentation?
50
+ -->
@@ -0,0 +1,16 @@
1
+ LICENSE
2
+ README.md
3
+ pyproject.toml
4
+ ape/__init__.py
5
+ ape/adapter.py
6
+ ape/checker.py
7
+ ape/cli.py
8
+ ape/combine.py
9
+ ape/report.py
10
+ ape/save_handler.py
11
+ ape/solver.py
12
+ ape/visualizer.py
13
+ ape_framework.egg-info/PKG-INFO
14
+ ape_framework.egg-info/SOURCES.txt
15
+ ape_framework.egg-info/dependency_links.txt
16
+ ape_framework.egg-info/top_level.txt
@@ -0,0 +1,37 @@
1
+ [tool.black]
2
+ line-length = 88
3
+ target-version = ["py311"]
4
+ skip-string-normalization = false
5
+
6
+ [tool.isort]
7
+ profile = "black"
8
+ line_length = 88
9
+
10
+ [tool.mypy]
11
+ python_version = "3.12"
12
+ packages = ["ape"]
13
+ disallow_untyped_defs = true
14
+ disallow_incomplete_defs = true
15
+ no_implicit_optional = true
16
+ warn_return_any = true
17
+ warn_unused_ignores = true
18
+ strict_equality = true
19
+ ignore_missing_imports = true
20
+
21
+ [build-system]
22
+ requires = ["setuptools>=61.0"]
23
+ build-backend = "setuptools.build_meta"
24
+
25
+ [project]
26
+ name = "ape-framework"
27
+ version = "0.1.0"
28
+ description = "Package for evaluating algebra problems using AI systems."
29
+ readme = "README.md"
30
+ requires-python = ">=3.12"
31
+ classifiers = [
32
+ "Programming Language :: Python :: 3",
33
+ "Operating System :: OS Independent"
34
+ ]
35
+
36
+ [tool.setuptools]
37
+ packages = ["ape"]
@@ -0,0 +1,4 @@
1
+ [egg_info]
2
+ tag_build =
3
+ tag_date = 0
4
+