nvidia-nat-nemo-customizer 1.4.0a20251223__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,550 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """
16
+ NeMo Customizer TrainerAdapter for DPO/SFT training.
17
+
18
+ This module provides a TrainerAdapter implementation that interfaces with
19
+ NeMo Customizer for submitting and monitoring training jobs.
20
+ """
21
+
22
+ import asyncio
23
+ import json
24
+ import logging
25
+ import tempfile
26
+ from datetime import datetime
27
+ from pathlib import Path
28
+ from typing import Any
29
+
30
+ import httpx
31
+ from huggingface_hub import HfApi
32
+ from nemo_microservices import NeMoMicroservices
33
+
34
+ from nat.data_models.finetuning import DPOItem
35
+ from nat.data_models.finetuning import FinetuneConfig
36
+ from nat.data_models.finetuning import OpenAIMessage
37
+ from nat.data_models.finetuning import TrainingJobRef
38
+ from nat.data_models.finetuning import TrainingJobStatus
39
+ from nat.data_models.finetuning import TrainingStatusEnum
40
+ from nat.data_models.finetuning import TrajectoryCollection
41
+ from nat.finetuning.interfaces.trainer_adapter import TrainerAdapter
42
+
43
+ from .config import NeMoCustomizerTrainerAdapterConfig
44
+
45
+ logger = logging.getLogger(__name__)
46
+
47
+
48
+ class NeMoCustomizerTrainerAdapter(TrainerAdapter):
49
+ """
50
+ TrainerAdapter for NeMo Customizer backend.
51
+
52
+ This adapter:
53
+ 1. Converts trajectories to JSONL format for DPO training
54
+ 2. Uploads datasets to NeMo Datastore via HuggingFace Hub API
55
+ 3. Submits customization jobs to NeMo Customizer
56
+ 4. Monitors job progress and status
57
+ 5. Optionally deploys trained models
58
+ """
59
+
60
+ def __init__(self, adapter_config: NeMoCustomizerTrainerAdapterConfig):
61
+ super().__init__(adapter_config)
62
+
63
+ self.adapter_config: NeMoCustomizerTrainerAdapterConfig = adapter_config
64
+
65
+ # Initialize NeMo Microservices client
66
+ self._entity_client: NeMoMicroservices | None = None
67
+ self._hf_api: HfApi | None = None
68
+
69
+ # Track active jobs
70
+ self._active_jobs: dict[str, str] = {} # run_id -> job_id mapping
71
+ self._job_output_models: dict[str, str] = {} # run_id -> output_model mapping
72
+
73
+ logger.info(f"Initialized NeMoCustomizerTrainerAdapter for namespace: {adapter_config.namespace}")
74
+
75
+ @property
76
+ def entity_client(self) -> NeMoMicroservices:
77
+ """Lazy initialization of NeMo Microservices client."""
78
+ if self._entity_client is None:
79
+ self._entity_client = NeMoMicroservices(base_url=self.adapter_config.entity_host)
80
+ return self._entity_client
81
+
82
+ @property
83
+ def hf_api(self) -> HfApi:
84
+ """Lazy initialization of HuggingFace API client."""
85
+ if self._hf_api is None:
86
+ self._hf_api = HfApi(
87
+ endpoint=f"{self.adapter_config.datastore_host}/v1/hf",
88
+ token=self.adapter_config.hf_token or "",
89
+ )
90
+ return self._hf_api
91
+
92
+ async def initialize(self, run_config: FinetuneConfig) -> None:
93
+ """Initialize the trainer adapter."""
94
+ await super().initialize(run_config)
95
+
96
+ if self.adapter_config.create_namespace_if_missing:
97
+ await self._ensure_namespaces_exist()
98
+
99
+ health = await self.is_healthy()
100
+ if not health:
101
+ raise ConnectionError(f"Failed to connect to NeMo Customizer at {self.adapter_config.entity_host}")
102
+
103
+ logger.info("Successfully initialized NeMo Customizer TrainerAdapter")
104
+
105
+ async def _ensure_namespaces_exist(self) -> None:
106
+ """Create namespaces in entity store and datastore if they don't exist."""
107
+ namespace = self.adapter_config.namespace
108
+
109
+ # Create namespace in entity store
110
+ try:
111
+ self.entity_client.namespaces.create(
112
+ id=namespace,
113
+ description=f"NAT finetuning namespace: {namespace}",
114
+ )
115
+ logger.info(f"Created namespace '{namespace}' in Entity Store")
116
+ except Exception as e:
117
+ logger.debug(f"Namespace '{namespace}' may already exist in Entity Store: {e}")
118
+
119
+ # Create namespace in datastore via HTTP
120
+ try:
121
+ async with httpx.AsyncClient() as client:
122
+ resp = await client.post(
123
+ f"{self.adapter_config.datastore_host}/v1/datastore/namespaces",
124
+ data={"namespace": namespace},
125
+ )
126
+ if resp.status_code in (200, 201):
127
+ logger.info(f"Created namespace '{namespace}' in Datastore")
128
+ elif resp.status_code in (409, 422):
129
+ logger.debug(f"Namespace '{namespace}' already exists in Datastore")
130
+ else:
131
+ logger.warning(f"Unexpected response creating namespace in Datastore: {resp.status_code}")
132
+ except Exception as e:
133
+ logger.warning(f"Error creating namespace in Datastore: {e}")
134
+
135
+ async def is_healthy(self) -> bool:
136
+ """Check if NeMo Customizer services are reachable."""
137
+ return True
138
+
139
+ def _format_prompt(self, prompt: list[OpenAIMessage] | str) -> list[dict[str, str]] | str:
140
+ """
141
+ Format prompt based on configuration.
142
+
143
+ Args:
144
+ prompt: Original prompt (string or list of OpenAI messages)
145
+
146
+ Returns:
147
+ Formatted prompt based on use_full_message_history setting
148
+ """
149
+ if self.adapter_config.use_full_message_history:
150
+ # Return full message history as list of dicts
151
+ if isinstance(prompt, str):
152
+ return [{"role": "user", "content": prompt}]
153
+ else:
154
+ return [{"role": msg.role, "content": msg.content} for msg in prompt]
155
+ # Return only last message content as string
156
+ elif isinstance(prompt, str):
157
+ return prompt
158
+ elif prompt:
159
+ return prompt[-1].content
160
+ else:
161
+ return ""
162
+
163
+ def _trajectory_to_dpo_jsonl(self, trajectories: TrajectoryCollection) -> tuple[str, str]:
164
+ """
165
+ Convert trajectory collection to JSONL format for DPO training.
166
+
167
+ Returns:
168
+ Tuple of (training_jsonl, validation_jsonl) content strings
169
+ """
170
+ all_items: list[dict[str, Any]] = []
171
+
172
+ for trajectory_group in trajectories.trajectories:
173
+ for trajectory in trajectory_group:
174
+ for episode_item in trajectory.episode:
175
+ if isinstance(episode_item, DPOItem):
176
+ formatted_prompt = self._format_prompt(episode_item.prompt)
177
+ dpo_record = {
178
+ "prompt": formatted_prompt,
179
+ "chosen_response": episode_item.chosen_response,
180
+ "rejected_response": episode_item.rejected_response,
181
+ }
182
+ all_items.append(dpo_record)
183
+
184
+ if not all_items:
185
+ raise ValueError("No DPO items found in trajectories")
186
+
187
+ # Split into training (80%) and validation (20%)
188
+ split_idx = max(1, int(len(all_items) * 0.8))
189
+ training_items = all_items[:split_idx]
190
+ validation_items = all_items[split_idx:] if split_idx < len(all_items) else all_items[-1:]
191
+
192
+ training_jsonl = "\n".join(json.dumps(item) for item in training_items)
193
+ validation_jsonl = "\n".join(json.dumps(item) for item in validation_items)
194
+
195
+ logger.info(f"Converted {len(all_items)} DPO items: "
196
+ f"{len(training_items)} training, {len(validation_items)} validation")
197
+
198
+ return training_jsonl, validation_jsonl
199
+
200
+ async def _setup_dataset(self, run_id: str, training_jsonl: str, validation_jsonl: str) -> str:
201
+ """
202
+ Create dataset repository and upload JSONL files.
203
+
204
+ Args:
205
+ run_id: Unique identifier for this training run
206
+ training_jsonl: Training data in JSONL format
207
+ validation_jsonl: Validation data in JSONL format
208
+
209
+ Returns:
210
+ Repository ID for the created dataset
211
+ """
212
+ dataset_name = f"{self.adapter_config.dataset_name}"
213
+ repo_id = f"{self.adapter_config.namespace}/{dataset_name}"
214
+
215
+ # Create dataset repo in datastore
216
+ self.hf_api.create_repo(repo_id, repo_type="dataset", exist_ok=True)
217
+
218
+ # Register dataset in entity store
219
+ try:
220
+ self.entity_client.datasets.create(
221
+ name=dataset_name,
222
+ namespace=self.adapter_config.namespace,
223
+ files_url=f"hf://datasets/{repo_id}",
224
+ description=f"NAT DPO training dataset for run {run_id}",
225
+ )
226
+ except Exception as e:
227
+ logger.debug(f"Dataset may already exist: {e}")
228
+
229
+ # Determine output directory for dataset files
230
+ if self.adapter_config.dataset_output_dir:
231
+ # Use configured output directory (create if needed, preserve files)
232
+ output_dir = Path(self.adapter_config.dataset_output_dir) / run_id
233
+ output_dir.mkdir(parents=True, exist_ok=True)
234
+ use_temp_dir = False
235
+ logger.info(f"Saving dataset files to: {output_dir}")
236
+ else:
237
+ # Use temporary directory (will be cleaned up)
238
+ use_temp_dir = True
239
+
240
+ def write_and_upload_files(base_dir: Path) -> None:
241
+ train_path = base_dir / "training_file.jsonl"
242
+ val_path = base_dir / "validation_file.jsonl"
243
+
244
+ train_path.write_text(training_jsonl)
245
+ val_path.write_text(validation_jsonl)
246
+
247
+ self.hf_api.upload_file(
248
+ path_or_fileobj=str(train_path),
249
+ path_in_repo="training/training_file.jsonl",
250
+ repo_id=repo_id,
251
+ repo_type="dataset",
252
+ revision="main",
253
+ commit_message=f"Training file for run {run_id}",
254
+ )
255
+
256
+ self.hf_api.upload_file(
257
+ path_or_fileobj=str(val_path),
258
+ path_in_repo="validation/validation_file.jsonl",
259
+ repo_id=repo_id,
260
+ repo_type="dataset",
261
+ revision="main",
262
+ commit_message=f"Validation file for run {run_id}",
263
+ )
264
+
265
+ if use_temp_dir:
266
+ with tempfile.TemporaryDirectory() as tmpdir:
267
+ write_and_upload_files(Path(tmpdir))
268
+ else:
269
+ write_and_upload_files(output_dir)
270
+
271
+ logger.info(f"Created and uploaded dataset: {repo_id}")
272
+ return dataset_name
273
+
274
+ async def submit(self, trajectories: TrajectoryCollection) -> TrainingJobRef:
275
+ """
276
+ Submit trajectories for training.
277
+
278
+ Args:
279
+ trajectories: Collection of trajectories containing DPO items
280
+
281
+ Returns:
282
+ Reference to the submitted training job
283
+ """
284
+ run_id = trajectories.run_id
285
+
286
+ if run_id in self._active_jobs:
287
+ raise ValueError(f"Training job for run {run_id} already exists")
288
+
289
+ # Convert trajectories to JSONL
290
+ training_jsonl, validation_jsonl = self._trajectory_to_dpo_jsonl(trajectories)
291
+
292
+ # Upload dataset
293
+ dataset_name = await self._setup_dataset(run_id, training_jsonl, validation_jsonl)
294
+
295
+ # Prepare hyperparameters
296
+ hyperparams = self.adapter_config.hyperparameters.model_dump()
297
+
298
+ # Submit customization job
299
+ job = self.entity_client.customization.jobs.create(
300
+ config=self.adapter_config.customization_config,
301
+ dataset={
302
+ "name": dataset_name,
303
+ "namespace": self.adapter_config.namespace,
304
+ },
305
+ hyperparameters=hyperparams,
306
+ )
307
+
308
+ job_id = job.id
309
+ self._active_jobs[run_id] = job_id
310
+ self._job_output_models[run_id] = job.output_model
311
+
312
+ logger.info(f"Submitted customization job {job_id} for run {run_id}. "
313
+ f"Output model: {job.output_model}")
314
+
315
+ return TrainingJobRef(
316
+ run_id=run_id,
317
+ backend="nemo-customizer",
318
+ metadata={
319
+ "job_id": job_id,
320
+ "output_model": job.output_model,
321
+ "dataset_name": dataset_name,
322
+ },
323
+ )
324
+
325
+ async def status(self, ref: TrainingJobRef) -> TrainingJobStatus:
326
+ """Get the status of a training job."""
327
+ job_id = self._active_jobs.get(ref.run_id)
328
+ if job_id is None:
329
+ # Try to get from metadata
330
+ job_id = ref.metadata.get("job_id") if ref.metadata else None
331
+
332
+ if job_id is None:
333
+ raise ValueError(f"No training job found for run {ref.run_id}")
334
+
335
+ try:
336
+ job_status = self.entity_client.customization.jobs.status(job_id)
337
+
338
+ # Map NeMo status to TrainingStatusEnum
339
+ status_map = {
340
+ "created": TrainingStatusEnum.PENDING,
341
+ "pending": TrainingStatusEnum.PENDING,
342
+ "running": TrainingStatusEnum.RUNNING,
343
+ "completed": TrainingStatusEnum.COMPLETED,
344
+ "failed": TrainingStatusEnum.FAILED,
345
+ "cancelled": TrainingStatusEnum.CANCELED,
346
+ "canceled": TrainingStatusEnum.CANCELED,
347
+ }
348
+
349
+ status = status_map.get(job_status.status.lower(), TrainingStatusEnum.RUNNING)
350
+ progress = getattr(job_status, "percentage_done", None)
351
+
352
+ message = f"Status: {job_status.status}"
353
+ if hasattr(job_status, "epochs_completed"):
354
+ message += f", Epochs: {job_status.epochs_completed}"
355
+
356
+ return TrainingJobStatus(
357
+ run_id=ref.run_id,
358
+ backend=ref.backend,
359
+ status=status,
360
+ progress=progress,
361
+ message=message,
362
+ metadata={
363
+ "job_id": job_id,
364
+ "nemo_status": job_status.status,
365
+ "output_model": self._job_output_models.get(ref.run_id),
366
+ },
367
+ )
368
+
369
+ except Exception as e:
370
+ logger.error(f"Error getting job status: {e}")
371
+ return TrainingJobStatus(
372
+ run_id=ref.run_id,
373
+ backend=ref.backend,
374
+ status=TrainingStatusEnum.FAILED,
375
+ message=f"Error getting status: {e}",
376
+ )
377
+
378
+ async def wait_until_complete(self, ref: TrainingJobRef, poll_interval: float | None = None) -> TrainingJobStatus:
379
+ """Wait for training job to complete."""
380
+ interval = poll_interval or self.adapter_config.poll_interval_seconds
381
+
382
+ last_status: str | None = None
383
+
384
+ while True:
385
+ status = await self.status(ref)
386
+
387
+ # Log when status changes
388
+ current_status = status.status.value
389
+ if current_status != last_status:
390
+ logger.info(f"Job {ref.run_id}: Status -> '{current_status}'")
391
+ last_status = current_status
392
+
393
+ # Log when progress changes
394
+ current_progress = status.progress
395
+ #if current_progress is not None and current_progress != last_progress:
396
+ logger.info(f"Job {ref.run_id}: Progress {current_progress:.1f}%")
397
+
398
+ if status.status in (
399
+ TrainingStatusEnum.COMPLETED,
400
+ TrainingStatusEnum.FAILED,
401
+ TrainingStatusEnum.CANCELED,
402
+ ):
403
+ # Handle deployment if configured
404
+ if (status.status == TrainingStatusEnum.COMPLETED and self.adapter_config.deploy_on_completion):
405
+ await self._deploy_model(ref)
406
+
407
+ # Clean up active job tracking
408
+ self._active_jobs.pop(ref.run_id, None)
409
+
410
+ return status
411
+
412
+ await asyncio.sleep(interval)
413
+
414
+ async def _deploy_model(self, ref: TrainingJobRef) -> None:
415
+ """Deploy the trained model and wait until deployment is ready."""
416
+ output_model = self._job_output_models.get(ref.run_id)
417
+ if not output_model:
418
+ logger.warning(f"No output model found for run {ref.run_id}, skipping deployment")
419
+ return
420
+
421
+ deploy_config = self.adapter_config.deployment_config
422
+ namespace = self.adapter_config.namespace
423
+
424
+ try:
425
+ # Create deployment configuration
426
+ config_name = f"nat-deploy-config-{ref.run_id}"
427
+ dep_config = self.entity_client.deployment.configs.create(
428
+ name=config_name,
429
+ namespace=namespace,
430
+ description=deploy_config.description,
431
+ model=output_model,
432
+ nim_deployment={
433
+ "image_name": deploy_config.image_name,
434
+ "image_tag": deploy_config.image_tag,
435
+ "gpu": deploy_config.gpu,
436
+ },
437
+ )
438
+
439
+ # Create model deployment
440
+ deployment_name = (deploy_config.deployment_name or f"nat-deployment-{ref.run_id}")
441
+ self.entity_client.deployment.model_deployments.create(
442
+ name=deployment_name,
443
+ namespace=namespace,
444
+ description=deploy_config.description,
445
+ config=f"{dep_config.namespace}/{dep_config.name}",
446
+ )
447
+
448
+ logger.info(f"Created deployment '{deployment_name}' for model {output_model}")
449
+
450
+ # Wait for deployment to be ready
451
+ await self._wait_for_deployment_ready(namespace, deployment_name)
452
+
453
+ except Exception as e:
454
+ logger.error(f"Failed to deploy model: {e}")
455
+ raise
456
+
457
+ async def _wait_for_deployment_ready(
458
+ self,
459
+ namespace: str,
460
+ deployment_name: str,
461
+ poll_interval: float | None = None,
462
+ timeout: float | None = None,
463
+ ) -> None:
464
+ """
465
+ Wait for a model deployment to become ready.
466
+
467
+ Args:
468
+ namespace: Namespace of the deployment
469
+ deployment_name: Name of the deployment
470
+ poll_interval: Seconds between status checks (default: adapter config poll_interval_seconds)
471
+ timeout: Maximum seconds to wait (default: adapter config deployment_timeout_seconds)
472
+ """
473
+ interval = poll_interval or self.adapter_config.poll_interval_seconds
474
+ max_wait = timeout or self.adapter_config.deployment_timeout_seconds
475
+
476
+ logger.info(f"Waiting for deployment '{deployment_name}' to be ready...")
477
+
478
+ last_status: str | None = None
479
+ elapsed = 0.0
480
+
481
+ while elapsed < max_wait:
482
+ try:
483
+ # Get all deployments and find ours
484
+ deployments = self.entity_client.deployment.model_deployments.list().data
485
+ deployment = None
486
+ for dep in deployments:
487
+ if dep.name == deployment_name and dep.namespace == namespace:
488
+ deployment = dep
489
+ break
490
+
491
+ if deployment is None:
492
+ logger.warning(f"Deployment '{deployment_name}' not found in namespace '{namespace}'")
493
+ await asyncio.sleep(interval)
494
+ elapsed += interval
495
+ continue
496
+
497
+ # Check status
498
+ status_details = getattr(deployment, "status_details", None)
499
+ current_status = status_details.status if status_details else "unknown"
500
+ description = status_details.description if status_details else ""
501
+
502
+ # Log status changes
503
+ if current_status != last_status:
504
+ logger.info(f"Deployment '{deployment_name}': Status -> '{current_status}'")
505
+ if description:
506
+ logger.info(f"Deployment '{deployment_name}': {description.strip()}")
507
+ last_status = current_status
508
+
509
+ # Check if ready
510
+ if current_status.lower() == "ready":
511
+ logger.info(f"Deployment '{deployment_name}' is ready!")
512
+ return
513
+
514
+ # Check for failure states
515
+ if current_status.lower() in ("failed", "error"):
516
+ raise RuntimeError(
517
+ f"Deployment '{deployment_name}' failed with status '{current_status}': {description}")
518
+
519
+ except RuntimeError:
520
+ raise
521
+ except Exception as e:
522
+ logger.warning(f"Error checking deployment status: {e}")
523
+
524
+ await asyncio.sleep(interval)
525
+ elapsed += interval
526
+
527
+ raise TimeoutError(f"Deployment '{deployment_name}' did not become ready within {max_wait} seconds")
528
+
529
+ def log_progress(self, ref: TrainingJobRef, metrics: dict[str, Any], output_dir: str | None = None) -> None:
530
+ """Log training progress to file."""
531
+ out_dir = Path(output_dir) if output_dir else Path("./.tmp/nat/finetuning/trainer_adapter")
532
+ out_dir.mkdir(parents=True, exist_ok=True)
533
+
534
+ log_file = out_dir / f"nemo_customizer_{ref.run_id}.jsonl"
535
+
536
+ log_entry = {
537
+ "timestamp": datetime.now().isoformat(),
538
+ "run_id": ref.run_id,
539
+ "backend": ref.backend,
540
+ "config": {
541
+ "namespace": self.adapter_config.namespace,
542
+ "customization_config": self.adapter_config.customization_config,
543
+ },
544
+ **metrics,
545
+ }
546
+
547
+ with open(log_file, "a", encoding="utf-8") as f:
548
+ f.write(json.dumps(log_entry) + "\n")
549
+
550
+ logger.debug(f"Logged progress for job {ref.run_id}")