synth-ai 0.2.9.dev2__py3-none-any.whl → 0.2.9.dev4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of synth-ai might be problematic. Click here for more details.
- examples/analyze_semantic_words.sh +17 -0
- examples/common_old/backend.py +21 -0
- examples/crafter_debug_render.py +180 -0
- examples/evals_old/README.md +98 -0
- examples/evals_old/__init__.py +6 -0
- examples/evals_old/compare_models.py +1037 -0
- examples/evals_old/example_log.md +145 -0
- examples/evals_old/run_demo.sh +126 -0
- examples/evals_old/trace_analysis.py +270 -0
- examples/finetuning_old/_backup_synth_qwen/config.toml +29 -0
- examples/finetuning_old/_backup_synth_qwen/example_log.md +324 -0
- examples/finetuning_old/_backup_synth_qwen/filter_traces.py +60 -0
- examples/finetuning_old/_backup_synth_qwen/filter_traces_achievements.py +239 -0
- examples/finetuning_old/_backup_synth_qwen/purge_v3_traces.py +109 -0
- examples/finetuning_old/_backup_synth_qwen/react_agent_lm.py +1924 -0
- examples/finetuning_old/_backup_synth_qwen/readme.md +49 -0
- examples/finetuning_old/_backup_synth_qwen/run_crafter_qwen4b.py +114 -0
- examples/finetuning_old/_backup_synth_qwen/run_demo.sh +195 -0
- examples/finetuning_old/_backup_synth_qwen/sft_kickoff.py +118 -0
- examples/finetuning_old/synth_qwen_v1/README.md +68 -0
- examples/finetuning_old/synth_qwen_v1/filter_traces.py +60 -0
- examples/finetuning_old/synth_qwen_v1/filter_traces_achievements.py +239 -0
- examples/finetuning_old/synth_qwen_v1/finetune.py +46 -0
- examples/finetuning_old/synth_qwen_v1/hello_ft_model.py +71 -0
- examples/finetuning_old/synth_qwen_v1/infer.py +37 -0
- examples/finetuning_old/synth_qwen_v1/poll.py +44 -0
- examples/finetuning_old/synth_qwen_v1/prepare_data.py +35 -0
- examples/finetuning_old/synth_qwen_v1/purge_v3_traces.py +109 -0
- examples/finetuning_old/synth_qwen_v1/react_agent_lm.py +1932 -0
- examples/finetuning_old/synth_qwen_v1/run_crafter_sft_job.py +207 -0
- examples/finetuning_old/synth_qwen_v1/run_ft_job.py +232 -0
- examples/finetuning_old/synth_qwen_v1/upload_data.py +34 -0
- examples/finetuning_old/synth_qwen_v1/util.py +147 -0
- examples/rl/README.md +169 -0
- examples/rl/configs/eval_base_qwen.toml +15 -0
- examples/rl/configs/eval_rl_qwen.toml +11 -0
- examples/rl/configs/rl_from_base_qwen.toml +35 -0
- examples/rl/configs/rl_from_base_qwen17.toml +74 -0
- examples/rl/configs/rl_from_ft_qwen.toml +35 -0
- examples/rl/download_dataset.py +64 -0
- examples/rl/run_eval.py +435 -0
- examples/rl/run_rl_and_save.py +94 -0
- examples/rl/task_app/README.md +22 -0
- {synth_ai/task/apps → examples/rl/task_app}/math_single_step.py +8 -8
- examples/rl/task_app/math_task_app.py +107 -0
- examples/rl_old/task_app.py +962 -0
- examples/run_crafter_demo.sh +10 -0
- examples/warming_up_to_rl/analyze_trace_db.py +420 -0
- examples/warming_up_to_rl/configs/crafter_fft.toml +48 -0
- examples/warming_up_to_rl/configs/crafter_fft_4b.toml +54 -0
- examples/warming_up_to_rl/configs/eval_fft_qwen4b.toml +20 -0
- examples/warming_up_to_rl/configs/eval_groq_qwen32b.toml +13 -0
- examples/warming_up_to_rl/configs/eval_modal_qwen4b.toml +23 -0
- examples/warming_up_to_rl/configs/rl_from_base_qwen4b.toml +73 -0
- examples/warming_up_to_rl/configs/rl_from_ft.toml +56 -0
- examples/warming_up_to_rl/export_trace_sft.py +541 -0
- examples/warming_up_to_rl/groq_test.py +88 -0
- examples/warming_up_to_rl/manage_secrets.py +127 -0
- examples/warming_up_to_rl/old/event_rewards.md +234 -0
- examples/warming_up_to_rl/old/notes.md +73 -0
- examples/warming_up_to_rl/readme.md +172 -0
- examples/warming_up_to_rl/run_eval.py +434 -0
- examples/warming_up_to_rl/run_fft_and_save.py +309 -0
- examples/warming_up_to_rl/run_local_rollout.py +188 -0
- examples/warming_up_to_rl/run_local_rollout_modal.py +160 -0
- examples/warming_up_to_rl/run_local_rollout_parallel.py +342 -0
- examples/warming_up_to_rl/run_local_rollout_traced.py +372 -0
- examples/warming_up_to_rl/run_rl_and_save.py +101 -0
- examples/warming_up_to_rl/run_rollout_remote.py +129 -0
- examples/warming_up_to_rl/task_app/README.md +38 -0
- {synth_ai/task/apps → examples/warming_up_to_rl/task_app}/grpo_crafter.py +7 -7
- examples/warming_up_to_rl/task_app/grpo_crafter_task_app.py +165 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/README.md +173 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/__init__.py +5 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/branching.py +145 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/environment_routes.py +1271 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/__init__.py +1 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/__init__.py +6 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/app.py +1 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/environment.py +429 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/policy.py +442 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/react_agent.py +96 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/shared.py +302 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/tools.py +47 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/hosted_app.py +202 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/inference/__init__.py +5 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/inference/openai_client.py +512 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/main.py +102 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/policy_routes.py +985 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/registry.py +197 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/rollout.py +1749 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/storage/__init__.py +5 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/storage/volume.py +217 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/test_agents.py +160 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/test_service.py +146 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/test_stepwise_rewards.py +58 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/utils.py +61 -0
- synth_ai/api/train/config_finder.py +18 -18
- synth_ai/api/train/env_resolver.py +28 -1
- synth_ai/cli/task_apps.py +264 -55
- synth_ai/demo_registry.py +7 -7
- synth_ai/demos/demo_task_apps/crafter/__init__.py +1 -0
- synth_ai/demos/demo_task_apps/crafter/configs/crafter_fft_4b.toml +54 -0
- synth_ai/demos/demo_task_apps/crafter/configs/rl_from_base_qwen4b.toml +73 -0
- synth_ai/demos/demo_task_apps/crafter/grpo_crafter_task_app.py +165 -0
- synth_ai/task/apps/__init__.py +54 -13
- {synth_ai-0.2.9.dev2.dist-info → synth_ai-0.2.9.dev4.dist-info}/METADATA +1 -1
- {synth_ai-0.2.9.dev2.dist-info → synth_ai-0.2.9.dev4.dist-info}/RECORD +112 -13
- {synth_ai-0.2.9.dev2.dist-info → synth_ai-0.2.9.dev4.dist-info}/top_level.txt +1 -0
- {synth_ai-0.2.9.dev2.dist-info → synth_ai-0.2.9.dev4.dist-info}/WHEEL +0 -0
- {synth_ai-0.2.9.dev2.dist-info → synth_ai-0.2.9.dev4.dist-info}/entry_points.txt +0 -0
- {synth_ai-0.2.9.dev2.dist-info → synth_ai-0.2.9.dev4.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,56 @@
|
|
|
1
|
+
# RL training starting from a finetuned model id (TOML-only model selection)
|
|
2
|
+
|
|
3
|
+
[services]
|
|
4
|
+
# Task app base URL used by the RL job for rollouts
|
|
5
|
+
# task_url = "https://YOUR-TASK-APP.modal.run"
|
|
6
|
+
|
|
7
|
+
[compute]
|
|
8
|
+
# Cluster shape for RL pipeline
|
|
9
|
+
gpu_type = "H100"
|
|
10
|
+
gpu_count = 8
|
|
11
|
+
|
|
12
|
+
[topology]
|
|
13
|
+
# Split GPUs across vLLM, training, and reference
|
|
14
|
+
# Must sum to compute.gpu_count
|
|
15
|
+
#gpus_for_vllm = 4
|
|
16
|
+
#gpus_for_training = 3
|
|
17
|
+
#gpus_for_ref = 1
|
|
18
|
+
|
|
19
|
+
[vllm]
|
|
20
|
+
# Serving tensor parallel size
|
|
21
|
+
# tensor_parallel_size = 4
|
|
22
|
+
|
|
23
|
+
[model]
|
|
24
|
+
# Finetuned model id to continue training from (required for this config)
|
|
25
|
+
# source = "ft:YOUR_FT_MODEL_ID"
|
|
26
|
+
label = "crafter-rl-from-ft"
|
|
27
|
+
|
|
28
|
+
[rollout]
|
|
29
|
+
max_turns = 10
|
|
30
|
+
episodes_per_batch = 64
|
|
31
|
+
|
|
32
|
+
[evaluation]
|
|
33
|
+
# Run baseline evaluation on the first 100 task seeds every 20 iterations
|
|
34
|
+
instances = 100
|
|
35
|
+
every_n_iters = 20
|
|
36
|
+
seeds = [
|
|
37
|
+
0, 1, 2, 3, 4, 5, 6, 7, 8, 9,
|
|
38
|
+
10, 11, 12, 13, 14, 15, 16, 17, 18, 19,
|
|
39
|
+
20, 21, 22, 23, 24, 25, 26, 27, 28, 29,
|
|
40
|
+
30, 31, 32, 33, 34, 35, 36, 37, 38, 39,
|
|
41
|
+
40, 41, 42, 43, 44, 45, 46, 47, 48, 49,
|
|
42
|
+
50, 51, 52, 53, 54, 55, 56, 57, 58, 59,
|
|
43
|
+
60, 61, 62, 63, 64, 65, 66, 67, 68, 69,
|
|
44
|
+
70, 71, 72, 73, 74, 75, 76, 77, 78, 79,
|
|
45
|
+
80, 81, 82, 83, 84, 85, 86, 87, 88, 89,
|
|
46
|
+
90, 91, 92, 93, 94, 95, 96, 97, 98, 99,
|
|
47
|
+
]
|
|
48
|
+
|
|
49
|
+
[training]
|
|
50
|
+
log_interval = 1
|
|
51
|
+
# Additional RL hyperparameters can go here
|
|
52
|
+
|
|
53
|
+
[training.weight_sync]
|
|
54
|
+
enable = true
|
|
55
|
+
targets = ["policy"]
|
|
56
|
+
weight_sync_interval = 1
|
|
@@ -0,0 +1,541 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
"""Export behavioural-cloning datasets from tracing_v3 SQLite traces with filters."""
|
|
3
|
+
|
|
4
|
+
from __future__ import annotations
|
|
5
|
+
|
|
6
|
+
import argparse
|
|
7
|
+
import json
|
|
8
|
+
import sqlite3
|
|
9
|
+
import sys
|
|
10
|
+
from collections import Counter, defaultdict
|
|
11
|
+
from pathlib import Path
|
|
12
|
+
from typing import Any, Dict, Iterable, List, Set, Tuple
|
|
13
|
+
|
|
14
|
+
Row = sqlite3.Row
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def connect(db_path: Path) -> sqlite3.Connection:
|
|
18
|
+
conn = sqlite3.connect(str(db_path))
|
|
19
|
+
conn.row_factory = sqlite3.Row
|
|
20
|
+
return conn
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def _parse_json(value: Any) -> Any:
|
|
24
|
+
if value is None:
|
|
25
|
+
return None
|
|
26
|
+
if isinstance(value, (dict, list)):
|
|
27
|
+
return value
|
|
28
|
+
try:
|
|
29
|
+
return json.loads(value)
|
|
30
|
+
except Exception:
|
|
31
|
+
return None
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
AchievementMap = dict[Tuple[str, int], dict[str, list[str]]]
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def fetch_achievement_data(
|
|
38
|
+
conn: sqlite3.Connection,
|
|
39
|
+
) -> tuple[
|
|
40
|
+
AchievementMap,
|
|
41
|
+
Counter,
|
|
42
|
+
Counter,
|
|
43
|
+
Counter,
|
|
44
|
+
dict[str, set[str]],
|
|
45
|
+
dict[str, set[str]],
|
|
46
|
+
]:
|
|
47
|
+
achievements_map: AchievementMap = defaultdict(lambda: {"unique": [], "all": []})
|
|
48
|
+
session_unique_sets: dict[str, set[str]] = defaultdict(set)
|
|
49
|
+
session_final_achievements: dict[str, set[str]] = defaultdict(set)
|
|
50
|
+
achievement_name_counts: Counter = Counter()
|
|
51
|
+
|
|
52
|
+
rows = conn.execute(
|
|
53
|
+
"""
|
|
54
|
+
SELECT er.session_id, er.reward_value, er.annotation, ev.metadata
|
|
55
|
+
FROM event_rewards er
|
|
56
|
+
JOIN events ev ON er.event_id = ev.id
|
|
57
|
+
WHERE er.reward_type = 'unique_achievement_delta' AND er.reward_value > 0
|
|
58
|
+
"""
|
|
59
|
+
).fetchall()
|
|
60
|
+
for row in rows:
|
|
61
|
+
session_id = row["session_id"]
|
|
62
|
+
annotation = _parse_json(row["annotation"]) or {}
|
|
63
|
+
metadata = _parse_json(row["metadata"]) or {}
|
|
64
|
+
turn = metadata.get("turn")
|
|
65
|
+
if turn is None:
|
|
66
|
+
continue
|
|
67
|
+
new_unique = annotation.get("new_unique") or []
|
|
68
|
+
if not isinstance(new_unique, list):
|
|
69
|
+
continue
|
|
70
|
+
if new_unique:
|
|
71
|
+
achievements_map[(session_id, int(turn))]["unique"].extend(new_unique)
|
|
72
|
+
session_unique_sets[session_id].update(new_unique)
|
|
73
|
+
|
|
74
|
+
rows = conn.execute(
|
|
75
|
+
"""
|
|
76
|
+
SELECT er.session_id, er.reward_value, er.annotation, ev.metadata
|
|
77
|
+
FROM event_rewards er
|
|
78
|
+
JOIN events ev ON er.event_id = ev.id
|
|
79
|
+
WHERE er.reward_type = 'achievement_delta' AND er.reward_value > 0
|
|
80
|
+
"""
|
|
81
|
+
).fetchall()
|
|
82
|
+
for row in rows:
|
|
83
|
+
session_id = row["session_id"]
|
|
84
|
+
annotation = _parse_json(row["annotation"]) or {}
|
|
85
|
+
metadata = _parse_json(row["metadata"]) or {}
|
|
86
|
+
turn = metadata.get("turn")
|
|
87
|
+
if turn is None:
|
|
88
|
+
continue
|
|
89
|
+
turned_true = annotation.get("turned_true") or []
|
|
90
|
+
if not isinstance(turned_true, list):
|
|
91
|
+
continue
|
|
92
|
+
if turned_true:
|
|
93
|
+
achievements_map[(session_id, int(turn))]["all"].extend(turned_true)
|
|
94
|
+
|
|
95
|
+
rows = conn.execute(
|
|
96
|
+
"""
|
|
97
|
+
SELECT session_id, reward_metadata
|
|
98
|
+
FROM outcome_rewards
|
|
99
|
+
WHERE reward_metadata IS NOT NULL
|
|
100
|
+
"""
|
|
101
|
+
).fetchall()
|
|
102
|
+
for row in rows:
|
|
103
|
+
session_id = row["session_id"]
|
|
104
|
+
metadata = _parse_json(row["reward_metadata"])
|
|
105
|
+
if not isinstance(metadata, dict):
|
|
106
|
+
continue
|
|
107
|
+
final_achievements = metadata.get("achievements") or []
|
|
108
|
+
if isinstance(final_achievements, list):
|
|
109
|
+
cleaned = [a for a in final_achievements if isinstance(a, str)]
|
|
110
|
+
session_unique_sets[session_id].update(cleaned)
|
|
111
|
+
session_final_achievements[session_id].update(cleaned)
|
|
112
|
+
|
|
113
|
+
unique_counts_per_session: Counter = Counter()
|
|
114
|
+
for session_id, achievement_set in session_unique_sets.items():
|
|
115
|
+
unique_counts_per_session[session_id] = len(achievement_set)
|
|
116
|
+
achievement_name_counts.update(achievement_set)
|
|
117
|
+
|
|
118
|
+
achievement_size_counts: Counter = Counter()
|
|
119
|
+
for session_id, count in unique_counts_per_session.items():
|
|
120
|
+
achievement_size_counts[count] += 1
|
|
121
|
+
|
|
122
|
+
return (
|
|
123
|
+
achievements_map,
|
|
124
|
+
unique_counts_per_session,
|
|
125
|
+
achievement_name_counts,
|
|
126
|
+
achievement_size_counts,
|
|
127
|
+
session_unique_sets,
|
|
128
|
+
session_final_achievements,
|
|
129
|
+
)
|
|
130
|
+
|
|
131
|
+
|
|
132
|
+
def fetch_session_models(conn: sqlite3.Connection) -> dict[str, tuple[str, str, int]]:
|
|
133
|
+
rows = conn.execute(
|
|
134
|
+
"""
|
|
135
|
+
SELECT session_id, model_name, provider, COUNT(*) AS calls
|
|
136
|
+
FROM events
|
|
137
|
+
WHERE event_type = 'cais' AND model_name IS NOT NULL
|
|
138
|
+
GROUP BY session_id, model_name, provider
|
|
139
|
+
"""
|
|
140
|
+
).fetchall()
|
|
141
|
+
|
|
142
|
+
session_models: dict[str, tuple[str, str, int]] = {}
|
|
143
|
+
for row in rows:
|
|
144
|
+
session_id = row["session_id"]
|
|
145
|
+
calls = int(row["calls"] or 0)
|
|
146
|
+
current = session_models.get(session_id)
|
|
147
|
+
if current is None or calls > current[2]:
|
|
148
|
+
session_models[session_id] = (row["model_name"], row["provider"], calls)
|
|
149
|
+
return session_models
|
|
150
|
+
|
|
151
|
+
|
|
152
|
+
def fetch_outcome_rewards(conn: sqlite3.Connection) -> dict[str, dict[str, Any]]:
|
|
153
|
+
rows = conn.execute(
|
|
154
|
+
"""
|
|
155
|
+
SELECT session_id, total_reward, reward_metadata
|
|
156
|
+
FROM outcome_rewards
|
|
157
|
+
"""
|
|
158
|
+
).fetchall()
|
|
159
|
+
|
|
160
|
+
outcome_data: dict[str, dict[str, Any]] = {}
|
|
161
|
+
for row in rows:
|
|
162
|
+
metadata = _parse_json(row["reward_metadata"])
|
|
163
|
+
achievements = set()
|
|
164
|
+
if isinstance(metadata, dict):
|
|
165
|
+
ach = metadata.get("achievements") or []
|
|
166
|
+
if isinstance(ach, list):
|
|
167
|
+
achievements = {a for a in ach if isinstance(a, str)}
|
|
168
|
+
outcome_data[row["session_id"]] = {
|
|
169
|
+
"total_reward": float(row["total_reward"] or 0.0),
|
|
170
|
+
"achievements": achievements,
|
|
171
|
+
}
|
|
172
|
+
return outcome_data
|
|
173
|
+
|
|
174
|
+
|
|
175
|
+
def fetch_event_reward_totals(conn: sqlite3.Connection) -> dict[str, dict[str, dict[str, float]]]:
|
|
176
|
+
rows = conn.execute(
|
|
177
|
+
"""
|
|
178
|
+
SELECT session_id, reward_type, COUNT(*) AS events, COALESCE(SUM(reward_value), 0) AS total_value
|
|
179
|
+
FROM event_rewards
|
|
180
|
+
GROUP BY session_id, reward_type
|
|
181
|
+
"""
|
|
182
|
+
).fetchall()
|
|
183
|
+
|
|
184
|
+
event_totals: dict[str, dict[str, dict[str, float]]] = defaultdict(dict)
|
|
185
|
+
for row in rows:
|
|
186
|
+
event_totals[row["session_id"]][row["reward_type"]] = {
|
|
187
|
+
"events": int(row["events"] or 0),
|
|
188
|
+
"total": float(row["total_value"] or 0.0),
|
|
189
|
+
}
|
|
190
|
+
return event_totals
|
|
191
|
+
|
|
192
|
+
|
|
193
|
+
def parse_event_filters(specs: list[str] | None) -> list[tuple[str, float]]:
|
|
194
|
+
filters: list[tuple[str, float]] = []
|
|
195
|
+
if not specs:
|
|
196
|
+
return filters
|
|
197
|
+
for spec in specs:
|
|
198
|
+
reward_type, _, min_val_str = spec.partition(":")
|
|
199
|
+
reward_type = reward_type.strip()
|
|
200
|
+
if not reward_type:
|
|
201
|
+
continue
|
|
202
|
+
min_val = 0.0
|
|
203
|
+
if min_val_str:
|
|
204
|
+
try:
|
|
205
|
+
min_val = float(min_val_str)
|
|
206
|
+
except ValueError:
|
|
207
|
+
print(f"Invalid event reward specification '{spec}'", file=sys.stderr)
|
|
208
|
+
raise SystemExit(1)
|
|
209
|
+
filters.append((reward_type, min_val))
|
|
210
|
+
return filters
|
|
211
|
+
|
|
212
|
+
|
|
213
|
+
def _collect_text(parts: Iterable[dict[str, Any]] | None) -> str:
|
|
214
|
+
texts: list[str] = []
|
|
215
|
+
if not parts:
|
|
216
|
+
return ""
|
|
217
|
+
for part in parts:
|
|
218
|
+
if not isinstance(part, dict):
|
|
219
|
+
continue
|
|
220
|
+
if part.get("type") == "text":
|
|
221
|
+
text = part.get("text")
|
|
222
|
+
if isinstance(text, str) and text:
|
|
223
|
+
texts.append(text)
|
|
224
|
+
return "\n".join(texts)
|
|
225
|
+
|
|
226
|
+
|
|
227
|
+
def _normalise_tool_calls(tool_calls: list[dict[str, Any]] | None) -> list[dict[str, Any]]:
|
|
228
|
+
normalised: list[dict[str, Any]] = []
|
|
229
|
+
if not tool_calls:
|
|
230
|
+
return normalised
|
|
231
|
+
for idx, call in enumerate(tool_calls):
|
|
232
|
+
if not isinstance(call, dict):
|
|
233
|
+
continue
|
|
234
|
+
entry = dict(call)
|
|
235
|
+
|
|
236
|
+
func_payload: dict[str, Any] | None = entry.get("function") if isinstance(entry.get("function"), dict) else None
|
|
237
|
+
name = entry.get("name") or (func_payload.get("name") if func_payload else None) or "tool"
|
|
238
|
+
|
|
239
|
+
args = None
|
|
240
|
+
if func_payload and "arguments" in func_payload:
|
|
241
|
+
args = func_payload.get("arguments")
|
|
242
|
+
else:
|
|
243
|
+
args = entry.get("arguments")
|
|
244
|
+
if args is None:
|
|
245
|
+
raw = entry.pop("arguments_json", None)
|
|
246
|
+
if isinstance(raw, str):
|
|
247
|
+
try:
|
|
248
|
+
args = json.loads(raw)
|
|
249
|
+
except Exception:
|
|
250
|
+
args = raw
|
|
251
|
+
|
|
252
|
+
if isinstance(args, (dict, list)):
|
|
253
|
+
args_str = json.dumps(args, ensure_ascii=False)
|
|
254
|
+
elif isinstance(args, str):
|
|
255
|
+
args_str = args
|
|
256
|
+
elif args is None:
|
|
257
|
+
args_str = "{}"
|
|
258
|
+
else:
|
|
259
|
+
args_str = str(args)
|
|
260
|
+
|
|
261
|
+
call_id = entry.get("id") or entry.get("call_id") or f"call_{idx}"
|
|
262
|
+
|
|
263
|
+
normalised.append(
|
|
264
|
+
{
|
|
265
|
+
"id": str(call_id),
|
|
266
|
+
"type": "function",
|
|
267
|
+
"function": {
|
|
268
|
+
"name": str(name),
|
|
269
|
+
"arguments": args_str,
|
|
270
|
+
},
|
|
271
|
+
}
|
|
272
|
+
)
|
|
273
|
+
|
|
274
|
+
return normalised
|
|
275
|
+
|
|
276
|
+
|
|
277
|
+
def build_sft_dataset(
|
|
278
|
+
conn: sqlite3.Connection,
|
|
279
|
+
achievements_map: AchievementMap,
|
|
280
|
+
sessions_filter: Set[str],
|
|
281
|
+
*,
|
|
282
|
+
allowed_models: set[str] | None = None,
|
|
283
|
+
limit: int | None = None,
|
|
284
|
+
) -> list[dict[str, Any]]:
|
|
285
|
+
rows = conn.execute(
|
|
286
|
+
"""
|
|
287
|
+
SELECT id, session_id, metadata, model_name, provider, call_records
|
|
288
|
+
FROM events
|
|
289
|
+
WHERE event_type = 'cais' AND call_records IS NOT NULL
|
|
290
|
+
ORDER BY session_id, id
|
|
291
|
+
"""
|
|
292
|
+
).fetchall()
|
|
293
|
+
|
|
294
|
+
dataset: list[dict[str, Any]] = []
|
|
295
|
+
cumulative_unique: dict[str, int] = defaultdict(int)
|
|
296
|
+
session_turn_counters: dict[str, int] = defaultdict(int)
|
|
297
|
+
|
|
298
|
+
for row in rows:
|
|
299
|
+
session_id = row["session_id"]
|
|
300
|
+
if session_id not in sessions_filter:
|
|
301
|
+
continue
|
|
302
|
+
if allowed_models and row["model_name"] not in allowed_models:
|
|
303
|
+
continue
|
|
304
|
+
|
|
305
|
+
metadata = _parse_json(row["metadata"]) or {}
|
|
306
|
+
turn = metadata.get("turn")
|
|
307
|
+
if turn is None:
|
|
308
|
+
step_id = metadata.get("step_id")
|
|
309
|
+
if isinstance(step_id, str) and step_id.startswith("turn_"):
|
|
310
|
+
try:
|
|
311
|
+
turn = int(step_id.split("_", 1)[1])
|
|
312
|
+
except (ValueError, IndexError):
|
|
313
|
+
turn = None
|
|
314
|
+
if turn is None:
|
|
315
|
+
turn = session_turn_counters[session_id]
|
|
316
|
+
session_turn_counters[session_id] = turn + 1
|
|
317
|
+
else:
|
|
318
|
+
try:
|
|
319
|
+
turn = int(turn)
|
|
320
|
+
except (TypeError, ValueError):
|
|
321
|
+
continue
|
|
322
|
+
session_turn_counters[session_id] = max(session_turn_counters[session_id], turn + 1)
|
|
323
|
+
|
|
324
|
+
call_records = _parse_json(row["call_records"]) or []
|
|
325
|
+
if not isinstance(call_records, list) or not call_records:
|
|
326
|
+
continue
|
|
327
|
+
|
|
328
|
+
for record in call_records:
|
|
329
|
+
messages: list[dict[str, Any]] = []
|
|
330
|
+
for message in record.get("input_messages", []):
|
|
331
|
+
role = message.get("role", "unknown")
|
|
332
|
+
content = _collect_text(message.get("parts"))
|
|
333
|
+
if not content:
|
|
334
|
+
continue
|
|
335
|
+
messages.append({"role": role, "content": content})
|
|
336
|
+
|
|
337
|
+
assistant_content = ""
|
|
338
|
+
assistant_tool_calls: list[dict[str, Any]] = []
|
|
339
|
+
|
|
340
|
+
output_text = record.get("output_text")
|
|
341
|
+
parsed_response: dict[str, Any] | None = None
|
|
342
|
+
if isinstance(output_text, str) and output_text:
|
|
343
|
+
try:
|
|
344
|
+
parsed_response = json.loads(output_text)
|
|
345
|
+
except json.JSONDecodeError:
|
|
346
|
+
parsed_response = None
|
|
347
|
+
|
|
348
|
+
if parsed_response:
|
|
349
|
+
choices = parsed_response.get("choices") or []
|
|
350
|
+
if choices:
|
|
351
|
+
message = choices[0].get("message") or {}
|
|
352
|
+
assistant_content = message.get("content") or ""
|
|
353
|
+
assistant_tool_calls = _normalise_tool_calls(message.get("tool_calls"))
|
|
354
|
+
|
|
355
|
+
if not assistant_tool_calls:
|
|
356
|
+
assistant_tool_calls = _normalise_tool_calls(record.get("output_tool_calls"))
|
|
357
|
+
|
|
358
|
+
assistant_message: dict[str, Any] = {"role": "assistant", "content": assistant_content or ""}
|
|
359
|
+
if assistant_tool_calls:
|
|
360
|
+
assistant_message["tool_calls"] = assistant_tool_calls
|
|
361
|
+
|
|
362
|
+
if assistant_message.get("content") == "" and not assistant_message.get("tool_calls"):
|
|
363
|
+
continue
|
|
364
|
+
|
|
365
|
+
messages.append(assistant_message)
|
|
366
|
+
|
|
367
|
+
if len(messages) < 2:
|
|
368
|
+
continue
|
|
369
|
+
|
|
370
|
+
achievements = achievements_map.get((session_id, turn), {"unique": [], "all": []})
|
|
371
|
+
cumulative_unique[session_id] += len(achievements.get("unique", []))
|
|
372
|
+
|
|
373
|
+
metadata = {
|
|
374
|
+
"session_id": session_id,
|
|
375
|
+
"turn": turn,
|
|
376
|
+
"model": row["model_name"],
|
|
377
|
+
"provider": row["provider"] or "unknown",
|
|
378
|
+
"achievements": {
|
|
379
|
+
"new_unique": achievements.get("unique", []),
|
|
380
|
+
"turned_true": achievements.get("all", []),
|
|
381
|
+
"cumulative_unique": cumulative_unique[session_id],
|
|
382
|
+
},
|
|
383
|
+
}
|
|
384
|
+
|
|
385
|
+
dataset.append({"messages": messages, "metadata": metadata})
|
|
386
|
+
if limit is not None and len(dataset) >= limit:
|
|
387
|
+
return dataset
|
|
388
|
+
|
|
389
|
+
return dataset
|
|
390
|
+
|
|
391
|
+
|
|
392
|
+
def write_jsonl(path: Path, records: Iterable[dict[str, Any]]) -> None:
|
|
393
|
+
path.parent.mkdir(parents=True, exist_ok=True)
|
|
394
|
+
with path.open("w", encoding="utf-8") as fh:
|
|
395
|
+
for record in records:
|
|
396
|
+
json.dump(record, fh, ensure_ascii=False)
|
|
397
|
+
fh.write("\n")
|
|
398
|
+
|
|
399
|
+
|
|
400
|
+
def _validate_dataset(records: list[dict[str, Any]]) -> None:
|
|
401
|
+
errors: list[str] = []
|
|
402
|
+
for idx, record in enumerate(records, start=1):
|
|
403
|
+
messages = record.get("messages")
|
|
404
|
+
if not isinstance(messages, list) or not messages:
|
|
405
|
+
errors.append(f"row {idx}: missing messages list")
|
|
406
|
+
if len(errors) >= 20:
|
|
407
|
+
break
|
|
408
|
+
continue
|
|
409
|
+
for msg_idx, msg in enumerate(messages):
|
|
410
|
+
if not isinstance(msg, dict):
|
|
411
|
+
errors.append(f"row {idx}: message {msg_idx} is not an object")
|
|
412
|
+
break
|
|
413
|
+
if "role" not in msg or "content" not in msg:
|
|
414
|
+
errors.append(f"row {idx}: message {msg_idx} missing role/content")
|
|
415
|
+
break
|
|
416
|
+
if not isinstance(msg["role"], str):
|
|
417
|
+
errors.append(f"row {idx}: message {msg_idx} role not string")
|
|
418
|
+
break
|
|
419
|
+
if not isinstance(msg["content"], str):
|
|
420
|
+
errors.append(f"row {idx}: message {msg_idx} content not string")
|
|
421
|
+
break
|
|
422
|
+
if len(errors) >= 20:
|
|
423
|
+
break
|
|
424
|
+
if errors:
|
|
425
|
+
summary = "\n - ".join(errors)
|
|
426
|
+
raise SystemExit(f"Validation error while exporting dataset:\n - {summary}")
|
|
427
|
+
|
|
428
|
+
|
|
429
|
+
def main() -> None:
|
|
430
|
+
parser = argparse.ArgumentParser(description=__doc__)
|
|
431
|
+
parser.add_argument("--db", type=Path, default=Path("traces/v3/synth_ai.db"), help="Path to tracing_v3 SQLite DB")
|
|
432
|
+
parser.add_argument("--output", type=Path, required=True, help="Destination JSONL path for the exported dataset")
|
|
433
|
+
parser.add_argument("--model", action="append", dest="models", help="Restrict to sessions whose dominant model matches (repeatable)")
|
|
434
|
+
parser.add_argument("--provider", action="append", dest="providers", help="Restrict to sessions whose dominant provider matches (repeatable)")
|
|
435
|
+
parser.add_argument("--min-unique", type=int, default=None, help="Minimum unique achievements per session")
|
|
436
|
+
parser.add_argument("--max-unique", type=int, default=None, help="Maximum unique achievements per session")
|
|
437
|
+
parser.add_argument(
|
|
438
|
+
"--exclude-achievement",
|
|
439
|
+
action="append",
|
|
440
|
+
dest="exclude_achievements",
|
|
441
|
+
help="Achievements to ignore when evaluating --min-unique/--max-unique (repeatable)",
|
|
442
|
+
)
|
|
443
|
+
parser.add_argument("--require-achievement", action="append", dest="required_achievements", help="Require these outcome achievements (repeatable)")
|
|
444
|
+
parser.add_argument("--min-outcome-reward", type=float, default=None, help="Minimum total outcome reward per session")
|
|
445
|
+
parser.add_argument("--max-outcome-reward", type=float, default=None, help="Maximum total outcome reward per session")
|
|
446
|
+
parser.add_argument("--event-reward", action="append", dest="event_reward_filters", help="Require reward_type[:min_total] in event_rewards (repeatable)")
|
|
447
|
+
parser.add_argument("--limit", type=int, default=None, help="Maximum number of examples to emit")
|
|
448
|
+
args = parser.parse_args()
|
|
449
|
+
|
|
450
|
+
if not args.db.exists():
|
|
451
|
+
print(f"Database not found: {args.db}", file=sys.stderr)
|
|
452
|
+
raise SystemExit(1)
|
|
453
|
+
|
|
454
|
+
conn = connect(args.db)
|
|
455
|
+
try:
|
|
456
|
+
(
|
|
457
|
+
achievements_map,
|
|
458
|
+
unique_counts_per_session,
|
|
459
|
+
_name_counts,
|
|
460
|
+
_size_counts,
|
|
461
|
+
session_unique_sets,
|
|
462
|
+
session_final_achievements,
|
|
463
|
+
) = fetch_achievement_data(conn)
|
|
464
|
+
session_models = fetch_session_models(conn)
|
|
465
|
+
outcome_data = fetch_outcome_rewards(conn)
|
|
466
|
+
event_totals = fetch_event_reward_totals(conn)
|
|
467
|
+
event_filters = parse_event_filters(args.event_reward_filters)
|
|
468
|
+
|
|
469
|
+
allowed_models = set(args.models) if args.models else None
|
|
470
|
+
allowed_providers = set(args.providers) if args.providers else None
|
|
471
|
+
required_achievements = set(args.required_achievements or [])
|
|
472
|
+
excluded_achievements = set(args.exclude_achievements or [])
|
|
473
|
+
|
|
474
|
+
eligible_sessions: set[str] = set()
|
|
475
|
+
for session_id, (model_name, provider, _calls) in session_models.items():
|
|
476
|
+
if allowed_models and model_name not in allowed_models:
|
|
477
|
+
continue
|
|
478
|
+
if allowed_providers and (provider or "unknown") not in allowed_providers:
|
|
479
|
+
continue
|
|
480
|
+
|
|
481
|
+
session_uniques = session_unique_sets.get(session_id, set())
|
|
482
|
+
adjusted_uniques = {a for a in session_uniques if a not in excluded_achievements}
|
|
483
|
+
unique_count = len(adjusted_uniques)
|
|
484
|
+
if args.min_unique is not None and unique_count < args.min_unique:
|
|
485
|
+
continue
|
|
486
|
+
if args.max_unique is not None and unique_count > args.max_unique:
|
|
487
|
+
continue
|
|
488
|
+
|
|
489
|
+
outcome = outcome_data.get(session_id)
|
|
490
|
+
total_reward = outcome["total_reward"] if outcome else 0.0
|
|
491
|
+
final_achievements = outcome["achievements"] if outcome else session_final_achievements.get(session_id, set())
|
|
492
|
+
|
|
493
|
+
if args.min_outcome_reward is not None and total_reward < args.min_outcome_reward:
|
|
494
|
+
continue
|
|
495
|
+
if args.max_outcome_reward is not None and total_reward > args.max_outcome_reward:
|
|
496
|
+
continue
|
|
497
|
+
if required_achievements and not required_achievements.issubset(final_achievements):
|
|
498
|
+
continue
|
|
499
|
+
|
|
500
|
+
session_event_totals = event_totals.get(session_id, {})
|
|
501
|
+
meets_event_filters = True
|
|
502
|
+
for reward_type, min_total in event_filters:
|
|
503
|
+
total = session_event_totals.get(reward_type, {}).get("total", 0.0)
|
|
504
|
+
if total < min_total:
|
|
505
|
+
meets_event_filters = False
|
|
506
|
+
break
|
|
507
|
+
if not meets_event_filters:
|
|
508
|
+
continue
|
|
509
|
+
|
|
510
|
+
eligible_sessions.add(session_id)
|
|
511
|
+
|
|
512
|
+
if not eligible_sessions:
|
|
513
|
+
print("No sessions matched the provided filters.", file=sys.stderr)
|
|
514
|
+
raise SystemExit(1)
|
|
515
|
+
|
|
516
|
+
dataset = build_sft_dataset(
|
|
517
|
+
conn,
|
|
518
|
+
achievements_map,
|
|
519
|
+
eligible_sessions,
|
|
520
|
+
allowed_models=allowed_models,
|
|
521
|
+
limit=args.limit,
|
|
522
|
+
)
|
|
523
|
+
|
|
524
|
+
if not dataset:
|
|
525
|
+
print("No rollout steps matched the filters (after session selection).", file=sys.stderr)
|
|
526
|
+
raise SystemExit(1)
|
|
527
|
+
|
|
528
|
+
_validate_dataset(dataset)
|
|
529
|
+
write_jsonl(args.output, dataset)
|
|
530
|
+
session_ids = {item.get("metadata", {}).get("session_id") for item in dataset}
|
|
531
|
+
session_ids.discard(None)
|
|
532
|
+
print(
|
|
533
|
+
f"Wrote {len(dataset)} examples from {len(session_ids)} session(s) -> {args.output}",
|
|
534
|
+
file=sys.stderr,
|
|
535
|
+
)
|
|
536
|
+
finally:
|
|
537
|
+
conn.close()
|
|
538
|
+
|
|
539
|
+
|
|
540
|
+
if __name__ == "__main__":
|
|
541
|
+
main()
|
|
@@ -0,0 +1,88 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
"""Quick smoke test that drives a rollout through the Groq proxy-backed Crafter Task App."""
|
|
4
|
+
|
|
5
|
+
import argparse
|
|
6
|
+
import asyncio
|
|
7
|
+
import os
|
|
8
|
+
from typing import Any
|
|
9
|
+
|
|
10
|
+
from synth_ai.task import (
|
|
11
|
+
INTERACT_TOOL_SCHEMA,
|
|
12
|
+
RolloutEnvSpec,
|
|
13
|
+
RolloutPolicySpec,
|
|
14
|
+
RolloutRequest,
|
|
15
|
+
TaskAppClient,
|
|
16
|
+
to_jsonable,
|
|
17
|
+
)
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def _build_policy_payload(seed: int, model: str) -> dict[str, Any]:
|
|
21
|
+
return {
|
|
22
|
+
"model": model,
|
|
23
|
+
"tools": INTERACT_TOOL_SCHEMA,
|
|
24
|
+
"messages": [
|
|
25
|
+
{
|
|
26
|
+
"role": "system",
|
|
27
|
+
"content": "You control the Crafter agent. Think briefly, then call the interact tool with 3-5 actions to maximize achievements.",
|
|
28
|
+
},
|
|
29
|
+
{
|
|
30
|
+
"role": "user",
|
|
31
|
+
"content": (
|
|
32
|
+
"Environment seed {seed}. Plan initial survival/crafting steps and then call interact with concrete actions."
|
|
33
|
+
).format(seed=seed),
|
|
34
|
+
},
|
|
35
|
+
],
|
|
36
|
+
}
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
async def run(args: argparse.Namespace) -> None:
|
|
40
|
+
client = TaskAppClient(args.base_url, api_key=args.api_key, timeout=args.timeout)
|
|
41
|
+
|
|
42
|
+
health = await client.health()
|
|
43
|
+
print("/health →", to_jsonable(health))
|
|
44
|
+
|
|
45
|
+
info = await client.info()
|
|
46
|
+
print("/info →", to_jsonable(info))
|
|
47
|
+
|
|
48
|
+
inference_url = args.inference_url or f"{args.base_url.rstrip('/')}/proxy/groq"
|
|
49
|
+
|
|
50
|
+
request = RolloutRequest(
|
|
51
|
+
run_id=args.run_id,
|
|
52
|
+
env=RolloutEnvSpec(env_name="crafter", seed=args.seed, config={"seed": args.seed}),
|
|
53
|
+
policy=RolloutPolicySpec(
|
|
54
|
+
policy_name="groq-smoke",
|
|
55
|
+
config={"model": args.model, "inference_url": inference_url.rstrip("/")},
|
|
56
|
+
),
|
|
57
|
+
ops=[
|
|
58
|
+
{"type": "policy", "payload": _build_policy_payload(args.seed, args.model)},
|
|
59
|
+
{"type": "env"},
|
|
60
|
+
],
|
|
61
|
+
)
|
|
62
|
+
|
|
63
|
+
response = await client.rollout(request)
|
|
64
|
+
print("rollout.metrics →", to_jsonable(response.metrics.model_dump()))
|
|
65
|
+
for idx, step in enumerate(response.trajectories[0].steps, start=1):
|
|
66
|
+
print(f"step[{idx}] tool_calls={step.tool_calls} reward={step.reward} info={to_jsonable(step.info)}")
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
def _parse_args() -> argparse.Namespace:
|
|
70
|
+
parser = argparse.ArgumentParser(description=__doc__)
|
|
71
|
+
parser.add_argument("--base-url", default=os.getenv("TASK_APP_BASE_URL", "http://localhost:8000"))
|
|
72
|
+
parser.add_argument("--api-key", default=os.getenv("TASK_APP_API_KEY"), required=os.getenv("TASK_APP_API_KEY") is None)
|
|
73
|
+
parser.add_argument("--model", default=os.getenv("GROQ_MODEL", "groq/mixtral-8x7b"))
|
|
74
|
+
parser.add_argument("--inference-url", default=os.getenv("TASK_APP_INFERENCE_URL"))
|
|
75
|
+
parser.add_argument("--seed", type=int, default=int(os.getenv("CRAFTER_TEST_SEED", "42")))
|
|
76
|
+
parser.add_argument("--run-id", default=os.getenv("TASK_APP_RUN_ID", "groq-test"))
|
|
77
|
+
parser.add_argument("--timeout", type=float, default=float(os.getenv("TASK_APP_TIMEOUT", "60")))
|
|
78
|
+
return parser.parse_args()
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
def main() -> None:
|
|
82
|
+
args = _parse_args()
|
|
83
|
+
asyncio.run(run(args))
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
if __name__ == "__main__":
|
|
87
|
+
main()
|
|
88
|
+
|