synth-ai 0.2.16__py3-none-any.whl → 0.2.17__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of synth-ai might be problematic. Click here for more details.
- examples/analyze_semantic_words.sh +2 -2
- examples/blog_posts/pokemon_vl/README.md +98 -0
- examples/blog_posts/pokemon_vl/configs/eval_qwen3_vl.toml +25 -0
- examples/blog_posts/pokemon_vl/configs/eval_rl_final.toml +24 -0
- examples/blog_posts/pokemon_vl/configs/filter_high_reward.toml +10 -0
- examples/blog_posts/pokemon_vl/configs/train_rl_from_sft.toml +42 -0
- examples/blog_posts/pokemon_vl/configs/train_sft_qwen4b_vl.toml +40 -0
- examples/blog_posts/warming_up_to_rl/README.md +158 -0
- examples/blog_posts/warming_up_to_rl/configs/eval_ft_qwen4b.toml +25 -0
- examples/blog_posts/warming_up_to_rl/configs/eval_groq_qwen32b.toml +25 -0
- examples/blog_posts/warming_up_to_rl/configs/eval_openai_gpt_oss_120b.toml +29 -0
- examples/blog_posts/warming_up_to_rl/configs/filter_high_reward_dataset.toml +10 -0
- examples/blog_posts/warming_up_to_rl/configs/train_rl_from_sft.toml +41 -0
- examples/blog_posts/warming_up_to_rl/configs/train_sft_qwen4b.toml +40 -0
- examples/dev/qwen3_32b_qlora_4xh100.toml +5 -0
- examples/multi_step/configs/crafter_rl_outcome.toml +1 -1
- examples/multi_step/configs/crafter_rl_stepwise_hosted_judge.toml +65 -107
- examples/multi_step/configs/crafter_rl_stepwise_shaped.toml +1 -1
- examples/multi_step/configs/crafter_rl_stepwise_simple.toml +1 -1
- examples/multi_step/configs/crafter_rl_stepwise_simple_NEW_FORMAT.toml +105 -0
- examples/multi_step/configs/verilog_rl_lora.toml +80 -123
- examples/qwen_coder/configs/coder_lora_30b.toml +1 -3
- examples/qwen_coder/configs/coder_lora_4b.toml +4 -1
- examples/qwen_coder/configs/coder_lora_small.toml +1 -3
- examples/qwen_vl/README.md +10 -12
- examples/qwen_vl/SETUP_COMPLETE.md +7 -8
- examples/qwen_vl/VISION_TESTS_COMPLETE.md +2 -3
- examples/qwen_vl/collect_data_via_cli.md +76 -84
- examples/qwen_vl/collect_vision_traces.py +4 -4
- examples/qwen_vl/configs/crafter_rl_vision_qwen3vl4b.toml +40 -57
- examples/qwen_vl/configs/crafter_vlm_sft_example.toml +1 -2
- examples/qwen_vl/configs/eval_gpt4o_mini_vision.toml +20 -37
- examples/qwen_vl/configs/eval_gpt5nano_vision.toml +21 -40
- examples/qwen_vl/configs/eval_qwen3vl_vision.toml +26 -0
- examples/qwen_vl/configs/{filter_qwen2vl_sft.toml → filter_qwen3vl_sft.toml} +4 -5
- examples/qwen_vl/configs/filter_vision_sft.toml +2 -3
- examples/qwen_vl/crafter_qwen_vl_agent.py +5 -5
- examples/qwen_vl/run_vision_comparison.sh +6 -7
- examples/rl/README.md +5 -5
- examples/rl/configs/rl_from_base_qwen.toml +26 -1
- examples/rl/configs/rl_from_base_qwen17.toml +5 -2
- examples/rl/task_app/README.md +1 -2
- examples/rl/task_app/math_single_step.py +2 -2
- examples/run_crafter_demo.sh +2 -2
- examples/sft/README.md +1 -1
- examples/sft/configs/crafter_fft_qwen0p6b.toml +4 -1
- examples/sft/configs/crafter_lora_qwen0p6b.toml +4 -1
- examples/swe/task_app/README.md +32 -2
- examples/swe/task_app/grpo_swe_mini.py +4 -0
- examples/swe/task_app/hosted/envs/crafter/react_agent.py +1 -1
- examples/swe/task_app/hosted/envs/mini_swe/environment.py +37 -10
- examples/swe/task_app/hosted/inference/openai_client.py +4 -4
- examples/swe/task_app/morph_backend.py +178 -0
- examples/task_apps/crafter/task_app/README.md +1 -1
- examples/task_apps/crafter/task_app/grpo_crafter.py +66 -3
- examples/task_apps/crafter/task_app/grpo_crafter_task_app.py +1 -1
- examples/task_apps/crafter/task_app/synth_envs_hosted/envs/crafter/policy.py +4 -26
- examples/task_apps/crafter/task_app/synth_envs_hosted/envs/crafter/react_agent.py +1 -2
- examples/task_apps/crafter/task_app/synth_envs_hosted/inference/openai_client.py +17 -49
- examples/task_apps/crafter/task_app/synth_envs_hosted/policy_routes.py +13 -5
- examples/task_apps/crafter/task_app/synth_envs_hosted/rollout.py +15 -1
- examples/task_apps/enron/task_app/grpo_enron_task_app.py +1 -1
- examples/task_apps/math/README.md +1 -2
- examples/task_apps/pokemon_red/README.md +3 -4
- examples/task_apps/pokemon_red/eval_image_only_gpt4o.toml +6 -5
- examples/task_apps/pokemon_red/eval_pokemon_red_policy.py +1 -2
- examples/task_apps/pokemon_red/task_app.py +36 -5
- examples/task_apps/sokoban/README.md +2 -3
- examples/task_apps/verilog/eval_groq_qwen32b.toml +12 -14
- examples/task_apps/verilog/task_app/grpo_verilog_task_app.py +1 -1
- examples/vlm/configs/crafter_vlm_gpt4o.toml +4 -1
- examples/warming_up_to_rl/configs/crafter_fft.toml +4 -1
- examples/warming_up_to_rl/configs/crafter_fft_4b.toml +0 -2
- examples/warming_up_to_rl/configs/rl_from_base_qwen4b.toml +2 -2
- examples/warming_up_to_rl/run_local_rollout_traced.py +1 -1
- examples/warming_up_to_rl/task_app/README.md +1 -1
- examples/warming_up_to_rl/task_app/grpo_crafter.py +134 -3
- examples/warming_up_to_rl/task_app/grpo_crafter_task_app.py +1 -1
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/policy.py +3 -27
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/react_agent.py +1 -1
- examples/warming_up_to_rl/task_app/synth_envs_hosted/inference/openai_client.py +4 -4
- examples/warming_up_to_rl/task_app/synth_envs_hosted/policy_routes.py +6 -3
- examples/workflows/math_rl/configs/rl_from_base_qwen.toml +27 -0
- examples/workflows/math_rl/configs/rl_from_base_qwen17.toml +5 -0
- synth_ai/api/train/builders.py +9 -3
- synth_ai/api/train/cli.py +125 -10
- synth_ai/api/train/configs/__init__.py +8 -1
- synth_ai/api/train/configs/rl.py +32 -7
- synth_ai/api/train/configs/sft.py +6 -2
- synth_ai/api/train/configs/shared.py +59 -2
- synth_ai/auth/credentials.py +119 -0
- synth_ai/cli/__init__.py +12 -4
- synth_ai/cli/commands/__init__.py +17 -0
- synth_ai/cli/commands/demo/__init__.py +6 -0
- synth_ai/cli/commands/demo/core.py +163 -0
- synth_ai/cli/commands/deploy/__init__.py +23 -0
- synth_ai/cli/commands/deploy/core.py +614 -0
- synth_ai/cli/commands/deploy/errors.py +72 -0
- synth_ai/cli/commands/deploy/validation.py +11 -0
- synth_ai/cli/commands/eval/__init__.py +19 -0
- synth_ai/cli/commands/eval/core.py +1109 -0
- synth_ai/cli/commands/eval/errors.py +81 -0
- synth_ai/cli/commands/eval/validation.py +133 -0
- synth_ai/cli/commands/filter/__init__.py +12 -0
- synth_ai/cli/commands/filter/core.py +388 -0
- synth_ai/cli/commands/filter/errors.py +55 -0
- synth_ai/cli/commands/filter/validation.py +77 -0
- synth_ai/cli/commands/help/__init__.py +177 -0
- synth_ai/cli/commands/help/core.py +73 -0
- synth_ai/cli/commands/status/__init__.py +64 -0
- synth_ai/cli/commands/status/client.py +192 -0
- synth_ai/cli/commands/status/config.py +92 -0
- synth_ai/cli/commands/status/errors.py +20 -0
- synth_ai/cli/commands/status/formatters.py +164 -0
- synth_ai/cli/commands/status/subcommands/__init__.py +9 -0
- synth_ai/cli/commands/status/subcommands/files.py +79 -0
- synth_ai/cli/commands/status/subcommands/jobs.py +334 -0
- synth_ai/cli/commands/status/subcommands/models.py +79 -0
- synth_ai/cli/commands/status/subcommands/runs.py +81 -0
- synth_ai/cli/commands/status/subcommands/summary.py +47 -0
- synth_ai/cli/commands/status/utils.py +114 -0
- synth_ai/cli/commands/train/__init__.py +53 -0
- synth_ai/cli/commands/train/core.py +21 -0
- synth_ai/cli/commands/train/errors.py +117 -0
- synth_ai/cli/commands/train/judge_schemas.py +199 -0
- synth_ai/cli/commands/train/judge_validation.py +304 -0
- synth_ai/cli/commands/train/validation.py +443 -0
- synth_ai/cli/demo.py +2 -162
- synth_ai/cli/deploy/__init__.py +28 -0
- synth_ai/cli/deploy/core.py +5 -0
- synth_ai/cli/deploy/errors.py +23 -0
- synth_ai/cli/deploy/validation.py +5 -0
- synth_ai/cli/eval/__init__.py +36 -0
- synth_ai/cli/eval/core.py +5 -0
- synth_ai/cli/eval/errors.py +31 -0
- synth_ai/cli/eval/validation.py +5 -0
- synth_ai/cli/filter/__init__.py +28 -0
- synth_ai/cli/filter/core.py +5 -0
- synth_ai/cli/filter/errors.py +23 -0
- synth_ai/cli/filter/validation.py +5 -0
- synth_ai/cli/modal_serve/__init__.py +12 -0
- synth_ai/cli/modal_serve/core.py +14 -0
- synth_ai/cli/modal_serve/errors.py +8 -0
- synth_ai/cli/modal_serve/validation.py +11 -0
- synth_ai/cli/serve/__init__.py +12 -0
- synth_ai/cli/serve/core.py +14 -0
- synth_ai/cli/serve/errors.py +8 -0
- synth_ai/cli/serve/validation.py +11 -0
- synth_ai/cli/setup.py +20 -265
- synth_ai/cli/status.py +7 -126
- synth_ai/cli/task_app_deploy.py +1 -10
- synth_ai/cli/task_app_modal_serve.py +4 -9
- synth_ai/cli/task_app_serve.py +4 -11
- synth_ai/cli/task_apps.py +58 -1487
- synth_ai/cli/train/__init__.py +12 -0
- synth_ai/cli/train/core.py +21 -0
- synth_ai/cli/train/errors.py +8 -0
- synth_ai/cli/train/validation.py +24 -0
- synth_ai/cli/train.py +1 -14
- synth_ai/demos/crafter/grpo_crafter_task_app.py +1 -1
- synth_ai/demos/demo_task_apps/crafter/grpo_crafter_task_app.py +1 -1
- synth_ai/environments/examples/red/engine.py +33 -12
- synth_ai/environments/examples/red/engine_helpers/reward_components.py +151 -179
- synth_ai/environments/examples/red/environment.py +26 -0
- synth_ai/environments/examples/red/trace_hooks_v3.py +168 -0
- synth_ai/http.py +12 -0
- synth_ai/judge_schemas.py +10 -11
- synth_ai/learning/rl/client.py +3 -1
- synth_ai/streaming/__init__.py +29 -0
- synth_ai/streaming/config.py +94 -0
- synth_ai/streaming/handlers.py +469 -0
- synth_ai/streaming/streamer.py +301 -0
- synth_ai/streaming/types.py +95 -0
- synth_ai/task/validators.py +2 -2
- synth_ai/tracing_v3/migration_helper.py +1 -2
- synth_ai/utils/env.py +25 -18
- synth_ai/utils/http.py +4 -1
- synth_ai/utils/modal.py +2 -2
- {synth_ai-0.2.16.dist-info → synth_ai-0.2.17.dist-info}/METADATA +8 -3
- {synth_ai-0.2.16.dist-info → synth_ai-0.2.17.dist-info}/RECORD +184 -109
- examples/qwen_vl/configs/eval_qwen2vl_vision.toml +0 -44
- synth_ai/cli/tui.py +0 -62
- synth_ai/tui/__init__.py +0 -5
- synth_ai/tui/__main__.py +0 -13
- synth_ai/tui/cli/__init__.py +0 -1
- synth_ai/tui/cli/query_experiments.py +0 -164
- synth_ai/tui/cli/query_experiments_v3.py +0 -164
- synth_ai/tui/dashboard.py +0 -911
- {synth_ai-0.2.16.dist-info → synth_ai-0.2.17.dist-info}/WHEEL +0 -0
- {synth_ai-0.2.16.dist-info → synth_ai-0.2.17.dist-info}/entry_points.txt +0 -0
- {synth_ai-0.2.16.dist-info → synth_ai-0.2.17.dist-info}/licenses/LICENSE +0 -0
- {synth_ai-0.2.16.dist-info → synth_ai-0.2.17.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,12 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from .core import register, train_command
|
|
4
|
+
from .errors import TrainCliError
|
|
5
|
+
from .validation import validate_train_environment
|
|
6
|
+
|
|
7
|
+
__all__ = [
|
|
8
|
+
"register",
|
|
9
|
+
"train_command",
|
|
10
|
+
"TrainCliError",
|
|
11
|
+
"validate_train_environment",
|
|
12
|
+
]
|
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import click
|
|
4
|
+
from synth_ai.api.train.cli import (
|
|
5
|
+
register as _register_with_cli,
|
|
6
|
+
)
|
|
7
|
+
from synth_ai.api.train.cli import (
|
|
8
|
+
train_command as _train_command,
|
|
9
|
+
)
|
|
10
|
+
|
|
11
|
+
__all__ = ["register", "train_command"]
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def register(cli: click.Group) -> None:
|
|
15
|
+
"""Attach the train command to the root CLI."""
|
|
16
|
+
_register_with_cli(cli)
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def train_command(*args, **kwargs):
|
|
20
|
+
"""Entrypoint used by the train CLI command."""
|
|
21
|
+
return _train_command(*args, **kwargs)
|
|
@@ -0,0 +1,24 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from collections.abc import Iterable
|
|
4
|
+
from pathlib import Path
|
|
5
|
+
from typing import Dict, Tuple
|
|
6
|
+
|
|
7
|
+
from synth_ai.api.train.env_resolver import KeySpec, resolve_env
|
|
8
|
+
|
|
9
|
+
__all__ = ["validate_train_environment"]
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def validate_train_environment(
|
|
13
|
+
*,
|
|
14
|
+
config_path: Path | None,
|
|
15
|
+
explicit_env_paths: Iterable[str],
|
|
16
|
+
required_keys: list[KeySpec],
|
|
17
|
+
) -> Tuple[Path, Dict[str, str]]:
|
|
18
|
+
"""Validate and resolve environment secrets used by the train command."""
|
|
19
|
+
resolved_path, resolved_keys = resolve_env(
|
|
20
|
+
config_path=config_path,
|
|
21
|
+
explicit_env_paths=explicit_env_paths,
|
|
22
|
+
required_keys=required_keys,
|
|
23
|
+
)
|
|
24
|
+
return resolved_path, resolved_keys
|
synth_ai/cli/train.py
CHANGED
|
@@ -1,18 +1,5 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
|
-
from
|
|
4
|
-
|
|
5
|
-
from synth_ai.api.train.cli import register as _register
|
|
6
|
-
from synth_ai.api.train.cli import train_command as _train_command
|
|
3
|
+
from synth_ai.cli.commands.train.core import register, train_command
|
|
7
4
|
|
|
8
5
|
__all__ = ["register", "train_command"]
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
def register(cli: Any) -> None:
|
|
12
|
-
"""Compatibility wrapper for the legacy train CLI location."""
|
|
13
|
-
|
|
14
|
-
_register(cli)
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
def train_command(*args: Any, **kwargs: Any) -> Any:
|
|
18
|
-
return _train_command(*args, **kwargs)
|
|
@@ -3,7 +3,7 @@
|
|
|
3
3
|
This module now delegates to the TaskAppConfig defined in the local example at
|
|
4
4
|
`examples/warming_up_to_rl/task_app/grpo_crafter.py`. It is kept for legacy usage
|
|
5
5
|
(running the file directly or targeting `fastapi_app` from external tooling).
|
|
6
|
-
Prefer using `uvx synth-ai
|
|
6
|
+
Prefer using `uvx synth-ai deploy --runtime uvicorn grpo-crafter` for local development and testing.
|
|
7
7
|
"""
|
|
8
8
|
|
|
9
9
|
from __future__ import annotations
|
|
@@ -3,7 +3,7 @@
|
|
|
3
3
|
This module now delegates to the TaskAppConfig defined in the local example at
|
|
4
4
|
`examples/task_apps/crafter/task_app/grpo_crafter.py`. It is kept for legacy usage
|
|
5
5
|
(running the file directly or targeting `fastapi_app` from external tooling).
|
|
6
|
-
Prefer using `uvx synth-ai
|
|
6
|
+
Prefer using `uvx synth-ai deploy --runtime uvicorn grpo-crafter` for local development and testing.
|
|
7
7
|
"""
|
|
8
8
|
|
|
9
9
|
from __future__ import annotations
|
|
@@ -14,12 +14,15 @@ from synth_ai.environments.stateful.engine import StatefulEngine, StatefulEngine
|
|
|
14
14
|
from synth_ai.environments.tasks.core import TaskInstance
|
|
15
15
|
|
|
16
16
|
from .engine_helpers.reward_components import (
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
17
|
+
RouteExplorationReward,
|
|
18
|
+
StrategicTrainingReward,
|
|
19
|
+
BattleProgressionReward,
|
|
20
|
+
GymPreparationReward,
|
|
21
|
+
ItemCollectionReward,
|
|
22
|
+
HealingManagementReward,
|
|
23
|
+
EfficientExplorationReward,
|
|
24
|
+
BadgeVictoryReward,
|
|
21
25
|
StepPenaltyComponent,
|
|
22
|
-
XPGainComponent,
|
|
23
26
|
)
|
|
24
27
|
from .engine_helpers.state_extraction import extract_game_state
|
|
25
28
|
|
|
@@ -268,15 +271,27 @@ class PokemonRedEngine(StatefulEngine, IReproducibleEngine):
|
|
|
268
271
|
# For testing purposes, use None emulator
|
|
269
272
|
self.emulator = None
|
|
270
273
|
|
|
271
|
-
# Initialize reward stack with
|
|
274
|
+
# Initialize reward stack with comprehensive progress-based components
|
|
272
275
|
self.reward_stack = RewardStack(
|
|
273
276
|
components=[
|
|
274
|
-
|
|
275
|
-
|
|
276
|
-
|
|
277
|
-
|
|
278
|
-
|
|
279
|
-
|
|
277
|
+
# Major progress rewards
|
|
278
|
+
BadgeVictoryReward(), # +50.0 for Boulder Badge (main goal)
|
|
279
|
+
RouteExplorationReward(), # +1.0-5.0 for reaching key areas
|
|
280
|
+
GymPreparationReward(), # +3.0 for being gym-ready
|
|
281
|
+
|
|
282
|
+
# Training and battle rewards
|
|
283
|
+
StrategicTrainingReward(), # +0.2-3.0 for level ups and milestones
|
|
284
|
+
BattleProgressionReward(), # +0.1-1.0 for battles
|
|
285
|
+
|
|
286
|
+
# Resource management rewards
|
|
287
|
+
ItemCollectionReward(), # +0.1-0.5 for collecting items
|
|
288
|
+
HealingManagementReward(), # +0.05-0.8 for healing Pokemon
|
|
289
|
+
|
|
290
|
+
# Exploration efficiency
|
|
291
|
+
EfficientExplorationReward(), # +0.02 for discovering new positions
|
|
292
|
+
|
|
293
|
+
# No penalty for unproductive actions
|
|
294
|
+
StepPenaltyComponent(penalty=0.0), # 0.0 per step
|
|
280
295
|
]
|
|
281
296
|
)
|
|
282
297
|
|
|
@@ -640,6 +655,12 @@ class PokemonRedEngine(StatefulEngine, IReproducibleEngine):
|
|
|
640
655
|
"prev_text_box_active": bool(prev_state.get("text_box_active", False)),
|
|
641
656
|
"prev_enemy_hp_current": int(prev_state.get("enemy_hp_current", 0)),
|
|
642
657
|
"prev_enemy_hp_percentage": float(prev_state.get("enemy_hp_percentage", 0.0)),
|
|
658
|
+
"prev_player_x": int(prev_state.get("player_x", 0)),
|
|
659
|
+
"prev_player_y": int(prev_state.get("player_y", 0)),
|
|
660
|
+
"prev_party": prev_state.get("party", []),
|
|
661
|
+
"prev_inventory": prev_state.get("inventory", []),
|
|
662
|
+
"prev_party_hp_current": int(prev_state.get("party_hp_current", 0)),
|
|
663
|
+
"prev_party_hp_max": int(prev_state.get("party_hp_max", 0)),
|
|
643
664
|
},
|
|
644
665
|
)
|
|
645
666
|
except Exception as e:
|
|
@@ -3,274 +3,246 @@ from typing import Any, Dict, Set
|
|
|
3
3
|
from synth_ai.environments.environment.rewards.core import RewardComponent
|
|
4
4
|
|
|
5
5
|
|
|
6
|
-
|
|
7
|
-
|
|
6
|
+
# ===== COMPREHENSIVE POKEMON RED PROGRESS REWARD SYSTEM =====
|
|
7
|
+
# Designed for deterministic rewards that guide toward beating Brock at Pewter Gym
|
|
8
8
|
|
|
9
|
-
async def score(self, state: Dict[str, Any], action: Dict[str, Any]) -> float:
|
|
10
|
-
prev_badges = action.get("prev_badges", 0)
|
|
11
|
-
current_badges = state["badges"]
|
|
12
|
-
new_badges = current_badges & ~prev_badges
|
|
13
|
-
badge_count = bin(new_badges).count("1")
|
|
14
|
-
return badge_count * 1.0
|
|
15
9
|
|
|
10
|
+
class RouteExplorationReward(RewardComponent):
|
|
11
|
+
"""High rewards for reaching key areas on the path to Pewter Gym - guides exploration"""
|
|
16
12
|
|
|
17
|
-
|
|
18
|
-
|
|
13
|
+
def __init__(self):
|
|
14
|
+
self.key_areas_reached: Set[int] = set()
|
|
19
15
|
|
|
20
16
|
async def score(self, state: Dict[str, Any], action: Dict[str, Any]) -> float:
|
|
21
|
-
prev_map = action.get("prev_map_id", -1)
|
|
22
17
|
current_map = state["map_id"]
|
|
23
|
-
|
|
24
|
-
|
|
18
|
+
prev_map = action.get("prev_map_id", -1)
|
|
25
19
|
|
|
26
|
-
|
|
27
|
-
|
|
20
|
+
# Key maps and rewards for progressing toward Pewter Gym
|
|
21
|
+
area_rewards = {
|
|
22
|
+
0: 0.0, # Pallet Town (starting point)
|
|
23
|
+
1: 2.0, # Route 1 - First step out of town (+2.0)
|
|
24
|
+
2: 1.5, # Viridian City - Major hub (+1.5)
|
|
25
|
+
3: 1.0, # Route 22 - Path to League (+1.0)
|
|
26
|
+
4: 1.0, # Route 2 - To Viridian Forest (+1.0)
|
|
27
|
+
5: 2.0, # Viridian Forest - Dense area (+2.0)
|
|
28
|
+
6: 1.5, # Pewter City - Target city (+1.5)
|
|
29
|
+
7: 5.0, # Pewter Gym - GOAL AREA (+5.0 for entering gym)
|
|
30
|
+
}
|
|
31
|
+
|
|
32
|
+
if current_map in area_rewards and current_map not in self.key_areas_reached:
|
|
33
|
+
if prev_map != current_map: # Only reward when actually entering new area
|
|
34
|
+
self.key_areas_reached.add(current_map)
|
|
35
|
+
return area_rewards[current_map]
|
|
28
36
|
|
|
29
|
-
async def score(self, state: Dict[str, Any], action: Dict[str, Any]) -> float:
|
|
30
|
-
prev_in_battle = action.get("prev_in_battle", False)
|
|
31
|
-
current_in_battle = state["in_battle"]
|
|
32
|
-
battle_outcome = state["battle_outcome"]
|
|
33
|
-
|
|
34
|
-
# Transitioning from battle to not in battle with victory
|
|
35
|
-
if prev_in_battle and not current_in_battle and battle_outcome == 1:
|
|
36
|
-
return 0.5
|
|
37
37
|
return 0.0
|
|
38
38
|
|
|
39
39
|
|
|
40
|
-
class
|
|
41
|
-
"""
|
|
40
|
+
class StrategicTrainingReward(RewardComponent):
|
|
41
|
+
"""Rewards for building Pokemon strength strategically"""
|
|
42
|
+
|
|
43
|
+
def __init__(self):
|
|
44
|
+
self.level_milestones: Set[int] = set()
|
|
45
|
+
self.last_level = 0
|
|
42
46
|
|
|
43
47
|
async def score(self, state: Dict[str, Any], action: Dict[str, Any]) -> float:
|
|
48
|
+
current_level = state.get("party_level", 0)
|
|
44
49
|
prev_level = action.get("prev_party_level", 0)
|
|
45
|
-
current_level = state["party_level"]
|
|
46
|
-
level_gain = max(0, current_level - prev_level)
|
|
47
|
-
return level_gain * 0.3
|
|
48
50
|
|
|
51
|
+
# Reward reaching key level milestones
|
|
52
|
+
milestone_rewards = {
|
|
53
|
+
8: 1.0, # Level 8 - Good for early battles
|
|
54
|
+
12: 2.0, # Level 12 - Ready for Brock
|
|
55
|
+
15: 3.0, # Level 15 - Strong Pokemon
|
|
56
|
+
}
|
|
49
57
|
|
|
50
|
-
|
|
51
|
-
|
|
58
|
+
if current_level > prev_level and current_level in milestone_rewards:
|
|
59
|
+
if current_level not in self.level_milestones:
|
|
60
|
+
self.level_milestones.add(current_level)
|
|
61
|
+
return milestone_rewards[current_level]
|
|
52
62
|
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
xp_gain = max(0, current_xp - prev_xp)
|
|
57
|
-
return xp_gain * 0.001 # Very small multiplier
|
|
63
|
+
# Small reward for any level up (0.2 points)
|
|
64
|
+
if current_level > prev_level:
|
|
65
|
+
return 0.2
|
|
58
66
|
|
|
67
|
+
return 0.0
|
|
59
68
|
|
|
60
|
-
class StepPenaltyComponent(RewardComponent):
|
|
61
|
-
"""Small penalty for each step to encourage efficiency"""
|
|
62
69
|
|
|
63
|
-
|
|
64
|
-
|
|
70
|
+
class BattleProgressionReward(RewardComponent):
|
|
71
|
+
"""Rewards for winning battles and gaining experience"""
|
|
65
72
|
|
|
66
73
|
async def score(self, state: Dict[str, Any], action: Dict[str, Any]) -> float:
|
|
67
|
-
|
|
74
|
+
prev_in_battle = action.get("prev_in_battle", False)
|
|
75
|
+
current_in_battle = state.get("in_battle", False)
|
|
76
|
+
battle_outcome = state.get("battle_outcome", 0)
|
|
68
77
|
|
|
78
|
+
# Large reward for battle victory (+1.0)
|
|
79
|
+
if prev_in_battle and not current_in_battle and battle_outcome == 1:
|
|
80
|
+
return 1.0
|
|
69
81
|
|
|
70
|
-
|
|
71
|
-
|
|
82
|
+
# Small reward for entering battle (+0.1) - shows engagement
|
|
83
|
+
if not prev_in_battle and current_in_battle:
|
|
84
|
+
return 0.1
|
|
72
85
|
|
|
73
|
-
async def score(self, state: Dict[str, Any], action: Dict[str, Any]) -> float:
|
|
74
|
-
# This would need more sophisticated menu tracking
|
|
75
86
|
return 0.0
|
|
76
87
|
|
|
77
88
|
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
class ExitHouseReward(RewardComponent):
|
|
82
|
-
"""High reward for first time leaving the starting house - +2.0 points"""
|
|
89
|
+
class GymPreparationReward(RewardComponent):
|
|
90
|
+
"""Rewards for preparing to challenge Brock"""
|
|
83
91
|
|
|
84
92
|
def __init__(self):
|
|
85
|
-
self.
|
|
93
|
+
self.prepared_for_gym = False
|
|
86
94
|
|
|
87
95
|
async def score(self, state: Dict[str, Any], action: Dict[str, Any]) -> float:
|
|
88
|
-
if self.
|
|
96
|
+
if self.prepared_for_gym:
|
|
89
97
|
return 0.0
|
|
90
98
|
|
|
91
|
-
|
|
92
|
-
|
|
99
|
+
# Check if in Pewter City area and have decent Pokemon
|
|
100
|
+
if state["map_id"] in [6, 7]: # Pewter City or Gym
|
|
101
|
+
party_level = state.get("party_level", 0)
|
|
102
|
+
party_count = len(state.get("party", []))
|
|
103
|
+
|
|
104
|
+
# Reward being prepared for gym battle
|
|
105
|
+
if party_level >= 10 and party_count >= 1:
|
|
106
|
+
self.prepared_for_gym = True
|
|
107
|
+
return 3.0 # Significant reward for being gym-ready
|
|
93
108
|
|
|
94
|
-
# Exit from house to town (assuming house maps are 1,2 and town is 0)
|
|
95
|
-
if prev_map in [1, 2] and current_map == 0:
|
|
96
|
-
self.house_exited = True
|
|
97
|
-
return 2.0
|
|
98
109
|
return 0.0
|
|
99
110
|
|
|
100
111
|
|
|
101
|
-
class
|
|
102
|
-
"""
|
|
112
|
+
class ItemCollectionReward(RewardComponent):
|
|
113
|
+
"""Rewards for collecting useful items"""
|
|
103
114
|
|
|
104
115
|
def __init__(self):
|
|
105
|
-
self.
|
|
116
|
+
self.items_collected: Set[int] = set()
|
|
106
117
|
|
|
107
118
|
async def score(self, state: Dict[str, Any], action: Dict[str, Any]) -> float:
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
# Use position as NPC identifier
|
|
111
|
-
npc_key = (state["player_x"], state["player_y"], state["map_id"])
|
|
112
|
-
if npc_key not in self.npcs_talked_to:
|
|
113
|
-
self.npcs_talked_to.add(npc_key)
|
|
114
|
-
return 0.8
|
|
115
|
-
return 0.0
|
|
119
|
+
prev_inventory = action.get("prev_inventory", [])
|
|
120
|
+
current_inventory = state.get("inventory", [])
|
|
116
121
|
|
|
122
|
+
# Check for new items
|
|
123
|
+
prev_item_ids = {item["item_id"] for item in prev_inventory}
|
|
124
|
+
current_item_ids = {item["item_id"] for item in current_inventory}
|
|
117
125
|
|
|
118
|
-
|
|
119
|
-
"""High reward for finding and entering Oak's lab - +2.5 points"""
|
|
126
|
+
new_items = current_item_ids - prev_item_ids
|
|
120
127
|
|
|
121
|
-
|
|
122
|
-
|
|
128
|
+
# Reward valuable items for gym preparation
|
|
129
|
+
valuable_items = {1, 2, 3, 4, 5, 10, 11, 12, 13} # Potions, Balls, etc.
|
|
130
|
+
reward = 0.0
|
|
131
|
+
|
|
132
|
+
for item_id in new_items:
|
|
133
|
+
if item_id not in self.items_collected:
|
|
134
|
+
self.items_collected.add(item_id)
|
|
135
|
+
if item_id in valuable_items:
|
|
136
|
+
reward += 0.5 # +0.5 per valuable item
|
|
137
|
+
else:
|
|
138
|
+
reward += 0.1 # +0.1 per other item
|
|
139
|
+
|
|
140
|
+
return reward
|
|
141
|
+
|
|
142
|
+
|
|
143
|
+
class HealingManagementReward(RewardComponent):
|
|
144
|
+
"""Rewards for keeping Pokemon healthy"""
|
|
123
145
|
|
|
124
146
|
async def score(self, state: Dict[str, Any], action: Dict[str, Any]) -> float:
|
|
125
|
-
|
|
147
|
+
prev_party = action.get("prev_party", [])
|
|
148
|
+
current_party = state.get("party", [])
|
|
149
|
+
|
|
150
|
+
if not prev_party or not current_party:
|
|
126
151
|
return 0.0
|
|
127
152
|
|
|
128
|
-
|
|
129
|
-
|
|
153
|
+
# Reward healing Pokemon back to full health
|
|
154
|
+
prev_hp_pct = sum(p.get("hp_percentage", 0) for p in prev_party) / len(prev_party)
|
|
155
|
+
current_hp_pct = sum(p.get("hp_percentage", 0) for p in current_party) / len(current_party)
|
|
156
|
+
|
|
157
|
+
# Significant improvement in health
|
|
158
|
+
if current_hp_pct > prev_hp_pct + 20: # Healed at least 20% overall
|
|
159
|
+
return 0.8
|
|
160
|
+
|
|
161
|
+
# Small reward for maintaining good health
|
|
162
|
+
if current_hp_pct >= 80 and prev_hp_pct >= 80:
|
|
163
|
+
return 0.05
|
|
130
164
|
|
|
131
|
-
# Entering Oak's lab (assuming map 3)
|
|
132
|
-
if prev_map == 0 and current_map == 3:
|
|
133
|
-
self.lab_discovered = True
|
|
134
|
-
return 2.5
|
|
135
165
|
return 0.0
|
|
136
166
|
|
|
137
167
|
|
|
138
|
-
class
|
|
139
|
-
"""
|
|
168
|
+
class EfficientExplorationReward(RewardComponent):
|
|
169
|
+
"""Rewards for exploring efficiently without getting lost"""
|
|
140
170
|
|
|
141
171
|
def __init__(self):
|
|
142
|
-
self.
|
|
172
|
+
self.positions_visited: Set[tuple] = set()
|
|
143
173
|
|
|
144
174
|
async def score(self, state: Dict[str, Any], action: Dict[str, Any]) -> float:
|
|
145
|
-
|
|
146
|
-
|
|
175
|
+
# Track unique positions visited in each map
|
|
176
|
+
position_key = (state["map_id"], state["player_x"], state["player_y"])
|
|
147
177
|
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
178
|
+
if position_key not in self.positions_visited:
|
|
179
|
+
self.positions_visited.add(position_key)
|
|
180
|
+
return 0.02 # Small reward for discovering new areas
|
|
151
181
|
|
|
152
|
-
if prev_party_count == 0 and current_party_count == 1:
|
|
153
|
-
if state["map_id"] == 3: # In Oak's lab
|
|
154
|
-
self.starter_obtained = True
|
|
155
|
-
return 10.0
|
|
156
182
|
return 0.0
|
|
157
183
|
|
|
158
184
|
|
|
159
|
-
class
|
|
160
|
-
"""
|
|
161
|
-
|
|
162
|
-
def __init__(self):
|
|
163
|
-
self.first_battle = False
|
|
185
|
+
class BadgeVictoryReward(RewardComponent):
|
|
186
|
+
"""HUGE reward for achieving the main goal - Boulder Badge"""
|
|
164
187
|
|
|
165
188
|
async def score(self, state: Dict[str, Any], action: Dict[str, Any]) -> float:
|
|
166
|
-
|
|
167
|
-
|
|
189
|
+
prev_badges = action.get("prev_badges", 0)
|
|
190
|
+
current_badges = state.get("badges", 0)
|
|
168
191
|
|
|
169
|
-
|
|
170
|
-
|
|
192
|
+
# Check if Boulder Badge (bit 0) was newly earned
|
|
193
|
+
boulder_badge_mask = 0x01
|
|
194
|
+
prev_has_badge = prev_badges & boulder_badge_mask
|
|
195
|
+
current_has_badge = current_badges & boulder_badge_mask
|
|
196
|
+
|
|
197
|
+
if not prev_has_badge and current_has_badge:
|
|
198
|
+
return 50.0 # MASSIVE reward for completing the main objective
|
|
171
199
|
|
|
172
|
-
if not prev_in_battle and current_in_battle:
|
|
173
|
-
self.first_battle = True
|
|
174
|
-
return 5.0
|
|
175
200
|
return 0.0
|
|
176
201
|
|
|
177
202
|
|
|
178
|
-
class
|
|
179
|
-
"""
|
|
203
|
+
class StepPenaltyComponent(RewardComponent):
|
|
204
|
+
"""Small penalty for each step to encourage efficiency"""
|
|
180
205
|
|
|
181
|
-
def __init__(self):
|
|
182
|
-
self.
|
|
183
|
-
self.reward_given = False
|
|
206
|
+
def __init__(self, penalty: float = 0.0): # Changed from -0.005 to 0.0
|
|
207
|
+
self.penalty = penalty
|
|
184
208
|
|
|
185
209
|
async def score(self, state: Dict[str, Any], action: Dict[str, Any]) -> float:
|
|
186
|
-
|
|
187
|
-
return 0.0
|
|
210
|
+
return self.penalty
|
|
188
211
|
|
|
189
|
-
# Track movement directions based on position changes
|
|
190
|
-
prev_x = action.get("prev_player_x", state["player_x"])
|
|
191
|
-
prev_y = action.get("prev_player_y", state["player_y"])
|
|
192
|
-
current_x = state["player_x"]
|
|
193
|
-
current_y = state["player_y"]
|
|
194
|
-
|
|
195
|
-
if current_x > prev_x:
|
|
196
|
-
self.directions_tried.add("RIGHT")
|
|
197
|
-
elif current_x < prev_x:
|
|
198
|
-
self.directions_tried.add("LEFT")
|
|
199
|
-
elif current_y > prev_y:
|
|
200
|
-
self.directions_tried.add("DOWN")
|
|
201
|
-
elif current_y < prev_y:
|
|
202
|
-
self.directions_tried.add("UP")
|
|
203
|
-
|
|
204
|
-
if len(self.directions_tried) >= 4:
|
|
205
|
-
self.reward_given = True
|
|
206
|
-
return 1.0
|
|
207
|
-
return 0.0
|
|
208
212
|
|
|
213
|
+
# ===== LEGACY COMPONENTS (kept for compatibility) =====
|
|
209
214
|
|
|
210
|
-
class BuildingExplorationReward(RewardComponent):
|
|
211
|
-
"""Reward for entering different buildings - +0.5 points per building"""
|
|
212
215
|
|
|
213
|
-
|
|
214
|
-
|
|
216
|
+
class BadgeRewardComponent(RewardComponent):
|
|
217
|
+
"""Legacy badge reward - now handled by BadgeVictoryReward"""
|
|
215
218
|
|
|
216
219
|
async def score(self, state: Dict[str, Any], action: Dict[str, Any]) -> float:
|
|
217
|
-
|
|
218
|
-
current_map = state["map_id"]
|
|
220
|
+
return 0.0 # Handled by BadgeVictoryReward
|
|
219
221
|
|
|
220
|
-
# Entering a new building from town
|
|
221
|
-
if (
|
|
222
|
-
prev_map == 0 and current_map > 0 and current_map not in [1, 2]
|
|
223
|
-
): # From town to new building
|
|
224
|
-
if current_map not in self.buildings_entered:
|
|
225
|
-
self.buildings_entered.add(current_map)
|
|
226
|
-
return 0.5
|
|
227
|
-
return 0.0
|
|
228
222
|
|
|
223
|
+
class MapTransitionComponent(RewardComponent):
|
|
224
|
+
"""Legacy map transition - now handled by RouteExplorationReward"""
|
|
225
|
+
|
|
226
|
+
async def score(self, state: Dict[str, Any], action: Dict[str, Any]) -> float:
|
|
227
|
+
return 0.0 # Handled by RouteExplorationReward
|
|
229
228
|
|
|
230
|
-
class ObjectInteractionReward(RewardComponent):
|
|
231
|
-
"""Reward for pressing A on various objects - +0.3 points per object"""
|
|
232
229
|
|
|
233
|
-
|
|
234
|
-
|
|
230
|
+
class BattleVictoryComponent(RewardComponent):
|
|
231
|
+
"""Legacy battle victory - now handled by BattleProgressionReward"""
|
|
235
232
|
|
|
236
233
|
async def score(self, state: Dict[str, Any], action: Dict[str, Any]) -> float:
|
|
237
|
-
#
|
|
238
|
-
if state["text_box_active"] and not action.get("prev_text_box_active", False):
|
|
239
|
-
object_key = (state["player_x"], state["player_y"], state["map_id"])
|
|
240
|
-
if object_key not in self.objects_interacted:
|
|
241
|
-
self.objects_interacted.add(object_key)
|
|
242
|
-
return 0.3
|
|
243
|
-
return 0.0
|
|
244
|
-
|
|
234
|
+
return 0.0 # Handled by BattleProgressionReward
|
|
245
235
|
|
|
246
|
-
class TownExplorationReward(RewardComponent):
|
|
247
|
-
"""Reward for thorough town exploration - +0.1 per new position"""
|
|
248
236
|
|
|
249
|
-
|
|
250
|
-
|
|
237
|
+
class LevelUpComponent(RewardComponent):
|
|
238
|
+
"""Legacy level up - now handled by StrategicTrainingReward"""
|
|
251
239
|
|
|
252
240
|
async def score(self, state: Dict[str, Any], action: Dict[str, Any]) -> float:
|
|
253
|
-
|
|
254
|
-
position_key = (state["player_x"], state["player_y"])
|
|
255
|
-
if position_key not in self.positions_visited:
|
|
256
|
-
self.positions_visited.add(position_key)
|
|
257
|
-
return 0.1
|
|
258
|
-
return 0.0
|
|
259
|
-
|
|
241
|
+
return 0.0 # Handled by StrategicTrainingReward
|
|
260
242
|
|
|
261
|
-
class RouteAttemptReward(RewardComponent):
|
|
262
|
-
"""Reward for trying to leave town (triggers story) - +3.0 points"""
|
|
263
243
|
|
|
264
|
-
|
|
265
|
-
|
|
244
|
+
class XPGainComponent(RewardComponent):
|
|
245
|
+
"""Legacy XP gain - now handled by StrategicTrainingReward"""
|
|
266
246
|
|
|
267
247
|
async def score(self, state: Dict[str, Any], action: Dict[str, Any]) -> float:
|
|
268
|
-
|
|
269
|
-
return 0.0
|
|
270
|
-
|
|
271
|
-
# Detect reaching the edge of Pallet Town (attempting to go north)
|
|
272
|
-
if state["map_id"] == 0: # In Pallet Town
|
|
273
|
-
if state["player_y"] <= 1: # At northern edge
|
|
274
|
-
self.route_attempted = True
|
|
275
|
-
return 3.0
|
|
276
|
-
return 0.0
|
|
248
|
+
return 0.0 # Handled by StrategicTrainingReward
|