adaptive-harmony 0.1.23__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- adaptive_harmony/__init__.py +162 -0
- adaptive_harmony/common/__init__.py +40 -0
- adaptive_harmony/common/callbacks.py +219 -0
- adaptive_harmony/common/checkpointing.py +163 -0
- adaptive_harmony/common/dpo.py +92 -0
- adaptive_harmony/common/env_grpo.py +361 -0
- adaptive_harmony/common/grpo.py +260 -0
- adaptive_harmony/common/gspo.py +70 -0
- adaptive_harmony/common/ppo.py +303 -0
- adaptive_harmony/common/rm.py +79 -0
- adaptive_harmony/common/sft.py +121 -0
- adaptive_harmony/core/__init__.py +0 -0
- adaptive_harmony/core/dataset.py +72 -0
- adaptive_harmony/core/display.py +93 -0
- adaptive_harmony/core/image_utils.py +110 -0
- adaptive_harmony/core/reasoning.py +12 -0
- adaptive_harmony/core/reward_client/__init__.py +19 -0
- adaptive_harmony/core/reward_client/client.py +160 -0
- adaptive_harmony/core/reward_client/reward_types.py +49 -0
- adaptive_harmony/core/reward_client/websocket_utils.py +18 -0
- adaptive_harmony/core/rich_counter.py +351 -0
- adaptive_harmony/core/rl_utils.py +38 -0
- adaptive_harmony/core/schedulers.py +38 -0
- adaptive_harmony/core/structured_output.py +385 -0
- adaptive_harmony/core/utils.py +365 -0
- adaptive_harmony/environment/__init__.py +8 -0
- adaptive_harmony/environment/environment.py +121 -0
- adaptive_harmony/evaluation/__init__.py +1 -0
- adaptive_harmony/evaluation/evaluation_artifact.py +67 -0
- adaptive_harmony/graders/__init__.py +20 -0
- adaptive_harmony/graders/answer_relevancy_judge/__init__.py +3 -0
- adaptive_harmony/graders/answer_relevancy_judge/answer_relevancy_judge.py +102 -0
- adaptive_harmony/graders/answer_relevancy_judge/prompts.py +58 -0
- adaptive_harmony/graders/base_grader.py +265 -0
- adaptive_harmony/graders/binary_judge/__init__.py +8 -0
- adaptive_harmony/graders/binary_judge/binary_judge.py +202 -0
- adaptive_harmony/graders/binary_judge/prompts.py +125 -0
- adaptive_harmony/graders/combined_grader.py +118 -0
- adaptive_harmony/graders/context_relevancy_judge/__init__.py +3 -0
- adaptive_harmony/graders/context_relevancy_judge/context_relevancy_judge.py +128 -0
- adaptive_harmony/graders/context_relevancy_judge/prompts.py +84 -0
- adaptive_harmony/graders/exceptions.py +9 -0
- adaptive_harmony/graders/faithfulness_judge/__init__.py +3 -0
- adaptive_harmony/graders/faithfulness_judge/faithfulness_judge.py +159 -0
- adaptive_harmony/graders/faithfulness_judge/prompts.py +22 -0
- adaptive_harmony/graders/range_judge/__init__.py +7 -0
- adaptive_harmony/graders/range_judge/prompts.py +232 -0
- adaptive_harmony/graders/range_judge/range_judge.py +188 -0
- adaptive_harmony/graders/range_judge/types.py +12 -0
- adaptive_harmony/graders/reward_server_grader.py +36 -0
- adaptive_harmony/graders/templated_prompt_judge.py +237 -0
- adaptive_harmony/graders/utils.py +79 -0
- adaptive_harmony/logging_table.py +1 -0
- adaptive_harmony/metric_logger.py +452 -0
- adaptive_harmony/parameters/__init__.py +2 -0
- adaptive_harmony/py.typed +0 -0
- adaptive_harmony/runtime/__init__.py +2 -0
- adaptive_harmony/runtime/context.py +2 -0
- adaptive_harmony/runtime/data.py +2 -0
- adaptive_harmony/runtime/decorators.py +2 -0
- adaptive_harmony/runtime/model_artifact_save.py +2 -0
- adaptive_harmony/runtime/runner.py +27 -0
- adaptive_harmony/runtime/simple_notifier.py +2 -0
- adaptive_harmony-0.1.23.dist-info/METADATA +37 -0
- adaptive_harmony-0.1.23.dist-info/RECORD +67 -0
- adaptive_harmony-0.1.23.dist-info/WHEEL +5 -0
- adaptive_harmony-0.1.23.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,361 @@
|
|
|
1
|
+
import os
|
|
2
|
+
from dataclasses import dataclass
|
|
3
|
+
from typing import Callable, Sequence
|
|
4
|
+
|
|
5
|
+
import numpy as np
|
|
6
|
+
|
|
7
|
+
from adaptive_harmony import (
|
|
8
|
+
CosineScheduler,
|
|
9
|
+
DataSet,
|
|
10
|
+
JobNotifier,
|
|
11
|
+
Logger,
|
|
12
|
+
StageNotifier,
|
|
13
|
+
StringThread,
|
|
14
|
+
TokenizedThread,
|
|
15
|
+
TrainingModel,
|
|
16
|
+
)
|
|
17
|
+
from adaptive_harmony.common import RecipeCallback
|
|
18
|
+
from adaptive_harmony.common.checkpointing import CheckpointManager
|
|
19
|
+
from adaptive_harmony.core.utils import async_map, async_map_batch, get_minibatches, hash_hyperparams, log_args
|
|
20
|
+
from adaptive_harmony.environment import EnvironmentFactory, TrajectoryScore
|
|
21
|
+
from adaptive_harmony.metric_logger import StdoutLogger
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def compute_advantages(
|
|
25
|
+
scores: list[TrajectoryScore],
|
|
26
|
+
logprobs: list[list[float]],
|
|
27
|
+
samples: list[TokenizedThread],
|
|
28
|
+
num_generated_turns: list[int],
|
|
29
|
+
) -> list[list[float]]:
|
|
30
|
+
def get_assistant_lengths(samples: list[TokenizedThread], num_generated_turns: list[int]) -> list[list[int]]:
|
|
31
|
+
# here the +1 is because we have a loss weight on the EOD token of a turn, which is not represented when you look at the tokenized
|
|
32
|
+
# Keep only the last num_generated_turns assistant turns (those with weight>0) since logprobs_per_token only returns logprobs for them
|
|
33
|
+
return [
|
|
34
|
+
[len(turn.content) + 1 for turn in sample.get_turns() if turn.role == "assistant"][-num_gen:]
|
|
35
|
+
for sample, num_gen in zip(samples, num_generated_turns)
|
|
36
|
+
]
|
|
37
|
+
|
|
38
|
+
# FROM https://arxiv.org/pdf/2402.03300 -> Process Supervision RL with GRPO
|
|
39
|
+
# HERE PADDING DOES NOT PLAYS A ROLE IN ADVANTAGE COMPUTATION. SINCE nan are ignored.
|
|
40
|
+
|
|
41
|
+
mapped_scores = [[turn_score.score for turn_score in score.scores] for score in scores]
|
|
42
|
+
|
|
43
|
+
max_len = max(map(len, mapped_scores))
|
|
44
|
+
|
|
45
|
+
# pad with np.nan instead of 0
|
|
46
|
+
all_scores = np.full((len(mapped_scores), max_len), np.nan)
|
|
47
|
+
for i, s in enumerate(mapped_scores):
|
|
48
|
+
all_scores[i, : len(s)] = s
|
|
49
|
+
|
|
50
|
+
# nan-aware mean and std
|
|
51
|
+
mean = np.nanmean(all_scores)
|
|
52
|
+
std = np.nanstd(all_scores) + 1e-8
|
|
53
|
+
|
|
54
|
+
normalized_rewards = (all_scores - mean) / std
|
|
55
|
+
|
|
56
|
+
# cumulative sum per row, ignoring nans
|
|
57
|
+
score_level_advantage = np.where(
|
|
58
|
+
np.isnan(normalized_rewards),
|
|
59
|
+
np.nan,
|
|
60
|
+
np.cumsum(np.nan_to_num(normalized_rewards)[:, ::-1], axis=1)[:, ::-1],
|
|
61
|
+
)
|
|
62
|
+
|
|
63
|
+
turn_level_advantage = [
|
|
64
|
+
np.repeat(
|
|
65
|
+
adv[: len(score.scores)],
|
|
66
|
+
[turn_score.num_assistant_turns for turn_score in score.scores],
|
|
67
|
+
)
|
|
68
|
+
for adv, score in zip(score_level_advantage, scores)
|
|
69
|
+
]
|
|
70
|
+
|
|
71
|
+
assistant_lengths = get_assistant_lengths(samples, num_generated_turns)
|
|
72
|
+
assert all([len(lp) == sum(al) for lp, al in zip(logprobs, assistant_lengths)])
|
|
73
|
+
|
|
74
|
+
token_level_advantage = [np.repeat(adv, al).tolist() for adv, al in zip(turn_level_advantage, assistant_lengths)]
|
|
75
|
+
|
|
76
|
+
return token_level_advantage
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
@dataclass
|
|
80
|
+
class Sample:
|
|
81
|
+
sample: TokenizedThread
|
|
82
|
+
logprobs: list[float]
|
|
83
|
+
ref_logprobs: list[float]
|
|
84
|
+
advantage: list[float]
|
|
85
|
+
kl_div: list[float]
|
|
86
|
+
# for logging
|
|
87
|
+
score: float
|
|
88
|
+
gen_len: float
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
ENVGRPO_HYPERPARAMS = {
|
|
92
|
+
"max_num_grpo_steps",
|
|
93
|
+
"completions_per_sample",
|
|
94
|
+
"lr",
|
|
95
|
+
"lr_scheduler",
|
|
96
|
+
"samples_per_batch",
|
|
97
|
+
"samples_per_mini_batch",
|
|
98
|
+
"mini_epochs_per_batch",
|
|
99
|
+
"max_grad_norm",
|
|
100
|
+
"clip_range",
|
|
101
|
+
"kl_beta",
|
|
102
|
+
"weight_decays",
|
|
103
|
+
"skip_nan_gradients",
|
|
104
|
+
}
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
class ENVGRPO:
|
|
108
|
+
@log_args
|
|
109
|
+
@hash_hyperparams(include=ENVGRPO_HYPERPARAMS)
|
|
110
|
+
def __init__(
|
|
111
|
+
self,
|
|
112
|
+
dataset: list[StringThread],
|
|
113
|
+
model: TrainingModel,
|
|
114
|
+
environment_factory: EnvironmentFactory,
|
|
115
|
+
logger: Logger = StdoutLogger(),
|
|
116
|
+
stage_notifier: StageNotifier = JobNotifier().stage_notifier("ENVGRPO Training"),
|
|
117
|
+
callbacks: Sequence[RecipeCallback] = [],
|
|
118
|
+
validation_dataset: list[StringThread] | None = None,
|
|
119
|
+
validation_frequency: float = 0.2,
|
|
120
|
+
max_num_grpo_steps: int | None = None,
|
|
121
|
+
completions_per_sample=8,
|
|
122
|
+
lr: float = 7.5e-7,
|
|
123
|
+
lr_scheduler: Callable[[float], float] | None = None,
|
|
124
|
+
samples_per_batch=128,
|
|
125
|
+
samples_per_mini_batch=128,
|
|
126
|
+
mini_epochs_per_batch=1,
|
|
127
|
+
max_grad_norm=1.0,
|
|
128
|
+
clip_range=0.1,
|
|
129
|
+
kl_beta=0.1,
|
|
130
|
+
weight_decays: float = 0.0,
|
|
131
|
+
skip_nan_gradients: bool = False,
|
|
132
|
+
restart_from_checkpoint: str | None = None,
|
|
133
|
+
checkpoint_frequency: float = 0.2,
|
|
134
|
+
):
|
|
135
|
+
# Core components
|
|
136
|
+
self.dataset = DataSet(dataset, allow_looping=True)
|
|
137
|
+
self.model = model
|
|
138
|
+
self.logger = logger
|
|
139
|
+
self.stage_notifier = stage_notifier
|
|
140
|
+
self.sample_index_counter = 0
|
|
141
|
+
self.skip_nan_gradients = skip_nan_gradients
|
|
142
|
+
# Validation data/params
|
|
143
|
+
self.validation_dataset = validation_dataset
|
|
144
|
+
self.validation_frequency = validation_frequency
|
|
145
|
+
self.last_validation_percentage = -1.0 # Validation will run before training starts
|
|
146
|
+
# GRPO HP's
|
|
147
|
+
self.max_num_batches = max_num_grpo_steps
|
|
148
|
+
self.completions_per_sample = completions_per_sample
|
|
149
|
+
self.lr_schedule = lr_scheduler or CosineScheduler(lr)
|
|
150
|
+
self.samples_per_batch = samples_per_batch // completions_per_sample
|
|
151
|
+
self.samples_per_mini_batch = samples_per_mini_batch
|
|
152
|
+
self.total_num_samples = (
|
|
153
|
+
self.max_num_batches * self.samples_per_batch if self.max_num_batches else len(self.dataset)
|
|
154
|
+
)
|
|
155
|
+
self.max_grad_norm = max_grad_norm
|
|
156
|
+
self.environment_factory = environment_factory
|
|
157
|
+
self.clip_range = clip_range
|
|
158
|
+
self.kl_beta = kl_beta
|
|
159
|
+
self.weight_decays = weight_decays
|
|
160
|
+
self.mini_epochs_per_batch = mini_epochs_per_batch
|
|
161
|
+
|
|
162
|
+
self.num_batches_processed = 0
|
|
163
|
+
self.callbacks = callbacks
|
|
164
|
+
|
|
165
|
+
self.checkpoint_manager = CheckpointManager(
|
|
166
|
+
recipe_name="ENVGRPO",
|
|
167
|
+
dataset=self.dataset,
|
|
168
|
+
threads_dataset=dataset,
|
|
169
|
+
callbacks=self.callbacks,
|
|
170
|
+
hyperparams_hash=self._hyperparams_hash, # type: ignore
|
|
171
|
+
job_id=os.environ.get("HARMONY_JOB_ID"),
|
|
172
|
+
checkpoint_frequency=checkpoint_frequency,
|
|
173
|
+
restart_from_checkpoint=restart_from_checkpoint,
|
|
174
|
+
)
|
|
175
|
+
|
|
176
|
+
@property
|
|
177
|
+
def training_completion_percentage(self):
|
|
178
|
+
return (
|
|
179
|
+
self.dataset.completion_percentage()
|
|
180
|
+
if self.max_num_batches is None
|
|
181
|
+
else min(self.num_batches_processed / self.max_num_batches, 1.0)
|
|
182
|
+
)
|
|
183
|
+
|
|
184
|
+
async def gen_data(self, sample: StringThread) -> list[Sample]:
|
|
185
|
+
async def generate_trajectory(
|
|
186
|
+
prompt: StringThread,
|
|
187
|
+
) -> tuple[TokenizedThread, TrajectoryScore, int]:
|
|
188
|
+
# this create the environment for the first turn.
|
|
189
|
+
environment = self.environment_factory.create_environment(prompt.metadata)
|
|
190
|
+
prompt = await environment.bootstrap_prompt(prompt)
|
|
191
|
+
|
|
192
|
+
# Count assistant turns in the context (before generation)
|
|
193
|
+
nb_context_assistant_turns = sum(1 for turn in prompt.get_turns() if turn.role == "assistant")
|
|
194
|
+
|
|
195
|
+
string_trajectory = await self.model.generate(prompt) # generate the first response from the agent.
|
|
196
|
+
num_generated_turns = 1
|
|
197
|
+
# we loop until the environment returns a score.
|
|
198
|
+
# notice how the environment can return a score or a tool or user response.
|
|
199
|
+
while not isinstance(
|
|
200
|
+
environment_response := await environment.react_to(string_trajectory),
|
|
201
|
+
TrajectoryScore,
|
|
202
|
+
):
|
|
203
|
+
for env_role, env_content in environment_response:
|
|
204
|
+
if not isinstance(env_content, str):
|
|
205
|
+
raise ValueError(f"env_content should be a str, got {env_content}")
|
|
206
|
+
if env_role == "user":
|
|
207
|
+
string_trajectory = string_trajectory.user(env_content)
|
|
208
|
+
elif env_role == "tool":
|
|
209
|
+
string_trajectory = string_trajectory.tool(env_content)
|
|
210
|
+
else:
|
|
211
|
+
raise ValueError
|
|
212
|
+
string_trajectory = await self.model.generate(string_trajectory)
|
|
213
|
+
num_generated_turns += 1
|
|
214
|
+
|
|
215
|
+
tokenized_trajectory = (
|
|
216
|
+
await self.model.tokenize_thread(string_trajectory)
|
|
217
|
+
).with_weight_assistant_turns_from_index(nb_context_assistant_turns)
|
|
218
|
+
|
|
219
|
+
return tokenized_trajectory, environment_response, num_generated_turns
|
|
220
|
+
|
|
221
|
+
assert self.model_ref is not None, "Calling `gen_data` before reference model has been set"
|
|
222
|
+
|
|
223
|
+
trajs_and_scores = await async_map(generate_trajectory, [sample] * self.completions_per_sample)
|
|
224
|
+
all_samples = [traj for traj, _, _ in trajs_and_scores]
|
|
225
|
+
num_generated_turns_list = [num_turns for _, _, num_turns in trajs_and_scores]
|
|
226
|
+
logprobs = await async_map(self.model.logprobs_per_token, all_samples)
|
|
227
|
+
ref_logprobs = await async_map(self.model_ref.logprobs_per_token, all_samples)
|
|
228
|
+
|
|
229
|
+
all_trajectory_scores = [score for _, score, _ in trajs_and_scores]
|
|
230
|
+
advantages = compute_advantages(all_trajectory_scores, logprobs, all_samples, num_generated_turns_list)
|
|
231
|
+
|
|
232
|
+
kl = [
|
|
233
|
+
(np.array(lp, dtype=np.float32) - np.array(ref_lp, dtype=np.float32)).tolist()
|
|
234
|
+
for lp, ref_lp in zip(logprobs, ref_logprobs)
|
|
235
|
+
]
|
|
236
|
+
|
|
237
|
+
samples = []
|
|
238
|
+
for i in range(len(logprobs)):
|
|
239
|
+
samples.append(
|
|
240
|
+
Sample(
|
|
241
|
+
sample=all_samples[i],
|
|
242
|
+
logprobs=logprobs[i],
|
|
243
|
+
ref_logprobs=ref_logprobs[i],
|
|
244
|
+
advantage=advantages[i],
|
|
245
|
+
kl_div=kl[i],
|
|
246
|
+
score=all_trajectory_scores[i].cumulative_score,
|
|
247
|
+
gen_len=all_samples[i].len_last_turn(),
|
|
248
|
+
)
|
|
249
|
+
)
|
|
250
|
+
return samples
|
|
251
|
+
|
|
252
|
+
async def train_sample(self, sample: Sample):
|
|
253
|
+
await self.model.train_grpo(
|
|
254
|
+
sample.sample,
|
|
255
|
+
sample.logprobs,
|
|
256
|
+
sample.ref_logprobs,
|
|
257
|
+
sample.advantage,
|
|
258
|
+
self.clip_range,
|
|
259
|
+
self.kl_beta,
|
|
260
|
+
)
|
|
261
|
+
|
|
262
|
+
async def _recipe_specific_checkpoint_loading(self, checkpoint_data: dict) -> None:
|
|
263
|
+
self.num_batches_processed = checkpoint_data["num_batches_processed"]
|
|
264
|
+
|
|
265
|
+
model_checkpoint_name = checkpoint_data["model_checkpoint_name"]
|
|
266
|
+
model_checkpoint_name = await self.model.load(f"model_registry://{model_checkpoint_name}")
|
|
267
|
+
|
|
268
|
+
self.last_validation_percentage = checkpoint_data.get("last_validation_percentage", -1.0)
|
|
269
|
+
self.sample_index_counter = checkpoint_data.get("sample_index_counter", 0)
|
|
270
|
+
|
|
271
|
+
self.model.set_optim_step(checkpoint_data["optim_step"])
|
|
272
|
+
|
|
273
|
+
async def _recipe_specific_checkpoint_saving(self) -> dict:
|
|
274
|
+
progress_pct = int(self.training_completion_percentage * 100)
|
|
275
|
+
model_checkpoint_name = f"checkpoint-{self.checkpoint_manager.job_id}-{progress_pct}-policy"
|
|
276
|
+
await self.model.save(model_checkpoint_name, inference_only=False)
|
|
277
|
+
|
|
278
|
+
return {
|
|
279
|
+
"num_batches_processed": self.num_batches_processed,
|
|
280
|
+
"model_checkpoint_name": model_checkpoint_name,
|
|
281
|
+
"last_validation_percentage": self.last_validation_percentage,
|
|
282
|
+
"sample_index_counter": self.sample_index_counter,
|
|
283
|
+
"optim_step": self.model.get_optim_step(),
|
|
284
|
+
}
|
|
285
|
+
|
|
286
|
+
async def run(self):
|
|
287
|
+
self.model_ref = await self.model.clone_inf()
|
|
288
|
+
await self.checkpoint_manager.maybe_restore_checkpoint(self._recipe_specific_checkpoint_loading)
|
|
289
|
+
|
|
290
|
+
self.stage_notifier.report_progress(
|
|
291
|
+
tot_num_samples=self.total_num_samples,
|
|
292
|
+
processed_num_samples=self.num_batches_processed * self.samples_per_batch,
|
|
293
|
+
monitoring_link=self.logger.training_monitoring_link,
|
|
294
|
+
)
|
|
295
|
+
|
|
296
|
+
while self.training_completion_percentage < 1.0:
|
|
297
|
+
self.num_batches_processed += 1
|
|
298
|
+
|
|
299
|
+
for callback in self.callbacks:
|
|
300
|
+
if logs := await callback.maybe_call(self.training_completion_percentage):
|
|
301
|
+
self.logger(logs)
|
|
302
|
+
|
|
303
|
+
# Generate training samples
|
|
304
|
+
data = await async_map_batch(self.gen_data, self.dataset, self.samples_per_batch)
|
|
305
|
+
|
|
306
|
+
scorer_logs = {}
|
|
307
|
+
for key, value in self.environment_factory.get_logs(clear=True).items():
|
|
308
|
+
if "/" not in key:
|
|
309
|
+
key = f"environment/{key}"
|
|
310
|
+
scorer_logs[key] = value
|
|
311
|
+
batch_logs = {
|
|
312
|
+
**scorer_logs,
|
|
313
|
+
**self.get_train_batch_logs(data),
|
|
314
|
+
}
|
|
315
|
+
|
|
316
|
+
current_lr = self.lr_schedule(self.training_completion_percentage)
|
|
317
|
+
|
|
318
|
+
# Train on generated samples
|
|
319
|
+
flattened_data = sum([inner_list for inner_list in data], start=[])
|
|
320
|
+
minibatches = get_minibatches(flattened_data, self.samples_per_mini_batch, self.mini_epochs_per_batch)
|
|
321
|
+
for idx, mini_batch in enumerate(minibatches):
|
|
322
|
+
await async_map(self.train_sample, mini_batch)
|
|
323
|
+
optim_logs = await self.model.optim_step(
|
|
324
|
+
current_lr, wd=0.0, max_grad_norm=self.max_grad_norm, skip_nan_gradients=self.skip_nan_gradients
|
|
325
|
+
)
|
|
326
|
+
if idx == len(minibatches) - 1:
|
|
327
|
+
# only log tables and full batch-related logs on the final minibatch
|
|
328
|
+
self.logger(optim_logs | batch_logs)
|
|
329
|
+
else:
|
|
330
|
+
self.logger(optim_logs | dict(completion_percentage=self.training_completion_percentage))
|
|
331
|
+
|
|
332
|
+
self.stage_notifier.report_progress(
|
|
333
|
+
tot_num_samples=self.total_num_samples,
|
|
334
|
+
processed_num_samples=self.num_batches_processed * self.samples_per_batch,
|
|
335
|
+
monitoring_link=self.logger.training_monitoring_link,
|
|
336
|
+
)
|
|
337
|
+
|
|
338
|
+
if await self.checkpoint_manager.maybe_checkpoint(
|
|
339
|
+
self.training_completion_percentage, self._recipe_specific_checkpoint_saving
|
|
340
|
+
):
|
|
341
|
+
break
|
|
342
|
+
|
|
343
|
+
def get_train_batch_logs(self, data: list[list[Sample]]) -> dict:
|
|
344
|
+
return {
|
|
345
|
+
**dict(
|
|
346
|
+
completion_percentage=self.training_completion_percentage,
|
|
347
|
+
score_mean=np.mean([[sample.score for sample in batch] for batch in data]).item(),
|
|
348
|
+
percentage_no_advantages=np.mean(
|
|
349
|
+
[all(sample.advantage == batch[0].advantage for sample in batch) for batch in data]
|
|
350
|
+
).item(),
|
|
351
|
+
score_std=np.std([[sample.score for sample in batch] for batch in data]).item(),
|
|
352
|
+
kl_div=np.mean([[np.mean(sample.kl_div) for sample in batch] for batch in data]).item(),
|
|
353
|
+
generation_length=np.mean([np.mean([sample.gen_len for sample in batch]) for batch in data]).item(),
|
|
354
|
+
logprobs=np.mean(
|
|
355
|
+
np.concatenate([np.concatenate([sample.logprobs for sample in batch]) for batch in data])
|
|
356
|
+
).item(),
|
|
357
|
+
ref_logprobs=np.mean(
|
|
358
|
+
np.concatenate([np.concatenate([sample.ref_logprobs for sample in batch]) for batch in data])
|
|
359
|
+
).item(),
|
|
360
|
+
),
|
|
361
|
+
}
|
|
@@ -0,0 +1,260 @@
|
|
|
1
|
+
import os
|
|
2
|
+
from dataclasses import dataclass
|
|
3
|
+
from typing import Callable, Sequence, TypeAlias
|
|
4
|
+
|
|
5
|
+
import numpy as np
|
|
6
|
+
from numpy.typing import NDArray
|
|
7
|
+
|
|
8
|
+
from adaptive_harmony import (
|
|
9
|
+
CosineScheduler,
|
|
10
|
+
DataSet,
|
|
11
|
+
JobNotifier,
|
|
12
|
+
Logger,
|
|
13
|
+
StageNotifier,
|
|
14
|
+
StringThread,
|
|
15
|
+
TokenizedThread,
|
|
16
|
+
TrainingModel,
|
|
17
|
+
)
|
|
18
|
+
from adaptive_harmony.common.callbacks import RecipeCallback
|
|
19
|
+
from adaptive_harmony.common.checkpointing import CheckpointManager
|
|
20
|
+
from adaptive_harmony.core.utils import async_map, async_map_batch, get_minibatches, hash_hyperparams, log_args
|
|
21
|
+
from adaptive_harmony.graders import BaseGrader
|
|
22
|
+
from adaptive_harmony.metric_logger import StdoutLogger
|
|
23
|
+
|
|
24
|
+
FloatArray: TypeAlias = NDArray[np.float32]
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
@dataclass
|
|
28
|
+
class Sample:
|
|
29
|
+
sample: TokenizedThread
|
|
30
|
+
logprobs: list[float]
|
|
31
|
+
ref_logprobs: list[float]
|
|
32
|
+
advantage: float
|
|
33
|
+
score: float
|
|
34
|
+
kl_div: list[float]
|
|
35
|
+
gen_len: float
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
GRPO_HYPERPARAMS = {
|
|
39
|
+
"max_num_grpo_steps",
|
|
40
|
+
"completions_per_sample",
|
|
41
|
+
"lr",
|
|
42
|
+
"lr_scheduler",
|
|
43
|
+
"samples_per_batch",
|
|
44
|
+
"samples_per_mini_batch",
|
|
45
|
+
"mini_epochs_per_batch",
|
|
46
|
+
"max_grad_norm",
|
|
47
|
+
"clip_range",
|
|
48
|
+
"kl_beta",
|
|
49
|
+
"weight_decay",
|
|
50
|
+
"skip_nan_gradients",
|
|
51
|
+
}
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
class GRPO:
|
|
55
|
+
@log_args
|
|
56
|
+
@hash_hyperparams(include=GRPO_HYPERPARAMS)
|
|
57
|
+
def __init__(
|
|
58
|
+
self,
|
|
59
|
+
dataset: list[StringThread],
|
|
60
|
+
model: TrainingModel,
|
|
61
|
+
grader: BaseGrader,
|
|
62
|
+
logger: Logger = StdoutLogger(),
|
|
63
|
+
stage_notifier: StageNotifier = JobNotifier().stage_notifier("GRPO Training"),
|
|
64
|
+
callbacks: Sequence[RecipeCallback] = [],
|
|
65
|
+
max_num_grpo_steps: int | None = None,
|
|
66
|
+
completions_per_sample=8,
|
|
67
|
+
lr: float = 7.5e-7,
|
|
68
|
+
lr_scheduler: Callable[[float], float] | None = None,
|
|
69
|
+
samples_per_batch=128,
|
|
70
|
+
samples_per_mini_batch=128,
|
|
71
|
+
mini_epochs_per_batch=1,
|
|
72
|
+
max_grad_norm=1.0,
|
|
73
|
+
clip_range=0.1,
|
|
74
|
+
kl_beta=0.01,
|
|
75
|
+
weight_decay=0.0,
|
|
76
|
+
skip_nan_gradients: bool = False,
|
|
77
|
+
restart_from_checkpoint: str | None = None,
|
|
78
|
+
checkpoint_frequency: float = 0.2,
|
|
79
|
+
):
|
|
80
|
+
# Core components
|
|
81
|
+
self.dataset = DataSet(dataset, allow_looping=True)
|
|
82
|
+
self.model = model
|
|
83
|
+
self.grader = grader
|
|
84
|
+
self.scoring_fn = grader.score_float_value
|
|
85
|
+
self.logger = logger
|
|
86
|
+
self.stage_notifier = stage_notifier
|
|
87
|
+
self.callbacks = callbacks
|
|
88
|
+
self.skip_nan_gradients = skip_nan_gradients
|
|
89
|
+
# GRPO HP's
|
|
90
|
+
self.max_num_batches = max_num_grpo_steps
|
|
91
|
+
self.completions_per_sample = completions_per_sample
|
|
92
|
+
self.lr_schedule = lr_scheduler or CosineScheduler(lr)
|
|
93
|
+
self.prompts_per_batch = samples_per_batch // completions_per_sample
|
|
94
|
+
self.samples_per_mini_batch = samples_per_mini_batch
|
|
95
|
+
self.total_num_samples = (
|
|
96
|
+
self.max_num_batches * self.prompts_per_batch if self.max_num_batches else len(self.dataset)
|
|
97
|
+
)
|
|
98
|
+
self.max_grad_norm = max_grad_norm
|
|
99
|
+
self.clip_range = clip_range
|
|
100
|
+
self.kl_beta = kl_beta
|
|
101
|
+
self.weight_decay = weight_decay
|
|
102
|
+
self.mini_epochs_per_batch = mini_epochs_per_batch
|
|
103
|
+
|
|
104
|
+
self.num_batches_processed = 0
|
|
105
|
+
|
|
106
|
+
self.checkpoint_manager = CheckpointManager(
|
|
107
|
+
recipe_name="GRPO",
|
|
108
|
+
dataset=self.dataset,
|
|
109
|
+
threads_dataset=dataset,
|
|
110
|
+
callbacks=self.callbacks,
|
|
111
|
+
hyperparams_hash=self._hyperparams_hash, # type: ignore
|
|
112
|
+
job_id=os.environ.get("HARMONY_JOB_ID"),
|
|
113
|
+
checkpoint_frequency=checkpoint_frequency,
|
|
114
|
+
restart_from_checkpoint=restart_from_checkpoint,
|
|
115
|
+
)
|
|
116
|
+
|
|
117
|
+
@property
|
|
118
|
+
def training_completion_percentage(self):
|
|
119
|
+
return (
|
|
120
|
+
self.dataset.completion_percentage()
|
|
121
|
+
if self.max_num_batches is None
|
|
122
|
+
else min(self.num_batches_processed / self.max_num_batches, 1.0)
|
|
123
|
+
)
|
|
124
|
+
|
|
125
|
+
async def gen_data(self, sample: StringThread) -> list[Sample]:
|
|
126
|
+
assert self.model_ref is not None, "Calling `gen_data` before reference model has been set"
|
|
127
|
+
|
|
128
|
+
all_samples = await async_map(self.model.generate_tokens, [sample] * self.completions_per_sample)
|
|
129
|
+
string_samples = await async_map(self.model.detokenize_thread, all_samples)
|
|
130
|
+
all_scores = np.array(await async_map(self.scoring_fn, string_samples), dtype=np.float32)
|
|
131
|
+
|
|
132
|
+
advantages: FloatArray = all_scores - all_scores.mean()
|
|
133
|
+
advantages /= advantages.std() + 1e-8
|
|
134
|
+
|
|
135
|
+
logprobs = await async_map(self.model.logprobs_per_token, all_samples)
|
|
136
|
+
ref_logprobs = await async_map(self.model_ref.logprobs_per_token, all_samples)
|
|
137
|
+
kl = [
|
|
138
|
+
(np.array(lp, dtype=np.float32) - np.array(ref_lp, dtype=np.float32)).tolist()
|
|
139
|
+
for lp, ref_lp in zip(logprobs, ref_logprobs)
|
|
140
|
+
]
|
|
141
|
+
|
|
142
|
+
samples = []
|
|
143
|
+
for i in range(len(logprobs)):
|
|
144
|
+
samples.append(
|
|
145
|
+
Sample(
|
|
146
|
+
sample=all_samples[i],
|
|
147
|
+
logprobs=logprobs[i],
|
|
148
|
+
ref_logprobs=ref_logprobs[i],
|
|
149
|
+
advantage=advantages[i],
|
|
150
|
+
score=all_scores[i],
|
|
151
|
+
kl_div=kl[i],
|
|
152
|
+
gen_len=all_samples[i].len_last_turn(),
|
|
153
|
+
)
|
|
154
|
+
)
|
|
155
|
+
return samples
|
|
156
|
+
|
|
157
|
+
async def train_sample(self, sample: Sample):
|
|
158
|
+
await self.model.train_grpo(
|
|
159
|
+
sample.sample,
|
|
160
|
+
sample.logprobs,
|
|
161
|
+
sample.ref_logprobs,
|
|
162
|
+
[sample.advantage] * len(sample.logprobs),
|
|
163
|
+
self.clip_range,
|
|
164
|
+
self.kl_beta,
|
|
165
|
+
)
|
|
166
|
+
|
|
167
|
+
async def _recipe_specific_checkpoint_loading(self, checkpoint_data: dict) -> None:
|
|
168
|
+
self.num_batches_processed = checkpoint_data["num_batches_processed"]
|
|
169
|
+
|
|
170
|
+
model_checkpoint_name = checkpoint_data["model_checkpoint_name"]
|
|
171
|
+
await self.model.load(f"model_registry://{model_checkpoint_name}")
|
|
172
|
+
|
|
173
|
+
self.model.set_optim_step(checkpoint_data["optim_step"])
|
|
174
|
+
|
|
175
|
+
async def _recipe_specific_checkpoint_saving(self) -> dict:
|
|
176
|
+
progress_pct = int(self.training_completion_percentage * 100)
|
|
177
|
+
model_checkpoint_name = f"checkpoint-{self.checkpoint_manager.job_id}-{progress_pct}-policy"
|
|
178
|
+
model_checkpoint_name = await self.model.save(model_checkpoint_name, inference_only=False)
|
|
179
|
+
|
|
180
|
+
return {
|
|
181
|
+
"num_batches_processed": self.num_batches_processed,
|
|
182
|
+
"model_checkpoint_name": model_checkpoint_name,
|
|
183
|
+
"optim_step": self.model.get_optim_step(),
|
|
184
|
+
}
|
|
185
|
+
|
|
186
|
+
async def run(self):
|
|
187
|
+
self.model_ref = await self.model.clone_inf()
|
|
188
|
+
await self.checkpoint_manager.maybe_restore_checkpoint(self._recipe_specific_checkpoint_loading)
|
|
189
|
+
|
|
190
|
+
self.stage_notifier.report_progress(
|
|
191
|
+
tot_num_samples=self.total_num_samples,
|
|
192
|
+
processed_num_samples=self.dataset.idx,
|
|
193
|
+
monitoring_link=self.logger.training_monitoring_link,
|
|
194
|
+
)
|
|
195
|
+
|
|
196
|
+
while self.training_completion_percentage < 1.0:
|
|
197
|
+
self.num_batches_processed += 1
|
|
198
|
+
|
|
199
|
+
for callback in self.callbacks:
|
|
200
|
+
if logs := await callback.maybe_call(self.training_completion_percentage):
|
|
201
|
+
self.logger(logs)
|
|
202
|
+
|
|
203
|
+
# Generate training samples
|
|
204
|
+
data = await async_map_batch(self.gen_data, self.dataset, self.prompts_per_batch)
|
|
205
|
+
scorer_logs = self.grader.get_logs(clear=True)
|
|
206
|
+
batch_logs = {
|
|
207
|
+
**{f"rewards/{key}": value for key, value in scorer_logs.items()},
|
|
208
|
+
**self.get_train_batch_logs(data),
|
|
209
|
+
}
|
|
210
|
+
|
|
211
|
+
current_lr = self.lr_schedule(self.training_completion_percentage)
|
|
212
|
+
# Train on generated samples
|
|
213
|
+
flattened_data = sum([inner_list for inner_list in data], start=[])
|
|
214
|
+
minibatches = get_minibatches(flattened_data, self.samples_per_mini_batch, self.mini_epochs_per_batch)
|
|
215
|
+
for idx, mini_batch in enumerate(minibatches):
|
|
216
|
+
await async_map(self.train_sample, mini_batch)
|
|
217
|
+
optim_logs = await self.model.optim_step(
|
|
218
|
+
current_lr,
|
|
219
|
+
wd=self.weight_decay,
|
|
220
|
+
max_grad_norm=self.max_grad_norm,
|
|
221
|
+
skip_nan_gradients=self.skip_nan_gradients,
|
|
222
|
+
)
|
|
223
|
+
if idx == len(minibatches) - 1:
|
|
224
|
+
# only log tables and full batch-related logs on the final minibatch
|
|
225
|
+
self.logger(optim_logs | batch_logs)
|
|
226
|
+
else:
|
|
227
|
+
self.logger(optim_logs | dict(completion_percentage=self.training_completion_percentage))
|
|
228
|
+
|
|
229
|
+
self.stage_notifier.report_progress(
|
|
230
|
+
tot_num_samples=self.total_num_samples,
|
|
231
|
+
processed_num_samples=self.dataset.idx,
|
|
232
|
+
monitoring_link=self.logger.training_monitoring_link,
|
|
233
|
+
)
|
|
234
|
+
|
|
235
|
+
if await self.checkpoint_manager.maybe_checkpoint(
|
|
236
|
+
self.training_completion_percentage, self._recipe_specific_checkpoint_saving
|
|
237
|
+
):
|
|
238
|
+
break
|
|
239
|
+
|
|
240
|
+
def get_train_batch_logs(self, data: list[list[Sample]]) -> dict:
|
|
241
|
+
return {
|
|
242
|
+
**dict(
|
|
243
|
+
completion_percentage=self.training_completion_percentage,
|
|
244
|
+
percentage_no_advantages=np.mean(
|
|
245
|
+
[all(sample.advantage == batch[0].advantage for sample in batch) for batch in data]
|
|
246
|
+
).item(),
|
|
247
|
+
score_mean=np.mean([[sample.score for sample in batch] for batch in data]).item(),
|
|
248
|
+
score_std=np.std([[sample.score for sample in batch] for batch in data]).item(),
|
|
249
|
+
kl_div=np.mean([[np.mean(sample.kl_div) for sample in batch] for batch in data]).item(),
|
|
250
|
+
advantages=np.mean(np.concatenate([[sample.advantage for sample in batch] for batch in data])).item(),
|
|
251
|
+
generation_length=np.mean([np.mean([sample.gen_len for sample in batch]) for batch in data]).item(),
|
|
252
|
+
logprobs=np.mean(
|
|
253
|
+
np.concatenate([np.concatenate([sample.logprobs for sample in batch]) for batch in data])
|
|
254
|
+
).item(),
|
|
255
|
+
ref_logprobs=np.mean(
|
|
256
|
+
np.concatenate([np.concatenate([sample.ref_logprobs for sample in batch]) for batch in data])
|
|
257
|
+
).item(),
|
|
258
|
+
),
|
|
259
|
+
**{"training/completion_percentage": self.training_completion_percentage},
|
|
260
|
+
}
|
|
@@ -0,0 +1,70 @@
|
|
|
1
|
+
from typing import Callable, Sequence
|
|
2
|
+
|
|
3
|
+
from adaptive_harmony import (
|
|
4
|
+
JobNotifier,
|
|
5
|
+
Logger,
|
|
6
|
+
StageNotifier,
|
|
7
|
+
StringThread,
|
|
8
|
+
TrainingModel,
|
|
9
|
+
)
|
|
10
|
+
from adaptive_harmony.common.callbacks import RecipeCallback
|
|
11
|
+
from adaptive_harmony.common.grpo import GRPO, Sample
|
|
12
|
+
from adaptive_harmony.graders import BaseGrader
|
|
13
|
+
from adaptive_harmony.metric_logger import StdoutLogger
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class GSPO(GRPO): # grpo already logs args so we don't do it here
|
|
17
|
+
def __init__(
|
|
18
|
+
self,
|
|
19
|
+
dataset: list[StringThread],
|
|
20
|
+
model: TrainingModel,
|
|
21
|
+
grader: BaseGrader,
|
|
22
|
+
logger: Logger = StdoutLogger(),
|
|
23
|
+
stage_notifier: StageNotifier = JobNotifier().stage_notifier("GSPO Training"),
|
|
24
|
+
callbacks: Sequence[RecipeCallback] = [],
|
|
25
|
+
max_num_gspo_steps: int | None = None,
|
|
26
|
+
completions_per_sample=8,
|
|
27
|
+
lr: float = 7.5e-7,
|
|
28
|
+
lr_scheduler: Callable[[float], float] | None = None,
|
|
29
|
+
samples_per_batch=128,
|
|
30
|
+
samples_per_mini_batch=128,
|
|
31
|
+
mini_epochs_per_batch=1,
|
|
32
|
+
max_grad_norm=1.0,
|
|
33
|
+
# wildly different defaults than GRPO because we are looking at the
|
|
34
|
+
# entire sequence at once, I tried the number from the GSPO paper but
|
|
35
|
+
# it was skipping ~60% of the samples. I had good success with 0.01 but
|
|
36
|
+
# it's not properly swept yet.
|
|
37
|
+
clip_range=0.01,
|
|
38
|
+
kl_beta=0.01,
|
|
39
|
+
weight_decay=0.0,
|
|
40
|
+
):
|
|
41
|
+
super().__init__(
|
|
42
|
+
dataset,
|
|
43
|
+
model,
|
|
44
|
+
grader,
|
|
45
|
+
logger,
|
|
46
|
+
stage_notifier,
|
|
47
|
+
callbacks,
|
|
48
|
+
max_num_grpo_steps=max_num_gspo_steps,
|
|
49
|
+
completions_per_sample=completions_per_sample,
|
|
50
|
+
lr=lr,
|
|
51
|
+
lr_scheduler=lr_scheduler,
|
|
52
|
+
samples_per_batch=samples_per_batch,
|
|
53
|
+
samples_per_mini_batch=samples_per_mini_batch,
|
|
54
|
+
mini_epochs_per_batch=mini_epochs_per_batch,
|
|
55
|
+
max_grad_norm=max_grad_norm,
|
|
56
|
+
clip_range=clip_range,
|
|
57
|
+
kl_beta=kl_beta,
|
|
58
|
+
weight_decay=weight_decay,
|
|
59
|
+
)
|
|
60
|
+
|
|
61
|
+
async def train_sample(self, sample: Sample):
|
|
62
|
+
await self.model.train_gspo(
|
|
63
|
+
sample.sample,
|
|
64
|
+
sample.logprobs,
|
|
65
|
+
sample.ref_logprobs,
|
|
66
|
+
[sample.advantage], # only diff with grpo is we train with a single advantage
|
|
67
|
+
self.clip_range,
|
|
68
|
+
self.clip_range,
|
|
69
|
+
self.kl_beta,
|
|
70
|
+
)
|