inspect-test-utils 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- inspect_test_utils/__init__.py +37 -0
- inspect_test_utils/_registry.py +39 -0
- inspect_test_utils/assertions.py +114 -0
- inspect_test_utils/eval_runner.py +164 -0
- inspect_test_utils/fixtures.py +100 -0
- inspect_test_utils/hardcoded.py +191 -0
- inspect_test_utils/mockllm.py +48 -0
- inspect_test_utils/scanners.py +54 -0
- inspect_test_utils/scorers.py +82 -0
- inspect_test_utils/solvers.py +149 -0
- inspect_test_utils/tasks.py +377 -0
- inspect_test_utils-0.2.0.dist-info/METADATA +78 -0
- inspect_test_utils-0.2.0.dist-info/RECORD +16 -0
- inspect_test_utils-0.2.0.dist-info/WHEEL +4 -0
- inspect_test_utils-0.2.0.dist-info/entry_points.txt +2 -0
- inspect_test_utils-0.2.0.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,37 @@
|
|
|
1
|
+
"""Reusable testing framework for Inspect AI evaluation tasks.
|
|
2
|
+
|
|
3
|
+
This framework provides:
|
|
4
|
+
- Model-level mocking (HardcodedModelAPI)
|
|
5
|
+
- Solver-level hardcoded execution
|
|
6
|
+
- Test scorers and tasks
|
|
7
|
+
- Pytest fixtures and assertion helpers
|
|
8
|
+
- Eval runner for integration tests
|
|
9
|
+
"""
|
|
10
|
+
|
|
11
|
+
from inspect_test_utils.assertions import (
|
|
12
|
+
assert_contains,
|
|
13
|
+
assert_eval_score,
|
|
14
|
+
assert_files_exist,
|
|
15
|
+
assert_score_in_range,
|
|
16
|
+
)
|
|
17
|
+
from inspect_test_utils.eval_runner import (
|
|
18
|
+
EvalTestResult,
|
|
19
|
+
run_eval_test,
|
|
20
|
+
)
|
|
21
|
+
from inspect_test_utils.solvers import (
|
|
22
|
+
hardcoded_bash_solver,
|
|
23
|
+
hardcoded_python_solver,
|
|
24
|
+
inspection_solver,
|
|
25
|
+
)
|
|
26
|
+
|
|
27
|
+
__all__ = [
|
|
28
|
+
"assert_contains",
|
|
29
|
+
"assert_eval_score",
|
|
30
|
+
"assert_files_exist",
|
|
31
|
+
"assert_score_in_range",
|
|
32
|
+
"EvalTestResult",
|
|
33
|
+
"hardcoded_bash_solver",
|
|
34
|
+
"hardcoded_python_solver",
|
|
35
|
+
"inspection_solver",
|
|
36
|
+
"run_eval_test",
|
|
37
|
+
]
|
|
@@ -0,0 +1,39 @@
|
|
|
1
|
+
from inspect_test_utils.hardcoded import hardcoded
|
|
2
|
+
from inspect_test_utils.mockllm import mockllm_wrapper
|
|
3
|
+
from inspect_test_utils.scanners import (
|
|
4
|
+
model_roles_scanner,
|
|
5
|
+
suspicious_behaviour,
|
|
6
|
+
word_counter,
|
|
7
|
+
)
|
|
8
|
+
from inspect_test_utils.tasks import (
|
|
9
|
+
configurable_sandbox,
|
|
10
|
+
guess_number,
|
|
11
|
+
guess_number_keep_guessing,
|
|
12
|
+
hardcoded_score,
|
|
13
|
+
network_sandbox,
|
|
14
|
+
say_hello,
|
|
15
|
+
say_hello_with_tools,
|
|
16
|
+
sometimes_fails_scoring,
|
|
17
|
+
sometimes_fails_setup,
|
|
18
|
+
timeout,
|
|
19
|
+
uses_model_roles,
|
|
20
|
+
)
|
|
21
|
+
|
|
22
|
+
__all__ = [
|
|
23
|
+
"configurable_sandbox",
|
|
24
|
+
"guess_number",
|
|
25
|
+
"guess_number_keep_guessing",
|
|
26
|
+
"hardcoded",
|
|
27
|
+
"hardcoded_score",
|
|
28
|
+
"mockllm_wrapper",
|
|
29
|
+
"model_roles_scanner",
|
|
30
|
+
"network_sandbox",
|
|
31
|
+
"say_hello",
|
|
32
|
+
"say_hello_with_tools",
|
|
33
|
+
"sometimes_fails_scoring",
|
|
34
|
+
"sometimes_fails_setup",
|
|
35
|
+
"suspicious_behaviour",
|
|
36
|
+
"timeout",
|
|
37
|
+
"uses_model_roles",
|
|
38
|
+
"word_counter",
|
|
39
|
+
]
|
|
@@ -0,0 +1,114 @@
|
|
|
1
|
+
"""Assertion helpers for Inspect AI evaluation tests.
|
|
2
|
+
|
|
3
|
+
These functions provide convenient assertions for common test patterns.
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
from typing import TYPE_CHECKING
|
|
7
|
+
|
|
8
|
+
if TYPE_CHECKING:
|
|
9
|
+
from inspect_test_utils.eval_runner import EvalTestResult
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def assert_eval_score(
|
|
13
|
+
result: "EvalTestResult",
|
|
14
|
+
expected: float,
|
|
15
|
+
tolerance: float = 0.01,
|
|
16
|
+
*,
|
|
17
|
+
message: str | None = None,
|
|
18
|
+
) -> None:
|
|
19
|
+
"""Assert that the eval score matches the expected value.
|
|
20
|
+
|
|
21
|
+
Args:
|
|
22
|
+
result: The EvalTestResult from run_eval_test.
|
|
23
|
+
expected: Expected score value.
|
|
24
|
+
tolerance: Acceptable deviation from expected score.
|
|
25
|
+
message: Optional custom failure message.
|
|
26
|
+
|
|
27
|
+
Raises:
|
|
28
|
+
AssertionError: If score doesn't match within tolerance.
|
|
29
|
+
"""
|
|
30
|
+
actual = result.score
|
|
31
|
+
if actual is None:
|
|
32
|
+
raise AssertionError(f"No score returned. Error: {result.error}")
|
|
33
|
+
|
|
34
|
+
if abs(actual - expected) > tolerance:
|
|
35
|
+
msg = message or f"Expected score {expected}, got {actual}"
|
|
36
|
+
if result.explanation:
|
|
37
|
+
msg += f"\nExplanation: {result.explanation}"
|
|
38
|
+
raise AssertionError(msg)
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
def assert_score_in_range(
|
|
42
|
+
result: "EvalTestResult",
|
|
43
|
+
min_score: float,
|
|
44
|
+
max_score: float,
|
|
45
|
+
*,
|
|
46
|
+
message: str | None = None,
|
|
47
|
+
) -> None:
|
|
48
|
+
"""Assert that the eval score is within a range.
|
|
49
|
+
|
|
50
|
+
Args:
|
|
51
|
+
result: The EvalTestResult from run_eval_test.
|
|
52
|
+
min_score: Minimum acceptable score (inclusive).
|
|
53
|
+
max_score: Maximum acceptable score (inclusive).
|
|
54
|
+
message: Optional custom failure message.
|
|
55
|
+
|
|
56
|
+
Raises:
|
|
57
|
+
AssertionError: If score is outside the range.
|
|
58
|
+
"""
|
|
59
|
+
actual = result.score
|
|
60
|
+
if actual is None:
|
|
61
|
+
raise AssertionError(f"No score returned. Error: {result.error}")
|
|
62
|
+
|
|
63
|
+
if not (min_score <= actual <= max_score):
|
|
64
|
+
msg = message or f"Expected score in [{min_score}, {max_score}], got {actual}"
|
|
65
|
+
if result.explanation:
|
|
66
|
+
msg += f"\nExplanation: {result.explanation}"
|
|
67
|
+
raise AssertionError(msg)
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
def assert_files_exist(
|
|
71
|
+
files: list[str],
|
|
72
|
+
actual_files: list[str],
|
|
73
|
+
*,
|
|
74
|
+
message: str | None = None,
|
|
75
|
+
) -> None:
|
|
76
|
+
"""Assert that expected files exist in the actual file list.
|
|
77
|
+
|
|
78
|
+
Args:
|
|
79
|
+
files: List of expected file paths/names.
|
|
80
|
+
actual_files: List of actual files found.
|
|
81
|
+
message: Optional custom failure message.
|
|
82
|
+
|
|
83
|
+
Raises:
|
|
84
|
+
AssertionError: If any expected file is missing.
|
|
85
|
+
"""
|
|
86
|
+
missing = [f for f in files if f not in actual_files]
|
|
87
|
+
if missing:
|
|
88
|
+
msg = message or f"Missing files: {missing}"
|
|
89
|
+
msg += f"\nFound: {actual_files}"
|
|
90
|
+
raise AssertionError(msg)
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
def assert_contains(
|
|
94
|
+
needle: str,
|
|
95
|
+
haystack: str,
|
|
96
|
+
*,
|
|
97
|
+
message: str | None = None,
|
|
98
|
+
) -> None:
|
|
99
|
+
"""Assert that a string contains a substring.
|
|
100
|
+
|
|
101
|
+
Args:
|
|
102
|
+
needle: String to search for.
|
|
103
|
+
haystack: String to search in.
|
|
104
|
+
message: Optional custom failure message.
|
|
105
|
+
|
|
106
|
+
Raises:
|
|
107
|
+
AssertionError: If needle not found in haystack.
|
|
108
|
+
"""
|
|
109
|
+
if needle not in haystack:
|
|
110
|
+
msg = message or f"Expected to find '{needle}' in output"
|
|
111
|
+
# Show a preview of the haystack
|
|
112
|
+
preview = haystack[:500] + "..." if len(haystack) > 500 else haystack
|
|
113
|
+
msg += f"\nActual output:\n{preview}"
|
|
114
|
+
raise AssertionError(msg)
|
|
@@ -0,0 +1,164 @@
|
|
|
1
|
+
"""Eval runner helpers for testing Inspect AI evaluations.
|
|
2
|
+
|
|
3
|
+
Provides a simplified interface for running evals in tests with custom solvers.
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
from dataclasses import dataclass, field
|
|
7
|
+
from typing import Any, Callable
|
|
8
|
+
|
|
9
|
+
from inspect_ai import Task, eval
|
|
10
|
+
from inspect_ai.log import EvalLog
|
|
11
|
+
from inspect_ai.solver import Solver
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
@dataclass
|
|
15
|
+
class EvalTestResult:
|
|
16
|
+
"""Result of running an eval test.
|
|
17
|
+
|
|
18
|
+
Attributes:
|
|
19
|
+
success: Whether the eval completed without errors.
|
|
20
|
+
score: The score from the first sample (if available).
|
|
21
|
+
scores: All scores from the eval (for multi-sample tests).
|
|
22
|
+
explanation: Score explanation (if available).
|
|
23
|
+
metadata: Score metadata (if available).
|
|
24
|
+
log: The full EvalLog for detailed inspection.
|
|
25
|
+
error: Error message if the eval failed.
|
|
26
|
+
"""
|
|
27
|
+
|
|
28
|
+
success: bool
|
|
29
|
+
score: float | None = None
|
|
30
|
+
scores: list[float] = field(default_factory=list)
|
|
31
|
+
explanation: str | None = None
|
|
32
|
+
metadata: dict[str, Any] = field(default_factory=dict)
|
|
33
|
+
log: EvalLog | None = None
|
|
34
|
+
error: str | None = None
|
|
35
|
+
|
|
36
|
+
@classmethod
|
|
37
|
+
def from_log(cls, log: EvalLog) -> "EvalTestResult":
|
|
38
|
+
"""Create an EvalTestResult from an EvalLog."""
|
|
39
|
+
if log.status != "success":
|
|
40
|
+
return cls(
|
|
41
|
+
success=False,
|
|
42
|
+
log=log,
|
|
43
|
+
error=f"Eval status: {log.status}. Error: {log.error}",
|
|
44
|
+
)
|
|
45
|
+
|
|
46
|
+
# Extract scores from samples
|
|
47
|
+
scores: list[float] = []
|
|
48
|
+
explanation: str | None = None
|
|
49
|
+
metadata: dict[str, Any] = {}
|
|
50
|
+
|
|
51
|
+
if log.samples:
|
|
52
|
+
for sample in log.samples:
|
|
53
|
+
if sample.scores:
|
|
54
|
+
for score_name, score in sample.scores.items():
|
|
55
|
+
if score.value is not None:
|
|
56
|
+
if isinstance(score.value, (int, float)):
|
|
57
|
+
scores.append(float(score.value))
|
|
58
|
+
# Capture first explanation/metadata for convenience
|
|
59
|
+
if explanation is None:
|
|
60
|
+
explanation = score.explanation
|
|
61
|
+
if not metadata:
|
|
62
|
+
metadata = score.metadata or {}
|
|
63
|
+
|
|
64
|
+
return cls(
|
|
65
|
+
success=True,
|
|
66
|
+
score=scores[0] if scores else None,
|
|
67
|
+
scores=scores,
|
|
68
|
+
explanation=explanation,
|
|
69
|
+
metadata=metadata,
|
|
70
|
+
log=log,
|
|
71
|
+
)
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
def run_eval_test(
|
|
75
|
+
task: Task | Callable[[], Task],
|
|
76
|
+
solver: Solver | list[Solver] | None = None,
|
|
77
|
+
*,
|
|
78
|
+
limit: int | None = 1,
|
|
79
|
+
model: str = "mockllm/model",
|
|
80
|
+
model_args: dict[str, Any] | None = None,
|
|
81
|
+
sandbox_cleanup: bool = True,
|
|
82
|
+
message_limit: int | None = None,
|
|
83
|
+
**eval_kwargs: Any,
|
|
84
|
+
) -> EvalTestResult:
|
|
85
|
+
"""Run an eval for testing and return a simplified result.
|
|
86
|
+
|
|
87
|
+
This is the main entry point for running eval tests. It:
|
|
88
|
+
- Runs the eval with the provided solver
|
|
89
|
+
- Uses mockllm by default (no API calls)
|
|
90
|
+
- Returns a simplified result for assertions
|
|
91
|
+
|
|
92
|
+
Args:
|
|
93
|
+
task: The Task or task function to evaluate.
|
|
94
|
+
solver: Custom solver(s) to use. If None, uses the task's default solver.
|
|
95
|
+
limit: Number of samples to run (default 1 for fast tests).
|
|
96
|
+
model: Model to use (default mockllm/model for no API calls).
|
|
97
|
+
Use "hardcoded/test" with model_args for HardcodedModelAPI.
|
|
98
|
+
model_args: Arguments passed to the model provider. For HardcodedModelAPI:
|
|
99
|
+
- tool_calls: List of commands (strings) or HardcodedToolCall dicts
|
|
100
|
+
- repetitions: Number of times to cycle through commands (default 1)
|
|
101
|
+
- answer: Final answer content (default "done")
|
|
102
|
+
sandbox_cleanup: Whether to cleanup sandbox after test (default True).
|
|
103
|
+
message_limit: Override message_limit for the task.
|
|
104
|
+
**eval_kwargs: Additional kwargs passed to eval().
|
|
105
|
+
|
|
106
|
+
Returns:
|
|
107
|
+
EvalTestResult with score and metadata.
|
|
108
|
+
|
|
109
|
+
Example using hardcoded solver (bypasses model):
|
|
110
|
+
result = run_eval_test(
|
|
111
|
+
code_repair,
|
|
112
|
+
solver=hardcoded_bash_solver(["sed -i 's/bug/fix/' file.py"]),
|
|
113
|
+
)
|
|
114
|
+
|
|
115
|
+
Example using HardcodedModelAPI (tests full agent loop):
|
|
116
|
+
result = run_eval_test(
|
|
117
|
+
code_repair,
|
|
118
|
+
model="hardcoded/test",
|
|
119
|
+
model_args={"tool_calls": ["sed -i 's/bug/fix/' file.py"]},
|
|
120
|
+
)
|
|
121
|
+
"""
|
|
122
|
+
task_instance = task() if callable(task) else task
|
|
123
|
+
|
|
124
|
+
if message_limit is not None:
|
|
125
|
+
task_instance = Task(
|
|
126
|
+
dataset=task_instance.dataset,
|
|
127
|
+
solver=task_instance.solver if solver is None else solver,
|
|
128
|
+
scorer=task_instance.scorer,
|
|
129
|
+
message_limit=message_limit,
|
|
130
|
+
sandbox=task_instance.sandbox,
|
|
131
|
+
metadata=task_instance.metadata,
|
|
132
|
+
)
|
|
133
|
+
elif solver is not None:
|
|
134
|
+
task_instance = Task(
|
|
135
|
+
dataset=task_instance.dataset,
|
|
136
|
+
solver=solver,
|
|
137
|
+
scorer=task_instance.scorer,
|
|
138
|
+
message_limit=task_instance.message_limit,
|
|
139
|
+
sandbox=task_instance.sandbox,
|
|
140
|
+
metadata=task_instance.metadata,
|
|
141
|
+
)
|
|
142
|
+
|
|
143
|
+
try:
|
|
144
|
+
call_kwargs: dict[str, Any] = {
|
|
145
|
+
"limit": limit,
|
|
146
|
+
"sandbox_cleanup": sandbox_cleanup,
|
|
147
|
+
**eval_kwargs,
|
|
148
|
+
}
|
|
149
|
+
if model_args is not None:
|
|
150
|
+
call_kwargs["model_args"] = model_args
|
|
151
|
+
|
|
152
|
+
logs = eval(
|
|
153
|
+
task_instance,
|
|
154
|
+
model=model,
|
|
155
|
+
**call_kwargs,
|
|
156
|
+
)
|
|
157
|
+
|
|
158
|
+
if not logs:
|
|
159
|
+
return EvalTestResult(success=False, error="No eval logs returned")
|
|
160
|
+
|
|
161
|
+
return EvalTestResult.from_log(logs[0])
|
|
162
|
+
|
|
163
|
+
except Exception as e:
|
|
164
|
+
return EvalTestResult(success=False, error=str(e))
|
|
@@ -0,0 +1,100 @@
|
|
|
1
|
+
"""Pytest fixtures and configuration for Inspect AI eval testing.
|
|
2
|
+
|
|
3
|
+
This module provides pytest plugins and fixtures for testing Inspect AI evals.
|
|
4
|
+
Register it in your conftest.py:
|
|
5
|
+
|
|
6
|
+
pytest_plugins = ["inspect_test_utils.fixtures"]
|
|
7
|
+
|
|
8
|
+
Or import fixtures directly:
|
|
9
|
+
|
|
10
|
+
from inspect_test_utils.fixtures import skip_sandbox
|
|
11
|
+
"""
|
|
12
|
+
|
|
13
|
+
import os
|
|
14
|
+
from collections.abc import Generator
|
|
15
|
+
|
|
16
|
+
import pytest
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def pytest_configure(config: pytest.Config) -> None:
|
|
20
|
+
"""Register custom markers."""
|
|
21
|
+
config.addinivalue_line(
|
|
22
|
+
"markers", "sandbox: marks tests as requiring Docker sandbox (may be slow)"
|
|
23
|
+
)
|
|
24
|
+
config.addinivalue_line("markers", "slow: marks tests as slow running")
|
|
25
|
+
config.addinivalue_line(
|
|
26
|
+
"markers",
|
|
27
|
+
"compose(path): specifies the compose.yaml path for sandbox tests",
|
|
28
|
+
)
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def pytest_addoption(parser: pytest.Parser) -> None:
|
|
32
|
+
"""Add custom command line options."""
|
|
33
|
+
parser.addoption(
|
|
34
|
+
"--skip-sandbox",
|
|
35
|
+
action="store_true",
|
|
36
|
+
default=False,
|
|
37
|
+
help="Skip tests that require Docker sandbox",
|
|
38
|
+
)
|
|
39
|
+
parser.addoption(
|
|
40
|
+
"--run-slow",
|
|
41
|
+
action="store_true",
|
|
42
|
+
default=False,
|
|
43
|
+
help="Run tests marked as slow",
|
|
44
|
+
)
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
def pytest_collection_modifyitems(
|
|
48
|
+
config: pytest.Config, items: list[pytest.Item]
|
|
49
|
+
) -> None:
|
|
50
|
+
"""Modify test collection based on markers and options."""
|
|
51
|
+
skip_sandbox_marker = pytest.mark.skip(reason="--skip-sandbox specified")
|
|
52
|
+
skip_slow_marker = pytest.mark.skip(reason="need --run-slow option to run")
|
|
53
|
+
|
|
54
|
+
for item in items:
|
|
55
|
+
# Skip sandbox tests if --skip-sandbox is specified
|
|
56
|
+
if "sandbox" in item.keywords and config.getoption("--skip-sandbox"):
|
|
57
|
+
item.add_marker(skip_sandbox_marker)
|
|
58
|
+
|
|
59
|
+
# Skip slow tests unless --run-slow is specified
|
|
60
|
+
if "slow" in item.keywords and not config.getoption("--run-slow"):
|
|
61
|
+
item.add_marker(skip_slow_marker)
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
@pytest.fixture
|
|
65
|
+
def task_root() -> str:
|
|
66
|
+
"""Return the path to the tasks directory."""
|
|
67
|
+
# Find the tasks directory relative to the test file
|
|
68
|
+
return os.path.join(os.path.dirname(__file__), "..", "tasks")
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
@pytest.fixture
|
|
72
|
+
def sandbox_timeout() -> int:
|
|
73
|
+
"""Default timeout for sandbox operations in seconds."""
|
|
74
|
+
return 60
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
# Marker shortcuts for use in tests
|
|
78
|
+
skip_sandbox = pytest.mark.skipif(
|
|
79
|
+
os.environ.get("SKIP_SANDBOX", "").lower() in ("1", "true", "yes"),
|
|
80
|
+
reason="SKIP_SANDBOX environment variable is set",
|
|
81
|
+
)
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
@pytest.fixture
|
|
85
|
+
def requires_docker() -> Generator[None, None, None]:
|
|
86
|
+
"""Fixture that skips test if Docker is not available."""
|
|
87
|
+
import subprocess
|
|
88
|
+
|
|
89
|
+
try:
|
|
90
|
+
result = subprocess.run(
|
|
91
|
+
["docker", "info"],
|
|
92
|
+
capture_output=True,
|
|
93
|
+
timeout=10,
|
|
94
|
+
)
|
|
95
|
+
if result.returncode != 0:
|
|
96
|
+
pytest.skip("Docker is not running")
|
|
97
|
+
except (FileNotFoundError, subprocess.TimeoutExpired):
|
|
98
|
+
pytest.skip("Docker is not available")
|
|
99
|
+
|
|
100
|
+
yield
|
|
@@ -0,0 +1,191 @@
|
|
|
1
|
+
import json
|
|
2
|
+
import random
|
|
3
|
+
from asyncio import sleep
|
|
4
|
+
from collections.abc import Callable
|
|
5
|
+
from typing import Any, TypedDict, override
|
|
6
|
+
|
|
7
|
+
import inspect_ai._util.constants
|
|
8
|
+
from inspect_ai.model import (
|
|
9
|
+
ChatCompletionChoice,
|
|
10
|
+
ChatMessage,
|
|
11
|
+
ChatMessageAssistant,
|
|
12
|
+
GenerateConfig,
|
|
13
|
+
ModelAPI,
|
|
14
|
+
ModelCall,
|
|
15
|
+
ModelOutput,
|
|
16
|
+
ModelUsage,
|
|
17
|
+
modelapi,
|
|
18
|
+
)
|
|
19
|
+
from inspect_ai.tool import ToolCall, ToolInfo, ToolChoice
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class HardcodedToolCall(TypedDict):
|
|
23
|
+
tool_name: str
|
|
24
|
+
tool_args: dict[str, Any]
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class HardcodedModelAPI(ModelAPI):
|
|
28
|
+
def __init__(
|
|
29
|
+
self,
|
|
30
|
+
model_name: str,
|
|
31
|
+
base_url: str | None = None,
|
|
32
|
+
api_key: str | None = None,
|
|
33
|
+
config: GenerateConfig = GenerateConfig(),
|
|
34
|
+
tool_calls: list[HardcodedToolCall] | str | list[str] | None = None,
|
|
35
|
+
repetitions: int = 1,
|
|
36
|
+
answer: str = "done",
|
|
37
|
+
delay: float = 0.0,
|
|
38
|
+
concurrency: int = inspect_ai._util.constants.DEFAULT_MAX_CONNECTIONS,
|
|
39
|
+
failure_rate: float = 0.0,
|
|
40
|
+
input_tokens: int = 100,
|
|
41
|
+
output_tokens: int = 50,
|
|
42
|
+
):
|
|
43
|
+
super().__init__(
|
|
44
|
+
model_name=model_name, base_url=base_url, api_key=api_key, config=config
|
|
45
|
+
)
|
|
46
|
+
self.tool_calls = self._parse_tool_calls(tool_calls)
|
|
47
|
+
self.repetitions = repetitions
|
|
48
|
+
self.answer = answer
|
|
49
|
+
self.delay = delay
|
|
50
|
+
self.concurrency = concurrency
|
|
51
|
+
self.failure_rate = failure_rate
|
|
52
|
+
self.input_tokens = input_tokens
|
|
53
|
+
self.output_tokens = output_tokens
|
|
54
|
+
|
|
55
|
+
def _parse_tool_calls(
|
|
56
|
+
self, tool_calls: list[HardcodedToolCall] | str | list[str] | None
|
|
57
|
+
) -> list[HardcodedToolCall]:
|
|
58
|
+
if tool_calls is None:
|
|
59
|
+
return []
|
|
60
|
+
|
|
61
|
+
# Handle empty list early
|
|
62
|
+
if isinstance(tool_calls, list) and len(tool_calls) == 0:
|
|
63
|
+
return []
|
|
64
|
+
|
|
65
|
+
# Try to parse JSON if it's a list of strings (could be JSON fragments)
|
|
66
|
+
if (
|
|
67
|
+
isinstance(tool_calls, list)
|
|
68
|
+
and len(tool_calls) > 0
|
|
69
|
+
and isinstance(tool_calls[0], str)
|
|
70
|
+
):
|
|
71
|
+
try:
|
|
72
|
+
tool_calls = json.loads("[" + ",".join(tool_calls) + "]")
|
|
73
|
+
except json.JSONDecodeError:
|
|
74
|
+
pass
|
|
75
|
+
elif isinstance(tool_calls, str):
|
|
76
|
+
try:
|
|
77
|
+
tool_calls = json.loads(tool_calls)
|
|
78
|
+
except json.JSONDecodeError:
|
|
79
|
+
pass
|
|
80
|
+
|
|
81
|
+
if isinstance(tool_calls, str):
|
|
82
|
+
tool_calls = [tool_calls]
|
|
83
|
+
if len(tool_calls) == 0:
|
|
84
|
+
return []
|
|
85
|
+
if isinstance(tool_calls[0], str):
|
|
86
|
+
return [
|
|
87
|
+
HardcodedToolCall(tool_name="bash", tool_args={"cmd": cmd})
|
|
88
|
+
for cmd in tool_calls
|
|
89
|
+
]
|
|
90
|
+
for tool_call in tool_calls:
|
|
91
|
+
if not isinstance(tool_call, dict):
|
|
92
|
+
raise ValueError(f"Invalid tool call: {tool_call}")
|
|
93
|
+
if "tool_name" not in tool_call or "tool_args" not in tool_call:
|
|
94
|
+
raise ValueError(f"Invalid tool call: {tool_call}")
|
|
95
|
+
if not isinstance(tool_call.get("tool_args"), dict):
|
|
96
|
+
raise ValueError(f"Invalid tool_args (must be dict): {tool_call}")
|
|
97
|
+
return tool_calls
|
|
98
|
+
|
|
99
|
+
def max_connections(self) -> int:
|
|
100
|
+
return self.concurrency
|
|
101
|
+
|
|
102
|
+
@override
|
|
103
|
+
async def generate(
|
|
104
|
+
self,
|
|
105
|
+
input: list[ChatMessage],
|
|
106
|
+
tools: list[ToolInfo],
|
|
107
|
+
tool_choice: ToolChoice,
|
|
108
|
+
config: GenerateConfig,
|
|
109
|
+
record_call: Callable[[ModelCall], None] | None = None,
|
|
110
|
+
) -> ModelOutput | tuple[ModelOutput | Exception, ModelCall]:
|
|
111
|
+
index = sum(1 for m in input if m.role == "assistant")
|
|
112
|
+
next_tool_call_index = (
|
|
113
|
+
int(index) % len(self.tool_calls) if self.tool_calls else 0
|
|
114
|
+
)
|
|
115
|
+
repetition_count = int(index) // len(self.tool_calls) if self.tool_calls else 1
|
|
116
|
+
next_tool_call = (
|
|
117
|
+
self.tool_calls[next_tool_call_index]
|
|
118
|
+
if next_tool_call_index < len(self.tool_calls)
|
|
119
|
+
else None
|
|
120
|
+
)
|
|
121
|
+
|
|
122
|
+
model_call = ModelCall.create(
|
|
123
|
+
request={"hardcoded": "test"}, response=None, filter=None, time=None
|
|
124
|
+
)
|
|
125
|
+
if record_call:
|
|
126
|
+
record_call(model_call)
|
|
127
|
+
|
|
128
|
+
if self.delay > 0:
|
|
129
|
+
await sleep(self.delay)
|
|
130
|
+
|
|
131
|
+
if random.random() < self.failure_rate:
|
|
132
|
+
model_call.response = {"failure": "test"}
|
|
133
|
+
try:
|
|
134
|
+
raise Exception("Failure")
|
|
135
|
+
except Exception as e:
|
|
136
|
+
return e, model_call
|
|
137
|
+
|
|
138
|
+
if repetition_count >= self.repetitions:
|
|
139
|
+
submit_tool = next((tool for tool in tools if tool.name == "submit"), None)
|
|
140
|
+
if submit_tool is None:
|
|
141
|
+
message = ChatMessageAssistant(content=self.answer)
|
|
142
|
+
else:
|
|
143
|
+
message = ChatMessageAssistant(
|
|
144
|
+
content="I will now submit my answer.",
|
|
145
|
+
tool_calls=[
|
|
146
|
+
ToolCall(
|
|
147
|
+
id="hardcoded_submit",
|
|
148
|
+
function=submit_tool.name,
|
|
149
|
+
arguments={"answer": self.answer},
|
|
150
|
+
)
|
|
151
|
+
],
|
|
152
|
+
)
|
|
153
|
+
|
|
154
|
+
choice = ChatCompletionChoice(
|
|
155
|
+
message=message,
|
|
156
|
+
stop_reason="stop",
|
|
157
|
+
)
|
|
158
|
+
else:
|
|
159
|
+
tool_name = next_tool_call["tool_name"]
|
|
160
|
+
tool_args = next_tool_call["tool_args"]
|
|
161
|
+
|
|
162
|
+
message = ChatMessageAssistant(
|
|
163
|
+
content=f"Executing {tool_name} with args: {tool_args}",
|
|
164
|
+
tool_calls=[
|
|
165
|
+
ToolCall(
|
|
166
|
+
id=f"hardcoded_{index}",
|
|
167
|
+
function=tool_name,
|
|
168
|
+
arguments=tool_args,
|
|
169
|
+
)
|
|
170
|
+
],
|
|
171
|
+
)
|
|
172
|
+
choice = ChatCompletionChoice(message=message)
|
|
173
|
+
|
|
174
|
+
model_call.response = {"test": "hardcoded"}
|
|
175
|
+
return ModelOutput(
|
|
176
|
+
model="hardcoded",
|
|
177
|
+
choices=[choice],
|
|
178
|
+
usage=ModelUsage(
|
|
179
|
+
input_tokens=self.input_tokens,
|
|
180
|
+
output_tokens=self.output_tokens,
|
|
181
|
+
total_tokens=self.input_tokens + self.output_tokens,
|
|
182
|
+
),
|
|
183
|
+
), model_call
|
|
184
|
+
|
|
185
|
+
def should_retry(self, ex: Exception) -> bool:
|
|
186
|
+
return True
|
|
187
|
+
|
|
188
|
+
|
|
189
|
+
@modelapi(name="hardcoded")
|
|
190
|
+
def hardcoded() -> type[ModelAPI]:
|
|
191
|
+
return HardcodedModelAPI
|
|
@@ -0,0 +1,48 @@
|
|
|
1
|
+
from typing import Any, Iterable
|
|
2
|
+
|
|
3
|
+
from inspect_ai.model import (
|
|
4
|
+
GenerateConfig,
|
|
5
|
+
ModelAPI,
|
|
6
|
+
ModelInfo,
|
|
7
|
+
ModelOutput,
|
|
8
|
+
modelapi,
|
|
9
|
+
set_model_info,
|
|
10
|
+
)
|
|
11
|
+
from inspect_ai.model._providers.mockllm import MockLLM
|
|
12
|
+
from pydantic import TypeAdapter
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class MockLLMWrapper(MockLLM):
|
|
16
|
+
"""A simple MockLLM wrapper that parses custom model outputs given as a dict. Useful
|
|
17
|
+
when configuring mockllm from an eval set config.
|
|
18
|
+
"""
|
|
19
|
+
|
|
20
|
+
def __init__(
|
|
21
|
+
self,
|
|
22
|
+
model_name: str,
|
|
23
|
+
base_url: str | None = None,
|
|
24
|
+
api_key: str | None = None,
|
|
25
|
+
config: GenerateConfig = GenerateConfig(),
|
|
26
|
+
custom_outputs: Iterable[ModelOutput] | None = None,
|
|
27
|
+
**model_args: dict[str, Any],
|
|
28
|
+
) -> None:
|
|
29
|
+
parsed_outputs = (
|
|
30
|
+
TypeAdapter(list[ModelOutput]).validate_python(custom_outputs)
|
|
31
|
+
if custom_outputs
|
|
32
|
+
else None
|
|
33
|
+
)
|
|
34
|
+
super().__init__(model_name, base_url, api_key, config, parsed_outputs, **model_args)
|
|
35
|
+
|
|
36
|
+
# Need to register this so cost tracking works
|
|
37
|
+
set_model_info(
|
|
38
|
+
f"mockllm_wrapper/{self.model_name}",
|
|
39
|
+
ModelInfo()
|
|
40
|
+
)
|
|
41
|
+
|
|
42
|
+
def canonical_name(self):
|
|
43
|
+
return f"mockllm_wrapper/{self.model_name}"
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
@modelapi(name="mockllm_wrapper")
|
|
47
|
+
def mockllm_wrapper() -> type[ModelAPI]:
|
|
48
|
+
return MockLLMWrapper
|