trainml 0.5.3__py3-none-any.whl → 0.5.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,447 @@
1
+ import re
2
+ import json
3
+ import logging
4
+ from unittest.mock import AsyncMock, patch
5
+ from pytest import mark, fixture, raises
6
+ from aiohttp import WSMessage, WSMsgType
7
+
8
+ import trainml.volumes as specimen
9
+ from trainml.exceptions import (
10
+ ApiError,
11
+ VolumeError,
12
+ SpecificationError,
13
+ TrainMLException,
14
+ )
15
+
16
+ pytestmark = [mark.sdk, mark.unit, mark.volumes]
17
+
18
+
19
+ @fixture
20
+ def volumes(mock_trainml):
21
+ yield specimen.Volumes(mock_trainml)
22
+
23
+
24
+ @fixture
25
+ def volume(mock_trainml):
26
+ yield specimen.Volume(
27
+ mock_trainml,
28
+ id="1",
29
+ project_uuid="proj-id-1",
30
+ name="first one",
31
+ status="downloading",
32
+ capacity="10G",
33
+ used_size=100000000,
34
+ billed_size=100000000,
35
+ createdAt="2020-12-31T23:59:59.000Z",
36
+ )
37
+
38
+
39
+ class VolumesTests:
40
+ @mark.asyncio
41
+ async def test_get_volume(
42
+ self,
43
+ volumes,
44
+ mock_trainml,
45
+ ):
46
+ api_response = dict()
47
+ mock_trainml._query = AsyncMock(return_value=api_response)
48
+ await volumes.get("1234")
49
+ mock_trainml._query.assert_called_once_with("/volume/1234", "GET", dict())
50
+
51
+ @mark.asyncio
52
+ async def test_list_volumes(
53
+ self,
54
+ volumes,
55
+ mock_trainml,
56
+ ):
57
+ api_response = dict()
58
+ mock_trainml._query = AsyncMock(return_value=api_response)
59
+ await volumes.list()
60
+ mock_trainml._query.assert_called_once_with("/volume", "GET", dict())
61
+
62
+ @mark.asyncio
63
+ async def test_remove_volume(
64
+ self,
65
+ volumes,
66
+ mock_trainml,
67
+ ):
68
+ api_response = dict()
69
+ mock_trainml._query = AsyncMock(return_value=api_response)
70
+ await volumes.remove("4567")
71
+ mock_trainml._query.assert_called_once_with(
72
+ "/volume/4567", "DELETE", dict(force=True)
73
+ )
74
+
75
+ @mark.asyncio
76
+ async def test_create_volume_simple(self, volumes, mock_trainml):
77
+ requested_config = dict(
78
+ name="new volume",
79
+ source_type="aws",
80
+ source_uri="s3://trainml-examples/volumes/resnet50",
81
+ capacity="10G",
82
+ )
83
+ expected_payload = dict(
84
+ project_uuid="proj-id-1",
85
+ name="new volume",
86
+ source_type="aws",
87
+ source_uri="s3://trainml-examples/volumes/resnet50",
88
+ capacity="10G",
89
+ )
90
+ api_response = {
91
+ "project_uuid": "cus-id-1",
92
+ "id": "volume-id-1",
93
+ "name": "new volume",
94
+ "status": "new",
95
+ "source_type": "aws",
96
+ "capacity": "10G",
97
+ "source_uri": "s3://trainml-examples/volumes/resnet50",
98
+ "createdAt": "2020-12-20T16:46:23.909Z",
99
+ }
100
+
101
+ mock_trainml._query = AsyncMock(return_value=api_response)
102
+ response = await volumes.create(**requested_config)
103
+ mock_trainml._query.assert_called_once_with(
104
+ "/volume", "POST", None, expected_payload
105
+ )
106
+ assert response.id == "volume-id-1"
107
+
108
+
109
+ class VolumeTests:
110
+ def test_volume_properties(self, volume):
111
+ assert isinstance(volume.id, str)
112
+ assert isinstance(volume.status, str)
113
+ assert isinstance(volume.name, str)
114
+ assert isinstance(volume.capacity, str)
115
+ assert isinstance(volume.used_size, int)
116
+ assert isinstance(volume.billed_size, int)
117
+
118
+ def test_volume_str(self, volume):
119
+ string = str(volume)
120
+ regex = r"^{.*\"id\": \"" + volume.id + r"\".*}$"
121
+ assert isinstance(string, str)
122
+ assert re.match(regex, string)
123
+
124
+ def test_volume_repr(self, volume):
125
+ string = repr(volume)
126
+ regex = r"^Volume\( trainml , \*\*{.*'id': '" + volume.id + r"'.*}\)$"
127
+ assert isinstance(string, str)
128
+ assert re.match(regex, string)
129
+
130
+ def test_volume_bool(self, volume, mock_trainml):
131
+ empty_volume = specimen.Volume(mock_trainml)
132
+ assert bool(volume)
133
+ assert not bool(empty_volume)
134
+
135
+ @mark.asyncio
136
+ async def test_volume_get_log_url(self, volume, mock_trainml):
137
+ api_response = (
138
+ "https://trainml-jobs-dev.s3.us-east-2.amazonaws.com/1/logs/first_one.zip"
139
+ )
140
+ mock_trainml._query = AsyncMock(return_value=api_response)
141
+ response = await volume.get_log_url()
142
+ mock_trainml._query.assert_called_once_with(
143
+ "/volume/1/logs", "GET", dict(project_uuid="proj-id-1")
144
+ )
145
+ assert response == api_response
146
+
147
+ @mark.asyncio
148
+ async def test_volume_get_details(self, volume, mock_trainml):
149
+ api_response = {
150
+ "type": "directory",
151
+ "name": "/",
152
+ "count": "8",
153
+ "used_size": "177M",
154
+ "contents": [],
155
+ }
156
+ mock_trainml._query = AsyncMock(return_value=api_response)
157
+ response = await volume.get_details()
158
+ mock_trainml._query.assert_called_once_with(
159
+ "/volume/1/details", "GET", dict(project_uuid="proj-id-1")
160
+ )
161
+ assert response == api_response
162
+
163
+ @mark.asyncio
164
+ async def test_volume_get_connection_utility_url(self, volume, mock_trainml):
165
+ api_response = (
166
+ "https://trainml-jobs-dev.s3.us-east-2.amazonaws.com/1/vpn/first_one.zip"
167
+ )
168
+ mock_trainml._query = AsyncMock(return_value=api_response)
169
+ response = await volume.get_connection_utility_url()
170
+ mock_trainml._query.assert_called_once_with(
171
+ "/volume/1/download", "GET", dict(project_uuid="proj-id-1")
172
+ )
173
+ assert response == api_response
174
+
175
+ def test_volume_get_connection_details_no_vpn(self, volume):
176
+ details = volume.get_connection_details()
177
+ expected_details = dict()
178
+ assert details == expected_details
179
+
180
+ def test_volume_get_connection_details_local_data(self, mock_trainml):
181
+ volume = specimen.Volume(
182
+ mock_trainml,
183
+ id="1",
184
+ project_uuid="a",
185
+ name="first one",
186
+ status="new",
187
+ capacity="10G",
188
+ createdAt="2020-12-31T23:59:59.000Z",
189
+ source_type="local",
190
+ source_uri="~/tensorflow-example",
191
+ vpn={
192
+ "status": "new",
193
+ "cidr": "10.106.171.0/24",
194
+ "client": {
195
+ "port": "36017",
196
+ "id": "cus-id-1",
197
+ "address": "10.106.171.253",
198
+ "ssh_port": 46600,
199
+ },
200
+ "net_prefix_type_id": 1,
201
+ },
202
+ )
203
+ details = volume.get_connection_details()
204
+ expected_details = dict(
205
+ project_uuid="a",
206
+ entity_type="volume",
207
+ cidr="10.106.171.0/24",
208
+ ssh_port=46600,
209
+ input_path="~/tensorflow-example",
210
+ output_path=None,
211
+ )
212
+ assert details == expected_details
213
+
214
+ @mark.asyncio
215
+ async def test_volume_connect(self, volume, mock_trainml):
216
+ with patch(
217
+ "trainml.volumes.Connection",
218
+ autospec=True,
219
+ ) as mock_connection:
220
+ connection = mock_connection.return_value
221
+ connection.status = "connected"
222
+ resp = await volume.connect()
223
+ connection.start.assert_called_once()
224
+ assert resp == "connected"
225
+
226
+ @mark.asyncio
227
+ async def test_volume_disconnect(self, volume, mock_trainml):
228
+ with patch(
229
+ "trainml.volumes.Connection",
230
+ autospec=True,
231
+ ) as mock_connection:
232
+ connection = mock_connection.return_value
233
+ connection.status = "removed"
234
+ resp = await volume.disconnect()
235
+ connection.stop.assert_called_once()
236
+ assert resp == "removed"
237
+
238
+ @mark.asyncio
239
+ async def test_volume_remove(self, volume, mock_trainml):
240
+ api_response = dict()
241
+ mock_trainml._query = AsyncMock(return_value=api_response)
242
+ await volume.remove()
243
+ mock_trainml._query.assert_called_once_with(
244
+ "/volume/1",
245
+ "DELETE",
246
+ dict(project_uuid="proj-id-1", force=False),
247
+ )
248
+
249
+ def test_volume_default_ws_msg_handler(self, volume, capsys):
250
+ data = {
251
+ "msg": "download: s3://trainml-examples/data/cifar10/data_batch_2.bin to ./data_batch_2.bin\n",
252
+ "time": 1613079345318,
253
+ "type": "subscription",
254
+ "stream": "worker-id-1",
255
+ "job_worker_uuid": "worker-id-1",
256
+ }
257
+
258
+ handler = volume._get_msg_handler(None)
259
+ handler(data)
260
+ captured = capsys.readouterr()
261
+ assert (
262
+ captured.out
263
+ == "02/11/2021, 15:35:45: download: s3://trainml-examples/data/cifar10/data_batch_2.bin to ./data_batch_2.bin\n"
264
+ )
265
+
266
+ def test_volume_custom_ws_msg_handler(self, volume, capsys):
267
+ def custom_handler(msg):
268
+ print(msg.get("stream"))
269
+
270
+ data = {
271
+ "msg": "download: s3://trainml-examples/data/cifar10/data_batch_2.bin to ./data_batch_2.bin\n",
272
+ "time": 1613079345318,
273
+ "type": "subscription",
274
+ "stream": "worker-id-1",
275
+ "job_worker_uuid": "worker-id-1",
276
+ }
277
+
278
+ handler = volume._get_msg_handler(custom_handler)
279
+ handler(data)
280
+ captured = capsys.readouterr()
281
+ assert captured.out == "worker-id-1\n"
282
+
283
+ @mark.asyncio
284
+ async def test_volume_attach(self, volume, mock_trainml):
285
+ api_response = None
286
+ mock_trainml._ws_subscribe = AsyncMock(return_value=api_response)
287
+ refresh_response = {
288
+ "customer_uuid": "cus-id-1",
289
+ "id": "data-id-1",
290
+ "name": "new volume",
291
+ "status": "downloading",
292
+ "source_type": "aws",
293
+ "source_uri": "s3://trainml-examples/data/cifar10",
294
+ "createdAt": "2020-12-20T16:46:23.909Z",
295
+ }
296
+ volume.refresh = AsyncMock(return_value=refresh_response)
297
+ await volume.attach()
298
+ mock_trainml._ws_subscribe.assert_called_once()
299
+
300
+ @mark.asyncio
301
+ async def test_volume_attach_immediate_return(self, mock_trainml):
302
+ volume = specimen.Volume(
303
+ mock_trainml,
304
+ id="1",
305
+ name="first one",
306
+ status="ready",
307
+ createdAt="2020-12-31T23:59:59.000Z",
308
+ )
309
+ api_response = None
310
+ mock_trainml._ws_subscribe = AsyncMock(return_value=api_response)
311
+ refresh_response = {
312
+ "customer_uuid": "cus-id-1",
313
+ "id": "1",
314
+ "name": "new volume",
315
+ "status": "ready",
316
+ "createdAt": "2020-12-20T16:46:23.909Z",
317
+ }
318
+ volume.refresh = AsyncMock(return_value=refresh_response)
319
+ await volume.attach()
320
+ mock_trainml._ws_subscribe.assert_not_called()
321
+
322
+ @mark.asyncio
323
+ async def test_volume_refresh(self, volume, mock_trainml):
324
+ api_response = {
325
+ "customer_uuid": "cus-id-1",
326
+ "id": "data-id-1",
327
+ "name": "new volume",
328
+ "status": "ready",
329
+ "source_type": "aws",
330
+ "source_uri": "s3://trainml-examples/data/cifar10",
331
+ "createdAt": "2020-12-20T16:46:23.909Z",
332
+ }
333
+ mock_trainml._query = AsyncMock(return_value=api_response)
334
+ response = await volume.refresh()
335
+ mock_trainml._query.assert_called_once_with(
336
+ f"/volume/1", "GET", dict(project_uuid="proj-id-1")
337
+ )
338
+ assert volume.id == "data-id-1"
339
+ assert response.id == "data-id-1"
340
+
341
+ @mark.asyncio
342
+ async def test_volume_wait_for_successful(self, volume, mock_trainml):
343
+ api_response = {
344
+ "customer_uuid": "cus-id-1",
345
+ "id": "data-id-1",
346
+ "name": "new volume",
347
+ "status": "ready",
348
+ "source_type": "aws",
349
+ "source_uri": "s3://trainml-examples/data/cifar10",
350
+ "createdAt": "2020-12-20T16:46:23.909Z",
351
+ }
352
+ mock_trainml._query = AsyncMock(return_value=api_response)
353
+ response = await volume.wait_for("ready")
354
+ mock_trainml._query.assert_called_once_with(
355
+ f"/volume/1", "GET", dict(project_uuid="proj-id-1")
356
+ )
357
+ assert volume.id == "data-id-1"
358
+ assert response.id == "data-id-1"
359
+
360
+ @mark.asyncio
361
+ async def test_volume_wait_for_current_status(self, mock_trainml):
362
+ volume = specimen.Volume(
363
+ mock_trainml,
364
+ id="1",
365
+ name="first one",
366
+ status="ready",
367
+ createdAt="2020-12-31T23:59:59.000Z",
368
+ )
369
+ api_response = None
370
+ mock_trainml._query = AsyncMock(return_value=api_response)
371
+ await volume.wait_for("ready")
372
+ mock_trainml._query.assert_not_called()
373
+
374
+ @mark.asyncio
375
+ async def test_volume_wait_for_incorrect_status(self, volume, mock_trainml):
376
+ api_response = None
377
+ mock_trainml._query = AsyncMock(return_value=api_response)
378
+ with raises(SpecificationError):
379
+ await volume.wait_for("stopped")
380
+ mock_trainml._query.assert_not_called()
381
+
382
+ @mark.asyncio
383
+ async def test_volume_wait_for_with_delay(self, volume, mock_trainml):
384
+ api_response_initial = dict(
385
+ id="1",
386
+ name="first one",
387
+ status="new",
388
+ createdAt="2020-12-31T23:59:59.000Z",
389
+ )
390
+ api_response_final = dict(
391
+ id="1",
392
+ name="first one",
393
+ status="ready",
394
+ createdAt="2020-12-31T23:59:59.000Z",
395
+ )
396
+ mock_trainml._query = AsyncMock()
397
+ mock_trainml._query.side_effect = [
398
+ api_response_initial,
399
+ api_response_initial,
400
+ api_response_final,
401
+ ]
402
+ response = await volume.wait_for("ready")
403
+ assert volume.status == "ready"
404
+ assert response.status == "ready"
405
+
406
+ @mark.asyncio
407
+ async def test_volume_wait_for_timeout(self, volume, mock_trainml):
408
+ api_response = dict(
409
+ id="1",
410
+ name="first one",
411
+ status="downloading",
412
+ createdAt="2020-12-31T23:59:59.000Z",
413
+ )
414
+ mock_trainml._query = AsyncMock(return_value=api_response)
415
+ with raises(TrainMLException):
416
+ await volume.wait_for("ready", 10)
417
+ mock_trainml._query.assert_called()
418
+
419
+ @mark.asyncio
420
+ async def test_volume_wait_for_volume_failed(self, volume, mock_trainml):
421
+ api_response = dict(
422
+ id="1",
423
+ name="first one",
424
+ status="failed",
425
+ createdAt="2020-12-31T23:59:59.000Z",
426
+ )
427
+ mock_trainml._query = AsyncMock(return_value=api_response)
428
+ with raises(VolumeError):
429
+ await volume.wait_for("ready")
430
+ mock_trainml._query.assert_called()
431
+
432
+ @mark.asyncio
433
+ async def test_volume_wait_for_archived_succeeded(self, volume, mock_trainml):
434
+ mock_trainml._query = AsyncMock(
435
+ side_effect=ApiError(404, dict(errorMessage="Volume Not Found"))
436
+ )
437
+ await volume.wait_for("archived")
438
+ mock_trainml._query.assert_called()
439
+
440
+ @mark.asyncio
441
+ async def test_volume_wait_for_unexpected_api_error(self, volume, mock_trainml):
442
+ mock_trainml._query = AsyncMock(
443
+ side_effect=ApiError(404, dict(errorMessage="Volume Not Found"))
444
+ )
445
+ with raises(ApiError):
446
+ await volume.wait_for("ready")
447
+ mock_trainml._query.assert_called()
trainml/__init__.py CHANGED
@@ -13,5 +13,5 @@ logging.basicConfig(
13
13
  logger = logging.getLogger(__name__)
14
14
 
15
15
 
16
- __version__ = "0.5.3"
16
+ __version__ = "0.5.5"
17
17
  __all__ = "TrainML"
trainml/cli/__init__.py CHANGED
@@ -142,9 +142,7 @@ def configure(config):
142
142
  project for project in projects if project.id == active_project_id
143
143
  ]
144
144
 
145
- active_project_name = (
146
- active_project[0].name if len(active_project) else "UNSET"
147
- )
145
+ active_project_name = active_project[0].name if len(active_project) else "UNSET"
148
146
 
149
147
  click.echo(f"Current Active Project: {active_project_name}")
150
148
 
@@ -154,9 +152,7 @@ def configure(config):
154
152
  show_choices=True,
155
153
  default=active_project_name,
156
154
  )
157
- selected_project = [
158
- project for project in projects if project.name == name
159
- ]
155
+ selected_project = [project for project in projects if project.name == name]
160
156
  config.trainml.client.set_active_project(selected_project[0].id)
161
157
 
162
158
 
@@ -164,6 +160,7 @@ from trainml.cli.connection import connection
164
160
  from trainml.cli.dataset import dataset
165
161
  from trainml.cli.model import model
166
162
  from trainml.cli.checkpoint import checkpoint
163
+ from trainml.cli.volume import volume
167
164
  from trainml.cli.environment import environment
168
165
  from trainml.cli.gpu import gpu
169
166
  from trainml.cli.job import job
trainml/cli/job/create.py CHANGED
@@ -389,7 +389,7 @@ def notebook(
389
389
  ],
390
390
  case_sensitive=False,
391
391
  ),
392
- default="rtx3090",
392
+ default=["rtx3090"],
393
393
  multiple=True,
394
394
  show_default=True,
395
395
  help="GPU type.",
@@ -732,7 +732,7 @@ def training(
732
732
  ],
733
733
  case_sensitive=False,
734
734
  ),
735
- default="rtx3090",
735
+ default=["rtx3090"],
736
736
  show_default=True,
737
737
  multiple=True,
738
738
  help="GPU type.",
@@ -1099,7 +1099,7 @@ def from_json(config, attach, connect, file):
1099
1099
  ],
1100
1100
  case_sensitive=False,
1101
1101
  ),
1102
- default="rtx3090",
1102
+ default=["rtx3090"],
1103
1103
  show_default=True,
1104
1104
  multiple=True,
1105
1105
  help="GPU type.",