trainml 0.5.0__py3-none-any.whl → 0.5.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (36) hide show
  1. tests/integration/test_jobs_integration.py +13 -18
  2. tests/unit/cli/cloudbender/test_cli_device_unit.py +38 -0
  3. tests/unit/cloudbender/test_datastores_unit.py +20 -0
  4. tests/unit/cloudbender/test_device_configs_unit.py +21 -0
  5. tests/unit/cloudbender/test_devices_unit.py +270 -0
  6. tests/unit/cloudbender/test_nodes_unit.py +43 -0
  7. tests/unit/cloudbender/test_providers_unit.py +16 -0
  8. tests/unit/cloudbender/test_regions_unit.py +18 -0
  9. tests/unit/cloudbender/test_reservations_unit.py +20 -0
  10. tests/unit/conftest.py +54 -3
  11. tests/unit/test_auth.py +1 -1
  12. trainml/__init__.py +1 -1
  13. trainml/auth.py +3 -7
  14. trainml/cli/cloudbender/__init__.py +1 -0
  15. trainml/cli/cloudbender/device.py +157 -0
  16. trainml/cli/dataset.py +1 -3
  17. trainml/cli/job/create.py +3 -3
  18. trainml/cli/model.py +2 -0
  19. trainml/cloudbender/cloudbender.py +2 -0
  20. trainml/cloudbender/datastores.py +8 -0
  21. trainml/cloudbender/device_configs.py +8 -0
  22. trainml/cloudbender/devices.py +190 -0
  23. trainml/cloudbender/nodes.py +22 -0
  24. trainml/cloudbender/providers.py +8 -0
  25. trainml/cloudbender/regions.py +8 -0
  26. trainml/cloudbender/reservations.py +8 -0
  27. trainml/datasets.py +25 -14
  28. trainml/gpu_types.py +5 -0
  29. trainml/models.py +22 -3
  30. trainml/trainml.py +2 -1
  31. {trainml-0.5.0.dist-info → trainml-0.5.2.dist-info}/METADATA +1 -1
  32. {trainml-0.5.0.dist-info → trainml-0.5.2.dist-info}/RECORD +36 -32
  33. {trainml-0.5.0.dist-info → trainml-0.5.2.dist-info}/LICENSE +0 -0
  34. {trainml-0.5.0.dist-info → trainml-0.5.2.dist-info}/WHEEL +0 -0
  35. {trainml-0.5.0.dist-info → trainml-0.5.2.dist-info}/entry_points.txt +0 -0
  36. {trainml-0.5.0.dist-info → trainml-0.5.2.dist-info}/top_level.txt +0 -0
tests/unit/conftest.py CHANGED
@@ -22,6 +22,7 @@ from trainml.cloudbender import Cloudbender
22
22
  from trainml.cloudbender.providers import Provider, Providers
23
23
  from trainml.cloudbender.regions import Region, Regions
24
24
  from trainml.cloudbender.nodes import Node, Nodes
25
+ from trainml.cloudbender.devices import Device, Devices
25
26
  from trainml.cloudbender.datastores import Datastore, Datastores
26
27
  from trainml.cloudbender.reservations import Reservation, Reservations
27
28
  from trainml.cloudbender.device_configs import DeviceConfig, DeviceConfigs
@@ -484,19 +485,19 @@ def mock_jobs():
484
485
  {
485
486
  "rig_uuid": "rig-id-1",
486
487
  "job_worker_uuid": "worker-id-11",
487
- "command": "PYTHONPATH=$PYTHONPATH:$TRAINML_MODEL_PATH python -m official.vision.image_classification.resnet_cifar_main --num_gpus=1 --data_dir=$TRAINML_DATA_PATH --model_dir=$TRAINML_OUTPUT_PATH --enable_checkpoint_and_export=True --train_epochs=1 --batch_size=1024",
488
+ "command": "PYTHONPATH=$PYTHONPATH:$ML_MODEL_PATH python -m official.vision.image_classification.resnet_cifar_main --num_gpus=1 --data_dir=$ML_DATA_PATH --model_dir=$ML_OUTPUT_PATH --enable_checkpoint_and_export=True --train_epochs=1 --batch_size=1024",
488
489
  "status": "stopped",
489
490
  },
490
491
  {
491
492
  "rig_uuid": "rig-id-2",
492
493
  "job_worker_uuid": "worker-id-12",
493
- "command": "PYTHONPATH=$PYTHONPATH:$TRAINML_MODEL_PATH python -m official.vision.image_classification.resnet_cifar_main --num_gpus=1 --data_dir=$TRAINML_DATA_PATH --model_dir=$TRAINML_OUTPUT_PATH --enable_checkpoint_and_export=True --train_epochs=1 --batch_size=1024",
494
+ "command": "PYTHONPATH=$PYTHONPATH:$ML_MODEL_PATH python -m official.vision.image_classification.resnet_cifar_main --num_gpus=1 --data_dir=$ML_DATA_PATH --model_dir=$ML_OUTPUT_PATH --enable_checkpoint_and_export=True --train_epochs=1 --batch_size=1024",
494
495
  "status": "stopped",
495
496
  },
496
497
  {
497
498
  "rig_uuid": "rig-id-2",
498
499
  "job_worker_uuid": "worker-id-13",
499
- "command": "PYTHONPATH=$PYTHONPATH:$TRAINML_MODEL_PATH python -m official.vision.image_classification.resnet_cifar_main --num_gpus=1 --data_dir=$TRAINML_DATA_PATH --model_dir=$TRAINML_OUTPUT_PATH --enable_checkpoint_and_export=True --train_epochs=1 --batch_size=1024",
500
+ "command": "PYTHONPATH=$PYTHONPATH:$ML_MODEL_PATH python -m official.vision.image_classification.resnet_cifar_main --num_gpus=1 --data_dir=$ML_DATA_PATH --model_dir=$ML_OUTPUT_PATH --enable_checkpoint_and_export=True --train_epochs=1 --batch_size=1024",
500
501
  "status": "stopped",
501
502
  },
502
503
  ],
@@ -706,6 +707,53 @@ def mock_nodes():
706
707
  ]
707
708
 
708
709
 
710
+ @fixture(scope="session")
711
+ def mock_devices():
712
+ trainml = Mock()
713
+ yield [
714
+ Device(
715
+ trainml,
716
+ **{
717
+ "provider_uuid": "prov-id-1",
718
+ "region_uuid": "reg-id-1",
719
+ "device_id": "dev-id-1",
720
+ "type": "device",
721
+ "service": "compute",
722
+ "friendly_name": "hq-orin-01",
723
+ "hostname": "hq-orin-01",
724
+ "status": "active",
725
+ "online": True,
726
+ "maintenance_mode": False,
727
+ "job_status": "running",
728
+ "job_last_deployed": "2023-06-02T21:22:40.084Z",
729
+ "job_config_id": "job-id-1",
730
+ "job_config_revision": "1685740490096",
731
+ "device_config_id": "conf-id-1",
732
+ },
733
+ ),
734
+ Device(
735
+ trainml,
736
+ **{
737
+ "provider_uuid": "prov-id-1",
738
+ "region_uuid": "reg-id-1",
739
+ "device_id": "dev-id-2",
740
+ "type": "device",
741
+ "service": "compute",
742
+ "friendly_name": "hq-orin-02",
743
+ "hostname": "hq-orin-02",
744
+ "status": "active",
745
+ "online": True,
746
+ "maintenance_mode": False,
747
+ "job_status": "running",
748
+ "job_last_deployed": "2023-06-02T21:22:40.084Z",
749
+ "job_config_id": "job-id-2",
750
+ "job_config_revision": "1685740490096",
751
+ "device_config_id": "conf-id-1",
752
+ },
753
+ ),
754
+ ]
755
+
756
+
709
757
  @fixture(scope="session")
710
758
  def mock_datastores():
711
759
  trainml = Mock()
@@ -863,6 +911,7 @@ def mock_trainml(
863
911
  mock_providers,
864
912
  mock_regions,
865
913
  mock_nodes,
914
+ mock_devices,
866
915
  mock_datastores,
867
916
  mock_reservations,
868
917
  mock_device_configs,
@@ -898,6 +947,8 @@ def mock_trainml(
898
947
  trainml.cloudbender.regions.list = AsyncMock(return_value=mock_regions)
899
948
  trainml.cloudbender.nodes = create_autospec(Nodes)
900
949
  trainml.cloudbender.nodes.list = AsyncMock(return_value=mock_nodes)
950
+ trainml.cloudbender.devices = create_autospec(Nodes)
951
+ trainml.cloudbender.devices.list = AsyncMock(return_value=mock_devices)
901
952
  trainml.cloudbender.datastores = create_autospec(Datastores)
902
953
  trainml.cloudbender.datastores.list = AsyncMock(
903
954
  return_value=mock_datastores
tests/unit/test_auth.py CHANGED
@@ -22,7 +22,7 @@ pytestmark = [mark.sdk, mark.unit]
22
22
  },
23
23
  )
24
24
  def test_auth_from_envs():
25
- auth = specimen.Auth()
25
+ auth = specimen.Auth(config_dir=os.path.expanduser("~/.trainml"))
26
26
  assert auth.__dict__.get("username") == "user-id"
27
27
  assert auth.__dict__.get("password") == "key"
28
28
  assert auth.__dict__.get("region") == "ap-east-1"
trainml/__init__.py CHANGED
@@ -13,5 +13,5 @@ logging.basicConfig(
13
13
  logger = logging.getLogger(__name__)
14
14
 
15
15
 
16
- __version__ = "0.5.0"
16
+ __version__ = "0.5.2"
17
17
  __all__ = "TrainML"
trainml/auth.py CHANGED
@@ -222,10 +222,6 @@ n_hex = (
222
222
  g_hex = "2"
223
223
  info_bits = bytearray("Caldera Derived Key", "utf-8")
224
224
 
225
- CONFIG_DIR = os.path.expanduser(
226
- os.environ.get("TRAINML_CONFIG_DIR") or "~/.trainml"
227
- )
228
-
229
225
 
230
226
  def hash_sha256(buf):
231
227
  """AuthenticationHelper.hash"""
@@ -512,9 +508,9 @@ class AWSSRP(object):
512
508
 
513
509
 
514
510
  class Auth(object):
515
- def __init__(self, domain_suffix="trainml.ai", **kwargs):
511
+ def __init__(self, config_dir, domain_suffix="trainml.ai", **kwargs):
516
512
  try:
517
- with open(f"{CONFIG_DIR}/environment.json", "r") as file:
513
+ with open(f"{config_dir}/environment.json", "r") as file:
518
514
  env_str = file.read().replace("\n", "")
519
515
  env = json.loads(env_str)
520
516
  except:
@@ -544,7 +540,7 @@ class Auth(object):
544
540
  )
545
541
 
546
542
  try:
547
- with open(f"{CONFIG_DIR}/credentials.json", "r") as file:
543
+ with open(f"{config_dir}/credentials.json", "r") as file:
548
544
  key_str = file.read().replace("\n", "")
549
545
  keys = json.loads(key_str)
550
546
  except:
@@ -13,5 +13,6 @@ def cloudbender(config):
13
13
  from trainml.cli.cloudbender.provider import provider
14
14
  from trainml.cli.cloudbender.region import region
15
15
  from trainml.cli.cloudbender.node import node
16
+ from trainml.cli.cloudbender.device import device
16
17
  from trainml.cli.cloudbender.datastore import datastore
17
18
  from trainml.cli.cloudbender.reservation import reservation
@@ -0,0 +1,157 @@
1
+ import click
2
+ from trainml.cli import cli, pass_config, search_by_id_name
3
+ from trainml.cli.cloudbender import cloudbender
4
+
5
+
6
+ @cloudbender.group()
7
+ @pass_config
8
+ def device(config):
9
+ """trainML CloudBender device commands."""
10
+ pass
11
+
12
+
13
+ @device.command()
14
+ @click.option(
15
+ "--provider",
16
+ "-p",
17
+ type=click.STRING,
18
+ required=True,
19
+ help="The provider ID of the region.",
20
+ )
21
+ @click.option(
22
+ "--region",
23
+ "-r",
24
+ type=click.STRING,
25
+ required=True,
26
+ help="The region ID to list devices for.",
27
+ )
28
+ @pass_config
29
+ def list(config, provider, region):
30
+ """List devices."""
31
+ data = [
32
+ [
33
+ "ID",
34
+ "NAME",
35
+ "STATUS",
36
+ "JOB STATUS",
37
+ "ONLINE",
38
+ "MAINTENANCE",
39
+ ],
40
+ [
41
+ "-" * 80,
42
+ "-" * 80,
43
+ "-" * 80,
44
+ "-" * 80,
45
+ "-" * 80,
46
+ "-" * 80,
47
+ "-" * 80,
48
+ ],
49
+ ]
50
+
51
+ devices = config.trainml.run(
52
+ config.trainml.client.cloudbender.devices.list(
53
+ provider_uuid=provider, region_uuid=region
54
+ )
55
+ )
56
+
57
+ for device in devices:
58
+ data.append(
59
+ [
60
+ device.id,
61
+ device.name,
62
+ device.status,
63
+ device.job_status,
64
+ "X" if device.online else "",
65
+ "X" if device.maintenance_mode else "",
66
+ ]
67
+ )
68
+
69
+ for row in data:
70
+ click.echo(
71
+ "{: >37.36} {: >29.28} {: >9.8} {: >11.10} {: >7.6} {: >12.11}"
72
+ "".format(*row),
73
+ file=config.stdout,
74
+ )
75
+
76
+
77
+ @device.command()
78
+ @click.option(
79
+ "--provider",
80
+ "-p",
81
+ type=click.STRING,
82
+ required=True,
83
+ help="The provider ID of the region.",
84
+ )
85
+ @click.option(
86
+ "--region",
87
+ "-r",
88
+ type=click.STRING,
89
+ required=True,
90
+ help="The region ID to create the region in.",
91
+ )
92
+ @click.option(
93
+ "--minion-id",
94
+ "-m",
95
+ type=click.STRING,
96
+ required=True,
97
+ help="The minion_id of the new node.",
98
+ )
99
+ @click.option(
100
+ "--hostname",
101
+ "-h",
102
+ type=click.STRING,
103
+ help="The hostname (if different from name)",
104
+ )
105
+ @click.argument("name", type=click.STRING, required=True)
106
+ @pass_config
107
+ def create(config, provider, region, minion_id, hostname, name):
108
+ """
109
+ Creates a node.
110
+ """
111
+ if not hostname:
112
+ hostname = name
113
+ return config.trainml.run(
114
+ config.trainml.client.cloudbender.devices.create(
115
+ provider_uuid=provider,
116
+ region_uuid=region,
117
+ friendly_name=name,
118
+ hostname=hostname,
119
+ minion_id=minion_id,
120
+ )
121
+ )
122
+
123
+
124
+ @device.command()
125
+ @click.option(
126
+ "--provider",
127
+ "-p",
128
+ type=click.STRING,
129
+ required=True,
130
+ help="The provider ID of the region.",
131
+ )
132
+ @click.option(
133
+ "--region",
134
+ "-r",
135
+ type=click.STRING,
136
+ required=True,
137
+ help="The region ID to delete the node from.",
138
+ )
139
+ @click.argument("device", type=click.STRING)
140
+ @pass_config
141
+ def remove(config, provider, region, node):
142
+ """
143
+ Remove a device.
144
+
145
+ DEVICE may be specified by name or ID, but ID is preferred.
146
+ """
147
+ devices = config.trainml.run(
148
+ config.trainml.client.cloudbender.devices.list(
149
+ provider_uuid=provider, region_uuid=region
150
+ )
151
+ )
152
+
153
+ found = search_by_id_name(device, devices)
154
+ if None is found:
155
+ raise click.UsageError("Cannot find specified device.")
156
+
157
+ return config.trainml.run(found.remove())
trainml/cli/dataset.py CHANGED
@@ -252,9 +252,7 @@ def rename(config, dataset, name):
252
252
  DATASET may be specified by name or ID, but ID is preferred.
253
253
  """
254
254
  try:
255
- dataset = config.trainml.run(
256
- config.trainml.client.datasets.get(dataset)
257
- )
255
+ dataset = config.trainml.run(config.trainml.client.datasets.get(dataset))
258
256
  if dataset is None:
259
257
  raise click.UsageError("Cannot find specified dataset.")
260
258
  except:
trainml/cli/job/create.py CHANGED
@@ -683,7 +683,7 @@ def training(
683
683
  config.trainml.client.jobs.create(
684
684
  name=name,
685
685
  type="training",
686
- gpu_type=gpu_type,
686
+ gpu_types=gpu_type,
687
687
  cpu_count=cpu_count,
688
688
  disk_size=disk_size,
689
689
  workers=[command for command in commands],
@@ -1021,7 +1021,7 @@ def inference(
1021
1021
  config.trainml.client.jobs.create(
1022
1022
  name=name,
1023
1023
  type="inference",
1024
- gpu_type=gpu_type,
1024
+ gpu_types=gpu_type,
1025
1025
  cpu_count=cpu_count,
1026
1026
  disk_size=disk_size,
1027
1027
  workers=[command],
@@ -1330,7 +1330,7 @@ def endpoint(
1330
1330
  config.trainml.client.jobs.create(
1331
1331
  name=name,
1332
1332
  type="endpoint",
1333
- gpu_type=gpu_type,
1333
+ gpu_types=gpu_type,
1334
1334
  cpu_count=cpu_count,
1335
1335
  disk_size=disk_size,
1336
1336
  endpoint=dict(routes=routes),
trainml/cli/model.py CHANGED
@@ -1,4 +1,5 @@
1
1
  import click
2
+ import logging
2
3
  from trainml.cli import cli, pass_config, search_by_id_name
3
4
 
4
5
 
@@ -64,6 +65,7 @@ def connect(config, model, attach):
64
65
  models = config.trainml.run(config.trainml.client.models.list())
65
66
 
66
67
  found = search_by_id_name(model, models)
68
+ logging.debug(found)
67
69
  if None is found:
68
70
  raise click.UsageError("Cannot find specified model.")
69
71
 
@@ -1,6 +1,7 @@
1
1
  from .providers import Providers
2
2
  from .regions import Regions
3
3
  from .nodes import Nodes
4
+ from .devices import Devices
4
5
  from .datastores import Datastores
5
6
  from .reservations import Reservations
6
7
  from .device_configs import DeviceConfigs
@@ -12,6 +13,7 @@ class Cloudbender(object):
12
13
  self.providers = Providers(trainml)
13
14
  self.regions = Regions(trainml)
14
15
  self.nodes = Nodes(trainml)
16
+ self.devices = Devices(trainml)
15
17
  self.datastores = Datastores(trainml)
16
18
  self.reservations = Reservations(trainml)
17
19
  self.device_configs = DeviceConfigs(trainml)
@@ -118,3 +118,11 @@ class Datastore:
118
118
  f"/provider/{self._provider_uuid}/region/{self._region_uuid}/datastore/{self._id}",
119
119
  "DELETE",
120
120
  )
121
+
122
+ async def refresh(self):
123
+ resp = await self.trainml._query(
124
+ f"/provider/{self._provider_uuid}/region/{self._region_uuid}/datastore/{self._id}",
125
+ "GET",
126
+ )
127
+ self.__init__(self.trainml, **resp)
128
+ return self
@@ -107,3 +107,11 @@ class DeviceConfig:
107
107
  f"/provider/{self._provider_uuid}/region/{self._region_uuid}/device/config/{self._id}",
108
108
  "DELETE",
109
109
  )
110
+
111
+ async def refresh(self):
112
+ resp = await self.trainml._query(
113
+ f"/provider/{self._provider_uuid}/region/{self._region_uuid}/device/config/{self._id}",
114
+ "GET",
115
+ )
116
+ self.__init__(self.trainml, **resp)
117
+ return self
@@ -0,0 +1,190 @@
1
+ import json
2
+ import logging
3
+
4
+
5
+ class Devices(object):
6
+ def __init__(self, trainml):
7
+ self.trainml = trainml
8
+
9
+ async def get(self, provider_uuid, region_uuid, id, **kwargs):
10
+ resp = await self.trainml._query(
11
+ f"/provider/{provider_uuid}/region/{region_uuid}/device/{id}",
12
+ "GET",
13
+ kwargs,
14
+ )
15
+ return Device(self.trainml, **resp)
16
+
17
+ async def list(self, provider_uuid, region_uuid, **kwargs):
18
+ resp = await self.trainml._query(
19
+ f"/provider/{provider_uuid}/region/{region_uuid}/device",
20
+ "GET",
21
+ kwargs,
22
+ )
23
+ devices = [Device(self.trainml, **device) for device in resp]
24
+ return devices
25
+
26
+ async def create(
27
+ self,
28
+ provider_uuid,
29
+ region_uuid,
30
+ friendly_name,
31
+ hostname,
32
+ minion_id,
33
+ **kwargs,
34
+ ):
35
+ logging.info(f"Creating Device {friendly_name}")
36
+ data = dict(
37
+ friendly_name=friendly_name,
38
+ hostname=hostname,
39
+ minion_id=minion_id,
40
+ type="device",
41
+ service="compute",
42
+ **kwargs,
43
+ )
44
+ payload = {k: v for k, v in data.items() if v is not None}
45
+ resp = await self.trainml._query(
46
+ f"/provider/{provider_uuid}/region/{region_uuid}/device",
47
+ "POST",
48
+ None,
49
+ payload,
50
+ )
51
+ device = Device(self.trainml, **resp)
52
+ logging.info(f"Created Device {friendly_name} with id {device.id}")
53
+ return device
54
+
55
+ async def remove(self, provider_uuid, region_uuid, id, **kwargs):
56
+ await self.trainml._query(
57
+ f"/provider/{provider_uuid}/region/{region_uuid}/device/{id}",
58
+ "DELETE",
59
+ kwargs,
60
+ )
61
+
62
+
63
+ class Device:
64
+ def __init__(self, trainml, **kwargs):
65
+ self.trainml = trainml
66
+ self._device = kwargs
67
+ self._id = self._device.get("device_id")
68
+ self._provider_uuid = self._device.get("provider_uuid")
69
+ self._region_uuid = self._device.get("region_uuid")
70
+ self._name = self._device.get("friendly_name")
71
+ self._hostname = self._device.get("hostname")
72
+ self._status = self._device.get("status")
73
+ self._online = self._device.get("online")
74
+ self._maintenance_mode = self._device.get("maintenance_mode")
75
+ self._device_config_id = self._device.get("device_config_id")
76
+ self._job_status = self._device.get("job_status")
77
+ self._job_last_deployed = self._device.get("job_last_deployed")
78
+ self._job_config_id = self._device.get("job_config_id")
79
+ self._job_config_revision = self._device.get("job_config_revision")
80
+
81
+ @property
82
+ def id(self) -> str:
83
+ return self._id
84
+
85
+ @property
86
+ def provider_uuid(self) -> str:
87
+ return self._provider_uuid
88
+
89
+ @property
90
+ def region_uuid(self) -> str:
91
+ return self._region_uuid
92
+
93
+ @property
94
+ def name(self) -> str:
95
+ return self._name
96
+
97
+ @property
98
+ def hostname(self) -> str:
99
+ return self._hostname
100
+
101
+ @property
102
+ def status(self) -> str:
103
+ return self._status
104
+
105
+ @property
106
+ def online(self) -> bool:
107
+ return self._online
108
+
109
+ @property
110
+ def maintenance_mode(self) -> bool:
111
+ return self._maintenance_mode
112
+
113
+ @property
114
+ def device_config_id(self) -> str:
115
+ return self._device_config_id
116
+
117
+ @property
118
+ def job_status(self) -> str:
119
+ return self._job_status
120
+
121
+ @property
122
+ def job_last_deployed(self) -> str:
123
+ return self._job_last_deployed
124
+
125
+ @property
126
+ def job_config_id(self) -> str:
127
+ return self._job_config_id
128
+
129
+ @property
130
+ def job_config_revision(self) -> str:
131
+ return self._job_config_revision
132
+
133
+ def __str__(self):
134
+ return json.dumps({k: v for k, v in self._device.items()})
135
+
136
+ def __repr__(self):
137
+ return f"Device( trainml , **{self._device.__repr__()})"
138
+
139
+ def __bool__(self):
140
+ return bool(self._id)
141
+
142
+ async def remove(self):
143
+ await self.trainml._query(
144
+ f"/provider/{self._provider_uuid}/region/{self._region_uuid}/device/{self._id}",
145
+ "DELETE",
146
+ )
147
+
148
+ async def refresh(self):
149
+ resp = await self.trainml._query(
150
+ f"/provider/{self._provider_uuid}/region/{self._region_uuid}/device/{self._id}",
151
+ "GET",
152
+ )
153
+ self.__init__(self.trainml, **resp)
154
+ return self
155
+
156
+ async def toggle_maintenance(self):
157
+ await self.trainml._query(
158
+ f"/provider/{self._provider_uuid}/region/{self._region_uuid}/device/{self._id}/maintenance",
159
+ "PATCH",
160
+ )
161
+
162
+ async def run_action(self, command):
163
+ await self.trainml._query(
164
+ f"/provider/{self._provider_uuid}/region/{self._region_uuid}/device/{self._id}/action",
165
+ "POST",
166
+ None,
167
+ dict(command=command),
168
+ )
169
+
170
+ async def set_config(self, device_config_id):
171
+ resp = await self.trainml._query(
172
+ f"/provider/{self._provider_uuid}/region/{self._region_uuid}/device/{self._id}",
173
+ "PATCH",
174
+ None,
175
+ dict(device_config_id=device_config_id),
176
+ )
177
+ self.__init__(self.trainml, **resp)
178
+ return self
179
+
180
+ async def deploy_endpoint(self):
181
+ await self.trainml._query(
182
+ f"/provider/{self._provider_uuid}/region/{self._region_uuid}/device/{self._id}/deploy",
183
+ "PUT",
184
+ )
185
+
186
+ async def stop_endpoint(self):
187
+ await self.trainml._query(
188
+ f"/provider/{self._provider_uuid}/region/{self._region_uuid}/device/{self._id}/stop",
189
+ "PUT",
190
+ )
@@ -131,3 +131,25 @@ class Node:
131
131
  f"/provider/{self._provider_uuid}/region/{self._region_uuid}/node/{self._id}",
132
132
  "DELETE",
133
133
  )
134
+
135
+ async def refresh(self):
136
+ resp = await self.trainml._query(
137
+ f"/provider/{self._provider_uuid}/region/{self._region_uuid}/node/{self._id}",
138
+ "GET",
139
+ )
140
+ self.__init__(self.trainml, **resp)
141
+ return self
142
+
143
+ async def toggle_maintenance(self):
144
+ await self.trainml._query(
145
+ f"/provider/{self._provider_uuid}/region/{self._region_uuid}/node/{self._id}/maintenance",
146
+ "PATCH",
147
+ )
148
+
149
+ async def run_action(self, command):
150
+ await self.trainml._query(
151
+ f"/provider/{self._provider_uuid}/region/{self._region_uuid}/node/{self._id}/action",
152
+ "POST",
153
+ None,
154
+ dict(command=command),
155
+ )
@@ -61,3 +61,11 @@ class Provider:
61
61
 
62
62
  async def remove(self):
63
63
  await self.trainml._query(f"/provider/{self._id}", "DELETE")
64
+
65
+ async def refresh(self):
66
+ resp = await self.trainml._query(
67
+ f"/provider/{self._id}",
68
+ "GET",
69
+ )
70
+ self.__init__(self.trainml, **resp)
71
+ return self
@@ -80,6 +80,14 @@ class Region:
80
80
  f"/provider/{self._provider_uuid}/region/{self._id}", "DELETE"
81
81
  )
82
82
 
83
+ async def refresh(self):
84
+ resp = await self.trainml._query(
85
+ f"/provider/{self._provider_uuid}/region/{self._id}",
86
+ "GET",
87
+ )
88
+ self.__init__(self.trainml, **resp)
89
+ return self
90
+
83
91
  async def add_dataset(self, project_uuid, dataset_uuid, **kwargs):
84
92
  await self.trainml._query(
85
93
  f"/provider/{self._provider_uuid}/region/{self._id}/dataset",
@@ -116,3 +116,11 @@ class Reservation:
116
116
  f"/provider/{self._provider_uuid}/region/{self._region_uuid}/reservation/{self._id}",
117
117
  "DELETE",
118
118
  )
119
+
120
+ async def refresh(self):
121
+ resp = await self.trainml._query(
122
+ f"/provider/{self._provider_uuid}/region/{self._region_uuid}/reservation/{self._id}",
123
+ "GET",
124
+ )
125
+ self.__init__(self.trainml, **resp)
126
+ return self