synth-ai 0.2.16__py3-none-any.whl → 0.2.17__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of synth-ai might be problematic. Click here for more details.

Files changed (192) hide show
  1. examples/analyze_semantic_words.sh +2 -2
  2. examples/blog_posts/pokemon_vl/README.md +98 -0
  3. examples/blog_posts/pokemon_vl/configs/eval_qwen3_vl.toml +25 -0
  4. examples/blog_posts/pokemon_vl/configs/eval_rl_final.toml +24 -0
  5. examples/blog_posts/pokemon_vl/configs/filter_high_reward.toml +10 -0
  6. examples/blog_posts/pokemon_vl/configs/train_rl_from_sft.toml +42 -0
  7. examples/blog_posts/pokemon_vl/configs/train_sft_qwen4b_vl.toml +40 -0
  8. examples/blog_posts/warming_up_to_rl/README.md +158 -0
  9. examples/blog_posts/warming_up_to_rl/configs/eval_ft_qwen4b.toml +25 -0
  10. examples/blog_posts/warming_up_to_rl/configs/eval_groq_qwen32b.toml +25 -0
  11. examples/blog_posts/warming_up_to_rl/configs/eval_openai_gpt_oss_120b.toml +29 -0
  12. examples/blog_posts/warming_up_to_rl/configs/filter_high_reward_dataset.toml +10 -0
  13. examples/blog_posts/warming_up_to_rl/configs/train_rl_from_sft.toml +41 -0
  14. examples/blog_posts/warming_up_to_rl/configs/train_sft_qwen4b.toml +40 -0
  15. examples/dev/qwen3_32b_qlora_4xh100.toml +5 -0
  16. examples/multi_step/configs/crafter_rl_outcome.toml +1 -1
  17. examples/multi_step/configs/crafter_rl_stepwise_hosted_judge.toml +65 -107
  18. examples/multi_step/configs/crafter_rl_stepwise_shaped.toml +1 -1
  19. examples/multi_step/configs/crafter_rl_stepwise_simple.toml +1 -1
  20. examples/multi_step/configs/crafter_rl_stepwise_simple_NEW_FORMAT.toml +105 -0
  21. examples/multi_step/configs/verilog_rl_lora.toml +80 -123
  22. examples/qwen_coder/configs/coder_lora_30b.toml +1 -3
  23. examples/qwen_coder/configs/coder_lora_4b.toml +4 -1
  24. examples/qwen_coder/configs/coder_lora_small.toml +1 -3
  25. examples/qwen_vl/README.md +10 -12
  26. examples/qwen_vl/SETUP_COMPLETE.md +7 -8
  27. examples/qwen_vl/VISION_TESTS_COMPLETE.md +2 -3
  28. examples/qwen_vl/collect_data_via_cli.md +76 -84
  29. examples/qwen_vl/collect_vision_traces.py +4 -4
  30. examples/qwen_vl/configs/crafter_rl_vision_qwen3vl4b.toml +40 -57
  31. examples/qwen_vl/configs/crafter_vlm_sft_example.toml +1 -2
  32. examples/qwen_vl/configs/eval_gpt4o_mini_vision.toml +20 -37
  33. examples/qwen_vl/configs/eval_gpt5nano_vision.toml +21 -40
  34. examples/qwen_vl/configs/eval_qwen3vl_vision.toml +26 -0
  35. examples/qwen_vl/configs/{filter_qwen2vl_sft.toml → filter_qwen3vl_sft.toml} +4 -5
  36. examples/qwen_vl/configs/filter_vision_sft.toml +2 -3
  37. examples/qwen_vl/crafter_qwen_vl_agent.py +5 -5
  38. examples/qwen_vl/run_vision_comparison.sh +6 -7
  39. examples/rl/README.md +5 -5
  40. examples/rl/configs/rl_from_base_qwen.toml +26 -1
  41. examples/rl/configs/rl_from_base_qwen17.toml +5 -2
  42. examples/rl/task_app/README.md +1 -2
  43. examples/rl/task_app/math_single_step.py +2 -2
  44. examples/run_crafter_demo.sh +2 -2
  45. examples/sft/README.md +1 -1
  46. examples/sft/configs/crafter_fft_qwen0p6b.toml +4 -1
  47. examples/sft/configs/crafter_lora_qwen0p6b.toml +4 -1
  48. examples/swe/task_app/README.md +32 -2
  49. examples/swe/task_app/grpo_swe_mini.py +4 -0
  50. examples/swe/task_app/hosted/envs/crafter/react_agent.py +1 -1
  51. examples/swe/task_app/hosted/envs/mini_swe/environment.py +37 -10
  52. examples/swe/task_app/hosted/inference/openai_client.py +4 -4
  53. examples/swe/task_app/morph_backend.py +178 -0
  54. examples/task_apps/crafter/task_app/README.md +1 -1
  55. examples/task_apps/crafter/task_app/grpo_crafter.py +66 -3
  56. examples/task_apps/crafter/task_app/grpo_crafter_task_app.py +1 -1
  57. examples/task_apps/crafter/task_app/synth_envs_hosted/envs/crafter/policy.py +4 -26
  58. examples/task_apps/crafter/task_app/synth_envs_hosted/envs/crafter/react_agent.py +1 -2
  59. examples/task_apps/crafter/task_app/synth_envs_hosted/inference/openai_client.py +17 -49
  60. examples/task_apps/crafter/task_app/synth_envs_hosted/policy_routes.py +13 -5
  61. examples/task_apps/crafter/task_app/synth_envs_hosted/rollout.py +15 -1
  62. examples/task_apps/enron/task_app/grpo_enron_task_app.py +1 -1
  63. examples/task_apps/math/README.md +1 -2
  64. examples/task_apps/pokemon_red/README.md +3 -4
  65. examples/task_apps/pokemon_red/eval_image_only_gpt4o.toml +6 -5
  66. examples/task_apps/pokemon_red/eval_pokemon_red_policy.py +1 -2
  67. examples/task_apps/pokemon_red/task_app.py +36 -5
  68. examples/task_apps/sokoban/README.md +2 -3
  69. examples/task_apps/verilog/eval_groq_qwen32b.toml +12 -14
  70. examples/task_apps/verilog/task_app/grpo_verilog_task_app.py +1 -1
  71. examples/vlm/configs/crafter_vlm_gpt4o.toml +4 -1
  72. examples/warming_up_to_rl/configs/crafter_fft.toml +4 -1
  73. examples/warming_up_to_rl/configs/crafter_fft_4b.toml +0 -2
  74. examples/warming_up_to_rl/configs/rl_from_base_qwen4b.toml +2 -2
  75. examples/warming_up_to_rl/run_local_rollout_traced.py +1 -1
  76. examples/warming_up_to_rl/task_app/README.md +1 -1
  77. examples/warming_up_to_rl/task_app/grpo_crafter.py +134 -3
  78. examples/warming_up_to_rl/task_app/grpo_crafter_task_app.py +1 -1
  79. examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/policy.py +3 -27
  80. examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/react_agent.py +1 -1
  81. examples/warming_up_to_rl/task_app/synth_envs_hosted/inference/openai_client.py +4 -4
  82. examples/warming_up_to_rl/task_app/synth_envs_hosted/policy_routes.py +6 -3
  83. examples/workflows/math_rl/configs/rl_from_base_qwen.toml +27 -0
  84. examples/workflows/math_rl/configs/rl_from_base_qwen17.toml +5 -0
  85. synth_ai/api/train/builders.py +9 -3
  86. synth_ai/api/train/cli.py +125 -10
  87. synth_ai/api/train/configs/__init__.py +8 -1
  88. synth_ai/api/train/configs/rl.py +32 -7
  89. synth_ai/api/train/configs/sft.py +6 -2
  90. synth_ai/api/train/configs/shared.py +59 -2
  91. synth_ai/auth/credentials.py +119 -0
  92. synth_ai/cli/__init__.py +12 -4
  93. synth_ai/cli/commands/__init__.py +17 -0
  94. synth_ai/cli/commands/demo/__init__.py +6 -0
  95. synth_ai/cli/commands/demo/core.py +163 -0
  96. synth_ai/cli/commands/deploy/__init__.py +23 -0
  97. synth_ai/cli/commands/deploy/core.py +614 -0
  98. synth_ai/cli/commands/deploy/errors.py +72 -0
  99. synth_ai/cli/commands/deploy/validation.py +11 -0
  100. synth_ai/cli/commands/eval/__init__.py +19 -0
  101. synth_ai/cli/commands/eval/core.py +1109 -0
  102. synth_ai/cli/commands/eval/errors.py +81 -0
  103. synth_ai/cli/commands/eval/validation.py +133 -0
  104. synth_ai/cli/commands/filter/__init__.py +12 -0
  105. synth_ai/cli/commands/filter/core.py +388 -0
  106. synth_ai/cli/commands/filter/errors.py +55 -0
  107. synth_ai/cli/commands/filter/validation.py +77 -0
  108. synth_ai/cli/commands/help/__init__.py +177 -0
  109. synth_ai/cli/commands/help/core.py +73 -0
  110. synth_ai/cli/commands/status/__init__.py +64 -0
  111. synth_ai/cli/commands/status/client.py +192 -0
  112. synth_ai/cli/commands/status/config.py +92 -0
  113. synth_ai/cli/commands/status/errors.py +20 -0
  114. synth_ai/cli/commands/status/formatters.py +164 -0
  115. synth_ai/cli/commands/status/subcommands/__init__.py +9 -0
  116. synth_ai/cli/commands/status/subcommands/files.py +79 -0
  117. synth_ai/cli/commands/status/subcommands/jobs.py +334 -0
  118. synth_ai/cli/commands/status/subcommands/models.py +79 -0
  119. synth_ai/cli/commands/status/subcommands/runs.py +81 -0
  120. synth_ai/cli/commands/status/subcommands/summary.py +47 -0
  121. synth_ai/cli/commands/status/utils.py +114 -0
  122. synth_ai/cli/commands/train/__init__.py +53 -0
  123. synth_ai/cli/commands/train/core.py +21 -0
  124. synth_ai/cli/commands/train/errors.py +117 -0
  125. synth_ai/cli/commands/train/judge_schemas.py +199 -0
  126. synth_ai/cli/commands/train/judge_validation.py +304 -0
  127. synth_ai/cli/commands/train/validation.py +443 -0
  128. synth_ai/cli/demo.py +2 -162
  129. synth_ai/cli/deploy/__init__.py +28 -0
  130. synth_ai/cli/deploy/core.py +5 -0
  131. synth_ai/cli/deploy/errors.py +23 -0
  132. synth_ai/cli/deploy/validation.py +5 -0
  133. synth_ai/cli/eval/__init__.py +36 -0
  134. synth_ai/cli/eval/core.py +5 -0
  135. synth_ai/cli/eval/errors.py +31 -0
  136. synth_ai/cli/eval/validation.py +5 -0
  137. synth_ai/cli/filter/__init__.py +28 -0
  138. synth_ai/cli/filter/core.py +5 -0
  139. synth_ai/cli/filter/errors.py +23 -0
  140. synth_ai/cli/filter/validation.py +5 -0
  141. synth_ai/cli/modal_serve/__init__.py +12 -0
  142. synth_ai/cli/modal_serve/core.py +14 -0
  143. synth_ai/cli/modal_serve/errors.py +8 -0
  144. synth_ai/cli/modal_serve/validation.py +11 -0
  145. synth_ai/cli/serve/__init__.py +12 -0
  146. synth_ai/cli/serve/core.py +14 -0
  147. synth_ai/cli/serve/errors.py +8 -0
  148. synth_ai/cli/serve/validation.py +11 -0
  149. synth_ai/cli/setup.py +20 -265
  150. synth_ai/cli/status.py +7 -126
  151. synth_ai/cli/task_app_deploy.py +1 -10
  152. synth_ai/cli/task_app_modal_serve.py +4 -9
  153. synth_ai/cli/task_app_serve.py +4 -11
  154. synth_ai/cli/task_apps.py +58 -1487
  155. synth_ai/cli/train/__init__.py +12 -0
  156. synth_ai/cli/train/core.py +21 -0
  157. synth_ai/cli/train/errors.py +8 -0
  158. synth_ai/cli/train/validation.py +24 -0
  159. synth_ai/cli/train.py +1 -14
  160. synth_ai/demos/crafter/grpo_crafter_task_app.py +1 -1
  161. synth_ai/demos/demo_task_apps/crafter/grpo_crafter_task_app.py +1 -1
  162. synth_ai/environments/examples/red/engine.py +33 -12
  163. synth_ai/environments/examples/red/engine_helpers/reward_components.py +151 -179
  164. synth_ai/environments/examples/red/environment.py +26 -0
  165. synth_ai/environments/examples/red/trace_hooks_v3.py +168 -0
  166. synth_ai/http.py +12 -0
  167. synth_ai/judge_schemas.py +10 -11
  168. synth_ai/learning/rl/client.py +3 -1
  169. synth_ai/streaming/__init__.py +29 -0
  170. synth_ai/streaming/config.py +94 -0
  171. synth_ai/streaming/handlers.py +469 -0
  172. synth_ai/streaming/streamer.py +301 -0
  173. synth_ai/streaming/types.py +95 -0
  174. synth_ai/task/validators.py +2 -2
  175. synth_ai/tracing_v3/migration_helper.py +1 -2
  176. synth_ai/utils/env.py +25 -18
  177. synth_ai/utils/http.py +4 -1
  178. synth_ai/utils/modal.py +2 -2
  179. {synth_ai-0.2.16.dist-info → synth_ai-0.2.17.dist-info}/METADATA +8 -3
  180. {synth_ai-0.2.16.dist-info → synth_ai-0.2.17.dist-info}/RECORD +184 -109
  181. examples/qwen_vl/configs/eval_qwen2vl_vision.toml +0 -44
  182. synth_ai/cli/tui.py +0 -62
  183. synth_ai/tui/__init__.py +0 -5
  184. synth_ai/tui/__main__.py +0 -13
  185. synth_ai/tui/cli/__init__.py +0 -1
  186. synth_ai/tui/cli/query_experiments.py +0 -164
  187. synth_ai/tui/cli/query_experiments_v3.py +0 -164
  188. synth_ai/tui/dashboard.py +0 -911
  189. {synth_ai-0.2.16.dist-info → synth_ai-0.2.17.dist-info}/WHEEL +0 -0
  190. {synth_ai-0.2.16.dist-info → synth_ai-0.2.17.dist-info}/entry_points.txt +0 -0
  191. {synth_ai-0.2.16.dist-info → synth_ai-0.2.17.dist-info}/licenses/LICENSE +0 -0
  192. {synth_ai-0.2.16.dist-info → synth_ai-0.2.17.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,12 @@
1
+ from __future__ import annotations
2
+
3
+ from .core import register, train_command
4
+ from .errors import TrainCliError
5
+ from .validation import validate_train_environment
6
+
7
+ __all__ = [
8
+ "register",
9
+ "train_command",
10
+ "TrainCliError",
11
+ "validate_train_environment",
12
+ ]
@@ -0,0 +1,21 @@
1
+ from __future__ import annotations
2
+
3
+ import click
4
+ from synth_ai.api.train.cli import (
5
+ register as _register_with_cli,
6
+ )
7
+ from synth_ai.api.train.cli import (
8
+ train_command as _train_command,
9
+ )
10
+
11
+ __all__ = ["register", "train_command"]
12
+
13
+
14
+ def register(cli: click.Group) -> None:
15
+ """Attach the train command to the root CLI."""
16
+ _register_with_cli(cli)
17
+
18
+
19
+ def train_command(*args, **kwargs):
20
+ """Entrypoint used by the train CLI command."""
21
+ return _train_command(*args, **kwargs)
@@ -0,0 +1,8 @@
1
+ from __future__ import annotations
2
+
3
+
4
+ class TrainCliError(RuntimeError):
5
+ """Base exception for train CLI failures."""
6
+
7
+
8
+ __all__ = ["TrainCliError"]
@@ -0,0 +1,24 @@
1
+ from __future__ import annotations
2
+
3
+ from collections.abc import Iterable
4
+ from pathlib import Path
5
+ from typing import Dict, Tuple
6
+
7
+ from synth_ai.api.train.env_resolver import KeySpec, resolve_env
8
+
9
+ __all__ = ["validate_train_environment"]
10
+
11
+
12
+ def validate_train_environment(
13
+ *,
14
+ config_path: Path | None,
15
+ explicit_env_paths: Iterable[str],
16
+ required_keys: list[KeySpec],
17
+ ) -> Tuple[Path, Dict[str, str]]:
18
+ """Validate and resolve environment secrets used by the train command."""
19
+ resolved_path, resolved_keys = resolve_env(
20
+ config_path=config_path,
21
+ explicit_env_paths=explicit_env_paths,
22
+ required_keys=required_keys,
23
+ )
24
+ return resolved_path, resolved_keys
synth_ai/cli/train.py CHANGED
@@ -1,18 +1,5 @@
1
1
  from __future__ import annotations
2
2
 
3
- from typing import Any
4
-
5
- from synth_ai.api.train.cli import register as _register
6
- from synth_ai.api.train.cli import train_command as _train_command
3
+ from synth_ai.cli.commands.train.core import register, train_command
7
4
 
8
5
  __all__ = ["register", "train_command"]
9
-
10
-
11
- def register(cli: Any) -> None:
12
- """Compatibility wrapper for the legacy train CLI location."""
13
-
14
- _register(cli)
15
-
16
-
17
- def train_command(*args: Any, **kwargs: Any) -> Any:
18
- return _train_command(*args, **kwargs)
@@ -3,7 +3,7 @@
3
3
  This module now delegates to the TaskAppConfig defined in the local example at
4
4
  `examples/warming_up_to_rl/task_app/grpo_crafter.py`. It is kept for legacy usage
5
5
  (running the file directly or targeting `fastapi_app` from external tooling).
6
- Prefer using `uvx synth-ai serve grpo-crafter` for local development and testing.
6
+ Prefer using `uvx synth-ai deploy --runtime uvicorn grpo-crafter` for local development and testing.
7
7
  """
8
8
 
9
9
  from __future__ import annotations
@@ -3,7 +3,7 @@
3
3
  This module now delegates to the TaskAppConfig defined in the local example at
4
4
  `examples/task_apps/crafter/task_app/grpo_crafter.py`. It is kept for legacy usage
5
5
  (running the file directly or targeting `fastapi_app` from external tooling).
6
- Prefer using `uvx synth-ai serve grpo-crafter` for local development and testing.
6
+ Prefer using `uvx synth-ai deploy --runtime uvicorn grpo-crafter` for local development and testing.
7
7
  """
8
8
 
9
9
  from __future__ import annotations
@@ -14,12 +14,15 @@ from synth_ai.environments.stateful.engine import StatefulEngine, StatefulEngine
14
14
  from synth_ai.environments.tasks.core import TaskInstance
15
15
 
16
16
  from .engine_helpers.reward_components import (
17
- BadgeRewardComponent,
18
- BattleVictoryComponent,
19
- LevelUpComponent,
20
- MapTransitionComponent,
17
+ RouteExplorationReward,
18
+ StrategicTrainingReward,
19
+ BattleProgressionReward,
20
+ GymPreparationReward,
21
+ ItemCollectionReward,
22
+ HealingManagementReward,
23
+ EfficientExplorationReward,
24
+ BadgeVictoryReward,
21
25
  StepPenaltyComponent,
22
- XPGainComponent,
23
26
  )
24
27
  from .engine_helpers.state_extraction import extract_game_state
25
28
 
@@ -268,15 +271,27 @@ class PokemonRedEngine(StatefulEngine, IReproducibleEngine):
268
271
  # For testing purposes, use None emulator
269
272
  self.emulator = None
270
273
 
271
- # Initialize reward stack with dense components
274
+ # Initialize reward stack with comprehensive progress-based components
272
275
  self.reward_stack = RewardStack(
273
276
  components=[
274
- BadgeRewardComponent(),
275
- MapTransitionComponent(),
276
- BattleVictoryComponent(),
277
- LevelUpComponent(),
278
- XPGainComponent(),
279
- StepPenaltyComponent(),
277
+ # Major progress rewards
278
+ BadgeVictoryReward(), # +50.0 for Boulder Badge (main goal)
279
+ RouteExplorationReward(), # +1.0-5.0 for reaching key areas
280
+ GymPreparationReward(), # +3.0 for being gym-ready
281
+
282
+ # Training and battle rewards
283
+ StrategicTrainingReward(), # +0.2-3.0 for level ups and milestones
284
+ BattleProgressionReward(), # +0.1-1.0 for battles
285
+
286
+ # Resource management rewards
287
+ ItemCollectionReward(), # +0.1-0.5 for collecting items
288
+ HealingManagementReward(), # +0.05-0.8 for healing Pokemon
289
+
290
+ # Exploration efficiency
291
+ EfficientExplorationReward(), # +0.02 for discovering new positions
292
+
293
+ # No penalty for unproductive actions
294
+ StepPenaltyComponent(penalty=0.0), # 0.0 per step
280
295
  ]
281
296
  )
282
297
 
@@ -640,6 +655,12 @@ class PokemonRedEngine(StatefulEngine, IReproducibleEngine):
640
655
  "prev_text_box_active": bool(prev_state.get("text_box_active", False)),
641
656
  "prev_enemy_hp_current": int(prev_state.get("enemy_hp_current", 0)),
642
657
  "prev_enemy_hp_percentage": float(prev_state.get("enemy_hp_percentage", 0.0)),
658
+ "prev_player_x": int(prev_state.get("player_x", 0)),
659
+ "prev_player_y": int(prev_state.get("player_y", 0)),
660
+ "prev_party": prev_state.get("party", []),
661
+ "prev_inventory": prev_state.get("inventory", []),
662
+ "prev_party_hp_current": int(prev_state.get("party_hp_current", 0)),
663
+ "prev_party_hp_max": int(prev_state.get("party_hp_max", 0)),
643
664
  },
644
665
  )
645
666
  except Exception as e:
@@ -3,274 +3,246 @@ from typing import Any, Dict, Set
3
3
  from synth_ai.environments.environment.rewards.core import RewardComponent
4
4
 
5
5
 
6
- class BadgeRewardComponent(RewardComponent):
7
- """Reward for earning gym badges"""
6
+ # ===== COMPREHENSIVE POKEMON RED PROGRESS REWARD SYSTEM =====
7
+ # Designed for deterministic rewards that guide toward beating Brock at Pewter Gym
8
8
 
9
- async def score(self, state: Dict[str, Any], action: Dict[str, Any]) -> float:
10
- prev_badges = action.get("prev_badges", 0)
11
- current_badges = state["badges"]
12
- new_badges = current_badges & ~prev_badges
13
- badge_count = bin(new_badges).count("1")
14
- return badge_count * 1.0
15
9
 
10
+ class RouteExplorationReward(RewardComponent):
11
+ """High rewards for reaching key areas on the path to Pewter Gym - guides exploration"""
16
12
 
17
- class MapTransitionComponent(RewardComponent):
18
- """Reward for exploring new areas"""
13
+ def __init__(self):
14
+ self.key_areas_reached: Set[int] = set()
19
15
 
20
16
  async def score(self, state: Dict[str, Any], action: Dict[str, Any]) -> float:
21
- prev_map = action.get("prev_map_id", -1)
22
17
  current_map = state["map_id"]
23
- return 0.1 if current_map != prev_map else 0.0
24
-
18
+ prev_map = action.get("prev_map_id", -1)
25
19
 
26
- class BattleVictoryComponent(RewardComponent):
27
- """Reward for winning battles"""
20
+ # Key maps and rewards for progressing toward Pewter Gym
21
+ area_rewards = {
22
+ 0: 0.0, # Pallet Town (starting point)
23
+ 1: 2.0, # Route 1 - First step out of town (+2.0)
24
+ 2: 1.5, # Viridian City - Major hub (+1.5)
25
+ 3: 1.0, # Route 22 - Path to League (+1.0)
26
+ 4: 1.0, # Route 2 - To Viridian Forest (+1.0)
27
+ 5: 2.0, # Viridian Forest - Dense area (+2.0)
28
+ 6: 1.5, # Pewter City - Target city (+1.5)
29
+ 7: 5.0, # Pewter Gym - GOAL AREA (+5.0 for entering gym)
30
+ }
31
+
32
+ if current_map in area_rewards and current_map not in self.key_areas_reached:
33
+ if prev_map != current_map: # Only reward when actually entering new area
34
+ self.key_areas_reached.add(current_map)
35
+ return area_rewards[current_map]
28
36
 
29
- async def score(self, state: Dict[str, Any], action: Dict[str, Any]) -> float:
30
- prev_in_battle = action.get("prev_in_battle", False)
31
- current_in_battle = state["in_battle"]
32
- battle_outcome = state["battle_outcome"]
33
-
34
- # Transitioning from battle to not in battle with victory
35
- if prev_in_battle and not current_in_battle and battle_outcome == 1:
36
- return 0.5
37
37
  return 0.0
38
38
 
39
39
 
40
- class LevelUpComponent(RewardComponent):
41
- """Reward for Pokemon leveling up"""
40
+ class StrategicTrainingReward(RewardComponent):
41
+ """Rewards for building Pokemon strength strategically"""
42
+
43
+ def __init__(self):
44
+ self.level_milestones: Set[int] = set()
45
+ self.last_level = 0
42
46
 
43
47
  async def score(self, state: Dict[str, Any], action: Dict[str, Any]) -> float:
48
+ current_level = state.get("party_level", 0)
44
49
  prev_level = action.get("prev_party_level", 0)
45
- current_level = state["party_level"]
46
- level_gain = max(0, current_level - prev_level)
47
- return level_gain * 0.3
48
50
 
51
+ # Reward reaching key level milestones
52
+ milestone_rewards = {
53
+ 8: 1.0, # Level 8 - Good for early battles
54
+ 12: 2.0, # Level 12 - Ready for Brock
55
+ 15: 3.0, # Level 15 - Strong Pokemon
56
+ }
49
57
 
50
- class XPGainComponent(RewardComponent):
51
- """Small reward for XP gains"""
58
+ if current_level > prev_level and current_level in milestone_rewards:
59
+ if current_level not in self.level_milestones:
60
+ self.level_milestones.add(current_level)
61
+ return milestone_rewards[current_level]
52
62
 
53
- async def score(self, state: Dict[str, Any], action: Dict[str, Any]) -> float:
54
- prev_xp = action.get("prev_party_xp", 0)
55
- current_xp = state["party_xp"]
56
- xp_gain = max(0, current_xp - prev_xp)
57
- return xp_gain * 0.001 # Very small multiplier
63
+ # Small reward for any level up (0.2 points)
64
+ if current_level > prev_level:
65
+ return 0.2
58
66
 
67
+ return 0.0
59
68
 
60
- class StepPenaltyComponent(RewardComponent):
61
- """Small penalty for each step to encourage efficiency"""
62
69
 
63
- def __init__(self, penalty: float = -0.001):
64
- self.penalty = penalty
70
+ class BattleProgressionReward(RewardComponent):
71
+ """Rewards for winning battles and gaining experience"""
65
72
 
66
73
  async def score(self, state: Dict[str, Any], action: Dict[str, Any]) -> float:
67
- return self.penalty
74
+ prev_in_battle = action.get("prev_in_battle", False)
75
+ current_in_battle = state.get("in_battle", False)
76
+ battle_outcome = state.get("battle_outcome", 0)
68
77
 
78
+ # Large reward for battle victory (+1.0)
79
+ if prev_in_battle and not current_in_battle and battle_outcome == 1:
80
+ return 1.0
69
81
 
70
- class MenuPenaltyComponent(RewardComponent):
71
- """Penalty for excessive menu usage"""
82
+ # Small reward for entering battle (+0.1) - shows engagement
83
+ if not prev_in_battle and current_in_battle:
84
+ return 0.1
72
85
 
73
- async def score(self, state: Dict[str, Any], action: Dict[str, Any]) -> float:
74
- # This would need more sophisticated menu tracking
75
86
  return 0.0
76
87
 
77
88
 
78
- # ===== NEW EARLY GAME PALLET TOWN REWARDS =====
79
-
80
-
81
- class ExitHouseReward(RewardComponent):
82
- """High reward for first time leaving the starting house - +2.0 points"""
89
+ class GymPreparationReward(RewardComponent):
90
+ """Rewards for preparing to challenge Brock"""
83
91
 
84
92
  def __init__(self):
85
- self.house_exited = False
93
+ self.prepared_for_gym = False
86
94
 
87
95
  async def score(self, state: Dict[str, Any], action: Dict[str, Any]) -> float:
88
- if self.house_exited:
96
+ if self.prepared_for_gym:
89
97
  return 0.0
90
98
 
91
- prev_map = action.get("prev_map_id", -1)
92
- current_map = state["map_id"]
99
+ # Check if in Pewter City area and have decent Pokemon
100
+ if state["map_id"] in [6, 7]: # Pewter City or Gym
101
+ party_level = state.get("party_level", 0)
102
+ party_count = len(state.get("party", []))
103
+
104
+ # Reward being prepared for gym battle
105
+ if party_level >= 10 and party_count >= 1:
106
+ self.prepared_for_gym = True
107
+ return 3.0 # Significant reward for being gym-ready
93
108
 
94
- # Exit from house to town (assuming house maps are 1,2 and town is 0)
95
- if prev_map in [1, 2] and current_map == 0:
96
- self.house_exited = True
97
- return 2.0
98
109
  return 0.0
99
110
 
100
111
 
101
- class NPCInteractionReward(RewardComponent):
102
- """Reward for talking to NPCs - +0.8 points per unique NPC"""
112
+ class ItemCollectionReward(RewardComponent):
113
+ """Rewards for collecting useful items"""
103
114
 
104
115
  def __init__(self):
105
- self.npcs_talked_to: Set[tuple] = set()
116
+ self.items_collected: Set[int] = set()
106
117
 
107
118
  async def score(self, state: Dict[str, Any], action: Dict[str, Any]) -> float:
108
- # Detect NPC conversations
109
- if state["text_box_active"] and not action.get("prev_text_box_active", False):
110
- # Use position as NPC identifier
111
- npc_key = (state["player_x"], state["player_y"], state["map_id"])
112
- if npc_key not in self.npcs_talked_to:
113
- self.npcs_talked_to.add(npc_key)
114
- return 0.8
115
- return 0.0
119
+ prev_inventory = action.get("prev_inventory", [])
120
+ current_inventory = state.get("inventory", [])
116
121
 
122
+ # Check for new items
123
+ prev_item_ids = {item["item_id"] for item in prev_inventory}
124
+ current_item_ids = {item["item_id"] for item in current_inventory}
117
125
 
118
- class OakLabDiscoveryReward(RewardComponent):
119
- """High reward for finding and entering Oak's lab - +2.5 points"""
126
+ new_items = current_item_ids - prev_item_ids
120
127
 
121
- def __init__(self):
122
- self.lab_discovered = False
128
+ # Reward valuable items for gym preparation
129
+ valuable_items = {1, 2, 3, 4, 5, 10, 11, 12, 13} # Potions, Balls, etc.
130
+ reward = 0.0
131
+
132
+ for item_id in new_items:
133
+ if item_id not in self.items_collected:
134
+ self.items_collected.add(item_id)
135
+ if item_id in valuable_items:
136
+ reward += 0.5 # +0.5 per valuable item
137
+ else:
138
+ reward += 0.1 # +0.1 per other item
139
+
140
+ return reward
141
+
142
+
143
+ class HealingManagementReward(RewardComponent):
144
+ """Rewards for keeping Pokemon healthy"""
123
145
 
124
146
  async def score(self, state: Dict[str, Any], action: Dict[str, Any]) -> float:
125
- if self.lab_discovered:
147
+ prev_party = action.get("prev_party", [])
148
+ current_party = state.get("party", [])
149
+
150
+ if not prev_party or not current_party:
126
151
  return 0.0
127
152
 
128
- prev_map = action.get("prev_map_id", -1)
129
- current_map = state["map_id"]
153
+ # Reward healing Pokemon back to full health
154
+ prev_hp_pct = sum(p.get("hp_percentage", 0) for p in prev_party) / len(prev_party)
155
+ current_hp_pct = sum(p.get("hp_percentage", 0) for p in current_party) / len(current_party)
156
+
157
+ # Significant improvement in health
158
+ if current_hp_pct > prev_hp_pct + 20: # Healed at least 20% overall
159
+ return 0.8
160
+
161
+ # Small reward for maintaining good health
162
+ if current_hp_pct >= 80 and prev_hp_pct >= 80:
163
+ return 0.05
130
164
 
131
- # Entering Oak's lab (assuming map 3)
132
- if prev_map == 0 and current_map == 3:
133
- self.lab_discovered = True
134
- return 2.5
135
165
  return 0.0
136
166
 
137
167
 
138
- class StarterPokemonReward(RewardComponent):
139
- """Very high reward for getting first Pokemon - +10.0 points"""
168
+ class EfficientExplorationReward(RewardComponent):
169
+ """Rewards for exploring efficiently without getting lost"""
140
170
 
141
171
  def __init__(self):
142
- self.starter_obtained = False
172
+ self.positions_visited: Set[tuple] = set()
143
173
 
144
174
  async def score(self, state: Dict[str, Any], action: Dict[str, Any]) -> float:
145
- if self.starter_obtained:
146
- return 0.0
175
+ # Track unique positions visited in each map
176
+ position_key = (state["map_id"], state["player_x"], state["player_y"])
147
177
 
148
- # Detect getting first Pokemon
149
- prev_party_count = len(action.get("prev_party", []))
150
- current_party_count = len(state.get("party", []))
178
+ if position_key not in self.positions_visited:
179
+ self.positions_visited.add(position_key)
180
+ return 0.02 # Small reward for discovering new areas
151
181
 
152
- if prev_party_count == 0 and current_party_count == 1:
153
- if state["map_id"] == 3: # In Oak's lab
154
- self.starter_obtained = True
155
- return 10.0
156
182
  return 0.0
157
183
 
158
184
 
159
- class FirstBattleReward(RewardComponent):
160
- """High reward for engaging in first battle - +5.0 points"""
161
-
162
- def __init__(self):
163
- self.first_battle = False
185
+ class BadgeVictoryReward(RewardComponent):
186
+ """HUGE reward for achieving the main goal - Boulder Badge"""
164
187
 
165
188
  async def score(self, state: Dict[str, Any], action: Dict[str, Any]) -> float:
166
- if self.first_battle:
167
- return 0.0
189
+ prev_badges = action.get("prev_badges", 0)
190
+ current_badges = state.get("badges", 0)
168
191
 
169
- prev_in_battle = action.get("prev_in_battle", False)
170
- current_in_battle = state["in_battle"]
192
+ # Check if Boulder Badge (bit 0) was newly earned
193
+ boulder_badge_mask = 0x01
194
+ prev_has_badge = prev_badges & boulder_badge_mask
195
+ current_has_badge = current_badges & boulder_badge_mask
196
+
197
+ if not prev_has_badge and current_has_badge:
198
+ return 50.0 # MASSIVE reward for completing the main objective
171
199
 
172
- if not prev_in_battle and current_in_battle:
173
- self.first_battle = True
174
- return 5.0
175
200
  return 0.0
176
201
 
177
202
 
178
- class DirectionExplorationReward(RewardComponent):
179
- """Reward for trying all movement directions - +1.0 points when complete"""
203
+ class StepPenaltyComponent(RewardComponent):
204
+ """Small penalty for each step to encourage efficiency"""
180
205
 
181
- def __init__(self):
182
- self.directions_tried: Set[str] = set()
183
- self.reward_given = False
206
+ def __init__(self, penalty: float = 0.0): # Changed from -0.005 to 0.0
207
+ self.penalty = penalty
184
208
 
185
209
  async def score(self, state: Dict[str, Any], action: Dict[str, Any]) -> float:
186
- if self.reward_given:
187
- return 0.0
210
+ return self.penalty
188
211
 
189
- # Track movement directions based on position changes
190
- prev_x = action.get("prev_player_x", state["player_x"])
191
- prev_y = action.get("prev_player_y", state["player_y"])
192
- current_x = state["player_x"]
193
- current_y = state["player_y"]
194
-
195
- if current_x > prev_x:
196
- self.directions_tried.add("RIGHT")
197
- elif current_x < prev_x:
198
- self.directions_tried.add("LEFT")
199
- elif current_y > prev_y:
200
- self.directions_tried.add("DOWN")
201
- elif current_y < prev_y:
202
- self.directions_tried.add("UP")
203
-
204
- if len(self.directions_tried) >= 4:
205
- self.reward_given = True
206
- return 1.0
207
- return 0.0
208
212
 
213
+ # ===== LEGACY COMPONENTS (kept for compatibility) =====
209
214
 
210
- class BuildingExplorationReward(RewardComponent):
211
- """Reward for entering different buildings - +0.5 points per building"""
212
215
 
213
- def __init__(self):
214
- self.buildings_entered: Set[int] = set()
216
+ class BadgeRewardComponent(RewardComponent):
217
+ """Legacy badge reward - now handled by BadgeVictoryReward"""
215
218
 
216
219
  async def score(self, state: Dict[str, Any], action: Dict[str, Any]) -> float:
217
- prev_map = action.get("prev_map_id", -1)
218
- current_map = state["map_id"]
220
+ return 0.0 # Handled by BadgeVictoryReward
219
221
 
220
- # Entering a new building from town
221
- if (
222
- prev_map == 0 and current_map > 0 and current_map not in [1, 2]
223
- ): # From town to new building
224
- if current_map not in self.buildings_entered:
225
- self.buildings_entered.add(current_map)
226
- return 0.5
227
- return 0.0
228
222
 
223
+ class MapTransitionComponent(RewardComponent):
224
+ """Legacy map transition - now handled by RouteExplorationReward"""
225
+
226
+ async def score(self, state: Dict[str, Any], action: Dict[str, Any]) -> float:
227
+ return 0.0 # Handled by RouteExplorationReward
229
228
 
230
- class ObjectInteractionReward(RewardComponent):
231
- """Reward for pressing A on various objects - +0.3 points per object"""
232
229
 
233
- def __init__(self):
234
- self.objects_interacted: Set[tuple] = set()
230
+ class BattleVictoryComponent(RewardComponent):
231
+ """Legacy battle victory - now handled by BattleProgressionReward"""
235
232
 
236
233
  async def score(self, state: Dict[str, Any], action: Dict[str, Any]) -> float:
237
- # Detect A button interactions that trigger text
238
- if state["text_box_active"] and not action.get("prev_text_box_active", False):
239
- object_key = (state["player_x"], state["player_y"], state["map_id"])
240
- if object_key not in self.objects_interacted:
241
- self.objects_interacted.add(object_key)
242
- return 0.3
243
- return 0.0
244
-
234
+ return 0.0 # Handled by BattleProgressionReward
245
235
 
246
- class TownExplorationReward(RewardComponent):
247
- """Reward for thorough town exploration - +0.1 per new position"""
248
236
 
249
- def __init__(self):
250
- self.positions_visited: Set[tuple] = set()
237
+ class LevelUpComponent(RewardComponent):
238
+ """Legacy level up - now handled by StrategicTrainingReward"""
251
239
 
252
240
  async def score(self, state: Dict[str, Any], action: Dict[str, Any]) -> float:
253
- if state["map_id"] == 0: # In Pallet Town
254
- position_key = (state["player_x"], state["player_y"])
255
- if position_key not in self.positions_visited:
256
- self.positions_visited.add(position_key)
257
- return 0.1
258
- return 0.0
259
-
241
+ return 0.0 # Handled by StrategicTrainingReward
260
242
 
261
- class RouteAttemptReward(RewardComponent):
262
- """Reward for trying to leave town (triggers story) - +3.0 points"""
263
243
 
264
- def __init__(self):
265
- self.route_attempted = False
244
+ class XPGainComponent(RewardComponent):
245
+ """Legacy XP gain - now handled by StrategicTrainingReward"""
266
246
 
267
247
  async def score(self, state: Dict[str, Any], action: Dict[str, Any]) -> float:
268
- if self.route_attempted:
269
- return 0.0
270
-
271
- # Detect reaching the edge of Pallet Town (attempting to go north)
272
- if state["map_id"] == 0: # In Pallet Town
273
- if state["player_y"] <= 1: # At northern edge
274
- self.route_attempted = True
275
- return 3.0
276
- return 0.0
248
+ return 0.0 # Handled by StrategicTrainingReward