synth-ai 0.2.2.dev0__py3-none-any.whl → 0.2.4.dev2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- synth_ai/cli/__init__.py +66 -0
- synth_ai/cli/balance.py +205 -0
- synth_ai/cli/calc.py +70 -0
- synth_ai/cli/demo.py +74 -0
- synth_ai/{cli.py → cli/legacy_root_backup.py} +60 -15
- synth_ai/cli/man.py +103 -0
- synth_ai/cli/recent.py +126 -0
- synth_ai/cli/root.py +184 -0
- synth_ai/cli/status.py +126 -0
- synth_ai/cli/traces.py +136 -0
- synth_ai/cli/watch.py +508 -0
- synth_ai/config/base_url.py +53 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/analyze_semantic_words_markdown.py +252 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_modal_ft/filter_traces_sft_duckdb_v2_backup.py +413 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_modal_ft/filter_traces_sft_turso.py +760 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_modal_ft/kick_off_ft_synth.py +34 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_modal_ft/test_crafter_react_agent_lm_synth.py +1740 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_modal_ft/test_crafter_react_agent_lm_synth_v2_backup.py +1318 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_openai_ft/filter_traces_sft_duckdb_v2_backup.py +386 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_openai_ft/filter_traces_sft_turso.py +580 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_openai_ft/run_rollouts_for_models_and_compare_v2_backup.py +1352 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_openai_ft/run_rollouts_for_models_and_compare_v3.py +4 -4
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_openai_ft/test_crafter_react_agent_openai_v2_backup.py +2551 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_trace_evaluation.py +1 -1
- synth_ai/environments/examples/crafter_classic/agent_demos/example_v3_usage.py +1 -1
- synth_ai/environments/examples/crafter_classic/agent_demos/old/traces/session_crafter_episode_16_15227b68-2906-416f-acc4-d6a9b4fa5828_20250725_001154.json +1363 -1
- synth_ai/environments/examples/crafter_classic/agent_demos/test_crafter_react_agent.py +3 -3
- synth_ai/environments/examples/crafter_classic/environment.py +1 -1
- synth_ai/environments/examples/crafter_custom/environment.py +1 -1
- synth_ai/environments/examples/enron/dataset/corbt___enron_emails_sample_questions/default/0.0.0/293c9fe8170037e01cc9cf5834e0cd5ef6f1a6bb/dataset_info.json +1 -0
- synth_ai/environments/examples/nethack/helpers/achievements.json +64 -0
- synth_ai/environments/examples/red/units/test_exploration_strategy.py +1 -1
- synth_ai/environments/examples/red/units/test_menu_bug_reproduction.py +5 -5
- synth_ai/environments/examples/red/units/test_movement_debug.py +2 -2
- synth_ai/environments/examples/red/units/test_retry_movement.py +1 -1
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/available_envs.json +122 -0
- synth_ai/environments/examples/sokoban/verified_puzzles.json +54987 -0
- synth_ai/environments/service/core_routes.py +1 -1
- synth_ai/experimental/synth_oss.py +446 -0
- synth_ai/learning/core.py +21 -0
- synth_ai/learning/gateway.py +4 -0
- synth_ai/learning/prompts/gepa.py +0 -0
- synth_ai/learning/prompts/mipro.py +8 -0
- synth_ai/lm/__init__.py +3 -0
- synth_ai/lm/core/main.py +4 -0
- synth_ai/lm/core/main_v3.py +238 -122
- synth_ai/lm/core/vendor_clients.py +4 -0
- synth_ai/lm/provider_support/openai.py +11 -2
- synth_ai/lm/vendors/base.py +7 -0
- synth_ai/lm/vendors/openai_standard.py +339 -4
- synth_ai/lm/vendors/openai_standard_responses.py +243 -0
- synth_ai/lm/vendors/synth_client.py +155 -5
- synth_ai/lm/warmup.py +54 -17
- synth_ai/tracing/__init__.py +18 -0
- synth_ai/tracing_v1/__init__.py +29 -14
- synth_ai/tracing_v3/__init__.py +2 -2
- synth_ai/tracing_v3/abstractions.py +62 -17
- synth_ai/tracing_v3/config.py +13 -7
- synth_ai/tracing_v3/db_config.py +6 -6
- synth_ai/tracing_v3/hooks.py +1 -1
- synth_ai/tracing_v3/llm_call_record_helpers.py +350 -0
- synth_ai/tracing_v3/lm_call_record_abstractions.py +257 -0
- synth_ai/tracing_v3/session_tracer.py +5 -5
- synth_ai/tracing_v3/tests/test_concurrent_operations.py +1 -1
- synth_ai/tracing_v3/tests/test_llm_call_records.py +672 -0
- synth_ai/tracing_v3/tests/test_session_tracer.py +43 -9
- synth_ai/tracing_v3/tests/test_turso_manager.py +1 -1
- synth_ai/tracing_v3/turso/manager.py +18 -11
- synth_ai/tracing_v3/turso/models.py +1 -0
- synth_ai/tui/__main__.py +13 -0
- synth_ai/tui/dashboard.py +329 -0
- synth_ai/v0/tracing/__init__.py +0 -0
- synth_ai/{tracing → v0/tracing}/base_client.py +3 -3
- synth_ai/{tracing → v0/tracing}/client_manager.py +1 -1
- synth_ai/{tracing → v0/tracing}/context.py +1 -1
- synth_ai/{tracing → v0/tracing}/decorators.py +11 -11
- synth_ai/v0/tracing/events/__init__.py +0 -0
- synth_ai/{tracing → v0/tracing}/events/manage.py +4 -4
- synth_ai/{tracing → v0/tracing}/events/scope.py +6 -6
- synth_ai/{tracing → v0/tracing}/events/store.py +3 -3
- synth_ai/{tracing → v0/tracing}/immediate_client.py +6 -6
- synth_ai/{tracing → v0/tracing}/log_client_base.py +2 -2
- synth_ai/{tracing → v0/tracing}/retry_queue.py +3 -3
- synth_ai/{tracing → v0/tracing}/trackers.py +2 -2
- synth_ai/{tracing → v0/tracing}/upload.py +4 -4
- synth_ai/v0/tracing_v1/__init__.py +16 -0
- synth_ai/{tracing_v1 → v0/tracing_v1}/base_client.py +3 -3
- synth_ai/{tracing_v1 → v0/tracing_v1}/client_manager.py +1 -1
- synth_ai/{tracing_v1 → v0/tracing_v1}/context.py +1 -1
- synth_ai/{tracing_v1 → v0/tracing_v1}/decorators.py +11 -11
- synth_ai/v0/tracing_v1/events/__init__.py +0 -0
- synth_ai/{tracing_v1 → v0/tracing_v1}/events/manage.py +4 -4
- synth_ai/{tracing_v1 → v0/tracing_v1}/events/scope.py +6 -6
- synth_ai/{tracing_v1 → v0/tracing_v1}/events/store.py +3 -3
- synth_ai/{tracing_v1 → v0/tracing_v1}/immediate_client.py +6 -6
- synth_ai/{tracing_v1 → v0/tracing_v1}/log_client_base.py +2 -2
- synth_ai/{tracing_v1 → v0/tracing_v1}/retry_queue.py +3 -3
- synth_ai/{tracing_v1 → v0/tracing_v1}/trackers.py +2 -2
- synth_ai/{tracing_v1 → v0/tracing_v1}/upload.py +4 -4
- {synth_ai-0.2.2.dev0.dist-info → synth_ai-0.2.4.dev2.dist-info}/METADATA +100 -5
- {synth_ai-0.2.2.dev0.dist-info → synth_ai-0.2.4.dev2.dist-info}/RECORD +115 -75
- /synth_ai/{tracing/events/__init__.py → compound/cais.py} +0 -0
- /synth_ai/{tracing_v1/events/__init__.py → environments/examples/crafter_classic/debug_translation.py} +0 -0
- /synth_ai/{tracing → v0/tracing}/abstractions.py +0 -0
- /synth_ai/{tracing → v0/tracing}/config.py +0 -0
- /synth_ai/{tracing → v0/tracing}/local.py +0 -0
- /synth_ai/{tracing → v0/tracing}/utils.py +0 -0
- /synth_ai/{tracing_v1 → v0/tracing_v1}/abstractions.py +0 -0
- /synth_ai/{tracing_v1 → v0/tracing_v1}/config.py +0 -0
- /synth_ai/{tracing_v1 → v0/tracing_v1}/local.py +0 -0
- /synth_ai/{tracing_v1 → v0/tracing_v1}/utils.py +0 -0
- {synth_ai-0.2.2.dev0.dist-info → synth_ai-0.2.4.dev2.dist-info}/WHEEL +0 -0
- {synth_ai-0.2.2.dev0.dist-info → synth_ai-0.2.4.dev2.dist-info}/entry_points.txt +0 -0
- {synth_ai-0.2.2.dev0.dist-info → synth_ai-0.2.4.dev2.dist-info}/licenses/LICENSE +0 -0
- {synth_ai-0.2.2.dev0.dist-info → synth_ai-0.2.4.dev2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,386 @@
|
|
1
|
+
#!/usr/bin/env python3
|
2
|
+
"""
|
3
|
+
Filter traces from DuckDB to create OpenAI SFT-ready .jsonl files
|
4
|
+
Supports two modes:
|
5
|
+
1. Trajectory-level filtering: Include entire trajectories above a score threshold
|
6
|
+
2. Window-based filtering: Extract high-scoring windows of actions
|
7
|
+
"""
|
8
|
+
|
9
|
+
import json
|
10
|
+
import argparse
|
11
|
+
from pathlib import Path
|
12
|
+
from typing import List, Dict, Any, Tuple, Optional
|
13
|
+
from collections import defaultdict
|
14
|
+
import numpy as np
|
15
|
+
import os
|
16
|
+
import sys
|
17
|
+
import toml
|
18
|
+
|
19
|
+
# Add synth_ai to path
|
20
|
+
sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent.parent))
|
21
|
+
|
22
|
+
from synth_ai.tracing_v2.duckdb.ft_utils import FinetuningDataExtractor
|
23
|
+
from synth_ai.tracing_v2.duckdb.manager import DuckDBTraceManager
|
24
|
+
|
25
|
+
|
26
|
+
def create_histogram(data: List[float], bins: int = 20, width: int = 60, height: int = 15,
|
27
|
+
title: str = "", x_label: str = "", y_label: str = "") -> str:
|
28
|
+
"""Create a beautiful ASCII histogram."""
|
29
|
+
if not data:
|
30
|
+
return "No data to display"
|
31
|
+
|
32
|
+
# Create histogram
|
33
|
+
counts, edges = np.histogram(data, bins=bins)
|
34
|
+
max_count = max(counts) if len(counts) > 0 else 1
|
35
|
+
|
36
|
+
# Normalize heights
|
37
|
+
if max_count > 0:
|
38
|
+
heights = [int(c * height / max_count) for c in counts]
|
39
|
+
else:
|
40
|
+
heights = [0] * len(counts)
|
41
|
+
|
42
|
+
# Build the plot
|
43
|
+
lines = []
|
44
|
+
|
45
|
+
# Title
|
46
|
+
if title:
|
47
|
+
lines.append(f"\n{title.center(width + 10)}")
|
48
|
+
lines.append("=" * (width + 10))
|
49
|
+
|
50
|
+
# Y-axis label
|
51
|
+
if y_label:
|
52
|
+
lines.append(f"{y_label}")
|
53
|
+
|
54
|
+
# Plot area with y-axis
|
55
|
+
for y in range(height, 0, -1):
|
56
|
+
# Y-axis value
|
57
|
+
y_val = int(max_count * y / height)
|
58
|
+
line = f"{y_val:>6} │"
|
59
|
+
|
60
|
+
# Bars
|
61
|
+
for h in heights:
|
62
|
+
if h >= y:
|
63
|
+
line += "█"
|
64
|
+
else:
|
65
|
+
line += " "
|
66
|
+
|
67
|
+
lines.append(line)
|
68
|
+
|
69
|
+
# X-axis
|
70
|
+
lines.append(f"{'':>6} └" + "─" * len(heights))
|
71
|
+
|
72
|
+
# X-axis labels
|
73
|
+
x_labels_line = " " * 8
|
74
|
+
min_val, max_val = min(data), max(data)
|
75
|
+
|
76
|
+
# Add labels at key positions
|
77
|
+
label_positions = [0, len(heights)//4, len(heights)//2, 3*len(heights)//4, len(heights)-1]
|
78
|
+
for i, pos in enumerate(label_positions):
|
79
|
+
if pos < len(edges) - 1:
|
80
|
+
val = edges[pos]
|
81
|
+
label = f"{val:.1f}"
|
82
|
+
# Calculate position
|
83
|
+
target_pos = 8 + pos
|
84
|
+
if i == 0:
|
85
|
+
x_labels_line = label + x_labels_line[len(label):]
|
86
|
+
elif i == len(label_positions) - 1:
|
87
|
+
start = max(0, target_pos - len(label))
|
88
|
+
x_labels_line = x_labels_line[:start] + label
|
89
|
+
else:
|
90
|
+
start = max(0, target_pos - len(label)//2)
|
91
|
+
end = min(len(x_labels_line), start + len(label))
|
92
|
+
if start < len(x_labels_line):
|
93
|
+
x_labels_line = x_labels_line[:start] + label[:end-start] + x_labels_line[end:]
|
94
|
+
|
95
|
+
lines.append(x_labels_line)
|
96
|
+
|
97
|
+
# X-axis label
|
98
|
+
if x_label:
|
99
|
+
lines.append(f"\n{x_label.center(width + 10)}")
|
100
|
+
|
101
|
+
return "\n".join(lines)
|
102
|
+
|
103
|
+
|
104
|
+
def create_bar_chart(categories: List[str], values: List[int], width: int = 60,
|
105
|
+
title: str = "", show_values: bool = True) -> str:
|
106
|
+
"""Create a horizontal bar chart."""
|
107
|
+
if not categories or not values:
|
108
|
+
return "No data to display"
|
109
|
+
|
110
|
+
max_val = max(values) if values else 1
|
111
|
+
lines = []
|
112
|
+
|
113
|
+
# Title
|
114
|
+
if title:
|
115
|
+
lines.append(f"\n{title}")
|
116
|
+
lines.append("=" * (width + 20))
|
117
|
+
|
118
|
+
# Find longest category name for alignment
|
119
|
+
max_cat_len = max(len(cat) for cat in categories)
|
120
|
+
|
121
|
+
# Create bars
|
122
|
+
for cat, val in zip(categories, values):
|
123
|
+
# Normalize bar length
|
124
|
+
bar_len = int(val * width / max_val) if max_val > 0 else 0
|
125
|
+
bar = "█" * bar_len
|
126
|
+
|
127
|
+
# Format line
|
128
|
+
if show_values:
|
129
|
+
line = f"{cat:<{max_cat_len}} │ {bar} {val}"
|
130
|
+
else:
|
131
|
+
line = f"{cat:<{max_cat_len}} │ {bar}"
|
132
|
+
|
133
|
+
lines.append(line)
|
134
|
+
|
135
|
+
return "\n".join(lines)
|
136
|
+
|
137
|
+
|
138
|
+
def filter_traces_from_duckdb(
|
139
|
+
db_path: str,
|
140
|
+
output_path: str,
|
141
|
+
config: Dict[str, Any]
|
142
|
+
) -> Tuple[int, Dict[str, Any]]:
|
143
|
+
"""
|
144
|
+
Filter traces from DuckDB based on configuration.
|
145
|
+
|
146
|
+
Returns:
|
147
|
+
Tuple of (num_examples, statistics_dict)
|
148
|
+
"""
|
149
|
+
mode = config.get("mode", "trajectory")
|
150
|
+
filters = config.get("filters", {})
|
151
|
+
|
152
|
+
# Extract filtering parameters
|
153
|
+
min_reward = filters.get("min_total_reward", 0.0)
|
154
|
+
min_achievements = filters.get("min_achievements", 0)
|
155
|
+
max_cost = filters.get("max_cost", float('inf'))
|
156
|
+
max_tokens = filters.get("max_tokens", float('inf'))
|
157
|
+
|
158
|
+
statistics = {
|
159
|
+
"total_sessions": 0,
|
160
|
+
"filtered_sessions": 0,
|
161
|
+
"total_examples": 0,
|
162
|
+
"reward_distribution": [],
|
163
|
+
"token_distribution": [],
|
164
|
+
"cost_distribution": [],
|
165
|
+
"model_distribution": defaultdict(int)
|
166
|
+
}
|
167
|
+
|
168
|
+
with FinetuningDataExtractor(db_path) as extractor:
|
169
|
+
# Get all sessions
|
170
|
+
all_sessions_query = "SELECT session_id FROM session_traces"
|
171
|
+
all_sessions = extractor.db_manager.query_traces(all_sessions_query)
|
172
|
+
statistics["total_sessions"] = len(all_sessions)
|
173
|
+
|
174
|
+
# Filter sessions based on criteria
|
175
|
+
filtered_sessions = []
|
176
|
+
|
177
|
+
for session_id in all_sessions['session_id']:
|
178
|
+
metrics = extractor.get_session_metrics(session_id)
|
179
|
+
|
180
|
+
# Apply filters
|
181
|
+
if metrics['total_reward'] < min_reward:
|
182
|
+
continue
|
183
|
+
if metrics['total_cost'] > max_cost:
|
184
|
+
continue
|
185
|
+
if metrics['total_tokens'] > max_tokens:
|
186
|
+
continue
|
187
|
+
|
188
|
+
# Check achievements if required
|
189
|
+
if min_achievements > 0:
|
190
|
+
achievement_sessions = extractor.filter_by_achievements(min_achievements)
|
191
|
+
if session_id not in achievement_sessions:
|
192
|
+
continue
|
193
|
+
|
194
|
+
filtered_sessions.append(session_id)
|
195
|
+
|
196
|
+
# Collect statistics
|
197
|
+
statistics["reward_distribution"].append(metrics['total_reward'])
|
198
|
+
statistics["token_distribution"].append(metrics['total_tokens'])
|
199
|
+
statistics["cost_distribution"].append(metrics['total_cost'])
|
200
|
+
|
201
|
+
statistics["filtered_sessions"] = len(filtered_sessions)
|
202
|
+
|
203
|
+
# Extract training data
|
204
|
+
if mode == "trajectory":
|
205
|
+
training_data = extractor.extract_openai_format(
|
206
|
+
session_ids=filtered_sessions,
|
207
|
+
min_reward=min_reward
|
208
|
+
)
|
209
|
+
else: # window mode
|
210
|
+
# For window mode, we need to implement window extraction
|
211
|
+
# For now, use trajectory mode
|
212
|
+
training_data = extractor.extract_openai_format(
|
213
|
+
session_ids=filtered_sessions,
|
214
|
+
min_reward=min_reward
|
215
|
+
)
|
216
|
+
|
217
|
+
statistics["total_examples"] = len(training_data)
|
218
|
+
|
219
|
+
# Write to output file
|
220
|
+
output_file = Path(output_path)
|
221
|
+
output_file.parent.mkdir(exist_ok=True)
|
222
|
+
|
223
|
+
with open(output_file, 'w') as f:
|
224
|
+
for example in training_data:
|
225
|
+
f.write(json.dumps(example) + '\n')
|
226
|
+
|
227
|
+
# Get model distribution
|
228
|
+
model_query = """
|
229
|
+
SELECT model_name, COUNT(*) as count
|
230
|
+
FROM events
|
231
|
+
WHERE event_type = 'cais' AND model_name IS NOT NULL
|
232
|
+
GROUP BY model_name
|
233
|
+
"""
|
234
|
+
model_stats = extractor.db_manager.query_traces(model_query)
|
235
|
+
for _, row in model_stats.iterrows():
|
236
|
+
statistics["model_distribution"][row['model_name']] = row['count']
|
237
|
+
|
238
|
+
return len(training_data), statistics
|
239
|
+
|
240
|
+
|
241
|
+
def print_statistics(stats: Dict[str, Any]):
|
242
|
+
"""Print filtering statistics with visualizations."""
|
243
|
+
print("\n" + "="*80)
|
244
|
+
print("FILTERING STATISTICS")
|
245
|
+
print("="*80)
|
246
|
+
|
247
|
+
# Basic counts
|
248
|
+
print(f"\nTotal sessions in database: {stats['total_sessions']}")
|
249
|
+
print(f"Sessions after filtering: {stats['filtered_sessions']}")
|
250
|
+
print(f"Training examples generated: {stats['total_examples']}")
|
251
|
+
|
252
|
+
filter_rate = (stats['filtered_sessions'] / stats['total_sessions'] * 100) if stats['total_sessions'] > 0 else 0
|
253
|
+
print(f"Filter pass rate: {filter_rate:.1f}%")
|
254
|
+
|
255
|
+
# Reward distribution
|
256
|
+
if stats['reward_distribution'] and any(not np.isnan(x) for x in stats['reward_distribution']):
|
257
|
+
valid_rewards = [x for x in stats['reward_distribution'] if not np.isnan(x)]
|
258
|
+
if valid_rewards:
|
259
|
+
print(create_histogram(
|
260
|
+
valid_rewards,
|
261
|
+
bins=20,
|
262
|
+
title="Reward Distribution",
|
263
|
+
x_label="Total Reward",
|
264
|
+
y_label="Count"
|
265
|
+
))
|
266
|
+
|
267
|
+
print(f"\nReward statistics:")
|
268
|
+
print(f" Min: {min(valid_rewards):.2f}")
|
269
|
+
print(f" Max: {max(valid_rewards):.2f}")
|
270
|
+
print(f" Mean: {np.mean(valid_rewards):.2f}")
|
271
|
+
print(f" Median: {np.median(valid_rewards):.2f}")
|
272
|
+
else:
|
273
|
+
print("\nNo valid reward data to display.")
|
274
|
+
|
275
|
+
# Token distribution
|
276
|
+
if stats['token_distribution'] and any(not np.isnan(x) for x in stats['token_distribution']):
|
277
|
+
valid_tokens = [x for x in stats['token_distribution'] if not np.isnan(x)]
|
278
|
+
if valid_tokens:
|
279
|
+
print(create_histogram(
|
280
|
+
valid_tokens,
|
281
|
+
bins=20,
|
282
|
+
title="Token Usage Distribution",
|
283
|
+
x_label="Total Tokens",
|
284
|
+
y_label="Count"
|
285
|
+
))
|
286
|
+
|
287
|
+
# Model distribution
|
288
|
+
if stats['model_distribution']:
|
289
|
+
models = list(stats['model_distribution'].keys())
|
290
|
+
counts = list(stats['model_distribution'].values())
|
291
|
+
print(create_bar_chart(
|
292
|
+
models,
|
293
|
+
counts,
|
294
|
+
title="Model Usage",
|
295
|
+
show_values=True
|
296
|
+
))
|
297
|
+
|
298
|
+
print("\n" + "="*80)
|
299
|
+
|
300
|
+
|
301
|
+
def main():
|
302
|
+
parser = argparse.ArgumentParser(
|
303
|
+
description="Filter traces from DuckDB for fine-tuning",
|
304
|
+
formatter_class=argparse.RawDescriptionHelpFormatter,
|
305
|
+
epilog="""
|
306
|
+
Example usage:
|
307
|
+
# Use default config
|
308
|
+
python filter_traces_sft_duckdb.py -d crafter_traces.duckdb -o ft_data/training.jsonl
|
309
|
+
|
310
|
+
# Use custom config file
|
311
|
+
python filter_traces_sft_duckdb.py -d crafter_traces.duckdb -c filter_config.toml
|
312
|
+
|
313
|
+
# Override config parameters
|
314
|
+
python filter_traces_sft_duckdb.py -d crafter_traces.duckdb --min-reward 5.0 --max-cost 0.1
|
315
|
+
"""
|
316
|
+
)
|
317
|
+
|
318
|
+
parser.add_argument('-d', '--database', required=True, help='Path to DuckDB database')
|
319
|
+
parser.add_argument('-o', '--output', default='ft_data/training.jsonl', help='Output JSONL file')
|
320
|
+
parser.add_argument('-c', '--config', help='Configuration TOML file')
|
321
|
+
|
322
|
+
# Filter overrides
|
323
|
+
parser.add_argument('--mode', choices=['trajectory', 'window'], help='Filtering mode')
|
324
|
+
parser.add_argument('--min-reward', type=float, help='Minimum total reward')
|
325
|
+
parser.add_argument('--min-achievements', type=int, help='Minimum achievements')
|
326
|
+
parser.add_argument('--max-cost', type=float, help='Maximum cost')
|
327
|
+
parser.add_argument('--max-tokens', type=int, help='Maximum tokens')
|
328
|
+
|
329
|
+
parser.add_argument('--dry-run', action='store_true', help='Show statistics without writing output')
|
330
|
+
|
331
|
+
args = parser.parse_args()
|
332
|
+
|
333
|
+
# Load config
|
334
|
+
config = {
|
335
|
+
"mode": "trajectory",
|
336
|
+
"filters": {
|
337
|
+
"min_total_reward": 1.0,
|
338
|
+
"min_achievements": 0,
|
339
|
+
"max_cost": 10.0,
|
340
|
+
"max_tokens": 100000
|
341
|
+
}
|
342
|
+
}
|
343
|
+
|
344
|
+
if args.config:
|
345
|
+
with open(args.config, 'r') as f:
|
346
|
+
loaded_config = toml.load(f)
|
347
|
+
config.update(loaded_config)
|
348
|
+
|
349
|
+
# Apply command-line overrides
|
350
|
+
if args.mode:
|
351
|
+
config["mode"] = args.mode
|
352
|
+
if args.min_reward is not None:
|
353
|
+
config["filters"]["min_total_reward"] = args.min_reward
|
354
|
+
if args.min_achievements is not None:
|
355
|
+
config["filters"]["min_achievements"] = args.min_achievements
|
356
|
+
if args.max_cost is not None:
|
357
|
+
config["filters"]["max_cost"] = args.max_cost
|
358
|
+
if args.max_tokens is not None:
|
359
|
+
config["filters"]["max_tokens"] = args.max_tokens
|
360
|
+
|
361
|
+
print(f"Using database: {args.database}")
|
362
|
+
print(f"Output file: {args.output}")
|
363
|
+
print(f"Mode: {config['mode']}")
|
364
|
+
print(f"Filters: {json.dumps(config['filters'], indent=2)}")
|
365
|
+
|
366
|
+
if args.dry_run:
|
367
|
+
print("\n🔍 DRY RUN - No output will be written")
|
368
|
+
|
369
|
+
# Run filtering
|
370
|
+
num_examples, stats = filter_traces_from_duckdb(
|
371
|
+
args.database,
|
372
|
+
args.output if not args.dry_run else "/dev/null",
|
373
|
+
config
|
374
|
+
)
|
375
|
+
|
376
|
+
# Print statistics
|
377
|
+
print_statistics(stats)
|
378
|
+
|
379
|
+
if not args.dry_run:
|
380
|
+
print(f"\n✅ Successfully wrote {num_examples} training examples to {args.output}")
|
381
|
+
else:
|
382
|
+
print(f"\n✅ Would write {num_examples} training examples (dry run)")
|
383
|
+
|
384
|
+
|
385
|
+
if __name__ == "__main__":
|
386
|
+
main()
|