clarifai 10.8.4__py3-none-any.whl → 10.8.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- clarifai/__init__.py +1 -1
- clarifai/client/dataset.py +9 -3
- clarifai/constants/dataset.py +1 -1
- clarifai/datasets/upload/base.py +6 -3
- clarifai/datasets/upload/features.py +10 -0
- clarifai/datasets/upload/image.py +22 -13
- clarifai/datasets/upload/multimodal.py +70 -0
- clarifai/datasets/upload/text.py +8 -5
- clarifai/utils/misc.py +6 -0
- {clarifai-10.8.4.dist-info → clarifai-10.8.5.dist-info}/METADATA +2 -1
- {clarifai-10.8.4.dist-info → clarifai-10.8.5.dist-info}/RECORD +15 -58
- clarifai/models/model_serving/README.md +0 -158
- clarifai/models/model_serving/__init__.py +0 -14
- clarifai/models/model_serving/cli/__init__.py +0 -12
- clarifai/models/model_serving/cli/_utils.py +0 -53
- clarifai/models/model_serving/cli/base.py +0 -14
- clarifai/models/model_serving/cli/build.py +0 -79
- clarifai/models/model_serving/cli/clarifai_clis.py +0 -33
- clarifai/models/model_serving/cli/create.py +0 -171
- clarifai/models/model_serving/cli/example_cli.py +0 -34
- clarifai/models/model_serving/cli/login.py +0 -26
- clarifai/models/model_serving/cli/upload.py +0 -183
- clarifai/models/model_serving/constants.py +0 -21
- clarifai/models/model_serving/docs/cli.md +0 -161
- clarifai/models/model_serving/docs/concepts.md +0 -229
- clarifai/models/model_serving/docs/dependencies.md +0 -11
- clarifai/models/model_serving/docs/inference_parameters.md +0 -139
- clarifai/models/model_serving/docs/model_types.md +0 -19
- clarifai/models/model_serving/model_config/__init__.py +0 -16
- clarifai/models/model_serving/model_config/base.py +0 -369
- clarifai/models/model_serving/model_config/config.py +0 -312
- clarifai/models/model_serving/model_config/inference_parameter.py +0 -129
- clarifai/models/model_serving/model_config/model_types_config/multimodal-embedder.yaml +0 -25
- clarifai/models/model_serving/model_config/model_types_config/text-classifier.yaml +0 -19
- clarifai/models/model_serving/model_config/model_types_config/text-embedder.yaml +0 -20
- clarifai/models/model_serving/model_config/model_types_config/text-to-image.yaml +0 -19
- clarifai/models/model_serving/model_config/model_types_config/text-to-text.yaml +0 -19
- clarifai/models/model_serving/model_config/model_types_config/visual-classifier.yaml +0 -22
- clarifai/models/model_serving/model_config/model_types_config/visual-detector.yaml +0 -32
- clarifai/models/model_serving/model_config/model_types_config/visual-embedder.yaml +0 -19
- clarifai/models/model_serving/model_config/model_types_config/visual-segmenter.yaml +0 -19
- clarifai/models/model_serving/model_config/output.py +0 -133
- clarifai/models/model_serving/model_config/triton/__init__.py +0 -14
- clarifai/models/model_serving/model_config/triton/serializer.py +0 -136
- clarifai/models/model_serving/model_config/triton/triton_config.py +0 -182
- clarifai/models/model_serving/model_config/triton/wrappers.py +0 -281
- clarifai/models/model_serving/repo_build/__init__.py +0 -14
- clarifai/models/model_serving/repo_build/build.py +0 -198
- clarifai/models/model_serving/repo_build/static_files/_requirements.txt +0 -2
- clarifai/models/model_serving/repo_build/static_files/base_test.py +0 -169
- clarifai/models/model_serving/repo_build/static_files/inference.py +0 -26
- clarifai/models/model_serving/repo_build/static_files/sample_clarifai_config.yaml +0 -25
- clarifai/models/model_serving/repo_build/static_files/test.py +0 -40
- clarifai/models/model_serving/repo_build/static_files/triton/model.py +0 -75
- clarifai/models/model_serving/utils.py +0 -31
- {clarifai-10.8.4.dist-info → clarifai-10.8.5.dist-info}/LICENSE +0 -0
- {clarifai-10.8.4.dist-info → clarifai-10.8.5.dist-info}/WHEEL +0 -0
- {clarifai-10.8.4.dist-info → clarifai-10.8.5.dist-info}/entry_points.txt +0 -0
- {clarifai-10.8.4.dist-info → clarifai-10.8.5.dist-info}/top_level.txt +0 -0
@@ -1,129 +0,0 @@
|
|
1
|
-
import json
|
2
|
-
from dataclasses import asdict, dataclass, field
|
3
|
-
from typing import Any, List
|
4
|
-
|
5
|
-
|
6
|
-
@dataclass(frozen=True)
|
7
|
-
class InferParamType:
|
8
|
-
BOOL: int = 1
|
9
|
-
STRING: int = 2
|
10
|
-
NUMBER: int = 3
|
11
|
-
ENCRYPTED_STRING: int = 21
|
12
|
-
|
13
|
-
|
14
|
-
@dataclass
|
15
|
-
class InferParam:
|
16
|
-
path: str
|
17
|
-
field_type: InferParamType = field(default_factory=InferParamType)
|
18
|
-
default_value: Any = None
|
19
|
-
description: str = ""
|
20
|
-
|
21
|
-
def __post_init__(self):
|
22
|
-
assert self.path.isidentifier(
|
23
|
-
), f"`path` must be valid for creating python variable, got {self.path}"
|
24
|
-
if self.default_value is not None:
|
25
|
-
self.validate_type(self.default_value)
|
26
|
-
|
27
|
-
def validate_type(self, value):
|
28
|
-
if self.field_type == InferParamType.BOOL:
|
29
|
-
assert isinstance(value, bool), f"`field_type` is `BOOL` (bool), however got {type(value)}"
|
30
|
-
elif self.field_type == InferParamType.NUMBER:
|
31
|
-
assert isinstance(value, float) or isinstance(
|
32
|
-
value, int), f"`field_type` is `NUMBER` (float or int), however got {type(value)}"
|
33
|
-
else:
|
34
|
-
assert isinstance(
|
35
|
-
value,
|
36
|
-
str), f"`field_type` is `STRING` or `ENCRYPTED_STRING` (str), however got {type(value)}"
|
37
|
-
|
38
|
-
def __setattr__(self, __name: str, __value: Any) -> None:
|
39
|
-
if __name == "default_value":
|
40
|
-
self.validate_type(__value)
|
41
|
-
super().__setattr__(__name, __value)
|
42
|
-
|
43
|
-
def todict(self):
|
44
|
-
return {k: v for k, v in asdict(self).items()}
|
45
|
-
|
46
|
-
|
47
|
-
@dataclass
|
48
|
-
class InferParamManager:
|
49
|
-
json_path: str = ""
|
50
|
-
params: List[InferParam] = field(default_factory=list)
|
51
|
-
_dict_params: dict = field(init=False)
|
52
|
-
|
53
|
-
@classmethod
|
54
|
-
def from_kwargs(cls, **kwargs):
|
55
|
-
params = list()
|
56
|
-
for k, v in kwargs.items():
|
57
|
-
if isinstance(v, str) and k.startswith("_"):
|
58
|
-
_type = InferParamType.ENCRYPTED_STRING
|
59
|
-
elif isinstance(v, str):
|
60
|
-
_type = InferParamType.STRING
|
61
|
-
elif isinstance(v, bool):
|
62
|
-
_type = InferParamType.BOOL
|
63
|
-
elif isinstance(v, float) or isinstance(v, int):
|
64
|
-
_type = InferParamType.NUMBER
|
65
|
-
else:
|
66
|
-
raise TypeError(f"Unsupported type {type(v)} of argument {k}, support {InferParamType}")
|
67
|
-
param = InferParam(path=k, field_type=_type, default_value=v, description=k)
|
68
|
-
params.append(param)
|
69
|
-
|
70
|
-
return cls(params=params)
|
71
|
-
|
72
|
-
def __post_init__(self):
|
73
|
-
#assert self.params == [] or self.json_path, "`json_path` or `params` must be set"
|
74
|
-
self._dict_params = dict()
|
75
|
-
if self.params == [] and self.json_path:
|
76
|
-
with open(self.json_path, "r") as fp:
|
77
|
-
objs = json.load(fp)
|
78
|
-
objs = objs if isinstance(objs, list) else [objs]
|
79
|
-
self.params = [InferParam(**obj) for obj in objs]
|
80
|
-
for param in self.params:
|
81
|
-
self._dict_params.update({param.path: param})
|
82
|
-
|
83
|
-
def get_list_params(self):
|
84
|
-
list_params = []
|
85
|
-
for each in self.params:
|
86
|
-
list_params.append(each.todict())
|
87
|
-
return list_params
|
88
|
-
|
89
|
-
def export(self, path: str):
|
90
|
-
list_params = self.get_list_params()
|
91
|
-
with open(path, "w") as fp:
|
92
|
-
json.dump(list_params, fp, indent=2)
|
93
|
-
|
94
|
-
def validate(self, **kwargs) -> dict:
|
95
|
-
output_kwargs = {k: v.default_value for k, v in self._dict_params.items()}
|
96
|
-
assert kwargs == {} or self.params != [], "kwargs are rejected since `params` is empty"
|
97
|
-
|
98
|
-
for key, value in kwargs.items():
|
99
|
-
assert key in self._dict_params, f"param `{key}` is not in setting: {list(self._dict_params.keys())}"
|
100
|
-
if key in self._dict_params:
|
101
|
-
self._dict_params[key].validate_type(value)
|
102
|
-
output_kwargs.update({key: value})
|
103
|
-
return output_kwargs
|
104
|
-
|
105
|
-
|
106
|
-
def is_number(v: str):
|
107
|
-
try:
|
108
|
-
_ = float(v)
|
109
|
-
return True
|
110
|
-
except ValueError:
|
111
|
-
return False
|
112
|
-
|
113
|
-
|
114
|
-
def str_to_number(v: str):
|
115
|
-
try:
|
116
|
-
return int(v)
|
117
|
-
except ValueError:
|
118
|
-
return float(v)
|
119
|
-
|
120
|
-
|
121
|
-
def parse_req_parameters(req_params: str):
|
122
|
-
req_params = json.loads(req_params)
|
123
|
-
for k, v in req_params.items():
|
124
|
-
if isinstance(v, str):
|
125
|
-
if is_number(v):
|
126
|
-
v = str_to_number(v)
|
127
|
-
req_params.update({k: v})
|
128
|
-
|
129
|
-
return req_params
|
@@ -1,25 +0,0 @@
|
|
1
|
-
serving_backend:
|
2
|
-
triton:
|
3
|
-
input:
|
4
|
-
- name: image
|
5
|
-
data_type: TYPE_UINT8
|
6
|
-
dims: [-1, -1, 3]
|
7
|
-
optional: true
|
8
|
-
- name: text
|
9
|
-
data_type: TYPE_STRING
|
10
|
-
dims: [1]
|
11
|
-
optional: true
|
12
|
-
output:
|
13
|
-
- name: embeddings
|
14
|
-
data_type: TYPE_FP32
|
15
|
-
dims: [-1]
|
16
|
-
label_filename: null
|
17
|
-
clarifai_model:
|
18
|
-
type: multimodal-embedder
|
19
|
-
output_type: EmbeddingOutput
|
20
|
-
field_maps:
|
21
|
-
input_fields_map:
|
22
|
-
image: image
|
23
|
-
text: text
|
24
|
-
output_fields_map:
|
25
|
-
embeddings: embeddings
|
@@ -1,19 +0,0 @@
|
|
1
|
-
serving_backend:
|
2
|
-
triton:
|
3
|
-
input:
|
4
|
-
- name: text
|
5
|
-
data_type: TYPE_STRING
|
6
|
-
dims: [1]
|
7
|
-
output:
|
8
|
-
- name: softmax_predictions
|
9
|
-
data_type: TYPE_FP32
|
10
|
-
dims: [-1]
|
11
|
-
label_filename: "labels.txt"
|
12
|
-
clarifai_model:
|
13
|
-
type: text-classifier
|
14
|
-
output_type: ClassifierOutput
|
15
|
-
field_maps:
|
16
|
-
input_fields_map:
|
17
|
-
text: text
|
18
|
-
output_fields_map:
|
19
|
-
concepts: softmax_predictions
|
@@ -1,20 +0,0 @@
|
|
1
|
-
serving_backend:
|
2
|
-
triton:
|
3
|
-
input:
|
4
|
-
- name: text
|
5
|
-
data_type: TYPE_STRING
|
6
|
-
dims: [1]
|
7
|
-
output:
|
8
|
-
- name: embeddings
|
9
|
-
data_type: TYPE_FP32
|
10
|
-
dims: [-1]
|
11
|
-
label_filename: null
|
12
|
-
|
13
|
-
clarifai_model:
|
14
|
-
type: text-embedder
|
15
|
-
output_type: EmbeddingOutput
|
16
|
-
field_maps:
|
17
|
-
input_fields_map:
|
18
|
-
text: text
|
19
|
-
output_fields_map:
|
20
|
-
embeddings: embeddings
|
@@ -1,19 +0,0 @@
|
|
1
|
-
serving_backend:
|
2
|
-
triton:
|
3
|
-
input:
|
4
|
-
- name: text
|
5
|
-
data_type: TYPE_STRING
|
6
|
-
dims: [1]
|
7
|
-
output:
|
8
|
-
- name: image
|
9
|
-
data_type: TYPE_UINT8
|
10
|
-
dims: [-1, -1, 3]
|
11
|
-
label_filename: null
|
12
|
-
clarifai_model:
|
13
|
-
type: text-to-image
|
14
|
-
output_type: ImageOutput
|
15
|
-
field_maps:
|
16
|
-
input_fields_map:
|
17
|
-
text: text
|
18
|
-
output_fields_map:
|
19
|
-
image: image
|
@@ -1,19 +0,0 @@
|
|
1
|
-
serving_backend:
|
2
|
-
triton:
|
3
|
-
input:
|
4
|
-
- name: text
|
5
|
-
data_type: TYPE_STRING
|
6
|
-
dims: [1]
|
7
|
-
output:
|
8
|
-
- name: text
|
9
|
-
data_type: TYPE_STRING
|
10
|
-
dims: [1]
|
11
|
-
label_filename: null
|
12
|
-
clarifai_model:
|
13
|
-
type: text-to-text
|
14
|
-
output_type: TextOutput
|
15
|
-
field_maps:
|
16
|
-
input_fields_map:
|
17
|
-
text: text
|
18
|
-
output_fields_map:
|
19
|
-
text: text
|
@@ -1,22 +0,0 @@
|
|
1
|
-
serving_backend:
|
2
|
-
triton:
|
3
|
-
input:
|
4
|
-
- name: image
|
5
|
-
data_type: TYPE_UINT8
|
6
|
-
dims: [-1, -1, 3]
|
7
|
-
output:
|
8
|
-
- name: softmax_predictions
|
9
|
-
data_type: TYPE_FP32
|
10
|
-
dims: [-1]
|
11
|
-
label_filename: "labels.txt"
|
12
|
-
|
13
|
-
clarifai_model:
|
14
|
-
field_maps:
|
15
|
-
input_fields_map:
|
16
|
-
image: image
|
17
|
-
output_fields_map:
|
18
|
-
concepts: softmax_predictions
|
19
|
-
output_type: ClassifierOutput
|
20
|
-
type: visual-classifier
|
21
|
-
labels:
|
22
|
-
inference_parameters:
|
@@ -1,32 +0,0 @@
|
|
1
|
-
serving_backend:
|
2
|
-
triton:
|
3
|
-
input:
|
4
|
-
- name: image
|
5
|
-
data_type: TYPE_UINT8
|
6
|
-
dims: [-1, -1, 3]
|
7
|
-
output:
|
8
|
-
- name: predicted_bboxes
|
9
|
-
data_type: TYPE_FP32
|
10
|
-
dims: [-1, 4]
|
11
|
-
label_filename: null
|
12
|
-
- name: predicted_labels
|
13
|
-
data_type: TYPE_INT32
|
14
|
-
dims: [-1, 1]
|
15
|
-
label_filename: "labels.txt"
|
16
|
-
- name: predicted_scores
|
17
|
-
data_type: TYPE_FP32
|
18
|
-
dims: [-1, 1]
|
19
|
-
label_filename: null
|
20
|
-
|
21
|
-
clarifai_model:
|
22
|
-
field_maps:
|
23
|
-
input_fields_map:
|
24
|
-
image: image
|
25
|
-
output_fields_map:
|
26
|
-
"regions[...].region_info.bounding_box": "predicted_bboxes"
|
27
|
-
"regions[...].data.concepts[...].id": "predicted_labels"
|
28
|
-
"regions[...].data.concepts[...].value": "predicted_scores"
|
29
|
-
output_type: VisualDetectorOutput
|
30
|
-
type: visual-detector
|
31
|
-
labels:
|
32
|
-
inference_parameters:
|
@@ -1,19 +0,0 @@
|
|
1
|
-
serving_backend:
|
2
|
-
triton:
|
3
|
-
input:
|
4
|
-
- name: image
|
5
|
-
data_type: TYPE_UINT8
|
6
|
-
dims: [-1, -1, 3]
|
7
|
-
output:
|
8
|
-
- name: embeddings
|
9
|
-
data_type: TYPE_FP32
|
10
|
-
dims: [-1]
|
11
|
-
label_filename: null
|
12
|
-
clarifai_model:
|
13
|
-
type: visual-embedder
|
14
|
-
output_type: EmbeddingOutput
|
15
|
-
field_maps:
|
16
|
-
input_fields_map:
|
17
|
-
image: image
|
18
|
-
output_fields_map:
|
19
|
-
embeddings: embeddings
|
@@ -1,19 +0,0 @@
|
|
1
|
-
serving_backend:
|
2
|
-
triton:
|
3
|
-
input:
|
4
|
-
- name: image
|
5
|
-
data_type: TYPE_UINT8
|
6
|
-
dims: [-1, -1, 3]
|
7
|
-
output:
|
8
|
-
- name: predicted_mask
|
9
|
-
data_type: TYPE_INT64
|
10
|
-
dims: [-1, -1]
|
11
|
-
label_filename: "labels.txt"
|
12
|
-
clarifai_model:
|
13
|
-
type: visual-segmenter
|
14
|
-
output_type: MasksOutput
|
15
|
-
field_maps:
|
16
|
-
input_fields_map:
|
17
|
-
image: image
|
18
|
-
output_fields_map:
|
19
|
-
"regions[...].region_info.mask,regions[...].data.concepts": "predicted_mask"
|
@@ -1,133 +0,0 @@
|
|
1
|
-
# Copyright 2023 Clarifai, Inc.
|
2
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
3
|
-
# you may not use this file except in compliance with the License.
|
4
|
-
# You may obtain a copy of the License at
|
5
|
-
#
|
6
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
7
|
-
#
|
8
|
-
# Unless required by applicable law or agreed to in writing, software
|
9
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
10
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
11
|
-
# See the License for the specific language governing permissions and
|
12
|
-
# limitations under the License.
|
13
|
-
"""
|
14
|
-
Output Predictions format for different model types.
|
15
|
-
"""
|
16
|
-
|
17
|
-
from dataclasses import dataclass
|
18
|
-
import numpy as np
|
19
|
-
|
20
|
-
|
21
|
-
@dataclass
|
22
|
-
class VisualDetectorOutput:
|
23
|
-
predicted_bboxes: np.ndarray
|
24
|
-
predicted_labels: np.ndarray
|
25
|
-
predicted_scores: np.ndarray
|
26
|
-
|
27
|
-
def __post_init__(self):
|
28
|
-
"""
|
29
|
-
Validate input upon initialization.
|
30
|
-
"""
|
31
|
-
assert isinstance(self.predicted_scores, np.ndarray), "`predicted_scores` must be numpy array"
|
32
|
-
assert isinstance(self.predicted_labels, np.ndarray), "`predicted_labels` must be numpy array"
|
33
|
-
assert isinstance(self.predicted_scores, np.ndarray), "`predicted_scores` must be numpy array"
|
34
|
-
|
35
|
-
assert self.predicted_bboxes.ndim == self.predicted_labels.ndim == \
|
36
|
-
self.predicted_scores.ndim==2, f"All predictions must be 2-dimensional, \
|
37
|
-
Got bbox-dims: {self.predicted_bboxes.ndim}, label-dims: {self.predicted_labels.ndim}, \
|
38
|
-
scores-dims: {self.predicted_scores.ndim} instead."
|
39
|
-
assert self.predicted_bboxes.shape[0] == self.predicted_labels.shape[0] == \
|
40
|
-
self.predicted_scores.shape[0], f"The Number of predicted bounding boxes, \
|
41
|
-
predicted labels and predicted scores MUST match. Got {len(self.predicted_bboxes)}, \
|
42
|
-
{self.predicted_labels.shape[0]}, {self.predicted_scores.shape[0]} instead."
|
43
|
-
|
44
|
-
if len(self.predicted_labels) > 0:
|
45
|
-
assert self.predicted_bboxes.shape[
|
46
|
-
1] == 4, f"Box coordinates must have a length of 4. Actual:{self.predicted_bboxes.shape[1]}"
|
47
|
-
assert np.all(np.logical_and(0 <= self.predicted_bboxes, self.predicted_bboxes <= 1)), \
|
48
|
-
"Bounding box coordinates must be between 0 and 1"
|
49
|
-
|
50
|
-
|
51
|
-
@dataclass
|
52
|
-
class ClassifierOutput:
|
53
|
-
"""
|
54
|
-
Takes model softmax predictions
|
55
|
-
"""
|
56
|
-
predicted_scores: np.ndarray
|
57
|
-
|
58
|
-
# the index of each predicted score as returned by the model must correspond
|
59
|
-
# to the predicted label index in the labels.txt file
|
60
|
-
|
61
|
-
def __post_init__(self):
|
62
|
-
"""
|
63
|
-
Validate input upon initialization.
|
64
|
-
"""
|
65
|
-
assert isinstance(self.predicted_scores, np.ndarray), "`predicted_scores` must be numpy array"
|
66
|
-
assert self.predicted_scores.ndim == 1, \
|
67
|
-
f"All predictions must be 1-dimensional, Got scores-dims: {self.predicted_scores.ndim} instead."
|
68
|
-
|
69
|
-
|
70
|
-
@dataclass
|
71
|
-
class TextOutput:
|
72
|
-
"""
|
73
|
-
Takes model text predictions
|
74
|
-
"""
|
75
|
-
predicted_text: str
|
76
|
-
|
77
|
-
def __post_init__(self):
|
78
|
-
"""
|
79
|
-
Validate input upon initialization.
|
80
|
-
"""
|
81
|
-
self.predicted_text = np.array(self.predicted_text, dtype=object)
|
82
|
-
assert self.predicted_text.ndim == 0, \
|
83
|
-
f"All predictions must be 0-dimensional, Got text-dims: {self.predicted_text.ndim} instead."
|
84
|
-
|
85
|
-
|
86
|
-
@dataclass
|
87
|
-
class EmbeddingOutput:
|
88
|
-
"""
|
89
|
-
Takes embedding vector returned by a model.
|
90
|
-
"""
|
91
|
-
embedding_vector: np.ndarray
|
92
|
-
|
93
|
-
def __post_init__(self):
|
94
|
-
"""
|
95
|
-
Validate input upon initialization.
|
96
|
-
"""
|
97
|
-
assert isinstance(self.embedding_vector, np.ndarray), "`embedding_vector` must be numpy array"
|
98
|
-
assert self.embedding_vector.ndim == 1, \
|
99
|
-
f"Embeddings must be 1-dimensional, Got embedding-dims: {self.embedding_vector.ndim} instead."
|
100
|
-
|
101
|
-
|
102
|
-
@dataclass
|
103
|
-
class MasksOutput:
|
104
|
-
"""
|
105
|
-
Takes image segmentation masks returned by a model.
|
106
|
-
"""
|
107
|
-
predicted_mask: np.ndarray
|
108
|
-
|
109
|
-
def __post_init__(self):
|
110
|
-
"""
|
111
|
-
Validate input upon initialization.
|
112
|
-
"""
|
113
|
-
assert isinstance(self.predicted_mask, np.ndarray), "`predicted_mask` must be numpy array"
|
114
|
-
assert self.predicted_mask.ndim == 2, \
|
115
|
-
f"predicted_mask must be 2-dimensional, Got mask dims: {self.predicted_mask.ndim} instead."
|
116
|
-
|
117
|
-
|
118
|
-
@dataclass
|
119
|
-
class ImageOutput:
|
120
|
-
"""
|
121
|
-
Takes a predicted/generated image array as returned by a model.
|
122
|
-
"""
|
123
|
-
image: np.ndarray
|
124
|
-
|
125
|
-
def __post_init__(self):
|
126
|
-
"""
|
127
|
-
Validate input upon initialization.
|
128
|
-
"""
|
129
|
-
assert isinstance(self.image, np.ndarray), "`image` must be numpy array"
|
130
|
-
assert self.image.ndim == 3, \
|
131
|
-
f"Generated image must be 3-dimensional, Got image-dims: {self.image.ndim} instead."
|
132
|
-
assert self.image.shape[2] == 3, \
|
133
|
-
f"The image channels dimension must equal 3, Got channel dim: {self.image.shape[2]} instead."
|
@@ -1,14 +0,0 @@
|
|
1
|
-
# Copyright 2023 Clarifai, Inc.
|
2
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
3
|
-
# you may not use this file except in compliance with the License.
|
4
|
-
# You may obtain a copy of the License at
|
5
|
-
#
|
6
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
7
|
-
#
|
8
|
-
# Unless required by applicable law or agreed to in writing, software
|
9
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
10
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
11
|
-
# See the License for the specific language governing permissions and
|
12
|
-
# limitations under the License.
|
13
|
-
from .serializer import Serializer # noqa # pylint: disable=unused-import
|
14
|
-
from .triton_config import * # noqa # pylint: disable=unused-import
|
@@ -1,136 +0,0 @@
|
|
1
|
-
# Copyright 2023 Clarifai, Inc.
|
2
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
3
|
-
# you may not use this file except in compliance with the License.
|
4
|
-
# You may obtain a copy of the License at
|
5
|
-
#
|
6
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
7
|
-
#
|
8
|
-
# Unless required by applicable law or agreed to in writing, software
|
9
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
10
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
11
|
-
# See the License for the specific language governing permissions and
|
12
|
-
# limitations under the License.
|
13
|
-
"""
|
14
|
-
Parse & Serialize TritonModelConfig objects into proto format.
|
15
|
-
"""
|
16
|
-
|
17
|
-
import os
|
18
|
-
from pathlib import Path
|
19
|
-
from typing import Type
|
20
|
-
|
21
|
-
from google.protobuf.text_format import MessageToString
|
22
|
-
from tritonclient.grpc import model_config_pb2
|
23
|
-
|
24
|
-
from .triton_config import TritonModelConfig
|
25
|
-
|
26
|
-
|
27
|
-
class Serializer:
|
28
|
-
"""
|
29
|
-
Serialize TritonModelConfig type object.
|
30
|
-
"""
|
31
|
-
|
32
|
-
def __init__(self, model_config: Type[TritonModelConfig]) -> None:
|
33
|
-
self.model_config = model_config #python dataclass config
|
34
|
-
self.config_proto = model_config_pb2.ModelConfig() #holds parsed python config
|
35
|
-
|
36
|
-
self._set_all_fields()
|
37
|
-
|
38
|
-
def _set_input(self) -> None:
|
39
|
-
"""
|
40
|
-
Parse InputConfig object to proto.
|
41
|
-
"""
|
42
|
-
for in_field in self.model_config.input:
|
43
|
-
input_config = self.config_proto.input.add()
|
44
|
-
for key, value in in_field.__dict__.items():
|
45
|
-
try:
|
46
|
-
setattr(input_config, key, value)
|
47
|
-
except AttributeError:
|
48
|
-
field = getattr(input_config, key)
|
49
|
-
if isinstance(value, list):
|
50
|
-
field.extend(value)
|
51
|
-
else:
|
52
|
-
field.extend([value])
|
53
|
-
return
|
54
|
-
|
55
|
-
def _set_output(self) -> None:
|
56
|
-
"""
|
57
|
-
Parse OutputConfig object to proto.
|
58
|
-
"""
|
59
|
-
# loop over output dataclass list
|
60
|
-
for out_field in self.model_config.output:
|
61
|
-
output_config = self.config_proto.output.add()
|
62
|
-
for key, value in out_field.__dict__.items():
|
63
|
-
try:
|
64
|
-
if not value:
|
65
|
-
continue
|
66
|
-
setattr(output_config, key, value)
|
67
|
-
except AttributeError: #Proto Repeated Field assignment not allowed
|
68
|
-
field = getattr(output_config, key)
|
69
|
-
if isinstance(value, list):
|
70
|
-
field.extend(value)
|
71
|
-
else:
|
72
|
-
field.extend([value])
|
73
|
-
return
|
74
|
-
|
75
|
-
def _set_instance_group(self) -> None:
|
76
|
-
"""
|
77
|
-
Parse triton model instance group settings to proto.
|
78
|
-
"""
|
79
|
-
instance = self.config_proto.instance_group.add()
|
80
|
-
for field_name, value in self.model_config.instance_group.__dict__.items():
|
81
|
-
try:
|
82
|
-
setattr(instance, field_name, value)
|
83
|
-
except AttributeError:
|
84
|
-
continue
|
85
|
-
return
|
86
|
-
|
87
|
-
def _set_batch_info(self) -> model_config_pb2.ModelDynamicBatching:
|
88
|
-
"""
|
89
|
-
Parse triton model dynamic batching settings to proto.
|
90
|
-
"""
|
91
|
-
dbatch_msg = model_config_pb2.ModelDynamicBatching()
|
92
|
-
for key, value in self.model_config.dynamic_batching.__dict__.items():
|
93
|
-
try:
|
94
|
-
setattr(dbatch_msg, key, value)
|
95
|
-
except AttributeError: #Proto Repeated Field assignment not allowed
|
96
|
-
field = getattr(dbatch_msg, key)
|
97
|
-
if isinstance(value, list):
|
98
|
-
field.extend(value)
|
99
|
-
else:
|
100
|
-
field.extend([value])
|
101
|
-
|
102
|
-
return dbatch_msg
|
103
|
-
|
104
|
-
def _set_all_fields(self) -> None:
|
105
|
-
"""
|
106
|
-
Set all config fields.
|
107
|
-
"""
|
108
|
-
self.config_proto.name = self.model_config.model_name
|
109
|
-
self.config_proto.backend = self.model_config.backend
|
110
|
-
self.config_proto.max_batch_size = self.model_config.max_batch_size
|
111
|
-
self._set_input()
|
112
|
-
self._set_output()
|
113
|
-
self._set_instance_group()
|
114
|
-
dynamic_batch_msg = self._set_batch_info()
|
115
|
-
self.config_proto.dynamic_batching.CopyFrom(dynamic_batch_msg)
|
116
|
-
|
117
|
-
@property
|
118
|
-
def get_config(self) -> model_config_pb2.ModelConfig:
|
119
|
-
"""
|
120
|
-
Return model config proto.
|
121
|
-
"""
|
122
|
-
return self.config_proto
|
123
|
-
|
124
|
-
def to_file(self, save_dir: Path) -> None:
|
125
|
-
"""
|
126
|
-
Serialize all triton config parameters and save output
|
127
|
-
to file.
|
128
|
-
Args:
|
129
|
-
-----
|
130
|
-
save_dir: Directory where to save resultant config.pbtxt file.
|
131
|
-
Defaults to the current working dir.
|
132
|
-
"""
|
133
|
-
msg_string = MessageToString(self.config_proto)
|
134
|
-
|
135
|
-
with open(os.path.join(save_dir, "config.pbtxt"), "w") as cfile:
|
136
|
-
cfile.write(msg_string)
|