synth-ai 0.2.2.dev0__py3-none-any.whl → 0.2.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (98) hide show
  1. synth_ai/cli/__init__.py +66 -0
  2. synth_ai/cli/balance.py +205 -0
  3. synth_ai/cli/calc.py +70 -0
  4. synth_ai/cli/demo.py +74 -0
  5. synth_ai/{cli.py → cli/legacy_root_backup.py} +60 -15
  6. synth_ai/cli/man.py +103 -0
  7. synth_ai/cli/recent.py +126 -0
  8. synth_ai/cli/root.py +184 -0
  9. synth_ai/cli/status.py +126 -0
  10. synth_ai/cli/traces.py +136 -0
  11. synth_ai/cli/watch.py +508 -0
  12. synth_ai/config/base_url.py +53 -0
  13. synth_ai/environments/examples/crafter_classic/agent_demos/analyze_semantic_words_markdown.py +252 -0
  14. synth_ai/environments/examples/crafter_classic/agent_demos/crafter_modal_ft/filter_traces_sft_duckdb_v2_backup.py +413 -0
  15. synth_ai/environments/examples/crafter_classic/agent_demos/crafter_modal_ft/filter_traces_sft_turso.py +646 -0
  16. synth_ai/environments/examples/crafter_classic/agent_demos/crafter_modal_ft/kick_off_ft_synth.py +34 -0
  17. synth_ai/environments/examples/crafter_classic/agent_demos/crafter_modal_ft/test_crafter_react_agent_lm_synth.py +1740 -0
  18. synth_ai/environments/examples/crafter_classic/agent_demos/crafter_modal_ft/test_crafter_react_agent_lm_synth_v2_backup.py +1318 -0
  19. synth_ai/environments/examples/crafter_classic/agent_demos/crafter_openai_ft/filter_traces_sft_duckdb_v2_backup.py +386 -0
  20. synth_ai/environments/examples/crafter_classic/agent_demos/crafter_openai_ft/filter_traces_sft_turso.py +580 -0
  21. synth_ai/environments/examples/crafter_classic/agent_demos/crafter_openai_ft/run_rollouts_for_models_and_compare_v2_backup.py +1352 -0
  22. synth_ai/environments/examples/crafter_classic/agent_demos/crafter_openai_ft/test_crafter_react_agent_openai_v2_backup.py +2551 -0
  23. synth_ai/environments/examples/crafter_classic/agent_demos/crafter_trace_evaluation.py +1 -1
  24. synth_ai/environments/examples/crafter_classic/agent_demos/old/traces/session_crafter_episode_16_15227b68-2906-416f-acc4-d6a9b4fa5828_20250725_001154.json +1363 -1
  25. synth_ai/environments/examples/crafter_classic/agent_demos/test_crafter_react_agent.py +3 -3
  26. synth_ai/environments/examples/enron/dataset/corbt___enron_emails_sample_questions/default/0.0.0/293c9fe8170037e01cc9cf5834e0cd5ef6f1a6bb/dataset_info.json +1 -0
  27. synth_ai/environments/examples/nethack/helpers/achievements.json +64 -0
  28. synth_ai/environments/examples/red/units/test_exploration_strategy.py +1 -1
  29. synth_ai/environments/examples/red/units/test_menu_bug_reproduction.py +5 -5
  30. synth_ai/environments/examples/red/units/test_movement_debug.py +2 -2
  31. synth_ai/environments/examples/red/units/test_retry_movement.py +1 -1
  32. synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/available_envs.json +122 -0
  33. synth_ai/environments/examples/sokoban/verified_puzzles.json +54987 -0
  34. synth_ai/experimental/synth_oss.py +446 -0
  35. synth_ai/learning/core.py +21 -0
  36. synth_ai/learning/gateway.py +4 -0
  37. synth_ai/learning/prompts/mipro.py +0 -0
  38. synth_ai/lm/__init__.py +3 -0
  39. synth_ai/lm/core/main.py +4 -0
  40. synth_ai/lm/core/main_v3.py +68 -13
  41. synth_ai/lm/core/vendor_clients.py +4 -0
  42. synth_ai/lm/provider_support/openai.py +11 -2
  43. synth_ai/lm/vendors/base.py +7 -0
  44. synth_ai/lm/vendors/openai_standard.py +339 -4
  45. synth_ai/lm/vendors/openai_standard_responses.py +243 -0
  46. synth_ai/lm/vendors/synth_client.py +155 -5
  47. synth_ai/lm/warmup.py +54 -17
  48. synth_ai/tracing/__init__.py +18 -0
  49. synth_ai/tracing_v1/__init__.py +29 -14
  50. synth_ai/tracing_v3/config.py +13 -7
  51. synth_ai/tracing_v3/db_config.py +6 -6
  52. synth_ai/tracing_v3/turso/manager.py +8 -8
  53. synth_ai/tui/__main__.py +13 -0
  54. synth_ai/tui/dashboard.py +329 -0
  55. synth_ai/v0/tracing/__init__.py +0 -0
  56. synth_ai/{tracing → v0/tracing}/base_client.py +3 -3
  57. synth_ai/{tracing → v0/tracing}/client_manager.py +1 -1
  58. synth_ai/{tracing → v0/tracing}/context.py +1 -1
  59. synth_ai/{tracing → v0/tracing}/decorators.py +11 -11
  60. synth_ai/v0/tracing/events/__init__.py +0 -0
  61. synth_ai/{tracing → v0/tracing}/events/manage.py +4 -4
  62. synth_ai/{tracing → v0/tracing}/events/scope.py +6 -6
  63. synth_ai/{tracing → v0/tracing}/events/store.py +3 -3
  64. synth_ai/{tracing → v0/tracing}/immediate_client.py +6 -6
  65. synth_ai/{tracing → v0/tracing}/log_client_base.py +2 -2
  66. synth_ai/{tracing → v0/tracing}/retry_queue.py +3 -3
  67. synth_ai/{tracing → v0/tracing}/trackers.py +2 -2
  68. synth_ai/{tracing → v0/tracing}/upload.py +4 -4
  69. synth_ai/v0/tracing_v1/__init__.py +16 -0
  70. synth_ai/{tracing_v1 → v0/tracing_v1}/base_client.py +3 -3
  71. synth_ai/{tracing_v1 → v0/tracing_v1}/client_manager.py +1 -1
  72. synth_ai/{tracing_v1 → v0/tracing_v1}/context.py +1 -1
  73. synth_ai/{tracing_v1 → v0/tracing_v1}/decorators.py +11 -11
  74. synth_ai/v0/tracing_v1/events/__init__.py +0 -0
  75. synth_ai/{tracing_v1 → v0/tracing_v1}/events/manage.py +4 -4
  76. synth_ai/{tracing_v1 → v0/tracing_v1}/events/scope.py +6 -6
  77. synth_ai/{tracing_v1 → v0/tracing_v1}/events/store.py +3 -3
  78. synth_ai/{tracing_v1 → v0/tracing_v1}/immediate_client.py +6 -6
  79. synth_ai/{tracing_v1 → v0/tracing_v1}/log_client_base.py +2 -2
  80. synth_ai/{tracing_v1 → v0/tracing_v1}/retry_queue.py +3 -3
  81. synth_ai/{tracing_v1 → v0/tracing_v1}/trackers.py +2 -2
  82. synth_ai/{tracing_v1 → v0/tracing_v1}/upload.py +4 -4
  83. {synth_ai-0.2.2.dev0.dist-info → synth_ai-0.2.3.dist-info}/METADATA +98 -4
  84. {synth_ai-0.2.2.dev0.dist-info → synth_ai-0.2.3.dist-info}/RECORD +98 -62
  85. /synth_ai/{tracing/events/__init__.py → environments/examples/crafter_classic/debug_translation.py} +0 -0
  86. /synth_ai/{tracing_v1/events/__init__.py → learning/prompts/gepa.py} +0 -0
  87. /synth_ai/{tracing → v0/tracing}/abstractions.py +0 -0
  88. /synth_ai/{tracing → v0/tracing}/config.py +0 -0
  89. /synth_ai/{tracing → v0/tracing}/local.py +0 -0
  90. /synth_ai/{tracing → v0/tracing}/utils.py +0 -0
  91. /synth_ai/{tracing_v1 → v0/tracing_v1}/abstractions.py +0 -0
  92. /synth_ai/{tracing_v1 → v0/tracing_v1}/config.py +0 -0
  93. /synth_ai/{tracing_v1 → v0/tracing_v1}/local.py +0 -0
  94. /synth_ai/{tracing_v1 → v0/tracing_v1}/utils.py +0 -0
  95. {synth_ai-0.2.2.dev0.dist-info → synth_ai-0.2.3.dist-info}/WHEEL +0 -0
  96. {synth_ai-0.2.2.dev0.dist-info → synth_ai-0.2.3.dist-info}/entry_points.txt +0 -0
  97. {synth_ai-0.2.2.dev0.dist-info → synth_ai-0.2.3.dist-info}/licenses/LICENSE +0 -0
  98. {synth_ai-0.2.2.dev0.dist-info → synth_ai-0.2.3.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,580 @@
1
+ #!/usr/bin/env python3
2
+ """
3
+ Filter traces from Turso/SQLite (v3) to create OpenAI SFT-ready .jsonl files
4
+ Supports two modes:
5
+ 1. Trajectory-level filtering: Include entire trajectories above a score threshold
6
+ 2. Window-based filtering: Extract high-scoring windows of actions
7
+
8
+ This is the v3 version using the new async Turso-based tracing system.
9
+ """
10
+
11
+ import json
12
+ import argparse
13
+ import asyncio
14
+ from pathlib import Path
15
+ from typing import List, Dict, Any, Tuple, Optional
16
+ from collections import defaultdict
17
+ import numpy as np
18
+ import os
19
+ import sys
20
+ import toml
21
+ import pandas as pd
22
+
23
+ # Add synth_ai to path
24
+ sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent.parent.parent))
25
+
26
+ from synth_ai.tracing_v3 import SessionTracer
27
+ from synth_ai.tracing_v3.turso.manager import AsyncSQLTraceManager
28
+ from synth_ai.tracing_v3.abstractions import LMCAISEvent, EnvironmentEvent, RuntimeEvent
29
+
30
+
31
+ def create_histogram(data: List[float], bins: int = 20, width: int = 60, height: int = 15,
32
+ title: str = "", x_label: str = "", y_label: str = "") -> str:
33
+ """Create a beautiful ASCII histogram."""
34
+ if not data:
35
+ return "No data to display"
36
+
37
+ # Create histogram
38
+ counts, edges = np.histogram(data, bins=bins)
39
+ max_count = max(counts) if len(counts) > 0 else 1
40
+
41
+ # Normalize heights
42
+ if max_count > 0:
43
+ heights = [int(c * height / max_count) for c in counts]
44
+ else:
45
+ heights = [0] * len(counts)
46
+
47
+ # Build the plot
48
+ lines = []
49
+
50
+ # Title
51
+ if title:
52
+ lines.append(f"\n{title.center(width + 10)}")
53
+ lines.append("=" * (width + 10))
54
+
55
+ # Y-axis label
56
+ if y_label:
57
+ lines.append(f"{y_label}")
58
+
59
+ # Plot area with y-axis
60
+ for y in range(height, 0, -1):
61
+ # Y-axis value
62
+ y_val = int(max_count * y / height)
63
+ line = f"{y_val:>6} │"
64
+
65
+ # Bars
66
+ for h in heights:
67
+ if h >= y:
68
+ line += "█"
69
+ else:
70
+ line += " "
71
+
72
+ lines.append(line)
73
+
74
+ # X-axis
75
+ lines.append(f"{'':>6} └" + "─" * len(heights))
76
+
77
+ # X-axis labels
78
+ x_labels_line = " " * 8
79
+ min_val, max_val = min(data), max(data)
80
+
81
+ # Add labels at key positions
82
+ label_positions = [0, len(heights)//4, len(heights)//2, 3*len(heights)//4, len(heights)-1]
83
+ for i, pos in enumerate(label_positions):
84
+ if pos < len(edges) - 1:
85
+ val = edges[pos]
86
+ label = f"{val:.1f}"
87
+ # Calculate position
88
+ target_pos = 8 + pos
89
+ if i == 0:
90
+ x_labels_line = label + x_labels_line[len(label):]
91
+ elif i == len(label_positions) - 1:
92
+ start = max(0, target_pos - len(label))
93
+ x_labels_line = x_labels_line[:start] + label
94
+ else:
95
+ start = max(0, target_pos - len(label)//2)
96
+ end = min(len(x_labels_line), start + len(label))
97
+ if start < len(x_labels_line):
98
+ x_labels_line = x_labels_line[:start] + label[:end-start] + x_labels_line[end:]
99
+
100
+ lines.append(x_labels_line)
101
+
102
+ # X-axis label
103
+ if x_label:
104
+ lines.append(f"\n{x_label.center(width + 10)}")
105
+
106
+ return "\n".join(lines)
107
+
108
+
109
+ def create_bar_chart(categories: List[str], values: List[int], width: int = 60,
110
+ title: str = "", show_values: bool = True) -> str:
111
+ """Create a horizontal bar chart."""
112
+ if not categories or not values:
113
+ return "No data to display"
114
+
115
+ max_val = max(values) if values else 1
116
+ lines = []
117
+
118
+ # Title
119
+ if title:
120
+ lines.append(f"\n{title}")
121
+ lines.append("=" * (width + 20))
122
+
123
+ # Find longest category name for alignment
124
+ max_cat_len = max(len(cat) for cat in categories)
125
+
126
+ # Create bars
127
+ for cat, val in zip(categories, values):
128
+ # Normalize bar length
129
+ bar_len = int(val * width / max_val) if max_val > 0 else 0
130
+ bar = "█" * bar_len
131
+
132
+ # Format line
133
+ if show_values:
134
+ line = f"{cat:<{max_cat_len}} │ {bar} {val}"
135
+ else:
136
+ line = f"{cat:<{max_cat_len}} │ {bar}"
137
+
138
+ lines.append(line)
139
+
140
+ return "\n".join(lines)
141
+
142
+
143
+ class FinetuningDataExtractorV3:
144
+ """Extract fine-tuning data from v3 Turso traces."""
145
+
146
+ def __init__(self, db_url: str):
147
+ self.db_manager = AsyncSQLTraceManager(db_url)
148
+ self._initialized = False
149
+
150
+ async def __aenter__(self):
151
+ await self.db_manager.initialize()
152
+ self._initialized = True
153
+ return self
154
+
155
+ async def __aexit__(self, exc_type, exc_val, exc_tb):
156
+ await self.db_manager.close()
157
+
158
+ async def get_all_sessions(self) -> pd.DataFrame:
159
+ """Get all session IDs from the database."""
160
+ query = """
161
+ SELECT DISTINCT session_id, created_at
162
+ FROM session_traces
163
+ ORDER BY created_at DESC
164
+ """
165
+ return await self.db_manager.query_traces(query)
166
+
167
+ async def get_session_metrics(self, session_id: str) -> Dict[str, Any]:
168
+ """Get metrics for a specific session."""
169
+ # Get total reward from environment events
170
+ reward_query = """
171
+ SELECT COALESCE(SUM(CAST(data->>'reward' AS REAL)), 0) as total_reward
172
+ FROM events
173
+ WHERE session_id = :session_id
174
+ AND event_type = 'environment'
175
+ AND data->>'reward' IS NOT NULL
176
+ """
177
+ reward_df = await self.db_manager.query_traces(reward_query, {"session_id": session_id})
178
+ total_reward = float(reward_df['total_reward'].iloc[0]) if not reward_df.empty else 0.0
179
+
180
+ # Get total tokens and cost from LM events
181
+ lm_query = """
182
+ SELECT
183
+ COALESCE(SUM(CAST(data->>'total_tokens' AS INTEGER)), 0) as total_tokens,
184
+ COALESCE(SUM(CAST(data->>'cost_usd' AS REAL)), 0) as total_cost
185
+ FROM events
186
+ WHERE session_id = :session_id
187
+ AND event_type = 'lm_cais'
188
+ """
189
+ lm_df = await self.db_manager.query_traces(lm_query, {"session_id": session_id})
190
+
191
+ total_tokens = int(lm_df['total_tokens'].iloc[0]) if not lm_df.empty else 0
192
+ total_cost = float(lm_df['total_cost'].iloc[0]) if not lm_df.empty else 0.0
193
+
194
+ return {
195
+ 'session_id': session_id,
196
+ 'total_reward': total_reward,
197
+ 'total_tokens': total_tokens,
198
+ 'total_cost': total_cost
199
+ }
200
+
201
+ async def get_session_achievements(self, session_id: str) -> List[str]:
202
+ """Get list of achievements unlocked in a session."""
203
+ # Look for achievement events in environment data
204
+ query = """
205
+ SELECT DISTINCT data->>'$.system_state_after.achievements_status' as achievements
206
+ FROM events
207
+ WHERE session_id = :session_id
208
+ AND event_type = 'environment'
209
+ AND data->>'$.system_state_after.achievements_status' IS NOT NULL
210
+ ORDER BY created_at DESC
211
+ LIMIT 1
212
+ """
213
+ df = await self.db_manager.query_traces(query, {"session_id": session_id})
214
+
215
+ if df.empty:
216
+ return []
217
+
218
+ try:
219
+ # Parse the achievements JSON
220
+ achievements_str = df['achievements'].iloc[0]
221
+ if achievements_str:
222
+ achievements = json.loads(achievements_str)
223
+ # Return list of unlocked achievements
224
+ return [k for k, v in achievements.items() if v]
225
+ except:
226
+ pass
227
+
228
+ return []
229
+
230
+ async def filter_by_achievements(self, min_achievements: int) -> List[str]:
231
+ """Get sessions with at least min_achievements unlocked."""
232
+ all_sessions = await self.get_all_sessions()
233
+ qualifying_sessions = []
234
+
235
+ for _, row in all_sessions.iterrows():
236
+ session_id = row['session_id']
237
+ achievements = await self.get_session_achievements(session_id)
238
+ if len(achievements) >= min_achievements:
239
+ qualifying_sessions.append(session_id)
240
+
241
+ return qualifying_sessions
242
+
243
+ async def extract_openai_format(self, session_ids: List[str], min_reward: float = 0.0) -> List[Dict[str, Any]]:
244
+ """Extract training data in OpenAI format from filtered sessions."""
245
+ training_data = []
246
+
247
+ for session_id in session_ids:
248
+ # Get session trace
249
+ trace_data = await self.db_manager.get_session_trace(session_id)
250
+ if not trace_data:
251
+ continue
252
+
253
+ # Build conversation history
254
+ messages = []
255
+
256
+ # Add system message if available
257
+ system_message = None
258
+ for timestep in trace_data.get('timesteps', []):
259
+ for msg in timestep.get('messages', []):
260
+ if msg['message_type'] == 'system' and system_message is None:
261
+ system_message = msg['content']
262
+ break
263
+ if system_message:
264
+ break
265
+
266
+ if system_message:
267
+ messages.append({"role": "system", "content": system_message})
268
+
269
+ # Process timesteps in order
270
+ for timestep in sorted(trace_data.get('timesteps', []), key=lambda x: x['turn_number']):
271
+ # Add messages from this timestep
272
+ for msg in timestep.get('messages', []):
273
+ if msg['message_type'] == 'user':
274
+ messages.append({"role": "user", "content": msg['content']})
275
+ elif msg['message_type'] == 'assistant':
276
+ messages.append({"role": "assistant", "content": msg['content']})
277
+
278
+ # Only include if we have a complete conversation
279
+ if len(messages) > 1:
280
+ training_data.append({
281
+ "messages": messages,
282
+ "metadata": {
283
+ "session_id": session_id,
284
+ "total_reward": trace_data.get('metadata', {}).get('total_reward', 0)
285
+ }
286
+ })
287
+
288
+ return training_data
289
+
290
+
291
+ async def filter_traces_from_turso(
292
+ db_url: str,
293
+ output_path: str,
294
+ config: Dict[str, Any]
295
+ ) -> Tuple[int, Dict[str, Any]]:
296
+ """
297
+ Filter traces from Turso/SQLite v3 database based on configuration.
298
+
299
+ Returns:
300
+ Tuple of (num_examples, statistics_dict)
301
+ """
302
+ mode = config.get("mode", "trajectory")
303
+ filters = config.get("filters", {})
304
+
305
+ # Extract filtering parameters
306
+ min_reward = filters.get("min_total_reward", 0.0)
307
+ min_achievements = filters.get("min_achievements", 0)
308
+ max_cost = filters.get("max_cost", float('inf'))
309
+ max_tokens = filters.get("max_tokens", float('inf'))
310
+
311
+ # OpenAI specific: filter by model if specified
312
+ target_models = filters.get("models", [])
313
+
314
+ statistics = {
315
+ "total_sessions": 0,
316
+ "filtered_sessions": 0,
317
+ "total_examples": 0,
318
+ "reward_distribution": [],
319
+ "token_distribution": [],
320
+ "cost_distribution": [],
321
+ "model_distribution": defaultdict(int)
322
+ }
323
+
324
+ async with FinetuningDataExtractorV3(db_url) as extractor:
325
+ # Get all sessions
326
+ all_sessions = await extractor.get_all_sessions()
327
+ statistics["total_sessions"] = len(all_sessions)
328
+
329
+ # Filter sessions based on criteria
330
+ filtered_sessions = []
331
+
332
+ for _, row in all_sessions.iterrows():
333
+ session_id = row['session_id']
334
+ metrics = await extractor.get_session_metrics(session_id)
335
+
336
+ # Apply filters
337
+ if metrics['total_reward'] < min_reward:
338
+ continue
339
+ if metrics['total_cost'] > max_cost:
340
+ continue
341
+ if metrics['total_tokens'] > max_tokens:
342
+ continue
343
+
344
+ # Check achievements if required
345
+ if min_achievements > 0:
346
+ achievements = await extractor.get_session_achievements(session_id)
347
+ if len(achievements) < min_achievements:
348
+ continue
349
+
350
+ # Check model filter if specified
351
+ if target_models:
352
+ model_query = """
353
+ SELECT DISTINCT data->>'model_name' as model_name
354
+ FROM events
355
+ WHERE session_id = :session_id
356
+ AND event_type = 'lm_cais'
357
+ AND data->>'model_name' IS NOT NULL
358
+ """
359
+ model_df = await extractor.db_manager.query_traces(
360
+ model_query, {"session_id": session_id}
361
+ )
362
+ session_models = model_df['model_name'].tolist() if not model_df.empty else []
363
+ if not any(model in target_models for model in session_models):
364
+ continue
365
+
366
+ filtered_sessions.append(session_id)
367
+
368
+ # Collect statistics
369
+ statistics["reward_distribution"].append(metrics['total_reward'])
370
+ statistics["token_distribution"].append(metrics['total_tokens'])
371
+ statistics["cost_distribution"].append(metrics['total_cost'])
372
+
373
+ statistics["filtered_sessions"] = len(filtered_sessions)
374
+
375
+ # Extract training data
376
+ if mode == "trajectory":
377
+ training_data = await extractor.extract_openai_format(
378
+ session_ids=filtered_sessions,
379
+ min_reward=min_reward
380
+ )
381
+ else: # window mode
382
+ # For window mode, we need to implement window extraction
383
+ # For now, use trajectory mode
384
+ training_data = await extractor.extract_openai_format(
385
+ session_ids=filtered_sessions,
386
+ min_reward=min_reward
387
+ )
388
+
389
+ statistics["total_examples"] = len(training_data)
390
+
391
+ # Write to output file
392
+ output_file = Path(output_path)
393
+ output_file.parent.mkdir(exist_ok=True)
394
+
395
+ with open(output_file, 'w') as f:
396
+ for example in training_data:
397
+ f.write(json.dumps(example) + '\n')
398
+
399
+ # Get model distribution
400
+ model_query = """
401
+ SELECT data->>'model_name' as model_name, COUNT(*) as count
402
+ FROM events
403
+ WHERE event_type = 'lm_cais'
404
+ AND data->>'model_name' IS NOT NULL
405
+ GROUP BY data->>'model_name'
406
+ """
407
+ model_stats = await extractor.db_manager.query_traces(model_query)
408
+ for _, row in model_stats.iterrows():
409
+ statistics["model_distribution"][row['model_name']] = int(row['count'])
410
+
411
+ return len(training_data), statistics
412
+
413
+
414
+ def print_statistics(stats: Dict[str, Any]):
415
+ """Print filtering statistics with visualizations."""
416
+ print("\n" + "="*80)
417
+ print("FILTERING STATISTICS (OpenAI - v3)")
418
+ print("="*80)
419
+
420
+ # Basic counts
421
+ print(f"\nTotal sessions in database: {stats['total_sessions']}")
422
+ print(f"Sessions after filtering: {stats['filtered_sessions']}")
423
+ print(f"Training examples generated: {stats['total_examples']}")
424
+
425
+ filter_rate = (stats['filtered_sessions'] / stats['total_sessions'] * 100) if stats['total_sessions'] > 0 else 0
426
+ print(f"Filter pass rate: {filter_rate:.1f}%")
427
+
428
+ # Reward distribution
429
+ if stats['reward_distribution'] and any(not np.isnan(x) for x in stats['reward_distribution']):
430
+ valid_rewards = [x for x in stats['reward_distribution'] if not np.isnan(x)]
431
+ if valid_rewards:
432
+ print(create_histogram(
433
+ valid_rewards,
434
+ bins=20,
435
+ title="Reward Distribution",
436
+ x_label="Total Reward",
437
+ y_label="Count"
438
+ ))
439
+
440
+ print(f"\nReward statistics:")
441
+ print(f" Min: {min(valid_rewards):.2f}")
442
+ print(f" Max: {max(valid_rewards):.2f}")
443
+ print(f" Mean: {np.mean(valid_rewards):.2f}")
444
+ print(f" Median: {np.median(valid_rewards):.2f}")
445
+ else:
446
+ print("\nNo valid reward data to display.")
447
+
448
+ # Token distribution
449
+ if stats['token_distribution'] and any(not np.isnan(x) for x in stats['token_distribution']):
450
+ valid_tokens = [x for x in stats['token_distribution'] if not np.isnan(x)]
451
+ if valid_tokens:
452
+ print(create_histogram(
453
+ valid_tokens,
454
+ bins=20,
455
+ title="Token Usage Distribution",
456
+ x_label="Total Tokens",
457
+ y_label="Count"
458
+ ))
459
+
460
+ # Model distribution
461
+ if stats['model_distribution']:
462
+ models = list(stats['model_distribution'].keys())
463
+ counts = list(stats['model_distribution'].values())
464
+ print(create_bar_chart(
465
+ models,
466
+ counts,
467
+ title="Model Usage",
468
+ show_values=True
469
+ ))
470
+
471
+ print("\n" + "="*80)
472
+
473
+
474
+ def main():
475
+ parser = argparse.ArgumentParser(
476
+ description="Filter traces from Turso/SQLite v3 for OpenAI fine-tuning",
477
+ formatter_class=argparse.RawDescriptionHelpFormatter,
478
+ epilog="""
479
+ Example usage:
480
+ # Use default config
481
+ python filter_traces_sft_turso.py -d sqlite:///traces.db -o ft_data/training.jsonl
482
+
483
+ # Use custom config file
484
+ python filter_traces_sft_turso.py -d sqlite:///traces.db -c filter_config.toml
485
+
486
+ # Override config parameters
487
+ python filter_traces_sft_turso.py -d sqlite:///traces.db --min-reward 5.0 --max-cost 0.1
488
+
489
+ # Filter by model
490
+ python filter_traces_sft_turso.py -d sqlite:///traces.db --models "gpt-4" "gpt-3.5-turbo"
491
+ """
492
+ )
493
+
494
+ parser.add_argument('-d', '--database', required=True, help='Path to Turso/SQLite database or connection URL')
495
+ parser.add_argument('-o', '--output', default='ft_data/training_openai.jsonl', help='Output JSONL file')
496
+ parser.add_argument('-c', '--config', help='Configuration TOML file')
497
+
498
+ # Filter overrides
499
+ parser.add_argument('--mode', choices=['trajectory', 'window'], help='Filtering mode')
500
+ parser.add_argument('--min-reward', type=float, help='Minimum total reward')
501
+ parser.add_argument('--min-achievements', type=int, help='Minimum achievements')
502
+ parser.add_argument('--max-cost', type=float, help='Maximum cost')
503
+ parser.add_argument('--max-tokens', type=int, help='Maximum tokens')
504
+ parser.add_argument('--models', nargs='+', help='Filter by model names (e.g., gpt-4 gpt-3.5-turbo)')
505
+
506
+ parser.add_argument('--dry-run', action='store_true', help='Show statistics without writing output')
507
+
508
+ args = parser.parse_args()
509
+
510
+ # Load config
511
+ config = {
512
+ "mode": "trajectory",
513
+ "filters": {
514
+ "min_total_reward": 1.0,
515
+ "min_achievements": 0,
516
+ "max_cost": 10.0,
517
+ "max_tokens": 100000,
518
+ "models": [] # Empty means all models
519
+ }
520
+ }
521
+
522
+ if args.config:
523
+ with open(args.config, 'r') as f:
524
+ loaded_config = toml.load(f)
525
+ config.update(loaded_config)
526
+
527
+ # Apply command-line overrides
528
+ if args.mode:
529
+ config["mode"] = args.mode
530
+ if args.min_reward is not None:
531
+ config["filters"]["min_total_reward"] = args.min_reward
532
+ if args.min_achievements is not None:
533
+ config["filters"]["min_achievements"] = args.min_achievements
534
+ if args.max_cost is not None:
535
+ config["filters"]["max_cost"] = args.max_cost
536
+ if args.max_tokens is not None:
537
+ config["filters"]["max_tokens"] = args.max_tokens
538
+ if args.models:
539
+ config["filters"]["models"] = args.models
540
+
541
+ # Convert database path to proper URL format if needed
542
+ db_url = args.database
543
+ if db_url.startswith("sqlite:///"):
544
+ # Already in URL format
545
+ pass
546
+ elif db_url.endswith(".db"):
547
+ # Convert file path to URL
548
+ db_url = f"sqlite+aiosqlite:///{db_url}"
549
+
550
+ print(f"🤖 OpenAI Fine-Tuning Data Filter (v3)")
551
+ print(f"Using database: {db_url}")
552
+ print(f"Output file: {args.output}")
553
+ print(f"Mode: {config['mode']}")
554
+ print(f"Filters: {json.dumps(config['filters'], indent=2)}")
555
+
556
+ if args.dry_run:
557
+ print("\n🔍 DRY RUN - No output will be written")
558
+
559
+ # Run filtering
560
+ async def run():
561
+ num_examples, stats = await filter_traces_from_turso(
562
+ db_url,
563
+ args.output if not args.dry_run else "/dev/null",
564
+ config
565
+ )
566
+
567
+ # Print statistics
568
+ print_statistics(stats)
569
+
570
+ if not args.dry_run:
571
+ print(f"\n✅ Successfully wrote {num_examples} training examples to {args.output}")
572
+ print(f" Ready for OpenAI fine-tuning!")
573
+ else:
574
+ print(f"\n✅ Would write {num_examples} training examples (dry run)")
575
+
576
+ asyncio.run(run())
577
+
578
+
579
+ if __name__ == "__main__":
580
+ main()