trainml 0.5.3__py3-none-any.whl → 0.5.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tests/integration/test_checkpoints_integration.py +7 -5
- tests/integration/test_datasets_integration.py +4 -5
- tests/integration/test_jobs_integration.py +40 -2
- tests/integration/test_models_integration.py +8 -10
- tests/integration/test_projects_integration.py +2 -6
- tests/integration/test_volumes_integration.py +100 -0
- tests/unit/cli/test_cli_volume_unit.py +20 -0
- tests/unit/conftest.py +82 -9
- tests/unit/test_volumes_unit.py +447 -0
- trainml/__init__.py +1 -1
- trainml/cli/__init__.py +3 -6
- trainml/cli/job/create.py +3 -3
- trainml/cli/volume.py +235 -0
- trainml/exceptions.py +21 -12
- trainml/jobs.py +36 -39
- trainml/trainml.py +7 -15
- trainml/volumes.py +255 -0
- {trainml-0.5.3.dist-info → trainml-0.5.5.dist-info}/METADATA +1 -1
- {trainml-0.5.3.dist-info → trainml-0.5.5.dist-info}/RECORD +23 -23
- tests/integration/test_providers_integration.py +0 -46
- tests/unit/test_providers_unit.py +0 -125
- trainml/cli/job.py +0 -173
- trainml/cli/provider.py +0 -75
- trainml/providers.py +0 -63
- {trainml-0.5.3.dist-info → trainml-0.5.5.dist-info}/LICENSE +0 -0
- {trainml-0.5.3.dist-info → trainml-0.5.5.dist-info}/WHEEL +0 -0
- {trainml-0.5.3.dist-info → trainml-0.5.5.dist-info}/entry_points.txt +0 -0
- {trainml-0.5.3.dist-info → trainml-0.5.5.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,447 @@
|
|
|
1
|
+
import re
|
|
2
|
+
import json
|
|
3
|
+
import logging
|
|
4
|
+
from unittest.mock import AsyncMock, patch
|
|
5
|
+
from pytest import mark, fixture, raises
|
|
6
|
+
from aiohttp import WSMessage, WSMsgType
|
|
7
|
+
|
|
8
|
+
import trainml.volumes as specimen
|
|
9
|
+
from trainml.exceptions import (
|
|
10
|
+
ApiError,
|
|
11
|
+
VolumeError,
|
|
12
|
+
SpecificationError,
|
|
13
|
+
TrainMLException,
|
|
14
|
+
)
|
|
15
|
+
|
|
16
|
+
pytestmark = [mark.sdk, mark.unit, mark.volumes]
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
@fixture
|
|
20
|
+
def volumes(mock_trainml):
|
|
21
|
+
yield specimen.Volumes(mock_trainml)
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
@fixture
|
|
25
|
+
def volume(mock_trainml):
|
|
26
|
+
yield specimen.Volume(
|
|
27
|
+
mock_trainml,
|
|
28
|
+
id="1",
|
|
29
|
+
project_uuid="proj-id-1",
|
|
30
|
+
name="first one",
|
|
31
|
+
status="downloading",
|
|
32
|
+
capacity="10G",
|
|
33
|
+
used_size=100000000,
|
|
34
|
+
billed_size=100000000,
|
|
35
|
+
createdAt="2020-12-31T23:59:59.000Z",
|
|
36
|
+
)
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
class VolumesTests:
|
|
40
|
+
@mark.asyncio
|
|
41
|
+
async def test_get_volume(
|
|
42
|
+
self,
|
|
43
|
+
volumes,
|
|
44
|
+
mock_trainml,
|
|
45
|
+
):
|
|
46
|
+
api_response = dict()
|
|
47
|
+
mock_trainml._query = AsyncMock(return_value=api_response)
|
|
48
|
+
await volumes.get("1234")
|
|
49
|
+
mock_trainml._query.assert_called_once_with("/volume/1234", "GET", dict())
|
|
50
|
+
|
|
51
|
+
@mark.asyncio
|
|
52
|
+
async def test_list_volumes(
|
|
53
|
+
self,
|
|
54
|
+
volumes,
|
|
55
|
+
mock_trainml,
|
|
56
|
+
):
|
|
57
|
+
api_response = dict()
|
|
58
|
+
mock_trainml._query = AsyncMock(return_value=api_response)
|
|
59
|
+
await volumes.list()
|
|
60
|
+
mock_trainml._query.assert_called_once_with("/volume", "GET", dict())
|
|
61
|
+
|
|
62
|
+
@mark.asyncio
|
|
63
|
+
async def test_remove_volume(
|
|
64
|
+
self,
|
|
65
|
+
volumes,
|
|
66
|
+
mock_trainml,
|
|
67
|
+
):
|
|
68
|
+
api_response = dict()
|
|
69
|
+
mock_trainml._query = AsyncMock(return_value=api_response)
|
|
70
|
+
await volumes.remove("4567")
|
|
71
|
+
mock_trainml._query.assert_called_once_with(
|
|
72
|
+
"/volume/4567", "DELETE", dict(force=True)
|
|
73
|
+
)
|
|
74
|
+
|
|
75
|
+
@mark.asyncio
|
|
76
|
+
async def test_create_volume_simple(self, volumes, mock_trainml):
|
|
77
|
+
requested_config = dict(
|
|
78
|
+
name="new volume",
|
|
79
|
+
source_type="aws",
|
|
80
|
+
source_uri="s3://trainml-examples/volumes/resnet50",
|
|
81
|
+
capacity="10G",
|
|
82
|
+
)
|
|
83
|
+
expected_payload = dict(
|
|
84
|
+
project_uuid="proj-id-1",
|
|
85
|
+
name="new volume",
|
|
86
|
+
source_type="aws",
|
|
87
|
+
source_uri="s3://trainml-examples/volumes/resnet50",
|
|
88
|
+
capacity="10G",
|
|
89
|
+
)
|
|
90
|
+
api_response = {
|
|
91
|
+
"project_uuid": "cus-id-1",
|
|
92
|
+
"id": "volume-id-1",
|
|
93
|
+
"name": "new volume",
|
|
94
|
+
"status": "new",
|
|
95
|
+
"source_type": "aws",
|
|
96
|
+
"capacity": "10G",
|
|
97
|
+
"source_uri": "s3://trainml-examples/volumes/resnet50",
|
|
98
|
+
"createdAt": "2020-12-20T16:46:23.909Z",
|
|
99
|
+
}
|
|
100
|
+
|
|
101
|
+
mock_trainml._query = AsyncMock(return_value=api_response)
|
|
102
|
+
response = await volumes.create(**requested_config)
|
|
103
|
+
mock_trainml._query.assert_called_once_with(
|
|
104
|
+
"/volume", "POST", None, expected_payload
|
|
105
|
+
)
|
|
106
|
+
assert response.id == "volume-id-1"
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
class VolumeTests:
|
|
110
|
+
def test_volume_properties(self, volume):
|
|
111
|
+
assert isinstance(volume.id, str)
|
|
112
|
+
assert isinstance(volume.status, str)
|
|
113
|
+
assert isinstance(volume.name, str)
|
|
114
|
+
assert isinstance(volume.capacity, str)
|
|
115
|
+
assert isinstance(volume.used_size, int)
|
|
116
|
+
assert isinstance(volume.billed_size, int)
|
|
117
|
+
|
|
118
|
+
def test_volume_str(self, volume):
|
|
119
|
+
string = str(volume)
|
|
120
|
+
regex = r"^{.*\"id\": \"" + volume.id + r"\".*}$"
|
|
121
|
+
assert isinstance(string, str)
|
|
122
|
+
assert re.match(regex, string)
|
|
123
|
+
|
|
124
|
+
def test_volume_repr(self, volume):
|
|
125
|
+
string = repr(volume)
|
|
126
|
+
regex = r"^Volume\( trainml , \*\*{.*'id': '" + volume.id + r"'.*}\)$"
|
|
127
|
+
assert isinstance(string, str)
|
|
128
|
+
assert re.match(regex, string)
|
|
129
|
+
|
|
130
|
+
def test_volume_bool(self, volume, mock_trainml):
|
|
131
|
+
empty_volume = specimen.Volume(mock_trainml)
|
|
132
|
+
assert bool(volume)
|
|
133
|
+
assert not bool(empty_volume)
|
|
134
|
+
|
|
135
|
+
@mark.asyncio
|
|
136
|
+
async def test_volume_get_log_url(self, volume, mock_trainml):
|
|
137
|
+
api_response = (
|
|
138
|
+
"https://trainml-jobs-dev.s3.us-east-2.amazonaws.com/1/logs/first_one.zip"
|
|
139
|
+
)
|
|
140
|
+
mock_trainml._query = AsyncMock(return_value=api_response)
|
|
141
|
+
response = await volume.get_log_url()
|
|
142
|
+
mock_trainml._query.assert_called_once_with(
|
|
143
|
+
"/volume/1/logs", "GET", dict(project_uuid="proj-id-1")
|
|
144
|
+
)
|
|
145
|
+
assert response == api_response
|
|
146
|
+
|
|
147
|
+
@mark.asyncio
|
|
148
|
+
async def test_volume_get_details(self, volume, mock_trainml):
|
|
149
|
+
api_response = {
|
|
150
|
+
"type": "directory",
|
|
151
|
+
"name": "/",
|
|
152
|
+
"count": "8",
|
|
153
|
+
"used_size": "177M",
|
|
154
|
+
"contents": [],
|
|
155
|
+
}
|
|
156
|
+
mock_trainml._query = AsyncMock(return_value=api_response)
|
|
157
|
+
response = await volume.get_details()
|
|
158
|
+
mock_trainml._query.assert_called_once_with(
|
|
159
|
+
"/volume/1/details", "GET", dict(project_uuid="proj-id-1")
|
|
160
|
+
)
|
|
161
|
+
assert response == api_response
|
|
162
|
+
|
|
163
|
+
@mark.asyncio
|
|
164
|
+
async def test_volume_get_connection_utility_url(self, volume, mock_trainml):
|
|
165
|
+
api_response = (
|
|
166
|
+
"https://trainml-jobs-dev.s3.us-east-2.amazonaws.com/1/vpn/first_one.zip"
|
|
167
|
+
)
|
|
168
|
+
mock_trainml._query = AsyncMock(return_value=api_response)
|
|
169
|
+
response = await volume.get_connection_utility_url()
|
|
170
|
+
mock_trainml._query.assert_called_once_with(
|
|
171
|
+
"/volume/1/download", "GET", dict(project_uuid="proj-id-1")
|
|
172
|
+
)
|
|
173
|
+
assert response == api_response
|
|
174
|
+
|
|
175
|
+
def test_volume_get_connection_details_no_vpn(self, volume):
|
|
176
|
+
details = volume.get_connection_details()
|
|
177
|
+
expected_details = dict()
|
|
178
|
+
assert details == expected_details
|
|
179
|
+
|
|
180
|
+
def test_volume_get_connection_details_local_data(self, mock_trainml):
|
|
181
|
+
volume = specimen.Volume(
|
|
182
|
+
mock_trainml,
|
|
183
|
+
id="1",
|
|
184
|
+
project_uuid="a",
|
|
185
|
+
name="first one",
|
|
186
|
+
status="new",
|
|
187
|
+
capacity="10G",
|
|
188
|
+
createdAt="2020-12-31T23:59:59.000Z",
|
|
189
|
+
source_type="local",
|
|
190
|
+
source_uri="~/tensorflow-example",
|
|
191
|
+
vpn={
|
|
192
|
+
"status": "new",
|
|
193
|
+
"cidr": "10.106.171.0/24",
|
|
194
|
+
"client": {
|
|
195
|
+
"port": "36017",
|
|
196
|
+
"id": "cus-id-1",
|
|
197
|
+
"address": "10.106.171.253",
|
|
198
|
+
"ssh_port": 46600,
|
|
199
|
+
},
|
|
200
|
+
"net_prefix_type_id": 1,
|
|
201
|
+
},
|
|
202
|
+
)
|
|
203
|
+
details = volume.get_connection_details()
|
|
204
|
+
expected_details = dict(
|
|
205
|
+
project_uuid="a",
|
|
206
|
+
entity_type="volume",
|
|
207
|
+
cidr="10.106.171.0/24",
|
|
208
|
+
ssh_port=46600,
|
|
209
|
+
input_path="~/tensorflow-example",
|
|
210
|
+
output_path=None,
|
|
211
|
+
)
|
|
212
|
+
assert details == expected_details
|
|
213
|
+
|
|
214
|
+
@mark.asyncio
|
|
215
|
+
async def test_volume_connect(self, volume, mock_trainml):
|
|
216
|
+
with patch(
|
|
217
|
+
"trainml.volumes.Connection",
|
|
218
|
+
autospec=True,
|
|
219
|
+
) as mock_connection:
|
|
220
|
+
connection = mock_connection.return_value
|
|
221
|
+
connection.status = "connected"
|
|
222
|
+
resp = await volume.connect()
|
|
223
|
+
connection.start.assert_called_once()
|
|
224
|
+
assert resp == "connected"
|
|
225
|
+
|
|
226
|
+
@mark.asyncio
|
|
227
|
+
async def test_volume_disconnect(self, volume, mock_trainml):
|
|
228
|
+
with patch(
|
|
229
|
+
"trainml.volumes.Connection",
|
|
230
|
+
autospec=True,
|
|
231
|
+
) as mock_connection:
|
|
232
|
+
connection = mock_connection.return_value
|
|
233
|
+
connection.status = "removed"
|
|
234
|
+
resp = await volume.disconnect()
|
|
235
|
+
connection.stop.assert_called_once()
|
|
236
|
+
assert resp == "removed"
|
|
237
|
+
|
|
238
|
+
@mark.asyncio
|
|
239
|
+
async def test_volume_remove(self, volume, mock_trainml):
|
|
240
|
+
api_response = dict()
|
|
241
|
+
mock_trainml._query = AsyncMock(return_value=api_response)
|
|
242
|
+
await volume.remove()
|
|
243
|
+
mock_trainml._query.assert_called_once_with(
|
|
244
|
+
"/volume/1",
|
|
245
|
+
"DELETE",
|
|
246
|
+
dict(project_uuid="proj-id-1", force=False),
|
|
247
|
+
)
|
|
248
|
+
|
|
249
|
+
def test_volume_default_ws_msg_handler(self, volume, capsys):
|
|
250
|
+
data = {
|
|
251
|
+
"msg": "download: s3://trainml-examples/data/cifar10/data_batch_2.bin to ./data_batch_2.bin\n",
|
|
252
|
+
"time": 1613079345318,
|
|
253
|
+
"type": "subscription",
|
|
254
|
+
"stream": "worker-id-1",
|
|
255
|
+
"job_worker_uuid": "worker-id-1",
|
|
256
|
+
}
|
|
257
|
+
|
|
258
|
+
handler = volume._get_msg_handler(None)
|
|
259
|
+
handler(data)
|
|
260
|
+
captured = capsys.readouterr()
|
|
261
|
+
assert (
|
|
262
|
+
captured.out
|
|
263
|
+
== "02/11/2021, 15:35:45: download: s3://trainml-examples/data/cifar10/data_batch_2.bin to ./data_batch_2.bin\n"
|
|
264
|
+
)
|
|
265
|
+
|
|
266
|
+
def test_volume_custom_ws_msg_handler(self, volume, capsys):
|
|
267
|
+
def custom_handler(msg):
|
|
268
|
+
print(msg.get("stream"))
|
|
269
|
+
|
|
270
|
+
data = {
|
|
271
|
+
"msg": "download: s3://trainml-examples/data/cifar10/data_batch_2.bin to ./data_batch_2.bin\n",
|
|
272
|
+
"time": 1613079345318,
|
|
273
|
+
"type": "subscription",
|
|
274
|
+
"stream": "worker-id-1",
|
|
275
|
+
"job_worker_uuid": "worker-id-1",
|
|
276
|
+
}
|
|
277
|
+
|
|
278
|
+
handler = volume._get_msg_handler(custom_handler)
|
|
279
|
+
handler(data)
|
|
280
|
+
captured = capsys.readouterr()
|
|
281
|
+
assert captured.out == "worker-id-1\n"
|
|
282
|
+
|
|
283
|
+
@mark.asyncio
|
|
284
|
+
async def test_volume_attach(self, volume, mock_trainml):
|
|
285
|
+
api_response = None
|
|
286
|
+
mock_trainml._ws_subscribe = AsyncMock(return_value=api_response)
|
|
287
|
+
refresh_response = {
|
|
288
|
+
"customer_uuid": "cus-id-1",
|
|
289
|
+
"id": "data-id-1",
|
|
290
|
+
"name": "new volume",
|
|
291
|
+
"status": "downloading",
|
|
292
|
+
"source_type": "aws",
|
|
293
|
+
"source_uri": "s3://trainml-examples/data/cifar10",
|
|
294
|
+
"createdAt": "2020-12-20T16:46:23.909Z",
|
|
295
|
+
}
|
|
296
|
+
volume.refresh = AsyncMock(return_value=refresh_response)
|
|
297
|
+
await volume.attach()
|
|
298
|
+
mock_trainml._ws_subscribe.assert_called_once()
|
|
299
|
+
|
|
300
|
+
@mark.asyncio
|
|
301
|
+
async def test_volume_attach_immediate_return(self, mock_trainml):
|
|
302
|
+
volume = specimen.Volume(
|
|
303
|
+
mock_trainml,
|
|
304
|
+
id="1",
|
|
305
|
+
name="first one",
|
|
306
|
+
status="ready",
|
|
307
|
+
createdAt="2020-12-31T23:59:59.000Z",
|
|
308
|
+
)
|
|
309
|
+
api_response = None
|
|
310
|
+
mock_trainml._ws_subscribe = AsyncMock(return_value=api_response)
|
|
311
|
+
refresh_response = {
|
|
312
|
+
"customer_uuid": "cus-id-1",
|
|
313
|
+
"id": "1",
|
|
314
|
+
"name": "new volume",
|
|
315
|
+
"status": "ready",
|
|
316
|
+
"createdAt": "2020-12-20T16:46:23.909Z",
|
|
317
|
+
}
|
|
318
|
+
volume.refresh = AsyncMock(return_value=refresh_response)
|
|
319
|
+
await volume.attach()
|
|
320
|
+
mock_trainml._ws_subscribe.assert_not_called()
|
|
321
|
+
|
|
322
|
+
@mark.asyncio
|
|
323
|
+
async def test_volume_refresh(self, volume, mock_trainml):
|
|
324
|
+
api_response = {
|
|
325
|
+
"customer_uuid": "cus-id-1",
|
|
326
|
+
"id": "data-id-1",
|
|
327
|
+
"name": "new volume",
|
|
328
|
+
"status": "ready",
|
|
329
|
+
"source_type": "aws",
|
|
330
|
+
"source_uri": "s3://trainml-examples/data/cifar10",
|
|
331
|
+
"createdAt": "2020-12-20T16:46:23.909Z",
|
|
332
|
+
}
|
|
333
|
+
mock_trainml._query = AsyncMock(return_value=api_response)
|
|
334
|
+
response = await volume.refresh()
|
|
335
|
+
mock_trainml._query.assert_called_once_with(
|
|
336
|
+
f"/volume/1", "GET", dict(project_uuid="proj-id-1")
|
|
337
|
+
)
|
|
338
|
+
assert volume.id == "data-id-1"
|
|
339
|
+
assert response.id == "data-id-1"
|
|
340
|
+
|
|
341
|
+
@mark.asyncio
|
|
342
|
+
async def test_volume_wait_for_successful(self, volume, mock_trainml):
|
|
343
|
+
api_response = {
|
|
344
|
+
"customer_uuid": "cus-id-1",
|
|
345
|
+
"id": "data-id-1",
|
|
346
|
+
"name": "new volume",
|
|
347
|
+
"status": "ready",
|
|
348
|
+
"source_type": "aws",
|
|
349
|
+
"source_uri": "s3://trainml-examples/data/cifar10",
|
|
350
|
+
"createdAt": "2020-12-20T16:46:23.909Z",
|
|
351
|
+
}
|
|
352
|
+
mock_trainml._query = AsyncMock(return_value=api_response)
|
|
353
|
+
response = await volume.wait_for("ready")
|
|
354
|
+
mock_trainml._query.assert_called_once_with(
|
|
355
|
+
f"/volume/1", "GET", dict(project_uuid="proj-id-1")
|
|
356
|
+
)
|
|
357
|
+
assert volume.id == "data-id-1"
|
|
358
|
+
assert response.id == "data-id-1"
|
|
359
|
+
|
|
360
|
+
@mark.asyncio
|
|
361
|
+
async def test_volume_wait_for_current_status(self, mock_trainml):
|
|
362
|
+
volume = specimen.Volume(
|
|
363
|
+
mock_trainml,
|
|
364
|
+
id="1",
|
|
365
|
+
name="first one",
|
|
366
|
+
status="ready",
|
|
367
|
+
createdAt="2020-12-31T23:59:59.000Z",
|
|
368
|
+
)
|
|
369
|
+
api_response = None
|
|
370
|
+
mock_trainml._query = AsyncMock(return_value=api_response)
|
|
371
|
+
await volume.wait_for("ready")
|
|
372
|
+
mock_trainml._query.assert_not_called()
|
|
373
|
+
|
|
374
|
+
@mark.asyncio
|
|
375
|
+
async def test_volume_wait_for_incorrect_status(self, volume, mock_trainml):
|
|
376
|
+
api_response = None
|
|
377
|
+
mock_trainml._query = AsyncMock(return_value=api_response)
|
|
378
|
+
with raises(SpecificationError):
|
|
379
|
+
await volume.wait_for("stopped")
|
|
380
|
+
mock_trainml._query.assert_not_called()
|
|
381
|
+
|
|
382
|
+
@mark.asyncio
|
|
383
|
+
async def test_volume_wait_for_with_delay(self, volume, mock_trainml):
|
|
384
|
+
api_response_initial = dict(
|
|
385
|
+
id="1",
|
|
386
|
+
name="first one",
|
|
387
|
+
status="new",
|
|
388
|
+
createdAt="2020-12-31T23:59:59.000Z",
|
|
389
|
+
)
|
|
390
|
+
api_response_final = dict(
|
|
391
|
+
id="1",
|
|
392
|
+
name="first one",
|
|
393
|
+
status="ready",
|
|
394
|
+
createdAt="2020-12-31T23:59:59.000Z",
|
|
395
|
+
)
|
|
396
|
+
mock_trainml._query = AsyncMock()
|
|
397
|
+
mock_trainml._query.side_effect = [
|
|
398
|
+
api_response_initial,
|
|
399
|
+
api_response_initial,
|
|
400
|
+
api_response_final,
|
|
401
|
+
]
|
|
402
|
+
response = await volume.wait_for("ready")
|
|
403
|
+
assert volume.status == "ready"
|
|
404
|
+
assert response.status == "ready"
|
|
405
|
+
|
|
406
|
+
@mark.asyncio
|
|
407
|
+
async def test_volume_wait_for_timeout(self, volume, mock_trainml):
|
|
408
|
+
api_response = dict(
|
|
409
|
+
id="1",
|
|
410
|
+
name="first one",
|
|
411
|
+
status="downloading",
|
|
412
|
+
createdAt="2020-12-31T23:59:59.000Z",
|
|
413
|
+
)
|
|
414
|
+
mock_trainml._query = AsyncMock(return_value=api_response)
|
|
415
|
+
with raises(TrainMLException):
|
|
416
|
+
await volume.wait_for("ready", 10)
|
|
417
|
+
mock_trainml._query.assert_called()
|
|
418
|
+
|
|
419
|
+
@mark.asyncio
|
|
420
|
+
async def test_volume_wait_for_volume_failed(self, volume, mock_trainml):
|
|
421
|
+
api_response = dict(
|
|
422
|
+
id="1",
|
|
423
|
+
name="first one",
|
|
424
|
+
status="failed",
|
|
425
|
+
createdAt="2020-12-31T23:59:59.000Z",
|
|
426
|
+
)
|
|
427
|
+
mock_trainml._query = AsyncMock(return_value=api_response)
|
|
428
|
+
with raises(VolumeError):
|
|
429
|
+
await volume.wait_for("ready")
|
|
430
|
+
mock_trainml._query.assert_called()
|
|
431
|
+
|
|
432
|
+
@mark.asyncio
|
|
433
|
+
async def test_volume_wait_for_archived_succeeded(self, volume, mock_trainml):
|
|
434
|
+
mock_trainml._query = AsyncMock(
|
|
435
|
+
side_effect=ApiError(404, dict(errorMessage="Volume Not Found"))
|
|
436
|
+
)
|
|
437
|
+
await volume.wait_for("archived")
|
|
438
|
+
mock_trainml._query.assert_called()
|
|
439
|
+
|
|
440
|
+
@mark.asyncio
|
|
441
|
+
async def test_volume_wait_for_unexpected_api_error(self, volume, mock_trainml):
|
|
442
|
+
mock_trainml._query = AsyncMock(
|
|
443
|
+
side_effect=ApiError(404, dict(errorMessage="Volume Not Found"))
|
|
444
|
+
)
|
|
445
|
+
with raises(ApiError):
|
|
446
|
+
await volume.wait_for("ready")
|
|
447
|
+
mock_trainml._query.assert_called()
|
trainml/__init__.py
CHANGED
trainml/cli/__init__.py
CHANGED
|
@@ -142,9 +142,7 @@ def configure(config):
|
|
|
142
142
|
project for project in projects if project.id == active_project_id
|
|
143
143
|
]
|
|
144
144
|
|
|
145
|
-
active_project_name = (
|
|
146
|
-
active_project[0].name if len(active_project) else "UNSET"
|
|
147
|
-
)
|
|
145
|
+
active_project_name = active_project[0].name if len(active_project) else "UNSET"
|
|
148
146
|
|
|
149
147
|
click.echo(f"Current Active Project: {active_project_name}")
|
|
150
148
|
|
|
@@ -154,9 +152,7 @@ def configure(config):
|
|
|
154
152
|
show_choices=True,
|
|
155
153
|
default=active_project_name,
|
|
156
154
|
)
|
|
157
|
-
selected_project = [
|
|
158
|
-
project for project in projects if project.name == name
|
|
159
|
-
]
|
|
155
|
+
selected_project = [project for project in projects if project.name == name]
|
|
160
156
|
config.trainml.client.set_active_project(selected_project[0].id)
|
|
161
157
|
|
|
162
158
|
|
|
@@ -164,6 +160,7 @@ from trainml.cli.connection import connection
|
|
|
164
160
|
from trainml.cli.dataset import dataset
|
|
165
161
|
from trainml.cli.model import model
|
|
166
162
|
from trainml.cli.checkpoint import checkpoint
|
|
163
|
+
from trainml.cli.volume import volume
|
|
167
164
|
from trainml.cli.environment import environment
|
|
168
165
|
from trainml.cli.gpu import gpu
|
|
169
166
|
from trainml.cli.job import job
|
trainml/cli/job/create.py
CHANGED
|
@@ -389,7 +389,7 @@ def notebook(
|
|
|
389
389
|
],
|
|
390
390
|
case_sensitive=False,
|
|
391
391
|
),
|
|
392
|
-
default="rtx3090",
|
|
392
|
+
default=["rtx3090"],
|
|
393
393
|
multiple=True,
|
|
394
394
|
show_default=True,
|
|
395
395
|
help="GPU type.",
|
|
@@ -732,7 +732,7 @@ def training(
|
|
|
732
732
|
],
|
|
733
733
|
case_sensitive=False,
|
|
734
734
|
),
|
|
735
|
-
default="rtx3090",
|
|
735
|
+
default=["rtx3090"],
|
|
736
736
|
show_default=True,
|
|
737
737
|
multiple=True,
|
|
738
738
|
help="GPU type.",
|
|
@@ -1099,7 +1099,7 @@ def from_json(config, attach, connect, file):
|
|
|
1099
1099
|
],
|
|
1100
1100
|
case_sensitive=False,
|
|
1101
1101
|
),
|
|
1102
|
-
default="rtx3090",
|
|
1102
|
+
default=["rtx3090"],
|
|
1103
1103
|
show_default=True,
|
|
1104
1104
|
multiple=True,
|
|
1105
1105
|
help="GPU type.",
|