nvidia-nat-nemo-customizer 1.4.0a20251223__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nat/meta/pypi.md +23 -0
- nat/plugins/customizer/__init__.py +43 -0
- nat/plugins/customizer/dpo/__init__.py +44 -0
- nat/plugins/customizer/dpo/config.py +360 -0
- nat/plugins/customizer/dpo/register.py +157 -0
- nat/plugins/customizer/dpo/trainer.py +424 -0
- nat/plugins/customizer/dpo/trainer_adapter.py +550 -0
- nat/plugins/customizer/dpo/trajectory_builder.py +767 -0
- nat/plugins/customizer/register.py +23 -0
- nvidia_nat_nemo_customizer-1.4.0a20251223.dist-info/METADATA +45 -0
- nvidia_nat_nemo_customizer-1.4.0a20251223.dist-info/RECORD +16 -0
- nvidia_nat_nemo_customizer-1.4.0a20251223.dist-info/WHEEL +5 -0
- nvidia_nat_nemo_customizer-1.4.0a20251223.dist-info/entry_points.txt +2 -0
- nvidia_nat_nemo_customizer-1.4.0a20251223.dist-info/licenses/LICENSE-3rd-party.txt +5478 -0
- nvidia_nat_nemo_customizer-1.4.0a20251223.dist-info/licenses/LICENSE.md +201 -0
- nvidia_nat_nemo_customizer-1.4.0a20251223.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,767 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
"""
|
|
16
|
+
DPO (Direct Preference Optimization) Trajectory Builder.
|
|
17
|
+
|
|
18
|
+
This module provides a trajectory builder that collects preference data from
|
|
19
|
+
workflows that produce TTC_END intermediate steps with TTCEventData.
|
|
20
|
+
|
|
21
|
+
The builder:
|
|
22
|
+
1. Runs evaluation to collect intermediate steps
|
|
23
|
+
2. Filters for TTC_END steps with the configured name
|
|
24
|
+
3. Extracts data from TTCEventData (turn_id, candidate_index, score, input, output)
|
|
25
|
+
4. Groups candidates by turn_id
|
|
26
|
+
5. Generates preference pairs based on score differences
|
|
27
|
+
6. Builds trajectories with DPOItem episodes for DPO training
|
|
28
|
+
"""
|
|
29
|
+
|
|
30
|
+
from __future__ import annotations
|
|
31
|
+
|
|
32
|
+
import asyncio
|
|
33
|
+
import json
|
|
34
|
+
import logging
|
|
35
|
+
from dataclasses import dataclass
|
|
36
|
+
from dataclasses import field
|
|
37
|
+
from datetime import datetime
|
|
38
|
+
from pathlib import Path
|
|
39
|
+
from typing import Any
|
|
40
|
+
|
|
41
|
+
from nat.data_models.finetuning import DPOItem
|
|
42
|
+
from nat.data_models.finetuning import OpenAIMessage
|
|
43
|
+
from nat.data_models.finetuning import Trajectory
|
|
44
|
+
from nat.data_models.finetuning import TrajectoryCollection
|
|
45
|
+
from nat.data_models.intermediate_step import IntermediateStep
|
|
46
|
+
from nat.data_models.intermediate_step import IntermediateStepCategory
|
|
47
|
+
from nat.data_models.intermediate_step import IntermediateStepType
|
|
48
|
+
from nat.data_models.intermediate_step import StreamEventData
|
|
49
|
+
from nat.data_models.intermediate_step import TTCEventData
|
|
50
|
+
from nat.eval.config import EvaluationRunOutput
|
|
51
|
+
from nat.eval.evaluator.evaluator_model import EvalInputItem
|
|
52
|
+
from nat.finetuning.interfaces.trajectory_builder import TrajectoryBuilder
|
|
53
|
+
|
|
54
|
+
from .config import DPOTrajectoryBuilderConfig
|
|
55
|
+
|
|
56
|
+
logger = logging.getLogger(__name__)
|
|
57
|
+
|
|
58
|
+
# Type alias for prompt which can be string or list of OpenAI messages
|
|
59
|
+
PromptType = list[OpenAIMessage] | str
|
|
60
|
+
|
|
61
|
+
# =============================================================================
|
|
62
|
+
# Data Classes
|
|
63
|
+
# =============================================================================
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
@dataclass
|
|
67
|
+
class CandidateStep:
|
|
68
|
+
"""
|
|
69
|
+
Parsed candidate from a TTC intermediate step.
|
|
70
|
+
|
|
71
|
+
Represents a single candidate response that was generated and scored
|
|
72
|
+
for a particular turn in the workflow.
|
|
73
|
+
"""
|
|
74
|
+
|
|
75
|
+
example_id: str
|
|
76
|
+
"""Unique identifier for the dataset example."""
|
|
77
|
+
|
|
78
|
+
turn_id: str
|
|
79
|
+
"""Identifier for the turn (groups candidates competing for the same prompt)."""
|
|
80
|
+
|
|
81
|
+
candidate_index: int
|
|
82
|
+
"""Index of this candidate within the turn."""
|
|
83
|
+
|
|
84
|
+
prompt: PromptType
|
|
85
|
+
"""Input prompt that produced this response (string or list of OpenAIMessage)."""
|
|
86
|
+
|
|
87
|
+
response: str
|
|
88
|
+
"""Model's response/completion."""
|
|
89
|
+
|
|
90
|
+
score: float
|
|
91
|
+
"""Score assigned to this candidate (higher is better)."""
|
|
92
|
+
|
|
93
|
+
raw_metadata: dict[str, Any] = field(default_factory=dict)
|
|
94
|
+
"""Original metadata from the intermediate step."""
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
@dataclass
|
|
98
|
+
class PreferencePair:
|
|
99
|
+
"""
|
|
100
|
+
A preference pair for DPO training.
|
|
101
|
+
|
|
102
|
+
Represents a single (prompt, chosen, rejected) triple where the chosen
|
|
103
|
+
response has a higher score than the rejected response.
|
|
104
|
+
"""
|
|
105
|
+
|
|
106
|
+
example_id: str
|
|
107
|
+
"""Unique identifier for the dataset example."""
|
|
108
|
+
|
|
109
|
+
turn_id: str
|
|
110
|
+
"""Identifier for the turn."""
|
|
111
|
+
|
|
112
|
+
prompt: PromptType
|
|
113
|
+
"""Input prompt (same for both responses)."""
|
|
114
|
+
|
|
115
|
+
chosen_response: str
|
|
116
|
+
"""Response that was preferred (higher score)."""
|
|
117
|
+
|
|
118
|
+
rejected_response: str
|
|
119
|
+
"""Response that was not preferred (lower score)."""
|
|
120
|
+
|
|
121
|
+
chosen_score: float
|
|
122
|
+
"""Score of the chosen response."""
|
|
123
|
+
|
|
124
|
+
rejected_score: float
|
|
125
|
+
"""Score of the rejected response."""
|
|
126
|
+
|
|
127
|
+
score_diff: float
|
|
128
|
+
"""Difference between chosen and rejected scores."""
|
|
129
|
+
|
|
130
|
+
chosen_index: int
|
|
131
|
+
"""Candidate index of the chosen response."""
|
|
132
|
+
|
|
133
|
+
rejected_index: int
|
|
134
|
+
"""Candidate index of the rejected response."""
|
|
135
|
+
|
|
136
|
+
metadata: dict[str, Any] = field(default_factory=dict)
|
|
137
|
+
"""Additional metadata for the pair."""
|
|
138
|
+
|
|
139
|
+
|
|
140
|
+
# =============================================================================
|
|
141
|
+
# DPO Trajectory Builder
|
|
142
|
+
# =============================================================================
|
|
143
|
+
|
|
144
|
+
|
|
145
|
+
class DPOTrajectoryBuilder(TrajectoryBuilder):
|
|
146
|
+
"""
|
|
147
|
+
Trajectory builder for DPO (Direct Preference Optimization) training.
|
|
148
|
+
|
|
149
|
+
This builder collects preference pairs from workflows that produce TTC_END
|
|
150
|
+
intermediate steps with TTCEventData. It uses the structured data model
|
|
151
|
+
to extract turn_id, candidate_index, score, input (prompt), and output.
|
|
152
|
+
|
|
153
|
+
Key features:
|
|
154
|
+
- Uses TTCEventData model directly (no brittle dictionary key configuration)
|
|
155
|
+
- Supports prompts as strings or list of OpenAIMessage
|
|
156
|
+
- Exhaustive or best-vs-worst pair generation modes
|
|
157
|
+
- Configurable score difference filtering
|
|
158
|
+
- Grouping by example for curriculum learning
|
|
159
|
+
- Builds trajectories with DPOItem episodes
|
|
160
|
+
|
|
161
|
+
Example workflow integration::
|
|
162
|
+
|
|
163
|
+
trajectory_builders:
|
|
164
|
+
dpo_builder:
|
|
165
|
+
_type: dpo_traj_builder
|
|
166
|
+
ttc_step_name: dpo_candidate_move
|
|
167
|
+
exhaustive_pairs: true
|
|
168
|
+
min_score_diff: 0.05
|
|
169
|
+
"""
|
|
170
|
+
|
|
171
|
+
def __init__(self, trajectory_builder_config: DPOTrajectoryBuilderConfig):
|
|
172
|
+
"""
|
|
173
|
+
Initialize the DPO Trajectory Builder.
|
|
174
|
+
|
|
175
|
+
Args:
|
|
176
|
+
trajectory_builder_config: Configuration for the builder.
|
|
177
|
+
"""
|
|
178
|
+
super().__init__(trajectory_builder_config=trajectory_builder_config)
|
|
179
|
+
self.config: DPOTrajectoryBuilderConfig = trajectory_builder_config
|
|
180
|
+
self.evaluation_runs: dict[str, asyncio.Task[EvaluationRunOutput]] = {}
|
|
181
|
+
|
|
182
|
+
# Metrics tracking
|
|
183
|
+
self._metrics: dict[str, Any] = {}
|
|
184
|
+
|
|
185
|
+
# =========================================================================
|
|
186
|
+
# TrajectoryBuilder Interface Implementation
|
|
187
|
+
# =========================================================================
|
|
188
|
+
|
|
189
|
+
async def start_run(self, run_id: str, meta: dict | None = None) -> None:
|
|
190
|
+
"""
|
|
191
|
+
Start a single evaluation run to collect intermediate steps.
|
|
192
|
+
|
|
193
|
+
Args:
|
|
194
|
+
run_id: Unique identifier for this run.
|
|
195
|
+
meta: Optional metadata for the run.
|
|
196
|
+
|
|
197
|
+
Raises:
|
|
198
|
+
ValueError: If a run with this ID is already in progress.
|
|
199
|
+
"""
|
|
200
|
+
if run_id in self.evaluation_runs:
|
|
201
|
+
raise ValueError(f"Run {run_id} is already in progress.")
|
|
202
|
+
|
|
203
|
+
logger.info("Starting DPO evaluation run: %s", run_id)
|
|
204
|
+
logger.info(
|
|
205
|
+
"Configuration: step_name=%s, exhaustive=%s, min_diff=%.3f",
|
|
206
|
+
self.config.ttc_step_name,
|
|
207
|
+
self.config.exhaustive_pairs,
|
|
208
|
+
self.config.min_score_diff,
|
|
209
|
+
)
|
|
210
|
+
|
|
211
|
+
# Create evaluation task
|
|
212
|
+
task = asyncio.create_task(self.run_eval(), name=f"dpo-eval-{run_id}")
|
|
213
|
+
|
|
214
|
+
def _on_done(t: asyncio.Task[EvaluationRunOutput]) -> None:
|
|
215
|
+
if t.cancelled():
|
|
216
|
+
logger.info("DPO evaluation run %s was cancelled.", run_id)
|
|
217
|
+
elif exc := t.exception():
|
|
218
|
+
logger.error("DPO evaluation run %s failed: %s", run_id, exc)
|
|
219
|
+
else:
|
|
220
|
+
logger.info("DPO evaluation run %s completed successfully.", run_id)
|
|
221
|
+
|
|
222
|
+
task.add_done_callback(_on_done)
|
|
223
|
+
self.evaluation_runs[run_id] = task
|
|
224
|
+
|
|
225
|
+
async def finalize(self, run_id: str, meta: dict | None = None) -> TrajectoryCollection:
|
|
226
|
+
"""
|
|
227
|
+
Wait for evaluation, collect TTC steps, and build DPO trajectories.
|
|
228
|
+
|
|
229
|
+
This method:
|
|
230
|
+
1. Waits for the evaluation run to complete
|
|
231
|
+
2. Collects and groups candidates by turn_id using TTCEventData
|
|
232
|
+
3. Generates preference pairs
|
|
233
|
+
4. Builds trajectories with DPOItem episodes
|
|
234
|
+
5. Groups trajectories by example for curriculum learning
|
|
235
|
+
|
|
236
|
+
Args:
|
|
237
|
+
run_id: Unique identifier for the run.
|
|
238
|
+
meta: Optional metadata for the run.
|
|
239
|
+
|
|
240
|
+
Returns:
|
|
241
|
+
TrajectoryCollection with DPO preference trajectories.
|
|
242
|
+
|
|
243
|
+
Raises:
|
|
244
|
+
ValueError: If no run with this ID exists.
|
|
245
|
+
"""
|
|
246
|
+
if run_id not in self.evaluation_runs:
|
|
247
|
+
raise ValueError(f"No evaluation run found for run_id: {run_id}")
|
|
248
|
+
|
|
249
|
+
# Wait for evaluation to complete
|
|
250
|
+
logger.info("Waiting for DPO evaluation run %s to complete...", run_id)
|
|
251
|
+
eval_result = await self.evaluation_runs[run_id]
|
|
252
|
+
|
|
253
|
+
# Initialize metrics
|
|
254
|
+
self._metrics = {
|
|
255
|
+
"run_id": run_id,
|
|
256
|
+
"total_examples": 0,
|
|
257
|
+
"total_turns": 0,
|
|
258
|
+
"total_candidates": 0,
|
|
259
|
+
"total_pairs": 0,
|
|
260
|
+
"total_trajectories": 0,
|
|
261
|
+
"skipped_single_candidate": 0,
|
|
262
|
+
"skipped_score_diff": 0,
|
|
263
|
+
}
|
|
264
|
+
|
|
265
|
+
# Step 1: Collect and group candidates
|
|
266
|
+
candidates_by_turn = self._collect_candidates(eval_result)
|
|
267
|
+
self._metrics["total_turns"] = len(candidates_by_turn)
|
|
268
|
+
|
|
269
|
+
if not candidates_by_turn:
|
|
270
|
+
logger.warning("No candidate steps found for run_id: %s", run_id)
|
|
271
|
+
del self.evaluation_runs[run_id]
|
|
272
|
+
return TrajectoryCollection(trajectories=[], run_id=run_id)
|
|
273
|
+
|
|
274
|
+
# Step 2: Generate preference pairs
|
|
275
|
+
pairs = self._generate_preference_pairs(candidates_by_turn)
|
|
276
|
+
self._metrics["total_pairs"] = len(pairs)
|
|
277
|
+
|
|
278
|
+
if not pairs:
|
|
279
|
+
logger.warning("No preference pairs generated for run_id: %s", run_id)
|
|
280
|
+
del self.evaluation_runs[run_id]
|
|
281
|
+
return TrajectoryCollection(trajectories=[], run_id=run_id)
|
|
282
|
+
|
|
283
|
+
# Step 3: Build trajectories with DPOItem episodes
|
|
284
|
+
trajectories = self._build_trajectories(pairs)
|
|
285
|
+
self._metrics["total_trajectories"] = len(trajectories)
|
|
286
|
+
|
|
287
|
+
# Step 4: Group by example for curriculum learning
|
|
288
|
+
grouped = self._group_by_example(trajectories)
|
|
289
|
+
self._metrics["total_examples"] = len(grouped)
|
|
290
|
+
|
|
291
|
+
# Log summary
|
|
292
|
+
logger.info(
|
|
293
|
+
"DPO trajectory building complete for run %s: "
|
|
294
|
+
"%d examples, %d turns, %d candidates, %d pairs, %d trajectories",
|
|
295
|
+
run_id,
|
|
296
|
+
self._metrics["total_examples"],
|
|
297
|
+
self._metrics["total_turns"],
|
|
298
|
+
self._metrics["total_candidates"],
|
|
299
|
+
self._metrics["total_pairs"],
|
|
300
|
+
self._metrics["total_trajectories"],
|
|
301
|
+
)
|
|
302
|
+
|
|
303
|
+
if self._metrics["skipped_single_candidate"] > 0:
|
|
304
|
+
logger.info(
|
|
305
|
+
"Skipped %d turns with single candidate (no preference signal)",
|
|
306
|
+
self._metrics["skipped_single_candidate"],
|
|
307
|
+
)
|
|
308
|
+
|
|
309
|
+
if self._metrics["skipped_score_diff"] > 0:
|
|
310
|
+
logger.info(
|
|
311
|
+
"Skipped %d pairs with score diff < %.3f",
|
|
312
|
+
self._metrics["skipped_score_diff"],
|
|
313
|
+
self.config.min_score_diff,
|
|
314
|
+
)
|
|
315
|
+
|
|
316
|
+
# Cleanup
|
|
317
|
+
del self.evaluation_runs[run_id]
|
|
318
|
+
|
|
319
|
+
return TrajectoryCollection(trajectories=grouped, run_id=run_id)
|
|
320
|
+
|
|
321
|
+
def log_progress(self, run_id: str, metrics: dict[str, Any], output_dir: str | None = None) -> None:
|
|
322
|
+
"""
|
|
323
|
+
Log trajectory building progress.
|
|
324
|
+
|
|
325
|
+
Args:
|
|
326
|
+
run_id: The training run ID.
|
|
327
|
+
metrics: Dictionary of metrics to log.
|
|
328
|
+
output_dir: Optional output directory override.
|
|
329
|
+
"""
|
|
330
|
+
# Use default output directory if not provided
|
|
331
|
+
out_dir = (Path(output_dir) if output_dir else Path("./.tmp/nat/finetuning/dpo_trajectory_builder"))
|
|
332
|
+
out_dir.mkdir(parents=True, exist_ok=True)
|
|
333
|
+
|
|
334
|
+
# Create log file
|
|
335
|
+
log_file = out_dir / f"dpo_trajectory_builder_{run_id}.jsonl"
|
|
336
|
+
|
|
337
|
+
# Prepare log entry
|
|
338
|
+
log_entry = {
|
|
339
|
+
"timestamp": datetime.now().isoformat(),
|
|
340
|
+
"run_id": run_id,
|
|
341
|
+
"config": {
|
|
342
|
+
"ttc_step_name": self.config.ttc_step_name,
|
|
343
|
+
"exhaustive_pairs": self.config.exhaustive_pairs,
|
|
344
|
+
"min_score_diff": self.config.min_score_diff,
|
|
345
|
+
"max_pairs_per_turn": self.config.max_pairs_per_turn,
|
|
346
|
+
},
|
|
347
|
+
**self._metrics,
|
|
348
|
+
**metrics,
|
|
349
|
+
}
|
|
350
|
+
|
|
351
|
+
# Append to log file
|
|
352
|
+
with open(log_file, "a", encoding="utf-8") as f:
|
|
353
|
+
f.write(json.dumps(log_entry) + "\n")
|
|
354
|
+
|
|
355
|
+
logger.debug(
|
|
356
|
+
"DPO trajectory builder progress logged for run %s: %d pairs",
|
|
357
|
+
run_id,
|
|
358
|
+
self._metrics.get("total_pairs", 0),
|
|
359
|
+
)
|
|
360
|
+
|
|
361
|
+
# =========================================================================
|
|
362
|
+
# Internal Methods
|
|
363
|
+
# =========================================================================
|
|
364
|
+
|
|
365
|
+
def _collect_candidates(self, eval_result: EvaluationRunOutput) -> dict[str, list[CandidateStep]]:
|
|
366
|
+
"""
|
|
367
|
+
Extract TTC_END intermediate steps and group by turn_id.
|
|
368
|
+
|
|
369
|
+
This method:
|
|
370
|
+
1. Iterates through all evaluation input items
|
|
371
|
+
2. Filters for TTC_END steps with the configured name
|
|
372
|
+
3. Extracts data from TTCEventData model directly
|
|
373
|
+
4. Groups candidates by (example_id, turn_id)
|
|
374
|
+
|
|
375
|
+
Args:
|
|
376
|
+
eval_result: The evaluation run output.
|
|
377
|
+
|
|
378
|
+
Returns:
|
|
379
|
+
Dictionary mapping turn keys to lists of candidates.
|
|
380
|
+
"""
|
|
381
|
+
candidates_by_turn: dict[str, list[CandidateStep]] = {}
|
|
382
|
+
|
|
383
|
+
# Create mapping of example ID to input item
|
|
384
|
+
input_items_map: dict[str, EvalInputItem] = {item.id: item for item in eval_result.eval_input.eval_input_items}
|
|
385
|
+
|
|
386
|
+
for example_id, input_item in input_items_map.items():
|
|
387
|
+
# Filter for TTC_END steps with matching name
|
|
388
|
+
for step in input_item.trajectory:
|
|
389
|
+
if not self._is_target_step(step):
|
|
390
|
+
continue
|
|
391
|
+
|
|
392
|
+
# Parse candidate from TTCEventData
|
|
393
|
+
candidate = self._parse_candidate(example_id, step)
|
|
394
|
+
if candidate is None:
|
|
395
|
+
continue
|
|
396
|
+
|
|
397
|
+
self._metrics["total_candidates"] = (self._metrics.get("total_candidates", 0) + 1)
|
|
398
|
+
|
|
399
|
+
# Group by (example_id, turn_id)
|
|
400
|
+
turn_key = f"{example_id}::{candidate.turn_id}"
|
|
401
|
+
if turn_key not in candidates_by_turn:
|
|
402
|
+
candidates_by_turn[turn_key] = []
|
|
403
|
+
candidates_by_turn[turn_key].append(candidate)
|
|
404
|
+
|
|
405
|
+
logger.debug(
|
|
406
|
+
"Collected %d candidates across %d turns",
|
|
407
|
+
self._metrics.get("total_candidates", 0),
|
|
408
|
+
len(candidates_by_turn),
|
|
409
|
+
)
|
|
410
|
+
|
|
411
|
+
return candidates_by_turn
|
|
412
|
+
|
|
413
|
+
def _is_target_step(self, step: IntermediateStep) -> bool:
|
|
414
|
+
"""
|
|
415
|
+
Check if an intermediate step is a target TTC step.
|
|
416
|
+
|
|
417
|
+
Args:
|
|
418
|
+
step: The intermediate step to check.
|
|
419
|
+
|
|
420
|
+
Returns:
|
|
421
|
+
True if this is a TTC_END step with the configured name.
|
|
422
|
+
"""
|
|
423
|
+
return (step.event_category == IntermediateStepCategory.TTC and step.event_type == IntermediateStepType.TTC_END
|
|
424
|
+
and step.payload.name == self.config.ttc_step_name)
|
|
425
|
+
|
|
426
|
+
def _parse_candidate(self, example_id: str, step: IntermediateStep) -> CandidateStep | None:
|
|
427
|
+
"""
|
|
428
|
+
Parse a CandidateStep from a TTC intermediate step using TTCEventData.
|
|
429
|
+
|
|
430
|
+
Args:
|
|
431
|
+
example_id: The example ID this step belongs to.
|
|
432
|
+
step: The intermediate step to parse.
|
|
433
|
+
|
|
434
|
+
Returns:
|
|
435
|
+
CandidateStep if parsing succeeds, None otherwise.
|
|
436
|
+
"""
|
|
437
|
+
# Get TTCEventData from step.payload.data
|
|
438
|
+
data = step.payload.data
|
|
439
|
+
if data is None:
|
|
440
|
+
logger.warning("Step has no data field, skipping: %s", step.payload.UUID)
|
|
441
|
+
return None
|
|
442
|
+
|
|
443
|
+
# Validate that we have TTCEventData (or compatible dict/StreamEventData)
|
|
444
|
+
# NOTE: When IntermediateStepPayload is serialized/deserialized, TTCEventData
|
|
445
|
+
# becomes StreamEventData because the data field is typed as StreamEventData.
|
|
446
|
+
# The TTC fields are preserved as extra fields due to extra="allow".
|
|
447
|
+
if isinstance(data, TTCEventData):
|
|
448
|
+
ttc_data = data
|
|
449
|
+
elif isinstance(data, StreamEventData):
|
|
450
|
+
# TTCEventData may have been deserialized as StreamEventData
|
|
451
|
+
# Try to construct TTCEventData from the model dump
|
|
452
|
+
try:
|
|
453
|
+
data_dict = data.model_dump()
|
|
454
|
+
ttc_data = TTCEventData(**data_dict)
|
|
455
|
+
except Exception as e:
|
|
456
|
+
logger.warning("Failed to parse TTCEventData from StreamEventData: %s", e)
|
|
457
|
+
return None
|
|
458
|
+
elif isinstance(data, dict):
|
|
459
|
+
# Try to parse as TTCEventData
|
|
460
|
+
try:
|
|
461
|
+
ttc_data = TTCEventData(**data)
|
|
462
|
+
except Exception as e:
|
|
463
|
+
logger.warning("Failed to parse TTCEventData from dict: %s", e)
|
|
464
|
+
return None
|
|
465
|
+
else:
|
|
466
|
+
logger.warning("Unexpected data type %s, expected TTCEventData", type(data))
|
|
467
|
+
return None
|
|
468
|
+
|
|
469
|
+
# Extract required fields from TTCEventData
|
|
470
|
+
try:
|
|
471
|
+
turn_id = ttc_data.turn_id
|
|
472
|
+
if turn_id is None:
|
|
473
|
+
logger.warning(
|
|
474
|
+
"TTCEventData missing turn_id, skipping: %s",
|
|
475
|
+
step.payload.UUID,
|
|
476
|
+
)
|
|
477
|
+
return None
|
|
478
|
+
|
|
479
|
+
score = ttc_data.score
|
|
480
|
+
if score is None:
|
|
481
|
+
logger.warning(
|
|
482
|
+
"TTCEventData missing score, skipping: %s",
|
|
483
|
+
step.payload.UUID,
|
|
484
|
+
)
|
|
485
|
+
return None
|
|
486
|
+
|
|
487
|
+
candidate_index = ttc_data.candidate_index or 0
|
|
488
|
+
|
|
489
|
+
# Get prompt from TTCEventData.input
|
|
490
|
+
# This can be a string or list of OpenAIMessage
|
|
491
|
+
prompt = self._extract_prompt(ttc_data.input)
|
|
492
|
+
|
|
493
|
+
# Get response from TTCEventData.output
|
|
494
|
+
response = str(ttc_data.output) if ttc_data.output else ""
|
|
495
|
+
|
|
496
|
+
# Get raw metadata for additional context
|
|
497
|
+
raw_metadata = {}
|
|
498
|
+
if step.payload.metadata:
|
|
499
|
+
if hasattr(step.payload.metadata, "model_dump"):
|
|
500
|
+
raw_metadata = step.payload.metadata.model_dump()
|
|
501
|
+
elif isinstance(step.payload.metadata, dict):
|
|
502
|
+
raw_metadata = step.payload.metadata
|
|
503
|
+
|
|
504
|
+
return CandidateStep(
|
|
505
|
+
example_id=str(example_id),
|
|
506
|
+
turn_id=str(turn_id),
|
|
507
|
+
candidate_index=int(candidate_index),
|
|
508
|
+
prompt=prompt,
|
|
509
|
+
response=response,
|
|
510
|
+
score=float(score),
|
|
511
|
+
raw_metadata=raw_metadata,
|
|
512
|
+
)
|
|
513
|
+
|
|
514
|
+
except (TypeError, ValueError) as e:
|
|
515
|
+
logger.warning(
|
|
516
|
+
"Failed to parse candidate from step %s: %s",
|
|
517
|
+
step.payload.UUID,
|
|
518
|
+
e,
|
|
519
|
+
)
|
|
520
|
+
return None
|
|
521
|
+
|
|
522
|
+
def _extract_prompt(self, input_data: Any) -> PromptType:
|
|
523
|
+
"""
|
|
524
|
+
Extract prompt from TTCEventData.input.
|
|
525
|
+
|
|
526
|
+
Handles both string prompts and list of OpenAIMessage.
|
|
527
|
+
|
|
528
|
+
Args:
|
|
529
|
+
input_data: The input field from TTCEventData.
|
|
530
|
+
|
|
531
|
+
Returns:
|
|
532
|
+
String prompt or list of OpenAIMessage.
|
|
533
|
+
"""
|
|
534
|
+
if input_data is None:
|
|
535
|
+
return ""
|
|
536
|
+
|
|
537
|
+
if isinstance(input_data, str):
|
|
538
|
+
return input_data
|
|
539
|
+
|
|
540
|
+
if isinstance(input_data, list):
|
|
541
|
+
# Try to convert to list of OpenAIMessage
|
|
542
|
+
messages: list[OpenAIMessage] = []
|
|
543
|
+
for item in input_data:
|
|
544
|
+
if isinstance(item, OpenAIMessage):
|
|
545
|
+
messages.append(item)
|
|
546
|
+
elif isinstance(item, dict):
|
|
547
|
+
# Try to parse as OpenAIMessage
|
|
548
|
+
try:
|
|
549
|
+
messages.append(OpenAIMessage(**item))
|
|
550
|
+
except Exception:
|
|
551
|
+
# If parsing fails, convert entire input to string
|
|
552
|
+
return str(input_data)
|
|
553
|
+
else:
|
|
554
|
+
# Unknown type, convert to string
|
|
555
|
+
return str(input_data)
|
|
556
|
+
return messages
|
|
557
|
+
|
|
558
|
+
# Fallback: convert to string
|
|
559
|
+
return str(input_data)
|
|
560
|
+
|
|
561
|
+
def _generate_preference_pairs(self, candidates_by_turn: dict[str, list[CandidateStep]]) -> list[PreferencePair]:
|
|
562
|
+
"""
|
|
563
|
+
Generate preference pairs from grouped candidates.
|
|
564
|
+
|
|
565
|
+
If exhaustive_pairs=True:
|
|
566
|
+
For candidates [A, B, C] with scores [0.9, 0.7, 0.5]:
|
|
567
|
+
Pairs: (A>B), (A>C), (B>C) - all pairwise comparisons
|
|
568
|
+
|
|
569
|
+
If exhaustive_pairs=False:
|
|
570
|
+
For candidates [A, B, C] with scores [0.9, 0.7, 0.5]:
|
|
571
|
+
Pairs: (A>C) only - best vs worst
|
|
572
|
+
|
|
573
|
+
Args:
|
|
574
|
+
candidates_by_turn: Dictionary mapping turn keys to candidate lists.
|
|
575
|
+
|
|
576
|
+
Returns:
|
|
577
|
+
List of preference pairs.
|
|
578
|
+
"""
|
|
579
|
+
all_pairs: list[PreferencePair] = []
|
|
580
|
+
|
|
581
|
+
for turn_key, candidates in candidates_by_turn.items():
|
|
582
|
+
# Check if we have enough candidates
|
|
583
|
+
if len(candidates) < 2:
|
|
584
|
+
if self.config.require_multiple_candidates:
|
|
585
|
+
self._metrics["skipped_single_candidate"] = (self._metrics.get("skipped_single_candidate", 0) + 1)
|
|
586
|
+
logger.debug("Skipping turn %s with single candidate", turn_key)
|
|
587
|
+
continue
|
|
588
|
+
|
|
589
|
+
# Sort candidates by score (descending)
|
|
590
|
+
sorted_candidates = sorted(candidates, key=lambda c: c.score, reverse=True)
|
|
591
|
+
|
|
592
|
+
if self.config.exhaustive_pairs:
|
|
593
|
+
pairs = self._generate_exhaustive_pairs(sorted_candidates)
|
|
594
|
+
else:
|
|
595
|
+
pairs = self._generate_best_vs_worst_pair(sorted_candidates)
|
|
596
|
+
|
|
597
|
+
all_pairs.extend(pairs)
|
|
598
|
+
|
|
599
|
+
logger.debug("Generated %d preference pairs", len(all_pairs))
|
|
600
|
+
return all_pairs
|
|
601
|
+
|
|
602
|
+
def _generate_exhaustive_pairs(self, sorted_candidates: list[CandidateStep]) -> list[PreferencePair]:
|
|
603
|
+
"""
|
|
604
|
+
Generate all pairwise comparisons where score(chosen) > score(rejected).
|
|
605
|
+
|
|
606
|
+
Args:
|
|
607
|
+
sorted_candidates: Candidates sorted by score (descending).
|
|
608
|
+
|
|
609
|
+
Returns:
|
|
610
|
+
List of preference pairs, sorted by score difference (descending).
|
|
611
|
+
"""
|
|
612
|
+
pairs: list[PreferencePair] = []
|
|
613
|
+
|
|
614
|
+
for i, chosen in enumerate(sorted_candidates):
|
|
615
|
+
for rejected in sorted_candidates[i + 1:]:
|
|
616
|
+
score_diff = chosen.score - rejected.score
|
|
617
|
+
|
|
618
|
+
# Apply minimum score difference filter
|
|
619
|
+
if score_diff < self.config.min_score_diff:
|
|
620
|
+
self._metrics["skipped_score_diff"] = (self._metrics.get("skipped_score_diff", 0) + 1)
|
|
621
|
+
continue
|
|
622
|
+
|
|
623
|
+
pairs.append(
|
|
624
|
+
PreferencePair(
|
|
625
|
+
example_id=chosen.example_id,
|
|
626
|
+
turn_id=chosen.turn_id,
|
|
627
|
+
prompt=chosen.prompt,
|
|
628
|
+
chosen_response=chosen.response,
|
|
629
|
+
rejected_response=rejected.response,
|
|
630
|
+
chosen_score=chosen.score,
|
|
631
|
+
rejected_score=rejected.score,
|
|
632
|
+
score_diff=score_diff,
|
|
633
|
+
chosen_index=chosen.candidate_index,
|
|
634
|
+
rejected_index=rejected.candidate_index,
|
|
635
|
+
metadata={
|
|
636
|
+
"chosen_raw_metadata": chosen.raw_metadata,
|
|
637
|
+
"rejected_raw_metadata": rejected.raw_metadata,
|
|
638
|
+
},
|
|
639
|
+
))
|
|
640
|
+
|
|
641
|
+
# Sort by score difference (highest first) and apply limit
|
|
642
|
+
pairs.sort(key=lambda p: p.score_diff, reverse=True)
|
|
643
|
+
|
|
644
|
+
if self.config.max_pairs_per_turn is not None:
|
|
645
|
+
pairs = pairs[:self.config.max_pairs_per_turn]
|
|
646
|
+
|
|
647
|
+
return pairs
|
|
648
|
+
|
|
649
|
+
def _generate_best_vs_worst_pair(self, sorted_candidates: list[CandidateStep]) -> list[PreferencePair]:
|
|
650
|
+
"""
|
|
651
|
+
Generate a single pair: best candidate vs worst candidate.
|
|
652
|
+
|
|
653
|
+
Args:
|
|
654
|
+
sorted_candidates: Candidates sorted by score (descending).
|
|
655
|
+
|
|
656
|
+
Returns:
|
|
657
|
+
List with at most one preference pair.
|
|
658
|
+
"""
|
|
659
|
+
if len(sorted_candidates) < 2:
|
|
660
|
+
return []
|
|
661
|
+
|
|
662
|
+
chosen = sorted_candidates[0] # Best
|
|
663
|
+
rejected = sorted_candidates[-1] # Worst
|
|
664
|
+
|
|
665
|
+
score_diff = chosen.score - rejected.score
|
|
666
|
+
|
|
667
|
+
# Apply minimum score difference filter
|
|
668
|
+
if score_diff < self.config.min_score_diff:
|
|
669
|
+
self._metrics["skipped_score_diff"] = (self._metrics.get("skipped_score_diff", 0) + 1)
|
|
670
|
+
return []
|
|
671
|
+
|
|
672
|
+
return [
|
|
673
|
+
PreferencePair(
|
|
674
|
+
example_id=chosen.example_id,
|
|
675
|
+
turn_id=chosen.turn_id,
|
|
676
|
+
prompt=chosen.prompt,
|
|
677
|
+
chosen_response=chosen.response,
|
|
678
|
+
rejected_response=rejected.response,
|
|
679
|
+
chosen_score=chosen.score,
|
|
680
|
+
rejected_score=rejected.score,
|
|
681
|
+
score_diff=score_diff,
|
|
682
|
+
chosen_index=chosen.candidate_index,
|
|
683
|
+
rejected_index=rejected.candidate_index,
|
|
684
|
+
metadata={
|
|
685
|
+
"num_candidates": len(sorted_candidates),
|
|
686
|
+
},
|
|
687
|
+
)
|
|
688
|
+
]
|
|
689
|
+
|
|
690
|
+
def _build_trajectories(self, pairs: list[PreferencePair]) -> list[Trajectory]:
|
|
691
|
+
"""
|
|
692
|
+
Convert preference pairs to Trajectory format with DPOItem episodes.
|
|
693
|
+
|
|
694
|
+
Each trajectory contains:
|
|
695
|
+
- episode: [DPOItem] with prompt, chosen_response, rejected_response
|
|
696
|
+
- reward: score_diff (if reward_from_score_diff) or chosen_score
|
|
697
|
+
- metadata: Contains pair information for tracking
|
|
698
|
+
|
|
699
|
+
Args:
|
|
700
|
+
pairs: List of preference pairs.
|
|
701
|
+
|
|
702
|
+
Returns:
|
|
703
|
+
List of trajectories with DPOItem episodes.
|
|
704
|
+
"""
|
|
705
|
+
trajectories: list[Trajectory] = []
|
|
706
|
+
|
|
707
|
+
for pair in pairs:
|
|
708
|
+
# Create DPOItem from preference pair
|
|
709
|
+
dpo_item = DPOItem(
|
|
710
|
+
prompt=pair.prompt,
|
|
711
|
+
chosen_response=pair.chosen_response,
|
|
712
|
+
rejected_response=pair.rejected_response,
|
|
713
|
+
)
|
|
714
|
+
|
|
715
|
+
# Compute reward
|
|
716
|
+
if self.config.reward_from_score_diff:
|
|
717
|
+
reward = pair.score_diff
|
|
718
|
+
else:
|
|
719
|
+
reward = pair.chosen_score
|
|
720
|
+
|
|
721
|
+
# Build trajectory with DPOItem episode
|
|
722
|
+
trajectory = Trajectory(
|
|
723
|
+
episode=[dpo_item],
|
|
724
|
+
reward=reward,
|
|
725
|
+
shaped_rewards=None,
|
|
726
|
+
metadata={
|
|
727
|
+
# DPO-specific fields
|
|
728
|
+
"dpo_type": "preference_pair",
|
|
729
|
+
"score_diff": pair.score_diff, # Tracking fields
|
|
730
|
+
"example_id": pair.example_id,
|
|
731
|
+
"turn_id": pair.turn_id,
|
|
732
|
+
"chosen_score": pair.chosen_score,
|
|
733
|
+
"rejected_score": pair.rejected_score,
|
|
734
|
+
"chosen_index": pair.chosen_index,
|
|
735
|
+
"rejected_index": pair.rejected_index, # Additional metadata
|
|
736
|
+
**pair.metadata,
|
|
737
|
+
},
|
|
738
|
+
)
|
|
739
|
+
|
|
740
|
+
trajectories.append(trajectory)
|
|
741
|
+
|
|
742
|
+
return trajectories
|
|
743
|
+
|
|
744
|
+
def _group_by_example(self, trajectories: list[Trajectory]) -> list[list[Trajectory]]:
|
|
745
|
+
"""
|
|
746
|
+
Group trajectories by example ID for curriculum learning.
|
|
747
|
+
|
|
748
|
+
This grouping enables:
|
|
749
|
+
- Filtering by average reward per example
|
|
750
|
+
- Expansion from easy to hard examples
|
|
751
|
+
|
|
752
|
+
Args:
|
|
753
|
+
trajectories: List of trajectories to group.
|
|
754
|
+
|
|
755
|
+
Returns:
|
|
756
|
+
List of trajectory lists, where each inner list contains
|
|
757
|
+
trajectories for one example.
|
|
758
|
+
"""
|
|
759
|
+
by_example: dict[str, list[Trajectory]] = {}
|
|
760
|
+
|
|
761
|
+
for traj in trajectories:
|
|
762
|
+
example_id = traj.metadata.get("example_id", "unknown")
|
|
763
|
+
if example_id not in by_example:
|
|
764
|
+
by_example[example_id] = []
|
|
765
|
+
by_example[example_id].append(traj)
|
|
766
|
+
|
|
767
|
+
return list(by_example.values())
|