truss 0.11.6rc102__py3-none-any.whl → 0.11.24rc2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (88) hide show
  1. truss/api/__init__.py +5 -2
  2. truss/base/constants.py +1 -0
  3. truss/base/trt_llm_config.py +14 -3
  4. truss/base/truss_config.py +19 -4
  5. truss/cli/chains_commands.py +49 -1
  6. truss/cli/cli.py +38 -7
  7. truss/cli/logs/base_watcher.py +31 -12
  8. truss/cli/logs/model_log_watcher.py +24 -1
  9. truss/cli/remote_cli.py +29 -0
  10. truss/cli/resolvers/chain_team_resolver.py +82 -0
  11. truss/cli/resolvers/model_team_resolver.py +90 -0
  12. truss/cli/resolvers/training_project_team_resolver.py +81 -0
  13. truss/cli/train/cache.py +332 -0
  14. truss/cli/train/core.py +57 -163
  15. truss/cli/train/deploy_checkpoints/__init__.py +2 -2
  16. truss/cli/train/deploy_checkpoints/deploy_checkpoints.py +236 -103
  17. truss/cli/train/deploy_checkpoints/deploy_checkpoints_helpers.py +1 -52
  18. truss/cli/train/deploy_checkpoints/deploy_full_checkpoints.py +1 -86
  19. truss/cli/train/deploy_checkpoints/deploy_lora_checkpoints.py +1 -85
  20. truss/cli/train/deploy_checkpoints/deploy_whisper_checkpoints.py +1 -56
  21. truss/cli/train/types.py +18 -9
  22. truss/cli/train_commands.py +180 -35
  23. truss/cli/utils/common.py +40 -3
  24. truss/contexts/image_builder/serving_image_builder.py +17 -4
  25. truss/remote/baseten/api.py +215 -9
  26. truss/remote/baseten/core.py +63 -7
  27. truss/remote/baseten/custom_types.py +1 -0
  28. truss/remote/baseten/remote.py +42 -2
  29. truss/remote/baseten/service.py +0 -7
  30. truss/remote/baseten/utils/transfer.py +5 -2
  31. truss/templates/base.Dockerfile.jinja +8 -4
  32. truss/templates/control/control/application.py +51 -26
  33. truss/templates/control/control/endpoints.py +1 -5
  34. truss/templates/control/control/helpers/inference_server_process_controller.py +10 -4
  35. truss/templates/control/control/helpers/truss_patch/model_container_patch_applier.py +33 -18
  36. truss/templates/control/control/server.py +1 -1
  37. truss/templates/control/requirements.txt +1 -2
  38. truss/templates/docker_server/proxy.conf.jinja +13 -0
  39. truss/templates/docker_server/supervisord.conf.jinja +2 -1
  40. truss/templates/no_build.Dockerfile.jinja +1 -0
  41. truss/templates/server/requirements.txt +2 -3
  42. truss/templates/server/truss_server.py +2 -5
  43. truss/templates/server.Dockerfile.jinja +12 -12
  44. truss/templates/shared/lazy_data_resolver.py +214 -2
  45. truss/templates/shared/util.py +6 -5
  46. truss/tests/cli/chains/test_chains_team_parameter.py +443 -0
  47. truss/tests/cli/test_chains_cli.py +144 -0
  48. truss/tests/cli/test_cli.py +134 -1
  49. truss/tests/cli/test_cli_utils_common.py +11 -0
  50. truss/tests/cli/test_model_team_resolver.py +279 -0
  51. truss/tests/cli/train/test_cache_view.py +240 -3
  52. truss/tests/cli/train/test_deploy_checkpoints.py +2 -846
  53. truss/tests/cli/train/test_train_cli_core.py +2 -2
  54. truss/tests/cli/train/test_train_team_parameter.py +395 -0
  55. truss/tests/conftest.py +187 -0
  56. truss/tests/contexts/image_builder/test_serving_image_builder.py +10 -5
  57. truss/tests/remote/baseten/test_api.py +122 -3
  58. truss/tests/remote/baseten/test_chain_upload.py +294 -0
  59. truss/tests/remote/baseten/test_core.py +86 -0
  60. truss/tests/remote/baseten/test_remote.py +216 -288
  61. truss/tests/remote/baseten/test_service.py +56 -0
  62. truss/tests/templates/control/control/conftest.py +20 -0
  63. truss/tests/templates/control/control/test_endpoints.py +4 -0
  64. truss/tests/templates/control/control/test_server.py +8 -24
  65. truss/tests/templates/control/control/test_server_integration.py +4 -2
  66. truss/tests/test_config.py +21 -12
  67. truss/tests/test_data/server.Dockerfile +3 -1
  68. truss/tests/test_data/test_build_commands_truss/__init__.py +0 -0
  69. truss/tests/test_data/test_build_commands_truss/config.yaml +14 -0
  70. truss/tests/test_data/test_build_commands_truss/model/model.py +12 -0
  71. truss/tests/test_data/test_build_commands_truss/packages/constants/constants.py +1 -0
  72. truss/tests/test_data/test_truss_server_model_cache_v1/config.yaml +1 -0
  73. truss/tests/test_model_inference.py +13 -0
  74. truss/tests/util/test_env_vars.py +8 -3
  75. truss/util/__init__.py +0 -0
  76. truss/util/env_vars.py +19 -8
  77. truss/util/error_utils.py +37 -0
  78. {truss-0.11.6rc102.dist-info → truss-0.11.24rc2.dist-info}/METADATA +2 -2
  79. {truss-0.11.6rc102.dist-info → truss-0.11.24rc2.dist-info}/RECORD +88 -70
  80. {truss-0.11.6rc102.dist-info → truss-0.11.24rc2.dist-info}/WHEEL +1 -1
  81. truss_chains/deployment/deployment_client.py +16 -4
  82. truss_chains/private_types.py +18 -0
  83. truss_chains/public_api.py +3 -0
  84. truss_train/definitions.py +6 -4
  85. truss_train/deployment.py +43 -21
  86. truss_train/public_api.py +4 -2
  87. {truss-0.11.6rc102.dist-info → truss-0.11.24rc2.dist-info}/entry_points.txt +0 -0
  88. {truss-0.11.6rc102.dist-info → truss-0.11.24rc2.dist-info}/licenses/LICENSE +0 -0
@@ -1,38 +1,17 @@
1
- import os
2
- import re
3
- from dataclasses import dataclass
4
- from pathlib import Path
5
- from typing import Dict, List, Optional
6
1
  from unittest.mock import MagicMock, patch
7
2
 
8
3
  import pytest
9
4
 
10
5
  import truss_train.definitions as definitions
11
- from truss.base import truss_config
12
- from truss.cli.train.deploy_checkpoints import prepare_checkpoint_deploy
13
6
  from truss.cli.train.deploy_checkpoints.deploy_checkpoints import (
14
7
  _get_checkpoint_ids_to_deploy,
15
- _render_truss_config_for_checkpoint_deployment,
16
8
  hydrate_checkpoint,
17
9
  )
18
10
  from truss.cli.train.deploy_checkpoints.deploy_full_checkpoints import (
19
11
  hydrate_full_checkpoint,
20
- render_vllm_full_truss_config,
21
- )
22
- from truss.cli.train.deploy_checkpoints.deploy_lora_checkpoints import (
23
- START_COMMAND_ENVVAR_NAME,
24
- _get_lora_rank,
25
- hydrate_lora_checkpoint,
26
- render_vllm_lora_truss_config,
27
12
  )
28
13
  from truss.cli.train.deploy_checkpoints.deploy_whisper_checkpoints import (
29
- VLLM_WHISPER_START_COMMAND,
30
14
  hydrate_whisper_checkpoint,
31
- render_vllm_whisper_truss_config,
32
- )
33
- from truss.cli.train.types import (
34
- DeployCheckpointsConfigComplete,
35
- PrepareCheckpointResult,
36
15
  )
37
16
  from truss_train.definitions import ModelWeightsFormat
38
17
 
@@ -125,413 +104,6 @@ def deploy_checkpoints_mock_checkbox(create_mock_prompt):
125
104
  yield mock
126
105
 
127
106
 
128
- def test_render_truss_config_for_checkpoint_deployment():
129
- deploy_config = DeployCheckpointsConfigComplete(
130
- checkpoint_details=definitions.CheckpointList(
131
- checkpoints=[
132
- definitions.LoRACheckpoint(
133
- training_job_id="job123",
134
- paths=["rank-0/checkpoint-1/"],
135
- model_weight_format=ModelWeightsFormat.LORA,
136
- lora_details=definitions.LoRADetails(rank=16),
137
- )
138
- ],
139
- base_model_id="google/gemma-3-27b-it",
140
- ),
141
- model_name="gemma-3-27b-it-vLLM-LORA",
142
- compute=definitions.Compute(
143
- accelerator=truss_config.AcceleratorSpec(accelerator="H100", count=4)
144
- ),
145
- runtime=definitions.DeployCheckpointsRuntime(
146
- environment_variables={
147
- "HF_TOKEN": definitions.SecretReference(name="hf_access_token")
148
- }
149
- ),
150
- deployment_name="gemma-3-27b-it-vLLM-LORA",
151
- model_weight_format=ModelWeightsFormat.LORA,
152
- )
153
- rendered_truss = _render_truss_config_for_checkpoint_deployment(deploy_config)
154
- test_truss = truss_config.TrussConfig.from_yaml(
155
- Path(
156
- os.path.dirname(__file__),
157
- "resources/test_deploy_from_checkpoint_config.yml",
158
- )
159
- )
160
- assert test_truss.model_name == rendered_truss.model_name
161
- assert (
162
- test_truss.training_checkpoints.artifact_references[0].paths[0]
163
- == rendered_truss.training_checkpoints.artifact_references[0].paths[0]
164
- )
165
- assert (
166
- test_truss.training_checkpoints.artifact_references[0].paths
167
- == rendered_truss.training_checkpoints.artifact_references[0].paths
168
- )
169
- assert (
170
- test_truss.docker_server.start_command
171
- == rendered_truss.docker_server.start_command
172
- )
173
- assert test_truss.resources.accelerator == rendered_truss.resources.accelerator
174
- assert test_truss.secrets == rendered_truss.secrets
175
- assert test_truss.training_checkpoints == rendered_truss.training_checkpoints
176
-
177
-
178
- def test_prepare_checkpoint_deploy_empty_config(
179
- mock_remote,
180
- deploy_checkpoints_mock_select,
181
- deploy_checkpoints_mock_input,
182
- deploy_checkpoints_mock_text,
183
- deploy_checkpoints_mock_checkbox,
184
- ):
185
- # Create empty config
186
- empty_config = definitions.DeployCheckpointsConfig()
187
-
188
- # Call function under test
189
- result = prepare_checkpoint_deploy(
190
- remote_provider=mock_remote,
191
- checkpoint_deploy_config=empty_config,
192
- project_id="project123",
193
- job_id="job123",
194
- )
195
-
196
- assert isinstance(result, PrepareCheckpointResult)
197
- assert result.checkpoint_deploy_config.model_name == "gemma-3-27b-it-vLLM-LORA"
198
- assert (
199
- result.checkpoint_deploy_config.checkpoint_details.base_model_id
200
- == "google/gemma-3-27b-it"
201
- )
202
- assert len(result.checkpoint_deploy_config.checkpoint_details.checkpoints) == 1
203
- checkpoint = result.checkpoint_deploy_config.checkpoint_details.checkpoints[0]
204
- assert checkpoint.training_job_id == "job123"
205
- assert isinstance(checkpoint, definitions.LoRACheckpoint)
206
- assert checkpoint.lora_details.rank == 16
207
- assert result.checkpoint_deploy_config.compute.accelerator.accelerator == "H100"
208
- assert result.checkpoint_deploy_config.compute.accelerator.count == 4
209
- assert (
210
- result.checkpoint_deploy_config.runtime.environment_variables["HF_TOKEN"].name
211
- == "hf_access_token"
212
- )
213
-
214
-
215
- def test_prepare_checkpoint_deploy_complete_config(
216
- mock_remote,
217
- deploy_checkpoints_mock_select,
218
- deploy_checkpoints_mock_text,
219
- deploy_checkpoints_mock_checkbox,
220
- ):
221
- # Create complete config with all fields specified
222
- complete_config = definitions.DeployCheckpointsConfig(
223
- checkpoint_details=definitions.CheckpointList(
224
- checkpoints=[
225
- definitions.LoRACheckpoint(
226
- training_job_id="job123",
227
- paths=["job123/rank-0/checkpoint-1/"],
228
- model_weight_format=ModelWeightsFormat.LORA,
229
- lora_details=definitions.LoRADetails(rank=32),
230
- )
231
- ],
232
- base_model_id="google/gemma-3-27b-it",
233
- ),
234
- model_name="my-custom-model",
235
- deployment_name="my-deployment",
236
- compute=definitions.Compute(
237
- accelerator=truss_config.AcceleratorSpec(accelerator="A100", count=2)
238
- ),
239
- runtime=definitions.DeployCheckpointsRuntime(
240
- environment_variables={
241
- "HF_TOKEN": definitions.SecretReference(name="my_custom_secret"),
242
- "CUSTOM_VAR": "custom_value",
243
- }
244
- ),
245
- )
246
-
247
- # Call function under test
248
- result = prepare_checkpoint_deploy(
249
- remote_provider=mock_remote,
250
- checkpoint_deploy_config=complete_config,
251
- project_id="project123",
252
- job_id="job123",
253
- )
254
-
255
- # Verify result
256
- assert isinstance(result, PrepareCheckpointResult)
257
-
258
- # Verify no prompts were called
259
- deploy_checkpoints_mock_select.assert_not_called()
260
- deploy_checkpoints_mock_text.assert_not_called()
261
- deploy_checkpoints_mock_checkbox.assert_not_called()
262
-
263
- # Verify config values were preserved
264
- config = result.checkpoint_deploy_config
265
- assert config.model_name == "my-custom-model"
266
- assert config.deployment_name == "my-deployment"
267
-
268
- # Verify checkpoint details
269
- assert config.checkpoint_details.base_model_id == "google/gemma-3-27b-it"
270
- assert len(config.checkpoint_details.checkpoints) == 1
271
- checkpoint = config.checkpoint_details.checkpoints[0]
272
- assert checkpoint.training_job_id == "job123"
273
- assert checkpoint.model_weight_format == ModelWeightsFormat.LORA
274
- assert isinstance(checkpoint, definitions.LoRACheckpoint)
275
- assert checkpoint.lora_details.rank == 32
276
-
277
- # Verify compute config
278
- assert config.compute.accelerator.accelerator == "A100"
279
- assert config.compute.accelerator.count == 2
280
-
281
- # Verify runtime config
282
- env_vars = config.runtime.environment_variables
283
- assert env_vars["HF_TOKEN"].name == "my_custom_secret"
284
- assert env_vars["CUSTOM_VAR"] == "custom_value"
285
-
286
- # open the config.yaml file and verify the tensor parallel size is 2
287
- # additional tests can be added to verify the config.yaml file is correct
288
- truss_cfg = truss_config.TrussConfig.from_yaml(
289
- Path(result.truss_directory, "config.yaml")
290
- )
291
- # Check that the start command is now the environment variable reference
292
- assert truss_cfg.docker_server.start_command == "%(ENV_BT_DOCKER_SERVER_START_CMD)s"
293
- # Check that the actual start command with tensor parallel size is in the environment variable
294
- assert (
295
- "--tensor-parallel-size 2"
296
- in truss_cfg.environment_variables["BT_DOCKER_SERVER_START_CMD"]
297
- )
298
-
299
-
300
- def test_checkpoint_lora_rank_validation():
301
- """Test that LoRACheckpoint accepts valid LoRA rank values."""
302
- valid_ranks = [8, 16, 32, 64, 128, 256, 320, 512]
303
-
304
- for rank in valid_ranks:
305
- checkpoint = definitions.LoRACheckpoint(
306
- training_job_id="job123",
307
- paths=["job123/rank-0/checkpoint-1/"],
308
- model_weight_format=ModelWeightsFormat.LORA,
309
- lora_details=definitions.LoRADetails(rank=rank),
310
- )
311
- assert checkpoint.lora_details.rank == rank
312
-
313
- invalid_ranks = [
314
- 1,
315
- 2,
316
- 4,
317
- 7,
318
- 9,
319
- 15,
320
- 17,
321
- 31,
322
- 33,
323
- 63,
324
- 65,
325
- 127,
326
- 129,
327
- 255,
328
- 257,
329
- 319,
330
- 321,
331
- 511,
332
- 513,
333
- 1000,
334
- ]
335
- for rank in invalid_ranks:
336
- with pytest.raises(ValueError, match=f"lora_rank \\({rank}\\) must be one of"):
337
- definitions.LoRACheckpoint(
338
- training_job_id="job123",
339
- paths=["job123/rank-0/checkpoint-1/"],
340
- model_weight_format=ModelWeightsFormat.LORA,
341
- lora_details=definitions.LoRADetails(rank=rank),
342
- )
343
-
344
-
345
- def test_get_lora_rank():
346
- """Test that get_lora_rank returns valid values from checkpoint response."""
347
- # Test with valid rank from API
348
- checkpoint_resp = {"lora_adapter_config": {"r": 64}}
349
- assert _get_lora_rank(checkpoint_resp) == 64
350
- # Test with missing lora_adapter_config (should use DEFAULT_LORA_RANK)
351
- checkpoint_resp = {}
352
- assert _get_lora_rank(checkpoint_resp) == 16 # DEFAULT_LORA_RANK
353
- # Test with missing 'r' field (should use DEFAULT_LORA_RANK)
354
- checkpoint_resp = {"lora_adapter_config": {}}
355
- assert _get_lora_rank(checkpoint_resp) == 16 # DEFAULT_LORA_RANK
356
- # Test with invalid rank from API
357
- checkpoint_resp = {"lora_adapter_config": {"r": 1}}
358
- with pytest.raises(
359
- ValueError,
360
- match=re.escape("LoRA rank 1 from checkpoint is not in allowed values"),
361
- ):
362
- _get_lora_rank(checkpoint_resp)
363
- # Test with another invalid rank
364
- checkpoint_resp = {"lora_adapter_config": {"r": 1000}}
365
- with pytest.raises(
366
- ValueError,
367
- match=re.escape("LoRA rank 1000 from checkpoint is not in allowed values"),
368
- ):
369
- _get_lora_rank(checkpoint_resp)
370
-
371
-
372
- def test_hydrate_lora_checkpoint():
373
- """Test that hydrate_lora_checkpoint creates proper LoRACheckpoint objects."""
374
- job_id = "test_job_123"
375
- checkpoint_id = "checkpoint-456"
376
- checkpoint_data = {
377
- "lora_adapter_config": {"r": 64},
378
- "base_model": "google/gemma-3-27b-it",
379
- "checkpoint_type": "lora",
380
- }
381
-
382
- result = hydrate_lora_checkpoint(job_id, checkpoint_id, checkpoint_data)
383
-
384
- assert isinstance(result, definitions.LoRACheckpoint)
385
- assert result.training_job_id == job_id
386
- assert result.lora_details.rank == 64
387
- assert len(result.paths) == 1
388
- assert result.paths[0] == f"rank-0/{checkpoint_id}/"
389
-
390
-
391
- def test_hydrate_checkpoint_dispatcher():
392
- """Test that hydrate_checkpoint properly dispatches to the right function based on checkpoint type."""
393
- job_id = "test_job_123"
394
- checkpoint_id = "checkpoint-456"
395
- checkpoint_data = {
396
- "lora_adapter_config": {"r": 32},
397
- "base_model": "google/gemma-3-27b-it",
398
- "checkpoint_type": "lora",
399
- }
400
-
401
- # Test LoRA checkpoint type
402
- result = hydrate_checkpoint(job_id, checkpoint_id, checkpoint_data, "lora")
403
- assert isinstance(result, definitions.LoRACheckpoint)
404
- assert result.lora_details.rank == 32
405
-
406
- # Test uppercase LoRA checkpoint type
407
- result = hydrate_checkpoint(job_id, checkpoint_id, checkpoint_data, "LORA")
408
- assert isinstance(result, definitions.LoRACheckpoint)
409
- assert result.lora_details.rank == 32
410
-
411
- # Test unsupported checkpoint type
412
- with pytest.raises(ValueError, match="Unsupported checkpoint type: unsupported"):
413
- hydrate_checkpoint(job_id, checkpoint_id, checkpoint_data, "unsupported")
414
-
415
-
416
- def test_render_vllm_lora_truss_config():
417
- """Test that render_vllm_lora_truss_config creates proper TrussConfig for LoRA deployments."""
418
- deploy_config = DeployCheckpointsConfigComplete(
419
- checkpoint_details=definitions.CheckpointList(
420
- checkpoints=[
421
- definitions.LoRACheckpoint(
422
- training_job_id="job123",
423
- paths=["rank-0/checkpoint-1/"],
424
- model_weight_format=ModelWeightsFormat.LORA,
425
- lora_details=definitions.LoRADetails(rank=64),
426
- )
427
- ],
428
- base_model_id="google/gemma-3-27b-it",
429
- ),
430
- model_name="test-lora-model",
431
- compute=definitions.Compute(
432
- accelerator=truss_config.AcceleratorSpec(accelerator="H100", count=2)
433
- ),
434
- runtime=definitions.DeployCheckpointsRuntime(
435
- environment_variables={
436
- "HF_TOKEN": definitions.SecretReference(name="hf_token")
437
- }
438
- ),
439
- deployment_name="test-deployment",
440
- model_weight_format=ModelWeightsFormat.LORA,
441
- )
442
-
443
- result = render_vllm_lora_truss_config(deploy_config)
444
-
445
- expected_vllm_command = 'sh -c "HF_TOKEN=$(cat /secrets/hf_token) vllm serve google/gemma-3-27b-it --port 8000 --tensor-parallel-size 2 --enable-lora --max-lora-rank 64 --dtype bfloat16 --lora-modules job123=/tmp/training_checkpoints/job123/rank-0/checkpoint-1"'
446
-
447
- assert isinstance(result, truss_config.TrussConfig)
448
- assert result.model_name == "test-lora-model"
449
- assert result.docker_server is not None
450
- assert result.docker_server.start_command == f"%(ENV_{START_COMMAND_ENVVAR_NAME})s"
451
- assert (
452
- result.environment_variables[START_COMMAND_ENVVAR_NAME] == expected_vllm_command
453
- )
454
-
455
-
456
- def test_render_truss_config_delegation():
457
- """Test that _render_truss_config_for_checkpoint_deployment delegates correctly based on model weight format."""
458
- deploy_config = DeployCheckpointsConfigComplete(
459
- checkpoint_details=definitions.CheckpointList(
460
- checkpoints=[
461
- definitions.LoRACheckpoint(
462
- training_job_id="job123",
463
- paths=["rank-0/checkpoint-1/"],
464
- model_weight_format=ModelWeightsFormat.LORA,
465
- lora_details=definitions.LoRADetails(rank=32),
466
- )
467
- ],
468
- base_model_id="google/gemma-3-27b-it",
469
- ),
470
- model_name="test-model",
471
- compute=definitions.Compute(
472
- accelerator=truss_config.AcceleratorSpec(accelerator="H100", count=4)
473
- ),
474
- runtime=definitions.DeployCheckpointsRuntime(environment_variables={}),
475
- deployment_name="test-deployment",
476
- model_weight_format=ModelWeightsFormat.LORA,
477
- )
478
-
479
- # Test that it works for LoRA format
480
- result = _render_truss_config_for_checkpoint_deployment(deploy_config)
481
- assert isinstance(result, truss_config.TrussConfig)
482
- expected_vllm_command = "vllm serve google/gemma-3-27b-it --port 8000 --tensor-parallel-size 4 --enable-lora --max-lora-rank 32 --dtype bfloat16 --lora-modules job123=/tmp/training_checkpoints/job123/rank-0/checkpoint-1"
483
- assert (
484
- expected_vllm_command in result.environment_variables[START_COMMAND_ENVVAR_NAME]
485
- )
486
-
487
-
488
- def test_render_vllm_full_truss_config():
489
- """Test that render_vllm_full_truss_config creates proper TrussConfig for full fine-tune deployments."""
490
- deploy_config = DeployCheckpointsConfigComplete(
491
- checkpoint_details=definitions.CheckpointList(
492
- checkpoints=[
493
- definitions.FullCheckpoint(
494
- training_job_id="job123",
495
- paths=["rank-0/checkpoint-1/"],
496
- model_weight_format=ModelWeightsFormat.FULL,
497
- )
498
- ],
499
- base_model_id=None, # Not needed for full fine-tune
500
- ),
501
- model_name="test-full-model",
502
- compute=definitions.Compute(
503
- accelerator=truss_config.AcceleratorSpec(accelerator="H100", count=2)
504
- ),
505
- runtime=definitions.DeployCheckpointsRuntime(
506
- environment_variables={
507
- "HF_TOKEN": definitions.SecretReference(name="hf_token")
508
- }
509
- ),
510
- deployment_name="test-deployment",
511
- model_weight_format=ModelWeightsFormat.FULL,
512
- )
513
-
514
- result = render_vllm_full_truss_config(deploy_config)
515
- expected_vllm_command = (
516
- "sh -c 'HF_TOKEN=$(cat /secrets/hf_token) "
517
- 'HF_TOKEN="$$(cat /secrets/hf_access_token)" && export HF_TOKEN && '
518
- "if [ -f /tmp/training_checkpoints/job123/rank-0/checkpoint-1/chat_template.jinja ]; then "
519
- "vllm serve /tmp/training_checkpoints/job123/rank-0/checkpoint-1 "
520
- "--chat-template /tmp/training_checkpoints/job123/rank-0/checkpoint-1/chat_template.jinja "
521
- "--port 8000 --tensor-parallel-size 2 --dtype bfloat16; else "
522
- "vllm serve /tmp/training_checkpoints/job123/rank-0/checkpoint-1 "
523
- "--port 8000 --tensor-parallel-size 2 --dtype bfloat16; fi'"
524
- )
525
-
526
- assert isinstance(result, truss_config.TrussConfig)
527
- assert result.model_name == "test-full-model"
528
- assert result.docker_server is not None
529
- assert result.docker_server.start_command == f"%(ENV_{START_COMMAND_ENVVAR_NAME})s"
530
- assert (
531
- result.environment_variables[START_COMMAND_ENVVAR_NAME] == expected_vllm_command
532
- )
533
-
534
-
535
107
  def test_hydrate_full_checkpoint():
536
108
  """Test that hydrate_full_checkpoint creates proper FullCheckpoint objects."""
537
109
  job_id = "test_job_123"
@@ -543,8 +115,7 @@ def test_hydrate_full_checkpoint():
543
115
  assert isinstance(result, definitions.FullCheckpoint)
544
116
  assert result.training_job_id == job_id
545
117
  assert result.model_weight_format == ModelWeightsFormat.FULL
546
- assert len(result.paths) == 1
547
- assert result.paths[0] == f"rank-0/{checkpoint_id}/"
118
+ assert result.checkpoint_name == checkpoint_id
548
119
 
549
120
 
550
121
  def test_hydrate_checkpoint_dispatcher_full():
@@ -691,34 +262,6 @@ def test_get_checkpoint_ids_to_deploy_single_checkpoint():
691
262
  assert result == ["checkpoint-1"]
692
263
 
693
264
 
694
- def test_vllm_whisper_start_command_template():
695
- """Test that the VLLM_WHISPER_START_COMMAND template renders correctly."""
696
- # Test with all variables
697
- result = VLLM_WHISPER_START_COMMAND.render(
698
- model_path="/path/to/model",
699
- envvars="CUDA_VISIBLE_DEVICES=0",
700
- specify_tensor_parallelism=4,
701
- )
702
-
703
- expected = (
704
- "sh -c 'CUDA_VISIBLE_DEVICES=0 "
705
- 'HF_TOKEN="$$(cat /secrets/hf_access_token)" && export HF_TOKEN && '
706
- "vllm serve /path/to/model --port 8000 --tensor-parallel-size 4'"
707
- )
708
- assert result == expected
709
-
710
- result = VLLM_WHISPER_START_COMMAND.render(
711
- model_path="/path/to/model", envvars=None, specify_tensor_parallelism=1
712
- )
713
-
714
- expected = (
715
- "sh -c '"
716
- 'HF_TOKEN="$$(cat /secrets/hf_access_token)" && export HF_TOKEN && '
717
- "vllm serve /path/to/model --port 8000 --tensor-parallel-size 1'"
718
- )
719
- assert result == expected
720
-
721
-
722
265
  def test_hydrate_whisper_checkpoint():
723
266
  """Test that hydrate_whisper_checkpoint creates correct WhisperCheckpoint object."""
724
267
  job_id = "test-job-123"
@@ -728,393 +271,6 @@ def test_hydrate_whisper_checkpoint():
728
271
  result = hydrate_whisper_checkpoint(job_id, checkpoint_id, checkpoint)
729
272
 
730
273
  assert result.training_job_id == job_id
731
- assert result.paths == [f"rank-0/{checkpoint_id}/"]
274
+ assert result.checkpoint_name == checkpoint_id
732
275
  assert result.model_weight_format == definitions.ModelWeightsFormat.WHISPER
733
276
  assert isinstance(result, definitions.WhisperCheckpoint)
734
-
735
-
736
- @patch(
737
- "truss.cli.train.deploy_checkpoints.deploy_whisper_checkpoints.setup_base_truss_config"
738
- )
739
- @patch(
740
- "truss.cli.train.deploy_checkpoints.deploy_whisper_checkpoints.setup_environment_variables_and_secrets"
741
- )
742
- @patch(
743
- "truss.cli.train.deploy_checkpoints.deploy_whisper_checkpoints.build_full_checkpoint_string"
744
- )
745
- def test_render_vllm_whisper_truss_config(
746
- mock_build_full_checkpoint_string, mock_setup_env_vars, mock_setup_base_config
747
- ):
748
- """Test that render_vllm_whisper_truss_config renders truss config correctly."""
749
- # Mock dependencies
750
- mock_truss_config = MagicMock()
751
- mock_truss_config.environment_variables = {}
752
- mock_truss_config.docker_server = MagicMock()
753
- mock_setup_base_config.return_value = mock_truss_config
754
-
755
- mock_setup_env_vars.return_value = "HF_TOKEN=$(cat /secrets/hf_access_token)"
756
- mock_build_full_checkpoint_string.return_value = "/path/to/checkpoint"
757
-
758
- # Create test config
759
- deploy_config = DeployCheckpointsConfigComplete(
760
- checkpoint_details=definitions.CheckpointList(
761
- checkpoints=[
762
- definitions.WhisperCheckpoint(
763
- training_job_id="job123",
764
- paths=["rank-0/checkpoint-1/"],
765
- model_weight_format=definitions.ModelWeightsFormat.WHISPER,
766
- )
767
- ],
768
- base_model_id="openai/whisper-large-v3",
769
- ),
770
- model_name="whisper-large-v3-vLLM",
771
- compute=definitions.Compute(
772
- accelerator=truss_config.AcceleratorSpec(accelerator="H100", count=4)
773
- ),
774
- runtime=definitions.DeployCheckpointsRuntime(
775
- environment_variables={
776
- "HF_TOKEN": definitions.SecretReference(name="hf_access_token")
777
- }
778
- ),
779
- deployment_name="whisper-large-v3-vLLM",
780
- model_weight_format=definitions.ModelWeightsFormat.WHISPER,
781
- )
782
-
783
- result = render_vllm_whisper_truss_config(deploy_config)
784
-
785
- mock_setup_base_config.assert_called_once_with(deploy_config)
786
- mock_setup_env_vars.assert_called_once_with(mock_truss_config, deploy_config)
787
- mock_build_full_checkpoint_string.assert_called_once_with(mock_truss_config)
788
-
789
- assert result == mock_truss_config
790
-
791
- expected_start_command = (
792
- "sh -c 'HF_TOKEN=$(cat /secrets/hf_access_token) "
793
- 'HF_TOKEN="$$(cat /secrets/hf_access_token)" && export HF_TOKEN && '
794
- "vllm serve /path/to/checkpoint --port 8000 --tensor-parallel-size 4'"
795
- )
796
- assert (
797
- result.environment_variables[START_COMMAND_ENVVAR_NAME]
798
- == expected_start_command
799
- )
800
-
801
- assert result.docker_server.start_command == f"%(ENV_{START_COMMAND_ENVVAR_NAME})s"
802
-
803
-
804
- @patch(
805
- "truss.cli.train.deploy_checkpoints.deploy_whisper_checkpoints.setup_base_truss_config"
806
- )
807
- @patch(
808
- "truss.cli.train.deploy_checkpoints.deploy_whisper_checkpoints.setup_environment_variables_and_secrets"
809
- )
810
- @patch(
811
- "truss.cli.train.deploy_checkpoints.deploy_whisper_checkpoints.build_full_checkpoint_string"
812
- )
813
- def test_render_vllm_whisper_truss_config_with_envvars(
814
- mock_build_full_checkpoint_string, mock_setup_env_vars, mock_setup_base_config
815
- ):
816
- """Test that render_vllm_whisper_truss_config handles environment variables correctly."""
817
- # Mock dependencies
818
- mock_truss_config = MagicMock()
819
- mock_truss_config.environment_variables = {}
820
- mock_truss_config.docker_server = MagicMock()
821
- mock_setup_base_config.return_value = mock_truss_config
822
-
823
- mock_setup_env_vars.return_value = "CUDA_VISIBLE_DEVICES=0,1"
824
- mock_build_full_checkpoint_string.return_value = "/path/to/checkpoint"
825
-
826
- # Create test config with environment variables
827
- deploy_config = DeployCheckpointsConfigComplete(
828
- checkpoint_details=definitions.CheckpointList(
829
- checkpoints=[
830
- definitions.WhisperCheckpoint(
831
- training_job_id="job123",
832
- paths=["rank-0/checkpoint-1/"],
833
- model_weight_format=definitions.ModelWeightsFormat.WHISPER,
834
- )
835
- ],
836
- base_model_id="openai/whisper-large-v3",
837
- ),
838
- model_name="whisper-large-v3-vLLM",
839
- compute=definitions.Compute(
840
- accelerator=truss_config.AcceleratorSpec(accelerator="H100", count=2)
841
- ),
842
- runtime=definitions.DeployCheckpointsRuntime(
843
- environment_variables={
844
- "CUDA_VISIBLE_DEVICES": "0,1",
845
- "HF_TOKEN": definitions.SecretReference(name="hf_access_token"),
846
- }
847
- ),
848
- deployment_name="whisper-large-v3-vLLM",
849
- model_weight_format=definitions.ModelWeightsFormat.WHISPER,
850
- )
851
-
852
- # Call function under test
853
- result = render_vllm_whisper_truss_config(deploy_config)
854
-
855
- # Verify environment variables are included in start command
856
- expected_start_command = (
857
- "sh -c 'CUDA_VISIBLE_DEVICES=0,1 "
858
- 'HF_TOKEN="$$(cat /secrets/hf_access_token)" && export HF_TOKEN && '
859
- "vllm serve /path/to/checkpoint --port 8000 --tensor-parallel-size 2'"
860
- )
861
- assert (
862
- result.environment_variables[START_COMMAND_ENVVAR_NAME]
863
- == expected_start_command
864
- )
865
-
866
-
867
- @dataclass
868
- class TestCase:
869
- """Test case for setup_base_truss_config function."""
870
-
871
- desc: str
872
- input_config: DeployCheckpointsConfigComplete
873
- expected_model_name: str
874
- expected_predict_endpoint: str
875
- expected_accelerator: Optional[str]
876
- expected_accelerator_count: Optional[int]
877
- expected_checkpoint_paths: List[str]
878
- expected_environment_variables: Dict[str, str]
879
- should_raise: Optional[str] = None # Error message if function should raise
880
-
881
- __test__ = False # Tell pytest this is not a test class
882
-
883
-
884
- def test_setup_base_truss_config():
885
- """Table-driven test for setup_base_truss_config function."""
886
- from truss.cli.train.deploy_checkpoints.deploy_checkpoints_helpers import (
887
- setup_base_truss_config,
888
- )
889
-
890
- # Define test cases
891
- test_cases = [
892
- TestCase(
893
- desc="LoRA checkpoint with H100 accelerator",
894
- input_config=DeployCheckpointsConfigComplete(
895
- checkpoint_details=definitions.CheckpointList(
896
- checkpoints=[
897
- definitions.LoRACheckpoint(
898
- training_job_id="job123",
899
- paths=["rank-0/checkpoint-1/"],
900
- model_weight_format=ModelWeightsFormat.LORA,
901
- lora_details=definitions.LoRADetails(rank=32),
902
- )
903
- ],
904
- base_model_id="google/gemma-3-27b-it",
905
- ),
906
- model_name="test-lora-model",
907
- compute=definitions.Compute(
908
- accelerator=truss_config.AcceleratorSpec(
909
- accelerator="H100", count=4
910
- )
911
- ),
912
- runtime=definitions.DeployCheckpointsRuntime(environment_variables={}),
913
- deployment_name="test-deployment",
914
- model_weight_format=ModelWeightsFormat.LORA,
915
- ),
916
- expected_model_name="test-lora-model",
917
- expected_predict_endpoint="/v1/chat/completions",
918
- expected_accelerator="H100",
919
- expected_accelerator_count=4,
920
- expected_checkpoint_paths=["rank-0/checkpoint-1/"],
921
- expected_environment_variables={
922
- "VLLM_LOGGING_LEVEL": "WARNING",
923
- "VLLM_USE_V1": "0",
924
- "HF_HUB_ENABLE_HF_TRANSFER": "1",
925
- },
926
- ),
927
- TestCase(
928
- desc="Whisper checkpoint with A100 accelerator",
929
- input_config=DeployCheckpointsConfigComplete(
930
- checkpoint_details=definitions.CheckpointList(
931
- checkpoints=[
932
- definitions.WhisperCheckpoint(
933
- training_job_id="job123",
934
- paths=["rank-0/checkpoint-1/"],
935
- model_weight_format=definitions.ModelWeightsFormat.WHISPER,
936
- )
937
- ],
938
- base_model_id="openai/whisper-large-v3",
939
- ),
940
- model_name="test-whisper-model",
941
- compute=definitions.Compute(
942
- accelerator=truss_config.AcceleratorSpec(
943
- accelerator="A100", count=2
944
- )
945
- ),
946
- runtime=definitions.DeployCheckpointsRuntime(environment_variables={}),
947
- deployment_name="test-whisper-deployment",
948
- model_weight_format=definitions.ModelWeightsFormat.WHISPER,
949
- ),
950
- expected_model_name="test-whisper-model",
951
- expected_predict_endpoint="/v1/audio/transcriptions",
952
- expected_accelerator="A100",
953
- expected_accelerator_count=2,
954
- expected_checkpoint_paths=["rank-0/checkpoint-1/"],
955
- expected_environment_variables={
956
- "VLLM_LOGGING_LEVEL": "WARNING",
957
- "VLLM_USE_V1": "0",
958
- "HF_HUB_ENABLE_HF_TRANSFER": "1",
959
- },
960
- ),
961
- TestCase(
962
- desc="Multiple LoRA checkpoints",
963
- input_config=DeployCheckpointsConfigComplete(
964
- checkpoint_details=definitions.CheckpointList(
965
- checkpoints=[
966
- definitions.LoRACheckpoint(
967
- training_job_id="job123",
968
- paths=["rank-0/checkpoint-1/"],
969
- model_weight_format=ModelWeightsFormat.LORA,
970
- lora_details=definitions.LoRADetails(rank=16),
971
- ),
972
- definitions.LoRACheckpoint(
973
- training_job_id="job123",
974
- paths=["rank-0/checkpoint-2/"],
975
- model_weight_format=ModelWeightsFormat.LORA,
976
- lora_details=definitions.LoRADetails(rank=32),
977
- ),
978
- ],
979
- base_model_id="google/gemma-3-27b-it",
980
- ),
981
- model_name="test-multi-checkpoint-model",
982
- compute=definitions.Compute(
983
- accelerator=truss_config.AcceleratorSpec(
984
- accelerator="H100", count=4
985
- )
986
- ),
987
- runtime=definitions.DeployCheckpointsRuntime(environment_variables={}),
988
- deployment_name="test-multi-deployment",
989
- model_weight_format=ModelWeightsFormat.LORA,
990
- ),
991
- expected_model_name="test-multi-checkpoint-model",
992
- expected_predict_endpoint="/v1/chat/completions",
993
- expected_accelerator="H100",
994
- expected_accelerator_count=4,
995
- expected_checkpoint_paths=["rank-0/checkpoint-1/", "rank-0/checkpoint-2/"],
996
- expected_environment_variables={
997
- "VLLM_LOGGING_LEVEL": "WARNING",
998
- "VLLM_USE_V1": "0",
999
- "HF_HUB_ENABLE_HF_TRANSFER": "1",
1000
- },
1001
- ),
1002
- TestCase(
1003
- desc="No accelerator specified",
1004
- input_config=DeployCheckpointsConfigComplete(
1005
- checkpoint_details=definitions.CheckpointList(
1006
- checkpoints=[
1007
- definitions.LoRACheckpoint(
1008
- training_job_id="job123",
1009
- paths=["rank-0/checkpoint-1/"],
1010
- model_weight_format=ModelWeightsFormat.LORA,
1011
- lora_details=definitions.LoRADetails(rank=16),
1012
- )
1013
- ],
1014
- base_model_id="google/gemma-3-27b-it",
1015
- ),
1016
- model_name="test-no-accelerator-model",
1017
- compute=definitions.Compute(), # No accelerator specified
1018
- runtime=definitions.DeployCheckpointsRuntime(environment_variables={}),
1019
- deployment_name="test-no-accelerator-deployment",
1020
- model_weight_format=ModelWeightsFormat.LORA,
1021
- ),
1022
- expected_model_name="test-no-accelerator-model",
1023
- expected_predict_endpoint="/v1/chat/completions",
1024
- expected_accelerator=None,
1025
- expected_accelerator_count=None,
1026
- expected_checkpoint_paths=["rank-0/checkpoint-1/"],
1027
- expected_environment_variables={
1028
- "VLLM_LOGGING_LEVEL": "WARNING",
1029
- "VLLM_USE_V1": "0",
1030
- "HF_HUB_ENABLE_HF_TRANSFER": "1",
1031
- },
1032
- ),
1033
- ]
1034
-
1035
- # Run test cases
1036
- for test_case in test_cases:
1037
- print(f"Running test case: {test_case.desc}")
1038
-
1039
- if test_case.should_raise:
1040
- # Test error cases
1041
- with pytest.raises(Exception, match=test_case.should_raise):
1042
- setup_base_truss_config(test_case.input_config)
1043
- else:
1044
- # Test success cases
1045
- result = setup_base_truss_config(test_case.input_config)
1046
-
1047
- # Verify basic structure
1048
- assert isinstance(result, truss_config.TrussConfig), (
1049
- f"Test case '{test_case.desc}': Result should be TrussConfig"
1050
- )
1051
- assert result.model_name == test_case.expected_model_name, (
1052
- f"Test case '{test_case.desc}': Model name mismatch"
1053
- )
1054
-
1055
- # Verify docker server configuration
1056
- assert result.docker_server is not None, (
1057
- f"Test case '{test_case.desc}': Docker server should not be None"
1058
- )
1059
- assert result.docker_server.start_command == 'sh -c ""', (
1060
- f"Test case '{test_case.desc}': Start command mismatch"
1061
- )
1062
- assert result.docker_server.readiness_endpoint == "/health", (
1063
- f"Test case '{test_case.desc}': Readiness endpoint mismatch"
1064
- )
1065
- assert result.docker_server.liveness_endpoint == "/health", (
1066
- f"Test case '{test_case.desc}': Liveness endpoint mismatch"
1067
- )
1068
- assert (
1069
- result.docker_server.predict_endpoint
1070
- == test_case.expected_predict_endpoint
1071
- ), f"Test case '{test_case.desc}': Predict endpoint mismatch"
1072
- assert result.docker_server.server_port == 8000, (
1073
- f"Test case '{test_case.desc}': Server port mismatch"
1074
- )
1075
-
1076
- # Verify training checkpoints
1077
- assert result.training_checkpoints is not None, (
1078
- f"Test case '{test_case.desc}': Training checkpoints should not be None"
1079
- )
1080
- assert len(result.training_checkpoints.artifact_references) == len(
1081
- test_case.expected_checkpoint_paths
1082
- ), f"Test case '{test_case.desc}': Number of checkpoint artifacts mismatch"
1083
-
1084
- for i, expected_path in enumerate(test_case.expected_checkpoint_paths):
1085
- artifact_ref = result.training_checkpoints.artifact_references[i]
1086
- assert artifact_ref.paths == [expected_path], (
1087
- f"Test case '{test_case.desc}': Checkpoint path {i} mismatch"
1088
- )
1089
-
1090
- # Verify resources
1091
- assert result.resources is not None, (
1092
- f"Test case '{test_case.desc}': Resources should not be None"
1093
- )
1094
-
1095
- if test_case.expected_accelerator:
1096
- assert result.resources.accelerator is not None, (
1097
- f"Test case '{test_case.desc}': Accelerator should not be None"
1098
- )
1099
- assert (
1100
- result.resources.accelerator.accelerator
1101
- == test_case.expected_accelerator
1102
- ), f"Test case '{test_case.desc}': Accelerator type mismatch"
1103
- assert (
1104
- result.resources.accelerator.count
1105
- == test_case.expected_accelerator_count
1106
- ), f"Test case '{test_case.desc}': Accelerator count mismatch"
1107
- else:
1108
- # When no accelerator is specified, it creates an AcceleratorSpec with None values
1109
- assert result.resources.accelerator is not None, (
1110
- f"Test case '{test_case.desc}': Accelerator should exist"
1111
- )
1112
- assert result.resources.accelerator.accelerator is None, (
1113
- f"Test case '{test_case.desc}': Accelerator type should be None"
1114
- )
1115
-
1116
- # Verify environment variables
1117
- for key, expected_value in test_case.expected_environment_variables.items():
1118
- assert result.environment_variables[key] == expected_value, (
1119
- f"Test case '{test_case.desc}': Environment variable {key} mismatch"
1120
- )