truss 0.11.18rc500__py3-none-any.whl → 0.11.24rc2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- truss/api/__init__.py +5 -2
- truss/base/truss_config.py +10 -3
- truss/cli/chains_commands.py +39 -1
- truss/cli/cli.py +35 -5
- truss/cli/remote_cli.py +29 -0
- truss/cli/resolvers/chain_team_resolver.py +82 -0
- truss/cli/resolvers/model_team_resolver.py +90 -0
- truss/cli/resolvers/training_project_team_resolver.py +81 -0
- truss/cli/train/cache.py +332 -0
- truss/cli/train/core.py +19 -143
- truss/cli/train_commands.py +69 -11
- truss/cli/utils/common.py +40 -3
- truss/remote/baseten/api.py +58 -5
- truss/remote/baseten/core.py +22 -4
- truss/remote/baseten/remote.py +24 -2
- truss/templates/control/control/helpers/inference_server_process_controller.py +3 -1
- truss/templates/server/requirements.txt +1 -1
- truss/templates/server.Dockerfile.jinja +10 -10
- truss/templates/shared/util.py +6 -5
- truss/tests/cli/chains/test_chains_team_parameter.py +443 -0
- truss/tests/cli/test_chains_cli.py +44 -0
- truss/tests/cli/test_cli.py +134 -1
- truss/tests/cli/test_cli_utils_common.py +11 -0
- truss/tests/cli/test_model_team_resolver.py +279 -0
- truss/tests/cli/train/test_cache_view.py +240 -3
- truss/tests/cli/train/test_train_cli_core.py +2 -2
- truss/tests/cli/train/test_train_team_parameter.py +395 -0
- truss/tests/conftest.py +187 -0
- truss/tests/contexts/image_builder/test_serving_image_builder.py +10 -5
- truss/tests/remote/baseten/test_api.py +122 -3
- truss/tests/remote/baseten/test_chain_upload.py +10 -1
- truss/tests/remote/baseten/test_core.py +86 -0
- truss/tests/remote/baseten/test_remote.py +216 -288
- truss/tests/test_config.py +21 -12
- truss/tests/test_data/test_build_commands_truss/__init__.py +0 -0
- truss/tests/test_data/test_build_commands_truss/config.yaml +14 -0
- truss/tests/test_data/test_build_commands_truss/model/model.py +12 -0
- truss/tests/test_data/test_build_commands_truss/packages/constants/constants.py +1 -0
- truss/tests/test_data/test_truss_server_model_cache_v1/config.yaml +1 -0
- truss/tests/test_model_inference.py +13 -0
- {truss-0.11.18rc500.dist-info → truss-0.11.24rc2.dist-info}/METADATA +1 -1
- {truss-0.11.18rc500.dist-info → truss-0.11.24rc2.dist-info}/RECORD +50 -38
- {truss-0.11.18rc500.dist-info → truss-0.11.24rc2.dist-info}/WHEEL +1 -1
- truss_chains/deployment/deployment_client.py +9 -4
- truss_chains/private_types.py +15 -0
- truss_train/definitions.py +3 -1
- truss_train/deployment.py +43 -21
- truss_train/public_api.py +4 -2
- {truss-0.11.18rc500.dist-info → truss-0.11.24rc2.dist-info}/entry_points.txt +0 -0
- {truss-0.11.18rc500.dist-info → truss-0.11.24rc2.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,443 @@
|
|
|
1
|
+
"""Tests for team parameter in chain deployment.
|
|
2
|
+
|
|
3
|
+
This test suite covers all 8 scenarios for team resolution in truss chains push:
|
|
4
|
+
1. --team PROVIDED: Valid team name, user has access, user has 1 existing team
|
|
5
|
+
2. --team PROVIDED: Valid team name, user has access, user has multiple existing teams
|
|
6
|
+
3. --team PROVIDED: Invalid team name (does not exist)
|
|
7
|
+
4. --team NOT PROVIDED: User has multiple teams, no existing chain
|
|
8
|
+
5. --team NOT PROVIDED: User has multiple teams, existing chain exists in multiple teams
|
|
9
|
+
6. --team NOT PROVIDED: User has multiple teams, existing chain in exactly one team
|
|
10
|
+
7. --team NOT PROVIDED: User has exactly one team, no existing chain
|
|
11
|
+
8. --team NOT PROVIDED: User has exactly one team, existing chain matches the team
|
|
12
|
+
"""
|
|
13
|
+
|
|
14
|
+
from pathlib import Path
|
|
15
|
+
from unittest.mock import Mock, patch
|
|
16
|
+
|
|
17
|
+
from click.testing import CliRunner
|
|
18
|
+
|
|
19
|
+
from truss.cli.cli import truss_cli
|
|
20
|
+
from truss.remote.baseten.remote import BasetenRemote
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class TestChainsTeamParameter:
|
|
24
|
+
"""Test team parameter in chain deployment using Given-When-Then format."""
|
|
25
|
+
|
|
26
|
+
@staticmethod
|
|
27
|
+
def _given_mock_remote(teams):
|
|
28
|
+
"""Given: A mock remote provider with specified teams."""
|
|
29
|
+
mock_remote = Mock(spec=BasetenRemote)
|
|
30
|
+
mock_api = Mock()
|
|
31
|
+
mock_remote.api = mock_api
|
|
32
|
+
mock_api.get_teams.return_value = teams
|
|
33
|
+
return mock_remote
|
|
34
|
+
|
|
35
|
+
@staticmethod
|
|
36
|
+
def _given_test_chain_file():
|
|
37
|
+
"""Given: A test chain file."""
|
|
38
|
+
chain_path = Path("/tmp/test_chain.py")
|
|
39
|
+
chain_path.parent.mkdir(parents=True, exist_ok=True)
|
|
40
|
+
chain_path.write_text(
|
|
41
|
+
"""
|
|
42
|
+
from truss_chains import Chainlet, mark_entrypoint
|
|
43
|
+
|
|
44
|
+
@mark_entrypoint
|
|
45
|
+
class TestChain(Chainlet[str, str]):
|
|
46
|
+
def run(self, inp: str) -> str:
|
|
47
|
+
return inp
|
|
48
|
+
"""
|
|
49
|
+
)
|
|
50
|
+
return chain_path
|
|
51
|
+
|
|
52
|
+
@staticmethod
|
|
53
|
+
def _given_mock_chainlet():
|
|
54
|
+
"""Given: A mock chainlet."""
|
|
55
|
+
mock_chainlet = Mock()
|
|
56
|
+
mock_meta_data = Mock()
|
|
57
|
+
mock_meta_data.chain_name = None
|
|
58
|
+
mock_chainlet.meta_data = mock_meta_data
|
|
59
|
+
mock_chainlet.display_name = "TestChain"
|
|
60
|
+
return mock_chainlet
|
|
61
|
+
|
|
62
|
+
@staticmethod
|
|
63
|
+
def _given_mock_chainlet_importer(mock_import_target, chainlet):
|
|
64
|
+
"""Given: A mock chainlet importer context manager."""
|
|
65
|
+
context_manager = Mock()
|
|
66
|
+
context_manager.__enter__ = Mock(return_value=chainlet)
|
|
67
|
+
context_manager.__exit__ = Mock(return_value=None)
|
|
68
|
+
mock_import_target.return_value = context_manager
|
|
69
|
+
|
|
70
|
+
@staticmethod
|
|
71
|
+
def _given_mock_chain_service():
|
|
72
|
+
"""Given: A mock chain service."""
|
|
73
|
+
mock_service = Mock()
|
|
74
|
+
mock_service.name = "TestChain"
|
|
75
|
+
mock_service.status_page_url = "https://app.baseten.co/chains/test123/overview"
|
|
76
|
+
mock_service.run_remote_url = "https://app.baseten.co/chains/test123/run_remote"
|
|
77
|
+
mock_service.is_websocket = False
|
|
78
|
+
mock_chainlet_info = Mock()
|
|
79
|
+
mock_chainlet_info.is_entrypoint = True
|
|
80
|
+
mock_chainlet_info.name = "TestChain"
|
|
81
|
+
mock_chainlet_info.status = "ACTIVE"
|
|
82
|
+
mock_chainlet_info.logs_url = "https://app.baseten.co/chains/test123/logs"
|
|
83
|
+
mock_service.get_info.return_value = [mock_chainlet_info]
|
|
84
|
+
return mock_service
|
|
85
|
+
|
|
86
|
+
@staticmethod
|
|
87
|
+
def _when_invoke_chains_push(
|
|
88
|
+
runner, chain_path, team_name=None, remote="test_remote"
|
|
89
|
+
):
|
|
90
|
+
"""When: Invoking the chains push command."""
|
|
91
|
+
args = ["chains", "push", str(chain_path), "--remote", remote]
|
|
92
|
+
if team_name:
|
|
93
|
+
args.extend(["--team", team_name])
|
|
94
|
+
return runner.invoke(truss_cli, args)
|
|
95
|
+
|
|
96
|
+
@staticmethod
|
|
97
|
+
def _then_assert_push_called_with_team(
|
|
98
|
+
mock_push, expected_team_id, expected_chain_name
|
|
99
|
+
):
|
|
100
|
+
"""Then: Assert push was called with correct team_id."""
|
|
101
|
+
mock_push.assert_called_once()
|
|
102
|
+
call_args = mock_push.call_args
|
|
103
|
+
options = call_args[0][1] # Second argument is options
|
|
104
|
+
assert options.chain_name == expected_chain_name
|
|
105
|
+
assert options.team_id == expected_team_id
|
|
106
|
+
|
|
107
|
+
def _patch_isinstance_for_mock_service(self, chain_service):
|
|
108
|
+
"""Helper to patch isinstance for mock service."""
|
|
109
|
+
from truss_chains.deployment import deployment_client
|
|
110
|
+
|
|
111
|
+
baseten_service_class = deployment_client.BasetenChainService
|
|
112
|
+
|
|
113
|
+
import builtins
|
|
114
|
+
|
|
115
|
+
original_isinstance = builtins.isinstance
|
|
116
|
+
|
|
117
|
+
def mock_isinstance(obj, cls):
|
|
118
|
+
if obj is chain_service and cls == baseten_service_class:
|
|
119
|
+
return True
|
|
120
|
+
return original_isinstance(obj, cls)
|
|
121
|
+
|
|
122
|
+
builtins.isinstance = mock_isinstance
|
|
123
|
+
return original_isinstance
|
|
124
|
+
|
|
125
|
+
def _restore_isinstance(self, original_isinstance):
|
|
126
|
+
"""Helper to restore original isinstance."""
|
|
127
|
+
import builtins
|
|
128
|
+
|
|
129
|
+
builtins.isinstance = original_isinstance
|
|
130
|
+
|
|
131
|
+
@patch("truss_chains.deployment.deployment_client.push")
|
|
132
|
+
@patch("truss.cli.chains_commands.RemoteFactory.create")
|
|
133
|
+
@patch("truss_chains.framework.ChainletImporter.import_target")
|
|
134
|
+
def test_scenario_1_team_provided_valid_team_name_single_team(
|
|
135
|
+
self, mock_import_target, mock_remote_factory, mock_push
|
|
136
|
+
):
|
|
137
|
+
"""
|
|
138
|
+
Given: User has 1 team ("Team Alpha") with id "team1"
|
|
139
|
+
When: User runs chains push with --team "Team Alpha"
|
|
140
|
+
Then: Chain is deployed with team_id="team1" and exit code 0
|
|
141
|
+
"""
|
|
142
|
+
teams = {"Team Alpha": {"id": "team1", "name": "Team Alpha"}}
|
|
143
|
+
chainlet = self._given_mock_chainlet()
|
|
144
|
+
chain_service = self._given_mock_chain_service()
|
|
145
|
+
|
|
146
|
+
mock_remote = self._given_mock_remote(teams)
|
|
147
|
+
mock_remote_factory.return_value = mock_remote
|
|
148
|
+
self._given_mock_chainlet_importer(mock_import_target, chainlet)
|
|
149
|
+
mock_push.return_value = chain_service
|
|
150
|
+
|
|
151
|
+
original_isinstance = self._patch_isinstance_for_mock_service(chain_service)
|
|
152
|
+
|
|
153
|
+
try:
|
|
154
|
+
runner = CliRunner()
|
|
155
|
+
chain_path = self._given_test_chain_file()
|
|
156
|
+
result = self._when_invoke_chains_push(
|
|
157
|
+
runner, chain_path, team_name="Team Alpha"
|
|
158
|
+
)
|
|
159
|
+
|
|
160
|
+
assert result.exit_code == 0, (
|
|
161
|
+
f"Expected exit code 0, got {result.exit_code}. Output: {result.output}"
|
|
162
|
+
)
|
|
163
|
+
self._then_assert_push_called_with_team(mock_push, "team1", "TestChain")
|
|
164
|
+
finally:
|
|
165
|
+
self._restore_isinstance(original_isinstance)
|
|
166
|
+
|
|
167
|
+
@patch("truss_chains.deployment.deployment_client.push")
|
|
168
|
+
@patch("truss.cli.chains_commands.RemoteFactory.create")
|
|
169
|
+
@patch("truss_chains.framework.ChainletImporter.import_target")
|
|
170
|
+
def test_scenario_2_team_provided_valid_team_name_multiple_teams(
|
|
171
|
+
self, mock_import_target, mock_remote_factory, mock_push
|
|
172
|
+
):
|
|
173
|
+
"""
|
|
174
|
+
Given: User has 3 teams with "Team Alpha" having id "team1"
|
|
175
|
+
When: User runs chains push with --team "Team Alpha"
|
|
176
|
+
Then: Chain is deployed with team_id="team1" and exit code 0
|
|
177
|
+
"""
|
|
178
|
+
teams = {
|
|
179
|
+
"Team Alpha": {"id": "team1", "name": "Team Alpha"},
|
|
180
|
+
"Team Beta": {"id": "team2", "name": "Team Beta"},
|
|
181
|
+
"Team Gamma": {"id": "team3", "name": "Team Gamma"},
|
|
182
|
+
}
|
|
183
|
+
chainlet = self._given_mock_chainlet()
|
|
184
|
+
chain_service = self._given_mock_chain_service()
|
|
185
|
+
|
|
186
|
+
mock_remote = self._given_mock_remote(teams)
|
|
187
|
+
mock_remote_factory.return_value = mock_remote
|
|
188
|
+
self._given_mock_chainlet_importer(mock_import_target, chainlet)
|
|
189
|
+
mock_push.return_value = chain_service
|
|
190
|
+
|
|
191
|
+
original_isinstance = self._patch_isinstance_for_mock_service(chain_service)
|
|
192
|
+
|
|
193
|
+
try:
|
|
194
|
+
runner = CliRunner()
|
|
195
|
+
chain_path = self._given_test_chain_file()
|
|
196
|
+
result = self._when_invoke_chains_push(
|
|
197
|
+
runner, chain_path, team_name="Team Alpha"
|
|
198
|
+
)
|
|
199
|
+
|
|
200
|
+
assert result.exit_code == 0, (
|
|
201
|
+
f"Expected exit code 0, got {result.exit_code}. Output: {result.output}"
|
|
202
|
+
)
|
|
203
|
+
self._then_assert_push_called_with_team(mock_push, "team1", "TestChain")
|
|
204
|
+
finally:
|
|
205
|
+
self._restore_isinstance(original_isinstance)
|
|
206
|
+
|
|
207
|
+
@patch("truss.cli.chains_commands.RemoteFactory.create")
|
|
208
|
+
@patch("truss_chains.framework.ChainletImporter.import_target")
|
|
209
|
+
def test_scenario_3_team_provided_invalid_team_name(
|
|
210
|
+
self, mock_import_target, mock_remote_factory
|
|
211
|
+
):
|
|
212
|
+
"""
|
|
213
|
+
Given: User has 1 team ("Team Alpha"), but not "NonExistentTeam"
|
|
214
|
+
When: User runs chains push with --team "NonExistentTeam"
|
|
215
|
+
Then: Command fails with exit code 1 and error message about team not existing
|
|
216
|
+
"""
|
|
217
|
+
teams = {"Team Alpha": {"id": "team1", "name": "Team Alpha"}}
|
|
218
|
+
chainlet = self._given_mock_chainlet()
|
|
219
|
+
|
|
220
|
+
mock_remote = self._given_mock_remote(teams)
|
|
221
|
+
mock_remote_factory.return_value = mock_remote
|
|
222
|
+
self._given_mock_chainlet_importer(mock_import_target, chainlet)
|
|
223
|
+
|
|
224
|
+
runner = CliRunner()
|
|
225
|
+
chain_path = self._given_test_chain_file()
|
|
226
|
+
result = self._when_invoke_chains_push(
|
|
227
|
+
runner, chain_path, team_name="NonExistentTeam"
|
|
228
|
+
)
|
|
229
|
+
|
|
230
|
+
assert result.exit_code == 1
|
|
231
|
+
assert "does not exist" in result.output
|
|
232
|
+
assert "NonExistentTeam" in result.output
|
|
233
|
+
|
|
234
|
+
@patch("truss_chains.deployment.deployment_client.push")
|
|
235
|
+
@patch("truss.cli.chains_commands.RemoteFactory.create")
|
|
236
|
+
@patch("truss.cli.remote_cli.inquire_team")
|
|
237
|
+
@patch("truss_chains.framework.ChainletImporter.import_target")
|
|
238
|
+
def test_scenario_4_multiple_teams_no_existing_chain(
|
|
239
|
+
self, mock_import_target, mock_inquire_team, mock_remote_factory, mock_push
|
|
240
|
+
):
|
|
241
|
+
"""
|
|
242
|
+
Given: User has 3 teams, no existing chain named "TestChain"
|
|
243
|
+
When: User runs chains push without --team and selects "Team Beta"
|
|
244
|
+
Then: Chain is deployed with team_id="team2" and exit code 0
|
|
245
|
+
"""
|
|
246
|
+
teams = {
|
|
247
|
+
"Team Alpha": {"id": "team1", "name": "Team Alpha"},
|
|
248
|
+
"Team Beta": {"id": "team2", "name": "Team Beta"},
|
|
249
|
+
"Team Gamma": {"id": "team3", "name": "Team Gamma"},
|
|
250
|
+
}
|
|
251
|
+
chainlet = self._given_mock_chainlet()
|
|
252
|
+
chain_service = self._given_mock_chain_service()
|
|
253
|
+
|
|
254
|
+
mock_remote = self._given_mock_remote(teams)
|
|
255
|
+
mock_remote.api.get_chains.return_value = []
|
|
256
|
+
mock_remote_factory.return_value = mock_remote
|
|
257
|
+
self._given_mock_chainlet_importer(mock_import_target, chainlet)
|
|
258
|
+
mock_inquire_team.return_value = "Team Beta"
|
|
259
|
+
mock_push.return_value = chain_service
|
|
260
|
+
|
|
261
|
+
original_isinstance = self._patch_isinstance_for_mock_service(chain_service)
|
|
262
|
+
|
|
263
|
+
try:
|
|
264
|
+
runner = CliRunner()
|
|
265
|
+
chain_path = self._given_test_chain_file()
|
|
266
|
+
result = self._when_invoke_chains_push(runner, chain_path)
|
|
267
|
+
|
|
268
|
+
assert result.exit_code == 0, (
|
|
269
|
+
f"Expected exit code 0, got {result.exit_code}. Output: {result.output}"
|
|
270
|
+
)
|
|
271
|
+
mock_inquire_team.assert_called_once()
|
|
272
|
+
assert mock_inquire_team.call_args[1]["existing_teams"] == teams
|
|
273
|
+
self._then_assert_push_called_with_team(mock_push, "team2", "TestChain")
|
|
274
|
+
finally:
|
|
275
|
+
self._restore_isinstance(original_isinstance)
|
|
276
|
+
|
|
277
|
+
@patch("truss_chains.deployment.deployment_client.push")
|
|
278
|
+
@patch("truss.cli.chains_commands.RemoteFactory.create")
|
|
279
|
+
@patch("truss.cli.remote_cli.inquire_team")
|
|
280
|
+
@patch("truss_chains.framework.ChainletImporter.import_target")
|
|
281
|
+
def test_scenario_5_multiple_teams_existing_chain_in_multiple_teams(
|
|
282
|
+
self, mock_import_target, mock_inquire_team, mock_remote_factory, mock_push
|
|
283
|
+
):
|
|
284
|
+
"""
|
|
285
|
+
Given: User has 3 teams, existing chain "TestChain" in "Team Alpha" and "Team Beta"
|
|
286
|
+
When: User runs chains push without --team and selects "Team Alpha"
|
|
287
|
+
Then: Chain is deployed with team_id="team1" and exit code 0
|
|
288
|
+
"""
|
|
289
|
+
teams = {
|
|
290
|
+
"Team Alpha": {"id": "team1", "name": "Team Alpha"},
|
|
291
|
+
"Team Beta": {"id": "team2", "name": "Team Beta"},
|
|
292
|
+
"Team Gamma": {"id": "team3", "name": "Team Gamma"},
|
|
293
|
+
}
|
|
294
|
+
existing_chains = [
|
|
295
|
+
{"id": "chain123", "name": "TestChain", "team": {"name": "Team Alpha"}},
|
|
296
|
+
{"id": "chain456", "name": "TestChain", "team": {"name": "Team Beta"}},
|
|
297
|
+
]
|
|
298
|
+
chainlet = self._given_mock_chainlet()
|
|
299
|
+
chain_service = self._given_mock_chain_service()
|
|
300
|
+
|
|
301
|
+
mock_remote = self._given_mock_remote(teams)
|
|
302
|
+
mock_remote.api.get_chains.return_value = existing_chains
|
|
303
|
+
mock_remote_factory.return_value = mock_remote
|
|
304
|
+
self._given_mock_chainlet_importer(mock_import_target, chainlet)
|
|
305
|
+
mock_inquire_team.return_value = "Team Alpha"
|
|
306
|
+
mock_push.return_value = chain_service
|
|
307
|
+
|
|
308
|
+
original_isinstance = self._patch_isinstance_for_mock_service(chain_service)
|
|
309
|
+
|
|
310
|
+
try:
|
|
311
|
+
runner = CliRunner()
|
|
312
|
+
chain_path = self._given_test_chain_file()
|
|
313
|
+
result = self._when_invoke_chains_push(runner, chain_path)
|
|
314
|
+
|
|
315
|
+
assert result.exit_code == 0, (
|
|
316
|
+
f"Expected exit code 0, got {result.exit_code}. Output: {result.output}"
|
|
317
|
+
)
|
|
318
|
+
mock_inquire_team.assert_called_once()
|
|
319
|
+
assert mock_inquire_team.call_args[1]["existing_teams"] == teams
|
|
320
|
+
self._then_assert_push_called_with_team(mock_push, "team1", "TestChain")
|
|
321
|
+
finally:
|
|
322
|
+
self._restore_isinstance(original_isinstance)
|
|
323
|
+
|
|
324
|
+
@patch("truss_chains.deployment.deployment_client.push")
|
|
325
|
+
@patch("truss.cli.chains_commands.RemoteFactory.create")
|
|
326
|
+
@patch("truss_chains.framework.ChainletImporter.import_target")
|
|
327
|
+
def test_scenario_6_multiple_teams_existing_chain_in_one_team(
|
|
328
|
+
self, mock_import_target, mock_remote_factory, mock_push
|
|
329
|
+
):
|
|
330
|
+
"""
|
|
331
|
+
Given: User has 3 teams, existing chain "TestChain" only in "Team Beta"
|
|
332
|
+
When: User runs chains push without --team
|
|
333
|
+
Then: Chain is deployed with team_id="team2" (auto-inferred) and exit code 0
|
|
334
|
+
"""
|
|
335
|
+
teams = {
|
|
336
|
+
"Team Alpha": {"id": "team1", "name": "Team Alpha"},
|
|
337
|
+
"Team Beta": {"id": "team2", "name": "Team Beta"},
|
|
338
|
+
"Team Gamma": {"id": "team3", "name": "Team Gamma"},
|
|
339
|
+
}
|
|
340
|
+
existing_chain = {
|
|
341
|
+
"id": "chain123",
|
|
342
|
+
"name": "TestChain",
|
|
343
|
+
"team": {"name": "Team Beta"},
|
|
344
|
+
}
|
|
345
|
+
chainlet = self._given_mock_chainlet()
|
|
346
|
+
chain_service = self._given_mock_chain_service()
|
|
347
|
+
|
|
348
|
+
mock_remote = self._given_mock_remote(teams)
|
|
349
|
+
mock_remote.api.get_chains.return_value = [existing_chain]
|
|
350
|
+
mock_remote_factory.return_value = mock_remote
|
|
351
|
+
self._given_mock_chainlet_importer(mock_import_target, chainlet)
|
|
352
|
+
mock_push.return_value = chain_service
|
|
353
|
+
|
|
354
|
+
original_isinstance = self._patch_isinstance_for_mock_service(chain_service)
|
|
355
|
+
|
|
356
|
+
try:
|
|
357
|
+
runner = CliRunner()
|
|
358
|
+
chain_path = self._given_test_chain_file()
|
|
359
|
+
result = self._when_invoke_chains_push(runner, chain_path)
|
|
360
|
+
|
|
361
|
+
assert result.exit_code == 0, (
|
|
362
|
+
f"Expected exit code 0, got {result.exit_code}. Output: {result.output}"
|
|
363
|
+
)
|
|
364
|
+
self._then_assert_push_called_with_team(mock_push, "team2", "TestChain")
|
|
365
|
+
mock_remote.api.get_chains.assert_called()
|
|
366
|
+
finally:
|
|
367
|
+
self._restore_isinstance(original_isinstance)
|
|
368
|
+
|
|
369
|
+
@patch("truss_chains.deployment.deployment_client.push")
|
|
370
|
+
@patch("truss.cli.chains_commands.RemoteFactory.create")
|
|
371
|
+
@patch("truss_chains.framework.ChainletImporter.import_target")
|
|
372
|
+
def test_scenario_7_single_team_no_existing_chain(
|
|
373
|
+
self, mock_import_target, mock_remote_factory, mock_push
|
|
374
|
+
):
|
|
375
|
+
"""
|
|
376
|
+
Given: User has 1 team ("Team Alpha"), no existing chain
|
|
377
|
+
When: User runs chains push without --team
|
|
378
|
+
Then: Chain is deployed with team_id="team1" (auto-inferred) and exit code 0
|
|
379
|
+
"""
|
|
380
|
+
teams = {"Team Alpha": {"id": "team1", "name": "Team Alpha"}}
|
|
381
|
+
chainlet = self._given_mock_chainlet()
|
|
382
|
+
chain_service = self._given_mock_chain_service()
|
|
383
|
+
|
|
384
|
+
mock_remote = self._given_mock_remote(teams)
|
|
385
|
+
mock_remote.api.get_chains.return_value = []
|
|
386
|
+
mock_remote_factory.return_value = mock_remote
|
|
387
|
+
self._given_mock_chainlet_importer(mock_import_target, chainlet)
|
|
388
|
+
mock_push.return_value = chain_service
|
|
389
|
+
|
|
390
|
+
original_isinstance = self._patch_isinstance_for_mock_service(chain_service)
|
|
391
|
+
|
|
392
|
+
try:
|
|
393
|
+
runner = CliRunner()
|
|
394
|
+
chain_path = self._given_test_chain_file()
|
|
395
|
+
result = self._when_invoke_chains_push(runner, chain_path)
|
|
396
|
+
|
|
397
|
+
assert result.exit_code == 0, (
|
|
398
|
+
f"Expected exit code 0, got {result.exit_code}. Output: {result.output}"
|
|
399
|
+
)
|
|
400
|
+
self._then_assert_push_called_with_team(mock_push, "team1", "TestChain")
|
|
401
|
+
finally:
|
|
402
|
+
self._restore_isinstance(original_isinstance)
|
|
403
|
+
|
|
404
|
+
@patch("truss_chains.deployment.deployment_client.push")
|
|
405
|
+
@patch("truss.cli.chains_commands.RemoteFactory.create")
|
|
406
|
+
@patch("truss_chains.framework.ChainletImporter.import_target")
|
|
407
|
+
def test_scenario_8_single_team_existing_chain_matches_team(
|
|
408
|
+
self, mock_import_target, mock_remote_factory, mock_push
|
|
409
|
+
):
|
|
410
|
+
"""
|
|
411
|
+
Given: User has 1 team ("Team Alpha"), existing chain "TestChain" in "Team Alpha"
|
|
412
|
+
When: User runs chains push without --team
|
|
413
|
+
Then: Chain is deployed with team_id="team1" (auto-inferred) and exit code 0
|
|
414
|
+
"""
|
|
415
|
+
teams = {"Team Alpha": {"id": "team1", "name": "Team Alpha"}}
|
|
416
|
+
existing_chain = {
|
|
417
|
+
"id": "chain123",
|
|
418
|
+
"name": "TestChain",
|
|
419
|
+
"team": {"name": "Team Alpha"},
|
|
420
|
+
}
|
|
421
|
+
chainlet = self._given_mock_chainlet()
|
|
422
|
+
chain_service = self._given_mock_chain_service()
|
|
423
|
+
|
|
424
|
+
mock_remote = self._given_mock_remote(teams)
|
|
425
|
+
mock_remote.api.get_chains.return_value = [existing_chain]
|
|
426
|
+
mock_remote_factory.return_value = mock_remote
|
|
427
|
+
self._given_mock_chainlet_importer(mock_import_target, chainlet)
|
|
428
|
+
mock_push.return_value = chain_service
|
|
429
|
+
|
|
430
|
+
original_isinstance = self._patch_isinstance_for_mock_service(chain_service)
|
|
431
|
+
|
|
432
|
+
try:
|
|
433
|
+
runner = CliRunner()
|
|
434
|
+
chain_path = self._given_test_chain_file()
|
|
435
|
+
result = self._when_invoke_chains_push(runner, chain_path)
|
|
436
|
+
|
|
437
|
+
assert result.exit_code == 0, (
|
|
438
|
+
f"Expected exit code 0, got {result.exit_code}. Output: {result.output}"
|
|
439
|
+
)
|
|
440
|
+
self._then_assert_push_called_with_team(mock_push, "team1", "TestChain")
|
|
441
|
+
mock_remote.api.get_chains.assert_called()
|
|
442
|
+
finally:
|
|
443
|
+
self._restore_isinstance(original_isinstance)
|
|
@@ -98,3 +98,47 @@ def test_chains_push_help_includes_disable_chain_download():
|
|
|
98
98
|
|
|
99
99
|
assert result.exit_code == 0
|
|
100
100
|
assert "--disable-chain-download" in result.output
|
|
101
|
+
|
|
102
|
+
|
|
103
|
+
def test_chains_push_with_deployment_name_flag():
|
|
104
|
+
"""Test that --deployment-name flag is properly parsed and passed through."""
|
|
105
|
+
runner = CliRunner()
|
|
106
|
+
|
|
107
|
+
mock_entrypoint_cls = Mock()
|
|
108
|
+
mock_entrypoint_cls.meta_data.chain_name = "test_chain"
|
|
109
|
+
mock_entrypoint_cls.display_name = "TestChain"
|
|
110
|
+
|
|
111
|
+
mock_service = Mock()
|
|
112
|
+
mock_service.run_remote_url = "http://test.com/run_remote"
|
|
113
|
+
mock_service.is_websocket = False
|
|
114
|
+
|
|
115
|
+
with patch(
|
|
116
|
+
"truss_chains.framework.ChainletImporter.import_target"
|
|
117
|
+
) as mock_importer:
|
|
118
|
+
with patch("truss_chains.deployment.deployment_client.push") as mock_push:
|
|
119
|
+
mock_importer.return_value.__enter__.return_value = mock_entrypoint_cls
|
|
120
|
+
mock_push.return_value = mock_service
|
|
121
|
+
|
|
122
|
+
result = runner.invoke(
|
|
123
|
+
truss_cli,
|
|
124
|
+
[
|
|
125
|
+
"chains",
|
|
126
|
+
"push",
|
|
127
|
+
"test_chain.py",
|
|
128
|
+
"--deployment-name",
|
|
129
|
+
"custom_deployment",
|
|
130
|
+
"--remote",
|
|
131
|
+
"test_remote",
|
|
132
|
+
"--publish",
|
|
133
|
+
"--dryrun",
|
|
134
|
+
],
|
|
135
|
+
)
|
|
136
|
+
|
|
137
|
+
assert result.exit_code == 0
|
|
138
|
+
|
|
139
|
+
mock_push.assert_called_once()
|
|
140
|
+
call_args = mock_push.call_args
|
|
141
|
+
options = call_args[0][1]
|
|
142
|
+
|
|
143
|
+
assert hasattr(options, "deployment_name")
|
|
144
|
+
assert options.deployment_name == "custom_deployment"
|
truss/tests/cli/test_cli.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
from unittest.mock import Mock, patch
|
|
1
|
+
from unittest.mock import MagicMock, Mock, patch
|
|
2
2
|
|
|
3
3
|
from click.testing import CliRunner
|
|
4
4
|
|
|
@@ -23,3 +23,136 @@ def test_push_with_grpc_transport_fails_for_development_deployment():
|
|
|
23
23
|
"Truss with gRPC transport cannot be used as a development deployment"
|
|
24
24
|
in result.output
|
|
25
25
|
)
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def test_cli_push_passes_deploy_timeout_minutes_to_create_truss_service(
|
|
29
|
+
custom_model_truss_dir_with_pre_and_post,
|
|
30
|
+
remote,
|
|
31
|
+
mock_baseten_requests,
|
|
32
|
+
mock_upload_truss,
|
|
33
|
+
mock_create_truss_service,
|
|
34
|
+
):
|
|
35
|
+
runner = CliRunner()
|
|
36
|
+
with patch("truss.cli.cli.RemoteFactory.create", return_value=remote):
|
|
37
|
+
remote.api.get_teams = Mock(return_value={})
|
|
38
|
+
with patch("truss.cli.cli.resolve_model_team_name", return_value=(None, None)):
|
|
39
|
+
result = runner.invoke(
|
|
40
|
+
truss_cli,
|
|
41
|
+
[
|
|
42
|
+
"push",
|
|
43
|
+
str(custom_model_truss_dir_with_pre_and_post),
|
|
44
|
+
"--remote",
|
|
45
|
+
"baseten",
|
|
46
|
+
"--model-name",
|
|
47
|
+
"model_name",
|
|
48
|
+
"--publish",
|
|
49
|
+
"--deploy-timeout-minutes",
|
|
50
|
+
"450",
|
|
51
|
+
],
|
|
52
|
+
)
|
|
53
|
+
|
|
54
|
+
assert result.exit_code == 0
|
|
55
|
+
mock_create_truss_service.assert_called_once()
|
|
56
|
+
_, kwargs = mock_create_truss_service.call_args
|
|
57
|
+
assert kwargs["deploy_timeout_minutes"] == 450
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
def test_cli_push_passes_none_deploy_timeout_minutes_when_not_specified(
|
|
61
|
+
custom_model_truss_dir_with_pre_and_post,
|
|
62
|
+
remote,
|
|
63
|
+
mock_baseten_requests,
|
|
64
|
+
mock_upload_truss,
|
|
65
|
+
mock_create_truss_service,
|
|
66
|
+
):
|
|
67
|
+
runner = CliRunner()
|
|
68
|
+
with patch("truss.cli.cli.RemoteFactory.create", return_value=remote):
|
|
69
|
+
remote.api.get_teams = Mock(return_value={})
|
|
70
|
+
with patch("truss.cli.cli.resolve_model_team_name", return_value=(None, None)):
|
|
71
|
+
result = runner.invoke(
|
|
72
|
+
truss_cli,
|
|
73
|
+
[
|
|
74
|
+
"push",
|
|
75
|
+
str(custom_model_truss_dir_with_pre_and_post),
|
|
76
|
+
"--remote",
|
|
77
|
+
"baseten",
|
|
78
|
+
"--model-name",
|
|
79
|
+
"model_name",
|
|
80
|
+
"--publish",
|
|
81
|
+
],
|
|
82
|
+
)
|
|
83
|
+
|
|
84
|
+
assert result.exit_code == 0
|
|
85
|
+
mock_create_truss_service.assert_called_once()
|
|
86
|
+
_, kwargs = mock_create_truss_service.call_args
|
|
87
|
+
assert kwargs.get("deploy_timeout_minutes") is None
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
def test_cli_push_integration_deploy_timeout_minutes_propagated(
|
|
91
|
+
custom_model_truss_dir_with_pre_and_post,
|
|
92
|
+
remote,
|
|
93
|
+
mock_baseten_requests,
|
|
94
|
+
mock_upload_truss,
|
|
95
|
+
mock_create_truss_service,
|
|
96
|
+
):
|
|
97
|
+
runner = CliRunner()
|
|
98
|
+
with patch("truss.cli.cli.RemoteFactory.create", return_value=remote):
|
|
99
|
+
remote.api.get_teams = Mock(return_value={})
|
|
100
|
+
with patch("truss.cli.cli.resolve_model_team_name", return_value=(None, None)):
|
|
101
|
+
result = runner.invoke(
|
|
102
|
+
truss_cli,
|
|
103
|
+
[
|
|
104
|
+
"push",
|
|
105
|
+
str(custom_model_truss_dir_with_pre_and_post),
|
|
106
|
+
"--remote",
|
|
107
|
+
"baseten",
|
|
108
|
+
"--model-name",
|
|
109
|
+
"model_name",
|
|
110
|
+
"--publish",
|
|
111
|
+
"--environment",
|
|
112
|
+
"staging",
|
|
113
|
+
"--deploy-timeout-minutes",
|
|
114
|
+
"750",
|
|
115
|
+
],
|
|
116
|
+
)
|
|
117
|
+
|
|
118
|
+
assert result.exit_code == 0
|
|
119
|
+
mock_create_truss_service.assert_called_once()
|
|
120
|
+
_, kwargs = mock_create_truss_service.call_args
|
|
121
|
+
assert kwargs["deploy_timeout_minutes"] == 750
|
|
122
|
+
assert kwargs["environment"] == "staging"
|
|
123
|
+
|
|
124
|
+
|
|
125
|
+
def test_cli_push_api_integration_deploy_timeout_minutes_propagated(
|
|
126
|
+
custom_model_truss_dir_with_pre_and_post,
|
|
127
|
+
mock_remote_factory,
|
|
128
|
+
temp_trussrc_dir,
|
|
129
|
+
mock_available_config_names,
|
|
130
|
+
):
|
|
131
|
+
mock_service = MagicMock()
|
|
132
|
+
mock_service.model_id = "model_id"
|
|
133
|
+
mock_service.model_version_id = "version_id"
|
|
134
|
+
mock_remote_factory.push.return_value = mock_service
|
|
135
|
+
|
|
136
|
+
runner = CliRunner()
|
|
137
|
+
with patch(
|
|
138
|
+
"truss.cli.cli.RemoteFactory.get_available_config_names",
|
|
139
|
+
return_value=["baseten"],
|
|
140
|
+
):
|
|
141
|
+
result = runner.invoke(
|
|
142
|
+
truss_cli,
|
|
143
|
+
[
|
|
144
|
+
"push",
|
|
145
|
+
str(custom_model_truss_dir_with_pre_and_post),
|
|
146
|
+
"--remote",
|
|
147
|
+
"baseten",
|
|
148
|
+
"--model-name",
|
|
149
|
+
"test_model",
|
|
150
|
+
"--deploy-timeout-minutes",
|
|
151
|
+
"1200",
|
|
152
|
+
],
|
|
153
|
+
)
|
|
154
|
+
|
|
155
|
+
assert result.exit_code == 0
|
|
156
|
+
mock_remote_factory.push.assert_called_once()
|
|
157
|
+
_, push_kwargs = mock_remote_factory.push.call_args
|
|
158
|
+
assert push_kwargs.get("deploy_timeout_minutes") == 1200
|
|
@@ -0,0 +1,11 @@
|
|
|
1
|
+
from truss.cli.utils import common
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
def test_normalize_iso_timestamp_handles_nanoseconds():
|
|
5
|
+
normalized = common._normalize_iso_timestamp("2025-11-17 05:05:06.000000000 +0000")
|
|
6
|
+
assert normalized == "2025-11-17 05:05:06.000000+00:00"
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def test_normalize_iso_timestamp_handles_z_suffix_and_short_fraction():
|
|
10
|
+
normalized = common._normalize_iso_timestamp("2025-11-17T05:05:06.123456Z")
|
|
11
|
+
assert normalized == "2025-11-17T05:05:06.123456+00:00"
|