trainml 0.5.17__py3-none-any.whl → 1.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- examples/local_storage.py +0 -2
- tests/integration/test_checkpoints_integration.py +4 -3
- tests/integration/test_datasets_integration.py +5 -3
- tests/integration/test_jobs_integration.py +33 -27
- tests/integration/test_models_integration.py +7 -3
- tests/integration/test_volumes_integration.py +2 -2
- tests/unit/cli/test_cli_checkpoint_unit.py +312 -1
- tests/unit/cloudbender/test_nodes_unit.py +112 -0
- tests/unit/cloudbender/test_providers_unit.py +96 -0
- tests/unit/cloudbender/test_regions_unit.py +106 -0
- tests/unit/cloudbender/test_services_unit.py +141 -0
- tests/unit/conftest.py +23 -10
- tests/unit/projects/test_project_data_connectors_unit.py +39 -0
- tests/unit/projects/test_project_datastores_unit.py +37 -0
- tests/unit/projects/test_project_members_unit.py +46 -0
- tests/unit/projects/test_project_services_unit.py +65 -0
- tests/unit/projects/test_projects_unit.py +16 -0
- tests/unit/test_auth_unit.py +17 -2
- tests/unit/test_checkpoints_unit.py +256 -71
- tests/unit/test_datasets_unit.py +218 -68
- tests/unit/test_exceptions.py +133 -0
- tests/unit/test_gpu_types_unit.py +11 -1
- tests/unit/test_jobs_unit.py +1014 -95
- tests/unit/test_main_unit.py +20 -0
- tests/unit/test_models_unit.py +218 -70
- tests/unit/test_trainml_unit.py +627 -3
- tests/unit/test_volumes_unit.py +211 -70
- tests/unit/utils/__init__.py +1 -0
- tests/unit/utils/test_transfer_unit.py +4260 -0
- trainml/__init__.py +1 -1
- trainml/checkpoints.py +56 -57
- trainml/cli/__init__.py +6 -3
- trainml/cli/checkpoint.py +18 -57
- trainml/cli/dataset.py +17 -57
- trainml/cli/job/__init__.py +89 -67
- trainml/cli/job/create.py +51 -24
- trainml/cli/model.py +14 -56
- trainml/cli/volume.py +18 -57
- trainml/datasets.py +50 -55
- trainml/jobs.py +269 -69
- trainml/models.py +51 -55
- trainml/trainml.py +159 -114
- trainml/utils/__init__.py +1 -0
- trainml/utils/auth.py +641 -0
- trainml/utils/transfer.py +647 -0
- trainml/volumes.py +48 -53
- {trainml-0.5.17.dist-info → trainml-1.0.1.dist-info}/METADATA +3 -3
- {trainml-0.5.17.dist-info → trainml-1.0.1.dist-info}/RECORD +52 -46
- {trainml-0.5.17.dist-info → trainml-1.0.1.dist-info}/LICENSE +0 -0
- {trainml-0.5.17.dist-info → trainml-1.0.1.dist-info}/WHEEL +0 -0
- {trainml-0.5.17.dist-info → trainml-1.0.1.dist-info}/entry_points.txt +0 -0
- {trainml-0.5.17.dist-info → trainml-1.0.1.dist-info}/top_level.txt +0 -0
trainml/trainml.py
CHANGED
|
@@ -7,7 +7,7 @@ import traceback
|
|
|
7
7
|
import random
|
|
8
8
|
from importlib.metadata import version
|
|
9
9
|
|
|
10
|
-
from trainml.auth import Auth
|
|
10
|
+
from trainml.utils.auth import Auth
|
|
11
11
|
from trainml.datasets import Datasets
|
|
12
12
|
from trainml.models import Models
|
|
13
13
|
from trainml.checkpoints import Checkpoints
|
|
@@ -16,7 +16,6 @@ from trainml.jobs import Jobs
|
|
|
16
16
|
from trainml.gpu_types import GpuTypes
|
|
17
17
|
from trainml.environments import Environments
|
|
18
18
|
from trainml.exceptions import ApiError, TrainMLException
|
|
19
|
-
from trainml.connections import Connections
|
|
20
19
|
from trainml.projects import Projects
|
|
21
20
|
from trainml.cloudbender import Cloudbender
|
|
22
21
|
|
|
@@ -34,13 +33,17 @@ class TrainML(object):
|
|
|
34
33
|
os.environ.get("TRAINML_CONFIG_DIR") or "~/.trainml"
|
|
35
34
|
)
|
|
36
35
|
try:
|
|
37
|
-
with open(
|
|
36
|
+
with open(
|
|
37
|
+
f"{CONFIG_DIR}/environment.json", "r", encoding="utf-8"
|
|
38
|
+
) as file:
|
|
38
39
|
env_str = file.read().replace("\n", "")
|
|
39
40
|
env = json.loads(env_str)
|
|
40
41
|
except OSError:
|
|
41
42
|
env = dict()
|
|
42
43
|
try:
|
|
43
|
-
with open(
|
|
44
|
+
with open(
|
|
45
|
+
f"{CONFIG_DIR}/config.json", "r", encoding="utf-8"
|
|
46
|
+
) as file:
|
|
44
47
|
config_str = file.read().replace("\n", "")
|
|
45
48
|
config = json.loads(config_str)
|
|
46
49
|
except OSError:
|
|
@@ -72,7 +75,6 @@ class TrainML(object):
|
|
|
72
75
|
self.jobs = Jobs(self)
|
|
73
76
|
self.gpu_types = GpuTypes(self)
|
|
74
77
|
self.environments = Environments(self)
|
|
75
|
-
self.connections = Connections(self)
|
|
76
78
|
self.projects = Projects(self)
|
|
77
79
|
self.cloudbender = Cloudbender(self)
|
|
78
80
|
self.api_url = (
|
|
@@ -92,7 +94,16 @@ class TrainML(object):
|
|
|
92
94
|
def project(self) -> str:
|
|
93
95
|
return self.active_project
|
|
94
96
|
|
|
95
|
-
async def _query(
|
|
97
|
+
async def _query(
|
|
98
|
+
self,
|
|
99
|
+
path,
|
|
100
|
+
method,
|
|
101
|
+
params=None,
|
|
102
|
+
data=None,
|
|
103
|
+
headers=None,
|
|
104
|
+
max_retries=3,
|
|
105
|
+
backoff_factor=0.5,
|
|
106
|
+
):
|
|
96
107
|
try:
|
|
97
108
|
tokens = self.auth.get_tokens()
|
|
98
109
|
except TrainMLException as e:
|
|
@@ -120,7 +131,9 @@ class TrainML(object):
|
|
|
120
131
|
)
|
|
121
132
|
if params:
|
|
122
133
|
if not isinstance(params, dict):
|
|
123
|
-
raise TrainMLException(
|
|
134
|
+
raise TrainMLException(
|
|
135
|
+
"Query parameters must be a valid dictionary"
|
|
136
|
+
)
|
|
124
137
|
params = {
|
|
125
138
|
k: (str(v).lower() if isinstance(v, bool) else v)
|
|
126
139
|
for k, v in params.items()
|
|
@@ -154,27 +167,44 @@ class TrainML(object):
|
|
|
154
167
|
params=params,
|
|
155
168
|
) as resp:
|
|
156
169
|
if (resp.status // 100) in [4, 5]:
|
|
157
|
-
if
|
|
158
|
-
|
|
170
|
+
if (
|
|
171
|
+
resp.status == 502
|
|
172
|
+
and attempt < max_retries - 1
|
|
173
|
+
):
|
|
174
|
+
wait_time = (
|
|
175
|
+
(2**attempt)
|
|
176
|
+
* backoff_factor
|
|
177
|
+
* (random.random() + 0.5)
|
|
178
|
+
)
|
|
159
179
|
await asyncio.sleep(wait_time)
|
|
160
180
|
continue
|
|
161
181
|
else:
|
|
162
182
|
what = await resp.read()
|
|
163
|
-
content_type = resp.headers.get(
|
|
183
|
+
content_type = resp.headers.get(
|
|
184
|
+
"content-type", ""
|
|
185
|
+
)
|
|
164
186
|
resp.close()
|
|
165
187
|
if content_type == "application/json":
|
|
166
|
-
raise ApiError(
|
|
188
|
+
raise ApiError(
|
|
189
|
+
resp.status,
|
|
190
|
+
json.loads(what.decode("utf8")),
|
|
191
|
+
)
|
|
167
192
|
else:
|
|
168
|
-
raise ApiError(
|
|
193
|
+
raise ApiError(
|
|
194
|
+
resp.status,
|
|
195
|
+
{"message": what.decode("utf8")},
|
|
196
|
+
)
|
|
169
197
|
results = await resp.json()
|
|
170
198
|
return results
|
|
171
199
|
except aiohttp.ClientResponseError as e:
|
|
172
200
|
if e.status == 502 and attempt < max_retries - 1:
|
|
173
|
-
wait_time = (
|
|
201
|
+
wait_time = (
|
|
202
|
+
(2**attempt) * backoff_factor * (random.random() + 0.5)
|
|
203
|
+
)
|
|
174
204
|
await asyncio.sleep(wait_time)
|
|
175
205
|
continue
|
|
176
206
|
else:
|
|
177
|
-
raise
|
|
207
|
+
raise ApiError(e.status, f"Error {e.message}")
|
|
178
208
|
|
|
179
209
|
raise TrainMLException("Unexpected API failure")
|
|
180
210
|
|
|
@@ -184,114 +214,129 @@ class TrainML(object):
|
|
|
184
214
|
"Content-Type": "application/json",
|
|
185
215
|
}
|
|
186
216
|
try:
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
async with aiohttp.ClientSession() as session:
|
|
195
|
-
done = False
|
|
196
|
-
async with session.ws_connect(
|
|
197
|
-
f"wss://{self.ws_url}?Authorization={tokens.get('id_token')}",
|
|
198
|
-
headers=headers,
|
|
199
|
-
heartbeat=30,
|
|
200
|
-
) as ws:
|
|
201
|
-
asyncio.create_task(
|
|
202
|
-
ws.send_json(
|
|
203
|
-
dict(
|
|
204
|
-
action="getlogs",
|
|
205
|
-
data=dict(
|
|
206
|
-
type="init",
|
|
207
|
-
entity=entity,
|
|
208
|
-
id=id,
|
|
209
|
-
project_uuid=project_uuid,
|
|
210
|
-
),
|
|
211
|
-
)
|
|
212
|
-
)
|
|
217
|
+
try:
|
|
218
|
+
tokens = self.auth.get_tokens()
|
|
219
|
+
except TrainMLException as e:
|
|
220
|
+
raise e
|
|
221
|
+
except Exception:
|
|
222
|
+
raise TrainMLException(
|
|
223
|
+
f"Error getting authorization tokens. Verify configured credentials. Error: {traceback.format_exc()}"
|
|
213
224
|
)
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
|
|
223
|
-
|
|
225
|
+
async with aiohttp.ClientSession() as session:
|
|
226
|
+
done = False
|
|
227
|
+
async with session.ws_connect(
|
|
228
|
+
f"wss://{self.ws_url}?Authorization={tokens.get('id_token')}",
|
|
229
|
+
headers=headers,
|
|
230
|
+
heartbeat=30,
|
|
231
|
+
) as ws:
|
|
232
|
+
asyncio.create_task(
|
|
233
|
+
ws.send_json(
|
|
234
|
+
dict(
|
|
235
|
+
action="getlogs",
|
|
236
|
+
data=dict(
|
|
237
|
+
type="init",
|
|
238
|
+
entity=entity,
|
|
239
|
+
id=id,
|
|
240
|
+
project_uuid=project_uuid,
|
|
241
|
+
),
|
|
242
|
+
)
|
|
224
243
|
)
|
|
225
244
|
)
|
|
226
|
-
|
|
227
|
-
|
|
228
|
-
|
|
229
|
-
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
|
|
233
|
-
|
|
234
|
-
|
|
245
|
+
asyncio.create_task(
|
|
246
|
+
ws.send_json(
|
|
247
|
+
dict(
|
|
248
|
+
action="subscribe",
|
|
249
|
+
data=dict(
|
|
250
|
+
type="logs",
|
|
251
|
+
entity=entity,
|
|
252
|
+
id=id,
|
|
253
|
+
project_uuid=project_uuid,
|
|
254
|
+
),
|
|
255
|
+
)
|
|
235
256
|
)
|
|
236
|
-
|
|
237
|
-
|
|
238
|
-
|
|
239
|
-
|
|
240
|
-
|
|
241
|
-
|
|
242
|
-
|
|
243
|
-
|
|
244
|
-
|
|
257
|
+
)
|
|
258
|
+
async for msg in ws:
|
|
259
|
+
if msg.type in (
|
|
260
|
+
aiohttp.WSMsgType.CLOSED,
|
|
261
|
+
aiohttp.WSMsgType.ERROR,
|
|
262
|
+
aiohttp.WSMsgType.CLOSE,
|
|
263
|
+
):
|
|
264
|
+
logging.debug(
|
|
265
|
+
f"Websocket Received Closed Message. Done? {done}"
|
|
266
|
+
)
|
|
267
|
+
await ws.close()
|
|
268
|
+
break
|
|
269
|
+
data = json.loads(msg.data)
|
|
270
|
+
if data.get("type") == "end":
|
|
271
|
+
done = True
|
|
272
|
+
asyncio.create_task(delayed_close(ws))
|
|
273
|
+
else:
|
|
274
|
+
msg_handler(data)
|
|
275
|
+
logging.debug(f"Websocket Disconnected. Done? {done}")
|
|
245
276
|
|
|
246
|
-
|
|
247
|
-
|
|
248
|
-
|
|
249
|
-
|
|
250
|
-
|
|
251
|
-
|
|
252
|
-
|
|
253
|
-
|
|
254
|
-
|
|
255
|
-
|
|
256
|
-
|
|
257
|
-
|
|
258
|
-
|
|
259
|
-
|
|
260
|
-
|
|
261
|
-
|
|
262
|
-
|
|
263
|
-
|
|
264
|
-
|
|
277
|
+
connection_tries = 0
|
|
278
|
+
while not done:
|
|
279
|
+
tokens = self.auth.get_tokens()
|
|
280
|
+
try:
|
|
281
|
+
async with session.ws_connect(
|
|
282
|
+
f"wss://{self.ws_url}?Authorization={tokens.get('id_token')}",
|
|
283
|
+
headers=headers,
|
|
284
|
+
heartbeat=30,
|
|
285
|
+
) as ws:
|
|
286
|
+
asyncio.create_task(
|
|
287
|
+
ws.send_json(
|
|
288
|
+
dict(
|
|
289
|
+
action="subscribe",
|
|
290
|
+
data=dict(
|
|
291
|
+
type="logs",
|
|
292
|
+
entity=entity,
|
|
293
|
+
id=id,
|
|
294
|
+
project_uuid=project_uuid,
|
|
295
|
+
),
|
|
296
|
+
)
|
|
265
297
|
)
|
|
266
298
|
)
|
|
299
|
+
async for msg in ws:
|
|
300
|
+
if msg.type in (
|
|
301
|
+
aiohttp.WSMsgType.CLOSED,
|
|
302
|
+
aiohttp.WSMsgType.ERROR,
|
|
303
|
+
aiohttp.WSMsgType.CLOSE,
|
|
304
|
+
):
|
|
305
|
+
logging.debug(
|
|
306
|
+
f"Websocket Received Closed Message. Done? {done}"
|
|
307
|
+
)
|
|
308
|
+
await ws.close()
|
|
309
|
+
break
|
|
310
|
+
data = json.loads(msg.data)
|
|
311
|
+
if data.get("type") == "end":
|
|
312
|
+
done = True
|
|
313
|
+
asyncio.create_task(delayed_close(ws))
|
|
314
|
+
else:
|
|
315
|
+
msg_handler(data)
|
|
316
|
+
connection_tries = 0
|
|
317
|
+
logging.debug(f"Websocket Disconnected. Done? {done}")
|
|
318
|
+
except Exception as e:
|
|
319
|
+
connection_tries += 1
|
|
320
|
+
logging.debug(
|
|
321
|
+
f"Connection error: {traceback.format_exc()}"
|
|
267
322
|
)
|
|
268
|
-
|
|
269
|
-
|
|
270
|
-
|
|
271
|
-
|
|
272
|
-
|
|
273
|
-
|
|
274
|
-
|
|
275
|
-
|
|
276
|
-
|
|
277
|
-
|
|
278
|
-
|
|
279
|
-
|
|
280
|
-
|
|
281
|
-
|
|
282
|
-
|
|
283
|
-
|
|
284
|
-
|
|
285
|
-
connection_tries = 0
|
|
286
|
-
logging.debug(f"Websocket Disconnected. Done? {done}")
|
|
287
|
-
except Exception as e:
|
|
288
|
-
connection_tries += 1
|
|
289
|
-
logging.debug(f"Connection error: {traceback.format_exc()}")
|
|
290
|
-
if connection_tries == 5:
|
|
291
|
-
raise ApiError(
|
|
292
|
-
500,
|
|
293
|
-
{"message": f"Connection error: {traceback.format_exc()}"},
|
|
294
|
-
)
|
|
323
|
+
if connection_tries == 5:
|
|
324
|
+
raise ApiError(
|
|
325
|
+
500,
|
|
326
|
+
{
|
|
327
|
+
"message": f"Connection error: {traceback.format_exc()}"
|
|
328
|
+
},
|
|
329
|
+
)
|
|
330
|
+
except GeneratorExit:
|
|
331
|
+
# Handle graceful shutdown - GeneratorExit is raised during
|
|
332
|
+
# event loop cleanup. Don't re-raise to avoid "coroutine ignored"
|
|
333
|
+
# warnings.
|
|
334
|
+
logging.debug("Websocket subscription cancelled during shutdown")
|
|
335
|
+
return
|
|
336
|
+
except asyncio.CancelledError:
|
|
337
|
+
# Re-raise CancelledError to properly propagate task cancellation
|
|
338
|
+
logging.debug("Websocket subscription task cancelled")
|
|
339
|
+
raise
|
|
295
340
|
|
|
296
341
|
def set_active_project(self, project_uuid):
|
|
297
342
|
CONFIG_DIR = os.path.expanduser(
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
"""Utility modules for trainml SDK."""
|