harmony-client 0.1.0__cp312-cp312-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- harmony_client/__init__.py +78 -0
- harmony_client/artifacts/__init__.py +5 -0
- harmony_client/artifacts/custom_artifact.py +46 -0
- harmony_client/artifacts/dataset_artifact.py +268 -0
- harmony_client/artifacts/model_artifact.py +34 -0
- harmony_client/file_storage.py +378 -0
- harmony_client/harmony_client.cp312-win_amd64.pyd +0 -0
- harmony_client/harmony_client.pyi +1615 -0
- harmony_client/internal/__init__.py +7 -0
- harmony_client/internal/eval_samples_html.py +122 -0
- harmony_client/internal/utils.py +9 -0
- harmony_client/logging_table.py +121 -0
- harmony_client/parameters/__init__.py +295 -0
- harmony_client/parameters/dataset_kinds.py +49 -0
- harmony_client/parameters/model_kinds.py +13 -0
- harmony_client/py.typed +0 -0
- harmony_client/runtime/__init__.py +29 -0
- harmony_client/runtime/context.py +191 -0
- harmony_client/runtime/data.py +76 -0
- harmony_client/runtime/decorators.py +19 -0
- harmony_client/runtime/dto/AdaptiveDataset.py +23 -0
- harmony_client/runtime/dto/AdaptiveGrader.py +68 -0
- harmony_client/runtime/dto/AdaptiveModel.py +19 -0
- harmony_client/runtime/dto/DatasetSampleFormats.py +93 -0
- harmony_client/runtime/dto/__init__.py +2 -0
- harmony_client/runtime/dto/base.py +7 -0
- harmony_client/runtime/model_artifact_save.py +23 -0
- harmony_client/runtime/runner.py +368 -0
- harmony_client/runtime/simple_notifier.py +21 -0
- harmony_client-0.1.0.dist-info/METADATA +38 -0
- harmony_client-0.1.0.dist-info/RECORD +32 -0
- harmony_client-0.1.0.dist-info/WHEEL +4 -0
|
@@ -0,0 +1,191 @@
|
|
|
1
|
+
import inspect
|
|
2
|
+
import json
|
|
3
|
+
from datetime import datetime
|
|
4
|
+
from pathlib import Path
|
|
5
|
+
from typing import Self
|
|
6
|
+
|
|
7
|
+
import pandas as pd
|
|
8
|
+
from pydantic import Field
|
|
9
|
+
from pydantic_settings import BaseSettings, SettingsConfigDict
|
|
10
|
+
from rich.console import Console
|
|
11
|
+
from rich.table import Table as RichTable
|
|
12
|
+
|
|
13
|
+
from harmony_client import EvalSample, HarmonyClient, HarmonyJobNotifier, JobNotifier, StringThread, get_client
|
|
14
|
+
from harmony_client.file_storage import FileStorage, FileStorageConfig
|
|
15
|
+
from harmony_client.internal import _extract_model_key, _save_detailed_eval_table
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class RecipeConfig(BaseSettings):
|
|
19
|
+
model_config = SettingsConfigDict(env_prefix="ADAPTIVE_", cli_parse_args=True, cli_kebab_case=True)
|
|
20
|
+
|
|
21
|
+
harmony_url: str = Field(description="url of harmony service")
|
|
22
|
+
control_plane_url: str | None = Field(
|
|
23
|
+
default=None,
|
|
24
|
+
description="URL of the control plane service (Concorde). Required for fetching grader, dataset, and model configurations. Set this if you need to access centralized configuration; may be omitted for local or test runs where such configurations are not needed."
|
|
25
|
+
)
|
|
26
|
+
control_plane_api_token: str | None = Field(
|
|
27
|
+
default=None,
|
|
28
|
+
description="JWT token for authenticating with the control plane service (Concorde). Required when accessing protected internal API endpoints; may be omitted for local or test runs."
|
|
29
|
+
)
|
|
30
|
+
user_input_file: str | None = None
|
|
31
|
+
job_id: str = "test"
|
|
32
|
+
use_case: str | None = None
|
|
33
|
+
api_key: str | None = None
|
|
34
|
+
compute_pool: str | None = None
|
|
35
|
+
storage_url: str | None = None
|
|
36
|
+
num_gpus: int = 0
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
class RecipeContext:
|
|
40
|
+
client: HarmonyClient
|
|
41
|
+
job: JobNotifier
|
|
42
|
+
file_storage: FileStorage
|
|
43
|
+
config: RecipeConfig
|
|
44
|
+
# todo: pass world size
|
|
45
|
+
world_size: int = 1
|
|
46
|
+
|
|
47
|
+
def __init__(self, client: HarmonyClient, config: RecipeConfig):
|
|
48
|
+
self.client = client
|
|
49
|
+
self.config = config
|
|
50
|
+
self.job = HarmonyJobNotifier(client, config.job_id)
|
|
51
|
+
print(f"{config.storage_url=}")
|
|
52
|
+
if config.storage_url:
|
|
53
|
+
self.file_storage = FileStorage.new(FileStorageConfig.from_url(config.storage_url))
|
|
54
|
+
else:
|
|
55
|
+
self.file_storage = FileStorage.new(FileStorageConfig.from_url("file:///tmp/recipe_storage"))
|
|
56
|
+
|
|
57
|
+
@classmethod
|
|
58
|
+
async def load(cls) -> Self:
|
|
59
|
+
config = RecipeConfig() # type: ignore
|
|
60
|
+
return await cls.from_config(config)
|
|
61
|
+
|
|
62
|
+
@classmethod
|
|
63
|
+
async def from_config(cls, config: RecipeConfig) -> Self:
|
|
64
|
+
client = await get_client(
|
|
65
|
+
config.harmony_url,
|
|
66
|
+
num_gpus=config.num_gpus,
|
|
67
|
+
api_key=config.api_key,
|
|
68
|
+
use_case=config.use_case,
|
|
69
|
+
compute_pool=config.compute_pool,
|
|
70
|
+
job_id=config.job_id,
|
|
71
|
+
control_plane_url=config.control_plane_url,
|
|
72
|
+
control_plane_api_token=config.control_plane_api_token,
|
|
73
|
+
)
|
|
74
|
+
return cls(client, config)
|
|
75
|
+
|
|
76
|
+
def load_dataset(self, path: str) -> list[StringThread]:
|
|
77
|
+
lines = self.file_storage.read(path, use_raw_path=True).decode("utf-8").splitlines()
|
|
78
|
+
threads = []
|
|
79
|
+
for line in lines:
|
|
80
|
+
line_dict = json.loads(line)
|
|
81
|
+
thread = None
|
|
82
|
+
if "input" in line_dict or "messages" in line_dict:
|
|
83
|
+
key = "input" if "input" in line_dict else "messages"
|
|
84
|
+
thread = StringThread(
|
|
85
|
+
[(inner_turn_dict["role"], inner_turn_dict["content"]) for inner_turn_dict in line_dict[key]]
|
|
86
|
+
)
|
|
87
|
+
if "completion" in line_dict and line_dict["completion"]:
|
|
88
|
+
thread = thread.assistant(line_dict["completion"])
|
|
89
|
+
else:
|
|
90
|
+
print("Did not find `input`, or `messages` key in sample, ignoring")
|
|
91
|
+
|
|
92
|
+
if thread is not None:
|
|
93
|
+
thread.metadata = line_dict.get("metadata", {})
|
|
94
|
+
if "other_completion" in line_dict and "preferred_completion" in line_dict:
|
|
95
|
+
thread.metadata["other_completion"] = line_dict["other_completion"]
|
|
96
|
+
thread.metadata["preferred_completion"] = line_dict["preferred_completion"]
|
|
97
|
+
|
|
98
|
+
threads.append(thread)
|
|
99
|
+
|
|
100
|
+
if len(threads) == 0:
|
|
101
|
+
raise ValueError("Did not find any valid format samples in the dataset")
|
|
102
|
+
|
|
103
|
+
return threads
|
|
104
|
+
|
|
105
|
+
@staticmethod
|
|
106
|
+
def log_eval_result(eval_samples: list[EvalSample]) -> None:
|
|
107
|
+
# Convert to DataFrame for easy aggregation
|
|
108
|
+
data = []
|
|
109
|
+
for eval_sample in eval_samples:
|
|
110
|
+
for grade in eval_sample.grades:
|
|
111
|
+
data.append(
|
|
112
|
+
{
|
|
113
|
+
"model": _extract_model_key(eval_sample.interaction.source),
|
|
114
|
+
"grader": grade.grader_key,
|
|
115
|
+
"score": grade.value,
|
|
116
|
+
}
|
|
117
|
+
)
|
|
118
|
+
df = pd.DataFrame(data)
|
|
119
|
+
|
|
120
|
+
# Create pivot table with models as rows and graders as columns
|
|
121
|
+
if not df.empty:
|
|
122
|
+
pivot_df = df.pivot_table(index="model", columns="grader", values="score", aggfunc="mean")
|
|
123
|
+
|
|
124
|
+
# Create Rich table from pivot
|
|
125
|
+
console = Console()
|
|
126
|
+
table = RichTable(title="GRADER EVALUATION RESULTS")
|
|
127
|
+
table.add_column("Model", style="cyan", no_wrap=True)
|
|
128
|
+
# One column per grader
|
|
129
|
+
for grader_name in pivot_df.columns:
|
|
130
|
+
table.add_column(grader_name, justify="center")
|
|
131
|
+
|
|
132
|
+
# Find the maximum score for each grader (excluding NaN values)
|
|
133
|
+
max_scores = {}
|
|
134
|
+
for grader_name in pivot_df.columns:
|
|
135
|
+
valid_scores = pivot_df[grader_name].dropna()
|
|
136
|
+
if len(valid_scores) > 0:
|
|
137
|
+
max_scores[grader_name] = valid_scores.max()
|
|
138
|
+
|
|
139
|
+
# One row per model
|
|
140
|
+
for model_name in pivot_df.index:
|
|
141
|
+
row = [model_name]
|
|
142
|
+
for grader_name in pivot_df.columns:
|
|
143
|
+
score = pivot_df.loc[model_name, grader_name]
|
|
144
|
+
# Avg will be nan if grader failed on all samples
|
|
145
|
+
if pd.isna(score):
|
|
146
|
+
row.append("All failed")
|
|
147
|
+
else:
|
|
148
|
+
score_str = f"{score:.3f}"
|
|
149
|
+
# Highlight the highest score in green
|
|
150
|
+
if grader_name in max_scores and abs(score - max_scores[grader_name]) < 1e-9:
|
|
151
|
+
row.append(f"[green]{score_str}[/green]")
|
|
152
|
+
else:
|
|
153
|
+
row.append(score_str)
|
|
154
|
+
table.add_row(*row)
|
|
155
|
+
|
|
156
|
+
console.print(table)
|
|
157
|
+
|
|
158
|
+
# Get the directory of the original caller (skip this method's frame)
|
|
159
|
+
caller_frame = inspect.currentframe()
|
|
160
|
+
original_caller_dir = None
|
|
161
|
+
if caller_frame is not None and caller_frame.f_back is not None:
|
|
162
|
+
original_caller_file = caller_frame.f_back.f_globals["__file__"]
|
|
163
|
+
original_caller_dir = str(Path(original_caller_file).parent)
|
|
164
|
+
|
|
165
|
+
# Save compact JSON of grader results
|
|
166
|
+
original_caller_dir = original_caller_dir or Path.cwd()
|
|
167
|
+
results_json = {}
|
|
168
|
+
for model_name in pivot_df.index:
|
|
169
|
+
results_json[model_name] = {}
|
|
170
|
+
for grader_name in pivot_df.columns:
|
|
171
|
+
score = pivot_df.loc[model_name, grader_name]
|
|
172
|
+
if pd.isna(score):
|
|
173
|
+
results_json[model_name][grader_name] = None
|
|
174
|
+
else:
|
|
175
|
+
# Convert pandas scalar to native Python type
|
|
176
|
+
score_val = score.item() if hasattr(score, "item") else float(score) # type: ignore
|
|
177
|
+
results_json[model_name][grader_name] = round(score_val, 3) # type:ignore
|
|
178
|
+
|
|
179
|
+
timestamp = datetime.now().replace(microsecond=0).isoformat()
|
|
180
|
+
eval_dir = Path(original_caller_dir) / "adaptive_eval_samples" / timestamp
|
|
181
|
+
eval_dir.mkdir(parents=True, exist_ok=True)
|
|
182
|
+
|
|
183
|
+
results_path = eval_dir / "aggregate_scores.json"
|
|
184
|
+
with open(results_path, "w") as f:
|
|
185
|
+
json.dump(results_json, f, indent=2)
|
|
186
|
+
print(f"\n📁 Aggregate results saved to: {results_path}")
|
|
187
|
+
|
|
188
|
+
# Save detailed evaluation samples as html
|
|
189
|
+
_save_detailed_eval_table(eval_samples, output_dir=str(eval_dir))
|
|
190
|
+
else:
|
|
191
|
+
print("No evaluation data to display")
|
|
@@ -0,0 +1,76 @@
|
|
|
1
|
+
import json
|
|
2
|
+
import types
|
|
3
|
+
from typing import Annotated, Literal, Self, Union, get_args, get_origin
|
|
4
|
+
|
|
5
|
+
from pydantic import BaseModel, Field
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class InputConfig(BaseModel):
|
|
9
|
+
def __init_subclass__(cls, **kwargs):
|
|
10
|
+
# Add union_variant BEFORE calling super (before Pydantic processes it)
|
|
11
|
+
if not hasattr(cls, "__annotations__"):
|
|
12
|
+
cls.__annotations__ = {}
|
|
13
|
+
if "union_variant" not in cls.__annotations__:
|
|
14
|
+
cls.__annotations__["union_variant"] = Annotated[Literal[cls.__name__], Field(default=cls.__name__)]
|
|
15
|
+
super().__init_subclass__(**kwargs)
|
|
16
|
+
if hasattr(cls, "__annotations__"):
|
|
17
|
+
for field_name, field_type in cls.__annotations__.items():
|
|
18
|
+
origin = get_origin(field_type)
|
|
19
|
+
if origin is Union or origin is types.UnionType:
|
|
20
|
+
# Add discriminator to the union annotation
|
|
21
|
+
cls.__annotations__[field_name] = Annotated[field_type, Field(discriminator="union_variant")]
|
|
22
|
+
|
|
23
|
+
# Add union_variant field to each variant class
|
|
24
|
+
variant_types = get_args(field_type)
|
|
25
|
+
for variant_type in variant_types:
|
|
26
|
+
# Skip None (for Optional types)
|
|
27
|
+
if variant_type is type(None):
|
|
28
|
+
continue
|
|
29
|
+
|
|
30
|
+
# Add union_variant: Literal["ClassName"] to the variant
|
|
31
|
+
# if not already present
|
|
32
|
+
if (
|
|
33
|
+
hasattr(variant_type, "__annotations__")
|
|
34
|
+
and "union_variant" not in variant_type.__annotations__
|
|
35
|
+
):
|
|
36
|
+
variant_type.__annotations__["union_variant"] = Literal[variant_type.__name__]
|
|
37
|
+
# Set default value
|
|
38
|
+
setattr(variant_type, "union_variant", variant_type.__name__)
|
|
39
|
+
|
|
40
|
+
@classmethod
|
|
41
|
+
def load_from_file(cls, json_file) -> Self:
|
|
42
|
+
with open(json_file) as f:
|
|
43
|
+
data = json.load(f)
|
|
44
|
+
return cls.model_validate(data)
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
from .dto.AdaptiveDataset import AdaptiveDatasetKind
|
|
48
|
+
from .dto.AdaptiveGrader import (
|
|
49
|
+
AdaptiveGrader as AdaptiveGrader,
|
|
50
|
+
)
|
|
51
|
+
from .dto.AdaptiveGrader import (
|
|
52
|
+
Judge as CustomJudge,
|
|
53
|
+
)
|
|
54
|
+
from .dto.AdaptiveGrader import (
|
|
55
|
+
JudgeExample as CustomJudgeExample,
|
|
56
|
+
)
|
|
57
|
+
from .dto.AdaptiveGrader import (
|
|
58
|
+
Prebuilt as PrebuiltJudge,
|
|
59
|
+
)
|
|
60
|
+
from .dto.AdaptiveGrader import (
|
|
61
|
+
PrebuiltConfigKey,
|
|
62
|
+
)
|
|
63
|
+
from .dto.AdaptiveGrader import (
|
|
64
|
+
Remote as RemoteRewardEndpoint,
|
|
65
|
+
)
|
|
66
|
+
|
|
67
|
+
__all__ = [
|
|
68
|
+
"InputConfig",
|
|
69
|
+
"AdaptiveDatasetKind",
|
|
70
|
+
"AdaptiveGrader",
|
|
71
|
+
"CustomJudge",
|
|
72
|
+
"CustomJudgeExample",
|
|
73
|
+
"PrebuiltJudge",
|
|
74
|
+
"PrebuiltConfigKey",
|
|
75
|
+
"RemoteRewardEndpoint",
|
|
76
|
+
]
|
|
@@ -0,0 +1,19 @@
|
|
|
1
|
+
from typing import Any, Callable, TypeVar, overload
|
|
2
|
+
|
|
3
|
+
from harmony_client.runtime.context import RecipeContext
|
|
4
|
+
from harmony_client.runtime.data import InputConfig
|
|
5
|
+
|
|
6
|
+
IN = TypeVar("IN", bound=InputConfig)
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
@overload
|
|
10
|
+
def recipe_main[IN: InputConfig](func: Callable[[IN, RecipeContext], Any]): ...
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
@overload
|
|
14
|
+
def recipe_main(func: Callable[[RecipeContext], Any]): ...
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def recipe_main(func):
|
|
18
|
+
func.is_recipe_main = True
|
|
19
|
+
return func
|
|
@@ -0,0 +1,23 @@
|
|
|
1
|
+
# generated by datamodel-codegen:
|
|
2
|
+
# filename: AdaptiveDataset.json
|
|
3
|
+
|
|
4
|
+
from __future__ import annotations
|
|
5
|
+
from enum import Enum
|
|
6
|
+
from typing import Optional
|
|
7
|
+
from .base import DtoBaseModel
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class AdaptiveDatasetKind(Enum):
|
|
11
|
+
Prompt = "Prompt"
|
|
12
|
+
Completion = "Completion"
|
|
13
|
+
Metric = "Metric"
|
|
14
|
+
Preference = "Preference"
|
|
15
|
+
Mixed = "Mixed"
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class AdaptiveDataset(DtoBaseModel):
|
|
19
|
+
dataset_key: str
|
|
20
|
+
feedback_key: Optional[str] = None
|
|
21
|
+
file: str
|
|
22
|
+
id: str
|
|
23
|
+
kind: Optional[AdaptiveDatasetKind] = AdaptiveDatasetKind.Mixed
|
|
@@ -0,0 +1,68 @@
|
|
|
1
|
+
# generated by datamodel-codegen:
|
|
2
|
+
# filename: AdaptiveGrader.json
|
|
3
|
+
|
|
4
|
+
from __future__ import annotations
|
|
5
|
+
from typing import Annotated, Any, List, Literal, Optional, Union
|
|
6
|
+
from .base import DtoBaseModel
|
|
7
|
+
from uuid import UUID
|
|
8
|
+
from pydantic import Field
|
|
9
|
+
from enum import Enum
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class ChatMessage(DtoBaseModel):
|
|
13
|
+
content: str
|
|
14
|
+
metadata: Optional[Any] = None
|
|
15
|
+
role: str
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class Custom(DtoBaseModel):
|
|
19
|
+
description: Optional[str] = None
|
|
20
|
+
type: Literal["Custom"] = "Custom"
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class JudgeExample(DtoBaseModel):
|
|
24
|
+
id: UUID
|
|
25
|
+
input: List[ChatMessage]
|
|
26
|
+
output: str
|
|
27
|
+
pass_: Annotated[bool, Field(alias="pass")]
|
|
28
|
+
reasoning: Optional[str] = None
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class PrebuiltConfigKey(Enum):
|
|
32
|
+
AnswerRelevancy = "AnswerRelevancy"
|
|
33
|
+
ContextRelevancy = "ContextRelevancy"
|
|
34
|
+
Faithfulness = "Faithfulness"
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
class Remote(DtoBaseModel):
|
|
38
|
+
description: str
|
|
39
|
+
type: Literal["Remote"] = "Remote"
|
|
40
|
+
url: str
|
|
41
|
+
version: str
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
class Judge(DtoBaseModel):
|
|
45
|
+
criteria: str
|
|
46
|
+
examples: List[JudgeExample]
|
|
47
|
+
model_key: str
|
|
48
|
+
model_uri: str
|
|
49
|
+
system_template: str
|
|
50
|
+
type: Literal["Judge"] = "Judge"
|
|
51
|
+
user_template: str
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
class Prebuilt(DtoBaseModel):
|
|
55
|
+
model_key: str
|
|
56
|
+
model_uri: str
|
|
57
|
+
prebuilt_config_key: PrebuiltConfigKey
|
|
58
|
+
type: Literal["Prebuilt"] = "Prebuilt"
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
class AdaptiveGrader(DtoBaseModel):
|
|
62
|
+
config: Annotated[
|
|
63
|
+
Union[Judge, Prebuilt, Remote, Custom], Field(discriminator="type")
|
|
64
|
+
]
|
|
65
|
+
grader_id: UUID
|
|
66
|
+
key: str
|
|
67
|
+
metric_id: UUID
|
|
68
|
+
name: str
|
|
@@ -0,0 +1,19 @@
|
|
|
1
|
+
# generated by datamodel-codegen:
|
|
2
|
+
# filename: AdaptiveModel.json
|
|
3
|
+
|
|
4
|
+
from __future__ import annotations
|
|
5
|
+
from typing import Annotated, Optional
|
|
6
|
+
from pydantic import Field
|
|
7
|
+
from .base import DtoBaseModel
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class ModelParams(DtoBaseModel):
|
|
11
|
+
kv_cache_len: Annotated[Optional[int], Field(ge=0)] = None
|
|
12
|
+
max_seq_len: Annotated[Optional[int], Field(ge=0)] = None
|
|
13
|
+
tp: Annotated[Optional[int], Field(ge=0)] = None
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class AdaptiveModel(DtoBaseModel):
|
|
17
|
+
model_key: Optional[str] = None
|
|
18
|
+
params: Optional[ModelParams] = None
|
|
19
|
+
path: str
|
|
@@ -0,0 +1,93 @@
|
|
|
1
|
+
# generated by datamodel-codegen:
|
|
2
|
+
# filename: DatasetSampleFormats.json
|
|
3
|
+
|
|
4
|
+
from __future__ import annotations
|
|
5
|
+
from typing import Annotated, Any, Dict, List, Optional, Union
|
|
6
|
+
from uuid import UUID
|
|
7
|
+
from .base import DtoBaseModel
|
|
8
|
+
from pydantic import Field, RootModel
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class SampleMetadata(DtoBaseModel):
|
|
12
|
+
created_at: int
|
|
13
|
+
external_data: Optional[Any] = None
|
|
14
|
+
id: UUID
|
|
15
|
+
labels: Optional[Dict[str, Any]] = None
|
|
16
|
+
model_id: Optional[UUID] = None
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class Type(DtoBaseModel):
|
|
20
|
+
enum: Optional[List[str]] = None
|
|
21
|
+
type: Optional[str] = None
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class Value(DtoBaseModel):
|
|
25
|
+
type: Optional[str] = None
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
class Properties(DtoBaseModel):
|
|
29
|
+
type: Optional[Type] = None
|
|
30
|
+
value: Optional[Value] = None
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
class OneOf(DtoBaseModel):
|
|
34
|
+
properties: Optional[Properties] = None
|
|
35
|
+
required: Optional[List[str]] = None
|
|
36
|
+
type: Optional[str] = None
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
class Value1(DtoBaseModel):
|
|
40
|
+
format: Optional[str] = None
|
|
41
|
+
type: Optional[str] = None
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
class Properties1(DtoBaseModel):
|
|
45
|
+
type: Optional[Type] = None
|
|
46
|
+
value: Optional[Value1] = None
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
class OneOf1(DtoBaseModel):
|
|
50
|
+
properties: Optional[Properties1] = None
|
|
51
|
+
required: Optional[List[str]] = None
|
|
52
|
+
type: Optional[str] = None
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
class TurnContent(DtoBaseModel):
|
|
56
|
+
oneOf: Optional[List[Union[OneOf, OneOf1]]] = None
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
class TurnTuple(RootModel[List]):
|
|
60
|
+
root: Annotated[List, Field(max_length=2, min_length=2)]
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
class DatasetMetricSample(DtoBaseModel):
|
|
64
|
+
completion: TurnTuple
|
|
65
|
+
metadata: SampleMetadata
|
|
66
|
+
metrics: Dict[str, float]
|
|
67
|
+
prompt: List[TurnTuple]
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
class DatasetPreferenceSample(DtoBaseModel):
|
|
71
|
+
bad_completion: TurnTuple
|
|
72
|
+
good_completion: TurnTuple
|
|
73
|
+
metadata: SampleMetadata
|
|
74
|
+
metric: Optional[str] = "preference"
|
|
75
|
+
prompt: List[TurnTuple]
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
class DatasetPromptSample(DtoBaseModel):
|
|
79
|
+
metadata: SampleMetadata
|
|
80
|
+
prompt: List[TurnTuple]
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
class DatasetSample(DtoBaseModel):
|
|
84
|
+
completion: TurnTuple
|
|
85
|
+
metadata: SampleMetadata
|
|
86
|
+
prompt: List[TurnTuple]
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
class DatasetSampleFormats(DtoBaseModel):
|
|
90
|
+
completion: DatasetSample
|
|
91
|
+
metric: DatasetMetricSample
|
|
92
|
+
preference: DatasetPreferenceSample
|
|
93
|
+
prompt: DatasetPromptSample
|
|
@@ -0,0 +1,23 @@
|
|
|
1
|
+
from collections.abc import Awaitable, Callable
|
|
2
|
+
|
|
3
|
+
from harmony_client import TrainingModel
|
|
4
|
+
from harmony_client.artifacts.model_artifact import ModelArtifact
|
|
5
|
+
from harmony_client.runtime import RecipeContext
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
async def save_with_artifact(
|
|
9
|
+
model: TrainingModel,
|
|
10
|
+
model_name: str,
|
|
11
|
+
inference_only: bool = True,
|
|
12
|
+
ctx: RecipeContext | None = None,
|
|
13
|
+
original_save_method: Callable[[TrainingModel, str, bool], Awaitable[str]] | None = None,
|
|
14
|
+
) -> str:
|
|
15
|
+
if original_save_method is None:
|
|
16
|
+
raise ValueError("original_save_method must be provided")
|
|
17
|
+
|
|
18
|
+
real_model_key = await original_save_method(model, model_name, inference_only)
|
|
19
|
+
|
|
20
|
+
if ctx is not None:
|
|
21
|
+
ModelArtifact(real_model_key, ctx)
|
|
22
|
+
|
|
23
|
+
return real_model_key
|