synth-ai 0.2.4.dev8__py3-none-any.whl → 0.2.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of synth-ai might be problematic. Click here for more details.
- synth_ai/__init__.py +1 -1
- synth_ai/cli/__init__.py +6 -0
- synth_ai/cli/demo.py +68 -9
- synth_ai/cli/rl_demo.py +137 -0
- synth_ai/cli/root.py +65 -0
- synth_ai/demos/core/__init__.py +1 -0
- synth_ai/demos/core/cli.py +685 -0
- synth_ai/demos/demo_task_apps/__init__.py +1 -0
- synth_ai/demos/demo_task_apps/core.py +374 -0
- synth_ai/demos/demo_task_apps/math/__init__.py +1 -0
- synth_ai/demos/demo_task_apps/math/app.py +37 -0
- synth_ai/demos/demo_task_apps/math/config.toml +44 -0
- synth_ai/demos/demo_task_apps/math/deploy_modal.py +60 -0
- synth_ai/demos/demo_task_apps/math/deploy_task_app.sh +22 -0
- synth_ai/environments/examples/bandit/__init__.py +33 -0
- synth_ai/environments/examples/bandit/engine.py +294 -0
- synth_ai/environments/examples/bandit/environment.py +194 -0
- synth_ai/environments/examples/bandit/taskset.py +200 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/analyze_semantic_words_markdown.py +250 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_comprehensive_evaluation.py +59 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_evaluation_browser.py +152 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_evaluation_config.toml +24 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_evaluation_framework.py +1194 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_modal_ft/crafter_synth_config.toml +56 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_modal_ft/filter_config_modal.toml +32 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_modal_ft/filter_traces_sft_turso.py +724 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_modal_ft/kick_off_ft_modal.py +384 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_modal_ft/old/analyze_action_results.py +53 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_modal_ft/old/analyze_agent_actions.py +178 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_modal_ft/old/analyze_latest_run.py +222 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_modal_ft/old/analyze_lm_traces.py +183 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_modal_ft/old/analyze_no_rewards.py +210 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_modal_ft/old/analyze_trace_issue.py +206 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_modal_ft/old/check_db_schema.py +49 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_modal_ft/old/check_latest_results.py +64 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_modal_ft/old/debug_agent_responses.py +88 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_modal_ft/old/quick_trace_check.py +77 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_openai_ft/compare_experiments.py +324 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_openai_ft/filter_traces_sft_turso.py +580 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_openai_ft/kick_off_ft_oai.py +362 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_openai_ft/multi_model_config.toml +49 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_openai_ft/old/analyze_enhanced_hooks.py +332 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_openai_ft/old/analyze_hook_events.py +97 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_openai_ft/old/analyze_hook_results.py +217 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_openai_ft/old/check_hook_storage.py +87 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_openai_ft/old/check_seeds.py +88 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_openai_ft/old/compare_seed_performance.py +195 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_openai_ft/old/custom_eval_pipelines.py +400 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_openai_ft/old/plot_hook_frequency.py +195 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_openai_ft/old/seed_analysis_summary.py +56 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_openai_ft/run_rollouts_for_models_and_compare_v3.py +858 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_quick_evaluation.py +52 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_react_agent.py +874 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_trace_evaluation.py +1412 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/example_v3_usage.py +216 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/old/compare_traces.py +296 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/old/crafter_comprehensive_evaluation.py +58 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/old/crafter_env_serialization.py +464 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/old/crafter_evaluation_browser.py +152 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/old/crafter_quick_evaluation.py +51 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/old/crafter_trace_evaluation.py +1412 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/old/debug_player_loss.py +112 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/old/diagnose_service.py +203 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/old/diagnose_slowness.py +305 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/old/eval_by_difficulty.py +126 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/old/eval_example.py +94 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/old/explore_saved_states.py +142 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/old/filter_traces_sft.py +26 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/old/filter_traces_sft_OLD.py +984 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/old/generate_ft_data_gemini.py +724 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/old/generate_ft_data_modal.py +386 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/old/generate_ft_metadata.py +205 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/old/kick_off_ft_gemini.py +150 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/old/kick_off_ft_modal.py +283 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/old/prepare_vertex_ft.py +280 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/old/profile_env_slowness.py +456 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/old/replicate_issue.py +166 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/old/run_and_eval.py +102 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/old/run_comparison.py +128 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/old/run_qwen_rollouts.py +655 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/old/trace_eval_OLD.py +202 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/old/validate_openai_format.py +166 -0
- synth_ai/environments/examples/crafter_classic/environment.py +41 -2
- synth_ai/environments/examples/crafter_custom/agent_demos/__init__.py +1 -0
- synth_ai/environments/examples/crafter_custom/agent_demos/trace_eval.py +202 -0
- synth_ai/environments/examples/crafter_custom/old/analyze_diamond_issue.py +159 -0
- synth_ai/environments/examples/crafter_custom/old/analyze_diamond_spawning.py +158 -0
- synth_ai/environments/examples/crafter_custom/old/compare_worlds.py +71 -0
- synth_ai/environments/examples/crafter_custom/old/dataset_stats.py +105 -0
- synth_ai/environments/examples/crafter_custom/old/diamond_spawning_summary.py +119 -0
- synth_ai/environments/examples/crafter_custom/old/example_dataset_usage.py +52 -0
- synth_ai/environments/examples/enron/units/keyword_stats.py +112 -0
- synth_ai/environments/examples/minigrid/agent_demos/minigrid_evaluation_framework.py +1188 -0
- synth_ai/environments/examples/minigrid/agent_demos/minigrid_quick_evaluation.py +48 -0
- synth_ai/environments/examples/minigrid/agent_demos/minigrid_react_agent.py +562 -0
- synth_ai/environments/examples/minigrid/agent_demos/minigrid_trace_evaluation.py +221 -0
- synth_ai/environments/examples/nethack/agent_demos/nethack_evaluation_framework.py +981 -0
- synth_ai/environments/examples/nethack/agent_demos/nethack_quick_evaluation.py +74 -0
- synth_ai/environments/examples/nethack/agent_demos/nethack_react_agent.py +831 -0
- synth_ai/environments/examples/red/agent_demos/__init__.py +1 -0
- synth_ai/environments/examples/red/units/__init__.py +1 -0
- synth_ai/environments/examples/sokoban/agent_demos/sokoban_full_eval.py +899 -0
- synth_ai/environments/examples/sokoban/units/astar_common.py +95 -0
- synth_ai/environments/service/app.py +8 -0
- synth_ai/install_sqld.sh +40 -0
- synth_ai-0.2.5.dist-info/METADATA +106 -0
- {synth_ai-0.2.4.dev8.dist-info → synth_ai-0.2.5.dist-info}/RECORD +111 -12
- {synth_ai-0.2.4.dev8.dist-info → synth_ai-0.2.5.dist-info}/entry_points.txt +1 -0
- synth_ai-0.2.4.dev8.dist-info/METADATA +0 -635
- {synth_ai-0.2.4.dev8.dist-info → synth_ai-0.2.5.dist-info}/WHEEL +0 -0
- {synth_ai-0.2.4.dev8.dist-info → synth_ai-0.2.5.dist-info}/licenses/LICENSE +0 -0
- {synth_ai-0.2.4.dev8.dist-info → synth_ai-0.2.5.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,984 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
"""
|
|
3
|
+
Filter traces to create OpenAI SFT-ready .jsonl files
|
|
4
|
+
Supports two modes:
|
|
5
|
+
1. Trajectory-level filtering: Include entire trajectories above a score threshold
|
|
6
|
+
2. Window-based filtering: Extract high-scoring windows of actions
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
import json
|
|
10
|
+
import argparse
|
|
11
|
+
from pathlib import Path
|
|
12
|
+
from typing import List, Dict, Any, Tuple, Optional
|
|
13
|
+
from collections import defaultdict
|
|
14
|
+
import numpy as np
|
|
15
|
+
import os
|
|
16
|
+
import sys
|
|
17
|
+
import toml
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def create_histogram(data: List[float], bins: int = 20, width: int = 60, height: int = 15,
|
|
21
|
+
title: str = "", x_label: str = "", y_label: str = "") -> str:
|
|
22
|
+
"""Create a beautiful ASCII histogram."""
|
|
23
|
+
if not data:
|
|
24
|
+
return "No data to display"
|
|
25
|
+
|
|
26
|
+
# Create histogram
|
|
27
|
+
counts, edges = np.histogram(data, bins=bins)
|
|
28
|
+
max_count = max(counts) if len(counts) > 0 else 1
|
|
29
|
+
|
|
30
|
+
# Normalize heights
|
|
31
|
+
if max_count > 0:
|
|
32
|
+
heights = [int(c * height / max_count) for c in counts]
|
|
33
|
+
else:
|
|
34
|
+
heights = [0] * len(counts)
|
|
35
|
+
|
|
36
|
+
# Build the plot
|
|
37
|
+
lines = []
|
|
38
|
+
|
|
39
|
+
# Title
|
|
40
|
+
if title:
|
|
41
|
+
lines.append(f"\n{title.center(width + 10)}")
|
|
42
|
+
lines.append("=" * (width + 10))
|
|
43
|
+
|
|
44
|
+
# Y-axis label
|
|
45
|
+
if y_label:
|
|
46
|
+
lines.append(f"{y_label}")
|
|
47
|
+
|
|
48
|
+
# Plot area with y-axis
|
|
49
|
+
for y in range(height, 0, -1):
|
|
50
|
+
# Y-axis value
|
|
51
|
+
y_val = int(max_count * y / height)
|
|
52
|
+
line = f"{y_val:>6} │"
|
|
53
|
+
|
|
54
|
+
# Bars
|
|
55
|
+
for h in heights:
|
|
56
|
+
if h >= y:
|
|
57
|
+
line += "█"
|
|
58
|
+
else:
|
|
59
|
+
line += " "
|
|
60
|
+
|
|
61
|
+
lines.append(line)
|
|
62
|
+
|
|
63
|
+
# X-axis
|
|
64
|
+
lines.append(f"{'':>6} └" + "─" * len(heights))
|
|
65
|
+
|
|
66
|
+
# X-axis labels
|
|
67
|
+
x_labels_line = " " * 8
|
|
68
|
+
min_val, max_val = min(data), max(data)
|
|
69
|
+
|
|
70
|
+
# Add labels at key positions
|
|
71
|
+
label_positions = [0, len(heights)//4, len(heights)//2, 3*len(heights)//4, len(heights)-1]
|
|
72
|
+
for i, pos in enumerate(label_positions):
|
|
73
|
+
if pos < len(edges) - 1:
|
|
74
|
+
val = edges[pos]
|
|
75
|
+
label = f"{val:.1f}"
|
|
76
|
+
# Calculate position
|
|
77
|
+
target_pos = 8 + pos
|
|
78
|
+
if i == 0:
|
|
79
|
+
x_labels_line = label + x_labels_line[len(label):]
|
|
80
|
+
elif i == len(label_positions) - 1:
|
|
81
|
+
start = max(0, target_pos - len(label))
|
|
82
|
+
x_labels_line = x_labels_line[:start] + label
|
|
83
|
+
else:
|
|
84
|
+
start = max(0, target_pos - len(label)//2)
|
|
85
|
+
end = min(len(x_labels_line), start + len(label))
|
|
86
|
+
if start < len(x_labels_line):
|
|
87
|
+
x_labels_line = x_labels_line[:start] + label[:end-start] + x_labels_line[end:]
|
|
88
|
+
|
|
89
|
+
lines.append(x_labels_line)
|
|
90
|
+
|
|
91
|
+
# X-axis label
|
|
92
|
+
if x_label:
|
|
93
|
+
lines.append(f"\n{x_label.center(width + 10)}")
|
|
94
|
+
|
|
95
|
+
return "\n".join(lines)
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
def create_bar_chart(categories: List[str], values: List[int], width: int = 60,
|
|
99
|
+
title: str = "", show_values: bool = True) -> str:
|
|
100
|
+
"""Create a horizontal bar chart."""
|
|
101
|
+
if not categories:
|
|
102
|
+
return "No data to display"
|
|
103
|
+
|
|
104
|
+
max_val = max(values) if values else 1
|
|
105
|
+
max_label_len = max(len(cat) for cat in categories)
|
|
106
|
+
|
|
107
|
+
lines = []
|
|
108
|
+
|
|
109
|
+
# Title
|
|
110
|
+
if title:
|
|
111
|
+
lines.append(f"\n{title}")
|
|
112
|
+
lines.append("=" * (width + max_label_len + 15))
|
|
113
|
+
|
|
114
|
+
# Bars
|
|
115
|
+
for cat, val in zip(categories, values):
|
|
116
|
+
bar_width = int(val * width / max_val) if max_val > 0 else 0
|
|
117
|
+
bar = "█" * bar_width
|
|
118
|
+
|
|
119
|
+
if show_values:
|
|
120
|
+
line = f"{cat:<{max_label_len}} │ {bar} {val}"
|
|
121
|
+
else:
|
|
122
|
+
line = f"{cat:<{max_label_len}} │ {bar}"
|
|
123
|
+
|
|
124
|
+
lines.append(line)
|
|
125
|
+
|
|
126
|
+
return "\n".join(lines)
|
|
127
|
+
|
|
128
|
+
|
|
129
|
+
def display_analysis_results(scores: List[float], achievements_per_trace: Dict[str, List[str]],
|
|
130
|
+
window_scores: List[Tuple[str, int, int, float]] = None):
|
|
131
|
+
"""Display beautiful analysis results."""
|
|
132
|
+
# Don't clear screen in analyze mode
|
|
133
|
+
# os.system('clear' if os.name == 'posix' else 'cls')
|
|
134
|
+
|
|
135
|
+
print("\n" + "╔" + "═" * 78 + "╗")
|
|
136
|
+
print("║" + " CRAFTER TRACE ANALYSIS RESULTS ".center(78) + "║")
|
|
137
|
+
print("╚" + "═" * 78 + "╝")
|
|
138
|
+
|
|
139
|
+
# 1. Trajectory Score Distribution
|
|
140
|
+
# For discrete scores, create custom histogram
|
|
141
|
+
if scores:
|
|
142
|
+
score_counts = {}
|
|
143
|
+
for s in scores:
|
|
144
|
+
score_counts[int(s)] = score_counts.get(int(s), 0) + 1
|
|
145
|
+
|
|
146
|
+
# Create bar chart for scores
|
|
147
|
+
max_score = int(max(scores))
|
|
148
|
+
max_count = max(score_counts.values())
|
|
149
|
+
|
|
150
|
+
print("\n" + "📊 Trajectory Score Distribution".center(70))
|
|
151
|
+
print("=" * 70)
|
|
152
|
+
print("Traces")
|
|
153
|
+
|
|
154
|
+
# Y-axis scale
|
|
155
|
+
for y in range(max_count, 0, -max(1, max_count // 10)):
|
|
156
|
+
line = f"{y:>6} │"
|
|
157
|
+
for score in range(max_score + 1):
|
|
158
|
+
count = score_counts.get(score, 0)
|
|
159
|
+
# Each score gets 10 characters width
|
|
160
|
+
bar_height = int(count * 10 / max_count) if max_count > 0 else 0
|
|
161
|
+
if count >= y:
|
|
162
|
+
line += " ████████ "
|
|
163
|
+
else:
|
|
164
|
+
line += " "
|
|
165
|
+
print(line)
|
|
166
|
+
|
|
167
|
+
# X-axis
|
|
168
|
+
print(f"{'':>6} └" + "─" * (10 * (max_score + 1)))
|
|
169
|
+
|
|
170
|
+
# X-axis labels
|
|
171
|
+
x_labels = " "
|
|
172
|
+
for score in range(max_score + 1):
|
|
173
|
+
x_labels += f" {score} "
|
|
174
|
+
print(x_labels)
|
|
175
|
+
print("\n" + "Number of Achievements (Score)".center(70))
|
|
176
|
+
|
|
177
|
+
# Statistics box
|
|
178
|
+
print("\n┌─ Statistics ─────────────────────────┐")
|
|
179
|
+
print(f"│ Total traces: {len(scores):<23}│")
|
|
180
|
+
print(f"│ Mean score: {np.mean(scores):<25.2f}│")
|
|
181
|
+
print(f"│ Median score: {np.median(scores):<23.1f}│")
|
|
182
|
+
print(f"│ Max score: {max(scores):<26.0f}│")
|
|
183
|
+
print(f"│ Traces with score > 0: {sum(1 for s in scores if s > 0):<14}│")
|
|
184
|
+
print("└──────────────────────────────────────┘")
|
|
185
|
+
|
|
186
|
+
# 2. Achievement Distribution
|
|
187
|
+
all_achievements = []
|
|
188
|
+
for achievements in achievements_per_trace.values():
|
|
189
|
+
all_achievements.extend(achievements)
|
|
190
|
+
|
|
191
|
+
from collections import Counter
|
|
192
|
+
achievement_counts = Counter(all_achievements)
|
|
193
|
+
|
|
194
|
+
if achievement_counts:
|
|
195
|
+
top_achievements = achievement_counts.most_common(10)
|
|
196
|
+
categories = [ach for ach, _ in top_achievements]
|
|
197
|
+
values = [count for _, count in top_achievements]
|
|
198
|
+
|
|
199
|
+
print("\n" + create_bar_chart(
|
|
200
|
+
categories,
|
|
201
|
+
values,
|
|
202
|
+
width=40,
|
|
203
|
+
title="🏆 Top 10 Achievements Unlocked",
|
|
204
|
+
show_values=True
|
|
205
|
+
))
|
|
206
|
+
|
|
207
|
+
# 3. Window Analysis (if provided)
|
|
208
|
+
if window_scores:
|
|
209
|
+
window_score_values = [score for _, _, _, score in window_scores]
|
|
210
|
+
unique_window_scores = sorted(set(window_score_values))
|
|
211
|
+
|
|
212
|
+
print("\n┌─ Window Analysis ────────────────────┐")
|
|
213
|
+
print(f"│ Total windows analyzed: {len(window_scores):<13}│")
|
|
214
|
+
print(f"│ Windows with score > 0: {sum(1 for s in window_score_values if s > 0):<13}│")
|
|
215
|
+
print(f"│ Unique score values: {unique_window_scores}".ljust(39) + "│")
|
|
216
|
+
print("└──────────────────────────────────────┘")
|
|
217
|
+
|
|
218
|
+
# 4. Filtering Recommendations
|
|
219
|
+
print("\n" + "╔" + "═" * 78 + "╗")
|
|
220
|
+
print("║" + " FILTERING RECOMMENDATIONS ".center(78) + "║")
|
|
221
|
+
print("╚" + "═" * 78 + "╝")
|
|
222
|
+
|
|
223
|
+
score_thresholds = [1, 2, 3]
|
|
224
|
+
print("\n┌─ Trajectory Filtering ───────────────┐")
|
|
225
|
+
for threshold in score_thresholds:
|
|
226
|
+
count = sum(1 for s in scores if s >= threshold)
|
|
227
|
+
pct = count / len(scores) * 100 if scores else 0
|
|
228
|
+
print(f"│ Score ≥ {threshold}: {count:>3} traces ({pct:>5.1f}%) │")
|
|
229
|
+
print("└──────────────────────────────────────┘")
|
|
230
|
+
|
|
231
|
+
if window_scores:
|
|
232
|
+
print("\n┌─ Window Filtering ───────────────────┐")
|
|
233
|
+
window_thresholds = [1, 2]
|
|
234
|
+
for threshold in window_thresholds:
|
|
235
|
+
count = sum(1 for s in window_score_values if s >= threshold)
|
|
236
|
+
pct = count / len(window_scores) * 100 if window_scores else 0
|
|
237
|
+
print(f"│ Score ≥ {threshold}: {count:>3} windows ({pct:>5.1f}%) │")
|
|
238
|
+
print("└──────────────────────────────────────┘")
|
|
239
|
+
|
|
240
|
+
|
|
241
|
+
def load_trace(trace_file: Path) -> Dict[str, Any]:
|
|
242
|
+
"""Load a trace file."""
|
|
243
|
+
with open(trace_file, 'r') as f:
|
|
244
|
+
return json.load(f)
|
|
245
|
+
|
|
246
|
+
|
|
247
|
+
def extract_trajectory_score(trace: Dict[str, Any]) -> float:
|
|
248
|
+
"""Extract the trajectory score from a trace."""
|
|
249
|
+
# Look for episode results in metadata
|
|
250
|
+
metadata = trace.get('session_metadata', [])
|
|
251
|
+
|
|
252
|
+
# Handle list format with metadata_type
|
|
253
|
+
if isinstance(metadata, list):
|
|
254
|
+
# Find episode_results in the list
|
|
255
|
+
for item in metadata:
|
|
256
|
+
if isinstance(item, dict) and item.get('metadata_type') == 'episode_results':
|
|
257
|
+
episode_results = item.get('data', {})
|
|
258
|
+
break
|
|
259
|
+
else:
|
|
260
|
+
episode_results = {}
|
|
261
|
+
else:
|
|
262
|
+
episode_results = metadata.get('episode_results', {})
|
|
263
|
+
|
|
264
|
+
# Use number of achievements as the primary score
|
|
265
|
+
num_achievements = episode_results.get('num_achievements', 0)
|
|
266
|
+
|
|
267
|
+
# Could also use shaped reward if available
|
|
268
|
+
# total_reward = episode_results.get('total_reward', 0)
|
|
269
|
+
|
|
270
|
+
return float(num_achievements)
|
|
271
|
+
|
|
272
|
+
|
|
273
|
+
def extract_llm_calls(trace: Dict[str, Any], hook_config: Optional[Dict[str, Any]] = None) -> List[Tuple[int, Dict[str, Any], Dict[str, Any]]]:
|
|
274
|
+
"""Extract all LLM calls from a trace with their turn numbers and event metadata.
|
|
275
|
+
|
|
276
|
+
Returns list of (turn_number, llm_record, event) tuples.
|
|
277
|
+
"""
|
|
278
|
+
llm_calls = []
|
|
279
|
+
|
|
280
|
+
# Get events that contain LLM calls
|
|
281
|
+
events = trace.get('event_history', [])
|
|
282
|
+
|
|
283
|
+
exclude_hooks = hook_config.get('exclude_hooks', []) if hook_config else []
|
|
284
|
+
include_hooks = hook_config.get('include_hooks', []) if hook_config else []
|
|
285
|
+
|
|
286
|
+
for event in events:
|
|
287
|
+
# Look for CAISEvents from the agent
|
|
288
|
+
if event.get('system_instance_id', '').startswith('crafter-react-agent'):
|
|
289
|
+
# Check hook filtering
|
|
290
|
+
event_hooks = event.get('hooks_triggered', [])
|
|
291
|
+
|
|
292
|
+
# Skip if any exclude hooks were triggered
|
|
293
|
+
if exclude_hooks and any(hook in event_hooks for hook in exclude_hooks):
|
|
294
|
+
continue
|
|
295
|
+
|
|
296
|
+
# Skip if include_hooks specified but none were triggered
|
|
297
|
+
if include_hooks and not any(hook in event_hooks for hook in include_hooks):
|
|
298
|
+
continue
|
|
299
|
+
|
|
300
|
+
# Get the LLM call records
|
|
301
|
+
llm_records = event.get('llm_call_records', [])
|
|
302
|
+
turn = event.get('time_record', {}).get('message_time', 0)
|
|
303
|
+
|
|
304
|
+
for record in llm_records:
|
|
305
|
+
if record:
|
|
306
|
+
llm_calls.append((turn, record, event))
|
|
307
|
+
|
|
308
|
+
return llm_calls
|
|
309
|
+
|
|
310
|
+
|
|
311
|
+
def calculate_window_score(trace: Dict[str, Any], start_turn: int, end_turn: int) -> float:
|
|
312
|
+
"""Calculate score for a window of turns."""
|
|
313
|
+
# Count achievements unlocked in this window
|
|
314
|
+
achievements_before = set()
|
|
315
|
+
achievements_after = set()
|
|
316
|
+
|
|
317
|
+
# Get messages to track achievement changes
|
|
318
|
+
messages = trace.get('message_history', [])
|
|
319
|
+
|
|
320
|
+
for message in messages:
|
|
321
|
+
turn = message.get('time_record', {}).get('message_time', -1)
|
|
322
|
+
if message.get('message_type') == 'observation':
|
|
323
|
+
obs = message.get('content', {}).get('payload', {})
|
|
324
|
+
achievements = obs.get('achievements_status', {})
|
|
325
|
+
|
|
326
|
+
if turn == start_turn - 1:
|
|
327
|
+
# Achievements before window
|
|
328
|
+
achievements_before = {k for k, v in achievements.items() if v}
|
|
329
|
+
elif turn == end_turn:
|
|
330
|
+
# Achievements after window
|
|
331
|
+
achievements_after = {k for k, v in achievements.items() if v}
|
|
332
|
+
|
|
333
|
+
# Score is number of new achievements unlocked in window
|
|
334
|
+
new_achievements = achievements_after - achievements_before
|
|
335
|
+
return len(new_achievements)
|
|
336
|
+
|
|
337
|
+
|
|
338
|
+
def convert_to_openai_format(llm_call: Dict[str, Any], quality_config: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
|
|
339
|
+
"""Convert an LLM call record to OpenAI fine-tuning format with quality filtering."""
|
|
340
|
+
# Extract messages
|
|
341
|
+
messages = llm_call.get('messages', [])
|
|
342
|
+
|
|
343
|
+
# Extract the completion (assistant's response)
|
|
344
|
+
response = llm_call.get('response', {})
|
|
345
|
+
choices = response.get('choices', [])
|
|
346
|
+
|
|
347
|
+
if choices:
|
|
348
|
+
assistant_message = choices[0].get('message', {})
|
|
349
|
+
content = assistant_message.get('content', '')
|
|
350
|
+
|
|
351
|
+
# Apply quality filters if provided
|
|
352
|
+
if quality_config:
|
|
353
|
+
# Check minimum response length
|
|
354
|
+
min_length = quality_config.get('min_response_length', 0)
|
|
355
|
+
if content and len(content) < min_length:
|
|
356
|
+
return None
|
|
357
|
+
elif not content and min_length > 0:
|
|
358
|
+
return None
|
|
359
|
+
|
|
360
|
+
# Check if tool calls required
|
|
361
|
+
require_tools = quality_config.get('require_tool_calls', False)
|
|
362
|
+
tool_calls = assistant_message.get('tool_calls', [])
|
|
363
|
+
if require_tools and not tool_calls:
|
|
364
|
+
return None
|
|
365
|
+
|
|
366
|
+
# Check excluded keywords
|
|
367
|
+
exclude_keywords = quality_config.get('exclude_keywords', [])
|
|
368
|
+
if content and exclude_keywords and any(keyword.lower() in content.lower() for keyword in exclude_keywords):
|
|
369
|
+
return None
|
|
370
|
+
|
|
371
|
+
# Build the completion message
|
|
372
|
+
completion = {
|
|
373
|
+
"role": "assistant",
|
|
374
|
+
"content": content,
|
|
375
|
+
}
|
|
376
|
+
|
|
377
|
+
# Add tool calls if present
|
|
378
|
+
tool_calls = assistant_message.get('tool_calls', [])
|
|
379
|
+
if tool_calls:
|
|
380
|
+
completion["tool_calls"] = tool_calls
|
|
381
|
+
|
|
382
|
+
# Create the training example
|
|
383
|
+
return {
|
|
384
|
+
"messages": messages + [completion]
|
|
385
|
+
}
|
|
386
|
+
|
|
387
|
+
return None
|
|
388
|
+
|
|
389
|
+
|
|
390
|
+
def filter_by_trajectory_score(traces_dir: Path, output_file: Path, config: Dict[str, Any]):
|
|
391
|
+
"""Filter entire trajectories by score threshold with hook and quality filtering."""
|
|
392
|
+
examples = []
|
|
393
|
+
trace_scores = []
|
|
394
|
+
included_count = 0
|
|
395
|
+
hook_filtered_count = 0
|
|
396
|
+
quality_filtered_count = 0
|
|
397
|
+
trace_contributions = {} # Track which traces contributed examples
|
|
398
|
+
included_trace_scores = {} # Track scores of included traces
|
|
399
|
+
|
|
400
|
+
score_threshold = config.get('trajectory_filtering', {}).get('score_threshold', 2.0)
|
|
401
|
+
hook_config = config.get('hook_filtering', {})
|
|
402
|
+
quality_config = config.get('quality_filtering', {})
|
|
403
|
+
|
|
404
|
+
# Process all trace files
|
|
405
|
+
trace_files = sorted(traces_dir.glob("*.json"))
|
|
406
|
+
print(f"Processing {len(trace_files)} trace files...")
|
|
407
|
+
|
|
408
|
+
for trace_file in trace_files:
|
|
409
|
+
trace = load_trace(trace_file)
|
|
410
|
+
score = extract_trajectory_score(trace)
|
|
411
|
+
trace_scores.append((trace_file.name, score))
|
|
412
|
+
|
|
413
|
+
if score >= score_threshold:
|
|
414
|
+
# Extract all LLM calls from this trajectory
|
|
415
|
+
llm_calls = extract_llm_calls(trace, hook_config)
|
|
416
|
+
initial_count = len(llm_calls)
|
|
417
|
+
|
|
418
|
+
trajectory_examples = []
|
|
419
|
+
for turn, llm_call, event in llm_calls:
|
|
420
|
+
example = convert_to_openai_format(llm_call, quality_config)
|
|
421
|
+
if example:
|
|
422
|
+
trajectory_examples.append(example)
|
|
423
|
+
else:
|
|
424
|
+
quality_filtered_count += 1
|
|
425
|
+
|
|
426
|
+
if trajectory_examples:
|
|
427
|
+
trace_contributions[trace_file.name] = len(trajectory_examples)
|
|
428
|
+
included_trace_scores[trace_file.name] = score
|
|
429
|
+
|
|
430
|
+
hook_filtered_count += initial_count - len(llm_calls)
|
|
431
|
+
examples.extend(trajectory_examples)
|
|
432
|
+
included_count += 1
|
|
433
|
+
|
|
434
|
+
# Save examples
|
|
435
|
+
with open(output_file, 'w') as f:
|
|
436
|
+
for example in examples:
|
|
437
|
+
f.write(json.dumps(example) + '\n')
|
|
438
|
+
|
|
439
|
+
print(f"\n✓ Included {included_count}/{len(trace_files)} traces (score ≥ {score_threshold})")
|
|
440
|
+
if hook_filtered_count > 0:
|
|
441
|
+
print(f"✓ Filtered out {hook_filtered_count} events due to hook exclusions")
|
|
442
|
+
if quality_filtered_count > 0:
|
|
443
|
+
print(f"✓ Filtered out {quality_filtered_count} events due to quality filters")
|
|
444
|
+
print(f"✓ Extracted {len(examples)} training examples from {len(trace_contributions)} unique traces")
|
|
445
|
+
|
|
446
|
+
return trace_contributions, included_trace_scores
|
|
447
|
+
|
|
448
|
+
|
|
449
|
+
def filter_by_window_score(traces_dir: Path, output_file: Path, config: Dict[str, Any]):
|
|
450
|
+
"""Filter by sliding window with greedy extraction and hook/quality filtering."""
|
|
451
|
+
all_examples = []
|
|
452
|
+
window_stats = defaultdict(int)
|
|
453
|
+
traces_with_windows = 0
|
|
454
|
+
hook_filtered_count = 0
|
|
455
|
+
quality_filtered_count = 0
|
|
456
|
+
trace_contributions = {} # Track which traces contributed examples
|
|
457
|
+
window_scores_by_trace = defaultdict(list) # Track window scores per trace
|
|
458
|
+
trace_overall_scores = {} # Track overall trace scores
|
|
459
|
+
|
|
460
|
+
window_config = config.get('window_filtering', {})
|
|
461
|
+
window_size = window_config.get('window_size', 5)
|
|
462
|
+
score_threshold = window_config.get('score_threshold', 1.0)
|
|
463
|
+
hook_config = config.get('hook_filtering', {})
|
|
464
|
+
quality_config = config.get('quality_filtering', {})
|
|
465
|
+
|
|
466
|
+
# Process all trace files
|
|
467
|
+
trace_files = sorted(traces_dir.glob("*.json"))
|
|
468
|
+
print(f"Processing {len(trace_files)} trace files with window_size={window_size}...")
|
|
469
|
+
|
|
470
|
+
for trace_file in trace_files:
|
|
471
|
+
trace = load_trace(trace_file)
|
|
472
|
+
trace_score = extract_trajectory_score(trace) # Get overall trace score
|
|
473
|
+
llm_calls = extract_llm_calls(trace, hook_config)
|
|
474
|
+
|
|
475
|
+
if not llm_calls:
|
|
476
|
+
continue
|
|
477
|
+
|
|
478
|
+
# Extract examples using greedy window approach
|
|
479
|
+
examples = []
|
|
480
|
+
used_turns = set()
|
|
481
|
+
found_window = False
|
|
482
|
+
trace_window_count = 0
|
|
483
|
+
|
|
484
|
+
# Get max turn number
|
|
485
|
+
max_turn = max(turn for turn, _, _ in llm_calls)
|
|
486
|
+
|
|
487
|
+
# Try all possible windows
|
|
488
|
+
for start in range(0, max_turn - window_size + 2):
|
|
489
|
+
end = start + window_size - 1
|
|
490
|
+
|
|
491
|
+
# Skip if any turn in window already used
|
|
492
|
+
if any(t in used_turns for t in range(start, end + 1)):
|
|
493
|
+
continue
|
|
494
|
+
|
|
495
|
+
# Calculate window score
|
|
496
|
+
score = calculate_window_score(trace, start, end)
|
|
497
|
+
|
|
498
|
+
if score >= score_threshold:
|
|
499
|
+
# Extract LLM calls from this window
|
|
500
|
+
window_examples = []
|
|
501
|
+
for turn, llm_call, event in llm_calls:
|
|
502
|
+
if start <= turn <= end:
|
|
503
|
+
example = convert_to_openai_format(llm_call, quality_config)
|
|
504
|
+
if example:
|
|
505
|
+
window_examples.append(example)
|
|
506
|
+
else:
|
|
507
|
+
quality_filtered_count += 1
|
|
508
|
+
|
|
509
|
+
if window_examples:
|
|
510
|
+
examples.extend(window_examples)
|
|
511
|
+
trace_window_count += len(window_examples)
|
|
512
|
+
window_scores_by_trace[trace_file.name].append(score)
|
|
513
|
+
# Mark all turns in window as used
|
|
514
|
+
for t in range(start, end + 1):
|
|
515
|
+
used_turns.add(t)
|
|
516
|
+
|
|
517
|
+
window_stats[score] += 1
|
|
518
|
+
found_window = True
|
|
519
|
+
|
|
520
|
+
if found_window:
|
|
521
|
+
traces_with_windows += 1
|
|
522
|
+
trace_contributions[trace_file.name] = trace_window_count
|
|
523
|
+
trace_overall_scores[trace_file.name] = trace_score
|
|
524
|
+
|
|
525
|
+
all_examples.extend(examples)
|
|
526
|
+
|
|
527
|
+
# Save examples
|
|
528
|
+
with open(output_file, 'w') as f:
|
|
529
|
+
for example in all_examples:
|
|
530
|
+
f.write(json.dumps(example) + '\n')
|
|
531
|
+
|
|
532
|
+
total_windows = sum(window_stats.values())
|
|
533
|
+
print(f"\n✓ Found qualifying windows in {traces_with_windows}/{len(trace_files)} traces")
|
|
534
|
+
print(f"✓ Extracted {total_windows} windows (score ≥ {score_threshold})")
|
|
535
|
+
if hook_filtered_count > 0:
|
|
536
|
+
print(f"✓ Filtered out {hook_filtered_count} events due to hook exclusions")
|
|
537
|
+
if quality_filtered_count > 0:
|
|
538
|
+
print(f"✓ Filtered out {quality_filtered_count} events due to quality filters")
|
|
539
|
+
print(f"✓ Generated {len(all_examples)} training examples from {len(trace_contributions)} unique traces")
|
|
540
|
+
|
|
541
|
+
return trace_contributions, window_scores_by_trace, trace_overall_scores
|
|
542
|
+
|
|
543
|
+
|
|
544
|
+
def analyze_traces(traces_dir: Path, analyze_windows: bool = True, hook_config: Optional[Dict[str, Any]] = None):
|
|
545
|
+
"""Analyze traces to help choose thresholds."""
|
|
546
|
+
scores = []
|
|
547
|
+
achievements_per_trace = {}
|
|
548
|
+
window_scores = []
|
|
549
|
+
|
|
550
|
+
trace_files = sorted(traces_dir.glob("*.json"))
|
|
551
|
+
|
|
552
|
+
print(f"Analyzing {len(trace_files)} trace files...")
|
|
553
|
+
|
|
554
|
+
for trace_file in trace_files:
|
|
555
|
+
trace = load_trace(trace_file)
|
|
556
|
+
score = extract_trajectory_score(trace)
|
|
557
|
+
scores.append(score)
|
|
558
|
+
|
|
559
|
+
# Get achievement details
|
|
560
|
+
metadata = trace.get('session_metadata', [])
|
|
561
|
+
|
|
562
|
+
# Handle list format with metadata_type
|
|
563
|
+
if isinstance(metadata, list):
|
|
564
|
+
for item in metadata:
|
|
565
|
+
if isinstance(item, dict) and item.get('metadata_type') == 'episode_results':
|
|
566
|
+
episode_results = item.get('data', {})
|
|
567
|
+
break
|
|
568
|
+
else:
|
|
569
|
+
episode_results = {}
|
|
570
|
+
else:
|
|
571
|
+
episode_results = metadata.get('episode_results', {})
|
|
572
|
+
|
|
573
|
+
achievements = episode_results.get('achievements', {})
|
|
574
|
+
unlocked = [k for k, v in achievements.items() if v]
|
|
575
|
+
achievements_per_trace[trace_file.name] = unlocked
|
|
576
|
+
|
|
577
|
+
# Analyze windows if requested
|
|
578
|
+
if analyze_windows:
|
|
579
|
+
llm_calls = extract_llm_calls(trace, hook_config)
|
|
580
|
+
if llm_calls:
|
|
581
|
+
max_turn = max(turn for turn, _, _ in llm_calls)
|
|
582
|
+
# Check all possible 5-turn windows
|
|
583
|
+
for start in range(0, max_turn - 4):
|
|
584
|
+
end = start + 4 # 5-turn window
|
|
585
|
+
window_score = calculate_window_score(trace, start, end)
|
|
586
|
+
if window_score > 0: # Only track windows with achievements
|
|
587
|
+
window_scores.append((trace_file.name, start, end, window_score))
|
|
588
|
+
|
|
589
|
+
# Display results
|
|
590
|
+
display_analysis_results(scores, achievements_per_trace, window_scores if analyze_windows else None)
|
|
591
|
+
|
|
592
|
+
|
|
593
|
+
def main():
|
|
594
|
+
parser = argparse.ArgumentParser(description="Filter traces for SFT data")
|
|
595
|
+
parser.add_argument("traces_dir", type=Path, help="Directory containing trace files")
|
|
596
|
+
parser.add_argument("--config", type=Path, default=None,
|
|
597
|
+
help="Path to TOML configuration file")
|
|
598
|
+
parser.add_argument("--analyze", action="store_true", help="Analyze traces to help choose thresholds")
|
|
599
|
+
parser.add_argument("--interactive", "-i", action="store_true", help="Interactive mode for choosing thresholds")
|
|
600
|
+
|
|
601
|
+
# Trajectory filtering options
|
|
602
|
+
parser.add_argument("--trajectory-threshold", type=float, default=None,
|
|
603
|
+
help="Minimum trajectory score for inclusion")
|
|
604
|
+
|
|
605
|
+
# Window filtering options
|
|
606
|
+
parser.add_argument("--window-size", type=int, default=None,
|
|
607
|
+
help="Window size for window-based filtering")
|
|
608
|
+
parser.add_argument("--window-threshold", type=float, default=None,
|
|
609
|
+
help="Minimum window score for inclusion")
|
|
610
|
+
|
|
611
|
+
args = parser.parse_args()
|
|
612
|
+
|
|
613
|
+
# Load configuration
|
|
614
|
+
config = {}
|
|
615
|
+
if args.config and args.config.exists():
|
|
616
|
+
config = toml.load(args.config)
|
|
617
|
+
print(f"Loaded configuration from {args.config}")
|
|
618
|
+
else:
|
|
619
|
+
# Look for default config in same directory as traces
|
|
620
|
+
default_config = args.traces_dir.parent / "filter_config.toml"
|
|
621
|
+
if default_config.exists():
|
|
622
|
+
config = toml.load(default_config)
|
|
623
|
+
print(f"Loaded configuration from {default_config}")
|
|
624
|
+
|
|
625
|
+
# Override config with command line arguments
|
|
626
|
+
if args.trajectory_threshold is not None:
|
|
627
|
+
config.setdefault('trajectory_filtering', {})['score_threshold'] = args.trajectory_threshold
|
|
628
|
+
if args.window_size is not None:
|
|
629
|
+
config.setdefault('window_filtering', {})['window_size'] = args.window_size
|
|
630
|
+
if args.window_threshold is not None:
|
|
631
|
+
config.setdefault('window_filtering', {})['score_threshold'] = args.window_threshold
|
|
632
|
+
|
|
633
|
+
if not args.traces_dir.exists():
|
|
634
|
+
print(f"Error: Traces directory not found: {args.traces_dir}")
|
|
635
|
+
return
|
|
636
|
+
|
|
637
|
+
if args.analyze:
|
|
638
|
+
hook_config = config.get('hook_filtering', {})
|
|
639
|
+
analyze_traces(args.traces_dir, hook_config=hook_config)
|
|
640
|
+
return
|
|
641
|
+
|
|
642
|
+
# Interactive mode or direct filtering
|
|
643
|
+
traj_threshold = config.get('trajectory_filtering', {}).get('score_threshold')
|
|
644
|
+
window_threshold = config.get('window_filtering', {}).get('score_threshold')
|
|
645
|
+
|
|
646
|
+
if args.interactive or (traj_threshold is None and window_threshold is None):
|
|
647
|
+
# First show analysis
|
|
648
|
+
hook_config = config.get('hook_filtering', {})
|
|
649
|
+
analyze_traces(args.traces_dir, hook_config=hook_config)
|
|
650
|
+
|
|
651
|
+
print("\n" + "╔" + "═" * 78 + "╗")
|
|
652
|
+
print("║" + " INTERACTIVE THRESHOLD SELECTION ".center(78) + "║")
|
|
653
|
+
print("╚" + "═" * 78 + "╝")
|
|
654
|
+
|
|
655
|
+
# Get thresholds interactively
|
|
656
|
+
print("\nBased on the analysis above, please choose filtering thresholds:")
|
|
657
|
+
|
|
658
|
+
if traj_threshold is None:
|
|
659
|
+
while True:
|
|
660
|
+
try:
|
|
661
|
+
traj_input = input("\n📊 Trajectory score threshold (e.g., 2.0): ")
|
|
662
|
+
traj_threshold = float(traj_input)
|
|
663
|
+
config.setdefault('trajectory_filtering', {})['score_threshold'] = traj_threshold
|
|
664
|
+
break
|
|
665
|
+
except ValueError:
|
|
666
|
+
print("Please enter a valid number.")
|
|
667
|
+
|
|
668
|
+
if window_threshold is None:
|
|
669
|
+
while True:
|
|
670
|
+
try:
|
|
671
|
+
window_input = input("📊 Window score threshold (e.g., 1.0): ")
|
|
672
|
+
window_threshold = float(window_input)
|
|
673
|
+
config.setdefault('window_filtering', {})['score_threshold'] = window_threshold
|
|
674
|
+
break
|
|
675
|
+
except ValueError:
|
|
676
|
+
print("Please enter a valid number.")
|
|
677
|
+
|
|
678
|
+
print(f"\nYou selected:")
|
|
679
|
+
print(f" • Trajectory threshold: {traj_threshold}")
|
|
680
|
+
print(f" • Window threshold: {window_threshold}")
|
|
681
|
+
|
|
682
|
+
# Get custom filenames
|
|
683
|
+
print(f"\nOutput file names:")
|
|
684
|
+
traj_filename = input("📁 Trajectory output filename (default: trajectory_score.jsonl): ").strip()
|
|
685
|
+
if not traj_filename:
|
|
686
|
+
traj_filename = "trajectory_score.jsonl"
|
|
687
|
+
elif not traj_filename.endswith('.jsonl'):
|
|
688
|
+
traj_filename += '.jsonl'
|
|
689
|
+
|
|
690
|
+
window_filename = input("📁 Window output filename (default: window_score.jsonl): ").strip()
|
|
691
|
+
if not window_filename:
|
|
692
|
+
window_filename = "window_score.jsonl"
|
|
693
|
+
elif not window_filename.endswith('.jsonl'):
|
|
694
|
+
window_filename += '.jsonl'
|
|
695
|
+
|
|
696
|
+
# Store filenames in config for later use
|
|
697
|
+
config.setdefault('output', {})['trajectory_file'] = traj_filename
|
|
698
|
+
config.setdefault('output', {})['window_file'] = window_filename
|
|
699
|
+
|
|
700
|
+
print(f"\nFiles will be saved as:")
|
|
701
|
+
print(f" • Trajectory data: {traj_filename}")
|
|
702
|
+
print(f" • Window data: {window_filename}")
|
|
703
|
+
|
|
704
|
+
confirm = input("\nProceed with filtering? (y/n): ")
|
|
705
|
+
if confirm.lower() != 'y':
|
|
706
|
+
print("Filtering cancelled.")
|
|
707
|
+
return
|
|
708
|
+
|
|
709
|
+
# Ensure we have defaults if still None
|
|
710
|
+
if traj_threshold is None:
|
|
711
|
+
config.setdefault('trajectory_filtering', {})['score_threshold'] = 2.0
|
|
712
|
+
if window_threshold is None:
|
|
713
|
+
config.setdefault('window_filtering', {})['score_threshold'] = 1.0
|
|
714
|
+
|
|
715
|
+
# Show configuration summary if hook or quality filtering enabled
|
|
716
|
+
hook_config = config.get('hook_filtering', {})
|
|
717
|
+
quality_config = config.get('quality_filtering', {})
|
|
718
|
+
|
|
719
|
+
if hook_config.get('exclude_hooks') or hook_config.get('include_hooks') or quality_config:
|
|
720
|
+
print("\n" + "Configuration Summary".center(50))
|
|
721
|
+
print("=" * 50)
|
|
722
|
+
|
|
723
|
+
if hook_config.get('exclude_hooks'):
|
|
724
|
+
print(f"Excluding events with hooks: {hook_config['exclude_hooks']}")
|
|
725
|
+
if hook_config.get('include_hooks'):
|
|
726
|
+
print(f"Including only events with hooks: {hook_config['include_hooks']}")
|
|
727
|
+
|
|
728
|
+
if quality_config.get('min_response_length'):
|
|
729
|
+
print(f"Min response length: {quality_config['min_response_length']}")
|
|
730
|
+
if quality_config.get('require_tool_calls'):
|
|
731
|
+
print("Requiring tool calls in responses")
|
|
732
|
+
if quality_config.get('exclude_keywords'):
|
|
733
|
+
print(f"Excluding keywords: {quality_config['exclude_keywords']}")
|
|
734
|
+
|
|
735
|
+
# Create ft_dataset directory
|
|
736
|
+
# Get the agent_demos directory (where this script is located)
|
|
737
|
+
script_dir = Path(__file__).parent
|
|
738
|
+
ft_dataset_dir = script_dir / "ft_dataset"
|
|
739
|
+
ft_dataset_dir.mkdir(exist_ok=True)
|
|
740
|
+
|
|
741
|
+
# Get output file names from config
|
|
742
|
+
output_config = config.get('output', {})
|
|
743
|
+
traj_filename = output_config.get('trajectory_file', 'trajectory_score.jsonl')
|
|
744
|
+
window_filename = output_config.get('window_file', 'window_score.jsonl')
|
|
745
|
+
|
|
746
|
+
# Run both filtering methods
|
|
747
|
+
print("\nRunning trajectory-based filtering...")
|
|
748
|
+
print("=" * 50)
|
|
749
|
+
traj_contributions, traj_trace_scores = filter_by_trajectory_score(
|
|
750
|
+
args.traces_dir,
|
|
751
|
+
ft_dataset_dir / traj_filename,
|
|
752
|
+
config
|
|
753
|
+
)
|
|
754
|
+
|
|
755
|
+
print("\n\nRunning window-based filtering...")
|
|
756
|
+
print("=" * 50)
|
|
757
|
+
window_contributions, window_scores_by_trace, window_trace_scores = filter_by_window_score(
|
|
758
|
+
args.traces_dir,
|
|
759
|
+
ft_dataset_dir / window_filename,
|
|
760
|
+
config
|
|
761
|
+
)
|
|
762
|
+
|
|
763
|
+
# Compare results
|
|
764
|
+
traj_file = ft_dataset_dir / "trajectory_score.jsonl"
|
|
765
|
+
window_file = ft_dataset_dir / "window_score.jsonl"
|
|
766
|
+
|
|
767
|
+
if traj_file.exists() and window_file.exists():
|
|
768
|
+
traj_count = sum(1 for _ in open(traj_file))
|
|
769
|
+
window_count = sum(1 for _ in open(window_file))
|
|
770
|
+
|
|
771
|
+
# Calculate yield rates
|
|
772
|
+
total_traces = len(list(args.traces_dir.glob("*.json")))
|
|
773
|
+
|
|
774
|
+
# For trajectory: count traces above threshold
|
|
775
|
+
traj_threshold = config.get('trajectory_filtering', {}).get('score_threshold', 2.0)
|
|
776
|
+
included_traces = 0
|
|
777
|
+
for trace_file in args.traces_dir.glob("*.json"):
|
|
778
|
+
trace = load_trace(trace_file)
|
|
779
|
+
score = extract_trajectory_score(trace)
|
|
780
|
+
if score >= traj_threshold:
|
|
781
|
+
included_traces += 1
|
|
782
|
+
|
|
783
|
+
traj_yield = (included_traces / total_traces * 100) if total_traces > 0 else 0
|
|
784
|
+
|
|
785
|
+
# For windows: count traces with qualifying windows
|
|
786
|
+
window_size = config.get('window_filtering', {}).get('window_size', 5)
|
|
787
|
+
window_threshold = config.get('window_filtering', {}).get('score_threshold', 1.0)
|
|
788
|
+
hook_config = config.get('hook_filtering', {})
|
|
789
|
+
|
|
790
|
+
traces_with_windows = 0
|
|
791
|
+
for trace_file in args.traces_dir.glob("*.json"):
|
|
792
|
+
trace = load_trace(trace_file)
|
|
793
|
+
llm_calls = extract_llm_calls(trace, hook_config)
|
|
794
|
+
if llm_calls:
|
|
795
|
+
max_turn = max(turn for turn, _, _ in llm_calls)
|
|
796
|
+
for start in range(0, max_turn - window_size + 2):
|
|
797
|
+
end = start + window_size - 1
|
|
798
|
+
score = calculate_window_score(trace, start, end)
|
|
799
|
+
if score >= window_threshold:
|
|
800
|
+
traces_with_windows += 1
|
|
801
|
+
break
|
|
802
|
+
|
|
803
|
+
window_yield = (traces_with_windows / total_traces * 100) if total_traces > 0 else 0
|
|
804
|
+
|
|
805
|
+
print("\n\nComparison:")
|
|
806
|
+
print("=" * 50)
|
|
807
|
+
print(f"Trajectory-based: {traj_count} examples ({traj_yield:.1f}% of traces)")
|
|
808
|
+
|
|
809
|
+
# Show trajectory contribution distribution
|
|
810
|
+
if traj_contributions:
|
|
811
|
+
traj_unique_count = len(traj_contributions)
|
|
812
|
+
print(f" └─ From {traj_unique_count} unique traces")
|
|
813
|
+
|
|
814
|
+
# Show distribution of examples per trace
|
|
815
|
+
example_counts = list(traj_contributions.values())
|
|
816
|
+
if example_counts:
|
|
817
|
+
avg_examples = sum(example_counts) / len(example_counts)
|
|
818
|
+
min_examples = min(example_counts)
|
|
819
|
+
max_examples = max(example_counts)
|
|
820
|
+
print(f" └─ Examples per trace: min={min_examples}, avg={avg_examples:.1f}, max={max_examples}")
|
|
821
|
+
|
|
822
|
+
# Show trace score distribution for included traces
|
|
823
|
+
if traj_trace_scores:
|
|
824
|
+
trace_score_counts = {}
|
|
825
|
+
for score in traj_trace_scores.values():
|
|
826
|
+
score_int = int(score)
|
|
827
|
+
trace_score_counts[score_int] = trace_score_counts.get(score_int, 0) + 1
|
|
828
|
+
|
|
829
|
+
print(" └─ Trace score distribution:")
|
|
830
|
+
for score in sorted(trace_score_counts.keys()):
|
|
831
|
+
count = trace_score_counts[score]
|
|
832
|
+
print(f" Score {score}: {count} traces")
|
|
833
|
+
|
|
834
|
+
print(f"\nWindow-based: {window_count} examples ({window_yield:.1f}% of traces)")
|
|
835
|
+
|
|
836
|
+
# Show window contribution distribution
|
|
837
|
+
if window_contributions:
|
|
838
|
+
window_unique_count = len(window_contributions)
|
|
839
|
+
print(f" └─ From {window_unique_count} unique traces")
|
|
840
|
+
|
|
841
|
+
# Show distribution of examples per trace
|
|
842
|
+
example_counts = list(window_contributions.values())
|
|
843
|
+
if example_counts:
|
|
844
|
+
avg_examples = sum(example_counts) / len(example_counts)
|
|
845
|
+
min_examples = min(example_counts)
|
|
846
|
+
max_examples = max(example_counts)
|
|
847
|
+
print(f" └─ Examples per trace: min={min_examples}, avg={avg_examples:.1f}, max={max_examples}")
|
|
848
|
+
|
|
849
|
+
# Show trace score distribution for traces with windows
|
|
850
|
+
if window_trace_scores:
|
|
851
|
+
trace_score_counts = {}
|
|
852
|
+
for score in window_trace_scores.values():
|
|
853
|
+
score_int = int(score)
|
|
854
|
+
trace_score_counts[score_int] = trace_score_counts.get(score_int, 0) + 1
|
|
855
|
+
|
|
856
|
+
print(" └─ Trace score distribution:")
|
|
857
|
+
for score in sorted(trace_score_counts.keys()):
|
|
858
|
+
count = trace_score_counts[score]
|
|
859
|
+
print(f" Score {score}: {count} traces")
|
|
860
|
+
|
|
861
|
+
# Show window score distribution
|
|
862
|
+
all_window_scores = []
|
|
863
|
+
for scores in window_scores_by_trace.values():
|
|
864
|
+
all_window_scores.extend(scores)
|
|
865
|
+
|
|
866
|
+
if all_window_scores:
|
|
867
|
+
score_counts = {}
|
|
868
|
+
for score in all_window_scores:
|
|
869
|
+
score_counts[int(score)] = score_counts.get(int(score), 0) + 1
|
|
870
|
+
|
|
871
|
+
print(" └─ Window score distribution:")
|
|
872
|
+
for score in sorted(score_counts.keys()):
|
|
873
|
+
count = score_counts[score]
|
|
874
|
+
print(f" Score {score}: {count} windows")
|
|
875
|
+
|
|
876
|
+
print(f"\nOutput files saved to: {ft_dataset_dir}/")
|
|
877
|
+
|
|
878
|
+
# Generate metadata
|
|
879
|
+
print("\nGenerating metadata...")
|
|
880
|
+
generate_metadata(args.traces_dir, ft_dataset_dir, config)
|
|
881
|
+
|
|
882
|
+
|
|
883
|
+
def generate_metadata(traces_dir: Path, output_dir: Path, config: Dict[str, Any]):
|
|
884
|
+
"""Generate comprehensive metadata for the filtered datasets."""
|
|
885
|
+
metadata = {
|
|
886
|
+
"dataset_creation": {
|
|
887
|
+
"source_traces_dir": traces_dir.name,
|
|
888
|
+
"num_source_traces": len(list(traces_dir.glob("*.json"))),
|
|
889
|
+
"filtering_methods": ["trajectory_score", "window_score"],
|
|
890
|
+
"config_used": config
|
|
891
|
+
}
|
|
892
|
+
}
|
|
893
|
+
|
|
894
|
+
traj_threshold = config.get('trajectory_filtering', {}).get('score_threshold', 2.0)
|
|
895
|
+
window_threshold = config.get('window_filtering', {}).get('score_threshold', 1.0)
|
|
896
|
+
window_size = config.get('window_filtering', {}).get('window_size', 5)
|
|
897
|
+
hook_config = config.get('hook_filtering', {})
|
|
898
|
+
quality_config = config.get('quality_filtering', {})
|
|
899
|
+
|
|
900
|
+
# Analyze trajectory filtering
|
|
901
|
+
traj_file = output_dir / "trajectory_score.jsonl"
|
|
902
|
+
if traj_file.exists():
|
|
903
|
+
traj_examples = sum(1 for _ in open(traj_file))
|
|
904
|
+
|
|
905
|
+
# Count included traces
|
|
906
|
+
included_traces = set()
|
|
907
|
+
trace_scores = {}
|
|
908
|
+
achievements_by_trace = {}
|
|
909
|
+
|
|
910
|
+
for trace_file in traces_dir.glob("*.json"):
|
|
911
|
+
trace = load_trace(trace_file)
|
|
912
|
+
score = extract_trajectory_score(trace)
|
|
913
|
+
trace_scores[trace_file.name] = score
|
|
914
|
+
|
|
915
|
+
if score >= traj_threshold:
|
|
916
|
+
included_traces.add(trace_file.name)
|
|
917
|
+
# Get achievements
|
|
918
|
+
metadata_list = trace.get('session_metadata', [])
|
|
919
|
+
if isinstance(metadata_list, list):
|
|
920
|
+
for item in metadata_list:
|
|
921
|
+
if isinstance(item, dict) and item.get('metadata_type') == 'episode_results':
|
|
922
|
+
episode_results = item.get('data', {})
|
|
923
|
+
break
|
|
924
|
+
else:
|
|
925
|
+
episode_results = {}
|
|
926
|
+
else:
|
|
927
|
+
episode_results = metadata_list.get('episode_results', {})
|
|
928
|
+
|
|
929
|
+
achievements = episode_results.get('achievements', {})
|
|
930
|
+
unlocked = [k for k, v in achievements.items() if v]
|
|
931
|
+
achievements_by_trace[trace_file.name] = unlocked
|
|
932
|
+
|
|
933
|
+
metadata["trajectory_filtering"] = {
|
|
934
|
+
"threshold": traj_threshold,
|
|
935
|
+
"total_traces": len(trace_scores),
|
|
936
|
+
"included_traces": len(included_traces),
|
|
937
|
+
"excluded_traces": len(trace_scores) - len(included_traces),
|
|
938
|
+
"yield_rate": (len(included_traces) / len(trace_scores) * 100) if trace_scores else 0,
|
|
939
|
+
"total_examples": traj_examples,
|
|
940
|
+
"avg_examples_per_trace": traj_examples / len(included_traces) if included_traces else 0
|
|
941
|
+
}
|
|
942
|
+
|
|
943
|
+
# Analyze window filtering
|
|
944
|
+
window_file = output_dir / "window_score.jsonl"
|
|
945
|
+
if window_file.exists():
|
|
946
|
+
window_examples = sum(1 for _ in open(window_file))
|
|
947
|
+
|
|
948
|
+
# Count traces with qualifying windows
|
|
949
|
+
traces_with_windows = set()
|
|
950
|
+
window_count = 0
|
|
951
|
+
|
|
952
|
+
for trace_file in traces_dir.glob("*.json"):
|
|
953
|
+
trace = load_trace(trace_file)
|
|
954
|
+
llm_calls = extract_llm_calls(trace, hook_config)
|
|
955
|
+
|
|
956
|
+
if llm_calls:
|
|
957
|
+
max_turn = max(turn for turn, _, _ in llm_calls)
|
|
958
|
+
for start in range(0, max_turn - window_size + 2):
|
|
959
|
+
end = start + window_size - 1
|
|
960
|
+
score = calculate_window_score(trace, start, end)
|
|
961
|
+
if score >= window_threshold:
|
|
962
|
+
traces_with_windows.add(trace_file.name)
|
|
963
|
+
window_count += 1
|
|
964
|
+
|
|
965
|
+
metadata["window_filtering"] = {
|
|
966
|
+
"window_size": window_size,
|
|
967
|
+
"threshold": window_threshold,
|
|
968
|
+
"total_traces": len(list(traces_dir.glob("*.json"))),
|
|
969
|
+
"traces_with_qualifying_windows": len(traces_with_windows),
|
|
970
|
+
"total_windows_extracted": window_count,
|
|
971
|
+
"total_examples": window_examples,
|
|
972
|
+
"avg_examples_per_window": window_size
|
|
973
|
+
}
|
|
974
|
+
|
|
975
|
+
# Save metadata
|
|
976
|
+
metadata_file = config.get('output', {}).get('metadata_file', 'metadata.json')
|
|
977
|
+
with open(output_dir / metadata_file, 'w') as f:
|
|
978
|
+
json.dump(metadata, f, indent=2)
|
|
979
|
+
|
|
980
|
+
print(f"✓ Metadata saved to {output_dir}/{metadata_file}")
|
|
981
|
+
|
|
982
|
+
|
|
983
|
+
if __name__ == "__main__":
|
|
984
|
+
main()
|