adaptive-harmony 0.1.23__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (67) hide show
  1. adaptive_harmony/__init__.py +162 -0
  2. adaptive_harmony/common/__init__.py +40 -0
  3. adaptive_harmony/common/callbacks.py +219 -0
  4. adaptive_harmony/common/checkpointing.py +163 -0
  5. adaptive_harmony/common/dpo.py +92 -0
  6. adaptive_harmony/common/env_grpo.py +361 -0
  7. adaptive_harmony/common/grpo.py +260 -0
  8. adaptive_harmony/common/gspo.py +70 -0
  9. adaptive_harmony/common/ppo.py +303 -0
  10. adaptive_harmony/common/rm.py +79 -0
  11. adaptive_harmony/common/sft.py +121 -0
  12. adaptive_harmony/core/__init__.py +0 -0
  13. adaptive_harmony/core/dataset.py +72 -0
  14. adaptive_harmony/core/display.py +93 -0
  15. adaptive_harmony/core/image_utils.py +110 -0
  16. adaptive_harmony/core/reasoning.py +12 -0
  17. adaptive_harmony/core/reward_client/__init__.py +19 -0
  18. adaptive_harmony/core/reward_client/client.py +160 -0
  19. adaptive_harmony/core/reward_client/reward_types.py +49 -0
  20. adaptive_harmony/core/reward_client/websocket_utils.py +18 -0
  21. adaptive_harmony/core/rich_counter.py +351 -0
  22. adaptive_harmony/core/rl_utils.py +38 -0
  23. adaptive_harmony/core/schedulers.py +38 -0
  24. adaptive_harmony/core/structured_output.py +385 -0
  25. adaptive_harmony/core/utils.py +365 -0
  26. adaptive_harmony/environment/__init__.py +8 -0
  27. adaptive_harmony/environment/environment.py +121 -0
  28. adaptive_harmony/evaluation/__init__.py +1 -0
  29. adaptive_harmony/evaluation/evaluation_artifact.py +67 -0
  30. adaptive_harmony/graders/__init__.py +20 -0
  31. adaptive_harmony/graders/answer_relevancy_judge/__init__.py +3 -0
  32. adaptive_harmony/graders/answer_relevancy_judge/answer_relevancy_judge.py +102 -0
  33. adaptive_harmony/graders/answer_relevancy_judge/prompts.py +58 -0
  34. adaptive_harmony/graders/base_grader.py +265 -0
  35. adaptive_harmony/graders/binary_judge/__init__.py +8 -0
  36. adaptive_harmony/graders/binary_judge/binary_judge.py +202 -0
  37. adaptive_harmony/graders/binary_judge/prompts.py +125 -0
  38. adaptive_harmony/graders/combined_grader.py +118 -0
  39. adaptive_harmony/graders/context_relevancy_judge/__init__.py +3 -0
  40. adaptive_harmony/graders/context_relevancy_judge/context_relevancy_judge.py +128 -0
  41. adaptive_harmony/graders/context_relevancy_judge/prompts.py +84 -0
  42. adaptive_harmony/graders/exceptions.py +9 -0
  43. adaptive_harmony/graders/faithfulness_judge/__init__.py +3 -0
  44. adaptive_harmony/graders/faithfulness_judge/faithfulness_judge.py +159 -0
  45. adaptive_harmony/graders/faithfulness_judge/prompts.py +22 -0
  46. adaptive_harmony/graders/range_judge/__init__.py +7 -0
  47. adaptive_harmony/graders/range_judge/prompts.py +232 -0
  48. adaptive_harmony/graders/range_judge/range_judge.py +188 -0
  49. adaptive_harmony/graders/range_judge/types.py +12 -0
  50. adaptive_harmony/graders/reward_server_grader.py +36 -0
  51. adaptive_harmony/graders/templated_prompt_judge.py +237 -0
  52. adaptive_harmony/graders/utils.py +79 -0
  53. adaptive_harmony/logging_table.py +1 -0
  54. adaptive_harmony/metric_logger.py +452 -0
  55. adaptive_harmony/parameters/__init__.py +2 -0
  56. adaptive_harmony/py.typed +0 -0
  57. adaptive_harmony/runtime/__init__.py +2 -0
  58. adaptive_harmony/runtime/context.py +2 -0
  59. adaptive_harmony/runtime/data.py +2 -0
  60. adaptive_harmony/runtime/decorators.py +2 -0
  61. adaptive_harmony/runtime/model_artifact_save.py +2 -0
  62. adaptive_harmony/runtime/runner.py +27 -0
  63. adaptive_harmony/runtime/simple_notifier.py +2 -0
  64. adaptive_harmony-0.1.23.dist-info/METADATA +37 -0
  65. adaptive_harmony-0.1.23.dist-info/RECORD +67 -0
  66. adaptive_harmony-0.1.23.dist-info/WHEEL +5 -0
  67. adaptive_harmony-0.1.23.dist-info/top_level.txt +1 -0
@@ -0,0 +1,361 @@
1
+ import os
2
+ from dataclasses import dataclass
3
+ from typing import Callable, Sequence
4
+
5
+ import numpy as np
6
+
7
+ from adaptive_harmony import (
8
+ CosineScheduler,
9
+ DataSet,
10
+ JobNotifier,
11
+ Logger,
12
+ StageNotifier,
13
+ StringThread,
14
+ TokenizedThread,
15
+ TrainingModel,
16
+ )
17
+ from adaptive_harmony.common import RecipeCallback
18
+ from adaptive_harmony.common.checkpointing import CheckpointManager
19
+ from adaptive_harmony.core.utils import async_map, async_map_batch, get_minibatches, hash_hyperparams, log_args
20
+ from adaptive_harmony.environment import EnvironmentFactory, TrajectoryScore
21
+ from adaptive_harmony.metric_logger import StdoutLogger
22
+
23
+
24
+ def compute_advantages(
25
+ scores: list[TrajectoryScore],
26
+ logprobs: list[list[float]],
27
+ samples: list[TokenizedThread],
28
+ num_generated_turns: list[int],
29
+ ) -> list[list[float]]:
30
+ def get_assistant_lengths(samples: list[TokenizedThread], num_generated_turns: list[int]) -> list[list[int]]:
31
+ # here the +1 is because we have a loss weight on the EOD token of a turn, which is not represented when you look at the tokenized
32
+ # Keep only the last num_generated_turns assistant turns (those with weight>0) since logprobs_per_token only returns logprobs for them
33
+ return [
34
+ [len(turn.content) + 1 for turn in sample.get_turns() if turn.role == "assistant"][-num_gen:]
35
+ for sample, num_gen in zip(samples, num_generated_turns)
36
+ ]
37
+
38
+ # FROM https://arxiv.org/pdf/2402.03300 -> Process Supervision RL with GRPO
39
+ # HERE PADDING DOES NOT PLAYS A ROLE IN ADVANTAGE COMPUTATION. SINCE nan are ignored.
40
+
41
+ mapped_scores = [[turn_score.score for turn_score in score.scores] for score in scores]
42
+
43
+ max_len = max(map(len, mapped_scores))
44
+
45
+ # pad with np.nan instead of 0
46
+ all_scores = np.full((len(mapped_scores), max_len), np.nan)
47
+ for i, s in enumerate(mapped_scores):
48
+ all_scores[i, : len(s)] = s
49
+
50
+ # nan-aware mean and std
51
+ mean = np.nanmean(all_scores)
52
+ std = np.nanstd(all_scores) + 1e-8
53
+
54
+ normalized_rewards = (all_scores - mean) / std
55
+
56
+ # cumulative sum per row, ignoring nans
57
+ score_level_advantage = np.where(
58
+ np.isnan(normalized_rewards),
59
+ np.nan,
60
+ np.cumsum(np.nan_to_num(normalized_rewards)[:, ::-1], axis=1)[:, ::-1],
61
+ )
62
+
63
+ turn_level_advantage = [
64
+ np.repeat(
65
+ adv[: len(score.scores)],
66
+ [turn_score.num_assistant_turns for turn_score in score.scores],
67
+ )
68
+ for adv, score in zip(score_level_advantage, scores)
69
+ ]
70
+
71
+ assistant_lengths = get_assistant_lengths(samples, num_generated_turns)
72
+ assert all([len(lp) == sum(al) for lp, al in zip(logprobs, assistant_lengths)])
73
+
74
+ token_level_advantage = [np.repeat(adv, al).tolist() for adv, al in zip(turn_level_advantage, assistant_lengths)]
75
+
76
+ return token_level_advantage
77
+
78
+
79
+ @dataclass
80
+ class Sample:
81
+ sample: TokenizedThread
82
+ logprobs: list[float]
83
+ ref_logprobs: list[float]
84
+ advantage: list[float]
85
+ kl_div: list[float]
86
+ # for logging
87
+ score: float
88
+ gen_len: float
89
+
90
+
91
+ ENVGRPO_HYPERPARAMS = {
92
+ "max_num_grpo_steps",
93
+ "completions_per_sample",
94
+ "lr",
95
+ "lr_scheduler",
96
+ "samples_per_batch",
97
+ "samples_per_mini_batch",
98
+ "mini_epochs_per_batch",
99
+ "max_grad_norm",
100
+ "clip_range",
101
+ "kl_beta",
102
+ "weight_decays",
103
+ "skip_nan_gradients",
104
+ }
105
+
106
+
107
+ class ENVGRPO:
108
+ @log_args
109
+ @hash_hyperparams(include=ENVGRPO_HYPERPARAMS)
110
+ def __init__(
111
+ self,
112
+ dataset: list[StringThread],
113
+ model: TrainingModel,
114
+ environment_factory: EnvironmentFactory,
115
+ logger: Logger = StdoutLogger(),
116
+ stage_notifier: StageNotifier = JobNotifier().stage_notifier("ENVGRPO Training"),
117
+ callbacks: Sequence[RecipeCallback] = [],
118
+ validation_dataset: list[StringThread] | None = None,
119
+ validation_frequency: float = 0.2,
120
+ max_num_grpo_steps: int | None = None,
121
+ completions_per_sample=8,
122
+ lr: float = 7.5e-7,
123
+ lr_scheduler: Callable[[float], float] | None = None,
124
+ samples_per_batch=128,
125
+ samples_per_mini_batch=128,
126
+ mini_epochs_per_batch=1,
127
+ max_grad_norm=1.0,
128
+ clip_range=0.1,
129
+ kl_beta=0.1,
130
+ weight_decays: float = 0.0,
131
+ skip_nan_gradients: bool = False,
132
+ restart_from_checkpoint: str | None = None,
133
+ checkpoint_frequency: float = 0.2,
134
+ ):
135
+ # Core components
136
+ self.dataset = DataSet(dataset, allow_looping=True)
137
+ self.model = model
138
+ self.logger = logger
139
+ self.stage_notifier = stage_notifier
140
+ self.sample_index_counter = 0
141
+ self.skip_nan_gradients = skip_nan_gradients
142
+ # Validation data/params
143
+ self.validation_dataset = validation_dataset
144
+ self.validation_frequency = validation_frequency
145
+ self.last_validation_percentage = -1.0 # Validation will run before training starts
146
+ # GRPO HP's
147
+ self.max_num_batches = max_num_grpo_steps
148
+ self.completions_per_sample = completions_per_sample
149
+ self.lr_schedule = lr_scheduler or CosineScheduler(lr)
150
+ self.samples_per_batch = samples_per_batch // completions_per_sample
151
+ self.samples_per_mini_batch = samples_per_mini_batch
152
+ self.total_num_samples = (
153
+ self.max_num_batches * self.samples_per_batch if self.max_num_batches else len(self.dataset)
154
+ )
155
+ self.max_grad_norm = max_grad_norm
156
+ self.environment_factory = environment_factory
157
+ self.clip_range = clip_range
158
+ self.kl_beta = kl_beta
159
+ self.weight_decays = weight_decays
160
+ self.mini_epochs_per_batch = mini_epochs_per_batch
161
+
162
+ self.num_batches_processed = 0
163
+ self.callbacks = callbacks
164
+
165
+ self.checkpoint_manager = CheckpointManager(
166
+ recipe_name="ENVGRPO",
167
+ dataset=self.dataset,
168
+ threads_dataset=dataset,
169
+ callbacks=self.callbacks,
170
+ hyperparams_hash=self._hyperparams_hash, # type: ignore
171
+ job_id=os.environ.get("HARMONY_JOB_ID"),
172
+ checkpoint_frequency=checkpoint_frequency,
173
+ restart_from_checkpoint=restart_from_checkpoint,
174
+ )
175
+
176
+ @property
177
+ def training_completion_percentage(self):
178
+ return (
179
+ self.dataset.completion_percentage()
180
+ if self.max_num_batches is None
181
+ else min(self.num_batches_processed / self.max_num_batches, 1.0)
182
+ )
183
+
184
+ async def gen_data(self, sample: StringThread) -> list[Sample]:
185
+ async def generate_trajectory(
186
+ prompt: StringThread,
187
+ ) -> tuple[TokenizedThread, TrajectoryScore, int]:
188
+ # this create the environment for the first turn.
189
+ environment = self.environment_factory.create_environment(prompt.metadata)
190
+ prompt = await environment.bootstrap_prompt(prompt)
191
+
192
+ # Count assistant turns in the context (before generation)
193
+ nb_context_assistant_turns = sum(1 for turn in prompt.get_turns() if turn.role == "assistant")
194
+
195
+ string_trajectory = await self.model.generate(prompt) # generate the first response from the agent.
196
+ num_generated_turns = 1
197
+ # we loop until the environment returns a score.
198
+ # notice how the environment can return a score or a tool or user response.
199
+ while not isinstance(
200
+ environment_response := await environment.react_to(string_trajectory),
201
+ TrajectoryScore,
202
+ ):
203
+ for env_role, env_content in environment_response:
204
+ if not isinstance(env_content, str):
205
+ raise ValueError(f"env_content should be a str, got {env_content}")
206
+ if env_role == "user":
207
+ string_trajectory = string_trajectory.user(env_content)
208
+ elif env_role == "tool":
209
+ string_trajectory = string_trajectory.tool(env_content)
210
+ else:
211
+ raise ValueError
212
+ string_trajectory = await self.model.generate(string_trajectory)
213
+ num_generated_turns += 1
214
+
215
+ tokenized_trajectory = (
216
+ await self.model.tokenize_thread(string_trajectory)
217
+ ).with_weight_assistant_turns_from_index(nb_context_assistant_turns)
218
+
219
+ return tokenized_trajectory, environment_response, num_generated_turns
220
+
221
+ assert self.model_ref is not None, "Calling `gen_data` before reference model has been set"
222
+
223
+ trajs_and_scores = await async_map(generate_trajectory, [sample] * self.completions_per_sample)
224
+ all_samples = [traj for traj, _, _ in trajs_and_scores]
225
+ num_generated_turns_list = [num_turns for _, _, num_turns in trajs_and_scores]
226
+ logprobs = await async_map(self.model.logprobs_per_token, all_samples)
227
+ ref_logprobs = await async_map(self.model_ref.logprobs_per_token, all_samples)
228
+
229
+ all_trajectory_scores = [score for _, score, _ in trajs_and_scores]
230
+ advantages = compute_advantages(all_trajectory_scores, logprobs, all_samples, num_generated_turns_list)
231
+
232
+ kl = [
233
+ (np.array(lp, dtype=np.float32) - np.array(ref_lp, dtype=np.float32)).tolist()
234
+ for lp, ref_lp in zip(logprobs, ref_logprobs)
235
+ ]
236
+
237
+ samples = []
238
+ for i in range(len(logprobs)):
239
+ samples.append(
240
+ Sample(
241
+ sample=all_samples[i],
242
+ logprobs=logprobs[i],
243
+ ref_logprobs=ref_logprobs[i],
244
+ advantage=advantages[i],
245
+ kl_div=kl[i],
246
+ score=all_trajectory_scores[i].cumulative_score,
247
+ gen_len=all_samples[i].len_last_turn(),
248
+ )
249
+ )
250
+ return samples
251
+
252
+ async def train_sample(self, sample: Sample):
253
+ await self.model.train_grpo(
254
+ sample.sample,
255
+ sample.logprobs,
256
+ sample.ref_logprobs,
257
+ sample.advantage,
258
+ self.clip_range,
259
+ self.kl_beta,
260
+ )
261
+
262
+ async def _recipe_specific_checkpoint_loading(self, checkpoint_data: dict) -> None:
263
+ self.num_batches_processed = checkpoint_data["num_batches_processed"]
264
+
265
+ model_checkpoint_name = checkpoint_data["model_checkpoint_name"]
266
+ model_checkpoint_name = await self.model.load(f"model_registry://{model_checkpoint_name}")
267
+
268
+ self.last_validation_percentage = checkpoint_data.get("last_validation_percentage", -1.0)
269
+ self.sample_index_counter = checkpoint_data.get("sample_index_counter", 0)
270
+
271
+ self.model.set_optim_step(checkpoint_data["optim_step"])
272
+
273
+ async def _recipe_specific_checkpoint_saving(self) -> dict:
274
+ progress_pct = int(self.training_completion_percentage * 100)
275
+ model_checkpoint_name = f"checkpoint-{self.checkpoint_manager.job_id}-{progress_pct}-policy"
276
+ await self.model.save(model_checkpoint_name, inference_only=False)
277
+
278
+ return {
279
+ "num_batches_processed": self.num_batches_processed,
280
+ "model_checkpoint_name": model_checkpoint_name,
281
+ "last_validation_percentage": self.last_validation_percentage,
282
+ "sample_index_counter": self.sample_index_counter,
283
+ "optim_step": self.model.get_optim_step(),
284
+ }
285
+
286
+ async def run(self):
287
+ self.model_ref = await self.model.clone_inf()
288
+ await self.checkpoint_manager.maybe_restore_checkpoint(self._recipe_specific_checkpoint_loading)
289
+
290
+ self.stage_notifier.report_progress(
291
+ tot_num_samples=self.total_num_samples,
292
+ processed_num_samples=self.num_batches_processed * self.samples_per_batch,
293
+ monitoring_link=self.logger.training_monitoring_link,
294
+ )
295
+
296
+ while self.training_completion_percentage < 1.0:
297
+ self.num_batches_processed += 1
298
+
299
+ for callback in self.callbacks:
300
+ if logs := await callback.maybe_call(self.training_completion_percentage):
301
+ self.logger(logs)
302
+
303
+ # Generate training samples
304
+ data = await async_map_batch(self.gen_data, self.dataset, self.samples_per_batch)
305
+
306
+ scorer_logs = {}
307
+ for key, value in self.environment_factory.get_logs(clear=True).items():
308
+ if "/" not in key:
309
+ key = f"environment/{key}"
310
+ scorer_logs[key] = value
311
+ batch_logs = {
312
+ **scorer_logs,
313
+ **self.get_train_batch_logs(data),
314
+ }
315
+
316
+ current_lr = self.lr_schedule(self.training_completion_percentage)
317
+
318
+ # Train on generated samples
319
+ flattened_data = sum([inner_list for inner_list in data], start=[])
320
+ minibatches = get_minibatches(flattened_data, self.samples_per_mini_batch, self.mini_epochs_per_batch)
321
+ for idx, mini_batch in enumerate(minibatches):
322
+ await async_map(self.train_sample, mini_batch)
323
+ optim_logs = await self.model.optim_step(
324
+ current_lr, wd=0.0, max_grad_norm=self.max_grad_norm, skip_nan_gradients=self.skip_nan_gradients
325
+ )
326
+ if idx == len(minibatches) - 1:
327
+ # only log tables and full batch-related logs on the final minibatch
328
+ self.logger(optim_logs | batch_logs)
329
+ else:
330
+ self.logger(optim_logs | dict(completion_percentage=self.training_completion_percentage))
331
+
332
+ self.stage_notifier.report_progress(
333
+ tot_num_samples=self.total_num_samples,
334
+ processed_num_samples=self.num_batches_processed * self.samples_per_batch,
335
+ monitoring_link=self.logger.training_monitoring_link,
336
+ )
337
+
338
+ if await self.checkpoint_manager.maybe_checkpoint(
339
+ self.training_completion_percentage, self._recipe_specific_checkpoint_saving
340
+ ):
341
+ break
342
+
343
+ def get_train_batch_logs(self, data: list[list[Sample]]) -> dict:
344
+ return {
345
+ **dict(
346
+ completion_percentage=self.training_completion_percentage,
347
+ score_mean=np.mean([[sample.score for sample in batch] for batch in data]).item(),
348
+ percentage_no_advantages=np.mean(
349
+ [all(sample.advantage == batch[0].advantage for sample in batch) for batch in data]
350
+ ).item(),
351
+ score_std=np.std([[sample.score for sample in batch] for batch in data]).item(),
352
+ kl_div=np.mean([[np.mean(sample.kl_div) for sample in batch] for batch in data]).item(),
353
+ generation_length=np.mean([np.mean([sample.gen_len for sample in batch]) for batch in data]).item(),
354
+ logprobs=np.mean(
355
+ np.concatenate([np.concatenate([sample.logprobs for sample in batch]) for batch in data])
356
+ ).item(),
357
+ ref_logprobs=np.mean(
358
+ np.concatenate([np.concatenate([sample.ref_logprobs for sample in batch]) for batch in data])
359
+ ).item(),
360
+ ),
361
+ }
@@ -0,0 +1,260 @@
1
+ import os
2
+ from dataclasses import dataclass
3
+ from typing import Callable, Sequence, TypeAlias
4
+
5
+ import numpy as np
6
+ from numpy.typing import NDArray
7
+
8
+ from adaptive_harmony import (
9
+ CosineScheduler,
10
+ DataSet,
11
+ JobNotifier,
12
+ Logger,
13
+ StageNotifier,
14
+ StringThread,
15
+ TokenizedThread,
16
+ TrainingModel,
17
+ )
18
+ from adaptive_harmony.common.callbacks import RecipeCallback
19
+ from adaptive_harmony.common.checkpointing import CheckpointManager
20
+ from adaptive_harmony.core.utils import async_map, async_map_batch, get_minibatches, hash_hyperparams, log_args
21
+ from adaptive_harmony.graders import BaseGrader
22
+ from adaptive_harmony.metric_logger import StdoutLogger
23
+
24
+ FloatArray: TypeAlias = NDArray[np.float32]
25
+
26
+
27
+ @dataclass
28
+ class Sample:
29
+ sample: TokenizedThread
30
+ logprobs: list[float]
31
+ ref_logprobs: list[float]
32
+ advantage: float
33
+ score: float
34
+ kl_div: list[float]
35
+ gen_len: float
36
+
37
+
38
+ GRPO_HYPERPARAMS = {
39
+ "max_num_grpo_steps",
40
+ "completions_per_sample",
41
+ "lr",
42
+ "lr_scheduler",
43
+ "samples_per_batch",
44
+ "samples_per_mini_batch",
45
+ "mini_epochs_per_batch",
46
+ "max_grad_norm",
47
+ "clip_range",
48
+ "kl_beta",
49
+ "weight_decay",
50
+ "skip_nan_gradients",
51
+ }
52
+
53
+
54
+ class GRPO:
55
+ @log_args
56
+ @hash_hyperparams(include=GRPO_HYPERPARAMS)
57
+ def __init__(
58
+ self,
59
+ dataset: list[StringThread],
60
+ model: TrainingModel,
61
+ grader: BaseGrader,
62
+ logger: Logger = StdoutLogger(),
63
+ stage_notifier: StageNotifier = JobNotifier().stage_notifier("GRPO Training"),
64
+ callbacks: Sequence[RecipeCallback] = [],
65
+ max_num_grpo_steps: int | None = None,
66
+ completions_per_sample=8,
67
+ lr: float = 7.5e-7,
68
+ lr_scheduler: Callable[[float], float] | None = None,
69
+ samples_per_batch=128,
70
+ samples_per_mini_batch=128,
71
+ mini_epochs_per_batch=1,
72
+ max_grad_norm=1.0,
73
+ clip_range=0.1,
74
+ kl_beta=0.01,
75
+ weight_decay=0.0,
76
+ skip_nan_gradients: bool = False,
77
+ restart_from_checkpoint: str | None = None,
78
+ checkpoint_frequency: float = 0.2,
79
+ ):
80
+ # Core components
81
+ self.dataset = DataSet(dataset, allow_looping=True)
82
+ self.model = model
83
+ self.grader = grader
84
+ self.scoring_fn = grader.score_float_value
85
+ self.logger = logger
86
+ self.stage_notifier = stage_notifier
87
+ self.callbacks = callbacks
88
+ self.skip_nan_gradients = skip_nan_gradients
89
+ # GRPO HP's
90
+ self.max_num_batches = max_num_grpo_steps
91
+ self.completions_per_sample = completions_per_sample
92
+ self.lr_schedule = lr_scheduler or CosineScheduler(lr)
93
+ self.prompts_per_batch = samples_per_batch // completions_per_sample
94
+ self.samples_per_mini_batch = samples_per_mini_batch
95
+ self.total_num_samples = (
96
+ self.max_num_batches * self.prompts_per_batch if self.max_num_batches else len(self.dataset)
97
+ )
98
+ self.max_grad_norm = max_grad_norm
99
+ self.clip_range = clip_range
100
+ self.kl_beta = kl_beta
101
+ self.weight_decay = weight_decay
102
+ self.mini_epochs_per_batch = mini_epochs_per_batch
103
+
104
+ self.num_batches_processed = 0
105
+
106
+ self.checkpoint_manager = CheckpointManager(
107
+ recipe_name="GRPO",
108
+ dataset=self.dataset,
109
+ threads_dataset=dataset,
110
+ callbacks=self.callbacks,
111
+ hyperparams_hash=self._hyperparams_hash, # type: ignore
112
+ job_id=os.environ.get("HARMONY_JOB_ID"),
113
+ checkpoint_frequency=checkpoint_frequency,
114
+ restart_from_checkpoint=restart_from_checkpoint,
115
+ )
116
+
117
+ @property
118
+ def training_completion_percentage(self):
119
+ return (
120
+ self.dataset.completion_percentage()
121
+ if self.max_num_batches is None
122
+ else min(self.num_batches_processed / self.max_num_batches, 1.0)
123
+ )
124
+
125
+ async def gen_data(self, sample: StringThread) -> list[Sample]:
126
+ assert self.model_ref is not None, "Calling `gen_data` before reference model has been set"
127
+
128
+ all_samples = await async_map(self.model.generate_tokens, [sample] * self.completions_per_sample)
129
+ string_samples = await async_map(self.model.detokenize_thread, all_samples)
130
+ all_scores = np.array(await async_map(self.scoring_fn, string_samples), dtype=np.float32)
131
+
132
+ advantages: FloatArray = all_scores - all_scores.mean()
133
+ advantages /= advantages.std() + 1e-8
134
+
135
+ logprobs = await async_map(self.model.logprobs_per_token, all_samples)
136
+ ref_logprobs = await async_map(self.model_ref.logprobs_per_token, all_samples)
137
+ kl = [
138
+ (np.array(lp, dtype=np.float32) - np.array(ref_lp, dtype=np.float32)).tolist()
139
+ for lp, ref_lp in zip(logprobs, ref_logprobs)
140
+ ]
141
+
142
+ samples = []
143
+ for i in range(len(logprobs)):
144
+ samples.append(
145
+ Sample(
146
+ sample=all_samples[i],
147
+ logprobs=logprobs[i],
148
+ ref_logprobs=ref_logprobs[i],
149
+ advantage=advantages[i],
150
+ score=all_scores[i],
151
+ kl_div=kl[i],
152
+ gen_len=all_samples[i].len_last_turn(),
153
+ )
154
+ )
155
+ return samples
156
+
157
+ async def train_sample(self, sample: Sample):
158
+ await self.model.train_grpo(
159
+ sample.sample,
160
+ sample.logprobs,
161
+ sample.ref_logprobs,
162
+ [sample.advantage] * len(sample.logprobs),
163
+ self.clip_range,
164
+ self.kl_beta,
165
+ )
166
+
167
+ async def _recipe_specific_checkpoint_loading(self, checkpoint_data: dict) -> None:
168
+ self.num_batches_processed = checkpoint_data["num_batches_processed"]
169
+
170
+ model_checkpoint_name = checkpoint_data["model_checkpoint_name"]
171
+ await self.model.load(f"model_registry://{model_checkpoint_name}")
172
+
173
+ self.model.set_optim_step(checkpoint_data["optim_step"])
174
+
175
+ async def _recipe_specific_checkpoint_saving(self) -> dict:
176
+ progress_pct = int(self.training_completion_percentage * 100)
177
+ model_checkpoint_name = f"checkpoint-{self.checkpoint_manager.job_id}-{progress_pct}-policy"
178
+ model_checkpoint_name = await self.model.save(model_checkpoint_name, inference_only=False)
179
+
180
+ return {
181
+ "num_batches_processed": self.num_batches_processed,
182
+ "model_checkpoint_name": model_checkpoint_name,
183
+ "optim_step": self.model.get_optim_step(),
184
+ }
185
+
186
+ async def run(self):
187
+ self.model_ref = await self.model.clone_inf()
188
+ await self.checkpoint_manager.maybe_restore_checkpoint(self._recipe_specific_checkpoint_loading)
189
+
190
+ self.stage_notifier.report_progress(
191
+ tot_num_samples=self.total_num_samples,
192
+ processed_num_samples=self.dataset.idx,
193
+ monitoring_link=self.logger.training_monitoring_link,
194
+ )
195
+
196
+ while self.training_completion_percentage < 1.0:
197
+ self.num_batches_processed += 1
198
+
199
+ for callback in self.callbacks:
200
+ if logs := await callback.maybe_call(self.training_completion_percentage):
201
+ self.logger(logs)
202
+
203
+ # Generate training samples
204
+ data = await async_map_batch(self.gen_data, self.dataset, self.prompts_per_batch)
205
+ scorer_logs = self.grader.get_logs(clear=True)
206
+ batch_logs = {
207
+ **{f"rewards/{key}": value for key, value in scorer_logs.items()},
208
+ **self.get_train_batch_logs(data),
209
+ }
210
+
211
+ current_lr = self.lr_schedule(self.training_completion_percentage)
212
+ # Train on generated samples
213
+ flattened_data = sum([inner_list for inner_list in data], start=[])
214
+ minibatches = get_minibatches(flattened_data, self.samples_per_mini_batch, self.mini_epochs_per_batch)
215
+ for idx, mini_batch in enumerate(minibatches):
216
+ await async_map(self.train_sample, mini_batch)
217
+ optim_logs = await self.model.optim_step(
218
+ current_lr,
219
+ wd=self.weight_decay,
220
+ max_grad_norm=self.max_grad_norm,
221
+ skip_nan_gradients=self.skip_nan_gradients,
222
+ )
223
+ if idx == len(minibatches) - 1:
224
+ # only log tables and full batch-related logs on the final minibatch
225
+ self.logger(optim_logs | batch_logs)
226
+ else:
227
+ self.logger(optim_logs | dict(completion_percentage=self.training_completion_percentage))
228
+
229
+ self.stage_notifier.report_progress(
230
+ tot_num_samples=self.total_num_samples,
231
+ processed_num_samples=self.dataset.idx,
232
+ monitoring_link=self.logger.training_monitoring_link,
233
+ )
234
+
235
+ if await self.checkpoint_manager.maybe_checkpoint(
236
+ self.training_completion_percentage, self._recipe_specific_checkpoint_saving
237
+ ):
238
+ break
239
+
240
+ def get_train_batch_logs(self, data: list[list[Sample]]) -> dict:
241
+ return {
242
+ **dict(
243
+ completion_percentage=self.training_completion_percentage,
244
+ percentage_no_advantages=np.mean(
245
+ [all(sample.advantage == batch[0].advantage for sample in batch) for batch in data]
246
+ ).item(),
247
+ score_mean=np.mean([[sample.score for sample in batch] for batch in data]).item(),
248
+ score_std=np.std([[sample.score for sample in batch] for batch in data]).item(),
249
+ kl_div=np.mean([[np.mean(sample.kl_div) for sample in batch] for batch in data]).item(),
250
+ advantages=np.mean(np.concatenate([[sample.advantage for sample in batch] for batch in data])).item(),
251
+ generation_length=np.mean([np.mean([sample.gen_len for sample in batch]) for batch in data]).item(),
252
+ logprobs=np.mean(
253
+ np.concatenate([np.concatenate([sample.logprobs for sample in batch]) for batch in data])
254
+ ).item(),
255
+ ref_logprobs=np.mean(
256
+ np.concatenate([np.concatenate([sample.ref_logprobs for sample in batch]) for batch in data])
257
+ ).item(),
258
+ ),
259
+ **{"training/completion_percentage": self.training_completion_percentage},
260
+ }
@@ -0,0 +1,70 @@
1
+ from typing import Callable, Sequence
2
+
3
+ from adaptive_harmony import (
4
+ JobNotifier,
5
+ Logger,
6
+ StageNotifier,
7
+ StringThread,
8
+ TrainingModel,
9
+ )
10
+ from adaptive_harmony.common.callbacks import RecipeCallback
11
+ from adaptive_harmony.common.grpo import GRPO, Sample
12
+ from adaptive_harmony.graders import BaseGrader
13
+ from adaptive_harmony.metric_logger import StdoutLogger
14
+
15
+
16
+ class GSPO(GRPO): # grpo already logs args so we don't do it here
17
+ def __init__(
18
+ self,
19
+ dataset: list[StringThread],
20
+ model: TrainingModel,
21
+ grader: BaseGrader,
22
+ logger: Logger = StdoutLogger(),
23
+ stage_notifier: StageNotifier = JobNotifier().stage_notifier("GSPO Training"),
24
+ callbacks: Sequence[RecipeCallback] = [],
25
+ max_num_gspo_steps: int | None = None,
26
+ completions_per_sample=8,
27
+ lr: float = 7.5e-7,
28
+ lr_scheduler: Callable[[float], float] | None = None,
29
+ samples_per_batch=128,
30
+ samples_per_mini_batch=128,
31
+ mini_epochs_per_batch=1,
32
+ max_grad_norm=1.0,
33
+ # wildly different defaults than GRPO because we are looking at the
34
+ # entire sequence at once, I tried the number from the GSPO paper but
35
+ # it was skipping ~60% of the samples. I had good success with 0.01 but
36
+ # it's not properly swept yet.
37
+ clip_range=0.01,
38
+ kl_beta=0.01,
39
+ weight_decay=0.0,
40
+ ):
41
+ super().__init__(
42
+ dataset,
43
+ model,
44
+ grader,
45
+ logger,
46
+ stage_notifier,
47
+ callbacks,
48
+ max_num_grpo_steps=max_num_gspo_steps,
49
+ completions_per_sample=completions_per_sample,
50
+ lr=lr,
51
+ lr_scheduler=lr_scheduler,
52
+ samples_per_batch=samples_per_batch,
53
+ samples_per_mini_batch=samples_per_mini_batch,
54
+ mini_epochs_per_batch=mini_epochs_per_batch,
55
+ max_grad_norm=max_grad_norm,
56
+ clip_range=clip_range,
57
+ kl_beta=kl_beta,
58
+ weight_decay=weight_decay,
59
+ )
60
+
61
+ async def train_sample(self, sample: Sample):
62
+ await self.model.train_gspo(
63
+ sample.sample,
64
+ sample.logprobs,
65
+ sample.ref_logprobs,
66
+ [sample.advantage], # only diff with grpo is we train with a single advantage
67
+ self.clip_range,
68
+ self.clip_range,
69
+ self.kl_beta,
70
+ )