trainml 0.5.17__py3-none-any.whl → 1.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- examples/local_storage.py +0 -2
- tests/integration/test_checkpoints_integration.py +4 -3
- tests/integration/test_datasets_integration.py +5 -3
- tests/integration/test_jobs_integration.py +33 -27
- tests/integration/test_models_integration.py +7 -3
- tests/integration/test_volumes_integration.py +2 -2
- tests/unit/cli/test_cli_checkpoint_unit.py +312 -1
- tests/unit/cloudbender/test_nodes_unit.py +112 -0
- tests/unit/cloudbender/test_providers_unit.py +96 -0
- tests/unit/cloudbender/test_regions_unit.py +106 -0
- tests/unit/cloudbender/test_services_unit.py +141 -0
- tests/unit/conftest.py +23 -10
- tests/unit/projects/test_project_data_connectors_unit.py +39 -0
- tests/unit/projects/test_project_datastores_unit.py +37 -0
- tests/unit/projects/test_project_members_unit.py +46 -0
- tests/unit/projects/test_project_services_unit.py +65 -0
- tests/unit/projects/test_projects_unit.py +16 -0
- tests/unit/test_auth_unit.py +17 -2
- tests/unit/test_checkpoints_unit.py +256 -71
- tests/unit/test_datasets_unit.py +218 -68
- tests/unit/test_exceptions.py +133 -0
- tests/unit/test_gpu_types_unit.py +11 -1
- tests/unit/test_jobs_unit.py +1014 -95
- tests/unit/test_main_unit.py +20 -0
- tests/unit/test_models_unit.py +218 -70
- tests/unit/test_trainml_unit.py +627 -3
- tests/unit/test_volumes_unit.py +211 -70
- tests/unit/utils/__init__.py +1 -0
- tests/unit/utils/test_transfer_unit.py +4260 -0
- trainml/__init__.py +1 -1
- trainml/checkpoints.py +56 -57
- trainml/cli/__init__.py +6 -3
- trainml/cli/checkpoint.py +18 -57
- trainml/cli/dataset.py +17 -57
- trainml/cli/job/__init__.py +89 -67
- trainml/cli/job/create.py +51 -24
- trainml/cli/model.py +14 -56
- trainml/cli/volume.py +18 -57
- trainml/datasets.py +50 -55
- trainml/jobs.py +269 -69
- trainml/models.py +51 -55
- trainml/trainml.py +159 -114
- trainml/utils/__init__.py +1 -0
- trainml/utils/auth.py +641 -0
- trainml/utils/transfer.py +647 -0
- trainml/volumes.py +48 -53
- {trainml-0.5.17.dist-info → trainml-1.0.1.dist-info}/METADATA +3 -3
- {trainml-0.5.17.dist-info → trainml-1.0.1.dist-info}/RECORD +52 -46
- {trainml-0.5.17.dist-info → trainml-1.0.1.dist-info}/LICENSE +0 -0
- {trainml-0.5.17.dist-info → trainml-1.0.1.dist-info}/WHEEL +0 -0
- {trainml-0.5.17.dist-info → trainml-1.0.1.dist-info}/entry_points.txt +0 -0
- {trainml-0.5.17.dist-info → trainml-1.0.1.dist-info}/top_level.txt +0 -0
tests/unit/test_jobs_unit.py
CHANGED
|
@@ -361,70 +361,81 @@ class JobTests:
|
|
|
361
361
|
assert response == api_response
|
|
362
362
|
|
|
363
363
|
@mark.asyncio
|
|
364
|
-
async def
|
|
365
|
-
|
|
366
|
-
|
|
367
|
-
|
|
368
|
-
|
|
369
|
-
|
|
364
|
+
async def test_job_connect_waiting_for_data_model_download_local_model_only(
|
|
365
|
+
self, mock_trainml
|
|
366
|
+
):
|
|
367
|
+
job = specimen.Job(
|
|
368
|
+
mock_trainml,
|
|
369
|
+
**{
|
|
370
|
+
"customer_uuid": "cus-id-1",
|
|
371
|
+
"project_uuid": "proj-id-1",
|
|
372
|
+
"job_uuid": "job-id-1",
|
|
373
|
+
"name": "test job",
|
|
374
|
+
"type": "training",
|
|
375
|
+
"status": "waiting for data/model download",
|
|
376
|
+
"model": {
|
|
377
|
+
"source_type": "local",
|
|
378
|
+
"auth_token": "model-token",
|
|
379
|
+
"hostname": "model-host.com",
|
|
380
|
+
"source_uri": "/path/to/model",
|
|
381
|
+
},
|
|
382
|
+
"data": {
|
|
383
|
+
"input_type": "trainml",
|
|
384
|
+
},
|
|
385
|
+
},
|
|
370
386
|
)
|
|
371
|
-
assert response == api_response
|
|
372
387
|
|
|
373
|
-
|
|
374
|
-
|
|
375
|
-
|
|
376
|
-
|
|
377
|
-
|
|
378
|
-
|
|
379
|
-
|
|
380
|
-
|
|
381
|
-
|
|
382
|
-
|
|
383
|
-
|
|
384
|
-
assert details == expected_details
|
|
388
|
+
with patch(
|
|
389
|
+
"trainml.jobs.Job.refresh", new_callable=AsyncMock
|
|
390
|
+
) as mock_refresh:
|
|
391
|
+
with patch(
|
|
392
|
+
"trainml.jobs.upload", new_callable=AsyncMock
|
|
393
|
+
) as mock_upload:
|
|
394
|
+
await job.connect()
|
|
395
|
+
mock_refresh.assert_called_once()
|
|
396
|
+
mock_upload.assert_called_once_with(
|
|
397
|
+
"model-host.com", "model-token", "/path/to/model"
|
|
398
|
+
)
|
|
385
399
|
|
|
386
|
-
|
|
400
|
+
@mark.asyncio
|
|
401
|
+
async def test_job_connect_waiting_for_data_model_download_local_data_only(
|
|
402
|
+
self, mock_trainml
|
|
403
|
+
):
|
|
387
404
|
job = specimen.Job(
|
|
388
405
|
mock_trainml,
|
|
389
406
|
**{
|
|
390
407
|
"customer_uuid": "cus-id-1",
|
|
391
408
|
"project_uuid": "proj-id-1",
|
|
392
409
|
"job_uuid": "job-id-1",
|
|
393
|
-
"name": "test
|
|
394
|
-
"type": "
|
|
395
|
-
"status": "
|
|
396
|
-
"model": {
|
|
397
|
-
|
|
398
|
-
"datasets": [],
|
|
399
|
-
"output_type": "local",
|
|
400
|
-
"output_uri": "~/tensorflow-example/output",
|
|
401
|
-
"status": "ready",
|
|
410
|
+
"name": "test job",
|
|
411
|
+
"type": "training",
|
|
412
|
+
"status": "waiting for data/model download",
|
|
413
|
+
"model": {
|
|
414
|
+
"source_type": "trainml",
|
|
402
415
|
},
|
|
403
|
-
"
|
|
404
|
-
"
|
|
405
|
-
"
|
|
406
|
-
"
|
|
407
|
-
|
|
408
|
-
"id": "cus-id-1",
|
|
409
|
-
"address": "10.106.171.253",
|
|
410
|
-
"ssh_port": 46600,
|
|
411
|
-
},
|
|
416
|
+
"data": {
|
|
417
|
+
"input_type": "local",
|
|
418
|
+
"input_auth_token": "data-token",
|
|
419
|
+
"input_hostname": "data-host.com",
|
|
420
|
+
"input_uri": "/path/to/data",
|
|
412
421
|
},
|
|
413
422
|
},
|
|
414
423
|
)
|
|
415
|
-
details = job.get_connection_details()
|
|
416
|
-
expected_details = dict(
|
|
417
|
-
project_uuid="proj-id-1",
|
|
418
|
-
entity_type="job",
|
|
419
|
-
cidr="10.106.171.0/24",
|
|
420
|
-
ssh_port=46600,
|
|
421
|
-
model_path=None,
|
|
422
|
-
input_path=None,
|
|
423
|
-
output_path="~/tensorflow-example/output",
|
|
424
|
-
)
|
|
425
|
-
assert details == expected_details
|
|
426
424
|
|
|
427
|
-
|
|
425
|
+
with patch(
|
|
426
|
+
"trainml.jobs.Job.refresh", new_callable=AsyncMock
|
|
427
|
+
) as mock_refresh:
|
|
428
|
+
with patch(
|
|
429
|
+
"trainml.jobs.upload", new_callable=AsyncMock
|
|
430
|
+
) as mock_upload:
|
|
431
|
+
await job.connect()
|
|
432
|
+
mock_refresh.assert_called_once()
|
|
433
|
+
mock_upload.assert_called_once_with(
|
|
434
|
+
"data-host.com", "data-token", "/path/to/data"
|
|
435
|
+
)
|
|
436
|
+
|
|
437
|
+
@mark.asyncio
|
|
438
|
+
async def test_job_connect_waiting_for_data_model_download_both_local_parallel(
|
|
428
439
|
self, mock_trainml
|
|
429
440
|
):
|
|
430
441
|
job = specimen.Job(
|
|
@@ -433,63 +444,288 @@ class JobTests:
|
|
|
433
444
|
"customer_uuid": "cus-id-1",
|
|
434
445
|
"project_uuid": "proj-id-1",
|
|
435
446
|
"job_uuid": "job-id-1",
|
|
436
|
-
"name": "test
|
|
437
|
-
"type": "
|
|
438
|
-
"status": "
|
|
439
|
-
"model": {
|
|
447
|
+
"name": "test job",
|
|
448
|
+
"type": "training",
|
|
449
|
+
"status": "waiting for data/model download",
|
|
450
|
+
"model": {
|
|
451
|
+
"source_type": "local",
|
|
452
|
+
"auth_token": "model-token",
|
|
453
|
+
"hostname": "model-host.com",
|
|
454
|
+
"source_uri": "/path/to/model",
|
|
455
|
+
},
|
|
440
456
|
"data": {
|
|
441
|
-
"datasets": [],
|
|
442
457
|
"input_type": "local",
|
|
443
|
-
"
|
|
444
|
-
"
|
|
458
|
+
"input_auth_token": "data-token",
|
|
459
|
+
"input_hostname": "data-host.com",
|
|
460
|
+
"input_uri": "/path/to/data",
|
|
445
461
|
},
|
|
446
|
-
|
|
447
|
-
|
|
448
|
-
|
|
449
|
-
|
|
450
|
-
|
|
451
|
-
|
|
452
|
-
|
|
453
|
-
|
|
454
|
-
|
|
462
|
+
},
|
|
463
|
+
)
|
|
464
|
+
|
|
465
|
+
with patch(
|
|
466
|
+
"trainml.jobs.Job.refresh", new_callable=AsyncMock
|
|
467
|
+
) as mock_refresh:
|
|
468
|
+
with patch(
|
|
469
|
+
"trainml.jobs.upload", new_callable=AsyncMock
|
|
470
|
+
) as mock_upload:
|
|
471
|
+
await job.connect()
|
|
472
|
+
mock_refresh.assert_called_once()
|
|
473
|
+
assert mock_upload.call_count == 2
|
|
474
|
+
# Verify both were called with correct parameters
|
|
475
|
+
calls = mock_upload.call_args_list
|
|
476
|
+
assert any(
|
|
477
|
+
call[0]
|
|
478
|
+
== ("model-host.com", "model-token", "/path/to/model")
|
|
479
|
+
for call in calls
|
|
480
|
+
)
|
|
481
|
+
assert any(
|
|
482
|
+
call[0] == ("data-host.com", "data-token", "/path/to/data")
|
|
483
|
+
for call in calls
|
|
484
|
+
)
|
|
485
|
+
|
|
486
|
+
@mark.asyncio
|
|
487
|
+
async def test_job_connect_waiting_for_data_model_download_neither_local_error(
|
|
488
|
+
self, mock_trainml
|
|
489
|
+
):
|
|
490
|
+
job = specimen.Job(
|
|
491
|
+
mock_trainml,
|
|
492
|
+
**{
|
|
493
|
+
"customer_uuid": "cus-id-1",
|
|
494
|
+
"project_uuid": "proj-id-1",
|
|
495
|
+
"job_uuid": "job-id-1",
|
|
496
|
+
"name": "test job",
|
|
497
|
+
"type": "training",
|
|
498
|
+
"status": "waiting for data/model download",
|
|
499
|
+
"model": {
|
|
500
|
+
"source_type": "trainml",
|
|
501
|
+
},
|
|
502
|
+
"data": {
|
|
503
|
+
"input_type": "trainml",
|
|
455
504
|
},
|
|
456
505
|
},
|
|
457
506
|
)
|
|
458
|
-
|
|
459
|
-
|
|
460
|
-
|
|
461
|
-
|
|
462
|
-
|
|
463
|
-
|
|
464
|
-
|
|
465
|
-
|
|
466
|
-
|
|
507
|
+
|
|
508
|
+
with patch("trainml.jobs.Job.refresh", new_callable=AsyncMock):
|
|
509
|
+
with raises(
|
|
510
|
+
SpecificationError,
|
|
511
|
+
match="Job has no local model or data to upload",
|
|
512
|
+
):
|
|
513
|
+
await job.connect()
|
|
514
|
+
|
|
515
|
+
@mark.asyncio
|
|
516
|
+
async def test_job_connect_uploading_status_single_worker(
|
|
517
|
+
self, mock_trainml, tmp_path
|
|
518
|
+
):
|
|
519
|
+
job = specimen.Job(
|
|
520
|
+
mock_trainml,
|
|
521
|
+
**{
|
|
522
|
+
"customer_uuid": "cus-id-1",
|
|
523
|
+
"project_uuid": "proj-id-1",
|
|
524
|
+
"job_uuid": "job-id-1",
|
|
525
|
+
"name": "test job",
|
|
526
|
+
"type": "training",
|
|
527
|
+
"status": "uploading",
|
|
528
|
+
"data": {
|
|
529
|
+
"output_type": "local",
|
|
530
|
+
"output_uri": str(tmp_path / "output"),
|
|
531
|
+
},
|
|
532
|
+
"workers": [
|
|
533
|
+
{
|
|
534
|
+
"job_worker_uuid": "worker-1",
|
|
535
|
+
"status": "uploading",
|
|
536
|
+
"output_auth_token": "worker-token",
|
|
537
|
+
"output_hostname": "worker-host.com",
|
|
538
|
+
}
|
|
539
|
+
],
|
|
540
|
+
},
|
|
467
541
|
)
|
|
468
|
-
|
|
542
|
+
|
|
543
|
+
# Mock refresh to preserve job state and control loop behavior
|
|
544
|
+
refresh_call_count = [0]
|
|
545
|
+
|
|
546
|
+
async def mock_refresh():
|
|
547
|
+
refresh_call_count[0] += 1
|
|
548
|
+
# First refresh (before loop, line 346) - ensure state is correct
|
|
549
|
+
if refresh_call_count[0] == 1:
|
|
550
|
+
job._status = "uploading"
|
|
551
|
+
# Ensure workers list exists and has the uploading worker
|
|
552
|
+
if (
|
|
553
|
+
not job._job.get("workers")
|
|
554
|
+
or len(job._job["workers"]) == 0
|
|
555
|
+
):
|
|
556
|
+
job._job["workers"] = [
|
|
557
|
+
{
|
|
558
|
+
"job_worker_uuid": "worker-1",
|
|
559
|
+
"status": "uploading",
|
|
560
|
+
"output_auth_token": "worker-token",
|
|
561
|
+
"output_hostname": "worker-host.com",
|
|
562
|
+
}
|
|
563
|
+
]
|
|
564
|
+
else:
|
|
565
|
+
job._job["workers"][0]["status"] = "uploading"
|
|
566
|
+
job._job["workers"][0][
|
|
567
|
+
"output_auth_token"
|
|
568
|
+
] = "worker-token"
|
|
569
|
+
job._job["workers"][0][
|
|
570
|
+
"output_hostname"
|
|
571
|
+
] = "worker-host.com"
|
|
572
|
+
# Also update _workers property
|
|
573
|
+
job._workers = job._job.get("workers")
|
|
574
|
+
# Second refresh (in loop, line 418, first iteration) - keep uploading so download task is created
|
|
575
|
+
elif refresh_call_count[0] == 2:
|
|
576
|
+
job._status = "uploading"
|
|
577
|
+
job._job["workers"][0]["status"] = "uploading"
|
|
578
|
+
job._job["workers"][0]["output_auth_token"] = "worker-token"
|
|
579
|
+
job._job["workers"][0]["output_hostname"] = "worker-host.com"
|
|
580
|
+
job._workers = job._job.get("workers")
|
|
581
|
+
# Third refresh (in loop, second iteration) - mark as finished to exit
|
|
582
|
+
elif refresh_call_count[0] == 3:
|
|
583
|
+
job._status = "finished"
|
|
584
|
+
job._job["workers"][0]["status"] = "finished"
|
|
585
|
+
job._workers = job._job.get("workers")
|
|
586
|
+
return job
|
|
587
|
+
|
|
588
|
+
with patch.object(job, "refresh", side_effect=mock_refresh):
|
|
589
|
+
with patch(
|
|
590
|
+
"trainml.jobs.download", new_callable=AsyncMock
|
|
591
|
+
) as mock_download:
|
|
592
|
+
# Mock sleep - allow loop to continue
|
|
593
|
+
async def sleep_side_effect(delay):
|
|
594
|
+
# After sleep, next refresh will mark as finished
|
|
595
|
+
pass
|
|
596
|
+
|
|
597
|
+
with patch("asyncio.sleep", side_effect=sleep_side_effect):
|
|
598
|
+
await job.connect()
|
|
599
|
+
# Download should be called once for the uploading worker
|
|
600
|
+
# The download task is created in the first loop iteration, then we wait for it
|
|
601
|
+
assert mock_download.call_count == 1
|
|
602
|
+
mock_download.assert_called_with(
|
|
603
|
+
"worker-host.com",
|
|
604
|
+
"worker-token",
|
|
605
|
+
str(tmp_path / "output"),
|
|
606
|
+
)
|
|
469
607
|
|
|
470
608
|
@mark.asyncio
|
|
471
|
-
async def
|
|
609
|
+
async def test_job_connect_running_status_multi_worker_polling(
|
|
610
|
+
self, mock_trainml, tmp_path
|
|
611
|
+
):
|
|
612
|
+
job = specimen.Job(
|
|
613
|
+
mock_trainml,
|
|
614
|
+
**{
|
|
615
|
+
"customer_uuid": "cus-id-1",
|
|
616
|
+
"project_uuid": "proj-id-1",
|
|
617
|
+
"job_uuid": "job-id-1",
|
|
618
|
+
"name": "test job",
|
|
619
|
+
"type": "training",
|
|
620
|
+
"status": "running",
|
|
621
|
+
"data": {
|
|
622
|
+
"output_type": "local",
|
|
623
|
+
"output_uri": str(tmp_path / "output"),
|
|
624
|
+
},
|
|
625
|
+
"workers": [
|
|
626
|
+
{
|
|
627
|
+
"job_worker_uuid": "worker-1",
|
|
628
|
+
"status": "running",
|
|
629
|
+
},
|
|
630
|
+
{
|
|
631
|
+
"job_worker_uuid": "worker-2",
|
|
632
|
+
"status": "running",
|
|
633
|
+
},
|
|
634
|
+
],
|
|
635
|
+
},
|
|
636
|
+
)
|
|
637
|
+
|
|
638
|
+
refresh_count = [0]
|
|
472
639
|
with patch(
|
|
473
|
-
"trainml.jobs.
|
|
474
|
-
|
|
475
|
-
|
|
476
|
-
|
|
477
|
-
|
|
478
|
-
|
|
479
|
-
|
|
480
|
-
|
|
640
|
+
"trainml.jobs.Job.refresh", new_callable=AsyncMock
|
|
641
|
+
) as mock_refresh:
|
|
642
|
+
|
|
643
|
+
def refresh_side_effect():
|
|
644
|
+
refresh_count[0] += 1
|
|
645
|
+
if refresh_count[0] == 1:
|
|
646
|
+
# First refresh: worker-1 becomes uploading
|
|
647
|
+
job._job["workers"][0]["status"] = "uploading"
|
|
648
|
+
job._job["workers"][0]["output_auth_token"] = "token-1"
|
|
649
|
+
job._job["workers"][0]["output_hostname"] = "host-1.com"
|
|
650
|
+
elif refresh_count[0] == 2:
|
|
651
|
+
# Second refresh: worker-2 becomes uploading
|
|
652
|
+
job._job["workers"][1]["status"] = "uploading"
|
|
653
|
+
job._job["workers"][1]["output_auth_token"] = "token-2"
|
|
654
|
+
job._job["workers"][1]["output_hostname"] = "host-2.com"
|
|
655
|
+
else:
|
|
656
|
+
# Third refresh: both finished
|
|
657
|
+
job._status = "finished"
|
|
658
|
+
job._job["workers"][0]["status"] = "finished"
|
|
659
|
+
job._job["workers"][1]["status"] = "finished"
|
|
660
|
+
|
|
661
|
+
mock_refresh.side_effect = refresh_side_effect
|
|
662
|
+
|
|
663
|
+
with patch(
|
|
664
|
+
"trainml.jobs.download", new_callable=AsyncMock
|
|
665
|
+
) as mock_download:
|
|
666
|
+
sleep_mock = AsyncMock()
|
|
667
|
+
with patch("asyncio.sleep", sleep_mock):
|
|
668
|
+
await job.connect()
|
|
669
|
+
# Should have called download twice (once per worker)
|
|
670
|
+
assert mock_download.call_count == 2
|
|
671
|
+
# Should have slept between polls (at least once before both workers finish)
|
|
672
|
+
assert sleep_mock.call_count >= 1
|
|
673
|
+
# Verify both downloads were called with correct parameters
|
|
674
|
+
calls = mock_download.call_args_list
|
|
675
|
+
assert any(
|
|
676
|
+
call[0]
|
|
677
|
+
== ("host-1.com", "token-1", str(tmp_path / "output"))
|
|
678
|
+
for call in calls
|
|
679
|
+
)
|
|
680
|
+
assert any(
|
|
681
|
+
call[0]
|
|
682
|
+
== ("host-2.com", "token-2", str(tmp_path / "output"))
|
|
683
|
+
for call in calls
|
|
684
|
+
)
|
|
481
685
|
|
|
482
686
|
@mark.asyncio
|
|
483
|
-
async def
|
|
484
|
-
|
|
485
|
-
|
|
486
|
-
|
|
487
|
-
|
|
488
|
-
|
|
489
|
-
|
|
490
|
-
|
|
491
|
-
|
|
492
|
-
|
|
687
|
+
async def test_job_connect_invalid_status(self, mock_trainml):
|
|
688
|
+
job = specimen.Job(
|
|
689
|
+
mock_trainml,
|
|
690
|
+
**{
|
|
691
|
+
"customer_uuid": "cus-id-1",
|
|
692
|
+
"project_uuid": "proj-id-1",
|
|
693
|
+
"job_uuid": "job-id-1",
|
|
694
|
+
"name": "test job",
|
|
695
|
+
"type": "training",
|
|
696
|
+
"status": "finished",
|
|
697
|
+
},
|
|
698
|
+
)
|
|
699
|
+
|
|
700
|
+
with raises(
|
|
701
|
+
SpecificationError, match="You can only connect to active jobs"
|
|
702
|
+
):
|
|
703
|
+
await job.connect()
|
|
704
|
+
|
|
705
|
+
@mark.asyncio
|
|
706
|
+
async def test_job_connect_uploading_no_local_output_error(
|
|
707
|
+
self, mock_trainml
|
|
708
|
+
):
|
|
709
|
+
job = specimen.Job(
|
|
710
|
+
mock_trainml,
|
|
711
|
+
**{
|
|
712
|
+
"customer_uuid": "cus-id-1",
|
|
713
|
+
"project_uuid": "proj-id-1",
|
|
714
|
+
"job_uuid": "job-id-1",
|
|
715
|
+
"name": "test job",
|
|
716
|
+
"type": "training",
|
|
717
|
+
"status": "uploading",
|
|
718
|
+
"data": {
|
|
719
|
+
"output_type": "s3",
|
|
720
|
+
},
|
|
721
|
+
},
|
|
722
|
+
)
|
|
723
|
+
|
|
724
|
+
with patch("trainml.jobs.Job.refresh", new_callable=AsyncMock):
|
|
725
|
+
with raises(
|
|
726
|
+
SpecificationError, match="Job output_type is not 'local'"
|
|
727
|
+
):
|
|
728
|
+
await job.connect()
|
|
493
729
|
|
|
494
730
|
@mark.asyncio
|
|
495
731
|
async def test_job_remove(self, job, mock_trainml):
|
|
@@ -828,3 +1064,686 @@ class JobTests:
|
|
|
828
1064
|
with raises(ApiError):
|
|
829
1065
|
await job.wait_for("running")
|
|
830
1066
|
mock_trainml._query.assert_called()
|
|
1067
|
+
|
|
1068
|
+
@mark.asyncio
|
|
1069
|
+
async def test_job_update_notebook(self, job, mock_trainml):
|
|
1070
|
+
"""Test Job.update() for notebook jobs."""
|
|
1071
|
+
job._job["type"] = "notebook"
|
|
1072
|
+
update_data = dict(environment=dict(type="DEEPLEARNING_PY310"))
|
|
1073
|
+
api_response = dict(
|
|
1074
|
+
job_uuid="job-id-1",
|
|
1075
|
+
name="test notebook",
|
|
1076
|
+
type="notebook",
|
|
1077
|
+
status="new",
|
|
1078
|
+
**update_data,
|
|
1079
|
+
)
|
|
1080
|
+
mock_trainml._query = AsyncMock(return_value=api_response)
|
|
1081
|
+
result = await job.update(update_data)
|
|
1082
|
+
mock_trainml._query.assert_called_once_with(
|
|
1083
|
+
"/job/job-id-1",
|
|
1084
|
+
"PATCH",
|
|
1085
|
+
dict(project_uuid="proj-id-1"),
|
|
1086
|
+
update_data,
|
|
1087
|
+
)
|
|
1088
|
+
assert result == job
|
|
1089
|
+
|
|
1090
|
+
@mark.asyncio
|
|
1091
|
+
async def test_job_update_non_notebook_error(self, job, mock_trainml):
|
|
1092
|
+
"""Test Job.update() raises error for non-notebook jobs."""
|
|
1093
|
+
job._job["type"] = "training"
|
|
1094
|
+
job._type = "training" # Update the cached property value
|
|
1095
|
+
with raises(SpecificationError) as exc_info:
|
|
1096
|
+
await job.update(dict(environment=dict(type="DEEPLEARNING_PY310")))
|
|
1097
|
+
assert "Only notebook jobs can be modified" in str(exc_info.value.message)
|
|
1098
|
+
|
|
1099
|
+
@mark.asyncio
|
|
1100
|
+
async def test_job_open_notebook(self, job, mock_trainml):
|
|
1101
|
+
"""Test Job.open() for notebook jobs."""
|
|
1102
|
+
job._job["type"] = "notebook"
|
|
1103
|
+
job._job["endpoint"] = dict(url="https://example.com")
|
|
1104
|
+
job._job["nb_token"] = "token123"
|
|
1105
|
+
with patch("trainml.jobs.webbrowser.open") as mock_open:
|
|
1106
|
+
await job.open()
|
|
1107
|
+
mock_open.assert_called_once_with("https://example.com/?token=token123")
|
|
1108
|
+
|
|
1109
|
+
@mark.asyncio
|
|
1110
|
+
async def test_job_open_non_notebook_error(self, job, mock_trainml):
|
|
1111
|
+
"""Test Job.open() raises error for non-notebook jobs."""
|
|
1112
|
+
job._job["type"] = "training"
|
|
1113
|
+
job._type = "training" # Update the cached property value
|
|
1114
|
+
# Ensure endpoint exists to avoid NoneType error
|
|
1115
|
+
job._job["endpoint"] = dict(url="https://example.com")
|
|
1116
|
+
with raises(SpecificationError) as exc_info:
|
|
1117
|
+
await job.open()
|
|
1118
|
+
assert "Only notebook jobs can be opened" in str(exc_info.value.message)
|
|
1119
|
+
|
|
1120
|
+
def test_job_get_create_json_comprehensive(self, job, mock_trainml):
|
|
1121
|
+
"""Test get_create_json() with comprehensive data."""
|
|
1122
|
+
job._job = dict(
|
|
1123
|
+
name="test job",
|
|
1124
|
+
type="training",
|
|
1125
|
+
project_uuid="proj-id-1",
|
|
1126
|
+
resources=dict(
|
|
1127
|
+
gpu_count=1,
|
|
1128
|
+
gpu_types=["rtx3090"],
|
|
1129
|
+
disk_size=10,
|
|
1130
|
+
max_price=5.0,
|
|
1131
|
+
preemptible=True,
|
|
1132
|
+
cpu_count=4,
|
|
1133
|
+
),
|
|
1134
|
+
model=dict(
|
|
1135
|
+
source_type="git",
|
|
1136
|
+
source_uri="git@github.com:test/repo.git",
|
|
1137
|
+
project_uuid="proj-id-1",
|
|
1138
|
+
checkpoints=["checkpoint-1"],
|
|
1139
|
+
),
|
|
1140
|
+
data=dict(
|
|
1141
|
+
datasets=["dataset-1"],
|
|
1142
|
+
input_type="aws",
|
|
1143
|
+
input_uri="s3://bucket/input",
|
|
1144
|
+
input_options=dict(key="value"),
|
|
1145
|
+
output_type="aws",
|
|
1146
|
+
output_uri="s3://bucket/output",
|
|
1147
|
+
output_options=dict(key="value"),
|
|
1148
|
+
),
|
|
1149
|
+
environment=dict(
|
|
1150
|
+
type="DEEPLEARNING_PY310",
|
|
1151
|
+
env=[dict(key="KEY", value="VALUE")],
|
|
1152
|
+
custom_image="custom:latest",
|
|
1153
|
+
worker_key_types=["ssh"],
|
|
1154
|
+
packages=["package1"],
|
|
1155
|
+
),
|
|
1156
|
+
endpoint=dict(
|
|
1157
|
+
routes=["/api"],
|
|
1158
|
+
start_command="python app.py",
|
|
1159
|
+
reservation_id="reservation-1",
|
|
1160
|
+
),
|
|
1161
|
+
workers=[
|
|
1162
|
+
dict(command="python train.py"),
|
|
1163
|
+
dict(command="python eval.py"),
|
|
1164
|
+
],
|
|
1165
|
+
)
|
|
1166
|
+
result = job.get_create_json()
|
|
1167
|
+
assert result["name"] == "test job"
|
|
1168
|
+
assert result["type"] == "training"
|
|
1169
|
+
assert result["project_uuid"] == "proj-id-1"
|
|
1170
|
+
assert result["resources"]["gpu_count"] == 1
|
|
1171
|
+
assert result["resources"]["gpu_types"] == ["rtx3090"]
|
|
1172
|
+
assert result["model"]["source_type"] == "git"
|
|
1173
|
+
assert result["data"]["datasets"] == ["dataset-1"]
|
|
1174
|
+
assert result["environment"]["type"] == "DEEPLEARNING_PY310"
|
|
1175
|
+
assert result["endpoint"]["routes"] == ["/api"]
|
|
1176
|
+
assert result["workers"] == ["python train.py", "python eval.py"]
|
|
1177
|
+
|
|
1178
|
+
def test_job_get_create_json_minimal(self, job, mock_trainml):
|
|
1179
|
+
"""Test get_create_json() with minimal data."""
|
|
1180
|
+
job._job = dict(
|
|
1181
|
+
name="minimal job",
|
|
1182
|
+
type="training",
|
|
1183
|
+
project_uuid="proj-id-1",
|
|
1184
|
+
)
|
|
1185
|
+
result = job.get_create_json()
|
|
1186
|
+
assert result["name"] == "minimal job"
|
|
1187
|
+
assert result["type"] == "training"
|
|
1188
|
+
assert result["project_uuid"] == "proj-id-1"
|
|
1189
|
+
assert "resources" not in result
|
|
1190
|
+
assert "model" not in result
|
|
1191
|
+
assert "data" not in result
|
|
1192
|
+
|
|
1193
|
+
def test_job_get_create_json_partial_resources(self, job, mock_trainml):
|
|
1194
|
+
"""Test get_create_json() with partial resources."""
|
|
1195
|
+
job._job = dict(
|
|
1196
|
+
name="partial job",
|
|
1197
|
+
type="training",
|
|
1198
|
+
project_uuid="proj-id-1",
|
|
1199
|
+
resources=dict(
|
|
1200
|
+
gpu_count=1,
|
|
1201
|
+
disk_size=10,
|
|
1202
|
+
# Missing other resource keys
|
|
1203
|
+
),
|
|
1204
|
+
)
|
|
1205
|
+
result = job.get_create_json()
|
|
1206
|
+
assert result["resources"]["gpu_count"] == 1
|
|
1207
|
+
assert result["resources"]["disk_size"] == 10
|
|
1208
|
+
assert "max_price" not in result["resources"]
|
|
1209
|
+
assert "preemptible" not in result["resources"]
|
|
1210
|
+
|
|
1211
|
+
def test_job_workers_property(self, job):
|
|
1212
|
+
"""Test workers property."""
|
|
1213
|
+
assert job.workers == job._workers
|
|
1214
|
+
|
|
1215
|
+
def test_job_credits_property(self, job):
|
|
1216
|
+
"""Test credits property."""
|
|
1217
|
+
assert job.credits == job._credits
|
|
1218
|
+
|
|
1219
|
+
def test_job_notebook_url_non_notebook(self, job, mock_trainml):
|
|
1220
|
+
"""Test notebook_url property returns None for non-notebook jobs."""
|
|
1221
|
+
job._type = "training"
|
|
1222
|
+
assert job.notebook_url is None
|
|
1223
|
+
|
|
1224
|
+
@mark.asyncio
|
|
1225
|
+
async def test_job_create_with_gpu_types(self, jobs, mock_trainml):
|
|
1226
|
+
"""Test create with gpu_types list (line 60)."""
|
|
1227
|
+
api_response = {
|
|
1228
|
+
"job_uuid": "job-id-1",
|
|
1229
|
+
"name": "test job",
|
|
1230
|
+
"type": "training",
|
|
1231
|
+
"status": "new",
|
|
1232
|
+
}
|
|
1233
|
+
mock_trainml._query = AsyncMock(return_value=api_response)
|
|
1234
|
+
result = await jobs.create(
|
|
1235
|
+
name="test job",
|
|
1236
|
+
type="training",
|
|
1237
|
+
gpu_types=["gpu-1", "gpu-2"],
|
|
1238
|
+
gpu_count=1,
|
|
1239
|
+
)
|
|
1240
|
+
call_args = mock_trainml._query.call_args
|
|
1241
|
+
# call_args is (args, kwargs), payload is in kwargs or args[3]
|
|
1242
|
+
payload = call_args[1].get("data") if call_args[1] else call_args[0][3]
|
|
1243
|
+
assert payload["resources"]["gpu_types"] == ["gpu-1", "gpu-2"]
|
|
1244
|
+
|
|
1245
|
+
@mark.asyncio
|
|
1246
|
+
async def test_job_connect_notebook_invalid_status(self, mock_trainml):
|
|
1247
|
+
"""Test connect for notebook type with invalid status (line 314)."""
|
|
1248
|
+
# Notebook type with status "running" should raise error at line 314
|
|
1249
|
+
job = specimen.Job(
|
|
1250
|
+
mock_trainml,
|
|
1251
|
+
**{
|
|
1252
|
+
"job_uuid": "job-id-1",
|
|
1253
|
+
"name": "test notebook",
|
|
1254
|
+
"type": "notebook",
|
|
1255
|
+
"status": "running",
|
|
1256
|
+
"endpoint": {"url": "https://example.com"},
|
|
1257
|
+
}
|
|
1258
|
+
)
|
|
1259
|
+
job._type = "notebook"
|
|
1260
|
+
job._status = "running"
|
|
1261
|
+
with raises(SpecificationError) as exc_info:
|
|
1262
|
+
await job.connect()
|
|
1263
|
+
assert "Notebooks cannot be connected to" in str(exc_info.value.message)
|
|
1264
|
+
|
|
1265
|
+
@mark.asyncio
|
|
1266
|
+
async def test_job_connect_endpoint_returns_url(self, mock_trainml):
|
|
1267
|
+
"""Test connect for endpoint type returns url (line 322)."""
|
|
1268
|
+
job = specimen.Job(
|
|
1269
|
+
mock_trainml,
|
|
1270
|
+
**{
|
|
1271
|
+
"job_uuid": "job-id-1",
|
|
1272
|
+
"name": "test endpoint",
|
|
1273
|
+
"type": "endpoint",
|
|
1274
|
+
"status": "running",
|
|
1275
|
+
"endpoint": {"url": "https://example.com"},
|
|
1276
|
+
}
|
|
1277
|
+
)
|
|
1278
|
+
job._type = "endpoint"
|
|
1279
|
+
job._status = "running"
|
|
1280
|
+
# Endpoint type with status not in ["new", "waiting for data/model download"]
|
|
1281
|
+
# returns url immediately without refresh
|
|
1282
|
+
result = await job.connect()
|
|
1283
|
+
assert result == "https://example.com"
|
|
1284
|
+
|
|
1285
|
+
@mark.asyncio
|
|
1286
|
+
async def test_job_connect_status_not_new_error(self, mock_trainml):
|
|
1287
|
+
"""Test connect raises error when status not new and not in allowed list (line 347)."""
|
|
1288
|
+
job = specimen.Job(
|
|
1289
|
+
mock_trainml,
|
|
1290
|
+
**{
|
|
1291
|
+
"job_uuid": "job-id-1",
|
|
1292
|
+
"name": "test job",
|
|
1293
|
+
"type": "training",
|
|
1294
|
+
"status": "waiting for GPUs",
|
|
1295
|
+
}
|
|
1296
|
+
)
|
|
1297
|
+
job._type = "training"
|
|
1298
|
+
job._status = "waiting for GPUs"
|
|
1299
|
+
api_response = dict(
|
|
1300
|
+
job_uuid="job-id-1",
|
|
1301
|
+
type="training",
|
|
1302
|
+
status="waiting for GPUs",
|
|
1303
|
+
)
|
|
1304
|
+
mock_trainml._query = AsyncMock(return_value=api_response)
|
|
1305
|
+
with raises(SpecificationError) as exc_info:
|
|
1306
|
+
await job.connect()
|
|
1307
|
+
assert "You can only connect to jobs" in str(exc_info.value.message)
|
|
1308
|
+
|
|
1309
|
+
@mark.asyncio
|
|
1310
|
+
async def test_job_connect_endpoint_error_non_downloading(self, mock_trainml, tmp_path):
|
|
1311
|
+
"""Test connect for endpoint type raises error for non-downloading status (line 361)."""
|
|
1312
|
+
# Endpoint with status "waiting for data/model download" goes through normal flow
|
|
1313
|
+
# Refresh at line 353 updates status to "finished", which hits line 361
|
|
1314
|
+
job = specimen.Job(
|
|
1315
|
+
mock_trainml,
|
|
1316
|
+
**{
|
|
1317
|
+
"job_uuid": "job-id-1",
|
|
1318
|
+
"name": "test endpoint",
|
|
1319
|
+
"type": "endpoint",
|
|
1320
|
+
"status": "waiting for data/model download",
|
|
1321
|
+
"endpoint": {"url": "https://example.com"},
|
|
1322
|
+
"data": {"input_type": "local", "input_uri": str(tmp_path / "input")},
|
|
1323
|
+
}
|
|
1324
|
+
)
|
|
1325
|
+
job._type = "endpoint"
|
|
1326
|
+
job._status = "waiting for data/model download"
|
|
1327
|
+
# Refresh at line 353: status changes to "finished" - hits line 361
|
|
1328
|
+
api_response_finished = dict(
|
|
1329
|
+
job_uuid="job-id-1",
|
|
1330
|
+
type="endpoint",
|
|
1331
|
+
status="finished", # Status changed after refresh - hits line 361
|
|
1332
|
+
endpoint={"url": "https://example.com"},
|
|
1333
|
+
data=dict(input_type="local", input_uri=str(tmp_path / "input")),
|
|
1334
|
+
)
|
|
1335
|
+
mock_trainml._query = AsyncMock(return_value=api_response_finished)
|
|
1336
|
+
with raises(SpecificationError) as exc_info:
|
|
1337
|
+
await job.connect()
|
|
1338
|
+
assert "Job status changed to" in str(exc_info.value.message)
|
|
1339
|
+
|
|
1340
|
+
@mark.asyncio
|
|
1341
|
+
async def test_job_connect_missing_model_properties(self, mock_trainml):
|
|
1342
|
+
"""Test connect raises error when model properties missing (line 392)."""
|
|
1343
|
+
job = specimen.Job(
|
|
1344
|
+
mock_trainml,
|
|
1345
|
+
**{
|
|
1346
|
+
"job_uuid": "job-id-1",
|
|
1347
|
+
"name": "test job",
|
|
1348
|
+
"type": "training",
|
|
1349
|
+
"status": "waiting for data/model download",
|
|
1350
|
+
"model": {"model_uuid": "model-1", "source_type": "local"},
|
|
1351
|
+
}
|
|
1352
|
+
)
|
|
1353
|
+
job._status = "waiting for data/model download"
|
|
1354
|
+
api_response = dict(
|
|
1355
|
+
job_uuid="job-id-1",
|
|
1356
|
+
status="waiting for data/model download",
|
|
1357
|
+
model=dict(model_uuid="model-1", source_type="local"),
|
|
1358
|
+
)
|
|
1359
|
+
mock_trainml._query = AsyncMock(return_value=api_response)
|
|
1360
|
+
with raises(SpecificationError) as exc_info:
|
|
1361
|
+
await job.connect()
|
|
1362
|
+
assert "missing required connection properties" in str(exc_info.value.message).lower()
|
|
1363
|
+
|
|
1364
|
+
@mark.asyncio
|
|
1365
|
+
async def test_job_connect_missing_data_properties(self, mock_trainml):
|
|
1366
|
+
"""Test connect raises error when data properties missing (line 411)."""
|
|
1367
|
+
job = specimen.Job(
|
|
1368
|
+
mock_trainml,
|
|
1369
|
+
**{
|
|
1370
|
+
"job_uuid": "job-id-1",
|
|
1371
|
+
"name": "test job",
|
|
1372
|
+
"type": "training",
|
|
1373
|
+
"status": "waiting for data/model download",
|
|
1374
|
+
"data": {"dataset_uuid": "dataset-1", "input_type": "local"},
|
|
1375
|
+
}
|
|
1376
|
+
)
|
|
1377
|
+
job._status = "waiting for data/model download"
|
|
1378
|
+
api_response = dict(
|
|
1379
|
+
job_uuid="job-id-1",
|
|
1380
|
+
status="waiting for data/model download",
|
|
1381
|
+
data=dict(dataset_uuid="dataset-1", input_type="local"),
|
|
1382
|
+
)
|
|
1383
|
+
mock_trainml._query = AsyncMock(return_value=api_response)
|
|
1384
|
+
with raises(SpecificationError) as exc_info:
|
|
1385
|
+
await job.connect()
|
|
1386
|
+
assert "missing required connection properties" in str(exc_info.value.message).lower()
|
|
1387
|
+
|
|
1388
|
+
@mark.asyncio
|
|
1389
|
+
async def test_job_connect_missing_output_uri(self, mock_trainml, tmp_path):
|
|
1390
|
+
"""Test connect raises error when output_uri missing (line 436)."""
|
|
1391
|
+
job = specimen.Job(
|
|
1392
|
+
mock_trainml,
|
|
1393
|
+
**{
|
|
1394
|
+
"job_uuid": "job-id-1",
|
|
1395
|
+
"name": "test job",
|
|
1396
|
+
"type": "training",
|
|
1397
|
+
"status": "uploading",
|
|
1398
|
+
"data": {"output_type": "local"},
|
|
1399
|
+
}
|
|
1400
|
+
)
|
|
1401
|
+
job._status = "uploading"
|
|
1402
|
+
api_response = dict(
|
|
1403
|
+
job_uuid="job-id-1",
|
|
1404
|
+
status="uploading",
|
|
1405
|
+
data=dict(output_type="local"),
|
|
1406
|
+
)
|
|
1407
|
+
mock_trainml._query = AsyncMock(return_value=api_response)
|
|
1408
|
+
with raises(SpecificationError) as exc_info:
|
|
1409
|
+
await job.connect()
|
|
1410
|
+
assert "missing output_uri" in str(exc_info.value.message).lower()
|
|
1411
|
+
|
|
1412
|
+
@mark.asyncio
|
|
1413
|
+
async def test_job_connect_missing_workers(self, mock_trainml, tmp_path):
|
|
1414
|
+
"""Test connect raises error when workers missing (line 453)."""
|
|
1415
|
+
job = specimen.Job(
|
|
1416
|
+
mock_trainml,
|
|
1417
|
+
**{
|
|
1418
|
+
"job_uuid": "job-id-1",
|
|
1419
|
+
"name": "test job",
|
|
1420
|
+
"type": "training",
|
|
1421
|
+
"status": "uploading",
|
|
1422
|
+
"data": {"output_type": "local", "output_uri": str(tmp_path / "output")},
|
|
1423
|
+
"workers": [],
|
|
1424
|
+
}
|
|
1425
|
+
)
|
|
1426
|
+
job._status = "uploading"
|
|
1427
|
+
job._workers = []
|
|
1428
|
+
api_response = dict(
|
|
1429
|
+
job_uuid="job-id-1",
|
|
1430
|
+
status="uploading",
|
|
1431
|
+
data=dict(output_type="local", output_uri=str(tmp_path / "output")),
|
|
1432
|
+
workers=[],
|
|
1433
|
+
)
|
|
1434
|
+
mock_trainml._query = AsyncMock(return_value=api_response)
|
|
1435
|
+
with raises(SpecificationError) as exc_info:
|
|
1436
|
+
await job.connect()
|
|
1437
|
+
assert "has no workers" in str(exc_info.value.message).lower()
|
|
1438
|
+
|
|
1439
|
+
@mark.asyncio
|
|
1440
|
+
async def test_job_wait_for_training_stopped_warning(self, mock_trainml):
|
|
1441
|
+
"""Test wait_for for training job with stopped status shows warning (line 664)."""
|
|
1442
|
+
import warnings
|
|
1443
|
+
job = specimen.Job(
|
|
1444
|
+
mock_trainml,
|
|
1445
|
+
**{
|
|
1446
|
+
"job_uuid": "job-id-1",
|
|
1447
|
+
"name": "test job",
|
|
1448
|
+
"type": "training",
|
|
1449
|
+
"status": "running",
|
|
1450
|
+
}
|
|
1451
|
+
)
|
|
1452
|
+
job._type = "training"
|
|
1453
|
+
job._status = "running"
|
|
1454
|
+
api_response_stopped = dict(
|
|
1455
|
+
job_uuid="job-id-1",
|
|
1456
|
+
status="stopped",
|
|
1457
|
+
)
|
|
1458
|
+
mock_trainml._query = AsyncMock(return_value=api_response_stopped)
|
|
1459
|
+
with patch("trainml.jobs.asyncio.sleep", new_callable=AsyncMock):
|
|
1460
|
+
with warnings.catch_warnings(record=True) as w:
|
|
1461
|
+
warnings.simplefilter("always")
|
|
1462
|
+
await job.wait_for("stopped", timeout=10)
|
|
1463
|
+
assert len(w) == 1
|
|
1464
|
+
assert "deprecated" in str(w[0].message).lower()
|
|
1465
|
+
|
|
1466
|
+
@mark.asyncio
|
|
1467
|
+
async def test_job_wait_for_timeout_validation(self, job):
|
|
1468
|
+
"""Test wait_for validates timeout (line 671)."""
|
|
1469
|
+
with raises(SpecificationError) as exc_info:
|
|
1470
|
+
await job.wait_for("finished", timeout=25 * 60 * 60)
|
|
1471
|
+
assert "timeout must be less than" in str(exc_info.value.message)
|
|
1472
|
+
|
|
1473
|
+
@mark.asyncio
|
|
1474
|
+
async def test_job_connect_worker_missing_output_auth_warning(
|
|
1475
|
+
self, mock_trainml, tmp_path, caplog
|
|
1476
|
+
):
|
|
1477
|
+
"""Test connect logs warning when worker missing output_auth_token (lines 478-481)."""
|
|
1478
|
+
import logging
|
|
1479
|
+
caplog.set_level(logging.WARNING)
|
|
1480
|
+
job = specimen.Job(
|
|
1481
|
+
mock_trainml,
|
|
1482
|
+
**{
|
|
1483
|
+
"job_uuid": "job-id-1",
|
|
1484
|
+
"name": "test job",
|
|
1485
|
+
"type": "training",
|
|
1486
|
+
"status": "running",
|
|
1487
|
+
"data": {
|
|
1488
|
+
"output_type": "local",
|
|
1489
|
+
"output_uri": str(tmp_path / "output"),
|
|
1490
|
+
},
|
|
1491
|
+
"workers": [
|
|
1492
|
+
{
|
|
1493
|
+
"job_worker_uuid": "worker-1",
|
|
1494
|
+
"status": "running",
|
|
1495
|
+
},
|
|
1496
|
+
],
|
|
1497
|
+
},
|
|
1498
|
+
)
|
|
1499
|
+
job._status = "running"
|
|
1500
|
+
# First refresh (line 353): initial refresh after status check
|
|
1501
|
+
api_response_initial = dict(
|
|
1502
|
+
job_uuid="job-id-1",
|
|
1503
|
+
status="running",
|
|
1504
|
+
data=dict(output_type="local", output_uri=str(tmp_path / "output")),
|
|
1505
|
+
workers=[dict(job_worker_uuid="worker-1", status="running")],
|
|
1506
|
+
)
|
|
1507
|
+
# Second refresh (line 448, first iteration): worker becomes uploading but missing output_auth_token
|
|
1508
|
+
api_response_uploading = dict(
|
|
1509
|
+
job_uuid="job-id-1",
|
|
1510
|
+
status="running",
|
|
1511
|
+
data=dict(output_type="local", output_uri=str(tmp_path / "output")),
|
|
1512
|
+
workers=[dict(job_worker_uuid="worker-1", status="uploading")], # Missing output_auth_token and output_hostname
|
|
1513
|
+
)
|
|
1514
|
+
# Third refresh (line 448, second iteration): worker finished to break loop
|
|
1515
|
+
api_response_finished = dict(
|
|
1516
|
+
job_uuid="job-id-1",
|
|
1517
|
+
status="running", # Keep as running to avoid line 361 error
|
|
1518
|
+
data=dict(output_type="local", output_uri=str(tmp_path / "output")),
|
|
1519
|
+
workers=[dict(job_worker_uuid="worker-1", status="finished")],
|
|
1520
|
+
)
|
|
1521
|
+
mock_trainml._query = AsyncMock(side_effect=[api_response_initial, api_response_uploading, api_response_finished])
|
|
1522
|
+
|
|
1523
|
+
with patch("trainml.jobs.download", new_callable=AsyncMock):
|
|
1524
|
+
with patch("asyncio.sleep", new_callable=AsyncMock):
|
|
1525
|
+
await job.connect()
|
|
1526
|
+
# Check that warning was logged (lines 478-481)
|
|
1527
|
+
# The warning should be logged when worker is uploading but missing output_auth_token or output_hostname
|
|
1528
|
+
assert "missing output_auth_token" in caplog.text.lower() or "missing output_hostname" in caplog.text.lower() or "skipping" in caplog.text.lower()
|
|
1529
|
+
|
|
1530
|
+
@mark.asyncio
|
|
1531
|
+
async def test_job_connect_download_task_creation_exception(
|
|
1532
|
+
self, mock_trainml, tmp_path
|
|
1533
|
+
):
|
|
1534
|
+
"""Test connect raises exception when download task creation fails (lines 500-505)."""
|
|
1535
|
+
job = specimen.Job(
|
|
1536
|
+
mock_trainml,
|
|
1537
|
+
**{
|
|
1538
|
+
"job_uuid": "job-id-1",
|
|
1539
|
+
"name": "test job",
|
|
1540
|
+
"type": "training",
|
|
1541
|
+
"status": "running",
|
|
1542
|
+
"data": {
|
|
1543
|
+
"output_type": "local",
|
|
1544
|
+
"output_uri": str(tmp_path / "output"),
|
|
1545
|
+
},
|
|
1546
|
+
"workers": [
|
|
1547
|
+
{
|
|
1548
|
+
"job_worker_uuid": "worker-1",
|
|
1549
|
+
"status": "uploading",
|
|
1550
|
+
"output_auth_token": "token-1",
|
|
1551
|
+
"output_hostname": "host-1.com",
|
|
1552
|
+
},
|
|
1553
|
+
],
|
|
1554
|
+
},
|
|
1555
|
+
)
|
|
1556
|
+
job._status = "running"
|
|
1557
|
+
api_response = dict(
|
|
1558
|
+
job_uuid="job-id-1",
|
|
1559
|
+
status="running",
|
|
1560
|
+
data=dict(output_uri=str(tmp_path / "output")),
|
|
1561
|
+
workers=[
|
|
1562
|
+
dict(
|
|
1563
|
+
job_worker_uuid="worker-1",
|
|
1564
|
+
status="uploading",
|
|
1565
|
+
output_auth_token="token-1",
|
|
1566
|
+
output_hostname="host-1.com",
|
|
1567
|
+
)
|
|
1568
|
+
],
|
|
1569
|
+
)
|
|
1570
|
+
mock_trainml._query = AsyncMock(return_value=api_response)
|
|
1571
|
+
|
|
1572
|
+
with patch("trainml.jobs.Job.refresh", new_callable=AsyncMock):
|
|
1573
|
+
with patch("trainml.jobs.asyncio.create_task", side_effect=Exception("Task creation failed")):
|
|
1574
|
+
with patch("asyncio.sleep", new_callable=AsyncMock):
|
|
1575
|
+
with raises(Exception) as exc_info:
|
|
1576
|
+
await job.connect()
|
|
1577
|
+
assert "Task creation failed" in str(exc_info.value)
|
|
1578
|
+
|
|
1579
|
+
@mark.asyncio
|
|
1580
|
+
async def test_job_connect_download_task_completion_exception(
|
|
1581
|
+
self, mock_trainml, tmp_path
|
|
1582
|
+
):
|
|
1583
|
+
"""Test connect raises exception when download task fails (lines 513-522)."""
|
|
1584
|
+
import asyncio
|
|
1585
|
+
job = specimen.Job(
|
|
1586
|
+
mock_trainml,
|
|
1587
|
+
**{
|
|
1588
|
+
"job_uuid": "job-id-1",
|
|
1589
|
+
"name": "test job",
|
|
1590
|
+
"type": "training",
|
|
1591
|
+
"status": "running",
|
|
1592
|
+
"data": {
|
|
1593
|
+
"output_type": "local",
|
|
1594
|
+
"output_uri": str(tmp_path / "output"),
|
|
1595
|
+
},
|
|
1596
|
+
"workers": [
|
|
1597
|
+
{
|
|
1598
|
+
"job_worker_uuid": "worker-1",
|
|
1599
|
+
"status": "uploading",
|
|
1600
|
+
"output_auth_token": "token-1",
|
|
1601
|
+
"output_hostname": "host-1.com",
|
|
1602
|
+
},
|
|
1603
|
+
],
|
|
1604
|
+
},
|
|
1605
|
+
)
|
|
1606
|
+
job._status = "running"
|
|
1607
|
+
api_response_running = dict(
|
|
1608
|
+
job_uuid="job-id-1",
|
|
1609
|
+
status="running",
|
|
1610
|
+
data=dict(output_type="local", output_uri=str(tmp_path / "output")),
|
|
1611
|
+
workers=[
|
|
1612
|
+
dict(
|
|
1613
|
+
job_worker_uuid="worker-1",
|
|
1614
|
+
status="uploading",
|
|
1615
|
+
output_auth_token="token-1",
|
|
1616
|
+
output_hostname="host-1.com",
|
|
1617
|
+
)
|
|
1618
|
+
],
|
|
1619
|
+
)
|
|
1620
|
+
mock_trainml._query = AsyncMock(return_value=api_response_running)
|
|
1621
|
+
|
|
1622
|
+
# Create a real task that fails immediately
|
|
1623
|
+
async def failing_download(*args, **kwargs):
|
|
1624
|
+
raise Exception("Download failed")
|
|
1625
|
+
|
|
1626
|
+
# Create the task and let it fail
|
|
1627
|
+
failed_task = asyncio.create_task(failing_download())
|
|
1628
|
+
try:
|
|
1629
|
+
await failed_task
|
|
1630
|
+
except Exception:
|
|
1631
|
+
pass # Task is now done and failed
|
|
1632
|
+
|
|
1633
|
+
refresh_count = [0]
|
|
1634
|
+
def refresh_side_effect():
|
|
1635
|
+
refresh_count[0] += 1
|
|
1636
|
+
if refresh_count[0] == 1:
|
|
1637
|
+
job._status = "running"
|
|
1638
|
+
job._job["workers"][0]["status"] = "uploading"
|
|
1639
|
+
|
|
1640
|
+
with patch("trainml.jobs.Job.refresh", new_callable=AsyncMock) as mock_refresh:
|
|
1641
|
+
mock_refresh.side_effect = refresh_side_effect
|
|
1642
|
+
with patch("trainml.jobs.download", new_callable=AsyncMock):
|
|
1643
|
+
with patch("asyncio.create_task", return_value=failed_task):
|
|
1644
|
+
with patch("asyncio.sleep", new_callable=AsyncMock):
|
|
1645
|
+
with raises(Exception) as exc_info:
|
|
1646
|
+
await job.connect()
|
|
1647
|
+
assert "Download failed" in str(exc_info.value)
|
|
1648
|
+
|
|
1649
|
+
@mark.asyncio
|
|
1650
|
+
async def test_job_connect_all_finished_break(
|
|
1651
|
+
self, mock_trainml, tmp_path
|
|
1652
|
+
):
|
|
1653
|
+
"""Test connect breaks when all workers finished (line 535)."""
|
|
1654
|
+
job = specimen.Job(
|
|
1655
|
+
mock_trainml,
|
|
1656
|
+
**{
|
|
1657
|
+
"job_uuid": "job-id-1",
|
|
1658
|
+
"name": "test job",
|
|
1659
|
+
"type": "training",
|
|
1660
|
+
"status": "running",
|
|
1661
|
+
"data": {
|
|
1662
|
+
"output_type": "local",
|
|
1663
|
+
"output_uri": str(tmp_path / "output"),
|
|
1664
|
+
},
|
|
1665
|
+
"workers": [
|
|
1666
|
+
{
|
|
1667
|
+
"job_worker_uuid": "worker-1",
|
|
1668
|
+
"status": "running",
|
|
1669
|
+
},
|
|
1670
|
+
],
|
|
1671
|
+
},
|
|
1672
|
+
)
|
|
1673
|
+
job._status = "running"
|
|
1674
|
+
# First refresh (line 353): initial refresh after status check
|
|
1675
|
+
api_response_initial = dict(
|
|
1676
|
+
job_uuid="job-id-1",
|
|
1677
|
+
status="running",
|
|
1678
|
+
data=dict(output_type="local", output_uri=str(tmp_path / "output")),
|
|
1679
|
+
workers=[dict(job_worker_uuid="worker-1", status="running")],
|
|
1680
|
+
)
|
|
1681
|
+
# Second refresh (line 448, first iteration): worker finished, but status stays running
|
|
1682
|
+
# This tests the all_finished break at line 535
|
|
1683
|
+
api_response_finished = dict(
|
|
1684
|
+
job_uuid="job-id-1",
|
|
1685
|
+
status="running", # Keep status as running so it doesn't hit line 361 error
|
|
1686
|
+
data=dict(output_type="local", output_uri=str(tmp_path / "output")),
|
|
1687
|
+
workers=[dict(job_worker_uuid="worker-1", status="finished")],
|
|
1688
|
+
)
|
|
1689
|
+
mock_trainml._query = AsyncMock(side_effect=[api_response_initial, api_response_finished])
|
|
1690
|
+
|
|
1691
|
+
with patch("asyncio.sleep", new_callable=AsyncMock) as sleep_mock:
|
|
1692
|
+
await job.connect()
|
|
1693
|
+
# Should break when all_finished is True (line 535)
|
|
1694
|
+
# The break happens in the while loop when all workers are finished
|
|
1695
|
+
# Since all workers finished immediately after first refresh, sleep should not be called
|
|
1696
|
+
assert sleep_mock.call_count == 0
|
|
1697
|
+
|
|
1698
|
+
@mark.asyncio
|
|
1699
|
+
async def test_job_connect_sleep_30_no_download_tasks(
|
|
1700
|
+
self, mock_trainml, tmp_path
|
|
1701
|
+
):
|
|
1702
|
+
"""Test connect sleeps 30 seconds when no download tasks (line 543)."""
|
|
1703
|
+
job = specimen.Job(
|
|
1704
|
+
mock_trainml,
|
|
1705
|
+
**{
|
|
1706
|
+
"job_uuid": "job-id-1",
|
|
1707
|
+
"name": "test job",
|
|
1708
|
+
"type": "training",
|
|
1709
|
+
"status": "running",
|
|
1710
|
+
"data": {
|
|
1711
|
+
"output_type": "local",
|
|
1712
|
+
"output_uri": str(tmp_path / "output"),
|
|
1713
|
+
},
|
|
1714
|
+
"workers": [
|
|
1715
|
+
{
|
|
1716
|
+
"job_worker_uuid": "worker-1",
|
|
1717
|
+
"status": "running",
|
|
1718
|
+
},
|
|
1719
|
+
],
|
|
1720
|
+
},
|
|
1721
|
+
)
|
|
1722
|
+
job._status = "running"
|
|
1723
|
+
api_response_running = dict(
|
|
1724
|
+
job_uuid="job-id-1",
|
|
1725
|
+
status="running",
|
|
1726
|
+
data=dict(output_type="local", output_uri=str(tmp_path / "output")),
|
|
1727
|
+
workers=[dict(job_worker_uuid="worker-1", status="running")], # Not uploading, so no download tasks
|
|
1728
|
+
)
|
|
1729
|
+
api_response_still_running = dict(
|
|
1730
|
+
job_uuid="job-id-1",
|
|
1731
|
+
status="running",
|
|
1732
|
+
data=dict(output_type="local", output_uri=str(tmp_path / "output")),
|
|
1733
|
+
workers=[dict(job_worker_uuid="worker-1", status="running")], # Still running, not finished
|
|
1734
|
+
)
|
|
1735
|
+
api_response_finished = dict(
|
|
1736
|
+
job_uuid="job-id-1",
|
|
1737
|
+
status="running", # Keep as running to avoid line 361 error
|
|
1738
|
+
data=dict(output_type="local", output_uri=str(tmp_path / "output")),
|
|
1739
|
+
workers=[dict(job_worker_uuid="worker-1", status="finished")],
|
|
1740
|
+
)
|
|
1741
|
+
mock_trainml._query = AsyncMock(side_effect=[api_response_running, api_response_still_running, api_response_finished])
|
|
1742
|
+
|
|
1743
|
+
with patch("asyncio.sleep", new_callable=AsyncMock) as sleep_mock:
|
|
1744
|
+
await job.connect()
|
|
1745
|
+
# Should have called sleep with 30 seconds when no download tasks (line 543)
|
|
1746
|
+
# First iteration: no download tasks, so sleep(30)
|
|
1747
|
+
# Second iteration: worker finished, so break
|
|
1748
|
+
sleep_calls = [call[0][0] for call in sleep_mock.call_args_list]
|
|
1749
|
+
assert 30 in sleep_calls
|