truss 0.11.2rc502__py3-none-any.whl → 0.11.2rc504__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of truss might be problematic. Click here for more details.
- truss/base/constants.py +3 -0
- truss/cli/chains_commands.py +20 -7
- truss/cli/train/core.py +156 -0
- truss/cli/train/deploy_checkpoints/deploy_checkpoints.py +1 -1
- truss/cli/train_commands.py +72 -0
- truss/templates/base.Dockerfile.jinja +1 -3
- truss/templates/control/control/endpoints.py +82 -33
- truss/templates/control/control/helpers/truss_patch/model_container_patch_applier.py +1 -19
- truss/templates/control/requirements.txt +1 -1
- truss/templates/server/common/errors.py +1 -0
- truss/templates/server/truss_server.py +5 -3
- truss/templates/server.Dockerfile.jinja +2 -4
- truss/templates/train/config.py +46 -0
- truss/templates/train/run.sh +11 -0
- truss/tests/cli/train/test_deploy_checkpoints.py +3 -3
- truss/tests/cli/train/test_train_init.py +499 -0
- truss/tests/templates/control/control/test_endpoints.py +20 -14
- {truss-0.11.2rc502.dist-info → truss-0.11.2rc504.dist-info}/METADATA +1 -1
- {truss-0.11.2rc502.dist-info → truss-0.11.2rc504.dist-info}/RECORD +29 -26
- truss_chains/deployment/code_gen.py +5 -1
- truss_chains/deployment/deployment_client.py +45 -7
- truss_chains/public_types.py +6 -3
- truss_chains/remote_chainlet/utils.py +46 -7
- truss_train/__init__.py +4 -0
- truss_train/definitions.py +47 -2
- truss_train/restore_from_checkpoint.py +42 -0
- truss/templates/server/entrypoint.sh +0 -16
- {truss-0.11.2rc502.dist-info → truss-0.11.2rc504.dist-info}/WHEEL +0 -0
- {truss-0.11.2rc502.dist-info → truss-0.11.2rc504.dist-info}/entry_points.txt +0 -0
- {truss-0.11.2rc502.dist-info → truss-0.11.2rc504.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,11 @@
|
|
|
1
|
+
#!/bin/bash
|
|
2
|
+
|
|
3
|
+
# Exit immediately if a command exits with a non-zero status
|
|
4
|
+
set -eux
|
|
5
|
+
|
|
6
|
+
echo "Initializing model training environment..."
|
|
7
|
+
# TODO: Call your training logic below
|
|
8
|
+
echo "Placeholder: insert your model training logic below."
|
|
9
|
+
# e.g., python train_model.py --config config.yaml --epochs 10
|
|
10
|
+
|
|
11
|
+
echo "Training process completed (placeholder)."
|
|
@@ -584,7 +584,7 @@ def test_get_checkpoint_ids_to_deploy_full_checkpoints():
|
|
|
584
584
|
mock_checkbox.assert_called_once()
|
|
585
585
|
assert (
|
|
586
586
|
mock_checkbox.call_args[1]["message"]
|
|
587
|
-
== "
|
|
587
|
+
== "Use spacebar to select/deselect checkpoints to deploy. Press enter when done."
|
|
588
588
|
)
|
|
589
589
|
assert mock_checkbox.call_args[1]["choices"] == checkpoint_options
|
|
590
590
|
|
|
@@ -621,7 +621,7 @@ def test_get_checkpoint_ids_to_deploy_lora_checkpoints():
|
|
|
621
621
|
mock_checkbox.assert_called_once()
|
|
622
622
|
assert (
|
|
623
623
|
mock_checkbox.call_args[1]["message"]
|
|
624
|
-
== "
|
|
624
|
+
== "Use spacebar to select/deselect checkpoints to deploy. Press enter when done."
|
|
625
625
|
)
|
|
626
626
|
assert mock_checkbox.call_args[1]["choices"] == checkpoint_options
|
|
627
627
|
|
|
@@ -656,7 +656,7 @@ def test_get_checkpoint_ids_to_deploy_mixed_checkpoints():
|
|
|
656
656
|
mock_checkbox.assert_called_once()
|
|
657
657
|
assert (
|
|
658
658
|
mock_checkbox.call_args[1]["message"]
|
|
659
|
-
== "
|
|
659
|
+
== "Use spacebar to select/deselect checkpoints to deploy. Press enter when done."
|
|
660
660
|
)
|
|
661
661
|
assert mock_checkbox.call_args[1]["choices"] == checkpoint_options
|
|
662
662
|
|
|
@@ -0,0 +1,499 @@
|
|
|
1
|
+
from unittest.mock import Mock, call, mock_open, patch
|
|
2
|
+
|
|
3
|
+
import pytest
|
|
4
|
+
import requests
|
|
5
|
+
|
|
6
|
+
from truss.cli.train.core import (
|
|
7
|
+
_get_all_train_init_example_options,
|
|
8
|
+
_get_train_init_example_info,
|
|
9
|
+
download_git_directory,
|
|
10
|
+
)
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class TestGetTrainInitExampleOptions:
|
|
14
|
+
"""Test cases for _get_train_init_example_options function"""
|
|
15
|
+
|
|
16
|
+
@patch("requests.get")
|
|
17
|
+
def test_successful_request_without_token(self, mock_get):
|
|
18
|
+
"""Test successful API call without authentication token"""
|
|
19
|
+
# Arrange
|
|
20
|
+
mock_response = Mock()
|
|
21
|
+
mock_response.json.return_value = [
|
|
22
|
+
{"name": "example1", "type": "dir"},
|
|
23
|
+
{"name": "example2", "type": "dir"},
|
|
24
|
+
{"name": "file1", "type": "file"}, # Should be filtered out
|
|
25
|
+
]
|
|
26
|
+
mock_response.raise_for_status.return_value = None
|
|
27
|
+
mock_get.return_value = mock_response
|
|
28
|
+
|
|
29
|
+
# Act
|
|
30
|
+
result = _get_all_train_init_example_options()
|
|
31
|
+
|
|
32
|
+
# Assert
|
|
33
|
+
mock_get.assert_called_once_with(
|
|
34
|
+
"https://api.github.com/repos/basetenlabs/ml-cookbook/contents/examples",
|
|
35
|
+
headers={},
|
|
36
|
+
)
|
|
37
|
+
assert len(result) == 2
|
|
38
|
+
assert "example1" in result
|
|
39
|
+
assert "example2" in result
|
|
40
|
+
assert "file1" not in result # Files should be filtered out
|
|
41
|
+
|
|
42
|
+
@patch("requests.get")
|
|
43
|
+
def test_successful_request_with_token(self, mock_get):
|
|
44
|
+
"""Test successful API call with authentication token"""
|
|
45
|
+
# Arrange
|
|
46
|
+
mock_response = Mock()
|
|
47
|
+
mock_response.json.return_value = [{"name": "example1", "type": "dir"}]
|
|
48
|
+
mock_response.raise_for_status.return_value = None
|
|
49
|
+
mock_get.return_value = mock_response
|
|
50
|
+
|
|
51
|
+
# Act
|
|
52
|
+
result = _get_all_train_init_example_options(token="test_token")
|
|
53
|
+
|
|
54
|
+
# Assert
|
|
55
|
+
mock_get.assert_called_once_with(
|
|
56
|
+
"https://api.github.com/repos/basetenlabs/ml-cookbook/contents/examples",
|
|
57
|
+
headers={"Authorization": "token test_token"},
|
|
58
|
+
)
|
|
59
|
+
assert len(result) == 1
|
|
60
|
+
assert "example1" in result
|
|
61
|
+
|
|
62
|
+
@patch("requests.get")
|
|
63
|
+
def test_custom_repo_and_subdir(self, mock_get):
|
|
64
|
+
"""Test with custom repository and subdirectory"""
|
|
65
|
+
# Arrange
|
|
66
|
+
mock_response = Mock()
|
|
67
|
+
mock_response.json.return_value = []
|
|
68
|
+
mock_response.raise_for_status.return_value = None
|
|
69
|
+
mock_get.return_value = mock_response
|
|
70
|
+
|
|
71
|
+
# Act
|
|
72
|
+
_ = _get_all_train_init_example_options(
|
|
73
|
+
repo_id="custom-repo", examples_subdir="custom-examples"
|
|
74
|
+
)
|
|
75
|
+
|
|
76
|
+
# Assert
|
|
77
|
+
mock_get.assert_called_once_with(
|
|
78
|
+
"https://api.github.com/repos/basetenlabs/custom-repo/contents/custom-examples",
|
|
79
|
+
headers={},
|
|
80
|
+
)
|
|
81
|
+
|
|
82
|
+
@patch("requests.get")
|
|
83
|
+
def test_single_item_response(self, mock_get):
|
|
84
|
+
"""Test when API returns a single item instead of a list"""
|
|
85
|
+
# Arrange
|
|
86
|
+
mock_response = Mock()
|
|
87
|
+
mock_response.json.return_value = {"name": "single_example", "type": "dir"}
|
|
88
|
+
mock_response.raise_for_status.return_value = None
|
|
89
|
+
mock_get.return_value = mock_response
|
|
90
|
+
|
|
91
|
+
# Act
|
|
92
|
+
result = _get_all_train_init_example_options()
|
|
93
|
+
|
|
94
|
+
# Assert
|
|
95
|
+
assert len(result) == 1
|
|
96
|
+
assert "single_example" in result
|
|
97
|
+
|
|
98
|
+
@patch("requests.get")
|
|
99
|
+
@patch("click.echo")
|
|
100
|
+
def test_request_exception_handling(self, mock_echo, mock_get):
|
|
101
|
+
"""Test handling of request exceptions"""
|
|
102
|
+
# Arrange
|
|
103
|
+
mock_get.side_effect = requests.exceptions.RequestException("Network error")
|
|
104
|
+
|
|
105
|
+
# Act
|
|
106
|
+
result = _get_all_train_init_example_options()
|
|
107
|
+
|
|
108
|
+
# Assert
|
|
109
|
+
mock_echo.assert_called_once_with(
|
|
110
|
+
"Error exploring directory: Network error. Please file an issue at https://github.com/basetenlabs/truss/issues"
|
|
111
|
+
)
|
|
112
|
+
assert result == []
|
|
113
|
+
|
|
114
|
+
@patch("requests.get")
|
|
115
|
+
@patch("click.echo")
|
|
116
|
+
def test_http_error_handling(self, mock_echo, mock_get):
|
|
117
|
+
"""Test handling of HTTP errors"""
|
|
118
|
+
# Arrange
|
|
119
|
+
mock_response = Mock()
|
|
120
|
+
mock_response.raise_for_status.side_effect = requests.exceptions.HTTPError(
|
|
121
|
+
"404 Not Found"
|
|
122
|
+
)
|
|
123
|
+
mock_get.return_value = mock_response
|
|
124
|
+
|
|
125
|
+
# Act
|
|
126
|
+
result = _get_all_train_init_example_options()
|
|
127
|
+
|
|
128
|
+
# Assert
|
|
129
|
+
mock_echo.assert_called_once_with(
|
|
130
|
+
"Error exploring directory: 404 Not Found. Please file an issue at https://github.com/basetenlabs/truss/issues"
|
|
131
|
+
)
|
|
132
|
+
assert result == []
|
|
133
|
+
|
|
134
|
+
@patch("requests.get")
|
|
135
|
+
def test_filters_only_directories(self, mock_get):
|
|
136
|
+
"""Test that only directories are returned, files are filtered out"""
|
|
137
|
+
# Arrange
|
|
138
|
+
mock_response = Mock()
|
|
139
|
+
mock_response.json.return_value = [
|
|
140
|
+
{"name": "example1", "type": "dir"},
|
|
141
|
+
{"name": "readme.md", "type": "file"},
|
|
142
|
+
{"name": "example2", "type": "dir"},
|
|
143
|
+
{"name": "config.json", "type": "file"},
|
|
144
|
+
]
|
|
145
|
+
mock_response.raise_for_status.return_value = None
|
|
146
|
+
mock_get.return_value = mock_response
|
|
147
|
+
|
|
148
|
+
# Act
|
|
149
|
+
result = _get_all_train_init_example_options()
|
|
150
|
+
|
|
151
|
+
# Assert
|
|
152
|
+
assert len(result) == 2
|
|
153
|
+
assert "example1" in result
|
|
154
|
+
assert "example2" in result
|
|
155
|
+
assert "readme.md" not in result
|
|
156
|
+
assert "config.json" not in result
|
|
157
|
+
|
|
158
|
+
|
|
159
|
+
class TestGetTrainInitExampleInfo:
|
|
160
|
+
"""Test cases for _get_train_init_example_info function"""
|
|
161
|
+
|
|
162
|
+
@patch("requests.get")
|
|
163
|
+
def test_successful_request_without_token(self, mock_get):
|
|
164
|
+
"""Test successful API call without authentication token"""
|
|
165
|
+
# Arrange
|
|
166
|
+
mock_response = Mock()
|
|
167
|
+
mock_response.json.return_value = [
|
|
168
|
+
{"name": "file1.py", "type": "file"},
|
|
169
|
+
{"name": "file2.py", "type": "file"},
|
|
170
|
+
]
|
|
171
|
+
mock_response.raise_for_status.return_value = None
|
|
172
|
+
mock_get.return_value = mock_response
|
|
173
|
+
|
|
174
|
+
# Act
|
|
175
|
+
result = _get_train_init_example_info(example_name="test_example")
|
|
176
|
+
|
|
177
|
+
# Assert
|
|
178
|
+
mock_get.assert_called_once_with(
|
|
179
|
+
"https://api.github.com/repos/basetenlabs/ml-cookbook/contents/examples/test_example",
|
|
180
|
+
headers={},
|
|
181
|
+
)
|
|
182
|
+
assert len(result) == 2
|
|
183
|
+
assert result[0]["name"] == "file1.py"
|
|
184
|
+
assert result[1]["name"] == "file2.py"
|
|
185
|
+
|
|
186
|
+
@patch("requests.get")
|
|
187
|
+
def test_successful_request_with_token(self, mock_get):
|
|
188
|
+
"""Test successful API call with authentication token"""
|
|
189
|
+
# Arrange
|
|
190
|
+
mock_response = Mock()
|
|
191
|
+
mock_response.json.return_value = [{"name": "file1.py", "type": "file"}]
|
|
192
|
+
mock_response.raise_for_status.return_value = None
|
|
193
|
+
mock_get.return_value = mock_response
|
|
194
|
+
|
|
195
|
+
# Act
|
|
196
|
+
result = _get_train_init_example_info(
|
|
197
|
+
example_name="test_example", token="test_token"
|
|
198
|
+
)
|
|
199
|
+
|
|
200
|
+
# Assert
|
|
201
|
+
mock_get.assert_called_once_with(
|
|
202
|
+
"https://api.github.com/repos/basetenlabs/ml-cookbook/contents/examples/test_example",
|
|
203
|
+
headers={"Authorization": "token test_token"},
|
|
204
|
+
)
|
|
205
|
+
assert len(result) == 1
|
|
206
|
+
|
|
207
|
+
@patch("requests.get")
|
|
208
|
+
def test_custom_repo_and_subdir(self, mock_get):
|
|
209
|
+
"""Test with custom repository and subdirectory"""
|
|
210
|
+
# Arrange
|
|
211
|
+
mock_response = Mock()
|
|
212
|
+
mock_response.json.return_value = []
|
|
213
|
+
mock_response.raise_for_status.return_value = None
|
|
214
|
+
mock_get.return_value = mock_response
|
|
215
|
+
|
|
216
|
+
# Act
|
|
217
|
+
_ = _get_train_init_example_info(
|
|
218
|
+
repo_id="custom-repo",
|
|
219
|
+
examples_subdir="custom-examples",
|
|
220
|
+
example_name="test_example",
|
|
221
|
+
)
|
|
222
|
+
|
|
223
|
+
# Assert
|
|
224
|
+
mock_get.assert_called_once_with(
|
|
225
|
+
"https://api.github.com/repos/basetenlabs/custom-repo/contents/custom-examples/test_example",
|
|
226
|
+
headers={},
|
|
227
|
+
)
|
|
228
|
+
|
|
229
|
+
@patch("requests.get")
|
|
230
|
+
def test_single_item_response(self, mock_get):
|
|
231
|
+
"""Test when API returns a single item instead of a list"""
|
|
232
|
+
# Arrange
|
|
233
|
+
mock_response = Mock()
|
|
234
|
+
mock_response.json.return_value = {"name": "single_file.py", "type": "file"}
|
|
235
|
+
mock_response.raise_for_status.return_value = None
|
|
236
|
+
mock_get.return_value = mock_response
|
|
237
|
+
|
|
238
|
+
# Act
|
|
239
|
+
result = _get_train_init_example_info(example_name="test_example")
|
|
240
|
+
|
|
241
|
+
# Assert
|
|
242
|
+
assert len(result) == 1
|
|
243
|
+
assert result[0]["name"] == "single_file.py"
|
|
244
|
+
|
|
245
|
+
@patch("requests.get")
|
|
246
|
+
@patch("click.echo")
|
|
247
|
+
def test_404_error_returns_empty_list(self, mock_echo, mock_get):
|
|
248
|
+
"""Test that 404 errors return empty list without error message"""
|
|
249
|
+
# Arrange
|
|
250
|
+
mock_response = Mock()
|
|
251
|
+
mock_response.status_code = 404
|
|
252
|
+
mock_response.raise_for_status.side_effect = requests.exceptions.HTTPError(
|
|
253
|
+
"404 Not Found"
|
|
254
|
+
)
|
|
255
|
+
mock_get.return_value = mock_response
|
|
256
|
+
|
|
257
|
+
# Act
|
|
258
|
+
result = _get_train_init_example_info(example_name="nonexistent_example")
|
|
259
|
+
|
|
260
|
+
# Assert
|
|
261
|
+
mock_echo.assert_not_called() # Should not echo error for 404
|
|
262
|
+
assert result == []
|
|
263
|
+
|
|
264
|
+
@patch("requests.get")
|
|
265
|
+
@patch("click.echo")
|
|
266
|
+
def test_other_http_error_handling(self, mock_echo, mock_get):
|
|
267
|
+
"""Test handling of non-404 HTTP errors"""
|
|
268
|
+
# Arrange
|
|
269
|
+
mock_response = Mock()
|
|
270
|
+
mock_response.status_code = 500
|
|
271
|
+
mock_response.raise_for_status.side_effect = requests.exceptions.HTTPError(
|
|
272
|
+
"500 Internal Server Error"
|
|
273
|
+
)
|
|
274
|
+
mock_get.return_value = mock_response
|
|
275
|
+
|
|
276
|
+
# Act
|
|
277
|
+
result = _get_train_init_example_info(example_name="test_example")
|
|
278
|
+
|
|
279
|
+
# Assert
|
|
280
|
+
mock_echo.assert_called_once_with(
|
|
281
|
+
"Error exploring directory: 500 Internal Server Error. Please file an issue at https://github.com/basetenlabs/truss/issues"
|
|
282
|
+
)
|
|
283
|
+
assert result == []
|
|
284
|
+
|
|
285
|
+
@patch("requests.get")
|
|
286
|
+
@patch("click.echo")
|
|
287
|
+
def test_request_exception_handling(self, mock_echo, mock_get):
|
|
288
|
+
"""Test handling of request exceptions"""
|
|
289
|
+
# Arrange
|
|
290
|
+
mock_get.side_effect = requests.exceptions.RequestException("Network error")
|
|
291
|
+
|
|
292
|
+
# Act
|
|
293
|
+
result = _get_train_init_example_info(example_name="test_example")
|
|
294
|
+
|
|
295
|
+
# Assert
|
|
296
|
+
mock_echo.assert_called_once_with(
|
|
297
|
+
"Error exploring directory: Network error. Please file an issue at https://github.com/basetenlabs/truss/issues"
|
|
298
|
+
)
|
|
299
|
+
assert result == []
|
|
300
|
+
|
|
301
|
+
@patch("requests.get")
|
|
302
|
+
def test_none_example_name(self, mock_get):
|
|
303
|
+
"""Test with None as example_name"""
|
|
304
|
+
# Arrange
|
|
305
|
+
mock_response = Mock()
|
|
306
|
+
mock_response.json.return_value = []
|
|
307
|
+
mock_response.raise_for_status.return_value = None
|
|
308
|
+
mock_get.return_value = mock_response
|
|
309
|
+
|
|
310
|
+
# Act
|
|
311
|
+
result = _get_train_init_example_info(example_name=None)
|
|
312
|
+
|
|
313
|
+
# Assert
|
|
314
|
+
mock_get.assert_called_once_with(
|
|
315
|
+
"https://api.github.com/repos/basetenlabs/ml-cookbook/contents/examples/None",
|
|
316
|
+
headers={},
|
|
317
|
+
)
|
|
318
|
+
assert result == []
|
|
319
|
+
|
|
320
|
+
|
|
321
|
+
class TestDownloadGitDirectory:
|
|
322
|
+
"""Test cases for download_git_directory function"""
|
|
323
|
+
|
|
324
|
+
@patch("os.makedirs")
|
|
325
|
+
@patch("requests.get")
|
|
326
|
+
@patch("builtins.open", new_callable=mock_open)
|
|
327
|
+
@patch("builtins.print")
|
|
328
|
+
def test_download_files_without_training_dir(
|
|
329
|
+
self, mock_print, mock_file, mock_get, mock_makedirs
|
|
330
|
+
):
|
|
331
|
+
"""Test downloading files without a training directory"""
|
|
332
|
+
# Arrange
|
|
333
|
+
mock_response = Mock()
|
|
334
|
+
mock_response.json.return_value = [
|
|
335
|
+
{
|
|
336
|
+
"name": "file1.txt",
|
|
337
|
+
"type": "file",
|
|
338
|
+
"download_url": "https://example.com/file1.txt",
|
|
339
|
+
},
|
|
340
|
+
{
|
|
341
|
+
"name": "file2.py",
|
|
342
|
+
"type": "file",
|
|
343
|
+
"download_url": "https://example.com/file2.py",
|
|
344
|
+
},
|
|
345
|
+
]
|
|
346
|
+
mock_response.raise_for_status.return_value = None
|
|
347
|
+
|
|
348
|
+
# Mock file download responses
|
|
349
|
+
file_response1 = Mock()
|
|
350
|
+
file_response1.content = b"file1 content"
|
|
351
|
+
file_response1.raise_for_status.return_value = None
|
|
352
|
+
|
|
353
|
+
file_response2 = Mock()
|
|
354
|
+
file_response2.content = b"file2 content"
|
|
355
|
+
file_response2.raise_for_status.return_value = None
|
|
356
|
+
|
|
357
|
+
mock_get.side_effect = [mock_response, file_response1, file_response2]
|
|
358
|
+
|
|
359
|
+
# Act
|
|
360
|
+
result = download_git_directory("https://api.github.com/test", "/local/dir")
|
|
361
|
+
|
|
362
|
+
# Assert
|
|
363
|
+
assert result is True
|
|
364
|
+
mock_makedirs.assert_called_once_with("/local/dir", exist_ok=True)
|
|
365
|
+
assert mock_get.call_count == 3
|
|
366
|
+
assert mock_file.call_count == 2
|
|
367
|
+
|
|
368
|
+
@patch("os.makedirs")
|
|
369
|
+
@patch("requests.get")
|
|
370
|
+
def test_download_with_training_directory(self, mock_get, mock_makedirs):
|
|
371
|
+
"""Test downloading when training directory is present"""
|
|
372
|
+
# Arrange
|
|
373
|
+
initial_response = Mock()
|
|
374
|
+
initial_response.json.return_value = [
|
|
375
|
+
{
|
|
376
|
+
"name": "training",
|
|
377
|
+
"type": "dir",
|
|
378
|
+
"url": "https://api.github.com/training",
|
|
379
|
+
},
|
|
380
|
+
{
|
|
381
|
+
"name": "other_file.txt",
|
|
382
|
+
"type": "file",
|
|
383
|
+
"download_url": "https://example.com/other_file.txt",
|
|
384
|
+
},
|
|
385
|
+
]
|
|
386
|
+
initial_response.raise_for_status.return_value = None
|
|
387
|
+
|
|
388
|
+
training_response = Mock()
|
|
389
|
+
training_response.json.return_value = []
|
|
390
|
+
training_response.raise_for_status.return_value = None
|
|
391
|
+
|
|
392
|
+
mock_get.side_effect = [initial_response, training_response]
|
|
393
|
+
|
|
394
|
+
# Act
|
|
395
|
+
result = download_git_directory("https://api.github.com/test", "/local/dir")
|
|
396
|
+
|
|
397
|
+
# Assert
|
|
398
|
+
assert result is True
|
|
399
|
+
# Should be called twice: once for initial dir, once for training contents
|
|
400
|
+
assert mock_makedirs.call_count == 2
|
|
401
|
+
|
|
402
|
+
@patch("os.makedirs")
|
|
403
|
+
@patch("requests.get")
|
|
404
|
+
def test_download_subdirectory_recursively(self, mock_get, mock_makedirs):
|
|
405
|
+
"""Test recursive download of subdirectories"""
|
|
406
|
+
# Arrange
|
|
407
|
+
initial_response = Mock()
|
|
408
|
+
initial_response.json.return_value = [
|
|
409
|
+
{"name": "subdir", "type": "dir", "url": "https://api.github.com/subdir"}
|
|
410
|
+
]
|
|
411
|
+
initial_response.raise_for_status.return_value = None
|
|
412
|
+
|
|
413
|
+
subdir_response = Mock()
|
|
414
|
+
subdir_response.json.return_value = []
|
|
415
|
+
subdir_response.raise_for_status.return_value = None
|
|
416
|
+
|
|
417
|
+
mock_get.side_effect = [initial_response, subdir_response]
|
|
418
|
+
|
|
419
|
+
# Act
|
|
420
|
+
result = download_git_directory("https://api.github.com/test", "/local/dir")
|
|
421
|
+
|
|
422
|
+
# Assert
|
|
423
|
+
assert result is True
|
|
424
|
+
expected_calls = [
|
|
425
|
+
call("/local/dir", exist_ok=True),
|
|
426
|
+
call("/local/dir/subdir", exist_ok=True),
|
|
427
|
+
]
|
|
428
|
+
mock_makedirs.assert_has_calls(expected_calls)
|
|
429
|
+
|
|
430
|
+
@patch("os.makedirs")
|
|
431
|
+
@patch("requests.get")
|
|
432
|
+
@patch("builtins.print")
|
|
433
|
+
def test_download_with_authentication_token(
|
|
434
|
+
self, mock_print, mock_get, mock_makedirs
|
|
435
|
+
):
|
|
436
|
+
"""Test download with authentication token"""
|
|
437
|
+
# Arrange
|
|
438
|
+
mock_response = Mock()
|
|
439
|
+
mock_response.json.return_value = []
|
|
440
|
+
mock_response.raise_for_status.return_value = None
|
|
441
|
+
mock_get.return_value = mock_response
|
|
442
|
+
|
|
443
|
+
# Act
|
|
444
|
+
result = download_git_directory(
|
|
445
|
+
"https://api.github.com/test", "/local/dir", token="test_token"
|
|
446
|
+
)
|
|
447
|
+
|
|
448
|
+
# Assert
|
|
449
|
+
assert result is True
|
|
450
|
+
mock_get.assert_called_once_with(
|
|
451
|
+
"https://api.github.com/test", headers={"Authorization": "token test_token"}
|
|
452
|
+
)
|
|
453
|
+
|
|
454
|
+
@patch("os.makedirs")
|
|
455
|
+
@patch("requests.get")
|
|
456
|
+
@patch("builtins.print")
|
|
457
|
+
def test_download_single_file_response(self, mock_print, mock_get, mock_makedirs):
|
|
458
|
+
"""Test when API returns a single file instead of a list"""
|
|
459
|
+
# Arrange
|
|
460
|
+
mock_response = Mock()
|
|
461
|
+
mock_response.json.return_value = {
|
|
462
|
+
"name": "single_file.txt",
|
|
463
|
+
"type": "file",
|
|
464
|
+
"download_url": "https://example.com/single_file.txt",
|
|
465
|
+
}
|
|
466
|
+
mock_response.raise_for_status.return_value = None
|
|
467
|
+
|
|
468
|
+
file_response = Mock()
|
|
469
|
+
file_response.content = b"single file content"
|
|
470
|
+
file_response.raise_for_status.return_value = None
|
|
471
|
+
|
|
472
|
+
mock_get.side_effect = [mock_response, file_response]
|
|
473
|
+
|
|
474
|
+
with patch("builtins.open", mock_open()) as mock_file:
|
|
475
|
+
# Act
|
|
476
|
+
result = download_git_directory("https://api.github.com/test", "/local/dir")
|
|
477
|
+
|
|
478
|
+
# Assert
|
|
479
|
+
assert result is True
|
|
480
|
+
mock_file.assert_called_once_with("/local/dir/single_file.txt", "wb")
|
|
481
|
+
|
|
482
|
+
@patch("os.makedirs")
|
|
483
|
+
@patch("requests.get")
|
|
484
|
+
@patch("builtins.print")
|
|
485
|
+
def test_download_exception_handling(self, mock_print, mock_get, mock_makedirs):
|
|
486
|
+
"""Test exception handling during download"""
|
|
487
|
+
# Arrange
|
|
488
|
+
mock_get.side_effect = Exception("Network error")
|
|
489
|
+
|
|
490
|
+
# Act
|
|
491
|
+
result = download_git_directory("https://api.github.com/test", "/local/dir")
|
|
492
|
+
|
|
493
|
+
# Assert
|
|
494
|
+
assert result is False
|
|
495
|
+
mock_print.assert_called_with("Error processing response: Network error")
|
|
496
|
+
|
|
497
|
+
|
|
498
|
+
if __name__ == "__main__":
|
|
499
|
+
pytest.main([__file__])
|
|
@@ -1,4 +1,5 @@
|
|
|
1
|
-
|
|
1
|
+
import asyncio
|
|
2
|
+
from unittest.mock import AsyncMock, MagicMock, call, patch
|
|
2
3
|
|
|
3
4
|
import pytest
|
|
4
5
|
from fastapi import FastAPI, WebSocket
|
|
@@ -31,33 +32,38 @@ def client_ws(app):
|
|
|
31
32
|
|
|
32
33
|
@pytest.mark.asyncio
|
|
33
34
|
async def test_proxy_ws_bidirectional_messaging(client_ws):
|
|
34
|
-
|
|
35
|
-
client_ws.receive
|
|
36
|
-
{"type": "websocket.receive", "text": "msg1"},
|
|
37
|
-
{"type": "websocket.receive", "text": "msg2"},
|
|
38
|
-
{"type": "websocket.disconnect"},
|
|
39
|
-
]
|
|
35
|
+
client_queue = asyncio.Queue()
|
|
36
|
+
client_ws.receive = client_queue.get
|
|
40
37
|
|
|
38
|
+
server_queue = asyncio.Queue()
|
|
41
39
|
mock_server_ws = AsyncMock(spec=AsyncWebSocketSession)
|
|
42
|
-
mock_server_ws.receive
|
|
43
|
-
TextMessage(data="response1"),
|
|
44
|
-
TextMessage(data="response2"),
|
|
45
|
-
None, # server closing connection
|
|
46
|
-
]
|
|
40
|
+
mock_server_ws.receive = server_queue.get
|
|
47
41
|
mock_server_ws.__aenter__.return_value = mock_server_ws
|
|
48
42
|
mock_server_ws.__aexit__.return_value = None
|
|
49
43
|
|
|
44
|
+
client_queue.put_nowait({"type": "websocket.receive", "text": "msg1"})
|
|
45
|
+
client_queue.put_nowait({"type": "websocket.receive", "text": "msg2"})
|
|
46
|
+
server_queue.put_nowait(TextMessage(data="response1"))
|
|
47
|
+
server_queue.put_nowait(TextMessage(data="response2"))
|
|
48
|
+
|
|
50
49
|
with patch(
|
|
51
50
|
"truss.templates.control.control.endpoints.aconnect_ws",
|
|
52
51
|
return_value=mock_server_ws,
|
|
53
52
|
):
|
|
54
|
-
|
|
53
|
+
proxy_task = asyncio.create_task(proxy_ws(client_ws))
|
|
54
|
+
client_queue.put_nowait(
|
|
55
|
+
{"type": "websocket.disconnect", "code": 1002, "reason": "test-closure"}
|
|
56
|
+
)
|
|
57
|
+
|
|
58
|
+
await proxy_task
|
|
55
59
|
|
|
56
60
|
assert mock_server_ws.send_text.call_count == 2
|
|
57
61
|
assert mock_server_ws.send_text.call_args_list == [(("msg1",),), (("msg2",),)]
|
|
58
62
|
assert client_ws.send_text.call_count == 2
|
|
59
63
|
assert client_ws.send_text.call_args_list == [(("response1",),), (("response2",),)]
|
|
60
|
-
|
|
64
|
+
|
|
65
|
+
assert mock_server_ws.close.call_args_list[0] == call(1002, "test-closure")
|
|
66
|
+
client_ws.close.assert_called()
|
|
61
67
|
|
|
62
68
|
|
|
63
69
|
@pytest.mark.asyncio
|