trainml 0.5.4__py3-none-any.whl → 0.5.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tests/integration/test_checkpoints_integration.py +7 -5
- tests/integration/test_datasets_integration.py +4 -5
- tests/integration/test_jobs_integration.py +40 -2
- tests/integration/test_models_integration.py +8 -10
- tests/integration/test_projects_integration.py +2 -6
- tests/integration/test_volumes_integration.py +100 -0
- tests/unit/cli/cloudbender/test_cli_reservation_unit.py +10 -14
- tests/unit/cli/test_cli_project_unit.py +5 -9
- tests/unit/cli/test_cli_volume_unit.py +20 -0
- tests/unit/cloudbender/test_services_unit.py +161 -0
- tests/unit/conftest.py +94 -21
- tests/unit/test_projects_unit.py +34 -48
- tests/unit/test_volumes_unit.py +447 -0
- trainml/__init__.py +1 -1
- trainml/cli/__init__.py +3 -6
- trainml/cli/cloudbender/__init__.py +1 -1
- trainml/cli/cloudbender/service.py +129 -0
- trainml/cli/project.py +10 -15
- trainml/cli/volume.py +235 -0
- trainml/cloudbender/cloudbender.py +2 -2
- trainml/cloudbender/services.py +115 -0
- trainml/exceptions.py +21 -12
- trainml/jobs.py +36 -39
- trainml/projects.py +19 -30
- trainml/trainml.py +7 -15
- trainml/volumes.py +255 -0
- {trainml-0.5.4.dist-info → trainml-0.5.6.dist-info}/METADATA +1 -1
- {trainml-0.5.4.dist-info → trainml-0.5.6.dist-info}/RECORD +32 -29
- tests/integration/test_providers_integration.py +0 -46
- tests/unit/test_providers_unit.py +0 -125
- trainml/cli/job.py +0 -173
- trainml/cli/provider.py +0 -75
- trainml/providers.py +0 -63
- {trainml-0.5.4.dist-info → trainml-0.5.6.dist-info}/LICENSE +0 -0
- {trainml-0.5.4.dist-info → trainml-0.5.6.dist-info}/WHEEL +0 -0
- {trainml-0.5.4.dist-info → trainml-0.5.6.dist-info}/entry_points.txt +0 -0
- {trainml-0.5.4.dist-info → trainml-0.5.6.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,447 @@
|
|
|
1
|
+
import re
|
|
2
|
+
import json
|
|
3
|
+
import logging
|
|
4
|
+
from unittest.mock import AsyncMock, patch
|
|
5
|
+
from pytest import mark, fixture, raises
|
|
6
|
+
from aiohttp import WSMessage, WSMsgType
|
|
7
|
+
|
|
8
|
+
import trainml.volumes as specimen
|
|
9
|
+
from trainml.exceptions import (
|
|
10
|
+
ApiError,
|
|
11
|
+
VolumeError,
|
|
12
|
+
SpecificationError,
|
|
13
|
+
TrainMLException,
|
|
14
|
+
)
|
|
15
|
+
|
|
16
|
+
pytestmark = [mark.sdk, mark.unit, mark.volumes]
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
@fixture
|
|
20
|
+
def volumes(mock_trainml):
|
|
21
|
+
yield specimen.Volumes(mock_trainml)
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
@fixture
|
|
25
|
+
def volume(mock_trainml):
|
|
26
|
+
yield specimen.Volume(
|
|
27
|
+
mock_trainml,
|
|
28
|
+
id="1",
|
|
29
|
+
project_uuid="proj-id-1",
|
|
30
|
+
name="first one",
|
|
31
|
+
status="downloading",
|
|
32
|
+
capacity="10G",
|
|
33
|
+
used_size=100000000,
|
|
34
|
+
billed_size=100000000,
|
|
35
|
+
createdAt="2020-12-31T23:59:59.000Z",
|
|
36
|
+
)
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
class VolumesTests:
|
|
40
|
+
@mark.asyncio
|
|
41
|
+
async def test_get_volume(
|
|
42
|
+
self,
|
|
43
|
+
volumes,
|
|
44
|
+
mock_trainml,
|
|
45
|
+
):
|
|
46
|
+
api_response = dict()
|
|
47
|
+
mock_trainml._query = AsyncMock(return_value=api_response)
|
|
48
|
+
await volumes.get("1234")
|
|
49
|
+
mock_trainml._query.assert_called_once_with("/volume/1234", "GET", dict())
|
|
50
|
+
|
|
51
|
+
@mark.asyncio
|
|
52
|
+
async def test_list_volumes(
|
|
53
|
+
self,
|
|
54
|
+
volumes,
|
|
55
|
+
mock_trainml,
|
|
56
|
+
):
|
|
57
|
+
api_response = dict()
|
|
58
|
+
mock_trainml._query = AsyncMock(return_value=api_response)
|
|
59
|
+
await volumes.list()
|
|
60
|
+
mock_trainml._query.assert_called_once_with("/volume", "GET", dict())
|
|
61
|
+
|
|
62
|
+
@mark.asyncio
|
|
63
|
+
async def test_remove_volume(
|
|
64
|
+
self,
|
|
65
|
+
volumes,
|
|
66
|
+
mock_trainml,
|
|
67
|
+
):
|
|
68
|
+
api_response = dict()
|
|
69
|
+
mock_trainml._query = AsyncMock(return_value=api_response)
|
|
70
|
+
await volumes.remove("4567")
|
|
71
|
+
mock_trainml._query.assert_called_once_with(
|
|
72
|
+
"/volume/4567", "DELETE", dict(force=True)
|
|
73
|
+
)
|
|
74
|
+
|
|
75
|
+
@mark.asyncio
|
|
76
|
+
async def test_create_volume_simple(self, volumes, mock_trainml):
|
|
77
|
+
requested_config = dict(
|
|
78
|
+
name="new volume",
|
|
79
|
+
source_type="aws",
|
|
80
|
+
source_uri="s3://trainml-examples/volumes/resnet50",
|
|
81
|
+
capacity="10G",
|
|
82
|
+
)
|
|
83
|
+
expected_payload = dict(
|
|
84
|
+
project_uuid="proj-id-1",
|
|
85
|
+
name="new volume",
|
|
86
|
+
source_type="aws",
|
|
87
|
+
source_uri="s3://trainml-examples/volumes/resnet50",
|
|
88
|
+
capacity="10G",
|
|
89
|
+
)
|
|
90
|
+
api_response = {
|
|
91
|
+
"project_uuid": "cus-id-1",
|
|
92
|
+
"id": "volume-id-1",
|
|
93
|
+
"name": "new volume",
|
|
94
|
+
"status": "new",
|
|
95
|
+
"source_type": "aws",
|
|
96
|
+
"capacity": "10G",
|
|
97
|
+
"source_uri": "s3://trainml-examples/volumes/resnet50",
|
|
98
|
+
"createdAt": "2020-12-20T16:46:23.909Z",
|
|
99
|
+
}
|
|
100
|
+
|
|
101
|
+
mock_trainml._query = AsyncMock(return_value=api_response)
|
|
102
|
+
response = await volumes.create(**requested_config)
|
|
103
|
+
mock_trainml._query.assert_called_once_with(
|
|
104
|
+
"/volume", "POST", None, expected_payload
|
|
105
|
+
)
|
|
106
|
+
assert response.id == "volume-id-1"
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
class VolumeTests:
|
|
110
|
+
def test_volume_properties(self, volume):
|
|
111
|
+
assert isinstance(volume.id, str)
|
|
112
|
+
assert isinstance(volume.status, str)
|
|
113
|
+
assert isinstance(volume.name, str)
|
|
114
|
+
assert isinstance(volume.capacity, str)
|
|
115
|
+
assert isinstance(volume.used_size, int)
|
|
116
|
+
assert isinstance(volume.billed_size, int)
|
|
117
|
+
|
|
118
|
+
def test_volume_str(self, volume):
|
|
119
|
+
string = str(volume)
|
|
120
|
+
regex = r"^{.*\"id\": \"" + volume.id + r"\".*}$"
|
|
121
|
+
assert isinstance(string, str)
|
|
122
|
+
assert re.match(regex, string)
|
|
123
|
+
|
|
124
|
+
def test_volume_repr(self, volume):
|
|
125
|
+
string = repr(volume)
|
|
126
|
+
regex = r"^Volume\( trainml , \*\*{.*'id': '" + volume.id + r"'.*}\)$"
|
|
127
|
+
assert isinstance(string, str)
|
|
128
|
+
assert re.match(regex, string)
|
|
129
|
+
|
|
130
|
+
def test_volume_bool(self, volume, mock_trainml):
|
|
131
|
+
empty_volume = specimen.Volume(mock_trainml)
|
|
132
|
+
assert bool(volume)
|
|
133
|
+
assert not bool(empty_volume)
|
|
134
|
+
|
|
135
|
+
@mark.asyncio
|
|
136
|
+
async def test_volume_get_log_url(self, volume, mock_trainml):
|
|
137
|
+
api_response = (
|
|
138
|
+
"https://trainml-jobs-dev.s3.us-east-2.amazonaws.com/1/logs/first_one.zip"
|
|
139
|
+
)
|
|
140
|
+
mock_trainml._query = AsyncMock(return_value=api_response)
|
|
141
|
+
response = await volume.get_log_url()
|
|
142
|
+
mock_trainml._query.assert_called_once_with(
|
|
143
|
+
"/volume/1/logs", "GET", dict(project_uuid="proj-id-1")
|
|
144
|
+
)
|
|
145
|
+
assert response == api_response
|
|
146
|
+
|
|
147
|
+
@mark.asyncio
|
|
148
|
+
async def test_volume_get_details(self, volume, mock_trainml):
|
|
149
|
+
api_response = {
|
|
150
|
+
"type": "directory",
|
|
151
|
+
"name": "/",
|
|
152
|
+
"count": "8",
|
|
153
|
+
"used_size": "177M",
|
|
154
|
+
"contents": [],
|
|
155
|
+
}
|
|
156
|
+
mock_trainml._query = AsyncMock(return_value=api_response)
|
|
157
|
+
response = await volume.get_details()
|
|
158
|
+
mock_trainml._query.assert_called_once_with(
|
|
159
|
+
"/volume/1/details", "GET", dict(project_uuid="proj-id-1")
|
|
160
|
+
)
|
|
161
|
+
assert response == api_response
|
|
162
|
+
|
|
163
|
+
@mark.asyncio
|
|
164
|
+
async def test_volume_get_connection_utility_url(self, volume, mock_trainml):
|
|
165
|
+
api_response = (
|
|
166
|
+
"https://trainml-jobs-dev.s3.us-east-2.amazonaws.com/1/vpn/first_one.zip"
|
|
167
|
+
)
|
|
168
|
+
mock_trainml._query = AsyncMock(return_value=api_response)
|
|
169
|
+
response = await volume.get_connection_utility_url()
|
|
170
|
+
mock_trainml._query.assert_called_once_with(
|
|
171
|
+
"/volume/1/download", "GET", dict(project_uuid="proj-id-1")
|
|
172
|
+
)
|
|
173
|
+
assert response == api_response
|
|
174
|
+
|
|
175
|
+
def test_volume_get_connection_details_no_vpn(self, volume):
|
|
176
|
+
details = volume.get_connection_details()
|
|
177
|
+
expected_details = dict()
|
|
178
|
+
assert details == expected_details
|
|
179
|
+
|
|
180
|
+
def test_volume_get_connection_details_local_data(self, mock_trainml):
|
|
181
|
+
volume = specimen.Volume(
|
|
182
|
+
mock_trainml,
|
|
183
|
+
id="1",
|
|
184
|
+
project_uuid="a",
|
|
185
|
+
name="first one",
|
|
186
|
+
status="new",
|
|
187
|
+
capacity="10G",
|
|
188
|
+
createdAt="2020-12-31T23:59:59.000Z",
|
|
189
|
+
source_type="local",
|
|
190
|
+
source_uri="~/tensorflow-example",
|
|
191
|
+
vpn={
|
|
192
|
+
"status": "new",
|
|
193
|
+
"cidr": "10.106.171.0/24",
|
|
194
|
+
"client": {
|
|
195
|
+
"port": "36017",
|
|
196
|
+
"id": "cus-id-1",
|
|
197
|
+
"address": "10.106.171.253",
|
|
198
|
+
"ssh_port": 46600,
|
|
199
|
+
},
|
|
200
|
+
"net_prefix_type_id": 1,
|
|
201
|
+
},
|
|
202
|
+
)
|
|
203
|
+
details = volume.get_connection_details()
|
|
204
|
+
expected_details = dict(
|
|
205
|
+
project_uuid="a",
|
|
206
|
+
entity_type="volume",
|
|
207
|
+
cidr="10.106.171.0/24",
|
|
208
|
+
ssh_port=46600,
|
|
209
|
+
input_path="~/tensorflow-example",
|
|
210
|
+
output_path=None,
|
|
211
|
+
)
|
|
212
|
+
assert details == expected_details
|
|
213
|
+
|
|
214
|
+
@mark.asyncio
|
|
215
|
+
async def test_volume_connect(self, volume, mock_trainml):
|
|
216
|
+
with patch(
|
|
217
|
+
"trainml.volumes.Connection",
|
|
218
|
+
autospec=True,
|
|
219
|
+
) as mock_connection:
|
|
220
|
+
connection = mock_connection.return_value
|
|
221
|
+
connection.status = "connected"
|
|
222
|
+
resp = await volume.connect()
|
|
223
|
+
connection.start.assert_called_once()
|
|
224
|
+
assert resp == "connected"
|
|
225
|
+
|
|
226
|
+
@mark.asyncio
|
|
227
|
+
async def test_volume_disconnect(self, volume, mock_trainml):
|
|
228
|
+
with patch(
|
|
229
|
+
"trainml.volumes.Connection",
|
|
230
|
+
autospec=True,
|
|
231
|
+
) as mock_connection:
|
|
232
|
+
connection = mock_connection.return_value
|
|
233
|
+
connection.status = "removed"
|
|
234
|
+
resp = await volume.disconnect()
|
|
235
|
+
connection.stop.assert_called_once()
|
|
236
|
+
assert resp == "removed"
|
|
237
|
+
|
|
238
|
+
@mark.asyncio
|
|
239
|
+
async def test_volume_remove(self, volume, mock_trainml):
|
|
240
|
+
api_response = dict()
|
|
241
|
+
mock_trainml._query = AsyncMock(return_value=api_response)
|
|
242
|
+
await volume.remove()
|
|
243
|
+
mock_trainml._query.assert_called_once_with(
|
|
244
|
+
"/volume/1",
|
|
245
|
+
"DELETE",
|
|
246
|
+
dict(project_uuid="proj-id-1", force=False),
|
|
247
|
+
)
|
|
248
|
+
|
|
249
|
+
def test_volume_default_ws_msg_handler(self, volume, capsys):
|
|
250
|
+
data = {
|
|
251
|
+
"msg": "download: s3://trainml-examples/data/cifar10/data_batch_2.bin to ./data_batch_2.bin\n",
|
|
252
|
+
"time": 1613079345318,
|
|
253
|
+
"type": "subscription",
|
|
254
|
+
"stream": "worker-id-1",
|
|
255
|
+
"job_worker_uuid": "worker-id-1",
|
|
256
|
+
}
|
|
257
|
+
|
|
258
|
+
handler = volume._get_msg_handler(None)
|
|
259
|
+
handler(data)
|
|
260
|
+
captured = capsys.readouterr()
|
|
261
|
+
assert (
|
|
262
|
+
captured.out
|
|
263
|
+
== "02/11/2021, 15:35:45: download: s3://trainml-examples/data/cifar10/data_batch_2.bin to ./data_batch_2.bin\n"
|
|
264
|
+
)
|
|
265
|
+
|
|
266
|
+
def test_volume_custom_ws_msg_handler(self, volume, capsys):
|
|
267
|
+
def custom_handler(msg):
|
|
268
|
+
print(msg.get("stream"))
|
|
269
|
+
|
|
270
|
+
data = {
|
|
271
|
+
"msg": "download: s3://trainml-examples/data/cifar10/data_batch_2.bin to ./data_batch_2.bin\n",
|
|
272
|
+
"time": 1613079345318,
|
|
273
|
+
"type": "subscription",
|
|
274
|
+
"stream": "worker-id-1",
|
|
275
|
+
"job_worker_uuid": "worker-id-1",
|
|
276
|
+
}
|
|
277
|
+
|
|
278
|
+
handler = volume._get_msg_handler(custom_handler)
|
|
279
|
+
handler(data)
|
|
280
|
+
captured = capsys.readouterr()
|
|
281
|
+
assert captured.out == "worker-id-1\n"
|
|
282
|
+
|
|
283
|
+
@mark.asyncio
|
|
284
|
+
async def test_volume_attach(self, volume, mock_trainml):
|
|
285
|
+
api_response = None
|
|
286
|
+
mock_trainml._ws_subscribe = AsyncMock(return_value=api_response)
|
|
287
|
+
refresh_response = {
|
|
288
|
+
"customer_uuid": "cus-id-1",
|
|
289
|
+
"id": "data-id-1",
|
|
290
|
+
"name": "new volume",
|
|
291
|
+
"status": "downloading",
|
|
292
|
+
"source_type": "aws",
|
|
293
|
+
"source_uri": "s3://trainml-examples/data/cifar10",
|
|
294
|
+
"createdAt": "2020-12-20T16:46:23.909Z",
|
|
295
|
+
}
|
|
296
|
+
volume.refresh = AsyncMock(return_value=refresh_response)
|
|
297
|
+
await volume.attach()
|
|
298
|
+
mock_trainml._ws_subscribe.assert_called_once()
|
|
299
|
+
|
|
300
|
+
@mark.asyncio
|
|
301
|
+
async def test_volume_attach_immediate_return(self, mock_trainml):
|
|
302
|
+
volume = specimen.Volume(
|
|
303
|
+
mock_trainml,
|
|
304
|
+
id="1",
|
|
305
|
+
name="first one",
|
|
306
|
+
status="ready",
|
|
307
|
+
createdAt="2020-12-31T23:59:59.000Z",
|
|
308
|
+
)
|
|
309
|
+
api_response = None
|
|
310
|
+
mock_trainml._ws_subscribe = AsyncMock(return_value=api_response)
|
|
311
|
+
refresh_response = {
|
|
312
|
+
"customer_uuid": "cus-id-1",
|
|
313
|
+
"id": "1",
|
|
314
|
+
"name": "new volume",
|
|
315
|
+
"status": "ready",
|
|
316
|
+
"createdAt": "2020-12-20T16:46:23.909Z",
|
|
317
|
+
}
|
|
318
|
+
volume.refresh = AsyncMock(return_value=refresh_response)
|
|
319
|
+
await volume.attach()
|
|
320
|
+
mock_trainml._ws_subscribe.assert_not_called()
|
|
321
|
+
|
|
322
|
+
@mark.asyncio
|
|
323
|
+
async def test_volume_refresh(self, volume, mock_trainml):
|
|
324
|
+
api_response = {
|
|
325
|
+
"customer_uuid": "cus-id-1",
|
|
326
|
+
"id": "data-id-1",
|
|
327
|
+
"name": "new volume",
|
|
328
|
+
"status": "ready",
|
|
329
|
+
"source_type": "aws",
|
|
330
|
+
"source_uri": "s3://trainml-examples/data/cifar10",
|
|
331
|
+
"createdAt": "2020-12-20T16:46:23.909Z",
|
|
332
|
+
}
|
|
333
|
+
mock_trainml._query = AsyncMock(return_value=api_response)
|
|
334
|
+
response = await volume.refresh()
|
|
335
|
+
mock_trainml._query.assert_called_once_with(
|
|
336
|
+
f"/volume/1", "GET", dict(project_uuid="proj-id-1")
|
|
337
|
+
)
|
|
338
|
+
assert volume.id == "data-id-1"
|
|
339
|
+
assert response.id == "data-id-1"
|
|
340
|
+
|
|
341
|
+
@mark.asyncio
|
|
342
|
+
async def test_volume_wait_for_successful(self, volume, mock_trainml):
|
|
343
|
+
api_response = {
|
|
344
|
+
"customer_uuid": "cus-id-1",
|
|
345
|
+
"id": "data-id-1",
|
|
346
|
+
"name": "new volume",
|
|
347
|
+
"status": "ready",
|
|
348
|
+
"source_type": "aws",
|
|
349
|
+
"source_uri": "s3://trainml-examples/data/cifar10",
|
|
350
|
+
"createdAt": "2020-12-20T16:46:23.909Z",
|
|
351
|
+
}
|
|
352
|
+
mock_trainml._query = AsyncMock(return_value=api_response)
|
|
353
|
+
response = await volume.wait_for("ready")
|
|
354
|
+
mock_trainml._query.assert_called_once_with(
|
|
355
|
+
f"/volume/1", "GET", dict(project_uuid="proj-id-1")
|
|
356
|
+
)
|
|
357
|
+
assert volume.id == "data-id-1"
|
|
358
|
+
assert response.id == "data-id-1"
|
|
359
|
+
|
|
360
|
+
@mark.asyncio
|
|
361
|
+
async def test_volume_wait_for_current_status(self, mock_trainml):
|
|
362
|
+
volume = specimen.Volume(
|
|
363
|
+
mock_trainml,
|
|
364
|
+
id="1",
|
|
365
|
+
name="first one",
|
|
366
|
+
status="ready",
|
|
367
|
+
createdAt="2020-12-31T23:59:59.000Z",
|
|
368
|
+
)
|
|
369
|
+
api_response = None
|
|
370
|
+
mock_trainml._query = AsyncMock(return_value=api_response)
|
|
371
|
+
await volume.wait_for("ready")
|
|
372
|
+
mock_trainml._query.assert_not_called()
|
|
373
|
+
|
|
374
|
+
@mark.asyncio
|
|
375
|
+
async def test_volume_wait_for_incorrect_status(self, volume, mock_trainml):
|
|
376
|
+
api_response = None
|
|
377
|
+
mock_trainml._query = AsyncMock(return_value=api_response)
|
|
378
|
+
with raises(SpecificationError):
|
|
379
|
+
await volume.wait_for("stopped")
|
|
380
|
+
mock_trainml._query.assert_not_called()
|
|
381
|
+
|
|
382
|
+
@mark.asyncio
|
|
383
|
+
async def test_volume_wait_for_with_delay(self, volume, mock_trainml):
|
|
384
|
+
api_response_initial = dict(
|
|
385
|
+
id="1",
|
|
386
|
+
name="first one",
|
|
387
|
+
status="new",
|
|
388
|
+
createdAt="2020-12-31T23:59:59.000Z",
|
|
389
|
+
)
|
|
390
|
+
api_response_final = dict(
|
|
391
|
+
id="1",
|
|
392
|
+
name="first one",
|
|
393
|
+
status="ready",
|
|
394
|
+
createdAt="2020-12-31T23:59:59.000Z",
|
|
395
|
+
)
|
|
396
|
+
mock_trainml._query = AsyncMock()
|
|
397
|
+
mock_trainml._query.side_effect = [
|
|
398
|
+
api_response_initial,
|
|
399
|
+
api_response_initial,
|
|
400
|
+
api_response_final,
|
|
401
|
+
]
|
|
402
|
+
response = await volume.wait_for("ready")
|
|
403
|
+
assert volume.status == "ready"
|
|
404
|
+
assert response.status == "ready"
|
|
405
|
+
|
|
406
|
+
@mark.asyncio
|
|
407
|
+
async def test_volume_wait_for_timeout(self, volume, mock_trainml):
|
|
408
|
+
api_response = dict(
|
|
409
|
+
id="1",
|
|
410
|
+
name="first one",
|
|
411
|
+
status="downloading",
|
|
412
|
+
createdAt="2020-12-31T23:59:59.000Z",
|
|
413
|
+
)
|
|
414
|
+
mock_trainml._query = AsyncMock(return_value=api_response)
|
|
415
|
+
with raises(TrainMLException):
|
|
416
|
+
await volume.wait_for("ready", 10)
|
|
417
|
+
mock_trainml._query.assert_called()
|
|
418
|
+
|
|
419
|
+
@mark.asyncio
|
|
420
|
+
async def test_volume_wait_for_volume_failed(self, volume, mock_trainml):
|
|
421
|
+
api_response = dict(
|
|
422
|
+
id="1",
|
|
423
|
+
name="first one",
|
|
424
|
+
status="failed",
|
|
425
|
+
createdAt="2020-12-31T23:59:59.000Z",
|
|
426
|
+
)
|
|
427
|
+
mock_trainml._query = AsyncMock(return_value=api_response)
|
|
428
|
+
with raises(VolumeError):
|
|
429
|
+
await volume.wait_for("ready")
|
|
430
|
+
mock_trainml._query.assert_called()
|
|
431
|
+
|
|
432
|
+
@mark.asyncio
|
|
433
|
+
async def test_volume_wait_for_archived_succeeded(self, volume, mock_trainml):
|
|
434
|
+
mock_trainml._query = AsyncMock(
|
|
435
|
+
side_effect=ApiError(404, dict(errorMessage="Volume Not Found"))
|
|
436
|
+
)
|
|
437
|
+
await volume.wait_for("archived")
|
|
438
|
+
mock_trainml._query.assert_called()
|
|
439
|
+
|
|
440
|
+
@mark.asyncio
|
|
441
|
+
async def test_volume_wait_for_unexpected_api_error(self, volume, mock_trainml):
|
|
442
|
+
mock_trainml._query = AsyncMock(
|
|
443
|
+
side_effect=ApiError(404, dict(errorMessage="Volume Not Found"))
|
|
444
|
+
)
|
|
445
|
+
with raises(ApiError):
|
|
446
|
+
await volume.wait_for("ready")
|
|
447
|
+
mock_trainml._query.assert_called()
|
trainml/__init__.py
CHANGED
trainml/cli/__init__.py
CHANGED
|
@@ -142,9 +142,7 @@ def configure(config):
|
|
|
142
142
|
project for project in projects if project.id == active_project_id
|
|
143
143
|
]
|
|
144
144
|
|
|
145
|
-
active_project_name = (
|
|
146
|
-
active_project[0].name if len(active_project) else "UNSET"
|
|
147
|
-
)
|
|
145
|
+
active_project_name = active_project[0].name if len(active_project) else "UNSET"
|
|
148
146
|
|
|
149
147
|
click.echo(f"Current Active Project: {active_project_name}")
|
|
150
148
|
|
|
@@ -154,9 +152,7 @@ def configure(config):
|
|
|
154
152
|
show_choices=True,
|
|
155
153
|
default=active_project_name,
|
|
156
154
|
)
|
|
157
|
-
selected_project = [
|
|
158
|
-
project for project in projects if project.name == name
|
|
159
|
-
]
|
|
155
|
+
selected_project = [project for project in projects if project.name == name]
|
|
160
156
|
config.trainml.client.set_active_project(selected_project[0].id)
|
|
161
157
|
|
|
162
158
|
|
|
@@ -164,6 +160,7 @@ from trainml.cli.connection import connection
|
|
|
164
160
|
from trainml.cli.dataset import dataset
|
|
165
161
|
from trainml.cli.model import model
|
|
166
162
|
from trainml.cli.checkpoint import checkpoint
|
|
163
|
+
from trainml.cli.volume import volume
|
|
167
164
|
from trainml.cli.environment import environment
|
|
168
165
|
from trainml.cli.gpu import gpu
|
|
169
166
|
from trainml.cli.job import job
|
|
@@ -15,4 +15,4 @@ from trainml.cli.cloudbender.region import region
|
|
|
15
15
|
from trainml.cli.cloudbender.node import node
|
|
16
16
|
from trainml.cli.cloudbender.device import device
|
|
17
17
|
from trainml.cli.cloudbender.datastore import datastore
|
|
18
|
-
from trainml.cli.cloudbender.
|
|
18
|
+
from trainml.cli.cloudbender.service import service
|
|
@@ -0,0 +1,129 @@
|
|
|
1
|
+
import click
|
|
2
|
+
from trainml.cli import cli, pass_config, search_by_id_name
|
|
3
|
+
from trainml.cli.cloudbender import cloudbender
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
@cloudbender.group()
|
|
7
|
+
@pass_config
|
|
8
|
+
def service(config):
|
|
9
|
+
"""trainML CloudBender service commands."""
|
|
10
|
+
pass
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
@service.command()
|
|
14
|
+
@click.option(
|
|
15
|
+
"--provider",
|
|
16
|
+
"-p",
|
|
17
|
+
type=click.STRING,
|
|
18
|
+
required=True,
|
|
19
|
+
help="The provider ID of the region.",
|
|
20
|
+
)
|
|
21
|
+
@click.option(
|
|
22
|
+
"--region",
|
|
23
|
+
"-r",
|
|
24
|
+
type=click.STRING,
|
|
25
|
+
required=True,
|
|
26
|
+
help="The region ID to list services for.",
|
|
27
|
+
)
|
|
28
|
+
@pass_config
|
|
29
|
+
def list(config, provider, region):
|
|
30
|
+
"""List services."""
|
|
31
|
+
data = [
|
|
32
|
+
["ID", "NAME", "HOSTNAME"],
|
|
33
|
+
[
|
|
34
|
+
"-" * 80,
|
|
35
|
+
"-" * 80,
|
|
36
|
+
"-" * 80,
|
|
37
|
+
],
|
|
38
|
+
]
|
|
39
|
+
|
|
40
|
+
services = config.trainml.run(
|
|
41
|
+
config.trainml.client.cloudbender.services.list(
|
|
42
|
+
provider_uuid=provider, region_uuid=region
|
|
43
|
+
)
|
|
44
|
+
)
|
|
45
|
+
|
|
46
|
+
for service in services:
|
|
47
|
+
data.append(
|
|
48
|
+
[
|
|
49
|
+
service.id,
|
|
50
|
+
service.name,
|
|
51
|
+
service.hostname,
|
|
52
|
+
]
|
|
53
|
+
)
|
|
54
|
+
|
|
55
|
+
for row in data:
|
|
56
|
+
click.echo(
|
|
57
|
+
"{: >25.24} {: >29.28} {: >40.39}" "".format(*row),
|
|
58
|
+
file=config.stdout,
|
|
59
|
+
)
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
@service.command()
|
|
63
|
+
@click.option(
|
|
64
|
+
"--provider",
|
|
65
|
+
"-p",
|
|
66
|
+
type=click.STRING,
|
|
67
|
+
required=True,
|
|
68
|
+
help="The provider ID of the region.",
|
|
69
|
+
)
|
|
70
|
+
@click.option(
|
|
71
|
+
"--region",
|
|
72
|
+
"-r",
|
|
73
|
+
type=click.STRING,
|
|
74
|
+
required=True,
|
|
75
|
+
help="The region ID to create the service in.",
|
|
76
|
+
)
|
|
77
|
+
@click.option(
|
|
78
|
+
"--public/--no-public",
|
|
79
|
+
default=True,
|
|
80
|
+
show_default=True,
|
|
81
|
+
help="Service should be accessible from the public internet.",
|
|
82
|
+
)
|
|
83
|
+
@click.argument("name", type=click.STRING, required=True)
|
|
84
|
+
@pass_config
|
|
85
|
+
def create(config, provider, region, public, name):
|
|
86
|
+
"""
|
|
87
|
+
Creates a service.
|
|
88
|
+
"""
|
|
89
|
+
return config.trainml.run(
|
|
90
|
+
config.trainml.client.cloudbender.services.create(
|
|
91
|
+
provider_uuid=provider, region_uuid=region, name=name, public=public
|
|
92
|
+
)
|
|
93
|
+
)
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
@service.command()
|
|
97
|
+
@click.option(
|
|
98
|
+
"--provider",
|
|
99
|
+
"-p",
|
|
100
|
+
type=click.STRING,
|
|
101
|
+
required=True,
|
|
102
|
+
help="The provider ID of the region.",
|
|
103
|
+
)
|
|
104
|
+
@click.option(
|
|
105
|
+
"--region",
|
|
106
|
+
"-r",
|
|
107
|
+
type=click.STRING,
|
|
108
|
+
required=True,
|
|
109
|
+
help="The region ID to remove the service from.",
|
|
110
|
+
)
|
|
111
|
+
@click.argument("service", type=click.STRING)
|
|
112
|
+
@pass_config
|
|
113
|
+
def remove(config, provider, region, service):
|
|
114
|
+
"""
|
|
115
|
+
Remove a service.
|
|
116
|
+
|
|
117
|
+
RESERVATION may be specified by name or ID, but ID is preferred.
|
|
118
|
+
"""
|
|
119
|
+
services = config.trainml.run(
|
|
120
|
+
config.trainml.client.cloudbender.services.list(
|
|
121
|
+
provider_uuid=provider, region_uuid=region
|
|
122
|
+
)
|
|
123
|
+
)
|
|
124
|
+
|
|
125
|
+
found = search_by_id_name(service, services)
|
|
126
|
+
if None is found:
|
|
127
|
+
raise click.UsageError("Cannot find specified service.")
|
|
128
|
+
|
|
129
|
+
return config.trainml.run(found.remove())
|
trainml/cli/project.py
CHANGED
|
@@ -115,40 +115,35 @@ def list_datastores(config):
|
|
|
115
115
|
|
|
116
116
|
@project.command()
|
|
117
117
|
@pass_config
|
|
118
|
-
def
|
|
119
|
-
"""List project
|
|
118
|
+
def list_services(config):
|
|
119
|
+
"""List project services."""
|
|
120
120
|
data = [
|
|
121
|
-
["ID", "NAME", "
|
|
121
|
+
["ID", "NAME", "HOSTNAME", "REGION_UUID"],
|
|
122
122
|
[
|
|
123
123
|
"-" * 80,
|
|
124
124
|
"-" * 80,
|
|
125
125
|
"-" * 80,
|
|
126
126
|
"-" * 80,
|
|
127
|
-
"-" * 80,
|
|
128
|
-
"-" * 80,
|
|
129
127
|
],
|
|
130
128
|
]
|
|
131
129
|
project = config.trainml.run(
|
|
132
130
|
config.trainml.client.projects.get(config.trainml.client.project)
|
|
133
131
|
)
|
|
134
132
|
|
|
135
|
-
|
|
133
|
+
services = config.trainml.run(project.list_services())
|
|
136
134
|
|
|
137
|
-
for
|
|
135
|
+
for service in services:
|
|
138
136
|
data.append(
|
|
139
137
|
[
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
reservation.hostname,
|
|
145
|
-
reservation.region_uuid,
|
|
138
|
+
service.id,
|
|
139
|
+
service.name,
|
|
140
|
+
service.hostname,
|
|
141
|
+
service.region_uuid,
|
|
146
142
|
]
|
|
147
143
|
)
|
|
148
144
|
|
|
149
145
|
for row in data:
|
|
150
146
|
click.echo(
|
|
151
|
-
"{: >38.36} {: >30.28} {: >
|
|
152
|
-
"".format(*row),
|
|
147
|
+
"{: >38.36} {: >30.28} {: >30.28} {: >38.36}" "".format(*row),
|
|
153
148
|
file=config.stdout,
|
|
154
149
|
)
|