proximl 0.5.17__py3-none-any.whl → 1.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- examples/local_storage.py +0 -2
- proximl/__init__.py +1 -1
- proximl/checkpoints.py +56 -57
- proximl/cli/__init__.py +6 -3
- proximl/cli/checkpoint.py +18 -57
- proximl/cli/dataset.py +17 -57
- proximl/cli/job/__init__.py +89 -67
- proximl/cli/job/create.py +51 -24
- proximl/cli/model.py +14 -56
- proximl/cli/volume.py +18 -57
- proximl/datasets.py +50 -55
- proximl/jobs.py +269 -69
- proximl/models.py +51 -55
- proximl/proximl.py +159 -114
- proximl/utils/__init__.py +1 -0
- proximl/{auth.py → utils/auth.py} +4 -3
- proximl/utils/transfer.py +647 -0
- proximl/volumes.py +48 -53
- {proximl-0.5.17.dist-info → proximl-1.0.1.dist-info}/METADATA +3 -3
- {proximl-0.5.17.dist-info → proximl-1.0.1.dist-info}/RECORD +52 -50
- tests/integration/test_checkpoints_integration.py +4 -3
- tests/integration/test_datasets_integration.py +5 -3
- tests/integration/test_jobs_integration.py +33 -27
- tests/integration/test_models_integration.py +7 -3
- tests/integration/test_volumes_integration.py +2 -2
- tests/unit/cli/test_cli_checkpoint_unit.py +312 -1
- tests/unit/cloudbender/test_nodes_unit.py +112 -0
- tests/unit/cloudbender/test_providers_unit.py +96 -0
- tests/unit/cloudbender/test_regions_unit.py +106 -0
- tests/unit/cloudbender/test_services_unit.py +141 -0
- tests/unit/conftest.py +23 -10
- tests/unit/projects/test_project_data_connectors_unit.py +39 -0
- tests/unit/projects/test_project_datastores_unit.py +37 -0
- tests/unit/projects/test_project_members_unit.py +46 -0
- tests/unit/projects/test_project_services_unit.py +65 -0
- tests/unit/projects/test_projects_unit.py +16 -0
- tests/unit/test_auth_unit.py +17 -2
- tests/unit/test_checkpoints_unit.py +256 -71
- tests/unit/test_datasets_unit.py +218 -68
- tests/unit/test_exceptions.py +133 -0
- tests/unit/test_gpu_types_unit.py +11 -1
- tests/unit/test_jobs_unit.py +1014 -95
- tests/unit/test_main_unit.py +20 -0
- tests/unit/test_models_unit.py +218 -70
- tests/unit/test_proximl_unit.py +627 -3
- tests/unit/test_volumes_unit.py +211 -70
- tests/unit/utils/__init__.py +1 -0
- tests/unit/utils/test_transfer_unit.py +4260 -0
- proximl/cli/connection.py +0 -61
- proximl/connections.py +0 -621
- tests/unit/test_connections_unit.py +0 -182
- {proximl-0.5.17.dist-info → proximl-1.0.1.dist-info}/LICENSE +0 -0
- {proximl-0.5.17.dist-info → proximl-1.0.1.dist-info}/WHEEL +0 -0
- {proximl-0.5.17.dist-info → proximl-1.0.1.dist-info}/entry_points.txt +0 -0
- {proximl-0.5.17.dist-info → proximl-1.0.1.dist-info}/top_level.txt +0 -0
proximl/jobs.py
CHANGED
|
@@ -12,7 +12,7 @@ from proximl.exceptions import (
|
|
|
12
12
|
SpecificationError,
|
|
13
13
|
ProxiMLException,
|
|
14
14
|
)
|
|
15
|
-
from proximl.
|
|
15
|
+
from proximl.utils.transfer import upload, download
|
|
16
16
|
|
|
17
17
|
|
|
18
18
|
class Jobs(object):
|
|
@@ -77,7 +77,8 @@ class Jobs(object):
|
|
|
77
77
|
model=model,
|
|
78
78
|
endpoint=endpoint,
|
|
79
79
|
source_job_uuid=kwargs.get("source_job_uuid"),
|
|
80
|
-
project_uuid=kwargs.get("project_uuid")
|
|
80
|
+
project_uuid=kwargs.get("project_uuid")
|
|
81
|
+
or self.proximl.active_project,
|
|
81
82
|
)
|
|
82
83
|
payload = {
|
|
83
84
|
k: v
|
|
@@ -102,7 +103,9 @@ class Jobs(object):
|
|
|
102
103
|
return job
|
|
103
104
|
|
|
104
105
|
async def remove(self, id, **kwargs):
|
|
105
|
-
await self.proximl._query(
|
|
106
|
+
await self.proximl._query(
|
|
107
|
+
f"/job/{id}", "DELETE", dict(**kwargs, force=True)
|
|
108
|
+
)
|
|
106
109
|
|
|
107
110
|
|
|
108
111
|
class Job:
|
|
@@ -292,42 +295,6 @@ class Job:
|
|
|
292
295
|
)
|
|
293
296
|
return resp
|
|
294
297
|
|
|
295
|
-
async def get_connection_utility_url(self):
|
|
296
|
-
resp = await self.proximl._query(
|
|
297
|
-
f"/job/{self._id}/download",
|
|
298
|
-
"GET",
|
|
299
|
-
dict(project_uuid=self._project_uuid),
|
|
300
|
-
)
|
|
301
|
-
return resp
|
|
302
|
-
|
|
303
|
-
def get_connection_details(self):
|
|
304
|
-
details = dict(
|
|
305
|
-
entity_type="job",
|
|
306
|
-
project_uuid=self._job.get("project_uuid"),
|
|
307
|
-
cidr=self.dict.get("vpn").get("cidr"),
|
|
308
|
-
ssh_port=(
|
|
309
|
-
self._job.get("vpn").get("client").get("ssh_port")
|
|
310
|
-
if self._job.get("vpn").get("client")
|
|
311
|
-
else None
|
|
312
|
-
),
|
|
313
|
-
model_path=(
|
|
314
|
-
self._job.get("model").get("source_uri")
|
|
315
|
-
if self._job.get("model").get("source_type") == "local"
|
|
316
|
-
else None
|
|
317
|
-
),
|
|
318
|
-
input_path=(
|
|
319
|
-
self._job.get("data").get("input_uri")
|
|
320
|
-
if self._job.get("data").get("input_type") == "local"
|
|
321
|
-
else None
|
|
322
|
-
),
|
|
323
|
-
output_path=(
|
|
324
|
-
self._job.get("data").get("output_uri")
|
|
325
|
-
if self._job.get("data").get("output_type") == "local"
|
|
326
|
-
else None
|
|
327
|
-
),
|
|
328
|
-
)
|
|
329
|
-
return details
|
|
330
|
-
|
|
331
298
|
async def open(self):
|
|
332
299
|
if self.type != "notebook":
|
|
333
300
|
raise SpecificationError(
|
|
@@ -337,6 +304,7 @@ class Job:
|
|
|
337
304
|
webbrowser.open(self.notebook_url)
|
|
338
305
|
|
|
339
306
|
async def connect(self):
|
|
307
|
+
# Handle notebook/endpoint special cases
|
|
340
308
|
if self.type == "notebook" and self.status not in [
|
|
341
309
|
"new",
|
|
342
310
|
"waiting for data/model download",
|
|
@@ -352,9 +320,30 @@ class Job:
|
|
|
352
320
|
"waiting for data/model download",
|
|
353
321
|
]:
|
|
354
322
|
return self.url
|
|
323
|
+
|
|
324
|
+
# Refresh to get latest job data first, so we can check worker statuses
|
|
325
|
+
await self.refresh()
|
|
326
|
+
|
|
327
|
+
# Check worker statuses - if any worker is uploading, allow connection
|
|
328
|
+
# This handles the case where job status might be "finished" but workers are still uploading
|
|
329
|
+
workers = self._job.get("workers", [])
|
|
330
|
+
has_uploading_workers = any(
|
|
331
|
+
worker.get("status") == "uploading" for worker in workers
|
|
332
|
+
) if workers else False
|
|
333
|
+
|
|
334
|
+
# Log worker statuses for debugging
|
|
335
|
+
if workers:
|
|
336
|
+
worker_statuses = [
|
|
337
|
+
f"Worker {i+1}: {worker.get('status')}"
|
|
338
|
+
for i, worker in enumerate(workers)
|
|
339
|
+
]
|
|
340
|
+
logging.debug(
|
|
341
|
+
f"Job status: {self.status}, Worker statuses: {', '.join(worker_statuses)}, Has uploading workers: {has_uploading_workers}"
|
|
342
|
+
)
|
|
343
|
+
|
|
344
|
+
# Check for invalid statuses (but allow "finished" if workers are still uploading)
|
|
355
345
|
if self.status in [
|
|
356
346
|
"failed",
|
|
357
|
-
"finished",
|
|
358
347
|
"canceled",
|
|
359
348
|
"archived",
|
|
360
349
|
"removed",
|
|
@@ -364,26 +353,231 @@ class Job:
|
|
|
364
353
|
"status",
|
|
365
354
|
f"You can only connect to active jobs.",
|
|
366
355
|
)
|
|
367
|
-
if self._job.get("vpn").get("status") == "n/a":
|
|
368
|
-
logging.info("Local connection not enabled for this job.")
|
|
369
|
-
return
|
|
370
|
-
if self.status == "new":
|
|
371
|
-
await self.wait_for("waiting for data/model download")
|
|
372
|
-
connection = Connection(
|
|
373
|
-
self.proximl, entity_type="job", id=self.id, entity=self
|
|
374
|
-
)
|
|
375
|
-
await connection.start()
|
|
376
|
-
return connection.status
|
|
377
356
|
|
|
378
|
-
|
|
379
|
-
|
|
380
|
-
|
|
381
|
-
|
|
382
|
-
|
|
383
|
-
|
|
384
|
-
|
|
385
|
-
|
|
386
|
-
|
|
357
|
+
# Allow "finished" status if there are workers still uploading
|
|
358
|
+
# This handles reconnection scenarios where some workers are done but others are still uploading
|
|
359
|
+
if self.status == "finished":
|
|
360
|
+
if not has_uploading_workers:
|
|
361
|
+
raise SpecificationError(
|
|
362
|
+
"status",
|
|
363
|
+
f"You can only connect to active jobs.",
|
|
364
|
+
)
|
|
365
|
+
logging.info(
|
|
366
|
+
f"Job status is 'finished' but has {sum(1 for w in workers if w.get('status') == 'uploading')} worker(s) still uploading. Allowing connection to download remaining workers."
|
|
367
|
+
)
|
|
368
|
+
# If we have uploading workers, fall through to download logic
|
|
369
|
+
|
|
370
|
+
# Only allow specific statuses for connect
|
|
371
|
+
if self.status not in [
|
|
372
|
+
"waiting for data/model download",
|
|
373
|
+
"uploading",
|
|
374
|
+
"running",
|
|
375
|
+
"finished", # Allow finished if workers are still uploading
|
|
376
|
+
]:
|
|
377
|
+
if self.status == "new":
|
|
378
|
+
await self.wait_for("waiting for data/model download")
|
|
379
|
+
else:
|
|
380
|
+
raise SpecificationError(
|
|
381
|
+
"status",
|
|
382
|
+
f"You can only connect to jobs in 'waiting for data/model download', 'uploading', 'running', or 'finished' (with uploading workers) status.",
|
|
383
|
+
)
|
|
384
|
+
|
|
385
|
+
if self.status == "waiting for data/model download":
|
|
386
|
+
# Upload model and/or data if local
|
|
387
|
+
model = self._job.get("model", {})
|
|
388
|
+
data = self._job.get("data", {})
|
|
389
|
+
|
|
390
|
+
model_local = model.get("source_type") == "local"
|
|
391
|
+
data_local = data.get("input_type") == "local"
|
|
392
|
+
|
|
393
|
+
if not model_local and not data_local:
|
|
394
|
+
raise SpecificationError(
|
|
395
|
+
"status",
|
|
396
|
+
f"Job has no local model or data to upload. Model source_type: {model.get('source_type')}, Data input_type: {data.get('input_type')}",
|
|
397
|
+
)
|
|
398
|
+
|
|
399
|
+
upload_tasks = []
|
|
400
|
+
|
|
401
|
+
if model_local:
|
|
402
|
+
model_auth_token = model.get("auth_token")
|
|
403
|
+
model_hostname = model.get("hostname")
|
|
404
|
+
model_source_uri = model.get("source_uri")
|
|
405
|
+
|
|
406
|
+
if (
|
|
407
|
+
not model_auth_token
|
|
408
|
+
or not model_hostname
|
|
409
|
+
or not model_source_uri
|
|
410
|
+
):
|
|
411
|
+
raise SpecificationError(
|
|
412
|
+
"status",
|
|
413
|
+
f"Job model missing required connection properties (auth_token, hostname, source_uri).",
|
|
414
|
+
)
|
|
415
|
+
|
|
416
|
+
upload_tasks.append(
|
|
417
|
+
upload(model_hostname, model_auth_token, model_source_uri)
|
|
418
|
+
)
|
|
419
|
+
|
|
420
|
+
if data_local:
|
|
421
|
+
data_auth_token = data.get("input_auth_token")
|
|
422
|
+
data_hostname = data.get("input_hostname")
|
|
423
|
+
data_input_uri = data.get("input_uri")
|
|
424
|
+
|
|
425
|
+
if (
|
|
426
|
+
not data_auth_token
|
|
427
|
+
or not data_hostname
|
|
428
|
+
or not data_input_uri
|
|
429
|
+
):
|
|
430
|
+
raise SpecificationError(
|
|
431
|
+
"status",
|
|
432
|
+
f"Job data missing required connection properties (input_auth_token, input_hostname, input_uri).",
|
|
433
|
+
)
|
|
434
|
+
|
|
435
|
+
upload_tasks.append(
|
|
436
|
+
upload(data_hostname, data_auth_token, data_input_uri)
|
|
437
|
+
)
|
|
438
|
+
|
|
439
|
+
# Upload both in parallel if both are local
|
|
440
|
+
if upload_tasks:
|
|
441
|
+
await asyncio.gather(*upload_tasks)
|
|
442
|
+
|
|
443
|
+
elif self.status in ["uploading", "running", "finished"]:
|
|
444
|
+
# Download output if local
|
|
445
|
+
data = self._job.get("data", {})
|
|
446
|
+
|
|
447
|
+
if data.get("output_type") != "local":
|
|
448
|
+
raise SpecificationError(
|
|
449
|
+
"status",
|
|
450
|
+
f"Job output_type is not 'local', cannot download output.",
|
|
451
|
+
)
|
|
452
|
+
|
|
453
|
+
output_uri = data.get("output_uri")
|
|
454
|
+
if not output_uri:
|
|
455
|
+
raise SpecificationError(
|
|
456
|
+
"status",
|
|
457
|
+
f"Job data missing output_uri for local output download.",
|
|
458
|
+
)
|
|
459
|
+
|
|
460
|
+
# Track which workers we've already started downloading
|
|
461
|
+
downloading_workers = set()
|
|
462
|
+
download_tasks = []
|
|
463
|
+
|
|
464
|
+
# Poll until all workers are finished
|
|
465
|
+
while True:
|
|
466
|
+
# Refresh job to get latest worker statuses
|
|
467
|
+
await self.refresh()
|
|
468
|
+
|
|
469
|
+
# Get fresh workers list
|
|
470
|
+
workers = self._job.get("workers", [])
|
|
471
|
+
if not workers:
|
|
472
|
+
raise SpecificationError(
|
|
473
|
+
"status",
|
|
474
|
+
f"Job has no workers.",
|
|
475
|
+
)
|
|
476
|
+
|
|
477
|
+
# Check if job is in a terminal state AND all workers are finished
|
|
478
|
+
# Allow "finished" status if workers are still uploading
|
|
479
|
+
all_workers_finished = all(
|
|
480
|
+
worker.get("status") in ["finished", "removed"]
|
|
481
|
+
for worker in workers
|
|
482
|
+
)
|
|
483
|
+
if self.status in ["canceled", "failed"]:
|
|
484
|
+
break
|
|
485
|
+
if self.status == "finished" and all_workers_finished:
|
|
486
|
+
break
|
|
487
|
+
|
|
488
|
+
# Check all workers for uploading status
|
|
489
|
+
for worker in workers:
|
|
490
|
+
worker_id = worker.get("job_worker_uuid") or worker.get(
|
|
491
|
+
"id"
|
|
492
|
+
)
|
|
493
|
+
worker_status = worker.get("status")
|
|
494
|
+
|
|
495
|
+
# Start download for any worker that enters uploading status
|
|
496
|
+
# This handles both new connections and reconnections where some workers are already uploading
|
|
497
|
+
if (
|
|
498
|
+
worker_status == "uploading"
|
|
499
|
+
and worker_id not in downloading_workers
|
|
500
|
+
):
|
|
501
|
+
output_auth_token = worker.get("output_auth_token")
|
|
502
|
+
output_hostname = worker.get("output_hostname")
|
|
503
|
+
|
|
504
|
+
if not output_auth_token or not output_hostname:
|
|
505
|
+
logging.warning(
|
|
506
|
+
f"Worker {worker_id} in uploading status missing output_auth_token or output_hostname, skipping."
|
|
507
|
+
)
|
|
508
|
+
# Mark as downloading to avoid retrying
|
|
509
|
+
downloading_workers.add(worker_id)
|
|
510
|
+
continue
|
|
511
|
+
|
|
512
|
+
downloading_workers.add(worker_id)
|
|
513
|
+
# Create and start download task (runs in parallel)
|
|
514
|
+
logging.info(
|
|
515
|
+
f"Starting download for worker {worker_id} from {output_hostname} to {output_uri}"
|
|
516
|
+
)
|
|
517
|
+
try:
|
|
518
|
+
download_task = asyncio.create_task(
|
|
519
|
+
download(
|
|
520
|
+
output_hostname,
|
|
521
|
+
output_auth_token,
|
|
522
|
+
output_uri,
|
|
523
|
+
)
|
|
524
|
+
)
|
|
525
|
+
download_tasks.append(download_task)
|
|
526
|
+
logging.debug(
|
|
527
|
+
f"Download task created for worker {worker_id}, task: {download_task}"
|
|
528
|
+
)
|
|
529
|
+
except Exception as e:
|
|
530
|
+
logging.error(
|
|
531
|
+
f"Failed to create download task for worker {worker_id}: {e}",
|
|
532
|
+
exc_info=True,
|
|
533
|
+
)
|
|
534
|
+
raise
|
|
535
|
+
|
|
536
|
+
# Check if any download tasks have completed or failed
|
|
537
|
+
if download_tasks:
|
|
538
|
+
completed_tasks = [
|
|
539
|
+
task for task in download_tasks if task.done()
|
|
540
|
+
]
|
|
541
|
+
for task in completed_tasks:
|
|
542
|
+
try:
|
|
543
|
+
await task # This will raise if the task failed
|
|
544
|
+
logging.info(
|
|
545
|
+
f"Download task completed successfully"
|
|
546
|
+
)
|
|
547
|
+
except Exception as e:
|
|
548
|
+
logging.error(
|
|
549
|
+
f"Download task failed: {e}", exc_info=True
|
|
550
|
+
)
|
|
551
|
+
raise
|
|
552
|
+
# Remove completed tasks
|
|
553
|
+
download_tasks = [
|
|
554
|
+
task for task in download_tasks if not task.done()
|
|
555
|
+
]
|
|
556
|
+
|
|
557
|
+
# Check if all workers are finished
|
|
558
|
+
all_finished = all(
|
|
559
|
+
worker.get("status") in ["finished", "removed"]
|
|
560
|
+
for worker in workers
|
|
561
|
+
)
|
|
562
|
+
|
|
563
|
+
if all_finished:
|
|
564
|
+
break
|
|
565
|
+
|
|
566
|
+
# If we have active download tasks, wait a bit for them to make progress
|
|
567
|
+
# but don't wait the full 30 seconds - check more frequently
|
|
568
|
+
if download_tasks:
|
|
569
|
+
await asyncio.sleep(5)
|
|
570
|
+
else:
|
|
571
|
+
# Wait 30 seconds before next poll if no downloads in progress
|
|
572
|
+
await asyncio.sleep(30)
|
|
573
|
+
|
|
574
|
+
# Wait for all download tasks to complete
|
|
575
|
+
if download_tasks:
|
|
576
|
+
logging.info(
|
|
577
|
+
f"Waiting for {len(download_tasks)} download task(s) to complete"
|
|
578
|
+
)
|
|
579
|
+
await asyncio.gather(*download_tasks)
|
|
580
|
+
logging.info("All downloads completed")
|
|
387
581
|
|
|
388
582
|
async def remove(self, force=False):
|
|
389
583
|
await self.proximl._query(
|
|
@@ -401,7 +595,8 @@ class Job:
|
|
|
401
595
|
|
|
402
596
|
def _get_msg_handler(self, msg_handler):
|
|
403
597
|
worker_numbers = {
|
|
404
|
-
w.get("job_worker_uuid"): ind + 1
|
|
598
|
+
w.get("job_worker_uuid"): ind + 1
|
|
599
|
+
for ind, w in enumerate(self._workers)
|
|
405
600
|
}
|
|
406
601
|
worker_numbers["data_worker"] = 0
|
|
407
602
|
|
|
@@ -411,7 +606,9 @@ class Job:
|
|
|
411
606
|
if msg_handler:
|
|
412
607
|
msg_handler(data)
|
|
413
608
|
else:
|
|
414
|
-
timestamp = datetime.fromtimestamp(
|
|
609
|
+
timestamp = datetime.fromtimestamp(
|
|
610
|
+
int(data.get("time")) / 1000
|
|
611
|
+
)
|
|
415
612
|
if len(self._workers) > 1:
|
|
416
613
|
print(
|
|
417
614
|
f"{timestamp.strftime('%m/%d/%Y, %H:%M:%S')}: Worker {data.get('worker_number')} - {data.get('msg').rstrip()}"
|
|
@@ -424,7 +621,10 @@ class Job:
|
|
|
424
621
|
return handler
|
|
425
622
|
|
|
426
623
|
async def attach(self, msg_handler=None):
|
|
427
|
-
if
|
|
624
|
+
if (
|
|
625
|
+
self.type == "notebook"
|
|
626
|
+
and self.status != "waiting for data/model download"
|
|
627
|
+
):
|
|
428
628
|
raise SpecificationError(
|
|
429
629
|
"type",
|
|
430
630
|
"Notebooks cannot be attached to after model download is complete. Use open() instead.",
|
|
@@ -441,7 +641,9 @@ class Job:
|
|
|
441
641
|
async def copy(self, name, **kwargs):
|
|
442
642
|
logging.debug(f"copy request - name: {name} ; kwargs: {kwargs}")
|
|
443
643
|
if self.type != "notebook":
|
|
444
|
-
raise SpecificationError(
|
|
644
|
+
raise SpecificationError(
|
|
645
|
+
"job", "Only notebook job types can be copied"
|
|
646
|
+
)
|
|
445
647
|
|
|
446
648
|
job = await self.proximl.jobs.create(
|
|
447
649
|
name,
|
|
@@ -502,7 +704,9 @@ class Job:
|
|
|
502
704
|
|
|
503
705
|
POLL_INTERVAL_MIN = 5
|
|
504
706
|
POLL_INTERVAL_MAX = 60
|
|
505
|
-
POLL_INTERVAL = max(
|
|
707
|
+
POLL_INTERVAL = max(
|
|
708
|
+
min(timeout / 60, POLL_INTERVAL_MAX), POLL_INTERVAL_MIN
|
|
709
|
+
)
|
|
506
710
|
retry_count = math.ceil(timeout / POLL_INTERVAL)
|
|
507
711
|
count = 0
|
|
508
712
|
while count < retry_count:
|
|
@@ -533,11 +737,7 @@ class Job:
|
|
|
533
737
|
or (
|
|
534
738
|
status
|
|
535
739
|
== "running" ## this status could be too short for polling could miss it
|
|
536
|
-
and self.status
|
|
537
|
-
in [
|
|
538
|
-
"uploading",
|
|
539
|
-
"finished"
|
|
540
|
-
]
|
|
740
|
+
and self.status in ["uploading", "finished"]
|
|
541
741
|
)
|
|
542
742
|
):
|
|
543
743
|
return self
|
proximl/models.py
CHANGED
|
@@ -10,7 +10,7 @@ from .exceptions import (
|
|
|
10
10
|
SpecificationError,
|
|
11
11
|
ProxiMLException,
|
|
12
12
|
)
|
|
13
|
-
from .
|
|
13
|
+
from proximl.utils.transfer import upload, download
|
|
14
14
|
|
|
15
15
|
|
|
16
16
|
class Models(object):
|
|
@@ -54,7 +54,9 @@ class Models(object):
|
|
|
54
54
|
return model
|
|
55
55
|
|
|
56
56
|
async def remove(self, id, **kwargs):
|
|
57
|
-
await self.proximl._query(
|
|
57
|
+
await self.proximl._query(
|
|
58
|
+
f"/model/{id}", "DELETE", dict(**kwargs, force=True)
|
|
59
|
+
)
|
|
58
60
|
|
|
59
61
|
|
|
60
62
|
class Model:
|
|
@@ -65,7 +67,9 @@ class Model:
|
|
|
65
67
|
self._status = self._model.get("status")
|
|
66
68
|
self._name = self._model.get("name")
|
|
67
69
|
self._size = self._model.get("size") or self._model.get("used_size")
|
|
68
|
-
self._billed_size = self._model.get("billed_size") or self._model.get(
|
|
70
|
+
self._billed_size = self._model.get("billed_size") or self._model.get(
|
|
71
|
+
"size"
|
|
72
|
+
)
|
|
69
73
|
self._project_uuid = self._model.get("project_uuid")
|
|
70
74
|
|
|
71
75
|
@property
|
|
@@ -113,57 +117,45 @@ class Model:
|
|
|
113
117
|
)
|
|
114
118
|
return resp
|
|
115
119
|
|
|
116
|
-
async def get_connection_utility_url(self):
|
|
117
|
-
resp = await self.proximl._query(
|
|
118
|
-
f"/model/{self._id}/download",
|
|
119
|
-
"GET",
|
|
120
|
-
dict(project_uuid=self._project_uuid),
|
|
121
|
-
)
|
|
122
|
-
return resp
|
|
123
|
-
|
|
124
|
-
def get_connection_details(self):
|
|
125
|
-
if self._model.get("vpn"):
|
|
126
|
-
details = dict(
|
|
127
|
-
entity_type="model",
|
|
128
|
-
project_uuid=self._model.get("project_uuid"),
|
|
129
|
-
cidr=self._model.get("vpn").get("cidr"),
|
|
130
|
-
ssh_port=self._model.get("vpn").get("client").get("ssh_port"),
|
|
131
|
-
input_path=(
|
|
132
|
-
self._model.get("source_uri")
|
|
133
|
-
if self.status in ["new", "downloading"]
|
|
134
|
-
else None
|
|
135
|
-
),
|
|
136
|
-
output_path=(
|
|
137
|
-
self._model.get("output_uri")
|
|
138
|
-
if self.status == "exporting"
|
|
139
|
-
else None
|
|
140
|
-
),
|
|
141
|
-
)
|
|
142
|
-
else:
|
|
143
|
-
details = dict()
|
|
144
|
-
logging.debug(f"Connection Details: {details}")
|
|
145
|
-
return details
|
|
146
|
-
|
|
147
120
|
async def connect(self):
|
|
148
|
-
if self.status in ["
|
|
149
|
-
|
|
150
|
-
"
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
self.proximl, entity_type="model", id=self.id, entity=self
|
|
157
|
-
)
|
|
158
|
-
await connection.start()
|
|
159
|
-
return connection.status
|
|
121
|
+
if self.status not in ["downloading", "exporting"]:
|
|
122
|
+
if self.status == "new":
|
|
123
|
+
await self.wait_for("downloading")
|
|
124
|
+
else:
|
|
125
|
+
raise SpecificationError(
|
|
126
|
+
"status",
|
|
127
|
+
f"You can only connect to downloading or exporting models.",
|
|
128
|
+
)
|
|
160
129
|
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
130
|
+
# Refresh to get latest entity data
|
|
131
|
+
await self.refresh()
|
|
132
|
+
|
|
133
|
+
if self.status == "downloading":
|
|
134
|
+
# Upload task - get auth_token, hostname, and source_uri from model
|
|
135
|
+
auth_token = self._model.get("auth_token")
|
|
136
|
+
hostname = self._model.get("hostname")
|
|
137
|
+
source_uri = self._model.get("source_uri")
|
|
138
|
+
|
|
139
|
+
if not auth_token or not hostname or not source_uri:
|
|
140
|
+
raise SpecificationError(
|
|
141
|
+
"status",
|
|
142
|
+
f"Model in downloading status missing required connection properties (auth_token, hostname, source_uri).",
|
|
143
|
+
)
|
|
144
|
+
|
|
145
|
+
await upload(hostname, auth_token, source_uri)
|
|
146
|
+
elif self.status == "exporting":
|
|
147
|
+
# Download task - get auth_token, hostname, and output_uri from model
|
|
148
|
+
auth_token = self._model.get("auth_token")
|
|
149
|
+
hostname = self._model.get("hostname")
|
|
150
|
+
output_uri = self._model.get("output_uri")
|
|
151
|
+
|
|
152
|
+
if not auth_token or not hostname or not output_uri:
|
|
153
|
+
raise SpecificationError(
|
|
154
|
+
"status",
|
|
155
|
+
f"Model in exporting status missing required connection properties (auth_token, hostname, output_uri).",
|
|
156
|
+
)
|
|
157
|
+
|
|
158
|
+
await download(hostname, auth_token, output_uri)
|
|
167
159
|
|
|
168
160
|
async def remove(self, force=False):
|
|
169
161
|
await self.proximl._query(
|
|
@@ -202,7 +194,9 @@ class Model:
|
|
|
202
194
|
if msg_handler:
|
|
203
195
|
msg_handler(data)
|
|
204
196
|
else:
|
|
205
|
-
timestamp = datetime.fromtimestamp(
|
|
197
|
+
timestamp = datetime.fromtimestamp(
|
|
198
|
+
int(data.get("time")) / 1000
|
|
199
|
+
)
|
|
206
200
|
print(
|
|
207
201
|
f"{timestamp.strftime('%m/%d/%Y, %H:%M:%S')}: {data.get('msg').rstrip()}"
|
|
208
202
|
)
|
|
@@ -231,7 +225,7 @@ class Model:
|
|
|
231
225
|
async def wait_for(self, status, timeout=300):
|
|
232
226
|
if self.status == status:
|
|
233
227
|
return
|
|
234
|
-
valid_statuses = ["downloading", "ready", "archived"]
|
|
228
|
+
valid_statuses = ["downloading", "ready","exporting", "archived"]
|
|
235
229
|
if not status in valid_statuses:
|
|
236
230
|
raise SpecificationError(
|
|
237
231
|
"status",
|
|
@@ -245,7 +239,9 @@ class Model:
|
|
|
245
239
|
)
|
|
246
240
|
POLL_INTERVAL_MIN = 5
|
|
247
241
|
POLL_INTERVAL_MAX = 60
|
|
248
|
-
POLL_INTERVAL = max(
|
|
242
|
+
POLL_INTERVAL = max(
|
|
243
|
+
min(timeout / 60, POLL_INTERVAL_MAX), POLL_INTERVAL_MIN
|
|
244
|
+
)
|
|
249
245
|
retry_count = math.ceil(timeout / POLL_INTERVAL)
|
|
250
246
|
count = 0
|
|
251
247
|
while count < retry_count:
|