truss 0.11.1rc4__py3-none-any.whl → 0.11.1rc6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of truss might be problematic. Click here for more details.

truss/base/constants.py CHANGED
@@ -18,6 +18,7 @@ SHARED_SERVING_AND_TRAINING_CODE_DIR: pathlib.Path = (
18
18
  CONTROL_SERVER_CODE_DIR: pathlib.Path = TEMPLATES_DIR / "control"
19
19
  CHAINS_CODE_DIR: pathlib.Path = _TRUSS_ROOT.parent / "truss-chains" / "truss_chains"
20
20
  TRUSS_CODE_DIR: pathlib.Path = _TRUSS_ROOT.parent / "truss"
21
+ TRAINING_TEMPLATE_DIR = TEMPLATES_DIR / "train"
21
22
  # Must be sorted ascendingly.
22
23
  SUPPORTED_PYTHON_VERSIONS = ["3.8", "3.9", "3.10", "3.11", "3.12", "3.13"]
23
24
 
truss/cli/train/core.py CHANGED
@@ -1,3 +1,4 @@
1
+ import base64
1
2
  import json
2
3
  import os
3
4
  import tarfile
@@ -8,6 +9,7 @@ from pathlib import Path
8
9
  from typing import Any, Callable, Dict, Optional, Tuple
9
10
 
10
11
  import click
12
+ import requests
11
13
  import rich
12
14
  from InquirerPy import inquirer
13
15
  from rich.text import Text
@@ -355,6 +357,7 @@ def download_training_job_data(
355
357
  temp_path.write_bytes(content)
356
358
 
357
359
  unzip_dir = output_dir / artifact_base_name
360
+ unzip_dir = Path(str(unzip_dir).replace(" ", "-"))
358
361
  if unzip_dir.exists():
359
362
  raise click.ClickException(
360
363
  f"Directory '{unzip_dir}' already exists. "
@@ -367,6 +370,7 @@ def download_training_job_data(
367
370
 
368
371
  return unzip_dir
369
372
  else:
373
+ target_path = Path(str(target_path).replace(" ", "-"))
370
374
  target_path.write_bytes(content)
371
375
  return target_path
372
376
 
@@ -417,6 +421,158 @@ def status_page_url(remote_url: str, training_job_id: str) -> str:
417
421
  return f"{remote_url}/training/jobs/{training_job_id}"
418
422
 
419
423
 
424
+ def _get_all_train_init_example_options(
425
+ repo_id: str = "ml-cookbook",
426
+ examples_subdir: str = "examples",
427
+ token: Optional[str] = None,
428
+ ) -> list[str]:
429
+ """
430
+ Retrieve a list of all example options from the ml-cookbook repository to
431
+ copy locally for training initialization. This method generates a list
432
+ of examples and URL paths to show the user for selection.
433
+ """
434
+ headers = {}
435
+ if token:
436
+ headers["Authorization"] = f"token {token}"
437
+
438
+ url = (
439
+ f"https://api.github.com/repos/basetenlabs/{repo_id}/contents/{examples_subdir}"
440
+ )
441
+ try:
442
+ response = requests.get(url, headers=headers)
443
+ response.raise_for_status()
444
+
445
+ items = response.json()
446
+ if not isinstance(items, list):
447
+ items = [items]
448
+ items = [item["name"] for item in items if item["type"] == "dir"]
449
+ return items
450
+
451
+ except requests.exceptions.RequestException as e:
452
+ click.echo(
453
+ f"Error exploring directory: {e}. Please file an issue at https://github.com/basetenlabs/truss/issues"
454
+ )
455
+ return []
456
+
457
+
458
+ def _get_train_init_example_info(
459
+ repo_id: str = "ml-cookbook",
460
+ examples_subdir: str = "examples",
461
+ example_name: Optional[str] = None,
462
+ token: Optional[str] = None,
463
+ ) -> list[Dict[str, str]]:
464
+ """
465
+ Retrieve directory download links for the example from the ml-cookbook repository to
466
+ copy locally for training initialization.
467
+ """
468
+ headers = {}
469
+ if token:
470
+ headers["Authorization"] = f"token {token}"
471
+
472
+ url = f"https://api.github.com/repos/basetenlabs/{repo_id}/contents/{examples_subdir}/{example_name}"
473
+
474
+ try:
475
+ response = requests.get(url, headers=headers)
476
+ response.raise_for_status()
477
+
478
+ items = response.json()
479
+ if not isinstance(items, list):
480
+ items = [items]
481
+ return items
482
+
483
+ except requests.exceptions.HTTPError as e:
484
+ if response.status_code == 404:
485
+ # example_name does not exist, return empty list
486
+ return []
487
+ else:
488
+ # Other HTTP errors
489
+ click.echo(
490
+ f"Error exploring directory: {e}. Please file an issue at https://github.com/basetenlabs/truss/issues"
491
+ )
492
+ return []
493
+ except requests.exceptions.RequestException as e:
494
+ # Network or other request errors
495
+ click.echo(
496
+ f"Error exploring directory: {e}. Please file an issue at https://github.com/basetenlabs/truss/issues"
497
+ )
498
+ return []
499
+
500
+
501
+ def download_git_directory(
502
+ git_api_url: str, local_dir: str, token: Optional[str] = None
503
+ ):
504
+ """
505
+ Recursively download directory contents from git api url.
506
+ Special handling for 'training' directory: downloads its contents directly
507
+ to local_dir without creating a 'training' subdirectory.
508
+ Args:
509
+ git_api_url (str): Example format "https://api.github.com/repos/basetenlabs/ml-cookbook/contents/examples/llama-finetune-8b-lora?ref=main"
510
+ local_dir(str): Local directory to download this directory to
511
+ """
512
+ headers = {}
513
+ if token:
514
+ headers["Authorization"] = f"token {token}"
515
+ try:
516
+ response = requests.get(git_api_url, headers=headers)
517
+ response.raise_for_status()
518
+ items = response.json()
519
+
520
+ # Handle single file case
521
+ if not isinstance(items, list):
522
+ items = [items]
523
+
524
+ # Create local directory
525
+ print(f"Creating directory {local_dir}")
526
+ os.makedirs(local_dir, exist_ok=True)
527
+
528
+ # Check if there's a 'training' directory in the items
529
+ training_dir = None
530
+ other_items = []
531
+
532
+ for item in items:
533
+ if item["name"] == "training" and item["type"] == "dir":
534
+ training_dir = item
535
+ else:
536
+ other_items.append(item)
537
+
538
+ # If training directory exists, download its contents directly to local_dir
539
+ if training_dir:
540
+ print(
541
+ f"📁 Found training directory, downloading its contents to {local_dir}"
542
+ )
543
+ return download_git_directory(training_dir["url"], local_dir)
544
+
545
+ # If no training directory, download all files normally
546
+ for item in other_items:
547
+ item_name = item["name"]
548
+ local_item_path = os.path.join(local_dir, item_name)
549
+
550
+ if item["type"] == "file":
551
+ print(f"📄 Downloading {item_name}")
552
+ if item.get("download_url"):
553
+ # Download file directly
554
+ file_response = requests.get(item["download_url"])
555
+ file_response.raise_for_status()
556
+ with open(local_item_path, "wb") as f:
557
+ f.write(file_response.content)
558
+ elif item.get("content"):
559
+ # Decode base64 content (for small files)
560
+ try:
561
+ content = base64.b64decode(item["content"])
562
+ with open(local_item_path, "wb") as f:
563
+ f.write(content)
564
+ except Exception as e:
565
+ print(f"⚠️ Could not decode {item_name}: {e}")
566
+ elif item["type"] == "dir":
567
+ print(f"📁 Entering directory {item_name}")
568
+ # Use the API URL from the response for subdirectories
569
+ download_git_directory(item["url"], local_item_path)
570
+ return True
571
+ except Exception as e:
572
+ print(f"Error processing response: {e}")
573
+ return False
574
+
575
+
420
576
  def fetch_project_by_name_or_id(
421
577
  remote_provider: BasetenRemote, project_identifier: str
422
578
  ) -> dict:
@@ -296,10 +296,22 @@ def _get_checkpoint_ids_to_deploy(
296
296
  return checkpoint_ids
297
297
 
298
298
 
299
+ def _select_single_checkpoint(checkpoint_id_options: List[str]) -> List[str]:
300
+ """Select a single checkpoint using interactive prompt."""
301
+ checkpoint_id = inquirer.select(
302
+ message="Select the checkpoints to deploy:", choices=checkpoint_id_options
303
+ ).execute()
304
+
305
+ if not checkpoint_id:
306
+ raise click.UsageError("A checkpoint must be selected.")
307
+
308
+ return [checkpoint_id]
309
+
310
+
299
311
  def _select_multiple_checkpoints(checkpoint_id_options: List[str]) -> List[str]:
300
312
  """Select multiple checkpoints using interactive checkbox."""
301
313
  checkpoint_ids = inquirer.checkbox(
302
- message="Select the checkpoint to deploy. Use spacebar to select/deselect.",
314
+ message="Use spacebar to select/deselect checkpoints to deploy. Press enter when done.",
303
315
  choices=checkpoint_id_options,
304
316
  ).execute()
305
317
 
@@ -1,3 +1,4 @@
1
+ import os
1
2
  import sys
2
3
  from pathlib import Path
3
4
  from typing import Optional, cast
@@ -5,6 +6,7 @@ from typing import Optional, cast
5
6
  import rich_click as click
6
7
 
7
8
  import truss.cli.train.core as train_cli
9
+ from truss.base.constants import TRAINING_TEMPLATE_DIR
8
10
  from truss.cli import remote_cli
9
11
  from truss.cli.cli import push, truss_cli
10
12
  from truss.cli.logs import utils as cli_log_utils
@@ -25,6 +27,7 @@ from truss.cli.utils.output import console, error_console
25
27
  from truss.remote.baseten.core import get_training_job_logs_with_pagination
26
28
  from truss.remote.baseten.remote import BasetenRemote
27
29
  from truss.remote.remote_factory import RemoteFactory
30
+ from truss.util.path import copy_tree_path
28
31
  from truss_train import TrainingJob
29
32
 
30
33
 
@@ -381,6 +384,75 @@ def download_checkpoint_artifacts(job_id: Optional[str], remote: Optional[str])
381
384
  sys.exit(1)
382
385
 
383
386
 
387
+ @train.command(name="init")
388
+ @click.option("--list-examples", is_flag=True, help="List all available examples.")
389
+ @click.option("--target-directory", type=str, required=False)
390
+ @click.option("--examples", type=str, required=False)
391
+ @common.common_options()
392
+ def init_training_job(
393
+ list_examples: Optional[bool],
394
+ target_directory: Optional[str],
395
+ examples: Optional[str],
396
+ ) -> None:
397
+ try:
398
+ if list_examples:
399
+ all_examples = train_cli._get_all_train_init_example_options()
400
+ console.print("Available training examples:", style="bold")
401
+ for example in all_examples:
402
+ console.print(f"- {example}")
403
+ console.print(
404
+ "To launch, run `truss train init --examples <example1,example2>`",
405
+ style="bold",
406
+ )
407
+ return
408
+
409
+ selected_options = examples.split(",") if examples else []
410
+
411
+ # No examples selected, initialize empty training project structure
412
+ if not selected_options:
413
+ if target_directory is None:
414
+ target_directory = "truss-train-init"
415
+ console.print(f"Initializing empty training project at {target_directory}")
416
+ os.makedirs(target_directory)
417
+ copy_tree_path(Path(TRAINING_TEMPLATE_DIR), Path(target_directory))
418
+ console.print(
419
+ f"✨ Empty training project initialized at {target_directory}",
420
+ style="bold green",
421
+ )
422
+ return
423
+
424
+ if target_directory is None:
425
+ target_directory = os.getcwd()
426
+ for example_to_download in selected_options:
427
+ download_info = train_cli._get_train_init_example_info(
428
+ example_name=example_to_download
429
+ )
430
+ local_dir = os.path.join(target_directory, example_to_download)
431
+
432
+ if not download_info:
433
+ all_examples = train_cli._get_all_train_init_example_options()
434
+ error_console.print(
435
+ f"Example {example_to_download} not found in the ml-cookbook repository. Examples have to be one or more comma separated values from: {', '.join(all_examples)}"
436
+ )
437
+ continue
438
+ success = train_cli.download_git_directory(
439
+ git_api_url=download_info[0]["url"], local_dir=local_dir
440
+ )
441
+ if success:
442
+ console.print(
443
+ f"✨ Training directory for {example_to_download} initialized at {local_dir}",
444
+ style="bold green",
445
+ )
446
+ else:
447
+ error_console.print(
448
+ f"Failed to initialize training artifacts to {local_dir}"
449
+ )
450
+
451
+ except Exception as e:
452
+ error_console.print(f"Failed to initialize training artifacts: {str(e)}")
453
+ sys.exit(1)
454
+
455
+
384
456
  @train.group(name="cache")
385
457
  def cache():
386
458
  """Cache-related subcommands for truss train"""
@@ -6,7 +6,7 @@ loguru>=0.7.2
6
6
  python-json-logger>=2.0.2
7
7
  tenacity>=8.1.0
8
8
  # To avoid divergence, this should follow the latest release.
9
- truss==0.11.1rc3
9
+ truss==0.11.1rc6
10
10
  uvicorn>=0.24.0
11
11
  uvloop>=0.19.0
12
12
  websockets>=10.0
@@ -0,0 +1,46 @@
1
+ # Import necessary classes from the Baseten Training SDK
2
+ from truss_train import definitions
3
+ from truss.base import truss_config
4
+
5
+ PROJECT_NAME = "My-Baseten-Training-Project"
6
+ NUM_NODES = 1
7
+ NUM_GPUS_PER_NODE = 1
8
+
9
+ # 1. Define a base image for your training job. You can also use
10
+ # private images via AWS IAM or GCP Service Account authentication.
11
+ BASE_IMAGE = "pytorch/pytorch:2.7.0-cuda12.8-cudnn9-runtime"
12
+
13
+ # 2. Define the Runtime Environment for the Training Job
14
+ # This includes start commands and environment variables.
15
+ # Secrets from the baseten workspace like API keys are referenced using
16
+ # `SecretReference`.
17
+ training_runtime = definitions.Runtime(
18
+ start_commands=[ # Example: list of commands to run your training script
19
+ "/bin/sh -c 'chmod +x ./run.sh && ./run.sh'"
20
+ ],
21
+ environment_variables={
22
+ # "HF_TOKEN": definitions.SecretReference(name="hf_access_token"),
23
+ "HELLO": "WORLD"
24
+ },
25
+ cache_config=definitions.CacheConfig(
26
+ enabled=False # Set to True to enable caching between runs
27
+ ),
28
+ checkpointing_config=definitions.CheckpointingConfig(
29
+ enabled=False # Set to True to enable saving checkpoints on Baseten
30
+ ),
31
+ )
32
+
33
+ training_compute = definitions.Compute(
34
+ node_count=NUM_NODES,
35
+ accelerator=truss_config.AcceleratorSpec(
36
+ accelerator=truss_config.Accelerator.H100, count=NUM_GPUS_PER_NODE
37
+ ),
38
+ )
39
+
40
+ training_job = definitions.TrainingJob(
41
+ image=definitions.Image(base_image=BASE_IMAGE),
42
+ compute=training_compute,
43
+ runtime=training_runtime,
44
+ )
45
+
46
+ training_project = definitions.TrainingProject(name=PROJECT_NAME, job=training_job)
@@ -0,0 +1,11 @@
1
+ #!/bin/bash
2
+
3
+ # Exit immediately if a command exits with a non-zero status
4
+ set -eux
5
+
6
+ echo "Initializing model training environment..."
7
+ # TODO: Call your training logic below
8
+ echo "Placeholder: insert your model training logic below."
9
+ # e.g., python train_model.py --config config.yaml --epochs 10
10
+
11
+ echo "Training process completed (placeholder)."
@@ -584,7 +584,7 @@ def test_get_checkpoint_ids_to_deploy_full_checkpoints():
584
584
  mock_checkbox.assert_called_once()
585
585
  assert (
586
586
  mock_checkbox.call_args[1]["message"]
587
- == "Select the checkpoint to deploy. Use spacebar to select/deselect."
587
+ == "Use spacebar to select/deselect checkpoints to deploy. Press enter when done."
588
588
  )
589
589
  assert mock_checkbox.call_args[1]["choices"] == checkpoint_options
590
590
 
@@ -621,7 +621,7 @@ def test_get_checkpoint_ids_to_deploy_lora_checkpoints():
621
621
  mock_checkbox.assert_called_once()
622
622
  assert (
623
623
  mock_checkbox.call_args[1]["message"]
624
- == "Select the checkpoint to deploy. Use spacebar to select/deselect."
624
+ == "Use spacebar to select/deselect checkpoints to deploy. Press enter when done."
625
625
  )
626
626
  assert mock_checkbox.call_args[1]["choices"] == checkpoint_options
627
627
 
@@ -656,7 +656,7 @@ def test_get_checkpoint_ids_to_deploy_mixed_checkpoints():
656
656
  mock_checkbox.assert_called_once()
657
657
  assert (
658
658
  mock_checkbox.call_args[1]["message"]
659
- == "Select the checkpoint to deploy. Use spacebar to select/deselect."
659
+ == "Use spacebar to select/deselect checkpoints to deploy. Press enter when done."
660
660
  )
661
661
  assert mock_checkbox.call_args[1]["choices"] == checkpoint_options
662
662
 
@@ -0,0 +1,499 @@
1
+ from unittest.mock import Mock, call, mock_open, patch
2
+
3
+ import pytest
4
+ import requests
5
+
6
+ from truss.cli.train.core import (
7
+ _get_all_train_init_example_options,
8
+ _get_train_init_example_info,
9
+ download_git_directory,
10
+ )
11
+
12
+
13
+ class TestGetTrainInitExampleOptions:
14
+ """Test cases for _get_train_init_example_options function"""
15
+
16
+ @patch("requests.get")
17
+ def test_successful_request_without_token(self, mock_get):
18
+ """Test successful API call without authentication token"""
19
+ # Arrange
20
+ mock_response = Mock()
21
+ mock_response.json.return_value = [
22
+ {"name": "example1", "type": "dir"},
23
+ {"name": "example2", "type": "dir"},
24
+ {"name": "file1", "type": "file"}, # Should be filtered out
25
+ ]
26
+ mock_response.raise_for_status.return_value = None
27
+ mock_get.return_value = mock_response
28
+
29
+ # Act
30
+ result = _get_all_train_init_example_options()
31
+
32
+ # Assert
33
+ mock_get.assert_called_once_with(
34
+ "https://api.github.com/repos/basetenlabs/ml-cookbook/contents/examples",
35
+ headers={},
36
+ )
37
+ assert len(result) == 2
38
+ assert "example1" in result
39
+ assert "example2" in result
40
+ assert "file1" not in result # Files should be filtered out
41
+
42
+ @patch("requests.get")
43
+ def test_successful_request_with_token(self, mock_get):
44
+ """Test successful API call with authentication token"""
45
+ # Arrange
46
+ mock_response = Mock()
47
+ mock_response.json.return_value = [{"name": "example1", "type": "dir"}]
48
+ mock_response.raise_for_status.return_value = None
49
+ mock_get.return_value = mock_response
50
+
51
+ # Act
52
+ result = _get_all_train_init_example_options(token="test_token")
53
+
54
+ # Assert
55
+ mock_get.assert_called_once_with(
56
+ "https://api.github.com/repos/basetenlabs/ml-cookbook/contents/examples",
57
+ headers={"Authorization": "token test_token"},
58
+ )
59
+ assert len(result) == 1
60
+ assert "example1" in result
61
+
62
+ @patch("requests.get")
63
+ def test_custom_repo_and_subdir(self, mock_get):
64
+ """Test with custom repository and subdirectory"""
65
+ # Arrange
66
+ mock_response = Mock()
67
+ mock_response.json.return_value = []
68
+ mock_response.raise_for_status.return_value = None
69
+ mock_get.return_value = mock_response
70
+
71
+ # Act
72
+ _ = _get_all_train_init_example_options(
73
+ repo_id="custom-repo", examples_subdir="custom-examples"
74
+ )
75
+
76
+ # Assert
77
+ mock_get.assert_called_once_with(
78
+ "https://api.github.com/repos/basetenlabs/custom-repo/contents/custom-examples",
79
+ headers={},
80
+ )
81
+
82
+ @patch("requests.get")
83
+ def test_single_item_response(self, mock_get):
84
+ """Test when API returns a single item instead of a list"""
85
+ # Arrange
86
+ mock_response = Mock()
87
+ mock_response.json.return_value = {"name": "single_example", "type": "dir"}
88
+ mock_response.raise_for_status.return_value = None
89
+ mock_get.return_value = mock_response
90
+
91
+ # Act
92
+ result = _get_all_train_init_example_options()
93
+
94
+ # Assert
95
+ assert len(result) == 1
96
+ assert "single_example" in result
97
+
98
+ @patch("requests.get")
99
+ @patch("click.echo")
100
+ def test_request_exception_handling(self, mock_echo, mock_get):
101
+ """Test handling of request exceptions"""
102
+ # Arrange
103
+ mock_get.side_effect = requests.exceptions.RequestException("Network error")
104
+
105
+ # Act
106
+ result = _get_all_train_init_example_options()
107
+
108
+ # Assert
109
+ mock_echo.assert_called_once_with(
110
+ "Error exploring directory: Network error. Please file an issue at https://github.com/basetenlabs/truss/issues"
111
+ )
112
+ assert result == []
113
+
114
+ @patch("requests.get")
115
+ @patch("click.echo")
116
+ def test_http_error_handling(self, mock_echo, mock_get):
117
+ """Test handling of HTTP errors"""
118
+ # Arrange
119
+ mock_response = Mock()
120
+ mock_response.raise_for_status.side_effect = requests.exceptions.HTTPError(
121
+ "404 Not Found"
122
+ )
123
+ mock_get.return_value = mock_response
124
+
125
+ # Act
126
+ result = _get_all_train_init_example_options()
127
+
128
+ # Assert
129
+ mock_echo.assert_called_once_with(
130
+ "Error exploring directory: 404 Not Found. Please file an issue at https://github.com/basetenlabs/truss/issues"
131
+ )
132
+ assert result == []
133
+
134
+ @patch("requests.get")
135
+ def test_filters_only_directories(self, mock_get):
136
+ """Test that only directories are returned, files are filtered out"""
137
+ # Arrange
138
+ mock_response = Mock()
139
+ mock_response.json.return_value = [
140
+ {"name": "example1", "type": "dir"},
141
+ {"name": "readme.md", "type": "file"},
142
+ {"name": "example2", "type": "dir"},
143
+ {"name": "config.json", "type": "file"},
144
+ ]
145
+ mock_response.raise_for_status.return_value = None
146
+ mock_get.return_value = mock_response
147
+
148
+ # Act
149
+ result = _get_all_train_init_example_options()
150
+
151
+ # Assert
152
+ assert len(result) == 2
153
+ assert "example1" in result
154
+ assert "example2" in result
155
+ assert "readme.md" not in result
156
+ assert "config.json" not in result
157
+
158
+
159
+ class TestGetTrainInitExampleInfo:
160
+ """Test cases for _get_train_init_example_info function"""
161
+
162
+ @patch("requests.get")
163
+ def test_successful_request_without_token(self, mock_get):
164
+ """Test successful API call without authentication token"""
165
+ # Arrange
166
+ mock_response = Mock()
167
+ mock_response.json.return_value = [
168
+ {"name": "file1.py", "type": "file"},
169
+ {"name": "file2.py", "type": "file"},
170
+ ]
171
+ mock_response.raise_for_status.return_value = None
172
+ mock_get.return_value = mock_response
173
+
174
+ # Act
175
+ result = _get_train_init_example_info(example_name="test_example")
176
+
177
+ # Assert
178
+ mock_get.assert_called_once_with(
179
+ "https://api.github.com/repos/basetenlabs/ml-cookbook/contents/examples/test_example",
180
+ headers={},
181
+ )
182
+ assert len(result) == 2
183
+ assert result[0]["name"] == "file1.py"
184
+ assert result[1]["name"] == "file2.py"
185
+
186
+ @patch("requests.get")
187
+ def test_successful_request_with_token(self, mock_get):
188
+ """Test successful API call with authentication token"""
189
+ # Arrange
190
+ mock_response = Mock()
191
+ mock_response.json.return_value = [{"name": "file1.py", "type": "file"}]
192
+ mock_response.raise_for_status.return_value = None
193
+ mock_get.return_value = mock_response
194
+
195
+ # Act
196
+ result = _get_train_init_example_info(
197
+ example_name="test_example", token="test_token"
198
+ )
199
+
200
+ # Assert
201
+ mock_get.assert_called_once_with(
202
+ "https://api.github.com/repos/basetenlabs/ml-cookbook/contents/examples/test_example",
203
+ headers={"Authorization": "token test_token"},
204
+ )
205
+ assert len(result) == 1
206
+
207
+ @patch("requests.get")
208
+ def test_custom_repo_and_subdir(self, mock_get):
209
+ """Test with custom repository and subdirectory"""
210
+ # Arrange
211
+ mock_response = Mock()
212
+ mock_response.json.return_value = []
213
+ mock_response.raise_for_status.return_value = None
214
+ mock_get.return_value = mock_response
215
+
216
+ # Act
217
+ _ = _get_train_init_example_info(
218
+ repo_id="custom-repo",
219
+ examples_subdir="custom-examples",
220
+ example_name="test_example",
221
+ )
222
+
223
+ # Assert
224
+ mock_get.assert_called_once_with(
225
+ "https://api.github.com/repos/basetenlabs/custom-repo/contents/custom-examples/test_example",
226
+ headers={},
227
+ )
228
+
229
+ @patch("requests.get")
230
+ def test_single_item_response(self, mock_get):
231
+ """Test when API returns a single item instead of a list"""
232
+ # Arrange
233
+ mock_response = Mock()
234
+ mock_response.json.return_value = {"name": "single_file.py", "type": "file"}
235
+ mock_response.raise_for_status.return_value = None
236
+ mock_get.return_value = mock_response
237
+
238
+ # Act
239
+ result = _get_train_init_example_info(example_name="test_example")
240
+
241
+ # Assert
242
+ assert len(result) == 1
243
+ assert result[0]["name"] == "single_file.py"
244
+
245
+ @patch("requests.get")
246
+ @patch("click.echo")
247
+ def test_404_error_returns_empty_list(self, mock_echo, mock_get):
248
+ """Test that 404 errors return empty list without error message"""
249
+ # Arrange
250
+ mock_response = Mock()
251
+ mock_response.status_code = 404
252
+ mock_response.raise_for_status.side_effect = requests.exceptions.HTTPError(
253
+ "404 Not Found"
254
+ )
255
+ mock_get.return_value = mock_response
256
+
257
+ # Act
258
+ result = _get_train_init_example_info(example_name="nonexistent_example")
259
+
260
+ # Assert
261
+ mock_echo.assert_not_called() # Should not echo error for 404
262
+ assert result == []
263
+
264
+ @patch("requests.get")
265
+ @patch("click.echo")
266
+ def test_other_http_error_handling(self, mock_echo, mock_get):
267
+ """Test handling of non-404 HTTP errors"""
268
+ # Arrange
269
+ mock_response = Mock()
270
+ mock_response.status_code = 500
271
+ mock_response.raise_for_status.side_effect = requests.exceptions.HTTPError(
272
+ "500 Internal Server Error"
273
+ )
274
+ mock_get.return_value = mock_response
275
+
276
+ # Act
277
+ result = _get_train_init_example_info(example_name="test_example")
278
+
279
+ # Assert
280
+ mock_echo.assert_called_once_with(
281
+ "Error exploring directory: 500 Internal Server Error. Please file an issue at https://github.com/basetenlabs/truss/issues"
282
+ )
283
+ assert result == []
284
+
285
+ @patch("requests.get")
286
+ @patch("click.echo")
287
+ def test_request_exception_handling(self, mock_echo, mock_get):
288
+ """Test handling of request exceptions"""
289
+ # Arrange
290
+ mock_get.side_effect = requests.exceptions.RequestException("Network error")
291
+
292
+ # Act
293
+ result = _get_train_init_example_info(example_name="test_example")
294
+
295
+ # Assert
296
+ mock_echo.assert_called_once_with(
297
+ "Error exploring directory: Network error. Please file an issue at https://github.com/basetenlabs/truss/issues"
298
+ )
299
+ assert result == []
300
+
301
+ @patch("requests.get")
302
+ def test_none_example_name(self, mock_get):
303
+ """Test with None as example_name"""
304
+ # Arrange
305
+ mock_response = Mock()
306
+ mock_response.json.return_value = []
307
+ mock_response.raise_for_status.return_value = None
308
+ mock_get.return_value = mock_response
309
+
310
+ # Act
311
+ result = _get_train_init_example_info(example_name=None)
312
+
313
+ # Assert
314
+ mock_get.assert_called_once_with(
315
+ "https://api.github.com/repos/basetenlabs/ml-cookbook/contents/examples/None",
316
+ headers={},
317
+ )
318
+ assert result == []
319
+
320
+
321
+ class TestDownloadGitDirectory:
322
+ """Test cases for download_git_directory function"""
323
+
324
+ @patch("os.makedirs")
325
+ @patch("requests.get")
326
+ @patch("builtins.open", new_callable=mock_open)
327
+ @patch("builtins.print")
328
+ def test_download_files_without_training_dir(
329
+ self, mock_print, mock_file, mock_get, mock_makedirs
330
+ ):
331
+ """Test downloading files without a training directory"""
332
+ # Arrange
333
+ mock_response = Mock()
334
+ mock_response.json.return_value = [
335
+ {
336
+ "name": "file1.txt",
337
+ "type": "file",
338
+ "download_url": "https://example.com/file1.txt",
339
+ },
340
+ {
341
+ "name": "file2.py",
342
+ "type": "file",
343
+ "download_url": "https://example.com/file2.py",
344
+ },
345
+ ]
346
+ mock_response.raise_for_status.return_value = None
347
+
348
+ # Mock file download responses
349
+ file_response1 = Mock()
350
+ file_response1.content = b"file1 content"
351
+ file_response1.raise_for_status.return_value = None
352
+
353
+ file_response2 = Mock()
354
+ file_response2.content = b"file2 content"
355
+ file_response2.raise_for_status.return_value = None
356
+
357
+ mock_get.side_effect = [mock_response, file_response1, file_response2]
358
+
359
+ # Act
360
+ result = download_git_directory("https://api.github.com/test", "/local/dir")
361
+
362
+ # Assert
363
+ assert result is True
364
+ mock_makedirs.assert_called_once_with("/local/dir", exist_ok=True)
365
+ assert mock_get.call_count == 3
366
+ assert mock_file.call_count == 2
367
+
368
+ @patch("os.makedirs")
369
+ @patch("requests.get")
370
+ def test_download_with_training_directory(self, mock_get, mock_makedirs):
371
+ """Test downloading when training directory is present"""
372
+ # Arrange
373
+ initial_response = Mock()
374
+ initial_response.json.return_value = [
375
+ {
376
+ "name": "training",
377
+ "type": "dir",
378
+ "url": "https://api.github.com/training",
379
+ },
380
+ {
381
+ "name": "other_file.txt",
382
+ "type": "file",
383
+ "download_url": "https://example.com/other_file.txt",
384
+ },
385
+ ]
386
+ initial_response.raise_for_status.return_value = None
387
+
388
+ training_response = Mock()
389
+ training_response.json.return_value = []
390
+ training_response.raise_for_status.return_value = None
391
+
392
+ mock_get.side_effect = [initial_response, training_response]
393
+
394
+ # Act
395
+ result = download_git_directory("https://api.github.com/test", "/local/dir")
396
+
397
+ # Assert
398
+ assert result is True
399
+ # Should be called twice: once for initial dir, once for training contents
400
+ assert mock_makedirs.call_count == 2
401
+
402
+ @patch("os.makedirs")
403
+ @patch("requests.get")
404
+ def test_download_subdirectory_recursively(self, mock_get, mock_makedirs):
405
+ """Test recursive download of subdirectories"""
406
+ # Arrange
407
+ initial_response = Mock()
408
+ initial_response.json.return_value = [
409
+ {"name": "subdir", "type": "dir", "url": "https://api.github.com/subdir"}
410
+ ]
411
+ initial_response.raise_for_status.return_value = None
412
+
413
+ subdir_response = Mock()
414
+ subdir_response.json.return_value = []
415
+ subdir_response.raise_for_status.return_value = None
416
+
417
+ mock_get.side_effect = [initial_response, subdir_response]
418
+
419
+ # Act
420
+ result = download_git_directory("https://api.github.com/test", "/local/dir")
421
+
422
+ # Assert
423
+ assert result is True
424
+ expected_calls = [
425
+ call("/local/dir", exist_ok=True),
426
+ call("/local/dir/subdir", exist_ok=True),
427
+ ]
428
+ mock_makedirs.assert_has_calls(expected_calls)
429
+
430
+ @patch("os.makedirs")
431
+ @patch("requests.get")
432
+ @patch("builtins.print")
433
+ def test_download_with_authentication_token(
434
+ self, mock_print, mock_get, mock_makedirs
435
+ ):
436
+ """Test download with authentication token"""
437
+ # Arrange
438
+ mock_response = Mock()
439
+ mock_response.json.return_value = []
440
+ mock_response.raise_for_status.return_value = None
441
+ mock_get.return_value = mock_response
442
+
443
+ # Act
444
+ result = download_git_directory(
445
+ "https://api.github.com/test", "/local/dir", token="test_token"
446
+ )
447
+
448
+ # Assert
449
+ assert result is True
450
+ mock_get.assert_called_once_with(
451
+ "https://api.github.com/test", headers={"Authorization": "token test_token"}
452
+ )
453
+
454
+ @patch("os.makedirs")
455
+ @patch("requests.get")
456
+ @patch("builtins.print")
457
+ def test_download_single_file_response(self, mock_print, mock_get, mock_makedirs):
458
+ """Test when API returns a single file instead of a list"""
459
+ # Arrange
460
+ mock_response = Mock()
461
+ mock_response.json.return_value = {
462
+ "name": "single_file.txt",
463
+ "type": "file",
464
+ "download_url": "https://example.com/single_file.txt",
465
+ }
466
+ mock_response.raise_for_status.return_value = None
467
+
468
+ file_response = Mock()
469
+ file_response.content = b"single file content"
470
+ file_response.raise_for_status.return_value = None
471
+
472
+ mock_get.side_effect = [mock_response, file_response]
473
+
474
+ with patch("builtins.open", mock_open()) as mock_file:
475
+ # Act
476
+ result = download_git_directory("https://api.github.com/test", "/local/dir")
477
+
478
+ # Assert
479
+ assert result is True
480
+ mock_file.assert_called_once_with("/local/dir/single_file.txt", "wb")
481
+
482
+ @patch("os.makedirs")
483
+ @patch("requests.get")
484
+ @patch("builtins.print")
485
+ def test_download_exception_handling(self, mock_print, mock_get, mock_makedirs):
486
+ """Test exception handling during download"""
487
+ # Arrange
488
+ mock_get.side_effect = Exception("Network error")
489
+
490
+ # Act
491
+ result = download_git_directory("https://api.github.com/test", "/local/dir")
492
+
493
+ # Assert
494
+ assert result is False
495
+ mock_print.assert_called_with("Error processing response: Network error")
496
+
497
+
498
+ if __name__ == "__main__":
499
+ pytest.main([__file__])
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: truss
3
- Version: 0.11.1rc4
3
+ Version: 0.11.1rc6
4
4
  Summary: A seamless bridge from model development to model delivery
5
5
  Project-URL: Repository, https://github.com/basetenlabs/truss
6
6
  Project-URL: Homepage, https://truss.baseten.co
@@ -2,7 +2,7 @@ truss/__init__.py,sha256=CoUcP6vx_pocyemRmpbCPlndkHhdMkABAlr0ZXVuPCk,1163
2
2
  truss/api/__init__.py,sha256=spBAa_m1pItiid97iDLKPmumgAkSirPkv-E8RWMZyOk,5090
3
3
  truss/api/definitions.py,sha256=QAaIBqL59Q-R7HtLcXcoeCIWBN2HqOzApdFX0PpCq2s,1604
4
4
  truss/base/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
5
- truss/base/constants.py,sha256=qwNNkd9EOAuiTxYLVccJaiPCNRayBAFvyj_GisYOT3I,3488
5
+ truss/base/constants.py,sha256=stc7WEm9neqoR6L2edzHXude7EkEmpnq-yAUXOULmPw,3536
6
6
  truss/base/custom_types.py,sha256=FUSIT2lPOQb6gfg6IzT63YBV8r8L6NIZ0D74Fp3e_jQ,2835
7
7
  truss/base/errors.py,sha256=zDVLEvseTChdPP0oNhBBQCtQUtZJUaof5zeWMIjqz6o,691
8
8
  truss/base/trt_llm_config.py,sha256=CRz3AqGDAyv8YpcBWXUrnfjvNAauyo3yf8ZOGVsSt6g,32782
@@ -11,20 +11,20 @@ truss/base/truss_spec.py,sha256=jFVF79CXoEEspl2kXBAPyi-rwISReIGTdobGpaIhwJw,5979
11
11
  truss/cli/chains_commands.py,sha256=bqOXQ-0RPS66vSP_OPQdJ5dvctGiVrsGoSUMbURGdSI,16970
12
12
  truss/cli/cli.py,sha256=PaMkuwXZflkU7sa1tEoT_Zmy-iBkEZs1m4IVqcieaeo,30367
13
13
  truss/cli/remote_cli.py,sha256=G_xCKRXzgkCmkiZJhUFfsv5YSVgde1jLA5LPQitpZgI,1905
14
- truss/cli/train_commands.py,sha256=GDye7yXGL_nQvXAlY5MWsdj5x0zYOvcQw0Ubn14TiRU,14365
14
+ truss/cli/train_commands.py,sha256=TZhtvofviWQF34pYppRCaQ6qayTsvPnx6afTrYbFpOM,17319
15
15
  truss/cli/logs/base_watcher.py,sha256=KKyd7lIrdaEeDVt8EtjMioSPGVpLyOcF0ewyzE_GGdQ,2785
16
16
  truss/cli/logs/model_log_watcher.py,sha256=NACcP-wkcaroYa2Cb9BZC7Yr0554WZa_FSM2LXOf4A8,1263
17
17
  truss/cli/logs/training_log_watcher.py,sha256=r6HRqrLnz-PiKTUXiDYYxg4ZnP8vYcXlEX1YmgHhzlo,1173
18
18
  truss/cli/logs/utils.py,sha256=z-U_FG4BUzdZLbE3BnXb4DZQ0zt3LSZ3PiQpLaDuc3o,1031
19
19
  truss/cli/train/common.py,sha256=xTR41U5FeSndXfNBBHF9wF5XwZH1sOIVFlv-XHjsKIU,1547
20
- truss/cli/train/core.py,sha256=dAmetxKqSc4bQPnVS_8WLfNsw1L7vLT2tU02BVwRPgc,20206
20
+ truss/cli/train/core.py,sha256=4vPnREmaJh8R_rlwR0_H5NRaXhdyY2g07w11uab-9qw,25908
21
21
  truss/cli/train/deploy_from_checkpoint_config.yml,sha256=mktaVrfhN8Kjx1UveC4xr-gTW-kjwbHvq6bx_LpO-Wg,371
22
22
  truss/cli/train/deploy_from_checkpoint_config_whisper.yml,sha256=6GbOorYC8ml0UyOUvuBpFO_fuYtYE646JqsalR-D4oY,406
23
23
  truss/cli/train/metrics_watcher.py,sha256=smz-zrEsBj_-wJHI0pAZ-EAPrvfCWzq1eQjGiFNM-Mk,12755
24
24
  truss/cli/train/poller.py,sha256=TGRzELxsicga0bEXewSX1ujw6lfPmDnHd6nr8zvOFO8,3550
25
25
  truss/cli/train/types.py,sha256=alGtr4Q71GeB65PpGMhsoKygw4k_ncR6MKIP1ioP8rI,951
26
26
  truss/cli/train/deploy_checkpoints/__init__.py,sha256=wL-M2yu8PxO2tFvjwshXAfPnB-5TlvsBp2v_bdzimRU,99
27
- truss/cli/train/deploy_checkpoints/deploy_checkpoints.py,sha256=xfblHi3py7GDgY24NcuAaDKzcQeOm67rjtWOK6vAEe4,17352
27
+ truss/cli/train/deploy_checkpoints/deploy_checkpoints.py,sha256=KUaUl5a2lGy-l5rZZJv-rxIAsCJPRNC8h1909fXVKtE,17763
28
28
  truss/cli/train/deploy_checkpoints/deploy_checkpoints_helpers.py,sha256=6x5nS_HnWYtS9vi-Pg8akzrJk9L_agjvFhm5EFh1m6Y,1964
29
29
  truss/cli/train/deploy_checkpoints/deploy_full_checkpoints.py,sha256=FYRG5KTMlxEMZS-RA_m2gp1wuqWbSpqt2RhdQfLibhA,3968
30
30
  truss/cli/train/deploy_checkpoints/deploy_lora_checkpoints.py,sha256=P91dIAzuhl2GlzmrWwCcYI7uCMT1Lm7C79JQHM_exN4,4442
@@ -72,7 +72,7 @@ truss/templates/cache_requirements.txt,sha256=xoPoJ-OVnf1z6oq_RVM3vCr3ionByyqMLj
72
72
  truss/templates/copy_cache_files.Dockerfile.jinja,sha256=Os5zFdYLZ_AfCRGq4RcpVTObOTwL7zvmwYcvOzd_Zqo,126
73
73
  truss/templates/docker_server_requirements.txt,sha256=PyhOPKAmKW1N2vLvTfLMwsEtuGpoRrbWuNo7tT6v2Mc,18
74
74
  truss/templates/server.Dockerfile.jinja,sha256=CUYnF_hgxPGq2re7__0UPWlwzOHMoFkxp6NVKi3U16s,7071
75
- truss/templates/control/requirements.txt,sha256=D2kIrXfCKlWl8LO7quTUlCFYuT3Dn_MVAlCG_0YjHQY,253
75
+ truss/templates/control/requirements.txt,sha256=MiVoU5n8GTj5ygN-iL6neTD1AIqhVYIaJAHkcnsHGvA,253
76
76
  truss/templates/control/control/application.py,sha256=jYeta6hWe1SkfLL3W4IDmdYjg3ZuKqI_UagWYs5RB_E,3793
77
77
  truss/templates/control/control/endpoints.py,sha256=VQ1lvZjFvR091yRkiFdvXw1Q7PiNGXT9rJwY7_sX6yg,11828
78
78
  truss/templates/control/control/server.py,sha256=R4Y219i1dcz0kkksN8obLoX-YXWGo9iW1igindyG50c,3128
@@ -112,6 +112,8 @@ truss/templates/shared/log_config.py,sha256=l9udyu4VKHZePlfK9LQEd5TOUUodPuehypsX
112
112
  truss/templates/shared/secrets_resolver.py,sha256=3prDe3Q06NTmUEe7KCW-W4TD1CzGck9lpDG789209z4,2110
113
113
  truss/templates/shared/serialization.py,sha256=_WC_2PPkRi-MdTwxwjG8LKQptnHi4sANfpOlKWevqWc,3736
114
114
  truss/templates/shared/util.py,sha256=dPgFF4iL_YkeC6Kf8tZUHJH60rbpskHwVPh0ONLGaQM,2222
115
+ truss/templates/train/config.py,sha256=aQJ3lsyVRlq6edjjZq4_Anz1bZVwkjLdclmZPJTdo1k,1626
116
+ truss/templates/train/run.sh,sha256=2rimigJOn6yg4DguRfOJWkzm77X-meNSYXnidLafqNg,346
115
117
  truss/templates/trtllm-audio/model/model.py,sha256=o38QqW57b1lf8O_td1lW_AojZZ8R_qAZCgzOWtoIse8,1619
116
118
  truss/templates/trtllm-audio/packages/sigint_patch.py,sha256=t6pYpVwgQsLCgcxQq7-V3scr9ZOiIxtYSpy9LCfdNTk,414
117
119
  truss/templates/trtllm-audio/packages/whisper_trt/__init__.py,sha256=5ZQfVlwtkWrnjYiuBIVSviYDhV-kksygDkHEWBS_ijM,7065
@@ -140,8 +142,9 @@ truss/tests/test_truss_handle.py,sha256=-xz9VXkecXDTslmQZ-dmUmQLnvD0uumRqHS2uvGl
140
142
  truss/tests/test_util.py,sha256=hs1bNMkXKEdoPRx4Nw-NAEdoibR92OubZuADGmbiYsQ,1344
141
143
  truss/tests/cli/test_cli.py,sha256=yfbVS5u1hnAmmA8mJ539vj3lhH-JVGUvC4Q_Mbort44,787
142
144
  truss/tests/cli/train/test_cache_view.py,sha256=aVRCh3atRpFbJqyYgq7N-vAW0DiKMftQ7ajUqO2ClOg,22606
143
- truss/tests/cli/train/test_deploy_checkpoints.py,sha256=wQZ3DPLPAyXE3iaQiyHJTBO15v_gXN44eDk1StYkKmM,44764
145
+ truss/tests/cli/train/test_deploy_checkpoints.py,sha256=lDk88uAUPYatJ30JKVVtJDdXv_zWNk1nxXFyUH6IVGw,44800
144
146
  truss/tests/cli/train/test_train_cli_core.py,sha256=vzYfxKdwoa3NaFMrVZbSg5qOoLXivMvZXN1ClQirGTQ,16148
147
+ truss/tests/cli/train/test_train_init.py,sha256=pv8BfyLlVG0QtdowTziITjKa_OE1KigatmAGx8XSZrM,17238
145
148
  truss/tests/cli/train/resources/test_deploy_from_checkpoint_config.yml,sha256=GF7r9l0KaeXiUYCPSBpeMPd2QG6PeWWyI12NdbqLOgc,1930
146
149
  truss/tests/contexts/image_builder/test_serving_image_builder.py,sha256=16niCXZnuxFHXYQw2vPFZ8svSZafkH5DT0Gx3Z9Xdd8,22377
147
150
  truss/tests/contexts/local_loader/test_load_local.py,sha256=D1qMH2IpYA2j5009v50QMgUnKdeOsX15ndkwXe10a4E,801
@@ -358,14 +361,14 @@ truss_chains/reference_code/reference_model.py,sha256=emH3hb23E_nbP98I37PGp1Xk1h
358
361
  truss_chains/remote_chainlet/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
359
362
  truss_chains/remote_chainlet/model_skeleton.py,sha256=8ZReLOO2MLcdg7bNZ61C-6j-e68i2Z-fFlyV3sz0qH8,2376
360
363
  truss_chains/remote_chainlet/stub.py,sha256=Y2gDUzMY9WRaQNHIz-o4dfLUfFyYV9dUhIRQcfgrY8g,17209
361
- truss_chains/remote_chainlet/utils.py,sha256=xX1t3e-BsYkWrxQIqfKRl4PHGuVyW3oleWFQpXSAynI,22949
364
+ truss_chains/remote_chainlet/utils.py,sha256=RJ74JeB_jzq0wjzxkkVrcnoh_fdWhiq5-FtZTYQdgyQ,23260
362
365
  truss_train/__init__.py,sha256=7hE6j6-u6UGzCGaNp3CsCN0kAVjBus1Ekups-Bk0fi4,837
363
366
  truss_train/definitions.py,sha256=V985HhY4rdXL10DZxpFEpze9ScxzWErMht4WwaPknGU,6789
364
367
  truss_train/deployment.py,sha256=lWWANSuzBWu2M4oK4qD7n-oVR1JKdmw2Pn5BJQHg-Ck,3074
365
368
  truss_train/loader.py,sha256=0o66EjBaHc2YY4syxxHVR4ordJWs13lNXnKjKq2wq0U,1630
366
369
  truss_train/public_api.py,sha256=9N_NstiUlmBuLUwH_fNG_1x7OhGCytZLNvqKXBlStrM,1220
367
- truss-0.11.1rc4.dist-info/METADATA,sha256=PYD_kydnF-Z7GjTBOB0-JA0lQjQMtiBn7Y-30qyT7wY,6672
368
- truss-0.11.1rc4.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
369
- truss-0.11.1rc4.dist-info/entry_points.txt,sha256=-MwKfHHQHQ6j0HqIgvxrz3CehCmczDLTD-OsRHnjjuU,130
370
- truss-0.11.1rc4.dist-info/licenses/LICENSE,sha256=FTqGzu85i-uw1Gi8E_o0oD60bH9yQ_XIGtZbA1QUYiw,1064
371
- truss-0.11.1rc4.dist-info/RECORD,,
370
+ truss-0.11.1rc6.dist-info/METADATA,sha256=FpPBQ0BmRG8IPSDsYnA7f9Su_I3L0dqTycCIsC7tKv4,6672
371
+ truss-0.11.1rc6.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
372
+ truss-0.11.1rc6.dist-info/entry_points.txt,sha256=-MwKfHHQHQ6j0HqIgvxrz3CehCmczDLTD-OsRHnjjuU,130
373
+ truss-0.11.1rc6.dist-info/licenses/LICENSE,sha256=FTqGzu85i-uw1Gi8E_o0oD60bH9yQ_XIGtZbA1QUYiw,1064
374
+ truss-0.11.1rc6.dist-info/RECORD,,
@@ -382,13 +382,20 @@ def pydantic_set_field_dict(obj: pydantic.BaseModel) -> dict[str, pydantic.BaseM
382
382
  # Error Propagation Utils. #############################################################
383
383
 
384
384
 
385
- def _handle_exception(exception: Exception) -> NoReturn:
386
- """Raises `HTTPException` with `RemoteErrorDetail`."""
385
+ # NB(nikhil): Deployed chainlets have access to FastAPI, but local testing doesn't necessarily
386
+ # have that dependency. We have a helpful error message via `utils.make_optional_import_error`
387
+ # for those cases.
388
+ def _safe_import_fastapi():
387
389
  try:
388
- import fastapi
390
+ import fastapi # noqa: F401
389
391
  except ImportError:
390
392
  raise utils.make_optional_import_error("fastapi")
391
393
 
394
+
395
+ def _handle_exception(exception: Exception) -> NoReturn:
396
+ """Raises `HTTPException` with `RemoteErrorDetail`."""
397
+ _safe_import_fastapi()
398
+
392
399
  if hasattr(exception, "__module__"):
393
400
  exception_module_name = exception.__module__
394
401
  else:
@@ -588,6 +595,7 @@ class WebsocketWrapperFastAPI:
588
595
  await self._websocket.close(code=code, reason=reason)
589
596
 
590
597
  async def receive(self) -> Union[str, bytes]:
598
+ _safe_import_fastapi()
591
599
  message = await self._websocket.receive()
592
600
 
593
601
  if message.get("type") == "websocket.disconnect":