baseten-loops 0.1.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,10 @@
1
+ .env
2
+ __pycache__/
3
+ *.pyc
4
+ *.pyo
5
+ .venv/
6
+ dist/
7
+ *.egg-info/
8
+ .pytest_cache/
9
+ .mypy_cache/
10
+ .ruff_cache/
@@ -0,0 +1,13 @@
1
+ Metadata-Version: 2.4
2
+ Name: baseten-loops
3
+ Version: 0.1.0
4
+ Summary: SDK for Baseten training workers
5
+ Requires-Python: >=3.12
6
+ Requires-Dist: datasets>=4.8.4
7
+ Requires-Dist: httpx>=0.24.1
8
+ Requires-Dist: pydantic>=2.10.0
9
+ Requires-Dist: tenacity>=8.0.0
10
+ Requires-Dist: transformers>=4.40.0
11
+ Requires-Dist: truss>=0.16.0
12
+ Provides-Extra: tinker
13
+ Requires-Dist: baseten-loops-tinker>=0.1.0; extra == 'tinker'
@@ -0,0 +1,65 @@
1
+ from baseten.loops.client import create_training_client
2
+ from baseten.loops.promise_client import (
3
+ RemoteOpError,
4
+ ServerShutdownError,
5
+ UnknownRequestError,
6
+ )
7
+ from baseten.loops.training_client import TrainingClient, OperationFuture
8
+ from baseten.loops.sampling_client import SamplingClient
9
+ from baseten.loops.service_client import ServiceClient
10
+ from baseten.loops.models import (
11
+ AdamParams,
12
+ Checkpoint,
13
+ CheckpointFile,
14
+ CheckpointFilesResponse,
15
+ Datum,
16
+ EncodedTextChunk,
17
+ ForwardBackwardOutput,
18
+ InitTrainerServerResponse,
19
+ ImageChunk,
20
+ LoadWeightsResponse,
21
+ ModelInput,
22
+ ModelInputChunk,
23
+ OptimStepResponse,
24
+ SampledSequence,
25
+ SampleResponse,
26
+ SampleResult,
27
+ SamplingParams,
28
+ SaveWeightsForSamplerResponse,
29
+ SaveWeightsResponse,
30
+ TensorData,
31
+ )
32
+
33
+ __all__ = [
34
+ # Client classes
35
+ "create_training_client",
36
+ "ServiceClient",
37
+ "TrainingClient",
38
+ "SamplingClient",
39
+ "OperationFuture",
40
+ # Promise-protocol exceptions
41
+ "RemoteOpError",
42
+ "UnknownRequestError",
43
+ "ServerShutdownError",
44
+ # Types
45
+ "AdamParams",
46
+ "Checkpoint",
47
+ "CheckpointFile",
48
+ "CheckpointFilesResponse",
49
+ "Datum",
50
+ "EncodedTextChunk",
51
+ "ForwardBackwardOutput",
52
+ "InitTrainerServerResponse",
53
+ "ImageChunk",
54
+ "LoadWeightsResponse",
55
+ "ModelInput",
56
+ "ModelInputChunk",
57
+ "OptimStepResponse",
58
+ "SampledSequence",
59
+ "SampleResponse",
60
+ "SampleResult",
61
+ "SamplingParams",
62
+ "SaveWeightsForSamplerResponse",
63
+ "SaveWeightsResponse",
64
+ "TensorData",
65
+ ]
@@ -0,0 +1,150 @@
1
+ """Deploy a training worker and return a TrainingClient connected directly to it."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import json
6
+ import uuid
7
+ from pathlib import Path
8
+ from typing import Optional
9
+
10
+ import httpx
11
+
12
+ from baseten.loops.training_client import TrainingClient
13
+ from truss.base import truss_config
14
+ from truss_train import definitions
15
+ from truss_train.public_api import push
16
+
17
+
18
+ _DEFAULT_WORKSPACE_ROOT = Path(__file__).resolve().parent.parent / "server"
19
+
20
+
21
+ def create_training_client(
22
+ base_model: str,
23
+ worker_url: str,
24
+ *,
25
+ api_key: str | None = None,
26
+ gpu_count: int = 1,
27
+ accelerator: str = "H100",
28
+ training_gpus: Optional[list[int]] = None,
29
+ inference_gpus: Optional[list[int]] = None,
30
+ max_seq_len: int = 4096,
31
+ worker_port: int = 8001,
32
+ namespace: str = "default",
33
+ remote: str = "baseten",
34
+ workspace_root: Optional[Path] = None,
35
+ deploy: bool = True,
36
+ timeout: float = 600.0,
37
+ ) -> TrainingClient:
38
+ """Deploy a training worker and return a TrainingClient connected to it.
39
+
40
+ Args:
41
+ base_model: HuggingFace model ID (e.g. "Qwen/Qwen3-8B").
42
+ worker_url: URL of the dp_worker to connect to. If deploy=True, this
43
+ is computed automatically from the job ID and namespace.
44
+ api_key: API key for authentication.
45
+ gpu_count: Total number of GPUs to request.
46
+ accelerator: GPU type (H100, H200, B200).
47
+ training_gpus: GPU indices for training. Defaults to [0].
48
+ inference_gpus: GPU indices for inference. Defaults to [0].
49
+ max_seq_len: Maximum sequence length for training.
50
+ worker_port: Port the dp_worker will listen on.
51
+ namespace: K8s namespace where the training job runs.
52
+ remote: Baseten remote name from .trussrc.
53
+ workspace_root: Path to thinker workspace root.
54
+ deploy: If True, deploy a training job. If False, just connect to worker_url.
55
+ timeout: HTTP timeout for training operations.
56
+ """
57
+ if not deploy:
58
+ return TrainingClient(worker_url, api_key=api_key, timeout=timeout)
59
+
60
+ suffix = uuid.uuid4().hex[:7]
61
+ project_name = f"trainer-{base_model.replace('/', '-')}-{suffix}"
62
+ print(f"Project: {project_name}")
63
+
64
+ if training_gpus is None:
65
+ training_gpus = [0]
66
+ if inference_gpus is None:
67
+ inference_gpus = [0]
68
+
69
+ rl_config = {
70
+ "model_id": base_model,
71
+ "training": {
72
+ "tensor_parallel_size": 1,
73
+ "pipeline_parallel_size": 1,
74
+ "max_length": max_seq_len,
75
+ "gpus": training_gpus,
76
+ },
77
+ "inference": {
78
+ "tensor_parallel_size": 1,
79
+ "gpus": inference_gpus,
80
+ "gpu_memory_utilization": 0.9,
81
+ },
82
+ }
83
+
84
+ ws_root = workspace_root or _DEFAULT_WORKSPACE_ROOT
85
+ if not ws_root.exists():
86
+ raise FileNotFoundError(
87
+ f"Workspace root not found: {ws_root}. "
88
+ "Pass workspace_root= pointing to the thinker repo."
89
+ )
90
+
91
+ rl_config_path = ws_root / "rl_config.json"
92
+ rl_config_path.write_text(json.dumps(rl_config))
93
+
94
+ accel_enum = getattr(
95
+ truss_config.Accelerator,
96
+ accelerator.upper(),
97
+ truss_config.Accelerator.H100,
98
+ )
99
+
100
+ project = definitions.TrainingProject(
101
+ name=project_name,
102
+ job=definitions.TrainingJob(
103
+ compute=definitions.Compute(
104
+ accelerator=truss_config.AcceleratorSpec(
105
+ accelerator=accel_enum,
106
+ count=gpu_count,
107
+ ),
108
+ ),
109
+ runtime=definitions.Runtime(
110
+ start_commands=[
111
+ "apt-get update && apt-get install -y python3-dev curl",
112
+ "curl -LsSf https://astral.sh/uv/install.sh | sh",
113
+ ". $HOME/.local/bin/env && uv sync --extra worker",
114
+ f".venv/bin/python -m trainers_server.dp_worker.main --config $RL_CONFIG_PATH --port {worker_port}",
115
+ ],
116
+ environment_variables={
117
+ "RL_CONFIG_PATH": "rl_config.json",
118
+ "BASETEN_API_KEY": definitions.SecretReference(name="baseten_api_key"),
119
+ },
120
+ ),
121
+ image=definitions.Image(
122
+ base_image="nvcr.io/nvidia/cuda:12.8.1-cudnn-devel-ubuntu24.04",
123
+ ),
124
+ workspace=definitions.Workspace(
125
+ workspace_root=str(ws_root),
126
+ exclude_dirs=[
127
+ str(ws_root / ".venv"),
128
+ str(ws_root / ".git"),
129
+ ],
130
+ ),
131
+ ),
132
+ )
133
+
134
+ try:
135
+ result = push(project, remote=remote, source_dir=ws_root)
136
+ job_id = result["id"]
137
+ print(f"Training Job ID: {job_id}")
138
+ finally:
139
+ rl_config_path.unlink(missing_ok=True)
140
+
141
+ # Build the deterministic pod DNS for the worker.
142
+ worker_host = (
143
+ f"baseten-training-job-{job_id}-multinode-0"
144
+ f".baseten-training-job-{job_id}-multinode"
145
+ f".{namespace}.svc.cluster.local"
146
+ )
147
+ resolved_url = f"http://{worker_host}:{worker_port}"
148
+ print(f"Worker URL: {resolved_url}")
149
+
150
+ return TrainingClient(resolved_url, api_key=api_key, timeout=timeout)
@@ -0,0 +1,291 @@
1
+ """Types for the loops SDK — wire-compatible with trainers-server."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import base64
6
+ from datetime import datetime
7
+ from typing import Annotated, Literal, Self, Union
8
+
9
+ from pydantic import BaseModel, Discriminator, Field, Tag
10
+
11
+
12
+ # ── Tensor / model-input primitives ─────────────────────────────────
13
+
14
+
15
+ class TensorData(BaseModel):
16
+ data: list
17
+ dtype: str
18
+ shape: list[int]
19
+
20
+ def tolist(self) -> list:
21
+ """Return the tensor data as a plain Python list."""
22
+ return list(self.data)
23
+
24
+ def to_torch(self): # type: ignore[return]
25
+ """Return the tensor data as a torch.Tensor with the correct shape and dtype."""
26
+ import torch
27
+
28
+ _DTYPE_MAP = {
29
+ "float32": torch.float32,
30
+ "float16": torch.float16,
31
+ "bfloat16": torch.bfloat16,
32
+ "float64": torch.float64,
33
+ "int32": torch.int32,
34
+ "int64": torch.int64,
35
+ "int8": torch.int8,
36
+ "bool": torch.bool,
37
+ }
38
+ torch_dtype = _DTYPE_MAP.get(self.dtype, torch.float32)
39
+ # The wire format encodes masked / inactive positions as JSON null
40
+ # (since JSON has no NaN literal); the server emits them via
41
+ # masked_fill(..., float("nan")) and serialisation flips NaN → None.
42
+ # Convert back here so torch.tensor can ingest the list. Only floats
43
+ # carry the None-as-NaN convention — int / bool tensors never have
44
+ # masked positions.
45
+ data = self.data
46
+ if torch_dtype.is_floating_point and any(v is None for v in data):
47
+ data = [float("nan") if v is None else v for v in data]
48
+ return torch.tensor(data, dtype=torch_dtype).reshape(self.shape)
49
+
50
+ @classmethod
51
+ def from_torch(cls, tensor) -> Self:
52
+ """Construct TensorData from a torch.Tensor."""
53
+ dtype_name = str(tensor.dtype).replace("torch.", "")
54
+ return cls(
55
+ data=tensor.flatten().tolist(),
56
+ dtype=dtype_name,
57
+ shape=list(tensor.shape),
58
+ )
59
+
60
+ @classmethod
61
+ def from_list(cls, data: list, dtype: str = "float32") -> Self:
62
+ flat = data
63
+ shape = []
64
+ level = data
65
+ while isinstance(level, list):
66
+ shape.append(len(level))
67
+ level = level[0] if level else []
68
+ return cls(data=flat, dtype=dtype, shape=shape)
69
+
70
+
71
+ class EncodedTextChunk(BaseModel):
72
+ type: Literal["encoded_text"] = "encoded_text"
73
+ tokens: list[int]
74
+
75
+ @property
76
+ def length(self) -> int:
77
+ return len(self.tokens)
78
+
79
+
80
+ class ImageChunk(BaseModel):
81
+ type: Literal["image"] = "image"
82
+ data: str # base64-encoded
83
+ format: str
84
+ expected_tokens: int = 0
85
+
86
+ @classmethod
87
+ def from_bytes(cls, raw: bytes, fmt: str, expected_tokens: int = 0) -> Self:
88
+ return cls(
89
+ data=base64.b64encode(raw).decode(),
90
+ format=fmt,
91
+ expected_tokens=expected_tokens,
92
+ )
93
+
94
+ def to_bytes(self) -> bytes:
95
+ return base64.b64decode(self.data)
96
+
97
+ @property
98
+ def length(self) -> int:
99
+ return self.expected_tokens
100
+
101
+
102
+ ModelInputChunk = Annotated[
103
+ Union[
104
+ Annotated[EncodedTextChunk, Tag("encoded_text")],
105
+ Annotated[ImageChunk, Tag("image")],
106
+ ],
107
+ Discriminator("type"),
108
+ ]
109
+
110
+
111
+ class ModelInput(BaseModel):
112
+ chunks: list[ModelInputChunk] = Field(default_factory=list)
113
+
114
+ @classmethod
115
+ def from_ints(cls, tokens: list[int]) -> Self:
116
+ return cls(chunks=[EncodedTextChunk(tokens=tokens)])
117
+
118
+ @classmethod
119
+ def empty(cls) -> Self:
120
+ return cls(chunks=[])
121
+
122
+ @property
123
+ def length(self) -> int:
124
+ return sum(c.length for c in self.chunks)
125
+
126
+ def to_ints(self) -> list[int]:
127
+ out: list[int] = []
128
+ for chunk in self.chunks:
129
+ if not isinstance(chunk, EncodedTextChunk):
130
+ raise TypeError(f"Cannot convert {type(chunk).__name__} to ints")
131
+ out.extend(chunk.tokens)
132
+ return out
133
+
134
+ def append(self, chunk: ModelInputChunk) -> Self:
135
+ return type(self)(chunks=[*self.chunks, chunk])
136
+
137
+
138
+ class Datum(BaseModel):
139
+ model_input: ModelInput
140
+ loss_fn_inputs: dict[str, TensorData] = Field(default_factory=dict)
141
+
142
+
143
+ # ── Sampling ─────────────────────────────────────────────────────────
144
+
145
+
146
+ class SamplingParams(BaseModel):
147
+ max_tokens: int | None = None
148
+ seed: int | None = None
149
+ stop: str | list[str] | list[int] | None = None
150
+ temperature: float = 1.0
151
+ top_k: int = -1
152
+ top_p: float = 1.0
153
+
154
+
155
+ class SampledSequence(BaseModel):
156
+ tokens: list[int] = Field(default_factory=list)
157
+ logprobs: list[float] | None = None
158
+ stop_reason: str = "length"
159
+
160
+
161
+ class PromptLogprobs(BaseModel):
162
+ """Per-token log-probabilities for the prompt.
163
+
164
+ ``logprobs[i]`` is the logprob of the actual prompt token at position
165
+ ``i``. Index 0 is always ``None`` — the first token has no preceding
166
+ context to score against.
167
+
168
+ ``top_logprobs[i]`` (only populated when the caller passed
169
+ ``topk_prompt_logprobs > 0``) maps ``token_id -> logprob`` for the top-k
170
+ alternatives at position ``i``. The actual prompt token is also included
171
+ even if it falls outside the top-k.
172
+ """
173
+
174
+ logprobs: list[float | None] = Field(default_factory=list)
175
+ top_logprobs: list[dict[int, float] | None] | None = None
176
+
177
+
178
+ class SampleResult(BaseModel):
179
+ sequences: list[SampledSequence] = Field(default_factory=list)
180
+ policy_version: int | None = None
181
+ prompt_logprobs: PromptLogprobs | None = None
182
+
183
+
184
+ # Alias kept for API surface compatibility.
185
+ SampleResponse = SampleResult
186
+
187
+
188
+ # ── Optimizer ────────────────────────────────────────────────────────
189
+
190
+
191
+ class AdamParams(BaseModel):
192
+ learning_rate: float = 1e-4
193
+ beta1: float = 0.9
194
+ beta2: float = 0.95
195
+ eps: float = 1e-12
196
+ weight_decay: float = 0.0
197
+ grad_clip_norm: float = 0.0 # 0 = disabled
198
+
199
+
200
+ # ── Training response types ──────────────────────────────────────────
201
+
202
+
203
+ class ForwardBackwardOutput(BaseModel):
204
+ loss: float = 0.0
205
+ loss_fn_output_type: str = ""
206
+ loss_fn_outputs: list[dict[str, TensorData]] = Field(default_factory=list)
207
+ metrics: dict[str, float] = Field(default_factory=dict)
208
+
209
+
210
+ class OptimStepResponse(BaseModel):
211
+ metrics: dict[str, float] | None = None
212
+
213
+
214
+ class SaveWeightsResponse(BaseModel):
215
+ mode: str = ""
216
+ path: str = ""
217
+
218
+
219
+ class SaveWeightsForSamplerResponse(BaseModel):
220
+ """Response for ``/save_weights_for_sampler`` (the trainer-side
221
+ publish-LoRA endpoint).
222
+
223
+ ``version`` is the controller's ``step_count`` at the moment of the
224
+ call — clients don't pass it. Two syncs at the same version mean
225
+ "no new weights since last check"; a sampler watching the URI keys
226
+ its hot-swap on this number.
227
+
228
+ ``path`` is the checkpoint subdirectory the adapters were written
229
+ to (``{weight_sync_uri}/{name}/``).
230
+ """
231
+ version: int = 0
232
+ path: str = ""
233
+
234
+
235
+ class LoadWeightsResponse(BaseModel):
236
+ step: int = 0
237
+
238
+
239
+ class InitTrainerServerResponse(BaseModel):
240
+ """Response for ``/init_trainer_server``."""
241
+
242
+ step: int = 0
243
+ lora_rank: int
244
+
245
+
246
+ # ── Checkpoints ──────────────────────────────────────────────────────
247
+
248
+
249
+ class Checkpoint(BaseModel):
250
+ model_config = {"populate_by_name": True}
251
+
252
+ trainer_server_id: str = Field(alias="trainer_id")
253
+ checkpoint_id: str
254
+ checkpoint_type: str
255
+ created_at: datetime
256
+ base_model: str | None = None
257
+ lora_adapter_config: dict | None = None
258
+ size_bytes: int = 0
259
+ sync_status: str | None = None
260
+
261
+
262
+ class CheckpointFile(BaseModel):
263
+ url: str
264
+ relative_file_name: str
265
+ size_bytes: int
266
+ last_modified: datetime
267
+
268
+
269
+ class CheckpointFilesResponse(BaseModel):
270
+ presigned_urls: list[CheckpointFile]
271
+ next_page_token: int | None = None
272
+ total_count: int
273
+
274
+
275
+ # ── Model info ───────────────────────────────────────────────────────
276
+
277
+
278
+ class ModelData(BaseModel):
279
+ arch: str | None = None
280
+ model_name: str | None = None
281
+ tokenizer_id: str | None = None
282
+
283
+
284
+ class GetInfoResponse(BaseModel):
285
+ model_data: ModelData = Field(default_factory=ModelData)
286
+ model_name: str | None = None
287
+ is_lora: bool | None = None
288
+ lora_rank: int | None = None
289
+ max_seq_len: int | None = None
290
+
291
+ model_config = {"protected_namespaces": ()}