synth-ai 0.2.4.dev3__py3-none-any.whl → 0.2.4.dev5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- synth_ai/environments/examples/__init__.py +1 -0
- synth_ai/environments/examples/crafter_classic/__init__.py +8 -0
- synth_ai/environments/examples/crafter_classic/config_logging.py +111 -0
- synth_ai/environments/examples/crafter_classic/debug_translation.py +0 -0
- synth_ai/environments/examples/crafter_classic/engine.py +575 -0
- synth_ai/environments/examples/crafter_classic/engine_deterministic_patch.py +63 -0
- synth_ai/environments/examples/crafter_classic/engine_helpers/action_map.py +5 -0
- synth_ai/environments/examples/crafter_classic/engine_helpers/serialization.py +74 -0
- synth_ai/environments/examples/crafter_classic/engine_serialization_patch_v3.py +266 -0
- synth_ai/environments/examples/crafter_classic/environment.py +364 -0
- synth_ai/environments/examples/crafter_classic/taskset.py +233 -0
- synth_ai/environments/examples/crafter_classic/trace_hooks_v3.py +229 -0
- synth_ai/environments/examples/crafter_classic/world_config_patch_simple.py +298 -0
- synth_ai/environments/examples/crafter_custom/__init__.py +4 -0
- synth_ai/environments/examples/crafter_custom/crafter/__init__.py +7 -0
- synth_ai/environments/examples/crafter_custom/crafter/config.py +182 -0
- synth_ai/environments/examples/crafter_custom/crafter/constants.py +8 -0
- synth_ai/environments/examples/crafter_custom/crafter/engine.py +269 -0
- synth_ai/environments/examples/crafter_custom/crafter/env.py +266 -0
- synth_ai/environments/examples/crafter_custom/crafter/objects.py +418 -0
- synth_ai/environments/examples/crafter_custom/crafter/recorder.py +187 -0
- synth_ai/environments/examples/crafter_custom/crafter/worldgen.py +119 -0
- synth_ai/environments/examples/crafter_custom/dataset_builder.py +373 -0
- synth_ai/environments/examples/crafter_custom/environment.py +312 -0
- synth_ai/environments/examples/crafter_custom/run_dataset.py +305 -0
- synth_ai/environments/examples/enron/art_helpers/email_search_tools.py +156 -0
- synth_ai/environments/examples/enron/art_helpers/local_email_db.py +280 -0
- synth_ai/environments/examples/enron/art_helpers/types_enron.py +24 -0
- synth_ai/environments/examples/enron/engine.py +291 -0
- synth_ai/environments/examples/enron/environment.py +165 -0
- synth_ai/environments/examples/enron/taskset.py +112 -0
- synth_ai/environments/examples/minigrid/__init__.py +48 -0
- synth_ai/environments/examples/minigrid/engine.py +589 -0
- synth_ai/environments/examples/minigrid/environment.py +274 -0
- synth_ai/environments/examples/minigrid/environment_mapping.py +242 -0
- synth_ai/environments/examples/minigrid/puzzle_loader.py +416 -0
- synth_ai/environments/examples/minigrid/taskset.py +583 -0
- synth_ai/environments/examples/nethack/__init__.py +7 -0
- synth_ai/environments/examples/nethack/achievements.py +337 -0
- synth_ai/environments/examples/nethack/engine.py +738 -0
- synth_ai/environments/examples/nethack/environment.py +255 -0
- synth_ai/environments/examples/nethack/helpers/__init__.py +42 -0
- synth_ai/environments/examples/nethack/helpers/action_mapping.py +301 -0
- synth_ai/environments/examples/nethack/helpers/nle_wrapper.py +401 -0
- synth_ai/environments/examples/nethack/helpers/observation_utils.py +433 -0
- synth_ai/environments/examples/nethack/helpers/recording_wrapper.py +201 -0
- synth_ai/environments/examples/nethack/helpers/trajectory_recorder.py +268 -0
- synth_ai/environments/examples/nethack/helpers/visualization/replay_viewer.py +308 -0
- synth_ai/environments/examples/nethack/helpers/visualization/visualizer.py +430 -0
- synth_ai/environments/examples/nethack/taskset.py +323 -0
- synth_ai/environments/examples/red/__init__.py +7 -0
- synth_ai/environments/examples/red/config_logging.py +110 -0
- synth_ai/environments/examples/red/engine.py +693 -0
- synth_ai/environments/examples/red/engine_helpers/__init__.py +1 -0
- synth_ai/environments/examples/red/engine_helpers/memory_map.py +28 -0
- synth_ai/environments/examples/red/engine_helpers/reward_components.py +275 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/__init__.py +142 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/adaptive_rewards.py +56 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/battle_rewards.py +283 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/composite_rewards.py +149 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/economy_rewards.py +137 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/efficiency_rewards.py +56 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/exploration_rewards.py +330 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/novelty_rewards.py +120 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/pallet_town_rewards.py +558 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/pokemon_rewards.py +312 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/social_rewards.py +147 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/story_rewards.py +246 -0
- synth_ai/environments/examples/red/engine_helpers/screen_analysis.py +367 -0
- synth_ai/environments/examples/red/engine_helpers/state_extraction.py +139 -0
- synth_ai/environments/examples/red/environment.py +235 -0
- synth_ai/environments/examples/red/taskset.py +77 -0
- synth_ai/environments/examples/sokoban/__init__.py +1 -0
- synth_ai/environments/examples/sokoban/engine.py +675 -0
- synth_ai/environments/examples/sokoban/engine_helpers/__init__.py +1 -0
- synth_ai/environments/examples/sokoban/engine_helpers/room_utils.py +656 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/__init__.py +17 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/__init__.py +3 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/boxoban_env.py +129 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/render_utils.py +370 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/room_utils.py +331 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/sokoban_env.py +305 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/sokoban_env_fixed_targets.py +66 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/sokoban_env_pull.py +114 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/sokoban_env_two_player.py +122 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/sokoban_env_variations.py +394 -0
- synth_ai/environments/examples/sokoban/environment.py +228 -0
- synth_ai/environments/examples/sokoban/generate_verified_puzzles.py +438 -0
- synth_ai/environments/examples/sokoban/puzzle_loader.py +311 -0
- synth_ai/environments/examples/sokoban/taskset.py +425 -0
- synth_ai/environments/examples/tictactoe/__init__.py +1 -0
- synth_ai/environments/examples/tictactoe/engine.py +368 -0
- synth_ai/environments/examples/tictactoe/environment.py +239 -0
- synth_ai/environments/examples/tictactoe/taskset.py +214 -0
- synth_ai/environments/examples/verilog/__init__.py +10 -0
- synth_ai/environments/examples/verilog/engine.py +328 -0
- synth_ai/environments/examples/verilog/environment.py +349 -0
- synth_ai/environments/examples/verilog/taskset.py +418 -0
- synth_ai/tracing_v3/examples/basic_usage.py +188 -0
- {synth_ai-0.2.4.dev3.dist-info → synth_ai-0.2.4.dev5.dist-info}/METADATA +1 -1
- {synth_ai-0.2.4.dev3.dist-info → synth_ai-0.2.4.dev5.dist-info}/RECORD +105 -6
- {synth_ai-0.2.4.dev3.dist-info → synth_ai-0.2.4.dev5.dist-info}/WHEEL +0 -0
- {synth_ai-0.2.4.dev3.dist-info → synth_ai-0.2.4.dev5.dist-info}/entry_points.txt +0 -0
- {synth_ai-0.2.4.dev3.dist-info → synth_ai-0.2.4.dev5.dist-info}/licenses/LICENSE +0 -0
- {synth_ai-0.2.4.dev3.dist-info → synth_ai-0.2.4.dev5.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,323 @@
|
|
1
|
+
"""TaskSet generation for NetHack environment."""
|
2
|
+
|
3
|
+
from __future__ import annotations
|
4
|
+
|
5
|
+
import random
|
6
|
+
from uuid import uuid4
|
7
|
+
from dataclasses import dataclass
|
8
|
+
from typing import Dict, Any, List, Optional, Set
|
9
|
+
|
10
|
+
from synth_ai.environments.tasks.core import (
|
11
|
+
TaskInstance,
|
12
|
+
TaskInstanceMetadata,
|
13
|
+
TaskInstanceSet,
|
14
|
+
Impetus,
|
15
|
+
Intent,
|
16
|
+
SplitInfo,
|
17
|
+
)
|
18
|
+
|
19
|
+
|
20
|
+
@dataclass
|
21
|
+
class NetHackTaskInstanceMetadata(TaskInstanceMetadata):
|
22
|
+
"""Task-specific metadata for NetHack."""
|
23
|
+
|
24
|
+
character_role: str # "wizard", "knight", etc.
|
25
|
+
starting_level: int # Dungeon level to start on
|
26
|
+
target_depth: int # Goal depth to reach
|
27
|
+
time_limit: int # Maximum turns
|
28
|
+
difficulty: str # "easy", "medium", "hard"
|
29
|
+
special_objectives: List[str] # Additional goals beyond survival
|
30
|
+
seed: Optional[int] = None # Random seed for reproducibility
|
31
|
+
|
32
|
+
|
33
|
+
@dataclass
|
34
|
+
class NetHackTaskInstance(TaskInstance):
|
35
|
+
"""NetHack task instance."""
|
36
|
+
|
37
|
+
async def serialize(self) -> dict:
|
38
|
+
"""Convert to serializable format."""
|
39
|
+
return {
|
40
|
+
"id": str(self.id),
|
41
|
+
"impetus": {"instructions": self.impetus.instructions},
|
42
|
+
"intent": {
|
43
|
+
"rubric": self.intent.rubric,
|
44
|
+
"gold_trajectories": None,
|
45
|
+
"gold_state_diff": self.intent.gold_state_diff,
|
46
|
+
},
|
47
|
+
"metadata": {
|
48
|
+
"character_role": self.metadata.character_role,
|
49
|
+
"starting_level": self.metadata.starting_level,
|
50
|
+
"target_depth": self.metadata.target_depth,
|
51
|
+
"time_limit": self.metadata.time_limit,
|
52
|
+
"difficulty": self.metadata.difficulty,
|
53
|
+
"special_objectives": self.metadata.special_objectives,
|
54
|
+
"seed": self.metadata.seed,
|
55
|
+
},
|
56
|
+
"is_reproducible": self.is_reproducible,
|
57
|
+
"initial_engine_snapshot": None,
|
58
|
+
}
|
59
|
+
|
60
|
+
@classmethod
|
61
|
+
async def deserialize(cls, data: dict) -> "NetHackTaskInstance":
|
62
|
+
"""Restore from serialized data."""
|
63
|
+
return cls(
|
64
|
+
id=uuid4() if not data.get("id") else data["id"],
|
65
|
+
impetus=Impetus(instructions=data["impetus"]["instructions"]),
|
66
|
+
intent=Intent(
|
67
|
+
rubric=data["intent"]["rubric"],
|
68
|
+
gold_trajectories=None,
|
69
|
+
gold_state_diff=data["intent"]["gold_state_diff"],
|
70
|
+
),
|
71
|
+
metadata=NetHackTaskInstanceMetadata(
|
72
|
+
character_role=data["metadata"]["character_role"],
|
73
|
+
starting_level=data["metadata"]["starting_level"],
|
74
|
+
target_depth=data["metadata"]["target_depth"],
|
75
|
+
time_limit=data["metadata"]["time_limit"],
|
76
|
+
difficulty=data["metadata"]["difficulty"],
|
77
|
+
special_objectives=data["metadata"]["special_objectives"],
|
78
|
+
seed=data["metadata"].get("seed"),
|
79
|
+
),
|
80
|
+
is_reproducible=data.get("is_reproducible", True),
|
81
|
+
initial_engine_snapshot=None,
|
82
|
+
)
|
83
|
+
|
84
|
+
|
85
|
+
# Character role definitions
|
86
|
+
CHARACTER_ROLES = {
|
87
|
+
"tourist": {
|
88
|
+
"description": "A tourist with a camera and Hawaiian shirt",
|
89
|
+
"difficulty_modifier": 0.8, # Easier
|
90
|
+
"starting_items": ["camera", "credit card", "hawaiian shirt"],
|
91
|
+
"strengths": ["gold finding", "luck"],
|
92
|
+
"weaknesses": ["combat", "magic"],
|
93
|
+
},
|
94
|
+
"knight": {
|
95
|
+
"description": "A noble knight in shining armor",
|
96
|
+
"difficulty_modifier": 1.0,
|
97
|
+
"starting_items": ["long sword", "armor", "shield"],
|
98
|
+
"strengths": ["combat", "riding"],
|
99
|
+
"weaknesses": ["magic"],
|
100
|
+
},
|
101
|
+
"wizard": {
|
102
|
+
"description": "A powerful wizard with magical abilities",
|
103
|
+
"difficulty_modifier": 1.2,
|
104
|
+
"starting_items": ["quarterstaff", "spellbook", "cloak"],
|
105
|
+
"strengths": ["magic", "identify"],
|
106
|
+
"weaknesses": ["physical combat", "low hp"],
|
107
|
+
},
|
108
|
+
"barbarian": {
|
109
|
+
"description": "A fierce barbarian warrior",
|
110
|
+
"difficulty_modifier": 0.9,
|
111
|
+
"starting_items": ["battle axe", "leather armor"],
|
112
|
+
"strengths": ["combat", "hp", "strength"],
|
113
|
+
"weaknesses": ["magic", "intelligence"],
|
114
|
+
},
|
115
|
+
"ranger": {
|
116
|
+
"description": "A skilled ranger and tracker",
|
117
|
+
"difficulty_modifier": 1.0,
|
118
|
+
"starting_items": ["bow", "arrows", "cloak"],
|
119
|
+
"strengths": ["ranged combat", "stealth"],
|
120
|
+
"weaknesses": ["melee combat"],
|
121
|
+
},
|
122
|
+
"priest": {
|
123
|
+
"description": "A holy priest with divine powers",
|
124
|
+
"difficulty_modifier": 1.1,
|
125
|
+
"starting_items": ["mace", "robe", "holy water"],
|
126
|
+
"strengths": ["healing", "undead turning"],
|
127
|
+
"weaknesses": ["edged weapons"],
|
128
|
+
},
|
129
|
+
"monk": {
|
130
|
+
"description": "A disciplined monk with martial arts skills",
|
131
|
+
"difficulty_modifier": 1.3,
|
132
|
+
"starting_items": ["robe"],
|
133
|
+
"strengths": ["martial arts", "speed"],
|
134
|
+
"weaknesses": ["armor", "weapons"],
|
135
|
+
},
|
136
|
+
"rogue": {
|
137
|
+
"description": "A stealthy rogue and thief",
|
138
|
+
"difficulty_modifier": 1.1,
|
139
|
+
"starting_items": ["dagger", "leather armor", "lock pick"],
|
140
|
+
"strengths": ["stealth", "backstab", "traps"],
|
141
|
+
"weaknesses": ["direct combat"],
|
142
|
+
},
|
143
|
+
}
|
144
|
+
|
145
|
+
# Special objectives for variety
|
146
|
+
SPECIAL_OBJECTIVES = {
|
147
|
+
"exploration": [
|
148
|
+
"Explore at least 3 different dungeon levels",
|
149
|
+
"Find and enter a shop",
|
150
|
+
"Discover a special room (vault, zoo, etc.)",
|
151
|
+
"Find the entrance to the Gnomish Mines",
|
152
|
+
],
|
153
|
+
"combat": [
|
154
|
+
"Defeat 10 monsters",
|
155
|
+
"Defeat a monster using magic",
|
156
|
+
"Defeat a monster using ranged weapons",
|
157
|
+
"Survive an encounter with a tough monster",
|
158
|
+
],
|
159
|
+
"collection": [
|
160
|
+
"Collect 100 gold pieces",
|
161
|
+
"Find and identify a magical item",
|
162
|
+
"Collect food rations for survival",
|
163
|
+
"Find a valuable gem",
|
164
|
+
],
|
165
|
+
"survival": [
|
166
|
+
"Survive for 500 turns",
|
167
|
+
"Maintain full health for 100 turns",
|
168
|
+
"Never let hunger status reach 'Weak'",
|
169
|
+
"Avoid all traps",
|
170
|
+
],
|
171
|
+
"progression": [
|
172
|
+
"Reach experience level 3",
|
173
|
+
"Improve at least one skill",
|
174
|
+
"Successfully pray to your deity",
|
175
|
+
"Complete a quest or mission",
|
176
|
+
],
|
177
|
+
}
|
178
|
+
|
179
|
+
|
180
|
+
async def create_nethack_taskset() -> TaskInstanceSet:
|
181
|
+
"""Generate diverse NetHack scenarios."""
|
182
|
+
instances = []
|
183
|
+
|
184
|
+
# Configuration for different difficulty levels
|
185
|
+
DIFFICULTY_CONFIGS = {
|
186
|
+
"tutorial": {
|
187
|
+
"roles": ["tourist"],
|
188
|
+
"target_depth_range": (1, 3),
|
189
|
+
"time_limit_range": (500, 1000),
|
190
|
+
"objective_count": 1,
|
191
|
+
"count": 20,
|
192
|
+
},
|
193
|
+
"beginner": {
|
194
|
+
"roles": ["knight", "barbarian"],
|
195
|
+
"target_depth_range": (3, 5),
|
196
|
+
"time_limit_range": (1000, 2000),
|
197
|
+
"objective_count": 2,
|
198
|
+
"count": 30,
|
199
|
+
},
|
200
|
+
"intermediate": {
|
201
|
+
"roles": ["wizard", "ranger", "priest"],
|
202
|
+
"target_depth_range": (5, 10),
|
203
|
+
"time_limit_range": (2000, 5000),
|
204
|
+
"objective_count": 2,
|
205
|
+
"count": 25,
|
206
|
+
},
|
207
|
+
"advanced": {
|
208
|
+
"roles": ["monk", "rogue"],
|
209
|
+
"target_depth_range": (10, 15),
|
210
|
+
"time_limit_range": (5000, 10000),
|
211
|
+
"objective_count": 3,
|
212
|
+
"count": 15,
|
213
|
+
},
|
214
|
+
"expert": {
|
215
|
+
"roles": list(CHARACTER_ROLES.keys()),
|
216
|
+
"target_depth_range": (15, 20),
|
217
|
+
"time_limit_range": (10000, 20000),
|
218
|
+
"objective_count": 4,
|
219
|
+
"count": 10,
|
220
|
+
},
|
221
|
+
}
|
222
|
+
|
223
|
+
# Generate instances for each difficulty
|
224
|
+
for difficulty, config in DIFFICULTY_CONFIGS.items():
|
225
|
+
for i in range(config["count"]):
|
226
|
+
# Random role selection
|
227
|
+
role = random.choice(config["roles"])
|
228
|
+
role_info = CHARACTER_ROLES[role]
|
229
|
+
|
230
|
+
# Random parameters within difficulty range
|
231
|
+
min_depth, max_depth = config["target_depth_range"]
|
232
|
+
target_depth = random.randint(min_depth, max_depth)
|
233
|
+
min_time, max_time = config["time_limit_range"]
|
234
|
+
time_limit = random.randint(min_time, max_time)
|
235
|
+
|
236
|
+
# Select random objectives
|
237
|
+
objectives = []
|
238
|
+
objective_categories = list(SPECIAL_OBJECTIVES.keys())
|
239
|
+
for _ in range(config["objective_count"]):
|
240
|
+
category = random.choice(objective_categories)
|
241
|
+
objective = random.choice(SPECIAL_OBJECTIVES[category])
|
242
|
+
objectives.append(objective)
|
243
|
+
|
244
|
+
# Create instruction text
|
245
|
+
instructions = f"""You are a {role_info["description"]}.
|
246
|
+
|
247
|
+
Your primary goal is to descend to dungeon level {target_depth} within {time_limit} turns.
|
248
|
+
|
249
|
+
Additional objectives:
|
250
|
+
{chr(10).join(f"- {obj}" for obj in objectives)}
|
251
|
+
|
252
|
+
Character strengths: {", ".join(role_info["strengths"])}
|
253
|
+
Character weaknesses: {", ".join(role_info["weaknesses"])}
|
254
|
+
|
255
|
+
Tips:
|
256
|
+
- Use 'inventory' to check your items
|
257
|
+
- Use 'search' to find secret doors
|
258
|
+
- Eat food before you become weak from hunger
|
259
|
+
- Save valuable items for when you need them
|
260
|
+
- Be cautious around unfamiliar monsters
|
261
|
+
|
262
|
+
Remember: In NetHack, careful planning often beats hasty action!"""
|
263
|
+
|
264
|
+
# Create success criteria
|
265
|
+
rubric = {
|
266
|
+
"goal": f"Reach dungeon level {target_depth}",
|
267
|
+
"success_criteria": {
|
268
|
+
"primary": f"Reach dungeon level {target_depth} within {time_limit} turns",
|
269
|
+
"secondary": objectives,
|
270
|
+
},
|
271
|
+
"evaluation_metrics": {
|
272
|
+
"depth_reached": target_depth,
|
273
|
+
"time_limit": time_limit,
|
274
|
+
"objectives_completed": len(objectives),
|
275
|
+
},
|
276
|
+
}
|
277
|
+
|
278
|
+
# Create metadata
|
279
|
+
metadata = NetHackTaskInstanceMetadata(
|
280
|
+
character_role=role,
|
281
|
+
starting_level=1,
|
282
|
+
target_depth=target_depth,
|
283
|
+
time_limit=time_limit,
|
284
|
+
difficulty=difficulty,
|
285
|
+
special_objectives=objectives,
|
286
|
+
seed=random.randint(0, 2**31 - 1),
|
287
|
+
)
|
288
|
+
|
289
|
+
# Create task instance
|
290
|
+
instance = NetHackTaskInstance(
|
291
|
+
id=uuid4(),
|
292
|
+
impetus=Impetus(instructions=instructions),
|
293
|
+
intent=Intent(rubric=rubric, gold_trajectories=None, gold_state_diff={}),
|
294
|
+
metadata=metadata,
|
295
|
+
is_reproducible=True,
|
296
|
+
initial_engine_snapshot=None,
|
297
|
+
)
|
298
|
+
|
299
|
+
instances.append(instance)
|
300
|
+
|
301
|
+
# Define splits (80% train, 10% val, 10% test)
|
302
|
+
random.shuffle(instances)
|
303
|
+
n_instances = len(instances)
|
304
|
+
n_val = n_instances // 10
|
305
|
+
n_test = n_instances // 10
|
306
|
+
|
307
|
+
val_ids = {inst.id for inst in instances[:n_val]}
|
308
|
+
test_ids = {inst.id for inst in instances[n_val : n_val + n_test]}
|
309
|
+
|
310
|
+
split_info = SplitInfo(
|
311
|
+
val_instance_ids=val_ids, test_instance_ids=test_ids, _is_split_defined=True
|
312
|
+
)
|
313
|
+
|
314
|
+
return TaskInstanceSet(
|
315
|
+
name="NetHack TaskSet",
|
316
|
+
description="A comprehensive set of NetHack dungeon exploration tasks with varying difficulty levels, character roles, and objectives",
|
317
|
+
instances=instances,
|
318
|
+
split_info=split_info,
|
319
|
+
)
|
320
|
+
|
321
|
+
|
322
|
+
# Module-level export
|
323
|
+
taskset = create_nethack_taskset
|
@@ -0,0 +1,110 @@
|
|
1
|
+
"""
|
2
|
+
Logging configuration for Pokemon Red environment.
|
3
|
+
Suppresses obnoxious JAX debug messages and sets appropriate log levels.
|
4
|
+
"""
|
5
|
+
|
6
|
+
import logging
|
7
|
+
import os
|
8
|
+
import warnings
|
9
|
+
|
10
|
+
|
11
|
+
def configure_logging():
|
12
|
+
"""Configure logging to suppress noisy debug messages."""
|
13
|
+
|
14
|
+
# Suppress JAX debug logging by setting appropriate log levels
|
15
|
+
jax_loggers = [
|
16
|
+
"jax._src.cache_key",
|
17
|
+
"jax._src.compilation_cache",
|
18
|
+
"jax._src.compiler",
|
19
|
+
"jax._src.dispatch",
|
20
|
+
"jax",
|
21
|
+
"jaxlib",
|
22
|
+
]
|
23
|
+
|
24
|
+
for logger_name in jax_loggers:
|
25
|
+
logger = logging.getLogger(logger_name)
|
26
|
+
logger.setLevel(logging.WARNING)
|
27
|
+
logger.propagate = False
|
28
|
+
|
29
|
+
# Set JAX platform to CPU to avoid GPU-related logging
|
30
|
+
os.environ.setdefault("JAX_PLATFORMS", "cpu")
|
31
|
+
|
32
|
+
# Suppress JAX warnings and compilation messages
|
33
|
+
os.environ.setdefault("JAX_ENABLE_X64", "False")
|
34
|
+
os.environ.setdefault("JAX_LOG_COMPILES", "0")
|
35
|
+
os.environ.setdefault("JAX_COMPILATION_CACHE_DIR", "/tmp/jax_cache")
|
36
|
+
|
37
|
+
# Configure root logger to INFO level
|
38
|
+
logging.basicConfig(level=logging.INFO, format="%(levelname)s:%(name)s:%(message)s")
|
39
|
+
|
40
|
+
# Suppress other noisy libraries
|
41
|
+
logging.getLogger("matplotlib").setLevel(logging.WARNING)
|
42
|
+
logging.getLogger("PIL").setLevel(logging.WARNING)
|
43
|
+
logging.getLogger("urllib3").setLevel(logging.WARNING)
|
44
|
+
|
45
|
+
# Filter out specific warnings
|
46
|
+
warnings.filterwarnings("ignore", category=UserWarning, module="jax")
|
47
|
+
warnings.filterwarnings("ignore", category=FutureWarning, module="jax")
|
48
|
+
|
49
|
+
|
50
|
+
def safe_compare(left, right, operation="<"):
|
51
|
+
"""
|
52
|
+
Safely compare two values, handling string vs int comparison errors.
|
53
|
+
|
54
|
+
Args:
|
55
|
+
left: Left operand
|
56
|
+
right: Right operand
|
57
|
+
operation: Comparison operation ('>', '<', '>=', '<=', '==', '!=')
|
58
|
+
|
59
|
+
Returns:
|
60
|
+
bool: Result of comparison, or False if types are incompatible
|
61
|
+
"""
|
62
|
+
try:
|
63
|
+
# If both are strings, try to convert to numbers
|
64
|
+
if isinstance(left, str) and isinstance(right, str):
|
65
|
+
try:
|
66
|
+
left = float(left)
|
67
|
+
right = float(right)
|
68
|
+
except ValueError:
|
69
|
+
# If conversion fails, compare as strings
|
70
|
+
pass
|
71
|
+
# If one is string and one is number, try to convert string to number
|
72
|
+
elif isinstance(left, str) and isinstance(right, (int, float)):
|
73
|
+
try:
|
74
|
+
left = type(right)(left)
|
75
|
+
except ValueError:
|
76
|
+
logging.warning(f"Cannot compare string '{left}' with number {right}")
|
77
|
+
return False
|
78
|
+
elif isinstance(left, (int, float)) and isinstance(right, str):
|
79
|
+
try:
|
80
|
+
right = type(left)(right)
|
81
|
+
except ValueError:
|
82
|
+
logging.warning(f"Cannot compare number {left} with string '{right}'")
|
83
|
+
return False
|
84
|
+
|
85
|
+
# Perform the comparison
|
86
|
+
if operation == "<":
|
87
|
+
return left < right
|
88
|
+
elif operation == ">":
|
89
|
+
return left > right
|
90
|
+
elif operation == "<=":
|
91
|
+
return left <= right
|
92
|
+
elif operation == ">=":
|
93
|
+
return left >= right
|
94
|
+
elif operation == "==":
|
95
|
+
return left == right
|
96
|
+
elif operation == "!=":
|
97
|
+
return left != right
|
98
|
+
else:
|
99
|
+
raise ValueError(f"Unsupported operation: {operation}")
|
100
|
+
|
101
|
+
except TypeError as e:
|
102
|
+
logging.error(f"Type error in comparison: {left} {operation} {right} - {e}")
|
103
|
+
return False
|
104
|
+
except Exception as e:
|
105
|
+
logging.error(f"Unexpected error in comparison: {left} {operation} {right} - {e}")
|
106
|
+
return False
|
107
|
+
|
108
|
+
|
109
|
+
# Configure logging when module is imported
|
110
|
+
configure_logging()
|