clarifai 10.3.0__py3-none-any.whl → 10.3.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- clarifai/client/input.py +32 -9
- clarifai/client/model.py +355 -36
- clarifai/client/search.py +90 -15
- clarifai/constants/model.py +1 -0
- clarifai/constants/search.py +2 -1
- clarifai/datasets/upload/features.py +4 -0
- clarifai/datasets/upload/image.py +25 -2
- clarifai/datasets/upload/loaders/coco_captions.py +7 -2
- clarifai/datasets/upload/loaders/coco_detection.py +7 -2
- clarifai/datasets/upload/text.py +2 -0
- clarifai/models/model_serving/README.md +3 -0
- clarifai/models/model_serving/cli/upload.py +65 -68
- clarifai/models/model_serving/docs/cli.md +17 -6
- clarifai/rag/rag.py +1 -1
- clarifai/rag/utils.py +2 -2
- clarifai/versions.py +1 -1
- {clarifai-10.3.0.dist-info → clarifai-10.3.2.dist-info}/METADATA +20 -3
- {clarifai-10.3.0.dist-info → clarifai-10.3.2.dist-info}/RECORD +22 -22
- {clarifai-10.3.0.dist-info → clarifai-10.3.2.dist-info}/LICENSE +0 -0
- {clarifai-10.3.0.dist-info → clarifai-10.3.2.dist-info}/WHEEL +0 -0
- {clarifai-10.3.0.dist-info → clarifai-10.3.2.dist-info}/entry_points.txt +0 -0
- {clarifai-10.3.0.dist-info → clarifai-10.3.2.dist-info}/top_level.txt +0 -0
clarifai/client/input.py
CHANGED
@@ -72,6 +72,7 @@ class Inputs(Lister, BaseClient):
|
|
72
72
|
text_pb: Text = None,
|
73
73
|
geo_info: List = None,
|
74
74
|
labels: List = None,
|
75
|
+
label_ids: List = None,
|
75
76
|
metadata: Struct = None) -> Input:
|
76
77
|
"""Create input proto for image data type.
|
77
78
|
Args:
|
@@ -82,7 +83,8 @@ class Inputs(Lister, BaseClient):
|
|
82
83
|
audio_pb (Audio): The audio proto to be used for the input.
|
83
84
|
text_pb (Text): The text proto to be used for the input.
|
84
85
|
geo_info (list): A list of longitude and latitude for the geo point.
|
85
|
-
labels (list): A list of
|
86
|
+
labels (list): A list of label names for the input.
|
87
|
+
label_ids (list): A list of label ids for the input.
|
86
88
|
metadata (Struct): A Struct of metadata for the input.
|
87
89
|
Returns:
|
88
90
|
Input: An Input object for the specified input ID.
|
@@ -90,14 +92,26 @@ class Inputs(Lister, BaseClient):
|
|
90
92
|
assert geo_info is None or isinstance(
|
91
93
|
geo_info, list), "geo_info must be a list of longitude and latitude"
|
92
94
|
assert labels is None or isinstance(labels, list), "labels must be a list of strings"
|
95
|
+
assert label_ids is None or isinstance(label_ids, list), "label_ids must be a list of strings"
|
93
96
|
assert metadata is None or isinstance(metadata, Struct), "metadata must be a Struct"
|
94
97
|
geo_pb = resources_pb2.Geo(geo_point=resources_pb2.GeoPoint(
|
95
98
|
longitude=geo_info[0], latitude=geo_info[1])) if geo_info else None
|
96
|
-
|
99
|
+
if labels:
|
100
|
+
if not label_ids:
|
101
|
+
concepts=[
|
97
102
|
resources_pb2.Concept(
|
98
103
|
id=f"id-{''.join(_label.split(' '))}", name=_label, value=1.)\
|
99
104
|
for _label in labels
|
100
|
-
]
|
105
|
+
]
|
106
|
+
else:
|
107
|
+
assert len(labels) == len(label_ids), "labels and label_ids must be of the same length"
|
108
|
+
concepts=[
|
109
|
+
resources_pb2.Concept(
|
110
|
+
id=label_id, name=_label, value=1.)\
|
111
|
+
for label_id, _label in zip(label_ids, labels)
|
112
|
+
]
|
113
|
+
else:
|
114
|
+
concepts = None
|
101
115
|
|
102
116
|
if dataset_id:
|
103
117
|
return resources_pb2.Input(
|
@@ -467,13 +481,14 @@ class Inputs(Lister, BaseClient):
|
|
467
481
|
return input_protos
|
468
482
|
|
469
483
|
@staticmethod
|
470
|
-
def get_bbox_proto(input_id: str, label: str, bbox: List) -> Annotation:
|
484
|
+
def get_bbox_proto(input_id: str, label: str, bbox: List, label_id: str = None) -> Annotation:
|
471
485
|
"""Create an annotation proto for each bounding box, label input pair.
|
472
486
|
|
473
487
|
Args:
|
474
488
|
input_id (str): The input ID for the annotation to create.
|
475
|
-
label (str): annotation label
|
489
|
+
label (str): annotation label name
|
476
490
|
bbox (List): a list of a single bbox's coordinates. # bbox ordering: [xmin, ymin, xmax, ymax]
|
491
|
+
label_id (str): annotation label ID
|
477
492
|
|
478
493
|
Returns:
|
479
494
|
An annotation object for the specified input ID.
|
@@ -500,19 +515,22 @@ class Inputs(Lister, BaseClient):
|
|
500
515
|
data=resources_pb2.Data(concepts=[
|
501
516
|
resources_pb2.Concept(
|
502
517
|
id=f"id-{''.join(label.split(' '))}", name=label, value=1.)
|
518
|
+
if not label_id else resources_pb2.Concept(id=label_id, name=label, value=1.)
|
503
519
|
]))
|
504
520
|
]))
|
505
521
|
|
506
522
|
return input_annot_proto
|
507
523
|
|
508
524
|
@staticmethod
|
509
|
-
def get_mask_proto(input_id: str, label: str, polygons: List[List[float]]
|
525
|
+
def get_mask_proto(input_id: str, label: str, polygons: List[List[float]],
|
526
|
+
label_id: str = None) -> Annotation:
|
510
527
|
"""Create an annotation proto for each polygon box, label input pair.
|
511
528
|
|
512
529
|
Args:
|
513
530
|
input_id (str): The input ID for the annotation to create.
|
514
|
-
label (str): annotation label
|
531
|
+
label (str): annotation label name
|
515
532
|
polygons (List): Polygon x,y points iterable
|
533
|
+
label_id (str): annotation label ID
|
516
534
|
|
517
535
|
Returns:
|
518
536
|
An annotation object for the specified input ID.
|
@@ -537,6 +555,7 @@ class Inputs(Lister, BaseClient):
|
|
537
555
|
data=resources_pb2.Data(concepts=[
|
538
556
|
resources_pb2.Concept(
|
539
557
|
id=f"id-{''.join(label.split(' '))}", name=label, value=1.)
|
558
|
+
if not label_id else resources_pb2.Concept(id=label_id, name=label, value=1.)
|
540
559
|
]))
|
541
560
|
]))
|
542
561
|
|
@@ -726,16 +745,20 @@ class Inputs(Lister, BaseClient):
|
|
726
745
|
request = service_pb2.PostAnnotationsRequest(
|
727
746
|
user_app_id=self.user_app_id, annotations=batch_annot)
|
728
747
|
response = self._grpc_request(self.STUB.PostAnnotations, request)
|
748
|
+
response_dict = MessageToDict(response)
|
729
749
|
if response.status.code != status_code_pb2.SUCCESS:
|
730
750
|
try:
|
731
|
-
|
751
|
+
for annot in response_dict["annotations"]:
|
752
|
+
if annot['status']['code'] != status_code_pb2.ANNOTATION_SUCCESS:
|
753
|
+
self.logger.warning(f"Post annotations failed, status: {annot['status']}")
|
732
754
|
except Exception:
|
733
|
-
self.logger.warning(f"Post annotations failed
|
755
|
+
self.logger.warning(f"Post annotations failed due to {response.status}")
|
734
756
|
finally:
|
735
757
|
retry_upload.extend(batch_annot)
|
736
758
|
else:
|
737
759
|
if show_log:
|
738
760
|
self.logger.info("\nAnnotations Uploaded\n%s", response.status)
|
761
|
+
|
739
762
|
return retry_upload
|
740
763
|
|
741
764
|
def _upload_batch(self, inputs: List[Input]) -> List[Input]:
|
clarifai/client/model.py
CHANGED
@@ -1,3 +1,4 @@
|
|
1
|
+
import json
|
1
2
|
import os
|
2
3
|
import time
|
3
4
|
from typing import Any, Dict, Generator, List, Tuple, Union
|
@@ -9,14 +10,15 @@ from clarifai_grpc.grpc.api import resources_pb2, service_pb2
|
|
9
10
|
from clarifai_grpc.grpc.api.resources_pb2 import Input
|
10
11
|
from clarifai_grpc.grpc.api.status import status_code_pb2
|
11
12
|
from google.protobuf.json_format import MessageToDict
|
12
|
-
from google.protobuf.struct_pb2 import Struct
|
13
|
+
from google.protobuf.struct_pb2 import Struct, Value
|
13
14
|
from tqdm import tqdm
|
14
15
|
|
15
16
|
from clarifai.client.base import BaseClient
|
16
17
|
from clarifai.client.dataset import Dataset
|
17
18
|
from clarifai.client.input import Inputs
|
18
19
|
from clarifai.client.lister import Lister
|
19
|
-
from clarifai.constants.model import MAX_MODEL_PREDICT_INPUTS,
|
20
|
+
from clarifai.constants.model import (MAX_MODEL_PREDICT_INPUTS, MODEL_EXPORT_TIMEOUT,
|
21
|
+
TRAINABLE_MODEL_TYPES)
|
20
22
|
from clarifai.errors import UserError
|
21
23
|
from clarifai.urls.helper import ClarifaiUrlHelper
|
22
24
|
from clarifai.utils.logging import get_logger
|
@@ -25,6 +27,10 @@ from clarifai.utils.model_train import (find_and_replace_key, params_parser,
|
|
25
27
|
response_to_model_params, response_to_param_info,
|
26
28
|
response_to_templates)
|
27
29
|
|
30
|
+
MAX_SIZE_PER_STREAM = int(89_128_960) # 85GiB
|
31
|
+
MIN_CHUNK_FOR_UPLOAD_FILE = int(5_242_880) # 5MiB
|
32
|
+
MAX_CHUNK_FOR_UPLOAD_FILE = int(5_242_880_000) # 5GiB
|
33
|
+
|
28
34
|
|
29
35
|
class Model(Lister, BaseClient):
|
30
36
|
"""Model is a class that provides access to Clarifai API endpoints related to Model information."""
|
@@ -58,7 +64,7 @@ class Model(Lister, BaseClient):
|
|
58
64
|
user_id, app_id, _, model_id, model_version_id = ClarifaiUrlHelper.split_clarifai_url(url)
|
59
65
|
model_version = {'id': model_version_id}
|
60
66
|
kwargs = {'user_id': user_id, 'app_id': app_id}
|
61
|
-
self.kwargs = {**kwargs, 'id': model_id, 'model_version': model_version,}
|
67
|
+
self.kwargs = {**kwargs, 'id': model_id, 'model_version': model_version, }
|
62
68
|
self.model_info = resources_pb2.Model(**self.kwargs)
|
63
69
|
self.logger = get_logger(logger_level="INFO", name=__name__)
|
64
70
|
self.training_params = {}
|
@@ -132,11 +138,11 @@ class Model(Lister, BaseClient):
|
|
132
138
|
raise Exception(response.status)
|
133
139
|
params = response_to_model_params(
|
134
140
|
response=response, model_type_id=self.model_info.model_type_id, template=template)
|
135
|
-
#yaml file
|
141
|
+
# yaml file
|
136
142
|
assert save_to.endswith('.yaml'), "File extension should be .yaml"
|
137
143
|
with open(save_to, 'w') as f:
|
138
144
|
yaml.dump(params, f, default_flow_style=False, sort_keys=False)
|
139
|
-
#updating the global model params
|
145
|
+
# updating the global model params
|
140
146
|
self.training_params.update(params)
|
141
147
|
|
142
148
|
return params
|
@@ -159,14 +165,14 @@ class Model(Lister, BaseClient):
|
|
159
165
|
raise UserError(
|
160
166
|
f"Run 'model.get_params' to get the params for the {self.model_info.model_type_id} model type"
|
161
167
|
)
|
162
|
-
#getting all the keys in nested dictionary
|
168
|
+
# getting all the keys in nested dictionary
|
163
169
|
all_keys = [key for key in self.training_params.keys()] + [
|
164
170
|
key for key in self.training_params.values() if isinstance(key, dict) for key in key
|
165
171
|
]
|
166
|
-
#checking if the given params are valid
|
172
|
+
# checking if the given params are valid
|
167
173
|
if not set(kwargs.keys()).issubset(all_keys):
|
168
174
|
raise UserError("Invalid params")
|
169
|
-
#updating the global model params
|
175
|
+
# updating the global model params
|
170
176
|
for key, value in kwargs.items():
|
171
177
|
find_and_replace_key(self.training_params, key, value)
|
172
178
|
|
@@ -238,7 +244,7 @@ class Model(Lister, BaseClient):
|
|
238
244
|
params_dict = yaml.safe_load(file)
|
239
245
|
else:
|
240
246
|
params_dict = self.training_params
|
241
|
-
#getting all the concepts for the model type
|
247
|
+
# getting all the concepts for the model type
|
242
248
|
if self.model_info.model_type_id not in ["clusterer", "text-to-text"]:
|
243
249
|
concepts = self._list_concepts()
|
244
250
|
train_dict = params_parser(params_dict, concepts)
|
@@ -423,7 +429,7 @@ class Model(Lister, BaseClient):
|
|
423
429
|
response = self._grpc_request(self.STUB.PostModelOutputs, request)
|
424
430
|
|
425
431
|
if response.status.code == status_code_pb2.MODEL_DEPLOYING and \
|
426
|
-
|
432
|
+
time.time() - start_time < 60 * 10: # 10 minutes
|
427
433
|
self.logger.info(f"{self.id} model is still deploying, please wait...")
|
428
434
|
time.sleep(next(backoff_iterator))
|
429
435
|
continue
|
@@ -944,19 +950,22 @@ class Model(Lister, BaseClient):
|
|
944
950
|
"""Export the model, stores the exported model as model.tar file
|
945
951
|
|
946
952
|
Args:
|
947
|
-
export_dir (str):
|
953
|
+
export_dir (str, optional): If provided, the exported model will be saved in the specified directory else export status will be shown. Defaults to None.
|
948
954
|
|
949
955
|
Example:
|
950
956
|
>>> from clarifai.client.model import Model
|
951
957
|
>>> model = Model("url")
|
958
|
+
>>> model.export()
|
959
|
+
or
|
952
960
|
>>> model.export('/path/to/export_model_dir')
|
953
961
|
"""
|
954
962
|
assert self.model_info.model_version.id, "Model version ID is missing. Please provide a `model_version` with a valid `id` as an argument or as a URL in the following format: '{user_id}/{app_id}/models/{your_model_id}/model_version_id/{your_version_model_id}' when initializing."
|
955
|
-
|
956
|
-
|
957
|
-
os.
|
958
|
-
|
959
|
-
|
963
|
+
if export_dir:
|
964
|
+
try:
|
965
|
+
if not os.path.exists(export_dir):
|
966
|
+
os.makedirs(export_dir)
|
967
|
+
except OSError as e:
|
968
|
+
raise Exception(f"An error occurred while creating the directory: {e}")
|
960
969
|
|
961
970
|
def _get_export_response():
|
962
971
|
get_export_request = service_pb2.GetModelVersionExportRequest(
|
@@ -1005,25 +1014,335 @@ class Model(Lister, BaseClient):
|
|
1005
1014
|
raise Exception(response.status)
|
1006
1015
|
|
1007
1016
|
self.logger.info(
|
1008
|
-
f"Model ID {self.id}
|
1017
|
+
f"Export process has started for Model ID {self.id}, Version {self.model_info.model_version.id}"
|
1009
1018
|
)
|
1010
|
-
|
1011
|
-
|
1012
|
-
|
1013
|
-
|
1014
|
-
|
1015
|
-
|
1016
|
-
|
1017
|
-
|
1018
|
-
|
1019
|
-
|
1020
|
-
|
1021
|
-
|
1022
|
-
|
1023
|
-
|
1024
|
-
|
1025
|
-
|
1026
|
-
|
1027
|
-
|
1019
|
+
if export_dir:
|
1020
|
+
start_time = time.time()
|
1021
|
+
backoff_iterator = BackoffIterator(10)
|
1022
|
+
while True:
|
1023
|
+
get_export_response = _get_export_response()
|
1024
|
+
if (get_export_response.export.status.code == status_code_pb2.MODEL_EXPORTING or \
|
1025
|
+
get_export_response.export.status.code == status_code_pb2.MODEL_EXPORT_PENDING) and \
|
1026
|
+
time.time() - start_time < MODEL_EXPORT_TIMEOUT:
|
1027
|
+
self.logger.info(
|
1028
|
+
f"Export process is ongoing for Model ID {self.id}, Version {self.model_info.model_version.id}. Please wait..."
|
1029
|
+
)
|
1030
|
+
time.sleep(next(backoff_iterator))
|
1031
|
+
elif get_export_response.export.status.code == status_code_pb2.MODEL_EXPORTED:
|
1032
|
+
_download_exported_model(get_export_response, os.path.join(export_dir, "model.tar"))
|
1033
|
+
break
|
1034
|
+
elif time.time() - start_time > MODEL_EXPORT_TIMEOUT:
|
1035
|
+
raise Exception(
|
1036
|
+
f"""Model Export took too long. Please try again or contact support@clarifai.com
|
1037
|
+
Req ID: {get_export_response.status.req_id}""")
|
1028
1038
|
elif get_export_response.export.status.code == status_code_pb2.MODEL_EXPORTED:
|
1029
|
-
|
1039
|
+
if export_dir:
|
1040
|
+
_download_exported_model(get_export_response, os.path.join(export_dir, "model.tar"))
|
1041
|
+
else:
|
1042
|
+
self.logger.info(
|
1043
|
+
f"Model ID {self.id} with version {self.model_info.model_version.id} is already exported, you can download it from the following URL: {get_export_response.export.url}"
|
1044
|
+
)
|
1045
|
+
elif get_export_response.export.status.code == status_code_pb2.MODEL_EXPORTING or \
|
1046
|
+
get_export_response.export.status.code == status_code_pb2.MODEL_EXPORT_PENDING:
|
1047
|
+
self.logger.info(
|
1048
|
+
f"Export process is ongoing for Model ID {self.id}, Version {self.model_info.model_version.id}. Please wait..."
|
1049
|
+
)
|
1050
|
+
|
1051
|
+
@staticmethod
|
1052
|
+
def _make_pretrained_config_proto(input_field_maps: dict,
|
1053
|
+
output_field_maps: dict,
|
1054
|
+
url: str = None):
|
1055
|
+
"""Make PretrainedModelConfig for uploading new version
|
1056
|
+
|
1057
|
+
Args:
|
1058
|
+
input_field_maps (dict): dict
|
1059
|
+
output_field_maps (dict): dict
|
1060
|
+
url (str, optional): direct download url. Defaults to None.
|
1061
|
+
"""
|
1062
|
+
|
1063
|
+
def _parse_fields_map(x):
|
1064
|
+
"""parse input, outputs to Struct"""
|
1065
|
+
_fields_map = Struct()
|
1066
|
+
_fields_map.update(x)
|
1067
|
+
return _fields_map
|
1068
|
+
|
1069
|
+
input_fields_map = _parse_fields_map(input_field_maps)
|
1070
|
+
output_fields_map = _parse_fields_map(output_field_maps)
|
1071
|
+
|
1072
|
+
return resources_pb2.PretrainedModelConfig(
|
1073
|
+
input_fields_map=input_fields_map, output_fields_map=output_fields_map, model_zip_url=url)
|
1074
|
+
|
1075
|
+
@staticmethod
|
1076
|
+
def _make_inference_params_proto(
|
1077
|
+
inference_parameters: List[Dict]) -> List[resources_pb2.ModelTypeField]:
|
1078
|
+
"""Convert list of Clarifai inference parameters to proto for uploading new version
|
1079
|
+
|
1080
|
+
Args:
|
1081
|
+
inference_parameters (List[Dict]): Each dict has keys {field_type, path, default_value, description}
|
1082
|
+
|
1083
|
+
Returns:
|
1084
|
+
List[resources_pb2.ModelTypeField]
|
1085
|
+
"""
|
1086
|
+
|
1087
|
+
def _make_default_value_proto(dtype, value):
|
1088
|
+
if dtype == 1:
|
1089
|
+
return Value(bool_value=value)
|
1090
|
+
elif dtype == 2 or dtype == 21:
|
1091
|
+
return Value(string_value=value)
|
1092
|
+
elif dtype == 3:
|
1093
|
+
return Value(number_value=value)
|
1094
|
+
|
1095
|
+
iterative_proto_params = []
|
1096
|
+
for param in inference_parameters:
|
1097
|
+
dtype = param.get("field_type")
|
1098
|
+
proto_param = resources_pb2.ModelTypeField(
|
1099
|
+
path=param.get("path"),
|
1100
|
+
field_type=dtype,
|
1101
|
+
default_value=_make_default_value_proto(dtype=dtype, value=param.get("default_value")),
|
1102
|
+
description=param.get("description"),
|
1103
|
+
)
|
1104
|
+
iterative_proto_params.append(proto_param)
|
1105
|
+
return iterative_proto_params
|
1106
|
+
|
1107
|
+
def create_version_by_file(self,
|
1108
|
+
file_path: str,
|
1109
|
+
input_field_maps: dict,
|
1110
|
+
output_field_maps: dict,
|
1111
|
+
inference_parameter_configs: dict = None,
|
1112
|
+
model_version: str = None,
|
1113
|
+
part_id: int = 1,
|
1114
|
+
range_start: int = 0,
|
1115
|
+
no_cache: bool = False,
|
1116
|
+
no_resume: bool = False,
|
1117
|
+
description: str = "") -> 'Model':
|
1118
|
+
"""Create model version by uploading local file
|
1119
|
+
|
1120
|
+
Args:
|
1121
|
+
file_path (str): path to built file.
|
1122
|
+
input_field_maps (dict): a dict where the key is clarifai input field and the value is triton model input,
|
1123
|
+
{clarifai_input_field: triton_input_filed}.
|
1124
|
+
output_field_maps (dict): a dict where the keys are clarifai output fields and the values are triton model outputs,
|
1125
|
+
{clarifai_output_field1: triton_output_filed1, clarifai_output_field2: triton_output_filed2,...}.
|
1126
|
+
inference_parameter_configs (List[dict]): list of dicts - keys are path, field_type, default_value, description. Default is None
|
1127
|
+
model_version (str, optional): Custom model version. Defaults to None.
|
1128
|
+
part_id (int, optional): part id of file. Defaults to 1.
|
1129
|
+
range_start (int, optional): range of uploaded size. Defaults to 0.
|
1130
|
+
no_cache (bool, optional): not saving uploading cache that is used to resume uploading. Defaults to False.
|
1131
|
+
no_resume (bool, optional): disable auto resume upload. Defaults to False.
|
1132
|
+
description (str): Model description.
|
1133
|
+
|
1134
|
+
Return:
|
1135
|
+
Model: instance of Model with new created version
|
1136
|
+
|
1137
|
+
"""
|
1138
|
+
file_size = os.path.getsize(file_path)
|
1139
|
+
assert MIN_CHUNK_FOR_UPLOAD_FILE <= file_size <= MAX_CHUNK_FOR_UPLOAD_FILE, "The file size exceeds the allowable limit, which ranges from 5MiB to 5GiB."
|
1140
|
+
|
1141
|
+
pretrained_proto = Model._make_pretrained_config_proto(
|
1142
|
+
input_field_maps=input_field_maps, output_field_maps=output_field_maps)
|
1143
|
+
inference_param_proto = Model._make_inference_params_proto(
|
1144
|
+
inference_parameter_configs) if inference_parameter_configs else None
|
1145
|
+
|
1146
|
+
if file_size >= 1e9:
|
1147
|
+
chunk_size = 1024 * 50_000 # 50MB
|
1148
|
+
else:
|
1149
|
+
chunk_size = 1024 * 10_000 # 10MB
|
1150
|
+
|
1151
|
+
#self.logger.info(f"Chunk {chunk_size/1e6}MB, {file_size/chunk_size} steps")
|
1152
|
+
#self.logger.info(f" Max bytes per stream {MAX_SIZE_PER_STREAM}")
|
1153
|
+
|
1154
|
+
cache_dir = os.path.join(file_path, '..', '.cache')
|
1155
|
+
cache_upload_file = os.path.join(cache_dir, "upload.json")
|
1156
|
+
last_percent = 0
|
1157
|
+
if os.path.exists(cache_upload_file) and not no_resume:
|
1158
|
+
with open(cache_upload_file, "r") as fp:
|
1159
|
+
try:
|
1160
|
+
cache_info = json.load(fp)
|
1161
|
+
if isinstance(cache_info, dict):
|
1162
|
+
part_id = cache_info.get("part_id", part_id)
|
1163
|
+
chunk_size = cache_info.get("chunk_size", chunk_size)
|
1164
|
+
range_start = cache_info.get("range_start", range_start)
|
1165
|
+
model_version = cache_info.get("model_version", model_version)
|
1166
|
+
last_percent = cache_info.get("last_percent", last_percent)
|
1167
|
+
except Exception as e:
|
1168
|
+
self.logger.error(f"Skipping loading the upload cache due to error {e}.")
|
1169
|
+
|
1170
|
+
def init_model_version_upload(model_version):
|
1171
|
+
return service_pb2.PostModelVersionsUploadRequest(
|
1172
|
+
upload_config=service_pb2.PostModelVersionsUploadConfig(
|
1173
|
+
user_app_id=self.user_app_id,
|
1174
|
+
model_id=self.id,
|
1175
|
+
total_size=file_size,
|
1176
|
+
model_version=resources_pb2.ModelVersion(
|
1177
|
+
id=model_version,
|
1178
|
+
pretrained_model_config=pretrained_proto,
|
1179
|
+
description=description,
|
1180
|
+
output_info=resources_pb2.OutputInfo(params_specs=inference_param_proto)),
|
1181
|
+
))
|
1182
|
+
|
1183
|
+
def _uploading(chunk, part_id, range_start, model_version):
|
1184
|
+
return service_pb2.PostModelVersionsUploadRequest(
|
1185
|
+
content_part=resources_pb2.UploadContentPart(
|
1186
|
+
data=chunk, part_number=part_id, range_start=range_start))
|
1187
|
+
|
1188
|
+
finished_status = [status_code_pb2.SUCCESS, status_code_pb2.UPLOAD_DONE]
|
1189
|
+
uploading_in_progress_status = [
|
1190
|
+
status_code_pb2.UPLOAD_IN_PROGRESS, status_code_pb2.MODEL_UPLOADING
|
1191
|
+
]
|
1192
|
+
|
1193
|
+
def _save_cache(cache: dict):
|
1194
|
+
if not no_cache:
|
1195
|
+
os.makedirs(cache_dir, exist_ok=True)
|
1196
|
+
with open(cache_upload_file, "w") as fp:
|
1197
|
+
json.dump(cache, fp, indent=2)
|
1198
|
+
|
1199
|
+
def stream_request(fp, part_id, end_part_id, chunk_size, version):
|
1200
|
+
yield init_model_version_upload(version)
|
1201
|
+
for iter_part_id in range(part_id, end_part_id):
|
1202
|
+
chunk = fp.read(chunk_size)
|
1203
|
+
if not chunk:
|
1204
|
+
return
|
1205
|
+
yield _uploading(
|
1206
|
+
chunk=chunk,
|
1207
|
+
part_id=iter_part_id,
|
1208
|
+
range_start=chunk_size * (iter_part_id - 1),
|
1209
|
+
model_version=version)
|
1210
|
+
|
1211
|
+
tqdm_loader = tqdm(total=100)
|
1212
|
+
if model_version:
|
1213
|
+
desc = f"Uploading model `{self.id}` version `{model_version}` ..."
|
1214
|
+
else:
|
1215
|
+
desc = f"Uploading model `{self.id}` ..."
|
1216
|
+
tqdm_loader.set_description(desc)
|
1217
|
+
|
1218
|
+
cache_uploading_info = {}
|
1219
|
+
cache_uploading_info["part_id"] = part_id
|
1220
|
+
cache_uploading_info["model_version"] = model_version
|
1221
|
+
cache_uploading_info["range_start"] = range_start
|
1222
|
+
cache_uploading_info["chunk_size"] = chunk_size
|
1223
|
+
cache_uploading_info["last_percent"] = last_percent
|
1224
|
+
tqdm_loader.update(last_percent)
|
1225
|
+
last_part_id = part_id
|
1226
|
+
n_chunks = file_size // chunk_size
|
1227
|
+
n_chunk_per_stream = MAX_SIZE_PER_STREAM // chunk_size or 1
|
1228
|
+
|
1229
|
+
def stream_and_logging(request, tqdm_loader, cache_uploading_info, expected_steps: int = None):
|
1230
|
+
for st_step, st_response in enumerate(self.auth_helper.get_stub().PostModelVersionsUpload(
|
1231
|
+
request, metadata=self.auth_helper.metadata)):
|
1232
|
+
if st_response.status.code in uploading_in_progress_status:
|
1233
|
+
if cache_uploading_info["model_version"]:
|
1234
|
+
assert st_response.model_version_id == cache_uploading_info[
|
1235
|
+
"model_version"], RuntimeError
|
1236
|
+
else:
|
1237
|
+
cache_uploading_info["model_version"] = st_response.model_version_id
|
1238
|
+
if st_step > 0:
|
1239
|
+
cache_uploading_info["part_id"] += 1
|
1240
|
+
cache_uploading_info["range_start"] += chunk_size
|
1241
|
+
_save_cache(cache_uploading_info)
|
1242
|
+
|
1243
|
+
if st_response.status.percent_completed:
|
1244
|
+
step_percent = st_response.status.percent_completed - cache_uploading_info["last_percent"]
|
1245
|
+
cache_uploading_info["last_percent"] += step_percent
|
1246
|
+
tqdm_loader.set_description(
|
1247
|
+
f"{st_response.status.description}, {st_response.status.details}, version id {cache_uploading_info.get('model_version')}"
|
1248
|
+
)
|
1249
|
+
tqdm_loader.update(step_percent)
|
1250
|
+
elif st_response.status.code not in finished_status + uploading_in_progress_status:
|
1251
|
+
# TODO: Find better way to handle error
|
1252
|
+
if expected_steps and st_step < expected_steps:
|
1253
|
+
raise Exception(f"Failed to upload model, error: {st_response.status}")
|
1254
|
+
|
1255
|
+
with open(file_path, 'rb') as fp:
|
1256
|
+
# seeking
|
1257
|
+
for _ in range(1, last_part_id):
|
1258
|
+
fp.read(chunk_size)
|
1259
|
+
# Stream even part
|
1260
|
+
end_part_id = n_chunks or 1
|
1261
|
+
for iter_part_id in range(int(last_part_id), int(n_chunks), int(n_chunk_per_stream)):
|
1262
|
+
end_part_id = iter_part_id + n_chunk_per_stream
|
1263
|
+
if end_part_id >= n_chunks:
|
1264
|
+
end_part_id = n_chunks
|
1265
|
+
expected_steps = end_part_id - iter_part_id + 1 # init step
|
1266
|
+
st_reqs = stream_request(
|
1267
|
+
fp,
|
1268
|
+
iter_part_id,
|
1269
|
+
end_part_id=end_part_id,
|
1270
|
+
chunk_size=chunk_size,
|
1271
|
+
version=cache_uploading_info["model_version"])
|
1272
|
+
stream_and_logging(st_reqs, tqdm_loader, cache_uploading_info, expected_steps)
|
1273
|
+
# Stream last part
|
1274
|
+
accum_size = (end_part_id - 1) * chunk_size
|
1275
|
+
remained_size = file_size - accum_size if accum_size >= 0 else file_size
|
1276
|
+
st_reqs = stream_request(
|
1277
|
+
fp,
|
1278
|
+
end_part_id,
|
1279
|
+
end_part_id=end_part_id + 1,
|
1280
|
+
chunk_size=remained_size,
|
1281
|
+
version=cache_uploading_info["model_version"])
|
1282
|
+
stream_and_logging(st_reqs, tqdm_loader, cache_uploading_info, 2)
|
1283
|
+
|
1284
|
+
# clean up cache
|
1285
|
+
if not no_cache:
|
1286
|
+
try:
|
1287
|
+
os.remove(cache_upload_file)
|
1288
|
+
except Exception:
|
1289
|
+
_save_cache({})
|
1290
|
+
|
1291
|
+
if cache_uploading_info["last_percent"] <= 100:
|
1292
|
+
tqdm_loader.update(100 - cache_uploading_info["last_percent"])
|
1293
|
+
tqdm_loader.set_description("Upload done")
|
1294
|
+
|
1295
|
+
tqdm_loader.set_description(
|
1296
|
+
f"Success uploading model {self.id}, new version {cache_uploading_info.get('model_version')}"
|
1297
|
+
)
|
1298
|
+
|
1299
|
+
return Model.from_auth_helper(
|
1300
|
+
auth=self.auth_helper,
|
1301
|
+
model_id=self.id,
|
1302
|
+
model_version=dict(id=cache_uploading_info.get('model_version')))
|
1303
|
+
|
1304
|
+
def create_version_by_url(self,
|
1305
|
+
url: str,
|
1306
|
+
input_field_maps: dict,
|
1307
|
+
output_field_maps: dict,
|
1308
|
+
inference_parameter_configs: List[dict] = None,
|
1309
|
+
description: str = "") -> 'Model':
|
1310
|
+
"""Upload a new version of an existing model in the Clarifai platform using direct download url.
|
1311
|
+
|
1312
|
+
Args:
|
1313
|
+
url (str]): url of zip of model
|
1314
|
+
input_field_maps (dict): a dict where the key is clarifai input field and the value is triton model input,
|
1315
|
+
{clarifai_input_field: triton_input_filed}.
|
1316
|
+
output_field_maps (dict): a dict where the keys are clarifai output fields and the values are triton model outputs,
|
1317
|
+
{clarifai_output_field1: triton_output_filed1, clarifai_output_field2: triton_output_filed2,...}.
|
1318
|
+
inference_parameter_configs (List[dict]): list of dicts - keys are path, field_type, default_value, description. Default is None
|
1319
|
+
description (str): Model description.
|
1320
|
+
|
1321
|
+
Return:
|
1322
|
+
Model: instance of Model with new created version
|
1323
|
+
"""
|
1324
|
+
|
1325
|
+
pretrained_proto = Model._make_pretrained_config_proto(
|
1326
|
+
input_field_maps=input_field_maps, output_field_maps=output_field_maps, url=url)
|
1327
|
+
inference_param_proto = Model._make_inference_params_proto(
|
1328
|
+
inference_parameter_configs) if inference_parameter_configs else None
|
1329
|
+
request = service_pb2.PostModelVersionsRequest(
|
1330
|
+
user_app_id=self.user_app_id,
|
1331
|
+
model_id=self.id,
|
1332
|
+
model_versions=[
|
1333
|
+
resources_pb2.ModelVersion(
|
1334
|
+
pretrained_model_config=pretrained_proto,
|
1335
|
+
description=description,
|
1336
|
+
output_info=resources_pb2.OutputInfo(params_specs=inference_param_proto))
|
1337
|
+
])
|
1338
|
+
response = self._grpc_request(self.STUB.PostModelVersions, request)
|
1339
|
+
|
1340
|
+
if response.status.code != status_code_pb2.SUCCESS:
|
1341
|
+
raise Exception(f"Failed to upload model, error: {response.status}")
|
1342
|
+
self.logger.info(
|
1343
|
+
f"Success uploading model {self.id}, new version {response.model.model_version.id}")
|
1344
|
+
|
1345
|
+
return Model.from_auth_helper(
|
1346
|
+
auth=self.auth_helper,
|
1347
|
+
model_id=self.id,
|
1348
|
+
model_version=dict(id=response.model.model_version.id))
|