harmony-client 0.1.0__cp312-cp312-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- harmony_client/__init__.py +78 -0
- harmony_client/artifacts/__init__.py +5 -0
- harmony_client/artifacts/custom_artifact.py +46 -0
- harmony_client/artifacts/dataset_artifact.py +268 -0
- harmony_client/artifacts/model_artifact.py +34 -0
- harmony_client/file_storage.py +378 -0
- harmony_client/harmony_client.cp312-win_amd64.pyd +0 -0
- harmony_client/harmony_client.pyi +1615 -0
- harmony_client/internal/__init__.py +7 -0
- harmony_client/internal/eval_samples_html.py +122 -0
- harmony_client/internal/utils.py +9 -0
- harmony_client/logging_table.py +121 -0
- harmony_client/parameters/__init__.py +295 -0
- harmony_client/parameters/dataset_kinds.py +49 -0
- harmony_client/parameters/model_kinds.py +13 -0
- harmony_client/py.typed +0 -0
- harmony_client/runtime/__init__.py +29 -0
- harmony_client/runtime/context.py +191 -0
- harmony_client/runtime/data.py +76 -0
- harmony_client/runtime/decorators.py +19 -0
- harmony_client/runtime/dto/AdaptiveDataset.py +23 -0
- harmony_client/runtime/dto/AdaptiveGrader.py +68 -0
- harmony_client/runtime/dto/AdaptiveModel.py +19 -0
- harmony_client/runtime/dto/DatasetSampleFormats.py +93 -0
- harmony_client/runtime/dto/__init__.py +2 -0
- harmony_client/runtime/dto/base.py +7 -0
- harmony_client/runtime/model_artifact_save.py +23 -0
- harmony_client/runtime/runner.py +368 -0
- harmony_client/runtime/simple_notifier.py +21 -0
- harmony_client-0.1.0.dist-info/METADATA +38 -0
- harmony_client-0.1.0.dist-info/RECORD +32 -0
- harmony_client-0.1.0.dist-info/WHEEL +4 -0
|
@@ -0,0 +1,122 @@
|
|
|
1
|
+
import json
|
|
2
|
+
from pathlib import Path
|
|
3
|
+
|
|
4
|
+
from harmony_client import EvalSample, StringThread
|
|
5
|
+
from harmony_client.internal.utils import stringify_thread
|
|
6
|
+
from harmony_client.logging_table import Table
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def _extract_model_key(model_path: str | None = None) -> str:
|
|
10
|
+
model_path = model_path or ""
|
|
11
|
+
if model_path.startswith("model_registry://"):
|
|
12
|
+
return model_path[len("model_registry://") :]
|
|
13
|
+
return model_path
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def _save_detailed_eval_table(eval_samples: list[EvalSample], output_dir: str | None = None) -> None:
|
|
17
|
+
"""
|
|
18
|
+
Method subject to change, only for internal library use.
|
|
19
|
+
Do not use on its own, use RecipeContext.log_eval_result instead.
|
|
20
|
+
"""
|
|
21
|
+
default_output_file: str = "evaluation_samples.html"
|
|
22
|
+
if not eval_samples:
|
|
23
|
+
print("No evaluation samples to save")
|
|
24
|
+
return
|
|
25
|
+
|
|
26
|
+
# Force provided path into dir
|
|
27
|
+
if output_dir is not None:
|
|
28
|
+
output_path = Path(output_dir)
|
|
29
|
+
if output_path.is_file() or output_path.suffix:
|
|
30
|
+
eval_dir = output_path.parent
|
|
31
|
+
else:
|
|
32
|
+
eval_dir = output_path
|
|
33
|
+
else:
|
|
34
|
+
eval_dir = Path.cwd()
|
|
35
|
+
|
|
36
|
+
# Create the directory structure
|
|
37
|
+
eval_dir.mkdir(parents=True, exist_ok=True)
|
|
38
|
+
html_path = eval_dir / Path(default_output_file).name
|
|
39
|
+
|
|
40
|
+
# Collect all unique grader names to determine column structure
|
|
41
|
+
all_grader_names = set()
|
|
42
|
+
for sample in eval_samples:
|
|
43
|
+
for grade in sample.grades:
|
|
44
|
+
all_grader_names.add(grade.grader_key)
|
|
45
|
+
|
|
46
|
+
all_grader_names = sorted(all_grader_names)
|
|
47
|
+
|
|
48
|
+
# Create Table
|
|
49
|
+
headers = ["Prompt", "Model", "Completion"]
|
|
50
|
+
for grader_name in all_grader_names:
|
|
51
|
+
headers.extend([f"{grader_name}_score", f"{grader_name}_reason"])
|
|
52
|
+
|
|
53
|
+
table = Table(headers)
|
|
54
|
+
|
|
55
|
+
# Group samples by prompt (stringified thread)
|
|
56
|
+
prompt_groups = {}
|
|
57
|
+
for sample in eval_samples:
|
|
58
|
+
# Extract just the prompt part (everything except the last assistant turn) for grouping
|
|
59
|
+
turns = sample.interaction.thread.get_turns()
|
|
60
|
+
# Find prompt turns (everything except the last assistant turn)
|
|
61
|
+
prompt_turns: list[tuple[str, str]] = []
|
|
62
|
+
completion = ""
|
|
63
|
+
# Get all turns except extract the completion separately
|
|
64
|
+
for i, (role, content) in enumerate(turns):
|
|
65
|
+
if role.lower() == "assistant" and i == len(turns) - 1:
|
|
66
|
+
completion = content
|
|
67
|
+
else:
|
|
68
|
+
prompt_turns.append((role, content))
|
|
69
|
+
|
|
70
|
+
# Create prompt string for grouping
|
|
71
|
+
prompt_str = stringify_thread(StringThread(prompt_turns))
|
|
72
|
+
if prompt_str not in prompt_groups:
|
|
73
|
+
prompt_groups[prompt_str] = []
|
|
74
|
+
|
|
75
|
+
prompt_groups[prompt_str].append(
|
|
76
|
+
{
|
|
77
|
+
"model": _extract_model_key(sample.interaction.source or "Unknown"),
|
|
78
|
+
"completion": completion,
|
|
79
|
+
"grades": sample.grades,
|
|
80
|
+
}
|
|
81
|
+
)
|
|
82
|
+
|
|
83
|
+
# Create table rows
|
|
84
|
+
for prompt_str, models_data in prompt_groups.items():
|
|
85
|
+
for i, model_data in enumerate(models_data):
|
|
86
|
+
row = [
|
|
87
|
+
prompt_str if i == 0 else "", # Only show prompt for first model in group
|
|
88
|
+
model_data["model"],
|
|
89
|
+
model_data["completion"],
|
|
90
|
+
]
|
|
91
|
+
|
|
92
|
+
# Add grader score and reasoning columns
|
|
93
|
+
grade_dict = {grade.grader_key: grade for grade in model_data["grades"]}
|
|
94
|
+
|
|
95
|
+
for grader_name in all_grader_names:
|
|
96
|
+
if grader_name in grade_dict:
|
|
97
|
+
grade = grade_dict[grader_name]
|
|
98
|
+
row.extend([grade.value, grade.reasoning or ""])
|
|
99
|
+
else:
|
|
100
|
+
row.extend(["N/A", "No evaluation"])
|
|
101
|
+
|
|
102
|
+
table.add_row(row)
|
|
103
|
+
|
|
104
|
+
# Save HTML table
|
|
105
|
+
with open(html_path, "w") as f:
|
|
106
|
+
f.write(table.to_html_table())
|
|
107
|
+
|
|
108
|
+
# Save summary metadata as JSON
|
|
109
|
+
metadata = {
|
|
110
|
+
"total_samples": len(eval_samples),
|
|
111
|
+
"unique_prompts": len(prompt_groups),
|
|
112
|
+
"models": sorted(
|
|
113
|
+
set(model_data["model"] for models_data in prompt_groups.values() for model_data in models_data)
|
|
114
|
+
),
|
|
115
|
+
"graders": all_grader_names,
|
|
116
|
+
}
|
|
117
|
+
metadata_path = eval_dir / "metadata.json"
|
|
118
|
+
|
|
119
|
+
with open(metadata_path, "w") as f:
|
|
120
|
+
json.dump(metadata, f, indent=2)
|
|
121
|
+
|
|
122
|
+
print(f"📁 Detailed evaluation samples saved to: {html_path}")
|
|
@@ -0,0 +1,9 @@
|
|
|
1
|
+
"""Utility functions for internal use."""
|
|
2
|
+
|
|
3
|
+
from harmony_client import StringThread
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
def stringify_thread(thread: StringThread, sep: str = "\n\n") -> str:
|
|
7
|
+
"""Convert StringThread to readable text format."""
|
|
8
|
+
turns = thread.get_turns()
|
|
9
|
+
return sep.join([f"[{turn.role}]\n{turn.content}" for turn in turns])
|
|
@@ -0,0 +1,121 @@
|
|
|
1
|
+
from typing import Any, Sequence, overload
|
|
2
|
+
|
|
3
|
+
from typing_extensions import Self
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class Table:
|
|
7
|
+
Column = list[str] | list[float]
|
|
8
|
+
Row = list[str | float]
|
|
9
|
+
|
|
10
|
+
def __init__(self, initial_headers: Sequence[str] = []):
|
|
11
|
+
self.headers: list[str] = list(initial_headers)
|
|
12
|
+
self.rows: list[Table.Row] = []
|
|
13
|
+
|
|
14
|
+
@overload
|
|
15
|
+
def add_column(self, header: str, column: list[str]) -> Self: ...
|
|
16
|
+
|
|
17
|
+
@overload
|
|
18
|
+
def add_column(self, header: str, column: list[float]) -> Self: ...
|
|
19
|
+
|
|
20
|
+
def add_column(self, header: str, column: list[str] | list[float]) -> Self:
|
|
21
|
+
if self.rows:
|
|
22
|
+
# if we're not empty we just add
|
|
23
|
+
for row, new_value in zip(self.rows, column, strict=True):
|
|
24
|
+
row.append(new_value) # type: ignore
|
|
25
|
+
else:
|
|
26
|
+
for value in column:
|
|
27
|
+
self.rows.append([value])
|
|
28
|
+
self.headers.append(header)
|
|
29
|
+
|
|
30
|
+
return self
|
|
31
|
+
|
|
32
|
+
# Union Any is me giving up because I have too much stuff to do to be stuck on thos
|
|
33
|
+
def add_row(self, row: Sequence[str | float | Any]) -> Self:
|
|
34
|
+
assert len(row) == len(self.headers)
|
|
35
|
+
self.rows.append(row) # type: ignore
|
|
36
|
+
return self
|
|
37
|
+
|
|
38
|
+
def add_rows(self, rows: list[Sequence[str | float | Any]]) -> Self:
|
|
39
|
+
for row in rows:
|
|
40
|
+
self.add_row(row)
|
|
41
|
+
return self
|
|
42
|
+
|
|
43
|
+
def __getitem__(self, key: str) -> Column:
|
|
44
|
+
column_index = self.headers.index(key)
|
|
45
|
+
return [row[column_index] for row in self.rows] # type: ignore
|
|
46
|
+
|
|
47
|
+
def __contains__(self, key: str):
|
|
48
|
+
return key in self.headers
|
|
49
|
+
|
|
50
|
+
def export(self) -> tuple[list[str], list[list[str]]]:
|
|
51
|
+
return self.headers, [[str(inner) for inner in row] for row in self.rows]
|
|
52
|
+
|
|
53
|
+
def _repr_html_(self) -> str:
|
|
54
|
+
return self.to_html_table()
|
|
55
|
+
|
|
56
|
+
def to_html_table(self):
|
|
57
|
+
"""
|
|
58
|
+
Generates a nicely styled HTML table with CSS for better readability,
|
|
59
|
+
especially for cells containing a lot of text.
|
|
60
|
+
"""
|
|
61
|
+
# CSS styles for a professional-looking table
|
|
62
|
+
# This is defined once and applied to the table via a class.
|
|
63
|
+
css_style = """
|
|
64
|
+
<style>
|
|
65
|
+
.classy-table {
|
|
66
|
+
width: 100%;
|
|
67
|
+
border-collapse: collapse; /* Removes space between borders */
|
|
68
|
+
font-family: Arial, sans-serif; /* A clean, readable font */
|
|
69
|
+
font-size: 14px;
|
|
70
|
+
box-shadow: 0 2px 5px rgba(0,0,0,0.1);
|
|
71
|
+
margin: 20px 0;
|
|
72
|
+
}
|
|
73
|
+
.classy-table th, .classy-table td {
|
|
74
|
+
padding: 12px 15px; /* Adds space inside cells */
|
|
75
|
+
border: 1px solid #ddd; /* Light grey borders */
|
|
76
|
+
text-align: left;
|
|
77
|
+
vertical-align: top; /* Aligns content to the top */
|
|
78
|
+
/* This is key for handling long text */
|
|
79
|
+
word-wrap: break-word;
|
|
80
|
+
overflow-wrap: break-word;
|
|
81
|
+
white-space: pre-wrap;
|
|
82
|
+
}
|
|
83
|
+
.classy-table th {
|
|
84
|
+
background-color: #f2f2f2; /* Light grey header background */
|
|
85
|
+
font-weight: bold;
|
|
86
|
+
color: #333;
|
|
87
|
+
}
|
|
88
|
+
.classy-table tr:nth-child(even) {
|
|
89
|
+
background-color: #f9f9f9; /* Zebra-striping for even rows */
|
|
90
|
+
}
|
|
91
|
+
.classy-table tr:hover {
|
|
92
|
+
background-color: #f1f1f1; /* Highlight row on hover */
|
|
93
|
+
}
|
|
94
|
+
</style>"""
|
|
95
|
+
# Create HTML table with a specific class
|
|
96
|
+
html_table = "<table class='classy-table'>\n"
|
|
97
|
+
|
|
98
|
+
# Add table header
|
|
99
|
+
html_table += " <thead>\n" # Using <thead> for semantic HTML
|
|
100
|
+
html_table += " <tr>\n"
|
|
101
|
+
for col_name in self.headers:
|
|
102
|
+
# No more inline styles needed here!
|
|
103
|
+
html_table += f" <th>{col_name}</th>\n"
|
|
104
|
+
html_table += " </tr>\n"
|
|
105
|
+
html_table += " </thead>\n"
|
|
106
|
+
|
|
107
|
+
# Add table rows
|
|
108
|
+
html_table += " <tbody>\n" # Using <tbody> for semantic HTML
|
|
109
|
+
for row in self.rows:
|
|
110
|
+
html_table += " <tr>\n"
|
|
111
|
+
for value in row:
|
|
112
|
+
# No more inline styles needed here either
|
|
113
|
+
html_table += f" <td>{value}</td>\n"
|
|
114
|
+
html_table += " </tr>\n"
|
|
115
|
+
html_table += " </tbody>\n"
|
|
116
|
+
|
|
117
|
+
# Close HTML table
|
|
118
|
+
html_table += "</table>"
|
|
119
|
+
|
|
120
|
+
# Return the styles and the table together
|
|
121
|
+
return css_style + html_table
|
|
@@ -0,0 +1,295 @@
|
|
|
1
|
+
import json
|
|
2
|
+
from typing import TYPE_CHECKING, Any
|
|
3
|
+
from uuid import UUID
|
|
4
|
+
|
|
5
|
+
import aiofiles
|
|
6
|
+
from loguru import logger
|
|
7
|
+
from pydantic import BaseModel, model_serializer, model_validator
|
|
8
|
+
|
|
9
|
+
from harmony_client import StringThread
|
|
10
|
+
from harmony_client.parameters import dataset_kinds, model_kinds
|
|
11
|
+
from harmony_client.runtime.context import RecipeContext
|
|
12
|
+
from harmony_client.runtime.dto.DatasetSampleFormats import SampleMetadata
|
|
13
|
+
|
|
14
|
+
if TYPE_CHECKING:
|
|
15
|
+
from adaptive_harmony.graders.base_grader import BaseGrader
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class Dataset[T: dataset_kinds.DatasetKind](BaseModel):
|
|
19
|
+
dataset_key: str
|
|
20
|
+
feedback_key: str | None = None
|
|
21
|
+
local_file_path: str | None = None
|
|
22
|
+
|
|
23
|
+
@model_validator(mode="before")
|
|
24
|
+
@classmethod
|
|
25
|
+
def validate_from_json(cls, data):
|
|
26
|
+
"""Handle deserialization from JSON - accepts dict with dataset_key."""
|
|
27
|
+
# If it's a string, convert it to the expected dict format
|
|
28
|
+
if isinstance(data, str):
|
|
29
|
+
return {"dataset_key": data, "feedback_key": None}
|
|
30
|
+
return data
|
|
31
|
+
|
|
32
|
+
async def load(self, ctx: RecipeContext | None = None) -> list[StringThread]:
|
|
33
|
+
"""
|
|
34
|
+
Load dataset samples from the Harmony service.
|
|
35
|
+
|
|
36
|
+
Args:
|
|
37
|
+
ctx: RecipeContext with client and file_storage
|
|
38
|
+
|
|
39
|
+
Returns:
|
|
40
|
+
list of StringThread objects
|
|
41
|
+
"""
|
|
42
|
+
|
|
43
|
+
def with_metadata(thread: StringThread, metadata: SampleMetadata):
|
|
44
|
+
if metadata.external_data:
|
|
45
|
+
thread.metadata = {**metadata.model_dump(exclude_none=True), **metadata.external_data}
|
|
46
|
+
else:
|
|
47
|
+
thread.metadata = metadata.model_dump()
|
|
48
|
+
|
|
49
|
+
async def parse_internal_format(line_dict: Any) -> StringThread | None:
|
|
50
|
+
thread = None
|
|
51
|
+
|
|
52
|
+
format: type[dataset_kinds.DtoBaseModel] | None = None
|
|
53
|
+
match T:
|
|
54
|
+
case dataset_kinds.Prompt:
|
|
55
|
+
format = dataset_kinds.DatasetPromptSample
|
|
56
|
+
case dataset_kinds.Completion:
|
|
57
|
+
format = dataset_kinds.DatasetSample
|
|
58
|
+
case dataset_kinds.Metric:
|
|
59
|
+
format = dataset_kinds.DatasetMetricSample
|
|
60
|
+
case dataset_kinds.Preference:
|
|
61
|
+
format = dataset_kinds.DatasetPreferenceSample
|
|
62
|
+
case _: # Mixed dataset case
|
|
63
|
+
# order is important here: try the most constrained formats first
|
|
64
|
+
formats: list[type[dataset_kinds.DtoBaseModel]] = [
|
|
65
|
+
dataset_kinds.DatasetPreferenceSample,
|
|
66
|
+
dataset_kinds.DatasetMetricSample,
|
|
67
|
+
dataset_kinds.DatasetSample,
|
|
68
|
+
dataset_kinds.DatasetPromptSample,
|
|
69
|
+
]
|
|
70
|
+
for f in formats:
|
|
71
|
+
try:
|
|
72
|
+
f.model_validate(line_dict)
|
|
73
|
+
format = f
|
|
74
|
+
break
|
|
75
|
+
except Exception:
|
|
76
|
+
pass
|
|
77
|
+
|
|
78
|
+
if format is None:
|
|
79
|
+
raise ValueError(f"Could not determine format for line in dataset. Line {line_dict}")
|
|
80
|
+
|
|
81
|
+
match format:
|
|
82
|
+
case dataset_kinds.DatasetPromptSample:
|
|
83
|
+
sample = format.model_validate(line_dict)
|
|
84
|
+
thread = await StringThread.from_dataset(
|
|
85
|
+
[(turn.root[0], turn.root[1]) for turn in sample.prompt], None
|
|
86
|
+
)
|
|
87
|
+
with_metadata(thread, sample.metadata)
|
|
88
|
+
case dataset_kinds.DatasetSample:
|
|
89
|
+
sample = format.model_validate(line_dict)
|
|
90
|
+
thread = await StringThread.from_dataset(
|
|
91
|
+
[(turn.root[0], turn.root[1]) for turn in sample.prompt], None
|
|
92
|
+
)
|
|
93
|
+
thread = thread.assistant(sample.completion.root[1])
|
|
94
|
+
with_metadata(thread, sample.metadata)
|
|
95
|
+
case dataset_kinds.DatasetMetricSample:
|
|
96
|
+
sample = dataset_kinds.DatasetMetricSample.model_validate(line_dict)
|
|
97
|
+
thread = await StringThread.from_dataset(
|
|
98
|
+
[(turn.root[0], turn.root[1]) for turn in sample.prompt], None
|
|
99
|
+
)
|
|
100
|
+
thread = thread.assistant(sample.completion.root[1])
|
|
101
|
+
with_metadata(thread, sample.metadata)
|
|
102
|
+
if self.feedback_key and sample.metrics:
|
|
103
|
+
# put metric value in "res" key in the metadata
|
|
104
|
+
thread.metadata["res"] = sample.metrics.get(self.feedback_key)
|
|
105
|
+
|
|
106
|
+
case dataset_kinds.DatasetPreferenceSample:
|
|
107
|
+
sample = dataset_kinds.DatasetPreferenceSample.model_validate(line_dict)
|
|
108
|
+
thread = await StringThread.from_dataset(
|
|
109
|
+
[(turn.root[0], turn.root[1]) for turn in sample.prompt], None
|
|
110
|
+
)
|
|
111
|
+
with_metadata(thread, sample.metadata)
|
|
112
|
+
thread.metadata["other_completion"] = sample.bad_completion.root[1]
|
|
113
|
+
thread.metadata["preferred_completion"] = sample.good_completion.root[1]
|
|
114
|
+
|
|
115
|
+
return thread
|
|
116
|
+
|
|
117
|
+
async def parse_external_format(line_dict: Any) -> StringThread | None:
|
|
118
|
+
thread = None
|
|
119
|
+
if "input" in line_dict or "messages" in line_dict:
|
|
120
|
+
key = "input" if "input" in line_dict else "messages"
|
|
121
|
+
thread = StringThread(
|
|
122
|
+
[(inner_turn_dict["role"], inner_turn_dict["content"]) for inner_turn_dict in line_dict[key]]
|
|
123
|
+
)
|
|
124
|
+
if "completion" in line_dict and line_dict["completion"]:
|
|
125
|
+
thread = thread.assistant(line_dict["completion"])
|
|
126
|
+
else:
|
|
127
|
+
print("Did not find `input`, or `messages` key in sample, ignoring")
|
|
128
|
+
if thread is not None:
|
|
129
|
+
thread.metadata = line_dict.get("metadata", {})
|
|
130
|
+
if "other_completion" in line_dict and "preferred_completion" in line_dict:
|
|
131
|
+
thread.metadata["other_completion"] = line_dict["other_completion"]
|
|
132
|
+
thread.metadata["preferred_completion"] = line_dict["preferred_completion"]
|
|
133
|
+
return thread
|
|
134
|
+
|
|
135
|
+
if ctx:
|
|
136
|
+
config_response = await ctx.client.get_dataset_config(self.dataset_key)
|
|
137
|
+
lines = ctx.file_storage.read(config_response.file_path, use_raw_path=True).decode("utf-8").splitlines()
|
|
138
|
+
else:
|
|
139
|
+
assert self.local_file_path is not None, "Local file path is required when ctx is not provided"
|
|
140
|
+
lines = []
|
|
141
|
+
async with aiofiles.open(self.local_file_path, encoding="utf-8") as f:
|
|
142
|
+
async for line in f:
|
|
143
|
+
lines.append(line.rstrip("\n"))
|
|
144
|
+
|
|
145
|
+
threads = []
|
|
146
|
+
parse_function = None
|
|
147
|
+
for line in lines:
|
|
148
|
+
if len(line.strip()) == 0:
|
|
149
|
+
continue
|
|
150
|
+
line_dict = json.loads(line)
|
|
151
|
+
|
|
152
|
+
if parse_function is None:
|
|
153
|
+
try:
|
|
154
|
+
thread = await parse_internal_format(line_dict)
|
|
155
|
+
parse_function = parse_internal_format
|
|
156
|
+
except Exception as e:
|
|
157
|
+
logger.warning("Could not read dataset as internal format, falling back to external format {}", e)
|
|
158
|
+
thread = await parse_external_format(line_dict)
|
|
159
|
+
parse_function = parse_external_format
|
|
160
|
+
else:
|
|
161
|
+
thread = await parse_function(line_dict)
|
|
162
|
+
|
|
163
|
+
if thread is not None:
|
|
164
|
+
threads.append(thread)
|
|
165
|
+
|
|
166
|
+
if len(threads) == 0:
|
|
167
|
+
raise ValueError("Did not find any valid format samples in the dataset")
|
|
168
|
+
return threads
|
|
169
|
+
|
|
170
|
+
def __hash__(self):
|
|
171
|
+
"""Make Dataset hashable based on its keys."""
|
|
172
|
+
return hash((self.dataset_key, self.feedback_key))
|
|
173
|
+
|
|
174
|
+
|
|
175
|
+
class Model[T: model_kinds.ModelKind](BaseModel):
|
|
176
|
+
model_key: str
|
|
177
|
+
|
|
178
|
+
@model_validator(mode="before")
|
|
179
|
+
@classmethod
|
|
180
|
+
def validate_from_json(cls, data):
|
|
181
|
+
"""Handle deserialization - accepts string or dict."""
|
|
182
|
+
if isinstance(data, str):
|
|
183
|
+
return {"model_key": data}
|
|
184
|
+
return data
|
|
185
|
+
|
|
186
|
+
@model_serializer
|
|
187
|
+
def serialize_model(self) -> str:
|
|
188
|
+
"""Serialize as just the model_key string."""
|
|
189
|
+
return self.model_key
|
|
190
|
+
|
|
191
|
+
async def to_builder(
|
|
192
|
+
self,
|
|
193
|
+
ctx: RecipeContext,
|
|
194
|
+
kv_cache_len: int | None = None,
|
|
195
|
+
tokens_to_generate: int | None = None,
|
|
196
|
+
tp: int | None = None,
|
|
197
|
+
):
|
|
198
|
+
"""
|
|
199
|
+
Create a ModelBuilder instance configured with this model's parameters.
|
|
200
|
+
|
|
201
|
+
Args:
|
|
202
|
+
ctx: RecipeContext with client
|
|
203
|
+
kv_cache_len (int | None, optional): KV cache length override
|
|
204
|
+
tokens_to_generate (int | None, optional): Tokens to generate override
|
|
205
|
+
tp (int | None, optional): Tensor parallelism override
|
|
206
|
+
|
|
207
|
+
Returns:
|
|
208
|
+
ModelBuilder: A configured model builder instance ready for use
|
|
209
|
+
"""
|
|
210
|
+
config_response = await ctx.client.get_model_config(self.model_key)
|
|
211
|
+
|
|
212
|
+
kwargs = {
|
|
213
|
+
k: v
|
|
214
|
+
for k, v in {
|
|
215
|
+
"kv_cache_len": kv_cache_len or config_response.kv_cache_len,
|
|
216
|
+
"tokens_to_generate": tokens_to_generate,
|
|
217
|
+
}.items()
|
|
218
|
+
if v is not None
|
|
219
|
+
}
|
|
220
|
+
builder = ctx.client.model(config_response.path, **kwargs)
|
|
221
|
+
|
|
222
|
+
if tp_to_pass := tp or config_response.tp:
|
|
223
|
+
builder = builder.tp(tp_to_pass)
|
|
224
|
+
|
|
225
|
+
return builder
|
|
226
|
+
|
|
227
|
+
def __hash__(self):
|
|
228
|
+
"""Make Model hashable based on its key."""
|
|
229
|
+
return hash(self.model_key)
|
|
230
|
+
|
|
231
|
+
|
|
232
|
+
class Grader(BaseModel):
|
|
233
|
+
grader_key: str
|
|
234
|
+
|
|
235
|
+
@model_validator(mode="before")
|
|
236
|
+
@classmethod
|
|
237
|
+
def validate_from_json(cls, data):
|
|
238
|
+
"""Handle deserialization - accepts string or dict."""
|
|
239
|
+
if isinstance(data, str):
|
|
240
|
+
return {"grader_key": data}
|
|
241
|
+
return data
|
|
242
|
+
|
|
243
|
+
@model_serializer
|
|
244
|
+
def serialize_model(self) -> str:
|
|
245
|
+
"""Serialize as just the grader_key string."""
|
|
246
|
+
return self.grader_key
|
|
247
|
+
|
|
248
|
+
async def load(
|
|
249
|
+
self,
|
|
250
|
+
ctx: RecipeContext,
|
|
251
|
+
tp: int | None = None,
|
|
252
|
+
kv_cache_len: int | None = None,
|
|
253
|
+
max_tokens: int | None = None,
|
|
254
|
+
) -> "BaseGrader":
|
|
255
|
+
"""
|
|
256
|
+
Load a grader instance configured with this grader's parameters.
|
|
257
|
+
|
|
258
|
+
Args:
|
|
259
|
+
ctx: RecipeContext with client
|
|
260
|
+
tp (int | None, optional): Tensor parallelism override
|
|
261
|
+
kv_cache_len (int | None, optional): KV cache length override
|
|
262
|
+
|
|
263
|
+
Returns:
|
|
264
|
+
BaseGrader: A configured grader instance ready for use
|
|
265
|
+
"""
|
|
266
|
+
# puke: need to figure out smth better here.
|
|
267
|
+
# I don't like code duplication but maybe we should duplicate these guys?
|
|
268
|
+
# I'm not sure tbh...
|
|
269
|
+
try:
|
|
270
|
+
from adaptive_harmony.graders import BaseGrader as GraderImpl
|
|
271
|
+
except ImportError:
|
|
272
|
+
raise ImportError(
|
|
273
|
+
"To load and instantiate a grader, the adaptive_harmony package is required. Install it with `pip install adaptive-harmony`"
|
|
274
|
+
)
|
|
275
|
+
|
|
276
|
+
from harmony_client.runtime.data import AdaptiveGrader
|
|
277
|
+
|
|
278
|
+
config_response = await ctx.client.get_grader_config(self.grader_key)
|
|
279
|
+
grader_config_data = json.loads(config_response.grader_config_json)
|
|
280
|
+
|
|
281
|
+
grader_config = AdaptiveGrader(
|
|
282
|
+
grader_id=UUID(config_response.grader_id),
|
|
283
|
+
key=config_response.key,
|
|
284
|
+
metric_id=UUID("00000000-0000-0000-0000-000000000000"), # unused in from_config
|
|
285
|
+
name=config_response.name,
|
|
286
|
+
config=grader_config_data,
|
|
287
|
+
)
|
|
288
|
+
|
|
289
|
+
return await GraderImpl.from_config(
|
|
290
|
+
grader_config=grader_config, ctx=ctx, tp=tp, kv_cache_len=kv_cache_len, max_tokens=max_tokens
|
|
291
|
+
)
|
|
292
|
+
|
|
293
|
+
def __hash__(self):
|
|
294
|
+
"""Make Grader hashable based on its key."""
|
|
295
|
+
return hash(self.grader_key)
|
|
@@ -0,0 +1,49 @@
|
|
|
1
|
+
from harmony_client.runtime.dto.DatasetSampleFormats import (
|
|
2
|
+
DatasetMetricSample,
|
|
3
|
+
DatasetPreferenceSample,
|
|
4
|
+
DatasetPromptSample,
|
|
5
|
+
DatasetSample,
|
|
6
|
+
DtoBaseModel,
|
|
7
|
+
)
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class DatasetKind:
|
|
11
|
+
def __init__(
|
|
12
|
+
self,
|
|
13
|
+
kind: str,
|
|
14
|
+
formats: list[type[DtoBaseModel]] = [
|
|
15
|
+
DatasetPreferenceSample,
|
|
16
|
+
DatasetMetricSample,
|
|
17
|
+
DatasetSample,
|
|
18
|
+
DatasetPromptSample,
|
|
19
|
+
],
|
|
20
|
+
):
|
|
21
|
+
self.kind = kind
|
|
22
|
+
self.formats = formats
|
|
23
|
+
|
|
24
|
+
def parse(self, line: dict):
|
|
25
|
+
for f in self.formats:
|
|
26
|
+
try:
|
|
27
|
+
return f.model_validate(line)
|
|
28
|
+
except Exception:
|
|
29
|
+
pass
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class Prompt(DatasetKind):
|
|
33
|
+
def __init__(self):
|
|
34
|
+
super().__init__("prompts", [DatasetPromptSample])
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
class Preference(DatasetKind):
|
|
38
|
+
def __init__(self):
|
|
39
|
+
super().__init__("preference", [DatasetPreferenceSample])
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
class Completion(DatasetKind):
|
|
43
|
+
def __init__(self):
|
|
44
|
+
super().__init__("completions", [DatasetSample])
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
class Metric(DatasetKind):
|
|
48
|
+
def __init__(self):
|
|
49
|
+
super().__init__("feedbacks", [DatasetMetricSample])
|
harmony_client/py.typed
ADDED
|
File without changes
|
|
@@ -0,0 +1,29 @@
|
|
|
1
|
+
from .context import RecipeContext
|
|
2
|
+
from .data import (
|
|
3
|
+
AdaptiveDatasetKind,
|
|
4
|
+
InputConfig,
|
|
5
|
+
)
|
|
6
|
+
from .decorators import recipe_main
|
|
7
|
+
from .dto.DatasetSampleFormats import (
|
|
8
|
+
DatasetMetricSample,
|
|
9
|
+
DatasetPreferenceSample,
|
|
10
|
+
DatasetPromptSample,
|
|
11
|
+
DatasetSample,
|
|
12
|
+
SampleMetadata,
|
|
13
|
+
TurnTuple,
|
|
14
|
+
)
|
|
15
|
+
from .simple_notifier import SimpleProgressNotifier
|
|
16
|
+
|
|
17
|
+
__all__ = [
|
|
18
|
+
"RecipeContext",
|
|
19
|
+
"AdaptiveDatasetKind",
|
|
20
|
+
"InputConfig",
|
|
21
|
+
"recipe_main",
|
|
22
|
+
"DatasetMetricSample",
|
|
23
|
+
"DatasetPreferenceSample",
|
|
24
|
+
"DatasetPromptSample",
|
|
25
|
+
"DatasetSample",
|
|
26
|
+
"SampleMetadata",
|
|
27
|
+
"TurnTuple",
|
|
28
|
+
"SimpleProgressNotifier",
|
|
29
|
+
]
|