nvidia-nat-openpipe-art 1.4.0a20260116__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,659 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import json
17
+ import logging
18
+ import math
19
+ import uuid
20
+ from datetime import datetime
21
+ from pathlib import Path
22
+ from typing import Any
23
+
24
+ from nat.data_models.finetuning import FinetuneConfig
25
+ from nat.data_models.finetuning import TrainingJobRef
26
+ from nat.data_models.finetuning import TrainingJobStatus
27
+ from nat.data_models.finetuning import TrainingStatusEnum
28
+ from nat.data_models.finetuning import TrajectoryCollection
29
+ from nat.finetuning.interfaces.finetuning_runner import Trainer
30
+
31
+ from .config import ARTTrainerConfig
32
+
33
+ # Configure matplotlib for non-interactive backend
34
+ try:
35
+ import matplotlib
36
+
37
+ matplotlib.use('Agg')
38
+ import matplotlib.pyplot as plt
39
+
40
+ MATPLOTLIB_AVAILABLE = True
41
+ except ImportError:
42
+ MATPLOTLIB_AVAILABLE = False
43
+ plt = None
44
+
45
+ logger = logging.getLogger(__name__)
46
+
47
+
48
+ class ARTTrainer(Trainer):
49
+ """
50
+ Concrete implementation of Trainer for the OpenPipe ART backend.
51
+
52
+ This runner orchestrates the finetuning process using:
53
+ - ARTTrajectoryBuilder to collect trajectories from evaluations
54
+ - ARTTrainerAdapter to submit trajectories to the ART training backend
55
+ """
56
+
57
+ def __init__(self, trainer_config: ARTTrainerConfig, **kwargs) -> None:
58
+ """
59
+ Initialize the OpenPipe ART Runner.
60
+
61
+ Args:
62
+ trainer_config: Configuration for the ART trainer backend
63
+ """
64
+ super().__init__(trainer_config)
65
+
66
+ # Type hint for the specific config
67
+ self.trainer_config: ARTTrainerConfig = trainer_config
68
+
69
+ # Track job references
70
+ self._job_refs: list[TrainingJobRef] = []
71
+ self._run_id: str | None = None
72
+
73
+ # Track rewards for plotting
74
+ self._reward_history: list[dict] = []
75
+ self._validation_history: list[dict] = []
76
+
77
+ async def initialize(self, run_config: FinetuneConfig) -> None:
78
+ """
79
+ Initialize the runner and its components.
80
+
81
+ This will:
82
+ - Initialize the TrainerAdapter and verify connectivity
83
+ - Prepare the TrajectoryBuilder for collecting trajectories
84
+ """
85
+ logger.info("Initializing OpenPipe ART Runner")
86
+
87
+ await super().initialize(run_config)
88
+
89
+ # Generate a unique run ID
90
+ self._run_id = f"art_run_{uuid.uuid4().hex[:8]}"
91
+
92
+ logger.info(f"OpenPipe ART Runner initialized with run ID: {self._run_id}")
93
+
94
+ async def run_epoch(self, epoch: int, run_id: str) -> TrainingJobRef | None:
95
+ """
96
+ Run a single epoch of training.
97
+
98
+ Args:
99
+ epoch: The current epoch number (0-indexed)
100
+ run_id: Unique identifier for this training run
101
+
102
+ Returns:
103
+ TrainingJobRef: Reference to the submitted training job
104
+ """
105
+ logger.info(f"Starting epoch {epoch + 1} for run {run_id}")
106
+
107
+ # Start the trajectory builder for this epoch
108
+ epoch_meta = {
109
+ "epoch": epoch,
110
+ "run_id": run_id,
111
+ "trainer_config": self.trainer_config.model_dump(),
112
+ }
113
+
114
+ # Check if we should run validation
115
+ if (self.run_config.run_configuration.validation_dataset
116
+ and epoch % self.run_config.run_configuration.validation_interval == 0):
117
+ logger.info(f"Running validation at epoch {epoch + 1}")
118
+ validation_metrics = await self.run_validation_evaluation(epoch, self._run_id)
119
+
120
+ # Store validation metrics
121
+ validation_info = {
122
+ "epoch": epoch,
123
+ "timestamp": datetime.now().isoformat(),
124
+ "avg_reward": validation_metrics.get("avg_reward", 0.0),
125
+ "min_reward": validation_metrics.get("min_reward", 0.0),
126
+ "max_reward": validation_metrics.get("max_reward", 0.0),
127
+ "num_examples": validation_metrics.get("num_examples", 0),
128
+ }
129
+ self._validation_history.append(validation_info)
130
+
131
+ await self.trajectory_builder.start_run(run_id=run_id, meta=epoch_meta)
132
+
133
+ # Finalize and get trajectories
134
+ trajectory_collection = await self.trajectory_builder.finalize(run_id=run_id, meta=epoch_meta)
135
+
136
+ if not trajectory_collection.trajectories:
137
+ logger.warning(f"No trajectories collected for epoch {epoch}")
138
+ # Return a dummy job ref
139
+ return None
140
+
141
+ # Calculate metrics from the original trajectories (before curriculum filtering)
142
+ # trajectory_collection.trajectories is a list of lists
143
+ # Each inner list contains trajectories for a specific example
144
+ all_rewards = []
145
+ total_trajectories = 0
146
+ group_stats = []
147
+
148
+ for trajectory_list in trajectory_collection.trajectories:
149
+ group_rewards = []
150
+ for trajectory in trajectory_list:
151
+ total_trajectories += 1
152
+ if hasattr(trajectory, 'reward'):
153
+ reward = trajectory.reward
154
+ all_rewards.append(reward)
155
+ group_rewards.append(reward)
156
+
157
+ if group_rewards:
158
+ avg_group_reward = sum(group_rewards) / len(group_rewards)
159
+ variance = sum((r - avg_group_reward)**2 for r in group_rewards) / len(group_rewards)
160
+ group_stats.append({"avg_reward": avg_group_reward, "variance": variance, "size": len(group_rewards)})
161
+
162
+ logger.info(f"Collected {total_trajectories} trajectories in {len(trajectory_collection.trajectories)} "
163
+ f"groups for epoch {epoch}")
164
+
165
+ # Calculate reward statistics from all trajectories
166
+ if all_rewards:
167
+ avg_reward = sum(all_rewards) / len(all_rewards)
168
+ min_reward = min(all_rewards)
169
+ max_reward = max(all_rewards)
170
+ else:
171
+ avg_reward = min_reward = max_reward = 0.0
172
+
173
+ # Apply curriculum learning to filter trajectories
174
+ filtered_collection = self.apply_curriculum_learning(trajectory_collection, epoch)
175
+
176
+ # Calculate metrics after curriculum filtering
177
+ filtered_trajectories = 0
178
+ filtered_rewards = []
179
+ for trajectory_list in filtered_collection.trajectories:
180
+ for trajectory in trajectory_list:
181
+ filtered_trajectories += 1
182
+ if hasattr(trajectory, 'reward'):
183
+ filtered_rewards.append(trajectory.reward)
184
+
185
+ if filtered_rewards:
186
+ filtered_avg_reward = sum(filtered_rewards) / len(filtered_rewards)
187
+ filtered_min_reward = min(filtered_rewards)
188
+ filtered_max_reward = max(filtered_rewards)
189
+ else:
190
+ filtered_avg_reward = filtered_min_reward = filtered_max_reward = 0.0
191
+
192
+ # Log progress with both original and filtered metrics
193
+ metrics = {
194
+ "avg_reward":
195
+ avg_reward,
196
+ "min_reward":
197
+ min_reward,
198
+ "max_reward":
199
+ max_reward,
200
+ "num_trajectories":
201
+ total_trajectories,
202
+ "num_groups":
203
+ len(trajectory_collection.trajectories), # Curriculum metrics
204
+ "filtered_trajectories":
205
+ filtered_trajectories,
206
+ "filtered_groups":
207
+ len(filtered_collection.trajectories),
208
+ "filtered_avg_reward":
209
+ filtered_avg_reward,
210
+ "filtered_min_reward":
211
+ filtered_min_reward,
212
+ "filtered_max_reward":
213
+ filtered_max_reward,
214
+ "curriculum_percentile":
215
+ self._curriculum_state["current_percentile"] if self.curriculum_config.enabled else 1.0,
216
+ }
217
+
218
+ # Log group statistics if curriculum learning is enabled
219
+ if self.curriculum_config.enabled and group_stats:
220
+ sorted_groups = sorted(group_stats, key=lambda x: x["avg_reward"], reverse=True)
221
+ logger.info("Group reward distribution - Top: %.4f, Median: %.4f, Bottom: %.4f",
222
+ sorted_groups[0]["avg_reward"],
223
+ sorted_groups[len(sorted_groups) // 2]["avg_reward"],
224
+ sorted_groups[-1]["avg_reward"])
225
+
226
+ self.log_progress(epoch, metrics)
227
+
228
+ # Check if we have trajectories after filtering
229
+ if not filtered_collection.trajectories:
230
+ logger.warning(f"No trajectories remaining after curriculum filtering for epoch {epoch}")
231
+ return None
232
+
233
+ # Submit filtered trajectories to trainer
234
+ job_ref = await self.trainer_adapter.submit(filtered_collection)
235
+ self._job_refs.append(job_ref)
236
+
237
+ logger.info(f"Submitted training job for epoch {epoch}: {job_ref}")
238
+
239
+ return job_ref
240
+
241
+ async def run(self, num_epochs: int) -> list[TrainingJobStatus]:
242
+ """
243
+ Run the complete finetuning workflow for the specified number of epochs.
244
+
245
+ Args:
246
+ num_epochs: Number of epochs to train
247
+
248
+ Returns:
249
+ list[TrainingJobStatus]: Status of all training jobs
250
+ """
251
+ if not self._run_id:
252
+ raise RuntimeError("Runner not initialized. Did you forget to call initialize(...)?")
253
+
254
+ logger.info(f"Starting finetuning run with {num_epochs} epochs")
255
+
256
+ job_statuses = []
257
+
258
+ for epoch in range(num_epochs):
259
+ try:
260
+ # Run the epoch
261
+ job_ref = await self.run_epoch(epoch, self._run_id)
262
+
263
+ # Wait for completion before starting next epoch
264
+ if job_ref:
265
+ status = await self.trainer_adapter.wait_until_complete(job_ref)
266
+ job_statuses.append(status)
267
+
268
+ # Check if training failed
269
+ if status.status == TrainingStatusEnum.FAILED:
270
+ logger.error(f"Training failed at epoch {epoch}: {status.message}")
271
+ break
272
+ else:
273
+ # No trajectories collected, create a dummy status
274
+ job_statuses.append(
275
+ TrainingJobStatus(run_id=self._run_id,
276
+ backend="openpipe-art",
277
+ status=TrainingStatusEnum.COMPLETED,
278
+ message="No trajectories to train on",
279
+ metadata={"epoch": epoch}))
280
+
281
+ logger.info(f"Completed epoch {epoch + 1}/{num_epochs}")
282
+
283
+ except Exception as e:
284
+ logger.error(f"Error during epoch {epoch}: {e}")
285
+ job_statuses.append(
286
+ TrainingJobStatus(run_id=self._run_id,
287
+ backend="openpipe-art",
288
+ status=TrainingStatusEnum.FAILED,
289
+ message=str(e),
290
+ metadata={"epoch": epoch}))
291
+ break
292
+
293
+ logger.info(f"Finetuning run completed. Processed {len(job_statuses)} epochs")
294
+ return job_statuses
295
+
296
+ async def get_metrics(self, run_id: str) -> dict[str, Any]:
297
+ """
298
+ Get training metrics for a specific run.
299
+
300
+ Args:
301
+ run_id: The run identifier
302
+
303
+ Returns:
304
+ dict: Metrics from the training run
305
+ """
306
+ metrics = {"run_id": run_id, "total_epochs": len(self._job_refs), "jobs": []}
307
+
308
+ for job_ref in self._job_refs:
309
+ try:
310
+ status = await self.trainer_adapter.status(job_ref)
311
+ metrics["jobs"].append({"job_ref": job_ref.model_dump(), "status": status.model_dump()})
312
+ except Exception as e:
313
+ logger.error(f"Failed to get status for job {job_ref}: {e}")
314
+ metrics["jobs"].append({"job_ref": job_ref.model_dump(), "error": str(e)})
315
+
316
+ return metrics
317
+
318
+ async def cleanup(self) -> None:
319
+ """
320
+ Clean up any resources used by the runner.
321
+ """
322
+ logger.info("Cleaning up OpenPipe ART Runner resources")
323
+
324
+ # Cleanup trajectory builder tasks
325
+ if hasattr(self.trajectory_builder, 'evaluation_runs'):
326
+ for run_id, task in self.trajectory_builder.evaluation_runs.items():
327
+ if not task.done():
328
+ logger.info(f"Cancelling evaluation task for run {run_id}")
329
+ task.cancel()
330
+
331
+ # Cleanup trainer adapter tasks
332
+ if hasattr(self.trainer_adapter, 'training_jobs'):
333
+ for job_id, task in self.trainer_adapter.training_jobs.items():
334
+ if not task.done():
335
+ logger.info(f"Cancelling training task for job {job_id}")
336
+ task.cancel()
337
+
338
+ logger.info("OpenPipe ART Runner cleanup completed")
339
+
340
+ def log_progress(self, epoch: int, metrics: dict[str, Any], output_dir: str | None = None) -> None:
341
+ """
342
+ Log training progress and create visualizations.
343
+
344
+ Args:
345
+ epoch: Current epoch number
346
+ metrics: Dictionary of metrics to log
347
+ output_dir: Optional output directory override
348
+ """
349
+ # Use provided output_dir or default
350
+ out_dir = Path(output_dir) if output_dir else self.run_config.output_dir
351
+ out_dir.mkdir(parents=True, exist_ok=True)
352
+
353
+ # Extract and store reward info
354
+ reward_info = {
355
+ "epoch": epoch,
356
+ "timestamp": datetime.now().isoformat(),
357
+ "avg_reward": metrics.get("avg_reward", 0.0),
358
+ "min_reward": metrics.get("min_reward", 0.0),
359
+ "max_reward": metrics.get("max_reward", 0.0),
360
+ "num_trajectories": metrics.get("num_trajectories", 0),
361
+ }
362
+ self._reward_history.append(reward_info)
363
+
364
+ # Create plots
365
+ self._create_reward_plot(epoch, out_dir)
366
+
367
+ # Log metrics to JSON file
368
+ self._log_metrics_to_file(epoch, metrics, out_dir)
369
+
370
+ logger.info("Epoch %d progress logged - Avg Reward: %.4f, Trajectories: %d",
371
+ epoch,
372
+ reward_info["avg_reward"],
373
+ reward_info["num_trajectories"])
374
+
375
+ def apply_curriculum_learning(self, trajectory_collection: TrajectoryCollection,
376
+ epoch: int) -> TrajectoryCollection:
377
+ """
378
+ Apply curriculum learning to filter trajectory groups based on difficulty.
379
+
380
+ This method:
381
+ 1. Sorts trajectory groups by average reward (difficulty)
382
+ 2. Filters out groups with no reward variance (no learning signal)
383
+ 3. Selects appropriate groups based on curriculum progression
384
+ 4. Expands curriculum at specified intervals
385
+
386
+ Args:
387
+ trajectory_collection: The complete collection of trajectories
388
+ epoch: Current epoch number
389
+
390
+ Returns:
391
+ TrajectoryCollection: Filtered trajectories for training
392
+ """
393
+ if not self.curriculum_config.enabled:
394
+ # Curriculum learning disabled, return all trajectories
395
+ return trajectory_collection
396
+
397
+ if len(trajectory_collection.trajectories) == 1:
398
+ # Only one group, so we pick only run a random subsample if specified
399
+ if self.curriculum_config.random_subsample is not None:
400
+ import random
401
+ fraction = self.curriculum_config.random_subsample
402
+ trajectory_group = trajectory_collection.trajectories[0]
403
+ max_required_trajectories = int(math.ceil(len(trajectory_group) * fraction))
404
+ if len(trajectory_group) > max_required_trajectories:
405
+ selected_trajectories = random.sample(trajectory_group, max_required_trajectories)
406
+ logger.info("After random subsampling %.2f, using %d trajectories from single group",
407
+ fraction,
408
+ len(selected_trajectories))
409
+ return TrajectoryCollection(trajectories=[selected_trajectories],
410
+ run_id=trajectory_collection.run_id)
411
+
412
+ return trajectory_collection
413
+
414
+ # Calculate statistics for each trajectory group
415
+ group_stats = []
416
+ for group_idx, trajectory_group in enumerate(trajectory_collection.trajectories):
417
+ if not trajectory_group:
418
+ continue
419
+
420
+ rewards = [t.reward for t in trajectory_group]
421
+ avg_reward = sum(rewards) / len(rewards)
422
+ variance = sum((r - avg_reward)**2 for r in rewards) / len(rewards)
423
+ max_diff = max(rewards) - min(rewards)
424
+
425
+ # Skip groups with insufficient reward variance (no learning signal)
426
+ if max_diff < self.curriculum_config.min_reward_diff:
427
+ logger.info("Skipping trajectory group %d with max_diff %.6f < %.6f (no learning signal)",
428
+ group_idx,
429
+ max_diff,
430
+ self.curriculum_config.min_reward_diff)
431
+ continue
432
+
433
+ group_stats.append({
434
+ "index": group_idx, "avg_reward": avg_reward, "variance": variance, "trajectories": trajectory_group
435
+ })
436
+
437
+ if not group_stats:
438
+ logger.warning("No trajectory groups with sufficient variance found")
439
+ return TrajectoryCollection(trajectories=[], run_id=trajectory_collection.run_id)
440
+
441
+ # Sort groups by average reward (difficulty)
442
+ group_stats.sort(key=lambda x: x["avg_reward"], reverse=not self.curriculum_config.sort_ascending)
443
+
444
+ # Store total groups if first epoch
445
+ if self._curriculum_state["total_groups"] == 0:
446
+ self._curriculum_state["total_groups"] = len(group_stats)
447
+
448
+ # Check if we should expand the curriculum
449
+ epochs_since_expansion = epoch - self._curriculum_state["last_expansion_epoch"]
450
+ should_expand = (epochs_since_expansion >= self.curriculum_config.expansion_interval
451
+ and self._curriculum_state["current_percentile"] < 1.0)
452
+
453
+ if should_expand:
454
+ # Expand curriculum by increment_percentile
455
+ old_percentile = self._curriculum_state["current_percentile"]
456
+ self._curriculum_state["current_percentile"] = min(
457
+ 1.0, old_percentile + self.curriculum_config.increment_percentile)
458
+ self._curriculum_state["last_expansion_epoch"] = epoch
459
+
460
+ logger.info("Expanding curriculum at epoch %d: %.1f%% -> %.1f%% of trajectory groups",
461
+ epoch,
462
+ old_percentile * 100,
463
+ self._curriculum_state["current_percentile"] * 100)
464
+
465
+ # Calculate number of groups to include
466
+ num_groups_to_include = max(
467
+ 1, # Always include at least one group
468
+ int(math.ceil(len(group_stats) * self._curriculum_state["current_percentile"])))
469
+
470
+ # Select the appropriate groups
471
+ selected_groups = group_stats[:num_groups_to_include]
472
+
473
+ # Track which groups are included
474
+ included_indices = {g["index"] for g in selected_groups}
475
+ new_groups = included_indices - self._curriculum_state["included_groups"]
476
+ if new_groups:
477
+ logger.info("Adding %d new trajectory groups to curriculum at epoch %d", len(new_groups), epoch)
478
+ self._curriculum_state["included_groups"] = included_indices
479
+
480
+ # Log curriculum statistics
481
+ selected_trajectories = [g["trajectories"] for g in selected_groups]
482
+ total_trajectories = sum(len(traj_list) for traj_list in selected_trajectories)
483
+
484
+ logger.info(
485
+ "Curriculum learning at epoch %d: Using %d/%d groups (%.1f%%), "
486
+ "%d total trajectories. Avg reward range: [%.4f, %.4f]",
487
+ epoch,
488
+ len(selected_groups),
489
+ len(group_stats),
490
+ self._curriculum_state["current_percentile"] * 100,
491
+ total_trajectories,
492
+ selected_groups[-1]["avg_reward"] if selected_groups else 0,
493
+ selected_groups[0]["avg_reward"] if selected_groups else 0)
494
+
495
+ if self.curriculum_config.random_subsample is not None:
496
+ # Randomly select only a fraction of trajectory groups to use
497
+ import random
498
+ fraction = self.curriculum_config.random_subsample
499
+ # Max required groups is the theoretical max based on fraction
500
+ max_required_groups = int(math.ceil(len(group_stats) * fraction))
501
+ # Now select at most that many groups from selected groups
502
+ if len(selected_groups) > max_required_groups:
503
+ selected_groups = random.sample(selected_groups, max_required_groups)
504
+ # Rebuild selected trajectories
505
+ selected_trajectories = [g["trajectories"] for g in selected_groups]
506
+ logger.info("After random subsampling %.2f, using %d trajectory groups", fraction, len(selected_groups))
507
+
508
+ return TrajectoryCollection(trajectories=selected_trajectories, run_id=trajectory_collection.run_id)
509
+
510
+ def _create_reward_plot(self, epoch: int, output_dir: Path) -> None:
511
+ """Create PNG plot showing reward progression and curriculum learning status."""
512
+ if not self._reward_history:
513
+ return
514
+
515
+ if not MATPLOTLIB_AVAILABLE:
516
+ logger.warning("Matplotlib not available, skipping plot generation")
517
+ return
518
+
519
+ # Create figure with potentially two y-axes
520
+ fig, ax = plt.subplots(figsize=(12, 7))
521
+
522
+ # Plot training rewards
523
+ epochs = [r["epoch"] for r in self._reward_history]
524
+ avg_rewards = [r["avg_reward"] for r in self._reward_history]
525
+
526
+ ax.plot(epochs, avg_rewards, 'b-', linewidth=2, label='Training Average Reward')
527
+ ax.scatter(epochs, avg_rewards, s=50, c='blue', zorder=5)
528
+
529
+ # Plot filtered average rewards if curriculum learning is enabled
530
+ if self.curriculum_config.enabled and any("filtered_avg_reward" in r for r in self._reward_history):
531
+ filtered_avg_rewards = [r.get("filtered_avg_reward", r["avg_reward"]) for r in self._reward_history]
532
+ ax.plot(epochs, filtered_avg_rewards, 'g:', linewidth=2, label='Filtered Avg Reward (Curriculum)')
533
+ ax.scatter(epochs, filtered_avg_rewards, s=30, c='green', zorder=4)
534
+
535
+ # Plot validation rewards if available
536
+ val_epochs = []
537
+ val_avg_rewards = []
538
+ if self._validation_history:
539
+ val_epochs = [r["epoch"] for r in self._validation_history]
540
+ val_avg_rewards = [r["avg_reward"] for r in self._validation_history]
541
+
542
+ ax.plot(val_epochs, val_avg_rewards, 'r--', linewidth=2, label='Validation Average Reward')
543
+ ax.scatter(val_epochs, val_avg_rewards, s=50, c='red', zorder=5)
544
+
545
+ # Combine all rewards for y-axis range calculation
546
+ all_rewards = avg_rewards + val_avg_rewards
547
+ else:
548
+ all_rewards = avg_rewards
549
+
550
+ # Calculate y-axis range with margin
551
+ if all_rewards:
552
+ min_avg = min(all_rewards)
553
+ max_avg = max(all_rewards)
554
+ # Add 10% margin on each side
555
+ range_margin = (max_avg - min_avg) * 0.1
556
+ # If all rewards are the same, use a fixed margin
557
+ if range_margin == 0:
558
+ range_margin = abs(min_avg) * 0.1 if min_avg != 0 else 0.1
559
+ ax.set_ylim(min_avg - range_margin, max_avg + range_margin)
560
+
561
+ # Add curriculum learning progression on secondary y-axis if enabled
562
+ if self.curriculum_config.enabled:
563
+ ax2 = ax.twinx()
564
+ curriculum_percentiles = [r.get("curriculum_percentile", 1.0) * 100 for r in self._reward_history]
565
+ ax2.plot(epochs, curriculum_percentiles, 'm-.', linewidth=1.5, label='Curriculum %', alpha=0.7)
566
+ ax2.set_ylabel('Curriculum Percentile (%)', fontsize=11, color='m')
567
+ ax2.set_ylim(0, 105)
568
+ ax2.tick_params(axis='y', labelcolor='m')
569
+ ax2.grid(False)
570
+
571
+ # Add shaded regions to indicate curriculum expansions
572
+ expansion_epochs = []
573
+ for i in range(1, len(curriculum_percentiles)):
574
+ if curriculum_percentiles[i] > curriculum_percentiles[i - 1]:
575
+ expansion_epochs.append(epochs[i])
576
+
577
+ for exp_epoch in expansion_epochs:
578
+ ax.axvline(x=exp_epoch, color='purple', linestyle=':', alpha=0.3, linewidth=1)
579
+
580
+ # Formatting
581
+ ax.set_xlabel('Epoch', fontsize=12)
582
+ ax.set_ylabel('Reward', fontsize=12)
583
+
584
+ title = f'Training Progress - Epoch {epoch}'
585
+ if self.curriculum_config.enabled:
586
+ title += f' (Curriculum Learning: {self._curriculum_state["current_percentile"]*100:.1f}%)'
587
+ ax.set_title(title, fontsize=14)
588
+
589
+ ax.grid(True, alpha=0.3)
590
+ ax.legend(loc='upper left')
591
+
592
+ # Set integer x-axis ticks
593
+ ax.set_xticks(epochs)
594
+
595
+ # Add value annotations for training (reduced to avoid clutter)
596
+ # Only annotate every 5th epoch if there are more than 10 epochs
597
+ annotation_epochs = epochs if len(epochs) <= 10 else epochs[::5]
598
+
599
+ for e in annotation_epochs:
600
+ idx = epochs.index(e)
601
+ ax.annotate(f'{avg_rewards[idx]:.3f}', (e, avg_rewards[idx]),
602
+ textcoords="offset points",
603
+ xytext=(0, 10),
604
+ ha='center',
605
+ fontsize=8,
606
+ color='blue')
607
+
608
+ # Add value annotations for validation (sparse)
609
+ if self._validation_history:
610
+ val_annotation_epochs = val_epochs if len(val_epochs) <= 5 else val_epochs[::2]
611
+ for e in val_annotation_epochs:
612
+ idx = val_epochs.index(e)
613
+ ax.annotate(f'{val_avg_rewards[idx]:.3f}', (e, val_avg_rewards[idx]),
614
+ textcoords="offset points",
615
+ xytext=(0, -15),
616
+ ha='center',
617
+ fontsize=8,
618
+ color='red')
619
+
620
+ # Save plot
621
+ plot_path = output_dir / "reward_plot.png"
622
+ plt.tight_layout()
623
+ plt.savefig(plot_path, dpi=150, bbox_inches='tight')
624
+ plt.close(fig)
625
+
626
+ logger.debug("Saved reward plot to %s", plot_path)
627
+
628
+ def _log_metrics_to_file(self, epoch: int, metrics: dict[str, Any], output_dir: Path) -> None:
629
+ """Log metrics to JSON file."""
630
+ # Create metrics log file
631
+ metrics_file = output_dir / "training_metrics.jsonl"
632
+
633
+ # Prepare log entry
634
+ log_entry = {"epoch": epoch, "timestamp": datetime.now().isoformat(), "run_id": self._run_id, **metrics}
635
+
636
+ # Add curriculum learning state if enabled
637
+ if self.curriculum_config.enabled:
638
+ log_entry["curriculum_state"] = self.get_curriculum_state()
639
+
640
+ # Append to file
641
+ with open(metrics_file, 'a', encoding='utf-8') as f:
642
+ f.write(json.dumps(log_entry) + '\n')
643
+
644
+ # Also save reward history separately
645
+ history_file = output_dir / "reward_history.json"
646
+ with open(history_file, 'w', encoding='utf-8') as f:
647
+ json.dump(self._reward_history, f, indent=2)
648
+
649
+ # Save validation history if available
650
+ if self._validation_history:
651
+ val_history_file = output_dir / "validation_history.json"
652
+ with open(val_history_file, 'w', encoding='utf-8') as f:
653
+ json.dump(self._validation_history, f, indent=2)
654
+
655
+ # Save curriculum learning history if enabled
656
+ if self.curriculum_config.enabled:
657
+ curriculum_file = output_dir / "curriculum_state.json"
658
+ with open(curriculum_file, 'w', encoding='utf-8') as f:
659
+ json.dump(self.get_curriculum_state(), f, indent=2)