adaptive-harmony 0.1.23__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- adaptive_harmony/__init__.py +162 -0
- adaptive_harmony/common/__init__.py +40 -0
- adaptive_harmony/common/callbacks.py +219 -0
- adaptive_harmony/common/checkpointing.py +163 -0
- adaptive_harmony/common/dpo.py +92 -0
- adaptive_harmony/common/env_grpo.py +361 -0
- adaptive_harmony/common/grpo.py +260 -0
- adaptive_harmony/common/gspo.py +70 -0
- adaptive_harmony/common/ppo.py +303 -0
- adaptive_harmony/common/rm.py +79 -0
- adaptive_harmony/common/sft.py +121 -0
- adaptive_harmony/core/__init__.py +0 -0
- adaptive_harmony/core/dataset.py +72 -0
- adaptive_harmony/core/display.py +93 -0
- adaptive_harmony/core/image_utils.py +110 -0
- adaptive_harmony/core/reasoning.py +12 -0
- adaptive_harmony/core/reward_client/__init__.py +19 -0
- adaptive_harmony/core/reward_client/client.py +160 -0
- adaptive_harmony/core/reward_client/reward_types.py +49 -0
- adaptive_harmony/core/reward_client/websocket_utils.py +18 -0
- adaptive_harmony/core/rich_counter.py +351 -0
- adaptive_harmony/core/rl_utils.py +38 -0
- adaptive_harmony/core/schedulers.py +38 -0
- adaptive_harmony/core/structured_output.py +385 -0
- adaptive_harmony/core/utils.py +365 -0
- adaptive_harmony/environment/__init__.py +8 -0
- adaptive_harmony/environment/environment.py +121 -0
- adaptive_harmony/evaluation/__init__.py +1 -0
- adaptive_harmony/evaluation/evaluation_artifact.py +67 -0
- adaptive_harmony/graders/__init__.py +20 -0
- adaptive_harmony/graders/answer_relevancy_judge/__init__.py +3 -0
- adaptive_harmony/graders/answer_relevancy_judge/answer_relevancy_judge.py +102 -0
- adaptive_harmony/graders/answer_relevancy_judge/prompts.py +58 -0
- adaptive_harmony/graders/base_grader.py +265 -0
- adaptive_harmony/graders/binary_judge/__init__.py +8 -0
- adaptive_harmony/graders/binary_judge/binary_judge.py +202 -0
- adaptive_harmony/graders/binary_judge/prompts.py +125 -0
- adaptive_harmony/graders/combined_grader.py +118 -0
- adaptive_harmony/graders/context_relevancy_judge/__init__.py +3 -0
- adaptive_harmony/graders/context_relevancy_judge/context_relevancy_judge.py +128 -0
- adaptive_harmony/graders/context_relevancy_judge/prompts.py +84 -0
- adaptive_harmony/graders/exceptions.py +9 -0
- adaptive_harmony/graders/faithfulness_judge/__init__.py +3 -0
- adaptive_harmony/graders/faithfulness_judge/faithfulness_judge.py +159 -0
- adaptive_harmony/graders/faithfulness_judge/prompts.py +22 -0
- adaptive_harmony/graders/range_judge/__init__.py +7 -0
- adaptive_harmony/graders/range_judge/prompts.py +232 -0
- adaptive_harmony/graders/range_judge/range_judge.py +188 -0
- adaptive_harmony/graders/range_judge/types.py +12 -0
- adaptive_harmony/graders/reward_server_grader.py +36 -0
- adaptive_harmony/graders/templated_prompt_judge.py +237 -0
- adaptive_harmony/graders/utils.py +79 -0
- adaptive_harmony/logging_table.py +1 -0
- adaptive_harmony/metric_logger.py +452 -0
- adaptive_harmony/parameters/__init__.py +2 -0
- adaptive_harmony/py.typed +0 -0
- adaptive_harmony/runtime/__init__.py +2 -0
- adaptive_harmony/runtime/context.py +2 -0
- adaptive_harmony/runtime/data.py +2 -0
- adaptive_harmony/runtime/decorators.py +2 -0
- adaptive_harmony/runtime/model_artifact_save.py +2 -0
- adaptive_harmony/runtime/runner.py +27 -0
- adaptive_harmony/runtime/simple_notifier.py +2 -0
- adaptive_harmony-0.1.23.dist-info/METADATA +37 -0
- adaptive_harmony-0.1.23.dist-info/RECORD +67 -0
- adaptive_harmony-0.1.23.dist-info/WHEEL +5 -0
- adaptive_harmony-0.1.23.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,237 @@
|
|
|
1
|
+
import html
|
|
2
|
+
from abc import ABC, abstractmethod
|
|
3
|
+
from typing import Any, ClassVar
|
|
4
|
+
|
|
5
|
+
from pybars import Compiler
|
|
6
|
+
from pydantic import BaseModel, Field
|
|
7
|
+
|
|
8
|
+
from adaptive_harmony import Grade, InferenceModel, StringThread, StringTurn
|
|
9
|
+
from adaptive_harmony.core.structured_output import JsonParseError, render_schema
|
|
10
|
+
from adaptive_harmony.core.utils import stringify_thread
|
|
11
|
+
from adaptive_harmony.graders import BaseGrader
|
|
12
|
+
from adaptive_harmony.graders.utils import (
|
|
13
|
+
FailedJudgeLog,
|
|
14
|
+
SuccessJudgeLog,
|
|
15
|
+
separate_context_from_last_user_turn,
|
|
16
|
+
validate_thread_last_assistant,
|
|
17
|
+
)
|
|
18
|
+
|
|
19
|
+
OPENAI_MODEL_FAMILIES_TEMPERATURE_1_ONLY = ["gpt-5", "o1", "o3", "o4"]
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def _turns_to_dicts(turns: list[StringTurn]) -> list[dict[str, str]]:
|
|
23
|
+
return [{"role": turn.role, "content": turn.content} for turn in turns]
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class BaseTemplatedPromptJudgeOutput(ABC, BaseModel):
|
|
27
|
+
@abstractmethod
|
|
28
|
+
def get_score(self) -> float:
|
|
29
|
+
pass
|
|
30
|
+
|
|
31
|
+
def get_reasoning(self) -> str | None:
|
|
32
|
+
return None
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
class SimpleReasonedFloatOutput(BaseTemplatedPromptJudgeOutput):
|
|
36
|
+
"""Sample default output format for simple float scores."""
|
|
37
|
+
|
|
38
|
+
score: float = Field(description="The numerical score for the sample")
|
|
39
|
+
reasoning: str = Field(description="Reasoning behind the score")
|
|
40
|
+
|
|
41
|
+
def get_score(self) -> float:
|
|
42
|
+
return self.score
|
|
43
|
+
|
|
44
|
+
def get_reasoning(self) -> str:
|
|
45
|
+
"""Extract the reasoning from this output."""
|
|
46
|
+
return self.reasoning
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
class BinaryJudgeOutput(BaseTemplatedPromptJudgeOutput):
|
|
50
|
+
"""Output format for binary PASS/FAIL/NA judges."""
|
|
51
|
+
|
|
52
|
+
reasoning: str = Field(description="Reasoning string to support the rationale behind the score")
|
|
53
|
+
score: str = Field(description="The literal score for the sample", pattern="^(PASS|FAIL|NA)$")
|
|
54
|
+
|
|
55
|
+
SCORES_MAP: ClassVar[dict[str, float]] = {"PASS": 1.0, "FAIL": 0.0}
|
|
56
|
+
|
|
57
|
+
def get_score(self) -> float:
|
|
58
|
+
"""Convert PASS/FAIL/NA to float. NA raises an exception."""
|
|
59
|
+
if self.score not in self.SCORES_MAP:
|
|
60
|
+
from adaptive_harmony.graders.exceptions import IgnoreScoreException
|
|
61
|
+
|
|
62
|
+
raise IgnoreScoreException(f"Non applicable score: {self.reasoning}")
|
|
63
|
+
return self.SCORES_MAP[self.score]
|
|
64
|
+
|
|
65
|
+
def get_reasoning(self) -> str:
|
|
66
|
+
return self.reasoning
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
class TemplatedPromptJudgeGrader[T: BaseTemplatedPromptJudgeOutput](BaseGrader[SuccessJudgeLog | FailedJudgeLog]):
|
|
70
|
+
r"""
|
|
71
|
+
A flexible grader that uses Handlebars templates for system and user prompts.
|
|
72
|
+
Generic over T, which must inherit from BaseTemplatedPromptJudgeOutput.
|
|
73
|
+
This enforces that all output models implement get_score() and get_reasoning() methods.
|
|
74
|
+
Templates have access to comprehensive context extracted from the StringThread:
|
|
75
|
+
- output_schema: Expected output structure schema as a string
|
|
76
|
+
- turns: List of all turns as dicts with "role" and "content" keys
|
|
77
|
+
- metadata: Thread metadata dict
|
|
78
|
+
- context_turns: All turns without the assistant's completion (includes system prompt)
|
|
79
|
+
- context_str: Context formatted as string ('turn.role:\nturn.content\n' for every turn)
|
|
80
|
+
- context_turns_without_last_user: Same as context_turns, but without the last user turn
|
|
81
|
+
- context_str_without_last_user: Same as context_str, but without the last user turn
|
|
82
|
+
- last_user_turn_content: Content of the last user turn
|
|
83
|
+
- completion: Assistant's completion
|
|
84
|
+
Example Handlebars templates:
|
|
85
|
+
System: "You are a judge. Evaluate responses based on: {{criteria}}.
|
|
86
|
+
Always output the following JSON schema, with no preamble or postamble: {{output_schema}}"
|
|
87
|
+
User: "Conversation context:\n {{context_str_without_last_user}}
|
|
88
|
+
Question:\n {{last_user_turn_content}}
|
|
89
|
+
Response:\n {{completion}}"
|
|
90
|
+
Advanced examples:
|
|
91
|
+
- Conditionals: "{{#if metadata.domain}}Domain: {{metadata.domain}}{{/if}}"
|
|
92
|
+
- Loops with element index: "{{#each user_turns}}Turn {{@index}}: {{content}}{{/each}}"
|
|
93
|
+
- Built-in vars: "{{@index}}, {{@key}}, {{@first}}, {{@last}}"
|
|
94
|
+
"""
|
|
95
|
+
|
|
96
|
+
def __init__(
|
|
97
|
+
self,
|
|
98
|
+
grader_key: str,
|
|
99
|
+
model: InferenceModel,
|
|
100
|
+
system_template: str,
|
|
101
|
+
user_template: str,
|
|
102
|
+
output_model: type[T],
|
|
103
|
+
template_variables: dict[str, Any] | None = None,
|
|
104
|
+
temperature: float = 0.0,
|
|
105
|
+
grader_id: str | None = None,
|
|
106
|
+
):
|
|
107
|
+
"""
|
|
108
|
+
Initialize the templated prompt judge.
|
|
109
|
+
Args:
|
|
110
|
+
grader_key: Unique identifier for this grader
|
|
111
|
+
model: Live InferenceModel to use as judge for grading
|
|
112
|
+
system_template: Handlebars template string for system prompt
|
|
113
|
+
user_template: Handlebars template string for user prompt
|
|
114
|
+
output_model: Pydantic model class that inherits from BaseOutput
|
|
115
|
+
(must implement get_score() and get_reasoning() methods)
|
|
116
|
+
template_variables: Additional variables to make available in templates
|
|
117
|
+
grader_id: Optional grader ID (defaults to grader_key)
|
|
118
|
+
"""
|
|
119
|
+
super().__init__(grader_key)
|
|
120
|
+
self._logs: list[SuccessJudgeLog | FailedJudgeLog] = [] # type: ignore[assignment]
|
|
121
|
+
|
|
122
|
+
# Set temperature to 1.0 if model_key is an OpenAI model in the temperature-1-only list
|
|
123
|
+
model_path: str = model.get_builder_args().get("path") # type: ignore[assignment]
|
|
124
|
+
if model_path.startswith("openai://"):
|
|
125
|
+
model_name = model_path.removeprefix("openai://").split("?")[0]
|
|
126
|
+
if any(model_name.startswith(model) for model in OPENAI_MODEL_FAMILIES_TEMPERATURE_1_ONLY):
|
|
127
|
+
temperature = 1.0
|
|
128
|
+
self.model = model.temperature(temperature)
|
|
129
|
+
self.grader_id_or_key = grader_id or grader_key
|
|
130
|
+
|
|
131
|
+
# Template setup
|
|
132
|
+
self.compiler = Compiler()
|
|
133
|
+
self.system_template = self.compiler.compile(system_template)
|
|
134
|
+
self.user_template = self.compiler.compile(user_template)
|
|
135
|
+
self.template_variables = template_variables or {}
|
|
136
|
+
|
|
137
|
+
# Output configuration
|
|
138
|
+
self.output_model = output_model
|
|
139
|
+
|
|
140
|
+
@classmethod
|
|
141
|
+
def render_template(cls, template: str, thread: StringThread, output_model: type[BaseModel]):
|
|
142
|
+
compiler = Compiler()
|
|
143
|
+
compiled_template = compiler.compile(template)
|
|
144
|
+
context = cls.extract_template_context(thread, output_model)
|
|
145
|
+
return html.unescape(compiled_template(context))
|
|
146
|
+
|
|
147
|
+
@classmethod
|
|
148
|
+
def extract_template_context(
|
|
149
|
+
cls,
|
|
150
|
+
thread: StringThread,
|
|
151
|
+
output_model: type[BaseModel],
|
|
152
|
+
template_variables: dict[str, Any] | None = None,
|
|
153
|
+
) -> dict[str, Any]:
|
|
154
|
+
"""Extract context from StringThread for template rendering"""
|
|
155
|
+
validate_thread_last_assistant(thread)
|
|
156
|
+
|
|
157
|
+
turns = thread.get_turns()
|
|
158
|
+
context_without_last_user, last_user_turn = separate_context_from_last_user_turn(
|
|
159
|
+
thread, include_system_prompt=True
|
|
160
|
+
)
|
|
161
|
+
context = [turn for turn in turns[:-1]]
|
|
162
|
+
|
|
163
|
+
# Build comprehensive context
|
|
164
|
+
context = {
|
|
165
|
+
# Core thread data
|
|
166
|
+
"metadata": thread.metadata,
|
|
167
|
+
"turns": _turns_to_dicts(turns),
|
|
168
|
+
# Context and key turns
|
|
169
|
+
"context_turns": _turns_to_dicts(context),
|
|
170
|
+
"context_str": stringify_thread(StringThread(context)),
|
|
171
|
+
"context_turns_without_last_user": _turns_to_dicts(context_without_last_user),
|
|
172
|
+
"context_str_without_last_user": stringify_thread(StringThread(context_without_last_user)),
|
|
173
|
+
"last_user_turn_content": last_user_turn,
|
|
174
|
+
"completion": thread.last_content(),
|
|
175
|
+
"output_schema": render_schema(output_model),
|
|
176
|
+
# Additional template variables
|
|
177
|
+
**(template_variables or {}),
|
|
178
|
+
}
|
|
179
|
+
|
|
180
|
+
return context
|
|
181
|
+
|
|
182
|
+
def _build_judge_prompt(self, thread: StringThread) -> StringThread:
|
|
183
|
+
"""Build the judging prompt using Handlebars templates"""
|
|
184
|
+
context = self.extract_template_context(thread, self.output_model, self.template_variables)
|
|
185
|
+
|
|
186
|
+
# Render templates using pybars with helpers
|
|
187
|
+
system_content = html.unescape(self.system_template(context))
|
|
188
|
+
user_content = html.unescape(self.user_template(context))
|
|
189
|
+
|
|
190
|
+
# Build prompt thread
|
|
191
|
+
judge_thread = StringThread().system(system_content).user(user_content)
|
|
192
|
+
return judge_thread
|
|
193
|
+
|
|
194
|
+
async def grade(self, sample: StringThread) -> Grade:
|
|
195
|
+
"""Grade a sample using the templated prompt"""
|
|
196
|
+
|
|
197
|
+
judging_prompt = self._build_judge_prompt(sample)
|
|
198
|
+
str_prompt = stringify_thread(judging_prompt, sep=f"\n\n{'-' * 10}\n\n")
|
|
199
|
+
|
|
200
|
+
try:
|
|
201
|
+
_, parsed_output = await self.model.generate_and_validate(judging_prompt, self.output_model)
|
|
202
|
+
except JsonParseError as e:
|
|
203
|
+
self.add_log({"prompt": str_prompt, "error": f"{str(e)}\n\nCOMPLETION:\n{e.completion}"})
|
|
204
|
+
raise
|
|
205
|
+
except Exception as e:
|
|
206
|
+
self.add_log({"prompt": str_prompt, "error": str(e)})
|
|
207
|
+
raise
|
|
208
|
+
|
|
209
|
+
# Extract score and reasoning using the abstract methods
|
|
210
|
+
score = parsed_output.get_score()
|
|
211
|
+
reasoning = parsed_output.get_reasoning() or ""
|
|
212
|
+
|
|
213
|
+
grade = Grade(value=score, grader_key=self.grader_id_or_key, reasoning=reasoning)
|
|
214
|
+
self.add_log({"score": score, "prompt": str_prompt, "reasoning": reasoning})
|
|
215
|
+
|
|
216
|
+
return grade
|
|
217
|
+
|
|
218
|
+
def get_logs(self, clear: bool = False, log_all_samples: bool = False) -> dict[str, float | Any]:
|
|
219
|
+
"""Get aggregated logs from all grading calls"""
|
|
220
|
+
# Get base statistics
|
|
221
|
+
logs = super().get_logs(clear=False)
|
|
222
|
+
|
|
223
|
+
# Get sample logs
|
|
224
|
+
successfully_scored_samples = [log for log in self._logs if "score" in log]
|
|
225
|
+
failed_scored_samples = [log for log in self._logs if "error" in log]
|
|
226
|
+
|
|
227
|
+
if not log_all_samples and successfully_scored_samples:
|
|
228
|
+
# Limit samples for display
|
|
229
|
+
successfully_scored_samples = successfully_scored_samples[:10]
|
|
230
|
+
|
|
231
|
+
sample_logs = self.get_sample_tables(successfully_scored_samples, failed_scored_samples)
|
|
232
|
+
logs.update(sample_logs)
|
|
233
|
+
|
|
234
|
+
if clear:
|
|
235
|
+
self.clear_logs()
|
|
236
|
+
|
|
237
|
+
return logs
|
|
@@ -0,0 +1,79 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from typing import TYPE_CHECKING, TypedDict
|
|
4
|
+
|
|
5
|
+
from adaptive_harmony import StringThread
|
|
6
|
+
|
|
7
|
+
if TYPE_CHECKING:
|
|
8
|
+
from adaptive_harmony import StringTurn
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def validate_thread_last_assistant(thread: StringThread):
|
|
12
|
+
turns = thread.get_turns()
|
|
13
|
+
assert len(turns) > 0, "The thread must have at least one turn"
|
|
14
|
+
assert turns[-1].role == "assistant", "The last turn must be an assistant turn"
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def separate_context_from_last_user_turn(
|
|
18
|
+
thread: StringThread,
|
|
19
|
+
include_system_prompt: bool = False,
|
|
20
|
+
) -> tuple[list[StringTurn], str | None]:
|
|
21
|
+
"""
|
|
22
|
+
Separates turns into context and last user turn.
|
|
23
|
+
Includes system prompt in context if include_system_prompt is True.
|
|
24
|
+
If there is no user turn, last user turn is None.
|
|
25
|
+
"""
|
|
26
|
+
validate_thread_last_assistant(thread)
|
|
27
|
+
turns = thread.get_turns()
|
|
28
|
+
|
|
29
|
+
# Possibly remove system prompt
|
|
30
|
+
if not include_system_prompt and turns[0].role == "system":
|
|
31
|
+
turns = turns[1:]
|
|
32
|
+
|
|
33
|
+
# Find last user turn
|
|
34
|
+
user_question = None
|
|
35
|
+
user_question_idx = -1
|
|
36
|
+
for i, turn in enumerate(turns):
|
|
37
|
+
if turn.role == "user":
|
|
38
|
+
user_question = turn.content
|
|
39
|
+
user_question_idx = i
|
|
40
|
+
|
|
41
|
+
if user_question is None:
|
|
42
|
+
# Last turn is guaranteed to be assistant due to validate_thread_last_assistant
|
|
43
|
+
context_turns = turns[:-1]
|
|
44
|
+
else:
|
|
45
|
+
context_turns = turns[:user_question_idx]
|
|
46
|
+
|
|
47
|
+
return context_turns, user_question
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
class SuccessJudgeLog(TypedDict):
|
|
51
|
+
prompt: str
|
|
52
|
+
reasoning: str
|
|
53
|
+
score: float
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
class FailedJudgeLog(TypedDict):
|
|
57
|
+
prompt: str
|
|
58
|
+
error: str | None
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
def sample_score_distribution(success_samples: list[SuccessJudgeLog], max_n_samples: int = 15) -> list[SuccessJudgeLog]:
|
|
62
|
+
# sort samples by score for percentile-based sampling
|
|
63
|
+
sorted_samples = sorted(success_samples, key=lambda x: x["score"])
|
|
64
|
+
total_samples = len(sorted_samples)
|
|
65
|
+
|
|
66
|
+
if total_samples >= max_n_samples:
|
|
67
|
+
# sample max_n_samples samples distributed across percentiles
|
|
68
|
+
indices = []
|
|
69
|
+
for i in range(max_n_samples):
|
|
70
|
+
# calculate percentile position (0% to 100% spread across 15 samples)
|
|
71
|
+
percentile = i / (max_n_samples - 1) # 14 intervals for 15 samples
|
|
72
|
+
index = int(percentile * (total_samples - 1))
|
|
73
|
+
indices.append(index)
|
|
74
|
+
|
|
75
|
+
subset_successfully_scored_samples = [sorted_samples[i] for i in indices]
|
|
76
|
+
else:
|
|
77
|
+
subset_successfully_scored_samples = sorted_samples
|
|
78
|
+
|
|
79
|
+
return subset_successfully_scored_samples
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
from harmony_client.logging_table import Table as Table
|