synth-ai 0.2.4.dev8__py3-none-any.whl → 0.2.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of synth-ai might be problematic. Click here for more details.
- synth_ai/__init__.py +1 -1
- synth_ai/cli/__init__.py +6 -0
- synth_ai/cli/demo.py +68 -9
- synth_ai/cli/rl_demo.py +137 -0
- synth_ai/cli/root.py +65 -0
- synth_ai/demos/core/__init__.py +1 -0
- synth_ai/demos/core/cli.py +685 -0
- synth_ai/demos/demo_task_apps/__init__.py +1 -0
- synth_ai/demos/demo_task_apps/core.py +374 -0
- synth_ai/demos/demo_task_apps/math/__init__.py +1 -0
- synth_ai/demos/demo_task_apps/math/app.py +37 -0
- synth_ai/demos/demo_task_apps/math/config.toml +44 -0
- synth_ai/demos/demo_task_apps/math/deploy_modal.py +60 -0
- synth_ai/demos/demo_task_apps/math/deploy_task_app.sh +22 -0
- synth_ai/environments/examples/bandit/__init__.py +33 -0
- synth_ai/environments/examples/bandit/engine.py +294 -0
- synth_ai/environments/examples/bandit/environment.py +194 -0
- synth_ai/environments/examples/bandit/taskset.py +200 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/analyze_semantic_words_markdown.py +250 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_comprehensive_evaluation.py +59 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_evaluation_browser.py +152 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_evaluation_config.toml +24 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_evaluation_framework.py +1194 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_modal_ft/crafter_synth_config.toml +56 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_modal_ft/filter_config_modal.toml +32 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_modal_ft/filter_traces_sft_turso.py +724 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_modal_ft/kick_off_ft_modal.py +384 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_modal_ft/old/analyze_action_results.py +53 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_modal_ft/old/analyze_agent_actions.py +178 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_modal_ft/old/analyze_latest_run.py +222 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_modal_ft/old/analyze_lm_traces.py +183 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_modal_ft/old/analyze_no_rewards.py +210 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_modal_ft/old/analyze_trace_issue.py +206 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_modal_ft/old/check_db_schema.py +49 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_modal_ft/old/check_latest_results.py +64 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_modal_ft/old/debug_agent_responses.py +88 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_modal_ft/old/quick_trace_check.py +77 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_openai_ft/compare_experiments.py +324 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_openai_ft/filter_traces_sft_turso.py +580 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_openai_ft/kick_off_ft_oai.py +362 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_openai_ft/multi_model_config.toml +49 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_openai_ft/old/analyze_enhanced_hooks.py +332 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_openai_ft/old/analyze_hook_events.py +97 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_openai_ft/old/analyze_hook_results.py +217 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_openai_ft/old/check_hook_storage.py +87 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_openai_ft/old/check_seeds.py +88 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_openai_ft/old/compare_seed_performance.py +195 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_openai_ft/old/custom_eval_pipelines.py +400 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_openai_ft/old/plot_hook_frequency.py +195 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_openai_ft/old/seed_analysis_summary.py +56 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_openai_ft/run_rollouts_for_models_and_compare_v3.py +858 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_quick_evaluation.py +52 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_react_agent.py +874 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_trace_evaluation.py +1412 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/example_v3_usage.py +216 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/old/compare_traces.py +296 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/old/crafter_comprehensive_evaluation.py +58 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/old/crafter_env_serialization.py +464 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/old/crafter_evaluation_browser.py +152 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/old/crafter_quick_evaluation.py +51 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/old/crafter_trace_evaluation.py +1412 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/old/debug_player_loss.py +112 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/old/diagnose_service.py +203 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/old/diagnose_slowness.py +305 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/old/eval_by_difficulty.py +126 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/old/eval_example.py +94 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/old/explore_saved_states.py +142 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/old/filter_traces_sft.py +26 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/old/filter_traces_sft_OLD.py +984 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/old/generate_ft_data_gemini.py +724 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/old/generate_ft_data_modal.py +386 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/old/generate_ft_metadata.py +205 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/old/kick_off_ft_gemini.py +150 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/old/kick_off_ft_modal.py +283 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/old/prepare_vertex_ft.py +280 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/old/profile_env_slowness.py +456 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/old/replicate_issue.py +166 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/old/run_and_eval.py +102 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/old/run_comparison.py +128 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/old/run_qwen_rollouts.py +655 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/old/trace_eval_OLD.py +202 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/old/validate_openai_format.py +166 -0
- synth_ai/environments/examples/crafter_classic/environment.py +41 -2
- synth_ai/environments/examples/crafter_custom/agent_demos/__init__.py +1 -0
- synth_ai/environments/examples/crafter_custom/agent_demos/trace_eval.py +202 -0
- synth_ai/environments/examples/crafter_custom/old/analyze_diamond_issue.py +159 -0
- synth_ai/environments/examples/crafter_custom/old/analyze_diamond_spawning.py +158 -0
- synth_ai/environments/examples/crafter_custom/old/compare_worlds.py +71 -0
- synth_ai/environments/examples/crafter_custom/old/dataset_stats.py +105 -0
- synth_ai/environments/examples/crafter_custom/old/diamond_spawning_summary.py +119 -0
- synth_ai/environments/examples/crafter_custom/old/example_dataset_usage.py +52 -0
- synth_ai/environments/examples/enron/units/keyword_stats.py +112 -0
- synth_ai/environments/examples/minigrid/agent_demos/minigrid_evaluation_framework.py +1188 -0
- synth_ai/environments/examples/minigrid/agent_demos/minigrid_quick_evaluation.py +48 -0
- synth_ai/environments/examples/minigrid/agent_demos/minigrid_react_agent.py +562 -0
- synth_ai/environments/examples/minigrid/agent_demos/minigrid_trace_evaluation.py +221 -0
- synth_ai/environments/examples/nethack/agent_demos/nethack_evaluation_framework.py +981 -0
- synth_ai/environments/examples/nethack/agent_demos/nethack_quick_evaluation.py +74 -0
- synth_ai/environments/examples/nethack/agent_demos/nethack_react_agent.py +831 -0
- synth_ai/environments/examples/red/agent_demos/__init__.py +1 -0
- synth_ai/environments/examples/red/units/__init__.py +1 -0
- synth_ai/environments/examples/sokoban/agent_demos/sokoban_full_eval.py +899 -0
- synth_ai/environments/examples/sokoban/units/astar_common.py +95 -0
- synth_ai/environments/service/app.py +8 -0
- synth_ai/install_sqld.sh +40 -0
- synth_ai-0.2.5.dist-info/METADATA +106 -0
- {synth_ai-0.2.4.dev8.dist-info → synth_ai-0.2.5.dist-info}/RECORD +111 -12
- {synth_ai-0.2.4.dev8.dist-info → synth_ai-0.2.5.dist-info}/entry_points.txt +1 -0
- synth_ai-0.2.4.dev8.dist-info/METADATA +0 -635
- {synth_ai-0.2.4.dev8.dist-info → synth_ai-0.2.5.dist-info}/WHEEL +0 -0
- {synth_ai-0.2.4.dev8.dist-info → synth_ai-0.2.5.dist-info}/licenses/LICENSE +0 -0
- {synth_ai-0.2.4.dev8.dist-info → synth_ai-0.2.5.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,386 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
"""
|
|
3
|
+
Script to generate and filter fine-tuning data from Crafter rollouts
|
|
4
|
+
Applies quality filters and formats data for Modal fine-tuning
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
import argparse
|
|
8
|
+
import json
|
|
9
|
+
import os
|
|
10
|
+
from collections import defaultdict
|
|
11
|
+
from datetime import datetime
|
|
12
|
+
from pathlib import Path
|
|
13
|
+
from typing import Dict, List, Optional, Tuple
|
|
14
|
+
|
|
15
|
+
import numpy as np
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class CrafterDataFilter:
|
|
19
|
+
"""Filter and process Crafter rollout data for fine-tuning."""
|
|
20
|
+
|
|
21
|
+
def __init__(
|
|
22
|
+
self,
|
|
23
|
+
min_achievements: int = 3,
|
|
24
|
+
min_reward: float = 0.0,
|
|
25
|
+
max_steps: int = 100,
|
|
26
|
+
min_steps: int = 10,
|
|
27
|
+
achievement_weights: Optional[Dict[str, float]] = None
|
|
28
|
+
):
|
|
29
|
+
self.min_achievements = min_achievements
|
|
30
|
+
self.min_reward = min_reward
|
|
31
|
+
self.max_steps = max_steps
|
|
32
|
+
self.min_steps = min_steps
|
|
33
|
+
|
|
34
|
+
# Default achievement weights (higher = more valuable)
|
|
35
|
+
self.achievement_weights = achievement_weights or {
|
|
36
|
+
# Basic resources
|
|
37
|
+
"collect_wood": 1.0,
|
|
38
|
+
"collect_stone": 1.5,
|
|
39
|
+
"collect_coal": 3.0,
|
|
40
|
+
"collect_iron": 5.0,
|
|
41
|
+
"collect_diamond": 10.0,
|
|
42
|
+
|
|
43
|
+
# Crafting
|
|
44
|
+
"place_table": 2.0,
|
|
45
|
+
"place_furnace": 3.0,
|
|
46
|
+
"make_wood_pickaxe": 2.5,
|
|
47
|
+
"make_stone_pickaxe": 3.5,
|
|
48
|
+
"make_iron_pickaxe": 6.0,
|
|
49
|
+
"make_wood_sword": 2.5,
|
|
50
|
+
"make_stone_sword": 3.5,
|
|
51
|
+
"make_iron_sword": 6.0,
|
|
52
|
+
|
|
53
|
+
# Survival
|
|
54
|
+
"eat_cow": 2.0,
|
|
55
|
+
"eat_plant": 1.0,
|
|
56
|
+
"collect_drink": 1.0,
|
|
57
|
+
"sleep": 1.5,
|
|
58
|
+
|
|
59
|
+
# Combat
|
|
60
|
+
"defeat_zombie": 3.0,
|
|
61
|
+
"defeat_skeleton": 4.0,
|
|
62
|
+
}
|
|
63
|
+
|
|
64
|
+
def calculate_episode_score(self, episode: dict) -> float:
|
|
65
|
+
"""Calculate a quality score for an episode."""
|
|
66
|
+
score = 0.0
|
|
67
|
+
|
|
68
|
+
# Achievement score
|
|
69
|
+
achievements = episode.get("achievements", [])
|
|
70
|
+
for ach in achievements:
|
|
71
|
+
score += self.achievement_weights.get(ach, 1.0)
|
|
72
|
+
|
|
73
|
+
# Efficiency bonus (achievements per step)
|
|
74
|
+
num_steps = len(episode.get("steps", []))
|
|
75
|
+
if num_steps > 0:
|
|
76
|
+
efficiency = len(achievements) / num_steps
|
|
77
|
+
score += efficiency * 10
|
|
78
|
+
|
|
79
|
+
# Reward contribution
|
|
80
|
+
total_reward = episode.get("total_reward", 0)
|
|
81
|
+
score += total_reward * 0.5
|
|
82
|
+
|
|
83
|
+
# Diversity bonus (unique actions)
|
|
84
|
+
unique_actions = len(set(step["action"] for step in episode.get("steps", [])))
|
|
85
|
+
score += unique_actions * 0.5
|
|
86
|
+
|
|
87
|
+
return score
|
|
88
|
+
|
|
89
|
+
def filter_episode(self, episode: dict) -> Tuple[bool, str]:
|
|
90
|
+
"""
|
|
91
|
+
Check if an episode passes quality filters.
|
|
92
|
+
Returns (passes, reason).
|
|
93
|
+
"""
|
|
94
|
+
# Check achievements
|
|
95
|
+
achievements = episode.get("achievements", [])
|
|
96
|
+
if len(achievements) < self.min_achievements:
|
|
97
|
+
return False, f"Too few achievements: {len(achievements)} < {self.min_achievements}"
|
|
98
|
+
|
|
99
|
+
# Check reward
|
|
100
|
+
total_reward = episode.get("total_reward", 0)
|
|
101
|
+
if total_reward < self.min_reward:
|
|
102
|
+
return False, f"Low reward: {total_reward} < {self.min_reward}"
|
|
103
|
+
|
|
104
|
+
# Check step count
|
|
105
|
+
num_steps = len(episode.get("steps", []))
|
|
106
|
+
if num_steps < self.min_steps:
|
|
107
|
+
return False, f"Too few steps: {num_steps} < {self.min_steps}"
|
|
108
|
+
if num_steps > self.max_steps:
|
|
109
|
+
return False, f"Too many steps: {num_steps} > {self.max_steps}"
|
|
110
|
+
|
|
111
|
+
# Check for errors
|
|
112
|
+
if episode.get("termination_reason", "").startswith("error"):
|
|
113
|
+
return False, "Episode ended with error"
|
|
114
|
+
|
|
115
|
+
return True, "Passed all filters"
|
|
116
|
+
|
|
117
|
+
def optimize_conversation(self, messages: List[dict]) -> List[dict]:
|
|
118
|
+
"""
|
|
119
|
+
Optimize conversation for fine-tuning by:
|
|
120
|
+
- Removing redundant information
|
|
121
|
+
- Condensing observations
|
|
122
|
+
- Ensuring proper format
|
|
123
|
+
"""
|
|
124
|
+
optimized = []
|
|
125
|
+
|
|
126
|
+
for msg in messages:
|
|
127
|
+
role = msg.get("role")
|
|
128
|
+
content = msg.get("content", "")
|
|
129
|
+
|
|
130
|
+
if role == "system":
|
|
131
|
+
# Keep system message as-is
|
|
132
|
+
optimized.append(msg)
|
|
133
|
+
|
|
134
|
+
elif role == "user":
|
|
135
|
+
# Condense user observations
|
|
136
|
+
if len(content) > 1000:
|
|
137
|
+
# Extract key information
|
|
138
|
+
lines = content.split("\n")
|
|
139
|
+
key_lines = []
|
|
140
|
+
|
|
141
|
+
for line in lines:
|
|
142
|
+
# Keep important lines
|
|
143
|
+
if any(keyword in line.lower() for keyword in [
|
|
144
|
+
"health:", "hunger:", "thirst:", "inventory:",
|
|
145
|
+
"achievements:", "map", "recent"
|
|
146
|
+
]):
|
|
147
|
+
key_lines.append(line)
|
|
148
|
+
|
|
149
|
+
# Keep reasoning prompt
|
|
150
|
+
if "Think step by step:" in content:
|
|
151
|
+
idx = content.index("Think step by step:")
|
|
152
|
+
key_lines.append(content[idx:])
|
|
153
|
+
|
|
154
|
+
content = "\n".join(key_lines)
|
|
155
|
+
|
|
156
|
+
optimized.append({"role": role, "content": content})
|
|
157
|
+
|
|
158
|
+
elif role == "assistant":
|
|
159
|
+
# Keep assistant responses but trim if very long
|
|
160
|
+
if len(content) > 500:
|
|
161
|
+
# Keep first part (reasoning) and action
|
|
162
|
+
lines = content.split("\n")
|
|
163
|
+
kept_lines = []
|
|
164
|
+
|
|
165
|
+
# Keep reasoning (usually first few lines)
|
|
166
|
+
for i, line in enumerate(lines[:5]):
|
|
167
|
+
kept_lines.append(line)
|
|
168
|
+
|
|
169
|
+
# Find and keep action
|
|
170
|
+
for line in lines:
|
|
171
|
+
if any(action in line for action in [
|
|
172
|
+
"move_", "do", "sleep", "place_", "make_"
|
|
173
|
+
]):
|
|
174
|
+
if line not in kept_lines:
|
|
175
|
+
kept_lines.append(line)
|
|
176
|
+
|
|
177
|
+
content = "\n".join(kept_lines)
|
|
178
|
+
|
|
179
|
+
optimized.append({"role": role, "content": content})
|
|
180
|
+
|
|
181
|
+
return optimized
|
|
182
|
+
|
|
183
|
+
def create_training_example(self, episode: dict) -> dict:
|
|
184
|
+
"""Create a training example from a filtered episode."""
|
|
185
|
+
messages = episode.get("messages", [])
|
|
186
|
+
|
|
187
|
+
# Optimize conversation
|
|
188
|
+
optimized_messages = self.optimize_conversation(messages)
|
|
189
|
+
|
|
190
|
+
# Add metadata as comment in system message
|
|
191
|
+
metadata_comment = (
|
|
192
|
+
f"# Episode stats: {len(episode['achievements'])} achievements, "
|
|
193
|
+
f"{episode['total_reward']:.1f} reward, {len(episode['steps'])} steps"
|
|
194
|
+
)
|
|
195
|
+
|
|
196
|
+
if optimized_messages and optimized_messages[0]["role"] == "system":
|
|
197
|
+
optimized_messages[0]["content"] += f"\n{metadata_comment}"
|
|
198
|
+
|
|
199
|
+
return {"messages": optimized_messages}
|
|
200
|
+
|
|
201
|
+
|
|
202
|
+
def analyze_rollout_directory(rollout_dir: Path) -> Dict[str, any]:
|
|
203
|
+
"""Analyze a rollout directory and return statistics."""
|
|
204
|
+
stats = {
|
|
205
|
+
"total_episodes": 0,
|
|
206
|
+
"total_achievements": defaultdict(int),
|
|
207
|
+
"reward_distribution": [],
|
|
208
|
+
"step_distribution": [],
|
|
209
|
+
"achievement_distribution": [],
|
|
210
|
+
}
|
|
211
|
+
|
|
212
|
+
# Find rollout data files
|
|
213
|
+
rollout_files = list(rollout_dir.glob("**/rollout_data.json"))
|
|
214
|
+
|
|
215
|
+
for rollout_file in rollout_files:
|
|
216
|
+
with open(rollout_file, "r") as f:
|
|
217
|
+
data = json.load(f)
|
|
218
|
+
|
|
219
|
+
episodes = data.get("episodes", [])
|
|
220
|
+
stats["total_episodes"] += len(episodes)
|
|
221
|
+
|
|
222
|
+
for ep in episodes:
|
|
223
|
+
# Track achievements
|
|
224
|
+
for ach in ep.get("achievements", []):
|
|
225
|
+
stats["total_achievements"][ach] += 1
|
|
226
|
+
|
|
227
|
+
# Track distributions
|
|
228
|
+
stats["reward_distribution"].append(ep.get("total_reward", 0))
|
|
229
|
+
stats["step_distribution"].append(len(ep.get("steps", [])))
|
|
230
|
+
stats["achievement_distribution"].append(len(ep.get("achievements", [])))
|
|
231
|
+
|
|
232
|
+
return stats
|
|
233
|
+
|
|
234
|
+
|
|
235
|
+
def main():
|
|
236
|
+
parser = argparse.ArgumentParser(description="Generate fine-tuning data from Crafter rollouts")
|
|
237
|
+
parser.add_argument("rollout_dir", type=str, help="Directory containing rollout data")
|
|
238
|
+
parser.add_argument("--output", type=str, default=None,
|
|
239
|
+
help="Output file (default: rollout_dir/fine_tuning_filtered.jsonl)")
|
|
240
|
+
parser.add_argument("--min-achievements", type=int, default=3,
|
|
241
|
+
help="Minimum achievements required")
|
|
242
|
+
parser.add_argument("--min-reward", type=float, default=0.0,
|
|
243
|
+
help="Minimum total reward required")
|
|
244
|
+
parser.add_argument("--max-steps", type=int, default=100,
|
|
245
|
+
help="Maximum steps allowed")
|
|
246
|
+
parser.add_argument("--min-steps", type=int, default=10,
|
|
247
|
+
help="Minimum steps required")
|
|
248
|
+
parser.add_argument("--top-k", type=int, default=None,
|
|
249
|
+
help="Only keep top K episodes by score")
|
|
250
|
+
parser.add_argument("--analyze-only", action="store_true",
|
|
251
|
+
help="Only analyze data, don't generate output")
|
|
252
|
+
parser.add_argument("--verbose", action="store_true",
|
|
253
|
+
help="Show detailed filtering information")
|
|
254
|
+
|
|
255
|
+
args = parser.parse_args()
|
|
256
|
+
|
|
257
|
+
rollout_dir = Path(args.rollout_dir)
|
|
258
|
+
if not rollout_dir.exists():
|
|
259
|
+
print(f"❌ Rollout directory not found: {rollout_dir}")
|
|
260
|
+
return 1
|
|
261
|
+
|
|
262
|
+
# Analyze rollout data
|
|
263
|
+
print(f"📊 Analyzing rollout data in {rollout_dir}...")
|
|
264
|
+
stats = analyze_rollout_directory(rollout_dir)
|
|
265
|
+
|
|
266
|
+
print(f"\n📈 Rollout Statistics:")
|
|
267
|
+
print(f" Total episodes: {stats['total_episodes']}")
|
|
268
|
+
|
|
269
|
+
if stats['total_episodes'] == 0:
|
|
270
|
+
print("❌ No episodes found in rollout directory")
|
|
271
|
+
return 1
|
|
272
|
+
|
|
273
|
+
print(f" Avg achievements: {np.mean(stats['achievement_distribution']):.1f}")
|
|
274
|
+
print(f" Avg reward: {np.mean(stats['reward_distribution']):.1f}")
|
|
275
|
+
print(f" Avg steps: {np.mean(stats['step_distribution']):.1f}")
|
|
276
|
+
|
|
277
|
+
print(f"\n🏆 Top achievements:")
|
|
278
|
+
for ach, count in sorted(stats['total_achievements'].items(),
|
|
279
|
+
key=lambda x: x[1], reverse=True)[:10]:
|
|
280
|
+
pct = count / stats['total_episodes'] * 100
|
|
281
|
+
print(f" {ach}: {count} ({pct:.1f}%)")
|
|
282
|
+
|
|
283
|
+
if args.analyze_only:
|
|
284
|
+
return 0
|
|
285
|
+
|
|
286
|
+
# Create filter
|
|
287
|
+
filter_config = CrafterDataFilter(
|
|
288
|
+
min_achievements=args.min_achievements,
|
|
289
|
+
min_reward=args.min_reward,
|
|
290
|
+
max_steps=args.max_steps,
|
|
291
|
+
min_steps=args.min_steps
|
|
292
|
+
)
|
|
293
|
+
|
|
294
|
+
# Process episodes
|
|
295
|
+
print(f"\n🔍 Filtering episodes...")
|
|
296
|
+
all_episodes = []
|
|
297
|
+
filtered_episodes = []
|
|
298
|
+
filter_reasons = defaultdict(int)
|
|
299
|
+
|
|
300
|
+
# Load all episodes
|
|
301
|
+
rollout_files = list(rollout_dir.glob("**/rollout_data.json"))
|
|
302
|
+
for rollout_file in rollout_files:
|
|
303
|
+
with open(rollout_file, "r") as f:
|
|
304
|
+
data = json.load(f)
|
|
305
|
+
|
|
306
|
+
episodes = data.get("episodes", [])
|
|
307
|
+
all_episodes.extend(episodes)
|
|
308
|
+
|
|
309
|
+
# Filter episodes
|
|
310
|
+
for ep in all_episodes:
|
|
311
|
+
passes, reason = filter_config.filter_episode(ep)
|
|
312
|
+
if passes:
|
|
313
|
+
# Calculate score for ranking
|
|
314
|
+
score = filter_config.calculate_episode_score(ep)
|
|
315
|
+
filtered_episodes.append((score, ep))
|
|
316
|
+
else:
|
|
317
|
+
filter_reasons[reason] += 1
|
|
318
|
+
if args.verbose:
|
|
319
|
+
print(f" Filtered out: {reason}")
|
|
320
|
+
|
|
321
|
+
print(f"\n📋 Filtering Results:")
|
|
322
|
+
print(f" Total episodes: {len(all_episodes)}")
|
|
323
|
+
print(f" Passed filters: {len(filtered_episodes)}")
|
|
324
|
+
print(f" Pass rate: {len(filtered_episodes)/len(all_episodes)*100:.1f}%")
|
|
325
|
+
|
|
326
|
+
if filter_reasons:
|
|
327
|
+
print(f"\n❌ Filter reasons:")
|
|
328
|
+
for reason, count in sorted(filter_reasons.items(),
|
|
329
|
+
key=lambda x: x[1], reverse=True):
|
|
330
|
+
print(f" {reason}: {count}")
|
|
331
|
+
|
|
332
|
+
if not filtered_episodes:
|
|
333
|
+
print("\n❌ No episodes passed the filters!")
|
|
334
|
+
return 1
|
|
335
|
+
|
|
336
|
+
# Sort by score and apply top-k if specified
|
|
337
|
+
filtered_episodes.sort(key=lambda x: x[0], reverse=True)
|
|
338
|
+
|
|
339
|
+
if args.top_k and args.top_k < len(filtered_episodes):
|
|
340
|
+
print(f"\n🎯 Selecting top {args.top_k} episodes by score")
|
|
341
|
+
filtered_episodes = filtered_episodes[:args.top_k]
|
|
342
|
+
|
|
343
|
+
# Generate training examples
|
|
344
|
+
print(f"\n✍️ Generating training examples...")
|
|
345
|
+
training_examples = []
|
|
346
|
+
|
|
347
|
+
for score, episode in filtered_episodes:
|
|
348
|
+
example = filter_config.create_training_example(episode)
|
|
349
|
+
training_examples.append(example)
|
|
350
|
+
|
|
351
|
+
if args.verbose:
|
|
352
|
+
achievements = episode.get("achievements", [])
|
|
353
|
+
print(f" Score: {score:.1f}, Achievements: {achievements}")
|
|
354
|
+
|
|
355
|
+
# Save output
|
|
356
|
+
output_file = args.output
|
|
357
|
+
if not output_file:
|
|
358
|
+
output_file = rollout_dir / "fine_tuning_filtered.jsonl"
|
|
359
|
+
else:
|
|
360
|
+
output_file = Path(output_file)
|
|
361
|
+
|
|
362
|
+
with open(output_file, "w") as f:
|
|
363
|
+
for example in training_examples:
|
|
364
|
+
f.write(json.dumps(example) + "\n")
|
|
365
|
+
|
|
366
|
+
print(f"\n✅ Generated {len(training_examples)} training examples")
|
|
367
|
+
print(f"📁 Saved to: {output_file}")
|
|
368
|
+
|
|
369
|
+
# Calculate token estimate
|
|
370
|
+
total_chars = sum(len(json.dumps(ex)) for ex in training_examples)
|
|
371
|
+
est_tokens = total_chars // 4
|
|
372
|
+
print(f"📊 Estimated tokens: {est_tokens:,}")
|
|
373
|
+
|
|
374
|
+
# Show sample
|
|
375
|
+
print(f"\n📝 Sample training example:")
|
|
376
|
+
sample = training_examples[0]["messages"]
|
|
377
|
+
for msg in sample[:3]: # Show first 3 messages
|
|
378
|
+
role = msg["role"]
|
|
379
|
+
content = msg["content"][:100] + "..." if len(msg["content"]) > 100 else msg["content"]
|
|
380
|
+
print(f" [{role}] {content}")
|
|
381
|
+
|
|
382
|
+
return 0
|
|
383
|
+
|
|
384
|
+
|
|
385
|
+
if __name__ == "__main__":
|
|
386
|
+
exit(main())
|
|
@@ -0,0 +1,205 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
"""
|
|
3
|
+
Generate metadata for fine-tuning datasets
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
import json
|
|
7
|
+
from pathlib import Path
|
|
8
|
+
from collections import defaultdict
|
|
9
|
+
import sys
|
|
10
|
+
|
|
11
|
+
sys.path.append(str(Path(__file__).parent))
|
|
12
|
+
from filter_traces_sft import load_trace, extract_trajectory_score, extract_llm_calls, calculate_window_score
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def analyze_trajectory_dataset(traces_dir: Path, threshold: float = 2.0):
|
|
16
|
+
"""Analyze trajectory-based filtering results."""
|
|
17
|
+
trace_files = sorted(traces_dir.glob("*.json"))
|
|
18
|
+
|
|
19
|
+
included_traces = []
|
|
20
|
+
excluded_traces = []
|
|
21
|
+
score_distribution = defaultdict(int)
|
|
22
|
+
achievement_counts = defaultdict(int)
|
|
23
|
+
total_llm_calls = 0
|
|
24
|
+
|
|
25
|
+
for trace_file in trace_files:
|
|
26
|
+
trace = load_trace(trace_file)
|
|
27
|
+
score = extract_trajectory_score(trace)
|
|
28
|
+
score_distribution[int(score)] += 1
|
|
29
|
+
|
|
30
|
+
# Get achievements
|
|
31
|
+
metadata = trace.get('session_metadata', [])
|
|
32
|
+
if isinstance(metadata, list):
|
|
33
|
+
for item in metadata:
|
|
34
|
+
if isinstance(item, dict) and item.get('metadata_type') == 'episode_results':
|
|
35
|
+
episode_results = item.get('data', {})
|
|
36
|
+
achievements = episode_results.get('achievements', {})
|
|
37
|
+
for ach, unlocked in achievements.items():
|
|
38
|
+
if unlocked:
|
|
39
|
+
achievement_counts[ach] += 1
|
|
40
|
+
break
|
|
41
|
+
|
|
42
|
+
# Count LLM calls
|
|
43
|
+
llm_calls = extract_llm_calls(trace)
|
|
44
|
+
|
|
45
|
+
if score >= threshold:
|
|
46
|
+
included_traces.append({
|
|
47
|
+
'trace_file': trace_file.name,
|
|
48
|
+
'score': score,
|
|
49
|
+
'num_llm_calls': len(llm_calls),
|
|
50
|
+
'achievements': [k for k, v in achievements.items() if v] if 'achievements' in locals() else []
|
|
51
|
+
})
|
|
52
|
+
total_llm_calls += len(llm_calls)
|
|
53
|
+
else:
|
|
54
|
+
excluded_traces.append({
|
|
55
|
+
'trace_file': trace_file.name,
|
|
56
|
+
'score': score
|
|
57
|
+
})
|
|
58
|
+
|
|
59
|
+
return {
|
|
60
|
+
'threshold': threshold,
|
|
61
|
+
'total_traces': len(trace_files),
|
|
62
|
+
'included_traces': len(included_traces),
|
|
63
|
+
'excluded_traces': len(excluded_traces),
|
|
64
|
+
'yield_rate': len(included_traces) / len(trace_files) * 100,
|
|
65
|
+
'total_examples': total_llm_calls,
|
|
66
|
+
'avg_examples_per_trace': total_llm_calls / len(included_traces) if included_traces else 0,
|
|
67
|
+
'score_distribution': dict(score_distribution),
|
|
68
|
+
'achievement_distribution': dict(achievement_counts),
|
|
69
|
+
'included_trace_details': included_traces
|
|
70
|
+
}
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
def analyze_window_dataset(traces_dir: Path, window_size: int = 5, threshold: float = 1.0):
|
|
74
|
+
"""Analyze window-based filtering results with greedy extraction."""
|
|
75
|
+
trace_files = sorted(traces_dir.glob("*.json"))
|
|
76
|
+
|
|
77
|
+
window_scores = defaultdict(int)
|
|
78
|
+
traces_with_windows = 0
|
|
79
|
+
total_windows = 0
|
|
80
|
+
total_examples = 0
|
|
81
|
+
window_details = []
|
|
82
|
+
|
|
83
|
+
for trace_file in trace_files:
|
|
84
|
+
trace = load_trace(trace_file)
|
|
85
|
+
llm_calls = extract_llm_calls(trace)
|
|
86
|
+
|
|
87
|
+
if not llm_calls:
|
|
88
|
+
continue
|
|
89
|
+
|
|
90
|
+
# Get max turn
|
|
91
|
+
max_turn = max(turn for turn, _ in llm_calls)
|
|
92
|
+
trace_has_window = False
|
|
93
|
+
used_turns = set()
|
|
94
|
+
|
|
95
|
+
# Greedy extraction - same as in filter_traces_sft.py
|
|
96
|
+
for start in range(0, max_turn - window_size + 2):
|
|
97
|
+
end = start + window_size - 1
|
|
98
|
+
|
|
99
|
+
# Skip if any turn in window already used
|
|
100
|
+
if any(t in used_turns for t in range(start, end + 1)):
|
|
101
|
+
continue
|
|
102
|
+
|
|
103
|
+
score = calculate_window_score(trace, start, end)
|
|
104
|
+
|
|
105
|
+
if score >= threshold:
|
|
106
|
+
window_scores[int(score)] += 1
|
|
107
|
+
total_windows += 1
|
|
108
|
+
trace_has_window = True
|
|
109
|
+
|
|
110
|
+
# Mark turns as used
|
|
111
|
+
for t in range(start, end + 1):
|
|
112
|
+
used_turns.add(t)
|
|
113
|
+
|
|
114
|
+
# Count examples in window
|
|
115
|
+
window_llm_calls = [llm for turn, llm in llm_calls if start <= turn <= end]
|
|
116
|
+
total_examples += len(window_llm_calls)
|
|
117
|
+
|
|
118
|
+
window_details.append({
|
|
119
|
+
'trace_file': trace_file.name,
|
|
120
|
+
'window': f"[{start}-{end}]",
|
|
121
|
+
'score': score,
|
|
122
|
+
'num_examples': len(window_llm_calls)
|
|
123
|
+
})
|
|
124
|
+
|
|
125
|
+
if trace_has_window:
|
|
126
|
+
traces_with_windows += 1
|
|
127
|
+
|
|
128
|
+
return {
|
|
129
|
+
'window_size': window_size,
|
|
130
|
+
'threshold': threshold,
|
|
131
|
+
'total_traces': len(trace_files),
|
|
132
|
+
'traces_with_qualifying_windows': traces_with_windows,
|
|
133
|
+
'total_windows_extracted': total_windows,
|
|
134
|
+
'total_examples': total_examples,
|
|
135
|
+
'avg_examples_per_window': total_examples / total_windows if total_windows else 0,
|
|
136
|
+
'window_score_distribution': dict(window_scores),
|
|
137
|
+
'window_details': window_details[:20] # First 20 for brevity
|
|
138
|
+
}
|
|
139
|
+
|
|
140
|
+
|
|
141
|
+
def main():
|
|
142
|
+
traces_dir = Path("traces")
|
|
143
|
+
ft_dir = Path("ft_dataset")
|
|
144
|
+
|
|
145
|
+
# Analyze trajectory dataset
|
|
146
|
+
print("Analyzing trajectory-based dataset...")
|
|
147
|
+
traj_metadata = analyze_trajectory_dataset(traces_dir, threshold=2.0)
|
|
148
|
+
|
|
149
|
+
# Analyze window dataset
|
|
150
|
+
print("Analyzing window-based dataset...")
|
|
151
|
+
window_metadata = analyze_window_dataset(traces_dir, window_size=5, threshold=1.0)
|
|
152
|
+
|
|
153
|
+
# Create combined metadata
|
|
154
|
+
combined_metadata = {
|
|
155
|
+
'dataset_creation': {
|
|
156
|
+
'source_traces_dir': str(traces_dir),
|
|
157
|
+
'num_source_traces': traj_metadata['total_traces'],
|
|
158
|
+
'filtering_methods': ['trajectory_score', 'window_score']
|
|
159
|
+
},
|
|
160
|
+
'trajectory_filtering': traj_metadata,
|
|
161
|
+
'window_filtering': window_metadata,
|
|
162
|
+
'comparison': {
|
|
163
|
+
'trajectory_examples': traj_metadata['total_examples'],
|
|
164
|
+
'window_examples': window_metadata['total_examples'],
|
|
165
|
+
'trajectory_yield_rate': f"{traj_metadata['yield_rate']:.1f}%",
|
|
166
|
+
'window_trace_coverage': f"{window_metadata['traces_with_qualifying_windows'] / window_metadata['total_traces'] * 100:.1f}%"
|
|
167
|
+
}
|
|
168
|
+
}
|
|
169
|
+
|
|
170
|
+
# Save metadata
|
|
171
|
+
with open(ft_dir / "metadata.json", 'w') as f:
|
|
172
|
+
json.dump(combined_metadata, f, indent=2)
|
|
173
|
+
|
|
174
|
+
# Save trajectory-specific metadata
|
|
175
|
+
with open(ft_dir / "trajectory_score_metadata.json", 'w') as f:
|
|
176
|
+
json.dump(traj_metadata, f, indent=2)
|
|
177
|
+
|
|
178
|
+
# Save window-specific metadata
|
|
179
|
+
with open(ft_dir / "window_score_metadata.json", 'w') as f:
|
|
180
|
+
json.dump(window_metadata, f, indent=2)
|
|
181
|
+
|
|
182
|
+
# Print summary
|
|
183
|
+
print("\n" + "="*60)
|
|
184
|
+
print("FINE-TUNING DATASET SUMMARY")
|
|
185
|
+
print("="*60)
|
|
186
|
+
print(f"Source traces: {traj_metadata['total_traces']}")
|
|
187
|
+
print(f"\nTrajectory-based filtering (score >= 2.0):")
|
|
188
|
+
print(f" - Included traces: {traj_metadata['included_traces']} ({traj_metadata['yield_rate']:.1f}%)")
|
|
189
|
+
print(f" - Total examples: {traj_metadata['total_examples']}")
|
|
190
|
+
print(f" - Avg examples/trace: {traj_metadata['avg_examples_per_trace']:.1f}")
|
|
191
|
+
|
|
192
|
+
print(f"\nWindow-based filtering (window_size=5, score >= 1.0):")
|
|
193
|
+
print(f" - Traces with windows: {window_metadata['traces_with_qualifying_windows']} ({window_metadata['traces_with_qualifying_windows'] / window_metadata['total_traces'] * 100:.1f}%)")
|
|
194
|
+
print(f" - Total windows: {window_metadata['total_windows_extracted']}")
|
|
195
|
+
print(f" - Total examples: {window_metadata['total_examples']}")
|
|
196
|
+
print(f" - Avg examples/window: {window_metadata['avg_examples_per_window']:.1f}")
|
|
197
|
+
|
|
198
|
+
print(f"\nWhy so many examples?")
|
|
199
|
+
print(f" - Each trace has multiple LLM calls (turns)")
|
|
200
|
+
print(f" - Trajectory method: {traj_metadata['included_traces']} traces × {traj_metadata['avg_examples_per_trace']:.1f} turns = {traj_metadata['total_examples']} examples")
|
|
201
|
+
print(f" - Window method: {window_metadata['total_windows_extracted']} windows × {window_metadata['avg_examples_per_window']:.1f} turns = {window_metadata['total_examples']} examples")
|
|
202
|
+
|
|
203
|
+
|
|
204
|
+
if __name__ == "__main__":
|
|
205
|
+
main()
|