trainml 0.5.16__py3-none-any.whl → 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (53) hide show
  1. examples/local_storage.py +0 -2
  2. tests/integration/test_checkpoints_integration.py +4 -3
  3. tests/integration/test_datasets_integration.py +5 -3
  4. tests/integration/test_jobs_integration.py +33 -27
  5. tests/integration/test_models_integration.py +7 -3
  6. tests/integration/test_volumes_integration.py +2 -2
  7. tests/unit/cli/test_cli_checkpoint_unit.py +312 -1
  8. tests/unit/cloudbender/test_nodes_unit.py +112 -0
  9. tests/unit/cloudbender/test_providers_unit.py +96 -0
  10. tests/unit/cloudbender/test_regions_unit.py +106 -0
  11. tests/unit/cloudbender/test_services_unit.py +141 -0
  12. tests/unit/conftest.py +23 -10
  13. tests/unit/projects/test_project_data_connectors_unit.py +39 -0
  14. tests/unit/projects/test_project_datastores_unit.py +37 -0
  15. tests/unit/projects/test_project_members_unit.py +46 -0
  16. tests/unit/projects/test_project_services_unit.py +65 -0
  17. tests/unit/projects/test_projects_unit.py +17 -1
  18. tests/unit/test_auth_unit.py +17 -2
  19. tests/unit/test_checkpoints_unit.py +256 -71
  20. tests/unit/test_datasets_unit.py +218 -68
  21. tests/unit/test_exceptions.py +133 -0
  22. tests/unit/test_gpu_types_unit.py +11 -1
  23. tests/unit/test_jobs_unit.py +1014 -95
  24. tests/unit/test_main_unit.py +20 -0
  25. tests/unit/test_models_unit.py +218 -70
  26. tests/unit/test_trainml_unit.py +627 -3
  27. tests/unit/test_volumes_unit.py +211 -70
  28. tests/unit/utils/__init__.py +1 -0
  29. tests/unit/utils/test_transfer_unit.py +4260 -0
  30. trainml/__init__.py +1 -1
  31. trainml/checkpoints.py +56 -57
  32. trainml/cli/__init__.py +6 -3
  33. trainml/cli/checkpoint.py +18 -57
  34. trainml/cli/dataset.py +17 -57
  35. trainml/cli/job/__init__.py +11 -53
  36. trainml/cli/job/create.py +51 -24
  37. trainml/cli/model.py +14 -56
  38. trainml/cli/volume.py +18 -57
  39. trainml/datasets.py +50 -55
  40. trainml/jobs.py +239 -68
  41. trainml/models.py +51 -55
  42. trainml/projects/projects.py +2 -2
  43. trainml/trainml.py +50 -16
  44. trainml/utils/__init__.py +1 -0
  45. trainml/utils/auth.py +641 -0
  46. trainml/utils/transfer.py +587 -0
  47. trainml/volumes.py +48 -53
  48. {trainml-0.5.16.dist-info → trainml-1.0.0.dist-info}/METADATA +3 -3
  49. {trainml-0.5.16.dist-info → trainml-1.0.0.dist-info}/RECORD +53 -47
  50. {trainml-0.5.16.dist-info → trainml-1.0.0.dist-info}/LICENSE +0 -0
  51. {trainml-0.5.16.dist-info → trainml-1.0.0.dist-info}/WHEEL +0 -0
  52. {trainml-0.5.16.dist-info → trainml-1.0.0.dist-info}/entry_points.txt +0 -0
  53. {trainml-0.5.16.dist-info → trainml-1.0.0.dist-info}/top_level.txt +0 -0
trainml/cli/model.py CHANGED
@@ -36,15 +36,7 @@ def attach(config, model):
36
36
  if None is found:
37
37
  raise click.UsageError("Cannot find specified model.")
38
38
 
39
- try:
40
- config.trainml.run(found.attach())
41
- return config.trainml.run(found.disconnect())
42
- except:
43
- try:
44
- config.trainml.run(found.disconnect())
45
- except:
46
- pass
47
- raise
39
+ config.trainml.run(found.attach())
48
40
 
49
41
 
50
42
  @model.command()
@@ -69,18 +61,10 @@ def connect(config, model, attach):
69
61
  if None is found:
70
62
  raise click.UsageError("Cannot find specified model.")
71
63
 
72
- try:
73
- if attach:
74
- config.trainml.run(found.connect(), found.attach())
75
- return config.trainml.run(found.disconnect())
76
- else:
77
- return config.trainml.run(found.connect())
78
- except:
79
- try:
80
- config.trainml.run(found.disconnect())
81
- except:
82
- pass
83
- raise
64
+ if attach:
65
+ config.trainml.run(found.connect(), found.attach())
66
+ else:
67
+ config.trainml.run(found.connect())
84
68
 
85
69
 
86
70
  @model.command()
@@ -125,41 +109,15 @@ def create(config, attach, connect, source, name, path):
125
109
  )
126
110
  )
127
111
 
128
- try:
129
- if connect and attach:
130
- config.trainml.run(model.attach(), model.connect())
131
- return config.trainml.run(model.disconnect())
132
- elif connect:
133
- return config.trainml.run(model.connect())
134
- else:
135
- raise click.UsageError(
136
- "Abort!\n"
137
- "No logs to show for local sourced model without connect."
138
- )
139
- except:
140
- try:
141
- config.trainml.run(model.disconnect())
142
- except:
143
- pass
144
- raise
145
-
146
-
147
- @model.command()
148
- @click.argument("model", type=click.STRING)
149
- @pass_config
150
- def disconnect(config, model):
151
- """
152
- Disconnect and clean-up model upload.
153
-
154
- MODEL may be specified by name or ID, but ID is preferred.
155
- """
156
- models = config.trainml.run(config.trainml.client.models.list())
157
-
158
- found = search_by_id_name(model, models)
159
- if None is found:
160
- raise click.UsageError("Cannot find specified model.")
161
-
162
- return config.trainml.run(found.disconnect())
112
+ if connect and attach:
113
+ config.trainml.run(model.attach(), model.connect())
114
+ elif connect:
115
+ config.trainml.run(model.connect())
116
+ else:
117
+ raise click.UsageError(
118
+ "Abort!\n"
119
+ "No logs to show for local sourced model without connect."
120
+ )
163
121
 
164
122
 
165
123
  @model.command()
trainml/cli/volume.py CHANGED
@@ -35,15 +35,7 @@ def attach(config, volume):
35
35
  if None is found:
36
36
  raise click.UsageError("Cannot find specified volume.")
37
37
 
38
- try:
39
- config.trainml.run(found.attach())
40
- return config.trainml.run(found.disconnect())
41
- except:
42
- try:
43
- config.trainml.run(found.disconnect())
44
- except:
45
- pass
46
- raise
38
+ config.trainml.run(found.attach())
47
39
 
48
40
 
49
41
  @volume.command()
@@ -67,18 +59,10 @@ def connect(config, volume, attach):
67
59
  if None is found:
68
60
  raise click.UsageError("Cannot find specified volume.")
69
61
 
70
- try:
71
- if attach:
72
- config.trainml.run(found.connect(), found.attach())
73
- return config.trainml.run(found.disconnect())
74
- else:
75
- return config.trainml.run(found.connect())
76
- except:
77
- try:
78
- config.trainml.run(found.disconnect())
79
- except:
80
- pass
81
- raise
62
+ if attach:
63
+ config.trainml.run(found.connect(), found.attach())
64
+ else:
65
+ config.trainml.run(found.connect())
82
66
 
83
67
 
84
68
  @volume.command()
@@ -120,45 +104,22 @@ def create(config, attach, connect, source, name, capacity, path):
120
104
  if source == "local":
121
105
  volume = config.trainml.run(
122
106
  config.trainml.client.volumes.create(
123
- name=name, source_type="local", source_uri=path, capacity=capacity
107
+ name=name,
108
+ source_type="local",
109
+ source_uri=path,
110
+ capacity=capacity,
124
111
  )
125
112
  )
126
113
 
127
- try:
128
- if connect and attach:
129
- config.trainml.run(volume.attach(), volume.connect())
130
- return config.trainml.run(volume.disconnect())
131
- elif connect:
132
- return config.trainml.run(volume.connect())
133
- else:
134
- raise click.UsageError(
135
- "Abort!\n"
136
- "No logs to show for local sourced volume without connect."
137
- )
138
- except:
139
- try:
140
- config.trainml.run(volume.disconnect())
141
- except:
142
- pass
143
- raise
144
-
145
-
146
- @volume.command()
147
- @click.argument("volume", type=click.STRING)
148
- @pass_config
149
- def disconnect(config, volume):
150
- """
151
- Disconnect and clean-up volume upload.
152
-
153
- VOLUME may be specified by name or ID, but ID is preferred.
154
- """
155
- volumes = config.trainml.run(config.trainml.client.volumes.list())
156
-
157
- found = search_by_id_name(volume, volumes)
158
- if None is found:
159
- raise click.UsageError("Cannot find specified volume.")
160
-
161
- return config.trainml.run(found.disconnect())
114
+ if connect and attach:
115
+ config.trainml.run(volume.attach(), volume.connect())
116
+ elif connect:
117
+ config.trainml.run(volume.connect())
118
+ else:
119
+ raise click.UsageError(
120
+ "Abort!\n"
121
+ "No logs to show for local sourced volume without connect."
122
+ )
162
123
 
163
124
 
164
125
  @volume.command()
trainml/datasets.py CHANGED
@@ -10,7 +10,7 @@ from .exceptions import (
10
10
  SpecificationError,
11
11
  TrainMLException,
12
12
  )
13
- from .connections import Connection
13
+ from trainml.utils.transfer import upload, download
14
14
 
15
15
 
16
16
  class Datasets(object):
@@ -71,10 +71,12 @@ class Dataset:
71
71
  self._id = self._dataset.get("id", self._dataset.get("dataset_uuid"))
72
72
  self._status = self._dataset.get("status")
73
73
  self._name = self._dataset.get("name")
74
- self._size = self._dataset.get("size") or self._dataset.get("used_size")
75
- self._billed_size = self._dataset.get("billed_size") or self._dataset.get(
76
- "size"
74
+ self._size = self._dataset.get("size") or self._dataset.get(
75
+ "used_size"
77
76
  )
77
+ self._billed_size = self._dataset.get(
78
+ "billed_size"
79
+ ) or self._dataset.get("size")
78
80
  self._project_uuid = self._dataset.get("project_uuid")
79
81
 
80
82
  @property
@@ -122,56 +124,45 @@ class Dataset:
122
124
  )
123
125
  return resp
124
126
 
125
- async def get_connection_utility_url(self):
126
- resp = await self.trainml._query(
127
- f"/dataset/{self._id}/download",
128
- "GET",
129
- dict(project_uuid=self._project_uuid),
130
- )
131
- return resp
132
-
133
- def get_connection_details(self):
134
- if self._dataset.get("vpn"):
135
- details = dict(
136
- entity_type="dataset",
137
- project_uuid=self._dataset.get("project_uuid"),
138
- cidr=self._dataset.get("vpn").get("cidr"),
139
- ssh_port=self._dataset.get("vpn").get("client").get("ssh_port"),
140
- input_path=(
141
- self._dataset.get("source_uri")
142
- if self.status in ["new", "downloading"]
143
- else None
144
- ),
145
- output_path=(
146
- self._dataset.get("output_uri")
147
- if self.status == "exporting"
148
- else None
149
- ),
150
- )
151
- else:
152
- details = dict()
153
- return details
154
-
155
127
  async def connect(self):
156
- if self.status in ["ready", "failed"]:
157
- raise SpecificationError(
158
- "status",
159
- f"You can only connect to downloading or exporting datasets.",
160
- )
161
- if self.status == "new":
162
- await self.wait_for("downloading")
163
- connection = Connection(
164
- self.trainml, entity_type="dataset", id=self.id, entity=self
165
- )
166
- await connection.start()
167
- return connection.status
128
+ if self.status not in ["downloading", "exporting"]:
129
+ if self.status == "new":
130
+ await self.wait_for("downloading")
131
+ else:
132
+ raise SpecificationError(
133
+ "status",
134
+ f"You can only connect to downloading or exporting datasets.",
135
+ )
168
136
 
169
- async def disconnect(self):
170
- connection = Connection(
171
- self.trainml, entity_type="dataset", id=self.id, entity=self
172
- )
173
- await connection.stop()
174
- return connection.status
137
+ # Refresh to get latest entity data
138
+ await self.refresh()
139
+
140
+ if self.status == "downloading":
141
+ # Upload task - get auth_token, hostname, and source_uri from dataset
142
+ auth_token = self._dataset.get("auth_token")
143
+ hostname = self._dataset.get("hostname")
144
+ source_uri = self._dataset.get("source_uri")
145
+
146
+ if not auth_token or not hostname or not source_uri:
147
+ raise SpecificationError(
148
+ "status",
149
+ f"Dataset in downloading status missing required connection properties (auth_token, hostname, source_uri).",
150
+ )
151
+
152
+ await upload(hostname, auth_token, source_uri)
153
+ elif self.status == "exporting":
154
+ # Download task - get auth_token, hostname, and output_uri from dataset
155
+ auth_token = self._dataset.get("auth_token")
156
+ hostname = self._dataset.get("hostname")
157
+ output_uri = self._dataset.get("output_uri")
158
+
159
+ if not auth_token or not hostname or not output_uri:
160
+ raise SpecificationError(
161
+ "status",
162
+ f"Dataset in exporting status missing required connection properties (auth_token, hostname, output_uri).",
163
+ )
164
+
165
+ await download(hostname, auth_token, output_uri)
175
166
 
176
167
  async def remove(self, force=False):
177
168
  await self.trainml._query(
@@ -210,7 +201,9 @@ class Dataset:
210
201
  if msg_handler:
211
202
  msg_handler(data)
212
203
  else:
213
- timestamp = datetime.fromtimestamp(int(data.get("time")) / 1000)
204
+ timestamp = datetime.fromtimestamp(
205
+ int(data.get("time")) / 1000
206
+ )
214
207
  print(
215
208
  f"{timestamp.strftime('%m/%d/%Y, %H:%M:%S')}: {data.get('msg').rstrip()}"
216
209
  )
@@ -239,7 +232,7 @@ class Dataset:
239
232
  async def wait_for(self, status, timeout=300):
240
233
  if self.status == status:
241
234
  return
242
- valid_statuses = ["downloading", "ready", "archived"]
235
+ valid_statuses = ["downloading", "ready", "exporting", "archived"]
243
236
  if not status in valid_statuses:
244
237
  raise SpecificationError(
245
238
  "status",
@@ -254,7 +247,9 @@ class Dataset:
254
247
 
255
248
  POLL_INTERVAL_MIN = 5
256
249
  POLL_INTERVAL_MAX = 60
257
- POLL_INTERVAL = max(min(timeout / 60, POLL_INTERVAL_MAX), POLL_INTERVAL_MIN)
250
+ POLL_INTERVAL = max(
251
+ min(timeout / 60, POLL_INTERVAL_MAX), POLL_INTERVAL_MIN
252
+ )
258
253
  retry_count = math.ceil(timeout / POLL_INTERVAL)
259
254
  count = 0
260
255
  while count < retry_count: