trainml 0.5.3__py3-none-any.whl → 0.5.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
trainml/volumes.py ADDED
@@ -0,0 +1,255 @@
1
+ import json
2
+ import logging
3
+ import math
4
+ import asyncio
5
+ from datetime import datetime
6
+
7
+ from .exceptions import (
8
+ VolumeError,
9
+ ApiError,
10
+ SpecificationError,
11
+ TrainMLException,
12
+ )
13
+ from .connections import Connection
14
+
15
+
16
+ class Volumes(object):
17
+ def __init__(self, trainml):
18
+ self.trainml = trainml
19
+
20
+ async def get(self, id, **kwargs):
21
+ resp = await self.trainml._query(f"/volume/{id}", "GET", kwargs)
22
+ return Volume(self.trainml, **resp)
23
+
24
+ async def list(self, **kwargs):
25
+ resp = await self.trainml._query(f"/volume", "GET", kwargs)
26
+ volumes = [Volume(self.trainml, **volume) for volume in resp]
27
+ return volumes
28
+
29
+ async def create(self, name, source_type, source_uri, capacity, **kwargs):
30
+ data = dict(
31
+ name=name,
32
+ source_type=source_type,
33
+ source_uri=source_uri,
34
+ capacity=capacity,
35
+ source_options=kwargs.get("source_options"),
36
+ project_uuid=kwargs.get("project_uuid") or self.trainml.active_project,
37
+ )
38
+ payload = {k: v for k, v in data.items() if v is not None}
39
+ logging.info(f"Creating Volume {name}")
40
+ resp = await self.trainml._query("/volume", "POST", None, payload)
41
+ volume = Volume(self.trainml, **resp)
42
+ logging.info(f"Created Volume {name} with id {volume.id}")
43
+
44
+ return volume
45
+
46
+ async def remove(self, id, **kwargs):
47
+ await self.trainml._query(f"/volume/{id}", "DELETE", dict(**kwargs, force=True))
48
+
49
+
50
+ class Volume:
51
+ def __init__(self, trainml, **kwargs):
52
+ self.trainml = trainml
53
+ self._volume = kwargs
54
+ self._id = self._volume.get("id", self._volume.get("id"))
55
+ self._status = self._volume.get("status")
56
+ self._name = self._volume.get("name")
57
+ self._capacity = self._volume.get("capacity")
58
+ self._used_size = self._volume.get("used_size")
59
+ self._billed_size = self._volume.get("billed_size")
60
+ self._project_uuid = self._volume.get("project_uuid")
61
+
62
+ @property
63
+ def id(self) -> str:
64
+ return self._id
65
+
66
+ @property
67
+ def status(self) -> str:
68
+ return self._status
69
+
70
+ @property
71
+ def name(self) -> str:
72
+ return self._name
73
+
74
+ @property
75
+ def capacity(self) -> str:
76
+ return self._capacity
77
+
78
+ @property
79
+ def used_size(self) -> int:
80
+ return self._used_size
81
+
82
+ @property
83
+ def billed_size(self) -> int:
84
+ return self._billed_size
85
+
86
+ def __str__(self):
87
+ return json.dumps({k: v for k, v in self._volume.items()})
88
+
89
+ def __repr__(self):
90
+ return f"Volume( trainml , **{self._volume.__repr__()})"
91
+
92
+ def __bool__(self):
93
+ return bool(self._id)
94
+
95
+ async def get_log_url(self):
96
+ resp = await self.trainml._query(
97
+ f"/volume/{self._id}/logs",
98
+ "GET",
99
+ dict(project_uuid=self._project_uuid),
100
+ )
101
+ return resp
102
+
103
+ async def get_details(self):
104
+ resp = await self.trainml._query(
105
+ f"/volume/{self._id}/details",
106
+ "GET",
107
+ dict(project_uuid=self._project_uuid),
108
+ )
109
+ return resp
110
+
111
+ async def get_connection_utility_url(self):
112
+ resp = await self.trainml._query(
113
+ f"/volume/{self._id}/download",
114
+ "GET",
115
+ dict(project_uuid=self._project_uuid),
116
+ )
117
+ return resp
118
+
119
+ def get_connection_details(self):
120
+ if self._volume.get("vpn"):
121
+ details = dict(
122
+ entity_type="volume",
123
+ project_uuid=self._volume.get("project_uuid"),
124
+ cidr=self._volume.get("vpn").get("cidr"),
125
+ ssh_port=self._volume.get("vpn").get("client").get("ssh_port"),
126
+ input_path=(
127
+ self._volume.get("source_uri")
128
+ if self.status in ["new", "downloading"]
129
+ else None
130
+ ),
131
+ output_path=(
132
+ self._volume.get("output_uri")
133
+ if self.status == "exporting"
134
+ else None
135
+ ),
136
+ )
137
+ else:
138
+ details = dict()
139
+ return details
140
+
141
+ async def connect(self):
142
+ if self.status in ["ready", "failed"]:
143
+ raise SpecificationError(
144
+ "status",
145
+ f"You can only connect to downloading or exporting volumes.",
146
+ )
147
+ if self.status == "new":
148
+ await self.wait_for("downloading")
149
+ connection = Connection(
150
+ self.trainml, entity_type="volume", id=self.id, entity=self
151
+ )
152
+ await connection.start()
153
+ return connection.status
154
+
155
+ async def disconnect(self):
156
+ connection = Connection(
157
+ self.trainml, entity_type="volume", id=self.id, entity=self
158
+ )
159
+ await connection.stop()
160
+ return connection.status
161
+
162
+ async def remove(self, force=False):
163
+ await self.trainml._query(
164
+ f"/volume/{self._id}",
165
+ "DELETE",
166
+ dict(project_uuid=self._project_uuid, force=force),
167
+ )
168
+
169
+ async def rename(self, name):
170
+ resp = await self.trainml._query(
171
+ f"/volume/{self._id}",
172
+ "PATCH",
173
+ dict(project_uuid=self._project_uuid),
174
+ dict(name=name),
175
+ )
176
+ self.__init__(self.trainml, **resp)
177
+ return self
178
+
179
+ async def export(self, output_type, output_uri, output_options=dict()):
180
+ resp = await self.trainml._query(
181
+ f"/volume/{self._id}/export",
182
+ "POST",
183
+ dict(project_uuid=self._project_uuid),
184
+ dict(
185
+ output_type=output_type,
186
+ output_uri=output_uri,
187
+ output_options=output_options,
188
+ ),
189
+ )
190
+ self.__init__(self.trainml, **resp)
191
+ return self
192
+
193
+ def _get_msg_handler(self, msg_handler):
194
+ def handler(data):
195
+ if data.get("type") == "subscription":
196
+ if msg_handler:
197
+ msg_handler(data)
198
+ else:
199
+ timestamp = datetime.fromtimestamp(int(data.get("time")) / 1000)
200
+ print(
201
+ f"{timestamp.strftime('%m/%d/%Y, %H:%M:%S')}: {data.get('msg').rstrip()}"
202
+ )
203
+
204
+ return handler
205
+
206
+ async def attach(self, msg_handler=None):
207
+ await self.refresh()
208
+ if self.status not in ["ready", "failed"]:
209
+ await self.trainml._ws_subscribe(
210
+ "volume",
211
+ self._project_uuid,
212
+ self.id,
213
+ self._get_msg_handler(msg_handler),
214
+ )
215
+
216
+ async def refresh(self):
217
+ resp = await self.trainml._query(
218
+ f"/volume/{self.id}",
219
+ "GET",
220
+ dict(project_uuid=self._project_uuid),
221
+ )
222
+ self.__init__(self.trainml, **resp)
223
+ return self
224
+
225
+ async def wait_for(self, status, timeout=300):
226
+ valid_statuses = ["downloading", "ready", "archived"]
227
+ if not status in valid_statuses:
228
+ raise SpecificationError(
229
+ "status",
230
+ f"Invalid wait_for status {status}. Valid statuses are: {valid_statuses}",
231
+ )
232
+ if self.status == status:
233
+ return
234
+ POLL_INTERVAL_MIN = 5
235
+ POLL_INTERVAL_MAX = 60
236
+ POLL_INTERVAL = max(min(timeout / 60, POLL_INTERVAL_MAX), POLL_INTERVAL_MIN)
237
+ retry_count = math.ceil(timeout / POLL_INTERVAL)
238
+ count = 0
239
+ while count < retry_count:
240
+ await asyncio.sleep(POLL_INTERVAL)
241
+ try:
242
+ await self.refresh()
243
+ except ApiError as e:
244
+ if status == "archived" and e.status == 404:
245
+ return
246
+ raise e
247
+ if self.status == status:
248
+ return self
249
+ elif self.status == "failed":
250
+ raise VolumeError(self.status, self)
251
+ else:
252
+ count += 1
253
+ logging.debug(f"self: {self}, retry count {count}")
254
+
255
+ raise TrainMLException(f"Timeout waiting for {status}")
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: trainml
3
- Version: 0.5.3
3
+ Version: 0.5.5
4
4
  Summary: trainML client SDK and command line utilities
5
5
  Home-page: https://github.com/trainML/trainml-cli
6
6
  Author: trainML
@@ -4,18 +4,18 @@ examples/local_storage.py,sha256=w8iAeqr5CLOCOkNrqGzEDtybjDGGY7SQUqeE0ibMUrM,174
4
4
  examples/training_inference_pipeline.py,sha256=SNr4RFT9y69F9G9tMD8ONUbJmXRFrq1yxynq-FbfEf8,2334
5
5
  tests/integration/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
6
6
  tests/integration/conftest.py,sha256=VWWTfofsFcBOdSCXQxNYbMcDEEaErDi2wFFMn--LE4Y,1134
7
- tests/integration/test_checkpoints_integration.py,sha256=Ev-GxXiBupOi3KduYoEMhb2u-MZFEBTCi-E2zJgQ5AA,3128
8
- tests/integration/test_datasets_integration.py,sha256=jskX8y9moLvAkZLcQL4iSBrUXSGGuWJ95Vw1JuxxwD8,3424
7
+ tests/integration/test_checkpoints_integration.py,sha256=mLha1BhVZ916OJIDOKF6vah3kxJXQu7pbAsHq7lsNCE,3230
8
+ tests/integration/test_datasets_integration.py,sha256=zdHOevduuMUWvVxaHBslpmH8AdvPdqEJ95MdqCC5_rw,3499
9
9
  tests/integration/test_environments_integration.py,sha256=0IckhJvQhd8j4Ouiu0hMq2b7iA1dbZpZYmknyfWjsFM,1403
10
10
  tests/integration/test_gpu_types_integration.py,sha256=V2OncokZWWVq_l5FSmKEDM4EsWrmpB-zKiVPt-we0aY,1256
11
- tests/integration/test_jobs_integration.py,sha256=A70q7QADTeETqUsjEkrHaPAbB13GIYiBecTtWhailwA,23341
12
- tests/integration/test_models_integration.py,sha256=xJPq_3m0Cf1liMH8e49ON_L3MO5XcPtJIz_MD9plkyU,2848
13
- tests/integration/test_projects_integration.py,sha256=_tmMRFFBe29WaWWuEy3_0j7SKyJc_60JS99hgJUHTG4,1492
14
- tests/integration/test_providers_integration.py,sha256=bBVvlSDLCnofSaD36OB21KlinEKeK04gPB--bbHJiP0,1419
11
+ tests/integration/test_jobs_integration.py,sha256=Wpva99kfDWz1IRI-2l8GoHgpLsY-cVIIAWRNj_ik2I8,24730
12
+ tests/integration/test_models_integration.py,sha256=UPRAz0lcpzGihsnDUARoafbd5sZ6OM8TIeh8HNN6Bg0,2902
13
+ tests/integration/test_projects_integration.py,sha256=BX-LqLfzawTQUhtx--5dw7QqR8kl_CJvwSCyNXDUQTw,1446
14
+ tests/integration/test_volumes_integration.py,sha256=gOmZpwwFxqeOAVmfKWSTmuyshx8nb2zu_0xv1RUEepM,3270
15
15
  tests/integration/cloudbender/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
16
16
  tests/integration/cloudbender/test_providers_integration.py,sha256=oV8ydFsosDZ_Z1Dkg2IN-ZhWuIl5e_HkHAORMsOsAJc,1473
17
17
  tests/unit/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
18
- tests/unit/conftest.py,sha256=AZfqMjB6qXyJdH6cRpxrZK4ciioG6HJTQIIDDkJB2H8,29059
18
+ tests/unit/conftest.py,sha256=fbID1lG1FZ-bywEjM7kEriyN_SUqPasRJVIhIqqsf98,31197
19
19
  tests/unit/test_auth.py,sha256=nfhlOCR7rUsn_MaD8QQtBc2v0k8pIxqbzGgRAZK1WGc,858
20
20
  tests/unit/test_checkpoints_unit.py,sha256=4Add2DXZCuriSZ0atvOXc8fsEGMaEfPhYmT8Q3UgP5E,16008
21
21
  tests/unit/test_connections_unit.py,sha256=FzN2ddQxNpjxzNGUsXhjTk0HnD24wSPelPTL4o_r-Ho,5507
@@ -26,8 +26,8 @@ tests/unit/test_gpu_types_unit.py,sha256=c9ie6YSYT5onBnlmHvHWON9WgQiJ1eO2C-4Tk-U
26
26
  tests/unit/test_jobs_unit.py,sha256=bZxN9HUfHCyQCjZCZGn6WFIhu8S5FU1z5ZG9sgH2XEg,26835
27
27
  tests/unit/test_models_unit.py,sha256=uezWF7FUHGmCSQBtpyyKhBttTnCTRjxU22NsHdJLYYg,15064
28
28
  tests/unit/test_projects_unit.py,sha256=iyqYntMj1UivqpL8GUnJeX2p4m9coEG8qL4c0iSfCCA,9530
29
- tests/unit/test_providers_unit.py,sha256=nEizghnC8pfDubkCw-kMmS_QQOUUWBk3i8D44pnyljo,3700
30
29
  tests/unit/test_trainml.py,sha256=8vAKvFD1xYsx_VY4HFVa0b1MUlMoNApY6TO8r7vI-UQ,1701
30
+ tests/unit/test_volumes_unit.py,sha256=KHVmdbQIiX8tEE09U-XsH-vl6wfYGVoRzR_UQJlhOVE,15305
31
31
  tests/unit/cli/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
32
32
  tests/unit/cli/conftest.py,sha256=w6p_2URicywJKUCtY79tSD_mx8cwJtxHbK_Lu3grOYs,236
33
33
  tests/unit/cli/test_cli_checkpoint_unit.py,sha256=6gO6PWWxNJiL520mhTFkSUJT7p5-dptLwUgDjPeHNrA,702
@@ -37,6 +37,7 @@ tests/unit/cli/test_cli_gpu_unit.py,sha256=FIq3tQIDmeD-pvxkhJKMykfnlWcVxho2vRkoC
37
37
  tests/unit/cli/test_cli_job_unit.py,sha256=xUqkLFDIyI1ExiVVgr-218gQSlFSYCm4RRX_eyipJhY,611
38
38
  tests/unit/cli/test_cli_model_unit.py,sha256=fE-CRVg8gbtDlwrKBkf-hc9x7EhFlYeE3jlum1E27EA,629
39
39
  tests/unit/cli/test_cli_project_unit.py,sha256=1tEfwXJqnac41bDtvYcN_gpScnb4UY6s5dHfTl-0YoQ,1756
40
+ tests/unit/cli/test_cli_volume_unit.py,sha256=oggGL2eLiaExP6rSdFmQevxLp6nw5o7SKUEqMKBmy_A,644
40
41
  tests/unit/cli/cloudbender/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
41
42
  tests/unit/cli/cloudbender/test_cli_datastore_unit.py,sha256=DQWDjqg4viBZRONi00nVzqF9rJ5qKOKRub9pKbTmMWU,1381
42
43
  tests/unit/cli/cloudbender/test_cli_device_unit.py,sha256=2BSMyXQ8fOzNKh_-pa_tx7fy_GCdNlGLNvkiA7vuV7s,1342
@@ -52,30 +53,29 @@ tests/unit/cloudbender/test_nodes_unit.py,sha256=BDpfJXCBNNpLt5rhJMk2BVXDQ_4QSmx
52
53
  tests/unit/cloudbender/test_providers_unit.py,sha256=OgxifgC1IqLH8DNMKXy1Ne9_7a75ea6kHEOfRSRoQuQ,4373
53
54
  tests/unit/cloudbender/test_regions_unit.py,sha256=BbJICLIQmlotpA1UmLD0KTW_H9g2UW0J8ZYzQk1_Xjc,6299
54
55
  tests/unit/cloudbender/test_reservations_unit.py,sha256=nWEZ_p9EF2C49nbgL7Dt4NG2Irmyt94ZqJJQDyNfGFI,5624
55
- trainml/__init__.py,sha256=QyP74mv1Pi-vn0Sk_sHRvM5Souv9e7BBIiEmiN1SnIA,432
56
+ trainml/__init__.py,sha256=ZVeJu_28DU8_PtbnRLowX5jQ6IxkuYUpQHov-bC8Bt0,432
56
57
  trainml/__main__.py,sha256=JgErYkiskih8Y6oRwowALtR-rwQhAAdqOYWjQraRIPI,59
57
58
  trainml/auth.py,sha256=gruZv27nhttrCbhcVQTH9kZkF2uMm1E06SwA_2pQAHQ,26565
58
59
  trainml/checkpoints.py,sha256=726yaeTHzIPvkQ-BhqmvJ6u0s1Oh732N6UAFbuWEv0Q,8274
59
60
  trainml/connections.py,sha256=h-S1NZbOkaXpIlpRStA6q-3eXc_OMlFWOLzF8R9SVG8,20029
60
61
  trainml/datasets.py,sha256=5zhxpmK12aA7iXp4n-8bD00Tie7bi3LU0tPNJKdnZlY,7935
61
62
  trainml/environments.py,sha256=OH4o08zXZ7IJ2CiA1rPnys2Fl45r8qvQHfM2mCBRAIc,1507
62
- trainml/exceptions.py,sha256=hhR78fI8rbU3fWQ-kUsgxOLyYP8D2bQcjLjrskqPG0Q,3724
63
+ trainml/exceptions.py,sha256=MG1FkcjRacv3HoPuBS1IWLCUk0wGHEQ6DaOzXNymsNI,4094
63
64
  trainml/gpu_types.py,sha256=mm-dwfYc02192bmYPIJmzesndyBcoOdkKYBaYZXOUwU,1901
64
- trainml/jobs.py,sha256=cUuiMvtJSFSupv2jy7fqAvUOkO8FRLfBDe3RuSZUjfE,17807
65
+ trainml/jobs.py,sha256=iuSKkZDK908K0JwZjSbEk1G6IzdKp7vGkXsAUfih6R8,17838
65
66
  trainml/models.py,sha256=Lqs3OJMuOZXx8cfGFNC3JZ8nVK6l9_jU1xM1Uw1c5UQ,7750
66
67
  trainml/projects.py,sha256=Jk-0xhEgBFPIVMAK_Szxp3B-h-tJJSIEtTF8JWAYzo8,5209
67
- trainml/providers.py,sha256=97VegYVSeK0BuYv04hfBY_awNBbGz_GR_mdDrknfO-A,1844
68
- trainml/trainml.py,sha256=OrDXkJiStA_fTkwl0HX529n-cPCQUTVfdMOX-w0u1Oo,11045
69
- trainml/cli/__init__.py,sha256=-lA19Djkvqr1I-5PWoONgzHia-Lkg0J45dzt7qbetD8,4338
68
+ trainml/trainml.py,sha256=EBnqQ3Q291xrPKYuN6xKm5yt0mJQOJ3b7GAlR-fl8NI,10864
69
+ trainml/volumes.py,sha256=x4_QLPnPCuqEWR9FjwotCVPYFhdiZuWjp4tKKEU-Ne4,8094
70
+ trainml/cli/__init__.py,sha256=Gvj6oGSEtgpb40ACtiVeMD93GM-uy15MG6VlX6rwdwA,4346
70
71
  trainml/cli/checkpoint.py,sha256=8Rh4bmFwJ4DKlIjHK-FLTeRynABqKCgIUGRtbQhAsX4,7170
71
72
  trainml/cli/connection.py,sha256=ELV6bPL30dzttFNxDU7Fb74R8oPL_E70k7TcJEzbwtQ,1700
72
73
  trainml/cli/dataset.py,sha256=Pc00M6t7hGoRzCxznmmkijsWhG4PhIfG7UkrwtwykTY,6871
73
74
  trainml/cli/environment.py,sha256=dfm_T8OlCM5M8dLyOQBapJl3eFuVIku2P4JO6V0cYVQ,1019
74
75
  trainml/cli/gpu.py,sha256=CMcQyl2qbUgc2bc-gvUVT6X7bq2-sgiCHl3hyZ6kFWM,883
75
- trainml/cli/job.py,sha256=Xh_542a2rjAjhXyq8WG9ZEySboq_MmFQ7sCqYJYRXcM,4302
76
76
  trainml/cli/model.py,sha256=hR23E6ttRXcLk-RofkPK6wUXMO7OU6sT6jTEHTmUg9Q,6111
77
77
  trainml/cli/project.py,sha256=iar4-Aic6k0O6HWfxB5vJ5EtjS4aAaw2g8wW7H-3Bjk,3522
78
- trainml/cli/provider.py,sha256=eaklYo0IUOGeyZ0ziZZz6UhOZ50Py4ewaPtWA9UDCqU,1594
78
+ trainml/cli/volume.py,sha256=kDUss93N78DT-YlLjC6I3jEq5nBWfRNNR5M4tY_F_Zg,6246
79
79
  trainml/cli/cloudbender/__init__.py,sha256=ewat7I7PvHQEW0a8zuky28kuMwtjL-gFQOAQmG0mmJA,534
80
80
  trainml/cli/cloudbender/datastore.py,sha256=gJ-comfAq65uiPoONQ35UIDLNVN7QKMf3l_2EcTN6zY,3478
81
81
  trainml/cli/cloudbender/device.py,sha256=KGZCFwwvS4tWsWuudrhlvquu_IFtV7LCUAOmCajicic,3453
@@ -84,7 +84,7 @@ trainml/cli/cloudbender/provider.py,sha256=oFjZWKfFQjNY7OtDu7nUdfv-RTmQc_Huuug96
84
84
  trainml/cli/cloudbender/region.py,sha256=X6-FYOb-pGpOEazn-NbsYSwa9ergB7FGATFkTe4a8Pk,2892
85
85
  trainml/cli/cloudbender/reservation.py,sha256=z2oMYwp-w_Keo1DepKUtuRnwiGz2VscVHDYWEFap1gs,3569
86
86
  trainml/cli/job/__init__.py,sha256=ljY-ELeXhXQ7txASbJEKGBom7OXfNyy7sWILz3nxRAE,6545
87
- trainml/cli/job/create.py,sha256=ZKxrNAW-w2RwopERZvsjErltEejqelas4JVnppH5MWg,34282
87
+ trainml/cli/job/create.py,sha256=pfOCqs5Vfk4PAI5KZpXHJ1vp3DDe4ccvYzieh0oFexY,34288
88
88
  trainml/cloudbender/__init__.py,sha256=iE29obtC0_9f0IhRvHQcG5aY58fVhVYipTakpjAhdss,64
89
89
  trainml/cloudbender/cloudbender.py,sha256=rAc93mtMaa3tWLi-NT4tiqO664MCPkhlGG9oNYSmimQ,634
90
90
  trainml/cloudbender/datastores.py,sha256=biVGifedc3r1DcuxsfCQh-f1Tw4HcJMMJfdgHxPfkKM,3506
@@ -94,9 +94,9 @@ trainml/cloudbender/nodes.py,sha256=7HV2VLmxiUcJ-Kc6AAXS3M8C_XO-HKmaVgJpPdVnBQk,
94
94
  trainml/cloudbender/providers.py,sha256=-gkdiTu6Ah2znUuyyc3ZuRALagW8s1-OgqVjtlvc1AU,2036
95
95
  trainml/cloudbender/regions.py,sha256=Aqc_MeLVAeEv21e-lR5u8x1eintqUhZT2DBiQG3AcEE,3570
96
96
  trainml/cloudbender/reservations.py,sha256=rOrGXWIUHON4ad2aufEcvK4Yv_Mv3dDoScUtLJE8LWw,3586
97
- trainml-0.5.3.dist-info/LICENSE,sha256=s0lpBxhSSUEpMavwde-Vb6K_K7xDCTTvSpNznVqVGR0,1069
98
- trainml-0.5.3.dist-info/METADATA,sha256=W7Sn5k_EiNFps2DcBDODKT8h7Gh72d-ZWvALheMDRt8,7345
99
- trainml-0.5.3.dist-info/WHEEL,sha256=pkctZYzUS4AYVn6dJ-7367OJZivF2e8RA9b_ZBjif18,92
100
- trainml-0.5.3.dist-info/entry_points.txt,sha256=OzBDm2wXby1bSGF02jTVxzRFZLejnbFiLHXhKdW3Bds,63
101
- trainml-0.5.3.dist-info/top_level.txt,sha256=Y1kLFRWKUW7RG8BX7cvejHF_yW8wBOaRYF1JQHENY4w,23
102
- trainml-0.5.3.dist-info/RECORD,,
97
+ trainml-0.5.5.dist-info/LICENSE,sha256=s0lpBxhSSUEpMavwde-Vb6K_K7xDCTTvSpNznVqVGR0,1069
98
+ trainml-0.5.5.dist-info/METADATA,sha256=1yX0kzexWk9MrF3vMdP87PRTbmV34UgsgF-zV-MHAg4,7345
99
+ trainml-0.5.5.dist-info/WHEEL,sha256=pkctZYzUS4AYVn6dJ-7367OJZivF2e8RA9b_ZBjif18,92
100
+ trainml-0.5.5.dist-info/entry_points.txt,sha256=OzBDm2wXby1bSGF02jTVxzRFZLejnbFiLHXhKdW3Bds,63
101
+ trainml-0.5.5.dist-info/top_level.txt,sha256=Y1kLFRWKUW7RG8BX7cvejHF_yW8wBOaRYF1JQHENY4w,23
102
+ trainml-0.5.5.dist-info/RECORD,,
@@ -1,46 +0,0 @@
1
- import re
2
- import sys
3
- import asyncio
4
- from pytest import mark, fixture
5
-
6
- pytestmark = [mark.sdk, mark.integration, mark.providers]
7
-
8
-
9
- @mark.create
10
- @mark.asyncio
11
- class GetProvidersTests:
12
- @fixture(scope="class")
13
- async def provider(self, trainml):
14
- provider = await trainml.providers.enable(type="test")
15
- yield provider
16
- await provider.remove()
17
-
18
- async def test_get_providers(self, trainml):
19
- providers = await trainml.providers.list()
20
- assert len(providers) > 0
21
-
22
- async def test_get_provider(self, trainml, provider):
23
- response = await trainml.providers.get(provider.id)
24
- assert response.id == provider.id
25
-
26
- async def test_provider_properties(self, provider):
27
- assert isinstance(provider.id, str)
28
- assert isinstance(provider.type, str)
29
- assert provider.type == "test"
30
- assert provider.credits == 0
31
-
32
- async def test_provider_str(self, provider):
33
- string = str(provider)
34
- regex = r"^{.*\"provider_uuid\": \"" + provider.id + r"\".*}$"
35
- assert isinstance(string, str)
36
- assert re.match(regex, string)
37
-
38
- async def test_provider_repr(self, provider):
39
- string = repr(provider)
40
- regex = (
41
- r"^Provider\( trainml , \*\*{.*'provider_uuid': '"
42
- + provider.id
43
- + r"'.*}\)$"
44
- )
45
- assert isinstance(string, str)
46
- assert re.match(regex, string)
@@ -1,125 +0,0 @@
1
- import re
2
- import json
3
- import logging
4
- from unittest.mock import AsyncMock, patch
5
- from pytest import mark, fixture, raises
6
- from aiohttp import WSMessage, WSMsgType
7
-
8
- import trainml.providers as specimen
9
- from trainml.exceptions import (
10
- ApiError,
11
- SpecificationError,
12
- TrainMLException,
13
- )
14
-
15
- pytestmark = [mark.sdk, mark.unit, mark.providers]
16
-
17
-
18
- @fixture
19
- def providers(mock_trainml):
20
- yield specimen.Providers(mock_trainml)
21
-
22
-
23
- @fixture
24
- def provider(mock_trainml):
25
- yield specimen.Provider(
26
- mock_trainml,
27
- customer_uuid="a",
28
- provider_uuid="1",
29
- type="physical",
30
- payment_mode="credits",
31
- createdAt="2020-12-31T23:59:59.000Z",
32
- credits=0.0,
33
- )
34
-
35
-
36
- class ProvidersTests:
37
- @mark.asyncio
38
- async def test_get_provider(
39
- self,
40
- providers,
41
- mock_trainml,
42
- ):
43
- api_response = dict()
44
- mock_trainml._query = AsyncMock(return_value=api_response)
45
- await providers.get("1234")
46
- mock_trainml._query.assert_called_once_with("/provider/1234", "GET")
47
-
48
- @mark.asyncio
49
- async def test_list_providers(
50
- self,
51
- providers,
52
- mock_trainml,
53
- ):
54
- api_response = dict()
55
- mock_trainml._query = AsyncMock(return_value=api_response)
56
- await providers.list()
57
- mock_trainml._query.assert_called_once_with("/provider", "GET")
58
-
59
- @mark.asyncio
60
- async def test_remove_provider(
61
- self,
62
- providers,
63
- mock_trainml,
64
- ):
65
- api_response = dict()
66
- mock_trainml._query = AsyncMock(return_value=api_response)
67
- await providers.remove("4567")
68
- mock_trainml._query.assert_called_once_with("/provider/4567", "DELETE")
69
-
70
- @mark.asyncio
71
- async def test_enable_provider_simple(self, providers, mock_trainml):
72
- requested_config = dict(
73
- type="physical",
74
- )
75
- expected_payload = dict(type="physical")
76
- api_response = {
77
- "customer_uuid": "cust-id-1",
78
- "provider_uuid": "provider-id-1",
79
- "type": "new provider",
80
- "credits": 0.0,
81
- "payment_mode": "credits",
82
- "createdAt": "2020-12-31T23:59:59.000Z",
83
- }
84
-
85
- mock_trainml._query = AsyncMock(return_value=api_response)
86
- response = await providers.enable(**requested_config)
87
- mock_trainml._query.assert_called_once_with(
88
- "/provider", "POST", None, expected_payload
89
- )
90
- assert response.id == "provider-id-1"
91
-
92
-
93
- class providerTests:
94
- def test_provider_properties(self, provider):
95
- assert isinstance(provider.id, str)
96
- assert isinstance(provider.type, str)
97
- assert isinstance(provider.credits, float)
98
-
99
- def test_provider_str(self, provider):
100
- string = str(provider)
101
- regex = r"^{.*\"provider_uuid\": \"" + provider.id + r"\".*}$"
102
- assert isinstance(string, str)
103
- assert re.match(regex, string)
104
-
105
- def test_provider_repr(self, provider):
106
- string = repr(provider)
107
- regex = (
108
- r"^Provider\( trainml , \*\*{.*'provider_uuid': '"
109
- + provider.id
110
- + r"'.*}\)$"
111
- )
112
- assert isinstance(string, str)
113
- assert re.match(regex, string)
114
-
115
- def test_provider_bool(self, provider, mock_trainml):
116
- empty_provider = specimen.Provider(mock_trainml)
117
- assert bool(provider)
118
- assert not bool(empty_provider)
119
-
120
- @mark.asyncio
121
- async def test_provider_remove(self, provider, mock_trainml):
122
- api_response = dict()
123
- mock_trainml._query = AsyncMock(return_value=api_response)
124
- await provider.remove()
125
- mock_trainml._query.assert_called_once_with("/provider/1", "DELETE")