nvidia-nat-nemo-customizer 1.4.0a20251223__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,424 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """
16
+ NeMo Customizer Trainer for DPO finetuning.
17
+
18
+ This module provides a Trainer implementation that orchestrates data collection
19
+ via trajectory builders and submits training jobs to NeMo Customizer.
20
+ """
21
+
22
+ import json
23
+ import logging
24
+ import uuid
25
+ from datetime import datetime
26
+ from pathlib import Path
27
+ from typing import Any
28
+
29
+ from nat.data_models.finetuning import FinetuneConfig
30
+ from nat.data_models.finetuning import TrainingJobRef
31
+ from nat.data_models.finetuning import TrainingJobStatus
32
+ from nat.data_models.finetuning import TrainingStatusEnum
33
+ from nat.data_models.finetuning import Trajectory
34
+ from nat.data_models.finetuning import TrajectoryCollection
35
+ from nat.finetuning.interfaces.finetuning_runner import Trainer
36
+
37
+ from .config import NeMoCustomizerTrainerConfig
38
+
39
+ logger = logging.getLogger(__name__)
40
+
41
+
42
+ class NeMoCustomizerTrainer(Trainer):
43
+ """
44
+ Trainer for NeMo Customizer DPO/SFT finetuning.
45
+
46
+ Unlike epoch-based trainers, this trainer:
47
+ 1. Runs the trajectory builder multiple times (num_runs) to collect data
48
+ 2. Aggregates all trajectories into a single dataset
49
+ 3. Submits the dataset to NeMo Customizer for training
50
+ 4. Monitors the training job until completion
51
+
52
+ The actual training epochs are handled by NeMo Customizer via hyperparameters.
53
+ """
54
+
55
+ def __init__(self, trainer_config: NeMoCustomizerTrainerConfig, **kwargs) -> None:
56
+ """
57
+ Initialize the NeMo Customizer Trainer.
58
+
59
+ Args:
60
+ trainer_config: Configuration for the trainer
61
+ """
62
+ super().__init__(trainer_config)
63
+
64
+ self.trainer_config: NeMoCustomizerTrainerConfig = trainer_config
65
+
66
+ # Track job references and metrics
67
+ self._job_ref: TrainingJobRef | None = None
68
+ self._run_id: str | None = None
69
+
70
+ # Track collected data across runs
71
+ self._all_trajectories: list[list[Trajectory]] = []
72
+ self._run_metrics: list[dict[str, Any]] = []
73
+
74
+ # Progress tracking
75
+ self._collection_history: list[dict[str, Any]] = []
76
+
77
+ async def initialize(self, run_config: FinetuneConfig) -> None:
78
+ """
79
+ Initialize the trainer and its components.
80
+
81
+ Note: Curriculum learning is not supported for DPO training.
82
+ """
83
+ logger.info("Initializing NeMo Customizer Trainer")
84
+
85
+ # Store run config but skip curriculum learning setup
86
+ self.run_config = run_config
87
+ self.trainer_config.reward = self.run_config.reward_function
88
+
89
+ # Disable curriculum learning for DPO
90
+ self.curriculum_config = None
91
+ self._curriculum_state = {
92
+ "current_percentile": 1.0,
93
+ "last_expansion_epoch": -1,
94
+ "total_groups": 0,
95
+ "included_groups": set(),
96
+ }
97
+
98
+ # Initialize components
99
+ await self.trajectory_builder.initialize(run_config)
100
+ await self.trainer_adapter.initialize(run_config)
101
+
102
+ # Generate unique run ID
103
+ self._run_id = f"nemo_dpo_{uuid.uuid4().hex[:8]}"
104
+
105
+ logger.info(f"NeMo Customizer Trainer initialized with run ID: {self._run_id}")
106
+
107
+ async def run_epoch(self, epoch: int, run_id: str) -> TrainingJobRef | None:
108
+ """
109
+ Run a single data collection run.
110
+
111
+ For NeMo Customizer, this collects trajectories without submitting
112
+ to training. The actual submission happens in run().
113
+
114
+ Args:
115
+ epoch: The current run number (0-indexed)
116
+ run_id: Unique identifier for this training run
117
+
118
+ Returns:
119
+ None (trajectories are accumulated, not submitted per-run)
120
+ """
121
+ logger.info(f"Starting data collection run {epoch + 1}/{self.trainer_config.num_runs}")
122
+
123
+ run_meta = {
124
+ "run_number": epoch,
125
+ "run_id": run_id,
126
+ "trainer_config": self.trainer_config.model_dump(),
127
+ }
128
+
129
+ # Run trajectory builder
130
+ await self.trajectory_builder.start_run(run_id=f"{run_id}_run{epoch}", meta=run_meta)
131
+
132
+ # Finalize and get trajectories
133
+ trajectory_collection = await self.trajectory_builder.finalize(run_id=f"{run_id}_run{epoch}", meta=run_meta)
134
+
135
+ if not trajectory_collection.trajectories:
136
+ logger.warning(f"No trajectories collected for run {epoch}")
137
+ return None
138
+
139
+ # Calculate metrics for this run
140
+ run_rewards = []
141
+ num_trajectories = 0
142
+ num_dpo_pairs = 0
143
+
144
+ for trajectory_group in trajectory_collection.trajectories:
145
+ for trajectory in trajectory_group:
146
+ num_trajectories += 1
147
+ run_rewards.append(trajectory.reward)
148
+ # Count DPO pairs (each trajectory has one DPOItem)
149
+ num_dpo_pairs += len(trajectory.episode)
150
+
151
+ metrics = {
152
+ "run_number": epoch,
153
+ "num_trajectories": num_trajectories,
154
+ "num_dpo_pairs": num_dpo_pairs,
155
+ "avg_reward": sum(run_rewards) / len(run_rewards) if run_rewards else 0.0,
156
+ "min_reward": min(run_rewards) if run_rewards else 0.0,
157
+ "max_reward": max(run_rewards) if run_rewards else 0.0,
158
+ "timestamp": datetime.now().isoformat(),
159
+ }
160
+ self._run_metrics.append(metrics)
161
+
162
+ logger.info(f"Run {epoch + 1}: Collected {num_trajectories} trajectories, "
163
+ f"{num_dpo_pairs} DPO pairs, avg reward: {metrics['avg_reward']:.4f}")
164
+
165
+ # Accumulate trajectories
166
+ self._all_trajectories.extend(trajectory_collection.trajectories)
167
+
168
+ # Log progress
169
+ self.log_progress(epoch, metrics)
170
+
171
+ return None # No job submitted per-run
172
+
173
+ async def run(self, num_epochs: int) -> list[TrainingJobStatus]:
174
+ """
175
+ Run the complete DPO data collection and training workflow.
176
+
177
+ Args:
178
+ num_epochs: Ignored for NeMo Customizer (uses trainer_config.num_runs)
179
+
180
+ Returns:
181
+ list[TrainingJobStatus]: Status of the training job
182
+ """
183
+ if not self._run_id:
184
+ raise RuntimeError("Trainer not initialized. Call initialize() first.")
185
+
186
+ num_runs = self.trainer_config.num_runs
187
+ logger.info(f"Starting NeMo Customizer DPO workflow with {num_runs} data collection runs")
188
+
189
+ # Phase 1: Collect data from multiple runs
190
+ for run_idx in range(num_runs):
191
+ try:
192
+ await self.run_epoch(run_idx, self._run_id)
193
+ except Exception as e:
194
+ logger.error(f"Error during data collection run {run_idx}: {e}")
195
+ if not self.trainer_config.continue_on_collection_error:
196
+ return [
197
+ TrainingJobStatus(
198
+ run_id=self._run_id,
199
+ backend="nemo-customizer",
200
+ status=TrainingStatusEnum.FAILED,
201
+ message=f"Data collection failed at run {run_idx}: {e}",
202
+ metadata={"run_number": run_idx},
203
+ )
204
+ ]
205
+
206
+ # Check if we have any data
207
+ if not self._all_trajectories:
208
+ logger.error("No trajectories collected from any run")
209
+ return [
210
+ TrainingJobStatus(
211
+ run_id=self._run_id,
212
+ backend="nemo-customizer",
213
+ status=TrainingStatusEnum.FAILED,
214
+ message="No trajectories collected",
215
+ )
216
+ ]
217
+
218
+ # Calculate total statistics
219
+ total_trajectories = len(self._all_trajectories)
220
+ total_dpo_pairs = sum(
221
+ len(traj.episode) for group in self._all_trajectories
222
+ for traj in (group if isinstance(group, list) else [group]))
223
+
224
+ logger.info(f"Data collection complete: {total_trajectories} trajectory groups, "
225
+ f"~{total_dpo_pairs} total DPO pairs from {num_runs} runs")
226
+
227
+ # Phase 2: Submit aggregated trajectories for training
228
+ try:
229
+ trajectory_collection = TrajectoryCollection(
230
+ trajectories=self._all_trajectories,
231
+ run_id=self._run_id,
232
+ )
233
+
234
+ # Apply deduplication if configured
235
+ if self.trainer_config.deduplicate_pairs:
236
+ trajectory_collection = self._deduplicate_trajectories(trajectory_collection)
237
+
238
+ # Apply sampling if configured
239
+ if self.trainer_config.max_pairs is not None:
240
+ trajectory_collection = self._sample_trajectories(trajectory_collection, self.trainer_config.max_pairs)
241
+
242
+ self._job_ref = await self.trainer_adapter.submit(trajectory_collection)
243
+
244
+ logger.info(f"Submitted training job: {self._job_ref.metadata.get('job_id')}")
245
+
246
+ except Exception as e:
247
+ logger.error(f"Failed to submit training job: {e}")
248
+ return [
249
+ TrainingJobStatus(
250
+ run_id=self._run_id,
251
+ backend="nemo-customizer",
252
+ status=TrainingStatusEnum.FAILED,
253
+ message=f"Failed to submit training job: {e}",
254
+ )
255
+ ]
256
+
257
+ # Phase 3: Wait for training completion
258
+ if self.trainer_config.wait_for_completion:
259
+ logger.info("Waiting for training job to complete...")
260
+ final_status = await self.trainer_adapter.wait_until_complete(self._job_ref)
261
+
262
+ # Log final metrics
263
+ self._log_final_metrics(final_status)
264
+
265
+ return [final_status]
266
+ else:
267
+ # Return immediately with pending status
268
+ return [
269
+ TrainingJobStatus(
270
+ run_id=self._run_id,
271
+ backend="nemo-customizer",
272
+ status=TrainingStatusEnum.RUNNING,
273
+ message="Training job submitted (not waiting for completion)",
274
+ metadata=self._job_ref.metadata,
275
+ )
276
+ ]
277
+
278
+ def _deduplicate_trajectories(self, collection: TrajectoryCollection) -> TrajectoryCollection:
279
+ """Remove duplicate DPO pairs based on prompt+responses."""
280
+ seen = set()
281
+ unique_groups = []
282
+
283
+ for group in collection.trajectories:
284
+ unique_trajectories = []
285
+ for traj in group:
286
+ for item in traj.episode:
287
+ # Create a hashable key from prompt and responses
288
+ prompt_str = (str(item.prompt) if hasattr(item, "prompt") else "")
289
+ key = (
290
+ prompt_str,
291
+ getattr(item, "chosen_response", ""),
292
+ getattr(item, "rejected_response", ""),
293
+ )
294
+ if key not in seen:
295
+ seen.add(key)
296
+ unique_trajectories.append(traj)
297
+ break # Only add trajectory once
298
+
299
+ if unique_trajectories:
300
+ unique_groups.append(unique_trajectories)
301
+
302
+ original_count = sum(len(g) for g in collection.trajectories)
303
+ new_count = sum(len(g) for g in unique_groups)
304
+ logger.info(f"Deduplication: {original_count} -> {new_count} trajectories")
305
+
306
+ return TrajectoryCollection(trajectories=unique_groups, run_id=collection.run_id)
307
+
308
+ def _sample_trajectories(self, collection: TrajectoryCollection, max_pairs: int) -> TrajectoryCollection:
309
+ """Sample trajectories to limit dataset size."""
310
+ import random
311
+
312
+ all_trajectories = []
313
+ for group in collection.trajectories:
314
+ all_trajectories.extend(group)
315
+
316
+ if len(all_trajectories) <= max_pairs:
317
+ return collection
318
+
319
+ # Sample randomly
320
+ sampled = random.sample(all_trajectories, max_pairs)
321
+ logger.info(f"Sampling: {len(all_trajectories)} -> {max_pairs} trajectories")
322
+
323
+ return TrajectoryCollection(
324
+ trajectories=[[t] for t in sampled],
325
+ run_id=collection.run_id,
326
+ )
327
+
328
+ async def get_metrics(self, run_id: str) -> dict[str, Any]:
329
+ """Get training metrics for the run."""
330
+ metrics = {
331
+ "run_id": run_id,
332
+ "num_collection_runs": len(self._run_metrics),
333
+ "collection_runs": self._run_metrics,
334
+ "total_trajectory_groups": len(self._all_trajectories),
335
+ }
336
+
337
+ if self._job_ref:
338
+ try:
339
+ status = await self.trainer_adapter.status(self._job_ref)
340
+ metrics["training_job"] = {
341
+ "job_id": self._job_ref.metadata.get("job_id"),
342
+ "status": status.status.value,
343
+ "progress": status.progress,
344
+ "message": status.message,
345
+ }
346
+ except Exception as e:
347
+ metrics["training_job"] = {"error": str(e)}
348
+
349
+ return metrics
350
+
351
+ async def cleanup(self) -> None:
352
+ """Clean up resources."""
353
+ logger.info("Cleaning up NeMo Customizer Trainer resources")
354
+
355
+ # Cancel any running trajectory builder tasks
356
+ if hasattr(self.trajectory_builder, "evaluation_runs"):
357
+ for run_id, task in self.trajectory_builder.evaluation_runs.items():
358
+ if not task.done():
359
+ logger.info(f"Cancelling evaluation task for run {run_id}")
360
+ task.cancel()
361
+
362
+ # Clear accumulated data
363
+ self._all_trajectories.clear()
364
+ self._run_metrics.clear()
365
+
366
+ logger.info("NeMo Customizer Trainer cleanup completed")
367
+
368
+ def log_progress(self, epoch: int, metrics: dict[str, Any], output_dir: str | None = None) -> None:
369
+ """Log data collection progress."""
370
+ out_dir = Path(output_dir) if output_dir else self.run_config.output_dir
371
+ out_dir.mkdir(parents=True, exist_ok=True)
372
+
373
+ # Store in history
374
+ progress_entry = {
375
+ "run_number": epoch,
376
+ "timestamp": datetime.now().isoformat(),
377
+ "run_id": self._run_id,
378
+ **metrics,
379
+ }
380
+ self._collection_history.append(progress_entry)
381
+
382
+ # Log to JSON file
383
+ log_file = out_dir / "data_collection_progress.jsonl"
384
+ with open(log_file, "a", encoding="utf-8") as f:
385
+ f.write(json.dumps(progress_entry) + "\n")
386
+
387
+ # Save collection history
388
+ history_file = out_dir / "collection_history.json"
389
+ with open(history_file, "w", encoding="utf-8") as f:
390
+ json.dump(self._collection_history, f, indent=2)
391
+
392
+ logger.info(f"Run {epoch + 1}: {metrics.get('num_dpo_pairs', 0)} DPO pairs, "
393
+ f"avg reward: {metrics.get('avg_reward', 0):.4f}")
394
+
395
+ def _log_final_metrics(self, final_status: TrainingJobStatus) -> None:
396
+ """Log final training metrics."""
397
+ out_dir = self.run_config.output_dir
398
+ out_dir.mkdir(parents=True, exist_ok=True)
399
+
400
+ final_metrics = {
401
+ "run_id": self._run_id,
402
+ "timestamp": datetime.now().isoformat(),
403
+ "status": final_status.status.value,
404
+ "message": final_status.message,
405
+ "progress": final_status.progress,
406
+ "num_collection_runs": len(self._run_metrics),
407
+ "total_trajectory_groups": len(self._all_trajectories),
408
+ "collection_summary": {
409
+ "total_trajectories":
410
+ sum(m.get("num_trajectories", 0) for m in self._run_metrics),
411
+ "total_dpo_pairs":
412
+ sum(m.get("num_dpo_pairs", 0) for m in self._run_metrics),
413
+ "avg_reward": (sum(m.get("avg_reward", 0)
414
+ for m in self._run_metrics) / len(self._run_metrics) if self._run_metrics else 0.0),
415
+ },
416
+ "job_metadata": self._job_ref.metadata if self._job_ref else None,
417
+ }
418
+
419
+ # Save final metrics
420
+ metrics_file = out_dir / "final_metrics.json"
421
+ with open(metrics_file, "w", encoding="utf-8") as f:
422
+ json.dump(final_metrics, f, indent=2)
423
+
424
+ logger.info(f"Training completed with status: {final_status.status.value}")