clarifai 9.8.1__py3-none-any.whl → 9.9.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (124) hide show
  1. clarifai/client/app.py +115 -14
  2. clarifai/client/base.py +11 -4
  3. clarifai/client/dataset.py +8 -3
  4. clarifai/client/input.py +34 -28
  5. clarifai/client/model.py +71 -2
  6. clarifai/client/module.py +4 -2
  7. clarifai/client/runner.py +161 -0
  8. clarifai/client/search.py +173 -0
  9. clarifai/client/user.py +110 -4
  10. clarifai/client/workflow.py +27 -2
  11. clarifai/constants/search.py +2 -0
  12. clarifai/datasets/upload/loaders/xview_detection.py +1 -1
  13. clarifai/models/model_serving/README.md +3 -3
  14. clarifai/models/model_serving/cli/deploy_cli.py +2 -3
  15. clarifai/models/model_serving/cli/repository.py +3 -5
  16. clarifai/models/model_serving/constants.py +1 -5
  17. clarifai/models/model_serving/docs/custom_config.md +5 -6
  18. clarifai/models/model_serving/docs/dependencies.md +5 -10
  19. clarifai/models/model_serving/examples/image_classification/age_vit/requirements.txt +1 -0
  20. clarifai/models/model_serving/examples/text_classification/xlm-roberta/requirements.txt +1 -0
  21. clarifai/models/model_serving/examples/text_to_image/sd-v1.5/requirements.txt +1 -0
  22. clarifai/models/model_serving/examples/text_to_text/bart-summarize/requirements.txt +1 -0
  23. clarifai/models/model_serving/examples/visual_detection/yolov5x/requirements.txt +1 -1
  24. clarifai/models/model_serving/examples/visual_embedding/vit-base/requirements.txt +1 -0
  25. clarifai/models/model_serving/examples/visual_segmentation/segformer-b2/requirements.txt +1 -0
  26. clarifai/models/model_serving/model_config/__init__.py +2 -0
  27. clarifai/models/model_serving/model_config/config.py +298 -0
  28. clarifai/models/model_serving/model_config/model_types_config/text-classifier.yaml +18 -0
  29. clarifai/models/model_serving/model_config/model_types_config/text-embedder.yaml +18 -0
  30. clarifai/models/model_serving/model_config/model_types_config/text-to-image.yaml +18 -0
  31. clarifai/models/model_serving/model_config/model_types_config/text-to-text.yaml +18 -0
  32. clarifai/models/model_serving/model_config/model_types_config/visual-classifier.yaml +18 -0
  33. clarifai/models/model_serving/model_config/model_types_config/visual-detector.yaml +28 -0
  34. clarifai/models/model_serving/model_config/model_types_config/visual-embedder.yaml +18 -0
  35. clarifai/models/model_serving/model_config/model_types_config/visual-segmenter.yaml +18 -0
  36. clarifai/models/model_serving/model_config/serializer.py +1 -1
  37. clarifai/models/model_serving/models/default_test.py +22 -21
  38. clarifai/models/model_serving/models/output.py +2 -2
  39. clarifai/models/model_serving/pb_model_repository.py +2 -5
  40. clarifai/runners/__init__.py +0 -0
  41. clarifai/runners/example.py +33 -0
  42. clarifai/schema/search.py +60 -0
  43. clarifai/utils/logging.py +53 -3
  44. clarifai/versions.py +1 -1
  45. clarifai/workflows/__init__.py +0 -0
  46. clarifai/workflows/export.py +68 -0
  47. clarifai/workflows/utils.py +59 -0
  48. clarifai/workflows/validate.py +67 -0
  49. {clarifai-9.8.1.dist-info → clarifai-9.9.0.dist-info}/METADATA +20 -2
  50. {clarifai-9.8.1.dist-info → clarifai-9.9.0.dist-info}/RECORD +102 -86
  51. clarifai_utils/client/app.py +115 -14
  52. clarifai_utils/client/base.py +11 -4
  53. clarifai_utils/client/dataset.py +8 -3
  54. clarifai_utils/client/input.py +34 -28
  55. clarifai_utils/client/model.py +71 -2
  56. clarifai_utils/client/module.py +4 -2
  57. clarifai_utils/client/runner.py +161 -0
  58. clarifai_utils/client/search.py +173 -0
  59. clarifai_utils/client/user.py +110 -4
  60. clarifai_utils/client/workflow.py +27 -2
  61. clarifai_utils/constants/search.py +2 -0
  62. clarifai_utils/datasets/upload/loaders/xview_detection.py +1 -1
  63. clarifai_utils/models/model_serving/README.md +3 -3
  64. clarifai_utils/models/model_serving/cli/deploy_cli.py +2 -3
  65. clarifai_utils/models/model_serving/cli/repository.py +3 -5
  66. clarifai_utils/models/model_serving/constants.py +1 -5
  67. clarifai_utils/models/model_serving/docs/custom_config.md +5 -6
  68. clarifai_utils/models/model_serving/docs/dependencies.md +5 -10
  69. clarifai_utils/models/model_serving/examples/image_classification/age_vit/requirements.txt +1 -0
  70. clarifai_utils/models/model_serving/examples/text_classification/xlm-roberta/requirements.txt +1 -0
  71. clarifai_utils/models/model_serving/examples/text_to_image/sd-v1.5/requirements.txt +1 -0
  72. clarifai_utils/models/model_serving/examples/text_to_text/bart-summarize/requirements.txt +1 -0
  73. clarifai_utils/models/model_serving/examples/visual_detection/yolov5x/requirements.txt +1 -1
  74. clarifai_utils/models/model_serving/examples/visual_embedding/vit-base/requirements.txt +1 -0
  75. clarifai_utils/models/model_serving/examples/visual_segmentation/segformer-b2/requirements.txt +1 -0
  76. clarifai_utils/models/model_serving/model_config/__init__.py +2 -0
  77. clarifai_utils/models/model_serving/model_config/config.py +298 -0
  78. clarifai_utils/models/model_serving/model_config/model_types_config/text-classifier.yaml +18 -0
  79. clarifai_utils/models/model_serving/model_config/model_types_config/text-embedder.yaml +18 -0
  80. clarifai_utils/models/model_serving/model_config/model_types_config/text-to-image.yaml +18 -0
  81. clarifai_utils/models/model_serving/model_config/model_types_config/text-to-text.yaml +18 -0
  82. clarifai_utils/models/model_serving/model_config/model_types_config/visual-classifier.yaml +18 -0
  83. clarifai_utils/models/model_serving/model_config/model_types_config/visual-detector.yaml +28 -0
  84. clarifai_utils/models/model_serving/model_config/model_types_config/visual-embedder.yaml +18 -0
  85. clarifai_utils/models/model_serving/model_config/model_types_config/visual-segmenter.yaml +18 -0
  86. clarifai_utils/models/model_serving/model_config/serializer.py +1 -1
  87. clarifai_utils/models/model_serving/models/default_test.py +22 -21
  88. clarifai_utils/models/model_serving/models/output.py +2 -2
  89. clarifai_utils/models/model_serving/pb_model_repository.py +2 -5
  90. clarifai_utils/runners/__init__.py +0 -0
  91. clarifai_utils/runners/example.py +33 -0
  92. clarifai_utils/schema/search.py +60 -0
  93. clarifai_utils/utils/logging.py +53 -3
  94. clarifai_utils/versions.py +1 -1
  95. clarifai_utils/workflows/__init__.py +0 -0
  96. clarifai_utils/workflows/export.py +68 -0
  97. clarifai_utils/workflows/utils.py +59 -0
  98. clarifai_utils/workflows/validate.py +67 -0
  99. clarifai/models/model_serving/envs/triton_conda-cp3.8-torch1.13.1-19f97078.yaml +0 -35
  100. clarifai/models/model_serving/envs/triton_conda-cp3.8-torch2.0.0-ce980f28.yaml +0 -51
  101. clarifai/models/model_serving/examples/image_classification/age_vit/triton_conda.yaml +0 -1
  102. clarifai/models/model_serving/examples/text_classification/xlm-roberta/triton_conda.yaml +0 -1
  103. clarifai/models/model_serving/examples/text_to_image/sd-v1.5/triton_conda.yaml +0 -1
  104. clarifai/models/model_serving/examples/text_to_text/bart-summarize/triton_conda.yaml +0 -1
  105. clarifai/models/model_serving/examples/visual_detection/yolov5x/triton_conda.yaml +0 -1
  106. clarifai/models/model_serving/examples/visual_embedding/vit-base/triton_conda.yaml +0 -1
  107. clarifai/models/model_serving/examples/visual_segmentation/segformer-b2/triton_conda.yaml +0 -1
  108. clarifai/models/model_serving/model_config/deploy.py +0 -75
  109. clarifai/models/model_serving/model_config/triton_config.py +0 -226
  110. clarifai_utils/models/model_serving/envs/triton_conda-cp3.8-torch1.13.1-19f97078.yaml +0 -35
  111. clarifai_utils/models/model_serving/envs/triton_conda-cp3.8-torch2.0.0-ce980f28.yaml +0 -51
  112. clarifai_utils/models/model_serving/examples/image_classification/age_vit/triton_conda.yaml +0 -1
  113. clarifai_utils/models/model_serving/examples/text_classification/xlm-roberta/triton_conda.yaml +0 -1
  114. clarifai_utils/models/model_serving/examples/text_to_image/sd-v1.5/triton_conda.yaml +0 -1
  115. clarifai_utils/models/model_serving/examples/text_to_text/bart-summarize/triton_conda.yaml +0 -1
  116. clarifai_utils/models/model_serving/examples/visual_detection/yolov5x/triton_conda.yaml +0 -1
  117. clarifai_utils/models/model_serving/examples/visual_embedding/vit-base/triton_conda.yaml +0 -1
  118. clarifai_utils/models/model_serving/examples/visual_segmentation/segformer-b2/triton_conda.yaml +0 -1
  119. clarifai_utils/models/model_serving/model_config/deploy.py +0 -75
  120. clarifai_utils/models/model_serving/model_config/triton_config.py +0 -226
  121. {clarifai-9.8.1.dist-info → clarifai-9.9.0.dist-info}/LICENSE +0 -0
  122. {clarifai-9.8.1.dist-info → clarifai-9.9.0.dist-info}/WHEEL +0 -0
  123. {clarifai-9.8.1.dist-info → clarifai-9.9.0.dist-info}/entry_points.txt +0 -0
  124. {clarifai-9.8.1.dist-info → clarifai-9.9.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,298 @@
1
+ # Copyright 2023 Clarifai, Inc.
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # http://www.apache.org/licenses/LICENSE-2.0
7
+ #
8
+ # Unless required by applicable law or agreed to in writing, software
9
+ # distributed under the License is distributed on an "AS IS" BASIS,
10
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
+ # See the License for the specific language governing permissions and
12
+ # limitations under the License.
13
+ """ Model Config classes."""
14
+
15
+ from dataclasses import asdict, dataclass, field
16
+ from typing import List
17
+
18
+ import yaml
19
+
20
+ from ..models.model_types import * # noqa # pylint: disable=unused-import
21
+ from ..models.output import * # noqa # pylint: disable=unused-import
22
+
23
+ __all__ = ["get_model_config", "MODEL_TYPES", "TritonModelConfig", "ModelTypes"]
24
+
25
+ ### Triton Model Config classes.###
26
+
27
+
28
+ @dataclass
29
+ class DType:
30
+ """
31
+ Triton Model Config data types.
32
+ """
33
+ # https://github.com/triton-inference-server/common/blob/main/protobuf/model_config.proto
34
+ TYPE_UINT8: int = 2
35
+ TYPE_INT8: int = 6
36
+ TYPE_INT16: int = 7
37
+ TYPE_INT32: int = 8
38
+ TYPE_INT64: int = 9
39
+ TYPE_FP16: int = 10
40
+ TYPE_FP32: int = 11
41
+ TYPE_STRING: int = 13
42
+ KIND_GPU: int = 1
43
+ KIND_CPU: int = 2
44
+
45
+
46
+ @dataclass
47
+ class InputConfig:
48
+ """
49
+ Triton Input definition.
50
+ Params:
51
+ -------
52
+ name: input name
53
+ data_type: input data type
54
+ dims: Pre-defined input data shape(s).
55
+
56
+ Returns:
57
+ --------
58
+ InputConfig
59
+ """
60
+ name: str
61
+ data_type: int
62
+ dims: List = field(default_factory=list)
63
+
64
+
65
+ @dataclass
66
+ class OutputConfig:
67
+ """
68
+ Triton Output definition.
69
+ Params:
70
+ -------
71
+ name: output name
72
+ data_type: output data type
73
+ dims: Pre-defined output data shape(s).
74
+ labels (bool): If labels file is required for inference.
75
+
76
+ Returns:
77
+ --------
78
+ OutputConfig
79
+ """
80
+ name: str
81
+ data_type: int
82
+ dims: List = field(default_factory=list)
83
+ labels: bool = False
84
+
85
+ def __post_init__(self):
86
+ if self.labels:
87
+ self.label_filename = "labels.txt"
88
+ del self.labels
89
+
90
+
91
+ @dataclass
92
+ class Device:
93
+ """
94
+ Triton instance_group.
95
+ Define the type of inference device and number of devices to use.
96
+ Params:
97
+ -------
98
+ count: number of devices
99
+ use_gpu: whether to use cpu or gpu.
100
+
101
+ Returns:
102
+ --------
103
+ Device object
104
+ """
105
+ count: int = 1
106
+ use_gpu: bool = True
107
+
108
+ def __post_init__(self):
109
+ if self.use_gpu:
110
+ self.kind: str = DType.KIND_GPU
111
+ else:
112
+ self.kind: str = DType.KIND_CPU
113
+
114
+
115
+ @dataclass
116
+ class DynamicBatching:
117
+ """
118
+ Triton dynamic_batching config.
119
+ Params:
120
+ -------
121
+ preferred_batch_size: batch size
122
+ max_queue_delay_microseconds: max queue delay for a request batch
123
+
124
+ Returns:
125
+ --------
126
+ DynamicBatching object
127
+ """
128
+ #preferred_batch_size: List[int] = [1] # recommended not to set
129
+ max_queue_delay_microseconds: int = 500
130
+
131
+
132
+ @dataclass
133
+ class TritonModelConfig:
134
+ """
135
+ Triton Model Config base.
136
+ Params:
137
+ -------
138
+ name: triton inference model name
139
+ input: a list of an InputConfig field
140
+ output: a list of OutputConfig fields/dicts
141
+ instance_group: Device. see Device
142
+ dynamic_batching: Triton dynamic batching settings.
143
+ max_batch_size: max request batch size
144
+ backend: Triton Python Backend. Constant
145
+
146
+ Returns:
147
+ --------
148
+ TritonModelConfig
149
+ """
150
+ model_type: str
151
+ model_name: str
152
+ model_version: str
153
+ image_shape: List #(H, W)
154
+ input: List[InputConfig] = field(default_factory=list)
155
+ output: List[OutputConfig] = field(default_factory=list)
156
+ instance_group: Device = field(default_factory=Device)
157
+ dynamic_batching: DynamicBatching = field(default_factory=DynamicBatching)
158
+ max_batch_size: int = 1
159
+ backend: str = "python"
160
+
161
+ def __post_init__(self):
162
+ if "image" in [each.name for each in self.input]:
163
+ image_dims = self.image_shape
164
+ image_dims.append(3) # add channel dim
165
+ self.input[0].dims = image_dims
166
+
167
+
168
+ ### General Model Config classes & functions ###
169
+
170
+
171
+ # Clarifai model types
172
+ @dataclass
173
+ class ModelTypes:
174
+ visual_detector: str = "visual-detector"
175
+ visual_classifier: str = "visual-classifier"
176
+ text_classifier: str = "text-classifier"
177
+ text_to_text: str = "text-to-text"
178
+ text_embedder: str = "text-embedder"
179
+ text_to_image: str = "text-to-image"
180
+ visual_embedder: str = "visual-embedder"
181
+ visual_segmenter: str = "visual-segmenter"
182
+
183
+ def __post_init__(self):
184
+ self.all = list(asdict(self).values())
185
+
186
+
187
+ @dataclass
188
+ class InferenceConfig:
189
+ wrap_func: callable
190
+ return_type: dataclass
191
+
192
+
193
+ @dataclass
194
+ class FieldMapsConfig:
195
+ input_fields_map: dict
196
+ output_fields_map: dict
197
+
198
+
199
+ @dataclass
200
+ class DefaultTritonConfig:
201
+ input: List[InputConfig] = field(default_factory=list)
202
+ output: List[OutputConfig] = field(default_factory=list)
203
+
204
+
205
+ @dataclass
206
+ class ModelConfigClass:
207
+ type: str = field(init=False)
208
+ triton: DefaultTritonConfig
209
+ inference: InferenceConfig
210
+ field_maps: FieldMapsConfig
211
+
212
+ def make_triton_model_config(
213
+ self,
214
+ model_name: str,
215
+ model_version: str,
216
+ image_shape: List = None,
217
+ instance_group: Device = Device(),
218
+ dynamic_batching: DynamicBatching = DynamicBatching(),
219
+ max_batch_size: int = 1,
220
+ backend: str = "python",
221
+ ) -> TritonModelConfig:
222
+
223
+ return TritonModelConfig(
224
+ model_type=self.type,
225
+ model_name=model_name,
226
+ model_version=model_version,
227
+ image_shape=image_shape,
228
+ instance_group=instance_group,
229
+ dynamic_batching=dynamic_batching,
230
+ max_batch_size=max_batch_size,
231
+ backend=backend,
232
+ input=self.triton.input,
233
+ output=self.triton.output)
234
+
235
+
236
+ def read_config(cfg: str):
237
+ with open(cfg, encoding="utf-8") as f:
238
+ config = yaml.safe_load(f) # model dict
239
+
240
+ # parse default triton
241
+ input_triton_configs = config["triton"]["input"]
242
+ output_triton_configs = config["triton"]["output"]
243
+ triton = DefaultTritonConfig(
244
+ input=[
245
+ InputConfig(
246
+ name=input["name"],
247
+ data_type=eval(f"DType.{input['data_type']}"),
248
+ dims=input["dims"]) for input in input_triton_configs
249
+ ],
250
+ output=[
251
+ OutputConfig(
252
+ name=output["name"],
253
+ data_type=eval(f"DType.{output['data_type']}"),
254
+ dims=output["dims"],
255
+ labels=output["labels"],
256
+ ) for output in output_triton_configs
257
+ ])
258
+
259
+ # parse inference config
260
+ inference = InferenceConfig(
261
+ wrap_func=eval(config["inference"]["wrap_func"]),
262
+ return_type=eval(config["inference"]["return_type"]),
263
+ )
264
+
265
+ # parse field maps for deployment
266
+ field_maps = FieldMapsConfig(**config["field_maps"])
267
+
268
+ return ModelConfigClass(triton=triton, inference=inference, field_maps=field_maps)
269
+
270
+
271
+ def get_model_config(model_type: str) -> ModelConfigClass:
272
+ """
273
+ Get model config by model type
274
+
275
+ Args:
276
+
277
+ model_type (str): One of field value of ModelTypes
278
+
279
+ Return:
280
+ ModelConfigClass
281
+
282
+ ### Example:
283
+ >>> cfg = get_model_config(ModelTypes.text_classifier)
284
+ >>> custom_triton_config = cfg.make_triton_model_config(**kwargs)
285
+
286
+
287
+ """
288
+ import os
289
+ assert model_type in MODEL_TYPES, f"`model_type` must be in {MODEL_TYPES}"
290
+ cfg = read_config(
291
+ os.path.join(os.path.dirname(__file__), "model_types_config", f"{model_type}.yaml"))
292
+ cfg.type = model_type
293
+ return cfg
294
+
295
+
296
+ _model_types = ModelTypes()
297
+ MODEL_TYPES = _model_types.all
298
+ del _model_types
@@ -0,0 +1,18 @@
1
+ triton:
2
+ input:
3
+ - name: text
4
+ data_type: TYPE_STRING
5
+ dims: [1]
6
+ output:
7
+ - name: softmax_predictions
8
+ data_type: TYPE_FP32
9
+ dims: [-1]
10
+ labels: true
11
+ inference:
12
+ wrap_func: text_classifier
13
+ return_type: ClassifierOutput
14
+ field_maps:
15
+ input_fields_map:
16
+ text: text
17
+ output_fields_map:
18
+ concepts: softmax_predictions
@@ -0,0 +1,18 @@
1
+ triton:
2
+ input:
3
+ - name: text
4
+ data_type: TYPE_STRING
5
+ dims: [1]
6
+ output:
7
+ - name: embeddings
8
+ data_type: TYPE_FP32
9
+ dims: [-1]
10
+ labels: false
11
+ inference:
12
+ wrap_func: text_embedder
13
+ return_type: EmbeddingOutput
14
+ field_maps:
15
+ input_fields_map:
16
+ text: text
17
+ output_fields_map:
18
+ embeddings: embeddings
@@ -0,0 +1,18 @@
1
+ triton:
2
+ input:
3
+ - name: text
4
+ data_type: TYPE_STRING
5
+ dims: [1]
6
+ output:
7
+ - name: image
8
+ data_type: TYPE_UINT8
9
+ dims: [-1, -1, 3]
10
+ labels: false
11
+ inference:
12
+ wrap_func: text_to_image
13
+ return_type: ImageOutput
14
+ field_maps:
15
+ input_fields_map:
16
+ text: text
17
+ output_fields_map:
18
+ image: image
@@ -0,0 +1,18 @@
1
+ triton:
2
+ input:
3
+ - name: text
4
+ data_type: TYPE_STRING
5
+ dims: [1]
6
+ output:
7
+ - name: text
8
+ data_type: TYPE_STRING
9
+ dims: [1]
10
+ labels: false
11
+ inference:
12
+ wrap_func: text_to_text
13
+ return_type: TextOutput
14
+ field_maps:
15
+ input_fields_map:
16
+ text: text
17
+ output_fields_map:
18
+ text: text
@@ -0,0 +1,18 @@
1
+ triton:
2
+ input:
3
+ - name: image
4
+ data_type: TYPE_UINT8
5
+ dims: [-1, -1, 3]
6
+ output:
7
+ - name: softmax_predictions
8
+ data_type: TYPE_FP32
9
+ dims: [-1]
10
+ labels: true
11
+ inference:
12
+ wrap_func: visual_classifier
13
+ return_type: ClassifierOutput
14
+ field_maps:
15
+ input_fields_map:
16
+ image: image
17
+ output_fields_map:
18
+ concepts: softmax_predictions
@@ -0,0 +1,28 @@
1
+ triton:
2
+ input:
3
+ - name: image
4
+ data_type: TYPE_UINT8
5
+ dims: [-1, -1, 3]
6
+ output:
7
+ - name: predicted_bboxes
8
+ data_type: TYPE_FP32
9
+ dims: [-1, 4]
10
+ labels: false
11
+ - name: predicted_labels
12
+ data_type: TYPE_INT32
13
+ dims: [-1, 1]
14
+ labels: true
15
+ - name: predicted_scores
16
+ data_type: TYPE_FP32
17
+ dims: [-1, 1]
18
+ labels: false
19
+ inference:
20
+ wrap_func: visual_detector
21
+ return_type: VisualDetectorOutput
22
+ field_maps:
23
+ input_fields_map:
24
+ image: image
25
+ output_fields_map:
26
+ "regions[...].region_info.bounding_box": "predicted_bboxes"
27
+ "regions[...].data.concepts[...].id": "predicted_labels"
28
+ "regions[...].data.concepts[...].value": "predicted_scores"
@@ -0,0 +1,18 @@
1
+ triton:
2
+ input:
3
+ - name: image
4
+ data_type: TYPE_UINT8
5
+ dims: [-1, -1, 3]
6
+ output:
7
+ - name: embeddings
8
+ data_type: TYPE_FP32
9
+ dims: [-1]
10
+ labels: false
11
+ inference:
12
+ wrap_func: visual_embedder
13
+ return_type: EmbeddingOutput
14
+ field_maps:
15
+ input_fields_map:
16
+ image: image
17
+ output_fields_map:
18
+ embeddings: embeddings
@@ -0,0 +1,18 @@
1
+ triton:
2
+ input:
3
+ - name: image
4
+ data_type: TYPE_UINT8
5
+ dims: [-1, -1, 3]
6
+ output:
7
+ - name: predicted_mask
8
+ data_type: TYPE_INT64
9
+ dims: [-1, -1]
10
+ labels: true
11
+ inference:
12
+ wrap_func: visual_segmenter
13
+ return_type: MasksOutput
14
+ field_maps:
15
+ input_fields_map:
16
+ image: image
17
+ output_fields_map:
18
+ "regions[...].region_info.mask,regions[...].data.concepts": "predicted_mask"
@@ -21,7 +21,7 @@ from typing import Type
21
21
  from google.protobuf.text_format import MessageToString
22
22
  from tritonclient.grpc import model_config_pb2
23
23
 
24
- from .triton_config import TritonModelConfig
24
+ from .config import TritonModelConfig
25
25
 
26
26
 
27
27
  class Serializer:
@@ -6,7 +6,8 @@ import unittest
6
6
 
7
7
  import numpy as np
8
8
 
9
- from ..model_config.triton_config import TritonModelConfig
9
+ from ..model_config import ModelTypes
10
+ from ..model_config.config import get_model_config
10
11
  from .output import (ClassifierOutput, EmbeddingOutput, ImageOutput, MasksOutput, TextOutput,
11
12
  VisualDetectorOutput)
12
13
 
@@ -75,19 +76,18 @@ class DefaultTestInferenceModel(unittest.TestCase):
75
76
  model_repository=os.path.join(repo_version_dir, ".."),
76
77
  model_instance_kind="GPU" if self.is_instance_kind_gpu else "cpu"))
77
78
  # Get default config of model and model_type
78
- self.default_triton_model_config = TritonModelConfig(
79
- model_name=self.model_type,
80
- model_version="1",
81
- model_type=self.model_type,
82
- image_shape=[-1, -1])
79
+ self.default_triton_model_config = get_model_config(self.model_type).make_triton_model_config(
80
+ model_name=self.model_type, model_version="1", image_shape=[-1, -1])
83
81
  # Get current model config
84
82
  self.triton_model_config = self.triton_model.config_msg
85
83
  self.triton_model_input_name = self.triton_model.input_name
86
84
  self.preprocess = self._get_preprocess()
87
85
  # load labels
88
86
  self._required_label_model_types = [
89
- "visual-detector", "visual-classifier", "text-classifier", "visual-segmenter"
87
+ ModelTypes.visual_detector, ModelTypes.visual_classifier, ModelTypes.text_classifier,
88
+ ModelTypes.visual_segmenter
90
89
  ]
90
+ self._output_text_models = [ModelTypes.text_to_text]
91
91
  self.labels = []
92
92
  if self.model_type in self._required_label_model_types:
93
93
  with open(os.path.join(repo_version_dir, "../labels.txt"), 'r') as fp:
@@ -144,14 +144,15 @@ class DefaultTestInferenceModel(unittest.TestCase):
144
144
  for inp, output in zip(inputs, outputs):
145
145
 
146
146
  field = dataclasses.fields(output)[0].name
147
- self.assertEqual(
148
- len(self.triton_model_config.output[0].dims),
149
- len(getattr(output, field).shape),
150
- "Length of 'dims' of config and output must be matched, but get "
151
- f"Config {len(self.triton_model_config.output[0].dims)} != Output {len(getattr(output, field).shape)}"
152
- )
147
+ if self.model_type not in self._output_text_models:
148
+ self.assertEqual(
149
+ len(self.triton_model_config.output[0].dims),
150
+ len(getattr(output, field).shape),
151
+ "Length of 'dims' of config and output must be matched, but get "
152
+ f"Config {len(self.triton_model_config.output[0].dims)} != Output {len(getattr(output, field).shape)}"
153
+ )
153
154
 
154
- if self.model_type == "visual-detector":
155
+ if self.model_type == ModelTypes.visual_detector:
155
156
  logging.info(output.predicted_labels)
156
157
  self.assertEqual(
157
158
  type(output), VisualDetectorOutput,
@@ -166,7 +167,7 @@ class DefaultTestInferenceModel(unittest.TestCase):
166
167
  f"`predicted_labels` must be in [0, {len(self.labels) - 1}]")
167
168
  self.assertTrue(_is_integer(output.predicted_labels), "`predicted_labels` must be integer")
168
169
 
169
- elif self.model_type == "visual-classifier":
170
+ elif self.model_type == ModelTypes.visual_classifier:
170
171
  self.assertEqual(
171
172
  type(output), ClassifierOutput,
172
173
  f"Output type must be `ClassifierOutput`, but got {type(output)}")
@@ -179,7 +180,7 @@ class DefaultTestInferenceModel(unittest.TestCase):
179
180
  f"`predicted_labels` must equal to {len(self.labels)}, however got {len(output.predicted_scores)}"
180
181
  )
181
182
 
182
- elif self.model_type == "text-classifier":
183
+ elif self.model_type == ModelTypes.text_classifier:
183
184
  self.assertEqual(
184
185
  type(output), ClassifierOutput,
185
186
  f"Output type must be `ClassifierOutput`, but got {type(output)}")
@@ -192,29 +193,29 @@ class DefaultTestInferenceModel(unittest.TestCase):
192
193
  f"`predicted_labels` must equal to {len(self.labels)}, however got {len(output.predicted_scores)}"
193
194
  )
194
195
 
195
- elif self.model_type == "text-embedder":
196
+ elif self.model_type == ModelTypes.text_embedder:
196
197
  self.assertEqual(
197
198
  type(output), EmbeddingOutput,
198
199
  f"Output type must be `EmbeddingOutput`, but got {type(output)}")
199
200
  self.assertNotEqual(output.embedding_vector.shape, [])
200
201
 
201
- elif self.model_type == "text-to-text":
202
+ elif self.model_type == ModelTypes.text_to_text:
202
203
  self.assertEqual(
203
204
  type(output), TextOutput, f"Output type must be `TextOutput`, but got {type(output)}")
204
205
 
205
- elif self.model_type == "text-to-image":
206
+ elif self.model_type == ModelTypes.text_to_image:
206
207
  self.assertEqual(
207
208
  type(output), ImageOutput,
208
209
  f"Output type must be `ImageOutput`, but got {type(output)}")
209
210
  self.assertTrue(_is_non_negative(output.image), "`image` elements must be >= 0")
210
211
 
211
- elif self.model_type == "visual-embedder":
212
+ elif self.model_type == ModelTypes.visual_embedder:
212
213
  self.assertEqual(
213
214
  type(output), EmbeddingOutput,
214
215
  f"Output type must be `EmbeddingOutput`, but got {type(output)}")
215
216
  self.assertNotEqual(output.embedding_vector.shape, [])
216
217
 
217
- elif self.model_type == "visual-segmenter":
218
+ elif self.model_type == ModelTypes.visual_segmenter:
218
219
  self.assertEqual(
219
220
  type(output), MasksOutput,
220
221
  f"Output type must be `MasksOutput`, but got {type(output)}")
@@ -72,8 +72,8 @@ class TextOutput:
72
72
  """
73
73
  Validate input upon initialization.
74
74
  """
75
- assert self.predicted_text.ndim == 1, \
76
- f"All predictions must be 1-dimensional, Got text-dims: {self.predicted_text.ndim} instead."
75
+ assert self.predicted_text.ndim == 0, \
76
+ f"All predictions must be 0-dimensional, Got text-dims: {self.predicted_text.ndim} instead."
77
77
 
78
78
 
79
79
  @dataclass
@@ -19,8 +19,7 @@ import os
19
19
  from pathlib import Path
20
20
  from typing import Callable, Type
21
21
 
22
- from .model_config.serializer import Serializer
23
- from .model_config.triton_config import TritonModelConfig
22
+ from .model_config import Serializer, TritonModelConfig
24
23
  from .models import inference, pb_model, test
25
24
 
26
25
 
@@ -79,11 +78,9 @@ class TritonModelRepository:
79
78
  pass
80
79
  else:
81
80
  continue
82
- # gen requirements & conda yaml
81
+ # gen requirements
83
82
  with open(os.path.join(repository_path, "requirements.txt"), "w") as f:
84
83
  f.write("clarifai>9.5.3\ntritonclient[all]") # for model upload utils
85
- with open(os.path.join(repository_path, "triton_conda.yaml"), "w") as conda_env:
86
- conda_env.write("name: triton_conda-cp3.8-torch1.13.1-19f97078")
87
84
 
88
85
  if not os.path.isdir(model_version_path):
89
86
  os.mkdir(model_version_path)
File without changes
@@ -0,0 +1,33 @@
1
+ from clarifai_grpc.grpc.api import resources_pb2
2
+
3
+ from clarifai.client.runner import Runner
4
+
5
+
6
+ class MyRunner(Runner):
7
+ """A custom runner that adds "Hello World" to the end of the text and replaces the domain of the
8
+ image URL as an example.
9
+ """
10
+
11
+ def run_input(self, input: resources_pb2.Input) -> resources_pb2.Output:
12
+ """This is the method that will be called when the runner is run. It takes in an input and
13
+ returns an output.
14
+ """
15
+
16
+ output = resources_pb2.Output()
17
+
18
+ data = input.data
19
+
20
+ if data.text.raw != "":
21
+ output.data.text.raw = data.text.raw + "Hello World"
22
+ if data.image.url != "":
23
+ output.data.text.raw = data.image.url.replace("samples.clarifai.com", "newdomain.com")
24
+ return output
25
+
26
+
27
+ if __name__ == '__main__':
28
+ # Make sure you set these env vars before running the example.
29
+ # CLARIFAI_PAT
30
+ # CLARIFAI_USER_ID
31
+
32
+ # You need to first create a runner in the Clarifai API and then use the ID here.
33
+ MyRunner(runner_id="sdk-test-runner").start()