synth-ai 0.2.8.dev12__py3-none-any.whl → 0.2.8.dev13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of synth-ai might be problematic. Click here for more details.

Files changed (36) hide show
  1. synth_ai/api/train/__init__.py +5 -0
  2. synth_ai/api/train/builders.py +165 -0
  3. synth_ai/api/train/cli.py +429 -0
  4. synth_ai/api/train/config_finder.py +120 -0
  5. synth_ai/api/train/env_resolver.py +302 -0
  6. synth_ai/api/train/pollers.py +66 -0
  7. synth_ai/api/train/task_app.py +128 -0
  8. synth_ai/api/train/utils.py +232 -0
  9. synth_ai/cli/__init__.py +23 -0
  10. synth_ai/cli/rl_demo.py +2 -2
  11. synth_ai/cli/root.py +2 -1
  12. synth_ai/cli/task_apps.py +520 -0
  13. synth_ai/task/__init__.py +94 -1
  14. synth_ai/task/apps/__init__.py +88 -0
  15. synth_ai/task/apps/grpo_crafter.py +438 -0
  16. synth_ai/task/apps/math_single_step.py +852 -0
  17. synth_ai/task/auth.py +132 -0
  18. synth_ai/task/client.py +148 -0
  19. synth_ai/task/contracts.py +29 -14
  20. synth_ai/task/datasets.py +105 -0
  21. synth_ai/task/errors.py +49 -0
  22. synth_ai/task/json.py +77 -0
  23. synth_ai/task/proxy.py +258 -0
  24. synth_ai/task/rubrics.py +212 -0
  25. synth_ai/task/server.py +398 -0
  26. synth_ai/task/tracing_utils.py +79 -0
  27. synth_ai/task/vendors.py +61 -0
  28. synth_ai/tracing_v3/session_tracer.py +13 -5
  29. synth_ai/tracing_v3/storage/base.py +10 -12
  30. synth_ai/tracing_v3/turso/manager.py +20 -6
  31. {synth_ai-0.2.8.dev12.dist-info → synth_ai-0.2.8.dev13.dist-info}/METADATA +3 -2
  32. {synth_ai-0.2.8.dev12.dist-info → synth_ai-0.2.8.dev13.dist-info}/RECORD +36 -14
  33. {synth_ai-0.2.8.dev12.dist-info → synth_ai-0.2.8.dev13.dist-info}/WHEEL +0 -0
  34. {synth_ai-0.2.8.dev12.dist-info → synth_ai-0.2.8.dev13.dist-info}/entry_points.txt +0 -0
  35. {synth_ai-0.2.8.dev12.dist-info → synth_ai-0.2.8.dev13.dist-info}/licenses/LICENSE +0 -0
  36. {synth_ai-0.2.8.dev12.dist-info → synth_ai-0.2.8.dev13.dist-info}/top_level.txt +0 -0
synth_ai/task/proxy.py ADDED
@@ -0,0 +1,258 @@
1
+ from __future__ import annotations
2
+
3
+ """Shared helpers for Task App proxy endpoints (OpenAI, Groq, etc.)."""
4
+
5
+ import copy
6
+ import json
7
+ import re
8
+ from typing import Any, Iterable, List, Tuple
9
+
10
+
11
+ INTERACT_TOOL_SCHEMA: List[dict[str, Any]] = [
12
+ {
13
+ "type": "function",
14
+ "function": {
15
+ "name": "interact",
16
+ "description": "Perform one or more environment actions.",
17
+ "parameters": {
18
+ "type": "object",
19
+ "properties": {
20
+ "actions": {
21
+ "type": "array",
22
+ "items": {"type": "string"},
23
+ "description": "List of environment actions to execute in order.",
24
+ },
25
+ "reasoning": {
26
+ "type": "string",
27
+ "description": "Optional reasoning for the chosen actions.",
28
+ },
29
+ },
30
+ "required": ["actions"],
31
+ "additionalProperties": False,
32
+ },
33
+ },
34
+ }
35
+ ]
36
+
37
+ _REMOVE_FIELDS = {
38
+ "stop_after_tool_calls",
39
+ "thinking_mode",
40
+ "thinking_budget",
41
+ "reasoning",
42
+ }
43
+ _REMOVE_SAMPLING_FIELDS = {"temperature", "top_p"}
44
+ _GPT5_MIN_COMPLETION_TOKENS = 16000
45
+
46
+
47
+ def _ensure_tools(payload: dict[str, Any]) -> None:
48
+ tools = payload.get("tools")
49
+ if not isinstance(tools, list) or not tools:
50
+ payload["tools"] = copy.deepcopy(INTERACT_TOOL_SCHEMA)
51
+
52
+
53
+ def prepare_for_openai(model: str | None, payload: dict[str, Any]) -> dict[str, Any]:
54
+ """Sanitise an OpenAI chat completions payload for Task App usage."""
55
+
56
+ sanitized = copy.deepcopy(payload)
57
+ for field in _REMOVE_FIELDS:
58
+ sanitized.pop(field, None)
59
+
60
+ if model and "gpt-5" in model:
61
+ max_tokens = sanitized.pop("max_tokens", None)
62
+ if "max_completion_tokens" not in sanitized and isinstance(max_tokens, int):
63
+ sanitized["max_completion_tokens"] = max_tokens
64
+ elif max_tokens is not None:
65
+ sanitized.setdefault("max_completion_tokens", max_tokens)
66
+ for field in _REMOVE_SAMPLING_FIELDS:
67
+ sanitized.pop(field, None)
68
+ mct = sanitized.get("max_completion_tokens")
69
+ if not isinstance(mct, int) or mct < _GPT5_MIN_COMPLETION_TOKENS:
70
+ sanitized["max_completion_tokens"] = _GPT5_MIN_COMPLETION_TOKENS
71
+ sanitized["tool_choice"] = {"type": "function", "function": {"name": "interact"}}
72
+ sanitized["parallel_tool_calls"] = False
73
+
74
+ _ensure_tools(sanitized)
75
+ return sanitized
76
+
77
+
78
+ def prepare_for_groq(model: str | None, payload: dict[str, Any]) -> dict[str, Any]:
79
+ """Groq uses the OpenAI schema; reuse most normalisation rules."""
80
+
81
+ sanitized = prepare_for_openai(model, payload)
82
+ # Groq supports `max_tokens`; prefer their native parameter when present
83
+ if model and "gpt-5" not in (model or ""):
84
+ if "max_completion_tokens" in sanitized and "max_tokens" not in payload:
85
+ sanitized["max_tokens"] = sanitized.pop("max_completion_tokens")
86
+ return sanitized
87
+
88
+
89
+ def inject_system_hint(payload: dict[str, Any], hint: str) -> dict[str, Any]:
90
+ """Insert or augment a system message with the provided hint (idempotent)."""
91
+
92
+ if not hint:
93
+ return payload
94
+ cloned = copy.deepcopy(payload)
95
+ messages = cloned.get("messages")
96
+ if not isinstance(messages, list):
97
+ return cloned
98
+ if messages and isinstance(messages[0], dict) and messages[0].get("role") == "system":
99
+ content = messages[0].get("content")
100
+ if isinstance(content, str) and hint not in content:
101
+ messages[0] = dict(messages[0])
102
+ messages[0]["content"] = content.rstrip() + ("\n\n" if content else "") + hint
103
+ else:
104
+ messages.insert(0, {"role": "system", "content": hint})
105
+ cloned["messages"] = messages
106
+ return cloned
107
+
108
+
109
+ def extract_message_text(message: Any) -> str:
110
+ """Return best-effort text from an OpenAI-style message structure."""
111
+
112
+ if message is None:
113
+ return ""
114
+ if isinstance(message, str):
115
+ return message
116
+ if isinstance(message, list):
117
+ parts = [extract_message_text(part) for part in message]
118
+ return "\n".join(part for part in parts if part)
119
+ if isinstance(message, dict):
120
+ content = message.get("content")
121
+ if isinstance(content, str):
122
+ return content
123
+ if isinstance(content, list):
124
+ parts = []
125
+ for item in content:
126
+ text = extract_message_text(item)
127
+ if text:
128
+ parts.append(text)
129
+ return "\n".join(parts)
130
+ if "text" in message and isinstance(message["text"], str):
131
+ return message["text"]
132
+ return str(message)
133
+
134
+
135
+ def _parse_actions_from_json_candidate(candidate: Any) -> tuple[list[str], str]:
136
+ actions: list[str] = []
137
+ reasoning = ""
138
+ if isinstance(candidate, dict):
139
+ potential_actions = candidate.get("actions")
140
+ if isinstance(potential_actions, list):
141
+ actions = [str(a).strip() for a in potential_actions if str(a).strip()]
142
+ elif isinstance(potential_actions, str):
143
+ actions = [a.strip() for a in potential_actions.split(";") if a.strip()]
144
+ if "reasoning" in candidate and isinstance(candidate["reasoning"], str):
145
+ reasoning = candidate["reasoning"].strip()
146
+ return actions, reasoning
147
+
148
+
149
+ def parse_tool_call_from_text(text: str) -> Tuple[list[str], str]:
150
+ """Derive tool-call actions and reasoning from assistant text."""
151
+
152
+ text = (text or "").strip()
153
+ if not text:
154
+ return [], ""
155
+
156
+ # Attempt to interpret the entire payload as JSON
157
+ try:
158
+ data = json.loads(text)
159
+ except Exception:
160
+ data = None
161
+ else:
162
+ actions, reasoning = _parse_actions_from_json_candidate(data)
163
+ if actions:
164
+ return actions, reasoning or text
165
+
166
+ # Look for embedded JSON objects containing an "actions" field
167
+ json_like_matches = re.findall(r"\{[^{}]*actions[^{}]*\}", text, flags=re.IGNORECASE)
168
+ for fragment in json_like_matches:
169
+ try:
170
+ data = json.loads(fragment)
171
+ except Exception:
172
+ continue
173
+ actions, reasoning = _parse_actions_from_json_candidate(data)
174
+ if actions:
175
+ return actions, reasoning or text
176
+
177
+ # Patterns like "Actions: move_right, jump"
178
+ m = re.search(r"actions?\s*:\s*([^\n]+)", text, flags=re.IGNORECASE)
179
+ if m:
180
+ items = [part.strip() for part in m.group(1).split(",") if part.strip()]
181
+ if items:
182
+ reasoning = text[:m.start()].strip()
183
+ return items, reasoning
184
+
185
+ # Patterns like "Action 1: move_right"
186
+ actions: list[str] = []
187
+ reasoning_lines: list[str] = []
188
+ for line in text.splitlines():
189
+ stripped = line.strip()
190
+ if not stripped:
191
+ continue
192
+ match = re.match(r"action\s*\d*\s*[:\-]\s*(.+)", stripped, flags=re.IGNORECASE)
193
+ if match:
194
+ candidate = match.group(1).strip()
195
+ if candidate:
196
+ actions.append(candidate)
197
+ else:
198
+ reasoning_lines.append(stripped)
199
+ if actions:
200
+ return actions, "\n".join(reasoning_lines).strip()
201
+
202
+ return [], text
203
+
204
+
205
+ def _build_tool_call(actions: Iterable[str], reasoning: str) -> dict[str, Any]:
206
+ payload = {
207
+ "actions": [str(a).strip() for a in actions if str(a).strip()],
208
+ }
209
+ if reasoning.strip():
210
+ payload["reasoning"] = reasoning.strip()
211
+ return {
212
+ "id": "tool_interact_fallback",
213
+ "type": "function",
214
+ "function": {
215
+ "name": INTERACT_TOOL_SCHEMA[0]["function"]["name"],
216
+ "arguments": json.dumps(payload, ensure_ascii=False),
217
+ },
218
+ }
219
+
220
+
221
+ def synthesize_tool_call_if_missing(openai_response: dict[str, Any]) -> dict[str, Any]:
222
+ """Ensure the first choice carries a tool_call derived from text if absent."""
223
+
224
+ if not isinstance(openai_response, dict):
225
+ return openai_response
226
+ choices = openai_response.get("choices")
227
+ if not isinstance(choices, list) or not choices:
228
+ return openai_response
229
+ first = choices[0]
230
+ if not isinstance(first, dict):
231
+ return openai_response
232
+ message = first.get("message")
233
+ if not isinstance(message, dict):
234
+ return openai_response
235
+ tool_calls = message.get("tool_calls")
236
+ if isinstance(tool_calls, list) and tool_calls:
237
+ return openai_response
238
+
239
+ text = extract_message_text(message)
240
+ actions, reasoning = parse_tool_call_from_text(text)
241
+ if not actions:
242
+ return openai_response
243
+
244
+ new_message = copy.deepcopy(message)
245
+ new_message["tool_calls"] = [
246
+ _build_tool_call(actions, reasoning)
247
+ ]
248
+ if "content" not in new_message:
249
+ new_message["content"] = None
250
+
251
+ new_first = copy.deepcopy(first)
252
+ new_first["message"] = new_message
253
+ new_choices = [new_first] + choices[1:]
254
+
255
+ result = copy.deepcopy(openai_response)
256
+ result["choices"] = new_choices
257
+ return result
258
+
@@ -0,0 +1,212 @@
1
+ from __future__ import annotations
2
+
3
+ """Rubric schema, loading, and scoring helpers for Task Apps."""
4
+
5
+ import json
6
+ from pathlib import Path
7
+ from typing import Any, Dict, Iterable, Optional
8
+
9
+ from pydantic import BaseModel, Field, field_validator
10
+
11
+
12
+ class Criterion(BaseModel):
13
+ id: str
14
+ description: str
15
+ weight: float = 1.0
16
+ required: bool = False
17
+
18
+ @field_validator("weight")
19
+ @classmethod
20
+ def _validate_weight(cls, value: float) -> float:
21
+ if value <= 0:
22
+ raise ValueError("criterion weight must be positive")
23
+ return value
24
+
25
+
26
+ class Rubric(BaseModel):
27
+ version: str
28
+ goal_text: str | None = None
29
+ criteria: list[Criterion] = Field(default_factory=list)
30
+ aggregation: str = "weighted_sum"
31
+
32
+ @field_validator("aggregation")
33
+ @classmethod
34
+ def _validate_aggregation(cls, value: str) -> str:
35
+ allowed = {"sum", "weighted_sum", "custom", "inherit"}
36
+ if value not in allowed:
37
+ raise ValueError(f"aggregation must be one of {sorted(allowed)}")
38
+ return value
39
+
40
+ @field_validator("criteria")
41
+ @classmethod
42
+ def _validate_criteria(cls, criteria: list[Criterion]) -> list[Criterion]:
43
+ seen = set()
44
+ for criterion in criteria:
45
+ if criterion.id in seen:
46
+ raise ValueError(f"duplicate criterion id: {criterion.id}")
47
+ seen.add(criterion.id)
48
+ return criteria
49
+
50
+
51
+ def _load_text(source: str) -> tuple[str, Optional[str]]:
52
+ path = Path(source)
53
+ if path.exists():
54
+ return path.read_text(encoding="utf-8"), path.suffix.lower()
55
+ return source, None
56
+
57
+
58
+ def _parse_structured(text: str, suffix: Optional[str]) -> Dict[str, Any]:
59
+ text = text.strip()
60
+ if not text:
61
+ raise ValueError("Rubric source is empty")
62
+ if suffix in (".yaml", ".yml"):
63
+ try:
64
+ import yaml # type: ignore
65
+ except Exception as exc: # pragma: no cover - optional dependency
66
+ raise RuntimeError("PyYAML is required to load YAML rubrics") from exc
67
+ data = yaml.safe_load(text)
68
+ if not isinstance(data, dict):
69
+ raise ValueError("Rubric YAML must produce a mapping")
70
+ return data
71
+ if text.startswith("{"):
72
+ return json.loads(text)
73
+ if text.startswith("http://") or text.startswith("https://"):
74
+ import requests # type: ignore
75
+
76
+ response = requests.get(text, timeout=15)
77
+ response.raise_for_status()
78
+ return _parse_structured(response.text, suffix)
79
+ try:
80
+ return json.loads(text)
81
+ except json.JSONDecodeError:
82
+ try:
83
+ import yaml # type: ignore
84
+ except Exception as exc: # pragma: no cover - optional dependency
85
+ raise RuntimeError("PyYAML is required to load rubric text") from exc
86
+ data = yaml.safe_load(text)
87
+ if not isinstance(data, dict):
88
+ raise ValueError("Rubric text must decode to a mapping")
89
+ return data
90
+
91
+
92
+ def load_rubric(source: str | dict[str, Any] | Rubric | None) -> Rubric | None:
93
+ if source is None:
94
+ return None
95
+ if isinstance(source, Rubric):
96
+ return source
97
+ if isinstance(source, dict):
98
+ return Rubric.model_validate(source)
99
+ text, suffix = _load_text(str(source))
100
+ data = _parse_structured(text, suffix)
101
+ return Rubric.model_validate(data)
102
+
103
+
104
+ def _merge_weights(base: Criterion, override: Criterion) -> float:
105
+ if override.weight != 1.0 and base.weight != 1.0:
106
+ return base.weight * override.weight
107
+ if override.weight != 1.0:
108
+ return override.weight
109
+ return base.weight
110
+
111
+
112
+ def blend_rubrics(base: Rubric | None, override: Rubric | None) -> Rubric | None:
113
+ if override is None and base is None:
114
+ return None
115
+ if base is None:
116
+ return override
117
+ if override is None:
118
+ return base
119
+
120
+ base_map = {criterion.id: criterion for criterion in base.criteria}
121
+ merged: list[Criterion] = []
122
+
123
+ for ov in override.criteria:
124
+ if ov.id in base_map:
125
+ existing = base_map.pop(ov.id)
126
+ merged.append(
127
+ Criterion(
128
+ id=ov.id,
129
+ description=ov.description or existing.description,
130
+ weight=_merge_weights(existing, ov),
131
+ required=ov.required if ov.required is not None else existing.required,
132
+ )
133
+ )
134
+ else:
135
+ merged.append(ov)
136
+
137
+ merged.extend(base_map.values())
138
+
139
+ aggregation = override.aggregation
140
+ if aggregation == "inherit":
141
+ aggregation = base.aggregation
142
+
143
+ return Rubric(
144
+ version=override.version or base.version,
145
+ goal_text=override.goal_text or base.goal_text,
146
+ criteria=merged,
147
+ aggregation=aggregation,
148
+ )
149
+
150
+
151
+ def _as_float(value: Any) -> Optional[float]:
152
+ try:
153
+ return float(value)
154
+ except Exception:
155
+ return None
156
+
157
+
158
+ def _score(criteria: Iterable[Criterion], values: Dict[str, float], aggregation: str) -> Dict[str, Any]:
159
+ if aggregation == "inherit":
160
+ aggregation = "weighted_sum"
161
+ per_criterion: Dict[str, Dict[str, Any]] = {}
162
+ total = 0.0
163
+ total_weight = 0.0
164
+ for criterion in criteria:
165
+ score = values.get(criterion.id, 0.0)
166
+ per_criterion[criterion.id] = {
167
+ "score": score,
168
+ "weight": criterion.weight,
169
+ "required": criterion.required,
170
+ }
171
+ if aggregation == "sum":
172
+ total += score
173
+ elif aggregation == "weighted_sum":
174
+ total += score * criterion.weight
175
+ total_weight += criterion.weight
176
+ if aggregation == "weighted_sum" and total_weight > 0:
177
+ total = total / total_weight
178
+ if aggregation == "custom":
179
+ total = None # type: ignore[assignment]
180
+ return {
181
+ "aggregation": aggregation,
182
+ "score": total,
183
+ "per_criterion": per_criterion,
184
+ }
185
+
186
+
187
+ def score_events_against_rubric(events: list[dict[str, Any]], rubric: Rubric | None) -> Dict[str, Any]:
188
+ if rubric is None:
189
+ return {"aggregation": "none", "score": None, "per_criterion": {}}
190
+ values: Dict[str, float] = {}
191
+ for event in events or []:
192
+ if not isinstance(event, dict):
193
+ continue
194
+ cid = event.get("criterion_id") or event.get("id") or event.get("criterion")
195
+ score = _as_float(event.get("score"))
196
+ if cid and score is not None:
197
+ values[str(cid)] = score
198
+ return _score(rubric.criteria, values, rubric.aggregation)
199
+
200
+
201
+ def score_outcome_against_rubric(outcome: dict[str, Any], rubric: Rubric | None) -> Dict[str, Any]:
202
+ if rubric is None:
203
+ return {"aggregation": "none", "score": None, "per_criterion": {}}
204
+ values: Dict[str, float] = {}
205
+ if isinstance(outcome, dict):
206
+ candidates = outcome.get("criteria") if isinstance(outcome.get("criteria"), dict) else outcome
207
+ if isinstance(candidates, dict):
208
+ for key, value in candidates.items():
209
+ score = _as_float(value)
210
+ if score is not None:
211
+ values[str(key)] = score
212
+ return _score(rubric.criteria, values, rubric.aggregation)