nvidia-nat-nemo-customizer 1.4.0a20251223__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nat/meta/pypi.md +23 -0
- nat/plugins/customizer/__init__.py +43 -0
- nat/plugins/customizer/dpo/__init__.py +44 -0
- nat/plugins/customizer/dpo/config.py +360 -0
- nat/plugins/customizer/dpo/register.py +157 -0
- nat/plugins/customizer/dpo/trainer.py +424 -0
- nat/plugins/customizer/dpo/trainer_adapter.py +550 -0
- nat/plugins/customizer/dpo/trajectory_builder.py +767 -0
- nat/plugins/customizer/register.py +23 -0
- nvidia_nat_nemo_customizer-1.4.0a20251223.dist-info/METADATA +45 -0
- nvidia_nat_nemo_customizer-1.4.0a20251223.dist-info/RECORD +16 -0
- nvidia_nat_nemo_customizer-1.4.0a20251223.dist-info/WHEEL +5 -0
- nvidia_nat_nemo_customizer-1.4.0a20251223.dist-info/entry_points.txt +2 -0
- nvidia_nat_nemo_customizer-1.4.0a20251223.dist-info/licenses/LICENSE-3rd-party.txt +5478 -0
- nvidia_nat_nemo_customizer-1.4.0a20251223.dist-info/licenses/LICENSE.md +201 -0
- nvidia_nat_nemo_customizer-1.4.0a20251223.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,550 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
"""
|
|
16
|
+
NeMo Customizer TrainerAdapter for DPO/SFT training.
|
|
17
|
+
|
|
18
|
+
This module provides a TrainerAdapter implementation that interfaces with
|
|
19
|
+
NeMo Customizer for submitting and monitoring training jobs.
|
|
20
|
+
"""
|
|
21
|
+
|
|
22
|
+
import asyncio
|
|
23
|
+
import json
|
|
24
|
+
import logging
|
|
25
|
+
import tempfile
|
|
26
|
+
from datetime import datetime
|
|
27
|
+
from pathlib import Path
|
|
28
|
+
from typing import Any
|
|
29
|
+
|
|
30
|
+
import httpx
|
|
31
|
+
from huggingface_hub import HfApi
|
|
32
|
+
from nemo_microservices import NeMoMicroservices
|
|
33
|
+
|
|
34
|
+
from nat.data_models.finetuning import DPOItem
|
|
35
|
+
from nat.data_models.finetuning import FinetuneConfig
|
|
36
|
+
from nat.data_models.finetuning import OpenAIMessage
|
|
37
|
+
from nat.data_models.finetuning import TrainingJobRef
|
|
38
|
+
from nat.data_models.finetuning import TrainingJobStatus
|
|
39
|
+
from nat.data_models.finetuning import TrainingStatusEnum
|
|
40
|
+
from nat.data_models.finetuning import TrajectoryCollection
|
|
41
|
+
from nat.finetuning.interfaces.trainer_adapter import TrainerAdapter
|
|
42
|
+
|
|
43
|
+
from .config import NeMoCustomizerTrainerAdapterConfig
|
|
44
|
+
|
|
45
|
+
logger = logging.getLogger(__name__)
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
class NeMoCustomizerTrainerAdapter(TrainerAdapter):
|
|
49
|
+
"""
|
|
50
|
+
TrainerAdapter for NeMo Customizer backend.
|
|
51
|
+
|
|
52
|
+
This adapter:
|
|
53
|
+
1. Converts trajectories to JSONL format for DPO training
|
|
54
|
+
2. Uploads datasets to NeMo Datastore via HuggingFace Hub API
|
|
55
|
+
3. Submits customization jobs to NeMo Customizer
|
|
56
|
+
4. Monitors job progress and status
|
|
57
|
+
5. Optionally deploys trained models
|
|
58
|
+
"""
|
|
59
|
+
|
|
60
|
+
def __init__(self, adapter_config: NeMoCustomizerTrainerAdapterConfig):
|
|
61
|
+
super().__init__(adapter_config)
|
|
62
|
+
|
|
63
|
+
self.adapter_config: NeMoCustomizerTrainerAdapterConfig = adapter_config
|
|
64
|
+
|
|
65
|
+
# Initialize NeMo Microservices client
|
|
66
|
+
self._entity_client: NeMoMicroservices | None = None
|
|
67
|
+
self._hf_api: HfApi | None = None
|
|
68
|
+
|
|
69
|
+
# Track active jobs
|
|
70
|
+
self._active_jobs: dict[str, str] = {} # run_id -> job_id mapping
|
|
71
|
+
self._job_output_models: dict[str, str] = {} # run_id -> output_model mapping
|
|
72
|
+
|
|
73
|
+
logger.info(f"Initialized NeMoCustomizerTrainerAdapter for namespace: {adapter_config.namespace}")
|
|
74
|
+
|
|
75
|
+
@property
|
|
76
|
+
def entity_client(self) -> NeMoMicroservices:
|
|
77
|
+
"""Lazy initialization of NeMo Microservices client."""
|
|
78
|
+
if self._entity_client is None:
|
|
79
|
+
self._entity_client = NeMoMicroservices(base_url=self.adapter_config.entity_host)
|
|
80
|
+
return self._entity_client
|
|
81
|
+
|
|
82
|
+
@property
|
|
83
|
+
def hf_api(self) -> HfApi:
|
|
84
|
+
"""Lazy initialization of HuggingFace API client."""
|
|
85
|
+
if self._hf_api is None:
|
|
86
|
+
self._hf_api = HfApi(
|
|
87
|
+
endpoint=f"{self.adapter_config.datastore_host}/v1/hf",
|
|
88
|
+
token=self.adapter_config.hf_token or "",
|
|
89
|
+
)
|
|
90
|
+
return self._hf_api
|
|
91
|
+
|
|
92
|
+
async def initialize(self, run_config: FinetuneConfig) -> None:
|
|
93
|
+
"""Initialize the trainer adapter."""
|
|
94
|
+
await super().initialize(run_config)
|
|
95
|
+
|
|
96
|
+
if self.adapter_config.create_namespace_if_missing:
|
|
97
|
+
await self._ensure_namespaces_exist()
|
|
98
|
+
|
|
99
|
+
health = await self.is_healthy()
|
|
100
|
+
if not health:
|
|
101
|
+
raise ConnectionError(f"Failed to connect to NeMo Customizer at {self.adapter_config.entity_host}")
|
|
102
|
+
|
|
103
|
+
logger.info("Successfully initialized NeMo Customizer TrainerAdapter")
|
|
104
|
+
|
|
105
|
+
async def _ensure_namespaces_exist(self) -> None:
|
|
106
|
+
"""Create namespaces in entity store and datastore if they don't exist."""
|
|
107
|
+
namespace = self.adapter_config.namespace
|
|
108
|
+
|
|
109
|
+
# Create namespace in entity store
|
|
110
|
+
try:
|
|
111
|
+
self.entity_client.namespaces.create(
|
|
112
|
+
id=namespace,
|
|
113
|
+
description=f"NAT finetuning namespace: {namespace}",
|
|
114
|
+
)
|
|
115
|
+
logger.info(f"Created namespace '{namespace}' in Entity Store")
|
|
116
|
+
except Exception as e:
|
|
117
|
+
logger.debug(f"Namespace '{namespace}' may already exist in Entity Store: {e}")
|
|
118
|
+
|
|
119
|
+
# Create namespace in datastore via HTTP
|
|
120
|
+
try:
|
|
121
|
+
async with httpx.AsyncClient() as client:
|
|
122
|
+
resp = await client.post(
|
|
123
|
+
f"{self.adapter_config.datastore_host}/v1/datastore/namespaces",
|
|
124
|
+
data={"namespace": namespace},
|
|
125
|
+
)
|
|
126
|
+
if resp.status_code in (200, 201):
|
|
127
|
+
logger.info(f"Created namespace '{namespace}' in Datastore")
|
|
128
|
+
elif resp.status_code in (409, 422):
|
|
129
|
+
logger.debug(f"Namespace '{namespace}' already exists in Datastore")
|
|
130
|
+
else:
|
|
131
|
+
logger.warning(f"Unexpected response creating namespace in Datastore: {resp.status_code}")
|
|
132
|
+
except Exception as e:
|
|
133
|
+
logger.warning(f"Error creating namespace in Datastore: {e}")
|
|
134
|
+
|
|
135
|
+
async def is_healthy(self) -> bool:
|
|
136
|
+
"""Check if NeMo Customizer services are reachable."""
|
|
137
|
+
return True
|
|
138
|
+
|
|
139
|
+
def _format_prompt(self, prompt: list[OpenAIMessage] | str) -> list[dict[str, str]] | str:
|
|
140
|
+
"""
|
|
141
|
+
Format prompt based on configuration.
|
|
142
|
+
|
|
143
|
+
Args:
|
|
144
|
+
prompt: Original prompt (string or list of OpenAI messages)
|
|
145
|
+
|
|
146
|
+
Returns:
|
|
147
|
+
Formatted prompt based on use_full_message_history setting
|
|
148
|
+
"""
|
|
149
|
+
if self.adapter_config.use_full_message_history:
|
|
150
|
+
# Return full message history as list of dicts
|
|
151
|
+
if isinstance(prompt, str):
|
|
152
|
+
return [{"role": "user", "content": prompt}]
|
|
153
|
+
else:
|
|
154
|
+
return [{"role": msg.role, "content": msg.content} for msg in prompt]
|
|
155
|
+
# Return only last message content as string
|
|
156
|
+
elif isinstance(prompt, str):
|
|
157
|
+
return prompt
|
|
158
|
+
elif prompt:
|
|
159
|
+
return prompt[-1].content
|
|
160
|
+
else:
|
|
161
|
+
return ""
|
|
162
|
+
|
|
163
|
+
def _trajectory_to_dpo_jsonl(self, trajectories: TrajectoryCollection) -> tuple[str, str]:
|
|
164
|
+
"""
|
|
165
|
+
Convert trajectory collection to JSONL format for DPO training.
|
|
166
|
+
|
|
167
|
+
Returns:
|
|
168
|
+
Tuple of (training_jsonl, validation_jsonl) content strings
|
|
169
|
+
"""
|
|
170
|
+
all_items: list[dict[str, Any]] = []
|
|
171
|
+
|
|
172
|
+
for trajectory_group in trajectories.trajectories:
|
|
173
|
+
for trajectory in trajectory_group:
|
|
174
|
+
for episode_item in trajectory.episode:
|
|
175
|
+
if isinstance(episode_item, DPOItem):
|
|
176
|
+
formatted_prompt = self._format_prompt(episode_item.prompt)
|
|
177
|
+
dpo_record = {
|
|
178
|
+
"prompt": formatted_prompt,
|
|
179
|
+
"chosen_response": episode_item.chosen_response,
|
|
180
|
+
"rejected_response": episode_item.rejected_response,
|
|
181
|
+
}
|
|
182
|
+
all_items.append(dpo_record)
|
|
183
|
+
|
|
184
|
+
if not all_items:
|
|
185
|
+
raise ValueError("No DPO items found in trajectories")
|
|
186
|
+
|
|
187
|
+
# Split into training (80%) and validation (20%)
|
|
188
|
+
split_idx = max(1, int(len(all_items) * 0.8))
|
|
189
|
+
training_items = all_items[:split_idx]
|
|
190
|
+
validation_items = all_items[split_idx:] if split_idx < len(all_items) else all_items[-1:]
|
|
191
|
+
|
|
192
|
+
training_jsonl = "\n".join(json.dumps(item) for item in training_items)
|
|
193
|
+
validation_jsonl = "\n".join(json.dumps(item) for item in validation_items)
|
|
194
|
+
|
|
195
|
+
logger.info(f"Converted {len(all_items)} DPO items: "
|
|
196
|
+
f"{len(training_items)} training, {len(validation_items)} validation")
|
|
197
|
+
|
|
198
|
+
return training_jsonl, validation_jsonl
|
|
199
|
+
|
|
200
|
+
async def _setup_dataset(self, run_id: str, training_jsonl: str, validation_jsonl: str) -> str:
|
|
201
|
+
"""
|
|
202
|
+
Create dataset repository and upload JSONL files.
|
|
203
|
+
|
|
204
|
+
Args:
|
|
205
|
+
run_id: Unique identifier for this training run
|
|
206
|
+
training_jsonl: Training data in JSONL format
|
|
207
|
+
validation_jsonl: Validation data in JSONL format
|
|
208
|
+
|
|
209
|
+
Returns:
|
|
210
|
+
Repository ID for the created dataset
|
|
211
|
+
"""
|
|
212
|
+
dataset_name = f"{self.adapter_config.dataset_name}"
|
|
213
|
+
repo_id = f"{self.adapter_config.namespace}/{dataset_name}"
|
|
214
|
+
|
|
215
|
+
# Create dataset repo in datastore
|
|
216
|
+
self.hf_api.create_repo(repo_id, repo_type="dataset", exist_ok=True)
|
|
217
|
+
|
|
218
|
+
# Register dataset in entity store
|
|
219
|
+
try:
|
|
220
|
+
self.entity_client.datasets.create(
|
|
221
|
+
name=dataset_name,
|
|
222
|
+
namespace=self.adapter_config.namespace,
|
|
223
|
+
files_url=f"hf://datasets/{repo_id}",
|
|
224
|
+
description=f"NAT DPO training dataset for run {run_id}",
|
|
225
|
+
)
|
|
226
|
+
except Exception as e:
|
|
227
|
+
logger.debug(f"Dataset may already exist: {e}")
|
|
228
|
+
|
|
229
|
+
# Determine output directory for dataset files
|
|
230
|
+
if self.adapter_config.dataset_output_dir:
|
|
231
|
+
# Use configured output directory (create if needed, preserve files)
|
|
232
|
+
output_dir = Path(self.adapter_config.dataset_output_dir) / run_id
|
|
233
|
+
output_dir.mkdir(parents=True, exist_ok=True)
|
|
234
|
+
use_temp_dir = False
|
|
235
|
+
logger.info(f"Saving dataset files to: {output_dir}")
|
|
236
|
+
else:
|
|
237
|
+
# Use temporary directory (will be cleaned up)
|
|
238
|
+
use_temp_dir = True
|
|
239
|
+
|
|
240
|
+
def write_and_upload_files(base_dir: Path) -> None:
|
|
241
|
+
train_path = base_dir / "training_file.jsonl"
|
|
242
|
+
val_path = base_dir / "validation_file.jsonl"
|
|
243
|
+
|
|
244
|
+
train_path.write_text(training_jsonl)
|
|
245
|
+
val_path.write_text(validation_jsonl)
|
|
246
|
+
|
|
247
|
+
self.hf_api.upload_file(
|
|
248
|
+
path_or_fileobj=str(train_path),
|
|
249
|
+
path_in_repo="training/training_file.jsonl",
|
|
250
|
+
repo_id=repo_id,
|
|
251
|
+
repo_type="dataset",
|
|
252
|
+
revision="main",
|
|
253
|
+
commit_message=f"Training file for run {run_id}",
|
|
254
|
+
)
|
|
255
|
+
|
|
256
|
+
self.hf_api.upload_file(
|
|
257
|
+
path_or_fileobj=str(val_path),
|
|
258
|
+
path_in_repo="validation/validation_file.jsonl",
|
|
259
|
+
repo_id=repo_id,
|
|
260
|
+
repo_type="dataset",
|
|
261
|
+
revision="main",
|
|
262
|
+
commit_message=f"Validation file for run {run_id}",
|
|
263
|
+
)
|
|
264
|
+
|
|
265
|
+
if use_temp_dir:
|
|
266
|
+
with tempfile.TemporaryDirectory() as tmpdir:
|
|
267
|
+
write_and_upload_files(Path(tmpdir))
|
|
268
|
+
else:
|
|
269
|
+
write_and_upload_files(output_dir)
|
|
270
|
+
|
|
271
|
+
logger.info(f"Created and uploaded dataset: {repo_id}")
|
|
272
|
+
return dataset_name
|
|
273
|
+
|
|
274
|
+
async def submit(self, trajectories: TrajectoryCollection) -> TrainingJobRef:
|
|
275
|
+
"""
|
|
276
|
+
Submit trajectories for training.
|
|
277
|
+
|
|
278
|
+
Args:
|
|
279
|
+
trajectories: Collection of trajectories containing DPO items
|
|
280
|
+
|
|
281
|
+
Returns:
|
|
282
|
+
Reference to the submitted training job
|
|
283
|
+
"""
|
|
284
|
+
run_id = trajectories.run_id
|
|
285
|
+
|
|
286
|
+
if run_id in self._active_jobs:
|
|
287
|
+
raise ValueError(f"Training job for run {run_id} already exists")
|
|
288
|
+
|
|
289
|
+
# Convert trajectories to JSONL
|
|
290
|
+
training_jsonl, validation_jsonl = self._trajectory_to_dpo_jsonl(trajectories)
|
|
291
|
+
|
|
292
|
+
# Upload dataset
|
|
293
|
+
dataset_name = await self._setup_dataset(run_id, training_jsonl, validation_jsonl)
|
|
294
|
+
|
|
295
|
+
# Prepare hyperparameters
|
|
296
|
+
hyperparams = self.adapter_config.hyperparameters.model_dump()
|
|
297
|
+
|
|
298
|
+
# Submit customization job
|
|
299
|
+
job = self.entity_client.customization.jobs.create(
|
|
300
|
+
config=self.adapter_config.customization_config,
|
|
301
|
+
dataset={
|
|
302
|
+
"name": dataset_name,
|
|
303
|
+
"namespace": self.adapter_config.namespace,
|
|
304
|
+
},
|
|
305
|
+
hyperparameters=hyperparams,
|
|
306
|
+
)
|
|
307
|
+
|
|
308
|
+
job_id = job.id
|
|
309
|
+
self._active_jobs[run_id] = job_id
|
|
310
|
+
self._job_output_models[run_id] = job.output_model
|
|
311
|
+
|
|
312
|
+
logger.info(f"Submitted customization job {job_id} for run {run_id}. "
|
|
313
|
+
f"Output model: {job.output_model}")
|
|
314
|
+
|
|
315
|
+
return TrainingJobRef(
|
|
316
|
+
run_id=run_id,
|
|
317
|
+
backend="nemo-customizer",
|
|
318
|
+
metadata={
|
|
319
|
+
"job_id": job_id,
|
|
320
|
+
"output_model": job.output_model,
|
|
321
|
+
"dataset_name": dataset_name,
|
|
322
|
+
},
|
|
323
|
+
)
|
|
324
|
+
|
|
325
|
+
async def status(self, ref: TrainingJobRef) -> TrainingJobStatus:
|
|
326
|
+
"""Get the status of a training job."""
|
|
327
|
+
job_id = self._active_jobs.get(ref.run_id)
|
|
328
|
+
if job_id is None:
|
|
329
|
+
# Try to get from metadata
|
|
330
|
+
job_id = ref.metadata.get("job_id") if ref.metadata else None
|
|
331
|
+
|
|
332
|
+
if job_id is None:
|
|
333
|
+
raise ValueError(f"No training job found for run {ref.run_id}")
|
|
334
|
+
|
|
335
|
+
try:
|
|
336
|
+
job_status = self.entity_client.customization.jobs.status(job_id)
|
|
337
|
+
|
|
338
|
+
# Map NeMo status to TrainingStatusEnum
|
|
339
|
+
status_map = {
|
|
340
|
+
"created": TrainingStatusEnum.PENDING,
|
|
341
|
+
"pending": TrainingStatusEnum.PENDING,
|
|
342
|
+
"running": TrainingStatusEnum.RUNNING,
|
|
343
|
+
"completed": TrainingStatusEnum.COMPLETED,
|
|
344
|
+
"failed": TrainingStatusEnum.FAILED,
|
|
345
|
+
"cancelled": TrainingStatusEnum.CANCELED,
|
|
346
|
+
"canceled": TrainingStatusEnum.CANCELED,
|
|
347
|
+
}
|
|
348
|
+
|
|
349
|
+
status = status_map.get(job_status.status.lower(), TrainingStatusEnum.RUNNING)
|
|
350
|
+
progress = getattr(job_status, "percentage_done", None)
|
|
351
|
+
|
|
352
|
+
message = f"Status: {job_status.status}"
|
|
353
|
+
if hasattr(job_status, "epochs_completed"):
|
|
354
|
+
message += f", Epochs: {job_status.epochs_completed}"
|
|
355
|
+
|
|
356
|
+
return TrainingJobStatus(
|
|
357
|
+
run_id=ref.run_id,
|
|
358
|
+
backend=ref.backend,
|
|
359
|
+
status=status,
|
|
360
|
+
progress=progress,
|
|
361
|
+
message=message,
|
|
362
|
+
metadata={
|
|
363
|
+
"job_id": job_id,
|
|
364
|
+
"nemo_status": job_status.status,
|
|
365
|
+
"output_model": self._job_output_models.get(ref.run_id),
|
|
366
|
+
},
|
|
367
|
+
)
|
|
368
|
+
|
|
369
|
+
except Exception as e:
|
|
370
|
+
logger.error(f"Error getting job status: {e}")
|
|
371
|
+
return TrainingJobStatus(
|
|
372
|
+
run_id=ref.run_id,
|
|
373
|
+
backend=ref.backend,
|
|
374
|
+
status=TrainingStatusEnum.FAILED,
|
|
375
|
+
message=f"Error getting status: {e}",
|
|
376
|
+
)
|
|
377
|
+
|
|
378
|
+
async def wait_until_complete(self, ref: TrainingJobRef, poll_interval: float | None = None) -> TrainingJobStatus:
|
|
379
|
+
"""Wait for training job to complete."""
|
|
380
|
+
interval = poll_interval or self.adapter_config.poll_interval_seconds
|
|
381
|
+
|
|
382
|
+
last_status: str | None = None
|
|
383
|
+
|
|
384
|
+
while True:
|
|
385
|
+
status = await self.status(ref)
|
|
386
|
+
|
|
387
|
+
# Log when status changes
|
|
388
|
+
current_status = status.status.value
|
|
389
|
+
if current_status != last_status:
|
|
390
|
+
logger.info(f"Job {ref.run_id}: Status -> '{current_status}'")
|
|
391
|
+
last_status = current_status
|
|
392
|
+
|
|
393
|
+
# Log when progress changes
|
|
394
|
+
current_progress = status.progress
|
|
395
|
+
#if current_progress is not None and current_progress != last_progress:
|
|
396
|
+
logger.info(f"Job {ref.run_id}: Progress {current_progress:.1f}%")
|
|
397
|
+
|
|
398
|
+
if status.status in (
|
|
399
|
+
TrainingStatusEnum.COMPLETED,
|
|
400
|
+
TrainingStatusEnum.FAILED,
|
|
401
|
+
TrainingStatusEnum.CANCELED,
|
|
402
|
+
):
|
|
403
|
+
# Handle deployment if configured
|
|
404
|
+
if (status.status == TrainingStatusEnum.COMPLETED and self.adapter_config.deploy_on_completion):
|
|
405
|
+
await self._deploy_model(ref)
|
|
406
|
+
|
|
407
|
+
# Clean up active job tracking
|
|
408
|
+
self._active_jobs.pop(ref.run_id, None)
|
|
409
|
+
|
|
410
|
+
return status
|
|
411
|
+
|
|
412
|
+
await asyncio.sleep(interval)
|
|
413
|
+
|
|
414
|
+
async def _deploy_model(self, ref: TrainingJobRef) -> None:
|
|
415
|
+
"""Deploy the trained model and wait until deployment is ready."""
|
|
416
|
+
output_model = self._job_output_models.get(ref.run_id)
|
|
417
|
+
if not output_model:
|
|
418
|
+
logger.warning(f"No output model found for run {ref.run_id}, skipping deployment")
|
|
419
|
+
return
|
|
420
|
+
|
|
421
|
+
deploy_config = self.adapter_config.deployment_config
|
|
422
|
+
namespace = self.adapter_config.namespace
|
|
423
|
+
|
|
424
|
+
try:
|
|
425
|
+
# Create deployment configuration
|
|
426
|
+
config_name = f"nat-deploy-config-{ref.run_id}"
|
|
427
|
+
dep_config = self.entity_client.deployment.configs.create(
|
|
428
|
+
name=config_name,
|
|
429
|
+
namespace=namespace,
|
|
430
|
+
description=deploy_config.description,
|
|
431
|
+
model=output_model,
|
|
432
|
+
nim_deployment={
|
|
433
|
+
"image_name": deploy_config.image_name,
|
|
434
|
+
"image_tag": deploy_config.image_tag,
|
|
435
|
+
"gpu": deploy_config.gpu,
|
|
436
|
+
},
|
|
437
|
+
)
|
|
438
|
+
|
|
439
|
+
# Create model deployment
|
|
440
|
+
deployment_name = (deploy_config.deployment_name or f"nat-deployment-{ref.run_id}")
|
|
441
|
+
self.entity_client.deployment.model_deployments.create(
|
|
442
|
+
name=deployment_name,
|
|
443
|
+
namespace=namespace,
|
|
444
|
+
description=deploy_config.description,
|
|
445
|
+
config=f"{dep_config.namespace}/{dep_config.name}",
|
|
446
|
+
)
|
|
447
|
+
|
|
448
|
+
logger.info(f"Created deployment '{deployment_name}' for model {output_model}")
|
|
449
|
+
|
|
450
|
+
# Wait for deployment to be ready
|
|
451
|
+
await self._wait_for_deployment_ready(namespace, deployment_name)
|
|
452
|
+
|
|
453
|
+
except Exception as e:
|
|
454
|
+
logger.error(f"Failed to deploy model: {e}")
|
|
455
|
+
raise
|
|
456
|
+
|
|
457
|
+
async def _wait_for_deployment_ready(
|
|
458
|
+
self,
|
|
459
|
+
namespace: str,
|
|
460
|
+
deployment_name: str,
|
|
461
|
+
poll_interval: float | None = None,
|
|
462
|
+
timeout: float | None = None,
|
|
463
|
+
) -> None:
|
|
464
|
+
"""
|
|
465
|
+
Wait for a model deployment to become ready.
|
|
466
|
+
|
|
467
|
+
Args:
|
|
468
|
+
namespace: Namespace of the deployment
|
|
469
|
+
deployment_name: Name of the deployment
|
|
470
|
+
poll_interval: Seconds between status checks (default: adapter config poll_interval_seconds)
|
|
471
|
+
timeout: Maximum seconds to wait (default: adapter config deployment_timeout_seconds)
|
|
472
|
+
"""
|
|
473
|
+
interval = poll_interval or self.adapter_config.poll_interval_seconds
|
|
474
|
+
max_wait = timeout or self.adapter_config.deployment_timeout_seconds
|
|
475
|
+
|
|
476
|
+
logger.info(f"Waiting for deployment '{deployment_name}' to be ready...")
|
|
477
|
+
|
|
478
|
+
last_status: str | None = None
|
|
479
|
+
elapsed = 0.0
|
|
480
|
+
|
|
481
|
+
while elapsed < max_wait:
|
|
482
|
+
try:
|
|
483
|
+
# Get all deployments and find ours
|
|
484
|
+
deployments = self.entity_client.deployment.model_deployments.list().data
|
|
485
|
+
deployment = None
|
|
486
|
+
for dep in deployments:
|
|
487
|
+
if dep.name == deployment_name and dep.namespace == namespace:
|
|
488
|
+
deployment = dep
|
|
489
|
+
break
|
|
490
|
+
|
|
491
|
+
if deployment is None:
|
|
492
|
+
logger.warning(f"Deployment '{deployment_name}' not found in namespace '{namespace}'")
|
|
493
|
+
await asyncio.sleep(interval)
|
|
494
|
+
elapsed += interval
|
|
495
|
+
continue
|
|
496
|
+
|
|
497
|
+
# Check status
|
|
498
|
+
status_details = getattr(deployment, "status_details", None)
|
|
499
|
+
current_status = status_details.status if status_details else "unknown"
|
|
500
|
+
description = status_details.description if status_details else ""
|
|
501
|
+
|
|
502
|
+
# Log status changes
|
|
503
|
+
if current_status != last_status:
|
|
504
|
+
logger.info(f"Deployment '{deployment_name}': Status -> '{current_status}'")
|
|
505
|
+
if description:
|
|
506
|
+
logger.info(f"Deployment '{deployment_name}': {description.strip()}")
|
|
507
|
+
last_status = current_status
|
|
508
|
+
|
|
509
|
+
# Check if ready
|
|
510
|
+
if current_status.lower() == "ready":
|
|
511
|
+
logger.info(f"Deployment '{deployment_name}' is ready!")
|
|
512
|
+
return
|
|
513
|
+
|
|
514
|
+
# Check for failure states
|
|
515
|
+
if current_status.lower() in ("failed", "error"):
|
|
516
|
+
raise RuntimeError(
|
|
517
|
+
f"Deployment '{deployment_name}' failed with status '{current_status}': {description}")
|
|
518
|
+
|
|
519
|
+
except RuntimeError:
|
|
520
|
+
raise
|
|
521
|
+
except Exception as e:
|
|
522
|
+
logger.warning(f"Error checking deployment status: {e}")
|
|
523
|
+
|
|
524
|
+
await asyncio.sleep(interval)
|
|
525
|
+
elapsed += interval
|
|
526
|
+
|
|
527
|
+
raise TimeoutError(f"Deployment '{deployment_name}' did not become ready within {max_wait} seconds")
|
|
528
|
+
|
|
529
|
+
def log_progress(self, ref: TrainingJobRef, metrics: dict[str, Any], output_dir: str | None = None) -> None:
|
|
530
|
+
"""Log training progress to file."""
|
|
531
|
+
out_dir = Path(output_dir) if output_dir else Path("./.tmp/nat/finetuning/trainer_adapter")
|
|
532
|
+
out_dir.mkdir(parents=True, exist_ok=True)
|
|
533
|
+
|
|
534
|
+
log_file = out_dir / f"nemo_customizer_{ref.run_id}.jsonl"
|
|
535
|
+
|
|
536
|
+
log_entry = {
|
|
537
|
+
"timestamp": datetime.now().isoformat(),
|
|
538
|
+
"run_id": ref.run_id,
|
|
539
|
+
"backend": ref.backend,
|
|
540
|
+
"config": {
|
|
541
|
+
"namespace": self.adapter_config.namespace,
|
|
542
|
+
"customization_config": self.adapter_config.customization_config,
|
|
543
|
+
},
|
|
544
|
+
**metrics,
|
|
545
|
+
}
|
|
546
|
+
|
|
547
|
+
with open(log_file, "a", encoding="utf-8") as f:
|
|
548
|
+
f.write(json.dumps(log_entry) + "\n")
|
|
549
|
+
|
|
550
|
+
logger.debug(f"Logged progress for job {ref.run_id}")
|