harmony-client 0.1.0__cp312-cp312-macosx_10_9_universal2.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- harmony_client/__init__.py +78 -0
- harmony_client/artifacts/__init__.py +5 -0
- harmony_client/artifacts/custom_artifact.py +46 -0
- harmony_client/artifacts/dataset_artifact.py +268 -0
- harmony_client/artifacts/model_artifact.py +34 -0
- harmony_client/file_storage.py +378 -0
- harmony_client/harmony_client.cpython-312-darwin.so +0 -0
- harmony_client/harmony_client.pyi +1615 -0
- harmony_client/internal/__init__.py +7 -0
- harmony_client/internal/eval_samples_html.py +122 -0
- harmony_client/internal/utils.py +9 -0
- harmony_client/logging_table.py +121 -0
- harmony_client/parameters/__init__.py +295 -0
- harmony_client/parameters/dataset_kinds.py +49 -0
- harmony_client/parameters/model_kinds.py +13 -0
- harmony_client/py.typed +0 -0
- harmony_client/runtime/__init__.py +29 -0
- harmony_client/runtime/context.py +191 -0
- harmony_client/runtime/data.py +76 -0
- harmony_client/runtime/decorators.py +19 -0
- harmony_client/runtime/dto/AdaptiveDataset.py +23 -0
- harmony_client/runtime/dto/AdaptiveGrader.py +68 -0
- harmony_client/runtime/dto/AdaptiveModel.py +19 -0
- harmony_client/runtime/dto/DatasetSampleFormats.py +93 -0
- harmony_client/runtime/dto/__init__.py +2 -0
- harmony_client/runtime/dto/base.py +7 -0
- harmony_client/runtime/model_artifact_save.py +23 -0
- harmony_client/runtime/runner.py +368 -0
- harmony_client/runtime/simple_notifier.py +21 -0
- harmony_client-0.1.0.dist-info/METADATA +38 -0
- harmony_client-0.1.0.dist-info/RECORD +32 -0
- harmony_client-0.1.0.dist-info/WHEEL +4 -0
|
@@ -0,0 +1,1615 @@
|
|
|
1
|
+
# ruff: noqa: E501, F401
|
|
2
|
+
"""Harmony Client - Python bindings for the Adaptive Harmony ML orchestration platform.
|
|
3
|
+
|
|
4
|
+
This module provides the Python interface to interact with Harmony workers for model training
|
|
5
|
+
and inference operations. It includes thread management, model lifecycle operations, and
|
|
6
|
+
training/evaluation utilities.
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
from enum import Enum
|
|
10
|
+
from typing import Any, MutableSequence, NamedTuple, Optional, Sequence, Type, TypeVar
|
|
11
|
+
|
|
12
|
+
from adaptive_harmony.runtime import RecipeContext
|
|
13
|
+
from pydantic import BaseModel
|
|
14
|
+
from typing_extensions import Literal, Required, TypedDict
|
|
15
|
+
|
|
16
|
+
T = TypeVar("T", bound=BaseModel)
|
|
17
|
+
|
|
18
|
+
# Constants
|
|
19
|
+
DEFAULT_MAX_DRAFT_STEPS: int
|
|
20
|
+
"""Default maximum number of draft steps for speculative decoding."""
|
|
21
|
+
|
|
22
|
+
class ImageFragment(TypedDict, total=False):
|
|
23
|
+
"""Fragment representing an image in a conversation turn.
|
|
24
|
+
|
|
25
|
+
Attributes:
|
|
26
|
+
type: Must be "image"
|
|
27
|
+
url: URL or data URI of the image (supports data URLs, file paths, and HTTP URLs)
|
|
28
|
+
"""
|
|
29
|
+
|
|
30
|
+
type: Required[Literal["image"]]
|
|
31
|
+
url: Required[str]
|
|
32
|
+
|
|
33
|
+
class TextFragment(TypedDict, total=False):
|
|
34
|
+
"""Fragment representing text content in a conversation turn.
|
|
35
|
+
|
|
36
|
+
Attributes:
|
|
37
|
+
type: Must be "text"
|
|
38
|
+
text: The text content
|
|
39
|
+
"""
|
|
40
|
+
|
|
41
|
+
type: Required[Literal["text"]]
|
|
42
|
+
text: Required[str]
|
|
43
|
+
|
|
44
|
+
Fragment = ImageFragment | TextFragment
|
|
45
|
+
"""Union type for content fragments that can appear in conversation turns."""
|
|
46
|
+
|
|
47
|
+
class EvalSample:
|
|
48
|
+
"""Represents a single evaluation sample with its interaction and grades.
|
|
49
|
+
|
|
50
|
+
An evaluation sample captures a conversation thread along with grades assigned
|
|
51
|
+
by one or more graders. Used for evaluation artifact creation and analysis.
|
|
52
|
+
|
|
53
|
+
Attributes:
|
|
54
|
+
interaction: The conversation interaction being evaluated
|
|
55
|
+
grades: List of grades from different graders
|
|
56
|
+
dataset_key: Key identifying the dataset this sample belongs to
|
|
57
|
+
id: Unique identifier for this evaluation sample (auto-generated UUID)
|
|
58
|
+
"""
|
|
59
|
+
|
|
60
|
+
interaction: EvalSampleInteraction
|
|
61
|
+
grades: MutableSequence[Grade]
|
|
62
|
+
dataset_key: str
|
|
63
|
+
id: str
|
|
64
|
+
|
|
65
|
+
def __new__(
|
|
66
|
+
cls,
|
|
67
|
+
interaction: EvalSampleInteraction,
|
|
68
|
+
grades: MutableSequence[Grade],
|
|
69
|
+
dataset_key: str,
|
|
70
|
+
) -> EvalSample: ...
|
|
71
|
+
def __repr__(self) -> str: ...
|
|
72
|
+
|
|
73
|
+
class EvalSampleInteraction:
|
|
74
|
+
"""Represents a conversation interaction for evaluation.
|
|
75
|
+
|
|
76
|
+
Encapsulates a conversation thread and optionally tracks which model
|
|
77
|
+
or source generated the interaction.
|
|
78
|
+
|
|
79
|
+
Attributes:
|
|
80
|
+
thread: The conversation thread (messages between user and assistant)
|
|
81
|
+
source: Optional identifier for the model or system that generated this interaction
|
|
82
|
+
"""
|
|
83
|
+
|
|
84
|
+
thread: StringThread
|
|
85
|
+
source: Optional[str]
|
|
86
|
+
|
|
87
|
+
def __new__(cls, thread: StringThread, source: Optional[str] = None) -> EvalSampleInteraction: ...
|
|
88
|
+
def __repr__(self) -> str: ...
|
|
89
|
+
|
|
90
|
+
class EvaluationArtifactBase:
|
|
91
|
+
"""Base class for evaluation artifacts that can be registered with jobs.
|
|
92
|
+
|
|
93
|
+
Evaluation artifacts track the results of model evaluations, including
|
|
94
|
+
the evaluated samples and their grades. Can be uploaded to object storage.
|
|
95
|
+
|
|
96
|
+
Attributes:
|
|
97
|
+
artifact: The underlying job artifact with metadata
|
|
98
|
+
id: Unique identifier for this artifact
|
|
99
|
+
name: Human-readable name for the evaluation
|
|
100
|
+
kind: Type of artifact (always "eval" for evaluations)
|
|
101
|
+
uri: Optional URI where evaluation results are stored
|
|
102
|
+
"""
|
|
103
|
+
|
|
104
|
+
artifact: JobArtifact
|
|
105
|
+
id: str
|
|
106
|
+
name: str
|
|
107
|
+
kind: str
|
|
108
|
+
uri: Optional[str]
|
|
109
|
+
|
|
110
|
+
def __new__(cls, name: str, uri: str, id: str, **py_kwargs) -> EvaluationArtifactBase: ...
|
|
111
|
+
def samples_to_adaptive_json(self, samples: MutableSequence[EvalSample]) -> list[str]:
|
|
112
|
+
"""Convert evaluation samples to JSONL format for storage.
|
|
113
|
+
|
|
114
|
+
Args:
|
|
115
|
+
samples: List of evaluation samples to serialize
|
|
116
|
+
|
|
117
|
+
Returns:
|
|
118
|
+
List of JSON strings, one per sample, suitable for JSONL file format
|
|
119
|
+
"""
|
|
120
|
+
...
|
|
121
|
+
def __repr__(self) -> str: ...
|
|
122
|
+
|
|
123
|
+
class Grade:
|
|
124
|
+
"""Represents a grade assigned to an evaluation sample by a grader.
|
|
125
|
+
|
|
126
|
+
Grades can include numeric scores and optional reasoning explaining
|
|
127
|
+
the score.
|
|
128
|
+
|
|
129
|
+
Attributes:
|
|
130
|
+
value: Numeric grade value (typically 0.0-1.0 range)
|
|
131
|
+
grader_key: Identifier for the grader that assigned this grade
|
|
132
|
+
reasoning: Optional explanation for why this grade was assigned
|
|
133
|
+
"""
|
|
134
|
+
|
|
135
|
+
value: float
|
|
136
|
+
grader_key: str
|
|
137
|
+
reasoning: Optional[str]
|
|
138
|
+
|
|
139
|
+
def __new__(cls, value: float, grader_key: str, reasoning: Optional[str] = None) -> Grade: ...
|
|
140
|
+
def __repr__(self) -> str: ...
|
|
141
|
+
|
|
142
|
+
class JobArtifact:
|
|
143
|
+
"""Represents an artifact produced by a training or evaluation job.
|
|
144
|
+
|
|
145
|
+
Job artifacts track outputs like trained models, evaluation results, datasets,
|
|
146
|
+
or custom artifacts. They can be registered with jobs to appear in the UI.
|
|
147
|
+
|
|
148
|
+
Attributes:
|
|
149
|
+
id: Unique identifier for this artifact
|
|
150
|
+
name: Human-readable name
|
|
151
|
+
kind: Type of artifact - "model", "eval", "dataset", or "custom"
|
|
152
|
+
metadata: Additional key-value metadata for the artifact
|
|
153
|
+
uri: Optional URI where the artifact is stored
|
|
154
|
+
|
|
155
|
+
Example:
|
|
156
|
+
```python
|
|
157
|
+
artifact = JobArtifact(
|
|
158
|
+
id="my-model-v1",
|
|
159
|
+
name="Fine-tuned Model",
|
|
160
|
+
kind="model",
|
|
161
|
+
uri="s3://bucket/models/my-model",
|
|
162
|
+
checkpoint_step=1000,
|
|
163
|
+
loss=0.25
|
|
164
|
+
)
|
|
165
|
+
notifier.register_artifact(artifact)
|
|
166
|
+
```
|
|
167
|
+
"""
|
|
168
|
+
|
|
169
|
+
id: str
|
|
170
|
+
name: str
|
|
171
|
+
kind: str
|
|
172
|
+
metadata: dict[str, Any]
|
|
173
|
+
uri: Optional[str]
|
|
174
|
+
|
|
175
|
+
def __new__(
|
|
176
|
+
cls,
|
|
177
|
+
id: str,
|
|
178
|
+
name: str,
|
|
179
|
+
kind: str,
|
|
180
|
+
uri: Optional[str] = None,
|
|
181
|
+
**py_kwargs,
|
|
182
|
+
) -> JobArtifact: ...
|
|
183
|
+
def __repr__(self) -> str: ...
|
|
184
|
+
|
|
185
|
+
class StringTurn(NamedTuple):
|
|
186
|
+
"""A single turn in a conversation thread with string content.
|
|
187
|
+
|
|
188
|
+
Attributes:
|
|
189
|
+
role: The role of the entity that is creating the content
|
|
190
|
+
content: The text content of the turn
|
|
191
|
+
"""
|
|
192
|
+
|
|
193
|
+
role: str
|
|
194
|
+
content: str
|
|
195
|
+
|
|
196
|
+
class TokenizedTurn(NamedTuple):
|
|
197
|
+
"""A single turn in a conversation thread with tokenized content.
|
|
198
|
+
|
|
199
|
+
Attributes:
|
|
200
|
+
role: The role of the entity that is creating the content
|
|
201
|
+
content: The tokenized content as a list of token IDs
|
|
202
|
+
"""
|
|
203
|
+
|
|
204
|
+
role: str
|
|
205
|
+
content: list[int]
|
|
206
|
+
|
|
207
|
+
class ModelConfigResponse:
|
|
208
|
+
"""Configuration metadata for a model retrieved from the control plane.
|
|
209
|
+
|
|
210
|
+
Attributes:
|
|
211
|
+
model_id: Unique identifier for the model
|
|
212
|
+
model_key: Human-readable key for the model
|
|
213
|
+
path: File system or registry path to the model
|
|
214
|
+
tp: Optional tensor parallelism degree
|
|
215
|
+
kv_cache_len: Optional KV cache length for inference
|
|
216
|
+
max_seq_len: Optional maximum sequence length the model supports
|
|
217
|
+
"""
|
|
218
|
+
|
|
219
|
+
model_id: str
|
|
220
|
+
model_key: str
|
|
221
|
+
path: str
|
|
222
|
+
tp: int | None
|
|
223
|
+
kv_cache_len: int | None
|
|
224
|
+
max_seq_len: int | None
|
|
225
|
+
|
|
226
|
+
class DatasetConfigResponse:
|
|
227
|
+
"""Configuration metadata for a dataset retrieved from the control plane.
|
|
228
|
+
|
|
229
|
+
Attributes:
|
|
230
|
+
dataset_id: Unique identifier for the dataset
|
|
231
|
+
dataset_key: Human-readable key for the dataset
|
|
232
|
+
name: Display name of the dataset
|
|
233
|
+
file_path: Path to the dataset file
|
|
234
|
+
kind: Type of dataset (e.g., "jsonl", "parquet")
|
|
235
|
+
"""
|
|
236
|
+
|
|
237
|
+
dataset_id: str
|
|
238
|
+
dataset_key: str
|
|
239
|
+
name: str
|
|
240
|
+
file_path: str
|
|
241
|
+
kind: str
|
|
242
|
+
|
|
243
|
+
class GraderConfigResponse:
|
|
244
|
+
"""Configuration metadata for a grader retrieved from the control plane.
|
|
245
|
+
|
|
246
|
+
Attributes:
|
|
247
|
+
grader_id: Unique identifier for the grader
|
|
248
|
+
key: Human-readable key for the grader
|
|
249
|
+
name: Display name of the grader
|
|
250
|
+
harmony_url: URL of the harmony instance to use for grading
|
|
251
|
+
grader_config_json: JSON configuration for the grader
|
|
252
|
+
"""
|
|
253
|
+
|
|
254
|
+
grader_id: str
|
|
255
|
+
key: str
|
|
256
|
+
name: str
|
|
257
|
+
harmony_url: str
|
|
258
|
+
grader_config_json: str
|
|
259
|
+
|
|
260
|
+
class HarmonyClient:
|
|
261
|
+
"""Main client for interacting with Harmony workers.
|
|
262
|
+
|
|
263
|
+
The HarmonyClient is the primary interface for creating models and accessing
|
|
264
|
+
platform configuration. It manages the connection to Harmony workers and
|
|
265
|
+
provides access to model builders for spawning inference and training models.
|
|
266
|
+
|
|
267
|
+
Example:
|
|
268
|
+
```python
|
|
269
|
+
from harmony_client import get_client
|
|
270
|
+
|
|
271
|
+
client = await get_client(
|
|
272
|
+
addr="ws://localhost:8080",
|
|
273
|
+
num_gpus=1,
|
|
274
|
+
api_key="my-key"
|
|
275
|
+
)
|
|
276
|
+
|
|
277
|
+
# Create a model builder
|
|
278
|
+
model = client.model("model_registry://llama-3.1-8b")
|
|
279
|
+
|
|
280
|
+
# Spawn an inference model
|
|
281
|
+
inf_model = await model.spawn_inference("my-model")
|
|
282
|
+
|
|
283
|
+
# Generate text
|
|
284
|
+
thread = StringThread([("user", "Hello!")])
|
|
285
|
+
result = await inf_model.generate(thread)
|
|
286
|
+
print(result.last_content())
|
|
287
|
+
```
|
|
288
|
+
"""
|
|
289
|
+
def model(self, path: str, kv_cache_len: int = 131072, tokens_to_generate: int = 2048) -> ModelBuilder:
|
|
290
|
+
"""Create a model builder for spawning inference or training models.
|
|
291
|
+
|
|
292
|
+
Args:
|
|
293
|
+
path: Path to the model. Can be:
|
|
294
|
+
- Model registry key: "model_registry://llama-3.1-8b"
|
|
295
|
+
- External provider: "openai://gpt-4", "anthropic://claude-3-5-sonnet"
|
|
296
|
+
- URL with API key: "openai://gpt-4?api_key=sk-..."
|
|
297
|
+
kv_cache_len: KV cache length for inference (default: 131072). Will be ignored for external models as we do not control it.
|
|
298
|
+
tokens_to_generate: Maximum tokens to generate (default: 2048)
|
|
299
|
+
|
|
300
|
+
Returns:
|
|
301
|
+
ModelBuilder that can be configured and spawned
|
|
302
|
+
"""
|
|
303
|
+
...
|
|
304
|
+
def session_id(self) -> str:
|
|
305
|
+
"""Get the unique session ID for this client connection.
|
|
306
|
+
|
|
307
|
+
Returns:
|
|
308
|
+
Session UUID as a string
|
|
309
|
+
"""
|
|
310
|
+
...
|
|
311
|
+
async def get_grader_config(self, grader_key: str) -> GraderConfigResponse:
|
|
312
|
+
"""Fetch grader configuration from the control plane.
|
|
313
|
+
|
|
314
|
+
Args:
|
|
315
|
+
grader_key: Key identifying the grader
|
|
316
|
+
|
|
317
|
+
Returns:
|
|
318
|
+
GraderConfigResponse with grader metadata
|
|
319
|
+
|
|
320
|
+
Raises:
|
|
321
|
+
RecipeError: If control_plane_url was not provided to get_client()
|
|
322
|
+
"""
|
|
323
|
+
...
|
|
324
|
+
async def get_dataset_config(self, dataset_key: str) -> DatasetConfigResponse:
|
|
325
|
+
"""Fetch dataset configuration from the control plane.
|
|
326
|
+
|
|
327
|
+
Args:
|
|
328
|
+
dataset_key: Key identifying the dataset
|
|
329
|
+
|
|
330
|
+
Returns:
|
|
331
|
+
DatasetConfigResponse with dataset metadata
|
|
332
|
+
|
|
333
|
+
Raises:
|
|
334
|
+
RecipeError: If control_plane_url was not provided to get_client()
|
|
335
|
+
"""
|
|
336
|
+
...
|
|
337
|
+
async def get_model_config(self, model_key: str) -> ModelConfigResponse:
|
|
338
|
+
"""Fetch model configuration from the control plane.
|
|
339
|
+
|
|
340
|
+
Args:
|
|
341
|
+
model_key: Key identifying the model
|
|
342
|
+
|
|
343
|
+
Returns:
|
|
344
|
+
ModelConfigResponse with model metadata
|
|
345
|
+
|
|
346
|
+
Raises:
|
|
347
|
+
RecipeError: If control_plane_url was not provided to get_client()
|
|
348
|
+
"""
|
|
349
|
+
...
|
|
350
|
+
def close(self):
|
|
351
|
+
"""Close the client connection and release resources."""
|
|
352
|
+
...
|
|
353
|
+
def __enter__(self) -> HarmonyClient: ...
|
|
354
|
+
def __exit__(self, exc_type, exc_value, traceback) -> None: ...
|
|
355
|
+
|
|
356
|
+
class InferenceModel:
|
|
357
|
+
"""Model instance for running inference operations.
|
|
358
|
+
|
|
359
|
+
InferenceModel provides methods for text generation, tokenization, scoring,
|
|
360
|
+
and computing log probabilities. It can be created from a ModelBuilder using
|
|
361
|
+
spawn_inference().
|
|
362
|
+
|
|
363
|
+
The model is automatically deallocated when the object is garbage collected,
|
|
364
|
+
but you can explicitly call dealloc() to free GPU memory sooner.
|
|
365
|
+
|
|
366
|
+
Example:
|
|
367
|
+
```python
|
|
368
|
+
model = client.model("model_registry://llama-3.1-8b")
|
|
369
|
+
inf_model = await model.spawn_inference("my-inf-model")
|
|
370
|
+
|
|
371
|
+
# Generate text
|
|
372
|
+
thread = StringThread([("user", "What is 2+2?")])
|
|
373
|
+
result = await inf_model.generate(thread)
|
|
374
|
+
print(result.last_content()) # "4"
|
|
375
|
+
```
|
|
376
|
+
"""
|
|
377
|
+
def is_scalar(self) -> bool:
|
|
378
|
+
"""Check if this model has a scalar output head (for scoring/reward models).
|
|
379
|
+
|
|
380
|
+
Returns:
|
|
381
|
+
True if model outputs scalars instead of next-token probabilities
|
|
382
|
+
"""
|
|
383
|
+
...
|
|
384
|
+
async def dealloc(self) -> None:
|
|
385
|
+
"""Deallocate the model from GPU memory.
|
|
386
|
+
|
|
387
|
+
Explicitly frees GPU resources. The model cannot be used after this call.
|
|
388
|
+
Models are automatically deallocated on garbage collection if not called.
|
|
389
|
+
"""
|
|
390
|
+
...
|
|
391
|
+
async def load(self, path: str) -> None:
|
|
392
|
+
"""Load model weights from a checkpoint path.
|
|
393
|
+
|
|
394
|
+
Args:
|
|
395
|
+
path: Path to model weights
|
|
396
|
+
"""
|
|
397
|
+
...
|
|
398
|
+
async def generate(self, thread: StringThread, return_timings: bool = False) -> StringThread:
|
|
399
|
+
"""Generate text completion for a conversation thread.
|
|
400
|
+
|
|
401
|
+
Appends an assistant turn with the generated completion to the thread.
|
|
402
|
+
|
|
403
|
+
Args:
|
|
404
|
+
thread: Conversation thread to continue
|
|
405
|
+
return_timings: If True, returns tuple of (thread, timings)
|
|
406
|
+
|
|
407
|
+
Returns:
|
|
408
|
+
Updated thread with assistant response, or (thread, timings) tuple
|
|
409
|
+
|
|
410
|
+
Example:
|
|
411
|
+
```python
|
|
412
|
+
thread = StringThread([("user", "Hello!")])
|
|
413
|
+
result = await model.generate(thread)
|
|
414
|
+
print(result.last_content()) # Generated response
|
|
415
|
+
```
|
|
416
|
+
"""
|
|
417
|
+
...
|
|
418
|
+
async def generate_tokens(self, thread: StringThread) -> TokenizedThread:
|
|
419
|
+
"""Generate token IDs for a conversation thread.
|
|
420
|
+
|
|
421
|
+
Similar to generate() but returns tokenized output instead of strings.
|
|
422
|
+
|
|
423
|
+
Args:
|
|
424
|
+
thread: Conversation thread to continue
|
|
425
|
+
|
|
426
|
+
Returns:
|
|
427
|
+
TokenizedThread with generated token IDs
|
|
428
|
+
"""
|
|
429
|
+
...
|
|
430
|
+
async def generate_and_validate(
|
|
431
|
+
self,
|
|
432
|
+
thread: StringThread,
|
|
433
|
+
pydantic_model: Type[T],
|
|
434
|
+
max_parsing_retries: int = 1,
|
|
435
|
+
) -> tuple[str, T]:
|
|
436
|
+
"""Generate structured output validated against a Pydantic model.
|
|
437
|
+
|
|
438
|
+
Generates JSON output and parses it into the specified Pydantic model,
|
|
439
|
+
with automatic retries on parse failures.
|
|
440
|
+
|
|
441
|
+
Args:
|
|
442
|
+
thread: Conversation thread (should request JSON output)
|
|
443
|
+
pydantic_model: Pydantic model class to validate against
|
|
444
|
+
max_parsing_retries: Maximum retry attempts on parse failures
|
|
445
|
+
|
|
446
|
+
Returns:
|
|
447
|
+
Tuple of (raw_json_string, parsed_model_instance)
|
|
448
|
+
|
|
449
|
+
Example:
|
|
450
|
+
```python
|
|
451
|
+
from pydantic import BaseModel
|
|
452
|
+
|
|
453
|
+
class Response(BaseModel):
|
|
454
|
+
answer: str
|
|
455
|
+
confidence: float
|
|
456
|
+
|
|
457
|
+
thread = StringThread([(
|
|
458
|
+
"user",
|
|
459
|
+
"Return JSON with 'answer' and 'confidence' fields"
|
|
460
|
+
)])
|
|
461
|
+
json_str, parsed = await model.generate_and_validate(
|
|
462
|
+
thread, Response
|
|
463
|
+
)
|
|
464
|
+
print(parsed.answer, parsed.confidence)
|
|
465
|
+
```
|
|
466
|
+
"""
|
|
467
|
+
...
|
|
468
|
+
async def tokenize_thread(self, thread: StringThread) -> TokenizedThread:
|
|
469
|
+
"""Convert a StringThread to a TokenizedThread.
|
|
470
|
+
|
|
471
|
+
Args:
|
|
472
|
+
thread: String-based conversation thread
|
|
473
|
+
|
|
474
|
+
Returns:
|
|
475
|
+
Tokenized version of the thread
|
|
476
|
+
"""
|
|
477
|
+
...
|
|
478
|
+
async def detokenize_thread(self, thread: TokenizedThread) -> StringThread:
|
|
479
|
+
"""Convert a TokenizedThread back to a StringThread.
|
|
480
|
+
|
|
481
|
+
Args:
|
|
482
|
+
thread: Tokenized conversation thread
|
|
483
|
+
|
|
484
|
+
Returns:
|
|
485
|
+
String version of the thread
|
|
486
|
+
"""
|
|
487
|
+
...
|
|
488
|
+
|
|
489
|
+
# the outputs of serialize_* are a list of tokens for the whole thread
|
|
490
|
+
# a list with the image tokens, and a list with each token's weight
|
|
491
|
+
async def serialize_thread(
|
|
492
|
+
self, thread: StringThread
|
|
493
|
+
) -> tuple[list[int], list[tuple[int, list[float]]], list[float]]:
|
|
494
|
+
"""Serialize a thread into tokens with weights for training.
|
|
495
|
+
|
|
496
|
+
Returns:
|
|
497
|
+
Tuple of (token_ids, image_tokens, weights) where:
|
|
498
|
+
- token_ids: All tokens in the thread
|
|
499
|
+
- image_tokens: List of (position, embedding) for images
|
|
500
|
+
- weights: Per-token training weights (0.0 = skip, 1.0 = train)
|
|
501
|
+
"""
|
|
502
|
+
...
|
|
503
|
+
async def serialize_tokenized_thread(
|
|
504
|
+
self, thread: TokenizedThread
|
|
505
|
+
) -> tuple[list[int], list[tuple[int, list[float]]], list[float]]:
|
|
506
|
+
"""Serialize a tokenized thread with weights for training.
|
|
507
|
+
|
|
508
|
+
Returns:
|
|
509
|
+
Tuple of (token_ids, image_tokens, weights)
|
|
510
|
+
"""
|
|
511
|
+
...
|
|
512
|
+
async def logprobs(self, thread: StringThread) -> float:
|
|
513
|
+
"""Compute the total log probability of a thread.
|
|
514
|
+
|
|
515
|
+
Returns the sum of log probabilities of all weighted tokens.
|
|
516
|
+
|
|
517
|
+
Args:
|
|
518
|
+
thread: Conversation thread to score
|
|
519
|
+
|
|
520
|
+
Returns:
|
|
521
|
+
Total log probability (summed across weighted tokens)
|
|
522
|
+
"""
|
|
523
|
+
...
|
|
524
|
+
async def logprobs_per_token(self, thread: TokenizedThread) -> list[float]:
|
|
525
|
+
"""Compute log probabilities for each weighted token in a thread.
|
|
526
|
+
|
|
527
|
+
Args:
|
|
528
|
+
thread: Tokenized conversation thread
|
|
529
|
+
|
|
530
|
+
Returns:
|
|
531
|
+
List of log probabilities for each weighted token
|
|
532
|
+
"""
|
|
533
|
+
...
|
|
534
|
+
async def score(self, thread: TokenizedThread) -> list[float]:
|
|
535
|
+
"""Score each weighted token using a scalar head model.
|
|
536
|
+
|
|
537
|
+
Only works with scoring models (where is_scalar() returns True).
|
|
538
|
+
|
|
539
|
+
Args:
|
|
540
|
+
thread: Tokenized thread to score
|
|
541
|
+
|
|
542
|
+
Returns:
|
|
543
|
+
List of scalar scores for each weighted token
|
|
544
|
+
"""
|
|
545
|
+
...
|
|
546
|
+
async def score_last_token(self, thread: StringThread) -> float:
|
|
547
|
+
"""Score the last token using a scalar head model.
|
|
548
|
+
|
|
549
|
+
Args:
|
|
550
|
+
thread: Conversation thread
|
|
551
|
+
|
|
552
|
+
Returns:
|
|
553
|
+
Scalar score for the last token
|
|
554
|
+
"""
|
|
555
|
+
...
|
|
556
|
+
async def raw_string_create(self, prompt: str) -> str:
|
|
557
|
+
"""Generate text from a raw string prompt (no conversation formatting).
|
|
558
|
+
Should normally not be used unless you know exactly what you are doing.
|
|
559
|
+
|
|
560
|
+
Args:
|
|
561
|
+
prompt: Raw text prompt
|
|
562
|
+
|
|
563
|
+
Returns:
|
|
564
|
+
Generated completion
|
|
565
|
+
"""
|
|
566
|
+
...
|
|
567
|
+
async def raw_token_create(self, prompt: Sequence[int]) -> list[int]:
|
|
568
|
+
"""Generate tokens from raw token IDs (no conversation formatting).
|
|
569
|
+
Should normally not be used unless you know exactly what you are doing.
|
|
570
|
+
|
|
571
|
+
Args:
|
|
572
|
+
prompt: List of token IDs
|
|
573
|
+
|
|
574
|
+
Returns:
|
|
575
|
+
Generated token IDs
|
|
576
|
+
"""
|
|
577
|
+
...
|
|
578
|
+
async def tokenize(self, data: str) -> list[int]:
|
|
579
|
+
"""Tokenize a string into token IDs.
|
|
580
|
+
|
|
581
|
+
Args:
|
|
582
|
+
data: Text to tokenize
|
|
583
|
+
|
|
584
|
+
Returns:
|
|
585
|
+
List of token IDs
|
|
586
|
+
"""
|
|
587
|
+
...
|
|
588
|
+
async def detokenize(self, data: Sequence[int]) -> str:
|
|
589
|
+
"""Convert token IDs back to text.
|
|
590
|
+
|
|
591
|
+
Args:
|
|
592
|
+
data: List of token IDs
|
|
593
|
+
|
|
594
|
+
Returns:
|
|
595
|
+
Decoded text string
|
|
596
|
+
"""
|
|
597
|
+
...
|
|
598
|
+
async def char_to_token_rewards(self, text: str, char_rewards: Sequence[float]) -> list[float]:
|
|
599
|
+
"""Map character-level rewards to token-level rewards.
|
|
600
|
+
|
|
601
|
+
Useful for reward modeling when you have character-granular feedback.
|
|
602
|
+
|
|
603
|
+
Args:
|
|
604
|
+
text: The text string
|
|
605
|
+
char_rewards: Reward value for each character
|
|
606
|
+
|
|
607
|
+
Returns:
|
|
608
|
+
Reward value for each token
|
|
609
|
+
"""
|
|
610
|
+
...
|
|
611
|
+
async def raw_logprobs(self, tokens: Sequence[int]) -> list[float]:
|
|
612
|
+
"""Compute log probabilities for raw token IDs.
|
|
613
|
+
Should normally not be used unless you know exactly what you are doing.
|
|
614
|
+
|
|
615
|
+
Args:
|
|
616
|
+
tokens: Sequence of token IDs
|
|
617
|
+
|
|
618
|
+
Returns:
|
|
619
|
+
Log probability for each token (first token gets 0.0)
|
|
620
|
+
"""
|
|
621
|
+
...
|
|
622
|
+
async def thread_with_tools(self, thread: StringThread, tool_uris: Sequence[str]) -> StringThread:
|
|
623
|
+
"""Inject tool definitions into a conversation thread.
|
|
624
|
+
DOCS_TODO: add an example here.
|
|
625
|
+
|
|
626
|
+
Args:
|
|
627
|
+
thread: Base conversation thread
|
|
628
|
+
tool_uris: List of tool URIs to make available
|
|
629
|
+
|
|
630
|
+
Returns:
|
|
631
|
+
Thread with tool definitions injected
|
|
632
|
+
"""
|
|
633
|
+
...
|
|
634
|
+
@staticmethod
|
|
635
|
+
def render_schema(pydantic_model: type[BaseModel], with_field_descriptions: bool = True) -> str:
|
|
636
|
+
"""Render a Pydantic model as a JSON schema string.
|
|
637
|
+
|
|
638
|
+
Args:
|
|
639
|
+
pydantic_model: Pydantic model class
|
|
640
|
+
with_field_descriptions: Include field descriptions in schema
|
|
641
|
+
|
|
642
|
+
Returns:
|
|
643
|
+
JSON schema as a string
|
|
644
|
+
"""
|
|
645
|
+
...
|
|
646
|
+
@staticmethod
|
|
647
|
+
def render_pydantic_model(pydantic_model: BaseModel) -> str:
|
|
648
|
+
"""Serialize a Pydantic model instance to JSON string.
|
|
649
|
+
|
|
650
|
+
Args:
|
|
651
|
+
pydantic_model: Pydantic model instance
|
|
652
|
+
|
|
653
|
+
Returns:
|
|
654
|
+
JSON representation
|
|
655
|
+
"""
|
|
656
|
+
...
|
|
657
|
+
def get_builder_args(self) -> dict[str, Any]:
|
|
658
|
+
"""Get the configuration args used to build this model.
|
|
659
|
+
|
|
660
|
+
Returns:
|
|
661
|
+
Dictionary of model builder arguments
|
|
662
|
+
"""
|
|
663
|
+
...
|
|
664
|
+
def top_p(self, top_p: float) -> InferenceModel:
|
|
665
|
+
"""Return a shallow copy with modified top_p sampling parameter.
|
|
666
|
+
|
|
667
|
+
Args:
|
|
668
|
+
top_p: Nucleus sampling probability (0.0-1.0)
|
|
669
|
+
|
|
670
|
+
Returns:
|
|
671
|
+
New InferenceModel instance with updated parameter
|
|
672
|
+
"""
|
|
673
|
+
...
|
|
674
|
+
def temperature(self, temperature: float) -> InferenceModel:
|
|
675
|
+
"""Return a shallow copy with modified temperature sampling parameter.
|
|
676
|
+
|
|
677
|
+
Args:
|
|
678
|
+
temperature: Sampling temperature (>0, typically 0.1-2.0)
|
|
679
|
+
|
|
680
|
+
Returns:
|
|
681
|
+
New InferenceModel instance with updated parameter
|
|
682
|
+
"""
|
|
683
|
+
...
|
|
684
|
+
def max_gen_len(self, max_num_tokens: int) -> InferenceModel:
|
|
685
|
+
"""Return a shallow copy with modified maximum generation length.
|
|
686
|
+
|
|
687
|
+
Args:
|
|
688
|
+
max_num_tokens: Maximum tokens to generate
|
|
689
|
+
|
|
690
|
+
Returns:
|
|
691
|
+
New InferenceModel instance with updated parameter
|
|
692
|
+
"""
|
|
693
|
+
...
|
|
694
|
+
def min_gen_len(self, min_num_tokens: int) -> InferenceModel:
|
|
695
|
+
"""Return a shallow copy with modified minimum generation length.
|
|
696
|
+
|
|
697
|
+
Args:
|
|
698
|
+
min_num_tokens: Minimum tokens to generate
|
|
699
|
+
|
|
700
|
+
Returns:
|
|
701
|
+
New InferenceModel instance with updated parameter
|
|
702
|
+
"""
|
|
703
|
+
...
|
|
704
|
+
|
|
705
|
+
class InferenceSettings:
|
|
706
|
+
"""Configuration for inference model spawning.
|
|
707
|
+
|
|
708
|
+
Attributes:
|
|
709
|
+
kv_cache_len: Length of the KV cache
|
|
710
|
+
tokens_to_generate: Maximum tokens to generate per request
|
|
711
|
+
"""
|
|
712
|
+
def __new__(cls, kv_cache_len: int, tokens_to_generate: int) -> InferenceSettings: ...
|
|
713
|
+
def to_python_dict(self) -> dict[str, Any]:
|
|
714
|
+
"""Convert settings to a dictionary.
|
|
715
|
+
|
|
716
|
+
Returns:
|
|
717
|
+
Dictionary representation of settings
|
|
718
|
+
"""
|
|
719
|
+
...
|
|
720
|
+
|
|
721
|
+
class ModelBuilder:
|
|
722
|
+
"""Builder for configuring and spawning model instances.
|
|
723
|
+
|
|
724
|
+
ModelBuilder provides a fluent API for configuring model parameters
|
|
725
|
+
before spawning inference or training instances. Create via HarmonyClient.model().
|
|
726
|
+
|
|
727
|
+
Example:
|
|
728
|
+
```python
|
|
729
|
+
# Basic inference model
|
|
730
|
+
model = client.model("model_registry://llama-3.1-8b")
|
|
731
|
+
inf = await model.spawn_inference("my-model")
|
|
732
|
+
|
|
733
|
+
# With tensor parallelism and temperature
|
|
734
|
+
model = (client.model("llama-3.1-70b")
|
|
735
|
+
.tp(4)
|
|
736
|
+
.temperature(0.7))
|
|
737
|
+
inf = await model.spawn_inference("my-70b")
|
|
738
|
+
|
|
739
|
+
# Training model with LoRA adapter
|
|
740
|
+
model = (client.model("model_registry://llama-3.1-8b")
|
|
741
|
+
.with_adapter())
|
|
742
|
+
train = await model.spawn_train("my-train", max_batch_size=4)
|
|
743
|
+
|
|
744
|
+
# With speculative decoding
|
|
745
|
+
model = (client.model("llama-3.1-70b")
|
|
746
|
+
.with_draft("llama-3.1-8b", num_draft_steps=4))
|
|
747
|
+
inf = await model.spawn_inference("fast-70b")
|
|
748
|
+
```
|
|
749
|
+
"""
|
|
750
|
+
async def spawn_train(self, name: str, max_batch_size: int) -> TrainingModel:
|
|
751
|
+
"""Spawn a training model instance.
|
|
752
|
+
|
|
753
|
+
Args:
|
|
754
|
+
name: Unique name for this model instance
|
|
755
|
+
max_batch_size: Maximum batch size for training. Due to our automatic packing this is also the maximum sequence length of a single sample.
|
|
756
|
+
|
|
757
|
+
Returns:
|
|
758
|
+
TrainingModel instance for fine-tuning operations
|
|
759
|
+
"""
|
|
760
|
+
...
|
|
761
|
+
async def spawn_inference(self, name: str) -> InferenceModel:
|
|
762
|
+
"""Spawn an inference model instance.
|
|
763
|
+
|
|
764
|
+
Args:
|
|
765
|
+
name: Unique name for this model instance
|
|
766
|
+
|
|
767
|
+
Returns:
|
|
768
|
+
InferenceModel instance for generation operations
|
|
769
|
+
"""
|
|
770
|
+
...
|
|
771
|
+
def tp(self, tp: int) -> ModelBuilder:
|
|
772
|
+
"""Set tensor parallelism degree.
|
|
773
|
+
|
|
774
|
+
Args:
|
|
775
|
+
tp: Number of GPUs to split the model across
|
|
776
|
+
|
|
777
|
+
Returns:
|
|
778
|
+
Updated builder
|
|
779
|
+
"""
|
|
780
|
+
...
|
|
781
|
+
def tools(self, tools: list[str]) -> ModelBuilder:
|
|
782
|
+
"""Add tool calling support.
|
|
783
|
+
|
|
784
|
+
Args:
|
|
785
|
+
tools: List of tool URIs to make available
|
|
786
|
+
|
|
787
|
+
Returns:
|
|
788
|
+
Updated builder
|
|
789
|
+
"""
|
|
790
|
+
...
|
|
791
|
+
def api_key(self, api_key: str) -> ModelBuilder:
|
|
792
|
+
"""Set API key for external model providers.
|
|
793
|
+
|
|
794
|
+
Args:
|
|
795
|
+
api_key: API key for provider (OpenAI, Anthropic, etc.)
|
|
796
|
+
|
|
797
|
+
Returns:
|
|
798
|
+
Updated builder
|
|
799
|
+
"""
|
|
800
|
+
...
|
|
801
|
+
def into_scoring_model(self) -> ModelBuilder:
|
|
802
|
+
"""Configure model to use scalar output head for scoring. Useful when creating a reward/value model from a base language model.
|
|
803
|
+
|
|
804
|
+
Returns:
|
|
805
|
+
Updated builder configured for reward/value modeling
|
|
806
|
+
"""
|
|
807
|
+
...
|
|
808
|
+
def with_adapter(self, use_adapter: bool = True) -> ModelBuilder:
|
|
809
|
+
"""Configure adapter/LoRA mode for model spawning.
|
|
810
|
+
|
|
811
|
+
Args:
|
|
812
|
+
use_adapter: Controls adapter mode (default: True).
|
|
813
|
+
- True: Enable adapter mode. For FW models, creates a new adapter.
|
|
814
|
+
For LoRA models, uses the existing adapter (no-op).
|
|
815
|
+
- False: Disable adapter mode. For FW models, trains the full weights directly.
|
|
816
|
+
For LoRA models, this is an error (cannot fall back to FW from a LoRA).
|
|
817
|
+
|
|
818
|
+
If with_adapter() is not called at all:
|
|
819
|
+
- FW models default to training full weights (equivalent to with_adapter(False))
|
|
820
|
+
- LoRA models default to using the adapter (equivalent to with_adapter(True))
|
|
821
|
+
|
|
822
|
+
Returns:
|
|
823
|
+
Updated builder with adapter mode configured
|
|
824
|
+
"""
|
|
825
|
+
...
|
|
826
|
+
def with_draft(self, draft_model_path: str, num_draft_steps: int | None = None) -> ModelBuilder:
|
|
827
|
+
"""Enable speculative decoding with a draft model.
|
|
828
|
+
|
|
829
|
+
Args:
|
|
830
|
+
draft_model_path: Path to smaller/faster draft model. Must share the same tokenizer as the main model.
|
|
831
|
+
num_draft_steps: Number of speculative tokens per step (default: DEFAULT_MAX_DRAFT_STEPS)
|
|
832
|
+
|
|
833
|
+
Returns:
|
|
834
|
+
Updated builder with speculative decoding enabled
|
|
835
|
+
"""
|
|
836
|
+
...
|
|
837
|
+
def to_python_dict(self) -> dict[str, Any]:
|
|
838
|
+
"""Get builder configuration as a dictionary.
|
|
839
|
+
|
|
840
|
+
Returns:
|
|
841
|
+
Dictionary of builder parameters
|
|
842
|
+
"""
|
|
843
|
+
...
|
|
844
|
+
def extra_params(self, **params) -> ModelBuilder:
|
|
845
|
+
"""Set additional provider-specific parameters.
|
|
846
|
+
DOCS_TODO: add examples for OpenAI, Anthropic, etc.
|
|
847
|
+
|
|
848
|
+
Args:
|
|
849
|
+
**params: Arbitrary parameters passed to the model provider
|
|
850
|
+
|
|
851
|
+
Returns:
|
|
852
|
+
Updated builder
|
|
853
|
+
"""
|
|
854
|
+
...
|
|
855
|
+
|
|
856
|
+
class SerializedThread:
|
|
857
|
+
"""Opaque type representing a fully serialized conversation thread.
|
|
858
|
+
|
|
859
|
+
Used internally for efficient communication between client and workers.
|
|
860
|
+
"""
|
|
861
|
+
|
|
862
|
+
...
|
|
863
|
+
|
|
864
|
+
class StatDuration:
|
|
865
|
+
"""Statistics about tokens processed over a time period.
|
|
866
|
+
|
|
867
|
+
Tracks input, output, and training token counts and throughput rates.
|
|
868
|
+
|
|
869
|
+
Attributes:
|
|
870
|
+
duration_sec: Duration in seconds
|
|
871
|
+
num_input_tokens: Total input tokens processed
|
|
872
|
+
num_output_tokens: Total output tokens generated
|
|
873
|
+
num_trained_tokens: Total tokens used for training
|
|
874
|
+
num_input_tokens_per_s: Input token throughput
|
|
875
|
+
num_output_tokens_per_s: Output token throughput
|
|
876
|
+
num_trained_tokens_per_s: Training token throughput
|
|
877
|
+
"""
|
|
878
|
+
|
|
879
|
+
duration_sec: float
|
|
880
|
+
num_input_tokens: int
|
|
881
|
+
num_output_tokens: int
|
|
882
|
+
num_trained_tokens: int
|
|
883
|
+
num_input_tokens_per_s: float
|
|
884
|
+
num_output_tokens_per_s: float
|
|
885
|
+
num_trained_tokens_per_s: float
|
|
886
|
+
|
|
887
|
+
def combine(self, other: StatDuration) -> StatDuration:
|
|
888
|
+
"""Combine statistics from two time periods.
|
|
889
|
+
|
|
890
|
+
Args:
|
|
891
|
+
other: Another StatDuration to combine
|
|
892
|
+
|
|
893
|
+
Returns:
|
|
894
|
+
Combined statistics spanning both periods
|
|
895
|
+
"""
|
|
896
|
+
...
|
|
897
|
+
|
|
898
|
+
class StatInstant:
|
|
899
|
+
"""Snapshot of model statistics at a point in time.
|
|
900
|
+
|
|
901
|
+
Used to compute statistics over time intervals via stats_since().
|
|
902
|
+
"""
|
|
903
|
+
def stats_since(self, other: StatInstant) -> StatDuration:
|
|
904
|
+
"""Compute statistics since another snapshot.
|
|
905
|
+
|
|
906
|
+
Args:
|
|
907
|
+
other: Earlier StatInstant snapshot
|
|
908
|
+
|
|
909
|
+
Returns:
|
|
910
|
+
StatDuration covering the time between snapshots
|
|
911
|
+
"""
|
|
912
|
+
...
|
|
913
|
+
|
|
914
|
+
class StringThread:
|
|
915
|
+
"""Represents a conversation thread with string-based content.
|
|
916
|
+
|
|
917
|
+
StringThread is the primary way to represent conversations in Harmony. It consists
|
|
918
|
+
of a sequence of turns (user, assistant, system, tool) with text content. Each turn
|
|
919
|
+
has an associated weight that controls whether it's used for training.
|
|
920
|
+
|
|
921
|
+
Threads support multi-modal content via fragments (text and images).
|
|
922
|
+
|
|
923
|
+
Attributes:
|
|
924
|
+
metadata: Arbitrary Python object to attach to this thread
|
|
925
|
+
|
|
926
|
+
Example:
|
|
927
|
+
```python
|
|
928
|
+
# Create a simple conversation
|
|
929
|
+
thread = StringThread([
|
|
930
|
+
("user", "What is the capital of France?"),
|
|
931
|
+
("assistant", "The capital of France is Paris.")
|
|
932
|
+
])
|
|
933
|
+
|
|
934
|
+
# Add more turns
|
|
935
|
+
thread = thread.user("Tell me more about it.")
|
|
936
|
+
|
|
937
|
+
# Access turns
|
|
938
|
+
for role, content in thread.get_turns():
|
|
939
|
+
print(f"{role}: {content}")
|
|
940
|
+
|
|
941
|
+
# Multi-modal with images
|
|
942
|
+
thread = await StringThread.from_fragments([
|
|
943
|
+
("user", [
|
|
944
|
+
{"type": "text", "text": "What's in this image?"},
|
|
945
|
+
{"type": "image", "url": "data:image/png;base64,..."}
|
|
946
|
+
])
|
|
947
|
+
], metadata=None)
|
|
948
|
+
```
|
|
949
|
+
"""
|
|
950
|
+
|
|
951
|
+
metadata: Any
|
|
952
|
+
|
|
953
|
+
def __new__(
|
|
954
|
+
cls,
|
|
955
|
+
turns: Optional[Sequence[tuple[str, str]]] = None,
|
|
956
|
+
metadata: Optional[Any] = None,
|
|
957
|
+
) -> StringThread: ...
|
|
958
|
+
@staticmethod
|
|
959
|
+
async def from_fragments(raw_turns: Sequence[tuple[str, Sequence[Fragment]]], metadata: Any) -> StringThread:
|
|
960
|
+
"""Create a thread from multi-modal fragments (text and images).
|
|
961
|
+
|
|
962
|
+
Args:
|
|
963
|
+
raw_turns: Sequence of (role, fragments) tuples
|
|
964
|
+
metadata: Optional metadata to attach
|
|
965
|
+
|
|
966
|
+
Returns:
|
|
967
|
+
StringThread with multi-modal content
|
|
968
|
+
"""
|
|
969
|
+
...
|
|
970
|
+
@staticmethod
|
|
971
|
+
async def from_dataset(raw_turns: Sequence[tuple[str, str | Sequence[Fragment]]], metadata: Any) -> StringThread:
|
|
972
|
+
"""Create a thread from dataset format (strings or fragments).
|
|
973
|
+
|
|
974
|
+
Flexible constructor that accepts either plain strings or fragment lists
|
|
975
|
+
for each turn's content.
|
|
976
|
+
|
|
977
|
+
Args:
|
|
978
|
+
raw_turns: Sequence of (role, content) where content is string or fragments
|
|
979
|
+
metadata: Optional metadata to attach
|
|
980
|
+
|
|
981
|
+
Returns:
|
|
982
|
+
StringThread from dataset
|
|
983
|
+
"""
|
|
984
|
+
...
|
|
985
|
+
@classmethod
|
|
986
|
+
def with_metadata(cls, turns: Sequence[tuple[str, str]], metadata: Any) -> StringThread: ...
|
|
987
|
+
def user(self, content: str) -> StringThread:
|
|
988
|
+
"""Append a user turn to the thread.
|
|
989
|
+
|
|
990
|
+
Args:
|
|
991
|
+
content: User message text
|
|
992
|
+
|
|
993
|
+
Returns:
|
|
994
|
+
New thread with user turn appended
|
|
995
|
+
"""
|
|
996
|
+
...
|
|
997
|
+
def system(self, content: str) -> StringThread:
|
|
998
|
+
"""Append a system turn to the thread.
|
|
999
|
+
|
|
1000
|
+
Args:
|
|
1001
|
+
content: System message text
|
|
1002
|
+
|
|
1003
|
+
Returns:
|
|
1004
|
+
New thread with system turn appended
|
|
1005
|
+
"""
|
|
1006
|
+
...
|
|
1007
|
+
def assistant(self, content: str) -> StringThread:
|
|
1008
|
+
"""Append an assistant turn to the thread.
|
|
1009
|
+
|
|
1010
|
+
Args:
|
|
1011
|
+
content: Assistant message text
|
|
1012
|
+
|
|
1013
|
+
Returns:
|
|
1014
|
+
New thread with assistant turn appended
|
|
1015
|
+
"""
|
|
1016
|
+
...
|
|
1017
|
+
def tool(self, content: str) -> StringThread:
|
|
1018
|
+
"""Append a tool result turn to the thread.
|
|
1019
|
+
|
|
1020
|
+
Args:
|
|
1021
|
+
content: Tool output text
|
|
1022
|
+
|
|
1023
|
+
Returns:
|
|
1024
|
+
New thread with tool turn appended
|
|
1025
|
+
"""
|
|
1026
|
+
...
|
|
1027
|
+
def last_content(self) -> str:
|
|
1028
|
+
"""Get the content of the last turn.
|
|
1029
|
+
|
|
1030
|
+
Returns:
|
|
1031
|
+
Text content of the last turn
|
|
1032
|
+
|
|
1033
|
+
Raises:
|
|
1034
|
+
RecipeError: If thread is empty
|
|
1035
|
+
"""
|
|
1036
|
+
...
|
|
1037
|
+
def messages(self) -> list[tuple[str, str]]:
|
|
1038
|
+
"""Get all turns except the last assistant turn (if present).
|
|
1039
|
+
DOCS_TODO: show example.
|
|
1040
|
+
|
|
1041
|
+
Useful for extracting the prompt portion of a thread.
|
|
1042
|
+
|
|
1043
|
+
Returns:
|
|
1044
|
+
List of (role, content) tuples excluding final assistant turn
|
|
1045
|
+
"""
|
|
1046
|
+
...
|
|
1047
|
+
def completion(self) -> str | None:
|
|
1048
|
+
"""Get the completion if the last turn is from the assistant.
|
|
1049
|
+
DOCS_TODO: show example.
|
|
1050
|
+
|
|
1051
|
+
Returns:
|
|
1052
|
+
Assistant's response if last turn is assistant, None otherwise
|
|
1053
|
+
"""
|
|
1054
|
+
...
|
|
1055
|
+
def get_turns(self) -> list[StringTurn]:
|
|
1056
|
+
"""Get all turns as StringTurn namedtuples.
|
|
1057
|
+
|
|
1058
|
+
Returns:
|
|
1059
|
+
List of StringTurn(role, content) tuples
|
|
1060
|
+
"""
|
|
1061
|
+
...
|
|
1062
|
+
def get_fragments(self) -> list[tuple[str, list[Fragment]]]:
|
|
1063
|
+
"""Get all turns as multi-modal fragments.
|
|
1064
|
+
|
|
1065
|
+
Returns:
|
|
1066
|
+
List of (role, fragments) tuples
|
|
1067
|
+
"""
|
|
1068
|
+
...
|
|
1069
|
+
@staticmethod
|
|
1070
|
+
def from_json(json_str) -> StringThread:
|
|
1071
|
+
"""Deserialize a thread from JSON string.
|
|
1072
|
+
|
|
1073
|
+
Args:
|
|
1074
|
+
json_str: JSON representation of thread
|
|
1075
|
+
|
|
1076
|
+
Returns:
|
|
1077
|
+
Deserialized StringThread
|
|
1078
|
+
"""
|
|
1079
|
+
...
|
|
1080
|
+
def to_json(self) -> str:
|
|
1081
|
+
"""Serialize thread to JSON string.
|
|
1082
|
+
|
|
1083
|
+
Returns:
|
|
1084
|
+
JSON representation of thread
|
|
1085
|
+
"""
|
|
1086
|
+
...
|
|
1087
|
+
def with_weight_all_assistant_turns(self) -> StringThread:
|
|
1088
|
+
"""Mark all assistant turns for training, zero out others.
|
|
1089
|
+
|
|
1090
|
+
Returns:
|
|
1091
|
+
New thread with weights adjusted
|
|
1092
|
+
"""
|
|
1093
|
+
...
|
|
1094
|
+
def with_weight_last_assistant_turn(self) -> StringThread:
|
|
1095
|
+
"""Mark only the last assistant turn for training.
|
|
1096
|
+
|
|
1097
|
+
Returns:
|
|
1098
|
+
New thread with weights adjusted
|
|
1099
|
+
"""
|
|
1100
|
+
...
|
|
1101
|
+
def with_weight_assistant_turns_from_index(self, start_index: int) -> StringThread:
|
|
1102
|
+
"""Mark assistant turns starting from index for training.
|
|
1103
|
+
|
|
1104
|
+
Args:
|
|
1105
|
+
start_index: Index of first assistant turn to weight (0-based)
|
|
1106
|
+
|
|
1107
|
+
Returns:
|
|
1108
|
+
New thread with weights adjusted
|
|
1109
|
+
"""
|
|
1110
|
+
...
|
|
1111
|
+
def uuid(self) -> str | None:
|
|
1112
|
+
"""Get the UUID associated with this thread (if any).
|
|
1113
|
+
|
|
1114
|
+
Returns:
|
|
1115
|
+
UUID string or None
|
|
1116
|
+
"""
|
|
1117
|
+
...
|
|
1118
|
+
def __repr__(self) -> str: ...
|
|
1119
|
+
|
|
1120
|
+
class TokenizedThread:
|
|
1121
|
+
"""Represents a conversation thread with tokenized content.
|
|
1122
|
+
|
|
1123
|
+
Similar to StringThread but stores token IDs instead of strings.
|
|
1124
|
+
Useful for training operations and when you need direct token-level control.
|
|
1125
|
+
|
|
1126
|
+
Attributes:
|
|
1127
|
+
metadata: Arbitrary Python object to attach to this thread
|
|
1128
|
+
"""
|
|
1129
|
+
|
|
1130
|
+
metadata: Any
|
|
1131
|
+
|
|
1132
|
+
def user(self, content: Sequence[int]) -> TokenizedThread:
|
|
1133
|
+
"""Append a user turn with token IDs.
|
|
1134
|
+
|
|
1135
|
+
Args:
|
|
1136
|
+
content: List of token IDs for user message
|
|
1137
|
+
|
|
1138
|
+
Returns:
|
|
1139
|
+
New thread with user turn appended
|
|
1140
|
+
"""
|
|
1141
|
+
...
|
|
1142
|
+
def assistant(self, content: Sequence[int]) -> TokenizedThread:
|
|
1143
|
+
"""Append an assistant turn with token IDs.
|
|
1144
|
+
|
|
1145
|
+
Args:
|
|
1146
|
+
content: List of token IDs for assistant message
|
|
1147
|
+
|
|
1148
|
+
Returns:
|
|
1149
|
+
New thread with assistant turn appended
|
|
1150
|
+
"""
|
|
1151
|
+
...
|
|
1152
|
+
def last_content(self) -> Optional[list[int]]:
|
|
1153
|
+
"""Get the token IDs of the last turn.
|
|
1154
|
+
|
|
1155
|
+
Returns:
|
|
1156
|
+
List of token IDs, or None if thread is empty
|
|
1157
|
+
"""
|
|
1158
|
+
...
|
|
1159
|
+
def len_last_turn(self) -> int:
|
|
1160
|
+
"""Get the number of tokens in the last turn.
|
|
1161
|
+
|
|
1162
|
+
Returns:
|
|
1163
|
+
Token count of last turn
|
|
1164
|
+
"""
|
|
1165
|
+
...
|
|
1166
|
+
def get_turns(self) -> list[TokenizedTurn]:
|
|
1167
|
+
"""Get all turns as TokenizedTurn namedtuples.
|
|
1168
|
+
|
|
1169
|
+
Returns:
|
|
1170
|
+
List of TokenizedTurn(role, content) tuples
|
|
1171
|
+
"""
|
|
1172
|
+
...
|
|
1173
|
+
def with_weight_all_assistant_turns(self) -> TokenizedThread:
|
|
1174
|
+
"""Mark all assistant turns for training, zero out others.
|
|
1175
|
+
|
|
1176
|
+
Returns:
|
|
1177
|
+
New thread with weights adjusted
|
|
1178
|
+
"""
|
|
1179
|
+
...
|
|
1180
|
+
def with_weight_last_assistant_turn(self) -> TokenizedThread:
|
|
1181
|
+
"""Mark only the last assistant turn for training.
|
|
1182
|
+
|
|
1183
|
+
Returns:
|
|
1184
|
+
New thread with weights adjusted
|
|
1185
|
+
"""
|
|
1186
|
+
...
|
|
1187
|
+
def with_weight_assistant_turns_from_index(self, start_index: int) -> TokenizedThread:
|
|
1188
|
+
"""Mark assistant turns starting from index for training.
|
|
1189
|
+
|
|
1190
|
+
Args:
|
|
1191
|
+
start_index: Index of first assistant turn to weight (0-based)
|
|
1192
|
+
|
|
1193
|
+
Returns:
|
|
1194
|
+
New thread with weights adjusted
|
|
1195
|
+
"""
|
|
1196
|
+
...
|
|
1197
|
+
def uuid(self) -> str | None:
|
|
1198
|
+
"""Get the UUID associated with this thread (if any).
|
|
1199
|
+
|
|
1200
|
+
Returns:
|
|
1201
|
+
UUID string or None
|
|
1202
|
+
"""
|
|
1203
|
+
...
|
|
1204
|
+
def __repr__(self) -> str: ...
|
|
1205
|
+
|
|
1206
|
+
class TrainingModel(InferenceModel):
|
|
1207
|
+
"""Model instance for running training operations.
|
|
1208
|
+
|
|
1209
|
+
TrainingModel extends InferenceModel with training capabilities. It supports
|
|
1210
|
+
various training algorithms including supervised fine-tuning, PPO, DPO, GRPO,
|
|
1211
|
+
and reward modeling.
|
|
1212
|
+
|
|
1213
|
+
Inherits all inference methods from InferenceModel.
|
|
1214
|
+
|
|
1215
|
+
Example:
|
|
1216
|
+
```python
|
|
1217
|
+
# Spawn a training model
|
|
1218
|
+
model = client.model("model_registry://llama-3.1-8b").with_adapter()
|
|
1219
|
+
train_model = await model.spawn_train("my-train", max_batch_size=1024)
|
|
1220
|
+
|
|
1221
|
+
# Supervised fine-tuning
|
|
1222
|
+
thread = StringThread([
|
|
1223
|
+
("user", "What is 2+2?"),
|
|
1224
|
+
("assistant", "4")
|
|
1225
|
+
]).with_weight_last_assistant_turn()
|
|
1226
|
+
|
|
1227
|
+
await train_model.train_language_modelling(thread)
|
|
1228
|
+
await train_model.optim_step(lr=1e-5, wd=0.01, max_grad_norm=1.0)
|
|
1229
|
+
|
|
1230
|
+
# Save the fine-tuned model
|
|
1231
|
+
model_key = await train_model.save("my-fine-tuned-model")
|
|
1232
|
+
```
|
|
1233
|
+
"""
|
|
1234
|
+
async def clone_inf(self) -> InferenceModel:
|
|
1235
|
+
"""Clone this model as an inference-only instance. Useful for all reinforcement learning algorithms that perform regularization with KL divergence to a reference model.
|
|
1236
|
+
|
|
1237
|
+
Returns:
|
|
1238
|
+
New InferenceModel sharing weights with this training model
|
|
1239
|
+
"""
|
|
1240
|
+
...
|
|
1241
|
+
def get_builder_args(self) -> dict[str, Any]: ...
|
|
1242
|
+
async def save(self, model_name: str, inference_only: bool = True, ctx: RecipeContext | None = None) -> str:
|
|
1243
|
+
"""Save the model weights to the model registry.
|
|
1244
|
+
|
|
1245
|
+
Args:
|
|
1246
|
+
model_name: Name for the saved model
|
|
1247
|
+
inference_only: If True, save only inference weights (default). If False we save the entire optimizer state as well.
|
|
1248
|
+
ctx: Optional recipe context for tracking
|
|
1249
|
+
|
|
1250
|
+
Returns:
|
|
1251
|
+
Model key that can be used to load the model
|
|
1252
|
+
"""
|
|
1253
|
+
...
|
|
1254
|
+
async def optim_step(
|
|
1255
|
+
self, lr: float, wd: float, max_grad_norm: float, skip_nan_gradients: bool = False
|
|
1256
|
+
) -> dict[str, float]:
|
|
1257
|
+
"""Perform an optimizer step (update model weights).
|
|
1258
|
+
|
|
1259
|
+
Args:
|
|
1260
|
+
lr: Learning rate
|
|
1261
|
+
wd: Weight decay
|
|
1262
|
+
max_grad_norm: Maximum gradient norm for clipping
|
|
1263
|
+
skip_nan_gradients: If True, skip step on NaN gradients instead of erroring
|
|
1264
|
+
|
|
1265
|
+
Returns:
|
|
1266
|
+
Dictionary of training metrics (loss, grad_norm, etc.)
|
|
1267
|
+
"""
|
|
1268
|
+
...
|
|
1269
|
+
def get_optim_step(self) -> int:
|
|
1270
|
+
"""Get the current optimizer step counter.
|
|
1271
|
+
|
|
1272
|
+
Returns:
|
|
1273
|
+
Current step number
|
|
1274
|
+
"""
|
|
1275
|
+
...
|
|
1276
|
+
def set_optim_step(self, step: int) -> None:
|
|
1277
|
+
"""Set the optimizer step counter.
|
|
1278
|
+
|
|
1279
|
+
Args:
|
|
1280
|
+
step: Step number to set
|
|
1281
|
+
"""
|
|
1282
|
+
...
|
|
1283
|
+
def inf(self) -> InferenceModel:
|
|
1284
|
+
"""Get inference-only view of this model.
|
|
1285
|
+
|
|
1286
|
+
Returns:
|
|
1287
|
+
InferenceModel view (no copy, shares weights)
|
|
1288
|
+
"""
|
|
1289
|
+
...
|
|
1290
|
+
async def train_language_modelling(self, thread: StringThread) -> None:
|
|
1291
|
+
"""Train with standard language modeling objective (next-token prediction).
|
|
1292
|
+
|
|
1293
|
+
Args:
|
|
1294
|
+
thread: Conversation thread with weighted assistant turns
|
|
1295
|
+
"""
|
|
1296
|
+
...
|
|
1297
|
+
async def train_ppo(
|
|
1298
|
+
self,
|
|
1299
|
+
thread: TokenizedThread,
|
|
1300
|
+
trajectory_logprobs: Sequence[float],
|
|
1301
|
+
advantages: Sequence[float],
|
|
1302
|
+
clip_range: float,
|
|
1303
|
+
) -> None:
|
|
1304
|
+
"""Train with Proximal Policy Optimization (PPO) objective.
|
|
1305
|
+
|
|
1306
|
+
Args:
|
|
1307
|
+
thread: Tokenized conversation thread
|
|
1308
|
+
trajectory_logprobs: Log probabilities from trajectory policy
|
|
1309
|
+
advantages: Per-token advantage estimates
|
|
1310
|
+
clip_range: PPO clipping range (epsilon)
|
|
1311
|
+
"""
|
|
1312
|
+
...
|
|
1313
|
+
async def train_grpo(
|
|
1314
|
+
self,
|
|
1315
|
+
thread: TokenizedThread,
|
|
1316
|
+
trajectory_logprobs: Sequence[float],
|
|
1317
|
+
reference_logprobs: Sequence[float],
|
|
1318
|
+
advantages: Sequence[float],
|
|
1319
|
+
clip_range: float,
|
|
1320
|
+
kl_beta: float,
|
|
1321
|
+
) -> None:
|
|
1322
|
+
"""Train with Group Relative Policy Optimization (GRPO).
|
|
1323
|
+
|
|
1324
|
+
Args:
|
|
1325
|
+
thread: Tokenized conversation thread
|
|
1326
|
+
trajectory_logprobs: Log probabilities from trajectory policy
|
|
1327
|
+
reference_logprobs: Log probabilities from reference policy
|
|
1328
|
+
advantages: Per-token advantage estimates
|
|
1329
|
+
clip_range: Clipping range
|
|
1330
|
+
kl_beta: KL divergence penalty coefficient
|
|
1331
|
+
"""
|
|
1332
|
+
...
|
|
1333
|
+
async def train_gspo(
|
|
1334
|
+
self,
|
|
1335
|
+
thread: TokenizedThread,
|
|
1336
|
+
trajectory_logprobs: Sequence[float],
|
|
1337
|
+
reference_logprobs: Sequence[float],
|
|
1338
|
+
advantage: Sequence[float],
|
|
1339
|
+
left_clip: float,
|
|
1340
|
+
right_clip: float,
|
|
1341
|
+
kl_beta: float,
|
|
1342
|
+
) -> None:
|
|
1343
|
+
"""Train with Group Sampling Policy Optimization (GSPO).
|
|
1344
|
+
|
|
1345
|
+
Args:
|
|
1346
|
+
thread: Tokenized conversation thread
|
|
1347
|
+
trajectory_logprobs: Log probabilities from trajectory policy
|
|
1348
|
+
reference_logprobs: Log probabilities from reference policy
|
|
1349
|
+
advantage: Advantage estimates
|
|
1350
|
+
left_clip: Left clipping bound
|
|
1351
|
+
right_clip: Right clipping bound
|
|
1352
|
+
kl_beta: KL divergence penalty coefficient
|
|
1353
|
+
"""
|
|
1354
|
+
...
|
|
1355
|
+
async def train_trust_region_mse(
|
|
1356
|
+
self,
|
|
1357
|
+
thread: TokenizedThread,
|
|
1358
|
+
targets: Sequence[float],
|
|
1359
|
+
clip_center: Sequence[float],
|
|
1360
|
+
clip_range: float,
|
|
1361
|
+
) -> None:
|
|
1362
|
+
"""Train with trust-region constrained MSE loss (for value functions).
|
|
1363
|
+
|
|
1364
|
+
Args:
|
|
1365
|
+
thread: Tokenized conversation thread
|
|
1366
|
+
targets: Target values
|
|
1367
|
+
clip_center: Center values for clipping
|
|
1368
|
+
clip_range: Clipping range
|
|
1369
|
+
"""
|
|
1370
|
+
...
|
|
1371
|
+
async def train_mse(self, thread: StringThread, target: float) -> None:
|
|
1372
|
+
"""Train with MSE loss on the last token (for reward models).
|
|
1373
|
+
|
|
1374
|
+
Args:
|
|
1375
|
+
thread: Conversation thread
|
|
1376
|
+
target: Target scalar value
|
|
1377
|
+
"""
|
|
1378
|
+
...
|
|
1379
|
+
async def train_mse_per_token(self, thread: TokenizedThread, targets: Sequence[float]) -> None:
|
|
1380
|
+
"""Train with per-token MSE loss.
|
|
1381
|
+
|
|
1382
|
+
Args:
|
|
1383
|
+
thread: Tokenized conversation thread
|
|
1384
|
+
targets: Target value for each weighted token
|
|
1385
|
+
"""
|
|
1386
|
+
...
|
|
1387
|
+
async def train_ranking(self, pos_thread: StringThread, neg_thread: StringThread) -> None:
|
|
1388
|
+
"""Train with ranking loss (positive example should score higher).
|
|
1389
|
+
|
|
1390
|
+
Args:
|
|
1391
|
+
pos_thread: Preferred/positive example
|
|
1392
|
+
neg_thread: Rejected/negative example
|
|
1393
|
+
"""
|
|
1394
|
+
...
|
|
1395
|
+
async def train_dpo(
|
|
1396
|
+
self,
|
|
1397
|
+
sample_pos: StringThread,
|
|
1398
|
+
sample_neg: StringThread,
|
|
1399
|
+
ref_logprobs_pos: float,
|
|
1400
|
+
ref_logprobs_neg: float,
|
|
1401
|
+
beta: float,
|
|
1402
|
+
) -> None:
|
|
1403
|
+
"""Train with Direct Preference Optimization (DPO).
|
|
1404
|
+
|
|
1405
|
+
Args:
|
|
1406
|
+
sample_pos: Preferred completion thread
|
|
1407
|
+
sample_neg: Rejected completion thread
|
|
1408
|
+
ref_logprobs_pos: Reference model logprobs for positive
|
|
1409
|
+
ref_logprobs_neg: Reference model logprobs for negative
|
|
1410
|
+
beta: DPO temperature parameter
|
|
1411
|
+
"""
|
|
1412
|
+
...
|
|
1413
|
+
def top_p(self, top_p: float) -> TrainingModel: ...
|
|
1414
|
+
def temperature(self, temperature: float) -> TrainingModel: ...
|
|
1415
|
+
def max_gen_len(self, max_num_tokens: int) -> TrainingModel: ...
|
|
1416
|
+
def min_gen_len(self, min_num_tokens: int) -> TrainingModel: ...
|
|
1417
|
+
|
|
1418
|
+
class Thread(Enum):
|
|
1419
|
+
"""Enum of thread types (internal use)."""
|
|
1420
|
+
|
|
1421
|
+
StringThread = ...
|
|
1422
|
+
TokenizedThread = ...
|
|
1423
|
+
SerializedThread = ...
|
|
1424
|
+
|
|
1425
|
+
class JobNotifier:
|
|
1426
|
+
"""Helper class to report job progress to stdout.
|
|
1427
|
+
|
|
1428
|
+
By default logs progress to stdout. For integration with the Harmony platform,
|
|
1429
|
+
use HarmonyJobNotifier instead.
|
|
1430
|
+
|
|
1431
|
+
Example:
|
|
1432
|
+
```python
|
|
1433
|
+
notifier = JobNotifier()
|
|
1434
|
+
notifier.register_stages(["training", "evaluation"])
|
|
1435
|
+
|
|
1436
|
+
stage = notifier.stage_notifier("training")
|
|
1437
|
+
stage.report_progress(tot_num_samples=1000, processed_num_samples=100)
|
|
1438
|
+
```
|
|
1439
|
+
"""
|
|
1440
|
+
|
|
1441
|
+
def __new__(cls) -> JobNotifier: ...
|
|
1442
|
+
def set_monitoring_link(self, monitoring_link: str) -> None:
|
|
1443
|
+
"""Set a monitoring link (e.g., Weights & Biases URL).
|
|
1444
|
+
|
|
1445
|
+
Args:
|
|
1446
|
+
monitoring_link: URL to monitoring dashboard
|
|
1447
|
+
"""
|
|
1448
|
+
...
|
|
1449
|
+
def register_stages(self, stages: Sequence[str]) -> None:
|
|
1450
|
+
"""Register the stages of this job.
|
|
1451
|
+
|
|
1452
|
+
Args:
|
|
1453
|
+
stages: Ordered list of stage names
|
|
1454
|
+
"""
|
|
1455
|
+
...
|
|
1456
|
+
def register_artifact(self, artifact: JobArtifact) -> None:
|
|
1457
|
+
"""Register an artifact produced by this job.
|
|
1458
|
+
|
|
1459
|
+
Args:
|
|
1460
|
+
artifact: Artifact to register
|
|
1461
|
+
"""
|
|
1462
|
+
...
|
|
1463
|
+
def report_error(self, error: str) -> None:
|
|
1464
|
+
"""Report an error that occurred during the job.
|
|
1465
|
+
|
|
1466
|
+
Args:
|
|
1467
|
+
error: Error message
|
|
1468
|
+
"""
|
|
1469
|
+
...
|
|
1470
|
+
def report_progress(
|
|
1471
|
+
self,
|
|
1472
|
+
stage: str,
|
|
1473
|
+
tot_num_samples: Optional[int] = None,
|
|
1474
|
+
processed_num_samples: Optional[int] = None,
|
|
1475
|
+
monitoring_link: Optional[str] = None,
|
|
1476
|
+
checkpoints: Optional[Sequence[str]] = None,
|
|
1477
|
+
) -> None:
|
|
1478
|
+
"""Report progress for a stage.
|
|
1479
|
+
|
|
1480
|
+
Args:
|
|
1481
|
+
stage: Stage name
|
|
1482
|
+
tot_num_samples: Total number of samples to process
|
|
1483
|
+
processed_num_samples: Number of samples processed so far
|
|
1484
|
+
monitoring_link: Optional monitoring dashboard URL
|
|
1485
|
+
checkpoints: Optional list of checkpoint identifiers
|
|
1486
|
+
"""
|
|
1487
|
+
...
|
|
1488
|
+
def stage_notifier(self, stage: str) -> StageNotifier:
|
|
1489
|
+
"""Get a stage-specific notifier.
|
|
1490
|
+
|
|
1491
|
+
Args:
|
|
1492
|
+
stage: Stage name
|
|
1493
|
+
|
|
1494
|
+
Returns:
|
|
1495
|
+
StageNotifier for the given stage
|
|
1496
|
+
"""
|
|
1497
|
+
...
|
|
1498
|
+
def __repr__(self) -> str: ...
|
|
1499
|
+
|
|
1500
|
+
class HarmonyJobNotifier(JobNotifier):
|
|
1501
|
+
"""Job notifier that reports progress to the Harmony platform.
|
|
1502
|
+
|
|
1503
|
+
Use this instead of JobNotifier when running jobs through the Harmony API
|
|
1504
|
+
to integrate with the UI and tracking system.
|
|
1505
|
+
|
|
1506
|
+
Example:
|
|
1507
|
+
```python
|
|
1508
|
+
notifier = HarmonyJobNotifier(client, job_id)
|
|
1509
|
+
notifier.register_stages(["training", "evaluation"])
|
|
1510
|
+
|
|
1511
|
+
# Register artifacts
|
|
1512
|
+
artifact = JobArtifact(
|
|
1513
|
+
id="model-v1",
|
|
1514
|
+
name="Fine-tuned Model",
|
|
1515
|
+
kind="model",
|
|
1516
|
+
uri="s3://bucket/models/model-v1"
|
|
1517
|
+
)
|
|
1518
|
+
notifier.register_artifact(artifact)
|
|
1519
|
+
|
|
1520
|
+
# Report progress
|
|
1521
|
+
stage = notifier.stage_notifier("training")
|
|
1522
|
+
stage.report_progress(tot_num_samples=1000, processed_num_samples=500)
|
|
1523
|
+
```
|
|
1524
|
+
"""
|
|
1525
|
+
def __new__(cls, client: HarmonyClient, job_id: str) -> HarmonyJobNotifier: ...
|
|
1526
|
+
# DOCS_TODO: add example and explain it will show up in the UI
|
|
1527
|
+
def set_monitoring_link(self, monitoring_link: str) -> None: ...
|
|
1528
|
+
|
|
1529
|
+
class StageNotifier:
|
|
1530
|
+
"""Helper class to report progress for a specific job stage.
|
|
1531
|
+
|
|
1532
|
+
Get an instance via JobNotifier.stage_notifier() or HarmonyJobNotifier.stage_notifier().
|
|
1533
|
+
Provides convenience methods for reporting progress without repeating the stage name.
|
|
1534
|
+
"""
|
|
1535
|
+
|
|
1536
|
+
def set_monitoring_link(self, monitoring_link: str) -> None:
|
|
1537
|
+
"""Set monitoring link for this stage.
|
|
1538
|
+
|
|
1539
|
+
Args:
|
|
1540
|
+
monitoring_link: URL to monitoring dashboard
|
|
1541
|
+
"""
|
|
1542
|
+
...
|
|
1543
|
+
def report_progress(
|
|
1544
|
+
self,
|
|
1545
|
+
tot_num_samples: Optional[int] = None,
|
|
1546
|
+
processed_num_samples: Optional[int] = None,
|
|
1547
|
+
monitoring_link: Optional[str] = None,
|
|
1548
|
+
checkpoints: Optional[Sequence[str]] = None,
|
|
1549
|
+
) -> None:
|
|
1550
|
+
"""Report progress for this stage.
|
|
1551
|
+
|
|
1552
|
+
Args:
|
|
1553
|
+
tot_num_samples: Total number of samples to process
|
|
1554
|
+
processed_num_samples: Number of samples processed so far
|
|
1555
|
+
monitoring_link: Optional monitoring dashboard URL
|
|
1556
|
+
checkpoints: Optional list of checkpoint identifiers
|
|
1557
|
+
"""
|
|
1558
|
+
...
|
|
1559
|
+
|
|
1560
|
+
async def get_client(
|
|
1561
|
+
addr: str,
|
|
1562
|
+
num_gpus: int | None = None,
|
|
1563
|
+
api_key: str | None = None,
|
|
1564
|
+
use_case: str | None = None,
|
|
1565
|
+
compute_pool: str | None = None,
|
|
1566
|
+
job_id: str | None = None,
|
|
1567
|
+
default_headers: dict[str, str] | None = None,
|
|
1568
|
+
ttl_after_disconnect_s: int | None = 30,
|
|
1569
|
+
control_plane_url: str | None = None,
|
|
1570
|
+
control_plane_api_token: str | None = None,
|
|
1571
|
+
) -> HarmonyClient:
|
|
1572
|
+
"""Create and connect a HarmonyClient to workers.
|
|
1573
|
+
|
|
1574
|
+
This is the main entry point for connecting to the Harmony platform.
|
|
1575
|
+
The client manages the connection to GPU workers and provides access
|
|
1576
|
+
to model building and inference/training operations.
|
|
1577
|
+
|
|
1578
|
+
Args:
|
|
1579
|
+
addr: WebSocket address of the deployment (e.g., "ws://localhost:8080")
|
|
1580
|
+
num_gpus: Number of GPUs to request (None = use all available)
|
|
1581
|
+
api_key: API key for authentication (or set ADAPTIVE_API_KEY env var)
|
|
1582
|
+
use_case: Use case identifier (or set ADAPTIVE_USE_CASE env var, default: "default")
|
|
1583
|
+
compute_pool: Compute pool to use (or set ADAPTIVE_COMPUTE_POOL env var, default: "default")
|
|
1584
|
+
job_id: Job ID for tracking (or set ADAPTIVE_JOB_ID env var)
|
|
1585
|
+
default_headers: Additional HTTP headers for the connection
|
|
1586
|
+
ttl_after_disconnect_s: Keep session alive for this many seconds after disconnect (default: 30)
|
|
1587
|
+
control_plane_url: URL of control plane for fetching model/dataset/grader configs
|
|
1588
|
+
control_plane_api_token: API token for authenticating with the control plane
|
|
1589
|
+
|
|
1590
|
+
Returns:
|
|
1591
|
+
Connected HarmonyClient instance
|
|
1592
|
+
|
|
1593
|
+
Example:
|
|
1594
|
+
```python
|
|
1595
|
+
from harmony_client import get_client
|
|
1596
|
+
|
|
1597
|
+
# Basic connection
|
|
1598
|
+
client = await get_client("ws://localhost:8080", num_gpus=1)
|
|
1599
|
+
|
|
1600
|
+
# With authentication
|
|
1601
|
+
client = await get_client(
|
|
1602
|
+
"wss://api.adaptive.com",
|
|
1603
|
+
num_gpus=4,
|
|
1604
|
+
api_key="my-api-key",
|
|
1605
|
+
control_plane_url="https://api.adaptive.com"
|
|
1606
|
+
)
|
|
1607
|
+
|
|
1608
|
+
# Use as context manager
|
|
1609
|
+
async with await get_client(...) as client:
|
|
1610
|
+
model = client.model("model_registry://llama-3.1-8b")
|
|
1611
|
+
# ... use model
|
|
1612
|
+
# Connection closed automatically
|
|
1613
|
+
```
|
|
1614
|
+
"""
|
|
1615
|
+
...
|