adaptive-harmony 0.1.23__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (67) hide show
  1. adaptive_harmony/__init__.py +162 -0
  2. adaptive_harmony/common/__init__.py +40 -0
  3. adaptive_harmony/common/callbacks.py +219 -0
  4. adaptive_harmony/common/checkpointing.py +163 -0
  5. adaptive_harmony/common/dpo.py +92 -0
  6. adaptive_harmony/common/env_grpo.py +361 -0
  7. adaptive_harmony/common/grpo.py +260 -0
  8. adaptive_harmony/common/gspo.py +70 -0
  9. adaptive_harmony/common/ppo.py +303 -0
  10. adaptive_harmony/common/rm.py +79 -0
  11. adaptive_harmony/common/sft.py +121 -0
  12. adaptive_harmony/core/__init__.py +0 -0
  13. adaptive_harmony/core/dataset.py +72 -0
  14. adaptive_harmony/core/display.py +93 -0
  15. adaptive_harmony/core/image_utils.py +110 -0
  16. adaptive_harmony/core/reasoning.py +12 -0
  17. adaptive_harmony/core/reward_client/__init__.py +19 -0
  18. adaptive_harmony/core/reward_client/client.py +160 -0
  19. adaptive_harmony/core/reward_client/reward_types.py +49 -0
  20. adaptive_harmony/core/reward_client/websocket_utils.py +18 -0
  21. adaptive_harmony/core/rich_counter.py +351 -0
  22. adaptive_harmony/core/rl_utils.py +38 -0
  23. adaptive_harmony/core/schedulers.py +38 -0
  24. adaptive_harmony/core/structured_output.py +385 -0
  25. adaptive_harmony/core/utils.py +365 -0
  26. adaptive_harmony/environment/__init__.py +8 -0
  27. adaptive_harmony/environment/environment.py +121 -0
  28. adaptive_harmony/evaluation/__init__.py +1 -0
  29. adaptive_harmony/evaluation/evaluation_artifact.py +67 -0
  30. adaptive_harmony/graders/__init__.py +20 -0
  31. adaptive_harmony/graders/answer_relevancy_judge/__init__.py +3 -0
  32. adaptive_harmony/graders/answer_relevancy_judge/answer_relevancy_judge.py +102 -0
  33. adaptive_harmony/graders/answer_relevancy_judge/prompts.py +58 -0
  34. adaptive_harmony/graders/base_grader.py +265 -0
  35. adaptive_harmony/graders/binary_judge/__init__.py +8 -0
  36. adaptive_harmony/graders/binary_judge/binary_judge.py +202 -0
  37. adaptive_harmony/graders/binary_judge/prompts.py +125 -0
  38. adaptive_harmony/graders/combined_grader.py +118 -0
  39. adaptive_harmony/graders/context_relevancy_judge/__init__.py +3 -0
  40. adaptive_harmony/graders/context_relevancy_judge/context_relevancy_judge.py +128 -0
  41. adaptive_harmony/graders/context_relevancy_judge/prompts.py +84 -0
  42. adaptive_harmony/graders/exceptions.py +9 -0
  43. adaptive_harmony/graders/faithfulness_judge/__init__.py +3 -0
  44. adaptive_harmony/graders/faithfulness_judge/faithfulness_judge.py +159 -0
  45. adaptive_harmony/graders/faithfulness_judge/prompts.py +22 -0
  46. adaptive_harmony/graders/range_judge/__init__.py +7 -0
  47. adaptive_harmony/graders/range_judge/prompts.py +232 -0
  48. adaptive_harmony/graders/range_judge/range_judge.py +188 -0
  49. adaptive_harmony/graders/range_judge/types.py +12 -0
  50. adaptive_harmony/graders/reward_server_grader.py +36 -0
  51. adaptive_harmony/graders/templated_prompt_judge.py +237 -0
  52. adaptive_harmony/graders/utils.py +79 -0
  53. adaptive_harmony/logging_table.py +1 -0
  54. adaptive_harmony/metric_logger.py +452 -0
  55. adaptive_harmony/parameters/__init__.py +2 -0
  56. adaptive_harmony/py.typed +0 -0
  57. adaptive_harmony/runtime/__init__.py +2 -0
  58. adaptive_harmony/runtime/context.py +2 -0
  59. adaptive_harmony/runtime/data.py +2 -0
  60. adaptive_harmony/runtime/decorators.py +2 -0
  61. adaptive_harmony/runtime/model_artifact_save.py +2 -0
  62. adaptive_harmony/runtime/runner.py +27 -0
  63. adaptive_harmony/runtime/simple_notifier.py +2 -0
  64. adaptive_harmony-0.1.23.dist-info/METADATA +37 -0
  65. adaptive_harmony-0.1.23.dist-info/RECORD +67 -0
  66. adaptive_harmony-0.1.23.dist-info/WHEEL +5 -0
  67. adaptive_harmony-0.1.23.dist-info/top_level.txt +1 -0
@@ -0,0 +1,162 @@
1
+ # ruff: noqa: F403, F401
2
+ from typing import TYPE_CHECKING
3
+
4
+ from harmony_client import (
5
+ EvalSample as EvalSample,
6
+ )
7
+ from harmony_client import (
8
+ EvalSampleInteraction as EvalSampleInteraction,
9
+ )
10
+ from harmony_client import (
11
+ Grade as Grade,
12
+ )
13
+ from harmony_client import (
14
+ HarmonyClient as HarmonyClient,
15
+ )
16
+ from harmony_client import (
17
+ HarmonyJobNotifier as HarmonyJobNotifier,
18
+ )
19
+ from harmony_client import (
20
+ InferenceModel as InferenceModel,
21
+ )
22
+ from harmony_client import (
23
+ JobArtifact as JobArtifact,
24
+ )
25
+ from harmony_client import (
26
+ JobNotifier as JobNotifier,
27
+ )
28
+ from harmony_client import (
29
+ ModelBuilder as ModelBuilder,
30
+ )
31
+ from harmony_client import (
32
+ StageNotifier as StageNotifier,
33
+ )
34
+ from harmony_client import (
35
+ StringThread as StringThread,
36
+ )
37
+ from harmony_client import (
38
+ TokenizedThread as TokenizedThread,
39
+ )
40
+ from harmony_client import (
41
+ TrainingModel as TrainingModel,
42
+ )
43
+ from harmony_client import (
44
+ get_client as get_client,
45
+ )
46
+ from harmony_client import parameters as parameters
47
+ from harmony_client import runtime as runtime
48
+ from rich.progress import Progress
49
+
50
+ if TYPE_CHECKING:
51
+ from harmony_client import StringTurn as StringTurn
52
+ else:
53
+ from typing import NamedTuple
54
+
55
+ class StringTurn(NamedTuple):
56
+ role: str
57
+ content: str
58
+
59
+
60
+ from harmony_client.artifacts.custom_artifact import CustomArtifact
61
+ from harmony_client.artifacts.dataset_artifact import DatasetArtifact
62
+ from harmony_client.file_storage import (
63
+ FileStorage,
64
+ FileStorageConfig,
65
+ LocalFileStorageConfig,
66
+ S3FileStorageConfig,
67
+ StoredFile,
68
+ )
69
+
70
+ import adaptive_harmony.core.rl_utils as rl_utils
71
+ from adaptive_harmony.core.dataset import DataSet
72
+ from adaptive_harmony.core.schedulers import CombinedSchedule, CosineScheduler, CosineSchedulerWithoutWarmup, Scheduler
73
+ from adaptive_harmony.evaluation.evaluation_artifact import EvaluationArtifact
74
+ from adaptive_harmony.metric_logger import Logger, WandbLogger
75
+
76
+ # Ensure key classes are available at module level
77
+ __all__ = [
78
+ "StringThread",
79
+ "StringTurn",
80
+ "TokenizedThread",
81
+ "InferenceModel",
82
+ "ModelBuilder",
83
+ "TrainingModel",
84
+ "HarmonyClient",
85
+ "get_client",
86
+ "DataSet",
87
+ "CosineScheduler",
88
+ "CombinedSchedule",
89
+ "CosineSchedulerWithoutWarmup",
90
+ "Scheduler",
91
+ "WandbLogger",
92
+ "Logger",
93
+ "FileStorage",
94
+ "FileStorageConfig",
95
+ "LocalFileStorageConfig",
96
+ "S3FileStorageConfig",
97
+ "StoredFile",
98
+ "EvaluationArtifact",
99
+ "CustomArtifact",
100
+ "DatasetArtifact",
101
+ "rl_utils",
102
+ "Grade",
103
+ "EvalSample",
104
+ "EvalSampleInteraction",
105
+ "JobArtifact",
106
+ ]
107
+
108
+
109
+ # Patch StringThread to use rich for display
110
+ from harmony_client.runtime.model_artifact_save import save_with_artifact
111
+
112
+ from adaptive_harmony.core.display import _stringthread_repr, _tokenizedthread_repr
113
+ from adaptive_harmony.core.image_utils import string_thread_to_html_string
114
+
115
+ # Patch InferenceModel to have json output capabilities
116
+ from adaptive_harmony.core.structured_output import generate_and_validate, render_pydantic_model, render_schema
117
+
118
+ StringThread.__repr__ = _stringthread_repr # type: ignore
119
+ TokenizedThread.__repr__ = _tokenizedthread_repr # type: ignore
120
+ setattr(StringThread, "_repr_html_", string_thread_to_html_string)
121
+ setattr(InferenceModel, "generate_and_validate", generate_and_validate)
122
+ setattr(InferenceModel, "render_schema", staticmethod(render_schema))
123
+ setattr(InferenceModel, "render_pydantic_model", staticmethod(render_pydantic_model))
124
+
125
+ _original_training_model_save = TrainingModel.save
126
+
127
+
128
+ async def _save_with_artifact_wrapper(model: TrainingModel, model_name: str, inference_only: bool = True, ctx=None):
129
+ return await save_with_artifact(model, model_name, inference_only, ctx, _original_training_model_save)
130
+
131
+
132
+ setattr(TrainingModel, "save", _save_with_artifact_wrapper)
133
+
134
+
135
+ async def spawn_train(self: ModelBuilder, name: str, max_batch_size: int) -> TrainingModel:
136
+ fut = await self.spawn_train_with_progress(name, max_batch_size) # type:ignore
137
+
138
+ with Progress() as pbar:
139
+ task = pbar.add_task("Loading model", total=1000)
140
+
141
+ while (prog := await fut._await_progress()) != 1.0:
142
+ pbar.update(task, completed=prog, total=1.0)
143
+ pbar.update(task, completed=1.0, total=1.0)
144
+
145
+ return await fut.get()
146
+
147
+
148
+ async def spawn_inference(self: ModelBuilder, name: str) -> InferenceModel:
149
+ fut = await self.spawn_inference_with_progress(name) # type:ignore
150
+
151
+ with Progress() as pbar:
152
+ task = pbar.add_task("Loading model", total=1000)
153
+
154
+ while (prog := await fut._await_progress()) != 1.0:
155
+ pbar.update(task, completed=prog, total=1.0)
156
+ pbar.update(task, completed=1.0, total=1.0)
157
+
158
+ return await fut.get()
159
+
160
+
161
+ setattr(ModelBuilder, "spawn_inference", spawn_inference)
162
+ setattr(ModelBuilder, "spawn_train", spawn_train)
@@ -0,0 +1,40 @@
1
+ from .callbacks import (
2
+ CheckpointCallback as CheckpointCallback,
3
+ )
4
+ from .callbacks import (
5
+ EnvironmentValidationCallback as EnvironmentValidationCallback,
6
+ )
7
+ from .callbacks import (
8
+ GenerateSamplesCallback as GenerateSamplesCallback,
9
+ )
10
+ from .callbacks import (
11
+ GraderEvalCallback as GraderEvalCallback,
12
+ )
13
+ from .callbacks import (
14
+ RecipeCallback as RecipeCallback,
15
+ )
16
+ from .callbacks import (
17
+ ValidationLossCallback as ValidationLossCallback,
18
+ )
19
+ from .dpo import DPO as DPO
20
+ from .env_grpo import ENVGRPO
21
+ from .grpo import GRPO as GRPO
22
+ from .gspo import GSPO as GSPO
23
+ from .ppo import PPO as PPO
24
+ from .rm import RewardModelling as RewardModelling
25
+ from .sft import SFT as SFT
26
+
27
+ __all__ = [
28
+ "SFT",
29
+ "PPO",
30
+ "GRPO",
31
+ "ENVGRPO",
32
+ "DPO",
33
+ "RewardModelling",
34
+ "RecipeCallback",
35
+ "GenerateSamplesCallback",
36
+ "ValidationLossCallback",
37
+ "CheckpointCallback",
38
+ "GraderEvalCallback",
39
+ "EnvironmentValidationCallback",
40
+ ]
@@ -0,0 +1,219 @@
1
+ from abc import abstractmethod
2
+ from typing import Any
3
+
4
+ import numpy as np
5
+ from harmony_client import (
6
+ InferenceModel,
7
+ StringThread,
8
+ TrainingModel,
9
+ )
10
+ from loguru import logger
11
+
12
+ from adaptive_harmony.core.utils import async_map, async_map_fallible
13
+ from adaptive_harmony.environment import EnvironmentFactory
14
+ from adaptive_harmony.graders import BaseGrader
15
+ from adaptive_harmony.logging_table import Table
16
+
17
+
18
+ class RecipeCallback:
19
+ def __init__(self, frequency: float, log_key_prefix: str | None = None):
20
+ self.frequency = frequency
21
+ self.last_call = -1.0
22
+ self.log_key_prefix = log_key_prefix
23
+
24
+ async def maybe_call(self, current_percentage: float) -> dict[str, Any]:
25
+ if current_percentage - self.last_call >= self.frequency:
26
+ self.last_call = current_percentage
27
+ callback_dict = await self.callback(current_percentage)
28
+ prefixed_dict = {
29
+ (f"{self.log_key_prefix}/{key}" if self.log_key_prefix else key): value
30
+ for key, value in callback_dict.items()
31
+ }
32
+ return prefixed_dict
33
+ return {}
34
+
35
+ @abstractmethod
36
+ async def callback(self, current_percentage: float) -> dict[str, Any]: ...
37
+
38
+
39
+ class GenerateSamplesCallback(RecipeCallback):
40
+ def __init__(
41
+ self,
42
+ thread_set: list[StringThread],
43
+ model: InferenceModel,
44
+ frequency: float,
45
+ log_key: str = "samples",
46
+ ):
47
+ super().__init__(frequency, log_key_prefix="generation")
48
+ self.thread_set = thread_set
49
+ self.model = model
50
+ self.log_key = log_key
51
+
52
+ async def callback(self, current_percentage: float) -> dict[str, Any]:
53
+ logger.info("Entering generation callback...")
54
+ generation_tokens = await async_map_fallible(self.model.generate_tokens, self.thread_set)
55
+ generation_results = await async_map_fallible(self.model.detokenize_thread, generation_tokens)
56
+ gen_lengths = [sample.len_last_turn() for sample in generation_tokens]
57
+
58
+ generation_logs = {
59
+ self.log_key: Table()
60
+ .add_column(
61
+ "system",
62
+ [
63
+ sample.get_turns()[0].content if sample.get_turns()[0].role == "system" else ""
64
+ for sample in generation_results
65
+ ],
66
+ )
67
+ .add_column(
68
+ "prompt",
69
+ [
70
+ repr(
71
+ StringThread(
72
+ sample.get_turns()[1:-1]
73
+ if (sample.get_turns() and sample.get_turns()[0].role == "system")
74
+ else sample.get_turns()[:-1]
75
+ )
76
+ )
77
+ for sample in generation_results
78
+ ],
79
+ )
80
+ .add_column("response", [response.last_content() for response in generation_results]),
81
+ "generation_length_mean": np.mean(gen_lengths).item(),
82
+ "generation_length_std": np.std(gen_lengths).item(),
83
+ "num_samples": len(generation_results),
84
+ }
85
+ return generation_logs
86
+
87
+
88
+ class ValidationLossCallback(RecipeCallback):
89
+ def __init__(
90
+ self,
91
+ validation_set: list[StringThread],
92
+ model: InferenceModel,
93
+ frequency: float = 0.1,
94
+ log_key: str = "loss",
95
+ ):
96
+ super().__init__(frequency, log_key_prefix="validation")
97
+ self.validation_set = validation_set
98
+ self.model = model
99
+ self.log_key = log_key
100
+
101
+ async def callback(self, current_percentage: float) -> dict[str, float]:
102
+ logger.info("Entering validation loss callback...")
103
+ losses = []
104
+ tokens = await async_map_fallible(self.model.tokenize_thread, self.validation_set)
105
+ logprobs = await async_map(self.model.logprobs_per_token, tokens)
106
+ losses = [-(sum(lp) / len(lp)) for lp in logprobs]
107
+
108
+ return {self.log_key: sum(losses) / len(losses)}
109
+
110
+
111
+ class CheckpointCallback(RecipeCallback):
112
+ def __init__(
113
+ self,
114
+ model: TrainingModel,
115
+ checkpoint_name: str,
116
+ frequency: float = 0.2,
117
+ ):
118
+ super().__init__(frequency, log_key_prefix="checkpointing")
119
+ self.last_call = 0.0 # avoid saving the model at the first period
120
+ self.model = model
121
+ self.model_log_name = checkpoint_name
122
+
123
+ async def callback(self, current_percentage: float):
124
+ logger.info(f"Saving checkpoint at {current_percentage * 100} % of training ...")
125
+ await self.model.save(f"{self.model_log_name}-{round(current_percentage, 3)}")
126
+ return {}
127
+
128
+
129
+ class GraderEvalCallback(RecipeCallback):
130
+ def __init__(
131
+ self,
132
+ validation_set: list[StringThread],
133
+ model: InferenceModel,
134
+ grader: BaseGrader,
135
+ frequency: float,
136
+ log_key: str = "validation",
137
+ clear_grader_logs: bool = True,
138
+ temperature: float = 0.0,
139
+ ):
140
+ super().__init__(frequency, log_key_prefix=log_key)
141
+ self.validation_set = validation_set
142
+ self.model = model
143
+ self.grader = grader
144
+ self.clear_grader_logs = clear_grader_logs
145
+ self.temperature = temperature
146
+
147
+ async def callback(self, current_percentage: float) -> dict[str, float | Table]:
148
+ logger.info("Entering grader evaluation callback...")
149
+ temp_model = self.model.temperature(self.temperature)
150
+
151
+ tokenized_results = await async_map_fallible(temp_model.generate_tokens, self.validation_set)
152
+ string_results = await async_map(temp_model.detokenize_thread, tokenized_results)
153
+ grades = await async_map_fallible(self.grader.grade, string_results)
154
+ gen_lengths = [sample.len_last_turn() for sample in tokenized_results]
155
+
156
+ grader_logs = self.grader.get_logs(clear=self.clear_grader_logs)
157
+ return {
158
+ **{f"rewards/{key}": value for key, value in grader_logs.items()},
159
+ "generation_length_mean": float(np.mean(gen_lengths).item()),
160
+ "generation_length_std": float(np.std(gen_lengths).item()),
161
+ "num_samples": float(len(grades)),
162
+ }
163
+
164
+
165
+ class EnvironmentValidationCallback(RecipeCallback):
166
+ def __init__(
167
+ self,
168
+ validation_set: list[StringThread],
169
+ model: InferenceModel,
170
+ env_factory: EnvironmentFactory,
171
+ frequency: float,
172
+ log_key: str = "validation",
173
+ clear_env_logs: bool = True,
174
+ temperature: float = 0.0,
175
+ num_samples_log: int = 0,
176
+ ):
177
+ super().__init__(frequency, log_key_prefix=log_key)
178
+ self.validation_set = validation_set
179
+ self.model = model
180
+ self.env_factory = env_factory
181
+ self.clear_env_logs = clear_env_logs
182
+ self.temperature = temperature
183
+ self.num_samples_log = num_samples_log
184
+
185
+ async def generate_trajectory(self, initial_thread: StringThread) -> tuple[StringThread, float, int]:
186
+ env = self.env_factory.create_environment(initial_thread.metadata)
187
+ temp_model = self.model.temperature(self.temperature)
188
+ trajectory, trajectory_score = await env.generate_trajectory_and_grade(temp_model, initial_thread)
189
+ num_turns = len([turn for turn in trajectory.get_turns() if turn.role == "assistant"])
190
+ return trajectory, trajectory_score.cumulative_score, num_turns
191
+
192
+ async def callback(self, current_percentage: float) -> dict[str, float | Table]:
193
+ logger.info("Entering environment validation callback...")
194
+
195
+ results = await async_map_fallible(self.generate_trajectory, self.validation_set)
196
+
197
+ trajectories = [traj for traj, _, _ in results]
198
+ scores = [score for _, score, _ in results]
199
+ num_turns_list = [num_turns for _, _, num_turns in results]
200
+
201
+ validation_logs = {
202
+ "score_mean": np.mean(scores).item(),
203
+ "score_std": np.std(scores).item(),
204
+ "num_turns_mean": np.mean(num_turns_list).item(),
205
+ "num_turns_std": np.std(num_turns_list).item(),
206
+ "num_samples": len(results),
207
+ }
208
+
209
+ env_logs = self.env_factory.get_logs(clear=self.clear_env_logs)
210
+ validation_logs.update({f"env/{key}": value for key, value in env_logs.items()})
211
+
212
+ if self.num_samples_log > 0:
213
+ samples = [repr(traj) for traj in trajectories[: self.num_samples_log]]
214
+ samples_scores = scores[: self.num_samples_log]
215
+ table = Table().add_column("trajectory", samples).add_column("score", samples_scores)
216
+ validation_logs["samples"] = table
217
+
218
+ logger.info(f"Validation Mean score: {validation_logs['score_mean']:.4f}")
219
+ return validation_logs
@@ -0,0 +1,163 @@
1
+ import json
2
+ import os
3
+ from pathlib import Path
4
+ from typing import Awaitable, Callable, Sequence
5
+
6
+ import anyio
7
+ import numpy as np
8
+ from loguru import logger as loguru
9
+
10
+ from adaptive_harmony import DataSet, StringThread
11
+ from adaptive_harmony.common.callbacks import RecipeCallback
12
+ from adaptive_harmony.core.utils import hash_dataset
13
+
14
+
15
+ class CheckpointManager:
16
+ def __init__(
17
+ self,
18
+ recipe_name: str,
19
+ dataset: DataSet,
20
+ threads_dataset: Sequence[StringThread],
21
+ callbacks: Sequence[RecipeCallback],
22
+ hyperparams_hash: str,
23
+ job_id: str | None = None,
24
+ checkpoint_frequency: float = 0.2,
25
+ restart_from_checkpoint: str | None = None,
26
+ ):
27
+ self.recipe_name = recipe_name
28
+ self.dataset = dataset
29
+ self.dataset_hash = hash_dataset(threads_dataset)
30
+ self.hyperparams_hash = hyperparams_hash
31
+ self.callbacks = callbacks
32
+ self.checkpoint_frequency = checkpoint_frequency
33
+ self.last_checkpoint_percentage = 0.0
34
+ self.restart_from_checkpoint = restart_from_checkpoint
35
+ self.job_id = job_id
36
+ self.checkpointing_folder = self._init_folder()
37
+
38
+ def _init_folder(self) -> str | None:
39
+ if self.job_id is None or os.getenv("HARMONY_NO_CHECKPOINTING") is not None:
40
+ loguru.warning("Checkpointing is disabled for this recipe.")
41
+ return None
42
+ return os.path.join(os.getenv("RECIPE_CHECKPOINTS_DIR", "/checkpoints"), self.job_id)
43
+
44
+ async def maybe_restore_checkpoint(
45
+ self,
46
+ recipe_specific_checkpoint_loading: Callable[[dict], Awaitable[None]],
47
+ ) -> None:
48
+ if self.restart_from_checkpoint is None:
49
+ return
50
+
51
+ checkpoint_path = Path(self.restart_from_checkpoint)
52
+ checkpoint_file = self._resolve_checkpoint_file(checkpoint_path)
53
+
54
+ assert checkpoint_file, f"Checkpoint file not found: {checkpoint_path}."
55
+
56
+ loguru.info(f"Loading {self.recipe_name} checkpoint from: {checkpoint_file}")
57
+
58
+ contents = ""
59
+ async with await anyio.open_file(checkpoint_file, "r") as f:
60
+ contents = await f.read()
61
+ checkpoint_data = json.loads(contents)
62
+
63
+ assert checkpoint_data.get("recipe_type") == self.recipe_name, (
64
+ f"Recipe type mismatch: checkpoint is '{checkpoint_data.get('recipe_type')}', "
65
+ f"but trying to load into {self.recipe_name}"
66
+ )
67
+
68
+ assert checkpoint_data.get("dataset_hash") == self.dataset_hash, (
69
+ "Dataset hash mismatch between checkpoint and current dataset."
70
+ )
71
+
72
+ assert checkpoint_data.get("hyperparams_hash") == self.hyperparams_hash, (
73
+ "Hyperparameters hash mismatch between checkpoint and current recipe configuration."
74
+ )
75
+
76
+ self.dataset.idx = checkpoint_data.get("dataset_idx", 0)
77
+
78
+ access_indices_list = checkpoint_data.get("dataset_access_indices", [])
79
+ if access_indices_list:
80
+ self.dataset.access_indices = np.array(access_indices_list)
81
+
82
+ rng_state = checkpoint_data.get("dataset_rng_state")
83
+ if rng_state:
84
+ self.dataset.rng.bit_generator.state = rng_state
85
+
86
+ callback_states = checkpoint_data.get("callback_last_calls", [])
87
+ assert len(callback_states) == len(self.callbacks), "Mismatch in number of callbacks when loading checkpoint"
88
+ for i, callback in enumerate(self.callbacks):
89
+ callback.last_call = callback_states[i]
90
+
91
+ await recipe_specific_checkpoint_loading(checkpoint_data)
92
+
93
+ self.last_checkpoint_percentage = checkpoint_data.get("completion_percentage", 0.0)
94
+
95
+ loguru.info(f"Checkpoint restored: starting {self.recipe_name} from {self.last_checkpoint_percentage:.2%}.")
96
+
97
+ async def maybe_checkpoint(
98
+ self,
99
+ completion_percentage: float,
100
+ recipe_specific_checkpoint_saving: Callable[[], Awaitable[dict]],
101
+ ) -> bool:
102
+ if self.checkpointing_folder is None:
103
+ return False
104
+
105
+ if completion_percentage >= 1.0:
106
+ return False
107
+
108
+ if await self._check_graceful_exit_file():
109
+ loguru.info(f"Graceful exit requested. Saving checkpoint and exiting {self.recipe_name} training loop.")
110
+ await self._save_checkpoint(completion_percentage, recipe_specific_checkpoint_saving)
111
+ return True
112
+
113
+ if completion_percentage - self.last_checkpoint_percentage >= self.checkpoint_frequency:
114
+ await self._save_checkpoint(completion_percentage, recipe_specific_checkpoint_saving)
115
+ self.last_checkpoint_percentage = completion_percentage
116
+
117
+ return False
118
+
119
+ async def _save_checkpoint(
120
+ self,
121
+ completion_percentage: float,
122
+ get_save_config: Callable[[], Awaitable[dict]],
123
+ ) -> None:
124
+ assert self.checkpointing_folder is not None # will never be called outside of this condition
125
+ progress_pct = int(completion_percentage * 100)
126
+ checkpoint_dir = Path(self.checkpointing_folder)
127
+ checkpoint_dir.mkdir(parents=True, exist_ok=True)
128
+
129
+ loguru.info(f"Checkpointing {self.recipe_name} at {checkpoint_dir} ({progress_pct}%)...")
130
+
131
+ recipe_data = await get_save_config()
132
+
133
+ checkpoint_data = {
134
+ "recipe_type": self.recipe_name,
135
+ "dataset_hash": self.dataset_hash,
136
+ "hyperparams_hash": self.hyperparams_hash,
137
+ "dataset_idx": self.dataset.idx,
138
+ "dataset_access_indices": self.dataset.access_indices.tolist(),
139
+ "dataset_rng_state": self.dataset.rng.bit_generator.state,
140
+ "callback_last_calls": [callback.last_call for callback in self.callbacks],
141
+ "completion_percentage": completion_percentage,
142
+ **recipe_data,
143
+ }
144
+
145
+ checkpoint_file = checkpoint_dir / f"checkpoint-{progress_pct}.json"
146
+
147
+ data_dump = json.dumps(checkpoint_data, indent=2)
148
+ async with await anyio.open_file(checkpoint_file, "w") as f:
149
+ await f.write(data_dump)
150
+
151
+ loguru.info(f"Checkpoint saved: {checkpoint_file}")
152
+
153
+ async def _check_graceful_exit_file(self) -> bool:
154
+ if self.checkpointing_folder is None:
155
+ return False
156
+ return (Path(self.checkpointing_folder) / "GRACEFUL_EXIT").exists()
157
+
158
+ @staticmethod
159
+ def _resolve_checkpoint_file(path: Path) -> Path | None:
160
+ if path.is_dir():
161
+ files = sorted(path.glob("checkpoint-*.json"), key=lambda p: int(p.stem.split("-")[1]))
162
+ return files[-1] if files else None
163
+ return path if path.exists() else None
@@ -0,0 +1,92 @@
1
+ from typing import Callable, Sequence
2
+
3
+ from tqdm.auto import tqdm
4
+
5
+ from adaptive_harmony import (
6
+ CosineScheduler,
7
+ DataSet,
8
+ JobNotifier,
9
+ Logger,
10
+ StageNotifier,
11
+ StringThread,
12
+ TrainingModel,
13
+ )
14
+ from adaptive_harmony.common.callbacks import RecipeCallback
15
+ from adaptive_harmony.core.utils import async_map_batch, log_args
16
+ from adaptive_harmony.metric_logger import StdoutLogger
17
+
18
+
19
+ class DPO:
20
+ @log_args
21
+ def __init__(
22
+ self,
23
+ dataset: list[tuple[StringThread, StringThread]], # (positive_sample, negative_sample)
24
+ model: TrainingModel,
25
+ logger: Logger = StdoutLogger(),
26
+ stage_notifier: StageNotifier = JobNotifier().stage_notifier("DPO Training"),
27
+ callbacks: Sequence[RecipeCallback] = [],
28
+ lr: float = 1e-6,
29
+ lr_scheduler: Callable[[float], float] | None = None,
30
+ samples_per_batch=32,
31
+ max_grad_norm=1.0,
32
+ kl_beta=0.1,
33
+ epochs=1,
34
+ skip_nan_gradients: bool = False,
35
+ ):
36
+ # Core components
37
+ self.model_ref = None
38
+ self.dataset = DataSet(dataset)
39
+ self.model = model
40
+ self.logger = logger
41
+ self.stage_notifier = stage_notifier
42
+ self.callbacks = callbacks
43
+ self.lr_schedule = lr_scheduler or CosineScheduler(lr)
44
+ self.samples_per_batch = samples_per_batch
45
+ self.max_grad_norm = max_grad_norm
46
+ self.skip_nan_gradients = skip_nan_gradients
47
+
48
+ # DPO HP's
49
+ self.kl_beta = kl_beta
50
+ self.epochs = epochs
51
+
52
+ @property
53
+ def training_completion_percentage(self):
54
+ return self.dataset.completion_percentage() / self.epochs
55
+
56
+ async def process_sample(self, sample: tuple[StringThread, StringThread]):
57
+ assert self.model_ref is not None, "Calling `process_sample_dpo` before reference model has been set"
58
+
59
+ pos, neg = sample
60
+ ref_logprobs_pos = await self.model_ref.logprobs(pos)
61
+ ref_logprobs_neg = await self.model_ref.logprobs(neg)
62
+ await self.model.train_dpo(pos, neg, ref_logprobs_pos, ref_logprobs_neg, self.kl_beta)
63
+
64
+ async def run(self):
65
+ self.model_ref = await self.model.clone_inf()
66
+
67
+ self.stage_notifier.report_progress(
68
+ tot_num_samples=len(self.dataset) * self.epochs,
69
+ processed_num_samples=self.dataset.idx,
70
+ monitoring_link=self.logger.training_monitoring_link,
71
+ )
72
+
73
+ with tqdm(total=100) as pbar:
74
+ while self.training_completion_percentage < 1.0:
75
+ for callback in self.callbacks:
76
+ if logs := await callback.maybe_call(self.training_completion_percentage):
77
+ self.logger(logs)
78
+
79
+ await async_map_batch(self.process_sample, self.dataset, self.samples_per_batch)
80
+ cp = self.training_completion_percentage
81
+ current_lr = self.lr_schedule(cp)
82
+ pbar.update(cp * 100.0 - pbar.n)
83
+ logs = await self.model.optim_step(
84
+ current_lr, wd=0, max_grad_norm=self.max_grad_norm, skip_nan_gradients=self.skip_nan_gradients
85
+ )
86
+ self.logger(logs | dict(completion_percentage=cp))
87
+
88
+ self.stage_notifier.report_progress(
89
+ tot_num_samples=len(self.dataset) * self.epochs,
90
+ processed_num_samples=self.dataset.idx,
91
+ monitoring_link=self.logger.training_monitoring_link,
92
+ )