truss 0.11.6rc102__py3-none-any.whl → 0.11.24rc2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- truss/api/__init__.py +5 -2
- truss/base/constants.py +1 -0
- truss/base/trt_llm_config.py +14 -3
- truss/base/truss_config.py +19 -4
- truss/cli/chains_commands.py +49 -1
- truss/cli/cli.py +38 -7
- truss/cli/logs/base_watcher.py +31 -12
- truss/cli/logs/model_log_watcher.py +24 -1
- truss/cli/remote_cli.py +29 -0
- truss/cli/resolvers/chain_team_resolver.py +82 -0
- truss/cli/resolvers/model_team_resolver.py +90 -0
- truss/cli/resolvers/training_project_team_resolver.py +81 -0
- truss/cli/train/cache.py +332 -0
- truss/cli/train/core.py +57 -163
- truss/cli/train/deploy_checkpoints/__init__.py +2 -2
- truss/cli/train/deploy_checkpoints/deploy_checkpoints.py +236 -103
- truss/cli/train/deploy_checkpoints/deploy_checkpoints_helpers.py +1 -52
- truss/cli/train/deploy_checkpoints/deploy_full_checkpoints.py +1 -86
- truss/cli/train/deploy_checkpoints/deploy_lora_checkpoints.py +1 -85
- truss/cli/train/deploy_checkpoints/deploy_whisper_checkpoints.py +1 -56
- truss/cli/train/types.py +18 -9
- truss/cli/train_commands.py +180 -35
- truss/cli/utils/common.py +40 -3
- truss/contexts/image_builder/serving_image_builder.py +17 -4
- truss/remote/baseten/api.py +215 -9
- truss/remote/baseten/core.py +63 -7
- truss/remote/baseten/custom_types.py +1 -0
- truss/remote/baseten/remote.py +42 -2
- truss/remote/baseten/service.py +0 -7
- truss/remote/baseten/utils/transfer.py +5 -2
- truss/templates/base.Dockerfile.jinja +8 -4
- truss/templates/control/control/application.py +51 -26
- truss/templates/control/control/endpoints.py +1 -5
- truss/templates/control/control/helpers/inference_server_process_controller.py +10 -4
- truss/templates/control/control/helpers/truss_patch/model_container_patch_applier.py +33 -18
- truss/templates/control/control/server.py +1 -1
- truss/templates/control/requirements.txt +1 -2
- truss/templates/docker_server/proxy.conf.jinja +13 -0
- truss/templates/docker_server/supervisord.conf.jinja +2 -1
- truss/templates/no_build.Dockerfile.jinja +1 -0
- truss/templates/server/requirements.txt +2 -3
- truss/templates/server/truss_server.py +2 -5
- truss/templates/server.Dockerfile.jinja +12 -12
- truss/templates/shared/lazy_data_resolver.py +214 -2
- truss/templates/shared/util.py +6 -5
- truss/tests/cli/chains/test_chains_team_parameter.py +443 -0
- truss/tests/cli/test_chains_cli.py +144 -0
- truss/tests/cli/test_cli.py +134 -1
- truss/tests/cli/test_cli_utils_common.py +11 -0
- truss/tests/cli/test_model_team_resolver.py +279 -0
- truss/tests/cli/train/test_cache_view.py +240 -3
- truss/tests/cli/train/test_deploy_checkpoints.py +2 -846
- truss/tests/cli/train/test_train_cli_core.py +2 -2
- truss/tests/cli/train/test_train_team_parameter.py +395 -0
- truss/tests/conftest.py +187 -0
- truss/tests/contexts/image_builder/test_serving_image_builder.py +10 -5
- truss/tests/remote/baseten/test_api.py +122 -3
- truss/tests/remote/baseten/test_chain_upload.py +294 -0
- truss/tests/remote/baseten/test_core.py +86 -0
- truss/tests/remote/baseten/test_remote.py +216 -288
- truss/tests/remote/baseten/test_service.py +56 -0
- truss/tests/templates/control/control/conftest.py +20 -0
- truss/tests/templates/control/control/test_endpoints.py +4 -0
- truss/tests/templates/control/control/test_server.py +8 -24
- truss/tests/templates/control/control/test_server_integration.py +4 -2
- truss/tests/test_config.py +21 -12
- truss/tests/test_data/server.Dockerfile +3 -1
- truss/tests/test_data/test_build_commands_truss/__init__.py +0 -0
- truss/tests/test_data/test_build_commands_truss/config.yaml +14 -0
- truss/tests/test_data/test_build_commands_truss/model/model.py +12 -0
- truss/tests/test_data/test_build_commands_truss/packages/constants/constants.py +1 -0
- truss/tests/test_data/test_truss_server_model_cache_v1/config.yaml +1 -0
- truss/tests/test_model_inference.py +13 -0
- truss/tests/util/test_env_vars.py +8 -3
- truss/util/__init__.py +0 -0
- truss/util/env_vars.py +19 -8
- truss/util/error_utils.py +37 -0
- {truss-0.11.6rc102.dist-info → truss-0.11.24rc2.dist-info}/METADATA +2 -2
- {truss-0.11.6rc102.dist-info → truss-0.11.24rc2.dist-info}/RECORD +88 -70
- {truss-0.11.6rc102.dist-info → truss-0.11.24rc2.dist-info}/WHEEL +1 -1
- truss_chains/deployment/deployment_client.py +16 -4
- truss_chains/private_types.py +18 -0
- truss_chains/public_api.py +3 -0
- truss_train/definitions.py +6 -4
- truss_train/deployment.py +43 -21
- truss_train/public_api.py +4 -2
- {truss-0.11.6rc102.dist-info → truss-0.11.24rc2.dist-info}/entry_points.txt +0 -0
- {truss-0.11.6rc102.dist-info → truss-0.11.24rc2.dist-info}/licenses/LICENSE +0 -0
truss/tests/cli/test_cli.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
from unittest.mock import Mock, patch
|
|
1
|
+
from unittest.mock import MagicMock, Mock, patch
|
|
2
2
|
|
|
3
3
|
from click.testing import CliRunner
|
|
4
4
|
|
|
@@ -23,3 +23,136 @@ def test_push_with_grpc_transport_fails_for_development_deployment():
|
|
|
23
23
|
"Truss with gRPC transport cannot be used as a development deployment"
|
|
24
24
|
in result.output
|
|
25
25
|
)
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def test_cli_push_passes_deploy_timeout_minutes_to_create_truss_service(
|
|
29
|
+
custom_model_truss_dir_with_pre_and_post,
|
|
30
|
+
remote,
|
|
31
|
+
mock_baseten_requests,
|
|
32
|
+
mock_upload_truss,
|
|
33
|
+
mock_create_truss_service,
|
|
34
|
+
):
|
|
35
|
+
runner = CliRunner()
|
|
36
|
+
with patch("truss.cli.cli.RemoteFactory.create", return_value=remote):
|
|
37
|
+
remote.api.get_teams = Mock(return_value={})
|
|
38
|
+
with patch("truss.cli.cli.resolve_model_team_name", return_value=(None, None)):
|
|
39
|
+
result = runner.invoke(
|
|
40
|
+
truss_cli,
|
|
41
|
+
[
|
|
42
|
+
"push",
|
|
43
|
+
str(custom_model_truss_dir_with_pre_and_post),
|
|
44
|
+
"--remote",
|
|
45
|
+
"baseten",
|
|
46
|
+
"--model-name",
|
|
47
|
+
"model_name",
|
|
48
|
+
"--publish",
|
|
49
|
+
"--deploy-timeout-minutes",
|
|
50
|
+
"450",
|
|
51
|
+
],
|
|
52
|
+
)
|
|
53
|
+
|
|
54
|
+
assert result.exit_code == 0
|
|
55
|
+
mock_create_truss_service.assert_called_once()
|
|
56
|
+
_, kwargs = mock_create_truss_service.call_args
|
|
57
|
+
assert kwargs["deploy_timeout_minutes"] == 450
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
def test_cli_push_passes_none_deploy_timeout_minutes_when_not_specified(
|
|
61
|
+
custom_model_truss_dir_with_pre_and_post,
|
|
62
|
+
remote,
|
|
63
|
+
mock_baseten_requests,
|
|
64
|
+
mock_upload_truss,
|
|
65
|
+
mock_create_truss_service,
|
|
66
|
+
):
|
|
67
|
+
runner = CliRunner()
|
|
68
|
+
with patch("truss.cli.cli.RemoteFactory.create", return_value=remote):
|
|
69
|
+
remote.api.get_teams = Mock(return_value={})
|
|
70
|
+
with patch("truss.cli.cli.resolve_model_team_name", return_value=(None, None)):
|
|
71
|
+
result = runner.invoke(
|
|
72
|
+
truss_cli,
|
|
73
|
+
[
|
|
74
|
+
"push",
|
|
75
|
+
str(custom_model_truss_dir_with_pre_and_post),
|
|
76
|
+
"--remote",
|
|
77
|
+
"baseten",
|
|
78
|
+
"--model-name",
|
|
79
|
+
"model_name",
|
|
80
|
+
"--publish",
|
|
81
|
+
],
|
|
82
|
+
)
|
|
83
|
+
|
|
84
|
+
assert result.exit_code == 0
|
|
85
|
+
mock_create_truss_service.assert_called_once()
|
|
86
|
+
_, kwargs = mock_create_truss_service.call_args
|
|
87
|
+
assert kwargs.get("deploy_timeout_minutes") is None
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
def test_cli_push_integration_deploy_timeout_minutes_propagated(
|
|
91
|
+
custom_model_truss_dir_with_pre_and_post,
|
|
92
|
+
remote,
|
|
93
|
+
mock_baseten_requests,
|
|
94
|
+
mock_upload_truss,
|
|
95
|
+
mock_create_truss_service,
|
|
96
|
+
):
|
|
97
|
+
runner = CliRunner()
|
|
98
|
+
with patch("truss.cli.cli.RemoteFactory.create", return_value=remote):
|
|
99
|
+
remote.api.get_teams = Mock(return_value={})
|
|
100
|
+
with patch("truss.cli.cli.resolve_model_team_name", return_value=(None, None)):
|
|
101
|
+
result = runner.invoke(
|
|
102
|
+
truss_cli,
|
|
103
|
+
[
|
|
104
|
+
"push",
|
|
105
|
+
str(custom_model_truss_dir_with_pre_and_post),
|
|
106
|
+
"--remote",
|
|
107
|
+
"baseten",
|
|
108
|
+
"--model-name",
|
|
109
|
+
"model_name",
|
|
110
|
+
"--publish",
|
|
111
|
+
"--environment",
|
|
112
|
+
"staging",
|
|
113
|
+
"--deploy-timeout-minutes",
|
|
114
|
+
"750",
|
|
115
|
+
],
|
|
116
|
+
)
|
|
117
|
+
|
|
118
|
+
assert result.exit_code == 0
|
|
119
|
+
mock_create_truss_service.assert_called_once()
|
|
120
|
+
_, kwargs = mock_create_truss_service.call_args
|
|
121
|
+
assert kwargs["deploy_timeout_minutes"] == 750
|
|
122
|
+
assert kwargs["environment"] == "staging"
|
|
123
|
+
|
|
124
|
+
|
|
125
|
+
def test_cli_push_api_integration_deploy_timeout_minutes_propagated(
|
|
126
|
+
custom_model_truss_dir_with_pre_and_post,
|
|
127
|
+
mock_remote_factory,
|
|
128
|
+
temp_trussrc_dir,
|
|
129
|
+
mock_available_config_names,
|
|
130
|
+
):
|
|
131
|
+
mock_service = MagicMock()
|
|
132
|
+
mock_service.model_id = "model_id"
|
|
133
|
+
mock_service.model_version_id = "version_id"
|
|
134
|
+
mock_remote_factory.push.return_value = mock_service
|
|
135
|
+
|
|
136
|
+
runner = CliRunner()
|
|
137
|
+
with patch(
|
|
138
|
+
"truss.cli.cli.RemoteFactory.get_available_config_names",
|
|
139
|
+
return_value=["baseten"],
|
|
140
|
+
):
|
|
141
|
+
result = runner.invoke(
|
|
142
|
+
truss_cli,
|
|
143
|
+
[
|
|
144
|
+
"push",
|
|
145
|
+
str(custom_model_truss_dir_with_pre_and_post),
|
|
146
|
+
"--remote",
|
|
147
|
+
"baseten",
|
|
148
|
+
"--model-name",
|
|
149
|
+
"test_model",
|
|
150
|
+
"--deploy-timeout-minutes",
|
|
151
|
+
"1200",
|
|
152
|
+
],
|
|
153
|
+
)
|
|
154
|
+
|
|
155
|
+
assert result.exit_code == 0
|
|
156
|
+
mock_remote_factory.push.assert_called_once()
|
|
157
|
+
_, push_kwargs = mock_remote_factory.push.call_args
|
|
158
|
+
assert push_kwargs.get("deploy_timeout_minutes") == 1200
|
|
@@ -0,0 +1,11 @@
|
|
|
1
|
+
from truss.cli.utils import common
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
def test_normalize_iso_timestamp_handles_nanoseconds():
|
|
5
|
+
normalized = common._normalize_iso_timestamp("2025-11-17 05:05:06.000000000 +0000")
|
|
6
|
+
assert normalized == "2025-11-17 05:05:06.000000+00:00"
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def test_normalize_iso_timestamp_handles_z_suffix_and_short_fraction():
|
|
10
|
+
normalized = common._normalize_iso_timestamp("2025-11-17T05:05:06.123456Z")
|
|
11
|
+
assert normalized == "2025-11-17T05:05:06.123456+00:00"
|
|
@@ -0,0 +1,279 @@
|
|
|
1
|
+
"""Tests for team parameter in model push.
|
|
2
|
+
|
|
3
|
+
This test suite covers all 8 scenarios for team resolution in truss push:
|
|
4
|
+
1. --team PROVIDED: Valid team name, user has access
|
|
5
|
+
2. --team PROVIDED: Invalid team name (does not exist)
|
|
6
|
+
3. --team NOT PROVIDED: User has multiple teams, no existing model
|
|
7
|
+
4. --team NOT PROVIDED: User has multiple teams, existing model in exactly one team
|
|
8
|
+
5. --team NOT PROVIDED: User has multiple teams, existing model exists in multiple teams
|
|
9
|
+
6. --team NOT PROVIDED: User has exactly one team, no existing model
|
|
10
|
+
7. --team NOT PROVIDED: User has exactly one team, existing model matches the team
|
|
11
|
+
8. --team NOT PROVIDED: User has exactly one team, existing model exists in different team
|
|
12
|
+
"""
|
|
13
|
+
|
|
14
|
+
from unittest.mock import Mock, patch
|
|
15
|
+
|
|
16
|
+
import click
|
|
17
|
+
import pytest
|
|
18
|
+
|
|
19
|
+
from truss.cli.resolvers.model_team_resolver import resolve_model_team_name
|
|
20
|
+
from truss.remote.baseten.remote import BasetenRemote
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class TestModelTeamResolver:
|
|
24
|
+
"""Test team parameter resolution for model push."""
|
|
25
|
+
|
|
26
|
+
@staticmethod
|
|
27
|
+
def _setup_mock_remote(teams):
|
|
28
|
+
mock_remote = Mock(spec=BasetenRemote)
|
|
29
|
+
mock_api = Mock()
|
|
30
|
+
mock_remote.api = mock_api
|
|
31
|
+
mock_api.get_teams.return_value = teams
|
|
32
|
+
return mock_remote
|
|
33
|
+
|
|
34
|
+
@pytest.mark.parametrize(
|
|
35
|
+
"provided_team_name,expected_team_name,expected_team_id,should_raise",
|
|
36
|
+
[
|
|
37
|
+
# SCENARIO 1: Valid team name
|
|
38
|
+
("Team Alpha", "Team Alpha", "team1", False),
|
|
39
|
+
# SCENARIO 2: Invalid team name
|
|
40
|
+
("NonExistentTeam", None, None, True),
|
|
41
|
+
],
|
|
42
|
+
)
|
|
43
|
+
def test_team_provided_scenarios(
|
|
44
|
+
self, provided_team_name, expected_team_name, expected_team_id, should_raise
|
|
45
|
+
):
|
|
46
|
+
"""Test scenarios when --team is provided."""
|
|
47
|
+
teams = {"Team Alpha": {"id": "team1", "name": "Team Alpha"}}
|
|
48
|
+
mock_remote = self._setup_mock_remote(teams)
|
|
49
|
+
|
|
50
|
+
if should_raise:
|
|
51
|
+
with pytest.raises(click.ClickException) as exc_info:
|
|
52
|
+
resolve_model_team_name(
|
|
53
|
+
remote_provider=mock_remote,
|
|
54
|
+
provided_team_name=provided_team_name,
|
|
55
|
+
existing_teams=teams,
|
|
56
|
+
)
|
|
57
|
+
assert "does not exist" in str(exc_info.value)
|
|
58
|
+
assert provided_team_name in str(exc_info.value)
|
|
59
|
+
else:
|
|
60
|
+
team_name, team_id = resolve_model_team_name(
|
|
61
|
+
remote_provider=mock_remote,
|
|
62
|
+
provided_team_name=provided_team_name,
|
|
63
|
+
existing_teams=teams,
|
|
64
|
+
)
|
|
65
|
+
assert team_name == expected_team_name
|
|
66
|
+
assert team_id == expected_team_id
|
|
67
|
+
mock_remote.api.get_teams.assert_not_called()
|
|
68
|
+
|
|
69
|
+
@pytest.mark.parametrize(
|
|
70
|
+
"scenario_num,teams,models_response,existing_model_name,inquire_return,expected_team_name,expected_team_id,should_prompt",
|
|
71
|
+
[
|
|
72
|
+
# SCENARIO 3: Multiple teams, no existing model
|
|
73
|
+
(
|
|
74
|
+
3,
|
|
75
|
+
{
|
|
76
|
+
"Team Alpha": {"id": "team1", "name": "Team Alpha"},
|
|
77
|
+
"Team Beta": {"id": "team2", "name": "Team Beta"},
|
|
78
|
+
"Team Gamma": {"id": "team3", "name": "Team Gamma"},
|
|
79
|
+
},
|
|
80
|
+
{"models": []}, # No models exist
|
|
81
|
+
"non-existent-model",
|
|
82
|
+
"Team Beta",
|
|
83
|
+
"Team Beta",
|
|
84
|
+
"team2",
|
|
85
|
+
True,
|
|
86
|
+
),
|
|
87
|
+
# SCENARIO 4: Multiple teams, existing model in exactly one team
|
|
88
|
+
(
|
|
89
|
+
4,
|
|
90
|
+
{
|
|
91
|
+
"Team Alpha": {"id": "team1", "name": "Team Alpha"},
|
|
92
|
+
"Team Beta": {"id": "team2", "name": "Team Beta"},
|
|
93
|
+
"Team Gamma": {"id": "team3", "name": "Team Gamma"},
|
|
94
|
+
},
|
|
95
|
+
{
|
|
96
|
+
"models": [
|
|
97
|
+
{
|
|
98
|
+
"id": "model1",
|
|
99
|
+
"name": "existing-model",
|
|
100
|
+
"team": {"id": "team2", "name": "Team Beta"},
|
|
101
|
+
}
|
|
102
|
+
]
|
|
103
|
+
},
|
|
104
|
+
"existing-model",
|
|
105
|
+
None,
|
|
106
|
+
"Team Beta",
|
|
107
|
+
"team2",
|
|
108
|
+
False,
|
|
109
|
+
),
|
|
110
|
+
# SCENARIO 5: Multiple teams, existing model in multiple teams
|
|
111
|
+
(
|
|
112
|
+
5,
|
|
113
|
+
{
|
|
114
|
+
"Team Alpha": {"id": "team1", "name": "Team Alpha"},
|
|
115
|
+
"Team Beta": {"id": "team2", "name": "Team Beta"},
|
|
116
|
+
"Team Gamma": {"id": "team3", "name": "Team Gamma"},
|
|
117
|
+
},
|
|
118
|
+
{
|
|
119
|
+
"models": [
|
|
120
|
+
{
|
|
121
|
+
"id": "model1",
|
|
122
|
+
"name": "existing-model",
|
|
123
|
+
"team": {"id": "team1", "name": "Team Alpha"},
|
|
124
|
+
},
|
|
125
|
+
{
|
|
126
|
+
"id": "model2",
|
|
127
|
+
"name": "existing-model",
|
|
128
|
+
"team": {"id": "team2", "name": "Team Beta"},
|
|
129
|
+
},
|
|
130
|
+
]
|
|
131
|
+
},
|
|
132
|
+
"existing-model",
|
|
133
|
+
"Team Alpha",
|
|
134
|
+
"Team Alpha",
|
|
135
|
+
"team1",
|
|
136
|
+
True,
|
|
137
|
+
),
|
|
138
|
+
# SCENARIO 6: Single team, no existing model
|
|
139
|
+
(
|
|
140
|
+
6,
|
|
141
|
+
{"Team Alpha": {"id": "team1", "name": "Team Alpha"}},
|
|
142
|
+
{"models": []}, # No models exist
|
|
143
|
+
"non-existent-model",
|
|
144
|
+
None,
|
|
145
|
+
"Team Alpha",
|
|
146
|
+
"team1",
|
|
147
|
+
False,
|
|
148
|
+
),
|
|
149
|
+
# SCENARIO 7: Single team, existing model matches the team
|
|
150
|
+
(
|
|
151
|
+
7,
|
|
152
|
+
{"Team Alpha": {"id": "team1", "name": "Team Alpha"}},
|
|
153
|
+
{
|
|
154
|
+
"models": [
|
|
155
|
+
{
|
|
156
|
+
"id": "model1",
|
|
157
|
+
"name": "existing-model",
|
|
158
|
+
"team": {"id": "team1", "name": "Team Alpha"},
|
|
159
|
+
}
|
|
160
|
+
]
|
|
161
|
+
},
|
|
162
|
+
"existing-model",
|
|
163
|
+
None,
|
|
164
|
+
"Team Alpha",
|
|
165
|
+
"team1",
|
|
166
|
+
False,
|
|
167
|
+
),
|
|
168
|
+
# SCENARIO 8: Single team, existing model in different team
|
|
169
|
+
(
|
|
170
|
+
8,
|
|
171
|
+
{"Team Alpha": {"id": "team1", "name": "Team Alpha"}},
|
|
172
|
+
{
|
|
173
|
+
"models": [
|
|
174
|
+
{
|
|
175
|
+
"id": "model1",
|
|
176
|
+
"name": "existing-model",
|
|
177
|
+
"team": {"id": "team2", "name": "Team Other"},
|
|
178
|
+
}
|
|
179
|
+
]
|
|
180
|
+
},
|
|
181
|
+
"existing-model",
|
|
182
|
+
None,
|
|
183
|
+
"Team Alpha",
|
|
184
|
+
"team1",
|
|
185
|
+
False,
|
|
186
|
+
),
|
|
187
|
+
],
|
|
188
|
+
)
|
|
189
|
+
@patch("truss.cli.resolvers.model_team_resolver.remote_cli.inquire_team")
|
|
190
|
+
def test_team_not_provided_scenarios(
|
|
191
|
+
self,
|
|
192
|
+
mock_inquire_team,
|
|
193
|
+
scenario_num,
|
|
194
|
+
teams,
|
|
195
|
+
models_response,
|
|
196
|
+
existing_model_name,
|
|
197
|
+
inquire_return,
|
|
198
|
+
expected_team_name,
|
|
199
|
+
expected_team_id,
|
|
200
|
+
should_prompt,
|
|
201
|
+
):
|
|
202
|
+
"""Test scenarios when --team is NOT provided."""
|
|
203
|
+
mock_remote = self._setup_mock_remote(teams)
|
|
204
|
+
if inquire_return:
|
|
205
|
+
mock_inquire_team.return_value = inquire_return
|
|
206
|
+
|
|
207
|
+
mock_remote.api.models.return_value = models_response
|
|
208
|
+
|
|
209
|
+
team_name, team_id = resolve_model_team_name(
|
|
210
|
+
remote_provider=mock_remote,
|
|
211
|
+
provided_team_name=None,
|
|
212
|
+
existing_model_name=existing_model_name,
|
|
213
|
+
existing_teams=teams,
|
|
214
|
+
)
|
|
215
|
+
|
|
216
|
+
assert team_name == expected_team_name
|
|
217
|
+
assert team_id == expected_team_id
|
|
218
|
+
if should_prompt:
|
|
219
|
+
mock_inquire_team.assert_called_once_with(existing_teams=teams)
|
|
220
|
+
else:
|
|
221
|
+
mock_inquire_team.assert_not_called()
|
|
222
|
+
if existing_model_name:
|
|
223
|
+
mock_remote.api.models.assert_called_once()
|
|
224
|
+
|
|
225
|
+
@pytest.mark.parametrize(
|
|
226
|
+
"existing_teams_param,should_call_get_teams",
|
|
227
|
+
[(None, True), ({"Team Alpha": {"id": "team1", "name": "Team Alpha"}}, False)],
|
|
228
|
+
)
|
|
229
|
+
def test_get_teams_called_when_existing_teams_none(
|
|
230
|
+
self, existing_teams_param, should_call_get_teams
|
|
231
|
+
):
|
|
232
|
+
"""Test that get_teams is called when existing_teams is not provided."""
|
|
233
|
+
teams = {"Team Alpha": {"id": "team1", "name": "Team Alpha"}}
|
|
234
|
+
mock_remote = self._setup_mock_remote(teams)
|
|
235
|
+
|
|
236
|
+
team_name, team_id = resolve_model_team_name(
|
|
237
|
+
remote_provider=mock_remote,
|
|
238
|
+
provided_team_name="Team Alpha",
|
|
239
|
+
existing_teams=existing_teams_param,
|
|
240
|
+
)
|
|
241
|
+
|
|
242
|
+
assert team_name == "Team Alpha"
|
|
243
|
+
assert team_id == "team1"
|
|
244
|
+
if should_call_get_teams:
|
|
245
|
+
mock_remote.api.get_teams.assert_called_once()
|
|
246
|
+
else:
|
|
247
|
+
mock_remote.api.get_teams.assert_not_called()
|
|
248
|
+
|
|
249
|
+
@pytest.mark.parametrize(
|
|
250
|
+
"existing_model_name,should_call_models_api",
|
|
251
|
+
[(None, False), ("some-model", True)],
|
|
252
|
+
)
|
|
253
|
+
@patch("truss.cli.resolvers.model_team_resolver.remote_cli.inquire_team")
|
|
254
|
+
def test_existing_model_name_scenarios(
|
|
255
|
+
self, mock_inquire_team, existing_model_name, should_call_models_api
|
|
256
|
+
):
|
|
257
|
+
"""Test behavior with different existing_model_name values."""
|
|
258
|
+
teams = {
|
|
259
|
+
"Team Alpha": {"id": "team1", "name": "Team Alpha"},
|
|
260
|
+
"Team Beta": {"id": "team2", "name": "Team Beta"},
|
|
261
|
+
}
|
|
262
|
+
mock_remote = self._setup_mock_remote(teams)
|
|
263
|
+
mock_inquire_team.return_value = "Team Beta"
|
|
264
|
+
mock_remote.api.models.return_value = {"models": []}
|
|
265
|
+
|
|
266
|
+
team_name, team_id = resolve_model_team_name(
|
|
267
|
+
remote_provider=mock_remote,
|
|
268
|
+
provided_team_name=None,
|
|
269
|
+
existing_model_name=existing_model_name,
|
|
270
|
+
existing_teams=teams,
|
|
271
|
+
)
|
|
272
|
+
|
|
273
|
+
assert team_name == "Team Beta"
|
|
274
|
+
assert team_id == "team2"
|
|
275
|
+
mock_inquire_team.assert_called_once_with(existing_teams=teams)
|
|
276
|
+
if should_call_models_api:
|
|
277
|
+
mock_remote.api.models.assert_called_once()
|
|
278
|
+
else:
|
|
279
|
+
mock_remote.api.models.assert_not_called()
|
|
@@ -3,7 +3,7 @@ from unittest.mock import Mock
|
|
|
3
3
|
import click
|
|
4
4
|
import pytest
|
|
5
5
|
|
|
6
|
-
from truss.cli.train.
|
|
6
|
+
from truss.cli.train.cache import (
|
|
7
7
|
SORT_BY_FILEPATH,
|
|
8
8
|
SORT_BY_MODIFIED,
|
|
9
9
|
SORT_BY_PERMISSIONS,
|
|
@@ -12,8 +12,8 @@ from truss.cli.train.core import (
|
|
|
12
12
|
SORT_ORDER_ASC,
|
|
13
13
|
SORT_ORDER_DESC,
|
|
14
14
|
view_cache_summary,
|
|
15
|
-
view_cache_summary_by_project,
|
|
16
15
|
)
|
|
16
|
+
from truss.cli.train.core import view_cache_summary_by_project
|
|
17
17
|
from truss.remote.baseten.remote import BasetenRemote
|
|
18
18
|
|
|
19
19
|
|
|
@@ -106,7 +106,9 @@ def test_view_cache_summary_empty_files(capsys):
|
|
|
106
106
|
mock_api.get_cache_summary.assert_called_once_with("proj123")
|
|
107
107
|
|
|
108
108
|
captured = capsys.readouterr()
|
|
109
|
-
|
|
109
|
+
# Empty files should still show the table with 0 files
|
|
110
|
+
assert "Cache summary for project: proj123" in captured.out
|
|
111
|
+
assert "Total files: 0" in captured.out
|
|
110
112
|
|
|
111
113
|
|
|
112
114
|
def test_view_cache_summary_api_error(capsys):
|
|
@@ -710,3 +712,238 @@ def test_view_cache_summary_sort_by_permissions_desc(capsys):
|
|
|
710
712
|
|
|
711
713
|
assert directory_pos < script_pos
|
|
712
714
|
assert script_pos < config_pos
|
|
715
|
+
|
|
716
|
+
|
|
717
|
+
# Tests for individual viewers
|
|
718
|
+
def test_cli_table_viewer_with_data(capsys):
|
|
719
|
+
"""Test CLITableViewer outputs table correctly with data."""
|
|
720
|
+
from truss.cli.train.cache import CLITableViewer
|
|
721
|
+
from truss.remote.baseten.custom_types import (
|
|
722
|
+
FileSummary,
|
|
723
|
+
FileSummaryWithTotalSize,
|
|
724
|
+
GetCacheSummaryResponseV1,
|
|
725
|
+
)
|
|
726
|
+
|
|
727
|
+
viewer = CLITableViewer()
|
|
728
|
+
cache_data = GetCacheSummaryResponseV1(
|
|
729
|
+
timestamp="2024-01-01T12:00:00Z",
|
|
730
|
+
project_id="proj123",
|
|
731
|
+
file_summaries=[
|
|
732
|
+
FileSummary(
|
|
733
|
+
path="test/file.txt",
|
|
734
|
+
size_bytes=1024,
|
|
735
|
+
modified="2024-01-01T10:00:00Z",
|
|
736
|
+
file_type="file",
|
|
737
|
+
permissions="-rw-r--r--",
|
|
738
|
+
)
|
|
739
|
+
],
|
|
740
|
+
)
|
|
741
|
+
files_with_total_sizes = [
|
|
742
|
+
FileSummaryWithTotalSize(
|
|
743
|
+
file_summary=cache_data.file_summaries[0], total_size=1024
|
|
744
|
+
)
|
|
745
|
+
]
|
|
746
|
+
|
|
747
|
+
viewer.output_cache_summary(
|
|
748
|
+
cache_data, files_with_total_sizes, 1024, "1.02 KB", "proj123"
|
|
749
|
+
)
|
|
750
|
+
|
|
751
|
+
captured = capsys.readouterr()
|
|
752
|
+
assert "Cache summary for project: proj123" in captured.out
|
|
753
|
+
assert "test/file.txt" in captured.out
|
|
754
|
+
assert "1.02 KB" in captured.out
|
|
755
|
+
assert "Total files: 1" in captured.out
|
|
756
|
+
|
|
757
|
+
|
|
758
|
+
def test_cli_table_viewer_empty_files(capsys):
|
|
759
|
+
"""Test CLITableViewer handles empty files correctly."""
|
|
760
|
+
from truss.cli.train.cache import CLITableViewer
|
|
761
|
+
from truss.remote.baseten.custom_types import GetCacheSummaryResponseV1
|
|
762
|
+
|
|
763
|
+
viewer = CLITableViewer()
|
|
764
|
+
cache_data = GetCacheSummaryResponseV1(
|
|
765
|
+
timestamp="2024-01-01T12:00:00Z", project_id="proj123", file_summaries=[]
|
|
766
|
+
)
|
|
767
|
+
|
|
768
|
+
viewer.output_cache_summary(cache_data, [], 0, "0 B", "proj123")
|
|
769
|
+
|
|
770
|
+
captured = capsys.readouterr()
|
|
771
|
+
assert "Cache summary for project: proj123" in captured.out
|
|
772
|
+
assert "Total files: 0" in captured.out
|
|
773
|
+
assert "Total size: 0 B" in captured.out
|
|
774
|
+
|
|
775
|
+
|
|
776
|
+
def test_cli_table_viewer_no_cache(capsys):
|
|
777
|
+
"""Test CLITableViewer handles no cache message."""
|
|
778
|
+
from truss.cli.train.cache import CLITableViewer
|
|
779
|
+
|
|
780
|
+
viewer = CLITableViewer()
|
|
781
|
+
viewer.output_no_cache_message("proj123")
|
|
782
|
+
|
|
783
|
+
captured = capsys.readouterr()
|
|
784
|
+
assert "No cache summary found for this project." in captured.out
|
|
785
|
+
|
|
786
|
+
|
|
787
|
+
def test_csv_viewer_with_data(capsys):
|
|
788
|
+
"""Test CSVViewer outputs CSV correctly with data."""
|
|
789
|
+
from truss.cli.train.cache import CSVViewer
|
|
790
|
+
from truss.remote.baseten.custom_types import (
|
|
791
|
+
FileSummary,
|
|
792
|
+
FileSummaryWithTotalSize,
|
|
793
|
+
GetCacheSummaryResponseV1,
|
|
794
|
+
)
|
|
795
|
+
|
|
796
|
+
viewer = CSVViewer()
|
|
797
|
+
cache_data = GetCacheSummaryResponseV1(
|
|
798
|
+
timestamp="2024-01-01T12:00:00Z",
|
|
799
|
+
project_id="proj123",
|
|
800
|
+
file_summaries=[
|
|
801
|
+
FileSummary(
|
|
802
|
+
path="test/file.txt",
|
|
803
|
+
size_bytes=2048,
|
|
804
|
+
modified="2024-01-01T10:00:00Z",
|
|
805
|
+
file_type="file",
|
|
806
|
+
permissions="-rw-r--r--",
|
|
807
|
+
)
|
|
808
|
+
],
|
|
809
|
+
)
|
|
810
|
+
files_with_total_sizes = [
|
|
811
|
+
FileSummaryWithTotalSize(
|
|
812
|
+
file_summary=cache_data.file_summaries[0], total_size=2048
|
|
813
|
+
)
|
|
814
|
+
]
|
|
815
|
+
|
|
816
|
+
viewer.output_cache_summary(
|
|
817
|
+
cache_data, files_with_total_sizes, 2048, "2.05 KB", "proj123"
|
|
818
|
+
)
|
|
819
|
+
|
|
820
|
+
captured = capsys.readouterr()
|
|
821
|
+
lines = captured.out.strip().split("\n")
|
|
822
|
+
assert len(lines) == 2 # Header + 1 data row
|
|
823
|
+
assert "File Path" in lines[0]
|
|
824
|
+
assert "Size (bytes)" in lines[0]
|
|
825
|
+
assert "test/file.txt" in lines[1]
|
|
826
|
+
assert "2048" in lines[1]
|
|
827
|
+
assert "2.05 KB" in lines[1]
|
|
828
|
+
|
|
829
|
+
|
|
830
|
+
def test_csv_viewer_empty_files(capsys):
|
|
831
|
+
"""Test CSVViewer handles empty files correctly (just headers)."""
|
|
832
|
+
from truss.cli.train.cache import CSVViewer
|
|
833
|
+
from truss.remote.baseten.custom_types import GetCacheSummaryResponseV1
|
|
834
|
+
|
|
835
|
+
viewer = CSVViewer()
|
|
836
|
+
cache_data = GetCacheSummaryResponseV1(
|
|
837
|
+
timestamp="2024-01-01T12:00:00Z", project_id="proj123", file_summaries=[]
|
|
838
|
+
)
|
|
839
|
+
|
|
840
|
+
viewer.output_cache_summary(cache_data, [], 0, "0 B", "proj123")
|
|
841
|
+
|
|
842
|
+
captured = capsys.readouterr()
|
|
843
|
+
lines = captured.out.strip().split("\n")
|
|
844
|
+
assert len(lines) == 1 # Just header row
|
|
845
|
+
assert "File Path" in lines[0]
|
|
846
|
+
assert "Size (bytes)" in lines[0]
|
|
847
|
+
|
|
848
|
+
|
|
849
|
+
def test_csv_viewer_no_cache(capsys):
|
|
850
|
+
"""Test CSVViewer handles no cache (outputs empty CSV with headers)."""
|
|
851
|
+
from truss.cli.train.cache import CSVViewer
|
|
852
|
+
|
|
853
|
+
viewer = CSVViewer()
|
|
854
|
+
viewer.output_no_cache_message("proj123")
|
|
855
|
+
|
|
856
|
+
captured = capsys.readouterr()
|
|
857
|
+
lines = captured.out.strip().split("\n")
|
|
858
|
+
assert len(lines) == 1 # Just header row
|
|
859
|
+
assert "File Path" in lines[0]
|
|
860
|
+
|
|
861
|
+
|
|
862
|
+
def test_json_viewer_with_data(capsys):
|
|
863
|
+
"""Test JSONViewer outputs JSON correctly with data."""
|
|
864
|
+
import json
|
|
865
|
+
|
|
866
|
+
from truss.cli.train.cache import JSONViewer
|
|
867
|
+
from truss.remote.baseten.custom_types import (
|
|
868
|
+
FileSummary,
|
|
869
|
+
FileSummaryWithTotalSize,
|
|
870
|
+
GetCacheSummaryResponseV1,
|
|
871
|
+
)
|
|
872
|
+
|
|
873
|
+
viewer = JSONViewer()
|
|
874
|
+
cache_data = GetCacheSummaryResponseV1(
|
|
875
|
+
timestamp="2024-01-01T12:00:00Z",
|
|
876
|
+
project_id="proj123",
|
|
877
|
+
file_summaries=[
|
|
878
|
+
FileSummary(
|
|
879
|
+
path="test/file.txt",
|
|
880
|
+
size_bytes=3072,
|
|
881
|
+
modified="2024-01-01T10:00:00Z",
|
|
882
|
+
file_type="file",
|
|
883
|
+
permissions="-rw-r--r--",
|
|
884
|
+
)
|
|
885
|
+
],
|
|
886
|
+
)
|
|
887
|
+
files_with_total_sizes = [
|
|
888
|
+
FileSummaryWithTotalSize(
|
|
889
|
+
file_summary=cache_data.file_summaries[0], total_size=3072
|
|
890
|
+
)
|
|
891
|
+
]
|
|
892
|
+
|
|
893
|
+
viewer.output_cache_summary(
|
|
894
|
+
cache_data, files_with_total_sizes, 3072, "3.07 KB", "proj123"
|
|
895
|
+
)
|
|
896
|
+
|
|
897
|
+
captured = capsys.readouterr()
|
|
898
|
+
output = json.loads(captured.out)
|
|
899
|
+
assert output["timestamp"] == "2024-01-01T12:00:00Z"
|
|
900
|
+
assert output["project_id"] == "proj123"
|
|
901
|
+
assert output["total_files"] == 1
|
|
902
|
+
assert output["total_size_bytes"] == 3072
|
|
903
|
+
assert output["total_size_human_readable"] == "3.07 KB"
|
|
904
|
+
assert len(output["files"]) == 1
|
|
905
|
+
assert output["files"][0]["path"] == "test/file.txt"
|
|
906
|
+
assert output["files"][0]["size_bytes"] == 3072
|
|
907
|
+
|
|
908
|
+
|
|
909
|
+
def test_json_viewer_empty_files(capsys):
|
|
910
|
+
"""Test JSONViewer handles empty files correctly."""
|
|
911
|
+
import json
|
|
912
|
+
|
|
913
|
+
from truss.cli.train.cache import JSONViewer
|
|
914
|
+
from truss.remote.baseten.custom_types import GetCacheSummaryResponseV1
|
|
915
|
+
|
|
916
|
+
viewer = JSONViewer()
|
|
917
|
+
cache_data = GetCacheSummaryResponseV1(
|
|
918
|
+
timestamp="2024-01-01T12:00:00Z", project_id="proj123", file_summaries=[]
|
|
919
|
+
)
|
|
920
|
+
|
|
921
|
+
viewer.output_cache_summary(cache_data, [], 0, "0 B", "proj123")
|
|
922
|
+
|
|
923
|
+
captured = capsys.readouterr()
|
|
924
|
+
output = json.loads(captured.out)
|
|
925
|
+
assert output["timestamp"] == "2024-01-01T12:00:00Z"
|
|
926
|
+
assert output["project_id"] == "proj123"
|
|
927
|
+
assert output["total_files"] == 0
|
|
928
|
+
assert output["total_size_bytes"] == 0
|
|
929
|
+
assert output["total_size_human_readable"] == "0 B"
|
|
930
|
+
assert output["files"] == []
|
|
931
|
+
|
|
932
|
+
|
|
933
|
+
def test_json_viewer_no_cache(capsys):
|
|
934
|
+
"""Test JSONViewer handles no cache (outputs empty JSON structure)."""
|
|
935
|
+
import json
|
|
936
|
+
|
|
937
|
+
from truss.cli.train.cache import JSONViewer
|
|
938
|
+
|
|
939
|
+
viewer = JSONViewer()
|
|
940
|
+
viewer.output_no_cache_message("proj123")
|
|
941
|
+
|
|
942
|
+
captured = capsys.readouterr()
|
|
943
|
+
output = json.loads(captured.out)
|
|
944
|
+
assert output["timestamp"] == ""
|
|
945
|
+
assert output["project_id"] == "proj123"
|
|
946
|
+
assert output["total_files"] == 0
|
|
947
|
+
assert output["total_size_bytes"] == 0
|
|
948
|
+
assert output["total_size_human_readable"] == "0 B"
|
|
949
|
+
assert output["files"] == []
|