proximl 0.5.17__py3-none-any.whl → 1.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- examples/local_storage.py +0 -2
- proximl/__init__.py +1 -1
- proximl/checkpoints.py +56 -57
- proximl/cli/__init__.py +6 -3
- proximl/cli/checkpoint.py +18 -57
- proximl/cli/dataset.py +17 -57
- proximl/cli/job/__init__.py +89 -67
- proximl/cli/job/create.py +51 -24
- proximl/cli/model.py +14 -56
- proximl/cli/volume.py +18 -57
- proximl/datasets.py +50 -55
- proximl/jobs.py +269 -69
- proximl/models.py +51 -55
- proximl/proximl.py +159 -114
- proximl/utils/__init__.py +1 -0
- proximl/{auth.py → utils/auth.py} +4 -3
- proximl/utils/transfer.py +647 -0
- proximl/volumes.py +48 -53
- {proximl-0.5.17.dist-info → proximl-1.0.1.dist-info}/METADATA +3 -3
- {proximl-0.5.17.dist-info → proximl-1.0.1.dist-info}/RECORD +52 -50
- tests/integration/test_checkpoints_integration.py +4 -3
- tests/integration/test_datasets_integration.py +5 -3
- tests/integration/test_jobs_integration.py +33 -27
- tests/integration/test_models_integration.py +7 -3
- tests/integration/test_volumes_integration.py +2 -2
- tests/unit/cli/test_cli_checkpoint_unit.py +312 -1
- tests/unit/cloudbender/test_nodes_unit.py +112 -0
- tests/unit/cloudbender/test_providers_unit.py +96 -0
- tests/unit/cloudbender/test_regions_unit.py +106 -0
- tests/unit/cloudbender/test_services_unit.py +141 -0
- tests/unit/conftest.py +23 -10
- tests/unit/projects/test_project_data_connectors_unit.py +39 -0
- tests/unit/projects/test_project_datastores_unit.py +37 -0
- tests/unit/projects/test_project_members_unit.py +46 -0
- tests/unit/projects/test_project_services_unit.py +65 -0
- tests/unit/projects/test_projects_unit.py +16 -0
- tests/unit/test_auth_unit.py +17 -2
- tests/unit/test_checkpoints_unit.py +256 -71
- tests/unit/test_datasets_unit.py +218 -68
- tests/unit/test_exceptions.py +133 -0
- tests/unit/test_gpu_types_unit.py +11 -1
- tests/unit/test_jobs_unit.py +1014 -95
- tests/unit/test_main_unit.py +20 -0
- tests/unit/test_models_unit.py +218 -70
- tests/unit/test_proximl_unit.py +627 -3
- tests/unit/test_volumes_unit.py +211 -70
- tests/unit/utils/__init__.py +1 -0
- tests/unit/utils/test_transfer_unit.py +4260 -0
- proximl/cli/connection.py +0 -61
- proximl/connections.py +0 -621
- tests/unit/test_connections_unit.py +0 -182
- {proximl-0.5.17.dist-info → proximl-1.0.1.dist-info}/LICENSE +0 -0
- {proximl-0.5.17.dist-info → proximl-1.0.1.dist-info}/WHEEL +0 -0
- {proximl-0.5.17.dist-info → proximl-1.0.1.dist-info}/entry_points.txt +0 -0
- {proximl-0.5.17.dist-info → proximl-1.0.1.dist-info}/top_level.txt +0 -0
proximl/proximl.py
CHANGED
|
@@ -7,7 +7,7 @@ import traceback
|
|
|
7
7
|
import random
|
|
8
8
|
from importlib.metadata import version
|
|
9
9
|
|
|
10
|
-
from proximl.auth import Auth
|
|
10
|
+
from proximl.utils.auth import Auth
|
|
11
11
|
from proximl.datasets import Datasets
|
|
12
12
|
from proximl.models import Models
|
|
13
13
|
from proximl.checkpoints import Checkpoints
|
|
@@ -16,7 +16,6 @@ from proximl.jobs import Jobs
|
|
|
16
16
|
from proximl.gpu_types import GpuTypes
|
|
17
17
|
from proximl.environments import Environments
|
|
18
18
|
from proximl.exceptions import ApiError, ProxiMLException
|
|
19
|
-
from proximl.connections import Connections
|
|
20
19
|
from proximl.projects import Projects
|
|
21
20
|
from proximl.cloudbender import Cloudbender
|
|
22
21
|
|
|
@@ -34,13 +33,17 @@ class ProxiML(object):
|
|
|
34
33
|
os.environ.get("PROXIML_CONFIG_DIR") or "~/.proximl"
|
|
35
34
|
)
|
|
36
35
|
try:
|
|
37
|
-
with open(
|
|
36
|
+
with open(
|
|
37
|
+
f"{CONFIG_DIR}/environment.json", "r", encoding="utf-8"
|
|
38
|
+
) as file:
|
|
38
39
|
env_str = file.read().replace("\n", "")
|
|
39
40
|
env = json.loads(env_str)
|
|
40
41
|
except OSError:
|
|
41
42
|
env = dict()
|
|
42
43
|
try:
|
|
43
|
-
with open(
|
|
44
|
+
with open(
|
|
45
|
+
f"{CONFIG_DIR}/config.json", "r", encoding="utf-8"
|
|
46
|
+
) as file:
|
|
44
47
|
config_str = file.read().replace("\n", "")
|
|
45
48
|
config = json.loads(config_str)
|
|
46
49
|
except OSError:
|
|
@@ -72,7 +75,6 @@ class ProxiML(object):
|
|
|
72
75
|
self.jobs = Jobs(self)
|
|
73
76
|
self.gpu_types = GpuTypes(self)
|
|
74
77
|
self.environments = Environments(self)
|
|
75
|
-
self.connections = Connections(self)
|
|
76
78
|
self.projects = Projects(self)
|
|
77
79
|
self.cloudbender = Cloudbender(self)
|
|
78
80
|
self.api_url = (
|
|
@@ -92,7 +94,16 @@ class ProxiML(object):
|
|
|
92
94
|
def project(self) -> str:
|
|
93
95
|
return self.active_project
|
|
94
96
|
|
|
95
|
-
async def _query(
|
|
97
|
+
async def _query(
|
|
98
|
+
self,
|
|
99
|
+
path,
|
|
100
|
+
method,
|
|
101
|
+
params=None,
|
|
102
|
+
data=None,
|
|
103
|
+
headers=None,
|
|
104
|
+
max_retries=3,
|
|
105
|
+
backoff_factor=0.5,
|
|
106
|
+
):
|
|
96
107
|
try:
|
|
97
108
|
tokens = self.auth.get_tokens()
|
|
98
109
|
except ProxiMLException as e:
|
|
@@ -120,7 +131,9 @@ class ProxiML(object):
|
|
|
120
131
|
)
|
|
121
132
|
if params:
|
|
122
133
|
if not isinstance(params, dict):
|
|
123
|
-
raise ProxiMLException(
|
|
134
|
+
raise ProxiMLException(
|
|
135
|
+
"Query parameters must be a valid dictionary"
|
|
136
|
+
)
|
|
124
137
|
params = {
|
|
125
138
|
k: (str(v).lower() if isinstance(v, bool) else v)
|
|
126
139
|
for k, v in params.items()
|
|
@@ -154,27 +167,44 @@ class ProxiML(object):
|
|
|
154
167
|
params=params,
|
|
155
168
|
) as resp:
|
|
156
169
|
if (resp.status // 100) in [4, 5]:
|
|
157
|
-
if
|
|
158
|
-
|
|
170
|
+
if (
|
|
171
|
+
resp.status == 502
|
|
172
|
+
and attempt < max_retries - 1
|
|
173
|
+
):
|
|
174
|
+
wait_time = (
|
|
175
|
+
(2**attempt)
|
|
176
|
+
* backoff_factor
|
|
177
|
+
* (random.random() + 0.5)
|
|
178
|
+
)
|
|
159
179
|
await asyncio.sleep(wait_time)
|
|
160
180
|
continue
|
|
161
181
|
else:
|
|
162
182
|
what = await resp.read()
|
|
163
|
-
content_type = resp.headers.get(
|
|
183
|
+
content_type = resp.headers.get(
|
|
184
|
+
"content-type", ""
|
|
185
|
+
)
|
|
164
186
|
resp.close()
|
|
165
187
|
if content_type == "application/json":
|
|
166
|
-
raise ApiError(
|
|
188
|
+
raise ApiError(
|
|
189
|
+
resp.status,
|
|
190
|
+
json.loads(what.decode("utf8")),
|
|
191
|
+
)
|
|
167
192
|
else:
|
|
168
|
-
raise ApiError(
|
|
193
|
+
raise ApiError(
|
|
194
|
+
resp.status,
|
|
195
|
+
{"message": what.decode("utf8")},
|
|
196
|
+
)
|
|
169
197
|
results = await resp.json()
|
|
170
198
|
return results
|
|
171
199
|
except aiohttp.ClientResponseError as e:
|
|
172
200
|
if e.status == 502 and attempt < max_retries - 1:
|
|
173
|
-
wait_time = (
|
|
201
|
+
wait_time = (
|
|
202
|
+
(2**attempt) * backoff_factor * (random.random() + 0.5)
|
|
203
|
+
)
|
|
174
204
|
await asyncio.sleep(wait_time)
|
|
175
205
|
continue
|
|
176
206
|
else:
|
|
177
|
-
raise
|
|
207
|
+
raise ApiError(e.status, f"Error {e.message}")
|
|
178
208
|
|
|
179
209
|
raise ProxiMLException("Unexpected API failure")
|
|
180
210
|
|
|
@@ -184,114 +214,129 @@ class ProxiML(object):
|
|
|
184
214
|
"Content-Type": "application/json",
|
|
185
215
|
}
|
|
186
216
|
try:
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
async with aiohttp.ClientSession() as session:
|
|
195
|
-
done = False
|
|
196
|
-
async with session.ws_connect(
|
|
197
|
-
f"wss://{self.ws_url}?Authorization={tokens.get('id_token')}",
|
|
198
|
-
headers=headers,
|
|
199
|
-
heartbeat=30,
|
|
200
|
-
) as ws:
|
|
201
|
-
asyncio.create_task(
|
|
202
|
-
ws.send_json(
|
|
203
|
-
dict(
|
|
204
|
-
action="getlogs",
|
|
205
|
-
data=dict(
|
|
206
|
-
type="init",
|
|
207
|
-
entity=entity,
|
|
208
|
-
id=id,
|
|
209
|
-
project_uuid=project_uuid,
|
|
210
|
-
),
|
|
211
|
-
)
|
|
212
|
-
)
|
|
217
|
+
try:
|
|
218
|
+
tokens = self.auth.get_tokens()
|
|
219
|
+
except ProxiMLException as e:
|
|
220
|
+
raise e
|
|
221
|
+
except Exception:
|
|
222
|
+
raise ProxiMLException(
|
|
223
|
+
f"Error getting authorization tokens. Verify configured credentials. Error: {traceback.format_exc()}"
|
|
213
224
|
)
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
|
|
223
|
-
|
|
225
|
+
async with aiohttp.ClientSession() as session:
|
|
226
|
+
done = False
|
|
227
|
+
async with session.ws_connect(
|
|
228
|
+
f"wss://{self.ws_url}?Authorization={tokens.get('id_token')}",
|
|
229
|
+
headers=headers,
|
|
230
|
+
heartbeat=30,
|
|
231
|
+
) as ws:
|
|
232
|
+
asyncio.create_task(
|
|
233
|
+
ws.send_json(
|
|
234
|
+
dict(
|
|
235
|
+
action="getlogs",
|
|
236
|
+
data=dict(
|
|
237
|
+
type="init",
|
|
238
|
+
entity=entity,
|
|
239
|
+
id=id,
|
|
240
|
+
project_uuid=project_uuid,
|
|
241
|
+
),
|
|
242
|
+
)
|
|
224
243
|
)
|
|
225
244
|
)
|
|
226
|
-
|
|
227
|
-
|
|
228
|
-
|
|
229
|
-
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
|
|
233
|
-
|
|
234
|
-
|
|
245
|
+
asyncio.create_task(
|
|
246
|
+
ws.send_json(
|
|
247
|
+
dict(
|
|
248
|
+
action="subscribe",
|
|
249
|
+
data=dict(
|
|
250
|
+
type="logs",
|
|
251
|
+
entity=entity,
|
|
252
|
+
id=id,
|
|
253
|
+
project_uuid=project_uuid,
|
|
254
|
+
),
|
|
255
|
+
)
|
|
235
256
|
)
|
|
236
|
-
|
|
237
|
-
|
|
238
|
-
|
|
239
|
-
|
|
240
|
-
|
|
241
|
-
|
|
242
|
-
|
|
243
|
-
|
|
244
|
-
|
|
257
|
+
)
|
|
258
|
+
async for msg in ws:
|
|
259
|
+
if msg.type in (
|
|
260
|
+
aiohttp.WSMsgType.CLOSED,
|
|
261
|
+
aiohttp.WSMsgType.ERROR,
|
|
262
|
+
aiohttp.WSMsgType.CLOSE,
|
|
263
|
+
):
|
|
264
|
+
logging.debug(
|
|
265
|
+
f"Websocket Received Closed Message. Done? {done}"
|
|
266
|
+
)
|
|
267
|
+
await ws.close()
|
|
268
|
+
break
|
|
269
|
+
data = json.loads(msg.data)
|
|
270
|
+
if data.get("type") == "end":
|
|
271
|
+
done = True
|
|
272
|
+
asyncio.create_task(delayed_close(ws))
|
|
273
|
+
else:
|
|
274
|
+
msg_handler(data)
|
|
275
|
+
logging.debug(f"Websocket Disconnected. Done? {done}")
|
|
245
276
|
|
|
246
|
-
|
|
247
|
-
|
|
248
|
-
|
|
249
|
-
|
|
250
|
-
|
|
251
|
-
|
|
252
|
-
|
|
253
|
-
|
|
254
|
-
|
|
255
|
-
|
|
256
|
-
|
|
257
|
-
|
|
258
|
-
|
|
259
|
-
|
|
260
|
-
|
|
261
|
-
|
|
262
|
-
|
|
263
|
-
|
|
264
|
-
|
|
277
|
+
connection_tries = 0
|
|
278
|
+
while not done:
|
|
279
|
+
tokens = self.auth.get_tokens()
|
|
280
|
+
try:
|
|
281
|
+
async with session.ws_connect(
|
|
282
|
+
f"wss://{self.ws_url}?Authorization={tokens.get('id_token')}",
|
|
283
|
+
headers=headers,
|
|
284
|
+
heartbeat=30,
|
|
285
|
+
) as ws:
|
|
286
|
+
asyncio.create_task(
|
|
287
|
+
ws.send_json(
|
|
288
|
+
dict(
|
|
289
|
+
action="subscribe",
|
|
290
|
+
data=dict(
|
|
291
|
+
type="logs",
|
|
292
|
+
entity=entity,
|
|
293
|
+
id=id,
|
|
294
|
+
project_uuid=project_uuid,
|
|
295
|
+
),
|
|
296
|
+
)
|
|
265
297
|
)
|
|
266
298
|
)
|
|
299
|
+
async for msg in ws:
|
|
300
|
+
if msg.type in (
|
|
301
|
+
aiohttp.WSMsgType.CLOSED,
|
|
302
|
+
aiohttp.WSMsgType.ERROR,
|
|
303
|
+
aiohttp.WSMsgType.CLOSE,
|
|
304
|
+
):
|
|
305
|
+
logging.debug(
|
|
306
|
+
f"Websocket Received Closed Message. Done? {done}"
|
|
307
|
+
)
|
|
308
|
+
await ws.close()
|
|
309
|
+
break
|
|
310
|
+
data = json.loads(msg.data)
|
|
311
|
+
if data.get("type") == "end":
|
|
312
|
+
done = True
|
|
313
|
+
asyncio.create_task(delayed_close(ws))
|
|
314
|
+
else:
|
|
315
|
+
msg_handler(data)
|
|
316
|
+
connection_tries = 0
|
|
317
|
+
logging.debug(f"Websocket Disconnected. Done? {done}")
|
|
318
|
+
except Exception as e:
|
|
319
|
+
connection_tries += 1
|
|
320
|
+
logging.debug(
|
|
321
|
+
f"Connection error: {traceback.format_exc()}"
|
|
267
322
|
)
|
|
268
|
-
|
|
269
|
-
|
|
270
|
-
|
|
271
|
-
|
|
272
|
-
|
|
273
|
-
|
|
274
|
-
|
|
275
|
-
|
|
276
|
-
|
|
277
|
-
|
|
278
|
-
|
|
279
|
-
|
|
280
|
-
|
|
281
|
-
|
|
282
|
-
|
|
283
|
-
|
|
284
|
-
|
|
285
|
-
connection_tries = 0
|
|
286
|
-
logging.debug(f"Websocket Disconnected. Done? {done}")
|
|
287
|
-
except Exception as e:
|
|
288
|
-
connection_tries += 1
|
|
289
|
-
logging.debug(f"Connection error: {traceback.format_exc()}")
|
|
290
|
-
if connection_tries == 5:
|
|
291
|
-
raise ApiError(
|
|
292
|
-
500,
|
|
293
|
-
{"message": f"Connection error: {traceback.format_exc()}"},
|
|
294
|
-
)
|
|
323
|
+
if connection_tries == 5:
|
|
324
|
+
raise ApiError(
|
|
325
|
+
500,
|
|
326
|
+
{
|
|
327
|
+
"message": f"Connection error: {traceback.format_exc()}"
|
|
328
|
+
},
|
|
329
|
+
)
|
|
330
|
+
except GeneratorExit:
|
|
331
|
+
# Handle graceful shutdown - GeneratorExit is raised during
|
|
332
|
+
# event loop cleanup. Don't re-raise to avoid "coroutine ignored"
|
|
333
|
+
# warnings.
|
|
334
|
+
logging.debug("Websocket subscription cancelled during shutdown")
|
|
335
|
+
return
|
|
336
|
+
except asyncio.CancelledError:
|
|
337
|
+
# Re-raise CancelledError to properly propagate task cancellation
|
|
338
|
+
logging.debug("Websocket subscription task cancelled")
|
|
339
|
+
raise
|
|
295
340
|
|
|
296
341
|
def set_active_project(self, project_uuid):
|
|
297
342
|
CONFIG_DIR = os.path.expanduser(
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
"""Utility modules for proximl SDK."""
|
|
@@ -46,7 +46,7 @@
|
|
|
46
46
|
# editorial revisions, annotations, elaborations, or other modifications
|
|
47
47
|
# represent, as a whole, an original work of authorship. For the purposes
|
|
48
48
|
# of this License, Derivative Works shall not include works that remain
|
|
49
|
-
# separable from, or merely link (or bind by name) to the interfaces of
|
|
49
|
+
# separable from, or merely link (or bind by name) to the interfaces of
|
|
50
50
|
# the Work and Derivative Works thereof.
|
|
51
51
|
|
|
52
52
|
# "Contribution" shall mean any work of authorship, including
|
|
@@ -609,15 +609,16 @@ class Auth(object):
|
|
|
609
609
|
logging.debug(f"ID Token Verification: {id_verify}")
|
|
610
610
|
if id_verify:
|
|
611
611
|
id_token = tokens["AuthenticationResult"]["IdToken"]
|
|
612
|
+
self.id_token = id_token
|
|
613
|
+
|
|
612
614
|
access_verify = self.verify_token(
|
|
613
615
|
tokens["AuthenticationResult"]["AccessToken"], "access_token"
|
|
614
616
|
)
|
|
615
617
|
logging.debug(f"Access Token Verification: {access_verify}")
|
|
616
618
|
if access_verify:
|
|
617
619
|
access_token = tokens["AuthenticationResult"]["AccessToken"]
|
|
620
|
+
self.access_token = access_token
|
|
618
621
|
|
|
619
|
-
self.id_token = id_token
|
|
620
|
-
self.access_token = access_token
|
|
621
622
|
self.refresh_token = refresh_token
|
|
622
623
|
self.expires = (
|
|
623
624
|
id_verify.get("exp") - 300
|