trainml 0.5.16__py3-none-any.whl → 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- examples/local_storage.py +0 -2
- tests/integration/test_checkpoints_integration.py +4 -3
- tests/integration/test_datasets_integration.py +5 -3
- tests/integration/test_jobs_integration.py +33 -27
- tests/integration/test_models_integration.py +7 -3
- tests/integration/test_volumes_integration.py +2 -2
- tests/unit/cli/test_cli_checkpoint_unit.py +312 -1
- tests/unit/cloudbender/test_nodes_unit.py +112 -0
- tests/unit/cloudbender/test_providers_unit.py +96 -0
- tests/unit/cloudbender/test_regions_unit.py +106 -0
- tests/unit/cloudbender/test_services_unit.py +141 -0
- tests/unit/conftest.py +23 -10
- tests/unit/projects/test_project_data_connectors_unit.py +39 -0
- tests/unit/projects/test_project_datastores_unit.py +37 -0
- tests/unit/projects/test_project_members_unit.py +46 -0
- tests/unit/projects/test_project_services_unit.py +65 -0
- tests/unit/projects/test_projects_unit.py +17 -1
- tests/unit/test_auth_unit.py +17 -2
- tests/unit/test_checkpoints_unit.py +256 -71
- tests/unit/test_datasets_unit.py +218 -68
- tests/unit/test_exceptions.py +133 -0
- tests/unit/test_gpu_types_unit.py +11 -1
- tests/unit/test_jobs_unit.py +1014 -95
- tests/unit/test_main_unit.py +20 -0
- tests/unit/test_models_unit.py +218 -70
- tests/unit/test_trainml_unit.py +627 -3
- tests/unit/test_volumes_unit.py +211 -70
- tests/unit/utils/__init__.py +1 -0
- tests/unit/utils/test_transfer_unit.py +4260 -0
- trainml/__init__.py +1 -1
- trainml/checkpoints.py +56 -57
- trainml/cli/__init__.py +6 -3
- trainml/cli/checkpoint.py +18 -57
- trainml/cli/dataset.py +17 -57
- trainml/cli/job/__init__.py +11 -53
- trainml/cli/job/create.py +51 -24
- trainml/cli/model.py +14 -56
- trainml/cli/volume.py +18 -57
- trainml/datasets.py +50 -55
- trainml/jobs.py +239 -68
- trainml/models.py +51 -55
- trainml/projects/projects.py +2 -2
- trainml/trainml.py +50 -16
- trainml/utils/__init__.py +1 -0
- trainml/utils/auth.py +641 -0
- trainml/utils/transfer.py +587 -0
- trainml/volumes.py +48 -53
- {trainml-0.5.16.dist-info → trainml-1.0.0.dist-info}/METADATA +3 -3
- {trainml-0.5.16.dist-info → trainml-1.0.0.dist-info}/RECORD +53 -47
- {trainml-0.5.16.dist-info → trainml-1.0.0.dist-info}/LICENSE +0 -0
- {trainml-0.5.16.dist-info → trainml-1.0.0.dist-info}/WHEEL +0 -0
- {trainml-0.5.16.dist-info → trainml-1.0.0.dist-info}/entry_points.txt +0 -0
- {trainml-0.5.16.dist-info → trainml-1.0.0.dist-info}/top_level.txt +0 -0
trainml/trainml.py
CHANGED
|
@@ -7,7 +7,7 @@ import traceback
|
|
|
7
7
|
import random
|
|
8
8
|
from importlib.metadata import version
|
|
9
9
|
|
|
10
|
-
from trainml.auth import Auth
|
|
10
|
+
from trainml.utils.auth import Auth
|
|
11
11
|
from trainml.datasets import Datasets
|
|
12
12
|
from trainml.models import Models
|
|
13
13
|
from trainml.checkpoints import Checkpoints
|
|
@@ -16,7 +16,6 @@ from trainml.jobs import Jobs
|
|
|
16
16
|
from trainml.gpu_types import GpuTypes
|
|
17
17
|
from trainml.environments import Environments
|
|
18
18
|
from trainml.exceptions import ApiError, TrainMLException
|
|
19
|
-
from trainml.connections import Connections
|
|
20
19
|
from trainml.projects import Projects
|
|
21
20
|
from trainml.cloudbender import Cloudbender
|
|
22
21
|
|
|
@@ -34,13 +33,17 @@ class TrainML(object):
|
|
|
34
33
|
os.environ.get("TRAINML_CONFIG_DIR") or "~/.trainml"
|
|
35
34
|
)
|
|
36
35
|
try:
|
|
37
|
-
with open(
|
|
36
|
+
with open(
|
|
37
|
+
f"{CONFIG_DIR}/environment.json", "r", encoding="utf-8"
|
|
38
|
+
) as file:
|
|
38
39
|
env_str = file.read().replace("\n", "")
|
|
39
40
|
env = json.loads(env_str)
|
|
40
41
|
except OSError:
|
|
41
42
|
env = dict()
|
|
42
43
|
try:
|
|
43
|
-
with open(
|
|
44
|
+
with open(
|
|
45
|
+
f"{CONFIG_DIR}/config.json", "r", encoding="utf-8"
|
|
46
|
+
) as file:
|
|
44
47
|
config_str = file.read().replace("\n", "")
|
|
45
48
|
config = json.loads(config_str)
|
|
46
49
|
except OSError:
|
|
@@ -72,7 +75,6 @@ class TrainML(object):
|
|
|
72
75
|
self.jobs = Jobs(self)
|
|
73
76
|
self.gpu_types = GpuTypes(self)
|
|
74
77
|
self.environments = Environments(self)
|
|
75
|
-
self.connections = Connections(self)
|
|
76
78
|
self.projects = Projects(self)
|
|
77
79
|
self.cloudbender = Cloudbender(self)
|
|
78
80
|
self.api_url = (
|
|
@@ -92,7 +94,16 @@ class TrainML(object):
|
|
|
92
94
|
def project(self) -> str:
|
|
93
95
|
return self.active_project
|
|
94
96
|
|
|
95
|
-
async def _query(
|
|
97
|
+
async def _query(
|
|
98
|
+
self,
|
|
99
|
+
path,
|
|
100
|
+
method,
|
|
101
|
+
params=None,
|
|
102
|
+
data=None,
|
|
103
|
+
headers=None,
|
|
104
|
+
max_retries=3,
|
|
105
|
+
backoff_factor=0.5,
|
|
106
|
+
):
|
|
96
107
|
try:
|
|
97
108
|
tokens = self.auth.get_tokens()
|
|
98
109
|
except TrainMLException as e:
|
|
@@ -120,7 +131,9 @@ class TrainML(object):
|
|
|
120
131
|
)
|
|
121
132
|
if params:
|
|
122
133
|
if not isinstance(params, dict):
|
|
123
|
-
raise TrainMLException(
|
|
134
|
+
raise TrainMLException(
|
|
135
|
+
"Query parameters must be a valid dictionary"
|
|
136
|
+
)
|
|
124
137
|
params = {
|
|
125
138
|
k: (str(v).lower() if isinstance(v, bool) else v)
|
|
126
139
|
for k, v in params.items()
|
|
@@ -154,27 +167,44 @@ class TrainML(object):
|
|
|
154
167
|
params=params,
|
|
155
168
|
) as resp:
|
|
156
169
|
if (resp.status // 100) in [4, 5]:
|
|
157
|
-
if
|
|
158
|
-
|
|
170
|
+
if (
|
|
171
|
+
resp.status == 502
|
|
172
|
+
and attempt < max_retries - 1
|
|
173
|
+
):
|
|
174
|
+
wait_time = (
|
|
175
|
+
(2**attempt)
|
|
176
|
+
* backoff_factor
|
|
177
|
+
* (random.random() + 0.5)
|
|
178
|
+
)
|
|
159
179
|
await asyncio.sleep(wait_time)
|
|
160
180
|
continue
|
|
161
181
|
else:
|
|
162
182
|
what = await resp.read()
|
|
163
|
-
content_type = resp.headers.get(
|
|
183
|
+
content_type = resp.headers.get(
|
|
184
|
+
"content-type", ""
|
|
185
|
+
)
|
|
164
186
|
resp.close()
|
|
165
187
|
if content_type == "application/json":
|
|
166
|
-
raise ApiError(
|
|
188
|
+
raise ApiError(
|
|
189
|
+
resp.status,
|
|
190
|
+
json.loads(what.decode("utf8")),
|
|
191
|
+
)
|
|
167
192
|
else:
|
|
168
|
-
raise ApiError(
|
|
193
|
+
raise ApiError(
|
|
194
|
+
resp.status,
|
|
195
|
+
{"message": what.decode("utf8")},
|
|
196
|
+
)
|
|
169
197
|
results = await resp.json()
|
|
170
198
|
return results
|
|
171
199
|
except aiohttp.ClientResponseError as e:
|
|
172
200
|
if e.status == 502 and attempt < max_retries - 1:
|
|
173
|
-
wait_time = (
|
|
201
|
+
wait_time = (
|
|
202
|
+
(2**attempt) * backoff_factor * (random.random() + 0.5)
|
|
203
|
+
)
|
|
174
204
|
await asyncio.sleep(wait_time)
|
|
175
205
|
continue
|
|
176
206
|
else:
|
|
177
|
-
raise
|
|
207
|
+
raise ApiError(e.status, f"Error {e.message}")
|
|
178
208
|
|
|
179
209
|
raise TrainMLException("Unexpected API failure")
|
|
180
210
|
|
|
@@ -286,11 +316,15 @@ class TrainML(object):
|
|
|
286
316
|
logging.debug(f"Websocket Disconnected. Done? {done}")
|
|
287
317
|
except Exception as e:
|
|
288
318
|
connection_tries += 1
|
|
289
|
-
logging.debug(
|
|
319
|
+
logging.debug(
|
|
320
|
+
f"Connection error: {traceback.format_exc()}"
|
|
321
|
+
)
|
|
290
322
|
if connection_tries == 5:
|
|
291
323
|
raise ApiError(
|
|
292
324
|
500,
|
|
293
|
-
{
|
|
325
|
+
{
|
|
326
|
+
"message": f"Connection error: {traceback.format_exc()}"
|
|
327
|
+
},
|
|
294
328
|
)
|
|
295
329
|
|
|
296
330
|
def set_active_project(self, project_uuid):
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
"""Utility modules for trainml SDK."""
|