synth-ai 0.2.13.dev2__py3-none-any.whl → 0.2.16__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of synth-ai might be problematic. Click here for more details.
- examples/README.md +1 -0
- examples/multi_step/SFT_README.md +147 -0
- examples/multi_step/configs/README_verilog_rl.md +77 -0
- examples/multi_step/configs/VERILOG_REWARDS.md +90 -0
- examples/multi_step/configs/VERILOG_RL_CHECKLIST.md +183 -0
- examples/multi_step/configs/crafter_eval_synth_qwen4b.toml +35 -0
- examples/multi_step/configs/crafter_eval_text_only_groq_qwen32b.toml +36 -0
- examples/multi_step/configs/crafter_rl_stepwise_hosted_judge.toml +12 -11
- examples/multi_step/configs/crafter_sft_qwen30b_lora.toml +62 -0
- examples/multi_step/configs/crafter_synth_backend.md +40 -0
- examples/multi_step/configs/verilog_eval_groq_qwen32b.toml +31 -0
- examples/multi_step/configs/verilog_eval_synth_qwen8b.toml +33 -0
- examples/multi_step/configs/verilog_rl_lora.toml +190 -0
- examples/multi_step/convert_traces_to_sft.py +84 -0
- examples/multi_step/judges/crafter_backend_judge.py +220 -0
- examples/multi_step/judges/verilog_backend_judge.py +234 -0
- examples/multi_step/readme.md +48 -0
- examples/multi_step/run_sft_qwen30b.sh +45 -0
- examples/multi_step/verilog_rl_lora.md +218 -0
- examples/qwen_coder/configs/coder_lora_30b.toml +3 -2
- examples/qwen_coder/configs/coder_lora_4b.toml +2 -1
- examples/qwen_coder/configs/coder_lora_small.toml +2 -1
- examples/qwen_vl/BUGS_AND_FIXES.md +232 -0
- examples/qwen_vl/IMAGE_VALIDATION_COMPLETE.md +271 -0
- examples/qwen_vl/IMAGE_VALIDATION_SUMMARY.md +260 -0
- examples/qwen_vl/INFERENCE_SFT_TESTS.md +412 -0
- examples/qwen_vl/NEXT_STEPS_2B.md +325 -0
- examples/qwen_vl/QUICKSTART.md +327 -0
- examples/qwen_vl/QUICKSTART_RL_VISION.md +110 -0
- examples/qwen_vl/README.md +154 -0
- examples/qwen_vl/RL_VISION_COMPLETE.md +475 -0
- examples/qwen_vl/RL_VISION_TESTING.md +333 -0
- examples/qwen_vl/SDK_VISION_INTEGRATION.md +328 -0
- examples/qwen_vl/SETUP_COMPLETE.md +275 -0
- examples/qwen_vl/VISION_TESTS_COMPLETE.md +490 -0
- examples/qwen_vl/VLM_PIPELINE_COMPLETE.md +242 -0
- examples/qwen_vl/__init__.py +2 -0
- examples/qwen_vl/collect_data_via_cli.md +423 -0
- examples/qwen_vl/collect_vision_traces.py +368 -0
- examples/qwen_vl/configs/crafter_rl_vision_qwen3vl4b.toml +127 -0
- examples/qwen_vl/configs/crafter_vlm_sft_example.toml +60 -0
- examples/qwen_vl/configs/eval_gpt4o_mini_vision.toml +43 -0
- examples/qwen_vl/configs/eval_gpt4o_vision_proper.toml +29 -0
- examples/qwen_vl/configs/eval_gpt5nano_vision.toml +45 -0
- examples/qwen_vl/configs/eval_qwen2vl_vision.toml +44 -0
- examples/qwen_vl/configs/filter_qwen2vl_sft.toml +50 -0
- examples/qwen_vl/configs/filter_vision_sft.toml +53 -0
- examples/qwen_vl/configs/filter_vision_test.toml +8 -0
- examples/qwen_vl/configs/sft_qwen3_vl_2b_test.toml +54 -0
- examples/qwen_vl/crafter_gpt5nano_agent.py +308 -0
- examples/qwen_vl/crafter_qwen_vl_agent.py +300 -0
- examples/qwen_vl/run_vision_comparison.sh +62 -0
- examples/qwen_vl/run_vision_sft_pipeline.sh +175 -0
- examples/qwen_vl/test_image_validation.py +201 -0
- examples/qwen_vl/test_sft_vision_data.py +110 -0
- examples/rl/README.md +1 -1
- examples/rl/configs/eval_base_qwen.toml +17 -0
- examples/rl/configs/eval_rl_qwen.toml +13 -0
- examples/rl/configs/rl_from_base_qwen.toml +37 -0
- examples/rl/configs/rl_from_base_qwen17.toml +76 -0
- examples/rl/configs/rl_from_ft_qwen.toml +37 -0
- examples/rl/run_eval.py +436 -0
- examples/rl/run_rl_and_save.py +111 -0
- examples/rl/task_app/README.md +22 -0
- examples/rl/task_app/math_single_step.py +990 -0
- examples/rl/task_app/math_task_app.py +111 -0
- examples/sft/README.md +5 -5
- examples/sft/configs/crafter_fft_qwen0p6b.toml +4 -2
- examples/sft/configs/crafter_lora_qwen0p6b.toml +4 -3
- examples/sft/evaluate.py +4 -4
- examples/sft/export_dataset.py +7 -4
- examples/sft/generate_traces.py +2 -0
- examples/swe/task_app/README.md +1 -1
- examples/swe/task_app/grpo_swe_mini.py +1 -1
- examples/swe/task_app/grpo_swe_mini_task_app.py +0 -12
- examples/swe/task_app/hosted/envs/mini_swe/environment.py +13 -13
- examples/swe/task_app/hosted/policy_routes.py +0 -2
- examples/swe/task_app/hosted/rollout.py +2 -8
- examples/task_apps/IMAGE_ONLY_EVAL_QUICKSTART.md +258 -0
- examples/task_apps/crafter/CREATE_SFT_DATASET.md +273 -0
- examples/task_apps/crafter/EVAL_IMAGE_ONLY_RESULTS.md +152 -0
- examples/task_apps/crafter/FILTER_COMMAND_STATUS.md +174 -0
- examples/task_apps/crafter/FILTER_COMMAND_SUCCESS.md +268 -0
- examples/task_apps/crafter/QUERY_EXAMPLES.md +203 -0
- examples/task_apps/crafter/README_IMAGE_ONLY_EVAL.md +316 -0
- examples/task_apps/crafter/eval_image_only_gpt4o.toml +28 -0
- examples/task_apps/crafter/eval_text_only_groq_llama.toml +36 -0
- examples/task_apps/crafter/filter_sft_dataset.toml +16 -0
- examples/task_apps/crafter/task_app/__init__.py +3 -0
- examples/task_apps/crafter/task_app/grpo_crafter.py +309 -14
- examples/task_apps/crafter/task_app/synth_envs_hosted/envs/crafter/environment.py +10 -0
- examples/task_apps/crafter/task_app/synth_envs_hosted/envs/crafter/policy.py +75 -4
- examples/task_apps/crafter/task_app/synth_envs_hosted/envs/crafter/react_agent.py +17 -2
- examples/task_apps/crafter/task_app/synth_envs_hosted/inference/openai_client.py +55 -3
- examples/task_apps/crafter/task_app/synth_envs_hosted/policy_routes.py +114 -32
- examples/task_apps/crafter/task_app/synth_envs_hosted/rollout.py +127 -27
- examples/task_apps/crafter/task_app/synth_envs_hosted/utils.py +156 -0
- examples/task_apps/enron/__init__.py +1 -0
- examples/task_apps/enron/filter_sft.toml +5 -0
- examples/task_apps/enron/tests/__init__.py +2 -0
- examples/task_apps/enron/tests/integration/__init__.py +2 -0
- examples/task_apps/enron/tests/integration/test_enron_eval.py +2 -0
- examples/task_apps/enron/tests/unit/__init__.py +2 -0
- examples/task_apps/pokemon_red/EVAL_IMAGE_ONLY_COMPLETE.md +283 -0
- examples/task_apps/pokemon_red/EVAL_IMAGE_ONLY_STATUS.md +155 -0
- examples/task_apps/pokemon_red/README_IMAGE_ONLY_EVAL.md +415 -0
- examples/task_apps/pokemon_red/eval_image_only_gpt4o.toml +29 -0
- examples/task_apps/pokemon_red/pallet_town_rl_config.toml +2 -0
- examples/task_apps/pokemon_red/task_app.py +199 -6
- examples/task_apps/pokemon_red/test_pallet_town_rewards.py +2 -0
- examples/task_apps/sokoban/filter_sft.toml +5 -0
- examples/task_apps/sokoban/tests/__init__.py +2 -0
- examples/task_apps/sokoban/tests/integration/__init__.py +2 -0
- examples/task_apps/sokoban/tests/unit/__init__.py +2 -0
- examples/task_apps/verilog/eval_groq_qwen32b.toml +8 -4
- examples/task_apps/verilog/filter_sft.toml +5 -0
- examples/task_apps/verilog/task_app/grpo_verilog.py +258 -23
- examples/task_apps/verilog/tests/__init__.py +2 -0
- examples/task_apps/verilog/tests/integration/__init__.py +2 -0
- examples/task_apps/verilog/tests/integration/test_verilog_eval.py +2 -0
- examples/task_apps/verilog/tests/unit/__init__.py +2 -0
- examples/vlm/README.md +3 -3
- examples/vlm/configs/crafter_vlm_gpt4o.toml +2 -0
- examples/vlm/crafter_openai_vlm_agent.py +3 -5
- examples/vlm/filter_image_rows.py +1 -1
- examples/vlm/run_crafter_vlm_benchmark.py +2 -2
- examples/warming_up_to_rl/_utils.py +92 -0
- examples/warming_up_to_rl/analyze_trace_db.py +1 -1
- examples/warming_up_to_rl/configs/crafter_fft.toml +2 -0
- examples/warming_up_to_rl/configs/crafter_fft_4b.toml +2 -0
- examples/warming_up_to_rl/configs/eval_fft_qwen4b.toml +2 -0
- examples/warming_up_to_rl/configs/eval_groq_qwen32b.toml +2 -0
- examples/warming_up_to_rl/configs/eval_modal_qwen4b.toml +2 -1
- examples/warming_up_to_rl/configs/rl_from_base_qwen4b.toml +2 -1
- examples/warming_up_to_rl/configs/rl_from_ft.toml +2 -0
- examples/warming_up_to_rl/export_trace_sft.py +174 -60
- examples/warming_up_to_rl/groq_test.py +2 -0
- examples/warming_up_to_rl/readme.md +63 -132
- examples/warming_up_to_rl/run_fft_and_save.py +1 -1
- examples/warming_up_to_rl/run_local_rollout.py +2 -0
- examples/warming_up_to_rl/run_local_rollout_modal.py +2 -0
- examples/warming_up_to_rl/run_local_rollout_parallel.py +2 -0
- examples/warming_up_to_rl/run_local_rollout_traced.py +2 -0
- examples/warming_up_to_rl/run_rl_and_save.py +1 -1
- examples/warming_up_to_rl/run_rollout_remote.py +2 -0
- examples/warming_up_to_rl/task_app/README.md +42 -0
- examples/warming_up_to_rl/task_app/grpo_crafter.py +696 -0
- examples/warming_up_to_rl/task_app/grpo_crafter_task_app.py +135 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/README.md +173 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/__init__.py +5 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/branching.py +143 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/environment_routes.py +1226 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/__init__.py +1 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/__init__.py +6 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/app.py +1 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/environment.py +522 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/policy.py +478 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/react_agent.py +108 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/shared.py +305 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/tools.py +47 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/hosted_app.py +204 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/inference/__init__.py +5 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/inference/openai_client.py +618 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/main.py +100 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/policy_routes.py +1081 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/registry.py +195 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/rollout.py +1861 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/storage/__init__.py +5 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/storage/volume.py +211 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/test_agents.py +161 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/test_service.py +137 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/utils.py +62 -0
- synth_ai/__init__.py +44 -30
- synth_ai/_utils/__init__.py +47 -0
- synth_ai/_utils/base_url.py +10 -0
- synth_ai/_utils/http.py +10 -0
- synth_ai/_utils/prompts.py +10 -0
- synth_ai/_utils/task_app_state.py +12 -0
- synth_ai/_utils/user_config.py +10 -0
- synth_ai/api/models/supported.py +145 -7
- synth_ai/api/train/__init__.py +13 -1
- synth_ai/api/train/cli.py +30 -7
- synth_ai/api/train/config_finder.py +18 -11
- synth_ai/api/train/env_resolver.py +13 -10
- synth_ai/cli/__init__.py +66 -49
- synth_ai/cli/_modal_wrapper.py +9 -6
- synth_ai/cli/_typer_patch.py +0 -2
- synth_ai/cli/_validate_task_app.py +22 -4
- synth_ai/cli/legacy_root_backup.py +3 -1
- synth_ai/cli/lib/__init__.py +10 -0
- synth_ai/cli/lib/task_app_discovery.py +7 -0
- synth_ai/cli/lib/task_app_env.py +518 -0
- synth_ai/cli/recent.py +1 -0
- synth_ai/cli/setup.py +266 -0
- synth_ai/cli/task_app_deploy.py +16 -0
- synth_ai/cli/task_app_list.py +25 -0
- synth_ai/cli/task_app_modal_serve.py +16 -0
- synth_ai/cli/task_app_serve.py +18 -0
- synth_ai/cli/task_apps.py +392 -141
- synth_ai/cli/train.py +18 -0
- synth_ai/cli/tui.py +62 -0
- synth_ai/demos/__init__.py +10 -0
- synth_ai/demos/core/__init__.py +28 -1
- synth_ai/demos/crafter/__init__.py +1 -0
- synth_ai/demos/crafter/crafter_fft_4b.toml +55 -0
- synth_ai/demos/crafter/grpo_crafter_task_app.py +185 -0
- synth_ai/demos/crafter/rl_from_base_qwen4b.toml +74 -0
- synth_ai/demos/demo_registry.py +176 -0
- synth_ai/demos/demo_task_apps/crafter/grpo_crafter_task_app.py +1 -1
- synth_ai/demos/math/__init__.py +1 -0
- synth_ai/demos/math/_common.py +16 -0
- synth_ai/demos/math/app.py +38 -0
- synth_ai/demos/math/config.toml +76 -0
- synth_ai/demos/math/deploy_modal.py +54 -0
- synth_ai/demos/math/modal_task_app.py +702 -0
- synth_ai/demos/math/task_app_entry.py +51 -0
- synth_ai/environments/environment/core.py +7 -1
- synth_ai/environments/examples/bandit/engine.py +0 -1
- synth_ai/environments/examples/bandit/environment.py +0 -1
- synth_ai/environments/examples/crafter_classic/environment.py +1 -1
- synth_ai/environments/examples/verilog/engine.py +76 -10
- synth_ai/environments/examples/wordle/environment.py +0 -1
- synth_ai/evals/base.py +16 -5
- synth_ai/evals/client.py +1 -1
- synth_ai/inference/client.py +1 -1
- synth_ai/learning/client.py +1 -1
- synth_ai/learning/health.py +1 -1
- synth_ai/learning/jobs.py +1 -1
- synth_ai/learning/rl/client.py +1 -1
- synth_ai/learning/rl/env_keys.py +1 -1
- synth_ai/learning/rl/secrets.py +1 -1
- synth_ai/learning/sft/client.py +1 -1
- synth_ai/learning/sft/data.py +407 -4
- synth_ai/learning/validators.py +4 -1
- synth_ai/task/__init__.py +11 -1
- synth_ai/task/apps/__init__.py +5 -2
- synth_ai/task/config.py +259 -0
- synth_ai/task/contracts.py +15 -2
- synth_ai/task/rubrics/__init__.py +4 -2
- synth_ai/task/rubrics/loaders.py +27 -4
- synth_ai/task/rubrics/scoring.py +3 -0
- synth_ai/task/rubrics.py +219 -0
- synth_ai/task/trace_correlation_helpers.py +328 -0
- synth_ai/task/tracing_utils.py +14 -3
- synth_ai/task/validators.py +145 -2
- synth_ai/tracing_v3/config.py +15 -13
- synth_ai/tracing_v3/constants.py +21 -0
- synth_ai/tracing_v3/db_config.py +3 -1
- synth_ai/tracing_v3/decorators.py +10 -7
- synth_ai/tracing_v3/session_tracer.py +10 -0
- synth_ai/tracing_v3/turso/daemon.py +2 -2
- synth_ai/tracing_v3/turso/native_manager.py +108 -77
- synth_ai/tracing_v3/utils.py +1 -1
- synth_ai/tui/__init__.py +5 -0
- synth_ai/tui/__main__.py +13 -0
- synth_ai/tui/cli/__init__.py +1 -0
- synth_ai/tui/cli/query_experiments.py +164 -0
- synth_ai/tui/cli/query_experiments_v3.py +164 -0
- synth_ai/tui/dashboard.py +911 -0
- synth_ai/utils/__init__.py +101 -0
- synth_ai/utils/base_url.py +94 -0
- synth_ai/utils/cli.py +131 -0
- synth_ai/utils/env.py +287 -0
- synth_ai/utils/http.py +169 -0
- synth_ai/utils/modal.py +308 -0
- synth_ai/utils/process.py +212 -0
- synth_ai/utils/prompts.py +39 -0
- synth_ai/utils/sqld.py +122 -0
- synth_ai/utils/task_app_discovery.py +882 -0
- synth_ai/utils/task_app_env.py +186 -0
- synth_ai/utils/task_app_state.py +318 -0
- synth_ai/utils/user_config.py +137 -0
- synth_ai/v0/config/__init__.py +1 -5
- synth_ai/v0/config/base_url.py +1 -7
- synth_ai/v0/tracing/config.py +1 -1
- synth_ai/v0/tracing/decorators.py +1 -1
- synth_ai/v0/tracing/upload.py +1 -1
- synth_ai/v0/tracing_v1/config.py +1 -1
- synth_ai/v0/tracing_v1/decorators.py +1 -1
- synth_ai/v0/tracing_v1/upload.py +1 -1
- {synth_ai-0.2.13.dev2.dist-info → synth_ai-0.2.16.dist-info}/METADATA +85 -31
- {synth_ai-0.2.13.dev2.dist-info → synth_ai-0.2.16.dist-info}/RECORD +286 -135
- synth_ai/cli/man.py +0 -106
- synth_ai/compound/cais.py +0 -0
- synth_ai/core/experiment.py +0 -13
- synth_ai/core/system.py +0 -15
- synth_ai/demo_registry.py +0 -295
- synth_ai/handshake.py +0 -109
- synth_ai/http.py +0 -26
- {synth_ai-0.2.13.dev2.dist-info → synth_ai-0.2.16.dist-info}/WHEEL +0 -0
- {synth_ai-0.2.13.dev2.dist-info → synth_ai-0.2.16.dist-info}/entry_points.txt +0 -0
- {synth_ai-0.2.13.dev2.dist-info → synth_ai-0.2.16.dist-info}/licenses/LICENSE +0 -0
- {synth_ai-0.2.13.dev2.dist-info → synth_ai-0.2.16.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,696 @@
|
|
|
1
|
+
"""Task App configuration for the GRPO Crafter example."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import logging
|
|
6
|
+
import os
|
|
7
|
+
import sys
|
|
8
|
+
from collections.abc import Iterable, Sequence
|
|
9
|
+
from contextlib import suppress
|
|
10
|
+
from dataclasses import dataclass
|
|
11
|
+
from pathlib import Path
|
|
12
|
+
from typing import Any
|
|
13
|
+
|
|
14
|
+
from synth_ai.task.apps import ModalDeploymentConfig, TaskAppEntry, register_task_app
|
|
15
|
+
from synth_ai.task.contracts import RolloutMetrics, RolloutRequest, RolloutResponse, TaskInfo
|
|
16
|
+
from synth_ai.task.datasets import TaskDatasetRegistry, TaskDatasetSpec
|
|
17
|
+
from synth_ai.task.json import to_jsonable # noqa: F401 (imported for side-effect compatibility)
|
|
18
|
+
from synth_ai.task.rubrics import load_rubric
|
|
19
|
+
from synth_ai.task.server import ProxyConfig, RubricBundle, TaskAppConfig
|
|
20
|
+
from synth_ai.task.tracing_utils import (
|
|
21
|
+
build_tracer_factory,
|
|
22
|
+
resolve_sft_output_dir,
|
|
23
|
+
resolve_tracing_db_url,
|
|
24
|
+
tracing_env_enabled,
|
|
25
|
+
)
|
|
26
|
+
from synth_ai.tracing_v3.session_tracer import SessionTracer
|
|
27
|
+
|
|
28
|
+
logger = logging.getLogger(__name__)
|
|
29
|
+
|
|
30
|
+
DEFAULT_ALIAS_OPS: list[str] = ["agent", "env"] * 10
|
|
31
|
+
DEFAULT_ALIAS_STEP_REWARDS: dict[str, Any] = {
|
|
32
|
+
"enabled": True,
|
|
33
|
+
"mode": "decision_stepwise",
|
|
34
|
+
"indicator_lambda": 1.0,
|
|
35
|
+
"step_beta": 0.0,
|
|
36
|
+
}
|
|
37
|
+
|
|
38
|
+
_HERE = Path(__file__).resolve()
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
def _resolve_repo_root() -> Path:
|
|
42
|
+
"""Best-effort detection of the Synth AI repo root across local and Modal mounts."""
|
|
43
|
+
|
|
44
|
+
candidates: list[Path] = []
|
|
45
|
+
env_root = os.getenv("SYNTH_AI_REPO_ROOT")
|
|
46
|
+
if env_root:
|
|
47
|
+
candidates.append(Path(env_root).expanduser())
|
|
48
|
+
candidates.append(Path("/opt/synth_ai_repo"))
|
|
49
|
+
candidates.extend(parent for parent in [_HERE.parent, *_HERE.parents])
|
|
50
|
+
|
|
51
|
+
for candidate in candidates:
|
|
52
|
+
try:
|
|
53
|
+
resolved = candidate.resolve()
|
|
54
|
+
except Exception:
|
|
55
|
+
continue
|
|
56
|
+
if not resolved.exists():
|
|
57
|
+
continue
|
|
58
|
+
if (resolved / "pyproject.toml").exists() or (resolved / "uv.lock").exists():
|
|
59
|
+
return resolved
|
|
60
|
+
if (resolved / "synth_ai").is_dir():
|
|
61
|
+
return resolved
|
|
62
|
+
|
|
63
|
+
try:
|
|
64
|
+
return _HERE.parents[3]
|
|
65
|
+
except IndexError:
|
|
66
|
+
return _HERE.parent
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
def _resolve_task_app_root(repo_root: Path) -> Path:
|
|
70
|
+
"""Locate the task_app directory even when the module is copied to a temp mount."""
|
|
71
|
+
|
|
72
|
+
preferred = (repo_root / "examples" / "warming_up_to_rl" / "task_app").resolve()
|
|
73
|
+
if preferred.is_dir():
|
|
74
|
+
return preferred
|
|
75
|
+
|
|
76
|
+
local_parent = _HERE.parent.resolve()
|
|
77
|
+
if (local_parent / "synth_envs_hosted").is_dir():
|
|
78
|
+
return local_parent
|
|
79
|
+
|
|
80
|
+
for parent in _HERE.parents:
|
|
81
|
+
candidate = parent.resolve()
|
|
82
|
+
if (candidate / "synth_envs_hosted").is_dir():
|
|
83
|
+
return candidate
|
|
84
|
+
|
|
85
|
+
fallback = Path("/opt/synth_ai_repo/examples/warming_up_to_rl/task_app")
|
|
86
|
+
if fallback.is_dir():
|
|
87
|
+
return fallback.resolve()
|
|
88
|
+
|
|
89
|
+
return local_parent
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
REPO_ROOT = _resolve_repo_root()
|
|
93
|
+
TASK_APP_ROOT = _resolve_task_app_root(REPO_ROOT)
|
|
94
|
+
SYNTH_ENVS_HOSTED_ROOT = (TASK_APP_ROOT / "synth_envs_hosted").resolve()
|
|
95
|
+
|
|
96
|
+
EXAMPLES_ROOT = (REPO_ROOT / "examples").resolve()
|
|
97
|
+
|
|
98
|
+
for path in (REPO_ROOT, TASK_APP_ROOT, SYNTH_ENVS_HOSTED_ROOT, EXAMPLES_ROOT):
|
|
99
|
+
try:
|
|
100
|
+
resolved = path.resolve()
|
|
101
|
+
except Exception:
|
|
102
|
+
resolved = path
|
|
103
|
+
if resolved.exists():
|
|
104
|
+
path_str = str(resolved)
|
|
105
|
+
if path_str not in sys.path:
|
|
106
|
+
sys.path.insert(0, path_str)
|
|
107
|
+
|
|
108
|
+
# Fallback: explicitly add Modal mount path for 'examples' if REPO_ROOT detection fails
|
|
109
|
+
try:
|
|
110
|
+
_hard_examples = Path("/opt/synth_ai_repo/examples")
|
|
111
|
+
if _hard_examples.exists():
|
|
112
|
+
_hard_examples_str = str(_hard_examples.resolve())
|
|
113
|
+
if _hard_examples_str not in sys.path:
|
|
114
|
+
sys.path.insert(0, _hard_examples_str)
|
|
115
|
+
except Exception:
|
|
116
|
+
pass
|
|
117
|
+
|
|
118
|
+
HAS_HOSTED = True
|
|
119
|
+
try:
|
|
120
|
+
import crafter # type: ignore
|
|
121
|
+
import crafter.constants as crafter_constants # type: ignore
|
|
122
|
+
from synth_ai.environments.examples.crafter_classic.taskset import TRAIT_BOUNDS
|
|
123
|
+
from synth_envs_hosted.branching import router as branching_router # type: ignore
|
|
124
|
+
from synth_envs_hosted.environment_routes import router as environment_router # type: ignore
|
|
125
|
+
from synth_envs_hosted.hosted_app import TaskApp as HostedTaskApp # type: ignore
|
|
126
|
+
from synth_envs_hosted.policy_routes import router as policy_router # type: ignore
|
|
127
|
+
from synth_envs_hosted.rollout import ( # type: ignore
|
|
128
|
+
RolloutEnvSpec as LegacyRolloutEnvSpec,
|
|
129
|
+
)
|
|
130
|
+
from synth_envs_hosted.rollout import (
|
|
131
|
+
RolloutPolicySpec as LegacyRolloutPolicySpec,
|
|
132
|
+
)
|
|
133
|
+
from synth_envs_hosted.rollout import (
|
|
134
|
+
RolloutRecordConfig as LegacyRolloutRecordConfig,
|
|
135
|
+
)
|
|
136
|
+
from synth_envs_hosted.rollout import (
|
|
137
|
+
RolloutRequest as LegacyRolloutRequest,
|
|
138
|
+
)
|
|
139
|
+
from synth_envs_hosted.rollout import (
|
|
140
|
+
RolloutResponse as LegacyRolloutResponse,
|
|
141
|
+
)
|
|
142
|
+
from synth_envs_hosted.rollout import (
|
|
143
|
+
RolloutSafetyConfig as LegacyRolloutSafetyConfig,
|
|
144
|
+
)
|
|
145
|
+
from synth_envs_hosted.rollout import (
|
|
146
|
+
execute_rollout as legacy_execute_rollout,
|
|
147
|
+
)
|
|
148
|
+
except Exception as exc: # pragma: no cover - import-time validation
|
|
149
|
+
# Provide a more actionable error with the missing module and fix hints
|
|
150
|
+
missing_mod = None
|
|
151
|
+
if isinstance(exc, ModuleNotFoundError):
|
|
152
|
+
missing_mod = (
|
|
153
|
+
getattr(exc, "name", None) or str(exc).split("'")[1] if "'" in str(exc) else None
|
|
154
|
+
)
|
|
155
|
+
fix_hint = None
|
|
156
|
+
if missing_mod:
|
|
157
|
+
mapping = {
|
|
158
|
+
"dotenv": "python-dotenv",
|
|
159
|
+
"crafter": "crafter",
|
|
160
|
+
"httpx": "httpx",
|
|
161
|
+
"aiohttp": "aiohttp",
|
|
162
|
+
"fastapi": "fastapi",
|
|
163
|
+
"uvicorn": "uvicorn",
|
|
164
|
+
"sqlalchemy": "sqlalchemy",
|
|
165
|
+
"aiosqlite": "aiosqlite",
|
|
166
|
+
"greenlet": "greenlet",
|
|
167
|
+
}
|
|
168
|
+
pkg = mapping.get(missing_mod, missing_mod)
|
|
169
|
+
fix_hint = (
|
|
170
|
+
f"Missing Python module '{missing_mod}'. Install the package '{pkg}'.\n"
|
|
171
|
+
f"For Modal: add '{pkg}' to ModalDeploymentConfig.pip_packages in synth_ai/task/apps/grpo_crafter.py.\n"
|
|
172
|
+
f"Locally: pip install {pkg}"
|
|
173
|
+
)
|
|
174
|
+
# Allow running without synth_envs_hosted; gate hosted features off
|
|
175
|
+
if missing_mod == "synth_envs_hosted":
|
|
176
|
+
HAS_HOSTED = False
|
|
177
|
+
else:
|
|
178
|
+
detailed = (
|
|
179
|
+
"grpo_crafter task app requires example dependencies and runtime libs.\n"
|
|
180
|
+
+ (fix_hint + "\n" if fix_hint else "")
|
|
181
|
+
+ f"Original error: {exc}"
|
|
182
|
+
)
|
|
183
|
+
raise RuntimeError(detailed) from exc
|
|
184
|
+
|
|
185
|
+
|
|
186
|
+
CRAFTING_RULES_SYSTEM_HINT = (
|
|
187
|
+
"Crafter crafting rules (from the paper):\n"
|
|
188
|
+
"- Make Wood Pickaxe: Nearby a table; have wood in inventory.\n"
|
|
189
|
+
"- Make Stone Pickaxe: Nearby a table; have wood and stone in inventory.\n"
|
|
190
|
+
"- Make Iron Pickaxe: Nearby a table; furnace exists; have wood, coal, and iron in inventory.\n"
|
|
191
|
+
"- Make Wood Sword: Nearby a table; have wood in inventory.\n"
|
|
192
|
+
"- Make Stone Sword: Nearby a table; have wood and stone in inventory.\n"
|
|
193
|
+
"- Make Iron Sword: Nearby a table; furnace exists; have wood, coal, and iron in inventory."
|
|
194
|
+
)
|
|
195
|
+
|
|
196
|
+
|
|
197
|
+
DATASET_SPEC = TaskDatasetSpec(
|
|
198
|
+
id="crafter_classic_procedural",
|
|
199
|
+
name="Crafter Classic Procedural Seeds",
|
|
200
|
+
version="1.0.0",
|
|
201
|
+
splits=["train"],
|
|
202
|
+
default_split="train",
|
|
203
|
+
description="Procedural Crafter Classic seeds with reproducible world traits.",
|
|
204
|
+
)
|
|
205
|
+
|
|
206
|
+
|
|
207
|
+
@dataclass
|
|
208
|
+
class CrafterDataset:
|
|
209
|
+
spec: TaskDatasetSpec
|
|
210
|
+
|
|
211
|
+
def __post_init__(self) -> None:
|
|
212
|
+
self.default_seed = int(env_value("CRAFTER_DEFAULT_SEED", 42))
|
|
213
|
+
self.seed_min = 0
|
|
214
|
+
self.seed_max = int(env_value("CRAFTER_MAX_SEED", 2**31 - 1))
|
|
215
|
+
area_env = env_value("CRAFTER_AREA", "64,64")
|
|
216
|
+
self.area = tuple(int(x) for x in str(area_env).split(","))
|
|
217
|
+
self.length = int(env_value("CRAFTER_EPISODE_LENGTH", 10000))
|
|
218
|
+
self._cache: dict[int, dict[str, Any]] = {}
|
|
219
|
+
|
|
220
|
+
def config_for_seed(self, seed: int) -> dict[str, Any]:
|
|
221
|
+
return {
|
|
222
|
+
"seed": int(seed),
|
|
223
|
+
"area": list(self.area),
|
|
224
|
+
"length": self.length,
|
|
225
|
+
}
|
|
226
|
+
|
|
227
|
+
def describe_seed(self, seed: int) -> dict[str, Any]:
|
|
228
|
+
seed = int(seed)
|
|
229
|
+
if seed in self._cache:
|
|
230
|
+
return self._cache[seed]
|
|
231
|
+
env = crafter.Env(area=self.area, length=self.length, seed=seed)
|
|
232
|
+
try:
|
|
233
|
+
env.reset()
|
|
234
|
+
traits = _compute_world_traits(env)
|
|
235
|
+
player = getattr(env, "_player", None)
|
|
236
|
+
inventory = dict(getattr(player, "inventory", {})) if player else {}
|
|
237
|
+
position = getattr(player, "pos", None)
|
|
238
|
+
finally:
|
|
239
|
+
close_fn = getattr(env, "close", None)
|
|
240
|
+
if callable(close_fn):
|
|
241
|
+
close_fn()
|
|
242
|
+
summary = {
|
|
243
|
+
"seed": seed,
|
|
244
|
+
"difficulty": self._difficulty(traits),
|
|
245
|
+
"traits": traits,
|
|
246
|
+
"inventory": inventory,
|
|
247
|
+
"player_position": list(position) if position is not None else None,
|
|
248
|
+
"config": self.config_for_seed(seed),
|
|
249
|
+
}
|
|
250
|
+
self._cache[seed] = summary
|
|
251
|
+
return summary
|
|
252
|
+
|
|
253
|
+
def _difficulty(self, traits: dict[str, int]) -> str:
|
|
254
|
+
for difficulty, bounds in TRAIT_BOUNDS.items():
|
|
255
|
+
if traits.get("trees", 0) >= bounds.get("min_trees", 0) and traits.get(
|
|
256
|
+
"hostiles", 0
|
|
257
|
+
) <= bounds.get("max_hostiles", 0):
|
|
258
|
+
return difficulty
|
|
259
|
+
return "custom"
|
|
260
|
+
|
|
261
|
+
@property
|
|
262
|
+
def seed_range(self) -> list[int]:
|
|
263
|
+
return [self.seed_min, self.seed_max]
|
|
264
|
+
|
|
265
|
+
|
|
266
|
+
def _compute_world_traits(env: crafter.Env, radius: int = 10) -> dict[str, int]:
|
|
267
|
+
# Local copy to avoid import-time issues; mirrors synth_ai.environments.examples.crafter_classic.taskset.world_traits
|
|
268
|
+
import numpy as _np # type: ignore
|
|
269
|
+
from crafter import objects as _objects # type: ignore
|
|
270
|
+
|
|
271
|
+
player = getattr(env, "_player", None)
|
|
272
|
+
if player is None:
|
|
273
|
+
return {"trees": 0, "cows": 0, "hostiles": 0}
|
|
274
|
+
pos = _np.array(getattr(player, "pos", [0, 0]))
|
|
275
|
+
counts = {"trees": 0, "cows": 0, "hostiles": 0}
|
|
276
|
+
world = getattr(env, "_world", None)
|
|
277
|
+
objects = getattr(world, "_objects", []) if world is not None else []
|
|
278
|
+
for obj in objects:
|
|
279
|
+
if obj is None or obj is player:
|
|
280
|
+
continue
|
|
281
|
+
try:
|
|
282
|
+
if _np.abs(obj.pos - pos).sum() > radius:
|
|
283
|
+
continue
|
|
284
|
+
except Exception:
|
|
285
|
+
continue
|
|
286
|
+
if isinstance(obj, _objects.Plant) and getattr(obj, "kind", "") == "tree":
|
|
287
|
+
counts["trees"] += 1
|
|
288
|
+
elif isinstance(obj, _objects.Cow):
|
|
289
|
+
counts["cows"] += 1
|
|
290
|
+
elif isinstance(obj, _objects.Zombie | _objects.Skeleton):
|
|
291
|
+
counts["hostiles"] += 1
|
|
292
|
+
return counts
|
|
293
|
+
|
|
294
|
+
|
|
295
|
+
def env_value(key: str, default: Any) -> Any:
|
|
296
|
+
return os.getenv(key, default)
|
|
297
|
+
|
|
298
|
+
|
|
299
|
+
def build_dataset() -> tuple[TaskDatasetRegistry, CrafterDataset]:
|
|
300
|
+
registry = TaskDatasetRegistry()
|
|
301
|
+
dataset = CrafterDataset(DATASET_SPEC)
|
|
302
|
+
registry.register(DATASET_SPEC, lambda _spec: dataset, cache=True)
|
|
303
|
+
return registry, dataset
|
|
304
|
+
|
|
305
|
+
|
|
306
|
+
def _base_task_info(dataset: CrafterDataset) -> TaskInfo:
|
|
307
|
+
return TaskInfo(
|
|
308
|
+
task={"id": "crafter_classic", "name": "Crafter Classic", "version": "1.0.0"},
|
|
309
|
+
environments=["crafter"],
|
|
310
|
+
action_space={
|
|
311
|
+
"type": "discrete",
|
|
312
|
+
"size": len(crafter_constants.actions),
|
|
313
|
+
"actions": list(crafter_constants.actions),
|
|
314
|
+
},
|
|
315
|
+
observation={
|
|
316
|
+
"summary": "RGB frame plus inventory, achievements, and semantic map patches.",
|
|
317
|
+
"keys": ["image", "inventory", "achievements", "semantic_map_patch7"],
|
|
318
|
+
"image_shape": [64, 64, 3],
|
|
319
|
+
},
|
|
320
|
+
dataset={
|
|
321
|
+
**DATASET_SPEC.model_dump(),
|
|
322
|
+
"seed_range": dataset.seed_range,
|
|
323
|
+
"default_seed": dataset.default_seed,
|
|
324
|
+
},
|
|
325
|
+
rubric={
|
|
326
|
+
"version": "1",
|
|
327
|
+
"criteria_count": 2,
|
|
328
|
+
"source": "inline",
|
|
329
|
+
"aggregation": "weighted_sum",
|
|
330
|
+
},
|
|
331
|
+
inference={
|
|
332
|
+
"supports_proxy": True,
|
|
333
|
+
"endpoints": {
|
|
334
|
+
"openai": "/proxy/v1/chat/completions",
|
|
335
|
+
"groq": "/proxy/groq/v1/chat/completions",
|
|
336
|
+
},
|
|
337
|
+
"tool": {"name": "interact", "parallel_tool_calls": False},
|
|
338
|
+
},
|
|
339
|
+
capabilities={
|
|
340
|
+
"supports_rollout": True,
|
|
341
|
+
"supports_env_lifecycle": True,
|
|
342
|
+
"requires_api_key_header": True,
|
|
343
|
+
},
|
|
344
|
+
limits={"max_ops": 100000, "max_time_s": 3600},
|
|
345
|
+
)
|
|
346
|
+
|
|
347
|
+
|
|
348
|
+
OUTCOME_RUBRIC = load_rubric(
|
|
349
|
+
{
|
|
350
|
+
"version": "1",
|
|
351
|
+
"goal_text": "Reward unlocking Crafter achievements and survival.",
|
|
352
|
+
"aggregation": "weighted_sum",
|
|
353
|
+
"criteria": [
|
|
354
|
+
{
|
|
355
|
+
"id": "achievements",
|
|
356
|
+
"description": "Unlock achievements or crafting milestones.",
|
|
357
|
+
"weight": 1.0,
|
|
358
|
+
},
|
|
359
|
+
{
|
|
360
|
+
"id": "survival",
|
|
361
|
+
"description": "Maintain health, food, and drink levels.",
|
|
362
|
+
"weight": 1.0,
|
|
363
|
+
},
|
|
364
|
+
],
|
|
365
|
+
}
|
|
366
|
+
)
|
|
367
|
+
|
|
368
|
+
EVENTS_RUBRIC = load_rubric(
|
|
369
|
+
{
|
|
370
|
+
"version": "1",
|
|
371
|
+
"goal_text": "Encourage purposeful step-wise exploration and crafting.",
|
|
372
|
+
"aggregation": "weighted_sum",
|
|
373
|
+
"criteria": [
|
|
374
|
+
{
|
|
375
|
+
"id": "progress_steps",
|
|
376
|
+
"description": "Actions progress quests, crafting, or exploration.",
|
|
377
|
+
"weight": 1.0,
|
|
378
|
+
}
|
|
379
|
+
],
|
|
380
|
+
}
|
|
381
|
+
)
|
|
382
|
+
|
|
383
|
+
|
|
384
|
+
def describe_taskset(dataset: CrafterDataset) -> dict[str, Any]:
|
|
385
|
+
return {
|
|
386
|
+
**DATASET_SPEC.model_dump(),
|
|
387
|
+
"seed_range": dataset.seed_range,
|
|
388
|
+
"default_seed": dataset.default_seed,
|
|
389
|
+
"config": {
|
|
390
|
+
"area": list(dataset.area),
|
|
391
|
+
"length": dataset.length,
|
|
392
|
+
},
|
|
393
|
+
}
|
|
394
|
+
|
|
395
|
+
|
|
396
|
+
def provide_task_instances(
|
|
397
|
+
dataset: CrafterDataset, base_info: TaskInfo, seeds: Sequence[int]
|
|
398
|
+
) -> Iterable[TaskInfo]:
|
|
399
|
+
infos: list[TaskInfo] = []
|
|
400
|
+
for seed_value in seeds:
|
|
401
|
+
summary = dataset.describe_seed(seed_value)
|
|
402
|
+
infos.append(
|
|
403
|
+
TaskInfo(
|
|
404
|
+
task=base_info.task,
|
|
405
|
+
environments=base_info.environments,
|
|
406
|
+
action_space=base_info.action_space,
|
|
407
|
+
observation={
|
|
408
|
+
**base_info.observation,
|
|
409
|
+
"seed": seed_value,
|
|
410
|
+
"traits": summary["traits"],
|
|
411
|
+
"inventory": summary["inventory"],
|
|
412
|
+
"player_position": summary["player_position"],
|
|
413
|
+
},
|
|
414
|
+
dataset={
|
|
415
|
+
**base_info.dataset,
|
|
416
|
+
"seed": seed_value,
|
|
417
|
+
"difficulty": summary["difficulty"],
|
|
418
|
+
"config": summary["config"],
|
|
419
|
+
},
|
|
420
|
+
rubric=base_info.rubric,
|
|
421
|
+
inference=base_info.inference,
|
|
422
|
+
capabilities=base_info.capabilities,
|
|
423
|
+
limits=base_info.limits,
|
|
424
|
+
)
|
|
425
|
+
)
|
|
426
|
+
return infos
|
|
427
|
+
|
|
428
|
+
|
|
429
|
+
def _normalise_op(op_value: Any, index: int) -> str:
|
|
430
|
+
if isinstance(op_value, str):
|
|
431
|
+
candidate = op_value
|
|
432
|
+
elif isinstance(op_value, dict):
|
|
433
|
+
candidate = op_value.get("type") or op_value.get("op")
|
|
434
|
+
else:
|
|
435
|
+
candidate = None
|
|
436
|
+
if not candidate:
|
|
437
|
+
raise ValueError(f"Missing op type at index {index}")
|
|
438
|
+
lowered = str(candidate).strip().lower()
|
|
439
|
+
if lowered in {"policy", "agent", "model"}:
|
|
440
|
+
return "agent"
|
|
441
|
+
if lowered in {"env", "environment", "step"}:
|
|
442
|
+
return "env"
|
|
443
|
+
raise ValueError(f"Unsupported op type '{candidate}' at index {index}")
|
|
444
|
+
|
|
445
|
+
|
|
446
|
+
def _coerce_math_to_crafter(request: RolloutRequest) -> RolloutRequest:
|
|
447
|
+
"""Map legacy math env/policy names to crafter and enrich rollout defaults."""
|
|
448
|
+
|
|
449
|
+
def _needs_crafter(name: str | None) -> bool:
|
|
450
|
+
if not name:
|
|
451
|
+
return False
|
|
452
|
+
lowered = str(name).strip().lower()
|
|
453
|
+
return lowered.startswith("math")
|
|
454
|
+
|
|
455
|
+
env_updates: dict[str, Any] = {}
|
|
456
|
+
policy_updates: dict[str, Any] = {}
|
|
457
|
+
alias_applied = False
|
|
458
|
+
|
|
459
|
+
if _needs_crafter(request.env.env_name):
|
|
460
|
+
env_updates["env_name"] = "crafter"
|
|
461
|
+
alias_applied = True
|
|
462
|
+
if request.env.env_id and _needs_crafter(request.env.env_id):
|
|
463
|
+
env_updates["env_id"] = None
|
|
464
|
+
alias_applied = True
|
|
465
|
+
if _needs_crafter(request.policy.policy_name):
|
|
466
|
+
policy_updates["policy_name"] = "crafter-react"
|
|
467
|
+
alias_applied = True
|
|
468
|
+
if request.policy.policy_id and _needs_crafter(request.policy.policy_id):
|
|
469
|
+
policy_updates["policy_id"] = None
|
|
470
|
+
alias_applied = True
|
|
471
|
+
|
|
472
|
+
if not alias_applied:
|
|
473
|
+
return request
|
|
474
|
+
|
|
475
|
+
updated_env = request.env.model_copy(update=env_updates) if env_updates else request.env
|
|
476
|
+
updated_policy = (
|
|
477
|
+
request.policy.model_copy(update=policy_updates) if policy_updates else request.policy
|
|
478
|
+
)
|
|
479
|
+
|
|
480
|
+
env_cfg = dict(updated_env.config or {})
|
|
481
|
+
env_cfg.setdefault("difficulty", "normal")
|
|
482
|
+
env_cfg.setdefault("step_rewards", dict(DEFAULT_ALIAS_STEP_REWARDS))
|
|
483
|
+
env_cfg.setdefault("env_params", {"max_steps_per_episode": 200})
|
|
484
|
+
updated_env = updated_env.model_copy(update={"config": env_cfg})
|
|
485
|
+
|
|
486
|
+
policy_cfg = dict(updated_policy.config or {})
|
|
487
|
+
policy_cfg.setdefault("max_llm_calls", 10)
|
|
488
|
+
policy_cfg.setdefault("max_completion_tokens", 1024)
|
|
489
|
+
policy_cfg.setdefault("temperature", 0.2)
|
|
490
|
+
policy_cfg.setdefault("step_rewards", dict(DEFAULT_ALIAS_STEP_REWARDS))
|
|
491
|
+
updated_policy = updated_policy.model_copy(update={"config": policy_cfg})
|
|
492
|
+
|
|
493
|
+
ops_override = request.ops
|
|
494
|
+
if not ops_override or len(ops_override) < len(DEFAULT_ALIAS_OPS):
|
|
495
|
+
ops_override = list(DEFAULT_ALIAS_OPS)
|
|
496
|
+
|
|
497
|
+
coerced = request.model_copy(update={"env": updated_env, "policy": updated_policy, "ops": ops_override})
|
|
498
|
+
|
|
499
|
+
with suppress(Exception):
|
|
500
|
+
print(
|
|
501
|
+
"[rollout] remapped math request -> crafter "
|
|
502
|
+
f"(env={request.env.env_name!r}→{coerced.env.env_name!r}, "
|
|
503
|
+
f"policy={request.policy.policy_name!r}→{coerced.policy.policy_name!r})",
|
|
504
|
+
flush=True,
|
|
505
|
+
)
|
|
506
|
+
with suppress(Exception):
|
|
507
|
+
logger.info(
|
|
508
|
+
"ROLLOUT_ALIAS: remapped math env/policy to crafter (env=%s→%s, policy=%s→%s)",
|
|
509
|
+
request.env.env_name,
|
|
510
|
+
coerced.env.env_name,
|
|
511
|
+
request.policy.policy_name,
|
|
512
|
+
coerced.policy.policy_name,
|
|
513
|
+
)
|
|
514
|
+
|
|
515
|
+
return coerced
|
|
516
|
+
|
|
517
|
+
|
|
518
|
+
async def rollout_executor(request: RolloutRequest, fastapi_request) -> RolloutResponse:
|
|
519
|
+
# If hosted env service code is not bundled, return a no-op rollout response compatible with contracts
|
|
520
|
+
if not HAS_HOSTED:
|
|
521
|
+
return RolloutResponse(
|
|
522
|
+
run_id=request.run_id,
|
|
523
|
+
trajectories=[],
|
|
524
|
+
branches={},
|
|
525
|
+
metrics=RolloutMetrics(
|
|
526
|
+
episode_returns=[],
|
|
527
|
+
mean_return=0.0,
|
|
528
|
+
num_steps=0,
|
|
529
|
+
num_episodes=0,
|
|
530
|
+
details={},
|
|
531
|
+
),
|
|
532
|
+
aborted=False,
|
|
533
|
+
ops_executed=0,
|
|
534
|
+
trace=None,
|
|
535
|
+
)
|
|
536
|
+
|
|
537
|
+
request = _coerce_math_to_crafter(request)
|
|
538
|
+
|
|
539
|
+
policy_cfg = dict(request.policy.config or {})
|
|
540
|
+
try:
|
|
541
|
+
max_llm_calls = int(policy_cfg.get("max_llm_calls") or 10)
|
|
542
|
+
except Exception:
|
|
543
|
+
max_llm_calls = 10
|
|
544
|
+
policy_cfg.setdefault("max_llm_calls", max_llm_calls)
|
|
545
|
+
policy_cfg.setdefault("max_tokens", 512)
|
|
546
|
+
policy_cfg.setdefault("max_completion_tokens", 512)
|
|
547
|
+
policy_cfg.setdefault("temperature", 0.2)
|
|
548
|
+
policy_cfg.setdefault("top_p", 0.95)
|
|
549
|
+
|
|
550
|
+
env_cfg = dict(request.env.config or {})
|
|
551
|
+
env_params = dict(env_cfg.get("env_params") or {})
|
|
552
|
+
try:
|
|
553
|
+
max_steps_episode = int(env_params.get("max_steps_per_episode") or max_llm_calls)
|
|
554
|
+
except Exception:
|
|
555
|
+
max_steps_episode = max_llm_calls
|
|
556
|
+
desired_steps = max(max_llm_calls, max_steps_episode)
|
|
557
|
+
env_params["max_steps_per_episode"] = int(desired_steps)
|
|
558
|
+
env_cfg["env_params"] = env_params
|
|
559
|
+
|
|
560
|
+
updated_policy = request.policy.model_copy(update={"config": policy_cfg})
|
|
561
|
+
updated_env = request.env.model_copy(update={"config": env_cfg})
|
|
562
|
+
request = request.model_copy(update={"policy": updated_policy, "env": updated_env})
|
|
563
|
+
|
|
564
|
+
converted_ops: list[str] = [_normalise_op(op, idx) for idx, op in enumerate(request.ops)]
|
|
565
|
+
max_ops_allowed = max_llm_calls * 2 if max_llm_calls > 0 else len(converted_ops)
|
|
566
|
+
if max_ops_allowed and len(converted_ops) > max_ops_allowed:
|
|
567
|
+
converted_ops = converted_ops[:max_ops_allowed]
|
|
568
|
+
legacy_request = LegacyRolloutRequest(
|
|
569
|
+
run_id=request.run_id,
|
|
570
|
+
env=LegacyRolloutEnvSpec(
|
|
571
|
+
env_id=request.env.env_id,
|
|
572
|
+
env_name=request.env.env_name,
|
|
573
|
+
config=env_cfg,
|
|
574
|
+
seed=request.env.seed,
|
|
575
|
+
),
|
|
576
|
+
policy=LegacyRolloutPolicySpec(
|
|
577
|
+
policy_id=request.policy.policy_id,
|
|
578
|
+
policy_name=request.policy.policy_name,
|
|
579
|
+
config=policy_cfg,
|
|
580
|
+
),
|
|
581
|
+
ops=converted_ops,
|
|
582
|
+
record=LegacyRolloutRecordConfig(**request.record.model_dump()),
|
|
583
|
+
on_done=request.on_done,
|
|
584
|
+
branch=None,
|
|
585
|
+
safety=LegacyRolloutSafetyConfig(**request.safety.model_dump()),
|
|
586
|
+
training_session_id=request.training_session_id,
|
|
587
|
+
synth_base_url=request.synth_base_url,
|
|
588
|
+
)
|
|
589
|
+
|
|
590
|
+
legacy_response: LegacyRolloutResponse = await legacy_execute_rollout(
|
|
591
|
+
legacy_request, fastapi_request
|
|
592
|
+
)
|
|
593
|
+
data = legacy_response.model_dump()
|
|
594
|
+
metrics = data.get("metrics", {}) or {}
|
|
595
|
+
metrics.setdefault("outcome_score", None)
|
|
596
|
+
metrics.setdefault("events_score", None)
|
|
597
|
+
metrics.setdefault("details", {})
|
|
598
|
+
data["metrics"] = metrics
|
|
599
|
+
return RolloutResponse.model_validate(data)
|
|
600
|
+
|
|
601
|
+
|
|
602
|
+
def build_config() -> TaskAppConfig:
|
|
603
|
+
registry, dataset = build_dataset()
|
|
604
|
+
base_info = _base_task_info(dataset)
|
|
605
|
+
|
|
606
|
+
hosted_task_app = HostedTaskApp() if HAS_HOSTED else None
|
|
607
|
+
|
|
608
|
+
tracing_enabled = tracing_env_enabled()
|
|
609
|
+
tracing_db_url = resolve_tracing_db_url()
|
|
610
|
+
tracer_factory = build_tracer_factory(
|
|
611
|
+
SessionTracer, enabled=tracing_enabled, db_url=tracing_db_url
|
|
612
|
+
)
|
|
613
|
+
sft_output_dir = resolve_sft_output_dir()
|
|
614
|
+
|
|
615
|
+
app_state: dict[str, Any] = {
|
|
616
|
+
"task_app": hosted_task_app,
|
|
617
|
+
"allowed_environments": ["crafter"],
|
|
618
|
+
"tracing_enabled": tracing_enabled,
|
|
619
|
+
}
|
|
620
|
+
if tracer_factory is not None:
|
|
621
|
+
app_state["session_tracer_factory"] = tracer_factory
|
|
622
|
+
if sft_output_dir:
|
|
623
|
+
app_state["sft_output_dir"] = sft_output_dir
|
|
624
|
+
|
|
625
|
+
if tracing_enabled:
|
|
626
|
+
status_msg = f"[task:tracing] enabled (db={tracing_db_url or 'default'})"
|
|
627
|
+
else:
|
|
628
|
+
status_msg = "[task:tracing] disabled"
|
|
629
|
+
print(status_msg, flush=True)
|
|
630
|
+
if sft_output_dir:
|
|
631
|
+
print(f"[task:sft] writing JSONL to {sft_output_dir}", flush=True)
|
|
632
|
+
|
|
633
|
+
def _describe_taskset() -> dict[str, Any]:
|
|
634
|
+
return describe_taskset(dataset)
|
|
635
|
+
|
|
636
|
+
def _provide_instances(seeds: Sequence[int]):
|
|
637
|
+
return provide_task_instances(dataset, base_info, seeds)
|
|
638
|
+
|
|
639
|
+
routers: tuple = (environment_router, policy_router, branching_router) if HAS_HOSTED else ()
|
|
640
|
+
|
|
641
|
+
config = TaskAppConfig(
|
|
642
|
+
app_id="grpo-crafter",
|
|
643
|
+
name="GRPO Crafter Task App",
|
|
644
|
+
description="Crafter Classic environment with GRPO task endpoints and LLM proxies.",
|
|
645
|
+
base_task_info=base_info,
|
|
646
|
+
describe_taskset=_describe_taskset,
|
|
647
|
+
provide_task_instances=_provide_instances,
|
|
648
|
+
rollout=rollout_executor,
|
|
649
|
+
dataset_registry=registry,
|
|
650
|
+
rubrics=RubricBundle(outcome=OUTCOME_RUBRIC, events=EVENTS_RUBRIC),
|
|
651
|
+
proxy=ProxyConfig(
|
|
652
|
+
enable_openai=True, enable_groq=True, system_hint=CRAFTING_RULES_SYSTEM_HINT
|
|
653
|
+
),
|
|
654
|
+
routers=routers,
|
|
655
|
+
app_state=app_state,
|
|
656
|
+
cors_origins=["*"],
|
|
657
|
+
)
|
|
658
|
+
return config
|
|
659
|
+
|
|
660
|
+
|
|
661
|
+
register_task_app(
|
|
662
|
+
entry=TaskAppEntry(
|
|
663
|
+
app_id="grpo-crafter",
|
|
664
|
+
description="Crafter Classic task app with rollout + proxy endpoints",
|
|
665
|
+
config_factory=build_config,
|
|
666
|
+
aliases=("crafter", "crafter-task"),
|
|
667
|
+
modal=ModalDeploymentConfig(
|
|
668
|
+
app_name="grpo-crafter-task-app",
|
|
669
|
+
python_version="3.11",
|
|
670
|
+
pip_packages=(
|
|
671
|
+
"fastapi>=0.100.0",
|
|
672
|
+
"uvicorn>=0.23.0",
|
|
673
|
+
"pydantic>=2.0.0",
|
|
674
|
+
"numpy>=1.24.0",
|
|
675
|
+
"aiohttp>=3.8.0",
|
|
676
|
+
"httpx>=0.24.0",
|
|
677
|
+
"python-dotenv>=1.0.1",
|
|
678
|
+
# Tracing/DB runtime deps
|
|
679
|
+
"sqlalchemy>=2.0.42",
|
|
680
|
+
"aiosqlite>=0.21.0",
|
|
681
|
+
"greenlet>=3.2.3",
|
|
682
|
+
"crafter",
|
|
683
|
+
),
|
|
684
|
+
extra_local_dirs=(
|
|
685
|
+
# Mount repo root so local modules resolve when deployed on Modal
|
|
686
|
+
(str(REPO_ROOT), "/opt/synth_ai_repo"),
|
|
687
|
+
(str(REPO_ROOT / "synth_ai"), "/opt/synth_ai_repo/synth_ai"),
|
|
688
|
+
(str(TASK_APP_ROOT), "/opt/synth_ai_repo/examples/warming_up_to_rl/task_app"),
|
|
689
|
+
),
|
|
690
|
+
secret_names=("groq-api-key", "openai-api-key"),
|
|
691
|
+
memory=16384,
|
|
692
|
+
cpu=4.0,
|
|
693
|
+
max_containers=10,
|
|
694
|
+
),
|
|
695
|
+
)
|
|
696
|
+
)
|