clarifai 10.8.3__py3-none-any.whl → 10.8.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- clarifai/__init__.py +1 -1
- clarifai/client/dataset.py +9 -3
- clarifai/constants/dataset.py +1 -1
- clarifai/datasets/upload/base.py +6 -3
- clarifai/datasets/upload/features.py +10 -0
- clarifai/datasets/upload/image.py +22 -13
- clarifai/datasets/upload/multimodal.py +70 -0
- clarifai/datasets/upload/text.py +8 -5
- clarifai/runners/utils/data_handler.py +31 -44
- clarifai/runners/utils/loader.py +6 -5
- clarifai/utils/misc.py +6 -0
- {clarifai-10.8.3.dist-info → clarifai-10.8.5.dist-info}/METADATA +2 -1
- {clarifai-10.8.3.dist-info → clarifai-10.8.5.dist-info}/RECORD +17 -60
- clarifai/models/model_serving/README.md +0 -158
- clarifai/models/model_serving/__init__.py +0 -14
- clarifai/models/model_serving/cli/__init__.py +0 -12
- clarifai/models/model_serving/cli/_utils.py +0 -53
- clarifai/models/model_serving/cli/base.py +0 -14
- clarifai/models/model_serving/cli/build.py +0 -79
- clarifai/models/model_serving/cli/clarifai_clis.py +0 -33
- clarifai/models/model_serving/cli/create.py +0 -171
- clarifai/models/model_serving/cli/example_cli.py +0 -34
- clarifai/models/model_serving/cli/login.py +0 -26
- clarifai/models/model_serving/cli/upload.py +0 -183
- clarifai/models/model_serving/constants.py +0 -21
- clarifai/models/model_serving/docs/cli.md +0 -161
- clarifai/models/model_serving/docs/concepts.md +0 -229
- clarifai/models/model_serving/docs/dependencies.md +0 -11
- clarifai/models/model_serving/docs/inference_parameters.md +0 -139
- clarifai/models/model_serving/docs/model_types.md +0 -19
- clarifai/models/model_serving/model_config/__init__.py +0 -16
- clarifai/models/model_serving/model_config/base.py +0 -369
- clarifai/models/model_serving/model_config/config.py +0 -312
- clarifai/models/model_serving/model_config/inference_parameter.py +0 -129
- clarifai/models/model_serving/model_config/model_types_config/multimodal-embedder.yaml +0 -25
- clarifai/models/model_serving/model_config/model_types_config/text-classifier.yaml +0 -19
- clarifai/models/model_serving/model_config/model_types_config/text-embedder.yaml +0 -20
- clarifai/models/model_serving/model_config/model_types_config/text-to-image.yaml +0 -19
- clarifai/models/model_serving/model_config/model_types_config/text-to-text.yaml +0 -19
- clarifai/models/model_serving/model_config/model_types_config/visual-classifier.yaml +0 -22
- clarifai/models/model_serving/model_config/model_types_config/visual-detector.yaml +0 -32
- clarifai/models/model_serving/model_config/model_types_config/visual-embedder.yaml +0 -19
- clarifai/models/model_serving/model_config/model_types_config/visual-segmenter.yaml +0 -19
- clarifai/models/model_serving/model_config/output.py +0 -133
- clarifai/models/model_serving/model_config/triton/__init__.py +0 -14
- clarifai/models/model_serving/model_config/triton/serializer.py +0 -136
- clarifai/models/model_serving/model_config/triton/triton_config.py +0 -182
- clarifai/models/model_serving/model_config/triton/wrappers.py +0 -281
- clarifai/models/model_serving/repo_build/__init__.py +0 -14
- clarifai/models/model_serving/repo_build/build.py +0 -198
- clarifai/models/model_serving/repo_build/static_files/_requirements.txt +0 -2
- clarifai/models/model_serving/repo_build/static_files/base_test.py +0 -169
- clarifai/models/model_serving/repo_build/static_files/inference.py +0 -26
- clarifai/models/model_serving/repo_build/static_files/sample_clarifai_config.yaml +0 -25
- clarifai/models/model_serving/repo_build/static_files/test.py +0 -40
- clarifai/models/model_serving/repo_build/static_files/triton/model.py +0 -75
- clarifai/models/model_serving/utils.py +0 -31
- {clarifai-10.8.3.dist-info → clarifai-10.8.5.dist-info}/LICENSE +0 -0
- {clarifai-10.8.3.dist-info → clarifai-10.8.5.dist-info}/WHEEL +0 -0
- {clarifai-10.8.3.dist-info → clarifai-10.8.5.dist-info}/entry_points.txt +0 -0
- {clarifai-10.8.3.dist-info → clarifai-10.8.5.dist-info}/top_level.txt +0 -0
@@ -1,133 +0,0 @@
|
|
1
|
-
# Copyright 2023 Clarifai, Inc.
|
2
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
3
|
-
# you may not use this file except in compliance with the License.
|
4
|
-
# You may obtain a copy of the License at
|
5
|
-
#
|
6
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
7
|
-
#
|
8
|
-
# Unless required by applicable law or agreed to in writing, software
|
9
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
10
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
11
|
-
# See the License for the specific language governing permissions and
|
12
|
-
# limitations under the License.
|
13
|
-
"""
|
14
|
-
Output Predictions format for different model types.
|
15
|
-
"""
|
16
|
-
|
17
|
-
from dataclasses import dataclass
|
18
|
-
import numpy as np
|
19
|
-
|
20
|
-
|
21
|
-
@dataclass
|
22
|
-
class VisualDetectorOutput:
|
23
|
-
predicted_bboxes: np.ndarray
|
24
|
-
predicted_labels: np.ndarray
|
25
|
-
predicted_scores: np.ndarray
|
26
|
-
|
27
|
-
def __post_init__(self):
|
28
|
-
"""
|
29
|
-
Validate input upon initialization.
|
30
|
-
"""
|
31
|
-
assert isinstance(self.predicted_scores, np.ndarray), "`predicted_scores` must be numpy array"
|
32
|
-
assert isinstance(self.predicted_labels, np.ndarray), "`predicted_labels` must be numpy array"
|
33
|
-
assert isinstance(self.predicted_scores, np.ndarray), "`predicted_scores` must be numpy array"
|
34
|
-
|
35
|
-
assert self.predicted_bboxes.ndim == self.predicted_labels.ndim == \
|
36
|
-
self.predicted_scores.ndim==2, f"All predictions must be 2-dimensional, \
|
37
|
-
Got bbox-dims: {self.predicted_bboxes.ndim}, label-dims: {self.predicted_labels.ndim}, \
|
38
|
-
scores-dims: {self.predicted_scores.ndim} instead."
|
39
|
-
assert self.predicted_bboxes.shape[0] == self.predicted_labels.shape[0] == \
|
40
|
-
self.predicted_scores.shape[0], f"The Number of predicted bounding boxes, \
|
41
|
-
predicted labels and predicted scores MUST match. Got {len(self.predicted_bboxes)}, \
|
42
|
-
{self.predicted_labels.shape[0]}, {self.predicted_scores.shape[0]} instead."
|
43
|
-
|
44
|
-
if len(self.predicted_labels) > 0:
|
45
|
-
assert self.predicted_bboxes.shape[
|
46
|
-
1] == 4, f"Box coordinates must have a length of 4. Actual:{self.predicted_bboxes.shape[1]}"
|
47
|
-
assert np.all(np.logical_and(0 <= self.predicted_bboxes, self.predicted_bboxes <= 1)), \
|
48
|
-
"Bounding box coordinates must be between 0 and 1"
|
49
|
-
|
50
|
-
|
51
|
-
@dataclass
|
52
|
-
class ClassifierOutput:
|
53
|
-
"""
|
54
|
-
Takes model softmax predictions
|
55
|
-
"""
|
56
|
-
predicted_scores: np.ndarray
|
57
|
-
|
58
|
-
# the index of each predicted score as returned by the model must correspond
|
59
|
-
# to the predicted label index in the labels.txt file
|
60
|
-
|
61
|
-
def __post_init__(self):
|
62
|
-
"""
|
63
|
-
Validate input upon initialization.
|
64
|
-
"""
|
65
|
-
assert isinstance(self.predicted_scores, np.ndarray), "`predicted_scores` must be numpy array"
|
66
|
-
assert self.predicted_scores.ndim == 1, \
|
67
|
-
f"All predictions must be 1-dimensional, Got scores-dims: {self.predicted_scores.ndim} instead."
|
68
|
-
|
69
|
-
|
70
|
-
@dataclass
|
71
|
-
class TextOutput:
|
72
|
-
"""
|
73
|
-
Takes model text predictions
|
74
|
-
"""
|
75
|
-
predicted_text: str
|
76
|
-
|
77
|
-
def __post_init__(self):
|
78
|
-
"""
|
79
|
-
Validate input upon initialization.
|
80
|
-
"""
|
81
|
-
self.predicted_text = np.array(self.predicted_text, dtype=object)
|
82
|
-
assert self.predicted_text.ndim == 0, \
|
83
|
-
f"All predictions must be 0-dimensional, Got text-dims: {self.predicted_text.ndim} instead."
|
84
|
-
|
85
|
-
|
86
|
-
@dataclass
|
87
|
-
class EmbeddingOutput:
|
88
|
-
"""
|
89
|
-
Takes embedding vector returned by a model.
|
90
|
-
"""
|
91
|
-
embedding_vector: np.ndarray
|
92
|
-
|
93
|
-
def __post_init__(self):
|
94
|
-
"""
|
95
|
-
Validate input upon initialization.
|
96
|
-
"""
|
97
|
-
assert isinstance(self.embedding_vector, np.ndarray), "`embedding_vector` must be numpy array"
|
98
|
-
assert self.embedding_vector.ndim == 1, \
|
99
|
-
f"Embeddings must be 1-dimensional, Got embedding-dims: {self.embedding_vector.ndim} instead."
|
100
|
-
|
101
|
-
|
102
|
-
@dataclass
|
103
|
-
class MasksOutput:
|
104
|
-
"""
|
105
|
-
Takes image segmentation masks returned by a model.
|
106
|
-
"""
|
107
|
-
predicted_mask: np.ndarray
|
108
|
-
|
109
|
-
def __post_init__(self):
|
110
|
-
"""
|
111
|
-
Validate input upon initialization.
|
112
|
-
"""
|
113
|
-
assert isinstance(self.predicted_mask, np.ndarray), "`predicted_mask` must be numpy array"
|
114
|
-
assert self.predicted_mask.ndim == 2, \
|
115
|
-
f"predicted_mask must be 2-dimensional, Got mask dims: {self.predicted_mask.ndim} instead."
|
116
|
-
|
117
|
-
|
118
|
-
@dataclass
|
119
|
-
class ImageOutput:
|
120
|
-
"""
|
121
|
-
Takes a predicted/generated image array as returned by a model.
|
122
|
-
"""
|
123
|
-
image: np.ndarray
|
124
|
-
|
125
|
-
def __post_init__(self):
|
126
|
-
"""
|
127
|
-
Validate input upon initialization.
|
128
|
-
"""
|
129
|
-
assert isinstance(self.image, np.ndarray), "`image` must be numpy array"
|
130
|
-
assert self.image.ndim == 3, \
|
131
|
-
f"Generated image must be 3-dimensional, Got image-dims: {self.image.ndim} instead."
|
132
|
-
assert self.image.shape[2] == 3, \
|
133
|
-
f"The image channels dimension must equal 3, Got channel dim: {self.image.shape[2]} instead."
|
@@ -1,14 +0,0 @@
|
|
1
|
-
# Copyright 2023 Clarifai, Inc.
|
2
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
3
|
-
# you may not use this file except in compliance with the License.
|
4
|
-
# You may obtain a copy of the License at
|
5
|
-
#
|
6
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
7
|
-
#
|
8
|
-
# Unless required by applicable law or agreed to in writing, software
|
9
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
10
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
11
|
-
# See the License for the specific language governing permissions and
|
12
|
-
# limitations under the License.
|
13
|
-
from .serializer import Serializer # noqa # pylint: disable=unused-import
|
14
|
-
from .triton_config import * # noqa # pylint: disable=unused-import
|
@@ -1,136 +0,0 @@
|
|
1
|
-
# Copyright 2023 Clarifai, Inc.
|
2
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
3
|
-
# you may not use this file except in compliance with the License.
|
4
|
-
# You may obtain a copy of the License at
|
5
|
-
#
|
6
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
7
|
-
#
|
8
|
-
# Unless required by applicable law or agreed to in writing, software
|
9
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
10
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
11
|
-
# See the License for the specific language governing permissions and
|
12
|
-
# limitations under the License.
|
13
|
-
"""
|
14
|
-
Parse & Serialize TritonModelConfig objects into proto format.
|
15
|
-
"""
|
16
|
-
|
17
|
-
import os
|
18
|
-
from pathlib import Path
|
19
|
-
from typing import Type
|
20
|
-
|
21
|
-
from google.protobuf.text_format import MessageToString
|
22
|
-
from tritonclient.grpc import model_config_pb2
|
23
|
-
|
24
|
-
from .triton_config import TritonModelConfig
|
25
|
-
|
26
|
-
|
27
|
-
class Serializer:
|
28
|
-
"""
|
29
|
-
Serialize TritonModelConfig type object.
|
30
|
-
"""
|
31
|
-
|
32
|
-
def __init__(self, model_config: Type[TritonModelConfig]) -> None:
|
33
|
-
self.model_config = model_config #python dataclass config
|
34
|
-
self.config_proto = model_config_pb2.ModelConfig() #holds parsed python config
|
35
|
-
|
36
|
-
self._set_all_fields()
|
37
|
-
|
38
|
-
def _set_input(self) -> None:
|
39
|
-
"""
|
40
|
-
Parse InputConfig object to proto.
|
41
|
-
"""
|
42
|
-
for in_field in self.model_config.input:
|
43
|
-
input_config = self.config_proto.input.add()
|
44
|
-
for key, value in in_field.__dict__.items():
|
45
|
-
try:
|
46
|
-
setattr(input_config, key, value)
|
47
|
-
except AttributeError:
|
48
|
-
field = getattr(input_config, key)
|
49
|
-
if isinstance(value, list):
|
50
|
-
field.extend(value)
|
51
|
-
else:
|
52
|
-
field.extend([value])
|
53
|
-
return
|
54
|
-
|
55
|
-
def _set_output(self) -> None:
|
56
|
-
"""
|
57
|
-
Parse OutputConfig object to proto.
|
58
|
-
"""
|
59
|
-
# loop over output dataclass list
|
60
|
-
for out_field in self.model_config.output:
|
61
|
-
output_config = self.config_proto.output.add()
|
62
|
-
for key, value in out_field.__dict__.items():
|
63
|
-
try:
|
64
|
-
if not value:
|
65
|
-
continue
|
66
|
-
setattr(output_config, key, value)
|
67
|
-
except AttributeError: #Proto Repeated Field assignment not allowed
|
68
|
-
field = getattr(output_config, key)
|
69
|
-
if isinstance(value, list):
|
70
|
-
field.extend(value)
|
71
|
-
else:
|
72
|
-
field.extend([value])
|
73
|
-
return
|
74
|
-
|
75
|
-
def _set_instance_group(self) -> None:
|
76
|
-
"""
|
77
|
-
Parse triton model instance group settings to proto.
|
78
|
-
"""
|
79
|
-
instance = self.config_proto.instance_group.add()
|
80
|
-
for field_name, value in self.model_config.instance_group.__dict__.items():
|
81
|
-
try:
|
82
|
-
setattr(instance, field_name, value)
|
83
|
-
except AttributeError:
|
84
|
-
continue
|
85
|
-
return
|
86
|
-
|
87
|
-
def _set_batch_info(self) -> model_config_pb2.ModelDynamicBatching:
|
88
|
-
"""
|
89
|
-
Parse triton model dynamic batching settings to proto.
|
90
|
-
"""
|
91
|
-
dbatch_msg = model_config_pb2.ModelDynamicBatching()
|
92
|
-
for key, value in self.model_config.dynamic_batching.__dict__.items():
|
93
|
-
try:
|
94
|
-
setattr(dbatch_msg, key, value)
|
95
|
-
except AttributeError: #Proto Repeated Field assignment not allowed
|
96
|
-
field = getattr(dbatch_msg, key)
|
97
|
-
if isinstance(value, list):
|
98
|
-
field.extend(value)
|
99
|
-
else:
|
100
|
-
field.extend([value])
|
101
|
-
|
102
|
-
return dbatch_msg
|
103
|
-
|
104
|
-
def _set_all_fields(self) -> None:
|
105
|
-
"""
|
106
|
-
Set all config fields.
|
107
|
-
"""
|
108
|
-
self.config_proto.name = self.model_config.model_name
|
109
|
-
self.config_proto.backend = self.model_config.backend
|
110
|
-
self.config_proto.max_batch_size = self.model_config.max_batch_size
|
111
|
-
self._set_input()
|
112
|
-
self._set_output()
|
113
|
-
self._set_instance_group()
|
114
|
-
dynamic_batch_msg = self._set_batch_info()
|
115
|
-
self.config_proto.dynamic_batching.CopyFrom(dynamic_batch_msg)
|
116
|
-
|
117
|
-
@property
|
118
|
-
def get_config(self) -> model_config_pb2.ModelConfig:
|
119
|
-
"""
|
120
|
-
Return model config proto.
|
121
|
-
"""
|
122
|
-
return self.config_proto
|
123
|
-
|
124
|
-
def to_file(self, save_dir: Path) -> None:
|
125
|
-
"""
|
126
|
-
Serialize all triton config parameters and save output
|
127
|
-
to file.
|
128
|
-
Args:
|
129
|
-
-----
|
130
|
-
save_dir: Directory where to save resultant config.pbtxt file.
|
131
|
-
Defaults to the current working dir.
|
132
|
-
"""
|
133
|
-
msg_string = MessageToString(self.config_proto)
|
134
|
-
|
135
|
-
with open(os.path.join(save_dir, "config.pbtxt"), "w") as cfile:
|
136
|
-
cfile.write(msg_string)
|
@@ -1,182 +0,0 @@
|
|
1
|
-
# Copyright 2023 Clarifai, Inc.
|
2
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
3
|
-
# you may not use this file except in compliance with the License.
|
4
|
-
# You may obtain a copy of the License at
|
5
|
-
#
|
6
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
7
|
-
#
|
8
|
-
# Unless required by applicable law or agreed to in writing, software
|
9
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
10
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
11
|
-
# See the License for the specific language governing permissions and
|
12
|
-
# limitations under the License.
|
13
|
-
""" Model Config classes."""
|
14
|
-
from __future__ import annotations # isort: skip
|
15
|
-
|
16
|
-
from copy import deepcopy
|
17
|
-
from dataclasses import dataclass, field
|
18
|
-
from typing import Any, List, Union
|
19
|
-
|
20
|
-
from ...constants import IMAGE_TENSOR_NAME, MAX_HW_DIM
|
21
|
-
|
22
|
-
|
23
|
-
### Triton Model Config classes.###
|
24
|
-
@dataclass
|
25
|
-
class DType:
|
26
|
-
"""
|
27
|
-
Triton Model Config data types.
|
28
|
-
"""
|
29
|
-
# https://github.com/triton-inference-server/common/blob/main/protobuf/model_config.proto
|
30
|
-
TYPE_UINT8: int = 2
|
31
|
-
TYPE_INT8: int = 6
|
32
|
-
TYPE_INT16: int = 7
|
33
|
-
TYPE_INT32: int = 8
|
34
|
-
TYPE_INT64: int = 9
|
35
|
-
TYPE_FP16: int = 10
|
36
|
-
TYPE_FP32: int = 11
|
37
|
-
TYPE_STRING: int = 13
|
38
|
-
KIND_GPU: int = 1
|
39
|
-
KIND_CPU: int = 2
|
40
|
-
|
41
|
-
|
42
|
-
@dataclass
|
43
|
-
class InputConfig:
|
44
|
-
"""
|
45
|
-
Triton Input definition.
|
46
|
-
Params:
|
47
|
-
-------
|
48
|
-
name: input name
|
49
|
-
data_type: input data type
|
50
|
-
dims: Pre-defined input data shape(s).
|
51
|
-
|
52
|
-
Returns:
|
53
|
-
--------
|
54
|
-
InputConfig
|
55
|
-
"""
|
56
|
-
name: str
|
57
|
-
data_type: int
|
58
|
-
dims: List = field(default_factory=list)
|
59
|
-
optional: bool = False
|
60
|
-
|
61
|
-
|
62
|
-
@dataclass
|
63
|
-
class OutputConfig:
|
64
|
-
"""
|
65
|
-
Triton Output definition.
|
66
|
-
Params:
|
67
|
-
-------
|
68
|
-
name: output name
|
69
|
-
data_type: output data type
|
70
|
-
dims: Pre-defined output data shape(s).
|
71
|
-
labels (bool): If labels file is required for inference.
|
72
|
-
|
73
|
-
Returns:
|
74
|
-
--------
|
75
|
-
OutputConfig
|
76
|
-
"""
|
77
|
-
name: str
|
78
|
-
data_type: int
|
79
|
-
dims: List = field(default_factory=list)
|
80
|
-
label_filename: str = ""
|
81
|
-
|
82
|
-
|
83
|
-
@dataclass
|
84
|
-
class Device:
|
85
|
-
"""
|
86
|
-
Triton instance_group.
|
87
|
-
Define the type of inference device and number of devices to use.
|
88
|
-
Params:
|
89
|
-
-------
|
90
|
-
count: number of devices
|
91
|
-
use_gpu: whether to use cpu or gpu.
|
92
|
-
|
93
|
-
Returns:
|
94
|
-
--------
|
95
|
-
Device object
|
96
|
-
"""
|
97
|
-
count: int = 1
|
98
|
-
use_gpu: bool = True
|
99
|
-
|
100
|
-
def __post_init__(self):
|
101
|
-
if self.use_gpu:
|
102
|
-
self.kind: str = DType.KIND_GPU
|
103
|
-
else:
|
104
|
-
self.kind: str = DType.KIND_CPU
|
105
|
-
|
106
|
-
|
107
|
-
@dataclass
|
108
|
-
class DynamicBatching:
|
109
|
-
"""
|
110
|
-
Triton dynamic_batching config.
|
111
|
-
Params:
|
112
|
-
-------
|
113
|
-
preferred_batch_size: batch size
|
114
|
-
max_queue_delay_microseconds: max queue delay for a request batch
|
115
|
-
|
116
|
-
Returns:
|
117
|
-
--------
|
118
|
-
DynamicBatching object
|
119
|
-
"""
|
120
|
-
#preferred_batch_size: List[int] = [1] # recommended not to set
|
121
|
-
max_queue_delay_microseconds: int = 500
|
122
|
-
|
123
|
-
|
124
|
-
@dataclass
|
125
|
-
class TritonModelConfig:
|
126
|
-
"""
|
127
|
-
Triton Model Config base.
|
128
|
-
Params:
|
129
|
-
-------
|
130
|
-
name: triton inference model name
|
131
|
-
input: a list of an InputConfig field
|
132
|
-
output: a list of OutputConfig fields/dicts
|
133
|
-
instance_group: Device. see Device
|
134
|
-
dynamic_batching: Triton dynamic batching settings.
|
135
|
-
max_batch_size: max request batch size
|
136
|
-
backend: Triton Python Backend. Constant
|
137
|
-
image_shape: List of Height and Width of input image. *
|
138
|
-
|
139
|
-
(*): This attribute won't be serialized in config.pbtxt
|
140
|
-
|
141
|
-
Returns:
|
142
|
-
--------
|
143
|
-
TritonModelConfig
|
144
|
-
"""
|
145
|
-
#model_type: str
|
146
|
-
model_name: str = ""
|
147
|
-
model_version: str = "1"
|
148
|
-
input: List[InputConfig] = field(default_factory=list)
|
149
|
-
output: List[OutputConfig] = field(default_factory=list)
|
150
|
-
instance_group: Device = field(default_factory=Device)
|
151
|
-
dynamic_batching: DynamicBatching = field(default_factory=DynamicBatching)
|
152
|
-
max_batch_size: int = 1
|
153
|
-
backend: str = "python"
|
154
|
-
image_shape: tuple[Union[int, float], Union[int, float]] = field(
|
155
|
-
default_factory=lambda: [-1, -1]) #(H, W)
|
156
|
-
|
157
|
-
def __setattr__(self, __name: str, __value: Any) -> None:
|
158
|
-
if __name == "image_shape":
|
159
|
-
__value = self._check_and_assign_image_shape_value(__value)
|
160
|
-
|
161
|
-
super().__setattr__(__name, __value)
|
162
|
-
|
163
|
-
def _check_and_assign_image_shape_value(self, value):
|
164
|
-
_has_image = False
|
165
|
-
for each in self.input:
|
166
|
-
if IMAGE_TENSOR_NAME in each.name:
|
167
|
-
_has_image = True
|
168
|
-
if len(value) != 2:
|
169
|
-
raise ValueError(
|
170
|
-
f"image_shape takes 2 values, Height and Width. Got {len(value)} values instead.")
|
171
|
-
if value[0] > MAX_HW_DIM or value[1] > MAX_HW_DIM:
|
172
|
-
raise ValueError(
|
173
|
-
f"H and W each have a maximum value of {MAX_HW_DIM}. Got H: {value[0]}, W: {value[1]}"
|
174
|
-
)
|
175
|
-
image_dims = deepcopy(value)
|
176
|
-
image_dims.append(3) # add channel dim
|
177
|
-
each.dims = image_dims
|
178
|
-
|
179
|
-
if not _has_image and self.input:
|
180
|
-
return [-1, -1]
|
181
|
-
else:
|
182
|
-
return value
|