adaptive-harmony 0.1.24__py3-none-any.whl → 0.1.25__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,4 @@
1
+ from harmony_client.artifacts import CustomArtifact as CustomArtifact
2
+ from harmony_client.artifacts import DatasetArtifact as DatasetArtifact
3
+
4
+ from adaptive_harmony.evaluation import EvaluationArtifact as EvaluationArtifact
@@ -1,5 +1,6 @@
1
1
  import os
2
2
  from dataclasses import dataclass
3
+ from itertools import groupby
3
4
  from typing import Callable, Sequence
4
5
 
5
6
  import numpy as np
@@ -7,6 +8,7 @@ import numpy as np
7
8
  from adaptive_harmony import (
8
9
  CosineScheduler,
9
10
  DataSet,
11
+ InferenceModel,
10
12
  JobNotifier,
11
13
  Logger,
12
14
  StageNotifier,
@@ -21,19 +23,23 @@ from adaptive_harmony.environment import EnvironmentFactory, TrajectoryScore
21
23
  from adaptive_harmony.metric_logger import StdoutLogger
22
24
 
23
25
 
24
- def compute_advantages(
26
+ async def compute_advantages(
25
27
  scores: list[TrajectoryScore],
26
28
  logprobs: list[list[float]],
27
29
  samples: list[TokenizedThread],
28
30
  num_generated_turns: list[int],
31
+ model: InferenceModel,
29
32
  ) -> list[list[float]]:
30
- def get_assistant_lengths(samples: list[TokenizedThread], num_generated_turns: list[int]) -> list[list[int]]:
31
- # here the +1 is because we have a loss weight on the EOD token of a turn, which is not represented when you look at the tokenized
32
- # Keep only the last num_generated_turns assistant turns (those with weight>0) since logprobs_per_token only returns logprobs for them
33
- return [
34
- [len(turn.content) + 1 for turn in sample.get_turns() if turn.role == "assistant"][-num_gen:]
35
- for sample, num_gen in zip(samples, num_generated_turns)
36
- ]
33
+ async def get_assistant_lengths(samples: list[TokenizedThread], num_generated_turns: list[int]) -> list[list[int]]:
34
+ async def get_number_weight_per_assistant_turn(thread: TokenizedThread):
35
+ # you cannot rely on the tokens of the thread because templates can modify
36
+ # the number of tokens that will get weights and you need to match the weights
37
+ # for the advantages
38
+ weights = (await model.serialize_tokenized_thread(thread))[2]
39
+ return [len(list(group)) for key, group in groupby(weights, key=bool) if key]
40
+
41
+ all_lengths = await async_map(get_number_weight_per_assistant_turn, samples)
42
+ return [lengths[-num_gen:] for lengths, num_gen in zip(all_lengths, num_generated_turns)]
37
43
 
38
44
  # FROM https://arxiv.org/pdf/2402.03300 -> Process Supervision RL with GRPO
39
45
  # HERE PADDING DOES NOT PLAYS A ROLE IN ADVANTAGE COMPUTATION. SINCE nan are ignored.
@@ -68,7 +74,7 @@ def compute_advantages(
68
74
  for adv, score in zip(score_level_advantage, scores)
69
75
  ]
70
76
 
71
- assistant_lengths = get_assistant_lengths(samples, num_generated_turns)
77
+ assistant_lengths = await get_assistant_lengths(samples, num_generated_turns)
72
78
  assert all([len(lp) == sum(al) for lp, al in zip(logprobs, assistant_lengths)])
73
79
 
74
80
  token_level_advantage = [np.repeat(adv, al).tolist() for adv, al in zip(turn_level_advantage, assistant_lengths)]
@@ -227,7 +233,9 @@ class ENVGRPO:
227
233
  ref_logprobs = await async_map(self.model_ref.logprobs_per_token, all_samples)
228
234
 
229
235
  all_trajectory_scores = [score for _, score, _ in trajs_and_scores]
230
- advantages = compute_advantages(all_trajectory_scores, logprobs, all_samples, num_generated_turns_list)
236
+ advantages = await compute_advantages(
237
+ all_trajectory_scores, logprobs, all_samples, num_generated_turns_list, self.model
238
+ )
231
239
 
232
240
  kl = [
233
241
  (np.array(lp, dtype=np.float32) - np.array(ref_lp, dtype=np.float32)).tolist()
@@ -0,0 +1,190 @@
1
+ from dataclasses import dataclass
2
+ from typing import Callable, Sequence, TypeAlias
3
+
4
+ import numpy as np
5
+ from numpy.typing import NDArray
6
+
7
+ from adaptive_harmony import (
8
+ JobNotifier,
9
+ Logger,
10
+ StageNotifier,
11
+ StringThread,
12
+ TokenizedThread,
13
+ TrainingModel,
14
+ )
15
+ from adaptive_harmony.common import RecipeCallback
16
+ from adaptive_harmony.common.env_grpo import ENVGRPO
17
+ from adaptive_harmony.core.utils import (
18
+ async_map,
19
+ hash_hyperparams,
20
+ log_args,
21
+ )
22
+ from adaptive_harmony.environment import EnvironmentFactory, TrajectoryScore
23
+ from adaptive_harmony.metric_logger import StdoutLogger
24
+
25
+ FloatArray: TypeAlias = NDArray[np.float32]
26
+
27
+
28
+ @dataclass
29
+ class Sample:
30
+ sample: TokenizedThread
31
+ logprobs: list[float]
32
+ ref_logprobs: list[float]
33
+ advantage: float
34
+ kl_div: list[float]
35
+ # for logging
36
+ score: float
37
+ gen_len: float
38
+
39
+
40
+ ENVGSPO_HYPERPARAMS = {
41
+ "max_num_grpo_steps",
42
+ "completions_per_sample",
43
+ "lr",
44
+ "lr_scheduler",
45
+ "samples_per_batch",
46
+ "samples_per_mini_batch",
47
+ "mini_epochs_per_batch",
48
+ "max_grad_norm",
49
+ "clip_range",
50
+ "kl_beta",
51
+ "weight_decays",
52
+ "skip_nan_gradients",
53
+ }
54
+
55
+
56
+ class ENVGSPO(ENVGRPO):
57
+ @log_args
58
+ @hash_hyperparams(include=ENVGSPO_HYPERPARAMS)
59
+ def __init__(
60
+ self,
61
+ dataset: list[StringThread],
62
+ model: TrainingModel,
63
+ environment_factory: EnvironmentFactory,
64
+ logger: Logger = StdoutLogger(),
65
+ stage_notifier: StageNotifier = JobNotifier().stage_notifier("ENVGSPO Training"),
66
+ callbacks: Sequence[RecipeCallback] = [],
67
+ validation_dataset: list[StringThread] | None = None,
68
+ validation_frequency: float = 0.2,
69
+ max_num_grpo_steps: int | None = None,
70
+ completions_per_sample=8,
71
+ lr: float = 7.5e-7,
72
+ lr_scheduler: Callable[[float], float] | None = None,
73
+ samples_per_batch=128,
74
+ samples_per_mini_batch=128,
75
+ mini_epochs_per_batch=1,
76
+ max_grad_norm=1.0,
77
+ clip_range=0.01,
78
+ kl_beta=0.1,
79
+ weight_decays: float = 0.0,
80
+ skip_nan_gradients: bool = False,
81
+ restart_from_checkpoint: str | None = None,
82
+ checkpoint_frequency: float = 0.2,
83
+ data_seed: int = 42,
84
+ ):
85
+ super().__init__(
86
+ dataset,
87
+ model,
88
+ environment_factory,
89
+ logger,
90
+ stage_notifier,
91
+ callbacks,
92
+ max_num_grpo_steps=max_num_grpo_steps,
93
+ completions_per_sample=completions_per_sample,
94
+ lr=lr,
95
+ lr_scheduler=lr_scheduler,
96
+ samples_per_batch=samples_per_batch,
97
+ samples_per_mini_batch=samples_per_mini_batch,
98
+ mini_epochs_per_batch=mini_epochs_per_batch,
99
+ max_grad_norm=max_grad_norm,
100
+ clip_range=clip_range,
101
+ kl_beta=kl_beta,
102
+ weight_decays=weight_decays,
103
+ data_seed=data_seed,
104
+ )
105
+
106
+ async def gen_data(self, sample: StringThread) -> list[Sample]:
107
+ # need to override gen data due to the single reward check
108
+ async def generate_trajectory(
109
+ prompt: StringThread,
110
+ ) -> tuple[TokenizedThread, TrajectoryScore, int]:
111
+ # this create the environment for the first turn.
112
+ environment = self.environment_factory.create_environment(prompt.metadata)
113
+ prompt = await environment.bootstrap_prompt(prompt)
114
+
115
+ # Count assistant turns in the context (before generation)
116
+ nb_context_assistant_turns = sum(1 for turn in prompt.get_turns() if turn.role == "assistant")
117
+
118
+ string_trajectory = await self.model.generate(prompt) # generate the first response from the agent.
119
+ num_generated_turns = 1
120
+ # we loop until the environment returns a score.
121
+ # notice how the environment can return a score or a tool or user response.
122
+ while not isinstance(
123
+ environment_response := await environment.react_to(string_trajectory),
124
+ TrajectoryScore,
125
+ ):
126
+ for env_role, env_content in environment_response:
127
+ if not isinstance(env_content, str):
128
+ raise ValueError(f"env_content should be a str, got {env_content}")
129
+ if env_role == "user":
130
+ string_trajectory = string_trajectory.user(env_content)
131
+ elif env_role == "tool":
132
+ string_trajectory = string_trajectory.tool(env_content)
133
+ else:
134
+ raise ValueError
135
+ string_trajectory = await self.model.generate(string_trajectory)
136
+ num_generated_turns += 1
137
+
138
+ tokenized_trajectory = (
139
+ await self.model.tokenize_thread(string_trajectory)
140
+ ).with_weight_assistant_turns_from_index(nb_context_assistant_turns)
141
+
142
+ return tokenized_trajectory, environment_response, num_generated_turns
143
+
144
+ assert self.model_ref is not None, "Calling `gen_data` before reference model has been set"
145
+
146
+ trajs_and_scores = await async_map(generate_trajectory, [sample] * self.completions_per_sample)
147
+ all_samples = [traj for traj, _, _ in trajs_and_scores]
148
+ logprobs = await async_map(self.model.logprobs_per_token, all_samples)
149
+ ref_logprobs = await async_map(self.model_ref.logprobs_per_token, all_samples)
150
+
151
+ all_trajectory_scores = [score for _, score, _ in trajs_and_scores]
152
+ assert all(len(traj_score.scores) == 1 for traj_score in all_trajectory_scores)
153
+ all_scores = np.array(
154
+ [traj_score.scores[0].score for traj_score in all_trajectory_scores],
155
+ dtype=np.float32,
156
+ )
157
+
158
+ advantages: FloatArray = all_scores - all_scores.mean()
159
+ advantages /= advantages.std() + 1e-8
160
+
161
+ kl = [
162
+ (np.array(lp, dtype=np.float32) - np.array(ref_lp, dtype=np.float32)).tolist()
163
+ for lp, ref_lp in zip(logprobs, ref_logprobs)
164
+ ]
165
+
166
+ samples = []
167
+ for i in range(len(logprobs)):
168
+ samples.append(
169
+ Sample(
170
+ sample=all_samples[i],
171
+ logprobs=logprobs[i],
172
+ ref_logprobs=ref_logprobs[i],
173
+ advantage=advantages[i],
174
+ kl_div=kl[i],
175
+ score=all_trajectory_scores[i].cumulative_score,
176
+ gen_len=all_samples[i].len_last_turn(),
177
+ )
178
+ )
179
+ return samples
180
+
181
+ async def train_sample(self, sample: Sample):
182
+ await self.model.train_gspo(
183
+ sample.sample,
184
+ sample.logprobs,
185
+ sample.ref_logprobs,
186
+ advantage=[sample.advantage],
187
+ left_clip=self.clip_range,
188
+ right_clip=self.clip_range,
189
+ kl_beta=self.kl_beta,
190
+ )
@@ -1,13 +1,13 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: adaptive-harmony
3
- Version: 0.1.24
3
+ Version: 0.1.25
4
4
  Summary: Adaptive Harmony training recipes and utilities for LLM fine-tuning
5
5
  Classifier: Programming Language :: Python :: 3.12
6
6
  Classifier: Programming Language :: Python :: 3.13
7
7
  Classifier: Programming Language :: Python :: Implementation :: CPython
8
8
  Classifier: Programming Language :: Python :: Implementation :: PyPy
9
9
  Requires-Python: >=3.12
10
- Requires-Dist: harmony-client~=0.1.0
10
+ Requires-Dist: harmony-client~=0.2.0
11
11
  Requires-Dist: rich>=13.7.0
12
12
  Requires-Dist: datasets>=2.14.0
13
13
  Requires-Dist: hf-xet>=1.1.2
@@ -2,11 +2,13 @@ adaptive_harmony/__init__.py,sha256=_KoDEWVU-mCtXWp7ZXXlWcTWSVVkE6_r8xlJDXyOxRw,
2
2
  adaptive_harmony/logging_table.py,sha256=kN5jS-PO0Y1B6KFicv3BnSyXz5OfThV4L1pCY3_kUmk,56
3
3
  adaptive_harmony/metric_logger.py,sha256=6KAp7UhhygiHgWj5l9Bhwc7Sg9cIhxSzAilpxp_7iZM,16619
4
4
  adaptive_harmony/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
5
+ adaptive_harmony/artifacts/__init__.py,sha256=6iZNKLr3gDC9swbmlwPaUyev5N7lY-ukLZgtV93BeZQ,224
5
6
  adaptive_harmony/common/__init__.py,sha256=qebnYmwNBurtouGDbK27mtwt9zLm3P0tHR_M9LnFZT4,967
6
7
  adaptive_harmony/common/callbacks.py,sha256=Q5qxVOAdnQRUZxy_ZcBAVxXTmSNA3o7L-cfEZ3JPnWs,8636
7
8
  adaptive_harmony/common/checkpointing.py,sha256=rNfzwTEvWzNbUMjkl4CUD3zfsYdsWU_ksR3Lqn-Ghck,6569
8
9
  adaptive_harmony/common/dpo.py,sha256=ioionFEnxzagfBVnIvLBh6rb6-d8WeWtVHgp-VDBKf8,3463
9
- adaptive_harmony/common/env_grpo.py,sha256=bQEBJ7TojBxnCIRopu_pkjzfTl1zAXRcW8olRDMDtIE,15149
10
+ adaptive_harmony/common/env_grpo.py,sha256=HR5CFrK1MiVXNljV1uE5ylTDeryF8hbdqWskqs_BqBE,15440
11
+ adaptive_harmony/common/env_gspo.py,sha256=AVkh8qTZ-IGgPrPYGifb1t-3mBsn-blW3aM5vZj-joA,6877
10
12
  adaptive_harmony/common/grpo.py,sha256=LlG0NxpTtFga06YguTNDnEOVfBjRYHJoRyz4fbAFCRc,10384
11
13
  adaptive_harmony/common/gspo.py,sha256=O4z-BrKLusGeM8P6LWz77h8i0HrUhLR7_wxrAluxdxQ,2407
12
14
  adaptive_harmony/common/ppo.py,sha256=owJlajLDnOxq4LpjjIn-dLXJVmKlsQh3wMG0zfnbUxU,12393
@@ -61,7 +63,7 @@ adaptive_harmony/runtime/decorators.py,sha256=zDNnG_fNz-zgHnb-d5WCPNLMMKFRtL_ncz
61
63
  adaptive_harmony/runtime/model_artifact_save.py,sha256=1Ui-Q1hP_eDAhKBFOXpEVix5Q3TY9_d11viXs0xsk3o,137
62
64
  adaptive_harmony/runtime/runner.py,sha256=70lNz2pe2dGEgqH8Igwp8ppGLDLxHVwNmxcyV4Y6HMM,898
63
65
  adaptive_harmony/runtime/simple_notifier.py,sha256=iVXtZwfcOvkZlWQgFC0qjE1P-yA6Y7Wx0SxQ9FoJ-0s,129
64
- adaptive_harmony-0.1.24.dist-info/METADATA,sha256=G67ZZoxVySEAvW2NAZM5wJK5HBGFmWpHw7Gi7QNS0pA,1436
65
- adaptive_harmony-0.1.24.dist-info/WHEEL,sha256=wUyA8OaulRlbfwMtmQsvNngGrxQHAvkKcvRmdizlJi0,92
66
- adaptive_harmony-0.1.24.dist-info/top_level.txt,sha256=ZEmoKxkFM4M7H2mgH15wQ4Tf0Eb13FBmghRvC2seacU,17
67
- adaptive_harmony-0.1.24.dist-info/RECORD,,
66
+ adaptive_harmony-0.1.25.dist-info/METADATA,sha256=LnlRqVeKLSw07fnI5_nFRLOP1pjMsL2wPAQ-6xYCt5A,1436
67
+ adaptive_harmony-0.1.25.dist-info/WHEEL,sha256=wUyA8OaulRlbfwMtmQsvNngGrxQHAvkKcvRmdizlJi0,92
68
+ adaptive_harmony-0.1.25.dist-info/top_level.txt,sha256=ZEmoKxkFM4M7H2mgH15wQ4Tf0Eb13FBmghRvC2seacU,17
69
+ adaptive_harmony-0.1.25.dist-info/RECORD,,