trainml 0.5.16__py3-none-any.whl → 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (53) hide show
  1. examples/local_storage.py +0 -2
  2. tests/integration/test_checkpoints_integration.py +4 -3
  3. tests/integration/test_datasets_integration.py +5 -3
  4. tests/integration/test_jobs_integration.py +33 -27
  5. tests/integration/test_models_integration.py +7 -3
  6. tests/integration/test_volumes_integration.py +2 -2
  7. tests/unit/cli/test_cli_checkpoint_unit.py +312 -1
  8. tests/unit/cloudbender/test_nodes_unit.py +112 -0
  9. tests/unit/cloudbender/test_providers_unit.py +96 -0
  10. tests/unit/cloudbender/test_regions_unit.py +106 -0
  11. tests/unit/cloudbender/test_services_unit.py +141 -0
  12. tests/unit/conftest.py +23 -10
  13. tests/unit/projects/test_project_data_connectors_unit.py +39 -0
  14. tests/unit/projects/test_project_datastores_unit.py +37 -0
  15. tests/unit/projects/test_project_members_unit.py +46 -0
  16. tests/unit/projects/test_project_services_unit.py +65 -0
  17. tests/unit/projects/test_projects_unit.py +17 -1
  18. tests/unit/test_auth_unit.py +17 -2
  19. tests/unit/test_checkpoints_unit.py +256 -71
  20. tests/unit/test_datasets_unit.py +218 -68
  21. tests/unit/test_exceptions.py +133 -0
  22. tests/unit/test_gpu_types_unit.py +11 -1
  23. tests/unit/test_jobs_unit.py +1014 -95
  24. tests/unit/test_main_unit.py +20 -0
  25. tests/unit/test_models_unit.py +218 -70
  26. tests/unit/test_trainml_unit.py +627 -3
  27. tests/unit/test_volumes_unit.py +211 -70
  28. tests/unit/utils/__init__.py +1 -0
  29. tests/unit/utils/test_transfer_unit.py +4260 -0
  30. trainml/__init__.py +1 -1
  31. trainml/checkpoints.py +56 -57
  32. trainml/cli/__init__.py +6 -3
  33. trainml/cli/checkpoint.py +18 -57
  34. trainml/cli/dataset.py +17 -57
  35. trainml/cli/job/__init__.py +11 -53
  36. trainml/cli/job/create.py +51 -24
  37. trainml/cli/model.py +14 -56
  38. trainml/cli/volume.py +18 -57
  39. trainml/datasets.py +50 -55
  40. trainml/jobs.py +239 -68
  41. trainml/models.py +51 -55
  42. trainml/projects/projects.py +2 -2
  43. trainml/trainml.py +50 -16
  44. trainml/utils/__init__.py +1 -0
  45. trainml/utils/auth.py +641 -0
  46. trainml/utils/transfer.py +587 -0
  47. trainml/volumes.py +48 -53
  48. {trainml-0.5.16.dist-info → trainml-1.0.0.dist-info}/METADATA +3 -3
  49. {trainml-0.5.16.dist-info → trainml-1.0.0.dist-info}/RECORD +53 -47
  50. {trainml-0.5.16.dist-info → trainml-1.0.0.dist-info}/LICENSE +0 -0
  51. {trainml-0.5.16.dist-info → trainml-1.0.0.dist-info}/WHEEL +0 -0
  52. {trainml-0.5.16.dist-info → trainml-1.0.0.dist-info}/entry_points.txt +0 -0
  53. {trainml-0.5.16.dist-info → trainml-1.0.0.dist-info}/top_level.txt +0 -0
@@ -8,6 +8,7 @@ from aiohttp import WSMessage, WSMsgType
8
8
  import trainml.cloudbender.nodes as specimen
9
9
  from trainml.exceptions import (
10
10
  ApiError,
11
+ NodeError,
11
12
  SpecificationError,
12
13
  TrainMLException,
13
14
  )
@@ -200,3 +201,114 @@ class nodeTests:
200
201
  None,
201
202
  dict(command="report"),
202
203
  )
204
+
205
+ @mark.asyncio
206
+ async def test_node_wait_for_already_at_status(self, node):
207
+ """Test wait_for returns immediately if already at target status."""
208
+ node._status = "active"
209
+ result = await node.wait_for("active")
210
+ assert result is None
211
+
212
+ @mark.asyncio
213
+ async def test_node_wait_for_invalid_status(self, node):
214
+ """Test wait_for raises error for invalid status."""
215
+ with raises(SpecificationError) as exc_info:
216
+ await node.wait_for("invalid_status")
217
+ assert "Invalid wait_for status" in str(exc_info.value.message)
218
+
219
+ @mark.asyncio
220
+ async def test_node_wait_for_timeout_validation(self, node):
221
+ """Test wait_for validates timeout (line 172)."""
222
+ node._status = "new" # Set to different status so timeout check runs
223
+ with raises(SpecificationError) as exc_info:
224
+ await node.wait_for("active", timeout=25 * 60 * 60)
225
+ assert "timeout must be less than" in str(exc_info.value.message)
226
+
227
+ @mark.asyncio
228
+ async def test_node_wait_for_success(self, node, mock_trainml):
229
+ """Test wait_for succeeds when status matches."""
230
+ node._status = "new"
231
+ api_response_new = dict(
232
+ provider_uuid="1",
233
+ region_uuid="a",
234
+ rig_uuid="x",
235
+ status="new",
236
+ )
237
+ api_response_active = dict(
238
+ provider_uuid="1",
239
+ region_uuid="a",
240
+ rig_uuid="x",
241
+ status="active",
242
+ )
243
+ mock_trainml._query = AsyncMock(
244
+ side_effect=[api_response_new, api_response_active]
245
+ )
246
+ with patch("trainml.cloudbender.nodes.asyncio.sleep", new_callable=AsyncMock):
247
+ result = await node.wait_for("active", timeout=10)
248
+ assert result == node
249
+ assert node.status == "active"
250
+
251
+ @mark.asyncio
252
+ async def test_node_wait_for_archived_404(self, node, mock_trainml):
253
+ """Test wait_for handles 404 for archived status."""
254
+ node._status = "active"
255
+ api_error = ApiError(404, {"errorMessage": "Not found"})
256
+ mock_trainml._query = AsyncMock(side_effect=api_error)
257
+ with patch("trainml.cloudbender.nodes.asyncio.sleep", new_callable=AsyncMock):
258
+ await node.wait_for("archived", timeout=10)
259
+
260
+ @mark.asyncio
261
+ async def test_node_wait_for_error_status(self, node, mock_trainml):
262
+ """Test wait_for raises error for errored/failed status."""
263
+ node._status = "new"
264
+ api_response_errored = dict(
265
+ provider_uuid="1",
266
+ region_uuid="a",
267
+ rig_uuid="x",
268
+ status="errored",
269
+ )
270
+ mock_trainml._query = AsyncMock(return_value=api_response_errored)
271
+ with patch("trainml.cloudbender.nodes.asyncio.sleep", new_callable=AsyncMock):
272
+ with raises(NodeError):
273
+ await node.wait_for("active", timeout=10)
274
+
275
+ @mark.asyncio
276
+ async def test_node_wait_for_timeout(self, node, mock_trainml):
277
+ """Test wait_for raises timeout exception."""
278
+ node._status = "new"
279
+ api_response_new = dict(
280
+ provider_uuid="1",
281
+ region_uuid="a",
282
+ rig_uuid="x",
283
+ status="new",
284
+ )
285
+ mock_trainml._query = AsyncMock(return_value=api_response_new)
286
+ with patch("trainml.cloudbender.nodes.asyncio.sleep", new_callable=AsyncMock):
287
+ with raises(TrainMLException) as exc_info:
288
+ await node.wait_for("active", timeout=0.1)
289
+ assert "Timeout waiting for" in str(exc_info.value.message)
290
+
291
+ @mark.asyncio
292
+ async def test_node_wait_for_api_error_non_404(self, node, mock_trainml):
293
+ """Test wait_for raises ApiError when not 404 for archived (line 189)."""
294
+ node._status = "active"
295
+ api_error = ApiError(500, {"errorMessage": "Server Error"})
296
+ mock_trainml._query = AsyncMock(side_effect=api_error)
297
+ with patch("trainml.cloudbender.nodes.asyncio.sleep", new_callable=AsyncMock):
298
+ with raises(ApiError):
299
+ await node.wait_for("archived", timeout=10)
300
+
301
+ @mark.asyncio
302
+ async def test_node_wait_for_failed_status(self, node, mock_trainml):
303
+ """Test wait_for raises error for failed status (line 191)."""
304
+ node._status = "new"
305
+ api_response_failed = dict(
306
+ provider_uuid="1",
307
+ region_uuid="a",
308
+ rig_uuid="x",
309
+ status="failed",
310
+ )
311
+ mock_trainml._query = AsyncMock(return_value=api_response_failed)
312
+ with patch("trainml.cloudbender.nodes.asyncio.sleep", new_callable=AsyncMock):
313
+ with raises(NodeError):
314
+ await node.wait_for("active", timeout=10)
@@ -139,3 +139,99 @@ class providerTests:
139
139
  mock_trainml._query.assert_called_once_with(f"/provider/1", "GET")
140
140
  assert provider.id == "provider-id-1"
141
141
  assert response.id == "provider-id-1"
142
+
143
+ def test_provider_status_property(self, provider):
144
+ """Test provider status property."""
145
+ provider._status = "ready"
146
+ assert provider.status == "ready"
147
+
148
+ @mark.asyncio
149
+ async def test_provider_wait_for_already_at_status(self, provider):
150
+ """Test wait_for returns immediately if already at target status."""
151
+ provider._status = "ready"
152
+ result = await provider.wait_for("ready")
153
+ assert result is None
154
+
155
+ @mark.asyncio
156
+ async def test_provider_wait_for_invalid_status(self, provider):
157
+ """Test wait_for raises error for invalid status."""
158
+ with raises(SpecificationError) as exc_info:
159
+ await provider.wait_for("invalid_status")
160
+ assert "Invalid wait_for status" in str(exc_info.value.message)
161
+
162
+ @mark.asyncio
163
+ async def test_provider_wait_for_timeout_validation(self, provider):
164
+ """Test wait_for validates timeout."""
165
+ with raises(SpecificationError) as exc_info:
166
+ await provider.wait_for("ready", timeout=25 * 60 * 60)
167
+ assert "timeout must be less than" in str(exc_info.value.message)
168
+
169
+ @mark.asyncio
170
+ async def test_provider_wait_for_success(self, provider, mock_trainml):
171
+ """Test wait_for succeeds when status matches."""
172
+ provider._status = "new"
173
+ api_response_new = dict(
174
+ customer_uuid="a",
175
+ provider_uuid="1",
176
+ status="new",
177
+ )
178
+ api_response_ready = dict(
179
+ customer_uuid="a",
180
+ provider_uuid="1",
181
+ status="ready",
182
+ )
183
+ mock_trainml._query = AsyncMock(
184
+ side_effect=[api_response_new, api_response_ready]
185
+ )
186
+ with patch("trainml.cloudbender.providers.asyncio.sleep", new_callable=AsyncMock):
187
+ result = await provider.wait_for("ready", timeout=10)
188
+ assert result == provider
189
+ assert provider.status == "ready"
190
+
191
+ @mark.asyncio
192
+ async def test_provider_wait_for_archived_404(self, provider, mock_trainml):
193
+ """Test wait_for handles 404 for archived status."""
194
+ provider._status = "ready"
195
+ api_error = ApiError(404, {"errorMessage": "Not found"})
196
+ mock_trainml._query = AsyncMock(side_effect=api_error)
197
+ with patch("trainml.cloudbender.providers.asyncio.sleep", new_callable=AsyncMock):
198
+ await provider.wait_for("archived", timeout=10)
199
+
200
+ @mark.asyncio
201
+ async def test_provider_wait_for_error_status(self, provider, mock_trainml):
202
+ """Test wait_for raises error for errored/failed status."""
203
+ provider._status = "new"
204
+ api_response_errored = dict(
205
+ customer_uuid="a",
206
+ provider_uuid="1",
207
+ status="errored",
208
+ )
209
+ mock_trainml._query = AsyncMock(return_value=api_response_errored)
210
+ with patch("trainml.cloudbender.providers.asyncio.sleep", new_callable=AsyncMock):
211
+ with raises(specimen.ProviderError):
212
+ await provider.wait_for("ready", timeout=10)
213
+
214
+ @mark.asyncio
215
+ async def test_provider_wait_for_timeout(self, provider, mock_trainml):
216
+ """Test wait_for raises timeout exception."""
217
+ provider._status = "new"
218
+ api_response_new = dict(
219
+ customer_uuid="a",
220
+ provider_uuid="1",
221
+ status="new",
222
+ )
223
+ mock_trainml._query = AsyncMock(return_value=api_response_new)
224
+ with patch("trainml.cloudbender.providers.asyncio.sleep", new_callable=AsyncMock):
225
+ with raises(TrainMLException) as exc_info:
226
+ await provider.wait_for("ready", timeout=0.1)
227
+ assert "Timeout waiting for" in str(exc_info.value.message)
228
+
229
+ @mark.asyncio
230
+ async def test_provider_wait_for_api_error_non_404(self, provider, mock_trainml):
231
+ """Test wait_for raises ApiError when not 404 for archived (line 115)."""
232
+ provider._status = "ready"
233
+ api_error = ApiError(500, {"errorMessage": "Server Error"})
234
+ mock_trainml._query = AsyncMock(side_effect=api_error)
235
+ with patch("trainml.cloudbender.providers.asyncio.sleep", new_callable=AsyncMock):
236
+ with raises(ApiError):
237
+ await provider.wait_for("archived", timeout=10)
@@ -195,3 +195,109 @@ class regionTests:
195
195
  mock_trainml._query.assert_called_once_with(
196
196
  "/provider/1/region/a/checkpoint", "POST", None, expected_payload
197
197
  )
198
+
199
+ @mark.asyncio
200
+ async def test_region_wait_for_already_at_status(self, region):
201
+ """Test wait_for returns immediately if already at target status."""
202
+ region._status = "healthy"
203
+ result = await region.wait_for("healthy")
204
+ assert result is None
205
+
206
+ @mark.asyncio
207
+ async def test_region_wait_for_invalid_status(self, region):
208
+ """Test wait_for raises error for invalid status."""
209
+ with raises(SpecificationError) as exc_info:
210
+ await region.wait_for("invalid_status")
211
+ assert "Invalid wait_for status" in str(exc_info.value.message)
212
+
213
+ @mark.asyncio
214
+ async def test_region_wait_for_timeout_validation(self, region):
215
+ """Test wait_for validates timeout (line 135)."""
216
+ region._status = "new" # Set to different status so timeout check runs
217
+ with raises(SpecificationError) as exc_info:
218
+ await region.wait_for("healthy", timeout=25 * 60 * 60)
219
+ assert "timeout must be less than" in str(exc_info.value.message)
220
+
221
+ @mark.asyncio
222
+ async def test_region_wait_for_success(self, region, mock_trainml):
223
+ """Test wait_for succeeds when status matches."""
224
+ region._status = "new"
225
+ api_response_new = dict(
226
+ provider_uuid="1",
227
+ region_uuid="a",
228
+ status="new",
229
+ )
230
+ api_response_healthy = dict(
231
+ provider_uuid="1",
232
+ region_uuid="a",
233
+ status="healthy",
234
+ )
235
+ mock_trainml._query = AsyncMock(
236
+ side_effect=[api_response_new, api_response_healthy]
237
+ )
238
+ with patch("trainml.cloudbender.regions.asyncio.sleep", new_callable=AsyncMock):
239
+ result = await region.wait_for("healthy", timeout=10)
240
+ assert result == region
241
+ assert region.status == "healthy"
242
+
243
+ @mark.asyncio
244
+ async def test_region_wait_for_archived_404(self, region, mock_trainml):
245
+ """Test wait_for handles 404 for archived status."""
246
+ region._status = "healthy"
247
+ api_error = ApiError(404, {"errorMessage": "Not found"})
248
+ mock_trainml._query = AsyncMock(side_effect=api_error)
249
+ with patch("trainml.cloudbender.regions.asyncio.sleep", new_callable=AsyncMock):
250
+ await region.wait_for("archived", timeout=10)
251
+
252
+ @mark.asyncio
253
+ async def test_region_wait_for_error_status(self, region, mock_trainml):
254
+ """Test wait_for raises error for errored/failed status."""
255
+ region._status = "new"
256
+ api_response_errored = dict(
257
+ provider_uuid="1",
258
+ region_uuid="a",
259
+ status="errored",
260
+ )
261
+ mock_trainml._query = AsyncMock(return_value=api_response_errored)
262
+ with patch("trainml.cloudbender.regions.asyncio.sleep", new_callable=AsyncMock):
263
+ with raises(specimen.RegionError):
264
+ await region.wait_for("healthy", timeout=10)
265
+
266
+ @mark.asyncio
267
+ async def test_region_wait_for_timeout(self, region, mock_trainml):
268
+ """Test wait_for raises timeout exception."""
269
+ region._status = "new"
270
+ api_response_new = dict(
271
+ provider_uuid="1",
272
+ region_uuid="a",
273
+ status="new",
274
+ )
275
+ mock_trainml._query = AsyncMock(return_value=api_response_new)
276
+ with patch("trainml.cloudbender.regions.asyncio.sleep", new_callable=AsyncMock):
277
+ with raises(TrainMLException) as exc_info:
278
+ await region.wait_for("healthy", timeout=0.1)
279
+ assert "Timeout waiting for" in str(exc_info.value.message)
280
+
281
+ @mark.asyncio
282
+ async def test_region_wait_for_api_error_non_404(self, region, mock_trainml):
283
+ """Test wait_for raises ApiError when not 404 for archived (line 152)."""
284
+ region._status = "healthy"
285
+ api_error = ApiError(500, {"errorMessage": "Server Error"})
286
+ mock_trainml._query = AsyncMock(side_effect=api_error)
287
+ with patch("trainml.cloudbender.regions.asyncio.sleep", new_callable=AsyncMock):
288
+ with raises(ApiError):
289
+ await region.wait_for("archived", timeout=10)
290
+
291
+ @mark.asyncio
292
+ async def test_region_wait_for_failed_status(self, region, mock_trainml):
293
+ """Test wait_for raises error for failed status."""
294
+ region._status = "new"
295
+ api_response_failed = dict(
296
+ provider_uuid="1",
297
+ region_uuid="a",
298
+ status="failed",
299
+ )
300
+ mock_trainml._query = AsyncMock(return_value=api_response_failed)
301
+ with patch("trainml.cloudbender.regions.asyncio.sleep", new_callable=AsyncMock):
302
+ with raises(specimen.RegionError):
303
+ await region.wait_for("healthy", timeout=10)
@@ -165,3 +165,144 @@ class serviceTests:
165
165
  )
166
166
  assert service.id == "service-id-1"
167
167
  assert response.id == "service-id-1"
168
+
169
+ def test_service_status_property(self, service):
170
+ """Test service status property."""
171
+ service._status = "active"
172
+ assert service.status == "active"
173
+
174
+ def test_service_port_property(self, service):
175
+ """Test service port property."""
176
+ service._port = "443"
177
+ assert service.port == "443"
178
+
179
+ @mark.asyncio
180
+ async def test_service_wait_for_already_at_status(self, service):
181
+ """Test wait_for returns immediately if already at target status."""
182
+ service._status = "active"
183
+ result = await service.wait_for("active")
184
+ assert result is None
185
+
186
+ @mark.asyncio
187
+ async def test_service_wait_for_invalid_status(self, service):
188
+ """Test wait_for raises error for invalid status."""
189
+ with raises(SpecificationError) as exc_info:
190
+ await service.wait_for("invalid_status")
191
+ assert "Invalid wait_for status" in str(exc_info.value.message)
192
+
193
+ @mark.asyncio
194
+ async def test_service_wait_for_timeout_validation(self, service):
195
+ """Test wait_for validates timeout."""
196
+ with raises(SpecificationError) as exc_info:
197
+ await service.wait_for("active", timeout=25 * 60 * 60)
198
+ assert "timeout must be less than" in str(exc_info.value.message)
199
+
200
+ @mark.asyncio
201
+ async def test_service_wait_for_success(self, service, mock_trainml):
202
+ """Test wait_for succeeds when status matches."""
203
+ service._status = "new"
204
+ api_response_new = dict(
205
+ provider_uuid="1",
206
+ region_uuid="a",
207
+ service_id="x",
208
+ status="new",
209
+ )
210
+ api_response_active = dict(
211
+ provider_uuid="1",
212
+ region_uuid="a",
213
+ service_id="x",
214
+ status="active",
215
+ )
216
+ mock_trainml._query = AsyncMock(
217
+ side_effect=[api_response_new, api_response_active]
218
+ )
219
+ with patch("trainml.cloudbender.services.asyncio.sleep", new_callable=AsyncMock):
220
+ result = await service.wait_for("active", timeout=10)
221
+ assert result == service
222
+ assert service.status == "active"
223
+
224
+ @mark.asyncio
225
+ async def test_service_wait_for_archived_404(self, service, mock_trainml):
226
+ """Test wait_for handles 404 for archived status."""
227
+ service._status = "active"
228
+ api_error = ApiError(404, {"errorMessage": "Not found"})
229
+ mock_trainml._query = AsyncMock(side_effect=api_error)
230
+ with patch("trainml.cloudbender.services.asyncio.sleep", new_callable=AsyncMock):
231
+ await service.wait_for("archived", timeout=10)
232
+
233
+ @mark.asyncio
234
+ async def test_service_wait_for_timeout(self, service, mock_trainml):
235
+ """Test wait_for raises timeout exception."""
236
+ service._status = "new"
237
+ api_response_new = dict(
238
+ provider_uuid="1",
239
+ region_uuid="a",
240
+ service_id="x",
241
+ status="new",
242
+ )
243
+ mock_trainml._query = AsyncMock(return_value=api_response_new)
244
+ with patch("trainml.cloudbender.services.asyncio.sleep", new_callable=AsyncMock):
245
+ with raises(TrainMLException) as exc_info:
246
+ await service.wait_for("active", timeout=0.1)
247
+ assert "Timeout waiting for" in str(exc_info.value.message)
248
+
249
+ @mark.asyncio
250
+ async def test_service_wait_for_api_error_non_404(self, service, mock_trainml):
251
+ """Test wait_for raises ApiError when not 404 for archived (line 181)."""
252
+ service._status = "active"
253
+ api_error = ApiError(500, {"errorMessage": "Server Error"})
254
+ mock_trainml._query = AsyncMock(side_effect=api_error)
255
+ with patch("trainml.cloudbender.services.asyncio.sleep", new_callable=AsyncMock):
256
+ with raises(ApiError):
257
+ await service.wait_for("archived", timeout=10)
258
+
259
+ @mark.asyncio
260
+ async def test_service_generate_certificate(self, service, mock_trainml):
261
+ """Test generate_certificate method."""
262
+ api_response = {
263
+ "provider_uuid": "1",
264
+ "region_uuid": "a",
265
+ "service_id": "x",
266
+ "certificate": "cert-data",
267
+ }
268
+ mock_trainml._query = AsyncMock(return_value=api_response)
269
+ result = await service.generate_certificate()
270
+ mock_trainml._query.assert_called_once_with(
271
+ "/provider/1/region/a/service/x/certificate",
272
+ "POST",
273
+ {},
274
+ dict(algorithm="ed25519"),
275
+ )
276
+ assert result == service
277
+
278
+ @mark.asyncio
279
+ async def test_service_generate_certificate_custom_algorithm(self, service, mock_trainml):
280
+ """Test generate_certificate with custom algorithm."""
281
+ api_response = {
282
+ "provider_uuid": "1",
283
+ "region_uuid": "a",
284
+ "service_id": "x",
285
+ "certificate": "cert-data",
286
+ }
287
+ mock_trainml._query = AsyncMock(return_value=api_response)
288
+ result = await service.generate_certificate(algorithm="rsa")
289
+ mock_trainml._query.assert_called_once_with(
290
+ "/provider/1/region/a/service/x/certificate",
291
+ "POST",
292
+ {},
293
+ dict(algorithm="rsa"),
294
+ )
295
+
296
+ @mark.asyncio
297
+ async def test_service_sign_client_certificate(self, service, mock_trainml):
298
+ """Test sign_client_certificate method."""
299
+ api_response = {"certificate": "signed-cert-data"}
300
+ mock_trainml._query = AsyncMock(return_value=api_response)
301
+ result = await service.sign_client_certificate("csr-data")
302
+ mock_trainml._query.assert_called_once_with(
303
+ "/provider/1/region/a/service/x/certificate/sign",
304
+ "POST",
305
+ {},
306
+ dict(csr="csr-data"),
307
+ )
308
+ assert result == api_response
tests/unit/conftest.py CHANGED
@@ -4,7 +4,7 @@ from pytest import fixture, mark
4
4
  from unittest.mock import Mock, AsyncMock, patch, create_autospec
5
5
 
6
6
  from trainml.trainml import TrainML
7
- from trainml.auth import Auth
7
+ from trainml.utils.auth import Auth
8
8
  from trainml.datasets import Dataset, Datasets
9
9
  from trainml.checkpoints import Checkpoint, Checkpoints
10
10
  from trainml.volumes import Volume, Volumes
@@ -12,14 +12,16 @@ from trainml.models import Model, Models
12
12
  from trainml.gpu_types import GpuType, GpuTypes
13
13
  from trainml.environments import Environment, Environments
14
14
  from trainml.jobs import Job, Jobs
15
- from trainml.connections import Connections
16
15
  from trainml.projects import (
17
16
  Projects,
18
17
  Project,
19
18
  )
20
19
  from trainml.projects.datastores import ProjectDatastores, ProjectDatastore
21
20
  from trainml.projects.services import ProjectServices, ProjectService
22
- from trainml.projects.data_connectors import ProjectDataConnectors, ProjectDataConnector
21
+ from trainml.projects.data_connectors import (
22
+ ProjectDataConnectors,
23
+ ProjectDataConnector,
24
+ )
23
25
  from trainml.projects.credentials import ProjectCredentials, ProjectCredential
24
26
  from trainml.projects.secrets import ProjectSecrets, ProjectSecret
25
27
 
@@ -1130,12 +1132,13 @@ def mock_trainml(
1130
1132
  trainml.gpu_types = create_autospec(GpuTypes)
1131
1133
  trainml.environments = create_autospec(Environments)
1132
1134
  trainml.jobs = create_autospec(Jobs)
1133
- trainml.connections = create_autospec(Connections)
1134
1135
  trainml.projects = create_autospec(Projects)
1135
1136
  trainml.datasets.list = AsyncMock(return_value=mock_my_datasets)
1136
1137
  trainml.datasets.list_public = AsyncMock(return_value=mock_public_datasets)
1137
1138
  trainml.checkpoints.list = AsyncMock(return_value=mock_my_checkpoints)
1138
- trainml.checkpoints.list_public = AsyncMock(return_value=mock_public_checkpoints)
1139
+ trainml.checkpoints.list_public = AsyncMock(
1140
+ return_value=mock_public_checkpoints
1141
+ )
1139
1142
  trainml.models.list = AsyncMock(return_value=mock_models)
1140
1143
  trainml.volumes.list = AsyncMock(return_value=mock_my_volumes)
1141
1144
  trainml.gpu_types.list = AsyncMock(return_value=mock_gpu_types)
@@ -1143,17 +1146,25 @@ def mock_trainml(
1143
1146
  trainml.jobs.list = AsyncMock(return_value=mock_jobs)
1144
1147
  trainml.projects.list = AsyncMock(return_value=mock_projects)
1145
1148
  trainml.projects.datastores = create_autospec(ProjectDatastores)
1146
- trainml.projects.datastores.list = AsyncMock(return_value=mock_project_datastores)
1149
+ trainml.projects.datastores.list = AsyncMock(
1150
+ return_value=mock_project_datastores
1151
+ )
1147
1152
  trainml.projects.services = create_autospec(ProjectServices)
1148
- trainml.projects.services.list = AsyncMock(return_value=mock_project_services)
1153
+ trainml.projects.services.list = AsyncMock(
1154
+ return_value=mock_project_services
1155
+ )
1149
1156
  trainml.projects.data_connectors = create_autospec(ProjectDataConnectors)
1150
1157
  trainml.projects.data_connectors.list = AsyncMock(
1151
1158
  return_value=mock_project_data_connectors
1152
1159
  )
1153
1160
  trainml.projects.credentials = create_autospec(ProjectCredentials)
1154
- trainml.projects.credentials.list = AsyncMock(return_value=mock_project_credentials)
1161
+ trainml.projects.credentials.list = AsyncMock(
1162
+ return_value=mock_project_credentials
1163
+ )
1155
1164
  trainml.projects.secrets = create_autospec(ProjectSecrets)
1156
- trainml.projects.secrets.list = AsyncMock(return_value=mock_project_secrets)
1165
+ trainml.projects.secrets.list = AsyncMock(
1166
+ return_value=mock_project_secrets
1167
+ )
1157
1168
 
1158
1169
  trainml.cloudbender = create_autospec(Cloudbender)
1159
1170
 
@@ -1166,7 +1177,9 @@ def mock_trainml(
1166
1177
  trainml.cloudbender.devices = create_autospec(Devices)
1167
1178
  trainml.cloudbender.devices.list = AsyncMock(return_value=mock_devices)
1168
1179
  trainml.cloudbender.datastores = create_autospec(Datastores)
1169
- trainml.cloudbender.datastores.list = AsyncMock(return_value=mock_datastores)
1180
+ trainml.cloudbender.datastores.list = AsyncMock(
1181
+ return_value=mock_datastores
1182
+ )
1170
1183
  trainml.cloudbender.services = create_autospec(Services)
1171
1184
  trainml.cloudbender.services.list = AsyncMock(return_value=mock_services)
1172
1185
  trainml.cloudbender.data_connectors = create_autospec(DataConnectors)
@@ -33,6 +33,25 @@ def project_data_connector(mock_trainml):
33
33
 
34
34
 
35
35
  class ProjectDataConnectorsTests:
36
+ @mark.asyncio
37
+ async def test_project_data_connectors_get(
38
+ self, project_data_connectors, mock_trainml
39
+ ):
40
+ """Test get method (lines 11-14)."""
41
+ api_response = {
42
+ "project_uuid": "proj-id-1",
43
+ "region_uuid": "reg-id-1",
44
+ "id": "connector-id-1",
45
+ "type": "custom",
46
+ "name": "On-Prem Connection A",
47
+ }
48
+ mock_trainml._query = AsyncMock(return_value=api_response)
49
+ result = await project_data_connectors.get("connector-id-1", param1="value1")
50
+ mock_trainml._query.assert_called_once_with(
51
+ "/project/1/data_connectors/connector-id-1", "GET", dict(param1="value1")
52
+ )
53
+ assert result.id == "connector-id-1"
54
+
36
55
  @mark.asyncio
37
56
  async def test_project_data_connectors_refresh(
38
57
  self, project_data_connectors, mock_trainml
@@ -100,3 +119,23 @@ class ProjectDataConnectorTests:
100
119
  empty_project_data_connector = specimen.ProjectDataConnector(mock_trainml)
101
120
  assert bool(project_data_connector)
102
121
  assert not bool(empty_project_data_connector)
122
+
123
+ @mark.asyncio
124
+ async def test_project_data_connector_enable(self, project_data_connector, mock_trainml):
125
+ """Test enable method (line 72)."""
126
+ api_response = dict()
127
+ mock_trainml._query = AsyncMock(return_value=api_response)
128
+ await project_data_connector.enable()
129
+ mock_trainml._query.assert_called_once_with(
130
+ "/project/proj-id-1/data_connectors/ds-id-1/enable", "PATCH"
131
+ )
132
+
133
+ @mark.asyncio
134
+ async def test_project_data_connector_disable(self, project_data_connector, mock_trainml):
135
+ """Test disable method (line 77)."""
136
+ api_response = dict()
137
+ mock_trainml._query = AsyncMock(return_value=api_response)
138
+ await project_data_connector.disable()
139
+ mock_trainml._query.assert_called_once_with(
140
+ "/project/proj-id-1/data_connectors/ds-id-1/disable", "PATCH"
141
+ )
@@ -33,6 +33,23 @@ def project_datastore(mock_trainml):
33
33
 
34
34
 
35
35
  class ProjectDatastoresTests:
36
+ @mark.asyncio
37
+ async def test_project_datastores_get(self, project_datastores, mock_trainml):
38
+ """Test get method (lines 11-14)."""
39
+ api_response = {
40
+ "project_uuid": "proj-id-1",
41
+ "region_uuid": "reg-id-1",
42
+ "id": "store-id-1",
43
+ "type": "nfs",
44
+ "name": "On-prem NFS",
45
+ }
46
+ mock_trainml._query = AsyncMock(return_value=api_response)
47
+ result = await project_datastores.get("store-id-1", param1="value1")
48
+ mock_trainml._query.assert_called_once_with(
49
+ "/project/1/datastores/store-id-1", "GET", dict(param1="value1")
50
+ )
51
+ assert result.id == "store-id-1"
52
+
36
53
  @mark.asyncio
37
54
  async def test_project_datastores_refresh(self, project_datastores, mock_trainml):
38
55
  api_response = dict()
@@ -94,3 +111,23 @@ class ProjectDatastoreTests:
94
111
  empty_project_datastore = specimen.ProjectDatastore(mock_trainml)
95
112
  assert bool(project_datastore)
96
113
  assert not bool(empty_project_datastore)
114
+
115
+ @mark.asyncio
116
+ async def test_project_datastore_enable(self, project_datastore, mock_trainml):
117
+ """Test enable method (line 67)."""
118
+ api_response = dict()
119
+ mock_trainml._query = AsyncMock(return_value=api_response)
120
+ await project_datastore.enable()
121
+ mock_trainml._query.assert_called_once_with(
122
+ "/project/proj-id-1/datastores/ds-id-1/enable", "PATCH"
123
+ )
124
+
125
+ @mark.asyncio
126
+ async def test_project_datastore_disable(self, project_datastore, mock_trainml):
127
+ """Test disable method (line 72)."""
128
+ api_response = dict()
129
+ mock_trainml._query = AsyncMock(return_value=api_response)
130
+ await project_datastore.disable()
131
+ mock_trainml._query.assert_called_once_with(
132
+ "/project/proj-id-1/datastores/ds-id-1/disable", "PATCH"
133
+ )