trainml 0.5.6__py3-none-any.whl → 0.5.8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tests/integration/test_jobs_integration.py +13 -0
- tests/unit/cli/cloudbender/test_cli_service_unit.py +34 -0
- tests/unit/cloudbender/test_data_connectors_unit.py +176 -0
- tests/unit/cloudbender/test_services_unit.py +6 -0
- tests/unit/test_projects_unit.py +45 -5
- trainml/__init__.py +1 -1
- trainml/checkpoints.py +33 -26
- trainml/cli/cloudbender/__init__.py +1 -0
- trainml/cli/cloudbender/data_connector.py +159 -0
- trainml/cli/cloudbender/service.py +19 -2
- trainml/cloudbender/cloudbender.py +2 -0
- trainml/cloudbender/data_connectors.py +112 -0
- trainml/cloudbender/services.py +65 -1
- trainml/datasets.py +27 -9
- trainml/jobs.py +13 -6
- trainml/models.py +28 -20
- trainml/projects.py +60 -8
- trainml/volumes.py +9 -2
- {trainml-0.5.6.dist-info → trainml-0.5.8.dist-info}/METADATA +1 -1
- {trainml-0.5.6.dist-info → trainml-0.5.8.dist-info}/RECORD +24 -20
- {trainml-0.5.6.dist-info → trainml-0.5.8.dist-info}/LICENSE +0 -0
- {trainml-0.5.6.dist-info → trainml-0.5.8.dist-info}/WHEEL +0 -0
- {trainml-0.5.6.dist-info → trainml-0.5.8.dist-info}/entry_points.txt +0 -0
- {trainml-0.5.6.dist-info → trainml-0.5.8.dist-info}/top_level.txt +0 -0
|
@@ -74,6 +74,19 @@ def list(config, provider, region):
|
|
|
74
74
|
required=True,
|
|
75
75
|
help="The region ID to create the service in.",
|
|
76
76
|
)
|
|
77
|
+
@click.option(
|
|
78
|
+
"--type",
|
|
79
|
+
"-t",
|
|
80
|
+
type=click.Choice(
|
|
81
|
+
[
|
|
82
|
+
"https",
|
|
83
|
+
"tcp",
|
|
84
|
+
"udp",
|
|
85
|
+
],
|
|
86
|
+
),
|
|
87
|
+
required=True,
|
|
88
|
+
help="The type of regional service.",
|
|
89
|
+
)
|
|
77
90
|
@click.option(
|
|
78
91
|
"--public/--no-public",
|
|
79
92
|
default=True,
|
|
@@ -82,13 +95,17 @@ def list(config, provider, region):
|
|
|
82
95
|
)
|
|
83
96
|
@click.argument("name", type=click.STRING, required=True)
|
|
84
97
|
@pass_config
|
|
85
|
-
def create(config, provider, region, public, name):
|
|
98
|
+
def create(config, provider, region, type, public, name):
|
|
86
99
|
"""
|
|
87
100
|
Creates a service.
|
|
88
101
|
"""
|
|
89
102
|
return config.trainml.run(
|
|
90
103
|
config.trainml.client.cloudbender.services.create(
|
|
91
|
-
provider_uuid=provider,
|
|
104
|
+
provider_uuid=provider,
|
|
105
|
+
region_uuid=region,
|
|
106
|
+
name=name,
|
|
107
|
+
type=type,
|
|
108
|
+
public=public,
|
|
92
109
|
)
|
|
93
110
|
)
|
|
94
111
|
|
|
@@ -3,6 +3,7 @@ from .regions import Regions
|
|
|
3
3
|
from .nodes import Nodes
|
|
4
4
|
from .devices import Devices
|
|
5
5
|
from .datastores import Datastores
|
|
6
|
+
from .data_connectors import DataConnectors
|
|
6
7
|
from .services import Services
|
|
7
8
|
from .device_configs import DeviceConfigs
|
|
8
9
|
|
|
@@ -15,5 +16,6 @@ class Cloudbender(object):
|
|
|
15
16
|
self.nodes = Nodes(trainml)
|
|
16
17
|
self.devices = Devices(trainml)
|
|
17
18
|
self.datastores = Datastores(trainml)
|
|
19
|
+
self.data_connectors = DataConnectors(trainml)
|
|
18
20
|
self.services = Services(trainml)
|
|
19
21
|
self.device_configs = DeviceConfigs(trainml)
|
|
@@ -0,0 +1,112 @@
|
|
|
1
|
+
import json
|
|
2
|
+
import logging
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
class DataConnectors(object):
|
|
6
|
+
def __init__(self, trainml):
|
|
7
|
+
self.trainml = trainml
|
|
8
|
+
|
|
9
|
+
async def get(self, provider_uuid, region_uuid, id, **kwargs):
|
|
10
|
+
resp = await self.trainml._query(
|
|
11
|
+
f"/provider/{provider_uuid}/region/{region_uuid}/data_connector/{id}",
|
|
12
|
+
"GET",
|
|
13
|
+
kwargs,
|
|
14
|
+
)
|
|
15
|
+
return DataConnector(self.trainml, **resp)
|
|
16
|
+
|
|
17
|
+
async def list(self, provider_uuid, region_uuid, **kwargs):
|
|
18
|
+
resp = await self.trainml._query(
|
|
19
|
+
f"/provider/{provider_uuid}/region/{region_uuid}/data_connector",
|
|
20
|
+
"GET",
|
|
21
|
+
kwargs,
|
|
22
|
+
)
|
|
23
|
+
data_connectors = [
|
|
24
|
+
DataConnector(self.trainml, **data_connector) for data_connector in resp
|
|
25
|
+
]
|
|
26
|
+
return data_connectors
|
|
27
|
+
|
|
28
|
+
async def create(
|
|
29
|
+
self,
|
|
30
|
+
provider_uuid,
|
|
31
|
+
region_uuid,
|
|
32
|
+
name,
|
|
33
|
+
type,
|
|
34
|
+
**kwargs,
|
|
35
|
+
):
|
|
36
|
+
logging.info(f"Creating Data Connector {name}")
|
|
37
|
+
data = dict(
|
|
38
|
+
name=name,
|
|
39
|
+
type=type,
|
|
40
|
+
**kwargs,
|
|
41
|
+
)
|
|
42
|
+
payload = {k: v for k, v in data.items() if v is not None}
|
|
43
|
+
resp = await self.trainml._query(
|
|
44
|
+
f"/provider/{provider_uuid}/region/{region_uuid}/data_connector",
|
|
45
|
+
"POST",
|
|
46
|
+
None,
|
|
47
|
+
payload,
|
|
48
|
+
)
|
|
49
|
+
data_connector = DataConnector(self.trainml, **resp)
|
|
50
|
+
logging.info(f"Created Data Connector {name} with id {data_connector.id}")
|
|
51
|
+
return data_connector
|
|
52
|
+
|
|
53
|
+
async def remove(self, provider_uuid, region_uuid, id, **kwargs):
|
|
54
|
+
await self.trainml._query(
|
|
55
|
+
f"/provider/{provider_uuid}/region/{region_uuid}/data_connector/{id}",
|
|
56
|
+
"DELETE",
|
|
57
|
+
kwargs,
|
|
58
|
+
)
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
class DataConnector:
|
|
62
|
+
def __init__(self, trainml, **kwargs):
|
|
63
|
+
self.trainml = trainml
|
|
64
|
+
self._data_connector = kwargs
|
|
65
|
+
self._id = self._data_connector.get("connector_id")
|
|
66
|
+
self._provider_uuid = self._data_connector.get("provider_uuid")
|
|
67
|
+
self._region_uuid = self._data_connector.get("region_uuid")
|
|
68
|
+
self._type = self._data_connector.get("type")
|
|
69
|
+
self._name = self._data_connector.get("name")
|
|
70
|
+
|
|
71
|
+
@property
|
|
72
|
+
def id(self) -> str:
|
|
73
|
+
return self._id
|
|
74
|
+
|
|
75
|
+
@property
|
|
76
|
+
def provider_uuid(self) -> str:
|
|
77
|
+
return self._provider_uuid
|
|
78
|
+
|
|
79
|
+
@property
|
|
80
|
+
def region_uuid(self) -> str:
|
|
81
|
+
return self._region_uuid
|
|
82
|
+
|
|
83
|
+
@property
|
|
84
|
+
def type(self) -> str:
|
|
85
|
+
return self._type
|
|
86
|
+
|
|
87
|
+
@property
|
|
88
|
+
def name(self) -> str:
|
|
89
|
+
return self._name
|
|
90
|
+
|
|
91
|
+
def __str__(self):
|
|
92
|
+
return json.dumps({k: v for k, v in self._data_connector.items()})
|
|
93
|
+
|
|
94
|
+
def __repr__(self):
|
|
95
|
+
return f"DataConnector( trainml , **{self._data_connector.__repr__()})"
|
|
96
|
+
|
|
97
|
+
def __bool__(self):
|
|
98
|
+
return bool(self._id)
|
|
99
|
+
|
|
100
|
+
async def remove(self):
|
|
101
|
+
await self.trainml._query(
|
|
102
|
+
f"/provider/{self._provider_uuid}/region/{self._region_uuid}/data_connector/{self._id}",
|
|
103
|
+
"DELETE",
|
|
104
|
+
)
|
|
105
|
+
|
|
106
|
+
async def refresh(self):
|
|
107
|
+
resp = await self.trainml._query(
|
|
108
|
+
f"/provider/{self._provider_uuid}/region/{self._region_uuid}/data_connector/{self._id}",
|
|
109
|
+
"GET",
|
|
110
|
+
)
|
|
111
|
+
self.__init__(self.trainml, **resp)
|
|
112
|
+
return self
|
trainml/cloudbender/services.py
CHANGED
|
@@ -1,5 +1,13 @@
|
|
|
1
1
|
import json
|
|
2
2
|
import logging
|
|
3
|
+
import asyncio
|
|
4
|
+
import math
|
|
5
|
+
|
|
6
|
+
from trainml.exceptions import (
|
|
7
|
+
ApiError,
|
|
8
|
+
SpecificationError,
|
|
9
|
+
TrainMLException,
|
|
10
|
+
)
|
|
3
11
|
|
|
4
12
|
|
|
5
13
|
class Services(object):
|
|
@@ -28,12 +36,14 @@ class Services(object):
|
|
|
28
36
|
provider_uuid,
|
|
29
37
|
region_uuid,
|
|
30
38
|
name,
|
|
39
|
+
type,
|
|
31
40
|
public,
|
|
32
41
|
**kwargs,
|
|
33
42
|
):
|
|
34
43
|
logging.info(f"Creating Service {name}")
|
|
35
44
|
data = dict(
|
|
36
45
|
name=name,
|
|
46
|
+
type=type,
|
|
37
47
|
public=public,
|
|
38
48
|
**kwargs,
|
|
39
49
|
)
|
|
@@ -65,7 +75,12 @@ class Service:
|
|
|
65
75
|
self._region_uuid = self._service.get("region_uuid")
|
|
66
76
|
self._public = self._service.get("public")
|
|
67
77
|
self._name = self._service.get("name")
|
|
68
|
-
self.
|
|
78
|
+
self._type = self._service.get("type")
|
|
79
|
+
self._hostname = self._service.get("custom_hostname") or self._service.get(
|
|
80
|
+
"hostname"
|
|
81
|
+
)
|
|
82
|
+
self._status = self._service.get("status")
|
|
83
|
+
self._port = self._service.get("port")
|
|
69
84
|
|
|
70
85
|
@property
|
|
71
86
|
def id(self) -> str:
|
|
@@ -91,6 +106,18 @@ class Service:
|
|
|
91
106
|
def hostname(self) -> str:
|
|
92
107
|
return self._hostname
|
|
93
108
|
|
|
109
|
+
@property
|
|
110
|
+
def status(self) -> str:
|
|
111
|
+
return self._status
|
|
112
|
+
|
|
113
|
+
@property
|
|
114
|
+
def type(self) -> str:
|
|
115
|
+
return self._type
|
|
116
|
+
|
|
117
|
+
@property
|
|
118
|
+
def port(self) -> str:
|
|
119
|
+
return self._port
|
|
120
|
+
|
|
94
121
|
def __str__(self):
|
|
95
122
|
return json.dumps({k: v for k, v in self._service.items()})
|
|
96
123
|
|
|
@@ -113,3 +140,40 @@ class Service:
|
|
|
113
140
|
)
|
|
114
141
|
self.__init__(self.trainml, **resp)
|
|
115
142
|
return self
|
|
143
|
+
|
|
144
|
+
async def wait_for(self, status, timeout=300):
|
|
145
|
+
if self.status == status:
|
|
146
|
+
return
|
|
147
|
+
valid_statuses = ["active", "archived"]
|
|
148
|
+
if not status in valid_statuses:
|
|
149
|
+
raise SpecificationError(
|
|
150
|
+
"status",
|
|
151
|
+
f"Invalid wait_for status {status}. Valid statuses are: {valid_statuses}",
|
|
152
|
+
)
|
|
153
|
+
MAX_TIMEOUT = 24 * 60 * 60
|
|
154
|
+
if timeout > MAX_TIMEOUT:
|
|
155
|
+
raise SpecificationError(
|
|
156
|
+
"timeout",
|
|
157
|
+
f"timeout must be less than {MAX_TIMEOUT} seconds.",
|
|
158
|
+
)
|
|
159
|
+
|
|
160
|
+
POLL_INTERVAL_MIN = 5
|
|
161
|
+
POLL_INTERVAL_MAX = 60
|
|
162
|
+
POLL_INTERVAL = max(min(timeout / 60, POLL_INTERVAL_MAX), POLL_INTERVAL_MIN)
|
|
163
|
+
retry_count = math.ceil(timeout / POLL_INTERVAL)
|
|
164
|
+
count = 0
|
|
165
|
+
while count < retry_count:
|
|
166
|
+
await asyncio.sleep(POLL_INTERVAL)
|
|
167
|
+
try:
|
|
168
|
+
await self.refresh()
|
|
169
|
+
except ApiError as e:
|
|
170
|
+
if status == "archived" and e.status == 404:
|
|
171
|
+
return
|
|
172
|
+
raise e
|
|
173
|
+
if self.status == status:
|
|
174
|
+
return self
|
|
175
|
+
else:
|
|
176
|
+
count += 1
|
|
177
|
+
logging.debug(f"self: {self}, retry count {count}")
|
|
178
|
+
|
|
179
|
+
raise TrainMLException(f"Timeout waiting for {status}")
|
trainml/datasets.py
CHANGED
|
@@ -60,7 +60,10 @@ class Dataset:
|
|
|
60
60
|
self._id = self._dataset.get("id", self._dataset.get("dataset_uuid"))
|
|
61
61
|
self._status = self._dataset.get("status")
|
|
62
62
|
self._name = self._dataset.get("name")
|
|
63
|
-
self._size = self._dataset.get("size")
|
|
63
|
+
self._size = self._dataset.get("size") or self._dataset.get("used_size")
|
|
64
|
+
self._billed_size = self._dataset.get("billed_size") or self._dataset.get(
|
|
65
|
+
"size"
|
|
66
|
+
)
|
|
64
67
|
self._project_uuid = self._dataset.get("project_uuid")
|
|
65
68
|
|
|
66
69
|
@property
|
|
@@ -79,6 +82,10 @@ class Dataset:
|
|
|
79
82
|
def size(self) -> int:
|
|
80
83
|
return self._size or 0
|
|
81
84
|
|
|
85
|
+
@property
|
|
86
|
+
def billed_size(self) -> int:
|
|
87
|
+
return self._billed_size
|
|
88
|
+
|
|
82
89
|
def __str__(self):
|
|
83
90
|
return json.dumps({k: v for k, v in self._dataset.items()})
|
|
84
91
|
|
|
@@ -119,12 +126,16 @@ class Dataset:
|
|
|
119
126
|
project_uuid=self._dataset.get("project_uuid"),
|
|
120
127
|
cidr=self._dataset.get("vpn").get("cidr"),
|
|
121
128
|
ssh_port=self._dataset.get("vpn").get("client").get("ssh_port"),
|
|
122
|
-
input_path=
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
129
|
+
input_path=(
|
|
130
|
+
self._dataset.get("source_uri")
|
|
131
|
+
if self.status in ["new", "downloading"]
|
|
132
|
+
else None
|
|
133
|
+
),
|
|
134
|
+
output_path=(
|
|
135
|
+
self._dataset.get("output_uri")
|
|
136
|
+
if self.status == "exporting"
|
|
137
|
+
else None
|
|
138
|
+
),
|
|
128
139
|
)
|
|
129
140
|
else:
|
|
130
141
|
details = dict()
|
|
@@ -215,14 +226,21 @@ class Dataset:
|
|
|
215
226
|
return self
|
|
216
227
|
|
|
217
228
|
async def wait_for(self, status, timeout=300):
|
|
229
|
+
if self.status == status:
|
|
230
|
+
return
|
|
218
231
|
valid_statuses = ["downloading", "ready", "archived"]
|
|
219
232
|
if not status in valid_statuses:
|
|
220
233
|
raise SpecificationError(
|
|
221
234
|
"status",
|
|
222
235
|
f"Invalid wait_for status {status}. Valid statuses are: {valid_statuses}",
|
|
223
236
|
)
|
|
224
|
-
|
|
225
|
-
|
|
237
|
+
MAX_TIMEOUT = 24 * 60 * 60
|
|
238
|
+
if timeout > MAX_TIMEOUT:
|
|
239
|
+
raise SpecificationError(
|
|
240
|
+
"timeout",
|
|
241
|
+
f"timeout must be less than {MAX_TIMEOUT} seconds.",
|
|
242
|
+
)
|
|
243
|
+
|
|
226
244
|
POLL_INTERVAL_MIN = 5
|
|
227
245
|
POLL_INTERVAL_MAX = 60
|
|
228
246
|
POLL_INTERVAL = max(min(timeout / 60, POLL_INTERVAL_MAX), POLL_INTERVAL_MIN)
|
trainml/jobs.py
CHANGED
|
@@ -468,6 +468,12 @@ class Job:
|
|
|
468
468
|
return job
|
|
469
469
|
|
|
470
470
|
async def wait_for(self, status, timeout=300):
|
|
471
|
+
if self.status == status or (
|
|
472
|
+
self.type == "training"
|
|
473
|
+
and status == "finished"
|
|
474
|
+
and self.status == "stopped"
|
|
475
|
+
):
|
|
476
|
+
return
|
|
471
477
|
valid_statuses = [
|
|
472
478
|
"waiting for data/model download",
|
|
473
479
|
"waiting for GPUs",
|
|
@@ -492,12 +498,13 @@ class Job:
|
|
|
492
498
|
"'stopped' status is deprecated for training jobs, use 'finished' instead.",
|
|
493
499
|
DeprecationWarning,
|
|
494
500
|
)
|
|
495
|
-
|
|
496
|
-
|
|
497
|
-
|
|
498
|
-
|
|
499
|
-
|
|
500
|
-
|
|
501
|
+
|
|
502
|
+
MAX_TIMEOUT = 24 * 60 * 60
|
|
503
|
+
if timeout > MAX_TIMEOUT:
|
|
504
|
+
raise SpecificationError(
|
|
505
|
+
"timeout",
|
|
506
|
+
f"timeout must be less than {MAX_TIMEOUT} seconds.",
|
|
507
|
+
)
|
|
501
508
|
|
|
502
509
|
POLL_INTERVAL_MIN = 5
|
|
503
510
|
POLL_INTERVAL_MAX = 60
|
trainml/models.py
CHANGED
|
@@ -32,8 +32,7 @@ class Models(object):
|
|
|
32
32
|
source_type=source_type,
|
|
33
33
|
source_uri=source_uri,
|
|
34
34
|
source_options=kwargs.get("source_options"),
|
|
35
|
-
project_uuid=kwargs.get("project_uuid")
|
|
36
|
-
or self.trainml.active_project,
|
|
35
|
+
project_uuid=kwargs.get("project_uuid") or self.trainml.active_project,
|
|
37
36
|
)
|
|
38
37
|
payload = {k: v for k, v in data.items() if v is not None}
|
|
39
38
|
logging.info(f"Creating Model {name}")
|
|
@@ -44,9 +43,7 @@ class Models(object):
|
|
|
44
43
|
return model
|
|
45
44
|
|
|
46
45
|
async def remove(self, id, **kwargs):
|
|
47
|
-
await self.trainml._query(
|
|
48
|
-
f"/model/{id}", "DELETE", dict(**kwargs, force=True)
|
|
49
|
-
)
|
|
46
|
+
await self.trainml._query(f"/model/{id}", "DELETE", dict(**kwargs, force=True))
|
|
50
47
|
|
|
51
48
|
|
|
52
49
|
class Model:
|
|
@@ -56,7 +53,8 @@ class Model:
|
|
|
56
53
|
self._id = self._model.get("id", self._model.get("model_uuid"))
|
|
57
54
|
self._status = self._model.get("status")
|
|
58
55
|
self._name = self._model.get("name")
|
|
59
|
-
self._size = self._model.get("size")
|
|
56
|
+
self._size = self._model.get("size") or self._model.get("used_size")
|
|
57
|
+
self._billed_size = self._model.get("billed_size") or self._model.get("size")
|
|
60
58
|
self._project_uuid = self._model.get("project_uuid")
|
|
61
59
|
|
|
62
60
|
@property
|
|
@@ -75,6 +73,10 @@ class Model:
|
|
|
75
73
|
def size(self) -> int:
|
|
76
74
|
return self._size
|
|
77
75
|
|
|
76
|
+
@property
|
|
77
|
+
def billed_size(self) -> int:
|
|
78
|
+
return self._billed_size
|
|
79
|
+
|
|
78
80
|
def __str__(self):
|
|
79
81
|
return json.dumps({k: v for k, v in self._model.items()})
|
|
80
82
|
|
|
@@ -115,12 +117,16 @@ class Model:
|
|
|
115
117
|
project_uuid=self._model.get("project_uuid"),
|
|
116
118
|
cidr=self._model.get("vpn").get("cidr"),
|
|
117
119
|
ssh_port=self._model.get("vpn").get("client").get("ssh_port"),
|
|
118
|
-
input_path=
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
120
|
+
input_path=(
|
|
121
|
+
self._model.get("source_uri")
|
|
122
|
+
if self.status in ["new", "downloading"]
|
|
123
|
+
else None
|
|
124
|
+
),
|
|
125
|
+
output_path=(
|
|
126
|
+
self._model.get("output_uri")
|
|
127
|
+
if self.status == "exporting"
|
|
128
|
+
else None
|
|
129
|
+
),
|
|
124
130
|
)
|
|
125
131
|
else:
|
|
126
132
|
details = dict()
|
|
@@ -185,9 +191,7 @@ class Model:
|
|
|
185
191
|
if msg_handler:
|
|
186
192
|
msg_handler(data)
|
|
187
193
|
else:
|
|
188
|
-
timestamp = datetime.fromtimestamp(
|
|
189
|
-
int(data.get("time")) / 1000
|
|
190
|
-
)
|
|
194
|
+
timestamp = datetime.fromtimestamp(int(data.get("time")) / 1000)
|
|
191
195
|
print(
|
|
192
196
|
f"{timestamp.strftime('%m/%d/%Y, %H:%M:%S')}: {data.get('msg').rstrip()}"
|
|
193
197
|
)
|
|
@@ -214,19 +218,23 @@ class Model:
|
|
|
214
218
|
return self
|
|
215
219
|
|
|
216
220
|
async def wait_for(self, status, timeout=300):
|
|
221
|
+
if self.status == status:
|
|
222
|
+
return
|
|
217
223
|
valid_statuses = ["downloading", "ready", "archived"]
|
|
218
224
|
if not status in valid_statuses:
|
|
219
225
|
raise SpecificationError(
|
|
220
226
|
"status",
|
|
221
227
|
f"Invalid wait_for status {status}. Valid statuses are: {valid_statuses}",
|
|
222
228
|
)
|
|
223
|
-
|
|
224
|
-
|
|
229
|
+
MAX_TIMEOUT = 24 * 60 * 60
|
|
230
|
+
if timeout > MAX_TIMEOUT:
|
|
231
|
+
raise SpecificationError(
|
|
232
|
+
"timeout",
|
|
233
|
+
f"timeout must be less than {MAX_TIMEOUT} seconds.",
|
|
234
|
+
)
|
|
225
235
|
POLL_INTERVAL_MIN = 5
|
|
226
236
|
POLL_INTERVAL_MAX = 60
|
|
227
|
-
POLL_INTERVAL = max(
|
|
228
|
-
min(timeout / 60, POLL_INTERVAL_MAX), POLL_INTERVAL_MIN
|
|
229
|
-
)
|
|
237
|
+
POLL_INTERVAL = max(min(timeout / 60, POLL_INTERVAL_MAX), POLL_INTERVAL_MIN)
|
|
230
238
|
retry_count = math.ceil(timeout / POLL_INTERVAL)
|
|
231
239
|
count = 0
|
|
232
240
|
while count < retry_count:
|
trainml/projects.py
CHANGED
|
@@ -10,6 +10,12 @@ class Projects(object):
|
|
|
10
10
|
resp = await self.trainml._query(f"/project/{id}", "GET", kwargs)
|
|
11
11
|
return Project(self.trainml, **resp)
|
|
12
12
|
|
|
13
|
+
async def get_current(self, **kwargs):
|
|
14
|
+
resp = await self.trainml._query(
|
|
15
|
+
f"/project/{self.trainml.project}", "GET", kwargs
|
|
16
|
+
)
|
|
17
|
+
return Project(self.trainml, **resp)
|
|
18
|
+
|
|
13
19
|
async def list(self, **kwargs):
|
|
14
20
|
resp = await self.trainml._query(f"/project", "GET", kwargs)
|
|
15
21
|
projects = [Project(self.trainml, **project) for project in resp]
|
|
@@ -72,6 +78,46 @@ class ProjectDatastore:
|
|
|
72
78
|
return bool(self._id)
|
|
73
79
|
|
|
74
80
|
|
|
81
|
+
class ProjectDataConnector:
|
|
82
|
+
def __init__(self, trainml, **kwargs):
|
|
83
|
+
self.trainml = trainml
|
|
84
|
+
self._data_connector = kwargs
|
|
85
|
+
self._id = self._data_connector.get("id")
|
|
86
|
+
self._project_uuid = self._data_connector.get("project_uuid")
|
|
87
|
+
self._name = self._data_connector.get("name")
|
|
88
|
+
self._type = self._data_connector.get("type")
|
|
89
|
+
self._region_uuid = self._data_connector.get("region_uuid")
|
|
90
|
+
|
|
91
|
+
@property
|
|
92
|
+
def id(self) -> str:
|
|
93
|
+
return self._id
|
|
94
|
+
|
|
95
|
+
@property
|
|
96
|
+
def project_uuid(self) -> str:
|
|
97
|
+
return self._project_uuid
|
|
98
|
+
|
|
99
|
+
@property
|
|
100
|
+
def name(self) -> str:
|
|
101
|
+
return self._name
|
|
102
|
+
|
|
103
|
+
@property
|
|
104
|
+
def type(self) -> str:
|
|
105
|
+
return self._type
|
|
106
|
+
|
|
107
|
+
@property
|
|
108
|
+
def region_uuid(self) -> str:
|
|
109
|
+
return self._region_uuid
|
|
110
|
+
|
|
111
|
+
def __str__(self):
|
|
112
|
+
return json.dumps({k: v for k, v in self._data_connector.items()})
|
|
113
|
+
|
|
114
|
+
def __repr__(self):
|
|
115
|
+
return f"ProjectDataConnector( trainml , **{self._data_connector.__repr__()})"
|
|
116
|
+
|
|
117
|
+
def __bool__(self):
|
|
118
|
+
return bool(self._id)
|
|
119
|
+
|
|
120
|
+
|
|
75
121
|
class ProjectService:
|
|
76
122
|
def __init__(self, trainml, **kwargs):
|
|
77
123
|
self.trainml = trainml
|
|
@@ -79,9 +125,8 @@ class ProjectService:
|
|
|
79
125
|
self._id = self._service.get("id")
|
|
80
126
|
self._project_uuid = self._service.get("project_uuid")
|
|
81
127
|
self._name = self._service.get("name")
|
|
82
|
-
self._type = self._service.get("type")
|
|
83
128
|
self._hostname = self._service.get("hostname")
|
|
84
|
-
self.
|
|
129
|
+
self._public = self._service.get("public")
|
|
85
130
|
self._region_uuid = self._service.get("region_uuid")
|
|
86
131
|
|
|
87
132
|
@property
|
|
@@ -96,17 +141,13 @@ class ProjectService:
|
|
|
96
141
|
def name(self) -> str:
|
|
97
142
|
return self._name
|
|
98
143
|
|
|
99
|
-
@property
|
|
100
|
-
def type(self) -> str:
|
|
101
|
-
return self._type
|
|
102
|
-
|
|
103
144
|
@property
|
|
104
145
|
def hostname(self) -> str:
|
|
105
146
|
return self._hostname
|
|
106
147
|
|
|
107
148
|
@property
|
|
108
|
-
def
|
|
109
|
-
return self.
|
|
149
|
+
def public(self) -> bool:
|
|
150
|
+
return self._public
|
|
110
151
|
|
|
111
152
|
@property
|
|
112
153
|
def region_uuid(self) -> str:
|
|
@@ -164,6 +205,14 @@ class Project:
|
|
|
164
205
|
datastores = [ProjectDatastore(self.trainml, **datastore) for datastore in resp]
|
|
165
206
|
return datastores
|
|
166
207
|
|
|
208
|
+
async def list_data_connectors(self):
|
|
209
|
+
resp = await self.trainml._query(f"/project/{self._id}/data_connectors", "GET")
|
|
210
|
+
data_connectors = [
|
|
211
|
+
ProjectDataConnector(self.trainml, **data_connector)
|
|
212
|
+
for data_connector in resp
|
|
213
|
+
]
|
|
214
|
+
return data_connectors
|
|
215
|
+
|
|
167
216
|
async def list_services(self):
|
|
168
217
|
resp = await self.trainml._query(f"/project/{self._id}/services", "GET")
|
|
169
218
|
services = [ProjectService(self.trainml, **service) for service in resp]
|
|
@@ -172,5 +221,8 @@ class Project:
|
|
|
172
221
|
async def refresh_datastores(self):
|
|
173
222
|
await self.trainml._query(f"/project/{self._id}/datastores", "PATCH")
|
|
174
223
|
|
|
224
|
+
async def refresh_data_connectors(self):
|
|
225
|
+
await self.trainml._query(f"/project/{self._id}/data_connectors", "PATCH")
|
|
226
|
+
|
|
175
227
|
async def refresh_services(self):
|
|
176
228
|
await self.trainml._query(f"/project/{self._id}/services", "PATCH")
|
trainml/volumes.py
CHANGED
|
@@ -223,14 +223,21 @@ class Volume:
|
|
|
223
223
|
return self
|
|
224
224
|
|
|
225
225
|
async def wait_for(self, status, timeout=300):
|
|
226
|
+
if self.status == status:
|
|
227
|
+
return
|
|
226
228
|
valid_statuses = ["downloading", "ready", "archived"]
|
|
227
229
|
if not status in valid_statuses:
|
|
228
230
|
raise SpecificationError(
|
|
229
231
|
"status",
|
|
230
232
|
f"Invalid wait_for status {status}. Valid statuses are: {valid_statuses}",
|
|
231
233
|
)
|
|
232
|
-
|
|
233
|
-
|
|
234
|
+
|
|
235
|
+
MAX_TIMEOUT = 24 * 60 * 60
|
|
236
|
+
if timeout > MAX_TIMEOUT:
|
|
237
|
+
raise SpecificationError(
|
|
238
|
+
"timeout",
|
|
239
|
+
f"timeout must be less than {MAX_TIMEOUT} seconds.",
|
|
240
|
+
)
|
|
234
241
|
POLL_INTERVAL_MIN = 5
|
|
235
242
|
POLL_INTERVAL_MAX = 60
|
|
236
243
|
POLL_INTERVAL = max(min(timeout / 60, POLL_INTERVAL_MAX), POLL_INTERVAL_MIN)
|