trainml 0.5.17__tar.gz → 1.0.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {trainml-0.5.17 → trainml-1.0.0}/PKG-INFO +2 -2
- {trainml-0.5.17 → trainml-1.0.0}/README.md +1 -1
- {trainml-0.5.17 → trainml-1.0.0}/examples/local_storage.py +0 -2
- {trainml-0.5.17 → trainml-1.0.0}/pyproject.toml +1 -0
- {trainml-0.5.17 → trainml-1.0.0}/tests/integration/test_checkpoints_integration.py +4 -3
- {trainml-0.5.17 → trainml-1.0.0}/tests/integration/test_datasets_integration.py +5 -3
- {trainml-0.5.17 → trainml-1.0.0}/tests/integration/test_jobs_integration.py +33 -27
- {trainml-0.5.17 → trainml-1.0.0}/tests/integration/test_models_integration.py +7 -3
- {trainml-0.5.17 → trainml-1.0.0}/tests/integration/test_volumes_integration.py +2 -2
- trainml-1.0.0/tests/unit/cli/test_cli_checkpoint_unit.py +333 -0
- {trainml-0.5.17 → trainml-1.0.0}/tests/unit/cloudbender/test_nodes_unit.py +112 -0
- {trainml-0.5.17 → trainml-1.0.0}/tests/unit/cloudbender/test_providers_unit.py +96 -0
- {trainml-0.5.17 → trainml-1.0.0}/tests/unit/cloudbender/test_regions_unit.py +106 -0
- trainml-1.0.0/tests/unit/cloudbender/test_services_unit.py +308 -0
- {trainml-0.5.17 → trainml-1.0.0}/tests/unit/conftest.py +23 -10
- {trainml-0.5.17 → trainml-1.0.0}/tests/unit/projects/test_project_data_connectors_unit.py +39 -0
- {trainml-0.5.17 → trainml-1.0.0}/tests/unit/projects/test_project_datastores_unit.py +37 -0
- {trainml-0.5.17 → trainml-1.0.0}/tests/unit/projects/test_project_members_unit.py +46 -0
- {trainml-0.5.17 → trainml-1.0.0}/tests/unit/projects/test_project_services_unit.py +65 -0
- {trainml-0.5.17 → trainml-1.0.0}/tests/unit/projects/test_projects_unit.py +16 -0
- {trainml-0.5.17 → trainml-1.0.0}/tests/unit/test_auth_unit.py +17 -2
- {trainml-0.5.17 → trainml-1.0.0}/tests/unit/test_checkpoints_unit.py +256 -71
- {trainml-0.5.17 → trainml-1.0.0}/tests/unit/test_datasets_unit.py +218 -68
- trainml-1.0.0/tests/unit/test_exceptions.py +164 -0
- {trainml-0.5.17 → trainml-1.0.0}/tests/unit/test_gpu_types_unit.py +11 -1
- trainml-1.0.0/tests/unit/test_jobs_unit.py +1749 -0
- trainml-1.0.0/tests/unit/test_main_unit.py +20 -0
- {trainml-0.5.17 → trainml-1.0.0}/tests/unit/test_models_unit.py +218 -70
- trainml-1.0.0/tests/unit/test_trainml_unit.py +678 -0
- {trainml-0.5.17 → trainml-1.0.0}/tests/unit/test_volumes_unit.py +211 -70
- trainml-1.0.0/tests/unit/utils/__init__.py +1 -0
- trainml-1.0.0/tests/unit/utils/test_transfer_unit.py +4260 -0
- {trainml-0.5.17 → trainml-1.0.0}/trainml/__init__.py +1 -1
- {trainml-0.5.17 → trainml-1.0.0}/trainml/checkpoints.py +56 -57
- {trainml-0.5.17 → trainml-1.0.0}/trainml/cli/__init__.py +6 -3
- {trainml-0.5.17 → trainml-1.0.0}/trainml/cli/checkpoint.py +18 -57
- {trainml-0.5.17 → trainml-1.0.0}/trainml/cli/dataset.py +17 -57
- {trainml-0.5.17 → trainml-1.0.0}/trainml/cli/job/__init__.py +11 -53
- {trainml-0.5.17 → trainml-1.0.0}/trainml/cli/job/create.py +51 -24
- {trainml-0.5.17 → trainml-1.0.0}/trainml/cli/model.py +14 -56
- {trainml-0.5.17 → trainml-1.0.0}/trainml/cli/volume.py +18 -57
- {trainml-0.5.17 → trainml-1.0.0}/trainml/datasets.py +50 -55
- {trainml-0.5.17 → trainml-1.0.0}/trainml/jobs.py +239 -68
- {trainml-0.5.17 → trainml-1.0.0}/trainml/models.py +51 -55
- {trainml-0.5.17 → trainml-1.0.0}/trainml/trainml.py +50 -16
- trainml-1.0.0/trainml/utils/__init__.py +1 -0
- {trainml-0.5.17/trainml → trainml-1.0.0/trainml/utils}/auth.py +4 -3
- trainml-1.0.0/trainml/utils/transfer.py +587 -0
- {trainml-0.5.17 → trainml-1.0.0}/trainml/volumes.py +48 -53
- {trainml-0.5.17 → trainml-1.0.0}/trainml.egg-info/PKG-INFO +2 -2
- {trainml-0.5.17 → trainml-1.0.0}/trainml.egg-info/SOURCES.txt +7 -5
- {trainml-0.5.17 → trainml-1.0.0}/trainml.egg-info/requires.txt +1 -1
- trainml-0.5.17/tests/unit/cli/test_cli_checkpoint_unit.py +0 -22
- trainml-0.5.17/tests/unit/cloudbender/test_services_unit.py +0 -167
- trainml-0.5.17/tests/unit/test_connections_unit.py +0 -182
- trainml-0.5.17/tests/unit/test_exceptions.py +0 -31
- trainml-0.5.17/tests/unit/test_jobs_unit.py +0 -830
- trainml-0.5.17/tests/unit/test_trainml_unit.py +0 -54
- trainml-0.5.17/trainml/cli/connection.py +0 -61
- trainml-0.5.17/trainml/connections.py +0 -621
- {trainml-0.5.17 → trainml-1.0.0}/LICENSE +0 -0
- {trainml-0.5.17 → trainml-1.0.0}/examples/__init__.py +0 -0
- {trainml-0.5.17 → trainml-1.0.0}/examples/create_dataset_and_training_job.py +0 -0
- {trainml-0.5.17 → trainml-1.0.0}/examples/training_inference_pipeline.py +0 -0
- {trainml-0.5.17 → trainml-1.0.0}/setup.cfg +0 -0
- {trainml-0.5.17 → trainml-1.0.0}/setup.py +0 -0
- {trainml-0.5.17 → trainml-1.0.0}/tests/integration/__init__.py +0 -0
- {trainml-0.5.17 → trainml-1.0.0}/tests/integration/cloudbender/__init__.py +0 -0
- {trainml-0.5.17 → trainml-1.0.0}/tests/integration/cloudbender/conftest.py +0 -0
- {trainml-0.5.17 → trainml-1.0.0}/tests/integration/cloudbender/test_providers_integration.py +0 -0
- {trainml-0.5.17 → trainml-1.0.0}/tests/integration/cloudbender/test_regions_integration.py +0 -0
- {trainml-0.5.17 → trainml-1.0.0}/tests/integration/cloudbender/test_services_integration.py +0 -0
- {trainml-0.5.17 → trainml-1.0.0}/tests/integration/conftest.py +0 -0
- {trainml-0.5.17 → trainml-1.0.0}/tests/integration/projects/__init__.py +0 -0
- {trainml-0.5.17 → trainml-1.0.0}/tests/integration/projects/conftest.py +0 -0
- {trainml-0.5.17 → trainml-1.0.0}/tests/integration/projects/test_projects_credentials_integration.py +0 -0
- {trainml-0.5.17 → trainml-1.0.0}/tests/integration/projects/test_projects_data_connectors_integration.py +0 -0
- {trainml-0.5.17 → trainml-1.0.0}/tests/integration/projects/test_projects_datastores_integration.py +0 -0
- {trainml-0.5.17 → trainml-1.0.0}/tests/integration/projects/test_projects_integration.py +0 -0
- {trainml-0.5.17 → trainml-1.0.0}/tests/integration/projects/test_projects_members_integration.py +0 -0
- {trainml-0.5.17 → trainml-1.0.0}/tests/integration/projects/test_projects_secrets_integration.py +0 -0
- {trainml-0.5.17 → trainml-1.0.0}/tests/integration/projects/test_projects_services_integration.py +0 -0
- {trainml-0.5.17 → trainml-1.0.0}/tests/integration/test_environments_integration.py +0 -0
- {trainml-0.5.17 → trainml-1.0.0}/tests/integration/test_gpu_types_integration.py +0 -0
- {trainml-0.5.17 → trainml-1.0.0}/tests/unit/__init__.py +0 -0
- {trainml-0.5.17 → trainml-1.0.0}/tests/unit/cli/__init__.py +0 -0
- {trainml-0.5.17 → trainml-1.0.0}/tests/unit/cli/cloudbender/__init__.py +0 -0
- {trainml-0.5.17 → trainml-1.0.0}/tests/unit/cli/cloudbender/test_cli_datastore_unit.py +0 -0
- {trainml-0.5.17 → trainml-1.0.0}/tests/unit/cli/cloudbender/test_cli_device_unit.py +0 -0
- {trainml-0.5.17 → trainml-1.0.0}/tests/unit/cli/cloudbender/test_cli_node_unit.py +0 -0
- {trainml-0.5.17 → trainml-1.0.0}/tests/unit/cli/cloudbender/test_cli_provider_unit.py +0 -0
- {trainml-0.5.17 → trainml-1.0.0}/tests/unit/cli/cloudbender/test_cli_region_unit.py +0 -0
- {trainml-0.5.17 → trainml-1.0.0}/tests/unit/cli/cloudbender/test_cli_service_unit.py +0 -0
- {trainml-0.5.17 → trainml-1.0.0}/tests/unit/cli/conftest.py +0 -0
- {trainml-0.5.17 → trainml-1.0.0}/tests/unit/cli/projects/__init__.py +0 -0
- {trainml-0.5.17 → trainml-1.0.0}/tests/unit/cli/projects/test_cli_project_credential_unit.py +0 -0
- {trainml-0.5.17 → trainml-1.0.0}/tests/unit/cli/projects/test_cli_project_data_connector_unit.py +0 -0
- {trainml-0.5.17 → trainml-1.0.0}/tests/unit/cli/projects/test_cli_project_datastore_unit.py +0 -0
- {trainml-0.5.17 → trainml-1.0.0}/tests/unit/cli/projects/test_cli_project_secret_unit.py +0 -0
- {trainml-0.5.17 → trainml-1.0.0}/tests/unit/cli/projects/test_cli_project_service_unit.py +0 -0
- {trainml-0.5.17 → trainml-1.0.0}/tests/unit/cli/projects/test_cli_project_unit.py +0 -0
- {trainml-0.5.17 → trainml-1.0.0}/tests/unit/cli/test_cli_datasets_unit.py +0 -0
- {trainml-0.5.17 → trainml-1.0.0}/tests/unit/cli/test_cli_environment_unit.py +0 -0
- {trainml-0.5.17 → trainml-1.0.0}/tests/unit/cli/test_cli_gpu_unit.py +0 -0
- {trainml-0.5.17 → trainml-1.0.0}/tests/unit/cli/test_cli_job_unit.py +0 -0
- {trainml-0.5.17 → trainml-1.0.0}/tests/unit/cli/test_cli_model_unit.py +0 -0
- {trainml-0.5.17 → trainml-1.0.0}/tests/unit/cli/test_cli_volume_unit.py +0 -0
- {trainml-0.5.17 → trainml-1.0.0}/tests/unit/cloudbender/__init__.py +0 -0
- {trainml-0.5.17 → trainml-1.0.0}/tests/unit/cloudbender/test_data_connectors_unit.py +0 -0
- {trainml-0.5.17 → trainml-1.0.0}/tests/unit/cloudbender/test_datastores_unit.py +0 -0
- {trainml-0.5.17 → trainml-1.0.0}/tests/unit/cloudbender/test_device_configs_unit.py +0 -0
- {trainml-0.5.17 → trainml-1.0.0}/tests/unit/cloudbender/test_devices_unit.py +0 -0
- {trainml-0.5.17 → trainml-1.0.0}/tests/unit/projects/__init__.py +0 -0
- {trainml-0.5.17 → trainml-1.0.0}/tests/unit/projects/test_project_credentials_unit.py +0 -0
- {trainml-0.5.17 → trainml-1.0.0}/tests/unit/projects/test_project_secrets_unit.py +0 -0
- {trainml-0.5.17 → trainml-1.0.0}/tests/unit/test_environments_unit.py +0 -0
- {trainml-0.5.17 → trainml-1.0.0}/trainml/__main__.py +0 -0
- {trainml-0.5.17 → trainml-1.0.0}/trainml/cli/cloudbender/__init__.py +0 -0
- {trainml-0.5.17 → trainml-1.0.0}/trainml/cli/cloudbender/data_connector.py +0 -0
- {trainml-0.5.17 → trainml-1.0.0}/trainml/cli/cloudbender/datastore.py +0 -0
- {trainml-0.5.17 → trainml-1.0.0}/trainml/cli/cloudbender/device.py +0 -0
- {trainml-0.5.17 → trainml-1.0.0}/trainml/cli/cloudbender/node.py +0 -0
- {trainml-0.5.17 → trainml-1.0.0}/trainml/cli/cloudbender/provider.py +0 -0
- {trainml-0.5.17 → trainml-1.0.0}/trainml/cli/cloudbender/region.py +0 -0
- {trainml-0.5.17 → trainml-1.0.0}/trainml/cli/cloudbender/service.py +0 -0
- {trainml-0.5.17 → trainml-1.0.0}/trainml/cli/environment.py +0 -0
- {trainml-0.5.17 → trainml-1.0.0}/trainml/cli/gpu.py +0 -0
- {trainml-0.5.17 → trainml-1.0.0}/trainml/cli/project/__init__.py +0 -0
- {trainml-0.5.17 → trainml-1.0.0}/trainml/cli/project/credential.py +0 -0
- {trainml-0.5.17 → trainml-1.0.0}/trainml/cli/project/data_connector.py +0 -0
- {trainml-0.5.17 → trainml-1.0.0}/trainml/cli/project/datastore.py +0 -0
- {trainml-0.5.17 → trainml-1.0.0}/trainml/cli/project/secret.py +0 -0
- {trainml-0.5.17 → trainml-1.0.0}/trainml/cli/project/service.py +0 -0
- {trainml-0.5.17 → trainml-1.0.0}/trainml/cloudbender/__init__.py +0 -0
- {trainml-0.5.17 → trainml-1.0.0}/trainml/cloudbender/cloudbender.py +0 -0
- {trainml-0.5.17 → trainml-1.0.0}/trainml/cloudbender/data_connectors.py +0 -0
- {trainml-0.5.17 → trainml-1.0.0}/trainml/cloudbender/datastores.py +0 -0
- {trainml-0.5.17 → trainml-1.0.0}/trainml/cloudbender/device_configs.py +0 -0
- {trainml-0.5.17 → trainml-1.0.0}/trainml/cloudbender/devices.py +0 -0
- {trainml-0.5.17 → trainml-1.0.0}/trainml/cloudbender/nodes.py +0 -0
- {trainml-0.5.17 → trainml-1.0.0}/trainml/cloudbender/providers.py +0 -0
- {trainml-0.5.17 → trainml-1.0.0}/trainml/cloudbender/regions.py +0 -0
- {trainml-0.5.17 → trainml-1.0.0}/trainml/cloudbender/services.py +0 -0
- {trainml-0.5.17 → trainml-1.0.0}/trainml/environments.py +0 -0
- {trainml-0.5.17 → trainml-1.0.0}/trainml/exceptions.py +0 -0
- {trainml-0.5.17 → trainml-1.0.0}/trainml/gpu_types.py +0 -0
- {trainml-0.5.17 → trainml-1.0.0}/trainml/projects/__init__.py +0 -0
- {trainml-0.5.17 → trainml-1.0.0}/trainml/projects/credentials.py +0 -0
- {trainml-0.5.17 → trainml-1.0.0}/trainml/projects/data_connectors.py +0 -0
- {trainml-0.5.17 → trainml-1.0.0}/trainml/projects/datastores.py +0 -0
- {trainml-0.5.17 → trainml-1.0.0}/trainml/projects/members.py +0 -0
- {trainml-0.5.17 → trainml-1.0.0}/trainml/projects/projects.py +0 -0
- {trainml-0.5.17 → trainml-1.0.0}/trainml/projects/secrets.py +0 -0
- {trainml-0.5.17 → trainml-1.0.0}/trainml/projects/services.py +0 -0
- {trainml-0.5.17 → trainml-1.0.0}/trainml.egg-info/dependency_links.txt +0 -0
- {trainml-0.5.17 → trainml-1.0.0}/trainml.egg-info/entry_points.txt +0 -0
- {trainml-0.5.17 → trainml-1.0.0}/trainml.egg-info/top_level.txt +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.1
|
|
2
2
|
Name: trainml
|
|
3
|
-
Version: 0.
|
|
3
|
+
Version: 1.0.0
|
|
4
4
|
Summary: trainML client SDK and command line utilities
|
|
5
5
|
Home-page: https://github.com/trainML/trainml-cli
|
|
6
6
|
Author: trainML
|
|
@@ -177,7 +177,7 @@ Description: <div align="center">
|
|
|
177
177
|
trainml dataset list
|
|
178
178
|
```
|
|
179
179
|
|
|
180
|
-
To connect to a job that
|
|
180
|
+
To connect to a job that uses the "local" file transfer method:
|
|
181
181
|
|
|
182
182
|
```
|
|
183
183
|
trainml job connect <job ID or name>
|
|
@@ -169,7 +169,7 @@ To list all datasets:
|
|
|
169
169
|
trainml dataset list
|
|
170
170
|
```
|
|
171
171
|
|
|
172
|
-
To connect to a job that
|
|
172
|
+
To connect to a job that uses the "local" file transfer method:
|
|
173
173
|
|
|
174
174
|
```
|
|
175
175
|
trainml job connect <job ID or name>
|
|
@@ -19,7 +19,6 @@ async def create_dataset():
|
|
|
19
19
|
attach_task = asyncio.create_task(dataset.attach())
|
|
20
20
|
connect_task = asyncio.create_task(dataset.connect())
|
|
21
21
|
await asyncio.gather(attach_task, connect_task)
|
|
22
|
-
await dataset.disconnect()
|
|
23
22
|
return dataset
|
|
24
23
|
|
|
25
24
|
|
|
@@ -55,7 +54,6 @@ async def run_job(dataset):
|
|
|
55
54
|
await asyncio.gather(attach_task, connect_task)
|
|
56
55
|
|
|
57
56
|
# Cleanup job
|
|
58
|
-
await job.disconnect()
|
|
59
57
|
await job.remove()
|
|
60
58
|
|
|
61
59
|
|
|
@@ -25,6 +25,7 @@ markers = [
|
|
|
25
25
|
"data_connectors: Data Connector tests",
|
|
26
26
|
"services: Services tests",
|
|
27
27
|
"device_configs: DeviceConfigs tests",
|
|
28
|
+
"local: Local Connection Utility tests",
|
|
28
29
|
"unit: All unit tests (no trainML environment required)",
|
|
29
30
|
"integration: All integration tests (trainML environment required)",
|
|
30
31
|
"sdk: All tests of the SDK",
|
|
@@ -23,6 +23,7 @@ class GetCheckpointTests:
|
|
|
23
23
|
checkpoint = await checkpoint.wait_for("archived", 60)
|
|
24
24
|
|
|
25
25
|
async def test_get_checkpoints(self, trainml, checkpoint):
|
|
26
|
+
_ = checkpoint
|
|
26
27
|
checkpoints = await trainml.checkpoints.list()
|
|
27
28
|
assert len(checkpoints) > 0
|
|
28
29
|
|
|
@@ -55,7 +56,7 @@ class GetCheckpointTests:
|
|
|
55
56
|
|
|
56
57
|
@mark.create
|
|
57
58
|
@mark.asyncio
|
|
58
|
-
async def test_checkpoint_wasabi(trainml
|
|
59
|
+
async def test_checkpoint_wasabi(trainml):
|
|
59
60
|
checkpoint = await trainml.checkpoints.create(
|
|
60
61
|
name="CLI Automated Wasabi",
|
|
61
62
|
source_type="wasabi",
|
|
@@ -72,6 +73,7 @@ async def test_checkpoint_wasabi(trainml, capsys):
|
|
|
72
73
|
|
|
73
74
|
@mark.create
|
|
74
75
|
@mark.asyncio
|
|
76
|
+
@mark.local
|
|
75
77
|
async def test_checkpoint_local(trainml, capsys):
|
|
76
78
|
checkpoint = await trainml.checkpoints.create(
|
|
77
79
|
name="CLI Automated Local",
|
|
@@ -81,7 +83,6 @@ async def test_checkpoint_local(trainml, capsys):
|
|
|
81
83
|
attach_task = asyncio.create_task(checkpoint.attach())
|
|
82
84
|
connect_task = asyncio.create_task(checkpoint.connect())
|
|
83
85
|
await asyncio.gather(attach_task, connect_task)
|
|
84
|
-
await checkpoint.disconnect()
|
|
85
86
|
await checkpoint.refresh()
|
|
86
87
|
status = checkpoint.status
|
|
87
88
|
size = checkpoint.size
|
|
@@ -92,5 +93,5 @@ async def test_checkpoint_local(trainml, capsys):
|
|
|
92
93
|
sys.stdout.write(captured.out)
|
|
93
94
|
sys.stderr.write(captured.err)
|
|
94
95
|
assert "Starting data upload from local" in captured.out
|
|
95
|
-
assert "official/LICENSE
|
|
96
|
+
assert "official/LICENSE" in captured.out
|
|
96
97
|
assert "Upload complete" in captured.out
|
|
@@ -50,7 +50,9 @@ class GetDatasetTests:
|
|
|
50
50
|
async def test_dataset_repr(self, dataset):
|
|
51
51
|
string = repr(dataset)
|
|
52
52
|
regex = (
|
|
53
|
-
r"^Dataset\( trainml , \*\*{.*'dataset_uuid': '"
|
|
53
|
+
r"^Dataset\( trainml , \*\*{.*'dataset_uuid': '"
|
|
54
|
+
+ dataset.id
|
|
55
|
+
+ r"'.*}\)$"
|
|
54
56
|
)
|
|
55
57
|
assert isinstance(string, str)
|
|
56
58
|
assert re.match(regex, string)
|
|
@@ -79,6 +81,7 @@ class GetDatasetTests:
|
|
|
79
81
|
|
|
80
82
|
@mark.create
|
|
81
83
|
@mark.asyncio
|
|
84
|
+
@mark.local
|
|
82
85
|
async def test_dataset_local(trainml, capsys):
|
|
83
86
|
dataset = await trainml.datasets.create(
|
|
84
87
|
name="CLI Automated Local",
|
|
@@ -88,7 +91,6 @@ async def test_dataset_local(trainml, capsys):
|
|
|
88
91
|
attach_task = asyncio.create_task(dataset.attach())
|
|
89
92
|
connect_task = asyncio.create_task(dataset.connect())
|
|
90
93
|
await asyncio.gather(attach_task, connect_task)
|
|
91
|
-
await dataset.disconnect()
|
|
92
94
|
await dataset.refresh()
|
|
93
95
|
status = dataset.status
|
|
94
96
|
size = dataset.size
|
|
@@ -99,5 +101,5 @@ async def test_dataset_local(trainml, capsys):
|
|
|
99
101
|
sys.stdout.write(captured.out)
|
|
100
102
|
sys.stderr.write(captured.err)
|
|
101
103
|
assert "Starting data upload from local" in captured.out
|
|
102
|
-
assert "data_batch_1.bin
|
|
104
|
+
assert "data_batch_1.bin" in captured.out
|
|
103
105
|
assert "Upload complete" in captured.out
|
|
@@ -46,7 +46,10 @@ class JobLifeCycleTests:
|
|
|
46
46
|
job = await job.wait_for("running")
|
|
47
47
|
assert job.status == "running"
|
|
48
48
|
assert job.url
|
|
49
|
-
assert
|
|
49
|
+
assert (
|
|
50
|
+
extract_domain_suffix(urlparse(job.url).hostname)
|
|
51
|
+
== "proximl.cloud"
|
|
52
|
+
)
|
|
50
53
|
|
|
51
54
|
async def test_stop_job(self, job):
|
|
52
55
|
assert job.status == "running"
|
|
@@ -204,7 +207,8 @@ class JobAPIResourceValidationTests:
|
|
|
204
207
|
disk_size=10,
|
|
205
208
|
)
|
|
206
209
|
assert (
|
|
207
|
-
"Invalid Request - CPU Count must be a multiple of 4"
|
|
210
|
+
"Invalid Request - CPU Count must be a multiple of 4"
|
|
211
|
+
in error.value.message
|
|
208
212
|
)
|
|
209
213
|
|
|
210
214
|
async def test_invalid_gpu_count_for_cpu(self, trainml):
|
|
@@ -417,6 +421,7 @@ class JobAPIWorkerValidationTests:
|
|
|
417
421
|
@mark.asyncio
|
|
418
422
|
@mark.xdist_group("job_io")
|
|
419
423
|
class JobIOTests:
|
|
424
|
+
@mark.local
|
|
420
425
|
async def test_job_local_output(self, trainml, capsys):
|
|
421
426
|
temp_dir = tempfile.TemporaryDirectory()
|
|
422
427
|
job = await trainml.jobs.create(
|
|
@@ -426,7 +431,7 @@ class JobIOTests:
|
|
|
426
431
|
disk_size=10,
|
|
427
432
|
workers=["python $ML_MODEL_PATH/tensorflow/main.py"],
|
|
428
433
|
environment=dict(
|
|
429
|
-
type="
|
|
434
|
+
type="DEEPLEARNING_PY313",
|
|
430
435
|
env=[
|
|
431
436
|
dict(
|
|
432
437
|
key="CHECKPOINT_FILE",
|
|
@@ -452,13 +457,13 @@ class JobIOTests:
|
|
|
452
457
|
],
|
|
453
458
|
),
|
|
454
459
|
)
|
|
455
|
-
|
|
460
|
+
# Wait for job to reach running status since only output_type is local
|
|
461
|
+
await job.wait_for("running")
|
|
456
462
|
attach_task = asyncio.create_task(job.attach())
|
|
457
463
|
connect_task = asyncio.create_task(job.connect())
|
|
458
464
|
await asyncio.gather(attach_task, connect_task)
|
|
459
465
|
await job.refresh()
|
|
460
466
|
assert job.status == "finished"
|
|
461
|
-
await job.disconnect()
|
|
462
467
|
await job.remove()
|
|
463
468
|
upload_contents = os.listdir(temp_dir.name)
|
|
464
469
|
temp_dir.cleanup()
|
|
@@ -470,9 +475,8 @@ class JobIOTests:
|
|
|
470
475
|
captured = capsys.readouterr()
|
|
471
476
|
sys.stdout.write(captured.out)
|
|
472
477
|
sys.stderr.write(captured.err)
|
|
473
|
-
assert "Epoch 1/2" in captured.out
|
|
474
|
-
assert "
|
|
475
|
-
assert "adding: model.ckpt-0001" in captured.out
|
|
478
|
+
assert "Epoch 1/2" in captured.out or "Epoch 2/2" in captured.out
|
|
479
|
+
assert "model.ckpt-0001" in captured.out
|
|
476
480
|
assert "Send complete" in captured.out
|
|
477
481
|
|
|
478
482
|
async def test_job_model_input_and_output(self, trainml, capsys):
|
|
@@ -513,8 +517,7 @@ class JobIOTests:
|
|
|
513
517
|
captured = capsys.readouterr()
|
|
514
518
|
sys.stdout.write(captured.out)
|
|
515
519
|
sys.stderr.write(captured.err)
|
|
516
|
-
assert "Epoch 1/2" in captured.out
|
|
517
|
-
assert "Epoch 2/2" in captured.out
|
|
520
|
+
assert "Epoch 1/2" in captured.out or "Epoch 2/2" in captured.out
|
|
518
521
|
|
|
519
522
|
new_model = await trainml.models.get(workers[0].get("output_uuid"))
|
|
520
523
|
assert new_model.id
|
|
@@ -560,9 +563,12 @@ class JobTypeTests:
|
|
|
560
563
|
await job.wait_for("running")
|
|
561
564
|
await job.refresh()
|
|
562
565
|
assert job.url
|
|
563
|
-
assert
|
|
566
|
+
assert (
|
|
567
|
+
extract_domain_suffix(urlparse(job.url).hostname)
|
|
568
|
+
== "proximl.cloud"
|
|
569
|
+
)
|
|
564
570
|
tries = 0
|
|
565
|
-
await asyncio.sleep(180)
|
|
571
|
+
await asyncio.sleep(180) ## downloading weights can be slow
|
|
566
572
|
async with aiohttp.ClientSession() as session:
|
|
567
573
|
retry = True
|
|
568
574
|
while retry:
|
|
@@ -640,9 +646,11 @@ class JobTypeTests:
|
|
|
640
646
|
captured = capsys.readouterr()
|
|
641
647
|
sys.stdout.write(captured.out)
|
|
642
648
|
sys.stderr.write(captured.err)
|
|
643
|
-
assert "Epoch 1/2" in captured.out
|
|
644
|
-
assert
|
|
645
|
-
|
|
649
|
+
assert "Epoch 1/2" in captured.out or "Epoch 2/2" in captured.out
|
|
650
|
+
assert (
|
|
651
|
+
"Uploading s3://trainml-example/output/resnet_cifar10"
|
|
652
|
+
in captured.out
|
|
653
|
+
)
|
|
646
654
|
assert (
|
|
647
655
|
"upload: ./model.ckpt-0002.weights.h5 to s3://trainml-example/output/resnet_cifar10/model.ckpt-0002.weights.h5"
|
|
648
656
|
in captured.out
|
|
@@ -680,9 +688,12 @@ class JobFeatureTests:
|
|
|
680
688
|
captured = capsys.readouterr()
|
|
681
689
|
sys.stdout.write(captured.out)
|
|
682
690
|
sys.stderr.write(captured.err)
|
|
683
|
-
assert
|
|
684
|
-
|
|
691
|
+
assert (
|
|
692
|
+
"Train Epoch: 1 [0/60000 (0%)]" in captured.out
|
|
693
|
+
or "Train Epoch: 1 [59520/60000 (99%)]" in captured.out
|
|
694
|
+
)
|
|
685
695
|
|
|
696
|
+
@mark.local
|
|
686
697
|
async def test_inference_job(self, trainml, capsys):
|
|
687
698
|
temp_dir = tempfile.TemporaryDirectory()
|
|
688
699
|
job = await trainml.jobs.create(
|
|
@@ -706,11 +717,11 @@ class JobFeatureTests:
|
|
|
706
717
|
)
|
|
707
718
|
assert job.id
|
|
708
719
|
await job.wait_for("running")
|
|
709
|
-
|
|
710
|
-
|
|
720
|
+
attach_task = asyncio.create_task(job.attach())
|
|
721
|
+
connect_task = asyncio.create_task(job.connect())
|
|
722
|
+
await asyncio.gather(attach_task, connect_task)
|
|
711
723
|
await job.refresh()
|
|
712
724
|
assert job.status == "finished"
|
|
713
|
-
await job.disconnect()
|
|
714
725
|
await job.remove()
|
|
715
726
|
await job.wait_for("archived")
|
|
716
727
|
captured = capsys.readouterr()
|
|
@@ -719,15 +730,10 @@ class JobFeatureTests:
|
|
|
719
730
|
upload_contents = os.listdir(temp_dir.name)
|
|
720
731
|
temp_dir.cleanup()
|
|
721
732
|
assert len(upload_contents) >= 3
|
|
722
|
-
assert any(
|
|
723
|
-
"model.ckpt-0002" in content
|
|
724
|
-
for content in upload_contents
|
|
725
|
-
)
|
|
733
|
+
assert any("model.ckpt-0002" in content for content in upload_contents)
|
|
726
734
|
|
|
727
735
|
captured = capsys.readouterr()
|
|
728
736
|
sys.stdout.write(captured.out)
|
|
729
737
|
sys.stderr.write(captured.err)
|
|
730
|
-
assert "Epoch 1/2" in captured.out
|
|
731
|
-
assert "Epoch 2/2" in captured.out
|
|
732
|
-
assert "Number of regular files transferred: 4" in captured.out
|
|
738
|
+
assert "Epoch 1/2" in captured.out or "Epoch 2/2" in captured.out
|
|
733
739
|
assert "Send complete" in captured.out
|
|
@@ -44,7 +44,11 @@ class GetModelTests:
|
|
|
44
44
|
|
|
45
45
|
async def test_model_repr(self, model):
|
|
46
46
|
string = repr(model)
|
|
47
|
-
regex =
|
|
47
|
+
regex = (
|
|
48
|
+
r"^Model\( trainml , \*\*{.*'model_uuid': '"
|
|
49
|
+
+ model.id
|
|
50
|
+
+ r"'.*}\)$"
|
|
51
|
+
)
|
|
48
52
|
assert isinstance(string, str)
|
|
49
53
|
assert re.match(regex, string)
|
|
50
54
|
|
|
@@ -68,6 +72,7 @@ async def test_model_wasabi(trainml, capsys):
|
|
|
68
72
|
|
|
69
73
|
@mark.create
|
|
70
74
|
@mark.asyncio
|
|
75
|
+
@mark.local
|
|
71
76
|
async def test_model_local(trainml, capsys):
|
|
72
77
|
model = await trainml.models.create(
|
|
73
78
|
name="CLI Automated Local",
|
|
@@ -77,7 +82,6 @@ async def test_model_local(trainml, capsys):
|
|
|
77
82
|
attach_task = asyncio.create_task(model.attach())
|
|
78
83
|
connect_task = asyncio.create_task(model.connect())
|
|
79
84
|
await asyncio.gather(attach_task, connect_task)
|
|
80
|
-
await model.disconnect()
|
|
81
85
|
await model.refresh()
|
|
82
86
|
status = model.status
|
|
83
87
|
size = model.size
|
|
@@ -88,5 +92,5 @@ async def test_model_local(trainml, capsys):
|
|
|
88
92
|
sys.stdout.write(captured.out)
|
|
89
93
|
sys.stderr.write(captured.err)
|
|
90
94
|
assert "Starting data upload from local" in captured.out
|
|
91
|
-
assert "official/LICENSE
|
|
95
|
+
assert "official/LICENSE" in captured.out
|
|
92
96
|
assert "Upload complete" in captured.out
|
|
@@ -74,6 +74,7 @@ async def test_volume_wasabi(trainml, capsys):
|
|
|
74
74
|
|
|
75
75
|
@mark.create
|
|
76
76
|
@mark.asyncio
|
|
77
|
+
@mark.local
|
|
77
78
|
async def test_volume_local(trainml, capsys):
|
|
78
79
|
volume = await trainml.volumes.create(
|
|
79
80
|
name="CLI Automated Local",
|
|
@@ -84,7 +85,6 @@ async def test_volume_local(trainml, capsys):
|
|
|
84
85
|
attach_task = asyncio.create_task(volume.attach())
|
|
85
86
|
connect_task = asyncio.create_task(volume.connect())
|
|
86
87
|
await asyncio.gather(attach_task, connect_task)
|
|
87
|
-
await volume.disconnect()
|
|
88
88
|
await volume.refresh()
|
|
89
89
|
status = volume.status
|
|
90
90
|
billed_size = volume.billed_size
|
|
@@ -97,5 +97,5 @@ async def test_volume_local(trainml, capsys):
|
|
|
97
97
|
sys.stdout.write(captured.out)
|
|
98
98
|
sys.stderr.write(captured.err)
|
|
99
99
|
assert "Starting data upload from local" in captured.out
|
|
100
|
-
assert "official/LICENSE
|
|
100
|
+
assert "official/LICENSE" in captured.out
|
|
101
101
|
assert "Upload complete" in captured.out
|
|
@@ -0,0 +1,333 @@
|
|
|
1
|
+
import re
|
|
2
|
+
import json
|
|
3
|
+
import click
|
|
4
|
+
from unittest.mock import AsyncMock, patch, Mock
|
|
5
|
+
from pytest import mark, fixture, raises
|
|
6
|
+
|
|
7
|
+
pytestmark = [mark.cli, mark.unit, mark.checkpoints]
|
|
8
|
+
|
|
9
|
+
from trainml.cli import checkpoint as specimen
|
|
10
|
+
from trainml.cli.checkpoint import pretty_size
|
|
11
|
+
from trainml.checkpoints import Checkpoint
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def test_pretty_size_zero():
|
|
15
|
+
"""Test pretty_size with zero/None (line 7)."""
|
|
16
|
+
result = pretty_size(None)
|
|
17
|
+
assert result == "0.00 B"
|
|
18
|
+
result = pretty_size(0)
|
|
19
|
+
assert result == "0.00 B"
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def test_list(runner, mock_my_checkpoints):
|
|
23
|
+
with patch("trainml.cli.TrainML", new=AsyncMock) as mock_trainml:
|
|
24
|
+
mock_trainml.checkpoints = AsyncMock()
|
|
25
|
+
mock_trainml.checkpoints.list = AsyncMock(
|
|
26
|
+
return_value=mock_my_checkpoints
|
|
27
|
+
)
|
|
28
|
+
result = runner.invoke(specimen, ["list"])
|
|
29
|
+
print(result)
|
|
30
|
+
assert result.exit_code == 0
|
|
31
|
+
mock_trainml.checkpoints.list.assert_called_once()
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
def test_attach_success(runner, mock_my_checkpoints):
|
|
35
|
+
"""Test attach command success (lines 32-38)."""
|
|
36
|
+
with patch("trainml.cli.TrainML", new=AsyncMock) as mock_trainml:
|
|
37
|
+
|
|
38
|
+
async def list_async():
|
|
39
|
+
return mock_my_checkpoints
|
|
40
|
+
|
|
41
|
+
mock_trainml.checkpoints = AsyncMock()
|
|
42
|
+
mock_trainml.checkpoints.list = Mock(side_effect=lambda: list_async())
|
|
43
|
+
|
|
44
|
+
# Use the first checkpoint from the list
|
|
45
|
+
checkpoint = mock_my_checkpoints[0]
|
|
46
|
+
|
|
47
|
+
async def attach_async():
|
|
48
|
+
return None
|
|
49
|
+
|
|
50
|
+
checkpoint.attach = Mock(return_value=attach_async())
|
|
51
|
+
|
|
52
|
+
with patch("trainml.cli.search_by_id_name", return_value=checkpoint):
|
|
53
|
+
result = runner.invoke(specimen, ["attach", "1"])
|
|
54
|
+
assert result.exit_code == 0
|
|
55
|
+
checkpoint.attach.assert_called_once()
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
def test_attach_not_found(runner, mock_my_checkpoints):
|
|
59
|
+
"""Test attach command when checkpoint not found (line 36)."""
|
|
60
|
+
with patch("trainml.cli.TrainML", new=AsyncMock) as mock_trainml:
|
|
61
|
+
|
|
62
|
+
async def list_async():
|
|
63
|
+
return mock_my_checkpoints
|
|
64
|
+
|
|
65
|
+
mock_trainml.checkpoints = AsyncMock()
|
|
66
|
+
mock_trainml.checkpoints.list = Mock(side_effect=lambda: list_async())
|
|
67
|
+
|
|
68
|
+
with patch("trainml.cli.search_by_id_name", return_value=None):
|
|
69
|
+
result = runner.invoke(specimen, ["attach", "nonexistent"])
|
|
70
|
+
assert result.exit_code != 0
|
|
71
|
+
assert "Cannot find specified checkpoint" in result.output
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
def test_connect_with_attach(runner, mock_my_checkpoints):
|
|
75
|
+
"""Test connect command with attach (lines 56-65, attach=True)."""
|
|
76
|
+
with patch("trainml.cli.TrainML", new=AsyncMock) as mock_trainml:
|
|
77
|
+
|
|
78
|
+
async def list_async():
|
|
79
|
+
return mock_my_checkpoints
|
|
80
|
+
|
|
81
|
+
mock_trainml.checkpoints = AsyncMock()
|
|
82
|
+
mock_trainml.checkpoints.list = Mock(side_effect=lambda: list_async())
|
|
83
|
+
|
|
84
|
+
checkpoint = mock_my_checkpoints[0]
|
|
85
|
+
|
|
86
|
+
async def connect_async():
|
|
87
|
+
return None
|
|
88
|
+
|
|
89
|
+
async def attach_async():
|
|
90
|
+
return None
|
|
91
|
+
|
|
92
|
+
checkpoint.connect = Mock(return_value=connect_async())
|
|
93
|
+
checkpoint.attach = Mock(return_value=attach_async())
|
|
94
|
+
|
|
95
|
+
with patch("trainml.cli.search_by_id_name", return_value=checkpoint):
|
|
96
|
+
result = runner.invoke(specimen, ["connect", "1"])
|
|
97
|
+
assert result.exit_code == 0
|
|
98
|
+
checkpoint.connect.assert_called_once()
|
|
99
|
+
checkpoint.attach.assert_called_once()
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
def test_connect_no_attach(runner, mock_my_checkpoints):
|
|
103
|
+
"""Test connect command without attach (lines 56-65, attach=False)."""
|
|
104
|
+
with patch("trainml.cli.TrainML", new=AsyncMock) as mock_trainml:
|
|
105
|
+
|
|
106
|
+
async def list_async():
|
|
107
|
+
return mock_my_checkpoints
|
|
108
|
+
|
|
109
|
+
mock_trainml.checkpoints = AsyncMock()
|
|
110
|
+
mock_trainml.checkpoints.list = Mock(side_effect=lambda: list_async())
|
|
111
|
+
|
|
112
|
+
checkpoint = mock_my_checkpoints[0]
|
|
113
|
+
|
|
114
|
+
async def connect_async():
|
|
115
|
+
return None
|
|
116
|
+
|
|
117
|
+
checkpoint.connect = Mock(return_value=connect_async())
|
|
118
|
+
|
|
119
|
+
with patch("trainml.cli.search_by_id_name", return_value=checkpoint):
|
|
120
|
+
result = runner.invoke(specimen, ["connect", "--no-attach", "1"])
|
|
121
|
+
assert result.exit_code == 0
|
|
122
|
+
checkpoint.connect.assert_called_once()
|
|
123
|
+
|
|
124
|
+
|
|
125
|
+
def test_connect_not_found(runner, mock_my_checkpoints):
|
|
126
|
+
"""Test connect command when checkpoint not found (line 60)."""
|
|
127
|
+
with patch("trainml.cli.TrainML", new=AsyncMock) as mock_trainml:
|
|
128
|
+
|
|
129
|
+
async def list_async():
|
|
130
|
+
return mock_my_checkpoints
|
|
131
|
+
|
|
132
|
+
mock_trainml.checkpoints = AsyncMock()
|
|
133
|
+
mock_trainml.checkpoints.list = Mock(side_effect=lambda: list_async())
|
|
134
|
+
|
|
135
|
+
with patch("trainml.cli.search_by_id_name", return_value=None):
|
|
136
|
+
result = runner.invoke(specimen, ["connect", "nonexistent"])
|
|
137
|
+
assert result.exit_code != 0
|
|
138
|
+
assert "Cannot find specified checkpoint" in result.output
|
|
139
|
+
|
|
140
|
+
|
|
141
|
+
def test_create_with_connect_and_attach(runner, tmp_path, mock_my_checkpoints):
|
|
142
|
+
"""Test create command with connect and attach (lines 103-115)."""
|
|
143
|
+
with patch("trainml.cli.TrainML", new=AsyncMock) as mock_trainml:
|
|
144
|
+
checkpoint = mock_my_checkpoints[0]
|
|
145
|
+
|
|
146
|
+
async def connect_async():
|
|
147
|
+
return None
|
|
148
|
+
|
|
149
|
+
async def attach_async():
|
|
150
|
+
return None
|
|
151
|
+
|
|
152
|
+
checkpoint.connect = Mock(return_value=connect_async())
|
|
153
|
+
checkpoint.attach = Mock(return_value=attach_async())
|
|
154
|
+
|
|
155
|
+
async def create_async(**kwargs):
|
|
156
|
+
return checkpoint
|
|
157
|
+
|
|
158
|
+
mock_trainml.checkpoints = AsyncMock()
|
|
159
|
+
mock_trainml.checkpoints.create = Mock(
|
|
160
|
+
side_effect=lambda **kwargs: create_async(**kwargs)
|
|
161
|
+
)
|
|
162
|
+
|
|
163
|
+
test_dir = tmp_path / "test_checkpoint"
|
|
164
|
+
test_dir.mkdir()
|
|
165
|
+
result = runner.invoke(
|
|
166
|
+
specimen, ["create", "test-checkpoint", str(test_dir)]
|
|
167
|
+
)
|
|
168
|
+
assert result.exit_code == 0
|
|
169
|
+
mock_trainml.checkpoints.create.assert_called_once()
|
|
170
|
+
checkpoint.connect.assert_called_once()
|
|
171
|
+
checkpoint.attach.assert_called_once()
|
|
172
|
+
|
|
173
|
+
|
|
174
|
+
def test_create_with_connect_no_attach(runner, tmp_path, mock_my_checkpoints):
|
|
175
|
+
"""Test create command with connect but no attach (lines 103-115)."""
|
|
176
|
+
with patch("trainml.cli.TrainML", new=AsyncMock) as mock_trainml:
|
|
177
|
+
checkpoint = mock_my_checkpoints[0]
|
|
178
|
+
|
|
179
|
+
async def connect_async():
|
|
180
|
+
return None
|
|
181
|
+
|
|
182
|
+
checkpoint.connect = Mock(return_value=connect_async())
|
|
183
|
+
|
|
184
|
+
async def create_async(**kwargs):
|
|
185
|
+
return checkpoint
|
|
186
|
+
|
|
187
|
+
mock_trainml.checkpoints = AsyncMock()
|
|
188
|
+
mock_trainml.checkpoints.create = Mock(
|
|
189
|
+
side_effect=lambda **kwargs: create_async(**kwargs)
|
|
190
|
+
)
|
|
191
|
+
|
|
192
|
+
test_dir = tmp_path / "test_checkpoint"
|
|
193
|
+
test_dir.mkdir()
|
|
194
|
+
result = runner.invoke(
|
|
195
|
+
specimen,
|
|
196
|
+
["create", "--no-attach", "test-checkpoint", str(test_dir)],
|
|
197
|
+
)
|
|
198
|
+
assert result.exit_code == 0
|
|
199
|
+
checkpoint.connect.assert_called_once()
|
|
200
|
+
|
|
201
|
+
|
|
202
|
+
def test_create_no_connect(runner, tmp_path):
|
|
203
|
+
"""Test create command without connect (lines 103-115, line 115)."""
|
|
204
|
+
mock_checkpoint = Mock(spec=Checkpoint)
|
|
205
|
+
|
|
206
|
+
mock_trainml_runner = Mock()
|
|
207
|
+
mock_trainml_runner.client = Mock()
|
|
208
|
+
mock_trainml_runner.client.checkpoints = Mock()
|
|
209
|
+
mock_trainml_runner.client.checkpoints.create = AsyncMock(
|
|
210
|
+
return_value=mock_checkpoint
|
|
211
|
+
)
|
|
212
|
+
mock_trainml_runner.run = Mock(
|
|
213
|
+
side_effect=lambda x: x if not hasattr(x, "__call__") else x()
|
|
214
|
+
)
|
|
215
|
+
|
|
216
|
+
with patch("trainml.cli.TrainMLRunner", return_value=mock_trainml_runner):
|
|
217
|
+
test_dir = tmp_path / "test_checkpoint"
|
|
218
|
+
test_dir.mkdir()
|
|
219
|
+
result = runner.invoke(
|
|
220
|
+
specimen,
|
|
221
|
+
["create", "--no-connect", "test-checkpoint", str(test_dir)],
|
|
222
|
+
)
|
|
223
|
+
assert result.exit_code != 0
|
|
224
|
+
assert "No logs to show" in result.output
|
|
225
|
+
|
|
226
|
+
|
|
227
|
+
def test_list_public(runner, mock_my_checkpoints):
|
|
228
|
+
"""Test list_public command (lines 152-171)."""
|
|
229
|
+
with patch("trainml.cli.TrainML", new=AsyncMock) as mock_trainml:
|
|
230
|
+
mock_trainml.checkpoints = AsyncMock()
|
|
231
|
+
mock_trainml.checkpoints.list_public = AsyncMock(
|
|
232
|
+
return_value=mock_my_checkpoints
|
|
233
|
+
)
|
|
234
|
+
|
|
235
|
+
result = runner.invoke(specimen, ["list-public"])
|
|
236
|
+
assert result.exit_code == 0
|
|
237
|
+
mock_trainml.checkpoints.list_public.assert_called_once()
|
|
238
|
+
|
|
239
|
+
|
|
240
|
+
def test_remove_success(runner, mock_my_checkpoints):
|
|
241
|
+
"""Test remove command success (lines 192-201)."""
|
|
242
|
+
with patch("trainml.cli.TrainML", new=AsyncMock) as mock_trainml:
|
|
243
|
+
|
|
244
|
+
async def list_async():
|
|
245
|
+
return mock_my_checkpoints
|
|
246
|
+
|
|
247
|
+
mock_trainml.checkpoints = AsyncMock()
|
|
248
|
+
mock_trainml.checkpoints.list = Mock(side_effect=lambda: list_async())
|
|
249
|
+
|
|
250
|
+
checkpoint = mock_my_checkpoints[0]
|
|
251
|
+
|
|
252
|
+
async def remove_async():
|
|
253
|
+
return None
|
|
254
|
+
|
|
255
|
+
checkpoint.remove = Mock(return_value=remove_async())
|
|
256
|
+
|
|
257
|
+
with patch("trainml.cli.search_by_id_name", return_value=checkpoint):
|
|
258
|
+
result = runner.invoke(specimen, ["remove", "1"])
|
|
259
|
+
assert result.exit_code == 0
|
|
260
|
+
checkpoint.remove.assert_called_once_with(force=False)
|
|
261
|
+
|
|
262
|
+
|
|
263
|
+
def test_remove_not_found(runner, mock_my_checkpoints):
|
|
264
|
+
"""Test remove command when checkpoint not found (lines 192-201)."""
|
|
265
|
+
with patch("trainml.cli.TrainML", new=AsyncMock) as mock_trainml:
|
|
266
|
+
|
|
267
|
+
async def list_async():
|
|
268
|
+
return mock_my_checkpoints
|
|
269
|
+
|
|
270
|
+
mock_trainml.checkpoints = AsyncMock()
|
|
271
|
+
mock_trainml.checkpoints.list = Mock(side_effect=lambda: list_async())
|
|
272
|
+
|
|
273
|
+
with patch("trainml.cli.search_by_id_name", return_value=None):
|
|
274
|
+
result = runner.invoke(specimen, ["remove", "nonexistent"])
|
|
275
|
+
assert result.exit_code != 0
|
|
276
|
+
assert "Cannot find specified checkpoint" in result.output
|
|
277
|
+
|
|
278
|
+
|
|
279
|
+
def test_rename_success(runner, mock_my_checkpoints):
|
|
280
|
+
"""Test rename command success (lines 214-223)."""
|
|
281
|
+
with patch("trainml.cli.TrainML", new=AsyncMock) as mock_trainml:
|
|
282
|
+
checkpoint = mock_my_checkpoints[0]
|
|
283
|
+
|
|
284
|
+
async def rename_async():
|
|
285
|
+
return None
|
|
286
|
+
|
|
287
|
+
checkpoint.rename = Mock(return_value=rename_async())
|
|
288
|
+
|
|
289
|
+
async def get_async(checkpoint_id):
|
|
290
|
+
return checkpoint
|
|
291
|
+
|
|
292
|
+
mock_trainml.checkpoints = AsyncMock()
|
|
293
|
+
mock_trainml.checkpoints.get = Mock(
|
|
294
|
+
side_effect=lambda checkpoint_id: get_async(checkpoint_id)
|
|
295
|
+
)
|
|
296
|
+
|
|
297
|
+
result = runner.invoke(specimen, ["rename", "1", "new-name"])
|
|
298
|
+
assert result.exit_code == 0
|
|
299
|
+
checkpoint.rename.assert_called_once_with(name="new-name")
|
|
300
|
+
|
|
301
|
+
|
|
302
|
+
def test_rename_not_found_none(runner):
|
|
303
|
+
"""Test rename command when checkpoint is None (lines 214-223, line 219)."""
|
|
304
|
+
with patch("trainml.cli.TrainML", new=AsyncMock) as mock_trainml:
|
|
305
|
+
|
|
306
|
+
async def get_async(checkpoint_id):
|
|
307
|
+
return None
|
|
308
|
+
|
|
309
|
+
mock_trainml.checkpoints = AsyncMock()
|
|
310
|
+
mock_trainml.checkpoints.get = Mock(
|
|
311
|
+
side_effect=lambda checkpoint_id: get_async(checkpoint_id)
|
|
312
|
+
)
|
|
313
|
+
|
|
314
|
+
result = runner.invoke(specimen, ["rename", "nonexistent", "new-name"])
|
|
315
|
+
assert result.exit_code != 0
|
|
316
|
+
assert "Cannot find specified checkpoint" in result.output
|
|
317
|
+
|
|
318
|
+
|
|
319
|
+
def test_rename_not_found_exception(runner):
|
|
320
|
+
"""Test rename command when exception occurs (lines 214-223, line 221)."""
|
|
321
|
+
with patch("trainml.cli.TrainML", new=AsyncMock) as mock_trainml:
|
|
322
|
+
|
|
323
|
+
async def get_async(checkpoint_id):
|
|
324
|
+
raise Exception("Not found")
|
|
325
|
+
|
|
326
|
+
mock_trainml.checkpoints = AsyncMock()
|
|
327
|
+
mock_trainml.checkpoints.get = Mock(
|
|
328
|
+
side_effect=lambda checkpoint_id: get_async(checkpoint_id)
|
|
329
|
+
)
|
|
330
|
+
|
|
331
|
+
result = runner.invoke(specimen, ["rename", "nonexistent", "new-name"])
|
|
332
|
+
assert result.exit_code != 0
|
|
333
|
+
assert "Cannot find specified checkpoint" in result.output
|