trainml 0.5.17__py3-none-any.whl → 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- examples/local_storage.py +0 -2
- tests/integration/test_checkpoints_integration.py +4 -3
- tests/integration/test_datasets_integration.py +5 -3
- tests/integration/test_jobs_integration.py +33 -27
- tests/integration/test_models_integration.py +7 -3
- tests/integration/test_volumes_integration.py +2 -2
- tests/unit/cli/test_cli_checkpoint_unit.py +312 -1
- tests/unit/cloudbender/test_nodes_unit.py +112 -0
- tests/unit/cloudbender/test_providers_unit.py +96 -0
- tests/unit/cloudbender/test_regions_unit.py +106 -0
- tests/unit/cloudbender/test_services_unit.py +141 -0
- tests/unit/conftest.py +23 -10
- tests/unit/projects/test_project_data_connectors_unit.py +39 -0
- tests/unit/projects/test_project_datastores_unit.py +37 -0
- tests/unit/projects/test_project_members_unit.py +46 -0
- tests/unit/projects/test_project_services_unit.py +65 -0
- tests/unit/projects/test_projects_unit.py +16 -0
- tests/unit/test_auth_unit.py +17 -2
- tests/unit/test_checkpoints_unit.py +256 -71
- tests/unit/test_datasets_unit.py +218 -68
- tests/unit/test_exceptions.py +133 -0
- tests/unit/test_gpu_types_unit.py +11 -1
- tests/unit/test_jobs_unit.py +1014 -95
- tests/unit/test_main_unit.py +20 -0
- tests/unit/test_models_unit.py +218 -70
- tests/unit/test_trainml_unit.py +627 -3
- tests/unit/test_volumes_unit.py +211 -70
- tests/unit/utils/__init__.py +1 -0
- tests/unit/utils/test_transfer_unit.py +4260 -0
- trainml/__init__.py +1 -1
- trainml/checkpoints.py +56 -57
- trainml/cli/__init__.py +6 -3
- trainml/cli/checkpoint.py +18 -57
- trainml/cli/dataset.py +17 -57
- trainml/cli/job/__init__.py +11 -53
- trainml/cli/job/create.py +51 -24
- trainml/cli/model.py +14 -56
- trainml/cli/volume.py +18 -57
- trainml/datasets.py +50 -55
- trainml/jobs.py +239 -68
- trainml/models.py +51 -55
- trainml/trainml.py +50 -16
- trainml/utils/__init__.py +1 -0
- trainml/utils/auth.py +641 -0
- trainml/utils/transfer.py +587 -0
- trainml/volumes.py +48 -53
- {trainml-0.5.17.dist-info → trainml-1.0.0.dist-info}/METADATA +3 -3
- {trainml-0.5.17.dist-info → trainml-1.0.0.dist-info}/RECORD +52 -46
- {trainml-0.5.17.dist-info → trainml-1.0.0.dist-info}/LICENSE +0 -0
- {trainml-0.5.17.dist-info → trainml-1.0.0.dist-info}/WHEEL +0 -0
- {trainml-0.5.17.dist-info → trainml-1.0.0.dist-info}/entry_points.txt +0 -0
- {trainml-0.5.17.dist-info → trainml-1.0.0.dist-info}/top_level.txt +0 -0
trainml/__init__.py
CHANGED
trainml/checkpoints.py
CHANGED
|
@@ -10,7 +10,7 @@ from .exceptions import (
|
|
|
10
10
|
SpecificationError,
|
|
11
11
|
TrainMLException,
|
|
12
12
|
)
|
|
13
|
-
from .
|
|
13
|
+
from trainml.utils.transfer import upload, download
|
|
14
14
|
|
|
15
15
|
|
|
16
16
|
class Checkpoints(object):
|
|
@@ -23,7 +23,9 @@ class Checkpoints(object):
|
|
|
23
23
|
|
|
24
24
|
async def list(self, **kwargs):
|
|
25
25
|
resp = await self.trainml._query(f"/checkpoint", "GET", kwargs)
|
|
26
|
-
checkpoints = [
|
|
26
|
+
checkpoints = [
|
|
27
|
+
Checkpoint(self.trainml, **checkpoint) for checkpoint in resp
|
|
28
|
+
]
|
|
27
29
|
return checkpoints
|
|
28
30
|
|
|
29
31
|
async def list_public(self, **kwargs):
|
|
@@ -68,13 +70,17 @@ class Checkpoint:
|
|
|
68
70
|
def __init__(self, trainml, **kwargs):
|
|
69
71
|
self.trainml = trainml
|
|
70
72
|
self._checkpoint = kwargs
|
|
71
|
-
self._id = self._checkpoint.get(
|
|
73
|
+
self._id = self._checkpoint.get(
|
|
74
|
+
"id", self._checkpoint.get("checkpoint_uuid")
|
|
75
|
+
)
|
|
72
76
|
self._status = self._checkpoint.get("status")
|
|
73
77
|
self._name = self._checkpoint.get("name")
|
|
74
|
-
self._size = self._checkpoint.get("size") or self._checkpoint.get(
|
|
75
|
-
|
|
76
|
-
"size"
|
|
78
|
+
self._size = self._checkpoint.get("size") or self._checkpoint.get(
|
|
79
|
+
"used_size"
|
|
77
80
|
)
|
|
81
|
+
self._billed_size = self._checkpoint.get(
|
|
82
|
+
"billed_size"
|
|
83
|
+
) or self._checkpoint.get("size")
|
|
78
84
|
self._project_uuid = self._checkpoint.get("project_uuid")
|
|
79
85
|
|
|
80
86
|
@property
|
|
@@ -122,56 +128,45 @@ class Checkpoint:
|
|
|
122
128
|
)
|
|
123
129
|
return resp
|
|
124
130
|
|
|
125
|
-
async def get_connection_utility_url(self):
|
|
126
|
-
resp = await self.trainml._query(
|
|
127
|
-
f"/checkpoint/{self._id}/download",
|
|
128
|
-
"GET",
|
|
129
|
-
dict(project_uuid=self._project_uuid),
|
|
130
|
-
)
|
|
131
|
-
return resp
|
|
132
|
-
|
|
133
|
-
def get_connection_details(self):
|
|
134
|
-
if self._checkpoint.get("vpn"):
|
|
135
|
-
details = dict(
|
|
136
|
-
entity_type="checkpoint",
|
|
137
|
-
project_uuid=self._checkpoint.get("project_uuid"),
|
|
138
|
-
cidr=self._checkpoint.get("vpn").get("cidr"),
|
|
139
|
-
ssh_port=self._checkpoint.get("vpn").get("client").get("ssh_port"),
|
|
140
|
-
input_path=(
|
|
141
|
-
self._checkpoint.get("source_uri")
|
|
142
|
-
if self.status in ["new", "downloading"]
|
|
143
|
-
else None
|
|
144
|
-
),
|
|
145
|
-
output_path=(
|
|
146
|
-
self._checkpoint.get("output_uri")
|
|
147
|
-
if self.status == "exporting"
|
|
148
|
-
else None
|
|
149
|
-
),
|
|
150
|
-
)
|
|
151
|
-
else:
|
|
152
|
-
details = dict()
|
|
153
|
-
return details
|
|
154
|
-
|
|
155
131
|
async def connect(self):
|
|
156
|
-
if self.status in ["
|
|
157
|
-
|
|
158
|
-
"
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
self.trainml, entity_type="checkpoint", id=self.id, entity=self
|
|
165
|
-
)
|
|
166
|
-
await connection.start()
|
|
167
|
-
return connection.status
|
|
132
|
+
if self.status not in ["downloading", "exporting"]:
|
|
133
|
+
if self.status == "new":
|
|
134
|
+
await self.wait_for("downloading")
|
|
135
|
+
else:
|
|
136
|
+
raise SpecificationError(
|
|
137
|
+
"status",
|
|
138
|
+
f"You can only connect to downloading or exporting checkpoints.",
|
|
139
|
+
)
|
|
168
140
|
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
|
|
141
|
+
# Refresh to get latest entity data
|
|
142
|
+
await self.refresh()
|
|
143
|
+
|
|
144
|
+
if self.status == "downloading":
|
|
145
|
+
# Upload task - get auth_token, hostname, and source_uri from checkpoint
|
|
146
|
+
auth_token = self._checkpoint.get("auth_token")
|
|
147
|
+
hostname = self._checkpoint.get("hostname")
|
|
148
|
+
source_uri = self._checkpoint.get("source_uri")
|
|
149
|
+
|
|
150
|
+
if not auth_token or not hostname or not source_uri:
|
|
151
|
+
raise SpecificationError(
|
|
152
|
+
"status",
|
|
153
|
+
f"Checkpoint in downloading status missing required connection properties (auth_token, hostname, source_uri).",
|
|
154
|
+
)
|
|
155
|
+
|
|
156
|
+
await upload(hostname, auth_token, source_uri)
|
|
157
|
+
elif self.status == "exporting":
|
|
158
|
+
# Download task - get auth_token, hostname, and output_uri from checkpoint
|
|
159
|
+
auth_token = self._checkpoint.get("auth_token")
|
|
160
|
+
hostname = self._checkpoint.get("hostname")
|
|
161
|
+
output_uri = self._checkpoint.get("output_uri")
|
|
162
|
+
|
|
163
|
+
if not auth_token or not hostname or not output_uri:
|
|
164
|
+
raise SpecificationError(
|
|
165
|
+
"status",
|
|
166
|
+
f"Checkpoint in exporting status missing required connection properties (auth_token, hostname, output_uri).",
|
|
167
|
+
)
|
|
168
|
+
|
|
169
|
+
await download(hostname, auth_token, output_uri)
|
|
175
170
|
|
|
176
171
|
async def remove(self, force=False):
|
|
177
172
|
await self.trainml._query(
|
|
@@ -210,7 +205,9 @@ class Checkpoint:
|
|
|
210
205
|
if msg_handler:
|
|
211
206
|
msg_handler(data)
|
|
212
207
|
else:
|
|
213
|
-
timestamp = datetime.fromtimestamp(
|
|
208
|
+
timestamp = datetime.fromtimestamp(
|
|
209
|
+
int(data.get("time")) / 1000
|
|
210
|
+
)
|
|
214
211
|
print(
|
|
215
212
|
f"{timestamp.strftime('%m/%d/%Y, %H:%M:%S')}: {data.get('msg').rstrip()}"
|
|
216
213
|
)
|
|
@@ -239,7 +236,7 @@ class Checkpoint:
|
|
|
239
236
|
async def wait_for(self, status, timeout=300):
|
|
240
237
|
if self.status == status:
|
|
241
238
|
return
|
|
242
|
-
valid_statuses = ["downloading", "ready", "archived"]
|
|
239
|
+
valid_statuses = ["downloading", "ready", "exporting", "archived"]
|
|
243
240
|
if not status in valid_statuses:
|
|
244
241
|
raise SpecificationError(
|
|
245
242
|
"status",
|
|
@@ -254,7 +251,9 @@ class Checkpoint:
|
|
|
254
251
|
)
|
|
255
252
|
POLL_INTERVAL_MIN = 5
|
|
256
253
|
POLL_INTERVAL_MAX = 60
|
|
257
|
-
POLL_INTERVAL = max(
|
|
254
|
+
POLL_INTERVAL = max(
|
|
255
|
+
min(timeout / 60, POLL_INTERVAL_MAX), POLL_INTERVAL_MIN
|
|
256
|
+
)
|
|
258
257
|
retry_count = math.ceil(timeout / POLL_INTERVAL)
|
|
259
258
|
count = 0
|
|
260
259
|
while count < retry_count:
|
trainml/cli/__init__.py
CHANGED
|
@@ -142,7 +142,9 @@ def configure(config):
|
|
|
142
142
|
project for project in projects if project.id == active_project_id
|
|
143
143
|
]
|
|
144
144
|
|
|
145
|
-
active_project_name =
|
|
145
|
+
active_project_name = (
|
|
146
|
+
active_project[0].name if len(active_project) else "UNSET"
|
|
147
|
+
)
|
|
146
148
|
|
|
147
149
|
click.echo(f"Current Active Project: {active_project_name}")
|
|
148
150
|
|
|
@@ -152,11 +154,12 @@ def configure(config):
|
|
|
152
154
|
show_choices=True,
|
|
153
155
|
default=active_project_name,
|
|
154
156
|
)
|
|
155
|
-
selected_project = [
|
|
157
|
+
selected_project = [
|
|
158
|
+
project for project in projects if project.name == name
|
|
159
|
+
]
|
|
156
160
|
config.trainml.client.set_active_project(selected_project[0].id)
|
|
157
161
|
|
|
158
162
|
|
|
159
|
-
from trainml.cli.connection import connection
|
|
160
163
|
from trainml.cli.dataset import dataset
|
|
161
164
|
from trainml.cli.model import model
|
|
162
165
|
from trainml.cli.checkpoint import checkpoint
|
trainml/cli/checkpoint.py
CHANGED
|
@@ -35,15 +35,7 @@ def attach(config, checkpoint):
|
|
|
35
35
|
if None is found:
|
|
36
36
|
raise click.UsageError("Cannot find specified checkpoint.")
|
|
37
37
|
|
|
38
|
-
|
|
39
|
-
config.trainml.run(found.attach())
|
|
40
|
-
return config.trainml.run(found.disconnect())
|
|
41
|
-
except:
|
|
42
|
-
try:
|
|
43
|
-
config.trainml.run(found.disconnect())
|
|
44
|
-
except:
|
|
45
|
-
pass
|
|
46
|
-
raise
|
|
38
|
+
config.trainml.run(found.attach())
|
|
47
39
|
|
|
48
40
|
|
|
49
41
|
@checkpoint.command()
|
|
@@ -67,18 +59,10 @@ def connect(config, checkpoint, attach):
|
|
|
67
59
|
if None is found:
|
|
68
60
|
raise click.UsageError("Cannot find specified checkpoint.")
|
|
69
61
|
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
else:
|
|
75
|
-
return config.trainml.run(found.connect())
|
|
76
|
-
except:
|
|
77
|
-
try:
|
|
78
|
-
config.trainml.run(found.disconnect())
|
|
79
|
-
except:
|
|
80
|
-
pass
|
|
81
|
-
raise
|
|
62
|
+
if attach:
|
|
63
|
+
config.trainml.run(found.connect(), found.attach())
|
|
64
|
+
else:
|
|
65
|
+
config.trainml.run(found.connect())
|
|
82
66
|
|
|
83
67
|
|
|
84
68
|
@checkpoint.command()
|
|
@@ -123,41 +107,15 @@ def create(config, attach, connect, source, name, path):
|
|
|
123
107
|
)
|
|
124
108
|
)
|
|
125
109
|
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
"No logs to show for local sourced checkpoint without connect."
|
|
136
|
-
)
|
|
137
|
-
except:
|
|
138
|
-
try:
|
|
139
|
-
config.trainml.run(checkpoint.disconnect())
|
|
140
|
-
except:
|
|
141
|
-
pass
|
|
142
|
-
raise
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
@checkpoint.command()
|
|
146
|
-
@click.argument("checkpoint", type=click.STRING)
|
|
147
|
-
@pass_config
|
|
148
|
-
def disconnect(config, checkpoint):
|
|
149
|
-
"""
|
|
150
|
-
Disconnect and clean-up checkpoint upload.
|
|
151
|
-
|
|
152
|
-
CHECKPOINT may be specified by name or ID, but ID is preferred.
|
|
153
|
-
"""
|
|
154
|
-
checkpoints = config.trainml.run(config.trainml.client.checkpoints.list())
|
|
155
|
-
|
|
156
|
-
found = search_by_id_name(checkpoint, checkpoints)
|
|
157
|
-
if None is found:
|
|
158
|
-
raise click.UsageError("Cannot find specified checkpoint.")
|
|
159
|
-
|
|
160
|
-
return config.trainml.run(found.disconnect())
|
|
110
|
+
if connect and attach:
|
|
111
|
+
config.trainml.run(checkpoint.attach(), checkpoint.connect())
|
|
112
|
+
elif connect:
|
|
113
|
+
config.trainml.run(checkpoint.connect())
|
|
114
|
+
else:
|
|
115
|
+
raise click.UsageError(
|
|
116
|
+
"Abort!\n"
|
|
117
|
+
"No logs to show for local sourced checkpoint without connect."
|
|
118
|
+
)
|
|
161
119
|
|
|
162
120
|
|
|
163
121
|
@checkpoint.command()
|
|
@@ -236,7 +194,10 @@ def remove(config, checkpoint, force):
|
|
|
236
194
|
found = search_by_id_name(checkpoint, checkpoints)
|
|
237
195
|
if None is found:
|
|
238
196
|
if force:
|
|
239
|
-
config.trainml.run(
|
|
197
|
+
config.trainml.run(
|
|
198
|
+
config.trainml.client.checkpoints.remove(checkpoint)
|
|
199
|
+
)
|
|
200
|
+
return
|
|
240
201
|
else:
|
|
241
202
|
raise click.UsageError("Cannot find specified checkpoint.")
|
|
242
203
|
|
trainml/cli/dataset.py
CHANGED
|
@@ -35,15 +35,7 @@ def attach(config, dataset):
|
|
|
35
35
|
if None is found:
|
|
36
36
|
raise click.UsageError("Cannot find specified dataset.")
|
|
37
37
|
|
|
38
|
-
|
|
39
|
-
config.trainml.run(found.attach())
|
|
40
|
-
return config.trainml.run(found.disconnect())
|
|
41
|
-
except:
|
|
42
|
-
try:
|
|
43
|
-
config.trainml.run(found.disconnect())
|
|
44
|
-
except:
|
|
45
|
-
pass
|
|
46
|
-
raise
|
|
38
|
+
config.trainml.run(found.attach())
|
|
47
39
|
|
|
48
40
|
|
|
49
41
|
@dataset.command()
|
|
@@ -67,18 +59,10 @@ def connect(config, dataset, attach):
|
|
|
67
59
|
if None is found:
|
|
68
60
|
raise click.UsageError("Cannot find specified dataset.")
|
|
69
61
|
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
else:
|
|
75
|
-
return config.trainml.run(found.connect())
|
|
76
|
-
except:
|
|
77
|
-
try:
|
|
78
|
-
config.trainml.run(found.disconnect())
|
|
79
|
-
except:
|
|
80
|
-
pass
|
|
81
|
-
raise
|
|
62
|
+
if attach:
|
|
63
|
+
config.trainml.run(found.connect(), found.attach())
|
|
64
|
+
else:
|
|
65
|
+
config.trainml.run(found.connect())
|
|
82
66
|
|
|
83
67
|
|
|
84
68
|
@dataset.command()
|
|
@@ -123,41 +107,15 @@ def create(config, attach, connect, source, name, path):
|
|
|
123
107
|
)
|
|
124
108
|
)
|
|
125
109
|
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
"No logs to show for local sourced dataset without connect."
|
|
136
|
-
)
|
|
137
|
-
except:
|
|
138
|
-
try:
|
|
139
|
-
config.trainml.run(dataset.disconnect())
|
|
140
|
-
except:
|
|
141
|
-
pass
|
|
142
|
-
raise
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
@dataset.command()
|
|
146
|
-
@click.argument("dataset", type=click.STRING)
|
|
147
|
-
@pass_config
|
|
148
|
-
def disconnect(config, dataset):
|
|
149
|
-
"""
|
|
150
|
-
Disconnect and clean-up dataset upload.
|
|
151
|
-
|
|
152
|
-
DATASET may be specified by name or ID, but ID is preferred.
|
|
153
|
-
"""
|
|
154
|
-
datasets = config.trainml.run(config.trainml.client.datasets.list())
|
|
155
|
-
|
|
156
|
-
found = search_by_id_name(dataset, datasets)
|
|
157
|
-
if None is found:
|
|
158
|
-
raise click.UsageError("Cannot find specified dataset.")
|
|
159
|
-
|
|
160
|
-
return config.trainml.run(found.disconnect())
|
|
110
|
+
if connect and attach:
|
|
111
|
+
config.trainml.run(dataset.attach(), dataset.connect())
|
|
112
|
+
elif connect:
|
|
113
|
+
config.trainml.run(dataset.connect())
|
|
114
|
+
else:
|
|
115
|
+
raise click.UsageError(
|
|
116
|
+
"Abort!\n"
|
|
117
|
+
"No logs to show for local sourced dataset without connect."
|
|
118
|
+
)
|
|
161
119
|
|
|
162
120
|
|
|
163
121
|
@dataset.command()
|
|
@@ -252,7 +210,9 @@ def rename(config, dataset, name):
|
|
|
252
210
|
DATASET may be specified by name or ID, but ID is preferred.
|
|
253
211
|
"""
|
|
254
212
|
try:
|
|
255
|
-
dataset = config.trainml.run(
|
|
213
|
+
dataset = config.trainml.run(
|
|
214
|
+
config.trainml.client.datasets.get(dataset)
|
|
215
|
+
)
|
|
256
216
|
if dataset is None:
|
|
257
217
|
raise click.UsageError("Cannot find specified dataset.")
|
|
258
218
|
except:
|
trainml/cli/job/__init__.py
CHANGED
|
@@ -25,15 +25,7 @@ def attach(config, job):
|
|
|
25
25
|
if None is found:
|
|
26
26
|
raise click.UsageError("Cannot find specified job.")
|
|
27
27
|
|
|
28
|
-
|
|
29
|
-
config.trainml.run(found.attach())
|
|
30
|
-
return config.trainml.run(found.disconnect())
|
|
31
|
-
except:
|
|
32
|
-
try:
|
|
33
|
-
config.trainml.run(found.disconnect())
|
|
34
|
-
except:
|
|
35
|
-
pass
|
|
36
|
-
raise
|
|
28
|
+
config.trainml.run(found.attach())
|
|
37
29
|
|
|
38
30
|
|
|
39
31
|
@job.command()
|
|
@@ -58,38 +50,22 @@ def connect(config, job, attach):
|
|
|
58
50
|
raise click.UsageError("Cannot find specified job.")
|
|
59
51
|
|
|
60
52
|
if found.type != "notebook":
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
else:
|
|
66
|
-
return config.trainml.run(found.connect())
|
|
67
|
-
except:
|
|
68
|
-
try:
|
|
69
|
-
config.trainml.run(found.disconnect())
|
|
70
|
-
except:
|
|
71
|
-
pass
|
|
72
|
-
raise
|
|
53
|
+
if attach:
|
|
54
|
+
config.trainml.run(found.connect(), found.attach())
|
|
55
|
+
else:
|
|
56
|
+
config.trainml.run(found.connect())
|
|
73
57
|
else:
|
|
74
58
|
if found.status in [
|
|
75
59
|
"new",
|
|
76
60
|
"waiting for data/model download",
|
|
77
61
|
"waiting for GPUs",
|
|
78
62
|
]:
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
else:
|
|
86
|
-
return config.trainml.run(found.connect())
|
|
87
|
-
except:
|
|
88
|
-
try:
|
|
89
|
-
config.trainml.run(found.disconnect())
|
|
90
|
-
except:
|
|
91
|
-
pass
|
|
92
|
-
raise
|
|
63
|
+
if attach:
|
|
64
|
+
config.trainml.run(found.connect(), found.attach())
|
|
65
|
+
click.echo("Launching...", file=config.stdout)
|
|
66
|
+
browse(found.notebook_url)
|
|
67
|
+
else:
|
|
68
|
+
config.trainml.run(found.connect())
|
|
93
69
|
elif found.status not in [
|
|
94
70
|
"starting",
|
|
95
71
|
"running",
|
|
@@ -103,24 +79,6 @@ def connect(config, job, attach):
|
|
|
103
79
|
browse(found.notebook_url)
|
|
104
80
|
|
|
105
81
|
|
|
106
|
-
@job.command()
|
|
107
|
-
@click.argument("job", type=click.STRING)
|
|
108
|
-
@pass_config
|
|
109
|
-
def disconnect(config, job):
|
|
110
|
-
"""
|
|
111
|
-
Disconnect and clean-up job.
|
|
112
|
-
|
|
113
|
-
JOB may be specified by name or ID, but ID is preferred.
|
|
114
|
-
"""
|
|
115
|
-
jobs = config.trainml.run(config.trainml.client.jobs.list())
|
|
116
|
-
|
|
117
|
-
found = search_by_id_name(job, jobs)
|
|
118
|
-
if None is found:
|
|
119
|
-
raise click.UsageError("Cannot find specified job.")
|
|
120
|
-
|
|
121
|
-
return config.trainml.run(found.disconnect())
|
|
122
|
-
|
|
123
|
-
|
|
124
82
|
@job.command()
|
|
125
83
|
@click.option(
|
|
126
84
|
"--wait/--no-wait",
|
trainml/cli/job/create.py
CHANGED
|
@@ -275,7 +275,9 @@ def notebook(
|
|
|
275
275
|
options["environment"]["type"] = environment
|
|
276
276
|
|
|
277
277
|
try:
|
|
278
|
-
envs = [
|
|
278
|
+
envs = [
|
|
279
|
+
{"key": e.split("=")[0], "value": e.split("=")[1]} for e in env
|
|
280
|
+
]
|
|
279
281
|
options["environment"]["env"] = envs
|
|
280
282
|
except IndexError:
|
|
281
283
|
raise click.UsageError(
|
|
@@ -289,21 +291,25 @@ def notebook(
|
|
|
289
291
|
if pip_packages:
|
|
290
292
|
options["environment"]["packages"]["pip"] = pip_packages.split(",")
|
|
291
293
|
if conda_packages:
|
|
292
|
-
options["environment"]["packages"]["conda"] = conda_packages.split(
|
|
294
|
+
options["environment"]["packages"]["conda"] = conda_packages.split(
|
|
295
|
+
","
|
|
296
|
+
)
|
|
293
297
|
|
|
294
298
|
if data_dir:
|
|
295
299
|
click.echo("Creating Dataset..", file=config.stdout)
|
|
296
300
|
new_dataset = config.trainml.run(
|
|
297
|
-
config.trainml.client.datasets.create(
|
|
301
|
+
config.trainml.client.datasets.create(
|
|
302
|
+
f"Job - {name}", "local", data_dir
|
|
303
|
+
)
|
|
298
304
|
)
|
|
299
305
|
if attach:
|
|
300
306
|
config.trainml.run(new_dataset.attach(), new_dataset.connect())
|
|
301
|
-
config.trainml.run(new_dataset.disconnect())
|
|
302
307
|
else:
|
|
303
308
|
config.trainml.run(new_dataset.connect())
|
|
304
309
|
config.trainml.run(new_dataset.wait_for("ready"))
|
|
305
|
-
|
|
306
|
-
|
|
310
|
+
options["data"]["datasets"].append(
|
|
311
|
+
dict(id=new_dataset.id, type="existing")
|
|
312
|
+
)
|
|
307
313
|
|
|
308
314
|
if git_uri:
|
|
309
315
|
options["model"]["source_type"] = "git"
|
|
@@ -331,13 +337,11 @@ def notebook(
|
|
|
331
337
|
if attach or connect:
|
|
332
338
|
click.echo("Waiting for job to start...", file=config.stdout)
|
|
333
339
|
config.trainml.run(job.connect(), job.attach())
|
|
334
|
-
config.trainml.run(job.disconnect())
|
|
335
340
|
click.echo("Launching...", file=config.stdout)
|
|
336
341
|
browse(job.notebook_url)
|
|
337
342
|
else:
|
|
338
343
|
config.trainml.run(job.connect())
|
|
339
344
|
config.trainml.run(job.wait_for("running"))
|
|
340
|
-
config.trainml.run(job.disconnect())
|
|
341
345
|
elif attach or connect:
|
|
342
346
|
click.echo("Waiting for job to start...", file=config.stdout)
|
|
343
347
|
config.trainml.run(job.wait_for("running", timeout))
|
|
@@ -626,15 +630,21 @@ def training(
|
|
|
626
630
|
if output_type:
|
|
627
631
|
options["data"]["output_type"] = output_type
|
|
628
632
|
options["data"]["output_uri"] = output_uri
|
|
629
|
-
options["data"]["output_options"] = dict(
|
|
633
|
+
options["data"]["output_options"] = dict(
|
|
634
|
+
archive=archive, save_model=save_model
|
|
635
|
+
)
|
|
630
636
|
|
|
631
637
|
if output_dir:
|
|
632
638
|
options["data"]["output_type"] = "local"
|
|
633
639
|
options["data"]["output_uri"] = output_dir
|
|
634
|
-
options["data"]["output_options"] = dict(
|
|
640
|
+
options["data"]["output_options"] = dict(
|
|
641
|
+
archive=archive, save_model=save_model
|
|
642
|
+
)
|
|
635
643
|
|
|
636
644
|
try:
|
|
637
|
-
envs = [
|
|
645
|
+
envs = [
|
|
646
|
+
{"key": e.split("=")[0], "value": e.split("=")[1]} for e in env
|
|
647
|
+
]
|
|
638
648
|
options["environment"]["env"] = envs
|
|
639
649
|
except IndexError:
|
|
640
650
|
raise click.UsageError(
|
|
@@ -648,21 +658,25 @@ def training(
|
|
|
648
658
|
if pip_packages:
|
|
649
659
|
options["environment"]["packages"]["pip"] = pip_packages.split(",")
|
|
650
660
|
if conda_packages:
|
|
651
|
-
options["environment"]["packages"]["conda"] = conda_packages.split(
|
|
661
|
+
options["environment"]["packages"]["conda"] = conda_packages.split(
|
|
662
|
+
","
|
|
663
|
+
)
|
|
652
664
|
|
|
653
665
|
if data_dir:
|
|
654
666
|
click.echo("Creating Dataset..", file=config.stdout)
|
|
655
667
|
new_dataset = config.trainml.run(
|
|
656
|
-
config.trainml.client.datasets.create(
|
|
668
|
+
config.trainml.client.datasets.create(
|
|
669
|
+
f"Job - {name}", "local", data_dir
|
|
670
|
+
)
|
|
657
671
|
)
|
|
658
672
|
if attach:
|
|
659
673
|
config.trainml.run(new_dataset.attach(), new_dataset.connect())
|
|
660
|
-
config.trainml.run(new_dataset.disconnect())
|
|
661
674
|
else:
|
|
662
675
|
config.trainml.run(new_dataset.connect())
|
|
663
676
|
config.trainml.run(new_dataset.wait_for("ready"))
|
|
664
|
-
|
|
665
|
-
|
|
677
|
+
options["data"]["datasets"].append(
|
|
678
|
+
dict(id=new_dataset.id, type="existing")
|
|
679
|
+
)
|
|
666
680
|
|
|
667
681
|
if git_uri:
|
|
668
682
|
options["model"]["source_type"] = "git"
|
|
@@ -979,15 +993,21 @@ def inference(
|
|
|
979
993
|
if output_type:
|
|
980
994
|
options["data"]["output_type"] = output_type
|
|
981
995
|
options["data"]["output_uri"] = output_uri
|
|
982
|
-
options["data"]["output_options"] = dict(
|
|
996
|
+
options["data"]["output_options"] = dict(
|
|
997
|
+
archive=archive, save_model=save_model
|
|
998
|
+
)
|
|
983
999
|
|
|
984
1000
|
if output_dir:
|
|
985
1001
|
options["data"]["output_type"] = "local"
|
|
986
1002
|
options["data"]["output_uri"] = output_dir
|
|
987
|
-
options["data"]["output_options"] = dict(
|
|
1003
|
+
options["data"]["output_options"] = dict(
|
|
1004
|
+
archive=archive, save_model=save_model
|
|
1005
|
+
)
|
|
988
1006
|
|
|
989
1007
|
try:
|
|
990
|
-
envs = [
|
|
1008
|
+
envs = [
|
|
1009
|
+
{"key": e.split("=")[0], "value": e.split("=")[1]} for e in env
|
|
1010
|
+
]
|
|
991
1011
|
options["environment"]["env"] = envs
|
|
992
1012
|
except IndexError:
|
|
993
1013
|
raise click.UsageError(
|
|
@@ -1001,7 +1021,9 @@ def inference(
|
|
|
1001
1021
|
if pip_packages:
|
|
1002
1022
|
options["environment"]["packages"]["pip"] = pip_packages.split(",")
|
|
1003
1023
|
if conda_packages:
|
|
1004
|
-
options["environment"]["packages"]["conda"] = conda_packages.split(
|
|
1024
|
+
options["environment"]["packages"]["conda"] = conda_packages.split(
|
|
1025
|
+
","
|
|
1026
|
+
)
|
|
1005
1027
|
|
|
1006
1028
|
if git_uri:
|
|
1007
1029
|
options["model"]["source_type"] = "git"
|
|
@@ -1301,7 +1323,9 @@ def endpoint(
|
|
|
1301
1323
|
options["environment"]["type"] = environment
|
|
1302
1324
|
|
|
1303
1325
|
try:
|
|
1304
|
-
envs = [
|
|
1326
|
+
envs = [
|
|
1327
|
+
{"key": e.split("=")[0], "value": e.split("=")[1]} for e in env
|
|
1328
|
+
]
|
|
1305
1329
|
options["environment"]["env"] = envs
|
|
1306
1330
|
except IndexError:
|
|
1307
1331
|
raise click.UsageError(
|
|
@@ -1315,7 +1339,9 @@ def endpoint(
|
|
|
1315
1339
|
if pip_packages:
|
|
1316
1340
|
options["environment"]["packages"]["pip"] = pip_packages.split(",")
|
|
1317
1341
|
if conda_packages:
|
|
1318
|
-
options["environment"]["packages"]["conda"] = conda_packages.split(
|
|
1342
|
+
options["environment"]["packages"]["conda"] = conda_packages.split(
|
|
1343
|
+
","
|
|
1344
|
+
)
|
|
1319
1345
|
|
|
1320
1346
|
if git_uri:
|
|
1321
1347
|
options["model"]["source_type"] = "git"
|
|
@@ -1349,7 +1375,6 @@ def endpoint(
|
|
|
1349
1375
|
config.trainml.run(job.connect())
|
|
1350
1376
|
click.echo("Waiting for job to start...", file=config.stdout)
|
|
1351
1377
|
config.trainml.run(job.wait_for("running", timeout))
|
|
1352
|
-
config.trainml.run(job.disconnect())
|
|
1353
1378
|
config.trainml.run(job.refresh())
|
|
1354
1379
|
click.echo(f"Endpoint is running at: {job.url}", file=config.stdout)
|
|
1355
1380
|
else:
|
|
@@ -1357,4 +1382,6 @@ def endpoint(
|
|
|
1357
1382
|
click.echo("Waiting for job to start...", file=config.stdout)
|
|
1358
1383
|
config.trainml.run(job.wait_for("running", timeout))
|
|
1359
1384
|
config.trainml.run(job.refresh())
|
|
1360
|
-
click.echo(
|
|
1385
|
+
click.echo(
|
|
1386
|
+
f"Endpoint is running at: {job.url}", file=config.stdout
|
|
1387
|
+
)
|