trainml 0.5.16__py3-none-any.whl → 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- examples/local_storage.py +0 -2
- tests/integration/test_checkpoints_integration.py +4 -3
- tests/integration/test_datasets_integration.py +5 -3
- tests/integration/test_jobs_integration.py +33 -27
- tests/integration/test_models_integration.py +7 -3
- tests/integration/test_volumes_integration.py +2 -2
- tests/unit/cli/test_cli_checkpoint_unit.py +312 -1
- tests/unit/cloudbender/test_nodes_unit.py +112 -0
- tests/unit/cloudbender/test_providers_unit.py +96 -0
- tests/unit/cloudbender/test_regions_unit.py +106 -0
- tests/unit/cloudbender/test_services_unit.py +141 -0
- tests/unit/conftest.py +23 -10
- tests/unit/projects/test_project_data_connectors_unit.py +39 -0
- tests/unit/projects/test_project_datastores_unit.py +37 -0
- tests/unit/projects/test_project_members_unit.py +46 -0
- tests/unit/projects/test_project_services_unit.py +65 -0
- tests/unit/projects/test_projects_unit.py +17 -1
- tests/unit/test_auth_unit.py +17 -2
- tests/unit/test_checkpoints_unit.py +256 -71
- tests/unit/test_datasets_unit.py +218 -68
- tests/unit/test_exceptions.py +133 -0
- tests/unit/test_gpu_types_unit.py +11 -1
- tests/unit/test_jobs_unit.py +1014 -95
- tests/unit/test_main_unit.py +20 -0
- tests/unit/test_models_unit.py +218 -70
- tests/unit/test_trainml_unit.py +627 -3
- tests/unit/test_volumes_unit.py +211 -70
- tests/unit/utils/__init__.py +1 -0
- tests/unit/utils/test_transfer_unit.py +4260 -0
- trainml/__init__.py +1 -1
- trainml/checkpoints.py +56 -57
- trainml/cli/__init__.py +6 -3
- trainml/cli/checkpoint.py +18 -57
- trainml/cli/dataset.py +17 -57
- trainml/cli/job/__init__.py +11 -53
- trainml/cli/job/create.py +51 -24
- trainml/cli/model.py +14 -56
- trainml/cli/volume.py +18 -57
- trainml/datasets.py +50 -55
- trainml/jobs.py +239 -68
- trainml/models.py +51 -55
- trainml/projects/projects.py +2 -2
- trainml/trainml.py +50 -16
- trainml/utils/__init__.py +1 -0
- trainml/utils/auth.py +641 -0
- trainml/utils/transfer.py +587 -0
- trainml/volumes.py +48 -53
- {trainml-0.5.16.dist-info → trainml-1.0.0.dist-info}/METADATA +3 -3
- {trainml-0.5.16.dist-info → trainml-1.0.0.dist-info}/RECORD +53 -47
- {trainml-0.5.16.dist-info → trainml-1.0.0.dist-info}/LICENSE +0 -0
- {trainml-0.5.16.dist-info → trainml-1.0.0.dist-info}/WHEEL +0 -0
- {trainml-0.5.16.dist-info → trainml-1.0.0.dist-info}/entry_points.txt +0 -0
- {trainml-0.5.16.dist-info → trainml-1.0.0.dist-info}/top_level.txt +0 -0
trainml/cli/model.py
CHANGED
|
@@ -36,15 +36,7 @@ def attach(config, model):
|
|
|
36
36
|
if None is found:
|
|
37
37
|
raise click.UsageError("Cannot find specified model.")
|
|
38
38
|
|
|
39
|
-
|
|
40
|
-
config.trainml.run(found.attach())
|
|
41
|
-
return config.trainml.run(found.disconnect())
|
|
42
|
-
except:
|
|
43
|
-
try:
|
|
44
|
-
config.trainml.run(found.disconnect())
|
|
45
|
-
except:
|
|
46
|
-
pass
|
|
47
|
-
raise
|
|
39
|
+
config.trainml.run(found.attach())
|
|
48
40
|
|
|
49
41
|
|
|
50
42
|
@model.command()
|
|
@@ -69,18 +61,10 @@ def connect(config, model, attach):
|
|
|
69
61
|
if None is found:
|
|
70
62
|
raise click.UsageError("Cannot find specified model.")
|
|
71
63
|
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
else:
|
|
77
|
-
return config.trainml.run(found.connect())
|
|
78
|
-
except:
|
|
79
|
-
try:
|
|
80
|
-
config.trainml.run(found.disconnect())
|
|
81
|
-
except:
|
|
82
|
-
pass
|
|
83
|
-
raise
|
|
64
|
+
if attach:
|
|
65
|
+
config.trainml.run(found.connect(), found.attach())
|
|
66
|
+
else:
|
|
67
|
+
config.trainml.run(found.connect())
|
|
84
68
|
|
|
85
69
|
|
|
86
70
|
@model.command()
|
|
@@ -125,41 +109,15 @@ def create(config, attach, connect, source, name, path):
|
|
|
125
109
|
)
|
|
126
110
|
)
|
|
127
111
|
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
"No logs to show for local sourced model without connect."
|
|
138
|
-
)
|
|
139
|
-
except:
|
|
140
|
-
try:
|
|
141
|
-
config.trainml.run(model.disconnect())
|
|
142
|
-
except:
|
|
143
|
-
pass
|
|
144
|
-
raise
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
@model.command()
|
|
148
|
-
@click.argument("model", type=click.STRING)
|
|
149
|
-
@pass_config
|
|
150
|
-
def disconnect(config, model):
|
|
151
|
-
"""
|
|
152
|
-
Disconnect and clean-up model upload.
|
|
153
|
-
|
|
154
|
-
MODEL may be specified by name or ID, but ID is preferred.
|
|
155
|
-
"""
|
|
156
|
-
models = config.trainml.run(config.trainml.client.models.list())
|
|
157
|
-
|
|
158
|
-
found = search_by_id_name(model, models)
|
|
159
|
-
if None is found:
|
|
160
|
-
raise click.UsageError("Cannot find specified model.")
|
|
161
|
-
|
|
162
|
-
return config.trainml.run(found.disconnect())
|
|
112
|
+
if connect and attach:
|
|
113
|
+
config.trainml.run(model.attach(), model.connect())
|
|
114
|
+
elif connect:
|
|
115
|
+
config.trainml.run(model.connect())
|
|
116
|
+
else:
|
|
117
|
+
raise click.UsageError(
|
|
118
|
+
"Abort!\n"
|
|
119
|
+
"No logs to show for local sourced model without connect."
|
|
120
|
+
)
|
|
163
121
|
|
|
164
122
|
|
|
165
123
|
@model.command()
|
trainml/cli/volume.py
CHANGED
|
@@ -35,15 +35,7 @@ def attach(config, volume):
|
|
|
35
35
|
if None is found:
|
|
36
36
|
raise click.UsageError("Cannot find specified volume.")
|
|
37
37
|
|
|
38
|
-
|
|
39
|
-
config.trainml.run(found.attach())
|
|
40
|
-
return config.trainml.run(found.disconnect())
|
|
41
|
-
except:
|
|
42
|
-
try:
|
|
43
|
-
config.trainml.run(found.disconnect())
|
|
44
|
-
except:
|
|
45
|
-
pass
|
|
46
|
-
raise
|
|
38
|
+
config.trainml.run(found.attach())
|
|
47
39
|
|
|
48
40
|
|
|
49
41
|
@volume.command()
|
|
@@ -67,18 +59,10 @@ def connect(config, volume, attach):
|
|
|
67
59
|
if None is found:
|
|
68
60
|
raise click.UsageError("Cannot find specified volume.")
|
|
69
61
|
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
else:
|
|
75
|
-
return config.trainml.run(found.connect())
|
|
76
|
-
except:
|
|
77
|
-
try:
|
|
78
|
-
config.trainml.run(found.disconnect())
|
|
79
|
-
except:
|
|
80
|
-
pass
|
|
81
|
-
raise
|
|
62
|
+
if attach:
|
|
63
|
+
config.trainml.run(found.connect(), found.attach())
|
|
64
|
+
else:
|
|
65
|
+
config.trainml.run(found.connect())
|
|
82
66
|
|
|
83
67
|
|
|
84
68
|
@volume.command()
|
|
@@ -120,45 +104,22 @@ def create(config, attach, connect, source, name, capacity, path):
|
|
|
120
104
|
if source == "local":
|
|
121
105
|
volume = config.trainml.run(
|
|
122
106
|
config.trainml.client.volumes.create(
|
|
123
|
-
name=name,
|
|
107
|
+
name=name,
|
|
108
|
+
source_type="local",
|
|
109
|
+
source_uri=path,
|
|
110
|
+
capacity=capacity,
|
|
124
111
|
)
|
|
125
112
|
)
|
|
126
113
|
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
"No logs to show for local sourced volume without connect."
|
|
137
|
-
)
|
|
138
|
-
except:
|
|
139
|
-
try:
|
|
140
|
-
config.trainml.run(volume.disconnect())
|
|
141
|
-
except:
|
|
142
|
-
pass
|
|
143
|
-
raise
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
@volume.command()
|
|
147
|
-
@click.argument("volume", type=click.STRING)
|
|
148
|
-
@pass_config
|
|
149
|
-
def disconnect(config, volume):
|
|
150
|
-
"""
|
|
151
|
-
Disconnect and clean-up volume upload.
|
|
152
|
-
|
|
153
|
-
VOLUME may be specified by name or ID, but ID is preferred.
|
|
154
|
-
"""
|
|
155
|
-
volumes = config.trainml.run(config.trainml.client.volumes.list())
|
|
156
|
-
|
|
157
|
-
found = search_by_id_name(volume, volumes)
|
|
158
|
-
if None is found:
|
|
159
|
-
raise click.UsageError("Cannot find specified volume.")
|
|
160
|
-
|
|
161
|
-
return config.trainml.run(found.disconnect())
|
|
114
|
+
if connect and attach:
|
|
115
|
+
config.trainml.run(volume.attach(), volume.connect())
|
|
116
|
+
elif connect:
|
|
117
|
+
config.trainml.run(volume.connect())
|
|
118
|
+
else:
|
|
119
|
+
raise click.UsageError(
|
|
120
|
+
"Abort!\n"
|
|
121
|
+
"No logs to show for local sourced volume without connect."
|
|
122
|
+
)
|
|
162
123
|
|
|
163
124
|
|
|
164
125
|
@volume.command()
|
trainml/datasets.py
CHANGED
|
@@ -10,7 +10,7 @@ from .exceptions import (
|
|
|
10
10
|
SpecificationError,
|
|
11
11
|
TrainMLException,
|
|
12
12
|
)
|
|
13
|
-
from .
|
|
13
|
+
from trainml.utils.transfer import upload, download
|
|
14
14
|
|
|
15
15
|
|
|
16
16
|
class Datasets(object):
|
|
@@ -71,10 +71,12 @@ class Dataset:
|
|
|
71
71
|
self._id = self._dataset.get("id", self._dataset.get("dataset_uuid"))
|
|
72
72
|
self._status = self._dataset.get("status")
|
|
73
73
|
self._name = self._dataset.get("name")
|
|
74
|
-
self._size = self._dataset.get("size") or self._dataset.get(
|
|
75
|
-
|
|
76
|
-
"size"
|
|
74
|
+
self._size = self._dataset.get("size") or self._dataset.get(
|
|
75
|
+
"used_size"
|
|
77
76
|
)
|
|
77
|
+
self._billed_size = self._dataset.get(
|
|
78
|
+
"billed_size"
|
|
79
|
+
) or self._dataset.get("size")
|
|
78
80
|
self._project_uuid = self._dataset.get("project_uuid")
|
|
79
81
|
|
|
80
82
|
@property
|
|
@@ -122,56 +124,45 @@ class Dataset:
|
|
|
122
124
|
)
|
|
123
125
|
return resp
|
|
124
126
|
|
|
125
|
-
async def get_connection_utility_url(self):
|
|
126
|
-
resp = await self.trainml._query(
|
|
127
|
-
f"/dataset/{self._id}/download",
|
|
128
|
-
"GET",
|
|
129
|
-
dict(project_uuid=self._project_uuid),
|
|
130
|
-
)
|
|
131
|
-
return resp
|
|
132
|
-
|
|
133
|
-
def get_connection_details(self):
|
|
134
|
-
if self._dataset.get("vpn"):
|
|
135
|
-
details = dict(
|
|
136
|
-
entity_type="dataset",
|
|
137
|
-
project_uuid=self._dataset.get("project_uuid"),
|
|
138
|
-
cidr=self._dataset.get("vpn").get("cidr"),
|
|
139
|
-
ssh_port=self._dataset.get("vpn").get("client").get("ssh_port"),
|
|
140
|
-
input_path=(
|
|
141
|
-
self._dataset.get("source_uri")
|
|
142
|
-
if self.status in ["new", "downloading"]
|
|
143
|
-
else None
|
|
144
|
-
),
|
|
145
|
-
output_path=(
|
|
146
|
-
self._dataset.get("output_uri")
|
|
147
|
-
if self.status == "exporting"
|
|
148
|
-
else None
|
|
149
|
-
),
|
|
150
|
-
)
|
|
151
|
-
else:
|
|
152
|
-
details = dict()
|
|
153
|
-
return details
|
|
154
|
-
|
|
155
127
|
async def connect(self):
|
|
156
|
-
if self.status in ["
|
|
157
|
-
|
|
158
|
-
"
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
self.trainml, entity_type="dataset", id=self.id, entity=self
|
|
165
|
-
)
|
|
166
|
-
await connection.start()
|
|
167
|
-
return connection.status
|
|
128
|
+
if self.status not in ["downloading", "exporting"]:
|
|
129
|
+
if self.status == "new":
|
|
130
|
+
await self.wait_for("downloading")
|
|
131
|
+
else:
|
|
132
|
+
raise SpecificationError(
|
|
133
|
+
"status",
|
|
134
|
+
f"You can only connect to downloading or exporting datasets.",
|
|
135
|
+
)
|
|
168
136
|
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
|
|
137
|
+
# Refresh to get latest entity data
|
|
138
|
+
await self.refresh()
|
|
139
|
+
|
|
140
|
+
if self.status == "downloading":
|
|
141
|
+
# Upload task - get auth_token, hostname, and source_uri from dataset
|
|
142
|
+
auth_token = self._dataset.get("auth_token")
|
|
143
|
+
hostname = self._dataset.get("hostname")
|
|
144
|
+
source_uri = self._dataset.get("source_uri")
|
|
145
|
+
|
|
146
|
+
if not auth_token or not hostname or not source_uri:
|
|
147
|
+
raise SpecificationError(
|
|
148
|
+
"status",
|
|
149
|
+
f"Dataset in downloading status missing required connection properties (auth_token, hostname, source_uri).",
|
|
150
|
+
)
|
|
151
|
+
|
|
152
|
+
await upload(hostname, auth_token, source_uri)
|
|
153
|
+
elif self.status == "exporting":
|
|
154
|
+
# Download task - get auth_token, hostname, and output_uri from dataset
|
|
155
|
+
auth_token = self._dataset.get("auth_token")
|
|
156
|
+
hostname = self._dataset.get("hostname")
|
|
157
|
+
output_uri = self._dataset.get("output_uri")
|
|
158
|
+
|
|
159
|
+
if not auth_token or not hostname or not output_uri:
|
|
160
|
+
raise SpecificationError(
|
|
161
|
+
"status",
|
|
162
|
+
f"Dataset in exporting status missing required connection properties (auth_token, hostname, output_uri).",
|
|
163
|
+
)
|
|
164
|
+
|
|
165
|
+
await download(hostname, auth_token, output_uri)
|
|
175
166
|
|
|
176
167
|
async def remove(self, force=False):
|
|
177
168
|
await self.trainml._query(
|
|
@@ -210,7 +201,9 @@ class Dataset:
|
|
|
210
201
|
if msg_handler:
|
|
211
202
|
msg_handler(data)
|
|
212
203
|
else:
|
|
213
|
-
timestamp = datetime.fromtimestamp(
|
|
204
|
+
timestamp = datetime.fromtimestamp(
|
|
205
|
+
int(data.get("time")) / 1000
|
|
206
|
+
)
|
|
214
207
|
print(
|
|
215
208
|
f"{timestamp.strftime('%m/%d/%Y, %H:%M:%S')}: {data.get('msg').rstrip()}"
|
|
216
209
|
)
|
|
@@ -239,7 +232,7 @@ class Dataset:
|
|
|
239
232
|
async def wait_for(self, status, timeout=300):
|
|
240
233
|
if self.status == status:
|
|
241
234
|
return
|
|
242
|
-
valid_statuses = ["downloading", "ready", "archived"]
|
|
235
|
+
valid_statuses = ["downloading", "ready", "exporting", "archived"]
|
|
243
236
|
if not status in valid_statuses:
|
|
244
237
|
raise SpecificationError(
|
|
245
238
|
"status",
|
|
@@ -254,7 +247,9 @@ class Dataset:
|
|
|
254
247
|
|
|
255
248
|
POLL_INTERVAL_MIN = 5
|
|
256
249
|
POLL_INTERVAL_MAX = 60
|
|
257
|
-
POLL_INTERVAL = max(
|
|
250
|
+
POLL_INTERVAL = max(
|
|
251
|
+
min(timeout / 60, POLL_INTERVAL_MAX), POLL_INTERVAL_MIN
|
|
252
|
+
)
|
|
258
253
|
retry_count = math.ceil(timeout / POLL_INTERVAL)
|
|
259
254
|
count = 0
|
|
260
255
|
while count < retry_count:
|