baseten-loops 0.1.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- baseten_loops-0.1.0/.gitignore +10 -0
- baseten_loops-0.1.0/PKG-INFO +13 -0
- baseten_loops-0.1.0/baseten/loops/__init__.py +65 -0
- baseten_loops-0.1.0/baseten/loops/client.py +150 -0
- baseten_loops-0.1.0/baseten/loops/models.py +291 -0
- baseten_loops-0.1.0/baseten/loops/promise_client.py +367 -0
- baseten_loops-0.1.0/baseten/loops/retry.py +100 -0
- baseten_loops-0.1.0/baseten/loops/sampling_client.py +392 -0
- baseten_loops-0.1.0/baseten/loops/service_client.py +271 -0
- baseten_loops-0.1.0/baseten/loops/templates/__init__.py +0 -0
- baseten_loops-0.1.0/baseten/loops/templates/inference/config.yaml +19 -0
- baseten_loops-0.1.0/baseten/loops/templates/training/config.py +45 -0
- baseten_loops-0.1.0/baseten/loops/training_client.py +759 -0
- baseten_loops-0.1.0/pyproject.toml +52 -0
- baseten_loops-0.1.0/tests/acceptance_math_rl.py +158 -0
- baseten_loops-0.1.0/tests/acceptance_multiturn_rl.py +206 -0
- baseten_loops-0.1.0/tests/e2e.py +428 -0
- baseten_loops-0.1.0/tests/e2e_job/config.py +55 -0
- baseten_loops-0.1.0/tests/test_get_checkpoint_archive_url.py +145 -0
- baseten_loops-0.1.0/tests/test_integration.py +319 -0
- baseten_loops-0.1.0/tests/test_list_checkpoints.py +129 -0
- baseten_loops-0.1.0/tests/test_promise_client.py +393 -0
- baseten_loops-0.1.0/tests/test_service_client_init.py +58 -0
- baseten_loops-0.1.0/tests/verify_worker.py +118 -0
- baseten_loops-0.1.0/uv.lock +2676 -0
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: baseten-loops
|
|
3
|
+
Version: 0.1.0
|
|
4
|
+
Summary: SDK for Baseten training workers
|
|
5
|
+
Requires-Python: >=3.12
|
|
6
|
+
Requires-Dist: datasets>=4.8.4
|
|
7
|
+
Requires-Dist: httpx>=0.24.1
|
|
8
|
+
Requires-Dist: pydantic>=2.10.0
|
|
9
|
+
Requires-Dist: tenacity>=8.0.0
|
|
10
|
+
Requires-Dist: transformers>=4.40.0
|
|
11
|
+
Requires-Dist: truss>=0.16.0
|
|
12
|
+
Provides-Extra: tinker
|
|
13
|
+
Requires-Dist: baseten-loops-tinker>=0.1.0; extra == 'tinker'
|
|
@@ -0,0 +1,65 @@
|
|
|
1
|
+
from baseten.loops.client import create_training_client
|
|
2
|
+
from baseten.loops.promise_client import (
|
|
3
|
+
RemoteOpError,
|
|
4
|
+
ServerShutdownError,
|
|
5
|
+
UnknownRequestError,
|
|
6
|
+
)
|
|
7
|
+
from baseten.loops.training_client import TrainingClient, OperationFuture
|
|
8
|
+
from baseten.loops.sampling_client import SamplingClient
|
|
9
|
+
from baseten.loops.service_client import ServiceClient
|
|
10
|
+
from baseten.loops.models import (
|
|
11
|
+
AdamParams,
|
|
12
|
+
Checkpoint,
|
|
13
|
+
CheckpointFile,
|
|
14
|
+
CheckpointFilesResponse,
|
|
15
|
+
Datum,
|
|
16
|
+
EncodedTextChunk,
|
|
17
|
+
ForwardBackwardOutput,
|
|
18
|
+
InitTrainerServerResponse,
|
|
19
|
+
ImageChunk,
|
|
20
|
+
LoadWeightsResponse,
|
|
21
|
+
ModelInput,
|
|
22
|
+
ModelInputChunk,
|
|
23
|
+
OptimStepResponse,
|
|
24
|
+
SampledSequence,
|
|
25
|
+
SampleResponse,
|
|
26
|
+
SampleResult,
|
|
27
|
+
SamplingParams,
|
|
28
|
+
SaveWeightsForSamplerResponse,
|
|
29
|
+
SaveWeightsResponse,
|
|
30
|
+
TensorData,
|
|
31
|
+
)
|
|
32
|
+
|
|
33
|
+
__all__ = [
|
|
34
|
+
# Client classes
|
|
35
|
+
"create_training_client",
|
|
36
|
+
"ServiceClient",
|
|
37
|
+
"TrainingClient",
|
|
38
|
+
"SamplingClient",
|
|
39
|
+
"OperationFuture",
|
|
40
|
+
# Promise-protocol exceptions
|
|
41
|
+
"RemoteOpError",
|
|
42
|
+
"UnknownRequestError",
|
|
43
|
+
"ServerShutdownError",
|
|
44
|
+
# Types
|
|
45
|
+
"AdamParams",
|
|
46
|
+
"Checkpoint",
|
|
47
|
+
"CheckpointFile",
|
|
48
|
+
"CheckpointFilesResponse",
|
|
49
|
+
"Datum",
|
|
50
|
+
"EncodedTextChunk",
|
|
51
|
+
"ForwardBackwardOutput",
|
|
52
|
+
"InitTrainerServerResponse",
|
|
53
|
+
"ImageChunk",
|
|
54
|
+
"LoadWeightsResponse",
|
|
55
|
+
"ModelInput",
|
|
56
|
+
"ModelInputChunk",
|
|
57
|
+
"OptimStepResponse",
|
|
58
|
+
"SampledSequence",
|
|
59
|
+
"SampleResponse",
|
|
60
|
+
"SampleResult",
|
|
61
|
+
"SamplingParams",
|
|
62
|
+
"SaveWeightsForSamplerResponse",
|
|
63
|
+
"SaveWeightsResponse",
|
|
64
|
+
"TensorData",
|
|
65
|
+
]
|
|
@@ -0,0 +1,150 @@
|
|
|
1
|
+
"""Deploy a training worker and return a TrainingClient connected directly to it."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import json
|
|
6
|
+
import uuid
|
|
7
|
+
from pathlib import Path
|
|
8
|
+
from typing import Optional
|
|
9
|
+
|
|
10
|
+
import httpx
|
|
11
|
+
|
|
12
|
+
from baseten.loops.training_client import TrainingClient
|
|
13
|
+
from truss.base import truss_config
|
|
14
|
+
from truss_train import definitions
|
|
15
|
+
from truss_train.public_api import push
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
_DEFAULT_WORKSPACE_ROOT = Path(__file__).resolve().parent.parent / "server"
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def create_training_client(
|
|
22
|
+
base_model: str,
|
|
23
|
+
worker_url: str,
|
|
24
|
+
*,
|
|
25
|
+
api_key: str | None = None,
|
|
26
|
+
gpu_count: int = 1,
|
|
27
|
+
accelerator: str = "H100",
|
|
28
|
+
training_gpus: Optional[list[int]] = None,
|
|
29
|
+
inference_gpus: Optional[list[int]] = None,
|
|
30
|
+
max_seq_len: int = 4096,
|
|
31
|
+
worker_port: int = 8001,
|
|
32
|
+
namespace: str = "default",
|
|
33
|
+
remote: str = "baseten",
|
|
34
|
+
workspace_root: Optional[Path] = None,
|
|
35
|
+
deploy: bool = True,
|
|
36
|
+
timeout: float = 600.0,
|
|
37
|
+
) -> TrainingClient:
|
|
38
|
+
"""Deploy a training worker and return a TrainingClient connected to it.
|
|
39
|
+
|
|
40
|
+
Args:
|
|
41
|
+
base_model: HuggingFace model ID (e.g. "Qwen/Qwen3-8B").
|
|
42
|
+
worker_url: URL of the dp_worker to connect to. If deploy=True, this
|
|
43
|
+
is computed automatically from the job ID and namespace.
|
|
44
|
+
api_key: API key for authentication.
|
|
45
|
+
gpu_count: Total number of GPUs to request.
|
|
46
|
+
accelerator: GPU type (H100, H200, B200).
|
|
47
|
+
training_gpus: GPU indices for training. Defaults to [0].
|
|
48
|
+
inference_gpus: GPU indices for inference. Defaults to [0].
|
|
49
|
+
max_seq_len: Maximum sequence length for training.
|
|
50
|
+
worker_port: Port the dp_worker will listen on.
|
|
51
|
+
namespace: K8s namespace where the training job runs.
|
|
52
|
+
remote: Baseten remote name from .trussrc.
|
|
53
|
+
workspace_root: Path to thinker workspace root.
|
|
54
|
+
deploy: If True, deploy a training job. If False, just connect to worker_url.
|
|
55
|
+
timeout: HTTP timeout for training operations.
|
|
56
|
+
"""
|
|
57
|
+
if not deploy:
|
|
58
|
+
return TrainingClient(worker_url, api_key=api_key, timeout=timeout)
|
|
59
|
+
|
|
60
|
+
suffix = uuid.uuid4().hex[:7]
|
|
61
|
+
project_name = f"trainer-{base_model.replace('/', '-')}-{suffix}"
|
|
62
|
+
print(f"Project: {project_name}")
|
|
63
|
+
|
|
64
|
+
if training_gpus is None:
|
|
65
|
+
training_gpus = [0]
|
|
66
|
+
if inference_gpus is None:
|
|
67
|
+
inference_gpus = [0]
|
|
68
|
+
|
|
69
|
+
rl_config = {
|
|
70
|
+
"model_id": base_model,
|
|
71
|
+
"training": {
|
|
72
|
+
"tensor_parallel_size": 1,
|
|
73
|
+
"pipeline_parallel_size": 1,
|
|
74
|
+
"max_length": max_seq_len,
|
|
75
|
+
"gpus": training_gpus,
|
|
76
|
+
},
|
|
77
|
+
"inference": {
|
|
78
|
+
"tensor_parallel_size": 1,
|
|
79
|
+
"gpus": inference_gpus,
|
|
80
|
+
"gpu_memory_utilization": 0.9,
|
|
81
|
+
},
|
|
82
|
+
}
|
|
83
|
+
|
|
84
|
+
ws_root = workspace_root or _DEFAULT_WORKSPACE_ROOT
|
|
85
|
+
if not ws_root.exists():
|
|
86
|
+
raise FileNotFoundError(
|
|
87
|
+
f"Workspace root not found: {ws_root}. "
|
|
88
|
+
"Pass workspace_root= pointing to the thinker repo."
|
|
89
|
+
)
|
|
90
|
+
|
|
91
|
+
rl_config_path = ws_root / "rl_config.json"
|
|
92
|
+
rl_config_path.write_text(json.dumps(rl_config))
|
|
93
|
+
|
|
94
|
+
accel_enum = getattr(
|
|
95
|
+
truss_config.Accelerator,
|
|
96
|
+
accelerator.upper(),
|
|
97
|
+
truss_config.Accelerator.H100,
|
|
98
|
+
)
|
|
99
|
+
|
|
100
|
+
project = definitions.TrainingProject(
|
|
101
|
+
name=project_name,
|
|
102
|
+
job=definitions.TrainingJob(
|
|
103
|
+
compute=definitions.Compute(
|
|
104
|
+
accelerator=truss_config.AcceleratorSpec(
|
|
105
|
+
accelerator=accel_enum,
|
|
106
|
+
count=gpu_count,
|
|
107
|
+
),
|
|
108
|
+
),
|
|
109
|
+
runtime=definitions.Runtime(
|
|
110
|
+
start_commands=[
|
|
111
|
+
"apt-get update && apt-get install -y python3-dev curl",
|
|
112
|
+
"curl -LsSf https://astral.sh/uv/install.sh | sh",
|
|
113
|
+
". $HOME/.local/bin/env && uv sync --extra worker",
|
|
114
|
+
f".venv/bin/python -m trainers_server.dp_worker.main --config $RL_CONFIG_PATH --port {worker_port}",
|
|
115
|
+
],
|
|
116
|
+
environment_variables={
|
|
117
|
+
"RL_CONFIG_PATH": "rl_config.json",
|
|
118
|
+
"BASETEN_API_KEY": definitions.SecretReference(name="baseten_api_key"),
|
|
119
|
+
},
|
|
120
|
+
),
|
|
121
|
+
image=definitions.Image(
|
|
122
|
+
base_image="nvcr.io/nvidia/cuda:12.8.1-cudnn-devel-ubuntu24.04",
|
|
123
|
+
),
|
|
124
|
+
workspace=definitions.Workspace(
|
|
125
|
+
workspace_root=str(ws_root),
|
|
126
|
+
exclude_dirs=[
|
|
127
|
+
str(ws_root / ".venv"),
|
|
128
|
+
str(ws_root / ".git"),
|
|
129
|
+
],
|
|
130
|
+
),
|
|
131
|
+
),
|
|
132
|
+
)
|
|
133
|
+
|
|
134
|
+
try:
|
|
135
|
+
result = push(project, remote=remote, source_dir=ws_root)
|
|
136
|
+
job_id = result["id"]
|
|
137
|
+
print(f"Training Job ID: {job_id}")
|
|
138
|
+
finally:
|
|
139
|
+
rl_config_path.unlink(missing_ok=True)
|
|
140
|
+
|
|
141
|
+
# Build the deterministic pod DNS for the worker.
|
|
142
|
+
worker_host = (
|
|
143
|
+
f"baseten-training-job-{job_id}-multinode-0"
|
|
144
|
+
f".baseten-training-job-{job_id}-multinode"
|
|
145
|
+
f".{namespace}.svc.cluster.local"
|
|
146
|
+
)
|
|
147
|
+
resolved_url = f"http://{worker_host}:{worker_port}"
|
|
148
|
+
print(f"Worker URL: {resolved_url}")
|
|
149
|
+
|
|
150
|
+
return TrainingClient(resolved_url, api_key=api_key, timeout=timeout)
|
|
@@ -0,0 +1,291 @@
|
|
|
1
|
+
"""Types for the loops SDK — wire-compatible with trainers-server."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import base64
|
|
6
|
+
from datetime import datetime
|
|
7
|
+
from typing import Annotated, Literal, Self, Union
|
|
8
|
+
|
|
9
|
+
from pydantic import BaseModel, Discriminator, Field, Tag
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
# ── Tensor / model-input primitives ─────────────────────────────────
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class TensorData(BaseModel):
|
|
16
|
+
data: list
|
|
17
|
+
dtype: str
|
|
18
|
+
shape: list[int]
|
|
19
|
+
|
|
20
|
+
def tolist(self) -> list:
|
|
21
|
+
"""Return the tensor data as a plain Python list."""
|
|
22
|
+
return list(self.data)
|
|
23
|
+
|
|
24
|
+
def to_torch(self): # type: ignore[return]
|
|
25
|
+
"""Return the tensor data as a torch.Tensor with the correct shape and dtype."""
|
|
26
|
+
import torch
|
|
27
|
+
|
|
28
|
+
_DTYPE_MAP = {
|
|
29
|
+
"float32": torch.float32,
|
|
30
|
+
"float16": torch.float16,
|
|
31
|
+
"bfloat16": torch.bfloat16,
|
|
32
|
+
"float64": torch.float64,
|
|
33
|
+
"int32": torch.int32,
|
|
34
|
+
"int64": torch.int64,
|
|
35
|
+
"int8": torch.int8,
|
|
36
|
+
"bool": torch.bool,
|
|
37
|
+
}
|
|
38
|
+
torch_dtype = _DTYPE_MAP.get(self.dtype, torch.float32)
|
|
39
|
+
# The wire format encodes masked / inactive positions as JSON null
|
|
40
|
+
# (since JSON has no NaN literal); the server emits them via
|
|
41
|
+
# masked_fill(..., float("nan")) and serialisation flips NaN → None.
|
|
42
|
+
# Convert back here so torch.tensor can ingest the list. Only floats
|
|
43
|
+
# carry the None-as-NaN convention — int / bool tensors never have
|
|
44
|
+
# masked positions.
|
|
45
|
+
data = self.data
|
|
46
|
+
if torch_dtype.is_floating_point and any(v is None for v in data):
|
|
47
|
+
data = [float("nan") if v is None else v for v in data]
|
|
48
|
+
return torch.tensor(data, dtype=torch_dtype).reshape(self.shape)
|
|
49
|
+
|
|
50
|
+
@classmethod
|
|
51
|
+
def from_torch(cls, tensor) -> Self:
|
|
52
|
+
"""Construct TensorData from a torch.Tensor."""
|
|
53
|
+
dtype_name = str(tensor.dtype).replace("torch.", "")
|
|
54
|
+
return cls(
|
|
55
|
+
data=tensor.flatten().tolist(),
|
|
56
|
+
dtype=dtype_name,
|
|
57
|
+
shape=list(tensor.shape),
|
|
58
|
+
)
|
|
59
|
+
|
|
60
|
+
@classmethod
|
|
61
|
+
def from_list(cls, data: list, dtype: str = "float32") -> Self:
|
|
62
|
+
flat = data
|
|
63
|
+
shape = []
|
|
64
|
+
level = data
|
|
65
|
+
while isinstance(level, list):
|
|
66
|
+
shape.append(len(level))
|
|
67
|
+
level = level[0] if level else []
|
|
68
|
+
return cls(data=flat, dtype=dtype, shape=shape)
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
class EncodedTextChunk(BaseModel):
|
|
72
|
+
type: Literal["encoded_text"] = "encoded_text"
|
|
73
|
+
tokens: list[int]
|
|
74
|
+
|
|
75
|
+
@property
|
|
76
|
+
def length(self) -> int:
|
|
77
|
+
return len(self.tokens)
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
class ImageChunk(BaseModel):
|
|
81
|
+
type: Literal["image"] = "image"
|
|
82
|
+
data: str # base64-encoded
|
|
83
|
+
format: str
|
|
84
|
+
expected_tokens: int = 0
|
|
85
|
+
|
|
86
|
+
@classmethod
|
|
87
|
+
def from_bytes(cls, raw: bytes, fmt: str, expected_tokens: int = 0) -> Self:
|
|
88
|
+
return cls(
|
|
89
|
+
data=base64.b64encode(raw).decode(),
|
|
90
|
+
format=fmt,
|
|
91
|
+
expected_tokens=expected_tokens,
|
|
92
|
+
)
|
|
93
|
+
|
|
94
|
+
def to_bytes(self) -> bytes:
|
|
95
|
+
return base64.b64decode(self.data)
|
|
96
|
+
|
|
97
|
+
@property
|
|
98
|
+
def length(self) -> int:
|
|
99
|
+
return self.expected_tokens
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
ModelInputChunk = Annotated[
|
|
103
|
+
Union[
|
|
104
|
+
Annotated[EncodedTextChunk, Tag("encoded_text")],
|
|
105
|
+
Annotated[ImageChunk, Tag("image")],
|
|
106
|
+
],
|
|
107
|
+
Discriminator("type"),
|
|
108
|
+
]
|
|
109
|
+
|
|
110
|
+
|
|
111
|
+
class ModelInput(BaseModel):
|
|
112
|
+
chunks: list[ModelInputChunk] = Field(default_factory=list)
|
|
113
|
+
|
|
114
|
+
@classmethod
|
|
115
|
+
def from_ints(cls, tokens: list[int]) -> Self:
|
|
116
|
+
return cls(chunks=[EncodedTextChunk(tokens=tokens)])
|
|
117
|
+
|
|
118
|
+
@classmethod
|
|
119
|
+
def empty(cls) -> Self:
|
|
120
|
+
return cls(chunks=[])
|
|
121
|
+
|
|
122
|
+
@property
|
|
123
|
+
def length(self) -> int:
|
|
124
|
+
return sum(c.length for c in self.chunks)
|
|
125
|
+
|
|
126
|
+
def to_ints(self) -> list[int]:
|
|
127
|
+
out: list[int] = []
|
|
128
|
+
for chunk in self.chunks:
|
|
129
|
+
if not isinstance(chunk, EncodedTextChunk):
|
|
130
|
+
raise TypeError(f"Cannot convert {type(chunk).__name__} to ints")
|
|
131
|
+
out.extend(chunk.tokens)
|
|
132
|
+
return out
|
|
133
|
+
|
|
134
|
+
def append(self, chunk: ModelInputChunk) -> Self:
|
|
135
|
+
return type(self)(chunks=[*self.chunks, chunk])
|
|
136
|
+
|
|
137
|
+
|
|
138
|
+
class Datum(BaseModel):
|
|
139
|
+
model_input: ModelInput
|
|
140
|
+
loss_fn_inputs: dict[str, TensorData] = Field(default_factory=dict)
|
|
141
|
+
|
|
142
|
+
|
|
143
|
+
# ── Sampling ─────────────────────────────────────────────────────────
|
|
144
|
+
|
|
145
|
+
|
|
146
|
+
class SamplingParams(BaseModel):
|
|
147
|
+
max_tokens: int | None = None
|
|
148
|
+
seed: int | None = None
|
|
149
|
+
stop: str | list[str] | list[int] | None = None
|
|
150
|
+
temperature: float = 1.0
|
|
151
|
+
top_k: int = -1
|
|
152
|
+
top_p: float = 1.0
|
|
153
|
+
|
|
154
|
+
|
|
155
|
+
class SampledSequence(BaseModel):
|
|
156
|
+
tokens: list[int] = Field(default_factory=list)
|
|
157
|
+
logprobs: list[float] | None = None
|
|
158
|
+
stop_reason: str = "length"
|
|
159
|
+
|
|
160
|
+
|
|
161
|
+
class PromptLogprobs(BaseModel):
|
|
162
|
+
"""Per-token log-probabilities for the prompt.
|
|
163
|
+
|
|
164
|
+
``logprobs[i]`` is the logprob of the actual prompt token at position
|
|
165
|
+
``i``. Index 0 is always ``None`` — the first token has no preceding
|
|
166
|
+
context to score against.
|
|
167
|
+
|
|
168
|
+
``top_logprobs[i]`` (only populated when the caller passed
|
|
169
|
+
``topk_prompt_logprobs > 0``) maps ``token_id -> logprob`` for the top-k
|
|
170
|
+
alternatives at position ``i``. The actual prompt token is also included
|
|
171
|
+
even if it falls outside the top-k.
|
|
172
|
+
"""
|
|
173
|
+
|
|
174
|
+
logprobs: list[float | None] = Field(default_factory=list)
|
|
175
|
+
top_logprobs: list[dict[int, float] | None] | None = None
|
|
176
|
+
|
|
177
|
+
|
|
178
|
+
class SampleResult(BaseModel):
|
|
179
|
+
sequences: list[SampledSequence] = Field(default_factory=list)
|
|
180
|
+
policy_version: int | None = None
|
|
181
|
+
prompt_logprobs: PromptLogprobs | None = None
|
|
182
|
+
|
|
183
|
+
|
|
184
|
+
# Alias kept for API surface compatibility.
|
|
185
|
+
SampleResponse = SampleResult
|
|
186
|
+
|
|
187
|
+
|
|
188
|
+
# ── Optimizer ────────────────────────────────────────────────────────
|
|
189
|
+
|
|
190
|
+
|
|
191
|
+
class AdamParams(BaseModel):
|
|
192
|
+
learning_rate: float = 1e-4
|
|
193
|
+
beta1: float = 0.9
|
|
194
|
+
beta2: float = 0.95
|
|
195
|
+
eps: float = 1e-12
|
|
196
|
+
weight_decay: float = 0.0
|
|
197
|
+
grad_clip_norm: float = 0.0 # 0 = disabled
|
|
198
|
+
|
|
199
|
+
|
|
200
|
+
# ── Training response types ──────────────────────────────────────────
|
|
201
|
+
|
|
202
|
+
|
|
203
|
+
class ForwardBackwardOutput(BaseModel):
|
|
204
|
+
loss: float = 0.0
|
|
205
|
+
loss_fn_output_type: str = ""
|
|
206
|
+
loss_fn_outputs: list[dict[str, TensorData]] = Field(default_factory=list)
|
|
207
|
+
metrics: dict[str, float] = Field(default_factory=dict)
|
|
208
|
+
|
|
209
|
+
|
|
210
|
+
class OptimStepResponse(BaseModel):
|
|
211
|
+
metrics: dict[str, float] | None = None
|
|
212
|
+
|
|
213
|
+
|
|
214
|
+
class SaveWeightsResponse(BaseModel):
|
|
215
|
+
mode: str = ""
|
|
216
|
+
path: str = ""
|
|
217
|
+
|
|
218
|
+
|
|
219
|
+
class SaveWeightsForSamplerResponse(BaseModel):
|
|
220
|
+
"""Response for ``/save_weights_for_sampler`` (the trainer-side
|
|
221
|
+
publish-LoRA endpoint).
|
|
222
|
+
|
|
223
|
+
``version`` is the controller's ``step_count`` at the moment of the
|
|
224
|
+
call — clients don't pass it. Two syncs at the same version mean
|
|
225
|
+
"no new weights since last check"; a sampler watching the URI keys
|
|
226
|
+
its hot-swap on this number.
|
|
227
|
+
|
|
228
|
+
``path`` is the checkpoint subdirectory the adapters were written
|
|
229
|
+
to (``{weight_sync_uri}/{name}/``).
|
|
230
|
+
"""
|
|
231
|
+
version: int = 0
|
|
232
|
+
path: str = ""
|
|
233
|
+
|
|
234
|
+
|
|
235
|
+
class LoadWeightsResponse(BaseModel):
|
|
236
|
+
step: int = 0
|
|
237
|
+
|
|
238
|
+
|
|
239
|
+
class InitTrainerServerResponse(BaseModel):
|
|
240
|
+
"""Response for ``/init_trainer_server``."""
|
|
241
|
+
|
|
242
|
+
step: int = 0
|
|
243
|
+
lora_rank: int
|
|
244
|
+
|
|
245
|
+
|
|
246
|
+
# ── Checkpoints ──────────────────────────────────────────────────────
|
|
247
|
+
|
|
248
|
+
|
|
249
|
+
class Checkpoint(BaseModel):
|
|
250
|
+
model_config = {"populate_by_name": True}
|
|
251
|
+
|
|
252
|
+
trainer_server_id: str = Field(alias="trainer_id")
|
|
253
|
+
checkpoint_id: str
|
|
254
|
+
checkpoint_type: str
|
|
255
|
+
created_at: datetime
|
|
256
|
+
base_model: str | None = None
|
|
257
|
+
lora_adapter_config: dict | None = None
|
|
258
|
+
size_bytes: int = 0
|
|
259
|
+
sync_status: str | None = None
|
|
260
|
+
|
|
261
|
+
|
|
262
|
+
class CheckpointFile(BaseModel):
|
|
263
|
+
url: str
|
|
264
|
+
relative_file_name: str
|
|
265
|
+
size_bytes: int
|
|
266
|
+
last_modified: datetime
|
|
267
|
+
|
|
268
|
+
|
|
269
|
+
class CheckpointFilesResponse(BaseModel):
|
|
270
|
+
presigned_urls: list[CheckpointFile]
|
|
271
|
+
next_page_token: int | None = None
|
|
272
|
+
total_count: int
|
|
273
|
+
|
|
274
|
+
|
|
275
|
+
# ── Model info ───────────────────────────────────────────────────────
|
|
276
|
+
|
|
277
|
+
|
|
278
|
+
class ModelData(BaseModel):
|
|
279
|
+
arch: str | None = None
|
|
280
|
+
model_name: str | None = None
|
|
281
|
+
tokenizer_id: str | None = None
|
|
282
|
+
|
|
283
|
+
|
|
284
|
+
class GetInfoResponse(BaseModel):
|
|
285
|
+
model_data: ModelData = Field(default_factory=ModelData)
|
|
286
|
+
model_name: str | None = None
|
|
287
|
+
is_lora: bool | None = None
|
|
288
|
+
lora_rank: int | None = None
|
|
289
|
+
max_seq_len: int | None = None
|
|
290
|
+
|
|
291
|
+
model_config = {"protected_namespaces": ()}
|