synth-ai 0.2.4.dev8__py3-none-any.whl → 0.2.4.dev9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of synth-ai might be problematic. Click here for more details.

Files changed (111) hide show
  1. synth_ai/cli/__init__.py +6 -0
  2. synth_ai/cli/demo.py +68 -9
  3. synth_ai/cli/rl_demo.py +137 -0
  4. synth_ai/cli/root.py +65 -0
  5. synth_ai/demos/core/__init__.py +1 -0
  6. synth_ai/demos/core/cli.py +621 -0
  7. synth_ai/demos/demo_task_apps/__init__.py +1 -0
  8. synth_ai/demos/demo_task_apps/core.py +374 -0
  9. synth_ai/demos/demo_task_apps/math/__init__.py +1 -0
  10. synth_ai/demos/demo_task_apps/math/app.py +37 -0
  11. synth_ai/demos/demo_task_apps/math/config.toml +44 -0
  12. synth_ai/demos/demo_task_apps/math/deploy_modal.py +60 -0
  13. synth_ai/demos/demo_task_apps/math/deploy_task_app.sh +22 -0
  14. synth_ai/environments/examples/bandit/__init__.py +33 -0
  15. synth_ai/environments/examples/bandit/engine.py +294 -0
  16. synth_ai/environments/examples/bandit/environment.py +194 -0
  17. synth_ai/environments/examples/bandit/taskset.py +200 -0
  18. synth_ai/environments/examples/crafter_classic/agent_demos/analyze_semantic_words_markdown.py +250 -0
  19. synth_ai/environments/examples/crafter_classic/agent_demos/crafter_comprehensive_evaluation.py +59 -0
  20. synth_ai/environments/examples/crafter_classic/agent_demos/crafter_evaluation_browser.py +152 -0
  21. synth_ai/environments/examples/crafter_classic/agent_demos/crafter_evaluation_config.toml +24 -0
  22. synth_ai/environments/examples/crafter_classic/agent_demos/crafter_evaluation_framework.py +1194 -0
  23. synth_ai/environments/examples/crafter_classic/agent_demos/crafter_modal_ft/crafter_synth_config.toml +56 -0
  24. synth_ai/environments/examples/crafter_classic/agent_demos/crafter_modal_ft/filter_config_modal.toml +32 -0
  25. synth_ai/environments/examples/crafter_classic/agent_demos/crafter_modal_ft/filter_traces_sft_turso.py +724 -0
  26. synth_ai/environments/examples/crafter_classic/agent_demos/crafter_modal_ft/kick_off_ft_modal.py +384 -0
  27. synth_ai/environments/examples/crafter_classic/agent_demos/crafter_modal_ft/old/analyze_action_results.py +53 -0
  28. synth_ai/environments/examples/crafter_classic/agent_demos/crafter_modal_ft/old/analyze_agent_actions.py +178 -0
  29. synth_ai/environments/examples/crafter_classic/agent_demos/crafter_modal_ft/old/analyze_latest_run.py +222 -0
  30. synth_ai/environments/examples/crafter_classic/agent_demos/crafter_modal_ft/old/analyze_lm_traces.py +183 -0
  31. synth_ai/environments/examples/crafter_classic/agent_demos/crafter_modal_ft/old/analyze_no_rewards.py +210 -0
  32. synth_ai/environments/examples/crafter_classic/agent_demos/crafter_modal_ft/old/analyze_trace_issue.py +206 -0
  33. synth_ai/environments/examples/crafter_classic/agent_demos/crafter_modal_ft/old/check_db_schema.py +49 -0
  34. synth_ai/environments/examples/crafter_classic/agent_demos/crafter_modal_ft/old/check_latest_results.py +64 -0
  35. synth_ai/environments/examples/crafter_classic/agent_demos/crafter_modal_ft/old/debug_agent_responses.py +88 -0
  36. synth_ai/environments/examples/crafter_classic/agent_demos/crafter_modal_ft/old/quick_trace_check.py +77 -0
  37. synth_ai/environments/examples/crafter_classic/agent_demos/crafter_openai_ft/compare_experiments.py +324 -0
  38. synth_ai/environments/examples/crafter_classic/agent_demos/crafter_openai_ft/filter_traces_sft_turso.py +580 -0
  39. synth_ai/environments/examples/crafter_classic/agent_demos/crafter_openai_ft/kick_off_ft_oai.py +362 -0
  40. synth_ai/environments/examples/crafter_classic/agent_demos/crafter_openai_ft/multi_model_config.toml +49 -0
  41. synth_ai/environments/examples/crafter_classic/agent_demos/crafter_openai_ft/old/analyze_enhanced_hooks.py +332 -0
  42. synth_ai/environments/examples/crafter_classic/agent_demos/crafter_openai_ft/old/analyze_hook_events.py +97 -0
  43. synth_ai/environments/examples/crafter_classic/agent_demos/crafter_openai_ft/old/analyze_hook_results.py +217 -0
  44. synth_ai/environments/examples/crafter_classic/agent_demos/crafter_openai_ft/old/check_hook_storage.py +87 -0
  45. synth_ai/environments/examples/crafter_classic/agent_demos/crafter_openai_ft/old/check_seeds.py +88 -0
  46. synth_ai/environments/examples/crafter_classic/agent_demos/crafter_openai_ft/old/compare_seed_performance.py +195 -0
  47. synth_ai/environments/examples/crafter_classic/agent_demos/crafter_openai_ft/old/custom_eval_pipelines.py +400 -0
  48. synth_ai/environments/examples/crafter_classic/agent_demos/crafter_openai_ft/old/plot_hook_frequency.py +195 -0
  49. synth_ai/environments/examples/crafter_classic/agent_demos/crafter_openai_ft/old/seed_analysis_summary.py +56 -0
  50. synth_ai/environments/examples/crafter_classic/agent_demos/crafter_openai_ft/run_rollouts_for_models_and_compare_v3.py +858 -0
  51. synth_ai/environments/examples/crafter_classic/agent_demos/crafter_quick_evaluation.py +52 -0
  52. synth_ai/environments/examples/crafter_classic/agent_demos/crafter_react_agent.py +874 -0
  53. synth_ai/environments/examples/crafter_classic/agent_demos/crafter_trace_evaluation.py +1412 -0
  54. synth_ai/environments/examples/crafter_classic/agent_demos/example_v3_usage.py +216 -0
  55. synth_ai/environments/examples/crafter_classic/agent_demos/old/compare_traces.py +296 -0
  56. synth_ai/environments/examples/crafter_classic/agent_demos/old/crafter_comprehensive_evaluation.py +58 -0
  57. synth_ai/environments/examples/crafter_classic/agent_demos/old/crafter_env_serialization.py +464 -0
  58. synth_ai/environments/examples/crafter_classic/agent_demos/old/crafter_evaluation_browser.py +152 -0
  59. synth_ai/environments/examples/crafter_classic/agent_demos/old/crafter_quick_evaluation.py +51 -0
  60. synth_ai/environments/examples/crafter_classic/agent_demos/old/crafter_trace_evaluation.py +1412 -0
  61. synth_ai/environments/examples/crafter_classic/agent_demos/old/debug_player_loss.py +112 -0
  62. synth_ai/environments/examples/crafter_classic/agent_demos/old/diagnose_service.py +203 -0
  63. synth_ai/environments/examples/crafter_classic/agent_demos/old/diagnose_slowness.py +305 -0
  64. synth_ai/environments/examples/crafter_classic/agent_demos/old/eval_by_difficulty.py +126 -0
  65. synth_ai/environments/examples/crafter_classic/agent_demos/old/eval_example.py +94 -0
  66. synth_ai/environments/examples/crafter_classic/agent_demos/old/explore_saved_states.py +142 -0
  67. synth_ai/environments/examples/crafter_classic/agent_demos/old/filter_traces_sft.py +26 -0
  68. synth_ai/environments/examples/crafter_classic/agent_demos/old/filter_traces_sft_OLD.py +984 -0
  69. synth_ai/environments/examples/crafter_classic/agent_demos/old/generate_ft_data_gemini.py +724 -0
  70. synth_ai/environments/examples/crafter_classic/agent_demos/old/generate_ft_data_modal.py +386 -0
  71. synth_ai/environments/examples/crafter_classic/agent_demos/old/generate_ft_metadata.py +205 -0
  72. synth_ai/environments/examples/crafter_classic/agent_demos/old/kick_off_ft_gemini.py +150 -0
  73. synth_ai/environments/examples/crafter_classic/agent_demos/old/kick_off_ft_modal.py +283 -0
  74. synth_ai/environments/examples/crafter_classic/agent_demos/old/prepare_vertex_ft.py +280 -0
  75. synth_ai/environments/examples/crafter_classic/agent_demos/old/profile_env_slowness.py +456 -0
  76. synth_ai/environments/examples/crafter_classic/agent_demos/old/replicate_issue.py +166 -0
  77. synth_ai/environments/examples/crafter_classic/agent_demos/old/run_and_eval.py +102 -0
  78. synth_ai/environments/examples/crafter_classic/agent_demos/old/run_comparison.py +128 -0
  79. synth_ai/environments/examples/crafter_classic/agent_demos/old/run_qwen_rollouts.py +655 -0
  80. synth_ai/environments/examples/crafter_classic/agent_demos/old/trace_eval_OLD.py +202 -0
  81. synth_ai/environments/examples/crafter_classic/agent_demos/old/validate_openai_format.py +166 -0
  82. synth_ai/environments/examples/crafter_classic/environment.py +41 -2
  83. synth_ai/environments/examples/crafter_custom/agent_demos/__init__.py +1 -0
  84. synth_ai/environments/examples/crafter_custom/agent_demos/trace_eval.py +202 -0
  85. synth_ai/environments/examples/crafter_custom/old/analyze_diamond_issue.py +159 -0
  86. synth_ai/environments/examples/crafter_custom/old/analyze_diamond_spawning.py +158 -0
  87. synth_ai/environments/examples/crafter_custom/old/compare_worlds.py +71 -0
  88. synth_ai/environments/examples/crafter_custom/old/dataset_stats.py +105 -0
  89. synth_ai/environments/examples/crafter_custom/old/diamond_spawning_summary.py +119 -0
  90. synth_ai/environments/examples/crafter_custom/old/example_dataset_usage.py +52 -0
  91. synth_ai/environments/examples/enron/units/keyword_stats.py +112 -0
  92. synth_ai/environments/examples/minigrid/agent_demos/minigrid_evaluation_framework.py +1188 -0
  93. synth_ai/environments/examples/minigrid/agent_demos/minigrid_quick_evaluation.py +48 -0
  94. synth_ai/environments/examples/minigrid/agent_demos/minigrid_react_agent.py +562 -0
  95. synth_ai/environments/examples/minigrid/agent_demos/minigrid_trace_evaluation.py +221 -0
  96. synth_ai/environments/examples/nethack/agent_demos/nethack_evaluation_framework.py +981 -0
  97. synth_ai/environments/examples/nethack/agent_demos/nethack_quick_evaluation.py +74 -0
  98. synth_ai/environments/examples/nethack/agent_demos/nethack_react_agent.py +831 -0
  99. synth_ai/environments/examples/red/agent_demos/__init__.py +1 -0
  100. synth_ai/environments/examples/red/units/__init__.py +1 -0
  101. synth_ai/environments/examples/sokoban/agent_demos/sokoban_full_eval.py +899 -0
  102. synth_ai/environments/examples/sokoban/units/astar_common.py +95 -0
  103. synth_ai/environments/service/app.py +8 -0
  104. synth_ai/install_sqld.sh +40 -0
  105. synth_ai-0.2.4.dev9.dist-info/METADATA +91 -0
  106. {synth_ai-0.2.4.dev8.dist-info → synth_ai-0.2.4.dev9.dist-info}/RECORD +110 -11
  107. {synth_ai-0.2.4.dev8.dist-info → synth_ai-0.2.4.dev9.dist-info}/entry_points.txt +1 -0
  108. synth_ai-0.2.4.dev8.dist-info/METADATA +0 -635
  109. {synth_ai-0.2.4.dev8.dist-info → synth_ai-0.2.4.dev9.dist-info}/WHEEL +0 -0
  110. {synth_ai-0.2.4.dev8.dist-info → synth_ai-0.2.4.dev9.dist-info}/licenses/LICENSE +0 -0
  111. {synth_ai-0.2.4.dev8.dist-info → synth_ai-0.2.4.dev9.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,724 @@
1
+ #!/usr/bin/env python3
2
+ """
3
+ Generate Fine-tuning Data for Gemini Models
4
+ ===========================================
5
+ This script generates high-quality trajectories from Crafter using Gemini models
6
+ and converts them to JSONL format suitable for Vertex AI fine-tuning.
7
+ """
8
+
9
+ import asyncio
10
+ import json
11
+ import uuid
12
+ import argparse
13
+ import toml
14
+ import logging
15
+ from datetime import datetime
16
+ from typing import Dict, Any, Optional, List, Set, Tuple
17
+ from pathlib import Path
18
+ import sys
19
+ import os
20
+ import numpy as np
21
+ from collections import defaultdict
22
+ import time
23
+ from tqdm.asyncio import tqdm_asyncio
24
+ from httpx import AsyncClient
25
+
26
+ # Add the src directory to the path
27
+ sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "..", "..", "..", "src"))
28
+
29
+ from synth_ai.lm.core.main import LM
30
+ from synth_ai.lm.tools.base import BaseTool
31
+ from pydantic import BaseModel, Field
32
+
33
+ # Import TaskInstance and related classes
34
+ from synth_ai.environments.tasks.core import (
35
+ Impetus,
36
+ Intent,
37
+ Task,
38
+ TaskInstance,
39
+ )
40
+ from synth_ai.environments.examples.crafter_classic.taskset import CrafterTaskInstance, CrafterTaskInstanceMetadata
41
+
42
+ # Import trace evaluation utilities
43
+ sys.path.append(str(Path(__file__).parent))
44
+ from trace_eval import evaluate_trace, WEIGHTS
45
+ from filter_traces_sft import load_trace, extract_trajectory_score, extract_llm_calls
46
+
47
+
48
+ # --- Helper Functions ---
49
+ def parse_observation_text(obs_text: str) -> Dict[str, Any]:
50
+ """Parse structured observation from text format."""
51
+ obs_data = {
52
+ "health": 10,
53
+ "hunger": 10,
54
+ "thirst": 10,
55
+ "inventory": {},
56
+ "achievements_dict": {},
57
+ "player_position": [0, 0],
58
+ "semantic_map": [],
59
+ "done": False
60
+ }
61
+
62
+ if not obs_text:
63
+ return obs_data
64
+
65
+ lines = obs_text.strip().split('\n')
66
+ current_section = None
67
+
68
+ for line in lines:
69
+ line = line.strip()
70
+ if not line:
71
+ continue
72
+
73
+ # Parse stats
74
+ if "Health:" in line:
75
+ try:
76
+ health = line.split(":")[1].strip().split("/")[0]
77
+ obs_data["health"] = int(health)
78
+ except:
79
+ pass
80
+ elif "Hunger:" in line:
81
+ try:
82
+ hunger = line.split(":")[1].strip().split("/")[0]
83
+ obs_data["hunger"] = int(hunger)
84
+ except:
85
+ pass
86
+ elif "Thirst:" in line:
87
+ try:
88
+ thirst = line.split(":")[1].strip().split("/")[0]
89
+ obs_data["thirst"] = int(thirst)
90
+ except:
91
+ pass
92
+ elif "Position:" in line:
93
+ try:
94
+ pos_str = line.split(":")[1].strip()
95
+ x, y = pos_str.strip("()").split(",")
96
+ obs_data["player_position"] = [int(x), int(y)]
97
+ except:
98
+ pass
99
+ elif "Inventory:" in line:
100
+ current_section = "inventory"
101
+ elif "Achievements:" in line:
102
+ current_section = "achievements"
103
+ elif current_section == "inventory" and " - " in line:
104
+ try:
105
+ item, count = line.split(" - ")
106
+ item = item.strip().strip("-").strip()
107
+ count = int(count.split(":")[1].strip())
108
+ obs_data["inventory"][item] = count
109
+ except:
110
+ pass
111
+ elif current_section == "achievements" and line:
112
+ # Parse achievements list
113
+ achievements = line.split(", ")
114
+ for ach in achievements:
115
+ ach = ach.strip()
116
+ if ach:
117
+ obs_data["achievements_dict"][ach] = True
118
+
119
+ return obs_data
120
+
121
+
122
+ # --- Configuration ---
123
+ class GenerationConfig:
124
+ """Configuration for fine-tuning data generation."""
125
+
126
+ def __init__(self, config_path: Optional[str] = None):
127
+ # Default values
128
+ self.model_name = "gemini-2.5-flash" # Best Gemini model for reasoning
129
+ self.num_rollouts = 100
130
+ self.max_turns = 30
131
+ self.difficulty = "easy"
132
+ self.service_base_url = "http://localhost:8901"
133
+ self.service_timeout = 30.0
134
+ self.seed = 42
135
+ self.traces_dir = Path("traces_gemini")
136
+ self.ft_data_dir = Path("ft_data_gemini")
137
+
138
+ # Quality filtering
139
+ self.min_score_threshold = 2.0 # Minimum trajectory score
140
+ self.min_achievements = 3 # Minimum achievements required
141
+ self.enable_thinking = True # Enable thinking/reasoning
142
+ self.thinking_budget = 15000 # Token budget for thinking
143
+
144
+ # Load from TOML if provided
145
+ if config_path and os.path.exists(config_path):
146
+ self.load_from_toml(config_path)
147
+
148
+ def load_from_toml(self, config_path: str):
149
+ """Load configuration from TOML file."""
150
+ config = toml.load(config_path)
151
+
152
+ # Extract generation settings
153
+ gen_config = config.get("generation", {})
154
+ self.model_name = gen_config.get("model_name", self.model_name)
155
+ self.num_rollouts = gen_config.get("num_rollouts", self.num_rollouts)
156
+ self.max_turns = gen_config.get("max_turns", self.max_turns)
157
+ self.difficulty = gen_config.get("difficulty", self.difficulty)
158
+ self.seed = gen_config.get("seed", self.seed)
159
+
160
+ # Extract service settings
161
+ service_config = config.get("service", {})
162
+ self.service_base_url = service_config.get("base_url", self.service_base_url)
163
+ self.service_timeout = service_config.get("timeout", self.service_timeout)
164
+
165
+ # Extract quality settings
166
+ quality_config = config.get("quality", {})
167
+ self.min_score_threshold = quality_config.get("min_score_threshold", self.min_score_threshold)
168
+ self.min_achievements = quality_config.get("min_achievements", self.min_achievements)
169
+ self.enable_thinking = quality_config.get("enable_thinking", self.enable_thinking)
170
+ self.thinking_budget = quality_config.get("thinking_budget", self.thinking_budget)
171
+
172
+
173
+ # --- Crafter Action Tool ---
174
+ class CrafterAction(BaseTool):
175
+ """Tool for performing actions in Crafter environment."""
176
+
177
+ name: str = "crafter_action"
178
+ description: str = "Perform an action in the Crafter environment"
179
+ params: List[tuple] = [
180
+ ("action", "str", "The action to perform (e.g., 'move_north', 'collect', 'craft_wood_pickaxe')")
181
+ ]
182
+
183
+ def __init__(self, instance_id: str, client: AsyncClient):
184
+ super().__init__()
185
+ self.instance_id = instance_id
186
+ self.client = client
187
+ self.base_url = "http://localhost:8901"
188
+
189
+ # Action mapping from string to integer
190
+ self.action_map = {
191
+ "noop": 0,
192
+ "move_north": 1,
193
+ "move_south": 2,
194
+ "move_east": 3,
195
+ "move_west": 4,
196
+ "attack": 5,
197
+ "collect": 6,
198
+ "craft_wood_pickaxe": 7,
199
+ "craft_stone_pickaxe": 8,
200
+ "craft_iron_pickaxe": 9,
201
+ "craft_wood_sword": 10,
202
+ "craft_stone_sword": 11,
203
+ "craft_iron_sword": 12,
204
+ "eat": 13,
205
+ "drink": 14,
206
+ "sleep": 15,
207
+ "place_stone": 16,
208
+ "place_table": 17,
209
+ "place_furnace": 18,
210
+ "place_plant": 19,
211
+ }
212
+
213
+ def _action_to_int(self, action: str) -> int:
214
+ """Convert action string to integer."""
215
+ return self.action_map.get(action.lower(), 0)
216
+
217
+ async def _run(self, action: str) -> Dict[str, Any]:
218
+ """Execute action in environment."""
219
+ response = await self.client.post(
220
+ f"{self.base_url}/env/CrafterClassic/step",
221
+ json={"env_id": self.instance_id, "tool_calls": [{
222
+ "tool_name": "interact",
223
+ "tool_call_id": str(uuid.uuid4()),
224
+ "tool_args": {"action": self._action_to_int(action)}
225
+ }]}
226
+ )
227
+ response.raise_for_status()
228
+ result = response.json()
229
+
230
+ # Return the full result for the agent to process
231
+ return result
232
+
233
+
234
+ # --- Gemini Agent ---
235
+ class GeminiCrafterAgent:
236
+ """Agent that plays Crafter using Gemini models via synth-ai LM."""
237
+
238
+ def __init__(self, model_name: str, instance_id: str, client: AsyncClient):
239
+ self.model_name = model_name
240
+ self.instance_id = instance_id
241
+ self.client = client
242
+
243
+ # Initialize LM with Gemini model
244
+ self.lm = LM(
245
+ model_name=model_name,
246
+ formatting_model_name=model_name,
247
+ temperature=0.7 # Use some temperature for diversity
248
+ )
249
+
250
+ # Create action tool
251
+ self.action_tool = CrafterAction(instance_id, client)
252
+
253
+ # Initialize conversation history
254
+ self.messages = []
255
+
256
+ # System prompt
257
+ self.system_prompt = """You are an expert Crafter player. Your goal is to achieve as many objectives as possible in the game.
258
+
259
+ Key objectives (achievements) in order of importance:
260
+ 1. Basic survival: collect resources, eat when hungry, drink when thirsty
261
+ 2. Tool progression: craft pickaxe → stone pickaxe → iron pickaxe
262
+ 3. Advanced goals: make iron sword, defeat enemies
263
+
264
+ Action format: Use the crafter_action tool with one of these actions:
265
+ - Movement: move_north, move_south, move_east, move_west
266
+ - Resource gathering: collect (gathers wood/stone/etc), attack (mines harder materials)
267
+ - Crafting: craft_wood_pickaxe, craft_stone_pickaxe, craft_iron_pickaxe, craft_wood_sword, craft_stone_sword, craft_iron_sword
268
+ - Survival: eat, drink, sleep
269
+ - Placing: place_stone, place_table, place_furnace, place_plant
270
+
271
+ Tips:
272
+ - Start by collecting wood (stand near trees and use 'collect')
273
+ - Craft a wood pickaxe early to mine stone
274
+ - Monitor your health, hunger, and thirst
275
+ - Explore to find water, coal, and iron
276
+ - Use the semantic map to navigate efficiently"""
277
+
278
+ # Add system message
279
+ self.messages.append({"role": "system", "content": self.system_prompt})
280
+
281
+ def _format_observation(self, obs_data: Dict[str, Any]) -> str:
282
+ """Format observation data into readable text."""
283
+ lines = ["=== Current State ==="]
284
+
285
+ # Stats
286
+ lines.append(f"Position: ({obs_data.get('player_position', [0, 0])[0]}, {obs_data.get('player_position', [0, 0])[1]})")
287
+ lines.append(f"Health: {obs_data.get('health', 0)}/10")
288
+ lines.append(f"Hunger: {obs_data.get('hunger', 0)}/10")
289
+ lines.append(f"Thirst: {obs_data.get('thirst', 0)}/10")
290
+
291
+ # Inventory
292
+ inventory = obs_data.get('inventory', {})
293
+ if inventory:
294
+ lines.append("\nInventory:")
295
+ for item, count in inventory.items():
296
+ if count > 0:
297
+ lines.append(f" - {item}: {count}")
298
+
299
+ # Achievements
300
+ achievements = obs_data.get('achievements_dict', {})
301
+ unlocked = [k for k, v in achievements.items() if v]
302
+ if unlocked:
303
+ lines.append(f"\nAchievements: {', '.join(unlocked)}")
304
+
305
+ # Local view (simplified)
306
+ lines.append("\nNearby (5x5 grid around you):")
307
+ semantic_map = obs_data.get('semantic_map', [])
308
+ if semantic_map:
309
+ # Get center region of semantic map
310
+ # Assuming semantic map is flattened, reconstruct as 2D
311
+ map_size = int(np.sqrt(len(semantic_map)))
312
+ if map_size * map_size == len(semantic_map):
313
+ map_2d = np.array(semantic_map).reshape(map_size, map_size)
314
+ center = map_size // 2
315
+ view_radius = 2
316
+
317
+ # Simple ID to symbol mapping
318
+ id_to_symbol = {
319
+ 0: '.', # void/empty
320
+ 1: 'G', # grass
321
+ 2: 'T', # tree
322
+ 3: 'S', # stone
323
+ 4: 'W', # water
324
+ 5: 'C', # coal
325
+ 6: 'I', # iron
326
+ 7: '@', # player
327
+ 8: 'E', # enemy
328
+ 9: 'F', # furnace
329
+ 10: 'P', # plant
330
+ }
331
+
332
+ for dy in range(-view_radius, view_radius + 1):
333
+ row = []
334
+ for dx in range(-view_radius, view_radius + 1):
335
+ y, x = center + dy, center + dx
336
+ if 0 <= y < map_size and 0 <= x < map_size:
337
+ cell_id = int(map_2d[y, x])
338
+ symbol = id_to_symbol.get(cell_id, '?')
339
+ if dy == 0 and dx == 0:
340
+ symbol = '@' # Player position
341
+ row.append(symbol)
342
+ else:
343
+ row.append(' ')
344
+ lines.append(' ' + ' '.join(row))
345
+
346
+ return '\n'.join(lines)
347
+
348
+ async def step(self, obs_data: Dict[str, Any]) -> Tuple[str, Dict[str, Any]]:
349
+ """Take a step in the environment."""
350
+ # Format observation
351
+ obs_text = self._format_observation(obs_data)
352
+
353
+ # Add observation to conversation
354
+ self.messages.append({"role": "user", "content": obs_text})
355
+
356
+ # Get action from LM with tool
357
+ response = await self.lm.ainvoke(
358
+ self.messages,
359
+ tools=[self.action_tool],
360
+ tool_choice="required"
361
+ )
362
+
363
+ # Extract action from response
364
+ action = None
365
+ thinking = None
366
+
367
+ # Handle response based on type
368
+ if hasattr(response, 'tool_calls') and response.tool_calls:
369
+ # Tool was called
370
+ tool_call = response.tool_calls[0]
371
+ action = tool_call.function.arguments.get('action', 'noop')
372
+
373
+ # Add assistant message
374
+ self.messages.append({
375
+ "role": "assistant",
376
+ "content": response.content or f"Taking action: {action}",
377
+ "tool_calls": response.tool_calls
378
+ })
379
+ else:
380
+ # No tool call, extract action from text
381
+ content = response.content if hasattr(response, 'content') else str(response)
382
+ self.messages.append({"role": "assistant", "content": content})
383
+ action = "noop"
384
+
385
+ # Extract thinking if available
386
+ if hasattr(response, '_raw_response'):
387
+ raw = response._raw_response
388
+ if isinstance(raw, dict) and 'thinking' in raw:
389
+ thinking = raw['thinking']
390
+
391
+ # Execute action
392
+ result = await self.action_tool._run(action)
393
+
394
+ # Add tool response
395
+ if hasattr(response, 'tool_calls') and response.tool_calls:
396
+ self.messages.append({
397
+ "role": "tool",
398
+ "content": json.dumps(result),
399
+ "tool_call_id": response.tool_calls[0].id
400
+ })
401
+
402
+ return action, {"thinking": thinking} if thinking else {}
403
+
404
+
405
+ # --- Main Generation Functions ---
406
+ async def generate_trajectory(config: GenerationConfig, instance_num: int) -> Optional[Dict[str, Any]]:
407
+ """Generate a single trajectory using Gemini model."""
408
+ async with AsyncClient(timeout=config.service_timeout) as client:
409
+ try:
410
+ # Create task instance
411
+ task_instance = CrafterTaskInstance(
412
+ id=uuid.uuid4(),
413
+ impetus=Impetus(
414
+ instructions="Survive and unlock achievements. Focus on collecting resources, crafting tools, and progressing through the game."
415
+ ),
416
+ intent=Intent(
417
+ rubric={"goal": "Unlock as many achievements as possible."},
418
+ gold_trajectories=None,
419
+ gold_state_diff={},
420
+ deterministic_eval_functions=[]
421
+ ),
422
+ metadata=CrafterTaskInstanceMetadata(
423
+ difficulty=config.difficulty,
424
+ seed=config.seed + instance_num,
425
+ num_trees_radius=4,
426
+ num_cows_radius=2,
427
+ num_hostiles_radius=0 if config.difficulty == "easy" else 2,
428
+ world_config="normal"
429
+ ),
430
+ is_reproducible=True,
431
+ initial_engine_snapshot=None # will be filled lazily when env starts
432
+ )
433
+
434
+ # Initialize environment
435
+ create_response = await client.post(
436
+ f"{config.service_base_url}/env/CrafterClassic/initialize",
437
+ json={"task_instance": await task_instance.serialize()}
438
+ )
439
+ create_response.raise_for_status()
440
+ env_data = create_response.json()
441
+ instance_id = env_data["env_id"]
442
+
443
+ print(f"šŸŽ® Instance {instance_num}: Created {instance_id}")
444
+
445
+ # Get initial observation
446
+ obs_data = env_data.get("observations", [{}])[0]
447
+
448
+ # Parse the observation to get structured data
449
+ if "human_observation" in obs_data:
450
+ obs_text = obs_data["human_observation"]
451
+ # Parse structured data from observation text
452
+ obs_data = parse_observation_text(obs_text)
453
+ obs_data["raw_text"] = obs_text
454
+
455
+ # Create agent
456
+ agent = GeminiCrafterAgent(
457
+ model_name=config.model_name,
458
+ instance_id=instance_id,
459
+ client=client
460
+ )
461
+
462
+ # Track trajectory data
463
+ trajectory = {
464
+ "instance_id": instance_id,
465
+ "instance_num": instance_num,
466
+ "model": config.model_name,
467
+ "start_time": datetime.now().isoformat(),
468
+ "actions": [],
469
+ "observations": [],
470
+ "llm_calls": [],
471
+ "achievements": {},
472
+ "final_score": 0.0
473
+ }
474
+
475
+ # Run episode
476
+ for turn in range(config.max_turns):
477
+ # Get action from agent
478
+ action, metadata = await agent.step(obs_data)
479
+
480
+ # Store LLM call data
481
+ llm_call = {
482
+ "turn": turn,
483
+ "messages": agent.messages[-3:], # Last 3 messages (user, assistant, tool)
484
+ "action": action,
485
+ "metadata": metadata
486
+ }
487
+ trajectory["llm_calls"].append(llm_call)
488
+
489
+ # Store action and observation
490
+ trajectory["actions"].append(action)
491
+ trajectory["observations"].append(obs_data)
492
+
493
+ # Check if done
494
+ if obs_data.get("done", False):
495
+ print(f"āœ… Instance {instance_num}: Episode done at turn {turn}")
496
+ break
497
+
498
+ # Step in environment and get next observation
499
+ step_response = await client.post(
500
+ f"{config.service_base_url}/env/CrafterClassic/step",
501
+ json={"env_id": instance_id, "tool_calls": [{
502
+ "tool_name": "interact",
503
+ "tool_call_id": str(uuid.uuid4()),
504
+ "tool_args": {"action": agent.action_tool._action_to_int(action)}
505
+ }]}
506
+ )
507
+ step_response.raise_for_status()
508
+ step_data = step_response.json()
509
+
510
+ # Extract observation
511
+ if "observations" in step_data and step_data["observations"]:
512
+ obs = step_data["observations"][0]
513
+ if "human_observation" in obs:
514
+ obs_data = parse_observation_text(obs["human_observation"])
515
+ obs_data["done"] = step_data.get("done", False)
516
+ else:
517
+ obs_data = {"done": step_data.get("done", False)}
518
+ else:
519
+ obs_data = {"done": True}
520
+
521
+ # Get final achievements
522
+ trajectory["achievements"] = obs_data.get("achievements_dict", {})
523
+ trajectory["end_time"] = datetime.now().isoformat()
524
+
525
+ # Also get achievements from last observation if available
526
+ if trajectory["observations"] and "achievements_dict" in trajectory["observations"][-1]:
527
+ trajectory["achievements"].update(trajectory["observations"][-1]["achievements_dict"])
528
+
529
+ # Calculate score
530
+ unlocked_achievements = sum(1 for v in trajectory["achievements"].values() if v)
531
+ trajectory["final_score"] = float(unlocked_achievements)
532
+
533
+ print(f"šŸ“Š Instance {instance_num}: Score={trajectory['final_score']:.1f}, Achievements={unlocked_achievements}")
534
+
535
+ # Terminate instance
536
+ await client.post(
537
+ f"{config.service_base_url}/env/CrafterClassic/terminate",
538
+ json={"env_id": instance_id}
539
+ )
540
+
541
+ return trajectory
542
+
543
+ except Exception as e:
544
+ print(f"āŒ Instance {instance_num}: Error - {e}")
545
+ return None
546
+
547
+
548
+ async def generate_all_trajectories(config: GenerationConfig) -> List[Dict[str, Any]]:
549
+ """Generate multiple trajectories concurrently."""
550
+ print(f"\nšŸš€ Generating {config.num_rollouts} trajectories with {config.model_name}")
551
+
552
+ # Create tasks
553
+ tasks = [generate_trajectory(config, i) for i in range(config.num_rollouts)]
554
+
555
+ # Run with progress bar
556
+ trajectories = []
557
+ with tqdm_asyncio(total=config.num_rollouts, desc="Generating") as pbar:
558
+ for coro in asyncio.as_completed(tasks):
559
+ trajectory = await coro
560
+ if trajectory:
561
+ trajectories.append(trajectory)
562
+ pbar.update(1)
563
+
564
+ return trajectories
565
+
566
+
567
+ def filter_high_quality_trajectories(trajectories: List[Dict[str, Any]],
568
+ min_score: float = 2.0,
569
+ min_achievements: int = 3) -> List[Dict[str, Any]]:
570
+ """Filter trajectories based on quality criteria."""
571
+ filtered = []
572
+
573
+ for traj in trajectories:
574
+ # Count achievements
575
+ achievements = traj.get("achievements", {})
576
+ num_achievements = sum(1 for v in achievements.values() if v)
577
+
578
+ # Calculate score (could be more sophisticated)
579
+ score = traj.get("final_score", 0.0)
580
+
581
+ # Apply filters
582
+ if score >= min_score and num_achievements >= min_achievements:
583
+ filtered.append(traj)
584
+
585
+ print(f"\nšŸ“Š Filtering Results:")
586
+ print(f" Total trajectories: {len(trajectories)}")
587
+ if trajectories:
588
+ print(f" High quality: {len(filtered)} ({len(filtered)/len(trajectories)*100:.1f}%)")
589
+ else:
590
+ print(f" High quality: 0 (no trajectories generated)")
591
+
592
+ return filtered
593
+
594
+
595
+ def convert_to_vertex_ai_format(trajectories: List[Dict[str, Any]], output_path: Path):
596
+ """Convert trajectories to Vertex AI fine-tuning format."""
597
+ output_path.parent.mkdir(parents=True, exist_ok=True)
598
+
599
+ examples = []
600
+
601
+ for traj in trajectories:
602
+ # Extract LLM calls
603
+ for llm_call in traj.get("llm_calls", []):
604
+ messages = llm_call.get("messages", [])
605
+
606
+ # Skip if not enough messages
607
+ if len(messages) < 2:
608
+ continue
609
+
610
+ # Convert to Vertex AI format
611
+ # Need user message and assistant response
612
+ user_msg = None
613
+ assistant_msg = None
614
+
615
+ for msg in messages:
616
+ if msg["role"] == "user":
617
+ user_msg = msg["content"]
618
+ elif msg["role"] == "assistant":
619
+ assistant_msg = msg["content"]
620
+
621
+ if user_msg and assistant_msg:
622
+ example = {
623
+ "messages": [
624
+ {"role": "user", "content": user_msg},
625
+ {"role": "assistant", "content": assistant_msg}
626
+ ]
627
+ }
628
+ examples.append(example)
629
+
630
+ # Write JSONL
631
+ with open(output_path, 'w') as f:
632
+ for example in examples:
633
+ f.write(json.dumps(example) + '\n')
634
+
635
+ print(f"\nāœ… Wrote {len(examples)} examples to {output_path}")
636
+ return len(examples)
637
+
638
+
639
+ def save_trajectories(trajectories: List[Dict[str, Any]], traces_dir: Path):
640
+ """Save trajectories to disk."""
641
+ traces_dir.mkdir(parents=True, exist_ok=True)
642
+
643
+ for i, traj in enumerate(trajectories):
644
+ filename = f"trajectory_{i:04d}.json"
645
+ with open(traces_dir / filename, 'w') as f:
646
+ json.dump(traj, f, indent=2)
647
+
648
+ print(f"šŸ’¾ Saved {len(trajectories)} trajectories to {traces_dir}")
649
+
650
+
651
+ # --- Main ---
652
+ async def main():
653
+ parser = argparse.ArgumentParser(description="Generate Gemini fine-tuning data for Crafter")
654
+ parser.add_argument("--config", type=str, help="Path to TOML config file")
655
+ parser.add_argument("--num-rollouts", type=int, help="Number of rollouts to generate")
656
+ parser.add_argument("--model", type=str, help="Gemini model name")
657
+ parser.add_argument("--filter-only", action="store_true", help="Only filter existing traces")
658
+ parser.add_argument("--min-achievements", type=int, help="Minimum achievements for filtering")
659
+
660
+ args = parser.parse_args()
661
+
662
+ # Load config
663
+ config = GenerationConfig(args.config)
664
+
665
+ # Override with command line args
666
+ if args.num_rollouts:
667
+ config.num_rollouts = args.num_rollouts
668
+ if args.model:
669
+ config.model_name = args.model
670
+ if args.min_achievements:
671
+ config.min_achievements = args.min_achievements
672
+
673
+ if args.filter_only:
674
+ # Filter existing trajectories
675
+ print("šŸ” Filtering existing trajectories...")
676
+
677
+ # Load trajectories
678
+ trajectories = []
679
+ for trace_file in sorted(config.traces_dir.glob("*.json")):
680
+ with open(trace_file) as f:
681
+ trajectories.append(json.load(f))
682
+
683
+ # Filter
684
+ filtered = filter_high_quality_trajectories(
685
+ trajectories,
686
+ min_score=config.min_score_threshold,
687
+ min_achievements=config.min_achievements
688
+ )
689
+
690
+ # Convert to JSONL
691
+ output_path = config.ft_data_dir / "crafter_gemini_ft.jsonl"
692
+ num_examples = convert_to_vertex_ai_format(filtered, output_path)
693
+
694
+ print(f"\nšŸŽÆ Summary:")
695
+ print(f" Filtered trajectories: {len(filtered)}")
696
+ print(f" Total training examples: {num_examples}")
697
+
698
+ else:
699
+ # Generate new trajectories
700
+ trajectories = await generate_all_trajectories(config)
701
+
702
+ # Save all trajectories
703
+ save_trajectories(trajectories, config.traces_dir)
704
+
705
+ # Filter high quality
706
+ filtered = filter_high_quality_trajectories(
707
+ trajectories,
708
+ min_score=config.min_score_threshold,
709
+ min_achievements=config.min_achievements
710
+ )
711
+
712
+ # Convert to JSONL
713
+ output_path = config.ft_data_dir / "crafter_gemini_ft.jsonl"
714
+ num_examples = convert_to_vertex_ai_format(filtered, output_path)
715
+
716
+ print(f"\nšŸŽÆ Summary:")
717
+ print(f" Generated trajectories: {len(trajectories)}")
718
+ print(f" High quality trajectories: {len(filtered)}")
719
+ print(f" Total training examples: {num_examples}")
720
+ print(f" Output: {output_path}")
721
+
722
+
723
+ if __name__ == "__main__":
724
+ asyncio.run(main())