trainml 0.5.5__py3-none-any.whl → 0.5.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tests/integration/test_jobs_integration.py +13 -0
- tests/unit/cli/cloudbender/test_cli_reservation_unit.py +10 -14
- tests/unit/cli/cloudbender/test_cli_service_unit.py +34 -0
- tests/unit/cli/test_cli_project_unit.py +5 -9
- tests/unit/cloudbender/test_data_connectors_unit.py +176 -0
- tests/unit/cloudbender/test_services_unit.py +167 -0
- tests/unit/conftest.py +13 -13
- tests/unit/test_projects_unit.py +77 -51
- trainml/__init__.py +1 -1
- trainml/checkpoints.py +25 -25
- trainml/cli/cloudbender/__init__.py +2 -1
- trainml/cli/cloudbender/data_connector.py +159 -0
- trainml/cli/cloudbender/service.py +146 -0
- trainml/cli/project.py +10 -15
- trainml/cloudbender/cloudbender.py +4 -2
- trainml/cloudbender/data_connectors.py +112 -0
- trainml/cloudbender/services.py +179 -0
- trainml/datasets.py +19 -8
- trainml/jobs.py +13 -6
- trainml/models.py +22 -19
- trainml/projects.py +72 -31
- trainml/volumes.py +9 -2
- {trainml-0.5.5.dist-info → trainml-0.5.7.dist-info}/METADATA +1 -1
- {trainml-0.5.5.dist-info → trainml-0.5.7.dist-info}/RECORD +28 -21
- {trainml-0.5.5.dist-info → trainml-0.5.7.dist-info}/LICENSE +0 -0
- {trainml-0.5.5.dist-info → trainml-0.5.7.dist-info}/WHEEL +0 -0
- {trainml-0.5.5.dist-info → trainml-0.5.7.dist-info}/entry_points.txt +0 -0
- {trainml-0.5.5.dist-info → trainml-0.5.7.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,112 @@
|
|
|
1
|
+
import json
|
|
2
|
+
import logging
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
class DataConnectors(object):
|
|
6
|
+
def __init__(self, trainml):
|
|
7
|
+
self.trainml = trainml
|
|
8
|
+
|
|
9
|
+
async def get(self, provider_uuid, region_uuid, id, **kwargs):
|
|
10
|
+
resp = await self.trainml._query(
|
|
11
|
+
f"/provider/{provider_uuid}/region/{region_uuid}/data_connector/{id}",
|
|
12
|
+
"GET",
|
|
13
|
+
kwargs,
|
|
14
|
+
)
|
|
15
|
+
return DataConnector(self.trainml, **resp)
|
|
16
|
+
|
|
17
|
+
async def list(self, provider_uuid, region_uuid, **kwargs):
|
|
18
|
+
resp = await self.trainml._query(
|
|
19
|
+
f"/provider/{provider_uuid}/region/{region_uuid}/data_connector",
|
|
20
|
+
"GET",
|
|
21
|
+
kwargs,
|
|
22
|
+
)
|
|
23
|
+
data_connectors = [
|
|
24
|
+
DataConnector(self.trainml, **data_connector) for data_connector in resp
|
|
25
|
+
]
|
|
26
|
+
return data_connectors
|
|
27
|
+
|
|
28
|
+
async def create(
|
|
29
|
+
self,
|
|
30
|
+
provider_uuid,
|
|
31
|
+
region_uuid,
|
|
32
|
+
name,
|
|
33
|
+
type,
|
|
34
|
+
**kwargs,
|
|
35
|
+
):
|
|
36
|
+
logging.info(f"Creating Data Connector {name}")
|
|
37
|
+
data = dict(
|
|
38
|
+
name=name,
|
|
39
|
+
type=type,
|
|
40
|
+
**kwargs,
|
|
41
|
+
)
|
|
42
|
+
payload = {k: v for k, v in data.items() if v is not None}
|
|
43
|
+
resp = await self.trainml._query(
|
|
44
|
+
f"/provider/{provider_uuid}/region/{region_uuid}/data_connector",
|
|
45
|
+
"POST",
|
|
46
|
+
None,
|
|
47
|
+
payload,
|
|
48
|
+
)
|
|
49
|
+
data_connector = DataConnector(self.trainml, **resp)
|
|
50
|
+
logging.info(f"Created Data Connector {name} with id {data_connector.id}")
|
|
51
|
+
return data_connector
|
|
52
|
+
|
|
53
|
+
async def remove(self, provider_uuid, region_uuid, id, **kwargs):
|
|
54
|
+
await self.trainml._query(
|
|
55
|
+
f"/provider/{provider_uuid}/region/{region_uuid}/data_connector/{id}",
|
|
56
|
+
"DELETE",
|
|
57
|
+
kwargs,
|
|
58
|
+
)
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
class DataConnector:
|
|
62
|
+
def __init__(self, trainml, **kwargs):
|
|
63
|
+
self.trainml = trainml
|
|
64
|
+
self._data_connector = kwargs
|
|
65
|
+
self._id = self._data_connector.get("connector_id")
|
|
66
|
+
self._provider_uuid = self._data_connector.get("provider_uuid")
|
|
67
|
+
self._region_uuid = self._data_connector.get("region_uuid")
|
|
68
|
+
self._type = self._data_connector.get("type")
|
|
69
|
+
self._name = self._data_connector.get("name")
|
|
70
|
+
|
|
71
|
+
@property
|
|
72
|
+
def id(self) -> str:
|
|
73
|
+
return self._id
|
|
74
|
+
|
|
75
|
+
@property
|
|
76
|
+
def provider_uuid(self) -> str:
|
|
77
|
+
return self._provider_uuid
|
|
78
|
+
|
|
79
|
+
@property
|
|
80
|
+
def region_uuid(self) -> str:
|
|
81
|
+
return self._region_uuid
|
|
82
|
+
|
|
83
|
+
@property
|
|
84
|
+
def type(self) -> str:
|
|
85
|
+
return self._type
|
|
86
|
+
|
|
87
|
+
@property
|
|
88
|
+
def name(self) -> str:
|
|
89
|
+
return self._name
|
|
90
|
+
|
|
91
|
+
def __str__(self):
|
|
92
|
+
return json.dumps({k: v for k, v in self._data_connector.items()})
|
|
93
|
+
|
|
94
|
+
def __repr__(self):
|
|
95
|
+
return f"DataConnector( trainml , **{self._data_connector.__repr__()})"
|
|
96
|
+
|
|
97
|
+
def __bool__(self):
|
|
98
|
+
return bool(self._id)
|
|
99
|
+
|
|
100
|
+
async def remove(self):
|
|
101
|
+
await self.trainml._query(
|
|
102
|
+
f"/provider/{self._provider_uuid}/region/{self._region_uuid}/data_connector/{self._id}",
|
|
103
|
+
"DELETE",
|
|
104
|
+
)
|
|
105
|
+
|
|
106
|
+
async def refresh(self):
|
|
107
|
+
resp = await self.trainml._query(
|
|
108
|
+
f"/provider/{self._provider_uuid}/region/{self._region_uuid}/data_connector/{self._id}",
|
|
109
|
+
"GET",
|
|
110
|
+
)
|
|
111
|
+
self.__init__(self.trainml, **resp)
|
|
112
|
+
return self
|
|
@@ -0,0 +1,179 @@
|
|
|
1
|
+
import json
|
|
2
|
+
import logging
|
|
3
|
+
import asyncio
|
|
4
|
+
import math
|
|
5
|
+
|
|
6
|
+
from trainml.exceptions import (
|
|
7
|
+
ApiError,
|
|
8
|
+
SpecificationError,
|
|
9
|
+
TrainMLException,
|
|
10
|
+
)
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class Services(object):
|
|
14
|
+
def __init__(self, trainml):
|
|
15
|
+
self.trainml = trainml
|
|
16
|
+
|
|
17
|
+
async def get(self, provider_uuid, region_uuid, id, **kwargs):
|
|
18
|
+
resp = await self.trainml._query(
|
|
19
|
+
f"/provider/{provider_uuid}/region/{region_uuid}/service/{id}",
|
|
20
|
+
"GET",
|
|
21
|
+
kwargs,
|
|
22
|
+
)
|
|
23
|
+
return Service(self.trainml, **resp)
|
|
24
|
+
|
|
25
|
+
async def list(self, provider_uuid, region_uuid, **kwargs):
|
|
26
|
+
resp = await self.trainml._query(
|
|
27
|
+
f"/provider/{provider_uuid}/region/{region_uuid}/service",
|
|
28
|
+
"GET",
|
|
29
|
+
kwargs,
|
|
30
|
+
)
|
|
31
|
+
services = [Service(self.trainml, **service) for service in resp]
|
|
32
|
+
return services
|
|
33
|
+
|
|
34
|
+
async def create(
|
|
35
|
+
self,
|
|
36
|
+
provider_uuid,
|
|
37
|
+
region_uuid,
|
|
38
|
+
name,
|
|
39
|
+
type,
|
|
40
|
+
public,
|
|
41
|
+
**kwargs,
|
|
42
|
+
):
|
|
43
|
+
logging.info(f"Creating Service {name}")
|
|
44
|
+
data = dict(
|
|
45
|
+
name=name,
|
|
46
|
+
type=type,
|
|
47
|
+
public=public,
|
|
48
|
+
**kwargs,
|
|
49
|
+
)
|
|
50
|
+
payload = {k: v for k, v in data.items() if v is not None}
|
|
51
|
+
resp = await self.trainml._query(
|
|
52
|
+
f"/provider/{provider_uuid}/region/{region_uuid}/service",
|
|
53
|
+
"POST",
|
|
54
|
+
None,
|
|
55
|
+
payload,
|
|
56
|
+
)
|
|
57
|
+
service = Service(self.trainml, **resp)
|
|
58
|
+
logging.info(f"Created Service {name} with id {service.id}")
|
|
59
|
+
return service
|
|
60
|
+
|
|
61
|
+
async def remove(self, provider_uuid, region_uuid, id, **kwargs):
|
|
62
|
+
await self.trainml._query(
|
|
63
|
+
f"/provider/{provider_uuid}/region/{region_uuid}/service/{id}",
|
|
64
|
+
"DELETE",
|
|
65
|
+
kwargs,
|
|
66
|
+
)
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
class Service:
|
|
70
|
+
def __init__(self, trainml, **kwargs):
|
|
71
|
+
self.trainml = trainml
|
|
72
|
+
self._service = kwargs
|
|
73
|
+
self._id = self._service.get("service_id")
|
|
74
|
+
self._provider_uuid = self._service.get("provider_uuid")
|
|
75
|
+
self._region_uuid = self._service.get("region_uuid")
|
|
76
|
+
self._public = self._service.get("public")
|
|
77
|
+
self._name = self._service.get("name")
|
|
78
|
+
self._type = self._service.get("type")
|
|
79
|
+
self._hostname = self._service.get("custom_hostname") or self._service.get(
|
|
80
|
+
"hostname"
|
|
81
|
+
)
|
|
82
|
+
self._status = self._service.get("status")
|
|
83
|
+
self._port = self._service.get("port")
|
|
84
|
+
|
|
85
|
+
@property
|
|
86
|
+
def id(self) -> str:
|
|
87
|
+
return self._id
|
|
88
|
+
|
|
89
|
+
@property
|
|
90
|
+
def provider_uuid(self) -> str:
|
|
91
|
+
return self._provider_uuid
|
|
92
|
+
|
|
93
|
+
@property
|
|
94
|
+
def region_uuid(self) -> str:
|
|
95
|
+
return self._region_uuid
|
|
96
|
+
|
|
97
|
+
@property
|
|
98
|
+
def public(self) -> bool:
|
|
99
|
+
return self._public
|
|
100
|
+
|
|
101
|
+
@property
|
|
102
|
+
def name(self) -> str:
|
|
103
|
+
return self._name
|
|
104
|
+
|
|
105
|
+
@property
|
|
106
|
+
def hostname(self) -> str:
|
|
107
|
+
return self._hostname
|
|
108
|
+
|
|
109
|
+
@property
|
|
110
|
+
def status(self) -> str:
|
|
111
|
+
return self._status
|
|
112
|
+
|
|
113
|
+
@property
|
|
114
|
+
def type(self) -> str:
|
|
115
|
+
return self._type
|
|
116
|
+
|
|
117
|
+
@property
|
|
118
|
+
def port(self) -> str:
|
|
119
|
+
return self._port
|
|
120
|
+
|
|
121
|
+
def __str__(self):
|
|
122
|
+
return json.dumps({k: v for k, v in self._service.items()})
|
|
123
|
+
|
|
124
|
+
def __repr__(self):
|
|
125
|
+
return f"Service( trainml , **{self._service.__repr__()})"
|
|
126
|
+
|
|
127
|
+
def __bool__(self):
|
|
128
|
+
return bool(self._id)
|
|
129
|
+
|
|
130
|
+
async def remove(self):
|
|
131
|
+
await self.trainml._query(
|
|
132
|
+
f"/provider/{self._provider_uuid}/region/{self._region_uuid}/service/{self._id}",
|
|
133
|
+
"DELETE",
|
|
134
|
+
)
|
|
135
|
+
|
|
136
|
+
async def refresh(self):
|
|
137
|
+
resp = await self.trainml._query(
|
|
138
|
+
f"/provider/{self._provider_uuid}/region/{self._region_uuid}/service/{self._id}",
|
|
139
|
+
"GET",
|
|
140
|
+
)
|
|
141
|
+
self.__init__(self.trainml, **resp)
|
|
142
|
+
return self
|
|
143
|
+
|
|
144
|
+
async def wait_for(self, status, timeout=300):
|
|
145
|
+
if self.status == status:
|
|
146
|
+
return
|
|
147
|
+
valid_statuses = ["active", "archived"]
|
|
148
|
+
if not status in valid_statuses:
|
|
149
|
+
raise SpecificationError(
|
|
150
|
+
"status",
|
|
151
|
+
f"Invalid wait_for status {status}. Valid statuses are: {valid_statuses}",
|
|
152
|
+
)
|
|
153
|
+
MAX_TIMEOUT = 24 * 60 * 60
|
|
154
|
+
if timeout > MAX_TIMEOUT:
|
|
155
|
+
raise SpecificationError(
|
|
156
|
+
"timeout",
|
|
157
|
+
f"timeout must be less than {MAX_TIMEOUT} seconds.",
|
|
158
|
+
)
|
|
159
|
+
|
|
160
|
+
POLL_INTERVAL_MIN = 5
|
|
161
|
+
POLL_INTERVAL_MAX = 60
|
|
162
|
+
POLL_INTERVAL = max(min(timeout / 60, POLL_INTERVAL_MAX), POLL_INTERVAL_MIN)
|
|
163
|
+
retry_count = math.ceil(timeout / POLL_INTERVAL)
|
|
164
|
+
count = 0
|
|
165
|
+
while count < retry_count:
|
|
166
|
+
await asyncio.sleep(POLL_INTERVAL)
|
|
167
|
+
try:
|
|
168
|
+
await self.refresh()
|
|
169
|
+
except ApiError as e:
|
|
170
|
+
if status == "archived" and e.status == 404:
|
|
171
|
+
return
|
|
172
|
+
raise e
|
|
173
|
+
if self.status == status:
|
|
174
|
+
return self
|
|
175
|
+
else:
|
|
176
|
+
count += 1
|
|
177
|
+
logging.debug(f"self: {self}, retry count {count}")
|
|
178
|
+
|
|
179
|
+
raise TrainMLException(f"Timeout waiting for {status}")
|
trainml/datasets.py
CHANGED
|
@@ -119,12 +119,16 @@ class Dataset:
|
|
|
119
119
|
project_uuid=self._dataset.get("project_uuid"),
|
|
120
120
|
cidr=self._dataset.get("vpn").get("cidr"),
|
|
121
121
|
ssh_port=self._dataset.get("vpn").get("client").get("ssh_port"),
|
|
122
|
-
input_path=
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
122
|
+
input_path=(
|
|
123
|
+
self._dataset.get("source_uri")
|
|
124
|
+
if self.status in ["new", "downloading"]
|
|
125
|
+
else None
|
|
126
|
+
),
|
|
127
|
+
output_path=(
|
|
128
|
+
self._dataset.get("output_uri")
|
|
129
|
+
if self.status == "exporting"
|
|
130
|
+
else None
|
|
131
|
+
),
|
|
128
132
|
)
|
|
129
133
|
else:
|
|
130
134
|
details = dict()
|
|
@@ -215,14 +219,21 @@ class Dataset:
|
|
|
215
219
|
return self
|
|
216
220
|
|
|
217
221
|
async def wait_for(self, status, timeout=300):
|
|
222
|
+
if self.status == status:
|
|
223
|
+
return
|
|
218
224
|
valid_statuses = ["downloading", "ready", "archived"]
|
|
219
225
|
if not status in valid_statuses:
|
|
220
226
|
raise SpecificationError(
|
|
221
227
|
"status",
|
|
222
228
|
f"Invalid wait_for status {status}. Valid statuses are: {valid_statuses}",
|
|
223
229
|
)
|
|
224
|
-
|
|
225
|
-
|
|
230
|
+
MAX_TIMEOUT = 24 * 60 * 60
|
|
231
|
+
if timeout > MAX_TIMEOUT:
|
|
232
|
+
raise SpecificationError(
|
|
233
|
+
"timeout",
|
|
234
|
+
f"timeout must be less than {MAX_TIMEOUT} seconds.",
|
|
235
|
+
)
|
|
236
|
+
|
|
226
237
|
POLL_INTERVAL_MIN = 5
|
|
227
238
|
POLL_INTERVAL_MAX = 60
|
|
228
239
|
POLL_INTERVAL = max(min(timeout / 60, POLL_INTERVAL_MAX), POLL_INTERVAL_MIN)
|
trainml/jobs.py
CHANGED
|
@@ -468,6 +468,12 @@ class Job:
|
|
|
468
468
|
return job
|
|
469
469
|
|
|
470
470
|
async def wait_for(self, status, timeout=300):
|
|
471
|
+
if self.status == status or (
|
|
472
|
+
self.type == "training"
|
|
473
|
+
and status == "finished"
|
|
474
|
+
and self.status == "stopped"
|
|
475
|
+
):
|
|
476
|
+
return
|
|
471
477
|
valid_statuses = [
|
|
472
478
|
"waiting for data/model download",
|
|
473
479
|
"waiting for GPUs",
|
|
@@ -492,12 +498,13 @@ class Job:
|
|
|
492
498
|
"'stopped' status is deprecated for training jobs, use 'finished' instead.",
|
|
493
499
|
DeprecationWarning,
|
|
494
500
|
)
|
|
495
|
-
|
|
496
|
-
|
|
497
|
-
|
|
498
|
-
|
|
499
|
-
|
|
500
|
-
|
|
501
|
+
|
|
502
|
+
MAX_TIMEOUT = 24 * 60 * 60
|
|
503
|
+
if timeout > MAX_TIMEOUT:
|
|
504
|
+
raise SpecificationError(
|
|
505
|
+
"timeout",
|
|
506
|
+
f"timeout must be less than {MAX_TIMEOUT} seconds.",
|
|
507
|
+
)
|
|
501
508
|
|
|
502
509
|
POLL_INTERVAL_MIN = 5
|
|
503
510
|
POLL_INTERVAL_MAX = 60
|
trainml/models.py
CHANGED
|
@@ -32,8 +32,7 @@ class Models(object):
|
|
|
32
32
|
source_type=source_type,
|
|
33
33
|
source_uri=source_uri,
|
|
34
34
|
source_options=kwargs.get("source_options"),
|
|
35
|
-
project_uuid=kwargs.get("project_uuid")
|
|
36
|
-
or self.trainml.active_project,
|
|
35
|
+
project_uuid=kwargs.get("project_uuid") or self.trainml.active_project,
|
|
37
36
|
)
|
|
38
37
|
payload = {k: v for k, v in data.items() if v is not None}
|
|
39
38
|
logging.info(f"Creating Model {name}")
|
|
@@ -44,9 +43,7 @@ class Models(object):
|
|
|
44
43
|
return model
|
|
45
44
|
|
|
46
45
|
async def remove(self, id, **kwargs):
|
|
47
|
-
await self.trainml._query(
|
|
48
|
-
f"/model/{id}", "DELETE", dict(**kwargs, force=True)
|
|
49
|
-
)
|
|
46
|
+
await self.trainml._query(f"/model/{id}", "DELETE", dict(**kwargs, force=True))
|
|
50
47
|
|
|
51
48
|
|
|
52
49
|
class Model:
|
|
@@ -115,12 +112,16 @@ class Model:
|
|
|
115
112
|
project_uuid=self._model.get("project_uuid"),
|
|
116
113
|
cidr=self._model.get("vpn").get("cidr"),
|
|
117
114
|
ssh_port=self._model.get("vpn").get("client").get("ssh_port"),
|
|
118
|
-
input_path=
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
115
|
+
input_path=(
|
|
116
|
+
self._model.get("source_uri")
|
|
117
|
+
if self.status in ["new", "downloading"]
|
|
118
|
+
else None
|
|
119
|
+
),
|
|
120
|
+
output_path=(
|
|
121
|
+
self._model.get("output_uri")
|
|
122
|
+
if self.status == "exporting"
|
|
123
|
+
else None
|
|
124
|
+
),
|
|
124
125
|
)
|
|
125
126
|
else:
|
|
126
127
|
details = dict()
|
|
@@ -185,9 +186,7 @@ class Model:
|
|
|
185
186
|
if msg_handler:
|
|
186
187
|
msg_handler(data)
|
|
187
188
|
else:
|
|
188
|
-
timestamp = datetime.fromtimestamp(
|
|
189
|
-
int(data.get("time")) / 1000
|
|
190
|
-
)
|
|
189
|
+
timestamp = datetime.fromtimestamp(int(data.get("time")) / 1000)
|
|
191
190
|
print(
|
|
192
191
|
f"{timestamp.strftime('%m/%d/%Y, %H:%M:%S')}: {data.get('msg').rstrip()}"
|
|
193
192
|
)
|
|
@@ -214,19 +213,23 @@ class Model:
|
|
|
214
213
|
return self
|
|
215
214
|
|
|
216
215
|
async def wait_for(self, status, timeout=300):
|
|
216
|
+
if self.status == status:
|
|
217
|
+
return
|
|
217
218
|
valid_statuses = ["downloading", "ready", "archived"]
|
|
218
219
|
if not status in valid_statuses:
|
|
219
220
|
raise SpecificationError(
|
|
220
221
|
"status",
|
|
221
222
|
f"Invalid wait_for status {status}. Valid statuses are: {valid_statuses}",
|
|
222
223
|
)
|
|
223
|
-
|
|
224
|
-
|
|
224
|
+
MAX_TIMEOUT = 24 * 60 * 60
|
|
225
|
+
if timeout > MAX_TIMEOUT:
|
|
226
|
+
raise SpecificationError(
|
|
227
|
+
"timeout",
|
|
228
|
+
f"timeout must be less than {MAX_TIMEOUT} seconds.",
|
|
229
|
+
)
|
|
225
230
|
POLL_INTERVAL_MIN = 5
|
|
226
231
|
POLL_INTERVAL_MAX = 60
|
|
227
|
-
POLL_INTERVAL = max(
|
|
228
|
-
min(timeout / 60, POLL_INTERVAL_MAX), POLL_INTERVAL_MIN
|
|
229
|
-
)
|
|
232
|
+
POLL_INTERVAL = max(min(timeout / 60, POLL_INTERVAL_MAX), POLL_INTERVAL_MIN)
|
|
230
233
|
retry_count = math.ceil(timeout / POLL_INTERVAL)
|
|
231
234
|
count = 0
|
|
232
235
|
while count < retry_count:
|
trainml/projects.py
CHANGED
|
@@ -10,6 +10,12 @@ class Projects(object):
|
|
|
10
10
|
resp = await self.trainml._query(f"/project/{id}", "GET", kwargs)
|
|
11
11
|
return Project(self.trainml, **resp)
|
|
12
12
|
|
|
13
|
+
async def get_current(self, **kwargs):
|
|
14
|
+
resp = await self.trainml._query(
|
|
15
|
+
f"/project/{self.trainml.project}", "GET", kwargs
|
|
16
|
+
)
|
|
17
|
+
return Project(self.trainml, **resp)
|
|
18
|
+
|
|
13
19
|
async def list(self, **kwargs):
|
|
14
20
|
resp = await self.trainml._query(f"/project", "GET", kwargs)
|
|
15
21
|
projects = [Project(self.trainml, **project) for project in resp]
|
|
@@ -72,17 +78,15 @@ class ProjectDatastore:
|
|
|
72
78
|
return bool(self._id)
|
|
73
79
|
|
|
74
80
|
|
|
75
|
-
class
|
|
81
|
+
class ProjectDataConnector:
|
|
76
82
|
def __init__(self, trainml, **kwargs):
|
|
77
83
|
self.trainml = trainml
|
|
78
|
-
self.
|
|
79
|
-
self._id = self.
|
|
80
|
-
self._project_uuid = self.
|
|
81
|
-
self._name = self.
|
|
82
|
-
self._type = self.
|
|
83
|
-
self.
|
|
84
|
-
self._resource = self._reservation.get("resource")
|
|
85
|
-
self._region_uuid = self._reservation.get("region_uuid")
|
|
84
|
+
self._data_connector = kwargs
|
|
85
|
+
self._id = self._data_connector.get("id")
|
|
86
|
+
self._project_uuid = self._data_connector.get("project_uuid")
|
|
87
|
+
self._name = self._data_connector.get("name")
|
|
88
|
+
self._type = self._data_connector.get("type")
|
|
89
|
+
self._region_uuid = self._data_connector.get("region_uuid")
|
|
86
90
|
|
|
87
91
|
@property
|
|
88
92
|
def id(self) -> str:
|
|
@@ -100,25 +104,60 @@ class ProjectReservation:
|
|
|
100
104
|
def type(self) -> str:
|
|
101
105
|
return self._type
|
|
102
106
|
|
|
107
|
+
@property
|
|
108
|
+
def region_uuid(self) -> str:
|
|
109
|
+
return self._region_uuid
|
|
110
|
+
|
|
111
|
+
def __str__(self):
|
|
112
|
+
return json.dumps({k: v for k, v in self._data_connector.items()})
|
|
113
|
+
|
|
114
|
+
def __repr__(self):
|
|
115
|
+
return f"ProjectDataConnector( trainml , **{self._data_connector.__repr__()})"
|
|
116
|
+
|
|
117
|
+
def __bool__(self):
|
|
118
|
+
return bool(self._id)
|
|
119
|
+
|
|
120
|
+
|
|
121
|
+
class ProjectService:
|
|
122
|
+
def __init__(self, trainml, **kwargs):
|
|
123
|
+
self.trainml = trainml
|
|
124
|
+
self._service = kwargs
|
|
125
|
+
self._id = self._service.get("id")
|
|
126
|
+
self._project_uuid = self._service.get("project_uuid")
|
|
127
|
+
self._name = self._service.get("name")
|
|
128
|
+
self._hostname = self._service.get("hostname")
|
|
129
|
+
self._public = self._service.get("public")
|
|
130
|
+
self._region_uuid = self._service.get("region_uuid")
|
|
131
|
+
|
|
132
|
+
@property
|
|
133
|
+
def id(self) -> str:
|
|
134
|
+
return self._id
|
|
135
|
+
|
|
136
|
+
@property
|
|
137
|
+
def project_uuid(self) -> str:
|
|
138
|
+
return self._project_uuid
|
|
139
|
+
|
|
140
|
+
@property
|
|
141
|
+
def name(self) -> str:
|
|
142
|
+
return self._name
|
|
143
|
+
|
|
103
144
|
@property
|
|
104
145
|
def hostname(self) -> str:
|
|
105
146
|
return self._hostname
|
|
106
147
|
|
|
107
148
|
@property
|
|
108
|
-
def
|
|
109
|
-
return self.
|
|
149
|
+
def public(self) -> bool:
|
|
150
|
+
return self._public
|
|
110
151
|
|
|
111
152
|
@property
|
|
112
153
|
def region_uuid(self) -> str:
|
|
113
154
|
return self._region_uuid
|
|
114
155
|
|
|
115
156
|
def __str__(self):
|
|
116
|
-
return json.dumps({k: v for k, v in self.
|
|
157
|
+
return json.dumps({k: v for k, v in self._service.items()})
|
|
117
158
|
|
|
118
159
|
def __repr__(self):
|
|
119
|
-
return (
|
|
120
|
-
f"ProjectReservation( trainml , **{self._reservation.__repr__()})"
|
|
121
|
-
)
|
|
160
|
+
return f"ProjectService( trainml , **{self._service.__repr__()})"
|
|
122
161
|
|
|
123
162
|
def __bool__(self):
|
|
124
163
|
return bool(self._id)
|
|
@@ -162,26 +201,28 @@ class Project:
|
|
|
162
201
|
await self.trainml._query(f"/project/{self._id}", "DELETE")
|
|
163
202
|
|
|
164
203
|
async def list_datastores(self):
|
|
165
|
-
resp = await self.trainml._query(
|
|
166
|
-
|
|
167
|
-
)
|
|
168
|
-
datastores = [
|
|
169
|
-
ProjectDatastore(self.trainml, **datastore) for datastore in resp
|
|
170
|
-
]
|
|
204
|
+
resp = await self.trainml._query(f"/project/{self._id}/datastores", "GET")
|
|
205
|
+
datastores = [ProjectDatastore(self.trainml, **datastore) for datastore in resp]
|
|
171
206
|
return datastores
|
|
172
207
|
|
|
173
|
-
async def
|
|
174
|
-
resp = await self.trainml._query(
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
ProjectReservation(self.trainml, **reservation)
|
|
179
|
-
for reservation in resp
|
|
208
|
+
async def list_data_connectors(self):
|
|
209
|
+
resp = await self.trainml._query(f"/project/{self._id}/data_connectors", "GET")
|
|
210
|
+
data_connectors = [
|
|
211
|
+
ProjectDataConnector(self.trainml, **data_connector)
|
|
212
|
+
for data_connector in resp
|
|
180
213
|
]
|
|
181
|
-
return
|
|
214
|
+
return data_connectors
|
|
215
|
+
|
|
216
|
+
async def list_services(self):
|
|
217
|
+
resp = await self.trainml._query(f"/project/{self._id}/services", "GET")
|
|
218
|
+
services = [ProjectService(self.trainml, **service) for service in resp]
|
|
219
|
+
return services
|
|
182
220
|
|
|
183
221
|
async def refresh_datastores(self):
|
|
184
222
|
await self.trainml._query(f"/project/{self._id}/datastores", "PATCH")
|
|
185
223
|
|
|
186
|
-
async def
|
|
187
|
-
await self.trainml._query(f"/project/{self._id}/
|
|
224
|
+
async def refresh_data_connectors(self):
|
|
225
|
+
await self.trainml._query(f"/project/{self._id}/data_connectors", "PATCH")
|
|
226
|
+
|
|
227
|
+
async def refresh_services(self):
|
|
228
|
+
await self.trainml._query(f"/project/{self._id}/services", "PATCH")
|
trainml/volumes.py
CHANGED
|
@@ -223,14 +223,21 @@ class Volume:
|
|
|
223
223
|
return self
|
|
224
224
|
|
|
225
225
|
async def wait_for(self, status, timeout=300):
|
|
226
|
+
if self.status == status:
|
|
227
|
+
return
|
|
226
228
|
valid_statuses = ["downloading", "ready", "archived"]
|
|
227
229
|
if not status in valid_statuses:
|
|
228
230
|
raise SpecificationError(
|
|
229
231
|
"status",
|
|
230
232
|
f"Invalid wait_for status {status}. Valid statuses are: {valid_statuses}",
|
|
231
233
|
)
|
|
232
|
-
|
|
233
|
-
|
|
234
|
+
|
|
235
|
+
MAX_TIMEOUT = 24 * 60 * 60
|
|
236
|
+
if timeout > MAX_TIMEOUT:
|
|
237
|
+
raise SpecificationError(
|
|
238
|
+
"timeout",
|
|
239
|
+
f"timeout must be less than {MAX_TIMEOUT} seconds.",
|
|
240
|
+
)
|
|
234
241
|
POLL_INTERVAL_MIN = 5
|
|
235
242
|
POLL_INTERVAL_MAX = 60
|
|
236
243
|
POLL_INTERVAL = max(min(timeout / 60, POLL_INTERVAL_MAX), POLL_INTERVAL_MIN)
|