truss 0.11.6rc102__py3-none-any.whl → 0.11.24rc2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (88) hide show
  1. truss/api/__init__.py +5 -2
  2. truss/base/constants.py +1 -0
  3. truss/base/trt_llm_config.py +14 -3
  4. truss/base/truss_config.py +19 -4
  5. truss/cli/chains_commands.py +49 -1
  6. truss/cli/cli.py +38 -7
  7. truss/cli/logs/base_watcher.py +31 -12
  8. truss/cli/logs/model_log_watcher.py +24 -1
  9. truss/cli/remote_cli.py +29 -0
  10. truss/cli/resolvers/chain_team_resolver.py +82 -0
  11. truss/cli/resolvers/model_team_resolver.py +90 -0
  12. truss/cli/resolvers/training_project_team_resolver.py +81 -0
  13. truss/cli/train/cache.py +332 -0
  14. truss/cli/train/core.py +57 -163
  15. truss/cli/train/deploy_checkpoints/__init__.py +2 -2
  16. truss/cli/train/deploy_checkpoints/deploy_checkpoints.py +236 -103
  17. truss/cli/train/deploy_checkpoints/deploy_checkpoints_helpers.py +1 -52
  18. truss/cli/train/deploy_checkpoints/deploy_full_checkpoints.py +1 -86
  19. truss/cli/train/deploy_checkpoints/deploy_lora_checkpoints.py +1 -85
  20. truss/cli/train/deploy_checkpoints/deploy_whisper_checkpoints.py +1 -56
  21. truss/cli/train/types.py +18 -9
  22. truss/cli/train_commands.py +180 -35
  23. truss/cli/utils/common.py +40 -3
  24. truss/contexts/image_builder/serving_image_builder.py +17 -4
  25. truss/remote/baseten/api.py +215 -9
  26. truss/remote/baseten/core.py +63 -7
  27. truss/remote/baseten/custom_types.py +1 -0
  28. truss/remote/baseten/remote.py +42 -2
  29. truss/remote/baseten/service.py +0 -7
  30. truss/remote/baseten/utils/transfer.py +5 -2
  31. truss/templates/base.Dockerfile.jinja +8 -4
  32. truss/templates/control/control/application.py +51 -26
  33. truss/templates/control/control/endpoints.py +1 -5
  34. truss/templates/control/control/helpers/inference_server_process_controller.py +10 -4
  35. truss/templates/control/control/helpers/truss_patch/model_container_patch_applier.py +33 -18
  36. truss/templates/control/control/server.py +1 -1
  37. truss/templates/control/requirements.txt +1 -2
  38. truss/templates/docker_server/proxy.conf.jinja +13 -0
  39. truss/templates/docker_server/supervisord.conf.jinja +2 -1
  40. truss/templates/no_build.Dockerfile.jinja +1 -0
  41. truss/templates/server/requirements.txt +2 -3
  42. truss/templates/server/truss_server.py +2 -5
  43. truss/templates/server.Dockerfile.jinja +12 -12
  44. truss/templates/shared/lazy_data_resolver.py +214 -2
  45. truss/templates/shared/util.py +6 -5
  46. truss/tests/cli/chains/test_chains_team_parameter.py +443 -0
  47. truss/tests/cli/test_chains_cli.py +144 -0
  48. truss/tests/cli/test_cli.py +134 -1
  49. truss/tests/cli/test_cli_utils_common.py +11 -0
  50. truss/tests/cli/test_model_team_resolver.py +279 -0
  51. truss/tests/cli/train/test_cache_view.py +240 -3
  52. truss/tests/cli/train/test_deploy_checkpoints.py +2 -846
  53. truss/tests/cli/train/test_train_cli_core.py +2 -2
  54. truss/tests/cli/train/test_train_team_parameter.py +395 -0
  55. truss/tests/conftest.py +187 -0
  56. truss/tests/contexts/image_builder/test_serving_image_builder.py +10 -5
  57. truss/tests/remote/baseten/test_api.py +122 -3
  58. truss/tests/remote/baseten/test_chain_upload.py +294 -0
  59. truss/tests/remote/baseten/test_core.py +86 -0
  60. truss/tests/remote/baseten/test_remote.py +216 -288
  61. truss/tests/remote/baseten/test_service.py +56 -0
  62. truss/tests/templates/control/control/conftest.py +20 -0
  63. truss/tests/templates/control/control/test_endpoints.py +4 -0
  64. truss/tests/templates/control/control/test_server.py +8 -24
  65. truss/tests/templates/control/control/test_server_integration.py +4 -2
  66. truss/tests/test_config.py +21 -12
  67. truss/tests/test_data/server.Dockerfile +3 -1
  68. truss/tests/test_data/test_build_commands_truss/__init__.py +0 -0
  69. truss/tests/test_data/test_build_commands_truss/config.yaml +14 -0
  70. truss/tests/test_data/test_build_commands_truss/model/model.py +12 -0
  71. truss/tests/test_data/test_build_commands_truss/packages/constants/constants.py +1 -0
  72. truss/tests/test_data/test_truss_server_model_cache_v1/config.yaml +1 -0
  73. truss/tests/test_model_inference.py +13 -0
  74. truss/tests/util/test_env_vars.py +8 -3
  75. truss/util/__init__.py +0 -0
  76. truss/util/env_vars.py +19 -8
  77. truss/util/error_utils.py +37 -0
  78. {truss-0.11.6rc102.dist-info → truss-0.11.24rc2.dist-info}/METADATA +2 -2
  79. {truss-0.11.6rc102.dist-info → truss-0.11.24rc2.dist-info}/RECORD +88 -70
  80. {truss-0.11.6rc102.dist-info → truss-0.11.24rc2.dist-info}/WHEEL +1 -1
  81. truss_chains/deployment/deployment_client.py +16 -4
  82. truss_chains/private_types.py +18 -0
  83. truss_chains/public_api.py +3 -0
  84. truss_train/definitions.py +6 -4
  85. truss_train/deployment.py +43 -21
  86. truss_train/public_api.py +4 -2
  87. {truss-0.11.6rc102.dist-info → truss-0.11.24rc2.dist-info}/entry_points.txt +0 -0
  88. {truss-0.11.6rc102.dist-info → truss-0.11.24rc2.dist-info}/licenses/LICENSE +0 -0
@@ -1,7 +1,6 @@
1
+ import json
1
2
  import re
2
- import tempfile
3
3
  from collections import OrderedDict
4
- from pathlib import Path
5
4
  from typing import List, Optional, Union
6
5
 
7
6
  import rich_click as click
@@ -11,7 +10,8 @@ from truss.base import truss_config
11
10
  from truss.cli.train import common
12
11
  from truss.cli.train.types import (
13
12
  DeployCheckpointsConfigComplete,
14
- PrepareCheckpointResult,
13
+ DeploySuccessModelVersion,
14
+ DeploySuccessResult,
15
15
  )
16
16
  from truss.cli.utils.output import console
17
17
  from truss.remote.baseten.remote import BasetenRemote
@@ -25,18 +25,9 @@ from truss_train.definitions import (
25
25
  SecretReference,
26
26
  )
27
27
 
28
- from .deploy_full_checkpoints import (
29
- hydrate_full_checkpoint,
30
- render_vllm_full_truss_config,
31
- )
32
- from .deploy_lora_checkpoints import (
33
- hydrate_lora_checkpoint,
34
- render_vllm_lora_truss_config,
35
- )
36
- from .deploy_whisper_checkpoints import (
37
- hydrate_whisper_checkpoint,
38
- render_vllm_whisper_truss_config,
39
- )
28
+ from .deploy_full_checkpoints import hydrate_full_checkpoint
29
+ from .deploy_lora_checkpoints import hydrate_lora_checkpoint
30
+ from .deploy_whisper_checkpoints import hydrate_whisper_checkpoint
40
31
 
41
32
  HF_TOKEN_ENVVAR_NAME = "HF_TOKEN"
42
33
  # If we change this, make sure to update the logic in backend codebase
@@ -44,28 +35,184 @@ CHECKPOINT_PATTERN = re.compile(r".*checkpoint-\d+(?:-\d+)?$")
44
35
  ALLOWED_DEPLOYMENT_NAMES = re.compile(r"^[0-9a-zA-Z_\-\.]*$")
45
36
 
46
37
 
47
- def prepare_checkpoint_deploy(
38
+ def create_model_version_from_inference_template(
48
39
  remote_provider: BasetenRemote,
49
40
  checkpoint_deploy_config: DeployCheckpointsConfig,
50
41
  project_id: Optional[str],
51
42
  job_id: Optional[str],
52
- ) -> PrepareCheckpointResult:
43
+ dry_run: bool,
44
+ ) -> DeploySuccessResult:
53
45
  checkpoint_deploy_config = _hydrate_deploy_config(
54
46
  checkpoint_deploy_config, remote_provider, project_id, job_id
55
47
  )
56
- rendered_truss = _render_truss_config_for_checkpoint_deployment(
57
- checkpoint_deploy_config
58
- )
59
- truss_directory = Path(tempfile.mkdtemp())
60
- truss_config_path = truss_directory / "config.yaml"
61
- rendered_truss.write_to_yaml_file(truss_config_path)
62
- console.print(rendered_truss, style="green")
63
- console.print(f"Writing truss config to {truss_config_path}", style="yellow")
64
- return PrepareCheckpointResult(
65
- truss_directory=truss_directory,
48
+
49
+ request_data = _build_inference_template_request(
66
50
  checkpoint_deploy_config=checkpoint_deploy_config,
51
+ remote_provider=remote_provider,
52
+ dry_run=dry_run,
67
53
  )
68
54
 
55
+ # Call the GraphQL mutation to create model version from inference template
56
+ try:
57
+ result = remote_provider.api.create_model_version_from_inference_template(
58
+ request_data
59
+ )
60
+ truss_config_result = _get_truss_config_from_result(result)
61
+
62
+ model_version = None
63
+ if result and result.get("model_version"):
64
+ console.print(
65
+ f"Successfully created model version: {result['model_version']['name']}",
66
+ style="green",
67
+ )
68
+ console.print(
69
+ f"Model version ID: {result['model_version']['id']}", style="yellow"
70
+ )
71
+ model_version = DeploySuccessModelVersion.model_validate(
72
+ result["model_version"]
73
+ )
74
+ elif not dry_run:
75
+ console.print(
76
+ "Warning: Unexpected response format from server", style="yellow"
77
+ )
78
+ console.print(f"Response: {result}", style="yellow")
79
+
80
+ except Exception as e:
81
+ console.print(f"Error creating model version: {e}", style="red")
82
+ raise
83
+
84
+ return DeploySuccessResult(
85
+ deploy_config=checkpoint_deploy_config,
86
+ truss_config=truss_config_result,
87
+ model_version=model_version,
88
+ )
89
+
90
+
91
+ def _get_truss_config_from_result(result: dict) -> Optional[truss_config.TrussConfig]:
92
+ if result and result.get("truss_config"):
93
+ truss_config_dict = json.loads(result["truss_config"])
94
+ return truss_config.TrussConfig.from_dict(truss_config_dict)
95
+ # Although this should never happen, we defensively allow ourselves to return None
96
+ # because we need a failure to handle the truss config doesn't necessarily mean we failed to deploy
97
+ # the model version.
98
+ console.print(
99
+ "No truss config returned. Reach out to Baseten for support if this persists.",
100
+ style="red",
101
+ )
102
+ return None
103
+
104
+
105
+ def _build_inference_template_request(
106
+ checkpoint_deploy_config: DeployCheckpointsConfigComplete,
107
+ remote_provider: BasetenRemote,
108
+ dry_run: bool,
109
+ ) -> dict:
110
+ """
111
+ Build the GraphQL request data structure for createModelVersionFromInferenceTemplate mutation.
112
+ """
113
+
114
+ # Build weights sources
115
+ weights_sources = []
116
+ for checkpoint in checkpoint_deploy_config.checkpoint_details.checkpoints:
117
+ # Extract checkpoint name from the first path
118
+ weights_source = {
119
+ "weight_source_type": "B10_CHECKPOINTING",
120
+ "b10_training_checkpoint_weights_source": {
121
+ "checkpoint": {
122
+ "training_job_id": checkpoint.training_job_id,
123
+ "checkpoint_name": checkpoint.checkpoint_name,
124
+ }
125
+ },
126
+ }
127
+ weights_sources.append(weights_source)
128
+
129
+ # Build environment variables
130
+ environment_variables = []
131
+ for name, value in checkpoint_deploy_config.runtime.environment_variables.items():
132
+ if isinstance(value, SecretReference):
133
+ env_var = {"name": name, "value": value.name, "is_secret_reference": True}
134
+ else:
135
+ env_var = {"name": name, "value": str(value), "is_secret_reference": False}
136
+ environment_variables.append(env_var)
137
+
138
+ # Build inference stack
139
+ inference_stack = {
140
+ "stack_type": "VLLM",
141
+ "environment_variables": environment_variables,
142
+ }
143
+
144
+ # Get instance type ID from compute spec
145
+ instance_type_id = _get_instance_type_id(
146
+ checkpoint_deploy_config.compute, remote_provider
147
+ )
148
+
149
+ # Build the complete request
150
+ request_data = {
151
+ "metadata": {"oracle_name": checkpoint_deploy_config.model_name},
152
+ "weights_sources": weights_sources,
153
+ "inference_stack": inference_stack,
154
+ "instance_type_id": instance_type_id,
155
+ "dry_run": dry_run,
156
+ }
157
+
158
+ return request_data
159
+
160
+
161
+ def _get_instance_type_id(compute: Compute, remote_provider: BasetenRemote) -> str:
162
+ """
163
+ Get the instance type ID based on the compute specification.
164
+ Fetches available instance types from the API and maps compute specs to instance type IDs.
165
+ Only considers single-node instances (node_count == 1).
166
+ """
167
+ # step 1: fetch the instance types from the API
168
+ instance_types = remote_provider.api.get_instance_types()
169
+ # step 2: sort them into two different dictionaries, excluding multi-node instances:
170
+ cpu_instance_types = {
171
+ it.id: it for it in instance_types if it.gpu_count == 0 and it.node_count == 1
172
+ }
173
+ gpu_instance_types = {
174
+ it.id: it for it in instance_types if it.gpu_count > 0 and it.node_count == 1
175
+ }
176
+ # step 3: if compute is cpu, find the smallest such cpu that matches the compute request
177
+ if not compute.accelerator or compute.accelerator.accelerator is None:
178
+ compute_as_truss_config = compute.to_truss_config()
179
+ smallest_cpu_instance_type = None
180
+ for it in cpu_instance_types.values():
181
+ if (
182
+ it.millicpu_limit / 1000 >= compute.cpu_count
183
+ and it.memory_limit >= compute_as_truss_config.memory_in_bytes
184
+ ):
185
+ if (
186
+ smallest_cpu_instance_type is None
187
+ or it.millicpu_limit < smallest_cpu_instance_type.millicpu_limit
188
+ ):
189
+ smallest_cpu_instance_type = it
190
+ if not smallest_cpu_instance_type:
191
+ raise ValueError(
192
+ f"Unable to find single-node instance type for {compute.cpu_count} CPU and {compute.memory} memory. Reach out to Baseten for support if this persists."
193
+ )
194
+ return smallest_cpu_instance_type.id
195
+ # step 4: if compute is gpu, find the smallest such gpu by instance type
196
+ else:
197
+ assert compute.accelerator.accelerator is not None
198
+ compute_as_truss_config = compute.to_truss_config()
199
+ smallest_gpu_instance_type = None
200
+ for it in gpu_instance_types.values():
201
+ if (
202
+ it.gpu_type == compute.accelerator.accelerator.value
203
+ and it.gpu_count >= compute.accelerator.count
204
+ ):
205
+ if (
206
+ smallest_gpu_instance_type is None
207
+ or it.gpu_count < smallest_gpu_instance_type.gpu_count
208
+ ):
209
+ smallest_gpu_instance_type = it
210
+ if not smallest_gpu_instance_type:
211
+ raise ValueError(
212
+ f"Unable to find single-node instance type for {compute.accelerator}:{compute.accelerator.count}. Reach out to Baseten for support if this persists."
213
+ )
214
+ return smallest_gpu_instance_type.id
215
+
69
216
 
70
217
  def _validate_base_model_id(
71
218
  base_model_id: Optional[str], model_weight_format: ModelWeightsFormat
@@ -93,18 +240,12 @@ def _get_model_name(
93
240
  else ""
94
241
  )
95
242
 
96
- model_name = inquirer.text(
243
+ return inquirer.text(
97
244
  message=f"Enter the model name for your {model_weight_format.value} model.",
98
245
  validate=lambda s: s and s.strip(),
99
246
  default=default,
100
247
  ).execute()
101
248
 
102
- if model_weight_format == ModelWeightsFormat.FULL:
103
- model_name += "-vLLM-Full"
104
- elif model_weight_format == ModelWeightsFormat.LORA:
105
- model_name += "-vLLM-LORA"
106
- return model_name
107
-
108
249
 
109
250
  def _hydrate_deploy_config(
110
251
  deploy_config: DeployCheckpointsConfig,
@@ -123,53 +264,18 @@ def _hydrate_deploy_config(
123
264
  else:
124
265
  model_name = _get_model_name(model_weight_format, base_model_id)
125
266
 
126
- compute = _ensure_compute_spec(deploy_config.compute)
267
+ compute = _ensure_compute_spec(deploy_config.compute, remote_provider)
127
268
 
128
269
  runtime = _ensure_runtime_config(deploy_config.runtime)
129
- deployment_name = _ensure_deployment_name(
130
- deploy_config.deployment_name, checkpoint_details.checkpoints
131
- )
132
270
 
133
271
  return DeployCheckpointsConfigComplete(
134
272
  checkpoint_details=checkpoint_details,
135
273
  model_name=model_name,
136
- deployment_name=deployment_name,
137
274
  runtime=runtime,
138
275
  compute=compute,
139
- model_weight_format=model_weight_format.to_truss_config(), # type: ignore[attr-defined]
140
276
  )
141
277
 
142
278
 
143
- def _ensure_deployment_name(
144
- deploy_config_deployment_name: Optional[str], checkpoints: List[Checkpoint]
145
- ) -> str:
146
- if deploy_config_deployment_name:
147
- return deploy_config_deployment_name
148
-
149
- default_deployment_name = "checkpoint"
150
-
151
- if checkpoints and checkpoints[0].paths:
152
- first_checkpoint_name = checkpoints[0].paths[0].strip("/").split("/")[-1]
153
-
154
- if ALLOWED_DEPLOYMENT_NAMES.match(first_checkpoint_name):
155
- # Allow autoincrementing if the checkpoint matches both regexes
156
- if (
157
- CHECKPOINT_PATTERN.match(first_checkpoint_name)
158
- and len(checkpoints) == 1
159
- ):
160
- return first_checkpoint_name
161
-
162
- # If no valid autoincrementing checkpoint name is found, prompt the user
163
- deployment_name = inquirer.text(
164
- message="Enter the deployment name.", default=default_deployment_name
165
- ).execute()
166
-
167
- if not deployment_name:
168
- raise click.UsageError("Deployment name is required.")
169
-
170
- return deployment_name
171
-
172
-
173
279
  def hydrate_checkpoint(
174
280
  job_id: str, checkpoint_id: str, checkpoint: dict, checkpoint_type: str
175
281
  ) -> Checkpoint:
@@ -190,26 +296,6 @@ def hydrate_checkpoint(
190
296
  )
191
297
 
192
298
 
193
- def _render_truss_config_for_checkpoint_deployment(
194
- checkpoint_deploy: DeployCheckpointsConfigComplete,
195
- ) -> truss_config.TrussConfig:
196
- """
197
- Render truss config for checkpoint deployment.
198
- Currently supports LoRA checkpoints via vLLM, but can be extended for other formats.
199
- """
200
- # Delegate to specific rendering function based on model weight format
201
- if checkpoint_deploy.model_weight_format == ModelWeightsFormat.LORA:
202
- return render_vllm_lora_truss_config(checkpoint_deploy)
203
- elif checkpoint_deploy.model_weight_format == ModelWeightsFormat.FULL:
204
- return render_vllm_full_truss_config(checkpoint_deploy)
205
- elif checkpoint_deploy.model_weight_format == ModelWeightsFormat.WHISPER:
206
- return render_vllm_whisper_truss_config(checkpoint_deploy)
207
- else:
208
- raise ValueError(
209
- f"Unsupported model weight format: {checkpoint_deploy.model_weight_format}. Please upgrade to the latest Truss version to access the latest supported formats. Contact Baseten if you would like us to support additional formats."
210
- )
211
-
212
-
213
299
  def _ensure_checkpoint_details(
214
300
  remote_provider: BasetenRemote,
215
301
  checkpoint_details: Optional[CheckpointList],
@@ -217,6 +303,7 @@ def _ensure_checkpoint_details(
217
303
  job_id: Optional[str],
218
304
  ) -> CheckpointList:
219
305
  if checkpoint_details and checkpoint_details.checkpoints:
306
+ # TODO: check here
220
307
  return _process_user_provided_checkpoints(checkpoint_details, remote_provider)
221
308
  else:
222
309
  return _prompt_user_for_checkpoint_details(
@@ -309,31 +396,77 @@ def _select_multiple_checkpoints(checkpoint_id_options: List[str]) -> List[str]:
309
396
  return checkpoint_ids
310
397
 
311
398
 
312
- def _ensure_compute_spec(compute: Optional[Compute]) -> Compute:
399
+ def _ensure_compute_spec(
400
+ compute: Optional[Compute], remote_provider: BasetenRemote
401
+ ) -> Compute:
313
402
  if not compute:
314
403
  compute = Compute(cpu_count=0, memory="0Mi")
315
- compute.accelerator = _get_accelerator_if_specified(compute.accelerator)
404
+ compute = _get_accelerator_if_specified(compute, remote_provider)
316
405
  return compute
317
406
 
318
407
 
319
408
  def _get_accelerator_if_specified(
320
- user_input: Optional[truss_config.AcceleratorSpec],
321
- ) -> Optional[truss_config.AcceleratorSpec]:
322
- if user_input:
409
+ user_input: Optional[Compute], remote_provider: BasetenRemote
410
+ ) -> Compute:
411
+ if user_input and user_input.accelerator:
323
412
  return user_input
413
+
414
+ # Fetch available instance types to get valid GPU options
415
+ instance_types = remote_provider.api.get_instance_types()
416
+
417
+ # Extract unique accelerator types from instance types
418
+ accelerator_options = set()
419
+ for it in instance_types:
420
+ if it.gpu_type and it.gpu_count > 0:
421
+ accelerator_options.add(it.gpu_type)
422
+
423
+ # Convert to sorted list and add CPU option
424
+ choices = sorted(list(accelerator_options)) + [None]
425
+
426
+ if not choices or choices == [None]:
427
+ console.print("No GPU instance types available, using CPU", style="yellow")
428
+ return Compute(cpu_count=0, memory="0Mi", accelerator=None)
429
+
324
430
  # prompt user for accelerator
325
431
  gpu_type = inquirer.select(
326
432
  message="Select the GPU type to use for deployment. Select None for CPU.",
327
- choices=[x.value for x in truss_config.Accelerator] + [None],
433
+ choices=choices,
328
434
  ).execute()
435
+
329
436
  if gpu_type is None:
330
- return None
331
- count = inquirer.text(
332
- message="Enter the number of accelerators to use for deployment.",
333
- default="1",
334
- validate=lambda x: x.isdigit() and int(x) > 0 and int(x) <= 8,
335
- ).execute()
336
- return truss_config.AcceleratorSpec(accelerator=gpu_type, count=int(count))
437
+ return Compute(cpu_count=0, memory="0Mi", accelerator=None)
438
+
439
+ # Get available counts for the selected GPU type
440
+ available_counts = set()
441
+ for it in instance_types:
442
+ if it.gpu_type == gpu_type and it.gpu_count > 0:
443
+ available_counts.add(it.gpu_count)
444
+ if not available_counts:
445
+ raise ValueError(
446
+ f"No available counts for {gpu_type}. Reach out to Baseten for support if this persists."
447
+ )
448
+
449
+ if available_counts:
450
+ count_choices = sorted(list(available_counts))
451
+ count = inquirer.select(
452
+ message=f"Select the number of {gpu_type} GPUs to use for deployment.",
453
+ choices=count_choices,
454
+ default=str(count_choices[0]),
455
+ ).execute()
456
+ else:
457
+ count = inquirer.text(
458
+ message=f"Enter the number of {gpu_type} accelerators to use for deployment.",
459
+ default="1",
460
+ validate=lambda x: x.isdigit() and int(x) > 0 and int(x) <= 8,
461
+ ).execute()
462
+
463
+ return Compute(
464
+ cpu_count=0,
465
+ memory="0Mi",
466
+ accelerator=truss_config.AcceleratorSpec(
467
+ accelerator=gpu_type.replace("-", "_"), count=int(count)
468
+ ),
469
+ )
337
470
 
338
471
 
339
472
  def _get_base_model_id(user_input: Optional[str], checkpoint: dict) -> Optional[str]:
@@ -1,52 +1 @@
1
- import os
2
- from pathlib import Path
3
-
4
- from truss.base import truss_config
5
- from truss.cli.train.types import DeployCheckpointsConfigComplete
6
- from truss_train.definitions import ModelWeightsFormat, SecretReference
7
-
8
- START_COMMAND_ENVVAR_NAME = "BT_DOCKER_SERVER_START_CMD"
9
-
10
-
11
- def setup_base_truss_config(
12
- checkpoint_deploy: DeployCheckpointsConfigComplete,
13
- ) -> truss_config.TrussConfig:
14
- """Set up the base truss config with common properties."""
15
- truss_deploy_config = None
16
- truss_base_file = (
17
- "deploy_from_checkpoint_config_whisper.yml"
18
- if checkpoint_deploy.model_weight_format == ModelWeightsFormat.WHISPER
19
- else "deploy_from_checkpoint_config.yml"
20
- )
21
- truss_deploy_config = truss_config.TrussConfig.from_yaml(
22
- Path(os.path.dirname(__file__), "..", truss_base_file)
23
- )
24
- if not truss_deploy_config.docker_server:
25
- raise ValueError(
26
- "Unexpected checkpoint deployment config: missing docker_server"
27
- )
28
-
29
- truss_deploy_config.model_name = checkpoint_deploy.model_name
30
- truss_deploy_config.training_checkpoints = (
31
- checkpoint_deploy.checkpoint_details.to_truss_config()
32
- )
33
- truss_deploy_config.resources = checkpoint_deploy.compute.to_truss_config()
34
-
35
- return truss_deploy_config
36
-
37
-
38
- def setup_environment_variables_and_secrets(
39
- truss_deploy_config: truss_config.TrussConfig,
40
- checkpoint_deploy: DeployCheckpointsConfigComplete,
41
- ) -> str:
42
- """Set up environment variables and secrets, return start command envvars string."""
43
- start_command_envvars = ""
44
-
45
- for key, value in checkpoint_deploy.runtime.environment_variables.items():
46
- if isinstance(value, SecretReference):
47
- truss_deploy_config.secrets[value.name] = "set token in baseten workspace"
48
- start_command_envvars = f"{key}=$(cat /secrets/{value.name})"
49
- else:
50
- truss_deploy_config.environment_variables[key] = value
51
-
52
- return start_command_envvars
1
+ # This file is kept for potential future use but currently contains no active code
@@ -1,94 +1,9 @@
1
- from pathlib import Path
2
-
3
- from jinja2 import Template
4
-
5
- from truss.base import truss_config
6
- from truss.cli.train.deploy_checkpoints.deploy_checkpoints_helpers import (
7
- START_COMMAND_ENVVAR_NAME,
8
- )
9
- from truss.cli.train.types import DeployCheckpointsConfigComplete
10
1
  from truss_train.definitions import FullCheckpoint
11
2
 
12
- from .deploy_checkpoints_helpers import (
13
- setup_base_truss_config,
14
- setup_environment_variables_and_secrets,
15
- )
16
-
17
- # NB(aghilan): Transformers was recently changed to save a chat_template.jinja file instead of inside the tokenizer_config.json file.
18
- # Old Models will not have this file, so we check for it and use it if it exists.
19
- # vLLM will not automatically resolve the chat_template.jinja file, so we need to pass it to the start command.
20
- # This logic is needed for any models trained using Transformers v4.51.3 or later
21
- VLLM_FULL_START_COMMAND = Template(
22
- "sh -c '{% if envvars %}{{ envvars }} {% endif %}"
23
- 'HF_TOKEN="$$(cat /secrets/hf_access_token)" && export HF_TOKEN && '
24
- "if [ -f {{ model_path }}/chat_template.jinja ]; then "
25
- " vllm serve {{ model_path }} --chat-template {{ model_path }}/chat_template.jinja "
26
- " --port 8000 --tensor-parallel-size {{ specify_tensor_parallelism }} --dtype bfloat16; "
27
- "else "
28
- " vllm serve {{ model_path }} --port 8000 --tensor-parallel-size {{ specify_tensor_parallelism }} --dtype bfloat16; "
29
- "fi'"
30
- )
31
-
32
-
33
- def render_vllm_full_truss_config(
34
- checkpoint_deploy: DeployCheckpointsConfigComplete,
35
- ) -> truss_config.TrussConfig:
36
- """Render truss config specifically for full checkpoints using vLLM."""
37
- truss_deploy_config = setup_base_truss_config(checkpoint_deploy)
38
-
39
- start_command_envvars = setup_environment_variables_and_secrets(
40
- truss_deploy_config, checkpoint_deploy
41
- )
42
-
43
- checkpoint_str = build_full_checkpoint_string(truss_deploy_config)
44
-
45
- accelerator = checkpoint_deploy.compute.accelerator
46
-
47
- start_command_args = {
48
- "model_path": checkpoint_str,
49
- "envvars": start_command_envvars,
50
- "specify_tensor_parallelism": accelerator.count if accelerator else 1,
51
- }
52
- # Note: we set the start command as an environment variable in supervisord config.
53
- # This is so that we don't have to change the supervisord config when the start command changes.
54
- # Our goal is to reduce the number of times we need to rebuild the image, and allow us to deploy faster.
55
- start_command = VLLM_FULL_START_COMMAND.render(**start_command_args)
56
- truss_deploy_config.environment_variables[START_COMMAND_ENVVAR_NAME] = start_command
57
- # Note: supervisord uses the convention %(ENV_VAR_NAME)s to access environment variable VAR_NAME
58
- truss_deploy_config.docker_server.start_command = ( # type: ignore[union-attr]
59
- f"%(ENV_{START_COMMAND_ENVVAR_NAME})s"
60
- )
61
-
62
- return truss_deploy_config
63
-
64
3
 
65
4
  def hydrate_full_checkpoint(
66
5
  job_id: str, checkpoint_id: str, checkpoint: dict
67
6
  ) -> FullCheckpoint:
68
7
  """Create a Checkpoint object for full model weights."""
69
8
  # NOTE: Slash at the end is important since it means the checkpoint is a directory
70
- paths = [f"rank-0/{checkpoint_id}/"]
71
- return FullCheckpoint(training_job_id=job_id, paths=paths)
72
-
73
-
74
- def build_full_checkpoint_string(truss_deploy_config) -> str:
75
- """Build checkpoint string from artifact references for full checkpoints.
76
-
77
- Args:
78
- truss_deploy_config: The truss deploy configuration containing training checkpoints.
79
-
80
- Returns:
81
- A space-separated string of checkpoint paths.
82
- """
83
- checkpoint_parts = []
84
- for (
85
- truss_checkpoint
86
- ) in truss_deploy_config.training_checkpoints.artifact_references: # type: ignore
87
- ckpt_path = Path(
88
- truss_deploy_config.training_checkpoints.download_folder, # type: ignore
89
- truss_checkpoint.training_job_id,
90
- truss_checkpoint.paths[0],
91
- )
92
- checkpoint_parts.append(str(ckpt_path))
93
-
94
- return " ".join(checkpoint_parts)
9
+ return FullCheckpoint(training_job_id=job_id, checkpoint_name=checkpoint_id)