truss 0.11.8rc12__py3-none-any.whl → 0.11.9rc2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of truss might be problematic. Click here for more details.

truss/cli/train/core.py CHANGED
@@ -16,7 +16,7 @@ from rich.text import Text
16
16
 
17
17
  from truss.cli.train import common, deploy_checkpoints
18
18
  from truss.cli.train.metrics_watcher import MetricsWatcher
19
- from truss.cli.train.types import PrepareCheckpointArgs, PrepareCheckpointResult
19
+ from truss.cli.train.types import DeployCheckpointArgs, DeployCheckpointsConfigComplete
20
20
  from truss.cli.utils import common as cli_common
21
21
  from truss.cli.utils.output import console
22
22
  from truss.remote.baseten.custom_types import (
@@ -242,35 +242,35 @@ def view_training_job_metrics(
242
242
  metrics_display.watch()
243
243
 
244
244
 
245
- def prepare_checkpoint_deploy(
246
- remote_provider: BasetenRemote, args: PrepareCheckpointArgs
247
- ) -> PrepareCheckpointResult:
245
+ def create_model_version_from_inference_template(
246
+ remote_provider: BasetenRemote, args: DeployCheckpointArgs
247
+ ) -> DeployCheckpointsConfigComplete:
248
248
  if not args.deploy_config_path:
249
- return deploy_checkpoints.prepare_checkpoint_deploy(
249
+ return deploy_checkpoints.create_model_version_from_inference_template(
250
250
  remote_provider, DeployCheckpointsConfig(), args.project_id, args.job_id
251
251
  )
252
252
  #### User provided a checkpoint deploy config file
253
253
  with loader.import_deploy_checkpoints_config(
254
254
  Path(args.deploy_config_path)
255
255
  ) as checkpoint_deploy:
256
- return deploy_checkpoints.prepare_checkpoint_deploy(
256
+ return deploy_checkpoints.create_model_version_from_inference_template(
257
257
  remote_provider, checkpoint_deploy, args.project_id, args.job_id
258
258
  )
259
259
 
260
260
 
261
261
  def _get_checkpoint_names(
262
- prepare_checkpoint_result: PrepareCheckpointResult,
262
+ checkpoint_deploy_config: DeployCheckpointsConfigComplete,
263
263
  ) -> list[str]:
264
264
  return [
265
265
  checkpoint.paths[0].strip("/").split("/")[-1]
266
- for checkpoint in prepare_checkpoint_result.checkpoint_deploy_config.checkpoint_details.checkpoints
266
+ for checkpoint in checkpoint_deploy_config.checkpoint_details.checkpoints
267
267
  ]
268
268
 
269
269
 
270
270
  def print_deploy_checkpoints_success_message(
271
- prepare_checkpoint_result: PrepareCheckpointResult,
271
+ checkpoint_deploy_config: DeployCheckpointsConfigComplete,
272
272
  ):
273
- checkpoint_names = _get_checkpoint_names(prepare_checkpoint_result)
273
+ checkpoint_names = _get_checkpoint_names(checkpoint_deploy_config)
274
274
  console.print(
275
275
  Text("\nTo run the model"),
276
276
  Text("ensure your `model` parameter is set to one of"),
@@ -279,7 +279,9 @@ def print_deploy_checkpoints_success_message(
279
279
  style="magenta",
280
280
  ),
281
281
  Text("in your request. An example request body might look like this:"),
282
- Text(f"\n{{'model': {checkpoint_names[0]}, 'messages': [...]}}", style="green"),
282
+ Text(
283
+ f'\n{{"model": "{checkpoint_names[0]}", "messages": [...]}}', style="green"
284
+ ),
283
285
  )
284
286
 
285
287
 
@@ -1,3 +1,3 @@
1
- from .deploy_checkpoints import prepare_checkpoint_deploy
1
+ from .deploy_checkpoints import create_model_version_from_inference_template
2
2
 
3
- __all__ = ["prepare_checkpoint_deploy"]
3
+ __all__ = ["create_model_version_from_inference_template"]
@@ -1,7 +1,5 @@
1
1
  import re
2
- import tempfile
3
2
  from collections import OrderedDict
4
- from pathlib import Path
5
3
  from typing import List, Optional, Union
6
4
 
7
5
  import rich_click as click
@@ -9,10 +7,7 @@ from InquirerPy import inquirer
9
7
 
10
8
  from truss.base import truss_config
11
9
  from truss.cli.train import common
12
- from truss.cli.train.types import (
13
- DeployCheckpointsConfigComplete,
14
- PrepareCheckpointResult,
15
- )
10
+ from truss.cli.train.types import DeployCheckpointsConfigComplete
16
11
  from truss.cli.utils.output import console
17
12
  from truss.remote.baseten.remote import BasetenRemote
18
13
  from truss_train.definitions import (
@@ -25,18 +20,9 @@ from truss_train.definitions import (
25
20
  SecretReference,
26
21
  )
27
22
 
28
- from .deploy_full_checkpoints import (
29
- hydrate_full_checkpoint,
30
- render_vllm_full_truss_config,
31
- )
32
- from .deploy_lora_checkpoints import (
33
- hydrate_lora_checkpoint,
34
- render_vllm_lora_truss_config,
35
- )
36
- from .deploy_whisper_checkpoints import (
37
- hydrate_whisper_checkpoint,
38
- render_vllm_whisper_truss_config,
39
- )
23
+ from .deploy_full_checkpoints import hydrate_full_checkpoint
24
+ from .deploy_lora_checkpoints import hydrate_lora_checkpoint
25
+ from .deploy_whisper_checkpoints import hydrate_whisper_checkpoint
40
26
 
41
27
  HF_TOKEN_ENVVAR_NAME = "HF_TOKEN"
42
28
  # If we change this, make sure to update the logic in backend codebase
@@ -44,28 +30,162 @@ CHECKPOINT_PATTERN = re.compile(r".*checkpoint-\d+(?:-\d+)?$")
44
30
  ALLOWED_DEPLOYMENT_NAMES = re.compile(r"^[0-9a-zA-Z_\-\.]*$")
45
31
 
46
32
 
47
- def prepare_checkpoint_deploy(
33
+ def create_model_version_from_inference_template(
48
34
  remote_provider: BasetenRemote,
49
35
  checkpoint_deploy_config: DeployCheckpointsConfig,
50
36
  project_id: Optional[str],
51
37
  job_id: Optional[str],
52
- ) -> PrepareCheckpointResult:
38
+ ) -> DeployCheckpointsConfigComplete:
53
39
  checkpoint_deploy_config = _hydrate_deploy_config(
54
40
  checkpoint_deploy_config, remote_provider, project_id, job_id
55
41
  )
56
- rendered_truss = _render_truss_config_for_checkpoint_deployment(
57
- checkpoint_deploy_config
42
+
43
+ request_data = _build_inference_template_request(
44
+ checkpoint_deploy_config, remote_provider
58
45
  )
59
- truss_directory = Path(tempfile.mkdtemp())
60
- truss_config_path = truss_directory / "config.yaml"
61
- rendered_truss.write_to_yaml_file(truss_config_path)
62
- console.print(rendered_truss, style="green")
63
- console.print(f"Writing truss config to {truss_config_path}", style="yellow")
64
- return PrepareCheckpointResult(
65
- truss_directory=truss_directory,
66
- checkpoint_deploy_config=checkpoint_deploy_config,
46
+
47
+ # Call the GraphQL mutation to create model version from inference template
48
+ try:
49
+ result = remote_provider.api.create_model_version_from_inference_template(
50
+ request_data
51
+ )
52
+
53
+ if result and result.get("model_version"):
54
+ console.print(
55
+ f"Successfully created model version: {result['model_version']['name']}",
56
+ style="green",
57
+ )
58
+ console.print(
59
+ f"Model version ID: {result['model_version']['id']}", style="yellow"
60
+ )
61
+ else:
62
+ console.print(
63
+ "Warning: Unexpected response format from server", style="yellow"
64
+ )
65
+ console.print(f"Response: {result}", style="yellow")
66
+
67
+ except Exception as e:
68
+ console.print(f"Error creating model version: {e}", style="red")
69
+ raise
70
+
71
+ return checkpoint_deploy_config
72
+
73
+
74
+ def _build_inference_template_request(
75
+ checkpoint_deploy_config: DeployCheckpointsConfigComplete,
76
+ remote_provider: BasetenRemote,
77
+ ) -> dict:
78
+ """
79
+ Build the GraphQL request data structure for createModelVersionFromInferenceTemplate mutation.
80
+ """
81
+
82
+ # Build weights sources
83
+ weights_sources = []
84
+ for checkpoint in checkpoint_deploy_config.checkpoint_details.checkpoints:
85
+ # Extract checkpoint name from the first path
86
+ checkpoint_name = (
87
+ checkpoint.paths[0].strip("/").split("/")[-1]
88
+ if checkpoint.paths
89
+ else "checkpoint"
90
+ )
91
+
92
+ weights_source = {
93
+ "weight_source_type": "B10_CHECKPOINTING",
94
+ "b10_training_checkpoint_weights_source": {
95
+ "checkpoint": {
96
+ "training_job_id": checkpoint.training_job_id,
97
+ "checkpoint_name": checkpoint_name,
98
+ }
99
+ },
100
+ }
101
+ weights_sources.append(weights_source)
102
+
103
+ # Build environment variables
104
+ environment_variables = []
105
+ for name, value in checkpoint_deploy_config.runtime.environment_variables.items():
106
+ if isinstance(value, SecretReference):
107
+ env_var = {"name": name, "value": value.name, "is_secret_reference": True}
108
+ else:
109
+ env_var = {"name": name, "value": str(value), "is_secret_reference": False}
110
+ environment_variables.append(env_var)
111
+
112
+ # Build inference stack
113
+ inference_stack = {
114
+ "stack_type": "VLLM",
115
+ "environment_variables": environment_variables,
116
+ }
117
+
118
+ # Get instance type ID from compute spec
119
+ instance_type_id = _get_instance_type_id(
120
+ checkpoint_deploy_config.compute, remote_provider
67
121
  )
68
122
 
123
+ # Build the complete request
124
+ request_data = {
125
+ "metadata": {"oracle_name": checkpoint_deploy_config.model_name},
126
+ "weights_sources": weights_sources,
127
+ "inference_stack": inference_stack,
128
+ "instance_type_id": instance_type_id,
129
+ }
130
+
131
+ return request_data
132
+
133
+
134
+ def _get_instance_type_id(compute: Compute, remote_provider: BasetenRemote) -> str:
135
+ """
136
+ Get the instance type ID based on the compute specification.
137
+ Fetches available instance types from the API and maps compute specs to instance type IDs.
138
+ Only considers single-node instances (node_count == 1).
139
+ """
140
+ # step 1: fetch the instance types from the API
141
+ instance_types = remote_provider.api.get_instance_types()
142
+ # step 2: sort them into two different dictionaries, excluding multi-node instances:
143
+ cpu_instance_types = {
144
+ it.id: it for it in instance_types if it.gpu_count == 0 and it.node_count == 1
145
+ }
146
+ gpu_instance_types = {
147
+ it.id: it for it in instance_types if it.gpu_count > 0 and it.node_count == 1
148
+ }
149
+ # step 3: if compute is cpu, find the smallest such cpu that matches the compute request
150
+ if not compute.accelerator or compute.accelerator.accelerator is None:
151
+ compute_as_truss_config = compute.to_truss_config()
152
+ smallest_cpu_instance_type = None
153
+ for it in cpu_instance_types.values():
154
+ if (
155
+ it.millicpu_limit / 1000 >= compute.cpu_count
156
+ and it.memory_limit >= compute_as_truss_config.memory_in_bytes
157
+ ):
158
+ if (
159
+ smallest_cpu_instance_type is None
160
+ or it.millicpu_limit < smallest_cpu_instance_type.millicpu_limit
161
+ ):
162
+ smallest_cpu_instance_type = it
163
+ if not smallest_cpu_instance_type:
164
+ raise ValueError(
165
+ f"Unable to find single-node instance type for {compute.cpu_count} CPU and {compute.memory} memory. Reach out to Baseten for support if this persists."
166
+ )
167
+ return smallest_cpu_instance_type.id
168
+ # step 4: if compute is gpu, find the smallest such gpu by instance type
169
+ else:
170
+ assert compute.accelerator.accelerator is not None
171
+ compute_as_truss_config = compute.to_truss_config()
172
+ smallest_gpu_instance_type = None
173
+ for it in gpu_instance_types.values():
174
+ if (
175
+ it.gpu_type == compute.accelerator.accelerator.value
176
+ and it.gpu_count >= compute.accelerator.count
177
+ ):
178
+ if (
179
+ smallest_gpu_instance_type is None
180
+ or it.gpu_count < smallest_gpu_instance_type.gpu_count
181
+ ):
182
+ smallest_gpu_instance_type = it
183
+ if not smallest_gpu_instance_type:
184
+ raise ValueError(
185
+ f"Unable to find single-node instance type for {compute.accelerator}:{compute.accelerator.count}. Reach out to Baseten for support if this persists."
186
+ )
187
+ return smallest_gpu_instance_type.id
188
+
69
189
 
70
190
  def _validate_base_model_id(
71
191
  base_model_id: Optional[str], model_weight_format: ModelWeightsFormat
@@ -93,18 +213,12 @@ def _get_model_name(
93
213
  else ""
94
214
  )
95
215
 
96
- model_name = inquirer.text(
216
+ return inquirer.text(
97
217
  message=f"Enter the model name for your {model_weight_format.value} model.",
98
218
  validate=lambda s: s and s.strip(),
99
219
  default=default,
100
220
  ).execute()
101
221
 
102
- if model_weight_format == ModelWeightsFormat.FULL:
103
- model_name += "-vLLM-Full"
104
- elif model_weight_format == ModelWeightsFormat.LORA:
105
- model_name += "-vLLM-LORA"
106
- return model_name
107
-
108
222
 
109
223
  def _hydrate_deploy_config(
110
224
  deploy_config: DeployCheckpointsConfig,
@@ -123,53 +237,18 @@ def _hydrate_deploy_config(
123
237
  else:
124
238
  model_name = _get_model_name(model_weight_format, base_model_id)
125
239
 
126
- compute = _ensure_compute_spec(deploy_config.compute)
240
+ compute = _ensure_compute_spec(deploy_config.compute, remote_provider)
127
241
 
128
242
  runtime = _ensure_runtime_config(deploy_config.runtime)
129
- deployment_name = _ensure_deployment_name(
130
- deploy_config.deployment_name, checkpoint_details.checkpoints
131
- )
132
243
 
133
244
  return DeployCheckpointsConfigComplete(
134
245
  checkpoint_details=checkpoint_details,
135
246
  model_name=model_name,
136
- deployment_name=deployment_name,
137
247
  runtime=runtime,
138
248
  compute=compute,
139
- model_weight_format=model_weight_format.to_truss_config(), # type: ignore[attr-defined]
140
249
  )
141
250
 
142
251
 
143
- def _ensure_deployment_name(
144
- deploy_config_deployment_name: Optional[str], checkpoints: List[Checkpoint]
145
- ) -> str:
146
- if deploy_config_deployment_name:
147
- return deploy_config_deployment_name
148
-
149
- default_deployment_name = "checkpoint"
150
-
151
- if checkpoints and checkpoints[0].paths:
152
- first_checkpoint_name = checkpoints[0].paths[0].strip("/").split("/")[-1]
153
-
154
- if ALLOWED_DEPLOYMENT_NAMES.match(first_checkpoint_name):
155
- # Allow autoincrementing if the checkpoint matches both regexes
156
- if (
157
- CHECKPOINT_PATTERN.match(first_checkpoint_name)
158
- and len(checkpoints) == 1
159
- ):
160
- return first_checkpoint_name
161
-
162
- # If no valid autoincrementing checkpoint name is found, prompt the user
163
- deployment_name = inquirer.text(
164
- message="Enter the deployment name.", default=default_deployment_name
165
- ).execute()
166
-
167
- if not deployment_name:
168
- raise click.UsageError("Deployment name is required.")
169
-
170
- return deployment_name
171
-
172
-
173
252
  def hydrate_checkpoint(
174
253
  job_id: str, checkpoint_id: str, checkpoint: dict, checkpoint_type: str
175
254
  ) -> Checkpoint:
@@ -190,26 +269,6 @@ def hydrate_checkpoint(
190
269
  )
191
270
 
192
271
 
193
- def _render_truss_config_for_checkpoint_deployment(
194
- checkpoint_deploy: DeployCheckpointsConfigComplete,
195
- ) -> truss_config.TrussConfig:
196
- """
197
- Render truss config for checkpoint deployment.
198
- Currently supports LoRA checkpoints via vLLM, but can be extended for other formats.
199
- """
200
- # Delegate to specific rendering function based on model weight format
201
- if checkpoint_deploy.model_weight_format == ModelWeightsFormat.LORA:
202
- return render_vllm_lora_truss_config(checkpoint_deploy)
203
- elif checkpoint_deploy.model_weight_format == ModelWeightsFormat.FULL:
204
- return render_vllm_full_truss_config(checkpoint_deploy)
205
- elif checkpoint_deploy.model_weight_format == ModelWeightsFormat.WHISPER:
206
- return render_vllm_whisper_truss_config(checkpoint_deploy)
207
- else:
208
- raise ValueError(
209
- f"Unsupported model weight format: {checkpoint_deploy.model_weight_format}. Please upgrade to the latest Truss version to access the latest supported formats. Contact Baseten if you would like us to support additional formats."
210
- )
211
-
212
-
213
272
  def _ensure_checkpoint_details(
214
273
  remote_provider: BasetenRemote,
215
274
  checkpoint_details: Optional[CheckpointList],
@@ -309,31 +368,77 @@ def _select_multiple_checkpoints(checkpoint_id_options: List[str]) -> List[str]:
309
368
  return checkpoint_ids
310
369
 
311
370
 
312
- def _ensure_compute_spec(compute: Optional[Compute]) -> Compute:
371
+ def _ensure_compute_spec(
372
+ compute: Optional[Compute], remote_provider: BasetenRemote
373
+ ) -> Compute:
313
374
  if not compute:
314
375
  compute = Compute(cpu_count=0, memory="0Mi")
315
- compute.accelerator = _get_accelerator_if_specified(compute.accelerator)
376
+ compute = _get_accelerator_if_specified(compute, remote_provider)
316
377
  return compute
317
378
 
318
379
 
319
380
  def _get_accelerator_if_specified(
320
- user_input: Optional[truss_config.AcceleratorSpec],
321
- ) -> Optional[truss_config.AcceleratorSpec]:
322
- if user_input:
381
+ user_input: Optional[Compute], remote_provider: BasetenRemote
382
+ ) -> Compute:
383
+ if user_input and user_input.accelerator:
323
384
  return user_input
385
+
386
+ # Fetch available instance types to get valid GPU options
387
+ instance_types = remote_provider.api.get_instance_types()
388
+
389
+ # Extract unique accelerator types from instance types
390
+ accelerator_options = set()
391
+ for it in instance_types:
392
+ if it.gpu_type and it.gpu_count > 0:
393
+ accelerator_options.add(it.gpu_type)
394
+
395
+ # Convert to sorted list and add CPU option
396
+ choices = sorted(list(accelerator_options)) + [None]
397
+
398
+ if not choices or choices == [None]:
399
+ console.print("No GPU instance types available, using CPU", style="yellow")
400
+ return Compute(cpu_count=0, memory="0Mi", accelerator=None)
401
+
324
402
  # prompt user for accelerator
325
403
  gpu_type = inquirer.select(
326
404
  message="Select the GPU type to use for deployment. Select None for CPU.",
327
- choices=[x.value for x in truss_config.Accelerator] + [None],
405
+ choices=choices,
328
406
  ).execute()
407
+
329
408
  if gpu_type is None:
330
- return None
331
- count = inquirer.text(
332
- message="Enter the number of accelerators to use for deployment.",
333
- default="1",
334
- validate=lambda x: x.isdigit() and int(x) > 0 and int(x) <= 8,
335
- ).execute()
336
- return truss_config.AcceleratorSpec(accelerator=gpu_type, count=int(count))
409
+ return Compute(cpu_count=0, memory="0Mi", accelerator=None)
410
+
411
+ # Get available counts for the selected GPU type
412
+ available_counts = set()
413
+ for it in instance_types:
414
+ if it.gpu_type == gpu_type and it.gpu_count > 0:
415
+ available_counts.add(it.gpu_count)
416
+ if not available_counts:
417
+ raise ValueError(
418
+ f"No available counts for {gpu_type}. Reach out to Baseten for support if this persists."
419
+ )
420
+
421
+ if available_counts:
422
+ count_choices = sorted(list(available_counts))
423
+ count = inquirer.select(
424
+ message=f"Select the number of {gpu_type} GPUs to use for deployment.",
425
+ choices=count_choices,
426
+ default=str(count_choices[0]),
427
+ ).execute()
428
+ else:
429
+ count = inquirer.text(
430
+ message=f"Enter the number of {gpu_type} accelerators to use for deployment.",
431
+ default="1",
432
+ validate=lambda x: x.isdigit() and int(x) > 0 and int(x) <= 8,
433
+ ).execute()
434
+
435
+ return Compute(
436
+ cpu_count=0,
437
+ memory="0Mi",
438
+ accelerator=truss_config.AcceleratorSpec(
439
+ accelerator=gpu_type.replace("-", "_"), count=int(count)
440
+ ),
441
+ )
337
442
 
338
443
 
339
444
  def _get_base_model_id(user_input: Optional[str], checkpoint: dict) -> Optional[str]:
@@ -1,52 +1 @@
1
- import os
2
- from pathlib import Path
3
-
4
- from truss.base import truss_config
5
- from truss.cli.train.types import DeployCheckpointsConfigComplete
6
- from truss_train.definitions import ModelWeightsFormat, SecretReference
7
-
8
- START_COMMAND_ENVVAR_NAME = "BT_DOCKER_SERVER_START_CMD"
9
-
10
-
11
- def setup_base_truss_config(
12
- checkpoint_deploy: DeployCheckpointsConfigComplete,
13
- ) -> truss_config.TrussConfig:
14
- """Set up the base truss config with common properties."""
15
- truss_deploy_config = None
16
- truss_base_file = (
17
- "deploy_from_checkpoint_config_whisper.yml"
18
- if checkpoint_deploy.model_weight_format == ModelWeightsFormat.WHISPER
19
- else "deploy_from_checkpoint_config.yml"
20
- )
21
- truss_deploy_config = truss_config.TrussConfig.from_yaml(
22
- Path(os.path.dirname(__file__), "..", truss_base_file)
23
- )
24
- if not truss_deploy_config.docker_server:
25
- raise ValueError(
26
- "Unexpected checkpoint deployment config: missing docker_server"
27
- )
28
-
29
- truss_deploy_config.model_name = checkpoint_deploy.model_name
30
- truss_deploy_config.training_checkpoints = (
31
- checkpoint_deploy.checkpoint_details.to_truss_config()
32
- )
33
- truss_deploy_config.resources = checkpoint_deploy.compute.to_truss_config()
34
-
35
- return truss_deploy_config
36
-
37
-
38
- def setup_environment_variables_and_secrets(
39
- truss_deploy_config: truss_config.TrussConfig,
40
- checkpoint_deploy: DeployCheckpointsConfigComplete,
41
- ) -> str:
42
- """Set up environment variables and secrets, return start command envvars string."""
43
- start_command_envvars = ""
44
-
45
- for key, value in checkpoint_deploy.runtime.environment_variables.items():
46
- if isinstance(value, SecretReference):
47
- truss_deploy_config.secrets[value.name] = "set token in baseten workspace"
48
- start_command_envvars = f"{key}=$(cat /secrets/{value.name})"
49
- else:
50
- truss_deploy_config.environment_variables[key] = value
51
-
52
- return start_command_envvars
1
+ # This file is kept for potential future use but currently contains no active code
@@ -1,66 +1,7 @@
1
1
  from pathlib import Path
2
2
 
3
- from jinja2 import Template
4
-
5
- from truss.base import truss_config
6
- from truss.cli.train.deploy_checkpoints.deploy_checkpoints_helpers import (
7
- START_COMMAND_ENVVAR_NAME,
8
- )
9
- from truss.cli.train.types import DeployCheckpointsConfigComplete
10
3
  from truss_train.definitions import FullCheckpoint
11
4
 
12
- from .deploy_checkpoints_helpers import (
13
- setup_base_truss_config,
14
- setup_environment_variables_and_secrets,
15
- )
16
-
17
- # NB(aghilan): Transformers was recently changed to save a chat_template.jinja file instead of inside the tokenizer_config.json file.
18
- # Old Models will not have this file, so we check for it and use it if it exists.
19
- # vLLM will not automatically resolve the chat_template.jinja file, so we need to pass it to the start command.
20
- # This logic is needed for any models trained using Transformers v4.51.3 or later
21
- VLLM_FULL_START_COMMAND = Template(
22
- "sh -c '{% if envvars %}{{ envvars }} {% endif %}"
23
- 'HF_TOKEN="$$(cat /secrets/hf_access_token)" && export HF_TOKEN && '
24
- "if [ -f {{ model_path }}/chat_template.jinja ]; then "
25
- " vllm serve {{ model_path }} --chat-template {{ model_path }}/chat_template.jinja "
26
- " --port 8000 --tensor-parallel-size {{ specify_tensor_parallelism }} --dtype bfloat16; "
27
- "else "
28
- " vllm serve {{ model_path }} --port 8000 --tensor-parallel-size {{ specify_tensor_parallelism }} --dtype bfloat16; "
29
- "fi'"
30
- )
31
-
32
-
33
- def render_vllm_full_truss_config(
34
- checkpoint_deploy: DeployCheckpointsConfigComplete,
35
- ) -> truss_config.TrussConfig:
36
- """Render truss config specifically for full checkpoints using vLLM."""
37
- truss_deploy_config = setup_base_truss_config(checkpoint_deploy)
38
-
39
- start_command_envvars = setup_environment_variables_and_secrets(
40
- truss_deploy_config, checkpoint_deploy
41
- )
42
-
43
- checkpoint_str = build_full_checkpoint_string(truss_deploy_config)
44
-
45
- accelerator = checkpoint_deploy.compute.accelerator
46
-
47
- start_command_args = {
48
- "model_path": checkpoint_str,
49
- "envvars": start_command_envvars,
50
- "specify_tensor_parallelism": accelerator.count if accelerator else 1,
51
- }
52
- # Note: we set the start command as an environment variable in supervisord config.
53
- # This is so that we don't have to change the supervisord config when the start command changes.
54
- # Our goal is to reduce the number of times we need to rebuild the image, and allow us to deploy faster.
55
- start_command = VLLM_FULL_START_COMMAND.render(**start_command_args)
56
- truss_deploy_config.environment_variables[START_COMMAND_ENVVAR_NAME] = start_command
57
- # Note: supervisord uses the convention %(ENV_VAR_NAME)s to access environment variable VAR_NAME
58
- truss_deploy_config.docker_server.start_command = ( # type: ignore[union-attr]
59
- f"%(ENV_{START_COMMAND_ENVVAR_NAME})s"
60
- )
61
-
62
- return truss_deploy_config
63
-
64
5
 
65
6
  def hydrate_full_checkpoint(
66
7
  job_id: str, checkpoint_id: str, checkpoint: dict