truss 0.11.6rc102__py3-none-any.whl → 0.11.24rc2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- truss/api/__init__.py +5 -2
- truss/base/constants.py +1 -0
- truss/base/trt_llm_config.py +14 -3
- truss/base/truss_config.py +19 -4
- truss/cli/chains_commands.py +49 -1
- truss/cli/cli.py +38 -7
- truss/cli/logs/base_watcher.py +31 -12
- truss/cli/logs/model_log_watcher.py +24 -1
- truss/cli/remote_cli.py +29 -0
- truss/cli/resolvers/chain_team_resolver.py +82 -0
- truss/cli/resolvers/model_team_resolver.py +90 -0
- truss/cli/resolvers/training_project_team_resolver.py +81 -0
- truss/cli/train/cache.py +332 -0
- truss/cli/train/core.py +57 -163
- truss/cli/train/deploy_checkpoints/__init__.py +2 -2
- truss/cli/train/deploy_checkpoints/deploy_checkpoints.py +236 -103
- truss/cli/train/deploy_checkpoints/deploy_checkpoints_helpers.py +1 -52
- truss/cli/train/deploy_checkpoints/deploy_full_checkpoints.py +1 -86
- truss/cli/train/deploy_checkpoints/deploy_lora_checkpoints.py +1 -85
- truss/cli/train/deploy_checkpoints/deploy_whisper_checkpoints.py +1 -56
- truss/cli/train/types.py +18 -9
- truss/cli/train_commands.py +180 -35
- truss/cli/utils/common.py +40 -3
- truss/contexts/image_builder/serving_image_builder.py +17 -4
- truss/remote/baseten/api.py +215 -9
- truss/remote/baseten/core.py +63 -7
- truss/remote/baseten/custom_types.py +1 -0
- truss/remote/baseten/remote.py +42 -2
- truss/remote/baseten/service.py +0 -7
- truss/remote/baseten/utils/transfer.py +5 -2
- truss/templates/base.Dockerfile.jinja +8 -4
- truss/templates/control/control/application.py +51 -26
- truss/templates/control/control/endpoints.py +1 -5
- truss/templates/control/control/helpers/inference_server_process_controller.py +10 -4
- truss/templates/control/control/helpers/truss_patch/model_container_patch_applier.py +33 -18
- truss/templates/control/control/server.py +1 -1
- truss/templates/control/requirements.txt +1 -2
- truss/templates/docker_server/proxy.conf.jinja +13 -0
- truss/templates/docker_server/supervisord.conf.jinja +2 -1
- truss/templates/no_build.Dockerfile.jinja +1 -0
- truss/templates/server/requirements.txt +2 -3
- truss/templates/server/truss_server.py +2 -5
- truss/templates/server.Dockerfile.jinja +12 -12
- truss/templates/shared/lazy_data_resolver.py +214 -2
- truss/templates/shared/util.py +6 -5
- truss/tests/cli/chains/test_chains_team_parameter.py +443 -0
- truss/tests/cli/test_chains_cli.py +144 -0
- truss/tests/cli/test_cli.py +134 -1
- truss/tests/cli/test_cli_utils_common.py +11 -0
- truss/tests/cli/test_model_team_resolver.py +279 -0
- truss/tests/cli/train/test_cache_view.py +240 -3
- truss/tests/cli/train/test_deploy_checkpoints.py +2 -846
- truss/tests/cli/train/test_train_cli_core.py +2 -2
- truss/tests/cli/train/test_train_team_parameter.py +395 -0
- truss/tests/conftest.py +187 -0
- truss/tests/contexts/image_builder/test_serving_image_builder.py +10 -5
- truss/tests/remote/baseten/test_api.py +122 -3
- truss/tests/remote/baseten/test_chain_upload.py +294 -0
- truss/tests/remote/baseten/test_core.py +86 -0
- truss/tests/remote/baseten/test_remote.py +216 -288
- truss/tests/remote/baseten/test_service.py +56 -0
- truss/tests/templates/control/control/conftest.py +20 -0
- truss/tests/templates/control/control/test_endpoints.py +4 -0
- truss/tests/templates/control/control/test_server.py +8 -24
- truss/tests/templates/control/control/test_server_integration.py +4 -2
- truss/tests/test_config.py +21 -12
- truss/tests/test_data/server.Dockerfile +3 -1
- truss/tests/test_data/test_build_commands_truss/__init__.py +0 -0
- truss/tests/test_data/test_build_commands_truss/config.yaml +14 -0
- truss/tests/test_data/test_build_commands_truss/model/model.py +12 -0
- truss/tests/test_data/test_build_commands_truss/packages/constants/constants.py +1 -0
- truss/tests/test_data/test_truss_server_model_cache_v1/config.yaml +1 -0
- truss/tests/test_model_inference.py +13 -0
- truss/tests/util/test_env_vars.py +8 -3
- truss/util/__init__.py +0 -0
- truss/util/env_vars.py +19 -8
- truss/util/error_utils.py +37 -0
- {truss-0.11.6rc102.dist-info → truss-0.11.24rc2.dist-info}/METADATA +2 -2
- {truss-0.11.6rc102.dist-info → truss-0.11.24rc2.dist-info}/RECORD +88 -70
- {truss-0.11.6rc102.dist-info → truss-0.11.24rc2.dist-info}/WHEEL +1 -1
- truss_chains/deployment/deployment_client.py +16 -4
- truss_chains/private_types.py +18 -0
- truss_chains/public_api.py +3 -0
- truss_train/definitions.py +6 -4
- truss_train/deployment.py +43 -21
- truss_train/public_api.py +4 -2
- {truss-0.11.6rc102.dist-info → truss-0.11.24rc2.dist-info}/entry_points.txt +0 -0
- {truss-0.11.6rc102.dist-info → truss-0.11.24rc2.dist-info}/licenses/LICENSE +0 -0
|
@@ -1,10 +1,10 @@
|
|
|
1
1
|
from unittest.mock import Mock, patch
|
|
2
2
|
|
|
3
|
-
from truss.cli.train.
|
|
3
|
+
from truss.cli.train.cache import (
|
|
4
4
|
calculate_directory_sizes,
|
|
5
5
|
create_file_summary_with_directory_sizes,
|
|
6
|
-
view_training_job_metrics,
|
|
7
6
|
)
|
|
7
|
+
from truss.cli.train.core import view_training_job_metrics
|
|
8
8
|
from truss.remote.baseten.custom_types import FileSummary
|
|
9
9
|
|
|
10
10
|
|
|
@@ -0,0 +1,395 @@
|
|
|
1
|
+
"""Tests for team parameter in training project creation.
|
|
2
|
+
|
|
3
|
+
This test suite covers all 8 scenarios for team resolution in truss train push:
|
|
4
|
+
1. --team PROVIDED: Valid team name, user has access
|
|
5
|
+
2. --team PROVIDED: Invalid team name (does not exist)
|
|
6
|
+
3. --team NOT PROVIDED: User has multiple teams, no existing project
|
|
7
|
+
4. --team NOT PROVIDED: User has multiple teams, existing project in exactly one team
|
|
8
|
+
5. --team NOT PROVIDED: User has multiple teams, existing project exists in multiple teams
|
|
9
|
+
6. --team NOT PROVIDED: User has exactly one team, no existing project
|
|
10
|
+
7. --team NOT PROVIDED: User has exactly one team, existing project matches the team
|
|
11
|
+
8. --team NOT PROVIDED: User has exactly one team, existing project exists in different team
|
|
12
|
+
"""
|
|
13
|
+
|
|
14
|
+
from pathlib import Path
|
|
15
|
+
from unittest.mock import Mock, patch
|
|
16
|
+
|
|
17
|
+
from click.testing import CliRunner
|
|
18
|
+
|
|
19
|
+
from truss.cli.cli import truss_cli
|
|
20
|
+
from truss.remote.baseten.remote import BasetenRemote
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class TestTeamParameter:
|
|
24
|
+
"""Test team parameter in training project creation."""
|
|
25
|
+
|
|
26
|
+
@staticmethod
|
|
27
|
+
def _setup_mock_remote(teams):
|
|
28
|
+
mock_remote = Mock(spec=BasetenRemote)
|
|
29
|
+
mock_api = Mock()
|
|
30
|
+
mock_remote.api = mock_api
|
|
31
|
+
mock_api.get_teams.return_value = teams
|
|
32
|
+
return mock_remote
|
|
33
|
+
|
|
34
|
+
@staticmethod
|
|
35
|
+
def _create_test_config():
|
|
36
|
+
config_path = Path("/tmp/test_config.py")
|
|
37
|
+
config_path.parent.mkdir(parents=True, exist_ok=True)
|
|
38
|
+
config_path.write_text("# dummy config")
|
|
39
|
+
return config_path
|
|
40
|
+
|
|
41
|
+
@staticmethod
|
|
42
|
+
def _invoke_train_push(runner, config_path, team_name=None, remote="test_remote"):
|
|
43
|
+
args = ["train", "push", str(config_path), "--remote", remote]
|
|
44
|
+
if team_name:
|
|
45
|
+
args.extend(["--team", team_name])
|
|
46
|
+
return runner.invoke(truss_cli, args)
|
|
47
|
+
|
|
48
|
+
@staticmethod
|
|
49
|
+
def _create_mock_training_project(name="test-project"):
|
|
50
|
+
mock_project = Mock()
|
|
51
|
+
mock_project.name = name
|
|
52
|
+
return mock_project
|
|
53
|
+
|
|
54
|
+
@staticmethod
|
|
55
|
+
def _setup_mock_loader(mock_import_project, training_project):
|
|
56
|
+
mock_import_project.return_value.__enter__ = Mock(return_value=training_project)
|
|
57
|
+
mock_import_project.return_value.__exit__ = Mock(return_value=None)
|
|
58
|
+
|
|
59
|
+
@staticmethod
|
|
60
|
+
def _setup_mock_status(mock_status):
|
|
61
|
+
mock_status.return_value.__enter__ = Mock(return_value=None)
|
|
62
|
+
mock_status.return_value.__exit__ = Mock(return_value=None)
|
|
63
|
+
|
|
64
|
+
@staticmethod
|
|
65
|
+
def _create_mock_job_response(
|
|
66
|
+
project_id="12345", project_name="test-project", job_id="job123"
|
|
67
|
+
):
|
|
68
|
+
return {
|
|
69
|
+
"id": job_id,
|
|
70
|
+
"training_project": {"id": project_id, "name": project_name},
|
|
71
|
+
}
|
|
72
|
+
|
|
73
|
+
@staticmethod
|
|
74
|
+
def _assert_training_job_called_with_team(
|
|
75
|
+
mock_create_job, expected_team_name, training_project, expected_teams=None
|
|
76
|
+
):
|
|
77
|
+
mock_create_job.assert_called_once()
|
|
78
|
+
call_args = mock_create_job.call_args
|
|
79
|
+
assert call_args[0][2] == training_project
|
|
80
|
+
assert call_args[1]["team_name"] == expected_team_name
|
|
81
|
+
# Verify team_id is resolved and passed correctly
|
|
82
|
+
if expected_team_name and expected_teams:
|
|
83
|
+
expected_team_id = expected_teams[expected_team_name]["id"]
|
|
84
|
+
assert call_args[1]["team_id"] == expected_team_id
|
|
85
|
+
elif expected_team_name is None:
|
|
86
|
+
# If no team_name, team_id should also be None
|
|
87
|
+
assert call_args[1]["team_id"] is None
|
|
88
|
+
else:
|
|
89
|
+
# team_name provided but team_id should be resolved
|
|
90
|
+
assert "team_id" in call_args[1]
|
|
91
|
+
|
|
92
|
+
# SCENARIO 1: --team PROVIDED: Valid team name, user has access
|
|
93
|
+
# CLI Command: truss train push /path/to/config.py --team "Team Alpha" --remote baseten_staging
|
|
94
|
+
# Exit Code: 0, Error Message: None, Interactive Prompt: No, Existing Teams: ["team1"]
|
|
95
|
+
@patch("truss_train.deployment.create_training_job")
|
|
96
|
+
@patch("truss.cli.train_commands.RemoteFactory.create")
|
|
97
|
+
@patch("truss.cli.train_commands.console.status")
|
|
98
|
+
@patch("truss_train.loader.import_training_project")
|
|
99
|
+
def test_scenario_1_team_provided_valid_team_name(
|
|
100
|
+
self, mock_import_project, mock_status, mock_remote_factory, mock_create_job
|
|
101
|
+
):
|
|
102
|
+
"""Scenario 1: --team PROVIDED with valid team name, user has access."""
|
|
103
|
+
teams = {"Team Alpha": {"id": "team1", "name": "Team Alpha"}}
|
|
104
|
+
training_project = self._create_mock_training_project()
|
|
105
|
+
job_response = self._create_mock_job_response()
|
|
106
|
+
|
|
107
|
+
mock_remote = self._setup_mock_remote(teams)
|
|
108
|
+
mock_remote_factory.return_value = mock_remote
|
|
109
|
+
self._setup_mock_status(mock_status)
|
|
110
|
+
self._setup_mock_loader(mock_import_project, training_project)
|
|
111
|
+
mock_create_job.return_value = job_response
|
|
112
|
+
|
|
113
|
+
runner = CliRunner()
|
|
114
|
+
config_path = self._create_test_config()
|
|
115
|
+
result = self._invoke_train_push(runner, config_path, team_name="Team Alpha")
|
|
116
|
+
|
|
117
|
+
assert result.exit_code == 0
|
|
118
|
+
self._assert_training_job_called_with_team(
|
|
119
|
+
mock_create_job, "Team Alpha", training_project, expected_teams=teams
|
|
120
|
+
)
|
|
121
|
+
|
|
122
|
+
# SCENARIO 2: --team PROVIDED: Invalid team name (does not exist)
|
|
123
|
+
# CLI Command: truss train push /path/to/config.py --team "NonExistentTeam" --remote baseten_staging
|
|
124
|
+
# Exit Code: 1, Error Message: Team does not exist, Interactive Prompt: No, Existing Teams: ["team1"]
|
|
125
|
+
@patch("truss.cli.train_commands.RemoteFactory.create")
|
|
126
|
+
@patch("truss_train.loader.import_training_project")
|
|
127
|
+
def test_scenario_2_team_provided_invalid_team_name(
|
|
128
|
+
self, mock_import_project, mock_remote_factory
|
|
129
|
+
):
|
|
130
|
+
"""Scenario 2: --team PROVIDED with invalid team name that does not exist."""
|
|
131
|
+
teams = {"Team Alpha": {"id": "team1", "name": "Team Alpha"}}
|
|
132
|
+
training_project = self._create_mock_training_project()
|
|
133
|
+
|
|
134
|
+
mock_remote = self._setup_mock_remote(teams)
|
|
135
|
+
mock_remote_factory.return_value = mock_remote
|
|
136
|
+
self._setup_mock_loader(mock_import_project, training_project)
|
|
137
|
+
|
|
138
|
+
runner = CliRunner()
|
|
139
|
+
config_path = self._create_test_config()
|
|
140
|
+
result = self._invoke_train_push(
|
|
141
|
+
runner, config_path, team_name="NonExistentTeam"
|
|
142
|
+
)
|
|
143
|
+
|
|
144
|
+
assert result.exit_code == 1
|
|
145
|
+
assert "does not exist" in result.output
|
|
146
|
+
assert "NonExistentTeam" in result.output
|
|
147
|
+
|
|
148
|
+
# SCENARIO 3: --team NOT PROVIDED: User has multiple teams, no existing project
|
|
149
|
+
# CLI Command: truss train push /path/to/config.py --remote baseten_staging
|
|
150
|
+
# Exit Code: 0, Error Message: None, Interactive Prompt: Yes, Existing Teams: ["team1", "team2", "team3"]
|
|
151
|
+
@patch("truss_train.deployment.create_training_job")
|
|
152
|
+
@patch("truss.cli.train_commands.RemoteFactory.create")
|
|
153
|
+
@patch("truss.cli.remote_cli.inquire_team")
|
|
154
|
+
@patch("truss.cli.train_commands.console.status")
|
|
155
|
+
@patch("truss_train.loader.import_training_project")
|
|
156
|
+
def test_scenario_3_multiple_teams_no_existing_project(
|
|
157
|
+
self,
|
|
158
|
+
mock_import_project,
|
|
159
|
+
mock_status,
|
|
160
|
+
mock_inquire_team,
|
|
161
|
+
mock_remote_factory,
|
|
162
|
+
mock_create_job,
|
|
163
|
+
):
|
|
164
|
+
"""Scenario 3: --team NOT PROVIDED, user has multiple teams, no existing project."""
|
|
165
|
+
teams = {
|
|
166
|
+
"Team Alpha": {"id": "team1", "name": "Team Alpha"},
|
|
167
|
+
"Team Beta": {"id": "team2", "name": "Team Beta"},
|
|
168
|
+
"Team Gamma": {"id": "team3", "name": "Team Gamma"},
|
|
169
|
+
}
|
|
170
|
+
training_project = self._create_mock_training_project()
|
|
171
|
+
job_response = self._create_mock_job_response()
|
|
172
|
+
|
|
173
|
+
mock_remote = self._setup_mock_remote(teams)
|
|
174
|
+
mock_remote.api.list_training_projects.return_value = []
|
|
175
|
+
mock_remote_factory.return_value = mock_remote
|
|
176
|
+
self._setup_mock_status(mock_status)
|
|
177
|
+
self._setup_mock_loader(mock_import_project, training_project)
|
|
178
|
+
mock_inquire_team.return_value = "Team Beta"
|
|
179
|
+
mock_create_job.return_value = job_response
|
|
180
|
+
|
|
181
|
+
runner = CliRunner()
|
|
182
|
+
config_path = self._create_test_config()
|
|
183
|
+
result = self._invoke_train_push(runner, config_path)
|
|
184
|
+
|
|
185
|
+
assert result.exit_code == 0
|
|
186
|
+
mock_inquire_team.assert_called_once()
|
|
187
|
+
assert mock_inquire_team.call_args[1]["existing_teams"] == teams
|
|
188
|
+
self._assert_training_job_called_with_team(
|
|
189
|
+
mock_create_job, "Team Beta", training_project, expected_teams=teams
|
|
190
|
+
)
|
|
191
|
+
|
|
192
|
+
# SCENARIO 4: --team NOT PROVIDED: User has multiple teams, existing project in exactly one team
|
|
193
|
+
# CLI Command: truss train push /path/to/config.py --remote baseten_staging
|
|
194
|
+
# Exit Code: 0, Error Message: None, Interactive Prompt: No, Existing Teams: ["team1", "team2", "team3"]
|
|
195
|
+
@patch("truss_train.deployment.create_training_job")
|
|
196
|
+
@patch("truss.cli.train_commands.RemoteFactory.create")
|
|
197
|
+
@patch("truss.cli.train_commands.console.status")
|
|
198
|
+
@patch("truss_train.loader.import_training_project")
|
|
199
|
+
def test_scenario_4_multiple_teams_existing_project_in_one_team(
|
|
200
|
+
self, mock_import_project, mock_status, mock_remote_factory, mock_create_job
|
|
201
|
+
):
|
|
202
|
+
"""Scenario 4: --team NOT PROVIDED, multiple teams, existing project in exactly one team."""
|
|
203
|
+
teams = {
|
|
204
|
+
"Team Alpha": {"id": "team1", "name": "Team Alpha"},
|
|
205
|
+
"Team Beta": {"id": "team2", "name": "Team Beta"},
|
|
206
|
+
"Team Gamma": {"id": "team3", "name": "Team Gamma"},
|
|
207
|
+
}
|
|
208
|
+
existing_project = {
|
|
209
|
+
"id": "project123",
|
|
210
|
+
"name": "existing-project",
|
|
211
|
+
"team_name": "Team Beta",
|
|
212
|
+
}
|
|
213
|
+
training_project = self._create_mock_training_project(name="existing-project")
|
|
214
|
+
job_response = self._create_mock_job_response(
|
|
215
|
+
project_id="project123", project_name="existing-project"
|
|
216
|
+
)
|
|
217
|
+
|
|
218
|
+
mock_remote = self._setup_mock_remote(teams)
|
|
219
|
+
mock_remote.api.list_training_projects.return_value = [existing_project]
|
|
220
|
+
mock_remote_factory.return_value = mock_remote
|
|
221
|
+
self._setup_mock_status(mock_status)
|
|
222
|
+
self._setup_mock_loader(mock_import_project, training_project)
|
|
223
|
+
mock_create_job.return_value = job_response
|
|
224
|
+
|
|
225
|
+
runner = CliRunner()
|
|
226
|
+
config_path = self._create_test_config()
|
|
227
|
+
result = self._invoke_train_push(runner, config_path)
|
|
228
|
+
|
|
229
|
+
assert result.exit_code == 0
|
|
230
|
+
self._assert_training_job_called_with_team(
|
|
231
|
+
mock_create_job, "Team Beta", training_project, expected_teams=teams
|
|
232
|
+
)
|
|
233
|
+
mock_remote.api.list_training_projects.assert_called_once()
|
|
234
|
+
|
|
235
|
+
# SCENARIO 5: --team NOT PROVIDED: User has multiple teams, existing project exists in multiple teams
|
|
236
|
+
# CLI Command: truss train push /path/to/config.py --remote baseten_staging
|
|
237
|
+
# Exit Code: 0, Error Message: None, Interactive Prompt: Yes, Existing Teams: ["team1", "team2", "team3"]
|
|
238
|
+
@patch("truss_train.deployment.create_training_job")
|
|
239
|
+
@patch("truss.cli.train_commands.RemoteFactory.create")
|
|
240
|
+
@patch("truss.cli.remote_cli.inquire_team")
|
|
241
|
+
@patch("truss.cli.train_commands.console.status")
|
|
242
|
+
@patch("truss_train.loader.import_training_project")
|
|
243
|
+
def test_scenario_5_multiple_teams_existing_project_in_multiple_teams(
|
|
244
|
+
self,
|
|
245
|
+
mock_import_project,
|
|
246
|
+
mock_status,
|
|
247
|
+
mock_inquire_team,
|
|
248
|
+
mock_remote_factory,
|
|
249
|
+
mock_create_job,
|
|
250
|
+
):
|
|
251
|
+
"""Scenario 5: --team NOT PROVIDED, multiple teams, existing project in multiple teams."""
|
|
252
|
+
teams = {
|
|
253
|
+
"Team Alpha": {"id": "team1", "name": "Team Alpha"},
|
|
254
|
+
"Team Beta": {"id": "team2", "name": "Team Beta"},
|
|
255
|
+
"Team Gamma": {"id": "team3", "name": "Team Gamma"},
|
|
256
|
+
}
|
|
257
|
+
existing_projects = [
|
|
258
|
+
{"id": "project123", "name": "existing-project", "team_name": "Team Alpha"},
|
|
259
|
+
{"id": "project456", "name": "existing-project", "team_name": "Team Beta"},
|
|
260
|
+
]
|
|
261
|
+
training_project = self._create_mock_training_project(name="existing-project")
|
|
262
|
+
job_response = self._create_mock_job_response(
|
|
263
|
+
project_id="project123", project_name="existing-project"
|
|
264
|
+
)
|
|
265
|
+
|
|
266
|
+
mock_remote = self._setup_mock_remote(teams)
|
|
267
|
+
mock_remote.api.list_training_projects.return_value = existing_projects
|
|
268
|
+
mock_remote_factory.return_value = mock_remote
|
|
269
|
+
self._setup_mock_status(mock_status)
|
|
270
|
+
self._setup_mock_loader(mock_import_project, training_project)
|
|
271
|
+
mock_inquire_team.return_value = "Team Alpha"
|
|
272
|
+
mock_create_job.return_value = job_response
|
|
273
|
+
|
|
274
|
+
runner = CliRunner()
|
|
275
|
+
config_path = self._create_test_config()
|
|
276
|
+
result = self._invoke_train_push(runner, config_path)
|
|
277
|
+
|
|
278
|
+
assert result.exit_code == 0
|
|
279
|
+
mock_inquire_team.assert_called_once()
|
|
280
|
+
assert mock_inquire_team.call_args[1]["existing_teams"] == teams
|
|
281
|
+
self._assert_training_job_called_with_team(
|
|
282
|
+
mock_create_job, "Team Alpha", training_project, expected_teams=teams
|
|
283
|
+
)
|
|
284
|
+
|
|
285
|
+
# SCENARIO 6: --team NOT PROVIDED: User has exactly one team, no existing project
|
|
286
|
+
# CLI Command: truss train push /path/to/config.py --remote baseten_staging
|
|
287
|
+
# Exit Code: 0, Error Message: None, Interactive Prompt: No, Existing Teams: ["team1"]
|
|
288
|
+
@patch("truss_train.deployment.create_training_job")
|
|
289
|
+
@patch("truss.cli.train_commands.RemoteFactory.create")
|
|
290
|
+
@patch("truss.cli.train_commands.console.status")
|
|
291
|
+
@patch("truss_train.loader.import_training_project")
|
|
292
|
+
def test_scenario_6_single_team_no_existing_project(
|
|
293
|
+
self, mock_import_project, mock_status, mock_remote_factory, mock_create_job
|
|
294
|
+
):
|
|
295
|
+
"""Scenario 6: --team NOT PROVIDED, user has exactly one team, no existing project."""
|
|
296
|
+
teams = {"Team Alpha": {"id": "team1", "name": "Team Alpha"}}
|
|
297
|
+
training_project = self._create_mock_training_project()
|
|
298
|
+
job_response = self._create_mock_job_response()
|
|
299
|
+
|
|
300
|
+
mock_remote = self._setup_mock_remote(teams)
|
|
301
|
+
mock_remote.api.list_training_projects.return_value = []
|
|
302
|
+
mock_remote_factory.return_value = mock_remote
|
|
303
|
+
self._setup_mock_status(mock_status)
|
|
304
|
+
self._setup_mock_loader(mock_import_project, training_project)
|
|
305
|
+
mock_create_job.return_value = job_response
|
|
306
|
+
|
|
307
|
+
runner = CliRunner()
|
|
308
|
+
config_path = self._create_test_config()
|
|
309
|
+
result = self._invoke_train_push(runner, config_path)
|
|
310
|
+
|
|
311
|
+
assert result.exit_code == 0
|
|
312
|
+
self._assert_training_job_called_with_team(
|
|
313
|
+
mock_create_job, "Team Alpha", training_project, expected_teams=teams
|
|
314
|
+
)
|
|
315
|
+
|
|
316
|
+
# SCENARIO 7: --team NOT PROVIDED: User has exactly one team, existing project matches the team
|
|
317
|
+
# CLI Command: truss train push /path/to/config.py --remote baseten_staging
|
|
318
|
+
# Exit Code: 0, Error Message: None, Interactive Prompt: No, Existing Teams: ["team1"]
|
|
319
|
+
@patch("truss_train.deployment.create_training_job")
|
|
320
|
+
@patch("truss.cli.train_commands.RemoteFactory.create")
|
|
321
|
+
@patch("truss.cli.train_commands.console.status")
|
|
322
|
+
@patch("truss_train.loader.import_training_project")
|
|
323
|
+
def test_scenario_7_single_team_existing_project_matches_team(
|
|
324
|
+
self, mock_import_project, mock_status, mock_remote_factory, mock_create_job
|
|
325
|
+
):
|
|
326
|
+
"""Scenario 7: --team NOT PROVIDED, single team, existing project matches the team."""
|
|
327
|
+
teams = {"Team Alpha": {"id": "team1", "name": "Team Alpha"}}
|
|
328
|
+
existing_project = {
|
|
329
|
+
"id": "project123",
|
|
330
|
+
"name": "existing-project",
|
|
331
|
+
"team_name": "Team Alpha",
|
|
332
|
+
}
|
|
333
|
+
training_project = self._create_mock_training_project(name="existing-project")
|
|
334
|
+
job_response = self._create_mock_job_response(
|
|
335
|
+
project_id="project123", project_name="existing-project"
|
|
336
|
+
)
|
|
337
|
+
|
|
338
|
+
mock_remote = self._setup_mock_remote(teams)
|
|
339
|
+
mock_remote.api.list_training_projects.return_value = [existing_project]
|
|
340
|
+
mock_remote_factory.return_value = mock_remote
|
|
341
|
+
self._setup_mock_status(mock_status)
|
|
342
|
+
self._setup_mock_loader(mock_import_project, training_project)
|
|
343
|
+
mock_create_job.return_value = job_response
|
|
344
|
+
|
|
345
|
+
runner = CliRunner()
|
|
346
|
+
config_path = self._create_test_config()
|
|
347
|
+
result = self._invoke_train_push(runner, config_path)
|
|
348
|
+
|
|
349
|
+
assert result.exit_code == 0
|
|
350
|
+
self._assert_training_job_called_with_team(
|
|
351
|
+
mock_create_job, "Team Alpha", training_project, expected_teams=teams
|
|
352
|
+
)
|
|
353
|
+
mock_remote.api.list_training_projects.assert_called_once()
|
|
354
|
+
|
|
355
|
+
# SCENARIO 8: --team NOT PROVIDED: User has exactly one team, existing project exists in different team
|
|
356
|
+
# CLI Command: truss train push /path/to/config.py --remote baseten_staging
|
|
357
|
+
# Exit Code: 1, Error Message: None, Interactive Prompt: No, Existing Teams: ["team1"]
|
|
358
|
+
# Note: This scenario occurs when a project exists in a team the user doesn't have access to
|
|
359
|
+
@patch("truss_train.deployment.create_training_job")
|
|
360
|
+
@patch("truss.cli.train_commands.RemoteFactory.create")
|
|
361
|
+
@patch("truss.cli.train_commands.console.status")
|
|
362
|
+
@patch("truss_train.loader.import_training_project")
|
|
363
|
+
def test_scenario_8_single_team_existing_project_different_team(
|
|
364
|
+
self, mock_import_project, mock_status, mock_remote_factory, mock_create_job
|
|
365
|
+
):
|
|
366
|
+
"""Scenario 8: --team NOT PROVIDED, single team, existing project in different team."""
|
|
367
|
+
teams = {"Team Alpha": {"id": "team1", "name": "Team Alpha"}}
|
|
368
|
+
existing_project = {
|
|
369
|
+
"id": "project123",
|
|
370
|
+
"name": "existing-project",
|
|
371
|
+
"team_name": "Team Other", # Different team user doesn't have access to
|
|
372
|
+
}
|
|
373
|
+
training_project = self._create_mock_training_project(name="existing-project")
|
|
374
|
+
job_response = self._create_mock_job_response(
|
|
375
|
+
project_id="project123", project_name="existing-project"
|
|
376
|
+
)
|
|
377
|
+
|
|
378
|
+
mock_remote = self._setup_mock_remote(teams)
|
|
379
|
+
mock_remote.api.list_training_projects.return_value = [existing_project]
|
|
380
|
+
mock_remote_factory.return_value = mock_remote
|
|
381
|
+
self._setup_mock_status(mock_status)
|
|
382
|
+
self._setup_mock_loader(mock_import_project, training_project)
|
|
383
|
+
mock_create_job.return_value = job_response
|
|
384
|
+
|
|
385
|
+
runner = CliRunner()
|
|
386
|
+
config_path = self._create_test_config()
|
|
387
|
+
result = self._invoke_train_push(runner, config_path)
|
|
388
|
+
|
|
389
|
+
# Based on current implementation, when project exists in different team but user has only one team,
|
|
390
|
+
# the resolver uses the user's single team (exit 0). The Excel table shows exit code 1, but
|
|
391
|
+
# that would require backend validation. Current behavior uses the single team.
|
|
392
|
+
assert result.exit_code == 0
|
|
393
|
+
self._assert_training_job_called_with_team(
|
|
394
|
+
mock_create_job, "Team Alpha", training_project, expected_teams=teams
|
|
395
|
+
)
|
truss/tests/conftest.py
CHANGED
|
@@ -2,15 +2,19 @@ import contextlib
|
|
|
2
2
|
import copy
|
|
3
3
|
import importlib
|
|
4
4
|
import os
|
|
5
|
+
import pathlib
|
|
5
6
|
import shutil
|
|
6
7
|
import subprocess
|
|
7
8
|
import sys
|
|
9
|
+
import tempfile
|
|
8
10
|
import time
|
|
9
11
|
from pathlib import Path
|
|
10
12
|
from typing import Any, Dict
|
|
13
|
+
from unittest import mock
|
|
11
14
|
|
|
12
15
|
import pytest
|
|
13
16
|
import requests
|
|
17
|
+
import requests_mock
|
|
14
18
|
import yaml
|
|
15
19
|
|
|
16
20
|
from truss.base.custom_types import Example
|
|
@@ -20,6 +24,8 @@ from truss.contexts.image_builder.serving_image_builder import (
|
|
|
20
24
|
ServingImageBuilderContext,
|
|
21
25
|
)
|
|
22
26
|
from truss.contexts.local_loader.docker_build_emulator import DockerBuildEmulator
|
|
27
|
+
from truss.remote.baseten.core import ModelVersionHandle
|
|
28
|
+
from truss.remote.baseten.remote import BasetenRemote
|
|
23
29
|
from truss.truss_handle.build import init_directory
|
|
24
30
|
from truss.truss_handle.truss_handle import TrussHandle
|
|
25
31
|
|
|
@@ -856,3 +862,184 @@ def trtllm_spec_dec_config_lookahead_v1(trtllm_config) -> Dict[str, Any]:
|
|
|
856
862
|
}
|
|
857
863
|
}
|
|
858
864
|
return spec_dec_config
|
|
865
|
+
|
|
866
|
+
|
|
867
|
+
@pytest.fixture
|
|
868
|
+
def remote_url():
|
|
869
|
+
return "http://test_remote.com"
|
|
870
|
+
|
|
871
|
+
|
|
872
|
+
@pytest.fixture
|
|
873
|
+
def truss_rc_content():
|
|
874
|
+
return """
|
|
875
|
+
[baseten]
|
|
876
|
+
remote_provider = baseten
|
|
877
|
+
api_key = test_key
|
|
878
|
+
remote_url = http://test.com
|
|
879
|
+
""".strip()
|
|
880
|
+
|
|
881
|
+
|
|
882
|
+
@pytest.fixture
|
|
883
|
+
def remote_graphql_path(remote_url):
|
|
884
|
+
return f"{remote_url}/graphql/"
|
|
885
|
+
|
|
886
|
+
|
|
887
|
+
@pytest.fixture
|
|
888
|
+
def remote(remote_url):
|
|
889
|
+
return BasetenRemote(remote_url, "api_key")
|
|
890
|
+
|
|
891
|
+
|
|
892
|
+
@pytest.fixture
|
|
893
|
+
def model_response():
|
|
894
|
+
return {
|
|
895
|
+
"data": {
|
|
896
|
+
"model": {
|
|
897
|
+
"name": "model_name",
|
|
898
|
+
"id": "model_id",
|
|
899
|
+
"primary_version": {"id": "version_id"},
|
|
900
|
+
}
|
|
901
|
+
}
|
|
902
|
+
}
|
|
903
|
+
|
|
904
|
+
|
|
905
|
+
@pytest.fixture
|
|
906
|
+
def mock_model_version_handle():
|
|
907
|
+
return ModelVersionHandle(
|
|
908
|
+
version_id="version_id", model_id="model_id", hostname="hostname"
|
|
909
|
+
)
|
|
910
|
+
|
|
911
|
+
|
|
912
|
+
@pytest.fixture
|
|
913
|
+
def setup_push_mocks(model_response, remote_graphql_path):
|
|
914
|
+
def _setup(m):
|
|
915
|
+
# Mock for get_model query - matches queries containing "model(name"
|
|
916
|
+
m.post(
|
|
917
|
+
remote_graphql_path,
|
|
918
|
+
json=model_response,
|
|
919
|
+
additional_matcher=lambda req: "model(name" in req.json().get("query", ""),
|
|
920
|
+
)
|
|
921
|
+
# Mock for validate_truss query - matches queries containing "truss_validation"
|
|
922
|
+
m.post(
|
|
923
|
+
remote_graphql_path,
|
|
924
|
+
json={"data": {"truss_validation": {"success": True, "details": "{}"}}},
|
|
925
|
+
additional_matcher=lambda req: "truss_validation"
|
|
926
|
+
in req.json().get("query", ""),
|
|
927
|
+
)
|
|
928
|
+
# Mock for model_s3_upload_credentials query
|
|
929
|
+
m.post(
|
|
930
|
+
remote_graphql_path,
|
|
931
|
+
json={
|
|
932
|
+
"data": {
|
|
933
|
+
"model_s3_upload_credentials": {
|
|
934
|
+
"s3_bucket": "bucket",
|
|
935
|
+
"s3_key": "key",
|
|
936
|
+
"aws_access_key_id": "key_id",
|
|
937
|
+
"aws_secret_access_key": "secret",
|
|
938
|
+
"aws_session_token": "token",
|
|
939
|
+
}
|
|
940
|
+
}
|
|
941
|
+
},
|
|
942
|
+
additional_matcher=lambda req: "model_s3_upload_credentials"
|
|
943
|
+
in req.json().get("query", ""),
|
|
944
|
+
)
|
|
945
|
+
m.post(
|
|
946
|
+
"http://test_remote.com/v1/models/model_id/upload",
|
|
947
|
+
json={"s3_bucket": "bucket", "s3_key": "key"},
|
|
948
|
+
)
|
|
949
|
+
m.post(
|
|
950
|
+
"http://test_remote.com/v1/blobs/credentials/truss",
|
|
951
|
+
json={
|
|
952
|
+
"s3_bucket": "bucket",
|
|
953
|
+
"s3_key": "key",
|
|
954
|
+
"aws_access_key_id": "key_id",
|
|
955
|
+
"aws_secret_access_key": "secret",
|
|
956
|
+
"aws_session_token": "token",
|
|
957
|
+
},
|
|
958
|
+
)
|
|
959
|
+
# Mock for create_model_version_from_truss mutation
|
|
960
|
+
m.post(
|
|
961
|
+
"http://test_remote.com/graphql/",
|
|
962
|
+
json={
|
|
963
|
+
"data": {
|
|
964
|
+
"create_model_version_from_truss": {
|
|
965
|
+
"model_version": {
|
|
966
|
+
"id": "version_id",
|
|
967
|
+
"oracle": {"id": "model_id", "hostname": "hostname"},
|
|
968
|
+
}
|
|
969
|
+
}
|
|
970
|
+
}
|
|
971
|
+
},
|
|
972
|
+
additional_matcher=lambda req: "create_model_version_from_truss"
|
|
973
|
+
in req.json().get("query", ""),
|
|
974
|
+
)
|
|
975
|
+
|
|
976
|
+
return _setup
|
|
977
|
+
|
|
978
|
+
|
|
979
|
+
@pytest.fixture
|
|
980
|
+
def mock_baseten_requests(setup_push_mocks):
|
|
981
|
+
"""Fixture that provides a configured requests_mock.Mocker with push mocks setup."""
|
|
982
|
+
with requests_mock.Mocker() as m:
|
|
983
|
+
setup_push_mocks(m)
|
|
984
|
+
yield m
|
|
985
|
+
|
|
986
|
+
|
|
987
|
+
@pytest.fixture
|
|
988
|
+
def mock_remote_factory():
|
|
989
|
+
"""Fixture that mocks RemoteFactory.create and returns a configured mock remote."""
|
|
990
|
+
from unittest.mock import MagicMock, patch
|
|
991
|
+
|
|
992
|
+
from truss.remote.remote_factory import RemoteFactory
|
|
993
|
+
|
|
994
|
+
with patch.object(RemoteFactory, "create") as mock_factory:
|
|
995
|
+
mock_remote = MagicMock()
|
|
996
|
+
mock_service = MagicMock()
|
|
997
|
+
mock_service.model_id = "model_id"
|
|
998
|
+
mock_service.model_version_id = "version_id"
|
|
999
|
+
mock_remote.push.return_value = mock_service
|
|
1000
|
+
mock_factory.return_value = mock_remote
|
|
1001
|
+
yield mock_remote
|
|
1002
|
+
|
|
1003
|
+
|
|
1004
|
+
@pytest.fixture
|
|
1005
|
+
def temp_trussrc_dir(truss_rc_content):
|
|
1006
|
+
"""Fixture that creates a temporary directory with a .trussrc file."""
|
|
1007
|
+
with tempfile.TemporaryDirectory() as tmpdir:
|
|
1008
|
+
trussrc_path = pathlib.Path(tmpdir) / ".trussrc"
|
|
1009
|
+
trussrc_path.write_text(truss_rc_content)
|
|
1010
|
+
yield tmpdir
|
|
1011
|
+
|
|
1012
|
+
|
|
1013
|
+
@pytest.fixture
|
|
1014
|
+
def mock_available_config_names():
|
|
1015
|
+
"""Fixture that patches RemoteFactory.get_available_config_names."""
|
|
1016
|
+
from unittest.mock import patch
|
|
1017
|
+
|
|
1018
|
+
with patch(
|
|
1019
|
+
"truss.api.RemoteFactory.get_available_config_names", return_value=["baseten"]
|
|
1020
|
+
):
|
|
1021
|
+
yield
|
|
1022
|
+
|
|
1023
|
+
|
|
1024
|
+
@pytest.fixture
|
|
1025
|
+
def mock_upload_truss():
|
|
1026
|
+
"""Fixture that patches upload_truss and returns a mock."""
|
|
1027
|
+
with mock.patch("truss.remote.baseten.remote.upload_truss") as mock_upload:
|
|
1028
|
+
mock_upload.return_value = "s3_key"
|
|
1029
|
+
yield mock_upload
|
|
1030
|
+
|
|
1031
|
+
|
|
1032
|
+
@pytest.fixture
|
|
1033
|
+
def mock_create_truss_service(mock_model_version_handle):
|
|
1034
|
+
"""Fixture that patches create_truss_service and returns a mock."""
|
|
1035
|
+
with mock.patch("truss.remote.baseten.remote.create_truss_service") as mock_create:
|
|
1036
|
+
mock_create.return_value = mock_model_version_handle
|
|
1037
|
+
yield mock_create
|
|
1038
|
+
|
|
1039
|
+
|
|
1040
|
+
@pytest.fixture
|
|
1041
|
+
def mock_truss_handle(custom_model_truss_dir_with_pre_and_post):
|
|
1042
|
+
from truss.truss_handle.truss_handle import TrussHandle
|
|
1043
|
+
|
|
1044
|
+
truss_handle = TrussHandle(custom_model_truss_dir_with_pre_and_post)
|
|
1045
|
+
return truss_handle
|
|
@@ -100,7 +100,8 @@ def flatten_cached_files(local_cache_files):
|
|
|
100
100
|
def test_correct_hf_files_accessed_for_caching():
|
|
101
101
|
model = "openai/whisper-small"
|
|
102
102
|
config = TrussConfig(
|
|
103
|
-
python_version="py39",
|
|
103
|
+
python_version="py39",
|
|
104
|
+
model_cache=ModelCache([ModelRepo(repo_id=model, use_volume=False)]),
|
|
104
105
|
)
|
|
105
106
|
|
|
106
107
|
with TemporaryDirectory() as tmp_dir:
|
|
@@ -137,7 +138,8 @@ def test_correct_gcs_files_accessed_for_caching(mock_list_bucket_files):
|
|
|
137
138
|
model = "gs://crazy-good-new-model-7b"
|
|
138
139
|
|
|
139
140
|
config = TrussConfig(
|
|
140
|
-
python_version="py39",
|
|
141
|
+
python_version="py39",
|
|
142
|
+
model_cache=ModelCache([ModelRepo(repo_id=model, use_volume=False)]),
|
|
141
143
|
)
|
|
142
144
|
|
|
143
145
|
with TemporaryDirectory() as tmp_dir:
|
|
@@ -172,7 +174,8 @@ def test_correct_s3_files_accessed_for_caching(mock_list_bucket_files):
|
|
|
172
174
|
model = "s3://crazy-good-new-model-7b"
|
|
173
175
|
|
|
174
176
|
config = TrussConfig(
|
|
175
|
-
python_version="py39",
|
|
177
|
+
python_version="py39",
|
|
178
|
+
model_cache=ModelCache([ModelRepo(repo_id=model, use_volume=False)]),
|
|
176
179
|
)
|
|
177
180
|
|
|
178
181
|
with TemporaryDirectory() as tmp_dir:
|
|
@@ -207,7 +210,8 @@ def test_correct_nested_gcs_files_accessed_for_caching(mock_list_bucket_files):
|
|
|
207
210
|
model = "gs://crazy-good-new-model-7b/folder_a/folder_b"
|
|
208
211
|
|
|
209
212
|
config = TrussConfig(
|
|
210
|
-
python_version="py39",
|
|
213
|
+
python_version="py39",
|
|
214
|
+
model_cache=ModelCache([ModelRepo(repo_id=model, use_volume=False)]),
|
|
211
215
|
)
|
|
212
216
|
|
|
213
217
|
with TemporaryDirectory() as tmp_dir:
|
|
@@ -246,7 +250,8 @@ def test_correct_nested_s3_files_accessed_for_caching(mock_list_bucket_files):
|
|
|
246
250
|
model = "s3://crazy-good-new-model-7b/folder_a/folder_b"
|
|
247
251
|
|
|
248
252
|
config = TrussConfig(
|
|
249
|
-
python_version="py39",
|
|
253
|
+
python_version="py39",
|
|
254
|
+
model_cache=ModelCache([ModelRepo(repo_id=model, use_volume=False)]),
|
|
250
255
|
)
|
|
251
256
|
|
|
252
257
|
with TemporaryDirectory() as tmp_dir:
|