nvidia-nat-nemo-customizer 1.4.0a20251223__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
nat/meta/pypi.md ADDED
@@ -0,0 +1,23 @@
1
+ <!--
2
+ SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ SPDX-License-Identifier: Apache-2.0
4
+
5
+ Licensed under the Apache License, Version 2.0 (the "License");
6
+ you may not use this file except in compliance with the License.
7
+ You may obtain a copy of the License at
8
+
9
+ http://www.apache.org/licenses/LICENSE-2.0
10
+
11
+ Unless required by applicable law or agreed to in writing, software
12
+ distributed under the License is distributed on an "AS IS" BASIS,
13
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ See the License for the specific language governing permissions and
15
+ limitations under the License.
16
+ -->
17
+
18
+ ![NVIDIA NeMo Agent Toolkit](https://media.githubusercontent.com/media/NVIDIA/NeMo-Agent-Toolkit/refs/heads/main/docs/source/_static/banner.png "NeMo Agent toolkit banner image")
19
+
20
+ # NVIDIA NeMo Agent Toolkit Subpackage
21
+ This is a subpackage for NeMo Customizer integration in NeMo Agent toolkit.
22
+
23
+ For more information about the NVIDIA NeMo Agent toolkit, please visit the [NeMo Agent toolkit GitHub Repo](https://github.com/NVIDIA/NeMo-Agent-Toolkit).
@@ -0,0 +1,43 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """
16
+ NeMo Customizer plugin for NAT finetuning.
17
+
18
+ This plugin provides trajectory builders and trainer adapters for
19
+ finetuning workflows using NeMo Customizer backend.
20
+
21
+ Available components:
22
+ - DPO Trajectory Builder: Collects preference pairs from scored TTC candidates
23
+ - NeMo Customizer TrainerAdapter: Submits DPO/SFT jobs to NeMo Customizer
24
+ """
25
+
26
+ from .dpo import DPOSpecificHyperparameters
27
+ from .dpo import DPOTrajectoryBuilder
28
+ from .dpo import DPOTrajectoryBuilderConfig
29
+ from .dpo import NeMoCustomizerHyperparameters
30
+ from .dpo import NeMoCustomizerTrainerAdapter
31
+ from .dpo import NeMoCustomizerTrainerAdapterConfig
32
+ from .dpo import NIMDeploymentConfig
33
+
34
+ __all__ = [
35
+ # Trajectory Builder
36
+ "DPOTrajectoryBuilder",
37
+ "DPOTrajectoryBuilderConfig", # TrainerAdapter
38
+ "NeMoCustomizerTrainerAdapter",
39
+ "NeMoCustomizerTrainerAdapterConfig",
40
+ "NeMoCustomizerHyperparameters",
41
+ "DPOSpecificHyperparameters",
42
+ "NIMDeploymentConfig",
43
+ ]
@@ -0,0 +1,44 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """
16
+ DPO (Direct Preference Optimization) components for NAT.
17
+
18
+ This module provides:
19
+ - DPO Trajectory Builder: Collects preference data from scored TTC intermediate steps
20
+ - NeMo Customizer TrainerAdapter: Submits DPO training jobs to NeMo Customizer
21
+ """
22
+
23
+ from .config import DPOSpecificHyperparameters
24
+ from .config import DPOTrajectoryBuilderConfig
25
+ from .config import NeMoCustomizerHyperparameters
26
+ from .config import NeMoCustomizerTrainerAdapterConfig
27
+ from .config import NeMoCustomizerTrainerConfig
28
+ from .config import NIMDeploymentConfig
29
+ from .trainer import NeMoCustomizerTrainer
30
+ from .trainer_adapter import NeMoCustomizerTrainerAdapter
31
+ from .trajectory_builder import DPOTrajectoryBuilder
32
+
33
+ __all__ = [
34
+ # Trajectory Builder
35
+ "DPOTrajectoryBuilderConfig",
36
+ "DPOTrajectoryBuilder", # Trainer
37
+ "NeMoCustomizerTrainerConfig",
38
+ "NeMoCustomizerTrainer", # TrainerAdapter
39
+ "NeMoCustomizerTrainerAdapterConfig",
40
+ "NeMoCustomizerTrainerAdapter",
41
+ "NeMoCustomizerHyperparameters",
42
+ "DPOSpecificHyperparameters",
43
+ "NIMDeploymentConfig",
44
+ ]
@@ -0,0 +1,360 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """
16
+ Configuration classes for DPO training with NeMo Customizer.
17
+
18
+ This module provides configuration for:
19
+ 1. DPO Trajectory Builder - collecting preference data from workflows
20
+ 2. NeMo Customizer TrainerAdapter - submitting DPO training jobs
21
+ """
22
+
23
+ from typing import Literal
24
+
25
+ from pydantic import BaseModel
26
+ from pydantic import Field
27
+ from pydantic import model_validator
28
+
29
+ from nat.data_models.finetuning import TrainerAdapterConfig
30
+ from nat.data_models.finetuning import TrainerConfig
31
+ from nat.data_models.finetuning import TrajectoryBuilderConfig
32
+
33
+
34
+ class DPOTrajectoryBuilderConfig(TrajectoryBuilderConfig, name="dpo_traj_builder"):
35
+ """
36
+ Configuration for the DPO (Direct Preference Optimization) Trajectory Builder.
37
+
38
+ This builder collects preference pairs from workflows that produce TTC_END
39
+ intermediate steps with TTCEventData. It uses the structured TTCEventData
40
+ model to extract turn_id, candidate_index, score, input (prompt), and
41
+ output (response) - no dictionary key configuration needed.
42
+
43
+ The builder groups candidates by turn_id and creates preference pairs based
44
+ on score differences.
45
+
46
+ Example YAML configuration::
47
+
48
+ trajectory_builders:
49
+ dpo_builder:
50
+ _type: dpo_traj_builder
51
+ ttc_step_name: dpo_candidate_move
52
+ exhaustive_pairs: true
53
+ min_score_diff: 0.05
54
+ max_pairs_per_turn: 5
55
+ """
56
+
57
+ # === Step Filtering ===
58
+ ttc_step_name: str = Field(
59
+ default="dpo_candidate_move",
60
+ description="Name of the TTC intermediate step to collect. "
61
+ "The builder filters for TTC_END events with this name.",
62
+ )
63
+
64
+ # === Pair Generation Modes ===
65
+ exhaustive_pairs: bool = Field(
66
+ default=True,
67
+ description="If True, generate all pairwise comparisons where "
68
+ "score(A) > score(B). If False, only generate best vs worst pair.",
69
+ )
70
+
71
+ min_score_diff: float = Field(
72
+ default=0.0,
73
+ ge=0.0,
74
+ description="Minimum score difference required to create a preference "
75
+ "pair. Pairs with smaller differences are filtered out.",
76
+ )
77
+
78
+ max_pairs_per_turn: int | None = Field(
79
+ default=None,
80
+ ge=1,
81
+ description="Maximum number of preference pairs to generate per turn. "
82
+ "If None, no limit. Pairs sorted by score difference (highest first).",
83
+ )
84
+
85
+ # === Reward Computation ===
86
+ reward_from_score_diff: bool = Field(
87
+ default=True,
88
+ description="If True, compute trajectory reward as score difference "
89
+ "(chosen - rejected). If False, use chosen score directly as reward.",
90
+ )
91
+
92
+ # === Validation ===
93
+ require_multiple_candidates: bool = Field(
94
+ default=True,
95
+ description="If True, skip turns with only one candidate (no preference "
96
+ "signal). If False, include single-candidate turns.",
97
+ )
98
+
99
+ @model_validator(mode="after")
100
+ def validate_config(self) -> "DPOTrajectoryBuilderConfig":
101
+ """Validate configuration consistency."""
102
+ if self.max_pairs_per_turn is not None and self.max_pairs_per_turn < 1:
103
+ raise ValueError("max_pairs_per_turn must be at least 1 if specified")
104
+ return self
105
+
106
+
107
+ # =============================================================================
108
+ # NeMo Customizer Trainer Configuration
109
+ # =============================================================================
110
+
111
+
112
+ class NeMoCustomizerTrainerConfig(TrainerConfig, name="nemo_customizer_trainer"):
113
+ """
114
+ Configuration for the NeMo Customizer Trainer.
115
+
116
+ This trainer orchestrates DPO data collection and training job submission.
117
+ Unlike epoch-based trainers, it runs the trajectory builder multiple times
118
+ to collect data, then submits a single training job to NeMo Customizer.
119
+
120
+ Example YAML configuration::
121
+
122
+ trainers:
123
+ nemo_dpo:
124
+ _type: nemo_customizer_trainer
125
+ num_runs: 5
126
+ wait_for_completion: true
127
+ deduplicate_pairs: true
128
+ max_pairs: 10000
129
+ """
130
+
131
+ # === Data Collection ===
132
+ num_runs: int = Field(
133
+ default=1,
134
+ ge=1,
135
+ description="Number of times to run the trajectory builder to collect data. "
136
+ "Each run generates preference pairs from the evaluation dataset. "
137
+ "Multiple runs can increase dataset diversity.",
138
+ )
139
+
140
+ continue_on_collection_error: bool = Field(
141
+ default=False,
142
+ description="If True, continue with remaining runs if one fails. "
143
+ "If False, stop immediately on first error.",
144
+ )
145
+
146
+ # === Data Processing ===
147
+ deduplicate_pairs: bool = Field(
148
+ default=True,
149
+ description="If True, remove duplicate DPO pairs based on prompt+responses. "
150
+ "Useful when multiple runs may generate the same pairs.",
151
+ )
152
+
153
+ max_pairs: int | None = Field(
154
+ default=None,
155
+ ge=1,
156
+ description="Maximum number of DPO pairs to include in training. "
157
+ "If None, use all collected pairs. If set, randomly samples pairs.",
158
+ )
159
+
160
+ # === Training Job ===
161
+ wait_for_completion: bool = Field(
162
+ default=True,
163
+ description="If True, wait for the NeMo Customizer training job to complete. "
164
+ "If False, submit the job and return immediately.",
165
+ )
166
+
167
+
168
+ # =============================================================================
169
+ # NeMo Customizer TrainerAdapter Configuration
170
+ # =============================================================================
171
+
172
+
173
+ class DPOSpecificHyperparameters(BaseModel):
174
+ """DPO-specific hyperparameters for NeMo Customizer."""
175
+
176
+ ref_policy_kl_penalty: float = Field(
177
+ default=0.1,
178
+ ge=0.0,
179
+ description="KL penalty coefficient for reference policy regularization.",
180
+ )
181
+
182
+ preference_loss_weight: float = Field(default=1.0,
183
+ ge=0.0,
184
+ description="Scales the contribution of the preference loss")
185
+
186
+ preference_average_log_probs: bool = Field(
187
+ default=False,
188
+ description="If True, use average log probabilities over sequence length "
189
+ "when computing preference loss. If False, use sum of log probabilities.",
190
+ )
191
+
192
+ sft_loss_weight: float = Field(default=0.0,
193
+ ge=0.0,
194
+ description="Scales the contribution of the supervised fine-tuning (SFT) loss. ")
195
+
196
+
197
+ class NeMoCustomizerHyperparameters(BaseModel):
198
+ """
199
+ Hyperparameters for NeMo Customizer training jobs.
200
+
201
+ These map to the `hyperparameters` argument in
202
+ `client.customization.jobs.create()`.
203
+ """
204
+
205
+ training_type: Literal["sft", "dpo"] = Field(
206
+ default="dpo",
207
+ description="Type of training: 'sft' for supervised fine-tuning, 'dpo' for direct preference optimization.",
208
+ )
209
+ finetuning_type: Literal["lora", "all_weights"] = Field(
210
+ default="all_weights",
211
+ description="Type of finetuning: 'lora' for LoRA adapters, 'all_weights' for full model.",
212
+ )
213
+ epochs: int = Field(
214
+ default=3,
215
+ ge=1,
216
+ description="Number of training epochs.",
217
+ )
218
+ batch_size: int = Field(
219
+ default=4,
220
+ ge=1,
221
+ description="Training batch size.",
222
+ )
223
+ learning_rate: float = Field(
224
+ default=5e-5,
225
+ gt=0.0,
226
+ description="Learning rate for optimizer.",
227
+ )
228
+ dpo: DPOSpecificHyperparameters = Field(
229
+ default_factory=DPOSpecificHyperparameters,
230
+ description="DPO-specific hyperparameters.",
231
+ )
232
+
233
+
234
+ class NIMDeploymentConfig(BaseModel):
235
+ """
236
+ Configuration for NIM deployment after training.
237
+
238
+ These settings are used when `deploy_on_completion` is True.
239
+ """
240
+
241
+ image_name: str = Field(
242
+ default="nvcr.io/nim/meta/llama-3.1-8b-instruct",
243
+ description="NIM container image name.",
244
+ )
245
+ image_tag: str = Field(
246
+ default="latest",
247
+ description="NIM container image tag.",
248
+ )
249
+ gpu: int = Field(
250
+ default=1,
251
+ ge=1,
252
+ description="Number of GPUs for deployment.",
253
+ )
254
+ deployment_name: str | None = Field(
255
+ default=None,
256
+ description="Name for the deployment. If None, auto-generated from model name.",
257
+ )
258
+ description: str = Field(
259
+ default="Fine-tuned model deployment",
260
+ description="Description for the deployment.",
261
+ )
262
+
263
+
264
+ class NeMoCustomizerTrainerAdapterConfig(TrainerAdapterConfig, name="nemo_customizer_trainer_adapter"):
265
+ """
266
+ Configuration for the NeMo Customizer TrainerAdapter.
267
+
268
+ This adapter submits DPO/SFT training jobs to NeMo Customizer and
269
+ optionally deploys the trained model.
270
+
271
+ Example YAML configuration::
272
+
273
+ trainer_adapters:
274
+ nemo_customizer:
275
+ _type: nemo_customizer_trainer_adapter
276
+ entity_host: https://nmp.example.com
277
+ datastore_host: https://datastore.example.com
278
+ namespace: my-project
279
+ customization_config: meta/llama-3.2-1b-instruct@v1.0.0+A100
280
+ hyperparameters:
281
+ training_type: dpo
282
+ epochs: 5
283
+ batch_size: 8
284
+ use_full_message_history: true
285
+ deploy_on_completion: true
286
+ """
287
+
288
+ # === Endpoint Configuration ===
289
+ entity_host: str = Field(description="Base URL for NeMo Entity Store (e.g., https://nmp.example.com).", )
290
+ datastore_host: str = Field(description="Base URL for NeMo Datastore (e.g., https://datastore.example.com).", )
291
+ hf_token: str = Field(
292
+ default="",
293
+ description="HuggingFace token for datastore authentication. Can be empty if not required.",
294
+ )
295
+
296
+ # === Namespace and Dataset ===
297
+ namespace: str = Field(description="Namespace for organizing resources (datasets, models, deployments).", )
298
+ dataset_name: str = Field(
299
+ default="nat-dpo",
300
+ description="Name for the training dataset. Must be unique within namespace.",
301
+ )
302
+ dataset_output_dir: str | None = Field(
303
+ default=None,
304
+ description="Directory to save dataset JSONL files locally before upload. "
305
+ "If None, uses a temporary directory that is deleted after upload. "
306
+ "If specified, creates the directory if it doesn't exist and preserves files.",
307
+ )
308
+ create_namespace_if_missing: bool = Field(
309
+ default=True,
310
+ description="If True, create namespace in entity store and datastore if it doesn't exist.",
311
+ )
312
+
313
+ # === Customization Job Configuration ===
314
+ customization_config: str = Field(description="Model configuration string for customization job "
315
+ "(e.g., 'meta/llama-3.2-1b-instruct@v1.0.0+A100'). "
316
+ "Available configs can be listed via the NeMo Customizer API.", )
317
+ hyperparameters: NeMoCustomizerHyperparameters = Field(
318
+ default_factory=NeMoCustomizerHyperparameters,
319
+ description="Hyperparameters for the training job.",
320
+ )
321
+
322
+ # === Prompt Formatting ===
323
+ use_full_message_history: bool = Field(
324
+ default=False,
325
+ description="If True, include full message history in prompt field as list of messages. "
326
+ "If False, use only the last message content as a string. "
327
+ "Full history format: [{\"role\": \"system\", \"content\": \"...\"}, ...]. "
328
+ "Last message format: \"<content string>\".",
329
+ )
330
+
331
+ # === Deployment Configuration ===
332
+ deploy_on_completion: bool = Field(
333
+ default=False,
334
+ description="If True, automatically deploy the trained model after job completion.",
335
+ )
336
+ deployment_config: NIMDeploymentConfig = Field(
337
+ default_factory=NIMDeploymentConfig,
338
+ description="Configuration for model deployment (used when deploy_on_completion=True).",
339
+ )
340
+
341
+ # === Polling Configuration ===
342
+ poll_interval_seconds: float = Field(
343
+ default=30.0,
344
+ gt=0.0,
345
+ description="Interval in seconds between job status checks.",
346
+ )
347
+ deployment_timeout_seconds: float = Field(
348
+ default=1800.0,
349
+ gt=0.0,
350
+ description="Maximum time in seconds to wait for deployment to be ready. "
351
+ "Default is 30 minutes (1800 seconds).",
352
+ )
353
+
354
+ @model_validator(mode="after")
355
+ def validate_config(self) -> "NeMoCustomizerTrainerAdapterConfig":
356
+ """Validate configuration consistency."""
357
+ # Ensure hosts don't have trailing slashes
358
+ self.entity_host = self.entity_host.rstrip("/")
359
+ self.datastore_host = self.datastore_host.rstrip("/")
360
+ return self
@@ -0,0 +1,157 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """
16
+ Registration module for DPO components.
17
+
18
+ This module registers the DPO trajectory builder and NeMo Customizer trainer adapter
19
+ with NAT's finetuning harness:
20
+ - `_type: dpo_traj_builder` - DPO Trajectory Builder
21
+ - `_type: nemo_customizer_trainer_adapter` - NeMo Customizer TrainerAdapter
22
+ """
23
+
24
+ from nat.builder.builder import Builder
25
+ from nat.cli.register_workflow import register_trainer
26
+ from nat.cli.register_workflow import register_trainer_adapter
27
+ from nat.cli.register_workflow import register_trajectory_builder
28
+
29
+ from .config import DPOTrajectoryBuilderConfig
30
+ from .config import NeMoCustomizerTrainerAdapterConfig
31
+ from .config import NeMoCustomizerTrainerConfig
32
+ from .trainer import NeMoCustomizerTrainer
33
+ from .trainer_adapter import NeMoCustomizerTrainerAdapter
34
+ from .trajectory_builder import DPOTrajectoryBuilder
35
+
36
+
37
+ @register_trajectory_builder(config_type=DPOTrajectoryBuilderConfig)
38
+ async def dpo_trajectory_builder(config: DPOTrajectoryBuilderConfig, builder: Builder):
39
+ """
40
+ Register the DPO (Direct Preference Optimization) trajectory builder.
41
+
42
+ This builder collects preference data from workflows that produce scored
43
+ candidate intermediate steps (TTC_END events with TTCEventData).
44
+
45
+ The builder:
46
+ 1. Runs evaluation to collect intermediate steps
47
+ 2. Filters for TTC_END steps with the configured name
48
+ 3. Groups candidates by turn_id
49
+ 4. Generates preference pairs based on score differences
50
+ 5. Builds trajectories with DPOItem episodes
51
+
52
+ Example YAML configuration::
53
+
54
+ trajectory_builders:
55
+ dpo_builder:
56
+ _type: dpo_traj_builder
57
+ ttc_step_name: dpo_candidate_move
58
+ exhaustive_pairs: true
59
+ min_score_diff: 0.05
60
+ max_pairs_per_turn: 5
61
+
62
+ finetuning:
63
+ enabled: true
64
+ trajectory_builder: dpo_builder
65
+ # ... other finetuning config
66
+
67
+ Args:
68
+ config: The trajectory builder configuration.
69
+ builder: The NAT workflow builder (for accessing other components).
70
+
71
+ Yields:
72
+ A configured DPOTrajectoryBuilder instance.
73
+ """
74
+ yield DPOTrajectoryBuilder(trajectory_builder_config=config)
75
+
76
+
77
+ @register_trainer_adapter(config_type=NeMoCustomizerTrainerAdapterConfig)
78
+ async def nemo_customizer_trainer_adapter(config: NeMoCustomizerTrainerAdapterConfig, builder: Builder):
79
+ """
80
+ Register the NeMo Customizer trainer adapter.
81
+
82
+ This adapter submits DPO/SFT training jobs to NeMo Customizer and
83
+ optionally deploys the trained model.
84
+
85
+ The adapter:
86
+ 1. Converts trajectories to JSONL format for DPO training
87
+ 2. Uploads datasets to NeMo Datastore
88
+ 3. Submits customization jobs to NeMo Customizer
89
+ 4. Monitors job progress and status
90
+ 5. Optionally deploys trained models
91
+
92
+ Example YAML configuration::
93
+
94
+ trainer_adapters:
95
+ nemo_customizer:
96
+ _type: nemo_customizer_trainer_adapter
97
+ entity_host: https://nmp.example.com
98
+ datastore_host: https://datastore.example.com
99
+ namespace: my-project
100
+ customization_config: meta/llama-3.2-1b-instruct@v1.0.0+A100
101
+ hyperparameters:
102
+ training_type: dpo
103
+ epochs: 5
104
+ batch_size: 8
105
+ use_full_message_history: true
106
+ deploy_on_completion: true
107
+
108
+ finetuning:
109
+ enabled: true
110
+ trainer_adapter: nemo_customizer
111
+ # ... other finetuning config
112
+
113
+ Args:
114
+ config: The trainer adapter configuration.
115
+ builder: The NAT workflow builder (for accessing other components).
116
+
117
+ Yields:
118
+ A configured NeMoCustomizerTrainerAdapter instance.
119
+ """
120
+ yield NeMoCustomizerTrainerAdapter(adapter_config=config)
121
+
122
+
123
+ @register_trainer(config_type=NeMoCustomizerTrainerConfig)
124
+ async def nemo_customizer_trainer(config: NeMoCustomizerTrainerConfig, builder: Builder):
125
+ """
126
+ Register the NeMo Customizer trainer.
127
+
128
+ This trainer orchestrates DPO data collection and training job submission.
129
+ Unlike epoch-based trainers, it:
130
+ 1. Runs the trajectory builder multiple times (num_runs) to collect data
131
+ 2. Aggregates all trajectories into a single dataset
132
+ 3. Submits the dataset to NeMo Customizer for training
133
+ 4. Monitors the training job until completion
134
+
135
+ Example YAML configuration::
136
+
137
+ trainers:
138
+ nemo_dpo:
139
+ _type: nemo_customizer_trainer
140
+ num_runs: 5
141
+ wait_for_completion: true
142
+ deduplicate_pairs: true
143
+ max_pairs: 10000
144
+
145
+ finetuning:
146
+ enabled: true
147
+ trainer: nemo_dpo
148
+ # ... other finetuning config
149
+
150
+ Args:
151
+ config: The trainer configuration.
152
+ builder: The NAT workflow builder (for accessing other components).
153
+
154
+ Yields:
155
+ A configured NeMoCustomizerTrainer instance.
156
+ """
157
+ yield NeMoCustomizerTrainer(trainer_config=config)