synth-ai 0.2.0__py3-none-any.whl → 0.2.1.dev0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- synth_ai/__init__.py +28 -2
- synth_ai/core/system.py +4 -0
- synth_ai/environments/__init__.py +35 -0
- synth_ai/environments/environment/__init__.py +1 -0
- synth_ai/environments/environment/artifacts/__init__.py +1 -0
- synth_ai/environments/environment/artifacts/base.py +50 -0
- synth_ai/environments/environment/core.py +22 -0
- synth_ai/environments/environment/db/__init__.py +1 -0
- synth_ai/environments/environment/db/sqlite.py +45 -0
- synth_ai/environments/environment/registry.py +24 -0
- synth_ai/environments/environment/resources/sqlite.py +46 -0
- synth_ai/environments/environment/results.py +1 -0
- synth_ai/environments/environment/rewards/__init__.py +1 -0
- synth_ai/environments/environment/rewards/core.py +28 -0
- synth_ai/environments/environment/shared_engine.py +26 -0
- synth_ai/environments/environment/tools/__init__.py +34 -0
- synth_ai/environments/examples/__init__.py +1 -0
- synth_ai/environments/examples/crafter_classic/__init__.py +8 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_comprehensive_evaluation.py +58 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_evaluation_browser.py +152 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_evaluation_framework.py +1194 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_quick_evaluation.py +51 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_react_agent.py +872 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_trace_evaluation.py +1412 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/test_crafter_react_agent.py +1110 -0
- synth_ai/environments/examples/crafter_classic/config_logging.py +111 -0
- synth_ai/environments/examples/crafter_classic/engine.py +502 -0
- synth_ai/environments/examples/crafter_classic/engine_deterministic_patch.py +63 -0
- synth_ai/environments/examples/crafter_classic/engine_helpers/action_map.py +5 -0
- synth_ai/environments/examples/crafter_classic/engine_helpers/serialization.py +74 -0
- synth_ai/environments/examples/crafter_classic/environment.py +255 -0
- synth_ai/environments/examples/crafter_classic/taskset.py +228 -0
- synth_ai/environments/examples/enron/agent_demos/test_synth_react.py +535 -0
- synth_ai/environments/examples/enron/art_helpers/email_search_tools.py +156 -0
- synth_ai/environments/examples/enron/art_helpers/local_email_db.py +280 -0
- synth_ai/environments/examples/enron/art_helpers/types_enron.py +24 -0
- synth_ai/environments/examples/enron/engine.py +291 -0
- synth_ai/environments/examples/enron/environment.py +165 -0
- synth_ai/environments/examples/enron/taskset.py +112 -0
- synth_ai/environments/examples/enron/units/keyword_stats.py +111 -0
- synth_ai/environments/examples/enron/units/test_email_index.py +8 -0
- synth_ai/environments/examples/minigrid/__init__.py +48 -0
- synth_ai/environments/examples/minigrid/agent_demos/minigrid_evaluation_framework.py +1188 -0
- synth_ai/environments/examples/minigrid/agent_demos/minigrid_quick_evaluation.py +47 -0
- synth_ai/environments/examples/minigrid/agent_demos/minigrid_react_agent.py +562 -0
- synth_ai/environments/examples/minigrid/agent_demos/minigrid_trace_evaluation.py +220 -0
- synth_ai/environments/examples/minigrid/agent_demos/test_minigrid_react_agent.py +393 -0
- synth_ai/environments/examples/minigrid/engine.py +589 -0
- synth_ai/environments/examples/minigrid/environment.py +274 -0
- synth_ai/environments/examples/minigrid/environment_mapping.py +242 -0
- synth_ai/environments/examples/minigrid/puzzle_loader.py +416 -0
- synth_ai/environments/examples/minigrid/taskset.py +583 -0
- synth_ai/environments/examples/minigrid/units/test_action_behavior.py +226 -0
- synth_ai/environments/examples/minigrid/units/test_debug_messages.py +83 -0
- synth_ai/environments/examples/minigrid/units/test_exploration.py +120 -0
- synth_ai/environments/examples/minigrid/units/test_minigrid_engine.py +214 -0
- synth_ai/environments/examples/minigrid/units/test_minigrid_environment.py +238 -0
- synth_ai/environments/examples/minigrid/units/test_minigrid_environment_mapping.py +301 -0
- synth_ai/environments/examples/minigrid/units/test_minigrid_taskset.py +210 -0
- synth_ai/environments/examples/nethack/__init__.py +7 -0
- synth_ai/environments/examples/nethack/achievements.py +337 -0
- synth_ai/environments/examples/nethack/agent_demos/nethack_evaluation_framework.py +981 -0
- synth_ai/environments/examples/nethack/agent_demos/nethack_quick_evaluation.py +74 -0
- synth_ai/environments/examples/nethack/agent_demos/nethack_react_agent.py +832 -0
- synth_ai/environments/examples/nethack/agent_demos/test_nethack_react_agent.py +1112 -0
- synth_ai/environments/examples/nethack/engine.py +738 -0
- synth_ai/environments/examples/nethack/environment.py +255 -0
- synth_ai/environments/examples/nethack/helpers/__init__.py +42 -0
- synth_ai/environments/examples/nethack/helpers/action_mapping.py +301 -0
- synth_ai/environments/examples/nethack/helpers/nle_wrapper.py +401 -0
- synth_ai/environments/examples/nethack/helpers/observation_utils.py +433 -0
- synth_ai/environments/examples/nethack/helpers/recording_wrapper.py +201 -0
- synth_ai/environments/examples/nethack/helpers/trajectory_recorder.py +268 -0
- synth_ai/environments/examples/nethack/helpers/visualization/replay_viewer.py +308 -0
- synth_ai/environments/examples/nethack/helpers/visualization/visualizer.py +430 -0
- synth_ai/environments/examples/nethack/taskset.py +323 -0
- synth_ai/environments/examples/nethack/units/test_nethack_engine.py +277 -0
- synth_ai/environments/examples/nethack/units/test_nethack_environment.py +281 -0
- synth_ai/environments/examples/nethack/units/test_nethack_taskset.py +213 -0
- synth_ai/environments/examples/nethack/units/test_recording.py +307 -0
- synth_ai/environments/examples/red/__init__.py +7 -0
- synth_ai/environments/examples/red/agent_demos/__init__.py +1 -0
- synth_ai/environments/examples/red/agent_demos/test_synth_react.py +1471 -0
- synth_ai/environments/examples/red/config_logging.py +110 -0
- synth_ai/environments/examples/red/engine.py +693 -0
- synth_ai/environments/examples/red/engine_helpers/__init__.py +1 -0
- synth_ai/environments/examples/red/engine_helpers/memory_map.py +28 -0
- synth_ai/environments/examples/red/engine_helpers/reward_components.py +275 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/__init__.py +142 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/adaptive_rewards.py +56 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/battle_rewards.py +283 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/composite_rewards.py +149 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/economy_rewards.py +137 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/efficiency_rewards.py +56 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/exploration_rewards.py +330 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/novelty_rewards.py +120 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/pallet_town_rewards.py +558 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/pokemon_rewards.py +312 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/social_rewards.py +147 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/story_rewards.py +246 -0
- synth_ai/environments/examples/red/engine_helpers/screen_analysis.py +367 -0
- synth_ai/environments/examples/red/engine_helpers/state_extraction.py +139 -0
- synth_ai/environments/examples/red/environment.py +235 -0
- synth_ai/environments/examples/red/taskset.py +77 -0
- synth_ai/environments/examples/red/test_fixes.py +125 -0
- synth_ai/environments/examples/red/test_fixes_mock.py +148 -0
- synth_ai/environments/examples/red/units/__init__.py +1 -0
- synth_ai/environments/examples/red/units/test_basic_functionality.py +97 -0
- synth_ai/environments/examples/red/units/test_button_press_requirements.py +217 -0
- synth_ai/environments/examples/red/units/test_engine.py +192 -0
- synth_ai/environments/examples/red/units/test_environment.py +455 -0
- synth_ai/environments/examples/red/units/test_exploration_strategy.py +227 -0
- synth_ai/environments/examples/red/units/test_integration.py +217 -0
- synth_ai/environments/examples/red/units/test_memory_extraction.py +111 -0
- synth_ai/environments/examples/red/units/test_menu_bug_reproduction.py +1100 -0
- synth_ai/environments/examples/red/units/test_movement_debug.py +255 -0
- synth_ai/environments/examples/red/units/test_pokemon_mcts_debug.py +163 -0
- synth_ai/environments/examples/red/units/test_pokemon_mcts_verbose.py +117 -0
- synth_ai/environments/examples/red/units/test_red_basic.py +145 -0
- synth_ai/environments/examples/red/units/test_red_comprehensive.py +323 -0
- synth_ai/environments/examples/red/units/test_retry_movement.py +195 -0
- synth_ai/environments/examples/red/units/test_reward_components.py +186 -0
- synth_ai/environments/examples/red/units/test_rom_integration.py +260 -0
- synth_ai/environments/examples/red/units/test_taskset.py +116 -0
- synth_ai/environments/examples/red/units/test_tree.py +448 -0
- synth_ai/environments/examples/sokoban/__init__.py +1 -0
- synth_ai/environments/examples/sokoban/agent_demos/sokoban_full_eval.py +900 -0
- synth_ai/environments/examples/sokoban/agent_demos/test_dspy_react.py +1 -0
- synth_ai/environments/examples/sokoban/agent_demos/test_sokoban_react_agent.py +498 -0
- synth_ai/environments/examples/sokoban/agent_demos/test_synth_lats.py +1 -0
- synth_ai/environments/examples/sokoban/agent_demos/test_synth_react_locally.py +748 -0
- synth_ai/environments/examples/sokoban/agent_demos/test_synth_react_service.py +296 -0
- synth_ai/environments/examples/sokoban/engine.py +675 -0
- synth_ai/environments/examples/sokoban/engine_helpers/__init__.py +1 -0
- synth_ai/environments/examples/sokoban/engine_helpers/room_utils.py +656 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/__init__.py +17 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/__init__.py +3 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/boxoban_env.py +129 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/render_utils.py +370 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/room_utils.py +331 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/sokoban_env.py +305 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/sokoban_env_fixed_targets.py +66 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/sokoban_env_pull.py +114 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/sokoban_env_two_player.py +122 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/sokoban_env_variations.py +394 -0
- synth_ai/environments/examples/sokoban/environment.py +228 -0
- synth_ai/environments/examples/sokoban/generate_verified_puzzles.py +438 -0
- synth_ai/environments/examples/sokoban/puzzle_loader.py +311 -0
- synth_ai/environments/examples/sokoban/taskset.py +425 -0
- synth_ai/environments/examples/sokoban/units/astar_common.py +94 -0
- synth_ai/environments/examples/sokoban/units/test_building_task_set.py +49 -0
- synth_ai/environments/examples/sokoban/units/test_false_positive.py +120 -0
- synth_ai/environments/examples/sokoban/units/test_simple_run_through_environment.py +119 -0
- synth_ai/environments/examples/sokoban/units/test_sokoban_environment.py +98 -0
- synth_ai/environments/examples/sokoban/units/test_tree.py +364 -0
- synth_ai/environments/examples/tictactoe/__init__.py +1 -0
- synth_ai/environments/examples/tictactoe/agent_demos/test_synth_react.py +266 -0
- synth_ai/environments/examples/tictactoe/agent_demos/test_tictactoe_react_agent.py +470 -0
- synth_ai/environments/examples/tictactoe/engine.py +368 -0
- synth_ai/environments/examples/tictactoe/environment.py +239 -0
- synth_ai/environments/examples/tictactoe/taskset.py +214 -0
- synth_ai/environments/examples/tictactoe/units/test_tictactoe_engine.py +393 -0
- synth_ai/environments/examples/tictactoe/units/test_tictactoe_environment.py +493 -0
- synth_ai/environments/examples/tictactoe/units/test_tictactoe_taskset.py +191 -0
- synth_ai/environments/examples/verilog/__init__.py +10 -0
- synth_ai/environments/examples/verilog/agent_demos/test_synth_react.py +520 -0
- synth_ai/environments/examples/verilog/engine.py +328 -0
- synth_ai/environments/examples/verilog/environment.py +349 -0
- synth_ai/environments/examples/verilog/taskset.py +418 -0
- synth_ai/environments/examples/verilog/units/test_verilog_engine.py +466 -0
- synth_ai/environments/examples/verilog/units/test_verilog_environment.py +585 -0
- synth_ai/environments/examples/verilog/units/test_verilog_integration.py +383 -0
- synth_ai/environments/examples/verilog/units/test_verilog_taskset.py +457 -0
- synth_ai/environments/reproducibility/core.py +42 -0
- synth_ai/environments/reproducibility/tree.py +364 -0
- synth_ai/environments/service/app.py +78 -0
- synth_ai/environments/service/core_routes.py +775 -0
- synth_ai/environments/service/external_registry.py +57 -0
- synth_ai/environments/service/registry.py +9 -0
- synth_ai/environments/stateful/__init__.py +1 -0
- synth_ai/environments/stateful/core.py +28 -0
- synth_ai/environments/stateful/engine.py +21 -0
- synth_ai/environments/stateful/state.py +7 -0
- synth_ai/environments/tasks/api.py +19 -0
- synth_ai/environments/tasks/core.py +78 -0
- synth_ai/environments/tasks/filters.py +39 -0
- synth_ai/environments/tasks/utils.py +89 -0
- synth_ai/environments/v0_observability/history.py +3 -0
- synth_ai/environments/v0_observability/log.py +2 -0
- synth_ai/lm/caching/constants.py +1 -0
- synth_ai/{zyk/lms → lm}/caching/ephemeral.py +4 -8
- synth_ai/{zyk/lms → lm}/caching/handler.py +15 -15
- synth_ai/{zyk/lms → lm}/caching/initialize.py +2 -4
- synth_ai/{zyk/lms → lm}/caching/persistent.py +4 -10
- synth_ai/{zyk/lms → lm}/config.py +2 -1
- synth_ai/{zyk/lms → lm}/constants.py +2 -2
- synth_ai/{zyk/lms → lm}/core/all.py +10 -10
- synth_ai/{zyk/lms → lm}/core/main.py +57 -33
- synth_ai/{zyk/lms → lm}/core/vendor_clients.py +12 -10
- synth_ai/lm/cost/monitor.py +1 -0
- synth_ai/lm/cost/statefulness.py +1 -0
- synth_ai/lm/provider_support/__init__.py +8 -0
- synth_ai/lm/provider_support/anthropic.py +945 -0
- synth_ai/lm/provider_support/openai.py +1115 -0
- synth_ai/lm/provider_support/suppress_logging.py +31 -0
- synth_ai/{zyk/lms → lm}/structured_outputs/handler.py +58 -80
- synth_ai/{zyk/lms → lm}/structured_outputs/inject.py +6 -20
- synth_ai/{zyk/lms → lm}/structured_outputs/rehabilitate.py +6 -12
- synth_ai/{zyk/lms → lm}/vendors/core/anthropic_api.py +21 -30
- synth_ai/{zyk/lms → lm}/vendors/core/gemini_api.py +35 -32
- synth_ai/{zyk/lms → lm}/vendors/core/mistral_api.py +19 -28
- synth_ai/{zyk/lms → lm}/vendors/core/openai_api.py +26 -36
- synth_ai/{zyk/lms → lm}/vendors/openai_standard.py +29 -33
- synth_ai/{zyk/lms → lm}/vendors/retries.py +1 -1
- synth_ai/lm/vendors/supported/__init__.py +0 -0
- synth_ai/{zyk/lms → lm}/vendors/supported/custom_endpoint.py +131 -118
- synth_ai/{zyk/lms → lm}/vendors/supported/deepseek.py +4 -8
- synth_ai/{zyk/lms → lm}/vendors/supported/grok.py +6 -8
- synth_ai/{zyk/lms → lm}/vendors/supported/groq.py +1 -1
- synth_ai/{zyk/lms → lm}/vendors/supported/ollama.py +2 -2
- synth_ai/{zyk/lms → lm}/vendors/supported/openrouter.py +18 -16
- synth_ai/{zyk/lms → lm}/vendors/supported/together.py +1 -1
- synth_ai/tracing/__init__.py +0 -0
- synth_ai/tracing/abstractions.py +224 -0
- synth_ai/tracing/base_client.py +91 -0
- synth_ai/tracing/client_manager.py +131 -0
- synth_ai/tracing/config.py +140 -0
- synth_ai/tracing/context.py +146 -0
- synth_ai/tracing/decorators.py +679 -0
- synth_ai/tracing/events/__init__.py +0 -0
- synth_ai/tracing/events/manage.py +147 -0
- synth_ai/tracing/events/scope.py +86 -0
- synth_ai/tracing/events/store.py +227 -0
- synth_ai/tracing/immediate_client.py +152 -0
- synth_ai/tracing/local.py +18 -0
- synth_ai/tracing/log_client_base.py +74 -0
- synth_ai/tracing/retry_queue.py +187 -0
- synth_ai/tracing/trackers.py +515 -0
- synth_ai/tracing/upload.py +504 -0
- synth_ai/tracing/utils.py +9 -0
- synth_ai/zyk/__init__.py +28 -2
- synth_ai-0.2.1.dev0.dist-info/METADATA +349 -0
- synth_ai-0.2.1.dev0.dist-info/RECORD +261 -0
- {synth_ai-0.2.0.dist-info → synth_ai-0.2.1.dev0.dist-info}/WHEEL +1 -1
- synth_ai/zyk/lms/caching/constants.py +0 -1
- synth_ai/zyk/lms/cost/monitor.py +0 -1
- synth_ai/zyk/lms/cost/statefulness.py +0 -1
- synth_ai-0.2.0.dist-info/METADATA +0 -36
- synth_ai-0.2.0.dist-info/RECORD +0 -50
- /synth_ai/{zyk/lms/__init__.py → environments/reproducibility/helpers.py} +0 -0
- /synth_ai/{zyk/lms/caching → lm}/__init__.py +0 -0
- /synth_ai/{zyk/lms/core → lm/caching}/__init__.py +0 -0
- /synth_ai/{zyk/lms → lm}/caching/dbs.py +0 -0
- /synth_ai/{zyk/lms/cost → lm/core}/__init__.py +0 -0
- /synth_ai/{zyk/lms → lm}/core/exceptions.py +0 -0
- /synth_ai/{zyk/lms/structured_outputs → lm/cost}/__init__.py +0 -0
- /synth_ai/{zyk/lms/vendors → lm/structured_outputs}/__init__.py +0 -0
- /synth_ai/{zyk/lms → lm}/tools/__init__.py +0 -0
- /synth_ai/{zyk/lms → lm}/tools/base.py +0 -0
- /synth_ai/{zyk/lms/vendors/core → lm/vendors}/__init__.py +0 -0
- /synth_ai/{zyk/lms → lm}/vendors/base.py +0 -0
- /synth_ai/{zyk/lms/vendors/local → lm/vendors/core}/__init__.py +0 -0
- /synth_ai/{zyk/lms/vendors/supported → lm/vendors/local}/__init__.py +0 -0
- /synth_ai/{zyk/lms → lm}/vendors/local/ollama.py +0 -0
- {synth_ai-0.2.0.dist-info → synth_ai-0.2.1.dev0.dist-info/licenses}/LICENSE +0 -0
- {synth_ai-0.2.0.dist-info → synth_ai-0.2.1.dev0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,312 @@
|
|
1
|
+
"""
|
2
|
+
Pokemon Collection & Management Reward Components
|
3
|
+
|
4
|
+
Rewards for catching Pokemon, Pokedex progress, and Pokemon development.
|
5
|
+
"""
|
6
|
+
|
7
|
+
from synth_ai.environments.environment.rewards.core import RewardComponent
|
8
|
+
from typing import Dict, Any, Set
|
9
|
+
|
10
|
+
|
11
|
+
class FirstPokemonCaughtReward(RewardComponent):
|
12
|
+
"""Reward for catching the starter or first wild Pokemon - +50 points"""
|
13
|
+
|
14
|
+
def __init__(self):
|
15
|
+
self.first_caught = False
|
16
|
+
|
17
|
+
async def score(self, state: Dict[str, Any], action: Dict[str, Any]) -> float:
|
18
|
+
if self.first_caught:
|
19
|
+
return 0.0
|
20
|
+
|
21
|
+
prev_party_count = len(action.get("prev_party", []))
|
22
|
+
current_party_count = len(state.get("party", []))
|
23
|
+
|
24
|
+
# First Pokemon acquired (starter or caught)
|
25
|
+
if prev_party_count == 0 and current_party_count == 1:
|
26
|
+
self.first_caught = True
|
27
|
+
return 50.0
|
28
|
+
return 0.0
|
29
|
+
|
30
|
+
|
31
|
+
class NewSpeciesCaughtReward(RewardComponent):
|
32
|
+
"""Reward for each new Pokedex entry - +20 points"""
|
33
|
+
|
34
|
+
def __init__(self):
|
35
|
+
self.species_caught: Set[int] = set()
|
36
|
+
|
37
|
+
async def score(self, state: Dict[str, Any], action: Dict[str, Any]) -> float:
|
38
|
+
# Check for new Pokemon in party
|
39
|
+
party = state.get("party", [])
|
40
|
+
total_reward = 0.0
|
41
|
+
|
42
|
+
for pokemon in party:
|
43
|
+
species_id = pokemon.get("species_id", 0)
|
44
|
+
if species_id not in self.species_caught and species_id > 0:
|
45
|
+
self.species_caught.add(species_id)
|
46
|
+
total_reward += 20.0
|
47
|
+
|
48
|
+
return total_reward
|
49
|
+
|
50
|
+
|
51
|
+
class RarePokemonCaughtReward(RewardComponent):
|
52
|
+
"""Reward for catching uncommon/rare Pokemon - +40 points"""
|
53
|
+
|
54
|
+
def __init__(self):
|
55
|
+
self.rare_pokemon_caught: Set[int] = set()
|
56
|
+
# Rare Pokemon species IDs (would be loaded from game data)
|
57
|
+
self.rare_species = {
|
58
|
+
144,
|
59
|
+
145,
|
60
|
+
146, # Legendary birds
|
61
|
+
150, # Mewtwo
|
62
|
+
149, # Dragonite
|
63
|
+
130,
|
64
|
+
131, # Gyarados, Lapras
|
65
|
+
138,
|
66
|
+
139, # Omanyte, Omastar
|
67
|
+
140,
|
68
|
+
141, # Kabuto, Kabutops
|
69
|
+
142, # Aerodactyl
|
70
|
+
}
|
71
|
+
|
72
|
+
async def score(self, state: Dict[str, Any], action: Dict[str, Any]) -> float:
|
73
|
+
party = state.get("party", [])
|
74
|
+
total_reward = 0.0
|
75
|
+
|
76
|
+
for pokemon in party:
|
77
|
+
species_id = pokemon.get("species_id", 0)
|
78
|
+
if species_id in self.rare_species and species_id not in self.rare_pokemon_caught:
|
79
|
+
self.rare_pokemon_caught.add(species_id)
|
80
|
+
total_reward += 40.0
|
81
|
+
|
82
|
+
return total_reward
|
83
|
+
|
84
|
+
|
85
|
+
class EvolutionStonePokemonReward(RewardComponent):
|
86
|
+
"""Reward for catching Pokemon that require evolution stones - +30 points"""
|
87
|
+
|
88
|
+
def __init__(self):
|
89
|
+
self.evolution_stone_pokemon_caught: Set[int] = set()
|
90
|
+
# Pokemon that evolve with stones
|
91
|
+
self.evolution_stone_pokemon = {
|
92
|
+
25, # Pikachu (Thunder Stone)
|
93
|
+
30, # Nidorina (Moon Stone)
|
94
|
+
33, # Nidorino (Moon Stone)
|
95
|
+
35, # Clefairy (Moon Stone)
|
96
|
+
37, # Vulpix (Fire Stone)
|
97
|
+
39, # Jigglypuff (Moon Stone)
|
98
|
+
44, # Gloom (Leaf Stone)
|
99
|
+
58, # Growlithe (Fire Stone)
|
100
|
+
61, # Poliwhirl (Water Stone)
|
101
|
+
90, # Shellder (Water Stone)
|
102
|
+
102, # Exeggcute (Leaf Stone)
|
103
|
+
108, # Lickitung (rare)
|
104
|
+
120, # Staryu (Water Stone)
|
105
|
+
}
|
106
|
+
|
107
|
+
async def score(self, state: Dict[str, Any], action: Dict[str, Any]) -> float:
|
108
|
+
party = state.get("party", [])
|
109
|
+
total_reward = 0.0
|
110
|
+
|
111
|
+
for pokemon in party:
|
112
|
+
species_id = pokemon.get("species_id", 0)
|
113
|
+
if (
|
114
|
+
species_id in self.evolution_stone_pokemon
|
115
|
+
and species_id not in self.evolution_stone_pokemon_caught
|
116
|
+
):
|
117
|
+
self.evolution_stone_pokemon_caught.add(species_id)
|
118
|
+
total_reward += 30.0
|
119
|
+
|
120
|
+
return total_reward
|
121
|
+
|
122
|
+
|
123
|
+
class PokedexMilestonesReward(RewardComponent):
|
124
|
+
"""Reward for reaching Pokedex milestones - +100 points for 10, 25, 50, 100, 150"""
|
125
|
+
|
126
|
+
def __init__(self):
|
127
|
+
self.milestones_reached: Set[int] = set()
|
128
|
+
self.milestones = [10, 25, 50, 100, 150]
|
129
|
+
self.unique_species: Set[int] = set()
|
130
|
+
|
131
|
+
async def score(self, state: Dict[str, Any], action: Dict[str, Any]) -> float:
|
132
|
+
# Track unique species seen/caught
|
133
|
+
party = state.get("party", [])
|
134
|
+
for pokemon in party:
|
135
|
+
species_id = pokemon.get("species_id", 0)
|
136
|
+
if species_id > 0:
|
137
|
+
self.unique_species.add(species_id)
|
138
|
+
|
139
|
+
total_reward = 0.0
|
140
|
+
species_count = len(self.unique_species)
|
141
|
+
|
142
|
+
for milestone in self.milestones:
|
143
|
+
if species_count >= milestone and milestone not in self.milestones_reached:
|
144
|
+
self.milestones_reached.add(milestone)
|
145
|
+
total_reward += 100.0
|
146
|
+
|
147
|
+
return total_reward
|
148
|
+
|
149
|
+
|
150
|
+
class AreaPokedexCompletionReward(RewardComponent):
|
151
|
+
"""Reward for catching all Pokemon available in an area - +50 points"""
|
152
|
+
|
153
|
+
def __init__(self):
|
154
|
+
self.completed_areas: Set[int] = set()
|
155
|
+
# Area Pokemon lists (would be loaded from game data)
|
156
|
+
self.area_pokemon = {
|
157
|
+
0: {16, 17, 18}, # Pallet Town area (Pidgey line)
|
158
|
+
1: {10, 11, 13, 14}, # Route 1 (Caterpie, Weedle lines)
|
159
|
+
# Add more areas
|
160
|
+
}
|
161
|
+
self.caught_by_area: Dict[int, Set[int]] = {}
|
162
|
+
|
163
|
+
async def score(self, state: Dict[str, Any], action: Dict[str, Any]) -> float:
|
164
|
+
current_map = state["map_id"]
|
165
|
+
|
166
|
+
if current_map in self.completed_areas:
|
167
|
+
return 0.0
|
168
|
+
|
169
|
+
# Track caught Pokemon in this area
|
170
|
+
if current_map not in self.caught_by_area:
|
171
|
+
self.caught_by_area[current_map] = set()
|
172
|
+
|
173
|
+
party = state.get("party", [])
|
174
|
+
for pokemon in party:
|
175
|
+
species_id = pokemon.get("species_id", 0)
|
176
|
+
if species_id > 0:
|
177
|
+
self.caught_by_area[current_map].add(species_id)
|
178
|
+
|
179
|
+
# Check if area is complete
|
180
|
+
required_pokemon = self.area_pokemon.get(current_map, set())
|
181
|
+
if required_pokemon.issubset(self.caught_by_area[current_map]):
|
182
|
+
self.completed_areas.add(current_map)
|
183
|
+
return 50.0
|
184
|
+
|
185
|
+
return 0.0
|
186
|
+
|
187
|
+
|
188
|
+
class TypeCollectionReward(RewardComponent):
|
189
|
+
"""Reward for catching first Pokemon of each type - +25 points"""
|
190
|
+
|
191
|
+
def __init__(self):
|
192
|
+
self.types_collected: Set[str] = set()
|
193
|
+
# Pokemon type mappings (simplified)
|
194
|
+
self.pokemon_types = {
|
195
|
+
1: "grass",
|
196
|
+
4: "fire",
|
197
|
+
7: "water",
|
198
|
+
25: "electric",
|
199
|
+
# Add more mappings
|
200
|
+
}
|
201
|
+
|
202
|
+
async def score(self, state: Dict[str, Any], action: Dict[str, Any]) -> float:
|
203
|
+
party = state.get("party", [])
|
204
|
+
total_reward = 0.0
|
205
|
+
|
206
|
+
for pokemon in party:
|
207
|
+
species_id = pokemon.get("species_id", 0)
|
208
|
+
pokemon_type = self.pokemon_types.get(species_id)
|
209
|
+
|
210
|
+
if pokemon_type and pokemon_type not in self.types_collected:
|
211
|
+
self.types_collected.add(pokemon_type)
|
212
|
+
total_reward += 25.0
|
213
|
+
|
214
|
+
return total_reward
|
215
|
+
|
216
|
+
|
217
|
+
class PokemonEvolutionReward(RewardComponent):
|
218
|
+
"""Reward for evolving Pokemon - +30 points"""
|
219
|
+
|
220
|
+
def __init__(self):
|
221
|
+
self.evolution_count = 0
|
222
|
+
self.previous_species: Set[int] = set()
|
223
|
+
|
224
|
+
async def score(self, state: Dict[str, Any], action: Dict[str, Any]) -> float:
|
225
|
+
# Track species changes (evolution)
|
226
|
+
prev_party = action.get("prev_party", [])
|
227
|
+
current_party = state.get("party", [])
|
228
|
+
|
229
|
+
prev_species = {p.get("species_id", 0) for p in prev_party}
|
230
|
+
current_species = {p.get("species_id", 0) for p in current_party}
|
231
|
+
|
232
|
+
# Check for evolution (new species appears, old species disappears)
|
233
|
+
evolved_species = current_species - prev_species
|
234
|
+
|
235
|
+
if evolved_species and self._is_evolution(prev_species, current_species):
|
236
|
+
return 30.0
|
237
|
+
|
238
|
+
return 0.0
|
239
|
+
|
240
|
+
def _is_evolution(self, prev_species: Set[int], current_species: Set[int]) -> bool:
|
241
|
+
"""Check if species change represents evolution"""
|
242
|
+
# This would check evolution chains
|
243
|
+
# Simplified: any new species with same party size is evolution
|
244
|
+
return len(prev_species) == len(current_species) and prev_species != current_species
|
245
|
+
|
246
|
+
|
247
|
+
class LevelMilestonesReward(RewardComponent):
|
248
|
+
"""Reward for reaching levels 10, 20, 30, 40, 50 with any Pokemon - +10 points"""
|
249
|
+
|
250
|
+
def __init__(self):
|
251
|
+
self.level_milestones_reached: Set[tuple] = set() # (pokemon_id, level)
|
252
|
+
self.milestones = [10, 20, 30, 40, 50]
|
253
|
+
|
254
|
+
async def score(self, state: Dict[str, Any], action: Dict[str, Any]) -> float:
|
255
|
+
party = state.get("party", [])
|
256
|
+
total_reward = 0.0
|
257
|
+
|
258
|
+
for i, pokemon in enumerate(party):
|
259
|
+
level = pokemon.get("level", 0)
|
260
|
+
|
261
|
+
for milestone in self.milestones:
|
262
|
+
milestone_key = (i, milestone)
|
263
|
+
if level >= milestone and milestone_key not in self.level_milestones_reached:
|
264
|
+
self.level_milestones_reached.add(milestone_key)
|
265
|
+
total_reward += 10.0
|
266
|
+
|
267
|
+
return total_reward
|
268
|
+
|
269
|
+
|
270
|
+
class MoveLearningReward(RewardComponent):
|
271
|
+
"""Reward for learning new moves (not replacing) - +5 points"""
|
272
|
+
|
273
|
+
def __init__(self):
|
274
|
+
self.known_moves: Set[tuple] = set() # (pokemon_index, move_id)
|
275
|
+
|
276
|
+
async def score(self, state: Dict[str, Any], action: Dict[str, Any]) -> float:
|
277
|
+
party = state.get("party", [])
|
278
|
+
total_reward = 0.0
|
279
|
+
|
280
|
+
for i, pokemon in enumerate(party):
|
281
|
+
moves = pokemon.get("moves", [])
|
282
|
+
for move_id in moves:
|
283
|
+
move_key = (i, move_id)
|
284
|
+
if move_key not in self.known_moves and move_id > 0:
|
285
|
+
self.known_moves.add(move_key)
|
286
|
+
total_reward += 5.0
|
287
|
+
|
288
|
+
return total_reward
|
289
|
+
|
290
|
+
|
291
|
+
class TMHMTeachingReward(RewardComponent):
|
292
|
+
"""Reward for successfully teaching TMs/HMs - +10 points"""
|
293
|
+
|
294
|
+
def __init__(self):
|
295
|
+
self.tm_hm_taught: Set[tuple] = set() # (pokemon_index, tm_hm_id)
|
296
|
+
# TM/HM move IDs (would be loaded from game data)
|
297
|
+
self.tm_hm_moves = set(range(15, 65)) # Example TM/HM move range
|
298
|
+
|
299
|
+
async def score(self, state: Dict[str, Any], action: Dict[str, Any]) -> float:
|
300
|
+
party = state.get("party", [])
|
301
|
+
total_reward = 0.0
|
302
|
+
|
303
|
+
for i, pokemon in enumerate(party):
|
304
|
+
moves = pokemon.get("moves", [])
|
305
|
+
for move_id in moves:
|
306
|
+
if move_id in self.tm_hm_moves:
|
307
|
+
move_key = (i, move_id)
|
308
|
+
if move_key not in self.tm_hm_taught:
|
309
|
+
self.tm_hm_taught.add(move_key)
|
310
|
+
total_reward += 10.0
|
311
|
+
|
312
|
+
return total_reward
|
@@ -0,0 +1,147 @@
|
|
1
|
+
"""
|
2
|
+
Social & NPC Interaction Reward Components
|
3
|
+
|
4
|
+
Rewards for dialogue, information gathering, and NPC interactions.
|
5
|
+
"""
|
6
|
+
|
7
|
+
from synth_ai.environments.environment.rewards.core import RewardComponent
|
8
|
+
from typing import Dict, Any, Set
|
9
|
+
|
10
|
+
|
11
|
+
class NewNPCConversationReward(RewardComponent):
|
12
|
+
"""Reward for talking to each unique NPC for the first time - +5 points"""
|
13
|
+
|
14
|
+
def __init__(self):
|
15
|
+
self.npcs_talked_to: Set[tuple] = set()
|
16
|
+
|
17
|
+
async def score(self, state: Dict[str, Any], action: Dict[str, Any]) -> float:
|
18
|
+
if state["text_box_active"] and not action.get("prev_text_box_active", False):
|
19
|
+
npc_key = (state["player_x"], state["player_y"], state["map_id"])
|
20
|
+
if npc_key not in self.npcs_talked_to:
|
21
|
+
self.npcs_talked_to.add(npc_key)
|
22
|
+
return 5.0
|
23
|
+
return 0.0
|
24
|
+
|
25
|
+
|
26
|
+
class HelpfulInformationReceivedReward(RewardComponent):
|
27
|
+
"""Reward for getting useful hints, directions, or game tips - +10 points"""
|
28
|
+
|
29
|
+
async def score(self, state: Dict[str, Any], action: Dict[str, Any]) -> float:
|
30
|
+
# This would need dialogue content analysis
|
31
|
+
# Placeholder implementation
|
32
|
+
if state["text_box_active"] and not action.get("prev_text_box_active", False):
|
33
|
+
# Simplified: reward for certain locations known to give helpful info
|
34
|
+
helpful_locations = {(5, 3, 0), (2, 4, 3)} # Example helpful NPC locations
|
35
|
+
location = (state["player_x"], state["player_y"], state["map_id"])
|
36
|
+
if location in helpful_locations:
|
37
|
+
return 10.0
|
38
|
+
return 0.0
|
39
|
+
|
40
|
+
|
41
|
+
class StoryDialogueProgressionReward(RewardComponent):
|
42
|
+
"""Reward for advancing story through key NPCs - +15 points"""
|
43
|
+
|
44
|
+
def __init__(self):
|
45
|
+
self.story_npcs_talked_to: Set[tuple] = set()
|
46
|
+
|
47
|
+
async def score(self, state: Dict[str, Any], action: Dict[str, Any]) -> float:
|
48
|
+
# Story NPCs in key locations
|
49
|
+
story_locations = {(3, 4, 3), (5, 2, 0)} # Oak's lab, important NPCs
|
50
|
+
location = (state["player_x"], state["player_y"], state["map_id"])
|
51
|
+
|
52
|
+
if (
|
53
|
+
state["text_box_active"]
|
54
|
+
and not action.get("prev_text_box_active", False)
|
55
|
+
and location in story_locations
|
56
|
+
and location not in self.story_npcs_talked_to
|
57
|
+
):
|
58
|
+
self.story_npcs_talked_to.add(location)
|
59
|
+
return 15.0
|
60
|
+
return 0.0
|
61
|
+
|
62
|
+
|
63
|
+
class ProfessorOakInteractionsReward(RewardComponent):
|
64
|
+
"""Reward for meaningful interactions with Professor Oak - +20 points"""
|
65
|
+
|
66
|
+
def __init__(self):
|
67
|
+
self.oak_interactions = 0
|
68
|
+
|
69
|
+
async def score(self, state: Dict[str, Any], action: Dict[str, Any]) -> float:
|
70
|
+
# Oak's lab interactions
|
71
|
+
if (
|
72
|
+
state["map_id"] == 3
|
73
|
+
and state["text_box_active"]
|
74
|
+
and not action.get("prev_text_box_active", False)
|
75
|
+
):
|
76
|
+
# Check if this is likely Oak (center of lab)
|
77
|
+
if 3 <= state["player_x"] <= 5 and 4 <= state["player_y"] <= 6:
|
78
|
+
return 20.0
|
79
|
+
return 0.0
|
80
|
+
|
81
|
+
|
82
|
+
class NPCGiftReceivedReward(RewardComponent):
|
83
|
+
"""Reward for receiving Pokemon or items from NPCs - +15 points"""
|
84
|
+
|
85
|
+
async def score(self, state: Dict[str, Any], action: Dict[str, Any]) -> float:
|
86
|
+
# Check for item/Pokemon acquisition during NPC interaction
|
87
|
+
prev_inventory_count = len(action.get("prev_inventory", []))
|
88
|
+
current_inventory_count = len(state.get("inventory", []))
|
89
|
+
prev_party_count = len(action.get("prev_party", []))
|
90
|
+
current_party_count = len(state.get("party", []))
|
91
|
+
|
92
|
+
# Gift received if items/Pokemon increased during text interaction
|
93
|
+
if state["text_box_active"] and (
|
94
|
+
current_inventory_count > prev_inventory_count or current_party_count > prev_party_count
|
95
|
+
):
|
96
|
+
return 15.0
|
97
|
+
return 0.0
|
98
|
+
|
99
|
+
|
100
|
+
class TradeCompletionReward(RewardComponent):
|
101
|
+
"""Reward for completing in-game trades - +25 points"""
|
102
|
+
|
103
|
+
def __init__(self):
|
104
|
+
self.trades_completed: Set[tuple] = set()
|
105
|
+
|
106
|
+
async def score(self, state: Dict[str, Any], action: Dict[str, Any]) -> float:
|
107
|
+
# Trade locations (would be loaded from game data)
|
108
|
+
trade_locations = {(2, 3, 15), (4, 5, 20)} # Example trade locations
|
109
|
+
location = (state["player_x"], state["player_y"], state["map_id"])
|
110
|
+
|
111
|
+
if location in trade_locations and location not in self.trades_completed:
|
112
|
+
# Check for Pokemon species change (trade occurred)
|
113
|
+
prev_party = action.get("prev_party", [])
|
114
|
+
current_party = state.get("party", [])
|
115
|
+
|
116
|
+
if len(prev_party) == len(current_party):
|
117
|
+
prev_species = {p.get("species_id") for p in prev_party}
|
118
|
+
current_species = {p.get("species_id") for p in current_party}
|
119
|
+
|
120
|
+
if prev_species != current_species:
|
121
|
+
self.trades_completed.add(location)
|
122
|
+
return 25.0
|
123
|
+
return 0.0
|
124
|
+
|
125
|
+
|
126
|
+
class NameRaterUsageReward(RewardComponent):
|
127
|
+
"""Reward for using nickname services - +5 points"""
|
128
|
+
|
129
|
+
def __init__(self):
|
130
|
+
self.name_rater_used = False
|
131
|
+
|
132
|
+
async def score(self, state: Dict[str, Any], action: Dict[str, Any]) -> float:
|
133
|
+
if self.name_rater_used:
|
134
|
+
return 0.0
|
135
|
+
|
136
|
+
# Name Rater location (would be loaded from game data)
|
137
|
+
name_rater_location = (3, 2, 25) # Example location
|
138
|
+
location = (state["player_x"], state["player_y"], state["map_id"])
|
139
|
+
|
140
|
+
if (
|
141
|
+
location == name_rater_location
|
142
|
+
and state["text_box_active"]
|
143
|
+
and not action.get("prev_text_box_active", False)
|
144
|
+
):
|
145
|
+
self.name_rater_used = True
|
146
|
+
return 5.0
|
147
|
+
return 0.0
|
@@ -0,0 +1,246 @@
|
|
1
|
+
"""
|
2
|
+
Story & Achievement Progression Reward Components
|
3
|
+
|
4
|
+
Rewards for major milestones, story gates, and achievements.
|
5
|
+
"""
|
6
|
+
|
7
|
+
from synth_ai.environments.environment.rewards.core import RewardComponent
|
8
|
+
from typing import Dict, Any, Set
|
9
|
+
|
10
|
+
|
11
|
+
class GymBadgeEarnedReward(RewardComponent):
|
12
|
+
"""Reward for earning gym badges - +150 points per badge (cumulative)"""
|
13
|
+
|
14
|
+
def __init__(self):
|
15
|
+
self.previous_badge_count = 0
|
16
|
+
|
17
|
+
async def score(self, state: Dict[str, Any], action: Dict[str, Any]) -> float:
|
18
|
+
current_badges = state.get("badges", 0)
|
19
|
+
|
20
|
+
# Convert badge bitmask to count
|
21
|
+
badge_count = bin(current_badges).count("1")
|
22
|
+
|
23
|
+
if badge_count > self.previous_badge_count:
|
24
|
+
new_badges = badge_count - self.previous_badge_count
|
25
|
+
self.previous_badge_count = badge_count
|
26
|
+
return new_badges * 150.0
|
27
|
+
|
28
|
+
return 0.0
|
29
|
+
|
30
|
+
|
31
|
+
class HMAcquisitionReward(RewardComponent):
|
32
|
+
"""Reward for getting HMs - +75 points"""
|
33
|
+
|
34
|
+
def __init__(self):
|
35
|
+
self.hms_acquired: Set[int] = set()
|
36
|
+
# HM item IDs (would be loaded from game data)
|
37
|
+
self.hm_items = {200, 201, 202, 203, 204} # Example HM IDs
|
38
|
+
|
39
|
+
async def score(self, state: Dict[str, Any], action: Dict[str, Any]) -> float:
|
40
|
+
inventory = state.get("inventory", [])
|
41
|
+
total_reward = 0.0
|
42
|
+
|
43
|
+
for item in inventory:
|
44
|
+
item_id = item.get("item_id", 0)
|
45
|
+
if item_id in self.hm_items and item_id not in self.hms_acquired:
|
46
|
+
self.hms_acquired.add(item_id)
|
47
|
+
total_reward += 75.0
|
48
|
+
|
49
|
+
return total_reward
|
50
|
+
|
51
|
+
|
52
|
+
class EliteFourAccessReward(RewardComponent):
|
53
|
+
"""Reward for reaching Pokemon League - +300 points"""
|
54
|
+
|
55
|
+
def __init__(self):
|
56
|
+
self.elite_four_accessed = False
|
57
|
+
self.elite_four_map = 100 # Pokemon League entrance
|
58
|
+
|
59
|
+
async def score(self, state: Dict[str, Any], action: Dict[str, Any]) -> float:
|
60
|
+
if self.elite_four_accessed:
|
61
|
+
return 0.0
|
62
|
+
|
63
|
+
if state["map_id"] == self.elite_four_map:
|
64
|
+
self.elite_four_accessed = True
|
65
|
+
return 300.0
|
66
|
+
|
67
|
+
return 0.0
|
68
|
+
|
69
|
+
|
70
|
+
class HallOfFameEntryReward(RewardComponent):
|
71
|
+
"""Reward for becoming Champion - +1000 points"""
|
72
|
+
|
73
|
+
def __init__(self):
|
74
|
+
self.hall_of_fame_entered = False
|
75
|
+
self.hall_of_fame_map = 105 # Hall of Fame room
|
76
|
+
|
77
|
+
async def score(self, state: Dict[str, Any], action: Dict[str, Any]) -> float:
|
78
|
+
if self.hall_of_fame_entered:
|
79
|
+
return 0.0
|
80
|
+
|
81
|
+
if state["map_id"] == self.hall_of_fame_map:
|
82
|
+
self.hall_of_fame_entered = True
|
83
|
+
return 1000.0
|
84
|
+
|
85
|
+
return 0.0
|
86
|
+
|
87
|
+
|
88
|
+
class RivalBattleCompletionReward(RewardComponent):
|
89
|
+
"""Reward for each scripted rival encounter - +50 points"""
|
90
|
+
|
91
|
+
def __init__(self):
|
92
|
+
self.rival_battles_completed: Set[int] = set()
|
93
|
+
# Rival battle locations
|
94
|
+
self.rival_battle_maps = {3, 22, 25, 30} # Oak's lab, Route 22, etc.
|
95
|
+
|
96
|
+
async def score(self, state: Dict[str, Any], action: Dict[str, Any]) -> float:
|
97
|
+
prev_in_battle = action.get("prev_in_battle", False)
|
98
|
+
current_in_battle = state["in_battle"]
|
99
|
+
battle_outcome = state.get("battle_outcome", 0)
|
100
|
+
current_map = state["map_id"]
|
101
|
+
|
102
|
+
# Completed rival battle
|
103
|
+
if (
|
104
|
+
prev_in_battle
|
105
|
+
and not current_in_battle
|
106
|
+
and battle_outcome == 1
|
107
|
+
and current_map in self.rival_battle_maps
|
108
|
+
and current_map not in self.rival_battles_completed
|
109
|
+
):
|
110
|
+
self.rival_battles_completed.add(current_map)
|
111
|
+
return 50.0
|
112
|
+
|
113
|
+
return 0.0
|
114
|
+
|
115
|
+
|
116
|
+
class TeamRocketDefeatReward(RewardComponent):
|
117
|
+
"""Reward for each Team Rocket encounter - +40 points"""
|
118
|
+
|
119
|
+
def __init__(self):
|
120
|
+
self.rocket_encounters: Set[tuple] = set()
|
121
|
+
|
122
|
+
async def score(self, state: Dict[str, Any], action: Dict[str, Any]) -> float:
|
123
|
+
# This would need Team Rocket battle detection
|
124
|
+
# Placeholder implementation
|
125
|
+
prev_in_battle = action.get("prev_in_battle", False)
|
126
|
+
current_in_battle = state["in_battle"]
|
127
|
+
battle_outcome = state.get("battle_outcome", 0)
|
128
|
+
|
129
|
+
if prev_in_battle and not current_in_battle and battle_outcome == 1:
|
130
|
+
# Check if in Team Rocket location
|
131
|
+
rocket_maps = {50, 51, 52} # Example Team Rocket hideout maps
|
132
|
+
if state["map_id"] in rocket_maps:
|
133
|
+
encounter_key = (state["player_x"], state["player_y"], state["map_id"])
|
134
|
+
if encounter_key not in self.rocket_encounters:
|
135
|
+
self.rocket_encounters.add(encounter_key)
|
136
|
+
return 40.0
|
137
|
+
|
138
|
+
return 0.0
|
139
|
+
|
140
|
+
|
141
|
+
class LegendaryEncounterReward(RewardComponent):
|
142
|
+
"""Reward for encountering legendary Pokemon - +200 points"""
|
143
|
+
|
144
|
+
def __init__(self):
|
145
|
+
self.legendary_encounters: Set[int] = set()
|
146
|
+
self.legendary_maps = {60, 61, 62, 70} # Legendary Pokemon locations
|
147
|
+
|
148
|
+
async def score(self, state: Dict[str, Any], action: Dict[str, Any]) -> float:
|
149
|
+
current_map = state["map_id"]
|
150
|
+
|
151
|
+
if current_map in self.legendary_maps and current_map not in self.legendary_encounters:
|
152
|
+
# Check if battle started (legendary encounter)
|
153
|
+
prev_in_battle = action.get("prev_in_battle", False)
|
154
|
+
current_in_battle = state["in_battle"]
|
155
|
+
|
156
|
+
if not prev_in_battle and current_in_battle:
|
157
|
+
self.legendary_encounters.add(current_map)
|
158
|
+
return 200.0
|
159
|
+
|
160
|
+
return 0.0
|
161
|
+
|
162
|
+
|
163
|
+
class SilphCoCompletionReward(RewardComponent):
|
164
|
+
"""Reward for completing major story dungeons - +100 points"""
|
165
|
+
|
166
|
+
def __init__(self):
|
167
|
+
self.silph_co_completed = False
|
168
|
+
self.silph_co_maps = set(range(80, 90)) # Silph Co floors
|
169
|
+
|
170
|
+
async def score(self, state: Dict[str, Any], action: Dict[str, Any]) -> float:
|
171
|
+
if self.silph_co_completed:
|
172
|
+
return 0.0
|
173
|
+
|
174
|
+
# Check if exiting Silph Co after completion
|
175
|
+
prev_map = action.get("prev_map_id", -1)
|
176
|
+
current_map = state["map_id"]
|
177
|
+
|
178
|
+
if prev_map in self.silph_co_maps and current_map not in self.silph_co_maps:
|
179
|
+
# Assume completion if leaving Silph Co
|
180
|
+
self.silph_co_completed = True
|
181
|
+
return 100.0
|
182
|
+
|
183
|
+
return 0.0
|
184
|
+
|
185
|
+
|
186
|
+
class SafariZoneSuccessReward(RewardComponent):
|
187
|
+
"""Reward for successful Safari Zone runs - +30 points"""
|
188
|
+
|
189
|
+
def __init__(self):
|
190
|
+
self.safari_zone_runs = 0
|
191
|
+
self.safari_zone_maps = {90, 91, 92, 93} # Safari Zone areas
|
192
|
+
|
193
|
+
async def score(self, state: Dict[str, Any], action: Dict[str, Any]) -> float:
|
194
|
+
# Check if exiting Safari Zone with new Pokemon
|
195
|
+
prev_map = action.get("prev_map_id", -1)
|
196
|
+
current_map = state["map_id"]
|
197
|
+
|
198
|
+
if prev_map in self.safari_zone_maps and current_map not in self.safari_zone_maps:
|
199
|
+
# Check if Pokemon count increased
|
200
|
+
prev_party_count = len(action.get("prev_party", []))
|
201
|
+
current_party_count = len(state.get("party", []))
|
202
|
+
|
203
|
+
if current_party_count > prev_party_count:
|
204
|
+
return 30.0
|
205
|
+
|
206
|
+
return 0.0
|
207
|
+
|
208
|
+
|
209
|
+
class GameCornerPrizesReward(RewardComponent):
|
210
|
+
"""Reward for earning significant Game Corner prizes - +20 points"""
|
211
|
+
|
212
|
+
def __init__(self):
|
213
|
+
self.game_corner_prizes: Set[int] = set()
|
214
|
+
self.prize_items = {300, 301, 302} # Game Corner prize item IDs
|
215
|
+
|
216
|
+
async def score(self, state: Dict[str, Any], action: Dict[str, Any]) -> float:
|
217
|
+
inventory = state.get("inventory", [])
|
218
|
+
total_reward = 0.0
|
219
|
+
|
220
|
+
for item in inventory:
|
221
|
+
item_id = item.get("item_id", 0)
|
222
|
+
if item_id in self.prize_items and item_id not in self.game_corner_prizes:
|
223
|
+
self.game_corner_prizes.add(item_id)
|
224
|
+
total_reward += 20.0
|
225
|
+
|
226
|
+
return total_reward
|
227
|
+
|
228
|
+
|
229
|
+
class FossilRevivalReward(RewardComponent):
|
230
|
+
"""Reward for reviving fossils - +40 points"""
|
231
|
+
|
232
|
+
def __init__(self):
|
233
|
+
self.fossils_revived: Set[int] = set()
|
234
|
+
self.fossil_pokemon = {138, 140, 142} # Omanyte, Kabuto, Aerodactyl
|
235
|
+
|
236
|
+
async def score(self, state: Dict[str, Any], action: Dict[str, Any]) -> float:
|
237
|
+
party = state.get("party", [])
|
238
|
+
total_reward = 0.0
|
239
|
+
|
240
|
+
for pokemon in party:
|
241
|
+
species_id = pokemon.get("species_id", 0)
|
242
|
+
if species_id in self.fossil_pokemon and species_id not in self.fossils_revived:
|
243
|
+
self.fossils_revived.add(species_id)
|
244
|
+
total_reward += 40.0
|
245
|
+
|
246
|
+
return total_reward
|