clarifai 10.3.0__py3-none-any.whl → 10.3.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- clarifai/client/input.py +32 -9
- clarifai/client/model.py +355 -36
- clarifai/client/search.py +90 -15
- clarifai/constants/model.py +1 -0
- clarifai/constants/search.py +2 -1
- clarifai/datasets/upload/features.py +4 -0
- clarifai/datasets/upload/image.py +25 -2
- clarifai/datasets/upload/loaders/coco_captions.py +7 -2
- clarifai/datasets/upload/loaders/coco_detection.py +7 -2
- clarifai/datasets/upload/text.py +2 -0
- clarifai/models/model_serving/README.md +3 -0
- clarifai/models/model_serving/cli/upload.py +65 -68
- clarifai/models/model_serving/docs/cli.md +17 -6
- clarifai/rag/rag.py +1 -1
- clarifai/rag/utils.py +2 -2
- clarifai/versions.py +1 -1
- {clarifai-10.3.0.dist-info → clarifai-10.3.2.dist-info}/METADATA +20 -3
- {clarifai-10.3.0.dist-info → clarifai-10.3.2.dist-info}/RECORD +22 -22
- {clarifai-10.3.0.dist-info → clarifai-10.3.2.dist-info}/LICENSE +0 -0
- {clarifai-10.3.0.dist-info → clarifai-10.3.2.dist-info}/WHEEL +0 -0
- {clarifai-10.3.0.dist-info → clarifai-10.3.2.dist-info}/entry_points.txt +0 -0
- {clarifai-10.3.0.dist-info → clarifai-10.3.2.dist-info}/top_level.txt +0 -0
    
        clarifai/client/input.py
    CHANGED
    
    | @@ -72,6 +72,7 @@ class Inputs(Lister, BaseClient): | |
| 72 72 | 
             
                             text_pb: Text = None,
         | 
| 73 73 | 
             
                             geo_info: List = None,
         | 
| 74 74 | 
             
                             labels: List = None,
         | 
| 75 | 
            +
                             label_ids: List = None,
         | 
| 75 76 | 
             
                             metadata: Struct = None) -> Input:
         | 
| 76 77 | 
             
                """Create input proto for image data type.
         | 
| 77 78 | 
             
                    Args:
         | 
| @@ -82,7 +83,8 @@ class Inputs(Lister, BaseClient): | |
| 82 83 | 
             
                        audio_pb (Audio): The audio proto to be used for the input.
         | 
| 83 84 | 
             
                        text_pb (Text): The text proto to be used for the input.
         | 
| 84 85 | 
             
                        geo_info (list): A list of longitude and latitude for the geo point.
         | 
| 85 | 
            -
                        labels (list): A list of  | 
| 86 | 
            +
                        labels (list): A list of label names for the input.
         | 
| 87 | 
            +
                        label_ids (list): A list of label ids for the input.
         | 
| 86 88 | 
             
                        metadata (Struct): A Struct of metadata for the input.
         | 
| 87 89 | 
             
                    Returns:
         | 
| 88 90 | 
             
                        Input: An Input object for the specified input ID.
         | 
| @@ -90,14 +92,26 @@ class Inputs(Lister, BaseClient): | |
| 90 92 | 
             
                assert geo_info is None or isinstance(
         | 
| 91 93 | 
             
                    geo_info, list), "geo_info must be a list of longitude and latitude"
         | 
| 92 94 | 
             
                assert labels is None or isinstance(labels, list), "labels must be a list of strings"
         | 
| 95 | 
            +
                assert label_ids is None or isinstance(label_ids, list), "label_ids must be a list of strings"
         | 
| 93 96 | 
             
                assert metadata is None or isinstance(metadata, Struct), "metadata must be a Struct"
         | 
| 94 97 | 
             
                geo_pb = resources_pb2.Geo(geo_point=resources_pb2.GeoPoint(
         | 
| 95 98 | 
             
                    longitude=geo_info[0], latitude=geo_info[1])) if geo_info else None
         | 
| 96 | 
            -
                 | 
| 99 | 
            +
                if labels:
         | 
| 100 | 
            +
                  if not label_ids:
         | 
| 101 | 
            +
                    concepts=[
         | 
| 97 102 | 
             
                        resources_pb2.Concept(
         | 
| 98 103 | 
             
                        id=f"id-{''.join(_label.split(' '))}", name=_label, value=1.)\
         | 
| 99 104 | 
             
                        for _label in labels
         | 
| 100 | 
            -
                    ] | 
| 105 | 
            +
                    ]
         | 
| 106 | 
            +
                  else:
         | 
| 107 | 
            +
                    assert len(labels) == len(label_ids), "labels and label_ids must be of the same length"
         | 
| 108 | 
            +
                    concepts=[
         | 
| 109 | 
            +
                        resources_pb2.Concept(
         | 
| 110 | 
            +
                        id=label_id, name=_label, value=1.)\
         | 
| 111 | 
            +
                        for label_id, _label in zip(label_ids, labels)
         | 
| 112 | 
            +
                    ]
         | 
| 113 | 
            +
                else:
         | 
| 114 | 
            +
                  concepts = None
         | 
| 101 115 |  | 
| 102 116 | 
             
                if dataset_id:
         | 
| 103 117 | 
             
                  return resources_pb2.Input(
         | 
| @@ -467,13 +481,14 @@ class Inputs(Lister, BaseClient): | |
| 467 481 | 
             
                return input_protos
         | 
| 468 482 |  | 
| 469 483 | 
             
              @staticmethod
         | 
| 470 | 
            -
              def get_bbox_proto(input_id: str, label: str, bbox: List) -> Annotation:
         | 
| 484 | 
            +
              def get_bbox_proto(input_id: str, label: str, bbox: List, label_id: str = None) -> Annotation:
         | 
| 471 485 | 
             
                """Create an annotation proto for each bounding box, label input pair.
         | 
| 472 486 |  | 
| 473 487 | 
             
                Args:
         | 
| 474 488 | 
             
                    input_id (str): The input ID for the annotation to create.
         | 
| 475 | 
            -
                    label (str): annotation label
         | 
| 489 | 
            +
                    label (str): annotation label name
         | 
| 476 490 | 
             
                    bbox (List): a list of a single bbox's coordinates. # bbox ordering: [xmin, ymin, xmax, ymax]
         | 
| 491 | 
            +
                    label_id (str): annotation label ID
         | 
| 477 492 |  | 
| 478 493 | 
             
                Returns:
         | 
| 479 494 | 
             
                    An annotation object for the specified input ID.
         | 
| @@ -500,19 +515,22 @@ class Inputs(Lister, BaseClient): | |
| 500 515 | 
             
                            data=resources_pb2.Data(concepts=[
         | 
| 501 516 | 
             
                                resources_pb2.Concept(
         | 
| 502 517 | 
             
                                    id=f"id-{''.join(label.split(' '))}", name=label, value=1.)
         | 
| 518 | 
            +
                                if not label_id else resources_pb2.Concept(id=label_id, name=label, value=1.)
         | 
| 503 519 | 
             
                            ]))
         | 
| 504 520 | 
             
                    ]))
         | 
| 505 521 |  | 
| 506 522 | 
             
                return input_annot_proto
         | 
| 507 523 |  | 
| 508 524 | 
             
              @staticmethod
         | 
| 509 | 
            -
              def get_mask_proto(input_id: str, label: str, polygons: List[List[float]] | 
| 525 | 
            +
              def get_mask_proto(input_id: str, label: str, polygons: List[List[float]],
         | 
| 526 | 
            +
                                 label_id: str = None) -> Annotation:
         | 
| 510 527 | 
             
                """Create an annotation proto for each polygon box, label input pair.
         | 
| 511 528 |  | 
| 512 529 | 
             
                Args:
         | 
| 513 530 | 
             
                    input_id (str): The input ID for the annotation to create.
         | 
| 514 | 
            -
                    label (str): annotation label
         | 
| 531 | 
            +
                    label (str): annotation label name
         | 
| 515 532 | 
             
                    polygons (List): Polygon x,y points iterable
         | 
| 533 | 
            +
                    label_id (str): annotation label ID
         | 
| 516 534 |  | 
| 517 535 | 
             
                Returns:
         | 
| 518 536 | 
             
                    An annotation object for the specified input ID.
         | 
| @@ -537,6 +555,7 @@ class Inputs(Lister, BaseClient): | |
| 537 555 | 
             
                            data=resources_pb2.Data(concepts=[
         | 
| 538 556 | 
             
                                resources_pb2.Concept(
         | 
| 539 557 | 
             
                                    id=f"id-{''.join(label.split(' '))}", name=label, value=1.)
         | 
| 558 | 
            +
                                if not label_id else resources_pb2.Concept(id=label_id, name=label, value=1.)
         | 
| 540 559 | 
             
                            ]))
         | 
| 541 560 | 
             
                    ]))
         | 
| 542 561 |  | 
| @@ -726,16 +745,20 @@ class Inputs(Lister, BaseClient): | |
| 726 745 | 
             
                request = service_pb2.PostAnnotationsRequest(
         | 
| 727 746 | 
             
                    user_app_id=self.user_app_id, annotations=batch_annot)
         | 
| 728 747 | 
             
                response = self._grpc_request(self.STUB.PostAnnotations, request)
         | 
| 748 | 
            +
                response_dict = MessageToDict(response)
         | 
| 729 749 | 
             
                if response.status.code != status_code_pb2.SUCCESS:
         | 
| 730 750 | 
             
                  try:
         | 
| 731 | 
            -
                     | 
| 751 | 
            +
                    for annot in response_dict["annotations"]:
         | 
| 752 | 
            +
                      if annot['status']['code'] != status_code_pb2.ANNOTATION_SUCCESS:
         | 
| 753 | 
            +
                        self.logger.warning(f"Post annotations failed, status: {annot['status']}")
         | 
| 732 754 | 
             
                  except Exception:
         | 
| 733 | 
            -
                    self.logger.warning(f"Post annotations failed | 
| 755 | 
            +
                    self.logger.warning(f"Post annotations failed due to {response.status}")
         | 
| 734 756 | 
             
                  finally:
         | 
| 735 757 | 
             
                    retry_upload.extend(batch_annot)
         | 
| 736 758 | 
             
                else:
         | 
| 737 759 | 
             
                  if show_log:
         | 
| 738 760 | 
             
                    self.logger.info("\nAnnotations Uploaded\n%s", response.status)
         | 
| 761 | 
            +
             | 
| 739 762 | 
             
                return retry_upload
         | 
| 740 763 |  | 
| 741 764 | 
             
              def _upload_batch(self, inputs: List[Input]) -> List[Input]:
         | 
    
        clarifai/client/model.py
    CHANGED
    
    | @@ -1,3 +1,4 @@ | |
| 1 | 
            +
            import json
         | 
| 1 2 | 
             
            import os
         | 
| 2 3 | 
             
            import time
         | 
| 3 4 | 
             
            from typing import Any, Dict, Generator, List, Tuple, Union
         | 
| @@ -9,14 +10,15 @@ from clarifai_grpc.grpc.api import resources_pb2, service_pb2 | |
| 9 10 | 
             
            from clarifai_grpc.grpc.api.resources_pb2 import Input
         | 
| 10 11 | 
             
            from clarifai_grpc.grpc.api.status import status_code_pb2
         | 
| 11 12 | 
             
            from google.protobuf.json_format import MessageToDict
         | 
| 12 | 
            -
            from google.protobuf.struct_pb2 import Struct
         | 
| 13 | 
            +
            from google.protobuf.struct_pb2 import Struct, Value
         | 
| 13 14 | 
             
            from tqdm import tqdm
         | 
| 14 15 |  | 
| 15 16 | 
             
            from clarifai.client.base import BaseClient
         | 
| 16 17 | 
             
            from clarifai.client.dataset import Dataset
         | 
| 17 18 | 
             
            from clarifai.client.input import Inputs
         | 
| 18 19 | 
             
            from clarifai.client.lister import Lister
         | 
| 19 | 
            -
            from clarifai.constants.model import MAX_MODEL_PREDICT_INPUTS,  | 
| 20 | 
            +
            from clarifai.constants.model import (MAX_MODEL_PREDICT_INPUTS, MODEL_EXPORT_TIMEOUT,
         | 
| 21 | 
            +
                                                  TRAINABLE_MODEL_TYPES)
         | 
| 20 22 | 
             
            from clarifai.errors import UserError
         | 
| 21 23 | 
             
            from clarifai.urls.helper import ClarifaiUrlHelper
         | 
| 22 24 | 
             
            from clarifai.utils.logging import get_logger
         | 
| @@ -25,6 +27,10 @@ from clarifai.utils.model_train import (find_and_replace_key, params_parser, | |
| 25 27 | 
             
                                                    response_to_model_params, response_to_param_info,
         | 
| 26 28 | 
             
                                                    response_to_templates)
         | 
| 27 29 |  | 
| 30 | 
            +
            MAX_SIZE_PER_STREAM = int(89_128_960)  # 85GiB
         | 
| 31 | 
            +
            MIN_CHUNK_FOR_UPLOAD_FILE = int(5_242_880)  # 5MiB
         | 
| 32 | 
            +
            MAX_CHUNK_FOR_UPLOAD_FILE = int(5_242_880_000)  # 5GiB
         | 
| 33 | 
            +
             | 
| 28 34 |  | 
| 29 35 | 
             
            class Model(Lister, BaseClient):
         | 
| 30 36 | 
             
              """Model is a class that provides access to Clarifai API endpoints related to Model information."""
         | 
| @@ -58,7 +64,7 @@ class Model(Lister, BaseClient): | |
| 58 64 | 
             
                  user_id, app_id, _, model_id, model_version_id = ClarifaiUrlHelper.split_clarifai_url(url)
         | 
| 59 65 | 
             
                  model_version = {'id': model_version_id}
         | 
| 60 66 | 
             
                  kwargs = {'user_id': user_id, 'app_id': app_id}
         | 
| 61 | 
            -
                self.kwargs = {**kwargs, 'id': model_id, 'model_version': model_version,}
         | 
| 67 | 
            +
                self.kwargs = {**kwargs, 'id': model_id, 'model_version': model_version, }
         | 
| 62 68 | 
             
                self.model_info = resources_pb2.Model(**self.kwargs)
         | 
| 63 69 | 
             
                self.logger = get_logger(logger_level="INFO", name=__name__)
         | 
| 64 70 | 
             
                self.training_params = {}
         | 
| @@ -132,11 +138,11 @@ class Model(Lister, BaseClient): | |
| 132 138 | 
             
                  raise Exception(response.status)
         | 
| 133 139 | 
             
                params = response_to_model_params(
         | 
| 134 140 | 
             
                    response=response, model_type_id=self.model_info.model_type_id, template=template)
         | 
| 135 | 
            -
                #yaml file
         | 
| 141 | 
            +
                # yaml file
         | 
| 136 142 | 
             
                assert save_to.endswith('.yaml'), "File extension should be .yaml"
         | 
| 137 143 | 
             
                with open(save_to, 'w') as f:
         | 
| 138 144 | 
             
                  yaml.dump(params, f, default_flow_style=False, sort_keys=False)
         | 
| 139 | 
            -
                #updating the global model params
         | 
| 145 | 
            +
                # updating the global model params
         | 
| 140 146 | 
             
                self.training_params.update(params)
         | 
| 141 147 |  | 
| 142 148 | 
             
                return params
         | 
| @@ -159,14 +165,14 @@ class Model(Lister, BaseClient): | |
| 159 165 | 
             
                  raise UserError(
         | 
| 160 166 | 
             
                      f"Run 'model.get_params' to get the params for the {self.model_info.model_type_id} model type"
         | 
| 161 167 | 
             
                  )
         | 
| 162 | 
            -
                #getting all the keys in nested dictionary
         | 
| 168 | 
            +
                # getting all the keys in nested dictionary
         | 
| 163 169 | 
             
                all_keys = [key for key in self.training_params.keys()] + [
         | 
| 164 170 | 
             
                    key for key in self.training_params.values() if isinstance(key, dict) for key in key
         | 
| 165 171 | 
             
                ]
         | 
| 166 | 
            -
                #checking if the given params are valid
         | 
| 172 | 
            +
                # checking if the given params are valid
         | 
| 167 173 | 
             
                if not set(kwargs.keys()).issubset(all_keys):
         | 
| 168 174 | 
             
                  raise UserError("Invalid params")
         | 
| 169 | 
            -
                #updating the global model params
         | 
| 175 | 
            +
                # updating the global model params
         | 
| 170 176 | 
             
                for key, value in kwargs.items():
         | 
| 171 177 | 
             
                  find_and_replace_key(self.training_params, key, value)
         | 
| 172 178 |  | 
| @@ -238,7 +244,7 @@ class Model(Lister, BaseClient): | |
| 238 244 | 
             
                    params_dict = yaml.safe_load(file)
         | 
| 239 245 | 
             
                else:
         | 
| 240 246 | 
             
                  params_dict = self.training_params
         | 
| 241 | 
            -
                #getting all the concepts for the model type
         | 
| 247 | 
            +
                # getting all the concepts for the model type
         | 
| 242 248 | 
             
                if self.model_info.model_type_id not in ["clusterer", "text-to-text"]:
         | 
| 243 249 | 
             
                  concepts = self._list_concepts()
         | 
| 244 250 | 
             
                train_dict = params_parser(params_dict, concepts)
         | 
| @@ -423,7 +429,7 @@ class Model(Lister, BaseClient): | |
| 423 429 | 
             
                  response = self._grpc_request(self.STUB.PostModelOutputs, request)
         | 
| 424 430 |  | 
| 425 431 | 
             
                  if response.status.code == status_code_pb2.MODEL_DEPLOYING and \
         | 
| 426 | 
            -
             | 
| 432 | 
            +
                          time.time() - start_time < 60 * 10:  # 10 minutes
         | 
| 427 433 | 
             
                    self.logger.info(f"{self.id} model is still deploying, please wait...")
         | 
| 428 434 | 
             
                    time.sleep(next(backoff_iterator))
         | 
| 429 435 | 
             
                    continue
         | 
| @@ -944,19 +950,22 @@ class Model(Lister, BaseClient): | |
| 944 950 | 
             
                """Export the model, stores the exported model as model.tar file
         | 
| 945 951 |  | 
| 946 952 | 
             
                Args:
         | 
| 947 | 
            -
                    export_dir (str):  | 
| 953 | 
            +
                    export_dir (str, optional): If provided, the exported model will be saved in the specified directory else export status will be shown. Defaults to None.
         | 
| 948 954 |  | 
| 949 955 | 
             
                Example:
         | 
| 950 956 | 
             
                    >>> from clarifai.client.model import Model
         | 
| 951 957 | 
             
                    >>> model = Model("url")
         | 
| 958 | 
            +
                    >>> model.export()
         | 
| 959 | 
            +
                            or
         | 
| 952 960 | 
             
                    >>> model.export('/path/to/export_model_dir')
         | 
| 953 961 | 
             
                """
         | 
| 954 962 | 
             
                assert self.model_info.model_version.id, "Model version ID is missing. Please provide a `model_version` with a valid `id` as an argument or as a URL in the following format: '{user_id}/{app_id}/models/{your_model_id}/model_version_id/{your_version_model_id}' when initializing."
         | 
| 955 | 
            -
                 | 
| 956 | 
            -
                   | 
| 957 | 
            -
                    os. | 
| 958 | 
            -
             | 
| 959 | 
            -
                   | 
| 963 | 
            +
                if export_dir:
         | 
| 964 | 
            +
                  try:
         | 
| 965 | 
            +
                    if not os.path.exists(export_dir):
         | 
| 966 | 
            +
                      os.makedirs(export_dir)
         | 
| 967 | 
            +
                  except OSError as e:
         | 
| 968 | 
            +
                    raise Exception(f"An error occurred while creating the directory: {e}")
         | 
| 960 969 |  | 
| 961 970 | 
             
                def _get_export_response():
         | 
| 962 971 | 
             
                  get_export_request = service_pb2.GetModelVersionExportRequest(
         | 
| @@ -1005,25 +1014,335 @@ class Model(Lister, BaseClient): | |
| 1005 1014 | 
             
                    raise Exception(response.status)
         | 
| 1006 1015 |  | 
| 1007 1016 | 
             
                  self.logger.info(
         | 
| 1008 | 
            -
                      f"Model ID {self.id}  | 
| 1017 | 
            +
                      f"Export process has started for Model ID {self.id}, Version {self.model_info.model_version.id}"
         | 
| 1009 1018 | 
             
                  )
         | 
| 1010 | 
            -
                   | 
| 1011 | 
            -
             | 
| 1012 | 
            -
             | 
| 1013 | 
            -
             | 
| 1014 | 
            -
             | 
| 1015 | 
            -
             | 
| 1016 | 
            -
             | 
| 1017 | 
            -
             | 
| 1018 | 
            -
             | 
| 1019 | 
            -
             | 
| 1020 | 
            -
             | 
| 1021 | 
            -
             | 
| 1022 | 
            -
                       | 
| 1023 | 
            -
             | 
| 1024 | 
            -
             | 
| 1025 | 
            -
                       | 
| 1026 | 
            -
             | 
| 1027 | 
            -
             | 
| 1019 | 
            +
                  if export_dir:
         | 
| 1020 | 
            +
                    start_time = time.time()
         | 
| 1021 | 
            +
                    backoff_iterator = BackoffIterator(10)
         | 
| 1022 | 
            +
                    while True:
         | 
| 1023 | 
            +
                      get_export_response = _get_export_response()
         | 
| 1024 | 
            +
                      if (get_export_response.export.status.code == status_code_pb2.MODEL_EXPORTING or \
         | 
| 1025 | 
            +
                            get_export_response.export.status.code == status_code_pb2.MODEL_EXPORT_PENDING) and \
         | 
| 1026 | 
            +
                              time.time() - start_time < MODEL_EXPORT_TIMEOUT:
         | 
| 1027 | 
            +
                        self.logger.info(
         | 
| 1028 | 
            +
                            f"Export process is ongoing for Model ID {self.id}, Version {self.model_info.model_version.id}. Please wait..."
         | 
| 1029 | 
            +
                        )
         | 
| 1030 | 
            +
                        time.sleep(next(backoff_iterator))
         | 
| 1031 | 
            +
                      elif get_export_response.export.status.code == status_code_pb2.MODEL_EXPORTED:
         | 
| 1032 | 
            +
                        _download_exported_model(get_export_response, os.path.join(export_dir, "model.tar"))
         | 
| 1033 | 
            +
                        break
         | 
| 1034 | 
            +
                      elif time.time() - start_time > MODEL_EXPORT_TIMEOUT:
         | 
| 1035 | 
            +
                        raise Exception(
         | 
| 1036 | 
            +
                            f"""Model Export took too long. Please try again or contact support@clarifai.com
         | 
| 1037 | 
            +
                            Req ID: {get_export_response.status.req_id}""")
         | 
| 1028 1038 | 
             
                elif get_export_response.export.status.code == status_code_pb2.MODEL_EXPORTED:
         | 
| 1029 | 
            -
                   | 
| 1039 | 
            +
                  if export_dir:
         | 
| 1040 | 
            +
                    _download_exported_model(get_export_response, os.path.join(export_dir, "model.tar"))
         | 
| 1041 | 
            +
                  else:
         | 
| 1042 | 
            +
                    self.logger.info(
         | 
| 1043 | 
            +
                        f"Model ID {self.id} with version {self.model_info.model_version.id} is already exported, you can download it from the following URL: {get_export_response.export.url}"
         | 
| 1044 | 
            +
                    )
         | 
| 1045 | 
            +
                elif get_export_response.export.status.code == status_code_pb2.MODEL_EXPORTING or \
         | 
| 1046 | 
            +
                        get_export_response.export.status.code == status_code_pb2.MODEL_EXPORT_PENDING:
         | 
| 1047 | 
            +
                  self.logger.info(
         | 
| 1048 | 
            +
                      f"Export process is ongoing for Model ID {self.id}, Version {self.model_info.model_version.id}. Please wait..."
         | 
| 1049 | 
            +
                  )
         | 
| 1050 | 
            +
             | 
| 1051 | 
            +
              @staticmethod
         | 
| 1052 | 
            +
              def _make_pretrained_config_proto(input_field_maps: dict,
         | 
| 1053 | 
            +
                                                output_field_maps: dict,
         | 
| 1054 | 
            +
                                                url: str = None):
         | 
| 1055 | 
            +
                """Make PretrainedModelConfig for uploading new version
         | 
| 1056 | 
            +
             | 
| 1057 | 
            +
                Args:
         | 
| 1058 | 
            +
                    input_field_maps (dict): dict
         | 
| 1059 | 
            +
                    output_field_maps (dict): dict
         | 
| 1060 | 
            +
                    url (str, optional): direct download url. Defaults to None.
         | 
| 1061 | 
            +
                """
         | 
| 1062 | 
            +
             | 
| 1063 | 
            +
                def _parse_fields_map(x):
         | 
| 1064 | 
            +
                  """parse input, outputs to Struct"""
         | 
| 1065 | 
            +
                  _fields_map = Struct()
         | 
| 1066 | 
            +
                  _fields_map.update(x)
         | 
| 1067 | 
            +
                  return _fields_map
         | 
| 1068 | 
            +
             | 
| 1069 | 
            +
                input_fields_map = _parse_fields_map(input_field_maps)
         | 
| 1070 | 
            +
                output_fields_map = _parse_fields_map(output_field_maps)
         | 
| 1071 | 
            +
             | 
| 1072 | 
            +
                return resources_pb2.PretrainedModelConfig(
         | 
| 1073 | 
            +
                    input_fields_map=input_fields_map, output_fields_map=output_fields_map, model_zip_url=url)
         | 
| 1074 | 
            +
             | 
| 1075 | 
            +
              @staticmethod
         | 
| 1076 | 
            +
              def _make_inference_params_proto(
         | 
| 1077 | 
            +
                  inference_parameters: List[Dict]) -> List[resources_pb2.ModelTypeField]:
         | 
| 1078 | 
            +
                """Convert list of Clarifai inference parameters to proto for uploading new version
         | 
| 1079 | 
            +
             | 
| 1080 | 
            +
                Args:
         | 
| 1081 | 
            +
                    inference_parameters (List[Dict]): Each dict has keys {field_type, path, default_value, description}
         | 
| 1082 | 
            +
             | 
| 1083 | 
            +
                Returns:
         | 
| 1084 | 
            +
                    List[resources_pb2.ModelTypeField]
         | 
| 1085 | 
            +
                """
         | 
| 1086 | 
            +
             | 
| 1087 | 
            +
                def _make_default_value_proto(dtype, value):
         | 
| 1088 | 
            +
                  if dtype == 1:
         | 
| 1089 | 
            +
                    return Value(bool_value=value)
         | 
| 1090 | 
            +
                  elif dtype == 2 or dtype == 21:
         | 
| 1091 | 
            +
                    return Value(string_value=value)
         | 
| 1092 | 
            +
                  elif dtype == 3:
         | 
| 1093 | 
            +
                    return Value(number_value=value)
         | 
| 1094 | 
            +
             | 
| 1095 | 
            +
                iterative_proto_params = []
         | 
| 1096 | 
            +
                for param in inference_parameters:
         | 
| 1097 | 
            +
                  dtype = param.get("field_type")
         | 
| 1098 | 
            +
                  proto_param = resources_pb2.ModelTypeField(
         | 
| 1099 | 
            +
                      path=param.get("path"),
         | 
| 1100 | 
            +
                      field_type=dtype,
         | 
| 1101 | 
            +
                      default_value=_make_default_value_proto(dtype=dtype, value=param.get("default_value")),
         | 
| 1102 | 
            +
                      description=param.get("description"),
         | 
| 1103 | 
            +
                  )
         | 
| 1104 | 
            +
                  iterative_proto_params.append(proto_param)
         | 
| 1105 | 
            +
                return iterative_proto_params
         | 
| 1106 | 
            +
             | 
| 1107 | 
            +
              def create_version_by_file(self,
         | 
| 1108 | 
            +
                                         file_path: str,
         | 
| 1109 | 
            +
                                         input_field_maps: dict,
         | 
| 1110 | 
            +
                                         output_field_maps: dict,
         | 
| 1111 | 
            +
                                         inference_parameter_configs: dict = None,
         | 
| 1112 | 
            +
                                         model_version: str = None,
         | 
| 1113 | 
            +
                                         part_id: int = 1,
         | 
| 1114 | 
            +
                                         range_start: int = 0,
         | 
| 1115 | 
            +
                                         no_cache: bool = False,
         | 
| 1116 | 
            +
                                         no_resume: bool = False,
         | 
| 1117 | 
            +
                                         description: str = "") -> 'Model':
         | 
| 1118 | 
            +
                """Create model version by uploading local file
         | 
| 1119 | 
            +
             | 
| 1120 | 
            +
                Args:
         | 
| 1121 | 
            +
                    file_path (str): path to built file.
         | 
| 1122 | 
            +
                    input_field_maps (dict): a dict where the key is clarifai input field and the value is triton model input,
         | 
| 1123 | 
            +
                      {clarifai_input_field: triton_input_filed}.
         | 
| 1124 | 
            +
                    output_field_maps (dict): a dict where the keys are clarifai output fields and the values are triton model outputs,
         | 
| 1125 | 
            +
                      {clarifai_output_field1: triton_output_filed1, clarifai_output_field2: triton_output_filed2,...}.
         | 
| 1126 | 
            +
                    inference_parameter_configs (List[dict]): list of dicts - keys are path, field_type, default_value, description. Default is None
         | 
| 1127 | 
            +
                    model_version (str, optional): Custom model version. Defaults to None.
         | 
| 1128 | 
            +
                    part_id (int, optional): part id of file. Defaults to 1.
         | 
| 1129 | 
            +
                    range_start (int, optional): range of uploaded size. Defaults to 0.
         | 
| 1130 | 
            +
                    no_cache (bool, optional): not saving uploading cache that is used to resume uploading. Defaults to False.
         | 
| 1131 | 
            +
                    no_resume (bool, optional): disable auto resume upload. Defaults to False.
         | 
| 1132 | 
            +
                    description (str): Model description.
         | 
| 1133 | 
            +
             | 
| 1134 | 
            +
                Return:
         | 
| 1135 | 
            +
                  Model: instance of Model with new created version
         | 
| 1136 | 
            +
             | 
| 1137 | 
            +
                """
         | 
| 1138 | 
            +
                file_size = os.path.getsize(file_path)
         | 
| 1139 | 
            +
                assert MIN_CHUNK_FOR_UPLOAD_FILE <= file_size <= MAX_CHUNK_FOR_UPLOAD_FILE, "The file size exceeds the allowable limit, which ranges from 5MiB to 5GiB."
         | 
| 1140 | 
            +
             | 
| 1141 | 
            +
                pretrained_proto = Model._make_pretrained_config_proto(
         | 
| 1142 | 
            +
                    input_field_maps=input_field_maps, output_field_maps=output_field_maps)
         | 
| 1143 | 
            +
                inference_param_proto = Model._make_inference_params_proto(
         | 
| 1144 | 
            +
                    inference_parameter_configs) if inference_parameter_configs else None
         | 
| 1145 | 
            +
             | 
| 1146 | 
            +
                if file_size >= 1e9:
         | 
| 1147 | 
            +
                  chunk_size = 1024 * 50_000  # 50MB
         | 
| 1148 | 
            +
                else:
         | 
| 1149 | 
            +
                  chunk_size = 1024 * 10_000  # 10MB
         | 
| 1150 | 
            +
             | 
| 1151 | 
            +
                #self.logger.info(f"Chunk {chunk_size/1e6}MB, {file_size/chunk_size} steps")
         | 
| 1152 | 
            +
                #self.logger.info(f" Max bytes per stream {MAX_SIZE_PER_STREAM}")
         | 
| 1153 | 
            +
             | 
| 1154 | 
            +
                cache_dir = os.path.join(file_path, '..', '.cache')
         | 
| 1155 | 
            +
                cache_upload_file = os.path.join(cache_dir, "upload.json")
         | 
| 1156 | 
            +
                last_percent = 0
         | 
| 1157 | 
            +
                if os.path.exists(cache_upload_file) and not no_resume:
         | 
| 1158 | 
            +
                  with open(cache_upload_file, "r") as fp:
         | 
| 1159 | 
            +
                    try:
         | 
| 1160 | 
            +
                      cache_info = json.load(fp)
         | 
| 1161 | 
            +
                      if isinstance(cache_info, dict):
         | 
| 1162 | 
            +
                        part_id = cache_info.get("part_id", part_id)
         | 
| 1163 | 
            +
                        chunk_size = cache_info.get("chunk_size", chunk_size)
         | 
| 1164 | 
            +
                        range_start = cache_info.get("range_start", range_start)
         | 
| 1165 | 
            +
                        model_version = cache_info.get("model_version", model_version)
         | 
| 1166 | 
            +
                        last_percent = cache_info.get("last_percent", last_percent)
         | 
| 1167 | 
            +
                    except Exception as e:
         | 
| 1168 | 
            +
                      self.logger.error(f"Skipping loading the upload cache due to error {e}.")
         | 
| 1169 | 
            +
             | 
| 1170 | 
            +
                def init_model_version_upload(model_version):
         | 
| 1171 | 
            +
                  return service_pb2.PostModelVersionsUploadRequest(
         | 
| 1172 | 
            +
                      upload_config=service_pb2.PostModelVersionsUploadConfig(
         | 
| 1173 | 
            +
                          user_app_id=self.user_app_id,
         | 
| 1174 | 
            +
                          model_id=self.id,
         | 
| 1175 | 
            +
                          total_size=file_size,
         | 
| 1176 | 
            +
                          model_version=resources_pb2.ModelVersion(
         | 
| 1177 | 
            +
                              id=model_version,
         | 
| 1178 | 
            +
                              pretrained_model_config=pretrained_proto,
         | 
| 1179 | 
            +
                              description=description,
         | 
| 1180 | 
            +
                              output_info=resources_pb2.OutputInfo(params_specs=inference_param_proto)),
         | 
| 1181 | 
            +
                      ))
         | 
| 1182 | 
            +
             | 
| 1183 | 
            +
                def _uploading(chunk, part_id, range_start, model_version):
         | 
| 1184 | 
            +
                  return service_pb2.PostModelVersionsUploadRequest(
         | 
| 1185 | 
            +
                      content_part=resources_pb2.UploadContentPart(
         | 
| 1186 | 
            +
                          data=chunk, part_number=part_id, range_start=range_start))
         | 
| 1187 | 
            +
             | 
| 1188 | 
            +
                finished_status = [status_code_pb2.SUCCESS, status_code_pb2.UPLOAD_DONE]
         | 
| 1189 | 
            +
                uploading_in_progress_status = [
         | 
| 1190 | 
            +
                    status_code_pb2.UPLOAD_IN_PROGRESS, status_code_pb2.MODEL_UPLOADING
         | 
| 1191 | 
            +
                ]
         | 
| 1192 | 
            +
             | 
| 1193 | 
            +
                def _save_cache(cache: dict):
         | 
| 1194 | 
            +
                  if not no_cache:
         | 
| 1195 | 
            +
                    os.makedirs(cache_dir, exist_ok=True)
         | 
| 1196 | 
            +
                    with open(cache_upload_file, "w") as fp:
         | 
| 1197 | 
            +
                      json.dump(cache, fp, indent=2)
         | 
| 1198 | 
            +
             | 
| 1199 | 
            +
                def stream_request(fp, part_id, end_part_id, chunk_size, version):
         | 
| 1200 | 
            +
                  yield init_model_version_upload(version)
         | 
| 1201 | 
            +
                  for iter_part_id in range(part_id, end_part_id):
         | 
| 1202 | 
            +
                    chunk = fp.read(chunk_size)
         | 
| 1203 | 
            +
                    if not chunk:
         | 
| 1204 | 
            +
                      return
         | 
| 1205 | 
            +
                    yield _uploading(
         | 
| 1206 | 
            +
                        chunk=chunk,
         | 
| 1207 | 
            +
                        part_id=iter_part_id,
         | 
| 1208 | 
            +
                        range_start=chunk_size * (iter_part_id - 1),
         | 
| 1209 | 
            +
                        model_version=version)
         | 
| 1210 | 
            +
             | 
| 1211 | 
            +
                tqdm_loader = tqdm(total=100)
         | 
| 1212 | 
            +
                if model_version:
         | 
| 1213 | 
            +
                  desc = f"Uploading model `{self.id}` version `{model_version}` ..."
         | 
| 1214 | 
            +
                else:
         | 
| 1215 | 
            +
                  desc = f"Uploading model `{self.id}` ..."
         | 
| 1216 | 
            +
                tqdm_loader.set_description(desc)
         | 
| 1217 | 
            +
             | 
| 1218 | 
            +
                cache_uploading_info = {}
         | 
| 1219 | 
            +
                cache_uploading_info["part_id"] = part_id
         | 
| 1220 | 
            +
                cache_uploading_info["model_version"] = model_version
         | 
| 1221 | 
            +
                cache_uploading_info["range_start"] = range_start
         | 
| 1222 | 
            +
                cache_uploading_info["chunk_size"] = chunk_size
         | 
| 1223 | 
            +
                cache_uploading_info["last_percent"] = last_percent
         | 
| 1224 | 
            +
                tqdm_loader.update(last_percent)
         | 
| 1225 | 
            +
                last_part_id = part_id
         | 
| 1226 | 
            +
                n_chunks = file_size // chunk_size
         | 
| 1227 | 
            +
                n_chunk_per_stream = MAX_SIZE_PER_STREAM // chunk_size or 1
         | 
| 1228 | 
            +
             | 
| 1229 | 
            +
                def stream_and_logging(request, tqdm_loader, cache_uploading_info, expected_steps: int = None):
         | 
| 1230 | 
            +
                  for st_step, st_response in enumerate(self.auth_helper.get_stub().PostModelVersionsUpload(
         | 
| 1231 | 
            +
                      request, metadata=self.auth_helper.metadata)):
         | 
| 1232 | 
            +
                    if st_response.status.code in uploading_in_progress_status:
         | 
| 1233 | 
            +
                      if cache_uploading_info["model_version"]:
         | 
| 1234 | 
            +
                        assert st_response.model_version_id == cache_uploading_info[
         | 
| 1235 | 
            +
                            "model_version"], RuntimeError
         | 
| 1236 | 
            +
                      else:
         | 
| 1237 | 
            +
                        cache_uploading_info["model_version"] = st_response.model_version_id
         | 
| 1238 | 
            +
                      if st_step > 0:
         | 
| 1239 | 
            +
                        cache_uploading_info["part_id"] += 1
         | 
| 1240 | 
            +
                        cache_uploading_info["range_start"] += chunk_size
         | 
| 1241 | 
            +
                        _save_cache(cache_uploading_info)
         | 
| 1242 | 
            +
             | 
| 1243 | 
            +
                        if st_response.status.percent_completed:
         | 
| 1244 | 
            +
                          step_percent = st_response.status.percent_completed - cache_uploading_info["last_percent"]
         | 
| 1245 | 
            +
                          cache_uploading_info["last_percent"] += step_percent
         | 
| 1246 | 
            +
                          tqdm_loader.set_description(
         | 
| 1247 | 
            +
                              f"{st_response.status.description}, {st_response.status.details}, version id  {cache_uploading_info.get('model_version')}"
         | 
| 1248 | 
            +
                          )
         | 
| 1249 | 
            +
                          tqdm_loader.update(step_percent)
         | 
| 1250 | 
            +
                    elif st_response.status.code not in finished_status + uploading_in_progress_status:
         | 
| 1251 | 
            +
                      # TODO: Find better way to handle error
         | 
| 1252 | 
            +
                      if expected_steps and st_step < expected_steps:
         | 
| 1253 | 
            +
                        raise Exception(f"Failed to upload model, error: {st_response.status}")
         | 
| 1254 | 
            +
             | 
| 1255 | 
            +
                with open(file_path, 'rb') as fp:
         | 
| 1256 | 
            +
                  # seeking
         | 
| 1257 | 
            +
                  for _ in range(1, last_part_id):
         | 
| 1258 | 
            +
                    fp.read(chunk_size)
         | 
| 1259 | 
            +
                  # Stream even part
         | 
| 1260 | 
            +
                  end_part_id = n_chunks or 1
         | 
| 1261 | 
            +
                  for iter_part_id in range(int(last_part_id), int(n_chunks), int(n_chunk_per_stream)):
         | 
| 1262 | 
            +
                    end_part_id = iter_part_id + n_chunk_per_stream
         | 
| 1263 | 
            +
                    if end_part_id >= n_chunks:
         | 
| 1264 | 
            +
                      end_part_id = n_chunks
         | 
| 1265 | 
            +
                    expected_steps = end_part_id - iter_part_id + 1  # init step
         | 
| 1266 | 
            +
                    st_reqs = stream_request(
         | 
| 1267 | 
            +
                        fp,
         | 
| 1268 | 
            +
                        iter_part_id,
         | 
| 1269 | 
            +
                        end_part_id=end_part_id,
         | 
| 1270 | 
            +
                        chunk_size=chunk_size,
         | 
| 1271 | 
            +
                        version=cache_uploading_info["model_version"])
         | 
| 1272 | 
            +
                    stream_and_logging(st_reqs, tqdm_loader, cache_uploading_info, expected_steps)
         | 
| 1273 | 
            +
                  # Stream last part
         | 
| 1274 | 
            +
                  accum_size = (end_part_id - 1) * chunk_size
         | 
| 1275 | 
            +
                  remained_size = file_size - accum_size if accum_size >= 0 else file_size
         | 
| 1276 | 
            +
                  st_reqs = stream_request(
         | 
| 1277 | 
            +
                      fp,
         | 
| 1278 | 
            +
                      end_part_id,
         | 
| 1279 | 
            +
                      end_part_id=end_part_id + 1,
         | 
| 1280 | 
            +
                      chunk_size=remained_size,
         | 
| 1281 | 
            +
                      version=cache_uploading_info["model_version"])
         | 
| 1282 | 
            +
                  stream_and_logging(st_reqs, tqdm_loader, cache_uploading_info, 2)
         | 
| 1283 | 
            +
             | 
| 1284 | 
            +
                # clean up cache
         | 
| 1285 | 
            +
                if not no_cache:
         | 
| 1286 | 
            +
                  try:
         | 
| 1287 | 
            +
                    os.remove(cache_upload_file)
         | 
| 1288 | 
            +
                  except Exception:
         | 
| 1289 | 
            +
                    _save_cache({})
         | 
| 1290 | 
            +
             | 
| 1291 | 
            +
                if cache_uploading_info["last_percent"] <= 100:
         | 
| 1292 | 
            +
                  tqdm_loader.update(100 - cache_uploading_info["last_percent"])
         | 
| 1293 | 
            +
                  tqdm_loader.set_description("Upload done")
         | 
| 1294 | 
            +
             | 
| 1295 | 
            +
                tqdm_loader.set_description(
         | 
| 1296 | 
            +
                    f"Success uploading model {self.id}, new version {cache_uploading_info.get('model_version')}"
         | 
| 1297 | 
            +
                )
         | 
| 1298 | 
            +
             | 
| 1299 | 
            +
                return Model.from_auth_helper(
         | 
| 1300 | 
            +
                    auth=self.auth_helper,
         | 
| 1301 | 
            +
                    model_id=self.id,
         | 
| 1302 | 
            +
                    model_version=dict(id=cache_uploading_info.get('model_version')))
         | 
| 1303 | 
            +
             | 
| 1304 | 
            +
              def create_version_by_url(self,
         | 
| 1305 | 
            +
                                        url: str,
         | 
| 1306 | 
            +
                                        input_field_maps: dict,
         | 
| 1307 | 
            +
                                        output_field_maps: dict,
         | 
| 1308 | 
            +
                                        inference_parameter_configs: List[dict] = None,
         | 
| 1309 | 
            +
                                        description: str = "") -> 'Model':
         | 
| 1310 | 
            +
                """Upload a new version of an existing model in the Clarifai platform using direct download url.
         | 
| 1311 | 
            +
             | 
| 1312 | 
            +
                Args:
         | 
| 1313 | 
            +
                  url (str]): url of zip of model
         | 
| 1314 | 
            +
                  input_field_maps (dict): a dict where the key is clarifai input field and the value is triton model input,
         | 
| 1315 | 
            +
                      {clarifai_input_field: triton_input_filed}.
         | 
| 1316 | 
            +
                  output_field_maps (dict): a dict where the keys are clarifai output fields and the values are triton model outputs,
         | 
| 1317 | 
            +
                      {clarifai_output_field1: triton_output_filed1, clarifai_output_field2: triton_output_filed2,...}.
         | 
| 1318 | 
            +
                  inference_parameter_configs (List[dict]): list of dicts - keys are path, field_type, default_value, description. Default is None
         | 
| 1319 | 
            +
                  description (str): Model description.
         | 
| 1320 | 
            +
             | 
| 1321 | 
            +
                Return:
         | 
| 1322 | 
            +
                  Model: instance of Model with new created version
         | 
| 1323 | 
            +
                """
         | 
| 1324 | 
            +
             | 
| 1325 | 
            +
                pretrained_proto = Model._make_pretrained_config_proto(
         | 
| 1326 | 
            +
                    input_field_maps=input_field_maps, output_field_maps=output_field_maps, url=url)
         | 
| 1327 | 
            +
                inference_param_proto = Model._make_inference_params_proto(
         | 
| 1328 | 
            +
                    inference_parameter_configs) if inference_parameter_configs else None
         | 
| 1329 | 
            +
                request = service_pb2.PostModelVersionsRequest(
         | 
| 1330 | 
            +
                    user_app_id=self.user_app_id,
         | 
| 1331 | 
            +
                    model_id=self.id,
         | 
| 1332 | 
            +
                    model_versions=[
         | 
| 1333 | 
            +
                        resources_pb2.ModelVersion(
         | 
| 1334 | 
            +
                            pretrained_model_config=pretrained_proto,
         | 
| 1335 | 
            +
                            description=description,
         | 
| 1336 | 
            +
                            output_info=resources_pb2.OutputInfo(params_specs=inference_param_proto))
         | 
| 1337 | 
            +
                    ])
         | 
| 1338 | 
            +
                response = self._grpc_request(self.STUB.PostModelVersions, request)
         | 
| 1339 | 
            +
             | 
| 1340 | 
            +
                if response.status.code != status_code_pb2.SUCCESS:
         | 
| 1341 | 
            +
                  raise Exception(f"Failed to upload model, error: {response.status}")
         | 
| 1342 | 
            +
                self.logger.info(
         | 
| 1343 | 
            +
                    f"Success uploading model {self.id}, new version {response.model.model_version.id}")
         | 
| 1344 | 
            +
             | 
| 1345 | 
            +
                return Model.from_auth_helper(
         | 
| 1346 | 
            +
                    auth=self.auth_helper,
         | 
| 1347 | 
            +
                    model_id=self.id,
         | 
| 1348 | 
            +
                    model_version=dict(id=response.model.model_version.id))
         |