synth-ai 0.2.2.dev0__py3-none-any.whl → 0.2.4.dev2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (115) hide show
  1. synth_ai/cli/__init__.py +66 -0
  2. synth_ai/cli/balance.py +205 -0
  3. synth_ai/cli/calc.py +70 -0
  4. synth_ai/cli/demo.py +74 -0
  5. synth_ai/{cli.py → cli/legacy_root_backup.py} +60 -15
  6. synth_ai/cli/man.py +103 -0
  7. synth_ai/cli/recent.py +126 -0
  8. synth_ai/cli/root.py +184 -0
  9. synth_ai/cli/status.py +126 -0
  10. synth_ai/cli/traces.py +136 -0
  11. synth_ai/cli/watch.py +508 -0
  12. synth_ai/config/base_url.py +53 -0
  13. synth_ai/environments/examples/crafter_classic/agent_demos/analyze_semantic_words_markdown.py +252 -0
  14. synth_ai/environments/examples/crafter_classic/agent_demos/crafter_modal_ft/filter_traces_sft_duckdb_v2_backup.py +413 -0
  15. synth_ai/environments/examples/crafter_classic/agent_demos/crafter_modal_ft/filter_traces_sft_turso.py +760 -0
  16. synth_ai/environments/examples/crafter_classic/agent_demos/crafter_modal_ft/kick_off_ft_synth.py +34 -0
  17. synth_ai/environments/examples/crafter_classic/agent_demos/crafter_modal_ft/test_crafter_react_agent_lm_synth.py +1740 -0
  18. synth_ai/environments/examples/crafter_classic/agent_demos/crafter_modal_ft/test_crafter_react_agent_lm_synth_v2_backup.py +1318 -0
  19. synth_ai/environments/examples/crafter_classic/agent_demos/crafter_openai_ft/filter_traces_sft_duckdb_v2_backup.py +386 -0
  20. synth_ai/environments/examples/crafter_classic/agent_demos/crafter_openai_ft/filter_traces_sft_turso.py +580 -0
  21. synth_ai/environments/examples/crafter_classic/agent_demos/crafter_openai_ft/run_rollouts_for_models_and_compare_v2_backup.py +1352 -0
  22. synth_ai/environments/examples/crafter_classic/agent_demos/crafter_openai_ft/run_rollouts_for_models_and_compare_v3.py +4 -4
  23. synth_ai/environments/examples/crafter_classic/agent_demos/crafter_openai_ft/test_crafter_react_agent_openai_v2_backup.py +2551 -0
  24. synth_ai/environments/examples/crafter_classic/agent_demos/crafter_trace_evaluation.py +1 -1
  25. synth_ai/environments/examples/crafter_classic/agent_demos/example_v3_usage.py +1 -1
  26. synth_ai/environments/examples/crafter_classic/agent_demos/old/traces/session_crafter_episode_16_15227b68-2906-416f-acc4-d6a9b4fa5828_20250725_001154.json +1363 -1
  27. synth_ai/environments/examples/crafter_classic/agent_demos/test_crafter_react_agent.py +3 -3
  28. synth_ai/environments/examples/crafter_classic/environment.py +1 -1
  29. synth_ai/environments/examples/crafter_custom/environment.py +1 -1
  30. synth_ai/environments/examples/enron/dataset/corbt___enron_emails_sample_questions/default/0.0.0/293c9fe8170037e01cc9cf5834e0cd5ef6f1a6bb/dataset_info.json +1 -0
  31. synth_ai/environments/examples/nethack/helpers/achievements.json +64 -0
  32. synth_ai/environments/examples/red/units/test_exploration_strategy.py +1 -1
  33. synth_ai/environments/examples/red/units/test_menu_bug_reproduction.py +5 -5
  34. synth_ai/environments/examples/red/units/test_movement_debug.py +2 -2
  35. synth_ai/environments/examples/red/units/test_retry_movement.py +1 -1
  36. synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/available_envs.json +122 -0
  37. synth_ai/environments/examples/sokoban/verified_puzzles.json +54987 -0
  38. synth_ai/environments/service/core_routes.py +1 -1
  39. synth_ai/experimental/synth_oss.py +446 -0
  40. synth_ai/learning/core.py +21 -0
  41. synth_ai/learning/gateway.py +4 -0
  42. synth_ai/learning/prompts/gepa.py +0 -0
  43. synth_ai/learning/prompts/mipro.py +8 -0
  44. synth_ai/lm/__init__.py +3 -0
  45. synth_ai/lm/core/main.py +4 -0
  46. synth_ai/lm/core/main_v3.py +238 -122
  47. synth_ai/lm/core/vendor_clients.py +4 -0
  48. synth_ai/lm/provider_support/openai.py +11 -2
  49. synth_ai/lm/vendors/base.py +7 -0
  50. synth_ai/lm/vendors/openai_standard.py +339 -4
  51. synth_ai/lm/vendors/openai_standard_responses.py +243 -0
  52. synth_ai/lm/vendors/synth_client.py +155 -5
  53. synth_ai/lm/warmup.py +54 -17
  54. synth_ai/tracing/__init__.py +18 -0
  55. synth_ai/tracing_v1/__init__.py +29 -14
  56. synth_ai/tracing_v3/__init__.py +2 -2
  57. synth_ai/tracing_v3/abstractions.py +62 -17
  58. synth_ai/tracing_v3/config.py +13 -7
  59. synth_ai/tracing_v3/db_config.py +6 -6
  60. synth_ai/tracing_v3/hooks.py +1 -1
  61. synth_ai/tracing_v3/llm_call_record_helpers.py +350 -0
  62. synth_ai/tracing_v3/lm_call_record_abstractions.py +257 -0
  63. synth_ai/tracing_v3/session_tracer.py +5 -5
  64. synth_ai/tracing_v3/tests/test_concurrent_operations.py +1 -1
  65. synth_ai/tracing_v3/tests/test_llm_call_records.py +672 -0
  66. synth_ai/tracing_v3/tests/test_session_tracer.py +43 -9
  67. synth_ai/tracing_v3/tests/test_turso_manager.py +1 -1
  68. synth_ai/tracing_v3/turso/manager.py +18 -11
  69. synth_ai/tracing_v3/turso/models.py +1 -0
  70. synth_ai/tui/__main__.py +13 -0
  71. synth_ai/tui/dashboard.py +329 -0
  72. synth_ai/v0/tracing/__init__.py +0 -0
  73. synth_ai/{tracing → v0/tracing}/base_client.py +3 -3
  74. synth_ai/{tracing → v0/tracing}/client_manager.py +1 -1
  75. synth_ai/{tracing → v0/tracing}/context.py +1 -1
  76. synth_ai/{tracing → v0/tracing}/decorators.py +11 -11
  77. synth_ai/v0/tracing/events/__init__.py +0 -0
  78. synth_ai/{tracing → v0/tracing}/events/manage.py +4 -4
  79. synth_ai/{tracing → v0/tracing}/events/scope.py +6 -6
  80. synth_ai/{tracing → v0/tracing}/events/store.py +3 -3
  81. synth_ai/{tracing → v0/tracing}/immediate_client.py +6 -6
  82. synth_ai/{tracing → v0/tracing}/log_client_base.py +2 -2
  83. synth_ai/{tracing → v0/tracing}/retry_queue.py +3 -3
  84. synth_ai/{tracing → v0/tracing}/trackers.py +2 -2
  85. synth_ai/{tracing → v0/tracing}/upload.py +4 -4
  86. synth_ai/v0/tracing_v1/__init__.py +16 -0
  87. synth_ai/{tracing_v1 → v0/tracing_v1}/base_client.py +3 -3
  88. synth_ai/{tracing_v1 → v0/tracing_v1}/client_manager.py +1 -1
  89. synth_ai/{tracing_v1 → v0/tracing_v1}/context.py +1 -1
  90. synth_ai/{tracing_v1 → v0/tracing_v1}/decorators.py +11 -11
  91. synth_ai/v0/tracing_v1/events/__init__.py +0 -0
  92. synth_ai/{tracing_v1 → v0/tracing_v1}/events/manage.py +4 -4
  93. synth_ai/{tracing_v1 → v0/tracing_v1}/events/scope.py +6 -6
  94. synth_ai/{tracing_v1 → v0/tracing_v1}/events/store.py +3 -3
  95. synth_ai/{tracing_v1 → v0/tracing_v1}/immediate_client.py +6 -6
  96. synth_ai/{tracing_v1 → v0/tracing_v1}/log_client_base.py +2 -2
  97. synth_ai/{tracing_v1 → v0/tracing_v1}/retry_queue.py +3 -3
  98. synth_ai/{tracing_v1 → v0/tracing_v1}/trackers.py +2 -2
  99. synth_ai/{tracing_v1 → v0/tracing_v1}/upload.py +4 -4
  100. {synth_ai-0.2.2.dev0.dist-info → synth_ai-0.2.4.dev2.dist-info}/METADATA +100 -5
  101. {synth_ai-0.2.2.dev0.dist-info → synth_ai-0.2.4.dev2.dist-info}/RECORD +115 -75
  102. /synth_ai/{tracing/events/__init__.py → compound/cais.py} +0 -0
  103. /synth_ai/{tracing_v1/events/__init__.py → environments/examples/crafter_classic/debug_translation.py} +0 -0
  104. /synth_ai/{tracing → v0/tracing}/abstractions.py +0 -0
  105. /synth_ai/{tracing → v0/tracing}/config.py +0 -0
  106. /synth_ai/{tracing → v0/tracing}/local.py +0 -0
  107. /synth_ai/{tracing → v0/tracing}/utils.py +0 -0
  108. /synth_ai/{tracing_v1 → v0/tracing_v1}/abstractions.py +0 -0
  109. /synth_ai/{tracing_v1 → v0/tracing_v1}/config.py +0 -0
  110. /synth_ai/{tracing_v1 → v0/tracing_v1}/local.py +0 -0
  111. /synth_ai/{tracing_v1 → v0/tracing_v1}/utils.py +0 -0
  112. {synth_ai-0.2.2.dev0.dist-info → synth_ai-0.2.4.dev2.dist-info}/WHEEL +0 -0
  113. {synth_ai-0.2.2.dev0.dist-info → synth_ai-0.2.4.dev2.dist-info}/entry_points.txt +0 -0
  114. {synth_ai-0.2.2.dev0.dist-info → synth_ai-0.2.4.dev2.dist-info}/licenses/LICENSE +0 -0
  115. {synth_ai-0.2.2.dev0.dist-info → synth_ai-0.2.4.dev2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,760 @@
1
+ #!/usr/bin/env python3
2
+ """
3
+ Filter traces from Turso/SQLite (v3) to create Modal/Synth SFT-ready .jsonl files
4
+ Supports two modes:
5
+ 1. Trajectory-level filtering: Include entire trajectories above a score threshold
6
+ 2. Window-based filtering: Extract high-scoring windows of actions
7
+
8
+ This is the v3 version using the new async Turso-based tracing system.
9
+ """
10
+
11
+ import json
12
+ import argparse
13
+ import asyncio
14
+ from pathlib import Path
15
+ from typing import List, Dict, Any, Tuple, Optional
16
+ from collections import defaultdict
17
+ import numpy as np
18
+ import os
19
+ import sys
20
+ import toml
21
+ import pandas as pd
22
+
23
+ # Add synth_ai to path
24
+ sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent.parent.parent))
25
+
26
+ from synth_ai.tracing_v3 import SessionTracer
27
+ from synth_ai.tracing_v3.turso.manager import AsyncSQLTraceManager
28
+ from synth_ai.tracing_v3.abstractions import LMCAISEvent, EnvironmentEvent, RuntimeEvent
29
+
30
+
31
+ def create_histogram(data: List[float], bins: int = 20, width: int = 60, height: int = 15,
32
+ title: str = "", x_label: str = "", y_label: str = "") -> str:
33
+ """Create a beautiful ASCII histogram."""
34
+ if not data:
35
+ return "No data to display"
36
+
37
+ # Create histogram
38
+ counts, edges = np.histogram(data, bins=bins)
39
+ max_count = max(counts) if len(counts) > 0 else 1
40
+
41
+ # Normalize heights
42
+ if max_count > 0:
43
+ heights = [int(c * height / max_count) for c in counts]
44
+ else:
45
+ heights = [0] * len(counts)
46
+
47
+ # Build the plot
48
+ lines = []
49
+
50
+ # Title
51
+ if title:
52
+ lines.append(f"\n{title.center(width + 10)}")
53
+ lines.append("=" * (width + 10))
54
+
55
+ # Y-axis label
56
+ if y_label:
57
+ lines.append(f"{y_label}")
58
+
59
+ # Plot area with y-axis
60
+ for y in range(height, 0, -1):
61
+ # Y-axis value
62
+ y_val = int(max_count * y / height)
63
+ line = f"{y_val:>6} │"
64
+
65
+ # Bars
66
+ for h in heights:
67
+ if h >= y:
68
+ line += "█"
69
+ else:
70
+ line += " "
71
+
72
+ lines.append(line)
73
+
74
+ # X-axis
75
+ lines.append(f"{'':>6} └" + "─" * len(heights))
76
+
77
+ # X-axis labels
78
+ x_labels_line = " " * 8
79
+ min_val, max_val = min(data), max(data)
80
+
81
+ # Add labels at key positions
82
+ label_positions = [0, len(heights)//4, len(heights)//2, 3*len(heights)//4, len(heights)-1]
83
+ for i, pos in enumerate(label_positions):
84
+ if pos < len(edges) - 1:
85
+ val = edges[pos]
86
+ label = f"{val:.1f}"
87
+ # Calculate position
88
+ target_pos = 8 + pos
89
+ if i == 0:
90
+ x_labels_line = label + x_labels_line[len(label):]
91
+ elif i == len(label_positions) - 1:
92
+ start = max(0, target_pos - len(label))
93
+ x_labels_line = x_labels_line[:start] + label
94
+ else:
95
+ start = max(0, target_pos - len(label)//2)
96
+ end = min(len(x_labels_line), start + len(label))
97
+ if start < len(x_labels_line):
98
+ x_labels_line = x_labels_line[:start] + label[:end-start] + x_labels_line[end:]
99
+
100
+ lines.append(x_labels_line)
101
+
102
+ # X-axis label
103
+ if x_label:
104
+ lines.append(f"\n{x_label.center(width + 10)}")
105
+
106
+ return "\n".join(lines)
107
+
108
+
109
+ def create_bar_chart(categories: List[str], values: List[int], width: int = 60,
110
+ title: str = "", show_values: bool = True) -> str:
111
+ """Create a horizontal bar chart."""
112
+ if not categories or not values:
113
+ return "No data to display"
114
+
115
+ max_val = max(values) if values else 1
116
+ lines = []
117
+
118
+ # Title
119
+ if title:
120
+ lines.append(f"\n{title}")
121
+ lines.append("=" * (width + 20))
122
+
123
+ # Find longest category name for alignment
124
+ max_cat_len = max(len(cat) for cat in categories)
125
+
126
+ # Create bars
127
+ for cat, val in zip(categories, values):
128
+ # Normalize bar length
129
+ bar_len = int(val * width / max_val) if max_val > 0 else 0
130
+ bar = "█" * bar_len
131
+
132
+ # Format line
133
+ if show_values:
134
+ line = f"{cat:<{max_cat_len}} │ {bar} {val}"
135
+ else:
136
+ line = f"{cat:<{max_cat_len}} │ {bar}"
137
+
138
+ lines.append(line)
139
+
140
+ return "\n".join(lines)
141
+
142
+
143
+ class FinetuningDataExtractorV3:
144
+ """Extract fine-tuning data from v3 Turso traces."""
145
+
146
+ def __init__(self, db_url: str):
147
+ self.db_manager = AsyncSQLTraceManager(db_url)
148
+ self._initialized = False
149
+
150
+ async def __aenter__(self):
151
+ await self.db_manager.initialize()
152
+ self._initialized = True
153
+ return self
154
+
155
+ async def __aexit__(self, exc_type, exc_val, exc_tb):
156
+ await self.db_manager.close()
157
+
158
+ async def get_all_sessions(self) -> pd.DataFrame:
159
+ """Get all session IDs from the database."""
160
+ query = """
161
+ SELECT DISTINCT session_id, created_at
162
+ FROM session_traces
163
+ ORDER BY created_at DESC
164
+ """
165
+ return await self.db_manager.query_traces(query)
166
+
167
+ async def get_session_metrics(self, session_id: str) -> Dict[str, Any]:
168
+ """Get metrics for a specific session."""
169
+ # Get total reward from environment events
170
+ # Now that rewards are properly saved in the DB, we can use them directly
171
+ reward_query = """
172
+ SELECT COALESCE(SUM(reward), 0) as total_reward
173
+ FROM events
174
+ WHERE session_id = :session_id
175
+ AND event_type = 'environment'
176
+ AND reward IS NOT NULL
177
+ """
178
+ reward_df = await self.db_manager.query_traces(reward_query, {"session_id": session_id})
179
+ total_reward = float(reward_df['total_reward'].iloc[0]) if not reward_df.empty else 0.0
180
+
181
+ # Get total tokens and cost from LM events
182
+ lm_query = """
183
+ SELECT
184
+ COALESCE(SUM(total_tokens), 0) as total_tokens,
185
+ COALESCE(SUM(cost_usd) / 100.0, 0) as total_cost
186
+ FROM events
187
+ WHERE session_id = :session_id
188
+ AND event_type = 'cais'
189
+ """
190
+ lm_df = await self.db_manager.query_traces(lm_query, {"session_id": session_id})
191
+
192
+ total_tokens = int(lm_df['total_tokens'].iloc[0]) if not lm_df.empty else 0
193
+ total_cost = float(lm_df['total_cost'].iloc[0]) if not lm_df.empty else 0.0
194
+
195
+ return {
196
+ 'session_id': session_id,
197
+ 'total_reward': total_reward,
198
+ 'total_tokens': total_tokens,
199
+ 'total_cost': total_cost
200
+ }
201
+
202
+ async def get_session_achievements(self, session_id: str) -> List[str]:
203
+ """Get list of achievements unlocked in a session."""
204
+ # Look for achievement events in environment data
205
+ # In v3, system_state_after is a direct column
206
+ query = """
207
+ SELECT system_state_after
208
+ FROM events
209
+ WHERE session_id = :session_id
210
+ AND event_type = 'environment'
211
+ AND system_state_after IS NOT NULL
212
+ ORDER BY id DESC
213
+ LIMIT 1
214
+ """
215
+ df = await self.db_manager.query_traces(query, {"session_id": session_id})
216
+
217
+ if df.empty:
218
+ return []
219
+
220
+ try:
221
+ # Parse the system_state_after JSON
222
+ state_after = df['system_state_after'].iloc[0]
223
+ if state_after:
224
+ # If it's a string, parse it
225
+ if isinstance(state_after, str):
226
+ state_after = json.loads(state_after)
227
+
228
+ # Look for achievements in public_state
229
+ if isinstance(state_after, dict) and 'public_state' in state_after:
230
+ public_state = state_after['public_state']
231
+ if 'achievements_status' in public_state:
232
+ achievements = public_state['achievements_status']
233
+ # Return list of unlocked achievements
234
+ return [k for k, v in achievements.items() if v]
235
+ except Exception as e:
236
+ print(f"Error parsing achievements: {e}")
237
+ pass
238
+
239
+ return []
240
+
241
+ async def filter_by_achievements(self, min_achievements: int) -> List[str]:
242
+ """Get sessions with at least min_achievements unlocked."""
243
+ all_sessions = await self.get_all_sessions()
244
+ qualifying_sessions = []
245
+
246
+ for _, row in all_sessions.iterrows():
247
+ session_id = row['session_id']
248
+ achievements = await self.get_session_achievements(session_id)
249
+ if len(achievements) >= min_achievements:
250
+ qualifying_sessions.append(session_id)
251
+
252
+ return qualifying_sessions
253
+
254
+ async def extract_openai_format_from_call_records(self, session_ids: List[str], min_reward: float = 0.0) -> List[Dict[str, Any]]:
255
+ """Extract training data in OpenAI format from call_records in LMCAISEvents.
256
+
257
+ This is the new method that uses the detailed LLM interaction data stored
258
+ in call_records instead of relying on separate message records.
259
+ """
260
+ training_data = []
261
+
262
+ for session_id in session_ids:
263
+ # Get LM CAIS events with call_records from the proper column
264
+ events_query = """
265
+ SELECT e.call_records, st.turn_number
266
+ FROM events e
267
+ LEFT JOIN session_timesteps st ON e.timestep_id = st.id
268
+ WHERE e.session_id = :session_id
269
+ AND e.event_type = 'cais'
270
+ AND e.call_records IS NOT NULL
271
+ ORDER BY COALESCE(st.turn_number, e.message_time), e.id
272
+ """
273
+
274
+ events_df = await self.db_manager.query_traces(events_query, {"session_id": session_id})
275
+
276
+ if len(events_df) == 0:
277
+ # Fall back to old method if no call_records
278
+ continue
279
+
280
+ # Extract messages from call_records
281
+ all_messages = []
282
+
283
+ for _, row in events_df.iterrows():
284
+ call_records_json = row['call_records']
285
+ if not call_records_json:
286
+ continue
287
+
288
+ # Parse the call_records JSON directly from the column
289
+ try:
290
+ import json
291
+ if isinstance(call_records_json, str):
292
+ call_records = json.loads(call_records_json)
293
+ else:
294
+ call_records = call_records_json
295
+
296
+ # Process each call record
297
+ for record in call_records:
298
+ # Extract input messages
299
+ for msg in record.get('input_messages', []):
300
+ role = msg.get('role', 'user')
301
+ parts = msg.get('parts', [])
302
+
303
+ # Combine text parts
304
+ text_content = []
305
+ for part in parts:
306
+ if part.get('type') == 'text' and part.get('text'):
307
+ text_content.append(part['text'])
308
+
309
+ if text_content:
310
+ content = ' '.join(text_content)
311
+ if role == 'system' and not any(m['role'] == 'system' for m in all_messages):
312
+ all_messages.insert(0, {"role": "system", "content": content})
313
+ elif role != 'system':
314
+ all_messages.append({"role": role, "content": content})
315
+
316
+ # Extract output messages
317
+ for msg in record.get('output_messages', []):
318
+ role = msg.get('role', 'assistant')
319
+ parts = msg.get('parts', [])
320
+
321
+ # Combine text parts
322
+ text_content = []
323
+ for part in parts:
324
+ if part.get('type') == 'text' and part.get('text'):
325
+ text_content.append(part['text'])
326
+
327
+ if text_content:
328
+ content = ' '.join(text_content)
329
+ all_messages.append({"role": role, "content": content})
330
+
331
+ except Exception as e:
332
+ print(f"Error parsing call_records for session {session_id}: {e}")
333
+ continue
334
+
335
+ # Only include if we have a complete conversation
336
+ if len(all_messages) > 1:
337
+ # Get total reward for this session
338
+ reward_query = """
339
+ SELECT COALESCE(SUM(reward), 0) as total_reward
340
+ FROM events
341
+ WHERE session_id = :session_id
342
+ AND event_type = 'environment'
343
+ AND reward IS NOT NULL
344
+ """
345
+ reward_df = await self.db_manager.query_traces(reward_query, {"session_id": session_id})
346
+ total_reward = reward_df.iloc[0]['total_reward'] if len(reward_df) > 0 else 0
347
+
348
+ training_data.append({
349
+ "messages": all_messages,
350
+ "metadata": {
351
+ "session_id": session_id,
352
+ "total_reward": float(total_reward),
353
+ "source": "call_records" # Mark that this came from call_records
354
+ }
355
+ })
356
+
357
+ return training_data
358
+
359
+ async def extract_openai_format(self, session_ids: List[str], min_reward: float = 0.0) -> List[Dict[str, Any]]:
360
+ """Extract training data in OpenAI format from filtered sessions."""
361
+ training_data = []
362
+
363
+ for session_id in session_ids:
364
+ # Get messages directly from the messages table
365
+ messages_query = """
366
+ SELECT m.message_type, m.content, m.message_time, st.turn_number
367
+ FROM messages m
368
+ LEFT JOIN session_timesteps st ON m.timestep_id = st.id
369
+ WHERE m.session_id = :session_id
370
+ ORDER BY COALESCE(st.turn_number, m.message_time), m.id
371
+ """
372
+
373
+ messages_df = await self.db_manager.query_traces(messages_query, {"session_id": session_id})
374
+
375
+ if len(messages_df) == 0:
376
+ continue
377
+
378
+ # Build conversation history
379
+ messages = []
380
+ system_message = None
381
+
382
+ for _, row in messages_df.iterrows():
383
+ msg_type = row['message_type']
384
+ content = row['content']
385
+
386
+ # Parse content if it's JSON (from origin_system_id format)
387
+ try:
388
+ import json
389
+ content_data = json.loads(content)
390
+ if isinstance(content_data, dict) and 'payload' in content_data:
391
+ content = content_data['payload']
392
+ except:
393
+ pass
394
+
395
+ if msg_type == 'system' and system_message is None:
396
+ # Extract system message from the first system message
397
+ if isinstance(content, str):
398
+ system_message = content
399
+
400
+ elif msg_type == 'user':
401
+ # Format user messages
402
+ if isinstance(content, dict):
403
+ # Convert observation dict to formatted string
404
+ content = self._format_observation_content(content)
405
+ messages.append({"role": "user", "content": str(content)})
406
+
407
+ elif msg_type == 'assistant':
408
+ messages.append({"role": "assistant", "content": str(content)})
409
+
410
+ # Add system message at the beginning if found
411
+ if system_message:
412
+ messages.insert(0, {"role": "system", "content": system_message})
413
+
414
+ # Only include if we have a complete conversation
415
+ if len(messages) > 1:
416
+ # Get total reward for this session
417
+ reward_query = """
418
+ SELECT COALESCE(SUM(reward), 0) as total_reward
419
+ FROM events
420
+ WHERE session_id = :session_id
421
+ AND event_type = 'environment'
422
+ AND reward IS NOT NULL
423
+ """
424
+ reward_df = await self.db_manager.query_traces(reward_query, {"session_id": session_id})
425
+ total_reward = reward_df.iloc[0]['total_reward'] if len(reward_df) > 0 else 0
426
+
427
+ training_data.append({
428
+ "messages": messages,
429
+ "metadata": {
430
+ "session_id": session_id,
431
+ "total_reward": float(total_reward) # Convert to float for JSON serialization
432
+ }
433
+ })
434
+
435
+ return training_data
436
+
437
+ def _format_observation_content(self, obs: Dict[str, Any]) -> str:
438
+ """Format observation dict into a readable string."""
439
+ if not isinstance(obs, dict):
440
+ return str(obs)
441
+
442
+ # Extract key fields for a concise representation
443
+ parts = []
444
+
445
+ if 'inventory' in obs:
446
+ inv = obs['inventory']
447
+ inv_str = ", ".join([f"{k}: {v}" for k, v in inv.items() if v > 0])
448
+ if inv_str:
449
+ parts.append(f"Inventory: {inv_str}")
450
+
451
+ if 'achievements_status' in obs:
452
+ achievements = [k for k, v in obs['achievements_status'].items() if v]
453
+ if achievements:
454
+ parts.append(f"Achievements: {', '.join(achievements)}")
455
+
456
+ if 'health' in obs:
457
+ parts.append(f"Health: {obs.get('health', 0)}")
458
+
459
+ return "; ".join(parts) if parts else "Empty observation"
460
+
461
+
462
+ async def filter_traces_from_turso(
463
+ db_url: str,
464
+ output_path: str,
465
+ config: Dict[str, Any]
466
+ ) -> Tuple[int, Dict[str, Any]]:
467
+ """
468
+ Filter traces from Turso/SQLite v3 database based on configuration.
469
+
470
+ Returns:
471
+ Tuple of (num_examples, statistics_dict)
472
+ """
473
+ mode = config.get("mode", "trajectory")
474
+ filters = config.get("filters", {})
475
+
476
+ # Extract filtering parameters
477
+ min_reward = filters.get("min_total_reward", 0.0)
478
+ min_achievements = filters.get("min_achievements", 0)
479
+ max_cost = filters.get("max_cost", float('inf'))
480
+ max_tokens = filters.get("max_tokens", float('inf'))
481
+
482
+ # Modal/Synth specific: filter by model if specified
483
+ target_models = filters.get("models", [])
484
+
485
+ statistics = {
486
+ "total_sessions": 0,
487
+ "filtered_sessions": 0,
488
+ "total_examples": 0,
489
+ "reward_distribution": [],
490
+ "token_distribution": [],
491
+ "cost_distribution": [],
492
+ "model_distribution": defaultdict(int)
493
+ }
494
+
495
+ async with FinetuningDataExtractorV3(db_url) as extractor:
496
+ # Get all sessions
497
+ all_sessions = await extractor.get_all_sessions()
498
+ statistics["total_sessions"] = len(all_sessions)
499
+
500
+ # Filter sessions based on criteria
501
+ filtered_sessions = []
502
+
503
+ for _, row in all_sessions.iterrows():
504
+ session_id = row['session_id']
505
+ metrics = await extractor.get_session_metrics(session_id)
506
+
507
+ # Apply filters
508
+ if metrics['total_reward'] < min_reward:
509
+ continue
510
+ if metrics['total_cost'] > max_cost:
511
+ continue
512
+ if metrics['total_tokens'] > max_tokens:
513
+ continue
514
+
515
+ # Check achievements if required
516
+ if min_achievements > 0:
517
+ achievements = await extractor.get_session_achievements(session_id)
518
+ if len(achievements) < min_achievements:
519
+ continue
520
+
521
+ # Check model filter if specified
522
+ if target_models:
523
+ model_query = """
524
+ SELECT DISTINCT model_name
525
+ FROM events
526
+ WHERE session_id = :session_id
527
+ AND event_type = 'cais'
528
+ AND model_name IS NOT NULL
529
+ """
530
+ model_df = await extractor.db_manager.query_traces(
531
+ model_query, {"session_id": session_id}
532
+ )
533
+ session_models = model_df['model_name'].tolist() if not model_df.empty else []
534
+ if not any(model in target_models for model in session_models):
535
+ continue
536
+
537
+ filtered_sessions.append(session_id)
538
+
539
+ # Collect statistics
540
+ statistics["reward_distribution"].append(metrics['total_reward'])
541
+ statistics["token_distribution"].append(metrics['total_tokens'])
542
+ statistics["cost_distribution"].append(metrics['total_cost'])
543
+
544
+ statistics["filtered_sessions"] = len(filtered_sessions)
545
+
546
+ # Extract training data
547
+ if mode == "trajectory":
548
+ # Try new method first (using call_records)
549
+ training_data = await extractor.extract_openai_format_from_call_records(
550
+ session_ids=filtered_sessions,
551
+ min_reward=min_reward
552
+ )
553
+
554
+ # If no data from call_records, fall back to old method
555
+ if not training_data:
556
+ print("No call_records found, falling back to message-based extraction...")
557
+ training_data = await extractor.extract_openai_format(
558
+ session_ids=filtered_sessions,
559
+ min_reward=min_reward
560
+ )
561
+ else: # window mode
562
+ # For window mode, we need to implement window extraction
563
+ # For now, use trajectory mode
564
+ training_data = await extractor.extract_openai_format(
565
+ session_ids=filtered_sessions,
566
+ min_reward=min_reward
567
+ )
568
+
569
+ statistics["total_examples"] = len(training_data)
570
+
571
+ # Write to output file
572
+ output_file = Path(output_path)
573
+ output_file.parent.mkdir(exist_ok=True)
574
+
575
+ with open(output_file, 'w') as f:
576
+ for example in training_data:
577
+ f.write(json.dumps(example) + '\n')
578
+
579
+ # Get model distribution
580
+ model_query = """
581
+ SELECT model_name, COUNT(*) as count
582
+ FROM events
583
+ WHERE event_type = 'cais'
584
+ AND model_name IS NOT NULL
585
+ GROUP BY model_name
586
+ """
587
+ model_stats = await extractor.db_manager.query_traces(model_query)
588
+ for _, row in model_stats.iterrows():
589
+ statistics["model_distribution"][row['model_name']] = int(row['count'])
590
+
591
+ return len(training_data), statistics
592
+
593
+
594
+ def print_statistics(stats: Dict[str, Any]):
595
+ """Print filtering statistics with visualizations."""
596
+ print("\n" + "="*80)
597
+ print("FILTERING STATISTICS (Modal/Synth - v3)")
598
+ print("="*80)
599
+
600
+ # Basic counts
601
+ print(f"\nTotal sessions in database: {stats['total_sessions']}")
602
+ print(f"Sessions after filtering: {stats['filtered_sessions']}")
603
+ print(f"Training examples generated: {stats['total_examples']}")
604
+
605
+ filter_rate = (stats['filtered_sessions'] / stats['total_sessions'] * 100) if stats['total_sessions'] > 0 else 0
606
+ print(f"Filter pass rate: {filter_rate:.1f}%")
607
+
608
+ # Reward distribution
609
+ if stats['reward_distribution'] and any(not np.isnan(x) for x in stats['reward_distribution']):
610
+ valid_rewards = [x for x in stats['reward_distribution'] if not np.isnan(x)]
611
+ if valid_rewards:
612
+ print(create_histogram(
613
+ valid_rewards,
614
+ bins=20,
615
+ title="Reward Distribution",
616
+ x_label="Total Reward",
617
+ y_label="Count"
618
+ ))
619
+
620
+ print(f"\nReward statistics:")
621
+ print(f" Min: {min(valid_rewards):.2f}")
622
+ print(f" Max: {max(valid_rewards):.2f}")
623
+ print(f" Mean: {np.mean(valid_rewards):.2f}")
624
+ print(f" Median: {np.median(valid_rewards):.2f}")
625
+ else:
626
+ print("\nNo valid reward data to display.")
627
+
628
+ # Token distribution
629
+ if stats['token_distribution'] and any(not np.isnan(x) for x in stats['token_distribution']):
630
+ valid_tokens = [x for x in stats['token_distribution'] if not np.isnan(x)]
631
+ if valid_tokens:
632
+ print(create_histogram(
633
+ valid_tokens,
634
+ bins=20,
635
+ title="Token Usage Distribution",
636
+ x_label="Total Tokens",
637
+ y_label="Count"
638
+ ))
639
+
640
+ # Model distribution
641
+ if stats['model_distribution']:
642
+ models = list(stats['model_distribution'].keys())
643
+ counts = list(stats['model_distribution'].values())
644
+ print(create_bar_chart(
645
+ models,
646
+ counts,
647
+ title="Model Usage",
648
+ show_values=True
649
+ ))
650
+
651
+ print("\n" + "="*80)
652
+
653
+
654
+ def main():
655
+ parser = argparse.ArgumentParser(
656
+ description="Filter traces from Turso/SQLite v3 for Modal/Synth fine-tuning",
657
+ formatter_class=argparse.RawDescriptionHelpFormatter,
658
+ epilog="""
659
+ Example usage:
660
+ # Use default config
661
+ python filter_traces_sft_turso.py -d sqlite:///traces.db -o ft_data/training.jsonl
662
+
663
+ # Use custom config file
664
+ python filter_traces_sft_turso.py -d sqlite:///traces.db -c filter_config.toml
665
+
666
+ # Override config parameters
667
+ python filter_traces_sft_turso.py -d sqlite:///traces.db --min-reward 5.0 --max-cost 0.1
668
+
669
+ # Filter by model
670
+ python filter_traces_sft_turso.py -d sqlite:///traces.db --models "Qwen/Qwen2.5-7B-Instruct"
671
+ """
672
+ )
673
+
674
+ parser.add_argument('-d', '--database', required=True, help='Path to Turso/SQLite database or connection URL')
675
+ parser.add_argument('-o', '--output', default='ft_data/training_modal.jsonl', help='Output JSONL file')
676
+ parser.add_argument('-c', '--config', help='Configuration TOML file')
677
+
678
+ # Filter overrides
679
+ parser.add_argument('--mode', choices=['trajectory', 'window'], help='Filtering mode')
680
+ parser.add_argument('--min-reward', type=float, help='Minimum total reward')
681
+ parser.add_argument('--min-achievements', type=int, help='Minimum achievements')
682
+ parser.add_argument('--max-cost', type=float, help='Maximum cost')
683
+ parser.add_argument('--max-tokens', type=int, help='Maximum tokens')
684
+ parser.add_argument('--models', nargs='+', help='Filter by model names (e.g., Qwen/Qwen2.5-7B-Instruct)')
685
+
686
+ parser.add_argument('--dry-run', action='store_true', help='Show statistics without writing output')
687
+
688
+ args = parser.parse_args()
689
+
690
+ # Load config
691
+ config = {
692
+ "mode": "trajectory",
693
+ "filters": {
694
+ "min_total_reward": 1.0,
695
+ "min_achievements": 0,
696
+ "max_cost": 10.0,
697
+ "max_tokens": 100000,
698
+ "models": [] # Empty means all models
699
+ }
700
+ }
701
+
702
+ if args.config:
703
+ with open(args.config, 'r') as f:
704
+ loaded_config = toml.load(f)
705
+ config.update(loaded_config)
706
+
707
+ # Apply command-line overrides
708
+ if args.mode:
709
+ config["mode"] = args.mode
710
+ if args.min_reward is not None:
711
+ config["filters"]["min_total_reward"] = args.min_reward
712
+ if args.min_achievements is not None:
713
+ config["filters"]["min_achievements"] = args.min_achievements
714
+ if args.max_cost is not None:
715
+ config["filters"]["max_cost"] = args.max_cost
716
+ if args.max_tokens is not None:
717
+ config["filters"]["max_tokens"] = args.max_tokens
718
+ if args.models:
719
+ config["filters"]["models"] = args.models
720
+
721
+ # Convert database path to proper URL format if needed
722
+ db_url = args.database
723
+ if db_url.startswith("sqlite:///"):
724
+ # Already in URL format
725
+ pass
726
+ elif db_url.endswith(".db"):
727
+ # Convert file path to URL
728
+ db_url = f"sqlite+aiosqlite:///{db_url}"
729
+
730
+ print(f"🤖 Modal/Synth Fine-Tuning Data Filter (v3)")
731
+ print(f"Using database: {db_url}")
732
+ print(f"Output file: {args.output}")
733
+ print(f"Mode: {config['mode']}")
734
+ print(f"Filters: {json.dumps(config['filters'], indent=2)}")
735
+
736
+ if args.dry_run:
737
+ print("\n🔍 DRY RUN - No output will be written")
738
+
739
+ # Run filtering
740
+ async def run():
741
+ num_examples, stats = await filter_traces_from_turso(
742
+ db_url,
743
+ args.output if not args.dry_run else "/dev/null",
744
+ config
745
+ )
746
+
747
+ # Print statistics
748
+ print_statistics(stats)
749
+
750
+ if not args.dry_run:
751
+ print(f"\n✅ Successfully wrote {num_examples} training examples to {args.output}")
752
+ print(f" Ready for Modal/Synth fine-tuning!")
753
+ else:
754
+ print(f"\n✅ Would write {num_examples} training examples (dry run)")
755
+
756
+ asyncio.run(run())
757
+
758
+
759
+ if __name__ == "__main__":
760
+ main()