adaptive-harmony 0.1.23__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- adaptive_harmony/__init__.py +162 -0
- adaptive_harmony/common/__init__.py +40 -0
- adaptive_harmony/common/callbacks.py +219 -0
- adaptive_harmony/common/checkpointing.py +163 -0
- adaptive_harmony/common/dpo.py +92 -0
- adaptive_harmony/common/env_grpo.py +361 -0
- adaptive_harmony/common/grpo.py +260 -0
- adaptive_harmony/common/gspo.py +70 -0
- adaptive_harmony/common/ppo.py +303 -0
- adaptive_harmony/common/rm.py +79 -0
- adaptive_harmony/common/sft.py +121 -0
- adaptive_harmony/core/__init__.py +0 -0
- adaptive_harmony/core/dataset.py +72 -0
- adaptive_harmony/core/display.py +93 -0
- adaptive_harmony/core/image_utils.py +110 -0
- adaptive_harmony/core/reasoning.py +12 -0
- adaptive_harmony/core/reward_client/__init__.py +19 -0
- adaptive_harmony/core/reward_client/client.py +160 -0
- adaptive_harmony/core/reward_client/reward_types.py +49 -0
- adaptive_harmony/core/reward_client/websocket_utils.py +18 -0
- adaptive_harmony/core/rich_counter.py +351 -0
- adaptive_harmony/core/rl_utils.py +38 -0
- adaptive_harmony/core/schedulers.py +38 -0
- adaptive_harmony/core/structured_output.py +385 -0
- adaptive_harmony/core/utils.py +365 -0
- adaptive_harmony/environment/__init__.py +8 -0
- adaptive_harmony/environment/environment.py +121 -0
- adaptive_harmony/evaluation/__init__.py +1 -0
- adaptive_harmony/evaluation/evaluation_artifact.py +67 -0
- adaptive_harmony/graders/__init__.py +20 -0
- adaptive_harmony/graders/answer_relevancy_judge/__init__.py +3 -0
- adaptive_harmony/graders/answer_relevancy_judge/answer_relevancy_judge.py +102 -0
- adaptive_harmony/graders/answer_relevancy_judge/prompts.py +58 -0
- adaptive_harmony/graders/base_grader.py +265 -0
- adaptive_harmony/graders/binary_judge/__init__.py +8 -0
- adaptive_harmony/graders/binary_judge/binary_judge.py +202 -0
- adaptive_harmony/graders/binary_judge/prompts.py +125 -0
- adaptive_harmony/graders/combined_grader.py +118 -0
- adaptive_harmony/graders/context_relevancy_judge/__init__.py +3 -0
- adaptive_harmony/graders/context_relevancy_judge/context_relevancy_judge.py +128 -0
- adaptive_harmony/graders/context_relevancy_judge/prompts.py +84 -0
- adaptive_harmony/graders/exceptions.py +9 -0
- adaptive_harmony/graders/faithfulness_judge/__init__.py +3 -0
- adaptive_harmony/graders/faithfulness_judge/faithfulness_judge.py +159 -0
- adaptive_harmony/graders/faithfulness_judge/prompts.py +22 -0
- adaptive_harmony/graders/range_judge/__init__.py +7 -0
- adaptive_harmony/graders/range_judge/prompts.py +232 -0
- adaptive_harmony/graders/range_judge/range_judge.py +188 -0
- adaptive_harmony/graders/range_judge/types.py +12 -0
- adaptive_harmony/graders/reward_server_grader.py +36 -0
- adaptive_harmony/graders/templated_prompt_judge.py +237 -0
- adaptive_harmony/graders/utils.py +79 -0
- adaptive_harmony/logging_table.py +1 -0
- adaptive_harmony/metric_logger.py +452 -0
- adaptive_harmony/parameters/__init__.py +2 -0
- adaptive_harmony/py.typed +0 -0
- adaptive_harmony/runtime/__init__.py +2 -0
- adaptive_harmony/runtime/context.py +2 -0
- adaptive_harmony/runtime/data.py +2 -0
- adaptive_harmony/runtime/decorators.py +2 -0
- adaptive_harmony/runtime/model_artifact_save.py +2 -0
- adaptive_harmony/runtime/runner.py +27 -0
- adaptive_harmony/runtime/simple_notifier.py +2 -0
- adaptive_harmony-0.1.23.dist-info/METADATA +37 -0
- adaptive_harmony-0.1.23.dist-info/RECORD +67 -0
- adaptive_harmony-0.1.23.dist-info/WHEEL +5 -0
- adaptive_harmony-0.1.23.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,365 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
import functools
|
|
3
|
+
import hashlib
|
|
4
|
+
import itertools
|
|
5
|
+
import json
|
|
6
|
+
import random
|
|
7
|
+
from pathlib import Path
|
|
8
|
+
from typing import Any, Callable, Coroutine, Dict, Iterator, List, NamedTuple, Sequence, TypedDict, TypeVar
|
|
9
|
+
|
|
10
|
+
import numpy as np
|
|
11
|
+
from loguru import logger
|
|
12
|
+
|
|
13
|
+
from adaptive_harmony import InferenceModel, StringThread, TrainingModel
|
|
14
|
+
from adaptive_harmony.core.rich_counter import ProgressCounter, get_progress_counter_or_wrapper
|
|
15
|
+
from adaptive_harmony.metric_logger import Logger, StdoutLogger
|
|
16
|
+
|
|
17
|
+
S = TypeVar("S")
|
|
18
|
+
T = TypeVar("T")
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
async def wrap_coroutine_with_progress[T](coroutine: Coroutine[Any, Any, T], progress_counter: ProgressCounter) -> T:
|
|
22
|
+
try:
|
|
23
|
+
return await coroutine
|
|
24
|
+
finally:
|
|
25
|
+
progress_counter.increment_total_counter()
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
async def async_map_batch[S, T](
|
|
29
|
+
f: Callable[[S], Coroutine[Any, Any, T]],
|
|
30
|
+
data: Iterator[S],
|
|
31
|
+
batch_size: int,
|
|
32
|
+
max_failure_fraction: float = 0.5,
|
|
33
|
+
) -> List[T]:
|
|
34
|
+
"""
|
|
35
|
+
Process items from an iterator in batches using concurrent coroutines.
|
|
36
|
+
|
|
37
|
+
This function processes items from an iterator in batches, executing the
|
|
38
|
+
provided coroutine function concurrently for each item. It excludes failing
|
|
39
|
+
samples until it can create a new batch of results of size # batch size.
|
|
40
|
+
If more than max_failure_fraction % of # batch size tasks fail in the process
|
|
41
|
+
of creating a new batch, the function will raise the last exception encountered.
|
|
42
|
+
Results are not ordered.
|
|
43
|
+
|
|
44
|
+
Args:
|
|
45
|
+
f: Coroutine function to apply to each item
|
|
46
|
+
data: Iterator of items to process
|
|
47
|
+
batch_size: Number of items to process in each batch
|
|
48
|
+
|
|
49
|
+
Returns:
|
|
50
|
+
List of results from successful task executions
|
|
51
|
+
|
|
52
|
+
Note:
|
|
53
|
+
- Failed tasks are not retried
|
|
54
|
+
- If more than max_failure_fraction of # batch size tasks fail, the function fails
|
|
55
|
+
- Tasks are automatically cancelled if the function exits early
|
|
56
|
+
"""
|
|
57
|
+
batch_items_from_iterator = list(itertools.islice(data, batch_size))
|
|
58
|
+
num_items = len(batch_items_from_iterator)
|
|
59
|
+
|
|
60
|
+
async with get_progress_counter_or_wrapper(f"async_map_batch({f.__name__})", batch_size) as counter:
|
|
61
|
+
final_results: list[Any] = [None] * num_items
|
|
62
|
+
active_tasks_this_batch: Dict[asyncio.Task, int] = {}
|
|
63
|
+
|
|
64
|
+
num_retries = 0
|
|
65
|
+
|
|
66
|
+
for i, item_value in enumerate(batch_items_from_iterator):
|
|
67
|
+
task: asyncio.Task[T] = asyncio.create_task(wrap_coroutine_with_progress(f(item_value), counter))
|
|
68
|
+
counter.register_task(task)
|
|
69
|
+
active_tasks_this_batch[task] = i
|
|
70
|
+
|
|
71
|
+
try:
|
|
72
|
+
while active_tasks_this_batch:
|
|
73
|
+
done_tasks, _ = await asyncio.wait(active_tasks_this_batch.keys(), return_when=asyncio.FIRST_COMPLETED)
|
|
74
|
+
|
|
75
|
+
for task_item in done_tasks:
|
|
76
|
+
original_batch_slot_idx = active_tasks_this_batch.pop(task_item)
|
|
77
|
+
|
|
78
|
+
try:
|
|
79
|
+
result: T = await task_item
|
|
80
|
+
final_results[original_batch_slot_idx] = result
|
|
81
|
+
except Exception as ex:
|
|
82
|
+
try:
|
|
83
|
+
if num_retries > batch_size * max_failure_fraction:
|
|
84
|
+
# if more than 50% of a batch fail we'll just go on.
|
|
85
|
+
raise ex
|
|
86
|
+
|
|
87
|
+
logger.debug(ex)
|
|
88
|
+
retry_item_value: S = next(data)
|
|
89
|
+
new_retry_task: asyncio.Task[T] = asyncio.create_task(
|
|
90
|
+
wrap_coroutine_with_progress(f(retry_item_value), counter)
|
|
91
|
+
)
|
|
92
|
+
active_tasks_this_batch[new_retry_task] = original_batch_slot_idx
|
|
93
|
+
num_retries += 1
|
|
94
|
+
except StopIteration:
|
|
95
|
+
...
|
|
96
|
+
finally:
|
|
97
|
+
tasks_to_cancel = list(active_tasks_this_batch.keys())
|
|
98
|
+
for task_to_cancel in tasks_to_cancel:
|
|
99
|
+
task_to_cancel.cancel()
|
|
100
|
+
|
|
101
|
+
if tasks_to_cancel:
|
|
102
|
+
await asyncio.gather(*tasks_to_cancel, return_exceptions=True)
|
|
103
|
+
|
|
104
|
+
if num_retries > 0:
|
|
105
|
+
print(f"WARNING: had to retry {num_retries} times to get a batch of {batch_size}")
|
|
106
|
+
ret = [res for res in final_results if res is not None]
|
|
107
|
+
|
|
108
|
+
print(f"Final number tasks with non-None results: {len(ret)}")
|
|
109
|
+
|
|
110
|
+
return ret
|
|
111
|
+
|
|
112
|
+
|
|
113
|
+
def hash_hyperparams(include: set[str]):
|
|
114
|
+
"""
|
|
115
|
+
A decorator that computes a hash of specified hyperparameters and stores it on `self._hyperparams_hash`.
|
|
116
|
+
|
|
117
|
+
Must be used on an `__init__` method. Only parameters listed in `include` will be hashed.
|
|
118
|
+
Non-serializable values are converted to their string representation.
|
|
119
|
+
|
|
120
|
+
Args:
|
|
121
|
+
include: Set of parameter names to include in the hash.
|
|
122
|
+
|
|
123
|
+
Example:
|
|
124
|
+
@hash_hyperparams(include={"lr", "batch_size", "epochs"})
|
|
125
|
+
def __init__(self, lr, batch_size, epochs, logger, ...):
|
|
126
|
+
...
|
|
127
|
+
"""
|
|
128
|
+
import inspect
|
|
129
|
+
|
|
130
|
+
def decorator(func):
|
|
131
|
+
@functools.wraps(func)
|
|
132
|
+
def wrapper(self, *args, **kwargs):
|
|
133
|
+
sig = inspect.signature(func)
|
|
134
|
+
bound_args = sig.bind_partial(self, *args, **kwargs)
|
|
135
|
+
bound_args.apply_defaults()
|
|
136
|
+
all_args = bound_args.arguments
|
|
137
|
+
|
|
138
|
+
hyperparams = {}
|
|
139
|
+
for key in include:
|
|
140
|
+
if key in all_args:
|
|
141
|
+
value = all_args[key]
|
|
142
|
+
try:
|
|
143
|
+
json.dumps(value)
|
|
144
|
+
hyperparams[key] = value
|
|
145
|
+
except (TypeError, OverflowError):
|
|
146
|
+
hyperparams[key] = repr(value)
|
|
147
|
+
|
|
148
|
+
serialized = json.dumps(hyperparams, sort_keys=True)
|
|
149
|
+
self._hyperparams_hash = hashlib.sha256(serialized.encode()).hexdigest()[:16]
|
|
150
|
+
|
|
151
|
+
return func(self, *args, **kwargs)
|
|
152
|
+
|
|
153
|
+
return wrapper
|
|
154
|
+
|
|
155
|
+
return decorator
|
|
156
|
+
|
|
157
|
+
|
|
158
|
+
def log_args(func):
|
|
159
|
+
"""
|
|
160
|
+
A Python decorator that logs the arguments of the decorated function
|
|
161
|
+
to experiment tracking tools (wandb, mlflow) or stdout.
|
|
162
|
+
|
|
163
|
+
Attempts to log to wandb if available and initialized, then to mlflow
|
|
164
|
+
if available and has an active run. If neither is available, logs to stdout.
|
|
165
|
+
"""
|
|
166
|
+
|
|
167
|
+
@functools.wraps(func)
|
|
168
|
+
def wrapper(*args, **kwargs):
|
|
169
|
+
import inspect
|
|
170
|
+
|
|
171
|
+
# Helper to check serializability and prepare value
|
|
172
|
+
def prepare_value(value):
|
|
173
|
+
# we need to log the model builder args here because they are not serializable by default
|
|
174
|
+
if isinstance(value, list) and len(value) > 100:
|
|
175
|
+
# exclude long lists since we want to skip datasets
|
|
176
|
+
return None
|
|
177
|
+
if isinstance(value, InferenceModel) or isinstance(value, TrainingModel):
|
|
178
|
+
return value.get_builder_args() # type: ignore PyRight being dumb
|
|
179
|
+
else:
|
|
180
|
+
# Check if the value itself is a complex object that might not be fully serializable
|
|
181
|
+
try:
|
|
182
|
+
json.dumps({"test_key": value})
|
|
183
|
+
return value
|
|
184
|
+
except (TypeError, OverflowError):
|
|
185
|
+
return None
|
|
186
|
+
|
|
187
|
+
# Get function arguments once
|
|
188
|
+
sig = inspect.signature(func)
|
|
189
|
+
bound_args = sig.bind_partial(*args, **kwargs)
|
|
190
|
+
bound_args.apply_defaults()
|
|
191
|
+
all_args = bound_args.arguments
|
|
192
|
+
|
|
193
|
+
# find the loggers that are given to recipe, if None are found, we will log to stdout
|
|
194
|
+
loggers = [v for v in all_args.values() if isinstance(v, Logger)]
|
|
195
|
+
if not loggers:
|
|
196
|
+
loggers.append(StdoutLogger())
|
|
197
|
+
|
|
198
|
+
# get loggable args only
|
|
199
|
+
loggable_args = {k: new_v for k, v in all_args.items() if (new_v := prepare_value(v)) is not None}
|
|
200
|
+
|
|
201
|
+
for logger_instance in loggers:
|
|
202
|
+
logger_instance.log_config(loggable_args)
|
|
203
|
+
|
|
204
|
+
return func(*args, **kwargs)
|
|
205
|
+
|
|
206
|
+
return wrapper
|
|
207
|
+
|
|
208
|
+
|
|
209
|
+
async def async_map[S, T](f: Callable[[S], Coroutine[Any, Any, T]], data: Sequence[S]) -> list[T]:
|
|
210
|
+
"""
|
|
211
|
+
Process all items in an iterable concurrently using the provided coroutine function.
|
|
212
|
+
|
|
213
|
+
Args:
|
|
214
|
+
f: Coroutine function to apply to each item
|
|
215
|
+
data: Iterable of items to process
|
|
216
|
+
|
|
217
|
+
Returns:
|
|
218
|
+
List of results from all task executions
|
|
219
|
+
"""
|
|
220
|
+
|
|
221
|
+
# Check if a Progress bar is already active
|
|
222
|
+
async with get_progress_counter_or_wrapper(f"async_map({f.__name__})", len(list(data))) as counter:
|
|
223
|
+
all_tasks = [asyncio.create_task(wrap_coroutine_with_progress(f(item), counter)) for item in data]
|
|
224
|
+
for t in all_tasks:
|
|
225
|
+
counter.register_task(t)
|
|
226
|
+
results = await asyncio.gather(*all_tasks)
|
|
227
|
+
return results
|
|
228
|
+
|
|
229
|
+
|
|
230
|
+
async def async_map_fallible[S, T](f: Callable[[S], Coroutine[Any, Any, T]], data: Sequence[S]) -> list[T]:
|
|
231
|
+
"""
|
|
232
|
+
Process all items in an iterable concurrently using the provided coroutine function.
|
|
233
|
+
|
|
234
|
+
Args:
|
|
235
|
+
f: Coroutine function to apply to each item
|
|
236
|
+
data: Iterable of items to process
|
|
237
|
+
|
|
238
|
+
Returns:
|
|
239
|
+
List of results from all task executions
|
|
240
|
+
"""
|
|
241
|
+
|
|
242
|
+
async def wrap_coroutine_with_error_handling(coro: Coroutine[Any, Any, T]) -> tuple[T, bool]:
|
|
243
|
+
try:
|
|
244
|
+
result = await coro
|
|
245
|
+
return result, True
|
|
246
|
+
except Exception:
|
|
247
|
+
return None, False # type: ignore
|
|
248
|
+
|
|
249
|
+
async with get_progress_counter_or_wrapper(f"async_map_fallible({f.__name__})", len(list(data))) as counter:
|
|
250
|
+
all_tasks = [
|
|
251
|
+
asyncio.create_task(wrap_coroutine_with_error_handling(wrap_coroutine_with_progress(f(item), counter)))
|
|
252
|
+
for item in data
|
|
253
|
+
]
|
|
254
|
+
for t in all_tasks:
|
|
255
|
+
counter.register_task(t)
|
|
256
|
+
results = await asyncio.gather(*all_tasks)
|
|
257
|
+
|
|
258
|
+
return [result for result, success in results if success]
|
|
259
|
+
|
|
260
|
+
|
|
261
|
+
def get_minibatches[T](dataset: list[T], mini_batch_size: int, number_of_epochs: int) -> list[list[T]]:
|
|
262
|
+
all_batches: list[list[T]] = []
|
|
263
|
+
|
|
264
|
+
for _ in range(number_of_epochs):
|
|
265
|
+
shuffled_dataset = random.sample(dataset, k=len(dataset))
|
|
266
|
+
|
|
267
|
+
epoch_batches: list[list[T]] = []
|
|
268
|
+
for i in range(0, len(shuffled_dataset), mini_batch_size):
|
|
269
|
+
batch = shuffled_dataset[i : i + mini_batch_size]
|
|
270
|
+
epoch_batches.append(batch)
|
|
271
|
+
all_batches.extend(epoch_batches)
|
|
272
|
+
|
|
273
|
+
return all_batches
|
|
274
|
+
|
|
275
|
+
|
|
276
|
+
def sample_data[T](data: list[T], epochs: float) -> list[T]:
|
|
277
|
+
num_samples = len(data) * epochs
|
|
278
|
+
return [data[x] for x in np.random.permutation(len(data))[: int(num_samples)]]
|
|
279
|
+
|
|
280
|
+
|
|
281
|
+
def weighted_mean(values: list[list[float]], weights: list[list[float]]) -> float:
|
|
282
|
+
return np.average(np.concatenate(values), weights=np.concatenate(weights)).item()
|
|
283
|
+
|
|
284
|
+
|
|
285
|
+
def stringify_thread(thread: StringThread, sep: str = "\n\n") -> str:
|
|
286
|
+
"""Convert StringThread to readable text format."""
|
|
287
|
+
turns = thread.get_turns()
|
|
288
|
+
return sep.join([f"[{turn.role}]\n{turn.content}" for turn in turns])
|
|
289
|
+
|
|
290
|
+
|
|
291
|
+
class SingleTurnShot(TypedDict):
|
|
292
|
+
user: dict[str, str]
|
|
293
|
+
assistant: dict[str, str]
|
|
294
|
+
|
|
295
|
+
|
|
296
|
+
class TurnTemplates(NamedTuple):
|
|
297
|
+
system: str | None
|
|
298
|
+
user: str | None
|
|
299
|
+
assistant: str | None
|
|
300
|
+
shots: list[SingleTurnShot] | None
|
|
301
|
+
|
|
302
|
+
|
|
303
|
+
def turn_templates_from_dir(root_dir: str) -> TurnTemplates:
|
|
304
|
+
"""
|
|
305
|
+
Returns system, user and assistant turn string templates from a directory, as well as a list of shot dicts.
|
|
306
|
+
Expects files to be named system.md, user.md, assistant.md and shots.jsonl.
|
|
307
|
+
Returns None for any turn template file that does not exist.
|
|
308
|
+
"""
|
|
309
|
+
root_path = Path(root_dir)
|
|
310
|
+
expected_files = ["system.md", "user.md", "assistant.md", "shots.jsonl"]
|
|
311
|
+
missing_templates = []
|
|
312
|
+
turn_templates: list[str | list[SingleTurnShot] | None] = []
|
|
313
|
+
|
|
314
|
+
for file in expected_files:
|
|
315
|
+
path = root_path / file
|
|
316
|
+
if not path.exists():
|
|
317
|
+
missing_templates.append(file)
|
|
318
|
+
turn_templates.append(None)
|
|
319
|
+
else:
|
|
320
|
+
if file == "shots.jsonl":
|
|
321
|
+
shots = []
|
|
322
|
+
for line in path.read_text().splitlines():
|
|
323
|
+
data = json.loads(line)
|
|
324
|
+
shot = SingleTurnShot(user=data["user"], assistant=data["assistant"])
|
|
325
|
+
shots.append(shot)
|
|
326
|
+
turn_templates.append(shots)
|
|
327
|
+
else:
|
|
328
|
+
turn_templates.append(path.read_text())
|
|
329
|
+
|
|
330
|
+
# Ensure proper typing: first 3 are str|None, last is list[SingleTurnShot]|None
|
|
331
|
+
system, user, assistant, shots = turn_templates
|
|
332
|
+
return TurnTemplates(
|
|
333
|
+
system=system if isinstance(system, str) else None,
|
|
334
|
+
user=user if isinstance(user, str) else None,
|
|
335
|
+
assistant=assistant if isinstance(assistant, str) else None,
|
|
336
|
+
shots=shots if isinstance(shots, list) else None,
|
|
337
|
+
)
|
|
338
|
+
|
|
339
|
+
|
|
340
|
+
def hash_dataset(dataset: Sequence[StringThread], num_samples: int = 5) -> str:
|
|
341
|
+
"""Compute a hash of dataset for quick unsafe comparison.
|
|
342
|
+
|
|
343
|
+
Hashes: dataset length + first N elements + last N elements
|
|
344
|
+
This catches most dataset changes without processing the entire dataset.
|
|
345
|
+
|
|
346
|
+
Args:
|
|
347
|
+
dataset: List of dataset items
|
|
348
|
+
num_samples: Number of elements to sample from start/end (default 5)
|
|
349
|
+
|
|
350
|
+
Returns:
|
|
351
|
+
SHA256 hash
|
|
352
|
+
"""
|
|
353
|
+
hasher = hashlib.sha256()
|
|
354
|
+
|
|
355
|
+
hasher.update(str(len(dataset)).encode())
|
|
356
|
+
|
|
357
|
+
num_samples = min(num_samples, len(dataset) // 2)
|
|
358
|
+
|
|
359
|
+
for item in dataset[:num_samples]:
|
|
360
|
+
hasher.update(str(item).encode())
|
|
361
|
+
|
|
362
|
+
for item in dataset[-num_samples:]:
|
|
363
|
+
hasher.update(str(item).encode())
|
|
364
|
+
|
|
365
|
+
return hasher.hexdigest()
|
|
@@ -0,0 +1,121 @@
|
|
|
1
|
+
from abc import ABC, abstractmethod
|
|
2
|
+
from collections import defaultdict
|
|
3
|
+
from dataclasses import dataclass
|
|
4
|
+
from numbers import Number
|
|
5
|
+
from typing import Any, Callable
|
|
6
|
+
|
|
7
|
+
import numpy as np
|
|
8
|
+
|
|
9
|
+
from adaptive_harmony import StringThread
|
|
10
|
+
from adaptive_harmony.logging_table import Table
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
@dataclass
|
|
14
|
+
class TurnScore:
|
|
15
|
+
score: float
|
|
16
|
+
num_assistant_turns: int
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
@dataclass
|
|
20
|
+
class TrajectoryScore:
|
|
21
|
+
scores: list[TurnScore]
|
|
22
|
+
|
|
23
|
+
def to_turn_scores(self) -> list[float]:
|
|
24
|
+
return [turn_score.score for turn_score in self.scores for _ in range(turn_score.num_assistant_turns)]
|
|
25
|
+
|
|
26
|
+
@property
|
|
27
|
+
def cumulative_score(self) -> float:
|
|
28
|
+
return sum([turn_score.score for turn_score in self.scores])
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class Environment(ABC):
|
|
32
|
+
"""
|
|
33
|
+
Environment to inherit from when building an environment.
|
|
34
|
+
"""
|
|
35
|
+
|
|
36
|
+
def __init__(self, logging_function: Callable):
|
|
37
|
+
self.logging_function = logging_function
|
|
38
|
+
|
|
39
|
+
@abstractmethod
|
|
40
|
+
async def react_to(self, thread: StringThread) -> list[tuple[str, str]] | TrajectoryScore:
|
|
41
|
+
"""Returns either [("tool", tool response), ...] or [("user", user question)] or TrajectoryScore when DONE."""
|
|
42
|
+
pass
|
|
43
|
+
|
|
44
|
+
async def bootstrap_prompt(self, thread: StringThread) -> StringThread:
|
|
45
|
+
return thread
|
|
46
|
+
|
|
47
|
+
async def generate_trajectory_and_grade(
|
|
48
|
+
self, model, initial_thread: StringThread
|
|
49
|
+
) -> tuple[StringThread, TrajectoryScore]:
|
|
50
|
+
"""Generate a full trajectory by interacting with the model until termination."""
|
|
51
|
+
thread = await self.bootstrap_prompt(initial_thread)
|
|
52
|
+
|
|
53
|
+
while True:
|
|
54
|
+
# Generate model response
|
|
55
|
+
thread = await model.generate(thread)
|
|
56
|
+
|
|
57
|
+
# Get environment reaction
|
|
58
|
+
env_response = await self.react_to(thread)
|
|
59
|
+
|
|
60
|
+
# If we got a score, we're done
|
|
61
|
+
if isinstance(env_response, TrajectoryScore):
|
|
62
|
+
return thread, env_response
|
|
63
|
+
|
|
64
|
+
# Otherwise, add the environment responses to the thread
|
|
65
|
+
for role, content in env_response:
|
|
66
|
+
if role == "tool":
|
|
67
|
+
thread = thread.tool(content)
|
|
68
|
+
elif role == "user":
|
|
69
|
+
thread = thread.user(content)
|
|
70
|
+
else:
|
|
71
|
+
raise ValueError(f"Unknown role: {role}")
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
class EnvironmentFactory(ABC):
|
|
75
|
+
"""
|
|
76
|
+
Abstract class to build environments. It is necessary because each trajectory must have its own unique Environment
|
|
77
|
+
"""
|
|
78
|
+
|
|
79
|
+
def __init__(self, logging_name: str | None = None):
|
|
80
|
+
self._logs: dict[str, list] = defaultdict(list)
|
|
81
|
+
self.logging_name = logging_name
|
|
82
|
+
|
|
83
|
+
def add_log(self, key, new_value) -> None:
|
|
84
|
+
"""Add a log entry to the scorer's log collection."""
|
|
85
|
+
self._logs[key].append(new_value)
|
|
86
|
+
|
|
87
|
+
def get_logs(self, clear: bool = False) -> dict[str, float | Table]:
|
|
88
|
+
"""
|
|
89
|
+
Get aggregated logs from all score calls.
|
|
90
|
+
Base implementation computes statistics for "score" keys in individual logs.
|
|
91
|
+
If there are none, returns empty dict.
|
|
92
|
+
"""
|
|
93
|
+
if not self._logs:
|
|
94
|
+
return {}
|
|
95
|
+
|
|
96
|
+
logs = {}
|
|
97
|
+
for k, v in self._logs.items():
|
|
98
|
+
if isinstance(v[0], Number):
|
|
99
|
+
logs[k] = np.asarray(v).mean()
|
|
100
|
+
elif isinstance(v[0], Table):
|
|
101
|
+
headers = v[0].headers
|
|
102
|
+
assert all(table.headers == headers for table in v)
|
|
103
|
+
overall_table = Table(headers)
|
|
104
|
+
for table in v:
|
|
105
|
+
overall_table.add_rows(table.rows)
|
|
106
|
+
logs[k] = overall_table
|
|
107
|
+
else:
|
|
108
|
+
raise ValueError(f"Unknown type: {type(v[0])}")
|
|
109
|
+
|
|
110
|
+
if clear:
|
|
111
|
+
self.clear_logs()
|
|
112
|
+
return logs
|
|
113
|
+
|
|
114
|
+
def clear_logs(self) -> None:
|
|
115
|
+
"""
|
|
116
|
+
Clear all accumulated logs.
|
|
117
|
+
"""
|
|
118
|
+
self._logs.clear()
|
|
119
|
+
|
|
120
|
+
@abstractmethod
|
|
121
|
+
def create_environment(self, metadata: Any) -> Environment: ...
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
from .evaluation_artifact import EvaluationArtifact as EvaluationArtifact
|
|
@@ -0,0 +1,67 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
import uuid
|
|
3
|
+
from typing import List, Self
|
|
4
|
+
|
|
5
|
+
from harmony_client import (
|
|
6
|
+
EvalSample,
|
|
7
|
+
EvaluationArtifactBase,
|
|
8
|
+
)
|
|
9
|
+
from harmony_client.runtime.context import RecipeContext
|
|
10
|
+
|
|
11
|
+
logger = logging.getLogger(__name__)
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class EvaluationArtifact:
|
|
15
|
+
def __init__(self, name: str, ctx: RecipeContext) -> None:
|
|
16
|
+
artifact_id = str(uuid.uuid4())
|
|
17
|
+
url = ctx.file_storage.mk_url(f"artifacts/eval_samples_{artifact_id}.jsonl")
|
|
18
|
+
self._base = EvaluationArtifactBase(name, url, artifact_id)
|
|
19
|
+
self.ctx = ctx
|
|
20
|
+
self.ctx.job.register_artifact(self._base.artifact)
|
|
21
|
+
|
|
22
|
+
@property
|
|
23
|
+
def id(self) -> str:
|
|
24
|
+
return self._base.id
|
|
25
|
+
|
|
26
|
+
@property
|
|
27
|
+
def name(self) -> str:
|
|
28
|
+
return self._base.name
|
|
29
|
+
|
|
30
|
+
@property
|
|
31
|
+
def kind(self) -> str:
|
|
32
|
+
return self._base.kind
|
|
33
|
+
|
|
34
|
+
@property
|
|
35
|
+
def uri(self) -> str:
|
|
36
|
+
assert self._base.uri is not None
|
|
37
|
+
return self._base.uri
|
|
38
|
+
|
|
39
|
+
def add_samples(self, samples: List[EvalSample]) -> Self:
|
|
40
|
+
"""Add evaluation samples to this artifact.
|
|
41
|
+
|
|
42
|
+
Args:
|
|
43
|
+
samples: List of evaluation samples to add
|
|
44
|
+
|
|
45
|
+
Returns:
|
|
46
|
+
Self for method chaining
|
|
47
|
+
|
|
48
|
+
Raises:
|
|
49
|
+
ValueError: If samples list is empty
|
|
50
|
+
Exception: If serialization or storage fails
|
|
51
|
+
"""
|
|
52
|
+
if not samples:
|
|
53
|
+
raise ValueError("Cannot add empty samples list")
|
|
54
|
+
|
|
55
|
+
try:
|
|
56
|
+
samples_json = self._base.samples_to_adaptive_json(samples)
|
|
57
|
+
for json_str in samples_json:
|
|
58
|
+
self.ctx.file_storage.append((json_str + "\n").encode("utf-8"), self.uri)
|
|
59
|
+
logger.debug(f"Added {len(samples)} samples to artifact {self.id}")
|
|
60
|
+
except Exception as e:
|
|
61
|
+
logger.error(f"Failed to add samples to artifact {self.id}: {e}")
|
|
62
|
+
raise
|
|
63
|
+
|
|
64
|
+
return self
|
|
65
|
+
|
|
66
|
+
def __repr__(self):
|
|
67
|
+
return f"EvaluationArtifact(id={self.id}, name={self.name}, kind={self.kind}, uri={self.uri})"
|
|
@@ -0,0 +1,20 @@
|
|
|
1
|
+
from adaptive_harmony import Grade
|
|
2
|
+
|
|
3
|
+
from .answer_relevancy_judge import AnswerRelevancyGrader
|
|
4
|
+
from .base_grader import BaseGrader
|
|
5
|
+
from .binary_judge import BinaryJudgeGrader
|
|
6
|
+
from .context_relevancy_judge import ContextRelevancyGrader
|
|
7
|
+
from .exceptions import IgnoreScoreException
|
|
8
|
+
from .faithfulness_judge import FaithfulnessGrader
|
|
9
|
+
from .range_judge import RangeJudgeGrader
|
|
10
|
+
|
|
11
|
+
__all__ = [
|
|
12
|
+
"BaseGrader",
|
|
13
|
+
"Grade",
|
|
14
|
+
"IgnoreScoreException",
|
|
15
|
+
"BinaryJudgeGrader",
|
|
16
|
+
"RangeJudgeGrader",
|
|
17
|
+
"FaithfulnessGrader",
|
|
18
|
+
"AnswerRelevancyGrader",
|
|
19
|
+
"ContextRelevancyGrader",
|
|
20
|
+
]
|
|
@@ -0,0 +1,102 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from typing import Literal
|
|
4
|
+
|
|
5
|
+
import numpy as np
|
|
6
|
+
import pysbd
|
|
7
|
+
from harmony_client import Grade, InferenceModel, StringThread
|
|
8
|
+
from pydantic import BaseModel, Field
|
|
9
|
+
|
|
10
|
+
from adaptive_harmony.core.utils import stringify_thread
|
|
11
|
+
from adaptive_harmony.graders.answer_relevancy_judge.prompts import DEFAULT_SHOTS, SYSTEM, USER
|
|
12
|
+
from adaptive_harmony.graders.base_grader import BaseGrader
|
|
13
|
+
from adaptive_harmony.graders.faithfulness_judge.faithfulness_judge import SupportedLanguages
|
|
14
|
+
from adaptive_harmony.graders.utils import sample_score_distribution, separate_context_from_last_user_turn
|
|
15
|
+
from adaptive_harmony.logging_table import Table
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class StatementRelevancy(BaseModel):
|
|
19
|
+
reason: str = Field(description="The justification for the score given to a statement. Keep it short and concise.")
|
|
20
|
+
score: Literal[0, 1] = Field(
|
|
21
|
+
description="The score for the statement. A score of 1 if the statement is relevant to addressing the original input, and 0 if the statement is irrelevant to addressing the input"
|
|
22
|
+
)
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class AnswerRelevancyResults(BaseModel):
|
|
26
|
+
results: list[StatementRelevancy] = Field(description="A list of relevancy results for the statements")
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class AnswerRelevancyGrader(BaseGrader):
|
|
30
|
+
def __init__(
|
|
31
|
+
self,
|
|
32
|
+
model: InferenceModel,
|
|
33
|
+
language: SupportedLanguages = "en",
|
|
34
|
+
grader_key: str = "answer_relevancy_judge",
|
|
35
|
+
grader_id: str | None = None,
|
|
36
|
+
):
|
|
37
|
+
super().__init__(grader_key)
|
|
38
|
+
self.model = model
|
|
39
|
+
self.language = language
|
|
40
|
+
self.grader_id_or_key = grader_id or grader_key
|
|
41
|
+
self.sentence_splitter = pysbd.Segmenter(language=language)
|
|
42
|
+
self.shots = DEFAULT_SHOTS
|
|
43
|
+
|
|
44
|
+
async def grade(self, sample: StringThread) -> Grade:
|
|
45
|
+
_, user_question = separate_context_from_last_user_turn(sample)
|
|
46
|
+
|
|
47
|
+
completion = sample.last_content()
|
|
48
|
+
split_sentences = self.sentence_splitter.segment(completion)
|
|
49
|
+
sentences = [sentence.strip() for sentence in split_sentences if sentence.strip()]
|
|
50
|
+
sentences_judge_str = "\n".join(f"{i}: {sentence}" for i, sentence in enumerate(sentences))
|
|
51
|
+
|
|
52
|
+
judging_thread = (
|
|
53
|
+
StringThread()
|
|
54
|
+
.system(SYSTEM.format(json_schema=self.model.render_schema(AnswerRelevancyResults), shots=self.shots))
|
|
55
|
+
.user(USER.format(user_question=user_question, statements=sentences_judge_str))
|
|
56
|
+
)
|
|
57
|
+
|
|
58
|
+
try:
|
|
59
|
+
_, response = await self.model.temperature(0.0).generate_and_validate(
|
|
60
|
+
judging_thread, AnswerRelevancyResults
|
|
61
|
+
)
|
|
62
|
+
results = response.results
|
|
63
|
+
except Exception as e:
|
|
64
|
+
self.add_log({"prompt": stringify_thread(judging_thread, sep=f"\n\n{'-' * 10}\n\n"), "error": str(e)})
|
|
65
|
+
raise
|
|
66
|
+
|
|
67
|
+
reason = ""
|
|
68
|
+
for i, (result, statement) in enumerate(zip(results, sentences)):
|
|
69
|
+
emoji = "✅" if result.score == 1 else "❌"
|
|
70
|
+
result = "PASS" if result.score == 1 else "FAIL"
|
|
71
|
+
statement_display = statement[:150] + ("..." if len(statement) > 150 else "")
|
|
72
|
+
reason += f"{emoji} Statement {i}: {result}\n Excerpt: {statement_display}:\nReason: {result}\n\n"
|
|
73
|
+
|
|
74
|
+
score = np.mean([float(result.score) for result in results]) if results else 0.0
|
|
75
|
+
self.add_log(
|
|
76
|
+
{"score": score, "prompt": stringify_thread(judging_thread, sep=f"\n\n{'-' * 10}\n\n"), "reasoning": reason}
|
|
77
|
+
)
|
|
78
|
+
return Grade(value=float(score), grader_key=self.grader_id_or_key, reasoning=reason)
|
|
79
|
+
|
|
80
|
+
def get_logs(self, clear: bool = False, log_all_samples: bool = False) -> dict[str, float | Table]:
|
|
81
|
+
# Only clear logs at the end if clear is True
|
|
82
|
+
logs = super().get_logs(clear=False)
|
|
83
|
+
|
|
84
|
+
successfully_scored_samples = [log for log in self._logs if "score" in log]
|
|
85
|
+
|
|
86
|
+
# stratified sample range of scores to see high and low
|
|
87
|
+
if not log_all_samples:
|
|
88
|
+
subset_successfully_scored_samples = sample_score_distribution(successfully_scored_samples, 15)
|
|
89
|
+
else:
|
|
90
|
+
# if we have fewer than 15 samples or we want to log all samples, take them all
|
|
91
|
+
subset_successfully_scored_samples = successfully_scored_samples
|
|
92
|
+
|
|
93
|
+
failed_scored_samples = [log for log in self._logs if "error" in log]
|
|
94
|
+
|
|
95
|
+
sample_logs = self.get_sample_tables(subset_successfully_scored_samples, failed_scored_samples)
|
|
96
|
+
|
|
97
|
+
logs.update(sample_logs)
|
|
98
|
+
|
|
99
|
+
if clear:
|
|
100
|
+
self.clear_logs()
|
|
101
|
+
|
|
102
|
+
return logs
|