trainml 0.5.6__py3-none-any.whl → 0.5.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tests/integration/test_jobs_integration.py +13 -0
- tests/unit/cli/cloudbender/test_cli_service_unit.py +34 -0
- tests/unit/cloudbender/test_data_connectors_unit.py +176 -0
- tests/unit/cloudbender/test_services_unit.py +6 -0
- tests/unit/test_projects_unit.py +45 -5
- trainml/__init__.py +1 -1
- trainml/checkpoints.py +25 -25
- trainml/cli/cloudbender/__init__.py +1 -0
- trainml/cli/cloudbender/data_connector.py +159 -0
- trainml/cli/cloudbender/service.py +19 -2
- trainml/cloudbender/cloudbender.py +2 -0
- trainml/cloudbender/data_connectors.py +112 -0
- trainml/cloudbender/services.py +65 -1
- trainml/datasets.py +19 -8
- trainml/jobs.py +13 -6
- trainml/models.py +22 -19
- trainml/projects.py +60 -8
- trainml/volumes.py +9 -2
- {trainml-0.5.6.dist-info → trainml-0.5.7.dist-info}/METADATA +1 -1
- {trainml-0.5.6.dist-info → trainml-0.5.7.dist-info}/RECORD +24 -20
- {trainml-0.5.6.dist-info → trainml-0.5.7.dist-info}/LICENSE +0 -0
- {trainml-0.5.6.dist-info → trainml-0.5.7.dist-info}/WHEEL +0 -0
- {trainml-0.5.6.dist-info → trainml-0.5.7.dist-info}/entry_points.txt +0 -0
- {trainml-0.5.6.dist-info → trainml-0.5.7.dist-info}/top_level.txt +0 -0
|
@@ -6,10 +6,19 @@ import asyncio
|
|
|
6
6
|
import aiohttp
|
|
7
7
|
from pytest import mark, fixture, raises
|
|
8
8
|
from trainml.exceptions import ApiError
|
|
9
|
+
from urllib.parse import urlparse
|
|
9
10
|
|
|
10
11
|
pytestmark = [mark.sdk, mark.integration, mark.jobs]
|
|
11
12
|
|
|
12
13
|
|
|
14
|
+
def extract_domain_suffix(hostname):
|
|
15
|
+
parts = hostname.split(".")
|
|
16
|
+
if len(parts) >= 2:
|
|
17
|
+
return ".".join(parts[-2:])
|
|
18
|
+
else:
|
|
19
|
+
return None
|
|
20
|
+
|
|
21
|
+
|
|
13
22
|
@fixture(scope="class")
|
|
14
23
|
async def job(trainml):
|
|
15
24
|
job = await trainml.jobs.create(
|
|
@@ -34,6 +43,8 @@ class JobLifeCycleTests:
|
|
|
34
43
|
assert job.status != "running"
|
|
35
44
|
job = await job.wait_for("running")
|
|
36
45
|
assert job.status == "running"
|
|
46
|
+
assert job.url
|
|
47
|
+
assert extract_domain_suffix(urlparse(job.url).hostname) == "proximl.cloud"
|
|
37
48
|
|
|
38
49
|
async def test_stop_job(self, job):
|
|
39
50
|
assert job.status == "running"
|
|
@@ -518,6 +529,7 @@ class JobIOTests:
|
|
|
518
529
|
@mark.asyncio
|
|
519
530
|
class JobTypeTests:
|
|
520
531
|
async def test_endpoint(self, trainml):
|
|
532
|
+
|
|
521
533
|
job = await trainml.jobs.create(
|
|
522
534
|
"CLI Automated Tests - Endpoint",
|
|
523
535
|
type="endpoint",
|
|
@@ -544,6 +556,7 @@ class JobTypeTests:
|
|
|
544
556
|
await job.wait_for("running")
|
|
545
557
|
await job.refresh()
|
|
546
558
|
assert job.url
|
|
559
|
+
assert extract_domain_suffix(urlparse(job.url).hostname) == "proximl.cloud"
|
|
547
560
|
tries = 0
|
|
548
561
|
await asyncio.sleep(30)
|
|
549
562
|
async with aiohttp.ClientSession() as session:
|
|
@@ -0,0 +1,34 @@
|
|
|
1
|
+
import re
|
|
2
|
+
import json
|
|
3
|
+
import click
|
|
4
|
+
from unittest.mock import AsyncMock, patch
|
|
5
|
+
from pytest import mark, fixture, raises
|
|
6
|
+
|
|
7
|
+
pytestmark = [mark.cli, mark.unit, mark.cloudbender, mark.services]
|
|
8
|
+
|
|
9
|
+
from trainml.cli.cloudbender import service as specimen
|
|
10
|
+
from trainml.cloudbender.services import Service
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def test_list(runner, mock_services):
|
|
14
|
+
with patch("trainml.cli.TrainML", new=AsyncMock) as mock_trainml:
|
|
15
|
+
mock_trainml.cloudbender = AsyncMock()
|
|
16
|
+
mock_trainml.cloudbender.services = AsyncMock()
|
|
17
|
+
mock_trainml.cloudbender.services.list = AsyncMock(return_value=mock_services)
|
|
18
|
+
result = runner.invoke(
|
|
19
|
+
specimen,
|
|
20
|
+
args=["list", "--provider=prov-id-1", "--region=reg-id-1"],
|
|
21
|
+
)
|
|
22
|
+
assert result.exit_code == 0
|
|
23
|
+
mock_trainml.cloudbender.services.list.assert_called_once_with(
|
|
24
|
+
provider_uuid="prov-id-1", region_uuid="reg-id-1"
|
|
25
|
+
)
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def test_list_no_provider(runner, mock_services):
|
|
29
|
+
with patch("trainml.cli.TrainML", new=AsyncMock) as mock_trainml:
|
|
30
|
+
mock_trainml.cloudbender = AsyncMock()
|
|
31
|
+
mock_trainml.cloudbender.services = AsyncMock()
|
|
32
|
+
mock_trainml.cloudbender.services.list = AsyncMock(return_value=mock_services)
|
|
33
|
+
result = runner.invoke(specimen, ["list"])
|
|
34
|
+
assert result.exit_code != 0
|
|
@@ -0,0 +1,176 @@
|
|
|
1
|
+
import re
|
|
2
|
+
import json
|
|
3
|
+
import logging
|
|
4
|
+
from unittest.mock import AsyncMock, patch
|
|
5
|
+
from pytest import mark, fixture, raises
|
|
6
|
+
from aiohttp import WSMessage, WSMsgType
|
|
7
|
+
|
|
8
|
+
import trainml.cloudbender.data_connectors as specimen
|
|
9
|
+
from trainml.exceptions import (
|
|
10
|
+
ApiError,
|
|
11
|
+
SpecificationError,
|
|
12
|
+
TrainMLException,
|
|
13
|
+
)
|
|
14
|
+
|
|
15
|
+
pytestmark = [mark.sdk, mark.unit, mark.cloudbender, mark.data_connectors]
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
@fixture
|
|
19
|
+
def data_connectors(mock_trainml):
|
|
20
|
+
yield specimen.DataConnectors(mock_trainml)
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
@fixture
|
|
24
|
+
def data_connector(mock_trainml):
|
|
25
|
+
yield specimen.DataConnector(
|
|
26
|
+
mock_trainml,
|
|
27
|
+
provider_uuid="1",
|
|
28
|
+
region_uuid="a",
|
|
29
|
+
connector_id="x",
|
|
30
|
+
name="On-Prem Data Connector",
|
|
31
|
+
type="custom",
|
|
32
|
+
cidr="192.168.0.50/32",
|
|
33
|
+
port="443",
|
|
34
|
+
protocol="tcp",
|
|
35
|
+
)
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
class RegionsTests:
|
|
39
|
+
@mark.asyncio
|
|
40
|
+
async def test_get_data_connector(
|
|
41
|
+
self,
|
|
42
|
+
data_connectors,
|
|
43
|
+
mock_trainml,
|
|
44
|
+
):
|
|
45
|
+
api_response = dict()
|
|
46
|
+
mock_trainml._query = AsyncMock(return_value=api_response)
|
|
47
|
+
await data_connectors.get("1234", "5687", "91011")
|
|
48
|
+
mock_trainml._query.assert_called_once_with(
|
|
49
|
+
"/provider/1234/region/5687/data_connector/91011", "GET", {}
|
|
50
|
+
)
|
|
51
|
+
|
|
52
|
+
@mark.asyncio
|
|
53
|
+
async def test_list_data_connectors(
|
|
54
|
+
self,
|
|
55
|
+
data_connectors,
|
|
56
|
+
mock_trainml,
|
|
57
|
+
):
|
|
58
|
+
api_response = dict()
|
|
59
|
+
mock_trainml._query = AsyncMock(return_value=api_response)
|
|
60
|
+
await data_connectors.list("1234", "5687")
|
|
61
|
+
mock_trainml._query.assert_called_once_with(
|
|
62
|
+
"/provider/1234/region/5687/data_connector", "GET", {}
|
|
63
|
+
)
|
|
64
|
+
|
|
65
|
+
@mark.asyncio
|
|
66
|
+
async def test_remove_data_connector(
|
|
67
|
+
self,
|
|
68
|
+
data_connectors,
|
|
69
|
+
mock_trainml,
|
|
70
|
+
):
|
|
71
|
+
api_response = dict()
|
|
72
|
+
mock_trainml._query = AsyncMock(return_value=api_response)
|
|
73
|
+
await data_connectors.remove("1234", "4567", "8910")
|
|
74
|
+
mock_trainml._query.assert_called_once_with(
|
|
75
|
+
"/provider/1234/region/4567/data_connector/8910", "DELETE", {}
|
|
76
|
+
)
|
|
77
|
+
|
|
78
|
+
@mark.asyncio
|
|
79
|
+
async def test_create_data_connector(self, data_connectors, mock_trainml):
|
|
80
|
+
requested_config = dict(
|
|
81
|
+
provider_uuid="provider-id-1",
|
|
82
|
+
region_uuid="region-id-1",
|
|
83
|
+
name="On-Prem DataConnector",
|
|
84
|
+
type="custom",
|
|
85
|
+
cidr="192.168.0.50/32",
|
|
86
|
+
port="443",
|
|
87
|
+
protocol="tcp",
|
|
88
|
+
)
|
|
89
|
+
expected_payload = dict(
|
|
90
|
+
name="On-Prem DataConnector",
|
|
91
|
+
type="custom",
|
|
92
|
+
cidr="192.168.0.50/32",
|
|
93
|
+
port="443",
|
|
94
|
+
protocol="tcp",
|
|
95
|
+
)
|
|
96
|
+
api_response = {
|
|
97
|
+
"provider_uuid": "provider-id-1",
|
|
98
|
+
"region_uuid": "region-id-1",
|
|
99
|
+
"connector_id": "connector-id-1",
|
|
100
|
+
"name": "On-Prem DataConnector",
|
|
101
|
+
"type": "custom",
|
|
102
|
+
"cidr": "192.168.0.50/32",
|
|
103
|
+
"port": "443",
|
|
104
|
+
"protocol": "tcp",
|
|
105
|
+
"createdAt": "2020-12-31T23:59:59.000Z",
|
|
106
|
+
}
|
|
107
|
+
|
|
108
|
+
mock_trainml._query = AsyncMock(return_value=api_response)
|
|
109
|
+
response = await data_connectors.create(**requested_config)
|
|
110
|
+
mock_trainml._query.assert_called_once_with(
|
|
111
|
+
"/provider/provider-id-1/region/region-id-1/data_connector",
|
|
112
|
+
"POST",
|
|
113
|
+
None,
|
|
114
|
+
expected_payload,
|
|
115
|
+
)
|
|
116
|
+
assert response.id == "connector-id-1"
|
|
117
|
+
|
|
118
|
+
|
|
119
|
+
class DataConnectorTests:
|
|
120
|
+
def test_data_connector_properties(self, data_connector):
|
|
121
|
+
assert isinstance(data_connector.id, str)
|
|
122
|
+
assert isinstance(data_connector.provider_uuid, str)
|
|
123
|
+
assert isinstance(data_connector.region_uuid, str)
|
|
124
|
+
assert isinstance(data_connector.type, str)
|
|
125
|
+
assert isinstance(data_connector.name, str)
|
|
126
|
+
|
|
127
|
+
def test_data_connector_str(self, data_connector):
|
|
128
|
+
string = str(data_connector)
|
|
129
|
+
regex = r"^{.*\"connector_id\": \"" + data_connector.id + r"\".*}$"
|
|
130
|
+
assert isinstance(string, str)
|
|
131
|
+
assert re.match(regex, string)
|
|
132
|
+
|
|
133
|
+
def test_data_connector_repr(self, data_connector):
|
|
134
|
+
string = repr(data_connector)
|
|
135
|
+
regex = (
|
|
136
|
+
r"^DataConnector\( trainml , \*\*{.*'connector_id': '"
|
|
137
|
+
+ data_connector.id
|
|
138
|
+
+ r"'.*}\)$"
|
|
139
|
+
)
|
|
140
|
+
assert isinstance(string, str)
|
|
141
|
+
assert re.match(regex, string)
|
|
142
|
+
|
|
143
|
+
def test_data_connector_bool(self, data_connector, mock_trainml):
|
|
144
|
+
empty_data_connector = specimen.DataConnector(mock_trainml)
|
|
145
|
+
assert bool(data_connector)
|
|
146
|
+
assert not bool(empty_data_connector)
|
|
147
|
+
|
|
148
|
+
@mark.asyncio
|
|
149
|
+
async def test_data_connector_remove(self, data_connector, mock_trainml):
|
|
150
|
+
api_response = dict()
|
|
151
|
+
mock_trainml._query = AsyncMock(return_value=api_response)
|
|
152
|
+
await data_connector.remove()
|
|
153
|
+
mock_trainml._query.assert_called_once_with(
|
|
154
|
+
"/provider/1/region/a/data_connector/x", "DELETE"
|
|
155
|
+
)
|
|
156
|
+
|
|
157
|
+
@mark.asyncio
|
|
158
|
+
async def test_data_connector_refresh(self, data_connector, mock_trainml):
|
|
159
|
+
api_response = {
|
|
160
|
+
"provider_uuid": "provider-id-1",
|
|
161
|
+
"region_uuid": "region-id-1",
|
|
162
|
+
"connector_id": "connector-id-1",
|
|
163
|
+
"name": "On-Prem DataConnector",
|
|
164
|
+
"type": "custom",
|
|
165
|
+
"cidr": "192.168.0.50/32",
|
|
166
|
+
"port": "443",
|
|
167
|
+
"protocol": "tcp",
|
|
168
|
+
"createdAt": "2020-12-31T23:59:59.000Z",
|
|
169
|
+
}
|
|
170
|
+
mock_trainml._query = AsyncMock(return_value=api_response)
|
|
171
|
+
response = await data_connector.refresh()
|
|
172
|
+
mock_trainml._query.assert_called_once_with(
|
|
173
|
+
f"/provider/1/region/a/data_connector/x", "GET"
|
|
174
|
+
)
|
|
175
|
+
assert data_connector.id == "connector-id-1"
|
|
176
|
+
assert response.id == "connector-id-1"
|
|
@@ -28,6 +28,7 @@ def service(mock_trainml):
|
|
|
28
28
|
region_uuid="a",
|
|
29
29
|
service_id="x",
|
|
30
30
|
name="On-Prem Service",
|
|
31
|
+
type="https",
|
|
31
32
|
public=False,
|
|
32
33
|
hostname="app1.proximl.cloud",
|
|
33
34
|
)
|
|
@@ -79,10 +80,12 @@ class RegionsTests:
|
|
|
79
80
|
provider_uuid="provider-id-1",
|
|
80
81
|
region_uuid="region-id-1",
|
|
81
82
|
name="On-Prem Service",
|
|
83
|
+
type="https",
|
|
82
84
|
public=False,
|
|
83
85
|
)
|
|
84
86
|
expected_payload = dict(
|
|
85
87
|
name="On-Prem Service",
|
|
88
|
+
type="https",
|
|
86
89
|
public=False,
|
|
87
90
|
)
|
|
88
91
|
api_response = {
|
|
@@ -90,6 +93,7 @@ class RegionsTests:
|
|
|
90
93
|
"region_uuid": "region-id-1",
|
|
91
94
|
"service_id": "service-id-1",
|
|
92
95
|
"name": "On-Prem Service",
|
|
96
|
+
"type": "https",
|
|
93
97
|
"public": False,
|
|
94
98
|
"hostname": "app1.proximl.cloud",
|
|
95
99
|
"createdAt": "2020-12-31T23:59:59.000Z",
|
|
@@ -114,6 +118,7 @@ class serviceTests:
|
|
|
114
118
|
assert isinstance(service.public, bool)
|
|
115
119
|
assert isinstance(service.name, str)
|
|
116
120
|
assert isinstance(service.hostname, str)
|
|
121
|
+
assert isinstance(service.type, str)
|
|
117
122
|
|
|
118
123
|
def test_service_str(self, service):
|
|
119
124
|
string = str(service)
|
|
@@ -148,6 +153,7 @@ class serviceTests:
|
|
|
148
153
|
"region_uuid": "region-id-1",
|
|
149
154
|
"service_id": "service-id-1",
|
|
150
155
|
"name": "On-Prem Service",
|
|
156
|
+
"type": "https",
|
|
151
157
|
"public": False,
|
|
152
158
|
"hostname": "app1.proximl.cloud",
|
|
153
159
|
"createdAt": "2020-12-31T23:59:59.000Z",
|
tests/unit/test_projects_unit.py
CHANGED
|
@@ -48,6 +48,18 @@ def project_datastore(mock_trainml):
|
|
|
48
48
|
)
|
|
49
49
|
|
|
50
50
|
|
|
51
|
+
@fixture
|
|
52
|
+
def project_data_connector(mock_trainml):
|
|
53
|
+
yield specimen.ProjectDataConnector(
|
|
54
|
+
mock_trainml,
|
|
55
|
+
id="ds-id-1",
|
|
56
|
+
name="connector 1",
|
|
57
|
+
project_uuid="proj-id-1",
|
|
58
|
+
type="custom",
|
|
59
|
+
region_uuid="reg-id-1",
|
|
60
|
+
)
|
|
61
|
+
|
|
62
|
+
|
|
51
63
|
@fixture
|
|
52
64
|
def project_service(mock_trainml):
|
|
53
65
|
yield specimen.ProjectService(
|
|
@@ -56,9 +68,8 @@ def project_service(mock_trainml):
|
|
|
56
68
|
name="service 1",
|
|
57
69
|
project_uuid="proj-id-1",
|
|
58
70
|
region_uuid="reg-id-1",
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
hostname="service.local",
|
|
71
|
+
public=False,
|
|
72
|
+
hostname="asdf.proximl.cloud",
|
|
62
73
|
)
|
|
63
74
|
|
|
64
75
|
|
|
@@ -152,14 +163,43 @@ class ProjectDatastoreTests:
|
|
|
152
163
|
assert not bool(empty_project_datastore)
|
|
153
164
|
|
|
154
165
|
|
|
166
|
+
class ProjectDataConnectorTests:
|
|
167
|
+
def test_project_data_connector_properties(self, project_data_connector):
|
|
168
|
+
assert isinstance(project_data_connector.id, str)
|
|
169
|
+
assert isinstance(project_data_connector.name, str)
|
|
170
|
+
assert isinstance(project_data_connector.project_uuid, str)
|
|
171
|
+
assert isinstance(project_data_connector.type, str)
|
|
172
|
+
assert isinstance(project_data_connector.region_uuid, str)
|
|
173
|
+
|
|
174
|
+
def test_project_data_connector_str(self, project_data_connector):
|
|
175
|
+
string = str(project_data_connector)
|
|
176
|
+
regex = r"^{.*\"id\": \"" + project_data_connector.id + r"\".*}$"
|
|
177
|
+
assert isinstance(string, str)
|
|
178
|
+
assert re.match(regex, string)
|
|
179
|
+
|
|
180
|
+
def test_project_data_connector_repr(self, project_data_connector):
|
|
181
|
+
string = repr(project_data_connector)
|
|
182
|
+
regex = (
|
|
183
|
+
r"^ProjectDataConnector\( trainml , \*\*{.*'id': '"
|
|
184
|
+
+ project_data_connector.id
|
|
185
|
+
+ r"'.*}\)$"
|
|
186
|
+
)
|
|
187
|
+
assert isinstance(string, str)
|
|
188
|
+
assert re.match(regex, string)
|
|
189
|
+
|
|
190
|
+
def test_project_data_connector_bool(self, project_data_connector, mock_trainml):
|
|
191
|
+
empty_project_data_connector = specimen.ProjectDataConnector(mock_trainml)
|
|
192
|
+
assert bool(project_data_connector)
|
|
193
|
+
assert not bool(empty_project_data_connector)
|
|
194
|
+
|
|
195
|
+
|
|
155
196
|
class ProjectServiceTests:
|
|
156
197
|
def test_project_service_properties(self, project_service):
|
|
157
198
|
assert isinstance(project_service.id, str)
|
|
158
199
|
assert isinstance(project_service.name, str)
|
|
159
200
|
assert isinstance(project_service.project_uuid, str)
|
|
160
|
-
assert isinstance(project_service.type, str)
|
|
161
201
|
assert isinstance(project_service.hostname, str)
|
|
162
|
-
assert isinstance(project_service.
|
|
202
|
+
assert isinstance(project_service.public, bool)
|
|
163
203
|
assert isinstance(project_service.region_uuid, str)
|
|
164
204
|
|
|
165
205
|
def test_project_service_str(self, project_service):
|
trainml/__init__.py
CHANGED
trainml/checkpoints.py
CHANGED
|
@@ -23,9 +23,7 @@ class Checkpoints(object):
|
|
|
23
23
|
|
|
24
24
|
async def list(self, **kwargs):
|
|
25
25
|
resp = await self.trainml._query(f"/checkpoint", "GET", kwargs)
|
|
26
|
-
checkpoints = [
|
|
27
|
-
Checkpoint(self.trainml, **checkpoint) for checkpoint in resp
|
|
28
|
-
]
|
|
26
|
+
checkpoints = [Checkpoint(self.trainml, **checkpoint) for checkpoint in resp]
|
|
29
27
|
return checkpoints
|
|
30
28
|
|
|
31
29
|
async def list_public(self, **kwargs):
|
|
@@ -39,8 +37,7 @@ class Checkpoints(object):
|
|
|
39
37
|
source_type=source_type,
|
|
40
38
|
source_uri=source_uri,
|
|
41
39
|
source_options=kwargs.get("source_options"),
|
|
42
|
-
project_uuid=kwargs.get("project_uuid")
|
|
43
|
-
or self.trainml.active_project,
|
|
40
|
+
project_uuid=kwargs.get("project_uuid") or self.trainml.active_project,
|
|
44
41
|
)
|
|
45
42
|
payload = {k: v for k, v in data.items() if v is not None}
|
|
46
43
|
logging.info(f"Creating Checkpoint {name}")
|
|
@@ -60,9 +57,7 @@ class Checkpoint:
|
|
|
60
57
|
def __init__(self, trainml, **kwargs):
|
|
61
58
|
self.trainml = trainml
|
|
62
59
|
self._checkpoint = kwargs
|
|
63
|
-
self._id = self._checkpoint.get(
|
|
64
|
-
"id", self._checkpoint.get("checkpoint_uuid")
|
|
65
|
-
)
|
|
60
|
+
self._id = self._checkpoint.get("id", self._checkpoint.get("checkpoint_uuid"))
|
|
66
61
|
self._status = self._checkpoint.get("status")
|
|
67
62
|
self._name = self._checkpoint.get("name")
|
|
68
63
|
self._size = self._checkpoint.get("size")
|
|
@@ -123,15 +118,17 @@ class Checkpoint:
|
|
|
123
118
|
entity_type="checkpoint",
|
|
124
119
|
project_uuid=self._checkpoint.get("project_uuid"),
|
|
125
120
|
cidr=self._checkpoint.get("vpn").get("cidr"),
|
|
126
|
-
ssh_port=self._checkpoint.get("vpn")
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
output_path=
|
|
133
|
-
|
|
134
|
-
|
|
121
|
+
ssh_port=self._checkpoint.get("vpn").get("client").get("ssh_port"),
|
|
122
|
+
input_path=(
|
|
123
|
+
self._checkpoint.get("source_uri")
|
|
124
|
+
if self.status in ["new", "downloading"]
|
|
125
|
+
else None
|
|
126
|
+
),
|
|
127
|
+
output_path=(
|
|
128
|
+
self._checkpoint.get("output_uri")
|
|
129
|
+
if self.status == "exporting"
|
|
130
|
+
else None
|
|
131
|
+
),
|
|
135
132
|
)
|
|
136
133
|
else:
|
|
137
134
|
details = dict()
|
|
@@ -195,9 +192,7 @@ class Checkpoint:
|
|
|
195
192
|
if msg_handler:
|
|
196
193
|
msg_handler(data)
|
|
197
194
|
else:
|
|
198
|
-
timestamp = datetime.fromtimestamp(
|
|
199
|
-
int(data.get("time")) / 1000
|
|
200
|
-
)
|
|
195
|
+
timestamp = datetime.fromtimestamp(int(data.get("time")) / 1000)
|
|
201
196
|
print(
|
|
202
197
|
f"{timestamp.strftime('%m/%d/%Y, %H:%M:%S')}: {data.get('msg').rstrip()}"
|
|
203
198
|
)
|
|
@@ -224,19 +219,24 @@ class Checkpoint:
|
|
|
224
219
|
return self
|
|
225
220
|
|
|
226
221
|
async def wait_for(self, status, timeout=300):
|
|
222
|
+
if self.status == status:
|
|
223
|
+
return
|
|
227
224
|
valid_statuses = ["downloading", "ready", "archived"]
|
|
228
225
|
if not status in valid_statuses:
|
|
229
226
|
raise SpecificationError(
|
|
230
227
|
"status",
|
|
231
228
|
f"Invalid wait_for status {status}. Valid statuses are: {valid_statuses}",
|
|
232
229
|
)
|
|
233
|
-
|
|
234
|
-
|
|
230
|
+
|
|
231
|
+
MAX_TIMEOUT = 24 * 60 * 60
|
|
232
|
+
if timeout > MAX_TIMEOUT:
|
|
233
|
+
raise SpecificationError(
|
|
234
|
+
"timeout",
|
|
235
|
+
f"timeout must be less than {MAX_TIMEOUT} seconds.",
|
|
236
|
+
)
|
|
235
237
|
POLL_INTERVAL_MIN = 5
|
|
236
238
|
POLL_INTERVAL_MAX = 60
|
|
237
|
-
POLL_INTERVAL = max(
|
|
238
|
-
min(timeout / 60, POLL_INTERVAL_MAX), POLL_INTERVAL_MIN
|
|
239
|
-
)
|
|
239
|
+
POLL_INTERVAL = max(min(timeout / 60, POLL_INTERVAL_MAX), POLL_INTERVAL_MIN)
|
|
240
240
|
retry_count = math.ceil(timeout / POLL_INTERVAL)
|
|
241
241
|
count = 0
|
|
242
242
|
while count < retry_count:
|
|
@@ -15,4 +15,5 @@ from trainml.cli.cloudbender.region import region
|
|
|
15
15
|
from trainml.cli.cloudbender.node import node
|
|
16
16
|
from trainml.cli.cloudbender.device import device
|
|
17
17
|
from trainml.cli.cloudbender.datastore import datastore
|
|
18
|
+
from trainml.cli.cloudbender.data_connector import data_connector
|
|
18
19
|
from trainml.cli.cloudbender.service import service
|
|
@@ -0,0 +1,159 @@
|
|
|
1
|
+
import click
|
|
2
|
+
from trainml.cli import cli, pass_config, search_by_id_name
|
|
3
|
+
from trainml.cli.cloudbender import cloudbender
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
@cloudbender.group()
|
|
7
|
+
@pass_config
|
|
8
|
+
def data_connector(config):
|
|
9
|
+
"""trainML CloudBender data connector commands."""
|
|
10
|
+
pass
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
@data_connector.command()
|
|
14
|
+
@click.option(
|
|
15
|
+
"--provider",
|
|
16
|
+
"-p",
|
|
17
|
+
type=click.STRING,
|
|
18
|
+
required=True,
|
|
19
|
+
help="The provider ID of the region.",
|
|
20
|
+
)
|
|
21
|
+
@click.option(
|
|
22
|
+
"--region",
|
|
23
|
+
"-r",
|
|
24
|
+
type=click.STRING,
|
|
25
|
+
required=True,
|
|
26
|
+
help="The region ID to list data connectors for.",
|
|
27
|
+
)
|
|
28
|
+
@pass_config
|
|
29
|
+
def list(config, provider, region):
|
|
30
|
+
"""List data connectors."""
|
|
31
|
+
data = [
|
|
32
|
+
["ID", "NAME", "TYPE"],
|
|
33
|
+
[
|
|
34
|
+
"-" * 80,
|
|
35
|
+
"-" * 80,
|
|
36
|
+
"-" * 80,
|
|
37
|
+
],
|
|
38
|
+
]
|
|
39
|
+
|
|
40
|
+
data_connectors = config.trainml.run(
|
|
41
|
+
config.trainml.client.cloudbender.data_connectors.list(
|
|
42
|
+
provider_uuid=provider, region_uuid=region
|
|
43
|
+
)
|
|
44
|
+
)
|
|
45
|
+
|
|
46
|
+
for data_connector in data_connectors:
|
|
47
|
+
data.append(
|
|
48
|
+
[
|
|
49
|
+
data_connector.id,
|
|
50
|
+
data_connector.name,
|
|
51
|
+
data_connector.type,
|
|
52
|
+
]
|
|
53
|
+
)
|
|
54
|
+
|
|
55
|
+
for row in data:
|
|
56
|
+
click.echo(
|
|
57
|
+
"{: >37.36} {: >29.28} {: >9.8}" "".format(*row),
|
|
58
|
+
file=config.stdout,
|
|
59
|
+
)
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
@data_connector.command()
|
|
63
|
+
@click.option(
|
|
64
|
+
"--provider",
|
|
65
|
+
"-p",
|
|
66
|
+
type=click.STRING,
|
|
67
|
+
required=True,
|
|
68
|
+
help="The provider ID of the region.",
|
|
69
|
+
)
|
|
70
|
+
@click.option(
|
|
71
|
+
"--region",
|
|
72
|
+
"-r",
|
|
73
|
+
type=click.STRING,
|
|
74
|
+
required=True,
|
|
75
|
+
help="The region ID to create the data_connector in.",
|
|
76
|
+
)
|
|
77
|
+
@click.option(
|
|
78
|
+
"--type",
|
|
79
|
+
"-t",
|
|
80
|
+
type=click.Choice(
|
|
81
|
+
[
|
|
82
|
+
"custom",
|
|
83
|
+
],
|
|
84
|
+
case_sensitive=False,
|
|
85
|
+
),
|
|
86
|
+
required=True,
|
|
87
|
+
help="The type of data connector to create.",
|
|
88
|
+
)
|
|
89
|
+
@click.option(
|
|
90
|
+
"--protocol",
|
|
91
|
+
"-r",
|
|
92
|
+
type=click.STRING,
|
|
93
|
+
help="The transport protocol of the data connector",
|
|
94
|
+
)
|
|
95
|
+
@click.option(
|
|
96
|
+
"--port-range",
|
|
97
|
+
"-p",
|
|
98
|
+
type=click.STRING,
|
|
99
|
+
help="The port range of the data connector",
|
|
100
|
+
)
|
|
101
|
+
@click.option(
|
|
102
|
+
"--cidr",
|
|
103
|
+
"-i",
|
|
104
|
+
type=click.STRING,
|
|
105
|
+
help="The IP range to allow in CIDR notation",
|
|
106
|
+
)
|
|
107
|
+
@click.argument("name", type=click.STRING, required=True)
|
|
108
|
+
@pass_config
|
|
109
|
+
def create(config, provider, region, type, protocol, port_range, cidr, name):
|
|
110
|
+
"""
|
|
111
|
+
Creates a data_connector.
|
|
112
|
+
"""
|
|
113
|
+
return config.trainml.run(
|
|
114
|
+
config.trainml.client.cloudbender.data_connectors.create(
|
|
115
|
+
provider_uuid=provider,
|
|
116
|
+
region_uuid=region,
|
|
117
|
+
name=name,
|
|
118
|
+
type=type,
|
|
119
|
+
protocol=protocol,
|
|
120
|
+
port_range=port_range,
|
|
121
|
+
cidr=cidr,
|
|
122
|
+
)
|
|
123
|
+
)
|
|
124
|
+
|
|
125
|
+
|
|
126
|
+
@data_connector.command()
|
|
127
|
+
@click.option(
|
|
128
|
+
"--provider",
|
|
129
|
+
"-p",
|
|
130
|
+
type=click.STRING,
|
|
131
|
+
required=True,
|
|
132
|
+
help="The provider ID of the region.",
|
|
133
|
+
)
|
|
134
|
+
@click.option(
|
|
135
|
+
"--region",
|
|
136
|
+
"-r",
|
|
137
|
+
type=click.STRING,
|
|
138
|
+
required=True,
|
|
139
|
+
help="The region ID to remove the data_connector from.",
|
|
140
|
+
)
|
|
141
|
+
@click.argument("data_connector", type=click.STRING)
|
|
142
|
+
@pass_config
|
|
143
|
+
def remove(config, provider, region, data_connector):
|
|
144
|
+
"""
|
|
145
|
+
Remove a data_connector.
|
|
146
|
+
|
|
147
|
+
DATASTORE may be specified by name or ID, but ID is preferred.
|
|
148
|
+
"""
|
|
149
|
+
data_connectors = config.trainml.run(
|
|
150
|
+
config.trainml.client.cloudbender.data_connectors.list(
|
|
151
|
+
provider_uuid=provider, region_uuid=region
|
|
152
|
+
)
|
|
153
|
+
)
|
|
154
|
+
|
|
155
|
+
found = search_by_id_name(data_connector, data_connectors)
|
|
156
|
+
if None is found:
|
|
157
|
+
raise click.UsageError("Cannot find specified data_connector.")
|
|
158
|
+
|
|
159
|
+
return config.trainml.run(found.remove())
|
|
@@ -74,6 +74,19 @@ def list(config, provider, region):
|
|
|
74
74
|
required=True,
|
|
75
75
|
help="The region ID to create the service in.",
|
|
76
76
|
)
|
|
77
|
+
@click.option(
|
|
78
|
+
"--type",
|
|
79
|
+
"-t",
|
|
80
|
+
type=click.Choice(
|
|
81
|
+
[
|
|
82
|
+
"https",
|
|
83
|
+
"tcp",
|
|
84
|
+
"udp",
|
|
85
|
+
],
|
|
86
|
+
),
|
|
87
|
+
required=True,
|
|
88
|
+
help="The type of regional service.",
|
|
89
|
+
)
|
|
77
90
|
@click.option(
|
|
78
91
|
"--public/--no-public",
|
|
79
92
|
default=True,
|
|
@@ -82,13 +95,17 @@ def list(config, provider, region):
|
|
|
82
95
|
)
|
|
83
96
|
@click.argument("name", type=click.STRING, required=True)
|
|
84
97
|
@pass_config
|
|
85
|
-
def create(config, provider, region, public, name):
|
|
98
|
+
def create(config, provider, region, type, public, name):
|
|
86
99
|
"""
|
|
87
100
|
Creates a service.
|
|
88
101
|
"""
|
|
89
102
|
return config.trainml.run(
|
|
90
103
|
config.trainml.client.cloudbender.services.create(
|
|
91
|
-
provider_uuid=provider,
|
|
104
|
+
provider_uuid=provider,
|
|
105
|
+
region_uuid=region,
|
|
106
|
+
name=name,
|
|
107
|
+
type=type,
|
|
108
|
+
public=public,
|
|
92
109
|
)
|
|
93
110
|
)
|
|
94
111
|
|