proximl 0.5.17__py3-none-any.whl → 1.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (55) hide show
  1. examples/local_storage.py +0 -2
  2. proximl/__init__.py +1 -1
  3. proximl/checkpoints.py +56 -57
  4. proximl/cli/__init__.py +6 -3
  5. proximl/cli/checkpoint.py +18 -57
  6. proximl/cli/dataset.py +17 -57
  7. proximl/cli/job/__init__.py +89 -67
  8. proximl/cli/job/create.py +51 -24
  9. proximl/cli/model.py +14 -56
  10. proximl/cli/volume.py +18 -57
  11. proximl/datasets.py +50 -55
  12. proximl/jobs.py +269 -69
  13. proximl/models.py +51 -55
  14. proximl/proximl.py +159 -114
  15. proximl/utils/__init__.py +1 -0
  16. proximl/{auth.py → utils/auth.py} +4 -3
  17. proximl/utils/transfer.py +647 -0
  18. proximl/volumes.py +48 -53
  19. {proximl-0.5.17.dist-info → proximl-1.0.1.dist-info}/METADATA +3 -3
  20. {proximl-0.5.17.dist-info → proximl-1.0.1.dist-info}/RECORD +52 -50
  21. tests/integration/test_checkpoints_integration.py +4 -3
  22. tests/integration/test_datasets_integration.py +5 -3
  23. tests/integration/test_jobs_integration.py +33 -27
  24. tests/integration/test_models_integration.py +7 -3
  25. tests/integration/test_volumes_integration.py +2 -2
  26. tests/unit/cli/test_cli_checkpoint_unit.py +312 -1
  27. tests/unit/cloudbender/test_nodes_unit.py +112 -0
  28. tests/unit/cloudbender/test_providers_unit.py +96 -0
  29. tests/unit/cloudbender/test_regions_unit.py +106 -0
  30. tests/unit/cloudbender/test_services_unit.py +141 -0
  31. tests/unit/conftest.py +23 -10
  32. tests/unit/projects/test_project_data_connectors_unit.py +39 -0
  33. tests/unit/projects/test_project_datastores_unit.py +37 -0
  34. tests/unit/projects/test_project_members_unit.py +46 -0
  35. tests/unit/projects/test_project_services_unit.py +65 -0
  36. tests/unit/projects/test_projects_unit.py +16 -0
  37. tests/unit/test_auth_unit.py +17 -2
  38. tests/unit/test_checkpoints_unit.py +256 -71
  39. tests/unit/test_datasets_unit.py +218 -68
  40. tests/unit/test_exceptions.py +133 -0
  41. tests/unit/test_gpu_types_unit.py +11 -1
  42. tests/unit/test_jobs_unit.py +1014 -95
  43. tests/unit/test_main_unit.py +20 -0
  44. tests/unit/test_models_unit.py +218 -70
  45. tests/unit/test_proximl_unit.py +627 -3
  46. tests/unit/test_volumes_unit.py +211 -70
  47. tests/unit/utils/__init__.py +1 -0
  48. tests/unit/utils/test_transfer_unit.py +4260 -0
  49. proximl/cli/connection.py +0 -61
  50. proximl/connections.py +0 -621
  51. tests/unit/test_connections_unit.py +0 -182
  52. {proximl-0.5.17.dist-info → proximl-1.0.1.dist-info}/LICENSE +0 -0
  53. {proximl-0.5.17.dist-info → proximl-1.0.1.dist-info}/WHEEL +0 -0
  54. {proximl-0.5.17.dist-info → proximl-1.0.1.dist-info}/entry_points.txt +0 -0
  55. {proximl-0.5.17.dist-info → proximl-1.0.1.dist-info}/top_level.txt +0 -0
proximl/proximl.py CHANGED
@@ -7,7 +7,7 @@ import traceback
7
7
  import random
8
8
  from importlib.metadata import version
9
9
 
10
- from proximl.auth import Auth
10
+ from proximl.utils.auth import Auth
11
11
  from proximl.datasets import Datasets
12
12
  from proximl.models import Models
13
13
  from proximl.checkpoints import Checkpoints
@@ -16,7 +16,6 @@ from proximl.jobs import Jobs
16
16
  from proximl.gpu_types import GpuTypes
17
17
  from proximl.environments import Environments
18
18
  from proximl.exceptions import ApiError, ProxiMLException
19
- from proximl.connections import Connections
20
19
  from proximl.projects import Projects
21
20
  from proximl.cloudbender import Cloudbender
22
21
 
@@ -34,13 +33,17 @@ class ProxiML(object):
34
33
  os.environ.get("PROXIML_CONFIG_DIR") or "~/.proximl"
35
34
  )
36
35
  try:
37
- with open(f"{CONFIG_DIR}/environment.json", "r") as file:
36
+ with open(
37
+ f"{CONFIG_DIR}/environment.json", "r", encoding="utf-8"
38
+ ) as file:
38
39
  env_str = file.read().replace("\n", "")
39
40
  env = json.loads(env_str)
40
41
  except OSError:
41
42
  env = dict()
42
43
  try:
43
- with open(f"{CONFIG_DIR}/config.json", "r") as file:
44
+ with open(
45
+ f"{CONFIG_DIR}/config.json", "r", encoding="utf-8"
46
+ ) as file:
44
47
  config_str = file.read().replace("\n", "")
45
48
  config = json.loads(config_str)
46
49
  except OSError:
@@ -72,7 +75,6 @@ class ProxiML(object):
72
75
  self.jobs = Jobs(self)
73
76
  self.gpu_types = GpuTypes(self)
74
77
  self.environments = Environments(self)
75
- self.connections = Connections(self)
76
78
  self.projects = Projects(self)
77
79
  self.cloudbender = Cloudbender(self)
78
80
  self.api_url = (
@@ -92,7 +94,16 @@ class ProxiML(object):
92
94
  def project(self) -> str:
93
95
  return self.active_project
94
96
 
95
- async def _query(self, path, method, params=None, data=None, headers=None,max_retries=3, backoff_factor=0.5):
97
+ async def _query(
98
+ self,
99
+ path,
100
+ method,
101
+ params=None,
102
+ data=None,
103
+ headers=None,
104
+ max_retries=3,
105
+ backoff_factor=0.5,
106
+ ):
96
107
  try:
97
108
  tokens = self.auth.get_tokens()
98
109
  except ProxiMLException as e:
@@ -120,7 +131,9 @@ class ProxiML(object):
120
131
  )
121
132
  if params:
122
133
  if not isinstance(params, dict):
123
- raise ProxiMLException("Query parameters must be a valid dictionary")
134
+ raise ProxiMLException(
135
+ "Query parameters must be a valid dictionary"
136
+ )
124
137
  params = {
125
138
  k: (str(v).lower() if isinstance(v, bool) else v)
126
139
  for k, v in params.items()
@@ -154,27 +167,44 @@ class ProxiML(object):
154
167
  params=params,
155
168
  ) as resp:
156
169
  if (resp.status // 100) in [4, 5]:
157
- if resp.status == 502 and attempt < max_retries - 1:
158
- wait_time = (2 ** attempt) * backoff_factor * (random.random() + 0.5)
170
+ if (
171
+ resp.status == 502
172
+ and attempt < max_retries - 1
173
+ ):
174
+ wait_time = (
175
+ (2**attempt)
176
+ * backoff_factor
177
+ * (random.random() + 0.5)
178
+ )
159
179
  await asyncio.sleep(wait_time)
160
180
  continue
161
181
  else:
162
182
  what = await resp.read()
163
- content_type = resp.headers.get("content-type", "")
183
+ content_type = resp.headers.get(
184
+ "content-type", ""
185
+ )
164
186
  resp.close()
165
187
  if content_type == "application/json":
166
- raise ApiError(resp.status, json.loads(what.decode("utf8")))
188
+ raise ApiError(
189
+ resp.status,
190
+ json.loads(what.decode("utf8")),
191
+ )
167
192
  else:
168
- raise ApiError(resp.status, {"message": what.decode("utf8")})
193
+ raise ApiError(
194
+ resp.status,
195
+ {"message": what.decode("utf8")},
196
+ )
169
197
  results = await resp.json()
170
198
  return results
171
199
  except aiohttp.ClientResponseError as e:
172
200
  if e.status == 502 and attempt < max_retries - 1:
173
- wait_time = (2 ** attempt) * backoff_factor * (random.random() + 0.5)
201
+ wait_time = (
202
+ (2**attempt) * backoff_factor * (random.random() + 0.5)
203
+ )
174
204
  await asyncio.sleep(wait_time)
175
205
  continue
176
206
  else:
177
- raise ApiError(e.status, f"Error {e.message}")
207
+ raise ApiError(e.status, f"Error {e.message}")
178
208
 
179
209
  raise ProxiMLException("Unexpected API failure")
180
210
 
@@ -184,114 +214,129 @@ class ProxiML(object):
184
214
  "Content-Type": "application/json",
185
215
  }
186
216
  try:
187
- tokens = self.auth.get_tokens()
188
- except ProxiMLException as e:
189
- raise e
190
- except Exception:
191
- raise ProxiMLException(
192
- f"Error getting authorization tokens. Verify configured credentials. Error: {traceback.format_exc()}"
193
- )
194
- async with aiohttp.ClientSession() as session:
195
- done = False
196
- async with session.ws_connect(
197
- f"wss://{self.ws_url}?Authorization={tokens.get('id_token')}",
198
- headers=headers,
199
- heartbeat=30,
200
- ) as ws:
201
- asyncio.create_task(
202
- ws.send_json(
203
- dict(
204
- action="getlogs",
205
- data=dict(
206
- type="init",
207
- entity=entity,
208
- id=id,
209
- project_uuid=project_uuid,
210
- ),
211
- )
212
- )
217
+ try:
218
+ tokens = self.auth.get_tokens()
219
+ except ProxiMLException as e:
220
+ raise e
221
+ except Exception:
222
+ raise ProxiMLException(
223
+ f"Error getting authorization tokens. Verify configured credentials. Error: {traceback.format_exc()}"
213
224
  )
214
- asyncio.create_task(
215
- ws.send_json(
216
- dict(
217
- action="subscribe",
218
- data=dict(
219
- type="logs",
220
- entity=entity,
221
- id=id,
222
- project_uuid=project_uuid,
223
- ),
225
+ async with aiohttp.ClientSession() as session:
226
+ done = False
227
+ async with session.ws_connect(
228
+ f"wss://{self.ws_url}?Authorization={tokens.get('id_token')}",
229
+ headers=headers,
230
+ heartbeat=30,
231
+ ) as ws:
232
+ asyncio.create_task(
233
+ ws.send_json(
234
+ dict(
235
+ action="getlogs",
236
+ data=dict(
237
+ type="init",
238
+ entity=entity,
239
+ id=id,
240
+ project_uuid=project_uuid,
241
+ ),
242
+ )
224
243
  )
225
244
  )
226
- )
227
- async for msg in ws:
228
- if msg.type in (
229
- aiohttp.WSMsgType.CLOSED,
230
- aiohttp.WSMsgType.ERROR,
231
- aiohttp.WSMsgType.CLOSE,
232
- ):
233
- logging.debug(
234
- f"Websocket Received Closed Message. Done? {done}"
245
+ asyncio.create_task(
246
+ ws.send_json(
247
+ dict(
248
+ action="subscribe",
249
+ data=dict(
250
+ type="logs",
251
+ entity=entity,
252
+ id=id,
253
+ project_uuid=project_uuid,
254
+ ),
255
+ )
235
256
  )
236
- await ws.close()
237
- break
238
- data = json.loads(msg.data)
239
- if data.get("type") == "end":
240
- done = True
241
- asyncio.create_task(delayed_close(ws))
242
- else:
243
- msg_handler(data)
244
- logging.debug(f"Websocket Disconnected. Done? {done}")
257
+ )
258
+ async for msg in ws:
259
+ if msg.type in (
260
+ aiohttp.WSMsgType.CLOSED,
261
+ aiohttp.WSMsgType.ERROR,
262
+ aiohttp.WSMsgType.CLOSE,
263
+ ):
264
+ logging.debug(
265
+ f"Websocket Received Closed Message. Done? {done}"
266
+ )
267
+ await ws.close()
268
+ break
269
+ data = json.loads(msg.data)
270
+ if data.get("type") == "end":
271
+ done = True
272
+ asyncio.create_task(delayed_close(ws))
273
+ else:
274
+ msg_handler(data)
275
+ logging.debug(f"Websocket Disconnected. Done? {done}")
245
276
 
246
- connection_tries = 0
247
- while not done:
248
- tokens = self.auth.get_tokens()
249
- try:
250
- async with session.ws_connect(
251
- f"wss://{self.ws_url}?Authorization={tokens.get('id_token')}",
252
- headers=headers,
253
- heartbeat=30,
254
- ) as ws:
255
- asyncio.create_task(
256
- ws.send_json(
257
- dict(
258
- action="subscribe",
259
- data=dict(
260
- type="logs",
261
- entity=entity,
262
- id=id,
263
- project_uuid=project_uuid,
264
- ),
277
+ connection_tries = 0
278
+ while not done:
279
+ tokens = self.auth.get_tokens()
280
+ try:
281
+ async with session.ws_connect(
282
+ f"wss://{self.ws_url}?Authorization={tokens.get('id_token')}",
283
+ headers=headers,
284
+ heartbeat=30,
285
+ ) as ws:
286
+ asyncio.create_task(
287
+ ws.send_json(
288
+ dict(
289
+ action="subscribe",
290
+ data=dict(
291
+ type="logs",
292
+ entity=entity,
293
+ id=id,
294
+ project_uuid=project_uuid,
295
+ ),
296
+ )
265
297
  )
266
298
  )
299
+ async for msg in ws:
300
+ if msg.type in (
301
+ aiohttp.WSMsgType.CLOSED,
302
+ aiohttp.WSMsgType.ERROR,
303
+ aiohttp.WSMsgType.CLOSE,
304
+ ):
305
+ logging.debug(
306
+ f"Websocket Received Closed Message. Done? {done}"
307
+ )
308
+ await ws.close()
309
+ break
310
+ data = json.loads(msg.data)
311
+ if data.get("type") == "end":
312
+ done = True
313
+ asyncio.create_task(delayed_close(ws))
314
+ else:
315
+ msg_handler(data)
316
+ connection_tries = 0
317
+ logging.debug(f"Websocket Disconnected. Done? {done}")
318
+ except Exception as e:
319
+ connection_tries += 1
320
+ logging.debug(
321
+ f"Connection error: {traceback.format_exc()}"
267
322
  )
268
- async for msg in ws:
269
- if msg.type in (
270
- aiohttp.WSMsgType.CLOSED,
271
- aiohttp.WSMsgType.ERROR,
272
- aiohttp.WSMsgType.CLOSE,
273
- ):
274
- logging.debug(
275
- f"Websocket Received Closed Message. Done? {done}"
276
- )
277
- await ws.close()
278
- break
279
- data = json.loads(msg.data)
280
- if data.get("type") == "end":
281
- done = True
282
- asyncio.create_task(delayed_close(ws))
283
- else:
284
- msg_handler(data)
285
- connection_tries = 0
286
- logging.debug(f"Websocket Disconnected. Done? {done}")
287
- except Exception as e:
288
- connection_tries += 1
289
- logging.debug(f"Connection error: {traceback.format_exc()}")
290
- if connection_tries == 5:
291
- raise ApiError(
292
- 500,
293
- {"message": f"Connection error: {traceback.format_exc()}"},
294
- )
323
+ if connection_tries == 5:
324
+ raise ApiError(
325
+ 500,
326
+ {
327
+ "message": f"Connection error: {traceback.format_exc()}"
328
+ },
329
+ )
330
+ except GeneratorExit:
331
+ # Handle graceful shutdown - GeneratorExit is raised during
332
+ # event loop cleanup. Don't re-raise to avoid "coroutine ignored"
333
+ # warnings.
334
+ logging.debug("Websocket subscription cancelled during shutdown")
335
+ return
336
+ except asyncio.CancelledError:
337
+ # Re-raise CancelledError to properly propagate task cancellation
338
+ logging.debug("Websocket subscription task cancelled")
339
+ raise
295
340
 
296
341
  def set_active_project(self, project_uuid):
297
342
  CONFIG_DIR = os.path.expanduser(
@@ -0,0 +1 @@
1
+ """Utility modules for proximl SDK."""
@@ -46,7 +46,7 @@
46
46
  # editorial revisions, annotations, elaborations, or other modifications
47
47
  # represent, as a whole, an original work of authorship. For the purposes
48
48
  # of this License, Derivative Works shall not include works that remain
49
- # separable from, or merely link (or bind by name) to the interfaces of,
49
+ # separable from, or merely link (or bind by name) to the interfaces of
50
50
  # the Work and Derivative Works thereof.
51
51
 
52
52
  # "Contribution" shall mean any work of authorship, including
@@ -609,15 +609,16 @@ class Auth(object):
609
609
  logging.debug(f"ID Token Verification: {id_verify}")
610
610
  if id_verify:
611
611
  id_token = tokens["AuthenticationResult"]["IdToken"]
612
+ self.id_token = id_token
613
+
612
614
  access_verify = self.verify_token(
613
615
  tokens["AuthenticationResult"]["AccessToken"], "access_token"
614
616
  )
615
617
  logging.debug(f"Access Token Verification: {access_verify}")
616
618
  if access_verify:
617
619
  access_token = tokens["AuthenticationResult"]["AccessToken"]
620
+ self.access_token = access_token
618
621
 
619
- self.id_token = id_token
620
- self.access_token = access_token
621
622
  self.refresh_token = refresh_token
622
623
  self.expires = (
623
624
  id_verify.get("exp") - 300