synth-ai 0.2.9.dev2__py3-none-any.whl → 0.2.9.dev4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of synth-ai might be problematic. Click here for more details.

Files changed (112) hide show
  1. examples/analyze_semantic_words.sh +17 -0
  2. examples/common_old/backend.py +21 -0
  3. examples/crafter_debug_render.py +180 -0
  4. examples/evals_old/README.md +98 -0
  5. examples/evals_old/__init__.py +6 -0
  6. examples/evals_old/compare_models.py +1037 -0
  7. examples/evals_old/example_log.md +145 -0
  8. examples/evals_old/run_demo.sh +126 -0
  9. examples/evals_old/trace_analysis.py +270 -0
  10. examples/finetuning_old/_backup_synth_qwen/config.toml +29 -0
  11. examples/finetuning_old/_backup_synth_qwen/example_log.md +324 -0
  12. examples/finetuning_old/_backup_synth_qwen/filter_traces.py +60 -0
  13. examples/finetuning_old/_backup_synth_qwen/filter_traces_achievements.py +239 -0
  14. examples/finetuning_old/_backup_synth_qwen/purge_v3_traces.py +109 -0
  15. examples/finetuning_old/_backup_synth_qwen/react_agent_lm.py +1924 -0
  16. examples/finetuning_old/_backup_synth_qwen/readme.md +49 -0
  17. examples/finetuning_old/_backup_synth_qwen/run_crafter_qwen4b.py +114 -0
  18. examples/finetuning_old/_backup_synth_qwen/run_demo.sh +195 -0
  19. examples/finetuning_old/_backup_synth_qwen/sft_kickoff.py +118 -0
  20. examples/finetuning_old/synth_qwen_v1/README.md +68 -0
  21. examples/finetuning_old/synth_qwen_v1/filter_traces.py +60 -0
  22. examples/finetuning_old/synth_qwen_v1/filter_traces_achievements.py +239 -0
  23. examples/finetuning_old/synth_qwen_v1/finetune.py +46 -0
  24. examples/finetuning_old/synth_qwen_v1/hello_ft_model.py +71 -0
  25. examples/finetuning_old/synth_qwen_v1/infer.py +37 -0
  26. examples/finetuning_old/synth_qwen_v1/poll.py +44 -0
  27. examples/finetuning_old/synth_qwen_v1/prepare_data.py +35 -0
  28. examples/finetuning_old/synth_qwen_v1/purge_v3_traces.py +109 -0
  29. examples/finetuning_old/synth_qwen_v1/react_agent_lm.py +1932 -0
  30. examples/finetuning_old/synth_qwen_v1/run_crafter_sft_job.py +207 -0
  31. examples/finetuning_old/synth_qwen_v1/run_ft_job.py +232 -0
  32. examples/finetuning_old/synth_qwen_v1/upload_data.py +34 -0
  33. examples/finetuning_old/synth_qwen_v1/util.py +147 -0
  34. examples/rl/README.md +169 -0
  35. examples/rl/configs/eval_base_qwen.toml +15 -0
  36. examples/rl/configs/eval_rl_qwen.toml +11 -0
  37. examples/rl/configs/rl_from_base_qwen.toml +35 -0
  38. examples/rl/configs/rl_from_base_qwen17.toml +74 -0
  39. examples/rl/configs/rl_from_ft_qwen.toml +35 -0
  40. examples/rl/download_dataset.py +64 -0
  41. examples/rl/run_eval.py +435 -0
  42. examples/rl/run_rl_and_save.py +94 -0
  43. examples/rl/task_app/README.md +22 -0
  44. {synth_ai/task/apps → examples/rl/task_app}/math_single_step.py +8 -8
  45. examples/rl/task_app/math_task_app.py +107 -0
  46. examples/rl_old/task_app.py +962 -0
  47. examples/run_crafter_demo.sh +10 -0
  48. examples/warming_up_to_rl/analyze_trace_db.py +420 -0
  49. examples/warming_up_to_rl/configs/crafter_fft.toml +48 -0
  50. examples/warming_up_to_rl/configs/crafter_fft_4b.toml +54 -0
  51. examples/warming_up_to_rl/configs/eval_fft_qwen4b.toml +20 -0
  52. examples/warming_up_to_rl/configs/eval_groq_qwen32b.toml +13 -0
  53. examples/warming_up_to_rl/configs/eval_modal_qwen4b.toml +23 -0
  54. examples/warming_up_to_rl/configs/rl_from_base_qwen4b.toml +73 -0
  55. examples/warming_up_to_rl/configs/rl_from_ft.toml +56 -0
  56. examples/warming_up_to_rl/export_trace_sft.py +541 -0
  57. examples/warming_up_to_rl/groq_test.py +88 -0
  58. examples/warming_up_to_rl/manage_secrets.py +127 -0
  59. examples/warming_up_to_rl/old/event_rewards.md +234 -0
  60. examples/warming_up_to_rl/old/notes.md +73 -0
  61. examples/warming_up_to_rl/readme.md +172 -0
  62. examples/warming_up_to_rl/run_eval.py +434 -0
  63. examples/warming_up_to_rl/run_fft_and_save.py +309 -0
  64. examples/warming_up_to_rl/run_local_rollout.py +188 -0
  65. examples/warming_up_to_rl/run_local_rollout_modal.py +160 -0
  66. examples/warming_up_to_rl/run_local_rollout_parallel.py +342 -0
  67. examples/warming_up_to_rl/run_local_rollout_traced.py +372 -0
  68. examples/warming_up_to_rl/run_rl_and_save.py +101 -0
  69. examples/warming_up_to_rl/run_rollout_remote.py +129 -0
  70. examples/warming_up_to_rl/task_app/README.md +38 -0
  71. {synth_ai/task/apps → examples/warming_up_to_rl/task_app}/grpo_crafter.py +7 -7
  72. examples/warming_up_to_rl/task_app/grpo_crafter_task_app.py +165 -0
  73. examples/warming_up_to_rl/task_app/synth_envs_hosted/README.md +173 -0
  74. examples/warming_up_to_rl/task_app/synth_envs_hosted/__init__.py +5 -0
  75. examples/warming_up_to_rl/task_app/synth_envs_hosted/branching.py +145 -0
  76. examples/warming_up_to_rl/task_app/synth_envs_hosted/environment_routes.py +1271 -0
  77. examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/__init__.py +1 -0
  78. examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/__init__.py +6 -0
  79. examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/app.py +1 -0
  80. examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/environment.py +429 -0
  81. examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/policy.py +442 -0
  82. examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/react_agent.py +96 -0
  83. examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/shared.py +302 -0
  84. examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/tools.py +47 -0
  85. examples/warming_up_to_rl/task_app/synth_envs_hosted/hosted_app.py +202 -0
  86. examples/warming_up_to_rl/task_app/synth_envs_hosted/inference/__init__.py +5 -0
  87. examples/warming_up_to_rl/task_app/synth_envs_hosted/inference/openai_client.py +512 -0
  88. examples/warming_up_to_rl/task_app/synth_envs_hosted/main.py +102 -0
  89. examples/warming_up_to_rl/task_app/synth_envs_hosted/policy_routes.py +985 -0
  90. examples/warming_up_to_rl/task_app/synth_envs_hosted/registry.py +197 -0
  91. examples/warming_up_to_rl/task_app/synth_envs_hosted/rollout.py +1749 -0
  92. examples/warming_up_to_rl/task_app/synth_envs_hosted/storage/__init__.py +5 -0
  93. examples/warming_up_to_rl/task_app/synth_envs_hosted/storage/volume.py +217 -0
  94. examples/warming_up_to_rl/task_app/synth_envs_hosted/test_agents.py +160 -0
  95. examples/warming_up_to_rl/task_app/synth_envs_hosted/test_service.py +146 -0
  96. examples/warming_up_to_rl/task_app/synth_envs_hosted/test_stepwise_rewards.py +58 -0
  97. examples/warming_up_to_rl/task_app/synth_envs_hosted/utils.py +61 -0
  98. synth_ai/api/train/config_finder.py +18 -18
  99. synth_ai/api/train/env_resolver.py +28 -1
  100. synth_ai/cli/task_apps.py +264 -55
  101. synth_ai/demo_registry.py +7 -7
  102. synth_ai/demos/demo_task_apps/crafter/__init__.py +1 -0
  103. synth_ai/demos/demo_task_apps/crafter/configs/crafter_fft_4b.toml +54 -0
  104. synth_ai/demos/demo_task_apps/crafter/configs/rl_from_base_qwen4b.toml +73 -0
  105. synth_ai/demos/demo_task_apps/crafter/grpo_crafter_task_app.py +165 -0
  106. synth_ai/task/apps/__init__.py +54 -13
  107. {synth_ai-0.2.9.dev2.dist-info → synth_ai-0.2.9.dev4.dist-info}/METADATA +1 -1
  108. {synth_ai-0.2.9.dev2.dist-info → synth_ai-0.2.9.dev4.dist-info}/RECORD +112 -13
  109. {synth_ai-0.2.9.dev2.dist-info → synth_ai-0.2.9.dev4.dist-info}/top_level.txt +1 -0
  110. {synth_ai-0.2.9.dev2.dist-info → synth_ai-0.2.9.dev4.dist-info}/WHEEL +0 -0
  111. {synth_ai-0.2.9.dev2.dist-info → synth_ai-0.2.9.dev4.dist-info}/entry_points.txt +0 -0
  112. {synth_ai-0.2.9.dev2.dist-info → synth_ai-0.2.9.dev4.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,1749 @@
1
+ from __future__ import annotations
2
+
3
+ import logging
4
+ import json
5
+ from datetime import datetime
6
+ from pathlib import Path
7
+ from typing import Any, Dict, List, Optional, Tuple
8
+
9
+ from fastapi import APIRouter, HTTPException, Request, status
10
+ import os
11
+ import time as _time
12
+ from pydantic import BaseModel
13
+ from synth_ai.lm.vendors.base import BaseLMResponse
14
+ from synth_ai.tracing_v3.session_tracer import SessionTracer
15
+ from synth_ai.tracing_v3.abstractions import EnvironmentEvent, LMCAISEvent, TimeRecord
16
+ from synth_ai.tracing_v3.llm_call_record_helpers import create_llm_call_record_from_response
17
+ from synth_ai.task.tracing_utils import unique_sft_path
18
+
19
+ from .registry import registry
20
+
21
+ logger = logging.getLogger(__name__)
22
+
23
+ # --- Seeding utilities (robust, optional deps) ---
24
+ def _set_global_seed(seed_value: int) -> Dict[str, Any]:
25
+ """Set global RNG seeds across common libraries; return details for logging/restoration.
26
+
27
+ Returns a dict containing which libraries were seeded and prior states if obtainable.
28
+ """
29
+ seeded: Dict[str, Any] = {"seed": int(seed_value), "libs": []}
30
+ try:
31
+ import random as _random # type: ignore
32
+ _random.seed(seed_value)
33
+ seeded["libs"].append("random")
34
+ except Exception:
35
+ pass
36
+ try:
37
+ import numpy as _np # type: ignore
38
+ _np.random.seed(seed_value)
39
+ seeded["libs"].append("numpy")
40
+ except Exception:
41
+ pass
42
+ try:
43
+ import torch as _torch # type: ignore
44
+ if hasattr(_torch, "manual_seed"):
45
+ _torch.manual_seed(seed_value)
46
+ seeded["libs"].append("torch")
47
+ # Make CUDA deterministic if present (best-effort)
48
+ try:
49
+ if getattr(_torch, "cuda", None) and _torch.cuda.is_available():
50
+ _torch.cuda.manual_seed_all(seed_value)
51
+ seeded.setdefault("cuda", True)
52
+ except Exception:
53
+ pass
54
+ # CUDNN deterministic flags (optional)
55
+ try:
56
+ if getattr(_torch, "backends", None) and getattr(_torch.backends, "cudnn", None):
57
+ _torch.backends.cudnn.deterministic = True # type: ignore[attr-defined]
58
+ _torch.backends.cudnn.benchmark = False # type: ignore[attr-defined]
59
+ except Exception:
60
+ pass
61
+ except Exception:
62
+ pass
63
+ return seeded
64
+
65
+ def _clear_seed_side_effects() -> None:
66
+ """Best-effort cleanup to avoid global deterministic side-effects between requests."""
67
+ # We cannot truly restore prior RNG states without capturing them; we just avoid
68
+ # leaving aggressive deterministic flags enabled where it matters.
69
+ try:
70
+ import torch as _torch # type: ignore
71
+ try:
72
+ if getattr(_torch, "backends", None) and getattr(_torch.backends, "cudnn", None):
73
+ # Re-enable cudnn.benchmark default True only if it was True; safest is False -> leave as is.
74
+ # We'll keep deterministic False to avoid global impact; benchmark left False for stability.
75
+ _torch.backends.cudnn.deterministic = False # type: ignore[attr-defined]
76
+ except Exception:
77
+ pass
78
+ except Exception:
79
+ pass
80
+
81
+ router = APIRouter()
82
+
83
+
84
+ class RolloutEnvSpec(BaseModel):
85
+ env_id: Optional[str] = None
86
+ env_name: Optional[str] = None
87
+ config: Dict[str, Any] = {}
88
+ seed: Optional[int] = None
89
+
90
+
91
+ class RolloutPolicySpec(BaseModel):
92
+ policy_id: Optional[str] = None
93
+ policy_name: Optional[str] = None
94
+ config: Dict[str, Any] = {}
95
+
96
+
97
+ class RolloutBranchConfig(BaseModel):
98
+ branch_every_n_steps: int = 0
99
+ branch_on_condition: Optional[str] = None
100
+ max_branches: int = 0
101
+ branch_policy: bool = False
102
+ branch_env: bool = False
103
+
104
+
105
+ class RolloutRecordConfig(BaseModel):
106
+ trajectories: bool = True
107
+ logprobs: bool = False
108
+ value: bool = False
109
+ return_trace: bool = False
110
+ trace_format: str = "compact"
111
+
112
+
113
+ class RolloutSafetyConfig(BaseModel):
114
+ max_ops: int = 100000
115
+ max_time_s: float = 3600.0
116
+
117
+
118
+ class RolloutRequest(BaseModel):
119
+ run_id: str
120
+ env: RolloutEnvSpec
121
+ policy: RolloutPolicySpec
122
+ ops: List[str] # ["agent", "env", ...]
123
+ record: RolloutRecordConfig = RolloutRecordConfig()
124
+ on_done: str = "reset" # "reset" | "terminate"
125
+ branch: Optional[RolloutBranchConfig] = None
126
+ safety: RolloutSafetyConfig = RolloutSafetyConfig()
127
+ # Optional run/session context
128
+ training_session_id: Optional[str] = None
129
+ synth_base_url: Optional[str] = None
130
+
131
+
132
+ class RolloutStep(BaseModel):
133
+ obs: Dict[str, Any]
134
+ tool_calls: List[Dict[str, Any]]
135
+ reward: Optional[float] = None
136
+ done: bool = False
137
+ truncated: Optional[bool] = None
138
+ logprob: Optional[float] = None
139
+ value: Optional[float] = None
140
+ info: Optional[Dict[str, Any]] = None
141
+
142
+
143
+ class RolloutTrajectory(BaseModel):
144
+ env_id: str
145
+ policy_id: str
146
+ steps: List[RolloutStep]
147
+ final: Optional[Dict[str, Any]] = None
148
+ length: int
149
+ decision_samples: Optional[List[Dict[str, Any]]] = None
150
+
151
+
152
+ def compute_stepwise_reward(
153
+ prev_achievements: Dict[str, bool],
154
+ new_achievements: Dict[str, bool],
155
+ decision_index: int,
156
+ actions_summary: List[Dict[str, Any]],
157
+ indicator_lambda: float,
158
+ ) -> Tuple[Dict[str, Any], Dict[str, Any], Dict[str, float]]:
159
+ """Compute stepwise reward metadata given achievement states before/after a decision."""
160
+
161
+ prev_map = prev_achievements or {}
162
+ next_map = new_achievements or {}
163
+
164
+ unlocked = [
165
+ name
166
+ for name, value in next_map.items()
167
+ if value and not prev_map.get(name, False)
168
+ ]
169
+ indicator = 1 if unlocked else 0
170
+ reward_value = float(indicator_lambda) * indicator
171
+
172
+ stepwise_info = {
173
+ "decision_index": decision_index,
174
+ "indicator": indicator,
175
+ "new_achievements": unlocked,
176
+ "reward": reward_value,
177
+ }
178
+ decision_sample = {
179
+ "decision_index": decision_index,
180
+ "indicator": indicator,
181
+ "r_i": reward_value,
182
+ "actions": actions_summary,
183
+ }
184
+ stats = {
185
+ "indicator": float(indicator),
186
+ "reward": reward_value,
187
+ "new_achievements_count": float(len(unlocked)),
188
+ }
189
+ return stepwise_info, decision_sample, stats
190
+
191
+
192
+ class RolloutMetrics(BaseModel):
193
+ episode_returns: List[float]
194
+ mean_return: float
195
+ num_steps: int
196
+ num_episodes: int = 0
197
+
198
+
199
+ class RolloutResponse(BaseModel):
200
+ run_id: str
201
+ trajectories: List[RolloutTrajectory]
202
+ branches: Dict[str, List[str]] = {}
203
+ metrics: RolloutMetrics
204
+ aborted: bool = False
205
+ ops_executed: int = 0
206
+ trace: Dict[str, Any] | None = None
207
+
208
+
209
+ class RolloutTracingContext:
210
+ """Helper managing tracing_v3 recording and optional SFT dumps for a rollout."""
211
+
212
+ def __init__(
213
+ self,
214
+ tracer: SessionTracer | None,
215
+ request: RolloutRequest,
216
+ fastapi_request: Request,
217
+ ) -> None:
218
+ self.tracer = tracer
219
+ self.enabled = tracer is not None
220
+ self.request = request
221
+ self.fastapi_request = fastapi_request
222
+ self.run_id = request.run_id
223
+ self.current_step_id: str | None = None
224
+ self.current_turn: int | None = None
225
+ self.lm_calls_summary: list[dict[str, Any]] = []
226
+ self.decision_rewards: list[dict[str, Any]] = []
227
+ self.sft_records: list[dict[str, Any]] = []
228
+ self.latest_system_messages: list[str] = []
229
+ self.latest_user_messages: list[str] = []
230
+ self.trace_format = (getattr(request.record, "trace_format", "compact") or "compact").lower()
231
+ self.return_trace = bool(getattr(request.record, "return_trace", False))
232
+ self.sft_output_dir = getattr(fastapi_request.app.state, "sft_output_dir", None)
233
+ self.session_trace = None
234
+ self.metadata_updates: dict[str, Any] = {}
235
+ self.policy_name = request.policy.policy_name or ""
236
+ self.env_name = request.env.env_name or ""
237
+ self.metadata_base: dict[str, Any] = {
238
+ "run_id": self.run_id,
239
+ "policy_name": self.policy_name,
240
+ "policy_id": request.policy.policy_id,
241
+ "env_name": self.env_name,
242
+ "env_id": request.env.env_id,
243
+ "seed": request.env.seed,
244
+ "training_session_id": request.training_session_id,
245
+ "synth_base_url": request.synth_base_url,
246
+ }
247
+
248
+ # Expose context for downstream calls inside this request lifecycle
249
+ fastapi_request.state.rollout_tracing = self
250
+ fastapi_request.state.rollout_run_id = self.run_id
251
+
252
+ async def start_session(self) -> None:
253
+ if not self.enabled or self.tracer is None:
254
+ return
255
+ try:
256
+ await self.tracer.initialize()
257
+ except Exception as exc:
258
+ logger.debug("TRACING_INIT_FAIL: %s", exc)
259
+ try:
260
+ await self.tracer.start_session(session_id=self.run_id, metadata=dict(self.metadata_base))
261
+ except Exception as exc:
262
+ logger.warning("TRACING_START_FAIL: %s", exc)
263
+ self.enabled = False
264
+ self.tracer = None
265
+
266
+ async def start_decision(self, turn_number: int) -> None:
267
+ self.current_turn = turn_number
268
+ self.current_step_id = f"decision_{turn_number}"
269
+ if not self.enabled or self.tracer is None:
270
+ return
271
+ try:
272
+ await self.tracer.start_timestep(step_id=self.current_step_id, turn_number=turn_number)
273
+ except Exception as exc:
274
+ logger.debug("TRACING_STEP_START_FAIL: %s", exc)
275
+
276
+ async def end_decision(self) -> None:
277
+ if not self.enabled or self.tracer is None:
278
+ return
279
+ try:
280
+ await self.tracer.end_timestep(step_id=self.current_step_id)
281
+ except Exception as exc:
282
+ logger.debug("TRACING_STEP_END_FAIL: %s", exc)
283
+ finally:
284
+ self.current_step_id = None
285
+
286
+ def _message_metadata(self) -> dict[str, Any]:
287
+ return {
288
+ "turn": self.current_turn,
289
+ "step_id": self.current_step_id,
290
+ }
291
+
292
+ async def record_policy_prompts(
293
+ self,
294
+ system_messages: list[str],
295
+ user_messages: list[str],
296
+ ) -> None:
297
+ self.latest_system_messages = list(system_messages)
298
+ self.latest_user_messages = list(user_messages)
299
+ if not self.enabled or self.tracer is None:
300
+ return
301
+ for msg in system_messages:
302
+ try:
303
+ await self.tracer.record_message(
304
+ content=msg,
305
+ message_type="policy_system_prompt",
306
+ metadata=self._message_metadata(),
307
+ )
308
+ except Exception as exc:
309
+ logger.debug("TRACING_SYSTEM_MSG_FAIL: %s", exc)
310
+ for msg in user_messages:
311
+ try:
312
+ await self.tracer.record_message(
313
+ content=msg,
314
+ message_type="policy_user_prompt",
315
+ metadata=self._message_metadata(),
316
+ )
317
+ except Exception as exc:
318
+ logger.debug("TRACING_USER_MSG_FAIL: %s", exc)
319
+
320
+ def _content_to_text(self, content: Any) -> str:
321
+ if isinstance(content, str):
322
+ return content
323
+ if isinstance(content, list):
324
+ parts: list[str] = []
325
+ for seg in content:
326
+ if isinstance(seg, dict):
327
+ text_val = seg.get("text") or seg.get("content")
328
+ if isinstance(text_val, str):
329
+ parts.append(text_val)
330
+ return "".join(parts)
331
+ if content is None:
332
+ return ""
333
+ return str(content)
334
+
335
+ def _safe_json(self, payload: Any, limit: int = 4000) -> str:
336
+ try:
337
+ text = json.dumps(payload, ensure_ascii=False)
338
+ except Exception:
339
+ text = str(payload)
340
+ if len(text) > limit:
341
+ return text[:limit] + "…"
342
+ return text
343
+
344
+ async def record_tool_invocation(self, tool_calls: list[dict[str, Any]] | None) -> None:
345
+ if tool_calls is None:
346
+ return
347
+ if self.enabled and self.tracer is not None:
348
+ try:
349
+ await self.tracer.record_message(
350
+ content=self._safe_json(tool_calls),
351
+ message_type="policy_tool_call",
352
+ metadata=self._message_metadata(),
353
+ )
354
+ except Exception as exc:
355
+ logger.debug("TRACING_TOOL_MSG_FAIL: %s", exc)
356
+
357
+ async def _record_event(self, event: Any) -> int | None:
358
+ if not self.enabled or self.tracer is None:
359
+ return None
360
+ try:
361
+ return await self.tracer.record_event(event)
362
+ except Exception as exc:
363
+ logger.debug("TRACING_EVENT_FAIL: %s", exc)
364
+ return None
365
+
366
+ async def record_llm_call(
367
+ self,
368
+ *,
369
+ inference_request: dict[str, Any],
370
+ inference_response: dict[str, Any],
371
+ tool_calls: list[dict[str, Any]] | None,
372
+ provider: str,
373
+ model_name: str,
374
+ started_at: datetime,
375
+ completed_at: datetime,
376
+ latency_ms: int | None,
377
+ ) -> None:
378
+ usage = inference_response.get("usage") or {}
379
+ input_tokens = usage.get("input_tokens") or usage.get("prompt_tokens")
380
+ output_tokens = usage.get("output_tokens") or usage.get("completion_tokens")
381
+ total_tokens = usage.get("total_tokens")
382
+ cost_usd = (
383
+ usage.get("cost_usd")
384
+ or usage.get("cost")
385
+ or usage.get("total_cost")
386
+ )
387
+
388
+ assistant_message = None
389
+ choices = inference_response.get("choices") or []
390
+ if choices:
391
+ assistant_message = choices[0].get("message") or {}
392
+ assistant_content = assistant_message.get("content") if isinstance(assistant_message, dict) else None
393
+
394
+ raw_response = self._content_to_text(assistant_content)
395
+ if not raw_response:
396
+ raw_response = self._safe_json(inference_response, limit=2000)
397
+
398
+ base_response = BaseLMResponse(
399
+ raw_response=raw_response,
400
+ tool_calls=assistant_message.get("tool_calls") if isinstance(assistant_message, dict) else None,
401
+ usage=usage or None,
402
+ api_type="chat_completions",
403
+ )
404
+
405
+ request_messages = inference_request.get("messages") or []
406
+ try:
407
+ temperature = float(inference_request.get("temperature"))
408
+ except Exception:
409
+ temperature = 0.0
410
+
411
+ call_record = create_llm_call_record_from_response(
412
+ response=base_response,
413
+ model_name=model_name,
414
+ provider=provider,
415
+ messages=request_messages,
416
+ temperature=temperature,
417
+ request_params=inference_request,
418
+ tools=inference_request.get("tools"),
419
+ started_at=started_at,
420
+ completed_at=completed_at,
421
+ latency_ms=latency_ms,
422
+ )
423
+
424
+ event_metadata = {
425
+ "policy_id": self.request.policy.policy_id,
426
+ "turn": self.current_turn,
427
+ "run_id": self.run_id,
428
+ }
429
+
430
+ event = LMCAISEvent(
431
+ system_instance_id=f"policy:{self.policy_name or 'unknown'}",
432
+ time_record=TimeRecord(event_time=completed_at.timestamp()),
433
+ model_name=model_name,
434
+ provider=provider,
435
+ input_tokens=input_tokens,
436
+ output_tokens=output_tokens,
437
+ total_tokens=total_tokens,
438
+ cost_usd=cost_usd,
439
+ latency_ms=latency_ms,
440
+ call_records=[call_record],
441
+ metadata=event_metadata,
442
+ )
443
+
444
+ await self._record_event(event)
445
+
446
+ self.lm_calls_summary.append(
447
+ {
448
+ "turn": self.current_turn,
449
+ "model": model_name,
450
+ "provider": provider,
451
+ "total_tokens": total_tokens,
452
+ "input_tokens": input_tokens,
453
+ "output_tokens": output_tokens,
454
+ "latency_ms": latency_ms,
455
+ "tool_calls": len(tool_calls or []),
456
+ }
457
+ )
458
+
459
+ if self.sft_output_dir is not None:
460
+ assistant_text = self._content_to_text(assistant_content)
461
+ record = {
462
+ "run_id": self.run_id,
463
+ "turn": self.current_turn,
464
+ "model": model_name,
465
+ "provider": provider,
466
+ "dialogue": (
467
+ [{"role": "system", "content": s} for s in self.latest_system_messages]
468
+ + [{"role": "user", "content": u} for u in self.latest_user_messages]
469
+ ),
470
+ "assistant": {
471
+ "content": assistant_text,
472
+ "tool_calls": assistant_message.get("tool_calls") if isinstance(assistant_message, dict) else [],
473
+ },
474
+ "timestamp": datetime.utcnow().isoformat(),
475
+ }
476
+ self.sft_records.append(record)
477
+
478
+ async def record_environment_event(
479
+ self,
480
+ *,
481
+ env_handle: Any,
482
+ prev_obs: Dict[str, Any] | None,
483
+ env_response: Any,
484
+ next_obs: Dict[str, Any] | None,
485
+ metadata: Dict[str, Any] | None = None,
486
+ ) -> int | None:
487
+ if not self.enabled or self.tracer is None:
488
+ return None
489
+
490
+ try:
491
+ prev_summary = _summarize_observation_for_storage(env_handle, prev_obs or {}) if prev_obs is not None else None
492
+ except Exception:
493
+ prev_summary = None
494
+ try:
495
+ next_summary = _summarize_observation_for_storage(env_handle, next_obs or {}) if next_obs is not None else None
496
+ except Exception:
497
+ next_summary = None
498
+
499
+ reward_val = getattr(env_response, "reward", None)
500
+ try:
501
+ reward_float = float(reward_val) if reward_val is not None else 0.0
502
+ except Exception:
503
+ reward_float = 0.0
504
+
505
+ event = EnvironmentEvent(
506
+ system_instance_id=f"environment:{self.env_name or 'unknown'}",
507
+ time_record=TimeRecord(event_time=datetime.utcnow().timestamp()),
508
+ reward=reward_float,
509
+ terminated=bool(getattr(env_response, "done", False)),
510
+ truncated=bool(getattr(env_response, "truncated", False)),
511
+ system_state_before=prev_summary,
512
+ system_state_after=next_summary,
513
+ metadata={
514
+ "turn": self.current_turn,
515
+ "run_id": self.run_id,
516
+ **(metadata or {}),
517
+ },
518
+ )
519
+
520
+ return await self._record_event(event)
521
+
522
+ async def record_decision_reward(
523
+ self,
524
+ *,
525
+ event_id: int | None,
526
+ decision_meta: Dict[str, Any] | None,
527
+ ) -> None:
528
+ decision_meta = decision_meta or {}
529
+ ach_delta = int(decision_meta.get("ach_delta", 0))
530
+ unique_delta = int(decision_meta.get("unique_delta", 0))
531
+ all_ach = list(decision_meta.get("all") or [])
532
+ unique_ach = list(decision_meta.get("unique") or [])
533
+
534
+ self.decision_rewards.append(
535
+ {
536
+ "turn": self.current_turn,
537
+ "ach_delta": ach_delta,
538
+ "unique_delta": unique_delta,
539
+ "achievements": all_ach,
540
+ "unique_achievements": unique_ach,
541
+ }
542
+ )
543
+
544
+ if not self.enabled or self.tracer is None or event_id is None:
545
+ return
546
+ try:
547
+ await self.tracer.record_event_reward(
548
+ event_id=event_id,
549
+ turn_number=self.current_turn,
550
+ reward_value=float(ach_delta),
551
+ reward_type="achievement_delta",
552
+ annotation={"achievements": all_ach},
553
+ source="environment",
554
+ )
555
+ if unique_delta:
556
+ await self.tracer.record_event_reward(
557
+ event_id=event_id,
558
+ turn_number=self.current_turn,
559
+ reward_value=float(unique_delta),
560
+ reward_type="unique_achievement_delta",
561
+ annotation={"achievements": unique_ach},
562
+ source="environment",
563
+ )
564
+ except Exception as exc:
565
+ logger.debug("TRACING_REWARD_FAIL: %s", exc)
566
+
567
+ def update_metadata(self, **kwargs: Any) -> None:
568
+ self.metadata_updates.update({k: v for k, v in kwargs.items() if v is not None})
569
+
570
+ async def finalize(
571
+ self,
572
+ *,
573
+ total_reward: float,
574
+ achievement_state: Dict[str, bool] | None,
575
+ total_steps: int,
576
+ ) -> Any:
577
+ final_achievements = [key for key, val in (achievement_state or {}).items() if val]
578
+ self.metadata_updates.setdefault("final_achievements", final_achievements)
579
+ if self.enabled and self.tracer is not None:
580
+ try:
581
+ await self.tracer.record_outcome_reward(
582
+ total_reward=int(total_reward),
583
+ achievements_count=len(final_achievements),
584
+ total_steps=int(total_steps),
585
+ reward_metadata=dict(self.metadata_updates),
586
+ )
587
+ except Exception as exc:
588
+ logger.debug("TRACING_OUTCOME_FAIL: %s", exc)
589
+ try:
590
+ self.session_trace = await self.tracer.end_session()
591
+ if self.session_trace is not None:
592
+ self.session_trace.metadata.update(self.metadata_updates)
593
+ except Exception as exc:
594
+ logger.debug("TRACING_END_SESSION_FAIL: %s", exc)
595
+ self.session_trace = None
596
+ try:
597
+ await self.tracer.close()
598
+ except Exception:
599
+ pass
600
+
601
+ if self.sft_records and self.sft_output_dir:
602
+ self.write_sft_records()
603
+
604
+ # Clear context from request state to avoid leaks
605
+ self.fastapi_request.state.rollout_tracing = None
606
+
607
+ return self.session_trace
608
+
609
+ def write_sft_records(self) -> None:
610
+ if not self.sft_output_dir or not self.sft_records:
611
+ return
612
+ try:
613
+ path = unique_sft_path(self.sft_output_dir, run_id=self.run_id)
614
+ path.parent.mkdir(parents=True, exist_ok=True)
615
+ with path.open("w", encoding="utf-8") as fh:
616
+ for record in self.sft_records:
617
+ json.dump(record, fh, ensure_ascii=False)
618
+ fh.write("\n")
619
+ logger.info(f"SFT_WRITTEN: {path}")
620
+ except Exception as exc:
621
+ logger.warning(f"SFT_WRITE_FAIL: {exc}")
622
+ finally:
623
+ self.sft_records.clear()
624
+
625
+ def build_trace_payload(self, session_trace: Any) -> Dict[str, Any] | None:
626
+ if not self.return_trace or session_trace is None:
627
+ return None
628
+ if self.trace_format == "full":
629
+ payload = session_trace.to_dict()
630
+ payload.setdefault("metadata", {}).update(self.metadata_updates)
631
+ return payload
632
+ metadata = dict(session_trace.metadata)
633
+ metadata.update(self.metadata_updates)
634
+ return {
635
+ "session_id": session_trace.session_id,
636
+ "created_at": session_trace.created_at.isoformat(),
637
+ "metadata": metadata,
638
+ "events_count": len(session_trace.event_history),
639
+ "messages_count": len(session_trace.markov_blanket_message_history),
640
+ "lm_calls": self.lm_calls_summary,
641
+ "decision_rewards": self.decision_rewards,
642
+ }
643
+ def _summarize_observation_for_storage(env_handle: Any, observation: Dict[str, Any]) -> Dict[str, Any]:
644
+ """Return a compact dict for trajectory storage instead of the raw observation.
645
+
646
+ - For Crafter, use the same summary used for the policy user prompt
647
+ - For others, keep a minimal subset or plain text preview
648
+ """
649
+ # Try Crafter-specific formatter
650
+ try:
651
+ from .envs.crafter.environment import CrafterEnvironmentWrapper as _CrafterWrapper # type: ignore
652
+ except Exception:
653
+ _CrafterWrapper = None # type: ignore
654
+
655
+ if _CrafterWrapper is not None and isinstance(getattr(env_handle, "env", None), _CrafterWrapper):
656
+ try:
657
+ from .envs.crafter.shared import format_observation as _fmt # type: ignore
658
+ text = _fmt(observation or {})
659
+ return {"text": text}
660
+ except Exception:
661
+ pass
662
+
663
+ # Generic fallback: extract a few small fields if present; avoid huge arrays
664
+ try:
665
+ inv = observation.get("inventory") if isinstance(observation, dict) else None
666
+ ach = observation.get("achievements_status") if isinstance(observation, dict) else None
667
+ pos = observation.get("player_position") if isinstance(observation, dict) else None
668
+ health = None
669
+ if isinstance(inv, dict):
670
+ health = inv.get("health")
671
+ summary = {
672
+ "position": pos,
673
+ "health": health,
674
+ "inventory_keys": sorted([k for k, v in (inv or {}).items() if v])[:10] if isinstance(inv, dict) else None,
675
+ "achievements_unlocked": sorted([k for k, v in (ach or {}).items() if v])[:10] if isinstance(ach, dict) else None,
676
+ }
677
+ return {"text": json.dumps(summary, ensure_ascii=False)}
678
+ except Exception:
679
+ pass
680
+
681
+ # Last resort: plain string preview
682
+ try:
683
+ return {"text": str(observation)[:10000]}
684
+ except Exception:
685
+ return {"text": ""}
686
+
687
+
688
+
689
+ class RunAbortRequest(BaseModel):
690
+ run_id: str
691
+
692
+
693
+ class RunAbortResponse(BaseModel):
694
+ ok: bool
695
+ run_id: str
696
+
697
+
698
+ class RunStatusResponse(BaseModel):
699
+ run_id: str
700
+ status: str
701
+ started_at: datetime
702
+ finished_at: Optional[datetime] = None
703
+
704
+
705
+ @router.post("/rollout", response_model=RolloutResponse)
706
+ async def execute_rollout(
707
+ request: RolloutRequest,
708
+ req: Request,
709
+ ) -> RolloutResponse:
710
+ """Execute a rollout with coordinated environment and policy steps."""
711
+ # Enforce per-episode step cap via env-specific parameters; default to 20 if omitted
712
+ try:
713
+ _env_params = {}
714
+ if isinstance(request.env, RolloutEnvSpec) and isinstance(request.env.config, dict):
715
+ _env_params = dict(request.env.config.get("env_params") or {})
716
+ max_steps_per_episode = int(_env_params.get("max_steps_per_episode") or 20)
717
+ assert max_steps_per_episode > 0, "max_steps_per_episode must be a positive integer"
718
+ except Exception as _mse:
719
+ raise HTTPException(
720
+ status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
721
+ detail={
722
+ "error": "invalid_env_params",
723
+ "message": f"Invalid or missing env_params.max_steps_per_episode: {_mse}",
724
+ },
725
+ )
726
+ # Truncate incoming ops to the enforced cap (each step is [agent, env])
727
+ ops_seq: List[str] = list(request.ops or [])
728
+ allowed_ops = max(0, int(max_steps_per_episode) * 2)
729
+ if len(ops_seq) > allowed_ops:
730
+ try:
731
+ logger.info(
732
+ "ROLL_OUT: truncating ops to cap: requested_ops=%s allowed_ops=%s",
733
+ str(len(ops_seq)),
734
+ str(allowed_ops),
735
+ )
736
+ except Exception:
737
+ pass
738
+ ops_seq = ops_seq[:allowed_ops]
739
+ # Simple API key auth for inbound rollout
740
+ header_key = req.headers.get("x-api-key")
741
+ env_key = os.getenv("ENVIRONMENT_API_KEY")
742
+ dev_key = os.getenv("dev_environment_api_key")
743
+ # Accept either ENVIRONMENT_API_KEY or dev_environment_api_key
744
+ expected_keys = [k for k in (env_key, dev_key) if k]
745
+ if not expected_keys:
746
+ missing = []
747
+ if not env_key:
748
+ missing.append("ENVIRONMENT_API_KEY")
749
+ if not dev_key:
750
+ missing.append("dev_environment_api_key")
751
+ msg = f"Auth not configured: missing {', '.join(missing)} in task service environment"
752
+ logger.error(msg)
753
+ raise HTTPException(status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail=msg)
754
+ if not header_key:
755
+ raise HTTPException(
756
+ status_code=status.HTTP_401_UNAUTHORIZED,
757
+ detail="Invalid or missing API key: X-API-Key header not provided",
758
+ )
759
+ if header_key not in expected_keys:
760
+ # Do not leak secrets; include short prefix for diagnostics
761
+ exp_src = env_key if env_key else (dev_key or "")
762
+ exp_prefix = (exp_src[:7] + "…") if len(exp_src) >= 7 else "set"
763
+ got_prefix = (header_key[:7] + "…") if len(header_key) >= 7 else "set"
764
+ raise HTTPException(
765
+ status_code=status.HTTP_401_UNAUTHORIZED,
766
+ detail=f"Invalid API key: header does not match expected (got={got_prefix}, expected_prefix={exp_prefix})",
767
+ )
768
+
769
+ # Log contextual fields for traceability
770
+ if request.training_session_id:
771
+ logger.info(f"ROLL_OUT: training_session_id={request.training_session_id}")
772
+ if request.synth_base_url:
773
+ logger.info(f"ROLL_OUT: synth_base_url={request.synth_base_url}")
774
+
775
+ # Log masked OpenAI API key presence for diagnostics
776
+ try:
777
+ _oa = os.getenv("OPENAI_API_KEY")
778
+ if _oa:
779
+ _pref = (_oa[:6] + "…") if len(_oa) >= 6 else "set"
780
+ logger.info(f"ROLL_OUT: OPENAI_API_KEY present (prefix={_pref})")
781
+ else:
782
+ logger.warning("ROLL_OUT: OPENAI_API_KEY missing")
783
+ except Exception:
784
+ pass
785
+
786
+ # Make synth_base_url available for outbound calls in this app
787
+ try:
788
+ task_app = req.app.state.task_app
789
+ if request.synth_base_url:
790
+ setattr(task_app, "synth_base_url", request.synth_base_url)
791
+ except Exception:
792
+ pass
793
+
794
+ tracer_factory = getattr(req.app.state, "session_tracer_factory", None)
795
+ tracer_instance = None
796
+ if callable(tracer_factory):
797
+ try:
798
+ tracer_instance = tracer_factory()
799
+ except Exception as exc:
800
+ logger.debug(f"TRACER_FACTORY_FAIL: {exc}")
801
+ tracing_context = RolloutTracingContext(tracer_instance, request, req)
802
+ await tracing_context.start_session()
803
+
804
+ # Register run
805
+ registry.register_run(request.run_id)
806
+
807
+ # Track resources created during this rollout so we can guarantee cleanup
808
+ created_env_id: str | None = None
809
+ created_policy_id: str | None = None
810
+ env_seed_used: int | None = None
811
+
812
+ try:
813
+ # Initialize deterministic seed early for the entire rollout
814
+ seed_value: Optional[int] = None
815
+ try:
816
+ if request.env and request.env.seed is not None:
817
+ seed_value = int(request.env.seed)
818
+ else:
819
+ # Derive a stable seed from run_id
820
+ import hashlib as _hashlib # local import to avoid global deps
821
+
822
+ _digest = _hashlib.sha256(request.run_id.encode("utf-8")).hexdigest()
823
+ # Use lower 32 bits to fit common RNG ranges
824
+ seed_value = int(_digest[:8], 16)
825
+ except Exception:
826
+ # Fallback to time-based seed if anything goes wrong
827
+ try:
828
+ seed_value = int((_time.time_ns() // 1_000_000) % (2**31 - 1))
829
+ except Exception:
830
+ seed_value = 42
831
+
832
+ _seed_info = _set_global_seed(int(seed_value))
833
+ try:
834
+ logger.info(
835
+ "ROLL_OUT: RNG seeded seed=%s libs=%s",
836
+ str(_seed_info.get("seed")),
837
+ ",".join(_seed_info.get("libs", [])),
838
+ )
839
+ except Exception:
840
+ pass
841
+ # Resolve or create environment
842
+ if request.env.env_id:
843
+ env_handle = registry.get_env(request.env.env_id)
844
+ if not env_handle:
845
+ raise HTTPException(
846
+ status_code=404,
847
+ detail=f"Environment {request.env.env_id} not found",
848
+ )
849
+ env_id = request.env.env_id
850
+ else:
851
+ # Create new environment
852
+ from .environment_routes import create_environment, EnvCreateRequest
853
+
854
+ if not request.env.env_name:
855
+ raise ValueError("FATAL: env_name is required - NO FALLBACKS!")
856
+
857
+ # Propagate training_session_id via env config for downstream usage
858
+ _env_config = dict(request.env.config or {})
859
+ if request.training_session_id is not None:
860
+ _env_config.setdefault(
861
+ "training_session_id", request.training_session_id
862
+ )
863
+ env_response = await create_environment(
864
+ EnvCreateRequest(
865
+ env_name=request.env.env_name,
866
+ config=_env_config,
867
+ seed=request.env.seed,
868
+ rl_run_id=request.run_id,
869
+ )
870
+ )
871
+ env_id = env_response.env_id
872
+ env_handle = registry.get_env(env_id)
873
+ created_env_id = env_id
874
+
875
+ tracing_context.update_metadata(env_id=env_id)
876
+
877
+ # Resolve or create policy
878
+ if request.policy.policy_id:
879
+ policy_handle = registry.get_policy(request.policy.policy_id)
880
+ if not policy_handle:
881
+ raise HTTPException(
882
+ status_code=404,
883
+ detail=f"Policy {request.policy.policy_id} not found",
884
+ )
885
+ policy_id = request.policy.policy_id
886
+ else:
887
+ # Create new policy
888
+ from .policy_routes import create_policy, PolicyCreateRequest
889
+
890
+ if not request.policy.policy_name:
891
+ raise ValueError("FATAL: policy_name is required - NO FALLBACKS!")
892
+
893
+ # Propagate training_session_id and synth_base_url via policy config
894
+ _policy_config = dict(request.policy.config or {})
895
+ if request.training_session_id is not None:
896
+ _policy_config.setdefault(
897
+ "training_session_id", request.training_session_id
898
+ )
899
+ if request.synth_base_url is not None:
900
+ _policy_config.setdefault("synth_base_url", request.synth_base_url)
901
+ policy_response = await create_policy(
902
+ PolicyCreateRequest(
903
+ policy_name=request.policy.policy_name,
904
+ config=_policy_config,
905
+ rl_run_id=request.run_id,
906
+ bound_env_id=env_id,
907
+ ),
908
+ req,
909
+ )
910
+ policy_id = policy_response.policy_id
911
+ policy_handle = registry.get_policy(policy_id)
912
+ created_policy_id = policy_id
913
+
914
+ tracing_context.update_metadata(policy_id=policy_id)
915
+
916
+ # Bind policy to environment if not already bound
917
+ if policy_handle and not policy_handle.bound_env_id:
918
+ policy_handle.bound_env_id = env_id
919
+
920
+ # Record seed bound to environment for end-of-rollout verification/logging
921
+ try:
922
+ env_seed_used = int(getattr(env_handle, "seed", 0) or 0)
923
+ except Exception:
924
+ env_seed_used = None
925
+ tracing_context.update_metadata(env_seed=env_seed_used)
926
+
927
+ # Initialize trajectory
928
+ trajectory_steps = []
929
+ pending_tool_calls = None
930
+ current_obs = env_handle.last_observation
931
+ total_reward = 0.0
932
+ ops_executed = 0
933
+ last_agent_response_ts: float | None = None
934
+ last_policy_meta: Dict[str, Any] | None = None
935
+ last_env_step_ms: float | None = None
936
+ last_env_step_completed_ts: float | None = None
937
+
938
+ # Stepwise reward configuration (Crafter shaping; gate on explicit enable)
939
+ step_rewards_cfg_raw: Dict[str, Any] = {}
940
+ try:
941
+ if isinstance(request.policy.config, dict):
942
+ step_rewards_cfg_raw = dict(request.policy.config.get("step_rewards") or {})
943
+ except Exception:
944
+ step_rewards_cfg_raw = {}
945
+ if not step_rewards_cfg_raw:
946
+ try:
947
+ if isinstance(request.env.config, dict):
948
+ step_rewards_cfg_raw = dict(request.env.config.get("step_rewards") or {})
949
+ except Exception:
950
+ step_rewards_cfg_raw = {}
951
+
952
+ step_rewards_enabled = bool(step_rewards_cfg_raw.get("enabled", False))
953
+ step_rewards_mode = str(step_rewards_cfg_raw.get("mode") or "off").lower()
954
+ try:
955
+ step_rewards_indicator_lambda = float(
956
+ step_rewards_cfg_raw.get("indicator_lambda") or 0.0
957
+ )
958
+ except Exception:
959
+ step_rewards_indicator_lambda = 0.0
960
+ try:
961
+ step_rewards_beta = float(step_rewards_cfg_raw.get("step_beta") or 0.0)
962
+ except Exception:
963
+ step_rewards_beta = 0.0
964
+ step_rewards_active = step_rewards_enabled and step_rewards_mode == "decision_stepwise"
965
+
966
+ def _extract_achievements(obs: Any) -> Dict[str, bool]:
967
+ if not isinstance(obs, dict):
968
+ return {}
969
+ ach = obs.get("achievements_status")
970
+ if isinstance(ach, dict):
971
+ return {str(k): bool(v) for k, v in ach.items()}
972
+ return {}
973
+
974
+ def _summarize_tool_calls(tool_calls: Any) -> List[Dict[str, Any]]:
975
+ if not tool_calls:
976
+ return []
977
+ try:
978
+ items = (
979
+ tool_calls
980
+ if isinstance(tool_calls, list)
981
+ else list(tool_calls) # tolerates tuples or pydantic lists
982
+ )
983
+ except Exception:
984
+ return []
985
+ summary: List[Dict[str, Any]] = []
986
+ for tc in items:
987
+ tool_name = None
988
+ args: Any = {}
989
+ if isinstance(tc, dict):
990
+ tool_name = tc.get("tool") or tc.get("tool_name") or tc.get("name")
991
+ raw_args = tc.get("arguments") or tc.get("args") or {}
992
+ else:
993
+ tool_name = getattr(tc, "tool", None) or getattr(tc, "tool_name", None)
994
+ raw_args = getattr(tc, "arguments", None) or getattr(tc, "args", None) or {}
995
+ args = raw_args
996
+ if isinstance(raw_args, str):
997
+ try:
998
+ args = json.loads(raw_args)
999
+ except Exception:
1000
+ args = raw_args
1001
+ summary.append({"tool": tool_name, "args": args})
1002
+ return summary
1003
+
1004
+ decision_samples: List[Dict[str, Any]] = []
1005
+ decision_index = 0
1006
+ decision_open = False
1007
+ session_trace = None
1008
+ finalized = False
1009
+ prev_achievements = _extract_achievements(current_obs)
1010
+ # Track episode-level achievements that have been seen as true at any point so far
1011
+ episode_seen_achievements: set[str] = set(
1012
+ [k for k, v in (prev_achievements or {}).items() if bool(v)]
1013
+ )
1014
+ stepwise_indicator_sum = 0.0
1015
+ stepwise_reward_sum = 0.0
1016
+ stepwise_new_achievements_total = 0
1017
+ final_achievement_count = sum(1 for v in prev_achievements.values() if v)
1018
+
1019
+ # Execute ops sequence (capped by env_params.max_steps_per_episode)
1020
+ for op_idx, op in enumerate(ops_seq):
1021
+ # Check for abort
1022
+ if registry.is_run_aborted(request.run_id):
1023
+ logger.info(f"Run {request.run_id} aborted at op {op_idx}")
1024
+ break
1025
+
1026
+ # Check safety limits
1027
+ if ops_executed >= request.safety.max_ops:
1028
+ logger.warning(f"Reached max_ops limit ({request.safety.max_ops})")
1029
+ break
1030
+
1031
+ if op == "agent":
1032
+ # Policy step
1033
+ from .policy_routes import step_policy, PolicyStepRequest
1034
+
1035
+ if not decision_open:
1036
+ await tracing_context.start_decision(decision_index)
1037
+ decision_open = True
1038
+
1039
+ agent_request_start = _time.perf_counter()
1040
+ if last_agent_response_ts is not None and last_policy_meta is not None:
1041
+ try:
1042
+ timing_prev = last_policy_meta.setdefault("timing", {})
1043
+ decision_ms = max(
1044
+ 0.0,
1045
+ (agent_request_start - float(last_agent_response_ts)) * 1000.0,
1046
+ )
1047
+ # Update timing on prior policy meta (kept by previous env step)
1048
+ timing_prev["decision_ms"] = decision_ms
1049
+ if last_env_step_ms is not None:
1050
+ timing_prev["env_step_ms"] = float(last_env_step_ms)
1051
+ timing_prev["overhead_ms"] = max(
1052
+ 0.0, decision_ms - float(last_env_step_ms)
1053
+ )
1054
+ else:
1055
+ timing_prev.setdefault("overhead_ms", 0.0)
1056
+ timing_prev["decision_ready_s"] = agent_request_start
1057
+ # Also backfill the last appended trajectory step so the trainer
1058
+ # can always see decision_ms without relying on shared dict refs.
1059
+ if trajectory_steps:
1060
+ try:
1061
+ _last = trajectory_steps[-1]
1062
+ _info = dict(_last.info or {})
1063
+ _meta = dict(_info.get("meta") or {})
1064
+ _timing = dict(_meta.get("timing") or {})
1065
+ _timing["decision_ms"] = decision_ms
1066
+ if last_env_step_ms is not None:
1067
+ _timing.setdefault("env_step_ms", float(last_env_step_ms))
1068
+ _timing.setdefault("overhead_ms", max(0.0, decision_ms - float(last_env_step_ms)))
1069
+ else:
1070
+ _timing.setdefault("overhead_ms", 0.0)
1071
+ _meta["timing"] = _timing
1072
+ _info["meta"] = _meta
1073
+ _last.info = _info
1074
+ except Exception:
1075
+ pass
1076
+ except Exception:
1077
+ pass
1078
+ last_env_step_ms = None
1079
+ last_env_step_completed_ts = None
1080
+
1081
+ # Build metadata for policy (carry previous tool_calls and env result)
1082
+ metadata = {}
1083
+ if pending_tool_calls:
1084
+ metadata["prev_tool_calls"] = pending_tool_calls
1085
+ if len(trajectory_steps) > 0:
1086
+ last_step = trajectory_steps[-1]
1087
+ # Prefer the last executed tool calls to seed history
1088
+ if last_step.tool_calls:
1089
+ metadata["prev_tool_calls"] = last_step.tool_calls
1090
+ # Provide a compact env result snapshot
1091
+ metadata["prev_env_result"] = {
1092
+ "observation": last_step.obs,
1093
+ "reward": last_step.reward,
1094
+ "done": last_step.done,
1095
+ "truncated": last_step.truncated,
1096
+ "info": last_step.info,
1097
+ }
1098
+
1099
+ # Log compact metadata summary to confirm history threading
1100
+ try:
1101
+ _prev_calls = (
1102
+ metadata["prev_tool_calls"]
1103
+ if isinstance(metadata, dict) and "prev_tool_calls" in metadata
1104
+ else None
1105
+ )
1106
+ _count = len(_prev_calls) if isinstance(_prev_calls, list) else 0
1107
+ _first_guess = None
1108
+ if _count > 0 and isinstance(_prev_calls[0], dict):
1109
+ _args = (
1110
+ _prev_calls[0]["arguments"]
1111
+ if "arguments" in _prev_calls[0]
1112
+ else None
1113
+ )
1114
+ if isinstance(_args, str):
1115
+ import json as _json
1116
+
1117
+ try:
1118
+ _args = _json.loads(_args)
1119
+ except Exception:
1120
+ _args = {}
1121
+ if isinstance(_args, dict):
1122
+ _first_guess = (
1123
+ _args["guess"] if "guess" in _args else None
1124
+ ) or (_args["word"] if "word" in _args else None)
1125
+ logger.info(
1126
+ "POLICY_METADATA: prev_tool_calls=%d first_guess=%r has_prev_env_result=%s",
1127
+ _count,
1128
+ _first_guess,
1129
+ str("prev_env_result" in metadata),
1130
+ )
1131
+ except Exception:
1132
+ pass
1133
+
1134
+ try:
1135
+ policy_response = await step_policy(
1136
+ PolicyStepRequest(
1137
+ policy_id=policy_id,
1138
+ observation=current_obs,
1139
+ metadata=metadata,
1140
+ ),
1141
+ req,
1142
+ )
1143
+ except Exception as _pe:
1144
+ # Do not 500 the rollout; finalize with partial trajectory
1145
+ try:
1146
+ logger.warning(
1147
+ "POLICY_STEP_FAIL: terminating episode early run_id=%s op_idx=%s err=%s",
1148
+ request.run_id,
1149
+ str(op_idx),
1150
+ str(_pe),
1151
+ )
1152
+ except Exception:
1153
+ pass
1154
+
1155
+ # Build partial trajectory and return HTTP 200
1156
+ trajectory = RolloutTrajectory(
1157
+ env_id=env_id,
1158
+ policy_id=policy_id,
1159
+ steps=trajectory_steps,
1160
+ final={
1161
+ "observation": current_obs,
1162
+ "rollout_status": "partial_policy_error",
1163
+ "error": str(_pe),
1164
+ "at_op": op,
1165
+ },
1166
+ length=len(trajectory_steps),
1167
+ decision_samples=decision_samples if step_rewards_active else None,
1168
+ )
1169
+ metrics = RolloutMetrics(
1170
+ episode_returns=[total_reward],
1171
+ mean_return=total_reward,
1172
+ num_steps=len(trajectory_steps),
1173
+ num_episodes=1,
1174
+ )
1175
+ aborted = registry.is_run_aborted(request.run_id)
1176
+ if not aborted:
1177
+ registry.complete_run(request.run_id)
1178
+ if decision_open:
1179
+ await tracing_context.end_decision()
1180
+ decision_open = False
1181
+ if not finalized:
1182
+ session_trace = await tracing_context.finalize(
1183
+ total_reward=total_reward,
1184
+ achievement_state=prev_achievements,
1185
+ total_steps=len(trajectory_steps),
1186
+ )
1187
+ finalized = True
1188
+ trace_payload = tracing_context.build_trace_payload(session_trace)
1189
+ return RolloutResponse(
1190
+ run_id=request.run_id,
1191
+ trajectories=[trajectory],
1192
+ branches={},
1193
+ metrics=metrics,
1194
+ aborted=aborted,
1195
+ ops_executed=ops_executed,
1196
+ trace=trace_payload,
1197
+ )
1198
+
1199
+ agent_response_ts = _time.perf_counter()
1200
+ if isinstance(policy_response.meta, dict):
1201
+ try:
1202
+ timing_cur = policy_response.meta.setdefault("timing", {})
1203
+ timing_cur["agent_request_start_s"] = agent_request_start
1204
+ timing_cur["agent_response_s"] = agent_response_ts
1205
+ if "inference_ms" in policy_response.meta:
1206
+ try:
1207
+ timing_cur.setdefault(
1208
+ "inference_ms",
1209
+ float(policy_response.meta["inference_ms"]),
1210
+ )
1211
+ timing_cur.setdefault(
1212
+ "inference_s",
1213
+ float(policy_response.meta["inference_ms"]) / 1000.0,
1214
+ )
1215
+ except Exception:
1216
+ pass
1217
+ except Exception:
1218
+ pass
1219
+ last_policy_meta = policy_response.meta
1220
+ else:
1221
+ last_policy_meta = None
1222
+ last_agent_response_ts = agent_response_ts
1223
+
1224
+ pending_tool_calls = policy_response.tool_calls
1225
+ await tracing_context.record_tool_invocation(pending_tool_calls)
1226
+ ops_executed += 1
1227
+
1228
+ elif op == "env":
1229
+ if not pending_tool_calls:
1230
+ # Treat absence of tool calls as a soft terminal condition; yield partial trajectory
1231
+ try:
1232
+ logger.warning(
1233
+ "NO_TOOL_CALLS: terminating episode early run_id=%s op_idx=%s",
1234
+ request.run_id,
1235
+ str(op_idx),
1236
+ )
1237
+ except Exception:
1238
+ pass
1239
+ term_step = RolloutStep(
1240
+ obs=current_obs,
1241
+ tool_calls=[],
1242
+ reward=None,
1243
+ done=True,
1244
+ truncated=False,
1245
+ info={
1246
+ "terminated": True,
1247
+ "reason": "no_tool_calls",
1248
+ },
1249
+ )
1250
+ trajectory_steps.append(term_step)
1251
+ trajectory = RolloutTrajectory(
1252
+ env_id=env_id,
1253
+ policy_id=policy_id,
1254
+ steps=trajectory_steps,
1255
+ final={
1256
+ "observation": current_obs,
1257
+ "rollout_status": "partial_no_tool_calls",
1258
+ "at_op": op,
1259
+ },
1260
+ length=len(trajectory_steps),
1261
+ decision_samples=decision_samples if step_rewards_active else None,
1262
+ )
1263
+ metrics = RolloutMetrics(
1264
+ episode_returns=[total_reward],
1265
+ mean_return=total_reward,
1266
+ num_steps=len(trajectory_steps),
1267
+ num_episodes=1,
1268
+ )
1269
+ aborted = registry.is_run_aborted(request.run_id)
1270
+ if not aborted:
1271
+ registry.complete_run(request.run_id)
1272
+ if decision_open:
1273
+ await tracing_context.end_decision()
1274
+ decision_open = False
1275
+ if not finalized:
1276
+ session_trace = await tracing_context.finalize(
1277
+ total_reward=total_reward,
1278
+ achievement_state=prev_achievements,
1279
+ total_steps=len(trajectory_steps),
1280
+ )
1281
+ finalized = True
1282
+ trace_payload = tracing_context.build_trace_payload(session_trace)
1283
+ return RolloutResponse(
1284
+ run_id=request.run_id,
1285
+ trajectories=[trajectory],
1286
+ branches={},
1287
+ metrics=metrics,
1288
+ aborted=aborted,
1289
+ ops_executed=ops_executed,
1290
+ trace=trace_payload,
1291
+ )
1292
+
1293
+ # Environment step
1294
+ from .environment_routes import step_environment, EnvStepRequest
1295
+
1296
+ env_step_error: Exception | None = None
1297
+ env_response = None
1298
+ env_step_start = _time.perf_counter()
1299
+ try:
1300
+ env_response = await step_environment(
1301
+ EnvStepRequest(
1302
+ env_id=env_id,
1303
+ tool_calls=pending_tool_calls,
1304
+ )
1305
+ )
1306
+ except Exception as _ee:
1307
+ env_step_error = _ee
1308
+ env_step_end = _time.perf_counter()
1309
+ env_step_duration_ms = (env_step_end - env_step_start) * 1000.0
1310
+ last_env_step_ms = env_step_duration_ms
1311
+ last_env_step_completed_ts = env_step_end
1312
+ if last_policy_meta is not None:
1313
+ try:
1314
+ timing_env = last_policy_meta.setdefault("timing", {})
1315
+ timing_env["env_step_ms"] = env_step_duration_ms
1316
+ timing_env["env_step_end_s"] = env_step_end
1317
+ except Exception:
1318
+ pass
1319
+
1320
+ if env_step_error is not None:
1321
+ # Invalid action or environment rejection — terminate episode early with partial trajectory
1322
+ try:
1323
+ logger.warning(
1324
+ "ENV_STEP_FAIL: terminating episode early run_id=%s op_idx=%s err=%s",
1325
+ request.run_id,
1326
+ str(op_idx),
1327
+ str(env_step_error),
1328
+ )
1329
+ except Exception:
1330
+ pass
1331
+
1332
+ term_step = RolloutStep(
1333
+ obs=current_obs,
1334
+ tool_calls=pending_tool_calls,
1335
+ reward=None,
1336
+ done=True,
1337
+ truncated=False,
1338
+ info={
1339
+ "terminated": True,
1340
+ "reason": "invalid_action",
1341
+ "error": str(env_step_error),
1342
+ },
1343
+ )
1344
+ trajectory_steps.append(term_step)
1345
+ # Build partial response
1346
+ trajectory = RolloutTrajectory(
1347
+ env_id=env_id,
1348
+ policy_id=policy_id,
1349
+ steps=trajectory_steps,
1350
+ final={
1351
+ "observation": current_obs,
1352
+ "rollout_status": "partial_invalid_action",
1353
+ "error": str(env_step_error),
1354
+ "at_op": op,
1355
+ },
1356
+ length=len(trajectory_steps),
1357
+ decision_samples=decision_samples if step_rewards_active else None,
1358
+ )
1359
+ metrics = RolloutMetrics(
1360
+ episode_returns=[total_reward],
1361
+ mean_return=total_reward,
1362
+ num_steps=len(trajectory_steps),
1363
+ num_episodes=1,
1364
+ )
1365
+ aborted = registry.is_run_aborted(request.run_id)
1366
+ if not aborted:
1367
+ registry.complete_run(request.run_id)
1368
+ if (
1369
+ last_policy_meta is not None
1370
+ and last_agent_response_ts is not None
1371
+ and "decision_ms" not in last_policy_meta.get("timing", {})
1372
+ ):
1373
+ try:
1374
+ timing_last = last_policy_meta.setdefault("timing", {})
1375
+ decision_ms = max(
1376
+ 0.0,
1377
+ (env_step_end - float(last_agent_response_ts)) * 1000.0,
1378
+ )
1379
+ timing_last["decision_ms"] = decision_ms
1380
+ timing_last.setdefault("overhead_ms", max(0.0, decision_ms - env_step_duration_ms))
1381
+ except Exception:
1382
+ pass
1383
+ if decision_open:
1384
+ await tracing_context.end_decision()
1385
+ decision_open = False
1386
+ if not finalized:
1387
+ session_trace = await tracing_context.finalize(
1388
+ total_reward=total_reward,
1389
+ achievement_state=prev_achievements,
1390
+ total_steps=len(trajectory_steps),
1391
+ )
1392
+ finalized = True
1393
+ trace_payload = tracing_context.build_trace_payload(session_trace)
1394
+ return RolloutResponse(
1395
+ run_id=request.run_id,
1396
+ trajectories=[trajectory],
1397
+ branches={},
1398
+ metrics=metrics,
1399
+ aborted=aborted,
1400
+ ops_executed=ops_executed,
1401
+ trace=trace_payload,
1402
+ )
1403
+
1404
+ # Reaching here means env step succeeded
1405
+ assert env_response is not None
1406
+
1407
+ # Record step, including policy meta if present for timing/tokens observability
1408
+ _info = env_response.info if isinstance(env_response.info, dict) else {}
1409
+ # Attach policy meta from the immediately preceding agent step
1410
+ try:
1411
+ prev_meta = {}
1412
+ if "policy_response" in locals() and isinstance(
1413
+ policy_response.meta, dict
1414
+ ): # type: ignore[name-defined]
1415
+ prev_meta = policy_response.meta
1416
+ if prev_meta:
1417
+ _info = dict(_info)
1418
+ _info["meta"] = prev_meta
1419
+ except Exception:
1420
+ pass
1421
+
1422
+ event_metadata = {
1423
+ "op_index": op_idx,
1424
+ }
1425
+ event_id = await tracing_context.record_environment_event(
1426
+ env_handle=env_handle,
1427
+ prev_obs=current_obs,
1428
+ env_response=env_response,
1429
+ next_obs=getattr(env_response, "observation", None),
1430
+ metadata=event_metadata,
1431
+ )
1432
+
1433
+ decision_index += 1
1434
+ next_obs = env_response.observation
1435
+ new_achievement_state = _extract_achievements(next_obs)
1436
+ final_achievement_count = sum(
1437
+ 1 for _, unlocked in new_achievement_state.items() if unlocked
1438
+ )
1439
+ indicator_val = 0
1440
+ reward_stepwise = 0.0
1441
+ decision_rewards_meta: Dict[str, Any] | None = None
1442
+ if step_rewards_active:
1443
+ decision_actions = _summarize_tool_calls(pending_tool_calls)
1444
+ stepwise_info, decision_record, stats = compute_stepwise_reward(
1445
+ prev_achievements or {},
1446
+ new_achievement_state,
1447
+ decision_index,
1448
+ decision_actions,
1449
+ step_rewards_indicator_lambda,
1450
+ )
1451
+ indicator_val = int(stats.get("indicator", 0.0))
1452
+ reward_stepwise = float(stats.get("reward", 0.0))
1453
+ stepwise_indicator_sum += float(stats.get("indicator", 0.0))
1454
+ stepwise_reward_sum += reward_stepwise
1455
+ stepwise_new_achievements_total += int(
1456
+ stats.get("new_achievements_count", 0.0)
1457
+ )
1458
+ if not isinstance(_info, dict):
1459
+ _info = {}
1460
+ else:
1461
+ _info = dict(_info)
1462
+ _info["stepwise"] = stepwise_info
1463
+ # Compute decision-level rewards (absolute vs unique) and attach to metadata
1464
+ try:
1465
+ turned_true = set(stepwise_info.get("new_achievements") or [])
1466
+ seen_before = set(episode_seen_achievements)
1467
+ new_unique = sorted(list(turned_true - seen_before))
1468
+ ach_delta = int(len(turned_true))
1469
+ unique_delta = int(len(new_unique))
1470
+ # Prepare stable lists for logging/metadata
1471
+ all_list = sorted(list(turned_true))
1472
+ # Ensure nested meta exists
1473
+ meta_block = _info.get("meta") if isinstance(_info.get("meta"), dict) else {}
1474
+ decision_rewards = {
1475
+ "turn": int(decision_index),
1476
+ "ach_delta": ach_delta,
1477
+ "unique_delta": unique_delta,
1478
+ "all": all_list,
1479
+ "unique": new_unique,
1480
+ }
1481
+ decision_rewards_meta = decision_rewards
1482
+ meta_block["decision_rewards"] = decision_rewards
1483
+ _info["meta"] = meta_block
1484
+ # Update episode-level seen set after attributing uniqueness to this decision
1485
+ episode_seen_achievements.update(turned_true)
1486
+ except Exception:
1487
+ # Best-effort; do not block rollout on metadata computation
1488
+ pass
1489
+ decision_samples.append(decision_record)
1490
+ prev_achievements = new_achievement_state
1491
+
1492
+ await tracing_context.record_decision_reward(
1493
+ event_id=event_id,
1494
+ decision_meta=decision_rewards_meta,
1495
+ )
1496
+
1497
+ step = RolloutStep(
1498
+ obs=_summarize_observation_for_storage(env_handle, current_obs),
1499
+ tool_calls=pending_tool_calls,
1500
+ reward=env_response.reward,
1501
+ done=env_response.done,
1502
+ truncated=env_response.truncated,
1503
+ info=_info,
1504
+ )
1505
+ trajectory_steps.append(step)
1506
+
1507
+ if env_response.reward is not None:
1508
+ total_reward += env_response.reward
1509
+
1510
+ # Update state
1511
+ current_obs = next_obs
1512
+ pending_tool_calls = None
1513
+ ops_executed += 1
1514
+
1515
+ # Handle episode end
1516
+ if env_response.done:
1517
+ if request.on_done == "reset":
1518
+ # Reset environment
1519
+ from .environment_routes import (
1520
+ reset_environment,
1521
+ EnvResetRequest,
1522
+ )
1523
+
1524
+ reset_response = await reset_environment(
1525
+ EnvResetRequest(env_id=env_id)
1526
+ )
1527
+ current_obs = reset_response.observation
1528
+ elif request.on_done == "terminate":
1529
+ break
1530
+
1531
+ if decision_open:
1532
+ await tracing_context.end_decision()
1533
+ decision_open = False
1534
+
1535
+ else:
1536
+ logger.warning(f"Unknown op: {op}")
1537
+
1538
+ if (
1539
+ last_policy_meta is not None
1540
+ and last_agent_response_ts is not None
1541
+ and "timing" in last_policy_meta
1542
+ and isinstance(last_policy_meta["timing"], dict)
1543
+ and "decision_ms" not in last_policy_meta["timing"]
1544
+ ):
1545
+ try:
1546
+ final_now = last_env_step_completed_ts or _time.perf_counter()
1547
+ final_decision_ms = max(
1548
+ 0.0, (final_now - float(last_agent_response_ts)) * 1000.0
1549
+ )
1550
+ timing_final = last_policy_meta.setdefault("timing", {})
1551
+ timing_final["decision_ms"] = final_decision_ms
1552
+ if last_env_step_ms is not None:
1553
+ timing_final.setdefault(
1554
+ "env_step_ms", float(last_env_step_ms)
1555
+ )
1556
+ timing_final.setdefault(
1557
+ "overhead_ms",
1558
+ max(0.0, final_decision_ms - float(last_env_step_ms)),
1559
+ )
1560
+ else:
1561
+ timing_final.setdefault("overhead_ms", 0.0)
1562
+ except Exception:
1563
+ pass
1564
+
1565
+ # Build trajectory
1566
+ trajectory = RolloutTrajectory(
1567
+ env_id=env_id,
1568
+ policy_id=policy_id,
1569
+ steps=trajectory_steps,
1570
+ final={"observation": _summarize_observation_for_storage(env_handle, current_obs)},
1571
+ length=len(trajectory_steps),
1572
+ decision_samples=decision_samples if step_rewards_active else None,
1573
+ )
1574
+
1575
+ # Build metrics
1576
+ metrics = RolloutMetrics(
1577
+ episode_returns=[total_reward],
1578
+ mean_return=total_reward,
1579
+ num_steps=len(trajectory_steps),
1580
+ num_episodes=1,
1581
+ )
1582
+
1583
+ # Environment-specific: Log summary if available
1584
+ try:
1585
+ # Check if this is a Wordle environment and use Wordle helpers (lazy import)
1586
+ try:
1587
+ from .envs.wordle.environment import WordleEnvironmentWrapper as _WordleWrapper
1588
+ from .envs.wordle.helpers import (
1589
+ get_wordle_rollout_summary,
1590
+ log_wordle_rollout_summary,
1591
+ )
1592
+ except Exception:
1593
+ _WordleWrapper = None # type: ignore
1594
+ get_wordle_rollout_summary = None # type: ignore
1595
+ log_wordle_rollout_summary = None # type: ignore
1596
+
1597
+ is_wordle = _WordleWrapper is not None and isinstance(env_handle.env, _WordleWrapper)
1598
+ if is_wordle:
1599
+ # Convert trajectory steps to expected format
1600
+ formatted_steps = []
1601
+ for step in trajectory_steps:
1602
+ formatted_steps.append({"tool_calls": step.tool_calls or []})
1603
+
1604
+ if get_wordle_rollout_summary is not None and log_wordle_rollout_summary is not None:
1605
+ summary = get_wordle_rollout_summary(
1606
+ formatted_steps, current_obs, env_handle
1607
+ )
1608
+ log_wordle_rollout_summary(request.run_id, summary)
1609
+ except ImportError:
1610
+ # Wordle helpers not available, skip Wordle-specific logging
1611
+ pass
1612
+ except Exception as e:
1613
+ logger.warning(f"Failed to generate environment-specific summary: {e}")
1614
+
1615
+ # Mark run as completed
1616
+ aborted = registry.is_run_aborted(request.run_id)
1617
+ if not aborted:
1618
+ registry.complete_run(request.run_id)
1619
+ if decision_open:
1620
+ await tracing_context.end_decision()
1621
+ decision_open = False
1622
+ if not finalized:
1623
+ session_trace = await tracing_context.finalize(
1624
+ total_reward=total_reward,
1625
+ achievement_state=prev_achievements,
1626
+ total_steps=len(trajectory_steps),
1627
+ )
1628
+ finalized = True
1629
+ trace_payload = tracing_context.build_trace_payload(session_trace)
1630
+
1631
+ return RolloutResponse(
1632
+ run_id=request.run_id,
1633
+ trajectories=[trajectory],
1634
+ branches={},
1635
+ metrics=metrics,
1636
+ aborted=aborted,
1637
+ ops_executed=ops_executed,
1638
+ trace=trace_payload,
1639
+ )
1640
+
1641
+ except Exception as e:
1642
+ logger.error(f"Rollout failed for run {request.run_id}: {e}")
1643
+ registry.abort_run(request.run_id)
1644
+ if decision_open:
1645
+ try:
1646
+ await tracing_context.end_decision()
1647
+ except Exception:
1648
+ pass
1649
+ decision_open = False
1650
+ if not finalized:
1651
+ try:
1652
+ session_trace = await tracing_context.finalize(
1653
+ total_reward=total_reward,
1654
+ achievement_state=prev_achievements,
1655
+ total_steps=len(trajectory_steps),
1656
+ )
1657
+ except Exception:
1658
+ session_trace = None
1659
+ finalized = True
1660
+ raise HTTPException(status_code=500, detail=str(e))
1661
+ finally:
1662
+ # Ensure any environment created for this rollout is terminated (no reuse across rollouts)
1663
+ try:
1664
+ if created_env_id:
1665
+ from .environment_routes import terminate_environment, EnvTerminateRequest
1666
+
1667
+ await terminate_environment(EnvTerminateRequest(env_id=created_env_id))
1668
+ logger.info(
1669
+ "ROLL_OUT: terminated environment env_id=%s seed=%s",
1670
+ str(created_env_id),
1671
+ str(env_seed_used) if env_seed_used is not None else "unknown",
1672
+ )
1673
+ # Verify removal from registry
1674
+ try:
1675
+ _post = registry.get_env(created_env_id)
1676
+ logger.info(
1677
+ "ROLL_OUT: env_killed=%s (post_lookup=%s)",
1678
+ str(_post is None),
1679
+ str(_post),
1680
+ )
1681
+ except Exception:
1682
+ pass
1683
+ except Exception as _te:
1684
+ logger.warning(
1685
+ f"ROLL_OUT: failed to terminate environment {created_env_id}: {_te}"
1686
+ )
1687
+
1688
+ # Best-effort policy cleanup if we created one (avoid reuse across rollouts)
1689
+ try:
1690
+ if created_policy_id:
1691
+ from .policy_routes import terminate_policy, PolicyTerminateRequest
1692
+
1693
+ await terminate_policy(PolicyTerminateRequest(policy_id=created_policy_id))
1694
+ logger.info("ROLL_OUT: terminated policy policy_id=%s", str(created_policy_id))
1695
+ except Exception:
1696
+ pass
1697
+
1698
+ if not finalized:
1699
+ try:
1700
+ session_trace = await tracing_context.finalize(
1701
+ total_reward=total_reward,
1702
+ achievement_state=prev_achievements,
1703
+ total_steps=len(trajectory_steps),
1704
+ )
1705
+ except Exception:
1706
+ session_trace = None
1707
+ finalized = True
1708
+
1709
+ try:
1710
+ _clear_seed_side_effects()
1711
+ logger.info("ROLL_OUT: RNG seed terminated/cleared before conclusion")
1712
+ except Exception:
1713
+ pass
1714
+
1715
+
1716
+ @router.post("/run/abort", response_model=RunAbortResponse)
1717
+ async def abort_run(request: RunAbortRequest) -> RunAbortResponse:
1718
+ """Abort a running rollout."""
1719
+ success = registry.abort_run(request.run_id)
1720
+
1721
+ if not success:
1722
+ raise HTTPException(
1723
+ status_code=404,
1724
+ detail=f"Run {request.run_id} not found",
1725
+ )
1726
+
1727
+ return RunAbortResponse(
1728
+ ok=True,
1729
+ run_id=request.run_id,
1730
+ )
1731
+
1732
+
1733
+ @router.get("/run/status/{run_id}", response_model=RunStatusResponse)
1734
+ async def get_run_status(run_id: str) -> RunStatusResponse:
1735
+ """Get the status of a run."""
1736
+ run_handle = registry.get_run(run_id)
1737
+
1738
+ if not run_handle:
1739
+ raise HTTPException(
1740
+ status_code=404,
1741
+ detail=f"Run {run_id} not found",
1742
+ )
1743
+
1744
+ return RunStatusResponse(
1745
+ run_id=run_id,
1746
+ status=run_handle.status,
1747
+ started_at=run_handle.started_at,
1748
+ finished_at=run_handle.finished_at,
1749
+ )