trainml 0.4.13__py3-none-any.whl → 0.4.15__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -314,7 +314,13 @@ class JobTests:
314
314
 
315
315
  @mark.asyncio
316
316
  async def test_job_start(self, job, mock_trainml):
317
- api_response = None
317
+ api_response = {
318
+ "customer_uuid": "cus-id-1",
319
+ "job_uuid": "job-id-1",
320
+ "name": "test notebook",
321
+ "type": "notebook",
322
+ "status": "starting",
323
+ }
318
324
  mock_trainml._query = AsyncMock(return_value=api_response)
319
325
  await job.start()
320
326
  mock_trainml._query.assert_called_once_with(
@@ -326,7 +332,13 @@ class JobTests:
326
332
 
327
333
  @mark.asyncio
328
334
  async def test_job_stop(self, job, mock_trainml):
329
- api_response = None
335
+ api_response = {
336
+ "customer_uuid": "cus-id-1",
337
+ "job_uuid": "job-id-1",
338
+ "name": "test notebook",
339
+ "type": "notebook",
340
+ "status": "stopping",
341
+ }
330
342
  mock_trainml._query = AsyncMock(return_value=api_response)
331
343
  await job.stop()
332
344
  mock_trainml._query.assert_called_once_with(
trainml/__init__.py CHANGED
@@ -13,5 +13,5 @@ logging.basicConfig(
13
13
  logger = logging.getLogger(__name__)
14
14
 
15
15
 
16
- __version__ = "0.4.13"
16
+ __version__ = "0.4.15"
17
17
  __all__ = "TrainML"
trainml/checkpoints.py CHANGED
@@ -126,8 +126,12 @@ class Checkpoint:
126
126
  ssh_port=self._checkpoint.get("vpn")
127
127
  .get("client")
128
128
  .get("ssh_port"),
129
- input_path=self._checkpoint.get("source_uri"),
130
- output_path=None,
129
+ input_path=self._checkpoint.get("source_uri")
130
+ if self.status in ["new", "downloading"]
131
+ else None,
132
+ output_path=self._checkpoint.get("output_uri")
133
+ if self.status == "exporting"
134
+ else None,
131
135
  )
132
136
  else:
133
137
  details = dict()
@@ -137,7 +141,7 @@ class Checkpoint:
137
141
  if self.status in ["ready", "failed"]:
138
142
  raise SpecificationError(
139
143
  "status",
140
- f"You can only connect to new or downloading checkpoints.",
144
+ f"You can only connect to downloading or exporting checkpoints.",
141
145
  )
142
146
  if self.status == "new":
143
147
  await self.wait_for("downloading")
@@ -162,12 +166,28 @@ class Checkpoint:
162
166
  )
163
167
 
164
168
  async def rename(self, name):
165
- await self.trainml._query(
169
+ resp = await self.trainml._query(
166
170
  f"/checkpoint/{self._id}",
167
171
  "PATCH",
168
- None,
172
+ dict(project_uuid=self._project_uuid),
169
173
  dict(name=name),
170
174
  )
175
+ self.__init__(self.trainml, **resp)
176
+ return self
177
+
178
+ async def export(self, output_type, output_uri, output_options=dict()):
179
+ resp = await self.trainml._query(
180
+ f"/checkpoint/{self._id}/export",
181
+ "POST",
182
+ dict(project_uuid=self._project_uuid),
183
+ dict(
184
+ output_type=output_type,
185
+ output_uri=output_uri,
186
+ output_options=output_options,
187
+ ),
188
+ )
189
+ self.__init__(self.trainml, **resp)
190
+ return self
171
191
 
172
192
  def _get_msg_handler(self, msg_handler):
173
193
  def handler(data):
trainml/datasets.py CHANGED
@@ -158,12 +158,14 @@ class Dataset:
158
158
  )
159
159
 
160
160
  async def rename(self, name):
161
- await self.trainml._query(
161
+ resp = await self.trainml._query(
162
162
  f"/dataset/{self._id}",
163
163
  "PATCH",
164
164
  None,
165
165
  dict(name=name),
166
166
  )
167
+ self.__init__(self.trainml, **resp)
168
+ return self
167
169
 
168
170
  def _get_msg_handler(self, msg_handler):
169
171
  def handler(data):
trainml/jobs.py CHANGED
@@ -253,20 +253,24 @@ class Job:
253
253
  return create_json
254
254
 
255
255
  async def start(self):
256
- await self.trainml._query(
256
+ resp = await self.trainml._query(
257
257
  f"/job/{self._id}",
258
258
  "PATCH",
259
259
  dict(project_uuid=self._project_uuid),
260
260
  dict(command="start"),
261
261
  )
262
+ self.__init__(self.trainml, **resp)
263
+ return self
262
264
 
263
265
  async def stop(self):
264
- await self.trainml._query(
266
+ resp = await self.trainml._query(
265
267
  f"/job/{self._id}",
266
268
  "PATCH",
267
269
  dict(project_uuid=self._project_uuid),
268
270
  dict(command="stop"),
269
271
  )
272
+ self.__init__(self.trainml, **resp)
273
+ return self
270
274
 
271
275
  async def update(self, data):
272
276
  if self.type != "notebook":
@@ -274,12 +278,14 @@ class Job:
274
278
  "type",
275
279
  "Only notebook jobs can be modified.",
276
280
  )
277
- await self.trainml._query(
281
+ resp = await self.trainml._query(
278
282
  f"/job/{self._id}",
279
283
  "PATCH",
280
284
  dict(project_uuid=self._project_uuid),
281
285
  data,
282
286
  )
287
+ self.__init__(self.trainml, **resp)
288
+ return self
283
289
 
284
290
  async def get_worker_log_url(self, job_worker_uuid):
285
291
  resp = await self.trainml._query(
trainml/models.py CHANGED
@@ -151,12 +151,14 @@ class Model:
151
151
  )
152
152
 
153
153
  async def rename(self, name):
154
- await self.trainml._query(
154
+ resp = await self.trainml._query(
155
155
  f"/model/{self._id}",
156
156
  "PATCH",
157
157
  None,
158
158
  dict(name=name),
159
159
  )
160
+ self.__init__(self.trainml, **resp)
161
+ return self
160
162
 
161
163
  def _get_msg_handler(self, msg_handler):
162
164
  def handler(data):
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: trainml
3
- Version: 0.4.13
3
+ Version: 0.4.15
4
4
  Summary: trainML client SDK and command line utilities
5
5
  Home-page: https://github.com/trainML/trainml-cli
6
6
  Author: trainML
@@ -21,7 +21,7 @@ tests/unit/test_datasets_unit.py,sha256=lVNoBZu4RIiJK26gbUPOUAra_k0YS2GcnjJDnT7U
21
21
  tests/unit/test_environments_unit.py,sha256=1QFGf1xwM0yKCyVHT_Xi0DX8g0Aelr0mcqAImEXJfQU,1882
22
22
  tests/unit/test_exceptions.py,sha256=3tAok6kAU1QRjN7qTNVYuSGWDg7IEoK__OXFLyzLr7k,906
23
23
  tests/unit/test_gpu_types_unit.py,sha256=6v_n_AytYjQZxv2OtcUYBxQz7iRjigSS7xmBo6wSJk0,1703
24
- tests/unit/test_jobs_unit.py,sha256=zG--HP1NOWDNx-5ZYWfv4R10RW9s-DR77OQC2EN6kaM,26459
24
+ tests/unit/test_jobs_unit.py,sha256=N-51OkTU_nHdPxCcLzGyz4A0LD7dmupE5TGr7JXjWVM,26833
25
25
  tests/unit/test_models_unit.py,sha256=uezWF7FUHGmCSQBtpyyKhBttTnCTRjxU22NsHdJLYYg,15064
26
26
  tests/unit/test_projects_unit.py,sha256=mV0CejcTSNJEYfFl-vYcPbQ2HnDw31xc04Asn6-jZrM,3871
27
27
  tests/unit/test_providers_unit.py,sha256=nEizghnC8pfDubkCw-kMmS_QQOUUWBk3i8D44pnyljo,3700
@@ -29,17 +29,17 @@ tests/unit/test_trainml.py,sha256=8vAKvFD1xYsx_VY4HFVa0b1MUlMoNApY6TO8r7vI-UQ,17
29
29
  tests/unit/cli/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
30
30
  tests/unit/cli/conftest.py,sha256=w6p_2URicywJKUCtY79tSD_mx8cwJtxHbK_Lu3grOYs,236
31
31
  tests/unit/cli/test_cli_environment_unit.py,sha256=7FFKFPVa6yqJqujTXD_tsSUDruW5KhlLgShPZTn1Eio,683
32
- trainml/__init__.py,sha256=YCDyPnS-RTLikmOhBm90HYDsjdcnIut8_uloxGBQr08,433
32
+ trainml/__init__.py,sha256=SZqtnWUI0eSxdRWMWS8SVB1kXKfEJSLzw5GhbC3r4SA,433
33
33
  trainml/__main__.py,sha256=JgErYkiskih8Y6oRwowALtR-rwQhAAdqOYWjQraRIPI,59
34
34
  trainml/auth.py,sha256=VI8wXgmXbwiBpgw0Sm-DzC4VmNqKg_zniua-Z61PafA,26646
35
- trainml/checkpoints.py,sha256=2Mq7VWs1_Jn9hhTwylrbh5XMocDup4cb1Urweod_X_E,7497
35
+ trainml/checkpoints.py,sha256=-P8FOkw1ihA5KCSQll-FDY1VtuKZUOboBjjEipUltF4,8262
36
36
  trainml/connections.py,sha256=xkcFoyB5AXztRh4DPkpg3GPIZPxBoDvb5FZfON46aVs,20059
37
- trainml/datasets.py,sha256=jhzA0t72oGiYYXOiMckJECap57TlfFUCYCvZZhaQaFU,7311
37
+ trainml/datasets.py,sha256=eImi3BO74M7mV9kmUsH03O1ybP_mcMA_FPdLMC--ioM,7382
38
38
  trainml/environments.py,sha256=hctwxZX7zNO-StgMnD9Zh4JS8fc3mUc98Ws3xM4c9QQ,1537
39
39
  trainml/exceptions.py,sha256=hhR78fI8rbU3fWQ-kUsgxOLyYP8D2bQcjLjrskqPG0Q,3724
40
40
  trainml/gpu_types.py,sha256=VXB4XqE0vEwtRyj5nOhWHlR2hRh_lU1VuURU3vyO_gE,1831
41
- trainml/jobs.py,sha256=YxRta4vVIob3OTRkX0vz45F9uXRtEJ34rUDYuScZR1Q,17594
42
- trainml/models.py,sha256=nZ1GzzvdXUyt30riev6mC4K3DG7gLu4BBJCFavGzbz0,6960
41
+ trainml/jobs.py,sha256=cUuiMvtJSFSupv2jy7fqAvUOkO8FRLfBDe3RuSZUjfE,17807
42
+ trainml/models.py,sha256=jQ84rqrTbOuCstcZYlh8agWeZOQkrApZEpy_m9bfKWU,7031
43
43
  trainml/projects.py,sha256=-lFWwii5cnaqm6vIVZbiVsolZ9kNRsawdd8HjF4BLJE,2063
44
44
  trainml/providers.py,sha256=97VegYVSeK0BuYv04hfBY_awNBbGz_GR_mdDrknfO-A,1844
45
45
  trainml/trainml.py,sha256=1xGKaI7mAUkB4I7bDyPVDNNrDVmgC6Sg8tJx2UXkKH4,10974
@@ -55,9 +55,9 @@ trainml/cli/project.py,sha256=sBId3S4K7kRwIqURkabSXU3iL2TXbg8XMcjMxHwtt7g,1654
55
55
  trainml/cli/provider.py,sha256=eaklYo0IUOGeyZ0ziZZz6UhOZ50Py4ewaPtWA9UDCqU,1594
56
56
  trainml/cli/job/__init__.py,sha256=KP3_j6aaokVYNJyi8BUReUdJ6WbZ3ObEBmRJTFWtS10,6544
57
57
  trainml/cli/job/create.py,sha256=JyyqVH0Oe9wktRRxxqMC5GEVWd6nzEdKkE42jn1jjtc,35272
58
- trainml-0.4.13.dist-info/LICENSE,sha256=s0lpBxhSSUEpMavwde-Vb6K_K7xDCTTvSpNznVqVGR0,1069
59
- trainml-0.4.13.dist-info/METADATA,sha256=qTw565C9HMTZ3zDTOx6ktUiki5EdBk7i02v_Q-JINiM,7346
60
- trainml-0.4.13.dist-info/WHEEL,sha256=pkctZYzUS4AYVn6dJ-7367OJZivF2e8RA9b_ZBjif18,92
61
- trainml-0.4.13.dist-info/entry_points.txt,sha256=OzBDm2wXby1bSGF02jTVxzRFZLejnbFiLHXhKdW3Bds,63
62
- trainml-0.4.13.dist-info/top_level.txt,sha256=Y1kLFRWKUW7RG8BX7cvejHF_yW8wBOaRYF1JQHENY4w,23
63
- trainml-0.4.13.dist-info/RECORD,,
58
+ trainml-0.4.15.dist-info/LICENSE,sha256=s0lpBxhSSUEpMavwde-Vb6K_K7xDCTTvSpNznVqVGR0,1069
59
+ trainml-0.4.15.dist-info/METADATA,sha256=XY6ZStH-E7tb-tBeJ4zPPjHYbUhYRh2GkFrwqe1Pk4s,7346
60
+ trainml-0.4.15.dist-info/WHEEL,sha256=pkctZYzUS4AYVn6dJ-7367OJZivF2e8RA9b_ZBjif18,92
61
+ trainml-0.4.15.dist-info/entry_points.txt,sha256=OzBDm2wXby1bSGF02jTVxzRFZLejnbFiLHXhKdW3Bds,63
62
+ trainml-0.4.15.dist-info/top_level.txt,sha256=Y1kLFRWKUW7RG8BX7cvejHF_yW8wBOaRYF1JQHENY4w,23
63
+ trainml-0.4.15.dist-info/RECORD,,