trainml 0.5.17__py3-none-any.whl → 1.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- examples/local_storage.py +0 -2
- tests/integration/test_checkpoints_integration.py +4 -3
- tests/integration/test_datasets_integration.py +5 -3
- tests/integration/test_jobs_integration.py +33 -27
- tests/integration/test_models_integration.py +7 -3
- tests/integration/test_volumes_integration.py +2 -2
- tests/unit/cli/test_cli_checkpoint_unit.py +312 -1
- tests/unit/cloudbender/test_nodes_unit.py +112 -0
- tests/unit/cloudbender/test_providers_unit.py +96 -0
- tests/unit/cloudbender/test_regions_unit.py +106 -0
- tests/unit/cloudbender/test_services_unit.py +141 -0
- tests/unit/conftest.py +23 -10
- tests/unit/projects/test_project_data_connectors_unit.py +39 -0
- tests/unit/projects/test_project_datastores_unit.py +37 -0
- tests/unit/projects/test_project_members_unit.py +46 -0
- tests/unit/projects/test_project_services_unit.py +65 -0
- tests/unit/projects/test_projects_unit.py +16 -0
- tests/unit/test_auth_unit.py +17 -2
- tests/unit/test_checkpoints_unit.py +256 -71
- tests/unit/test_datasets_unit.py +218 -68
- tests/unit/test_exceptions.py +133 -0
- tests/unit/test_gpu_types_unit.py +11 -1
- tests/unit/test_jobs_unit.py +1014 -95
- tests/unit/test_main_unit.py +20 -0
- tests/unit/test_models_unit.py +218 -70
- tests/unit/test_trainml_unit.py +627 -3
- tests/unit/test_volumes_unit.py +211 -70
- tests/unit/utils/__init__.py +1 -0
- tests/unit/utils/test_transfer_unit.py +4260 -0
- trainml/__init__.py +1 -1
- trainml/checkpoints.py +56 -57
- trainml/cli/__init__.py +6 -3
- trainml/cli/checkpoint.py +18 -57
- trainml/cli/dataset.py +17 -57
- trainml/cli/job/__init__.py +89 -67
- trainml/cli/job/create.py +51 -24
- trainml/cli/model.py +14 -56
- trainml/cli/volume.py +18 -57
- trainml/datasets.py +50 -55
- trainml/jobs.py +269 -69
- trainml/models.py +51 -55
- trainml/trainml.py +159 -114
- trainml/utils/__init__.py +1 -0
- trainml/utils/auth.py +641 -0
- trainml/utils/transfer.py +647 -0
- trainml/volumes.py +48 -53
- {trainml-0.5.17.dist-info → trainml-1.0.1.dist-info}/METADATA +3 -3
- {trainml-0.5.17.dist-info → trainml-1.0.1.dist-info}/RECORD +52 -46
- {trainml-0.5.17.dist-info → trainml-1.0.1.dist-info}/LICENSE +0 -0
- {trainml-0.5.17.dist-info → trainml-1.0.1.dist-info}/WHEEL +0 -0
- {trainml-0.5.17.dist-info → trainml-1.0.1.dist-info}/entry_points.txt +0 -0
- {trainml-0.5.17.dist-info → trainml-1.0.1.dist-info}/top_level.txt +0 -0
|
@@ -8,6 +8,7 @@ from aiohttp import WSMessage, WSMsgType
|
|
|
8
8
|
import trainml.cloudbender.nodes as specimen
|
|
9
9
|
from trainml.exceptions import (
|
|
10
10
|
ApiError,
|
|
11
|
+
NodeError,
|
|
11
12
|
SpecificationError,
|
|
12
13
|
TrainMLException,
|
|
13
14
|
)
|
|
@@ -200,3 +201,114 @@ class nodeTests:
|
|
|
200
201
|
None,
|
|
201
202
|
dict(command="report"),
|
|
202
203
|
)
|
|
204
|
+
|
|
205
|
+
@mark.asyncio
|
|
206
|
+
async def test_node_wait_for_already_at_status(self, node):
|
|
207
|
+
"""Test wait_for returns immediately if already at target status."""
|
|
208
|
+
node._status = "active"
|
|
209
|
+
result = await node.wait_for("active")
|
|
210
|
+
assert result is None
|
|
211
|
+
|
|
212
|
+
@mark.asyncio
|
|
213
|
+
async def test_node_wait_for_invalid_status(self, node):
|
|
214
|
+
"""Test wait_for raises error for invalid status."""
|
|
215
|
+
with raises(SpecificationError) as exc_info:
|
|
216
|
+
await node.wait_for("invalid_status")
|
|
217
|
+
assert "Invalid wait_for status" in str(exc_info.value.message)
|
|
218
|
+
|
|
219
|
+
@mark.asyncio
|
|
220
|
+
async def test_node_wait_for_timeout_validation(self, node):
|
|
221
|
+
"""Test wait_for validates timeout (line 172)."""
|
|
222
|
+
node._status = "new" # Set to different status so timeout check runs
|
|
223
|
+
with raises(SpecificationError) as exc_info:
|
|
224
|
+
await node.wait_for("active", timeout=25 * 60 * 60)
|
|
225
|
+
assert "timeout must be less than" in str(exc_info.value.message)
|
|
226
|
+
|
|
227
|
+
@mark.asyncio
|
|
228
|
+
async def test_node_wait_for_success(self, node, mock_trainml):
|
|
229
|
+
"""Test wait_for succeeds when status matches."""
|
|
230
|
+
node._status = "new"
|
|
231
|
+
api_response_new = dict(
|
|
232
|
+
provider_uuid="1",
|
|
233
|
+
region_uuid="a",
|
|
234
|
+
rig_uuid="x",
|
|
235
|
+
status="new",
|
|
236
|
+
)
|
|
237
|
+
api_response_active = dict(
|
|
238
|
+
provider_uuid="1",
|
|
239
|
+
region_uuid="a",
|
|
240
|
+
rig_uuid="x",
|
|
241
|
+
status="active",
|
|
242
|
+
)
|
|
243
|
+
mock_trainml._query = AsyncMock(
|
|
244
|
+
side_effect=[api_response_new, api_response_active]
|
|
245
|
+
)
|
|
246
|
+
with patch("trainml.cloudbender.nodes.asyncio.sleep", new_callable=AsyncMock):
|
|
247
|
+
result = await node.wait_for("active", timeout=10)
|
|
248
|
+
assert result == node
|
|
249
|
+
assert node.status == "active"
|
|
250
|
+
|
|
251
|
+
@mark.asyncio
|
|
252
|
+
async def test_node_wait_for_archived_404(self, node, mock_trainml):
|
|
253
|
+
"""Test wait_for handles 404 for archived status."""
|
|
254
|
+
node._status = "active"
|
|
255
|
+
api_error = ApiError(404, {"errorMessage": "Not found"})
|
|
256
|
+
mock_trainml._query = AsyncMock(side_effect=api_error)
|
|
257
|
+
with patch("trainml.cloudbender.nodes.asyncio.sleep", new_callable=AsyncMock):
|
|
258
|
+
await node.wait_for("archived", timeout=10)
|
|
259
|
+
|
|
260
|
+
@mark.asyncio
|
|
261
|
+
async def test_node_wait_for_error_status(self, node, mock_trainml):
|
|
262
|
+
"""Test wait_for raises error for errored/failed status."""
|
|
263
|
+
node._status = "new"
|
|
264
|
+
api_response_errored = dict(
|
|
265
|
+
provider_uuid="1",
|
|
266
|
+
region_uuid="a",
|
|
267
|
+
rig_uuid="x",
|
|
268
|
+
status="errored",
|
|
269
|
+
)
|
|
270
|
+
mock_trainml._query = AsyncMock(return_value=api_response_errored)
|
|
271
|
+
with patch("trainml.cloudbender.nodes.asyncio.sleep", new_callable=AsyncMock):
|
|
272
|
+
with raises(NodeError):
|
|
273
|
+
await node.wait_for("active", timeout=10)
|
|
274
|
+
|
|
275
|
+
@mark.asyncio
|
|
276
|
+
async def test_node_wait_for_timeout(self, node, mock_trainml):
|
|
277
|
+
"""Test wait_for raises timeout exception."""
|
|
278
|
+
node._status = "new"
|
|
279
|
+
api_response_new = dict(
|
|
280
|
+
provider_uuid="1",
|
|
281
|
+
region_uuid="a",
|
|
282
|
+
rig_uuid="x",
|
|
283
|
+
status="new",
|
|
284
|
+
)
|
|
285
|
+
mock_trainml._query = AsyncMock(return_value=api_response_new)
|
|
286
|
+
with patch("trainml.cloudbender.nodes.asyncio.sleep", new_callable=AsyncMock):
|
|
287
|
+
with raises(TrainMLException) as exc_info:
|
|
288
|
+
await node.wait_for("active", timeout=0.1)
|
|
289
|
+
assert "Timeout waiting for" in str(exc_info.value.message)
|
|
290
|
+
|
|
291
|
+
@mark.asyncio
|
|
292
|
+
async def test_node_wait_for_api_error_non_404(self, node, mock_trainml):
|
|
293
|
+
"""Test wait_for raises ApiError when not 404 for archived (line 189)."""
|
|
294
|
+
node._status = "active"
|
|
295
|
+
api_error = ApiError(500, {"errorMessage": "Server Error"})
|
|
296
|
+
mock_trainml._query = AsyncMock(side_effect=api_error)
|
|
297
|
+
with patch("trainml.cloudbender.nodes.asyncio.sleep", new_callable=AsyncMock):
|
|
298
|
+
with raises(ApiError):
|
|
299
|
+
await node.wait_for("archived", timeout=10)
|
|
300
|
+
|
|
301
|
+
@mark.asyncio
|
|
302
|
+
async def test_node_wait_for_failed_status(self, node, mock_trainml):
|
|
303
|
+
"""Test wait_for raises error for failed status (line 191)."""
|
|
304
|
+
node._status = "new"
|
|
305
|
+
api_response_failed = dict(
|
|
306
|
+
provider_uuid="1",
|
|
307
|
+
region_uuid="a",
|
|
308
|
+
rig_uuid="x",
|
|
309
|
+
status="failed",
|
|
310
|
+
)
|
|
311
|
+
mock_trainml._query = AsyncMock(return_value=api_response_failed)
|
|
312
|
+
with patch("trainml.cloudbender.nodes.asyncio.sleep", new_callable=AsyncMock):
|
|
313
|
+
with raises(NodeError):
|
|
314
|
+
await node.wait_for("active", timeout=10)
|
|
@@ -139,3 +139,99 @@ class providerTests:
|
|
|
139
139
|
mock_trainml._query.assert_called_once_with(f"/provider/1", "GET")
|
|
140
140
|
assert provider.id == "provider-id-1"
|
|
141
141
|
assert response.id == "provider-id-1"
|
|
142
|
+
|
|
143
|
+
def test_provider_status_property(self, provider):
|
|
144
|
+
"""Test provider status property."""
|
|
145
|
+
provider._status = "ready"
|
|
146
|
+
assert provider.status == "ready"
|
|
147
|
+
|
|
148
|
+
@mark.asyncio
|
|
149
|
+
async def test_provider_wait_for_already_at_status(self, provider):
|
|
150
|
+
"""Test wait_for returns immediately if already at target status."""
|
|
151
|
+
provider._status = "ready"
|
|
152
|
+
result = await provider.wait_for("ready")
|
|
153
|
+
assert result is None
|
|
154
|
+
|
|
155
|
+
@mark.asyncio
|
|
156
|
+
async def test_provider_wait_for_invalid_status(self, provider):
|
|
157
|
+
"""Test wait_for raises error for invalid status."""
|
|
158
|
+
with raises(SpecificationError) as exc_info:
|
|
159
|
+
await provider.wait_for("invalid_status")
|
|
160
|
+
assert "Invalid wait_for status" in str(exc_info.value.message)
|
|
161
|
+
|
|
162
|
+
@mark.asyncio
|
|
163
|
+
async def test_provider_wait_for_timeout_validation(self, provider):
|
|
164
|
+
"""Test wait_for validates timeout."""
|
|
165
|
+
with raises(SpecificationError) as exc_info:
|
|
166
|
+
await provider.wait_for("ready", timeout=25 * 60 * 60)
|
|
167
|
+
assert "timeout must be less than" in str(exc_info.value.message)
|
|
168
|
+
|
|
169
|
+
@mark.asyncio
|
|
170
|
+
async def test_provider_wait_for_success(self, provider, mock_trainml):
|
|
171
|
+
"""Test wait_for succeeds when status matches."""
|
|
172
|
+
provider._status = "new"
|
|
173
|
+
api_response_new = dict(
|
|
174
|
+
customer_uuid="a",
|
|
175
|
+
provider_uuid="1",
|
|
176
|
+
status="new",
|
|
177
|
+
)
|
|
178
|
+
api_response_ready = dict(
|
|
179
|
+
customer_uuid="a",
|
|
180
|
+
provider_uuid="1",
|
|
181
|
+
status="ready",
|
|
182
|
+
)
|
|
183
|
+
mock_trainml._query = AsyncMock(
|
|
184
|
+
side_effect=[api_response_new, api_response_ready]
|
|
185
|
+
)
|
|
186
|
+
with patch("trainml.cloudbender.providers.asyncio.sleep", new_callable=AsyncMock):
|
|
187
|
+
result = await provider.wait_for("ready", timeout=10)
|
|
188
|
+
assert result == provider
|
|
189
|
+
assert provider.status == "ready"
|
|
190
|
+
|
|
191
|
+
@mark.asyncio
|
|
192
|
+
async def test_provider_wait_for_archived_404(self, provider, mock_trainml):
|
|
193
|
+
"""Test wait_for handles 404 for archived status."""
|
|
194
|
+
provider._status = "ready"
|
|
195
|
+
api_error = ApiError(404, {"errorMessage": "Not found"})
|
|
196
|
+
mock_trainml._query = AsyncMock(side_effect=api_error)
|
|
197
|
+
with patch("trainml.cloudbender.providers.asyncio.sleep", new_callable=AsyncMock):
|
|
198
|
+
await provider.wait_for("archived", timeout=10)
|
|
199
|
+
|
|
200
|
+
@mark.asyncio
|
|
201
|
+
async def test_provider_wait_for_error_status(self, provider, mock_trainml):
|
|
202
|
+
"""Test wait_for raises error for errored/failed status."""
|
|
203
|
+
provider._status = "new"
|
|
204
|
+
api_response_errored = dict(
|
|
205
|
+
customer_uuid="a",
|
|
206
|
+
provider_uuid="1",
|
|
207
|
+
status="errored",
|
|
208
|
+
)
|
|
209
|
+
mock_trainml._query = AsyncMock(return_value=api_response_errored)
|
|
210
|
+
with patch("trainml.cloudbender.providers.asyncio.sleep", new_callable=AsyncMock):
|
|
211
|
+
with raises(specimen.ProviderError):
|
|
212
|
+
await provider.wait_for("ready", timeout=10)
|
|
213
|
+
|
|
214
|
+
@mark.asyncio
|
|
215
|
+
async def test_provider_wait_for_timeout(self, provider, mock_trainml):
|
|
216
|
+
"""Test wait_for raises timeout exception."""
|
|
217
|
+
provider._status = "new"
|
|
218
|
+
api_response_new = dict(
|
|
219
|
+
customer_uuid="a",
|
|
220
|
+
provider_uuid="1",
|
|
221
|
+
status="new",
|
|
222
|
+
)
|
|
223
|
+
mock_trainml._query = AsyncMock(return_value=api_response_new)
|
|
224
|
+
with patch("trainml.cloudbender.providers.asyncio.sleep", new_callable=AsyncMock):
|
|
225
|
+
with raises(TrainMLException) as exc_info:
|
|
226
|
+
await provider.wait_for("ready", timeout=0.1)
|
|
227
|
+
assert "Timeout waiting for" in str(exc_info.value.message)
|
|
228
|
+
|
|
229
|
+
@mark.asyncio
|
|
230
|
+
async def test_provider_wait_for_api_error_non_404(self, provider, mock_trainml):
|
|
231
|
+
"""Test wait_for raises ApiError when not 404 for archived (line 115)."""
|
|
232
|
+
provider._status = "ready"
|
|
233
|
+
api_error = ApiError(500, {"errorMessage": "Server Error"})
|
|
234
|
+
mock_trainml._query = AsyncMock(side_effect=api_error)
|
|
235
|
+
with patch("trainml.cloudbender.providers.asyncio.sleep", new_callable=AsyncMock):
|
|
236
|
+
with raises(ApiError):
|
|
237
|
+
await provider.wait_for("archived", timeout=10)
|
|
@@ -195,3 +195,109 @@ class regionTests:
|
|
|
195
195
|
mock_trainml._query.assert_called_once_with(
|
|
196
196
|
"/provider/1/region/a/checkpoint", "POST", None, expected_payload
|
|
197
197
|
)
|
|
198
|
+
|
|
199
|
+
@mark.asyncio
|
|
200
|
+
async def test_region_wait_for_already_at_status(self, region):
|
|
201
|
+
"""Test wait_for returns immediately if already at target status."""
|
|
202
|
+
region._status = "healthy"
|
|
203
|
+
result = await region.wait_for("healthy")
|
|
204
|
+
assert result is None
|
|
205
|
+
|
|
206
|
+
@mark.asyncio
|
|
207
|
+
async def test_region_wait_for_invalid_status(self, region):
|
|
208
|
+
"""Test wait_for raises error for invalid status."""
|
|
209
|
+
with raises(SpecificationError) as exc_info:
|
|
210
|
+
await region.wait_for("invalid_status")
|
|
211
|
+
assert "Invalid wait_for status" in str(exc_info.value.message)
|
|
212
|
+
|
|
213
|
+
@mark.asyncio
|
|
214
|
+
async def test_region_wait_for_timeout_validation(self, region):
|
|
215
|
+
"""Test wait_for validates timeout (line 135)."""
|
|
216
|
+
region._status = "new" # Set to different status so timeout check runs
|
|
217
|
+
with raises(SpecificationError) as exc_info:
|
|
218
|
+
await region.wait_for("healthy", timeout=25 * 60 * 60)
|
|
219
|
+
assert "timeout must be less than" in str(exc_info.value.message)
|
|
220
|
+
|
|
221
|
+
@mark.asyncio
|
|
222
|
+
async def test_region_wait_for_success(self, region, mock_trainml):
|
|
223
|
+
"""Test wait_for succeeds when status matches."""
|
|
224
|
+
region._status = "new"
|
|
225
|
+
api_response_new = dict(
|
|
226
|
+
provider_uuid="1",
|
|
227
|
+
region_uuid="a",
|
|
228
|
+
status="new",
|
|
229
|
+
)
|
|
230
|
+
api_response_healthy = dict(
|
|
231
|
+
provider_uuid="1",
|
|
232
|
+
region_uuid="a",
|
|
233
|
+
status="healthy",
|
|
234
|
+
)
|
|
235
|
+
mock_trainml._query = AsyncMock(
|
|
236
|
+
side_effect=[api_response_new, api_response_healthy]
|
|
237
|
+
)
|
|
238
|
+
with patch("trainml.cloudbender.regions.asyncio.sleep", new_callable=AsyncMock):
|
|
239
|
+
result = await region.wait_for("healthy", timeout=10)
|
|
240
|
+
assert result == region
|
|
241
|
+
assert region.status == "healthy"
|
|
242
|
+
|
|
243
|
+
@mark.asyncio
|
|
244
|
+
async def test_region_wait_for_archived_404(self, region, mock_trainml):
|
|
245
|
+
"""Test wait_for handles 404 for archived status."""
|
|
246
|
+
region._status = "healthy"
|
|
247
|
+
api_error = ApiError(404, {"errorMessage": "Not found"})
|
|
248
|
+
mock_trainml._query = AsyncMock(side_effect=api_error)
|
|
249
|
+
with patch("trainml.cloudbender.regions.asyncio.sleep", new_callable=AsyncMock):
|
|
250
|
+
await region.wait_for("archived", timeout=10)
|
|
251
|
+
|
|
252
|
+
@mark.asyncio
|
|
253
|
+
async def test_region_wait_for_error_status(self, region, mock_trainml):
|
|
254
|
+
"""Test wait_for raises error for errored/failed status."""
|
|
255
|
+
region._status = "new"
|
|
256
|
+
api_response_errored = dict(
|
|
257
|
+
provider_uuid="1",
|
|
258
|
+
region_uuid="a",
|
|
259
|
+
status="errored",
|
|
260
|
+
)
|
|
261
|
+
mock_trainml._query = AsyncMock(return_value=api_response_errored)
|
|
262
|
+
with patch("trainml.cloudbender.regions.asyncio.sleep", new_callable=AsyncMock):
|
|
263
|
+
with raises(specimen.RegionError):
|
|
264
|
+
await region.wait_for("healthy", timeout=10)
|
|
265
|
+
|
|
266
|
+
@mark.asyncio
|
|
267
|
+
async def test_region_wait_for_timeout(self, region, mock_trainml):
|
|
268
|
+
"""Test wait_for raises timeout exception."""
|
|
269
|
+
region._status = "new"
|
|
270
|
+
api_response_new = dict(
|
|
271
|
+
provider_uuid="1",
|
|
272
|
+
region_uuid="a",
|
|
273
|
+
status="new",
|
|
274
|
+
)
|
|
275
|
+
mock_trainml._query = AsyncMock(return_value=api_response_new)
|
|
276
|
+
with patch("trainml.cloudbender.regions.asyncio.sleep", new_callable=AsyncMock):
|
|
277
|
+
with raises(TrainMLException) as exc_info:
|
|
278
|
+
await region.wait_for("healthy", timeout=0.1)
|
|
279
|
+
assert "Timeout waiting for" in str(exc_info.value.message)
|
|
280
|
+
|
|
281
|
+
@mark.asyncio
|
|
282
|
+
async def test_region_wait_for_api_error_non_404(self, region, mock_trainml):
|
|
283
|
+
"""Test wait_for raises ApiError when not 404 for archived (line 152)."""
|
|
284
|
+
region._status = "healthy"
|
|
285
|
+
api_error = ApiError(500, {"errorMessage": "Server Error"})
|
|
286
|
+
mock_trainml._query = AsyncMock(side_effect=api_error)
|
|
287
|
+
with patch("trainml.cloudbender.regions.asyncio.sleep", new_callable=AsyncMock):
|
|
288
|
+
with raises(ApiError):
|
|
289
|
+
await region.wait_for("archived", timeout=10)
|
|
290
|
+
|
|
291
|
+
@mark.asyncio
|
|
292
|
+
async def test_region_wait_for_failed_status(self, region, mock_trainml):
|
|
293
|
+
"""Test wait_for raises error for failed status."""
|
|
294
|
+
region._status = "new"
|
|
295
|
+
api_response_failed = dict(
|
|
296
|
+
provider_uuid="1",
|
|
297
|
+
region_uuid="a",
|
|
298
|
+
status="failed",
|
|
299
|
+
)
|
|
300
|
+
mock_trainml._query = AsyncMock(return_value=api_response_failed)
|
|
301
|
+
with patch("trainml.cloudbender.regions.asyncio.sleep", new_callable=AsyncMock):
|
|
302
|
+
with raises(specimen.RegionError):
|
|
303
|
+
await region.wait_for("healthy", timeout=10)
|
|
@@ -165,3 +165,144 @@ class serviceTests:
|
|
|
165
165
|
)
|
|
166
166
|
assert service.id == "service-id-1"
|
|
167
167
|
assert response.id == "service-id-1"
|
|
168
|
+
|
|
169
|
+
def test_service_status_property(self, service):
|
|
170
|
+
"""Test service status property."""
|
|
171
|
+
service._status = "active"
|
|
172
|
+
assert service.status == "active"
|
|
173
|
+
|
|
174
|
+
def test_service_port_property(self, service):
|
|
175
|
+
"""Test service port property."""
|
|
176
|
+
service._port = "443"
|
|
177
|
+
assert service.port == "443"
|
|
178
|
+
|
|
179
|
+
@mark.asyncio
|
|
180
|
+
async def test_service_wait_for_already_at_status(self, service):
|
|
181
|
+
"""Test wait_for returns immediately if already at target status."""
|
|
182
|
+
service._status = "active"
|
|
183
|
+
result = await service.wait_for("active")
|
|
184
|
+
assert result is None
|
|
185
|
+
|
|
186
|
+
@mark.asyncio
|
|
187
|
+
async def test_service_wait_for_invalid_status(self, service):
|
|
188
|
+
"""Test wait_for raises error for invalid status."""
|
|
189
|
+
with raises(SpecificationError) as exc_info:
|
|
190
|
+
await service.wait_for("invalid_status")
|
|
191
|
+
assert "Invalid wait_for status" in str(exc_info.value.message)
|
|
192
|
+
|
|
193
|
+
@mark.asyncio
|
|
194
|
+
async def test_service_wait_for_timeout_validation(self, service):
|
|
195
|
+
"""Test wait_for validates timeout."""
|
|
196
|
+
with raises(SpecificationError) as exc_info:
|
|
197
|
+
await service.wait_for("active", timeout=25 * 60 * 60)
|
|
198
|
+
assert "timeout must be less than" in str(exc_info.value.message)
|
|
199
|
+
|
|
200
|
+
@mark.asyncio
|
|
201
|
+
async def test_service_wait_for_success(self, service, mock_trainml):
|
|
202
|
+
"""Test wait_for succeeds when status matches."""
|
|
203
|
+
service._status = "new"
|
|
204
|
+
api_response_new = dict(
|
|
205
|
+
provider_uuid="1",
|
|
206
|
+
region_uuid="a",
|
|
207
|
+
service_id="x",
|
|
208
|
+
status="new",
|
|
209
|
+
)
|
|
210
|
+
api_response_active = dict(
|
|
211
|
+
provider_uuid="1",
|
|
212
|
+
region_uuid="a",
|
|
213
|
+
service_id="x",
|
|
214
|
+
status="active",
|
|
215
|
+
)
|
|
216
|
+
mock_trainml._query = AsyncMock(
|
|
217
|
+
side_effect=[api_response_new, api_response_active]
|
|
218
|
+
)
|
|
219
|
+
with patch("trainml.cloudbender.services.asyncio.sleep", new_callable=AsyncMock):
|
|
220
|
+
result = await service.wait_for("active", timeout=10)
|
|
221
|
+
assert result == service
|
|
222
|
+
assert service.status == "active"
|
|
223
|
+
|
|
224
|
+
@mark.asyncio
|
|
225
|
+
async def test_service_wait_for_archived_404(self, service, mock_trainml):
|
|
226
|
+
"""Test wait_for handles 404 for archived status."""
|
|
227
|
+
service._status = "active"
|
|
228
|
+
api_error = ApiError(404, {"errorMessage": "Not found"})
|
|
229
|
+
mock_trainml._query = AsyncMock(side_effect=api_error)
|
|
230
|
+
with patch("trainml.cloudbender.services.asyncio.sleep", new_callable=AsyncMock):
|
|
231
|
+
await service.wait_for("archived", timeout=10)
|
|
232
|
+
|
|
233
|
+
@mark.asyncio
|
|
234
|
+
async def test_service_wait_for_timeout(self, service, mock_trainml):
|
|
235
|
+
"""Test wait_for raises timeout exception."""
|
|
236
|
+
service._status = "new"
|
|
237
|
+
api_response_new = dict(
|
|
238
|
+
provider_uuid="1",
|
|
239
|
+
region_uuid="a",
|
|
240
|
+
service_id="x",
|
|
241
|
+
status="new",
|
|
242
|
+
)
|
|
243
|
+
mock_trainml._query = AsyncMock(return_value=api_response_new)
|
|
244
|
+
with patch("trainml.cloudbender.services.asyncio.sleep", new_callable=AsyncMock):
|
|
245
|
+
with raises(TrainMLException) as exc_info:
|
|
246
|
+
await service.wait_for("active", timeout=0.1)
|
|
247
|
+
assert "Timeout waiting for" in str(exc_info.value.message)
|
|
248
|
+
|
|
249
|
+
@mark.asyncio
|
|
250
|
+
async def test_service_wait_for_api_error_non_404(self, service, mock_trainml):
|
|
251
|
+
"""Test wait_for raises ApiError when not 404 for archived (line 181)."""
|
|
252
|
+
service._status = "active"
|
|
253
|
+
api_error = ApiError(500, {"errorMessage": "Server Error"})
|
|
254
|
+
mock_trainml._query = AsyncMock(side_effect=api_error)
|
|
255
|
+
with patch("trainml.cloudbender.services.asyncio.sleep", new_callable=AsyncMock):
|
|
256
|
+
with raises(ApiError):
|
|
257
|
+
await service.wait_for("archived", timeout=10)
|
|
258
|
+
|
|
259
|
+
@mark.asyncio
|
|
260
|
+
async def test_service_generate_certificate(self, service, mock_trainml):
|
|
261
|
+
"""Test generate_certificate method."""
|
|
262
|
+
api_response = {
|
|
263
|
+
"provider_uuid": "1",
|
|
264
|
+
"region_uuid": "a",
|
|
265
|
+
"service_id": "x",
|
|
266
|
+
"certificate": "cert-data",
|
|
267
|
+
}
|
|
268
|
+
mock_trainml._query = AsyncMock(return_value=api_response)
|
|
269
|
+
result = await service.generate_certificate()
|
|
270
|
+
mock_trainml._query.assert_called_once_with(
|
|
271
|
+
"/provider/1/region/a/service/x/certificate",
|
|
272
|
+
"POST",
|
|
273
|
+
{},
|
|
274
|
+
dict(algorithm="ed25519"),
|
|
275
|
+
)
|
|
276
|
+
assert result == service
|
|
277
|
+
|
|
278
|
+
@mark.asyncio
|
|
279
|
+
async def test_service_generate_certificate_custom_algorithm(self, service, mock_trainml):
|
|
280
|
+
"""Test generate_certificate with custom algorithm."""
|
|
281
|
+
api_response = {
|
|
282
|
+
"provider_uuid": "1",
|
|
283
|
+
"region_uuid": "a",
|
|
284
|
+
"service_id": "x",
|
|
285
|
+
"certificate": "cert-data",
|
|
286
|
+
}
|
|
287
|
+
mock_trainml._query = AsyncMock(return_value=api_response)
|
|
288
|
+
result = await service.generate_certificate(algorithm="rsa")
|
|
289
|
+
mock_trainml._query.assert_called_once_with(
|
|
290
|
+
"/provider/1/region/a/service/x/certificate",
|
|
291
|
+
"POST",
|
|
292
|
+
{},
|
|
293
|
+
dict(algorithm="rsa"),
|
|
294
|
+
)
|
|
295
|
+
|
|
296
|
+
@mark.asyncio
|
|
297
|
+
async def test_service_sign_client_certificate(self, service, mock_trainml):
|
|
298
|
+
"""Test sign_client_certificate method."""
|
|
299
|
+
api_response = {"certificate": "signed-cert-data"}
|
|
300
|
+
mock_trainml._query = AsyncMock(return_value=api_response)
|
|
301
|
+
result = await service.sign_client_certificate("csr-data")
|
|
302
|
+
mock_trainml._query.assert_called_once_with(
|
|
303
|
+
"/provider/1/region/a/service/x/certificate/sign",
|
|
304
|
+
"POST",
|
|
305
|
+
{},
|
|
306
|
+
dict(csr="csr-data"),
|
|
307
|
+
)
|
|
308
|
+
assert result == api_response
|
tests/unit/conftest.py
CHANGED
|
@@ -4,7 +4,7 @@ from pytest import fixture, mark
|
|
|
4
4
|
from unittest.mock import Mock, AsyncMock, patch, create_autospec
|
|
5
5
|
|
|
6
6
|
from trainml.trainml import TrainML
|
|
7
|
-
from trainml.auth import Auth
|
|
7
|
+
from trainml.utils.auth import Auth
|
|
8
8
|
from trainml.datasets import Dataset, Datasets
|
|
9
9
|
from trainml.checkpoints import Checkpoint, Checkpoints
|
|
10
10
|
from trainml.volumes import Volume, Volumes
|
|
@@ -12,14 +12,16 @@ from trainml.models import Model, Models
|
|
|
12
12
|
from trainml.gpu_types import GpuType, GpuTypes
|
|
13
13
|
from trainml.environments import Environment, Environments
|
|
14
14
|
from trainml.jobs import Job, Jobs
|
|
15
|
-
from trainml.connections import Connections
|
|
16
15
|
from trainml.projects import (
|
|
17
16
|
Projects,
|
|
18
17
|
Project,
|
|
19
18
|
)
|
|
20
19
|
from trainml.projects.datastores import ProjectDatastores, ProjectDatastore
|
|
21
20
|
from trainml.projects.services import ProjectServices, ProjectService
|
|
22
|
-
from trainml.projects.data_connectors import
|
|
21
|
+
from trainml.projects.data_connectors import (
|
|
22
|
+
ProjectDataConnectors,
|
|
23
|
+
ProjectDataConnector,
|
|
24
|
+
)
|
|
23
25
|
from trainml.projects.credentials import ProjectCredentials, ProjectCredential
|
|
24
26
|
from trainml.projects.secrets import ProjectSecrets, ProjectSecret
|
|
25
27
|
|
|
@@ -1130,12 +1132,13 @@ def mock_trainml(
|
|
|
1130
1132
|
trainml.gpu_types = create_autospec(GpuTypes)
|
|
1131
1133
|
trainml.environments = create_autospec(Environments)
|
|
1132
1134
|
trainml.jobs = create_autospec(Jobs)
|
|
1133
|
-
trainml.connections = create_autospec(Connections)
|
|
1134
1135
|
trainml.projects = create_autospec(Projects)
|
|
1135
1136
|
trainml.datasets.list = AsyncMock(return_value=mock_my_datasets)
|
|
1136
1137
|
trainml.datasets.list_public = AsyncMock(return_value=mock_public_datasets)
|
|
1137
1138
|
trainml.checkpoints.list = AsyncMock(return_value=mock_my_checkpoints)
|
|
1138
|
-
trainml.checkpoints.list_public = AsyncMock(
|
|
1139
|
+
trainml.checkpoints.list_public = AsyncMock(
|
|
1140
|
+
return_value=mock_public_checkpoints
|
|
1141
|
+
)
|
|
1139
1142
|
trainml.models.list = AsyncMock(return_value=mock_models)
|
|
1140
1143
|
trainml.volumes.list = AsyncMock(return_value=mock_my_volumes)
|
|
1141
1144
|
trainml.gpu_types.list = AsyncMock(return_value=mock_gpu_types)
|
|
@@ -1143,17 +1146,25 @@ def mock_trainml(
|
|
|
1143
1146
|
trainml.jobs.list = AsyncMock(return_value=mock_jobs)
|
|
1144
1147
|
trainml.projects.list = AsyncMock(return_value=mock_projects)
|
|
1145
1148
|
trainml.projects.datastores = create_autospec(ProjectDatastores)
|
|
1146
|
-
trainml.projects.datastores.list = AsyncMock(
|
|
1149
|
+
trainml.projects.datastores.list = AsyncMock(
|
|
1150
|
+
return_value=mock_project_datastores
|
|
1151
|
+
)
|
|
1147
1152
|
trainml.projects.services = create_autospec(ProjectServices)
|
|
1148
|
-
trainml.projects.services.list = AsyncMock(
|
|
1153
|
+
trainml.projects.services.list = AsyncMock(
|
|
1154
|
+
return_value=mock_project_services
|
|
1155
|
+
)
|
|
1149
1156
|
trainml.projects.data_connectors = create_autospec(ProjectDataConnectors)
|
|
1150
1157
|
trainml.projects.data_connectors.list = AsyncMock(
|
|
1151
1158
|
return_value=mock_project_data_connectors
|
|
1152
1159
|
)
|
|
1153
1160
|
trainml.projects.credentials = create_autospec(ProjectCredentials)
|
|
1154
|
-
trainml.projects.credentials.list = AsyncMock(
|
|
1161
|
+
trainml.projects.credentials.list = AsyncMock(
|
|
1162
|
+
return_value=mock_project_credentials
|
|
1163
|
+
)
|
|
1155
1164
|
trainml.projects.secrets = create_autospec(ProjectSecrets)
|
|
1156
|
-
trainml.projects.secrets.list = AsyncMock(
|
|
1165
|
+
trainml.projects.secrets.list = AsyncMock(
|
|
1166
|
+
return_value=mock_project_secrets
|
|
1167
|
+
)
|
|
1157
1168
|
|
|
1158
1169
|
trainml.cloudbender = create_autospec(Cloudbender)
|
|
1159
1170
|
|
|
@@ -1166,7 +1177,9 @@ def mock_trainml(
|
|
|
1166
1177
|
trainml.cloudbender.devices = create_autospec(Devices)
|
|
1167
1178
|
trainml.cloudbender.devices.list = AsyncMock(return_value=mock_devices)
|
|
1168
1179
|
trainml.cloudbender.datastores = create_autospec(Datastores)
|
|
1169
|
-
trainml.cloudbender.datastores.list = AsyncMock(
|
|
1180
|
+
trainml.cloudbender.datastores.list = AsyncMock(
|
|
1181
|
+
return_value=mock_datastores
|
|
1182
|
+
)
|
|
1170
1183
|
trainml.cloudbender.services = create_autospec(Services)
|
|
1171
1184
|
trainml.cloudbender.services.list = AsyncMock(return_value=mock_services)
|
|
1172
1185
|
trainml.cloudbender.data_connectors = create_autospec(DataConnectors)
|
|
@@ -33,6 +33,25 @@ def project_data_connector(mock_trainml):
|
|
|
33
33
|
|
|
34
34
|
|
|
35
35
|
class ProjectDataConnectorsTests:
|
|
36
|
+
@mark.asyncio
|
|
37
|
+
async def test_project_data_connectors_get(
|
|
38
|
+
self, project_data_connectors, mock_trainml
|
|
39
|
+
):
|
|
40
|
+
"""Test get method (lines 11-14)."""
|
|
41
|
+
api_response = {
|
|
42
|
+
"project_uuid": "proj-id-1",
|
|
43
|
+
"region_uuid": "reg-id-1",
|
|
44
|
+
"id": "connector-id-1",
|
|
45
|
+
"type": "custom",
|
|
46
|
+
"name": "On-Prem Connection A",
|
|
47
|
+
}
|
|
48
|
+
mock_trainml._query = AsyncMock(return_value=api_response)
|
|
49
|
+
result = await project_data_connectors.get("connector-id-1", param1="value1")
|
|
50
|
+
mock_trainml._query.assert_called_once_with(
|
|
51
|
+
"/project/1/data_connectors/connector-id-1", "GET", dict(param1="value1")
|
|
52
|
+
)
|
|
53
|
+
assert result.id == "connector-id-1"
|
|
54
|
+
|
|
36
55
|
@mark.asyncio
|
|
37
56
|
async def test_project_data_connectors_refresh(
|
|
38
57
|
self, project_data_connectors, mock_trainml
|
|
@@ -100,3 +119,23 @@ class ProjectDataConnectorTests:
|
|
|
100
119
|
empty_project_data_connector = specimen.ProjectDataConnector(mock_trainml)
|
|
101
120
|
assert bool(project_data_connector)
|
|
102
121
|
assert not bool(empty_project_data_connector)
|
|
122
|
+
|
|
123
|
+
@mark.asyncio
|
|
124
|
+
async def test_project_data_connector_enable(self, project_data_connector, mock_trainml):
|
|
125
|
+
"""Test enable method (line 72)."""
|
|
126
|
+
api_response = dict()
|
|
127
|
+
mock_trainml._query = AsyncMock(return_value=api_response)
|
|
128
|
+
await project_data_connector.enable()
|
|
129
|
+
mock_trainml._query.assert_called_once_with(
|
|
130
|
+
"/project/proj-id-1/data_connectors/ds-id-1/enable", "PATCH"
|
|
131
|
+
)
|
|
132
|
+
|
|
133
|
+
@mark.asyncio
|
|
134
|
+
async def test_project_data_connector_disable(self, project_data_connector, mock_trainml):
|
|
135
|
+
"""Test disable method (line 77)."""
|
|
136
|
+
api_response = dict()
|
|
137
|
+
mock_trainml._query = AsyncMock(return_value=api_response)
|
|
138
|
+
await project_data_connector.disable()
|
|
139
|
+
mock_trainml._query.assert_called_once_with(
|
|
140
|
+
"/project/proj-id-1/data_connectors/ds-id-1/disable", "PATCH"
|
|
141
|
+
)
|
|
@@ -33,6 +33,23 @@ def project_datastore(mock_trainml):
|
|
|
33
33
|
|
|
34
34
|
|
|
35
35
|
class ProjectDatastoresTests:
|
|
36
|
+
@mark.asyncio
|
|
37
|
+
async def test_project_datastores_get(self, project_datastores, mock_trainml):
|
|
38
|
+
"""Test get method (lines 11-14)."""
|
|
39
|
+
api_response = {
|
|
40
|
+
"project_uuid": "proj-id-1",
|
|
41
|
+
"region_uuid": "reg-id-1",
|
|
42
|
+
"id": "store-id-1",
|
|
43
|
+
"type": "nfs",
|
|
44
|
+
"name": "On-prem NFS",
|
|
45
|
+
}
|
|
46
|
+
mock_trainml._query = AsyncMock(return_value=api_response)
|
|
47
|
+
result = await project_datastores.get("store-id-1", param1="value1")
|
|
48
|
+
mock_trainml._query.assert_called_once_with(
|
|
49
|
+
"/project/1/datastores/store-id-1", "GET", dict(param1="value1")
|
|
50
|
+
)
|
|
51
|
+
assert result.id == "store-id-1"
|
|
52
|
+
|
|
36
53
|
@mark.asyncio
|
|
37
54
|
async def test_project_datastores_refresh(self, project_datastores, mock_trainml):
|
|
38
55
|
api_response = dict()
|
|
@@ -94,3 +111,23 @@ class ProjectDatastoreTests:
|
|
|
94
111
|
empty_project_datastore = specimen.ProjectDatastore(mock_trainml)
|
|
95
112
|
assert bool(project_datastore)
|
|
96
113
|
assert not bool(empty_project_datastore)
|
|
114
|
+
|
|
115
|
+
@mark.asyncio
|
|
116
|
+
async def test_project_datastore_enable(self, project_datastore, mock_trainml):
|
|
117
|
+
"""Test enable method (line 67)."""
|
|
118
|
+
api_response = dict()
|
|
119
|
+
mock_trainml._query = AsyncMock(return_value=api_response)
|
|
120
|
+
await project_datastore.enable()
|
|
121
|
+
mock_trainml._query.assert_called_once_with(
|
|
122
|
+
"/project/proj-id-1/datastores/ds-id-1/enable", "PATCH"
|
|
123
|
+
)
|
|
124
|
+
|
|
125
|
+
@mark.asyncio
|
|
126
|
+
async def test_project_datastore_disable(self, project_datastore, mock_trainml):
|
|
127
|
+
"""Test disable method (line 72)."""
|
|
128
|
+
api_response = dict()
|
|
129
|
+
mock_trainml._query = AsyncMock(return_value=api_response)
|
|
130
|
+
await project_datastore.disable()
|
|
131
|
+
mock_trainml._query.assert_called_once_with(
|
|
132
|
+
"/project/proj-id-1/datastores/ds-id-1/disable", "PATCH"
|
|
133
|
+
)
|