trainml 0.5.16__py3-none-any.whl → 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (53) hide show
  1. examples/local_storage.py +0 -2
  2. tests/integration/test_checkpoints_integration.py +4 -3
  3. tests/integration/test_datasets_integration.py +5 -3
  4. tests/integration/test_jobs_integration.py +33 -27
  5. tests/integration/test_models_integration.py +7 -3
  6. tests/integration/test_volumes_integration.py +2 -2
  7. tests/unit/cli/test_cli_checkpoint_unit.py +312 -1
  8. tests/unit/cloudbender/test_nodes_unit.py +112 -0
  9. tests/unit/cloudbender/test_providers_unit.py +96 -0
  10. tests/unit/cloudbender/test_regions_unit.py +106 -0
  11. tests/unit/cloudbender/test_services_unit.py +141 -0
  12. tests/unit/conftest.py +23 -10
  13. tests/unit/projects/test_project_data_connectors_unit.py +39 -0
  14. tests/unit/projects/test_project_datastores_unit.py +37 -0
  15. tests/unit/projects/test_project_members_unit.py +46 -0
  16. tests/unit/projects/test_project_services_unit.py +65 -0
  17. tests/unit/projects/test_projects_unit.py +17 -1
  18. tests/unit/test_auth_unit.py +17 -2
  19. tests/unit/test_checkpoints_unit.py +256 -71
  20. tests/unit/test_datasets_unit.py +218 -68
  21. tests/unit/test_exceptions.py +133 -0
  22. tests/unit/test_gpu_types_unit.py +11 -1
  23. tests/unit/test_jobs_unit.py +1014 -95
  24. tests/unit/test_main_unit.py +20 -0
  25. tests/unit/test_models_unit.py +218 -70
  26. tests/unit/test_trainml_unit.py +627 -3
  27. tests/unit/test_volumes_unit.py +211 -70
  28. tests/unit/utils/__init__.py +1 -0
  29. tests/unit/utils/test_transfer_unit.py +4260 -0
  30. trainml/__init__.py +1 -1
  31. trainml/checkpoints.py +56 -57
  32. trainml/cli/__init__.py +6 -3
  33. trainml/cli/checkpoint.py +18 -57
  34. trainml/cli/dataset.py +17 -57
  35. trainml/cli/job/__init__.py +11 -53
  36. trainml/cli/job/create.py +51 -24
  37. trainml/cli/model.py +14 -56
  38. trainml/cli/volume.py +18 -57
  39. trainml/datasets.py +50 -55
  40. trainml/jobs.py +239 -68
  41. trainml/models.py +51 -55
  42. trainml/projects/projects.py +2 -2
  43. trainml/trainml.py +50 -16
  44. trainml/utils/__init__.py +1 -0
  45. trainml/utils/auth.py +641 -0
  46. trainml/utils/transfer.py +587 -0
  47. trainml/volumes.py +48 -53
  48. {trainml-0.5.16.dist-info → trainml-1.0.0.dist-info}/METADATA +3 -3
  49. {trainml-0.5.16.dist-info → trainml-1.0.0.dist-info}/RECORD +53 -47
  50. {trainml-0.5.16.dist-info → trainml-1.0.0.dist-info}/LICENSE +0 -0
  51. {trainml-0.5.16.dist-info → trainml-1.0.0.dist-info}/WHEEL +0 -0
  52. {trainml-0.5.16.dist-info → trainml-1.0.0.dist-info}/entry_points.txt +0 -0
  53. {trainml-0.5.16.dist-info → trainml-1.0.0.dist-info}/top_level.txt +0 -0
@@ -142,7 +142,9 @@ class DatasetTests:
142
142
  def test_dataset_repr(self, dataset):
143
143
  string = repr(dataset)
144
144
  regex = (
145
- r"^Dataset\( trainml , \*\*{.*'dataset_uuid': '" + dataset.id + r"'.*}\)$"
145
+ r"^Dataset\( trainml , \*\*{.*'dataset_uuid': '"
146
+ + dataset.id
147
+ + r"'.*}\)$"
146
148
  )
147
149
  assert isinstance(string, str)
148
150
  assert re.match(regex, string)
@@ -154,9 +156,7 @@ class DatasetTests:
154
156
 
155
157
  @mark.asyncio
156
158
  async def test_dataset_get_log_url(self, dataset, mock_trainml):
157
- api_response = (
158
- "https://trainml-jobs-dev.s3.us-east-2.amazonaws.com/1/logs/first_one.zip"
159
- )
159
+ api_response = "https://trainml-jobs-dev.s3.us-east-2.amazonaws.com/1/logs/first_one.zip"
160
160
  mock_trainml._query = AsyncMock(return_value=api_response)
161
161
  response = await dataset.get_log_url()
162
162
  mock_trainml._query.assert_called_once_with(
@@ -181,78 +181,73 @@ class DatasetTests:
181
181
  assert response == api_response
182
182
 
183
183
  @mark.asyncio
184
- async def test_dataset_get_connection_utility_url(self, dataset, mock_trainml):
185
- api_response = (
186
- "https://trainml-jobs-dev.s3.us-east-2.amazonaws.com/1/vpn/first_one.zip"
187
- )
188
- mock_trainml._query = AsyncMock(return_value=api_response)
189
- response = await dataset.get_connection_utility_url()
190
- mock_trainml._query.assert_called_once_with(
191
- "/dataset/1/download", "GET", dict(project_uuid="proj-id-1")
184
+ async def test_dataset_connect_downloading_status(self, mock_trainml):
185
+ dataset = specimen.Dataset(
186
+ mock_trainml,
187
+ dataset_uuid="1",
188
+ project_uuid="proj-id-1",
189
+ name="test dataset",
190
+ status="downloading",
191
+ auth_token="test-token",
192
+ hostname="example.com",
193
+ source_uri="/path/to/source",
192
194
  )
193
- assert response == api_response
194
195
 
195
- def test_dataset_get_connection_details_no_vpn(self, dataset):
196
- details = dataset.get_connection_details()
197
- expected_details = dict()
198
- assert details == expected_details
196
+ with patch(
197
+ "trainml.datasets.Dataset.refresh", new_callable=AsyncMock
198
+ ) as mock_refresh:
199
+ with patch(
200
+ "trainml.datasets.upload", new_callable=AsyncMock
201
+ ) as mock_upload:
202
+ await dataset.connect()
203
+ mock_refresh.assert_called_once()
204
+ mock_upload.assert_called_once_with(
205
+ "example.com", "test-token", "/path/to/source"
206
+ )
199
207
 
200
- def test_dataset_get_connection_details_local_data(self, mock_trainml):
208
+ @mark.asyncio
209
+ async def test_dataset_connect_exporting_status(
210
+ self, mock_trainml, tmp_path
211
+ ):
212
+ output_dir = str(tmp_path / "output")
201
213
  dataset = specimen.Dataset(
202
214
  mock_trainml,
203
215
  dataset_uuid="1",
204
216
  project_uuid="proj-id-1",
205
- name="first one",
206
- status="new",
207
- size=100000,
208
- createdAt="2020-12-31T23:59:59.000Z",
209
- source_type="local",
210
- source_uri="~/tensorflow-example/data",
211
- vpn={
212
- "status": "new",
213
- "cidr": "10.106.171.0/24",
214
- "client": {
215
- "port": "36017",
216
- "id": "cus-id-1",
217
- "address": "10.106.171.253",
218
- "ssh_port": 46600,
219
- },
220
- },
221
- )
222
- details = dataset.get_connection_details()
223
- expected_details = dict(
224
- project_uuid="proj-id-1",
225
- entity_type="dataset",
226
- cidr="10.106.171.0/24",
227
- ssh_port=46600,
228
- input_path="~/tensorflow-example/data",
229
- output_path=None,
217
+ name="test dataset",
218
+ status="exporting",
219
+ auth_token="test-token",
220
+ hostname="example.com",
221
+ output_uri=output_dir,
230
222
  )
231
- assert details == expected_details
232
223
 
233
- @mark.asyncio
234
- async def test_dataset_connect(self, dataset, mock_trainml):
235
224
  with patch(
236
- "trainml.datasets.Connection",
237
- autospec=True,
238
- ) as mock_connection:
239
- connection = mock_connection.return_value
240
- connection.status = "connected"
241
- resp = await dataset.connect()
242
- connection.start.assert_called_once()
243
- assert resp == "connected"
225
+ "trainml.datasets.Dataset.refresh", new_callable=AsyncMock
226
+ ) as mock_refresh:
227
+ with patch(
228
+ "trainml.datasets.download", new_callable=AsyncMock
229
+ ) as mock_download:
230
+ await dataset.connect()
231
+ mock_refresh.assert_called_once()
232
+ mock_download.assert_called_once_with(
233
+ "example.com", "test-token", output_dir
234
+ )
244
235
 
245
236
  @mark.asyncio
246
- async def test_dataset_disconnect(self, dataset, mock_trainml):
247
- with patch(
248
- "trainml.datasets.Connection",
249
- autospec=True,
250
- ) as mock_connection:
251
- connection = mock_connection.return_value
252
- connection.status = "removed"
253
- resp = await dataset.disconnect()
254
- connection.stop.assert_called_once()
255
- assert resp == "removed"
237
+ async def test_dataset_connect_invalid_status(self, mock_trainml):
238
+ dataset = specimen.Dataset(
239
+ mock_trainml,
240
+ dataset_uuid="1",
241
+ project_uuid="proj-id-1",
242
+ name="test dataset",
243
+ status="ready",
244
+ )
245
+
246
+ with raises(
247
+ SpecificationError,
248
+ match="You can only connect to downloading or exporting datasets",
249
+ ):
250
+ await dataset.connect()
256
251
 
257
252
  @mark.asyncio
258
253
  async def test_dataset_remove(self, dataset, mock_trainml):
@@ -391,7 +386,9 @@ class DatasetTests:
391
386
  mock_trainml._query.assert_not_called()
392
387
 
393
388
  @mark.asyncio
394
- async def test_dataset_wait_for_incorrect_status(self, dataset, mock_trainml):
389
+ async def test_dataset_wait_for_incorrect_status(
390
+ self, dataset, mock_trainml
391
+ ):
395
392
  api_response = None
396
393
  mock_trainml._query = AsyncMock(return_value=api_response)
397
394
  with raises(SpecificationError):
@@ -436,7 +433,9 @@ class DatasetTests:
436
433
  mock_trainml._query.assert_called()
437
434
 
438
435
  @mark.asyncio
439
- async def test_dataset_wait_for_dataset_failed(self, dataset, mock_trainml):
436
+ async def test_dataset_wait_for_dataset_failed(
437
+ self, dataset, mock_trainml
438
+ ):
440
439
  api_response = dict(
441
440
  dataset_uuid="1",
442
441
  name="first one",
@@ -449,7 +448,156 @@ class DatasetTests:
449
448
  mock_trainml._query.assert_called()
450
449
 
451
450
  @mark.asyncio
452
- async def test_dataset_wait_for_archived_succeeded(self, dataset, mock_trainml):
451
+ async def test_dataset_rename(self, dataset, mock_trainml):
452
+ api_response = dict(
453
+ dataset_uuid="1",
454
+ name="renamed dataset",
455
+ project_uuid="proj-id-1",
456
+ status="ready",
457
+ )
458
+ mock_trainml._query = AsyncMock(return_value=api_response)
459
+ result = await dataset.rename("renamed dataset")
460
+ mock_trainml._query.assert_called_once_with(
461
+ "/dataset/1",
462
+ "PATCH",
463
+ None,
464
+ dict(name="renamed dataset"),
465
+ )
466
+ assert result == dataset
467
+ assert dataset.name == "renamed dataset"
468
+
469
+ @mark.asyncio
470
+ async def test_dataset_export(self, dataset, mock_trainml):
471
+ api_response = dict(
472
+ dataset_uuid="1",
473
+ name="first one",
474
+ project_uuid="proj-id-1",
475
+ status="exporting",
476
+ )
477
+ mock_trainml._query = AsyncMock(return_value=api_response)
478
+ result = await dataset.export("aws", "s3://bucket/path", dict(key="value"))
479
+ mock_trainml._query.assert_called_once_with(
480
+ "/dataset/1/export",
481
+ "POST",
482
+ dict(project_uuid="proj-id-1"),
483
+ dict(
484
+ output_type="aws",
485
+ output_uri="s3://bucket/path",
486
+ output_options=dict(key="value"),
487
+ ),
488
+ )
489
+ assert result == dataset
490
+ assert dataset.status == "exporting"
491
+
492
+ @mark.asyncio
493
+ async def test_dataset_export_default_options(self, dataset, mock_trainml):
494
+ api_response = dict(
495
+ dataset_uuid="1",
496
+ name="first one",
497
+ project_uuid="proj-id-1",
498
+ status="exporting",
499
+ )
500
+ mock_trainml._query = AsyncMock(return_value=api_response)
501
+ result = await dataset.export("aws", "s3://bucket/path")
502
+ mock_trainml._query.assert_called_once_with(
503
+ "/dataset/1/export",
504
+ "POST",
505
+ dict(project_uuid="proj-id-1"),
506
+ dict(
507
+ output_type="aws",
508
+ output_uri="s3://bucket/path",
509
+ output_options=dict(),
510
+ ),
511
+ )
512
+ assert result == dataset
513
+
514
+ @mark.asyncio
515
+ async def test_dataset_wait_for_timeout_validation(
516
+ self, dataset, mock_trainml
517
+ ):
518
+ with raises(SpecificationError) as exc_info:
519
+ await dataset.wait_for("ready", timeout=25 * 60 * 60) # > 24 hours
520
+ assert "timeout" in str(exc_info.value.attribute).lower()
521
+ assert "less than" in str(exc_info.value.message).lower()
522
+
523
+ @mark.asyncio
524
+ async def test_dataset_connect_new_status_waits_for_downloading(
525
+ self, dataset, mock_trainml
526
+ ):
527
+ """Test that connect() waits for downloading status when status is 'new'."""
528
+ dataset._dataset["status"] = "new"
529
+ dataset._status = "new"
530
+ api_response_new = dict(
531
+ dataset_uuid="1",
532
+ name="first one",
533
+ status="new",
534
+ )
535
+ api_response_downloading = dict(
536
+ dataset_uuid="1",
537
+ name="first one",
538
+ status="downloading",
539
+ auth_token="token",
540
+ hostname="host",
541
+ source_uri="s3://bucket/path",
542
+ )
543
+ # wait_for calls refresh multiple times, then connect calls refresh again
544
+ # We need enough responses for wait_for polling and the final refresh
545
+ mock_trainml._query = AsyncMock(
546
+ side_effect=[
547
+ api_response_new, # wait_for refresh 1
548
+ api_response_downloading, # wait_for refresh 2 (status matches, wait_for returns)
549
+ api_response_downloading, # connect refresh
550
+ ]
551
+ )
552
+ with patch("trainml.datasets.upload", new_callable=AsyncMock) as mock_upload:
553
+ await dataset.connect()
554
+ # After refresh, status should be downloading
555
+ assert dataset.status == "downloading"
556
+ mock_upload.assert_called_once()
557
+
558
+ @mark.asyncio
559
+ async def test_dataset_connect_downloading_missing_properties(
560
+ self, dataset, mock_trainml
561
+ ):
562
+ """Test connect() raises error when downloading status missing properties."""
563
+ dataset._dataset["status"] = "downloading"
564
+ api_response = dict(
565
+ dataset_uuid="1",
566
+ name="first one",
567
+ status="downloading",
568
+ # Missing auth_token, hostname, or source_uri
569
+ )
570
+ mock_trainml._query = AsyncMock(return_value=api_response)
571
+ with raises(SpecificationError) as exc_info:
572
+ await dataset.connect()
573
+ assert "missing required connection properties" in str(exc_info.value.message).lower()
574
+
575
+ @mark.asyncio
576
+ async def test_dataset_connect_exporting_missing_properties(
577
+ self, dataset, mock_trainml
578
+ ):
579
+ """Test connect() raises error when exporting status missing properties."""
580
+ dataset._dataset["status"] = "exporting"
581
+ api_response = dict(
582
+ dataset_uuid="1",
583
+ name="first one",
584
+ status="exporting",
585
+ # Missing auth_token, hostname, or output_uri
586
+ )
587
+ mock_trainml._query = AsyncMock(return_value=api_response)
588
+ with raises(SpecificationError) as exc_info:
589
+ await dataset.connect()
590
+ assert "missing required connection properties" in str(exc_info.value.message).lower()
591
+
592
+ def test_dataset_billed_size_property(self, dataset, mock_trainml):
593
+ """Test billed_size property access."""
594
+ dataset._billed_size = 50000
595
+ assert dataset.billed_size == 50000
596
+
597
+ @mark.asyncio
598
+ async def test_dataset_wait_for_archived_succeeded(
599
+ self, dataset, mock_trainml
600
+ ):
453
601
  mock_trainml._query = AsyncMock(
454
602
  side_effect=ApiError(404, dict(errorMessage="Dataset Not Found"))
455
603
  )
@@ -457,7 +605,9 @@ class DatasetTests:
457
605
  mock_trainml._query.assert_called()
458
606
 
459
607
  @mark.asyncio
460
- async def test_dataset_wait_for_unexpected_api_error(self, dataset, mock_trainml):
608
+ async def test_dataset_wait_for_unexpected_api_error(
609
+ self, dataset, mock_trainml
610
+ ):
461
611
  mock_trainml._query = AsyncMock(
462
612
  side_effect=ApiError(404, dict(errorMessage="Dataset Not Found"))
463
613
  )
@@ -5,22 +5,55 @@ import trainml.exceptions as specimen
5
5
  pytestmark = [mark.sdk, mark.unit]
6
6
 
7
7
 
8
+ def test_trainml_exception():
9
+ """Test TrainMLException base class."""
10
+ error = specimen.TrainMLException("test message")
11
+ assert error.message == "test message"
12
+ assert repr(error) == "TrainMLException( 'test message')"
13
+ assert str(error) == "TrainMLException('test message')"
14
+
15
+ # Test with multiple args
16
+ error2 = specimen.TrainMLException("test", "arg1", "arg2")
17
+ assert error2.message == "test"
18
+
19
+
8
20
  def test_api_error():
21
+ """Test ApiError exception."""
9
22
  error = specimen.ApiError(400, dict(errorMessage="test message"))
23
+ assert error.status == 400
24
+ assert error.message == "test message"
10
25
  assert repr(error) == "ApiError(400, 'test message')"
11
26
  assert str(error) == "ApiError(400, 'test message')"
27
+
28
+ # Test with message key instead of errorMessage
29
+ error2 = specimen.ApiError(404, dict(message="not found"))
30
+ assert error2.message == "not found"
31
+
32
+ # Test with multiple args
33
+ error3 = specimen.ApiError(500, dict(errorMessage="server error"), "extra")
34
+ assert error3.message == "server error"
12
35
 
13
36
 
14
37
  def test_job_error():
38
+ """Test JobError exception."""
15
39
  error = specimen.JobError("failed", dict(id="id-1", status="failed"))
40
+ assert error.status == "failed"
41
+ assert error.message == dict(id="id-1", status="failed")
16
42
  assert (
17
43
  repr(error) == "JobError(failed, {'id': 'id-1', 'status': 'failed'})"
18
44
  )
19
45
  assert str(error) == "JobError(failed, {'id': 'id-1', 'status': 'failed'})"
46
+
47
+ # Test with string data
48
+ error2 = specimen.JobError("errored", "error string")
49
+ assert error2.message == "error string"
20
50
 
21
51
 
22
52
  def test_dataset_error():
53
+ """Test DatasetError exception."""
23
54
  error = specimen.DatasetError("failed", dict(id="id-1", status="failed"))
55
+ assert error.status == "failed"
56
+ assert error.message == dict(id="id-1", status="failed")
24
57
  assert (
25
58
  repr(error)
26
59
  == "DatasetError(failed, {'id': 'id-1', 'status': 'failed'})"
@@ -29,3 +62,103 @@ def test_dataset_error():
29
62
  str(error)
30
63
  == "DatasetError(failed, {'id': 'id-1', 'status': 'failed'})"
31
64
  )
65
+
66
+
67
+ def test_model_error():
68
+ """Test ModelError exception."""
69
+ error = specimen.ModelError("failed", dict(id="id-1", status="failed"))
70
+ assert error.status == "failed"
71
+ assert error.message == dict(id="id-1", status="failed")
72
+ assert (
73
+ repr(error) == "ModelError(failed, {'id': 'id-1', 'status': 'failed'})"
74
+ )
75
+ assert str(error) == "ModelError(failed, {'id': 'id-1', 'status': 'failed'})"
76
+
77
+
78
+ def test_checkpoint_error():
79
+ """Test CheckpointError exception."""
80
+ error = specimen.CheckpointError("failed", dict(id="id-1", status="failed"))
81
+ assert error.status == "failed"
82
+ assert error.message == dict(id="id-1", status="failed")
83
+ assert (
84
+ repr(error)
85
+ == "CheckpointError(failed, {'id': 'id-1', 'status': 'failed'})"
86
+ )
87
+ assert (
88
+ str(error)
89
+ == "CheckpointError(failed, {'id': 'id-1', 'status': 'failed'})"
90
+ )
91
+
92
+
93
+ def test_volume_error():
94
+ """Test VolumeError exception."""
95
+ error = specimen.VolumeError("failed", dict(id="id-1", status="failed"))
96
+ assert error.status == "failed"
97
+ assert error.message == dict(id="id-1", status="failed")
98
+ assert (
99
+ repr(error) == "VolumeError(failed, {'id': 'id-1', 'status': 'failed'})"
100
+ )
101
+ assert str(error) == "VolumeError(failed, {'id': 'id-1', 'status': 'failed'})"
102
+
103
+
104
+ def test_connection_error():
105
+ """Test ConnectionError exception."""
106
+ error = specimen.ConnectionError("connection failed")
107
+ assert error.message == "connection failed"
108
+ assert repr(error) == "ConnectionError(connection failed)"
109
+ assert str(error) == "ConnectionError(connection failed)"
110
+
111
+ # Test with multiple args
112
+ error2 = specimen.ConnectionError("test", "arg1", "arg2")
113
+ assert error2.message == "test"
114
+
115
+
116
+ def test_specification_error():
117
+ """Test SpecificationError exception."""
118
+ error = specimen.SpecificationError("attr", "invalid value")
119
+ assert error.attribute == "attr"
120
+ assert error.message == "invalid value"
121
+ assert repr(error) == "SpecificationError(attr, invalid value)"
122
+ assert str(error) == "SpecificationError(attr, invalid value)"
123
+
124
+ # Test with multiple args
125
+ error2 = specimen.SpecificationError("attr", "test", "arg1")
126
+ assert error2.attribute == "attr"
127
+ assert error2.message == "test"
128
+
129
+
130
+ def test_node_error():
131
+ """Test NodeError exception."""
132
+ error = specimen.NodeError("failed", dict(id="id-1", status="failed"))
133
+ assert error.status == "failed"
134
+ assert error.message == dict(id="id-1", status="failed")
135
+ assert (
136
+ repr(error) == "NodeError(failed, {'id': 'id-1', 'status': 'failed'})"
137
+ )
138
+ assert str(error) == "NodeError(failed, {'id': 'id-1', 'status': 'failed'})"
139
+
140
+
141
+ def test_provider_error():
142
+ """Test ProviderError exception."""
143
+ error = specimen.ProviderError("failed", dict(id="id-1", status="failed"))
144
+ assert error.status == "failed"
145
+ assert error.message == dict(id="id-1", status="failed")
146
+ assert (
147
+ repr(error)
148
+ == "ProviderError(failed, {'id': 'id-1', 'status': 'failed'})"
149
+ )
150
+ assert (
151
+ str(error)
152
+ == "ProviderError(failed, {'id': 'id-1', 'status': 'failed'})"
153
+ )
154
+
155
+
156
+ def test_region_error():
157
+ """Test RegionError exception."""
158
+ error = specimen.RegionError("failed", dict(id="id-1", status="failed"))
159
+ assert error.status == "failed"
160
+ assert error.message == dict(id="id-1", status="failed")
161
+ assert (
162
+ repr(error) == "RegionError(failed, {'id': 'id-1', 'status': 'failed'})"
163
+ )
164
+ assert str(error) == "RegionError(failed, {'id': 'id-1', 'status': 'failed'})"
@@ -1,6 +1,6 @@
1
1
  import re
2
2
  from unittest.mock import AsyncMock
3
- from pytest import mark, fixture
3
+ from pytest import mark, fixture, raises
4
4
 
5
5
  import trainml.gpu_types as specimen
6
6
 
@@ -39,6 +39,16 @@ class GpuTypesTests:
39
39
  f"/project/proj-id-1/gputypes", "GET"
40
40
  )
41
41
 
42
+ @mark.asyncio
43
+ async def test_list_gpu_types_no_project(self, mock_trainml):
44
+ """Test list raises error when no active project (line 11)."""
45
+ from trainml.exceptions import TrainMLException
46
+ gpu_types = specimen.GpuTypes(mock_trainml)
47
+ mock_trainml.project = None
48
+ with raises(TrainMLException) as exc_info:
49
+ await gpu_types.list()
50
+ assert "Active project not configured" in str(exc_info.value.message)
51
+
42
52
  @mark.asyncio
43
53
  async def test_project_refresh_gpu_types(self, gpu_types, mock_trainml):
44
54
  api_response = dict()