adaptive-harmony 0.1.23__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (67) hide show
  1. adaptive_harmony/__init__.py +162 -0
  2. adaptive_harmony/common/__init__.py +40 -0
  3. adaptive_harmony/common/callbacks.py +219 -0
  4. adaptive_harmony/common/checkpointing.py +163 -0
  5. adaptive_harmony/common/dpo.py +92 -0
  6. adaptive_harmony/common/env_grpo.py +361 -0
  7. adaptive_harmony/common/grpo.py +260 -0
  8. adaptive_harmony/common/gspo.py +70 -0
  9. adaptive_harmony/common/ppo.py +303 -0
  10. adaptive_harmony/common/rm.py +79 -0
  11. adaptive_harmony/common/sft.py +121 -0
  12. adaptive_harmony/core/__init__.py +0 -0
  13. adaptive_harmony/core/dataset.py +72 -0
  14. adaptive_harmony/core/display.py +93 -0
  15. adaptive_harmony/core/image_utils.py +110 -0
  16. adaptive_harmony/core/reasoning.py +12 -0
  17. adaptive_harmony/core/reward_client/__init__.py +19 -0
  18. adaptive_harmony/core/reward_client/client.py +160 -0
  19. adaptive_harmony/core/reward_client/reward_types.py +49 -0
  20. adaptive_harmony/core/reward_client/websocket_utils.py +18 -0
  21. adaptive_harmony/core/rich_counter.py +351 -0
  22. adaptive_harmony/core/rl_utils.py +38 -0
  23. adaptive_harmony/core/schedulers.py +38 -0
  24. adaptive_harmony/core/structured_output.py +385 -0
  25. adaptive_harmony/core/utils.py +365 -0
  26. adaptive_harmony/environment/__init__.py +8 -0
  27. adaptive_harmony/environment/environment.py +121 -0
  28. adaptive_harmony/evaluation/__init__.py +1 -0
  29. adaptive_harmony/evaluation/evaluation_artifact.py +67 -0
  30. adaptive_harmony/graders/__init__.py +20 -0
  31. adaptive_harmony/graders/answer_relevancy_judge/__init__.py +3 -0
  32. adaptive_harmony/graders/answer_relevancy_judge/answer_relevancy_judge.py +102 -0
  33. adaptive_harmony/graders/answer_relevancy_judge/prompts.py +58 -0
  34. adaptive_harmony/graders/base_grader.py +265 -0
  35. adaptive_harmony/graders/binary_judge/__init__.py +8 -0
  36. adaptive_harmony/graders/binary_judge/binary_judge.py +202 -0
  37. adaptive_harmony/graders/binary_judge/prompts.py +125 -0
  38. adaptive_harmony/graders/combined_grader.py +118 -0
  39. adaptive_harmony/graders/context_relevancy_judge/__init__.py +3 -0
  40. adaptive_harmony/graders/context_relevancy_judge/context_relevancy_judge.py +128 -0
  41. adaptive_harmony/graders/context_relevancy_judge/prompts.py +84 -0
  42. adaptive_harmony/graders/exceptions.py +9 -0
  43. adaptive_harmony/graders/faithfulness_judge/__init__.py +3 -0
  44. adaptive_harmony/graders/faithfulness_judge/faithfulness_judge.py +159 -0
  45. adaptive_harmony/graders/faithfulness_judge/prompts.py +22 -0
  46. adaptive_harmony/graders/range_judge/__init__.py +7 -0
  47. adaptive_harmony/graders/range_judge/prompts.py +232 -0
  48. adaptive_harmony/graders/range_judge/range_judge.py +188 -0
  49. adaptive_harmony/graders/range_judge/types.py +12 -0
  50. adaptive_harmony/graders/reward_server_grader.py +36 -0
  51. adaptive_harmony/graders/templated_prompt_judge.py +237 -0
  52. adaptive_harmony/graders/utils.py +79 -0
  53. adaptive_harmony/logging_table.py +1 -0
  54. adaptive_harmony/metric_logger.py +452 -0
  55. adaptive_harmony/parameters/__init__.py +2 -0
  56. adaptive_harmony/py.typed +0 -0
  57. adaptive_harmony/runtime/__init__.py +2 -0
  58. adaptive_harmony/runtime/context.py +2 -0
  59. adaptive_harmony/runtime/data.py +2 -0
  60. adaptive_harmony/runtime/decorators.py +2 -0
  61. adaptive_harmony/runtime/model_artifact_save.py +2 -0
  62. adaptive_harmony/runtime/runner.py +27 -0
  63. adaptive_harmony/runtime/simple_notifier.py +2 -0
  64. adaptive_harmony-0.1.23.dist-info/METADATA +37 -0
  65. adaptive_harmony-0.1.23.dist-info/RECORD +67 -0
  66. adaptive_harmony-0.1.23.dist-info/WHEEL +5 -0
  67. adaptive_harmony-0.1.23.dist-info/top_level.txt +1 -0
@@ -0,0 +1,303 @@
1
+ import os
2
+ from dataclasses import dataclass
3
+ from typing import Callable, Sequence
4
+
5
+ import numpy as np
6
+
7
+ from adaptive_harmony import (
8
+ CombinedSchedule,
9
+ CosineScheduler,
10
+ DataSet,
11
+ InferenceModel,
12
+ JobNotifier,
13
+ Logger,
14
+ StageNotifier,
15
+ StringThread,
16
+ TokenizedThread,
17
+ TrainingModel,
18
+ )
19
+ from adaptive_harmony.common.callbacks import RecipeCallback
20
+ from adaptive_harmony.common.checkpointing import CheckpointManager
21
+ from adaptive_harmony.core import rl_utils
22
+ from adaptive_harmony.core.utils import async_map, async_map_batch, get_minibatches, hash_hyperparams, log_args
23
+ from adaptive_harmony.graders import BaseGrader
24
+ from adaptive_harmony.metric_logger import StdoutLogger
25
+
26
+
27
+ @dataclass
28
+ class Sample:
29
+ sample: TokenizedThread
30
+ string_sample: StringThread
31
+ logprobs: list[float]
32
+ ref_logprobs: list[float]
33
+ advantages: list[float]
34
+ returns: list[float]
35
+ score: float
36
+ kl_div: list[float]
37
+ values: list[float]
38
+ kl_pen: float
39
+ cumulative_reward: float
40
+
41
+
42
+ PPO_HYPERPARAMS = {
43
+ "max_num_ppo_steps",
44
+ "value_only_fraction",
45
+ "lr_policy",
46
+ "lr_scheduler_policy",
47
+ "lr_scheduler_value",
48
+ "lr_value",
49
+ "samples_per_batch",
50
+ "samples_per_mini_batch",
51
+ "mini_epochs_per_batch",
52
+ "max_grad_norm",
53
+ "clip_range",
54
+ "kl_beta",
55
+ "gae_lambda",
56
+ "gae_gamma",
57
+ "weight_decay",
58
+ "skip_nan_gradients",
59
+ }
60
+
61
+
62
+ class PPO:
63
+ @log_args
64
+ @hash_hyperparams(include=PPO_HYPERPARAMS)
65
+ def __init__(
66
+ self,
67
+ dataset: list[StringThread],
68
+ model: TrainingModel,
69
+ value_model: TrainingModel,
70
+ grader: BaseGrader,
71
+ logger: Logger = StdoutLogger(),
72
+ stage_notifier: StageNotifier = JobNotifier().stage_notifier("PPO Training"),
73
+ callbacks: Sequence[RecipeCallback] = [],
74
+ max_num_ppo_steps: int | None = None,
75
+ value_only_fraction=0.25,
76
+ lr_policy: float = 0.75e-6,
77
+ lr_scheduler_policy: Callable[[float], float] | None = None,
78
+ lr_scheduler_value: Callable[[float], float] | None = None,
79
+ lr_value: float = 1e-6,
80
+ samples_per_batch=128,
81
+ samples_per_mini_batch=128,
82
+ mini_epochs_per_batch=1,
83
+ max_grad_norm=1.0,
84
+ clip_range=0.1,
85
+ kl_beta=0.1,
86
+ gae_lambda=0.95,
87
+ gae_gamma=1.0,
88
+ weight_decay: float = 0,
89
+ skip_nan_gradients: bool = False,
90
+ restart_from_checkpoint: str | None = None,
91
+ checkpoint_frequency: float = 0.2,
92
+ ):
93
+ assert value_model.is_scalar(), "You must give a scalar model to PPO for the value network"
94
+ # Core components
95
+ self.model_ref: InferenceModel | None = None # Instantiated when run() is called
96
+ self.dataset = DataSet(dataset, allow_looping=True)
97
+ self.model = model
98
+ self.value_model = value_model
99
+ self.grader = grader
100
+ self.scoring_fn = grader.score_float_value
101
+ self.logger = logger
102
+ self.stage_notifier = stage_notifier
103
+ self.callbacks = callbacks
104
+ self.skip_nan_gradients = skip_nan_gradients
105
+ # PPO HP's
106
+ self.max_num_batches = max_num_ppo_steps
107
+ self.lr_schedule_policy = lr_scheduler_policy or CombinedSchedule(
108
+ lambda _: 0, CosineScheduler(lr_policy), value_only_fraction
109
+ )
110
+ self.lr_schedule_value = lr_scheduler_value or CosineScheduler(lr_value)
111
+ self.samples_per_batch = samples_per_batch
112
+ self.samples_per_mini_batch = samples_per_mini_batch
113
+ self.total_num_samples = (
114
+ self.max_num_batches * self.samples_per_batch if self.max_num_batches else len(self.dataset)
115
+ )
116
+ self.mini_epochs_per_batch = mini_epochs_per_batch
117
+ self.max_grad_norm = max_grad_norm
118
+ self.clip_range = clip_range
119
+ self.kl_beta = kl_beta
120
+ self.gae_lambda = gae_lambda
121
+ self.gae_gamma = gae_gamma
122
+ self.weight_decay = weight_decay
123
+
124
+ self.num_batches_processed = 0
125
+
126
+ self.checkpoint_manager = CheckpointManager(
127
+ recipe_name="PPO",
128
+ dataset=self.dataset,
129
+ threads_dataset=dataset,
130
+ callbacks=self.callbacks,
131
+ hyperparams_hash=self._hyperparams_hash, # type: ignore
132
+ job_id=os.environ.get("HARMONY_JOB_ID"),
133
+ checkpoint_frequency=checkpoint_frequency,
134
+ restart_from_checkpoint=restart_from_checkpoint,
135
+ )
136
+
137
+ @property
138
+ def training_completion_percentage(self):
139
+ return (
140
+ self.dataset.completion_percentage()
141
+ if self.max_num_batches is None
142
+ else min(self.num_batches_processed / self.max_num_batches, 1.0)
143
+ )
144
+
145
+ async def generate_sample(self, prompt: StringThread):
146
+ assert self.model_ref is not None, "Calling `generate_sample` before reference model has been set"
147
+
148
+ sample = await self.model.generate_tokens(prompt)
149
+ string_sample = await self.model.detokenize_thread(sample)
150
+ score = await self.scoring_fn(string_sample)
151
+ values = await self.value_model.score(sample)
152
+
153
+ logprobs = await self.model.logprobs_per_token(sample)
154
+ ref_logprobs = await self.model_ref.logprobs_per_token(sample)
155
+
156
+ kl = np.array(logprobs, dtype=np.float32) - np.array(ref_logprobs, dtype=np.float32)
157
+ kl_pen = -kl * self.kl_beta
158
+ rewards = np.array(kl_pen)
159
+ rewards[-1] += score
160
+
161
+ advantages = rl_utils.gae_advantages(values, rewards.tolist(), self.gae_lambda, self.gae_gamma)
162
+ returns = rl_utils.discounted_cumulative_rewards(rewards.tolist(), self.gae_gamma)
163
+
164
+ return Sample(
165
+ sample=sample,
166
+ string_sample=string_sample,
167
+ logprobs=logprobs,
168
+ ref_logprobs=ref_logprobs,
169
+ advantages=advantages,
170
+ returns=returns,
171
+ score=score,
172
+ values=values,
173
+ cumulative_reward=sum(rewards),
174
+ kl_div=kl.tolist(),
175
+ kl_pen=np.sum(kl_pen).item(),
176
+ )
177
+
178
+ async def train_ppo(self, sample: Sample):
179
+ await self.model.train_ppo(sample.sample, sample.logprobs, sample.advantages, self.clip_range)
180
+
181
+ async def train_value(self, sample: Sample):
182
+ await self.value_model.train_mse_per_token(sample.sample, sample.returns)
183
+
184
+ async def _recipe_specific_checkpoint_loading(self, checkpoint_data: dict) -> None:
185
+ self.num_batches_processed = checkpoint_data["num_batches_processed"]
186
+
187
+ model_checkpoint_name = checkpoint_data["model_checkpoint_name"]
188
+ await self.model.load(f"model_registry://{model_checkpoint_name}")
189
+ value_model_checkpoint_name = checkpoint_data["value_model_checkpoint_name"]
190
+ await self.value_model.load(f"model_registry://{value_model_checkpoint_name}")
191
+
192
+ self.model.set_optim_step(checkpoint_data["optim_step"])
193
+ self.value_model.set_optim_step(checkpoint_data["value_optim_step"])
194
+
195
+ async def _recipe_specific_checkpoint_saving(self) -> dict:
196
+ progress_pct = int(self.training_completion_percentage * 100)
197
+ model_checkpoint_name = f"checkpoint-{self.checkpoint_manager.job_id}-{progress_pct}-policy"
198
+ model_checkpoint_name = await self.model.save(model_checkpoint_name, inference_only=False)
199
+
200
+ value_model_checkpoint_name = f"checkpoint-{self.checkpoint_manager.job_id}-{progress_pct}-value"
201
+ value_model_checkpoint_name = await self.value_model.save(value_model_checkpoint_name, inference_only=False)
202
+
203
+ return {
204
+ "num_batches_processed": self.num_batches_processed,
205
+ "model_checkpoint_name": model_checkpoint_name,
206
+ "value_model_checkpoint_name": value_model_checkpoint_name,
207
+ "optim_step": self.model.get_optim_step(),
208
+ "value_optim_step": self.value_model.get_optim_step(),
209
+ }
210
+
211
+ async def run(self):
212
+ self.model_ref = await self.model.clone_inf()
213
+ await self.checkpoint_manager.maybe_restore_checkpoint(self._recipe_specific_checkpoint_loading)
214
+
215
+ self.stage_notifier.report_progress(
216
+ tot_num_samples=self.total_num_samples,
217
+ processed_num_samples=self.dataset.idx,
218
+ monitoring_link=self.logger.training_monitoring_link,
219
+ )
220
+
221
+ while self.training_completion_percentage < 1.0:
222
+ self.num_batches_processed += 1
223
+
224
+ for callback in self.callbacks:
225
+ if logs := await callback.maybe_call(self.training_completion_percentage):
226
+ self.logger(logs)
227
+
228
+ # Generate training samples
229
+ data = await async_map_batch(
230
+ self.generate_sample,
231
+ self.dataset,
232
+ self.samples_per_batch,
233
+ )
234
+ scorer_logs = self.grader.get_logs(clear=True)
235
+ batch_logs = {
236
+ **{f"rewards/{key}": value for key, value in scorer_logs.items()},
237
+ **self.get_train_batch_logs(data),
238
+ }
239
+
240
+ lr_policy = self.lr_schedule_policy(self.training_completion_percentage)
241
+ lr_value = self.lr_schedule_value(self.training_completion_percentage)
242
+
243
+ # Train on generated samples
244
+ if lr_policy > 0:
245
+ minibatches = get_minibatches(data, self.samples_per_mini_batch, self.mini_epochs_per_batch)
246
+ for idx, mini_batch in enumerate(minibatches):
247
+ await async_map(self.train_ppo, mini_batch)
248
+ optim_logs = await self.model.optim_step(
249
+ lr_policy,
250
+ wd=self.weight_decay,
251
+ max_grad_norm=self.max_grad_norm,
252
+ skip_nan_gradients=self.skip_nan_gradients,
253
+ )
254
+ if idx == len(minibatches) - 1:
255
+ # only log tables and full batch-related logs on the final minibatch
256
+ self.logger(optim_logs | batch_logs)
257
+ else:
258
+ self.logger(optim_logs | dict(completion_percentage=self.training_completion_percentage))
259
+
260
+ for mini_batch in get_minibatches(data, self.samples_per_mini_batch, self.mini_epochs_per_batch):
261
+ await async_map(self.train_value, mini_batch)
262
+ batch_logs |= await self.value_model.optim_step(
263
+ lr_value, wd=0, max_grad_norm=self.max_grad_norm, skip_nan_gradients=self.skip_nan_gradients
264
+ )
265
+ self.logger(batch_logs)
266
+
267
+ self.stage_notifier.report_progress(
268
+ tot_num_samples=self.total_num_samples,
269
+ processed_num_samples=self.dataset.idx,
270
+ monitoring_link=self.logger.training_monitoring_link,
271
+ )
272
+
273
+ if await self.checkpoint_manager.maybe_checkpoint(
274
+ self.training_completion_percentage, self._recipe_specific_checkpoint_saving
275
+ ):
276
+ break
277
+
278
+ def get_train_batch_logs(self, data: list[Sample]) -> dict:
279
+ returns = np.concatenate([batch.returns for batch in data])
280
+ cur_values = np.concatenate([batch.values for batch in data])
281
+
282
+ var_return = returns.var()
283
+ mean_error = ((cur_values - returns) ** 2).mean()
284
+ explained_variance = (1 - mean_error / (var_return + 1e-8)).item()
285
+
286
+ logs = dict(
287
+ completion_percentage=self.training_completion_percentage,
288
+ score_mean=np.mean([batch.score for batch in data]).item(),
289
+ score_std=np.std([batch.score for batch in data]).item(),
290
+ returns=np.mean(np.concatenate([batch.returns for batch in data])),
291
+ kl_div=np.mean(np.concatenate([batch.kl_div for batch in data])),
292
+ advantages=np.mean(np.concatenate([batch.advantages for batch in data])),
293
+ generation_length=np.mean([batch.sample.len_last_turn() for batch in data]),
294
+ logprobs=np.mean(np.concatenate([batch.logprobs for batch in data])),
295
+ ref_logprobs=np.mean(np.concatenate([batch.ref_logprobs for batch in data])),
296
+ kl_penalty=np.mean([batch.kl_pen for batch in data]),
297
+ explained_variance=explained_variance,
298
+ cumulative_reward=np.mean([batch.cumulative_reward for batch in data]),
299
+ ) | {
300
+ "training/completion_percentage": self.training_completion_percentage
301
+ } # to have an comparable axis with prior runs
302
+
303
+ return logs
@@ -0,0 +1,79 @@
1
+ from typing import Callable, Sequence
2
+
3
+ from tqdm.auto import tqdm
4
+
5
+ from adaptive_harmony import CosineScheduler, DataSet, JobNotifier, Logger, StageNotifier, StringThread, TrainingModel
6
+ from adaptive_harmony.common.callbacks import RecipeCallback
7
+ from adaptive_harmony.core.utils import async_map_batch, log_args
8
+ from adaptive_harmony.metric_logger import StdoutLogger
9
+
10
+
11
+ class RewardModelling:
12
+ @log_args
13
+ def __init__(
14
+ self,
15
+ dataset: list[tuple[StringThread, StringThread]] | list[StringThread],
16
+ model: TrainingModel,
17
+ logger: Logger = StdoutLogger(),
18
+ job_notifier: StageNotifier = JobNotifier().stage_notifier("Reward Model Training"),
19
+ callbacks: Sequence[RecipeCallback] = [],
20
+ lr: float = 1e-06,
21
+ lr_scheduler: Callable[[float], float] | None = None,
22
+ samples_per_batch: int = 64,
23
+ max_grad_norm: float = 1.0,
24
+ epochs: int = 1,
25
+ skip_nan_gradients: bool = False,
26
+ ):
27
+ self.dataset: DataSet[StringThread | tuple[StringThread, StringThread]] = DataSet(dataset, allow_looping=True)
28
+ self.lr_schedule = lr_scheduler or CosineScheduler(lr)
29
+ self.model = model
30
+ self.logger = logger
31
+ self.job_notifier = job_notifier
32
+ self.callbacks = callbacks
33
+ self.samples_per_batch = samples_per_batch
34
+ self.max_grad_norm = max_grad_norm
35
+ self.skip_nan_gradients = skip_nan_gradients
36
+ self.epochs = epochs
37
+
38
+ @property
39
+ def training_completion_percentage(self):
40
+ return self.dataset.completion_percentage() / self.epochs
41
+
42
+ async def train_rm(self, sample: tuple[StringThread, StringThread] | StringThread):
43
+ # having both preference and metric feedback in dataset likely will never happen, but it's possible
44
+ if isinstance(sample, tuple):
45
+ await self.model.train_ranking(sample[0], sample[1])
46
+ else:
47
+ if "res" not in sample.metadata:
48
+ raise ValueError(f"Sample missing required 'res' field in metadata: {sample.metadata.keys()}")
49
+ target_value = sample.metadata["res"]
50
+ await self.model.train_mse(sample, target_value)
51
+
52
+ async def run(self):
53
+ with tqdm(total=100) as pbar:
54
+ self.job_notifier.report_progress(
55
+ tot_num_samples=len(self.dataset) * self.epochs,
56
+ processed_num_samples=self.dataset.idx,
57
+ monitoring_link=self.logger.training_monitoring_link,
58
+ )
59
+ while self.training_completion_percentage < 1.0:
60
+ for callback in self.callbacks:
61
+ if logs := await callback.maybe_call(self.training_completion_percentage):
62
+ self.logger(logs)
63
+
64
+ await async_map_batch(self.train_rm, self.dataset, self.samples_per_batch)
65
+ cp = self.training_completion_percentage
66
+ current_lr = self.lr_schedule(cp)
67
+ pbar.update(cp * 100.0 - pbar.n)
68
+
69
+ logs = await self.model.optim_step(
70
+ current_lr, wd=0, max_grad_norm=self.max_grad_norm, skip_nan_gradients=self.skip_nan_gradients
71
+ )
72
+
73
+ self.job_notifier.report_progress(
74
+ tot_num_samples=len(self.dataset) * self.epochs,
75
+ processed_num_samples=self.dataset.idx,
76
+ monitoring_link=self.logger.training_monitoring_link,
77
+ )
78
+
79
+ self.logger(logs | dict(completion_percentage=cp))
@@ -0,0 +1,121 @@
1
+ import os
2
+ from typing import Callable, Sequence
3
+
4
+ from tqdm.auto import tqdm
5
+
6
+ from adaptive_harmony import CosineScheduler, DataSet, JobNotifier, Logger, StageNotifier, StringThread, TrainingModel
7
+ from adaptive_harmony.common.callbacks import RecipeCallback
8
+ from adaptive_harmony.common.checkpointing import CheckpointManager
9
+ from adaptive_harmony.core.utils import async_map_batch, hash_hyperparams, log_args
10
+ from adaptive_harmony.metric_logger import StdoutLogger
11
+
12
+ SFT_HYPERPARAMS = {
13
+ "lr",
14
+ "lr_scheduler",
15
+ "samples_per_batch",
16
+ "max_grad_norm",
17
+ "epochs",
18
+ "weight_decay",
19
+ "skip_nan_gradients",
20
+ }
21
+
22
+
23
+ class SFT:
24
+ @log_args
25
+ @hash_hyperparams(include=SFT_HYPERPARAMS)
26
+ def __init__(
27
+ self,
28
+ dataset: list[StringThread],
29
+ model: TrainingModel,
30
+ logger: Logger = StdoutLogger(),
31
+ stage_notifier: StageNotifier = JobNotifier().stage_notifier("SFT Training"),
32
+ callbacks: Sequence[RecipeCallback] = [],
33
+ lr: float = 1e-5,
34
+ lr_scheduler: Callable[[float], float] | None = None,
35
+ samples_per_batch=512, # axel magic number: "pretty well validated across different scales"
36
+ max_grad_norm=1.0,
37
+ epochs: int = 1,
38
+ weight_decay: float = 0,
39
+ skip_nan_gradients: bool = False,
40
+ restart_from_checkpoint: str | None = None,
41
+ checkpoint_frequency: float = 0.2,
42
+ ):
43
+ self.dataset = DataSet(dataset, allow_looping=epochs != 1)
44
+ self.lr_schedule = lr_scheduler or CosineScheduler(lr)
45
+ self.model = model
46
+ self.logger = logger
47
+ self.stage_notifier = stage_notifier
48
+ self.callbacks = callbacks
49
+ self.samples_per_batch = samples_per_batch
50
+ self.max_grad_norm = max_grad_norm
51
+ self.epochs = epochs
52
+ self.weight_decay = weight_decay
53
+ self.skip_nan_gradients = skip_nan_gradients
54
+
55
+ self.checkpoint_manager = CheckpointManager(
56
+ recipe_name="SFT",
57
+ dataset=self.dataset,
58
+ threads_dataset=dataset,
59
+ callbacks=self.callbacks,
60
+ hyperparams_hash=self._hyperparams_hash, # type: ignore
61
+ job_id=os.environ.get("HARMONY_JOB_ID"),
62
+ checkpoint_frequency=checkpoint_frequency,
63
+ restart_from_checkpoint=restart_from_checkpoint,
64
+ )
65
+
66
+ @property
67
+ def training_completion_percentage(self):
68
+ return self.dataset.completion_percentage() / self.epochs
69
+
70
+ async def _recipe_specific_checkpoint_loading(self, checkpoint_data: dict) -> None:
71
+ model_checkpoint_name = checkpoint_data["model_checkpoint_name"]
72
+ await self.model.load(f"model_registry://{model_checkpoint_name}")
73
+ self.model.set_optim_step(checkpoint_data["optim_step"])
74
+
75
+ async def _recipe_specific_checkpoint_saving(self) -> dict:
76
+ progress_pct = int(self.training_completion_percentage * 100)
77
+ model_checkpoint_name = f"checkpoint-{self.checkpoint_manager.job_id}-{progress_pct}"
78
+ model_checkpoint_name = await self.model.save(model_checkpoint_name, inference_only=False)
79
+
80
+ return {
81
+ "model_checkpoint_name": model_checkpoint_name,
82
+ "optim_step": self.model.get_optim_step(),
83
+ }
84
+
85
+ async def run(self):
86
+ await self.checkpoint_manager.maybe_restore_checkpoint(self._recipe_specific_checkpoint_loading)
87
+
88
+ self.stage_notifier.report_progress(
89
+ tot_num_samples=len(self.dataset) * self.epochs,
90
+ processed_num_samples=self.dataset.idx,
91
+ monitoring_link=self.logger.training_monitoring_link,
92
+ )
93
+
94
+ with tqdm(total=100) as pbar:
95
+ while self.training_completion_percentage < 1.0:
96
+ for callback in self.callbacks:
97
+ if logs := await callback.maybe_call(self.training_completion_percentage):
98
+ self.logger(logs)
99
+
100
+ await async_map_batch(self.model.train_language_modelling, self.dataset, self.samples_per_batch)
101
+ cp = self.training_completion_percentage
102
+ current_lr = self.lr_schedule(cp)
103
+ pbar.update(cp * 100.0 - pbar.n)
104
+
105
+ logs = await self.model.optim_step(
106
+ current_lr,
107
+ wd=self.weight_decay,
108
+ max_grad_norm=self.max_grad_norm,
109
+ skip_nan_gradients=self.skip_nan_gradients,
110
+ )
111
+
112
+ self.logger(logs | dict(completion_percentage=cp))
113
+
114
+ self.stage_notifier.report_progress(
115
+ tot_num_samples=len(self.dataset) * self.epochs,
116
+ processed_num_samples=self.dataset.idx,
117
+ monitoring_link=self.logger.training_monitoring_link,
118
+ )
119
+
120
+ if await self.checkpoint_manager.maybe_checkpoint(cp, self._recipe_specific_checkpoint_saving):
121
+ break
File without changes
@@ -0,0 +1,72 @@
1
+ from typing import Callable, Sequence
2
+
3
+ import numpy as np
4
+ from datasets import load_dataset
5
+
6
+ from adaptive_harmony import StringThread
7
+
8
+
9
+ class DataSet[T]:
10
+ def __init__(self, threads: Sequence[T], allow_looping: bool = False, seed: int = 42):
11
+ self.threads = threads
12
+ self.allow_looping = allow_looping
13
+ # This will be used to shuffle the dataset when we cross the epoch boundary, initially it is just the indices of the threads
14
+ # to respect the given order in the first epoch
15
+ self.rng = np.random.default_rng(seed)
16
+ self.access_indices = self.rng.permutation(len(threads))
17
+ self.idx = 0
18
+
19
+ def __iter__(self) -> "DataSet":
20
+ return self
21
+
22
+ def __len__(self) -> int:
23
+ return len(self.threads)
24
+
25
+ def __next__(self) -> T:
26
+ if not self.allow_looping and self.idx == len(self.threads):
27
+ raise StopIteration()
28
+ elif self.allow_looping and self.idx == len(self.access_indices):
29
+ self.access_indices = np.concatenate([self.access_indices, self.rng.permutation(len(self.threads))])
30
+
31
+ sample_idx = self.access_indices[self.idx]
32
+ ret = self.threads[sample_idx]
33
+
34
+ if hasattr(ret, "metadata") and isinstance(ret.metadata, dict):
35
+ ret.metadata["sample_index"] = self.idx
36
+
37
+ self.idx += 1
38
+
39
+ return ret
40
+
41
+ def __getitem__(self, x):
42
+ return self.threads.__getitem__(x)
43
+
44
+ def completion_percentage(self) -> float:
45
+ """If dataset is looping, this can return a value greater than 1.0. Handle in recipe."""
46
+ return self.idx / len(self.threads)
47
+
48
+ def reset(self):
49
+ self.idx = 0
50
+
51
+
52
+ def convert_sample_dict(
53
+ turns_key: str | None = "messages", role_key="role", content_key="content", trim_final_assistant_turns=False
54
+ ):
55
+ def f(dialogue: dict) -> StringThread:
56
+ if turns_key is not None:
57
+ dialogue = dialogue[turns_key]
58
+ turns = [(turn[role_key], turn[content_key]) for turn in dialogue]
59
+
60
+ if trim_final_assistant_turns:
61
+ while len(turns) > 0 and turns[-1][0] == "assistant":
62
+ turns = turns[:-1]
63
+
64
+ return StringThread(turns)
65
+
66
+ return f
67
+
68
+
69
+ def load_from_hf(repo: str, split: str, convert_sample_fn: Callable[..., StringThread]) -> list[StringThread]:
70
+ dataset = load_dataset(repo, split=split, keep_in_memory=True)
71
+ dataset = dataset.select(range(len(dataset))) # type: ignore
72
+ return [convert_sample_fn(x) for x in dataset]
@@ -0,0 +1,93 @@
1
+ import textwrap
2
+
3
+ from rich.console import Console
4
+ from rich.table import Table
5
+
6
+ from adaptive_harmony import StringThread, TokenizedThread
7
+
8
+
9
+ def _stringthread_repr(self: StringThread) -> str:
10
+ """Rich-based __repr__ for StringThread."""
11
+ # Get turns from the thread
12
+ turns = self.get_turns()
13
+
14
+ # Create a table without borders
15
+ table = Table(show_header=False, show_edge=False, box=None, padding=(0, 1))
16
+ table.add_column("Role", style="bold blue", no_wrap=True, justify="right")
17
+ table.add_column("Content", overflow="fold")
18
+
19
+ wrap_width = 90
20
+
21
+ for turn in turns:
22
+ # Wrap content if needed
23
+ if wrap_width > 0 and turn.content:
24
+ wrapped_lines = []
25
+ for line in turn.content.split("\n"):
26
+ if line:
27
+ wrapped = textwrap.fill(line, width=wrap_width)
28
+ wrapped_lines.append(wrapped)
29
+ else:
30
+ wrapped_lines.append("")
31
+ content = "\n".join(wrapped_lines)
32
+ else:
33
+ content = turn.content
34
+
35
+ table.add_row(turn.role.upper(), content)
36
+
37
+ # Capture the output with horizontal lines
38
+ from io import StringIO
39
+
40
+ buffer = StringIO()
41
+ # Use styling for __repr__ since it's typically for interactive display
42
+ console = Console(file=buffer, width=120, markup=False)
43
+
44
+ # Get max width for the separator line
45
+ max_role_len = max(len(turn.role) for turn in turns) if turns else 0
46
+ total_width = max_role_len + 2 + wrap_width
47
+ separator = "─" * total_width
48
+
49
+ # Print with separators like the Rust version
50
+ buffer.write(separator + "\n")
51
+ console.print(table)
52
+ if self.metadata:
53
+ buffer.write(f"Metadata={self.metadata}\n")
54
+ buffer.write(separator + "\n")
55
+
56
+ return buffer.getvalue().rstrip()
57
+
58
+
59
+ def _tokenizedthread_repr(self: TokenizedThread) -> str:
60
+ """Rich-based __repr__ for TokenizedThread."""
61
+ # Get turns from the thread
62
+ turns = self.get_turns()
63
+
64
+ # Create a table without borders
65
+ table = Table(show_header=False, show_edge=False, box=None, padding=(0, 1))
66
+ table.add_column("Role", style="bold blue", no_wrap=True, justify="right")
67
+ table.add_column("Tokens", overflow="fold")
68
+
69
+ for turn in turns:
70
+ # Format tokens as a string of integers
71
+ tokens_str = " ".join(str(t) for t in turn.content)
72
+
73
+ table.add_row(turn.role.upper(), tokens_str)
74
+
75
+ # Capture the output with horizontal lines
76
+ from io import StringIO
77
+
78
+ buffer = StringIO()
79
+ # Use styling for __repr__ since it's typically for interactive display
80
+ console = Console(file=buffer, width=120, markup=False)
81
+
82
+ # Get max width for the separator line
83
+ max_role_len = max(len(turn.role) for turn in turns) if turns else 0
84
+ # Estimate token display width (rough estimate)
85
+ total_width = max_role_len + 2 + 90
86
+ separator = "─" * total_width
87
+
88
+ # Print with separators like the Rust version
89
+ buffer.write(separator + "\n")
90
+ console.print(table)
91
+ buffer.write(separator + "\n")
92
+
93
+ return buffer.getvalue().rstrip()