trainml 0.5.17__py3-none-any.whl → 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- examples/local_storage.py +0 -2
- tests/integration/test_checkpoints_integration.py +4 -3
- tests/integration/test_datasets_integration.py +5 -3
- tests/integration/test_jobs_integration.py +33 -27
- tests/integration/test_models_integration.py +7 -3
- tests/integration/test_volumes_integration.py +2 -2
- tests/unit/cli/test_cli_checkpoint_unit.py +312 -1
- tests/unit/cloudbender/test_nodes_unit.py +112 -0
- tests/unit/cloudbender/test_providers_unit.py +96 -0
- tests/unit/cloudbender/test_regions_unit.py +106 -0
- tests/unit/cloudbender/test_services_unit.py +141 -0
- tests/unit/conftest.py +23 -10
- tests/unit/projects/test_project_data_connectors_unit.py +39 -0
- tests/unit/projects/test_project_datastores_unit.py +37 -0
- tests/unit/projects/test_project_members_unit.py +46 -0
- tests/unit/projects/test_project_services_unit.py +65 -0
- tests/unit/projects/test_projects_unit.py +16 -0
- tests/unit/test_auth_unit.py +17 -2
- tests/unit/test_checkpoints_unit.py +256 -71
- tests/unit/test_datasets_unit.py +218 -68
- tests/unit/test_exceptions.py +133 -0
- tests/unit/test_gpu_types_unit.py +11 -1
- tests/unit/test_jobs_unit.py +1014 -95
- tests/unit/test_main_unit.py +20 -0
- tests/unit/test_models_unit.py +218 -70
- tests/unit/test_trainml_unit.py +627 -3
- tests/unit/test_volumes_unit.py +211 -70
- tests/unit/utils/__init__.py +1 -0
- tests/unit/utils/test_transfer_unit.py +4260 -0
- trainml/__init__.py +1 -1
- trainml/checkpoints.py +56 -57
- trainml/cli/__init__.py +6 -3
- trainml/cli/checkpoint.py +18 -57
- trainml/cli/dataset.py +17 -57
- trainml/cli/job/__init__.py +11 -53
- trainml/cli/job/create.py +51 -24
- trainml/cli/model.py +14 -56
- trainml/cli/volume.py +18 -57
- trainml/datasets.py +50 -55
- trainml/jobs.py +239 -68
- trainml/models.py +51 -55
- trainml/trainml.py +50 -16
- trainml/utils/__init__.py +1 -0
- trainml/utils/auth.py +641 -0
- trainml/utils/transfer.py +587 -0
- trainml/volumes.py +48 -53
- {trainml-0.5.17.dist-info → trainml-1.0.0.dist-info}/METADATA +3 -3
- {trainml-0.5.17.dist-info → trainml-1.0.0.dist-info}/RECORD +52 -46
- {trainml-0.5.17.dist-info → trainml-1.0.0.dist-info}/LICENSE +0 -0
- {trainml-0.5.17.dist-info → trainml-1.0.0.dist-info}/WHEEL +0 -0
- {trainml-0.5.17.dist-info → trainml-1.0.0.dist-info}/entry_points.txt +0 -0
- {trainml-0.5.17.dist-info → trainml-1.0.0.dist-info}/top_level.txt +0 -0
trainml/jobs.py
CHANGED
|
@@ -12,7 +12,7 @@ from trainml.exceptions import (
|
|
|
12
12
|
SpecificationError,
|
|
13
13
|
TrainMLException,
|
|
14
14
|
)
|
|
15
|
-
from trainml.
|
|
15
|
+
from trainml.utils.transfer import upload, download
|
|
16
16
|
|
|
17
17
|
|
|
18
18
|
class Jobs(object):
|
|
@@ -77,7 +77,8 @@ class Jobs(object):
|
|
|
77
77
|
model=model,
|
|
78
78
|
endpoint=endpoint,
|
|
79
79
|
source_job_uuid=kwargs.get("source_job_uuid"),
|
|
80
|
-
project_uuid=kwargs.get("project_uuid")
|
|
80
|
+
project_uuid=kwargs.get("project_uuid")
|
|
81
|
+
or self.trainml.active_project,
|
|
81
82
|
)
|
|
82
83
|
payload = {
|
|
83
84
|
k: v
|
|
@@ -102,7 +103,9 @@ class Jobs(object):
|
|
|
102
103
|
return job
|
|
103
104
|
|
|
104
105
|
async def remove(self, id, **kwargs):
|
|
105
|
-
await self.trainml._query(
|
|
106
|
+
await self.trainml._query(
|
|
107
|
+
f"/job/{id}", "DELETE", dict(**kwargs, force=True)
|
|
108
|
+
)
|
|
106
109
|
|
|
107
110
|
|
|
108
111
|
class Job:
|
|
@@ -292,42 +295,6 @@ class Job:
|
|
|
292
295
|
)
|
|
293
296
|
return resp
|
|
294
297
|
|
|
295
|
-
async def get_connection_utility_url(self):
|
|
296
|
-
resp = await self.trainml._query(
|
|
297
|
-
f"/job/{self._id}/download",
|
|
298
|
-
"GET",
|
|
299
|
-
dict(project_uuid=self._project_uuid),
|
|
300
|
-
)
|
|
301
|
-
return resp
|
|
302
|
-
|
|
303
|
-
def get_connection_details(self):
|
|
304
|
-
details = dict(
|
|
305
|
-
entity_type="job",
|
|
306
|
-
project_uuid=self._job.get("project_uuid"),
|
|
307
|
-
cidr=self.dict.get("vpn").get("cidr"),
|
|
308
|
-
ssh_port=(
|
|
309
|
-
self._job.get("vpn").get("client").get("ssh_port")
|
|
310
|
-
if self._job.get("vpn").get("client")
|
|
311
|
-
else None
|
|
312
|
-
),
|
|
313
|
-
model_path=(
|
|
314
|
-
self._job.get("model").get("source_uri")
|
|
315
|
-
if self._job.get("model").get("source_type") == "local"
|
|
316
|
-
else None
|
|
317
|
-
),
|
|
318
|
-
input_path=(
|
|
319
|
-
self._job.get("data").get("input_uri")
|
|
320
|
-
if self._job.get("data").get("input_type") == "local"
|
|
321
|
-
else None
|
|
322
|
-
),
|
|
323
|
-
output_path=(
|
|
324
|
-
self._job.get("data").get("output_uri")
|
|
325
|
-
if self._job.get("data").get("output_type") == "local"
|
|
326
|
-
else None
|
|
327
|
-
),
|
|
328
|
-
)
|
|
329
|
-
return details
|
|
330
|
-
|
|
331
298
|
async def open(self):
|
|
332
299
|
if self.type != "notebook":
|
|
333
300
|
raise SpecificationError(
|
|
@@ -337,6 +304,7 @@ class Job:
|
|
|
337
304
|
webbrowser.open(self.notebook_url)
|
|
338
305
|
|
|
339
306
|
async def connect(self):
|
|
307
|
+
# Handle notebook/endpoint special cases
|
|
340
308
|
if self.type == "notebook" and self.status not in [
|
|
341
309
|
"new",
|
|
342
310
|
"waiting for data/model download",
|
|
@@ -352,6 +320,8 @@ class Job:
|
|
|
352
320
|
"waiting for data/model download",
|
|
353
321
|
]:
|
|
354
322
|
return self.url
|
|
323
|
+
|
|
324
|
+
# Check for invalid statuses
|
|
355
325
|
if self.status in [
|
|
356
326
|
"failed",
|
|
357
327
|
"finished",
|
|
@@ -364,26 +334,221 @@ class Job:
|
|
|
364
334
|
"status",
|
|
365
335
|
f"You can only connect to active jobs.",
|
|
366
336
|
)
|
|
367
|
-
if self._job.get("vpn").get("status") == "n/a":
|
|
368
|
-
logging.info("Local connection not enabled for this job.")
|
|
369
|
-
return
|
|
370
|
-
if self.status == "new":
|
|
371
|
-
await self.wait_for("waiting for data/model download")
|
|
372
|
-
connection = Connection(
|
|
373
|
-
self.trainml, entity_type="job", id=self.id, entity=self
|
|
374
|
-
)
|
|
375
|
-
await connection.start()
|
|
376
|
-
return connection.status
|
|
377
337
|
|
|
378
|
-
|
|
379
|
-
if self.
|
|
380
|
-
|
|
381
|
-
|
|
382
|
-
|
|
383
|
-
|
|
384
|
-
|
|
385
|
-
|
|
386
|
-
|
|
338
|
+
# Only allow specific statuses for connect
|
|
339
|
+
if self.status not in [
|
|
340
|
+
"waiting for data/model download",
|
|
341
|
+
"uploading",
|
|
342
|
+
"running",
|
|
343
|
+
]:
|
|
344
|
+
if self.status == "new":
|
|
345
|
+
await self.wait_for("waiting for data/model download")
|
|
346
|
+
else:
|
|
347
|
+
raise SpecificationError(
|
|
348
|
+
"status",
|
|
349
|
+
f"You can only connect to jobs in 'waiting for data/model download', 'uploading', or 'running' status.",
|
|
350
|
+
)
|
|
351
|
+
|
|
352
|
+
# Refresh to get latest job data
|
|
353
|
+
await self.refresh()
|
|
354
|
+
|
|
355
|
+
# Re-check status after refresh (status may have changed if attach() is running in parallel)
|
|
356
|
+
if self.status not in [
|
|
357
|
+
"waiting for data/model download",
|
|
358
|
+
"uploading",
|
|
359
|
+
"running",
|
|
360
|
+
]:
|
|
361
|
+
raise SpecificationError(
|
|
362
|
+
"status",
|
|
363
|
+
f"Job status changed to '{self.status}'. You can only connect to jobs in 'waiting for data/model download', 'uploading', or 'running' status.",
|
|
364
|
+
)
|
|
365
|
+
|
|
366
|
+
if self.status == "waiting for data/model download":
|
|
367
|
+
# Upload model and/or data if local
|
|
368
|
+
model = self._job.get("model", {})
|
|
369
|
+
data = self._job.get("data", {})
|
|
370
|
+
|
|
371
|
+
model_local = model.get("source_type") == "local"
|
|
372
|
+
data_local = data.get("input_type") == "local"
|
|
373
|
+
|
|
374
|
+
if not model_local and not data_local:
|
|
375
|
+
raise SpecificationError(
|
|
376
|
+
"status",
|
|
377
|
+
f"Job has no local model or data to upload. Model source_type: {model.get('source_type')}, Data input_type: {data.get('input_type')}",
|
|
378
|
+
)
|
|
379
|
+
|
|
380
|
+
upload_tasks = []
|
|
381
|
+
|
|
382
|
+
if model_local:
|
|
383
|
+
model_auth_token = model.get("auth_token")
|
|
384
|
+
model_hostname = model.get("hostname")
|
|
385
|
+
model_source_uri = model.get("source_uri")
|
|
386
|
+
|
|
387
|
+
if (
|
|
388
|
+
not model_auth_token
|
|
389
|
+
or not model_hostname
|
|
390
|
+
or not model_source_uri
|
|
391
|
+
):
|
|
392
|
+
raise SpecificationError(
|
|
393
|
+
"status",
|
|
394
|
+
f"Job model missing required connection properties (auth_token, hostname, source_uri).",
|
|
395
|
+
)
|
|
396
|
+
|
|
397
|
+
upload_tasks.append(
|
|
398
|
+
upload(model_hostname, model_auth_token, model_source_uri)
|
|
399
|
+
)
|
|
400
|
+
|
|
401
|
+
if data_local:
|
|
402
|
+
data_auth_token = data.get("input_auth_token")
|
|
403
|
+
data_hostname = data.get("input_hostname")
|
|
404
|
+
data_input_uri = data.get("input_uri")
|
|
405
|
+
|
|
406
|
+
if (
|
|
407
|
+
not data_auth_token
|
|
408
|
+
or not data_hostname
|
|
409
|
+
or not data_input_uri
|
|
410
|
+
):
|
|
411
|
+
raise SpecificationError(
|
|
412
|
+
"status",
|
|
413
|
+
f"Job data missing required connection properties (input_auth_token, input_hostname, input_uri).",
|
|
414
|
+
)
|
|
415
|
+
|
|
416
|
+
upload_tasks.append(
|
|
417
|
+
upload(data_hostname, data_auth_token, data_input_uri)
|
|
418
|
+
)
|
|
419
|
+
|
|
420
|
+
# Upload both in parallel if both are local
|
|
421
|
+
if upload_tasks:
|
|
422
|
+
await asyncio.gather(*upload_tasks)
|
|
423
|
+
|
|
424
|
+
elif self.status in ["uploading", "running"]:
|
|
425
|
+
# Download output if local
|
|
426
|
+
data = self._job.get("data", {})
|
|
427
|
+
|
|
428
|
+
if data.get("output_type") != "local":
|
|
429
|
+
raise SpecificationError(
|
|
430
|
+
"status",
|
|
431
|
+
f"Job output_type is not 'local', cannot download output.",
|
|
432
|
+
)
|
|
433
|
+
|
|
434
|
+
output_uri = data.get("output_uri")
|
|
435
|
+
if not output_uri:
|
|
436
|
+
raise SpecificationError(
|
|
437
|
+
"status",
|
|
438
|
+
f"Job data missing output_uri for local output download.",
|
|
439
|
+
)
|
|
440
|
+
|
|
441
|
+
# Track which workers we've already started downloading
|
|
442
|
+
downloading_workers = set()
|
|
443
|
+
download_tasks = []
|
|
444
|
+
|
|
445
|
+
# Poll until all workers are finished
|
|
446
|
+
while True:
|
|
447
|
+
# Refresh job to get latest worker statuses
|
|
448
|
+
await self.refresh()
|
|
449
|
+
|
|
450
|
+
# Get fresh workers list
|
|
451
|
+
workers = self._job.get("workers", [])
|
|
452
|
+
if not workers:
|
|
453
|
+
raise SpecificationError(
|
|
454
|
+
"status",
|
|
455
|
+
f"Job has no workers.",
|
|
456
|
+
)
|
|
457
|
+
|
|
458
|
+
# Check if job is finished
|
|
459
|
+
if self.status in ["finished", "canceled", "failed"]:
|
|
460
|
+
break
|
|
461
|
+
|
|
462
|
+
# Check all workers for uploading status
|
|
463
|
+
for worker in workers:
|
|
464
|
+
worker_id = worker.get("job_worker_uuid") or worker.get(
|
|
465
|
+
"id"
|
|
466
|
+
)
|
|
467
|
+
worker_status = worker.get("status")
|
|
468
|
+
|
|
469
|
+
# Start download for any worker that enters uploading status
|
|
470
|
+
if (
|
|
471
|
+
worker_status == "uploading"
|
|
472
|
+
and worker_id not in downloading_workers
|
|
473
|
+
):
|
|
474
|
+
output_auth_token = worker.get("output_auth_token")
|
|
475
|
+
output_hostname = worker.get("output_hostname")
|
|
476
|
+
|
|
477
|
+
if not output_auth_token or not output_hostname:
|
|
478
|
+
logging.warning(
|
|
479
|
+
f"Worker {worker_id} in uploading status missing output_auth_token or output_hostname, skipping."
|
|
480
|
+
)
|
|
481
|
+
continue
|
|
482
|
+
|
|
483
|
+
downloading_workers.add(worker_id)
|
|
484
|
+
# Create and start download task (runs in parallel)
|
|
485
|
+
logging.info(
|
|
486
|
+
f"Starting download for worker {worker_id} from {output_hostname} to {output_uri}"
|
|
487
|
+
)
|
|
488
|
+
try:
|
|
489
|
+
download_task = asyncio.create_task(
|
|
490
|
+
download(
|
|
491
|
+
output_hostname,
|
|
492
|
+
output_auth_token,
|
|
493
|
+
output_uri,
|
|
494
|
+
)
|
|
495
|
+
)
|
|
496
|
+
download_tasks.append(download_task)
|
|
497
|
+
logging.debug(
|
|
498
|
+
f"Download task created for worker {worker_id}, task: {download_task}"
|
|
499
|
+
)
|
|
500
|
+
except Exception as e:
|
|
501
|
+
logging.error(
|
|
502
|
+
f"Failed to create download task for worker {worker_id}: {e}",
|
|
503
|
+
exc_info=True,
|
|
504
|
+
)
|
|
505
|
+
raise
|
|
506
|
+
|
|
507
|
+
# Check if any download tasks have completed or failed
|
|
508
|
+
if download_tasks:
|
|
509
|
+
completed_tasks = [
|
|
510
|
+
task for task in download_tasks if task.done()
|
|
511
|
+
]
|
|
512
|
+
for task in completed_tasks:
|
|
513
|
+
try:
|
|
514
|
+
await task # This will raise if the task failed
|
|
515
|
+
logging.info(
|
|
516
|
+
f"Download task completed successfully"
|
|
517
|
+
)
|
|
518
|
+
except Exception as e:
|
|
519
|
+
logging.error(
|
|
520
|
+
f"Download task failed: {e}", exc_info=True
|
|
521
|
+
)
|
|
522
|
+
raise
|
|
523
|
+
# Remove completed tasks
|
|
524
|
+
download_tasks = [
|
|
525
|
+
task for task in download_tasks if not task.done()
|
|
526
|
+
]
|
|
527
|
+
|
|
528
|
+
# Check if all workers are finished
|
|
529
|
+
all_finished = all(
|
|
530
|
+
worker.get("status") in ["finished", "removed"]
|
|
531
|
+
for worker in workers
|
|
532
|
+
)
|
|
533
|
+
|
|
534
|
+
if all_finished:
|
|
535
|
+
break
|
|
536
|
+
|
|
537
|
+
# If we have active download tasks, wait a bit for them to make progress
|
|
538
|
+
# but don't wait the full 30 seconds - check more frequently
|
|
539
|
+
if download_tasks:
|
|
540
|
+
await asyncio.sleep(5)
|
|
541
|
+
else:
|
|
542
|
+
# Wait 30 seconds before next poll if no downloads in progress
|
|
543
|
+
await asyncio.sleep(30)
|
|
544
|
+
|
|
545
|
+
# Wait for all download tasks to complete
|
|
546
|
+
if download_tasks:
|
|
547
|
+
logging.info(
|
|
548
|
+
f"Waiting for {len(download_tasks)} download task(s) to complete"
|
|
549
|
+
)
|
|
550
|
+
await asyncio.gather(*download_tasks)
|
|
551
|
+
logging.info("All downloads completed")
|
|
387
552
|
|
|
388
553
|
async def remove(self, force=False):
|
|
389
554
|
await self.trainml._query(
|
|
@@ -401,7 +566,8 @@ class Job:
|
|
|
401
566
|
|
|
402
567
|
def _get_msg_handler(self, msg_handler):
|
|
403
568
|
worker_numbers = {
|
|
404
|
-
w.get("job_worker_uuid"): ind + 1
|
|
569
|
+
w.get("job_worker_uuid"): ind + 1
|
|
570
|
+
for ind, w in enumerate(self._workers)
|
|
405
571
|
}
|
|
406
572
|
worker_numbers["data_worker"] = 0
|
|
407
573
|
|
|
@@ -411,7 +577,9 @@ class Job:
|
|
|
411
577
|
if msg_handler:
|
|
412
578
|
msg_handler(data)
|
|
413
579
|
else:
|
|
414
|
-
timestamp = datetime.fromtimestamp(
|
|
580
|
+
timestamp = datetime.fromtimestamp(
|
|
581
|
+
int(data.get("time")) / 1000
|
|
582
|
+
)
|
|
415
583
|
if len(self._workers) > 1:
|
|
416
584
|
print(
|
|
417
585
|
f"{timestamp.strftime('%m/%d/%Y, %H:%M:%S')}: Worker {data.get('worker_number')} - {data.get('msg').rstrip()}"
|
|
@@ -424,7 +592,10 @@ class Job:
|
|
|
424
592
|
return handler
|
|
425
593
|
|
|
426
594
|
async def attach(self, msg_handler=None):
|
|
427
|
-
if
|
|
595
|
+
if (
|
|
596
|
+
self.type == "notebook"
|
|
597
|
+
and self.status != "waiting for data/model download"
|
|
598
|
+
):
|
|
428
599
|
raise SpecificationError(
|
|
429
600
|
"type",
|
|
430
601
|
"Notebooks cannot be attached to after model download is complete. Use open() instead.",
|
|
@@ -441,7 +612,9 @@ class Job:
|
|
|
441
612
|
async def copy(self, name, **kwargs):
|
|
442
613
|
logging.debug(f"copy request - name: {name} ; kwargs: {kwargs}")
|
|
443
614
|
if self.type != "notebook":
|
|
444
|
-
raise SpecificationError(
|
|
615
|
+
raise SpecificationError(
|
|
616
|
+
"job", "Only notebook job types can be copied"
|
|
617
|
+
)
|
|
445
618
|
|
|
446
619
|
job = await self.trainml.jobs.create(
|
|
447
620
|
name,
|
|
@@ -502,7 +675,9 @@ class Job:
|
|
|
502
675
|
|
|
503
676
|
POLL_INTERVAL_MIN = 5
|
|
504
677
|
POLL_INTERVAL_MAX = 60
|
|
505
|
-
POLL_INTERVAL = max(
|
|
678
|
+
POLL_INTERVAL = max(
|
|
679
|
+
min(timeout / 60, POLL_INTERVAL_MAX), POLL_INTERVAL_MIN
|
|
680
|
+
)
|
|
506
681
|
retry_count = math.ceil(timeout / POLL_INTERVAL)
|
|
507
682
|
count = 0
|
|
508
683
|
while count < retry_count:
|
|
@@ -533,11 +708,7 @@ class Job:
|
|
|
533
708
|
or (
|
|
534
709
|
status
|
|
535
710
|
== "running" ## this status could be too short for polling could miss it
|
|
536
|
-
and self.status
|
|
537
|
-
in [
|
|
538
|
-
"uploading",
|
|
539
|
-
"finished"
|
|
540
|
-
]
|
|
711
|
+
and self.status in ["uploading", "finished"]
|
|
541
712
|
)
|
|
542
713
|
):
|
|
543
714
|
return self
|
trainml/models.py
CHANGED
|
@@ -10,7 +10,7 @@ from .exceptions import (
|
|
|
10
10
|
SpecificationError,
|
|
11
11
|
TrainMLException,
|
|
12
12
|
)
|
|
13
|
-
from .
|
|
13
|
+
from trainml.utils.transfer import upload, download
|
|
14
14
|
|
|
15
15
|
|
|
16
16
|
class Models(object):
|
|
@@ -54,7 +54,9 @@ class Models(object):
|
|
|
54
54
|
return model
|
|
55
55
|
|
|
56
56
|
async def remove(self, id, **kwargs):
|
|
57
|
-
await self.trainml._query(
|
|
57
|
+
await self.trainml._query(
|
|
58
|
+
f"/model/{id}", "DELETE", dict(**kwargs, force=True)
|
|
59
|
+
)
|
|
58
60
|
|
|
59
61
|
|
|
60
62
|
class Model:
|
|
@@ -65,7 +67,9 @@ class Model:
|
|
|
65
67
|
self._status = self._model.get("status")
|
|
66
68
|
self._name = self._model.get("name")
|
|
67
69
|
self._size = self._model.get("size") or self._model.get("used_size")
|
|
68
|
-
self._billed_size = self._model.get("billed_size") or self._model.get(
|
|
70
|
+
self._billed_size = self._model.get("billed_size") or self._model.get(
|
|
71
|
+
"size"
|
|
72
|
+
)
|
|
69
73
|
self._project_uuid = self._model.get("project_uuid")
|
|
70
74
|
|
|
71
75
|
@property
|
|
@@ -113,57 +117,45 @@ class Model:
|
|
|
113
117
|
)
|
|
114
118
|
return resp
|
|
115
119
|
|
|
116
|
-
async def get_connection_utility_url(self):
|
|
117
|
-
resp = await self.trainml._query(
|
|
118
|
-
f"/model/{self._id}/download",
|
|
119
|
-
"GET",
|
|
120
|
-
dict(project_uuid=self._project_uuid),
|
|
121
|
-
)
|
|
122
|
-
return resp
|
|
123
|
-
|
|
124
|
-
def get_connection_details(self):
|
|
125
|
-
if self._model.get("vpn"):
|
|
126
|
-
details = dict(
|
|
127
|
-
entity_type="model",
|
|
128
|
-
project_uuid=self._model.get("project_uuid"),
|
|
129
|
-
cidr=self._model.get("vpn").get("cidr"),
|
|
130
|
-
ssh_port=self._model.get("vpn").get("client").get("ssh_port"),
|
|
131
|
-
input_path=(
|
|
132
|
-
self._model.get("source_uri")
|
|
133
|
-
if self.status in ["new", "downloading"]
|
|
134
|
-
else None
|
|
135
|
-
),
|
|
136
|
-
output_path=(
|
|
137
|
-
self._model.get("output_uri")
|
|
138
|
-
if self.status == "exporting"
|
|
139
|
-
else None
|
|
140
|
-
),
|
|
141
|
-
)
|
|
142
|
-
else:
|
|
143
|
-
details = dict()
|
|
144
|
-
logging.debug(f"Connection Details: {details}")
|
|
145
|
-
return details
|
|
146
|
-
|
|
147
120
|
async def connect(self):
|
|
148
|
-
if self.status in ["
|
|
149
|
-
|
|
150
|
-
"
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
self.trainml, entity_type="model", id=self.id, entity=self
|
|
157
|
-
)
|
|
158
|
-
await connection.start()
|
|
159
|
-
return connection.status
|
|
121
|
+
if self.status not in ["downloading", "exporting"]:
|
|
122
|
+
if self.status == "new":
|
|
123
|
+
await self.wait_for("downloading")
|
|
124
|
+
else:
|
|
125
|
+
raise SpecificationError(
|
|
126
|
+
"status",
|
|
127
|
+
f"You can only connect to downloading or exporting models.",
|
|
128
|
+
)
|
|
160
129
|
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
130
|
+
# Refresh to get latest entity data
|
|
131
|
+
await self.refresh()
|
|
132
|
+
|
|
133
|
+
if self.status == "downloading":
|
|
134
|
+
# Upload task - get auth_token, hostname, and source_uri from model
|
|
135
|
+
auth_token = self._model.get("auth_token")
|
|
136
|
+
hostname = self._model.get("hostname")
|
|
137
|
+
source_uri = self._model.get("source_uri")
|
|
138
|
+
|
|
139
|
+
if not auth_token or not hostname or not source_uri:
|
|
140
|
+
raise SpecificationError(
|
|
141
|
+
"status",
|
|
142
|
+
f"Model in downloading status missing required connection properties (auth_token, hostname, source_uri).",
|
|
143
|
+
)
|
|
144
|
+
|
|
145
|
+
await upload(hostname, auth_token, source_uri)
|
|
146
|
+
elif self.status == "exporting":
|
|
147
|
+
# Download task - get auth_token, hostname, and output_uri from model
|
|
148
|
+
auth_token = self._model.get("auth_token")
|
|
149
|
+
hostname = self._model.get("hostname")
|
|
150
|
+
output_uri = self._model.get("output_uri")
|
|
151
|
+
|
|
152
|
+
if not auth_token or not hostname or not output_uri:
|
|
153
|
+
raise SpecificationError(
|
|
154
|
+
"status",
|
|
155
|
+
f"Model in exporting status missing required connection properties (auth_token, hostname, output_uri).",
|
|
156
|
+
)
|
|
157
|
+
|
|
158
|
+
await download(hostname, auth_token, output_uri)
|
|
167
159
|
|
|
168
160
|
async def remove(self, force=False):
|
|
169
161
|
await self.trainml._query(
|
|
@@ -202,7 +194,9 @@ class Model:
|
|
|
202
194
|
if msg_handler:
|
|
203
195
|
msg_handler(data)
|
|
204
196
|
else:
|
|
205
|
-
timestamp = datetime.fromtimestamp(
|
|
197
|
+
timestamp = datetime.fromtimestamp(
|
|
198
|
+
int(data.get("time")) / 1000
|
|
199
|
+
)
|
|
206
200
|
print(
|
|
207
201
|
f"{timestamp.strftime('%m/%d/%Y, %H:%M:%S')}: {data.get('msg').rstrip()}"
|
|
208
202
|
)
|
|
@@ -231,7 +225,7 @@ class Model:
|
|
|
231
225
|
async def wait_for(self, status, timeout=300):
|
|
232
226
|
if self.status == status:
|
|
233
227
|
return
|
|
234
|
-
valid_statuses = ["downloading", "ready", "archived"]
|
|
228
|
+
valid_statuses = ["downloading", "ready","exporting", "archived"]
|
|
235
229
|
if not status in valid_statuses:
|
|
236
230
|
raise SpecificationError(
|
|
237
231
|
"status",
|
|
@@ -245,7 +239,9 @@ class Model:
|
|
|
245
239
|
)
|
|
246
240
|
POLL_INTERVAL_MIN = 5
|
|
247
241
|
POLL_INTERVAL_MAX = 60
|
|
248
|
-
POLL_INTERVAL = max(
|
|
242
|
+
POLL_INTERVAL = max(
|
|
243
|
+
min(timeout / 60, POLL_INTERVAL_MAX), POLL_INTERVAL_MIN
|
|
244
|
+
)
|
|
249
245
|
retry_count = math.ceil(timeout / POLL_INTERVAL)
|
|
250
246
|
count = 0
|
|
251
247
|
while count < retry_count:
|