trainml 0.5.17__py3-none-any.whl → 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- examples/local_storage.py +0 -2
- tests/integration/test_checkpoints_integration.py +4 -3
- tests/integration/test_datasets_integration.py +5 -3
- tests/integration/test_jobs_integration.py +33 -27
- tests/integration/test_models_integration.py +7 -3
- tests/integration/test_volumes_integration.py +2 -2
- tests/unit/cli/test_cli_checkpoint_unit.py +312 -1
- tests/unit/cloudbender/test_nodes_unit.py +112 -0
- tests/unit/cloudbender/test_providers_unit.py +96 -0
- tests/unit/cloudbender/test_regions_unit.py +106 -0
- tests/unit/cloudbender/test_services_unit.py +141 -0
- tests/unit/conftest.py +23 -10
- tests/unit/projects/test_project_data_connectors_unit.py +39 -0
- tests/unit/projects/test_project_datastores_unit.py +37 -0
- tests/unit/projects/test_project_members_unit.py +46 -0
- tests/unit/projects/test_project_services_unit.py +65 -0
- tests/unit/projects/test_projects_unit.py +16 -0
- tests/unit/test_auth_unit.py +17 -2
- tests/unit/test_checkpoints_unit.py +256 -71
- tests/unit/test_datasets_unit.py +218 -68
- tests/unit/test_exceptions.py +133 -0
- tests/unit/test_gpu_types_unit.py +11 -1
- tests/unit/test_jobs_unit.py +1014 -95
- tests/unit/test_main_unit.py +20 -0
- tests/unit/test_models_unit.py +218 -70
- tests/unit/test_trainml_unit.py +627 -3
- tests/unit/test_volumes_unit.py +211 -70
- tests/unit/utils/__init__.py +1 -0
- tests/unit/utils/test_transfer_unit.py +4260 -0
- trainml/__init__.py +1 -1
- trainml/checkpoints.py +56 -57
- trainml/cli/__init__.py +6 -3
- trainml/cli/checkpoint.py +18 -57
- trainml/cli/dataset.py +17 -57
- trainml/cli/job/__init__.py +11 -53
- trainml/cli/job/create.py +51 -24
- trainml/cli/model.py +14 -56
- trainml/cli/volume.py +18 -57
- trainml/datasets.py +50 -55
- trainml/jobs.py +239 -68
- trainml/models.py +51 -55
- trainml/trainml.py +50 -16
- trainml/utils/__init__.py +1 -0
- trainml/utils/auth.py +641 -0
- trainml/utils/transfer.py +587 -0
- trainml/volumes.py +48 -53
- {trainml-0.5.17.dist-info → trainml-1.0.0.dist-info}/METADATA +3 -3
- {trainml-0.5.17.dist-info → trainml-1.0.0.dist-info}/RECORD +52 -46
- {trainml-0.5.17.dist-info → trainml-1.0.0.dist-info}/LICENSE +0 -0
- {trainml-0.5.17.dist-info → trainml-1.0.0.dist-info}/WHEEL +0 -0
- {trainml-0.5.17.dist-info → trainml-1.0.0.dist-info}/entry_points.txt +0 -0
- {trainml-0.5.17.dist-info → trainml-1.0.0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,20 @@
|
|
|
1
|
+
from unittest.mock import patch
|
|
2
|
+
from pytest import mark
|
|
3
|
+
|
|
4
|
+
pytestmark = [mark.sdk, mark.unit]
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
@patch('trainml.cli.cli')
|
|
8
|
+
def test_main_module_execution(mock_cli):
|
|
9
|
+
"""Test that __main__ module calls cli() when executed as main."""
|
|
10
|
+
# Import the module to get the cli reference
|
|
11
|
+
import trainml.__main__
|
|
12
|
+
|
|
13
|
+
# Execute the code that runs when __name__ == '__main__'
|
|
14
|
+
# We need to simulate the main execution block
|
|
15
|
+
# Since we can't change __name__ after import, we'll directly call
|
|
16
|
+
# the logic that would execute
|
|
17
|
+
trainml.__main__.cli()
|
|
18
|
+
|
|
19
|
+
# Verify cli was called
|
|
20
|
+
mock_cli.assert_called_once()
|
tests/unit/test_models_unit.py
CHANGED
|
@@ -45,7 +45,9 @@ class ModelsTests:
|
|
|
45
45
|
api_response = dict()
|
|
46
46
|
mock_trainml._query = AsyncMock(return_value=api_response)
|
|
47
47
|
await models.get("1234")
|
|
48
|
-
mock_trainml._query.assert_called_once_with(
|
|
48
|
+
mock_trainml._query.assert_called_once_with(
|
|
49
|
+
"/model/1234", "GET", dict()
|
|
50
|
+
)
|
|
49
51
|
|
|
50
52
|
@mark.asyncio
|
|
51
53
|
async def test_list_models(
|
|
@@ -119,7 +121,11 @@ class ModelTests:
|
|
|
119
121
|
|
|
120
122
|
def test_model_repr(self, model):
|
|
121
123
|
string = repr(model)
|
|
122
|
-
regex =
|
|
124
|
+
regex = (
|
|
125
|
+
r"^Model\( trainml , \*\*{.*'model_uuid': '"
|
|
126
|
+
+ model.id
|
|
127
|
+
+ r"'.*}\)$"
|
|
128
|
+
)
|
|
123
129
|
assert isinstance(string, str)
|
|
124
130
|
assert re.match(regex, string)
|
|
125
131
|
|
|
@@ -130,9 +136,7 @@ class ModelTests:
|
|
|
130
136
|
|
|
131
137
|
@mark.asyncio
|
|
132
138
|
async def test_model_get_log_url(self, model, mock_trainml):
|
|
133
|
-
api_response =
|
|
134
|
-
"https://trainml-jobs-dev.s3.us-east-2.amazonaws.com/1/logs/first_one.zip"
|
|
135
|
-
)
|
|
139
|
+
api_response = "https://trainml-jobs-dev.s3.us-east-2.amazonaws.com/1/logs/first_one.zip"
|
|
136
140
|
mock_trainml._query = AsyncMock(return_value=api_response)
|
|
137
141
|
response = await model.get_log_url()
|
|
138
142
|
mock_trainml._query.assert_called_once_with(
|
|
@@ -157,79 +161,73 @@ class ModelTests:
|
|
|
157
161
|
assert response == api_response
|
|
158
162
|
|
|
159
163
|
@mark.asyncio
|
|
160
|
-
async def
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
164
|
+
async def test_model_connect_downloading_status(self, mock_trainml):
|
|
165
|
+
model = specimen.Model(
|
|
166
|
+
mock_trainml,
|
|
167
|
+
model_uuid="1",
|
|
168
|
+
project_uuid="proj-id-1",
|
|
169
|
+
name="test model",
|
|
170
|
+
status="downloading",
|
|
171
|
+
auth_token="test-token",
|
|
172
|
+
hostname="example.com",
|
|
173
|
+
source_uri="/path/to/source",
|
|
168
174
|
)
|
|
169
|
-
assert response == api_response
|
|
170
175
|
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
|
|
176
|
+
with patch(
|
|
177
|
+
"trainml.models.Model.refresh", new_callable=AsyncMock
|
|
178
|
+
) as mock_refresh:
|
|
179
|
+
with patch(
|
|
180
|
+
"trainml.models.upload", new_callable=AsyncMock
|
|
181
|
+
) as mock_upload:
|
|
182
|
+
await model.connect()
|
|
183
|
+
mock_refresh.assert_called_once()
|
|
184
|
+
mock_upload.assert_called_once_with(
|
|
185
|
+
"example.com", "test-token", "/path/to/source"
|
|
186
|
+
)
|
|
175
187
|
|
|
176
|
-
|
|
188
|
+
@mark.asyncio
|
|
189
|
+
async def test_model_connect_exporting_status(
|
|
190
|
+
self, mock_trainml, tmp_path
|
|
191
|
+
):
|
|
192
|
+
output_dir = str(tmp_path / "output")
|
|
177
193
|
model = specimen.Model(
|
|
178
194
|
mock_trainml,
|
|
179
195
|
model_uuid="1",
|
|
180
|
-
project_uuid="
|
|
181
|
-
name="
|
|
182
|
-
status="
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
vpn={
|
|
188
|
-
"status": "new",
|
|
189
|
-
"cidr": "10.106.171.0/24",
|
|
190
|
-
"client": {
|
|
191
|
-
"port": "36017",
|
|
192
|
-
"id": "cus-id-1",
|
|
193
|
-
"address": "10.106.171.253",
|
|
194
|
-
"ssh_port": 46600,
|
|
195
|
-
},
|
|
196
|
-
"net_prefix_type_id": 1,
|
|
197
|
-
},
|
|
198
|
-
)
|
|
199
|
-
details = model.get_connection_details()
|
|
200
|
-
expected_details = dict(
|
|
201
|
-
project_uuid="a",
|
|
202
|
-
entity_type="model",
|
|
203
|
-
cidr="10.106.171.0/24",
|
|
204
|
-
ssh_port=46600,
|
|
205
|
-
input_path="~/tensorflow-example",
|
|
206
|
-
output_path=None,
|
|
207
|
-
)
|
|
208
|
-
assert details == expected_details
|
|
196
|
+
project_uuid="proj-id-1",
|
|
197
|
+
name="test model",
|
|
198
|
+
status="exporting",
|
|
199
|
+
auth_token="test-token",
|
|
200
|
+
hostname="example.com",
|
|
201
|
+
output_uri=output_dir,
|
|
202
|
+
)
|
|
209
203
|
|
|
210
|
-
@mark.asyncio
|
|
211
|
-
async def test_model_connect(self, model, mock_trainml):
|
|
212
204
|
with patch(
|
|
213
|
-
"trainml.models.
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
|
|
205
|
+
"trainml.models.Model.refresh", new_callable=AsyncMock
|
|
206
|
+
) as mock_refresh:
|
|
207
|
+
with patch(
|
|
208
|
+
"trainml.models.download", new_callable=AsyncMock
|
|
209
|
+
) as mock_download:
|
|
210
|
+
await model.connect()
|
|
211
|
+
mock_refresh.assert_called_once()
|
|
212
|
+
mock_download.assert_called_once_with(
|
|
213
|
+
"example.com", "test-token", output_dir
|
|
214
|
+
)
|
|
221
215
|
|
|
222
216
|
@mark.asyncio
|
|
223
|
-
async def
|
|
224
|
-
|
|
225
|
-
|
|
226
|
-
|
|
227
|
-
|
|
228
|
-
|
|
229
|
-
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
|
|
217
|
+
async def test_model_connect_invalid_status(self, mock_trainml):
|
|
218
|
+
model = specimen.Model(
|
|
219
|
+
mock_trainml,
|
|
220
|
+
model_uuid="1",
|
|
221
|
+
project_uuid="proj-id-1",
|
|
222
|
+
name="test model",
|
|
223
|
+
status="ready",
|
|
224
|
+
)
|
|
225
|
+
|
|
226
|
+
with raises(
|
|
227
|
+
SpecificationError,
|
|
228
|
+
match="You can only connect to downloading or exporting models",
|
|
229
|
+
):
|
|
230
|
+
await model.connect()
|
|
233
231
|
|
|
234
232
|
@mark.asyncio
|
|
235
233
|
async def test_model_remove(self, model, mock_trainml):
|
|
@@ -412,6 +410,152 @@ class ModelTests:
|
|
|
412
410
|
await model.wait_for("ready", 10)
|
|
413
411
|
mock_trainml._query.assert_called()
|
|
414
412
|
|
|
413
|
+
@mark.asyncio
|
|
414
|
+
async def test_model_rename(self, model, mock_trainml):
|
|
415
|
+
api_response = dict(
|
|
416
|
+
model_uuid="1",
|
|
417
|
+
name="renamed model",
|
|
418
|
+
project_uuid="proj-id-1",
|
|
419
|
+
status="ready",
|
|
420
|
+
)
|
|
421
|
+
mock_trainml._query = AsyncMock(return_value=api_response)
|
|
422
|
+
result = await model.rename("renamed model")
|
|
423
|
+
mock_trainml._query.assert_called_once_with(
|
|
424
|
+
"/model/1",
|
|
425
|
+
"PATCH",
|
|
426
|
+
None,
|
|
427
|
+
dict(name="renamed model"),
|
|
428
|
+
)
|
|
429
|
+
assert result == model
|
|
430
|
+
assert model.name == "renamed model"
|
|
431
|
+
|
|
432
|
+
@mark.asyncio
|
|
433
|
+
async def test_model_export(self, model, mock_trainml):
|
|
434
|
+
api_response = dict(
|
|
435
|
+
model_uuid="1",
|
|
436
|
+
name="first one",
|
|
437
|
+
project_uuid="proj-id-1",
|
|
438
|
+
status="exporting",
|
|
439
|
+
)
|
|
440
|
+
mock_trainml._query = AsyncMock(return_value=api_response)
|
|
441
|
+
result = await model.export("aws", "s3://bucket/path", dict(key="value"))
|
|
442
|
+
mock_trainml._query.assert_called_once_with(
|
|
443
|
+
"/model/1/export",
|
|
444
|
+
"POST",
|
|
445
|
+
dict(project_uuid="proj-id-1"),
|
|
446
|
+
dict(
|
|
447
|
+
output_type="aws",
|
|
448
|
+
output_uri="s3://bucket/path",
|
|
449
|
+
output_options=dict(key="value"),
|
|
450
|
+
),
|
|
451
|
+
)
|
|
452
|
+
assert result == model
|
|
453
|
+
assert model.status == "exporting"
|
|
454
|
+
|
|
455
|
+
@mark.asyncio
|
|
456
|
+
async def test_model_export_default_options(self, model, mock_trainml):
|
|
457
|
+
api_response = dict(
|
|
458
|
+
model_uuid="1",
|
|
459
|
+
name="first one",
|
|
460
|
+
project_uuid="proj-id-1",
|
|
461
|
+
status="exporting",
|
|
462
|
+
)
|
|
463
|
+
mock_trainml._query = AsyncMock(return_value=api_response)
|
|
464
|
+
result = await model.export("aws", "s3://bucket/path")
|
|
465
|
+
mock_trainml._query.assert_called_once_with(
|
|
466
|
+
"/model/1/export",
|
|
467
|
+
"POST",
|
|
468
|
+
dict(project_uuid="proj-id-1"),
|
|
469
|
+
dict(
|
|
470
|
+
output_type="aws",
|
|
471
|
+
output_uri="s3://bucket/path",
|
|
472
|
+
output_options=dict(),
|
|
473
|
+
),
|
|
474
|
+
)
|
|
475
|
+
assert result == model
|
|
476
|
+
|
|
477
|
+
@mark.asyncio
|
|
478
|
+
async def test_model_wait_for_timeout_validation(
|
|
479
|
+
self, model, mock_trainml
|
|
480
|
+
):
|
|
481
|
+
with raises(SpecificationError) as exc_info:
|
|
482
|
+
await model.wait_for("ready", timeout=25 * 60 * 60) # > 24 hours
|
|
483
|
+
assert "timeout" in str(exc_info.value.attribute).lower()
|
|
484
|
+
assert "less than" in str(exc_info.value.message).lower()
|
|
485
|
+
|
|
486
|
+
@mark.asyncio
|
|
487
|
+
async def test_model_connect_new_status_waits_for_downloading(
|
|
488
|
+
self, model, mock_trainml
|
|
489
|
+
):
|
|
490
|
+
"""Test that connect() waits for downloading status when status is 'new'."""
|
|
491
|
+
model._model["status"] = "new"
|
|
492
|
+
model._status = "new"
|
|
493
|
+
api_response_new = dict(
|
|
494
|
+
model_uuid="1",
|
|
495
|
+
name="first one",
|
|
496
|
+
status="new",
|
|
497
|
+
)
|
|
498
|
+
api_response_downloading = dict(
|
|
499
|
+
model_uuid="1",
|
|
500
|
+
name="first one",
|
|
501
|
+
status="downloading",
|
|
502
|
+
auth_token="token",
|
|
503
|
+
hostname="host",
|
|
504
|
+
source_uri="s3://bucket/path",
|
|
505
|
+
)
|
|
506
|
+
# wait_for calls refresh multiple times, then connect calls refresh again
|
|
507
|
+
mock_trainml._query = AsyncMock(
|
|
508
|
+
side_effect=[
|
|
509
|
+
api_response_new, # wait_for refresh 1
|
|
510
|
+
api_response_downloading, # wait_for refresh 2 (status matches, wait_for returns)
|
|
511
|
+
api_response_downloading, # connect refresh
|
|
512
|
+
]
|
|
513
|
+
)
|
|
514
|
+
with patch("trainml.models.upload", new_callable=AsyncMock) as mock_upload:
|
|
515
|
+
await model.connect()
|
|
516
|
+
# After refresh, status should be downloading
|
|
517
|
+
assert model.status == "downloading"
|
|
518
|
+
mock_upload.assert_called_once()
|
|
519
|
+
|
|
520
|
+
@mark.asyncio
|
|
521
|
+
async def test_model_connect_downloading_missing_properties(
|
|
522
|
+
self, model, mock_trainml
|
|
523
|
+
):
|
|
524
|
+
"""Test connect() raises error when downloading status missing properties."""
|
|
525
|
+
model._model["status"] = "downloading"
|
|
526
|
+
api_response = dict(
|
|
527
|
+
model_uuid="1",
|
|
528
|
+
name="first one",
|
|
529
|
+
status="downloading",
|
|
530
|
+
# Missing auth_token, hostname, or source_uri
|
|
531
|
+
)
|
|
532
|
+
mock_trainml._query = AsyncMock(return_value=api_response)
|
|
533
|
+
with raises(SpecificationError) as exc_info:
|
|
534
|
+
await model.connect()
|
|
535
|
+
assert "missing required connection properties" in str(exc_info.value.message).lower()
|
|
536
|
+
|
|
537
|
+
@mark.asyncio
|
|
538
|
+
async def test_model_connect_exporting_missing_properties(
|
|
539
|
+
self, model, mock_trainml
|
|
540
|
+
):
|
|
541
|
+
"""Test connect() raises error when exporting status missing properties."""
|
|
542
|
+
model._model["status"] = "exporting"
|
|
543
|
+
api_response = dict(
|
|
544
|
+
model_uuid="1",
|
|
545
|
+
name="first one",
|
|
546
|
+
status="exporting",
|
|
547
|
+
# Missing auth_token, hostname, or output_uri
|
|
548
|
+
)
|
|
549
|
+
mock_trainml._query = AsyncMock(return_value=api_response)
|
|
550
|
+
with raises(SpecificationError) as exc_info:
|
|
551
|
+
await model.connect()
|
|
552
|
+
assert "missing required connection properties" in str(exc_info.value.message).lower()
|
|
553
|
+
|
|
554
|
+
def test_model_billed_size_property(self, model, mock_trainml):
|
|
555
|
+
"""Test billed_size property access."""
|
|
556
|
+
model._billed_size = 50000
|
|
557
|
+
assert model.billed_size == 50000
|
|
558
|
+
|
|
415
559
|
@mark.asyncio
|
|
416
560
|
async def test_model_wait_for_model_failed(self, model, mock_trainml):
|
|
417
561
|
api_response = dict(
|
|
@@ -426,7 +570,9 @@ class ModelTests:
|
|
|
426
570
|
mock_trainml._query.assert_called()
|
|
427
571
|
|
|
428
572
|
@mark.asyncio
|
|
429
|
-
async def test_model_wait_for_archived_succeeded(
|
|
573
|
+
async def test_model_wait_for_archived_succeeded(
|
|
574
|
+
self, model, mock_trainml
|
|
575
|
+
):
|
|
430
576
|
mock_trainml._query = AsyncMock(
|
|
431
577
|
side_effect=ApiError(404, dict(errorMessage="Model Not Found"))
|
|
432
578
|
)
|
|
@@ -434,7 +580,9 @@ class ModelTests:
|
|
|
434
580
|
mock_trainml._query.assert_called()
|
|
435
581
|
|
|
436
582
|
@mark.asyncio
|
|
437
|
-
async def test_model_wait_for_unexpected_api_error(
|
|
583
|
+
async def test_model_wait_for_unexpected_api_error(
|
|
584
|
+
self, model, mock_trainml
|
|
585
|
+
):
|
|
438
586
|
mock_trainml._query = AsyncMock(
|
|
439
587
|
side_effect=ApiError(404, dict(errorMessage="Model Not Found"))
|
|
440
588
|
)
|