nvidia-nat-nemo-customizer 1.4.0a20251223__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,767 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """
16
+ DPO (Direct Preference Optimization) Trajectory Builder.
17
+
18
+ This module provides a trajectory builder that collects preference data from
19
+ workflows that produce TTC_END intermediate steps with TTCEventData.
20
+
21
+ The builder:
22
+ 1. Runs evaluation to collect intermediate steps
23
+ 2. Filters for TTC_END steps with the configured name
24
+ 3. Extracts data from TTCEventData (turn_id, candidate_index, score, input, output)
25
+ 4. Groups candidates by turn_id
26
+ 5. Generates preference pairs based on score differences
27
+ 6. Builds trajectories with DPOItem episodes for DPO training
28
+ """
29
+
30
+ from __future__ import annotations
31
+
32
+ import asyncio
33
+ import json
34
+ import logging
35
+ from dataclasses import dataclass
36
+ from dataclasses import field
37
+ from datetime import datetime
38
+ from pathlib import Path
39
+ from typing import Any
40
+
41
+ from nat.data_models.finetuning import DPOItem
42
+ from nat.data_models.finetuning import OpenAIMessage
43
+ from nat.data_models.finetuning import Trajectory
44
+ from nat.data_models.finetuning import TrajectoryCollection
45
+ from nat.data_models.intermediate_step import IntermediateStep
46
+ from nat.data_models.intermediate_step import IntermediateStepCategory
47
+ from nat.data_models.intermediate_step import IntermediateStepType
48
+ from nat.data_models.intermediate_step import StreamEventData
49
+ from nat.data_models.intermediate_step import TTCEventData
50
+ from nat.eval.config import EvaluationRunOutput
51
+ from nat.eval.evaluator.evaluator_model import EvalInputItem
52
+ from nat.finetuning.interfaces.trajectory_builder import TrajectoryBuilder
53
+
54
+ from .config import DPOTrajectoryBuilderConfig
55
+
56
+ logger = logging.getLogger(__name__)
57
+
58
+ # Type alias for prompt which can be string or list of OpenAI messages
59
+ PromptType = list[OpenAIMessage] | str
60
+
61
+ # =============================================================================
62
+ # Data Classes
63
+ # =============================================================================
64
+
65
+
66
+ @dataclass
67
+ class CandidateStep:
68
+ """
69
+ Parsed candidate from a TTC intermediate step.
70
+
71
+ Represents a single candidate response that was generated and scored
72
+ for a particular turn in the workflow.
73
+ """
74
+
75
+ example_id: str
76
+ """Unique identifier for the dataset example."""
77
+
78
+ turn_id: str
79
+ """Identifier for the turn (groups candidates competing for the same prompt)."""
80
+
81
+ candidate_index: int
82
+ """Index of this candidate within the turn."""
83
+
84
+ prompt: PromptType
85
+ """Input prompt that produced this response (string or list of OpenAIMessage)."""
86
+
87
+ response: str
88
+ """Model's response/completion."""
89
+
90
+ score: float
91
+ """Score assigned to this candidate (higher is better)."""
92
+
93
+ raw_metadata: dict[str, Any] = field(default_factory=dict)
94
+ """Original metadata from the intermediate step."""
95
+
96
+
97
+ @dataclass
98
+ class PreferencePair:
99
+ """
100
+ A preference pair for DPO training.
101
+
102
+ Represents a single (prompt, chosen, rejected) triple where the chosen
103
+ response has a higher score than the rejected response.
104
+ """
105
+
106
+ example_id: str
107
+ """Unique identifier for the dataset example."""
108
+
109
+ turn_id: str
110
+ """Identifier for the turn."""
111
+
112
+ prompt: PromptType
113
+ """Input prompt (same for both responses)."""
114
+
115
+ chosen_response: str
116
+ """Response that was preferred (higher score)."""
117
+
118
+ rejected_response: str
119
+ """Response that was not preferred (lower score)."""
120
+
121
+ chosen_score: float
122
+ """Score of the chosen response."""
123
+
124
+ rejected_score: float
125
+ """Score of the rejected response."""
126
+
127
+ score_diff: float
128
+ """Difference between chosen and rejected scores."""
129
+
130
+ chosen_index: int
131
+ """Candidate index of the chosen response."""
132
+
133
+ rejected_index: int
134
+ """Candidate index of the rejected response."""
135
+
136
+ metadata: dict[str, Any] = field(default_factory=dict)
137
+ """Additional metadata for the pair."""
138
+
139
+
140
+ # =============================================================================
141
+ # DPO Trajectory Builder
142
+ # =============================================================================
143
+
144
+
145
+ class DPOTrajectoryBuilder(TrajectoryBuilder):
146
+ """
147
+ Trajectory builder for DPO (Direct Preference Optimization) training.
148
+
149
+ This builder collects preference pairs from workflows that produce TTC_END
150
+ intermediate steps with TTCEventData. It uses the structured data model
151
+ to extract turn_id, candidate_index, score, input (prompt), and output.
152
+
153
+ Key features:
154
+ - Uses TTCEventData model directly (no brittle dictionary key configuration)
155
+ - Supports prompts as strings or list of OpenAIMessage
156
+ - Exhaustive or best-vs-worst pair generation modes
157
+ - Configurable score difference filtering
158
+ - Grouping by example for curriculum learning
159
+ - Builds trajectories with DPOItem episodes
160
+
161
+ Example workflow integration::
162
+
163
+ trajectory_builders:
164
+ dpo_builder:
165
+ _type: dpo_traj_builder
166
+ ttc_step_name: dpo_candidate_move
167
+ exhaustive_pairs: true
168
+ min_score_diff: 0.05
169
+ """
170
+
171
+ def __init__(self, trajectory_builder_config: DPOTrajectoryBuilderConfig):
172
+ """
173
+ Initialize the DPO Trajectory Builder.
174
+
175
+ Args:
176
+ trajectory_builder_config: Configuration for the builder.
177
+ """
178
+ super().__init__(trajectory_builder_config=trajectory_builder_config)
179
+ self.config: DPOTrajectoryBuilderConfig = trajectory_builder_config
180
+ self.evaluation_runs: dict[str, asyncio.Task[EvaluationRunOutput]] = {}
181
+
182
+ # Metrics tracking
183
+ self._metrics: dict[str, Any] = {}
184
+
185
+ # =========================================================================
186
+ # TrajectoryBuilder Interface Implementation
187
+ # =========================================================================
188
+
189
+ async def start_run(self, run_id: str, meta: dict | None = None) -> None:
190
+ """
191
+ Start a single evaluation run to collect intermediate steps.
192
+
193
+ Args:
194
+ run_id: Unique identifier for this run.
195
+ meta: Optional metadata for the run.
196
+
197
+ Raises:
198
+ ValueError: If a run with this ID is already in progress.
199
+ """
200
+ if run_id in self.evaluation_runs:
201
+ raise ValueError(f"Run {run_id} is already in progress.")
202
+
203
+ logger.info("Starting DPO evaluation run: %s", run_id)
204
+ logger.info(
205
+ "Configuration: step_name=%s, exhaustive=%s, min_diff=%.3f",
206
+ self.config.ttc_step_name,
207
+ self.config.exhaustive_pairs,
208
+ self.config.min_score_diff,
209
+ )
210
+
211
+ # Create evaluation task
212
+ task = asyncio.create_task(self.run_eval(), name=f"dpo-eval-{run_id}")
213
+
214
+ def _on_done(t: asyncio.Task[EvaluationRunOutput]) -> None:
215
+ if t.cancelled():
216
+ logger.info("DPO evaluation run %s was cancelled.", run_id)
217
+ elif exc := t.exception():
218
+ logger.error("DPO evaluation run %s failed: %s", run_id, exc)
219
+ else:
220
+ logger.info("DPO evaluation run %s completed successfully.", run_id)
221
+
222
+ task.add_done_callback(_on_done)
223
+ self.evaluation_runs[run_id] = task
224
+
225
+ async def finalize(self, run_id: str, meta: dict | None = None) -> TrajectoryCollection:
226
+ """
227
+ Wait for evaluation, collect TTC steps, and build DPO trajectories.
228
+
229
+ This method:
230
+ 1. Waits for the evaluation run to complete
231
+ 2. Collects and groups candidates by turn_id using TTCEventData
232
+ 3. Generates preference pairs
233
+ 4. Builds trajectories with DPOItem episodes
234
+ 5. Groups trajectories by example for curriculum learning
235
+
236
+ Args:
237
+ run_id: Unique identifier for the run.
238
+ meta: Optional metadata for the run.
239
+
240
+ Returns:
241
+ TrajectoryCollection with DPO preference trajectories.
242
+
243
+ Raises:
244
+ ValueError: If no run with this ID exists.
245
+ """
246
+ if run_id not in self.evaluation_runs:
247
+ raise ValueError(f"No evaluation run found for run_id: {run_id}")
248
+
249
+ # Wait for evaluation to complete
250
+ logger.info("Waiting for DPO evaluation run %s to complete...", run_id)
251
+ eval_result = await self.evaluation_runs[run_id]
252
+
253
+ # Initialize metrics
254
+ self._metrics = {
255
+ "run_id": run_id,
256
+ "total_examples": 0,
257
+ "total_turns": 0,
258
+ "total_candidates": 0,
259
+ "total_pairs": 0,
260
+ "total_trajectories": 0,
261
+ "skipped_single_candidate": 0,
262
+ "skipped_score_diff": 0,
263
+ }
264
+
265
+ # Step 1: Collect and group candidates
266
+ candidates_by_turn = self._collect_candidates(eval_result)
267
+ self._metrics["total_turns"] = len(candidates_by_turn)
268
+
269
+ if not candidates_by_turn:
270
+ logger.warning("No candidate steps found for run_id: %s", run_id)
271
+ del self.evaluation_runs[run_id]
272
+ return TrajectoryCollection(trajectories=[], run_id=run_id)
273
+
274
+ # Step 2: Generate preference pairs
275
+ pairs = self._generate_preference_pairs(candidates_by_turn)
276
+ self._metrics["total_pairs"] = len(pairs)
277
+
278
+ if not pairs:
279
+ logger.warning("No preference pairs generated for run_id: %s", run_id)
280
+ del self.evaluation_runs[run_id]
281
+ return TrajectoryCollection(trajectories=[], run_id=run_id)
282
+
283
+ # Step 3: Build trajectories with DPOItem episodes
284
+ trajectories = self._build_trajectories(pairs)
285
+ self._metrics["total_trajectories"] = len(trajectories)
286
+
287
+ # Step 4: Group by example for curriculum learning
288
+ grouped = self._group_by_example(trajectories)
289
+ self._metrics["total_examples"] = len(grouped)
290
+
291
+ # Log summary
292
+ logger.info(
293
+ "DPO trajectory building complete for run %s: "
294
+ "%d examples, %d turns, %d candidates, %d pairs, %d trajectories",
295
+ run_id,
296
+ self._metrics["total_examples"],
297
+ self._metrics["total_turns"],
298
+ self._metrics["total_candidates"],
299
+ self._metrics["total_pairs"],
300
+ self._metrics["total_trajectories"],
301
+ )
302
+
303
+ if self._metrics["skipped_single_candidate"] > 0:
304
+ logger.info(
305
+ "Skipped %d turns with single candidate (no preference signal)",
306
+ self._metrics["skipped_single_candidate"],
307
+ )
308
+
309
+ if self._metrics["skipped_score_diff"] > 0:
310
+ logger.info(
311
+ "Skipped %d pairs with score diff < %.3f",
312
+ self._metrics["skipped_score_diff"],
313
+ self.config.min_score_diff,
314
+ )
315
+
316
+ # Cleanup
317
+ del self.evaluation_runs[run_id]
318
+
319
+ return TrajectoryCollection(trajectories=grouped, run_id=run_id)
320
+
321
+ def log_progress(self, run_id: str, metrics: dict[str, Any], output_dir: str | None = None) -> None:
322
+ """
323
+ Log trajectory building progress.
324
+
325
+ Args:
326
+ run_id: The training run ID.
327
+ metrics: Dictionary of metrics to log.
328
+ output_dir: Optional output directory override.
329
+ """
330
+ # Use default output directory if not provided
331
+ out_dir = (Path(output_dir) if output_dir else Path("./.tmp/nat/finetuning/dpo_trajectory_builder"))
332
+ out_dir.mkdir(parents=True, exist_ok=True)
333
+
334
+ # Create log file
335
+ log_file = out_dir / f"dpo_trajectory_builder_{run_id}.jsonl"
336
+
337
+ # Prepare log entry
338
+ log_entry = {
339
+ "timestamp": datetime.now().isoformat(),
340
+ "run_id": run_id,
341
+ "config": {
342
+ "ttc_step_name": self.config.ttc_step_name,
343
+ "exhaustive_pairs": self.config.exhaustive_pairs,
344
+ "min_score_diff": self.config.min_score_diff,
345
+ "max_pairs_per_turn": self.config.max_pairs_per_turn,
346
+ },
347
+ **self._metrics,
348
+ **metrics,
349
+ }
350
+
351
+ # Append to log file
352
+ with open(log_file, "a", encoding="utf-8") as f:
353
+ f.write(json.dumps(log_entry) + "\n")
354
+
355
+ logger.debug(
356
+ "DPO trajectory builder progress logged for run %s: %d pairs",
357
+ run_id,
358
+ self._metrics.get("total_pairs", 0),
359
+ )
360
+
361
+ # =========================================================================
362
+ # Internal Methods
363
+ # =========================================================================
364
+
365
+ def _collect_candidates(self, eval_result: EvaluationRunOutput) -> dict[str, list[CandidateStep]]:
366
+ """
367
+ Extract TTC_END intermediate steps and group by turn_id.
368
+
369
+ This method:
370
+ 1. Iterates through all evaluation input items
371
+ 2. Filters for TTC_END steps with the configured name
372
+ 3. Extracts data from TTCEventData model directly
373
+ 4. Groups candidates by (example_id, turn_id)
374
+
375
+ Args:
376
+ eval_result: The evaluation run output.
377
+
378
+ Returns:
379
+ Dictionary mapping turn keys to lists of candidates.
380
+ """
381
+ candidates_by_turn: dict[str, list[CandidateStep]] = {}
382
+
383
+ # Create mapping of example ID to input item
384
+ input_items_map: dict[str, EvalInputItem] = {item.id: item for item in eval_result.eval_input.eval_input_items}
385
+
386
+ for example_id, input_item in input_items_map.items():
387
+ # Filter for TTC_END steps with matching name
388
+ for step in input_item.trajectory:
389
+ if not self._is_target_step(step):
390
+ continue
391
+
392
+ # Parse candidate from TTCEventData
393
+ candidate = self._parse_candidate(example_id, step)
394
+ if candidate is None:
395
+ continue
396
+
397
+ self._metrics["total_candidates"] = (self._metrics.get("total_candidates", 0) + 1)
398
+
399
+ # Group by (example_id, turn_id)
400
+ turn_key = f"{example_id}::{candidate.turn_id}"
401
+ if turn_key not in candidates_by_turn:
402
+ candidates_by_turn[turn_key] = []
403
+ candidates_by_turn[turn_key].append(candidate)
404
+
405
+ logger.debug(
406
+ "Collected %d candidates across %d turns",
407
+ self._metrics.get("total_candidates", 0),
408
+ len(candidates_by_turn),
409
+ )
410
+
411
+ return candidates_by_turn
412
+
413
+ def _is_target_step(self, step: IntermediateStep) -> bool:
414
+ """
415
+ Check if an intermediate step is a target TTC step.
416
+
417
+ Args:
418
+ step: The intermediate step to check.
419
+
420
+ Returns:
421
+ True if this is a TTC_END step with the configured name.
422
+ """
423
+ return (step.event_category == IntermediateStepCategory.TTC and step.event_type == IntermediateStepType.TTC_END
424
+ and step.payload.name == self.config.ttc_step_name)
425
+
426
+ def _parse_candidate(self, example_id: str, step: IntermediateStep) -> CandidateStep | None:
427
+ """
428
+ Parse a CandidateStep from a TTC intermediate step using TTCEventData.
429
+
430
+ Args:
431
+ example_id: The example ID this step belongs to.
432
+ step: The intermediate step to parse.
433
+
434
+ Returns:
435
+ CandidateStep if parsing succeeds, None otherwise.
436
+ """
437
+ # Get TTCEventData from step.payload.data
438
+ data = step.payload.data
439
+ if data is None:
440
+ logger.warning("Step has no data field, skipping: %s", step.payload.UUID)
441
+ return None
442
+
443
+ # Validate that we have TTCEventData (or compatible dict/StreamEventData)
444
+ # NOTE: When IntermediateStepPayload is serialized/deserialized, TTCEventData
445
+ # becomes StreamEventData because the data field is typed as StreamEventData.
446
+ # The TTC fields are preserved as extra fields due to extra="allow".
447
+ if isinstance(data, TTCEventData):
448
+ ttc_data = data
449
+ elif isinstance(data, StreamEventData):
450
+ # TTCEventData may have been deserialized as StreamEventData
451
+ # Try to construct TTCEventData from the model dump
452
+ try:
453
+ data_dict = data.model_dump()
454
+ ttc_data = TTCEventData(**data_dict)
455
+ except Exception as e:
456
+ logger.warning("Failed to parse TTCEventData from StreamEventData: %s", e)
457
+ return None
458
+ elif isinstance(data, dict):
459
+ # Try to parse as TTCEventData
460
+ try:
461
+ ttc_data = TTCEventData(**data)
462
+ except Exception as e:
463
+ logger.warning("Failed to parse TTCEventData from dict: %s", e)
464
+ return None
465
+ else:
466
+ logger.warning("Unexpected data type %s, expected TTCEventData", type(data))
467
+ return None
468
+
469
+ # Extract required fields from TTCEventData
470
+ try:
471
+ turn_id = ttc_data.turn_id
472
+ if turn_id is None:
473
+ logger.warning(
474
+ "TTCEventData missing turn_id, skipping: %s",
475
+ step.payload.UUID,
476
+ )
477
+ return None
478
+
479
+ score = ttc_data.score
480
+ if score is None:
481
+ logger.warning(
482
+ "TTCEventData missing score, skipping: %s",
483
+ step.payload.UUID,
484
+ )
485
+ return None
486
+
487
+ candidate_index = ttc_data.candidate_index or 0
488
+
489
+ # Get prompt from TTCEventData.input
490
+ # This can be a string or list of OpenAIMessage
491
+ prompt = self._extract_prompt(ttc_data.input)
492
+
493
+ # Get response from TTCEventData.output
494
+ response = str(ttc_data.output) if ttc_data.output else ""
495
+
496
+ # Get raw metadata for additional context
497
+ raw_metadata = {}
498
+ if step.payload.metadata:
499
+ if hasattr(step.payload.metadata, "model_dump"):
500
+ raw_metadata = step.payload.metadata.model_dump()
501
+ elif isinstance(step.payload.metadata, dict):
502
+ raw_metadata = step.payload.metadata
503
+
504
+ return CandidateStep(
505
+ example_id=str(example_id),
506
+ turn_id=str(turn_id),
507
+ candidate_index=int(candidate_index),
508
+ prompt=prompt,
509
+ response=response,
510
+ score=float(score),
511
+ raw_metadata=raw_metadata,
512
+ )
513
+
514
+ except (TypeError, ValueError) as e:
515
+ logger.warning(
516
+ "Failed to parse candidate from step %s: %s",
517
+ step.payload.UUID,
518
+ e,
519
+ )
520
+ return None
521
+
522
+ def _extract_prompt(self, input_data: Any) -> PromptType:
523
+ """
524
+ Extract prompt from TTCEventData.input.
525
+
526
+ Handles both string prompts and list of OpenAIMessage.
527
+
528
+ Args:
529
+ input_data: The input field from TTCEventData.
530
+
531
+ Returns:
532
+ String prompt or list of OpenAIMessage.
533
+ """
534
+ if input_data is None:
535
+ return ""
536
+
537
+ if isinstance(input_data, str):
538
+ return input_data
539
+
540
+ if isinstance(input_data, list):
541
+ # Try to convert to list of OpenAIMessage
542
+ messages: list[OpenAIMessage] = []
543
+ for item in input_data:
544
+ if isinstance(item, OpenAIMessage):
545
+ messages.append(item)
546
+ elif isinstance(item, dict):
547
+ # Try to parse as OpenAIMessage
548
+ try:
549
+ messages.append(OpenAIMessage(**item))
550
+ except Exception:
551
+ # If parsing fails, convert entire input to string
552
+ return str(input_data)
553
+ else:
554
+ # Unknown type, convert to string
555
+ return str(input_data)
556
+ return messages
557
+
558
+ # Fallback: convert to string
559
+ return str(input_data)
560
+
561
+ def _generate_preference_pairs(self, candidates_by_turn: dict[str, list[CandidateStep]]) -> list[PreferencePair]:
562
+ """
563
+ Generate preference pairs from grouped candidates.
564
+
565
+ If exhaustive_pairs=True:
566
+ For candidates [A, B, C] with scores [0.9, 0.7, 0.5]:
567
+ Pairs: (A>B), (A>C), (B>C) - all pairwise comparisons
568
+
569
+ If exhaustive_pairs=False:
570
+ For candidates [A, B, C] with scores [0.9, 0.7, 0.5]:
571
+ Pairs: (A>C) only - best vs worst
572
+
573
+ Args:
574
+ candidates_by_turn: Dictionary mapping turn keys to candidate lists.
575
+
576
+ Returns:
577
+ List of preference pairs.
578
+ """
579
+ all_pairs: list[PreferencePair] = []
580
+
581
+ for turn_key, candidates in candidates_by_turn.items():
582
+ # Check if we have enough candidates
583
+ if len(candidates) < 2:
584
+ if self.config.require_multiple_candidates:
585
+ self._metrics["skipped_single_candidate"] = (self._metrics.get("skipped_single_candidate", 0) + 1)
586
+ logger.debug("Skipping turn %s with single candidate", turn_key)
587
+ continue
588
+
589
+ # Sort candidates by score (descending)
590
+ sorted_candidates = sorted(candidates, key=lambda c: c.score, reverse=True)
591
+
592
+ if self.config.exhaustive_pairs:
593
+ pairs = self._generate_exhaustive_pairs(sorted_candidates)
594
+ else:
595
+ pairs = self._generate_best_vs_worst_pair(sorted_candidates)
596
+
597
+ all_pairs.extend(pairs)
598
+
599
+ logger.debug("Generated %d preference pairs", len(all_pairs))
600
+ return all_pairs
601
+
602
+ def _generate_exhaustive_pairs(self, sorted_candidates: list[CandidateStep]) -> list[PreferencePair]:
603
+ """
604
+ Generate all pairwise comparisons where score(chosen) > score(rejected).
605
+
606
+ Args:
607
+ sorted_candidates: Candidates sorted by score (descending).
608
+
609
+ Returns:
610
+ List of preference pairs, sorted by score difference (descending).
611
+ """
612
+ pairs: list[PreferencePair] = []
613
+
614
+ for i, chosen in enumerate(sorted_candidates):
615
+ for rejected in sorted_candidates[i + 1:]:
616
+ score_diff = chosen.score - rejected.score
617
+
618
+ # Apply minimum score difference filter
619
+ if score_diff < self.config.min_score_diff:
620
+ self._metrics["skipped_score_diff"] = (self._metrics.get("skipped_score_diff", 0) + 1)
621
+ continue
622
+
623
+ pairs.append(
624
+ PreferencePair(
625
+ example_id=chosen.example_id,
626
+ turn_id=chosen.turn_id,
627
+ prompt=chosen.prompt,
628
+ chosen_response=chosen.response,
629
+ rejected_response=rejected.response,
630
+ chosen_score=chosen.score,
631
+ rejected_score=rejected.score,
632
+ score_diff=score_diff,
633
+ chosen_index=chosen.candidate_index,
634
+ rejected_index=rejected.candidate_index,
635
+ metadata={
636
+ "chosen_raw_metadata": chosen.raw_metadata,
637
+ "rejected_raw_metadata": rejected.raw_metadata,
638
+ },
639
+ ))
640
+
641
+ # Sort by score difference (highest first) and apply limit
642
+ pairs.sort(key=lambda p: p.score_diff, reverse=True)
643
+
644
+ if self.config.max_pairs_per_turn is not None:
645
+ pairs = pairs[:self.config.max_pairs_per_turn]
646
+
647
+ return pairs
648
+
649
+ def _generate_best_vs_worst_pair(self, sorted_candidates: list[CandidateStep]) -> list[PreferencePair]:
650
+ """
651
+ Generate a single pair: best candidate vs worst candidate.
652
+
653
+ Args:
654
+ sorted_candidates: Candidates sorted by score (descending).
655
+
656
+ Returns:
657
+ List with at most one preference pair.
658
+ """
659
+ if len(sorted_candidates) < 2:
660
+ return []
661
+
662
+ chosen = sorted_candidates[0] # Best
663
+ rejected = sorted_candidates[-1] # Worst
664
+
665
+ score_diff = chosen.score - rejected.score
666
+
667
+ # Apply minimum score difference filter
668
+ if score_diff < self.config.min_score_diff:
669
+ self._metrics["skipped_score_diff"] = (self._metrics.get("skipped_score_diff", 0) + 1)
670
+ return []
671
+
672
+ return [
673
+ PreferencePair(
674
+ example_id=chosen.example_id,
675
+ turn_id=chosen.turn_id,
676
+ prompt=chosen.prompt,
677
+ chosen_response=chosen.response,
678
+ rejected_response=rejected.response,
679
+ chosen_score=chosen.score,
680
+ rejected_score=rejected.score,
681
+ score_diff=score_diff,
682
+ chosen_index=chosen.candidate_index,
683
+ rejected_index=rejected.candidate_index,
684
+ metadata={
685
+ "num_candidates": len(sorted_candidates),
686
+ },
687
+ )
688
+ ]
689
+
690
+ def _build_trajectories(self, pairs: list[PreferencePair]) -> list[Trajectory]:
691
+ """
692
+ Convert preference pairs to Trajectory format with DPOItem episodes.
693
+
694
+ Each trajectory contains:
695
+ - episode: [DPOItem] with prompt, chosen_response, rejected_response
696
+ - reward: score_diff (if reward_from_score_diff) or chosen_score
697
+ - metadata: Contains pair information for tracking
698
+
699
+ Args:
700
+ pairs: List of preference pairs.
701
+
702
+ Returns:
703
+ List of trajectories with DPOItem episodes.
704
+ """
705
+ trajectories: list[Trajectory] = []
706
+
707
+ for pair in pairs:
708
+ # Create DPOItem from preference pair
709
+ dpo_item = DPOItem(
710
+ prompt=pair.prompt,
711
+ chosen_response=pair.chosen_response,
712
+ rejected_response=pair.rejected_response,
713
+ )
714
+
715
+ # Compute reward
716
+ if self.config.reward_from_score_diff:
717
+ reward = pair.score_diff
718
+ else:
719
+ reward = pair.chosen_score
720
+
721
+ # Build trajectory with DPOItem episode
722
+ trajectory = Trajectory(
723
+ episode=[dpo_item],
724
+ reward=reward,
725
+ shaped_rewards=None,
726
+ metadata={
727
+ # DPO-specific fields
728
+ "dpo_type": "preference_pair",
729
+ "score_diff": pair.score_diff, # Tracking fields
730
+ "example_id": pair.example_id,
731
+ "turn_id": pair.turn_id,
732
+ "chosen_score": pair.chosen_score,
733
+ "rejected_score": pair.rejected_score,
734
+ "chosen_index": pair.chosen_index,
735
+ "rejected_index": pair.rejected_index, # Additional metadata
736
+ **pair.metadata,
737
+ },
738
+ )
739
+
740
+ trajectories.append(trajectory)
741
+
742
+ return trajectories
743
+
744
+ def _group_by_example(self, trajectories: list[Trajectory]) -> list[list[Trajectory]]:
745
+ """
746
+ Group trajectories by example ID for curriculum learning.
747
+
748
+ This grouping enables:
749
+ - Filtering by average reward per example
750
+ - Expansion from easy to hard examples
751
+
752
+ Args:
753
+ trajectories: List of trajectories to group.
754
+
755
+ Returns:
756
+ List of trajectory lists, where each inner list contains
757
+ trajectories for one example.
758
+ """
759
+ by_example: dict[str, list[Trajectory]] = {}
760
+
761
+ for traj in trajectories:
762
+ example_id = traj.metadata.get("example_id", "unknown")
763
+ if example_id not in by_example:
764
+ by_example[example_id] = []
765
+ by_example[example_id].append(traj)
766
+
767
+ return list(by_example.values())