adaptive-harmony 0.1.23__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (67) hide show
  1. adaptive_harmony/__init__.py +162 -0
  2. adaptive_harmony/common/__init__.py +40 -0
  3. adaptive_harmony/common/callbacks.py +219 -0
  4. adaptive_harmony/common/checkpointing.py +163 -0
  5. adaptive_harmony/common/dpo.py +92 -0
  6. adaptive_harmony/common/env_grpo.py +361 -0
  7. adaptive_harmony/common/grpo.py +260 -0
  8. adaptive_harmony/common/gspo.py +70 -0
  9. adaptive_harmony/common/ppo.py +303 -0
  10. adaptive_harmony/common/rm.py +79 -0
  11. adaptive_harmony/common/sft.py +121 -0
  12. adaptive_harmony/core/__init__.py +0 -0
  13. adaptive_harmony/core/dataset.py +72 -0
  14. adaptive_harmony/core/display.py +93 -0
  15. adaptive_harmony/core/image_utils.py +110 -0
  16. adaptive_harmony/core/reasoning.py +12 -0
  17. adaptive_harmony/core/reward_client/__init__.py +19 -0
  18. adaptive_harmony/core/reward_client/client.py +160 -0
  19. adaptive_harmony/core/reward_client/reward_types.py +49 -0
  20. adaptive_harmony/core/reward_client/websocket_utils.py +18 -0
  21. adaptive_harmony/core/rich_counter.py +351 -0
  22. adaptive_harmony/core/rl_utils.py +38 -0
  23. adaptive_harmony/core/schedulers.py +38 -0
  24. adaptive_harmony/core/structured_output.py +385 -0
  25. adaptive_harmony/core/utils.py +365 -0
  26. adaptive_harmony/environment/__init__.py +8 -0
  27. adaptive_harmony/environment/environment.py +121 -0
  28. adaptive_harmony/evaluation/__init__.py +1 -0
  29. adaptive_harmony/evaluation/evaluation_artifact.py +67 -0
  30. adaptive_harmony/graders/__init__.py +20 -0
  31. adaptive_harmony/graders/answer_relevancy_judge/__init__.py +3 -0
  32. adaptive_harmony/graders/answer_relevancy_judge/answer_relevancy_judge.py +102 -0
  33. adaptive_harmony/graders/answer_relevancy_judge/prompts.py +58 -0
  34. adaptive_harmony/graders/base_grader.py +265 -0
  35. adaptive_harmony/graders/binary_judge/__init__.py +8 -0
  36. adaptive_harmony/graders/binary_judge/binary_judge.py +202 -0
  37. adaptive_harmony/graders/binary_judge/prompts.py +125 -0
  38. adaptive_harmony/graders/combined_grader.py +118 -0
  39. adaptive_harmony/graders/context_relevancy_judge/__init__.py +3 -0
  40. adaptive_harmony/graders/context_relevancy_judge/context_relevancy_judge.py +128 -0
  41. adaptive_harmony/graders/context_relevancy_judge/prompts.py +84 -0
  42. adaptive_harmony/graders/exceptions.py +9 -0
  43. adaptive_harmony/graders/faithfulness_judge/__init__.py +3 -0
  44. adaptive_harmony/graders/faithfulness_judge/faithfulness_judge.py +159 -0
  45. adaptive_harmony/graders/faithfulness_judge/prompts.py +22 -0
  46. adaptive_harmony/graders/range_judge/__init__.py +7 -0
  47. adaptive_harmony/graders/range_judge/prompts.py +232 -0
  48. adaptive_harmony/graders/range_judge/range_judge.py +188 -0
  49. adaptive_harmony/graders/range_judge/types.py +12 -0
  50. adaptive_harmony/graders/reward_server_grader.py +36 -0
  51. adaptive_harmony/graders/templated_prompt_judge.py +237 -0
  52. adaptive_harmony/graders/utils.py +79 -0
  53. adaptive_harmony/logging_table.py +1 -0
  54. adaptive_harmony/metric_logger.py +452 -0
  55. adaptive_harmony/parameters/__init__.py +2 -0
  56. adaptive_harmony/py.typed +0 -0
  57. adaptive_harmony/runtime/__init__.py +2 -0
  58. adaptive_harmony/runtime/context.py +2 -0
  59. adaptive_harmony/runtime/data.py +2 -0
  60. adaptive_harmony/runtime/decorators.py +2 -0
  61. adaptive_harmony/runtime/model_artifact_save.py +2 -0
  62. adaptive_harmony/runtime/runner.py +27 -0
  63. adaptive_harmony/runtime/simple_notifier.py +2 -0
  64. adaptive_harmony-0.1.23.dist-info/METADATA +37 -0
  65. adaptive_harmony-0.1.23.dist-info/RECORD +67 -0
  66. adaptive_harmony-0.1.23.dist-info/WHEEL +5 -0
  67. adaptive_harmony-0.1.23.dist-info/top_level.txt +1 -0
@@ -0,0 +1,365 @@
1
+ import asyncio
2
+ import functools
3
+ import hashlib
4
+ import itertools
5
+ import json
6
+ import random
7
+ from pathlib import Path
8
+ from typing import Any, Callable, Coroutine, Dict, Iterator, List, NamedTuple, Sequence, TypedDict, TypeVar
9
+
10
+ import numpy as np
11
+ from loguru import logger
12
+
13
+ from adaptive_harmony import InferenceModel, StringThread, TrainingModel
14
+ from adaptive_harmony.core.rich_counter import ProgressCounter, get_progress_counter_or_wrapper
15
+ from adaptive_harmony.metric_logger import Logger, StdoutLogger
16
+
17
+ S = TypeVar("S")
18
+ T = TypeVar("T")
19
+
20
+
21
+ async def wrap_coroutine_with_progress[T](coroutine: Coroutine[Any, Any, T], progress_counter: ProgressCounter) -> T:
22
+ try:
23
+ return await coroutine
24
+ finally:
25
+ progress_counter.increment_total_counter()
26
+
27
+
28
+ async def async_map_batch[S, T](
29
+ f: Callable[[S], Coroutine[Any, Any, T]],
30
+ data: Iterator[S],
31
+ batch_size: int,
32
+ max_failure_fraction: float = 0.5,
33
+ ) -> List[T]:
34
+ """
35
+ Process items from an iterator in batches using concurrent coroutines.
36
+
37
+ This function processes items from an iterator in batches, executing the
38
+ provided coroutine function concurrently for each item. It excludes failing
39
+ samples until it can create a new batch of results of size # batch size.
40
+ If more than max_failure_fraction % of # batch size tasks fail in the process
41
+ of creating a new batch, the function will raise the last exception encountered.
42
+ Results are not ordered.
43
+
44
+ Args:
45
+ f: Coroutine function to apply to each item
46
+ data: Iterator of items to process
47
+ batch_size: Number of items to process in each batch
48
+
49
+ Returns:
50
+ List of results from successful task executions
51
+
52
+ Note:
53
+ - Failed tasks are not retried
54
+ - If more than max_failure_fraction of # batch size tasks fail, the function fails
55
+ - Tasks are automatically cancelled if the function exits early
56
+ """
57
+ batch_items_from_iterator = list(itertools.islice(data, batch_size))
58
+ num_items = len(batch_items_from_iterator)
59
+
60
+ async with get_progress_counter_or_wrapper(f"async_map_batch({f.__name__})", batch_size) as counter:
61
+ final_results: list[Any] = [None] * num_items
62
+ active_tasks_this_batch: Dict[asyncio.Task, int] = {}
63
+
64
+ num_retries = 0
65
+
66
+ for i, item_value in enumerate(batch_items_from_iterator):
67
+ task: asyncio.Task[T] = asyncio.create_task(wrap_coroutine_with_progress(f(item_value), counter))
68
+ counter.register_task(task)
69
+ active_tasks_this_batch[task] = i
70
+
71
+ try:
72
+ while active_tasks_this_batch:
73
+ done_tasks, _ = await asyncio.wait(active_tasks_this_batch.keys(), return_when=asyncio.FIRST_COMPLETED)
74
+
75
+ for task_item in done_tasks:
76
+ original_batch_slot_idx = active_tasks_this_batch.pop(task_item)
77
+
78
+ try:
79
+ result: T = await task_item
80
+ final_results[original_batch_slot_idx] = result
81
+ except Exception as ex:
82
+ try:
83
+ if num_retries > batch_size * max_failure_fraction:
84
+ # if more than 50% of a batch fail we'll just go on.
85
+ raise ex
86
+
87
+ logger.debug(ex)
88
+ retry_item_value: S = next(data)
89
+ new_retry_task: asyncio.Task[T] = asyncio.create_task(
90
+ wrap_coroutine_with_progress(f(retry_item_value), counter)
91
+ )
92
+ active_tasks_this_batch[new_retry_task] = original_batch_slot_idx
93
+ num_retries += 1
94
+ except StopIteration:
95
+ ...
96
+ finally:
97
+ tasks_to_cancel = list(active_tasks_this_batch.keys())
98
+ for task_to_cancel in tasks_to_cancel:
99
+ task_to_cancel.cancel()
100
+
101
+ if tasks_to_cancel:
102
+ await asyncio.gather(*tasks_to_cancel, return_exceptions=True)
103
+
104
+ if num_retries > 0:
105
+ print(f"WARNING: had to retry {num_retries} times to get a batch of {batch_size}")
106
+ ret = [res for res in final_results if res is not None]
107
+
108
+ print(f"Final number tasks with non-None results: {len(ret)}")
109
+
110
+ return ret
111
+
112
+
113
+ def hash_hyperparams(include: set[str]):
114
+ """
115
+ A decorator that computes a hash of specified hyperparameters and stores it on `self._hyperparams_hash`.
116
+
117
+ Must be used on an `__init__` method. Only parameters listed in `include` will be hashed.
118
+ Non-serializable values are converted to their string representation.
119
+
120
+ Args:
121
+ include: Set of parameter names to include in the hash.
122
+
123
+ Example:
124
+ @hash_hyperparams(include={"lr", "batch_size", "epochs"})
125
+ def __init__(self, lr, batch_size, epochs, logger, ...):
126
+ ...
127
+ """
128
+ import inspect
129
+
130
+ def decorator(func):
131
+ @functools.wraps(func)
132
+ def wrapper(self, *args, **kwargs):
133
+ sig = inspect.signature(func)
134
+ bound_args = sig.bind_partial(self, *args, **kwargs)
135
+ bound_args.apply_defaults()
136
+ all_args = bound_args.arguments
137
+
138
+ hyperparams = {}
139
+ for key in include:
140
+ if key in all_args:
141
+ value = all_args[key]
142
+ try:
143
+ json.dumps(value)
144
+ hyperparams[key] = value
145
+ except (TypeError, OverflowError):
146
+ hyperparams[key] = repr(value)
147
+
148
+ serialized = json.dumps(hyperparams, sort_keys=True)
149
+ self._hyperparams_hash = hashlib.sha256(serialized.encode()).hexdigest()[:16]
150
+
151
+ return func(self, *args, **kwargs)
152
+
153
+ return wrapper
154
+
155
+ return decorator
156
+
157
+
158
+ def log_args(func):
159
+ """
160
+ A Python decorator that logs the arguments of the decorated function
161
+ to experiment tracking tools (wandb, mlflow) or stdout.
162
+
163
+ Attempts to log to wandb if available and initialized, then to mlflow
164
+ if available and has an active run. If neither is available, logs to stdout.
165
+ """
166
+
167
+ @functools.wraps(func)
168
+ def wrapper(*args, **kwargs):
169
+ import inspect
170
+
171
+ # Helper to check serializability and prepare value
172
+ def prepare_value(value):
173
+ # we need to log the model builder args here because they are not serializable by default
174
+ if isinstance(value, list) and len(value) > 100:
175
+ # exclude long lists since we want to skip datasets
176
+ return None
177
+ if isinstance(value, InferenceModel) or isinstance(value, TrainingModel):
178
+ return value.get_builder_args() # type: ignore PyRight being dumb
179
+ else:
180
+ # Check if the value itself is a complex object that might not be fully serializable
181
+ try:
182
+ json.dumps({"test_key": value})
183
+ return value
184
+ except (TypeError, OverflowError):
185
+ return None
186
+
187
+ # Get function arguments once
188
+ sig = inspect.signature(func)
189
+ bound_args = sig.bind_partial(*args, **kwargs)
190
+ bound_args.apply_defaults()
191
+ all_args = bound_args.arguments
192
+
193
+ # find the loggers that are given to recipe, if None are found, we will log to stdout
194
+ loggers = [v for v in all_args.values() if isinstance(v, Logger)]
195
+ if not loggers:
196
+ loggers.append(StdoutLogger())
197
+
198
+ # get loggable args only
199
+ loggable_args = {k: new_v for k, v in all_args.items() if (new_v := prepare_value(v)) is not None}
200
+
201
+ for logger_instance in loggers:
202
+ logger_instance.log_config(loggable_args)
203
+
204
+ return func(*args, **kwargs)
205
+
206
+ return wrapper
207
+
208
+
209
+ async def async_map[S, T](f: Callable[[S], Coroutine[Any, Any, T]], data: Sequence[S]) -> list[T]:
210
+ """
211
+ Process all items in an iterable concurrently using the provided coroutine function.
212
+
213
+ Args:
214
+ f: Coroutine function to apply to each item
215
+ data: Iterable of items to process
216
+
217
+ Returns:
218
+ List of results from all task executions
219
+ """
220
+
221
+ # Check if a Progress bar is already active
222
+ async with get_progress_counter_or_wrapper(f"async_map({f.__name__})", len(list(data))) as counter:
223
+ all_tasks = [asyncio.create_task(wrap_coroutine_with_progress(f(item), counter)) for item in data]
224
+ for t in all_tasks:
225
+ counter.register_task(t)
226
+ results = await asyncio.gather(*all_tasks)
227
+ return results
228
+
229
+
230
+ async def async_map_fallible[S, T](f: Callable[[S], Coroutine[Any, Any, T]], data: Sequence[S]) -> list[T]:
231
+ """
232
+ Process all items in an iterable concurrently using the provided coroutine function.
233
+
234
+ Args:
235
+ f: Coroutine function to apply to each item
236
+ data: Iterable of items to process
237
+
238
+ Returns:
239
+ List of results from all task executions
240
+ """
241
+
242
+ async def wrap_coroutine_with_error_handling(coro: Coroutine[Any, Any, T]) -> tuple[T, bool]:
243
+ try:
244
+ result = await coro
245
+ return result, True
246
+ except Exception:
247
+ return None, False # type: ignore
248
+
249
+ async with get_progress_counter_or_wrapper(f"async_map_fallible({f.__name__})", len(list(data))) as counter:
250
+ all_tasks = [
251
+ asyncio.create_task(wrap_coroutine_with_error_handling(wrap_coroutine_with_progress(f(item), counter)))
252
+ for item in data
253
+ ]
254
+ for t in all_tasks:
255
+ counter.register_task(t)
256
+ results = await asyncio.gather(*all_tasks)
257
+
258
+ return [result for result, success in results if success]
259
+
260
+
261
+ def get_minibatches[T](dataset: list[T], mini_batch_size: int, number_of_epochs: int) -> list[list[T]]:
262
+ all_batches: list[list[T]] = []
263
+
264
+ for _ in range(number_of_epochs):
265
+ shuffled_dataset = random.sample(dataset, k=len(dataset))
266
+
267
+ epoch_batches: list[list[T]] = []
268
+ for i in range(0, len(shuffled_dataset), mini_batch_size):
269
+ batch = shuffled_dataset[i : i + mini_batch_size]
270
+ epoch_batches.append(batch)
271
+ all_batches.extend(epoch_batches)
272
+
273
+ return all_batches
274
+
275
+
276
+ def sample_data[T](data: list[T], epochs: float) -> list[T]:
277
+ num_samples = len(data) * epochs
278
+ return [data[x] for x in np.random.permutation(len(data))[: int(num_samples)]]
279
+
280
+
281
+ def weighted_mean(values: list[list[float]], weights: list[list[float]]) -> float:
282
+ return np.average(np.concatenate(values), weights=np.concatenate(weights)).item()
283
+
284
+
285
+ def stringify_thread(thread: StringThread, sep: str = "\n\n") -> str:
286
+ """Convert StringThread to readable text format."""
287
+ turns = thread.get_turns()
288
+ return sep.join([f"[{turn.role}]\n{turn.content}" for turn in turns])
289
+
290
+
291
+ class SingleTurnShot(TypedDict):
292
+ user: dict[str, str]
293
+ assistant: dict[str, str]
294
+
295
+
296
+ class TurnTemplates(NamedTuple):
297
+ system: str | None
298
+ user: str | None
299
+ assistant: str | None
300
+ shots: list[SingleTurnShot] | None
301
+
302
+
303
+ def turn_templates_from_dir(root_dir: str) -> TurnTemplates:
304
+ """
305
+ Returns system, user and assistant turn string templates from a directory, as well as a list of shot dicts.
306
+ Expects files to be named system.md, user.md, assistant.md and shots.jsonl.
307
+ Returns None for any turn template file that does not exist.
308
+ """
309
+ root_path = Path(root_dir)
310
+ expected_files = ["system.md", "user.md", "assistant.md", "shots.jsonl"]
311
+ missing_templates = []
312
+ turn_templates: list[str | list[SingleTurnShot] | None] = []
313
+
314
+ for file in expected_files:
315
+ path = root_path / file
316
+ if not path.exists():
317
+ missing_templates.append(file)
318
+ turn_templates.append(None)
319
+ else:
320
+ if file == "shots.jsonl":
321
+ shots = []
322
+ for line in path.read_text().splitlines():
323
+ data = json.loads(line)
324
+ shot = SingleTurnShot(user=data["user"], assistant=data["assistant"])
325
+ shots.append(shot)
326
+ turn_templates.append(shots)
327
+ else:
328
+ turn_templates.append(path.read_text())
329
+
330
+ # Ensure proper typing: first 3 are str|None, last is list[SingleTurnShot]|None
331
+ system, user, assistant, shots = turn_templates
332
+ return TurnTemplates(
333
+ system=system if isinstance(system, str) else None,
334
+ user=user if isinstance(user, str) else None,
335
+ assistant=assistant if isinstance(assistant, str) else None,
336
+ shots=shots if isinstance(shots, list) else None,
337
+ )
338
+
339
+
340
+ def hash_dataset(dataset: Sequence[StringThread], num_samples: int = 5) -> str:
341
+ """Compute a hash of dataset for quick unsafe comparison.
342
+
343
+ Hashes: dataset length + first N elements + last N elements
344
+ This catches most dataset changes without processing the entire dataset.
345
+
346
+ Args:
347
+ dataset: List of dataset items
348
+ num_samples: Number of elements to sample from start/end (default 5)
349
+
350
+ Returns:
351
+ SHA256 hash
352
+ """
353
+ hasher = hashlib.sha256()
354
+
355
+ hasher.update(str(len(dataset)).encode())
356
+
357
+ num_samples = min(num_samples, len(dataset) // 2)
358
+
359
+ for item in dataset[:num_samples]:
360
+ hasher.update(str(item).encode())
361
+
362
+ for item in dataset[-num_samples:]:
363
+ hasher.update(str(item).encode())
364
+
365
+ return hasher.hexdigest()
@@ -0,0 +1,8 @@
1
+ from .environment import Environment, EnvironmentFactory, TrajectoryScore, TurnScore
2
+
3
+ __all__ = [
4
+ "Environment",
5
+ "EnvironmentFactory",
6
+ "TrajectoryScore",
7
+ "TurnScore",
8
+ ]
@@ -0,0 +1,121 @@
1
+ from abc import ABC, abstractmethod
2
+ from collections import defaultdict
3
+ from dataclasses import dataclass
4
+ from numbers import Number
5
+ from typing import Any, Callable
6
+
7
+ import numpy as np
8
+
9
+ from adaptive_harmony import StringThread
10
+ from adaptive_harmony.logging_table import Table
11
+
12
+
13
+ @dataclass
14
+ class TurnScore:
15
+ score: float
16
+ num_assistant_turns: int
17
+
18
+
19
+ @dataclass
20
+ class TrajectoryScore:
21
+ scores: list[TurnScore]
22
+
23
+ def to_turn_scores(self) -> list[float]:
24
+ return [turn_score.score for turn_score in self.scores for _ in range(turn_score.num_assistant_turns)]
25
+
26
+ @property
27
+ def cumulative_score(self) -> float:
28
+ return sum([turn_score.score for turn_score in self.scores])
29
+
30
+
31
+ class Environment(ABC):
32
+ """
33
+ Environment to inherit from when building an environment.
34
+ """
35
+
36
+ def __init__(self, logging_function: Callable):
37
+ self.logging_function = logging_function
38
+
39
+ @abstractmethod
40
+ async def react_to(self, thread: StringThread) -> list[tuple[str, str]] | TrajectoryScore:
41
+ """Returns either [("tool", tool response), ...] or [("user", user question)] or TrajectoryScore when DONE."""
42
+ pass
43
+
44
+ async def bootstrap_prompt(self, thread: StringThread) -> StringThread:
45
+ return thread
46
+
47
+ async def generate_trajectory_and_grade(
48
+ self, model, initial_thread: StringThread
49
+ ) -> tuple[StringThread, TrajectoryScore]:
50
+ """Generate a full trajectory by interacting with the model until termination."""
51
+ thread = await self.bootstrap_prompt(initial_thread)
52
+
53
+ while True:
54
+ # Generate model response
55
+ thread = await model.generate(thread)
56
+
57
+ # Get environment reaction
58
+ env_response = await self.react_to(thread)
59
+
60
+ # If we got a score, we're done
61
+ if isinstance(env_response, TrajectoryScore):
62
+ return thread, env_response
63
+
64
+ # Otherwise, add the environment responses to the thread
65
+ for role, content in env_response:
66
+ if role == "tool":
67
+ thread = thread.tool(content)
68
+ elif role == "user":
69
+ thread = thread.user(content)
70
+ else:
71
+ raise ValueError(f"Unknown role: {role}")
72
+
73
+
74
+ class EnvironmentFactory(ABC):
75
+ """
76
+ Abstract class to build environments. It is necessary because each trajectory must have its own unique Environment
77
+ """
78
+
79
+ def __init__(self, logging_name: str | None = None):
80
+ self._logs: dict[str, list] = defaultdict(list)
81
+ self.logging_name = logging_name
82
+
83
+ def add_log(self, key, new_value) -> None:
84
+ """Add a log entry to the scorer's log collection."""
85
+ self._logs[key].append(new_value)
86
+
87
+ def get_logs(self, clear: bool = False) -> dict[str, float | Table]:
88
+ """
89
+ Get aggregated logs from all score calls.
90
+ Base implementation computes statistics for "score" keys in individual logs.
91
+ If there are none, returns empty dict.
92
+ """
93
+ if not self._logs:
94
+ return {}
95
+
96
+ logs = {}
97
+ for k, v in self._logs.items():
98
+ if isinstance(v[0], Number):
99
+ logs[k] = np.asarray(v).mean()
100
+ elif isinstance(v[0], Table):
101
+ headers = v[0].headers
102
+ assert all(table.headers == headers for table in v)
103
+ overall_table = Table(headers)
104
+ for table in v:
105
+ overall_table.add_rows(table.rows)
106
+ logs[k] = overall_table
107
+ else:
108
+ raise ValueError(f"Unknown type: {type(v[0])}")
109
+
110
+ if clear:
111
+ self.clear_logs()
112
+ return logs
113
+
114
+ def clear_logs(self) -> None:
115
+ """
116
+ Clear all accumulated logs.
117
+ """
118
+ self._logs.clear()
119
+
120
+ @abstractmethod
121
+ def create_environment(self, metadata: Any) -> Environment: ...
@@ -0,0 +1 @@
1
+ from .evaluation_artifact import EvaluationArtifact as EvaluationArtifact
@@ -0,0 +1,67 @@
1
+ import logging
2
+ import uuid
3
+ from typing import List, Self
4
+
5
+ from harmony_client import (
6
+ EvalSample,
7
+ EvaluationArtifactBase,
8
+ )
9
+ from harmony_client.runtime.context import RecipeContext
10
+
11
+ logger = logging.getLogger(__name__)
12
+
13
+
14
+ class EvaluationArtifact:
15
+ def __init__(self, name: str, ctx: RecipeContext) -> None:
16
+ artifact_id = str(uuid.uuid4())
17
+ url = ctx.file_storage.mk_url(f"artifacts/eval_samples_{artifact_id}.jsonl")
18
+ self._base = EvaluationArtifactBase(name, url, artifact_id)
19
+ self.ctx = ctx
20
+ self.ctx.job.register_artifact(self._base.artifact)
21
+
22
+ @property
23
+ def id(self) -> str:
24
+ return self._base.id
25
+
26
+ @property
27
+ def name(self) -> str:
28
+ return self._base.name
29
+
30
+ @property
31
+ def kind(self) -> str:
32
+ return self._base.kind
33
+
34
+ @property
35
+ def uri(self) -> str:
36
+ assert self._base.uri is not None
37
+ return self._base.uri
38
+
39
+ def add_samples(self, samples: List[EvalSample]) -> Self:
40
+ """Add evaluation samples to this artifact.
41
+
42
+ Args:
43
+ samples: List of evaluation samples to add
44
+
45
+ Returns:
46
+ Self for method chaining
47
+
48
+ Raises:
49
+ ValueError: If samples list is empty
50
+ Exception: If serialization or storage fails
51
+ """
52
+ if not samples:
53
+ raise ValueError("Cannot add empty samples list")
54
+
55
+ try:
56
+ samples_json = self._base.samples_to_adaptive_json(samples)
57
+ for json_str in samples_json:
58
+ self.ctx.file_storage.append((json_str + "\n").encode("utf-8"), self.uri)
59
+ logger.debug(f"Added {len(samples)} samples to artifact {self.id}")
60
+ except Exception as e:
61
+ logger.error(f"Failed to add samples to artifact {self.id}: {e}")
62
+ raise
63
+
64
+ return self
65
+
66
+ def __repr__(self):
67
+ return f"EvaluationArtifact(id={self.id}, name={self.name}, kind={self.kind}, uri={self.uri})"
@@ -0,0 +1,20 @@
1
+ from adaptive_harmony import Grade
2
+
3
+ from .answer_relevancy_judge import AnswerRelevancyGrader
4
+ from .base_grader import BaseGrader
5
+ from .binary_judge import BinaryJudgeGrader
6
+ from .context_relevancy_judge import ContextRelevancyGrader
7
+ from .exceptions import IgnoreScoreException
8
+ from .faithfulness_judge import FaithfulnessGrader
9
+ from .range_judge import RangeJudgeGrader
10
+
11
+ __all__ = [
12
+ "BaseGrader",
13
+ "Grade",
14
+ "IgnoreScoreException",
15
+ "BinaryJudgeGrader",
16
+ "RangeJudgeGrader",
17
+ "FaithfulnessGrader",
18
+ "AnswerRelevancyGrader",
19
+ "ContextRelevancyGrader",
20
+ ]
@@ -0,0 +1,3 @@
1
+ from .answer_relevancy_judge import AnswerRelevancyGrader
2
+
3
+ __all__ = ["AnswerRelevancyGrader"]
@@ -0,0 +1,102 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import Literal
4
+
5
+ import numpy as np
6
+ import pysbd
7
+ from harmony_client import Grade, InferenceModel, StringThread
8
+ from pydantic import BaseModel, Field
9
+
10
+ from adaptive_harmony.core.utils import stringify_thread
11
+ from adaptive_harmony.graders.answer_relevancy_judge.prompts import DEFAULT_SHOTS, SYSTEM, USER
12
+ from adaptive_harmony.graders.base_grader import BaseGrader
13
+ from adaptive_harmony.graders.faithfulness_judge.faithfulness_judge import SupportedLanguages
14
+ from adaptive_harmony.graders.utils import sample_score_distribution, separate_context_from_last_user_turn
15
+ from adaptive_harmony.logging_table import Table
16
+
17
+
18
+ class StatementRelevancy(BaseModel):
19
+ reason: str = Field(description="The justification for the score given to a statement. Keep it short and concise.")
20
+ score: Literal[0, 1] = Field(
21
+ description="The score for the statement. A score of 1 if the statement is relevant to addressing the original input, and 0 if the statement is irrelevant to addressing the input"
22
+ )
23
+
24
+
25
+ class AnswerRelevancyResults(BaseModel):
26
+ results: list[StatementRelevancy] = Field(description="A list of relevancy results for the statements")
27
+
28
+
29
+ class AnswerRelevancyGrader(BaseGrader):
30
+ def __init__(
31
+ self,
32
+ model: InferenceModel,
33
+ language: SupportedLanguages = "en",
34
+ grader_key: str = "answer_relevancy_judge",
35
+ grader_id: str | None = None,
36
+ ):
37
+ super().__init__(grader_key)
38
+ self.model = model
39
+ self.language = language
40
+ self.grader_id_or_key = grader_id or grader_key
41
+ self.sentence_splitter = pysbd.Segmenter(language=language)
42
+ self.shots = DEFAULT_SHOTS
43
+
44
+ async def grade(self, sample: StringThread) -> Grade:
45
+ _, user_question = separate_context_from_last_user_turn(sample)
46
+
47
+ completion = sample.last_content()
48
+ split_sentences = self.sentence_splitter.segment(completion)
49
+ sentences = [sentence.strip() for sentence in split_sentences if sentence.strip()]
50
+ sentences_judge_str = "\n".join(f"{i}: {sentence}" for i, sentence in enumerate(sentences))
51
+
52
+ judging_thread = (
53
+ StringThread()
54
+ .system(SYSTEM.format(json_schema=self.model.render_schema(AnswerRelevancyResults), shots=self.shots))
55
+ .user(USER.format(user_question=user_question, statements=sentences_judge_str))
56
+ )
57
+
58
+ try:
59
+ _, response = await self.model.temperature(0.0).generate_and_validate(
60
+ judging_thread, AnswerRelevancyResults
61
+ )
62
+ results = response.results
63
+ except Exception as e:
64
+ self.add_log({"prompt": stringify_thread(judging_thread, sep=f"\n\n{'-' * 10}\n\n"), "error": str(e)})
65
+ raise
66
+
67
+ reason = ""
68
+ for i, (result, statement) in enumerate(zip(results, sentences)):
69
+ emoji = "✅" if result.score == 1 else "❌"
70
+ result = "PASS" if result.score == 1 else "FAIL"
71
+ statement_display = statement[:150] + ("..." if len(statement) > 150 else "")
72
+ reason += f"{emoji} Statement {i}: {result}\n Excerpt: {statement_display}:\nReason: {result}\n\n"
73
+
74
+ score = np.mean([float(result.score) for result in results]) if results else 0.0
75
+ self.add_log(
76
+ {"score": score, "prompt": stringify_thread(judging_thread, sep=f"\n\n{'-' * 10}\n\n"), "reasoning": reason}
77
+ )
78
+ return Grade(value=float(score), grader_key=self.grader_id_or_key, reasoning=reason)
79
+
80
+ def get_logs(self, clear: bool = False, log_all_samples: bool = False) -> dict[str, float | Table]:
81
+ # Only clear logs at the end if clear is True
82
+ logs = super().get_logs(clear=False)
83
+
84
+ successfully_scored_samples = [log for log in self._logs if "score" in log]
85
+
86
+ # stratified sample range of scores to see high and low
87
+ if not log_all_samples:
88
+ subset_successfully_scored_samples = sample_score_distribution(successfully_scored_samples, 15)
89
+ else:
90
+ # if we have fewer than 15 samples or we want to log all samples, take them all
91
+ subset_successfully_scored_samples = successfully_scored_samples
92
+
93
+ failed_scored_samples = [log for log in self._logs if "error" in log]
94
+
95
+ sample_logs = self.get_sample_tables(subset_successfully_scored_samples, failed_scored_samples)
96
+
97
+ logs.update(sample_logs)
98
+
99
+ if clear:
100
+ self.clear_logs()
101
+
102
+ return logs