harmony-client 0.1.0__cp312-cp312-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,1615 @@
1
+ # ruff: noqa: E501, F401
2
+ """Harmony Client - Python bindings for the Adaptive Harmony ML orchestration platform.
3
+
4
+ This module provides the Python interface to interact with Harmony workers for model training
5
+ and inference operations. It includes thread management, model lifecycle operations, and
6
+ training/evaluation utilities.
7
+ """
8
+
9
+ from enum import Enum
10
+ from typing import Any, MutableSequence, NamedTuple, Optional, Sequence, Type, TypeVar
11
+
12
+ from adaptive_harmony.runtime import RecipeContext
13
+ from pydantic import BaseModel
14
+ from typing_extensions import Literal, Required, TypedDict
15
+
16
+ T = TypeVar("T", bound=BaseModel)
17
+
18
+ # Constants
19
+ DEFAULT_MAX_DRAFT_STEPS: int
20
+ """Default maximum number of draft steps for speculative decoding."""
21
+
22
+ class ImageFragment(TypedDict, total=False):
23
+ """Fragment representing an image in a conversation turn.
24
+
25
+ Attributes:
26
+ type: Must be "image"
27
+ url: URL or data URI of the image (supports data URLs, file paths, and HTTP URLs)
28
+ """
29
+
30
+ type: Required[Literal["image"]]
31
+ url: Required[str]
32
+
33
+ class TextFragment(TypedDict, total=False):
34
+ """Fragment representing text content in a conversation turn.
35
+
36
+ Attributes:
37
+ type: Must be "text"
38
+ text: The text content
39
+ """
40
+
41
+ type: Required[Literal["text"]]
42
+ text: Required[str]
43
+
44
+ Fragment = ImageFragment | TextFragment
45
+ """Union type for content fragments that can appear in conversation turns."""
46
+
47
+ class EvalSample:
48
+ """Represents a single evaluation sample with its interaction and grades.
49
+
50
+ An evaluation sample captures a conversation thread along with grades assigned
51
+ by one or more graders. Used for evaluation artifact creation and analysis.
52
+
53
+ Attributes:
54
+ interaction: The conversation interaction being evaluated
55
+ grades: List of grades from different graders
56
+ dataset_key: Key identifying the dataset this sample belongs to
57
+ id: Unique identifier for this evaluation sample (auto-generated UUID)
58
+ """
59
+
60
+ interaction: EvalSampleInteraction
61
+ grades: MutableSequence[Grade]
62
+ dataset_key: str
63
+ id: str
64
+
65
+ def __new__(
66
+ cls,
67
+ interaction: EvalSampleInteraction,
68
+ grades: MutableSequence[Grade],
69
+ dataset_key: str,
70
+ ) -> EvalSample: ...
71
+ def __repr__(self) -> str: ...
72
+
73
+ class EvalSampleInteraction:
74
+ """Represents a conversation interaction for evaluation.
75
+
76
+ Encapsulates a conversation thread and optionally tracks which model
77
+ or source generated the interaction.
78
+
79
+ Attributes:
80
+ thread: The conversation thread (messages between user and assistant)
81
+ source: Optional identifier for the model or system that generated this interaction
82
+ """
83
+
84
+ thread: StringThread
85
+ source: Optional[str]
86
+
87
+ def __new__(cls, thread: StringThread, source: Optional[str] = None) -> EvalSampleInteraction: ...
88
+ def __repr__(self) -> str: ...
89
+
90
+ class EvaluationArtifactBase:
91
+ """Base class for evaluation artifacts that can be registered with jobs.
92
+
93
+ Evaluation artifacts track the results of model evaluations, including
94
+ the evaluated samples and their grades. Can be uploaded to object storage.
95
+
96
+ Attributes:
97
+ artifact: The underlying job artifact with metadata
98
+ id: Unique identifier for this artifact
99
+ name: Human-readable name for the evaluation
100
+ kind: Type of artifact (always "eval" for evaluations)
101
+ uri: Optional URI where evaluation results are stored
102
+ """
103
+
104
+ artifact: JobArtifact
105
+ id: str
106
+ name: str
107
+ kind: str
108
+ uri: Optional[str]
109
+
110
+ def __new__(cls, name: str, uri: str, id: str, **py_kwargs) -> EvaluationArtifactBase: ...
111
+ def samples_to_adaptive_json(self, samples: MutableSequence[EvalSample]) -> list[str]:
112
+ """Convert evaluation samples to JSONL format for storage.
113
+
114
+ Args:
115
+ samples: List of evaluation samples to serialize
116
+
117
+ Returns:
118
+ List of JSON strings, one per sample, suitable for JSONL file format
119
+ """
120
+ ...
121
+ def __repr__(self) -> str: ...
122
+
123
+ class Grade:
124
+ """Represents a grade assigned to an evaluation sample by a grader.
125
+
126
+ Grades can include numeric scores and optional reasoning explaining
127
+ the score.
128
+
129
+ Attributes:
130
+ value: Numeric grade value (typically 0.0-1.0 range)
131
+ grader_key: Identifier for the grader that assigned this grade
132
+ reasoning: Optional explanation for why this grade was assigned
133
+ """
134
+
135
+ value: float
136
+ grader_key: str
137
+ reasoning: Optional[str]
138
+
139
+ def __new__(cls, value: float, grader_key: str, reasoning: Optional[str] = None) -> Grade: ...
140
+ def __repr__(self) -> str: ...
141
+
142
+ class JobArtifact:
143
+ """Represents an artifact produced by a training or evaluation job.
144
+
145
+ Job artifacts track outputs like trained models, evaluation results, datasets,
146
+ or custom artifacts. They can be registered with jobs to appear in the UI.
147
+
148
+ Attributes:
149
+ id: Unique identifier for this artifact
150
+ name: Human-readable name
151
+ kind: Type of artifact - "model", "eval", "dataset", or "custom"
152
+ metadata: Additional key-value metadata for the artifact
153
+ uri: Optional URI where the artifact is stored
154
+
155
+ Example:
156
+ ```python
157
+ artifact = JobArtifact(
158
+ id="my-model-v1",
159
+ name="Fine-tuned Model",
160
+ kind="model",
161
+ uri="s3://bucket/models/my-model",
162
+ checkpoint_step=1000,
163
+ loss=0.25
164
+ )
165
+ notifier.register_artifact(artifact)
166
+ ```
167
+ """
168
+
169
+ id: str
170
+ name: str
171
+ kind: str
172
+ metadata: dict[str, Any]
173
+ uri: Optional[str]
174
+
175
+ def __new__(
176
+ cls,
177
+ id: str,
178
+ name: str,
179
+ kind: str,
180
+ uri: Optional[str] = None,
181
+ **py_kwargs,
182
+ ) -> JobArtifact: ...
183
+ def __repr__(self) -> str: ...
184
+
185
+ class StringTurn(NamedTuple):
186
+ """A single turn in a conversation thread with string content.
187
+
188
+ Attributes:
189
+ role: The role of the entity that is creating the content
190
+ content: The text content of the turn
191
+ """
192
+
193
+ role: str
194
+ content: str
195
+
196
+ class TokenizedTurn(NamedTuple):
197
+ """A single turn in a conversation thread with tokenized content.
198
+
199
+ Attributes:
200
+ role: The role of the entity that is creating the content
201
+ content: The tokenized content as a list of token IDs
202
+ """
203
+
204
+ role: str
205
+ content: list[int]
206
+
207
+ class ModelConfigResponse:
208
+ """Configuration metadata for a model retrieved from the control plane.
209
+
210
+ Attributes:
211
+ model_id: Unique identifier for the model
212
+ model_key: Human-readable key for the model
213
+ path: File system or registry path to the model
214
+ tp: Optional tensor parallelism degree
215
+ kv_cache_len: Optional KV cache length for inference
216
+ max_seq_len: Optional maximum sequence length the model supports
217
+ """
218
+
219
+ model_id: str
220
+ model_key: str
221
+ path: str
222
+ tp: int | None
223
+ kv_cache_len: int | None
224
+ max_seq_len: int | None
225
+
226
+ class DatasetConfigResponse:
227
+ """Configuration metadata for a dataset retrieved from the control plane.
228
+
229
+ Attributes:
230
+ dataset_id: Unique identifier for the dataset
231
+ dataset_key: Human-readable key for the dataset
232
+ name: Display name of the dataset
233
+ file_path: Path to the dataset file
234
+ kind: Type of dataset (e.g., "jsonl", "parquet")
235
+ """
236
+
237
+ dataset_id: str
238
+ dataset_key: str
239
+ name: str
240
+ file_path: str
241
+ kind: str
242
+
243
+ class GraderConfigResponse:
244
+ """Configuration metadata for a grader retrieved from the control plane.
245
+
246
+ Attributes:
247
+ grader_id: Unique identifier for the grader
248
+ key: Human-readable key for the grader
249
+ name: Display name of the grader
250
+ harmony_url: URL of the harmony instance to use for grading
251
+ grader_config_json: JSON configuration for the grader
252
+ """
253
+
254
+ grader_id: str
255
+ key: str
256
+ name: str
257
+ harmony_url: str
258
+ grader_config_json: str
259
+
260
+ class HarmonyClient:
261
+ """Main client for interacting with Harmony workers.
262
+
263
+ The HarmonyClient is the primary interface for creating models and accessing
264
+ platform configuration. It manages the connection to Harmony workers and
265
+ provides access to model builders for spawning inference and training models.
266
+
267
+ Example:
268
+ ```python
269
+ from harmony_client import get_client
270
+
271
+ client = await get_client(
272
+ addr="ws://localhost:8080",
273
+ num_gpus=1,
274
+ api_key="my-key"
275
+ )
276
+
277
+ # Create a model builder
278
+ model = client.model("model_registry://llama-3.1-8b")
279
+
280
+ # Spawn an inference model
281
+ inf_model = await model.spawn_inference("my-model")
282
+
283
+ # Generate text
284
+ thread = StringThread([("user", "Hello!")])
285
+ result = await inf_model.generate(thread)
286
+ print(result.last_content())
287
+ ```
288
+ """
289
+ def model(self, path: str, kv_cache_len: int = 131072, tokens_to_generate: int = 2048) -> ModelBuilder:
290
+ """Create a model builder for spawning inference or training models.
291
+
292
+ Args:
293
+ path: Path to the model. Can be:
294
+ - Model registry key: "model_registry://llama-3.1-8b"
295
+ - External provider: "openai://gpt-4", "anthropic://claude-3-5-sonnet"
296
+ - URL with API key: "openai://gpt-4?api_key=sk-..."
297
+ kv_cache_len: KV cache length for inference (default: 131072). Will be ignored for external models as we do not control it.
298
+ tokens_to_generate: Maximum tokens to generate (default: 2048)
299
+
300
+ Returns:
301
+ ModelBuilder that can be configured and spawned
302
+ """
303
+ ...
304
+ def session_id(self) -> str:
305
+ """Get the unique session ID for this client connection.
306
+
307
+ Returns:
308
+ Session UUID as a string
309
+ """
310
+ ...
311
+ async def get_grader_config(self, grader_key: str) -> GraderConfigResponse:
312
+ """Fetch grader configuration from the control plane.
313
+
314
+ Args:
315
+ grader_key: Key identifying the grader
316
+
317
+ Returns:
318
+ GraderConfigResponse with grader metadata
319
+
320
+ Raises:
321
+ RecipeError: If control_plane_url was not provided to get_client()
322
+ """
323
+ ...
324
+ async def get_dataset_config(self, dataset_key: str) -> DatasetConfigResponse:
325
+ """Fetch dataset configuration from the control plane.
326
+
327
+ Args:
328
+ dataset_key: Key identifying the dataset
329
+
330
+ Returns:
331
+ DatasetConfigResponse with dataset metadata
332
+
333
+ Raises:
334
+ RecipeError: If control_plane_url was not provided to get_client()
335
+ """
336
+ ...
337
+ async def get_model_config(self, model_key: str) -> ModelConfigResponse:
338
+ """Fetch model configuration from the control plane.
339
+
340
+ Args:
341
+ model_key: Key identifying the model
342
+
343
+ Returns:
344
+ ModelConfigResponse with model metadata
345
+
346
+ Raises:
347
+ RecipeError: If control_plane_url was not provided to get_client()
348
+ """
349
+ ...
350
+ def close(self):
351
+ """Close the client connection and release resources."""
352
+ ...
353
+ def __enter__(self) -> HarmonyClient: ...
354
+ def __exit__(self, exc_type, exc_value, traceback) -> None: ...
355
+
356
+ class InferenceModel:
357
+ """Model instance for running inference operations.
358
+
359
+ InferenceModel provides methods for text generation, tokenization, scoring,
360
+ and computing log probabilities. It can be created from a ModelBuilder using
361
+ spawn_inference().
362
+
363
+ The model is automatically deallocated when the object is garbage collected,
364
+ but you can explicitly call dealloc() to free GPU memory sooner.
365
+
366
+ Example:
367
+ ```python
368
+ model = client.model("model_registry://llama-3.1-8b")
369
+ inf_model = await model.spawn_inference("my-inf-model")
370
+
371
+ # Generate text
372
+ thread = StringThread([("user", "What is 2+2?")])
373
+ result = await inf_model.generate(thread)
374
+ print(result.last_content()) # "4"
375
+ ```
376
+ """
377
+ def is_scalar(self) -> bool:
378
+ """Check if this model has a scalar output head (for scoring/reward models).
379
+
380
+ Returns:
381
+ True if model outputs scalars instead of next-token probabilities
382
+ """
383
+ ...
384
+ async def dealloc(self) -> None:
385
+ """Deallocate the model from GPU memory.
386
+
387
+ Explicitly frees GPU resources. The model cannot be used after this call.
388
+ Models are automatically deallocated on garbage collection if not called.
389
+ """
390
+ ...
391
+ async def load(self, path: str) -> None:
392
+ """Load model weights from a checkpoint path.
393
+
394
+ Args:
395
+ path: Path to model weights
396
+ """
397
+ ...
398
+ async def generate(self, thread: StringThread, return_timings: bool = False) -> StringThread:
399
+ """Generate text completion for a conversation thread.
400
+
401
+ Appends an assistant turn with the generated completion to the thread.
402
+
403
+ Args:
404
+ thread: Conversation thread to continue
405
+ return_timings: If True, returns tuple of (thread, timings)
406
+
407
+ Returns:
408
+ Updated thread with assistant response, or (thread, timings) tuple
409
+
410
+ Example:
411
+ ```python
412
+ thread = StringThread([("user", "Hello!")])
413
+ result = await model.generate(thread)
414
+ print(result.last_content()) # Generated response
415
+ ```
416
+ """
417
+ ...
418
+ async def generate_tokens(self, thread: StringThread) -> TokenizedThread:
419
+ """Generate token IDs for a conversation thread.
420
+
421
+ Similar to generate() but returns tokenized output instead of strings.
422
+
423
+ Args:
424
+ thread: Conversation thread to continue
425
+
426
+ Returns:
427
+ TokenizedThread with generated token IDs
428
+ """
429
+ ...
430
+ async def generate_and_validate(
431
+ self,
432
+ thread: StringThread,
433
+ pydantic_model: Type[T],
434
+ max_parsing_retries: int = 1,
435
+ ) -> tuple[str, T]:
436
+ """Generate structured output validated against a Pydantic model.
437
+
438
+ Generates JSON output and parses it into the specified Pydantic model,
439
+ with automatic retries on parse failures.
440
+
441
+ Args:
442
+ thread: Conversation thread (should request JSON output)
443
+ pydantic_model: Pydantic model class to validate against
444
+ max_parsing_retries: Maximum retry attempts on parse failures
445
+
446
+ Returns:
447
+ Tuple of (raw_json_string, parsed_model_instance)
448
+
449
+ Example:
450
+ ```python
451
+ from pydantic import BaseModel
452
+
453
+ class Response(BaseModel):
454
+ answer: str
455
+ confidence: float
456
+
457
+ thread = StringThread([(
458
+ "user",
459
+ "Return JSON with 'answer' and 'confidence' fields"
460
+ )])
461
+ json_str, parsed = await model.generate_and_validate(
462
+ thread, Response
463
+ )
464
+ print(parsed.answer, parsed.confidence)
465
+ ```
466
+ """
467
+ ...
468
+ async def tokenize_thread(self, thread: StringThread) -> TokenizedThread:
469
+ """Convert a StringThread to a TokenizedThread.
470
+
471
+ Args:
472
+ thread: String-based conversation thread
473
+
474
+ Returns:
475
+ Tokenized version of the thread
476
+ """
477
+ ...
478
+ async def detokenize_thread(self, thread: TokenizedThread) -> StringThread:
479
+ """Convert a TokenizedThread back to a StringThread.
480
+
481
+ Args:
482
+ thread: Tokenized conversation thread
483
+
484
+ Returns:
485
+ String version of the thread
486
+ """
487
+ ...
488
+
489
+ # the outputs of serialize_* are a list of tokens for the whole thread
490
+ # a list with the image tokens, and a list with each token's weight
491
+ async def serialize_thread(
492
+ self, thread: StringThread
493
+ ) -> tuple[list[int], list[tuple[int, list[float]]], list[float]]:
494
+ """Serialize a thread into tokens with weights for training.
495
+
496
+ Returns:
497
+ Tuple of (token_ids, image_tokens, weights) where:
498
+ - token_ids: All tokens in the thread
499
+ - image_tokens: List of (position, embedding) for images
500
+ - weights: Per-token training weights (0.0 = skip, 1.0 = train)
501
+ """
502
+ ...
503
+ async def serialize_tokenized_thread(
504
+ self, thread: TokenizedThread
505
+ ) -> tuple[list[int], list[tuple[int, list[float]]], list[float]]:
506
+ """Serialize a tokenized thread with weights for training.
507
+
508
+ Returns:
509
+ Tuple of (token_ids, image_tokens, weights)
510
+ """
511
+ ...
512
+ async def logprobs(self, thread: StringThread) -> float:
513
+ """Compute the total log probability of a thread.
514
+
515
+ Returns the sum of log probabilities of all weighted tokens.
516
+
517
+ Args:
518
+ thread: Conversation thread to score
519
+
520
+ Returns:
521
+ Total log probability (summed across weighted tokens)
522
+ """
523
+ ...
524
+ async def logprobs_per_token(self, thread: TokenizedThread) -> list[float]:
525
+ """Compute log probabilities for each weighted token in a thread.
526
+
527
+ Args:
528
+ thread: Tokenized conversation thread
529
+
530
+ Returns:
531
+ List of log probabilities for each weighted token
532
+ """
533
+ ...
534
+ async def score(self, thread: TokenizedThread) -> list[float]:
535
+ """Score each weighted token using a scalar head model.
536
+
537
+ Only works with scoring models (where is_scalar() returns True).
538
+
539
+ Args:
540
+ thread: Tokenized thread to score
541
+
542
+ Returns:
543
+ List of scalar scores for each weighted token
544
+ """
545
+ ...
546
+ async def score_last_token(self, thread: StringThread) -> float:
547
+ """Score the last token using a scalar head model.
548
+
549
+ Args:
550
+ thread: Conversation thread
551
+
552
+ Returns:
553
+ Scalar score for the last token
554
+ """
555
+ ...
556
+ async def raw_string_create(self, prompt: str) -> str:
557
+ """Generate text from a raw string prompt (no conversation formatting).
558
+ Should normally not be used unless you know exactly what you are doing.
559
+
560
+ Args:
561
+ prompt: Raw text prompt
562
+
563
+ Returns:
564
+ Generated completion
565
+ """
566
+ ...
567
+ async def raw_token_create(self, prompt: Sequence[int]) -> list[int]:
568
+ """Generate tokens from raw token IDs (no conversation formatting).
569
+ Should normally not be used unless you know exactly what you are doing.
570
+
571
+ Args:
572
+ prompt: List of token IDs
573
+
574
+ Returns:
575
+ Generated token IDs
576
+ """
577
+ ...
578
+ async def tokenize(self, data: str) -> list[int]:
579
+ """Tokenize a string into token IDs.
580
+
581
+ Args:
582
+ data: Text to tokenize
583
+
584
+ Returns:
585
+ List of token IDs
586
+ """
587
+ ...
588
+ async def detokenize(self, data: Sequence[int]) -> str:
589
+ """Convert token IDs back to text.
590
+
591
+ Args:
592
+ data: List of token IDs
593
+
594
+ Returns:
595
+ Decoded text string
596
+ """
597
+ ...
598
+ async def char_to_token_rewards(self, text: str, char_rewards: Sequence[float]) -> list[float]:
599
+ """Map character-level rewards to token-level rewards.
600
+
601
+ Useful for reward modeling when you have character-granular feedback.
602
+
603
+ Args:
604
+ text: The text string
605
+ char_rewards: Reward value for each character
606
+
607
+ Returns:
608
+ Reward value for each token
609
+ """
610
+ ...
611
+ async def raw_logprobs(self, tokens: Sequence[int]) -> list[float]:
612
+ """Compute log probabilities for raw token IDs.
613
+ Should normally not be used unless you know exactly what you are doing.
614
+
615
+ Args:
616
+ tokens: Sequence of token IDs
617
+
618
+ Returns:
619
+ Log probability for each token (first token gets 0.0)
620
+ """
621
+ ...
622
+ async def thread_with_tools(self, thread: StringThread, tool_uris: Sequence[str]) -> StringThread:
623
+ """Inject tool definitions into a conversation thread.
624
+ DOCS_TODO: add an example here.
625
+
626
+ Args:
627
+ thread: Base conversation thread
628
+ tool_uris: List of tool URIs to make available
629
+
630
+ Returns:
631
+ Thread with tool definitions injected
632
+ """
633
+ ...
634
+ @staticmethod
635
+ def render_schema(pydantic_model: type[BaseModel], with_field_descriptions: bool = True) -> str:
636
+ """Render a Pydantic model as a JSON schema string.
637
+
638
+ Args:
639
+ pydantic_model: Pydantic model class
640
+ with_field_descriptions: Include field descriptions in schema
641
+
642
+ Returns:
643
+ JSON schema as a string
644
+ """
645
+ ...
646
+ @staticmethod
647
+ def render_pydantic_model(pydantic_model: BaseModel) -> str:
648
+ """Serialize a Pydantic model instance to JSON string.
649
+
650
+ Args:
651
+ pydantic_model: Pydantic model instance
652
+
653
+ Returns:
654
+ JSON representation
655
+ """
656
+ ...
657
+ def get_builder_args(self) -> dict[str, Any]:
658
+ """Get the configuration args used to build this model.
659
+
660
+ Returns:
661
+ Dictionary of model builder arguments
662
+ """
663
+ ...
664
+ def top_p(self, top_p: float) -> InferenceModel:
665
+ """Return a shallow copy with modified top_p sampling parameter.
666
+
667
+ Args:
668
+ top_p: Nucleus sampling probability (0.0-1.0)
669
+
670
+ Returns:
671
+ New InferenceModel instance with updated parameter
672
+ """
673
+ ...
674
+ def temperature(self, temperature: float) -> InferenceModel:
675
+ """Return a shallow copy with modified temperature sampling parameter.
676
+
677
+ Args:
678
+ temperature: Sampling temperature (>0, typically 0.1-2.0)
679
+
680
+ Returns:
681
+ New InferenceModel instance with updated parameter
682
+ """
683
+ ...
684
+ def max_gen_len(self, max_num_tokens: int) -> InferenceModel:
685
+ """Return a shallow copy with modified maximum generation length.
686
+
687
+ Args:
688
+ max_num_tokens: Maximum tokens to generate
689
+
690
+ Returns:
691
+ New InferenceModel instance with updated parameter
692
+ """
693
+ ...
694
+ def min_gen_len(self, min_num_tokens: int) -> InferenceModel:
695
+ """Return a shallow copy with modified minimum generation length.
696
+
697
+ Args:
698
+ min_num_tokens: Minimum tokens to generate
699
+
700
+ Returns:
701
+ New InferenceModel instance with updated parameter
702
+ """
703
+ ...
704
+
705
+ class InferenceSettings:
706
+ """Configuration for inference model spawning.
707
+
708
+ Attributes:
709
+ kv_cache_len: Length of the KV cache
710
+ tokens_to_generate: Maximum tokens to generate per request
711
+ """
712
+ def __new__(cls, kv_cache_len: int, tokens_to_generate: int) -> InferenceSettings: ...
713
+ def to_python_dict(self) -> dict[str, Any]:
714
+ """Convert settings to a dictionary.
715
+
716
+ Returns:
717
+ Dictionary representation of settings
718
+ """
719
+ ...
720
+
721
+ class ModelBuilder:
722
+ """Builder for configuring and spawning model instances.
723
+
724
+ ModelBuilder provides a fluent API for configuring model parameters
725
+ before spawning inference or training instances. Create via HarmonyClient.model().
726
+
727
+ Example:
728
+ ```python
729
+ # Basic inference model
730
+ model = client.model("model_registry://llama-3.1-8b")
731
+ inf = await model.spawn_inference("my-model")
732
+
733
+ # With tensor parallelism and temperature
734
+ model = (client.model("llama-3.1-70b")
735
+ .tp(4)
736
+ .temperature(0.7))
737
+ inf = await model.spawn_inference("my-70b")
738
+
739
+ # Training model with LoRA adapter
740
+ model = (client.model("model_registry://llama-3.1-8b")
741
+ .with_adapter())
742
+ train = await model.spawn_train("my-train", max_batch_size=4)
743
+
744
+ # With speculative decoding
745
+ model = (client.model("llama-3.1-70b")
746
+ .with_draft("llama-3.1-8b", num_draft_steps=4))
747
+ inf = await model.spawn_inference("fast-70b")
748
+ ```
749
+ """
750
+ async def spawn_train(self, name: str, max_batch_size: int) -> TrainingModel:
751
+ """Spawn a training model instance.
752
+
753
+ Args:
754
+ name: Unique name for this model instance
755
+ max_batch_size: Maximum batch size for training. Due to our automatic packing this is also the maximum sequence length of a single sample.
756
+
757
+ Returns:
758
+ TrainingModel instance for fine-tuning operations
759
+ """
760
+ ...
761
+ async def spawn_inference(self, name: str) -> InferenceModel:
762
+ """Spawn an inference model instance.
763
+
764
+ Args:
765
+ name: Unique name for this model instance
766
+
767
+ Returns:
768
+ InferenceModel instance for generation operations
769
+ """
770
+ ...
771
+ def tp(self, tp: int) -> ModelBuilder:
772
+ """Set tensor parallelism degree.
773
+
774
+ Args:
775
+ tp: Number of GPUs to split the model across
776
+
777
+ Returns:
778
+ Updated builder
779
+ """
780
+ ...
781
+ def tools(self, tools: list[str]) -> ModelBuilder:
782
+ """Add tool calling support.
783
+
784
+ Args:
785
+ tools: List of tool URIs to make available
786
+
787
+ Returns:
788
+ Updated builder
789
+ """
790
+ ...
791
+ def api_key(self, api_key: str) -> ModelBuilder:
792
+ """Set API key for external model providers.
793
+
794
+ Args:
795
+ api_key: API key for provider (OpenAI, Anthropic, etc.)
796
+
797
+ Returns:
798
+ Updated builder
799
+ """
800
+ ...
801
+ def into_scoring_model(self) -> ModelBuilder:
802
+ """Configure model to use scalar output head for scoring. Useful when creating a reward/value model from a base language model.
803
+
804
+ Returns:
805
+ Updated builder configured for reward/value modeling
806
+ """
807
+ ...
808
+ def with_adapter(self, use_adapter: bool = True) -> ModelBuilder:
809
+ """Configure adapter/LoRA mode for model spawning.
810
+
811
+ Args:
812
+ use_adapter: Controls adapter mode (default: True).
813
+ - True: Enable adapter mode. For FW models, creates a new adapter.
814
+ For LoRA models, uses the existing adapter (no-op).
815
+ - False: Disable adapter mode. For FW models, trains the full weights directly.
816
+ For LoRA models, this is an error (cannot fall back to FW from a LoRA).
817
+
818
+ If with_adapter() is not called at all:
819
+ - FW models default to training full weights (equivalent to with_adapter(False))
820
+ - LoRA models default to using the adapter (equivalent to with_adapter(True))
821
+
822
+ Returns:
823
+ Updated builder with adapter mode configured
824
+ """
825
+ ...
826
+ def with_draft(self, draft_model_path: str, num_draft_steps: int | None = None) -> ModelBuilder:
827
+ """Enable speculative decoding with a draft model.
828
+
829
+ Args:
830
+ draft_model_path: Path to smaller/faster draft model. Must share the same tokenizer as the main model.
831
+ num_draft_steps: Number of speculative tokens per step (default: DEFAULT_MAX_DRAFT_STEPS)
832
+
833
+ Returns:
834
+ Updated builder with speculative decoding enabled
835
+ """
836
+ ...
837
+ def to_python_dict(self) -> dict[str, Any]:
838
+ """Get builder configuration as a dictionary.
839
+
840
+ Returns:
841
+ Dictionary of builder parameters
842
+ """
843
+ ...
844
+ def extra_params(self, **params) -> ModelBuilder:
845
+ """Set additional provider-specific parameters.
846
+ DOCS_TODO: add examples for OpenAI, Anthropic, etc.
847
+
848
+ Args:
849
+ **params: Arbitrary parameters passed to the model provider
850
+
851
+ Returns:
852
+ Updated builder
853
+ """
854
+ ...
855
+
856
+ class SerializedThread:
857
+ """Opaque type representing a fully serialized conversation thread.
858
+
859
+ Used internally for efficient communication between client and workers.
860
+ """
861
+
862
+ ...
863
+
864
+ class StatDuration:
865
+ """Statistics about tokens processed over a time period.
866
+
867
+ Tracks input, output, and training token counts and throughput rates.
868
+
869
+ Attributes:
870
+ duration_sec: Duration in seconds
871
+ num_input_tokens: Total input tokens processed
872
+ num_output_tokens: Total output tokens generated
873
+ num_trained_tokens: Total tokens used for training
874
+ num_input_tokens_per_s: Input token throughput
875
+ num_output_tokens_per_s: Output token throughput
876
+ num_trained_tokens_per_s: Training token throughput
877
+ """
878
+
879
+ duration_sec: float
880
+ num_input_tokens: int
881
+ num_output_tokens: int
882
+ num_trained_tokens: int
883
+ num_input_tokens_per_s: float
884
+ num_output_tokens_per_s: float
885
+ num_trained_tokens_per_s: float
886
+
887
+ def combine(self, other: StatDuration) -> StatDuration:
888
+ """Combine statistics from two time periods.
889
+
890
+ Args:
891
+ other: Another StatDuration to combine
892
+
893
+ Returns:
894
+ Combined statistics spanning both periods
895
+ """
896
+ ...
897
+
898
+ class StatInstant:
899
+ """Snapshot of model statistics at a point in time.
900
+
901
+ Used to compute statistics over time intervals via stats_since().
902
+ """
903
+ def stats_since(self, other: StatInstant) -> StatDuration:
904
+ """Compute statistics since another snapshot.
905
+
906
+ Args:
907
+ other: Earlier StatInstant snapshot
908
+
909
+ Returns:
910
+ StatDuration covering the time between snapshots
911
+ """
912
+ ...
913
+
914
+ class StringThread:
915
+ """Represents a conversation thread with string-based content.
916
+
917
+ StringThread is the primary way to represent conversations in Harmony. It consists
918
+ of a sequence of turns (user, assistant, system, tool) with text content. Each turn
919
+ has an associated weight that controls whether it's used for training.
920
+
921
+ Threads support multi-modal content via fragments (text and images).
922
+
923
+ Attributes:
924
+ metadata: Arbitrary Python object to attach to this thread
925
+
926
+ Example:
927
+ ```python
928
+ # Create a simple conversation
929
+ thread = StringThread([
930
+ ("user", "What is the capital of France?"),
931
+ ("assistant", "The capital of France is Paris.")
932
+ ])
933
+
934
+ # Add more turns
935
+ thread = thread.user("Tell me more about it.")
936
+
937
+ # Access turns
938
+ for role, content in thread.get_turns():
939
+ print(f"{role}: {content}")
940
+
941
+ # Multi-modal with images
942
+ thread = await StringThread.from_fragments([
943
+ ("user", [
944
+ {"type": "text", "text": "What's in this image?"},
945
+ {"type": "image", "url": "data:image/png;base64,..."}
946
+ ])
947
+ ], metadata=None)
948
+ ```
949
+ """
950
+
951
+ metadata: Any
952
+
953
+ def __new__(
954
+ cls,
955
+ turns: Optional[Sequence[tuple[str, str]]] = None,
956
+ metadata: Optional[Any] = None,
957
+ ) -> StringThread: ...
958
+ @staticmethod
959
+ async def from_fragments(raw_turns: Sequence[tuple[str, Sequence[Fragment]]], metadata: Any) -> StringThread:
960
+ """Create a thread from multi-modal fragments (text and images).
961
+
962
+ Args:
963
+ raw_turns: Sequence of (role, fragments) tuples
964
+ metadata: Optional metadata to attach
965
+
966
+ Returns:
967
+ StringThread with multi-modal content
968
+ """
969
+ ...
970
+ @staticmethod
971
+ async def from_dataset(raw_turns: Sequence[tuple[str, str | Sequence[Fragment]]], metadata: Any) -> StringThread:
972
+ """Create a thread from dataset format (strings or fragments).
973
+
974
+ Flexible constructor that accepts either plain strings or fragment lists
975
+ for each turn's content.
976
+
977
+ Args:
978
+ raw_turns: Sequence of (role, content) where content is string or fragments
979
+ metadata: Optional metadata to attach
980
+
981
+ Returns:
982
+ StringThread from dataset
983
+ """
984
+ ...
985
+ @classmethod
986
+ def with_metadata(cls, turns: Sequence[tuple[str, str]], metadata: Any) -> StringThread: ...
987
+ def user(self, content: str) -> StringThread:
988
+ """Append a user turn to the thread.
989
+
990
+ Args:
991
+ content: User message text
992
+
993
+ Returns:
994
+ New thread with user turn appended
995
+ """
996
+ ...
997
+ def system(self, content: str) -> StringThread:
998
+ """Append a system turn to the thread.
999
+
1000
+ Args:
1001
+ content: System message text
1002
+
1003
+ Returns:
1004
+ New thread with system turn appended
1005
+ """
1006
+ ...
1007
+ def assistant(self, content: str) -> StringThread:
1008
+ """Append an assistant turn to the thread.
1009
+
1010
+ Args:
1011
+ content: Assistant message text
1012
+
1013
+ Returns:
1014
+ New thread with assistant turn appended
1015
+ """
1016
+ ...
1017
+ def tool(self, content: str) -> StringThread:
1018
+ """Append a tool result turn to the thread.
1019
+
1020
+ Args:
1021
+ content: Tool output text
1022
+
1023
+ Returns:
1024
+ New thread with tool turn appended
1025
+ """
1026
+ ...
1027
+ def last_content(self) -> str:
1028
+ """Get the content of the last turn.
1029
+
1030
+ Returns:
1031
+ Text content of the last turn
1032
+
1033
+ Raises:
1034
+ RecipeError: If thread is empty
1035
+ """
1036
+ ...
1037
+ def messages(self) -> list[tuple[str, str]]:
1038
+ """Get all turns except the last assistant turn (if present).
1039
+ DOCS_TODO: show example.
1040
+
1041
+ Useful for extracting the prompt portion of a thread.
1042
+
1043
+ Returns:
1044
+ List of (role, content) tuples excluding final assistant turn
1045
+ """
1046
+ ...
1047
+ def completion(self) -> str | None:
1048
+ """Get the completion if the last turn is from the assistant.
1049
+ DOCS_TODO: show example.
1050
+
1051
+ Returns:
1052
+ Assistant's response if last turn is assistant, None otherwise
1053
+ """
1054
+ ...
1055
+ def get_turns(self) -> list[StringTurn]:
1056
+ """Get all turns as StringTurn namedtuples.
1057
+
1058
+ Returns:
1059
+ List of StringTurn(role, content) tuples
1060
+ """
1061
+ ...
1062
+ def get_fragments(self) -> list[tuple[str, list[Fragment]]]:
1063
+ """Get all turns as multi-modal fragments.
1064
+
1065
+ Returns:
1066
+ List of (role, fragments) tuples
1067
+ """
1068
+ ...
1069
+ @staticmethod
1070
+ def from_json(json_str) -> StringThread:
1071
+ """Deserialize a thread from JSON string.
1072
+
1073
+ Args:
1074
+ json_str: JSON representation of thread
1075
+
1076
+ Returns:
1077
+ Deserialized StringThread
1078
+ """
1079
+ ...
1080
+ def to_json(self) -> str:
1081
+ """Serialize thread to JSON string.
1082
+
1083
+ Returns:
1084
+ JSON representation of thread
1085
+ """
1086
+ ...
1087
+ def with_weight_all_assistant_turns(self) -> StringThread:
1088
+ """Mark all assistant turns for training, zero out others.
1089
+
1090
+ Returns:
1091
+ New thread with weights adjusted
1092
+ """
1093
+ ...
1094
+ def with_weight_last_assistant_turn(self) -> StringThread:
1095
+ """Mark only the last assistant turn for training.
1096
+
1097
+ Returns:
1098
+ New thread with weights adjusted
1099
+ """
1100
+ ...
1101
+ def with_weight_assistant_turns_from_index(self, start_index: int) -> StringThread:
1102
+ """Mark assistant turns starting from index for training.
1103
+
1104
+ Args:
1105
+ start_index: Index of first assistant turn to weight (0-based)
1106
+
1107
+ Returns:
1108
+ New thread with weights adjusted
1109
+ """
1110
+ ...
1111
+ def uuid(self) -> str | None:
1112
+ """Get the UUID associated with this thread (if any).
1113
+
1114
+ Returns:
1115
+ UUID string or None
1116
+ """
1117
+ ...
1118
+ def __repr__(self) -> str: ...
1119
+
1120
+ class TokenizedThread:
1121
+ """Represents a conversation thread with tokenized content.
1122
+
1123
+ Similar to StringThread but stores token IDs instead of strings.
1124
+ Useful for training operations and when you need direct token-level control.
1125
+
1126
+ Attributes:
1127
+ metadata: Arbitrary Python object to attach to this thread
1128
+ """
1129
+
1130
+ metadata: Any
1131
+
1132
+ def user(self, content: Sequence[int]) -> TokenizedThread:
1133
+ """Append a user turn with token IDs.
1134
+
1135
+ Args:
1136
+ content: List of token IDs for user message
1137
+
1138
+ Returns:
1139
+ New thread with user turn appended
1140
+ """
1141
+ ...
1142
+ def assistant(self, content: Sequence[int]) -> TokenizedThread:
1143
+ """Append an assistant turn with token IDs.
1144
+
1145
+ Args:
1146
+ content: List of token IDs for assistant message
1147
+
1148
+ Returns:
1149
+ New thread with assistant turn appended
1150
+ """
1151
+ ...
1152
+ def last_content(self) -> Optional[list[int]]:
1153
+ """Get the token IDs of the last turn.
1154
+
1155
+ Returns:
1156
+ List of token IDs, or None if thread is empty
1157
+ """
1158
+ ...
1159
+ def len_last_turn(self) -> int:
1160
+ """Get the number of tokens in the last turn.
1161
+
1162
+ Returns:
1163
+ Token count of last turn
1164
+ """
1165
+ ...
1166
+ def get_turns(self) -> list[TokenizedTurn]:
1167
+ """Get all turns as TokenizedTurn namedtuples.
1168
+
1169
+ Returns:
1170
+ List of TokenizedTurn(role, content) tuples
1171
+ """
1172
+ ...
1173
+ def with_weight_all_assistant_turns(self) -> TokenizedThread:
1174
+ """Mark all assistant turns for training, zero out others.
1175
+
1176
+ Returns:
1177
+ New thread with weights adjusted
1178
+ """
1179
+ ...
1180
+ def with_weight_last_assistant_turn(self) -> TokenizedThread:
1181
+ """Mark only the last assistant turn for training.
1182
+
1183
+ Returns:
1184
+ New thread with weights adjusted
1185
+ """
1186
+ ...
1187
+ def with_weight_assistant_turns_from_index(self, start_index: int) -> TokenizedThread:
1188
+ """Mark assistant turns starting from index for training.
1189
+
1190
+ Args:
1191
+ start_index: Index of first assistant turn to weight (0-based)
1192
+
1193
+ Returns:
1194
+ New thread with weights adjusted
1195
+ """
1196
+ ...
1197
+ def uuid(self) -> str | None:
1198
+ """Get the UUID associated with this thread (if any).
1199
+
1200
+ Returns:
1201
+ UUID string or None
1202
+ """
1203
+ ...
1204
+ def __repr__(self) -> str: ...
1205
+
1206
+ class TrainingModel(InferenceModel):
1207
+ """Model instance for running training operations.
1208
+
1209
+ TrainingModel extends InferenceModel with training capabilities. It supports
1210
+ various training algorithms including supervised fine-tuning, PPO, DPO, GRPO,
1211
+ and reward modeling.
1212
+
1213
+ Inherits all inference methods from InferenceModel.
1214
+
1215
+ Example:
1216
+ ```python
1217
+ # Spawn a training model
1218
+ model = client.model("model_registry://llama-3.1-8b").with_adapter()
1219
+ train_model = await model.spawn_train("my-train", max_batch_size=1024)
1220
+
1221
+ # Supervised fine-tuning
1222
+ thread = StringThread([
1223
+ ("user", "What is 2+2?"),
1224
+ ("assistant", "4")
1225
+ ]).with_weight_last_assistant_turn()
1226
+
1227
+ await train_model.train_language_modelling(thread)
1228
+ await train_model.optim_step(lr=1e-5, wd=0.01, max_grad_norm=1.0)
1229
+
1230
+ # Save the fine-tuned model
1231
+ model_key = await train_model.save("my-fine-tuned-model")
1232
+ ```
1233
+ """
1234
+ async def clone_inf(self) -> InferenceModel:
1235
+ """Clone this model as an inference-only instance. Useful for all reinforcement learning algorithms that perform regularization with KL divergence to a reference model.
1236
+
1237
+ Returns:
1238
+ New InferenceModel sharing weights with this training model
1239
+ """
1240
+ ...
1241
+ def get_builder_args(self) -> dict[str, Any]: ...
1242
+ async def save(self, model_name: str, inference_only: bool = True, ctx: RecipeContext | None = None) -> str:
1243
+ """Save the model weights to the model registry.
1244
+
1245
+ Args:
1246
+ model_name: Name for the saved model
1247
+ inference_only: If True, save only inference weights (default). If False we save the entire optimizer state as well.
1248
+ ctx: Optional recipe context for tracking
1249
+
1250
+ Returns:
1251
+ Model key that can be used to load the model
1252
+ """
1253
+ ...
1254
+ async def optim_step(
1255
+ self, lr: float, wd: float, max_grad_norm: float, skip_nan_gradients: bool = False
1256
+ ) -> dict[str, float]:
1257
+ """Perform an optimizer step (update model weights).
1258
+
1259
+ Args:
1260
+ lr: Learning rate
1261
+ wd: Weight decay
1262
+ max_grad_norm: Maximum gradient norm for clipping
1263
+ skip_nan_gradients: If True, skip step on NaN gradients instead of erroring
1264
+
1265
+ Returns:
1266
+ Dictionary of training metrics (loss, grad_norm, etc.)
1267
+ """
1268
+ ...
1269
+ def get_optim_step(self) -> int:
1270
+ """Get the current optimizer step counter.
1271
+
1272
+ Returns:
1273
+ Current step number
1274
+ """
1275
+ ...
1276
+ def set_optim_step(self, step: int) -> None:
1277
+ """Set the optimizer step counter.
1278
+
1279
+ Args:
1280
+ step: Step number to set
1281
+ """
1282
+ ...
1283
+ def inf(self) -> InferenceModel:
1284
+ """Get inference-only view of this model.
1285
+
1286
+ Returns:
1287
+ InferenceModel view (no copy, shares weights)
1288
+ """
1289
+ ...
1290
+ async def train_language_modelling(self, thread: StringThread) -> None:
1291
+ """Train with standard language modeling objective (next-token prediction).
1292
+
1293
+ Args:
1294
+ thread: Conversation thread with weighted assistant turns
1295
+ """
1296
+ ...
1297
+ async def train_ppo(
1298
+ self,
1299
+ thread: TokenizedThread,
1300
+ trajectory_logprobs: Sequence[float],
1301
+ advantages: Sequence[float],
1302
+ clip_range: float,
1303
+ ) -> None:
1304
+ """Train with Proximal Policy Optimization (PPO) objective.
1305
+
1306
+ Args:
1307
+ thread: Tokenized conversation thread
1308
+ trajectory_logprobs: Log probabilities from trajectory policy
1309
+ advantages: Per-token advantage estimates
1310
+ clip_range: PPO clipping range (epsilon)
1311
+ """
1312
+ ...
1313
+ async def train_grpo(
1314
+ self,
1315
+ thread: TokenizedThread,
1316
+ trajectory_logprobs: Sequence[float],
1317
+ reference_logprobs: Sequence[float],
1318
+ advantages: Sequence[float],
1319
+ clip_range: float,
1320
+ kl_beta: float,
1321
+ ) -> None:
1322
+ """Train with Group Relative Policy Optimization (GRPO).
1323
+
1324
+ Args:
1325
+ thread: Tokenized conversation thread
1326
+ trajectory_logprobs: Log probabilities from trajectory policy
1327
+ reference_logprobs: Log probabilities from reference policy
1328
+ advantages: Per-token advantage estimates
1329
+ clip_range: Clipping range
1330
+ kl_beta: KL divergence penalty coefficient
1331
+ """
1332
+ ...
1333
+ async def train_gspo(
1334
+ self,
1335
+ thread: TokenizedThread,
1336
+ trajectory_logprobs: Sequence[float],
1337
+ reference_logprobs: Sequence[float],
1338
+ advantage: Sequence[float],
1339
+ left_clip: float,
1340
+ right_clip: float,
1341
+ kl_beta: float,
1342
+ ) -> None:
1343
+ """Train with Group Sampling Policy Optimization (GSPO).
1344
+
1345
+ Args:
1346
+ thread: Tokenized conversation thread
1347
+ trajectory_logprobs: Log probabilities from trajectory policy
1348
+ reference_logprobs: Log probabilities from reference policy
1349
+ advantage: Advantage estimates
1350
+ left_clip: Left clipping bound
1351
+ right_clip: Right clipping bound
1352
+ kl_beta: KL divergence penalty coefficient
1353
+ """
1354
+ ...
1355
+ async def train_trust_region_mse(
1356
+ self,
1357
+ thread: TokenizedThread,
1358
+ targets: Sequence[float],
1359
+ clip_center: Sequence[float],
1360
+ clip_range: float,
1361
+ ) -> None:
1362
+ """Train with trust-region constrained MSE loss (for value functions).
1363
+
1364
+ Args:
1365
+ thread: Tokenized conversation thread
1366
+ targets: Target values
1367
+ clip_center: Center values for clipping
1368
+ clip_range: Clipping range
1369
+ """
1370
+ ...
1371
+ async def train_mse(self, thread: StringThread, target: float) -> None:
1372
+ """Train with MSE loss on the last token (for reward models).
1373
+
1374
+ Args:
1375
+ thread: Conversation thread
1376
+ target: Target scalar value
1377
+ """
1378
+ ...
1379
+ async def train_mse_per_token(self, thread: TokenizedThread, targets: Sequence[float]) -> None:
1380
+ """Train with per-token MSE loss.
1381
+
1382
+ Args:
1383
+ thread: Tokenized conversation thread
1384
+ targets: Target value for each weighted token
1385
+ """
1386
+ ...
1387
+ async def train_ranking(self, pos_thread: StringThread, neg_thread: StringThread) -> None:
1388
+ """Train with ranking loss (positive example should score higher).
1389
+
1390
+ Args:
1391
+ pos_thread: Preferred/positive example
1392
+ neg_thread: Rejected/negative example
1393
+ """
1394
+ ...
1395
+ async def train_dpo(
1396
+ self,
1397
+ sample_pos: StringThread,
1398
+ sample_neg: StringThread,
1399
+ ref_logprobs_pos: float,
1400
+ ref_logprobs_neg: float,
1401
+ beta: float,
1402
+ ) -> None:
1403
+ """Train with Direct Preference Optimization (DPO).
1404
+
1405
+ Args:
1406
+ sample_pos: Preferred completion thread
1407
+ sample_neg: Rejected completion thread
1408
+ ref_logprobs_pos: Reference model logprobs for positive
1409
+ ref_logprobs_neg: Reference model logprobs for negative
1410
+ beta: DPO temperature parameter
1411
+ """
1412
+ ...
1413
+ def top_p(self, top_p: float) -> TrainingModel: ...
1414
+ def temperature(self, temperature: float) -> TrainingModel: ...
1415
+ def max_gen_len(self, max_num_tokens: int) -> TrainingModel: ...
1416
+ def min_gen_len(self, min_num_tokens: int) -> TrainingModel: ...
1417
+
1418
+ class Thread(Enum):
1419
+ """Enum of thread types (internal use)."""
1420
+
1421
+ StringThread = ...
1422
+ TokenizedThread = ...
1423
+ SerializedThread = ...
1424
+
1425
+ class JobNotifier:
1426
+ """Helper class to report job progress to stdout.
1427
+
1428
+ By default logs progress to stdout. For integration with the Harmony platform,
1429
+ use HarmonyJobNotifier instead.
1430
+
1431
+ Example:
1432
+ ```python
1433
+ notifier = JobNotifier()
1434
+ notifier.register_stages(["training", "evaluation"])
1435
+
1436
+ stage = notifier.stage_notifier("training")
1437
+ stage.report_progress(tot_num_samples=1000, processed_num_samples=100)
1438
+ ```
1439
+ """
1440
+
1441
+ def __new__(cls) -> JobNotifier: ...
1442
+ def set_monitoring_link(self, monitoring_link: str) -> None:
1443
+ """Set a monitoring link (e.g., Weights & Biases URL).
1444
+
1445
+ Args:
1446
+ monitoring_link: URL to monitoring dashboard
1447
+ """
1448
+ ...
1449
+ def register_stages(self, stages: Sequence[str]) -> None:
1450
+ """Register the stages of this job.
1451
+
1452
+ Args:
1453
+ stages: Ordered list of stage names
1454
+ """
1455
+ ...
1456
+ def register_artifact(self, artifact: JobArtifact) -> None:
1457
+ """Register an artifact produced by this job.
1458
+
1459
+ Args:
1460
+ artifact: Artifact to register
1461
+ """
1462
+ ...
1463
+ def report_error(self, error: str) -> None:
1464
+ """Report an error that occurred during the job.
1465
+
1466
+ Args:
1467
+ error: Error message
1468
+ """
1469
+ ...
1470
+ def report_progress(
1471
+ self,
1472
+ stage: str,
1473
+ tot_num_samples: Optional[int] = None,
1474
+ processed_num_samples: Optional[int] = None,
1475
+ monitoring_link: Optional[str] = None,
1476
+ checkpoints: Optional[Sequence[str]] = None,
1477
+ ) -> None:
1478
+ """Report progress for a stage.
1479
+
1480
+ Args:
1481
+ stage: Stage name
1482
+ tot_num_samples: Total number of samples to process
1483
+ processed_num_samples: Number of samples processed so far
1484
+ monitoring_link: Optional monitoring dashboard URL
1485
+ checkpoints: Optional list of checkpoint identifiers
1486
+ """
1487
+ ...
1488
+ def stage_notifier(self, stage: str) -> StageNotifier:
1489
+ """Get a stage-specific notifier.
1490
+
1491
+ Args:
1492
+ stage: Stage name
1493
+
1494
+ Returns:
1495
+ StageNotifier for the given stage
1496
+ """
1497
+ ...
1498
+ def __repr__(self) -> str: ...
1499
+
1500
+ class HarmonyJobNotifier(JobNotifier):
1501
+ """Job notifier that reports progress to the Harmony platform.
1502
+
1503
+ Use this instead of JobNotifier when running jobs through the Harmony API
1504
+ to integrate with the UI and tracking system.
1505
+
1506
+ Example:
1507
+ ```python
1508
+ notifier = HarmonyJobNotifier(client, job_id)
1509
+ notifier.register_stages(["training", "evaluation"])
1510
+
1511
+ # Register artifacts
1512
+ artifact = JobArtifact(
1513
+ id="model-v1",
1514
+ name="Fine-tuned Model",
1515
+ kind="model",
1516
+ uri="s3://bucket/models/model-v1"
1517
+ )
1518
+ notifier.register_artifact(artifact)
1519
+
1520
+ # Report progress
1521
+ stage = notifier.stage_notifier("training")
1522
+ stage.report_progress(tot_num_samples=1000, processed_num_samples=500)
1523
+ ```
1524
+ """
1525
+ def __new__(cls, client: HarmonyClient, job_id: str) -> HarmonyJobNotifier: ...
1526
+ # DOCS_TODO: add example and explain it will show up in the UI
1527
+ def set_monitoring_link(self, monitoring_link: str) -> None: ...
1528
+
1529
+ class StageNotifier:
1530
+ """Helper class to report progress for a specific job stage.
1531
+
1532
+ Get an instance via JobNotifier.stage_notifier() or HarmonyJobNotifier.stage_notifier().
1533
+ Provides convenience methods for reporting progress without repeating the stage name.
1534
+ """
1535
+
1536
+ def set_monitoring_link(self, monitoring_link: str) -> None:
1537
+ """Set monitoring link for this stage.
1538
+
1539
+ Args:
1540
+ monitoring_link: URL to monitoring dashboard
1541
+ """
1542
+ ...
1543
+ def report_progress(
1544
+ self,
1545
+ tot_num_samples: Optional[int] = None,
1546
+ processed_num_samples: Optional[int] = None,
1547
+ monitoring_link: Optional[str] = None,
1548
+ checkpoints: Optional[Sequence[str]] = None,
1549
+ ) -> None:
1550
+ """Report progress for this stage.
1551
+
1552
+ Args:
1553
+ tot_num_samples: Total number of samples to process
1554
+ processed_num_samples: Number of samples processed so far
1555
+ monitoring_link: Optional monitoring dashboard URL
1556
+ checkpoints: Optional list of checkpoint identifiers
1557
+ """
1558
+ ...
1559
+
1560
+ async def get_client(
1561
+ addr: str,
1562
+ num_gpus: int | None = None,
1563
+ api_key: str | None = None,
1564
+ use_case: str | None = None,
1565
+ compute_pool: str | None = None,
1566
+ job_id: str | None = None,
1567
+ default_headers: dict[str, str] | None = None,
1568
+ ttl_after_disconnect_s: int | None = 30,
1569
+ control_plane_url: str | None = None,
1570
+ control_plane_api_token: str | None = None,
1571
+ ) -> HarmonyClient:
1572
+ """Create and connect a HarmonyClient to workers.
1573
+
1574
+ This is the main entry point for connecting to the Harmony platform.
1575
+ The client manages the connection to GPU workers and provides access
1576
+ to model building and inference/training operations.
1577
+
1578
+ Args:
1579
+ addr: WebSocket address of the deployment (e.g., "ws://localhost:8080")
1580
+ num_gpus: Number of GPUs to request (None = use all available)
1581
+ api_key: API key for authentication (or set ADAPTIVE_API_KEY env var)
1582
+ use_case: Use case identifier (or set ADAPTIVE_USE_CASE env var, default: "default")
1583
+ compute_pool: Compute pool to use (or set ADAPTIVE_COMPUTE_POOL env var, default: "default")
1584
+ job_id: Job ID for tracking (or set ADAPTIVE_JOB_ID env var)
1585
+ default_headers: Additional HTTP headers for the connection
1586
+ ttl_after_disconnect_s: Keep session alive for this many seconds after disconnect (default: 30)
1587
+ control_plane_url: URL of control plane for fetching model/dataset/grader configs
1588
+ control_plane_api_token: API token for authenticating with the control plane
1589
+
1590
+ Returns:
1591
+ Connected HarmonyClient instance
1592
+
1593
+ Example:
1594
+ ```python
1595
+ from harmony_client import get_client
1596
+
1597
+ # Basic connection
1598
+ client = await get_client("ws://localhost:8080", num_gpus=1)
1599
+
1600
+ # With authentication
1601
+ client = await get_client(
1602
+ "wss://api.adaptive.com",
1603
+ num_gpus=4,
1604
+ api_key="my-api-key",
1605
+ control_plane_url="https://api.adaptive.com"
1606
+ )
1607
+
1608
+ # Use as context manager
1609
+ async with await get_client(...) as client:
1610
+ model = client.model("model_registry://llama-3.1-8b")
1611
+ # ... use model
1612
+ # Connection closed automatically
1613
+ ```
1614
+ """
1615
+ ...