synth-ai 0.2.14__py3-none-any.whl → 0.2.16__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of synth-ai might be problematic. Click here for more details.
- examples/README.md +1 -0
- examples/multi_step/SFT_README.md +147 -0
- examples/multi_step/configs/crafter_rl_stepwise_hosted_judge.toml +9 -9
- examples/multi_step/configs/crafter_sft_qwen30b_lora.toml +62 -0
- examples/multi_step/convert_traces_to_sft.py +84 -0
- examples/multi_step/run_sft_qwen30b.sh +45 -0
- examples/qwen_coder/configs/coder_lora_30b.toml +2 -1
- examples/qwen_coder/configs/coder_lora_4b.toml +2 -1
- examples/qwen_coder/configs/coder_lora_small.toml +2 -1
- examples/qwen_vl/BUGS_AND_FIXES.md +232 -0
- examples/qwen_vl/IMAGE_VALIDATION_COMPLETE.md +271 -0
- examples/qwen_vl/IMAGE_VALIDATION_SUMMARY.md +260 -0
- examples/qwen_vl/INFERENCE_SFT_TESTS.md +412 -0
- examples/qwen_vl/NEXT_STEPS_2B.md +325 -0
- examples/qwen_vl/QUICKSTART.md +327 -0
- examples/qwen_vl/QUICKSTART_RL_VISION.md +110 -0
- examples/qwen_vl/README.md +154 -0
- examples/qwen_vl/RL_VISION_COMPLETE.md +475 -0
- examples/qwen_vl/RL_VISION_TESTING.md +333 -0
- examples/qwen_vl/SDK_VISION_INTEGRATION.md +328 -0
- examples/qwen_vl/SETUP_COMPLETE.md +275 -0
- examples/qwen_vl/VISION_TESTS_COMPLETE.md +490 -0
- examples/qwen_vl/VLM_PIPELINE_COMPLETE.md +242 -0
- examples/qwen_vl/__init__.py +2 -0
- examples/qwen_vl/collect_data_via_cli.md +423 -0
- examples/qwen_vl/collect_vision_traces.py +368 -0
- examples/qwen_vl/configs/crafter_rl_vision_qwen3vl4b.toml +127 -0
- examples/qwen_vl/configs/crafter_vlm_sft_example.toml +60 -0
- examples/qwen_vl/configs/eval_gpt4o_mini_vision.toml +43 -0
- examples/qwen_vl/configs/eval_gpt4o_vision_proper.toml +29 -0
- examples/qwen_vl/configs/eval_gpt5nano_vision.toml +45 -0
- examples/qwen_vl/configs/eval_qwen2vl_vision.toml +44 -0
- examples/qwen_vl/configs/filter_qwen2vl_sft.toml +50 -0
- examples/qwen_vl/configs/filter_vision_sft.toml +53 -0
- examples/qwen_vl/configs/filter_vision_test.toml +8 -0
- examples/qwen_vl/configs/sft_qwen3_vl_2b_test.toml +54 -0
- examples/qwen_vl/crafter_gpt5nano_agent.py +308 -0
- examples/qwen_vl/crafter_qwen_vl_agent.py +300 -0
- examples/qwen_vl/run_vision_comparison.sh +62 -0
- examples/qwen_vl/run_vision_sft_pipeline.sh +175 -0
- examples/qwen_vl/test_image_validation.py +201 -0
- examples/qwen_vl/test_sft_vision_data.py +110 -0
- examples/rl/README.md +1 -1
- examples/rl/configs/eval_base_qwen.toml +17 -0
- examples/rl/configs/eval_rl_qwen.toml +13 -0
- examples/rl/configs/rl_from_base_qwen.toml +37 -0
- examples/rl/configs/rl_from_base_qwen17.toml +76 -0
- examples/rl/configs/rl_from_ft_qwen.toml +37 -0
- examples/rl/run_eval.py +436 -0
- examples/rl/run_rl_and_save.py +111 -0
- examples/rl/task_app/README.md +22 -0
- examples/rl/task_app/math_single_step.py +990 -0
- examples/rl/task_app/math_task_app.py +111 -0
- examples/sft/README.md +5 -5
- examples/sft/configs/crafter_fft_qwen0p6b.toml +4 -2
- examples/sft/configs/crafter_lora_qwen0p6b.toml +4 -3
- examples/sft/evaluate.py +2 -4
- examples/sft/export_dataset.py +7 -4
- examples/swe/task_app/README.md +1 -1
- examples/swe/task_app/grpo_swe_mini.py +0 -1
- examples/swe/task_app/grpo_swe_mini_task_app.py +0 -12
- examples/swe/task_app/hosted/envs/mini_swe/environment.py +13 -13
- examples/swe/task_app/hosted/policy_routes.py +0 -2
- examples/swe/task_app/hosted/rollout.py +0 -8
- examples/task_apps/crafter/task_app/grpo_crafter.py +4 -7
- examples/task_apps/crafter/task_app/synth_envs_hosted/envs/crafter/policy.py +59 -1
- examples/task_apps/crafter/task_app/synth_envs_hosted/inference/openai_client.py +30 -0
- examples/task_apps/crafter/task_app/synth_envs_hosted/policy_routes.py +62 -31
- examples/task_apps/crafter/task_app/synth_envs_hosted/rollout.py +16 -14
- examples/task_apps/enron/__init__.py +1 -0
- examples/vlm/README.md +3 -3
- examples/vlm/configs/crafter_vlm_gpt4o.toml +2 -0
- examples/vlm/crafter_openai_vlm_agent.py +3 -5
- examples/vlm/filter_image_rows.py +1 -1
- examples/vlm/run_crafter_vlm_benchmark.py +2 -2
- examples/warming_up_to_rl/_utils.py +92 -0
- examples/warming_up_to_rl/analyze_trace_db.py +1 -1
- examples/warming_up_to_rl/configs/crafter_fft.toml +2 -0
- examples/warming_up_to_rl/configs/crafter_fft_4b.toml +2 -0
- examples/warming_up_to_rl/configs/eval_fft_qwen4b.toml +2 -0
- examples/warming_up_to_rl/configs/eval_groq_qwen32b.toml +2 -0
- examples/warming_up_to_rl/configs/eval_modal_qwen4b.toml +2 -1
- examples/warming_up_to_rl/configs/rl_from_base_qwen4b.toml +2 -1
- examples/warming_up_to_rl/configs/rl_from_ft.toml +2 -0
- examples/warming_up_to_rl/export_trace_sft.py +174 -60
- examples/warming_up_to_rl/readme.md +63 -132
- examples/warming_up_to_rl/run_fft_and_save.py +1 -1
- examples/warming_up_to_rl/run_rl_and_save.py +1 -1
- examples/warming_up_to_rl/task_app/README.md +42 -0
- examples/warming_up_to_rl/task_app/grpo_crafter.py +696 -0
- examples/warming_up_to_rl/task_app/grpo_crafter_task_app.py +135 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/README.md +173 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/__init__.py +5 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/branching.py +143 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/environment_routes.py +1226 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/__init__.py +1 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/__init__.py +6 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/app.py +1 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/environment.py +522 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/policy.py +478 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/react_agent.py +108 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/shared.py +305 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/tools.py +47 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/hosted_app.py +204 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/inference/__init__.py +5 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/inference/openai_client.py +618 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/main.py +100 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/policy_routes.py +1081 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/registry.py +195 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/rollout.py +1861 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/storage/__init__.py +5 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/storage/volume.py +211 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/test_agents.py +161 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/test_service.py +137 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/utils.py +62 -0
- synth_ai/__init__.py +44 -30
- synth_ai/_utils/__init__.py +47 -0
- synth_ai/_utils/base_url.py +10 -0
- synth_ai/_utils/http.py +10 -0
- synth_ai/_utils/prompts.py +10 -0
- synth_ai/_utils/task_app_state.py +12 -0
- synth_ai/_utils/user_config.py +10 -0
- synth_ai/api/models/supported.py +144 -7
- synth_ai/api/train/__init__.py +13 -1
- synth_ai/api/train/cli.py +30 -7
- synth_ai/api/train/config_finder.py +18 -11
- synth_ai/api/train/env_resolver.py +13 -10
- synth_ai/cli/__init__.py +62 -78
- synth_ai/cli/_modal_wrapper.py +7 -5
- synth_ai/cli/_typer_patch.py +0 -2
- synth_ai/cli/_validate_task_app.py +22 -4
- synth_ai/cli/legacy_root_backup.py +3 -1
- synth_ai/cli/lib/__init__.py +10 -0
- synth_ai/cli/lib/task_app_discovery.py +7 -0
- synth_ai/cli/lib/task_app_env.py +518 -0
- synth_ai/cli/recent.py +2 -1
- synth_ai/cli/setup.py +266 -0
- synth_ai/cli/status.py +1 -1
- synth_ai/cli/task_app_deploy.py +16 -0
- synth_ai/cli/task_app_list.py +25 -0
- synth_ai/cli/task_app_modal_serve.py +16 -0
- synth_ai/cli/task_app_serve.py +18 -0
- synth_ai/cli/task_apps.py +71 -31
- synth_ai/cli/traces.py +1 -1
- synth_ai/cli/train.py +18 -0
- synth_ai/cli/tui.py +7 -2
- synth_ai/cli/turso.py +1 -1
- synth_ai/cli/watch.py +1 -1
- synth_ai/demos/__init__.py +10 -0
- synth_ai/demos/core/__init__.py +28 -1
- synth_ai/demos/crafter/__init__.py +1 -0
- synth_ai/demos/crafter/crafter_fft_4b.toml +55 -0
- synth_ai/demos/crafter/grpo_crafter_task_app.py +185 -0
- synth_ai/demos/crafter/rl_from_base_qwen4b.toml +74 -0
- synth_ai/demos/demo_registry.py +176 -0
- synth_ai/demos/math/__init__.py +1 -0
- synth_ai/demos/math/_common.py +16 -0
- synth_ai/demos/math/app.py +38 -0
- synth_ai/demos/math/config.toml +76 -0
- synth_ai/demos/math/deploy_modal.py +54 -0
- synth_ai/demos/math/modal_task_app.py +702 -0
- synth_ai/demos/math/task_app_entry.py +51 -0
- synth_ai/environments/environment/core.py +7 -1
- synth_ai/environments/examples/bandit/engine.py +0 -1
- synth_ai/environments/examples/bandit/environment.py +0 -1
- synth_ai/environments/examples/wordle/environment.py +0 -1
- synth_ai/evals/base.py +16 -5
- synth_ai/evals/client.py +1 -1
- synth_ai/inference/client.py +1 -1
- synth_ai/judge_schemas.py +8 -8
- synth_ai/learning/client.py +1 -1
- synth_ai/learning/health.py +1 -1
- synth_ai/learning/jobs.py +1 -1
- synth_ai/learning/rl/client.py +1 -1
- synth_ai/learning/rl/env_keys.py +1 -1
- synth_ai/learning/rl/secrets.py +1 -1
- synth_ai/learning/sft/client.py +1 -1
- synth_ai/learning/sft/data.py +407 -4
- synth_ai/learning/validators.py +4 -1
- synth_ai/task/apps/__init__.py +4 -2
- synth_ai/task/config.py +6 -4
- synth_ai/task/rubrics/__init__.py +1 -2
- synth_ai/task/rubrics/loaders.py +14 -10
- synth_ai/task/rubrics.py +219 -0
- synth_ai/task/trace_correlation_helpers.py +24 -11
- synth_ai/task/tracing_utils.py +14 -3
- synth_ai/task/validators.py +2 -3
- synth_ai/tracing_v3/abstractions.py +3 -3
- synth_ai/tracing_v3/config.py +15 -13
- synth_ai/tracing_v3/constants.py +21 -0
- synth_ai/tracing_v3/db_config.py +3 -1
- synth_ai/tracing_v3/decorators.py +10 -7
- synth_ai/tracing_v3/llm_call_record_helpers.py +5 -5
- synth_ai/tracing_v3/session_tracer.py +7 -7
- synth_ai/tracing_v3/storage/base.py +29 -29
- synth_ai/tracing_v3/storage/config.py +3 -3
- synth_ai/tracing_v3/turso/daemon.py +8 -9
- synth_ai/tracing_v3/turso/native_manager.py +80 -72
- synth_ai/tracing_v3/utils.py +2 -2
- synth_ai/tui/cli/query_experiments.py +4 -4
- synth_ai/tui/cli/query_experiments_v3.py +4 -4
- synth_ai/tui/dashboard.py +14 -9
- synth_ai/utils/__init__.py +101 -0
- synth_ai/utils/base_url.py +94 -0
- synth_ai/utils/cli.py +131 -0
- synth_ai/utils/env.py +287 -0
- synth_ai/utils/http.py +169 -0
- synth_ai/utils/modal.py +308 -0
- synth_ai/utils/process.py +212 -0
- synth_ai/utils/prompts.py +39 -0
- synth_ai/utils/sqld.py +122 -0
- synth_ai/utils/task_app_discovery.py +882 -0
- synth_ai/utils/task_app_env.py +186 -0
- synth_ai/utils/task_app_state.py +318 -0
- synth_ai/utils/user_config.py +137 -0
- synth_ai/v0/config/__init__.py +1 -5
- synth_ai/v0/config/base_url.py +1 -7
- synth_ai/v0/tracing/config.py +1 -1
- synth_ai/v0/tracing/decorators.py +1 -1
- synth_ai/v0/tracing/upload.py +1 -1
- synth_ai/v0/tracing_v1/config.py +1 -1
- synth_ai/v0/tracing_v1/decorators.py +1 -1
- synth_ai/v0/tracing_v1/upload.py +1 -1
- {synth_ai-0.2.14.dist-info → synth_ai-0.2.16.dist-info}/METADATA +85 -31
- {synth_ai-0.2.14.dist-info → synth_ai-0.2.16.dist-info}/RECORD +229 -117
- synth_ai/cli/man.py +0 -106
- synth_ai/compound/cais.py +0 -0
- synth_ai/core/experiment.py +0 -13
- synth_ai/core/system.py +0 -15
- synth_ai/demo_registry.py +0 -295
- synth_ai/handshake.py +0 -109
- synth_ai/http.py +0 -26
- {synth_ai-0.2.14.dist-info → synth_ai-0.2.16.dist-info}/WHEEL +0 -0
- {synth_ai-0.2.14.dist-info → synth_ai-0.2.16.dist-info}/entry_points.txt +0 -0
- {synth_ai-0.2.14.dist-info → synth_ai-0.2.16.dist-info}/licenses/LICENSE +0 -0
- {synth_ai-0.2.14.dist-info → synth_ai-0.2.16.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,702 @@
|
|
|
1
|
+
"""Modal task app for Hendrycks MATH single-step Environment."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import os
|
|
6
|
+
from collections.abc import Iterable
|
|
7
|
+
from functools import lru_cache
|
|
8
|
+
from pathlib import Path
|
|
9
|
+
|
|
10
|
+
from modal import App, Image, Secret, asgi_app
|
|
11
|
+
from starlette.requests import Request
|
|
12
|
+
|
|
13
|
+
try: # Backward compatibility with older installed SDKs
|
|
14
|
+
from synth_ai.demos.core import DEFAULT_TASK_APP_SECRET_NAME
|
|
15
|
+
except Exception: # pragma: no cover - occurs on older deployments
|
|
16
|
+
DEFAULT_TASK_APP_SECRET_NAME = "hendrycks-math-task-app-secret"
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
_HERE = Path(__file__).resolve()
|
|
20
|
+
_ROOT = _HERE.parent
|
|
21
|
+
_SYNTH_HOSTED = None
|
|
22
|
+
try:
|
|
23
|
+
probe = _HERE
|
|
24
|
+
for _ in range(8):
|
|
25
|
+
candidate = (
|
|
26
|
+
probe / "backend/app/routes/clustered_training/dev/synth_envs_hosted"
|
|
27
|
+
).resolve()
|
|
28
|
+
if candidate.exists():
|
|
29
|
+
_SYNTH_HOSTED = candidate
|
|
30
|
+
break
|
|
31
|
+
if probe.parent == probe:
|
|
32
|
+
break
|
|
33
|
+
probe = probe.parent
|
|
34
|
+
except Exception:
|
|
35
|
+
_SYNTH_HOSTED = None
|
|
36
|
+
|
|
37
|
+
image = Image.debian_slim(python_version="3.11").pip_install(
|
|
38
|
+
"fastapi>=0.110.0",
|
|
39
|
+
"uvicorn>=0.23.0",
|
|
40
|
+
"pydantic>=2.6.0",
|
|
41
|
+
"httpx>=0.24.0",
|
|
42
|
+
"numpy>=1.24.0",
|
|
43
|
+
"aiohttp>=3.8.0",
|
|
44
|
+
"datasets>=2.16.0",
|
|
45
|
+
"synth-ai",
|
|
46
|
+
)
|
|
47
|
+
if _SYNTH_HOSTED is not None:
|
|
48
|
+
image = image.add_local_dir(str(_SYNTH_HOSTED), "/app/synth_envs_hosted")
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
def _build_inline_secret() -> Secret:
|
|
52
|
+
required = ("ENVIRONMENT_API_KEY",)
|
|
53
|
+
optional = ("SYNTH_API_KEY", "OPENAI_API_KEY")
|
|
54
|
+
payload: dict[str, str] = {}
|
|
55
|
+
missing: list[str] = []
|
|
56
|
+
|
|
57
|
+
for key in required:
|
|
58
|
+
value = (os.environ.get(key) or "").strip()
|
|
59
|
+
if not value:
|
|
60
|
+
missing.append(key)
|
|
61
|
+
else:
|
|
62
|
+
payload[key] = value
|
|
63
|
+
|
|
64
|
+
for key in optional:
|
|
65
|
+
value = (os.environ.get(key) or "").strip()
|
|
66
|
+
if value:
|
|
67
|
+
payload[key] = value
|
|
68
|
+
|
|
69
|
+
if missing:
|
|
70
|
+
raise RuntimeError(
|
|
71
|
+
"Missing required environment values for inline secret: " + ", ".join(missing)
|
|
72
|
+
)
|
|
73
|
+
|
|
74
|
+
previews = ", ".join(f"{k}:len={len(v)}" for k, v in payload.items())
|
|
75
|
+
print(f"[startup] TASK_APP_SECRET_NAME={DEFAULT_TASK_APP_SECRET_NAME}")
|
|
76
|
+
print(f"[startup] inline secret prepared ({previews})")
|
|
77
|
+
|
|
78
|
+
return Secret.from_dict(payload)
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
INLINE_SECRET = _build_inline_secret()
|
|
82
|
+
|
|
83
|
+
app = App("hendrycks-math-task-app")
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
@app.function(
|
|
87
|
+
image=image,
|
|
88
|
+
timeout=600,
|
|
89
|
+
memory=16384,
|
|
90
|
+
cpu=4,
|
|
91
|
+
min_containers=1,
|
|
92
|
+
secrets=[INLINE_SECRET],
|
|
93
|
+
)
|
|
94
|
+
@asgi_app()
|
|
95
|
+
def fastapi_app():
|
|
96
|
+
import httpx
|
|
97
|
+
from fastapi import Body, FastAPI, HTTPException, status
|
|
98
|
+
from fastapi.middleware.cors import CORSMiddleware
|
|
99
|
+
from fastapi.responses import JSONResponse
|
|
100
|
+
|
|
101
|
+
try:
|
|
102
|
+
from synth_ai.task.auth import (
|
|
103
|
+
is_api_key_header_authorized,
|
|
104
|
+
normalize_environment_api_key,
|
|
105
|
+
)
|
|
106
|
+
except Exception: # pragma: no cover - fallback for older synth-ai builds
|
|
107
|
+
|
|
108
|
+
def _normalize_env_key_fallback() -> str | None:
|
|
109
|
+
key = os.getenv("ENVIRONMENT_API_KEY")
|
|
110
|
+
if key:
|
|
111
|
+
return key
|
|
112
|
+
for alias in ("dev_environment_api_key", "DEV_ENVIRONMENT_API_KEY"):
|
|
113
|
+
candidate = os.getenv(alias)
|
|
114
|
+
if candidate:
|
|
115
|
+
os.environ["ENVIRONMENT_API_KEY"] = candidate
|
|
116
|
+
return candidate
|
|
117
|
+
return None
|
|
118
|
+
|
|
119
|
+
def normalize_environment_api_key() -> str | None: # type: ignore[override]
|
|
120
|
+
return _normalize_env_key_fallback()
|
|
121
|
+
|
|
122
|
+
def _header_values(request: Request, header: str) -> Iterable[str]:
|
|
123
|
+
raw = request.headers.get(header) or request.headers.get(header.lower())
|
|
124
|
+
return [raw] if raw is not None else []
|
|
125
|
+
|
|
126
|
+
def _split(values: Iterable[str]) -> list[str]:
|
|
127
|
+
parts: list[str] = []
|
|
128
|
+
for value in values:
|
|
129
|
+
if not isinstance(value, str):
|
|
130
|
+
continue
|
|
131
|
+
for chunk in value.split(","):
|
|
132
|
+
chunk = chunk.strip()
|
|
133
|
+
if chunk:
|
|
134
|
+
parts.append(chunk)
|
|
135
|
+
return parts
|
|
136
|
+
|
|
137
|
+
def is_api_key_header_authorized(request: Request) -> bool: # type: ignore[override]
|
|
138
|
+
expected = normalize_environment_api_key()
|
|
139
|
+
if not expected:
|
|
140
|
+
return False
|
|
141
|
+
single = _header_values(request, "x-api-key")
|
|
142
|
+
multi = _header_values(request, "x-api-keys")
|
|
143
|
+
auth = _header_values(request, "authorization")
|
|
144
|
+
bearer = []
|
|
145
|
+
for token in auth:
|
|
146
|
+
if isinstance(token, str) and token.lower().startswith("bearer "):
|
|
147
|
+
bearer.append(token.split(" ", 1)[1].strip())
|
|
148
|
+
candidates = _split(single + multi + bearer)
|
|
149
|
+
return any(candidate == expected for candidate in candidates)
|
|
150
|
+
|
|
151
|
+
@lru_cache(maxsize=1)
|
|
152
|
+
def _hf_split(subject: str, split: str, slice_spec: str | None = None):
|
|
153
|
+
from datasets import load_dataset # type: ignore
|
|
154
|
+
|
|
155
|
+
s = split
|
|
156
|
+
if slice_spec:
|
|
157
|
+
s = f"{s}{slice_spec}"
|
|
158
|
+
|
|
159
|
+
try:
|
|
160
|
+
return load_dataset("nlile/hendrycks-MATH-benchmark", subject, split=s)
|
|
161
|
+
except ValueError:
|
|
162
|
+
base = load_dataset("nlile/hendrycks-MATH-benchmark", split=s)
|
|
163
|
+
if subject and subject not in {"", "default"}:
|
|
164
|
+
if "subject" in base.column_names:
|
|
165
|
+
base = base.filter(lambda ex: ex.get("subject") == subject)
|
|
166
|
+
elif isinstance(base, list):
|
|
167
|
+
base = [ex for ex in base if ex.get("subject") == subject]
|
|
168
|
+
return base
|
|
169
|
+
|
|
170
|
+
def _normalize_answer_text(s: str) -> str:
|
|
171
|
+
import re as _re
|
|
172
|
+
|
|
173
|
+
return _re.sub(r"[^0-9A-Za-z.+\\-/*=]", "", (s or "").strip()).lower()
|
|
174
|
+
|
|
175
|
+
def _extract_boxed(s: str) -> str:
|
|
176
|
+
import re as _re
|
|
177
|
+
|
|
178
|
+
matches = list(_re.finditer(r"\\boxed\\{([^}]+)\\}", s or ""))
|
|
179
|
+
return matches[-1].group(1) if matches else ""
|
|
180
|
+
|
|
181
|
+
def _load_hendrycks_problem(seed: int, subject: str | None = None) -> tuple[str, str]:
|
|
182
|
+
subj = subject or os.getenv("HENDRYCKS_MATH_CONFIG", "default")
|
|
183
|
+
ds = _hf_split(
|
|
184
|
+
subj, os.getenv("HENDRYCKS_MATH_SPLIT", "test"), os.getenv("HENDRYCKS_MATH_SLICE")
|
|
185
|
+
)
|
|
186
|
+
n = len(ds) if hasattr(ds, "__len__") else 0
|
|
187
|
+
if n == 0 and subject not in {"", "default"}:
|
|
188
|
+
ds = _hf_split(
|
|
189
|
+
"default",
|
|
190
|
+
os.getenv("HENDRYCKS_MATH_SPLIT", "test"),
|
|
191
|
+
os.getenv("HENDRYCKS_MATH_SLICE"),
|
|
192
|
+
)
|
|
193
|
+
n = len(ds) if hasattr(ds, "__len__") else 0
|
|
194
|
+
if n == 0:
|
|
195
|
+
raise RuntimeError("Hendrycks MATH dataset loaded empty")
|
|
196
|
+
idx = abs(int(seed)) % n
|
|
197
|
+
ex = ds[int(idx)]
|
|
198
|
+
q = ex.get("problem") or ex.get("question") or ex.get("prompt")
|
|
199
|
+
a = ex.get("solution") or ex.get("answer") or ""
|
|
200
|
+
if not q:
|
|
201
|
+
raise RuntimeError("Hendrycks item missing problem text")
|
|
202
|
+
return str(q), str(a)
|
|
203
|
+
|
|
204
|
+
def create_app():
|
|
205
|
+
|
|
206
|
+
app = FastAPI(title="Hendrycks Math Task App", version="0.1.0")
|
|
207
|
+
app.add_middleware(
|
|
208
|
+
CORSMiddleware,
|
|
209
|
+
allow_origins=["*"],
|
|
210
|
+
allow_credentials=True,
|
|
211
|
+
allow_methods=["*"],
|
|
212
|
+
allow_headers=["*"],
|
|
213
|
+
)
|
|
214
|
+
|
|
215
|
+
import logging
|
|
216
|
+
|
|
217
|
+
logger = logging.getLogger("hendrycks_math_task_app")
|
|
218
|
+
if not logger.handlers:
|
|
219
|
+
logger.addHandler(logging.StreamHandler())
|
|
220
|
+
logger.setLevel(logging.INFO)
|
|
221
|
+
|
|
222
|
+
def _log_env_key_prefix(source: str, env_key: str | None) -> str | None:
|
|
223
|
+
if not env_key:
|
|
224
|
+
return None
|
|
225
|
+
half = max(1, len(env_key) // 2)
|
|
226
|
+
prefix = env_key[:half]
|
|
227
|
+
msg = f"[{source}] expected ENVIRONMENT_API_KEY prefix: {prefix}"
|
|
228
|
+
print(msg)
|
|
229
|
+
logger.info(msg)
|
|
230
|
+
return prefix
|
|
231
|
+
|
|
232
|
+
def _resolve_env_keys() -> set[str]:
|
|
233
|
+
keys: set[str] = set()
|
|
234
|
+
for alias in (
|
|
235
|
+
"ENVIRONMENT_API_KEY",
|
|
236
|
+
"dev_environment_api_key",
|
|
237
|
+
"DEV_ENVIRONMENT_API_KEY",
|
|
238
|
+
):
|
|
239
|
+
value = os.environ.get(alias)
|
|
240
|
+
if value:
|
|
241
|
+
os.environ.setdefault("ENVIRONMENT_API_KEY", value)
|
|
242
|
+
keys.add(value)
|
|
243
|
+
alias_env = os.environ.get("ENVIRONMENT_API_KEY_ALIASES", "")
|
|
244
|
+
for chunk in alias_env.split(","):
|
|
245
|
+
trimmed = chunk.strip()
|
|
246
|
+
if trimmed:
|
|
247
|
+
keys.add(trimmed)
|
|
248
|
+
return keys
|
|
249
|
+
|
|
250
|
+
def _extract_header_candidates(
|
|
251
|
+
request: Request,
|
|
252
|
+
x_api_key: str | None,
|
|
253
|
+
x_api_keys: str | None,
|
|
254
|
+
authorization: str | None,
|
|
255
|
+
) -> list[str]:
|
|
256
|
+
headers = request.headers
|
|
257
|
+
candidates: list[str] = []
|
|
258
|
+
primary = x_api_key or headers.get("x-api-key")
|
|
259
|
+
if primary:
|
|
260
|
+
candidates.append(primary.strip())
|
|
261
|
+
secondary = x_api_keys or headers.get("x-api-keys")
|
|
262
|
+
if secondary:
|
|
263
|
+
candidates.extend(
|
|
264
|
+
[value.strip() for value in secondary.split(",") if value.strip()]
|
|
265
|
+
)
|
|
266
|
+
auth_header = (
|
|
267
|
+
authorization or headers.get("authorization") or headers.get("Authorization")
|
|
268
|
+
)
|
|
269
|
+
if auth_header and auth_header.lower().startswith("bearer "):
|
|
270
|
+
token = auth_header.split(" ", 1)[1].strip()
|
|
271
|
+
if token:
|
|
272
|
+
candidates.append(token)
|
|
273
|
+
return [c for c in candidates if c]
|
|
274
|
+
|
|
275
|
+
def _is_authorized(
|
|
276
|
+
request: Request,
|
|
277
|
+
x_api_key: str | None,
|
|
278
|
+
x_api_keys: str | None,
|
|
279
|
+
authorization: str | None,
|
|
280
|
+
) -> bool:
|
|
281
|
+
keys = _resolve_env_keys()
|
|
282
|
+
if not keys:
|
|
283
|
+
return False
|
|
284
|
+
candidates = _extract_header_candidates(request, x_api_key, x_api_keys, authorization)
|
|
285
|
+
return any(candidate in keys for candidate in candidates)
|
|
286
|
+
|
|
287
|
+
@app.get("/info")
|
|
288
|
+
async def info():
|
|
289
|
+
return {
|
|
290
|
+
"service": {"base_url": os.getenv("SERVICE_BASE_URL", "")},
|
|
291
|
+
"inference": {
|
|
292
|
+
"base_url": "",
|
|
293
|
+
"endpoints": {"chat_completions": "/v1/chat/completions"},
|
|
294
|
+
},
|
|
295
|
+
}
|
|
296
|
+
|
|
297
|
+
@app.get("/health")
|
|
298
|
+
async def health(request: Request):
|
|
299
|
+
env_keys = _resolve_env_keys()
|
|
300
|
+
env_key = next(iter(env_keys), None)
|
|
301
|
+
if not env_key:
|
|
302
|
+
return JSONResponse(
|
|
303
|
+
status_code=503,
|
|
304
|
+
content={"status": "unhealthy", "detail": "Missing ENVIRONMENT_API_KEY"},
|
|
305
|
+
)
|
|
306
|
+
authorized = is_api_key_header_authorized(request)
|
|
307
|
+
if not authorized:
|
|
308
|
+
prefix = _log_env_key_prefix("health", env_key)
|
|
309
|
+
content = {
|
|
310
|
+
"status": "healthy",
|
|
311
|
+
"authorized": False,
|
|
312
|
+
}
|
|
313
|
+
if prefix:
|
|
314
|
+
content["expected_api_key_prefix"] = prefix
|
|
315
|
+
return JSONResponse(status_code=200, content=content)
|
|
316
|
+
return {"status": "healthy", "authorized": True}
|
|
317
|
+
|
|
318
|
+
@app.get("/health/rollout")
|
|
319
|
+
async def health_rollout(request: Request):
|
|
320
|
+
env_keys = _resolve_env_keys()
|
|
321
|
+
env_key = next(iter(env_keys), None)
|
|
322
|
+
if not env_key:
|
|
323
|
+
return JSONResponse(
|
|
324
|
+
status_code=503,
|
|
325
|
+
content={"status": "unhealthy", "detail": "Missing ENVIRONMENT_API_KEY"},
|
|
326
|
+
)
|
|
327
|
+
authorized = is_api_key_header_authorized(request)
|
|
328
|
+
if not authorized:
|
|
329
|
+
prefix = _log_env_key_prefix("health/rollout", env_key)
|
|
330
|
+
content = {
|
|
331
|
+
"status": "healthy",
|
|
332
|
+
"authorized": False,
|
|
333
|
+
}
|
|
334
|
+
if prefix:
|
|
335
|
+
content["expected_api_key_prefix"] = prefix
|
|
336
|
+
return JSONResponse(status_code=200, content=content)
|
|
337
|
+
return {"ok": True, "authorized": True}
|
|
338
|
+
|
|
339
|
+
@app.get("/task_info")
|
|
340
|
+
async def task_info(seed: int = 0, subject: str = "default"):
|
|
341
|
+
q, a = _load_hendrycks_problem(int(seed), subject=subject)
|
|
342
|
+
tools = [
|
|
343
|
+
{
|
|
344
|
+
"name": "submit_answer",
|
|
345
|
+
"description": "Provide the final numerical or algebraic answer for the current math problem.",
|
|
346
|
+
"parameters": {
|
|
347
|
+
"type": "object",
|
|
348
|
+
"properties": {
|
|
349
|
+
"answer": {
|
|
350
|
+
"type": "string",
|
|
351
|
+
"description": "The proposed final answer",
|
|
352
|
+
},
|
|
353
|
+
},
|
|
354
|
+
"required": ["answer"],
|
|
355
|
+
},
|
|
356
|
+
}
|
|
357
|
+
]
|
|
358
|
+
return {
|
|
359
|
+
"seed": int(seed),
|
|
360
|
+
"subject": subject,
|
|
361
|
+
"system": "",
|
|
362
|
+
"user": q,
|
|
363
|
+
"tools": tools,
|
|
364
|
+
"policy": {"name": "math-react"},
|
|
365
|
+
"answer": a,
|
|
366
|
+
}
|
|
367
|
+
|
|
368
|
+
return app
|
|
369
|
+
|
|
370
|
+
api = create_app()
|
|
371
|
+
|
|
372
|
+
from fastapi.exceptions import RequestValidationError
|
|
373
|
+
|
|
374
|
+
@api.exception_handler(RequestValidationError)
|
|
375
|
+
async def _on_validation_error(request: Request, exc: RequestValidationError):
|
|
376
|
+
try:
|
|
377
|
+
hdr = request.headers
|
|
378
|
+
snapshot = {
|
|
379
|
+
"path": str(request.url.path),
|
|
380
|
+
"have_x_api_key": bool(hdr.get("x-api-key")),
|
|
381
|
+
"have_x_api_keys": bool(hdr.get("x-api-keys")),
|
|
382
|
+
"have_authorization": bool(hdr.get("authorization")),
|
|
383
|
+
"errors": exc.errors()[:5],
|
|
384
|
+
}
|
|
385
|
+
print("[422] validation", snapshot, flush=True)
|
|
386
|
+
except Exception:
|
|
387
|
+
pass
|
|
388
|
+
return JSONResponse(
|
|
389
|
+
status_code=422, content={"status": "invalid", "detail": exc.errors()[:5]}
|
|
390
|
+
)
|
|
391
|
+
|
|
392
|
+
@api.get("/")
|
|
393
|
+
async def root_probe():
|
|
394
|
+
return {"status": "ok", "service": "math"}
|
|
395
|
+
|
|
396
|
+
@api.head("/")
|
|
397
|
+
async def head_probe():
|
|
398
|
+
return {"status": "ok"}
|
|
399
|
+
|
|
400
|
+
env_key = (
|
|
401
|
+
os.environ.get("ENVIRONMENT_API_KEY")
|
|
402
|
+
or os.environ.get("DEV_ENVIRONMENT_API_KEY")
|
|
403
|
+
or os.environ.get("DEV_ENVIRONMENT_API_KEY")
|
|
404
|
+
)
|
|
405
|
+
if not env_key:
|
|
406
|
+
raise RuntimeError("ENVIRONMENT_API_KEY missing in task app environment")
|
|
407
|
+
|
|
408
|
+
openai_remove_fields = (
|
|
409
|
+
"stop_after_tool_calls",
|
|
410
|
+
"thinking_mode",
|
|
411
|
+
"thinking_budget",
|
|
412
|
+
"reasoning",
|
|
413
|
+
)
|
|
414
|
+
openai_remove_sampling_fields = ("temperature", "top_p")
|
|
415
|
+
tool_choice_force = {"type": "function", "function": {"name": "submit_answer"}}
|
|
416
|
+
|
|
417
|
+
def _prepare_openai_payload(model: str | None, payload: dict[str, object]) -> dict[str, object]:
|
|
418
|
+
sanitized = dict(payload)
|
|
419
|
+
for key in openai_remove_fields:
|
|
420
|
+
sanitized.pop(key, None)
|
|
421
|
+
if model and "gpt-5" in model:
|
|
422
|
+
if "max_tokens" in sanitized and "max_completion_tokens" not in sanitized:
|
|
423
|
+
sanitized["max_completion_tokens"] = sanitized.pop("max_tokens")
|
|
424
|
+
else:
|
|
425
|
+
sanitized.pop("max_tokens", None)
|
|
426
|
+
for field in openai_remove_sampling_fields:
|
|
427
|
+
sanitized.pop(field, None)
|
|
428
|
+
sanitized["tool_choice"] = tool_choice_force
|
|
429
|
+
sanitized["parallel_tool_calls"] = False
|
|
430
|
+
return sanitized
|
|
431
|
+
return sanitized
|
|
432
|
+
|
|
433
|
+
@api.post("/proxy/v1/chat/completions")
|
|
434
|
+
def proxy_chat_completions(request: dict[str, object] = Body(...)):
|
|
435
|
+
key = os.environ.get("OPENAI_API_KEY")
|
|
436
|
+
if not key:
|
|
437
|
+
raise HTTPException(
|
|
438
|
+
status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail="OPENAI_API_KEY missing"
|
|
439
|
+
)
|
|
440
|
+
model = request.get("model") if isinstance(request, dict) else None
|
|
441
|
+
payload = _prepare_openai_payload(
|
|
442
|
+
model if isinstance(model, str) else None, request if isinstance(request, dict) else {}
|
|
443
|
+
)
|
|
444
|
+
headers = {"Authorization": f"Bearer {key}"}
|
|
445
|
+
with httpx.Client(timeout=httpx.Timeout(180.0), follow_redirects=True) as client:
|
|
446
|
+
resp = client.post(
|
|
447
|
+
"https://api.openai.com/v1/chat/completions", json=payload, headers=headers
|
|
448
|
+
)
|
|
449
|
+
try:
|
|
450
|
+
data = resp.json()
|
|
451
|
+
except Exception:
|
|
452
|
+
data = {"error": "invalid_json", "raw": resp.text[:400]}
|
|
453
|
+
if resp.status_code >= 400:
|
|
454
|
+
from fastapi.responses import JSONResponse
|
|
455
|
+
|
|
456
|
+
return JSONResponse(status_code=resp.status_code, content=data)
|
|
457
|
+
return data
|
|
458
|
+
|
|
459
|
+
@api.post("/rollout")
|
|
460
|
+
def rollout(request: dict[str, object] = Body(...)):
|
|
461
|
+
import json as _json
|
|
462
|
+
from typing import Any
|
|
463
|
+
|
|
464
|
+
run_id = str(request.get("run_id"))
|
|
465
|
+
data = request if isinstance(request, dict) else {}
|
|
466
|
+
env = data.get("env") if isinstance(data, dict) else {}
|
|
467
|
+
policy = data.get("policy") if isinstance(data, dict) else {}
|
|
468
|
+
ops = data.get("ops") if isinstance(data, dict) else []
|
|
469
|
+
if not isinstance(ops, list):
|
|
470
|
+
ops = []
|
|
471
|
+
env_name = (env or {}).get("env_name") or "math"
|
|
472
|
+
policy_cfg = (policy or {}).get("config") or {}
|
|
473
|
+
model = policy_cfg.get("model")
|
|
474
|
+
inference_url = (policy_cfg.get("inference_url") or "").rstrip("/")
|
|
475
|
+
|
|
476
|
+
env_cfg = (env or {}).get("config") or {}
|
|
477
|
+
try:
|
|
478
|
+
seed_val = (
|
|
479
|
+
int((env or {}).get("seed"))
|
|
480
|
+
if isinstance(env, dict) and (env or {}).get("seed") is not None
|
|
481
|
+
else 0
|
|
482
|
+
)
|
|
483
|
+
except Exception:
|
|
484
|
+
seed_val = 0
|
|
485
|
+
if seed_val == 0:
|
|
486
|
+
try:
|
|
487
|
+
seed_val = (
|
|
488
|
+
int(env_cfg.get("seed"))
|
|
489
|
+
if isinstance(env_cfg, dict) and env_cfg.get("seed") is not None
|
|
490
|
+
else 0
|
|
491
|
+
)
|
|
492
|
+
except Exception:
|
|
493
|
+
seed_val = 0
|
|
494
|
+
subject = (env_cfg.get("subject") if isinstance(env_cfg, dict) else None) or os.getenv(
|
|
495
|
+
"HENDRYCKS_MATH_CONFIG", "default"
|
|
496
|
+
)
|
|
497
|
+
qh, ah = _load_hendrycks_problem(seed_val, subject=subject)
|
|
498
|
+
question = qh
|
|
499
|
+
expected_answer = ah
|
|
500
|
+
|
|
501
|
+
def _prepare_payload(m: str | None, payload: dict[str, Any]) -> dict[str, Any]:
|
|
502
|
+
sanitized = dict(payload)
|
|
503
|
+
for k in ("stop_after_tool_calls", "thinking_mode", "thinking_budget", "reasoning"):
|
|
504
|
+
sanitized.pop(k, None)
|
|
505
|
+
if m and "gpt-5" in m:
|
|
506
|
+
if "max_tokens" in sanitized and "max_completion_tokens" not in sanitized:
|
|
507
|
+
sanitized["max_completion_tokens"] = sanitized.pop("max_tokens")
|
|
508
|
+
else:
|
|
509
|
+
sanitized.pop("max_tokens", None)
|
|
510
|
+
sanitized["tool_choice"] = tool_choice_force
|
|
511
|
+
sanitized["parallel_tool_calls"] = False
|
|
512
|
+
return sanitized
|
|
513
|
+
|
|
514
|
+
def _parse_tool_answer(payload: dict[str, Any]) -> str:
|
|
515
|
+
choices = payload.get("choices") if isinstance(payload, dict) else None
|
|
516
|
+
if not isinstance(choices, list):
|
|
517
|
+
return ""
|
|
518
|
+
for choice in choices:
|
|
519
|
+
if not isinstance(choice, dict):
|
|
520
|
+
continue
|
|
521
|
+
tool_calls = choice.get("tool_calls")
|
|
522
|
+
if not isinstance(tool_calls, list):
|
|
523
|
+
continue
|
|
524
|
+
for call in tool_calls:
|
|
525
|
+
if not isinstance(call, dict):
|
|
526
|
+
continue
|
|
527
|
+
function = call.get("function")
|
|
528
|
+
if not isinstance(function, dict):
|
|
529
|
+
continue
|
|
530
|
+
if function.get("name") != "submit_answer":
|
|
531
|
+
continue
|
|
532
|
+
arguments = function.get("arguments")
|
|
533
|
+
if isinstance(arguments, str):
|
|
534
|
+
try:
|
|
535
|
+
parsed = _json.loads(arguments)
|
|
536
|
+
except Exception:
|
|
537
|
+
parsed = {}
|
|
538
|
+
if isinstance(parsed, dict):
|
|
539
|
+
answer = parsed.get("answer")
|
|
540
|
+
if isinstance(answer, str):
|
|
541
|
+
return answer
|
|
542
|
+
elif isinstance(arguments, dict):
|
|
543
|
+
answer = arguments.get("answer")
|
|
544
|
+
if isinstance(answer, str):
|
|
545
|
+
return answer
|
|
546
|
+
return ""
|
|
547
|
+
|
|
548
|
+
steps: list[dict[str, Any]] = []
|
|
549
|
+
history: list[dict[str, Any]] = []
|
|
550
|
+
total_reward = 0.0
|
|
551
|
+
|
|
552
|
+
def _call_inference(input_messages: list[dict[str, Any]]):
|
|
553
|
+
payload = {
|
|
554
|
+
"model": model,
|
|
555
|
+
"messages": input_messages,
|
|
556
|
+
"max_completion_tokens": policy_cfg.get("max_tokens", 512),
|
|
557
|
+
"temperature": policy_cfg.get("temperature", 0.0),
|
|
558
|
+
"tool_choice": tool_choice_force,
|
|
559
|
+
}
|
|
560
|
+
body = _prepare_payload(model if isinstance(model, str) else None, payload)
|
|
561
|
+
with httpx.Client(timeout=httpx.Timeout(120.0), follow_redirects=True) as client:
|
|
562
|
+
resp = client.post(f"{inference_url}/v1/chat/completions", json=body)
|
|
563
|
+
resp.raise_for_status()
|
|
564
|
+
return resp.json()
|
|
565
|
+
|
|
566
|
+
messages = [
|
|
567
|
+
{"role": "system", "content": "You are a math expert. Solve the problem step by step."},
|
|
568
|
+
{"role": "user", "content": question},
|
|
569
|
+
]
|
|
570
|
+
|
|
571
|
+
steps.append(
|
|
572
|
+
{
|
|
573
|
+
"obs": {"prompt": question},
|
|
574
|
+
"tool_calls": [],
|
|
575
|
+
"reward": None,
|
|
576
|
+
"done": False,
|
|
577
|
+
"truncated": False,
|
|
578
|
+
"info": None,
|
|
579
|
+
}
|
|
580
|
+
)
|
|
581
|
+
history.append({"question": question, "subject": subject})
|
|
582
|
+
|
|
583
|
+
data = _call_inference(messages)
|
|
584
|
+
|
|
585
|
+
llm_text = None
|
|
586
|
+
try:
|
|
587
|
+
choices = data.get("choices") if isinstance(data, dict) else None
|
|
588
|
+
if isinstance(choices, list) and choices:
|
|
589
|
+
message_obj = choices[0].get("message", {}) if isinstance(choices[0], dict) else {}
|
|
590
|
+
if isinstance(message_obj, dict):
|
|
591
|
+
content = message_obj.get("content")
|
|
592
|
+
if isinstance(content, str) and content.strip():
|
|
593
|
+
llm_text = content
|
|
594
|
+
except Exception:
|
|
595
|
+
llm_text = None
|
|
596
|
+
|
|
597
|
+
try:
|
|
598
|
+
if question is not None:
|
|
599
|
+
print(f"[math] question: {question}", flush=True)
|
|
600
|
+
if llm_text is not None:
|
|
601
|
+
print(f"[math] llm: {llm_text}", flush=True)
|
|
602
|
+
if expected_answer is not None and llm_text is not None:
|
|
603
|
+
exp_fragment = str(expected_answer).strip()
|
|
604
|
+
got = llm_text.strip()
|
|
605
|
+
is_correct = exp_fragment and (exp_fragment in got)
|
|
606
|
+
print(f"[math] correct: {bool(is_correct)} (expected fragment: {exp_fragment})", flush=True)
|
|
607
|
+
except Exception:
|
|
608
|
+
pass
|
|
609
|
+
|
|
610
|
+
tool_answer = _parse_tool_answer(data)
|
|
611
|
+
history.append({"answer": tool_answer})
|
|
612
|
+
steps.append(
|
|
613
|
+
{
|
|
614
|
+
"obs": {},
|
|
615
|
+
"tool_calls": [
|
|
616
|
+
{
|
|
617
|
+
"tool_name": "submit_answer",
|
|
618
|
+
"arguments": _json.dumps({"answer": tool_answer}),
|
|
619
|
+
}
|
|
620
|
+
],
|
|
621
|
+
"reward": None,
|
|
622
|
+
"done": False,
|
|
623
|
+
"truncated": False,
|
|
624
|
+
"info": None,
|
|
625
|
+
}
|
|
626
|
+
)
|
|
627
|
+
|
|
628
|
+
reward_val = 0.0
|
|
629
|
+
candidate = tool_answer or ""
|
|
630
|
+
try:
|
|
631
|
+
if not candidate and llm_text is not None:
|
|
632
|
+
candidate = _extract_boxed(llm_text) or llm_text
|
|
633
|
+
if expected_answer is not None:
|
|
634
|
+
exp_raw = _extract_boxed(str(expected_answer)) or str(expected_answer)
|
|
635
|
+
got_raw = candidate
|
|
636
|
+
exp_n = _normalize_answer_text(exp_raw)
|
|
637
|
+
got_n = _normalize_answer_text(got_raw)
|
|
638
|
+
if exp_n and exp_n in got_n:
|
|
639
|
+
reward_val = 1.0
|
|
640
|
+
except Exception:
|
|
641
|
+
reward_val = 0.0
|
|
642
|
+
|
|
643
|
+
try:
|
|
644
|
+
preview = candidate[:120] + ("…" if len(candidate) > 120 else "")
|
|
645
|
+
components = {
|
|
646
|
+
"env": float(reward_val),
|
|
647
|
+
"rubric_event": 1.0 if bool(candidate.strip()) else 0.0,
|
|
648
|
+
"rubric_outcome": 1.0 if float(reward_val) > 0.0 else 0.0,
|
|
649
|
+
}
|
|
650
|
+
print(
|
|
651
|
+
"[MATH_ROLLOUT] run=",
|
|
652
|
+
run_id,
|
|
653
|
+
" seed=",
|
|
654
|
+
seed_val,
|
|
655
|
+
" subject=",
|
|
656
|
+
subject,
|
|
657
|
+
" tool=submit_answer answer=",
|
|
658
|
+
preview,
|
|
659
|
+
" reward=",
|
|
660
|
+
float(reward_val),
|
|
661
|
+
" components=",
|
|
662
|
+
components,
|
|
663
|
+
flush=True,
|
|
664
|
+
)
|
|
665
|
+
except Exception:
|
|
666
|
+
pass
|
|
667
|
+
|
|
668
|
+
total_reward += float(reward_val)
|
|
669
|
+
steps.append(
|
|
670
|
+
{
|
|
671
|
+
"obs": {},
|
|
672
|
+
"tool_calls": [],
|
|
673
|
+
"reward": reward_val,
|
|
674
|
+
"done": True,
|
|
675
|
+
"truncated": False,
|
|
676
|
+
"info": None,
|
|
677
|
+
}
|
|
678
|
+
)
|
|
679
|
+
|
|
680
|
+
return {
|
|
681
|
+
"run_id": run_id,
|
|
682
|
+
"trajectories": [
|
|
683
|
+
{
|
|
684
|
+
"env_id": env_name,
|
|
685
|
+
"policy_id": (policy or {}).get("policy_name") or "math-react",
|
|
686
|
+
"steps": steps,
|
|
687
|
+
"final": {"observation": {}},
|
|
688
|
+
"length": len(steps),
|
|
689
|
+
}
|
|
690
|
+
],
|
|
691
|
+
"branches": {},
|
|
692
|
+
"metrics": {
|
|
693
|
+
"episode_returns": [total_reward],
|
|
694
|
+
"mean_return": float(total_reward),
|
|
695
|
+
"num_steps": len(steps),
|
|
696
|
+
"num_episodes": 1,
|
|
697
|
+
},
|
|
698
|
+
"aborted": False,
|
|
699
|
+
"ops_executed": len(steps),
|
|
700
|
+
}
|
|
701
|
+
|
|
702
|
+
return api
|