synth-ai 0.2.13.dev2__py3-none-any.whl → 0.2.16__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of synth-ai might be problematic. Click here for more details.
- examples/README.md +1 -0
- examples/multi_step/SFT_README.md +147 -0
- examples/multi_step/configs/README_verilog_rl.md +77 -0
- examples/multi_step/configs/VERILOG_REWARDS.md +90 -0
- examples/multi_step/configs/VERILOG_RL_CHECKLIST.md +183 -0
- examples/multi_step/configs/crafter_eval_synth_qwen4b.toml +35 -0
- examples/multi_step/configs/crafter_eval_text_only_groq_qwen32b.toml +36 -0
- examples/multi_step/configs/crafter_rl_stepwise_hosted_judge.toml +12 -11
- examples/multi_step/configs/crafter_sft_qwen30b_lora.toml +62 -0
- examples/multi_step/configs/crafter_synth_backend.md +40 -0
- examples/multi_step/configs/verilog_eval_groq_qwen32b.toml +31 -0
- examples/multi_step/configs/verilog_eval_synth_qwen8b.toml +33 -0
- examples/multi_step/configs/verilog_rl_lora.toml +190 -0
- examples/multi_step/convert_traces_to_sft.py +84 -0
- examples/multi_step/judges/crafter_backend_judge.py +220 -0
- examples/multi_step/judges/verilog_backend_judge.py +234 -0
- examples/multi_step/readme.md +48 -0
- examples/multi_step/run_sft_qwen30b.sh +45 -0
- examples/multi_step/verilog_rl_lora.md +218 -0
- examples/qwen_coder/configs/coder_lora_30b.toml +3 -2
- examples/qwen_coder/configs/coder_lora_4b.toml +2 -1
- examples/qwen_coder/configs/coder_lora_small.toml +2 -1
- examples/qwen_vl/BUGS_AND_FIXES.md +232 -0
- examples/qwen_vl/IMAGE_VALIDATION_COMPLETE.md +271 -0
- examples/qwen_vl/IMAGE_VALIDATION_SUMMARY.md +260 -0
- examples/qwen_vl/INFERENCE_SFT_TESTS.md +412 -0
- examples/qwen_vl/NEXT_STEPS_2B.md +325 -0
- examples/qwen_vl/QUICKSTART.md +327 -0
- examples/qwen_vl/QUICKSTART_RL_VISION.md +110 -0
- examples/qwen_vl/README.md +154 -0
- examples/qwen_vl/RL_VISION_COMPLETE.md +475 -0
- examples/qwen_vl/RL_VISION_TESTING.md +333 -0
- examples/qwen_vl/SDK_VISION_INTEGRATION.md +328 -0
- examples/qwen_vl/SETUP_COMPLETE.md +275 -0
- examples/qwen_vl/VISION_TESTS_COMPLETE.md +490 -0
- examples/qwen_vl/VLM_PIPELINE_COMPLETE.md +242 -0
- examples/qwen_vl/__init__.py +2 -0
- examples/qwen_vl/collect_data_via_cli.md +423 -0
- examples/qwen_vl/collect_vision_traces.py +368 -0
- examples/qwen_vl/configs/crafter_rl_vision_qwen3vl4b.toml +127 -0
- examples/qwen_vl/configs/crafter_vlm_sft_example.toml +60 -0
- examples/qwen_vl/configs/eval_gpt4o_mini_vision.toml +43 -0
- examples/qwen_vl/configs/eval_gpt4o_vision_proper.toml +29 -0
- examples/qwen_vl/configs/eval_gpt5nano_vision.toml +45 -0
- examples/qwen_vl/configs/eval_qwen2vl_vision.toml +44 -0
- examples/qwen_vl/configs/filter_qwen2vl_sft.toml +50 -0
- examples/qwen_vl/configs/filter_vision_sft.toml +53 -0
- examples/qwen_vl/configs/filter_vision_test.toml +8 -0
- examples/qwen_vl/configs/sft_qwen3_vl_2b_test.toml +54 -0
- examples/qwen_vl/crafter_gpt5nano_agent.py +308 -0
- examples/qwen_vl/crafter_qwen_vl_agent.py +300 -0
- examples/qwen_vl/run_vision_comparison.sh +62 -0
- examples/qwen_vl/run_vision_sft_pipeline.sh +175 -0
- examples/qwen_vl/test_image_validation.py +201 -0
- examples/qwen_vl/test_sft_vision_data.py +110 -0
- examples/rl/README.md +1 -1
- examples/rl/configs/eval_base_qwen.toml +17 -0
- examples/rl/configs/eval_rl_qwen.toml +13 -0
- examples/rl/configs/rl_from_base_qwen.toml +37 -0
- examples/rl/configs/rl_from_base_qwen17.toml +76 -0
- examples/rl/configs/rl_from_ft_qwen.toml +37 -0
- examples/rl/run_eval.py +436 -0
- examples/rl/run_rl_and_save.py +111 -0
- examples/rl/task_app/README.md +22 -0
- examples/rl/task_app/math_single_step.py +990 -0
- examples/rl/task_app/math_task_app.py +111 -0
- examples/sft/README.md +5 -5
- examples/sft/configs/crafter_fft_qwen0p6b.toml +4 -2
- examples/sft/configs/crafter_lora_qwen0p6b.toml +4 -3
- examples/sft/evaluate.py +4 -4
- examples/sft/export_dataset.py +7 -4
- examples/sft/generate_traces.py +2 -0
- examples/swe/task_app/README.md +1 -1
- examples/swe/task_app/grpo_swe_mini.py +1 -1
- examples/swe/task_app/grpo_swe_mini_task_app.py +0 -12
- examples/swe/task_app/hosted/envs/mini_swe/environment.py +13 -13
- examples/swe/task_app/hosted/policy_routes.py +0 -2
- examples/swe/task_app/hosted/rollout.py +2 -8
- examples/task_apps/IMAGE_ONLY_EVAL_QUICKSTART.md +258 -0
- examples/task_apps/crafter/CREATE_SFT_DATASET.md +273 -0
- examples/task_apps/crafter/EVAL_IMAGE_ONLY_RESULTS.md +152 -0
- examples/task_apps/crafter/FILTER_COMMAND_STATUS.md +174 -0
- examples/task_apps/crafter/FILTER_COMMAND_SUCCESS.md +268 -0
- examples/task_apps/crafter/QUERY_EXAMPLES.md +203 -0
- examples/task_apps/crafter/README_IMAGE_ONLY_EVAL.md +316 -0
- examples/task_apps/crafter/eval_image_only_gpt4o.toml +28 -0
- examples/task_apps/crafter/eval_text_only_groq_llama.toml +36 -0
- examples/task_apps/crafter/filter_sft_dataset.toml +16 -0
- examples/task_apps/crafter/task_app/__init__.py +3 -0
- examples/task_apps/crafter/task_app/grpo_crafter.py +309 -14
- examples/task_apps/crafter/task_app/synth_envs_hosted/envs/crafter/environment.py +10 -0
- examples/task_apps/crafter/task_app/synth_envs_hosted/envs/crafter/policy.py +75 -4
- examples/task_apps/crafter/task_app/synth_envs_hosted/envs/crafter/react_agent.py +17 -2
- examples/task_apps/crafter/task_app/synth_envs_hosted/inference/openai_client.py +55 -3
- examples/task_apps/crafter/task_app/synth_envs_hosted/policy_routes.py +114 -32
- examples/task_apps/crafter/task_app/synth_envs_hosted/rollout.py +127 -27
- examples/task_apps/crafter/task_app/synth_envs_hosted/utils.py +156 -0
- examples/task_apps/enron/__init__.py +1 -0
- examples/task_apps/enron/filter_sft.toml +5 -0
- examples/task_apps/enron/tests/__init__.py +2 -0
- examples/task_apps/enron/tests/integration/__init__.py +2 -0
- examples/task_apps/enron/tests/integration/test_enron_eval.py +2 -0
- examples/task_apps/enron/tests/unit/__init__.py +2 -0
- examples/task_apps/pokemon_red/EVAL_IMAGE_ONLY_COMPLETE.md +283 -0
- examples/task_apps/pokemon_red/EVAL_IMAGE_ONLY_STATUS.md +155 -0
- examples/task_apps/pokemon_red/README_IMAGE_ONLY_EVAL.md +415 -0
- examples/task_apps/pokemon_red/eval_image_only_gpt4o.toml +29 -0
- examples/task_apps/pokemon_red/pallet_town_rl_config.toml +2 -0
- examples/task_apps/pokemon_red/task_app.py +199 -6
- examples/task_apps/pokemon_red/test_pallet_town_rewards.py +2 -0
- examples/task_apps/sokoban/filter_sft.toml +5 -0
- examples/task_apps/sokoban/tests/__init__.py +2 -0
- examples/task_apps/sokoban/tests/integration/__init__.py +2 -0
- examples/task_apps/sokoban/tests/unit/__init__.py +2 -0
- examples/task_apps/verilog/eval_groq_qwen32b.toml +8 -4
- examples/task_apps/verilog/filter_sft.toml +5 -0
- examples/task_apps/verilog/task_app/grpo_verilog.py +258 -23
- examples/task_apps/verilog/tests/__init__.py +2 -0
- examples/task_apps/verilog/tests/integration/__init__.py +2 -0
- examples/task_apps/verilog/tests/integration/test_verilog_eval.py +2 -0
- examples/task_apps/verilog/tests/unit/__init__.py +2 -0
- examples/vlm/README.md +3 -3
- examples/vlm/configs/crafter_vlm_gpt4o.toml +2 -0
- examples/vlm/crafter_openai_vlm_agent.py +3 -5
- examples/vlm/filter_image_rows.py +1 -1
- examples/vlm/run_crafter_vlm_benchmark.py +2 -2
- examples/warming_up_to_rl/_utils.py +92 -0
- examples/warming_up_to_rl/analyze_trace_db.py +1 -1
- examples/warming_up_to_rl/configs/crafter_fft.toml +2 -0
- examples/warming_up_to_rl/configs/crafter_fft_4b.toml +2 -0
- examples/warming_up_to_rl/configs/eval_fft_qwen4b.toml +2 -0
- examples/warming_up_to_rl/configs/eval_groq_qwen32b.toml +2 -0
- examples/warming_up_to_rl/configs/eval_modal_qwen4b.toml +2 -1
- examples/warming_up_to_rl/configs/rl_from_base_qwen4b.toml +2 -1
- examples/warming_up_to_rl/configs/rl_from_ft.toml +2 -0
- examples/warming_up_to_rl/export_trace_sft.py +174 -60
- examples/warming_up_to_rl/groq_test.py +2 -0
- examples/warming_up_to_rl/readme.md +63 -132
- examples/warming_up_to_rl/run_fft_and_save.py +1 -1
- examples/warming_up_to_rl/run_local_rollout.py +2 -0
- examples/warming_up_to_rl/run_local_rollout_modal.py +2 -0
- examples/warming_up_to_rl/run_local_rollout_parallel.py +2 -0
- examples/warming_up_to_rl/run_local_rollout_traced.py +2 -0
- examples/warming_up_to_rl/run_rl_and_save.py +1 -1
- examples/warming_up_to_rl/run_rollout_remote.py +2 -0
- examples/warming_up_to_rl/task_app/README.md +42 -0
- examples/warming_up_to_rl/task_app/grpo_crafter.py +696 -0
- examples/warming_up_to_rl/task_app/grpo_crafter_task_app.py +135 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/README.md +173 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/__init__.py +5 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/branching.py +143 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/environment_routes.py +1226 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/__init__.py +1 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/__init__.py +6 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/app.py +1 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/environment.py +522 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/policy.py +478 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/react_agent.py +108 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/shared.py +305 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/tools.py +47 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/hosted_app.py +204 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/inference/__init__.py +5 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/inference/openai_client.py +618 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/main.py +100 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/policy_routes.py +1081 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/registry.py +195 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/rollout.py +1861 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/storage/__init__.py +5 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/storage/volume.py +211 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/test_agents.py +161 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/test_service.py +137 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/utils.py +62 -0
- synth_ai/__init__.py +44 -30
- synth_ai/_utils/__init__.py +47 -0
- synth_ai/_utils/base_url.py +10 -0
- synth_ai/_utils/http.py +10 -0
- synth_ai/_utils/prompts.py +10 -0
- synth_ai/_utils/task_app_state.py +12 -0
- synth_ai/_utils/user_config.py +10 -0
- synth_ai/api/models/supported.py +145 -7
- synth_ai/api/train/__init__.py +13 -1
- synth_ai/api/train/cli.py +30 -7
- synth_ai/api/train/config_finder.py +18 -11
- synth_ai/api/train/env_resolver.py +13 -10
- synth_ai/cli/__init__.py +66 -49
- synth_ai/cli/_modal_wrapper.py +9 -6
- synth_ai/cli/_typer_patch.py +0 -2
- synth_ai/cli/_validate_task_app.py +22 -4
- synth_ai/cli/legacy_root_backup.py +3 -1
- synth_ai/cli/lib/__init__.py +10 -0
- synth_ai/cli/lib/task_app_discovery.py +7 -0
- synth_ai/cli/lib/task_app_env.py +518 -0
- synth_ai/cli/recent.py +1 -0
- synth_ai/cli/setup.py +266 -0
- synth_ai/cli/task_app_deploy.py +16 -0
- synth_ai/cli/task_app_list.py +25 -0
- synth_ai/cli/task_app_modal_serve.py +16 -0
- synth_ai/cli/task_app_serve.py +18 -0
- synth_ai/cli/task_apps.py +392 -141
- synth_ai/cli/train.py +18 -0
- synth_ai/cli/tui.py +62 -0
- synth_ai/demos/__init__.py +10 -0
- synth_ai/demos/core/__init__.py +28 -1
- synth_ai/demos/crafter/__init__.py +1 -0
- synth_ai/demos/crafter/crafter_fft_4b.toml +55 -0
- synth_ai/demos/crafter/grpo_crafter_task_app.py +185 -0
- synth_ai/demos/crafter/rl_from_base_qwen4b.toml +74 -0
- synth_ai/demos/demo_registry.py +176 -0
- synth_ai/demos/demo_task_apps/crafter/grpo_crafter_task_app.py +1 -1
- synth_ai/demos/math/__init__.py +1 -0
- synth_ai/demos/math/_common.py +16 -0
- synth_ai/demos/math/app.py +38 -0
- synth_ai/demos/math/config.toml +76 -0
- synth_ai/demos/math/deploy_modal.py +54 -0
- synth_ai/demos/math/modal_task_app.py +702 -0
- synth_ai/demos/math/task_app_entry.py +51 -0
- synth_ai/environments/environment/core.py +7 -1
- synth_ai/environments/examples/bandit/engine.py +0 -1
- synth_ai/environments/examples/bandit/environment.py +0 -1
- synth_ai/environments/examples/crafter_classic/environment.py +1 -1
- synth_ai/environments/examples/verilog/engine.py +76 -10
- synth_ai/environments/examples/wordle/environment.py +0 -1
- synth_ai/evals/base.py +16 -5
- synth_ai/evals/client.py +1 -1
- synth_ai/inference/client.py +1 -1
- synth_ai/learning/client.py +1 -1
- synth_ai/learning/health.py +1 -1
- synth_ai/learning/jobs.py +1 -1
- synth_ai/learning/rl/client.py +1 -1
- synth_ai/learning/rl/env_keys.py +1 -1
- synth_ai/learning/rl/secrets.py +1 -1
- synth_ai/learning/sft/client.py +1 -1
- synth_ai/learning/sft/data.py +407 -4
- synth_ai/learning/validators.py +4 -1
- synth_ai/task/__init__.py +11 -1
- synth_ai/task/apps/__init__.py +5 -2
- synth_ai/task/config.py +259 -0
- synth_ai/task/contracts.py +15 -2
- synth_ai/task/rubrics/__init__.py +4 -2
- synth_ai/task/rubrics/loaders.py +27 -4
- synth_ai/task/rubrics/scoring.py +3 -0
- synth_ai/task/rubrics.py +219 -0
- synth_ai/task/trace_correlation_helpers.py +328 -0
- synth_ai/task/tracing_utils.py +14 -3
- synth_ai/task/validators.py +145 -2
- synth_ai/tracing_v3/config.py +15 -13
- synth_ai/tracing_v3/constants.py +21 -0
- synth_ai/tracing_v3/db_config.py +3 -1
- synth_ai/tracing_v3/decorators.py +10 -7
- synth_ai/tracing_v3/session_tracer.py +10 -0
- synth_ai/tracing_v3/turso/daemon.py +2 -2
- synth_ai/tracing_v3/turso/native_manager.py +108 -77
- synth_ai/tracing_v3/utils.py +1 -1
- synth_ai/tui/__init__.py +5 -0
- synth_ai/tui/__main__.py +13 -0
- synth_ai/tui/cli/__init__.py +1 -0
- synth_ai/tui/cli/query_experiments.py +164 -0
- synth_ai/tui/cli/query_experiments_v3.py +164 -0
- synth_ai/tui/dashboard.py +911 -0
- synth_ai/utils/__init__.py +101 -0
- synth_ai/utils/base_url.py +94 -0
- synth_ai/utils/cli.py +131 -0
- synth_ai/utils/env.py +287 -0
- synth_ai/utils/http.py +169 -0
- synth_ai/utils/modal.py +308 -0
- synth_ai/utils/process.py +212 -0
- synth_ai/utils/prompts.py +39 -0
- synth_ai/utils/sqld.py +122 -0
- synth_ai/utils/task_app_discovery.py +882 -0
- synth_ai/utils/task_app_env.py +186 -0
- synth_ai/utils/task_app_state.py +318 -0
- synth_ai/utils/user_config.py +137 -0
- synth_ai/v0/config/__init__.py +1 -5
- synth_ai/v0/config/base_url.py +1 -7
- synth_ai/v0/tracing/config.py +1 -1
- synth_ai/v0/tracing/decorators.py +1 -1
- synth_ai/v0/tracing/upload.py +1 -1
- synth_ai/v0/tracing_v1/config.py +1 -1
- synth_ai/v0/tracing_v1/decorators.py +1 -1
- synth_ai/v0/tracing_v1/upload.py +1 -1
- {synth_ai-0.2.13.dev2.dist-info → synth_ai-0.2.16.dist-info}/METADATA +85 -31
- {synth_ai-0.2.13.dev2.dist-info → synth_ai-0.2.16.dist-info}/RECORD +286 -135
- synth_ai/cli/man.py +0 -106
- synth_ai/compound/cais.py +0 -0
- synth_ai/core/experiment.py +0 -13
- synth_ai/core/system.py +0 -15
- synth_ai/demo_registry.py +0 -295
- synth_ai/handshake.py +0 -109
- synth_ai/http.py +0 -26
- {synth_ai-0.2.13.dev2.dist-info → synth_ai-0.2.16.dist-info}/WHEEL +0 -0
- {synth_ai-0.2.13.dev2.dist-info → synth_ai-0.2.16.dist-info}/entry_points.txt +0 -0
- {synth_ai-0.2.13.dev2.dist-info → synth_ai-0.2.16.dist-info}/licenses/LICENSE +0 -0
- {synth_ai-0.2.13.dev2.dist-info → synth_ai-0.2.16.dist-info}/top_level.txt +0 -0
|
@@ -2,20 +2,23 @@
|
|
|
2
2
|
|
|
3
3
|
from __future__ import annotations
|
|
4
4
|
|
|
5
|
+
import json
|
|
5
6
|
import logging
|
|
6
7
|
import os
|
|
7
8
|
import sys
|
|
8
9
|
from collections.abc import Iterable, Sequence
|
|
10
|
+
from contextlib import suppress
|
|
9
11
|
from dataclasses import dataclass
|
|
10
12
|
from pathlib import Path
|
|
11
13
|
from typing import Any
|
|
12
14
|
|
|
13
15
|
from synth_ai.task.apps import ModalDeploymentConfig, TaskAppEntry, register_task_app
|
|
14
|
-
from synth_ai.task.contracts import RolloutMetrics, RolloutRequest, RolloutResponse, TaskInfo
|
|
16
|
+
from synth_ai.task.contracts import RolloutMetrics, RolloutMode, RolloutRequest, RolloutResponse, TaskInfo
|
|
15
17
|
from synth_ai.task.datasets import TaskDatasetRegistry, TaskDatasetSpec
|
|
16
18
|
from synth_ai.task.json import to_jsonable # noqa: F401 (imported for side-effect compatibility)
|
|
17
19
|
from synth_ai.task.rubrics import load_rubric
|
|
18
20
|
from synth_ai.task.server import ProxyConfig, RubricBundle, TaskAppConfig
|
|
21
|
+
from synth_ai.task.validators import normalize_inference_url
|
|
19
22
|
from synth_ai.task.tracing_utils import (
|
|
20
23
|
build_tracer_factory,
|
|
21
24
|
resolve_sft_output_dir,
|
|
@@ -24,6 +27,18 @@ from synth_ai.task.tracing_utils import (
|
|
|
24
27
|
)
|
|
25
28
|
from synth_ai.tracing_v3.session_tracer import SessionTracer
|
|
26
29
|
|
|
30
|
+
try:
|
|
31
|
+
from .synth_envs_hosted.utils import (
|
|
32
|
+
ensure_chat_completions_url,
|
|
33
|
+
extract_trace_correlation_id,
|
|
34
|
+
)
|
|
35
|
+
except Exception: # pragma: no cover - utils unavailable if optional deps missing
|
|
36
|
+
def ensure_chat_completions_url(raw_url, mode=None):
|
|
37
|
+
"""Fallback to shared utility for URL normalization."""
|
|
38
|
+
return normalize_inference_url(raw_url) if raw_url else raw_url
|
|
39
|
+
|
|
40
|
+
def extract_trace_correlation_id(_raw_url):
|
|
41
|
+
return None
|
|
27
42
|
logger = logging.getLogger(__name__)
|
|
28
43
|
|
|
29
44
|
DEFAULT_ALIAS_OPS: list[str] = ["agent", "env"] * 10
|
|
@@ -95,6 +110,110 @@ SYNTH_ENVS_HOSTED_ROOT = (TASK_APP_ROOT / "synth_envs_hosted").resolve()
|
|
|
95
110
|
EXAMPLES_ROOT = (REPO_ROOT / "examples").resolve()
|
|
96
111
|
RUBRICS_ROOT = (EXAMPLES_ROOT / "multi_step" / "rubrics").resolve()
|
|
97
112
|
|
|
113
|
+
DEFAULT_OUTCOME_RUBRIC_DATA: dict[str, Any] = {
|
|
114
|
+
"version": "1",
|
|
115
|
+
"goal_text": (
|
|
116
|
+
"Reward episodes that climb the Crafter achievement ladder, stockpile key resources "
|
|
117
|
+
"(especially wood), and finish alive with clear understanding of any failure."
|
|
118
|
+
),
|
|
119
|
+
"aggregation": "weighted_sum",
|
|
120
|
+
"criteria": [
|
|
121
|
+
{
|
|
122
|
+
"id": "achievement_progression",
|
|
123
|
+
"description": (
|
|
124
|
+
"Weigh achievements by tier: late-game unlocks (iron tools, furnace, armor) earn "
|
|
125
|
+
"the most, mid-tier crafting (stone tools, furnace prep) gets partial credit, early "
|
|
126
|
+
"tasks (collecting saplings/wood tools) only lightly scored."
|
|
127
|
+
),
|
|
128
|
+
"weight": 0.35,
|
|
129
|
+
},
|
|
130
|
+
{
|
|
131
|
+
"id": "resource_stockpile",
|
|
132
|
+
"description": (
|
|
133
|
+
"Assess resource totals with emphasis on wood stores; high scores require abundant "
|
|
134
|
+
"wood plus supporting materials (stone, coal, iron) that signal readiness for "
|
|
135
|
+
"crafting."
|
|
136
|
+
),
|
|
137
|
+
"weight": 0.2,
|
|
138
|
+
},
|
|
139
|
+
{
|
|
140
|
+
"id": "survival_state",
|
|
141
|
+
"description": (
|
|
142
|
+
"Reward finishing alive with healthy food/drink bars and safe positioning; penalize "
|
|
143
|
+
"deaths, low vitals, or lingering hazards at episode end."
|
|
144
|
+
),
|
|
145
|
+
"weight": 0.2,
|
|
146
|
+
},
|
|
147
|
+
{
|
|
148
|
+
"id": "failure_analysis",
|
|
149
|
+
"description": (
|
|
150
|
+
"If the run ends in death or timeout, clearly identify the cause and deduct unless "
|
|
151
|
+
"the agent mitigated risk; highlight when the agent survives despite danger."
|
|
152
|
+
),
|
|
153
|
+
"weight": 0.15,
|
|
154
|
+
},
|
|
155
|
+
{
|
|
156
|
+
"id": "future_readiness",
|
|
157
|
+
"description": (
|
|
158
|
+
"Describe how prepared the agent is for the next objectives (tools crafted, shelters, "
|
|
159
|
+
"furnaces, smelted materials) and whether the inventory supports further progress."
|
|
160
|
+
),
|
|
161
|
+
"weight": 0.1,
|
|
162
|
+
},
|
|
163
|
+
],
|
|
164
|
+
}
|
|
165
|
+
|
|
166
|
+
DEFAULT_EVENTS_RUBRIC_DATA: dict[str, Any] = {
|
|
167
|
+
"version": "1",
|
|
168
|
+
"goal_text": (
|
|
169
|
+
"Score each decision in proportion to the concrete Crafter achievement progress it "
|
|
170
|
+
"delivers, topping out the scale when the log shows a fresh achievement unlock and keeping "
|
|
171
|
+
"routine upkeep near zero."
|
|
172
|
+
),
|
|
173
|
+
"aggregation": "weighted_sum",
|
|
174
|
+
"criteria": [
|
|
175
|
+
{
|
|
176
|
+
"id": "achievement_unlocks",
|
|
177
|
+
"description": (
|
|
178
|
+
"Assign 0.9-1.0 when the decision explicitly unlocks a new Crafter achievement (look "
|
|
179
|
+
'for "Achievement unlocked" messages or equivalent deterministic completions such as '
|
|
180
|
+
"placing a furnace that immediately crafts ingots). Cap the score at 0.4 when no new "
|
|
181
|
+
"achievement fires, and drop to <=0.1 if the turn repeats known actions without "
|
|
182
|
+
"measurable progress."
|
|
183
|
+
),
|
|
184
|
+
"weight": 0.55,
|
|
185
|
+
},
|
|
186
|
+
{
|
|
187
|
+
"id": "milestone_setup",
|
|
188
|
+
"description": (
|
|
189
|
+
"Give 0.5-0.7 when the action completes the last prerequisite for a specific upcoming "
|
|
190
|
+
"achievement (e.g., gathering the final ore before smelting, crafting sticks right "
|
|
191
|
+
"before a tool). Keep the score <=0.3 if the progress is speculative or still several "
|
|
192
|
+
"steps away."
|
|
193
|
+
),
|
|
194
|
+
"weight": 0.2,
|
|
195
|
+
},
|
|
196
|
+
{
|
|
197
|
+
"id": "inventory_depth",
|
|
198
|
+
"description": (
|
|
199
|
+
"Reward 0.3-0.5 for pulls that clearly deepen critical buffers (fuel, food, ore) and "
|
|
200
|
+
"immediately unblock the next milestone. If resources are already plentiful or the "
|
|
201
|
+
"haul is generic filler, stay at <=0.2."
|
|
202
|
+
),
|
|
203
|
+
"weight": 0.15,
|
|
204
|
+
},
|
|
205
|
+
{
|
|
206
|
+
"id": "execution_quality",
|
|
207
|
+
"description": (
|
|
208
|
+
"Only add up to 0.1 for clean, legal execution that avoids wasted turns; drop to 0.0 "
|
|
209
|
+
"whenever the agent idles, repeats failed moves, or takes damage without compensating "
|
|
210
|
+
"progress."
|
|
211
|
+
),
|
|
212
|
+
"weight": 0.1,
|
|
213
|
+
},
|
|
214
|
+
],
|
|
215
|
+
}
|
|
216
|
+
|
|
98
217
|
for path in (REPO_ROOT, TASK_APP_ROOT, SYNTH_ENVS_HOSTED_ROOT, EXAMPLES_ROOT):
|
|
99
218
|
try:
|
|
100
219
|
resolved = path.resolve()
|
|
@@ -115,6 +234,28 @@ try:
|
|
|
115
234
|
except Exception:
|
|
116
235
|
pass
|
|
117
236
|
|
|
237
|
+
def _load_rubric_with_fallback(filename: str, fallback: dict[str, Any]):
|
|
238
|
+
"""Load rubric from JSON file when available, otherwise use bundled fallback."""
|
|
239
|
+
|
|
240
|
+
search_paths = [RUBRICS_ROOT / filename, TASK_APP_ROOT / "rubrics" / filename]
|
|
241
|
+
for path in search_paths:
|
|
242
|
+
try:
|
|
243
|
+
if path.exists():
|
|
244
|
+
logger.debug("Loading rubric from %s", path)
|
|
245
|
+
return load_rubric(str(path))
|
|
246
|
+
except Exception as exc:
|
|
247
|
+
logger.warning("Failed to load rubric %s from %s: %s", filename, path, exc)
|
|
248
|
+
|
|
249
|
+
logger.warning("Falling back to inline rubric %s: file not available", filename)
|
|
250
|
+
try:
|
|
251
|
+
materialized = search_paths[0]
|
|
252
|
+
materialized.parent.mkdir(parents=True, exist_ok=True)
|
|
253
|
+
materialized.write_text(json.dumps(fallback, indent=2), encoding="utf-8")
|
|
254
|
+
except Exception:
|
|
255
|
+
logger.debug("Unable to materialize inline rubric %s", filename, exc_info=True)
|
|
256
|
+
return load_rubric(fallback)
|
|
257
|
+
|
|
258
|
+
|
|
118
259
|
HAS_HOSTED = True
|
|
119
260
|
try:
|
|
120
261
|
import crafter # type: ignore
|
|
@@ -343,9 +484,13 @@ def _base_task_info(dataset: CrafterDataset) -> TaskInfo:
|
|
|
343
484
|
)
|
|
344
485
|
|
|
345
486
|
|
|
346
|
-
OUTCOME_RUBRIC =
|
|
487
|
+
OUTCOME_RUBRIC = _load_rubric_with_fallback(
|
|
488
|
+
"crafter_outcome_rubric.json", DEFAULT_OUTCOME_RUBRIC_DATA
|
|
489
|
+
)
|
|
347
490
|
|
|
348
|
-
EVENTS_RUBRIC =
|
|
491
|
+
EVENTS_RUBRIC = _load_rubric_with_fallback(
|
|
492
|
+
"crafter_events_rubric.json", DEFAULT_EVENTS_RUBRIC_DATA
|
|
493
|
+
)
|
|
349
494
|
|
|
350
495
|
|
|
351
496
|
def describe_taskset(dataset: CrafterDataset) -> dict[str, Any]:
|
|
@@ -470,16 +615,14 @@ def _coerce_math_to_crafter(request: RolloutRequest) -> RolloutRequest:
|
|
|
470
615
|
|
|
471
616
|
coerced = request.model_copy(update={"env": updated_env, "policy": updated_policy, "ops": ops_override})
|
|
472
617
|
|
|
473
|
-
|
|
618
|
+
with suppress(Exception):
|
|
474
619
|
print(
|
|
475
620
|
"[rollout] remapped math request -> crafter "
|
|
476
621
|
f"(env={request.env.env_name!r}→{coerced.env.env_name!r}, "
|
|
477
622
|
f"policy={request.policy.policy_name!r}→{coerced.policy.policy_name!r})",
|
|
478
623
|
flush=True,
|
|
479
624
|
)
|
|
480
|
-
|
|
481
|
-
pass
|
|
482
|
-
try:
|
|
625
|
+
with suppress(Exception):
|
|
483
626
|
logger.info(
|
|
484
627
|
"ROLLOUT_ALIAS: remapped math env/policy to crafter (env=%s→%s, policy=%s→%s)",
|
|
485
628
|
request.env.env_name,
|
|
@@ -487,15 +630,98 @@ def _coerce_math_to_crafter(request: RolloutRequest) -> RolloutRequest:
|
|
|
487
630
|
request.policy.policy_name,
|
|
488
631
|
coerced.policy.policy_name,
|
|
489
632
|
)
|
|
490
|
-
except Exception:
|
|
491
|
-
pass
|
|
492
633
|
|
|
493
634
|
return coerced
|
|
494
635
|
|
|
495
636
|
|
|
637
|
+
def _resolve_trace_correlation_id(policy_cfg: dict[str, Any], mode: Any = None) -> str | None:
|
|
638
|
+
"""Best-effort extraction of the trace correlation identifier."""
|
|
639
|
+
candidates: list[Any] = [
|
|
640
|
+
policy_cfg.get("trace_correlation_id"),
|
|
641
|
+
policy_cfg.get("trace"),
|
|
642
|
+
]
|
|
643
|
+
logger.debug(
|
|
644
|
+
"_resolve_trace_correlation_id: inspecting policy_cfg keys=%s candidates=%s",
|
|
645
|
+
sorted(policy_cfg.keys()),
|
|
646
|
+
candidates,
|
|
647
|
+
)
|
|
648
|
+
for candidate in candidates:
|
|
649
|
+
if isinstance(candidate, str):
|
|
650
|
+
stripped = candidate.strip()
|
|
651
|
+
if stripped:
|
|
652
|
+
return stripped
|
|
653
|
+
|
|
654
|
+
return extract_trace_correlation_id(policy_cfg.get("inference_url"))
|
|
655
|
+
|
|
656
|
+
|
|
496
657
|
async def rollout_executor(request: RolloutRequest, fastapi_request) -> RolloutResponse:
|
|
658
|
+
request = _coerce_math_to_crafter(request)
|
|
659
|
+
|
|
660
|
+
policy_cfg = dict(request.policy.config or {})
|
|
661
|
+
logger.info(
|
|
662
|
+
"ROLLOUT_EXEC: incoming policy config keys=%s inference_url=%s run_id=%s mode=%s",
|
|
663
|
+
sorted(policy_cfg.keys()),
|
|
664
|
+
policy_cfg.get("inference_url"),
|
|
665
|
+
request.run_id,
|
|
666
|
+
request.mode,
|
|
667
|
+
)
|
|
668
|
+
inferred_url = ensure_chat_completions_url(policy_cfg.get("inference_url"), mode=request.mode)
|
|
669
|
+
if isinstance(inferred_url, str) and inferred_url:
|
|
670
|
+
if inferred_url != policy_cfg.get("inference_url"):
|
|
671
|
+
logger.warning(
|
|
672
|
+
"ROLLOUT_EXEC: normalized inference_url run_id=%s from %s to %s",
|
|
673
|
+
request.run_id,
|
|
674
|
+
policy_cfg.get("inference_url"),
|
|
675
|
+
inferred_url,
|
|
676
|
+
)
|
|
677
|
+
policy_cfg["inference_url"] = inferred_url
|
|
678
|
+
else:
|
|
679
|
+
logger.warning(
|
|
680
|
+
"ROLLOUT_EXEC: inference_url missing or not normalized run_id=%s raw=%s",
|
|
681
|
+
request.run_id,
|
|
682
|
+
policy_cfg.get("inference_url"),
|
|
683
|
+
)
|
|
684
|
+
|
|
685
|
+
trace_correlation_id = _resolve_trace_correlation_id(policy_cfg, mode=request.mode)
|
|
686
|
+
|
|
687
|
+
# ASSERTION: trace_correlation_id MUST be present for RL mode (but not EVAL mode)
|
|
688
|
+
if request.mode == RolloutMode.RL:
|
|
689
|
+
assert trace_correlation_id is not None, (
|
|
690
|
+
f"FATAL: trace_correlation_id extraction failed for run_id={request.run_id}. "
|
|
691
|
+
f"policy_cfg_keys={sorted(policy_cfg.keys())} "
|
|
692
|
+
f"inference_url={policy_cfg.get('inference_url')}"
|
|
693
|
+
)
|
|
694
|
+
assert isinstance(trace_correlation_id, str) and trace_correlation_id.strip(), (
|
|
695
|
+
f"FATAL: trace_correlation_id is empty for run_id={request.run_id}. "
|
|
696
|
+
f"Got: {trace_correlation_id!r}"
|
|
697
|
+
)
|
|
698
|
+
|
|
699
|
+
if trace_correlation_id:
|
|
700
|
+
policy_cfg["trace_correlation_id"] = trace_correlation_id
|
|
701
|
+
logger.info(
|
|
702
|
+
"ROLLOUT_EXEC: resolved trace_correlation_id=%s run_id=%s",
|
|
703
|
+
trace_correlation_id,
|
|
704
|
+
request.run_id,
|
|
705
|
+
)
|
|
706
|
+
|
|
707
|
+
pipeline_metadata: dict[str, Any] = {}
|
|
708
|
+
if trace_correlation_id:
|
|
709
|
+
pipeline_metadata["trace_correlation_id"] = trace_correlation_id
|
|
710
|
+
if isinstance(policy_cfg.get("inference_url"), str) and policy_cfg["inference_url"]:
|
|
711
|
+
pipeline_metadata.setdefault("inference_url", policy_cfg["inference_url"])
|
|
712
|
+
logger.info(
|
|
713
|
+
"ROLLOUT_EXEC: pipeline metadata prepared run_id=%s metadata=%s",
|
|
714
|
+
request.run_id,
|
|
715
|
+
pipeline_metadata,
|
|
716
|
+
)
|
|
717
|
+
|
|
497
718
|
# If hosted env service code is not bundled, return a no-op rollout response compatible with contracts
|
|
498
719
|
if not HAS_HOSTED:
|
|
720
|
+
logger.warning(
|
|
721
|
+
"ROLLOUT_EXEC: HAS_HOSTED disabled, returning stub response run_id=%s metadata=%s",
|
|
722
|
+
request.run_id,
|
|
723
|
+
pipeline_metadata,
|
|
724
|
+
)
|
|
499
725
|
return RolloutResponse(
|
|
500
726
|
run_id=request.run_id,
|
|
501
727
|
trajectories=[],
|
|
@@ -510,11 +736,10 @@ async def rollout_executor(request: RolloutRequest, fastapi_request) -> RolloutR
|
|
|
510
736
|
aborted=False,
|
|
511
737
|
ops_executed=0,
|
|
512
738
|
trace=None,
|
|
739
|
+
trace_correlation_id=trace_correlation_id or f"trace_{request.run_id}",
|
|
740
|
+
pipeline_metadata=pipeline_metadata,
|
|
513
741
|
)
|
|
514
742
|
|
|
515
|
-
request = _coerce_math_to_crafter(request)
|
|
516
|
-
|
|
517
|
-
policy_cfg = dict(request.policy.config or {})
|
|
518
743
|
try:
|
|
519
744
|
max_llm_calls = int(policy_cfg.get("max_llm_calls") or 10)
|
|
520
745
|
except Exception:
|
|
@@ -545,6 +770,7 @@ async def rollout_executor(request: RolloutRequest, fastapi_request) -> RolloutR
|
|
|
545
770
|
converted_ops = converted_ops[:max_ops_allowed]
|
|
546
771
|
legacy_request = LegacyRolloutRequest(
|
|
547
772
|
run_id=request.run_id,
|
|
773
|
+
mode=request.mode, # Preserve mode for nested requests
|
|
548
774
|
env=LegacyRolloutEnvSpec(
|
|
549
775
|
env_id=request.env.env_id,
|
|
550
776
|
env_name=request.env.env_name,
|
|
@@ -568,12 +794,79 @@ async def rollout_executor(request: RolloutRequest, fastapi_request) -> RolloutR
|
|
|
568
794
|
legacy_response: LegacyRolloutResponse = await legacy_execute_rollout(
|
|
569
795
|
legacy_request, fastapi_request
|
|
570
796
|
)
|
|
797
|
+
logger.info(
|
|
798
|
+
"ROLLOUT_EXEC: legacy rollout completed run_id=%s trace_id=%s",
|
|
799
|
+
request.run_id,
|
|
800
|
+
trace_correlation_id,
|
|
801
|
+
)
|
|
571
802
|
data = legacy_response.model_dump()
|
|
572
803
|
metrics = data.get("metrics", {}) or {}
|
|
573
804
|
metrics.setdefault("outcome_score", None)
|
|
574
805
|
metrics.setdefault("events_score", None)
|
|
575
806
|
metrics.setdefault("details", {})
|
|
576
807
|
data["metrics"] = metrics
|
|
808
|
+
|
|
809
|
+
# Add trace_correlation_id at TOP-LEVEL (REQUIRED for RL training pipeline)
|
|
810
|
+
# Use fallback if somehow missing
|
|
811
|
+
data["trace_correlation_id"] = trace_correlation_id or f"trace_{request.run_id}"
|
|
812
|
+
|
|
813
|
+
# Add trace_correlation_id to pipeline_metadata
|
|
814
|
+
existing_meta = data.get("pipeline_metadata")
|
|
815
|
+
if not isinstance(existing_meta, dict):
|
|
816
|
+
existing_meta = {}
|
|
817
|
+
# ALWAYS set trace_correlation_id (use fallback if needed)
|
|
818
|
+
final_cid = trace_correlation_id or f"trace_{request.run_id}"
|
|
819
|
+
existing_meta["trace_correlation_id"] = final_cid
|
|
820
|
+
if isinstance(policy_cfg.get("inference_url"), str) and policy_cfg["inference_url"]:
|
|
821
|
+
existing_meta.setdefault("inference_url", policy_cfg["inference_url"])
|
|
822
|
+
data["pipeline_metadata"] = existing_meta
|
|
823
|
+
|
|
824
|
+
# Add trace_correlation_id to each trajectory (required for RL training pipeline)
|
|
825
|
+
if "trajectories" in data:
|
|
826
|
+
for traj in data.get("trajectories", []):
|
|
827
|
+
if isinstance(traj, dict):
|
|
828
|
+
traj["trace_correlation_id"] = final_cid
|
|
829
|
+
logger.info(
|
|
830
|
+
"ROLLOUT_EXEC: final pipeline metadata run_id=%s metadata=%s",
|
|
831
|
+
request.run_id,
|
|
832
|
+
existing_meta,
|
|
833
|
+
)
|
|
834
|
+
if trace_correlation_id and existing_meta.get("trace_correlation_id") != trace_correlation_id:
|
|
835
|
+
logger.error(
|
|
836
|
+
"ROLLOUT_EXEC: metadata trace mismatch run_id=%s expected=%s actual=%s",
|
|
837
|
+
request.run_id,
|
|
838
|
+
trace_correlation_id,
|
|
839
|
+
existing_meta.get("trace_correlation_id"),
|
|
840
|
+
)
|
|
841
|
+
if not existing_meta.get("trace_correlation_id"):
|
|
842
|
+
logger.error(
|
|
843
|
+
"ROLLOUT_EXEC: final metadata missing trace_correlation_id run_id=%s metadata=%s",
|
|
844
|
+
request.run_id,
|
|
845
|
+
existing_meta,
|
|
846
|
+
)
|
|
847
|
+
|
|
848
|
+
# ASSERTION: Verify trace_correlation_id is present in response at all required levels
|
|
849
|
+
assert "trace_correlation_id" in data, (
|
|
850
|
+
f"FATAL: trace_correlation_id missing from top-level response data for run_id={request.run_id}. "
|
|
851
|
+
f"Keys: {list(data.keys())}"
|
|
852
|
+
)
|
|
853
|
+
assert data["trace_correlation_id"] == final_cid, (
|
|
854
|
+
f"FATAL: trace_correlation_id mismatch in response for run_id={request.run_id}. "
|
|
855
|
+
f"Expected: {final_cid!r}, Got: {data.get('trace_correlation_id')!r}"
|
|
856
|
+
)
|
|
857
|
+
assert "pipeline_metadata" in data, (
|
|
858
|
+
f"FATAL: pipeline_metadata missing from response for run_id={request.run_id}"
|
|
859
|
+
)
|
|
860
|
+
assert data["pipeline_metadata"].get("trace_correlation_id") == final_cid, (
|
|
861
|
+
f"FATAL: trace_correlation_id missing or mismatched in pipeline_metadata for run_id={request.run_id}. "
|
|
862
|
+
f"Expected: {final_cid!r}, Got: {data['pipeline_metadata'].get('trace_correlation_id')!r}"
|
|
863
|
+
)
|
|
864
|
+
logger.info(
|
|
865
|
+
"ROLLOUT_EXEC: assertions passed - trace_correlation_id present in response run_id=%s cid=%s",
|
|
866
|
+
request.run_id,
|
|
867
|
+
final_cid,
|
|
868
|
+
)
|
|
869
|
+
|
|
577
870
|
return RolloutResponse.model_validate(data)
|
|
578
871
|
|
|
579
872
|
|
|
@@ -617,7 +910,7 @@ def build_config() -> TaskAppConfig:
|
|
|
617
910
|
routers: tuple = (environment_router, policy_router, branching_router) if HAS_HOSTED else ()
|
|
618
911
|
|
|
619
912
|
config = TaskAppConfig(
|
|
620
|
-
app_id="grpo-crafter",
|
|
913
|
+
app_id="grpo-crafter-task-app",
|
|
621
914
|
name="GRPO Crafter Task App",
|
|
622
915
|
description="Crafter Classic environment with GRPO task endpoints and LLM proxies.",
|
|
623
916
|
base_task_info=base_info,
|
|
@@ -638,7 +931,7 @@ def build_config() -> TaskAppConfig:
|
|
|
638
931
|
|
|
639
932
|
register_task_app(
|
|
640
933
|
entry=TaskAppEntry(
|
|
641
|
-
app_id="grpo-crafter",
|
|
934
|
+
app_id="grpo-crafter-task-app",
|
|
642
935
|
description="Crafter Classic task app with rollout + proxy endpoints",
|
|
643
936
|
config_factory=build_config,
|
|
644
937
|
aliases=("crafter", "crafter-task"),
|
|
@@ -665,6 +958,8 @@ register_task_app(
|
|
|
665
958
|
(str(REPO_ROOT), "/opt/synth_ai_repo"),
|
|
666
959
|
(str(REPO_ROOT / "synth_ai"), "/opt/synth_ai_repo/synth_ai"),
|
|
667
960
|
(str(TASK_APP_ROOT), "/opt/synth_ai_repo/examples/task_apps/crafter/task_app"),
|
|
961
|
+
# Explicitly mount rubrics directory
|
|
962
|
+
(str(RUBRICS_ROOT), "/opt/synth_ai_repo/examples/multi_step/rubrics"),
|
|
668
963
|
),
|
|
669
964
|
secret_names=("groq-api-key", "openai-api-key"),
|
|
670
965
|
memory=16384,
|
|
@@ -209,6 +209,16 @@ class CrafterEnvironmentWrapper:
|
|
|
209
209
|
logger.info("No valid actions provided, defaulting to noop")
|
|
210
210
|
normalized.append(EnvToolCall(tool="interact", args={"action": 0})) # noop action
|
|
211
211
|
|
|
212
|
+
# Limit to first 20 actions to prevent spam from overly long tool calls
|
|
213
|
+
MAX_ACTIONS_PER_STEP = 20
|
|
214
|
+
if len(normalized) > MAX_ACTIONS_PER_STEP:
|
|
215
|
+
logger.warning(
|
|
216
|
+
"Tool call contained %d actions, limiting to first %d to prevent spam",
|
|
217
|
+
len(normalized),
|
|
218
|
+
MAX_ACTIONS_PER_STEP,
|
|
219
|
+
)
|
|
220
|
+
normalized = normalized[:MAX_ACTIONS_PER_STEP]
|
|
221
|
+
|
|
212
222
|
# Pre-step logging: capture current public state and print concise summary
|
|
213
223
|
before_state: dict[str, Any] | None = None
|
|
214
224
|
try:
|
|
@@ -45,6 +45,7 @@ class CrafterPolicy(Policy):
|
|
|
45
45
|
self.model = model
|
|
46
46
|
self.use_tools = True
|
|
47
47
|
self.use_vision = False # Enable vision for VLMs
|
|
48
|
+
self.image_only_mode = False # If True, only send images without text observations
|
|
48
49
|
# Sampling parameters (populated via initialize(config))
|
|
49
50
|
self.temperature: float | None = None
|
|
50
51
|
self.top_p: float | None = None
|
|
@@ -58,6 +59,13 @@ class CrafterPolicy(Policy):
|
|
|
58
59
|
self.trajectory_history: list[dict[str, Any]] = [] # env/policy step records
|
|
59
60
|
|
|
60
61
|
async def initialize(self, config: dict[str, Any]) -> None:
|
|
62
|
+
# DEBUG: Log the incoming config
|
|
63
|
+
import logging
|
|
64
|
+
_logger = logging.getLogger(__name__)
|
|
65
|
+
_logger.debug(f"🔊 [POLICY_INIT] Received config keys: {list(config.keys())}")
|
|
66
|
+
_logger.debug(f"🔊 [POLICY_INIT] use_vision in config: {'use_vision' in config}, value: {config.get('use_vision')}")
|
|
67
|
+
_logger.debug(f"🔊 [POLICY_INIT] image_only_mode in config: {'image_only_mode' in config}, value: {config.get('image_only_mode')}")
|
|
68
|
+
|
|
61
69
|
if "inference_url" in config:
|
|
62
70
|
self.inference_url = config["inference_url"]
|
|
63
71
|
if "model" in config:
|
|
@@ -66,6 +74,12 @@ class CrafterPolicy(Policy):
|
|
|
66
74
|
self.use_tools = bool(config["use_tools"])
|
|
67
75
|
if "use_vision" in config:
|
|
68
76
|
self.use_vision = bool(config["use_vision"])
|
|
77
|
+
_logger.debug(f"🔊 [POLICY_INIT] Set use_vision={self.use_vision} from config")
|
|
78
|
+
if "image_only_mode" in config:
|
|
79
|
+
self.image_only_mode = bool(config["image_only_mode"])
|
|
80
|
+
# If image_only_mode is enabled, automatically enable vision
|
|
81
|
+
if self.image_only_mode:
|
|
82
|
+
self.use_vision = True
|
|
69
83
|
# Auto-detect vision capability from model name if not explicitly set
|
|
70
84
|
if "use_vision" not in config and self.model:
|
|
71
85
|
self.use_vision = self._is_vision_model(self.model)
|
|
@@ -91,6 +105,9 @@ class CrafterPolicy(Policy):
|
|
|
91
105
|
self.history_messages = []
|
|
92
106
|
self.turn_index = 0
|
|
93
107
|
self.trajectory_history = []
|
|
108
|
+
|
|
109
|
+
# DEBUG: Log final state
|
|
110
|
+
_logger.debug(f"🔊 [POLICY_INIT] FINAL STATE: use_vision={self.use_vision}, image_only_mode={self.image_only_mode}, model={self.model}")
|
|
94
111
|
|
|
95
112
|
def _append_user_observation(self, observation_text: str) -> None:
|
|
96
113
|
self.history_messages.append({"role": "user", "content": observation_text})
|
|
@@ -125,10 +142,36 @@ class CrafterPolicy(Policy):
|
|
|
125
142
|
history=history,
|
|
126
143
|
turn=turn,
|
|
127
144
|
image_parts=image_parts,
|
|
145
|
+
image_only_mode=self.image_only_mode,
|
|
128
146
|
)
|
|
147
|
+
|
|
148
|
+
# DEBUG: Log message structure
|
|
149
|
+
import logging
|
|
150
|
+
_logger = logging.getLogger(__name__)
|
|
151
|
+
_logger.debug(f"🔊 [BUILD_REQUEST] Built {len(messages)} messages")
|
|
152
|
+
for idx, msg in enumerate(messages):
|
|
153
|
+
role = msg.get("role")
|
|
154
|
+
content = msg.get("content")
|
|
155
|
+
if isinstance(content, list):
|
|
156
|
+
_logger.debug(f"🔊 [BUILD_REQUEST] Message[{idx}] role={role}, content=list[{len(content)}]")
|
|
157
|
+
for part_idx, part in enumerate(content):
|
|
158
|
+
if isinstance(part, dict):
|
|
159
|
+
part_type = part.get("type")
|
|
160
|
+
_logger.debug(f"🔊 [BUILD_REQUEST] Part[{part_idx}]: type={part_type}")
|
|
161
|
+
else:
|
|
162
|
+
content_len = len(str(content)) if content else 0
|
|
163
|
+
_logger.debug(f"🔊 [BUILD_REQUEST] Message[{idx}] role={role}, content_len={content_len}")
|
|
164
|
+
|
|
129
165
|
payload: dict[str, Any] = {
|
|
130
166
|
"messages": messages,
|
|
131
167
|
}
|
|
168
|
+
|
|
169
|
+
# DEBUG: Verify messages are in payload correctly
|
|
170
|
+
_logger.debug(f"🔊 [BUILD_REQUEST_PAYLOAD] Created payload with {len(payload['messages'])} messages")
|
|
171
|
+
for idx, msg in enumerate(payload["messages"]):
|
|
172
|
+
content = msg.get("content")
|
|
173
|
+
_logger.debug(f"🔊 [BUILD_REQUEST_PAYLOAD] Payload message[{idx}]: type={type(content).__name__}, is_list={isinstance(content, list)}, len={len(content) if isinstance(content, list) else len(str(content)) if content else 0}")
|
|
174
|
+
|
|
132
175
|
if self.model is not None:
|
|
133
176
|
payload["model"] = self.model
|
|
134
177
|
# Thinking controls
|
|
@@ -354,7 +397,18 @@ class CrafterPolicy(Policy):
|
|
|
354
397
|
raw_candidate = metadata.get("raw_observation")
|
|
355
398
|
if isinstance(raw_candidate, dict):
|
|
356
399
|
raw_observation = raw_candidate
|
|
400
|
+
|
|
401
|
+
# DEBUG: Log image extraction
|
|
402
|
+
import logging
|
|
403
|
+
_logger = logging.getLogger(__name__)
|
|
404
|
+
_logger.debug(f"🔊 [POLICY] use_vision={self.use_vision}, has_raw_obs={raw_observation is not None}")
|
|
405
|
+
if raw_observation:
|
|
406
|
+
obs = raw_observation.get("observation", raw_observation)
|
|
407
|
+
data_url = obs.get("observation_image_data_url") if isinstance(obs, dict) else None
|
|
408
|
+
_logger.debug(f"🔊 [POLICY] has_data_url={data_url is not None}, url_preview={data_url[:50] if data_url else 'NONE'}...")
|
|
409
|
+
|
|
357
410
|
image_parts = self._extract_image_parts(raw_observation)
|
|
411
|
+
_logger.debug(f"🔊 [POLICY] Extracted {len(image_parts)} image parts")
|
|
358
412
|
|
|
359
413
|
payload = self.build_inference_request(
|
|
360
414
|
combined_text,
|
|
@@ -362,7 +416,17 @@ class CrafterPolicy(Policy):
|
|
|
362
416
|
turn=self.turn_index,
|
|
363
417
|
image_parts=image_parts,
|
|
364
418
|
)
|
|
365
|
-
|
|
419
|
+
|
|
420
|
+
# DEBUG: Verify payload before returning
|
|
421
|
+
_logger.debug(f"🔊 [POLICY_STEP_RETURN] About to return payload with {len(payload.get('messages', []))} messages")
|
|
422
|
+
for idx, msg in enumerate(payload.get("messages", [])):
|
|
423
|
+
content = msg.get("content")
|
|
424
|
+
_logger.debug(f"🔊 [POLICY_STEP_RETURN] Return message[{idx}]: type={type(content).__name__}, is_list={isinstance(content, list)}")
|
|
425
|
+
if isinstance(content, list):
|
|
426
|
+
_logger.debug(f"🔊 [POLICY_STEP_RETURN] Content list has {len(content)} items")
|
|
427
|
+
# Add assertion to catch corruption early
|
|
428
|
+
assert len(content) > 0, f"Message content list is empty! This should contain images."
|
|
429
|
+
|
|
366
430
|
meta_out = {
|
|
367
431
|
"inference_url": self.inference_url,
|
|
368
432
|
"inference_request": payload,
|
|
@@ -417,14 +481,21 @@ class CrafterPolicy(Policy):
|
|
|
417
481
|
"""Prepare an inference request (implementing abstract method)."""
|
|
418
482
|
# Format observation with rich contextual information
|
|
419
483
|
observation_text = self._format_observation_for_llm(observation)
|
|
420
|
-
|
|
421
|
-
|
|
422
|
-
|
|
484
|
+
|
|
485
|
+
# Extract image parts based on vision settings
|
|
486
|
+
if self.use_vision:
|
|
487
|
+
image_parts = self._extract_image_parts(observation)
|
|
488
|
+
else:
|
|
489
|
+
# Text-only mode: don't include any images
|
|
490
|
+
image_parts = []
|
|
491
|
+
|
|
492
|
+
# Build messages with appropriate mode
|
|
423
493
|
messages = CrafterReActAgent.build_messages(
|
|
424
494
|
observation=observation_text,
|
|
425
495
|
history=history,
|
|
426
496
|
turn=self.turn_index,
|
|
427
497
|
image_parts=image_parts,
|
|
498
|
+
image_only_mode=self.image_only_mode,
|
|
428
499
|
)
|
|
429
500
|
|
|
430
501
|
# Return messages and tools schema
|
|
@@ -85,8 +85,17 @@ class CrafterReActAgent:
|
|
|
85
85
|
history: list[dict[str, Any]] | None = None,
|
|
86
86
|
turn: int | None = None,
|
|
87
87
|
image_parts: list[dict[str, Any]] | None = None,
|
|
88
|
+
image_only_mode: bool = False,
|
|
88
89
|
) -> list[dict[str, Any]]:
|
|
89
|
-
"""Construct OpenAI-style messages list for vLLM generation.
|
|
90
|
+
"""Construct OpenAI-style messages list for vLLM generation.
|
|
91
|
+
|
|
92
|
+
Args:
|
|
93
|
+
observation: Text observation to include
|
|
94
|
+
history: Previous conversation history
|
|
95
|
+
turn: Current turn number
|
|
96
|
+
image_parts: Image content parts in OpenAI format
|
|
97
|
+
image_only_mode: If True, only include images without text observation
|
|
98
|
+
"""
|
|
90
99
|
msgs: list[dict[str, Any]] = [
|
|
91
100
|
{"role": "system", "content": CrafterReActAgent.get_system_prompt()}
|
|
92
101
|
]
|
|
@@ -94,8 +103,14 @@ class CrafterReActAgent:
|
|
|
94
103
|
msgs.extend(history)
|
|
95
104
|
user_content: Any
|
|
96
105
|
if image_parts:
|
|
97
|
-
|
|
106
|
+
# Image-only mode: send only images without text observation
|
|
107
|
+
if image_only_mode:
|
|
108
|
+
user_content = list(image_parts)
|
|
109
|
+
else:
|
|
110
|
+
# Normal vision mode: send both text and images
|
|
111
|
+
user_content = [{"type": "text", "text": observation}] + list(image_parts)
|
|
98
112
|
else:
|
|
113
|
+
# Text-only mode (default): no images
|
|
99
114
|
user_content = observation
|
|
100
115
|
msgs.append({"role": "user", "content": user_content})
|
|
101
116
|
return msgs
|