trainml 0.5.9__py3-none-any.whl → 0.5.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (53) hide show
  1. tests/integration/projects/conftest.py +3 -1
  2. tests/integration/projects/test_projects_data_connectors_integration.py +44 -0
  3. tests/integration/projects/test_projects_datastores_integration.py +42 -0
  4. tests/integration/projects/test_projects_services_integration.py +44 -0
  5. tests/integration/test_checkpoints_integration.py +1 -2
  6. tests/integration/test_models_integration.py +0 -1
  7. tests/unit/cli/projects/__init__.py +0 -0
  8. tests/unit/cli/projects/test_cli_project_data_connector_unit.py +28 -0
  9. tests/unit/cli/projects/test_cli_project_datastore_unit.py +26 -0
  10. tests/unit/cli/projects/test_cli_project_key_unit.py +26 -0
  11. tests/unit/cli/projects/test_cli_project_secret_unit.py +26 -0
  12. tests/unit/cli/projects/test_cli_project_service_unit.py +26 -0
  13. tests/unit/cli/projects/test_cli_project_unit.py +19 -0
  14. tests/unit/cloudbender/test_datastores_unit.py +1 -5
  15. tests/unit/conftest.py +77 -4
  16. tests/unit/test_checkpoints_unit.py +15 -23
  17. tests/unit/test_datasets_unit.py +15 -20
  18. tests/unit/test_models_unit.py +13 -16
  19. tests/unit/test_volumes_unit.py +3 -0
  20. trainml/__init__.py +1 -1
  21. trainml/checkpoints.py +14 -3
  22. trainml/cli/cloudbender/datastore.py +2 -7
  23. trainml/cli/project/__init__.py +3 -72
  24. trainml/cli/project/data_connector.py +61 -0
  25. trainml/cli/project/datastore.py +61 -0
  26. trainml/cli/project/service.py +61 -0
  27. trainml/cloudbender/data_connectors.py +8 -0
  28. trainml/cloudbender/datastores.py +9 -19
  29. trainml/cloudbender/nodes.py +44 -1
  30. trainml/cloudbender/providers.py +53 -0
  31. trainml/cloudbender/regions.py +48 -0
  32. trainml/datasets.py +14 -3
  33. trainml/exceptions.py +51 -0
  34. trainml/jobs.py +2 -13
  35. trainml/models.py +14 -3
  36. trainml/volumes.py +15 -3
  37. {trainml-0.5.9.dist-info → trainml-0.5.11.dist-info}/METADATA +1 -1
  38. {trainml-0.5.9.dist-info → trainml-0.5.11.dist-info}/RECORD +42 -40
  39. tests/integration/test_projects_integration.py +0 -44
  40. tests/unit/cli/cloudbender/test_cli_reservation_unit.py +0 -34
  41. tests/unit/cli/test_cli_project_unit.py +0 -42
  42. tests/unit/cloudbender/test_reservations_unit.py +0 -173
  43. tests/unit/test_auth.py +0 -30
  44. tests/unit/test_projects_unit.py +0 -320
  45. tests/unit/test_trainml.py +0 -54
  46. trainml/cli/cloudbender/reservation.py +0 -159
  47. trainml/cli/project.py +0 -149
  48. trainml/cloudbender/reservations.py +0 -126
  49. trainml/projects.py +0 -228
  50. {trainml-0.5.9.dist-info → trainml-0.5.11.dist-info}/LICENSE +0 -0
  51. {trainml-0.5.9.dist-info → trainml-0.5.11.dist-info}/WHEEL +0 -0
  52. {trainml-0.5.9.dist-info → trainml-0.5.11.dist-info}/entry_points.txt +0 -0
  53. {trainml-0.5.9.dist-info → trainml-0.5.11.dist-info}/top_level.txt +0 -0
@@ -28,6 +28,7 @@ def dataset(mock_trainml):
28
28
  dataset_uuid="1",
29
29
  project_uuid="proj-id-1",
30
30
  name="first one",
31
+ type="evefs",
31
32
  status="downloading",
32
33
  size=100000,
33
34
  createdAt="2020-12-31T23:59:59.000Z",
@@ -103,6 +104,7 @@ class DatasetsTests:
103
104
  name="new dataset",
104
105
  source_type="aws",
105
106
  source_uri="s3://trainml-examples/data/cifar10",
107
+ type="evefs",
106
108
  )
107
109
  api_response = {
108
110
  "customer_uuid": "cus-id-1",
@@ -110,6 +112,7 @@ class DatasetsTests:
110
112
  "dataset_uuid": "data-id-1",
111
113
  "name": "new dataset",
112
114
  "status": "new",
115
+ "type": "evefs",
113
116
  "source_type": "aws",
114
117
  "source_uri": "s3://trainml-examples/data/cifar10",
115
118
  "createdAt": "2020-12-20T16:46:23.909Z",
@@ -139,9 +142,7 @@ class DatasetTests:
139
142
  def test_dataset_repr(self, dataset):
140
143
  string = repr(dataset)
141
144
  regex = (
142
- r"^Dataset\( trainml , \*\*{.*'dataset_uuid': '"
143
- + dataset.id
144
- + r"'.*}\)$"
145
+ r"^Dataset\( trainml , \*\*{.*'dataset_uuid': '" + dataset.id + r"'.*}\)$"
145
146
  )
146
147
  assert isinstance(string, str)
147
148
  assert re.match(regex, string)
@@ -153,7 +154,9 @@ class DatasetTests:
153
154
 
154
155
  @mark.asyncio
155
156
  async def test_dataset_get_log_url(self, dataset, mock_trainml):
156
- api_response = "https://trainml-jobs-dev.s3.us-east-2.amazonaws.com/1/logs/first_one.zip"
157
+ api_response = (
158
+ "https://trainml-jobs-dev.s3.us-east-2.amazonaws.com/1/logs/first_one.zip"
159
+ )
157
160
  mock_trainml._query = AsyncMock(return_value=api_response)
158
161
  response = await dataset.get_log_url()
159
162
  mock_trainml._query.assert_called_once_with(
@@ -178,10 +181,10 @@ class DatasetTests:
178
181
  assert response == api_response
179
182
 
180
183
  @mark.asyncio
181
- async def test_dataset_get_connection_utility_url(
182
- self, dataset, mock_trainml
183
- ):
184
- api_response = "https://trainml-jobs-dev.s3.us-east-2.amazonaws.com/1/vpn/first_one.zip"
184
+ async def test_dataset_get_connection_utility_url(self, dataset, mock_trainml):
185
+ api_response = (
186
+ "https://trainml-jobs-dev.s3.us-east-2.amazonaws.com/1/vpn/first_one.zip"
187
+ )
185
188
  mock_trainml._query = AsyncMock(return_value=api_response)
186
189
  response = await dataset.get_connection_utility_url()
187
190
  mock_trainml._query.assert_called_once_with(
@@ -388,9 +391,7 @@ class DatasetTests:
388
391
  mock_trainml._query.assert_not_called()
389
392
 
390
393
  @mark.asyncio
391
- async def test_dataset_wait_for_incorrect_status(
392
- self, dataset, mock_trainml
393
- ):
394
+ async def test_dataset_wait_for_incorrect_status(self, dataset, mock_trainml):
394
395
  api_response = None
395
396
  mock_trainml._query = AsyncMock(return_value=api_response)
396
397
  with raises(SpecificationError):
@@ -435,9 +436,7 @@ class DatasetTests:
435
436
  mock_trainml._query.assert_called()
436
437
 
437
438
  @mark.asyncio
438
- async def test_dataset_wait_for_dataset_failed(
439
- self, dataset, mock_trainml
440
- ):
439
+ async def test_dataset_wait_for_dataset_failed(self, dataset, mock_trainml):
441
440
  api_response = dict(
442
441
  dataset_uuid="1",
443
442
  name="first one",
@@ -450,9 +449,7 @@ class DatasetTests:
450
449
  mock_trainml._query.assert_called()
451
450
 
452
451
  @mark.asyncio
453
- async def test_dataset_wait_for_archived_succeeded(
454
- self, dataset, mock_trainml
455
- ):
452
+ async def test_dataset_wait_for_archived_succeeded(self, dataset, mock_trainml):
456
453
  mock_trainml._query = AsyncMock(
457
454
  side_effect=ApiError(404, dict(errorMessage="Dataset Not Found"))
458
455
  )
@@ -460,9 +457,7 @@ class DatasetTests:
460
457
  mock_trainml._query.assert_called()
461
458
 
462
459
  @mark.asyncio
463
- async def test_dataset_wait_for_unexpected_api_error(
464
- self, dataset, mock_trainml
465
- ):
460
+ async def test_dataset_wait_for_unexpected_api_error(self, dataset, mock_trainml):
466
461
  mock_trainml._query = AsyncMock(
467
462
  side_effect=ApiError(404, dict(errorMessage="Dataset Not Found"))
468
463
  )
@@ -28,6 +28,7 @@ def model(mock_trainml):
28
28
  model_uuid="1",
29
29
  project_uuid="proj-id-1",
30
30
  name="first one",
31
+ type="evefs",
31
32
  status="downloading",
32
33
  size=100000,
33
34
  createdAt="2020-12-31T23:59:59.000Z",
@@ -44,9 +45,7 @@ class ModelsTests:
44
45
  api_response = dict()
45
46
  mock_trainml._query = AsyncMock(return_value=api_response)
46
47
  await models.get("1234")
47
- mock_trainml._query.assert_called_once_with(
48
- "/model/1234", "GET", dict()
49
- )
48
+ mock_trainml._query.assert_called_once_with("/model/1234", "GET", dict())
50
49
 
51
50
  @mark.asyncio
52
51
  async def test_list_models(
@@ -84,12 +83,14 @@ class ModelsTests:
84
83
  name="new model",
85
84
  source_type="aws",
86
85
  source_uri="s3://trainml-examples/models/resnet50",
86
+ type="evefs",
87
87
  )
88
88
  api_response = {
89
89
  "customer_uuid": "cus-id-1",
90
90
  "model_uuid": "model-id-1",
91
91
  "name": "new model",
92
92
  "status": "new",
93
+ "type": "evefs",
93
94
  "source_type": "aws",
94
95
  "source_uri": "s3://trainml-examples/models/resnet50",
95
96
  "createdAt": "2020-12-20T16:46:23.909Z",
@@ -118,11 +119,7 @@ class ModelTests:
118
119
 
119
120
  def test_model_repr(self, model):
120
121
  string = repr(model)
121
- regex = (
122
- r"^Model\( trainml , \*\*{.*'model_uuid': '"
123
- + model.id
124
- + r"'.*}\)$"
125
- )
122
+ regex = r"^Model\( trainml , \*\*{.*'model_uuid': '" + model.id + r"'.*}\)$"
126
123
  assert isinstance(string, str)
127
124
  assert re.match(regex, string)
128
125
 
@@ -133,7 +130,9 @@ class ModelTests:
133
130
 
134
131
  @mark.asyncio
135
132
  async def test_model_get_log_url(self, model, mock_trainml):
136
- api_response = "https://trainml-jobs-dev.s3.us-east-2.amazonaws.com/1/logs/first_one.zip"
133
+ api_response = (
134
+ "https://trainml-jobs-dev.s3.us-east-2.amazonaws.com/1/logs/first_one.zip"
135
+ )
137
136
  mock_trainml._query = AsyncMock(return_value=api_response)
138
137
  response = await model.get_log_url()
139
138
  mock_trainml._query.assert_called_once_with(
@@ -159,7 +158,9 @@ class ModelTests:
159
158
 
160
159
  @mark.asyncio
161
160
  async def test_model_get_connection_utility_url(self, model, mock_trainml):
162
- api_response = "https://trainml-jobs-dev.s3.us-east-2.amazonaws.com/1/vpn/first_one.zip"
161
+ api_response = (
162
+ "https://trainml-jobs-dev.s3.us-east-2.amazonaws.com/1/vpn/first_one.zip"
163
+ )
163
164
  mock_trainml._query = AsyncMock(return_value=api_response)
164
165
  response = await model.get_connection_utility_url()
165
166
  mock_trainml._query.assert_called_once_with(
@@ -425,9 +426,7 @@ class ModelTests:
425
426
  mock_trainml._query.assert_called()
426
427
 
427
428
  @mark.asyncio
428
- async def test_model_wait_for_archived_succeeded(
429
- self, model, mock_trainml
430
- ):
429
+ async def test_model_wait_for_archived_succeeded(self, model, mock_trainml):
431
430
  mock_trainml._query = AsyncMock(
432
431
  side_effect=ApiError(404, dict(errorMessage="Model Not Found"))
433
432
  )
@@ -435,9 +434,7 @@ class ModelTests:
435
434
  mock_trainml._query.assert_called()
436
435
 
437
436
  @mark.asyncio
438
- async def test_model_wait_for_unexpected_api_error(
439
- self, model, mock_trainml
440
- ):
437
+ async def test_model_wait_for_unexpected_api_error(self, model, mock_trainml):
441
438
  mock_trainml._query = AsyncMock(
442
439
  side_effect=ApiError(404, dict(errorMessage="Model Not Found"))
443
440
  )
@@ -26,6 +26,7 @@ def volume(mock_trainml):
26
26
  yield specimen.Volume(
27
27
  mock_trainml,
28
28
  id="1",
29
+ type="evefs",
29
30
  project_uuid="proj-id-1",
30
31
  name="first one",
31
32
  status="downloading",
@@ -86,12 +87,14 @@ class VolumesTests:
86
87
  source_type="aws",
87
88
  source_uri="s3://trainml-examples/volumes/resnet50",
88
89
  capacity="10G",
90
+ type="evefs",
89
91
  )
90
92
  api_response = {
91
93
  "project_uuid": "cus-id-1",
92
94
  "id": "volume-id-1",
93
95
  "name": "new volume",
94
96
  "status": "new",
97
+ "type": "evefs",
95
98
  "source_type": "aws",
96
99
  "capacity": "10G",
97
100
  "source_uri": "s3://trainml-examples/volumes/resnet50",
trainml/__init__.py CHANGED
@@ -13,5 +13,5 @@ logging.basicConfig(
13
13
  logger = logging.getLogger(__name__)
14
14
 
15
15
 
16
- __version__ = "0.5.9"
16
+ __version__ = "0.5.11"
17
17
  __all__ = "TrainML"
trainml/checkpoints.py CHANGED
@@ -31,13 +31,24 @@ class Checkpoints(object):
31
31
  datasets = [Checkpoint(self.trainml, **dataset) for dataset in resp]
32
32
  return datasets
33
33
 
34
- async def create(self, name, source_type, source_uri, **kwargs):
34
+ async def create(
35
+ self,
36
+ name,
37
+ source_type,
38
+ source_uri,
39
+ type="evefs",
40
+ project_uuid=None,
41
+ **kwargs,
42
+ ):
43
+ if not project_uuid:
44
+ project_uuid = self.trainml.active_project
35
45
  data = dict(
36
46
  name=name,
37
47
  source_type=source_type,
38
48
  source_uri=source_uri,
39
- source_options=kwargs.get("source_options"),
40
- project_uuid=kwargs.get("project_uuid") or self.trainml.active_project,
49
+ project_uuid=project_uuid,
50
+ type=type,
51
+ **kwargs,
41
52
  )
42
53
  payload = {k: v for k, v in data.items() if v is not None}
43
54
  logging.info(f"Creating Checkpoint {name}")
@@ -29,13 +29,11 @@ def datastore(config):
29
29
  def list(config, provider, region):
30
30
  """List datastores."""
31
31
  data = [
32
- ["ID", "NAME", "TYPE", "URI", "ROOT"],
32
+ ["ID", "NAME", "TYPE"],
33
33
  [
34
34
  "-" * 80,
35
35
  "-" * 80,
36
36
  "-" * 80,
37
- "-" * 80,
38
- "-" * 80,
39
37
  ],
40
38
  ]
41
39
 
@@ -51,15 +49,12 @@ def list(config, provider, region):
51
49
  datastore.id,
52
50
  datastore.name,
53
51
  datastore.type,
54
- datastore.uri,
55
- datastore.root,
56
52
  ]
57
53
  )
58
54
 
59
55
  for row in data:
60
56
  click.echo(
61
- "{: >37.36} {: >29.28} {: >9.8} {: >12.11} {: >12.11}"
62
- "".format(*row),
57
+ "{: >37.36} {: >29.28} {: >9.8} " "".format(*row),
63
58
  file=config.stdout,
64
59
  )
65
60
 
@@ -77,77 +77,8 @@ def remove(config, project):
77
77
  return config.trainml.run(found.remove())
78
78
 
79
79
 
80
- @project.command()
81
- @pass_config
82
- def list_datastores(config):
83
- """List project datastores."""
84
- data = [
85
- ["ID", "NAME", "TYPE", "REGION_UUID"],
86
- [
87
- "-" * 80,
88
- "-" * 80,
89
- "-" * 80,
90
- "-" * 80,
91
- ],
92
- ]
93
- project = config.trainml.run(
94
- config.trainml.client.projects.get(config.trainml.client.project)
95
- )
96
-
97
- datastores = config.trainml.run(project.list_datastores())
98
-
99
- for datastore in datastores:
100
- data.append(
101
- [
102
- datastore.id,
103
- datastore.name,
104
- datastore.type,
105
- datastore.region_uuid,
106
- ]
107
- )
108
-
109
- for row in data:
110
- click.echo(
111
- "{: >38.36} {: >30.28} {: >15.13} {: >38.36}" "".format(*row),
112
- file=config.stdout,
113
- )
114
-
115
-
116
- @project.command()
117
- @pass_config
118
- def list_services(config):
119
- """List project services."""
120
- data = [
121
- ["ID", "NAME", "HOSTNAME", "REGION_UUID"],
122
- [
123
- "-" * 80,
124
- "-" * 80,
125
- "-" * 80,
126
- "-" * 80,
127
- ],
128
- ]
129
- project = config.trainml.run(
130
- config.trainml.client.projects.get(config.trainml.client.project)
131
- )
132
-
133
- services = config.trainml.run(project.list_services())
134
-
135
- for service in services:
136
- data.append(
137
- [
138
- service.id,
139
- service.name,
140
- service.hostname,
141
- service.region_uuid,
142
- ]
143
- )
144
-
145
- for row in data:
146
- click.echo(
147
- "{: >38.36} {: >30.28} {: >30.28} {: >38.36}" "".format(*row),
148
- file=config.stdout,
149
- )
150
-
151
-
152
80
  from trainml.cli.project.secret import secret
153
81
  from trainml.cli.project.key import key
82
+ from trainml.cli.project.data_connector import data_connector
83
+ from trainml.cli.project.datastore import datastore
84
+ from trainml.cli.project.service import service
@@ -0,0 +1,61 @@
1
+ import click
2
+ import os
3
+ import json
4
+ import base64
5
+ from pathlib import Path
6
+ from trainml.cli import pass_config
7
+ from trainml.cli.project import project
8
+
9
+
10
+ @project.group()
11
+ @pass_config
12
+ def data_connector(config):
13
+ """trainML project data_connector commands."""
14
+ pass
15
+
16
+
17
+ @data_connector.command()
18
+ @pass_config
19
+ def list(config):
20
+ """List project data_connectors."""
21
+ data = [
22
+ ["ID", "NAME", "TYPE", "REGION_UUID"],
23
+ [
24
+ "-" * 80,
25
+ "-" * 80,
26
+ "-" * 80,
27
+ "-" * 80,
28
+ ],
29
+ ]
30
+ project = config.trainml.run(
31
+ config.trainml.client.projects.get(config.trainml.client.project)
32
+ )
33
+
34
+ data_connectors = config.trainml.run(project.data_connectors.list())
35
+
36
+ for data_connector in data_connectors:
37
+ data.append(
38
+ [
39
+ data_connector.id,
40
+ data_connector.name,
41
+ data_connector.type,
42
+ data_connector.region_uuid,
43
+ ]
44
+ )
45
+
46
+ for row in data:
47
+ click.echo(
48
+ "{: >38.36} {: >30.28} {: >15.13} {: >38.36}" "".format(*row),
49
+ file=config.stdout,
50
+ )
51
+
52
+
53
+ @data_connector.command()
54
+ @pass_config
55
+ def refresh(config):
56
+ """
57
+ Refresh project data_connector list.
58
+ """
59
+ project = config.trainml.run(config.trainml.client.projects.get_current())
60
+
61
+ return config.trainml.run(project.data_connectors.refresh())
@@ -0,0 +1,61 @@
1
+ import click
2
+ import os
3
+ import json
4
+ import base64
5
+ from pathlib import Path
6
+ from trainml.cli import pass_config
7
+ from trainml.cli.project import project
8
+
9
+
10
+ @project.group()
11
+ @pass_config
12
+ def datastore(config):
13
+ """trainML project datastore commands."""
14
+ pass
15
+
16
+
17
+ @datastore.command()
18
+ @pass_config
19
+ def list(config):
20
+ """List project datastores."""
21
+ data = [
22
+ ["ID", "NAME", "TYPE", "REGION_UUID"],
23
+ [
24
+ "-" * 80,
25
+ "-" * 80,
26
+ "-" * 80,
27
+ "-" * 80,
28
+ ],
29
+ ]
30
+ project = config.trainml.run(
31
+ config.trainml.client.projects.get(config.trainml.client.project)
32
+ )
33
+
34
+ datastores = config.trainml.run(project.datastores.list())
35
+
36
+ for datastore in datastores:
37
+ data.append(
38
+ [
39
+ datastore.id,
40
+ datastore.name,
41
+ datastore.type,
42
+ datastore.region_uuid,
43
+ ]
44
+ )
45
+
46
+ for row in data:
47
+ click.echo(
48
+ "{: >38.36} {: >30.28} {: >15.13} {: >38.36}" "".format(*row),
49
+ file=config.stdout,
50
+ )
51
+
52
+
53
+ @datastore.command()
54
+ @pass_config
55
+ def refresh(config):
56
+ """
57
+ Refresh project datastore list.
58
+ """
59
+ project = config.trainml.run(config.trainml.client.projects.get_current())
60
+
61
+ return config.trainml.run(project.datastores.refresh())
@@ -0,0 +1,61 @@
1
+ import click
2
+ import os
3
+ import json
4
+ import base64
5
+ from pathlib import Path
6
+ from trainml.cli import pass_config
7
+ from trainml.cli.project import project
8
+
9
+
10
+ @project.group()
11
+ @pass_config
12
+ def service(config):
13
+ """trainML project service commands."""
14
+ pass
15
+
16
+
17
+ @service.command()
18
+ @pass_config
19
+ def list(config):
20
+ """List project services."""
21
+ data = [
22
+ ["ID", "NAME", "TYPE", "REGION_UUID"],
23
+ [
24
+ "-" * 80,
25
+ "-" * 80,
26
+ "-" * 80,
27
+ "-" * 80,
28
+ ],
29
+ ]
30
+ project = config.trainml.run(
31
+ config.trainml.client.projects.get(config.trainml.client.project)
32
+ )
33
+
34
+ services = config.trainml.run(project.services.list())
35
+
36
+ for service in services:
37
+ data.append(
38
+ [
39
+ service.id,
40
+ service.name,
41
+ service.hostname,
42
+ service.region_uuid,
43
+ ]
44
+ )
45
+
46
+ for row in data:
47
+ click.echo(
48
+ "{: >38.36} {: >30.28} {: >15.13} {: >38.36}" "".format(*row),
49
+ file=config.stdout,
50
+ )
51
+
52
+
53
+ @service.command()
54
+ @pass_config
55
+ def refresh(config):
56
+ """
57
+ Refresh project service list.
58
+ """
59
+ project = config.trainml.run(config.trainml.client.projects.get_current())
60
+
61
+ return config.trainml.run(project.services.refresh())
@@ -1,5 +1,13 @@
1
1
  import json
2
2
  import logging
3
+ import asyncio
4
+ import math
5
+
6
+ from trainml.exceptions import (
7
+ ApiError,
8
+ SpecificationError,
9
+ TrainMLException,
10
+ )
3
11
 
4
12
 
5
13
  class DataConnectors(object):
@@ -1,5 +1,13 @@
1
1
  import json
2
2
  import logging
3
+ import asyncio
4
+ import math
5
+
6
+ from trainml.exceptions import (
7
+ ApiError,
8
+ SpecificationError,
9
+ TrainMLException,
10
+ )
3
11
 
4
12
 
5
13
  class Datastores(object):
@@ -20,9 +28,7 @@ class Datastores(object):
20
28
  "GET",
21
29
  kwargs,
22
30
  )
23
- datastores = [
24
- Datastore(self.trainml, **datastore) for datastore in resp
25
- ]
31
+ datastores = [Datastore(self.trainml, **datastore) for datastore in resp]
26
32
  return datastores
27
33
 
28
34
  async def create(
@@ -31,18 +37,12 @@ class Datastores(object):
31
37
  region_uuid,
32
38
  name,
33
39
  type,
34
- uri,
35
- root,
36
- options=None,
37
40
  **kwargs,
38
41
  ):
39
42
  logging.info(f"Creating Datastore {name}")
40
43
  data = dict(
41
44
  name=name,
42
45
  type=type,
43
- uri=uri,
44
- root=root,
45
- options=options,
46
46
  **kwargs,
47
47
  )
48
48
  payload = {k: v for k, v in data.items() if v is not None}
@@ -73,8 +73,6 @@ class Datastore:
73
73
  self._region_uuid = self._datastore.get("region_uuid")
74
74
  self._type = self._datastore.get("type")
75
75
  self._name = self._datastore.get("name")
76
- self._uri = self._datastore.get("uri")
77
- self._root = self._datastore.get("root")
78
76
 
79
77
  @property
80
78
  def id(self) -> str:
@@ -96,14 +94,6 @@ class Datastore:
96
94
  def name(self) -> str:
97
95
  return self._name
98
96
 
99
- @property
100
- def uri(self) -> str:
101
- return self._uri
102
-
103
- @property
104
- def root(self) -> str:
105
- return self._root
106
-
107
97
  def __str__(self):
108
98
  return json.dumps({k: v for k, v in self._datastore.items()})
109
99
 
@@ -1,5 +1,9 @@
1
1
  import json
2
2
  import logging
3
+ import asyncio
4
+ import math
5
+
6
+ from trainml.exceptions import ApiError, SpecificationError, TrainMLException, NodeError
3
7
 
4
8
 
5
9
  class Nodes(object):
@@ -29,7 +33,7 @@ class Nodes(object):
29
33
  region_uuid,
30
34
  friendly_name,
31
35
  hostname,
32
- minion_id,
36
+ minion_id=None,
33
37
  type="permanent",
34
38
  service="compute",
35
39
  **kwargs,
@@ -153,3 +157,42 @@ class Node:
153
157
  None,
154
158
  dict(command=command),
155
159
  )
160
+
161
+ async def wait_for(self, status, timeout=300):
162
+ if self.status == status:
163
+ return
164
+ valid_statuses = ["active", "maintenance", "offline", "stopped", "archived"]
165
+ if not status in valid_statuses:
166
+ raise SpecificationError(
167
+ "status",
168
+ f"Invalid wait_for status {status}. Valid statuses are: {valid_statuses}",
169
+ )
170
+ MAX_TIMEOUT = 24 * 60 * 60
171
+ if timeout > MAX_TIMEOUT:
172
+ raise SpecificationError(
173
+ "timeout",
174
+ f"timeout must be less than {MAX_TIMEOUT} seconds.",
175
+ )
176
+
177
+ POLL_INTERVAL_MIN = 5
178
+ POLL_INTERVAL_MAX = 60
179
+ POLL_INTERVAL = max(min(timeout / 60, POLL_INTERVAL_MAX), POLL_INTERVAL_MIN)
180
+ retry_count = math.ceil(timeout / POLL_INTERVAL)
181
+ count = 0
182
+ while count < retry_count:
183
+ await asyncio.sleep(POLL_INTERVAL)
184
+ try:
185
+ await self.refresh()
186
+ except ApiError as e:
187
+ if status == "archived" and e.status == 404:
188
+ return
189
+ raise e
190
+ if self.status in ["errored", "failed"]:
191
+ raise NodeError(self.status, self)
192
+ if self.status == status:
193
+ return self
194
+ else:
195
+ count += 1
196
+ logging.debug(f"self: {self}, retry count {count}")
197
+
198
+ raise TrainMLException(f"Timeout waiting for {status}")