trainml 0.5.17__py3-none-any.whl → 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (52) hide show
  1. examples/local_storage.py +0 -2
  2. tests/integration/test_checkpoints_integration.py +4 -3
  3. tests/integration/test_datasets_integration.py +5 -3
  4. tests/integration/test_jobs_integration.py +33 -27
  5. tests/integration/test_models_integration.py +7 -3
  6. tests/integration/test_volumes_integration.py +2 -2
  7. tests/unit/cli/test_cli_checkpoint_unit.py +312 -1
  8. tests/unit/cloudbender/test_nodes_unit.py +112 -0
  9. tests/unit/cloudbender/test_providers_unit.py +96 -0
  10. tests/unit/cloudbender/test_regions_unit.py +106 -0
  11. tests/unit/cloudbender/test_services_unit.py +141 -0
  12. tests/unit/conftest.py +23 -10
  13. tests/unit/projects/test_project_data_connectors_unit.py +39 -0
  14. tests/unit/projects/test_project_datastores_unit.py +37 -0
  15. tests/unit/projects/test_project_members_unit.py +46 -0
  16. tests/unit/projects/test_project_services_unit.py +65 -0
  17. tests/unit/projects/test_projects_unit.py +16 -0
  18. tests/unit/test_auth_unit.py +17 -2
  19. tests/unit/test_checkpoints_unit.py +256 -71
  20. tests/unit/test_datasets_unit.py +218 -68
  21. tests/unit/test_exceptions.py +133 -0
  22. tests/unit/test_gpu_types_unit.py +11 -1
  23. tests/unit/test_jobs_unit.py +1014 -95
  24. tests/unit/test_main_unit.py +20 -0
  25. tests/unit/test_models_unit.py +218 -70
  26. tests/unit/test_trainml_unit.py +627 -3
  27. tests/unit/test_volumes_unit.py +211 -70
  28. tests/unit/utils/__init__.py +1 -0
  29. tests/unit/utils/test_transfer_unit.py +4260 -0
  30. trainml/__init__.py +1 -1
  31. trainml/checkpoints.py +56 -57
  32. trainml/cli/__init__.py +6 -3
  33. trainml/cli/checkpoint.py +18 -57
  34. trainml/cli/dataset.py +17 -57
  35. trainml/cli/job/__init__.py +11 -53
  36. trainml/cli/job/create.py +51 -24
  37. trainml/cli/model.py +14 -56
  38. trainml/cli/volume.py +18 -57
  39. trainml/datasets.py +50 -55
  40. trainml/jobs.py +239 -68
  41. trainml/models.py +51 -55
  42. trainml/trainml.py +50 -16
  43. trainml/utils/__init__.py +1 -0
  44. trainml/utils/auth.py +641 -0
  45. trainml/utils/transfer.py +587 -0
  46. trainml/volumes.py +48 -53
  47. {trainml-0.5.17.dist-info → trainml-1.0.0.dist-info}/METADATA +3 -3
  48. {trainml-0.5.17.dist-info → trainml-1.0.0.dist-info}/RECORD +52 -46
  49. {trainml-0.5.17.dist-info → trainml-1.0.0.dist-info}/LICENSE +0 -0
  50. {trainml-0.5.17.dist-info → trainml-1.0.0.dist-info}/WHEEL +0 -0
  51. {trainml-0.5.17.dist-info → trainml-1.0.0.dist-info}/entry_points.txt +0 -0
  52. {trainml-0.5.17.dist-info → trainml-1.0.0.dist-info}/top_level.txt +0 -0
trainml/trainml.py CHANGED
@@ -7,7 +7,7 @@ import traceback
7
7
  import random
8
8
  from importlib.metadata import version
9
9
 
10
- from trainml.auth import Auth
10
+ from trainml.utils.auth import Auth
11
11
  from trainml.datasets import Datasets
12
12
  from trainml.models import Models
13
13
  from trainml.checkpoints import Checkpoints
@@ -16,7 +16,6 @@ from trainml.jobs import Jobs
16
16
  from trainml.gpu_types import GpuTypes
17
17
  from trainml.environments import Environments
18
18
  from trainml.exceptions import ApiError, TrainMLException
19
- from trainml.connections import Connections
20
19
  from trainml.projects import Projects
21
20
  from trainml.cloudbender import Cloudbender
22
21
 
@@ -34,13 +33,17 @@ class TrainML(object):
34
33
  os.environ.get("TRAINML_CONFIG_DIR") or "~/.trainml"
35
34
  )
36
35
  try:
37
- with open(f"{CONFIG_DIR}/environment.json", "r") as file:
36
+ with open(
37
+ f"{CONFIG_DIR}/environment.json", "r", encoding="utf-8"
38
+ ) as file:
38
39
  env_str = file.read().replace("\n", "")
39
40
  env = json.loads(env_str)
40
41
  except OSError:
41
42
  env = dict()
42
43
  try:
43
- with open(f"{CONFIG_DIR}/config.json", "r") as file:
44
+ with open(
45
+ f"{CONFIG_DIR}/config.json", "r", encoding="utf-8"
46
+ ) as file:
44
47
  config_str = file.read().replace("\n", "")
45
48
  config = json.loads(config_str)
46
49
  except OSError:
@@ -72,7 +75,6 @@ class TrainML(object):
72
75
  self.jobs = Jobs(self)
73
76
  self.gpu_types = GpuTypes(self)
74
77
  self.environments = Environments(self)
75
- self.connections = Connections(self)
76
78
  self.projects = Projects(self)
77
79
  self.cloudbender = Cloudbender(self)
78
80
  self.api_url = (
@@ -92,7 +94,16 @@ class TrainML(object):
92
94
  def project(self) -> str:
93
95
  return self.active_project
94
96
 
95
- async def _query(self, path, method, params=None, data=None, headers=None,max_retries=3, backoff_factor=0.5):
97
+ async def _query(
98
+ self,
99
+ path,
100
+ method,
101
+ params=None,
102
+ data=None,
103
+ headers=None,
104
+ max_retries=3,
105
+ backoff_factor=0.5,
106
+ ):
96
107
  try:
97
108
  tokens = self.auth.get_tokens()
98
109
  except TrainMLException as e:
@@ -120,7 +131,9 @@ class TrainML(object):
120
131
  )
121
132
  if params:
122
133
  if not isinstance(params, dict):
123
- raise TrainMLException("Query parameters must be a valid dictionary")
134
+ raise TrainMLException(
135
+ "Query parameters must be a valid dictionary"
136
+ )
124
137
  params = {
125
138
  k: (str(v).lower() if isinstance(v, bool) else v)
126
139
  for k, v in params.items()
@@ -154,27 +167,44 @@ class TrainML(object):
154
167
  params=params,
155
168
  ) as resp:
156
169
  if (resp.status // 100) in [4, 5]:
157
- if resp.status == 502 and attempt < max_retries - 1:
158
- wait_time = (2 ** attempt) * backoff_factor * (random.random() + 0.5)
170
+ if (
171
+ resp.status == 502
172
+ and attempt < max_retries - 1
173
+ ):
174
+ wait_time = (
175
+ (2**attempt)
176
+ * backoff_factor
177
+ * (random.random() + 0.5)
178
+ )
159
179
  await asyncio.sleep(wait_time)
160
180
  continue
161
181
  else:
162
182
  what = await resp.read()
163
- content_type = resp.headers.get("content-type", "")
183
+ content_type = resp.headers.get(
184
+ "content-type", ""
185
+ )
164
186
  resp.close()
165
187
  if content_type == "application/json":
166
- raise ApiError(resp.status, json.loads(what.decode("utf8")))
188
+ raise ApiError(
189
+ resp.status,
190
+ json.loads(what.decode("utf8")),
191
+ )
167
192
  else:
168
- raise ApiError(resp.status, {"message": what.decode("utf8")})
193
+ raise ApiError(
194
+ resp.status,
195
+ {"message": what.decode("utf8")},
196
+ )
169
197
  results = await resp.json()
170
198
  return results
171
199
  except aiohttp.ClientResponseError as e:
172
200
  if e.status == 502 and attempt < max_retries - 1:
173
- wait_time = (2 ** attempt) * backoff_factor * (random.random() + 0.5)
201
+ wait_time = (
202
+ (2**attempt) * backoff_factor * (random.random() + 0.5)
203
+ )
174
204
  await asyncio.sleep(wait_time)
175
205
  continue
176
206
  else:
177
- raise ApiError(e.status, f"Error {e.message}")
207
+ raise ApiError(e.status, f"Error {e.message}")
178
208
 
179
209
  raise TrainMLException("Unexpected API failure")
180
210
 
@@ -286,11 +316,15 @@ class TrainML(object):
286
316
  logging.debug(f"Websocket Disconnected. Done? {done}")
287
317
  except Exception as e:
288
318
  connection_tries += 1
289
- logging.debug(f"Connection error: {traceback.format_exc()}")
319
+ logging.debug(
320
+ f"Connection error: {traceback.format_exc()}"
321
+ )
290
322
  if connection_tries == 5:
291
323
  raise ApiError(
292
324
  500,
293
- {"message": f"Connection error: {traceback.format_exc()}"},
325
+ {
326
+ "message": f"Connection error: {traceback.format_exc()}"
327
+ },
294
328
  )
295
329
 
296
330
  def set_active_project(self, project_uuid):
@@ -0,0 +1 @@
1
+ """Utility modules for trainml SDK."""