trainml 0.5.0__py3-none-any.whl → 0.5.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (36) hide show
  1. tests/integration/test_jobs_integration.py +13 -18
  2. tests/unit/cli/cloudbender/test_cli_device_unit.py +38 -0
  3. tests/unit/cloudbender/test_datastores_unit.py +20 -0
  4. tests/unit/cloudbender/test_device_configs_unit.py +21 -0
  5. tests/unit/cloudbender/test_devices_unit.py +270 -0
  6. tests/unit/cloudbender/test_nodes_unit.py +43 -0
  7. tests/unit/cloudbender/test_providers_unit.py +16 -0
  8. tests/unit/cloudbender/test_regions_unit.py +18 -0
  9. tests/unit/cloudbender/test_reservations_unit.py +20 -0
  10. tests/unit/conftest.py +54 -3
  11. tests/unit/test_auth.py +1 -1
  12. trainml/__init__.py +1 -1
  13. trainml/auth.py +3 -7
  14. trainml/cli/cloudbender/__init__.py +1 -0
  15. trainml/cli/cloudbender/device.py +157 -0
  16. trainml/cli/dataset.py +1 -3
  17. trainml/cli/job/create.py +3 -3
  18. trainml/cli/model.py +2 -0
  19. trainml/cloudbender/cloudbender.py +2 -0
  20. trainml/cloudbender/datastores.py +8 -0
  21. trainml/cloudbender/device_configs.py +8 -0
  22. trainml/cloudbender/devices.py +190 -0
  23. trainml/cloudbender/nodes.py +22 -0
  24. trainml/cloudbender/providers.py +8 -0
  25. trainml/cloudbender/regions.py +8 -0
  26. trainml/cloudbender/reservations.py +8 -0
  27. trainml/datasets.py +25 -14
  28. trainml/gpu_types.py +5 -0
  29. trainml/models.py +22 -3
  30. trainml/trainml.py +2 -1
  31. {trainml-0.5.0.dist-info → trainml-0.5.2.dist-info}/METADATA +1 -1
  32. {trainml-0.5.0.dist-info → trainml-0.5.2.dist-info}/RECORD +36 -32
  33. {trainml-0.5.0.dist-info → trainml-0.5.2.dist-info}/LICENSE +0 -0
  34. {trainml-0.5.0.dist-info → trainml-0.5.2.dist-info}/WHEEL +0 -0
  35. {trainml-0.5.0.dist-info → trainml-0.5.2.dist-info}/entry_points.txt +0 -0
  36. {trainml-0.5.0.dist-info → trainml-0.5.2.dist-info}/top_level.txt +0 -0
@@ -100,7 +100,7 @@ class JobLifeCycleTests:
100
100
  training_job = await job.copy(
101
101
  name="CLI Automated Tests - Job Convert",
102
102
  type="training",
103
- workers=["python $TRAINML_MODEL_PATH/tensorflow/main.py"],
103
+ workers=["python $ML_MODEL_PATH/tensorflow/main.py"],
104
104
  data=dict(
105
105
  datasets=[
106
106
  dict(
@@ -191,8 +191,7 @@ class JobAPIResourceValidationTests:
191
191
  disk_size=10,
192
192
  )
193
193
  assert (
194
- "Invalid Request - CPU Count must be a multiple of 4"
195
- in error.value.message
194
+ "Invalid Request - CPU Count must be a multiple of 4" in error.value.message
196
195
  )
197
196
 
198
197
  async def test_invalid_gpu_count_for_cpu(self, trainml):
@@ -373,7 +372,7 @@ class JobIOTests:
373
372
  type="training",
374
373
  gpu_types=["gtx1060"],
375
374
  disk_size=10,
376
- workers=["python $TRAINML_MODEL_PATH/tensorflow/main.py"],
375
+ workers=["python $ML_MODEL_PATH/tensorflow/main.py"],
377
376
  environment=dict(
378
377
  type="DEEPLEARNING_PY310",
379
378
  env=[
@@ -431,7 +430,7 @@ class JobIOTests:
431
430
  source_uri="git@github.com:trainML/environment-tests.git",
432
431
  )
433
432
  await model.wait_for("ready", 300)
434
- assert model.size >= 500000
433
+ assert model.size >= 200000
435
434
 
436
435
  job = await trainml.jobs.create(
437
436
  "CLI Automated Tests - Training With trainML Model Output",
@@ -440,7 +439,7 @@ class JobIOTests:
440
439
  gpu_count=1,
441
440
  cpu_count=8,
442
441
  disk_size=10,
443
- worker_commands=["python $TRAINML_MODEL_PATH/tensorflow/main.py"],
442
+ worker_commands=["python $ML_MODEL_PATH/tensorflow/main.py"],
444
443
  data=dict(
445
444
  datasets=[
446
445
  dict(
@@ -548,22 +547,21 @@ class JobTypeTests:
548
547
  ),
549
548
  environment=dict(
550
549
  type="CUSTOM",
551
- custom_image="tensorflow/tensorflow:2.10.0-gpu",
550
+ custom_image="tensorflow/tensorflow:2.13.0-gpu",
552
551
  packages=dict(
553
552
  pip=[
554
- "tensorflow_addons==0.18.0",
555
553
  "matplotlib",
556
554
  "scipy",
557
- "tensorflow_hub==0.12.0",
555
+ "tensorflow_hub",
558
556
  "keras_applications",
559
557
  "keras_preprocessing",
560
- "protobuf==3.20.1",
561
- "typing-extensions==4.4.0",
558
+ "protobuf",
559
+ "typing-extensions",
562
560
  ]
563
561
  ),
564
562
  ),
565
563
  worker_commands=[
566
- "python $TRAINML_MODEL_PATH/tensorflow/main.py",
564
+ "python $ML_MODEL_PATH/tensorflow/main.py",
567
565
  ],
568
566
  data=dict(
569
567
  datasets=[
@@ -589,10 +587,7 @@ class JobTypeTests:
589
587
  sys.stderr.write(captured.err)
590
588
  assert "Epoch 1/2" in captured.out
591
589
  assert "Epoch 2/2" in captured.out
592
- assert (
593
- "Uploading s3://trainml-example/output/resnet_cifar10"
594
- in captured.out
595
- )
590
+ assert "Uploading s3://trainml-example/output/resnet_cifar10" in captured.out
596
591
  assert (
597
592
  "upload: ./model.ckpt-0002.data-00000-of-00001 to s3://trainml-example/output/resnet_cifar10/model.ckpt-0002.data-00000-of-00001"
598
593
  in captured.out
@@ -615,7 +610,7 @@ class JobFeatureTests:
615
610
  source_uri="git@github.com:trainML/environment-tests.git",
616
611
  ),
617
612
  worker_commands=[
618
- "python $TRAINML_MODEL_PATH/pytorch/main.py",
613
+ "python $ML_MODEL_PATH/pytorch/main.py",
619
614
  ],
620
615
  data=dict(
621
616
  datasets=[dict(id="MNIST", public=True)],
@@ -641,7 +636,7 @@ class JobFeatureTests:
641
636
  gpu_count=1,
642
637
  disk_size=10,
643
638
  workers=[
644
- "python $TRAINML_MODEL_PATH/tensorflow/main.py",
639
+ "python $ML_MODEL_PATH/tensorflow/main.py",
645
640
  ],
646
641
  data=dict(
647
642
  input_type="wasabi",
@@ -0,0 +1,38 @@
1
+ import re
2
+ import json
3
+ import click
4
+ from unittest.mock import AsyncMock, patch
5
+ from pytest import mark, fixture, raises
6
+
7
+ pytestmark = [mark.cli, mark.unit, mark.cloudbender, mark.devices]
8
+
9
+ from trainml.cli.cloudbender import device as specimen
10
+ from trainml.cloudbender.devices import Device
11
+
12
+
13
+ def test_list(runner, mock_devices):
14
+ with patch("trainml.cli.TrainML", new=AsyncMock) as mock_trainml:
15
+ mock_trainml.cloudbender = AsyncMock()
16
+ mock_trainml.cloudbender.devices = AsyncMock()
17
+ mock_trainml.cloudbender.devices.list = AsyncMock(
18
+ return_value=mock_devices
19
+ )
20
+ result = runner.invoke(
21
+ specimen,
22
+ args=["list", "--provider=prov-id-1", "--region=reg-id-1"],
23
+ )
24
+ assert result.exit_code == 0
25
+ mock_trainml.cloudbender.devices.list.assert_called_once_with(
26
+ provider_uuid="prov-id-1", region_uuid="reg-id-1"
27
+ )
28
+
29
+
30
+ def test_list_no_provider(runner, mock_devices):
31
+ with patch("trainml.cli.TrainML", new=AsyncMock) as mock_trainml:
32
+ mock_trainml.cloudbender = AsyncMock()
33
+ mock_trainml.cloudbender.devices = AsyncMock()
34
+ mock_trainml.cloudbender.devices.list = AsyncMock(
35
+ return_value=mock_devices
36
+ )
37
+ result = runner.invoke(specimen, ["list"])
38
+ assert result.exit_code != 0
@@ -151,3 +151,23 @@ class datastoreTests:
151
151
  mock_trainml._query.assert_called_once_with(
152
152
  "/provider/1/region/a/datastore/x", "DELETE"
153
153
  )
154
+
155
+ @mark.asyncio
156
+ async def test_datastore_refresh(self, datastore, mock_trainml):
157
+ api_response = {
158
+ "provider_uuid": "provider-id-1",
159
+ "region_uuid": "region-id-1",
160
+ "store_id": "store-id-1",
161
+ "name": "On-Prem Datastore",
162
+ "type": "nfs",
163
+ "uri": "192.168.0.50",
164
+ "root": "/exports",
165
+ "createdAt": "2020-12-31T23:59:59.000Z",
166
+ }
167
+ mock_trainml._query = AsyncMock(return_value=api_response)
168
+ response = await datastore.refresh()
169
+ mock_trainml._query.assert_called_once_with(
170
+ f"/provider/1/region/a/datastore/x", "GET"
171
+ )
172
+ assert datastore.id == "store-id-1"
173
+ assert response.id == "store-id-1"
@@ -150,3 +150,24 @@ class device_configTests:
150
150
  mock_trainml._query.assert_called_once_with(
151
151
  "/provider/1/region/a/device/config/x", "DELETE"
152
152
  )
153
+
154
+ @mark.asyncio
155
+ async def test_device_config_refresh(self, device_config, mock_trainml):
156
+ api_response = {
157
+ "provider_uuid": "provider-id-1",
158
+ "region_uuid": "region-id-1",
159
+ "config_id": "device_config-id-1",
160
+ "name": "IoT 1",
161
+ "model_uuid": "model-id-1",
162
+ "model_project_uuid": "proj-id-1",
163
+ "image": "nvidia/cuda",
164
+ "command": "python run.py",
165
+ "createdAt": "2020-12-31T23:59:59.000Z",
166
+ }
167
+ mock_trainml._query = AsyncMock(return_value=api_response)
168
+ response = await device_config.refresh()
169
+ mock_trainml._query.assert_called_once_with(
170
+ f"/provider/1/region/a/device/config/x", "GET"
171
+ )
172
+ assert device_config.id == "device_config-id-1"
173
+ assert response.id == "device_config-id-1"
@@ -0,0 +1,270 @@
1
+ import re
2
+ import json
3
+ import logging
4
+ from unittest.mock import AsyncMock, patch
5
+ from pytest import mark, fixture, raises
6
+ from aiohttp import WSMessage, WSMsgType
7
+
8
+ import trainml.cloudbender.devices as specimen
9
+ from trainml.exceptions import (
10
+ ApiError,
11
+ SpecificationError,
12
+ TrainMLException,
13
+ )
14
+
15
+ pytestmark = [mark.sdk, mark.unit, mark.cloudbender, mark.devices]
16
+
17
+
18
+ @fixture
19
+ def devices(mock_trainml):
20
+ yield specimen.Devices(mock_trainml)
21
+
22
+
23
+ @fixture
24
+ def device(mock_trainml):
25
+ yield specimen.Device(
26
+ mock_trainml,
27
+ provider_uuid="1",
28
+ region_uuid="a",
29
+ device_id="x",
30
+ type="device",
31
+ service="compute",
32
+ friendly_name="hq-orin-01",
33
+ hostname="hq-orin-01",
34
+ status="active",
35
+ online=True,
36
+ maintenance_mode=False,
37
+ job_status="stopped",
38
+ job_last_deployed="2023-06-02T21:22:40.084Z",
39
+ job_config_id="job-id-1",
40
+ job_config_revision="1685740490096",
41
+ device_config_id="conf-id-2",
42
+ )
43
+
44
+
45
+ class RegionsTests:
46
+ @mark.asyncio
47
+ async def test_get_device(
48
+ self,
49
+ devices,
50
+ mock_trainml,
51
+ ):
52
+ api_response = dict()
53
+ mock_trainml._query = AsyncMock(return_value=api_response)
54
+ await devices.get("1234", "5687", "91011")
55
+ mock_trainml._query.assert_called_once_with(
56
+ "/provider/1234/region/5687/device/91011", "GET", {}
57
+ )
58
+
59
+ @mark.asyncio
60
+ async def test_list_devices(
61
+ self,
62
+ devices,
63
+ mock_trainml,
64
+ ):
65
+ api_response = dict()
66
+ mock_trainml._query = AsyncMock(return_value=api_response)
67
+ await devices.list("1234", "5687")
68
+ mock_trainml._query.assert_called_once_with(
69
+ "/provider/1234/region/5687/device", "GET", {}
70
+ )
71
+
72
+ @mark.asyncio
73
+ async def test_remove_device(
74
+ self,
75
+ devices,
76
+ mock_trainml,
77
+ ):
78
+ api_response = dict()
79
+ mock_trainml._query = AsyncMock(return_value=api_response)
80
+ await devices.remove("1234", "4567", "8910")
81
+ mock_trainml._query.assert_called_once_with(
82
+ "/provider/1234/region/4567/device/8910", "DELETE", {}
83
+ )
84
+
85
+ @mark.asyncio
86
+ async def test_create_device(self, devices, mock_trainml):
87
+ requested_config = dict(
88
+ provider_uuid="provider-id-1",
89
+ region_uuid="region-id-1",
90
+ friendly_name="phys-device",
91
+ hostname="phys-device",
92
+ minion_id="asdf",
93
+ )
94
+ expected_payload = dict(
95
+ friendly_name="phys-device",
96
+ hostname="phys-device",
97
+ minion_id="asdf",
98
+ type="device",
99
+ service="compute",
100
+ )
101
+ api_response = {
102
+ "provider_uuid": "provider-id-1",
103
+ "region_uuid": "region-id-1",
104
+ "device_id": "rig-id-1",
105
+ "name": "phys-device",
106
+ "type": "device",
107
+ "service": "compute",
108
+ "status": "new",
109
+ "online": False,
110
+ "maintenance_mode": True,
111
+ "job_status": "stopped",
112
+ "job_last_deployed": "2023-06-02T21:22:40.084Z",
113
+ "job_config_id": "job-id-1",
114
+ "job_config_revision": "1685740490096",
115
+ "device_config_id": "conf-id-1",
116
+ "createdAt": "2020-12-31T23:59:59.000Z",
117
+ }
118
+
119
+ mock_trainml._query = AsyncMock(return_value=api_response)
120
+ response = await devices.create(**requested_config)
121
+ mock_trainml._query.assert_called_once_with(
122
+ "/provider/provider-id-1/region/region-id-1/device",
123
+ "POST",
124
+ None,
125
+ expected_payload,
126
+ )
127
+ assert response.id == "rig-id-1"
128
+
129
+
130
+ class deviceTests:
131
+ def test_device_properties(self, device):
132
+ assert isinstance(device.id, str)
133
+ assert isinstance(device.provider_uuid, str)
134
+ assert isinstance(device.region_uuid, str)
135
+ assert isinstance(device.name, str)
136
+ assert isinstance(device.hostname, str)
137
+ assert isinstance(device.status, str)
138
+ assert isinstance(device.online, bool)
139
+ assert isinstance(device.maintenance_mode, bool)
140
+ assert isinstance(device.device_config_id, str)
141
+ assert isinstance(device.job_status, str)
142
+ assert isinstance(device.job_last_deployed, str)
143
+ assert isinstance(device.job_config_id, str)
144
+ assert isinstance(device.job_config_revision, str)
145
+
146
+ def test_device_str(self, device):
147
+ string = str(device)
148
+ regex = r"^{.*\"device_id\": \"" + device.id + r"\".*}$"
149
+ assert isinstance(string, str)
150
+ assert re.match(regex, string)
151
+
152
+ def test_device_repr(self, device):
153
+ string = repr(device)
154
+ regex = (
155
+ r"^Device\( trainml , \*\*{.*'device_id': '"
156
+ + device.id
157
+ + r"'.*}\)$"
158
+ )
159
+ assert isinstance(string, str)
160
+ assert re.match(regex, string)
161
+
162
+ def test_device_bool(self, device, mock_trainml):
163
+ empty_device = specimen.Device(mock_trainml)
164
+ assert bool(device)
165
+ assert not bool(empty_device)
166
+
167
+ @mark.asyncio
168
+ async def test_device_remove(self, device, mock_trainml):
169
+ api_response = dict()
170
+ mock_trainml._query = AsyncMock(return_value=api_response)
171
+ await device.remove()
172
+ mock_trainml._query.assert_called_once_with(
173
+ "/provider/1/region/a/device/x", "DELETE"
174
+ )
175
+
176
+ @mark.asyncio
177
+ async def test_device_refresh(self, device, mock_trainml):
178
+ api_response = {
179
+ "provider_uuid": "provider-id-1",
180
+ "region_uuid": "region-id-1",
181
+ "device_id": "device-id-1",
182
+ "name": "phys-device",
183
+ "type": "device",
184
+ "service": "compute",
185
+ "status": "new",
186
+ "online": False,
187
+ "maintenance_mode": True,
188
+ "job_status": "stopped",
189
+ "job_last_deployed": "2023-06-02T21:22:40.084Z",
190
+ "job_config_id": "job-id-1",
191
+ "job_config_revision": "1685740490096",
192
+ "device_config_id": "conf-id-1",
193
+ "createdAt": "2020-12-31T23:59:59.000Z",
194
+ }
195
+ mock_trainml._query = AsyncMock(return_value=api_response)
196
+ response = await device.refresh()
197
+ mock_trainml._query.assert_called_once_with(
198
+ f"/provider/1/region/a/device/x", "GET"
199
+ )
200
+ assert device.id == "device-id-1"
201
+ assert response.id == "device-id-1"
202
+
203
+ @mark.asyncio
204
+ async def test_device_toggle_maintenance(self, device, mock_trainml):
205
+ api_response = None
206
+ mock_trainml._query = AsyncMock(return_value=api_response)
207
+ await device.toggle_maintenance()
208
+ mock_trainml._query.assert_called_once_with(
209
+ "/provider/1/region/a/device/x/maintenance", "PATCH"
210
+ )
211
+
212
+ @mark.asyncio
213
+ async def test_device_run_action(self, device, mock_trainml):
214
+ api_response = None
215
+ mock_trainml._query = AsyncMock(return_value=api_response)
216
+ await device.run_action(command="report")
217
+ mock_trainml._query.assert_called_once_with(
218
+ "/provider/1/region/a/device/x/action",
219
+ "POST",
220
+ None,
221
+ dict(command="report"),
222
+ )
223
+
224
+ @mark.asyncio
225
+ async def test_device_set_config(self, device, mock_trainml):
226
+ api_response = {
227
+ "provider_uuid": "provider-id-1",
228
+ "region_uuid": "region-id-1",
229
+ "device_id": "device-id-1",
230
+ "name": "phys-device",
231
+ "type": "device",
232
+ "service": "compute",
233
+ "status": "new",
234
+ "online": False,
235
+ "maintenance_mode": True,
236
+ "job_status": "stopped",
237
+ "job_last_deployed": "2023-06-02T21:22:40.084Z",
238
+ "job_config_id": "job-id-1",
239
+ "job_config_revision": "1685740490096",
240
+ "device_config_id": "config-id-1",
241
+ "createdAt": "2020-12-31T23:59:59.000Z",
242
+ }
243
+ mock_trainml._query = AsyncMock(return_value=api_response)
244
+ response = await device.set_config(device_config_id="config-id-1")
245
+ mock_trainml._query.assert_called_once_with(
246
+ "/provider/1/region/a/device/x",
247
+ "PATCH",
248
+ None,
249
+ dict(device_config_id="config-id-1"),
250
+ )
251
+ assert device.id == "device-id-1"
252
+ assert response.id == "device-id-1"
253
+
254
+ @mark.asyncio
255
+ async def test_device_deploy_endpoint(self, device, mock_trainml):
256
+ api_response = None
257
+ mock_trainml._query = AsyncMock(return_value=api_response)
258
+ await device.deploy_endpoint()
259
+ mock_trainml._query.assert_called_once_with(
260
+ "/provider/1/region/a/device/x/deploy", "PUT"
261
+ )
262
+
263
+ @mark.asyncio
264
+ async def test_device_stop_endpoint(self, device, mock_trainml):
265
+ api_response = None
266
+ mock_trainml._query = AsyncMock(return_value=api_response)
267
+ await device.stop_endpoint()
268
+ mock_trainml._query.assert_called_once_with(
269
+ "/provider/1/region/a/device/x/stop", "PUT"
270
+ )
@@ -157,3 +157,46 @@ class nodeTests:
157
157
  mock_trainml._query.assert_called_once_with(
158
158
  "/provider/1/region/a/node/x", "DELETE"
159
159
  )
160
+
161
+ @mark.asyncio
162
+ async def test_node_refresh(self, node, mock_trainml):
163
+ api_response = {
164
+ "provider_uuid": "provider-id-1",
165
+ "region_uuid": "region-id-1",
166
+ "rig_uuid": "rig-id-1",
167
+ "name": "phys-node",
168
+ "type": "permanent",
169
+ "service": "compute",
170
+ "status": "new",
171
+ "online": False,
172
+ "maintenance_mode": True,
173
+ "createdAt": "2020-12-31T23:59:59.000Z",
174
+ }
175
+ mock_trainml._query = AsyncMock(return_value=api_response)
176
+ response = await node.refresh()
177
+ mock_trainml._query.assert_called_once_with(
178
+ f"/provider/1/region/a/node/x", "GET"
179
+ )
180
+ assert node.id == "rig-id-1"
181
+ assert response.id == "rig-id-1"
182
+
183
+ @mark.asyncio
184
+ async def test_node_toggle_maintenance(self, node, mock_trainml):
185
+ api_response = None
186
+ mock_trainml._query = AsyncMock(return_value=api_response)
187
+ await node.toggle_maintenance()
188
+ mock_trainml._query.assert_called_once_with(
189
+ "/provider/1/region/a/node/x/maintenance", "PATCH"
190
+ )
191
+
192
+ @mark.asyncio
193
+ async def test_node_run_action(self, node, mock_trainml):
194
+ api_response = None
195
+ mock_trainml._query = AsyncMock(return_value=api_response)
196
+ await node.run_action(command="report")
197
+ mock_trainml._query.assert_called_once_with(
198
+ "/provider/1/region/a/node/x/action",
199
+ "POST",
200
+ None,
201
+ dict(command="report"),
202
+ )
@@ -123,3 +123,19 @@ class providerTests:
123
123
  mock_trainml._query = AsyncMock(return_value=api_response)
124
124
  await provider.remove()
125
125
  mock_trainml._query.assert_called_once_with("/provider/1", "DELETE")
126
+
127
+ @mark.asyncio
128
+ async def test_provider_refresh(self, provider, mock_trainml):
129
+ api_response = {
130
+ "customer_uuid": "cust-id-1",
131
+ "provider_uuid": "provider-id-1",
132
+ "type": "new provider",
133
+ "credits": 0.0,
134
+ "payment_mode": "credits",
135
+ "createdAt": "2020-12-31T23:59:59.000Z",
136
+ }
137
+ mock_trainml._query = AsyncMock(return_value=api_response)
138
+ response = await provider.refresh()
139
+ mock_trainml._query.assert_called_once_with(f"/provider/1", "GET")
140
+ assert provider.id == "provider-id-1"
141
+ assert response.id == "provider-id-1"
@@ -139,6 +139,24 @@ class regionTests:
139
139
  "/provider/1/region/a", "DELETE"
140
140
  )
141
141
 
142
+ @mark.asyncio
143
+ async def test_region_refresh(self, region, mock_trainml):
144
+ api_response = {
145
+ "provider_uuid": "provider-id-1",
146
+ "region_uuid": "region-id-1",
147
+ "provider_type": "physical",
148
+ "name": "phys-region",
149
+ "status": "new",
150
+ "createdAt": "2020-12-31T23:59:59.000Z",
151
+ }
152
+ mock_trainml._query = AsyncMock(return_value=api_response)
153
+ response = await region.refresh()
154
+ mock_trainml._query.assert_called_once_with(
155
+ f"/provider/1/region/a", "GET"
156
+ )
157
+ assert region.id == "region-id-1"
158
+ assert response.id == "region-id-1"
159
+
142
160
  @mark.asyncio
143
161
  async def test_region_stage_dataset(self, region, mock_trainml):
144
162
  api_response = dict()
@@ -151,3 +151,23 @@ class reservationTests:
151
151
  mock_trainml._query.assert_called_once_with(
152
152
  "/provider/1/region/a/reservation/x", "DELETE"
153
153
  )
154
+
155
+ @mark.asyncio
156
+ async def test_reservation_refresh(self, reservation, mock_trainml):
157
+ api_response = {
158
+ "provider_uuid": "provider-id-1",
159
+ "region_uuid": "region-id-1",
160
+ "reservation_id": "reservation-id-1",
161
+ "name": "On-Prem Reservation",
162
+ "type": "port",
163
+ "resource": "8001",
164
+ "hostname": "service.local",
165
+ "createdAt": "2020-12-31T23:59:59.000Z",
166
+ }
167
+ mock_trainml._query = AsyncMock(return_value=api_response)
168
+ response = await reservation.refresh()
169
+ mock_trainml._query.assert_called_once_with(
170
+ f"/provider/1/region/a/reservation/x", "GET"
171
+ )
172
+ assert reservation.id == "reservation-id-1"
173
+ assert response.id == "reservation-id-1"