trainml 0.5.17__py3-none-any.whl → 1.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- examples/local_storage.py +0 -2
- tests/integration/test_checkpoints_integration.py +4 -3
- tests/integration/test_datasets_integration.py +5 -3
- tests/integration/test_jobs_integration.py +33 -27
- tests/integration/test_models_integration.py +7 -3
- tests/integration/test_volumes_integration.py +2 -2
- tests/unit/cli/test_cli_checkpoint_unit.py +312 -1
- tests/unit/cloudbender/test_nodes_unit.py +112 -0
- tests/unit/cloudbender/test_providers_unit.py +96 -0
- tests/unit/cloudbender/test_regions_unit.py +106 -0
- tests/unit/cloudbender/test_services_unit.py +141 -0
- tests/unit/conftest.py +23 -10
- tests/unit/projects/test_project_data_connectors_unit.py +39 -0
- tests/unit/projects/test_project_datastores_unit.py +37 -0
- tests/unit/projects/test_project_members_unit.py +46 -0
- tests/unit/projects/test_project_services_unit.py +65 -0
- tests/unit/projects/test_projects_unit.py +16 -0
- tests/unit/test_auth_unit.py +17 -2
- tests/unit/test_checkpoints_unit.py +256 -71
- tests/unit/test_datasets_unit.py +218 -68
- tests/unit/test_exceptions.py +133 -0
- tests/unit/test_gpu_types_unit.py +11 -1
- tests/unit/test_jobs_unit.py +1014 -95
- tests/unit/test_main_unit.py +20 -0
- tests/unit/test_models_unit.py +218 -70
- tests/unit/test_trainml_unit.py +627 -3
- tests/unit/test_volumes_unit.py +211 -70
- tests/unit/utils/__init__.py +1 -0
- tests/unit/utils/test_transfer_unit.py +4260 -0
- trainml/__init__.py +1 -1
- trainml/checkpoints.py +56 -57
- trainml/cli/__init__.py +6 -3
- trainml/cli/checkpoint.py +18 -57
- trainml/cli/dataset.py +17 -57
- trainml/cli/job/__init__.py +89 -67
- trainml/cli/job/create.py +51 -24
- trainml/cli/model.py +14 -56
- trainml/cli/volume.py +18 -57
- trainml/datasets.py +50 -55
- trainml/jobs.py +269 -69
- trainml/models.py +51 -55
- trainml/trainml.py +159 -114
- trainml/utils/__init__.py +1 -0
- trainml/utils/auth.py +641 -0
- trainml/utils/transfer.py +647 -0
- trainml/volumes.py +48 -53
- {trainml-0.5.17.dist-info → trainml-1.0.1.dist-info}/METADATA +3 -3
- {trainml-0.5.17.dist-info → trainml-1.0.1.dist-info}/RECORD +52 -46
- {trainml-0.5.17.dist-info → trainml-1.0.1.dist-info}/LICENSE +0 -0
- {trainml-0.5.17.dist-info → trainml-1.0.1.dist-info}/WHEEL +0 -0
- {trainml-0.5.17.dist-info → trainml-1.0.1.dist-info}/entry_points.txt +0 -0
- {trainml-0.5.17.dist-info → trainml-1.0.1.dist-info}/top_level.txt +0 -0
|
@@ -37,6 +37,42 @@ def project_member(mock_trainml):
|
|
|
37
37
|
|
|
38
38
|
|
|
39
39
|
class ProjectMembersTests:
|
|
40
|
+
@mark.asyncio
|
|
41
|
+
async def test_project_members_add(self, project_members, mock_trainml):
|
|
42
|
+
"""Test add method (lines 18-31)."""
|
|
43
|
+
api_response = {
|
|
44
|
+
"project_uuid": "proj-id-1",
|
|
45
|
+
"email": "newuser@gmail.com",
|
|
46
|
+
"owner": False,
|
|
47
|
+
"job": "all",
|
|
48
|
+
"dataset": "read",
|
|
49
|
+
"model": "all",
|
|
50
|
+
"checkpoint": "read",
|
|
51
|
+
"volume": "all",
|
|
52
|
+
}
|
|
53
|
+
mock_trainml._query = AsyncMock(return_value=api_response)
|
|
54
|
+
result = await project_members.add(
|
|
55
|
+
email="newuser@gmail.com",
|
|
56
|
+
job="all",
|
|
57
|
+
dataset="read",
|
|
58
|
+
model="all",
|
|
59
|
+
checkpoint="read",
|
|
60
|
+
volume="all",
|
|
61
|
+
param1="value1",
|
|
62
|
+
)
|
|
63
|
+
expected_payload = dict(
|
|
64
|
+
email="newuser@gmail.com",
|
|
65
|
+
job="all",
|
|
66
|
+
dataset="read",
|
|
67
|
+
model="all",
|
|
68
|
+
checkpoint="read",
|
|
69
|
+
volume="all",
|
|
70
|
+
)
|
|
71
|
+
mock_trainml._query.assert_called_once_with(
|
|
72
|
+
"/project/1/access", "POST", dict(param1="value1"), expected_payload
|
|
73
|
+
)
|
|
74
|
+
assert result.email == "newuser@gmail.com"
|
|
75
|
+
|
|
40
76
|
@mark.asyncio
|
|
41
77
|
async def test_project_members_list(self, project_members, mock_trainml):
|
|
42
78
|
api_response = [
|
|
@@ -72,6 +108,16 @@ class ProjectMembersTests:
|
|
|
72
108
|
)
|
|
73
109
|
assert len(resp) == 2
|
|
74
110
|
|
|
111
|
+
@mark.asyncio
|
|
112
|
+
async def test_project_members_remove(self, project_members, mock_trainml):
|
|
113
|
+
"""Test remove method (line 35)."""
|
|
114
|
+
api_response = dict()
|
|
115
|
+
mock_trainml._query = AsyncMock(return_value=api_response)
|
|
116
|
+
await project_members.remove("user@gmail.com", param1="value1")
|
|
117
|
+
mock_trainml._query.assert_called_once_with(
|
|
118
|
+
"/project/1/access", "DELETE", dict(param1="value1", email="user@gmail.com")
|
|
119
|
+
)
|
|
120
|
+
|
|
75
121
|
|
|
76
122
|
class ProjectMemberTests:
|
|
77
123
|
def test_project_member_properties(self, project_member):
|
|
@@ -34,6 +34,24 @@ def project_service(mock_trainml):
|
|
|
34
34
|
|
|
35
35
|
|
|
36
36
|
class ProjectServicesTests:
|
|
37
|
+
@mark.asyncio
|
|
38
|
+
async def test_project_services_get(self, project_services, mock_trainml):
|
|
39
|
+
"""Test get method (lines 11-14)."""
|
|
40
|
+
api_response = {
|
|
41
|
+
"project_uuid": "proj-id-1",
|
|
42
|
+
"region_uuid": "reg-id-1",
|
|
43
|
+
"id": "res-id-1",
|
|
44
|
+
"type": "port",
|
|
45
|
+
"name": "On-Prem Service A",
|
|
46
|
+
"hostname": "service-a.local",
|
|
47
|
+
}
|
|
48
|
+
mock_trainml._query = AsyncMock(return_value=api_response)
|
|
49
|
+
result = await project_services.get("res-id-1", param1="value1")
|
|
50
|
+
mock_trainml._query.assert_called_once_with(
|
|
51
|
+
"/project/1/services/res-id-1", "GET", dict(param1="value1")
|
|
52
|
+
)
|
|
53
|
+
assert result.id == "res-id-1"
|
|
54
|
+
|
|
37
55
|
@mark.asyncio
|
|
38
56
|
async def test_project_services_refresh(self, project_services, mock_trainml):
|
|
39
57
|
api_response = dict()
|
|
@@ -100,3 +118,50 @@ class ProjectServiceTests:
|
|
|
100
118
|
empty_project_service = specimen.ProjectService(mock_trainml)
|
|
101
119
|
assert bool(project_service)
|
|
102
120
|
assert not bool(empty_project_service)
|
|
121
|
+
|
|
122
|
+
@mark.asyncio
|
|
123
|
+
async def test_project_service_enable(self, project_service, mock_trainml):
|
|
124
|
+
"""Test enable method (line 72)."""
|
|
125
|
+
api_response = dict()
|
|
126
|
+
mock_trainml._query = AsyncMock(return_value=api_response)
|
|
127
|
+
await project_service.enable()
|
|
128
|
+
mock_trainml._query.assert_called_once_with(
|
|
129
|
+
"/project/proj-id-1/services/res-id-1/enable", "PATCH"
|
|
130
|
+
)
|
|
131
|
+
|
|
132
|
+
@mark.asyncio
|
|
133
|
+
async def test_project_service_disable(self, project_service, mock_trainml):
|
|
134
|
+
"""Test disable method (line 77)."""
|
|
135
|
+
api_response = dict()
|
|
136
|
+
mock_trainml._query = AsyncMock(return_value=api_response)
|
|
137
|
+
await project_service.disable()
|
|
138
|
+
mock_trainml._query.assert_called_once_with(
|
|
139
|
+
"/project/proj-id-1/services/res-id-1/disable", "PATCH"
|
|
140
|
+
)
|
|
141
|
+
|
|
142
|
+
@mark.asyncio
|
|
143
|
+
async def test_project_service_get_service_ca_certificate(self, project_service, mock_trainml):
|
|
144
|
+
"""Test get_service_ca_certificate method (lines 82-87)."""
|
|
145
|
+
api_response = {"certificate": "ca-cert-data"}
|
|
146
|
+
mock_trainml._query = AsyncMock(return_value=api_response)
|
|
147
|
+
result = await project_service.get_service_ca_certificate(param1="value1")
|
|
148
|
+
mock_trainml._query.assert_called_once_with(
|
|
149
|
+
"/project/proj-id-1/services/res-id-1/certificate/ca",
|
|
150
|
+
"GET",
|
|
151
|
+
dict(param1="value1"),
|
|
152
|
+
)
|
|
153
|
+
assert result == api_response
|
|
154
|
+
|
|
155
|
+
@mark.asyncio
|
|
156
|
+
async def test_project_service_sign_client_certificate(self, project_service, mock_trainml):
|
|
157
|
+
"""Test sign_client_certificate method (lines 90-96)."""
|
|
158
|
+
api_response = {"certificate": "signed-cert-data"}
|
|
159
|
+
mock_trainml._query = AsyncMock(return_value=api_response)
|
|
160
|
+
result = await project_service.sign_client_certificate("csr-data", param1="value1")
|
|
161
|
+
mock_trainml._query.assert_called_once_with(
|
|
162
|
+
"/project/proj-id-1/services/res-id-1/certificate/sign",
|
|
163
|
+
"POST",
|
|
164
|
+
dict(param1="value1"),
|
|
165
|
+
dict(csr="csr-data"),
|
|
166
|
+
)
|
|
167
|
+
assert result == api_response
|
|
@@ -48,6 +48,22 @@ class ProjectsTests:
|
|
|
48
48
|
await projects.get("1234")
|
|
49
49
|
mock_trainml._query.assert_called_once_with("/project/1234", "GET", dict())
|
|
50
50
|
|
|
51
|
+
@mark.asyncio
|
|
52
|
+
async def test_get_current_project(self, projects, mock_trainml):
|
|
53
|
+
"""Test get_current method (lines 20-23)."""
|
|
54
|
+
api_response = {
|
|
55
|
+
"id": "project-id-1",
|
|
56
|
+
"name": "current project",
|
|
57
|
+
"owner": True,
|
|
58
|
+
}
|
|
59
|
+
mock_trainml.project = "project-id-1"
|
|
60
|
+
mock_trainml._query = AsyncMock(return_value=api_response)
|
|
61
|
+
result = await projects.get_current(param1="value1")
|
|
62
|
+
mock_trainml._query.assert_called_once_with(
|
|
63
|
+
"/project/project-id-1", "GET", dict(param1="value1")
|
|
64
|
+
)
|
|
65
|
+
assert result.id == "project-id-1"
|
|
66
|
+
|
|
51
67
|
@mark.asyncio
|
|
52
68
|
async def test_list_projects(
|
|
53
69
|
self,
|
tests/unit/test_auth_unit.py
CHANGED
|
@@ -6,7 +6,7 @@ from unittest.mock import AsyncMock, patch, mock_open, MagicMock
|
|
|
6
6
|
from pytest import mark, fixture, raises
|
|
7
7
|
from aiohttp import WSMessage, WSMsgType
|
|
8
8
|
|
|
9
|
-
import trainml.auth as specimen
|
|
9
|
+
import trainml.utils.auth as specimen
|
|
10
10
|
|
|
11
11
|
pytestmark = [mark.sdk, mark.unit]
|
|
12
12
|
|
|
@@ -21,7 +21,22 @@ pytestmark = [mark.sdk, mark.unit]
|
|
|
21
21
|
"TRAINML_POOL_ID": "pool_id",
|
|
22
22
|
},
|
|
23
23
|
)
|
|
24
|
-
|
|
24
|
+
@patch("trainml.utils.auth.boto3.client")
|
|
25
|
+
@patch("trainml.utils.auth.requests.get")
|
|
26
|
+
@patch("builtins.open", side_effect=FileNotFoundError)
|
|
27
|
+
def test_auth_from_envs(mock_open, mock_requests_get, mock_boto3_client):
|
|
28
|
+
# Mock the auth config request
|
|
29
|
+
mock_response = MagicMock()
|
|
30
|
+
mock_response.json.return_value = {
|
|
31
|
+
"region": "us-east-1",
|
|
32
|
+
"userPoolSDKClientId": "default_client_id",
|
|
33
|
+
"userPoolId": "default_pool_id",
|
|
34
|
+
}
|
|
35
|
+
mock_requests_get.return_value = mock_response
|
|
36
|
+
|
|
37
|
+
# Mock boto3 client
|
|
38
|
+
mock_boto3_client.return_value = MagicMock()
|
|
39
|
+
|
|
25
40
|
auth = specimen.Auth(config_dir=os.path.expanduser("~/.trainml"))
|
|
26
41
|
assert auth.__dict__.get("username") == "user-id"
|
|
27
42
|
assert auth.__dict__.get("password") == "key"
|
|
@@ -44,7 +44,9 @@ class CheckpointsTests:
|
|
|
44
44
|
api_response = dict()
|
|
45
45
|
mock_trainml._query = AsyncMock(return_value=api_response)
|
|
46
46
|
await checkpoints.get("1234")
|
|
47
|
-
mock_trainml._query.assert_called_once_with(
|
|
47
|
+
mock_trainml._query.assert_called_once_with(
|
|
48
|
+
"/checkpoint/1234", "GET", dict()
|
|
49
|
+
)
|
|
48
50
|
|
|
49
51
|
@mark.asyncio
|
|
50
52
|
async def test_list_checkpoints(
|
|
@@ -55,7 +57,30 @@ class CheckpointsTests:
|
|
|
55
57
|
api_response = dict()
|
|
56
58
|
mock_trainml._query = AsyncMock(return_value=api_response)
|
|
57
59
|
await checkpoints.list()
|
|
58
|
-
mock_trainml._query.assert_called_once_with(
|
|
60
|
+
mock_trainml._query.assert_called_once_with(
|
|
61
|
+
"/checkpoint", "GET", dict()
|
|
62
|
+
)
|
|
63
|
+
|
|
64
|
+
@mark.asyncio
|
|
65
|
+
async def test_list_public_checkpoints(
|
|
66
|
+
self,
|
|
67
|
+
checkpoints,
|
|
68
|
+
mock_trainml,
|
|
69
|
+
):
|
|
70
|
+
api_response = [
|
|
71
|
+
dict(
|
|
72
|
+
checkpoint_uuid="1",
|
|
73
|
+
name="public checkpoint",
|
|
74
|
+
status="ready",
|
|
75
|
+
)
|
|
76
|
+
]
|
|
77
|
+
mock_trainml._query = AsyncMock(return_value=api_response)
|
|
78
|
+
result = await checkpoints.list_public()
|
|
79
|
+
mock_trainml._query.assert_called_once_with(
|
|
80
|
+
"/checkpoint/public", "GET", dict()
|
|
81
|
+
)
|
|
82
|
+
assert len(result) == 1
|
|
83
|
+
assert isinstance(result[0], specimen.Checkpoint)
|
|
59
84
|
|
|
60
85
|
@mark.asyncio
|
|
61
86
|
async def test_remove_checkpoint(
|
|
@@ -133,9 +158,7 @@ class CheckpointTests:
|
|
|
133
158
|
|
|
134
159
|
@mark.asyncio
|
|
135
160
|
async def test_checkpoint_get_log_url(self, checkpoint, mock_trainml):
|
|
136
|
-
api_response =
|
|
137
|
-
"https://trainml-jobs-dev.s3.us-east-2.amazonaws.com/1/logs/first_one.zip"
|
|
138
|
-
)
|
|
161
|
+
api_response = "https://trainml-jobs-dev.s3.us-east-2.amazonaws.com/1/logs/first_one.zip"
|
|
139
162
|
mock_trainml._query = AsyncMock(return_value=api_response)
|
|
140
163
|
response = await checkpoint.get_log_url()
|
|
141
164
|
mock_trainml._query.assert_called_once_with(
|
|
@@ -160,81 +183,155 @@ class CheckpointTests:
|
|
|
160
183
|
assert response == api_response
|
|
161
184
|
|
|
162
185
|
@mark.asyncio
|
|
163
|
-
async def
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
"
|
|
186
|
+
async def test_checkpoint_connect_downloading_status(self, mock_trainml):
|
|
187
|
+
checkpoint = specimen.Checkpoint(
|
|
188
|
+
mock_trainml,
|
|
189
|
+
checkpoint_uuid="1",
|
|
190
|
+
project_uuid="proj-id-1",
|
|
191
|
+
name="test checkpoint",
|
|
192
|
+
status="downloading",
|
|
193
|
+
auth_token="test-token",
|
|
194
|
+
hostname="example.com",
|
|
195
|
+
source_uri="/path/to/source",
|
|
168
196
|
)
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
197
|
+
|
|
198
|
+
with patch(
|
|
199
|
+
"trainml.checkpoints.Checkpoint.refresh", new_callable=AsyncMock
|
|
200
|
+
) as mock_refresh:
|
|
201
|
+
with patch(
|
|
202
|
+
"trainml.checkpoints.upload", new_callable=AsyncMock
|
|
203
|
+
) as mock_upload:
|
|
204
|
+
await checkpoint.connect()
|
|
205
|
+
mock_refresh.assert_called_once()
|
|
206
|
+
mock_upload.assert_called_once_with(
|
|
207
|
+
"example.com", "test-token", "/path/to/source"
|
|
208
|
+
)
|
|
209
|
+
|
|
210
|
+
@mark.asyncio
|
|
211
|
+
async def test_checkpoint_connect_exporting_status(
|
|
212
|
+
self, mock_trainml, tmp_path
|
|
213
|
+
):
|
|
214
|
+
output_dir = str(tmp_path / "output")
|
|
215
|
+
checkpoint = specimen.Checkpoint(
|
|
216
|
+
mock_trainml,
|
|
217
|
+
checkpoint_uuid="1",
|
|
218
|
+
project_uuid="proj-id-1",
|
|
219
|
+
name="test checkpoint",
|
|
220
|
+
status="exporting",
|
|
221
|
+
auth_token="test-token",
|
|
222
|
+
hostname="example.com",
|
|
223
|
+
output_uri=output_dir,
|
|
173
224
|
)
|
|
174
|
-
assert response == api_response
|
|
175
225
|
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
226
|
+
with patch(
|
|
227
|
+
"trainml.checkpoints.Checkpoint.refresh", new_callable=AsyncMock
|
|
228
|
+
) as mock_refresh:
|
|
229
|
+
with patch(
|
|
230
|
+
"trainml.checkpoints.download", new_callable=AsyncMock
|
|
231
|
+
) as mock_download:
|
|
232
|
+
await checkpoint.connect()
|
|
233
|
+
mock_refresh.assert_called_once()
|
|
234
|
+
mock_download.assert_called_once_with(
|
|
235
|
+
"example.com", "test-token", output_dir
|
|
236
|
+
)
|
|
180
237
|
|
|
181
|
-
|
|
238
|
+
@mark.asyncio
|
|
239
|
+
async def test_checkpoint_connect_new_status_waits_for_downloading(
|
|
240
|
+
self, mock_trainml
|
|
241
|
+
):
|
|
182
242
|
checkpoint = specimen.Checkpoint(
|
|
183
243
|
mock_trainml,
|
|
184
244
|
checkpoint_uuid="1",
|
|
185
|
-
project_uuid="
|
|
186
|
-
name="
|
|
245
|
+
project_uuid="proj-id-1",
|
|
246
|
+
name="test checkpoint",
|
|
187
247
|
status="new",
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
"
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
"
|
|
200
|
-
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
|
|
213
|
-
|
|
248
|
+
)
|
|
249
|
+
|
|
250
|
+
with patch(
|
|
251
|
+
"trainml.checkpoints.Checkpoint.wait_for", new_callable=AsyncMock
|
|
252
|
+
) as mock_wait:
|
|
253
|
+
with patch(
|
|
254
|
+
"trainml.checkpoints.Checkpoint.refresh",
|
|
255
|
+
new_callable=AsyncMock,
|
|
256
|
+
) as mock_refresh:
|
|
257
|
+
# After refresh, status becomes downloading
|
|
258
|
+
def update_status():
|
|
259
|
+
checkpoint._status = "downloading"
|
|
260
|
+
checkpoint._checkpoint.update(
|
|
261
|
+
{
|
|
262
|
+
"auth_token": "test-token",
|
|
263
|
+
"hostname": "example.com",
|
|
264
|
+
"source_uri": "/path/to/source",
|
|
265
|
+
}
|
|
266
|
+
)
|
|
267
|
+
|
|
268
|
+
mock_refresh.side_effect = update_status
|
|
269
|
+
|
|
270
|
+
with patch(
|
|
271
|
+
"trainml.checkpoints.upload", new_callable=AsyncMock
|
|
272
|
+
) as mock_upload:
|
|
273
|
+
await checkpoint.connect()
|
|
274
|
+
mock_wait.assert_called_once_with("downloading")
|
|
275
|
+
mock_refresh.assert_called_once()
|
|
276
|
+
mock_upload.assert_called_once()
|
|
214
277
|
|
|
215
278
|
@mark.asyncio
|
|
216
|
-
async def
|
|
279
|
+
async def test_checkpoint_connect_invalid_status(self, mock_trainml):
|
|
280
|
+
checkpoint = specimen.Checkpoint(
|
|
281
|
+
mock_trainml,
|
|
282
|
+
checkpoint_uuid="1",
|
|
283
|
+
project_uuid="proj-id-1",
|
|
284
|
+
name="test checkpoint",
|
|
285
|
+
status="ready",
|
|
286
|
+
)
|
|
287
|
+
|
|
288
|
+
with raises(
|
|
289
|
+
SpecificationError,
|
|
290
|
+
match="You can only connect to downloading or exporting checkpoints",
|
|
291
|
+
):
|
|
292
|
+
await checkpoint.connect()
|
|
293
|
+
|
|
294
|
+
@mark.asyncio
|
|
295
|
+
async def test_checkpoint_connect_missing_properties_downloading(
|
|
296
|
+
self, mock_trainml
|
|
297
|
+
):
|
|
298
|
+
checkpoint = specimen.Checkpoint(
|
|
299
|
+
mock_trainml,
|
|
300
|
+
checkpoint_uuid="1",
|
|
301
|
+
project_uuid="proj-id-1",
|
|
302
|
+
name="test checkpoint",
|
|
303
|
+
status="downloading",
|
|
304
|
+
)
|
|
305
|
+
|
|
217
306
|
with patch(
|
|
218
|
-
"trainml.checkpoints.
|
|
219
|
-
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
|
|
223
|
-
|
|
224
|
-
|
|
225
|
-
assert resp == "connected"
|
|
307
|
+
"trainml.checkpoints.Checkpoint.refresh", new_callable=AsyncMock
|
|
308
|
+
):
|
|
309
|
+
with raises(
|
|
310
|
+
SpecificationError,
|
|
311
|
+
match="missing required connection properties",
|
|
312
|
+
):
|
|
313
|
+
await checkpoint.connect()
|
|
226
314
|
|
|
227
315
|
@mark.asyncio
|
|
228
|
-
async def
|
|
316
|
+
async def test_checkpoint_connect_missing_properties_exporting(
|
|
317
|
+
self, mock_trainml
|
|
318
|
+
):
|
|
319
|
+
checkpoint = specimen.Checkpoint(
|
|
320
|
+
mock_trainml,
|
|
321
|
+
checkpoint_uuid="1",
|
|
322
|
+
project_uuid="proj-id-1",
|
|
323
|
+
name="test checkpoint",
|
|
324
|
+
status="exporting",
|
|
325
|
+
)
|
|
326
|
+
|
|
229
327
|
with patch(
|
|
230
|
-
"trainml.checkpoints.
|
|
231
|
-
|
|
232
|
-
|
|
233
|
-
|
|
234
|
-
|
|
235
|
-
|
|
236
|
-
|
|
237
|
-
assert resp == "removed"
|
|
328
|
+
"trainml.checkpoints.Checkpoint.refresh", new_callable=AsyncMock
|
|
329
|
+
):
|
|
330
|
+
with raises(
|
|
331
|
+
SpecificationError,
|
|
332
|
+
match="missing required connection properties",
|
|
333
|
+
):
|
|
334
|
+
await checkpoint.connect()
|
|
238
335
|
|
|
239
336
|
@mark.asyncio
|
|
240
337
|
async def test_checkpoint_remove(self, checkpoint, mock_trainml):
|
|
@@ -340,7 +437,9 @@ class CheckpointTests:
|
|
|
340
437
|
assert response.id == "data-id-1"
|
|
341
438
|
|
|
342
439
|
@mark.asyncio
|
|
343
|
-
async def test_checkpoint_wait_for_successful(
|
|
440
|
+
async def test_checkpoint_wait_for_successful(
|
|
441
|
+
self, checkpoint, mock_trainml
|
|
442
|
+
):
|
|
344
443
|
api_response = {
|
|
345
444
|
"customer_uuid": "cus-id-1",
|
|
346
445
|
"checkpoint_uuid": "data-id-1",
|
|
@@ -373,7 +472,9 @@ class CheckpointTests:
|
|
|
373
472
|
mock_trainml._query.assert_not_called()
|
|
374
473
|
|
|
375
474
|
@mark.asyncio
|
|
376
|
-
async def test_checkpoint_wait_for_incorrect_status(
|
|
475
|
+
async def test_checkpoint_wait_for_incorrect_status(
|
|
476
|
+
self, checkpoint, mock_trainml
|
|
477
|
+
):
|
|
377
478
|
api_response = None
|
|
378
479
|
mock_trainml._query = AsyncMock(return_value=api_response)
|
|
379
480
|
with raises(SpecificationError):
|
|
@@ -381,7 +482,9 @@ class CheckpointTests:
|
|
|
381
482
|
mock_trainml._query.assert_not_called()
|
|
382
483
|
|
|
383
484
|
@mark.asyncio
|
|
384
|
-
async def test_checkpoint_wait_for_with_delay(
|
|
485
|
+
async def test_checkpoint_wait_for_with_delay(
|
|
486
|
+
self, checkpoint, mock_trainml
|
|
487
|
+
):
|
|
385
488
|
api_response_initial = dict(
|
|
386
489
|
checkpoint_uuid="1",
|
|
387
490
|
name="first one",
|
|
@@ -437,7 +540,9 @@ class CheckpointTests:
|
|
|
437
540
|
self, checkpoint, mock_trainml
|
|
438
541
|
):
|
|
439
542
|
mock_trainml._query = AsyncMock(
|
|
440
|
-
side_effect=ApiError(
|
|
543
|
+
side_effect=ApiError(
|
|
544
|
+
404, dict(errorMessage="Checkpoint Not Found")
|
|
545
|
+
)
|
|
441
546
|
)
|
|
442
547
|
await checkpoint.wait_for("archived")
|
|
443
548
|
mock_trainml._query.assert_called()
|
|
@@ -447,8 +552,88 @@ class CheckpointTests:
|
|
|
447
552
|
self, checkpoint, mock_trainml
|
|
448
553
|
):
|
|
449
554
|
mock_trainml._query = AsyncMock(
|
|
450
|
-
side_effect=ApiError(
|
|
555
|
+
side_effect=ApiError(
|
|
556
|
+
404, dict(errorMessage="Checkpoint Not Found")
|
|
557
|
+
)
|
|
451
558
|
)
|
|
452
559
|
with raises(ApiError):
|
|
453
560
|
await checkpoint.wait_for("ready")
|
|
454
561
|
mock_trainml._query.assert_called()
|
|
562
|
+
|
|
563
|
+
@mark.asyncio
|
|
564
|
+
async def test_checkpoint_rename(self, checkpoint, mock_trainml):
|
|
565
|
+
api_response = dict(
|
|
566
|
+
checkpoint_uuid="1",
|
|
567
|
+
name="renamed checkpoint",
|
|
568
|
+
project_uuid="proj-id-1",
|
|
569
|
+
status="ready",
|
|
570
|
+
)
|
|
571
|
+
mock_trainml._query = AsyncMock(return_value=api_response)
|
|
572
|
+
result = await checkpoint.rename("renamed checkpoint")
|
|
573
|
+
mock_trainml._query.assert_called_once_with(
|
|
574
|
+
"/checkpoint/1",
|
|
575
|
+
"PATCH",
|
|
576
|
+
dict(project_uuid="proj-id-1"),
|
|
577
|
+
dict(name="renamed checkpoint"),
|
|
578
|
+
)
|
|
579
|
+
assert result == checkpoint
|
|
580
|
+
assert checkpoint.name == "renamed checkpoint"
|
|
581
|
+
|
|
582
|
+
@mark.asyncio
|
|
583
|
+
async def test_checkpoint_export(self, checkpoint, mock_trainml):
|
|
584
|
+
api_response = dict(
|
|
585
|
+
checkpoint_uuid="1",
|
|
586
|
+
name="first one",
|
|
587
|
+
project_uuid="proj-id-1",
|
|
588
|
+
status="exporting",
|
|
589
|
+
)
|
|
590
|
+
mock_trainml._query = AsyncMock(return_value=api_response)
|
|
591
|
+
result = await checkpoint.export("aws", "s3://bucket/path", dict(key="value"))
|
|
592
|
+
mock_trainml._query.assert_called_once_with(
|
|
593
|
+
"/checkpoint/1/export",
|
|
594
|
+
"POST",
|
|
595
|
+
dict(project_uuid="proj-id-1"),
|
|
596
|
+
dict(
|
|
597
|
+
output_type="aws",
|
|
598
|
+
output_uri="s3://bucket/path",
|
|
599
|
+
output_options=dict(key="value"),
|
|
600
|
+
),
|
|
601
|
+
)
|
|
602
|
+
assert result == checkpoint
|
|
603
|
+
assert checkpoint.status == "exporting"
|
|
604
|
+
|
|
605
|
+
@mark.asyncio
|
|
606
|
+
async def test_checkpoint_export_default_options(self, checkpoint, mock_trainml):
|
|
607
|
+
api_response = dict(
|
|
608
|
+
checkpoint_uuid="1",
|
|
609
|
+
name="first one",
|
|
610
|
+
project_uuid="proj-id-1",
|
|
611
|
+
status="exporting",
|
|
612
|
+
)
|
|
613
|
+
mock_trainml._query = AsyncMock(return_value=api_response)
|
|
614
|
+
result = await checkpoint.export("aws", "s3://bucket/path")
|
|
615
|
+
mock_trainml._query.assert_called_once_with(
|
|
616
|
+
"/checkpoint/1/export",
|
|
617
|
+
"POST",
|
|
618
|
+
dict(project_uuid="proj-id-1"),
|
|
619
|
+
dict(
|
|
620
|
+
output_type="aws",
|
|
621
|
+
output_uri="s3://bucket/path",
|
|
622
|
+
output_options=dict(),
|
|
623
|
+
),
|
|
624
|
+
)
|
|
625
|
+
assert result == checkpoint
|
|
626
|
+
|
|
627
|
+
@mark.asyncio
|
|
628
|
+
async def test_checkpoint_wait_for_timeout_validation(
|
|
629
|
+
self, checkpoint, mock_trainml
|
|
630
|
+
):
|
|
631
|
+
with raises(SpecificationError) as exc_info:
|
|
632
|
+
await checkpoint.wait_for("ready", timeout=25 * 60 * 60) # > 24 hours
|
|
633
|
+
assert "timeout" in str(exc_info.value.attribute).lower()
|
|
634
|
+
assert "less than" in str(exc_info.value.message).lower()
|
|
635
|
+
|
|
636
|
+
def test_checkpoint_billed_size_property(self, checkpoint, mock_trainml):
|
|
637
|
+
"""Test billed_size property access."""
|
|
638
|
+
checkpoint._billed_size = 50000
|
|
639
|
+
assert checkpoint.billed_size == 50000
|