synth-ai 0.2.4.dev8__py3-none-any.whl → 0.2.4.dev9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of synth-ai might be problematic. Click here for more details.

Files changed (111) hide show
  1. synth_ai/cli/__init__.py +6 -0
  2. synth_ai/cli/demo.py +68 -9
  3. synth_ai/cli/rl_demo.py +137 -0
  4. synth_ai/cli/root.py +65 -0
  5. synth_ai/demos/core/__init__.py +1 -0
  6. synth_ai/demos/core/cli.py +621 -0
  7. synth_ai/demos/demo_task_apps/__init__.py +1 -0
  8. synth_ai/demos/demo_task_apps/core.py +374 -0
  9. synth_ai/demos/demo_task_apps/math/__init__.py +1 -0
  10. synth_ai/demos/demo_task_apps/math/app.py +37 -0
  11. synth_ai/demos/demo_task_apps/math/config.toml +44 -0
  12. synth_ai/demos/demo_task_apps/math/deploy_modal.py +60 -0
  13. synth_ai/demos/demo_task_apps/math/deploy_task_app.sh +22 -0
  14. synth_ai/environments/examples/bandit/__init__.py +33 -0
  15. synth_ai/environments/examples/bandit/engine.py +294 -0
  16. synth_ai/environments/examples/bandit/environment.py +194 -0
  17. synth_ai/environments/examples/bandit/taskset.py +200 -0
  18. synth_ai/environments/examples/crafter_classic/agent_demos/analyze_semantic_words_markdown.py +250 -0
  19. synth_ai/environments/examples/crafter_classic/agent_demos/crafter_comprehensive_evaluation.py +59 -0
  20. synth_ai/environments/examples/crafter_classic/agent_demos/crafter_evaluation_browser.py +152 -0
  21. synth_ai/environments/examples/crafter_classic/agent_demos/crafter_evaluation_config.toml +24 -0
  22. synth_ai/environments/examples/crafter_classic/agent_demos/crafter_evaluation_framework.py +1194 -0
  23. synth_ai/environments/examples/crafter_classic/agent_demos/crafter_modal_ft/crafter_synth_config.toml +56 -0
  24. synth_ai/environments/examples/crafter_classic/agent_demos/crafter_modal_ft/filter_config_modal.toml +32 -0
  25. synth_ai/environments/examples/crafter_classic/agent_demos/crafter_modal_ft/filter_traces_sft_turso.py +724 -0
  26. synth_ai/environments/examples/crafter_classic/agent_demos/crafter_modal_ft/kick_off_ft_modal.py +384 -0
  27. synth_ai/environments/examples/crafter_classic/agent_demos/crafter_modal_ft/old/analyze_action_results.py +53 -0
  28. synth_ai/environments/examples/crafter_classic/agent_demos/crafter_modal_ft/old/analyze_agent_actions.py +178 -0
  29. synth_ai/environments/examples/crafter_classic/agent_demos/crafter_modal_ft/old/analyze_latest_run.py +222 -0
  30. synth_ai/environments/examples/crafter_classic/agent_demos/crafter_modal_ft/old/analyze_lm_traces.py +183 -0
  31. synth_ai/environments/examples/crafter_classic/agent_demos/crafter_modal_ft/old/analyze_no_rewards.py +210 -0
  32. synth_ai/environments/examples/crafter_classic/agent_demos/crafter_modal_ft/old/analyze_trace_issue.py +206 -0
  33. synth_ai/environments/examples/crafter_classic/agent_demos/crafter_modal_ft/old/check_db_schema.py +49 -0
  34. synth_ai/environments/examples/crafter_classic/agent_demos/crafter_modal_ft/old/check_latest_results.py +64 -0
  35. synth_ai/environments/examples/crafter_classic/agent_demos/crafter_modal_ft/old/debug_agent_responses.py +88 -0
  36. synth_ai/environments/examples/crafter_classic/agent_demos/crafter_modal_ft/old/quick_trace_check.py +77 -0
  37. synth_ai/environments/examples/crafter_classic/agent_demos/crafter_openai_ft/compare_experiments.py +324 -0
  38. synth_ai/environments/examples/crafter_classic/agent_demos/crafter_openai_ft/filter_traces_sft_turso.py +580 -0
  39. synth_ai/environments/examples/crafter_classic/agent_demos/crafter_openai_ft/kick_off_ft_oai.py +362 -0
  40. synth_ai/environments/examples/crafter_classic/agent_demos/crafter_openai_ft/multi_model_config.toml +49 -0
  41. synth_ai/environments/examples/crafter_classic/agent_demos/crafter_openai_ft/old/analyze_enhanced_hooks.py +332 -0
  42. synth_ai/environments/examples/crafter_classic/agent_demos/crafter_openai_ft/old/analyze_hook_events.py +97 -0
  43. synth_ai/environments/examples/crafter_classic/agent_demos/crafter_openai_ft/old/analyze_hook_results.py +217 -0
  44. synth_ai/environments/examples/crafter_classic/agent_demos/crafter_openai_ft/old/check_hook_storage.py +87 -0
  45. synth_ai/environments/examples/crafter_classic/agent_demos/crafter_openai_ft/old/check_seeds.py +88 -0
  46. synth_ai/environments/examples/crafter_classic/agent_demos/crafter_openai_ft/old/compare_seed_performance.py +195 -0
  47. synth_ai/environments/examples/crafter_classic/agent_demos/crafter_openai_ft/old/custom_eval_pipelines.py +400 -0
  48. synth_ai/environments/examples/crafter_classic/agent_demos/crafter_openai_ft/old/plot_hook_frequency.py +195 -0
  49. synth_ai/environments/examples/crafter_classic/agent_demos/crafter_openai_ft/old/seed_analysis_summary.py +56 -0
  50. synth_ai/environments/examples/crafter_classic/agent_demos/crafter_openai_ft/run_rollouts_for_models_and_compare_v3.py +858 -0
  51. synth_ai/environments/examples/crafter_classic/agent_demos/crafter_quick_evaluation.py +52 -0
  52. synth_ai/environments/examples/crafter_classic/agent_demos/crafter_react_agent.py +874 -0
  53. synth_ai/environments/examples/crafter_classic/agent_demos/crafter_trace_evaluation.py +1412 -0
  54. synth_ai/environments/examples/crafter_classic/agent_demos/example_v3_usage.py +216 -0
  55. synth_ai/environments/examples/crafter_classic/agent_demos/old/compare_traces.py +296 -0
  56. synth_ai/environments/examples/crafter_classic/agent_demos/old/crafter_comprehensive_evaluation.py +58 -0
  57. synth_ai/environments/examples/crafter_classic/agent_demos/old/crafter_env_serialization.py +464 -0
  58. synth_ai/environments/examples/crafter_classic/agent_demos/old/crafter_evaluation_browser.py +152 -0
  59. synth_ai/environments/examples/crafter_classic/agent_demos/old/crafter_quick_evaluation.py +51 -0
  60. synth_ai/environments/examples/crafter_classic/agent_demos/old/crafter_trace_evaluation.py +1412 -0
  61. synth_ai/environments/examples/crafter_classic/agent_demos/old/debug_player_loss.py +112 -0
  62. synth_ai/environments/examples/crafter_classic/agent_demos/old/diagnose_service.py +203 -0
  63. synth_ai/environments/examples/crafter_classic/agent_demos/old/diagnose_slowness.py +305 -0
  64. synth_ai/environments/examples/crafter_classic/agent_demos/old/eval_by_difficulty.py +126 -0
  65. synth_ai/environments/examples/crafter_classic/agent_demos/old/eval_example.py +94 -0
  66. synth_ai/environments/examples/crafter_classic/agent_demos/old/explore_saved_states.py +142 -0
  67. synth_ai/environments/examples/crafter_classic/agent_demos/old/filter_traces_sft.py +26 -0
  68. synth_ai/environments/examples/crafter_classic/agent_demos/old/filter_traces_sft_OLD.py +984 -0
  69. synth_ai/environments/examples/crafter_classic/agent_demos/old/generate_ft_data_gemini.py +724 -0
  70. synth_ai/environments/examples/crafter_classic/agent_demos/old/generate_ft_data_modal.py +386 -0
  71. synth_ai/environments/examples/crafter_classic/agent_demos/old/generate_ft_metadata.py +205 -0
  72. synth_ai/environments/examples/crafter_classic/agent_demos/old/kick_off_ft_gemini.py +150 -0
  73. synth_ai/environments/examples/crafter_classic/agent_demos/old/kick_off_ft_modal.py +283 -0
  74. synth_ai/environments/examples/crafter_classic/agent_demos/old/prepare_vertex_ft.py +280 -0
  75. synth_ai/environments/examples/crafter_classic/agent_demos/old/profile_env_slowness.py +456 -0
  76. synth_ai/environments/examples/crafter_classic/agent_demos/old/replicate_issue.py +166 -0
  77. synth_ai/environments/examples/crafter_classic/agent_demos/old/run_and_eval.py +102 -0
  78. synth_ai/environments/examples/crafter_classic/agent_demos/old/run_comparison.py +128 -0
  79. synth_ai/environments/examples/crafter_classic/agent_demos/old/run_qwen_rollouts.py +655 -0
  80. synth_ai/environments/examples/crafter_classic/agent_demos/old/trace_eval_OLD.py +202 -0
  81. synth_ai/environments/examples/crafter_classic/agent_demos/old/validate_openai_format.py +166 -0
  82. synth_ai/environments/examples/crafter_classic/environment.py +41 -2
  83. synth_ai/environments/examples/crafter_custom/agent_demos/__init__.py +1 -0
  84. synth_ai/environments/examples/crafter_custom/agent_demos/trace_eval.py +202 -0
  85. synth_ai/environments/examples/crafter_custom/old/analyze_diamond_issue.py +159 -0
  86. synth_ai/environments/examples/crafter_custom/old/analyze_diamond_spawning.py +158 -0
  87. synth_ai/environments/examples/crafter_custom/old/compare_worlds.py +71 -0
  88. synth_ai/environments/examples/crafter_custom/old/dataset_stats.py +105 -0
  89. synth_ai/environments/examples/crafter_custom/old/diamond_spawning_summary.py +119 -0
  90. synth_ai/environments/examples/crafter_custom/old/example_dataset_usage.py +52 -0
  91. synth_ai/environments/examples/enron/units/keyword_stats.py +112 -0
  92. synth_ai/environments/examples/minigrid/agent_demos/minigrid_evaluation_framework.py +1188 -0
  93. synth_ai/environments/examples/minigrid/agent_demos/minigrid_quick_evaluation.py +48 -0
  94. synth_ai/environments/examples/minigrid/agent_demos/minigrid_react_agent.py +562 -0
  95. synth_ai/environments/examples/minigrid/agent_demos/minigrid_trace_evaluation.py +221 -0
  96. synth_ai/environments/examples/nethack/agent_demos/nethack_evaluation_framework.py +981 -0
  97. synth_ai/environments/examples/nethack/agent_demos/nethack_quick_evaluation.py +74 -0
  98. synth_ai/environments/examples/nethack/agent_demos/nethack_react_agent.py +831 -0
  99. synth_ai/environments/examples/red/agent_demos/__init__.py +1 -0
  100. synth_ai/environments/examples/red/units/__init__.py +1 -0
  101. synth_ai/environments/examples/sokoban/agent_demos/sokoban_full_eval.py +899 -0
  102. synth_ai/environments/examples/sokoban/units/astar_common.py +95 -0
  103. synth_ai/environments/service/app.py +8 -0
  104. synth_ai/install_sqld.sh +40 -0
  105. synth_ai-0.2.4.dev9.dist-info/METADATA +91 -0
  106. {synth_ai-0.2.4.dev8.dist-info → synth_ai-0.2.4.dev9.dist-info}/RECORD +110 -11
  107. {synth_ai-0.2.4.dev8.dist-info → synth_ai-0.2.4.dev9.dist-info}/entry_points.txt +1 -0
  108. synth_ai-0.2.4.dev8.dist-info/METADATA +0 -635
  109. {synth_ai-0.2.4.dev8.dist-info → synth_ai-0.2.4.dev9.dist-info}/WHEEL +0 -0
  110. {synth_ai-0.2.4.dev8.dist-info → synth_ai-0.2.4.dev9.dist-info}/licenses/LICENSE +0 -0
  111. {synth_ai-0.2.4.dev8.dist-info → synth_ai-0.2.4.dev9.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,386 @@
1
+ #!/usr/bin/env python3
2
+ """
3
+ Script to generate and filter fine-tuning data from Crafter rollouts
4
+ Applies quality filters and formats data for Modal fine-tuning
5
+ """
6
+
7
+ import argparse
8
+ import json
9
+ import os
10
+ from collections import defaultdict
11
+ from datetime import datetime
12
+ from pathlib import Path
13
+ from typing import Dict, List, Optional, Tuple
14
+
15
+ import numpy as np
16
+
17
+
18
+ class CrafterDataFilter:
19
+ """Filter and process Crafter rollout data for fine-tuning."""
20
+
21
+ def __init__(
22
+ self,
23
+ min_achievements: int = 3,
24
+ min_reward: float = 0.0,
25
+ max_steps: int = 100,
26
+ min_steps: int = 10,
27
+ achievement_weights: Optional[Dict[str, float]] = None
28
+ ):
29
+ self.min_achievements = min_achievements
30
+ self.min_reward = min_reward
31
+ self.max_steps = max_steps
32
+ self.min_steps = min_steps
33
+
34
+ # Default achievement weights (higher = more valuable)
35
+ self.achievement_weights = achievement_weights or {
36
+ # Basic resources
37
+ "collect_wood": 1.0,
38
+ "collect_stone": 1.5,
39
+ "collect_coal": 3.0,
40
+ "collect_iron": 5.0,
41
+ "collect_diamond": 10.0,
42
+
43
+ # Crafting
44
+ "place_table": 2.0,
45
+ "place_furnace": 3.0,
46
+ "make_wood_pickaxe": 2.5,
47
+ "make_stone_pickaxe": 3.5,
48
+ "make_iron_pickaxe": 6.0,
49
+ "make_wood_sword": 2.5,
50
+ "make_stone_sword": 3.5,
51
+ "make_iron_sword": 6.0,
52
+
53
+ # Survival
54
+ "eat_cow": 2.0,
55
+ "eat_plant": 1.0,
56
+ "collect_drink": 1.0,
57
+ "sleep": 1.5,
58
+
59
+ # Combat
60
+ "defeat_zombie": 3.0,
61
+ "defeat_skeleton": 4.0,
62
+ }
63
+
64
+ def calculate_episode_score(self, episode: dict) -> float:
65
+ """Calculate a quality score for an episode."""
66
+ score = 0.0
67
+
68
+ # Achievement score
69
+ achievements = episode.get("achievements", [])
70
+ for ach in achievements:
71
+ score += self.achievement_weights.get(ach, 1.0)
72
+
73
+ # Efficiency bonus (achievements per step)
74
+ num_steps = len(episode.get("steps", []))
75
+ if num_steps > 0:
76
+ efficiency = len(achievements) / num_steps
77
+ score += efficiency * 10
78
+
79
+ # Reward contribution
80
+ total_reward = episode.get("total_reward", 0)
81
+ score += total_reward * 0.5
82
+
83
+ # Diversity bonus (unique actions)
84
+ unique_actions = len(set(step["action"] for step in episode.get("steps", [])))
85
+ score += unique_actions * 0.5
86
+
87
+ return score
88
+
89
+ def filter_episode(self, episode: dict) -> Tuple[bool, str]:
90
+ """
91
+ Check if an episode passes quality filters.
92
+ Returns (passes, reason).
93
+ """
94
+ # Check achievements
95
+ achievements = episode.get("achievements", [])
96
+ if len(achievements) < self.min_achievements:
97
+ return False, f"Too few achievements: {len(achievements)} < {self.min_achievements}"
98
+
99
+ # Check reward
100
+ total_reward = episode.get("total_reward", 0)
101
+ if total_reward < self.min_reward:
102
+ return False, f"Low reward: {total_reward} < {self.min_reward}"
103
+
104
+ # Check step count
105
+ num_steps = len(episode.get("steps", []))
106
+ if num_steps < self.min_steps:
107
+ return False, f"Too few steps: {num_steps} < {self.min_steps}"
108
+ if num_steps > self.max_steps:
109
+ return False, f"Too many steps: {num_steps} > {self.max_steps}"
110
+
111
+ # Check for errors
112
+ if episode.get("termination_reason", "").startswith("error"):
113
+ return False, "Episode ended with error"
114
+
115
+ return True, "Passed all filters"
116
+
117
+ def optimize_conversation(self, messages: List[dict]) -> List[dict]:
118
+ """
119
+ Optimize conversation for fine-tuning by:
120
+ - Removing redundant information
121
+ - Condensing observations
122
+ - Ensuring proper format
123
+ """
124
+ optimized = []
125
+
126
+ for msg in messages:
127
+ role = msg.get("role")
128
+ content = msg.get("content", "")
129
+
130
+ if role == "system":
131
+ # Keep system message as-is
132
+ optimized.append(msg)
133
+
134
+ elif role == "user":
135
+ # Condense user observations
136
+ if len(content) > 1000:
137
+ # Extract key information
138
+ lines = content.split("\n")
139
+ key_lines = []
140
+
141
+ for line in lines:
142
+ # Keep important lines
143
+ if any(keyword in line.lower() for keyword in [
144
+ "health:", "hunger:", "thirst:", "inventory:",
145
+ "achievements:", "map", "recent"
146
+ ]):
147
+ key_lines.append(line)
148
+
149
+ # Keep reasoning prompt
150
+ if "Think step by step:" in content:
151
+ idx = content.index("Think step by step:")
152
+ key_lines.append(content[idx:])
153
+
154
+ content = "\n".join(key_lines)
155
+
156
+ optimized.append({"role": role, "content": content})
157
+
158
+ elif role == "assistant":
159
+ # Keep assistant responses but trim if very long
160
+ if len(content) > 500:
161
+ # Keep first part (reasoning) and action
162
+ lines = content.split("\n")
163
+ kept_lines = []
164
+
165
+ # Keep reasoning (usually first few lines)
166
+ for i, line in enumerate(lines[:5]):
167
+ kept_lines.append(line)
168
+
169
+ # Find and keep action
170
+ for line in lines:
171
+ if any(action in line for action in [
172
+ "move_", "do", "sleep", "place_", "make_"
173
+ ]):
174
+ if line not in kept_lines:
175
+ kept_lines.append(line)
176
+
177
+ content = "\n".join(kept_lines)
178
+
179
+ optimized.append({"role": role, "content": content})
180
+
181
+ return optimized
182
+
183
+ def create_training_example(self, episode: dict) -> dict:
184
+ """Create a training example from a filtered episode."""
185
+ messages = episode.get("messages", [])
186
+
187
+ # Optimize conversation
188
+ optimized_messages = self.optimize_conversation(messages)
189
+
190
+ # Add metadata as comment in system message
191
+ metadata_comment = (
192
+ f"# Episode stats: {len(episode['achievements'])} achievements, "
193
+ f"{episode['total_reward']:.1f} reward, {len(episode['steps'])} steps"
194
+ )
195
+
196
+ if optimized_messages and optimized_messages[0]["role"] == "system":
197
+ optimized_messages[0]["content"] += f"\n{metadata_comment}"
198
+
199
+ return {"messages": optimized_messages}
200
+
201
+
202
+ def analyze_rollout_directory(rollout_dir: Path) -> Dict[str, any]:
203
+ """Analyze a rollout directory and return statistics."""
204
+ stats = {
205
+ "total_episodes": 0,
206
+ "total_achievements": defaultdict(int),
207
+ "reward_distribution": [],
208
+ "step_distribution": [],
209
+ "achievement_distribution": [],
210
+ }
211
+
212
+ # Find rollout data files
213
+ rollout_files = list(rollout_dir.glob("**/rollout_data.json"))
214
+
215
+ for rollout_file in rollout_files:
216
+ with open(rollout_file, "r") as f:
217
+ data = json.load(f)
218
+
219
+ episodes = data.get("episodes", [])
220
+ stats["total_episodes"] += len(episodes)
221
+
222
+ for ep in episodes:
223
+ # Track achievements
224
+ for ach in ep.get("achievements", []):
225
+ stats["total_achievements"][ach] += 1
226
+
227
+ # Track distributions
228
+ stats["reward_distribution"].append(ep.get("total_reward", 0))
229
+ stats["step_distribution"].append(len(ep.get("steps", [])))
230
+ stats["achievement_distribution"].append(len(ep.get("achievements", [])))
231
+
232
+ return stats
233
+
234
+
235
+ def main():
236
+ parser = argparse.ArgumentParser(description="Generate fine-tuning data from Crafter rollouts")
237
+ parser.add_argument("rollout_dir", type=str, help="Directory containing rollout data")
238
+ parser.add_argument("--output", type=str, default=None,
239
+ help="Output file (default: rollout_dir/fine_tuning_filtered.jsonl)")
240
+ parser.add_argument("--min-achievements", type=int, default=3,
241
+ help="Minimum achievements required")
242
+ parser.add_argument("--min-reward", type=float, default=0.0,
243
+ help="Minimum total reward required")
244
+ parser.add_argument("--max-steps", type=int, default=100,
245
+ help="Maximum steps allowed")
246
+ parser.add_argument("--min-steps", type=int, default=10,
247
+ help="Minimum steps required")
248
+ parser.add_argument("--top-k", type=int, default=None,
249
+ help="Only keep top K episodes by score")
250
+ parser.add_argument("--analyze-only", action="store_true",
251
+ help="Only analyze data, don't generate output")
252
+ parser.add_argument("--verbose", action="store_true",
253
+ help="Show detailed filtering information")
254
+
255
+ args = parser.parse_args()
256
+
257
+ rollout_dir = Path(args.rollout_dir)
258
+ if not rollout_dir.exists():
259
+ print(f"❌ Rollout directory not found: {rollout_dir}")
260
+ return 1
261
+
262
+ # Analyze rollout data
263
+ print(f"📊 Analyzing rollout data in {rollout_dir}...")
264
+ stats = analyze_rollout_directory(rollout_dir)
265
+
266
+ print(f"\n📈 Rollout Statistics:")
267
+ print(f" Total episodes: {stats['total_episodes']}")
268
+
269
+ if stats['total_episodes'] == 0:
270
+ print("❌ No episodes found in rollout directory")
271
+ return 1
272
+
273
+ print(f" Avg achievements: {np.mean(stats['achievement_distribution']):.1f}")
274
+ print(f" Avg reward: {np.mean(stats['reward_distribution']):.1f}")
275
+ print(f" Avg steps: {np.mean(stats['step_distribution']):.1f}")
276
+
277
+ print(f"\n🏆 Top achievements:")
278
+ for ach, count in sorted(stats['total_achievements'].items(),
279
+ key=lambda x: x[1], reverse=True)[:10]:
280
+ pct = count / stats['total_episodes'] * 100
281
+ print(f" {ach}: {count} ({pct:.1f}%)")
282
+
283
+ if args.analyze_only:
284
+ return 0
285
+
286
+ # Create filter
287
+ filter_config = CrafterDataFilter(
288
+ min_achievements=args.min_achievements,
289
+ min_reward=args.min_reward,
290
+ max_steps=args.max_steps,
291
+ min_steps=args.min_steps
292
+ )
293
+
294
+ # Process episodes
295
+ print(f"\n🔍 Filtering episodes...")
296
+ all_episodes = []
297
+ filtered_episodes = []
298
+ filter_reasons = defaultdict(int)
299
+
300
+ # Load all episodes
301
+ rollout_files = list(rollout_dir.glob("**/rollout_data.json"))
302
+ for rollout_file in rollout_files:
303
+ with open(rollout_file, "r") as f:
304
+ data = json.load(f)
305
+
306
+ episodes = data.get("episodes", [])
307
+ all_episodes.extend(episodes)
308
+
309
+ # Filter episodes
310
+ for ep in all_episodes:
311
+ passes, reason = filter_config.filter_episode(ep)
312
+ if passes:
313
+ # Calculate score for ranking
314
+ score = filter_config.calculate_episode_score(ep)
315
+ filtered_episodes.append((score, ep))
316
+ else:
317
+ filter_reasons[reason] += 1
318
+ if args.verbose:
319
+ print(f" Filtered out: {reason}")
320
+
321
+ print(f"\n📋 Filtering Results:")
322
+ print(f" Total episodes: {len(all_episodes)}")
323
+ print(f" Passed filters: {len(filtered_episodes)}")
324
+ print(f" Pass rate: {len(filtered_episodes)/len(all_episodes)*100:.1f}%")
325
+
326
+ if filter_reasons:
327
+ print(f"\n❌ Filter reasons:")
328
+ for reason, count in sorted(filter_reasons.items(),
329
+ key=lambda x: x[1], reverse=True):
330
+ print(f" {reason}: {count}")
331
+
332
+ if not filtered_episodes:
333
+ print("\n❌ No episodes passed the filters!")
334
+ return 1
335
+
336
+ # Sort by score and apply top-k if specified
337
+ filtered_episodes.sort(key=lambda x: x[0], reverse=True)
338
+
339
+ if args.top_k and args.top_k < len(filtered_episodes):
340
+ print(f"\n🎯 Selecting top {args.top_k} episodes by score")
341
+ filtered_episodes = filtered_episodes[:args.top_k]
342
+
343
+ # Generate training examples
344
+ print(f"\n✍️ Generating training examples...")
345
+ training_examples = []
346
+
347
+ for score, episode in filtered_episodes:
348
+ example = filter_config.create_training_example(episode)
349
+ training_examples.append(example)
350
+
351
+ if args.verbose:
352
+ achievements = episode.get("achievements", [])
353
+ print(f" Score: {score:.1f}, Achievements: {achievements}")
354
+
355
+ # Save output
356
+ output_file = args.output
357
+ if not output_file:
358
+ output_file = rollout_dir / "fine_tuning_filtered.jsonl"
359
+ else:
360
+ output_file = Path(output_file)
361
+
362
+ with open(output_file, "w") as f:
363
+ for example in training_examples:
364
+ f.write(json.dumps(example) + "\n")
365
+
366
+ print(f"\n✅ Generated {len(training_examples)} training examples")
367
+ print(f"📁 Saved to: {output_file}")
368
+
369
+ # Calculate token estimate
370
+ total_chars = sum(len(json.dumps(ex)) for ex in training_examples)
371
+ est_tokens = total_chars // 4
372
+ print(f"📊 Estimated tokens: {est_tokens:,}")
373
+
374
+ # Show sample
375
+ print(f"\n📝 Sample training example:")
376
+ sample = training_examples[0]["messages"]
377
+ for msg in sample[:3]: # Show first 3 messages
378
+ role = msg["role"]
379
+ content = msg["content"][:100] + "..." if len(msg["content"]) > 100 else msg["content"]
380
+ print(f" [{role}] {content}")
381
+
382
+ return 0
383
+
384
+
385
+ if __name__ == "__main__":
386
+ exit(main())
@@ -0,0 +1,205 @@
1
+ #!/usr/bin/env python3
2
+ """
3
+ Generate metadata for fine-tuning datasets
4
+ """
5
+
6
+ import json
7
+ from pathlib import Path
8
+ from collections import defaultdict
9
+ import sys
10
+
11
+ sys.path.append(str(Path(__file__).parent))
12
+ from filter_traces_sft import load_trace, extract_trajectory_score, extract_llm_calls, calculate_window_score
13
+
14
+
15
+ def analyze_trajectory_dataset(traces_dir: Path, threshold: float = 2.0):
16
+ """Analyze trajectory-based filtering results."""
17
+ trace_files = sorted(traces_dir.glob("*.json"))
18
+
19
+ included_traces = []
20
+ excluded_traces = []
21
+ score_distribution = defaultdict(int)
22
+ achievement_counts = defaultdict(int)
23
+ total_llm_calls = 0
24
+
25
+ for trace_file in trace_files:
26
+ trace = load_trace(trace_file)
27
+ score = extract_trajectory_score(trace)
28
+ score_distribution[int(score)] += 1
29
+
30
+ # Get achievements
31
+ metadata = trace.get('session_metadata', [])
32
+ if isinstance(metadata, list):
33
+ for item in metadata:
34
+ if isinstance(item, dict) and item.get('metadata_type') == 'episode_results':
35
+ episode_results = item.get('data', {})
36
+ achievements = episode_results.get('achievements', {})
37
+ for ach, unlocked in achievements.items():
38
+ if unlocked:
39
+ achievement_counts[ach] += 1
40
+ break
41
+
42
+ # Count LLM calls
43
+ llm_calls = extract_llm_calls(trace)
44
+
45
+ if score >= threshold:
46
+ included_traces.append({
47
+ 'trace_file': trace_file.name,
48
+ 'score': score,
49
+ 'num_llm_calls': len(llm_calls),
50
+ 'achievements': [k for k, v in achievements.items() if v] if 'achievements' in locals() else []
51
+ })
52
+ total_llm_calls += len(llm_calls)
53
+ else:
54
+ excluded_traces.append({
55
+ 'trace_file': trace_file.name,
56
+ 'score': score
57
+ })
58
+
59
+ return {
60
+ 'threshold': threshold,
61
+ 'total_traces': len(trace_files),
62
+ 'included_traces': len(included_traces),
63
+ 'excluded_traces': len(excluded_traces),
64
+ 'yield_rate': len(included_traces) / len(trace_files) * 100,
65
+ 'total_examples': total_llm_calls,
66
+ 'avg_examples_per_trace': total_llm_calls / len(included_traces) if included_traces else 0,
67
+ 'score_distribution': dict(score_distribution),
68
+ 'achievement_distribution': dict(achievement_counts),
69
+ 'included_trace_details': included_traces
70
+ }
71
+
72
+
73
+ def analyze_window_dataset(traces_dir: Path, window_size: int = 5, threshold: float = 1.0):
74
+ """Analyze window-based filtering results with greedy extraction."""
75
+ trace_files = sorted(traces_dir.glob("*.json"))
76
+
77
+ window_scores = defaultdict(int)
78
+ traces_with_windows = 0
79
+ total_windows = 0
80
+ total_examples = 0
81
+ window_details = []
82
+
83
+ for trace_file in trace_files:
84
+ trace = load_trace(trace_file)
85
+ llm_calls = extract_llm_calls(trace)
86
+
87
+ if not llm_calls:
88
+ continue
89
+
90
+ # Get max turn
91
+ max_turn = max(turn for turn, _ in llm_calls)
92
+ trace_has_window = False
93
+ used_turns = set()
94
+
95
+ # Greedy extraction - same as in filter_traces_sft.py
96
+ for start in range(0, max_turn - window_size + 2):
97
+ end = start + window_size - 1
98
+
99
+ # Skip if any turn in window already used
100
+ if any(t in used_turns for t in range(start, end + 1)):
101
+ continue
102
+
103
+ score = calculate_window_score(trace, start, end)
104
+
105
+ if score >= threshold:
106
+ window_scores[int(score)] += 1
107
+ total_windows += 1
108
+ trace_has_window = True
109
+
110
+ # Mark turns as used
111
+ for t in range(start, end + 1):
112
+ used_turns.add(t)
113
+
114
+ # Count examples in window
115
+ window_llm_calls = [llm for turn, llm in llm_calls if start <= turn <= end]
116
+ total_examples += len(window_llm_calls)
117
+
118
+ window_details.append({
119
+ 'trace_file': trace_file.name,
120
+ 'window': f"[{start}-{end}]",
121
+ 'score': score,
122
+ 'num_examples': len(window_llm_calls)
123
+ })
124
+
125
+ if trace_has_window:
126
+ traces_with_windows += 1
127
+
128
+ return {
129
+ 'window_size': window_size,
130
+ 'threshold': threshold,
131
+ 'total_traces': len(trace_files),
132
+ 'traces_with_qualifying_windows': traces_with_windows,
133
+ 'total_windows_extracted': total_windows,
134
+ 'total_examples': total_examples,
135
+ 'avg_examples_per_window': total_examples / total_windows if total_windows else 0,
136
+ 'window_score_distribution': dict(window_scores),
137
+ 'window_details': window_details[:20] # First 20 for brevity
138
+ }
139
+
140
+
141
+ def main():
142
+ traces_dir = Path("traces")
143
+ ft_dir = Path("ft_dataset")
144
+
145
+ # Analyze trajectory dataset
146
+ print("Analyzing trajectory-based dataset...")
147
+ traj_metadata = analyze_trajectory_dataset(traces_dir, threshold=2.0)
148
+
149
+ # Analyze window dataset
150
+ print("Analyzing window-based dataset...")
151
+ window_metadata = analyze_window_dataset(traces_dir, window_size=5, threshold=1.0)
152
+
153
+ # Create combined metadata
154
+ combined_metadata = {
155
+ 'dataset_creation': {
156
+ 'source_traces_dir': str(traces_dir),
157
+ 'num_source_traces': traj_metadata['total_traces'],
158
+ 'filtering_methods': ['trajectory_score', 'window_score']
159
+ },
160
+ 'trajectory_filtering': traj_metadata,
161
+ 'window_filtering': window_metadata,
162
+ 'comparison': {
163
+ 'trajectory_examples': traj_metadata['total_examples'],
164
+ 'window_examples': window_metadata['total_examples'],
165
+ 'trajectory_yield_rate': f"{traj_metadata['yield_rate']:.1f}%",
166
+ 'window_trace_coverage': f"{window_metadata['traces_with_qualifying_windows'] / window_metadata['total_traces'] * 100:.1f}%"
167
+ }
168
+ }
169
+
170
+ # Save metadata
171
+ with open(ft_dir / "metadata.json", 'w') as f:
172
+ json.dump(combined_metadata, f, indent=2)
173
+
174
+ # Save trajectory-specific metadata
175
+ with open(ft_dir / "trajectory_score_metadata.json", 'w') as f:
176
+ json.dump(traj_metadata, f, indent=2)
177
+
178
+ # Save window-specific metadata
179
+ with open(ft_dir / "window_score_metadata.json", 'w') as f:
180
+ json.dump(window_metadata, f, indent=2)
181
+
182
+ # Print summary
183
+ print("\n" + "="*60)
184
+ print("FINE-TUNING DATASET SUMMARY")
185
+ print("="*60)
186
+ print(f"Source traces: {traj_metadata['total_traces']}")
187
+ print(f"\nTrajectory-based filtering (score >= 2.0):")
188
+ print(f" - Included traces: {traj_metadata['included_traces']} ({traj_metadata['yield_rate']:.1f}%)")
189
+ print(f" - Total examples: {traj_metadata['total_examples']}")
190
+ print(f" - Avg examples/trace: {traj_metadata['avg_examples_per_trace']:.1f}")
191
+
192
+ print(f"\nWindow-based filtering (window_size=5, score >= 1.0):")
193
+ print(f" - Traces with windows: {window_metadata['traces_with_qualifying_windows']} ({window_metadata['traces_with_qualifying_windows'] / window_metadata['total_traces'] * 100:.1f}%)")
194
+ print(f" - Total windows: {window_metadata['total_windows_extracted']}")
195
+ print(f" - Total examples: {window_metadata['total_examples']}")
196
+ print(f" - Avg examples/window: {window_metadata['avg_examples_per_window']:.1f}")
197
+
198
+ print(f"\nWhy so many examples?")
199
+ print(f" - Each trace has multiple LLM calls (turns)")
200
+ print(f" - Trajectory method: {traj_metadata['included_traces']} traces × {traj_metadata['avg_examples_per_trace']:.1f} turns = {traj_metadata['total_examples']} examples")
201
+ print(f" - Window method: {window_metadata['total_windows_extracted']} windows × {window_metadata['avg_examples_per_window']:.1f} turns = {window_metadata['total_examples']} examples")
202
+
203
+
204
+ if __name__ == "__main__":
205
+ main()