synth-ai 0.2.9.dev17__py3-none-any.whl → 0.2.12__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of synth-ai might be problematic. Click here for more details.
- examples/dev/qwen3_32b_qlora_4xh100.toml +40 -0
- examples/multi_step/crafter_rl_lora.md +29 -0
- examples/multi_step/task_app_config_notes.md +488 -0
- examples/qwen_coder/infer_ft_smoke.py +1 -0
- examples/qwen_coder/scripts/infer_coder.sh +1 -0
- examples/qwen_coder/scripts/train_coder_30b.sh +1 -0
- examples/qwen_coder/subset_jsonl.py +1 -0
- examples/qwen_coder/todos.md +38 -0
- examples/qwen_coder/validate_jsonl.py +1 -0
- examples/vlm/PROPOSAL.md +53 -0
- examples/warming_up_to_rl/configs/eval_stepwise_complex.toml +33 -0
- examples/warming_up_to_rl/configs/eval_stepwise_consistent.toml +26 -0
- examples/warming_up_to_rl/configs/eval_stepwise_per_achievement.toml +36 -0
- examples/warming_up_to_rl/configs/eval_stepwise_simple.toml +30 -0
- examples/warming_up_to_rl/old/event_rewards.md +234 -0
- examples/warming_up_to_rl/old/notes.md +73 -0
- examples/warming_up_to_rl/run_eval.py +142 -25
- examples/warming_up_to_rl/task_app/synth_envs_hosted/rollout.py +146 -2
- synth_ai/__init__.py +5 -20
- synth_ai/api/train/builders.py +25 -14
- synth_ai/api/train/cli.py +29 -6
- synth_ai/api/train/env_resolver.py +18 -19
- synth_ai/api/train/supported_algos.py +8 -5
- synth_ai/api/train/utils.py +6 -1
- synth_ai/cli/__init__.py +4 -2
- synth_ai/cli/_storage.py +19 -0
- synth_ai/cli/balance.py +14 -2
- synth_ai/cli/calc.py +37 -22
- synth_ai/cli/legacy_root_backup.py +12 -14
- synth_ai/cli/recent.py +12 -7
- synth_ai/cli/root.py +1 -23
- synth_ai/cli/status.py +4 -3
- synth_ai/cli/task_apps.py +143 -137
- synth_ai/cli/traces.py +4 -3
- synth_ai/cli/watch.py +3 -2
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_modal_ft/filter_traces_sft_turso.py +738 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_openai_ft/filter_traces_sft_turso.py +580 -0
- synth_ai/jobs/client.py +15 -3
- synth_ai/task/server.py +14 -7
- synth_ai/tracing_v3/decorators.py +51 -26
- synth_ai/tracing_v3/examples/basic_usage.py +12 -7
- synth_ai/tracing_v3/llm_call_record_helpers.py +107 -53
- synth_ai/tracing_v3/replica_sync.py +8 -4
- synth_ai/tracing_v3/storage/utils.py +11 -9
- synth_ai/tracing_v3/turso/__init__.py +12 -0
- synth_ai/tracing_v3/turso/daemon.py +2 -1
- synth_ai/tracing_v3/turso/native_manager.py +28 -15
- {synth_ai-0.2.9.dev17.dist-info → synth_ai-0.2.12.dist-info}/METADATA +33 -88
- {synth_ai-0.2.9.dev17.dist-info → synth_ai-0.2.12.dist-info}/RECORD +53 -41
- {synth_ai-0.2.9.dev17.dist-info → synth_ai-0.2.12.dist-info}/top_level.txt +0 -1
- synth/__init__.py +0 -14
- synth_ai/_docs_message.py +0 -10
- synth_ai/main.py +0 -5
- {synth_ai-0.2.9.dev17.dist-info → synth_ai-0.2.12.dist-info}/WHEEL +0 -0
- {synth_ai-0.2.9.dev17.dist-info → synth_ai-0.2.12.dist-info}/entry_points.txt +0 -0
- {synth_ai-0.2.9.dev17.dist-info → synth_ai-0.2.12.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,738 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
"""
|
|
3
|
+
Filter traces from Turso/SQLite (v3) to create Modal/Synth SFT-ready .jsonl files
|
|
4
|
+
Supports two modes:
|
|
5
|
+
1. Trajectory-level filtering: Include entire trajectories above a score threshold
|
|
6
|
+
2. Window-based filtering: Extract high-scoring windows of actions
|
|
7
|
+
|
|
8
|
+
This is the v3 version using the new async Turso-based tracing system.
|
|
9
|
+
"""
|
|
10
|
+
|
|
11
|
+
import json
|
|
12
|
+
import argparse
|
|
13
|
+
import asyncio
|
|
14
|
+
from pathlib import Path
|
|
15
|
+
from typing import List, Dict, Any, Tuple, Optional
|
|
16
|
+
from collections import defaultdict
|
|
17
|
+
import numpy as np
|
|
18
|
+
import os
|
|
19
|
+
import sys
|
|
20
|
+
import toml
|
|
21
|
+
import pandas as pd
|
|
22
|
+
|
|
23
|
+
# Add synth_ai to path
|
|
24
|
+
sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent.parent.parent))
|
|
25
|
+
|
|
26
|
+
from synth_ai.tracing_v3 import SessionTracer
|
|
27
|
+
from synth_ai.tracing_v3.turso.manager import AsyncSQLTraceManager
|
|
28
|
+
from synth_ai.tracing_v3.abstractions import LMCAISEvent, EnvironmentEvent, RuntimeEvent
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def create_histogram(data: List[float], bins: int = 20, width: int = 60, height: int = 15,
|
|
32
|
+
title: str = "", x_label: str = "", y_label: str = "") -> str:
|
|
33
|
+
"""Create a beautiful ASCII histogram."""
|
|
34
|
+
if not data:
|
|
35
|
+
return "No data to display"
|
|
36
|
+
|
|
37
|
+
# Create histogram
|
|
38
|
+
counts, edges = np.histogram(data, bins=bins)
|
|
39
|
+
max_count = max(counts) if len(counts) > 0 else 1
|
|
40
|
+
|
|
41
|
+
# Normalize heights
|
|
42
|
+
if max_count > 0:
|
|
43
|
+
heights = [int(c * height / max_count) for c in counts]
|
|
44
|
+
else:
|
|
45
|
+
heights = [0] * len(counts)
|
|
46
|
+
|
|
47
|
+
# Build the plot
|
|
48
|
+
lines = []
|
|
49
|
+
|
|
50
|
+
# Title
|
|
51
|
+
if title:
|
|
52
|
+
lines.append(f"\n{title.center(width + 10)}")
|
|
53
|
+
lines.append("=" * (width + 10))
|
|
54
|
+
|
|
55
|
+
# Y-axis label
|
|
56
|
+
if y_label:
|
|
57
|
+
lines.append(f"{y_label}")
|
|
58
|
+
|
|
59
|
+
# Plot area with y-axis
|
|
60
|
+
for y in range(height, 0, -1):
|
|
61
|
+
# Y-axis value
|
|
62
|
+
y_val = int(max_count * y / height)
|
|
63
|
+
line = f"{y_val:>6} │"
|
|
64
|
+
|
|
65
|
+
# Bars
|
|
66
|
+
for h in heights:
|
|
67
|
+
if h >= y:
|
|
68
|
+
line += "█"
|
|
69
|
+
else:
|
|
70
|
+
line += " "
|
|
71
|
+
|
|
72
|
+
lines.append(line)
|
|
73
|
+
|
|
74
|
+
# X-axis
|
|
75
|
+
lines.append(f"{'':>6} └" + "─" * len(heights))
|
|
76
|
+
|
|
77
|
+
# X-axis labels
|
|
78
|
+
x_labels_line = " " * 8
|
|
79
|
+
min_val, max_val = min(data), max(data)
|
|
80
|
+
|
|
81
|
+
# Add labels at key positions
|
|
82
|
+
label_positions = [0, len(heights)//4, len(heights)//2, 3*len(heights)//4, len(heights)-1]
|
|
83
|
+
for i, pos in enumerate(label_positions):
|
|
84
|
+
if pos < len(edges) - 1:
|
|
85
|
+
val = edges[pos]
|
|
86
|
+
label = f"{val:.1f}"
|
|
87
|
+
# Calculate position
|
|
88
|
+
target_pos = 8 + pos
|
|
89
|
+
if i == 0:
|
|
90
|
+
x_labels_line = label + x_labels_line[len(label):]
|
|
91
|
+
elif i == len(label_positions) - 1:
|
|
92
|
+
start = max(0, target_pos - len(label))
|
|
93
|
+
x_labels_line = x_labels_line[:start] + label
|
|
94
|
+
else:
|
|
95
|
+
start = max(0, target_pos - len(label)//2)
|
|
96
|
+
end = min(len(x_labels_line), start + len(label))
|
|
97
|
+
if start < len(x_labels_line):
|
|
98
|
+
x_labels_line = x_labels_line[:start] + label[:end-start] + x_labels_line[end:]
|
|
99
|
+
|
|
100
|
+
lines.append(x_labels_line)
|
|
101
|
+
|
|
102
|
+
# X-axis label
|
|
103
|
+
if x_label:
|
|
104
|
+
lines.append(f"\n{x_label.center(width + 10)}")
|
|
105
|
+
|
|
106
|
+
return "\n".join(lines)
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
def create_bar_chart(categories: List[str], values: List[int], width: int = 60,
|
|
110
|
+
title: str = "", show_values: bool = True) -> str:
|
|
111
|
+
"""Create a horizontal bar chart."""
|
|
112
|
+
if not categories or not values:
|
|
113
|
+
return "No data to display"
|
|
114
|
+
|
|
115
|
+
max_val = max(values) if values else 1
|
|
116
|
+
lines = []
|
|
117
|
+
|
|
118
|
+
# Title
|
|
119
|
+
if title:
|
|
120
|
+
lines.append(f"\n{title}")
|
|
121
|
+
lines.append("=" * (width + 20))
|
|
122
|
+
|
|
123
|
+
# Find longest category name for alignment
|
|
124
|
+
max_cat_len = max(len(cat) for cat in categories)
|
|
125
|
+
|
|
126
|
+
# Create bars
|
|
127
|
+
for cat, val in zip(categories, values):
|
|
128
|
+
# Normalize bar length
|
|
129
|
+
bar_len = int(val * width / max_val) if max_val > 0 else 0
|
|
130
|
+
bar = "█" * bar_len
|
|
131
|
+
|
|
132
|
+
# Format line
|
|
133
|
+
if show_values:
|
|
134
|
+
line = f"{cat:<{max_cat_len}} │ {bar} {val}"
|
|
135
|
+
else:
|
|
136
|
+
line = f"{cat:<{max_cat_len}} │ {bar}"
|
|
137
|
+
|
|
138
|
+
lines.append(line)
|
|
139
|
+
|
|
140
|
+
return "\n".join(lines)
|
|
141
|
+
|
|
142
|
+
|
|
143
|
+
class FinetuningDataExtractorV3:
|
|
144
|
+
"""Extract fine-tuning data from v3 Turso traces."""
|
|
145
|
+
|
|
146
|
+
def __init__(self, db_url: str):
|
|
147
|
+
self.db_manager = AsyncSQLTraceManager(db_url)
|
|
148
|
+
self._initialized = False
|
|
149
|
+
|
|
150
|
+
async def __aenter__(self):
|
|
151
|
+
await self.db_manager.initialize()
|
|
152
|
+
self._initialized = True
|
|
153
|
+
return self
|
|
154
|
+
|
|
155
|
+
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
|
156
|
+
await self.db_manager.close()
|
|
157
|
+
|
|
158
|
+
async def get_all_sessions(self) -> pd.DataFrame:
|
|
159
|
+
"""Get all session IDs from the database."""
|
|
160
|
+
query = """
|
|
161
|
+
SELECT DISTINCT session_id, created_at
|
|
162
|
+
FROM session_traces
|
|
163
|
+
ORDER BY created_at DESC
|
|
164
|
+
"""
|
|
165
|
+
return await self.db_manager.query_traces(query)
|
|
166
|
+
|
|
167
|
+
async def get_session_metrics(self, session_id: str) -> Dict[str, Any]:
|
|
168
|
+
"""Get metrics for a specific session."""
|
|
169
|
+
# Prefer outcome rewards table if present; fall back to environment event sums
|
|
170
|
+
outcome_query = """
|
|
171
|
+
SELECT COALESCE(MAX(total_reward), 0) as total_reward
|
|
172
|
+
FROM outcome_rewards
|
|
173
|
+
WHERE session_id = :session_id
|
|
174
|
+
"""
|
|
175
|
+
outcome_df = await self.db_manager.query_traces(outcome_query, {"session_id": session_id})
|
|
176
|
+
total_reward: float = 0.0
|
|
177
|
+
try:
|
|
178
|
+
if not outcome_df.empty:
|
|
179
|
+
total_reward = float(outcome_df['total_reward'].iloc[0] or 0.0)
|
|
180
|
+
except Exception:
|
|
181
|
+
total_reward = 0.0
|
|
182
|
+
|
|
183
|
+
if total_reward == 0.0:
|
|
184
|
+
# Fallback: sum environment rewards
|
|
185
|
+
reward_query = """
|
|
186
|
+
SELECT COALESCE(SUM(reward), 0) as total_reward
|
|
187
|
+
FROM events
|
|
188
|
+
WHERE session_id = :session_id
|
|
189
|
+
AND event_type = 'environment'
|
|
190
|
+
AND reward IS NOT NULL
|
|
191
|
+
"""
|
|
192
|
+
reward_df = await self.db_manager.query_traces(reward_query, {"session_id": session_id})
|
|
193
|
+
total_reward = float(reward_df['total_reward'].iloc[0]) if not reward_df.empty else 0.0
|
|
194
|
+
|
|
195
|
+
# Get total tokens and cost from LM events
|
|
196
|
+
lm_query = """
|
|
197
|
+
SELECT
|
|
198
|
+
COALESCE(SUM(total_tokens), 0) as total_tokens,
|
|
199
|
+
COALESCE(SUM(cost_usd) / 100.0, 0) as total_cost
|
|
200
|
+
FROM events
|
|
201
|
+
WHERE session_id = :session_id
|
|
202
|
+
AND event_type = 'cais'
|
|
203
|
+
"""
|
|
204
|
+
lm_df = await self.db_manager.query_traces(lm_query, {"session_id": session_id})
|
|
205
|
+
|
|
206
|
+
total_tokens = int(lm_df['total_tokens'].iloc[0]) if not lm_df.empty else 0
|
|
207
|
+
total_cost = float(lm_df['total_cost'].iloc[0]) if not lm_df.empty else 0.0
|
|
208
|
+
|
|
209
|
+
return {
|
|
210
|
+
'session_id': session_id,
|
|
211
|
+
'total_reward': total_reward,
|
|
212
|
+
'total_tokens': total_tokens,
|
|
213
|
+
'total_cost': total_cost
|
|
214
|
+
}
|
|
215
|
+
|
|
216
|
+
async def get_session_achievements(self, session_id: str) -> List[str]:
|
|
217
|
+
"""Get list of achievements unlocked in a session.
|
|
218
|
+
|
|
219
|
+
Aggregates across ALL environment events with a non-null system_state_after,
|
|
220
|
+
unioning any flags that were ever true. This is more robust than inspecting
|
|
221
|
+
only the last event, which can miss transient unlocks.
|
|
222
|
+
"""
|
|
223
|
+
query = """
|
|
224
|
+
SELECT system_state_after
|
|
225
|
+
FROM events
|
|
226
|
+
WHERE session_id = :session_id
|
|
227
|
+
AND event_type = 'environment'
|
|
228
|
+
AND system_state_after IS NOT NULL
|
|
229
|
+
ORDER BY id ASC
|
|
230
|
+
"""
|
|
231
|
+
df = await self.db_manager.query_traces(query, {"session_id": session_id})
|
|
232
|
+
|
|
233
|
+
if df.empty:
|
|
234
|
+
return []
|
|
235
|
+
|
|
236
|
+
unlocked: Dict[str, bool] = {}
|
|
237
|
+
for _, row in df.iterrows():
|
|
238
|
+
try:
|
|
239
|
+
state_after = row['system_state_after']
|
|
240
|
+
if not state_after:
|
|
241
|
+
continue
|
|
242
|
+
if isinstance(state_after, str):
|
|
243
|
+
state_after = json.loads(state_after)
|
|
244
|
+
if not isinstance(state_after, dict):
|
|
245
|
+
continue
|
|
246
|
+
public_state = state_after.get('public_state')
|
|
247
|
+
if not isinstance(public_state, dict):
|
|
248
|
+
continue
|
|
249
|
+
ach = public_state.get('achievements_status')
|
|
250
|
+
if not isinstance(ach, dict):
|
|
251
|
+
continue
|
|
252
|
+
for name, flag in ach.items():
|
|
253
|
+
if flag:
|
|
254
|
+
unlocked[name] = True
|
|
255
|
+
except Exception as e:
|
|
256
|
+
print(f"Error parsing achievements row: {e}")
|
|
257
|
+
continue
|
|
258
|
+
|
|
259
|
+
return [k for k, v in unlocked.items() if v]
|
|
260
|
+
|
|
261
|
+
async def filter_by_achievements(self, min_achievements: int) -> List[str]:
|
|
262
|
+
"""Get sessions with at least min_achievements unlocked."""
|
|
263
|
+
all_sessions = await self.get_all_sessions()
|
|
264
|
+
qualifying_sessions = []
|
|
265
|
+
|
|
266
|
+
for _, row in all_sessions.iterrows():
|
|
267
|
+
session_id = row['session_id']
|
|
268
|
+
achievements = await self.get_session_achievements(session_id)
|
|
269
|
+
if len(achievements) >= min_achievements:
|
|
270
|
+
qualifying_sessions.append(session_id)
|
|
271
|
+
|
|
272
|
+
return qualifying_sessions
|
|
273
|
+
|
|
274
|
+
async def extract_openai_format(self, session_ids: List[str], min_reward: float = 0.0) -> List[Dict[str, Any]]:
|
|
275
|
+
"""Extract training data in OpenAI format from filtered sessions."""
|
|
276
|
+
training_data = []
|
|
277
|
+
|
|
278
|
+
for session_id in session_ids:
|
|
279
|
+
# Get messages directly from the messages table
|
|
280
|
+
messages_query = """
|
|
281
|
+
SELECT m.message_type, m.content, m.message_time, st.turn_number
|
|
282
|
+
FROM messages m
|
|
283
|
+
LEFT JOIN session_timesteps st ON m.timestep_id = st.id
|
|
284
|
+
WHERE m.session_id = :session_id
|
|
285
|
+
ORDER BY COALESCE(st.turn_number, m.message_time), m.id
|
|
286
|
+
"""
|
|
287
|
+
|
|
288
|
+
messages_df = await self.db_manager.query_traces(messages_query, {"session_id": session_id})
|
|
289
|
+
|
|
290
|
+
if len(messages_df) == 0:
|
|
291
|
+
continue
|
|
292
|
+
|
|
293
|
+
# Build conversation history
|
|
294
|
+
messages = []
|
|
295
|
+
system_message = None
|
|
296
|
+
|
|
297
|
+
for _, row in messages_df.iterrows():
|
|
298
|
+
msg_type = row['message_type']
|
|
299
|
+
content = row['content']
|
|
300
|
+
|
|
301
|
+
# Parse content if it's JSON (from origin_system_id format)
|
|
302
|
+
try:
|
|
303
|
+
import json
|
|
304
|
+
content_data = json.loads(content)
|
|
305
|
+
if isinstance(content_data, dict) and 'payload' in content_data:
|
|
306
|
+
content = content_data['payload']
|
|
307
|
+
except:
|
|
308
|
+
pass
|
|
309
|
+
|
|
310
|
+
if msg_type == 'system' and system_message is None:
|
|
311
|
+
# Extract system message from the first system message
|
|
312
|
+
if isinstance(content, str):
|
|
313
|
+
system_message = content
|
|
314
|
+
|
|
315
|
+
elif msg_type == 'user':
|
|
316
|
+
# Format user messages
|
|
317
|
+
if isinstance(content, dict):
|
|
318
|
+
# Convert observation dict to formatted string
|
|
319
|
+
content = self._format_observation_content(content)
|
|
320
|
+
messages.append({"role": "user", "content": str(content)})
|
|
321
|
+
|
|
322
|
+
elif msg_type == 'assistant':
|
|
323
|
+
messages.append({"role": "assistant", "content": str(content)})
|
|
324
|
+
|
|
325
|
+
# Add system message at the beginning if found
|
|
326
|
+
if system_message:
|
|
327
|
+
messages.insert(0, {"role": "system", "content": system_message})
|
|
328
|
+
|
|
329
|
+
# Only include if we have a complete conversation
|
|
330
|
+
if len(messages) > 1:
|
|
331
|
+
# Get total reward for this session
|
|
332
|
+
reward_query = """
|
|
333
|
+
SELECT COALESCE(SUM(reward), 0) as total_reward
|
|
334
|
+
FROM events
|
|
335
|
+
WHERE session_id = :session_id
|
|
336
|
+
AND event_type = 'environment'
|
|
337
|
+
AND reward IS NOT NULL
|
|
338
|
+
"""
|
|
339
|
+
reward_df = await self.db_manager.query_traces(reward_query, {"session_id": session_id})
|
|
340
|
+
total_reward = reward_df.iloc[0]['total_reward'] if len(reward_df) > 0 else 0
|
|
341
|
+
|
|
342
|
+
training_data.append({
|
|
343
|
+
"messages": messages,
|
|
344
|
+
"metadata": {
|
|
345
|
+
"session_id": session_id,
|
|
346
|
+
"total_reward": float(total_reward) # Convert to float for JSON serialization
|
|
347
|
+
}
|
|
348
|
+
})
|
|
349
|
+
|
|
350
|
+
return training_data
|
|
351
|
+
|
|
352
|
+
async def extract_openai_window_format(self, session_ids: List[str]) -> List[Dict[str, Any]]:
|
|
353
|
+
"""Extract per-turn user→assistant pairs (window mode) for SFT.
|
|
354
|
+
|
|
355
|
+
Emits one example per assistant message, pairing it with the preceding user
|
|
356
|
+
message in the same turn (based on session_timesteps.turn_number).
|
|
357
|
+
"""
|
|
358
|
+
window_data: List[Dict[str, Any]] = []
|
|
359
|
+
|
|
360
|
+
for session_id in session_ids:
|
|
361
|
+
messages_query = """
|
|
362
|
+
SELECT st.turn_number, m.message_type, m.content, m.id AS message_id
|
|
363
|
+
FROM messages m
|
|
364
|
+
LEFT JOIN session_timesteps st ON m.timestep_id = st.id
|
|
365
|
+
WHERE m.session_id = :session_id
|
|
366
|
+
ORDER BY COALESCE(st.turn_number, m.message_time), m.id
|
|
367
|
+
"""
|
|
368
|
+
df = await self.db_manager.query_traces(messages_query, {"session_id": session_id})
|
|
369
|
+
if df is None or df.empty:
|
|
370
|
+
continue
|
|
371
|
+
|
|
372
|
+
# Parse content and group by turn_number
|
|
373
|
+
parsed_rows: List[Dict[str, Any]] = []
|
|
374
|
+
for _, row in df.iterrows():
|
|
375
|
+
msg_type = row.get('message_type')
|
|
376
|
+
content = row.get('content')
|
|
377
|
+
try:
|
|
378
|
+
content_data = json.loads(content)
|
|
379
|
+
if isinstance(content_data, dict) and 'payload' in content_data:
|
|
380
|
+
content = content_data['payload']
|
|
381
|
+
except Exception:
|
|
382
|
+
pass
|
|
383
|
+
parsed_rows.append({
|
|
384
|
+
'turn_number': row.get('turn_number'),
|
|
385
|
+
'message_type': msg_type,
|
|
386
|
+
'content': content,
|
|
387
|
+
})
|
|
388
|
+
|
|
389
|
+
# Build windows per turn_number
|
|
390
|
+
from collections import defaultdict
|
|
391
|
+
turn_to_msgs: Dict[int, List[Dict[str, Any]]] = defaultdict(list)
|
|
392
|
+
for r in parsed_rows:
|
|
393
|
+
tn = r.get('turn_number')
|
|
394
|
+
if tn is None:
|
|
395
|
+
# Skip rows that aren't associated with a turn
|
|
396
|
+
continue
|
|
397
|
+
turn_to_msgs[int(tn)].append(r)
|
|
398
|
+
|
|
399
|
+
# For each turn, find user -> assistant pair(s)
|
|
400
|
+
for tn in sorted(turn_to_msgs.keys()):
|
|
401
|
+
msgs = turn_to_msgs[tn]
|
|
402
|
+
# find last user before first assistant
|
|
403
|
+
user_content: Optional[str] = None
|
|
404
|
+
assistant_content: Optional[str] = None
|
|
405
|
+
for r in msgs:
|
|
406
|
+
if r['message_type'] == 'user':
|
|
407
|
+
user_content = r['content']
|
|
408
|
+
elif r['message_type'] == 'assistant' and assistant_content is None:
|
|
409
|
+
assistant_content = r['content']
|
|
410
|
+
if user_content and assistant_content:
|
|
411
|
+
window_data.append({
|
|
412
|
+
'messages': [
|
|
413
|
+
{ 'role': 'user', 'content': str(user_content) },
|
|
414
|
+
{ 'role': 'assistant', 'content': str(assistant_content) },
|
|
415
|
+
],
|
|
416
|
+
'metadata': {
|
|
417
|
+
'session_id': session_id,
|
|
418
|
+
'turn_number': tn,
|
|
419
|
+
}
|
|
420
|
+
})
|
|
421
|
+
|
|
422
|
+
return window_data
|
|
423
|
+
|
|
424
|
+
def _format_observation_content(self, obs: Dict[str, Any]) -> str:
|
|
425
|
+
"""Format observation dict into a readable string."""
|
|
426
|
+
if not isinstance(obs, dict):
|
|
427
|
+
return str(obs)
|
|
428
|
+
|
|
429
|
+
# Extract key fields for a concise representation
|
|
430
|
+
parts = []
|
|
431
|
+
|
|
432
|
+
if 'inventory' in obs:
|
|
433
|
+
inv = obs['inventory']
|
|
434
|
+
inv_str = ", ".join([f"{k}: {v}" for k, v in inv.items() if v > 0])
|
|
435
|
+
if inv_str:
|
|
436
|
+
parts.append(f"Inventory: {inv_str}")
|
|
437
|
+
|
|
438
|
+
if 'achievements_status' in obs:
|
|
439
|
+
achievements = [k for k, v in obs['achievements_status'].items() if v]
|
|
440
|
+
if achievements:
|
|
441
|
+
parts.append(f"Achievements: {', '.join(achievements)}")
|
|
442
|
+
|
|
443
|
+
if 'health' in obs:
|
|
444
|
+
parts.append(f"Health: {obs.get('health', 0)}")
|
|
445
|
+
|
|
446
|
+
return "; ".join(parts) if parts else "Empty observation"
|
|
447
|
+
|
|
448
|
+
|
|
449
|
+
async def filter_traces_from_turso(
|
|
450
|
+
db_url: str,
|
|
451
|
+
output_path: str,
|
|
452
|
+
config: Dict[str, Any]
|
|
453
|
+
) -> Tuple[int, Dict[str, Any]]:
|
|
454
|
+
"""
|
|
455
|
+
Filter traces from Turso/SQLite v3 database based on configuration.
|
|
456
|
+
|
|
457
|
+
Returns:
|
|
458
|
+
Tuple of (num_examples, statistics_dict)
|
|
459
|
+
"""
|
|
460
|
+
mode = config.get("mode", "trajectory")
|
|
461
|
+
filters = config.get("filters", {})
|
|
462
|
+
|
|
463
|
+
# Extract filtering parameters
|
|
464
|
+
min_reward = filters.get("min_total_reward", 0.0)
|
|
465
|
+
min_achievements = filters.get("min_achievements", 0)
|
|
466
|
+
max_cost = filters.get("max_cost", float('inf'))
|
|
467
|
+
max_tokens = filters.get("max_tokens", float('inf'))
|
|
468
|
+
|
|
469
|
+
# Modal/Synth specific: filter by model if specified
|
|
470
|
+
target_models = filters.get("models", [])
|
|
471
|
+
|
|
472
|
+
statistics = {
|
|
473
|
+
"total_sessions": 0,
|
|
474
|
+
"filtered_sessions": 0,
|
|
475
|
+
"total_examples": 0,
|
|
476
|
+
"reward_distribution": [],
|
|
477
|
+
"token_distribution": [],
|
|
478
|
+
"cost_distribution": [],
|
|
479
|
+
"model_distribution": defaultdict(int)
|
|
480
|
+
}
|
|
481
|
+
|
|
482
|
+
async with FinetuningDataExtractorV3(db_url) as extractor:
|
|
483
|
+
# Get all sessions
|
|
484
|
+
all_sessions = await extractor.get_all_sessions()
|
|
485
|
+
statistics["total_sessions"] = len(all_sessions)
|
|
486
|
+
|
|
487
|
+
# Filter sessions based on criteria
|
|
488
|
+
filtered_sessions = []
|
|
489
|
+
|
|
490
|
+
for _, row in all_sessions.iterrows():
|
|
491
|
+
session_id = row['session_id']
|
|
492
|
+
metrics = await extractor.get_session_metrics(session_id)
|
|
493
|
+
|
|
494
|
+
# Apply filters
|
|
495
|
+
if metrics['total_reward'] < min_reward:
|
|
496
|
+
continue
|
|
497
|
+
if metrics['total_cost'] > max_cost:
|
|
498
|
+
continue
|
|
499
|
+
if metrics['total_tokens'] > max_tokens:
|
|
500
|
+
continue
|
|
501
|
+
|
|
502
|
+
# Check achievements if required
|
|
503
|
+
if min_achievements > 0:
|
|
504
|
+
achievements = await extractor.get_session_achievements(session_id)
|
|
505
|
+
if len(achievements) < min_achievements:
|
|
506
|
+
continue
|
|
507
|
+
|
|
508
|
+
# Check model filter if specified
|
|
509
|
+
if target_models:
|
|
510
|
+
model_query = """
|
|
511
|
+
SELECT DISTINCT model_name
|
|
512
|
+
FROM events
|
|
513
|
+
WHERE session_id = :session_id
|
|
514
|
+
AND event_type = 'cais'
|
|
515
|
+
AND model_name IS NOT NULL
|
|
516
|
+
"""
|
|
517
|
+
model_df = await extractor.db_manager.query_traces(
|
|
518
|
+
model_query, {"session_id": session_id}
|
|
519
|
+
)
|
|
520
|
+
session_models = model_df['model_name'].tolist() if not model_df.empty else []
|
|
521
|
+
if not any(model in target_models for model in session_models):
|
|
522
|
+
continue
|
|
523
|
+
|
|
524
|
+
filtered_sessions.append(session_id)
|
|
525
|
+
|
|
526
|
+
# Collect statistics
|
|
527
|
+
statistics["reward_distribution"].append(metrics['total_reward'])
|
|
528
|
+
statistics["token_distribution"].append(metrics['total_tokens'])
|
|
529
|
+
statistics["cost_distribution"].append(metrics['total_cost'])
|
|
530
|
+
|
|
531
|
+
statistics["filtered_sessions"] = len(filtered_sessions)
|
|
532
|
+
|
|
533
|
+
# Extract training data
|
|
534
|
+
if mode == "trajectory":
|
|
535
|
+
training_data = await extractor.extract_openai_format(
|
|
536
|
+
session_ids=filtered_sessions,
|
|
537
|
+
min_reward=min_reward
|
|
538
|
+
)
|
|
539
|
+
else: # window mode
|
|
540
|
+
# For window mode, we need to implement window extraction
|
|
541
|
+
# For now, use trajectory mode
|
|
542
|
+
training_data = await extractor.extract_openai_format(
|
|
543
|
+
session_ids=filtered_sessions,
|
|
544
|
+
min_reward=min_reward
|
|
545
|
+
)
|
|
546
|
+
|
|
547
|
+
statistics["total_examples"] = len(training_data)
|
|
548
|
+
|
|
549
|
+
# Write to output file
|
|
550
|
+
output_file = Path(output_path)
|
|
551
|
+
output_file.parent.mkdir(exist_ok=True)
|
|
552
|
+
|
|
553
|
+
with open(output_file, 'w') as f:
|
|
554
|
+
for example in training_data:
|
|
555
|
+
f.write(json.dumps(example) + '\n')
|
|
556
|
+
|
|
557
|
+
# Get model distribution
|
|
558
|
+
model_query = """
|
|
559
|
+
SELECT model_name, COUNT(*) as count
|
|
560
|
+
FROM events
|
|
561
|
+
WHERE event_type = 'cais'
|
|
562
|
+
AND model_name IS NOT NULL
|
|
563
|
+
GROUP BY model_name
|
|
564
|
+
"""
|
|
565
|
+
model_stats = await extractor.db_manager.query_traces(model_query)
|
|
566
|
+
for _, row in model_stats.iterrows():
|
|
567
|
+
statistics["model_distribution"][row['model_name']] = int(row['count'])
|
|
568
|
+
|
|
569
|
+
return len(training_data), statistics
|
|
570
|
+
|
|
571
|
+
|
|
572
|
+
def print_statistics(stats: Dict[str, Any]):
|
|
573
|
+
"""Print filtering statistics with visualizations."""
|
|
574
|
+
print("\n" + "="*80)
|
|
575
|
+
print("FILTERING STATISTICS (Modal/Synth - v3)")
|
|
576
|
+
print("="*80)
|
|
577
|
+
|
|
578
|
+
# Basic counts
|
|
579
|
+
print(f"\nTotal sessions in database: {stats['total_sessions']}")
|
|
580
|
+
print(f"Sessions after filtering: {stats['filtered_sessions']}")
|
|
581
|
+
print(f"Training examples generated: {stats['total_examples']}")
|
|
582
|
+
|
|
583
|
+
filter_rate = (stats['filtered_sessions'] / stats['total_sessions'] * 100) if stats['total_sessions'] > 0 else 0
|
|
584
|
+
print(f"Filter pass rate: {filter_rate:.1f}%")
|
|
585
|
+
|
|
586
|
+
# Reward distribution
|
|
587
|
+
if stats['reward_distribution'] and any(not np.isnan(x) for x in stats['reward_distribution']):
|
|
588
|
+
valid_rewards = [x for x in stats['reward_distribution'] if not np.isnan(x)]
|
|
589
|
+
if valid_rewards:
|
|
590
|
+
print(create_histogram(
|
|
591
|
+
valid_rewards,
|
|
592
|
+
bins=20,
|
|
593
|
+
title="Reward Distribution",
|
|
594
|
+
x_label="Total Reward",
|
|
595
|
+
y_label="Count"
|
|
596
|
+
))
|
|
597
|
+
|
|
598
|
+
print(f"\nReward statistics:")
|
|
599
|
+
print(f" Min: {min(valid_rewards):.2f}")
|
|
600
|
+
print(f" Max: {max(valid_rewards):.2f}")
|
|
601
|
+
print(f" Mean: {np.mean(valid_rewards):.2f}")
|
|
602
|
+
print(f" Median: {np.median(valid_rewards):.2f}")
|
|
603
|
+
else:
|
|
604
|
+
print("\nNo valid reward data to display.")
|
|
605
|
+
|
|
606
|
+
# Token distribution
|
|
607
|
+
if stats['token_distribution'] and any(not np.isnan(x) for x in stats['token_distribution']):
|
|
608
|
+
valid_tokens = [x for x in stats['token_distribution'] if not np.isnan(x)]
|
|
609
|
+
if valid_tokens:
|
|
610
|
+
print(create_histogram(
|
|
611
|
+
valid_tokens,
|
|
612
|
+
bins=20,
|
|
613
|
+
title="Token Usage Distribution",
|
|
614
|
+
x_label="Total Tokens",
|
|
615
|
+
y_label="Count"
|
|
616
|
+
))
|
|
617
|
+
|
|
618
|
+
# Model distribution
|
|
619
|
+
if stats['model_distribution']:
|
|
620
|
+
models = list(stats['model_distribution'].keys())
|
|
621
|
+
counts = list(stats['model_distribution'].values())
|
|
622
|
+
print(create_bar_chart(
|
|
623
|
+
models,
|
|
624
|
+
counts,
|
|
625
|
+
title="Model Usage",
|
|
626
|
+
show_values=True
|
|
627
|
+
))
|
|
628
|
+
|
|
629
|
+
print("\n" + "="*80)
|
|
630
|
+
|
|
631
|
+
|
|
632
|
+
def main():
|
|
633
|
+
parser = argparse.ArgumentParser(
|
|
634
|
+
description="Filter traces from Turso/SQLite v3 for Modal/Synth fine-tuning",
|
|
635
|
+
formatter_class=argparse.RawDescriptionHelpFormatter,
|
|
636
|
+
epilog="""
|
|
637
|
+
Example usage:
|
|
638
|
+
# Use default config
|
|
639
|
+
python filter_traces_sft_turso.py -d sqlite:///traces.db -o ft_data/training.jsonl
|
|
640
|
+
|
|
641
|
+
# Use custom config file
|
|
642
|
+
python filter_traces_sft_turso.py -d sqlite:///traces.db -c filter_config.toml
|
|
643
|
+
|
|
644
|
+
# Override config parameters
|
|
645
|
+
python filter_traces_sft_turso.py -d sqlite:///traces.db --min-reward 5.0 --max-cost 0.1
|
|
646
|
+
|
|
647
|
+
# Filter by model
|
|
648
|
+
python filter_traces_sft_turso.py -d sqlite:///traces.db --models "Qwen/Qwen2.5-7B-Instruct"
|
|
649
|
+
"""
|
|
650
|
+
)
|
|
651
|
+
|
|
652
|
+
parser.add_argument('-d', '--database', required=True, help='Path to Turso/SQLite database or connection URL')
|
|
653
|
+
parser.add_argument('-o', '--output', default='ft_data/training_modal.jsonl', help='Output JSONL file')
|
|
654
|
+
parser.add_argument('-c', '--config', help='Configuration TOML file')
|
|
655
|
+
|
|
656
|
+
# Filter overrides
|
|
657
|
+
parser.add_argument('--mode', choices=['trajectory', 'window'], help='Filtering mode')
|
|
658
|
+
parser.add_argument('--min-reward', type=float, help='Minimum total reward')
|
|
659
|
+
parser.add_argument('--min-achievements', type=int, help='Minimum achievements')
|
|
660
|
+
parser.add_argument('--max-cost', type=float, help='Maximum cost')
|
|
661
|
+
parser.add_argument('--max-tokens', type=int, help='Maximum tokens')
|
|
662
|
+
parser.add_argument('--models', nargs='+', help='Filter by model names (e.g., Qwen/Qwen2.5-7B-Instruct)')
|
|
663
|
+
|
|
664
|
+
parser.add_argument('--dry-run', action='store_true', help='Show statistics without writing output')
|
|
665
|
+
|
|
666
|
+
args = parser.parse_args()
|
|
667
|
+
|
|
668
|
+
# Load config
|
|
669
|
+
config = {
|
|
670
|
+
"mode": "trajectory",
|
|
671
|
+
"filters": {
|
|
672
|
+
"min_total_reward": 1.0,
|
|
673
|
+
"min_achievements": 0,
|
|
674
|
+
"max_cost": 10.0,
|
|
675
|
+
"max_tokens": 100000,
|
|
676
|
+
"models": [] # Empty means all models
|
|
677
|
+
}
|
|
678
|
+
}
|
|
679
|
+
|
|
680
|
+
if args.config:
|
|
681
|
+
with open(args.config, 'r') as f:
|
|
682
|
+
loaded_config = toml.load(f)
|
|
683
|
+
config.update(loaded_config)
|
|
684
|
+
|
|
685
|
+
# Apply command-line overrides
|
|
686
|
+
if args.mode:
|
|
687
|
+
config["mode"] = args.mode
|
|
688
|
+
if args.min_reward is not None:
|
|
689
|
+
config["filters"]["min_total_reward"] = args.min_reward
|
|
690
|
+
if args.min_achievements is not None:
|
|
691
|
+
config["filters"]["min_achievements"] = args.min_achievements
|
|
692
|
+
if args.max_cost is not None:
|
|
693
|
+
config["filters"]["max_cost"] = args.max_cost
|
|
694
|
+
if args.max_tokens is not None:
|
|
695
|
+
config["filters"]["max_tokens"] = args.max_tokens
|
|
696
|
+
if args.models:
|
|
697
|
+
config["filters"]["models"] = args.models
|
|
698
|
+
|
|
699
|
+
# Convert database path to proper URL format if needed
|
|
700
|
+
db_url = args.database
|
|
701
|
+
if db_url.startswith("sqlite:///"):
|
|
702
|
+
# Already in URL format
|
|
703
|
+
pass
|
|
704
|
+
elif db_url.endswith(".db"):
|
|
705
|
+
# Convert file path to URL
|
|
706
|
+
db_url = f"sqlite+aiosqlite:///{db_url}"
|
|
707
|
+
|
|
708
|
+
print(f"🤖 Modal/Synth Fine-Tuning Data Filter (v3)")
|
|
709
|
+
print(f"Using database: {db_url}")
|
|
710
|
+
print(f"Output file: {args.output}")
|
|
711
|
+
print(f"Mode: {config['mode']}")
|
|
712
|
+
print(f"Filters: {json.dumps(config['filters'], indent=2)}")
|
|
713
|
+
|
|
714
|
+
if args.dry_run:
|
|
715
|
+
print("\n🔍 DRY RUN - No output will be written")
|
|
716
|
+
|
|
717
|
+
# Run filtering
|
|
718
|
+
async def run():
|
|
719
|
+
num_examples, stats = await filter_traces_from_turso(
|
|
720
|
+
db_url,
|
|
721
|
+
args.output if not args.dry_run else "/dev/null",
|
|
722
|
+
config
|
|
723
|
+
)
|
|
724
|
+
|
|
725
|
+
# Print statistics
|
|
726
|
+
print_statistics(stats)
|
|
727
|
+
|
|
728
|
+
if not args.dry_run:
|
|
729
|
+
print(f"\n✅ Successfully wrote {num_examples} training examples to {args.output}")
|
|
730
|
+
print(f" Ready for Modal/Synth fine-tuning!")
|
|
731
|
+
else:
|
|
732
|
+
print(f"\n✅ Would write {num_examples} training examples (dry run)")
|
|
733
|
+
|
|
734
|
+
asyncio.run(run())
|
|
735
|
+
|
|
736
|
+
|
|
737
|
+
if __name__ == "__main__":
|
|
738
|
+
main()
|