clarifai 11.1.4rc2__py3-none-any.whl → 11.1.5rc1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (46) hide show
  1. clarifai/__init__.py +1 -1
  2. clarifai/cli/__pycache__/model.cpython-310.pyc +0 -0
  3. clarifai/cli/model.py +46 -10
  4. clarifai/client/model.py +89 -364
  5. clarifai/client/model_client.py +400 -0
  6. clarifai/client/workflow.py +2 -2
  7. clarifai/datasets/upload/loaders/__pycache__/__init__.cpython-310.pyc +0 -0
  8. clarifai/datasets/upload/loaders/__pycache__/coco_detection.cpython-310.pyc +0 -0
  9. clarifai/rag/__pycache__/rag.cpython-310.pyc +0 -0
  10. clarifai/runners/__init__.py +2 -7
  11. clarifai/runners/__pycache__/__init__.cpython-310.pyc +0 -0
  12. clarifai/runners/__pycache__/server.cpython-310.pyc +0 -0
  13. clarifai/runners/dockerfile_template/Dockerfile.template +4 -32
  14. clarifai/runners/models/__pycache__/base_typed_model.cpython-310.pyc +0 -0
  15. clarifai/runners/models/__pycache__/model_builder.cpython-310.pyc +0 -0
  16. clarifai/runners/models/__pycache__/model_class.cpython-310.pyc +0 -0
  17. clarifai/runners/models/__pycache__/model_run_locally.cpython-310.pyc +0 -0
  18. clarifai/runners/models/__pycache__/model_runner.cpython-310.pyc +0 -0
  19. clarifai/runners/models/__pycache__/model_servicer.cpython-310.pyc +0 -0
  20. clarifai/runners/models/model_builder.py +47 -20
  21. clarifai/runners/models/model_class.py +249 -25
  22. clarifai/runners/models/model_run_locally.py +5 -2
  23. clarifai/runners/models/model_runner.py +2 -0
  24. clarifai/runners/models/model_servicer.py +11 -2
  25. clarifai/runners/server.py +26 -9
  26. clarifai/runners/utils/__pycache__/const.cpython-310.pyc +0 -0
  27. clarifai/runners/utils/__pycache__/data_handler.cpython-310.pyc +0 -0
  28. clarifai/runners/utils/__pycache__/method_signatures.cpython-310.pyc +0 -0
  29. clarifai/runners/utils/__pycache__/serializers.cpython-310.pyc +0 -0
  30. clarifai/runners/utils/const.py +1 -1
  31. clarifai/runners/utils/data_handler.py +308 -205
  32. clarifai/runners/utils/method_signatures.py +437 -0
  33. clarifai/runners/utils/serializers.py +132 -0
  34. clarifai/utils/evaluation/__pycache__/__init__.cpython-310.pyc +0 -0
  35. clarifai/utils/evaluation/__pycache__/helpers.cpython-310.pyc +0 -0
  36. clarifai/utils/evaluation/__pycache__/main.cpython-310.pyc +0 -0
  37. clarifai/utils/misc.py +12 -0
  38. {clarifai-11.1.4rc2.dist-info → clarifai-11.1.5rc1.dist-info}/METADATA +3 -2
  39. {clarifai-11.1.4rc2.dist-info → clarifai-11.1.5rc1.dist-info}/RECORD +43 -36
  40. clarifai/runners/models/base_typed_model.py +0 -238
  41. clarifai/runners/models/model_upload.py +0 -607
  42. clarifai/runners/utils/#const.py# +0 -30
  43. {clarifai-11.1.4rc2.dist-info → clarifai-11.1.5rc1.dist-info}/LICENSE +0 -0
  44. {clarifai-11.1.4rc2.dist-info → clarifai-11.1.5rc1.dist-info}/WHEEL +0 -0
  45. {clarifai-11.1.4rc2.dist-info → clarifai-11.1.5rc1.dist-info}/entry_points.txt +0 -0
  46. {clarifai-11.1.4rc2.dist-info → clarifai-11.1.5rc1.dist-info}/top_level.txt +0 -0
@@ -14,13 +14,14 @@ from google.protobuf import json_format
14
14
  from rich import print
15
15
  from rich.markup import escape
16
16
 
17
- from clarifai.client import BaseClient
17
+ from clarifai.client.base import BaseClient
18
18
  from clarifai.runners.models.model_class import ModelClass
19
19
  from clarifai.runners.utils.const import (
20
20
  AVAILABLE_PYTHON_IMAGES, AVAILABLE_TORCH_IMAGES, CONCEPTS_REQUIRED_MODEL_TYPE,
21
21
  DEFAULT_DOWNLOAD_CHECKPOINT_WHEN, DEFAULT_PYTHON_VERSION, DEFAULT_RUNTIME_DOWNLOAD_PATH,
22
22
  PYTHON_BASE_IMAGE, TORCH_BASE_IMAGE)
23
23
  from clarifai.runners.utils.loader import HuggingFaceLoader
24
+ from clarifai.runners.utils.method_signatures import signatures_to_yaml
24
25
  from clarifai.urls.helper import ClarifaiUrlHelper
25
26
  from clarifai.utils.logging import logger
26
27
  from clarifai.versions import CLIENT_VERSION
@@ -69,6 +70,18 @@ class ModelBuilder:
69
70
  """
70
71
  Create an instance of the model class, as specified in the config file.
71
72
  """
73
+ model_class = self.load_model_class()
74
+
75
+ # initialize the model
76
+ model = model_class()
77
+ if load_model:
78
+ model.load_model()
79
+ return model
80
+
81
+ def load_model_class(self):
82
+ """
83
+ Import the model class from the model.py file.
84
+ """
72
85
  # look for default model.py file location
73
86
  for loc in ["model.py", "1/model.py"]:
74
87
  model_file = os.path.join(self.folder, loc)
@@ -107,12 +120,7 @@ class ModelBuilder:
107
120
  "Could not determine model class. There should be exactly one model inheriting from ModelClass defined in the model.py"
108
121
  )
109
122
  model_class = classes[0]
110
-
111
- # initialize the model
112
- model = model_class()
113
- if load_model:
114
- model.load_model()
115
- return model
123
+ return model_class
116
124
 
117
125
  def _validate_folder(self, folder):
118
126
  if folder == ".":
@@ -159,7 +167,6 @@ class ModelBuilder:
159
167
  f"No 'when' specified in the config file for checkpoints, defaulting to download at {DEFAULT_DOWNLOAD_CHECKPOINT_WHEN}"
160
168
  )
161
169
  when = checkpoints.get("when", DEFAULT_DOWNLOAD_CHECKPOINT_WHEN)
162
- # In the config.yaml we don't allow "any", that's only used in download_checkpoints to force download.
163
170
  assert when in [
164
171
  "upload",
165
172
  "build",
@@ -227,6 +234,15 @@ class ModelBuilder:
227
234
  )
228
235
  logger.info("Continuing without Hugging Face token")
229
236
 
237
+ num_threads = self.config.get("num_threads")
238
+ if num_threads or num_threads == 0:
239
+ assert isinstance(num_threads, int) and num_threads >= 1, ValueError(
240
+ f"`num_threads` must be an integer greater than or equal to 1. Received type {type(num_threads)} with value {num_threads}."
241
+ )
242
+ else:
243
+ num_threads = int(os.environ.get("CLARIFAI_NUM_THREADS", 1))
244
+ self.config["num_threads"] = num_threads
245
+
230
246
  @staticmethod
231
247
  def _get_tar_file_content_size(tar_file_path):
232
248
  """
@@ -245,6 +261,15 @@ class ModelBuilder:
245
261
  total_size += member.size
246
262
  return total_size
247
263
 
264
+ def method_signatures_yaml(self):
265
+ """
266
+ Returns the method signatures for the model class in YAML format.
267
+ """
268
+ model_class = self.load_model_class()
269
+ method_info = model_class._get_method_info()
270
+ signatures = {name: m.signature for name, m in method_info.items()}
271
+ return signatures_to_yaml(signatures)
272
+
248
273
  @property
249
274
  def client(self):
250
275
  if self._client is None:
@@ -366,10 +391,10 @@ class ModelBuilder:
366
391
  if 'python_version' in build_info:
367
392
  python_version = build_info['python_version']
368
393
  if python_version not in AVAILABLE_PYTHON_IMAGES:
369
- logger.error(
370
- f"Python version {python_version} not supported, please use one of the following versions: {AVAILABLE_PYTHON_IMAGES}"
394
+ raise Exception(
395
+ f"Python version {python_version} not supported, please use one of the following versions: {AVAILABLE_PYTHON_IMAGES} in your config.yaml"
371
396
  )
372
- return
397
+
373
398
  logger.info(
374
399
  f"Using Python version {python_version} from the config file to build the Dockerfile")
375
400
  else:
@@ -443,7 +468,7 @@ class ModelBuilder:
443
468
 
444
469
  @property
445
470
  def checkpoint_suffix(self):
446
- return '1/checkpoints'
471
+ return os.path.join('1', 'checkpoints')
447
472
 
448
473
  @property
449
474
  def tar_file(self):
@@ -452,13 +477,15 @@ class ModelBuilder:
452
477
  def default_runtime_checkpoint_path(self):
453
478
  return DEFAULT_RUNTIME_DOWNLOAD_PATH
454
479
 
455
- def download_checkpoints(self, stage: str, checkpoint_path_override: str = None):
480
+ def download_checkpoints(self,
481
+ stage: str = DEFAULT_DOWNLOAD_CHECKPOINT_WHEN,
482
+ checkpoint_path_override: str = None):
456
483
  """
457
484
  Downloads the checkpoints specified in the config file.
458
485
 
459
486
  :param stage: The stage of the build process. This is used to determine when to download the
460
- checkpoints. The stage can be one of ['build', 'upload', 'runtime', 'any']. If "any" it will always try to download
461
- regardless of what is specified in config.yaml. Otherwise it must match what is in config.yaml
487
+ checkpoints. The stage can be one of ['build', 'upload', 'runtime']. If you want to force
488
+ downloading now then set stage to match e when field of the checkpoints section of you config.yaml.
462
489
  :param checkpoint_path_override: The path to download the checkpoints to (with 1/checkpoints added as suffix). If not provided, the
463
490
  default path is used based on the folder ModelUploader was initialized with. The checkpoint_suffix will be appended to the path.
464
491
  If stage is 'runtime' and checkpoint_path_override is None, the default runtime path will be used.
@@ -471,9 +498,9 @@ class ModelBuilder:
471
498
  return path
472
499
 
473
500
  loader_type, repo_id, hf_token, when = self._validate_config_checkpoints()
474
- if stage not in ["build", "upload", "runtime", "any"]:
501
+ if stage not in ["build", "upload", "runtime"]:
475
502
  raise Exception("Invalid stage provided, must be one of ['build', 'upload', 'runtime']")
476
- if when != stage and stage != "any":
503
+ if when != stage:
477
504
  logger.info(
478
505
  f"Skipping downloading checkpoints for stage {stage} since config.yaml says to download them at stage {when}"
479
506
  )
@@ -588,7 +615,7 @@ class ModelBuilder:
588
615
 
589
616
  def filter_func(tarinfo):
590
617
  name = tarinfo.name
591
- exclude = [self.tar_file, "*~"]
618
+ exclude = [self.tar_file, "*~", "*.pyc", "*.pyo", "__pycache__"]
592
619
  if when != "upload":
593
620
  exclude.append(self.checkpoint_suffix)
594
621
  return None if any(name.endswith(ex) for ex in exclude) else tarinfo
@@ -739,8 +766,8 @@ def upload_model(folder, stage, skip_dockerfile):
739
766
  Uploads a model to Clarifai.
740
767
 
741
768
  :param folder: The folder containing the model files.
742
- :param stage: The stage of when you're uploading this model. This is used to determine when to download the checkpoints based on a match with the "when" field in the config.yaml checkpoints section or if you set stage to "any" it will always download the checkpoints.
743
- :param skip_dockerfile: If True, skips creating the Dockerfile so you can re-use the local one.
769
+ :param stage: The stage we are calling download checkpoints from. Typically this would "upload" and will download checkpoints if config.yaml checkpoints section has when set to "upload". Other options include "runtime" to be used in load_model or "upload" to be used during model upload. Set this stage to whatever you have in config.yaml to force downloading now.
770
+ :param skip_dockerfile: If True, will not create a Dockerfile.
744
771
  """
745
772
  builder = ModelBuilder(folder)
746
773
  builder.download_checkpoints(stage=stage)
@@ -1,41 +1,265 @@
1
- from abc import ABC, abstractmethod
2
- from typing import Iterator
1
+ import inspect
2
+ import itertools
3
+ import logging
4
+ import os
5
+ import traceback
6
+ import types
7
+ from abc import ABC
8
+ from typing import Any, Dict, Iterator, List
3
9
 
4
- from clarifai_grpc.grpc.api import service_pb2
10
+ from clarifai_grpc.grpc.api import resources_pb2, service_pb2
11
+ from clarifai_grpc.grpc.api.status import status_code_pb2, status_pb2
12
+
13
+ from clarifai.runners.utils import data_handler
14
+ from clarifai.runners.utils.method_signatures import (build_function_signature, deserialize,
15
+ serialize, signatures_to_json)
16
+
17
+ _METHOD_INFO_ATTR = '_cf_method_info'
18
+
19
+ _RAISE_EXCEPTIONS = os.getenv("RAISE_EXCEPTIONS", "false").lower() == "true"
5
20
 
6
21
 
7
22
  class ModelClass(ABC):
8
23
 
24
+ def load_model(self):
25
+ """Load the model."""
26
+ pass
27
+
28
+ def predict(self, **kwargs):
29
+ """Predict method for single or batched inputs."""
30
+ raise NotImplementedError("predict() not implemented")
31
+
32
+ def generate(self, **kwargs) -> Iterator:
33
+ """Generate method for streaming outputs."""
34
+ raise NotImplementedError("generate() not implemented")
35
+
36
+ def stream(self, **kwargs) -> Iterator:
37
+ """Stream method for streaming inputs and outputs."""
38
+ raise NotImplementedError("stream() not implemented")
39
+
40
+ def _handle_get_signatures_request(self) -> service_pb2.MultiOutputResponse:
41
+ methods = self._get_method_info()
42
+ signatures = {method.name: method.signature for method in methods.values()}
43
+ resp = service_pb2.MultiOutputResponse(status=status_pb2.Status(code=status_code_pb2.SUCCESS))
44
+ resp.outputs.add().data.string_value = signatures_to_json(signatures)
45
+ return resp
46
+
47
+ def batch_predict(self, method, inputs: List[Dict[str, Any]]) -> List[Any]:
48
+ """Batch predict method for multiple inputs."""
49
+ outputs = []
50
+ for input in inputs:
51
+ output = method(**input)
52
+ outputs.append(output)
53
+ return outputs
54
+
55
+ def batch_generate(self, method, inputs: List[Dict[str, Any]]) -> Iterator[List[Any]]:
56
+ """Batch generate method for multiple inputs."""
57
+ generators = [method(**input) for input in inputs]
58
+ for outputs in itertools.zip_longest(*generators):
59
+ yield outputs
60
+
9
61
  def predict_wrapper(
10
62
  self, request: service_pb2.PostModelOutputsRequest) -> service_pb2.MultiOutputResponse:
11
- """This method is used for input/output proto data conversion"""
12
- return self.predict(request)
63
+ outputs = []
64
+ try:
65
+ # TODO add method name field to proto
66
+ method_name = request.model.model_version.output_info.params['_method_name']
67
+ if method_name == '_GET_SIGNATURES': # special case to fetch signatures, TODO add endpoint for this
68
+ return self._handle_get_signatures_request()
69
+ if method_name not in self._get_method_info():
70
+ raise ValueError(f"Method {method_name} not found in model class")
71
+ method = getattr(self, method_name)
72
+ method_info = method._cf_method_info
73
+ signature = method_info.signature
74
+ python_param_types = method_info.python_param_types
75
+ inputs = self._convert_input_protos_to_python(request.inputs, signature.inputs,
76
+ python_param_types)
77
+ if len(inputs) == 1:
78
+ inputs = inputs[0]
79
+ output = method(**inputs)
80
+ outputs.append(self._convert_output_to_proto(output, signature.outputs))
81
+ else:
82
+ outputs = self.batch_predict(method, inputs)
83
+ outputs = [self._convert_output_to_proto(output, signature.outputs) for output in outputs]
84
+
85
+ return service_pb2.MultiOutputResponse(
86
+ outputs=outputs, status=status_pb2.Status(code=status_code_pb2.SUCCESS))
87
+ except Exception as e:
88
+ if _RAISE_EXCEPTIONS:
89
+ raise
90
+ logging.exception("Error in predict")
91
+ return service_pb2.MultiOutputResponse(status=status_pb2.Status(
92
+ code=status_code_pb2.FAILURE,
93
+ details=str(e),
94
+ stack_trace=traceback.format_exc().split('\n')))
13
95
 
14
96
  def generate_wrapper(self, request: service_pb2.PostModelOutputsRequest
15
97
  ) -> Iterator[service_pb2.MultiOutputResponse]:
16
- """This method is used for input/output proto data conversion and yield outcome"""
17
- return self.generate(request)
98
+ try:
99
+ method_name = request.model.model_version.output_info.params['_method_name']
100
+ method = getattr(self, method_name)
101
+ method_info = method._cf_method_info
102
+ signature = method_info.signature
103
+ python_param_types = method_info.python_param_types
18
104
 
19
- def stream_wrapper(self, request: service_pb2.PostModelOutputsRequest
105
+ inputs = self._convert_input_protos_to_python(request.inputs, signature.inputs,
106
+ python_param_types)
107
+ if len(inputs) == 1:
108
+ inputs = inputs[0]
109
+ for output in method(**inputs):
110
+ resp = service_pb2.MultiOutputResponse()
111
+ self._convert_output_to_proto(output, signature.outputs, proto=resp.outputs.add())
112
+ resp.status.code = status_code_pb2.SUCCESS
113
+ yield resp
114
+ else:
115
+ for outputs in self.batch_generate(method, inputs):
116
+ resp = service_pb2.MultiOutputResponse()
117
+ for output in outputs:
118
+ self._convert_output_to_proto(output, signature.outputs, proto=resp.outputs.add())
119
+ resp.status.code = status_code_pb2.SUCCESS
120
+ yield resp
121
+ except Exception as e:
122
+ if _RAISE_EXCEPTIONS:
123
+ raise
124
+ logging.exception("Error in generate")
125
+ yield service_pb2.MultiOutputResponse(status=status_pb2.Status(
126
+ code=status_code_pb2.FAILURE,
127
+ details=str(e),
128
+ stack_trace=traceback.format_exc().split('\n')))
129
+
130
+ def stream_wrapper(self, request_iterator: Iterator[service_pb2.PostModelOutputsRequest]
20
131
  ) -> Iterator[service_pb2.MultiOutputResponse]:
21
- """This method is used for input/output proto data conversion and yield outcome"""
22
- return self.stream(request)
132
+ try:
133
+ request = next(request_iterator) # get first request to determine method
134
+ assert len(request.inputs) == 1, "Streaming requires exactly one input"
23
135
 
24
- @abstractmethod
25
- def load_model(self):
26
- raise NotImplementedError("load_model() not implemented")
136
+ method_name = request.model.model_version.output_info.params['_method_name']
137
+ method = getattr(self, method_name)
138
+ method_info = method._cf_method_info
139
+ signature = method_info.signature
140
+ python_param_types = method_info.python_param_types
27
141
 
28
- @abstractmethod
29
- def predict(self,
30
- request: service_pb2.PostModelOutputsRequest) -> service_pb2.MultiOutputResponse:
31
- raise NotImplementedError("run_input() not implemented")
142
+ # find the streaming vars in the signature
143
+ streaming_var_signatures = [var for var in signature.inputs if var.streaming]
144
+ stream_argname = set([var.name.split('.', 1)[0] for var in streaming_var_signatures])
145
+ assert len(
146
+ stream_argname) == 1, 'streaming methods must have exactly one streaming function arg'
147
+ stream_argname = stream_argname.pop()
32
148
 
33
- @abstractmethod
34
- def generate(self, request: service_pb2.PostModelOutputsRequest
35
- ) -> Iterator[service_pb2.MultiOutputResponse]:
36
- raise NotImplementedError("generate() not implemented")
149
+ # convert all inputs for the first request, including the first stream value
150
+ inputs = self._convert_input_protos_to_python(request.inputs, signature.inputs,
151
+ python_param_types)
152
+ kwargs = inputs[0]
37
153
 
38
- @abstractmethod
39
- def stream(self, request_iterator: Iterator[service_pb2.PostModelOutputsRequest]
40
- ) -> Iterator[service_pb2.MultiOutputResponse]:
41
- raise NotImplementedError("stream() not implemented")
154
+ # first streaming item
155
+ first_item = kwargs.pop(stream_argname)
156
+
157
+ # streaming generator
158
+ def InputStream():
159
+ yield first_item
160
+ # subsequent streaming items contain only the streaming input
161
+ for request in request_iterator:
162
+ item = self._convert_input_protos_to_python(request.inputs, streaming_var_signatures,
163
+ python_param_types)
164
+ item = item[0][stream_argname]
165
+ yield item
166
+
167
+ # add stream generator back to the input kwargs
168
+ kwargs[stream_argname] = InputStream()
169
+
170
+ for output in method(**kwargs):
171
+ resp = service_pb2.MultiOutputResponse()
172
+ self._convert_output_to_proto(output, signature.outputs, proto=resp.outputs.add())
173
+ resp.status.code = status_code_pb2.SUCCESS
174
+ yield resp
175
+ except Exception as e:
176
+ if _RAISE_EXCEPTIONS:
177
+ raise
178
+ logging.exception("Error in stream")
179
+ yield service_pb2.MultiOutputResponse(status=status_pb2.Status(
180
+ code=status_code_pb2.FAILURE,
181
+ details=str(e),
182
+ stack_trace=traceback.format_exc().split('\n')))
183
+
184
+ def _convert_input_protos_to_python(self, inputs: List[resources_pb2.Input], variables_signature,
185
+ python_param_types) -> List[Dict[str, Any]]:
186
+ result = []
187
+ for input in inputs:
188
+ kwargs = deserialize(input.data, variables_signature)
189
+ # dynamic cast to annotated types
190
+ for k, v in kwargs.items():
191
+ if k not in python_param_types:
192
+ continue
193
+ kwargs[k] = data_handler.cast(v, python_param_types[k])
194
+ result.append(kwargs)
195
+ return result
196
+
197
+ def _convert_output_to_proto(self, output: Any, variables_signature,
198
+ proto=None) -> resources_pb2.Output:
199
+ if proto is None:
200
+ proto = resources_pb2.Output()
201
+ if isinstance(output, tuple):
202
+ output = {f'return.{i}': item for i, item in enumerate(output)}
203
+ if not isinstance(output, dict): # TODO Output type, not just dict
204
+ output = {'return': output}
205
+ serialize(output, variables_signature, proto.data, is_output=True)
206
+ return proto
207
+
208
+ @classmethod
209
+ def _register_model_methods(cls):
210
+ # go up the class hierarchy to find all decorated methods, and add to registry of current class
211
+ methods = {}
212
+ for base in reversed(cls.__mro__):
213
+ for name, method in base.__dict__.items():
214
+ method_info = getattr(method, _METHOD_INFO_ATTR, None)
215
+ if not method_info: # regular function, not a model method
216
+ continue
217
+ methods[name] = method_info
218
+ # check for generic predict(request) -> response, etc. methods
219
+ #for name in ('predict', 'generate', 'stream'):
220
+ # if hasattr(cls, name):
221
+ # method = getattr(cls, name)
222
+ # if not hasattr(method, _METHOD_INFO_ATTR): # not already put in registry
223
+ # methods[name] = _MethodInfo(method, method_type=name)
224
+ # set method table for this class in the registry
225
+ return methods
226
+
227
+ @classmethod
228
+ def _get_method_info(cls, func_name=None):
229
+ if not hasattr(cls, _METHOD_INFO_ATTR):
230
+ setattr(cls, _METHOD_INFO_ATTR, cls._register_model_methods())
231
+ method_info = getattr(cls, _METHOD_INFO_ATTR)
232
+ if func_name:
233
+ return method_info[func_name]
234
+ return method_info
235
+
236
+
237
+ class _MethodInfo:
238
+
239
+ def __init__(self, method, method_type):
240
+ self.name = method.__name__
241
+ self.signature = build_function_signature(method, method_type)
242
+ self.python_param_types = {
243
+ p.name: p.annotation
244
+ for p in inspect.signature(method).parameters.values()
245
+ if p.annotation != inspect.Parameter.empty
246
+ }
247
+ self.python_param_types.pop('self', None)
248
+
249
+
250
+ def predict(method):
251
+ setattr(method, _METHOD_INFO_ATTR, _MethodInfo(method, 'predict'))
252
+ return method
253
+
254
+
255
+ def generate(method):
256
+ setattr(method, _METHOD_INFO_ATTR, _MethodInfo(method, 'generate'))
257
+ return method
258
+
259
+
260
+ def stream(method):
261
+ setattr(method, _METHOD_INFO_ATTR, _MethodInfo(method, 'stream'))
262
+ return method
263
+
264
+
265
+ methods = types.SimpleNamespace(predict=predict, generate=generate, stream=stream)
@@ -481,8 +481,11 @@ def main(model_path,
481
481
  )
482
482
  sys.exit(1)
483
483
  manager = ModelRunLocally(model_path)
484
- # stage="any" forces downloaded now regardless of config.yaml
485
- manager.builder.download_checkpoints(stage="any")
484
+ # get whatever stage is in config.yaml to force download now
485
+ # also always write to where upload/build wants to, not the /tmp folder that runtime stage uses
486
+ _, _, _, when = manager.builder._validate_config_checkpoints()
487
+ manager.builder.download_checkpoints(
488
+ stage=when, checkpoint_path_override=manager.builder.checkpoint_path)
486
489
  if inside_container:
487
490
  if not manager.is_docker_installed():
488
491
  sys.exit(1)
@@ -82,6 +82,8 @@ class ModelRunner(BaseRunner, HealthProbeRequestHandler):
82
82
  ensure_urls_downloaded(request)
83
83
 
84
84
  resp = self.model.predict_wrapper(request)
85
+ if resp.status.code != status_code_pb2.SUCCESS:
86
+ return service_pb2.RunnerItemOutput(multi_output_response=resp)
85
87
  successes = [o.status.code == status_code_pb2.SUCCESS for o in resp.outputs]
86
88
  if all(successes):
87
89
  status = status_pb2.Status(
@@ -1,3 +1,4 @@
1
+ import os
1
2
  from itertools import tee
2
3
  from typing import Iterator
3
4
 
@@ -6,6 +7,8 @@ from clarifai_grpc.grpc.api.status import status_code_pb2, status_pb2
6
7
 
7
8
  from ..utils.url_fetcher import ensure_urls_downloaded
8
9
 
10
+ _RAISE_EXCEPTIONS = os.getenv("RAISE_EXCEPTIONS", "false").lower() in ("true", "1")
11
+
9
12
 
10
13
  class ModelServicer(service_pb2_grpc.V2Servicer):
11
14
  """
@@ -33,6 +36,8 @@ class ModelServicer(service_pb2_grpc.V2Servicer):
33
36
  try:
34
37
  return self.model.predict_wrapper(request)
35
38
  except Exception as e:
39
+ if _RAISE_EXCEPTIONS:
40
+ raise
36
41
  return service_pb2.MultiOutputResponse(status=status_pb2.Status(
37
42
  code=status_code_pb2.MODEL_PREDICTION_FAILED,
38
43
  description="Failed",
@@ -50,8 +55,10 @@ class ModelServicer(service_pb2_grpc.V2Servicer):
50
55
  ensure_urls_downloaded(request)
51
56
 
52
57
  try:
53
- return self.model.generate_wrapper(request)
58
+ yield from self.model.generate_wrapper(request)
54
59
  except Exception as e:
60
+ if _RAISE_EXCEPTIONS:
61
+ raise
55
62
  yield service_pb2.MultiOutputResponse(status=status_pb2.Status(
56
63
  code=status_code_pb2.MODEL_PREDICTION_FAILED,
57
64
  description="Failed",
@@ -74,8 +81,10 @@ class ModelServicer(service_pb2_grpc.V2Servicer):
74
81
  ensure_urls_downloaded(req)
75
82
 
76
83
  try:
77
- return self.model.stream_wrapper(request_copy)
84
+ yield from self.model.stream_wrapper(request_copy)
78
85
  except Exception as e:
86
+ if _RAISE_EXCEPTIONS:
87
+ raise
79
88
  yield service_pb2.MultiOutputResponse(status=status_pb2.Status(
80
89
  code=status_code_pb2.MODEL_PREDICTION_FAILED,
81
90
  description="Failed",
@@ -68,30 +68,47 @@ def main():
68
68
 
69
69
  parsed_args = parser.parse_args()
70
70
 
71
- builder = ModelBuilder(parsed_args.model_path, download_validation_only=True)
71
+ serve(parsed_args.model_path, parsed_args.port, parsed_args.pool_size,
72
+ parsed_args.max_queue_size, parsed_args.max_msg_length, parsed_args.enable_tls,
73
+ parsed_args.grpc)
74
+
75
+
76
+ def serve(model_path,
77
+ port=8000,
78
+ pool_size=32,
79
+ max_queue_size=10,
80
+ max_msg_length=1024 * 1024 * 1024,
81
+ enable_tls=False,
82
+ grpc=False):
83
+
84
+ builder = ModelBuilder(model_path, download_validation_only=True)
72
85
 
73
86
  model = builder.create_model_instance()
74
87
 
88
+ # `num_threads` can be set in config.yaml or via the environment variable CLARIFAI_NUM_THREADS="<integer>".
89
+ # Note: The value in config.yaml takes precedence over the environment variable.
90
+ num_threads = builder.config.get("num_threads")
91
+
75
92
  # Setup the grpc server for local development.
76
- if parsed_args.grpc:
93
+ if grpc:
77
94
 
78
95
  # initialize the servicer with the runner so that it gets the predict(), generate(), stream() classes.
79
96
  servicer = ModelServicer(model)
80
97
 
81
98
  server = GRPCServer(
82
99
  futures.ThreadPoolExecutor(
83
- max_workers=parsed_args.pool_size,
100
+ max_workers=pool_size,
84
101
  thread_name_prefix="ServeCalls",
85
102
  ),
86
- parsed_args.max_msg_length,
87
- parsed_args.max_queue_size,
103
+ max_msg_length,
104
+ max_queue_size,
88
105
  )
89
- server.add_port_to_server('[::]:%s' % parsed_args.port, parsed_args.enable_tls)
106
+ server.add_port_to_server('[::]:%s' % port, enable_tls)
90
107
 
91
108
  service_pb2_grpc.add_V2Servicer_to_server(servicer, server)
92
109
  server.start()
93
- logger.info("Started server on port %s", parsed_args.port)
94
- logger.info(f"Access the model at http://localhost:{parsed_args.port}")
110
+ logger.info("Started server on port %s", port)
111
+ logger.info(f"Access the model at http://localhost:{port}")
95
112
  server.wait_for_termination()
96
113
  else: # start the runner with the proper env variables and as a runner protocol.
97
114
 
@@ -102,7 +119,7 @@ def main():
102
119
  nodepool_id=os.environ["CLARIFAI_NODEPOOL_ID"],
103
120
  compute_cluster_id=os.environ["CLARIFAI_COMPUTE_CLUSTER_ID"],
104
121
  base_url=os.environ.get("CLARIFAI_API_BASE", "https://api.clarifai.com"),
105
- num_parallel_polls=int(os.environ.get("CLARIFAI_NUM_THREADS", 1)),
122
+ num_parallel_polls=num_threads,
106
123
  )
107
124
  runner.start() # start the runner to fetch work from the API.
108
125
 
@@ -16,7 +16,7 @@ DEFAULT_PYTHON_VERSION = 3.12
16
16
  DEFAULT_DOWNLOAD_CHECKPOINT_WHEN = "runtime"
17
17
 
18
18
  # Folder for downloading checkpoints at runtime.
19
- DEFAULT_RUNTIME_DOWNLOAD_PATH = "/tmp/.cache"
19
+ DEFAULT_RUNTIME_DOWNLOAD_PATH = os.path.join(os.sep, "tmp", ".cache")
20
20
 
21
21
  # List of available torch images
22
22
  # Keep sorted by most recent cuda version.