synth-ai 0.2.8.dev11__py3-none-any.whl → 0.2.8.dev13__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of synth-ai might be problematic. Click here for more details.
- synth_ai/api/train/__init__.py +5 -0
- synth_ai/api/train/builders.py +165 -0
- synth_ai/api/train/cli.py +429 -0
- synth_ai/api/train/config_finder.py +120 -0
- synth_ai/api/train/env_resolver.py +302 -0
- synth_ai/api/train/pollers.py +66 -0
- synth_ai/api/train/task_app.py +128 -0
- synth_ai/api/train/utils.py +232 -0
- synth_ai/cli/__init__.py +23 -0
- synth_ai/cli/rl_demo.py +2 -2
- synth_ai/cli/root.py +2 -1
- synth_ai/cli/task_apps.py +520 -0
- synth_ai/demos/demo_task_apps/math/modal_task_app.py +31 -25
- synth_ai/task/__init__.py +94 -1
- synth_ai/task/apps/__init__.py +88 -0
- synth_ai/task/apps/grpo_crafter.py +438 -0
- synth_ai/task/apps/math_single_step.py +852 -0
- synth_ai/task/auth.py +132 -0
- synth_ai/task/client.py +148 -0
- synth_ai/task/contracts.py +29 -14
- synth_ai/task/datasets.py +105 -0
- synth_ai/task/errors.py +49 -0
- synth_ai/task/json.py +77 -0
- synth_ai/task/proxy.py +258 -0
- synth_ai/task/rubrics.py +212 -0
- synth_ai/task/server.py +398 -0
- synth_ai/task/tracing_utils.py +79 -0
- synth_ai/task/vendors.py +61 -0
- synth_ai/tracing_v3/session_tracer.py +13 -5
- synth_ai/tracing_v3/storage/base.py +10 -12
- synth_ai/tracing_v3/turso/manager.py +20 -6
- {synth_ai-0.2.8.dev11.dist-info → synth_ai-0.2.8.dev13.dist-info}/METADATA +3 -2
- {synth_ai-0.2.8.dev11.dist-info → synth_ai-0.2.8.dev13.dist-info}/RECORD +37 -15
- {synth_ai-0.2.8.dev11.dist-info → synth_ai-0.2.8.dev13.dist-info}/WHEEL +0 -0
- {synth_ai-0.2.8.dev11.dist-info → synth_ai-0.2.8.dev13.dist-info}/entry_points.txt +0 -0
- {synth_ai-0.2.8.dev11.dist-info → synth_ai-0.2.8.dev13.dist-info}/licenses/LICENSE +0 -0
- {synth_ai-0.2.8.dev11.dist-info → synth_ai-0.2.8.dev13.dist-info}/top_level.txt +0 -0
synth_ai/task/proxy.py
ADDED
|
@@ -0,0 +1,258 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
"""Shared helpers for Task App proxy endpoints (OpenAI, Groq, etc.)."""
|
|
4
|
+
|
|
5
|
+
import copy
|
|
6
|
+
import json
|
|
7
|
+
import re
|
|
8
|
+
from typing import Any, Iterable, List, Tuple
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
INTERACT_TOOL_SCHEMA: List[dict[str, Any]] = [
|
|
12
|
+
{
|
|
13
|
+
"type": "function",
|
|
14
|
+
"function": {
|
|
15
|
+
"name": "interact",
|
|
16
|
+
"description": "Perform one or more environment actions.",
|
|
17
|
+
"parameters": {
|
|
18
|
+
"type": "object",
|
|
19
|
+
"properties": {
|
|
20
|
+
"actions": {
|
|
21
|
+
"type": "array",
|
|
22
|
+
"items": {"type": "string"},
|
|
23
|
+
"description": "List of environment actions to execute in order.",
|
|
24
|
+
},
|
|
25
|
+
"reasoning": {
|
|
26
|
+
"type": "string",
|
|
27
|
+
"description": "Optional reasoning for the chosen actions.",
|
|
28
|
+
},
|
|
29
|
+
},
|
|
30
|
+
"required": ["actions"],
|
|
31
|
+
"additionalProperties": False,
|
|
32
|
+
},
|
|
33
|
+
},
|
|
34
|
+
}
|
|
35
|
+
]
|
|
36
|
+
|
|
37
|
+
_REMOVE_FIELDS = {
|
|
38
|
+
"stop_after_tool_calls",
|
|
39
|
+
"thinking_mode",
|
|
40
|
+
"thinking_budget",
|
|
41
|
+
"reasoning",
|
|
42
|
+
}
|
|
43
|
+
_REMOVE_SAMPLING_FIELDS = {"temperature", "top_p"}
|
|
44
|
+
_GPT5_MIN_COMPLETION_TOKENS = 16000
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
def _ensure_tools(payload: dict[str, Any]) -> None:
|
|
48
|
+
tools = payload.get("tools")
|
|
49
|
+
if not isinstance(tools, list) or not tools:
|
|
50
|
+
payload["tools"] = copy.deepcopy(INTERACT_TOOL_SCHEMA)
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
def prepare_for_openai(model: str | None, payload: dict[str, Any]) -> dict[str, Any]:
|
|
54
|
+
"""Sanitise an OpenAI chat completions payload for Task App usage."""
|
|
55
|
+
|
|
56
|
+
sanitized = copy.deepcopy(payload)
|
|
57
|
+
for field in _REMOVE_FIELDS:
|
|
58
|
+
sanitized.pop(field, None)
|
|
59
|
+
|
|
60
|
+
if model and "gpt-5" in model:
|
|
61
|
+
max_tokens = sanitized.pop("max_tokens", None)
|
|
62
|
+
if "max_completion_tokens" not in sanitized and isinstance(max_tokens, int):
|
|
63
|
+
sanitized["max_completion_tokens"] = max_tokens
|
|
64
|
+
elif max_tokens is not None:
|
|
65
|
+
sanitized.setdefault("max_completion_tokens", max_tokens)
|
|
66
|
+
for field in _REMOVE_SAMPLING_FIELDS:
|
|
67
|
+
sanitized.pop(field, None)
|
|
68
|
+
mct = sanitized.get("max_completion_tokens")
|
|
69
|
+
if not isinstance(mct, int) or mct < _GPT5_MIN_COMPLETION_TOKENS:
|
|
70
|
+
sanitized["max_completion_tokens"] = _GPT5_MIN_COMPLETION_TOKENS
|
|
71
|
+
sanitized["tool_choice"] = {"type": "function", "function": {"name": "interact"}}
|
|
72
|
+
sanitized["parallel_tool_calls"] = False
|
|
73
|
+
|
|
74
|
+
_ensure_tools(sanitized)
|
|
75
|
+
return sanitized
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
def prepare_for_groq(model: str | None, payload: dict[str, Any]) -> dict[str, Any]:
|
|
79
|
+
"""Groq uses the OpenAI schema; reuse most normalisation rules."""
|
|
80
|
+
|
|
81
|
+
sanitized = prepare_for_openai(model, payload)
|
|
82
|
+
# Groq supports `max_tokens`; prefer their native parameter when present
|
|
83
|
+
if model and "gpt-5" not in (model or ""):
|
|
84
|
+
if "max_completion_tokens" in sanitized and "max_tokens" not in payload:
|
|
85
|
+
sanitized["max_tokens"] = sanitized.pop("max_completion_tokens")
|
|
86
|
+
return sanitized
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
def inject_system_hint(payload: dict[str, Any], hint: str) -> dict[str, Any]:
|
|
90
|
+
"""Insert or augment a system message with the provided hint (idempotent)."""
|
|
91
|
+
|
|
92
|
+
if not hint:
|
|
93
|
+
return payload
|
|
94
|
+
cloned = copy.deepcopy(payload)
|
|
95
|
+
messages = cloned.get("messages")
|
|
96
|
+
if not isinstance(messages, list):
|
|
97
|
+
return cloned
|
|
98
|
+
if messages and isinstance(messages[0], dict) and messages[0].get("role") == "system":
|
|
99
|
+
content = messages[0].get("content")
|
|
100
|
+
if isinstance(content, str) and hint not in content:
|
|
101
|
+
messages[0] = dict(messages[0])
|
|
102
|
+
messages[0]["content"] = content.rstrip() + ("\n\n" if content else "") + hint
|
|
103
|
+
else:
|
|
104
|
+
messages.insert(0, {"role": "system", "content": hint})
|
|
105
|
+
cloned["messages"] = messages
|
|
106
|
+
return cloned
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
def extract_message_text(message: Any) -> str:
|
|
110
|
+
"""Return best-effort text from an OpenAI-style message structure."""
|
|
111
|
+
|
|
112
|
+
if message is None:
|
|
113
|
+
return ""
|
|
114
|
+
if isinstance(message, str):
|
|
115
|
+
return message
|
|
116
|
+
if isinstance(message, list):
|
|
117
|
+
parts = [extract_message_text(part) for part in message]
|
|
118
|
+
return "\n".join(part for part in parts if part)
|
|
119
|
+
if isinstance(message, dict):
|
|
120
|
+
content = message.get("content")
|
|
121
|
+
if isinstance(content, str):
|
|
122
|
+
return content
|
|
123
|
+
if isinstance(content, list):
|
|
124
|
+
parts = []
|
|
125
|
+
for item in content:
|
|
126
|
+
text = extract_message_text(item)
|
|
127
|
+
if text:
|
|
128
|
+
parts.append(text)
|
|
129
|
+
return "\n".join(parts)
|
|
130
|
+
if "text" in message and isinstance(message["text"], str):
|
|
131
|
+
return message["text"]
|
|
132
|
+
return str(message)
|
|
133
|
+
|
|
134
|
+
|
|
135
|
+
def _parse_actions_from_json_candidate(candidate: Any) -> tuple[list[str], str]:
|
|
136
|
+
actions: list[str] = []
|
|
137
|
+
reasoning = ""
|
|
138
|
+
if isinstance(candidate, dict):
|
|
139
|
+
potential_actions = candidate.get("actions")
|
|
140
|
+
if isinstance(potential_actions, list):
|
|
141
|
+
actions = [str(a).strip() for a in potential_actions if str(a).strip()]
|
|
142
|
+
elif isinstance(potential_actions, str):
|
|
143
|
+
actions = [a.strip() for a in potential_actions.split(";") if a.strip()]
|
|
144
|
+
if "reasoning" in candidate and isinstance(candidate["reasoning"], str):
|
|
145
|
+
reasoning = candidate["reasoning"].strip()
|
|
146
|
+
return actions, reasoning
|
|
147
|
+
|
|
148
|
+
|
|
149
|
+
def parse_tool_call_from_text(text: str) -> Tuple[list[str], str]:
|
|
150
|
+
"""Derive tool-call actions and reasoning from assistant text."""
|
|
151
|
+
|
|
152
|
+
text = (text or "").strip()
|
|
153
|
+
if not text:
|
|
154
|
+
return [], ""
|
|
155
|
+
|
|
156
|
+
# Attempt to interpret the entire payload as JSON
|
|
157
|
+
try:
|
|
158
|
+
data = json.loads(text)
|
|
159
|
+
except Exception:
|
|
160
|
+
data = None
|
|
161
|
+
else:
|
|
162
|
+
actions, reasoning = _parse_actions_from_json_candidate(data)
|
|
163
|
+
if actions:
|
|
164
|
+
return actions, reasoning or text
|
|
165
|
+
|
|
166
|
+
# Look for embedded JSON objects containing an "actions" field
|
|
167
|
+
json_like_matches = re.findall(r"\{[^{}]*actions[^{}]*\}", text, flags=re.IGNORECASE)
|
|
168
|
+
for fragment in json_like_matches:
|
|
169
|
+
try:
|
|
170
|
+
data = json.loads(fragment)
|
|
171
|
+
except Exception:
|
|
172
|
+
continue
|
|
173
|
+
actions, reasoning = _parse_actions_from_json_candidate(data)
|
|
174
|
+
if actions:
|
|
175
|
+
return actions, reasoning or text
|
|
176
|
+
|
|
177
|
+
# Patterns like "Actions: move_right, jump"
|
|
178
|
+
m = re.search(r"actions?\s*:\s*([^\n]+)", text, flags=re.IGNORECASE)
|
|
179
|
+
if m:
|
|
180
|
+
items = [part.strip() for part in m.group(1).split(",") if part.strip()]
|
|
181
|
+
if items:
|
|
182
|
+
reasoning = text[:m.start()].strip()
|
|
183
|
+
return items, reasoning
|
|
184
|
+
|
|
185
|
+
# Patterns like "Action 1: move_right"
|
|
186
|
+
actions: list[str] = []
|
|
187
|
+
reasoning_lines: list[str] = []
|
|
188
|
+
for line in text.splitlines():
|
|
189
|
+
stripped = line.strip()
|
|
190
|
+
if not stripped:
|
|
191
|
+
continue
|
|
192
|
+
match = re.match(r"action\s*\d*\s*[:\-]\s*(.+)", stripped, flags=re.IGNORECASE)
|
|
193
|
+
if match:
|
|
194
|
+
candidate = match.group(1).strip()
|
|
195
|
+
if candidate:
|
|
196
|
+
actions.append(candidate)
|
|
197
|
+
else:
|
|
198
|
+
reasoning_lines.append(stripped)
|
|
199
|
+
if actions:
|
|
200
|
+
return actions, "\n".join(reasoning_lines).strip()
|
|
201
|
+
|
|
202
|
+
return [], text
|
|
203
|
+
|
|
204
|
+
|
|
205
|
+
def _build_tool_call(actions: Iterable[str], reasoning: str) -> dict[str, Any]:
|
|
206
|
+
payload = {
|
|
207
|
+
"actions": [str(a).strip() for a in actions if str(a).strip()],
|
|
208
|
+
}
|
|
209
|
+
if reasoning.strip():
|
|
210
|
+
payload["reasoning"] = reasoning.strip()
|
|
211
|
+
return {
|
|
212
|
+
"id": "tool_interact_fallback",
|
|
213
|
+
"type": "function",
|
|
214
|
+
"function": {
|
|
215
|
+
"name": INTERACT_TOOL_SCHEMA[0]["function"]["name"],
|
|
216
|
+
"arguments": json.dumps(payload, ensure_ascii=False),
|
|
217
|
+
},
|
|
218
|
+
}
|
|
219
|
+
|
|
220
|
+
|
|
221
|
+
def synthesize_tool_call_if_missing(openai_response: dict[str, Any]) -> dict[str, Any]:
|
|
222
|
+
"""Ensure the first choice carries a tool_call derived from text if absent."""
|
|
223
|
+
|
|
224
|
+
if not isinstance(openai_response, dict):
|
|
225
|
+
return openai_response
|
|
226
|
+
choices = openai_response.get("choices")
|
|
227
|
+
if not isinstance(choices, list) or not choices:
|
|
228
|
+
return openai_response
|
|
229
|
+
first = choices[0]
|
|
230
|
+
if not isinstance(first, dict):
|
|
231
|
+
return openai_response
|
|
232
|
+
message = first.get("message")
|
|
233
|
+
if not isinstance(message, dict):
|
|
234
|
+
return openai_response
|
|
235
|
+
tool_calls = message.get("tool_calls")
|
|
236
|
+
if isinstance(tool_calls, list) and tool_calls:
|
|
237
|
+
return openai_response
|
|
238
|
+
|
|
239
|
+
text = extract_message_text(message)
|
|
240
|
+
actions, reasoning = parse_tool_call_from_text(text)
|
|
241
|
+
if not actions:
|
|
242
|
+
return openai_response
|
|
243
|
+
|
|
244
|
+
new_message = copy.deepcopy(message)
|
|
245
|
+
new_message["tool_calls"] = [
|
|
246
|
+
_build_tool_call(actions, reasoning)
|
|
247
|
+
]
|
|
248
|
+
if "content" not in new_message:
|
|
249
|
+
new_message["content"] = None
|
|
250
|
+
|
|
251
|
+
new_first = copy.deepcopy(first)
|
|
252
|
+
new_first["message"] = new_message
|
|
253
|
+
new_choices = [new_first] + choices[1:]
|
|
254
|
+
|
|
255
|
+
result = copy.deepcopy(openai_response)
|
|
256
|
+
result["choices"] = new_choices
|
|
257
|
+
return result
|
|
258
|
+
|
synth_ai/task/rubrics.py
ADDED
|
@@ -0,0 +1,212 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
"""Rubric schema, loading, and scoring helpers for Task Apps."""
|
|
4
|
+
|
|
5
|
+
import json
|
|
6
|
+
from pathlib import Path
|
|
7
|
+
from typing import Any, Dict, Iterable, Optional
|
|
8
|
+
|
|
9
|
+
from pydantic import BaseModel, Field, field_validator
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class Criterion(BaseModel):
|
|
13
|
+
id: str
|
|
14
|
+
description: str
|
|
15
|
+
weight: float = 1.0
|
|
16
|
+
required: bool = False
|
|
17
|
+
|
|
18
|
+
@field_validator("weight")
|
|
19
|
+
@classmethod
|
|
20
|
+
def _validate_weight(cls, value: float) -> float:
|
|
21
|
+
if value <= 0:
|
|
22
|
+
raise ValueError("criterion weight must be positive")
|
|
23
|
+
return value
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class Rubric(BaseModel):
|
|
27
|
+
version: str
|
|
28
|
+
goal_text: str | None = None
|
|
29
|
+
criteria: list[Criterion] = Field(default_factory=list)
|
|
30
|
+
aggregation: str = "weighted_sum"
|
|
31
|
+
|
|
32
|
+
@field_validator("aggregation")
|
|
33
|
+
@classmethod
|
|
34
|
+
def _validate_aggregation(cls, value: str) -> str:
|
|
35
|
+
allowed = {"sum", "weighted_sum", "custom", "inherit"}
|
|
36
|
+
if value not in allowed:
|
|
37
|
+
raise ValueError(f"aggregation must be one of {sorted(allowed)}")
|
|
38
|
+
return value
|
|
39
|
+
|
|
40
|
+
@field_validator("criteria")
|
|
41
|
+
@classmethod
|
|
42
|
+
def _validate_criteria(cls, criteria: list[Criterion]) -> list[Criterion]:
|
|
43
|
+
seen = set()
|
|
44
|
+
for criterion in criteria:
|
|
45
|
+
if criterion.id in seen:
|
|
46
|
+
raise ValueError(f"duplicate criterion id: {criterion.id}")
|
|
47
|
+
seen.add(criterion.id)
|
|
48
|
+
return criteria
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
def _load_text(source: str) -> tuple[str, Optional[str]]:
|
|
52
|
+
path = Path(source)
|
|
53
|
+
if path.exists():
|
|
54
|
+
return path.read_text(encoding="utf-8"), path.suffix.lower()
|
|
55
|
+
return source, None
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
def _parse_structured(text: str, suffix: Optional[str]) -> Dict[str, Any]:
|
|
59
|
+
text = text.strip()
|
|
60
|
+
if not text:
|
|
61
|
+
raise ValueError("Rubric source is empty")
|
|
62
|
+
if suffix in (".yaml", ".yml"):
|
|
63
|
+
try:
|
|
64
|
+
import yaml # type: ignore
|
|
65
|
+
except Exception as exc: # pragma: no cover - optional dependency
|
|
66
|
+
raise RuntimeError("PyYAML is required to load YAML rubrics") from exc
|
|
67
|
+
data = yaml.safe_load(text)
|
|
68
|
+
if not isinstance(data, dict):
|
|
69
|
+
raise ValueError("Rubric YAML must produce a mapping")
|
|
70
|
+
return data
|
|
71
|
+
if text.startswith("{"):
|
|
72
|
+
return json.loads(text)
|
|
73
|
+
if text.startswith("http://") or text.startswith("https://"):
|
|
74
|
+
import requests # type: ignore
|
|
75
|
+
|
|
76
|
+
response = requests.get(text, timeout=15)
|
|
77
|
+
response.raise_for_status()
|
|
78
|
+
return _parse_structured(response.text, suffix)
|
|
79
|
+
try:
|
|
80
|
+
return json.loads(text)
|
|
81
|
+
except json.JSONDecodeError:
|
|
82
|
+
try:
|
|
83
|
+
import yaml # type: ignore
|
|
84
|
+
except Exception as exc: # pragma: no cover - optional dependency
|
|
85
|
+
raise RuntimeError("PyYAML is required to load rubric text") from exc
|
|
86
|
+
data = yaml.safe_load(text)
|
|
87
|
+
if not isinstance(data, dict):
|
|
88
|
+
raise ValueError("Rubric text must decode to a mapping")
|
|
89
|
+
return data
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
def load_rubric(source: str | dict[str, Any] | Rubric | None) -> Rubric | None:
|
|
93
|
+
if source is None:
|
|
94
|
+
return None
|
|
95
|
+
if isinstance(source, Rubric):
|
|
96
|
+
return source
|
|
97
|
+
if isinstance(source, dict):
|
|
98
|
+
return Rubric.model_validate(source)
|
|
99
|
+
text, suffix = _load_text(str(source))
|
|
100
|
+
data = _parse_structured(text, suffix)
|
|
101
|
+
return Rubric.model_validate(data)
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
def _merge_weights(base: Criterion, override: Criterion) -> float:
|
|
105
|
+
if override.weight != 1.0 and base.weight != 1.0:
|
|
106
|
+
return base.weight * override.weight
|
|
107
|
+
if override.weight != 1.0:
|
|
108
|
+
return override.weight
|
|
109
|
+
return base.weight
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
def blend_rubrics(base: Rubric | None, override: Rubric | None) -> Rubric | None:
|
|
113
|
+
if override is None and base is None:
|
|
114
|
+
return None
|
|
115
|
+
if base is None:
|
|
116
|
+
return override
|
|
117
|
+
if override is None:
|
|
118
|
+
return base
|
|
119
|
+
|
|
120
|
+
base_map = {criterion.id: criterion for criterion in base.criteria}
|
|
121
|
+
merged: list[Criterion] = []
|
|
122
|
+
|
|
123
|
+
for ov in override.criteria:
|
|
124
|
+
if ov.id in base_map:
|
|
125
|
+
existing = base_map.pop(ov.id)
|
|
126
|
+
merged.append(
|
|
127
|
+
Criterion(
|
|
128
|
+
id=ov.id,
|
|
129
|
+
description=ov.description or existing.description,
|
|
130
|
+
weight=_merge_weights(existing, ov),
|
|
131
|
+
required=ov.required if ov.required is not None else existing.required,
|
|
132
|
+
)
|
|
133
|
+
)
|
|
134
|
+
else:
|
|
135
|
+
merged.append(ov)
|
|
136
|
+
|
|
137
|
+
merged.extend(base_map.values())
|
|
138
|
+
|
|
139
|
+
aggregation = override.aggregation
|
|
140
|
+
if aggregation == "inherit":
|
|
141
|
+
aggregation = base.aggregation
|
|
142
|
+
|
|
143
|
+
return Rubric(
|
|
144
|
+
version=override.version or base.version,
|
|
145
|
+
goal_text=override.goal_text or base.goal_text,
|
|
146
|
+
criteria=merged,
|
|
147
|
+
aggregation=aggregation,
|
|
148
|
+
)
|
|
149
|
+
|
|
150
|
+
|
|
151
|
+
def _as_float(value: Any) -> Optional[float]:
|
|
152
|
+
try:
|
|
153
|
+
return float(value)
|
|
154
|
+
except Exception:
|
|
155
|
+
return None
|
|
156
|
+
|
|
157
|
+
|
|
158
|
+
def _score(criteria: Iterable[Criterion], values: Dict[str, float], aggregation: str) -> Dict[str, Any]:
|
|
159
|
+
if aggregation == "inherit":
|
|
160
|
+
aggregation = "weighted_sum"
|
|
161
|
+
per_criterion: Dict[str, Dict[str, Any]] = {}
|
|
162
|
+
total = 0.0
|
|
163
|
+
total_weight = 0.0
|
|
164
|
+
for criterion in criteria:
|
|
165
|
+
score = values.get(criterion.id, 0.0)
|
|
166
|
+
per_criterion[criterion.id] = {
|
|
167
|
+
"score": score,
|
|
168
|
+
"weight": criterion.weight,
|
|
169
|
+
"required": criterion.required,
|
|
170
|
+
}
|
|
171
|
+
if aggregation == "sum":
|
|
172
|
+
total += score
|
|
173
|
+
elif aggregation == "weighted_sum":
|
|
174
|
+
total += score * criterion.weight
|
|
175
|
+
total_weight += criterion.weight
|
|
176
|
+
if aggregation == "weighted_sum" and total_weight > 0:
|
|
177
|
+
total = total / total_weight
|
|
178
|
+
if aggregation == "custom":
|
|
179
|
+
total = None # type: ignore[assignment]
|
|
180
|
+
return {
|
|
181
|
+
"aggregation": aggregation,
|
|
182
|
+
"score": total,
|
|
183
|
+
"per_criterion": per_criterion,
|
|
184
|
+
}
|
|
185
|
+
|
|
186
|
+
|
|
187
|
+
def score_events_against_rubric(events: list[dict[str, Any]], rubric: Rubric | None) -> Dict[str, Any]:
|
|
188
|
+
if rubric is None:
|
|
189
|
+
return {"aggregation": "none", "score": None, "per_criterion": {}}
|
|
190
|
+
values: Dict[str, float] = {}
|
|
191
|
+
for event in events or []:
|
|
192
|
+
if not isinstance(event, dict):
|
|
193
|
+
continue
|
|
194
|
+
cid = event.get("criterion_id") or event.get("id") or event.get("criterion")
|
|
195
|
+
score = _as_float(event.get("score"))
|
|
196
|
+
if cid and score is not None:
|
|
197
|
+
values[str(cid)] = score
|
|
198
|
+
return _score(rubric.criteria, values, rubric.aggregation)
|
|
199
|
+
|
|
200
|
+
|
|
201
|
+
def score_outcome_against_rubric(outcome: dict[str, Any], rubric: Rubric | None) -> Dict[str, Any]:
|
|
202
|
+
if rubric is None:
|
|
203
|
+
return {"aggregation": "none", "score": None, "per_criterion": {}}
|
|
204
|
+
values: Dict[str, float] = {}
|
|
205
|
+
if isinstance(outcome, dict):
|
|
206
|
+
candidates = outcome.get("criteria") if isinstance(outcome.get("criteria"), dict) else outcome
|
|
207
|
+
if isinstance(candidates, dict):
|
|
208
|
+
for key, value in candidates.items():
|
|
209
|
+
score = _as_float(value)
|
|
210
|
+
if score is not None:
|
|
211
|
+
values[str(key)] = score
|
|
212
|
+
return _score(rubric.criteria, values, rubric.aggregation)
|