synth-ai 0.2.9.dev3__py3-none-any.whl → 0.2.9.dev4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of synth-ai might be problematic. Click here for more details.

Files changed (107) hide show
  1. examples/analyze_semantic_words.sh +17 -0
  2. examples/common_old/backend.py +21 -0
  3. examples/crafter_debug_render.py +180 -0
  4. examples/evals_old/README.md +98 -0
  5. examples/evals_old/__init__.py +6 -0
  6. examples/evals_old/compare_models.py +1037 -0
  7. examples/evals_old/example_log.md +145 -0
  8. examples/evals_old/run_demo.sh +126 -0
  9. examples/evals_old/trace_analysis.py +270 -0
  10. examples/finetuning_old/_backup_synth_qwen/config.toml +29 -0
  11. examples/finetuning_old/_backup_synth_qwen/example_log.md +324 -0
  12. examples/finetuning_old/_backup_synth_qwen/filter_traces.py +60 -0
  13. examples/finetuning_old/_backup_synth_qwen/filter_traces_achievements.py +239 -0
  14. examples/finetuning_old/_backup_synth_qwen/purge_v3_traces.py +109 -0
  15. examples/finetuning_old/_backup_synth_qwen/react_agent_lm.py +1924 -0
  16. examples/finetuning_old/_backup_synth_qwen/readme.md +49 -0
  17. examples/finetuning_old/_backup_synth_qwen/run_crafter_qwen4b.py +114 -0
  18. examples/finetuning_old/_backup_synth_qwen/run_demo.sh +195 -0
  19. examples/finetuning_old/_backup_synth_qwen/sft_kickoff.py +118 -0
  20. examples/finetuning_old/synth_qwen_v1/README.md +68 -0
  21. examples/finetuning_old/synth_qwen_v1/filter_traces.py +60 -0
  22. examples/finetuning_old/synth_qwen_v1/filter_traces_achievements.py +239 -0
  23. examples/finetuning_old/synth_qwen_v1/finetune.py +46 -0
  24. examples/finetuning_old/synth_qwen_v1/hello_ft_model.py +71 -0
  25. examples/finetuning_old/synth_qwen_v1/infer.py +37 -0
  26. examples/finetuning_old/synth_qwen_v1/poll.py +44 -0
  27. examples/finetuning_old/synth_qwen_v1/prepare_data.py +35 -0
  28. examples/finetuning_old/synth_qwen_v1/purge_v3_traces.py +109 -0
  29. examples/finetuning_old/synth_qwen_v1/react_agent_lm.py +1932 -0
  30. examples/finetuning_old/synth_qwen_v1/run_crafter_sft_job.py +207 -0
  31. examples/finetuning_old/synth_qwen_v1/run_ft_job.py +232 -0
  32. examples/finetuning_old/synth_qwen_v1/upload_data.py +34 -0
  33. examples/finetuning_old/synth_qwen_v1/util.py +147 -0
  34. examples/rl/README.md +169 -0
  35. examples/rl/configs/eval_base_qwen.toml +15 -0
  36. examples/rl/configs/eval_rl_qwen.toml +11 -0
  37. examples/rl/configs/rl_from_base_qwen.toml +35 -0
  38. examples/rl/configs/rl_from_base_qwen17.toml +74 -0
  39. examples/rl/configs/rl_from_ft_qwen.toml +35 -0
  40. examples/rl/download_dataset.py +64 -0
  41. examples/rl/run_eval.py +435 -0
  42. examples/rl/run_rl_and_save.py +94 -0
  43. examples/rl/task_app/README.md +22 -0
  44. {synth_ai/task/apps → examples/rl/task_app}/math_single_step.py +8 -8
  45. examples/rl/task_app/math_task_app.py +107 -0
  46. examples/rl_old/task_app.py +962 -0
  47. examples/run_crafter_demo.sh +10 -0
  48. examples/warming_up_to_rl/analyze_trace_db.py +420 -0
  49. examples/warming_up_to_rl/configs/crafter_fft.toml +48 -0
  50. examples/warming_up_to_rl/configs/crafter_fft_4b.toml +54 -0
  51. examples/warming_up_to_rl/configs/eval_fft_qwen4b.toml +20 -0
  52. examples/warming_up_to_rl/configs/eval_groq_qwen32b.toml +13 -0
  53. examples/warming_up_to_rl/configs/eval_modal_qwen4b.toml +23 -0
  54. examples/warming_up_to_rl/configs/rl_from_base_qwen4b.toml +73 -0
  55. examples/warming_up_to_rl/configs/rl_from_ft.toml +56 -0
  56. examples/warming_up_to_rl/export_trace_sft.py +541 -0
  57. examples/warming_up_to_rl/groq_test.py +88 -0
  58. examples/warming_up_to_rl/manage_secrets.py +127 -0
  59. examples/warming_up_to_rl/old/event_rewards.md +234 -0
  60. examples/warming_up_to_rl/old/notes.md +73 -0
  61. examples/warming_up_to_rl/readme.md +172 -0
  62. examples/warming_up_to_rl/run_eval.py +434 -0
  63. examples/warming_up_to_rl/run_fft_and_save.py +309 -0
  64. examples/warming_up_to_rl/run_local_rollout.py +188 -0
  65. examples/warming_up_to_rl/run_local_rollout_modal.py +160 -0
  66. examples/warming_up_to_rl/run_local_rollout_parallel.py +342 -0
  67. examples/warming_up_to_rl/run_local_rollout_traced.py +372 -0
  68. examples/warming_up_to_rl/run_rl_and_save.py +101 -0
  69. examples/warming_up_to_rl/run_rollout_remote.py +129 -0
  70. examples/warming_up_to_rl/task_app/README.md +38 -0
  71. {synth_ai/task/apps → examples/warming_up_to_rl/task_app}/grpo_crafter.py +7 -7
  72. examples/warming_up_to_rl/task_app/grpo_crafter_task_app.py +165 -0
  73. examples/warming_up_to_rl/task_app/synth_envs_hosted/README.md +173 -0
  74. examples/warming_up_to_rl/task_app/synth_envs_hosted/__init__.py +5 -0
  75. examples/warming_up_to_rl/task_app/synth_envs_hosted/branching.py +145 -0
  76. examples/warming_up_to_rl/task_app/synth_envs_hosted/environment_routes.py +1271 -0
  77. examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/__init__.py +1 -0
  78. examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/__init__.py +6 -0
  79. examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/app.py +1 -0
  80. examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/environment.py +429 -0
  81. examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/policy.py +442 -0
  82. examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/react_agent.py +96 -0
  83. examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/shared.py +302 -0
  84. examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/tools.py +47 -0
  85. examples/warming_up_to_rl/task_app/synth_envs_hosted/hosted_app.py +202 -0
  86. examples/warming_up_to_rl/task_app/synth_envs_hosted/inference/__init__.py +5 -0
  87. examples/warming_up_to_rl/task_app/synth_envs_hosted/inference/openai_client.py +512 -0
  88. examples/warming_up_to_rl/task_app/synth_envs_hosted/main.py +102 -0
  89. examples/warming_up_to_rl/task_app/synth_envs_hosted/policy_routes.py +985 -0
  90. examples/warming_up_to_rl/task_app/synth_envs_hosted/registry.py +197 -0
  91. examples/warming_up_to_rl/task_app/synth_envs_hosted/rollout.py +1749 -0
  92. examples/warming_up_to_rl/task_app/synth_envs_hosted/storage/__init__.py +5 -0
  93. examples/warming_up_to_rl/task_app/synth_envs_hosted/storage/volume.py +217 -0
  94. examples/warming_up_to_rl/task_app/synth_envs_hosted/test_agents.py +160 -0
  95. examples/warming_up_to_rl/task_app/synth_envs_hosted/test_service.py +146 -0
  96. examples/warming_up_to_rl/task_app/synth_envs_hosted/test_stepwise_rewards.py +58 -0
  97. examples/warming_up_to_rl/task_app/synth_envs_hosted/utils.py +61 -0
  98. synth_ai/api/train/config_finder.py +18 -18
  99. synth_ai/api/train/env_resolver.py +28 -1
  100. synth_ai/cli/task_apps.py +264 -55
  101. synth_ai/task/apps/__init__.py +54 -13
  102. {synth_ai-0.2.9.dev3.dist-info → synth_ai-0.2.9.dev4.dist-info}/METADATA +1 -1
  103. {synth_ai-0.2.9.dev3.dist-info → synth_ai-0.2.9.dev4.dist-info}/RECORD +107 -12
  104. {synth_ai-0.2.9.dev3.dist-info → synth_ai-0.2.9.dev4.dist-info}/top_level.txt +1 -0
  105. {synth_ai-0.2.9.dev3.dist-info → synth_ai-0.2.9.dev4.dist-info}/WHEEL +0 -0
  106. {synth_ai-0.2.9.dev3.dist-info → synth_ai-0.2.9.dev4.dist-info}/entry_points.txt +0 -0
  107. {synth_ai-0.2.9.dev3.dist-info → synth_ai-0.2.9.dev4.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,56 @@
1
+ # RL training starting from a finetuned model id (TOML-only model selection)
2
+
3
+ [services]
4
+ # Task app base URL used by the RL job for rollouts
5
+ # task_url = "https://YOUR-TASK-APP.modal.run"
6
+
7
+ [compute]
8
+ # Cluster shape for RL pipeline
9
+ gpu_type = "H100"
10
+ gpu_count = 8
11
+
12
+ [topology]
13
+ # Split GPUs across vLLM, training, and reference
14
+ # Must sum to compute.gpu_count
15
+ #gpus_for_vllm = 4
16
+ #gpus_for_training = 3
17
+ #gpus_for_ref = 1
18
+
19
+ [vllm]
20
+ # Serving tensor parallel size
21
+ # tensor_parallel_size = 4
22
+
23
+ [model]
24
+ # Finetuned model id to continue training from (required for this config)
25
+ # source = "ft:YOUR_FT_MODEL_ID"
26
+ label = "crafter-rl-from-ft"
27
+
28
+ [rollout]
29
+ max_turns = 10
30
+ episodes_per_batch = 64
31
+
32
+ [evaluation]
33
+ # Run baseline evaluation on the first 100 task seeds every 20 iterations
34
+ instances = 100
35
+ every_n_iters = 20
36
+ seeds = [
37
+ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9,
38
+ 10, 11, 12, 13, 14, 15, 16, 17, 18, 19,
39
+ 20, 21, 22, 23, 24, 25, 26, 27, 28, 29,
40
+ 30, 31, 32, 33, 34, 35, 36, 37, 38, 39,
41
+ 40, 41, 42, 43, 44, 45, 46, 47, 48, 49,
42
+ 50, 51, 52, 53, 54, 55, 56, 57, 58, 59,
43
+ 60, 61, 62, 63, 64, 65, 66, 67, 68, 69,
44
+ 70, 71, 72, 73, 74, 75, 76, 77, 78, 79,
45
+ 80, 81, 82, 83, 84, 85, 86, 87, 88, 89,
46
+ 90, 91, 92, 93, 94, 95, 96, 97, 98, 99,
47
+ ]
48
+
49
+ [training]
50
+ log_interval = 1
51
+ # Additional RL hyperparameters can go here
52
+
53
+ [training.weight_sync]
54
+ enable = true
55
+ targets = ["policy"]
56
+ weight_sync_interval = 1
@@ -0,0 +1,541 @@
1
+ #!/usr/bin/env python3
2
+ """Export behavioural-cloning datasets from tracing_v3 SQLite traces with filters."""
3
+
4
+ from __future__ import annotations
5
+
6
+ import argparse
7
+ import json
8
+ import sqlite3
9
+ import sys
10
+ from collections import Counter, defaultdict
11
+ from pathlib import Path
12
+ from typing import Any, Dict, Iterable, List, Set, Tuple
13
+
14
+ Row = sqlite3.Row
15
+
16
+
17
+ def connect(db_path: Path) -> sqlite3.Connection:
18
+ conn = sqlite3.connect(str(db_path))
19
+ conn.row_factory = sqlite3.Row
20
+ return conn
21
+
22
+
23
+ def _parse_json(value: Any) -> Any:
24
+ if value is None:
25
+ return None
26
+ if isinstance(value, (dict, list)):
27
+ return value
28
+ try:
29
+ return json.loads(value)
30
+ except Exception:
31
+ return None
32
+
33
+
34
+ AchievementMap = dict[Tuple[str, int], dict[str, list[str]]]
35
+
36
+
37
+ def fetch_achievement_data(
38
+ conn: sqlite3.Connection,
39
+ ) -> tuple[
40
+ AchievementMap,
41
+ Counter,
42
+ Counter,
43
+ Counter,
44
+ dict[str, set[str]],
45
+ dict[str, set[str]],
46
+ ]:
47
+ achievements_map: AchievementMap = defaultdict(lambda: {"unique": [], "all": []})
48
+ session_unique_sets: dict[str, set[str]] = defaultdict(set)
49
+ session_final_achievements: dict[str, set[str]] = defaultdict(set)
50
+ achievement_name_counts: Counter = Counter()
51
+
52
+ rows = conn.execute(
53
+ """
54
+ SELECT er.session_id, er.reward_value, er.annotation, ev.metadata
55
+ FROM event_rewards er
56
+ JOIN events ev ON er.event_id = ev.id
57
+ WHERE er.reward_type = 'unique_achievement_delta' AND er.reward_value > 0
58
+ """
59
+ ).fetchall()
60
+ for row in rows:
61
+ session_id = row["session_id"]
62
+ annotation = _parse_json(row["annotation"]) or {}
63
+ metadata = _parse_json(row["metadata"]) or {}
64
+ turn = metadata.get("turn")
65
+ if turn is None:
66
+ continue
67
+ new_unique = annotation.get("new_unique") or []
68
+ if not isinstance(new_unique, list):
69
+ continue
70
+ if new_unique:
71
+ achievements_map[(session_id, int(turn))]["unique"].extend(new_unique)
72
+ session_unique_sets[session_id].update(new_unique)
73
+
74
+ rows = conn.execute(
75
+ """
76
+ SELECT er.session_id, er.reward_value, er.annotation, ev.metadata
77
+ FROM event_rewards er
78
+ JOIN events ev ON er.event_id = ev.id
79
+ WHERE er.reward_type = 'achievement_delta' AND er.reward_value > 0
80
+ """
81
+ ).fetchall()
82
+ for row in rows:
83
+ session_id = row["session_id"]
84
+ annotation = _parse_json(row["annotation"]) or {}
85
+ metadata = _parse_json(row["metadata"]) or {}
86
+ turn = metadata.get("turn")
87
+ if turn is None:
88
+ continue
89
+ turned_true = annotation.get("turned_true") or []
90
+ if not isinstance(turned_true, list):
91
+ continue
92
+ if turned_true:
93
+ achievements_map[(session_id, int(turn))]["all"].extend(turned_true)
94
+
95
+ rows = conn.execute(
96
+ """
97
+ SELECT session_id, reward_metadata
98
+ FROM outcome_rewards
99
+ WHERE reward_metadata IS NOT NULL
100
+ """
101
+ ).fetchall()
102
+ for row in rows:
103
+ session_id = row["session_id"]
104
+ metadata = _parse_json(row["reward_metadata"])
105
+ if not isinstance(metadata, dict):
106
+ continue
107
+ final_achievements = metadata.get("achievements") or []
108
+ if isinstance(final_achievements, list):
109
+ cleaned = [a for a in final_achievements if isinstance(a, str)]
110
+ session_unique_sets[session_id].update(cleaned)
111
+ session_final_achievements[session_id].update(cleaned)
112
+
113
+ unique_counts_per_session: Counter = Counter()
114
+ for session_id, achievement_set in session_unique_sets.items():
115
+ unique_counts_per_session[session_id] = len(achievement_set)
116
+ achievement_name_counts.update(achievement_set)
117
+
118
+ achievement_size_counts: Counter = Counter()
119
+ for session_id, count in unique_counts_per_session.items():
120
+ achievement_size_counts[count] += 1
121
+
122
+ return (
123
+ achievements_map,
124
+ unique_counts_per_session,
125
+ achievement_name_counts,
126
+ achievement_size_counts,
127
+ session_unique_sets,
128
+ session_final_achievements,
129
+ )
130
+
131
+
132
+ def fetch_session_models(conn: sqlite3.Connection) -> dict[str, tuple[str, str, int]]:
133
+ rows = conn.execute(
134
+ """
135
+ SELECT session_id, model_name, provider, COUNT(*) AS calls
136
+ FROM events
137
+ WHERE event_type = 'cais' AND model_name IS NOT NULL
138
+ GROUP BY session_id, model_name, provider
139
+ """
140
+ ).fetchall()
141
+
142
+ session_models: dict[str, tuple[str, str, int]] = {}
143
+ for row in rows:
144
+ session_id = row["session_id"]
145
+ calls = int(row["calls"] or 0)
146
+ current = session_models.get(session_id)
147
+ if current is None or calls > current[2]:
148
+ session_models[session_id] = (row["model_name"], row["provider"], calls)
149
+ return session_models
150
+
151
+
152
+ def fetch_outcome_rewards(conn: sqlite3.Connection) -> dict[str, dict[str, Any]]:
153
+ rows = conn.execute(
154
+ """
155
+ SELECT session_id, total_reward, reward_metadata
156
+ FROM outcome_rewards
157
+ """
158
+ ).fetchall()
159
+
160
+ outcome_data: dict[str, dict[str, Any]] = {}
161
+ for row in rows:
162
+ metadata = _parse_json(row["reward_metadata"])
163
+ achievements = set()
164
+ if isinstance(metadata, dict):
165
+ ach = metadata.get("achievements") or []
166
+ if isinstance(ach, list):
167
+ achievements = {a for a in ach if isinstance(a, str)}
168
+ outcome_data[row["session_id"]] = {
169
+ "total_reward": float(row["total_reward"] or 0.0),
170
+ "achievements": achievements,
171
+ }
172
+ return outcome_data
173
+
174
+
175
+ def fetch_event_reward_totals(conn: sqlite3.Connection) -> dict[str, dict[str, dict[str, float]]]:
176
+ rows = conn.execute(
177
+ """
178
+ SELECT session_id, reward_type, COUNT(*) AS events, COALESCE(SUM(reward_value), 0) AS total_value
179
+ FROM event_rewards
180
+ GROUP BY session_id, reward_type
181
+ """
182
+ ).fetchall()
183
+
184
+ event_totals: dict[str, dict[str, dict[str, float]]] = defaultdict(dict)
185
+ for row in rows:
186
+ event_totals[row["session_id"]][row["reward_type"]] = {
187
+ "events": int(row["events"] or 0),
188
+ "total": float(row["total_value"] or 0.0),
189
+ }
190
+ return event_totals
191
+
192
+
193
+ def parse_event_filters(specs: list[str] | None) -> list[tuple[str, float]]:
194
+ filters: list[tuple[str, float]] = []
195
+ if not specs:
196
+ return filters
197
+ for spec in specs:
198
+ reward_type, _, min_val_str = spec.partition(":")
199
+ reward_type = reward_type.strip()
200
+ if not reward_type:
201
+ continue
202
+ min_val = 0.0
203
+ if min_val_str:
204
+ try:
205
+ min_val = float(min_val_str)
206
+ except ValueError:
207
+ print(f"Invalid event reward specification '{spec}'", file=sys.stderr)
208
+ raise SystemExit(1)
209
+ filters.append((reward_type, min_val))
210
+ return filters
211
+
212
+
213
+ def _collect_text(parts: Iterable[dict[str, Any]] | None) -> str:
214
+ texts: list[str] = []
215
+ if not parts:
216
+ return ""
217
+ for part in parts:
218
+ if not isinstance(part, dict):
219
+ continue
220
+ if part.get("type") == "text":
221
+ text = part.get("text")
222
+ if isinstance(text, str) and text:
223
+ texts.append(text)
224
+ return "\n".join(texts)
225
+
226
+
227
+ def _normalise_tool_calls(tool_calls: list[dict[str, Any]] | None) -> list[dict[str, Any]]:
228
+ normalised: list[dict[str, Any]] = []
229
+ if not tool_calls:
230
+ return normalised
231
+ for idx, call in enumerate(tool_calls):
232
+ if not isinstance(call, dict):
233
+ continue
234
+ entry = dict(call)
235
+
236
+ func_payload: dict[str, Any] | None = entry.get("function") if isinstance(entry.get("function"), dict) else None
237
+ name = entry.get("name") or (func_payload.get("name") if func_payload else None) or "tool"
238
+
239
+ args = None
240
+ if func_payload and "arguments" in func_payload:
241
+ args = func_payload.get("arguments")
242
+ else:
243
+ args = entry.get("arguments")
244
+ if args is None:
245
+ raw = entry.pop("arguments_json", None)
246
+ if isinstance(raw, str):
247
+ try:
248
+ args = json.loads(raw)
249
+ except Exception:
250
+ args = raw
251
+
252
+ if isinstance(args, (dict, list)):
253
+ args_str = json.dumps(args, ensure_ascii=False)
254
+ elif isinstance(args, str):
255
+ args_str = args
256
+ elif args is None:
257
+ args_str = "{}"
258
+ else:
259
+ args_str = str(args)
260
+
261
+ call_id = entry.get("id") or entry.get("call_id") or f"call_{idx}"
262
+
263
+ normalised.append(
264
+ {
265
+ "id": str(call_id),
266
+ "type": "function",
267
+ "function": {
268
+ "name": str(name),
269
+ "arguments": args_str,
270
+ },
271
+ }
272
+ )
273
+
274
+ return normalised
275
+
276
+
277
+ def build_sft_dataset(
278
+ conn: sqlite3.Connection,
279
+ achievements_map: AchievementMap,
280
+ sessions_filter: Set[str],
281
+ *,
282
+ allowed_models: set[str] | None = None,
283
+ limit: int | None = None,
284
+ ) -> list[dict[str, Any]]:
285
+ rows = conn.execute(
286
+ """
287
+ SELECT id, session_id, metadata, model_name, provider, call_records
288
+ FROM events
289
+ WHERE event_type = 'cais' AND call_records IS NOT NULL
290
+ ORDER BY session_id, id
291
+ """
292
+ ).fetchall()
293
+
294
+ dataset: list[dict[str, Any]] = []
295
+ cumulative_unique: dict[str, int] = defaultdict(int)
296
+ session_turn_counters: dict[str, int] = defaultdict(int)
297
+
298
+ for row in rows:
299
+ session_id = row["session_id"]
300
+ if session_id not in sessions_filter:
301
+ continue
302
+ if allowed_models and row["model_name"] not in allowed_models:
303
+ continue
304
+
305
+ metadata = _parse_json(row["metadata"]) or {}
306
+ turn = metadata.get("turn")
307
+ if turn is None:
308
+ step_id = metadata.get("step_id")
309
+ if isinstance(step_id, str) and step_id.startswith("turn_"):
310
+ try:
311
+ turn = int(step_id.split("_", 1)[1])
312
+ except (ValueError, IndexError):
313
+ turn = None
314
+ if turn is None:
315
+ turn = session_turn_counters[session_id]
316
+ session_turn_counters[session_id] = turn + 1
317
+ else:
318
+ try:
319
+ turn = int(turn)
320
+ except (TypeError, ValueError):
321
+ continue
322
+ session_turn_counters[session_id] = max(session_turn_counters[session_id], turn + 1)
323
+
324
+ call_records = _parse_json(row["call_records"]) or []
325
+ if not isinstance(call_records, list) or not call_records:
326
+ continue
327
+
328
+ for record in call_records:
329
+ messages: list[dict[str, Any]] = []
330
+ for message in record.get("input_messages", []):
331
+ role = message.get("role", "unknown")
332
+ content = _collect_text(message.get("parts"))
333
+ if not content:
334
+ continue
335
+ messages.append({"role": role, "content": content})
336
+
337
+ assistant_content = ""
338
+ assistant_tool_calls: list[dict[str, Any]] = []
339
+
340
+ output_text = record.get("output_text")
341
+ parsed_response: dict[str, Any] | None = None
342
+ if isinstance(output_text, str) and output_text:
343
+ try:
344
+ parsed_response = json.loads(output_text)
345
+ except json.JSONDecodeError:
346
+ parsed_response = None
347
+
348
+ if parsed_response:
349
+ choices = parsed_response.get("choices") or []
350
+ if choices:
351
+ message = choices[0].get("message") or {}
352
+ assistant_content = message.get("content") or ""
353
+ assistant_tool_calls = _normalise_tool_calls(message.get("tool_calls"))
354
+
355
+ if not assistant_tool_calls:
356
+ assistant_tool_calls = _normalise_tool_calls(record.get("output_tool_calls"))
357
+
358
+ assistant_message: dict[str, Any] = {"role": "assistant", "content": assistant_content or ""}
359
+ if assistant_tool_calls:
360
+ assistant_message["tool_calls"] = assistant_tool_calls
361
+
362
+ if assistant_message.get("content") == "" and not assistant_message.get("tool_calls"):
363
+ continue
364
+
365
+ messages.append(assistant_message)
366
+
367
+ if len(messages) < 2:
368
+ continue
369
+
370
+ achievements = achievements_map.get((session_id, turn), {"unique": [], "all": []})
371
+ cumulative_unique[session_id] += len(achievements.get("unique", []))
372
+
373
+ metadata = {
374
+ "session_id": session_id,
375
+ "turn": turn,
376
+ "model": row["model_name"],
377
+ "provider": row["provider"] or "unknown",
378
+ "achievements": {
379
+ "new_unique": achievements.get("unique", []),
380
+ "turned_true": achievements.get("all", []),
381
+ "cumulative_unique": cumulative_unique[session_id],
382
+ },
383
+ }
384
+
385
+ dataset.append({"messages": messages, "metadata": metadata})
386
+ if limit is not None and len(dataset) >= limit:
387
+ return dataset
388
+
389
+ return dataset
390
+
391
+
392
+ def write_jsonl(path: Path, records: Iterable[dict[str, Any]]) -> None:
393
+ path.parent.mkdir(parents=True, exist_ok=True)
394
+ with path.open("w", encoding="utf-8") as fh:
395
+ for record in records:
396
+ json.dump(record, fh, ensure_ascii=False)
397
+ fh.write("\n")
398
+
399
+
400
+ def _validate_dataset(records: list[dict[str, Any]]) -> None:
401
+ errors: list[str] = []
402
+ for idx, record in enumerate(records, start=1):
403
+ messages = record.get("messages")
404
+ if not isinstance(messages, list) or not messages:
405
+ errors.append(f"row {idx}: missing messages list")
406
+ if len(errors) >= 20:
407
+ break
408
+ continue
409
+ for msg_idx, msg in enumerate(messages):
410
+ if not isinstance(msg, dict):
411
+ errors.append(f"row {idx}: message {msg_idx} is not an object")
412
+ break
413
+ if "role" not in msg or "content" not in msg:
414
+ errors.append(f"row {idx}: message {msg_idx} missing role/content")
415
+ break
416
+ if not isinstance(msg["role"], str):
417
+ errors.append(f"row {idx}: message {msg_idx} role not string")
418
+ break
419
+ if not isinstance(msg["content"], str):
420
+ errors.append(f"row {idx}: message {msg_idx} content not string")
421
+ break
422
+ if len(errors) >= 20:
423
+ break
424
+ if errors:
425
+ summary = "\n - ".join(errors)
426
+ raise SystemExit(f"Validation error while exporting dataset:\n - {summary}")
427
+
428
+
429
+ def main() -> None:
430
+ parser = argparse.ArgumentParser(description=__doc__)
431
+ parser.add_argument("--db", type=Path, default=Path("traces/v3/synth_ai.db"), help="Path to tracing_v3 SQLite DB")
432
+ parser.add_argument("--output", type=Path, required=True, help="Destination JSONL path for the exported dataset")
433
+ parser.add_argument("--model", action="append", dest="models", help="Restrict to sessions whose dominant model matches (repeatable)")
434
+ parser.add_argument("--provider", action="append", dest="providers", help="Restrict to sessions whose dominant provider matches (repeatable)")
435
+ parser.add_argument("--min-unique", type=int, default=None, help="Minimum unique achievements per session")
436
+ parser.add_argument("--max-unique", type=int, default=None, help="Maximum unique achievements per session")
437
+ parser.add_argument(
438
+ "--exclude-achievement",
439
+ action="append",
440
+ dest="exclude_achievements",
441
+ help="Achievements to ignore when evaluating --min-unique/--max-unique (repeatable)",
442
+ )
443
+ parser.add_argument("--require-achievement", action="append", dest="required_achievements", help="Require these outcome achievements (repeatable)")
444
+ parser.add_argument("--min-outcome-reward", type=float, default=None, help="Minimum total outcome reward per session")
445
+ parser.add_argument("--max-outcome-reward", type=float, default=None, help="Maximum total outcome reward per session")
446
+ parser.add_argument("--event-reward", action="append", dest="event_reward_filters", help="Require reward_type[:min_total] in event_rewards (repeatable)")
447
+ parser.add_argument("--limit", type=int, default=None, help="Maximum number of examples to emit")
448
+ args = parser.parse_args()
449
+
450
+ if not args.db.exists():
451
+ print(f"Database not found: {args.db}", file=sys.stderr)
452
+ raise SystemExit(1)
453
+
454
+ conn = connect(args.db)
455
+ try:
456
+ (
457
+ achievements_map,
458
+ unique_counts_per_session,
459
+ _name_counts,
460
+ _size_counts,
461
+ session_unique_sets,
462
+ session_final_achievements,
463
+ ) = fetch_achievement_data(conn)
464
+ session_models = fetch_session_models(conn)
465
+ outcome_data = fetch_outcome_rewards(conn)
466
+ event_totals = fetch_event_reward_totals(conn)
467
+ event_filters = parse_event_filters(args.event_reward_filters)
468
+
469
+ allowed_models = set(args.models) if args.models else None
470
+ allowed_providers = set(args.providers) if args.providers else None
471
+ required_achievements = set(args.required_achievements or [])
472
+ excluded_achievements = set(args.exclude_achievements or [])
473
+
474
+ eligible_sessions: set[str] = set()
475
+ for session_id, (model_name, provider, _calls) in session_models.items():
476
+ if allowed_models and model_name not in allowed_models:
477
+ continue
478
+ if allowed_providers and (provider or "unknown") not in allowed_providers:
479
+ continue
480
+
481
+ session_uniques = session_unique_sets.get(session_id, set())
482
+ adjusted_uniques = {a for a in session_uniques if a not in excluded_achievements}
483
+ unique_count = len(adjusted_uniques)
484
+ if args.min_unique is not None and unique_count < args.min_unique:
485
+ continue
486
+ if args.max_unique is not None and unique_count > args.max_unique:
487
+ continue
488
+
489
+ outcome = outcome_data.get(session_id)
490
+ total_reward = outcome["total_reward"] if outcome else 0.0
491
+ final_achievements = outcome["achievements"] if outcome else session_final_achievements.get(session_id, set())
492
+
493
+ if args.min_outcome_reward is not None and total_reward < args.min_outcome_reward:
494
+ continue
495
+ if args.max_outcome_reward is not None and total_reward > args.max_outcome_reward:
496
+ continue
497
+ if required_achievements and not required_achievements.issubset(final_achievements):
498
+ continue
499
+
500
+ session_event_totals = event_totals.get(session_id, {})
501
+ meets_event_filters = True
502
+ for reward_type, min_total in event_filters:
503
+ total = session_event_totals.get(reward_type, {}).get("total", 0.0)
504
+ if total < min_total:
505
+ meets_event_filters = False
506
+ break
507
+ if not meets_event_filters:
508
+ continue
509
+
510
+ eligible_sessions.add(session_id)
511
+
512
+ if not eligible_sessions:
513
+ print("No sessions matched the provided filters.", file=sys.stderr)
514
+ raise SystemExit(1)
515
+
516
+ dataset = build_sft_dataset(
517
+ conn,
518
+ achievements_map,
519
+ eligible_sessions,
520
+ allowed_models=allowed_models,
521
+ limit=args.limit,
522
+ )
523
+
524
+ if not dataset:
525
+ print("No rollout steps matched the filters (after session selection).", file=sys.stderr)
526
+ raise SystemExit(1)
527
+
528
+ _validate_dataset(dataset)
529
+ write_jsonl(args.output, dataset)
530
+ session_ids = {item.get("metadata", {}).get("session_id") for item in dataset}
531
+ session_ids.discard(None)
532
+ print(
533
+ f"Wrote {len(dataset)} examples from {len(session_ids)} session(s) -> {args.output}",
534
+ file=sys.stderr,
535
+ )
536
+ finally:
537
+ conn.close()
538
+
539
+
540
+ if __name__ == "__main__":
541
+ main()
@@ -0,0 +1,88 @@
1
+ from __future__ import annotations
2
+
3
+ """Quick smoke test that drives a rollout through the Groq proxy-backed Crafter Task App."""
4
+
5
+ import argparse
6
+ import asyncio
7
+ import os
8
+ from typing import Any
9
+
10
+ from synth_ai.task import (
11
+ INTERACT_TOOL_SCHEMA,
12
+ RolloutEnvSpec,
13
+ RolloutPolicySpec,
14
+ RolloutRequest,
15
+ TaskAppClient,
16
+ to_jsonable,
17
+ )
18
+
19
+
20
+ def _build_policy_payload(seed: int, model: str) -> dict[str, Any]:
21
+ return {
22
+ "model": model,
23
+ "tools": INTERACT_TOOL_SCHEMA,
24
+ "messages": [
25
+ {
26
+ "role": "system",
27
+ "content": "You control the Crafter agent. Think briefly, then call the interact tool with 3-5 actions to maximize achievements.",
28
+ },
29
+ {
30
+ "role": "user",
31
+ "content": (
32
+ "Environment seed {seed}. Plan initial survival/crafting steps and then call interact with concrete actions."
33
+ ).format(seed=seed),
34
+ },
35
+ ],
36
+ }
37
+
38
+
39
+ async def run(args: argparse.Namespace) -> None:
40
+ client = TaskAppClient(args.base_url, api_key=args.api_key, timeout=args.timeout)
41
+
42
+ health = await client.health()
43
+ print("/health →", to_jsonable(health))
44
+
45
+ info = await client.info()
46
+ print("/info →", to_jsonable(info))
47
+
48
+ inference_url = args.inference_url or f"{args.base_url.rstrip('/')}/proxy/groq"
49
+
50
+ request = RolloutRequest(
51
+ run_id=args.run_id,
52
+ env=RolloutEnvSpec(env_name="crafter", seed=args.seed, config={"seed": args.seed}),
53
+ policy=RolloutPolicySpec(
54
+ policy_name="groq-smoke",
55
+ config={"model": args.model, "inference_url": inference_url.rstrip("/")},
56
+ ),
57
+ ops=[
58
+ {"type": "policy", "payload": _build_policy_payload(args.seed, args.model)},
59
+ {"type": "env"},
60
+ ],
61
+ )
62
+
63
+ response = await client.rollout(request)
64
+ print("rollout.metrics →", to_jsonable(response.metrics.model_dump()))
65
+ for idx, step in enumerate(response.trajectories[0].steps, start=1):
66
+ print(f"step[{idx}] tool_calls={step.tool_calls} reward={step.reward} info={to_jsonable(step.info)}")
67
+
68
+
69
+ def _parse_args() -> argparse.Namespace:
70
+ parser = argparse.ArgumentParser(description=__doc__)
71
+ parser.add_argument("--base-url", default=os.getenv("TASK_APP_BASE_URL", "http://localhost:8000"))
72
+ parser.add_argument("--api-key", default=os.getenv("TASK_APP_API_KEY"), required=os.getenv("TASK_APP_API_KEY") is None)
73
+ parser.add_argument("--model", default=os.getenv("GROQ_MODEL", "groq/mixtral-8x7b"))
74
+ parser.add_argument("--inference-url", default=os.getenv("TASK_APP_INFERENCE_URL"))
75
+ parser.add_argument("--seed", type=int, default=int(os.getenv("CRAFTER_TEST_SEED", "42")))
76
+ parser.add_argument("--run-id", default=os.getenv("TASK_APP_RUN_ID", "groq-test"))
77
+ parser.add_argument("--timeout", type=float, default=float(os.getenv("TASK_APP_TIMEOUT", "60")))
78
+ return parser.parse_args()
79
+
80
+
81
+ def main() -> None:
82
+ args = _parse_args()
83
+ asyncio.run(run(args))
84
+
85
+
86
+ if __name__ == "__main__":
87
+ main()
88
+