nvidia-nat-openpipe-art 1.4.0a20260116__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nat/meta/pypi.md +23 -0
- nat/plugins/openpipe/__init__.py +0 -0
- nat/plugins/openpipe/config.py +77 -0
- nat/plugins/openpipe/register.py +71 -0
- nat/plugins/openpipe/trainer.py +659 -0
- nat/plugins/openpipe/trainer_adapter.py +339 -0
- nat/plugins/openpipe/trajectory_builder.py +333 -0
- nvidia_nat_openpipe_art-1.4.0a20260116.dist-info/METADATA +46 -0
- nvidia_nat_openpipe_art-1.4.0a20260116.dist-info/RECORD +14 -0
- nvidia_nat_openpipe_art-1.4.0a20260116.dist-info/WHEEL +5 -0
- nvidia_nat_openpipe_art-1.4.0a20260116.dist-info/entry_points.txt +2 -0
- nvidia_nat_openpipe_art-1.4.0a20260116.dist-info/licenses/LICENSE-3rd-party.txt +5478 -0
- nvidia_nat_openpipe_art-1.4.0a20260116.dist-info/licenses/LICENSE.md +201 -0
- nvidia_nat_openpipe_art-1.4.0a20260116.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,659 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
import json
|
|
17
|
+
import logging
|
|
18
|
+
import math
|
|
19
|
+
import uuid
|
|
20
|
+
from datetime import datetime
|
|
21
|
+
from pathlib import Path
|
|
22
|
+
from typing import Any
|
|
23
|
+
|
|
24
|
+
from nat.data_models.finetuning import FinetuneConfig
|
|
25
|
+
from nat.data_models.finetuning import TrainingJobRef
|
|
26
|
+
from nat.data_models.finetuning import TrainingJobStatus
|
|
27
|
+
from nat.data_models.finetuning import TrainingStatusEnum
|
|
28
|
+
from nat.data_models.finetuning import TrajectoryCollection
|
|
29
|
+
from nat.finetuning.interfaces.finetuning_runner import Trainer
|
|
30
|
+
|
|
31
|
+
from .config import ARTTrainerConfig
|
|
32
|
+
|
|
33
|
+
# Configure matplotlib for non-interactive backend
|
|
34
|
+
try:
|
|
35
|
+
import matplotlib
|
|
36
|
+
|
|
37
|
+
matplotlib.use('Agg')
|
|
38
|
+
import matplotlib.pyplot as plt
|
|
39
|
+
|
|
40
|
+
MATPLOTLIB_AVAILABLE = True
|
|
41
|
+
except ImportError:
|
|
42
|
+
MATPLOTLIB_AVAILABLE = False
|
|
43
|
+
plt = None
|
|
44
|
+
|
|
45
|
+
logger = logging.getLogger(__name__)
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
class ARTTrainer(Trainer):
|
|
49
|
+
"""
|
|
50
|
+
Concrete implementation of Trainer for the OpenPipe ART backend.
|
|
51
|
+
|
|
52
|
+
This runner orchestrates the finetuning process using:
|
|
53
|
+
- ARTTrajectoryBuilder to collect trajectories from evaluations
|
|
54
|
+
- ARTTrainerAdapter to submit trajectories to the ART training backend
|
|
55
|
+
"""
|
|
56
|
+
|
|
57
|
+
def __init__(self, trainer_config: ARTTrainerConfig, **kwargs) -> None:
|
|
58
|
+
"""
|
|
59
|
+
Initialize the OpenPipe ART Runner.
|
|
60
|
+
|
|
61
|
+
Args:
|
|
62
|
+
trainer_config: Configuration for the ART trainer backend
|
|
63
|
+
"""
|
|
64
|
+
super().__init__(trainer_config)
|
|
65
|
+
|
|
66
|
+
# Type hint for the specific config
|
|
67
|
+
self.trainer_config: ARTTrainerConfig = trainer_config
|
|
68
|
+
|
|
69
|
+
# Track job references
|
|
70
|
+
self._job_refs: list[TrainingJobRef] = []
|
|
71
|
+
self._run_id: str | None = None
|
|
72
|
+
|
|
73
|
+
# Track rewards for plotting
|
|
74
|
+
self._reward_history: list[dict] = []
|
|
75
|
+
self._validation_history: list[dict] = []
|
|
76
|
+
|
|
77
|
+
async def initialize(self, run_config: FinetuneConfig) -> None:
|
|
78
|
+
"""
|
|
79
|
+
Initialize the runner and its components.
|
|
80
|
+
|
|
81
|
+
This will:
|
|
82
|
+
- Initialize the TrainerAdapter and verify connectivity
|
|
83
|
+
- Prepare the TrajectoryBuilder for collecting trajectories
|
|
84
|
+
"""
|
|
85
|
+
logger.info("Initializing OpenPipe ART Runner")
|
|
86
|
+
|
|
87
|
+
await super().initialize(run_config)
|
|
88
|
+
|
|
89
|
+
# Generate a unique run ID
|
|
90
|
+
self._run_id = f"art_run_{uuid.uuid4().hex[:8]}"
|
|
91
|
+
|
|
92
|
+
logger.info(f"OpenPipe ART Runner initialized with run ID: {self._run_id}")
|
|
93
|
+
|
|
94
|
+
async def run_epoch(self, epoch: int, run_id: str) -> TrainingJobRef | None:
|
|
95
|
+
"""
|
|
96
|
+
Run a single epoch of training.
|
|
97
|
+
|
|
98
|
+
Args:
|
|
99
|
+
epoch: The current epoch number (0-indexed)
|
|
100
|
+
run_id: Unique identifier for this training run
|
|
101
|
+
|
|
102
|
+
Returns:
|
|
103
|
+
TrainingJobRef: Reference to the submitted training job
|
|
104
|
+
"""
|
|
105
|
+
logger.info(f"Starting epoch {epoch + 1} for run {run_id}")
|
|
106
|
+
|
|
107
|
+
# Start the trajectory builder for this epoch
|
|
108
|
+
epoch_meta = {
|
|
109
|
+
"epoch": epoch,
|
|
110
|
+
"run_id": run_id,
|
|
111
|
+
"trainer_config": self.trainer_config.model_dump(),
|
|
112
|
+
}
|
|
113
|
+
|
|
114
|
+
# Check if we should run validation
|
|
115
|
+
if (self.run_config.run_configuration.validation_dataset
|
|
116
|
+
and epoch % self.run_config.run_configuration.validation_interval == 0):
|
|
117
|
+
logger.info(f"Running validation at epoch {epoch + 1}")
|
|
118
|
+
validation_metrics = await self.run_validation_evaluation(epoch, self._run_id)
|
|
119
|
+
|
|
120
|
+
# Store validation metrics
|
|
121
|
+
validation_info = {
|
|
122
|
+
"epoch": epoch,
|
|
123
|
+
"timestamp": datetime.now().isoformat(),
|
|
124
|
+
"avg_reward": validation_metrics.get("avg_reward", 0.0),
|
|
125
|
+
"min_reward": validation_metrics.get("min_reward", 0.0),
|
|
126
|
+
"max_reward": validation_metrics.get("max_reward", 0.0),
|
|
127
|
+
"num_examples": validation_metrics.get("num_examples", 0),
|
|
128
|
+
}
|
|
129
|
+
self._validation_history.append(validation_info)
|
|
130
|
+
|
|
131
|
+
await self.trajectory_builder.start_run(run_id=run_id, meta=epoch_meta)
|
|
132
|
+
|
|
133
|
+
# Finalize and get trajectories
|
|
134
|
+
trajectory_collection = await self.trajectory_builder.finalize(run_id=run_id, meta=epoch_meta)
|
|
135
|
+
|
|
136
|
+
if not trajectory_collection.trajectories:
|
|
137
|
+
logger.warning(f"No trajectories collected for epoch {epoch}")
|
|
138
|
+
# Return a dummy job ref
|
|
139
|
+
return None
|
|
140
|
+
|
|
141
|
+
# Calculate metrics from the original trajectories (before curriculum filtering)
|
|
142
|
+
# trajectory_collection.trajectories is a list of lists
|
|
143
|
+
# Each inner list contains trajectories for a specific example
|
|
144
|
+
all_rewards = []
|
|
145
|
+
total_trajectories = 0
|
|
146
|
+
group_stats = []
|
|
147
|
+
|
|
148
|
+
for trajectory_list in trajectory_collection.trajectories:
|
|
149
|
+
group_rewards = []
|
|
150
|
+
for trajectory in trajectory_list:
|
|
151
|
+
total_trajectories += 1
|
|
152
|
+
if hasattr(trajectory, 'reward'):
|
|
153
|
+
reward = trajectory.reward
|
|
154
|
+
all_rewards.append(reward)
|
|
155
|
+
group_rewards.append(reward)
|
|
156
|
+
|
|
157
|
+
if group_rewards:
|
|
158
|
+
avg_group_reward = sum(group_rewards) / len(group_rewards)
|
|
159
|
+
variance = sum((r - avg_group_reward)**2 for r in group_rewards) / len(group_rewards)
|
|
160
|
+
group_stats.append({"avg_reward": avg_group_reward, "variance": variance, "size": len(group_rewards)})
|
|
161
|
+
|
|
162
|
+
logger.info(f"Collected {total_trajectories} trajectories in {len(trajectory_collection.trajectories)} "
|
|
163
|
+
f"groups for epoch {epoch}")
|
|
164
|
+
|
|
165
|
+
# Calculate reward statistics from all trajectories
|
|
166
|
+
if all_rewards:
|
|
167
|
+
avg_reward = sum(all_rewards) / len(all_rewards)
|
|
168
|
+
min_reward = min(all_rewards)
|
|
169
|
+
max_reward = max(all_rewards)
|
|
170
|
+
else:
|
|
171
|
+
avg_reward = min_reward = max_reward = 0.0
|
|
172
|
+
|
|
173
|
+
# Apply curriculum learning to filter trajectories
|
|
174
|
+
filtered_collection = self.apply_curriculum_learning(trajectory_collection, epoch)
|
|
175
|
+
|
|
176
|
+
# Calculate metrics after curriculum filtering
|
|
177
|
+
filtered_trajectories = 0
|
|
178
|
+
filtered_rewards = []
|
|
179
|
+
for trajectory_list in filtered_collection.trajectories:
|
|
180
|
+
for trajectory in trajectory_list:
|
|
181
|
+
filtered_trajectories += 1
|
|
182
|
+
if hasattr(trajectory, 'reward'):
|
|
183
|
+
filtered_rewards.append(trajectory.reward)
|
|
184
|
+
|
|
185
|
+
if filtered_rewards:
|
|
186
|
+
filtered_avg_reward = sum(filtered_rewards) / len(filtered_rewards)
|
|
187
|
+
filtered_min_reward = min(filtered_rewards)
|
|
188
|
+
filtered_max_reward = max(filtered_rewards)
|
|
189
|
+
else:
|
|
190
|
+
filtered_avg_reward = filtered_min_reward = filtered_max_reward = 0.0
|
|
191
|
+
|
|
192
|
+
# Log progress with both original and filtered metrics
|
|
193
|
+
metrics = {
|
|
194
|
+
"avg_reward":
|
|
195
|
+
avg_reward,
|
|
196
|
+
"min_reward":
|
|
197
|
+
min_reward,
|
|
198
|
+
"max_reward":
|
|
199
|
+
max_reward,
|
|
200
|
+
"num_trajectories":
|
|
201
|
+
total_trajectories,
|
|
202
|
+
"num_groups":
|
|
203
|
+
len(trajectory_collection.trajectories), # Curriculum metrics
|
|
204
|
+
"filtered_trajectories":
|
|
205
|
+
filtered_trajectories,
|
|
206
|
+
"filtered_groups":
|
|
207
|
+
len(filtered_collection.trajectories),
|
|
208
|
+
"filtered_avg_reward":
|
|
209
|
+
filtered_avg_reward,
|
|
210
|
+
"filtered_min_reward":
|
|
211
|
+
filtered_min_reward,
|
|
212
|
+
"filtered_max_reward":
|
|
213
|
+
filtered_max_reward,
|
|
214
|
+
"curriculum_percentile":
|
|
215
|
+
self._curriculum_state["current_percentile"] if self.curriculum_config.enabled else 1.0,
|
|
216
|
+
}
|
|
217
|
+
|
|
218
|
+
# Log group statistics if curriculum learning is enabled
|
|
219
|
+
if self.curriculum_config.enabled and group_stats:
|
|
220
|
+
sorted_groups = sorted(group_stats, key=lambda x: x["avg_reward"], reverse=True)
|
|
221
|
+
logger.info("Group reward distribution - Top: %.4f, Median: %.4f, Bottom: %.4f",
|
|
222
|
+
sorted_groups[0]["avg_reward"],
|
|
223
|
+
sorted_groups[len(sorted_groups) // 2]["avg_reward"],
|
|
224
|
+
sorted_groups[-1]["avg_reward"])
|
|
225
|
+
|
|
226
|
+
self.log_progress(epoch, metrics)
|
|
227
|
+
|
|
228
|
+
# Check if we have trajectories after filtering
|
|
229
|
+
if not filtered_collection.trajectories:
|
|
230
|
+
logger.warning(f"No trajectories remaining after curriculum filtering for epoch {epoch}")
|
|
231
|
+
return None
|
|
232
|
+
|
|
233
|
+
# Submit filtered trajectories to trainer
|
|
234
|
+
job_ref = await self.trainer_adapter.submit(filtered_collection)
|
|
235
|
+
self._job_refs.append(job_ref)
|
|
236
|
+
|
|
237
|
+
logger.info(f"Submitted training job for epoch {epoch}: {job_ref}")
|
|
238
|
+
|
|
239
|
+
return job_ref
|
|
240
|
+
|
|
241
|
+
async def run(self, num_epochs: int) -> list[TrainingJobStatus]:
|
|
242
|
+
"""
|
|
243
|
+
Run the complete finetuning workflow for the specified number of epochs.
|
|
244
|
+
|
|
245
|
+
Args:
|
|
246
|
+
num_epochs: Number of epochs to train
|
|
247
|
+
|
|
248
|
+
Returns:
|
|
249
|
+
list[TrainingJobStatus]: Status of all training jobs
|
|
250
|
+
"""
|
|
251
|
+
if not self._run_id:
|
|
252
|
+
raise RuntimeError("Runner not initialized. Did you forget to call initialize(...)?")
|
|
253
|
+
|
|
254
|
+
logger.info(f"Starting finetuning run with {num_epochs} epochs")
|
|
255
|
+
|
|
256
|
+
job_statuses = []
|
|
257
|
+
|
|
258
|
+
for epoch in range(num_epochs):
|
|
259
|
+
try:
|
|
260
|
+
# Run the epoch
|
|
261
|
+
job_ref = await self.run_epoch(epoch, self._run_id)
|
|
262
|
+
|
|
263
|
+
# Wait for completion before starting next epoch
|
|
264
|
+
if job_ref:
|
|
265
|
+
status = await self.trainer_adapter.wait_until_complete(job_ref)
|
|
266
|
+
job_statuses.append(status)
|
|
267
|
+
|
|
268
|
+
# Check if training failed
|
|
269
|
+
if status.status == TrainingStatusEnum.FAILED:
|
|
270
|
+
logger.error(f"Training failed at epoch {epoch}: {status.message}")
|
|
271
|
+
break
|
|
272
|
+
else:
|
|
273
|
+
# No trajectories collected, create a dummy status
|
|
274
|
+
job_statuses.append(
|
|
275
|
+
TrainingJobStatus(run_id=self._run_id,
|
|
276
|
+
backend="openpipe-art",
|
|
277
|
+
status=TrainingStatusEnum.COMPLETED,
|
|
278
|
+
message="No trajectories to train on",
|
|
279
|
+
metadata={"epoch": epoch}))
|
|
280
|
+
|
|
281
|
+
logger.info(f"Completed epoch {epoch + 1}/{num_epochs}")
|
|
282
|
+
|
|
283
|
+
except Exception as e:
|
|
284
|
+
logger.error(f"Error during epoch {epoch}: {e}")
|
|
285
|
+
job_statuses.append(
|
|
286
|
+
TrainingJobStatus(run_id=self._run_id,
|
|
287
|
+
backend="openpipe-art",
|
|
288
|
+
status=TrainingStatusEnum.FAILED,
|
|
289
|
+
message=str(e),
|
|
290
|
+
metadata={"epoch": epoch}))
|
|
291
|
+
break
|
|
292
|
+
|
|
293
|
+
logger.info(f"Finetuning run completed. Processed {len(job_statuses)} epochs")
|
|
294
|
+
return job_statuses
|
|
295
|
+
|
|
296
|
+
async def get_metrics(self, run_id: str) -> dict[str, Any]:
|
|
297
|
+
"""
|
|
298
|
+
Get training metrics for a specific run.
|
|
299
|
+
|
|
300
|
+
Args:
|
|
301
|
+
run_id: The run identifier
|
|
302
|
+
|
|
303
|
+
Returns:
|
|
304
|
+
dict: Metrics from the training run
|
|
305
|
+
"""
|
|
306
|
+
metrics = {"run_id": run_id, "total_epochs": len(self._job_refs), "jobs": []}
|
|
307
|
+
|
|
308
|
+
for job_ref in self._job_refs:
|
|
309
|
+
try:
|
|
310
|
+
status = await self.trainer_adapter.status(job_ref)
|
|
311
|
+
metrics["jobs"].append({"job_ref": job_ref.model_dump(), "status": status.model_dump()})
|
|
312
|
+
except Exception as e:
|
|
313
|
+
logger.error(f"Failed to get status for job {job_ref}: {e}")
|
|
314
|
+
metrics["jobs"].append({"job_ref": job_ref.model_dump(), "error": str(e)})
|
|
315
|
+
|
|
316
|
+
return metrics
|
|
317
|
+
|
|
318
|
+
async def cleanup(self) -> None:
|
|
319
|
+
"""
|
|
320
|
+
Clean up any resources used by the runner.
|
|
321
|
+
"""
|
|
322
|
+
logger.info("Cleaning up OpenPipe ART Runner resources")
|
|
323
|
+
|
|
324
|
+
# Cleanup trajectory builder tasks
|
|
325
|
+
if hasattr(self.trajectory_builder, 'evaluation_runs'):
|
|
326
|
+
for run_id, task in self.trajectory_builder.evaluation_runs.items():
|
|
327
|
+
if not task.done():
|
|
328
|
+
logger.info(f"Cancelling evaluation task for run {run_id}")
|
|
329
|
+
task.cancel()
|
|
330
|
+
|
|
331
|
+
# Cleanup trainer adapter tasks
|
|
332
|
+
if hasattr(self.trainer_adapter, 'training_jobs'):
|
|
333
|
+
for job_id, task in self.trainer_adapter.training_jobs.items():
|
|
334
|
+
if not task.done():
|
|
335
|
+
logger.info(f"Cancelling training task for job {job_id}")
|
|
336
|
+
task.cancel()
|
|
337
|
+
|
|
338
|
+
logger.info("OpenPipe ART Runner cleanup completed")
|
|
339
|
+
|
|
340
|
+
def log_progress(self, epoch: int, metrics: dict[str, Any], output_dir: str | None = None) -> None:
|
|
341
|
+
"""
|
|
342
|
+
Log training progress and create visualizations.
|
|
343
|
+
|
|
344
|
+
Args:
|
|
345
|
+
epoch: Current epoch number
|
|
346
|
+
metrics: Dictionary of metrics to log
|
|
347
|
+
output_dir: Optional output directory override
|
|
348
|
+
"""
|
|
349
|
+
# Use provided output_dir or default
|
|
350
|
+
out_dir = Path(output_dir) if output_dir else self.run_config.output_dir
|
|
351
|
+
out_dir.mkdir(parents=True, exist_ok=True)
|
|
352
|
+
|
|
353
|
+
# Extract and store reward info
|
|
354
|
+
reward_info = {
|
|
355
|
+
"epoch": epoch,
|
|
356
|
+
"timestamp": datetime.now().isoformat(),
|
|
357
|
+
"avg_reward": metrics.get("avg_reward", 0.0),
|
|
358
|
+
"min_reward": metrics.get("min_reward", 0.0),
|
|
359
|
+
"max_reward": metrics.get("max_reward", 0.0),
|
|
360
|
+
"num_trajectories": metrics.get("num_trajectories", 0),
|
|
361
|
+
}
|
|
362
|
+
self._reward_history.append(reward_info)
|
|
363
|
+
|
|
364
|
+
# Create plots
|
|
365
|
+
self._create_reward_plot(epoch, out_dir)
|
|
366
|
+
|
|
367
|
+
# Log metrics to JSON file
|
|
368
|
+
self._log_metrics_to_file(epoch, metrics, out_dir)
|
|
369
|
+
|
|
370
|
+
logger.info("Epoch %d progress logged - Avg Reward: %.4f, Trajectories: %d",
|
|
371
|
+
epoch,
|
|
372
|
+
reward_info["avg_reward"],
|
|
373
|
+
reward_info["num_trajectories"])
|
|
374
|
+
|
|
375
|
+
def apply_curriculum_learning(self, trajectory_collection: TrajectoryCollection,
|
|
376
|
+
epoch: int) -> TrajectoryCollection:
|
|
377
|
+
"""
|
|
378
|
+
Apply curriculum learning to filter trajectory groups based on difficulty.
|
|
379
|
+
|
|
380
|
+
This method:
|
|
381
|
+
1. Sorts trajectory groups by average reward (difficulty)
|
|
382
|
+
2. Filters out groups with no reward variance (no learning signal)
|
|
383
|
+
3. Selects appropriate groups based on curriculum progression
|
|
384
|
+
4. Expands curriculum at specified intervals
|
|
385
|
+
|
|
386
|
+
Args:
|
|
387
|
+
trajectory_collection: The complete collection of trajectories
|
|
388
|
+
epoch: Current epoch number
|
|
389
|
+
|
|
390
|
+
Returns:
|
|
391
|
+
TrajectoryCollection: Filtered trajectories for training
|
|
392
|
+
"""
|
|
393
|
+
if not self.curriculum_config.enabled:
|
|
394
|
+
# Curriculum learning disabled, return all trajectories
|
|
395
|
+
return trajectory_collection
|
|
396
|
+
|
|
397
|
+
if len(trajectory_collection.trajectories) == 1:
|
|
398
|
+
# Only one group, so we pick only run a random subsample if specified
|
|
399
|
+
if self.curriculum_config.random_subsample is not None:
|
|
400
|
+
import random
|
|
401
|
+
fraction = self.curriculum_config.random_subsample
|
|
402
|
+
trajectory_group = trajectory_collection.trajectories[0]
|
|
403
|
+
max_required_trajectories = int(math.ceil(len(trajectory_group) * fraction))
|
|
404
|
+
if len(trajectory_group) > max_required_trajectories:
|
|
405
|
+
selected_trajectories = random.sample(trajectory_group, max_required_trajectories)
|
|
406
|
+
logger.info("After random subsampling %.2f, using %d trajectories from single group",
|
|
407
|
+
fraction,
|
|
408
|
+
len(selected_trajectories))
|
|
409
|
+
return TrajectoryCollection(trajectories=[selected_trajectories],
|
|
410
|
+
run_id=trajectory_collection.run_id)
|
|
411
|
+
|
|
412
|
+
return trajectory_collection
|
|
413
|
+
|
|
414
|
+
# Calculate statistics for each trajectory group
|
|
415
|
+
group_stats = []
|
|
416
|
+
for group_idx, trajectory_group in enumerate(trajectory_collection.trajectories):
|
|
417
|
+
if not trajectory_group:
|
|
418
|
+
continue
|
|
419
|
+
|
|
420
|
+
rewards = [t.reward for t in trajectory_group]
|
|
421
|
+
avg_reward = sum(rewards) / len(rewards)
|
|
422
|
+
variance = sum((r - avg_reward)**2 for r in rewards) / len(rewards)
|
|
423
|
+
max_diff = max(rewards) - min(rewards)
|
|
424
|
+
|
|
425
|
+
# Skip groups with insufficient reward variance (no learning signal)
|
|
426
|
+
if max_diff < self.curriculum_config.min_reward_diff:
|
|
427
|
+
logger.info("Skipping trajectory group %d with max_diff %.6f < %.6f (no learning signal)",
|
|
428
|
+
group_idx,
|
|
429
|
+
max_diff,
|
|
430
|
+
self.curriculum_config.min_reward_diff)
|
|
431
|
+
continue
|
|
432
|
+
|
|
433
|
+
group_stats.append({
|
|
434
|
+
"index": group_idx, "avg_reward": avg_reward, "variance": variance, "trajectories": trajectory_group
|
|
435
|
+
})
|
|
436
|
+
|
|
437
|
+
if not group_stats:
|
|
438
|
+
logger.warning("No trajectory groups with sufficient variance found")
|
|
439
|
+
return TrajectoryCollection(trajectories=[], run_id=trajectory_collection.run_id)
|
|
440
|
+
|
|
441
|
+
# Sort groups by average reward (difficulty)
|
|
442
|
+
group_stats.sort(key=lambda x: x["avg_reward"], reverse=not self.curriculum_config.sort_ascending)
|
|
443
|
+
|
|
444
|
+
# Store total groups if first epoch
|
|
445
|
+
if self._curriculum_state["total_groups"] == 0:
|
|
446
|
+
self._curriculum_state["total_groups"] = len(group_stats)
|
|
447
|
+
|
|
448
|
+
# Check if we should expand the curriculum
|
|
449
|
+
epochs_since_expansion = epoch - self._curriculum_state["last_expansion_epoch"]
|
|
450
|
+
should_expand = (epochs_since_expansion >= self.curriculum_config.expansion_interval
|
|
451
|
+
and self._curriculum_state["current_percentile"] < 1.0)
|
|
452
|
+
|
|
453
|
+
if should_expand:
|
|
454
|
+
# Expand curriculum by increment_percentile
|
|
455
|
+
old_percentile = self._curriculum_state["current_percentile"]
|
|
456
|
+
self._curriculum_state["current_percentile"] = min(
|
|
457
|
+
1.0, old_percentile + self.curriculum_config.increment_percentile)
|
|
458
|
+
self._curriculum_state["last_expansion_epoch"] = epoch
|
|
459
|
+
|
|
460
|
+
logger.info("Expanding curriculum at epoch %d: %.1f%% -> %.1f%% of trajectory groups",
|
|
461
|
+
epoch,
|
|
462
|
+
old_percentile * 100,
|
|
463
|
+
self._curriculum_state["current_percentile"] * 100)
|
|
464
|
+
|
|
465
|
+
# Calculate number of groups to include
|
|
466
|
+
num_groups_to_include = max(
|
|
467
|
+
1, # Always include at least one group
|
|
468
|
+
int(math.ceil(len(group_stats) * self._curriculum_state["current_percentile"])))
|
|
469
|
+
|
|
470
|
+
# Select the appropriate groups
|
|
471
|
+
selected_groups = group_stats[:num_groups_to_include]
|
|
472
|
+
|
|
473
|
+
# Track which groups are included
|
|
474
|
+
included_indices = {g["index"] for g in selected_groups}
|
|
475
|
+
new_groups = included_indices - self._curriculum_state["included_groups"]
|
|
476
|
+
if new_groups:
|
|
477
|
+
logger.info("Adding %d new trajectory groups to curriculum at epoch %d", len(new_groups), epoch)
|
|
478
|
+
self._curriculum_state["included_groups"] = included_indices
|
|
479
|
+
|
|
480
|
+
# Log curriculum statistics
|
|
481
|
+
selected_trajectories = [g["trajectories"] for g in selected_groups]
|
|
482
|
+
total_trajectories = sum(len(traj_list) for traj_list in selected_trajectories)
|
|
483
|
+
|
|
484
|
+
logger.info(
|
|
485
|
+
"Curriculum learning at epoch %d: Using %d/%d groups (%.1f%%), "
|
|
486
|
+
"%d total trajectories. Avg reward range: [%.4f, %.4f]",
|
|
487
|
+
epoch,
|
|
488
|
+
len(selected_groups),
|
|
489
|
+
len(group_stats),
|
|
490
|
+
self._curriculum_state["current_percentile"] * 100,
|
|
491
|
+
total_trajectories,
|
|
492
|
+
selected_groups[-1]["avg_reward"] if selected_groups else 0,
|
|
493
|
+
selected_groups[0]["avg_reward"] if selected_groups else 0)
|
|
494
|
+
|
|
495
|
+
if self.curriculum_config.random_subsample is not None:
|
|
496
|
+
# Randomly select only a fraction of trajectory groups to use
|
|
497
|
+
import random
|
|
498
|
+
fraction = self.curriculum_config.random_subsample
|
|
499
|
+
# Max required groups is the theoretical max based on fraction
|
|
500
|
+
max_required_groups = int(math.ceil(len(group_stats) * fraction))
|
|
501
|
+
# Now select at most that many groups from selected groups
|
|
502
|
+
if len(selected_groups) > max_required_groups:
|
|
503
|
+
selected_groups = random.sample(selected_groups, max_required_groups)
|
|
504
|
+
# Rebuild selected trajectories
|
|
505
|
+
selected_trajectories = [g["trajectories"] for g in selected_groups]
|
|
506
|
+
logger.info("After random subsampling %.2f, using %d trajectory groups", fraction, len(selected_groups))
|
|
507
|
+
|
|
508
|
+
return TrajectoryCollection(trajectories=selected_trajectories, run_id=trajectory_collection.run_id)
|
|
509
|
+
|
|
510
|
+
def _create_reward_plot(self, epoch: int, output_dir: Path) -> None:
|
|
511
|
+
"""Create PNG plot showing reward progression and curriculum learning status."""
|
|
512
|
+
if not self._reward_history:
|
|
513
|
+
return
|
|
514
|
+
|
|
515
|
+
if not MATPLOTLIB_AVAILABLE:
|
|
516
|
+
logger.warning("Matplotlib not available, skipping plot generation")
|
|
517
|
+
return
|
|
518
|
+
|
|
519
|
+
# Create figure with potentially two y-axes
|
|
520
|
+
fig, ax = plt.subplots(figsize=(12, 7))
|
|
521
|
+
|
|
522
|
+
# Plot training rewards
|
|
523
|
+
epochs = [r["epoch"] for r in self._reward_history]
|
|
524
|
+
avg_rewards = [r["avg_reward"] for r in self._reward_history]
|
|
525
|
+
|
|
526
|
+
ax.plot(epochs, avg_rewards, 'b-', linewidth=2, label='Training Average Reward')
|
|
527
|
+
ax.scatter(epochs, avg_rewards, s=50, c='blue', zorder=5)
|
|
528
|
+
|
|
529
|
+
# Plot filtered average rewards if curriculum learning is enabled
|
|
530
|
+
if self.curriculum_config.enabled and any("filtered_avg_reward" in r for r in self._reward_history):
|
|
531
|
+
filtered_avg_rewards = [r.get("filtered_avg_reward", r["avg_reward"]) for r in self._reward_history]
|
|
532
|
+
ax.plot(epochs, filtered_avg_rewards, 'g:', linewidth=2, label='Filtered Avg Reward (Curriculum)')
|
|
533
|
+
ax.scatter(epochs, filtered_avg_rewards, s=30, c='green', zorder=4)
|
|
534
|
+
|
|
535
|
+
# Plot validation rewards if available
|
|
536
|
+
val_epochs = []
|
|
537
|
+
val_avg_rewards = []
|
|
538
|
+
if self._validation_history:
|
|
539
|
+
val_epochs = [r["epoch"] for r in self._validation_history]
|
|
540
|
+
val_avg_rewards = [r["avg_reward"] for r in self._validation_history]
|
|
541
|
+
|
|
542
|
+
ax.plot(val_epochs, val_avg_rewards, 'r--', linewidth=2, label='Validation Average Reward')
|
|
543
|
+
ax.scatter(val_epochs, val_avg_rewards, s=50, c='red', zorder=5)
|
|
544
|
+
|
|
545
|
+
# Combine all rewards for y-axis range calculation
|
|
546
|
+
all_rewards = avg_rewards + val_avg_rewards
|
|
547
|
+
else:
|
|
548
|
+
all_rewards = avg_rewards
|
|
549
|
+
|
|
550
|
+
# Calculate y-axis range with margin
|
|
551
|
+
if all_rewards:
|
|
552
|
+
min_avg = min(all_rewards)
|
|
553
|
+
max_avg = max(all_rewards)
|
|
554
|
+
# Add 10% margin on each side
|
|
555
|
+
range_margin = (max_avg - min_avg) * 0.1
|
|
556
|
+
# If all rewards are the same, use a fixed margin
|
|
557
|
+
if range_margin == 0:
|
|
558
|
+
range_margin = abs(min_avg) * 0.1 if min_avg != 0 else 0.1
|
|
559
|
+
ax.set_ylim(min_avg - range_margin, max_avg + range_margin)
|
|
560
|
+
|
|
561
|
+
# Add curriculum learning progression on secondary y-axis if enabled
|
|
562
|
+
if self.curriculum_config.enabled:
|
|
563
|
+
ax2 = ax.twinx()
|
|
564
|
+
curriculum_percentiles = [r.get("curriculum_percentile", 1.0) * 100 for r in self._reward_history]
|
|
565
|
+
ax2.plot(epochs, curriculum_percentiles, 'm-.', linewidth=1.5, label='Curriculum %', alpha=0.7)
|
|
566
|
+
ax2.set_ylabel('Curriculum Percentile (%)', fontsize=11, color='m')
|
|
567
|
+
ax2.set_ylim(0, 105)
|
|
568
|
+
ax2.tick_params(axis='y', labelcolor='m')
|
|
569
|
+
ax2.grid(False)
|
|
570
|
+
|
|
571
|
+
# Add shaded regions to indicate curriculum expansions
|
|
572
|
+
expansion_epochs = []
|
|
573
|
+
for i in range(1, len(curriculum_percentiles)):
|
|
574
|
+
if curriculum_percentiles[i] > curriculum_percentiles[i - 1]:
|
|
575
|
+
expansion_epochs.append(epochs[i])
|
|
576
|
+
|
|
577
|
+
for exp_epoch in expansion_epochs:
|
|
578
|
+
ax.axvline(x=exp_epoch, color='purple', linestyle=':', alpha=0.3, linewidth=1)
|
|
579
|
+
|
|
580
|
+
# Formatting
|
|
581
|
+
ax.set_xlabel('Epoch', fontsize=12)
|
|
582
|
+
ax.set_ylabel('Reward', fontsize=12)
|
|
583
|
+
|
|
584
|
+
title = f'Training Progress - Epoch {epoch}'
|
|
585
|
+
if self.curriculum_config.enabled:
|
|
586
|
+
title += f' (Curriculum Learning: {self._curriculum_state["current_percentile"]*100:.1f}%)'
|
|
587
|
+
ax.set_title(title, fontsize=14)
|
|
588
|
+
|
|
589
|
+
ax.grid(True, alpha=0.3)
|
|
590
|
+
ax.legend(loc='upper left')
|
|
591
|
+
|
|
592
|
+
# Set integer x-axis ticks
|
|
593
|
+
ax.set_xticks(epochs)
|
|
594
|
+
|
|
595
|
+
# Add value annotations for training (reduced to avoid clutter)
|
|
596
|
+
# Only annotate every 5th epoch if there are more than 10 epochs
|
|
597
|
+
annotation_epochs = epochs if len(epochs) <= 10 else epochs[::5]
|
|
598
|
+
|
|
599
|
+
for e in annotation_epochs:
|
|
600
|
+
idx = epochs.index(e)
|
|
601
|
+
ax.annotate(f'{avg_rewards[idx]:.3f}', (e, avg_rewards[idx]),
|
|
602
|
+
textcoords="offset points",
|
|
603
|
+
xytext=(0, 10),
|
|
604
|
+
ha='center',
|
|
605
|
+
fontsize=8,
|
|
606
|
+
color='blue')
|
|
607
|
+
|
|
608
|
+
# Add value annotations for validation (sparse)
|
|
609
|
+
if self._validation_history:
|
|
610
|
+
val_annotation_epochs = val_epochs if len(val_epochs) <= 5 else val_epochs[::2]
|
|
611
|
+
for e in val_annotation_epochs:
|
|
612
|
+
idx = val_epochs.index(e)
|
|
613
|
+
ax.annotate(f'{val_avg_rewards[idx]:.3f}', (e, val_avg_rewards[idx]),
|
|
614
|
+
textcoords="offset points",
|
|
615
|
+
xytext=(0, -15),
|
|
616
|
+
ha='center',
|
|
617
|
+
fontsize=8,
|
|
618
|
+
color='red')
|
|
619
|
+
|
|
620
|
+
# Save plot
|
|
621
|
+
plot_path = output_dir / "reward_plot.png"
|
|
622
|
+
plt.tight_layout()
|
|
623
|
+
plt.savefig(plot_path, dpi=150, bbox_inches='tight')
|
|
624
|
+
plt.close(fig)
|
|
625
|
+
|
|
626
|
+
logger.debug("Saved reward plot to %s", plot_path)
|
|
627
|
+
|
|
628
|
+
def _log_metrics_to_file(self, epoch: int, metrics: dict[str, Any], output_dir: Path) -> None:
|
|
629
|
+
"""Log metrics to JSON file."""
|
|
630
|
+
# Create metrics log file
|
|
631
|
+
metrics_file = output_dir / "training_metrics.jsonl"
|
|
632
|
+
|
|
633
|
+
# Prepare log entry
|
|
634
|
+
log_entry = {"epoch": epoch, "timestamp": datetime.now().isoformat(), "run_id": self._run_id, **metrics}
|
|
635
|
+
|
|
636
|
+
# Add curriculum learning state if enabled
|
|
637
|
+
if self.curriculum_config.enabled:
|
|
638
|
+
log_entry["curriculum_state"] = self.get_curriculum_state()
|
|
639
|
+
|
|
640
|
+
# Append to file
|
|
641
|
+
with open(metrics_file, 'a', encoding='utf-8') as f:
|
|
642
|
+
f.write(json.dumps(log_entry) + '\n')
|
|
643
|
+
|
|
644
|
+
# Also save reward history separately
|
|
645
|
+
history_file = output_dir / "reward_history.json"
|
|
646
|
+
with open(history_file, 'w', encoding='utf-8') as f:
|
|
647
|
+
json.dump(self._reward_history, f, indent=2)
|
|
648
|
+
|
|
649
|
+
# Save validation history if available
|
|
650
|
+
if self._validation_history:
|
|
651
|
+
val_history_file = output_dir / "validation_history.json"
|
|
652
|
+
with open(val_history_file, 'w', encoding='utf-8') as f:
|
|
653
|
+
json.dump(self._validation_history, f, indent=2)
|
|
654
|
+
|
|
655
|
+
# Save curriculum learning history if enabled
|
|
656
|
+
if self.curriculum_config.enabled:
|
|
657
|
+
curriculum_file = output_dir / "curriculum_state.json"
|
|
658
|
+
with open(curriculum_file, 'w', encoding='utf-8') as f:
|
|
659
|
+
json.dump(self.get_curriculum_state(), f, indent=2)
|