truss 0.11.6rc102__py3-none-any.whl → 0.11.24rc2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- truss/api/__init__.py +5 -2
- truss/base/constants.py +1 -0
- truss/base/trt_llm_config.py +14 -3
- truss/base/truss_config.py +19 -4
- truss/cli/chains_commands.py +49 -1
- truss/cli/cli.py +38 -7
- truss/cli/logs/base_watcher.py +31 -12
- truss/cli/logs/model_log_watcher.py +24 -1
- truss/cli/remote_cli.py +29 -0
- truss/cli/resolvers/chain_team_resolver.py +82 -0
- truss/cli/resolvers/model_team_resolver.py +90 -0
- truss/cli/resolvers/training_project_team_resolver.py +81 -0
- truss/cli/train/cache.py +332 -0
- truss/cli/train/core.py +57 -163
- truss/cli/train/deploy_checkpoints/__init__.py +2 -2
- truss/cli/train/deploy_checkpoints/deploy_checkpoints.py +236 -103
- truss/cli/train/deploy_checkpoints/deploy_checkpoints_helpers.py +1 -52
- truss/cli/train/deploy_checkpoints/deploy_full_checkpoints.py +1 -86
- truss/cli/train/deploy_checkpoints/deploy_lora_checkpoints.py +1 -85
- truss/cli/train/deploy_checkpoints/deploy_whisper_checkpoints.py +1 -56
- truss/cli/train/types.py +18 -9
- truss/cli/train_commands.py +180 -35
- truss/cli/utils/common.py +40 -3
- truss/contexts/image_builder/serving_image_builder.py +17 -4
- truss/remote/baseten/api.py +215 -9
- truss/remote/baseten/core.py +63 -7
- truss/remote/baseten/custom_types.py +1 -0
- truss/remote/baseten/remote.py +42 -2
- truss/remote/baseten/service.py +0 -7
- truss/remote/baseten/utils/transfer.py +5 -2
- truss/templates/base.Dockerfile.jinja +8 -4
- truss/templates/control/control/application.py +51 -26
- truss/templates/control/control/endpoints.py +1 -5
- truss/templates/control/control/helpers/inference_server_process_controller.py +10 -4
- truss/templates/control/control/helpers/truss_patch/model_container_patch_applier.py +33 -18
- truss/templates/control/control/server.py +1 -1
- truss/templates/control/requirements.txt +1 -2
- truss/templates/docker_server/proxy.conf.jinja +13 -0
- truss/templates/docker_server/supervisord.conf.jinja +2 -1
- truss/templates/no_build.Dockerfile.jinja +1 -0
- truss/templates/server/requirements.txt +2 -3
- truss/templates/server/truss_server.py +2 -5
- truss/templates/server.Dockerfile.jinja +12 -12
- truss/templates/shared/lazy_data_resolver.py +214 -2
- truss/templates/shared/util.py +6 -5
- truss/tests/cli/chains/test_chains_team_parameter.py +443 -0
- truss/tests/cli/test_chains_cli.py +144 -0
- truss/tests/cli/test_cli.py +134 -1
- truss/tests/cli/test_cli_utils_common.py +11 -0
- truss/tests/cli/test_model_team_resolver.py +279 -0
- truss/tests/cli/train/test_cache_view.py +240 -3
- truss/tests/cli/train/test_deploy_checkpoints.py +2 -846
- truss/tests/cli/train/test_train_cli_core.py +2 -2
- truss/tests/cli/train/test_train_team_parameter.py +395 -0
- truss/tests/conftest.py +187 -0
- truss/tests/contexts/image_builder/test_serving_image_builder.py +10 -5
- truss/tests/remote/baseten/test_api.py +122 -3
- truss/tests/remote/baseten/test_chain_upload.py +294 -0
- truss/tests/remote/baseten/test_core.py +86 -0
- truss/tests/remote/baseten/test_remote.py +216 -288
- truss/tests/remote/baseten/test_service.py +56 -0
- truss/tests/templates/control/control/conftest.py +20 -0
- truss/tests/templates/control/control/test_endpoints.py +4 -0
- truss/tests/templates/control/control/test_server.py +8 -24
- truss/tests/templates/control/control/test_server_integration.py +4 -2
- truss/tests/test_config.py +21 -12
- truss/tests/test_data/server.Dockerfile +3 -1
- truss/tests/test_data/test_build_commands_truss/__init__.py +0 -0
- truss/tests/test_data/test_build_commands_truss/config.yaml +14 -0
- truss/tests/test_data/test_build_commands_truss/model/model.py +12 -0
- truss/tests/test_data/test_build_commands_truss/packages/constants/constants.py +1 -0
- truss/tests/test_data/test_truss_server_model_cache_v1/config.yaml +1 -0
- truss/tests/test_model_inference.py +13 -0
- truss/tests/util/test_env_vars.py +8 -3
- truss/util/__init__.py +0 -0
- truss/util/env_vars.py +19 -8
- truss/util/error_utils.py +37 -0
- {truss-0.11.6rc102.dist-info → truss-0.11.24rc2.dist-info}/METADATA +2 -2
- {truss-0.11.6rc102.dist-info → truss-0.11.24rc2.dist-info}/RECORD +88 -70
- {truss-0.11.6rc102.dist-info → truss-0.11.24rc2.dist-info}/WHEEL +1 -1
- truss_chains/deployment/deployment_client.py +16 -4
- truss_chains/private_types.py +18 -0
- truss_chains/public_api.py +3 -0
- truss_train/definitions.py +6 -4
- truss_train/deployment.py +43 -21
- truss_train/public_api.py +4 -2
- {truss-0.11.6rc102.dist-info → truss-0.11.24rc2.dist-info}/entry_points.txt +0 -0
- {truss-0.11.6rc102.dist-info → truss-0.11.24rc2.dist-info}/licenses/LICENSE +0 -0
|
@@ -1,38 +1,17 @@
|
|
|
1
|
-
import os
|
|
2
|
-
import re
|
|
3
|
-
from dataclasses import dataclass
|
|
4
|
-
from pathlib import Path
|
|
5
|
-
from typing import Dict, List, Optional
|
|
6
1
|
from unittest.mock import MagicMock, patch
|
|
7
2
|
|
|
8
3
|
import pytest
|
|
9
4
|
|
|
10
5
|
import truss_train.definitions as definitions
|
|
11
|
-
from truss.base import truss_config
|
|
12
|
-
from truss.cli.train.deploy_checkpoints import prepare_checkpoint_deploy
|
|
13
6
|
from truss.cli.train.deploy_checkpoints.deploy_checkpoints import (
|
|
14
7
|
_get_checkpoint_ids_to_deploy,
|
|
15
|
-
_render_truss_config_for_checkpoint_deployment,
|
|
16
8
|
hydrate_checkpoint,
|
|
17
9
|
)
|
|
18
10
|
from truss.cli.train.deploy_checkpoints.deploy_full_checkpoints import (
|
|
19
11
|
hydrate_full_checkpoint,
|
|
20
|
-
render_vllm_full_truss_config,
|
|
21
|
-
)
|
|
22
|
-
from truss.cli.train.deploy_checkpoints.deploy_lora_checkpoints import (
|
|
23
|
-
START_COMMAND_ENVVAR_NAME,
|
|
24
|
-
_get_lora_rank,
|
|
25
|
-
hydrate_lora_checkpoint,
|
|
26
|
-
render_vllm_lora_truss_config,
|
|
27
12
|
)
|
|
28
13
|
from truss.cli.train.deploy_checkpoints.deploy_whisper_checkpoints import (
|
|
29
|
-
VLLM_WHISPER_START_COMMAND,
|
|
30
14
|
hydrate_whisper_checkpoint,
|
|
31
|
-
render_vllm_whisper_truss_config,
|
|
32
|
-
)
|
|
33
|
-
from truss.cli.train.types import (
|
|
34
|
-
DeployCheckpointsConfigComplete,
|
|
35
|
-
PrepareCheckpointResult,
|
|
36
15
|
)
|
|
37
16
|
from truss_train.definitions import ModelWeightsFormat
|
|
38
17
|
|
|
@@ -125,413 +104,6 @@ def deploy_checkpoints_mock_checkbox(create_mock_prompt):
|
|
|
125
104
|
yield mock
|
|
126
105
|
|
|
127
106
|
|
|
128
|
-
def test_render_truss_config_for_checkpoint_deployment():
|
|
129
|
-
deploy_config = DeployCheckpointsConfigComplete(
|
|
130
|
-
checkpoint_details=definitions.CheckpointList(
|
|
131
|
-
checkpoints=[
|
|
132
|
-
definitions.LoRACheckpoint(
|
|
133
|
-
training_job_id="job123",
|
|
134
|
-
paths=["rank-0/checkpoint-1/"],
|
|
135
|
-
model_weight_format=ModelWeightsFormat.LORA,
|
|
136
|
-
lora_details=definitions.LoRADetails(rank=16),
|
|
137
|
-
)
|
|
138
|
-
],
|
|
139
|
-
base_model_id="google/gemma-3-27b-it",
|
|
140
|
-
),
|
|
141
|
-
model_name="gemma-3-27b-it-vLLM-LORA",
|
|
142
|
-
compute=definitions.Compute(
|
|
143
|
-
accelerator=truss_config.AcceleratorSpec(accelerator="H100", count=4)
|
|
144
|
-
),
|
|
145
|
-
runtime=definitions.DeployCheckpointsRuntime(
|
|
146
|
-
environment_variables={
|
|
147
|
-
"HF_TOKEN": definitions.SecretReference(name="hf_access_token")
|
|
148
|
-
}
|
|
149
|
-
),
|
|
150
|
-
deployment_name="gemma-3-27b-it-vLLM-LORA",
|
|
151
|
-
model_weight_format=ModelWeightsFormat.LORA,
|
|
152
|
-
)
|
|
153
|
-
rendered_truss = _render_truss_config_for_checkpoint_deployment(deploy_config)
|
|
154
|
-
test_truss = truss_config.TrussConfig.from_yaml(
|
|
155
|
-
Path(
|
|
156
|
-
os.path.dirname(__file__),
|
|
157
|
-
"resources/test_deploy_from_checkpoint_config.yml",
|
|
158
|
-
)
|
|
159
|
-
)
|
|
160
|
-
assert test_truss.model_name == rendered_truss.model_name
|
|
161
|
-
assert (
|
|
162
|
-
test_truss.training_checkpoints.artifact_references[0].paths[0]
|
|
163
|
-
== rendered_truss.training_checkpoints.artifact_references[0].paths[0]
|
|
164
|
-
)
|
|
165
|
-
assert (
|
|
166
|
-
test_truss.training_checkpoints.artifact_references[0].paths
|
|
167
|
-
== rendered_truss.training_checkpoints.artifact_references[0].paths
|
|
168
|
-
)
|
|
169
|
-
assert (
|
|
170
|
-
test_truss.docker_server.start_command
|
|
171
|
-
== rendered_truss.docker_server.start_command
|
|
172
|
-
)
|
|
173
|
-
assert test_truss.resources.accelerator == rendered_truss.resources.accelerator
|
|
174
|
-
assert test_truss.secrets == rendered_truss.secrets
|
|
175
|
-
assert test_truss.training_checkpoints == rendered_truss.training_checkpoints
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
def test_prepare_checkpoint_deploy_empty_config(
|
|
179
|
-
mock_remote,
|
|
180
|
-
deploy_checkpoints_mock_select,
|
|
181
|
-
deploy_checkpoints_mock_input,
|
|
182
|
-
deploy_checkpoints_mock_text,
|
|
183
|
-
deploy_checkpoints_mock_checkbox,
|
|
184
|
-
):
|
|
185
|
-
# Create empty config
|
|
186
|
-
empty_config = definitions.DeployCheckpointsConfig()
|
|
187
|
-
|
|
188
|
-
# Call function under test
|
|
189
|
-
result = prepare_checkpoint_deploy(
|
|
190
|
-
remote_provider=mock_remote,
|
|
191
|
-
checkpoint_deploy_config=empty_config,
|
|
192
|
-
project_id="project123",
|
|
193
|
-
job_id="job123",
|
|
194
|
-
)
|
|
195
|
-
|
|
196
|
-
assert isinstance(result, PrepareCheckpointResult)
|
|
197
|
-
assert result.checkpoint_deploy_config.model_name == "gemma-3-27b-it-vLLM-LORA"
|
|
198
|
-
assert (
|
|
199
|
-
result.checkpoint_deploy_config.checkpoint_details.base_model_id
|
|
200
|
-
== "google/gemma-3-27b-it"
|
|
201
|
-
)
|
|
202
|
-
assert len(result.checkpoint_deploy_config.checkpoint_details.checkpoints) == 1
|
|
203
|
-
checkpoint = result.checkpoint_deploy_config.checkpoint_details.checkpoints[0]
|
|
204
|
-
assert checkpoint.training_job_id == "job123"
|
|
205
|
-
assert isinstance(checkpoint, definitions.LoRACheckpoint)
|
|
206
|
-
assert checkpoint.lora_details.rank == 16
|
|
207
|
-
assert result.checkpoint_deploy_config.compute.accelerator.accelerator == "H100"
|
|
208
|
-
assert result.checkpoint_deploy_config.compute.accelerator.count == 4
|
|
209
|
-
assert (
|
|
210
|
-
result.checkpoint_deploy_config.runtime.environment_variables["HF_TOKEN"].name
|
|
211
|
-
== "hf_access_token"
|
|
212
|
-
)
|
|
213
|
-
|
|
214
|
-
|
|
215
|
-
def test_prepare_checkpoint_deploy_complete_config(
|
|
216
|
-
mock_remote,
|
|
217
|
-
deploy_checkpoints_mock_select,
|
|
218
|
-
deploy_checkpoints_mock_text,
|
|
219
|
-
deploy_checkpoints_mock_checkbox,
|
|
220
|
-
):
|
|
221
|
-
# Create complete config with all fields specified
|
|
222
|
-
complete_config = definitions.DeployCheckpointsConfig(
|
|
223
|
-
checkpoint_details=definitions.CheckpointList(
|
|
224
|
-
checkpoints=[
|
|
225
|
-
definitions.LoRACheckpoint(
|
|
226
|
-
training_job_id="job123",
|
|
227
|
-
paths=["job123/rank-0/checkpoint-1/"],
|
|
228
|
-
model_weight_format=ModelWeightsFormat.LORA,
|
|
229
|
-
lora_details=definitions.LoRADetails(rank=32),
|
|
230
|
-
)
|
|
231
|
-
],
|
|
232
|
-
base_model_id="google/gemma-3-27b-it",
|
|
233
|
-
),
|
|
234
|
-
model_name="my-custom-model",
|
|
235
|
-
deployment_name="my-deployment",
|
|
236
|
-
compute=definitions.Compute(
|
|
237
|
-
accelerator=truss_config.AcceleratorSpec(accelerator="A100", count=2)
|
|
238
|
-
),
|
|
239
|
-
runtime=definitions.DeployCheckpointsRuntime(
|
|
240
|
-
environment_variables={
|
|
241
|
-
"HF_TOKEN": definitions.SecretReference(name="my_custom_secret"),
|
|
242
|
-
"CUSTOM_VAR": "custom_value",
|
|
243
|
-
}
|
|
244
|
-
),
|
|
245
|
-
)
|
|
246
|
-
|
|
247
|
-
# Call function under test
|
|
248
|
-
result = prepare_checkpoint_deploy(
|
|
249
|
-
remote_provider=mock_remote,
|
|
250
|
-
checkpoint_deploy_config=complete_config,
|
|
251
|
-
project_id="project123",
|
|
252
|
-
job_id="job123",
|
|
253
|
-
)
|
|
254
|
-
|
|
255
|
-
# Verify result
|
|
256
|
-
assert isinstance(result, PrepareCheckpointResult)
|
|
257
|
-
|
|
258
|
-
# Verify no prompts were called
|
|
259
|
-
deploy_checkpoints_mock_select.assert_not_called()
|
|
260
|
-
deploy_checkpoints_mock_text.assert_not_called()
|
|
261
|
-
deploy_checkpoints_mock_checkbox.assert_not_called()
|
|
262
|
-
|
|
263
|
-
# Verify config values were preserved
|
|
264
|
-
config = result.checkpoint_deploy_config
|
|
265
|
-
assert config.model_name == "my-custom-model"
|
|
266
|
-
assert config.deployment_name == "my-deployment"
|
|
267
|
-
|
|
268
|
-
# Verify checkpoint details
|
|
269
|
-
assert config.checkpoint_details.base_model_id == "google/gemma-3-27b-it"
|
|
270
|
-
assert len(config.checkpoint_details.checkpoints) == 1
|
|
271
|
-
checkpoint = config.checkpoint_details.checkpoints[0]
|
|
272
|
-
assert checkpoint.training_job_id == "job123"
|
|
273
|
-
assert checkpoint.model_weight_format == ModelWeightsFormat.LORA
|
|
274
|
-
assert isinstance(checkpoint, definitions.LoRACheckpoint)
|
|
275
|
-
assert checkpoint.lora_details.rank == 32
|
|
276
|
-
|
|
277
|
-
# Verify compute config
|
|
278
|
-
assert config.compute.accelerator.accelerator == "A100"
|
|
279
|
-
assert config.compute.accelerator.count == 2
|
|
280
|
-
|
|
281
|
-
# Verify runtime config
|
|
282
|
-
env_vars = config.runtime.environment_variables
|
|
283
|
-
assert env_vars["HF_TOKEN"].name == "my_custom_secret"
|
|
284
|
-
assert env_vars["CUSTOM_VAR"] == "custom_value"
|
|
285
|
-
|
|
286
|
-
# open the config.yaml file and verify the tensor parallel size is 2
|
|
287
|
-
# additional tests can be added to verify the config.yaml file is correct
|
|
288
|
-
truss_cfg = truss_config.TrussConfig.from_yaml(
|
|
289
|
-
Path(result.truss_directory, "config.yaml")
|
|
290
|
-
)
|
|
291
|
-
# Check that the start command is now the environment variable reference
|
|
292
|
-
assert truss_cfg.docker_server.start_command == "%(ENV_BT_DOCKER_SERVER_START_CMD)s"
|
|
293
|
-
# Check that the actual start command with tensor parallel size is in the environment variable
|
|
294
|
-
assert (
|
|
295
|
-
"--tensor-parallel-size 2"
|
|
296
|
-
in truss_cfg.environment_variables["BT_DOCKER_SERVER_START_CMD"]
|
|
297
|
-
)
|
|
298
|
-
|
|
299
|
-
|
|
300
|
-
def test_checkpoint_lora_rank_validation():
|
|
301
|
-
"""Test that LoRACheckpoint accepts valid LoRA rank values."""
|
|
302
|
-
valid_ranks = [8, 16, 32, 64, 128, 256, 320, 512]
|
|
303
|
-
|
|
304
|
-
for rank in valid_ranks:
|
|
305
|
-
checkpoint = definitions.LoRACheckpoint(
|
|
306
|
-
training_job_id="job123",
|
|
307
|
-
paths=["job123/rank-0/checkpoint-1/"],
|
|
308
|
-
model_weight_format=ModelWeightsFormat.LORA,
|
|
309
|
-
lora_details=definitions.LoRADetails(rank=rank),
|
|
310
|
-
)
|
|
311
|
-
assert checkpoint.lora_details.rank == rank
|
|
312
|
-
|
|
313
|
-
invalid_ranks = [
|
|
314
|
-
1,
|
|
315
|
-
2,
|
|
316
|
-
4,
|
|
317
|
-
7,
|
|
318
|
-
9,
|
|
319
|
-
15,
|
|
320
|
-
17,
|
|
321
|
-
31,
|
|
322
|
-
33,
|
|
323
|
-
63,
|
|
324
|
-
65,
|
|
325
|
-
127,
|
|
326
|
-
129,
|
|
327
|
-
255,
|
|
328
|
-
257,
|
|
329
|
-
319,
|
|
330
|
-
321,
|
|
331
|
-
511,
|
|
332
|
-
513,
|
|
333
|
-
1000,
|
|
334
|
-
]
|
|
335
|
-
for rank in invalid_ranks:
|
|
336
|
-
with pytest.raises(ValueError, match=f"lora_rank \\({rank}\\) must be one of"):
|
|
337
|
-
definitions.LoRACheckpoint(
|
|
338
|
-
training_job_id="job123",
|
|
339
|
-
paths=["job123/rank-0/checkpoint-1/"],
|
|
340
|
-
model_weight_format=ModelWeightsFormat.LORA,
|
|
341
|
-
lora_details=definitions.LoRADetails(rank=rank),
|
|
342
|
-
)
|
|
343
|
-
|
|
344
|
-
|
|
345
|
-
def test_get_lora_rank():
|
|
346
|
-
"""Test that get_lora_rank returns valid values from checkpoint response."""
|
|
347
|
-
# Test with valid rank from API
|
|
348
|
-
checkpoint_resp = {"lora_adapter_config": {"r": 64}}
|
|
349
|
-
assert _get_lora_rank(checkpoint_resp) == 64
|
|
350
|
-
# Test with missing lora_adapter_config (should use DEFAULT_LORA_RANK)
|
|
351
|
-
checkpoint_resp = {}
|
|
352
|
-
assert _get_lora_rank(checkpoint_resp) == 16 # DEFAULT_LORA_RANK
|
|
353
|
-
# Test with missing 'r' field (should use DEFAULT_LORA_RANK)
|
|
354
|
-
checkpoint_resp = {"lora_adapter_config": {}}
|
|
355
|
-
assert _get_lora_rank(checkpoint_resp) == 16 # DEFAULT_LORA_RANK
|
|
356
|
-
# Test with invalid rank from API
|
|
357
|
-
checkpoint_resp = {"lora_adapter_config": {"r": 1}}
|
|
358
|
-
with pytest.raises(
|
|
359
|
-
ValueError,
|
|
360
|
-
match=re.escape("LoRA rank 1 from checkpoint is not in allowed values"),
|
|
361
|
-
):
|
|
362
|
-
_get_lora_rank(checkpoint_resp)
|
|
363
|
-
# Test with another invalid rank
|
|
364
|
-
checkpoint_resp = {"lora_adapter_config": {"r": 1000}}
|
|
365
|
-
with pytest.raises(
|
|
366
|
-
ValueError,
|
|
367
|
-
match=re.escape("LoRA rank 1000 from checkpoint is not in allowed values"),
|
|
368
|
-
):
|
|
369
|
-
_get_lora_rank(checkpoint_resp)
|
|
370
|
-
|
|
371
|
-
|
|
372
|
-
def test_hydrate_lora_checkpoint():
|
|
373
|
-
"""Test that hydrate_lora_checkpoint creates proper LoRACheckpoint objects."""
|
|
374
|
-
job_id = "test_job_123"
|
|
375
|
-
checkpoint_id = "checkpoint-456"
|
|
376
|
-
checkpoint_data = {
|
|
377
|
-
"lora_adapter_config": {"r": 64},
|
|
378
|
-
"base_model": "google/gemma-3-27b-it",
|
|
379
|
-
"checkpoint_type": "lora",
|
|
380
|
-
}
|
|
381
|
-
|
|
382
|
-
result = hydrate_lora_checkpoint(job_id, checkpoint_id, checkpoint_data)
|
|
383
|
-
|
|
384
|
-
assert isinstance(result, definitions.LoRACheckpoint)
|
|
385
|
-
assert result.training_job_id == job_id
|
|
386
|
-
assert result.lora_details.rank == 64
|
|
387
|
-
assert len(result.paths) == 1
|
|
388
|
-
assert result.paths[0] == f"rank-0/{checkpoint_id}/"
|
|
389
|
-
|
|
390
|
-
|
|
391
|
-
def test_hydrate_checkpoint_dispatcher():
|
|
392
|
-
"""Test that hydrate_checkpoint properly dispatches to the right function based on checkpoint type."""
|
|
393
|
-
job_id = "test_job_123"
|
|
394
|
-
checkpoint_id = "checkpoint-456"
|
|
395
|
-
checkpoint_data = {
|
|
396
|
-
"lora_adapter_config": {"r": 32},
|
|
397
|
-
"base_model": "google/gemma-3-27b-it",
|
|
398
|
-
"checkpoint_type": "lora",
|
|
399
|
-
}
|
|
400
|
-
|
|
401
|
-
# Test LoRA checkpoint type
|
|
402
|
-
result = hydrate_checkpoint(job_id, checkpoint_id, checkpoint_data, "lora")
|
|
403
|
-
assert isinstance(result, definitions.LoRACheckpoint)
|
|
404
|
-
assert result.lora_details.rank == 32
|
|
405
|
-
|
|
406
|
-
# Test uppercase LoRA checkpoint type
|
|
407
|
-
result = hydrate_checkpoint(job_id, checkpoint_id, checkpoint_data, "LORA")
|
|
408
|
-
assert isinstance(result, definitions.LoRACheckpoint)
|
|
409
|
-
assert result.lora_details.rank == 32
|
|
410
|
-
|
|
411
|
-
# Test unsupported checkpoint type
|
|
412
|
-
with pytest.raises(ValueError, match="Unsupported checkpoint type: unsupported"):
|
|
413
|
-
hydrate_checkpoint(job_id, checkpoint_id, checkpoint_data, "unsupported")
|
|
414
|
-
|
|
415
|
-
|
|
416
|
-
def test_render_vllm_lora_truss_config():
|
|
417
|
-
"""Test that render_vllm_lora_truss_config creates proper TrussConfig for LoRA deployments."""
|
|
418
|
-
deploy_config = DeployCheckpointsConfigComplete(
|
|
419
|
-
checkpoint_details=definitions.CheckpointList(
|
|
420
|
-
checkpoints=[
|
|
421
|
-
definitions.LoRACheckpoint(
|
|
422
|
-
training_job_id="job123",
|
|
423
|
-
paths=["rank-0/checkpoint-1/"],
|
|
424
|
-
model_weight_format=ModelWeightsFormat.LORA,
|
|
425
|
-
lora_details=definitions.LoRADetails(rank=64),
|
|
426
|
-
)
|
|
427
|
-
],
|
|
428
|
-
base_model_id="google/gemma-3-27b-it",
|
|
429
|
-
),
|
|
430
|
-
model_name="test-lora-model",
|
|
431
|
-
compute=definitions.Compute(
|
|
432
|
-
accelerator=truss_config.AcceleratorSpec(accelerator="H100", count=2)
|
|
433
|
-
),
|
|
434
|
-
runtime=definitions.DeployCheckpointsRuntime(
|
|
435
|
-
environment_variables={
|
|
436
|
-
"HF_TOKEN": definitions.SecretReference(name="hf_token")
|
|
437
|
-
}
|
|
438
|
-
),
|
|
439
|
-
deployment_name="test-deployment",
|
|
440
|
-
model_weight_format=ModelWeightsFormat.LORA,
|
|
441
|
-
)
|
|
442
|
-
|
|
443
|
-
result = render_vllm_lora_truss_config(deploy_config)
|
|
444
|
-
|
|
445
|
-
expected_vllm_command = 'sh -c "HF_TOKEN=$(cat /secrets/hf_token) vllm serve google/gemma-3-27b-it --port 8000 --tensor-parallel-size 2 --enable-lora --max-lora-rank 64 --dtype bfloat16 --lora-modules job123=/tmp/training_checkpoints/job123/rank-0/checkpoint-1"'
|
|
446
|
-
|
|
447
|
-
assert isinstance(result, truss_config.TrussConfig)
|
|
448
|
-
assert result.model_name == "test-lora-model"
|
|
449
|
-
assert result.docker_server is not None
|
|
450
|
-
assert result.docker_server.start_command == f"%(ENV_{START_COMMAND_ENVVAR_NAME})s"
|
|
451
|
-
assert (
|
|
452
|
-
result.environment_variables[START_COMMAND_ENVVAR_NAME] == expected_vllm_command
|
|
453
|
-
)
|
|
454
|
-
|
|
455
|
-
|
|
456
|
-
def test_render_truss_config_delegation():
|
|
457
|
-
"""Test that _render_truss_config_for_checkpoint_deployment delegates correctly based on model weight format."""
|
|
458
|
-
deploy_config = DeployCheckpointsConfigComplete(
|
|
459
|
-
checkpoint_details=definitions.CheckpointList(
|
|
460
|
-
checkpoints=[
|
|
461
|
-
definitions.LoRACheckpoint(
|
|
462
|
-
training_job_id="job123",
|
|
463
|
-
paths=["rank-0/checkpoint-1/"],
|
|
464
|
-
model_weight_format=ModelWeightsFormat.LORA,
|
|
465
|
-
lora_details=definitions.LoRADetails(rank=32),
|
|
466
|
-
)
|
|
467
|
-
],
|
|
468
|
-
base_model_id="google/gemma-3-27b-it",
|
|
469
|
-
),
|
|
470
|
-
model_name="test-model",
|
|
471
|
-
compute=definitions.Compute(
|
|
472
|
-
accelerator=truss_config.AcceleratorSpec(accelerator="H100", count=4)
|
|
473
|
-
),
|
|
474
|
-
runtime=definitions.DeployCheckpointsRuntime(environment_variables={}),
|
|
475
|
-
deployment_name="test-deployment",
|
|
476
|
-
model_weight_format=ModelWeightsFormat.LORA,
|
|
477
|
-
)
|
|
478
|
-
|
|
479
|
-
# Test that it works for LoRA format
|
|
480
|
-
result = _render_truss_config_for_checkpoint_deployment(deploy_config)
|
|
481
|
-
assert isinstance(result, truss_config.TrussConfig)
|
|
482
|
-
expected_vllm_command = "vllm serve google/gemma-3-27b-it --port 8000 --tensor-parallel-size 4 --enable-lora --max-lora-rank 32 --dtype bfloat16 --lora-modules job123=/tmp/training_checkpoints/job123/rank-0/checkpoint-1"
|
|
483
|
-
assert (
|
|
484
|
-
expected_vllm_command in result.environment_variables[START_COMMAND_ENVVAR_NAME]
|
|
485
|
-
)
|
|
486
|
-
|
|
487
|
-
|
|
488
|
-
def test_render_vllm_full_truss_config():
|
|
489
|
-
"""Test that render_vllm_full_truss_config creates proper TrussConfig for full fine-tune deployments."""
|
|
490
|
-
deploy_config = DeployCheckpointsConfigComplete(
|
|
491
|
-
checkpoint_details=definitions.CheckpointList(
|
|
492
|
-
checkpoints=[
|
|
493
|
-
definitions.FullCheckpoint(
|
|
494
|
-
training_job_id="job123",
|
|
495
|
-
paths=["rank-0/checkpoint-1/"],
|
|
496
|
-
model_weight_format=ModelWeightsFormat.FULL,
|
|
497
|
-
)
|
|
498
|
-
],
|
|
499
|
-
base_model_id=None, # Not needed for full fine-tune
|
|
500
|
-
),
|
|
501
|
-
model_name="test-full-model",
|
|
502
|
-
compute=definitions.Compute(
|
|
503
|
-
accelerator=truss_config.AcceleratorSpec(accelerator="H100", count=2)
|
|
504
|
-
),
|
|
505
|
-
runtime=definitions.DeployCheckpointsRuntime(
|
|
506
|
-
environment_variables={
|
|
507
|
-
"HF_TOKEN": definitions.SecretReference(name="hf_token")
|
|
508
|
-
}
|
|
509
|
-
),
|
|
510
|
-
deployment_name="test-deployment",
|
|
511
|
-
model_weight_format=ModelWeightsFormat.FULL,
|
|
512
|
-
)
|
|
513
|
-
|
|
514
|
-
result = render_vllm_full_truss_config(deploy_config)
|
|
515
|
-
expected_vllm_command = (
|
|
516
|
-
"sh -c 'HF_TOKEN=$(cat /secrets/hf_token) "
|
|
517
|
-
'HF_TOKEN="$$(cat /secrets/hf_access_token)" && export HF_TOKEN && '
|
|
518
|
-
"if [ -f /tmp/training_checkpoints/job123/rank-0/checkpoint-1/chat_template.jinja ]; then "
|
|
519
|
-
"vllm serve /tmp/training_checkpoints/job123/rank-0/checkpoint-1 "
|
|
520
|
-
"--chat-template /tmp/training_checkpoints/job123/rank-0/checkpoint-1/chat_template.jinja "
|
|
521
|
-
"--port 8000 --tensor-parallel-size 2 --dtype bfloat16; else "
|
|
522
|
-
"vllm serve /tmp/training_checkpoints/job123/rank-0/checkpoint-1 "
|
|
523
|
-
"--port 8000 --tensor-parallel-size 2 --dtype bfloat16; fi'"
|
|
524
|
-
)
|
|
525
|
-
|
|
526
|
-
assert isinstance(result, truss_config.TrussConfig)
|
|
527
|
-
assert result.model_name == "test-full-model"
|
|
528
|
-
assert result.docker_server is not None
|
|
529
|
-
assert result.docker_server.start_command == f"%(ENV_{START_COMMAND_ENVVAR_NAME})s"
|
|
530
|
-
assert (
|
|
531
|
-
result.environment_variables[START_COMMAND_ENVVAR_NAME] == expected_vllm_command
|
|
532
|
-
)
|
|
533
|
-
|
|
534
|
-
|
|
535
107
|
def test_hydrate_full_checkpoint():
|
|
536
108
|
"""Test that hydrate_full_checkpoint creates proper FullCheckpoint objects."""
|
|
537
109
|
job_id = "test_job_123"
|
|
@@ -543,8 +115,7 @@ def test_hydrate_full_checkpoint():
|
|
|
543
115
|
assert isinstance(result, definitions.FullCheckpoint)
|
|
544
116
|
assert result.training_job_id == job_id
|
|
545
117
|
assert result.model_weight_format == ModelWeightsFormat.FULL
|
|
546
|
-
assert
|
|
547
|
-
assert result.paths[0] == f"rank-0/{checkpoint_id}/"
|
|
118
|
+
assert result.checkpoint_name == checkpoint_id
|
|
548
119
|
|
|
549
120
|
|
|
550
121
|
def test_hydrate_checkpoint_dispatcher_full():
|
|
@@ -691,34 +262,6 @@ def test_get_checkpoint_ids_to_deploy_single_checkpoint():
|
|
|
691
262
|
assert result == ["checkpoint-1"]
|
|
692
263
|
|
|
693
264
|
|
|
694
|
-
def test_vllm_whisper_start_command_template():
|
|
695
|
-
"""Test that the VLLM_WHISPER_START_COMMAND template renders correctly."""
|
|
696
|
-
# Test with all variables
|
|
697
|
-
result = VLLM_WHISPER_START_COMMAND.render(
|
|
698
|
-
model_path="/path/to/model",
|
|
699
|
-
envvars="CUDA_VISIBLE_DEVICES=0",
|
|
700
|
-
specify_tensor_parallelism=4,
|
|
701
|
-
)
|
|
702
|
-
|
|
703
|
-
expected = (
|
|
704
|
-
"sh -c 'CUDA_VISIBLE_DEVICES=0 "
|
|
705
|
-
'HF_TOKEN="$$(cat /secrets/hf_access_token)" && export HF_TOKEN && '
|
|
706
|
-
"vllm serve /path/to/model --port 8000 --tensor-parallel-size 4'"
|
|
707
|
-
)
|
|
708
|
-
assert result == expected
|
|
709
|
-
|
|
710
|
-
result = VLLM_WHISPER_START_COMMAND.render(
|
|
711
|
-
model_path="/path/to/model", envvars=None, specify_tensor_parallelism=1
|
|
712
|
-
)
|
|
713
|
-
|
|
714
|
-
expected = (
|
|
715
|
-
"sh -c '"
|
|
716
|
-
'HF_TOKEN="$$(cat /secrets/hf_access_token)" && export HF_TOKEN && '
|
|
717
|
-
"vllm serve /path/to/model --port 8000 --tensor-parallel-size 1'"
|
|
718
|
-
)
|
|
719
|
-
assert result == expected
|
|
720
|
-
|
|
721
|
-
|
|
722
265
|
def test_hydrate_whisper_checkpoint():
|
|
723
266
|
"""Test that hydrate_whisper_checkpoint creates correct WhisperCheckpoint object."""
|
|
724
267
|
job_id = "test-job-123"
|
|
@@ -728,393 +271,6 @@ def test_hydrate_whisper_checkpoint():
|
|
|
728
271
|
result = hydrate_whisper_checkpoint(job_id, checkpoint_id, checkpoint)
|
|
729
272
|
|
|
730
273
|
assert result.training_job_id == job_id
|
|
731
|
-
assert result.
|
|
274
|
+
assert result.checkpoint_name == checkpoint_id
|
|
732
275
|
assert result.model_weight_format == definitions.ModelWeightsFormat.WHISPER
|
|
733
276
|
assert isinstance(result, definitions.WhisperCheckpoint)
|
|
734
|
-
|
|
735
|
-
|
|
736
|
-
@patch(
|
|
737
|
-
"truss.cli.train.deploy_checkpoints.deploy_whisper_checkpoints.setup_base_truss_config"
|
|
738
|
-
)
|
|
739
|
-
@patch(
|
|
740
|
-
"truss.cli.train.deploy_checkpoints.deploy_whisper_checkpoints.setup_environment_variables_and_secrets"
|
|
741
|
-
)
|
|
742
|
-
@patch(
|
|
743
|
-
"truss.cli.train.deploy_checkpoints.deploy_whisper_checkpoints.build_full_checkpoint_string"
|
|
744
|
-
)
|
|
745
|
-
def test_render_vllm_whisper_truss_config(
|
|
746
|
-
mock_build_full_checkpoint_string, mock_setup_env_vars, mock_setup_base_config
|
|
747
|
-
):
|
|
748
|
-
"""Test that render_vllm_whisper_truss_config renders truss config correctly."""
|
|
749
|
-
# Mock dependencies
|
|
750
|
-
mock_truss_config = MagicMock()
|
|
751
|
-
mock_truss_config.environment_variables = {}
|
|
752
|
-
mock_truss_config.docker_server = MagicMock()
|
|
753
|
-
mock_setup_base_config.return_value = mock_truss_config
|
|
754
|
-
|
|
755
|
-
mock_setup_env_vars.return_value = "HF_TOKEN=$(cat /secrets/hf_access_token)"
|
|
756
|
-
mock_build_full_checkpoint_string.return_value = "/path/to/checkpoint"
|
|
757
|
-
|
|
758
|
-
# Create test config
|
|
759
|
-
deploy_config = DeployCheckpointsConfigComplete(
|
|
760
|
-
checkpoint_details=definitions.CheckpointList(
|
|
761
|
-
checkpoints=[
|
|
762
|
-
definitions.WhisperCheckpoint(
|
|
763
|
-
training_job_id="job123",
|
|
764
|
-
paths=["rank-0/checkpoint-1/"],
|
|
765
|
-
model_weight_format=definitions.ModelWeightsFormat.WHISPER,
|
|
766
|
-
)
|
|
767
|
-
],
|
|
768
|
-
base_model_id="openai/whisper-large-v3",
|
|
769
|
-
),
|
|
770
|
-
model_name="whisper-large-v3-vLLM",
|
|
771
|
-
compute=definitions.Compute(
|
|
772
|
-
accelerator=truss_config.AcceleratorSpec(accelerator="H100", count=4)
|
|
773
|
-
),
|
|
774
|
-
runtime=definitions.DeployCheckpointsRuntime(
|
|
775
|
-
environment_variables={
|
|
776
|
-
"HF_TOKEN": definitions.SecretReference(name="hf_access_token")
|
|
777
|
-
}
|
|
778
|
-
),
|
|
779
|
-
deployment_name="whisper-large-v3-vLLM",
|
|
780
|
-
model_weight_format=definitions.ModelWeightsFormat.WHISPER,
|
|
781
|
-
)
|
|
782
|
-
|
|
783
|
-
result = render_vllm_whisper_truss_config(deploy_config)
|
|
784
|
-
|
|
785
|
-
mock_setup_base_config.assert_called_once_with(deploy_config)
|
|
786
|
-
mock_setup_env_vars.assert_called_once_with(mock_truss_config, deploy_config)
|
|
787
|
-
mock_build_full_checkpoint_string.assert_called_once_with(mock_truss_config)
|
|
788
|
-
|
|
789
|
-
assert result == mock_truss_config
|
|
790
|
-
|
|
791
|
-
expected_start_command = (
|
|
792
|
-
"sh -c 'HF_TOKEN=$(cat /secrets/hf_access_token) "
|
|
793
|
-
'HF_TOKEN="$$(cat /secrets/hf_access_token)" && export HF_TOKEN && '
|
|
794
|
-
"vllm serve /path/to/checkpoint --port 8000 --tensor-parallel-size 4'"
|
|
795
|
-
)
|
|
796
|
-
assert (
|
|
797
|
-
result.environment_variables[START_COMMAND_ENVVAR_NAME]
|
|
798
|
-
== expected_start_command
|
|
799
|
-
)
|
|
800
|
-
|
|
801
|
-
assert result.docker_server.start_command == f"%(ENV_{START_COMMAND_ENVVAR_NAME})s"
|
|
802
|
-
|
|
803
|
-
|
|
804
|
-
@patch(
|
|
805
|
-
"truss.cli.train.deploy_checkpoints.deploy_whisper_checkpoints.setup_base_truss_config"
|
|
806
|
-
)
|
|
807
|
-
@patch(
|
|
808
|
-
"truss.cli.train.deploy_checkpoints.deploy_whisper_checkpoints.setup_environment_variables_and_secrets"
|
|
809
|
-
)
|
|
810
|
-
@patch(
|
|
811
|
-
"truss.cli.train.deploy_checkpoints.deploy_whisper_checkpoints.build_full_checkpoint_string"
|
|
812
|
-
)
|
|
813
|
-
def test_render_vllm_whisper_truss_config_with_envvars(
|
|
814
|
-
mock_build_full_checkpoint_string, mock_setup_env_vars, mock_setup_base_config
|
|
815
|
-
):
|
|
816
|
-
"""Test that render_vllm_whisper_truss_config handles environment variables correctly."""
|
|
817
|
-
# Mock dependencies
|
|
818
|
-
mock_truss_config = MagicMock()
|
|
819
|
-
mock_truss_config.environment_variables = {}
|
|
820
|
-
mock_truss_config.docker_server = MagicMock()
|
|
821
|
-
mock_setup_base_config.return_value = mock_truss_config
|
|
822
|
-
|
|
823
|
-
mock_setup_env_vars.return_value = "CUDA_VISIBLE_DEVICES=0,1"
|
|
824
|
-
mock_build_full_checkpoint_string.return_value = "/path/to/checkpoint"
|
|
825
|
-
|
|
826
|
-
# Create test config with environment variables
|
|
827
|
-
deploy_config = DeployCheckpointsConfigComplete(
|
|
828
|
-
checkpoint_details=definitions.CheckpointList(
|
|
829
|
-
checkpoints=[
|
|
830
|
-
definitions.WhisperCheckpoint(
|
|
831
|
-
training_job_id="job123",
|
|
832
|
-
paths=["rank-0/checkpoint-1/"],
|
|
833
|
-
model_weight_format=definitions.ModelWeightsFormat.WHISPER,
|
|
834
|
-
)
|
|
835
|
-
],
|
|
836
|
-
base_model_id="openai/whisper-large-v3",
|
|
837
|
-
),
|
|
838
|
-
model_name="whisper-large-v3-vLLM",
|
|
839
|
-
compute=definitions.Compute(
|
|
840
|
-
accelerator=truss_config.AcceleratorSpec(accelerator="H100", count=2)
|
|
841
|
-
),
|
|
842
|
-
runtime=definitions.DeployCheckpointsRuntime(
|
|
843
|
-
environment_variables={
|
|
844
|
-
"CUDA_VISIBLE_DEVICES": "0,1",
|
|
845
|
-
"HF_TOKEN": definitions.SecretReference(name="hf_access_token"),
|
|
846
|
-
}
|
|
847
|
-
),
|
|
848
|
-
deployment_name="whisper-large-v3-vLLM",
|
|
849
|
-
model_weight_format=definitions.ModelWeightsFormat.WHISPER,
|
|
850
|
-
)
|
|
851
|
-
|
|
852
|
-
# Call function under test
|
|
853
|
-
result = render_vllm_whisper_truss_config(deploy_config)
|
|
854
|
-
|
|
855
|
-
# Verify environment variables are included in start command
|
|
856
|
-
expected_start_command = (
|
|
857
|
-
"sh -c 'CUDA_VISIBLE_DEVICES=0,1 "
|
|
858
|
-
'HF_TOKEN="$$(cat /secrets/hf_access_token)" && export HF_TOKEN && '
|
|
859
|
-
"vllm serve /path/to/checkpoint --port 8000 --tensor-parallel-size 2'"
|
|
860
|
-
)
|
|
861
|
-
assert (
|
|
862
|
-
result.environment_variables[START_COMMAND_ENVVAR_NAME]
|
|
863
|
-
== expected_start_command
|
|
864
|
-
)
|
|
865
|
-
|
|
866
|
-
|
|
867
|
-
@dataclass
|
|
868
|
-
class TestCase:
|
|
869
|
-
"""Test case for setup_base_truss_config function."""
|
|
870
|
-
|
|
871
|
-
desc: str
|
|
872
|
-
input_config: DeployCheckpointsConfigComplete
|
|
873
|
-
expected_model_name: str
|
|
874
|
-
expected_predict_endpoint: str
|
|
875
|
-
expected_accelerator: Optional[str]
|
|
876
|
-
expected_accelerator_count: Optional[int]
|
|
877
|
-
expected_checkpoint_paths: List[str]
|
|
878
|
-
expected_environment_variables: Dict[str, str]
|
|
879
|
-
should_raise: Optional[str] = None # Error message if function should raise
|
|
880
|
-
|
|
881
|
-
__test__ = False # Tell pytest this is not a test class
|
|
882
|
-
|
|
883
|
-
|
|
884
|
-
def test_setup_base_truss_config():
|
|
885
|
-
"""Table-driven test for setup_base_truss_config function."""
|
|
886
|
-
from truss.cli.train.deploy_checkpoints.deploy_checkpoints_helpers import (
|
|
887
|
-
setup_base_truss_config,
|
|
888
|
-
)
|
|
889
|
-
|
|
890
|
-
# Define test cases
|
|
891
|
-
test_cases = [
|
|
892
|
-
TestCase(
|
|
893
|
-
desc="LoRA checkpoint with H100 accelerator",
|
|
894
|
-
input_config=DeployCheckpointsConfigComplete(
|
|
895
|
-
checkpoint_details=definitions.CheckpointList(
|
|
896
|
-
checkpoints=[
|
|
897
|
-
definitions.LoRACheckpoint(
|
|
898
|
-
training_job_id="job123",
|
|
899
|
-
paths=["rank-0/checkpoint-1/"],
|
|
900
|
-
model_weight_format=ModelWeightsFormat.LORA,
|
|
901
|
-
lora_details=definitions.LoRADetails(rank=32),
|
|
902
|
-
)
|
|
903
|
-
],
|
|
904
|
-
base_model_id="google/gemma-3-27b-it",
|
|
905
|
-
),
|
|
906
|
-
model_name="test-lora-model",
|
|
907
|
-
compute=definitions.Compute(
|
|
908
|
-
accelerator=truss_config.AcceleratorSpec(
|
|
909
|
-
accelerator="H100", count=4
|
|
910
|
-
)
|
|
911
|
-
),
|
|
912
|
-
runtime=definitions.DeployCheckpointsRuntime(environment_variables={}),
|
|
913
|
-
deployment_name="test-deployment",
|
|
914
|
-
model_weight_format=ModelWeightsFormat.LORA,
|
|
915
|
-
),
|
|
916
|
-
expected_model_name="test-lora-model",
|
|
917
|
-
expected_predict_endpoint="/v1/chat/completions",
|
|
918
|
-
expected_accelerator="H100",
|
|
919
|
-
expected_accelerator_count=4,
|
|
920
|
-
expected_checkpoint_paths=["rank-0/checkpoint-1/"],
|
|
921
|
-
expected_environment_variables={
|
|
922
|
-
"VLLM_LOGGING_LEVEL": "WARNING",
|
|
923
|
-
"VLLM_USE_V1": "0",
|
|
924
|
-
"HF_HUB_ENABLE_HF_TRANSFER": "1",
|
|
925
|
-
},
|
|
926
|
-
),
|
|
927
|
-
TestCase(
|
|
928
|
-
desc="Whisper checkpoint with A100 accelerator",
|
|
929
|
-
input_config=DeployCheckpointsConfigComplete(
|
|
930
|
-
checkpoint_details=definitions.CheckpointList(
|
|
931
|
-
checkpoints=[
|
|
932
|
-
definitions.WhisperCheckpoint(
|
|
933
|
-
training_job_id="job123",
|
|
934
|
-
paths=["rank-0/checkpoint-1/"],
|
|
935
|
-
model_weight_format=definitions.ModelWeightsFormat.WHISPER,
|
|
936
|
-
)
|
|
937
|
-
],
|
|
938
|
-
base_model_id="openai/whisper-large-v3",
|
|
939
|
-
),
|
|
940
|
-
model_name="test-whisper-model",
|
|
941
|
-
compute=definitions.Compute(
|
|
942
|
-
accelerator=truss_config.AcceleratorSpec(
|
|
943
|
-
accelerator="A100", count=2
|
|
944
|
-
)
|
|
945
|
-
),
|
|
946
|
-
runtime=definitions.DeployCheckpointsRuntime(environment_variables={}),
|
|
947
|
-
deployment_name="test-whisper-deployment",
|
|
948
|
-
model_weight_format=definitions.ModelWeightsFormat.WHISPER,
|
|
949
|
-
),
|
|
950
|
-
expected_model_name="test-whisper-model",
|
|
951
|
-
expected_predict_endpoint="/v1/audio/transcriptions",
|
|
952
|
-
expected_accelerator="A100",
|
|
953
|
-
expected_accelerator_count=2,
|
|
954
|
-
expected_checkpoint_paths=["rank-0/checkpoint-1/"],
|
|
955
|
-
expected_environment_variables={
|
|
956
|
-
"VLLM_LOGGING_LEVEL": "WARNING",
|
|
957
|
-
"VLLM_USE_V1": "0",
|
|
958
|
-
"HF_HUB_ENABLE_HF_TRANSFER": "1",
|
|
959
|
-
},
|
|
960
|
-
),
|
|
961
|
-
TestCase(
|
|
962
|
-
desc="Multiple LoRA checkpoints",
|
|
963
|
-
input_config=DeployCheckpointsConfigComplete(
|
|
964
|
-
checkpoint_details=definitions.CheckpointList(
|
|
965
|
-
checkpoints=[
|
|
966
|
-
definitions.LoRACheckpoint(
|
|
967
|
-
training_job_id="job123",
|
|
968
|
-
paths=["rank-0/checkpoint-1/"],
|
|
969
|
-
model_weight_format=ModelWeightsFormat.LORA,
|
|
970
|
-
lora_details=definitions.LoRADetails(rank=16),
|
|
971
|
-
),
|
|
972
|
-
definitions.LoRACheckpoint(
|
|
973
|
-
training_job_id="job123",
|
|
974
|
-
paths=["rank-0/checkpoint-2/"],
|
|
975
|
-
model_weight_format=ModelWeightsFormat.LORA,
|
|
976
|
-
lora_details=definitions.LoRADetails(rank=32),
|
|
977
|
-
),
|
|
978
|
-
],
|
|
979
|
-
base_model_id="google/gemma-3-27b-it",
|
|
980
|
-
),
|
|
981
|
-
model_name="test-multi-checkpoint-model",
|
|
982
|
-
compute=definitions.Compute(
|
|
983
|
-
accelerator=truss_config.AcceleratorSpec(
|
|
984
|
-
accelerator="H100", count=4
|
|
985
|
-
)
|
|
986
|
-
),
|
|
987
|
-
runtime=definitions.DeployCheckpointsRuntime(environment_variables={}),
|
|
988
|
-
deployment_name="test-multi-deployment",
|
|
989
|
-
model_weight_format=ModelWeightsFormat.LORA,
|
|
990
|
-
),
|
|
991
|
-
expected_model_name="test-multi-checkpoint-model",
|
|
992
|
-
expected_predict_endpoint="/v1/chat/completions",
|
|
993
|
-
expected_accelerator="H100",
|
|
994
|
-
expected_accelerator_count=4,
|
|
995
|
-
expected_checkpoint_paths=["rank-0/checkpoint-1/", "rank-0/checkpoint-2/"],
|
|
996
|
-
expected_environment_variables={
|
|
997
|
-
"VLLM_LOGGING_LEVEL": "WARNING",
|
|
998
|
-
"VLLM_USE_V1": "0",
|
|
999
|
-
"HF_HUB_ENABLE_HF_TRANSFER": "1",
|
|
1000
|
-
},
|
|
1001
|
-
),
|
|
1002
|
-
TestCase(
|
|
1003
|
-
desc="No accelerator specified",
|
|
1004
|
-
input_config=DeployCheckpointsConfigComplete(
|
|
1005
|
-
checkpoint_details=definitions.CheckpointList(
|
|
1006
|
-
checkpoints=[
|
|
1007
|
-
definitions.LoRACheckpoint(
|
|
1008
|
-
training_job_id="job123",
|
|
1009
|
-
paths=["rank-0/checkpoint-1/"],
|
|
1010
|
-
model_weight_format=ModelWeightsFormat.LORA,
|
|
1011
|
-
lora_details=definitions.LoRADetails(rank=16),
|
|
1012
|
-
)
|
|
1013
|
-
],
|
|
1014
|
-
base_model_id="google/gemma-3-27b-it",
|
|
1015
|
-
),
|
|
1016
|
-
model_name="test-no-accelerator-model",
|
|
1017
|
-
compute=definitions.Compute(), # No accelerator specified
|
|
1018
|
-
runtime=definitions.DeployCheckpointsRuntime(environment_variables={}),
|
|
1019
|
-
deployment_name="test-no-accelerator-deployment",
|
|
1020
|
-
model_weight_format=ModelWeightsFormat.LORA,
|
|
1021
|
-
),
|
|
1022
|
-
expected_model_name="test-no-accelerator-model",
|
|
1023
|
-
expected_predict_endpoint="/v1/chat/completions",
|
|
1024
|
-
expected_accelerator=None,
|
|
1025
|
-
expected_accelerator_count=None,
|
|
1026
|
-
expected_checkpoint_paths=["rank-0/checkpoint-1/"],
|
|
1027
|
-
expected_environment_variables={
|
|
1028
|
-
"VLLM_LOGGING_LEVEL": "WARNING",
|
|
1029
|
-
"VLLM_USE_V1": "0",
|
|
1030
|
-
"HF_HUB_ENABLE_HF_TRANSFER": "1",
|
|
1031
|
-
},
|
|
1032
|
-
),
|
|
1033
|
-
]
|
|
1034
|
-
|
|
1035
|
-
# Run test cases
|
|
1036
|
-
for test_case in test_cases:
|
|
1037
|
-
print(f"Running test case: {test_case.desc}")
|
|
1038
|
-
|
|
1039
|
-
if test_case.should_raise:
|
|
1040
|
-
# Test error cases
|
|
1041
|
-
with pytest.raises(Exception, match=test_case.should_raise):
|
|
1042
|
-
setup_base_truss_config(test_case.input_config)
|
|
1043
|
-
else:
|
|
1044
|
-
# Test success cases
|
|
1045
|
-
result = setup_base_truss_config(test_case.input_config)
|
|
1046
|
-
|
|
1047
|
-
# Verify basic structure
|
|
1048
|
-
assert isinstance(result, truss_config.TrussConfig), (
|
|
1049
|
-
f"Test case '{test_case.desc}': Result should be TrussConfig"
|
|
1050
|
-
)
|
|
1051
|
-
assert result.model_name == test_case.expected_model_name, (
|
|
1052
|
-
f"Test case '{test_case.desc}': Model name mismatch"
|
|
1053
|
-
)
|
|
1054
|
-
|
|
1055
|
-
# Verify docker server configuration
|
|
1056
|
-
assert result.docker_server is not None, (
|
|
1057
|
-
f"Test case '{test_case.desc}': Docker server should not be None"
|
|
1058
|
-
)
|
|
1059
|
-
assert result.docker_server.start_command == 'sh -c ""', (
|
|
1060
|
-
f"Test case '{test_case.desc}': Start command mismatch"
|
|
1061
|
-
)
|
|
1062
|
-
assert result.docker_server.readiness_endpoint == "/health", (
|
|
1063
|
-
f"Test case '{test_case.desc}': Readiness endpoint mismatch"
|
|
1064
|
-
)
|
|
1065
|
-
assert result.docker_server.liveness_endpoint == "/health", (
|
|
1066
|
-
f"Test case '{test_case.desc}': Liveness endpoint mismatch"
|
|
1067
|
-
)
|
|
1068
|
-
assert (
|
|
1069
|
-
result.docker_server.predict_endpoint
|
|
1070
|
-
== test_case.expected_predict_endpoint
|
|
1071
|
-
), f"Test case '{test_case.desc}': Predict endpoint mismatch"
|
|
1072
|
-
assert result.docker_server.server_port == 8000, (
|
|
1073
|
-
f"Test case '{test_case.desc}': Server port mismatch"
|
|
1074
|
-
)
|
|
1075
|
-
|
|
1076
|
-
# Verify training checkpoints
|
|
1077
|
-
assert result.training_checkpoints is not None, (
|
|
1078
|
-
f"Test case '{test_case.desc}': Training checkpoints should not be None"
|
|
1079
|
-
)
|
|
1080
|
-
assert len(result.training_checkpoints.artifact_references) == len(
|
|
1081
|
-
test_case.expected_checkpoint_paths
|
|
1082
|
-
), f"Test case '{test_case.desc}': Number of checkpoint artifacts mismatch"
|
|
1083
|
-
|
|
1084
|
-
for i, expected_path in enumerate(test_case.expected_checkpoint_paths):
|
|
1085
|
-
artifact_ref = result.training_checkpoints.artifact_references[i]
|
|
1086
|
-
assert artifact_ref.paths == [expected_path], (
|
|
1087
|
-
f"Test case '{test_case.desc}': Checkpoint path {i} mismatch"
|
|
1088
|
-
)
|
|
1089
|
-
|
|
1090
|
-
# Verify resources
|
|
1091
|
-
assert result.resources is not None, (
|
|
1092
|
-
f"Test case '{test_case.desc}': Resources should not be None"
|
|
1093
|
-
)
|
|
1094
|
-
|
|
1095
|
-
if test_case.expected_accelerator:
|
|
1096
|
-
assert result.resources.accelerator is not None, (
|
|
1097
|
-
f"Test case '{test_case.desc}': Accelerator should not be None"
|
|
1098
|
-
)
|
|
1099
|
-
assert (
|
|
1100
|
-
result.resources.accelerator.accelerator
|
|
1101
|
-
== test_case.expected_accelerator
|
|
1102
|
-
), f"Test case '{test_case.desc}': Accelerator type mismatch"
|
|
1103
|
-
assert (
|
|
1104
|
-
result.resources.accelerator.count
|
|
1105
|
-
== test_case.expected_accelerator_count
|
|
1106
|
-
), f"Test case '{test_case.desc}': Accelerator count mismatch"
|
|
1107
|
-
else:
|
|
1108
|
-
# When no accelerator is specified, it creates an AcceleratorSpec with None values
|
|
1109
|
-
assert result.resources.accelerator is not None, (
|
|
1110
|
-
f"Test case '{test_case.desc}': Accelerator should exist"
|
|
1111
|
-
)
|
|
1112
|
-
assert result.resources.accelerator.accelerator is None, (
|
|
1113
|
-
f"Test case '{test_case.desc}': Accelerator type should be None"
|
|
1114
|
-
)
|
|
1115
|
-
|
|
1116
|
-
# Verify environment variables
|
|
1117
|
-
for key, expected_value in test_case.expected_environment_variables.items():
|
|
1118
|
-
assert result.environment_variables[key] == expected_value, (
|
|
1119
|
-
f"Test case '{test_case.desc}': Environment variable {key} mismatch"
|
|
1120
|
-
)
|