synth-ai 0.2.4.dev4__py3-none-any.whl → 0.2.4.dev6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (123) hide show
  1. synth_ai/environments/examples/__init__.py +1 -0
  2. synth_ai/environments/examples/crafter_classic/__init__.py +8 -0
  3. synth_ai/environments/examples/crafter_classic/config_logging.py +111 -0
  4. synth_ai/environments/examples/crafter_classic/debug_translation.py +0 -0
  5. synth_ai/environments/examples/crafter_classic/engine.py +579 -0
  6. synth_ai/environments/examples/crafter_classic/engine_deterministic_patch.py +63 -0
  7. synth_ai/environments/examples/crafter_classic/engine_helpers/action_map.py +5 -0
  8. synth_ai/environments/examples/crafter_classic/engine_helpers/serialization.py +74 -0
  9. synth_ai/environments/examples/crafter_classic/engine_serialization_patch_v3.py +266 -0
  10. synth_ai/environments/examples/crafter_classic/environment.py +364 -0
  11. synth_ai/environments/examples/crafter_classic/taskset.py +233 -0
  12. synth_ai/environments/examples/crafter_classic/trace_hooks_v3.py +229 -0
  13. synth_ai/environments/examples/crafter_classic/world_config_patch_simple.py +298 -0
  14. synth_ai/environments/examples/crafter_custom/__init__.py +4 -0
  15. synth_ai/environments/examples/crafter_custom/crafter/__init__.py +7 -0
  16. synth_ai/environments/examples/crafter_custom/crafter/config.py +182 -0
  17. synth_ai/environments/examples/crafter_custom/crafter/constants.py +8 -0
  18. synth_ai/environments/examples/crafter_custom/crafter/engine.py +269 -0
  19. synth_ai/environments/examples/crafter_custom/crafter/env.py +266 -0
  20. synth_ai/environments/examples/crafter_custom/crafter/objects.py +418 -0
  21. synth_ai/environments/examples/crafter_custom/crafter/recorder.py +187 -0
  22. synth_ai/environments/examples/crafter_custom/crafter/worldgen.py +119 -0
  23. synth_ai/environments/examples/crafter_custom/dataset_builder.py +373 -0
  24. synth_ai/environments/examples/crafter_custom/environment.py +312 -0
  25. synth_ai/environments/examples/crafter_custom/run_dataset.py +305 -0
  26. synth_ai/environments/examples/enron/art_helpers/email_search_tools.py +156 -0
  27. synth_ai/environments/examples/enron/art_helpers/local_email_db.py +280 -0
  28. synth_ai/environments/examples/enron/art_helpers/types_enron.py +24 -0
  29. synth_ai/environments/examples/enron/engine.py +291 -0
  30. synth_ai/environments/examples/enron/environment.py +165 -0
  31. synth_ai/environments/examples/enron/taskset.py +112 -0
  32. synth_ai/environments/examples/minigrid/__init__.py +48 -0
  33. synth_ai/environments/examples/minigrid/engine.py +589 -0
  34. synth_ai/environments/examples/minigrid/environment.py +274 -0
  35. synth_ai/environments/examples/minigrid/environment_mapping.py +242 -0
  36. synth_ai/environments/examples/minigrid/puzzle_loader.py +416 -0
  37. synth_ai/environments/examples/minigrid/taskset.py +583 -0
  38. synth_ai/environments/examples/nethack/__init__.py +7 -0
  39. synth_ai/environments/examples/nethack/achievements.py +337 -0
  40. synth_ai/environments/examples/nethack/engine.py +738 -0
  41. synth_ai/environments/examples/nethack/environment.py +255 -0
  42. synth_ai/environments/examples/nethack/helpers/__init__.py +42 -0
  43. synth_ai/environments/examples/nethack/helpers/action_mapping.py +301 -0
  44. synth_ai/environments/examples/nethack/helpers/nle_wrapper.py +401 -0
  45. synth_ai/environments/examples/nethack/helpers/observation_utils.py +433 -0
  46. synth_ai/environments/examples/nethack/helpers/recording_wrapper.py +201 -0
  47. synth_ai/environments/examples/nethack/helpers/trajectory_recorder.py +268 -0
  48. synth_ai/environments/examples/nethack/helpers/visualization/replay_viewer.py +308 -0
  49. synth_ai/environments/examples/nethack/helpers/visualization/visualizer.py +430 -0
  50. synth_ai/environments/examples/nethack/taskset.py +323 -0
  51. synth_ai/environments/examples/red/__init__.py +7 -0
  52. synth_ai/environments/examples/red/config_logging.py +110 -0
  53. synth_ai/environments/examples/red/engine.py +693 -0
  54. synth_ai/environments/examples/red/engine_helpers/__init__.py +1 -0
  55. synth_ai/environments/examples/red/engine_helpers/memory_map.py +28 -0
  56. synth_ai/environments/examples/red/engine_helpers/reward_components.py +275 -0
  57. synth_ai/environments/examples/red/engine_helpers/reward_library/__init__.py +142 -0
  58. synth_ai/environments/examples/red/engine_helpers/reward_library/adaptive_rewards.py +56 -0
  59. synth_ai/environments/examples/red/engine_helpers/reward_library/battle_rewards.py +283 -0
  60. synth_ai/environments/examples/red/engine_helpers/reward_library/composite_rewards.py +149 -0
  61. synth_ai/environments/examples/red/engine_helpers/reward_library/economy_rewards.py +137 -0
  62. synth_ai/environments/examples/red/engine_helpers/reward_library/efficiency_rewards.py +56 -0
  63. synth_ai/environments/examples/red/engine_helpers/reward_library/exploration_rewards.py +330 -0
  64. synth_ai/environments/examples/red/engine_helpers/reward_library/novelty_rewards.py +120 -0
  65. synth_ai/environments/examples/red/engine_helpers/reward_library/pallet_town_rewards.py +558 -0
  66. synth_ai/environments/examples/red/engine_helpers/reward_library/pokemon_rewards.py +312 -0
  67. synth_ai/environments/examples/red/engine_helpers/reward_library/social_rewards.py +147 -0
  68. synth_ai/environments/examples/red/engine_helpers/reward_library/story_rewards.py +246 -0
  69. synth_ai/environments/examples/red/engine_helpers/screen_analysis.py +367 -0
  70. synth_ai/environments/examples/red/engine_helpers/state_extraction.py +139 -0
  71. synth_ai/environments/examples/red/environment.py +235 -0
  72. synth_ai/environments/examples/red/taskset.py +77 -0
  73. synth_ai/environments/examples/sokoban/__init__.py +1 -0
  74. synth_ai/environments/examples/sokoban/engine.py +675 -0
  75. synth_ai/environments/examples/sokoban/engine_helpers/__init__.py +1 -0
  76. synth_ai/environments/examples/sokoban/engine_helpers/room_utils.py +656 -0
  77. synth_ai/environments/examples/sokoban/engine_helpers/vendored/__init__.py +17 -0
  78. synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/__init__.py +3 -0
  79. synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/boxoban_env.py +129 -0
  80. synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/render_utils.py +370 -0
  81. synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/room_utils.py +331 -0
  82. synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/sokoban_env.py +305 -0
  83. synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/sokoban_env_fixed_targets.py +66 -0
  84. synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/sokoban_env_pull.py +114 -0
  85. synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/sokoban_env_two_player.py +122 -0
  86. synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/sokoban_env_variations.py +394 -0
  87. synth_ai/environments/examples/sokoban/environment.py +228 -0
  88. synth_ai/environments/examples/sokoban/generate_verified_puzzles.py +438 -0
  89. synth_ai/environments/examples/sokoban/puzzle_loader.py +311 -0
  90. synth_ai/environments/examples/sokoban/taskset.py +425 -0
  91. synth_ai/environments/examples/tictactoe/__init__.py +1 -0
  92. synth_ai/environments/examples/tictactoe/engine.py +368 -0
  93. synth_ai/environments/examples/tictactoe/environment.py +239 -0
  94. synth_ai/environments/examples/tictactoe/taskset.py +214 -0
  95. synth_ai/environments/examples/verilog/__init__.py +10 -0
  96. synth_ai/environments/examples/verilog/engine.py +328 -0
  97. synth_ai/environments/examples/verilog/environment.py +349 -0
  98. synth_ai/environments/examples/verilog/taskset.py +418 -0
  99. synth_ai/environments/examples/wordle/__init__.py +29 -0
  100. synth_ai/environments/examples/wordle/engine.py +391 -0
  101. synth_ai/environments/examples/wordle/environment.py +154 -0
  102. synth_ai/environments/examples/wordle/helpers/generate_instances_wordfreq.py +75 -0
  103. synth_ai/environments/examples/wordle/taskset.py +222 -0
  104. synth_ai/environments/service/app.py +8 -0
  105. synth_ai/environments/service/core_routes.py +38 -0
  106. synth_ai/learning/prompts/banking77_injection_eval.py +163 -0
  107. synth_ai/learning/prompts/hello_world_in_context_injection_ex.py +201 -0
  108. synth_ai/learning/prompts/mipro.py +273 -1
  109. synth_ai/learning/prompts/random_search.py +247 -0
  110. synth_ai/learning/prompts/run_mipro_banking77.py +160 -0
  111. synth_ai/learning/prompts/run_random_search_banking77.py +305 -0
  112. synth_ai/lm/injection.py +81 -0
  113. synth_ai/lm/overrides.py +204 -0
  114. synth_ai/lm/provider_support/anthropic.py +39 -12
  115. synth_ai/lm/provider_support/openai.py +31 -4
  116. synth_ai/lm/vendors/core/anthropic_api.py +16 -0
  117. synth_ai/lm/vendors/openai_standard.py +35 -5
  118. {synth_ai-0.2.4.dev4.dist-info → synth_ai-0.2.4.dev6.dist-info}/METADATA +2 -1
  119. {synth_ai-0.2.4.dev4.dist-info → synth_ai-0.2.4.dev6.dist-info}/RECORD +123 -13
  120. {synth_ai-0.2.4.dev4.dist-info → synth_ai-0.2.4.dev6.dist-info}/WHEEL +0 -0
  121. {synth_ai-0.2.4.dev4.dist-info → synth_ai-0.2.4.dev6.dist-info}/entry_points.txt +0 -0
  122. {synth_ai-0.2.4.dev4.dist-info → synth_ai-0.2.4.dev6.dist-info}/licenses/LICENSE +0 -0
  123. {synth_ai-0.2.4.dev4.dist-info → synth_ai-0.2.4.dev6.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,312 @@
1
+ """
2
+ Pokemon Collection & Management Reward Components
3
+
4
+ Rewards for catching Pokemon, Pokedex progress, and Pokemon development.
5
+ """
6
+
7
+ from synth_ai.environments.environment.rewards.core import RewardComponent
8
+ from typing import Dict, Any, Set
9
+
10
+
11
+ class FirstPokemonCaughtReward(RewardComponent):
12
+ """Reward for catching the starter or first wild Pokemon - +50 points"""
13
+
14
+ def __init__(self):
15
+ self.first_caught = False
16
+
17
+ async def score(self, state: Dict[str, Any], action: Dict[str, Any]) -> float:
18
+ if self.first_caught:
19
+ return 0.0
20
+
21
+ prev_party_count = len(action.get("prev_party", []))
22
+ current_party_count = len(state.get("party", []))
23
+
24
+ # First Pokemon acquired (starter or caught)
25
+ if prev_party_count == 0 and current_party_count == 1:
26
+ self.first_caught = True
27
+ return 50.0
28
+ return 0.0
29
+
30
+
31
+ class NewSpeciesCaughtReward(RewardComponent):
32
+ """Reward for each new Pokedex entry - +20 points"""
33
+
34
+ def __init__(self):
35
+ self.species_caught: Set[int] = set()
36
+
37
+ async def score(self, state: Dict[str, Any], action: Dict[str, Any]) -> float:
38
+ # Check for new Pokemon in party
39
+ party = state.get("party", [])
40
+ total_reward = 0.0
41
+
42
+ for pokemon in party:
43
+ species_id = pokemon.get("species_id", 0)
44
+ if species_id not in self.species_caught and species_id > 0:
45
+ self.species_caught.add(species_id)
46
+ total_reward += 20.0
47
+
48
+ return total_reward
49
+
50
+
51
+ class RarePokemonCaughtReward(RewardComponent):
52
+ """Reward for catching uncommon/rare Pokemon - +40 points"""
53
+
54
+ def __init__(self):
55
+ self.rare_pokemon_caught: Set[int] = set()
56
+ # Rare Pokemon species IDs (would be loaded from game data)
57
+ self.rare_species = {
58
+ 144,
59
+ 145,
60
+ 146, # Legendary birds
61
+ 150, # Mewtwo
62
+ 149, # Dragonite
63
+ 130,
64
+ 131, # Gyarados, Lapras
65
+ 138,
66
+ 139, # Omanyte, Omastar
67
+ 140,
68
+ 141, # Kabuto, Kabutops
69
+ 142, # Aerodactyl
70
+ }
71
+
72
+ async def score(self, state: Dict[str, Any], action: Dict[str, Any]) -> float:
73
+ party = state.get("party", [])
74
+ total_reward = 0.0
75
+
76
+ for pokemon in party:
77
+ species_id = pokemon.get("species_id", 0)
78
+ if species_id in self.rare_species and species_id not in self.rare_pokemon_caught:
79
+ self.rare_pokemon_caught.add(species_id)
80
+ total_reward += 40.0
81
+
82
+ return total_reward
83
+
84
+
85
+ class EvolutionStonePokemonReward(RewardComponent):
86
+ """Reward for catching Pokemon that require evolution stones - +30 points"""
87
+
88
+ def __init__(self):
89
+ self.evolution_stone_pokemon_caught: Set[int] = set()
90
+ # Pokemon that evolve with stones
91
+ self.evolution_stone_pokemon = {
92
+ 25, # Pikachu (Thunder Stone)
93
+ 30, # Nidorina (Moon Stone)
94
+ 33, # Nidorino (Moon Stone)
95
+ 35, # Clefairy (Moon Stone)
96
+ 37, # Vulpix (Fire Stone)
97
+ 39, # Jigglypuff (Moon Stone)
98
+ 44, # Gloom (Leaf Stone)
99
+ 58, # Growlithe (Fire Stone)
100
+ 61, # Poliwhirl (Water Stone)
101
+ 90, # Shellder (Water Stone)
102
+ 102, # Exeggcute (Leaf Stone)
103
+ 108, # Lickitung (rare)
104
+ 120, # Staryu (Water Stone)
105
+ }
106
+
107
+ async def score(self, state: Dict[str, Any], action: Dict[str, Any]) -> float:
108
+ party = state.get("party", [])
109
+ total_reward = 0.0
110
+
111
+ for pokemon in party:
112
+ species_id = pokemon.get("species_id", 0)
113
+ if (
114
+ species_id in self.evolution_stone_pokemon
115
+ and species_id not in self.evolution_stone_pokemon_caught
116
+ ):
117
+ self.evolution_stone_pokemon_caught.add(species_id)
118
+ total_reward += 30.0
119
+
120
+ return total_reward
121
+
122
+
123
+ class PokedexMilestonesReward(RewardComponent):
124
+ """Reward for reaching Pokedex milestones - +100 points for 10, 25, 50, 100, 150"""
125
+
126
+ def __init__(self):
127
+ self.milestones_reached: Set[int] = set()
128
+ self.milestones = [10, 25, 50, 100, 150]
129
+ self.unique_species: Set[int] = set()
130
+
131
+ async def score(self, state: Dict[str, Any], action: Dict[str, Any]) -> float:
132
+ # Track unique species seen/caught
133
+ party = state.get("party", [])
134
+ for pokemon in party:
135
+ species_id = pokemon.get("species_id", 0)
136
+ if species_id > 0:
137
+ self.unique_species.add(species_id)
138
+
139
+ total_reward = 0.0
140
+ species_count = len(self.unique_species)
141
+
142
+ for milestone in self.milestones:
143
+ if species_count >= milestone and milestone not in self.milestones_reached:
144
+ self.milestones_reached.add(milestone)
145
+ total_reward += 100.0
146
+
147
+ return total_reward
148
+
149
+
150
+ class AreaPokedexCompletionReward(RewardComponent):
151
+ """Reward for catching all Pokemon available in an area - +50 points"""
152
+
153
+ def __init__(self):
154
+ self.completed_areas: Set[int] = set()
155
+ # Area Pokemon lists (would be loaded from game data)
156
+ self.area_pokemon = {
157
+ 0: {16, 17, 18}, # Pallet Town area (Pidgey line)
158
+ 1: {10, 11, 13, 14}, # Route 1 (Caterpie, Weedle lines)
159
+ # Add more areas
160
+ }
161
+ self.caught_by_area: Dict[int, Set[int]] = {}
162
+
163
+ async def score(self, state: Dict[str, Any], action: Dict[str, Any]) -> float:
164
+ current_map = state["map_id"]
165
+
166
+ if current_map in self.completed_areas:
167
+ return 0.0
168
+
169
+ # Track caught Pokemon in this area
170
+ if current_map not in self.caught_by_area:
171
+ self.caught_by_area[current_map] = set()
172
+
173
+ party = state.get("party", [])
174
+ for pokemon in party:
175
+ species_id = pokemon.get("species_id", 0)
176
+ if species_id > 0:
177
+ self.caught_by_area[current_map].add(species_id)
178
+
179
+ # Check if area is complete
180
+ required_pokemon = self.area_pokemon.get(current_map, set())
181
+ if required_pokemon.issubset(self.caught_by_area[current_map]):
182
+ self.completed_areas.add(current_map)
183
+ return 50.0
184
+
185
+ return 0.0
186
+
187
+
188
+ class TypeCollectionReward(RewardComponent):
189
+ """Reward for catching first Pokemon of each type - +25 points"""
190
+
191
+ def __init__(self):
192
+ self.types_collected: Set[str] = set()
193
+ # Pokemon type mappings (simplified)
194
+ self.pokemon_types = {
195
+ 1: "grass",
196
+ 4: "fire",
197
+ 7: "water",
198
+ 25: "electric",
199
+ # Add more mappings
200
+ }
201
+
202
+ async def score(self, state: Dict[str, Any], action: Dict[str, Any]) -> float:
203
+ party = state.get("party", [])
204
+ total_reward = 0.0
205
+
206
+ for pokemon in party:
207
+ species_id = pokemon.get("species_id", 0)
208
+ pokemon_type = self.pokemon_types.get(species_id)
209
+
210
+ if pokemon_type and pokemon_type not in self.types_collected:
211
+ self.types_collected.add(pokemon_type)
212
+ total_reward += 25.0
213
+
214
+ return total_reward
215
+
216
+
217
+ class PokemonEvolutionReward(RewardComponent):
218
+ """Reward for evolving Pokemon - +30 points"""
219
+
220
+ def __init__(self):
221
+ self.evolution_count = 0
222
+ self.previous_species: Set[int] = set()
223
+
224
+ async def score(self, state: Dict[str, Any], action: Dict[str, Any]) -> float:
225
+ # Track species changes (evolution)
226
+ prev_party = action.get("prev_party", [])
227
+ current_party = state.get("party", [])
228
+
229
+ prev_species = {p.get("species_id", 0) for p in prev_party}
230
+ current_species = {p.get("species_id", 0) for p in current_party}
231
+
232
+ # Check for evolution (new species appears, old species disappears)
233
+ evolved_species = current_species - prev_species
234
+
235
+ if evolved_species and self._is_evolution(prev_species, current_species):
236
+ return 30.0
237
+
238
+ return 0.0
239
+
240
+ def _is_evolution(self, prev_species: Set[int], current_species: Set[int]) -> bool:
241
+ """Check if species change represents evolution"""
242
+ # This would check evolution chains
243
+ # Simplified: any new species with same party size is evolution
244
+ return len(prev_species) == len(current_species) and prev_species != current_species
245
+
246
+
247
+ class LevelMilestonesReward(RewardComponent):
248
+ """Reward for reaching levels 10, 20, 30, 40, 50 with any Pokemon - +10 points"""
249
+
250
+ def __init__(self):
251
+ self.level_milestones_reached: Set[tuple] = set() # (pokemon_id, level)
252
+ self.milestones = [10, 20, 30, 40, 50]
253
+
254
+ async def score(self, state: Dict[str, Any], action: Dict[str, Any]) -> float:
255
+ party = state.get("party", [])
256
+ total_reward = 0.0
257
+
258
+ for i, pokemon in enumerate(party):
259
+ level = pokemon.get("level", 0)
260
+
261
+ for milestone in self.milestones:
262
+ milestone_key = (i, milestone)
263
+ if level >= milestone and milestone_key not in self.level_milestones_reached:
264
+ self.level_milestones_reached.add(milestone_key)
265
+ total_reward += 10.0
266
+
267
+ return total_reward
268
+
269
+
270
+ class MoveLearningReward(RewardComponent):
271
+ """Reward for learning new moves (not replacing) - +5 points"""
272
+
273
+ def __init__(self):
274
+ self.known_moves: Set[tuple] = set() # (pokemon_index, move_id)
275
+
276
+ async def score(self, state: Dict[str, Any], action: Dict[str, Any]) -> float:
277
+ party = state.get("party", [])
278
+ total_reward = 0.0
279
+
280
+ for i, pokemon in enumerate(party):
281
+ moves = pokemon.get("moves", [])
282
+ for move_id in moves:
283
+ move_key = (i, move_id)
284
+ if move_key not in self.known_moves and move_id > 0:
285
+ self.known_moves.add(move_key)
286
+ total_reward += 5.0
287
+
288
+ return total_reward
289
+
290
+
291
+ class TMHMTeachingReward(RewardComponent):
292
+ """Reward for successfully teaching TMs/HMs - +10 points"""
293
+
294
+ def __init__(self):
295
+ self.tm_hm_taught: Set[tuple] = set() # (pokemon_index, tm_hm_id)
296
+ # TM/HM move IDs (would be loaded from game data)
297
+ self.tm_hm_moves = set(range(15, 65)) # Example TM/HM move range
298
+
299
+ async def score(self, state: Dict[str, Any], action: Dict[str, Any]) -> float:
300
+ party = state.get("party", [])
301
+ total_reward = 0.0
302
+
303
+ for i, pokemon in enumerate(party):
304
+ moves = pokemon.get("moves", [])
305
+ for move_id in moves:
306
+ if move_id in self.tm_hm_moves:
307
+ move_key = (i, move_id)
308
+ if move_key not in self.tm_hm_taught:
309
+ self.tm_hm_taught.add(move_key)
310
+ total_reward += 10.0
311
+
312
+ return total_reward
@@ -0,0 +1,147 @@
1
+ """
2
+ Social & NPC Interaction Reward Components
3
+
4
+ Rewards for dialogue, information gathering, and NPC interactions.
5
+ """
6
+
7
+ from synth_ai.environments.environment.rewards.core import RewardComponent
8
+ from typing import Dict, Any, Set
9
+
10
+
11
+ class NewNPCConversationReward(RewardComponent):
12
+ """Reward for talking to each unique NPC for the first time - +5 points"""
13
+
14
+ def __init__(self):
15
+ self.npcs_talked_to: Set[tuple] = set()
16
+
17
+ async def score(self, state: Dict[str, Any], action: Dict[str, Any]) -> float:
18
+ if state["text_box_active"] and not action.get("prev_text_box_active", False):
19
+ npc_key = (state["player_x"], state["player_y"], state["map_id"])
20
+ if npc_key not in self.npcs_talked_to:
21
+ self.npcs_talked_to.add(npc_key)
22
+ return 5.0
23
+ return 0.0
24
+
25
+
26
+ class HelpfulInformationReceivedReward(RewardComponent):
27
+ """Reward for getting useful hints, directions, or game tips - +10 points"""
28
+
29
+ async def score(self, state: Dict[str, Any], action: Dict[str, Any]) -> float:
30
+ # This would need dialogue content analysis
31
+ # Placeholder implementation
32
+ if state["text_box_active"] and not action.get("prev_text_box_active", False):
33
+ # Simplified: reward for certain locations known to give helpful info
34
+ helpful_locations = {(5, 3, 0), (2, 4, 3)} # Example helpful NPC locations
35
+ location = (state["player_x"], state["player_y"], state["map_id"])
36
+ if location in helpful_locations:
37
+ return 10.0
38
+ return 0.0
39
+
40
+
41
+ class StoryDialogueProgressionReward(RewardComponent):
42
+ """Reward for advancing story through key NPCs - +15 points"""
43
+
44
+ def __init__(self):
45
+ self.story_npcs_talked_to: Set[tuple] = set()
46
+
47
+ async def score(self, state: Dict[str, Any], action: Dict[str, Any]) -> float:
48
+ # Story NPCs in key locations
49
+ story_locations = {(3, 4, 3), (5, 2, 0)} # Oak's lab, important NPCs
50
+ location = (state["player_x"], state["player_y"], state["map_id"])
51
+
52
+ if (
53
+ state["text_box_active"]
54
+ and not action.get("prev_text_box_active", False)
55
+ and location in story_locations
56
+ and location not in self.story_npcs_talked_to
57
+ ):
58
+ self.story_npcs_talked_to.add(location)
59
+ return 15.0
60
+ return 0.0
61
+
62
+
63
+ class ProfessorOakInteractionsReward(RewardComponent):
64
+ """Reward for meaningful interactions with Professor Oak - +20 points"""
65
+
66
+ def __init__(self):
67
+ self.oak_interactions = 0
68
+
69
+ async def score(self, state: Dict[str, Any], action: Dict[str, Any]) -> float:
70
+ # Oak's lab interactions
71
+ if (
72
+ state["map_id"] == 3
73
+ and state["text_box_active"]
74
+ and not action.get("prev_text_box_active", False)
75
+ ):
76
+ # Check if this is likely Oak (center of lab)
77
+ if 3 <= state["player_x"] <= 5 and 4 <= state["player_y"] <= 6:
78
+ return 20.0
79
+ return 0.0
80
+
81
+
82
+ class NPCGiftReceivedReward(RewardComponent):
83
+ """Reward for receiving Pokemon or items from NPCs - +15 points"""
84
+
85
+ async def score(self, state: Dict[str, Any], action: Dict[str, Any]) -> float:
86
+ # Check for item/Pokemon acquisition during NPC interaction
87
+ prev_inventory_count = len(action.get("prev_inventory", []))
88
+ current_inventory_count = len(state.get("inventory", []))
89
+ prev_party_count = len(action.get("prev_party", []))
90
+ current_party_count = len(state.get("party", []))
91
+
92
+ # Gift received if items/Pokemon increased during text interaction
93
+ if state["text_box_active"] and (
94
+ current_inventory_count > prev_inventory_count or current_party_count > prev_party_count
95
+ ):
96
+ return 15.0
97
+ return 0.0
98
+
99
+
100
+ class TradeCompletionReward(RewardComponent):
101
+ """Reward for completing in-game trades - +25 points"""
102
+
103
+ def __init__(self):
104
+ self.trades_completed: Set[tuple] = set()
105
+
106
+ async def score(self, state: Dict[str, Any], action: Dict[str, Any]) -> float:
107
+ # Trade locations (would be loaded from game data)
108
+ trade_locations = {(2, 3, 15), (4, 5, 20)} # Example trade locations
109
+ location = (state["player_x"], state["player_y"], state["map_id"])
110
+
111
+ if location in trade_locations and location not in self.trades_completed:
112
+ # Check for Pokemon species change (trade occurred)
113
+ prev_party = action.get("prev_party", [])
114
+ current_party = state.get("party", [])
115
+
116
+ if len(prev_party) == len(current_party):
117
+ prev_species = {p.get("species_id") for p in prev_party}
118
+ current_species = {p.get("species_id") for p in current_party}
119
+
120
+ if prev_species != current_species:
121
+ self.trades_completed.add(location)
122
+ return 25.0
123
+ return 0.0
124
+
125
+
126
+ class NameRaterUsageReward(RewardComponent):
127
+ """Reward for using nickname services - +5 points"""
128
+
129
+ def __init__(self):
130
+ self.name_rater_used = False
131
+
132
+ async def score(self, state: Dict[str, Any], action: Dict[str, Any]) -> float:
133
+ if self.name_rater_used:
134
+ return 0.0
135
+
136
+ # Name Rater location (would be loaded from game data)
137
+ name_rater_location = (3, 2, 25) # Example location
138
+ location = (state["player_x"], state["player_y"], state["map_id"])
139
+
140
+ if (
141
+ location == name_rater_location
142
+ and state["text_box_active"]
143
+ and not action.get("prev_text_box_active", False)
144
+ ):
145
+ self.name_rater_used = True
146
+ return 5.0
147
+ return 0.0
@@ -0,0 +1,246 @@
1
+ """
2
+ Story & Achievement Progression Reward Components
3
+
4
+ Rewards for major milestones, story gates, and achievements.
5
+ """
6
+
7
+ from synth_ai.environments.environment.rewards.core import RewardComponent
8
+ from typing import Dict, Any, Set
9
+
10
+
11
+ class GymBadgeEarnedReward(RewardComponent):
12
+ """Reward for earning gym badges - +150 points per badge (cumulative)"""
13
+
14
+ def __init__(self):
15
+ self.previous_badge_count = 0
16
+
17
+ async def score(self, state: Dict[str, Any], action: Dict[str, Any]) -> float:
18
+ current_badges = state.get("badges", 0)
19
+
20
+ # Convert badge bitmask to count
21
+ badge_count = bin(current_badges).count("1")
22
+
23
+ if badge_count > self.previous_badge_count:
24
+ new_badges = badge_count - self.previous_badge_count
25
+ self.previous_badge_count = badge_count
26
+ return new_badges * 150.0
27
+
28
+ return 0.0
29
+
30
+
31
+ class HMAcquisitionReward(RewardComponent):
32
+ """Reward for getting HMs - +75 points"""
33
+
34
+ def __init__(self):
35
+ self.hms_acquired: Set[int] = set()
36
+ # HM item IDs (would be loaded from game data)
37
+ self.hm_items = {200, 201, 202, 203, 204} # Example HM IDs
38
+
39
+ async def score(self, state: Dict[str, Any], action: Dict[str, Any]) -> float:
40
+ inventory = state.get("inventory", [])
41
+ total_reward = 0.0
42
+
43
+ for item in inventory:
44
+ item_id = item.get("item_id", 0)
45
+ if item_id in self.hm_items and item_id not in self.hms_acquired:
46
+ self.hms_acquired.add(item_id)
47
+ total_reward += 75.0
48
+
49
+ return total_reward
50
+
51
+
52
+ class EliteFourAccessReward(RewardComponent):
53
+ """Reward for reaching Pokemon League - +300 points"""
54
+
55
+ def __init__(self):
56
+ self.elite_four_accessed = False
57
+ self.elite_four_map = 100 # Pokemon League entrance
58
+
59
+ async def score(self, state: Dict[str, Any], action: Dict[str, Any]) -> float:
60
+ if self.elite_four_accessed:
61
+ return 0.0
62
+
63
+ if state["map_id"] == self.elite_four_map:
64
+ self.elite_four_accessed = True
65
+ return 300.0
66
+
67
+ return 0.0
68
+
69
+
70
+ class HallOfFameEntryReward(RewardComponent):
71
+ """Reward for becoming Champion - +1000 points"""
72
+
73
+ def __init__(self):
74
+ self.hall_of_fame_entered = False
75
+ self.hall_of_fame_map = 105 # Hall of Fame room
76
+
77
+ async def score(self, state: Dict[str, Any], action: Dict[str, Any]) -> float:
78
+ if self.hall_of_fame_entered:
79
+ return 0.0
80
+
81
+ if state["map_id"] == self.hall_of_fame_map:
82
+ self.hall_of_fame_entered = True
83
+ return 1000.0
84
+
85
+ return 0.0
86
+
87
+
88
+ class RivalBattleCompletionReward(RewardComponent):
89
+ """Reward for each scripted rival encounter - +50 points"""
90
+
91
+ def __init__(self):
92
+ self.rival_battles_completed: Set[int] = set()
93
+ # Rival battle locations
94
+ self.rival_battle_maps = {3, 22, 25, 30} # Oak's lab, Route 22, etc.
95
+
96
+ async def score(self, state: Dict[str, Any], action: Dict[str, Any]) -> float:
97
+ prev_in_battle = action.get("prev_in_battle", False)
98
+ current_in_battle = state["in_battle"]
99
+ battle_outcome = state.get("battle_outcome", 0)
100
+ current_map = state["map_id"]
101
+
102
+ # Completed rival battle
103
+ if (
104
+ prev_in_battle
105
+ and not current_in_battle
106
+ and battle_outcome == 1
107
+ and current_map in self.rival_battle_maps
108
+ and current_map not in self.rival_battles_completed
109
+ ):
110
+ self.rival_battles_completed.add(current_map)
111
+ return 50.0
112
+
113
+ return 0.0
114
+
115
+
116
+ class TeamRocketDefeatReward(RewardComponent):
117
+ """Reward for each Team Rocket encounter - +40 points"""
118
+
119
+ def __init__(self):
120
+ self.rocket_encounters: Set[tuple] = set()
121
+
122
+ async def score(self, state: Dict[str, Any], action: Dict[str, Any]) -> float:
123
+ # This would need Team Rocket battle detection
124
+ # Placeholder implementation
125
+ prev_in_battle = action.get("prev_in_battle", False)
126
+ current_in_battle = state["in_battle"]
127
+ battle_outcome = state.get("battle_outcome", 0)
128
+
129
+ if prev_in_battle and not current_in_battle and battle_outcome == 1:
130
+ # Check if in Team Rocket location
131
+ rocket_maps = {50, 51, 52} # Example Team Rocket hideout maps
132
+ if state["map_id"] in rocket_maps:
133
+ encounter_key = (state["player_x"], state["player_y"], state["map_id"])
134
+ if encounter_key not in self.rocket_encounters:
135
+ self.rocket_encounters.add(encounter_key)
136
+ return 40.0
137
+
138
+ return 0.0
139
+
140
+
141
+ class LegendaryEncounterReward(RewardComponent):
142
+ """Reward for encountering legendary Pokemon - +200 points"""
143
+
144
+ def __init__(self):
145
+ self.legendary_encounters: Set[int] = set()
146
+ self.legendary_maps = {60, 61, 62, 70} # Legendary Pokemon locations
147
+
148
+ async def score(self, state: Dict[str, Any], action: Dict[str, Any]) -> float:
149
+ current_map = state["map_id"]
150
+
151
+ if current_map in self.legendary_maps and current_map not in self.legendary_encounters:
152
+ # Check if battle started (legendary encounter)
153
+ prev_in_battle = action.get("prev_in_battle", False)
154
+ current_in_battle = state["in_battle"]
155
+
156
+ if not prev_in_battle and current_in_battle:
157
+ self.legendary_encounters.add(current_map)
158
+ return 200.0
159
+
160
+ return 0.0
161
+
162
+
163
+ class SilphCoCompletionReward(RewardComponent):
164
+ """Reward for completing major story dungeons - +100 points"""
165
+
166
+ def __init__(self):
167
+ self.silph_co_completed = False
168
+ self.silph_co_maps = set(range(80, 90)) # Silph Co floors
169
+
170
+ async def score(self, state: Dict[str, Any], action: Dict[str, Any]) -> float:
171
+ if self.silph_co_completed:
172
+ return 0.0
173
+
174
+ # Check if exiting Silph Co after completion
175
+ prev_map = action.get("prev_map_id", -1)
176
+ current_map = state["map_id"]
177
+
178
+ if prev_map in self.silph_co_maps and current_map not in self.silph_co_maps:
179
+ # Assume completion if leaving Silph Co
180
+ self.silph_co_completed = True
181
+ return 100.0
182
+
183
+ return 0.0
184
+
185
+
186
+ class SafariZoneSuccessReward(RewardComponent):
187
+ """Reward for successful Safari Zone runs - +30 points"""
188
+
189
+ def __init__(self):
190
+ self.safari_zone_runs = 0
191
+ self.safari_zone_maps = {90, 91, 92, 93} # Safari Zone areas
192
+
193
+ async def score(self, state: Dict[str, Any], action: Dict[str, Any]) -> float:
194
+ # Check if exiting Safari Zone with new Pokemon
195
+ prev_map = action.get("prev_map_id", -1)
196
+ current_map = state["map_id"]
197
+
198
+ if prev_map in self.safari_zone_maps and current_map not in self.safari_zone_maps:
199
+ # Check if Pokemon count increased
200
+ prev_party_count = len(action.get("prev_party", []))
201
+ current_party_count = len(state.get("party", []))
202
+
203
+ if current_party_count > prev_party_count:
204
+ return 30.0
205
+
206
+ return 0.0
207
+
208
+
209
+ class GameCornerPrizesReward(RewardComponent):
210
+ """Reward for earning significant Game Corner prizes - +20 points"""
211
+
212
+ def __init__(self):
213
+ self.game_corner_prizes: Set[int] = set()
214
+ self.prize_items = {300, 301, 302} # Game Corner prize item IDs
215
+
216
+ async def score(self, state: Dict[str, Any], action: Dict[str, Any]) -> float:
217
+ inventory = state.get("inventory", [])
218
+ total_reward = 0.0
219
+
220
+ for item in inventory:
221
+ item_id = item.get("item_id", 0)
222
+ if item_id in self.prize_items and item_id not in self.game_corner_prizes:
223
+ self.game_corner_prizes.add(item_id)
224
+ total_reward += 20.0
225
+
226
+ return total_reward
227
+
228
+
229
+ class FossilRevivalReward(RewardComponent):
230
+ """Reward for reviving fossils - +40 points"""
231
+
232
+ def __init__(self):
233
+ self.fossils_revived: Set[int] = set()
234
+ self.fossil_pokemon = {138, 140, 142} # Omanyte, Kabuto, Aerodactyl
235
+
236
+ async def score(self, state: Dict[str, Any], action: Dict[str, Any]) -> float:
237
+ party = state.get("party", [])
238
+ total_reward = 0.0
239
+
240
+ for pokemon in party:
241
+ species_id = pokemon.get("species_id", 0)
242
+ if species_id in self.fossil_pokemon and species_id not in self.fossils_revived:
243
+ self.fossils_revived.add(species_id)
244
+ total_reward += 40.0
245
+
246
+ return total_reward