adaptive-harmony 0.1.24__py3-none-any.whl → 0.1.25__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- adaptive_harmony/artifacts/__init__.py +4 -0
- adaptive_harmony/common/env_grpo.py +18 -10
- adaptive_harmony/common/env_gspo.py +190 -0
- {adaptive_harmony-0.1.24.dist-info → adaptive_harmony-0.1.25.dist-info}/METADATA +2 -2
- {adaptive_harmony-0.1.24.dist-info → adaptive_harmony-0.1.25.dist-info}/RECORD +7 -5
- {adaptive_harmony-0.1.24.dist-info → adaptive_harmony-0.1.25.dist-info}/WHEEL +0 -0
- {adaptive_harmony-0.1.24.dist-info → adaptive_harmony-0.1.25.dist-info}/top_level.txt +0 -0
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
import os
|
|
2
2
|
from dataclasses import dataclass
|
|
3
|
+
from itertools import groupby
|
|
3
4
|
from typing import Callable, Sequence
|
|
4
5
|
|
|
5
6
|
import numpy as np
|
|
@@ -7,6 +8,7 @@ import numpy as np
|
|
|
7
8
|
from adaptive_harmony import (
|
|
8
9
|
CosineScheduler,
|
|
9
10
|
DataSet,
|
|
11
|
+
InferenceModel,
|
|
10
12
|
JobNotifier,
|
|
11
13
|
Logger,
|
|
12
14
|
StageNotifier,
|
|
@@ -21,19 +23,23 @@ from adaptive_harmony.environment import EnvironmentFactory, TrajectoryScore
|
|
|
21
23
|
from adaptive_harmony.metric_logger import StdoutLogger
|
|
22
24
|
|
|
23
25
|
|
|
24
|
-
def compute_advantages(
|
|
26
|
+
async def compute_advantages(
|
|
25
27
|
scores: list[TrajectoryScore],
|
|
26
28
|
logprobs: list[list[float]],
|
|
27
29
|
samples: list[TokenizedThread],
|
|
28
30
|
num_generated_turns: list[int],
|
|
31
|
+
model: InferenceModel,
|
|
29
32
|
) -> list[list[float]]:
|
|
30
|
-
def get_assistant_lengths(samples: list[TokenizedThread], num_generated_turns: list[int]) -> list[list[int]]:
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
33
|
+
async def get_assistant_lengths(samples: list[TokenizedThread], num_generated_turns: list[int]) -> list[list[int]]:
|
|
34
|
+
async def get_number_weight_per_assistant_turn(thread: TokenizedThread):
|
|
35
|
+
# you cannot rely on the tokens of the thread because templates can modify
|
|
36
|
+
# the number of tokens that will get weights and you need to match the weights
|
|
37
|
+
# for the advantages
|
|
38
|
+
weights = (await model.serialize_tokenized_thread(thread))[2]
|
|
39
|
+
return [len(list(group)) for key, group in groupby(weights, key=bool) if key]
|
|
40
|
+
|
|
41
|
+
all_lengths = await async_map(get_number_weight_per_assistant_turn, samples)
|
|
42
|
+
return [lengths[-num_gen:] for lengths, num_gen in zip(all_lengths, num_generated_turns)]
|
|
37
43
|
|
|
38
44
|
# FROM https://arxiv.org/pdf/2402.03300 -> Process Supervision RL with GRPO
|
|
39
45
|
# HERE PADDING DOES NOT PLAYS A ROLE IN ADVANTAGE COMPUTATION. SINCE nan are ignored.
|
|
@@ -68,7 +74,7 @@ def compute_advantages(
|
|
|
68
74
|
for adv, score in zip(score_level_advantage, scores)
|
|
69
75
|
]
|
|
70
76
|
|
|
71
|
-
assistant_lengths = get_assistant_lengths(samples, num_generated_turns)
|
|
77
|
+
assistant_lengths = await get_assistant_lengths(samples, num_generated_turns)
|
|
72
78
|
assert all([len(lp) == sum(al) for lp, al in zip(logprobs, assistant_lengths)])
|
|
73
79
|
|
|
74
80
|
token_level_advantage = [np.repeat(adv, al).tolist() for adv, al in zip(turn_level_advantage, assistant_lengths)]
|
|
@@ -227,7 +233,9 @@ class ENVGRPO:
|
|
|
227
233
|
ref_logprobs = await async_map(self.model_ref.logprobs_per_token, all_samples)
|
|
228
234
|
|
|
229
235
|
all_trajectory_scores = [score for _, score, _ in trajs_and_scores]
|
|
230
|
-
advantages = compute_advantages(
|
|
236
|
+
advantages = await compute_advantages(
|
|
237
|
+
all_trajectory_scores, logprobs, all_samples, num_generated_turns_list, self.model
|
|
238
|
+
)
|
|
231
239
|
|
|
232
240
|
kl = [
|
|
233
241
|
(np.array(lp, dtype=np.float32) - np.array(ref_lp, dtype=np.float32)).tolist()
|
|
@@ -0,0 +1,190 @@
|
|
|
1
|
+
from dataclasses import dataclass
|
|
2
|
+
from typing import Callable, Sequence, TypeAlias
|
|
3
|
+
|
|
4
|
+
import numpy as np
|
|
5
|
+
from numpy.typing import NDArray
|
|
6
|
+
|
|
7
|
+
from adaptive_harmony import (
|
|
8
|
+
JobNotifier,
|
|
9
|
+
Logger,
|
|
10
|
+
StageNotifier,
|
|
11
|
+
StringThread,
|
|
12
|
+
TokenizedThread,
|
|
13
|
+
TrainingModel,
|
|
14
|
+
)
|
|
15
|
+
from adaptive_harmony.common import RecipeCallback
|
|
16
|
+
from adaptive_harmony.common.env_grpo import ENVGRPO
|
|
17
|
+
from adaptive_harmony.core.utils import (
|
|
18
|
+
async_map,
|
|
19
|
+
hash_hyperparams,
|
|
20
|
+
log_args,
|
|
21
|
+
)
|
|
22
|
+
from adaptive_harmony.environment import EnvironmentFactory, TrajectoryScore
|
|
23
|
+
from adaptive_harmony.metric_logger import StdoutLogger
|
|
24
|
+
|
|
25
|
+
FloatArray: TypeAlias = NDArray[np.float32]
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
@dataclass
|
|
29
|
+
class Sample:
|
|
30
|
+
sample: TokenizedThread
|
|
31
|
+
logprobs: list[float]
|
|
32
|
+
ref_logprobs: list[float]
|
|
33
|
+
advantage: float
|
|
34
|
+
kl_div: list[float]
|
|
35
|
+
# for logging
|
|
36
|
+
score: float
|
|
37
|
+
gen_len: float
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
ENVGSPO_HYPERPARAMS = {
|
|
41
|
+
"max_num_grpo_steps",
|
|
42
|
+
"completions_per_sample",
|
|
43
|
+
"lr",
|
|
44
|
+
"lr_scheduler",
|
|
45
|
+
"samples_per_batch",
|
|
46
|
+
"samples_per_mini_batch",
|
|
47
|
+
"mini_epochs_per_batch",
|
|
48
|
+
"max_grad_norm",
|
|
49
|
+
"clip_range",
|
|
50
|
+
"kl_beta",
|
|
51
|
+
"weight_decays",
|
|
52
|
+
"skip_nan_gradients",
|
|
53
|
+
}
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
class ENVGSPO(ENVGRPO):
|
|
57
|
+
@log_args
|
|
58
|
+
@hash_hyperparams(include=ENVGSPO_HYPERPARAMS)
|
|
59
|
+
def __init__(
|
|
60
|
+
self,
|
|
61
|
+
dataset: list[StringThread],
|
|
62
|
+
model: TrainingModel,
|
|
63
|
+
environment_factory: EnvironmentFactory,
|
|
64
|
+
logger: Logger = StdoutLogger(),
|
|
65
|
+
stage_notifier: StageNotifier = JobNotifier().stage_notifier("ENVGSPO Training"),
|
|
66
|
+
callbacks: Sequence[RecipeCallback] = [],
|
|
67
|
+
validation_dataset: list[StringThread] | None = None,
|
|
68
|
+
validation_frequency: float = 0.2,
|
|
69
|
+
max_num_grpo_steps: int | None = None,
|
|
70
|
+
completions_per_sample=8,
|
|
71
|
+
lr: float = 7.5e-7,
|
|
72
|
+
lr_scheduler: Callable[[float], float] | None = None,
|
|
73
|
+
samples_per_batch=128,
|
|
74
|
+
samples_per_mini_batch=128,
|
|
75
|
+
mini_epochs_per_batch=1,
|
|
76
|
+
max_grad_norm=1.0,
|
|
77
|
+
clip_range=0.01,
|
|
78
|
+
kl_beta=0.1,
|
|
79
|
+
weight_decays: float = 0.0,
|
|
80
|
+
skip_nan_gradients: bool = False,
|
|
81
|
+
restart_from_checkpoint: str | None = None,
|
|
82
|
+
checkpoint_frequency: float = 0.2,
|
|
83
|
+
data_seed: int = 42,
|
|
84
|
+
):
|
|
85
|
+
super().__init__(
|
|
86
|
+
dataset,
|
|
87
|
+
model,
|
|
88
|
+
environment_factory,
|
|
89
|
+
logger,
|
|
90
|
+
stage_notifier,
|
|
91
|
+
callbacks,
|
|
92
|
+
max_num_grpo_steps=max_num_grpo_steps,
|
|
93
|
+
completions_per_sample=completions_per_sample,
|
|
94
|
+
lr=lr,
|
|
95
|
+
lr_scheduler=lr_scheduler,
|
|
96
|
+
samples_per_batch=samples_per_batch,
|
|
97
|
+
samples_per_mini_batch=samples_per_mini_batch,
|
|
98
|
+
mini_epochs_per_batch=mini_epochs_per_batch,
|
|
99
|
+
max_grad_norm=max_grad_norm,
|
|
100
|
+
clip_range=clip_range,
|
|
101
|
+
kl_beta=kl_beta,
|
|
102
|
+
weight_decays=weight_decays,
|
|
103
|
+
data_seed=data_seed,
|
|
104
|
+
)
|
|
105
|
+
|
|
106
|
+
async def gen_data(self, sample: StringThread) -> list[Sample]:
|
|
107
|
+
# need to override gen data due to the single reward check
|
|
108
|
+
async def generate_trajectory(
|
|
109
|
+
prompt: StringThread,
|
|
110
|
+
) -> tuple[TokenizedThread, TrajectoryScore, int]:
|
|
111
|
+
# this create the environment for the first turn.
|
|
112
|
+
environment = self.environment_factory.create_environment(prompt.metadata)
|
|
113
|
+
prompt = await environment.bootstrap_prompt(prompt)
|
|
114
|
+
|
|
115
|
+
# Count assistant turns in the context (before generation)
|
|
116
|
+
nb_context_assistant_turns = sum(1 for turn in prompt.get_turns() if turn.role == "assistant")
|
|
117
|
+
|
|
118
|
+
string_trajectory = await self.model.generate(prompt) # generate the first response from the agent.
|
|
119
|
+
num_generated_turns = 1
|
|
120
|
+
# we loop until the environment returns a score.
|
|
121
|
+
# notice how the environment can return a score or a tool or user response.
|
|
122
|
+
while not isinstance(
|
|
123
|
+
environment_response := await environment.react_to(string_trajectory),
|
|
124
|
+
TrajectoryScore,
|
|
125
|
+
):
|
|
126
|
+
for env_role, env_content in environment_response:
|
|
127
|
+
if not isinstance(env_content, str):
|
|
128
|
+
raise ValueError(f"env_content should be a str, got {env_content}")
|
|
129
|
+
if env_role == "user":
|
|
130
|
+
string_trajectory = string_trajectory.user(env_content)
|
|
131
|
+
elif env_role == "tool":
|
|
132
|
+
string_trajectory = string_trajectory.tool(env_content)
|
|
133
|
+
else:
|
|
134
|
+
raise ValueError
|
|
135
|
+
string_trajectory = await self.model.generate(string_trajectory)
|
|
136
|
+
num_generated_turns += 1
|
|
137
|
+
|
|
138
|
+
tokenized_trajectory = (
|
|
139
|
+
await self.model.tokenize_thread(string_trajectory)
|
|
140
|
+
).with_weight_assistant_turns_from_index(nb_context_assistant_turns)
|
|
141
|
+
|
|
142
|
+
return tokenized_trajectory, environment_response, num_generated_turns
|
|
143
|
+
|
|
144
|
+
assert self.model_ref is not None, "Calling `gen_data` before reference model has been set"
|
|
145
|
+
|
|
146
|
+
trajs_and_scores = await async_map(generate_trajectory, [sample] * self.completions_per_sample)
|
|
147
|
+
all_samples = [traj for traj, _, _ in trajs_and_scores]
|
|
148
|
+
logprobs = await async_map(self.model.logprobs_per_token, all_samples)
|
|
149
|
+
ref_logprobs = await async_map(self.model_ref.logprobs_per_token, all_samples)
|
|
150
|
+
|
|
151
|
+
all_trajectory_scores = [score for _, score, _ in trajs_and_scores]
|
|
152
|
+
assert all(len(traj_score.scores) == 1 for traj_score in all_trajectory_scores)
|
|
153
|
+
all_scores = np.array(
|
|
154
|
+
[traj_score.scores[0].score for traj_score in all_trajectory_scores],
|
|
155
|
+
dtype=np.float32,
|
|
156
|
+
)
|
|
157
|
+
|
|
158
|
+
advantages: FloatArray = all_scores - all_scores.mean()
|
|
159
|
+
advantages /= advantages.std() + 1e-8
|
|
160
|
+
|
|
161
|
+
kl = [
|
|
162
|
+
(np.array(lp, dtype=np.float32) - np.array(ref_lp, dtype=np.float32)).tolist()
|
|
163
|
+
for lp, ref_lp in zip(logprobs, ref_logprobs)
|
|
164
|
+
]
|
|
165
|
+
|
|
166
|
+
samples = []
|
|
167
|
+
for i in range(len(logprobs)):
|
|
168
|
+
samples.append(
|
|
169
|
+
Sample(
|
|
170
|
+
sample=all_samples[i],
|
|
171
|
+
logprobs=logprobs[i],
|
|
172
|
+
ref_logprobs=ref_logprobs[i],
|
|
173
|
+
advantage=advantages[i],
|
|
174
|
+
kl_div=kl[i],
|
|
175
|
+
score=all_trajectory_scores[i].cumulative_score,
|
|
176
|
+
gen_len=all_samples[i].len_last_turn(),
|
|
177
|
+
)
|
|
178
|
+
)
|
|
179
|
+
return samples
|
|
180
|
+
|
|
181
|
+
async def train_sample(self, sample: Sample):
|
|
182
|
+
await self.model.train_gspo(
|
|
183
|
+
sample.sample,
|
|
184
|
+
sample.logprobs,
|
|
185
|
+
sample.ref_logprobs,
|
|
186
|
+
advantage=[sample.advantage],
|
|
187
|
+
left_clip=self.clip_range,
|
|
188
|
+
right_clip=self.clip_range,
|
|
189
|
+
kl_beta=self.kl_beta,
|
|
190
|
+
)
|
|
@@ -1,13 +1,13 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: adaptive-harmony
|
|
3
|
-
Version: 0.1.
|
|
3
|
+
Version: 0.1.25
|
|
4
4
|
Summary: Adaptive Harmony training recipes and utilities for LLM fine-tuning
|
|
5
5
|
Classifier: Programming Language :: Python :: 3.12
|
|
6
6
|
Classifier: Programming Language :: Python :: 3.13
|
|
7
7
|
Classifier: Programming Language :: Python :: Implementation :: CPython
|
|
8
8
|
Classifier: Programming Language :: Python :: Implementation :: PyPy
|
|
9
9
|
Requires-Python: >=3.12
|
|
10
|
-
Requires-Dist: harmony-client~=0.
|
|
10
|
+
Requires-Dist: harmony-client~=0.2.0
|
|
11
11
|
Requires-Dist: rich>=13.7.0
|
|
12
12
|
Requires-Dist: datasets>=2.14.0
|
|
13
13
|
Requires-Dist: hf-xet>=1.1.2
|
|
@@ -2,11 +2,13 @@ adaptive_harmony/__init__.py,sha256=_KoDEWVU-mCtXWp7ZXXlWcTWSVVkE6_r8xlJDXyOxRw,
|
|
|
2
2
|
adaptive_harmony/logging_table.py,sha256=kN5jS-PO0Y1B6KFicv3BnSyXz5OfThV4L1pCY3_kUmk,56
|
|
3
3
|
adaptive_harmony/metric_logger.py,sha256=6KAp7UhhygiHgWj5l9Bhwc7Sg9cIhxSzAilpxp_7iZM,16619
|
|
4
4
|
adaptive_harmony/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
5
|
+
adaptive_harmony/artifacts/__init__.py,sha256=6iZNKLr3gDC9swbmlwPaUyev5N7lY-ukLZgtV93BeZQ,224
|
|
5
6
|
adaptive_harmony/common/__init__.py,sha256=qebnYmwNBurtouGDbK27mtwt9zLm3P0tHR_M9LnFZT4,967
|
|
6
7
|
adaptive_harmony/common/callbacks.py,sha256=Q5qxVOAdnQRUZxy_ZcBAVxXTmSNA3o7L-cfEZ3JPnWs,8636
|
|
7
8
|
adaptive_harmony/common/checkpointing.py,sha256=rNfzwTEvWzNbUMjkl4CUD3zfsYdsWU_ksR3Lqn-Ghck,6569
|
|
8
9
|
adaptive_harmony/common/dpo.py,sha256=ioionFEnxzagfBVnIvLBh6rb6-d8WeWtVHgp-VDBKf8,3463
|
|
9
|
-
adaptive_harmony/common/env_grpo.py,sha256=
|
|
10
|
+
adaptive_harmony/common/env_grpo.py,sha256=HR5CFrK1MiVXNljV1uE5ylTDeryF8hbdqWskqs_BqBE,15440
|
|
11
|
+
adaptive_harmony/common/env_gspo.py,sha256=AVkh8qTZ-IGgPrPYGifb1t-3mBsn-blW3aM5vZj-joA,6877
|
|
10
12
|
adaptive_harmony/common/grpo.py,sha256=LlG0NxpTtFga06YguTNDnEOVfBjRYHJoRyz4fbAFCRc,10384
|
|
11
13
|
adaptive_harmony/common/gspo.py,sha256=O4z-BrKLusGeM8P6LWz77h8i0HrUhLR7_wxrAluxdxQ,2407
|
|
12
14
|
adaptive_harmony/common/ppo.py,sha256=owJlajLDnOxq4LpjjIn-dLXJVmKlsQh3wMG0zfnbUxU,12393
|
|
@@ -61,7 +63,7 @@ adaptive_harmony/runtime/decorators.py,sha256=zDNnG_fNz-zgHnb-d5WCPNLMMKFRtL_ncz
|
|
|
61
63
|
adaptive_harmony/runtime/model_artifact_save.py,sha256=1Ui-Q1hP_eDAhKBFOXpEVix5Q3TY9_d11viXs0xsk3o,137
|
|
62
64
|
adaptive_harmony/runtime/runner.py,sha256=70lNz2pe2dGEgqH8Igwp8ppGLDLxHVwNmxcyV4Y6HMM,898
|
|
63
65
|
adaptive_harmony/runtime/simple_notifier.py,sha256=iVXtZwfcOvkZlWQgFC0qjE1P-yA6Y7Wx0SxQ9FoJ-0s,129
|
|
64
|
-
adaptive_harmony-0.1.
|
|
65
|
-
adaptive_harmony-0.1.
|
|
66
|
-
adaptive_harmony-0.1.
|
|
67
|
-
adaptive_harmony-0.1.
|
|
66
|
+
adaptive_harmony-0.1.25.dist-info/METADATA,sha256=LnlRqVeKLSw07fnI5_nFRLOP1pjMsL2wPAQ-6xYCt5A,1436
|
|
67
|
+
adaptive_harmony-0.1.25.dist-info/WHEEL,sha256=wUyA8OaulRlbfwMtmQsvNngGrxQHAvkKcvRmdizlJi0,92
|
|
68
|
+
adaptive_harmony-0.1.25.dist-info/top_level.txt,sha256=ZEmoKxkFM4M7H2mgH15wQ4Tf0Eb13FBmghRvC2seacU,17
|
|
69
|
+
adaptive_harmony-0.1.25.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|