truss 0.11.6rc102__py3-none-any.whl → 0.11.24rc2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- truss/api/__init__.py +5 -2
- truss/base/constants.py +1 -0
- truss/base/trt_llm_config.py +14 -3
- truss/base/truss_config.py +19 -4
- truss/cli/chains_commands.py +49 -1
- truss/cli/cli.py +38 -7
- truss/cli/logs/base_watcher.py +31 -12
- truss/cli/logs/model_log_watcher.py +24 -1
- truss/cli/remote_cli.py +29 -0
- truss/cli/resolvers/chain_team_resolver.py +82 -0
- truss/cli/resolvers/model_team_resolver.py +90 -0
- truss/cli/resolvers/training_project_team_resolver.py +81 -0
- truss/cli/train/cache.py +332 -0
- truss/cli/train/core.py +57 -163
- truss/cli/train/deploy_checkpoints/__init__.py +2 -2
- truss/cli/train/deploy_checkpoints/deploy_checkpoints.py +236 -103
- truss/cli/train/deploy_checkpoints/deploy_checkpoints_helpers.py +1 -52
- truss/cli/train/deploy_checkpoints/deploy_full_checkpoints.py +1 -86
- truss/cli/train/deploy_checkpoints/deploy_lora_checkpoints.py +1 -85
- truss/cli/train/deploy_checkpoints/deploy_whisper_checkpoints.py +1 -56
- truss/cli/train/types.py +18 -9
- truss/cli/train_commands.py +180 -35
- truss/cli/utils/common.py +40 -3
- truss/contexts/image_builder/serving_image_builder.py +17 -4
- truss/remote/baseten/api.py +215 -9
- truss/remote/baseten/core.py +63 -7
- truss/remote/baseten/custom_types.py +1 -0
- truss/remote/baseten/remote.py +42 -2
- truss/remote/baseten/service.py +0 -7
- truss/remote/baseten/utils/transfer.py +5 -2
- truss/templates/base.Dockerfile.jinja +8 -4
- truss/templates/control/control/application.py +51 -26
- truss/templates/control/control/endpoints.py +1 -5
- truss/templates/control/control/helpers/inference_server_process_controller.py +10 -4
- truss/templates/control/control/helpers/truss_patch/model_container_patch_applier.py +33 -18
- truss/templates/control/control/server.py +1 -1
- truss/templates/control/requirements.txt +1 -2
- truss/templates/docker_server/proxy.conf.jinja +13 -0
- truss/templates/docker_server/supervisord.conf.jinja +2 -1
- truss/templates/no_build.Dockerfile.jinja +1 -0
- truss/templates/server/requirements.txt +2 -3
- truss/templates/server/truss_server.py +2 -5
- truss/templates/server.Dockerfile.jinja +12 -12
- truss/templates/shared/lazy_data_resolver.py +214 -2
- truss/templates/shared/util.py +6 -5
- truss/tests/cli/chains/test_chains_team_parameter.py +443 -0
- truss/tests/cli/test_chains_cli.py +144 -0
- truss/tests/cli/test_cli.py +134 -1
- truss/tests/cli/test_cli_utils_common.py +11 -0
- truss/tests/cli/test_model_team_resolver.py +279 -0
- truss/tests/cli/train/test_cache_view.py +240 -3
- truss/tests/cli/train/test_deploy_checkpoints.py +2 -846
- truss/tests/cli/train/test_train_cli_core.py +2 -2
- truss/tests/cli/train/test_train_team_parameter.py +395 -0
- truss/tests/conftest.py +187 -0
- truss/tests/contexts/image_builder/test_serving_image_builder.py +10 -5
- truss/tests/remote/baseten/test_api.py +122 -3
- truss/tests/remote/baseten/test_chain_upload.py +294 -0
- truss/tests/remote/baseten/test_core.py +86 -0
- truss/tests/remote/baseten/test_remote.py +216 -288
- truss/tests/remote/baseten/test_service.py +56 -0
- truss/tests/templates/control/control/conftest.py +20 -0
- truss/tests/templates/control/control/test_endpoints.py +4 -0
- truss/tests/templates/control/control/test_server.py +8 -24
- truss/tests/templates/control/control/test_server_integration.py +4 -2
- truss/tests/test_config.py +21 -12
- truss/tests/test_data/server.Dockerfile +3 -1
- truss/tests/test_data/test_build_commands_truss/__init__.py +0 -0
- truss/tests/test_data/test_build_commands_truss/config.yaml +14 -0
- truss/tests/test_data/test_build_commands_truss/model/model.py +12 -0
- truss/tests/test_data/test_build_commands_truss/packages/constants/constants.py +1 -0
- truss/tests/test_data/test_truss_server_model_cache_v1/config.yaml +1 -0
- truss/tests/test_model_inference.py +13 -0
- truss/tests/util/test_env_vars.py +8 -3
- truss/util/__init__.py +0 -0
- truss/util/env_vars.py +19 -8
- truss/util/error_utils.py +37 -0
- {truss-0.11.6rc102.dist-info → truss-0.11.24rc2.dist-info}/METADATA +2 -2
- {truss-0.11.6rc102.dist-info → truss-0.11.24rc2.dist-info}/RECORD +88 -70
- {truss-0.11.6rc102.dist-info → truss-0.11.24rc2.dist-info}/WHEEL +1 -1
- truss_chains/deployment/deployment_client.py +16 -4
- truss_chains/private_types.py +18 -0
- truss_chains/public_api.py +3 -0
- truss_train/definitions.py +6 -4
- truss_train/deployment.py +43 -21
- truss_train/public_api.py +4 -2
- {truss-0.11.6rc102.dist-info → truss-0.11.24rc2.dist-info}/entry_points.txt +0 -0
- {truss-0.11.6rc102.dist-info → truss-0.11.24rc2.dist-info}/licenses/LICENSE +0 -0
|
@@ -1,7 +1,6 @@
|
|
|
1
|
+
import json
|
|
1
2
|
import re
|
|
2
|
-
import tempfile
|
|
3
3
|
from collections import OrderedDict
|
|
4
|
-
from pathlib import Path
|
|
5
4
|
from typing import List, Optional, Union
|
|
6
5
|
|
|
7
6
|
import rich_click as click
|
|
@@ -11,7 +10,8 @@ from truss.base import truss_config
|
|
|
11
10
|
from truss.cli.train import common
|
|
12
11
|
from truss.cli.train.types import (
|
|
13
12
|
DeployCheckpointsConfigComplete,
|
|
14
|
-
|
|
13
|
+
DeploySuccessModelVersion,
|
|
14
|
+
DeploySuccessResult,
|
|
15
15
|
)
|
|
16
16
|
from truss.cli.utils.output import console
|
|
17
17
|
from truss.remote.baseten.remote import BasetenRemote
|
|
@@ -25,18 +25,9 @@ from truss_train.definitions import (
|
|
|
25
25
|
SecretReference,
|
|
26
26
|
)
|
|
27
27
|
|
|
28
|
-
from .deploy_full_checkpoints import
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
)
|
|
32
|
-
from .deploy_lora_checkpoints import (
|
|
33
|
-
hydrate_lora_checkpoint,
|
|
34
|
-
render_vllm_lora_truss_config,
|
|
35
|
-
)
|
|
36
|
-
from .deploy_whisper_checkpoints import (
|
|
37
|
-
hydrate_whisper_checkpoint,
|
|
38
|
-
render_vllm_whisper_truss_config,
|
|
39
|
-
)
|
|
28
|
+
from .deploy_full_checkpoints import hydrate_full_checkpoint
|
|
29
|
+
from .deploy_lora_checkpoints import hydrate_lora_checkpoint
|
|
30
|
+
from .deploy_whisper_checkpoints import hydrate_whisper_checkpoint
|
|
40
31
|
|
|
41
32
|
HF_TOKEN_ENVVAR_NAME = "HF_TOKEN"
|
|
42
33
|
# If we change this, make sure to update the logic in backend codebase
|
|
@@ -44,28 +35,184 @@ CHECKPOINT_PATTERN = re.compile(r".*checkpoint-\d+(?:-\d+)?$")
|
|
|
44
35
|
ALLOWED_DEPLOYMENT_NAMES = re.compile(r"^[0-9a-zA-Z_\-\.]*$")
|
|
45
36
|
|
|
46
37
|
|
|
47
|
-
def
|
|
38
|
+
def create_model_version_from_inference_template(
|
|
48
39
|
remote_provider: BasetenRemote,
|
|
49
40
|
checkpoint_deploy_config: DeployCheckpointsConfig,
|
|
50
41
|
project_id: Optional[str],
|
|
51
42
|
job_id: Optional[str],
|
|
52
|
-
|
|
43
|
+
dry_run: bool,
|
|
44
|
+
) -> DeploySuccessResult:
|
|
53
45
|
checkpoint_deploy_config = _hydrate_deploy_config(
|
|
54
46
|
checkpoint_deploy_config, remote_provider, project_id, job_id
|
|
55
47
|
)
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
)
|
|
59
|
-
truss_directory = Path(tempfile.mkdtemp())
|
|
60
|
-
truss_config_path = truss_directory / "config.yaml"
|
|
61
|
-
rendered_truss.write_to_yaml_file(truss_config_path)
|
|
62
|
-
console.print(rendered_truss, style="green")
|
|
63
|
-
console.print(f"Writing truss config to {truss_config_path}", style="yellow")
|
|
64
|
-
return PrepareCheckpointResult(
|
|
65
|
-
truss_directory=truss_directory,
|
|
48
|
+
|
|
49
|
+
request_data = _build_inference_template_request(
|
|
66
50
|
checkpoint_deploy_config=checkpoint_deploy_config,
|
|
51
|
+
remote_provider=remote_provider,
|
|
52
|
+
dry_run=dry_run,
|
|
67
53
|
)
|
|
68
54
|
|
|
55
|
+
# Call the GraphQL mutation to create model version from inference template
|
|
56
|
+
try:
|
|
57
|
+
result = remote_provider.api.create_model_version_from_inference_template(
|
|
58
|
+
request_data
|
|
59
|
+
)
|
|
60
|
+
truss_config_result = _get_truss_config_from_result(result)
|
|
61
|
+
|
|
62
|
+
model_version = None
|
|
63
|
+
if result and result.get("model_version"):
|
|
64
|
+
console.print(
|
|
65
|
+
f"Successfully created model version: {result['model_version']['name']}",
|
|
66
|
+
style="green",
|
|
67
|
+
)
|
|
68
|
+
console.print(
|
|
69
|
+
f"Model version ID: {result['model_version']['id']}", style="yellow"
|
|
70
|
+
)
|
|
71
|
+
model_version = DeploySuccessModelVersion.model_validate(
|
|
72
|
+
result["model_version"]
|
|
73
|
+
)
|
|
74
|
+
elif not dry_run:
|
|
75
|
+
console.print(
|
|
76
|
+
"Warning: Unexpected response format from server", style="yellow"
|
|
77
|
+
)
|
|
78
|
+
console.print(f"Response: {result}", style="yellow")
|
|
79
|
+
|
|
80
|
+
except Exception as e:
|
|
81
|
+
console.print(f"Error creating model version: {e}", style="red")
|
|
82
|
+
raise
|
|
83
|
+
|
|
84
|
+
return DeploySuccessResult(
|
|
85
|
+
deploy_config=checkpoint_deploy_config,
|
|
86
|
+
truss_config=truss_config_result,
|
|
87
|
+
model_version=model_version,
|
|
88
|
+
)
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
def _get_truss_config_from_result(result: dict) -> Optional[truss_config.TrussConfig]:
|
|
92
|
+
if result and result.get("truss_config"):
|
|
93
|
+
truss_config_dict = json.loads(result["truss_config"])
|
|
94
|
+
return truss_config.TrussConfig.from_dict(truss_config_dict)
|
|
95
|
+
# Although this should never happen, we defensively allow ourselves to return None
|
|
96
|
+
# because we need a failure to handle the truss config doesn't necessarily mean we failed to deploy
|
|
97
|
+
# the model version.
|
|
98
|
+
console.print(
|
|
99
|
+
"No truss config returned. Reach out to Baseten for support if this persists.",
|
|
100
|
+
style="red",
|
|
101
|
+
)
|
|
102
|
+
return None
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
def _build_inference_template_request(
|
|
106
|
+
checkpoint_deploy_config: DeployCheckpointsConfigComplete,
|
|
107
|
+
remote_provider: BasetenRemote,
|
|
108
|
+
dry_run: bool,
|
|
109
|
+
) -> dict:
|
|
110
|
+
"""
|
|
111
|
+
Build the GraphQL request data structure for createModelVersionFromInferenceTemplate mutation.
|
|
112
|
+
"""
|
|
113
|
+
|
|
114
|
+
# Build weights sources
|
|
115
|
+
weights_sources = []
|
|
116
|
+
for checkpoint in checkpoint_deploy_config.checkpoint_details.checkpoints:
|
|
117
|
+
# Extract checkpoint name from the first path
|
|
118
|
+
weights_source = {
|
|
119
|
+
"weight_source_type": "B10_CHECKPOINTING",
|
|
120
|
+
"b10_training_checkpoint_weights_source": {
|
|
121
|
+
"checkpoint": {
|
|
122
|
+
"training_job_id": checkpoint.training_job_id,
|
|
123
|
+
"checkpoint_name": checkpoint.checkpoint_name,
|
|
124
|
+
}
|
|
125
|
+
},
|
|
126
|
+
}
|
|
127
|
+
weights_sources.append(weights_source)
|
|
128
|
+
|
|
129
|
+
# Build environment variables
|
|
130
|
+
environment_variables = []
|
|
131
|
+
for name, value in checkpoint_deploy_config.runtime.environment_variables.items():
|
|
132
|
+
if isinstance(value, SecretReference):
|
|
133
|
+
env_var = {"name": name, "value": value.name, "is_secret_reference": True}
|
|
134
|
+
else:
|
|
135
|
+
env_var = {"name": name, "value": str(value), "is_secret_reference": False}
|
|
136
|
+
environment_variables.append(env_var)
|
|
137
|
+
|
|
138
|
+
# Build inference stack
|
|
139
|
+
inference_stack = {
|
|
140
|
+
"stack_type": "VLLM",
|
|
141
|
+
"environment_variables": environment_variables,
|
|
142
|
+
}
|
|
143
|
+
|
|
144
|
+
# Get instance type ID from compute spec
|
|
145
|
+
instance_type_id = _get_instance_type_id(
|
|
146
|
+
checkpoint_deploy_config.compute, remote_provider
|
|
147
|
+
)
|
|
148
|
+
|
|
149
|
+
# Build the complete request
|
|
150
|
+
request_data = {
|
|
151
|
+
"metadata": {"oracle_name": checkpoint_deploy_config.model_name},
|
|
152
|
+
"weights_sources": weights_sources,
|
|
153
|
+
"inference_stack": inference_stack,
|
|
154
|
+
"instance_type_id": instance_type_id,
|
|
155
|
+
"dry_run": dry_run,
|
|
156
|
+
}
|
|
157
|
+
|
|
158
|
+
return request_data
|
|
159
|
+
|
|
160
|
+
|
|
161
|
+
def _get_instance_type_id(compute: Compute, remote_provider: BasetenRemote) -> str:
|
|
162
|
+
"""
|
|
163
|
+
Get the instance type ID based on the compute specification.
|
|
164
|
+
Fetches available instance types from the API and maps compute specs to instance type IDs.
|
|
165
|
+
Only considers single-node instances (node_count == 1).
|
|
166
|
+
"""
|
|
167
|
+
# step 1: fetch the instance types from the API
|
|
168
|
+
instance_types = remote_provider.api.get_instance_types()
|
|
169
|
+
# step 2: sort them into two different dictionaries, excluding multi-node instances:
|
|
170
|
+
cpu_instance_types = {
|
|
171
|
+
it.id: it for it in instance_types if it.gpu_count == 0 and it.node_count == 1
|
|
172
|
+
}
|
|
173
|
+
gpu_instance_types = {
|
|
174
|
+
it.id: it for it in instance_types if it.gpu_count > 0 and it.node_count == 1
|
|
175
|
+
}
|
|
176
|
+
# step 3: if compute is cpu, find the smallest such cpu that matches the compute request
|
|
177
|
+
if not compute.accelerator or compute.accelerator.accelerator is None:
|
|
178
|
+
compute_as_truss_config = compute.to_truss_config()
|
|
179
|
+
smallest_cpu_instance_type = None
|
|
180
|
+
for it in cpu_instance_types.values():
|
|
181
|
+
if (
|
|
182
|
+
it.millicpu_limit / 1000 >= compute.cpu_count
|
|
183
|
+
and it.memory_limit >= compute_as_truss_config.memory_in_bytes
|
|
184
|
+
):
|
|
185
|
+
if (
|
|
186
|
+
smallest_cpu_instance_type is None
|
|
187
|
+
or it.millicpu_limit < smallest_cpu_instance_type.millicpu_limit
|
|
188
|
+
):
|
|
189
|
+
smallest_cpu_instance_type = it
|
|
190
|
+
if not smallest_cpu_instance_type:
|
|
191
|
+
raise ValueError(
|
|
192
|
+
f"Unable to find single-node instance type for {compute.cpu_count} CPU and {compute.memory} memory. Reach out to Baseten for support if this persists."
|
|
193
|
+
)
|
|
194
|
+
return smallest_cpu_instance_type.id
|
|
195
|
+
# step 4: if compute is gpu, find the smallest such gpu by instance type
|
|
196
|
+
else:
|
|
197
|
+
assert compute.accelerator.accelerator is not None
|
|
198
|
+
compute_as_truss_config = compute.to_truss_config()
|
|
199
|
+
smallest_gpu_instance_type = None
|
|
200
|
+
for it in gpu_instance_types.values():
|
|
201
|
+
if (
|
|
202
|
+
it.gpu_type == compute.accelerator.accelerator.value
|
|
203
|
+
and it.gpu_count >= compute.accelerator.count
|
|
204
|
+
):
|
|
205
|
+
if (
|
|
206
|
+
smallest_gpu_instance_type is None
|
|
207
|
+
or it.gpu_count < smallest_gpu_instance_type.gpu_count
|
|
208
|
+
):
|
|
209
|
+
smallest_gpu_instance_type = it
|
|
210
|
+
if not smallest_gpu_instance_type:
|
|
211
|
+
raise ValueError(
|
|
212
|
+
f"Unable to find single-node instance type for {compute.accelerator}:{compute.accelerator.count}. Reach out to Baseten for support if this persists."
|
|
213
|
+
)
|
|
214
|
+
return smallest_gpu_instance_type.id
|
|
215
|
+
|
|
69
216
|
|
|
70
217
|
def _validate_base_model_id(
|
|
71
218
|
base_model_id: Optional[str], model_weight_format: ModelWeightsFormat
|
|
@@ -93,18 +240,12 @@ def _get_model_name(
|
|
|
93
240
|
else ""
|
|
94
241
|
)
|
|
95
242
|
|
|
96
|
-
|
|
243
|
+
return inquirer.text(
|
|
97
244
|
message=f"Enter the model name for your {model_weight_format.value} model.",
|
|
98
245
|
validate=lambda s: s and s.strip(),
|
|
99
246
|
default=default,
|
|
100
247
|
).execute()
|
|
101
248
|
|
|
102
|
-
if model_weight_format == ModelWeightsFormat.FULL:
|
|
103
|
-
model_name += "-vLLM-Full"
|
|
104
|
-
elif model_weight_format == ModelWeightsFormat.LORA:
|
|
105
|
-
model_name += "-vLLM-LORA"
|
|
106
|
-
return model_name
|
|
107
|
-
|
|
108
249
|
|
|
109
250
|
def _hydrate_deploy_config(
|
|
110
251
|
deploy_config: DeployCheckpointsConfig,
|
|
@@ -123,53 +264,18 @@ def _hydrate_deploy_config(
|
|
|
123
264
|
else:
|
|
124
265
|
model_name = _get_model_name(model_weight_format, base_model_id)
|
|
125
266
|
|
|
126
|
-
compute = _ensure_compute_spec(deploy_config.compute)
|
|
267
|
+
compute = _ensure_compute_spec(deploy_config.compute, remote_provider)
|
|
127
268
|
|
|
128
269
|
runtime = _ensure_runtime_config(deploy_config.runtime)
|
|
129
|
-
deployment_name = _ensure_deployment_name(
|
|
130
|
-
deploy_config.deployment_name, checkpoint_details.checkpoints
|
|
131
|
-
)
|
|
132
270
|
|
|
133
271
|
return DeployCheckpointsConfigComplete(
|
|
134
272
|
checkpoint_details=checkpoint_details,
|
|
135
273
|
model_name=model_name,
|
|
136
|
-
deployment_name=deployment_name,
|
|
137
274
|
runtime=runtime,
|
|
138
275
|
compute=compute,
|
|
139
|
-
model_weight_format=model_weight_format.to_truss_config(), # type: ignore[attr-defined]
|
|
140
276
|
)
|
|
141
277
|
|
|
142
278
|
|
|
143
|
-
def _ensure_deployment_name(
|
|
144
|
-
deploy_config_deployment_name: Optional[str], checkpoints: List[Checkpoint]
|
|
145
|
-
) -> str:
|
|
146
|
-
if deploy_config_deployment_name:
|
|
147
|
-
return deploy_config_deployment_name
|
|
148
|
-
|
|
149
|
-
default_deployment_name = "checkpoint"
|
|
150
|
-
|
|
151
|
-
if checkpoints and checkpoints[0].paths:
|
|
152
|
-
first_checkpoint_name = checkpoints[0].paths[0].strip("/").split("/")[-1]
|
|
153
|
-
|
|
154
|
-
if ALLOWED_DEPLOYMENT_NAMES.match(first_checkpoint_name):
|
|
155
|
-
# Allow autoincrementing if the checkpoint matches both regexes
|
|
156
|
-
if (
|
|
157
|
-
CHECKPOINT_PATTERN.match(first_checkpoint_name)
|
|
158
|
-
and len(checkpoints) == 1
|
|
159
|
-
):
|
|
160
|
-
return first_checkpoint_name
|
|
161
|
-
|
|
162
|
-
# If no valid autoincrementing checkpoint name is found, prompt the user
|
|
163
|
-
deployment_name = inquirer.text(
|
|
164
|
-
message="Enter the deployment name.", default=default_deployment_name
|
|
165
|
-
).execute()
|
|
166
|
-
|
|
167
|
-
if not deployment_name:
|
|
168
|
-
raise click.UsageError("Deployment name is required.")
|
|
169
|
-
|
|
170
|
-
return deployment_name
|
|
171
|
-
|
|
172
|
-
|
|
173
279
|
def hydrate_checkpoint(
|
|
174
280
|
job_id: str, checkpoint_id: str, checkpoint: dict, checkpoint_type: str
|
|
175
281
|
) -> Checkpoint:
|
|
@@ -190,26 +296,6 @@ def hydrate_checkpoint(
|
|
|
190
296
|
)
|
|
191
297
|
|
|
192
298
|
|
|
193
|
-
def _render_truss_config_for_checkpoint_deployment(
|
|
194
|
-
checkpoint_deploy: DeployCheckpointsConfigComplete,
|
|
195
|
-
) -> truss_config.TrussConfig:
|
|
196
|
-
"""
|
|
197
|
-
Render truss config for checkpoint deployment.
|
|
198
|
-
Currently supports LoRA checkpoints via vLLM, but can be extended for other formats.
|
|
199
|
-
"""
|
|
200
|
-
# Delegate to specific rendering function based on model weight format
|
|
201
|
-
if checkpoint_deploy.model_weight_format == ModelWeightsFormat.LORA:
|
|
202
|
-
return render_vllm_lora_truss_config(checkpoint_deploy)
|
|
203
|
-
elif checkpoint_deploy.model_weight_format == ModelWeightsFormat.FULL:
|
|
204
|
-
return render_vllm_full_truss_config(checkpoint_deploy)
|
|
205
|
-
elif checkpoint_deploy.model_weight_format == ModelWeightsFormat.WHISPER:
|
|
206
|
-
return render_vllm_whisper_truss_config(checkpoint_deploy)
|
|
207
|
-
else:
|
|
208
|
-
raise ValueError(
|
|
209
|
-
f"Unsupported model weight format: {checkpoint_deploy.model_weight_format}. Please upgrade to the latest Truss version to access the latest supported formats. Contact Baseten if you would like us to support additional formats."
|
|
210
|
-
)
|
|
211
|
-
|
|
212
|
-
|
|
213
299
|
def _ensure_checkpoint_details(
|
|
214
300
|
remote_provider: BasetenRemote,
|
|
215
301
|
checkpoint_details: Optional[CheckpointList],
|
|
@@ -217,6 +303,7 @@ def _ensure_checkpoint_details(
|
|
|
217
303
|
job_id: Optional[str],
|
|
218
304
|
) -> CheckpointList:
|
|
219
305
|
if checkpoint_details and checkpoint_details.checkpoints:
|
|
306
|
+
# TODO: check here
|
|
220
307
|
return _process_user_provided_checkpoints(checkpoint_details, remote_provider)
|
|
221
308
|
else:
|
|
222
309
|
return _prompt_user_for_checkpoint_details(
|
|
@@ -309,31 +396,77 @@ def _select_multiple_checkpoints(checkpoint_id_options: List[str]) -> List[str]:
|
|
|
309
396
|
return checkpoint_ids
|
|
310
397
|
|
|
311
398
|
|
|
312
|
-
def _ensure_compute_spec(
|
|
399
|
+
def _ensure_compute_spec(
|
|
400
|
+
compute: Optional[Compute], remote_provider: BasetenRemote
|
|
401
|
+
) -> Compute:
|
|
313
402
|
if not compute:
|
|
314
403
|
compute = Compute(cpu_count=0, memory="0Mi")
|
|
315
|
-
compute
|
|
404
|
+
compute = _get_accelerator_if_specified(compute, remote_provider)
|
|
316
405
|
return compute
|
|
317
406
|
|
|
318
407
|
|
|
319
408
|
def _get_accelerator_if_specified(
|
|
320
|
-
user_input: Optional[
|
|
321
|
-
) ->
|
|
322
|
-
if user_input:
|
|
409
|
+
user_input: Optional[Compute], remote_provider: BasetenRemote
|
|
410
|
+
) -> Compute:
|
|
411
|
+
if user_input and user_input.accelerator:
|
|
323
412
|
return user_input
|
|
413
|
+
|
|
414
|
+
# Fetch available instance types to get valid GPU options
|
|
415
|
+
instance_types = remote_provider.api.get_instance_types()
|
|
416
|
+
|
|
417
|
+
# Extract unique accelerator types from instance types
|
|
418
|
+
accelerator_options = set()
|
|
419
|
+
for it in instance_types:
|
|
420
|
+
if it.gpu_type and it.gpu_count > 0:
|
|
421
|
+
accelerator_options.add(it.gpu_type)
|
|
422
|
+
|
|
423
|
+
# Convert to sorted list and add CPU option
|
|
424
|
+
choices = sorted(list(accelerator_options)) + [None]
|
|
425
|
+
|
|
426
|
+
if not choices or choices == [None]:
|
|
427
|
+
console.print("No GPU instance types available, using CPU", style="yellow")
|
|
428
|
+
return Compute(cpu_count=0, memory="0Mi", accelerator=None)
|
|
429
|
+
|
|
324
430
|
# prompt user for accelerator
|
|
325
431
|
gpu_type = inquirer.select(
|
|
326
432
|
message="Select the GPU type to use for deployment. Select None for CPU.",
|
|
327
|
-
choices=
|
|
433
|
+
choices=choices,
|
|
328
434
|
).execute()
|
|
435
|
+
|
|
329
436
|
if gpu_type is None:
|
|
330
|
-
return None
|
|
331
|
-
|
|
332
|
-
|
|
333
|
-
|
|
334
|
-
|
|
335
|
-
|
|
336
|
-
|
|
437
|
+
return Compute(cpu_count=0, memory="0Mi", accelerator=None)
|
|
438
|
+
|
|
439
|
+
# Get available counts for the selected GPU type
|
|
440
|
+
available_counts = set()
|
|
441
|
+
for it in instance_types:
|
|
442
|
+
if it.gpu_type == gpu_type and it.gpu_count > 0:
|
|
443
|
+
available_counts.add(it.gpu_count)
|
|
444
|
+
if not available_counts:
|
|
445
|
+
raise ValueError(
|
|
446
|
+
f"No available counts for {gpu_type}. Reach out to Baseten for support if this persists."
|
|
447
|
+
)
|
|
448
|
+
|
|
449
|
+
if available_counts:
|
|
450
|
+
count_choices = sorted(list(available_counts))
|
|
451
|
+
count = inquirer.select(
|
|
452
|
+
message=f"Select the number of {gpu_type} GPUs to use for deployment.",
|
|
453
|
+
choices=count_choices,
|
|
454
|
+
default=str(count_choices[0]),
|
|
455
|
+
).execute()
|
|
456
|
+
else:
|
|
457
|
+
count = inquirer.text(
|
|
458
|
+
message=f"Enter the number of {gpu_type} accelerators to use for deployment.",
|
|
459
|
+
default="1",
|
|
460
|
+
validate=lambda x: x.isdigit() and int(x) > 0 and int(x) <= 8,
|
|
461
|
+
).execute()
|
|
462
|
+
|
|
463
|
+
return Compute(
|
|
464
|
+
cpu_count=0,
|
|
465
|
+
memory="0Mi",
|
|
466
|
+
accelerator=truss_config.AcceleratorSpec(
|
|
467
|
+
accelerator=gpu_type.replace("-", "_"), count=int(count)
|
|
468
|
+
),
|
|
469
|
+
)
|
|
337
470
|
|
|
338
471
|
|
|
339
472
|
def _get_base_model_id(user_input: Optional[str], checkpoint: dict) -> Optional[str]:
|
|
@@ -1,52 +1 @@
|
|
|
1
|
-
|
|
2
|
-
from pathlib import Path
|
|
3
|
-
|
|
4
|
-
from truss.base import truss_config
|
|
5
|
-
from truss.cli.train.types import DeployCheckpointsConfigComplete
|
|
6
|
-
from truss_train.definitions import ModelWeightsFormat, SecretReference
|
|
7
|
-
|
|
8
|
-
START_COMMAND_ENVVAR_NAME = "BT_DOCKER_SERVER_START_CMD"
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
def setup_base_truss_config(
|
|
12
|
-
checkpoint_deploy: DeployCheckpointsConfigComplete,
|
|
13
|
-
) -> truss_config.TrussConfig:
|
|
14
|
-
"""Set up the base truss config with common properties."""
|
|
15
|
-
truss_deploy_config = None
|
|
16
|
-
truss_base_file = (
|
|
17
|
-
"deploy_from_checkpoint_config_whisper.yml"
|
|
18
|
-
if checkpoint_deploy.model_weight_format == ModelWeightsFormat.WHISPER
|
|
19
|
-
else "deploy_from_checkpoint_config.yml"
|
|
20
|
-
)
|
|
21
|
-
truss_deploy_config = truss_config.TrussConfig.from_yaml(
|
|
22
|
-
Path(os.path.dirname(__file__), "..", truss_base_file)
|
|
23
|
-
)
|
|
24
|
-
if not truss_deploy_config.docker_server:
|
|
25
|
-
raise ValueError(
|
|
26
|
-
"Unexpected checkpoint deployment config: missing docker_server"
|
|
27
|
-
)
|
|
28
|
-
|
|
29
|
-
truss_deploy_config.model_name = checkpoint_deploy.model_name
|
|
30
|
-
truss_deploy_config.training_checkpoints = (
|
|
31
|
-
checkpoint_deploy.checkpoint_details.to_truss_config()
|
|
32
|
-
)
|
|
33
|
-
truss_deploy_config.resources = checkpoint_deploy.compute.to_truss_config()
|
|
34
|
-
|
|
35
|
-
return truss_deploy_config
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
def setup_environment_variables_and_secrets(
|
|
39
|
-
truss_deploy_config: truss_config.TrussConfig,
|
|
40
|
-
checkpoint_deploy: DeployCheckpointsConfigComplete,
|
|
41
|
-
) -> str:
|
|
42
|
-
"""Set up environment variables and secrets, return start command envvars string."""
|
|
43
|
-
start_command_envvars = ""
|
|
44
|
-
|
|
45
|
-
for key, value in checkpoint_deploy.runtime.environment_variables.items():
|
|
46
|
-
if isinstance(value, SecretReference):
|
|
47
|
-
truss_deploy_config.secrets[value.name] = "set token in baseten workspace"
|
|
48
|
-
start_command_envvars = f"{key}=$(cat /secrets/{value.name})"
|
|
49
|
-
else:
|
|
50
|
-
truss_deploy_config.environment_variables[key] = value
|
|
51
|
-
|
|
52
|
-
return start_command_envvars
|
|
1
|
+
# This file is kept for potential future use but currently contains no active code
|
|
@@ -1,94 +1,9 @@
|
|
|
1
|
-
from pathlib import Path
|
|
2
|
-
|
|
3
|
-
from jinja2 import Template
|
|
4
|
-
|
|
5
|
-
from truss.base import truss_config
|
|
6
|
-
from truss.cli.train.deploy_checkpoints.deploy_checkpoints_helpers import (
|
|
7
|
-
START_COMMAND_ENVVAR_NAME,
|
|
8
|
-
)
|
|
9
|
-
from truss.cli.train.types import DeployCheckpointsConfigComplete
|
|
10
1
|
from truss_train.definitions import FullCheckpoint
|
|
11
2
|
|
|
12
|
-
from .deploy_checkpoints_helpers import (
|
|
13
|
-
setup_base_truss_config,
|
|
14
|
-
setup_environment_variables_and_secrets,
|
|
15
|
-
)
|
|
16
|
-
|
|
17
|
-
# NB(aghilan): Transformers was recently changed to save a chat_template.jinja file instead of inside the tokenizer_config.json file.
|
|
18
|
-
# Old Models will not have this file, so we check for it and use it if it exists.
|
|
19
|
-
# vLLM will not automatically resolve the chat_template.jinja file, so we need to pass it to the start command.
|
|
20
|
-
# This logic is needed for any models trained using Transformers v4.51.3 or later
|
|
21
|
-
VLLM_FULL_START_COMMAND = Template(
|
|
22
|
-
"sh -c '{% if envvars %}{{ envvars }} {% endif %}"
|
|
23
|
-
'HF_TOKEN="$$(cat /secrets/hf_access_token)" && export HF_TOKEN && '
|
|
24
|
-
"if [ -f {{ model_path }}/chat_template.jinja ]; then "
|
|
25
|
-
" vllm serve {{ model_path }} --chat-template {{ model_path }}/chat_template.jinja "
|
|
26
|
-
" --port 8000 --tensor-parallel-size {{ specify_tensor_parallelism }} --dtype bfloat16; "
|
|
27
|
-
"else "
|
|
28
|
-
" vllm serve {{ model_path }} --port 8000 --tensor-parallel-size {{ specify_tensor_parallelism }} --dtype bfloat16; "
|
|
29
|
-
"fi'"
|
|
30
|
-
)
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
def render_vllm_full_truss_config(
|
|
34
|
-
checkpoint_deploy: DeployCheckpointsConfigComplete,
|
|
35
|
-
) -> truss_config.TrussConfig:
|
|
36
|
-
"""Render truss config specifically for full checkpoints using vLLM."""
|
|
37
|
-
truss_deploy_config = setup_base_truss_config(checkpoint_deploy)
|
|
38
|
-
|
|
39
|
-
start_command_envvars = setup_environment_variables_and_secrets(
|
|
40
|
-
truss_deploy_config, checkpoint_deploy
|
|
41
|
-
)
|
|
42
|
-
|
|
43
|
-
checkpoint_str = build_full_checkpoint_string(truss_deploy_config)
|
|
44
|
-
|
|
45
|
-
accelerator = checkpoint_deploy.compute.accelerator
|
|
46
|
-
|
|
47
|
-
start_command_args = {
|
|
48
|
-
"model_path": checkpoint_str,
|
|
49
|
-
"envvars": start_command_envvars,
|
|
50
|
-
"specify_tensor_parallelism": accelerator.count if accelerator else 1,
|
|
51
|
-
}
|
|
52
|
-
# Note: we set the start command as an environment variable in supervisord config.
|
|
53
|
-
# This is so that we don't have to change the supervisord config when the start command changes.
|
|
54
|
-
# Our goal is to reduce the number of times we need to rebuild the image, and allow us to deploy faster.
|
|
55
|
-
start_command = VLLM_FULL_START_COMMAND.render(**start_command_args)
|
|
56
|
-
truss_deploy_config.environment_variables[START_COMMAND_ENVVAR_NAME] = start_command
|
|
57
|
-
# Note: supervisord uses the convention %(ENV_VAR_NAME)s to access environment variable VAR_NAME
|
|
58
|
-
truss_deploy_config.docker_server.start_command = ( # type: ignore[union-attr]
|
|
59
|
-
f"%(ENV_{START_COMMAND_ENVVAR_NAME})s"
|
|
60
|
-
)
|
|
61
|
-
|
|
62
|
-
return truss_deploy_config
|
|
63
|
-
|
|
64
3
|
|
|
65
4
|
def hydrate_full_checkpoint(
|
|
66
5
|
job_id: str, checkpoint_id: str, checkpoint: dict
|
|
67
6
|
) -> FullCheckpoint:
|
|
68
7
|
"""Create a Checkpoint object for full model weights."""
|
|
69
8
|
# NOTE: Slash at the end is important since it means the checkpoint is a directory
|
|
70
|
-
|
|
71
|
-
return FullCheckpoint(training_job_id=job_id, paths=paths)
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
def build_full_checkpoint_string(truss_deploy_config) -> str:
|
|
75
|
-
"""Build checkpoint string from artifact references for full checkpoints.
|
|
76
|
-
|
|
77
|
-
Args:
|
|
78
|
-
truss_deploy_config: The truss deploy configuration containing training checkpoints.
|
|
79
|
-
|
|
80
|
-
Returns:
|
|
81
|
-
A space-separated string of checkpoint paths.
|
|
82
|
-
"""
|
|
83
|
-
checkpoint_parts = []
|
|
84
|
-
for (
|
|
85
|
-
truss_checkpoint
|
|
86
|
-
) in truss_deploy_config.training_checkpoints.artifact_references: # type: ignore
|
|
87
|
-
ckpt_path = Path(
|
|
88
|
-
truss_deploy_config.training_checkpoints.download_folder, # type: ignore
|
|
89
|
-
truss_checkpoint.training_job_id,
|
|
90
|
-
truss_checkpoint.paths[0],
|
|
91
|
-
)
|
|
92
|
-
checkpoint_parts.append(str(ckpt_path))
|
|
93
|
-
|
|
94
|
-
return " ".join(checkpoint_parts)
|
|
9
|
+
return FullCheckpoint(training_job_id=job_id, checkpoint_name=checkpoint_id)
|