adaptive-harmony 0.1.23__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- adaptive_harmony/__init__.py +162 -0
- adaptive_harmony/common/__init__.py +40 -0
- adaptive_harmony/common/callbacks.py +219 -0
- adaptive_harmony/common/checkpointing.py +163 -0
- adaptive_harmony/common/dpo.py +92 -0
- adaptive_harmony/common/env_grpo.py +361 -0
- adaptive_harmony/common/grpo.py +260 -0
- adaptive_harmony/common/gspo.py +70 -0
- adaptive_harmony/common/ppo.py +303 -0
- adaptive_harmony/common/rm.py +79 -0
- adaptive_harmony/common/sft.py +121 -0
- adaptive_harmony/core/__init__.py +0 -0
- adaptive_harmony/core/dataset.py +72 -0
- adaptive_harmony/core/display.py +93 -0
- adaptive_harmony/core/image_utils.py +110 -0
- adaptive_harmony/core/reasoning.py +12 -0
- adaptive_harmony/core/reward_client/__init__.py +19 -0
- adaptive_harmony/core/reward_client/client.py +160 -0
- adaptive_harmony/core/reward_client/reward_types.py +49 -0
- adaptive_harmony/core/reward_client/websocket_utils.py +18 -0
- adaptive_harmony/core/rich_counter.py +351 -0
- adaptive_harmony/core/rl_utils.py +38 -0
- adaptive_harmony/core/schedulers.py +38 -0
- adaptive_harmony/core/structured_output.py +385 -0
- adaptive_harmony/core/utils.py +365 -0
- adaptive_harmony/environment/__init__.py +8 -0
- adaptive_harmony/environment/environment.py +121 -0
- adaptive_harmony/evaluation/__init__.py +1 -0
- adaptive_harmony/evaluation/evaluation_artifact.py +67 -0
- adaptive_harmony/graders/__init__.py +20 -0
- adaptive_harmony/graders/answer_relevancy_judge/__init__.py +3 -0
- adaptive_harmony/graders/answer_relevancy_judge/answer_relevancy_judge.py +102 -0
- adaptive_harmony/graders/answer_relevancy_judge/prompts.py +58 -0
- adaptive_harmony/graders/base_grader.py +265 -0
- adaptive_harmony/graders/binary_judge/__init__.py +8 -0
- adaptive_harmony/graders/binary_judge/binary_judge.py +202 -0
- adaptive_harmony/graders/binary_judge/prompts.py +125 -0
- adaptive_harmony/graders/combined_grader.py +118 -0
- adaptive_harmony/graders/context_relevancy_judge/__init__.py +3 -0
- adaptive_harmony/graders/context_relevancy_judge/context_relevancy_judge.py +128 -0
- adaptive_harmony/graders/context_relevancy_judge/prompts.py +84 -0
- adaptive_harmony/graders/exceptions.py +9 -0
- adaptive_harmony/graders/faithfulness_judge/__init__.py +3 -0
- adaptive_harmony/graders/faithfulness_judge/faithfulness_judge.py +159 -0
- adaptive_harmony/graders/faithfulness_judge/prompts.py +22 -0
- adaptive_harmony/graders/range_judge/__init__.py +7 -0
- adaptive_harmony/graders/range_judge/prompts.py +232 -0
- adaptive_harmony/graders/range_judge/range_judge.py +188 -0
- adaptive_harmony/graders/range_judge/types.py +12 -0
- adaptive_harmony/graders/reward_server_grader.py +36 -0
- adaptive_harmony/graders/templated_prompt_judge.py +237 -0
- adaptive_harmony/graders/utils.py +79 -0
- adaptive_harmony/logging_table.py +1 -0
- adaptive_harmony/metric_logger.py +452 -0
- adaptive_harmony/parameters/__init__.py +2 -0
- adaptive_harmony/py.typed +0 -0
- adaptive_harmony/runtime/__init__.py +2 -0
- adaptive_harmony/runtime/context.py +2 -0
- adaptive_harmony/runtime/data.py +2 -0
- adaptive_harmony/runtime/decorators.py +2 -0
- adaptive_harmony/runtime/model_artifact_save.py +2 -0
- adaptive_harmony/runtime/runner.py +27 -0
- adaptive_harmony/runtime/simple_notifier.py +2 -0
- adaptive_harmony-0.1.23.dist-info/METADATA +37 -0
- adaptive_harmony-0.1.23.dist-info/RECORD +67 -0
- adaptive_harmony-0.1.23.dist-info/WHEEL +5 -0
- adaptive_harmony-0.1.23.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,303 @@
|
|
|
1
|
+
import os
|
|
2
|
+
from dataclasses import dataclass
|
|
3
|
+
from typing import Callable, Sequence
|
|
4
|
+
|
|
5
|
+
import numpy as np
|
|
6
|
+
|
|
7
|
+
from adaptive_harmony import (
|
|
8
|
+
CombinedSchedule,
|
|
9
|
+
CosineScheduler,
|
|
10
|
+
DataSet,
|
|
11
|
+
InferenceModel,
|
|
12
|
+
JobNotifier,
|
|
13
|
+
Logger,
|
|
14
|
+
StageNotifier,
|
|
15
|
+
StringThread,
|
|
16
|
+
TokenizedThread,
|
|
17
|
+
TrainingModel,
|
|
18
|
+
)
|
|
19
|
+
from adaptive_harmony.common.callbacks import RecipeCallback
|
|
20
|
+
from adaptive_harmony.common.checkpointing import CheckpointManager
|
|
21
|
+
from adaptive_harmony.core import rl_utils
|
|
22
|
+
from adaptive_harmony.core.utils import async_map, async_map_batch, get_minibatches, hash_hyperparams, log_args
|
|
23
|
+
from adaptive_harmony.graders import BaseGrader
|
|
24
|
+
from adaptive_harmony.metric_logger import StdoutLogger
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
@dataclass
|
|
28
|
+
class Sample:
|
|
29
|
+
sample: TokenizedThread
|
|
30
|
+
string_sample: StringThread
|
|
31
|
+
logprobs: list[float]
|
|
32
|
+
ref_logprobs: list[float]
|
|
33
|
+
advantages: list[float]
|
|
34
|
+
returns: list[float]
|
|
35
|
+
score: float
|
|
36
|
+
kl_div: list[float]
|
|
37
|
+
values: list[float]
|
|
38
|
+
kl_pen: float
|
|
39
|
+
cumulative_reward: float
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
PPO_HYPERPARAMS = {
|
|
43
|
+
"max_num_ppo_steps",
|
|
44
|
+
"value_only_fraction",
|
|
45
|
+
"lr_policy",
|
|
46
|
+
"lr_scheduler_policy",
|
|
47
|
+
"lr_scheduler_value",
|
|
48
|
+
"lr_value",
|
|
49
|
+
"samples_per_batch",
|
|
50
|
+
"samples_per_mini_batch",
|
|
51
|
+
"mini_epochs_per_batch",
|
|
52
|
+
"max_grad_norm",
|
|
53
|
+
"clip_range",
|
|
54
|
+
"kl_beta",
|
|
55
|
+
"gae_lambda",
|
|
56
|
+
"gae_gamma",
|
|
57
|
+
"weight_decay",
|
|
58
|
+
"skip_nan_gradients",
|
|
59
|
+
}
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
class PPO:
|
|
63
|
+
@log_args
|
|
64
|
+
@hash_hyperparams(include=PPO_HYPERPARAMS)
|
|
65
|
+
def __init__(
|
|
66
|
+
self,
|
|
67
|
+
dataset: list[StringThread],
|
|
68
|
+
model: TrainingModel,
|
|
69
|
+
value_model: TrainingModel,
|
|
70
|
+
grader: BaseGrader,
|
|
71
|
+
logger: Logger = StdoutLogger(),
|
|
72
|
+
stage_notifier: StageNotifier = JobNotifier().stage_notifier("PPO Training"),
|
|
73
|
+
callbacks: Sequence[RecipeCallback] = [],
|
|
74
|
+
max_num_ppo_steps: int | None = None,
|
|
75
|
+
value_only_fraction=0.25,
|
|
76
|
+
lr_policy: float = 0.75e-6,
|
|
77
|
+
lr_scheduler_policy: Callable[[float], float] | None = None,
|
|
78
|
+
lr_scheduler_value: Callable[[float], float] | None = None,
|
|
79
|
+
lr_value: float = 1e-6,
|
|
80
|
+
samples_per_batch=128,
|
|
81
|
+
samples_per_mini_batch=128,
|
|
82
|
+
mini_epochs_per_batch=1,
|
|
83
|
+
max_grad_norm=1.0,
|
|
84
|
+
clip_range=0.1,
|
|
85
|
+
kl_beta=0.1,
|
|
86
|
+
gae_lambda=0.95,
|
|
87
|
+
gae_gamma=1.0,
|
|
88
|
+
weight_decay: float = 0,
|
|
89
|
+
skip_nan_gradients: bool = False,
|
|
90
|
+
restart_from_checkpoint: str | None = None,
|
|
91
|
+
checkpoint_frequency: float = 0.2,
|
|
92
|
+
):
|
|
93
|
+
assert value_model.is_scalar(), "You must give a scalar model to PPO for the value network"
|
|
94
|
+
# Core components
|
|
95
|
+
self.model_ref: InferenceModel | None = None # Instantiated when run() is called
|
|
96
|
+
self.dataset = DataSet(dataset, allow_looping=True)
|
|
97
|
+
self.model = model
|
|
98
|
+
self.value_model = value_model
|
|
99
|
+
self.grader = grader
|
|
100
|
+
self.scoring_fn = grader.score_float_value
|
|
101
|
+
self.logger = logger
|
|
102
|
+
self.stage_notifier = stage_notifier
|
|
103
|
+
self.callbacks = callbacks
|
|
104
|
+
self.skip_nan_gradients = skip_nan_gradients
|
|
105
|
+
# PPO HP's
|
|
106
|
+
self.max_num_batches = max_num_ppo_steps
|
|
107
|
+
self.lr_schedule_policy = lr_scheduler_policy or CombinedSchedule(
|
|
108
|
+
lambda _: 0, CosineScheduler(lr_policy), value_only_fraction
|
|
109
|
+
)
|
|
110
|
+
self.lr_schedule_value = lr_scheduler_value or CosineScheduler(lr_value)
|
|
111
|
+
self.samples_per_batch = samples_per_batch
|
|
112
|
+
self.samples_per_mini_batch = samples_per_mini_batch
|
|
113
|
+
self.total_num_samples = (
|
|
114
|
+
self.max_num_batches * self.samples_per_batch if self.max_num_batches else len(self.dataset)
|
|
115
|
+
)
|
|
116
|
+
self.mini_epochs_per_batch = mini_epochs_per_batch
|
|
117
|
+
self.max_grad_norm = max_grad_norm
|
|
118
|
+
self.clip_range = clip_range
|
|
119
|
+
self.kl_beta = kl_beta
|
|
120
|
+
self.gae_lambda = gae_lambda
|
|
121
|
+
self.gae_gamma = gae_gamma
|
|
122
|
+
self.weight_decay = weight_decay
|
|
123
|
+
|
|
124
|
+
self.num_batches_processed = 0
|
|
125
|
+
|
|
126
|
+
self.checkpoint_manager = CheckpointManager(
|
|
127
|
+
recipe_name="PPO",
|
|
128
|
+
dataset=self.dataset,
|
|
129
|
+
threads_dataset=dataset,
|
|
130
|
+
callbacks=self.callbacks,
|
|
131
|
+
hyperparams_hash=self._hyperparams_hash, # type: ignore
|
|
132
|
+
job_id=os.environ.get("HARMONY_JOB_ID"),
|
|
133
|
+
checkpoint_frequency=checkpoint_frequency,
|
|
134
|
+
restart_from_checkpoint=restart_from_checkpoint,
|
|
135
|
+
)
|
|
136
|
+
|
|
137
|
+
@property
|
|
138
|
+
def training_completion_percentage(self):
|
|
139
|
+
return (
|
|
140
|
+
self.dataset.completion_percentage()
|
|
141
|
+
if self.max_num_batches is None
|
|
142
|
+
else min(self.num_batches_processed / self.max_num_batches, 1.0)
|
|
143
|
+
)
|
|
144
|
+
|
|
145
|
+
async def generate_sample(self, prompt: StringThread):
|
|
146
|
+
assert self.model_ref is not None, "Calling `generate_sample` before reference model has been set"
|
|
147
|
+
|
|
148
|
+
sample = await self.model.generate_tokens(prompt)
|
|
149
|
+
string_sample = await self.model.detokenize_thread(sample)
|
|
150
|
+
score = await self.scoring_fn(string_sample)
|
|
151
|
+
values = await self.value_model.score(sample)
|
|
152
|
+
|
|
153
|
+
logprobs = await self.model.logprobs_per_token(sample)
|
|
154
|
+
ref_logprobs = await self.model_ref.logprobs_per_token(sample)
|
|
155
|
+
|
|
156
|
+
kl = np.array(logprobs, dtype=np.float32) - np.array(ref_logprobs, dtype=np.float32)
|
|
157
|
+
kl_pen = -kl * self.kl_beta
|
|
158
|
+
rewards = np.array(kl_pen)
|
|
159
|
+
rewards[-1] += score
|
|
160
|
+
|
|
161
|
+
advantages = rl_utils.gae_advantages(values, rewards.tolist(), self.gae_lambda, self.gae_gamma)
|
|
162
|
+
returns = rl_utils.discounted_cumulative_rewards(rewards.tolist(), self.gae_gamma)
|
|
163
|
+
|
|
164
|
+
return Sample(
|
|
165
|
+
sample=sample,
|
|
166
|
+
string_sample=string_sample,
|
|
167
|
+
logprobs=logprobs,
|
|
168
|
+
ref_logprobs=ref_logprobs,
|
|
169
|
+
advantages=advantages,
|
|
170
|
+
returns=returns,
|
|
171
|
+
score=score,
|
|
172
|
+
values=values,
|
|
173
|
+
cumulative_reward=sum(rewards),
|
|
174
|
+
kl_div=kl.tolist(),
|
|
175
|
+
kl_pen=np.sum(kl_pen).item(),
|
|
176
|
+
)
|
|
177
|
+
|
|
178
|
+
async def train_ppo(self, sample: Sample):
|
|
179
|
+
await self.model.train_ppo(sample.sample, sample.logprobs, sample.advantages, self.clip_range)
|
|
180
|
+
|
|
181
|
+
async def train_value(self, sample: Sample):
|
|
182
|
+
await self.value_model.train_mse_per_token(sample.sample, sample.returns)
|
|
183
|
+
|
|
184
|
+
async def _recipe_specific_checkpoint_loading(self, checkpoint_data: dict) -> None:
|
|
185
|
+
self.num_batches_processed = checkpoint_data["num_batches_processed"]
|
|
186
|
+
|
|
187
|
+
model_checkpoint_name = checkpoint_data["model_checkpoint_name"]
|
|
188
|
+
await self.model.load(f"model_registry://{model_checkpoint_name}")
|
|
189
|
+
value_model_checkpoint_name = checkpoint_data["value_model_checkpoint_name"]
|
|
190
|
+
await self.value_model.load(f"model_registry://{value_model_checkpoint_name}")
|
|
191
|
+
|
|
192
|
+
self.model.set_optim_step(checkpoint_data["optim_step"])
|
|
193
|
+
self.value_model.set_optim_step(checkpoint_data["value_optim_step"])
|
|
194
|
+
|
|
195
|
+
async def _recipe_specific_checkpoint_saving(self) -> dict:
|
|
196
|
+
progress_pct = int(self.training_completion_percentage * 100)
|
|
197
|
+
model_checkpoint_name = f"checkpoint-{self.checkpoint_manager.job_id}-{progress_pct}-policy"
|
|
198
|
+
model_checkpoint_name = await self.model.save(model_checkpoint_name, inference_only=False)
|
|
199
|
+
|
|
200
|
+
value_model_checkpoint_name = f"checkpoint-{self.checkpoint_manager.job_id}-{progress_pct}-value"
|
|
201
|
+
value_model_checkpoint_name = await self.value_model.save(value_model_checkpoint_name, inference_only=False)
|
|
202
|
+
|
|
203
|
+
return {
|
|
204
|
+
"num_batches_processed": self.num_batches_processed,
|
|
205
|
+
"model_checkpoint_name": model_checkpoint_name,
|
|
206
|
+
"value_model_checkpoint_name": value_model_checkpoint_name,
|
|
207
|
+
"optim_step": self.model.get_optim_step(),
|
|
208
|
+
"value_optim_step": self.value_model.get_optim_step(),
|
|
209
|
+
}
|
|
210
|
+
|
|
211
|
+
async def run(self):
|
|
212
|
+
self.model_ref = await self.model.clone_inf()
|
|
213
|
+
await self.checkpoint_manager.maybe_restore_checkpoint(self._recipe_specific_checkpoint_loading)
|
|
214
|
+
|
|
215
|
+
self.stage_notifier.report_progress(
|
|
216
|
+
tot_num_samples=self.total_num_samples,
|
|
217
|
+
processed_num_samples=self.dataset.idx,
|
|
218
|
+
monitoring_link=self.logger.training_monitoring_link,
|
|
219
|
+
)
|
|
220
|
+
|
|
221
|
+
while self.training_completion_percentage < 1.0:
|
|
222
|
+
self.num_batches_processed += 1
|
|
223
|
+
|
|
224
|
+
for callback in self.callbacks:
|
|
225
|
+
if logs := await callback.maybe_call(self.training_completion_percentage):
|
|
226
|
+
self.logger(logs)
|
|
227
|
+
|
|
228
|
+
# Generate training samples
|
|
229
|
+
data = await async_map_batch(
|
|
230
|
+
self.generate_sample,
|
|
231
|
+
self.dataset,
|
|
232
|
+
self.samples_per_batch,
|
|
233
|
+
)
|
|
234
|
+
scorer_logs = self.grader.get_logs(clear=True)
|
|
235
|
+
batch_logs = {
|
|
236
|
+
**{f"rewards/{key}": value for key, value in scorer_logs.items()},
|
|
237
|
+
**self.get_train_batch_logs(data),
|
|
238
|
+
}
|
|
239
|
+
|
|
240
|
+
lr_policy = self.lr_schedule_policy(self.training_completion_percentage)
|
|
241
|
+
lr_value = self.lr_schedule_value(self.training_completion_percentage)
|
|
242
|
+
|
|
243
|
+
# Train on generated samples
|
|
244
|
+
if lr_policy > 0:
|
|
245
|
+
minibatches = get_minibatches(data, self.samples_per_mini_batch, self.mini_epochs_per_batch)
|
|
246
|
+
for idx, mini_batch in enumerate(minibatches):
|
|
247
|
+
await async_map(self.train_ppo, mini_batch)
|
|
248
|
+
optim_logs = await self.model.optim_step(
|
|
249
|
+
lr_policy,
|
|
250
|
+
wd=self.weight_decay,
|
|
251
|
+
max_grad_norm=self.max_grad_norm,
|
|
252
|
+
skip_nan_gradients=self.skip_nan_gradients,
|
|
253
|
+
)
|
|
254
|
+
if idx == len(minibatches) - 1:
|
|
255
|
+
# only log tables and full batch-related logs on the final minibatch
|
|
256
|
+
self.logger(optim_logs | batch_logs)
|
|
257
|
+
else:
|
|
258
|
+
self.logger(optim_logs | dict(completion_percentage=self.training_completion_percentage))
|
|
259
|
+
|
|
260
|
+
for mini_batch in get_minibatches(data, self.samples_per_mini_batch, self.mini_epochs_per_batch):
|
|
261
|
+
await async_map(self.train_value, mini_batch)
|
|
262
|
+
batch_logs |= await self.value_model.optim_step(
|
|
263
|
+
lr_value, wd=0, max_grad_norm=self.max_grad_norm, skip_nan_gradients=self.skip_nan_gradients
|
|
264
|
+
)
|
|
265
|
+
self.logger(batch_logs)
|
|
266
|
+
|
|
267
|
+
self.stage_notifier.report_progress(
|
|
268
|
+
tot_num_samples=self.total_num_samples,
|
|
269
|
+
processed_num_samples=self.dataset.idx,
|
|
270
|
+
monitoring_link=self.logger.training_monitoring_link,
|
|
271
|
+
)
|
|
272
|
+
|
|
273
|
+
if await self.checkpoint_manager.maybe_checkpoint(
|
|
274
|
+
self.training_completion_percentage, self._recipe_specific_checkpoint_saving
|
|
275
|
+
):
|
|
276
|
+
break
|
|
277
|
+
|
|
278
|
+
def get_train_batch_logs(self, data: list[Sample]) -> dict:
|
|
279
|
+
returns = np.concatenate([batch.returns for batch in data])
|
|
280
|
+
cur_values = np.concatenate([batch.values for batch in data])
|
|
281
|
+
|
|
282
|
+
var_return = returns.var()
|
|
283
|
+
mean_error = ((cur_values - returns) ** 2).mean()
|
|
284
|
+
explained_variance = (1 - mean_error / (var_return + 1e-8)).item()
|
|
285
|
+
|
|
286
|
+
logs = dict(
|
|
287
|
+
completion_percentage=self.training_completion_percentage,
|
|
288
|
+
score_mean=np.mean([batch.score for batch in data]).item(),
|
|
289
|
+
score_std=np.std([batch.score for batch in data]).item(),
|
|
290
|
+
returns=np.mean(np.concatenate([batch.returns for batch in data])),
|
|
291
|
+
kl_div=np.mean(np.concatenate([batch.kl_div for batch in data])),
|
|
292
|
+
advantages=np.mean(np.concatenate([batch.advantages for batch in data])),
|
|
293
|
+
generation_length=np.mean([batch.sample.len_last_turn() for batch in data]),
|
|
294
|
+
logprobs=np.mean(np.concatenate([batch.logprobs for batch in data])),
|
|
295
|
+
ref_logprobs=np.mean(np.concatenate([batch.ref_logprobs for batch in data])),
|
|
296
|
+
kl_penalty=np.mean([batch.kl_pen for batch in data]),
|
|
297
|
+
explained_variance=explained_variance,
|
|
298
|
+
cumulative_reward=np.mean([batch.cumulative_reward for batch in data]),
|
|
299
|
+
) | {
|
|
300
|
+
"training/completion_percentage": self.training_completion_percentage
|
|
301
|
+
} # to have an comparable axis with prior runs
|
|
302
|
+
|
|
303
|
+
return logs
|
|
@@ -0,0 +1,79 @@
|
|
|
1
|
+
from typing import Callable, Sequence
|
|
2
|
+
|
|
3
|
+
from tqdm.auto import tqdm
|
|
4
|
+
|
|
5
|
+
from adaptive_harmony import CosineScheduler, DataSet, JobNotifier, Logger, StageNotifier, StringThread, TrainingModel
|
|
6
|
+
from adaptive_harmony.common.callbacks import RecipeCallback
|
|
7
|
+
from adaptive_harmony.core.utils import async_map_batch, log_args
|
|
8
|
+
from adaptive_harmony.metric_logger import StdoutLogger
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class RewardModelling:
|
|
12
|
+
@log_args
|
|
13
|
+
def __init__(
|
|
14
|
+
self,
|
|
15
|
+
dataset: list[tuple[StringThread, StringThread]] | list[StringThread],
|
|
16
|
+
model: TrainingModel,
|
|
17
|
+
logger: Logger = StdoutLogger(),
|
|
18
|
+
job_notifier: StageNotifier = JobNotifier().stage_notifier("Reward Model Training"),
|
|
19
|
+
callbacks: Sequence[RecipeCallback] = [],
|
|
20
|
+
lr: float = 1e-06,
|
|
21
|
+
lr_scheduler: Callable[[float], float] | None = None,
|
|
22
|
+
samples_per_batch: int = 64,
|
|
23
|
+
max_grad_norm: float = 1.0,
|
|
24
|
+
epochs: int = 1,
|
|
25
|
+
skip_nan_gradients: bool = False,
|
|
26
|
+
):
|
|
27
|
+
self.dataset: DataSet[StringThread | tuple[StringThread, StringThread]] = DataSet(dataset, allow_looping=True)
|
|
28
|
+
self.lr_schedule = lr_scheduler or CosineScheduler(lr)
|
|
29
|
+
self.model = model
|
|
30
|
+
self.logger = logger
|
|
31
|
+
self.job_notifier = job_notifier
|
|
32
|
+
self.callbacks = callbacks
|
|
33
|
+
self.samples_per_batch = samples_per_batch
|
|
34
|
+
self.max_grad_norm = max_grad_norm
|
|
35
|
+
self.skip_nan_gradients = skip_nan_gradients
|
|
36
|
+
self.epochs = epochs
|
|
37
|
+
|
|
38
|
+
@property
|
|
39
|
+
def training_completion_percentage(self):
|
|
40
|
+
return self.dataset.completion_percentage() / self.epochs
|
|
41
|
+
|
|
42
|
+
async def train_rm(self, sample: tuple[StringThread, StringThread] | StringThread):
|
|
43
|
+
# having both preference and metric feedback in dataset likely will never happen, but it's possible
|
|
44
|
+
if isinstance(sample, tuple):
|
|
45
|
+
await self.model.train_ranking(sample[0], sample[1])
|
|
46
|
+
else:
|
|
47
|
+
if "res" not in sample.metadata:
|
|
48
|
+
raise ValueError(f"Sample missing required 'res' field in metadata: {sample.metadata.keys()}")
|
|
49
|
+
target_value = sample.metadata["res"]
|
|
50
|
+
await self.model.train_mse(sample, target_value)
|
|
51
|
+
|
|
52
|
+
async def run(self):
|
|
53
|
+
with tqdm(total=100) as pbar:
|
|
54
|
+
self.job_notifier.report_progress(
|
|
55
|
+
tot_num_samples=len(self.dataset) * self.epochs,
|
|
56
|
+
processed_num_samples=self.dataset.idx,
|
|
57
|
+
monitoring_link=self.logger.training_monitoring_link,
|
|
58
|
+
)
|
|
59
|
+
while self.training_completion_percentage < 1.0:
|
|
60
|
+
for callback in self.callbacks:
|
|
61
|
+
if logs := await callback.maybe_call(self.training_completion_percentage):
|
|
62
|
+
self.logger(logs)
|
|
63
|
+
|
|
64
|
+
await async_map_batch(self.train_rm, self.dataset, self.samples_per_batch)
|
|
65
|
+
cp = self.training_completion_percentage
|
|
66
|
+
current_lr = self.lr_schedule(cp)
|
|
67
|
+
pbar.update(cp * 100.0 - pbar.n)
|
|
68
|
+
|
|
69
|
+
logs = await self.model.optim_step(
|
|
70
|
+
current_lr, wd=0, max_grad_norm=self.max_grad_norm, skip_nan_gradients=self.skip_nan_gradients
|
|
71
|
+
)
|
|
72
|
+
|
|
73
|
+
self.job_notifier.report_progress(
|
|
74
|
+
tot_num_samples=len(self.dataset) * self.epochs,
|
|
75
|
+
processed_num_samples=self.dataset.idx,
|
|
76
|
+
monitoring_link=self.logger.training_monitoring_link,
|
|
77
|
+
)
|
|
78
|
+
|
|
79
|
+
self.logger(logs | dict(completion_percentage=cp))
|
|
@@ -0,0 +1,121 @@
|
|
|
1
|
+
import os
|
|
2
|
+
from typing import Callable, Sequence
|
|
3
|
+
|
|
4
|
+
from tqdm.auto import tqdm
|
|
5
|
+
|
|
6
|
+
from adaptive_harmony import CosineScheduler, DataSet, JobNotifier, Logger, StageNotifier, StringThread, TrainingModel
|
|
7
|
+
from adaptive_harmony.common.callbacks import RecipeCallback
|
|
8
|
+
from adaptive_harmony.common.checkpointing import CheckpointManager
|
|
9
|
+
from adaptive_harmony.core.utils import async_map_batch, hash_hyperparams, log_args
|
|
10
|
+
from adaptive_harmony.metric_logger import StdoutLogger
|
|
11
|
+
|
|
12
|
+
SFT_HYPERPARAMS = {
|
|
13
|
+
"lr",
|
|
14
|
+
"lr_scheduler",
|
|
15
|
+
"samples_per_batch",
|
|
16
|
+
"max_grad_norm",
|
|
17
|
+
"epochs",
|
|
18
|
+
"weight_decay",
|
|
19
|
+
"skip_nan_gradients",
|
|
20
|
+
}
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class SFT:
|
|
24
|
+
@log_args
|
|
25
|
+
@hash_hyperparams(include=SFT_HYPERPARAMS)
|
|
26
|
+
def __init__(
|
|
27
|
+
self,
|
|
28
|
+
dataset: list[StringThread],
|
|
29
|
+
model: TrainingModel,
|
|
30
|
+
logger: Logger = StdoutLogger(),
|
|
31
|
+
stage_notifier: StageNotifier = JobNotifier().stage_notifier("SFT Training"),
|
|
32
|
+
callbacks: Sequence[RecipeCallback] = [],
|
|
33
|
+
lr: float = 1e-5,
|
|
34
|
+
lr_scheduler: Callable[[float], float] | None = None,
|
|
35
|
+
samples_per_batch=512, # axel magic number: "pretty well validated across different scales"
|
|
36
|
+
max_grad_norm=1.0,
|
|
37
|
+
epochs: int = 1,
|
|
38
|
+
weight_decay: float = 0,
|
|
39
|
+
skip_nan_gradients: bool = False,
|
|
40
|
+
restart_from_checkpoint: str | None = None,
|
|
41
|
+
checkpoint_frequency: float = 0.2,
|
|
42
|
+
):
|
|
43
|
+
self.dataset = DataSet(dataset, allow_looping=epochs != 1)
|
|
44
|
+
self.lr_schedule = lr_scheduler or CosineScheduler(lr)
|
|
45
|
+
self.model = model
|
|
46
|
+
self.logger = logger
|
|
47
|
+
self.stage_notifier = stage_notifier
|
|
48
|
+
self.callbacks = callbacks
|
|
49
|
+
self.samples_per_batch = samples_per_batch
|
|
50
|
+
self.max_grad_norm = max_grad_norm
|
|
51
|
+
self.epochs = epochs
|
|
52
|
+
self.weight_decay = weight_decay
|
|
53
|
+
self.skip_nan_gradients = skip_nan_gradients
|
|
54
|
+
|
|
55
|
+
self.checkpoint_manager = CheckpointManager(
|
|
56
|
+
recipe_name="SFT",
|
|
57
|
+
dataset=self.dataset,
|
|
58
|
+
threads_dataset=dataset,
|
|
59
|
+
callbacks=self.callbacks,
|
|
60
|
+
hyperparams_hash=self._hyperparams_hash, # type: ignore
|
|
61
|
+
job_id=os.environ.get("HARMONY_JOB_ID"),
|
|
62
|
+
checkpoint_frequency=checkpoint_frequency,
|
|
63
|
+
restart_from_checkpoint=restart_from_checkpoint,
|
|
64
|
+
)
|
|
65
|
+
|
|
66
|
+
@property
|
|
67
|
+
def training_completion_percentage(self):
|
|
68
|
+
return self.dataset.completion_percentage() / self.epochs
|
|
69
|
+
|
|
70
|
+
async def _recipe_specific_checkpoint_loading(self, checkpoint_data: dict) -> None:
|
|
71
|
+
model_checkpoint_name = checkpoint_data["model_checkpoint_name"]
|
|
72
|
+
await self.model.load(f"model_registry://{model_checkpoint_name}")
|
|
73
|
+
self.model.set_optim_step(checkpoint_data["optim_step"])
|
|
74
|
+
|
|
75
|
+
async def _recipe_specific_checkpoint_saving(self) -> dict:
|
|
76
|
+
progress_pct = int(self.training_completion_percentage * 100)
|
|
77
|
+
model_checkpoint_name = f"checkpoint-{self.checkpoint_manager.job_id}-{progress_pct}"
|
|
78
|
+
model_checkpoint_name = await self.model.save(model_checkpoint_name, inference_only=False)
|
|
79
|
+
|
|
80
|
+
return {
|
|
81
|
+
"model_checkpoint_name": model_checkpoint_name,
|
|
82
|
+
"optim_step": self.model.get_optim_step(),
|
|
83
|
+
}
|
|
84
|
+
|
|
85
|
+
async def run(self):
|
|
86
|
+
await self.checkpoint_manager.maybe_restore_checkpoint(self._recipe_specific_checkpoint_loading)
|
|
87
|
+
|
|
88
|
+
self.stage_notifier.report_progress(
|
|
89
|
+
tot_num_samples=len(self.dataset) * self.epochs,
|
|
90
|
+
processed_num_samples=self.dataset.idx,
|
|
91
|
+
monitoring_link=self.logger.training_monitoring_link,
|
|
92
|
+
)
|
|
93
|
+
|
|
94
|
+
with tqdm(total=100) as pbar:
|
|
95
|
+
while self.training_completion_percentage < 1.0:
|
|
96
|
+
for callback in self.callbacks:
|
|
97
|
+
if logs := await callback.maybe_call(self.training_completion_percentage):
|
|
98
|
+
self.logger(logs)
|
|
99
|
+
|
|
100
|
+
await async_map_batch(self.model.train_language_modelling, self.dataset, self.samples_per_batch)
|
|
101
|
+
cp = self.training_completion_percentage
|
|
102
|
+
current_lr = self.lr_schedule(cp)
|
|
103
|
+
pbar.update(cp * 100.0 - pbar.n)
|
|
104
|
+
|
|
105
|
+
logs = await self.model.optim_step(
|
|
106
|
+
current_lr,
|
|
107
|
+
wd=self.weight_decay,
|
|
108
|
+
max_grad_norm=self.max_grad_norm,
|
|
109
|
+
skip_nan_gradients=self.skip_nan_gradients,
|
|
110
|
+
)
|
|
111
|
+
|
|
112
|
+
self.logger(logs | dict(completion_percentage=cp))
|
|
113
|
+
|
|
114
|
+
self.stage_notifier.report_progress(
|
|
115
|
+
tot_num_samples=len(self.dataset) * self.epochs,
|
|
116
|
+
processed_num_samples=self.dataset.idx,
|
|
117
|
+
monitoring_link=self.logger.training_monitoring_link,
|
|
118
|
+
)
|
|
119
|
+
|
|
120
|
+
if await self.checkpoint_manager.maybe_checkpoint(cp, self._recipe_specific_checkpoint_saving):
|
|
121
|
+
break
|
|
File without changes
|
|
@@ -0,0 +1,72 @@
|
|
|
1
|
+
from typing import Callable, Sequence
|
|
2
|
+
|
|
3
|
+
import numpy as np
|
|
4
|
+
from datasets import load_dataset
|
|
5
|
+
|
|
6
|
+
from adaptive_harmony import StringThread
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class DataSet[T]:
|
|
10
|
+
def __init__(self, threads: Sequence[T], allow_looping: bool = False, seed: int = 42):
|
|
11
|
+
self.threads = threads
|
|
12
|
+
self.allow_looping = allow_looping
|
|
13
|
+
# This will be used to shuffle the dataset when we cross the epoch boundary, initially it is just the indices of the threads
|
|
14
|
+
# to respect the given order in the first epoch
|
|
15
|
+
self.rng = np.random.default_rng(seed)
|
|
16
|
+
self.access_indices = self.rng.permutation(len(threads))
|
|
17
|
+
self.idx = 0
|
|
18
|
+
|
|
19
|
+
def __iter__(self) -> "DataSet":
|
|
20
|
+
return self
|
|
21
|
+
|
|
22
|
+
def __len__(self) -> int:
|
|
23
|
+
return len(self.threads)
|
|
24
|
+
|
|
25
|
+
def __next__(self) -> T:
|
|
26
|
+
if not self.allow_looping and self.idx == len(self.threads):
|
|
27
|
+
raise StopIteration()
|
|
28
|
+
elif self.allow_looping and self.idx == len(self.access_indices):
|
|
29
|
+
self.access_indices = np.concatenate([self.access_indices, self.rng.permutation(len(self.threads))])
|
|
30
|
+
|
|
31
|
+
sample_idx = self.access_indices[self.idx]
|
|
32
|
+
ret = self.threads[sample_idx]
|
|
33
|
+
|
|
34
|
+
if hasattr(ret, "metadata") and isinstance(ret.metadata, dict):
|
|
35
|
+
ret.metadata["sample_index"] = self.idx
|
|
36
|
+
|
|
37
|
+
self.idx += 1
|
|
38
|
+
|
|
39
|
+
return ret
|
|
40
|
+
|
|
41
|
+
def __getitem__(self, x):
|
|
42
|
+
return self.threads.__getitem__(x)
|
|
43
|
+
|
|
44
|
+
def completion_percentage(self) -> float:
|
|
45
|
+
"""If dataset is looping, this can return a value greater than 1.0. Handle in recipe."""
|
|
46
|
+
return self.idx / len(self.threads)
|
|
47
|
+
|
|
48
|
+
def reset(self):
|
|
49
|
+
self.idx = 0
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
def convert_sample_dict(
|
|
53
|
+
turns_key: str | None = "messages", role_key="role", content_key="content", trim_final_assistant_turns=False
|
|
54
|
+
):
|
|
55
|
+
def f(dialogue: dict) -> StringThread:
|
|
56
|
+
if turns_key is not None:
|
|
57
|
+
dialogue = dialogue[turns_key]
|
|
58
|
+
turns = [(turn[role_key], turn[content_key]) for turn in dialogue]
|
|
59
|
+
|
|
60
|
+
if trim_final_assistant_turns:
|
|
61
|
+
while len(turns) > 0 and turns[-1][0] == "assistant":
|
|
62
|
+
turns = turns[:-1]
|
|
63
|
+
|
|
64
|
+
return StringThread(turns)
|
|
65
|
+
|
|
66
|
+
return f
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
def load_from_hf(repo: str, split: str, convert_sample_fn: Callable[..., StringThread]) -> list[StringThread]:
|
|
70
|
+
dataset = load_dataset(repo, split=split, keep_in_memory=True)
|
|
71
|
+
dataset = dataset.select(range(len(dataset))) # type: ignore
|
|
72
|
+
return [convert_sample_fn(x) for x in dataset]
|
|
@@ -0,0 +1,93 @@
|
|
|
1
|
+
import textwrap
|
|
2
|
+
|
|
3
|
+
from rich.console import Console
|
|
4
|
+
from rich.table import Table
|
|
5
|
+
|
|
6
|
+
from adaptive_harmony import StringThread, TokenizedThread
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def _stringthread_repr(self: StringThread) -> str:
|
|
10
|
+
"""Rich-based __repr__ for StringThread."""
|
|
11
|
+
# Get turns from the thread
|
|
12
|
+
turns = self.get_turns()
|
|
13
|
+
|
|
14
|
+
# Create a table without borders
|
|
15
|
+
table = Table(show_header=False, show_edge=False, box=None, padding=(0, 1))
|
|
16
|
+
table.add_column("Role", style="bold blue", no_wrap=True, justify="right")
|
|
17
|
+
table.add_column("Content", overflow="fold")
|
|
18
|
+
|
|
19
|
+
wrap_width = 90
|
|
20
|
+
|
|
21
|
+
for turn in turns:
|
|
22
|
+
# Wrap content if needed
|
|
23
|
+
if wrap_width > 0 and turn.content:
|
|
24
|
+
wrapped_lines = []
|
|
25
|
+
for line in turn.content.split("\n"):
|
|
26
|
+
if line:
|
|
27
|
+
wrapped = textwrap.fill(line, width=wrap_width)
|
|
28
|
+
wrapped_lines.append(wrapped)
|
|
29
|
+
else:
|
|
30
|
+
wrapped_lines.append("")
|
|
31
|
+
content = "\n".join(wrapped_lines)
|
|
32
|
+
else:
|
|
33
|
+
content = turn.content
|
|
34
|
+
|
|
35
|
+
table.add_row(turn.role.upper(), content)
|
|
36
|
+
|
|
37
|
+
# Capture the output with horizontal lines
|
|
38
|
+
from io import StringIO
|
|
39
|
+
|
|
40
|
+
buffer = StringIO()
|
|
41
|
+
# Use styling for __repr__ since it's typically for interactive display
|
|
42
|
+
console = Console(file=buffer, width=120, markup=False)
|
|
43
|
+
|
|
44
|
+
# Get max width for the separator line
|
|
45
|
+
max_role_len = max(len(turn.role) for turn in turns) if turns else 0
|
|
46
|
+
total_width = max_role_len + 2 + wrap_width
|
|
47
|
+
separator = "─" * total_width
|
|
48
|
+
|
|
49
|
+
# Print with separators like the Rust version
|
|
50
|
+
buffer.write(separator + "\n")
|
|
51
|
+
console.print(table)
|
|
52
|
+
if self.metadata:
|
|
53
|
+
buffer.write(f"Metadata={self.metadata}\n")
|
|
54
|
+
buffer.write(separator + "\n")
|
|
55
|
+
|
|
56
|
+
return buffer.getvalue().rstrip()
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
def _tokenizedthread_repr(self: TokenizedThread) -> str:
|
|
60
|
+
"""Rich-based __repr__ for TokenizedThread."""
|
|
61
|
+
# Get turns from the thread
|
|
62
|
+
turns = self.get_turns()
|
|
63
|
+
|
|
64
|
+
# Create a table without borders
|
|
65
|
+
table = Table(show_header=False, show_edge=False, box=None, padding=(0, 1))
|
|
66
|
+
table.add_column("Role", style="bold blue", no_wrap=True, justify="right")
|
|
67
|
+
table.add_column("Tokens", overflow="fold")
|
|
68
|
+
|
|
69
|
+
for turn in turns:
|
|
70
|
+
# Format tokens as a string of integers
|
|
71
|
+
tokens_str = " ".join(str(t) for t in turn.content)
|
|
72
|
+
|
|
73
|
+
table.add_row(turn.role.upper(), tokens_str)
|
|
74
|
+
|
|
75
|
+
# Capture the output with horizontal lines
|
|
76
|
+
from io import StringIO
|
|
77
|
+
|
|
78
|
+
buffer = StringIO()
|
|
79
|
+
# Use styling for __repr__ since it's typically for interactive display
|
|
80
|
+
console = Console(file=buffer, width=120, markup=False)
|
|
81
|
+
|
|
82
|
+
# Get max width for the separator line
|
|
83
|
+
max_role_len = max(len(turn.role) for turn in turns) if turns else 0
|
|
84
|
+
# Estimate token display width (rough estimate)
|
|
85
|
+
total_width = max_role_len + 2 + 90
|
|
86
|
+
separator = "─" * total_width
|
|
87
|
+
|
|
88
|
+
# Print with separators like the Rust version
|
|
89
|
+
buffer.write(separator + "\n")
|
|
90
|
+
console.print(table)
|
|
91
|
+
buffer.write(separator + "\n")
|
|
92
|
+
|
|
93
|
+
return buffer.getvalue().rstrip()
|