trainml 0.5.17__py3-none-any.whl → 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (52) hide show
  1. examples/local_storage.py +0 -2
  2. tests/integration/test_checkpoints_integration.py +4 -3
  3. tests/integration/test_datasets_integration.py +5 -3
  4. tests/integration/test_jobs_integration.py +33 -27
  5. tests/integration/test_models_integration.py +7 -3
  6. tests/integration/test_volumes_integration.py +2 -2
  7. tests/unit/cli/test_cli_checkpoint_unit.py +312 -1
  8. tests/unit/cloudbender/test_nodes_unit.py +112 -0
  9. tests/unit/cloudbender/test_providers_unit.py +96 -0
  10. tests/unit/cloudbender/test_regions_unit.py +106 -0
  11. tests/unit/cloudbender/test_services_unit.py +141 -0
  12. tests/unit/conftest.py +23 -10
  13. tests/unit/projects/test_project_data_connectors_unit.py +39 -0
  14. tests/unit/projects/test_project_datastores_unit.py +37 -0
  15. tests/unit/projects/test_project_members_unit.py +46 -0
  16. tests/unit/projects/test_project_services_unit.py +65 -0
  17. tests/unit/projects/test_projects_unit.py +16 -0
  18. tests/unit/test_auth_unit.py +17 -2
  19. tests/unit/test_checkpoints_unit.py +256 -71
  20. tests/unit/test_datasets_unit.py +218 -68
  21. tests/unit/test_exceptions.py +133 -0
  22. tests/unit/test_gpu_types_unit.py +11 -1
  23. tests/unit/test_jobs_unit.py +1014 -95
  24. tests/unit/test_main_unit.py +20 -0
  25. tests/unit/test_models_unit.py +218 -70
  26. tests/unit/test_trainml_unit.py +627 -3
  27. tests/unit/test_volumes_unit.py +211 -70
  28. tests/unit/utils/__init__.py +1 -0
  29. tests/unit/utils/test_transfer_unit.py +4260 -0
  30. trainml/__init__.py +1 -1
  31. trainml/checkpoints.py +56 -57
  32. trainml/cli/__init__.py +6 -3
  33. trainml/cli/checkpoint.py +18 -57
  34. trainml/cli/dataset.py +17 -57
  35. trainml/cli/job/__init__.py +11 -53
  36. trainml/cli/job/create.py +51 -24
  37. trainml/cli/model.py +14 -56
  38. trainml/cli/volume.py +18 -57
  39. trainml/datasets.py +50 -55
  40. trainml/jobs.py +239 -68
  41. trainml/models.py +51 -55
  42. trainml/trainml.py +50 -16
  43. trainml/utils/__init__.py +1 -0
  44. trainml/utils/auth.py +641 -0
  45. trainml/utils/transfer.py +587 -0
  46. trainml/volumes.py +48 -53
  47. {trainml-0.5.17.dist-info → trainml-1.0.0.dist-info}/METADATA +3 -3
  48. {trainml-0.5.17.dist-info → trainml-1.0.0.dist-info}/RECORD +52 -46
  49. {trainml-0.5.17.dist-info → trainml-1.0.0.dist-info}/LICENSE +0 -0
  50. {trainml-0.5.17.dist-info → trainml-1.0.0.dist-info}/WHEEL +0 -0
  51. {trainml-0.5.17.dist-info → trainml-1.0.0.dist-info}/entry_points.txt +0 -0
  52. {trainml-0.5.17.dist-info → trainml-1.0.0.dist-info}/top_level.txt +0 -0
trainml/__init__.py CHANGED
@@ -13,5 +13,5 @@ logging.basicConfig(
13
13
  logger = logging.getLogger(__name__)
14
14
 
15
15
 
16
- __version__ = "0.5.17"
16
+ __version__ = "1.0.0"
17
17
  __all__ = "TrainML"
trainml/checkpoints.py CHANGED
@@ -10,7 +10,7 @@ from .exceptions import (
10
10
  SpecificationError,
11
11
  TrainMLException,
12
12
  )
13
- from .connections import Connection
13
+ from trainml.utils.transfer import upload, download
14
14
 
15
15
 
16
16
  class Checkpoints(object):
@@ -23,7 +23,9 @@ class Checkpoints(object):
23
23
 
24
24
  async def list(self, **kwargs):
25
25
  resp = await self.trainml._query(f"/checkpoint", "GET", kwargs)
26
- checkpoints = [Checkpoint(self.trainml, **checkpoint) for checkpoint in resp]
26
+ checkpoints = [
27
+ Checkpoint(self.trainml, **checkpoint) for checkpoint in resp
28
+ ]
27
29
  return checkpoints
28
30
 
29
31
  async def list_public(self, **kwargs):
@@ -68,13 +70,17 @@ class Checkpoint:
68
70
  def __init__(self, trainml, **kwargs):
69
71
  self.trainml = trainml
70
72
  self._checkpoint = kwargs
71
- self._id = self._checkpoint.get("id", self._checkpoint.get("checkpoint_uuid"))
73
+ self._id = self._checkpoint.get(
74
+ "id", self._checkpoint.get("checkpoint_uuid")
75
+ )
72
76
  self._status = self._checkpoint.get("status")
73
77
  self._name = self._checkpoint.get("name")
74
- self._size = self._checkpoint.get("size") or self._checkpoint.get("used_size")
75
- self._billed_size = self._checkpoint.get("billed_size") or self._checkpoint.get(
76
- "size"
78
+ self._size = self._checkpoint.get("size") or self._checkpoint.get(
79
+ "used_size"
77
80
  )
81
+ self._billed_size = self._checkpoint.get(
82
+ "billed_size"
83
+ ) or self._checkpoint.get("size")
78
84
  self._project_uuid = self._checkpoint.get("project_uuid")
79
85
 
80
86
  @property
@@ -122,56 +128,45 @@ class Checkpoint:
122
128
  )
123
129
  return resp
124
130
 
125
- async def get_connection_utility_url(self):
126
- resp = await self.trainml._query(
127
- f"/checkpoint/{self._id}/download",
128
- "GET",
129
- dict(project_uuid=self._project_uuid),
130
- )
131
- return resp
132
-
133
- def get_connection_details(self):
134
- if self._checkpoint.get("vpn"):
135
- details = dict(
136
- entity_type="checkpoint",
137
- project_uuid=self._checkpoint.get("project_uuid"),
138
- cidr=self._checkpoint.get("vpn").get("cidr"),
139
- ssh_port=self._checkpoint.get("vpn").get("client").get("ssh_port"),
140
- input_path=(
141
- self._checkpoint.get("source_uri")
142
- if self.status in ["new", "downloading"]
143
- else None
144
- ),
145
- output_path=(
146
- self._checkpoint.get("output_uri")
147
- if self.status == "exporting"
148
- else None
149
- ),
150
- )
151
- else:
152
- details = dict()
153
- return details
154
-
155
131
  async def connect(self):
156
- if self.status in ["ready", "failed"]:
157
- raise SpecificationError(
158
- "status",
159
- f"You can only connect to downloading or exporting checkpoints.",
160
- )
161
- if self.status == "new":
162
- await self.wait_for("downloading")
163
- connection = Connection(
164
- self.trainml, entity_type="checkpoint", id=self.id, entity=self
165
- )
166
- await connection.start()
167
- return connection.status
132
+ if self.status not in ["downloading", "exporting"]:
133
+ if self.status == "new":
134
+ await self.wait_for("downloading")
135
+ else:
136
+ raise SpecificationError(
137
+ "status",
138
+ f"You can only connect to downloading or exporting checkpoints.",
139
+ )
168
140
 
169
- async def disconnect(self):
170
- connection = Connection(
171
- self.trainml, entity_type="checkpoint", id=self.id, entity=self
172
- )
173
- await connection.stop()
174
- return connection.status
141
+ # Refresh to get latest entity data
142
+ await self.refresh()
143
+
144
+ if self.status == "downloading":
145
+ # Upload task - get auth_token, hostname, and source_uri from checkpoint
146
+ auth_token = self._checkpoint.get("auth_token")
147
+ hostname = self._checkpoint.get("hostname")
148
+ source_uri = self._checkpoint.get("source_uri")
149
+
150
+ if not auth_token or not hostname or not source_uri:
151
+ raise SpecificationError(
152
+ "status",
153
+ f"Checkpoint in downloading status missing required connection properties (auth_token, hostname, source_uri).",
154
+ )
155
+
156
+ await upload(hostname, auth_token, source_uri)
157
+ elif self.status == "exporting":
158
+ # Download task - get auth_token, hostname, and output_uri from checkpoint
159
+ auth_token = self._checkpoint.get("auth_token")
160
+ hostname = self._checkpoint.get("hostname")
161
+ output_uri = self._checkpoint.get("output_uri")
162
+
163
+ if not auth_token or not hostname or not output_uri:
164
+ raise SpecificationError(
165
+ "status",
166
+ f"Checkpoint in exporting status missing required connection properties (auth_token, hostname, output_uri).",
167
+ )
168
+
169
+ await download(hostname, auth_token, output_uri)
175
170
 
176
171
  async def remove(self, force=False):
177
172
  await self.trainml._query(
@@ -210,7 +205,9 @@ class Checkpoint:
210
205
  if msg_handler:
211
206
  msg_handler(data)
212
207
  else:
213
- timestamp = datetime.fromtimestamp(int(data.get("time")) / 1000)
208
+ timestamp = datetime.fromtimestamp(
209
+ int(data.get("time")) / 1000
210
+ )
214
211
  print(
215
212
  f"{timestamp.strftime('%m/%d/%Y, %H:%M:%S')}: {data.get('msg').rstrip()}"
216
213
  )
@@ -239,7 +236,7 @@ class Checkpoint:
239
236
  async def wait_for(self, status, timeout=300):
240
237
  if self.status == status:
241
238
  return
242
- valid_statuses = ["downloading", "ready", "archived"]
239
+ valid_statuses = ["downloading", "ready", "exporting", "archived"]
243
240
  if not status in valid_statuses:
244
241
  raise SpecificationError(
245
242
  "status",
@@ -254,7 +251,9 @@ class Checkpoint:
254
251
  )
255
252
  POLL_INTERVAL_MIN = 5
256
253
  POLL_INTERVAL_MAX = 60
257
- POLL_INTERVAL = max(min(timeout / 60, POLL_INTERVAL_MAX), POLL_INTERVAL_MIN)
254
+ POLL_INTERVAL = max(
255
+ min(timeout / 60, POLL_INTERVAL_MAX), POLL_INTERVAL_MIN
256
+ )
258
257
  retry_count = math.ceil(timeout / POLL_INTERVAL)
259
258
  count = 0
260
259
  while count < retry_count:
trainml/cli/__init__.py CHANGED
@@ -142,7 +142,9 @@ def configure(config):
142
142
  project for project in projects if project.id == active_project_id
143
143
  ]
144
144
 
145
- active_project_name = active_project[0].name if len(active_project) else "UNSET"
145
+ active_project_name = (
146
+ active_project[0].name if len(active_project) else "UNSET"
147
+ )
146
148
 
147
149
  click.echo(f"Current Active Project: {active_project_name}")
148
150
 
@@ -152,11 +154,12 @@ def configure(config):
152
154
  show_choices=True,
153
155
  default=active_project_name,
154
156
  )
155
- selected_project = [project for project in projects if project.name == name]
157
+ selected_project = [
158
+ project for project in projects if project.name == name
159
+ ]
156
160
  config.trainml.client.set_active_project(selected_project[0].id)
157
161
 
158
162
 
159
- from trainml.cli.connection import connection
160
163
  from trainml.cli.dataset import dataset
161
164
  from trainml.cli.model import model
162
165
  from trainml.cli.checkpoint import checkpoint
trainml/cli/checkpoint.py CHANGED
@@ -35,15 +35,7 @@ def attach(config, checkpoint):
35
35
  if None is found:
36
36
  raise click.UsageError("Cannot find specified checkpoint.")
37
37
 
38
- try:
39
- config.trainml.run(found.attach())
40
- return config.trainml.run(found.disconnect())
41
- except:
42
- try:
43
- config.trainml.run(found.disconnect())
44
- except:
45
- pass
46
- raise
38
+ config.trainml.run(found.attach())
47
39
 
48
40
 
49
41
  @checkpoint.command()
@@ -67,18 +59,10 @@ def connect(config, checkpoint, attach):
67
59
  if None is found:
68
60
  raise click.UsageError("Cannot find specified checkpoint.")
69
61
 
70
- try:
71
- if attach:
72
- config.trainml.run(found.connect(), found.attach())
73
- return config.trainml.run(found.disconnect())
74
- else:
75
- return config.trainml.run(found.connect())
76
- except:
77
- try:
78
- config.trainml.run(found.disconnect())
79
- except:
80
- pass
81
- raise
62
+ if attach:
63
+ config.trainml.run(found.connect(), found.attach())
64
+ else:
65
+ config.trainml.run(found.connect())
82
66
 
83
67
 
84
68
  @checkpoint.command()
@@ -123,41 +107,15 @@ def create(config, attach, connect, source, name, path):
123
107
  )
124
108
  )
125
109
 
126
- try:
127
- if connect and attach:
128
- config.trainml.run(checkpoint.attach(), checkpoint.connect())
129
- return config.trainml.run(checkpoint.disconnect())
130
- elif connect:
131
- return config.trainml.run(checkpoint.connect())
132
- else:
133
- raise click.UsageError(
134
- "Abort!\n"
135
- "No logs to show for local sourced checkpoint without connect."
136
- )
137
- except:
138
- try:
139
- config.trainml.run(checkpoint.disconnect())
140
- except:
141
- pass
142
- raise
143
-
144
-
145
- @checkpoint.command()
146
- @click.argument("checkpoint", type=click.STRING)
147
- @pass_config
148
- def disconnect(config, checkpoint):
149
- """
150
- Disconnect and clean-up checkpoint upload.
151
-
152
- CHECKPOINT may be specified by name or ID, but ID is preferred.
153
- """
154
- checkpoints = config.trainml.run(config.trainml.client.checkpoints.list())
155
-
156
- found = search_by_id_name(checkpoint, checkpoints)
157
- if None is found:
158
- raise click.UsageError("Cannot find specified checkpoint.")
159
-
160
- return config.trainml.run(found.disconnect())
110
+ if connect and attach:
111
+ config.trainml.run(checkpoint.attach(), checkpoint.connect())
112
+ elif connect:
113
+ config.trainml.run(checkpoint.connect())
114
+ else:
115
+ raise click.UsageError(
116
+ "Abort!\n"
117
+ "No logs to show for local sourced checkpoint without connect."
118
+ )
161
119
 
162
120
 
163
121
  @checkpoint.command()
@@ -236,7 +194,10 @@ def remove(config, checkpoint, force):
236
194
  found = search_by_id_name(checkpoint, checkpoints)
237
195
  if None is found:
238
196
  if force:
239
- config.trainml.run(found.client.checkpoints.remove(checkpoint))
197
+ config.trainml.run(
198
+ config.trainml.client.checkpoints.remove(checkpoint)
199
+ )
200
+ return
240
201
  else:
241
202
  raise click.UsageError("Cannot find specified checkpoint.")
242
203
 
trainml/cli/dataset.py CHANGED
@@ -35,15 +35,7 @@ def attach(config, dataset):
35
35
  if None is found:
36
36
  raise click.UsageError("Cannot find specified dataset.")
37
37
 
38
- try:
39
- config.trainml.run(found.attach())
40
- return config.trainml.run(found.disconnect())
41
- except:
42
- try:
43
- config.trainml.run(found.disconnect())
44
- except:
45
- pass
46
- raise
38
+ config.trainml.run(found.attach())
47
39
 
48
40
 
49
41
  @dataset.command()
@@ -67,18 +59,10 @@ def connect(config, dataset, attach):
67
59
  if None is found:
68
60
  raise click.UsageError("Cannot find specified dataset.")
69
61
 
70
- try:
71
- if attach:
72
- config.trainml.run(found.connect(), found.attach())
73
- return config.trainml.run(found.disconnect())
74
- else:
75
- return config.trainml.run(found.connect())
76
- except:
77
- try:
78
- config.trainml.run(found.disconnect())
79
- except:
80
- pass
81
- raise
62
+ if attach:
63
+ config.trainml.run(found.connect(), found.attach())
64
+ else:
65
+ config.trainml.run(found.connect())
82
66
 
83
67
 
84
68
  @dataset.command()
@@ -123,41 +107,15 @@ def create(config, attach, connect, source, name, path):
123
107
  )
124
108
  )
125
109
 
126
- try:
127
- if connect and attach:
128
- config.trainml.run(dataset.attach(), dataset.connect())
129
- return config.trainml.run(dataset.disconnect())
130
- elif connect:
131
- return config.trainml.run(dataset.connect())
132
- else:
133
- raise click.UsageError(
134
- "Abort!\n"
135
- "No logs to show for local sourced dataset without connect."
136
- )
137
- except:
138
- try:
139
- config.trainml.run(dataset.disconnect())
140
- except:
141
- pass
142
- raise
143
-
144
-
145
- @dataset.command()
146
- @click.argument("dataset", type=click.STRING)
147
- @pass_config
148
- def disconnect(config, dataset):
149
- """
150
- Disconnect and clean-up dataset upload.
151
-
152
- DATASET may be specified by name or ID, but ID is preferred.
153
- """
154
- datasets = config.trainml.run(config.trainml.client.datasets.list())
155
-
156
- found = search_by_id_name(dataset, datasets)
157
- if None is found:
158
- raise click.UsageError("Cannot find specified dataset.")
159
-
160
- return config.trainml.run(found.disconnect())
110
+ if connect and attach:
111
+ config.trainml.run(dataset.attach(), dataset.connect())
112
+ elif connect:
113
+ config.trainml.run(dataset.connect())
114
+ else:
115
+ raise click.UsageError(
116
+ "Abort!\n"
117
+ "No logs to show for local sourced dataset without connect."
118
+ )
161
119
 
162
120
 
163
121
  @dataset.command()
@@ -252,7 +210,9 @@ def rename(config, dataset, name):
252
210
  DATASET may be specified by name or ID, but ID is preferred.
253
211
  """
254
212
  try:
255
- dataset = config.trainml.run(config.trainml.client.datasets.get(dataset))
213
+ dataset = config.trainml.run(
214
+ config.trainml.client.datasets.get(dataset)
215
+ )
256
216
  if dataset is None:
257
217
  raise click.UsageError("Cannot find specified dataset.")
258
218
  except:
@@ -25,15 +25,7 @@ def attach(config, job):
25
25
  if None is found:
26
26
  raise click.UsageError("Cannot find specified job.")
27
27
 
28
- try:
29
- config.trainml.run(found.attach())
30
- return config.trainml.run(found.disconnect())
31
- except:
32
- try:
33
- config.trainml.run(found.disconnect())
34
- except:
35
- pass
36
- raise
28
+ config.trainml.run(found.attach())
37
29
 
38
30
 
39
31
  @job.command()
@@ -58,38 +50,22 @@ def connect(config, job, attach):
58
50
  raise click.UsageError("Cannot find specified job.")
59
51
 
60
52
  if found.type != "notebook":
61
- try:
62
- if attach:
63
- config.trainml.run(found.connect(), found.attach())
64
- return config.trainml.run(found.disconnect())
65
- else:
66
- return config.trainml.run(found.connect())
67
- except:
68
- try:
69
- config.trainml.run(found.disconnect())
70
- except:
71
- pass
72
- raise
53
+ if attach:
54
+ config.trainml.run(found.connect(), found.attach())
55
+ else:
56
+ config.trainml.run(found.connect())
73
57
  else:
74
58
  if found.status in [
75
59
  "new",
76
60
  "waiting for data/model download",
77
61
  "waiting for GPUs",
78
62
  ]:
79
- try:
80
- if attach:
81
- config.trainml.run(found.connect(), found.attach())
82
- config.trainml.run(found.disconnect())
83
- click.echo("Launching...", file=config.stdout)
84
- browse(found.notebook_url)
85
- else:
86
- return config.trainml.run(found.connect())
87
- except:
88
- try:
89
- config.trainml.run(found.disconnect())
90
- except:
91
- pass
92
- raise
63
+ if attach:
64
+ config.trainml.run(found.connect(), found.attach())
65
+ click.echo("Launching...", file=config.stdout)
66
+ browse(found.notebook_url)
67
+ else:
68
+ config.trainml.run(found.connect())
93
69
  elif found.status not in [
94
70
  "starting",
95
71
  "running",
@@ -103,24 +79,6 @@ def connect(config, job, attach):
103
79
  browse(found.notebook_url)
104
80
 
105
81
 
106
- @job.command()
107
- @click.argument("job", type=click.STRING)
108
- @pass_config
109
- def disconnect(config, job):
110
- """
111
- Disconnect and clean-up job.
112
-
113
- JOB may be specified by name or ID, but ID is preferred.
114
- """
115
- jobs = config.trainml.run(config.trainml.client.jobs.list())
116
-
117
- found = search_by_id_name(job, jobs)
118
- if None is found:
119
- raise click.UsageError("Cannot find specified job.")
120
-
121
- return config.trainml.run(found.disconnect())
122
-
123
-
124
82
  @job.command()
125
83
  @click.option(
126
84
  "--wait/--no-wait",
trainml/cli/job/create.py CHANGED
@@ -275,7 +275,9 @@ def notebook(
275
275
  options["environment"]["type"] = environment
276
276
 
277
277
  try:
278
- envs = [{"key": e.split("=")[0], "value": e.split("=")[1]} for e in env]
278
+ envs = [
279
+ {"key": e.split("=")[0], "value": e.split("=")[1]} for e in env
280
+ ]
279
281
  options["environment"]["env"] = envs
280
282
  except IndexError:
281
283
  raise click.UsageError(
@@ -289,21 +291,25 @@ def notebook(
289
291
  if pip_packages:
290
292
  options["environment"]["packages"]["pip"] = pip_packages.split(",")
291
293
  if conda_packages:
292
- options["environment"]["packages"]["conda"] = conda_packages.split(",")
294
+ options["environment"]["packages"]["conda"] = conda_packages.split(
295
+ ","
296
+ )
293
297
 
294
298
  if data_dir:
295
299
  click.echo("Creating Dataset..", file=config.stdout)
296
300
  new_dataset = config.trainml.run(
297
- config.trainml.client.datasets.create(f"Job - {name}", "local", data_dir)
301
+ config.trainml.client.datasets.create(
302
+ f"Job - {name}", "local", data_dir
303
+ )
298
304
  )
299
305
  if attach:
300
306
  config.trainml.run(new_dataset.attach(), new_dataset.connect())
301
- config.trainml.run(new_dataset.disconnect())
302
307
  else:
303
308
  config.trainml.run(new_dataset.connect())
304
309
  config.trainml.run(new_dataset.wait_for("ready"))
305
- config.trainml.run(new_dataset.disconnect())
306
- options["data"]["datasets"].append(dict(id=new_dataset.id, type="existing"))
310
+ options["data"]["datasets"].append(
311
+ dict(id=new_dataset.id, type="existing")
312
+ )
307
313
 
308
314
  if git_uri:
309
315
  options["model"]["source_type"] = "git"
@@ -331,13 +337,11 @@ def notebook(
331
337
  if attach or connect:
332
338
  click.echo("Waiting for job to start...", file=config.stdout)
333
339
  config.trainml.run(job.connect(), job.attach())
334
- config.trainml.run(job.disconnect())
335
340
  click.echo("Launching...", file=config.stdout)
336
341
  browse(job.notebook_url)
337
342
  else:
338
343
  config.trainml.run(job.connect())
339
344
  config.trainml.run(job.wait_for("running"))
340
- config.trainml.run(job.disconnect())
341
345
  elif attach or connect:
342
346
  click.echo("Waiting for job to start...", file=config.stdout)
343
347
  config.trainml.run(job.wait_for("running", timeout))
@@ -626,15 +630,21 @@ def training(
626
630
  if output_type:
627
631
  options["data"]["output_type"] = output_type
628
632
  options["data"]["output_uri"] = output_uri
629
- options["data"]["output_options"] = dict(archive=archive, save_model=save_model)
633
+ options["data"]["output_options"] = dict(
634
+ archive=archive, save_model=save_model
635
+ )
630
636
 
631
637
  if output_dir:
632
638
  options["data"]["output_type"] = "local"
633
639
  options["data"]["output_uri"] = output_dir
634
- options["data"]["output_options"] = dict(archive=archive, save_model=save_model)
640
+ options["data"]["output_options"] = dict(
641
+ archive=archive, save_model=save_model
642
+ )
635
643
 
636
644
  try:
637
- envs = [{"key": e.split("=")[0], "value": e.split("=")[1]} for e in env]
645
+ envs = [
646
+ {"key": e.split("=")[0], "value": e.split("=")[1]} for e in env
647
+ ]
638
648
  options["environment"]["env"] = envs
639
649
  except IndexError:
640
650
  raise click.UsageError(
@@ -648,21 +658,25 @@ def training(
648
658
  if pip_packages:
649
659
  options["environment"]["packages"]["pip"] = pip_packages.split(",")
650
660
  if conda_packages:
651
- options["environment"]["packages"]["conda"] = conda_packages.split(",")
661
+ options["environment"]["packages"]["conda"] = conda_packages.split(
662
+ ","
663
+ )
652
664
 
653
665
  if data_dir:
654
666
  click.echo("Creating Dataset..", file=config.stdout)
655
667
  new_dataset = config.trainml.run(
656
- config.trainml.client.datasets.create(f"Job - {name}", "local", data_dir)
668
+ config.trainml.client.datasets.create(
669
+ f"Job - {name}", "local", data_dir
670
+ )
657
671
  )
658
672
  if attach:
659
673
  config.trainml.run(new_dataset.attach(), new_dataset.connect())
660
- config.trainml.run(new_dataset.disconnect())
661
674
  else:
662
675
  config.trainml.run(new_dataset.connect())
663
676
  config.trainml.run(new_dataset.wait_for("ready"))
664
- config.trainml.run(new_dataset.disconnect())
665
- options["data"]["datasets"].append(dict(id=new_dataset.id, type="existing"))
677
+ options["data"]["datasets"].append(
678
+ dict(id=new_dataset.id, type="existing")
679
+ )
666
680
 
667
681
  if git_uri:
668
682
  options["model"]["source_type"] = "git"
@@ -979,15 +993,21 @@ def inference(
979
993
  if output_type:
980
994
  options["data"]["output_type"] = output_type
981
995
  options["data"]["output_uri"] = output_uri
982
- options["data"]["output_options"] = dict(archive=archive, save_model=save_model)
996
+ options["data"]["output_options"] = dict(
997
+ archive=archive, save_model=save_model
998
+ )
983
999
 
984
1000
  if output_dir:
985
1001
  options["data"]["output_type"] = "local"
986
1002
  options["data"]["output_uri"] = output_dir
987
- options["data"]["output_options"] = dict(archive=archive, save_model=save_model)
1003
+ options["data"]["output_options"] = dict(
1004
+ archive=archive, save_model=save_model
1005
+ )
988
1006
 
989
1007
  try:
990
- envs = [{"key": e.split("=")[0], "value": e.split("=")[1]} for e in env]
1008
+ envs = [
1009
+ {"key": e.split("=")[0], "value": e.split("=")[1]} for e in env
1010
+ ]
991
1011
  options["environment"]["env"] = envs
992
1012
  except IndexError:
993
1013
  raise click.UsageError(
@@ -1001,7 +1021,9 @@ def inference(
1001
1021
  if pip_packages:
1002
1022
  options["environment"]["packages"]["pip"] = pip_packages.split(",")
1003
1023
  if conda_packages:
1004
- options["environment"]["packages"]["conda"] = conda_packages.split(",")
1024
+ options["environment"]["packages"]["conda"] = conda_packages.split(
1025
+ ","
1026
+ )
1005
1027
 
1006
1028
  if git_uri:
1007
1029
  options["model"]["source_type"] = "git"
@@ -1301,7 +1323,9 @@ def endpoint(
1301
1323
  options["environment"]["type"] = environment
1302
1324
 
1303
1325
  try:
1304
- envs = [{"key": e.split("=")[0], "value": e.split("=")[1]} for e in env]
1326
+ envs = [
1327
+ {"key": e.split("=")[0], "value": e.split("=")[1]} for e in env
1328
+ ]
1305
1329
  options["environment"]["env"] = envs
1306
1330
  except IndexError:
1307
1331
  raise click.UsageError(
@@ -1315,7 +1339,9 @@ def endpoint(
1315
1339
  if pip_packages:
1316
1340
  options["environment"]["packages"]["pip"] = pip_packages.split(",")
1317
1341
  if conda_packages:
1318
- options["environment"]["packages"]["conda"] = conda_packages.split(",")
1342
+ options["environment"]["packages"]["conda"] = conda_packages.split(
1343
+ ","
1344
+ )
1319
1345
 
1320
1346
  if git_uri:
1321
1347
  options["model"]["source_type"] = "git"
@@ -1349,7 +1375,6 @@ def endpoint(
1349
1375
  config.trainml.run(job.connect())
1350
1376
  click.echo("Waiting for job to start...", file=config.stdout)
1351
1377
  config.trainml.run(job.wait_for("running", timeout))
1352
- config.trainml.run(job.disconnect())
1353
1378
  config.trainml.run(job.refresh())
1354
1379
  click.echo(f"Endpoint is running at: {job.url}", file=config.stdout)
1355
1380
  else:
@@ -1357,4 +1382,6 @@ def endpoint(
1357
1382
  click.echo("Waiting for job to start...", file=config.stdout)
1358
1383
  config.trainml.run(job.wait_for("running", timeout))
1359
1384
  config.trainml.run(job.refresh())
1360
- click.echo(f"Endpoint is running at: {job.url}", file=config.stdout)
1385
+ click.echo(
1386
+ f"Endpoint is running at: {job.url}", file=config.stdout
1387
+ )