clarifai 11.1.5__py3-none-any.whl → 11.1.5rc1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (120) hide show
  1. clarifai/__init__.py +1 -1
  2. clarifai/__pycache__/__init__.cpython-310.pyc +0 -0
  3. clarifai/__pycache__/errors.cpython-310.pyc +0 -0
  4. clarifai/__pycache__/versions.cpython-310.pyc +0 -0
  5. clarifai/cli/__main__.py~ +4 -0
  6. clarifai/cli/__pycache__/__init__.cpython-310.pyc +0 -0
  7. clarifai/cli/__pycache__/__main__.cpython-310.pyc +0 -0
  8. clarifai/cli/__pycache__/base.cpython-310.pyc +0 -0
  9. clarifai/cli/__pycache__/compute_cluster.cpython-310.pyc +0 -0
  10. clarifai/cli/__pycache__/deployment.cpython-310.pyc +0 -0
  11. clarifai/cli/__pycache__/model.cpython-310.pyc +0 -0
  12. clarifai/cli/__pycache__/nodepool.cpython-310.pyc +0 -0
  13. clarifai/cli/model.py +25 -0
  14. clarifai/client/__pycache__/__init__.cpython-310.pyc +0 -0
  15. clarifai/client/__pycache__/app.cpython-310.pyc +0 -0
  16. clarifai/client/__pycache__/base.cpython-310.pyc +0 -0
  17. clarifai/client/__pycache__/dataset.cpython-310.pyc +0 -0
  18. clarifai/client/__pycache__/input.cpython-310.pyc +0 -0
  19. clarifai/client/__pycache__/lister.cpython-310.pyc +0 -0
  20. clarifai/client/__pycache__/model.cpython-310.pyc +0 -0
  21. clarifai/client/__pycache__/module.cpython-310.pyc +0 -0
  22. clarifai/client/__pycache__/runner.cpython-310.pyc +0 -0
  23. clarifai/client/__pycache__/search.cpython-310.pyc +0 -0
  24. clarifai/client/__pycache__/user.cpython-310.pyc +0 -0
  25. clarifai/client/__pycache__/workflow.cpython-310.pyc +0 -0
  26. clarifai/client/auth/__pycache__/__init__.cpython-310.pyc +0 -0
  27. clarifai/client/auth/__pycache__/helper.cpython-310.pyc +0 -0
  28. clarifai/client/auth/__pycache__/register.cpython-310.pyc +0 -0
  29. clarifai/client/auth/__pycache__/stub.cpython-310.pyc +0 -0
  30. clarifai/client/model.py +90 -365
  31. clarifai/client/model_client.py +400 -0
  32. clarifai/constants/__pycache__/dataset.cpython-310.pyc +0 -0
  33. clarifai/constants/__pycache__/model.cpython-310.pyc +0 -0
  34. clarifai/constants/__pycache__/search.cpython-310.pyc +0 -0
  35. clarifai/datasets/__pycache__/__init__.cpython-310.pyc +0 -0
  36. clarifai/datasets/export/__pycache__/__init__.cpython-310.pyc +0 -0
  37. clarifai/datasets/export/__pycache__/inputs_annotations.cpython-310.pyc +0 -0
  38. clarifai/datasets/upload/__pycache__/__init__.cpython-310.pyc +0 -0
  39. clarifai/datasets/upload/__pycache__/base.cpython-310.pyc +0 -0
  40. clarifai/datasets/upload/__pycache__/features.cpython-310.pyc +0 -0
  41. clarifai/datasets/upload/__pycache__/image.cpython-310.pyc +0 -0
  42. clarifai/datasets/upload/__pycache__/text.cpython-310.pyc +0 -0
  43. clarifai/datasets/upload/__pycache__/utils.cpython-310.pyc +0 -0
  44. clarifai/datasets/upload/loaders/__pycache__/__init__.cpython-310.pyc +0 -0
  45. clarifai/datasets/upload/loaders/__pycache__/coco_detection.cpython-310.pyc +0 -0
  46. clarifai/models/__pycache__/__init__.cpython-310.pyc +0 -0
  47. clarifai/models/model_serving/__pycache__/__init__.cpython-310.pyc +0 -0
  48. clarifai/models/model_serving/__pycache__/constants.cpython-310.pyc +0 -0
  49. clarifai/models/model_serving/cli/__pycache__/__init__.cpython-310.pyc +0 -0
  50. clarifai/models/model_serving/cli/__pycache__/_utils.cpython-310.pyc +0 -0
  51. clarifai/models/model_serving/cli/__pycache__/base.cpython-310.pyc +0 -0
  52. clarifai/models/model_serving/cli/__pycache__/build.cpython-310.pyc +0 -0
  53. clarifai/models/model_serving/cli/__pycache__/create.cpython-310.pyc +0 -0
  54. clarifai/models/model_serving/model_config/__pycache__/__init__.cpython-310.pyc +0 -0
  55. clarifai/models/model_serving/model_config/__pycache__/base.cpython-310.pyc +0 -0
  56. clarifai/models/model_serving/model_config/__pycache__/config.cpython-310.pyc +0 -0
  57. clarifai/models/model_serving/model_config/__pycache__/inference_parameter.cpython-310.pyc +0 -0
  58. clarifai/models/model_serving/model_config/__pycache__/output.cpython-310.pyc +0 -0
  59. clarifai/models/model_serving/model_config/triton/__pycache__/__init__.cpython-310.pyc +0 -0
  60. clarifai/models/model_serving/model_config/triton/__pycache__/serializer.cpython-310.pyc +0 -0
  61. clarifai/models/model_serving/model_config/triton/__pycache__/triton_config.cpython-310.pyc +0 -0
  62. clarifai/models/model_serving/model_config/triton/__pycache__/wrappers.cpython-310.pyc +0 -0
  63. clarifai/models/model_serving/repo_build/__pycache__/__init__.cpython-310.pyc +0 -0
  64. clarifai/models/model_serving/repo_build/__pycache__/build.cpython-310.pyc +0 -0
  65. clarifai/models/model_serving/repo_build/static_files/__pycache__/base_test.cpython-310-pytest-7.2.0.pyc +0 -0
  66. clarifai/rag/__pycache__/__init__.cpython-310.pyc +0 -0
  67. clarifai/rag/__pycache__/rag.cpython-310.pyc +0 -0
  68. clarifai/rag/__pycache__/utils.cpython-310.pyc +0 -0
  69. clarifai/runners/__init__.py +2 -7
  70. clarifai/runners/__pycache__/__init__.cpython-310.pyc +0 -0
  71. clarifai/runners/__pycache__/server.cpython-310.pyc +0 -0
  72. clarifai/runners/dockerfile_template/Dockerfile.debug +11 -0
  73. clarifai/runners/dockerfile_template/Dockerfile.debug~ +9 -0
  74. clarifai/runners/dockerfile_template/Dockerfile.template +3 -0
  75. clarifai/runners/models/__pycache__/__init__.cpython-310.pyc +0 -0
  76. clarifai/runners/models/__pycache__/base_typed_model.cpython-310.pyc +0 -0
  77. clarifai/runners/models/__pycache__/model_builder.cpython-310.pyc +0 -0
  78. clarifai/runners/models/__pycache__/model_class.cpython-310.pyc +0 -0
  79. clarifai/runners/models/__pycache__/model_run_locally.cpython-310.pyc +0 -0
  80. clarifai/runners/models/__pycache__/model_runner.cpython-310.pyc +0 -0
  81. clarifai/runners/models/__pycache__/model_servicer.cpython-310.pyc +0 -0
  82. clarifai/runners/models/__pycache__/model_upload.cpython-310.pyc +0 -0
  83. clarifai/runners/models/model_builder.py +33 -7
  84. clarifai/runners/models/model_class.py +249 -25
  85. clarifai/runners/models/model_runner.py +2 -0
  86. clarifai/runners/models/model_servicer.py +11 -2
  87. clarifai/runners/server.py +5 -1
  88. clarifai/runners/utils/__pycache__/__init__.cpython-310.pyc +0 -0
  89. clarifai/runners/utils/__pycache__/const.cpython-310.pyc +0 -0
  90. clarifai/runners/utils/__pycache__/data_handler.cpython-310.pyc +0 -0
  91. clarifai/runners/utils/__pycache__/data_utils.cpython-310.pyc +0 -0
  92. clarifai/runners/utils/__pycache__/loader.cpython-310.pyc +0 -0
  93. clarifai/runners/utils/__pycache__/logging.cpython-310.pyc +0 -0
  94. clarifai/runners/utils/__pycache__/method_signatures.cpython-310.pyc +0 -0
  95. clarifai/runners/utils/__pycache__/serializers.cpython-310.pyc +0 -0
  96. clarifai/runners/utils/__pycache__/url_fetcher.cpython-310.pyc +0 -0
  97. clarifai/runners/utils/data_handler.py +308 -205
  98. clarifai/runners/utils/method_signatures.py +437 -0
  99. clarifai/runners/utils/serializers.py +132 -0
  100. clarifai/schema/__pycache__/search.cpython-310.pyc +0 -0
  101. clarifai/urls/__pycache__/helper.cpython-310.pyc +0 -0
  102. clarifai/utils/__pycache__/__init__.cpython-310.pyc +0 -0
  103. clarifai/utils/__pycache__/logging.cpython-310.pyc +0 -0
  104. clarifai/utils/__pycache__/misc.cpython-310.pyc +0 -0
  105. clarifai/utils/__pycache__/model_train.cpython-310.pyc +0 -0
  106. clarifai/utils/evaluation/__pycache__/__init__.cpython-310.pyc +0 -0
  107. clarifai/utils/evaluation/__pycache__/helpers.cpython-310.pyc +0 -0
  108. clarifai/utils/evaluation/__pycache__/main.cpython-310.pyc +0 -0
  109. clarifai/workflows/__pycache__/__init__.cpython-310.pyc +0 -0
  110. clarifai/workflows/__pycache__/export.cpython-310.pyc +0 -0
  111. clarifai/workflows/__pycache__/utils.cpython-310.pyc +0 -0
  112. clarifai/workflows/__pycache__/validate.cpython-310.pyc +0 -0
  113. {clarifai-11.1.5.dist-info → clarifai-11.1.5rc1.dist-info}/METADATA +16 -26
  114. clarifai-11.1.5rc1.dist-info/RECORD +201 -0
  115. {clarifai-11.1.5.dist-info → clarifai-11.1.5rc1.dist-info}/WHEEL +1 -1
  116. clarifai/runners/models/base_typed_model.py +0 -238
  117. clarifai-11.1.5.dist-info/RECORD +0 -101
  118. {clarifai-11.1.5.dist-info → clarifai-11.1.5rc1.dist-info}/LICENSE +0 -0
  119. {clarifai-11.1.5.dist-info → clarifai-11.1.5rc1.dist-info}/entry_points.txt +0 -0
  120. {clarifai-11.1.5.dist-info → clarifai-11.1.5rc1.dist-info}/top_level.txt +0 -0
@@ -1,41 +1,265 @@
1
- from abc import ABC, abstractmethod
2
- from typing import Iterator
1
+ import inspect
2
+ import itertools
3
+ import logging
4
+ import os
5
+ import traceback
6
+ import types
7
+ from abc import ABC
8
+ from typing import Any, Dict, Iterator, List
3
9
 
4
- from clarifai_grpc.grpc.api import service_pb2
10
+ from clarifai_grpc.grpc.api import resources_pb2, service_pb2
11
+ from clarifai_grpc.grpc.api.status import status_code_pb2, status_pb2
12
+
13
+ from clarifai.runners.utils import data_handler
14
+ from clarifai.runners.utils.method_signatures import (build_function_signature, deserialize,
15
+ serialize, signatures_to_json)
16
+
17
+ _METHOD_INFO_ATTR = '_cf_method_info'
18
+
19
+ _RAISE_EXCEPTIONS = os.getenv("RAISE_EXCEPTIONS", "false").lower() == "true"
5
20
 
6
21
 
7
22
  class ModelClass(ABC):
8
23
 
24
+ def load_model(self):
25
+ """Load the model."""
26
+ pass
27
+
28
+ def predict(self, **kwargs):
29
+ """Predict method for single or batched inputs."""
30
+ raise NotImplementedError("predict() not implemented")
31
+
32
+ def generate(self, **kwargs) -> Iterator:
33
+ """Generate method for streaming outputs."""
34
+ raise NotImplementedError("generate() not implemented")
35
+
36
+ def stream(self, **kwargs) -> Iterator:
37
+ """Stream method for streaming inputs and outputs."""
38
+ raise NotImplementedError("stream() not implemented")
39
+
40
+ def _handle_get_signatures_request(self) -> service_pb2.MultiOutputResponse:
41
+ methods = self._get_method_info()
42
+ signatures = {method.name: method.signature for method in methods.values()}
43
+ resp = service_pb2.MultiOutputResponse(status=status_pb2.Status(code=status_code_pb2.SUCCESS))
44
+ resp.outputs.add().data.string_value = signatures_to_json(signatures)
45
+ return resp
46
+
47
+ def batch_predict(self, method, inputs: List[Dict[str, Any]]) -> List[Any]:
48
+ """Batch predict method for multiple inputs."""
49
+ outputs = []
50
+ for input in inputs:
51
+ output = method(**input)
52
+ outputs.append(output)
53
+ return outputs
54
+
55
+ def batch_generate(self, method, inputs: List[Dict[str, Any]]) -> Iterator[List[Any]]:
56
+ """Batch generate method for multiple inputs."""
57
+ generators = [method(**input) for input in inputs]
58
+ for outputs in itertools.zip_longest(*generators):
59
+ yield outputs
60
+
9
61
  def predict_wrapper(
10
62
  self, request: service_pb2.PostModelOutputsRequest) -> service_pb2.MultiOutputResponse:
11
- """This method is used for input/output proto data conversion"""
12
- return self.predict(request)
63
+ outputs = []
64
+ try:
65
+ # TODO add method name field to proto
66
+ method_name = request.model.model_version.output_info.params['_method_name']
67
+ if method_name == '_GET_SIGNATURES': # special case to fetch signatures, TODO add endpoint for this
68
+ return self._handle_get_signatures_request()
69
+ if method_name not in self._get_method_info():
70
+ raise ValueError(f"Method {method_name} not found in model class")
71
+ method = getattr(self, method_name)
72
+ method_info = method._cf_method_info
73
+ signature = method_info.signature
74
+ python_param_types = method_info.python_param_types
75
+ inputs = self._convert_input_protos_to_python(request.inputs, signature.inputs,
76
+ python_param_types)
77
+ if len(inputs) == 1:
78
+ inputs = inputs[0]
79
+ output = method(**inputs)
80
+ outputs.append(self._convert_output_to_proto(output, signature.outputs))
81
+ else:
82
+ outputs = self.batch_predict(method, inputs)
83
+ outputs = [self._convert_output_to_proto(output, signature.outputs) for output in outputs]
84
+
85
+ return service_pb2.MultiOutputResponse(
86
+ outputs=outputs, status=status_pb2.Status(code=status_code_pb2.SUCCESS))
87
+ except Exception as e:
88
+ if _RAISE_EXCEPTIONS:
89
+ raise
90
+ logging.exception("Error in predict")
91
+ return service_pb2.MultiOutputResponse(status=status_pb2.Status(
92
+ code=status_code_pb2.FAILURE,
93
+ details=str(e),
94
+ stack_trace=traceback.format_exc().split('\n')))
13
95
 
14
96
  def generate_wrapper(self, request: service_pb2.PostModelOutputsRequest
15
97
  ) -> Iterator[service_pb2.MultiOutputResponse]:
16
- """This method is used for input/output proto data conversion and yield outcome"""
17
- return self.generate(request)
98
+ try:
99
+ method_name = request.model.model_version.output_info.params['_method_name']
100
+ method = getattr(self, method_name)
101
+ method_info = method._cf_method_info
102
+ signature = method_info.signature
103
+ python_param_types = method_info.python_param_types
18
104
 
19
- def stream_wrapper(self, request: service_pb2.PostModelOutputsRequest
105
+ inputs = self._convert_input_protos_to_python(request.inputs, signature.inputs,
106
+ python_param_types)
107
+ if len(inputs) == 1:
108
+ inputs = inputs[0]
109
+ for output in method(**inputs):
110
+ resp = service_pb2.MultiOutputResponse()
111
+ self._convert_output_to_proto(output, signature.outputs, proto=resp.outputs.add())
112
+ resp.status.code = status_code_pb2.SUCCESS
113
+ yield resp
114
+ else:
115
+ for outputs in self.batch_generate(method, inputs):
116
+ resp = service_pb2.MultiOutputResponse()
117
+ for output in outputs:
118
+ self._convert_output_to_proto(output, signature.outputs, proto=resp.outputs.add())
119
+ resp.status.code = status_code_pb2.SUCCESS
120
+ yield resp
121
+ except Exception as e:
122
+ if _RAISE_EXCEPTIONS:
123
+ raise
124
+ logging.exception("Error in generate")
125
+ yield service_pb2.MultiOutputResponse(status=status_pb2.Status(
126
+ code=status_code_pb2.FAILURE,
127
+ details=str(e),
128
+ stack_trace=traceback.format_exc().split('\n')))
129
+
130
+ def stream_wrapper(self, request_iterator: Iterator[service_pb2.PostModelOutputsRequest]
20
131
  ) -> Iterator[service_pb2.MultiOutputResponse]:
21
- """This method is used for input/output proto data conversion and yield outcome"""
22
- return self.stream(request)
132
+ try:
133
+ request = next(request_iterator) # get first request to determine method
134
+ assert len(request.inputs) == 1, "Streaming requires exactly one input"
23
135
 
24
- @abstractmethod
25
- def load_model(self):
26
- raise NotImplementedError("load_model() not implemented")
136
+ method_name = request.model.model_version.output_info.params['_method_name']
137
+ method = getattr(self, method_name)
138
+ method_info = method._cf_method_info
139
+ signature = method_info.signature
140
+ python_param_types = method_info.python_param_types
27
141
 
28
- @abstractmethod
29
- def predict(self,
30
- request: service_pb2.PostModelOutputsRequest) -> service_pb2.MultiOutputResponse:
31
- raise NotImplementedError("run_input() not implemented")
142
+ # find the streaming vars in the signature
143
+ streaming_var_signatures = [var for var in signature.inputs if var.streaming]
144
+ stream_argname = set([var.name.split('.', 1)[0] for var in streaming_var_signatures])
145
+ assert len(
146
+ stream_argname) == 1, 'streaming methods must have exactly one streaming function arg'
147
+ stream_argname = stream_argname.pop()
32
148
 
33
- @abstractmethod
34
- def generate(self, request: service_pb2.PostModelOutputsRequest
35
- ) -> Iterator[service_pb2.MultiOutputResponse]:
36
- raise NotImplementedError("generate() not implemented")
149
+ # convert all inputs for the first request, including the first stream value
150
+ inputs = self._convert_input_protos_to_python(request.inputs, signature.inputs,
151
+ python_param_types)
152
+ kwargs = inputs[0]
37
153
 
38
- @abstractmethod
39
- def stream(self, request_iterator: Iterator[service_pb2.PostModelOutputsRequest]
40
- ) -> Iterator[service_pb2.MultiOutputResponse]:
41
- raise NotImplementedError("stream() not implemented")
154
+ # first streaming item
155
+ first_item = kwargs.pop(stream_argname)
156
+
157
+ # streaming generator
158
+ def InputStream():
159
+ yield first_item
160
+ # subsequent streaming items contain only the streaming input
161
+ for request in request_iterator:
162
+ item = self._convert_input_protos_to_python(request.inputs, streaming_var_signatures,
163
+ python_param_types)
164
+ item = item[0][stream_argname]
165
+ yield item
166
+
167
+ # add stream generator back to the input kwargs
168
+ kwargs[stream_argname] = InputStream()
169
+
170
+ for output in method(**kwargs):
171
+ resp = service_pb2.MultiOutputResponse()
172
+ self._convert_output_to_proto(output, signature.outputs, proto=resp.outputs.add())
173
+ resp.status.code = status_code_pb2.SUCCESS
174
+ yield resp
175
+ except Exception as e:
176
+ if _RAISE_EXCEPTIONS:
177
+ raise
178
+ logging.exception("Error in stream")
179
+ yield service_pb2.MultiOutputResponse(status=status_pb2.Status(
180
+ code=status_code_pb2.FAILURE,
181
+ details=str(e),
182
+ stack_trace=traceback.format_exc().split('\n')))
183
+
184
+ def _convert_input_protos_to_python(self, inputs: List[resources_pb2.Input], variables_signature,
185
+ python_param_types) -> List[Dict[str, Any]]:
186
+ result = []
187
+ for input in inputs:
188
+ kwargs = deserialize(input.data, variables_signature)
189
+ # dynamic cast to annotated types
190
+ for k, v in kwargs.items():
191
+ if k not in python_param_types:
192
+ continue
193
+ kwargs[k] = data_handler.cast(v, python_param_types[k])
194
+ result.append(kwargs)
195
+ return result
196
+
197
+ def _convert_output_to_proto(self, output: Any, variables_signature,
198
+ proto=None) -> resources_pb2.Output:
199
+ if proto is None:
200
+ proto = resources_pb2.Output()
201
+ if isinstance(output, tuple):
202
+ output = {f'return.{i}': item for i, item in enumerate(output)}
203
+ if not isinstance(output, dict): # TODO Output type, not just dict
204
+ output = {'return': output}
205
+ serialize(output, variables_signature, proto.data, is_output=True)
206
+ return proto
207
+
208
+ @classmethod
209
+ def _register_model_methods(cls):
210
+ # go up the class hierarchy to find all decorated methods, and add to registry of current class
211
+ methods = {}
212
+ for base in reversed(cls.__mro__):
213
+ for name, method in base.__dict__.items():
214
+ method_info = getattr(method, _METHOD_INFO_ATTR, None)
215
+ if not method_info: # regular function, not a model method
216
+ continue
217
+ methods[name] = method_info
218
+ # check for generic predict(request) -> response, etc. methods
219
+ #for name in ('predict', 'generate', 'stream'):
220
+ # if hasattr(cls, name):
221
+ # method = getattr(cls, name)
222
+ # if not hasattr(method, _METHOD_INFO_ATTR): # not already put in registry
223
+ # methods[name] = _MethodInfo(method, method_type=name)
224
+ # set method table for this class in the registry
225
+ return methods
226
+
227
+ @classmethod
228
+ def _get_method_info(cls, func_name=None):
229
+ if not hasattr(cls, _METHOD_INFO_ATTR):
230
+ setattr(cls, _METHOD_INFO_ATTR, cls._register_model_methods())
231
+ method_info = getattr(cls, _METHOD_INFO_ATTR)
232
+ if func_name:
233
+ return method_info[func_name]
234
+ return method_info
235
+
236
+
237
+ class _MethodInfo:
238
+
239
+ def __init__(self, method, method_type):
240
+ self.name = method.__name__
241
+ self.signature = build_function_signature(method, method_type)
242
+ self.python_param_types = {
243
+ p.name: p.annotation
244
+ for p in inspect.signature(method).parameters.values()
245
+ if p.annotation != inspect.Parameter.empty
246
+ }
247
+ self.python_param_types.pop('self', None)
248
+
249
+
250
+ def predict(method):
251
+ setattr(method, _METHOD_INFO_ATTR, _MethodInfo(method, 'predict'))
252
+ return method
253
+
254
+
255
+ def generate(method):
256
+ setattr(method, _METHOD_INFO_ATTR, _MethodInfo(method, 'generate'))
257
+ return method
258
+
259
+
260
+ def stream(method):
261
+ setattr(method, _METHOD_INFO_ATTR, _MethodInfo(method, 'stream'))
262
+ return method
263
+
264
+
265
+ methods = types.SimpleNamespace(predict=predict, generate=generate, stream=stream)
@@ -82,6 +82,8 @@ class ModelRunner(BaseRunner, HealthProbeRequestHandler):
82
82
  ensure_urls_downloaded(request)
83
83
 
84
84
  resp = self.model.predict_wrapper(request)
85
+ if resp.status.code != status_code_pb2.SUCCESS:
86
+ return service_pb2.RunnerItemOutput(multi_output_response=resp)
85
87
  successes = [o.status.code == status_code_pb2.SUCCESS for o in resp.outputs]
86
88
  if all(successes):
87
89
  status = status_pb2.Status(
@@ -1,3 +1,4 @@
1
+ import os
1
2
  from itertools import tee
2
3
  from typing import Iterator
3
4
 
@@ -6,6 +7,8 @@ from clarifai_grpc.grpc.api.status import status_code_pb2, status_pb2
6
7
 
7
8
  from ..utils.url_fetcher import ensure_urls_downloaded
8
9
 
10
+ _RAISE_EXCEPTIONS = os.getenv("RAISE_EXCEPTIONS", "false").lower() in ("true", "1")
11
+
9
12
 
10
13
  class ModelServicer(service_pb2_grpc.V2Servicer):
11
14
  """
@@ -33,6 +36,8 @@ class ModelServicer(service_pb2_grpc.V2Servicer):
33
36
  try:
34
37
  return self.model.predict_wrapper(request)
35
38
  except Exception as e:
39
+ if _RAISE_EXCEPTIONS:
40
+ raise
36
41
  return service_pb2.MultiOutputResponse(status=status_pb2.Status(
37
42
  code=status_code_pb2.MODEL_PREDICTION_FAILED,
38
43
  description="Failed",
@@ -50,8 +55,10 @@ class ModelServicer(service_pb2_grpc.V2Servicer):
50
55
  ensure_urls_downloaded(request)
51
56
 
52
57
  try:
53
- return self.model.generate_wrapper(request)
58
+ yield from self.model.generate_wrapper(request)
54
59
  except Exception as e:
60
+ if _RAISE_EXCEPTIONS:
61
+ raise
55
62
  yield service_pb2.MultiOutputResponse(status=status_pb2.Status(
56
63
  code=status_code_pb2.MODEL_PREDICTION_FAILED,
57
64
  description="Failed",
@@ -74,8 +81,10 @@ class ModelServicer(service_pb2_grpc.V2Servicer):
74
81
  ensure_urls_downloaded(req)
75
82
 
76
83
  try:
77
- return self.model.stream_wrapper(request_copy)
84
+ yield from self.model.stream_wrapper(request_copy)
78
85
  except Exception as e:
86
+ if _RAISE_EXCEPTIONS:
87
+ raise
79
88
  yield service_pb2.MultiOutputResponse(status=status_pb2.Status(
80
89
  code=status_code_pb2.MODEL_PREDICTION_FAILED,
81
90
  description="Failed",
@@ -85,6 +85,10 @@ def serve(model_path,
85
85
 
86
86
  model = builder.create_model_instance()
87
87
 
88
+ # `num_threads` can be set in config.yaml or via the environment variable CLARIFAI_NUM_THREADS="<integer>".
89
+ # Note: The value in config.yaml takes precedence over the environment variable.
90
+ num_threads = builder.config.get("num_threads")
91
+
88
92
  # Setup the grpc server for local development.
89
93
  if grpc:
90
94
 
@@ -115,7 +119,7 @@ def serve(model_path,
115
119
  nodepool_id=os.environ["CLARIFAI_NODEPOOL_ID"],
116
120
  compute_cluster_id=os.environ["CLARIFAI_COMPUTE_CLUSTER_ID"],
117
121
  base_url=os.environ.get("CLARIFAI_API_BASE", "https://api.clarifai.com"),
118
- num_parallel_polls=int(os.environ.get("CLARIFAI_NUM_THREADS", 1)),
122
+ num_parallel_polls=num_threads,
119
123
  )
120
124
  runner.start() # start the runner to fetch work from the API.
121
125