nvidia-nat-openpipe-art 1.4.0a20260109__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of nvidia-nat-openpipe-art might be problematic. Click here for more details.

@@ -0,0 +1,339 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import asyncio
17
+ import json
18
+ import logging
19
+ from datetime import datetime
20
+ from pathlib import Path
21
+ from typing import Any
22
+
23
+ import art
24
+ import httpx
25
+
26
+ from nat.data_models.finetuning import EpisodeItem
27
+ from nat.data_models.finetuning import EpisodeItemRole
28
+ from nat.data_models.finetuning import FinetuneConfig
29
+ from nat.data_models.finetuning import TrainingJobRef
30
+ from nat.data_models.finetuning import TrainingJobStatus
31
+ from nat.data_models.finetuning import TrainingStatusEnum
32
+ from nat.data_models.finetuning import Trajectory
33
+ from nat.data_models.finetuning import TrajectoryCollection
34
+ from nat.finetuning.interfaces.trainer_adapter import TrainerAdapter
35
+
36
+ from .config import ARTTrainerAdapterConfig
37
+
38
+ logger = logging.getLogger(__name__)
39
+
40
+
41
+ class ARTTrainerAdapter(TrainerAdapter):
42
+ """
43
+ Adapter for the ART Trainer backend.
44
+ """
45
+
46
+ def __init__(self, adapter_config: ARTTrainerAdapterConfig):
47
+ super().__init__(adapter_config)
48
+
49
+ self.adapter_config: ARTTrainerAdapterConfig = adapter_config
50
+
51
+ self.remote_backend: art.Backend = art.Backend(
52
+ base_url=f"http://{adapter_config.backend.ip}:{adapter_config.backend.port}")
53
+
54
+ self._model_internal_config: art.dev.InternalModelConfig = art.dev.InternalModelConfig(
55
+ init_args=self.adapter_config.backend.init_args,
56
+ engine_args=self.adapter_config.backend.engine_args,
57
+ torchtune_args=self.adapter_config.backend.torchtune_args,
58
+ trainer_args=self.adapter_config.training)
59
+
60
+ self.model: art.TrainableModel = art.TrainableModel(
61
+ name=self.adapter_config.backend.name,
62
+ project=self.adapter_config.backend.project,
63
+ base_model=self.adapter_config.backend.base_model,
64
+ _internal_config=self._model_internal_config,
65
+ )
66
+
67
+ self._training_jobs: dict[str, asyncio.Task[None]] = {}
68
+
69
+ logger.info(f"Initialized ARTTrainerAdapter with model: {self.model}")
70
+
71
+ @property
72
+ def training_jobs(self) -> dict[str, asyncio.Task[None]]:
73
+ return self._training_jobs
74
+
75
+ async def initialize(self, run_config: FinetuneConfig) -> None:
76
+
77
+ await super().initialize(run_config)
78
+
79
+ await self.model.register(self.remote_backend, _openai_client_config=self.adapter_config.backend.server_config)
80
+
81
+ health = await self.is_healthy()
82
+
83
+ if not health:
84
+ raise ConnectionError("Failed to connect to ART backend.")
85
+
86
+ logger.info("Successfully registered with ART backend.")
87
+
88
+ async def is_healthy(self) -> bool:
89
+ try:
90
+ async with httpx.AsyncClient() as c:
91
+ await c.get(f"http://{self.adapter_config.backend.ip}:8000/v1/models",
92
+ headers={"Authorization": f"Bearer {self.adapter_config.backend.api_key}"})
93
+ return True
94
+ except httpx.HTTPError as e:
95
+ logger.error(f"Health check failed: {e}")
96
+ return False
97
+
98
+ async def _validate_episode_order(self, traj: Trajectory):
99
+ """
100
+ Checks all EpisodeItem in traj.episode to validate:
101
+
102
+ - Every EpisodeItem.role is EpisodeItemRole.USER, SYSTEM, or ASSISTANT
103
+ - The first EpisodeItem.role is SYSTEM or USER
104
+ - The last EpisodeItem.role is ASSISTANT
105
+ - No two consecutive EpisodeItem.role are the same, except for SYSTEM
106
+
107
+ Args:
108
+ traj: Trajectory to validate
109
+
110
+ Raises:
111
+ ValueError: If any of the above conditions are not met.
112
+ """
113
+ if not traj.episode:
114
+ raise ValueError("Trajectory episode is empty.")
115
+
116
+ if traj.episode[0].role not in {EpisodeItemRole.USER, EpisodeItemRole.SYSTEM}:
117
+ raise ValueError("The first message in the trajectory must be from 'user' or 'system'.")
118
+
119
+ # if traj.episode[-1].role != EpisodeItemRole.ASSISTANT:
120
+ # raise ValueError("The last message in the trajectory must be from 'assistant'.")
121
+
122
+ for i in range(1, len(traj.episode)):
123
+ if traj.episode[i].role == traj.episode[i - 1].role and traj.episode[i].role == EpisodeItemRole.ASSISTANT:
124
+ raise ValueError("Consecutive assistant messages from the same role found in trajectory.")
125
+
126
+ async def _construct_trajectory_groups(self, trajectory_lists: list[list[Trajectory]]) -> list[art.TrajectoryGroup]:
127
+ """
128
+ Convert list of lists of NAT Trajectory to list of ART TrajectoryGroup.
129
+
130
+ Args:
131
+ trajectory_lists: List of lists of NAT Trajectory (each inner list
132
+ contains trajectories for one example).
133
+
134
+ Returns:
135
+ List of ART TrajectoryGroup.
136
+
137
+ Raises:
138
+ ValueError: If any trajectory is invalid.
139
+ """
140
+
141
+ from openai.types.chat.chat_completion import Choice
142
+
143
+ # ---------- helpers ----------
144
+ def _as_text(obj: Any) -> str:
145
+ return obj if isinstance(obj, str) else json.dumps(obj, ensure_ascii=False)
146
+
147
+ def _to_chat_msg(d: EpisodeItem) -> dict:
148
+
149
+ if d.role == EpisodeItemRole.USER:
150
+ return {
151
+ "role": "user",
152
+ "content": _as_text(d.content),
153
+ }
154
+ elif d.role == EpisodeItemRole.SYSTEM:
155
+ return {
156
+ "role": "system",
157
+ "content": _as_text(d.content),
158
+ }
159
+ else:
160
+ return {"role": "assistant", "content": _as_text(d.content)}
161
+
162
+ output_trajectory_groups = []
163
+
164
+ for trajectory_list in trajectory_lists:
165
+ art_trajectories = []
166
+
167
+ for traj in trajectory_list:
168
+ episode = traj.episode
169
+ reward = traj.reward
170
+
171
+ # Validate episode order
172
+ await self._validate_episode_order(traj)
173
+
174
+ try:
175
+ first_msg = _to_chat_msg(episode[0])
176
+
177
+ t = art.Trajectory(messages_and_choices=[first_msg], reward=reward)
178
+
179
+ for msg in episode[1:]:
180
+ if msg.role == EpisodeItemRole.ASSISTANT:
181
+ t.messages_and_choices.append(
182
+ Choice(index=0, logprobs=msg.logprobs, message=_to_chat_msg(msg), finish_reason="stop"))
183
+ else:
184
+ t.messages_and_choices.append(_to_chat_msg(msg))
185
+
186
+ # Sanity check: art.Trajectory.model_validate()
187
+ t.model_validate(t.model_dump())
188
+
189
+ art_trajectories.append(t)
190
+
191
+ except Exception as e:
192
+ logger.error(f"Error constructing trajectory: {e}. Skipping.")
193
+ continue
194
+
195
+ # Create TrajectoryGroup for this list of trajectories
196
+ if art_trajectories:
197
+ trajectory_group = art.TrajectoryGroup(trajectories=art_trajectories)
198
+ output_trajectory_groups.append(trajectory_group)
199
+
200
+ return output_trajectory_groups
201
+
202
+ async def submit(self, trajectories: TrajectoryCollection) -> TrainingJobRef:
203
+ """
204
+ Submit trajectories to ART backend for training.
205
+
206
+ Args:
207
+ trajectories: TrajectoryCollection with list of lists of NAT Trajectory.
208
+
209
+ Returns:
210
+ TrainingJobRef: Reference to the submitted training job.
211
+ """
212
+
213
+ trajectory_groups = await self._construct_trajectory_groups(trajectories.trajectories)
214
+ if not trajectory_groups:
215
+ raise ValueError("No valid trajectory groups to submit.")
216
+
217
+ assert trajectories.run_id not in self.training_jobs, (f"Training job "
218
+ f"with run_id {trajectories.run_id} already exists.")
219
+
220
+ # Delete old remote checkpoints
221
+ if self.adapter_config.backend.delete_old_checkpoints:
222
+ try:
223
+ logger.info("Deleting old checkpoints on ART backend...")
224
+ await self.model.delete_checkpoints()
225
+ except Exception as e:
226
+ logger.warning(f"Failed to delete old checkpoints: {e}")
227
+
228
+ # Submit new trajectories
229
+ task = asyncio.create_task(
230
+ self.model.train(trajectory_groups=trajectory_groups,
231
+ verbose=False,
232
+ config=art.types.TrainConfig(
233
+ beta=getattr(self.adapter_config.training, "beta", 0),
234
+ learning_rate=getattr(self.adapter_config.training, "learning_rate", 5e-5),
235
+ )),
236
+ name=f"art-train:{trajectories.run_id}",
237
+ )
238
+
239
+ # Optional: log + cleanup on completion to avoid leaks
240
+ def _on_done(t: asyncio.Task, rid: str = trajectories.run_id) -> None:
241
+ if t.cancelled():
242
+ logger.info(f"Training {rid} was cancelled.")
243
+ elif (exc := t.exception()) is not None:
244
+ logger.exception(f"Training {rid} failed", exc_info=exc)
245
+ else:
246
+ logger.info(f"Training {rid} completed successfully.")
247
+
248
+ task.add_done_callback(_on_done)
249
+
250
+ self.training_jobs[trajectories.run_id] = task
251
+
252
+ total_trajectories = sum(len(group.trajectories) for group in trajectory_groups)
253
+ logger.info(f"Submitted {total_trajectories} trajectories in {len(trajectory_groups)} groups for "
254
+ f"training with run_id {trajectories.run_id}.")
255
+
256
+ return TrainingJobRef(run_id=trajectories.run_id, backend="openpipe-art")
257
+
258
+ async def status(self, ref: TrainingJobRef) -> TrainingJobStatus:
259
+ task = self.training_jobs.get(ref.run_id)
260
+ if task is None:
261
+ raise ValueError(f"No training job found with run_id {ref.run_id}.")
262
+
263
+ if task.done():
264
+ if task.cancelled():
265
+ status = TrainingStatusEnum.CANCELED
266
+ progress = None
267
+ message = "Training was cancelled."
268
+ else:
269
+ exc = task.exception()
270
+ if exc is not None:
271
+ status = TrainingStatusEnum.FAILED
272
+ progress = None
273
+ message = f"Training failed with error: {exc!r}"
274
+ else:
275
+ status = TrainingStatusEnum.COMPLETED
276
+ progress = 100.0
277
+ message = "Training completed successfully."
278
+
279
+ _ = self.training_jobs.pop(ref.run_id, None) # Clean up completed job
280
+
281
+ else:
282
+ status = TrainingStatusEnum.RUNNING
283
+ progress = None
284
+ message = "Training is in progress."
285
+
286
+ return TrainingJobStatus(
287
+ run_id=ref.run_id,
288
+ backend=ref.backend,
289
+ status=status,
290
+ progress=progress,
291
+ message=message,
292
+ )
293
+
294
+ async def wait_until_complete(self, ref: TrainingJobRef, poll_interval: float = 10.0) -> TrainingJobStatus:
295
+ task = self.training_jobs.get(ref.run_id)
296
+ if task is None:
297
+ raise ValueError(f"No training job found with run_id {ref.run_id}.")
298
+
299
+ while not task.done():
300
+ await asyncio.sleep(poll_interval)
301
+
302
+ return await self.status(ref)
303
+
304
+ def log_progress(self, ref: TrainingJobRef, metrics: dict[str, Any], output_dir: str | None = None) -> None:
305
+ """
306
+ Log training adapter progress.
307
+
308
+ Args:
309
+ ref: Training job reference
310
+ metrics: Dictionary of metrics to log
311
+ output_dir: Optional output directory override
312
+ """
313
+ # Use default output directory if not provided
314
+ out_dir = Path(output_dir) if output_dir else Path("./.tmp/nat/finetuning/trainer_adapter")
315
+ out_dir.mkdir(parents=True, exist_ok=True)
316
+
317
+ # Create log file for trainer adapter
318
+ log_file = out_dir / f"trainer_adapter_{ref.run_id}.jsonl"
319
+
320
+ # Prepare log entry
321
+ log_entry = {
322
+ "timestamp": datetime.now().isoformat(),
323
+ "run_id": ref.run_id,
324
+ "backend": ref.backend,
325
+ "trainer_config": {
326
+ "base_model": self.adapter_config.backend.base_model,
327
+ "project": self.adapter_config.backend.project,
328
+ "name": self.adapter_config.backend.name,
329
+ },
330
+ **metrics
331
+ }
332
+
333
+ # Append to log file
334
+ with open(log_file, 'a', encoding='utf-8') as f:
335
+ f.write(json.dumps(log_entry) + '\n')
336
+
337
+ logger.info("Trainer adapter progress logged for job %s: status=%s",
338
+ ref.run_id,
339
+ metrics.get("status", "unknown"))
@@ -0,0 +1,333 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES.
2
+ # All rights reserved.
3
+ # SPDX-License-Identifier: Apache-2.0
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+ import asyncio
18
+ import json
19
+ import logging
20
+ from datetime import datetime
21
+ from pathlib import Path
22
+ from typing import Any
23
+
24
+ from nat.data_models.finetuning import EpisodeItem
25
+ from nat.data_models.finetuning import EpisodeItemRole
26
+ from nat.data_models.finetuning import Trajectory
27
+ from nat.data_models.finetuning import TrajectoryCollection
28
+ from nat.data_models.intermediate_step import IntermediateStep
29
+ from nat.data_models.intermediate_step import IntermediateStepCategory
30
+ from nat.eval.config import EvaluationRunOutput
31
+ from nat.eval.evaluator.evaluator_model import EvalInputItem
32
+ from nat.finetuning.interfaces.trajectory_builder import TrajectoryBuilder
33
+ from nat.finetuning.utils.parsers.base_parser import parse_to_openai_messages
34
+
35
+ from .config import ARTTrajectoryBuilderConfig
36
+
37
+ logger = logging.getLogger(__name__)
38
+
39
+
40
+ class ARTTrajectoryBuilder(TrajectoryBuilder):
41
+ """
42
+ Trajectory builder for the ART backend.
43
+ """
44
+
45
+ def __init__(
46
+ self,
47
+ trajectory_builder_config: ARTTrajectoryBuilderConfig,
48
+ ):
49
+ super().__init__(trajectory_builder_config=trajectory_builder_config)
50
+ self.evaluation_runs: dict[str, list[asyncio.Task[EvaluationRunOutput]]] = {}
51
+
52
+ @property
53
+ def num_generations(self) -> int:
54
+ return self.trajectory_builder_config.num_generations
55
+
56
+ async def start_run(self, run_id: str, meta: dict | None = None) -> None:
57
+ """
58
+ Start multiple evaluation runs to collect trajectories.
59
+
60
+ Args:
61
+ run_id (str): The ID of the run.
62
+ meta (dict): Metadata for the run.
63
+ """
64
+
65
+ if run_id in self.evaluation_runs:
66
+ raise ValueError(f"Run {run_id} is already in progress.")
67
+
68
+ logger.info("Starting %d evaluation runs for run_id: %s", self.num_generations, run_id)
69
+ tasks = []
70
+
71
+ for gen_idx in range(self.num_generations):
72
+ task = asyncio.create_task(self.run_eval(), name=f"eval-run-{run_id}-gen-{gen_idx}")
73
+
74
+ def _on_done(t: asyncio.Task[EvaluationRunOutput], generation_index: int = gen_idx) -> None:
75
+ if t.cancelled():
76
+ logger.info("Evaluation run for run_id: %s, generation: %d was cancelled.",
77
+ run_id,
78
+ generation_index)
79
+ elif exc := t.exception():
80
+ logger.error(
81
+ "Evaluation run for run_id: %s, generation: %d failed with exception: %s",
82
+ run_id,
83
+ generation_index,
84
+ exc,
85
+ )
86
+ else:
87
+ logger.info(
88
+ "Evaluation run for run_id: %s, generation: %d completed successfully.",
89
+ run_id,
90
+ generation_index,
91
+ )
92
+
93
+ task.add_done_callback(_on_done)
94
+ tasks.append(task)
95
+
96
+ self.evaluation_runs[run_id] = tasks
97
+
98
+ async def finalize(self, run_id: str, meta: dict | None = None) -> TrajectoryCollection:
99
+ """
100
+ Waits for all evaluation runs to finalize and builds trajectories from
101
+ the episode items, grouping them by example ID.
102
+
103
+ Args:
104
+ run_id (str): The ID of the run.
105
+ meta (dict): Metadata for the run.
106
+
107
+ Returns:
108
+ TrajectoryCollection: The collection of built trajectories grouped by example.
109
+ """
110
+ from nat.eval.evaluator.evaluator_model import EvalOutputItem
111
+
112
+ if run_id not in self.evaluation_runs:
113
+ raise ValueError(f"No evaluation runs found for run_id: {run_id}")
114
+
115
+ # Wait for all evaluation runs to complete
116
+ tasks = self.evaluation_runs[run_id]
117
+ eval_results = await asyncio.gather(*tasks)
118
+
119
+ # Dictionary to group trajectories by example ID
120
+ trajectories_by_id: dict[str, list[Trajectory]] = {}
121
+
122
+ # Process each evaluation result
123
+ for gen_idx, eval_result in enumerate(eval_results):
124
+ reward_results: list[EvalOutputItem] | None = None
125
+ for metric_name, metric_value in eval_result.evaluation_results:
126
+ if metric_name == self.run_config.reward_function.name:
127
+ reward_results = metric_value.eval_output_items
128
+ break
129
+
130
+ if not reward_results:
131
+ logger.warning(f"No reward results found for run_id: {run_id}, generation: {gen_idx}")
132
+ continue
133
+
134
+ logger.info("Building trajectories for run_id: %s, generation: %d", run_id, gen_idx)
135
+
136
+ # ---------- helpers ----------
137
+ def _as_text(obj: Any) -> str:
138
+ return (obj if isinstance(obj, str) else json.dumps(obj, ensure_ascii=False))
139
+
140
+ def _parse_trajectory_from_steps(steps: list[IntermediateStep], ) -> list[EpisodeItem]:
141
+ """Parse trajectory from intermediate steps using parser."""
142
+ episode_items = []
143
+
144
+ try:
145
+ # Use the base parser to convert to OpenAI messages
146
+ openai_messages = parse_to_openai_messages(steps)
147
+
148
+ # Convert OpenAI messages to EpisodeItems
149
+ for msg in openai_messages:
150
+ # Map OpenAI role to EpisodeItemRole
151
+ role_mapping = {
152
+ "user": EpisodeItemRole.USER,
153
+ "assistant": EpisodeItemRole.ASSISTANT,
154
+ "system": EpisodeItemRole.SYSTEM,
155
+ "tool": EpisodeItemRole.TOOL,
156
+ "function": EpisodeItemRole.FUNCTION,
157
+ "human": EpisodeItemRole.USER,
158
+ "ai": EpisodeItemRole.ASSISTANT,
159
+ }
160
+
161
+ role = role_mapping.get(msg.get("role"), EpisodeItemRole.OTHER)
162
+ content = msg.get("content", "")
163
+ logprobs = msg.get("logprobs")
164
+
165
+ # For assistant messages, skip if no logprobs
166
+ if role == EpisodeItemRole.ASSISTANT and not logprobs:
167
+ logger.debug("Skipping assistant message without logprobs")
168
+ continue
169
+
170
+ # Build metadata from message attributes
171
+ metadata = {}
172
+
173
+ # Add tool/function specific metadata
174
+ if "tool_call_id" in msg:
175
+ metadata["tool_call_id"] = msg["tool_call_id"]
176
+ if "tool_calls" in msg:
177
+ metadata["tool_calls"] = msg["tool_calls"]
178
+ if "function_call" in msg:
179
+ metadata["function_call"] = msg["function_call"]
180
+ if "name" in msg:
181
+ metadata["name"] = msg["name"]
182
+
183
+ episode_items.append(
184
+ EpisodeItem(
185
+ role=role,
186
+ content=content,
187
+ logprobs=logprobs,
188
+ metadata=metadata if metadata else None,
189
+ ))
190
+
191
+ except ValueError as e:
192
+ logger.warning(
193
+ "Failed to parse trajectory using base parser: %s. "
194
+ "Falling back to empty episode.", str(e))
195
+ # Return empty list on parse failure
196
+ return []
197
+
198
+ return episode_items
199
+
200
+ # Create a mapping of id to input item for quick lookup
201
+ input_items_map = {item.id: item for item in eval_result.eval_input.eval_input_items}
202
+
203
+ for reward_item in reward_results:
204
+ # Find the corresponding input item
205
+ input_item: EvalInputItem = input_items_map.get(reward_item.id)
206
+ if not input_item:
207
+ logger.warning(
208
+ "No input item found for reward item id: %s",
209
+ reward_item.id,
210
+ )
211
+ continue
212
+
213
+ filtered_trajectory = []
214
+ for item in input_item.trajectory:
215
+ if item.function_ancestry.function_name in self.run_config.target_functions:
216
+ # If target model is specified, filter by model name
217
+ if (self.run_config.target_model and item.event_category == IntermediateStepCategory.LLM
218
+ and item.payload.name != self.run_config.target_model):
219
+ continue
220
+ filtered_trajectory.append(item)
221
+
222
+ if not filtered_trajectory:
223
+ logger.warning(
224
+ "No trajectory steps found for target function '%s' in item id: %s",
225
+ self.run_config.target_functions,
226
+ reward_item.id,
227
+ )
228
+ continue
229
+
230
+ # Parse episode from intermediate steps
231
+ episode = _parse_trajectory_from_steps(filtered_trajectory)
232
+
233
+ # If no episode was parsed from steps, try to build from
234
+ # input/output
235
+ if not episode:
236
+ continue
237
+
238
+ # Ensure we have at least a user and assistant message
239
+ if len(episode) < 2:
240
+ logger.warning(
241
+ "Episode for item %s has less than 2 messages, skipping",
242
+ reward_item.id,
243
+ )
244
+ continue
245
+
246
+ # Validate that assistant messages have logprobs
247
+ # (required for training)
248
+ has_valid_assistant = False
249
+ for item in episode:
250
+ if item.role == EpisodeItemRole.ASSISTANT and item.logprobs:
251
+ has_valid_assistant = True
252
+ break
253
+
254
+ if not has_valid_assistant:
255
+ logger.warning(
256
+ "Episode for item %s has no assistant messages with "
257
+ "logprobs, skipping as it cannot be used for training",
258
+ reward_item.id,
259
+ )
260
+ continue
261
+
262
+ # Create trajectory
263
+ trajectory = Trajectory(
264
+ episode=episode,
265
+ reward=(await self.compute_reward(reward_item, meta=meta)),
266
+ metadata={
267
+ "id": reward_item.id,
268
+ "reasoning": reward_item.reasoning,
269
+ "run_id": run_id,
270
+ "generation": gen_idx,
271
+ },
272
+ )
273
+
274
+ # Group by example ID
275
+ if reward_item.id not in trajectories_by_id:
276
+ trajectories_by_id[reward_item.id] = []
277
+ trajectories_by_id[reward_item.id].append(trajectory)
278
+
279
+ # Clean up completed runs
280
+ self.evaluation_runs.pop(run_id, None)
281
+
282
+ # Convert dictionary to list of lists, maintaining order
283
+ trajectories_list = list(trajectories_by_id.values())
284
+
285
+ total_trajectories = sum(len(traj_list) for traj_list in trajectories_list)
286
+ logger.info("Built %d trajectories across %d examples for run_id: %s",
287
+ total_trajectories,
288
+ len(trajectories_list),
289
+ run_id)
290
+
291
+ # Flatten the trajectories list into a 1 d list of trajectories
292
+ if not trajectories_list:
293
+ logger.warning("No trajectories were built for run_id: %s", run_id)
294
+ return TrajectoryCollection(trajectories=[], run_id=run_id)
295
+
296
+ if self.num_generations == 1:
297
+ # If only one generation, return flat list
298
+ flat_trajectories = [traj for sublist in trajectories_list for traj in sublist]
299
+ return TrajectoryCollection(trajectories=[flat_trajectories], run_id=run_id)
300
+
301
+ return TrajectoryCollection(trajectories=trajectories_list, run_id=run_id)
302
+
303
+ def log_progress(self, run_id: str, metrics: dict[str, Any], output_dir: str | None = None) -> None:
304
+ """
305
+ Log trajectory building progress.
306
+
307
+ Args:
308
+ run_id: The training run ID
309
+ metrics: Dictionary of metrics to log
310
+ output_dir: Optional output directory override
311
+ """
312
+ # Use default output directory if not provided
313
+ out_dir = Path(output_dir) if output_dir else Path("./.tmp/nat/finetuning/trajectory_builder")
314
+ out_dir.mkdir(parents=True, exist_ok=True)
315
+
316
+ # Create log file for trajectory builder
317
+ log_file = out_dir / f"trajectory_builder_{run_id}.jsonl"
318
+
319
+ # Prepare log entry
320
+ log_entry = {
321
+ "timestamp": datetime.now().isoformat(),
322
+ "run_id": run_id,
323
+ "num_generations": self.num_generations,
324
+ **metrics
325
+ }
326
+
327
+ # Append to log file
328
+ with open(log_file, 'a', encoding='utf-8') as f:
329
+ f.write(json.dumps(log_entry) + '\n')
330
+
331
+ logger.debug("Trajectory builder progress logged for run %s: %d trajectories",
332
+ run_id,
333
+ metrics.get("num_trajectories", 0))