nvidia-nat-nemo-customizer 1.4.0a20251223__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nat/meta/pypi.md +23 -0
- nat/plugins/customizer/__init__.py +43 -0
- nat/plugins/customizer/dpo/__init__.py +44 -0
- nat/plugins/customizer/dpo/config.py +360 -0
- nat/plugins/customizer/dpo/register.py +157 -0
- nat/plugins/customizer/dpo/trainer.py +424 -0
- nat/plugins/customizer/dpo/trainer_adapter.py +550 -0
- nat/plugins/customizer/dpo/trajectory_builder.py +767 -0
- nat/plugins/customizer/register.py +23 -0
- nvidia_nat_nemo_customizer-1.4.0a20251223.dist-info/METADATA +45 -0
- nvidia_nat_nemo_customizer-1.4.0a20251223.dist-info/RECORD +16 -0
- nvidia_nat_nemo_customizer-1.4.0a20251223.dist-info/WHEEL +5 -0
- nvidia_nat_nemo_customizer-1.4.0a20251223.dist-info/entry_points.txt +2 -0
- nvidia_nat_nemo_customizer-1.4.0a20251223.dist-info/licenses/LICENSE-3rd-party.txt +5478 -0
- nvidia_nat_nemo_customizer-1.4.0a20251223.dist-info/licenses/LICENSE.md +201 -0
- nvidia_nat_nemo_customizer-1.4.0a20251223.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,424 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
"""
|
|
16
|
+
NeMo Customizer Trainer for DPO finetuning.
|
|
17
|
+
|
|
18
|
+
This module provides a Trainer implementation that orchestrates data collection
|
|
19
|
+
via trajectory builders and submits training jobs to NeMo Customizer.
|
|
20
|
+
"""
|
|
21
|
+
|
|
22
|
+
import json
|
|
23
|
+
import logging
|
|
24
|
+
import uuid
|
|
25
|
+
from datetime import datetime
|
|
26
|
+
from pathlib import Path
|
|
27
|
+
from typing import Any
|
|
28
|
+
|
|
29
|
+
from nat.data_models.finetuning import FinetuneConfig
|
|
30
|
+
from nat.data_models.finetuning import TrainingJobRef
|
|
31
|
+
from nat.data_models.finetuning import TrainingJobStatus
|
|
32
|
+
from nat.data_models.finetuning import TrainingStatusEnum
|
|
33
|
+
from nat.data_models.finetuning import Trajectory
|
|
34
|
+
from nat.data_models.finetuning import TrajectoryCollection
|
|
35
|
+
from nat.finetuning.interfaces.finetuning_runner import Trainer
|
|
36
|
+
|
|
37
|
+
from .config import NeMoCustomizerTrainerConfig
|
|
38
|
+
|
|
39
|
+
logger = logging.getLogger(__name__)
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
class NeMoCustomizerTrainer(Trainer):
|
|
43
|
+
"""
|
|
44
|
+
Trainer for NeMo Customizer DPO/SFT finetuning.
|
|
45
|
+
|
|
46
|
+
Unlike epoch-based trainers, this trainer:
|
|
47
|
+
1. Runs the trajectory builder multiple times (num_runs) to collect data
|
|
48
|
+
2. Aggregates all trajectories into a single dataset
|
|
49
|
+
3. Submits the dataset to NeMo Customizer for training
|
|
50
|
+
4. Monitors the training job until completion
|
|
51
|
+
|
|
52
|
+
The actual training epochs are handled by NeMo Customizer via hyperparameters.
|
|
53
|
+
"""
|
|
54
|
+
|
|
55
|
+
def __init__(self, trainer_config: NeMoCustomizerTrainerConfig, **kwargs) -> None:
|
|
56
|
+
"""
|
|
57
|
+
Initialize the NeMo Customizer Trainer.
|
|
58
|
+
|
|
59
|
+
Args:
|
|
60
|
+
trainer_config: Configuration for the trainer
|
|
61
|
+
"""
|
|
62
|
+
super().__init__(trainer_config)
|
|
63
|
+
|
|
64
|
+
self.trainer_config: NeMoCustomizerTrainerConfig = trainer_config
|
|
65
|
+
|
|
66
|
+
# Track job references and metrics
|
|
67
|
+
self._job_ref: TrainingJobRef | None = None
|
|
68
|
+
self._run_id: str | None = None
|
|
69
|
+
|
|
70
|
+
# Track collected data across runs
|
|
71
|
+
self._all_trajectories: list[list[Trajectory]] = []
|
|
72
|
+
self._run_metrics: list[dict[str, Any]] = []
|
|
73
|
+
|
|
74
|
+
# Progress tracking
|
|
75
|
+
self._collection_history: list[dict[str, Any]] = []
|
|
76
|
+
|
|
77
|
+
async def initialize(self, run_config: FinetuneConfig) -> None:
|
|
78
|
+
"""
|
|
79
|
+
Initialize the trainer and its components.
|
|
80
|
+
|
|
81
|
+
Note: Curriculum learning is not supported for DPO training.
|
|
82
|
+
"""
|
|
83
|
+
logger.info("Initializing NeMo Customizer Trainer")
|
|
84
|
+
|
|
85
|
+
# Store run config but skip curriculum learning setup
|
|
86
|
+
self.run_config = run_config
|
|
87
|
+
self.trainer_config.reward = self.run_config.reward_function
|
|
88
|
+
|
|
89
|
+
# Disable curriculum learning for DPO
|
|
90
|
+
self.curriculum_config = None
|
|
91
|
+
self._curriculum_state = {
|
|
92
|
+
"current_percentile": 1.0,
|
|
93
|
+
"last_expansion_epoch": -1,
|
|
94
|
+
"total_groups": 0,
|
|
95
|
+
"included_groups": set(),
|
|
96
|
+
}
|
|
97
|
+
|
|
98
|
+
# Initialize components
|
|
99
|
+
await self.trajectory_builder.initialize(run_config)
|
|
100
|
+
await self.trainer_adapter.initialize(run_config)
|
|
101
|
+
|
|
102
|
+
# Generate unique run ID
|
|
103
|
+
self._run_id = f"nemo_dpo_{uuid.uuid4().hex[:8]}"
|
|
104
|
+
|
|
105
|
+
logger.info(f"NeMo Customizer Trainer initialized with run ID: {self._run_id}")
|
|
106
|
+
|
|
107
|
+
async def run_epoch(self, epoch: int, run_id: str) -> TrainingJobRef | None:
|
|
108
|
+
"""
|
|
109
|
+
Run a single data collection run.
|
|
110
|
+
|
|
111
|
+
For NeMo Customizer, this collects trajectories without submitting
|
|
112
|
+
to training. The actual submission happens in run().
|
|
113
|
+
|
|
114
|
+
Args:
|
|
115
|
+
epoch: The current run number (0-indexed)
|
|
116
|
+
run_id: Unique identifier for this training run
|
|
117
|
+
|
|
118
|
+
Returns:
|
|
119
|
+
None (trajectories are accumulated, not submitted per-run)
|
|
120
|
+
"""
|
|
121
|
+
logger.info(f"Starting data collection run {epoch + 1}/{self.trainer_config.num_runs}")
|
|
122
|
+
|
|
123
|
+
run_meta = {
|
|
124
|
+
"run_number": epoch,
|
|
125
|
+
"run_id": run_id,
|
|
126
|
+
"trainer_config": self.trainer_config.model_dump(),
|
|
127
|
+
}
|
|
128
|
+
|
|
129
|
+
# Run trajectory builder
|
|
130
|
+
await self.trajectory_builder.start_run(run_id=f"{run_id}_run{epoch}", meta=run_meta)
|
|
131
|
+
|
|
132
|
+
# Finalize and get trajectories
|
|
133
|
+
trajectory_collection = await self.trajectory_builder.finalize(run_id=f"{run_id}_run{epoch}", meta=run_meta)
|
|
134
|
+
|
|
135
|
+
if not trajectory_collection.trajectories:
|
|
136
|
+
logger.warning(f"No trajectories collected for run {epoch}")
|
|
137
|
+
return None
|
|
138
|
+
|
|
139
|
+
# Calculate metrics for this run
|
|
140
|
+
run_rewards = []
|
|
141
|
+
num_trajectories = 0
|
|
142
|
+
num_dpo_pairs = 0
|
|
143
|
+
|
|
144
|
+
for trajectory_group in trajectory_collection.trajectories:
|
|
145
|
+
for trajectory in trajectory_group:
|
|
146
|
+
num_trajectories += 1
|
|
147
|
+
run_rewards.append(trajectory.reward)
|
|
148
|
+
# Count DPO pairs (each trajectory has one DPOItem)
|
|
149
|
+
num_dpo_pairs += len(trajectory.episode)
|
|
150
|
+
|
|
151
|
+
metrics = {
|
|
152
|
+
"run_number": epoch,
|
|
153
|
+
"num_trajectories": num_trajectories,
|
|
154
|
+
"num_dpo_pairs": num_dpo_pairs,
|
|
155
|
+
"avg_reward": sum(run_rewards) / len(run_rewards) if run_rewards else 0.0,
|
|
156
|
+
"min_reward": min(run_rewards) if run_rewards else 0.0,
|
|
157
|
+
"max_reward": max(run_rewards) if run_rewards else 0.0,
|
|
158
|
+
"timestamp": datetime.now().isoformat(),
|
|
159
|
+
}
|
|
160
|
+
self._run_metrics.append(metrics)
|
|
161
|
+
|
|
162
|
+
logger.info(f"Run {epoch + 1}: Collected {num_trajectories} trajectories, "
|
|
163
|
+
f"{num_dpo_pairs} DPO pairs, avg reward: {metrics['avg_reward']:.4f}")
|
|
164
|
+
|
|
165
|
+
# Accumulate trajectories
|
|
166
|
+
self._all_trajectories.extend(trajectory_collection.trajectories)
|
|
167
|
+
|
|
168
|
+
# Log progress
|
|
169
|
+
self.log_progress(epoch, metrics)
|
|
170
|
+
|
|
171
|
+
return None # No job submitted per-run
|
|
172
|
+
|
|
173
|
+
async def run(self, num_epochs: int) -> list[TrainingJobStatus]:
|
|
174
|
+
"""
|
|
175
|
+
Run the complete DPO data collection and training workflow.
|
|
176
|
+
|
|
177
|
+
Args:
|
|
178
|
+
num_epochs: Ignored for NeMo Customizer (uses trainer_config.num_runs)
|
|
179
|
+
|
|
180
|
+
Returns:
|
|
181
|
+
list[TrainingJobStatus]: Status of the training job
|
|
182
|
+
"""
|
|
183
|
+
if not self._run_id:
|
|
184
|
+
raise RuntimeError("Trainer not initialized. Call initialize() first.")
|
|
185
|
+
|
|
186
|
+
num_runs = self.trainer_config.num_runs
|
|
187
|
+
logger.info(f"Starting NeMo Customizer DPO workflow with {num_runs} data collection runs")
|
|
188
|
+
|
|
189
|
+
# Phase 1: Collect data from multiple runs
|
|
190
|
+
for run_idx in range(num_runs):
|
|
191
|
+
try:
|
|
192
|
+
await self.run_epoch(run_idx, self._run_id)
|
|
193
|
+
except Exception as e:
|
|
194
|
+
logger.error(f"Error during data collection run {run_idx}: {e}")
|
|
195
|
+
if not self.trainer_config.continue_on_collection_error:
|
|
196
|
+
return [
|
|
197
|
+
TrainingJobStatus(
|
|
198
|
+
run_id=self._run_id,
|
|
199
|
+
backend="nemo-customizer",
|
|
200
|
+
status=TrainingStatusEnum.FAILED,
|
|
201
|
+
message=f"Data collection failed at run {run_idx}: {e}",
|
|
202
|
+
metadata={"run_number": run_idx},
|
|
203
|
+
)
|
|
204
|
+
]
|
|
205
|
+
|
|
206
|
+
# Check if we have any data
|
|
207
|
+
if not self._all_trajectories:
|
|
208
|
+
logger.error("No trajectories collected from any run")
|
|
209
|
+
return [
|
|
210
|
+
TrainingJobStatus(
|
|
211
|
+
run_id=self._run_id,
|
|
212
|
+
backend="nemo-customizer",
|
|
213
|
+
status=TrainingStatusEnum.FAILED,
|
|
214
|
+
message="No trajectories collected",
|
|
215
|
+
)
|
|
216
|
+
]
|
|
217
|
+
|
|
218
|
+
# Calculate total statistics
|
|
219
|
+
total_trajectories = len(self._all_trajectories)
|
|
220
|
+
total_dpo_pairs = sum(
|
|
221
|
+
len(traj.episode) for group in self._all_trajectories
|
|
222
|
+
for traj in (group if isinstance(group, list) else [group]))
|
|
223
|
+
|
|
224
|
+
logger.info(f"Data collection complete: {total_trajectories} trajectory groups, "
|
|
225
|
+
f"~{total_dpo_pairs} total DPO pairs from {num_runs} runs")
|
|
226
|
+
|
|
227
|
+
# Phase 2: Submit aggregated trajectories for training
|
|
228
|
+
try:
|
|
229
|
+
trajectory_collection = TrajectoryCollection(
|
|
230
|
+
trajectories=self._all_trajectories,
|
|
231
|
+
run_id=self._run_id,
|
|
232
|
+
)
|
|
233
|
+
|
|
234
|
+
# Apply deduplication if configured
|
|
235
|
+
if self.trainer_config.deduplicate_pairs:
|
|
236
|
+
trajectory_collection = self._deduplicate_trajectories(trajectory_collection)
|
|
237
|
+
|
|
238
|
+
# Apply sampling if configured
|
|
239
|
+
if self.trainer_config.max_pairs is not None:
|
|
240
|
+
trajectory_collection = self._sample_trajectories(trajectory_collection, self.trainer_config.max_pairs)
|
|
241
|
+
|
|
242
|
+
self._job_ref = await self.trainer_adapter.submit(trajectory_collection)
|
|
243
|
+
|
|
244
|
+
logger.info(f"Submitted training job: {self._job_ref.metadata.get('job_id')}")
|
|
245
|
+
|
|
246
|
+
except Exception as e:
|
|
247
|
+
logger.error(f"Failed to submit training job: {e}")
|
|
248
|
+
return [
|
|
249
|
+
TrainingJobStatus(
|
|
250
|
+
run_id=self._run_id,
|
|
251
|
+
backend="nemo-customizer",
|
|
252
|
+
status=TrainingStatusEnum.FAILED,
|
|
253
|
+
message=f"Failed to submit training job: {e}",
|
|
254
|
+
)
|
|
255
|
+
]
|
|
256
|
+
|
|
257
|
+
# Phase 3: Wait for training completion
|
|
258
|
+
if self.trainer_config.wait_for_completion:
|
|
259
|
+
logger.info("Waiting for training job to complete...")
|
|
260
|
+
final_status = await self.trainer_adapter.wait_until_complete(self._job_ref)
|
|
261
|
+
|
|
262
|
+
# Log final metrics
|
|
263
|
+
self._log_final_metrics(final_status)
|
|
264
|
+
|
|
265
|
+
return [final_status]
|
|
266
|
+
else:
|
|
267
|
+
# Return immediately with pending status
|
|
268
|
+
return [
|
|
269
|
+
TrainingJobStatus(
|
|
270
|
+
run_id=self._run_id,
|
|
271
|
+
backend="nemo-customizer",
|
|
272
|
+
status=TrainingStatusEnum.RUNNING,
|
|
273
|
+
message="Training job submitted (not waiting for completion)",
|
|
274
|
+
metadata=self._job_ref.metadata,
|
|
275
|
+
)
|
|
276
|
+
]
|
|
277
|
+
|
|
278
|
+
def _deduplicate_trajectories(self, collection: TrajectoryCollection) -> TrajectoryCollection:
|
|
279
|
+
"""Remove duplicate DPO pairs based on prompt+responses."""
|
|
280
|
+
seen = set()
|
|
281
|
+
unique_groups = []
|
|
282
|
+
|
|
283
|
+
for group in collection.trajectories:
|
|
284
|
+
unique_trajectories = []
|
|
285
|
+
for traj in group:
|
|
286
|
+
for item in traj.episode:
|
|
287
|
+
# Create a hashable key from prompt and responses
|
|
288
|
+
prompt_str = (str(item.prompt) if hasattr(item, "prompt") else "")
|
|
289
|
+
key = (
|
|
290
|
+
prompt_str,
|
|
291
|
+
getattr(item, "chosen_response", ""),
|
|
292
|
+
getattr(item, "rejected_response", ""),
|
|
293
|
+
)
|
|
294
|
+
if key not in seen:
|
|
295
|
+
seen.add(key)
|
|
296
|
+
unique_trajectories.append(traj)
|
|
297
|
+
break # Only add trajectory once
|
|
298
|
+
|
|
299
|
+
if unique_trajectories:
|
|
300
|
+
unique_groups.append(unique_trajectories)
|
|
301
|
+
|
|
302
|
+
original_count = sum(len(g) for g in collection.trajectories)
|
|
303
|
+
new_count = sum(len(g) for g in unique_groups)
|
|
304
|
+
logger.info(f"Deduplication: {original_count} -> {new_count} trajectories")
|
|
305
|
+
|
|
306
|
+
return TrajectoryCollection(trajectories=unique_groups, run_id=collection.run_id)
|
|
307
|
+
|
|
308
|
+
def _sample_trajectories(self, collection: TrajectoryCollection, max_pairs: int) -> TrajectoryCollection:
|
|
309
|
+
"""Sample trajectories to limit dataset size."""
|
|
310
|
+
import random
|
|
311
|
+
|
|
312
|
+
all_trajectories = []
|
|
313
|
+
for group in collection.trajectories:
|
|
314
|
+
all_trajectories.extend(group)
|
|
315
|
+
|
|
316
|
+
if len(all_trajectories) <= max_pairs:
|
|
317
|
+
return collection
|
|
318
|
+
|
|
319
|
+
# Sample randomly
|
|
320
|
+
sampled = random.sample(all_trajectories, max_pairs)
|
|
321
|
+
logger.info(f"Sampling: {len(all_trajectories)} -> {max_pairs} trajectories")
|
|
322
|
+
|
|
323
|
+
return TrajectoryCollection(
|
|
324
|
+
trajectories=[[t] for t in sampled],
|
|
325
|
+
run_id=collection.run_id,
|
|
326
|
+
)
|
|
327
|
+
|
|
328
|
+
async def get_metrics(self, run_id: str) -> dict[str, Any]:
|
|
329
|
+
"""Get training metrics for the run."""
|
|
330
|
+
metrics = {
|
|
331
|
+
"run_id": run_id,
|
|
332
|
+
"num_collection_runs": len(self._run_metrics),
|
|
333
|
+
"collection_runs": self._run_metrics,
|
|
334
|
+
"total_trajectory_groups": len(self._all_trajectories),
|
|
335
|
+
}
|
|
336
|
+
|
|
337
|
+
if self._job_ref:
|
|
338
|
+
try:
|
|
339
|
+
status = await self.trainer_adapter.status(self._job_ref)
|
|
340
|
+
metrics["training_job"] = {
|
|
341
|
+
"job_id": self._job_ref.metadata.get("job_id"),
|
|
342
|
+
"status": status.status.value,
|
|
343
|
+
"progress": status.progress,
|
|
344
|
+
"message": status.message,
|
|
345
|
+
}
|
|
346
|
+
except Exception as e:
|
|
347
|
+
metrics["training_job"] = {"error": str(e)}
|
|
348
|
+
|
|
349
|
+
return metrics
|
|
350
|
+
|
|
351
|
+
async def cleanup(self) -> None:
|
|
352
|
+
"""Clean up resources."""
|
|
353
|
+
logger.info("Cleaning up NeMo Customizer Trainer resources")
|
|
354
|
+
|
|
355
|
+
# Cancel any running trajectory builder tasks
|
|
356
|
+
if hasattr(self.trajectory_builder, "evaluation_runs"):
|
|
357
|
+
for run_id, task in self.trajectory_builder.evaluation_runs.items():
|
|
358
|
+
if not task.done():
|
|
359
|
+
logger.info(f"Cancelling evaluation task for run {run_id}")
|
|
360
|
+
task.cancel()
|
|
361
|
+
|
|
362
|
+
# Clear accumulated data
|
|
363
|
+
self._all_trajectories.clear()
|
|
364
|
+
self._run_metrics.clear()
|
|
365
|
+
|
|
366
|
+
logger.info("NeMo Customizer Trainer cleanup completed")
|
|
367
|
+
|
|
368
|
+
def log_progress(self, epoch: int, metrics: dict[str, Any], output_dir: str | None = None) -> None:
|
|
369
|
+
"""Log data collection progress."""
|
|
370
|
+
out_dir = Path(output_dir) if output_dir else self.run_config.output_dir
|
|
371
|
+
out_dir.mkdir(parents=True, exist_ok=True)
|
|
372
|
+
|
|
373
|
+
# Store in history
|
|
374
|
+
progress_entry = {
|
|
375
|
+
"run_number": epoch,
|
|
376
|
+
"timestamp": datetime.now().isoformat(),
|
|
377
|
+
"run_id": self._run_id,
|
|
378
|
+
**metrics,
|
|
379
|
+
}
|
|
380
|
+
self._collection_history.append(progress_entry)
|
|
381
|
+
|
|
382
|
+
# Log to JSON file
|
|
383
|
+
log_file = out_dir / "data_collection_progress.jsonl"
|
|
384
|
+
with open(log_file, "a", encoding="utf-8") as f:
|
|
385
|
+
f.write(json.dumps(progress_entry) + "\n")
|
|
386
|
+
|
|
387
|
+
# Save collection history
|
|
388
|
+
history_file = out_dir / "collection_history.json"
|
|
389
|
+
with open(history_file, "w", encoding="utf-8") as f:
|
|
390
|
+
json.dump(self._collection_history, f, indent=2)
|
|
391
|
+
|
|
392
|
+
logger.info(f"Run {epoch + 1}: {metrics.get('num_dpo_pairs', 0)} DPO pairs, "
|
|
393
|
+
f"avg reward: {metrics.get('avg_reward', 0):.4f}")
|
|
394
|
+
|
|
395
|
+
def _log_final_metrics(self, final_status: TrainingJobStatus) -> None:
|
|
396
|
+
"""Log final training metrics."""
|
|
397
|
+
out_dir = self.run_config.output_dir
|
|
398
|
+
out_dir.mkdir(parents=True, exist_ok=True)
|
|
399
|
+
|
|
400
|
+
final_metrics = {
|
|
401
|
+
"run_id": self._run_id,
|
|
402
|
+
"timestamp": datetime.now().isoformat(),
|
|
403
|
+
"status": final_status.status.value,
|
|
404
|
+
"message": final_status.message,
|
|
405
|
+
"progress": final_status.progress,
|
|
406
|
+
"num_collection_runs": len(self._run_metrics),
|
|
407
|
+
"total_trajectory_groups": len(self._all_trajectories),
|
|
408
|
+
"collection_summary": {
|
|
409
|
+
"total_trajectories":
|
|
410
|
+
sum(m.get("num_trajectories", 0) for m in self._run_metrics),
|
|
411
|
+
"total_dpo_pairs":
|
|
412
|
+
sum(m.get("num_dpo_pairs", 0) for m in self._run_metrics),
|
|
413
|
+
"avg_reward": (sum(m.get("avg_reward", 0)
|
|
414
|
+
for m in self._run_metrics) / len(self._run_metrics) if self._run_metrics else 0.0),
|
|
415
|
+
},
|
|
416
|
+
"job_metadata": self._job_ref.metadata if self._job_ref else None,
|
|
417
|
+
}
|
|
418
|
+
|
|
419
|
+
# Save final metrics
|
|
420
|
+
metrics_file = out_dir / "final_metrics.json"
|
|
421
|
+
with open(metrics_file, "w", encoding="utf-8") as f:
|
|
422
|
+
json.dump(final_metrics, f, indent=2)
|
|
423
|
+
|
|
424
|
+
logger.info(f"Training completed with status: {final_status.status.value}")
|