wizata-dsapi 1.3.40__py3-none-any.whl → 1.3.41__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
wizata_dsapi/api_dto.py CHANGED
@@ -2,6 +2,12 @@ import uuid
2
2
  from enum import Enum
3
3
 
4
4
 
5
+ class ApiDtoInterface:
6
+
7
+ def load_model(self, model):
8
+ pass
9
+
10
+
5
11
  class VarType(Enum):
6
12
  """
7
13
  defines possible type for a defined variable.
wizata_dsapi/mlmodel.py CHANGED
@@ -1,6 +1,6 @@
1
1
  from typing import List, Iterator, Union, Optional
2
2
  import os
3
- from .api_dto import ApiDto
3
+ from .api_dto import ApiDto, ApiDtoInterface
4
4
  from datetime import datetime, timezone
5
5
 
6
6
 
@@ -165,6 +165,17 @@ class ModelInfo:
165
165
  self.has_target_feat = False
166
166
  self.label_counts = 0
167
167
 
168
+ # api
169
+ self._api = None
170
+
171
+ def bind_api(self, api:ApiDtoInterface):
172
+ """
173
+ internal method to bind the api to the dto.
174
+ :param api: api client
175
+ :return: None
176
+ """
177
+ self._api = api
178
+
168
179
  @classmethod
169
180
  def split_identifier(cls, identifier: str):
170
181
  """
@@ -279,6 +290,19 @@ class ModelInfo:
279
290
  self.is_active = get_bool(obj, name="is_active")
280
291
  if "updatedDate" in obj.keys() and obj["updatedDate"] is not None:
281
292
  self.updated_date = obj["updatedDate"]
293
+ if "files" in obj.keys():
294
+ for obj_file in obj["files"]:
295
+ model_file = ModelFile()
296
+ model_file.from_json(obj_file)
297
+ self.add_file(model_file)
298
+
299
+ def load(self):
300
+ """
301
+ load the trained model from the repository.
302
+ """
303
+ if self._api is None:
304
+ raise RuntimeError("api is not bound to the dto use bind_api()")
305
+ self._api.load_model(self)
282
306
 
283
307
 
284
308
  class ModelList:
@@ -304,15 +328,23 @@ class ModelList:
304
328
  """
305
329
  return any(model_in_list.identifier(include_alias=True) == model.identifier(include_alias=True) for model_in_list in self.models)
306
330
 
307
- def __getitem__(self, key: Union[int, ModelInfo]) -> ModelInfo:
331
+ def __getitem__(self, key: Union[int, str, ModelInfo]) -> ModelInfo:
308
332
  """
309
- find a model within list based on key or ModelInfo.
310
- :param key: ModelInfo or index
333
+ find a model within list based on index, identifier or ModelInfo.
334
+ :param key: identifier or ModelInfo or index
311
335
  :return: the model_info
312
336
  """
313
337
  if isinstance(key, int):
314
338
  return self.models[key]
315
339
 
340
+ elif isinstance(key, str):
341
+ if "@" not in key:
342
+ return self.select_active_model(identifier=key)
343
+ for model in self.models:
344
+ if model.identifier(include_alias=True) == key:
345
+ return model
346
+ raise KeyError(f"model with identifier '{key}' not found within this ModelList.")
347
+
316
348
  elif isinstance(key, ModelInfo):
317
349
  identifier = key.identifier(include_alias=True)
318
350
  for model in self.models:
@@ -321,7 +353,22 @@ class ModelList:
321
353
  raise KeyError(f"model with identifier '{identifier}' not found within this ModelList.")
322
354
 
323
355
  else:
324
- raise TypeError("ModelList indices must be int or ModelInfo.")
356
+ raise TypeError("ModelList indices must be int, str or ModelInfo.")
357
+
358
+ def select_active_model(self, identifier: str) -> ModelInfo:
359
+ """
360
+ return the active model based on active status or latest one if none active.
361
+ :param identifier: identifier
362
+ :return: active model
363
+ """
364
+ models = []
365
+ for model in self.models:
366
+ if model.identifier(include_alias=False) == identifier:
367
+ if model.is_active:
368
+ return model
369
+ else:
370
+ models.append(model)
371
+ return max(models, key=lambda f: f.updated_date, default=None)
325
372
 
326
373
  def append(self, model: ModelInfo):
327
374
  self.models.append(model)
wizata_dsapi/pipeline.py CHANGED
@@ -1,5 +1,6 @@
1
1
  import uuid
2
2
  import json
3
+ import sys
3
4
 
4
5
  import pandas
5
6
  import wizata_dsapi
@@ -159,13 +160,13 @@ class PipelineIO(ApiDto):
159
160
  obj["columns"] = self.columns
160
161
  return obj
161
162
 
162
- def prepare(self, df: pandas.DataFrame) -> pandas.DataFrame:
163
+ def _prepare_df(self, df: pandas.DataFrame) -> pandas.DataFrame:
163
164
  """
164
- prepare the dataframe based in information from this pipeline I/O definition.
165
- perform mapping, selection and drops.
165
+ prepare dataframe in both 3.9 and 3.11+
166
166
  :param pandas.DataFrame df: dataframe to prepare.
167
167
  :return: prepared dataframe.
168
168
  """
169
+
169
170
  try:
170
171
  prepare_df = df.copy()
171
172
 
@@ -186,6 +187,25 @@ class PipelineIO(ApiDto):
186
187
  except Exception as e:
187
188
  raise RuntimeError(f'not able to prepare your dataframe following Pipeline I/O directives {e}')
188
189
 
190
+ if sys.version_info >= (3, 11):
191
+
192
+ ## with torch support
193
+ import torch
194
+ from typing import Union
195
+
196
+ def prepare(self, df: Union[pandas.DataFrame, torch.Tensor]) -> Union[pandas.DataFrame, torch.Tensor]:
197
+ import torch
198
+ if isinstance(df, torch.Tensor):
199
+ return df
200
+ else:
201
+ return self._prepare_df(df)
202
+
203
+ else:
204
+
205
+ ## without torch support
206
+ def prepare(self, df: pandas.DataFrame) -> pandas.DataFrame:
207
+ return self._prepare_df(df)
208
+
189
209
  @classmethod
190
210
  def from_obj(cls, obj):
191
211
  """
wizata_dsapi/version.py CHANGED
@@ -1 +1 @@
1
- __version__ = "1.3.40"
1
+ __version__ = "1.3.41"
@@ -16,11 +16,13 @@ import types
16
16
  import wizata_dsapi
17
17
  import urllib.parse
18
18
  import base64
19
+ import joblib
20
+ import io
19
21
 
20
22
  import string
21
23
  import random
22
24
 
23
- from .api_dto import ApiDto, VarType
25
+ from .api_dto import ApiDto, VarType, ApiDtoInterface
24
26
  from .business_label import BusinessLabel
25
27
  from .plot import Plot
26
28
  from .request import Request
@@ -69,7 +71,7 @@ def parse_string_list(s):
69
71
  return []
70
72
 
71
73
 
72
- class WizataDSAPIClient(ApiInterface):
74
+ class WizataDSAPIClient(ApiInterface, ApiDtoInterface):
73
75
  """
74
76
  client wrapper to cloud data science API
75
77
 
@@ -1100,17 +1102,24 @@ class WizataDSAPIClient(ApiInterface):
1100
1102
  raise self.__raise_error(response)
1101
1103
 
1102
1104
  def upload_model(self,
1103
- model_info: ModelInfo):
1105
+ model_info: ModelInfo,
1106
+ bytes_content = None):
1104
1107
  """
1105
1108
  upload a model within the model repository.
1109
+ - by default use model_info.trained_model and convert it to a pickle
1110
+ - for already torch or pickle please pass the bytes_content
1111
+ - model_info.file_format must be set properly to 'pkl' or 'pt'
1106
1112
  :param model_info: model info, with at least key (+twin, +property, +alias) and trained_model.
1113
+ :param bytes_content: bytes[] of your torch or pickle model.
1107
1114
  """
1108
- if model_info.trained_model is None:
1109
- raise ValueError("model_info must have a trained model as bytes content pkl or pt")
1115
+ if model_info.trained_model is None and bytes_content is None:
1116
+ raise ValueError("model_info must have a trained model (to pickle) or bytes content")
1117
+ if bytes_content is None:
1118
+ bytes_content = pickle.dumps(model_info.trained_model)
1110
1119
  files = {
1111
1120
  "trained_model": (
1112
1121
  "trained_model." + model_info.file_format ,
1113
- model_info.trained_model,
1122
+ bytes_content,
1114
1123
  "application/octet-stream",
1115
1124
  )
1116
1125
  }
@@ -1849,6 +1858,27 @@ class WizataDSAPIClient(ApiInterface):
1849
1858
  else:
1850
1859
  raise self.__raise_error(response)
1851
1860
 
1861
+ def search_models(self) -> ModelList:
1862
+ """
1863
+ get all information related to models stored on Wizata.
1864
+ :return: ModelList structure model list
1865
+ """
1866
+ response = requests.request("GET",
1867
+ self.__url() + "models",
1868
+ headers=self.__header()
1869
+ )
1870
+ if response.status_code == 200:
1871
+ response_json = response.json()
1872
+ model_list = ModelList()
1873
+ for model_json in response_json:
1874
+ model_info = ModelInfo(model_json["key"])
1875
+ model_info.from_json(model_json)
1876
+ model_info._api = self
1877
+ model_list.append(model_info)
1878
+ return model_list
1879
+ else:
1880
+ raise self.__raise_error(response)
1881
+
1852
1882
  def abort(self, executions: list) -> str:
1853
1883
  """
1854
1884
  send an abort request for executions and return a result message
@@ -1905,6 +1935,32 @@ class WizataDSAPIClient(ApiInterface):
1905
1935
  image = PipelineImage.loads(pipeline_image_id=pipeline_image_id, g_bytes=response_bytes)
1906
1936
  return image
1907
1937
 
1938
+ def load_model(self, model):
1939
+ """
1940
+ load a model pickle or torch from the repository ready to be used.
1941
+ :param model: ModelInfo to load
1942
+ :return: ModelInfo with the trained model.
1943
+ """
1944
+ if not isinstance(model, ModelInfo):
1945
+ raise TypeError('model must be an instance of ModelInfo')
1946
+
1947
+ identifier = model.identifier(include_alias=True)
1948
+ extension = model.file_format
1949
+ response = requests.get(self.__url() + f"models/{identifier}/files/trained_model.{extension}/",
1950
+ headers=self.__header())
1951
+ if response.status_code == 200:
1952
+ if extension == 'pkl':
1953
+ model.trained_model = joblib.load(io.BytesIO(response.content))
1954
+ return model
1955
+ elif extension == 'pt':
1956
+ import torch
1957
+ model.trained_model = torch.jit.load(io.BytesIO(response.content))
1958
+ return model
1959
+ else:
1960
+ raise ValueError(f'unsupported file format {extension}')
1961
+ else:
1962
+ self.__raise_error(response)
1963
+
1908
1964
 
1909
1965
  def api() -> WizataDSAPIClient:
1910
1966
  """
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: wizata_dsapi
3
- Version: 1.3.40
3
+ Version: 1.3.41
4
4
  Summary: Wizata Data Science Toolkit
5
5
  Author: Wizata S.A.
6
6
  Author-email: info@wizata.com
@@ -1,6 +1,6 @@
1
1
  wizata_dsapi/__init__.py,sha256=aVeizk2zzmAX0TLNG9o7vcmG8KB73XiXw1_qozrJN2w,2072
2
2
  wizata_dsapi/api_config.py,sha256=6Pnnv62X_QrTUXaa1MtFZeQaqMUJC-9Y5BW7B4gef10,5139
3
- wizata_dsapi/api_dto.py,sha256=-NdaTRvw5jW5xFGpIhY8U0-SdvzW2t6QD26y0UPApU0,2238
3
+ wizata_dsapi/api_dto.py,sha256=zyzj6-Kcxi59vVAvThfsuq7RLUE2O_FZ3TJ5LpKYPbE,2310
4
4
  wizata_dsapi/api_interface.py,sha256=DURk-0ey16T8sV5e2Y2G_YybPEusJvZuY0oD5L7AnXo,10903
5
5
  wizata_dsapi/bucket.py,sha256=Zz9olv-pymikAutGitSuGWrAPiawOTW86JDDHG4ugTc,1150
6
6
  wizata_dsapi/business_label.py,sha256=u0TVfUNfoR9qSv8lzpf6rNjlg3G9xTiz6itefcKfeak,4151
@@ -16,10 +16,10 @@ wizata_dsapi/experiment.py,sha256=QYQ1CJ-MTWsXq08xYbm5sAp95dRxbPOmGDgaAOoBMDQ,46
16
16
  wizata_dsapi/group_system.py,sha256=6rUKe0_J3YWACysyBlzuw_TEpKNXgLOMxhpWsNxOzwY,1708
17
17
  wizata_dsapi/ilogger.py,sha256=iYnID-Z-qrYhie26C43404aIuU4_tHSKXbDeQIdo82Q,807
18
18
  wizata_dsapi/insight.py,sha256=ABFZ04DqYxxzqAEfU1tzlTZqqrigM-zN-8Lbetko3g0,6468
19
- wizata_dsapi/mlmodel.py,sha256=RF9hWv2vEkovieVv4hr99X5uFhABnttAxFTnZX6Pcec,22874
19
+ wizata_dsapi/mlmodel.py,sha256=Bdx4bZdrvcR4IZyVH1_yDtU94qe4M8LOFK7I0QqdknQ,24541
20
20
  wizata_dsapi/model_toolkit.py,sha256=UNyw5CFSgZeXydQFsiDIRTjoMeqIsdyIIuiwumLW5bA,1574
21
21
  wizata_dsapi/paged_query_result.py,sha256=0Iyt2Kd4tvrfthhT-tk9EmSERsbJTaPNON2euHcBn6k,1150
22
- wizata_dsapi/pipeline.py,sha256=WDJeOxPZJiYW1qwTNZUm3jom2epIxqrSoiUwcrTF9EE,31300
22
+ wizata_dsapi/pipeline.py,sha256=CtB6-HwJ2OtqUIbwAVTcawLmvYfudgIwGiPZARwTobM,31778
23
23
  wizata_dsapi/pipeline_deployment.py,sha256=grekBaxUK0EhL9w7lDB8vNuW_wzLnHVm9Mq8Lkbkguk,1722
24
24
  wizata_dsapi/pipeline_image.py,sha256=FUxaDDAOZHG8MA2xpZDoG7m1xbtiRSB8YuLFObUSd8c,5274
25
25
  wizata_dsapi/plot.py,sha256=SPGKFWWYNcRvHcqvvnPIIIBKsd5UwhdsxLW7b2dG2rs,2360
@@ -31,10 +31,10 @@ wizata_dsapi/template.py,sha256=wtCRKKk3PchH4RrNgNYlEF_9C6bzZwKIeLyEvgv6Fdo,1370
31
31
  wizata_dsapi/trigger.py,sha256=w3BZYP-L3SUwvaT0oCTanh_Ewn57peZvlt7vxzHv9J8,5129
32
32
  wizata_dsapi/twin.py,sha256=S0DUzQf1smZXZTdXpXZPtkZYCfKIhw53EecCnsl9i4Q,11017
33
33
  wizata_dsapi/twinregistration.py,sha256=Mi6-YuwroiEXc0c1hgrOaphh4hNVoHupxOnXedVtJtE,13377
34
- wizata_dsapi/version.py,sha256=eI6M8nfSYqsdcfjfFStmYRSgA4ytJVy7jwU-LA5MtPc,23
34
+ wizata_dsapi/version.py,sha256=Reh2o_lvciobRRYtnopGjIb_a7RRkURGM3Ay0rAOFfg,23
35
35
  wizata_dsapi/wizard_function.py,sha256=RbM7W7Gf-6Rhp_1dU9DBYkHaciknGAGvuAndhAS_vyo,942
36
36
  wizata_dsapi/wizard_request.py,sha256=v6BaqKLKvTWmUSo0_gda9FabAQz5x_-GOH1Av50GzFo,3762
37
- wizata_dsapi/wizata_dsapi_client.py,sha256=K7xukE_1O3-t1l2Arq6f0rlbMib2xL0oEFGvLqEBjzI,78067
37
+ wizata_dsapi/wizata_dsapi_client.py,sha256=7CpAeo0uBI9XAPpPXWcTUfkTAWJYlCP0CYZTBUNCafU,80517
38
38
  wizata_dsapi/words.py,sha256=tV8CqzCqODZCV7PgBxBF5exBxeF_ya9t5DiUy-cg6Sg,1535
39
39
  wizata_dsapi/models/__init__.py,sha256=O5PHqw8lKILw4apO-MfDxPz73wK0vADD9y3xjuzX7Tw,104
40
40
  wizata_dsapi/models/common.py,sha256=1dTqE80-mFJnUwEdNlJdhJzfZ2N5Kp8Nb3LQ8uwPtLc,3808
@@ -42,8 +42,8 @@ wizata_dsapi/plots/__init__.py,sha256=qgnSFqrjOPur-807M8uh5awIfjM1ZHXUXcAqHc-r2l
42
42
  wizata_dsapi/plots/common.py,sha256=jdPsJqLHBwSKc6dX83BSGPqSRxzIVNHSYO5yI_8sjGk,6568
43
43
  wizata_dsapi/scripts/__init__.py,sha256=hAxiETSQf0qOHde1si1tEAJU48seqEgHrchCzS2-LvQ,80
44
44
  wizata_dsapi/scripts/common.py,sha256=efwq-Rd0lvYljIs3gSFz9izogBD7asOU2cTK-IvHTkM,4244
45
- wizata_dsapi-1.3.40.dist-info/licenses/LICENSE.txt,sha256=QwcOLU5TJoTeUhuIXzhdCEEDDvorGiC6-3YTOl4TecE,11356
46
- wizata_dsapi-1.3.40.dist-info/METADATA,sha256=B_SMDSgjWAASV-xiK5BDmlPCJmQYKlM9csU6__mRseo,5651
47
- wizata_dsapi-1.3.40.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
48
- wizata_dsapi-1.3.40.dist-info/top_level.txt,sha256=-OeTJbEnh5DuWyTOHtvw0Dw3LRg3G27TNS6W4ZtfwPs,13
49
- wizata_dsapi-1.3.40.dist-info/RECORD,,
45
+ wizata_dsapi-1.3.41.dist-info/licenses/LICENSE.txt,sha256=QwcOLU5TJoTeUhuIXzhdCEEDDvorGiC6-3YTOl4TecE,11356
46
+ wizata_dsapi-1.3.41.dist-info/METADATA,sha256=ElNx6WEHlfid-Z9AXSRGwl-5s-7UQu8b8-Wz-FtIrCA,5651
47
+ wizata_dsapi-1.3.41.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
48
+ wizata_dsapi-1.3.41.dist-info/top_level.txt,sha256=-OeTJbEnh5DuWyTOHtvw0Dw3LRg3G27TNS6W4ZtfwPs,13
49
+ wizata_dsapi-1.3.41.dist-info/RECORD,,