synth-ai 0.2.9.dev2__py3-none-any.whl → 0.2.9.dev4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of synth-ai might be problematic. Click here for more details.

Files changed (112) hide show
  1. examples/analyze_semantic_words.sh +17 -0
  2. examples/common_old/backend.py +21 -0
  3. examples/crafter_debug_render.py +180 -0
  4. examples/evals_old/README.md +98 -0
  5. examples/evals_old/__init__.py +6 -0
  6. examples/evals_old/compare_models.py +1037 -0
  7. examples/evals_old/example_log.md +145 -0
  8. examples/evals_old/run_demo.sh +126 -0
  9. examples/evals_old/trace_analysis.py +270 -0
  10. examples/finetuning_old/_backup_synth_qwen/config.toml +29 -0
  11. examples/finetuning_old/_backup_synth_qwen/example_log.md +324 -0
  12. examples/finetuning_old/_backup_synth_qwen/filter_traces.py +60 -0
  13. examples/finetuning_old/_backup_synth_qwen/filter_traces_achievements.py +239 -0
  14. examples/finetuning_old/_backup_synth_qwen/purge_v3_traces.py +109 -0
  15. examples/finetuning_old/_backup_synth_qwen/react_agent_lm.py +1924 -0
  16. examples/finetuning_old/_backup_synth_qwen/readme.md +49 -0
  17. examples/finetuning_old/_backup_synth_qwen/run_crafter_qwen4b.py +114 -0
  18. examples/finetuning_old/_backup_synth_qwen/run_demo.sh +195 -0
  19. examples/finetuning_old/_backup_synth_qwen/sft_kickoff.py +118 -0
  20. examples/finetuning_old/synth_qwen_v1/README.md +68 -0
  21. examples/finetuning_old/synth_qwen_v1/filter_traces.py +60 -0
  22. examples/finetuning_old/synth_qwen_v1/filter_traces_achievements.py +239 -0
  23. examples/finetuning_old/synth_qwen_v1/finetune.py +46 -0
  24. examples/finetuning_old/synth_qwen_v1/hello_ft_model.py +71 -0
  25. examples/finetuning_old/synth_qwen_v1/infer.py +37 -0
  26. examples/finetuning_old/synth_qwen_v1/poll.py +44 -0
  27. examples/finetuning_old/synth_qwen_v1/prepare_data.py +35 -0
  28. examples/finetuning_old/synth_qwen_v1/purge_v3_traces.py +109 -0
  29. examples/finetuning_old/synth_qwen_v1/react_agent_lm.py +1932 -0
  30. examples/finetuning_old/synth_qwen_v1/run_crafter_sft_job.py +207 -0
  31. examples/finetuning_old/synth_qwen_v1/run_ft_job.py +232 -0
  32. examples/finetuning_old/synth_qwen_v1/upload_data.py +34 -0
  33. examples/finetuning_old/synth_qwen_v1/util.py +147 -0
  34. examples/rl/README.md +169 -0
  35. examples/rl/configs/eval_base_qwen.toml +15 -0
  36. examples/rl/configs/eval_rl_qwen.toml +11 -0
  37. examples/rl/configs/rl_from_base_qwen.toml +35 -0
  38. examples/rl/configs/rl_from_base_qwen17.toml +74 -0
  39. examples/rl/configs/rl_from_ft_qwen.toml +35 -0
  40. examples/rl/download_dataset.py +64 -0
  41. examples/rl/run_eval.py +435 -0
  42. examples/rl/run_rl_and_save.py +94 -0
  43. examples/rl/task_app/README.md +22 -0
  44. {synth_ai/task/apps → examples/rl/task_app}/math_single_step.py +8 -8
  45. examples/rl/task_app/math_task_app.py +107 -0
  46. examples/rl_old/task_app.py +962 -0
  47. examples/run_crafter_demo.sh +10 -0
  48. examples/warming_up_to_rl/analyze_trace_db.py +420 -0
  49. examples/warming_up_to_rl/configs/crafter_fft.toml +48 -0
  50. examples/warming_up_to_rl/configs/crafter_fft_4b.toml +54 -0
  51. examples/warming_up_to_rl/configs/eval_fft_qwen4b.toml +20 -0
  52. examples/warming_up_to_rl/configs/eval_groq_qwen32b.toml +13 -0
  53. examples/warming_up_to_rl/configs/eval_modal_qwen4b.toml +23 -0
  54. examples/warming_up_to_rl/configs/rl_from_base_qwen4b.toml +73 -0
  55. examples/warming_up_to_rl/configs/rl_from_ft.toml +56 -0
  56. examples/warming_up_to_rl/export_trace_sft.py +541 -0
  57. examples/warming_up_to_rl/groq_test.py +88 -0
  58. examples/warming_up_to_rl/manage_secrets.py +127 -0
  59. examples/warming_up_to_rl/old/event_rewards.md +234 -0
  60. examples/warming_up_to_rl/old/notes.md +73 -0
  61. examples/warming_up_to_rl/readme.md +172 -0
  62. examples/warming_up_to_rl/run_eval.py +434 -0
  63. examples/warming_up_to_rl/run_fft_and_save.py +309 -0
  64. examples/warming_up_to_rl/run_local_rollout.py +188 -0
  65. examples/warming_up_to_rl/run_local_rollout_modal.py +160 -0
  66. examples/warming_up_to_rl/run_local_rollout_parallel.py +342 -0
  67. examples/warming_up_to_rl/run_local_rollout_traced.py +372 -0
  68. examples/warming_up_to_rl/run_rl_and_save.py +101 -0
  69. examples/warming_up_to_rl/run_rollout_remote.py +129 -0
  70. examples/warming_up_to_rl/task_app/README.md +38 -0
  71. {synth_ai/task/apps → examples/warming_up_to_rl/task_app}/grpo_crafter.py +7 -7
  72. examples/warming_up_to_rl/task_app/grpo_crafter_task_app.py +165 -0
  73. examples/warming_up_to_rl/task_app/synth_envs_hosted/README.md +173 -0
  74. examples/warming_up_to_rl/task_app/synth_envs_hosted/__init__.py +5 -0
  75. examples/warming_up_to_rl/task_app/synth_envs_hosted/branching.py +145 -0
  76. examples/warming_up_to_rl/task_app/synth_envs_hosted/environment_routes.py +1271 -0
  77. examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/__init__.py +1 -0
  78. examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/__init__.py +6 -0
  79. examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/app.py +1 -0
  80. examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/environment.py +429 -0
  81. examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/policy.py +442 -0
  82. examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/react_agent.py +96 -0
  83. examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/shared.py +302 -0
  84. examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/tools.py +47 -0
  85. examples/warming_up_to_rl/task_app/synth_envs_hosted/hosted_app.py +202 -0
  86. examples/warming_up_to_rl/task_app/synth_envs_hosted/inference/__init__.py +5 -0
  87. examples/warming_up_to_rl/task_app/synth_envs_hosted/inference/openai_client.py +512 -0
  88. examples/warming_up_to_rl/task_app/synth_envs_hosted/main.py +102 -0
  89. examples/warming_up_to_rl/task_app/synth_envs_hosted/policy_routes.py +985 -0
  90. examples/warming_up_to_rl/task_app/synth_envs_hosted/registry.py +197 -0
  91. examples/warming_up_to_rl/task_app/synth_envs_hosted/rollout.py +1749 -0
  92. examples/warming_up_to_rl/task_app/synth_envs_hosted/storage/__init__.py +5 -0
  93. examples/warming_up_to_rl/task_app/synth_envs_hosted/storage/volume.py +217 -0
  94. examples/warming_up_to_rl/task_app/synth_envs_hosted/test_agents.py +160 -0
  95. examples/warming_up_to_rl/task_app/synth_envs_hosted/test_service.py +146 -0
  96. examples/warming_up_to_rl/task_app/synth_envs_hosted/test_stepwise_rewards.py +58 -0
  97. examples/warming_up_to_rl/task_app/synth_envs_hosted/utils.py +61 -0
  98. synth_ai/api/train/config_finder.py +18 -18
  99. synth_ai/api/train/env_resolver.py +28 -1
  100. synth_ai/cli/task_apps.py +264 -55
  101. synth_ai/demo_registry.py +7 -7
  102. synth_ai/demos/demo_task_apps/crafter/__init__.py +1 -0
  103. synth_ai/demos/demo_task_apps/crafter/configs/crafter_fft_4b.toml +54 -0
  104. synth_ai/demos/demo_task_apps/crafter/configs/rl_from_base_qwen4b.toml +73 -0
  105. synth_ai/demos/demo_task_apps/crafter/grpo_crafter_task_app.py +165 -0
  106. synth_ai/task/apps/__init__.py +54 -13
  107. {synth_ai-0.2.9.dev2.dist-info → synth_ai-0.2.9.dev4.dist-info}/METADATA +1 -1
  108. {synth_ai-0.2.9.dev2.dist-info → synth_ai-0.2.9.dev4.dist-info}/RECORD +112 -13
  109. {synth_ai-0.2.9.dev2.dist-info → synth_ai-0.2.9.dev4.dist-info}/top_level.txt +1 -0
  110. {synth_ai-0.2.9.dev2.dist-info → synth_ai-0.2.9.dev4.dist-info}/WHEEL +0 -0
  111. {synth_ai-0.2.9.dev2.dist-info → synth_ai-0.2.9.dev4.dist-info}/entry_points.txt +0 -0
  112. {synth_ai-0.2.9.dev2.dist-info → synth_ai-0.2.9.dev4.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,962 @@
1
+ import modal
2
+ from typing import Any, Optional
3
+ from dataclasses import dataclass, field
4
+ import os as _os
5
+ import sys as _sys
6
+ from pathlib import Path as _Path
7
+ import time
8
+
9
+ # Make local 'crafter' importable when running locally
10
+ _HERE = _Path(__file__).resolve()
11
+ _LOCAL_CRAFTER_PARENT = _HERE.parent.parent # points to examples/rl
12
+ if str(_LOCAL_CRAFTER_PARENT) not in _sys.path:
13
+ _sys.path.insert(0, str(_LOCAL_CRAFTER_PARENT))
14
+ if "/opt" not in _sys.path:
15
+ _sys.path.insert(0, "/opt")
16
+
17
+ # Use environment-aware names to avoid collisions across dev/prod
18
+ _env_flag = (_os.getenv("SYNTH_BACKEND_URL_OVERRIDE", "") or _os.getenv("ENVIRONMENT", "") or _os.getenv("APP_ENVIRONMENT", "")).strip().lower()
19
+ _is_prod = _env_flag in ("prod", "production")
20
+
21
+ # Secret name must be provided explicitly via TASK_APP_SECRET_NAME
22
+ MODAL_SECRET_NAME = _os.getenv("TASK_APP_SECRET_NAME")
23
+ assert MODAL_SECRET_NAME, "TASK_APP_SECRET_NAME must be set before launching the task app"
24
+
25
+
26
+ # Modal app name (overridable via TASK_APP_NAME)
27
+ _default_app_name = "grpo-task-service-sdk-prod" if _is_prod else "grpo-task-service-sdk"
28
+ app = modal.App(_os.getenv("TASK_APP_NAME", _default_app_name))
29
+
30
+ image = (
31
+ modal.Image.debian_slim(python_version="3.11")
32
+ .pip_install(
33
+ [
34
+ "fastapi",
35
+ "uvicorn",
36
+ "pydantic>=2",
37
+ "httpx",
38
+ "requests",
39
+ "tqdm",
40
+ "urllib3>=2.3.0",
41
+ "jsonschema>=4.23.0",
42
+ "typing_extensions>=4.0.0",
43
+ "numpy",
44
+ "pandas",
45
+ "sqlalchemy",
46
+ "aiosqlite",
47
+ "asyncpg>=0.30.0",
48
+ "crafter",
49
+ "pillow",
50
+ "imageio",
51
+ "opensimplex",
52
+ "ruamel.yaml",
53
+ "networkx>=3.4.2",
54
+ "redis>=6.2.0",
55
+ "duckdb>=1.0.0",
56
+ "ty>=0.0.1a5",
57
+ "toml>=0.10.2",
58
+ "libsql>=0.1.8",
59
+ "python-dotenv",
60
+ "anthropic",
61
+ "openai",
62
+ "diskcache",
63
+ "backoff",
64
+ "groq",
65
+ "google-genai",
66
+ "google-generativeai",
67
+ "google-api-python-client",
68
+ "google-api-core>=2.25.1",
69
+ "google-auth",
70
+ "google-auth-httplib2",
71
+ "opentelemetry-api>=1.26.0,<1.27.0",
72
+ "opentelemetry-sdk>=1.26.0,<1.27.0",
73
+ "opentelemetry-exporter-otlp-proto-http>=1.26.0,<1.27.0",
74
+ "wrapt",
75
+ "langfuse>=2.53.9,<3.0.0",
76
+ "together",
77
+ "mistralai>=1.9.2",
78
+ "click>=8.1.0",
79
+ "textual>=1.1.0",
80
+ "openai-harmony>=0.0.1",
81
+ "aiohttp>=3.8.0",
82
+ "datasets>=4.0.0",
83
+ "gymnasium>=0.29.1",
84
+ "minigrid>=2.3.1",
85
+ ]
86
+ )
87
+ # Bundle the crafter module into the image for imports at runtime (absolute path)
88
+ .add_local_dir(str((_HERE.parent / "crafter_task_app_helpers").resolve()), "/opt/crafter_task_app_helpers")
89
+ # Bundle synth_ai package to import full environment implementation.
90
+ # Resolve repo root robustly (examples/rl/task_app.py -> repo_root = examples/rl/../../..)
91
+ .add_local_dir(str((_HERE.parent.parent.parent / "synth_ai").resolve()), "/opt/synth_ai")
92
+ )
93
+
94
+ # --- OpenAI payload sanitizer (local) ---
95
+ OPENAI_MAX_COMPLETION_TOKENS_MIN = 16000
96
+ OPENAI_REMOVE_FIELDS = (
97
+ "stop_after_tool_calls",
98
+ "thinking_mode",
99
+ "thinking_budget",
100
+ "reasoning",
101
+ )
102
+ OPENAI_REMOVE_SAMPLING_FIELDS = ("temperature", "top_p")
103
+ OPENAI_TOOL_CHOICE_FORCED = {"type": "function", "function": {"name": "interact"}}
104
+
105
+ def prepare_inference_payload_for_model(model: str | None, payload: dict[str, Any]) -> dict[str, Any]:
106
+ """Sanitize payload for OpenAI API.
107
+
108
+ - Always strip Synth-specific fields not supported by OpenAI (e.g., stop_after_tool_calls).
109
+ - For gpt-5 family: map max_tokens->max_completion_tokens, enforce tool_choice, disable parallel tools,
110
+ and remove vendor-specific sampling fields.
111
+ """
112
+ out = dict(payload)
113
+ # Always remove unsupported fields for OpenAI
114
+ for k in OPENAI_REMOVE_FIELDS:
115
+ if k in out:
116
+ out.pop(k)
117
+
118
+ # gpt-5 family specific adjustments
119
+ if model and "gpt-5" in model:
120
+ if "max_completion_tokens" not in out and "max_tokens" in out:
121
+ out["max_completion_tokens"] = out.pop("max_tokens")
122
+ # Ensure we don't send both
123
+ if "max_tokens" in out:
124
+ out.pop("max_tokens")
125
+ for k in OPENAI_REMOVE_SAMPLING_FIELDS:
126
+ if k in out:
127
+ out.pop(k)
128
+ mct = out.get("max_completion_tokens")
129
+ if not isinstance(mct, int) or mct < OPENAI_MAX_COMPLETION_TOKENS_MIN:
130
+ out["max_completion_tokens"] = OPENAI_MAX_COMPLETION_TOKENS_MIN
131
+ out["tool_choice"] = OPENAI_TOOL_CHOICE_FORCED
132
+ out["parallel_tool_calls"] = False
133
+ return out
134
+
135
+ @app.function(image=image, secrets=[modal.Secret.from_name(MODAL_SECRET_NAME)], min_containers=1, max_containers=1)
136
+ @modal.asgi_app()
137
+ def fastapi_app():
138
+ # Import FastAPI/Pydantic inside the container runtime to avoid local import errors
139
+ from fastapi import FastAPI, Body, HTTPException, status
140
+ from starlette.requests import Request
141
+ from fastapi.responses import JSONResponse
142
+ from pydantic import BaseModel
143
+ import logging
144
+ import sys
145
+ import os
146
+ import httpx
147
+ # Logger for debug output
148
+ logger = logging.getLogger(__name__)
149
+
150
+ # Preload synth_ai modules and vendor deps so missing packages surface early
151
+ if "/opt/synth_ai" not in sys.path:
152
+ sys.path.insert(0, "/opt/synth_ai")
153
+ # Ensure tracing DB points to a writable location in the container
154
+ os.environ.setdefault("TURSO_LOCAL_DB_URL", "sqlite+aiosqlite:////tmp/synth_ai.db")
155
+
156
+ import importlib
157
+ preload_modules = [
158
+ # synth_ai core
159
+ "synth_ai",
160
+ "synth_ai.lm",
161
+ "synth_ai.lm.core.main",
162
+ "synth_ai.lm.core.main_v3",
163
+ "synth_ai.lm.core.vendor_clients",
164
+ "synth_ai.lm.core.all",
165
+ # vendors
166
+ "synth_ai.lm.vendors.core.anthropic_api",
167
+ "synth_ai.lm.vendors.core.openai_api",
168
+ "synth_ai.lm.vendors.openai_standard",
169
+ "synth_ai.lm.vendors.core.gemini_api",
170
+ # environments
171
+ "synth_ai.environments",
172
+ "synth_ai.environments.environment.rewards.core",
173
+ "synth_ai.environments.examples.crafter_classic.environment",
174
+ # tracing
175
+ "synth_ai.tracing_v3.turso.models",
176
+ "synth_ai.tracing_v3.turso.manager",
177
+ # common 3p libs these modules rely on
178
+ "anthropic",
179
+ "openai",
180
+ "groq",
181
+ "google.genai",
182
+ "google.generativeai",
183
+ "googleapiclient.discovery",
184
+ "google.auth",
185
+ "google_auth_httplib2",
186
+ "requests",
187
+ "tqdm",
188
+ "langfuse",
189
+ "diskcache",
190
+ "backoff",
191
+ "together",
192
+ "dotenv",
193
+ "grpc",
194
+ ]
195
+ for mod in preload_modules:
196
+ try:
197
+ importlib.import_module(mod)
198
+ except Exception as _e:
199
+ print(f"[task:crafter] preload missing/err: {mod}: {_e}", flush=True)
200
+
201
+ # Make packaged local crafter modules importable ahead of site-packages 'crafter'
202
+ if "/opt/crafter_task_app_helpers" not in sys.path:
203
+ sys.path.insert(0, "/opt/crafter_task_app_helpers")
204
+ if "/opt" not in sys.path:
205
+ sys.path.insert(0, "/opt")
206
+ if "/opt/synth_ai" not in sys.path:
207
+ sys.path.insert(0, "/opt/synth_ai")
208
+ from crafter_task_app_helpers.env import EnvRegistry
209
+ from crafter_task_app_helpers.rewards import compute_decision_rewards
210
+ from crafter_task_app_helpers.config import ACTION_SPACE, ENV_NAME
211
+ from crafter_task_app_helpers.policy import CrafterPolicy
212
+
213
+ _registry = EnvRegistry()
214
+
215
+ # --- JSON sanitization for responses (convert numpy -> python primitives, arrays -> shapes) ---
216
+ import numpy as _np
217
+
218
+ def _to_jsonable(value):
219
+ # Numpy types first: scalars vs arrays
220
+ if isinstance(value, (_np.generic,)):
221
+ return value.item()
222
+ if isinstance(value, _np.ndarray):
223
+ return f"<ndarray shape={tuple(value.shape)} dtype={str(value.dtype)}>"
224
+ # Basic containers
225
+ if isinstance(value, dict):
226
+ return {k: _to_jsonable(v) for k, v in value.items()}
227
+ if isinstance(value, (list, tuple)):
228
+ return [_to_jsonable(v) for v in value]
229
+ # Sets to lists
230
+ if isinstance(value, set):
231
+ return [_to_jsonable(v) for v in value]
232
+ return value
233
+
234
+ class InitRequest(BaseModel):
235
+ env_name: str | None = None
236
+ env_config: dict[str, Any] | None = None
237
+
238
+ class StepRequest(BaseModel):
239
+ env_id: str
240
+ action: str
241
+
242
+ api = FastAPI(debug=True)
243
+
244
+ # Basic root endpoints so HEAD/GET / succeeds for preflight checks
245
+ @api.head("/")
246
+ def head_root(): # type: ignore[empty-body]
247
+ return JSONResponse(status_code=status.HTTP_200_OK, content=None)
248
+
249
+ @api.get("/")
250
+ def get_root():
251
+ return {"ok": True, "service": "synth-ai task app"}
252
+
253
+ @api.get("/health")
254
+ def health(request: Request):
255
+ env_key = os.environ.get("ENVIRONMENT_API_KEY")
256
+ if not env_key:
257
+ raise HTTPException(status_code=503, detail="Auth not configured: missing ENVIRONMENT_API_KEY in task service environment")
258
+ # Authorize using all header variants; avoid typed Header to prevent 422s
259
+ try:
260
+ from synth_ai.task.auth import is_api_key_header_authorized
261
+ authorized = is_api_key_header_authorized(request)
262
+ except Exception:
263
+ # Fallback: check only x-api-key
264
+ header_key = request.headers.get("x-api-key")
265
+ authorized = bool(header_key) and (header_key == env_key)
266
+ if not authorized:
267
+ # Soft 200 with authorized flag so CLI preflight can proceed
268
+ prefix = env_key[: max(1, len(env_key) // 2)]
269
+ content = {"status": "healthy", "authorized": False, "expected_api_key_prefix": prefix}
270
+ return JSONResponse(status_code=200, content=content)
271
+ return {"healthy": True, "authorized": True}
272
+
273
+ # Rollout health endpoint used by CLI configure flow
274
+ @api.get("/health/rollout")
275
+ def health_rollout(request: Request):
276
+ expected = os.environ.get("ENVIRONMENT_API_KEY")
277
+ if not expected:
278
+ raise HTTPException(status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail="Missing ENVIRONMENT_API_KEY in service env")
279
+ try:
280
+ from synth_ai.task.auth import is_api_key_header_authorized
281
+ authorized = is_api_key_header_authorized(request)
282
+ except Exception:
283
+ header_key = request.headers.get("x-api-key")
284
+ authorized = bool(header_key) and (header_key == expected)
285
+ if not authorized:
286
+ prefix = expected[: max(1, len(expected) // 2)]
287
+ content = {"status": "healthy", "authorized": False, "expected_api_key_prefix": prefix}
288
+ return JSONResponse(status_code=200, content=content)
289
+ return {"ok": True, "authorized": True}
290
+
291
+ # Log and surface 422 validation errors with header presence
292
+ from fastapi.exceptions import RequestValidationError
293
+ @api.exception_handler(RequestValidationError)
294
+ async def _on_validation_error(request: Request, exc: RequestValidationError):
295
+ try:
296
+ hdr = request.headers
297
+ snapshot = {
298
+ "path": str(getattr(request, "url").path),
299
+ "have_x_api_key": bool(hdr.get("x-api-key")),
300
+ "have_x_api_keys": bool(hdr.get("x-api-keys")),
301
+ "have_authorization": bool(hdr.get("authorization")),
302
+ "errors": exc.errors()[:5],
303
+ }
304
+ print("[422] validation", snapshot, flush=True)
305
+ except Exception:
306
+ pass
307
+ return JSONResponse(status_code=422, content={"status": "invalid", "detail": exc.errors()[:5]})
308
+
309
+ @api.post(f"/env/{ENV_NAME}/initialize")
310
+ async def initialize(req: InitRequest, request: Request):
311
+ # Optionally tie the environment to a run_id header so we can guarantee isolation
312
+ run_id_hdr = request.headers.get("X-Run-Id") or request.headers.get("X-Run-ID")
313
+ env_id, obs = await _registry.initialize(req.env_config, run_id=run_id_hdr)
314
+ return {"env_id": env_id, "observation": _to_jsonable(obs)}
315
+
316
+ @api.post(f"/env/{ENV_NAME}/step")
317
+ async def step(req: StepRequest):
318
+ obs, reward, done, info = await _registry.step(req.env_id, req.action)
319
+ return {
320
+ "observation": _to_jsonable(obs),
321
+ "reward": float(reward) if isinstance(reward, (int, float)) else reward,
322
+ "done": bool(done),
323
+ "info": _to_jsonable(info) if info is not None else None,
324
+ }
325
+
326
+ @api.post(f"/env/{ENV_NAME}/terminate")
327
+ async def terminate(req: dict[str, str] = Body(...)):
328
+ env_id = str(req.get("env_id"))
329
+ return await _registry.terminate(env_id)
330
+
331
+ @api.get("/actions")
332
+ def actions():
333
+ return {"actions": ACTION_SPACE}
334
+
335
+ # OpenAI proxy: forward chat/completions to OpenAI using env OPENAI_API_KEY
336
+ @api.post("/proxy/v1/chat/completions")
337
+ def proxy_chat_completions(req: dict[str, Any]):
338
+ openai_key = os.environ.get("OPENAI_API_KEY")
339
+ if not openai_key:
340
+ raise HTTPException(status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail="Missing OPENAI_API_KEY in task service environment")
341
+ # Sanitize payload for OpenAI models (e.g., gpt-5-*)
342
+ model = req.get("model")
343
+ payload = prepare_inference_payload_for_model(model, req)
344
+ headers = {"Authorization": f"Bearer {openai_key}"}
345
+ # Increase timeout for proxy calls (models may be slower)
346
+ with httpx.Client(timeout=120.0) as client:
347
+ resp = client.post("https://api.openai.com/v1/chat/completions", json=payload, headers=headers)
348
+ try:
349
+ data = resp.json()
350
+ except Exception:
351
+ data = {"error": "invalid_json", "raw": resp.text[:800]}
352
+ if resp.status_code >= 400:
353
+ return JSONResponse(status_code=resp.status_code, content=data)
354
+ return data
355
+
356
+ # Unified rollout schema imported from SDK task contracts
357
+ from synth_ai.task.contracts import (
358
+ RolloutEnvSpec,
359
+ RolloutPolicySpec,
360
+ RolloutRecordConfig,
361
+ RolloutSafetyConfig,
362
+ RolloutRequest,
363
+ RolloutStep,
364
+ RolloutTrajectory,
365
+ RolloutMetrics,
366
+ RolloutResponse,
367
+ )
368
+
369
+ @api.post("/rollout", response_model=RolloutResponse)
370
+ async def rollout(req: RolloutRequest, request: Request):
371
+ expected = os.environ.get("ENVIRONMENT_API_KEY")
372
+ if not expected:
373
+ logger.error("rollout.auth.misconfigured: missing ENVIRONMENT_API_KEY")
374
+ raise HTTPException(status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail="Auth not configured: missing ENVIRONMENT_API_KEY")
375
+ # Compute masked diagnostics (never log full keys)
376
+ try:
377
+ exp_len = len(expected)
378
+ exp_suf = expected[-5:] if exp_len >= 5 else "" # last 5 chars
379
+ # Collect candidates from headers: X-API-Key, X-API-Keys (CSV), Authorization: Bearer
380
+ hdr = request.headers
381
+ single = hdr.get("x-api-key") or ""
382
+ multi = [p.strip() for p in (hdr.get("x-api-keys") or "").split(",") if p.strip()]
383
+ auth = hdr.get("authorization") or ""
384
+ bearer = auth.split(" ", 1)[1].strip() if auth.lower().startswith("bearer ") else ""
385
+ candidates = [c for c in [single, bearer, *multi] if c]
386
+ # Assert server sees ALL keys sent by client
387
+ if multi:
388
+ logger.info("rollout.auth.candidates: n=%s first15=%s", len(candidates), [c[:15] for c in candidates])
389
+ got_len = len(single or bearer or "")
390
+ got_suf = (single or bearer or "")[-5:] if got_len >= 5 else ""
391
+ except Exception:
392
+ exp_len = -1
393
+ exp_suf = ""
394
+ got_len = -1
395
+ got_suf = ""
396
+ # Authorize if ANY candidate matches expected
397
+ authorized = any(c == expected for c in candidates)
398
+ if not authorized:
399
+ logger.warning(
400
+ "rollout.auth.failed: have_any=%s expect_len=%s expect_last5=%s got_len=%s got_last5=%s",
401
+ bool(candidates), exp_len, exp_suf, got_len, got_suf,
402
+ )
403
+ raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid or missing API key")
404
+ else:
405
+ logger.info(
406
+ "rollout.auth.ok: expect_len=%s expect_last5=%s got_len=%s got_last5=%s",
407
+ exp_len, exp_suf, got_len, got_suf,
408
+ )
409
+
410
+ # Extract policy config
411
+ inference_url = req.policy.config["inference_url"]
412
+ model = req.policy.config.get("model")
413
+ max_steps = int(req.env.config.get("max_steps_per_episode", 10))
414
+ policy = CrafterPolicy(inference_url=inference_url, model=model)
415
+
416
+ # Debug: request summary
417
+ print(
418
+ "[task:crafter] ROLLOUT req: ",
419
+ {
420
+ "run_id": req.run_id,
421
+ "env": req.env.env_name,
422
+ "seed": req.env.seed,
423
+ "ops": len(req.ops),
424
+ "model": model,
425
+ "inference_url": inference_url,
426
+ "max_steps": max_steps,
427
+ },
428
+ flush=True,
429
+ )
430
+
431
+ # Initialize env (preemptively terminate any prior instance for this run to avoid sharing)
432
+ cfg = dict(req.env.config or {})
433
+ if req.env.seed is not None:
434
+ cfg["seed"] = int(req.env.seed)
435
+ env_id, observation = await _registry.initialize(cfg, run_id=req.run_id)
436
+
437
+ trajectory_steps: list[RolloutStep] = []
438
+ # Track per-decision achievement flips for stepwise shaping
439
+ decision_summaries: list[dict[str, Any]] = []
440
+ prev_ach: dict[str, bool] | None = None
441
+ total_reward = 0.0
442
+ ops_executed = 0
443
+ pending_tool_calls: list[dict[str, Any]] | None = None
444
+ try:
445
+ for op in req.ops:
446
+ if ops_executed >= req.safety.max_ops:
447
+ break
448
+ if op == "agent":
449
+ # Format current observation for the prompt
450
+ # Cache for mapping semantic ids to names
451
+ _id_to_item_cache: list[str] | None = None
452
+
453
+ def _ensure_semantic_mapping() -> list[str] | None:
454
+ nonlocal _id_to_item_cache
455
+ if _id_to_item_cache is not None:
456
+ return _id_to_item_cache
457
+ # Build mapping using crafter's internal ids
458
+ import itertools as _it
459
+ import crafter as _crafter
460
+ dummy = None
461
+ try:
462
+ dummy = _crafter.Env()
463
+ max_id = (
464
+ max(max(dummy._world._mat_ids.values()), max(dummy._sem_view._obj_ids.values()))
465
+ + 1
466
+ )
467
+ id_to_item = ["void"] * max_id
468
+ for name, ind in _it.chain(
469
+ dummy._world._mat_ids.items(), dummy._sem_view._obj_ids.items()
470
+ ):
471
+ if name is None:
472
+ clean = "none"
473
+ elif hasattr(name, "__name__"):
474
+ clean = name.__name__
475
+ else:
476
+ clean = str(name)
477
+ id_to_item[ind] = clean.lower()
478
+ _id_to_item_cache = id_to_item
479
+ finally:
480
+ if dummy is not None:
481
+ try:
482
+ dummy.close()
483
+ except Exception:
484
+ pass
485
+ return _id_to_item_cache
486
+
487
+ def _format_obs(obs: dict[str, Any]) -> str:
488
+ if not isinstance(obs, dict):
489
+ # Avoid dumping raw matrices; encourage exploration to gather context
490
+ return "no salient state; explore to gather context"
491
+ inv = obs.get("inventory") or {}
492
+ pos = obs.get("player_position")
493
+ steps = obs.get("num_steps_taken")
494
+ direction = obs.get("player_direction")
495
+ ach = obs.get("achievements_status") or {}
496
+ inv_lines = ", ".join(f"{k}:{v}" for k, v in inv.items() if v)
497
+ ach_on = [k for k, v in ach.items() if v]
498
+ lines = []
499
+ if pos is not None:
500
+ px, py = int(pos[0]), int(pos[1])
501
+ lines.append(f"position: (x={px}, y={py})")
502
+ if direction is not None:
503
+ dx, dy = int(direction[0]), int(direction[1])
504
+ dir_label = {
505
+ (1, 0): "→ east/right",
506
+ (-1, 0): "← west/left",
507
+ (0, 1): "↓ south/down",
508
+ (0, -1): "↑ north/up",
509
+ (0, 0): "• idle",
510
+ }.get((dx, dy), f"({dx},{dy})")
511
+ lines.append(f"direction: {dir_label}")
512
+ if steps is not None:
513
+ lines.append(f"steps: {int(steps)}")
514
+ if inv_lines:
515
+ lines.append(f"inventory: {inv_lines}")
516
+ if ach:
517
+ all_achievements = list(ach.keys())
518
+ lines.append(f"achievements_available: {', '.join(all_achievements)}")
519
+ lines.append(f"achievements_unlocked: {', '.join(ach_on)}" if ach_on else "achievements_unlocked: ")
520
+ lines.append(f"achievements_progress: {len(ach_on)}/{len(all_achievements)}")
521
+ # Local surroundings (7x7) using semantic_map
522
+ smap = obs.get("semantic_map")
523
+ if smap is not None and pos is not None:
524
+ try:
525
+ px, py = int(pos[0]), int(pos[1])
526
+ view_size = 7
527
+ half = view_size // 2
528
+ id_to_item = _ensure_semantic_mapping() or []
529
+ grid_rows: list[str] = []
530
+ # Build matrix centered at player, then transpose for human-friendly view
531
+ matrix: list[list[str]] = []
532
+ for dy in range(-half, half + 1):
533
+ row: list[str] = []
534
+ for dx in range(-half, half + 1):
535
+ x, y = px + dx, py + dy
536
+ if not (0 <= x < smap.shape[0] and 0 <= y < smap.shape[1]):
537
+ row.append("void")
538
+ elif dx == 0 and dy == 0:
539
+ row.append("player")
540
+ else:
541
+ idx = int(smap[x, y])
542
+ name = id_to_item[idx] if 0 <= idx < len(id_to_item) else str(idx)
543
+ row.append(name)
544
+ matrix.append(row)
545
+ # Transpose to match visual orientation
546
+ transposed = list(zip(*matrix))
547
+ for row in transposed:
548
+ grid_rows.append(" ".join(row))
549
+ if grid_rows:
550
+ lines.append(f"Local Map View (7x7):\n" + "\n".join(grid_rows))
551
+ except Exception:
552
+ # If any issue occurs, skip map rendering without crashing
553
+ pass
554
+ if not lines:
555
+ lines.append("no salient state; explore to gather context")
556
+ return "\n".join(lines)
557
+ # Build compact context from last few tool calls (gpt-5-nano friendly)
558
+ lines: list[str] = []
559
+ for rec in reversed(trajectory_steps):
560
+ if len(lines) >= 3:
561
+ break
562
+ tcs = rec.tool_calls
563
+ if not tcs:
564
+ continue
565
+ tc0 = tcs[0] if isinstance(tcs, list) and tcs else None
566
+ if not isinstance(tc0, dict):
567
+ continue
568
+ name = tc0.get("tool_name") or tc0.get("name") or "unknown"
569
+ args = tc0.get("arguments")
570
+ lines.append(f"- {name}: {args}")
571
+ context_text = "Previous tool calls (most recent first):\n" + ("\n".join(lines) if lines else "- none")
572
+ obs_text = _format_obs(observation)
573
+ combined_text = f"Current observation:\n{obs_text}\n\n{context_text}"
574
+ payload = policy.build_inference_request(combined_text, history=[], turn=len(trajectory_steps))
575
+ # Debug: print the full prompt content in a stable labeled block for grepability
576
+ try:
577
+ print("PROMPT_DUMP_BEGIN")
578
+ print(combined_text)
579
+ print("PROMPT_DUMP_END")
580
+ except Exception:
581
+ pass
582
+ # Debug: print user prompt and achievements unlocked list
583
+ try:
584
+ _msgs = payload.get("messages", [])
585
+ _last_user = None
586
+ for _m in reversed(_msgs):
587
+ if isinstance(_m, dict) and _m.get("role") == "user":
588
+ _last_user = _m
589
+ break
590
+ if _last_user is not None:
591
+ _content = _last_user.get("content")
592
+ print("[task:crafter] user prompt:", _content, flush=True)
593
+ except Exception:
594
+ pass
595
+ try:
596
+ _ach = observation.get("achievements_status") if isinstance(observation, dict) else {}
597
+ _ach_on = [k for k, v in (_ach or {}).items() if v]
598
+ print(f"[task:crafter] achievements_unlocked: {_ach_on}", flush=True)
599
+ except Exception:
600
+ pass
601
+ # Prepare payload based on model family (OpenAI vs vLLM)
602
+ def _prepare_payload(p: dict, mdl: str | None) -> dict:
603
+ return prepare_inference_payload_for_model(mdl, p)
604
+ # Debug: payload shape
605
+ print(
606
+ "[task:crafter] inference payload: ",
607
+ {
608
+ "has_model": bool(payload.get("model") is not None),
609
+ "messages": payload.get("messages", []),
610
+ "tools": isinstance(payload.get("tools"), list),
611
+ "tool_choice": payload.get("tool_choice"),
612
+ "stop_after_tool_calls": payload.get("stop_after_tool_calls"),
613
+ },
614
+ flush=True,
615
+ )
616
+ headers: dict[str, str] = {}
617
+ _okey = os.environ.get("OPENAI_API_KEY")
618
+ # Configure granular timeouts for slow model/tool runs
619
+ _timeouts = httpx.Timeout(connect=10.0, read=180.0, write=60.0, pool=60.0)
620
+ with httpx.Client(timeout=_timeouts) as client:
621
+ # Decide endpoint: avoid calling our own /proxy inside the same request
622
+ _direct = ("api.openai.com" in inference_url)
623
+ if _direct:
624
+ # Call OpenAI directly
625
+ if _okey:
626
+ headers["Authorization"] = f"Bearer {_okey}"
627
+ to_send = _prepare_payload(payload, model)
628
+ endpoint_base = "https://api.openai.com"
629
+ else:
630
+ # Non-OpenAI inference endpoint
631
+ to_send = payload
632
+ endpoint_base = inference_url
633
+ # If targeting Synth proxy, attach backend auth
634
+ if "/proxy" in endpoint_base:
635
+ _skey = os.environ.get("SYNTH_API_KEY")
636
+ if _skey:
637
+ headers["Authorization"] = f"Bearer {_skey}"
638
+
639
+ # Debug: outbound request diagnostics
640
+ try:
641
+ import json as _json
642
+ _size = len(_json.dumps(to_send))
643
+ except Exception:
644
+ _size = -1
645
+ print(
646
+ "[task:crafter] inference dispatch:",
647
+ {
648
+ "endpoint": f"{endpoint_base.rstrip('/')}/v1/chat/completions",
649
+ "direct_openai": bool(_direct),
650
+ "timeout": {"read": 180.0, "connect": 10.0, "write": 60.0, "pool": 60.0},
651
+ "payload_bytes": _size,
652
+ "has_auth": bool(headers.get("Authorization")),
653
+ },
654
+ flush=True,
655
+ )
656
+
657
+ _t0 = time.time()
658
+ try:
659
+ resp = client.post(
660
+ f"{endpoint_base.rstrip('/')}/v1/chat/completions",
661
+ json=to_send,
662
+ headers=headers,
663
+ )
664
+ except httpx.ReadTimeout as rte:
665
+ _elapsed = time.time() - _t0
666
+ print(f"[task:crafter][timeout] read timeout after {_elapsed:.1f}s: {rte}", flush=True)
667
+ raise
668
+ except Exception as re:
669
+ _elapsed = time.time() - _t0
670
+ print(f"[task:crafter][error] request failed after {_elapsed:.1f}s: {type(re).__name__}: {re}", flush=True)
671
+ raise
672
+ _elapsed = time.time() - _t0
673
+ print(f"[task:crafter] inference status= {resp.status_code} elapsed={_elapsed:.2f}s", flush=True)
674
+ # Emit a light-weight perf snapshot for visibility
675
+ try:
676
+ print(
677
+ "[metric] perf ",
678
+ "tok/s=n/a",
679
+ f"decision p50=n/a p95=n/a",
680
+ "roll n/a",
681
+ flush=True,
682
+ )
683
+ except Exception:
684
+ pass
685
+ # Debug: response status and body (on errors)
686
+ print("[task:crafter] inference status=", resp.status_code, flush=True)
687
+ if resp.status_code >= 400:
688
+ body_preview = resp.text[:800]
689
+ print("[task:crafter] inference error body:", body_preview, flush=True)
690
+ data = resp.json()
691
+ print(f"[task:crafter] inference response: {data}")
692
+ parsed = CrafterPolicy.parse_response_to_tool_calls(data, use_tools=True) or []
693
+ # Debug: parsed tool call summary
694
+ print(
695
+ "[task:crafter] parsed tool_calls: ",
696
+ {
697
+ "n": len(parsed),
698
+ "first": (parsed[0] if isinstance(parsed, list) and parsed else None),
699
+ },
700
+ flush=True,
701
+ )
702
+ # Print full tool call payloads for inspection
703
+ try:
704
+ import json as _json
705
+ for _i, _tc in enumerate(parsed):
706
+ try:
707
+ print(
708
+ f"[task:crafter] tool_call[{_i}]:",
709
+ _json.dumps(_tc, separators=(",", ":")),
710
+ flush=True,
711
+ )
712
+ except Exception:
713
+ print(f"[task:crafter] tool_call[{_i}]: {_tc}", flush=True)
714
+ except Exception:
715
+ pass
716
+ if not parsed:
717
+ # Dump compact body preview to understand schema when no tools parsed
718
+ try:
719
+ import json as _json
720
+ preview = _json.dumps(data, separators=(",",":"))
721
+ print("[task:crafter] body(no_tools) preview:", preview[:800], flush=True)
722
+ except Exception:
723
+ pass
724
+ # Early terminate the episode to avoid hanging on empty tool calls
725
+ print("[task:crafter] NO_TOOL_CALLS: terminating episode early", flush=True)
726
+ break
727
+ pending_tool_calls = parsed
728
+ ops_executed += 1
729
+ elif op == "env":
730
+ if not pending_tool_calls:
731
+ print("[task:crafter] no tool_calls; skipping env step", flush=True)
732
+ continue
733
+ info: dict[str, Any] | None = None
734
+ for tc in pending_tool_calls:
735
+ name = tc.get("tool_name")
736
+ if name == "interact":
737
+ # Parse the JSON arguments string
738
+ import json
739
+ args_str = tc.get("arguments", "{}")
740
+ try:
741
+ args_dict = json.loads(args_str)
742
+ actions = args_dict.get("actions", [])
743
+ reasoning = args_dict.get("reasoning", "")
744
+ print(f"[task:crafter] reasoning: {reasoning}", flush=True)
745
+ except (json.JSONDecodeError, TypeError):
746
+ print(f"[task:crafter] ERROR: Failed to parse arguments: {args_str}", flush=True)
747
+ actions = []
748
+ reasoning = "Parse error"
749
+
750
+ print(f"[task:crafter] env actions: {actions}", flush=True)
751
+ # Print a compact echo of the current prompt + tool call for easier triage
752
+ try:
753
+ import json as _json
754
+ print("TOOLCALL_CONFIG:", _json.dumps({
755
+ "policy": req.policy.policy_name,
756
+ "tools_present": True,
757
+ "tool_choice": "required",
758
+ "stop_after": 1,
759
+ }))
760
+ except Exception:
761
+ pass
762
+
763
+ # Execute each action individually
764
+ # Reset decision-level flip set for this decision
765
+ decision_flips: set[str] = set()
766
+ for act in actions:
767
+ observation, reward, done, _info = await _registry.step(env_id, act)
768
+ total_reward += float(reward)
769
+ # Debug: print step outcome (compact)
770
+ try:
771
+ ok = list(observation.keys()) if isinstance(observation, dict) else []
772
+ print(f"[task:crafter] step => a={act} r={float(reward)} done={bool(done)} obs_keys={ok[:5]}", flush=True)
773
+ except Exception:
774
+ pass
775
+ step = RolloutStep(obs=observation, tool_calls=pending_tool_calls, reward=float(reward), done=bool(done), truncated=False, info=info)
776
+ trajectory_steps.append(step)
777
+ ops_executed += 1
778
+
779
+ # Check for achievement-based termination
780
+ if isinstance(observation, dict):
781
+ current_achievements = observation.get("achievements_status", {})
782
+ # Track flips 0→1 within this decision
783
+ try:
784
+ if not isinstance(current_achievements, dict):
785
+ current_achievements = {}
786
+ if prev_ach is None:
787
+ prev_ach = {k: bool(v) for k, v in (current_achievements or {}).items()}
788
+ else:
789
+ for name, on in (current_achievements or {}).items():
790
+ if bool(on) and not bool(prev_ach.get(name, False)):
791
+ decision_flips.add(str(name))
792
+ # Update prev_ach to latest snapshot
793
+ prev_ach = {k: bool(v) for k, v in (current_achievements or {}).items()}
794
+ except Exception:
795
+ pass
796
+ achieved_count = sum(1 for v in current_achievements.values() if v)
797
+ total_achievements = len(current_achievements)
798
+
799
+ # Terminate if we've achieved a significant portion of available achievements
800
+ if total_achievements > 0 and achieved_count >= max(3, total_achievements // 2):
801
+ print(f"[task:crafter] achievement_termination: {achieved_count}/{total_achievements} achievements reached", flush=True)
802
+ print(f"[task:crafter] achieved: {[k for k, v in current_achievements.items() if v]}", flush=True)
803
+ break
804
+
805
+ if done or len(trajectory_steps) >= max_steps:
806
+ print(f"[task:crafter] episode_end: done={bool(done)} steps={len(trajectory_steps)} total_reward={total_reward}", flush=True)
807
+ break
808
+ elif name == "terminate":
809
+ # Handle termination
810
+ print("[task:crafter] Agent requested termination", flush=True)
811
+ break
812
+ else:
813
+ # Non-interact tool call: count as a step without env change
814
+ print("[task:crafter] non-interact tool_call:", name, flush=True)
815
+ step = RolloutStep(obs=observation, tool_calls=pending_tool_calls, reward=None, done=False, truncated=False, info=info)
816
+ trajectory_steps.append(step)
817
+ ops_executed += 1
818
+ # End of decision: record indicator_i for shaping
819
+ try:
820
+ indicator_i = 1 if decision_flips else 0
821
+ decision_summaries.append({"indicator_i": indicator_i})
822
+ except Exception:
823
+ pass
824
+ pending_tool_calls = None
825
+ if len(trajectory_steps) >= max_steps:
826
+ print(f"[task:crafter] max_steps_reached: steps={len(trajectory_steps)} total_reward={total_reward}", flush=True)
827
+ break
828
+ else:
829
+ # Unknown op: skip
830
+ continue
831
+ if len(trajectory_steps) >= max_steps:
832
+ break
833
+ finally:
834
+ await _registry.terminate(env_id)
835
+
836
+ # Sanitize steps for JSON
837
+ safe_steps = [
838
+ RolloutStep(
839
+ obs=_to_jsonable(s.obs),
840
+ tool_calls=s.tool_calls,
841
+ reward=float(s.reward) if s.reward is not None else None,
842
+ done=bool(s.done),
843
+ truncated=bool(s.truncated) if s.truncated is not None else None,
844
+ info=_to_jsonable(s.info) if s.info is not None else None,
845
+ )
846
+ for s in trajectory_steps
847
+ ]
848
+
849
+ trajectory = RolloutTrajectory(
850
+ env_id=env_id,
851
+ policy_id=req.policy.policy_name or "crafter-policy",
852
+ steps=safe_steps,
853
+ final={"observation": _to_jsonable(observation)},
854
+ length=len(safe_steps),
855
+ )
856
+ # Calculate achievements for this episode
857
+ final_obs = observation
858
+ if isinstance(final_obs, dict):
859
+ final_achievements = final_obs.get("achievements_status", {})
860
+ else:
861
+ # Handle numpy array case - no achievements available
862
+ final_achievements = {}
863
+ total_achievements = sum(1 for v in final_achievements.values() if v)
864
+
865
+ # Step-reward shaping: compute decision-level rewards if enabled
866
+ branches: dict[str, Any] = {}
867
+ try:
868
+ sr_cfg = (req.record.config or {}).get("step_rewards") if isinstance(req.record, RolloutRecordConfig) else None
869
+ except Exception:
870
+ sr_cfg = None
871
+ try:
872
+ enabled = False
873
+ mode = None
874
+ step_beta = 0.0
875
+ indicator_lambda = 0.0
876
+ if isinstance(sr_cfg, dict):
877
+ enabled = bool(sr_cfg.get("enabled", False))
878
+ mode = (sr_cfg.get("mode") or "off").strip().lower()
879
+ step_beta = float(sr_cfg.get("step_beta", 0.0))
880
+ indicator_lambda = float(sr_cfg.get("indicator_lambda", 0.0))
881
+ # Env overrides
882
+ import os as _os2
883
+ if _os2.getenv("STEP_BETA"):
884
+ step_beta = float(_os2.getenv("STEP_BETA"))
885
+ if _os2.getenv("STEP_LAMBDA"):
886
+ indicator_lambda = float(_os2.getenv("STEP_LAMBDA"))
887
+ if enabled and mode == "decision_stepwise" and decision_summaries:
888
+ dec_rewards = compute_decision_rewards(
889
+ decision_summaries=decision_summaries,
890
+ total_achievements=total_achievements,
891
+ step_beta=step_beta,
892
+ indicator_lambda=indicator_lambda,
893
+ )
894
+ branches["decision_rewards"] = dec_rewards
895
+ print(
896
+ "[task:crafter] step_rewards: ",
897
+ {
898
+ "enabled": True,
899
+ "mode": mode,
900
+ "step_beta": step_beta,
901
+ "indicator_lambda": indicator_lambda,
902
+ "decisions": len(dec_rewards),
903
+ },
904
+ flush=True,
905
+ )
906
+ except Exception as _e_sr:
907
+ print(f"[task:crafter] step_rewards_error: {_e_sr}", flush=True)
908
+
909
+ # Optional tracing of episode/rewards (gated)
910
+ try:
911
+ import os as _os3
912
+ if _os3.getenv("TRACE_RL", "0") == "1":
913
+ from synth_ai.tracing_v3.session_tracer import SessionTracer # type: ignore
914
+ tracer = SessionTracer()
915
+ await tracer.initialize()
916
+ meta = {
917
+ "env": req.env.env_name,
918
+ "policy": req.policy.policy_name,
919
+ "step_rewards": {
920
+ "enabled": bool(sr_cfg.get("enabled", False)) if isinstance(sr_cfg, dict) else False,
921
+ "mode": (sr_cfg.get("mode") if isinstance(sr_cfg, dict) else None),
922
+ },
923
+ }
924
+ async with tracer.session(metadata=meta):
925
+ # Record episode outcome at end
926
+ await tracer.record_outcome_reward(
927
+ total_reward=int(total_reward),
928
+ achievements_count=int(total_achievements),
929
+ total_steps=int(len(trajectory_steps)),
930
+ )
931
+ except Exception as _te:
932
+ print(f"[task:crafter] tracing_error: {_te}", flush=True)
933
+
934
+ metrics = RolloutMetrics(
935
+ episode_returns=[total_reward],
936
+ mean_return=float(total_achievements),
937
+ num_steps=len(trajectory_steps),
938
+ num_episodes=1,
939
+ )
940
+ # Debug: print reward and achievement metrics
941
+ print(f"[task:crafter] Rollout metrics: total_reward={total_reward}, total_achievements={total_achievements}, mean_return={metrics.mean_return}, episode_returns={metrics.episode_returns}", flush=True)
942
+ return RolloutResponse(
943
+ run_id=req.run_id,
944
+ trajectories=[trajectory],
945
+ branches=branches,
946
+ metrics=metrics,
947
+ aborted=False,
948
+ ops_executed=ops_executed,
949
+ )
950
+
951
+ @api.get("/test_auth")
952
+ def test_auth(request: Request):
953
+ expected = os.environ.get("ENVIRONMENT_API_KEY")
954
+ if not expected:
955
+ raise HTTPException(status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail="Missing ENVIRONMENT_API_KEY in service env")
956
+ header_key = request.headers.get("x-api-key") or request.headers.get("X-API-Key")
957
+ ok = bool(header_key) and (header_key == expected)
958
+ if not ok:
959
+ raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid or missing API key")
960
+ return {"ok": True}
961
+
962
+ return api