adaptive-harmony 0.1.23__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (80) hide show
  1. adaptive_harmony-0.1.23/PKG-INFO +37 -0
  2. adaptive_harmony-0.1.23/pyproject.toml +110 -0
  3. adaptive_harmony-0.1.23/python/adaptive_harmony/__init__.py +162 -0
  4. adaptive_harmony-0.1.23/python/adaptive_harmony/common/__init__.py +40 -0
  5. adaptive_harmony-0.1.23/python/adaptive_harmony/common/callbacks.py +219 -0
  6. adaptive_harmony-0.1.23/python/adaptive_harmony/common/checkpointing.py +163 -0
  7. adaptive_harmony-0.1.23/python/adaptive_harmony/common/dpo.py +92 -0
  8. adaptive_harmony-0.1.23/python/adaptive_harmony/common/env_grpo.py +361 -0
  9. adaptive_harmony-0.1.23/python/adaptive_harmony/common/grpo.py +260 -0
  10. adaptive_harmony-0.1.23/python/adaptive_harmony/common/gspo.py +70 -0
  11. adaptive_harmony-0.1.23/python/adaptive_harmony/common/ppo.py +303 -0
  12. adaptive_harmony-0.1.23/python/adaptive_harmony/common/rm.py +79 -0
  13. adaptive_harmony-0.1.23/python/adaptive_harmony/common/sft.py +121 -0
  14. adaptive_harmony-0.1.23/python/adaptive_harmony/core/__init__.py +0 -0
  15. adaptive_harmony-0.1.23/python/adaptive_harmony/core/dataset.py +72 -0
  16. adaptive_harmony-0.1.23/python/adaptive_harmony/core/display.py +93 -0
  17. adaptive_harmony-0.1.23/python/adaptive_harmony/core/image_utils.py +110 -0
  18. adaptive_harmony-0.1.23/python/adaptive_harmony/core/reasoning.py +12 -0
  19. adaptive_harmony-0.1.23/python/adaptive_harmony/core/reward_client/__init__.py +19 -0
  20. adaptive_harmony-0.1.23/python/adaptive_harmony/core/reward_client/client.py +160 -0
  21. adaptive_harmony-0.1.23/python/adaptive_harmony/core/reward_client/reward_types.py +49 -0
  22. adaptive_harmony-0.1.23/python/adaptive_harmony/core/reward_client/websocket_utils.py +18 -0
  23. adaptive_harmony-0.1.23/python/adaptive_harmony/core/rich_counter.py +351 -0
  24. adaptive_harmony-0.1.23/python/adaptive_harmony/core/rl_utils.py +38 -0
  25. adaptive_harmony-0.1.23/python/adaptive_harmony/core/schedulers.py +38 -0
  26. adaptive_harmony-0.1.23/python/adaptive_harmony/core/structured_output.py +385 -0
  27. adaptive_harmony-0.1.23/python/adaptive_harmony/core/utils.py +365 -0
  28. adaptive_harmony-0.1.23/python/adaptive_harmony/environment/__init__.py +8 -0
  29. adaptive_harmony-0.1.23/python/adaptive_harmony/environment/environment.py +121 -0
  30. adaptive_harmony-0.1.23/python/adaptive_harmony/evaluation/__init__.py +1 -0
  31. adaptive_harmony-0.1.23/python/adaptive_harmony/evaluation/evaluation_artifact.py +67 -0
  32. adaptive_harmony-0.1.23/python/adaptive_harmony/graders/__init__.py +20 -0
  33. adaptive_harmony-0.1.23/python/adaptive_harmony/graders/answer_relevancy_judge/__init__.py +3 -0
  34. adaptive_harmony-0.1.23/python/adaptive_harmony/graders/answer_relevancy_judge/answer_relevancy_judge.py +102 -0
  35. adaptive_harmony-0.1.23/python/adaptive_harmony/graders/answer_relevancy_judge/prompts.py +58 -0
  36. adaptive_harmony-0.1.23/python/adaptive_harmony/graders/base_grader.py +265 -0
  37. adaptive_harmony-0.1.23/python/adaptive_harmony/graders/binary_judge/__init__.py +8 -0
  38. adaptive_harmony-0.1.23/python/adaptive_harmony/graders/binary_judge/binary_judge.py +202 -0
  39. adaptive_harmony-0.1.23/python/adaptive_harmony/graders/binary_judge/prompts.py +125 -0
  40. adaptive_harmony-0.1.23/python/adaptive_harmony/graders/combined_grader.py +118 -0
  41. adaptive_harmony-0.1.23/python/adaptive_harmony/graders/context_relevancy_judge/__init__.py +3 -0
  42. adaptive_harmony-0.1.23/python/adaptive_harmony/graders/context_relevancy_judge/context_relevancy_judge.py +128 -0
  43. adaptive_harmony-0.1.23/python/adaptive_harmony/graders/context_relevancy_judge/prompts.py +84 -0
  44. adaptive_harmony-0.1.23/python/adaptive_harmony/graders/exceptions.py +9 -0
  45. adaptive_harmony-0.1.23/python/adaptive_harmony/graders/faithfulness_judge/__init__.py +3 -0
  46. adaptive_harmony-0.1.23/python/adaptive_harmony/graders/faithfulness_judge/faithfulness_judge.py +159 -0
  47. adaptive_harmony-0.1.23/python/adaptive_harmony/graders/faithfulness_judge/prompts.py +22 -0
  48. adaptive_harmony-0.1.23/python/adaptive_harmony/graders/range_judge/__init__.py +7 -0
  49. adaptive_harmony-0.1.23/python/adaptive_harmony/graders/range_judge/prompts.py +232 -0
  50. adaptive_harmony-0.1.23/python/adaptive_harmony/graders/range_judge/range_judge.py +188 -0
  51. adaptive_harmony-0.1.23/python/adaptive_harmony/graders/range_judge/types.py +12 -0
  52. adaptive_harmony-0.1.23/python/adaptive_harmony/graders/reward_server_grader.py +36 -0
  53. adaptive_harmony-0.1.23/python/adaptive_harmony/graders/templated_prompt_judge.py +237 -0
  54. adaptive_harmony-0.1.23/python/adaptive_harmony/graders/utils.py +79 -0
  55. adaptive_harmony-0.1.23/python/adaptive_harmony/logging_table.py +1 -0
  56. adaptive_harmony-0.1.23/python/adaptive_harmony/metric_logger.py +452 -0
  57. adaptive_harmony-0.1.23/python/adaptive_harmony/parameters/__init__.py +2 -0
  58. adaptive_harmony-0.1.23/python/adaptive_harmony/py.typed +0 -0
  59. adaptive_harmony-0.1.23/python/adaptive_harmony/runtime/__init__.py +2 -0
  60. adaptive_harmony-0.1.23/python/adaptive_harmony/runtime/context.py +2 -0
  61. adaptive_harmony-0.1.23/python/adaptive_harmony/runtime/data.py +2 -0
  62. adaptive_harmony-0.1.23/python/adaptive_harmony/runtime/decorators.py +2 -0
  63. adaptive_harmony-0.1.23/python/adaptive_harmony/runtime/model_artifact_save.py +2 -0
  64. adaptive_harmony-0.1.23/python/adaptive_harmony/runtime/runner.py +27 -0
  65. adaptive_harmony-0.1.23/python/adaptive_harmony/runtime/simple_notifier.py +2 -0
  66. adaptive_harmony-0.1.23/python/adaptive_harmony.egg-info/PKG-INFO +37 -0
  67. adaptive_harmony-0.1.23/python/adaptive_harmony.egg-info/SOURCES.txt +78 -0
  68. adaptive_harmony-0.1.23/python/adaptive_harmony.egg-info/dependency_links.txt +1 -0
  69. adaptive_harmony-0.1.23/python/adaptive_harmony.egg-info/requires.txt +33 -0
  70. adaptive_harmony-0.1.23/python/adaptive_harmony.egg-info/top_level.txt +1 -0
  71. adaptive_harmony-0.1.23/setup.cfg +4 -0
  72. adaptive_harmony-0.1.23/tests/test.py +3 -0
  73. adaptive_harmony-0.1.23/tests/test_adaptive_dataset.py +257 -0
  74. adaptive_harmony-0.1.23/tests/test_eval_artifact.py +108 -0
  75. adaptive_harmony-0.1.23/tests/test_graders.py +78 -0
  76. adaptive_harmony-0.1.23/tests/test_job_artifact.py +77 -0
  77. adaptive_harmony-0.1.23/tests/test_local_eval_logging.py +131 -0
  78. adaptive_harmony-0.1.23/tests/test_prebuilt_graders.py +273 -0
  79. adaptive_harmony-0.1.23/tests/test_reasoning.py +115 -0
  80. adaptive_harmony-0.1.23/tests/test_runtime_runner.py +361 -0
@@ -0,0 +1,37 @@
1
+ Metadata-Version: 2.4
2
+ Name: adaptive-harmony
3
+ Version: 0.1.23
4
+ Summary: Adaptive Harmony training recipes and utilities for LLM fine-tuning
5
+ Classifier: Programming Language :: Python :: 3.12
6
+ Classifier: Programming Language :: Python :: 3.13
7
+ Classifier: Programming Language :: Python :: Implementation :: CPython
8
+ Classifier: Programming Language :: Python :: Implementation :: PyPy
9
+ Requires-Python: >=3.12
10
+ Requires-Dist: harmony-client
11
+ Requires-Dist: rich>=13.7.0
12
+ Requires-Dist: datasets>=2.14.0
13
+ Requires-Dist: hf-xet>=1.1.2
14
+ Requires-Dist: loguru>=0.7.2
15
+ Requires-Dist: pydantic-settings>=2.9.1
16
+ Requires-Dist: pydantic>=2.10.5
17
+ Requires-Dist: pybars3>=0.9.7
18
+ Requires-Dist: pysbd>=0.3.4
19
+ Requires-Dist: websockets>=15.0.1
20
+ Requires-Dist: setproctitle>=1.3.3
21
+ Requires-Dist: pydantic-xml>=2.16.0
22
+ Requires-Dist: openai>=1.42.0
23
+ Requires-Dist: pillow>=11.3.0
24
+ Requires-Dist: boto3>=1.40
25
+ Requires-Dist: tomli>=2
26
+ Provides-Extra: tensorboard
27
+ Requires-Dist: tensorboardX>=2.6.2.2; extra == "tensorboard"
28
+ Provides-Extra: mlflow
29
+ Requires-Dist: mlflow>=3.4.0; extra == "mlflow"
30
+ Provides-Extra: wandb
31
+ Requires-Dist: wandb>=0.19.11; extra == "wandb"
32
+ Provides-Extra: all-monitoring
33
+ Requires-Dist: tensorboardX>=2.6.2.2; extra == "all-monitoring"
34
+ Requires-Dist: mlflow>=3.4.0; extra == "all-monitoring"
35
+ Requires-Dist: wandb>=0.19.11; extra == "all-monitoring"
36
+ Provides-Extra: mlflow-skinny
37
+ Requires-Dist: mlflow-skinny>=3.4.0; extra == "mlflow-skinny"
@@ -0,0 +1,110 @@
1
+ [build-system]
2
+ requires = ["setuptools>=61.0"]
3
+ build-backend = "setuptools.build_meta"
4
+
5
+ [project]
6
+ version = "0.1.23"
7
+ name = "adaptive-harmony"
8
+ description = "Adaptive Harmony training recipes and utilities for LLM fine-tuning"
9
+ requires-python = ">=3.12"
10
+ classifiers = [
11
+ "Programming Language :: Python :: 3.12",
12
+ "Programming Language :: Python :: 3.13",
13
+ "Programming Language :: Python :: Implementation :: CPython",
14
+ "Programming Language :: Python :: Implementation :: PyPy",
15
+ ]
16
+ dependencies = [
17
+ "harmony-client",
18
+ "rich>=13.7.0",
19
+ "datasets>=2.14.0",
20
+ "hf-xet>=1.1.2",
21
+ "loguru>=0.7.2",
22
+ "pydantic-settings>=2.9.1",
23
+ "pydantic>=2.10.5",
24
+ "pybars3>=0.9.7",
25
+ "pysbd>=0.3.4",
26
+ "websockets>=15.0.1",
27
+ "setproctitle>=1.3.3",
28
+ "pydantic-xml>=2.16.0",
29
+ "openai>=1.42.0",
30
+ "pillow>=11.3.0",
31
+ "boto3>=1.40",
32
+ "tomli>=2",
33
+ ]
34
+
35
+ [project.optional-dependencies]
36
+ tensorboard = ["tensorboardX>=2.6.2.2"]
37
+ mlflow = ["mlflow>=3.4.0"]
38
+ wandb = ["wandb>=0.19.11"]
39
+ all_monitoring = ["tensorboardX>=2.6.2.2", "mlflow>=3.4.0", "wandb>=0.19.11"]
40
+ mlflow_skinny = ["mlflow-skinny>=3.4.0"]
41
+
42
+ [dependency-groups]
43
+ dev = ["pyright>=1.1.407", "pytest", "pytest-asyncio>=0.23.5", "ruff>=0.14.3"]
44
+
45
+ [tool.pyright]
46
+ pythonVersion = "3.12"
47
+ pythonPlatform = "Linux"
48
+ typeCheckingMode = "standard"
49
+ reportUntypedFunctionDecorator = "warning"
50
+ reportMissingTypeStubs = "warning"
51
+ reportMissingImports = false
52
+ reportUnusedVariable = "warning"
53
+ reportUnusedImport = "warning"
54
+ reportDuplicateImport = "error"
55
+ reportArgumentType = "error"
56
+ reportPossiblyUnboundVariable = "warning"
57
+ reportIncompatibleMethodOverride = "warning"
58
+ exclude = [
59
+ ".venv/**",
60
+ "**/__pycache__",
61
+ "tests",
62
+ "examples",
63
+ "python/adaptive_harmony/metric_logger.py",
64
+ "python/adaptive_harmony/runtime", # auto-generated code
65
+ ]
66
+ extraPaths = ["python"]
67
+
68
+ [tool.ruff]
69
+ line-length = 120
70
+ lint.ignore = [
71
+ "E731",
72
+ "E402",
73
+ "UP035",
74
+ "UP045",
75
+ "UP040",
76
+ "UP006",
77
+ "F841",
78
+ "E501",
79
+ ]
80
+ lint.select = [
81
+ "E", # pycodestyle errors
82
+ "F", # pyflakes
83
+ "I", # isort (import sorting)
84
+ "UP", # pyupgrade (modern Python syntax)
85
+ "ASYNC", # async-specific checks
86
+ ]
87
+ lint.fixable = ["I", "UP", "F401"]
88
+ exclude = [
89
+ ".venv/**",
90
+ "**/__pycache__",
91
+ "tests",
92
+ "examples",
93
+ "python/adaptive_harmony/runtime", # auto-generated code
94
+ ]
95
+
96
+
97
+ [tool.pytest.ini_options]
98
+ asyncio_mode = "auto"
99
+ markers = [
100
+ "harmony: mark tests that require harmony to run (deselect with '-m \"not harmony\"'')",
101
+ ]
102
+
103
+ [tool.setuptools]
104
+ package-dir = { "" = "python" }
105
+
106
+ [tool.setuptools.packages.find]
107
+ where = ["python"]
108
+
109
+ [tool.uv]
110
+ cache-keys = [{ file = "pyproject.toml" }]
@@ -0,0 +1,162 @@
1
+ # ruff: noqa: F403, F401
2
+ from typing import TYPE_CHECKING
3
+
4
+ from harmony_client import (
5
+ EvalSample as EvalSample,
6
+ )
7
+ from harmony_client import (
8
+ EvalSampleInteraction as EvalSampleInteraction,
9
+ )
10
+ from harmony_client import (
11
+ Grade as Grade,
12
+ )
13
+ from harmony_client import (
14
+ HarmonyClient as HarmonyClient,
15
+ )
16
+ from harmony_client import (
17
+ HarmonyJobNotifier as HarmonyJobNotifier,
18
+ )
19
+ from harmony_client import (
20
+ InferenceModel as InferenceModel,
21
+ )
22
+ from harmony_client import (
23
+ JobArtifact as JobArtifact,
24
+ )
25
+ from harmony_client import (
26
+ JobNotifier as JobNotifier,
27
+ )
28
+ from harmony_client import (
29
+ ModelBuilder as ModelBuilder,
30
+ )
31
+ from harmony_client import (
32
+ StageNotifier as StageNotifier,
33
+ )
34
+ from harmony_client import (
35
+ StringThread as StringThread,
36
+ )
37
+ from harmony_client import (
38
+ TokenizedThread as TokenizedThread,
39
+ )
40
+ from harmony_client import (
41
+ TrainingModel as TrainingModel,
42
+ )
43
+ from harmony_client import (
44
+ get_client as get_client,
45
+ )
46
+ from harmony_client import parameters as parameters
47
+ from harmony_client import runtime as runtime
48
+ from rich.progress import Progress
49
+
50
+ if TYPE_CHECKING:
51
+ from harmony_client import StringTurn as StringTurn
52
+ else:
53
+ from typing import NamedTuple
54
+
55
+ class StringTurn(NamedTuple):
56
+ role: str
57
+ content: str
58
+
59
+
60
+ from harmony_client.artifacts.custom_artifact import CustomArtifact
61
+ from harmony_client.artifacts.dataset_artifact import DatasetArtifact
62
+ from harmony_client.file_storage import (
63
+ FileStorage,
64
+ FileStorageConfig,
65
+ LocalFileStorageConfig,
66
+ S3FileStorageConfig,
67
+ StoredFile,
68
+ )
69
+
70
+ import adaptive_harmony.core.rl_utils as rl_utils
71
+ from adaptive_harmony.core.dataset import DataSet
72
+ from adaptive_harmony.core.schedulers import CombinedSchedule, CosineScheduler, CosineSchedulerWithoutWarmup, Scheduler
73
+ from adaptive_harmony.evaluation.evaluation_artifact import EvaluationArtifact
74
+ from adaptive_harmony.metric_logger import Logger, WandbLogger
75
+
76
+ # Ensure key classes are available at module level
77
+ __all__ = [
78
+ "StringThread",
79
+ "StringTurn",
80
+ "TokenizedThread",
81
+ "InferenceModel",
82
+ "ModelBuilder",
83
+ "TrainingModel",
84
+ "HarmonyClient",
85
+ "get_client",
86
+ "DataSet",
87
+ "CosineScheduler",
88
+ "CombinedSchedule",
89
+ "CosineSchedulerWithoutWarmup",
90
+ "Scheduler",
91
+ "WandbLogger",
92
+ "Logger",
93
+ "FileStorage",
94
+ "FileStorageConfig",
95
+ "LocalFileStorageConfig",
96
+ "S3FileStorageConfig",
97
+ "StoredFile",
98
+ "EvaluationArtifact",
99
+ "CustomArtifact",
100
+ "DatasetArtifact",
101
+ "rl_utils",
102
+ "Grade",
103
+ "EvalSample",
104
+ "EvalSampleInteraction",
105
+ "JobArtifact",
106
+ ]
107
+
108
+
109
+ # Patch StringThread to use rich for display
110
+ from harmony_client.runtime.model_artifact_save import save_with_artifact
111
+
112
+ from adaptive_harmony.core.display import _stringthread_repr, _tokenizedthread_repr
113
+ from adaptive_harmony.core.image_utils import string_thread_to_html_string
114
+
115
+ # Patch InferenceModel to have json output capabilities
116
+ from adaptive_harmony.core.structured_output import generate_and_validate, render_pydantic_model, render_schema
117
+
118
+ StringThread.__repr__ = _stringthread_repr # type: ignore
119
+ TokenizedThread.__repr__ = _tokenizedthread_repr # type: ignore
120
+ setattr(StringThread, "_repr_html_", string_thread_to_html_string)
121
+ setattr(InferenceModel, "generate_and_validate", generate_and_validate)
122
+ setattr(InferenceModel, "render_schema", staticmethod(render_schema))
123
+ setattr(InferenceModel, "render_pydantic_model", staticmethod(render_pydantic_model))
124
+
125
+ _original_training_model_save = TrainingModel.save
126
+
127
+
128
+ async def _save_with_artifact_wrapper(model: TrainingModel, model_name: str, inference_only: bool = True, ctx=None):
129
+ return await save_with_artifact(model, model_name, inference_only, ctx, _original_training_model_save)
130
+
131
+
132
+ setattr(TrainingModel, "save", _save_with_artifact_wrapper)
133
+
134
+
135
+ async def spawn_train(self: ModelBuilder, name: str, max_batch_size: int) -> TrainingModel:
136
+ fut = await self.spawn_train_with_progress(name, max_batch_size) # type:ignore
137
+
138
+ with Progress() as pbar:
139
+ task = pbar.add_task("Loading model", total=1000)
140
+
141
+ while (prog := await fut._await_progress()) != 1.0:
142
+ pbar.update(task, completed=prog, total=1.0)
143
+ pbar.update(task, completed=1.0, total=1.0)
144
+
145
+ return await fut.get()
146
+
147
+
148
+ async def spawn_inference(self: ModelBuilder, name: str) -> InferenceModel:
149
+ fut = await self.spawn_inference_with_progress(name) # type:ignore
150
+
151
+ with Progress() as pbar:
152
+ task = pbar.add_task("Loading model", total=1000)
153
+
154
+ while (prog := await fut._await_progress()) != 1.0:
155
+ pbar.update(task, completed=prog, total=1.0)
156
+ pbar.update(task, completed=1.0, total=1.0)
157
+
158
+ return await fut.get()
159
+
160
+
161
+ setattr(ModelBuilder, "spawn_inference", spawn_inference)
162
+ setattr(ModelBuilder, "spawn_train", spawn_train)
@@ -0,0 +1,40 @@
1
+ from .callbacks import (
2
+ CheckpointCallback as CheckpointCallback,
3
+ )
4
+ from .callbacks import (
5
+ EnvironmentValidationCallback as EnvironmentValidationCallback,
6
+ )
7
+ from .callbacks import (
8
+ GenerateSamplesCallback as GenerateSamplesCallback,
9
+ )
10
+ from .callbacks import (
11
+ GraderEvalCallback as GraderEvalCallback,
12
+ )
13
+ from .callbacks import (
14
+ RecipeCallback as RecipeCallback,
15
+ )
16
+ from .callbacks import (
17
+ ValidationLossCallback as ValidationLossCallback,
18
+ )
19
+ from .dpo import DPO as DPO
20
+ from .env_grpo import ENVGRPO
21
+ from .grpo import GRPO as GRPO
22
+ from .gspo import GSPO as GSPO
23
+ from .ppo import PPO as PPO
24
+ from .rm import RewardModelling as RewardModelling
25
+ from .sft import SFT as SFT
26
+
27
+ __all__ = [
28
+ "SFT",
29
+ "PPO",
30
+ "GRPO",
31
+ "ENVGRPO",
32
+ "DPO",
33
+ "RewardModelling",
34
+ "RecipeCallback",
35
+ "GenerateSamplesCallback",
36
+ "ValidationLossCallback",
37
+ "CheckpointCallback",
38
+ "GraderEvalCallback",
39
+ "EnvironmentValidationCallback",
40
+ ]
@@ -0,0 +1,219 @@
1
+ from abc import abstractmethod
2
+ from typing import Any
3
+
4
+ import numpy as np
5
+ from harmony_client import (
6
+ InferenceModel,
7
+ StringThread,
8
+ TrainingModel,
9
+ )
10
+ from loguru import logger
11
+
12
+ from adaptive_harmony.core.utils import async_map, async_map_fallible
13
+ from adaptive_harmony.environment import EnvironmentFactory
14
+ from adaptive_harmony.graders import BaseGrader
15
+ from adaptive_harmony.logging_table import Table
16
+
17
+
18
+ class RecipeCallback:
19
+ def __init__(self, frequency: float, log_key_prefix: str | None = None):
20
+ self.frequency = frequency
21
+ self.last_call = -1.0
22
+ self.log_key_prefix = log_key_prefix
23
+
24
+ async def maybe_call(self, current_percentage: float) -> dict[str, Any]:
25
+ if current_percentage - self.last_call >= self.frequency:
26
+ self.last_call = current_percentage
27
+ callback_dict = await self.callback(current_percentage)
28
+ prefixed_dict = {
29
+ (f"{self.log_key_prefix}/{key}" if self.log_key_prefix else key): value
30
+ for key, value in callback_dict.items()
31
+ }
32
+ return prefixed_dict
33
+ return {}
34
+
35
+ @abstractmethod
36
+ async def callback(self, current_percentage: float) -> dict[str, Any]: ...
37
+
38
+
39
+ class GenerateSamplesCallback(RecipeCallback):
40
+ def __init__(
41
+ self,
42
+ thread_set: list[StringThread],
43
+ model: InferenceModel,
44
+ frequency: float,
45
+ log_key: str = "samples",
46
+ ):
47
+ super().__init__(frequency, log_key_prefix="generation")
48
+ self.thread_set = thread_set
49
+ self.model = model
50
+ self.log_key = log_key
51
+
52
+ async def callback(self, current_percentage: float) -> dict[str, Any]:
53
+ logger.info("Entering generation callback...")
54
+ generation_tokens = await async_map_fallible(self.model.generate_tokens, self.thread_set)
55
+ generation_results = await async_map_fallible(self.model.detokenize_thread, generation_tokens)
56
+ gen_lengths = [sample.len_last_turn() for sample in generation_tokens]
57
+
58
+ generation_logs = {
59
+ self.log_key: Table()
60
+ .add_column(
61
+ "system",
62
+ [
63
+ sample.get_turns()[0].content if sample.get_turns()[0].role == "system" else ""
64
+ for sample in generation_results
65
+ ],
66
+ )
67
+ .add_column(
68
+ "prompt",
69
+ [
70
+ repr(
71
+ StringThread(
72
+ sample.get_turns()[1:-1]
73
+ if (sample.get_turns() and sample.get_turns()[0].role == "system")
74
+ else sample.get_turns()[:-1]
75
+ )
76
+ )
77
+ for sample in generation_results
78
+ ],
79
+ )
80
+ .add_column("response", [response.last_content() for response in generation_results]),
81
+ "generation_length_mean": np.mean(gen_lengths).item(),
82
+ "generation_length_std": np.std(gen_lengths).item(),
83
+ "num_samples": len(generation_results),
84
+ }
85
+ return generation_logs
86
+
87
+
88
+ class ValidationLossCallback(RecipeCallback):
89
+ def __init__(
90
+ self,
91
+ validation_set: list[StringThread],
92
+ model: InferenceModel,
93
+ frequency: float = 0.1,
94
+ log_key: str = "loss",
95
+ ):
96
+ super().__init__(frequency, log_key_prefix="validation")
97
+ self.validation_set = validation_set
98
+ self.model = model
99
+ self.log_key = log_key
100
+
101
+ async def callback(self, current_percentage: float) -> dict[str, float]:
102
+ logger.info("Entering validation loss callback...")
103
+ losses = []
104
+ tokens = await async_map_fallible(self.model.tokenize_thread, self.validation_set)
105
+ logprobs = await async_map(self.model.logprobs_per_token, tokens)
106
+ losses = [-(sum(lp) / len(lp)) for lp in logprobs]
107
+
108
+ return {self.log_key: sum(losses) / len(losses)}
109
+
110
+
111
+ class CheckpointCallback(RecipeCallback):
112
+ def __init__(
113
+ self,
114
+ model: TrainingModel,
115
+ checkpoint_name: str,
116
+ frequency: float = 0.2,
117
+ ):
118
+ super().__init__(frequency, log_key_prefix="checkpointing")
119
+ self.last_call = 0.0 # avoid saving the model at the first period
120
+ self.model = model
121
+ self.model_log_name = checkpoint_name
122
+
123
+ async def callback(self, current_percentage: float):
124
+ logger.info(f"Saving checkpoint at {current_percentage * 100} % of training ...")
125
+ await self.model.save(f"{self.model_log_name}-{round(current_percentage, 3)}")
126
+ return {}
127
+
128
+
129
+ class GraderEvalCallback(RecipeCallback):
130
+ def __init__(
131
+ self,
132
+ validation_set: list[StringThread],
133
+ model: InferenceModel,
134
+ grader: BaseGrader,
135
+ frequency: float,
136
+ log_key: str = "validation",
137
+ clear_grader_logs: bool = True,
138
+ temperature: float = 0.0,
139
+ ):
140
+ super().__init__(frequency, log_key_prefix=log_key)
141
+ self.validation_set = validation_set
142
+ self.model = model
143
+ self.grader = grader
144
+ self.clear_grader_logs = clear_grader_logs
145
+ self.temperature = temperature
146
+
147
+ async def callback(self, current_percentage: float) -> dict[str, float | Table]:
148
+ logger.info("Entering grader evaluation callback...")
149
+ temp_model = self.model.temperature(self.temperature)
150
+
151
+ tokenized_results = await async_map_fallible(temp_model.generate_tokens, self.validation_set)
152
+ string_results = await async_map(temp_model.detokenize_thread, tokenized_results)
153
+ grades = await async_map_fallible(self.grader.grade, string_results)
154
+ gen_lengths = [sample.len_last_turn() for sample in tokenized_results]
155
+
156
+ grader_logs = self.grader.get_logs(clear=self.clear_grader_logs)
157
+ return {
158
+ **{f"rewards/{key}": value for key, value in grader_logs.items()},
159
+ "generation_length_mean": float(np.mean(gen_lengths).item()),
160
+ "generation_length_std": float(np.std(gen_lengths).item()),
161
+ "num_samples": float(len(grades)),
162
+ }
163
+
164
+
165
+ class EnvironmentValidationCallback(RecipeCallback):
166
+ def __init__(
167
+ self,
168
+ validation_set: list[StringThread],
169
+ model: InferenceModel,
170
+ env_factory: EnvironmentFactory,
171
+ frequency: float,
172
+ log_key: str = "validation",
173
+ clear_env_logs: bool = True,
174
+ temperature: float = 0.0,
175
+ num_samples_log: int = 0,
176
+ ):
177
+ super().__init__(frequency, log_key_prefix=log_key)
178
+ self.validation_set = validation_set
179
+ self.model = model
180
+ self.env_factory = env_factory
181
+ self.clear_env_logs = clear_env_logs
182
+ self.temperature = temperature
183
+ self.num_samples_log = num_samples_log
184
+
185
+ async def generate_trajectory(self, initial_thread: StringThread) -> tuple[StringThread, float, int]:
186
+ env = self.env_factory.create_environment(initial_thread.metadata)
187
+ temp_model = self.model.temperature(self.temperature)
188
+ trajectory, trajectory_score = await env.generate_trajectory_and_grade(temp_model, initial_thread)
189
+ num_turns = len([turn for turn in trajectory.get_turns() if turn.role == "assistant"])
190
+ return trajectory, trajectory_score.cumulative_score, num_turns
191
+
192
+ async def callback(self, current_percentage: float) -> dict[str, float | Table]:
193
+ logger.info("Entering environment validation callback...")
194
+
195
+ results = await async_map_fallible(self.generate_trajectory, self.validation_set)
196
+
197
+ trajectories = [traj for traj, _, _ in results]
198
+ scores = [score for _, score, _ in results]
199
+ num_turns_list = [num_turns for _, _, num_turns in results]
200
+
201
+ validation_logs = {
202
+ "score_mean": np.mean(scores).item(),
203
+ "score_std": np.std(scores).item(),
204
+ "num_turns_mean": np.mean(num_turns_list).item(),
205
+ "num_turns_std": np.std(num_turns_list).item(),
206
+ "num_samples": len(results),
207
+ }
208
+
209
+ env_logs = self.env_factory.get_logs(clear=self.clear_env_logs)
210
+ validation_logs.update({f"env/{key}": value for key, value in env_logs.items()})
211
+
212
+ if self.num_samples_log > 0:
213
+ samples = [repr(traj) for traj in trajectories[: self.num_samples_log]]
214
+ samples_scores = scores[: self.num_samples_log]
215
+ table = Table().add_column("trajectory", samples).add_column("score", samples_scores)
216
+ validation_logs["samples"] = table
217
+
218
+ logger.info(f"Validation Mean score: {validation_logs['score_mean']:.4f}")
219
+ return validation_logs