adaptive-harmony 0.1.23__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- adaptive_harmony-0.1.23/PKG-INFO +37 -0
- adaptive_harmony-0.1.23/pyproject.toml +110 -0
- adaptive_harmony-0.1.23/python/adaptive_harmony/__init__.py +162 -0
- adaptive_harmony-0.1.23/python/adaptive_harmony/common/__init__.py +40 -0
- adaptive_harmony-0.1.23/python/adaptive_harmony/common/callbacks.py +219 -0
- adaptive_harmony-0.1.23/python/adaptive_harmony/common/checkpointing.py +163 -0
- adaptive_harmony-0.1.23/python/adaptive_harmony/common/dpo.py +92 -0
- adaptive_harmony-0.1.23/python/adaptive_harmony/common/env_grpo.py +361 -0
- adaptive_harmony-0.1.23/python/adaptive_harmony/common/grpo.py +260 -0
- adaptive_harmony-0.1.23/python/adaptive_harmony/common/gspo.py +70 -0
- adaptive_harmony-0.1.23/python/adaptive_harmony/common/ppo.py +303 -0
- adaptive_harmony-0.1.23/python/adaptive_harmony/common/rm.py +79 -0
- adaptive_harmony-0.1.23/python/adaptive_harmony/common/sft.py +121 -0
- adaptive_harmony-0.1.23/python/adaptive_harmony/core/__init__.py +0 -0
- adaptive_harmony-0.1.23/python/adaptive_harmony/core/dataset.py +72 -0
- adaptive_harmony-0.1.23/python/adaptive_harmony/core/display.py +93 -0
- adaptive_harmony-0.1.23/python/adaptive_harmony/core/image_utils.py +110 -0
- adaptive_harmony-0.1.23/python/adaptive_harmony/core/reasoning.py +12 -0
- adaptive_harmony-0.1.23/python/adaptive_harmony/core/reward_client/__init__.py +19 -0
- adaptive_harmony-0.1.23/python/adaptive_harmony/core/reward_client/client.py +160 -0
- adaptive_harmony-0.1.23/python/adaptive_harmony/core/reward_client/reward_types.py +49 -0
- adaptive_harmony-0.1.23/python/adaptive_harmony/core/reward_client/websocket_utils.py +18 -0
- adaptive_harmony-0.1.23/python/adaptive_harmony/core/rich_counter.py +351 -0
- adaptive_harmony-0.1.23/python/adaptive_harmony/core/rl_utils.py +38 -0
- adaptive_harmony-0.1.23/python/adaptive_harmony/core/schedulers.py +38 -0
- adaptive_harmony-0.1.23/python/adaptive_harmony/core/structured_output.py +385 -0
- adaptive_harmony-0.1.23/python/adaptive_harmony/core/utils.py +365 -0
- adaptive_harmony-0.1.23/python/adaptive_harmony/environment/__init__.py +8 -0
- adaptive_harmony-0.1.23/python/adaptive_harmony/environment/environment.py +121 -0
- adaptive_harmony-0.1.23/python/adaptive_harmony/evaluation/__init__.py +1 -0
- adaptive_harmony-0.1.23/python/adaptive_harmony/evaluation/evaluation_artifact.py +67 -0
- adaptive_harmony-0.1.23/python/adaptive_harmony/graders/__init__.py +20 -0
- adaptive_harmony-0.1.23/python/adaptive_harmony/graders/answer_relevancy_judge/__init__.py +3 -0
- adaptive_harmony-0.1.23/python/adaptive_harmony/graders/answer_relevancy_judge/answer_relevancy_judge.py +102 -0
- adaptive_harmony-0.1.23/python/adaptive_harmony/graders/answer_relevancy_judge/prompts.py +58 -0
- adaptive_harmony-0.1.23/python/adaptive_harmony/graders/base_grader.py +265 -0
- adaptive_harmony-0.1.23/python/adaptive_harmony/graders/binary_judge/__init__.py +8 -0
- adaptive_harmony-0.1.23/python/adaptive_harmony/graders/binary_judge/binary_judge.py +202 -0
- adaptive_harmony-0.1.23/python/adaptive_harmony/graders/binary_judge/prompts.py +125 -0
- adaptive_harmony-0.1.23/python/adaptive_harmony/graders/combined_grader.py +118 -0
- adaptive_harmony-0.1.23/python/adaptive_harmony/graders/context_relevancy_judge/__init__.py +3 -0
- adaptive_harmony-0.1.23/python/adaptive_harmony/graders/context_relevancy_judge/context_relevancy_judge.py +128 -0
- adaptive_harmony-0.1.23/python/adaptive_harmony/graders/context_relevancy_judge/prompts.py +84 -0
- adaptive_harmony-0.1.23/python/adaptive_harmony/graders/exceptions.py +9 -0
- adaptive_harmony-0.1.23/python/adaptive_harmony/graders/faithfulness_judge/__init__.py +3 -0
- adaptive_harmony-0.1.23/python/adaptive_harmony/graders/faithfulness_judge/faithfulness_judge.py +159 -0
- adaptive_harmony-0.1.23/python/adaptive_harmony/graders/faithfulness_judge/prompts.py +22 -0
- adaptive_harmony-0.1.23/python/adaptive_harmony/graders/range_judge/__init__.py +7 -0
- adaptive_harmony-0.1.23/python/adaptive_harmony/graders/range_judge/prompts.py +232 -0
- adaptive_harmony-0.1.23/python/adaptive_harmony/graders/range_judge/range_judge.py +188 -0
- adaptive_harmony-0.1.23/python/adaptive_harmony/graders/range_judge/types.py +12 -0
- adaptive_harmony-0.1.23/python/adaptive_harmony/graders/reward_server_grader.py +36 -0
- adaptive_harmony-0.1.23/python/adaptive_harmony/graders/templated_prompt_judge.py +237 -0
- adaptive_harmony-0.1.23/python/adaptive_harmony/graders/utils.py +79 -0
- adaptive_harmony-0.1.23/python/adaptive_harmony/logging_table.py +1 -0
- adaptive_harmony-0.1.23/python/adaptive_harmony/metric_logger.py +452 -0
- adaptive_harmony-0.1.23/python/adaptive_harmony/parameters/__init__.py +2 -0
- adaptive_harmony-0.1.23/python/adaptive_harmony/py.typed +0 -0
- adaptive_harmony-0.1.23/python/adaptive_harmony/runtime/__init__.py +2 -0
- adaptive_harmony-0.1.23/python/adaptive_harmony/runtime/context.py +2 -0
- adaptive_harmony-0.1.23/python/adaptive_harmony/runtime/data.py +2 -0
- adaptive_harmony-0.1.23/python/adaptive_harmony/runtime/decorators.py +2 -0
- adaptive_harmony-0.1.23/python/adaptive_harmony/runtime/model_artifact_save.py +2 -0
- adaptive_harmony-0.1.23/python/adaptive_harmony/runtime/runner.py +27 -0
- adaptive_harmony-0.1.23/python/adaptive_harmony/runtime/simple_notifier.py +2 -0
- adaptive_harmony-0.1.23/python/adaptive_harmony.egg-info/PKG-INFO +37 -0
- adaptive_harmony-0.1.23/python/adaptive_harmony.egg-info/SOURCES.txt +78 -0
- adaptive_harmony-0.1.23/python/adaptive_harmony.egg-info/dependency_links.txt +1 -0
- adaptive_harmony-0.1.23/python/adaptive_harmony.egg-info/requires.txt +33 -0
- adaptive_harmony-0.1.23/python/adaptive_harmony.egg-info/top_level.txt +1 -0
- adaptive_harmony-0.1.23/setup.cfg +4 -0
- adaptive_harmony-0.1.23/tests/test.py +3 -0
- adaptive_harmony-0.1.23/tests/test_adaptive_dataset.py +257 -0
- adaptive_harmony-0.1.23/tests/test_eval_artifact.py +108 -0
- adaptive_harmony-0.1.23/tests/test_graders.py +78 -0
- adaptive_harmony-0.1.23/tests/test_job_artifact.py +77 -0
- adaptive_harmony-0.1.23/tests/test_local_eval_logging.py +131 -0
- adaptive_harmony-0.1.23/tests/test_prebuilt_graders.py +273 -0
- adaptive_harmony-0.1.23/tests/test_reasoning.py +115 -0
- adaptive_harmony-0.1.23/tests/test_runtime_runner.py +361 -0
|
@@ -0,0 +1,37 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: adaptive-harmony
|
|
3
|
+
Version: 0.1.23
|
|
4
|
+
Summary: Adaptive Harmony training recipes and utilities for LLM fine-tuning
|
|
5
|
+
Classifier: Programming Language :: Python :: 3.12
|
|
6
|
+
Classifier: Programming Language :: Python :: 3.13
|
|
7
|
+
Classifier: Programming Language :: Python :: Implementation :: CPython
|
|
8
|
+
Classifier: Programming Language :: Python :: Implementation :: PyPy
|
|
9
|
+
Requires-Python: >=3.12
|
|
10
|
+
Requires-Dist: harmony-client
|
|
11
|
+
Requires-Dist: rich>=13.7.0
|
|
12
|
+
Requires-Dist: datasets>=2.14.0
|
|
13
|
+
Requires-Dist: hf-xet>=1.1.2
|
|
14
|
+
Requires-Dist: loguru>=0.7.2
|
|
15
|
+
Requires-Dist: pydantic-settings>=2.9.1
|
|
16
|
+
Requires-Dist: pydantic>=2.10.5
|
|
17
|
+
Requires-Dist: pybars3>=0.9.7
|
|
18
|
+
Requires-Dist: pysbd>=0.3.4
|
|
19
|
+
Requires-Dist: websockets>=15.0.1
|
|
20
|
+
Requires-Dist: setproctitle>=1.3.3
|
|
21
|
+
Requires-Dist: pydantic-xml>=2.16.0
|
|
22
|
+
Requires-Dist: openai>=1.42.0
|
|
23
|
+
Requires-Dist: pillow>=11.3.0
|
|
24
|
+
Requires-Dist: boto3>=1.40
|
|
25
|
+
Requires-Dist: tomli>=2
|
|
26
|
+
Provides-Extra: tensorboard
|
|
27
|
+
Requires-Dist: tensorboardX>=2.6.2.2; extra == "tensorboard"
|
|
28
|
+
Provides-Extra: mlflow
|
|
29
|
+
Requires-Dist: mlflow>=3.4.0; extra == "mlflow"
|
|
30
|
+
Provides-Extra: wandb
|
|
31
|
+
Requires-Dist: wandb>=0.19.11; extra == "wandb"
|
|
32
|
+
Provides-Extra: all-monitoring
|
|
33
|
+
Requires-Dist: tensorboardX>=2.6.2.2; extra == "all-monitoring"
|
|
34
|
+
Requires-Dist: mlflow>=3.4.0; extra == "all-monitoring"
|
|
35
|
+
Requires-Dist: wandb>=0.19.11; extra == "all-monitoring"
|
|
36
|
+
Provides-Extra: mlflow-skinny
|
|
37
|
+
Requires-Dist: mlflow-skinny>=3.4.0; extra == "mlflow-skinny"
|
|
@@ -0,0 +1,110 @@
|
|
|
1
|
+
[build-system]
|
|
2
|
+
requires = ["setuptools>=61.0"]
|
|
3
|
+
build-backend = "setuptools.build_meta"
|
|
4
|
+
|
|
5
|
+
[project]
|
|
6
|
+
version = "0.1.23"
|
|
7
|
+
name = "adaptive-harmony"
|
|
8
|
+
description = "Adaptive Harmony training recipes and utilities for LLM fine-tuning"
|
|
9
|
+
requires-python = ">=3.12"
|
|
10
|
+
classifiers = [
|
|
11
|
+
"Programming Language :: Python :: 3.12",
|
|
12
|
+
"Programming Language :: Python :: 3.13",
|
|
13
|
+
"Programming Language :: Python :: Implementation :: CPython",
|
|
14
|
+
"Programming Language :: Python :: Implementation :: PyPy",
|
|
15
|
+
]
|
|
16
|
+
dependencies = [
|
|
17
|
+
"harmony-client",
|
|
18
|
+
"rich>=13.7.0",
|
|
19
|
+
"datasets>=2.14.0",
|
|
20
|
+
"hf-xet>=1.1.2",
|
|
21
|
+
"loguru>=0.7.2",
|
|
22
|
+
"pydantic-settings>=2.9.1",
|
|
23
|
+
"pydantic>=2.10.5",
|
|
24
|
+
"pybars3>=0.9.7",
|
|
25
|
+
"pysbd>=0.3.4",
|
|
26
|
+
"websockets>=15.0.1",
|
|
27
|
+
"setproctitle>=1.3.3",
|
|
28
|
+
"pydantic-xml>=2.16.0",
|
|
29
|
+
"openai>=1.42.0",
|
|
30
|
+
"pillow>=11.3.0",
|
|
31
|
+
"boto3>=1.40",
|
|
32
|
+
"tomli>=2",
|
|
33
|
+
]
|
|
34
|
+
|
|
35
|
+
[project.optional-dependencies]
|
|
36
|
+
tensorboard = ["tensorboardX>=2.6.2.2"]
|
|
37
|
+
mlflow = ["mlflow>=3.4.0"]
|
|
38
|
+
wandb = ["wandb>=0.19.11"]
|
|
39
|
+
all_monitoring = ["tensorboardX>=2.6.2.2", "mlflow>=3.4.0", "wandb>=0.19.11"]
|
|
40
|
+
mlflow_skinny = ["mlflow-skinny>=3.4.0"]
|
|
41
|
+
|
|
42
|
+
[dependency-groups]
|
|
43
|
+
dev = ["pyright>=1.1.407", "pytest", "pytest-asyncio>=0.23.5", "ruff>=0.14.3"]
|
|
44
|
+
|
|
45
|
+
[tool.pyright]
|
|
46
|
+
pythonVersion = "3.12"
|
|
47
|
+
pythonPlatform = "Linux"
|
|
48
|
+
typeCheckingMode = "standard"
|
|
49
|
+
reportUntypedFunctionDecorator = "warning"
|
|
50
|
+
reportMissingTypeStubs = "warning"
|
|
51
|
+
reportMissingImports = false
|
|
52
|
+
reportUnusedVariable = "warning"
|
|
53
|
+
reportUnusedImport = "warning"
|
|
54
|
+
reportDuplicateImport = "error"
|
|
55
|
+
reportArgumentType = "error"
|
|
56
|
+
reportPossiblyUnboundVariable = "warning"
|
|
57
|
+
reportIncompatibleMethodOverride = "warning"
|
|
58
|
+
exclude = [
|
|
59
|
+
".venv/**",
|
|
60
|
+
"**/__pycache__",
|
|
61
|
+
"tests",
|
|
62
|
+
"examples",
|
|
63
|
+
"python/adaptive_harmony/metric_logger.py",
|
|
64
|
+
"python/adaptive_harmony/runtime", # auto-generated code
|
|
65
|
+
]
|
|
66
|
+
extraPaths = ["python"]
|
|
67
|
+
|
|
68
|
+
[tool.ruff]
|
|
69
|
+
line-length = 120
|
|
70
|
+
lint.ignore = [
|
|
71
|
+
"E731",
|
|
72
|
+
"E402",
|
|
73
|
+
"UP035",
|
|
74
|
+
"UP045",
|
|
75
|
+
"UP040",
|
|
76
|
+
"UP006",
|
|
77
|
+
"F841",
|
|
78
|
+
"E501",
|
|
79
|
+
]
|
|
80
|
+
lint.select = [
|
|
81
|
+
"E", # pycodestyle errors
|
|
82
|
+
"F", # pyflakes
|
|
83
|
+
"I", # isort (import sorting)
|
|
84
|
+
"UP", # pyupgrade (modern Python syntax)
|
|
85
|
+
"ASYNC", # async-specific checks
|
|
86
|
+
]
|
|
87
|
+
lint.fixable = ["I", "UP", "F401"]
|
|
88
|
+
exclude = [
|
|
89
|
+
".venv/**",
|
|
90
|
+
"**/__pycache__",
|
|
91
|
+
"tests",
|
|
92
|
+
"examples",
|
|
93
|
+
"python/adaptive_harmony/runtime", # auto-generated code
|
|
94
|
+
]
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
[tool.pytest.ini_options]
|
|
98
|
+
asyncio_mode = "auto"
|
|
99
|
+
markers = [
|
|
100
|
+
"harmony: mark tests that require harmony to run (deselect with '-m \"not harmony\"'')",
|
|
101
|
+
]
|
|
102
|
+
|
|
103
|
+
[tool.setuptools]
|
|
104
|
+
package-dir = { "" = "python" }
|
|
105
|
+
|
|
106
|
+
[tool.setuptools.packages.find]
|
|
107
|
+
where = ["python"]
|
|
108
|
+
|
|
109
|
+
[tool.uv]
|
|
110
|
+
cache-keys = [{ file = "pyproject.toml" }]
|
|
@@ -0,0 +1,162 @@
|
|
|
1
|
+
# ruff: noqa: F403, F401
|
|
2
|
+
from typing import TYPE_CHECKING
|
|
3
|
+
|
|
4
|
+
from harmony_client import (
|
|
5
|
+
EvalSample as EvalSample,
|
|
6
|
+
)
|
|
7
|
+
from harmony_client import (
|
|
8
|
+
EvalSampleInteraction as EvalSampleInteraction,
|
|
9
|
+
)
|
|
10
|
+
from harmony_client import (
|
|
11
|
+
Grade as Grade,
|
|
12
|
+
)
|
|
13
|
+
from harmony_client import (
|
|
14
|
+
HarmonyClient as HarmonyClient,
|
|
15
|
+
)
|
|
16
|
+
from harmony_client import (
|
|
17
|
+
HarmonyJobNotifier as HarmonyJobNotifier,
|
|
18
|
+
)
|
|
19
|
+
from harmony_client import (
|
|
20
|
+
InferenceModel as InferenceModel,
|
|
21
|
+
)
|
|
22
|
+
from harmony_client import (
|
|
23
|
+
JobArtifact as JobArtifact,
|
|
24
|
+
)
|
|
25
|
+
from harmony_client import (
|
|
26
|
+
JobNotifier as JobNotifier,
|
|
27
|
+
)
|
|
28
|
+
from harmony_client import (
|
|
29
|
+
ModelBuilder as ModelBuilder,
|
|
30
|
+
)
|
|
31
|
+
from harmony_client import (
|
|
32
|
+
StageNotifier as StageNotifier,
|
|
33
|
+
)
|
|
34
|
+
from harmony_client import (
|
|
35
|
+
StringThread as StringThread,
|
|
36
|
+
)
|
|
37
|
+
from harmony_client import (
|
|
38
|
+
TokenizedThread as TokenizedThread,
|
|
39
|
+
)
|
|
40
|
+
from harmony_client import (
|
|
41
|
+
TrainingModel as TrainingModel,
|
|
42
|
+
)
|
|
43
|
+
from harmony_client import (
|
|
44
|
+
get_client as get_client,
|
|
45
|
+
)
|
|
46
|
+
from harmony_client import parameters as parameters
|
|
47
|
+
from harmony_client import runtime as runtime
|
|
48
|
+
from rich.progress import Progress
|
|
49
|
+
|
|
50
|
+
if TYPE_CHECKING:
|
|
51
|
+
from harmony_client import StringTurn as StringTurn
|
|
52
|
+
else:
|
|
53
|
+
from typing import NamedTuple
|
|
54
|
+
|
|
55
|
+
class StringTurn(NamedTuple):
|
|
56
|
+
role: str
|
|
57
|
+
content: str
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
from harmony_client.artifacts.custom_artifact import CustomArtifact
|
|
61
|
+
from harmony_client.artifacts.dataset_artifact import DatasetArtifact
|
|
62
|
+
from harmony_client.file_storage import (
|
|
63
|
+
FileStorage,
|
|
64
|
+
FileStorageConfig,
|
|
65
|
+
LocalFileStorageConfig,
|
|
66
|
+
S3FileStorageConfig,
|
|
67
|
+
StoredFile,
|
|
68
|
+
)
|
|
69
|
+
|
|
70
|
+
import adaptive_harmony.core.rl_utils as rl_utils
|
|
71
|
+
from adaptive_harmony.core.dataset import DataSet
|
|
72
|
+
from adaptive_harmony.core.schedulers import CombinedSchedule, CosineScheduler, CosineSchedulerWithoutWarmup, Scheduler
|
|
73
|
+
from adaptive_harmony.evaluation.evaluation_artifact import EvaluationArtifact
|
|
74
|
+
from adaptive_harmony.metric_logger import Logger, WandbLogger
|
|
75
|
+
|
|
76
|
+
# Ensure key classes are available at module level
|
|
77
|
+
__all__ = [
|
|
78
|
+
"StringThread",
|
|
79
|
+
"StringTurn",
|
|
80
|
+
"TokenizedThread",
|
|
81
|
+
"InferenceModel",
|
|
82
|
+
"ModelBuilder",
|
|
83
|
+
"TrainingModel",
|
|
84
|
+
"HarmonyClient",
|
|
85
|
+
"get_client",
|
|
86
|
+
"DataSet",
|
|
87
|
+
"CosineScheduler",
|
|
88
|
+
"CombinedSchedule",
|
|
89
|
+
"CosineSchedulerWithoutWarmup",
|
|
90
|
+
"Scheduler",
|
|
91
|
+
"WandbLogger",
|
|
92
|
+
"Logger",
|
|
93
|
+
"FileStorage",
|
|
94
|
+
"FileStorageConfig",
|
|
95
|
+
"LocalFileStorageConfig",
|
|
96
|
+
"S3FileStorageConfig",
|
|
97
|
+
"StoredFile",
|
|
98
|
+
"EvaluationArtifact",
|
|
99
|
+
"CustomArtifact",
|
|
100
|
+
"DatasetArtifact",
|
|
101
|
+
"rl_utils",
|
|
102
|
+
"Grade",
|
|
103
|
+
"EvalSample",
|
|
104
|
+
"EvalSampleInteraction",
|
|
105
|
+
"JobArtifact",
|
|
106
|
+
]
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
# Patch StringThread to use rich for display
|
|
110
|
+
from harmony_client.runtime.model_artifact_save import save_with_artifact
|
|
111
|
+
|
|
112
|
+
from adaptive_harmony.core.display import _stringthread_repr, _tokenizedthread_repr
|
|
113
|
+
from adaptive_harmony.core.image_utils import string_thread_to_html_string
|
|
114
|
+
|
|
115
|
+
# Patch InferenceModel to have json output capabilities
|
|
116
|
+
from adaptive_harmony.core.structured_output import generate_and_validate, render_pydantic_model, render_schema
|
|
117
|
+
|
|
118
|
+
StringThread.__repr__ = _stringthread_repr # type: ignore
|
|
119
|
+
TokenizedThread.__repr__ = _tokenizedthread_repr # type: ignore
|
|
120
|
+
setattr(StringThread, "_repr_html_", string_thread_to_html_string)
|
|
121
|
+
setattr(InferenceModel, "generate_and_validate", generate_and_validate)
|
|
122
|
+
setattr(InferenceModel, "render_schema", staticmethod(render_schema))
|
|
123
|
+
setattr(InferenceModel, "render_pydantic_model", staticmethod(render_pydantic_model))
|
|
124
|
+
|
|
125
|
+
_original_training_model_save = TrainingModel.save
|
|
126
|
+
|
|
127
|
+
|
|
128
|
+
async def _save_with_artifact_wrapper(model: TrainingModel, model_name: str, inference_only: bool = True, ctx=None):
|
|
129
|
+
return await save_with_artifact(model, model_name, inference_only, ctx, _original_training_model_save)
|
|
130
|
+
|
|
131
|
+
|
|
132
|
+
setattr(TrainingModel, "save", _save_with_artifact_wrapper)
|
|
133
|
+
|
|
134
|
+
|
|
135
|
+
async def spawn_train(self: ModelBuilder, name: str, max_batch_size: int) -> TrainingModel:
|
|
136
|
+
fut = await self.spawn_train_with_progress(name, max_batch_size) # type:ignore
|
|
137
|
+
|
|
138
|
+
with Progress() as pbar:
|
|
139
|
+
task = pbar.add_task("Loading model", total=1000)
|
|
140
|
+
|
|
141
|
+
while (prog := await fut._await_progress()) != 1.0:
|
|
142
|
+
pbar.update(task, completed=prog, total=1.0)
|
|
143
|
+
pbar.update(task, completed=1.0, total=1.0)
|
|
144
|
+
|
|
145
|
+
return await fut.get()
|
|
146
|
+
|
|
147
|
+
|
|
148
|
+
async def spawn_inference(self: ModelBuilder, name: str) -> InferenceModel:
|
|
149
|
+
fut = await self.spawn_inference_with_progress(name) # type:ignore
|
|
150
|
+
|
|
151
|
+
with Progress() as pbar:
|
|
152
|
+
task = pbar.add_task("Loading model", total=1000)
|
|
153
|
+
|
|
154
|
+
while (prog := await fut._await_progress()) != 1.0:
|
|
155
|
+
pbar.update(task, completed=prog, total=1.0)
|
|
156
|
+
pbar.update(task, completed=1.0, total=1.0)
|
|
157
|
+
|
|
158
|
+
return await fut.get()
|
|
159
|
+
|
|
160
|
+
|
|
161
|
+
setattr(ModelBuilder, "spawn_inference", spawn_inference)
|
|
162
|
+
setattr(ModelBuilder, "spawn_train", spawn_train)
|
|
@@ -0,0 +1,40 @@
|
|
|
1
|
+
from .callbacks import (
|
|
2
|
+
CheckpointCallback as CheckpointCallback,
|
|
3
|
+
)
|
|
4
|
+
from .callbacks import (
|
|
5
|
+
EnvironmentValidationCallback as EnvironmentValidationCallback,
|
|
6
|
+
)
|
|
7
|
+
from .callbacks import (
|
|
8
|
+
GenerateSamplesCallback as GenerateSamplesCallback,
|
|
9
|
+
)
|
|
10
|
+
from .callbacks import (
|
|
11
|
+
GraderEvalCallback as GraderEvalCallback,
|
|
12
|
+
)
|
|
13
|
+
from .callbacks import (
|
|
14
|
+
RecipeCallback as RecipeCallback,
|
|
15
|
+
)
|
|
16
|
+
from .callbacks import (
|
|
17
|
+
ValidationLossCallback as ValidationLossCallback,
|
|
18
|
+
)
|
|
19
|
+
from .dpo import DPO as DPO
|
|
20
|
+
from .env_grpo import ENVGRPO
|
|
21
|
+
from .grpo import GRPO as GRPO
|
|
22
|
+
from .gspo import GSPO as GSPO
|
|
23
|
+
from .ppo import PPO as PPO
|
|
24
|
+
from .rm import RewardModelling as RewardModelling
|
|
25
|
+
from .sft import SFT as SFT
|
|
26
|
+
|
|
27
|
+
__all__ = [
|
|
28
|
+
"SFT",
|
|
29
|
+
"PPO",
|
|
30
|
+
"GRPO",
|
|
31
|
+
"ENVGRPO",
|
|
32
|
+
"DPO",
|
|
33
|
+
"RewardModelling",
|
|
34
|
+
"RecipeCallback",
|
|
35
|
+
"GenerateSamplesCallback",
|
|
36
|
+
"ValidationLossCallback",
|
|
37
|
+
"CheckpointCallback",
|
|
38
|
+
"GraderEvalCallback",
|
|
39
|
+
"EnvironmentValidationCallback",
|
|
40
|
+
]
|
|
@@ -0,0 +1,219 @@
|
|
|
1
|
+
from abc import abstractmethod
|
|
2
|
+
from typing import Any
|
|
3
|
+
|
|
4
|
+
import numpy as np
|
|
5
|
+
from harmony_client import (
|
|
6
|
+
InferenceModel,
|
|
7
|
+
StringThread,
|
|
8
|
+
TrainingModel,
|
|
9
|
+
)
|
|
10
|
+
from loguru import logger
|
|
11
|
+
|
|
12
|
+
from adaptive_harmony.core.utils import async_map, async_map_fallible
|
|
13
|
+
from adaptive_harmony.environment import EnvironmentFactory
|
|
14
|
+
from adaptive_harmony.graders import BaseGrader
|
|
15
|
+
from adaptive_harmony.logging_table import Table
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class RecipeCallback:
|
|
19
|
+
def __init__(self, frequency: float, log_key_prefix: str | None = None):
|
|
20
|
+
self.frequency = frequency
|
|
21
|
+
self.last_call = -1.0
|
|
22
|
+
self.log_key_prefix = log_key_prefix
|
|
23
|
+
|
|
24
|
+
async def maybe_call(self, current_percentage: float) -> dict[str, Any]:
|
|
25
|
+
if current_percentage - self.last_call >= self.frequency:
|
|
26
|
+
self.last_call = current_percentage
|
|
27
|
+
callback_dict = await self.callback(current_percentage)
|
|
28
|
+
prefixed_dict = {
|
|
29
|
+
(f"{self.log_key_prefix}/{key}" if self.log_key_prefix else key): value
|
|
30
|
+
for key, value in callback_dict.items()
|
|
31
|
+
}
|
|
32
|
+
return prefixed_dict
|
|
33
|
+
return {}
|
|
34
|
+
|
|
35
|
+
@abstractmethod
|
|
36
|
+
async def callback(self, current_percentage: float) -> dict[str, Any]: ...
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
class GenerateSamplesCallback(RecipeCallback):
|
|
40
|
+
def __init__(
|
|
41
|
+
self,
|
|
42
|
+
thread_set: list[StringThread],
|
|
43
|
+
model: InferenceModel,
|
|
44
|
+
frequency: float,
|
|
45
|
+
log_key: str = "samples",
|
|
46
|
+
):
|
|
47
|
+
super().__init__(frequency, log_key_prefix="generation")
|
|
48
|
+
self.thread_set = thread_set
|
|
49
|
+
self.model = model
|
|
50
|
+
self.log_key = log_key
|
|
51
|
+
|
|
52
|
+
async def callback(self, current_percentage: float) -> dict[str, Any]:
|
|
53
|
+
logger.info("Entering generation callback...")
|
|
54
|
+
generation_tokens = await async_map_fallible(self.model.generate_tokens, self.thread_set)
|
|
55
|
+
generation_results = await async_map_fallible(self.model.detokenize_thread, generation_tokens)
|
|
56
|
+
gen_lengths = [sample.len_last_turn() for sample in generation_tokens]
|
|
57
|
+
|
|
58
|
+
generation_logs = {
|
|
59
|
+
self.log_key: Table()
|
|
60
|
+
.add_column(
|
|
61
|
+
"system",
|
|
62
|
+
[
|
|
63
|
+
sample.get_turns()[0].content if sample.get_turns()[0].role == "system" else ""
|
|
64
|
+
for sample in generation_results
|
|
65
|
+
],
|
|
66
|
+
)
|
|
67
|
+
.add_column(
|
|
68
|
+
"prompt",
|
|
69
|
+
[
|
|
70
|
+
repr(
|
|
71
|
+
StringThread(
|
|
72
|
+
sample.get_turns()[1:-1]
|
|
73
|
+
if (sample.get_turns() and sample.get_turns()[0].role == "system")
|
|
74
|
+
else sample.get_turns()[:-1]
|
|
75
|
+
)
|
|
76
|
+
)
|
|
77
|
+
for sample in generation_results
|
|
78
|
+
],
|
|
79
|
+
)
|
|
80
|
+
.add_column("response", [response.last_content() for response in generation_results]),
|
|
81
|
+
"generation_length_mean": np.mean(gen_lengths).item(),
|
|
82
|
+
"generation_length_std": np.std(gen_lengths).item(),
|
|
83
|
+
"num_samples": len(generation_results),
|
|
84
|
+
}
|
|
85
|
+
return generation_logs
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
class ValidationLossCallback(RecipeCallback):
|
|
89
|
+
def __init__(
|
|
90
|
+
self,
|
|
91
|
+
validation_set: list[StringThread],
|
|
92
|
+
model: InferenceModel,
|
|
93
|
+
frequency: float = 0.1,
|
|
94
|
+
log_key: str = "loss",
|
|
95
|
+
):
|
|
96
|
+
super().__init__(frequency, log_key_prefix="validation")
|
|
97
|
+
self.validation_set = validation_set
|
|
98
|
+
self.model = model
|
|
99
|
+
self.log_key = log_key
|
|
100
|
+
|
|
101
|
+
async def callback(self, current_percentage: float) -> dict[str, float]:
|
|
102
|
+
logger.info("Entering validation loss callback...")
|
|
103
|
+
losses = []
|
|
104
|
+
tokens = await async_map_fallible(self.model.tokenize_thread, self.validation_set)
|
|
105
|
+
logprobs = await async_map(self.model.logprobs_per_token, tokens)
|
|
106
|
+
losses = [-(sum(lp) / len(lp)) for lp in logprobs]
|
|
107
|
+
|
|
108
|
+
return {self.log_key: sum(losses) / len(losses)}
|
|
109
|
+
|
|
110
|
+
|
|
111
|
+
class CheckpointCallback(RecipeCallback):
|
|
112
|
+
def __init__(
|
|
113
|
+
self,
|
|
114
|
+
model: TrainingModel,
|
|
115
|
+
checkpoint_name: str,
|
|
116
|
+
frequency: float = 0.2,
|
|
117
|
+
):
|
|
118
|
+
super().__init__(frequency, log_key_prefix="checkpointing")
|
|
119
|
+
self.last_call = 0.0 # avoid saving the model at the first period
|
|
120
|
+
self.model = model
|
|
121
|
+
self.model_log_name = checkpoint_name
|
|
122
|
+
|
|
123
|
+
async def callback(self, current_percentage: float):
|
|
124
|
+
logger.info(f"Saving checkpoint at {current_percentage * 100} % of training ...")
|
|
125
|
+
await self.model.save(f"{self.model_log_name}-{round(current_percentage, 3)}")
|
|
126
|
+
return {}
|
|
127
|
+
|
|
128
|
+
|
|
129
|
+
class GraderEvalCallback(RecipeCallback):
|
|
130
|
+
def __init__(
|
|
131
|
+
self,
|
|
132
|
+
validation_set: list[StringThread],
|
|
133
|
+
model: InferenceModel,
|
|
134
|
+
grader: BaseGrader,
|
|
135
|
+
frequency: float,
|
|
136
|
+
log_key: str = "validation",
|
|
137
|
+
clear_grader_logs: bool = True,
|
|
138
|
+
temperature: float = 0.0,
|
|
139
|
+
):
|
|
140
|
+
super().__init__(frequency, log_key_prefix=log_key)
|
|
141
|
+
self.validation_set = validation_set
|
|
142
|
+
self.model = model
|
|
143
|
+
self.grader = grader
|
|
144
|
+
self.clear_grader_logs = clear_grader_logs
|
|
145
|
+
self.temperature = temperature
|
|
146
|
+
|
|
147
|
+
async def callback(self, current_percentage: float) -> dict[str, float | Table]:
|
|
148
|
+
logger.info("Entering grader evaluation callback...")
|
|
149
|
+
temp_model = self.model.temperature(self.temperature)
|
|
150
|
+
|
|
151
|
+
tokenized_results = await async_map_fallible(temp_model.generate_tokens, self.validation_set)
|
|
152
|
+
string_results = await async_map(temp_model.detokenize_thread, tokenized_results)
|
|
153
|
+
grades = await async_map_fallible(self.grader.grade, string_results)
|
|
154
|
+
gen_lengths = [sample.len_last_turn() for sample in tokenized_results]
|
|
155
|
+
|
|
156
|
+
grader_logs = self.grader.get_logs(clear=self.clear_grader_logs)
|
|
157
|
+
return {
|
|
158
|
+
**{f"rewards/{key}": value for key, value in grader_logs.items()},
|
|
159
|
+
"generation_length_mean": float(np.mean(gen_lengths).item()),
|
|
160
|
+
"generation_length_std": float(np.std(gen_lengths).item()),
|
|
161
|
+
"num_samples": float(len(grades)),
|
|
162
|
+
}
|
|
163
|
+
|
|
164
|
+
|
|
165
|
+
class EnvironmentValidationCallback(RecipeCallback):
|
|
166
|
+
def __init__(
|
|
167
|
+
self,
|
|
168
|
+
validation_set: list[StringThread],
|
|
169
|
+
model: InferenceModel,
|
|
170
|
+
env_factory: EnvironmentFactory,
|
|
171
|
+
frequency: float,
|
|
172
|
+
log_key: str = "validation",
|
|
173
|
+
clear_env_logs: bool = True,
|
|
174
|
+
temperature: float = 0.0,
|
|
175
|
+
num_samples_log: int = 0,
|
|
176
|
+
):
|
|
177
|
+
super().__init__(frequency, log_key_prefix=log_key)
|
|
178
|
+
self.validation_set = validation_set
|
|
179
|
+
self.model = model
|
|
180
|
+
self.env_factory = env_factory
|
|
181
|
+
self.clear_env_logs = clear_env_logs
|
|
182
|
+
self.temperature = temperature
|
|
183
|
+
self.num_samples_log = num_samples_log
|
|
184
|
+
|
|
185
|
+
async def generate_trajectory(self, initial_thread: StringThread) -> tuple[StringThread, float, int]:
|
|
186
|
+
env = self.env_factory.create_environment(initial_thread.metadata)
|
|
187
|
+
temp_model = self.model.temperature(self.temperature)
|
|
188
|
+
trajectory, trajectory_score = await env.generate_trajectory_and_grade(temp_model, initial_thread)
|
|
189
|
+
num_turns = len([turn for turn in trajectory.get_turns() if turn.role == "assistant"])
|
|
190
|
+
return trajectory, trajectory_score.cumulative_score, num_turns
|
|
191
|
+
|
|
192
|
+
async def callback(self, current_percentage: float) -> dict[str, float | Table]:
|
|
193
|
+
logger.info("Entering environment validation callback...")
|
|
194
|
+
|
|
195
|
+
results = await async_map_fallible(self.generate_trajectory, self.validation_set)
|
|
196
|
+
|
|
197
|
+
trajectories = [traj for traj, _, _ in results]
|
|
198
|
+
scores = [score for _, score, _ in results]
|
|
199
|
+
num_turns_list = [num_turns for _, _, num_turns in results]
|
|
200
|
+
|
|
201
|
+
validation_logs = {
|
|
202
|
+
"score_mean": np.mean(scores).item(),
|
|
203
|
+
"score_std": np.std(scores).item(),
|
|
204
|
+
"num_turns_mean": np.mean(num_turns_list).item(),
|
|
205
|
+
"num_turns_std": np.std(num_turns_list).item(),
|
|
206
|
+
"num_samples": len(results),
|
|
207
|
+
}
|
|
208
|
+
|
|
209
|
+
env_logs = self.env_factory.get_logs(clear=self.clear_env_logs)
|
|
210
|
+
validation_logs.update({f"env/{key}": value for key, value in env_logs.items()})
|
|
211
|
+
|
|
212
|
+
if self.num_samples_log > 0:
|
|
213
|
+
samples = [repr(traj) for traj in trajectories[: self.num_samples_log]]
|
|
214
|
+
samples_scores = scores[: self.num_samples_log]
|
|
215
|
+
table = Table().add_column("trajectory", samples).add_column("score", samples_scores)
|
|
216
|
+
validation_logs["samples"] = table
|
|
217
|
+
|
|
218
|
+
logger.info(f"Validation Mean score: {validation_logs['score_mean']:.4f}")
|
|
219
|
+
return validation_logs
|