nvidia-nat-openpipe-art 1.4.0a20260116__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nat/meta/pypi.md +23 -0
- nat/plugins/openpipe/__init__.py +0 -0
- nat/plugins/openpipe/config.py +77 -0
- nat/plugins/openpipe/register.py +71 -0
- nat/plugins/openpipe/trainer.py +659 -0
- nat/plugins/openpipe/trainer_adapter.py +339 -0
- nat/plugins/openpipe/trajectory_builder.py +333 -0
- nvidia_nat_openpipe_art-1.4.0a20260116.dist-info/METADATA +46 -0
- nvidia_nat_openpipe_art-1.4.0a20260116.dist-info/RECORD +14 -0
- nvidia_nat_openpipe_art-1.4.0a20260116.dist-info/WHEEL +5 -0
- nvidia_nat_openpipe_art-1.4.0a20260116.dist-info/entry_points.txt +2 -0
- nvidia_nat_openpipe_art-1.4.0a20260116.dist-info/licenses/LICENSE-3rd-party.txt +5478 -0
- nvidia_nat_openpipe_art-1.4.0a20260116.dist-info/licenses/LICENSE.md +201 -0
- nvidia_nat_openpipe_art-1.4.0a20260116.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,339 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
import asyncio
|
|
17
|
+
import json
|
|
18
|
+
import logging
|
|
19
|
+
from datetime import datetime
|
|
20
|
+
from pathlib import Path
|
|
21
|
+
from typing import Any
|
|
22
|
+
|
|
23
|
+
import art
|
|
24
|
+
import httpx
|
|
25
|
+
|
|
26
|
+
from nat.data_models.finetuning import EpisodeItem
|
|
27
|
+
from nat.data_models.finetuning import EpisodeItemRole
|
|
28
|
+
from nat.data_models.finetuning import FinetuneConfig
|
|
29
|
+
from nat.data_models.finetuning import TrainingJobRef
|
|
30
|
+
from nat.data_models.finetuning import TrainingJobStatus
|
|
31
|
+
from nat.data_models.finetuning import TrainingStatusEnum
|
|
32
|
+
from nat.data_models.finetuning import Trajectory
|
|
33
|
+
from nat.data_models.finetuning import TrajectoryCollection
|
|
34
|
+
from nat.finetuning.interfaces.trainer_adapter import TrainerAdapter
|
|
35
|
+
|
|
36
|
+
from .config import ARTTrainerAdapterConfig
|
|
37
|
+
|
|
38
|
+
logger = logging.getLogger(__name__)
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
class ARTTrainerAdapter(TrainerAdapter):
|
|
42
|
+
"""
|
|
43
|
+
Adapter for the ART Trainer backend.
|
|
44
|
+
"""
|
|
45
|
+
|
|
46
|
+
def __init__(self, adapter_config: ARTTrainerAdapterConfig):
|
|
47
|
+
super().__init__(adapter_config)
|
|
48
|
+
|
|
49
|
+
self.adapter_config: ARTTrainerAdapterConfig = adapter_config
|
|
50
|
+
|
|
51
|
+
self.remote_backend: art.Backend = art.Backend(
|
|
52
|
+
base_url=f"http://{adapter_config.backend.ip}:{adapter_config.backend.port}")
|
|
53
|
+
|
|
54
|
+
self._model_internal_config: art.dev.InternalModelConfig = art.dev.InternalModelConfig(
|
|
55
|
+
init_args=self.adapter_config.backend.init_args,
|
|
56
|
+
engine_args=self.adapter_config.backend.engine_args,
|
|
57
|
+
torchtune_args=self.adapter_config.backend.torchtune_args,
|
|
58
|
+
trainer_args=self.adapter_config.training)
|
|
59
|
+
|
|
60
|
+
self.model: art.TrainableModel = art.TrainableModel(
|
|
61
|
+
name=self.adapter_config.backend.name,
|
|
62
|
+
project=self.adapter_config.backend.project,
|
|
63
|
+
base_model=self.adapter_config.backend.base_model,
|
|
64
|
+
_internal_config=self._model_internal_config,
|
|
65
|
+
)
|
|
66
|
+
|
|
67
|
+
self._training_jobs: dict[str, asyncio.Task[None]] = {}
|
|
68
|
+
|
|
69
|
+
logger.info(f"Initialized ARTTrainerAdapter with model: {self.model}")
|
|
70
|
+
|
|
71
|
+
@property
|
|
72
|
+
def training_jobs(self) -> dict[str, asyncio.Task[None]]:
|
|
73
|
+
return self._training_jobs
|
|
74
|
+
|
|
75
|
+
async def initialize(self, run_config: FinetuneConfig) -> None:
|
|
76
|
+
|
|
77
|
+
await super().initialize(run_config)
|
|
78
|
+
|
|
79
|
+
await self.model.register(self.remote_backend, _openai_client_config=self.adapter_config.backend.server_config)
|
|
80
|
+
|
|
81
|
+
health = await self.is_healthy()
|
|
82
|
+
|
|
83
|
+
if not health:
|
|
84
|
+
raise ConnectionError("Failed to connect to ART backend.")
|
|
85
|
+
|
|
86
|
+
logger.info("Successfully registered with ART backend.")
|
|
87
|
+
|
|
88
|
+
async def is_healthy(self) -> bool:
|
|
89
|
+
try:
|
|
90
|
+
async with httpx.AsyncClient() as c:
|
|
91
|
+
await c.get(f"http://{self.adapter_config.backend.ip}:8000/v1/models",
|
|
92
|
+
headers={"Authorization": f"Bearer {self.adapter_config.backend.api_key}"})
|
|
93
|
+
return True
|
|
94
|
+
except httpx.HTTPError as e:
|
|
95
|
+
logger.error(f"Health check failed: {e}")
|
|
96
|
+
return False
|
|
97
|
+
|
|
98
|
+
async def _validate_episode_order(self, traj: Trajectory):
|
|
99
|
+
"""
|
|
100
|
+
Checks all EpisodeItem in traj.episode to validate:
|
|
101
|
+
|
|
102
|
+
- Every EpisodeItem.role is EpisodeItemRole.USER, SYSTEM, or ASSISTANT
|
|
103
|
+
- The first EpisodeItem.role is SYSTEM or USER
|
|
104
|
+
- The last EpisodeItem.role is ASSISTANT
|
|
105
|
+
- No two consecutive EpisodeItem.role are the same, except for SYSTEM
|
|
106
|
+
|
|
107
|
+
Args:
|
|
108
|
+
traj: Trajectory to validate
|
|
109
|
+
|
|
110
|
+
Raises:
|
|
111
|
+
ValueError: If any of the above conditions are not met.
|
|
112
|
+
"""
|
|
113
|
+
if not traj.episode:
|
|
114
|
+
raise ValueError("Trajectory episode is empty.")
|
|
115
|
+
|
|
116
|
+
if traj.episode[0].role not in {EpisodeItemRole.USER, EpisodeItemRole.SYSTEM}:
|
|
117
|
+
raise ValueError("The first message in the trajectory must be from 'user' or 'system'.")
|
|
118
|
+
|
|
119
|
+
# if traj.episode[-1].role != EpisodeItemRole.ASSISTANT:
|
|
120
|
+
# raise ValueError("The last message in the trajectory must be from 'assistant'.")
|
|
121
|
+
|
|
122
|
+
for i in range(1, len(traj.episode)):
|
|
123
|
+
if traj.episode[i].role == traj.episode[i - 1].role and traj.episode[i].role == EpisodeItemRole.ASSISTANT:
|
|
124
|
+
raise ValueError("Consecutive assistant messages from the same role found in trajectory.")
|
|
125
|
+
|
|
126
|
+
async def _construct_trajectory_groups(self, trajectory_lists: list[list[Trajectory]]) -> list[art.TrajectoryGroup]:
|
|
127
|
+
"""
|
|
128
|
+
Convert list of lists of NAT Trajectory to list of ART TrajectoryGroup.
|
|
129
|
+
|
|
130
|
+
Args:
|
|
131
|
+
trajectory_lists: List of lists of NAT Trajectory (each inner list
|
|
132
|
+
contains trajectories for one example).
|
|
133
|
+
|
|
134
|
+
Returns:
|
|
135
|
+
List of ART TrajectoryGroup.
|
|
136
|
+
|
|
137
|
+
Raises:
|
|
138
|
+
ValueError: If any trajectory is invalid.
|
|
139
|
+
"""
|
|
140
|
+
|
|
141
|
+
from openai.types.chat.chat_completion import Choice
|
|
142
|
+
|
|
143
|
+
# ---------- helpers ----------
|
|
144
|
+
def _as_text(obj: Any) -> str:
|
|
145
|
+
return obj if isinstance(obj, str) else json.dumps(obj, ensure_ascii=False)
|
|
146
|
+
|
|
147
|
+
def _to_chat_msg(d: EpisodeItem) -> dict:
|
|
148
|
+
|
|
149
|
+
if d.role == EpisodeItemRole.USER:
|
|
150
|
+
return {
|
|
151
|
+
"role": "user",
|
|
152
|
+
"content": _as_text(d.content),
|
|
153
|
+
}
|
|
154
|
+
elif d.role == EpisodeItemRole.SYSTEM:
|
|
155
|
+
return {
|
|
156
|
+
"role": "system",
|
|
157
|
+
"content": _as_text(d.content),
|
|
158
|
+
}
|
|
159
|
+
else:
|
|
160
|
+
return {"role": "assistant", "content": _as_text(d.content)}
|
|
161
|
+
|
|
162
|
+
output_trajectory_groups = []
|
|
163
|
+
|
|
164
|
+
for trajectory_list in trajectory_lists:
|
|
165
|
+
art_trajectories = []
|
|
166
|
+
|
|
167
|
+
for traj in trajectory_list:
|
|
168
|
+
episode = traj.episode
|
|
169
|
+
reward = traj.reward
|
|
170
|
+
|
|
171
|
+
# Validate episode order
|
|
172
|
+
await self._validate_episode_order(traj)
|
|
173
|
+
|
|
174
|
+
try:
|
|
175
|
+
first_msg = _to_chat_msg(episode[0])
|
|
176
|
+
|
|
177
|
+
t = art.Trajectory(messages_and_choices=[first_msg], reward=reward)
|
|
178
|
+
|
|
179
|
+
for msg in episode[1:]:
|
|
180
|
+
if msg.role == EpisodeItemRole.ASSISTANT:
|
|
181
|
+
t.messages_and_choices.append(
|
|
182
|
+
Choice(index=0, logprobs=msg.logprobs, message=_to_chat_msg(msg), finish_reason="stop"))
|
|
183
|
+
else:
|
|
184
|
+
t.messages_and_choices.append(_to_chat_msg(msg))
|
|
185
|
+
|
|
186
|
+
# Sanity check: art.Trajectory.model_validate()
|
|
187
|
+
t.model_validate(t.model_dump())
|
|
188
|
+
|
|
189
|
+
art_trajectories.append(t)
|
|
190
|
+
|
|
191
|
+
except Exception as e:
|
|
192
|
+
logger.error(f"Error constructing trajectory: {e}. Skipping.")
|
|
193
|
+
continue
|
|
194
|
+
|
|
195
|
+
# Create TrajectoryGroup for this list of trajectories
|
|
196
|
+
if art_trajectories:
|
|
197
|
+
trajectory_group = art.TrajectoryGroup(trajectories=art_trajectories)
|
|
198
|
+
output_trajectory_groups.append(trajectory_group)
|
|
199
|
+
|
|
200
|
+
return output_trajectory_groups
|
|
201
|
+
|
|
202
|
+
async def submit(self, trajectories: TrajectoryCollection) -> TrainingJobRef:
|
|
203
|
+
"""
|
|
204
|
+
Submit trajectories to ART backend for training.
|
|
205
|
+
|
|
206
|
+
Args:
|
|
207
|
+
trajectories: TrajectoryCollection with list of lists of NAT Trajectory.
|
|
208
|
+
|
|
209
|
+
Returns:
|
|
210
|
+
TrainingJobRef: Reference to the submitted training job.
|
|
211
|
+
"""
|
|
212
|
+
|
|
213
|
+
trajectory_groups = await self._construct_trajectory_groups(trajectories.trajectories)
|
|
214
|
+
if not trajectory_groups:
|
|
215
|
+
raise ValueError("No valid trajectory groups to submit.")
|
|
216
|
+
|
|
217
|
+
assert trajectories.run_id not in self.training_jobs, (f"Training job "
|
|
218
|
+
f"with run_id {trajectories.run_id} already exists.")
|
|
219
|
+
|
|
220
|
+
# Delete old remote checkpoints
|
|
221
|
+
if self.adapter_config.backend.delete_old_checkpoints:
|
|
222
|
+
try:
|
|
223
|
+
logger.info("Deleting old checkpoints on ART backend...")
|
|
224
|
+
await self.model.delete_checkpoints()
|
|
225
|
+
except Exception as e:
|
|
226
|
+
logger.warning(f"Failed to delete old checkpoints: {e}")
|
|
227
|
+
|
|
228
|
+
# Submit new trajectories
|
|
229
|
+
task = asyncio.create_task(
|
|
230
|
+
self.model.train(trajectory_groups=trajectory_groups,
|
|
231
|
+
verbose=False,
|
|
232
|
+
config=art.types.TrainConfig(
|
|
233
|
+
beta=getattr(self.adapter_config.training, "beta", 0),
|
|
234
|
+
learning_rate=getattr(self.adapter_config.training, "learning_rate", 5e-5),
|
|
235
|
+
)),
|
|
236
|
+
name=f"art-train:{trajectories.run_id}",
|
|
237
|
+
)
|
|
238
|
+
|
|
239
|
+
# Optional: log + cleanup on completion to avoid leaks
|
|
240
|
+
def _on_done(t: asyncio.Task, rid: str = trajectories.run_id) -> None:
|
|
241
|
+
if t.cancelled():
|
|
242
|
+
logger.info(f"Training {rid} was cancelled.")
|
|
243
|
+
elif (exc := t.exception()) is not None:
|
|
244
|
+
logger.exception(f"Training {rid} failed", exc_info=exc)
|
|
245
|
+
else:
|
|
246
|
+
logger.info(f"Training {rid} completed successfully.")
|
|
247
|
+
|
|
248
|
+
task.add_done_callback(_on_done)
|
|
249
|
+
|
|
250
|
+
self.training_jobs[trajectories.run_id] = task
|
|
251
|
+
|
|
252
|
+
total_trajectories = sum(len(group.trajectories) for group in trajectory_groups)
|
|
253
|
+
logger.info(f"Submitted {total_trajectories} trajectories in {len(trajectory_groups)} groups for "
|
|
254
|
+
f"training with run_id {trajectories.run_id}.")
|
|
255
|
+
|
|
256
|
+
return TrainingJobRef(run_id=trajectories.run_id, backend="openpipe-art")
|
|
257
|
+
|
|
258
|
+
async def status(self, ref: TrainingJobRef) -> TrainingJobStatus:
|
|
259
|
+
task = self.training_jobs.get(ref.run_id)
|
|
260
|
+
if task is None:
|
|
261
|
+
raise ValueError(f"No training job found with run_id {ref.run_id}.")
|
|
262
|
+
|
|
263
|
+
if task.done():
|
|
264
|
+
if task.cancelled():
|
|
265
|
+
status = TrainingStatusEnum.CANCELED
|
|
266
|
+
progress = None
|
|
267
|
+
message = "Training was cancelled."
|
|
268
|
+
else:
|
|
269
|
+
exc = task.exception()
|
|
270
|
+
if exc is not None:
|
|
271
|
+
status = TrainingStatusEnum.FAILED
|
|
272
|
+
progress = None
|
|
273
|
+
message = f"Training failed with error: {exc!r}"
|
|
274
|
+
else:
|
|
275
|
+
status = TrainingStatusEnum.COMPLETED
|
|
276
|
+
progress = 100.0
|
|
277
|
+
message = "Training completed successfully."
|
|
278
|
+
|
|
279
|
+
_ = self.training_jobs.pop(ref.run_id, None) # Clean up completed job
|
|
280
|
+
|
|
281
|
+
else:
|
|
282
|
+
status = TrainingStatusEnum.RUNNING
|
|
283
|
+
progress = None
|
|
284
|
+
message = "Training is in progress."
|
|
285
|
+
|
|
286
|
+
return TrainingJobStatus(
|
|
287
|
+
run_id=ref.run_id,
|
|
288
|
+
backend=ref.backend,
|
|
289
|
+
status=status,
|
|
290
|
+
progress=progress,
|
|
291
|
+
message=message,
|
|
292
|
+
)
|
|
293
|
+
|
|
294
|
+
async def wait_until_complete(self, ref: TrainingJobRef, poll_interval: float = 10.0) -> TrainingJobStatus:
|
|
295
|
+
task = self.training_jobs.get(ref.run_id)
|
|
296
|
+
if task is None:
|
|
297
|
+
raise ValueError(f"No training job found with run_id {ref.run_id}.")
|
|
298
|
+
|
|
299
|
+
while not task.done():
|
|
300
|
+
await asyncio.sleep(poll_interval)
|
|
301
|
+
|
|
302
|
+
return await self.status(ref)
|
|
303
|
+
|
|
304
|
+
def log_progress(self, ref: TrainingJobRef, metrics: dict[str, Any], output_dir: str | None = None) -> None:
|
|
305
|
+
"""
|
|
306
|
+
Log training adapter progress.
|
|
307
|
+
|
|
308
|
+
Args:
|
|
309
|
+
ref: Training job reference
|
|
310
|
+
metrics: Dictionary of metrics to log
|
|
311
|
+
output_dir: Optional output directory override
|
|
312
|
+
"""
|
|
313
|
+
# Use default output directory if not provided
|
|
314
|
+
out_dir = Path(output_dir) if output_dir else Path("./.tmp/nat/finetuning/trainer_adapter")
|
|
315
|
+
out_dir.mkdir(parents=True, exist_ok=True)
|
|
316
|
+
|
|
317
|
+
# Create log file for trainer adapter
|
|
318
|
+
log_file = out_dir / f"trainer_adapter_{ref.run_id}.jsonl"
|
|
319
|
+
|
|
320
|
+
# Prepare log entry
|
|
321
|
+
log_entry = {
|
|
322
|
+
"timestamp": datetime.now().isoformat(),
|
|
323
|
+
"run_id": ref.run_id,
|
|
324
|
+
"backend": ref.backend,
|
|
325
|
+
"trainer_config": {
|
|
326
|
+
"base_model": self.adapter_config.backend.base_model,
|
|
327
|
+
"project": self.adapter_config.backend.project,
|
|
328
|
+
"name": self.adapter_config.backend.name,
|
|
329
|
+
},
|
|
330
|
+
**metrics
|
|
331
|
+
}
|
|
332
|
+
|
|
333
|
+
# Append to log file
|
|
334
|
+
with open(log_file, 'a', encoding='utf-8') as f:
|
|
335
|
+
f.write(json.dumps(log_entry) + '\n')
|
|
336
|
+
|
|
337
|
+
logger.info("Trainer adapter progress logged for job %s: status=%s",
|
|
338
|
+
ref.run_id,
|
|
339
|
+
metrics.get("status", "unknown"))
|
|
@@ -0,0 +1,333 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
4
|
+
#
|
|
5
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
6
|
+
# you may not use this file except in compliance with the License.
|
|
7
|
+
# You may obtain a copy of the License at
|
|
8
|
+
#
|
|
9
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
10
|
+
#
|
|
11
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
12
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
13
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
14
|
+
# See the License for the specific language governing permissions and
|
|
15
|
+
# limitations under the License.
|
|
16
|
+
|
|
17
|
+
import asyncio
|
|
18
|
+
import json
|
|
19
|
+
import logging
|
|
20
|
+
from datetime import datetime
|
|
21
|
+
from pathlib import Path
|
|
22
|
+
from typing import Any
|
|
23
|
+
|
|
24
|
+
from nat.data_models.finetuning import EpisodeItem
|
|
25
|
+
from nat.data_models.finetuning import EpisodeItemRole
|
|
26
|
+
from nat.data_models.finetuning import Trajectory
|
|
27
|
+
from nat.data_models.finetuning import TrajectoryCollection
|
|
28
|
+
from nat.data_models.intermediate_step import IntermediateStep
|
|
29
|
+
from nat.data_models.intermediate_step import IntermediateStepCategory
|
|
30
|
+
from nat.eval.config import EvaluationRunOutput
|
|
31
|
+
from nat.eval.evaluator.evaluator_model import EvalInputItem
|
|
32
|
+
from nat.finetuning.interfaces.trajectory_builder import TrajectoryBuilder
|
|
33
|
+
from nat.finetuning.utils.parsers.base_parser import parse_to_openai_messages
|
|
34
|
+
|
|
35
|
+
from .config import ARTTrajectoryBuilderConfig
|
|
36
|
+
|
|
37
|
+
logger = logging.getLogger(__name__)
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
class ARTTrajectoryBuilder(TrajectoryBuilder):
|
|
41
|
+
"""
|
|
42
|
+
Trajectory builder for the ART backend.
|
|
43
|
+
"""
|
|
44
|
+
|
|
45
|
+
def __init__(
|
|
46
|
+
self,
|
|
47
|
+
trajectory_builder_config: ARTTrajectoryBuilderConfig,
|
|
48
|
+
):
|
|
49
|
+
super().__init__(trajectory_builder_config=trajectory_builder_config)
|
|
50
|
+
self.evaluation_runs: dict[str, list[asyncio.Task[EvaluationRunOutput]]] = {}
|
|
51
|
+
|
|
52
|
+
@property
|
|
53
|
+
def num_generations(self) -> int:
|
|
54
|
+
return self.trajectory_builder_config.num_generations
|
|
55
|
+
|
|
56
|
+
async def start_run(self, run_id: str, meta: dict | None = None) -> None:
|
|
57
|
+
"""
|
|
58
|
+
Start multiple evaluation runs to collect trajectories.
|
|
59
|
+
|
|
60
|
+
Args:
|
|
61
|
+
run_id (str): The ID of the run.
|
|
62
|
+
meta (dict): Metadata for the run.
|
|
63
|
+
"""
|
|
64
|
+
|
|
65
|
+
if run_id in self.evaluation_runs:
|
|
66
|
+
raise ValueError(f"Run {run_id} is already in progress.")
|
|
67
|
+
|
|
68
|
+
logger.info("Starting %d evaluation runs for run_id: %s", self.num_generations, run_id)
|
|
69
|
+
tasks = []
|
|
70
|
+
|
|
71
|
+
for gen_idx in range(self.num_generations):
|
|
72
|
+
task = asyncio.create_task(self.run_eval(), name=f"eval-run-{run_id}-gen-{gen_idx}")
|
|
73
|
+
|
|
74
|
+
def _on_done(t: asyncio.Task[EvaluationRunOutput], generation_index: int = gen_idx) -> None:
|
|
75
|
+
if t.cancelled():
|
|
76
|
+
logger.info("Evaluation run for run_id: %s, generation: %d was cancelled.",
|
|
77
|
+
run_id,
|
|
78
|
+
generation_index)
|
|
79
|
+
elif exc := t.exception():
|
|
80
|
+
logger.error(
|
|
81
|
+
"Evaluation run for run_id: %s, generation: %d failed with exception: %s",
|
|
82
|
+
run_id,
|
|
83
|
+
generation_index,
|
|
84
|
+
exc,
|
|
85
|
+
)
|
|
86
|
+
else:
|
|
87
|
+
logger.info(
|
|
88
|
+
"Evaluation run for run_id: %s, generation: %d completed successfully.",
|
|
89
|
+
run_id,
|
|
90
|
+
generation_index,
|
|
91
|
+
)
|
|
92
|
+
|
|
93
|
+
task.add_done_callback(_on_done)
|
|
94
|
+
tasks.append(task)
|
|
95
|
+
|
|
96
|
+
self.evaluation_runs[run_id] = tasks
|
|
97
|
+
|
|
98
|
+
async def finalize(self, run_id: str, meta: dict | None = None) -> TrajectoryCollection:
|
|
99
|
+
"""
|
|
100
|
+
Waits for all evaluation runs to finalize and builds trajectories from
|
|
101
|
+
the episode items, grouping them by example ID.
|
|
102
|
+
|
|
103
|
+
Args:
|
|
104
|
+
run_id (str): The ID of the run.
|
|
105
|
+
meta (dict): Metadata for the run.
|
|
106
|
+
|
|
107
|
+
Returns:
|
|
108
|
+
TrajectoryCollection: The collection of built trajectories grouped by example.
|
|
109
|
+
"""
|
|
110
|
+
from nat.eval.evaluator.evaluator_model import EvalOutputItem
|
|
111
|
+
|
|
112
|
+
if run_id not in self.evaluation_runs:
|
|
113
|
+
raise ValueError(f"No evaluation runs found for run_id: {run_id}")
|
|
114
|
+
|
|
115
|
+
# Wait for all evaluation runs to complete
|
|
116
|
+
tasks = self.evaluation_runs[run_id]
|
|
117
|
+
eval_results = await asyncio.gather(*tasks)
|
|
118
|
+
|
|
119
|
+
# Dictionary to group trajectories by example ID
|
|
120
|
+
trajectories_by_id: dict[str, list[Trajectory]] = {}
|
|
121
|
+
|
|
122
|
+
# Process each evaluation result
|
|
123
|
+
for gen_idx, eval_result in enumerate(eval_results):
|
|
124
|
+
reward_results: list[EvalOutputItem] | None = None
|
|
125
|
+
for metric_name, metric_value in eval_result.evaluation_results:
|
|
126
|
+
if metric_name == self.run_config.reward_function.name:
|
|
127
|
+
reward_results = metric_value.eval_output_items
|
|
128
|
+
break
|
|
129
|
+
|
|
130
|
+
if not reward_results:
|
|
131
|
+
logger.warning(f"No reward results found for run_id: {run_id}, generation: {gen_idx}")
|
|
132
|
+
continue
|
|
133
|
+
|
|
134
|
+
logger.info("Building trajectories for run_id: %s, generation: %d", run_id, gen_idx)
|
|
135
|
+
|
|
136
|
+
# ---------- helpers ----------
|
|
137
|
+
def _as_text(obj: Any) -> str:
|
|
138
|
+
return (obj if isinstance(obj, str) else json.dumps(obj, ensure_ascii=False))
|
|
139
|
+
|
|
140
|
+
def _parse_trajectory_from_steps(steps: list[IntermediateStep], ) -> list[EpisodeItem]:
|
|
141
|
+
"""Parse trajectory from intermediate steps using parser."""
|
|
142
|
+
episode_items = []
|
|
143
|
+
|
|
144
|
+
try:
|
|
145
|
+
# Use the base parser to convert to OpenAI messages
|
|
146
|
+
openai_messages = parse_to_openai_messages(steps)
|
|
147
|
+
|
|
148
|
+
# Convert OpenAI messages to EpisodeItems
|
|
149
|
+
for msg in openai_messages:
|
|
150
|
+
# Map OpenAI role to EpisodeItemRole
|
|
151
|
+
role_mapping = {
|
|
152
|
+
"user": EpisodeItemRole.USER,
|
|
153
|
+
"assistant": EpisodeItemRole.ASSISTANT,
|
|
154
|
+
"system": EpisodeItemRole.SYSTEM,
|
|
155
|
+
"tool": EpisodeItemRole.TOOL,
|
|
156
|
+
"function": EpisodeItemRole.FUNCTION,
|
|
157
|
+
"human": EpisodeItemRole.USER,
|
|
158
|
+
"ai": EpisodeItemRole.ASSISTANT,
|
|
159
|
+
}
|
|
160
|
+
|
|
161
|
+
role = role_mapping.get(msg.get("role"), EpisodeItemRole.OTHER)
|
|
162
|
+
content = msg.get("content", "")
|
|
163
|
+
logprobs = msg.get("logprobs")
|
|
164
|
+
|
|
165
|
+
# For assistant messages, skip if no logprobs
|
|
166
|
+
if role == EpisodeItemRole.ASSISTANT and not logprobs:
|
|
167
|
+
logger.debug("Skipping assistant message without logprobs")
|
|
168
|
+
continue
|
|
169
|
+
|
|
170
|
+
# Build metadata from message attributes
|
|
171
|
+
metadata = {}
|
|
172
|
+
|
|
173
|
+
# Add tool/function specific metadata
|
|
174
|
+
if "tool_call_id" in msg:
|
|
175
|
+
metadata["tool_call_id"] = msg["tool_call_id"]
|
|
176
|
+
if "tool_calls" in msg:
|
|
177
|
+
metadata["tool_calls"] = msg["tool_calls"]
|
|
178
|
+
if "function_call" in msg:
|
|
179
|
+
metadata["function_call"] = msg["function_call"]
|
|
180
|
+
if "name" in msg:
|
|
181
|
+
metadata["name"] = msg["name"]
|
|
182
|
+
|
|
183
|
+
episode_items.append(
|
|
184
|
+
EpisodeItem(
|
|
185
|
+
role=role,
|
|
186
|
+
content=content,
|
|
187
|
+
logprobs=logprobs,
|
|
188
|
+
metadata=metadata if metadata else None,
|
|
189
|
+
))
|
|
190
|
+
|
|
191
|
+
except ValueError as e:
|
|
192
|
+
logger.warning(
|
|
193
|
+
"Failed to parse trajectory using base parser: %s. "
|
|
194
|
+
"Falling back to empty episode.", str(e))
|
|
195
|
+
# Return empty list on parse failure
|
|
196
|
+
return []
|
|
197
|
+
|
|
198
|
+
return episode_items
|
|
199
|
+
|
|
200
|
+
# Create a mapping of id to input item for quick lookup
|
|
201
|
+
input_items_map = {item.id: item for item in eval_result.eval_input.eval_input_items}
|
|
202
|
+
|
|
203
|
+
for reward_item in reward_results:
|
|
204
|
+
# Find the corresponding input item
|
|
205
|
+
input_item: EvalInputItem = input_items_map.get(reward_item.id)
|
|
206
|
+
if not input_item:
|
|
207
|
+
logger.warning(
|
|
208
|
+
"No input item found for reward item id: %s",
|
|
209
|
+
reward_item.id,
|
|
210
|
+
)
|
|
211
|
+
continue
|
|
212
|
+
|
|
213
|
+
filtered_trajectory = []
|
|
214
|
+
for item in input_item.trajectory:
|
|
215
|
+
if item.function_ancestry.function_name in self.run_config.target_functions:
|
|
216
|
+
# If target model is specified, filter by model name
|
|
217
|
+
if (self.run_config.target_model and item.event_category == IntermediateStepCategory.LLM
|
|
218
|
+
and item.payload.name != self.run_config.target_model):
|
|
219
|
+
continue
|
|
220
|
+
filtered_trajectory.append(item)
|
|
221
|
+
|
|
222
|
+
if not filtered_trajectory:
|
|
223
|
+
logger.warning(
|
|
224
|
+
"No trajectory steps found for target function '%s' in item id: %s",
|
|
225
|
+
self.run_config.target_functions,
|
|
226
|
+
reward_item.id,
|
|
227
|
+
)
|
|
228
|
+
continue
|
|
229
|
+
|
|
230
|
+
# Parse episode from intermediate steps
|
|
231
|
+
episode = _parse_trajectory_from_steps(filtered_trajectory)
|
|
232
|
+
|
|
233
|
+
# If no episode was parsed from steps, try to build from
|
|
234
|
+
# input/output
|
|
235
|
+
if not episode:
|
|
236
|
+
continue
|
|
237
|
+
|
|
238
|
+
# Ensure we have at least a user and assistant message
|
|
239
|
+
if len(episode) < 2:
|
|
240
|
+
logger.warning(
|
|
241
|
+
"Episode for item %s has less than 2 messages, skipping",
|
|
242
|
+
reward_item.id,
|
|
243
|
+
)
|
|
244
|
+
continue
|
|
245
|
+
|
|
246
|
+
# Validate that assistant messages have logprobs
|
|
247
|
+
# (required for training)
|
|
248
|
+
has_valid_assistant = False
|
|
249
|
+
for item in episode:
|
|
250
|
+
if item.role == EpisodeItemRole.ASSISTANT and item.logprobs:
|
|
251
|
+
has_valid_assistant = True
|
|
252
|
+
break
|
|
253
|
+
|
|
254
|
+
if not has_valid_assistant:
|
|
255
|
+
logger.warning(
|
|
256
|
+
"Episode for item %s has no assistant messages with "
|
|
257
|
+
"logprobs, skipping as it cannot be used for training",
|
|
258
|
+
reward_item.id,
|
|
259
|
+
)
|
|
260
|
+
continue
|
|
261
|
+
|
|
262
|
+
# Create trajectory
|
|
263
|
+
trajectory = Trajectory(
|
|
264
|
+
episode=episode,
|
|
265
|
+
reward=(await self.compute_reward(reward_item, meta=meta)),
|
|
266
|
+
metadata={
|
|
267
|
+
"id": reward_item.id,
|
|
268
|
+
"reasoning": reward_item.reasoning,
|
|
269
|
+
"run_id": run_id,
|
|
270
|
+
"generation": gen_idx,
|
|
271
|
+
},
|
|
272
|
+
)
|
|
273
|
+
|
|
274
|
+
# Group by example ID
|
|
275
|
+
if reward_item.id not in trajectories_by_id:
|
|
276
|
+
trajectories_by_id[reward_item.id] = []
|
|
277
|
+
trajectories_by_id[reward_item.id].append(trajectory)
|
|
278
|
+
|
|
279
|
+
# Clean up completed runs
|
|
280
|
+
self.evaluation_runs.pop(run_id, None)
|
|
281
|
+
|
|
282
|
+
# Convert dictionary to list of lists, maintaining order
|
|
283
|
+
trajectories_list = list(trajectories_by_id.values())
|
|
284
|
+
|
|
285
|
+
total_trajectories = sum(len(traj_list) for traj_list in trajectories_list)
|
|
286
|
+
logger.info("Built %d trajectories across %d examples for run_id: %s",
|
|
287
|
+
total_trajectories,
|
|
288
|
+
len(trajectories_list),
|
|
289
|
+
run_id)
|
|
290
|
+
|
|
291
|
+
# Flatten the trajectories list into a 1 d list of trajectories
|
|
292
|
+
if not trajectories_list:
|
|
293
|
+
logger.warning("No trajectories were built for run_id: %s", run_id)
|
|
294
|
+
return TrajectoryCollection(trajectories=[], run_id=run_id)
|
|
295
|
+
|
|
296
|
+
if self.num_generations == 1:
|
|
297
|
+
# If only one generation, return flat list
|
|
298
|
+
flat_trajectories = [traj for sublist in trajectories_list for traj in sublist]
|
|
299
|
+
return TrajectoryCollection(trajectories=[flat_trajectories], run_id=run_id)
|
|
300
|
+
|
|
301
|
+
return TrajectoryCollection(trajectories=trajectories_list, run_id=run_id)
|
|
302
|
+
|
|
303
|
+
def log_progress(self, run_id: str, metrics: dict[str, Any], output_dir: str | None = None) -> None:
|
|
304
|
+
"""
|
|
305
|
+
Log trajectory building progress.
|
|
306
|
+
|
|
307
|
+
Args:
|
|
308
|
+
run_id: The training run ID
|
|
309
|
+
metrics: Dictionary of metrics to log
|
|
310
|
+
output_dir: Optional output directory override
|
|
311
|
+
"""
|
|
312
|
+
# Use default output directory if not provided
|
|
313
|
+
out_dir = Path(output_dir) if output_dir else Path("./.tmp/nat/finetuning/trajectory_builder")
|
|
314
|
+
out_dir.mkdir(parents=True, exist_ok=True)
|
|
315
|
+
|
|
316
|
+
# Create log file for trajectory builder
|
|
317
|
+
log_file = out_dir / f"trajectory_builder_{run_id}.jsonl"
|
|
318
|
+
|
|
319
|
+
# Prepare log entry
|
|
320
|
+
log_entry = {
|
|
321
|
+
"timestamp": datetime.now().isoformat(),
|
|
322
|
+
"run_id": run_id,
|
|
323
|
+
"num_generations": self.num_generations,
|
|
324
|
+
**metrics
|
|
325
|
+
}
|
|
326
|
+
|
|
327
|
+
# Append to log file
|
|
328
|
+
with open(log_file, 'a', encoding='utf-8') as f:
|
|
329
|
+
f.write(json.dumps(log_entry) + '\n')
|
|
330
|
+
|
|
331
|
+
logger.debug("Trajectory builder progress logged for run %s: %d trajectories",
|
|
332
|
+
run_id,
|
|
333
|
+
metrics.get("num_trajectories", 0))
|