clarifai 10.2.1__py3-none-any.whl → 10.3.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
clarifai/client/model.py CHANGED
@@ -1,17 +1,20 @@
1
+ import json
1
2
  import os
2
3
  import time
3
- from typing import Any, Dict, Generator, List, Union
4
+ from typing import Any, Dict, Generator, List, Tuple, Union
4
5
 
6
+ import numpy as np
5
7
  import requests
6
8
  import yaml
7
9
  from clarifai_grpc.grpc.api import resources_pb2, service_pb2
8
10
  from clarifai_grpc.grpc.api.resources_pb2 import Input
9
11
  from clarifai_grpc.grpc.api.status import status_code_pb2
10
12
  from google.protobuf.json_format import MessageToDict
11
- from google.protobuf.struct_pb2 import Struct
13
+ from google.protobuf.struct_pb2 import Struct, Value
12
14
  from tqdm import tqdm
13
15
 
14
16
  from clarifai.client.base import BaseClient
17
+ from clarifai.client.dataset import Dataset
15
18
  from clarifai.client.input import Inputs
16
19
  from clarifai.client.lister import Lister
17
20
  from clarifai.constants.model import MAX_MODEL_PREDICT_INPUTS, TRAINABLE_MODEL_TYPES
@@ -23,6 +26,10 @@ from clarifai.utils.model_train import (find_and_replace_key, params_parser,
23
26
  response_to_model_params, response_to_param_info,
24
27
  response_to_templates)
25
28
 
29
+ MAX_SIZE_PER_STREAM = int(89_128_960) # 85GiB
30
+ MIN_CHUNK_FOR_UPLOAD_FILE = int(5_242_880) # 5MiB
31
+ MAX_CHUNK_FOR_UPLOAD_FILE = int(5_242_880_000) # 5GiB
32
+
26
33
 
27
34
  class Model(Lister, BaseClient):
28
35
  """Model is a class that provides access to Clarifai API endpoints related to Model information."""
@@ -34,6 +41,7 @@ class Model(Lister, BaseClient):
34
41
  base_url: str = "https://api.clarifai.com",
35
42
  pat: str = None,
36
43
  token: str = None,
44
+ root_certificates_path: str = None,
37
45
  **kwargs):
38
46
  """Initializes a Model object.
39
47
 
@@ -44,6 +52,7 @@ class Model(Lister, BaseClient):
44
52
  base_url (str): Base API url. Default "https://api.clarifai.com"
45
53
  pat (str): A personal access token for authentication. Can be set as env var CLARIFAI_PAT
46
54
  token (str): A session token for authentication. Accepts either a session token or a pat. Can be set as env var CLARIFAI_SESSION_TOKEN
55
+ root_certificates_path (str): Path to the SSL root certificates file, used to establish secure gRPC connections.
47
56
  **kwargs: Additional keyword arguments to be passed to the Model.
48
57
  """
49
58
  if url and model_id:
@@ -54,12 +63,18 @@ class Model(Lister, BaseClient):
54
63
  user_id, app_id, _, model_id, model_version_id = ClarifaiUrlHelper.split_clarifai_url(url)
55
64
  model_version = {'id': model_version_id}
56
65
  kwargs = {'user_id': user_id, 'app_id': app_id}
57
- self.kwargs = {**kwargs, 'id': model_id, 'model_version': model_version,}
66
+ self.kwargs = {**kwargs, 'id': model_id, 'model_version': model_version, }
58
67
  self.model_info = resources_pb2.Model(**self.kwargs)
59
68
  self.logger = get_logger(logger_level="INFO", name=__name__)
60
69
  self.training_params = {}
61
70
  BaseClient.__init__(
62
- self, user_id=self.user_id, app_id=self.app_id, base=base_url, pat=pat, token=token)
71
+ self,
72
+ user_id=self.user_id,
73
+ app_id=self.app_id,
74
+ base=base_url,
75
+ pat=pat,
76
+ token=token,
77
+ root_certificates_path=root_certificates_path)
63
78
  Lister.__init__(self)
64
79
 
65
80
  def list_training_templates(self) -> List[str]:
@@ -122,11 +137,11 @@ class Model(Lister, BaseClient):
122
137
  raise Exception(response.status)
123
138
  params = response_to_model_params(
124
139
  response=response, model_type_id=self.model_info.model_type_id, template=template)
125
- #yaml file
140
+ # yaml file
126
141
  assert save_to.endswith('.yaml'), "File extension should be .yaml"
127
142
  with open(save_to, 'w') as f:
128
143
  yaml.dump(params, f, default_flow_style=False, sort_keys=False)
129
- #updating the global model params
144
+ # updating the global model params
130
145
  self.training_params.update(params)
131
146
 
132
147
  return params
@@ -149,14 +164,14 @@ class Model(Lister, BaseClient):
149
164
  raise UserError(
150
165
  f"Run 'model.get_params' to get the params for the {self.model_info.model_type_id} model type"
151
166
  )
152
- #getting all the keys in nested dictionary
167
+ # getting all the keys in nested dictionary
153
168
  all_keys = [key for key in self.training_params.keys()] + [
154
169
  key for key in self.training_params.values() if isinstance(key, dict) for key in key
155
170
  ]
156
- #checking if the given params are valid
171
+ # checking if the given params are valid
157
172
  if not set(kwargs.keys()).issubset(all_keys):
158
173
  raise UserError("Invalid params")
159
- #updating the global model params
174
+ # updating the global model params
160
175
  for key, value in kwargs.items():
161
176
  find_and_replace_key(self.training_params, key, value)
162
177
 
@@ -228,7 +243,7 @@ class Model(Lister, BaseClient):
228
243
  params_dict = yaml.safe_load(file)
229
244
  else:
230
245
  params_dict = self.training_params
231
- #getting all the concepts for the model type
246
+ # getting all the concepts for the model type
232
247
  if self.model_info.model_type_id not in ["clusterer", "text-to-text"]:
233
248
  concepts = self._list_concepts()
234
249
  train_dict = params_parser(params_dict, concepts)
@@ -243,7 +258,7 @@ class Model(Lister, BaseClient):
243
258
 
244
259
  return response.model.model_version.id
245
260
 
246
- def training_status(self, version_id: str, training_logs: bool = False) -> Dict[str, str]:
261
+ def training_status(self, version_id: str = None, training_logs: bool = False) -> Dict[str, str]:
247
262
  """Get the training status for the model version. Also stores training logs
248
263
 
249
264
  Args:
@@ -258,19 +273,20 @@ class Model(Lister, BaseClient):
258
273
  >>> model = Model(model_id='model_id', user_id='user_id', app_id='app_id')
259
274
  >>> model.training_status(version_id='version_id',training_logs=True)
260
275
  """
276
+ if not version_id and not self.model_info.model_version.id:
277
+ raise UserError(
278
+ "Model version ID is missing. Please provide a `model_version` with a valid `id` as an argument or as a URL in the following format: '{user_id}/{app_id}/models/{your_model_id}/model_version_id/{your_version_model_id}' when initializing."
279
+ )
280
+
281
+ if not self.model_info.model_type_id or not self.model_info.model_version.train_log:
282
+ self.load_info()
261
283
  if self.model_info.model_type_id not in TRAINABLE_MODEL_TYPES:
262
284
  raise UserError(f"Model type {self.model_info.model_type_id} is not trainable")
263
285
 
264
- request = service_pb2.GetModelVersionRequest(
265
- user_app_id=self.user_app_id, model_id=self.id, version_id=version_id)
266
- response = self._grpc_request(self.STUB.GetModelVersion, request)
267
- if response.status.code != status_code_pb2.SUCCESS:
268
- raise Exception(response.status)
269
-
270
286
  if training_logs:
271
287
  try:
272
- if response.model_version.train_log:
273
- log_response = requests.get(response.model_version.train_log)
288
+ if self.model_info.model_version.train_log:
289
+ log_response = requests.get(self.model_info.model_version.train_log)
274
290
  log_response.raise_for_status() # Check for any HTTP errors
275
291
  with open(version_id + '.log', 'wb') as file:
276
292
  for chunk in log_response.iter_content(chunk_size=4096): # 4KB
@@ -280,7 +296,7 @@ class Model(Lister, BaseClient):
280
296
  except requests.exceptions.RequestException as e:
281
297
  raise Exception(f"An error occurred while getting training logs: {e}")
282
298
 
283
- return response.model_version.status
299
+ return self.model_info.model_version.status
284
300
 
285
301
  def delete_version(self, version_id: str) -> None:
286
302
  """Deletes a model version for the Model.
@@ -412,7 +428,7 @@ class Model(Lister, BaseClient):
412
428
  response = self._grpc_request(self.STUB.PostModelOutputs, request)
413
429
 
414
430
  if response.status.code == status_code_pb2.MODEL_DEPLOYING and \
415
- time.time() - start_time < 60 * 10: # 10 minutes
431
+ time.time() - start_time < 60 * 10: # 10 minutes
416
432
  self.logger.info(f"{self.id} model is still deploying, please wait...")
417
433
  time.sleep(next(backoff_iterator))
418
434
  continue
@@ -617,18 +633,22 @@ class Model(Lister, BaseClient):
617
633
  return response.eval_metrics
618
634
 
619
635
  def evaluate(self,
620
- dataset_id: str,
636
+ dataset: Dataset = None,
637
+ dataset_id: str = None,
621
638
  dataset_app_id: str = None,
622
639
  dataset_user_id: str = None,
640
+ dataset_version_id: str = None,
623
641
  eval_id: str = None,
624
642
  extended_metrics: dict = None,
625
643
  eval_info: dict = None) -> resources_pb2.EvalMetrics:
626
644
  """ Run evaluation
627
645
 
628
646
  Args:
629
- dataset_id (str): Dataset Id.
647
+ dataset (Dataset): If Clarifai Dataset is set, it will ignore other arguments prefixed with 'dataset_'.
648
+ dataset_id (str): Dataset Id. Default is None.
630
649
  dataset_app_id (str): App ID for cross app evaluation, leave it as None to use Model App ID. Default is None.
631
650
  dataset_user_id (str): User ID for cross app evaluation, leave it as None to use Model User ID. Default is None.
651
+ dataset_version_id (str): Dataset version Id. Default is None.
632
652
  eval_id (str): Specific ID for the evaluation. You must specify this parameter to either overwrite the result with the dataset ID or format your evaluation in an informative manner. If you don't, it will use random ID from system. Default is None.
633
653
  extended_metrics (dict): user custom metrics result. Default is None.
634
654
  eval_info (dict): custom eval info. Default is empty dict.
@@ -638,6 +658,23 @@ class Model(Lister, BaseClient):
638
658
 
639
659
  """
640
660
  assert self.model_info.model_version.id, "Model version is empty. Please provide `model_version` as arguments or with a URL as the format '{user_id}/{app_id}/models/{your_model_id}/model_version_id/{your_version_model_id}' when initializing."
661
+
662
+ if dataset:
663
+ self.logger.info("Using dataset, ignore other arguments prefixed with 'dataset_'")
664
+ dataset_id = dataset.id
665
+ dataset_app_id = dataset.app_id
666
+ dataset_user_id = dataset.user_id
667
+ dataset_version_id = dataset.version.id
668
+ else:
669
+ self.logger.warning(
670
+ "Arguments prefixed with `dataset_` will be removed soon, please use dataset")
671
+
672
+ gt_dataset = resources_pb2.Dataset(
673
+ id=dataset_id,
674
+ app_id=dataset_app_id or self.auth_helper.app_id,
675
+ user_id=dataset_user_id or self.auth_helper.user_id,
676
+ version=resources_pb2.DatasetVersion(id=dataset_version_id))
677
+
641
678
  metrics = None
642
679
  if isinstance(extended_metrics, dict):
643
680
  metrics = Struct()
@@ -659,11 +696,7 @@ class Model(Lister, BaseClient):
659
696
  model_version=resources_pb2.ModelVersion(id=self.model_info.model_version.id),
660
697
  ),
661
698
  extended_metrics=metrics,
662
- ground_truth_dataset=resources_pb2.Dataset(
663
- id=dataset_id,
664
- app_id=dataset_app_id or self.auth_helper.app_id,
665
- user_id=dataset_user_id or self.auth_helper.user_id,
666
- ),
699
+ ground_truth_dataset=gt_dataset,
667
700
  eval_info=eval_info_params,
668
701
  )
669
702
  request = service_pb2.PostEvaluationsRequest(
@@ -761,6 +794,157 @@ class Model(Lister, BaseClient):
761
794
 
762
795
  return result
763
796
 
797
+ def get_eval_by_dataset(self, dataset: Dataset) -> List[resources_pb2.EvalMetrics]:
798
+ """Get all eval data of dataset
799
+
800
+ Args:
801
+ dataset (Dataset): Clarifai dataset
802
+
803
+ Returns:
804
+ List[resources_pb2.EvalMetrics]
805
+ """
806
+ _id = dataset.id
807
+ app = dataset.app_id or self.app_id
808
+ user_id = dataset.user_id or self.user_id
809
+ version = dataset.version.id
810
+
811
+ list_eval: resources_pb2.EvalMetrics = self.list_evaluations()
812
+ outputs = []
813
+ for _eval in list_eval:
814
+ if _eval.status.code == status_code_pb2.MODEL_EVALUATED:
815
+ gt_ds = _eval.ground_truth_dataset
816
+ if (_id == gt_ds.id and user_id == gt_ds.user_id and app == gt_ds.app_id):
817
+ if not version or version == gt_ds.version.id:
818
+ outputs.append(_eval)
819
+
820
+ return outputs
821
+
822
+ def get_raw_eval(self,
823
+ dataset: Dataset = None,
824
+ eval_id: str = None,
825
+ return_format: str = 'array') -> Union[resources_pb2.EvalTestSetEntry, Tuple[
826
+ np.array, np.array, list, List[Input]], Tuple[List[dict], List[dict]]]:
827
+ """Get ground truths, predictions and input information. Do not pass dataset and eval_id at same time
828
+
829
+ Args:
830
+ dataset (Dataset): Clarifai dataset, get eval data of latest eval result of dataset.
831
+ eval_id (str): Evaluation ID, get eval data of specific eval id.
832
+ return_format (str, optional): Choice {proto, array, coco}. !Note that `coco` is only applicable for 'visual-detector'. Defaults to 'array'.
833
+
834
+ Returns:
835
+
836
+ Depends on `return_format`.
837
+
838
+ * if return_format == proto
839
+ `resources_pb2.EvalTestSetEntry`
840
+
841
+ * if return_format == array
842
+ `Tuple(np.array, np.array, List[str], List[Input])`: Tuple has 4 elements (y, y_pred, concept_ids, inputs).
843
+ y, y_pred, concept_ids can be used to compute metrics. 'inputs' can be use to download
844
+ - if model is 'classifier': 'y' and 'y_pred' are both arrays with a shape of (num_inputs,)
845
+ - if model is 'visual-detector': 'y' and 'y_pred' are arrays with a shape of (num_inputs,), where each element is array has shape (num_annotation, 6) consists of [x_min, y_min, x_max, y_max, concept_index, score]. The score is always 1 for 'y'
846
+
847
+ * if return_format == coco: Applicable only for 'visual-detector'
848
+ `Tuple[List[Dict], List[Dict]]`: Tuple has 2 elemnts where first element is COCO Ground Truth and last one is COCO Prediction Annotation
849
+
850
+ Example Usages:
851
+ ------
852
+ * Evaluate `visual-classifier` using sklearn
853
+
854
+ ```python
855
+ import os
856
+ from sklearn.metrics import accuracy_score
857
+ from sklearn.metrics import classification_report
858
+ import numpy as np
859
+ from clarifai.client.model import Model
860
+ from clarifai.client.dataset import Dataset
861
+ os.environ["CLARIFAI_PAT"] = "???"
862
+ model = Model(url="url/of/model/includes/version-id")
863
+ dataset = Dataset(dataset_id="dataset-id")
864
+ y, y_pred, clss, input_protos = model.get_raw_eval(dataset, return_format="array")
865
+ y = np.argmax(y, axis=1)
866
+ y_pred = np.argmax(y_pred, axis=1)
867
+ report = classification_report(y, y_pred, target_names=clss)
868
+ print(report)
869
+ acc = accuracy_score(y, y_pred)
870
+ print("acc ", acc)
871
+ ```
872
+
873
+ * Evaluate `visual-detector` using COCOeval
874
+
875
+ ```python
876
+ import os
877
+ import json
878
+ from pycocotools.coco import COCO
879
+ from pycocotools.cocoeval import COCOeval
880
+ from clarifai.client.model import Model
881
+ from clarifai.client.dataset import Dataset
882
+ os.environ["CLARIFAI_PAT"] = "???" # Insert your PAT
883
+ model = Model(url=model_url)
884
+ dataset = Dataset(url=dataset_url)
885
+ y, y_pred = model.get_raw_eval(dataset, return_format="coco")
886
+ # save as files to load in COCO API
887
+ def save_annot(d, path):
888
+ with open(path, "w") as fp:
889
+ json.dump(d, fp, indent=2)
890
+ gt_path = os.path.join("gt.json")
891
+ pred_path = os.path.join("pred.json")
892
+ save_annot(y, gt_path)
893
+ save_annot(y_pred, pred_path)
894
+
895
+ cocoGt = COCO(gt_path)
896
+ cocoPred = COCO(pred_path)
897
+ cocoEval = COCOeval(cocoGt, cocoPred, "bbox")
898
+ cocoEval.evaluate()
899
+ cocoEval.accumulate()
900
+ cocoEval.summarize() # Print out result of all classes with all area type
901
+ # Example:
902
+ # Average Precision (AP) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.863
903
+ # Average Precision (AP) @[ IoU=0.50 | area= all | maxDets=100 ] = 0.973
904
+ # Average Precision (AP) @[ IoU=0.75 | area= all | maxDets=100 ] = 0.939
905
+ # ...
906
+ ```
907
+
908
+ """
909
+ from clarifai.utils.evaluation.testset_annotation_parser import (
910
+ parse_eval_annotation_classifier, parse_eval_annotation_detector,
911
+ parse_eval_annotation_detector_coco)
912
+
913
+ valid_model_types = ["visual-classifier", "text-classifier", "visual-detector"]
914
+ supported_format = ['proto', 'array', 'coco']
915
+ assert return_format in supported_format, ValueError(
916
+ f"Expected return_format in {supported_format}, got {return_format}")
917
+ self.load_info()
918
+ model_type_id = self.model_info.model_type_id
919
+ assert model_type_id in valid_model_types, \
920
+ f"This method only supports model types {valid_model_types}, but your model type is {self.model_info.model_type_id}."
921
+ assert not (dataset and
922
+ eval_id), "Using both `dataset` and `eval_id`, but only one should be passed."
923
+ assert not dataset or not eval_id, "Please provide either `dataset` or `eval_id`, but nothing was passed."
924
+ if model_type_id.endswith("-classifier") and return_format == "coco":
925
+ raise ValueError(
926
+ f"return_format coco only applies for `visual-detector`, however your model is `{model_type_id}`"
927
+ )
928
+
929
+ if dataset:
930
+ eval_by_ds = self.get_eval_by_dataset(dataset)
931
+ if len(eval_by_ds) == 0:
932
+ raise Exception(f"Model is not valuated with dataset: {dataset}")
933
+ eval_id = eval_by_ds[0].id
934
+
935
+ detail_eval_data = self.get_eval_by_id(eval_id=eval_id, test_set=True, metrics_by_class=True)
936
+
937
+ if return_format == "proto":
938
+ return detail_eval_data.test_set
939
+ else:
940
+ if model_type_id.endswith("-classifier"):
941
+ return parse_eval_annotation_classifier(detail_eval_data)
942
+ elif model_type_id == "visual-detector":
943
+ if return_format == "array":
944
+ return parse_eval_annotation_detector(detail_eval_data)
945
+ elif return_format == "coco":
946
+ return parse_eval_annotation_detector_coco(detail_eval_data)
947
+
764
948
  def export(self, export_dir: str = None) -> None:
765
949
  """Export the model, stores the exported model as model.tar file
766
950
 
@@ -834,7 +1018,7 @@ class Model(Lister, BaseClient):
834
1018
  while True:
835
1019
  get_export_response = _get_export_response()
836
1020
  if get_export_response.export.status.code == status_code_pb2.MODEL_EXPORTING and \
837
- time.time() - start_time < 60 * 30: # 30 minutes
1021
+ time.time() - start_time < 60 * 30: # 30 minutes
838
1022
  self.logger.info(
839
1023
  f"Model ID {self.id} with version {self.model_info.model_version.id} is still exporting, please wait..."
840
1024
  )
@@ -848,3 +1032,302 @@ class Model(Lister, BaseClient):
848
1032
  Req ID: {get_export_response.status.req_id}""")
849
1033
  elif get_export_response.export.status.code == status_code_pb2.MODEL_EXPORTED:
850
1034
  _download_exported_model(get_export_response, os.path.join(export_dir, "model.tar"))
1035
+
1036
+ @staticmethod
1037
+ def _make_pretrained_config_proto(input_field_maps: dict,
1038
+ output_field_maps: dict,
1039
+ url: str = None):
1040
+ """Make PretrainedModelConfig for uploading new version
1041
+
1042
+ Args:
1043
+ input_field_maps (dict): dict
1044
+ output_field_maps (dict): dict
1045
+ url (str, optional): direct download url. Defaults to None.
1046
+ """
1047
+
1048
+ def _parse_fields_map(x):
1049
+ """parse input, outputs to Struct"""
1050
+ _fields_map = Struct()
1051
+ _fields_map.update(x)
1052
+ return _fields_map
1053
+
1054
+ input_fields_map = _parse_fields_map(input_field_maps)
1055
+ output_fields_map = _parse_fields_map(output_field_maps)
1056
+
1057
+ return resources_pb2.PretrainedModelConfig(
1058
+ input_fields_map=input_fields_map, output_fields_map=output_fields_map, model_zip_url=url)
1059
+
1060
+ @staticmethod
1061
+ def _make_inference_params_proto(
1062
+ inference_parameters: List[Dict]) -> List[resources_pb2.ModelTypeField]:
1063
+ """Convert list of Clarifai inference parameters to proto for uploading new version
1064
+
1065
+ Args:
1066
+ inference_parameters (List[Dict]): Each dict has keys {field_type, path, default_value, description}
1067
+
1068
+ Returns:
1069
+ List[resources_pb2.ModelTypeField]
1070
+ """
1071
+
1072
+ def _make_default_value_proto(dtype, value):
1073
+ if dtype == 1:
1074
+ return Value(bool_value=value)
1075
+ elif dtype == 2 or dtype == 21:
1076
+ return Value(string_value=value)
1077
+ elif dtype == 3:
1078
+ return Value(number_value=value)
1079
+
1080
+ iterative_proto_params = []
1081
+ for param in inference_parameters:
1082
+ dtype = param.get("field_type")
1083
+ proto_param = resources_pb2.ModelTypeField(
1084
+ path=param.get("path"),
1085
+ field_type=dtype,
1086
+ default_value=_make_default_value_proto(dtype=dtype, value=param.get("default_value")),
1087
+ description=param.get("description"),
1088
+ )
1089
+ iterative_proto_params.append(proto_param)
1090
+ return iterative_proto_params
1091
+
1092
+ def create_version_by_file(self,
1093
+ file_path: str,
1094
+ input_field_maps: dict,
1095
+ output_field_maps: dict,
1096
+ inference_parameter_configs: dict = None,
1097
+ model_version: str = None,
1098
+ part_id: int = 1,
1099
+ range_start: int = 0,
1100
+ no_cache: bool = False,
1101
+ no_resume: bool = False,
1102
+ description: str = "") -> 'Model':
1103
+ """Create model version by uploading local file
1104
+
1105
+ Args:
1106
+ file_path (str): path to built file.
1107
+ input_field_maps (dict): a dict where the key is clarifai input field and the value is triton model input,
1108
+ {clarifai_input_field: triton_input_filed}.
1109
+ output_field_maps (dict): a dict where the keys are clarifai output fields and the values are triton model outputs,
1110
+ {clarifai_output_field1: triton_output_filed1, clarifai_output_field2: triton_output_filed2,...}.
1111
+ inference_parameter_configs (List[dict]): list of dicts - keys are path, field_type, default_value, description. Default is None
1112
+ model_version (str, optional): Custom model version. Defaults to None.
1113
+ part_id (int, optional): part id of file. Defaults to 1.
1114
+ range_start (int, optional): range of uploaded size. Defaults to 0.
1115
+ no_cache (bool, optional): not saving uploading cache that is used to resume uploading. Defaults to False.
1116
+ no_resume (bool, optional): disable auto resume upload. Defaults to False.
1117
+ description (str): Model description.
1118
+
1119
+ Return:
1120
+ Model: instance of Model with new created version
1121
+
1122
+ """
1123
+ file_size = os.path.getsize(file_path)
1124
+ assert MIN_CHUNK_FOR_UPLOAD_FILE <= file_size <= MAX_CHUNK_FOR_UPLOAD_FILE, "The file size exceeds the allowable limit, which ranges from 5MiB to 5GiB."
1125
+
1126
+ pretrained_proto = Model._make_pretrained_config_proto(
1127
+ input_field_maps=input_field_maps, output_field_maps=output_field_maps)
1128
+ inference_param_proto = Model._make_inference_params_proto(
1129
+ inference_parameter_configs) if inference_parameter_configs else None
1130
+
1131
+ if file_size >= 1e9:
1132
+ chunk_size = 1024 * 50_000 # 50MB
1133
+ else:
1134
+ chunk_size = 1024 * 10_000 # 10MB
1135
+
1136
+ #self.logger.info(f"Chunk {chunk_size/1e6}MB, {file_size/chunk_size} steps")
1137
+ #self.logger.info(f" Max bytes per stream {MAX_SIZE_PER_STREAM}")
1138
+
1139
+ cache_dir = os.path.join(file_path, '..', '.cache')
1140
+ cache_upload_file = os.path.join(cache_dir, "upload.json")
1141
+ last_percent = 0
1142
+ if os.path.exists(cache_upload_file) and not no_resume:
1143
+ with open(cache_upload_file, "r") as fp:
1144
+ try:
1145
+ cache_info = json.load(fp)
1146
+ if isinstance(cache_info, dict):
1147
+ part_id = cache_info.get("part_id", part_id)
1148
+ chunk_size = cache_info.get("chunk_size", chunk_size)
1149
+ range_start = cache_info.get("range_start", range_start)
1150
+ model_version = cache_info.get("model_version", model_version)
1151
+ last_percent = cache_info.get("last_percent", last_percent)
1152
+ except Exception as e:
1153
+ self.logger.error(f"Skipping loading the upload cache due to error {e}.")
1154
+
1155
+ def init_model_version_upload(model_version):
1156
+ return service_pb2.PostModelVersionsUploadRequest(
1157
+ upload_config=service_pb2.PostModelVersionsUploadConfig(
1158
+ user_app_id=self.user_app_id,
1159
+ model_id=self.id,
1160
+ total_size=file_size,
1161
+ model_version=resources_pb2.ModelVersion(
1162
+ id=model_version,
1163
+ pretrained_model_config=pretrained_proto,
1164
+ description=description,
1165
+ output_info=resources_pb2.OutputInfo(params_specs=inference_param_proto)),
1166
+ ))
1167
+
1168
+ def _uploading(chunk, part_id, range_start, model_version):
1169
+ return service_pb2.PostModelVersionsUploadRequest(
1170
+ content_part=resources_pb2.UploadContentPart(
1171
+ data=chunk, part_number=part_id, range_start=range_start))
1172
+
1173
+ finished_status = [status_code_pb2.SUCCESS, status_code_pb2.UPLOAD_DONE]
1174
+ uploading_in_progress_status = [
1175
+ status_code_pb2.UPLOAD_IN_PROGRESS, status_code_pb2.MODEL_UPLOADING
1176
+ ]
1177
+
1178
+ def _save_cache(cache: dict):
1179
+ if not no_cache:
1180
+ os.makedirs(cache_dir, exist_ok=True)
1181
+ with open(cache_upload_file, "w") as fp:
1182
+ json.dump(cache, fp, indent=2)
1183
+
1184
+ def stream_request(fp, part_id, end_part_id, chunk_size, version):
1185
+ yield init_model_version_upload(version)
1186
+ for iter_part_id in range(part_id, end_part_id):
1187
+ chunk = fp.read(chunk_size)
1188
+ if not chunk:
1189
+ return
1190
+ yield _uploading(
1191
+ chunk=chunk,
1192
+ part_id=iter_part_id,
1193
+ range_start=chunk_size * (iter_part_id - 1),
1194
+ model_version=version)
1195
+
1196
+ tqdm_loader = tqdm(total=100)
1197
+ if model_version:
1198
+ desc = f"Uploading model `{self.id}` version `{model_version}` ..."
1199
+ else:
1200
+ desc = f"Uploading model `{self.id}` ..."
1201
+ tqdm_loader.set_description(desc)
1202
+
1203
+ cache_uploading_info = {}
1204
+ cache_uploading_info["part_id"] = part_id
1205
+ cache_uploading_info["model_version"] = model_version
1206
+ cache_uploading_info["range_start"] = range_start
1207
+ cache_uploading_info["chunk_size"] = chunk_size
1208
+ cache_uploading_info["last_percent"] = last_percent
1209
+ tqdm_loader.update(last_percent)
1210
+ last_part_id = part_id
1211
+ n_chunks = file_size // chunk_size
1212
+ n_chunk_per_stream = MAX_SIZE_PER_STREAM // chunk_size or 1
1213
+
1214
+ def stream_and_logging(request, tqdm_loader, cache_uploading_info, expected_steps: int = None):
1215
+ for st_step, st_response in enumerate(self.auth_helper.get_stub().PostModelVersionsUpload(
1216
+ request, metadata=self.auth_helper.metadata)):
1217
+ if st_response.status.code in uploading_in_progress_status:
1218
+ if cache_uploading_info["model_version"]:
1219
+ assert st_response.model_version_id == cache_uploading_info[
1220
+ "model_version"], RuntimeError
1221
+ else:
1222
+ cache_uploading_info["model_version"] = st_response.model_version_id
1223
+ if st_step > 0:
1224
+ cache_uploading_info["part_id"] += 1
1225
+ cache_uploading_info["range_start"] += chunk_size
1226
+ _save_cache(cache_uploading_info)
1227
+
1228
+ if st_response.status.percent_completed:
1229
+ step_percent = st_response.status.percent_completed - cache_uploading_info["last_percent"]
1230
+ cache_uploading_info["last_percent"] += step_percent
1231
+ tqdm_loader.set_description(
1232
+ f"{st_response.status.description}, {st_response.status.details}, version id {cache_uploading_info.get('model_version')}"
1233
+ )
1234
+ tqdm_loader.update(step_percent)
1235
+ elif st_response.status.code not in finished_status + uploading_in_progress_status:
1236
+ # TODO: Find better way to handle error
1237
+ if expected_steps and st_step < expected_steps:
1238
+ raise Exception(f"Failed to upload model, error: {st_response.status}")
1239
+
1240
+ with open(file_path, 'rb') as fp:
1241
+ # seeking
1242
+ for _ in range(1, last_part_id):
1243
+ fp.read(chunk_size)
1244
+ # Stream even part
1245
+ end_part_id = n_chunks or 1
1246
+ for iter_part_id in range(int(last_part_id), int(n_chunks), int(n_chunk_per_stream)):
1247
+ end_part_id = iter_part_id + n_chunk_per_stream
1248
+ if end_part_id >= n_chunks:
1249
+ end_part_id = n_chunks
1250
+ expected_steps = end_part_id - iter_part_id + 1 # init step
1251
+ st_reqs = stream_request(
1252
+ fp,
1253
+ iter_part_id,
1254
+ end_part_id=end_part_id,
1255
+ chunk_size=chunk_size,
1256
+ version=cache_uploading_info["model_version"])
1257
+ stream_and_logging(st_reqs, tqdm_loader, cache_uploading_info, expected_steps)
1258
+ # Stream last part
1259
+ accum_size = (end_part_id - 1) * chunk_size
1260
+ remained_size = file_size - accum_size if accum_size >= 0 else file_size
1261
+ st_reqs = stream_request(
1262
+ fp,
1263
+ end_part_id,
1264
+ end_part_id=end_part_id + 1,
1265
+ chunk_size=remained_size,
1266
+ version=cache_uploading_info["model_version"])
1267
+ stream_and_logging(st_reqs, tqdm_loader, cache_uploading_info, 2)
1268
+
1269
+ # clean up cache
1270
+ if not no_cache:
1271
+ try:
1272
+ os.remove(cache_upload_file)
1273
+ except Exception:
1274
+ _save_cache({})
1275
+
1276
+ if cache_uploading_info["last_percent"] <= 100:
1277
+ tqdm_loader.update(100 - cache_uploading_info["last_percent"])
1278
+ tqdm_loader.set_description("Upload done")
1279
+
1280
+ tqdm_loader.set_description(
1281
+ f"Success uploading model {self.id}, new version {cache_uploading_info.get('model_version')}"
1282
+ )
1283
+
1284
+ return Model.from_auth_helper(
1285
+ auth=self.auth_helper,
1286
+ model_id=self.id,
1287
+ model_version=dict(id=cache_uploading_info.get('model_version')))
1288
+
1289
+ def create_version_by_url(self,
1290
+ url: str,
1291
+ input_field_maps: dict,
1292
+ output_field_maps: dict,
1293
+ inference_parameter_configs: List[dict] = None,
1294
+ description: str = "") -> 'Model':
1295
+ """Upload a new version of an existing model in the Clarifai platform using direct download url.
1296
+
1297
+ Args:
1298
+ url (str]): url of zip of model
1299
+ input_field_maps (dict): a dict where the key is clarifai input field and the value is triton model input,
1300
+ {clarifai_input_field: triton_input_filed}.
1301
+ output_field_maps (dict): a dict where the keys are clarifai output fields and the values are triton model outputs,
1302
+ {clarifai_output_field1: triton_output_filed1, clarifai_output_field2: triton_output_filed2,...}.
1303
+ inference_parameter_configs (List[dict]): list of dicts - keys are path, field_type, default_value, description. Default is None
1304
+ description (str): Model description.
1305
+
1306
+ Return:
1307
+ Model: instance of Model with new created version
1308
+ """
1309
+
1310
+ pretrained_proto = Model._make_pretrained_config_proto(
1311
+ input_field_maps=input_field_maps, output_field_maps=output_field_maps, url=url)
1312
+ inference_param_proto = Model._make_inference_params_proto(
1313
+ inference_parameter_configs) if inference_parameter_configs else None
1314
+ request = service_pb2.PostModelVersionsRequest(
1315
+ user_app_id=self.user_app_id,
1316
+ model_id=self.id,
1317
+ model_versions=[
1318
+ resources_pb2.ModelVersion(
1319
+ pretrained_model_config=pretrained_proto,
1320
+ description=description,
1321
+ output_info=resources_pb2.OutputInfo(params_specs=inference_param_proto))
1322
+ ])
1323
+ response = self._grpc_request(self.STUB.PostModelVersions, request)
1324
+
1325
+ if response.status.code != status_code_pb2.SUCCESS:
1326
+ raise Exception(f"Failed to upload model, error: {response.status}")
1327
+ self.logger.info(
1328
+ f"Success uploading model {self.id}, new version {response.model.model_version.id}")
1329
+
1330
+ return Model.from_auth_helper(
1331
+ auth=self.auth_helper,
1332
+ model_id=self.id,
1333
+ model_version=dict(id=response.model.model_version.id))