synth-ai 0.2.9.dev3__py3-none-any.whl → 0.2.9.dev4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of synth-ai might be problematic. Click here for more details.
- examples/analyze_semantic_words.sh +17 -0
- examples/common_old/backend.py +21 -0
- examples/crafter_debug_render.py +180 -0
- examples/evals_old/README.md +98 -0
- examples/evals_old/__init__.py +6 -0
- examples/evals_old/compare_models.py +1037 -0
- examples/evals_old/example_log.md +145 -0
- examples/evals_old/run_demo.sh +126 -0
- examples/evals_old/trace_analysis.py +270 -0
- examples/finetuning_old/_backup_synth_qwen/config.toml +29 -0
- examples/finetuning_old/_backup_synth_qwen/example_log.md +324 -0
- examples/finetuning_old/_backup_synth_qwen/filter_traces.py +60 -0
- examples/finetuning_old/_backup_synth_qwen/filter_traces_achievements.py +239 -0
- examples/finetuning_old/_backup_synth_qwen/purge_v3_traces.py +109 -0
- examples/finetuning_old/_backup_synth_qwen/react_agent_lm.py +1924 -0
- examples/finetuning_old/_backup_synth_qwen/readme.md +49 -0
- examples/finetuning_old/_backup_synth_qwen/run_crafter_qwen4b.py +114 -0
- examples/finetuning_old/_backup_synth_qwen/run_demo.sh +195 -0
- examples/finetuning_old/_backup_synth_qwen/sft_kickoff.py +118 -0
- examples/finetuning_old/synth_qwen_v1/README.md +68 -0
- examples/finetuning_old/synth_qwen_v1/filter_traces.py +60 -0
- examples/finetuning_old/synth_qwen_v1/filter_traces_achievements.py +239 -0
- examples/finetuning_old/synth_qwen_v1/finetune.py +46 -0
- examples/finetuning_old/synth_qwen_v1/hello_ft_model.py +71 -0
- examples/finetuning_old/synth_qwen_v1/infer.py +37 -0
- examples/finetuning_old/synth_qwen_v1/poll.py +44 -0
- examples/finetuning_old/synth_qwen_v1/prepare_data.py +35 -0
- examples/finetuning_old/synth_qwen_v1/purge_v3_traces.py +109 -0
- examples/finetuning_old/synth_qwen_v1/react_agent_lm.py +1932 -0
- examples/finetuning_old/synth_qwen_v1/run_crafter_sft_job.py +207 -0
- examples/finetuning_old/synth_qwen_v1/run_ft_job.py +232 -0
- examples/finetuning_old/synth_qwen_v1/upload_data.py +34 -0
- examples/finetuning_old/synth_qwen_v1/util.py +147 -0
- examples/rl/README.md +169 -0
- examples/rl/configs/eval_base_qwen.toml +15 -0
- examples/rl/configs/eval_rl_qwen.toml +11 -0
- examples/rl/configs/rl_from_base_qwen.toml +35 -0
- examples/rl/configs/rl_from_base_qwen17.toml +74 -0
- examples/rl/configs/rl_from_ft_qwen.toml +35 -0
- examples/rl/download_dataset.py +64 -0
- examples/rl/run_eval.py +435 -0
- examples/rl/run_rl_and_save.py +94 -0
- examples/rl/task_app/README.md +22 -0
- {synth_ai/task/apps → examples/rl/task_app}/math_single_step.py +8 -8
- examples/rl/task_app/math_task_app.py +107 -0
- examples/rl_old/task_app.py +962 -0
- examples/run_crafter_demo.sh +10 -0
- examples/warming_up_to_rl/analyze_trace_db.py +420 -0
- examples/warming_up_to_rl/configs/crafter_fft.toml +48 -0
- examples/warming_up_to_rl/configs/crafter_fft_4b.toml +54 -0
- examples/warming_up_to_rl/configs/eval_fft_qwen4b.toml +20 -0
- examples/warming_up_to_rl/configs/eval_groq_qwen32b.toml +13 -0
- examples/warming_up_to_rl/configs/eval_modal_qwen4b.toml +23 -0
- examples/warming_up_to_rl/configs/rl_from_base_qwen4b.toml +73 -0
- examples/warming_up_to_rl/configs/rl_from_ft.toml +56 -0
- examples/warming_up_to_rl/export_trace_sft.py +541 -0
- examples/warming_up_to_rl/groq_test.py +88 -0
- examples/warming_up_to_rl/manage_secrets.py +127 -0
- examples/warming_up_to_rl/old/event_rewards.md +234 -0
- examples/warming_up_to_rl/old/notes.md +73 -0
- examples/warming_up_to_rl/readme.md +172 -0
- examples/warming_up_to_rl/run_eval.py +434 -0
- examples/warming_up_to_rl/run_fft_and_save.py +309 -0
- examples/warming_up_to_rl/run_local_rollout.py +188 -0
- examples/warming_up_to_rl/run_local_rollout_modal.py +160 -0
- examples/warming_up_to_rl/run_local_rollout_parallel.py +342 -0
- examples/warming_up_to_rl/run_local_rollout_traced.py +372 -0
- examples/warming_up_to_rl/run_rl_and_save.py +101 -0
- examples/warming_up_to_rl/run_rollout_remote.py +129 -0
- examples/warming_up_to_rl/task_app/README.md +38 -0
- {synth_ai/task/apps → examples/warming_up_to_rl/task_app}/grpo_crafter.py +7 -7
- examples/warming_up_to_rl/task_app/grpo_crafter_task_app.py +165 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/README.md +173 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/__init__.py +5 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/branching.py +145 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/environment_routes.py +1271 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/__init__.py +1 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/__init__.py +6 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/app.py +1 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/environment.py +429 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/policy.py +442 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/react_agent.py +96 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/shared.py +302 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/tools.py +47 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/hosted_app.py +202 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/inference/__init__.py +5 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/inference/openai_client.py +512 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/main.py +102 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/policy_routes.py +985 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/registry.py +197 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/rollout.py +1749 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/storage/__init__.py +5 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/storage/volume.py +217 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/test_agents.py +160 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/test_service.py +146 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/test_stepwise_rewards.py +58 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/utils.py +61 -0
- synth_ai/api/train/config_finder.py +18 -18
- synth_ai/api/train/env_resolver.py +28 -1
- synth_ai/cli/task_apps.py +264 -55
- synth_ai/task/apps/__init__.py +54 -13
- {synth_ai-0.2.9.dev3.dist-info → synth_ai-0.2.9.dev4.dist-info}/METADATA +1 -1
- {synth_ai-0.2.9.dev3.dist-info → synth_ai-0.2.9.dev4.dist-info}/RECORD +107 -12
- {synth_ai-0.2.9.dev3.dist-info → synth_ai-0.2.9.dev4.dist-info}/top_level.txt +1 -0
- {synth_ai-0.2.9.dev3.dist-info → synth_ai-0.2.9.dev4.dist-info}/WHEEL +0 -0
- {synth_ai-0.2.9.dev3.dist-info → synth_ai-0.2.9.dev4.dist-info}/entry_points.txt +0 -0
- {synth_ai-0.2.9.dev3.dist-info → synth_ai-0.2.9.dev4.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,1749 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import logging
|
|
4
|
+
import json
|
|
5
|
+
from datetime import datetime
|
|
6
|
+
from pathlib import Path
|
|
7
|
+
from typing import Any, Dict, List, Optional, Tuple
|
|
8
|
+
|
|
9
|
+
from fastapi import APIRouter, HTTPException, Request, status
|
|
10
|
+
import os
|
|
11
|
+
import time as _time
|
|
12
|
+
from pydantic import BaseModel
|
|
13
|
+
from synth_ai.lm.vendors.base import BaseLMResponse
|
|
14
|
+
from synth_ai.tracing_v3.session_tracer import SessionTracer
|
|
15
|
+
from synth_ai.tracing_v3.abstractions import EnvironmentEvent, LMCAISEvent, TimeRecord
|
|
16
|
+
from synth_ai.tracing_v3.llm_call_record_helpers import create_llm_call_record_from_response
|
|
17
|
+
from synth_ai.task.tracing_utils import unique_sft_path
|
|
18
|
+
|
|
19
|
+
from .registry import registry
|
|
20
|
+
|
|
21
|
+
logger = logging.getLogger(__name__)
|
|
22
|
+
|
|
23
|
+
# --- Seeding utilities (robust, optional deps) ---
|
|
24
|
+
def _set_global_seed(seed_value: int) -> Dict[str, Any]:
|
|
25
|
+
"""Set global RNG seeds across common libraries; return details for logging/restoration.
|
|
26
|
+
|
|
27
|
+
Returns a dict containing which libraries were seeded and prior states if obtainable.
|
|
28
|
+
"""
|
|
29
|
+
seeded: Dict[str, Any] = {"seed": int(seed_value), "libs": []}
|
|
30
|
+
try:
|
|
31
|
+
import random as _random # type: ignore
|
|
32
|
+
_random.seed(seed_value)
|
|
33
|
+
seeded["libs"].append("random")
|
|
34
|
+
except Exception:
|
|
35
|
+
pass
|
|
36
|
+
try:
|
|
37
|
+
import numpy as _np # type: ignore
|
|
38
|
+
_np.random.seed(seed_value)
|
|
39
|
+
seeded["libs"].append("numpy")
|
|
40
|
+
except Exception:
|
|
41
|
+
pass
|
|
42
|
+
try:
|
|
43
|
+
import torch as _torch # type: ignore
|
|
44
|
+
if hasattr(_torch, "manual_seed"):
|
|
45
|
+
_torch.manual_seed(seed_value)
|
|
46
|
+
seeded["libs"].append("torch")
|
|
47
|
+
# Make CUDA deterministic if present (best-effort)
|
|
48
|
+
try:
|
|
49
|
+
if getattr(_torch, "cuda", None) and _torch.cuda.is_available():
|
|
50
|
+
_torch.cuda.manual_seed_all(seed_value)
|
|
51
|
+
seeded.setdefault("cuda", True)
|
|
52
|
+
except Exception:
|
|
53
|
+
pass
|
|
54
|
+
# CUDNN deterministic flags (optional)
|
|
55
|
+
try:
|
|
56
|
+
if getattr(_torch, "backends", None) and getattr(_torch.backends, "cudnn", None):
|
|
57
|
+
_torch.backends.cudnn.deterministic = True # type: ignore[attr-defined]
|
|
58
|
+
_torch.backends.cudnn.benchmark = False # type: ignore[attr-defined]
|
|
59
|
+
except Exception:
|
|
60
|
+
pass
|
|
61
|
+
except Exception:
|
|
62
|
+
pass
|
|
63
|
+
return seeded
|
|
64
|
+
|
|
65
|
+
def _clear_seed_side_effects() -> None:
|
|
66
|
+
"""Best-effort cleanup to avoid global deterministic side-effects between requests."""
|
|
67
|
+
# We cannot truly restore prior RNG states without capturing them; we just avoid
|
|
68
|
+
# leaving aggressive deterministic flags enabled where it matters.
|
|
69
|
+
try:
|
|
70
|
+
import torch as _torch # type: ignore
|
|
71
|
+
try:
|
|
72
|
+
if getattr(_torch, "backends", None) and getattr(_torch.backends, "cudnn", None):
|
|
73
|
+
# Re-enable cudnn.benchmark default True only if it was True; safest is False -> leave as is.
|
|
74
|
+
# We'll keep deterministic False to avoid global impact; benchmark left False for stability.
|
|
75
|
+
_torch.backends.cudnn.deterministic = False # type: ignore[attr-defined]
|
|
76
|
+
except Exception:
|
|
77
|
+
pass
|
|
78
|
+
except Exception:
|
|
79
|
+
pass
|
|
80
|
+
|
|
81
|
+
router = APIRouter()
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
class RolloutEnvSpec(BaseModel):
|
|
85
|
+
env_id: Optional[str] = None
|
|
86
|
+
env_name: Optional[str] = None
|
|
87
|
+
config: Dict[str, Any] = {}
|
|
88
|
+
seed: Optional[int] = None
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
class RolloutPolicySpec(BaseModel):
|
|
92
|
+
policy_id: Optional[str] = None
|
|
93
|
+
policy_name: Optional[str] = None
|
|
94
|
+
config: Dict[str, Any] = {}
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
class RolloutBranchConfig(BaseModel):
|
|
98
|
+
branch_every_n_steps: int = 0
|
|
99
|
+
branch_on_condition: Optional[str] = None
|
|
100
|
+
max_branches: int = 0
|
|
101
|
+
branch_policy: bool = False
|
|
102
|
+
branch_env: bool = False
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
class RolloutRecordConfig(BaseModel):
|
|
106
|
+
trajectories: bool = True
|
|
107
|
+
logprobs: bool = False
|
|
108
|
+
value: bool = False
|
|
109
|
+
return_trace: bool = False
|
|
110
|
+
trace_format: str = "compact"
|
|
111
|
+
|
|
112
|
+
|
|
113
|
+
class RolloutSafetyConfig(BaseModel):
|
|
114
|
+
max_ops: int = 100000
|
|
115
|
+
max_time_s: float = 3600.0
|
|
116
|
+
|
|
117
|
+
|
|
118
|
+
class RolloutRequest(BaseModel):
|
|
119
|
+
run_id: str
|
|
120
|
+
env: RolloutEnvSpec
|
|
121
|
+
policy: RolloutPolicySpec
|
|
122
|
+
ops: List[str] # ["agent", "env", ...]
|
|
123
|
+
record: RolloutRecordConfig = RolloutRecordConfig()
|
|
124
|
+
on_done: str = "reset" # "reset" | "terminate"
|
|
125
|
+
branch: Optional[RolloutBranchConfig] = None
|
|
126
|
+
safety: RolloutSafetyConfig = RolloutSafetyConfig()
|
|
127
|
+
# Optional run/session context
|
|
128
|
+
training_session_id: Optional[str] = None
|
|
129
|
+
synth_base_url: Optional[str] = None
|
|
130
|
+
|
|
131
|
+
|
|
132
|
+
class RolloutStep(BaseModel):
|
|
133
|
+
obs: Dict[str, Any]
|
|
134
|
+
tool_calls: List[Dict[str, Any]]
|
|
135
|
+
reward: Optional[float] = None
|
|
136
|
+
done: bool = False
|
|
137
|
+
truncated: Optional[bool] = None
|
|
138
|
+
logprob: Optional[float] = None
|
|
139
|
+
value: Optional[float] = None
|
|
140
|
+
info: Optional[Dict[str, Any]] = None
|
|
141
|
+
|
|
142
|
+
|
|
143
|
+
class RolloutTrajectory(BaseModel):
|
|
144
|
+
env_id: str
|
|
145
|
+
policy_id: str
|
|
146
|
+
steps: List[RolloutStep]
|
|
147
|
+
final: Optional[Dict[str, Any]] = None
|
|
148
|
+
length: int
|
|
149
|
+
decision_samples: Optional[List[Dict[str, Any]]] = None
|
|
150
|
+
|
|
151
|
+
|
|
152
|
+
def compute_stepwise_reward(
|
|
153
|
+
prev_achievements: Dict[str, bool],
|
|
154
|
+
new_achievements: Dict[str, bool],
|
|
155
|
+
decision_index: int,
|
|
156
|
+
actions_summary: List[Dict[str, Any]],
|
|
157
|
+
indicator_lambda: float,
|
|
158
|
+
) -> Tuple[Dict[str, Any], Dict[str, Any], Dict[str, float]]:
|
|
159
|
+
"""Compute stepwise reward metadata given achievement states before/after a decision."""
|
|
160
|
+
|
|
161
|
+
prev_map = prev_achievements or {}
|
|
162
|
+
next_map = new_achievements or {}
|
|
163
|
+
|
|
164
|
+
unlocked = [
|
|
165
|
+
name
|
|
166
|
+
for name, value in next_map.items()
|
|
167
|
+
if value and not prev_map.get(name, False)
|
|
168
|
+
]
|
|
169
|
+
indicator = 1 if unlocked else 0
|
|
170
|
+
reward_value = float(indicator_lambda) * indicator
|
|
171
|
+
|
|
172
|
+
stepwise_info = {
|
|
173
|
+
"decision_index": decision_index,
|
|
174
|
+
"indicator": indicator,
|
|
175
|
+
"new_achievements": unlocked,
|
|
176
|
+
"reward": reward_value,
|
|
177
|
+
}
|
|
178
|
+
decision_sample = {
|
|
179
|
+
"decision_index": decision_index,
|
|
180
|
+
"indicator": indicator,
|
|
181
|
+
"r_i": reward_value,
|
|
182
|
+
"actions": actions_summary,
|
|
183
|
+
}
|
|
184
|
+
stats = {
|
|
185
|
+
"indicator": float(indicator),
|
|
186
|
+
"reward": reward_value,
|
|
187
|
+
"new_achievements_count": float(len(unlocked)),
|
|
188
|
+
}
|
|
189
|
+
return stepwise_info, decision_sample, stats
|
|
190
|
+
|
|
191
|
+
|
|
192
|
+
class RolloutMetrics(BaseModel):
|
|
193
|
+
episode_returns: List[float]
|
|
194
|
+
mean_return: float
|
|
195
|
+
num_steps: int
|
|
196
|
+
num_episodes: int = 0
|
|
197
|
+
|
|
198
|
+
|
|
199
|
+
class RolloutResponse(BaseModel):
|
|
200
|
+
run_id: str
|
|
201
|
+
trajectories: List[RolloutTrajectory]
|
|
202
|
+
branches: Dict[str, List[str]] = {}
|
|
203
|
+
metrics: RolloutMetrics
|
|
204
|
+
aborted: bool = False
|
|
205
|
+
ops_executed: int = 0
|
|
206
|
+
trace: Dict[str, Any] | None = None
|
|
207
|
+
|
|
208
|
+
|
|
209
|
+
class RolloutTracingContext:
|
|
210
|
+
"""Helper managing tracing_v3 recording and optional SFT dumps for a rollout."""
|
|
211
|
+
|
|
212
|
+
def __init__(
|
|
213
|
+
self,
|
|
214
|
+
tracer: SessionTracer | None,
|
|
215
|
+
request: RolloutRequest,
|
|
216
|
+
fastapi_request: Request,
|
|
217
|
+
) -> None:
|
|
218
|
+
self.tracer = tracer
|
|
219
|
+
self.enabled = tracer is not None
|
|
220
|
+
self.request = request
|
|
221
|
+
self.fastapi_request = fastapi_request
|
|
222
|
+
self.run_id = request.run_id
|
|
223
|
+
self.current_step_id: str | None = None
|
|
224
|
+
self.current_turn: int | None = None
|
|
225
|
+
self.lm_calls_summary: list[dict[str, Any]] = []
|
|
226
|
+
self.decision_rewards: list[dict[str, Any]] = []
|
|
227
|
+
self.sft_records: list[dict[str, Any]] = []
|
|
228
|
+
self.latest_system_messages: list[str] = []
|
|
229
|
+
self.latest_user_messages: list[str] = []
|
|
230
|
+
self.trace_format = (getattr(request.record, "trace_format", "compact") or "compact").lower()
|
|
231
|
+
self.return_trace = bool(getattr(request.record, "return_trace", False))
|
|
232
|
+
self.sft_output_dir = getattr(fastapi_request.app.state, "sft_output_dir", None)
|
|
233
|
+
self.session_trace = None
|
|
234
|
+
self.metadata_updates: dict[str, Any] = {}
|
|
235
|
+
self.policy_name = request.policy.policy_name or ""
|
|
236
|
+
self.env_name = request.env.env_name or ""
|
|
237
|
+
self.metadata_base: dict[str, Any] = {
|
|
238
|
+
"run_id": self.run_id,
|
|
239
|
+
"policy_name": self.policy_name,
|
|
240
|
+
"policy_id": request.policy.policy_id,
|
|
241
|
+
"env_name": self.env_name,
|
|
242
|
+
"env_id": request.env.env_id,
|
|
243
|
+
"seed": request.env.seed,
|
|
244
|
+
"training_session_id": request.training_session_id,
|
|
245
|
+
"synth_base_url": request.synth_base_url,
|
|
246
|
+
}
|
|
247
|
+
|
|
248
|
+
# Expose context for downstream calls inside this request lifecycle
|
|
249
|
+
fastapi_request.state.rollout_tracing = self
|
|
250
|
+
fastapi_request.state.rollout_run_id = self.run_id
|
|
251
|
+
|
|
252
|
+
async def start_session(self) -> None:
|
|
253
|
+
if not self.enabled or self.tracer is None:
|
|
254
|
+
return
|
|
255
|
+
try:
|
|
256
|
+
await self.tracer.initialize()
|
|
257
|
+
except Exception as exc:
|
|
258
|
+
logger.debug("TRACING_INIT_FAIL: %s", exc)
|
|
259
|
+
try:
|
|
260
|
+
await self.tracer.start_session(session_id=self.run_id, metadata=dict(self.metadata_base))
|
|
261
|
+
except Exception as exc:
|
|
262
|
+
logger.warning("TRACING_START_FAIL: %s", exc)
|
|
263
|
+
self.enabled = False
|
|
264
|
+
self.tracer = None
|
|
265
|
+
|
|
266
|
+
async def start_decision(self, turn_number: int) -> None:
|
|
267
|
+
self.current_turn = turn_number
|
|
268
|
+
self.current_step_id = f"decision_{turn_number}"
|
|
269
|
+
if not self.enabled or self.tracer is None:
|
|
270
|
+
return
|
|
271
|
+
try:
|
|
272
|
+
await self.tracer.start_timestep(step_id=self.current_step_id, turn_number=turn_number)
|
|
273
|
+
except Exception as exc:
|
|
274
|
+
logger.debug("TRACING_STEP_START_FAIL: %s", exc)
|
|
275
|
+
|
|
276
|
+
async def end_decision(self) -> None:
|
|
277
|
+
if not self.enabled or self.tracer is None:
|
|
278
|
+
return
|
|
279
|
+
try:
|
|
280
|
+
await self.tracer.end_timestep(step_id=self.current_step_id)
|
|
281
|
+
except Exception as exc:
|
|
282
|
+
logger.debug("TRACING_STEP_END_FAIL: %s", exc)
|
|
283
|
+
finally:
|
|
284
|
+
self.current_step_id = None
|
|
285
|
+
|
|
286
|
+
def _message_metadata(self) -> dict[str, Any]:
|
|
287
|
+
return {
|
|
288
|
+
"turn": self.current_turn,
|
|
289
|
+
"step_id": self.current_step_id,
|
|
290
|
+
}
|
|
291
|
+
|
|
292
|
+
async def record_policy_prompts(
|
|
293
|
+
self,
|
|
294
|
+
system_messages: list[str],
|
|
295
|
+
user_messages: list[str],
|
|
296
|
+
) -> None:
|
|
297
|
+
self.latest_system_messages = list(system_messages)
|
|
298
|
+
self.latest_user_messages = list(user_messages)
|
|
299
|
+
if not self.enabled or self.tracer is None:
|
|
300
|
+
return
|
|
301
|
+
for msg in system_messages:
|
|
302
|
+
try:
|
|
303
|
+
await self.tracer.record_message(
|
|
304
|
+
content=msg,
|
|
305
|
+
message_type="policy_system_prompt",
|
|
306
|
+
metadata=self._message_metadata(),
|
|
307
|
+
)
|
|
308
|
+
except Exception as exc:
|
|
309
|
+
logger.debug("TRACING_SYSTEM_MSG_FAIL: %s", exc)
|
|
310
|
+
for msg in user_messages:
|
|
311
|
+
try:
|
|
312
|
+
await self.tracer.record_message(
|
|
313
|
+
content=msg,
|
|
314
|
+
message_type="policy_user_prompt",
|
|
315
|
+
metadata=self._message_metadata(),
|
|
316
|
+
)
|
|
317
|
+
except Exception as exc:
|
|
318
|
+
logger.debug("TRACING_USER_MSG_FAIL: %s", exc)
|
|
319
|
+
|
|
320
|
+
def _content_to_text(self, content: Any) -> str:
|
|
321
|
+
if isinstance(content, str):
|
|
322
|
+
return content
|
|
323
|
+
if isinstance(content, list):
|
|
324
|
+
parts: list[str] = []
|
|
325
|
+
for seg in content:
|
|
326
|
+
if isinstance(seg, dict):
|
|
327
|
+
text_val = seg.get("text") or seg.get("content")
|
|
328
|
+
if isinstance(text_val, str):
|
|
329
|
+
parts.append(text_val)
|
|
330
|
+
return "".join(parts)
|
|
331
|
+
if content is None:
|
|
332
|
+
return ""
|
|
333
|
+
return str(content)
|
|
334
|
+
|
|
335
|
+
def _safe_json(self, payload: Any, limit: int = 4000) -> str:
|
|
336
|
+
try:
|
|
337
|
+
text = json.dumps(payload, ensure_ascii=False)
|
|
338
|
+
except Exception:
|
|
339
|
+
text = str(payload)
|
|
340
|
+
if len(text) > limit:
|
|
341
|
+
return text[:limit] + "…"
|
|
342
|
+
return text
|
|
343
|
+
|
|
344
|
+
async def record_tool_invocation(self, tool_calls: list[dict[str, Any]] | None) -> None:
|
|
345
|
+
if tool_calls is None:
|
|
346
|
+
return
|
|
347
|
+
if self.enabled and self.tracer is not None:
|
|
348
|
+
try:
|
|
349
|
+
await self.tracer.record_message(
|
|
350
|
+
content=self._safe_json(tool_calls),
|
|
351
|
+
message_type="policy_tool_call",
|
|
352
|
+
metadata=self._message_metadata(),
|
|
353
|
+
)
|
|
354
|
+
except Exception as exc:
|
|
355
|
+
logger.debug("TRACING_TOOL_MSG_FAIL: %s", exc)
|
|
356
|
+
|
|
357
|
+
async def _record_event(self, event: Any) -> int | None:
|
|
358
|
+
if not self.enabled or self.tracer is None:
|
|
359
|
+
return None
|
|
360
|
+
try:
|
|
361
|
+
return await self.tracer.record_event(event)
|
|
362
|
+
except Exception as exc:
|
|
363
|
+
logger.debug("TRACING_EVENT_FAIL: %s", exc)
|
|
364
|
+
return None
|
|
365
|
+
|
|
366
|
+
async def record_llm_call(
|
|
367
|
+
self,
|
|
368
|
+
*,
|
|
369
|
+
inference_request: dict[str, Any],
|
|
370
|
+
inference_response: dict[str, Any],
|
|
371
|
+
tool_calls: list[dict[str, Any]] | None,
|
|
372
|
+
provider: str,
|
|
373
|
+
model_name: str,
|
|
374
|
+
started_at: datetime,
|
|
375
|
+
completed_at: datetime,
|
|
376
|
+
latency_ms: int | None,
|
|
377
|
+
) -> None:
|
|
378
|
+
usage = inference_response.get("usage") or {}
|
|
379
|
+
input_tokens = usage.get("input_tokens") or usage.get("prompt_tokens")
|
|
380
|
+
output_tokens = usage.get("output_tokens") or usage.get("completion_tokens")
|
|
381
|
+
total_tokens = usage.get("total_tokens")
|
|
382
|
+
cost_usd = (
|
|
383
|
+
usage.get("cost_usd")
|
|
384
|
+
or usage.get("cost")
|
|
385
|
+
or usage.get("total_cost")
|
|
386
|
+
)
|
|
387
|
+
|
|
388
|
+
assistant_message = None
|
|
389
|
+
choices = inference_response.get("choices") or []
|
|
390
|
+
if choices:
|
|
391
|
+
assistant_message = choices[0].get("message") or {}
|
|
392
|
+
assistant_content = assistant_message.get("content") if isinstance(assistant_message, dict) else None
|
|
393
|
+
|
|
394
|
+
raw_response = self._content_to_text(assistant_content)
|
|
395
|
+
if not raw_response:
|
|
396
|
+
raw_response = self._safe_json(inference_response, limit=2000)
|
|
397
|
+
|
|
398
|
+
base_response = BaseLMResponse(
|
|
399
|
+
raw_response=raw_response,
|
|
400
|
+
tool_calls=assistant_message.get("tool_calls") if isinstance(assistant_message, dict) else None,
|
|
401
|
+
usage=usage or None,
|
|
402
|
+
api_type="chat_completions",
|
|
403
|
+
)
|
|
404
|
+
|
|
405
|
+
request_messages = inference_request.get("messages") or []
|
|
406
|
+
try:
|
|
407
|
+
temperature = float(inference_request.get("temperature"))
|
|
408
|
+
except Exception:
|
|
409
|
+
temperature = 0.0
|
|
410
|
+
|
|
411
|
+
call_record = create_llm_call_record_from_response(
|
|
412
|
+
response=base_response,
|
|
413
|
+
model_name=model_name,
|
|
414
|
+
provider=provider,
|
|
415
|
+
messages=request_messages,
|
|
416
|
+
temperature=temperature,
|
|
417
|
+
request_params=inference_request,
|
|
418
|
+
tools=inference_request.get("tools"),
|
|
419
|
+
started_at=started_at,
|
|
420
|
+
completed_at=completed_at,
|
|
421
|
+
latency_ms=latency_ms,
|
|
422
|
+
)
|
|
423
|
+
|
|
424
|
+
event_metadata = {
|
|
425
|
+
"policy_id": self.request.policy.policy_id,
|
|
426
|
+
"turn": self.current_turn,
|
|
427
|
+
"run_id": self.run_id,
|
|
428
|
+
}
|
|
429
|
+
|
|
430
|
+
event = LMCAISEvent(
|
|
431
|
+
system_instance_id=f"policy:{self.policy_name or 'unknown'}",
|
|
432
|
+
time_record=TimeRecord(event_time=completed_at.timestamp()),
|
|
433
|
+
model_name=model_name,
|
|
434
|
+
provider=provider,
|
|
435
|
+
input_tokens=input_tokens,
|
|
436
|
+
output_tokens=output_tokens,
|
|
437
|
+
total_tokens=total_tokens,
|
|
438
|
+
cost_usd=cost_usd,
|
|
439
|
+
latency_ms=latency_ms,
|
|
440
|
+
call_records=[call_record],
|
|
441
|
+
metadata=event_metadata,
|
|
442
|
+
)
|
|
443
|
+
|
|
444
|
+
await self._record_event(event)
|
|
445
|
+
|
|
446
|
+
self.lm_calls_summary.append(
|
|
447
|
+
{
|
|
448
|
+
"turn": self.current_turn,
|
|
449
|
+
"model": model_name,
|
|
450
|
+
"provider": provider,
|
|
451
|
+
"total_tokens": total_tokens,
|
|
452
|
+
"input_tokens": input_tokens,
|
|
453
|
+
"output_tokens": output_tokens,
|
|
454
|
+
"latency_ms": latency_ms,
|
|
455
|
+
"tool_calls": len(tool_calls or []),
|
|
456
|
+
}
|
|
457
|
+
)
|
|
458
|
+
|
|
459
|
+
if self.sft_output_dir is not None:
|
|
460
|
+
assistant_text = self._content_to_text(assistant_content)
|
|
461
|
+
record = {
|
|
462
|
+
"run_id": self.run_id,
|
|
463
|
+
"turn": self.current_turn,
|
|
464
|
+
"model": model_name,
|
|
465
|
+
"provider": provider,
|
|
466
|
+
"dialogue": (
|
|
467
|
+
[{"role": "system", "content": s} for s in self.latest_system_messages]
|
|
468
|
+
+ [{"role": "user", "content": u} for u in self.latest_user_messages]
|
|
469
|
+
),
|
|
470
|
+
"assistant": {
|
|
471
|
+
"content": assistant_text,
|
|
472
|
+
"tool_calls": assistant_message.get("tool_calls") if isinstance(assistant_message, dict) else [],
|
|
473
|
+
},
|
|
474
|
+
"timestamp": datetime.utcnow().isoformat(),
|
|
475
|
+
}
|
|
476
|
+
self.sft_records.append(record)
|
|
477
|
+
|
|
478
|
+
async def record_environment_event(
|
|
479
|
+
self,
|
|
480
|
+
*,
|
|
481
|
+
env_handle: Any,
|
|
482
|
+
prev_obs: Dict[str, Any] | None,
|
|
483
|
+
env_response: Any,
|
|
484
|
+
next_obs: Dict[str, Any] | None,
|
|
485
|
+
metadata: Dict[str, Any] | None = None,
|
|
486
|
+
) -> int | None:
|
|
487
|
+
if not self.enabled or self.tracer is None:
|
|
488
|
+
return None
|
|
489
|
+
|
|
490
|
+
try:
|
|
491
|
+
prev_summary = _summarize_observation_for_storage(env_handle, prev_obs or {}) if prev_obs is not None else None
|
|
492
|
+
except Exception:
|
|
493
|
+
prev_summary = None
|
|
494
|
+
try:
|
|
495
|
+
next_summary = _summarize_observation_for_storage(env_handle, next_obs or {}) if next_obs is not None else None
|
|
496
|
+
except Exception:
|
|
497
|
+
next_summary = None
|
|
498
|
+
|
|
499
|
+
reward_val = getattr(env_response, "reward", None)
|
|
500
|
+
try:
|
|
501
|
+
reward_float = float(reward_val) if reward_val is not None else 0.0
|
|
502
|
+
except Exception:
|
|
503
|
+
reward_float = 0.0
|
|
504
|
+
|
|
505
|
+
event = EnvironmentEvent(
|
|
506
|
+
system_instance_id=f"environment:{self.env_name or 'unknown'}",
|
|
507
|
+
time_record=TimeRecord(event_time=datetime.utcnow().timestamp()),
|
|
508
|
+
reward=reward_float,
|
|
509
|
+
terminated=bool(getattr(env_response, "done", False)),
|
|
510
|
+
truncated=bool(getattr(env_response, "truncated", False)),
|
|
511
|
+
system_state_before=prev_summary,
|
|
512
|
+
system_state_after=next_summary,
|
|
513
|
+
metadata={
|
|
514
|
+
"turn": self.current_turn,
|
|
515
|
+
"run_id": self.run_id,
|
|
516
|
+
**(metadata or {}),
|
|
517
|
+
},
|
|
518
|
+
)
|
|
519
|
+
|
|
520
|
+
return await self._record_event(event)
|
|
521
|
+
|
|
522
|
+
async def record_decision_reward(
|
|
523
|
+
self,
|
|
524
|
+
*,
|
|
525
|
+
event_id: int | None,
|
|
526
|
+
decision_meta: Dict[str, Any] | None,
|
|
527
|
+
) -> None:
|
|
528
|
+
decision_meta = decision_meta or {}
|
|
529
|
+
ach_delta = int(decision_meta.get("ach_delta", 0))
|
|
530
|
+
unique_delta = int(decision_meta.get("unique_delta", 0))
|
|
531
|
+
all_ach = list(decision_meta.get("all") or [])
|
|
532
|
+
unique_ach = list(decision_meta.get("unique") or [])
|
|
533
|
+
|
|
534
|
+
self.decision_rewards.append(
|
|
535
|
+
{
|
|
536
|
+
"turn": self.current_turn,
|
|
537
|
+
"ach_delta": ach_delta,
|
|
538
|
+
"unique_delta": unique_delta,
|
|
539
|
+
"achievements": all_ach,
|
|
540
|
+
"unique_achievements": unique_ach,
|
|
541
|
+
}
|
|
542
|
+
)
|
|
543
|
+
|
|
544
|
+
if not self.enabled or self.tracer is None or event_id is None:
|
|
545
|
+
return
|
|
546
|
+
try:
|
|
547
|
+
await self.tracer.record_event_reward(
|
|
548
|
+
event_id=event_id,
|
|
549
|
+
turn_number=self.current_turn,
|
|
550
|
+
reward_value=float(ach_delta),
|
|
551
|
+
reward_type="achievement_delta",
|
|
552
|
+
annotation={"achievements": all_ach},
|
|
553
|
+
source="environment",
|
|
554
|
+
)
|
|
555
|
+
if unique_delta:
|
|
556
|
+
await self.tracer.record_event_reward(
|
|
557
|
+
event_id=event_id,
|
|
558
|
+
turn_number=self.current_turn,
|
|
559
|
+
reward_value=float(unique_delta),
|
|
560
|
+
reward_type="unique_achievement_delta",
|
|
561
|
+
annotation={"achievements": unique_ach},
|
|
562
|
+
source="environment",
|
|
563
|
+
)
|
|
564
|
+
except Exception as exc:
|
|
565
|
+
logger.debug("TRACING_REWARD_FAIL: %s", exc)
|
|
566
|
+
|
|
567
|
+
def update_metadata(self, **kwargs: Any) -> None:
|
|
568
|
+
self.metadata_updates.update({k: v for k, v in kwargs.items() if v is not None})
|
|
569
|
+
|
|
570
|
+
async def finalize(
|
|
571
|
+
self,
|
|
572
|
+
*,
|
|
573
|
+
total_reward: float,
|
|
574
|
+
achievement_state: Dict[str, bool] | None,
|
|
575
|
+
total_steps: int,
|
|
576
|
+
) -> Any:
|
|
577
|
+
final_achievements = [key for key, val in (achievement_state or {}).items() if val]
|
|
578
|
+
self.metadata_updates.setdefault("final_achievements", final_achievements)
|
|
579
|
+
if self.enabled and self.tracer is not None:
|
|
580
|
+
try:
|
|
581
|
+
await self.tracer.record_outcome_reward(
|
|
582
|
+
total_reward=int(total_reward),
|
|
583
|
+
achievements_count=len(final_achievements),
|
|
584
|
+
total_steps=int(total_steps),
|
|
585
|
+
reward_metadata=dict(self.metadata_updates),
|
|
586
|
+
)
|
|
587
|
+
except Exception as exc:
|
|
588
|
+
logger.debug("TRACING_OUTCOME_FAIL: %s", exc)
|
|
589
|
+
try:
|
|
590
|
+
self.session_trace = await self.tracer.end_session()
|
|
591
|
+
if self.session_trace is not None:
|
|
592
|
+
self.session_trace.metadata.update(self.metadata_updates)
|
|
593
|
+
except Exception as exc:
|
|
594
|
+
logger.debug("TRACING_END_SESSION_FAIL: %s", exc)
|
|
595
|
+
self.session_trace = None
|
|
596
|
+
try:
|
|
597
|
+
await self.tracer.close()
|
|
598
|
+
except Exception:
|
|
599
|
+
pass
|
|
600
|
+
|
|
601
|
+
if self.sft_records and self.sft_output_dir:
|
|
602
|
+
self.write_sft_records()
|
|
603
|
+
|
|
604
|
+
# Clear context from request state to avoid leaks
|
|
605
|
+
self.fastapi_request.state.rollout_tracing = None
|
|
606
|
+
|
|
607
|
+
return self.session_trace
|
|
608
|
+
|
|
609
|
+
def write_sft_records(self) -> None:
|
|
610
|
+
if not self.sft_output_dir or not self.sft_records:
|
|
611
|
+
return
|
|
612
|
+
try:
|
|
613
|
+
path = unique_sft_path(self.sft_output_dir, run_id=self.run_id)
|
|
614
|
+
path.parent.mkdir(parents=True, exist_ok=True)
|
|
615
|
+
with path.open("w", encoding="utf-8") as fh:
|
|
616
|
+
for record in self.sft_records:
|
|
617
|
+
json.dump(record, fh, ensure_ascii=False)
|
|
618
|
+
fh.write("\n")
|
|
619
|
+
logger.info(f"SFT_WRITTEN: {path}")
|
|
620
|
+
except Exception as exc:
|
|
621
|
+
logger.warning(f"SFT_WRITE_FAIL: {exc}")
|
|
622
|
+
finally:
|
|
623
|
+
self.sft_records.clear()
|
|
624
|
+
|
|
625
|
+
def build_trace_payload(self, session_trace: Any) -> Dict[str, Any] | None:
|
|
626
|
+
if not self.return_trace or session_trace is None:
|
|
627
|
+
return None
|
|
628
|
+
if self.trace_format == "full":
|
|
629
|
+
payload = session_trace.to_dict()
|
|
630
|
+
payload.setdefault("metadata", {}).update(self.metadata_updates)
|
|
631
|
+
return payload
|
|
632
|
+
metadata = dict(session_trace.metadata)
|
|
633
|
+
metadata.update(self.metadata_updates)
|
|
634
|
+
return {
|
|
635
|
+
"session_id": session_trace.session_id,
|
|
636
|
+
"created_at": session_trace.created_at.isoformat(),
|
|
637
|
+
"metadata": metadata,
|
|
638
|
+
"events_count": len(session_trace.event_history),
|
|
639
|
+
"messages_count": len(session_trace.markov_blanket_message_history),
|
|
640
|
+
"lm_calls": self.lm_calls_summary,
|
|
641
|
+
"decision_rewards": self.decision_rewards,
|
|
642
|
+
}
|
|
643
|
+
def _summarize_observation_for_storage(env_handle: Any, observation: Dict[str, Any]) -> Dict[str, Any]:
|
|
644
|
+
"""Return a compact dict for trajectory storage instead of the raw observation.
|
|
645
|
+
|
|
646
|
+
- For Crafter, use the same summary used for the policy user prompt
|
|
647
|
+
- For others, keep a minimal subset or plain text preview
|
|
648
|
+
"""
|
|
649
|
+
# Try Crafter-specific formatter
|
|
650
|
+
try:
|
|
651
|
+
from .envs.crafter.environment import CrafterEnvironmentWrapper as _CrafterWrapper # type: ignore
|
|
652
|
+
except Exception:
|
|
653
|
+
_CrafterWrapper = None # type: ignore
|
|
654
|
+
|
|
655
|
+
if _CrafterWrapper is not None and isinstance(getattr(env_handle, "env", None), _CrafterWrapper):
|
|
656
|
+
try:
|
|
657
|
+
from .envs.crafter.shared import format_observation as _fmt # type: ignore
|
|
658
|
+
text = _fmt(observation or {})
|
|
659
|
+
return {"text": text}
|
|
660
|
+
except Exception:
|
|
661
|
+
pass
|
|
662
|
+
|
|
663
|
+
# Generic fallback: extract a few small fields if present; avoid huge arrays
|
|
664
|
+
try:
|
|
665
|
+
inv = observation.get("inventory") if isinstance(observation, dict) else None
|
|
666
|
+
ach = observation.get("achievements_status") if isinstance(observation, dict) else None
|
|
667
|
+
pos = observation.get("player_position") if isinstance(observation, dict) else None
|
|
668
|
+
health = None
|
|
669
|
+
if isinstance(inv, dict):
|
|
670
|
+
health = inv.get("health")
|
|
671
|
+
summary = {
|
|
672
|
+
"position": pos,
|
|
673
|
+
"health": health,
|
|
674
|
+
"inventory_keys": sorted([k for k, v in (inv or {}).items() if v])[:10] if isinstance(inv, dict) else None,
|
|
675
|
+
"achievements_unlocked": sorted([k for k, v in (ach or {}).items() if v])[:10] if isinstance(ach, dict) else None,
|
|
676
|
+
}
|
|
677
|
+
return {"text": json.dumps(summary, ensure_ascii=False)}
|
|
678
|
+
except Exception:
|
|
679
|
+
pass
|
|
680
|
+
|
|
681
|
+
# Last resort: plain string preview
|
|
682
|
+
try:
|
|
683
|
+
return {"text": str(observation)[:10000]}
|
|
684
|
+
except Exception:
|
|
685
|
+
return {"text": ""}
|
|
686
|
+
|
|
687
|
+
|
|
688
|
+
|
|
689
|
+
class RunAbortRequest(BaseModel):
|
|
690
|
+
run_id: str
|
|
691
|
+
|
|
692
|
+
|
|
693
|
+
class RunAbortResponse(BaseModel):
|
|
694
|
+
ok: bool
|
|
695
|
+
run_id: str
|
|
696
|
+
|
|
697
|
+
|
|
698
|
+
class RunStatusResponse(BaseModel):
|
|
699
|
+
run_id: str
|
|
700
|
+
status: str
|
|
701
|
+
started_at: datetime
|
|
702
|
+
finished_at: Optional[datetime] = None
|
|
703
|
+
|
|
704
|
+
|
|
705
|
+
@router.post("/rollout", response_model=RolloutResponse)
|
|
706
|
+
async def execute_rollout(
|
|
707
|
+
request: RolloutRequest,
|
|
708
|
+
req: Request,
|
|
709
|
+
) -> RolloutResponse:
|
|
710
|
+
"""Execute a rollout with coordinated environment and policy steps."""
|
|
711
|
+
# Enforce per-episode step cap via env-specific parameters; default to 20 if omitted
|
|
712
|
+
try:
|
|
713
|
+
_env_params = {}
|
|
714
|
+
if isinstance(request.env, RolloutEnvSpec) and isinstance(request.env.config, dict):
|
|
715
|
+
_env_params = dict(request.env.config.get("env_params") or {})
|
|
716
|
+
max_steps_per_episode = int(_env_params.get("max_steps_per_episode") or 20)
|
|
717
|
+
assert max_steps_per_episode > 0, "max_steps_per_episode must be a positive integer"
|
|
718
|
+
except Exception as _mse:
|
|
719
|
+
raise HTTPException(
|
|
720
|
+
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
|
|
721
|
+
detail={
|
|
722
|
+
"error": "invalid_env_params",
|
|
723
|
+
"message": f"Invalid or missing env_params.max_steps_per_episode: {_mse}",
|
|
724
|
+
},
|
|
725
|
+
)
|
|
726
|
+
# Truncate incoming ops to the enforced cap (each step is [agent, env])
|
|
727
|
+
ops_seq: List[str] = list(request.ops or [])
|
|
728
|
+
allowed_ops = max(0, int(max_steps_per_episode) * 2)
|
|
729
|
+
if len(ops_seq) > allowed_ops:
|
|
730
|
+
try:
|
|
731
|
+
logger.info(
|
|
732
|
+
"ROLL_OUT: truncating ops to cap: requested_ops=%s allowed_ops=%s",
|
|
733
|
+
str(len(ops_seq)),
|
|
734
|
+
str(allowed_ops),
|
|
735
|
+
)
|
|
736
|
+
except Exception:
|
|
737
|
+
pass
|
|
738
|
+
ops_seq = ops_seq[:allowed_ops]
|
|
739
|
+
# Simple API key auth for inbound rollout
|
|
740
|
+
header_key = req.headers.get("x-api-key")
|
|
741
|
+
env_key = os.getenv("ENVIRONMENT_API_KEY")
|
|
742
|
+
dev_key = os.getenv("dev_environment_api_key")
|
|
743
|
+
# Accept either ENVIRONMENT_API_KEY or dev_environment_api_key
|
|
744
|
+
expected_keys = [k for k in (env_key, dev_key) if k]
|
|
745
|
+
if not expected_keys:
|
|
746
|
+
missing = []
|
|
747
|
+
if not env_key:
|
|
748
|
+
missing.append("ENVIRONMENT_API_KEY")
|
|
749
|
+
if not dev_key:
|
|
750
|
+
missing.append("dev_environment_api_key")
|
|
751
|
+
msg = f"Auth not configured: missing {', '.join(missing)} in task service environment"
|
|
752
|
+
logger.error(msg)
|
|
753
|
+
raise HTTPException(status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail=msg)
|
|
754
|
+
if not header_key:
|
|
755
|
+
raise HTTPException(
|
|
756
|
+
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
757
|
+
detail="Invalid or missing API key: X-API-Key header not provided",
|
|
758
|
+
)
|
|
759
|
+
if header_key not in expected_keys:
|
|
760
|
+
# Do not leak secrets; include short prefix for diagnostics
|
|
761
|
+
exp_src = env_key if env_key else (dev_key or "")
|
|
762
|
+
exp_prefix = (exp_src[:7] + "…") if len(exp_src) >= 7 else "set"
|
|
763
|
+
got_prefix = (header_key[:7] + "…") if len(header_key) >= 7 else "set"
|
|
764
|
+
raise HTTPException(
|
|
765
|
+
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
766
|
+
detail=f"Invalid API key: header does not match expected (got={got_prefix}, expected_prefix={exp_prefix})",
|
|
767
|
+
)
|
|
768
|
+
|
|
769
|
+
# Log contextual fields for traceability
|
|
770
|
+
if request.training_session_id:
|
|
771
|
+
logger.info(f"ROLL_OUT: training_session_id={request.training_session_id}")
|
|
772
|
+
if request.synth_base_url:
|
|
773
|
+
logger.info(f"ROLL_OUT: synth_base_url={request.synth_base_url}")
|
|
774
|
+
|
|
775
|
+
# Log masked OpenAI API key presence for diagnostics
|
|
776
|
+
try:
|
|
777
|
+
_oa = os.getenv("OPENAI_API_KEY")
|
|
778
|
+
if _oa:
|
|
779
|
+
_pref = (_oa[:6] + "…") if len(_oa) >= 6 else "set"
|
|
780
|
+
logger.info(f"ROLL_OUT: OPENAI_API_KEY present (prefix={_pref})")
|
|
781
|
+
else:
|
|
782
|
+
logger.warning("ROLL_OUT: OPENAI_API_KEY missing")
|
|
783
|
+
except Exception:
|
|
784
|
+
pass
|
|
785
|
+
|
|
786
|
+
# Make synth_base_url available for outbound calls in this app
|
|
787
|
+
try:
|
|
788
|
+
task_app = req.app.state.task_app
|
|
789
|
+
if request.synth_base_url:
|
|
790
|
+
setattr(task_app, "synth_base_url", request.synth_base_url)
|
|
791
|
+
except Exception:
|
|
792
|
+
pass
|
|
793
|
+
|
|
794
|
+
tracer_factory = getattr(req.app.state, "session_tracer_factory", None)
|
|
795
|
+
tracer_instance = None
|
|
796
|
+
if callable(tracer_factory):
|
|
797
|
+
try:
|
|
798
|
+
tracer_instance = tracer_factory()
|
|
799
|
+
except Exception as exc:
|
|
800
|
+
logger.debug(f"TRACER_FACTORY_FAIL: {exc}")
|
|
801
|
+
tracing_context = RolloutTracingContext(tracer_instance, request, req)
|
|
802
|
+
await tracing_context.start_session()
|
|
803
|
+
|
|
804
|
+
# Register run
|
|
805
|
+
registry.register_run(request.run_id)
|
|
806
|
+
|
|
807
|
+
# Track resources created during this rollout so we can guarantee cleanup
|
|
808
|
+
created_env_id: str | None = None
|
|
809
|
+
created_policy_id: str | None = None
|
|
810
|
+
env_seed_used: int | None = None
|
|
811
|
+
|
|
812
|
+
try:
|
|
813
|
+
# Initialize deterministic seed early for the entire rollout
|
|
814
|
+
seed_value: Optional[int] = None
|
|
815
|
+
try:
|
|
816
|
+
if request.env and request.env.seed is not None:
|
|
817
|
+
seed_value = int(request.env.seed)
|
|
818
|
+
else:
|
|
819
|
+
# Derive a stable seed from run_id
|
|
820
|
+
import hashlib as _hashlib # local import to avoid global deps
|
|
821
|
+
|
|
822
|
+
_digest = _hashlib.sha256(request.run_id.encode("utf-8")).hexdigest()
|
|
823
|
+
# Use lower 32 bits to fit common RNG ranges
|
|
824
|
+
seed_value = int(_digest[:8], 16)
|
|
825
|
+
except Exception:
|
|
826
|
+
# Fallback to time-based seed if anything goes wrong
|
|
827
|
+
try:
|
|
828
|
+
seed_value = int((_time.time_ns() // 1_000_000) % (2**31 - 1))
|
|
829
|
+
except Exception:
|
|
830
|
+
seed_value = 42
|
|
831
|
+
|
|
832
|
+
_seed_info = _set_global_seed(int(seed_value))
|
|
833
|
+
try:
|
|
834
|
+
logger.info(
|
|
835
|
+
"ROLL_OUT: RNG seeded seed=%s libs=%s",
|
|
836
|
+
str(_seed_info.get("seed")),
|
|
837
|
+
",".join(_seed_info.get("libs", [])),
|
|
838
|
+
)
|
|
839
|
+
except Exception:
|
|
840
|
+
pass
|
|
841
|
+
# Resolve or create environment
|
|
842
|
+
if request.env.env_id:
|
|
843
|
+
env_handle = registry.get_env(request.env.env_id)
|
|
844
|
+
if not env_handle:
|
|
845
|
+
raise HTTPException(
|
|
846
|
+
status_code=404,
|
|
847
|
+
detail=f"Environment {request.env.env_id} not found",
|
|
848
|
+
)
|
|
849
|
+
env_id = request.env.env_id
|
|
850
|
+
else:
|
|
851
|
+
# Create new environment
|
|
852
|
+
from .environment_routes import create_environment, EnvCreateRequest
|
|
853
|
+
|
|
854
|
+
if not request.env.env_name:
|
|
855
|
+
raise ValueError("FATAL: env_name is required - NO FALLBACKS!")
|
|
856
|
+
|
|
857
|
+
# Propagate training_session_id via env config for downstream usage
|
|
858
|
+
_env_config = dict(request.env.config or {})
|
|
859
|
+
if request.training_session_id is not None:
|
|
860
|
+
_env_config.setdefault(
|
|
861
|
+
"training_session_id", request.training_session_id
|
|
862
|
+
)
|
|
863
|
+
env_response = await create_environment(
|
|
864
|
+
EnvCreateRequest(
|
|
865
|
+
env_name=request.env.env_name,
|
|
866
|
+
config=_env_config,
|
|
867
|
+
seed=request.env.seed,
|
|
868
|
+
rl_run_id=request.run_id,
|
|
869
|
+
)
|
|
870
|
+
)
|
|
871
|
+
env_id = env_response.env_id
|
|
872
|
+
env_handle = registry.get_env(env_id)
|
|
873
|
+
created_env_id = env_id
|
|
874
|
+
|
|
875
|
+
tracing_context.update_metadata(env_id=env_id)
|
|
876
|
+
|
|
877
|
+
# Resolve or create policy
|
|
878
|
+
if request.policy.policy_id:
|
|
879
|
+
policy_handle = registry.get_policy(request.policy.policy_id)
|
|
880
|
+
if not policy_handle:
|
|
881
|
+
raise HTTPException(
|
|
882
|
+
status_code=404,
|
|
883
|
+
detail=f"Policy {request.policy.policy_id} not found",
|
|
884
|
+
)
|
|
885
|
+
policy_id = request.policy.policy_id
|
|
886
|
+
else:
|
|
887
|
+
# Create new policy
|
|
888
|
+
from .policy_routes import create_policy, PolicyCreateRequest
|
|
889
|
+
|
|
890
|
+
if not request.policy.policy_name:
|
|
891
|
+
raise ValueError("FATAL: policy_name is required - NO FALLBACKS!")
|
|
892
|
+
|
|
893
|
+
# Propagate training_session_id and synth_base_url via policy config
|
|
894
|
+
_policy_config = dict(request.policy.config or {})
|
|
895
|
+
if request.training_session_id is not None:
|
|
896
|
+
_policy_config.setdefault(
|
|
897
|
+
"training_session_id", request.training_session_id
|
|
898
|
+
)
|
|
899
|
+
if request.synth_base_url is not None:
|
|
900
|
+
_policy_config.setdefault("synth_base_url", request.synth_base_url)
|
|
901
|
+
policy_response = await create_policy(
|
|
902
|
+
PolicyCreateRequest(
|
|
903
|
+
policy_name=request.policy.policy_name,
|
|
904
|
+
config=_policy_config,
|
|
905
|
+
rl_run_id=request.run_id,
|
|
906
|
+
bound_env_id=env_id,
|
|
907
|
+
),
|
|
908
|
+
req,
|
|
909
|
+
)
|
|
910
|
+
policy_id = policy_response.policy_id
|
|
911
|
+
policy_handle = registry.get_policy(policy_id)
|
|
912
|
+
created_policy_id = policy_id
|
|
913
|
+
|
|
914
|
+
tracing_context.update_metadata(policy_id=policy_id)
|
|
915
|
+
|
|
916
|
+
# Bind policy to environment if not already bound
|
|
917
|
+
if policy_handle and not policy_handle.bound_env_id:
|
|
918
|
+
policy_handle.bound_env_id = env_id
|
|
919
|
+
|
|
920
|
+
# Record seed bound to environment for end-of-rollout verification/logging
|
|
921
|
+
try:
|
|
922
|
+
env_seed_used = int(getattr(env_handle, "seed", 0) or 0)
|
|
923
|
+
except Exception:
|
|
924
|
+
env_seed_used = None
|
|
925
|
+
tracing_context.update_metadata(env_seed=env_seed_used)
|
|
926
|
+
|
|
927
|
+
# Initialize trajectory
|
|
928
|
+
trajectory_steps = []
|
|
929
|
+
pending_tool_calls = None
|
|
930
|
+
current_obs = env_handle.last_observation
|
|
931
|
+
total_reward = 0.0
|
|
932
|
+
ops_executed = 0
|
|
933
|
+
last_agent_response_ts: float | None = None
|
|
934
|
+
last_policy_meta: Dict[str, Any] | None = None
|
|
935
|
+
last_env_step_ms: float | None = None
|
|
936
|
+
last_env_step_completed_ts: float | None = None
|
|
937
|
+
|
|
938
|
+
# Stepwise reward configuration (Crafter shaping; gate on explicit enable)
|
|
939
|
+
step_rewards_cfg_raw: Dict[str, Any] = {}
|
|
940
|
+
try:
|
|
941
|
+
if isinstance(request.policy.config, dict):
|
|
942
|
+
step_rewards_cfg_raw = dict(request.policy.config.get("step_rewards") or {})
|
|
943
|
+
except Exception:
|
|
944
|
+
step_rewards_cfg_raw = {}
|
|
945
|
+
if not step_rewards_cfg_raw:
|
|
946
|
+
try:
|
|
947
|
+
if isinstance(request.env.config, dict):
|
|
948
|
+
step_rewards_cfg_raw = dict(request.env.config.get("step_rewards") or {})
|
|
949
|
+
except Exception:
|
|
950
|
+
step_rewards_cfg_raw = {}
|
|
951
|
+
|
|
952
|
+
step_rewards_enabled = bool(step_rewards_cfg_raw.get("enabled", False))
|
|
953
|
+
step_rewards_mode = str(step_rewards_cfg_raw.get("mode") or "off").lower()
|
|
954
|
+
try:
|
|
955
|
+
step_rewards_indicator_lambda = float(
|
|
956
|
+
step_rewards_cfg_raw.get("indicator_lambda") or 0.0
|
|
957
|
+
)
|
|
958
|
+
except Exception:
|
|
959
|
+
step_rewards_indicator_lambda = 0.0
|
|
960
|
+
try:
|
|
961
|
+
step_rewards_beta = float(step_rewards_cfg_raw.get("step_beta") or 0.0)
|
|
962
|
+
except Exception:
|
|
963
|
+
step_rewards_beta = 0.0
|
|
964
|
+
step_rewards_active = step_rewards_enabled and step_rewards_mode == "decision_stepwise"
|
|
965
|
+
|
|
966
|
+
def _extract_achievements(obs: Any) -> Dict[str, bool]:
|
|
967
|
+
if not isinstance(obs, dict):
|
|
968
|
+
return {}
|
|
969
|
+
ach = obs.get("achievements_status")
|
|
970
|
+
if isinstance(ach, dict):
|
|
971
|
+
return {str(k): bool(v) for k, v in ach.items()}
|
|
972
|
+
return {}
|
|
973
|
+
|
|
974
|
+
def _summarize_tool_calls(tool_calls: Any) -> List[Dict[str, Any]]:
|
|
975
|
+
if not tool_calls:
|
|
976
|
+
return []
|
|
977
|
+
try:
|
|
978
|
+
items = (
|
|
979
|
+
tool_calls
|
|
980
|
+
if isinstance(tool_calls, list)
|
|
981
|
+
else list(tool_calls) # tolerates tuples or pydantic lists
|
|
982
|
+
)
|
|
983
|
+
except Exception:
|
|
984
|
+
return []
|
|
985
|
+
summary: List[Dict[str, Any]] = []
|
|
986
|
+
for tc in items:
|
|
987
|
+
tool_name = None
|
|
988
|
+
args: Any = {}
|
|
989
|
+
if isinstance(tc, dict):
|
|
990
|
+
tool_name = tc.get("tool") or tc.get("tool_name") or tc.get("name")
|
|
991
|
+
raw_args = tc.get("arguments") or tc.get("args") or {}
|
|
992
|
+
else:
|
|
993
|
+
tool_name = getattr(tc, "tool", None) or getattr(tc, "tool_name", None)
|
|
994
|
+
raw_args = getattr(tc, "arguments", None) or getattr(tc, "args", None) or {}
|
|
995
|
+
args = raw_args
|
|
996
|
+
if isinstance(raw_args, str):
|
|
997
|
+
try:
|
|
998
|
+
args = json.loads(raw_args)
|
|
999
|
+
except Exception:
|
|
1000
|
+
args = raw_args
|
|
1001
|
+
summary.append({"tool": tool_name, "args": args})
|
|
1002
|
+
return summary
|
|
1003
|
+
|
|
1004
|
+
decision_samples: List[Dict[str, Any]] = []
|
|
1005
|
+
decision_index = 0
|
|
1006
|
+
decision_open = False
|
|
1007
|
+
session_trace = None
|
|
1008
|
+
finalized = False
|
|
1009
|
+
prev_achievements = _extract_achievements(current_obs)
|
|
1010
|
+
# Track episode-level achievements that have been seen as true at any point so far
|
|
1011
|
+
episode_seen_achievements: set[str] = set(
|
|
1012
|
+
[k for k, v in (prev_achievements or {}).items() if bool(v)]
|
|
1013
|
+
)
|
|
1014
|
+
stepwise_indicator_sum = 0.0
|
|
1015
|
+
stepwise_reward_sum = 0.0
|
|
1016
|
+
stepwise_new_achievements_total = 0
|
|
1017
|
+
final_achievement_count = sum(1 for v in prev_achievements.values() if v)
|
|
1018
|
+
|
|
1019
|
+
# Execute ops sequence (capped by env_params.max_steps_per_episode)
|
|
1020
|
+
for op_idx, op in enumerate(ops_seq):
|
|
1021
|
+
# Check for abort
|
|
1022
|
+
if registry.is_run_aborted(request.run_id):
|
|
1023
|
+
logger.info(f"Run {request.run_id} aborted at op {op_idx}")
|
|
1024
|
+
break
|
|
1025
|
+
|
|
1026
|
+
# Check safety limits
|
|
1027
|
+
if ops_executed >= request.safety.max_ops:
|
|
1028
|
+
logger.warning(f"Reached max_ops limit ({request.safety.max_ops})")
|
|
1029
|
+
break
|
|
1030
|
+
|
|
1031
|
+
if op == "agent":
|
|
1032
|
+
# Policy step
|
|
1033
|
+
from .policy_routes import step_policy, PolicyStepRequest
|
|
1034
|
+
|
|
1035
|
+
if not decision_open:
|
|
1036
|
+
await tracing_context.start_decision(decision_index)
|
|
1037
|
+
decision_open = True
|
|
1038
|
+
|
|
1039
|
+
agent_request_start = _time.perf_counter()
|
|
1040
|
+
if last_agent_response_ts is not None and last_policy_meta is not None:
|
|
1041
|
+
try:
|
|
1042
|
+
timing_prev = last_policy_meta.setdefault("timing", {})
|
|
1043
|
+
decision_ms = max(
|
|
1044
|
+
0.0,
|
|
1045
|
+
(agent_request_start - float(last_agent_response_ts)) * 1000.0,
|
|
1046
|
+
)
|
|
1047
|
+
# Update timing on prior policy meta (kept by previous env step)
|
|
1048
|
+
timing_prev["decision_ms"] = decision_ms
|
|
1049
|
+
if last_env_step_ms is not None:
|
|
1050
|
+
timing_prev["env_step_ms"] = float(last_env_step_ms)
|
|
1051
|
+
timing_prev["overhead_ms"] = max(
|
|
1052
|
+
0.0, decision_ms - float(last_env_step_ms)
|
|
1053
|
+
)
|
|
1054
|
+
else:
|
|
1055
|
+
timing_prev.setdefault("overhead_ms", 0.0)
|
|
1056
|
+
timing_prev["decision_ready_s"] = agent_request_start
|
|
1057
|
+
# Also backfill the last appended trajectory step so the trainer
|
|
1058
|
+
# can always see decision_ms without relying on shared dict refs.
|
|
1059
|
+
if trajectory_steps:
|
|
1060
|
+
try:
|
|
1061
|
+
_last = trajectory_steps[-1]
|
|
1062
|
+
_info = dict(_last.info or {})
|
|
1063
|
+
_meta = dict(_info.get("meta") or {})
|
|
1064
|
+
_timing = dict(_meta.get("timing") or {})
|
|
1065
|
+
_timing["decision_ms"] = decision_ms
|
|
1066
|
+
if last_env_step_ms is not None:
|
|
1067
|
+
_timing.setdefault("env_step_ms", float(last_env_step_ms))
|
|
1068
|
+
_timing.setdefault("overhead_ms", max(0.0, decision_ms - float(last_env_step_ms)))
|
|
1069
|
+
else:
|
|
1070
|
+
_timing.setdefault("overhead_ms", 0.0)
|
|
1071
|
+
_meta["timing"] = _timing
|
|
1072
|
+
_info["meta"] = _meta
|
|
1073
|
+
_last.info = _info
|
|
1074
|
+
except Exception:
|
|
1075
|
+
pass
|
|
1076
|
+
except Exception:
|
|
1077
|
+
pass
|
|
1078
|
+
last_env_step_ms = None
|
|
1079
|
+
last_env_step_completed_ts = None
|
|
1080
|
+
|
|
1081
|
+
# Build metadata for policy (carry previous tool_calls and env result)
|
|
1082
|
+
metadata = {}
|
|
1083
|
+
if pending_tool_calls:
|
|
1084
|
+
metadata["prev_tool_calls"] = pending_tool_calls
|
|
1085
|
+
if len(trajectory_steps) > 0:
|
|
1086
|
+
last_step = trajectory_steps[-1]
|
|
1087
|
+
# Prefer the last executed tool calls to seed history
|
|
1088
|
+
if last_step.tool_calls:
|
|
1089
|
+
metadata["prev_tool_calls"] = last_step.tool_calls
|
|
1090
|
+
# Provide a compact env result snapshot
|
|
1091
|
+
metadata["prev_env_result"] = {
|
|
1092
|
+
"observation": last_step.obs,
|
|
1093
|
+
"reward": last_step.reward,
|
|
1094
|
+
"done": last_step.done,
|
|
1095
|
+
"truncated": last_step.truncated,
|
|
1096
|
+
"info": last_step.info,
|
|
1097
|
+
}
|
|
1098
|
+
|
|
1099
|
+
# Log compact metadata summary to confirm history threading
|
|
1100
|
+
try:
|
|
1101
|
+
_prev_calls = (
|
|
1102
|
+
metadata["prev_tool_calls"]
|
|
1103
|
+
if isinstance(metadata, dict) and "prev_tool_calls" in metadata
|
|
1104
|
+
else None
|
|
1105
|
+
)
|
|
1106
|
+
_count = len(_prev_calls) if isinstance(_prev_calls, list) else 0
|
|
1107
|
+
_first_guess = None
|
|
1108
|
+
if _count > 0 and isinstance(_prev_calls[0], dict):
|
|
1109
|
+
_args = (
|
|
1110
|
+
_prev_calls[0]["arguments"]
|
|
1111
|
+
if "arguments" in _prev_calls[0]
|
|
1112
|
+
else None
|
|
1113
|
+
)
|
|
1114
|
+
if isinstance(_args, str):
|
|
1115
|
+
import json as _json
|
|
1116
|
+
|
|
1117
|
+
try:
|
|
1118
|
+
_args = _json.loads(_args)
|
|
1119
|
+
except Exception:
|
|
1120
|
+
_args = {}
|
|
1121
|
+
if isinstance(_args, dict):
|
|
1122
|
+
_first_guess = (
|
|
1123
|
+
_args["guess"] if "guess" in _args else None
|
|
1124
|
+
) or (_args["word"] if "word" in _args else None)
|
|
1125
|
+
logger.info(
|
|
1126
|
+
"POLICY_METADATA: prev_tool_calls=%d first_guess=%r has_prev_env_result=%s",
|
|
1127
|
+
_count,
|
|
1128
|
+
_first_guess,
|
|
1129
|
+
str("prev_env_result" in metadata),
|
|
1130
|
+
)
|
|
1131
|
+
except Exception:
|
|
1132
|
+
pass
|
|
1133
|
+
|
|
1134
|
+
try:
|
|
1135
|
+
policy_response = await step_policy(
|
|
1136
|
+
PolicyStepRequest(
|
|
1137
|
+
policy_id=policy_id,
|
|
1138
|
+
observation=current_obs,
|
|
1139
|
+
metadata=metadata,
|
|
1140
|
+
),
|
|
1141
|
+
req,
|
|
1142
|
+
)
|
|
1143
|
+
except Exception as _pe:
|
|
1144
|
+
# Do not 500 the rollout; finalize with partial trajectory
|
|
1145
|
+
try:
|
|
1146
|
+
logger.warning(
|
|
1147
|
+
"POLICY_STEP_FAIL: terminating episode early run_id=%s op_idx=%s err=%s",
|
|
1148
|
+
request.run_id,
|
|
1149
|
+
str(op_idx),
|
|
1150
|
+
str(_pe),
|
|
1151
|
+
)
|
|
1152
|
+
except Exception:
|
|
1153
|
+
pass
|
|
1154
|
+
|
|
1155
|
+
# Build partial trajectory and return HTTP 200
|
|
1156
|
+
trajectory = RolloutTrajectory(
|
|
1157
|
+
env_id=env_id,
|
|
1158
|
+
policy_id=policy_id,
|
|
1159
|
+
steps=trajectory_steps,
|
|
1160
|
+
final={
|
|
1161
|
+
"observation": current_obs,
|
|
1162
|
+
"rollout_status": "partial_policy_error",
|
|
1163
|
+
"error": str(_pe),
|
|
1164
|
+
"at_op": op,
|
|
1165
|
+
},
|
|
1166
|
+
length=len(trajectory_steps),
|
|
1167
|
+
decision_samples=decision_samples if step_rewards_active else None,
|
|
1168
|
+
)
|
|
1169
|
+
metrics = RolloutMetrics(
|
|
1170
|
+
episode_returns=[total_reward],
|
|
1171
|
+
mean_return=total_reward,
|
|
1172
|
+
num_steps=len(trajectory_steps),
|
|
1173
|
+
num_episodes=1,
|
|
1174
|
+
)
|
|
1175
|
+
aborted = registry.is_run_aborted(request.run_id)
|
|
1176
|
+
if not aborted:
|
|
1177
|
+
registry.complete_run(request.run_id)
|
|
1178
|
+
if decision_open:
|
|
1179
|
+
await tracing_context.end_decision()
|
|
1180
|
+
decision_open = False
|
|
1181
|
+
if not finalized:
|
|
1182
|
+
session_trace = await tracing_context.finalize(
|
|
1183
|
+
total_reward=total_reward,
|
|
1184
|
+
achievement_state=prev_achievements,
|
|
1185
|
+
total_steps=len(trajectory_steps),
|
|
1186
|
+
)
|
|
1187
|
+
finalized = True
|
|
1188
|
+
trace_payload = tracing_context.build_trace_payload(session_trace)
|
|
1189
|
+
return RolloutResponse(
|
|
1190
|
+
run_id=request.run_id,
|
|
1191
|
+
trajectories=[trajectory],
|
|
1192
|
+
branches={},
|
|
1193
|
+
metrics=metrics,
|
|
1194
|
+
aborted=aborted,
|
|
1195
|
+
ops_executed=ops_executed,
|
|
1196
|
+
trace=trace_payload,
|
|
1197
|
+
)
|
|
1198
|
+
|
|
1199
|
+
agent_response_ts = _time.perf_counter()
|
|
1200
|
+
if isinstance(policy_response.meta, dict):
|
|
1201
|
+
try:
|
|
1202
|
+
timing_cur = policy_response.meta.setdefault("timing", {})
|
|
1203
|
+
timing_cur["agent_request_start_s"] = agent_request_start
|
|
1204
|
+
timing_cur["agent_response_s"] = agent_response_ts
|
|
1205
|
+
if "inference_ms" in policy_response.meta:
|
|
1206
|
+
try:
|
|
1207
|
+
timing_cur.setdefault(
|
|
1208
|
+
"inference_ms",
|
|
1209
|
+
float(policy_response.meta["inference_ms"]),
|
|
1210
|
+
)
|
|
1211
|
+
timing_cur.setdefault(
|
|
1212
|
+
"inference_s",
|
|
1213
|
+
float(policy_response.meta["inference_ms"]) / 1000.0,
|
|
1214
|
+
)
|
|
1215
|
+
except Exception:
|
|
1216
|
+
pass
|
|
1217
|
+
except Exception:
|
|
1218
|
+
pass
|
|
1219
|
+
last_policy_meta = policy_response.meta
|
|
1220
|
+
else:
|
|
1221
|
+
last_policy_meta = None
|
|
1222
|
+
last_agent_response_ts = agent_response_ts
|
|
1223
|
+
|
|
1224
|
+
pending_tool_calls = policy_response.tool_calls
|
|
1225
|
+
await tracing_context.record_tool_invocation(pending_tool_calls)
|
|
1226
|
+
ops_executed += 1
|
|
1227
|
+
|
|
1228
|
+
elif op == "env":
|
|
1229
|
+
if not pending_tool_calls:
|
|
1230
|
+
# Treat absence of tool calls as a soft terminal condition; yield partial trajectory
|
|
1231
|
+
try:
|
|
1232
|
+
logger.warning(
|
|
1233
|
+
"NO_TOOL_CALLS: terminating episode early run_id=%s op_idx=%s",
|
|
1234
|
+
request.run_id,
|
|
1235
|
+
str(op_idx),
|
|
1236
|
+
)
|
|
1237
|
+
except Exception:
|
|
1238
|
+
pass
|
|
1239
|
+
term_step = RolloutStep(
|
|
1240
|
+
obs=current_obs,
|
|
1241
|
+
tool_calls=[],
|
|
1242
|
+
reward=None,
|
|
1243
|
+
done=True,
|
|
1244
|
+
truncated=False,
|
|
1245
|
+
info={
|
|
1246
|
+
"terminated": True,
|
|
1247
|
+
"reason": "no_tool_calls",
|
|
1248
|
+
},
|
|
1249
|
+
)
|
|
1250
|
+
trajectory_steps.append(term_step)
|
|
1251
|
+
trajectory = RolloutTrajectory(
|
|
1252
|
+
env_id=env_id,
|
|
1253
|
+
policy_id=policy_id,
|
|
1254
|
+
steps=trajectory_steps,
|
|
1255
|
+
final={
|
|
1256
|
+
"observation": current_obs,
|
|
1257
|
+
"rollout_status": "partial_no_tool_calls",
|
|
1258
|
+
"at_op": op,
|
|
1259
|
+
},
|
|
1260
|
+
length=len(trajectory_steps),
|
|
1261
|
+
decision_samples=decision_samples if step_rewards_active else None,
|
|
1262
|
+
)
|
|
1263
|
+
metrics = RolloutMetrics(
|
|
1264
|
+
episode_returns=[total_reward],
|
|
1265
|
+
mean_return=total_reward,
|
|
1266
|
+
num_steps=len(trajectory_steps),
|
|
1267
|
+
num_episodes=1,
|
|
1268
|
+
)
|
|
1269
|
+
aborted = registry.is_run_aborted(request.run_id)
|
|
1270
|
+
if not aborted:
|
|
1271
|
+
registry.complete_run(request.run_id)
|
|
1272
|
+
if decision_open:
|
|
1273
|
+
await tracing_context.end_decision()
|
|
1274
|
+
decision_open = False
|
|
1275
|
+
if not finalized:
|
|
1276
|
+
session_trace = await tracing_context.finalize(
|
|
1277
|
+
total_reward=total_reward,
|
|
1278
|
+
achievement_state=prev_achievements,
|
|
1279
|
+
total_steps=len(trajectory_steps),
|
|
1280
|
+
)
|
|
1281
|
+
finalized = True
|
|
1282
|
+
trace_payload = tracing_context.build_trace_payload(session_trace)
|
|
1283
|
+
return RolloutResponse(
|
|
1284
|
+
run_id=request.run_id,
|
|
1285
|
+
trajectories=[trajectory],
|
|
1286
|
+
branches={},
|
|
1287
|
+
metrics=metrics,
|
|
1288
|
+
aborted=aborted,
|
|
1289
|
+
ops_executed=ops_executed,
|
|
1290
|
+
trace=trace_payload,
|
|
1291
|
+
)
|
|
1292
|
+
|
|
1293
|
+
# Environment step
|
|
1294
|
+
from .environment_routes import step_environment, EnvStepRequest
|
|
1295
|
+
|
|
1296
|
+
env_step_error: Exception | None = None
|
|
1297
|
+
env_response = None
|
|
1298
|
+
env_step_start = _time.perf_counter()
|
|
1299
|
+
try:
|
|
1300
|
+
env_response = await step_environment(
|
|
1301
|
+
EnvStepRequest(
|
|
1302
|
+
env_id=env_id,
|
|
1303
|
+
tool_calls=pending_tool_calls,
|
|
1304
|
+
)
|
|
1305
|
+
)
|
|
1306
|
+
except Exception as _ee:
|
|
1307
|
+
env_step_error = _ee
|
|
1308
|
+
env_step_end = _time.perf_counter()
|
|
1309
|
+
env_step_duration_ms = (env_step_end - env_step_start) * 1000.0
|
|
1310
|
+
last_env_step_ms = env_step_duration_ms
|
|
1311
|
+
last_env_step_completed_ts = env_step_end
|
|
1312
|
+
if last_policy_meta is not None:
|
|
1313
|
+
try:
|
|
1314
|
+
timing_env = last_policy_meta.setdefault("timing", {})
|
|
1315
|
+
timing_env["env_step_ms"] = env_step_duration_ms
|
|
1316
|
+
timing_env["env_step_end_s"] = env_step_end
|
|
1317
|
+
except Exception:
|
|
1318
|
+
pass
|
|
1319
|
+
|
|
1320
|
+
if env_step_error is not None:
|
|
1321
|
+
# Invalid action or environment rejection — terminate episode early with partial trajectory
|
|
1322
|
+
try:
|
|
1323
|
+
logger.warning(
|
|
1324
|
+
"ENV_STEP_FAIL: terminating episode early run_id=%s op_idx=%s err=%s",
|
|
1325
|
+
request.run_id,
|
|
1326
|
+
str(op_idx),
|
|
1327
|
+
str(env_step_error),
|
|
1328
|
+
)
|
|
1329
|
+
except Exception:
|
|
1330
|
+
pass
|
|
1331
|
+
|
|
1332
|
+
term_step = RolloutStep(
|
|
1333
|
+
obs=current_obs,
|
|
1334
|
+
tool_calls=pending_tool_calls,
|
|
1335
|
+
reward=None,
|
|
1336
|
+
done=True,
|
|
1337
|
+
truncated=False,
|
|
1338
|
+
info={
|
|
1339
|
+
"terminated": True,
|
|
1340
|
+
"reason": "invalid_action",
|
|
1341
|
+
"error": str(env_step_error),
|
|
1342
|
+
},
|
|
1343
|
+
)
|
|
1344
|
+
trajectory_steps.append(term_step)
|
|
1345
|
+
# Build partial response
|
|
1346
|
+
trajectory = RolloutTrajectory(
|
|
1347
|
+
env_id=env_id,
|
|
1348
|
+
policy_id=policy_id,
|
|
1349
|
+
steps=trajectory_steps,
|
|
1350
|
+
final={
|
|
1351
|
+
"observation": current_obs,
|
|
1352
|
+
"rollout_status": "partial_invalid_action",
|
|
1353
|
+
"error": str(env_step_error),
|
|
1354
|
+
"at_op": op,
|
|
1355
|
+
},
|
|
1356
|
+
length=len(trajectory_steps),
|
|
1357
|
+
decision_samples=decision_samples if step_rewards_active else None,
|
|
1358
|
+
)
|
|
1359
|
+
metrics = RolloutMetrics(
|
|
1360
|
+
episode_returns=[total_reward],
|
|
1361
|
+
mean_return=total_reward,
|
|
1362
|
+
num_steps=len(trajectory_steps),
|
|
1363
|
+
num_episodes=1,
|
|
1364
|
+
)
|
|
1365
|
+
aborted = registry.is_run_aborted(request.run_id)
|
|
1366
|
+
if not aborted:
|
|
1367
|
+
registry.complete_run(request.run_id)
|
|
1368
|
+
if (
|
|
1369
|
+
last_policy_meta is not None
|
|
1370
|
+
and last_agent_response_ts is not None
|
|
1371
|
+
and "decision_ms" not in last_policy_meta.get("timing", {})
|
|
1372
|
+
):
|
|
1373
|
+
try:
|
|
1374
|
+
timing_last = last_policy_meta.setdefault("timing", {})
|
|
1375
|
+
decision_ms = max(
|
|
1376
|
+
0.0,
|
|
1377
|
+
(env_step_end - float(last_agent_response_ts)) * 1000.0,
|
|
1378
|
+
)
|
|
1379
|
+
timing_last["decision_ms"] = decision_ms
|
|
1380
|
+
timing_last.setdefault("overhead_ms", max(0.0, decision_ms - env_step_duration_ms))
|
|
1381
|
+
except Exception:
|
|
1382
|
+
pass
|
|
1383
|
+
if decision_open:
|
|
1384
|
+
await tracing_context.end_decision()
|
|
1385
|
+
decision_open = False
|
|
1386
|
+
if not finalized:
|
|
1387
|
+
session_trace = await tracing_context.finalize(
|
|
1388
|
+
total_reward=total_reward,
|
|
1389
|
+
achievement_state=prev_achievements,
|
|
1390
|
+
total_steps=len(trajectory_steps),
|
|
1391
|
+
)
|
|
1392
|
+
finalized = True
|
|
1393
|
+
trace_payload = tracing_context.build_trace_payload(session_trace)
|
|
1394
|
+
return RolloutResponse(
|
|
1395
|
+
run_id=request.run_id,
|
|
1396
|
+
trajectories=[trajectory],
|
|
1397
|
+
branches={},
|
|
1398
|
+
metrics=metrics,
|
|
1399
|
+
aborted=aborted,
|
|
1400
|
+
ops_executed=ops_executed,
|
|
1401
|
+
trace=trace_payload,
|
|
1402
|
+
)
|
|
1403
|
+
|
|
1404
|
+
# Reaching here means env step succeeded
|
|
1405
|
+
assert env_response is not None
|
|
1406
|
+
|
|
1407
|
+
# Record step, including policy meta if present for timing/tokens observability
|
|
1408
|
+
_info = env_response.info if isinstance(env_response.info, dict) else {}
|
|
1409
|
+
# Attach policy meta from the immediately preceding agent step
|
|
1410
|
+
try:
|
|
1411
|
+
prev_meta = {}
|
|
1412
|
+
if "policy_response" in locals() and isinstance(
|
|
1413
|
+
policy_response.meta, dict
|
|
1414
|
+
): # type: ignore[name-defined]
|
|
1415
|
+
prev_meta = policy_response.meta
|
|
1416
|
+
if prev_meta:
|
|
1417
|
+
_info = dict(_info)
|
|
1418
|
+
_info["meta"] = prev_meta
|
|
1419
|
+
except Exception:
|
|
1420
|
+
pass
|
|
1421
|
+
|
|
1422
|
+
event_metadata = {
|
|
1423
|
+
"op_index": op_idx,
|
|
1424
|
+
}
|
|
1425
|
+
event_id = await tracing_context.record_environment_event(
|
|
1426
|
+
env_handle=env_handle,
|
|
1427
|
+
prev_obs=current_obs,
|
|
1428
|
+
env_response=env_response,
|
|
1429
|
+
next_obs=getattr(env_response, "observation", None),
|
|
1430
|
+
metadata=event_metadata,
|
|
1431
|
+
)
|
|
1432
|
+
|
|
1433
|
+
decision_index += 1
|
|
1434
|
+
next_obs = env_response.observation
|
|
1435
|
+
new_achievement_state = _extract_achievements(next_obs)
|
|
1436
|
+
final_achievement_count = sum(
|
|
1437
|
+
1 for _, unlocked in new_achievement_state.items() if unlocked
|
|
1438
|
+
)
|
|
1439
|
+
indicator_val = 0
|
|
1440
|
+
reward_stepwise = 0.0
|
|
1441
|
+
decision_rewards_meta: Dict[str, Any] | None = None
|
|
1442
|
+
if step_rewards_active:
|
|
1443
|
+
decision_actions = _summarize_tool_calls(pending_tool_calls)
|
|
1444
|
+
stepwise_info, decision_record, stats = compute_stepwise_reward(
|
|
1445
|
+
prev_achievements or {},
|
|
1446
|
+
new_achievement_state,
|
|
1447
|
+
decision_index,
|
|
1448
|
+
decision_actions,
|
|
1449
|
+
step_rewards_indicator_lambda,
|
|
1450
|
+
)
|
|
1451
|
+
indicator_val = int(stats.get("indicator", 0.0))
|
|
1452
|
+
reward_stepwise = float(stats.get("reward", 0.0))
|
|
1453
|
+
stepwise_indicator_sum += float(stats.get("indicator", 0.0))
|
|
1454
|
+
stepwise_reward_sum += reward_stepwise
|
|
1455
|
+
stepwise_new_achievements_total += int(
|
|
1456
|
+
stats.get("new_achievements_count", 0.0)
|
|
1457
|
+
)
|
|
1458
|
+
if not isinstance(_info, dict):
|
|
1459
|
+
_info = {}
|
|
1460
|
+
else:
|
|
1461
|
+
_info = dict(_info)
|
|
1462
|
+
_info["stepwise"] = stepwise_info
|
|
1463
|
+
# Compute decision-level rewards (absolute vs unique) and attach to metadata
|
|
1464
|
+
try:
|
|
1465
|
+
turned_true = set(stepwise_info.get("new_achievements") or [])
|
|
1466
|
+
seen_before = set(episode_seen_achievements)
|
|
1467
|
+
new_unique = sorted(list(turned_true - seen_before))
|
|
1468
|
+
ach_delta = int(len(turned_true))
|
|
1469
|
+
unique_delta = int(len(new_unique))
|
|
1470
|
+
# Prepare stable lists for logging/metadata
|
|
1471
|
+
all_list = sorted(list(turned_true))
|
|
1472
|
+
# Ensure nested meta exists
|
|
1473
|
+
meta_block = _info.get("meta") if isinstance(_info.get("meta"), dict) else {}
|
|
1474
|
+
decision_rewards = {
|
|
1475
|
+
"turn": int(decision_index),
|
|
1476
|
+
"ach_delta": ach_delta,
|
|
1477
|
+
"unique_delta": unique_delta,
|
|
1478
|
+
"all": all_list,
|
|
1479
|
+
"unique": new_unique,
|
|
1480
|
+
}
|
|
1481
|
+
decision_rewards_meta = decision_rewards
|
|
1482
|
+
meta_block["decision_rewards"] = decision_rewards
|
|
1483
|
+
_info["meta"] = meta_block
|
|
1484
|
+
# Update episode-level seen set after attributing uniqueness to this decision
|
|
1485
|
+
episode_seen_achievements.update(turned_true)
|
|
1486
|
+
except Exception:
|
|
1487
|
+
# Best-effort; do not block rollout on metadata computation
|
|
1488
|
+
pass
|
|
1489
|
+
decision_samples.append(decision_record)
|
|
1490
|
+
prev_achievements = new_achievement_state
|
|
1491
|
+
|
|
1492
|
+
await tracing_context.record_decision_reward(
|
|
1493
|
+
event_id=event_id,
|
|
1494
|
+
decision_meta=decision_rewards_meta,
|
|
1495
|
+
)
|
|
1496
|
+
|
|
1497
|
+
step = RolloutStep(
|
|
1498
|
+
obs=_summarize_observation_for_storage(env_handle, current_obs),
|
|
1499
|
+
tool_calls=pending_tool_calls,
|
|
1500
|
+
reward=env_response.reward,
|
|
1501
|
+
done=env_response.done,
|
|
1502
|
+
truncated=env_response.truncated,
|
|
1503
|
+
info=_info,
|
|
1504
|
+
)
|
|
1505
|
+
trajectory_steps.append(step)
|
|
1506
|
+
|
|
1507
|
+
if env_response.reward is not None:
|
|
1508
|
+
total_reward += env_response.reward
|
|
1509
|
+
|
|
1510
|
+
# Update state
|
|
1511
|
+
current_obs = next_obs
|
|
1512
|
+
pending_tool_calls = None
|
|
1513
|
+
ops_executed += 1
|
|
1514
|
+
|
|
1515
|
+
# Handle episode end
|
|
1516
|
+
if env_response.done:
|
|
1517
|
+
if request.on_done == "reset":
|
|
1518
|
+
# Reset environment
|
|
1519
|
+
from .environment_routes import (
|
|
1520
|
+
reset_environment,
|
|
1521
|
+
EnvResetRequest,
|
|
1522
|
+
)
|
|
1523
|
+
|
|
1524
|
+
reset_response = await reset_environment(
|
|
1525
|
+
EnvResetRequest(env_id=env_id)
|
|
1526
|
+
)
|
|
1527
|
+
current_obs = reset_response.observation
|
|
1528
|
+
elif request.on_done == "terminate":
|
|
1529
|
+
break
|
|
1530
|
+
|
|
1531
|
+
if decision_open:
|
|
1532
|
+
await tracing_context.end_decision()
|
|
1533
|
+
decision_open = False
|
|
1534
|
+
|
|
1535
|
+
else:
|
|
1536
|
+
logger.warning(f"Unknown op: {op}")
|
|
1537
|
+
|
|
1538
|
+
if (
|
|
1539
|
+
last_policy_meta is not None
|
|
1540
|
+
and last_agent_response_ts is not None
|
|
1541
|
+
and "timing" in last_policy_meta
|
|
1542
|
+
and isinstance(last_policy_meta["timing"], dict)
|
|
1543
|
+
and "decision_ms" not in last_policy_meta["timing"]
|
|
1544
|
+
):
|
|
1545
|
+
try:
|
|
1546
|
+
final_now = last_env_step_completed_ts or _time.perf_counter()
|
|
1547
|
+
final_decision_ms = max(
|
|
1548
|
+
0.0, (final_now - float(last_agent_response_ts)) * 1000.0
|
|
1549
|
+
)
|
|
1550
|
+
timing_final = last_policy_meta.setdefault("timing", {})
|
|
1551
|
+
timing_final["decision_ms"] = final_decision_ms
|
|
1552
|
+
if last_env_step_ms is not None:
|
|
1553
|
+
timing_final.setdefault(
|
|
1554
|
+
"env_step_ms", float(last_env_step_ms)
|
|
1555
|
+
)
|
|
1556
|
+
timing_final.setdefault(
|
|
1557
|
+
"overhead_ms",
|
|
1558
|
+
max(0.0, final_decision_ms - float(last_env_step_ms)),
|
|
1559
|
+
)
|
|
1560
|
+
else:
|
|
1561
|
+
timing_final.setdefault("overhead_ms", 0.0)
|
|
1562
|
+
except Exception:
|
|
1563
|
+
pass
|
|
1564
|
+
|
|
1565
|
+
# Build trajectory
|
|
1566
|
+
trajectory = RolloutTrajectory(
|
|
1567
|
+
env_id=env_id,
|
|
1568
|
+
policy_id=policy_id,
|
|
1569
|
+
steps=trajectory_steps,
|
|
1570
|
+
final={"observation": _summarize_observation_for_storage(env_handle, current_obs)},
|
|
1571
|
+
length=len(trajectory_steps),
|
|
1572
|
+
decision_samples=decision_samples if step_rewards_active else None,
|
|
1573
|
+
)
|
|
1574
|
+
|
|
1575
|
+
# Build metrics
|
|
1576
|
+
metrics = RolloutMetrics(
|
|
1577
|
+
episode_returns=[total_reward],
|
|
1578
|
+
mean_return=total_reward,
|
|
1579
|
+
num_steps=len(trajectory_steps),
|
|
1580
|
+
num_episodes=1,
|
|
1581
|
+
)
|
|
1582
|
+
|
|
1583
|
+
# Environment-specific: Log summary if available
|
|
1584
|
+
try:
|
|
1585
|
+
# Check if this is a Wordle environment and use Wordle helpers (lazy import)
|
|
1586
|
+
try:
|
|
1587
|
+
from .envs.wordle.environment import WordleEnvironmentWrapper as _WordleWrapper
|
|
1588
|
+
from .envs.wordle.helpers import (
|
|
1589
|
+
get_wordle_rollout_summary,
|
|
1590
|
+
log_wordle_rollout_summary,
|
|
1591
|
+
)
|
|
1592
|
+
except Exception:
|
|
1593
|
+
_WordleWrapper = None # type: ignore
|
|
1594
|
+
get_wordle_rollout_summary = None # type: ignore
|
|
1595
|
+
log_wordle_rollout_summary = None # type: ignore
|
|
1596
|
+
|
|
1597
|
+
is_wordle = _WordleWrapper is not None and isinstance(env_handle.env, _WordleWrapper)
|
|
1598
|
+
if is_wordle:
|
|
1599
|
+
# Convert trajectory steps to expected format
|
|
1600
|
+
formatted_steps = []
|
|
1601
|
+
for step in trajectory_steps:
|
|
1602
|
+
formatted_steps.append({"tool_calls": step.tool_calls or []})
|
|
1603
|
+
|
|
1604
|
+
if get_wordle_rollout_summary is not None and log_wordle_rollout_summary is not None:
|
|
1605
|
+
summary = get_wordle_rollout_summary(
|
|
1606
|
+
formatted_steps, current_obs, env_handle
|
|
1607
|
+
)
|
|
1608
|
+
log_wordle_rollout_summary(request.run_id, summary)
|
|
1609
|
+
except ImportError:
|
|
1610
|
+
# Wordle helpers not available, skip Wordle-specific logging
|
|
1611
|
+
pass
|
|
1612
|
+
except Exception as e:
|
|
1613
|
+
logger.warning(f"Failed to generate environment-specific summary: {e}")
|
|
1614
|
+
|
|
1615
|
+
# Mark run as completed
|
|
1616
|
+
aborted = registry.is_run_aborted(request.run_id)
|
|
1617
|
+
if not aborted:
|
|
1618
|
+
registry.complete_run(request.run_id)
|
|
1619
|
+
if decision_open:
|
|
1620
|
+
await tracing_context.end_decision()
|
|
1621
|
+
decision_open = False
|
|
1622
|
+
if not finalized:
|
|
1623
|
+
session_trace = await tracing_context.finalize(
|
|
1624
|
+
total_reward=total_reward,
|
|
1625
|
+
achievement_state=prev_achievements,
|
|
1626
|
+
total_steps=len(trajectory_steps),
|
|
1627
|
+
)
|
|
1628
|
+
finalized = True
|
|
1629
|
+
trace_payload = tracing_context.build_trace_payload(session_trace)
|
|
1630
|
+
|
|
1631
|
+
return RolloutResponse(
|
|
1632
|
+
run_id=request.run_id,
|
|
1633
|
+
trajectories=[trajectory],
|
|
1634
|
+
branches={},
|
|
1635
|
+
metrics=metrics,
|
|
1636
|
+
aborted=aborted,
|
|
1637
|
+
ops_executed=ops_executed,
|
|
1638
|
+
trace=trace_payload,
|
|
1639
|
+
)
|
|
1640
|
+
|
|
1641
|
+
except Exception as e:
|
|
1642
|
+
logger.error(f"Rollout failed for run {request.run_id}: {e}")
|
|
1643
|
+
registry.abort_run(request.run_id)
|
|
1644
|
+
if decision_open:
|
|
1645
|
+
try:
|
|
1646
|
+
await tracing_context.end_decision()
|
|
1647
|
+
except Exception:
|
|
1648
|
+
pass
|
|
1649
|
+
decision_open = False
|
|
1650
|
+
if not finalized:
|
|
1651
|
+
try:
|
|
1652
|
+
session_trace = await tracing_context.finalize(
|
|
1653
|
+
total_reward=total_reward,
|
|
1654
|
+
achievement_state=prev_achievements,
|
|
1655
|
+
total_steps=len(trajectory_steps),
|
|
1656
|
+
)
|
|
1657
|
+
except Exception:
|
|
1658
|
+
session_trace = None
|
|
1659
|
+
finalized = True
|
|
1660
|
+
raise HTTPException(status_code=500, detail=str(e))
|
|
1661
|
+
finally:
|
|
1662
|
+
# Ensure any environment created for this rollout is terminated (no reuse across rollouts)
|
|
1663
|
+
try:
|
|
1664
|
+
if created_env_id:
|
|
1665
|
+
from .environment_routes import terminate_environment, EnvTerminateRequest
|
|
1666
|
+
|
|
1667
|
+
await terminate_environment(EnvTerminateRequest(env_id=created_env_id))
|
|
1668
|
+
logger.info(
|
|
1669
|
+
"ROLL_OUT: terminated environment env_id=%s seed=%s",
|
|
1670
|
+
str(created_env_id),
|
|
1671
|
+
str(env_seed_used) if env_seed_used is not None else "unknown",
|
|
1672
|
+
)
|
|
1673
|
+
# Verify removal from registry
|
|
1674
|
+
try:
|
|
1675
|
+
_post = registry.get_env(created_env_id)
|
|
1676
|
+
logger.info(
|
|
1677
|
+
"ROLL_OUT: env_killed=%s (post_lookup=%s)",
|
|
1678
|
+
str(_post is None),
|
|
1679
|
+
str(_post),
|
|
1680
|
+
)
|
|
1681
|
+
except Exception:
|
|
1682
|
+
pass
|
|
1683
|
+
except Exception as _te:
|
|
1684
|
+
logger.warning(
|
|
1685
|
+
f"ROLL_OUT: failed to terminate environment {created_env_id}: {_te}"
|
|
1686
|
+
)
|
|
1687
|
+
|
|
1688
|
+
# Best-effort policy cleanup if we created one (avoid reuse across rollouts)
|
|
1689
|
+
try:
|
|
1690
|
+
if created_policy_id:
|
|
1691
|
+
from .policy_routes import terminate_policy, PolicyTerminateRequest
|
|
1692
|
+
|
|
1693
|
+
await terminate_policy(PolicyTerminateRequest(policy_id=created_policy_id))
|
|
1694
|
+
logger.info("ROLL_OUT: terminated policy policy_id=%s", str(created_policy_id))
|
|
1695
|
+
except Exception:
|
|
1696
|
+
pass
|
|
1697
|
+
|
|
1698
|
+
if not finalized:
|
|
1699
|
+
try:
|
|
1700
|
+
session_trace = await tracing_context.finalize(
|
|
1701
|
+
total_reward=total_reward,
|
|
1702
|
+
achievement_state=prev_achievements,
|
|
1703
|
+
total_steps=len(trajectory_steps),
|
|
1704
|
+
)
|
|
1705
|
+
except Exception:
|
|
1706
|
+
session_trace = None
|
|
1707
|
+
finalized = True
|
|
1708
|
+
|
|
1709
|
+
try:
|
|
1710
|
+
_clear_seed_side_effects()
|
|
1711
|
+
logger.info("ROLL_OUT: RNG seed terminated/cleared before conclusion")
|
|
1712
|
+
except Exception:
|
|
1713
|
+
pass
|
|
1714
|
+
|
|
1715
|
+
|
|
1716
|
+
@router.post("/run/abort", response_model=RunAbortResponse)
|
|
1717
|
+
async def abort_run(request: RunAbortRequest) -> RunAbortResponse:
|
|
1718
|
+
"""Abort a running rollout."""
|
|
1719
|
+
success = registry.abort_run(request.run_id)
|
|
1720
|
+
|
|
1721
|
+
if not success:
|
|
1722
|
+
raise HTTPException(
|
|
1723
|
+
status_code=404,
|
|
1724
|
+
detail=f"Run {request.run_id} not found",
|
|
1725
|
+
)
|
|
1726
|
+
|
|
1727
|
+
return RunAbortResponse(
|
|
1728
|
+
ok=True,
|
|
1729
|
+
run_id=request.run_id,
|
|
1730
|
+
)
|
|
1731
|
+
|
|
1732
|
+
|
|
1733
|
+
@router.get("/run/status/{run_id}", response_model=RunStatusResponse)
|
|
1734
|
+
async def get_run_status(run_id: str) -> RunStatusResponse:
|
|
1735
|
+
"""Get the status of a run."""
|
|
1736
|
+
run_handle = registry.get_run(run_id)
|
|
1737
|
+
|
|
1738
|
+
if not run_handle:
|
|
1739
|
+
raise HTTPException(
|
|
1740
|
+
status_code=404,
|
|
1741
|
+
detail=f"Run {run_id} not found",
|
|
1742
|
+
)
|
|
1743
|
+
|
|
1744
|
+
return RunStatusResponse(
|
|
1745
|
+
run_id=run_id,
|
|
1746
|
+
status=run_handle.status,
|
|
1747
|
+
started_at=run_handle.started_at,
|
|
1748
|
+
finished_at=run_handle.finished_at,
|
|
1749
|
+
)
|