synth-ai 0.2.12__py3-none-any.whl → 0.2.13.dev2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of synth-ai might be problematic. Click here for more details.
- examples/multi_step/configs/crafter_rl_outcome.toml +74 -0
- examples/multi_step/configs/crafter_rl_stepwise_hosted_judge.toml +186 -0
- examples/multi_step/configs/crafter_rl_stepwise_shaped.toml +83 -0
- examples/multi_step/configs/crafter_rl_stepwise_simple.toml +78 -0
- examples/multi_step/crafter_rl_lora.md +51 -10
- examples/multi_step/sse_metrics_streaming_notes.md +357 -0
- examples/multi_step/task_app_config_notes.md +7 -1
- examples/swe/task_app/grpo_swe_mini.py +55 -26
- examples/swe/task_app/hosted/rollout.py +40 -0
- examples/swe/task_app/hosted/test_service.py +5 -6
- examples/task_apps/TESTING.md +275 -0
- examples/task_apps/__init__.py +0 -0
- examples/task_apps/crafter/__init__.py +0 -0
- examples/task_apps/crafter/task_app/__init__.py +2 -0
- examples/{warming_up_to_rl → task_apps/crafter}/task_app/grpo_crafter.py +21 -46
- examples/{warming_up_to_rl → task_apps/crafter}/task_app/grpo_crafter_task_app.py +1 -1
- examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/envs/crafter/policy.py +60 -4
- examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/inference/openai_client.py +109 -45
- examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/policy_routes.py +67 -49
- examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/rollout.py +242 -193
- examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/test_service.py +5 -6
- examples/task_apps/dev/pokemon_emerald/__init__.py +2 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/README.md +811 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/agent/__init__.py +120 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/agent/action.py +160 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/agent/memory.py +155 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/agent/perception.py +69 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/agent/planning.py +96 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/agent/simple.py +1502 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/agent/system_prompt.py +4 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/grab_map.py +68 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/manual.py +216 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/pokemon_env/__init__.py +35 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/pokemon_env/emerald_utils.py +631 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/pokemon_env/emulator.py +1544 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/pokemon_env/enums.py +1428 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/pokemon_env/memory_reader.py +4848 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/pokemon_env/types.py +41 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/pokemon_env/utils.py +298 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/pyproject.toml +95 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/run.py +204 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/server/__init__.py +0 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/server/app.py +2152 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/server/client.py +429 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/server/frame_server.py +155 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/README.md +78 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/__init__.py +0 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/run_tests.py +122 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/test_agent_direct.py +76 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/test_agent_prompts.py +413 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/test_battle_state_formatting.py +204 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/test_dialogue_detection.py +133 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/test_dialogue_detection_comprehensive.py +229 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/test_direct_agent_emulator.py +300 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/test_fps_adjustment_pytest.py +205 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/test_house_to_outside_direct.py +200 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/test_house_to_outside_transition.py +284 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/test_map_ground_truth_comparison.py +468 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/test_memory_map.py +575 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/test_server_map_validation.py +311 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/test_torchic_state.py +259 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/__init__.py +0 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/anticheat.py +372 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/checkpoint.py +296 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/error_handler.py +275 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/get_local_ip.py +22 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/helpers.py +44 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/llm_logger.py +514 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/map_formatter.py +415 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/map_stitcher.py +1763 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/map_stitcher_singleton.py +33 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/map_trimmer.py +106 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/map_visualizer.py +334 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/ocr_dialogue.py +1020 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/recording.py +188 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/state_formatter.py +1481 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/vlm.py +862 -0
- examples/task_apps/dev/pokemon_emerald/modal_app.py +114 -0
- examples/task_apps/dev/pokemon_emerald/task_app/README.md +81 -0
- examples/task_apps/dev/pokemon_emerald/task_app/__init__.py +6 -0
- examples/task_apps/dev/pokemon_emerald/task_app/pokemon_emerald.py +685 -0
- examples/task_apps/enron/__init__.py +1 -0
- examples/task_apps/enron/eval_groq_qwen32.toml +16 -0
- examples/task_apps/enron/task_app/README.md +14 -0
- examples/task_apps/enron/task_app/__init__.py +1 -0
- examples/task_apps/enron/task_app/grpo_enron.py +906 -0
- examples/task_apps/enron/task_app/grpo_enron_task_app.py +146 -0
- examples/task_apps/enron/tests/__init__.py +2 -0
- examples/task_apps/enron/tests/conftest.py +115 -0
- examples/task_apps/enron/tests/integration/__init__.py +2 -0
- examples/task_apps/enron/tests/integration/test_enron_eval.py +177 -0
- examples/task_apps/enron/tests/integration/test_enron_rollout.py +135 -0
- examples/task_apps/enron/tests/unit/__init__.py +2 -0
- examples/task_apps/enron/tests/unit/test_enron_environment.py +126 -0
- examples/task_apps/math/__init__.py +0 -0
- examples/{rl/task_app → task_apps/math}/math_single_step.py +19 -10
- examples/task_apps/pokemon_battle/__init__.py +2 -0
- examples/task_apps/pokemon_battle/modal_app.py +104 -0
- examples/task_apps/pokemon_battle/task_app/README.md +68 -0
- examples/task_apps/pokemon_battle/task_app/__init__.py +6 -0
- examples/task_apps/pokemon_battle/task_app/pokemon_showdown.py +932 -0
- examples/task_apps/pokemon_red/README.md +357 -0
- examples/task_apps/pokemon_red/__init__.py +3 -0
- examples/task_apps/pokemon_red/eval_pokemon_red_policy.py +225 -0
- examples/task_apps/pokemon_red/pallet_town_rl_config.toml +73 -0
- examples/task_apps/pokemon_red/task_app.py +606 -0
- examples/task_apps/pokemon_red/test_pallet_town_rewards.py +191 -0
- examples/task_apps/sokoban/README.md +307 -0
- examples/task_apps/sokoban/__init__.py +3 -0
- examples/task_apps/sokoban/eval_groq_qwen32.toml +16 -0
- examples/task_apps/sokoban/eval_openai_gpt5.toml +16 -0
- examples/task_apps/sokoban/task_app.py +1058 -0
- examples/task_apps/sokoban/tests/__init__.py +2 -0
- examples/task_apps/sokoban/tests/conftest.py +113 -0
- examples/task_apps/sokoban/tests/integration/__init__.py +2 -0
- examples/task_apps/sokoban/tests/integration/test_sokoban_eval.py +57 -0
- examples/task_apps/sokoban/tests/integration/test_sokoban_rollout.py +198 -0
- examples/task_apps/sokoban/tests/unit/__init__.py +2 -0
- examples/task_apps/sokoban/tests/unit/test_sokoban_environment.py +114 -0
- examples/task_apps/verilog/__init__.py +1 -0
- examples/task_apps/verilog/eval_groq_qwen32b.toml +20 -0
- examples/task_apps/verilog/task_app/README.md +12 -0
- examples/task_apps/verilog/task_app/__init__.py +1 -0
- examples/task_apps/verilog/task_app/grpo_verilog.py +931 -0
- examples/task_apps/verilog/task_app/grpo_verilog_task_app.py +145 -0
- examples/task_apps/verilog/tests/__init__.py +2 -0
- examples/task_apps/verilog/tests/conftest.py +115 -0
- examples/task_apps/verilog/tests/integration/__init__.py +2 -0
- examples/task_apps/verilog/tests/integration/test_verilog_eval.py +179 -0
- examples/task_apps/verilog/tests/integration/test_verilog_rollout.py +55 -0
- examples/task_apps/verilog/tests/unit/__init__.py +2 -0
- examples/task_apps/verilog/tests/unit/test_verilog_scoring.py +118 -0
- examples/vlm/crafter_openai_vlm_agent.py +4 -4
- examples/vlm/run_crafter_vlm_benchmark.py +4 -4
- examples/warming_up_to_rl/configs/eval_stepwise_complex.toml +4 -2
- examples/warming_up_to_rl/configs/eval_stepwise_simple.toml +4 -2
- examples/warming_up_to_rl/run_eval.py +127 -18
- examples/workflows/__init__.py +0 -0
- examples/workflows/math_rl/__init__.py +0 -0
- examples/workflows/math_rl/download_dataset.py +80 -0
- synth_ai/__init__.py +41 -1
- synth_ai/api/train/builders.py +73 -29
- synth_ai/api/train/cli.py +12 -6
- synth_ai/api/train/configs/__init__.py +44 -0
- synth_ai/api/train/configs/rl.py +134 -0
- synth_ai/api/train/configs/sft.py +95 -0
- synth_ai/api/train/configs/shared.py +24 -0
- synth_ai/api/train/env_resolver.py +5 -2
- synth_ai/api/train/supported_algos.py +10 -5
- synth_ai/api/train/utils.py +7 -4
- synth_ai/cli/__init__.py +7 -51
- synth_ai/cli/_storage.py +4 -3
- synth_ai/cli/_validate_task_app.py +11 -0
- synth_ai/cli/balance.py +4 -3
- synth_ai/cli/calc.py +2 -2
- synth_ai/cli/demo.py +49 -43
- synth_ai/cli/legacy_root_backup.py +1 -1
- synth_ai/cli/rl_demo.py +86 -106
- synth_ai/cli/root.py +0 -97
- synth_ai/cli/task_apps.py +1710 -186
- synth_ai/demos/core/cli.py +121 -159
- synth_ai/demos/demo_task_apps/crafter/grpo_crafter_task_app.py +28 -16
- synth_ai/environments/examples/crafter_classic/environment.py +16 -0
- synth_ai/environments/examples/enron/engine.py +7 -2
- synth_ai/environments/examples/enron/environment.py +68 -0
- synth_ai/environments/examples/red/engine.py +27 -0
- synth_ai/environments/examples/red/engine_helpers/memory_map.py +7 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/pallet_town_progression.py +477 -0
- synth_ai/environments/examples/red/engine_helpers/state_extraction.py +32 -0
- synth_ai/environments/examples/red/environment.py +60 -0
- synth_ai/environments/examples/sokoban/taskset.py +116 -0
- synth_ai/environments/examples/verilog/engine.py +30 -4
- synth_ai/evals/__init__.py +15 -0
- synth_ai/evals/client.py +82 -0
- synth_ai/evals/types.py +42 -0
- synth_ai/jobs/client.py +16 -4
- synth_ai/judge_schemas.py +127 -0
- synth_ai/py.typed +0 -0
- synth_ai/task/__init__.py +14 -5
- synth_ai/task/contracts.py +124 -38
- synth_ai/task/proxy.py +48 -56
- synth_ai/task/rubrics/__init__.py +53 -0
- synth_ai/task/rubrics/loaders.py +133 -0
- synth_ai/task/rubrics/models.py +57 -0
- synth_ai/task/rubrics/scoring.py +113 -0
- synth_ai/task/rubrics/strict.py +149 -0
- synth_ai/task/server.py +8 -7
- synth_ai/task/validators.py +269 -6
- synth_ai/tracing_v3/decorators.py +7 -3
- synth_ai/tracing_v3/replica_sync.py +4 -4
- synth_ai/tracing_v3/serialization.py +130 -0
- synth_ai/tracing_v3/trace_utils.py +317 -0
- synth_ai/tracing_v3/turso/native_manager.py +3 -3
- {synth_ai-0.2.12.dist-info → synth_ai-0.2.13.dev2.dist-info}/METADATA +4 -1
- {synth_ai-0.2.12.dist-info → synth_ai-0.2.13.dev2.dist-info}/RECORD +228 -89
- {synth_ai-0.2.12.dist-info → synth_ai-0.2.13.dev2.dist-info}/entry_points.txt +0 -1
- synth_ai/task/rubrics.py +0 -219
- /examples/{warming_up_to_rl → task_apps/crafter}/task_app/README.md +0 -0
- /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/README.md +0 -0
- /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/__init__.py +0 -0
- /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/branching.py +0 -0
- /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/environment_routes.py +0 -0
- /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/envs/__init__.py +0 -0
- /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/envs/crafter/__init__.py +0 -0
- /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/envs/crafter/app.py +0 -0
- /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/envs/crafter/environment.py +0 -0
- /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/envs/crafter/react_agent.py +0 -0
- /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/envs/crafter/shared.py +0 -0
- /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/envs/crafter/tools.py +0 -0
- /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/hosted_app.py +0 -0
- /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/inference/__init__.py +0 -0
- /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/main.py +0 -0
- /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/registry.py +0 -0
- /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/storage/__init__.py +0 -0
- /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/storage/volume.py +0 -0
- /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/test_agents.py +0 -0
- /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/utils.py +0 -0
- /examples/{rl/task_app → task_apps/math}/README.md +0 -0
- /examples/{rl/task_app → task_apps/math}/math_task_app.py +0 -0
- /examples/{rl → workflows/math_rl}/configs/eval_base_qwen.toml +0 -0
- /examples/{rl → workflows/math_rl}/configs/eval_rl_qwen.toml +0 -0
- /examples/{rl → workflows/math_rl}/configs/rl_from_base_qwen.toml +0 -0
- /examples/{rl → workflows/math_rl}/configs/rl_from_base_qwen17.toml +0 -0
- /examples/{rl → workflows/math_rl}/configs/rl_from_ft_qwen.toml +0 -0
- /examples/{rl → workflows/math_rl}/run_eval.py +0 -0
- /examples/{rl → workflows/math_rl}/run_rl_and_save.py +0 -0
- {synth_ai-0.2.12.dist-info → synth_ai-0.2.13.dev2.dist-info}/WHEEL +0 -0
- {synth_ai-0.2.12.dist-info → synth_ai-0.2.13.dev2.dist-info}/licenses/LICENSE +0 -0
- {synth_ai-0.2.12.dist-info → synth_ai-0.2.13.dev2.dist-info}/top_level.txt +0 -0
|
@@ -6,7 +6,7 @@ import logging
|
|
|
6
6
|
import os
|
|
7
7
|
import time as _time
|
|
8
8
|
from datetime import datetime
|
|
9
|
-
from typing import Any
|
|
9
|
+
from typing import Any, Mapping
|
|
10
10
|
|
|
11
11
|
from fastapi import APIRouter, HTTPException, Request, status
|
|
12
12
|
from pydantic import BaseModel, Field
|
|
@@ -184,6 +184,121 @@ def _coerce_k_limits(raw_limits: Any) -> dict[str, int]:
|
|
|
184
184
|
return limits
|
|
185
185
|
|
|
186
186
|
|
|
187
|
+
def _coerce_int_value(value: Any) -> int | None:
|
|
188
|
+
if isinstance(value, bool):
|
|
189
|
+
return int(value)
|
|
190
|
+
try:
|
|
191
|
+
return int(value) # type: ignore[arg-type]
|
|
192
|
+
except Exception:
|
|
193
|
+
try:
|
|
194
|
+
return int(float(value)) # type: ignore[arg-type]
|
|
195
|
+
except Exception:
|
|
196
|
+
return None
|
|
197
|
+
|
|
198
|
+
|
|
199
|
+
def _compute_resource_reward(
|
|
200
|
+
prev_inventory: Mapping[str, Any] | None,
|
|
201
|
+
new_inventory: Mapping[str, Any] | None,
|
|
202
|
+
prev_counts: Mapping[str, Any] | None,
|
|
203
|
+
new_counts: Mapping[str, Any] | None,
|
|
204
|
+
) -> tuple[float, list[dict[str, Any]], dict[str, int], dict[str, int]]:
|
|
205
|
+
reward_total = 0.0
|
|
206
|
+
components: list[dict[str, Any]] = []
|
|
207
|
+
inventory_deltas: dict[str, int] = {}
|
|
208
|
+
achievement_deltas: dict[str, int] = {}
|
|
209
|
+
|
|
210
|
+
resource_weights = {
|
|
211
|
+
"wood": 0.10,
|
|
212
|
+
"sapling": 0.08,
|
|
213
|
+
"stone": 0.15,
|
|
214
|
+
"coal": 0.18,
|
|
215
|
+
"iron": 0.22,
|
|
216
|
+
"plant": 0.06,
|
|
217
|
+
"meat": 0.12,
|
|
218
|
+
"drink": 0.07,
|
|
219
|
+
"food": 0.07,
|
|
220
|
+
"water": 0.07,
|
|
221
|
+
"energy": 0.04,
|
|
222
|
+
}
|
|
223
|
+
tool_weights = {
|
|
224
|
+
"wood_pickaxe": 0.40,
|
|
225
|
+
"stone_pickaxe": 0.55,
|
|
226
|
+
"iron_pickaxe": 0.75,
|
|
227
|
+
"wood_sword": 0.35,
|
|
228
|
+
"stone_sword": 0.50,
|
|
229
|
+
"iron_sword": 0.70,
|
|
230
|
+
"furnace": 0.45,
|
|
231
|
+
"table": 0.30,
|
|
232
|
+
"bow": 0.45,
|
|
233
|
+
}
|
|
234
|
+
achievement_weights = {
|
|
235
|
+
"collect_wood": 0.08,
|
|
236
|
+
"collect_sapling": 0.06,
|
|
237
|
+
"collect_stone": 0.10,
|
|
238
|
+
"collect_coal": 0.12,
|
|
239
|
+
"collect_iron": 0.14,
|
|
240
|
+
"collect_drink": 0.06,
|
|
241
|
+
"collect_food": 0.06,
|
|
242
|
+
"collect_plant": 0.06,
|
|
243
|
+
}
|
|
244
|
+
default_resource_weight = 0.05
|
|
245
|
+
default_achievement_weight = 0.05
|
|
246
|
+
|
|
247
|
+
prev_inv = prev_inventory or {}
|
|
248
|
+
new_inv = new_inventory or {}
|
|
249
|
+
for key, raw_value in new_inv.items():
|
|
250
|
+
new_val = _coerce_int_value(raw_value)
|
|
251
|
+
if new_val is None:
|
|
252
|
+
continue
|
|
253
|
+
prev_val = _coerce_int_value(prev_inv.get(key, 0)) or 0
|
|
254
|
+
delta = new_val - prev_val
|
|
255
|
+
if delta <= 0:
|
|
256
|
+
continue
|
|
257
|
+
weight = resource_weights.get(key)
|
|
258
|
+
if weight is None and key in tool_weights:
|
|
259
|
+
weight = tool_weights[key]
|
|
260
|
+
if weight is None:
|
|
261
|
+
weight = default_resource_weight
|
|
262
|
+
gain = weight * delta
|
|
263
|
+
reward_total += gain
|
|
264
|
+
inventory_deltas[str(key)] = delta
|
|
265
|
+
components.append(
|
|
266
|
+
{
|
|
267
|
+
"type": "inventory",
|
|
268
|
+
"item": str(key),
|
|
269
|
+
"delta": delta,
|
|
270
|
+
"weight": weight,
|
|
271
|
+
"reward": gain,
|
|
272
|
+
}
|
|
273
|
+
)
|
|
274
|
+
|
|
275
|
+
prev_ct = prev_counts or {}
|
|
276
|
+
new_ct = new_counts or {}
|
|
277
|
+
for key, raw_value in new_ct.items():
|
|
278
|
+
new_val = _coerce_int_value(raw_value)
|
|
279
|
+
if new_val is None:
|
|
280
|
+
continue
|
|
281
|
+
prev_val = _coerce_int_value(prev_ct.get(key, 0)) or 0
|
|
282
|
+
delta = new_val - prev_val
|
|
283
|
+
if delta <= 0:
|
|
284
|
+
continue
|
|
285
|
+
weight = achievement_weights.get(key, default_achievement_weight)
|
|
286
|
+
gain = weight * delta
|
|
287
|
+
reward_total += gain
|
|
288
|
+
achievement_deltas[str(key)] = delta
|
|
289
|
+
components.append(
|
|
290
|
+
{
|
|
291
|
+
"type": "achievement_count",
|
|
292
|
+
"name": str(key),
|
|
293
|
+
"delta": delta,
|
|
294
|
+
"weight": weight,
|
|
295
|
+
"reward": gain,
|
|
296
|
+
}
|
|
297
|
+
)
|
|
298
|
+
|
|
299
|
+
return reward_total, components, inventory_deltas, achievement_deltas
|
|
300
|
+
|
|
301
|
+
|
|
187
302
|
def compute_stepwise_reward(
|
|
188
303
|
prev_achievements: dict[str, bool],
|
|
189
304
|
new_achievements: dict[str, bool],
|
|
@@ -195,6 +310,10 @@ def compute_stepwise_reward(
|
|
|
195
310
|
weights: dict[str, float] | None = None,
|
|
196
311
|
k_limits: dict[str, int] | None = None,
|
|
197
312
|
episode_counts: dict[str, int] | None = None,
|
|
313
|
+
prev_inventory: dict[str, int] | None = None,
|
|
314
|
+
new_inventory: dict[str, int] | None = None,
|
|
315
|
+
prev_counts: dict[str, int] | None = None,
|
|
316
|
+
new_counts: dict[str, int] | None = None,
|
|
198
317
|
) -> tuple[dict[str, Any], dict[str, Any], dict[str, float]]:
|
|
199
318
|
"""Compute stepwise reward metadata given achievement states before/after a decision."""
|
|
200
319
|
|
|
@@ -202,13 +321,13 @@ def compute_stepwise_reward(
|
|
|
202
321
|
next_map = new_achievements or {}
|
|
203
322
|
|
|
204
323
|
unlocked = [name for name, value in next_map.items() if value and not prev_map.get(name, False)]
|
|
205
|
-
|
|
324
|
+
indicator_from_achievements = 1 if unlocked else 0
|
|
206
325
|
normalized_strategy = _normalize_step_strategy(strategy)
|
|
207
326
|
base_reward = 0.0
|
|
208
327
|
reward_components: list[dict[str, Any]] = []
|
|
209
328
|
credited: list[str] = []
|
|
210
329
|
|
|
211
|
-
if
|
|
330
|
+
if indicator_from_achievements:
|
|
212
331
|
if normalized_strategy == "per_achievement":
|
|
213
332
|
weight_map = weights or {}
|
|
214
333
|
limit_map = k_limits or {}
|
|
@@ -253,7 +372,26 @@ def compute_stepwise_reward(
|
|
|
253
372
|
}
|
|
254
373
|
)
|
|
255
374
|
|
|
256
|
-
|
|
375
|
+
resource_reward = 0.0
|
|
376
|
+
resource_components: list[dict[str, Any]] = []
|
|
377
|
+
inventory_deltas: dict[str, int] = {}
|
|
378
|
+
achievement_deltas: dict[str, int] = {}
|
|
379
|
+
if normalized_strategy == "per_achievement":
|
|
380
|
+
(
|
|
381
|
+
resource_reward,
|
|
382
|
+
resource_components,
|
|
383
|
+
inventory_deltas,
|
|
384
|
+
achievement_deltas,
|
|
385
|
+
) = _compute_resource_reward(prev_inventory, new_inventory, prev_counts, new_counts)
|
|
386
|
+
if resource_components:
|
|
387
|
+
reward_components.extend(resource_components)
|
|
388
|
+
base_reward += resource_reward
|
|
389
|
+
|
|
390
|
+
indicator = 1 if base_reward > 0 else 0
|
|
391
|
+
if indicator == 0 and indicator_from_achievements:
|
|
392
|
+
indicator = indicator_from_achievements
|
|
393
|
+
lambda_effective = indicator_lambda if indicator_lambda not in (None, 0) else 1.0
|
|
394
|
+
reward_value = float(lambda_effective) * float(base_reward)
|
|
257
395
|
|
|
258
396
|
stepwise_info = {
|
|
259
397
|
"decision_index": decision_index,
|
|
@@ -263,10 +401,18 @@ def compute_stepwise_reward(
|
|
|
263
401
|
"strategy": normalized_strategy,
|
|
264
402
|
"base_reward": float(base_reward),
|
|
265
403
|
}
|
|
404
|
+
if indicator_from_achievements and not unlocked:
|
|
405
|
+
stepwise_info["indicator_from_achievements"] = indicator_from_achievements
|
|
266
406
|
if reward_components:
|
|
267
407
|
stepwise_info["components"] = reward_components
|
|
268
408
|
if credited:
|
|
269
409
|
stepwise_info["credited_achievements"] = credited
|
|
410
|
+
if resource_reward:
|
|
411
|
+
stepwise_info["resource_reward"] = float(resource_reward)
|
|
412
|
+
if inventory_deltas:
|
|
413
|
+
stepwise_info["inventory_deltas"] = inventory_deltas
|
|
414
|
+
if achievement_deltas:
|
|
415
|
+
stepwise_info["achievement_count_deltas"] = achievement_deltas
|
|
270
416
|
|
|
271
417
|
decision_sample = {
|
|
272
418
|
"decision_index": decision_index,
|
|
@@ -278,6 +424,8 @@ def compute_stepwise_reward(
|
|
|
278
424
|
}
|
|
279
425
|
if reward_components:
|
|
280
426
|
decision_sample["components"] = reward_components
|
|
427
|
+
if resource_reward:
|
|
428
|
+
decision_sample["resource_reward"] = float(resource_reward)
|
|
281
429
|
|
|
282
430
|
stats = {
|
|
283
431
|
"indicator": float(indicator),
|
|
@@ -286,6 +434,8 @@ def compute_stepwise_reward(
|
|
|
286
434
|
"base_reward": float(base_reward),
|
|
287
435
|
"credited_achievements_count": float(len(credited)),
|
|
288
436
|
}
|
|
437
|
+
if resource_reward:
|
|
438
|
+
stats["resource_reward"] = float(resource_reward)
|
|
289
439
|
return stepwise_info, decision_sample, stats
|
|
290
440
|
|
|
291
441
|
|
|
@@ -368,7 +518,7 @@ class RolloutTracingContext:
|
|
|
368
518
|
session_id=self.run_id, metadata=dict(self.metadata_base)
|
|
369
519
|
)
|
|
370
520
|
except Exception as exc:
|
|
371
|
-
logger.
|
|
521
|
+
logger.info("TRACING_START_FAIL: %s", exc)
|
|
372
522
|
self.enabled = False
|
|
373
523
|
self.tracer = None
|
|
374
524
|
|
|
@@ -1190,6 +1340,34 @@ async def execute_rollout(
|
|
|
1190
1340
|
return {str(k): bool(v) for k, v in ach.items()}
|
|
1191
1341
|
return {}
|
|
1192
1342
|
|
|
1343
|
+
def _extract_inventory(obs: Any) -> dict[str, int]:
|
|
1344
|
+
if not isinstance(obs, dict):
|
|
1345
|
+
return {}
|
|
1346
|
+
inv = obs.get("inventory")
|
|
1347
|
+
if not isinstance(inv, dict):
|
|
1348
|
+
return {}
|
|
1349
|
+
cleaned: dict[str, int] = {}
|
|
1350
|
+
for key, value in inv.items():
|
|
1351
|
+
coerced = _coerce_int_value(value)
|
|
1352
|
+
if coerced is None:
|
|
1353
|
+
continue
|
|
1354
|
+
cleaned[str(key)] = coerced
|
|
1355
|
+
return cleaned
|
|
1356
|
+
|
|
1357
|
+
def _extract_achievement_counts(obs: Any) -> dict[str, int]:
|
|
1358
|
+
if not isinstance(obs, dict):
|
|
1359
|
+
return {}
|
|
1360
|
+
counts = obs.get("achievements_counts")
|
|
1361
|
+
if not isinstance(counts, dict):
|
|
1362
|
+
return {}
|
|
1363
|
+
cleaned: dict[str, int] = {}
|
|
1364
|
+
for key, value in counts.items():
|
|
1365
|
+
coerced = _coerce_int_value(value)
|
|
1366
|
+
if coerced is None:
|
|
1367
|
+
continue
|
|
1368
|
+
cleaned[str(key)] = coerced
|
|
1369
|
+
return cleaned
|
|
1370
|
+
|
|
1193
1371
|
def _summarize_tool_calls(tool_calls: Any) -> list[dict[str, Any]]:
|
|
1194
1372
|
if not tool_calls:
|
|
1195
1373
|
return []
|
|
@@ -1226,6 +1404,8 @@ async def execute_rollout(
|
|
|
1226
1404
|
session_trace = None
|
|
1227
1405
|
finalized = False
|
|
1228
1406
|
prev_achievements = _extract_achievements(current_obs)
|
|
1407
|
+
prev_inventory_state = _extract_inventory(current_obs)
|
|
1408
|
+
prev_achievement_counts_state = _extract_achievement_counts(current_obs)
|
|
1229
1409
|
# Track episode-level achievements that have been seen as true at any point so far
|
|
1230
1410
|
episode_seen_achievements: set[str] = {
|
|
1231
1411
|
k for k, v in (prev_achievements or {}).items() if bool(v)
|
|
@@ -1233,6 +1413,7 @@ async def execute_rollout(
|
|
|
1233
1413
|
episode_achievement_counts: dict[str, int] = {}
|
|
1234
1414
|
stepwise_indicator_sum = 0.0
|
|
1235
1415
|
stepwise_reward_sum = 0.0
|
|
1416
|
+
stepwise_resource_reward_sum = 0.0
|
|
1236
1417
|
stepwise_new_achievements_total = 0
|
|
1237
1418
|
final_achievement_count = sum(1 for v in prev_achievements.values() if v)
|
|
1238
1419
|
|
|
@@ -1346,58 +1527,14 @@ async def execute_rollout(
|
|
|
1346
1527
|
req,
|
|
1347
1528
|
)
|
|
1348
1529
|
except Exception as _pe:
|
|
1349
|
-
#
|
|
1350
|
-
|
|
1351
|
-
|
|
1352
|
-
|
|
1353
|
-
|
|
1354
|
-
|
|
1355
|
-
str(_pe),
|
|
1356
|
-
)
|
|
1357
|
-
|
|
1358
|
-
# Build partial trajectory and return HTTP 200
|
|
1359
|
-
trajectory = RolloutTrajectory(
|
|
1360
|
-
env_id=env_id,
|
|
1361
|
-
policy_id=policy_id,
|
|
1362
|
-
steps=trajectory_steps,
|
|
1363
|
-
final={
|
|
1364
|
-
"observation": current_obs,
|
|
1365
|
-
"rollout_status": "partial_policy_error",
|
|
1366
|
-
"error": str(_pe),
|
|
1367
|
-
"at_op": op,
|
|
1368
|
-
},
|
|
1369
|
-
length=len(trajectory_steps),
|
|
1370
|
-
decision_samples=decision_samples if step_rewards_active else None,
|
|
1371
|
-
)
|
|
1372
|
-
metrics = RolloutMetrics(
|
|
1373
|
-
episode_returns=[total_reward],
|
|
1374
|
-
mean_return=total_reward,
|
|
1375
|
-
num_steps=len(trajectory_steps),
|
|
1376
|
-
num_episodes=1,
|
|
1377
|
-
)
|
|
1378
|
-
aborted = registry.is_run_aborted(request.run_id)
|
|
1379
|
-
if not aborted:
|
|
1380
|
-
registry.complete_run(request.run_id)
|
|
1381
|
-
if decision_open:
|
|
1382
|
-
await tracing_context.end_decision()
|
|
1383
|
-
decision_open = False
|
|
1384
|
-
if not finalized:
|
|
1385
|
-
session_trace = await tracing_context.finalize(
|
|
1386
|
-
total_reward=total_reward,
|
|
1387
|
-
achievement_state=prev_achievements,
|
|
1388
|
-
total_steps=len(trajectory_steps),
|
|
1389
|
-
)
|
|
1390
|
-
finalized = True
|
|
1391
|
-
trace_payload = tracing_context.build_trace_payload(session_trace)
|
|
1392
|
-
return RolloutResponse(
|
|
1393
|
-
run_id=request.run_id,
|
|
1394
|
-
trajectories=[trajectory],
|
|
1395
|
-
branches={},
|
|
1396
|
-
metrics=metrics,
|
|
1397
|
-
aborted=aborted,
|
|
1398
|
-
ops_executed=ops_executed,
|
|
1399
|
-
trace=trace_payload,
|
|
1530
|
+
# Hard fail the rollout on policy step error (e.g., inference auth 4xx)
|
|
1531
|
+
logger.error(
|
|
1532
|
+
"POLICY_STEP_HARD_FAIL: run_id=%s op_idx=%s err=%s",
|
|
1533
|
+
request.run_id,
|
|
1534
|
+
str(op_idx),
|
|
1535
|
+
str(_pe),
|
|
1400
1536
|
)
|
|
1537
|
+
raise HTTPException(status_code=500, detail=f"policy_step_failed: {str(_pe)}")
|
|
1401
1538
|
|
|
1402
1539
|
agent_response_ts = _time.perf_counter()
|
|
1403
1540
|
if isinstance(policy_response.meta, dict):
|
|
@@ -1464,69 +1601,15 @@ async def execute_rollout(
|
|
|
1464
1601
|
|
|
1465
1602
|
elif op == "env":
|
|
1466
1603
|
if not pending_tool_calls:
|
|
1467
|
-
# Treat absence of tool calls as a soft terminal condition; yield partial trajectory
|
|
1468
1604
|
with contextlib.suppress(Exception):
|
|
1469
1605
|
logger.warning(
|
|
1470
|
-
"
|
|
1606
|
+
"POLICY_STEP_FAIL: missing tool_calls; failing rollout run_id=%s op_idx=%s",
|
|
1471
1607
|
request.run_id,
|
|
1472
1608
|
str(op_idx),
|
|
1473
1609
|
)
|
|
1474
|
-
|
|
1475
|
-
|
|
1476
|
-
|
|
1477
|
-
)
|
|
1478
|
-
term_step = RolloutStep(
|
|
1479
|
-
obs=current_obs,
|
|
1480
|
-
tool_calls=[],
|
|
1481
|
-
reward=None,
|
|
1482
|
-
done=True,
|
|
1483
|
-
truncated=False,
|
|
1484
|
-
info={
|
|
1485
|
-
"terminated": True,
|
|
1486
|
-
"reason": "no_tool_calls",
|
|
1487
|
-
},
|
|
1488
|
-
)
|
|
1489
|
-
trajectory_steps.append(term_step)
|
|
1490
|
-
trajectory = RolloutTrajectory(
|
|
1491
|
-
env_id=env_id,
|
|
1492
|
-
policy_id=policy_id,
|
|
1493
|
-
steps=trajectory_steps,
|
|
1494
|
-
final={
|
|
1495
|
-
"observation": current_obs,
|
|
1496
|
-
"rollout_status": "partial_no_tool_calls",
|
|
1497
|
-
"at_op": op,
|
|
1498
|
-
},
|
|
1499
|
-
length=len(trajectory_steps),
|
|
1500
|
-
decision_samples=decision_samples if step_rewards_active else None,
|
|
1501
|
-
)
|
|
1502
|
-
metrics = RolloutMetrics(
|
|
1503
|
-
episode_returns=[total_reward],
|
|
1504
|
-
mean_return=total_reward,
|
|
1505
|
-
num_steps=len(trajectory_steps),
|
|
1506
|
-
num_episodes=1,
|
|
1507
|
-
)
|
|
1508
|
-
aborted = registry.is_run_aborted(request.run_id)
|
|
1509
|
-
if not aborted:
|
|
1510
|
-
registry.complete_run(request.run_id)
|
|
1511
|
-
if decision_open:
|
|
1512
|
-
await tracing_context.end_decision()
|
|
1513
|
-
decision_open = False
|
|
1514
|
-
if not finalized:
|
|
1515
|
-
session_trace = await tracing_context.finalize(
|
|
1516
|
-
total_reward=total_reward,
|
|
1517
|
-
achievement_state=prev_achievements,
|
|
1518
|
-
total_steps=len(trajectory_steps),
|
|
1519
|
-
)
|
|
1520
|
-
finalized = True
|
|
1521
|
-
trace_payload = tracing_context.build_trace_payload(session_trace)
|
|
1522
|
-
return RolloutResponse(
|
|
1523
|
-
run_id=request.run_id,
|
|
1524
|
-
trajectories=[trajectory],
|
|
1525
|
-
branches={},
|
|
1526
|
-
metrics=metrics,
|
|
1527
|
-
aborted=aborted,
|
|
1528
|
-
ops_executed=ops_executed,
|
|
1529
|
-
trace=trace_payload,
|
|
1610
|
+
raise HTTPException(
|
|
1611
|
+
status_code=500,
|
|
1612
|
+
detail="policy_step_failed: missing tool_calls (no_tool_calls)",
|
|
1530
1613
|
)
|
|
1531
1614
|
|
|
1532
1615
|
# Environment step
|
|
@@ -1555,85 +1638,16 @@ async def execute_rollout(
|
|
|
1555
1638
|
timing_env["env_step_end_s"] = env_step_end
|
|
1556
1639
|
|
|
1557
1640
|
if env_step_error is not None:
|
|
1558
|
-
# Invalid action or environment rejection — terminate episode early with partial trajectory
|
|
1559
1641
|
with contextlib.suppress(Exception):
|
|
1560
1642
|
logger.warning(
|
|
1561
|
-
"ENV_STEP_FAIL:
|
|
1643
|
+
"ENV_STEP_FAIL: failing rollout run_id=%s op_idx=%s err=%s",
|
|
1562
1644
|
request.run_id,
|
|
1563
1645
|
str(op_idx),
|
|
1564
1646
|
str(env_step_error),
|
|
1565
1647
|
)
|
|
1566
|
-
|
|
1567
|
-
|
|
1568
|
-
|
|
1569
|
-
tool_calls=pending_tool_calls,
|
|
1570
|
-
reward=None,
|
|
1571
|
-
done=True,
|
|
1572
|
-
truncated=False,
|
|
1573
|
-
info={
|
|
1574
|
-
"terminated": True,
|
|
1575
|
-
"reason": "invalid_action",
|
|
1576
|
-
"error": str(env_step_error),
|
|
1577
|
-
},
|
|
1578
|
-
)
|
|
1579
|
-
trajectory_steps.append(term_step)
|
|
1580
|
-
# Build partial response
|
|
1581
|
-
trajectory = RolloutTrajectory(
|
|
1582
|
-
env_id=env_id,
|
|
1583
|
-
policy_id=policy_id,
|
|
1584
|
-
steps=trajectory_steps,
|
|
1585
|
-
final={
|
|
1586
|
-
"observation": current_obs,
|
|
1587
|
-
"rollout_status": "partial_invalid_action",
|
|
1588
|
-
"error": str(env_step_error),
|
|
1589
|
-
"at_op": op,
|
|
1590
|
-
},
|
|
1591
|
-
length=len(trajectory_steps),
|
|
1592
|
-
decision_samples=decision_samples if step_rewards_active else None,
|
|
1593
|
-
)
|
|
1594
|
-
metrics = RolloutMetrics(
|
|
1595
|
-
episode_returns=[total_reward],
|
|
1596
|
-
mean_return=total_reward,
|
|
1597
|
-
num_steps=len(trajectory_steps),
|
|
1598
|
-
num_episodes=1,
|
|
1599
|
-
)
|
|
1600
|
-
aborted = registry.is_run_aborted(request.run_id)
|
|
1601
|
-
if not aborted:
|
|
1602
|
-
registry.complete_run(request.run_id)
|
|
1603
|
-
if (
|
|
1604
|
-
last_policy_meta is not None
|
|
1605
|
-
and last_agent_response_ts is not None
|
|
1606
|
-
and "decision_ms" not in last_policy_meta.get("timing", {})
|
|
1607
|
-
):
|
|
1608
|
-
with contextlib.suppress(Exception):
|
|
1609
|
-
timing_last = last_policy_meta.setdefault("timing", {})
|
|
1610
|
-
decision_ms = max(
|
|
1611
|
-
0.0,
|
|
1612
|
-
(env_step_end - float(last_agent_response_ts)) * 1000.0,
|
|
1613
|
-
)
|
|
1614
|
-
timing_last["decision_ms"] = decision_ms
|
|
1615
|
-
timing_last.setdefault(
|
|
1616
|
-
"overhead_ms", max(0.0, decision_ms - env_step_duration_ms)
|
|
1617
|
-
)
|
|
1618
|
-
if decision_open:
|
|
1619
|
-
await tracing_context.end_decision()
|
|
1620
|
-
decision_open = False
|
|
1621
|
-
if not finalized:
|
|
1622
|
-
session_trace = await tracing_context.finalize(
|
|
1623
|
-
total_reward=total_reward,
|
|
1624
|
-
achievement_state=prev_achievements,
|
|
1625
|
-
total_steps=len(trajectory_steps),
|
|
1626
|
-
)
|
|
1627
|
-
finalized = True
|
|
1628
|
-
trace_payload = tracing_context.build_trace_payload(session_trace)
|
|
1629
|
-
return RolloutResponse(
|
|
1630
|
-
run_id=request.run_id,
|
|
1631
|
-
trajectories=[trajectory],
|
|
1632
|
-
branches={},
|
|
1633
|
-
metrics=metrics,
|
|
1634
|
-
aborted=aborted,
|
|
1635
|
-
ops_executed=ops_executed,
|
|
1636
|
-
trace=trace_payload,
|
|
1648
|
+
raise HTTPException(
|
|
1649
|
+
status_code=500,
|
|
1650
|
+
detail=f"env_step_failed: {str(env_step_error)}",
|
|
1637
1651
|
)
|
|
1638
1652
|
|
|
1639
1653
|
# Reaching here means env step succeeded
|
|
@@ -1664,12 +1678,16 @@ async def execute_rollout(
|
|
|
1664
1678
|
decision_index += 1
|
|
1665
1679
|
next_obs = env_response.observation
|
|
1666
1680
|
new_achievement_state = _extract_achievements(next_obs)
|
|
1681
|
+
new_inventory_state = _extract_inventory(next_obs)
|
|
1682
|
+
new_achievement_counts_state = _extract_achievement_counts(next_obs)
|
|
1667
1683
|
final_achievement_count = sum(
|
|
1668
1684
|
1 for _, unlocked in new_achievement_state.items() if unlocked
|
|
1669
1685
|
)
|
|
1670
1686
|
indicator_val = 0
|
|
1671
1687
|
reward_stepwise = 0.0
|
|
1672
1688
|
decision_rewards_meta: dict[str, Any] | None = None
|
|
1689
|
+
decision_record = None
|
|
1690
|
+
_info = {} if not isinstance(_info, dict) else dict(_info)
|
|
1673
1691
|
if step_rewards_active:
|
|
1674
1692
|
decision_actions = _summarize_tool_calls(pending_tool_calls)
|
|
1675
1693
|
stepwise_info, decision_record, stats = compute_stepwise_reward(
|
|
@@ -1682,13 +1700,20 @@ async def execute_rollout(
|
|
|
1682
1700
|
weights=step_rewards_weights,
|
|
1683
1701
|
k_limits=step_rewards_k_limits,
|
|
1684
1702
|
episode_counts=episode_achievement_counts,
|
|
1703
|
+
prev_inventory=prev_inventory_state,
|
|
1704
|
+
new_inventory=new_inventory_state,
|
|
1705
|
+
prev_counts=prev_achievement_counts_state,
|
|
1706
|
+
new_counts=new_achievement_counts_state,
|
|
1685
1707
|
)
|
|
1686
1708
|
indicator_val = int(stats.get("indicator", 0.0))
|
|
1687
1709
|
reward_stepwise = float(stats.get("reward", 0.0))
|
|
1688
1710
|
stepwise_indicator_sum += float(stats.get("indicator", 0.0))
|
|
1689
1711
|
stepwise_reward_sum += reward_stepwise
|
|
1690
1712
|
stepwise_new_achievements_total += int(stats.get("new_achievements_count", 0.0))
|
|
1691
|
-
|
|
1713
|
+
with contextlib.suppress(Exception):
|
|
1714
|
+
resource_component = stats.get("resource_reward")
|
|
1715
|
+
if resource_component is not None:
|
|
1716
|
+
stepwise_resource_reward_sum += float(resource_component)
|
|
1692
1717
|
_info["stepwise"] = stepwise_info
|
|
1693
1718
|
# Compute decision-level rewards (absolute vs unique) and attach to metadata
|
|
1694
1719
|
with contextlib.suppress(Exception):
|
|
@@ -1710,13 +1735,16 @@ async def execute_rollout(
|
|
|
1710
1735
|
"all": all_list,
|
|
1711
1736
|
"unique": new_unique,
|
|
1712
1737
|
}
|
|
1713
|
-
|
|
1714
|
-
|
|
1715
|
-
|
|
1716
|
-
|
|
1717
|
-
|
|
1738
|
+
decision_rewards_meta = decision_rewards
|
|
1739
|
+
meta_block["decision_rewards"] = decision_rewards
|
|
1740
|
+
_info["meta"] = meta_block
|
|
1741
|
+
# Update episode-level seen set after attributing uniqueness to this decision
|
|
1742
|
+
episode_seen_achievements.update(turned_true)
|
|
1743
|
+
if decision_record is not None:
|
|
1718
1744
|
decision_samples.append(decision_record)
|
|
1719
1745
|
prev_achievements = new_achievement_state
|
|
1746
|
+
prev_inventory_state = new_inventory_state
|
|
1747
|
+
prev_achievement_counts_state = new_achievement_counts_state
|
|
1720
1748
|
|
|
1721
1749
|
await tracing_context.record_decision_reward(
|
|
1722
1750
|
event_id=event_id,
|
|
@@ -1815,12 +1843,22 @@ async def execute_rollout(
|
|
|
1815
1843
|
timing_final.setdefault("overhead_ms", 0.0)
|
|
1816
1844
|
|
|
1817
1845
|
# Build trajectory
|
|
1846
|
+
# Extract inference_url from policy meta
|
|
1847
|
+
inference_url = None
|
|
1848
|
+
if policy_handle is not None:
|
|
1849
|
+
try:
|
|
1850
|
+
policy_snapshot = policy_handle.snapshot()
|
|
1851
|
+
inference_url = policy_snapshot.get("config", {}).get("inference_url")
|
|
1852
|
+
except Exception:
|
|
1853
|
+
pass
|
|
1854
|
+
|
|
1818
1855
|
trajectory = RolloutTrajectory(
|
|
1819
1856
|
env_id=env_id,
|
|
1820
1857
|
policy_id=policy_id,
|
|
1821
1858
|
steps=trajectory_steps,
|
|
1822
1859
|
final={"observation": _summarize_observation_for_storage(env_handle, current_obs)},
|
|
1823
1860
|
length=len(trajectory_steps),
|
|
1861
|
+
inference_url=inference_url, # NEW: Required for trace correlation
|
|
1824
1862
|
decision_samples=decision_samples if step_rewards_active else None,
|
|
1825
1863
|
)
|
|
1826
1864
|
|
|
@@ -1835,6 +1873,7 @@ async def execute_rollout(
|
|
|
1835
1873
|
stepwise_summary: dict[str, Any] = {
|
|
1836
1874
|
"indicator_sum": float(stepwise_indicator_sum),
|
|
1837
1875
|
"reward_sum": float(stepwise_reward_sum),
|
|
1876
|
+
"resource_reward": float(stepwise_resource_reward_sum),
|
|
1838
1877
|
"new_achievements_total": int(stepwise_new_achievements_total),
|
|
1839
1878
|
"mode": step_rewards_mode,
|
|
1840
1879
|
"strategy": step_rewards_strategy,
|
|
@@ -1847,6 +1886,12 @@ async def execute_rollout(
|
|
|
1847
1886
|
stepwise_summary["weights"] = dict(step_rewards_weights)
|
|
1848
1887
|
if step_rewards_k_limits:
|
|
1849
1888
|
stepwise_summary["k_limits"] = dict(step_rewards_k_limits)
|
|
1889
|
+
final_achievements_list = sorted(
|
|
1890
|
+
key for key, val in (prev_achievements or {}).items() if bool(val)
|
|
1891
|
+
)
|
|
1892
|
+
stepwise_summary["unique_achievements_total"] = int(len(episode_seen_achievements))
|
|
1893
|
+
stepwise_summary["unique_achievements"] = sorted(episode_seen_achievements)
|
|
1894
|
+
stepwise_summary["final_achievements"] = final_achievements_list
|
|
1850
1895
|
metrics.details["stepwise"] = stepwise_summary
|
|
1851
1896
|
|
|
1852
1897
|
# Environment-specific: Log summary if available
|
|
@@ -1904,6 +1949,10 @@ async def execute_rollout(
|
|
|
1904
1949
|
finalized = True
|
|
1905
1950
|
trace_payload = tracing_context.build_trace_payload(session_trace)
|
|
1906
1951
|
|
|
1952
|
+
# Hard-fail if no steps executed (avg_turns == 0 scenario)
|
|
1953
|
+
if metrics.num_steps <= 0:
|
|
1954
|
+
raise HTTPException(status_code=500, detail="no_steps_executed: avg_turns == 0")
|
|
1955
|
+
|
|
1907
1956
|
return RolloutResponse(
|
|
1908
1957
|
run_id=request.run_id,
|
|
1909
1958
|
trajectories=[trajectory],
|
|
@@ -1,15 +1,14 @@
|
|
|
1
1
|
#!/usr/bin/env python3
|
|
2
|
-
"""
|
|
3
|
-
Simple test script for the GRPO Synth Envs Hosted Service.
|
|
4
|
-
|
|
5
|
-
Run this after starting the service with:
|
|
6
|
-
python main.py
|
|
7
|
-
"""
|
|
2
|
+
"""Manual smoke script for the GRPO Synth Envs Hosted Service."""
|
|
8
3
|
|
|
9
4
|
import asyncio
|
|
10
5
|
import json
|
|
11
6
|
|
|
12
7
|
import httpx
|
|
8
|
+
import pytest
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
pytestmark = pytest.mark.skip(reason="Requires running hosted service on localhost:8000")
|
|
13
12
|
|
|
14
13
|
|
|
15
14
|
async def test_service():
|