osmosis-ai 0.2.1__py3-none-any.whl → 0.2.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of osmosis-ai might be problematic. Click here for more details.
- osmosis_ai/__init__.py +13 -4
- osmosis_ai/cli.py +50 -0
- osmosis_ai/cli_commands.py +181 -0
- osmosis_ai/cli_services/__init__.py +67 -0
- osmosis_ai/cli_services/config.py +407 -0
- osmosis_ai/cli_services/dataset.py +229 -0
- osmosis_ai/cli_services/engine.py +251 -0
- osmosis_ai/cli_services/errors.py +7 -0
- osmosis_ai/cli_services/reporting.py +307 -0
- osmosis_ai/cli_services/session.py +174 -0
- osmosis_ai/cli_services/shared.py +209 -0
- osmosis_ai/consts.py +1 -1
- osmosis_ai/providers/__init__.py +36 -0
- osmosis_ai/providers/anthropic_provider.py +85 -0
- osmosis_ai/providers/base.py +60 -0
- osmosis_ai/providers/gemini_provider.py +314 -0
- osmosis_ai/providers/openai_family.py +607 -0
- osmosis_ai/providers/shared.py +92 -0
- osmosis_ai/rubric_eval.py +498 -0
- osmosis_ai/rubric_types.py +49 -0
- osmosis_ai/utils.py +392 -5
- osmosis_ai-0.2.3.dist-info/METADATA +303 -0
- osmosis_ai-0.2.3.dist-info/RECORD +27 -0
- osmosis_ai-0.2.3.dist-info/entry_points.txt +4 -0
- osmosis_ai-0.2.1.dist-info/METADATA +0 -143
- osmosis_ai-0.2.1.dist-info/RECORD +0 -8
- {osmosis_ai-0.2.1.dist-info → osmosis_ai-0.2.3.dist-info}/WHEEL +0 -0
- {osmosis_ai-0.2.1.dist-info → osmosis_ai-0.2.3.dist-info}/licenses/LICENSE +0 -0
- {osmosis_ai-0.2.1.dist-info → osmosis_ai-0.2.3.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,174 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import time
|
|
4
|
+
from dataclasses import dataclass
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
from typing import Callable, Optional, Sequence
|
|
7
|
+
|
|
8
|
+
from ..rubric_eval import ensure_api_key_available
|
|
9
|
+
from ..rubric_types import MissingAPIKeyError
|
|
10
|
+
from .config import RubricConfig, RubricSuite, discover_rubric_config_path, load_rubric_suite
|
|
11
|
+
from .dataset import DatasetLoader, DatasetRecord
|
|
12
|
+
from .engine import RubricEvaluationEngine, EvaluationReport
|
|
13
|
+
from .errors import CLIError
|
|
14
|
+
from .reporting import BaselineComparator, BaselineStatistics, JsonReportWriter
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
_CACHE_ROOT = Path("~/.cache/osmosis/eval_result").expanduser()
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def _sanitise_rubric_folder(rubric_id: str) -> str:
|
|
21
|
+
"""Produce a filesystem-safe folder name for the rubric id."""
|
|
22
|
+
clean = "".join(ch if ch.isalnum() or ch in {"-", "_"} else "_" for ch in rubric_id.strip())
|
|
23
|
+
clean = clean.strip("_") or "rubric"
|
|
24
|
+
return clean.lower()
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
@dataclass(frozen=True)
|
|
28
|
+
class EvaluationSessionRequest:
|
|
29
|
+
rubric_id: str
|
|
30
|
+
data_path: Path
|
|
31
|
+
number: int = 1
|
|
32
|
+
config_path: Optional[Path] = None
|
|
33
|
+
output_path: Optional[Path] = None
|
|
34
|
+
output_identifier: Optional[str] = None
|
|
35
|
+
baseline_path: Optional[Path] = None
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
@dataclass
|
|
39
|
+
class EvaluationSessionResult:
|
|
40
|
+
request: EvaluationSessionRequest
|
|
41
|
+
config_path: Path
|
|
42
|
+
data_path: Path
|
|
43
|
+
rubric_config: RubricConfig
|
|
44
|
+
records: Sequence[DatasetRecord]
|
|
45
|
+
report: EvaluationReport
|
|
46
|
+
baseline: Optional[BaselineStatistics]
|
|
47
|
+
written_path: Optional[Path]
|
|
48
|
+
output_identifier: Optional[str]
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
class EvaluationSession:
|
|
52
|
+
"""Coordinates rubric evaluation end-to-end for reusable orchestration."""
|
|
53
|
+
|
|
54
|
+
def __init__(
|
|
55
|
+
self,
|
|
56
|
+
*,
|
|
57
|
+
config_locator: Callable[[Optional[str], Path], Path] = discover_rubric_config_path,
|
|
58
|
+
suite_loader: Callable[[Path], RubricSuite] = load_rubric_suite,
|
|
59
|
+
dataset_loader: Optional[DatasetLoader] = None,
|
|
60
|
+
engine: Optional[RubricEvaluationEngine] = None,
|
|
61
|
+
baseline_comparator: Optional[BaselineComparator] = None,
|
|
62
|
+
report_writer: Optional[JsonReportWriter] = None,
|
|
63
|
+
identifier_factory: Optional[Callable[[], str]] = None,
|
|
64
|
+
):
|
|
65
|
+
self._config_locator = config_locator
|
|
66
|
+
self._suite_loader = suite_loader
|
|
67
|
+
self._dataset_loader = dataset_loader or DatasetLoader()
|
|
68
|
+
self._engine = engine or RubricEvaluationEngine()
|
|
69
|
+
self._baseline_comparator = baseline_comparator or BaselineComparator()
|
|
70
|
+
self._report_writer = report_writer or JsonReportWriter()
|
|
71
|
+
self._identifier_factory = identifier_factory or self._default_identifier
|
|
72
|
+
|
|
73
|
+
def execute(self, request: EvaluationSessionRequest) -> EvaluationSessionResult:
|
|
74
|
+
rubric_id = request.rubric_id.strip()
|
|
75
|
+
if not rubric_id:
|
|
76
|
+
raise CLIError("Rubric identifier cannot be empty.")
|
|
77
|
+
|
|
78
|
+
number_value = request.number if request.number is not None else 1
|
|
79
|
+
number = int(number_value)
|
|
80
|
+
if number < 1:
|
|
81
|
+
raise CLIError("Number of runs must be a positive integer.")
|
|
82
|
+
|
|
83
|
+
data_path = request.data_path.expanduser()
|
|
84
|
+
if not data_path.exists():
|
|
85
|
+
raise CLIError(f"Data path '{data_path}' does not exist.")
|
|
86
|
+
if data_path.is_dir():
|
|
87
|
+
raise CLIError(f"Expected a JSONL file but received directory '{data_path}'.")
|
|
88
|
+
|
|
89
|
+
config_override = str(request.config_path.expanduser()) if request.config_path else None
|
|
90
|
+
config_path = self._config_locator(config_override, data_path)
|
|
91
|
+
suite = self._suite_loader(config_path)
|
|
92
|
+
rubric_config = suite.get(rubric_id)
|
|
93
|
+
|
|
94
|
+
try:
|
|
95
|
+
ensure_api_key_available(rubric_config.model_info)
|
|
96
|
+
except (MissingAPIKeyError, TypeError) as exc:
|
|
97
|
+
raise CLIError(str(exc)) from exc
|
|
98
|
+
|
|
99
|
+
all_records = self._dataset_loader.load(data_path)
|
|
100
|
+
matching_records = [
|
|
101
|
+
record for record in all_records if record.rubric_id.lower() == rubric_id.lower()
|
|
102
|
+
]
|
|
103
|
+
if not matching_records:
|
|
104
|
+
raise CLIError(f"No records in '{data_path}' reference rubric '{rubric_id}'.")
|
|
105
|
+
|
|
106
|
+
baseline_stats = self._load_baseline(request.baseline_path)
|
|
107
|
+
|
|
108
|
+
resolved_output_path, resolved_identifier = self._resolve_output_path(
|
|
109
|
+
request.output_path,
|
|
110
|
+
request.output_identifier,
|
|
111
|
+
rubric_id=rubric_id,
|
|
112
|
+
)
|
|
113
|
+
|
|
114
|
+
report = self._engine.execute(
|
|
115
|
+
rubric_config=rubric_config,
|
|
116
|
+
config_path=config_path,
|
|
117
|
+
data_path=data_path,
|
|
118
|
+
records=matching_records,
|
|
119
|
+
number=number,
|
|
120
|
+
)
|
|
121
|
+
|
|
122
|
+
written_path = None
|
|
123
|
+
if resolved_output_path is not None:
|
|
124
|
+
written_path = self._report_writer.write(
|
|
125
|
+
report,
|
|
126
|
+
output_path=resolved_output_path,
|
|
127
|
+
output_identifier=resolved_identifier,
|
|
128
|
+
baseline=baseline_stats,
|
|
129
|
+
)
|
|
130
|
+
|
|
131
|
+
return EvaluationSessionResult(
|
|
132
|
+
request=request,
|
|
133
|
+
config_path=config_path,
|
|
134
|
+
data_path=data_path,
|
|
135
|
+
rubric_config=rubric_config,
|
|
136
|
+
records=matching_records,
|
|
137
|
+
report=report,
|
|
138
|
+
baseline=baseline_stats,
|
|
139
|
+
written_path=written_path,
|
|
140
|
+
output_identifier=resolved_identifier,
|
|
141
|
+
)
|
|
142
|
+
|
|
143
|
+
def _load_baseline(self, baseline_path: Optional[Path]) -> Optional[BaselineStatistics]:
|
|
144
|
+
if baseline_path is None:
|
|
145
|
+
return None
|
|
146
|
+
resolved = baseline_path.expanduser()
|
|
147
|
+
return self._baseline_comparator.load(resolved)
|
|
148
|
+
|
|
149
|
+
def _resolve_output_path(
|
|
150
|
+
self,
|
|
151
|
+
output_candidate: Optional[Path],
|
|
152
|
+
output_identifier: Optional[str],
|
|
153
|
+
*,
|
|
154
|
+
rubric_id: str,
|
|
155
|
+
) -> tuple[Optional[Path], Optional[str]]:
|
|
156
|
+
if output_candidate is None:
|
|
157
|
+
identifier = output_identifier or self._identifier_factory()
|
|
158
|
+
target_dir = _CACHE_ROOT / _sanitise_rubric_folder(rubric_id)
|
|
159
|
+
target_dir.mkdir(parents=True, exist_ok=True)
|
|
160
|
+
return target_dir / f"rubric_eval_result_{identifier}.json", identifier
|
|
161
|
+
|
|
162
|
+
candidate = output_candidate.expanduser()
|
|
163
|
+
if candidate.suffix:
|
|
164
|
+
if candidate.exists() and candidate.is_dir():
|
|
165
|
+
raise CLIError(f"Output path '{candidate}' is a directory.")
|
|
166
|
+
return candidate, output_identifier
|
|
167
|
+
|
|
168
|
+
candidate.mkdir(parents=True, exist_ok=True)
|
|
169
|
+
identifier = output_identifier or self._identifier_factory()
|
|
170
|
+
return candidate / f"rubric_eval_result_{identifier}.json", identifier
|
|
171
|
+
|
|
172
|
+
@staticmethod
|
|
173
|
+
def _default_identifier() -> str:
|
|
174
|
+
return str(int(time.time()))
|
|
@@ -0,0 +1,209 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from statistics import mean, pvariance, pstdev
|
|
4
|
+
from typing import Any, Collection, Optional, Set
|
|
5
|
+
|
|
6
|
+
from .errors import CLIError
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def coerce_optional_float(value: Any, field_name: str, source_label: str) -> Optional[float]:
|
|
10
|
+
if value is None:
|
|
11
|
+
return None
|
|
12
|
+
if isinstance(value, (int, float)) and not isinstance(value, bool):
|
|
13
|
+
return float(value)
|
|
14
|
+
raise CLIError(
|
|
15
|
+
f"Expected '{field_name}' in {source_label} to be numeric, got {type(value).__name__}."
|
|
16
|
+
)
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def collapse_preview_text(value: Any, *, max_length: int = 140) -> Optional[str]:
|
|
20
|
+
if not isinstance(value, str):
|
|
21
|
+
return None
|
|
22
|
+
collapsed = " ".join(value.strip().split())
|
|
23
|
+
if not collapsed:
|
|
24
|
+
return None
|
|
25
|
+
if len(collapsed) > max_length:
|
|
26
|
+
collapsed = collapsed[: max_length - 3].rstrip() + "..."
|
|
27
|
+
return collapsed
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def calculate_statistics(scores: list[float]) -> dict[str, float]:
|
|
31
|
+
if not scores:
|
|
32
|
+
return {
|
|
33
|
+
"average": 0.0,
|
|
34
|
+
"variance": 0.0,
|
|
35
|
+
"stdev": 0.0,
|
|
36
|
+
"min": 0.0,
|
|
37
|
+
"max": 0.0,
|
|
38
|
+
}
|
|
39
|
+
average = mean(scores)
|
|
40
|
+
variance = pvariance(scores)
|
|
41
|
+
std_dev = pstdev(scores)
|
|
42
|
+
return {
|
|
43
|
+
"average": average,
|
|
44
|
+
"variance": variance,
|
|
45
|
+
"stdev": std_dev,
|
|
46
|
+
"min": min(scores),
|
|
47
|
+
"max": max(scores),
|
|
48
|
+
}
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
def calculate_stat_deltas(baseline: dict[str, float], current: dict[str, float]) -> dict[str, float]:
|
|
52
|
+
delta: dict[str, float] = {}
|
|
53
|
+
for key, current_value in current.items():
|
|
54
|
+
if key not in baseline:
|
|
55
|
+
continue
|
|
56
|
+
try:
|
|
57
|
+
baseline_value = float(baseline[key])
|
|
58
|
+
current_numeric = float(current_value)
|
|
59
|
+
except (TypeError, ValueError):
|
|
60
|
+
continue
|
|
61
|
+
delta[key] = current_numeric - baseline_value
|
|
62
|
+
return delta
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
def gather_text_fragments(
|
|
66
|
+
node: Any,
|
|
67
|
+
fragments: list[str],
|
|
68
|
+
*,
|
|
69
|
+
allow_free_strings: bool = False,
|
|
70
|
+
seen: Optional[Set[int]] = None,
|
|
71
|
+
string_key_allowlist: Optional[Collection[str]] = None,
|
|
72
|
+
) -> None:
|
|
73
|
+
"""Collect textual snippets from nested message-like structures.
|
|
74
|
+
|
|
75
|
+
The traversal favours common chat-completions shapes (e.g. ``{"type": "text"}``
|
|
76
|
+
blocks) and avoids indiscriminately pulling in metadata values such as IDs.
|
|
77
|
+
``allow_free_strings`` controls whether bare strings encountered at the current
|
|
78
|
+
level should be considered textual content (useful for raw message content but
|
|
79
|
+
typically disabled for metadata fields).
|
|
80
|
+
"""
|
|
81
|
+
|
|
82
|
+
if seen is None:
|
|
83
|
+
seen = set()
|
|
84
|
+
|
|
85
|
+
if isinstance(node, str):
|
|
86
|
+
if allow_free_strings:
|
|
87
|
+
stripped = node.strip()
|
|
88
|
+
if stripped:
|
|
89
|
+
fragments.append(stripped)
|
|
90
|
+
return
|
|
91
|
+
|
|
92
|
+
if isinstance(node, list):
|
|
93
|
+
for item in node:
|
|
94
|
+
gather_text_fragments(
|
|
95
|
+
item,
|
|
96
|
+
fragments,
|
|
97
|
+
allow_free_strings=allow_free_strings,
|
|
98
|
+
seen=seen,
|
|
99
|
+
string_key_allowlist=string_key_allowlist,
|
|
100
|
+
)
|
|
101
|
+
return
|
|
102
|
+
|
|
103
|
+
if not isinstance(node, dict):
|
|
104
|
+
return
|
|
105
|
+
|
|
106
|
+
node_id = id(node)
|
|
107
|
+
if node_id in seen:
|
|
108
|
+
return
|
|
109
|
+
seen.add(node_id)
|
|
110
|
+
|
|
111
|
+
allowlist = {"text", "value", "message"}
|
|
112
|
+
if string_key_allowlist is not None:
|
|
113
|
+
allowlist = {key.lower() for key in string_key_allowlist}
|
|
114
|
+
else:
|
|
115
|
+
allowlist = {key.lower() for key in allowlist}
|
|
116
|
+
|
|
117
|
+
prioritized_keys = ("text", "value")
|
|
118
|
+
handled_keys: Set[str] = {
|
|
119
|
+
"text",
|
|
120
|
+
"value",
|
|
121
|
+
"content",
|
|
122
|
+
"message",
|
|
123
|
+
"parts",
|
|
124
|
+
"input_text",
|
|
125
|
+
"output_text",
|
|
126
|
+
"type",
|
|
127
|
+
"role",
|
|
128
|
+
"name",
|
|
129
|
+
"id",
|
|
130
|
+
"index",
|
|
131
|
+
"finish_reason",
|
|
132
|
+
"reason",
|
|
133
|
+
"tool_call_id",
|
|
134
|
+
"metadata",
|
|
135
|
+
}
|
|
136
|
+
|
|
137
|
+
for key in prioritized_keys:
|
|
138
|
+
if key not in node:
|
|
139
|
+
continue
|
|
140
|
+
before_count = len(fragments)
|
|
141
|
+
gather_text_fragments(
|
|
142
|
+
node[key],
|
|
143
|
+
fragments,
|
|
144
|
+
allow_free_strings=True,
|
|
145
|
+
seen=seen,
|
|
146
|
+
string_key_allowlist=string_key_allowlist,
|
|
147
|
+
)
|
|
148
|
+
if len(fragments) > before_count:
|
|
149
|
+
break
|
|
150
|
+
|
|
151
|
+
if node.get("type") == "tool_result" and "content" in node:
|
|
152
|
+
gather_text_fragments(
|
|
153
|
+
node["content"],
|
|
154
|
+
fragments,
|
|
155
|
+
allow_free_strings=True,
|
|
156
|
+
seen=seen,
|
|
157
|
+
string_key_allowlist=string_key_allowlist,
|
|
158
|
+
)
|
|
159
|
+
elif "content" in node:
|
|
160
|
+
gather_text_fragments(
|
|
161
|
+
node["content"],
|
|
162
|
+
fragments,
|
|
163
|
+
allow_free_strings=True,
|
|
164
|
+
seen=seen,
|
|
165
|
+
string_key_allowlist=string_key_allowlist,
|
|
166
|
+
)
|
|
167
|
+
|
|
168
|
+
for key in ("message", "parts", "input_text", "output_text"):
|
|
169
|
+
if key in node:
|
|
170
|
+
gather_text_fragments(
|
|
171
|
+
node[key],
|
|
172
|
+
fragments,
|
|
173
|
+
allow_free_strings=True,
|
|
174
|
+
seen=seen,
|
|
175
|
+
string_key_allowlist=string_key_allowlist,
|
|
176
|
+
)
|
|
177
|
+
|
|
178
|
+
for key, value in node.items():
|
|
179
|
+
if key in handled_keys:
|
|
180
|
+
continue
|
|
181
|
+
if isinstance(value, (list, dict)):
|
|
182
|
+
gather_text_fragments(
|
|
183
|
+
value,
|
|
184
|
+
fragments,
|
|
185
|
+
allow_free_strings=False,
|
|
186
|
+
seen=seen,
|
|
187
|
+
string_key_allowlist=string_key_allowlist,
|
|
188
|
+
)
|
|
189
|
+
elif isinstance(value, str) and key.lower() in allowlist:
|
|
190
|
+
stripped = value.strip()
|
|
191
|
+
if stripped:
|
|
192
|
+
fragments.append(stripped)
|
|
193
|
+
|
|
194
|
+
|
|
195
|
+
def collect_text_fragments(
|
|
196
|
+
node: Any,
|
|
197
|
+
*,
|
|
198
|
+
allow_free_strings: bool = False,
|
|
199
|
+
string_key_allowlist: Optional[Collection[str]] = None,
|
|
200
|
+
) -> list[str]:
|
|
201
|
+
fragments: list[str] = []
|
|
202
|
+
gather_text_fragments(
|
|
203
|
+
node,
|
|
204
|
+
fragments,
|
|
205
|
+
allow_free_strings=allow_free_strings,
|
|
206
|
+
seen=set(),
|
|
207
|
+
string_key_allowlist=string_key_allowlist,
|
|
208
|
+
)
|
|
209
|
+
return fragments
|
osmosis_ai/consts.py
CHANGED
|
@@ -0,0 +1,36 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from typing import Tuple
|
|
4
|
+
|
|
5
|
+
from .anthropic_provider import AnthropicProvider
|
|
6
|
+
from .base import DEFAULT_REQUEST_TIMEOUT_SECONDS, ProviderRegistry, ProviderRequest, RubricProvider
|
|
7
|
+
from .gemini_provider import GeminiProvider
|
|
8
|
+
from .openai_family import OpenAIProvider, XAIProvider
|
|
9
|
+
|
|
10
|
+
_REGISTRY = ProviderRegistry()
|
|
11
|
+
_REGISTRY.register(OpenAIProvider())
|
|
12
|
+
_REGISTRY.register(XAIProvider())
|
|
13
|
+
_REGISTRY.register(AnthropicProvider())
|
|
14
|
+
_REGISTRY.register(GeminiProvider())
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def get_provider(name: str) -> RubricProvider:
|
|
18
|
+
return _REGISTRY.get(name)
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def register_provider(provider: RubricProvider) -> None:
|
|
22
|
+
_REGISTRY.register(provider)
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def supported_providers() -> Tuple[str, ...]:
|
|
26
|
+
return _REGISTRY.supported_providers()
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
__all__ = [
|
|
30
|
+
"DEFAULT_REQUEST_TIMEOUT_SECONDS",
|
|
31
|
+
"ProviderRequest",
|
|
32
|
+
"RubricProvider",
|
|
33
|
+
"get_provider",
|
|
34
|
+
"register_provider",
|
|
35
|
+
"supported_providers",
|
|
36
|
+
]
|
|
@@ -0,0 +1,85 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from typing import Any, Dict
|
|
4
|
+
|
|
5
|
+
try: # pragma: no cover - optional dependency
|
|
6
|
+
import anthropic # type: ignore
|
|
7
|
+
from anthropic import APIError # type: ignore
|
|
8
|
+
except ImportError: # pragma: no cover - optional dependency
|
|
9
|
+
anthropic = None # type: ignore[assignment]
|
|
10
|
+
APIError = None # type: ignore[assignment]
|
|
11
|
+
|
|
12
|
+
from ..rubric_types import ModelNotFoundError, ProviderRequestError, RewardRubricRunResult
|
|
13
|
+
from .base import DEFAULT_REQUEST_TIMEOUT_SECONDS, ProviderRequest, RubricProvider
|
|
14
|
+
from .shared import dump_model, extract_structured_score, reward_schema_definition
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class AnthropicProvider(RubricProvider):
|
|
18
|
+
name = "anthropic"
|
|
19
|
+
|
|
20
|
+
def default_timeout(self, model: str) -> float:
|
|
21
|
+
return DEFAULT_REQUEST_TIMEOUT_SECONDS
|
|
22
|
+
|
|
23
|
+
def run(self, request: ProviderRequest) -> RewardRubricRunResult:
|
|
24
|
+
if anthropic is None or APIError is None:
|
|
25
|
+
raise ProviderRequestError(
|
|
26
|
+
self.name,
|
|
27
|
+
request.model,
|
|
28
|
+
"Anthropic SDK is required. Install it via `pip install anthropic`.",
|
|
29
|
+
)
|
|
30
|
+
|
|
31
|
+
client = anthropic.Anthropic(api_key=request.api_key)
|
|
32
|
+
tool_name = "emit_reward_rubric_response"
|
|
33
|
+
schema_definition = reward_schema_definition()
|
|
34
|
+
tool = {
|
|
35
|
+
"name": tool_name,
|
|
36
|
+
"description": "Return the reward rubric score and explanation as structured JSON.",
|
|
37
|
+
"input_schema": schema_definition,
|
|
38
|
+
}
|
|
39
|
+
|
|
40
|
+
try:
|
|
41
|
+
response = client.messages.create(
|
|
42
|
+
model=request.model,
|
|
43
|
+
system=request.system_content,
|
|
44
|
+
messages=[{"role": "user", "content": [{"type": "text", "text": request.user_content}]}],
|
|
45
|
+
tools=[tool],
|
|
46
|
+
tool_choice={"type": "tool", "name": tool_name},
|
|
47
|
+
max_tokens=512,
|
|
48
|
+
temperature=0,
|
|
49
|
+
timeout=request.timeout,
|
|
50
|
+
)
|
|
51
|
+
except APIError as err:
|
|
52
|
+
detail = getattr(err, "message", None)
|
|
53
|
+
if not isinstance(detail, str) or not detail.strip():
|
|
54
|
+
detail = str(err)
|
|
55
|
+
status_code = getattr(err, "status_code", None)
|
|
56
|
+
if status_code == 404:
|
|
57
|
+
not_found_detail = (
|
|
58
|
+
f"Model '{request.model}' was not found. Confirm your Anthropic account has access "
|
|
59
|
+
"to the requested snapshot or update the model identifier."
|
|
60
|
+
)
|
|
61
|
+
raise ModelNotFoundError(self.name, request.model, not_found_detail) from err
|
|
62
|
+
raise ProviderRequestError(self.name, request.model, detail) from err
|
|
63
|
+
except Exception as err:
|
|
64
|
+
detail = str(err).strip() or "Unexpected error during Anthropic request."
|
|
65
|
+
raise ProviderRequestError(self.name, request.model, detail) from err
|
|
66
|
+
|
|
67
|
+
raw = dump_model(response)
|
|
68
|
+
|
|
69
|
+
payload: Dict[str, Any] | None = None
|
|
70
|
+
content_blocks = raw.get("content") if isinstance(raw, dict) else None
|
|
71
|
+
if isinstance(content_blocks, list):
|
|
72
|
+
for block in content_blocks:
|
|
73
|
+
if isinstance(block, dict) and block.get("type") == "tool_use" and block.get("name") == tool_name:
|
|
74
|
+
maybe_input = block.get("input")
|
|
75
|
+
if isinstance(maybe_input, dict):
|
|
76
|
+
payload = maybe_input
|
|
77
|
+
break
|
|
78
|
+
if payload is None:
|
|
79
|
+
raise ProviderRequestError(self.name, request.model, "Model response missing expected tool output.")
|
|
80
|
+
score, explanation = extract_structured_score(payload)
|
|
81
|
+
bounded = max(request.score_min, min(request.score_max, score))
|
|
82
|
+
return {"score": bounded, "explanation": explanation, "raw": raw}
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
__all__ = ["AnthropicProvider"]
|
|
@@ -0,0 +1,60 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from dataclasses import dataclass
|
|
4
|
+
from typing import Dict, Tuple
|
|
5
|
+
|
|
6
|
+
from ..rubric_types import RewardRubricRunResult
|
|
7
|
+
|
|
8
|
+
DEFAULT_REQUEST_TIMEOUT_SECONDS = 30.0
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
@dataclass(frozen=True)
|
|
12
|
+
class ProviderRequest:
|
|
13
|
+
provider: str
|
|
14
|
+
model: str
|
|
15
|
+
api_key: str
|
|
16
|
+
system_content: str
|
|
17
|
+
user_content: str
|
|
18
|
+
score_min: float
|
|
19
|
+
score_max: float
|
|
20
|
+
timeout: float
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class RubricProvider:
|
|
24
|
+
"""Interface for hosted LLM providers that can score rubrics."""
|
|
25
|
+
|
|
26
|
+
name: str
|
|
27
|
+
|
|
28
|
+
def default_timeout(self, model: str) -> float:
|
|
29
|
+
return DEFAULT_REQUEST_TIMEOUT_SECONDS
|
|
30
|
+
|
|
31
|
+
def run(self, request: ProviderRequest) -> RewardRubricRunResult:
|
|
32
|
+
raise NotImplementedError
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
class ProviderRegistry:
|
|
36
|
+
def __init__(self) -> None:
|
|
37
|
+
self._providers: Dict[str, RubricProvider] = {}
|
|
38
|
+
|
|
39
|
+
def register(self, provider: RubricProvider) -> None:
|
|
40
|
+
key = provider.name
|
|
41
|
+
if key in self._providers:
|
|
42
|
+
raise ValueError(f"Provider '{key}' is already registered.")
|
|
43
|
+
self._providers[key] = provider
|
|
44
|
+
|
|
45
|
+
def get(self, name: str) -> RubricProvider:
|
|
46
|
+
try:
|
|
47
|
+
return self._providers[name]
|
|
48
|
+
except KeyError as exc:
|
|
49
|
+
raise ValueError(f"Unsupported provider '{name}'.") from exc
|
|
50
|
+
|
|
51
|
+
def supported_providers(self) -> Tuple[str, ...]:
|
|
52
|
+
return tuple(sorted(self._providers))
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
__all__ = [
|
|
56
|
+
"DEFAULT_REQUEST_TIMEOUT_SECONDS",
|
|
57
|
+
"ProviderRequest",
|
|
58
|
+
"RubricProvider",
|
|
59
|
+
"ProviderRegistry",
|
|
60
|
+
]
|