clarifai 11.1.4rc2__py3-none-any.whl → 11.1.5rc1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- clarifai/__init__.py +1 -1
- clarifai/cli/__pycache__/model.cpython-310.pyc +0 -0
- clarifai/cli/model.py +46 -10
- clarifai/client/model.py +89 -364
- clarifai/client/model_client.py +400 -0
- clarifai/client/workflow.py +2 -2
- clarifai/datasets/upload/loaders/__pycache__/__init__.cpython-310.pyc +0 -0
- clarifai/datasets/upload/loaders/__pycache__/coco_detection.cpython-310.pyc +0 -0
- clarifai/rag/__pycache__/rag.cpython-310.pyc +0 -0
- clarifai/runners/__init__.py +2 -7
- clarifai/runners/__pycache__/__init__.cpython-310.pyc +0 -0
- clarifai/runners/__pycache__/server.cpython-310.pyc +0 -0
- clarifai/runners/dockerfile_template/Dockerfile.template +4 -32
- clarifai/runners/models/__pycache__/base_typed_model.cpython-310.pyc +0 -0
- clarifai/runners/models/__pycache__/model_builder.cpython-310.pyc +0 -0
- clarifai/runners/models/__pycache__/model_class.cpython-310.pyc +0 -0
- clarifai/runners/models/__pycache__/model_run_locally.cpython-310.pyc +0 -0
- clarifai/runners/models/__pycache__/model_runner.cpython-310.pyc +0 -0
- clarifai/runners/models/__pycache__/model_servicer.cpython-310.pyc +0 -0
- clarifai/runners/models/model_builder.py +47 -20
- clarifai/runners/models/model_class.py +249 -25
- clarifai/runners/models/model_run_locally.py +5 -2
- clarifai/runners/models/model_runner.py +2 -0
- clarifai/runners/models/model_servicer.py +11 -2
- clarifai/runners/server.py +26 -9
- clarifai/runners/utils/__pycache__/const.cpython-310.pyc +0 -0
- clarifai/runners/utils/__pycache__/data_handler.cpython-310.pyc +0 -0
- clarifai/runners/utils/__pycache__/method_signatures.cpython-310.pyc +0 -0
- clarifai/runners/utils/__pycache__/serializers.cpython-310.pyc +0 -0
- clarifai/runners/utils/const.py +1 -1
- clarifai/runners/utils/data_handler.py +308 -205
- clarifai/runners/utils/method_signatures.py +437 -0
- clarifai/runners/utils/serializers.py +132 -0
- clarifai/utils/evaluation/__pycache__/__init__.cpython-310.pyc +0 -0
- clarifai/utils/evaluation/__pycache__/helpers.cpython-310.pyc +0 -0
- clarifai/utils/evaluation/__pycache__/main.cpython-310.pyc +0 -0
- clarifai/utils/misc.py +12 -0
- {clarifai-11.1.4rc2.dist-info → clarifai-11.1.5rc1.dist-info}/METADATA +3 -2
- {clarifai-11.1.4rc2.dist-info → clarifai-11.1.5rc1.dist-info}/RECORD +43 -36
- clarifai/runners/models/base_typed_model.py +0 -238
- clarifai/runners/models/model_upload.py +0 -607
- clarifai/runners/utils/#const.py# +0 -30
- {clarifai-11.1.4rc2.dist-info → clarifai-11.1.5rc1.dist-info}/LICENSE +0 -0
- {clarifai-11.1.4rc2.dist-info → clarifai-11.1.5rc1.dist-info}/WHEEL +0 -0
- {clarifai-11.1.4rc2.dist-info → clarifai-11.1.5rc1.dist-info}/entry_points.txt +0 -0
- {clarifai-11.1.4rc2.dist-info → clarifai-11.1.5rc1.dist-info}/top_level.txt +0 -0
@@ -14,13 +14,14 @@ from google.protobuf import json_format
|
|
14
14
|
from rich import print
|
15
15
|
from rich.markup import escape
|
16
16
|
|
17
|
-
from clarifai.client import BaseClient
|
17
|
+
from clarifai.client.base import BaseClient
|
18
18
|
from clarifai.runners.models.model_class import ModelClass
|
19
19
|
from clarifai.runners.utils.const import (
|
20
20
|
AVAILABLE_PYTHON_IMAGES, AVAILABLE_TORCH_IMAGES, CONCEPTS_REQUIRED_MODEL_TYPE,
|
21
21
|
DEFAULT_DOWNLOAD_CHECKPOINT_WHEN, DEFAULT_PYTHON_VERSION, DEFAULT_RUNTIME_DOWNLOAD_PATH,
|
22
22
|
PYTHON_BASE_IMAGE, TORCH_BASE_IMAGE)
|
23
23
|
from clarifai.runners.utils.loader import HuggingFaceLoader
|
24
|
+
from clarifai.runners.utils.method_signatures import signatures_to_yaml
|
24
25
|
from clarifai.urls.helper import ClarifaiUrlHelper
|
25
26
|
from clarifai.utils.logging import logger
|
26
27
|
from clarifai.versions import CLIENT_VERSION
|
@@ -69,6 +70,18 @@ class ModelBuilder:
|
|
69
70
|
"""
|
70
71
|
Create an instance of the model class, as specified in the config file.
|
71
72
|
"""
|
73
|
+
model_class = self.load_model_class()
|
74
|
+
|
75
|
+
# initialize the model
|
76
|
+
model = model_class()
|
77
|
+
if load_model:
|
78
|
+
model.load_model()
|
79
|
+
return model
|
80
|
+
|
81
|
+
def load_model_class(self):
|
82
|
+
"""
|
83
|
+
Import the model class from the model.py file.
|
84
|
+
"""
|
72
85
|
# look for default model.py file location
|
73
86
|
for loc in ["model.py", "1/model.py"]:
|
74
87
|
model_file = os.path.join(self.folder, loc)
|
@@ -107,12 +120,7 @@ class ModelBuilder:
|
|
107
120
|
"Could not determine model class. There should be exactly one model inheriting from ModelClass defined in the model.py"
|
108
121
|
)
|
109
122
|
model_class = classes[0]
|
110
|
-
|
111
|
-
# initialize the model
|
112
|
-
model = model_class()
|
113
|
-
if load_model:
|
114
|
-
model.load_model()
|
115
|
-
return model
|
123
|
+
return model_class
|
116
124
|
|
117
125
|
def _validate_folder(self, folder):
|
118
126
|
if folder == ".":
|
@@ -159,7 +167,6 @@ class ModelBuilder:
|
|
159
167
|
f"No 'when' specified in the config file for checkpoints, defaulting to download at {DEFAULT_DOWNLOAD_CHECKPOINT_WHEN}"
|
160
168
|
)
|
161
169
|
when = checkpoints.get("when", DEFAULT_DOWNLOAD_CHECKPOINT_WHEN)
|
162
|
-
# In the config.yaml we don't allow "any", that's only used in download_checkpoints to force download.
|
163
170
|
assert when in [
|
164
171
|
"upload",
|
165
172
|
"build",
|
@@ -227,6 +234,15 @@ class ModelBuilder:
|
|
227
234
|
)
|
228
235
|
logger.info("Continuing without Hugging Face token")
|
229
236
|
|
237
|
+
num_threads = self.config.get("num_threads")
|
238
|
+
if num_threads or num_threads == 0:
|
239
|
+
assert isinstance(num_threads, int) and num_threads >= 1, ValueError(
|
240
|
+
f"`num_threads` must be an integer greater than or equal to 1. Received type {type(num_threads)} with value {num_threads}."
|
241
|
+
)
|
242
|
+
else:
|
243
|
+
num_threads = int(os.environ.get("CLARIFAI_NUM_THREADS", 1))
|
244
|
+
self.config["num_threads"] = num_threads
|
245
|
+
|
230
246
|
@staticmethod
|
231
247
|
def _get_tar_file_content_size(tar_file_path):
|
232
248
|
"""
|
@@ -245,6 +261,15 @@ class ModelBuilder:
|
|
245
261
|
total_size += member.size
|
246
262
|
return total_size
|
247
263
|
|
264
|
+
def method_signatures_yaml(self):
|
265
|
+
"""
|
266
|
+
Returns the method signatures for the model class in YAML format.
|
267
|
+
"""
|
268
|
+
model_class = self.load_model_class()
|
269
|
+
method_info = model_class._get_method_info()
|
270
|
+
signatures = {name: m.signature for name, m in method_info.items()}
|
271
|
+
return signatures_to_yaml(signatures)
|
272
|
+
|
248
273
|
@property
|
249
274
|
def client(self):
|
250
275
|
if self._client is None:
|
@@ -366,10 +391,10 @@ class ModelBuilder:
|
|
366
391
|
if 'python_version' in build_info:
|
367
392
|
python_version = build_info['python_version']
|
368
393
|
if python_version not in AVAILABLE_PYTHON_IMAGES:
|
369
|
-
|
370
|
-
f"Python version {python_version} not supported, please use one of the following versions: {AVAILABLE_PYTHON_IMAGES}"
|
394
|
+
raise Exception(
|
395
|
+
f"Python version {python_version} not supported, please use one of the following versions: {AVAILABLE_PYTHON_IMAGES} in your config.yaml"
|
371
396
|
)
|
372
|
-
|
397
|
+
|
373
398
|
logger.info(
|
374
399
|
f"Using Python version {python_version} from the config file to build the Dockerfile")
|
375
400
|
else:
|
@@ -443,7 +468,7 @@ class ModelBuilder:
|
|
443
468
|
|
444
469
|
@property
|
445
470
|
def checkpoint_suffix(self):
|
446
|
-
return '1
|
471
|
+
return os.path.join('1', 'checkpoints')
|
447
472
|
|
448
473
|
@property
|
449
474
|
def tar_file(self):
|
@@ -452,13 +477,15 @@ class ModelBuilder:
|
|
452
477
|
def default_runtime_checkpoint_path(self):
|
453
478
|
return DEFAULT_RUNTIME_DOWNLOAD_PATH
|
454
479
|
|
455
|
-
def download_checkpoints(self,
|
480
|
+
def download_checkpoints(self,
|
481
|
+
stage: str = DEFAULT_DOWNLOAD_CHECKPOINT_WHEN,
|
482
|
+
checkpoint_path_override: str = None):
|
456
483
|
"""
|
457
484
|
Downloads the checkpoints specified in the config file.
|
458
485
|
|
459
486
|
:param stage: The stage of the build process. This is used to determine when to download the
|
460
|
-
checkpoints. The stage can be one of ['build', 'upload', 'runtime'
|
461
|
-
|
487
|
+
checkpoints. The stage can be one of ['build', 'upload', 'runtime']. If you want to force
|
488
|
+
downloading now then set stage to match e when field of the checkpoints section of you config.yaml.
|
462
489
|
:param checkpoint_path_override: The path to download the checkpoints to (with 1/checkpoints added as suffix). If not provided, the
|
463
490
|
default path is used based on the folder ModelUploader was initialized with. The checkpoint_suffix will be appended to the path.
|
464
491
|
If stage is 'runtime' and checkpoint_path_override is None, the default runtime path will be used.
|
@@ -471,9 +498,9 @@ class ModelBuilder:
|
|
471
498
|
return path
|
472
499
|
|
473
500
|
loader_type, repo_id, hf_token, when = self._validate_config_checkpoints()
|
474
|
-
if stage not in ["build", "upload", "runtime"
|
501
|
+
if stage not in ["build", "upload", "runtime"]:
|
475
502
|
raise Exception("Invalid stage provided, must be one of ['build', 'upload', 'runtime']")
|
476
|
-
if when != stage
|
503
|
+
if when != stage:
|
477
504
|
logger.info(
|
478
505
|
f"Skipping downloading checkpoints for stage {stage} since config.yaml says to download them at stage {when}"
|
479
506
|
)
|
@@ -588,7 +615,7 @@ class ModelBuilder:
|
|
588
615
|
|
589
616
|
def filter_func(tarinfo):
|
590
617
|
name = tarinfo.name
|
591
|
-
exclude = [self.tar_file, "*~"]
|
618
|
+
exclude = [self.tar_file, "*~", "*.pyc", "*.pyo", "__pycache__"]
|
592
619
|
if when != "upload":
|
593
620
|
exclude.append(self.checkpoint_suffix)
|
594
621
|
return None if any(name.endswith(ex) for ex in exclude) else tarinfo
|
@@ -739,8 +766,8 @@ def upload_model(folder, stage, skip_dockerfile):
|
|
739
766
|
Uploads a model to Clarifai.
|
740
767
|
|
741
768
|
:param folder: The folder containing the model files.
|
742
|
-
:param stage: The stage
|
743
|
-
:param skip_dockerfile: If True,
|
769
|
+
:param stage: The stage we are calling download checkpoints from. Typically this would "upload" and will download checkpoints if config.yaml checkpoints section has when set to "upload". Other options include "runtime" to be used in load_model or "upload" to be used during model upload. Set this stage to whatever you have in config.yaml to force downloading now.
|
770
|
+
:param skip_dockerfile: If True, will not create a Dockerfile.
|
744
771
|
"""
|
745
772
|
builder = ModelBuilder(folder)
|
746
773
|
builder.download_checkpoints(stage=stage)
|
@@ -1,41 +1,265 @@
|
|
1
|
-
|
2
|
-
|
1
|
+
import inspect
|
2
|
+
import itertools
|
3
|
+
import logging
|
4
|
+
import os
|
5
|
+
import traceback
|
6
|
+
import types
|
7
|
+
from abc import ABC
|
8
|
+
from typing import Any, Dict, Iterator, List
|
3
9
|
|
4
|
-
from clarifai_grpc.grpc.api import service_pb2
|
10
|
+
from clarifai_grpc.grpc.api import resources_pb2, service_pb2
|
11
|
+
from clarifai_grpc.grpc.api.status import status_code_pb2, status_pb2
|
12
|
+
|
13
|
+
from clarifai.runners.utils import data_handler
|
14
|
+
from clarifai.runners.utils.method_signatures import (build_function_signature, deserialize,
|
15
|
+
serialize, signatures_to_json)
|
16
|
+
|
17
|
+
_METHOD_INFO_ATTR = '_cf_method_info'
|
18
|
+
|
19
|
+
_RAISE_EXCEPTIONS = os.getenv("RAISE_EXCEPTIONS", "false").lower() == "true"
|
5
20
|
|
6
21
|
|
7
22
|
class ModelClass(ABC):
|
8
23
|
|
24
|
+
def load_model(self):
|
25
|
+
"""Load the model."""
|
26
|
+
pass
|
27
|
+
|
28
|
+
def predict(self, **kwargs):
|
29
|
+
"""Predict method for single or batched inputs."""
|
30
|
+
raise NotImplementedError("predict() not implemented")
|
31
|
+
|
32
|
+
def generate(self, **kwargs) -> Iterator:
|
33
|
+
"""Generate method for streaming outputs."""
|
34
|
+
raise NotImplementedError("generate() not implemented")
|
35
|
+
|
36
|
+
def stream(self, **kwargs) -> Iterator:
|
37
|
+
"""Stream method for streaming inputs and outputs."""
|
38
|
+
raise NotImplementedError("stream() not implemented")
|
39
|
+
|
40
|
+
def _handle_get_signatures_request(self) -> service_pb2.MultiOutputResponse:
|
41
|
+
methods = self._get_method_info()
|
42
|
+
signatures = {method.name: method.signature for method in methods.values()}
|
43
|
+
resp = service_pb2.MultiOutputResponse(status=status_pb2.Status(code=status_code_pb2.SUCCESS))
|
44
|
+
resp.outputs.add().data.string_value = signatures_to_json(signatures)
|
45
|
+
return resp
|
46
|
+
|
47
|
+
def batch_predict(self, method, inputs: List[Dict[str, Any]]) -> List[Any]:
|
48
|
+
"""Batch predict method for multiple inputs."""
|
49
|
+
outputs = []
|
50
|
+
for input in inputs:
|
51
|
+
output = method(**input)
|
52
|
+
outputs.append(output)
|
53
|
+
return outputs
|
54
|
+
|
55
|
+
def batch_generate(self, method, inputs: List[Dict[str, Any]]) -> Iterator[List[Any]]:
|
56
|
+
"""Batch generate method for multiple inputs."""
|
57
|
+
generators = [method(**input) for input in inputs]
|
58
|
+
for outputs in itertools.zip_longest(*generators):
|
59
|
+
yield outputs
|
60
|
+
|
9
61
|
def predict_wrapper(
|
10
62
|
self, request: service_pb2.PostModelOutputsRequest) -> service_pb2.MultiOutputResponse:
|
11
|
-
|
12
|
-
|
63
|
+
outputs = []
|
64
|
+
try:
|
65
|
+
# TODO add method name field to proto
|
66
|
+
method_name = request.model.model_version.output_info.params['_method_name']
|
67
|
+
if method_name == '_GET_SIGNATURES': # special case to fetch signatures, TODO add endpoint for this
|
68
|
+
return self._handle_get_signatures_request()
|
69
|
+
if method_name not in self._get_method_info():
|
70
|
+
raise ValueError(f"Method {method_name} not found in model class")
|
71
|
+
method = getattr(self, method_name)
|
72
|
+
method_info = method._cf_method_info
|
73
|
+
signature = method_info.signature
|
74
|
+
python_param_types = method_info.python_param_types
|
75
|
+
inputs = self._convert_input_protos_to_python(request.inputs, signature.inputs,
|
76
|
+
python_param_types)
|
77
|
+
if len(inputs) == 1:
|
78
|
+
inputs = inputs[0]
|
79
|
+
output = method(**inputs)
|
80
|
+
outputs.append(self._convert_output_to_proto(output, signature.outputs))
|
81
|
+
else:
|
82
|
+
outputs = self.batch_predict(method, inputs)
|
83
|
+
outputs = [self._convert_output_to_proto(output, signature.outputs) for output in outputs]
|
84
|
+
|
85
|
+
return service_pb2.MultiOutputResponse(
|
86
|
+
outputs=outputs, status=status_pb2.Status(code=status_code_pb2.SUCCESS))
|
87
|
+
except Exception as e:
|
88
|
+
if _RAISE_EXCEPTIONS:
|
89
|
+
raise
|
90
|
+
logging.exception("Error in predict")
|
91
|
+
return service_pb2.MultiOutputResponse(status=status_pb2.Status(
|
92
|
+
code=status_code_pb2.FAILURE,
|
93
|
+
details=str(e),
|
94
|
+
stack_trace=traceback.format_exc().split('\n')))
|
13
95
|
|
14
96
|
def generate_wrapper(self, request: service_pb2.PostModelOutputsRequest
|
15
97
|
) -> Iterator[service_pb2.MultiOutputResponse]:
|
16
|
-
|
17
|
-
|
98
|
+
try:
|
99
|
+
method_name = request.model.model_version.output_info.params['_method_name']
|
100
|
+
method = getattr(self, method_name)
|
101
|
+
method_info = method._cf_method_info
|
102
|
+
signature = method_info.signature
|
103
|
+
python_param_types = method_info.python_param_types
|
18
104
|
|
19
|
-
|
105
|
+
inputs = self._convert_input_protos_to_python(request.inputs, signature.inputs,
|
106
|
+
python_param_types)
|
107
|
+
if len(inputs) == 1:
|
108
|
+
inputs = inputs[0]
|
109
|
+
for output in method(**inputs):
|
110
|
+
resp = service_pb2.MultiOutputResponse()
|
111
|
+
self._convert_output_to_proto(output, signature.outputs, proto=resp.outputs.add())
|
112
|
+
resp.status.code = status_code_pb2.SUCCESS
|
113
|
+
yield resp
|
114
|
+
else:
|
115
|
+
for outputs in self.batch_generate(method, inputs):
|
116
|
+
resp = service_pb2.MultiOutputResponse()
|
117
|
+
for output in outputs:
|
118
|
+
self._convert_output_to_proto(output, signature.outputs, proto=resp.outputs.add())
|
119
|
+
resp.status.code = status_code_pb2.SUCCESS
|
120
|
+
yield resp
|
121
|
+
except Exception as e:
|
122
|
+
if _RAISE_EXCEPTIONS:
|
123
|
+
raise
|
124
|
+
logging.exception("Error in generate")
|
125
|
+
yield service_pb2.MultiOutputResponse(status=status_pb2.Status(
|
126
|
+
code=status_code_pb2.FAILURE,
|
127
|
+
details=str(e),
|
128
|
+
stack_trace=traceback.format_exc().split('\n')))
|
129
|
+
|
130
|
+
def stream_wrapper(self, request_iterator: Iterator[service_pb2.PostModelOutputsRequest]
|
20
131
|
) -> Iterator[service_pb2.MultiOutputResponse]:
|
21
|
-
|
22
|
-
|
132
|
+
try:
|
133
|
+
request = next(request_iterator) # get first request to determine method
|
134
|
+
assert len(request.inputs) == 1, "Streaming requires exactly one input"
|
23
135
|
|
24
|
-
|
25
|
-
|
26
|
-
|
136
|
+
method_name = request.model.model_version.output_info.params['_method_name']
|
137
|
+
method = getattr(self, method_name)
|
138
|
+
method_info = method._cf_method_info
|
139
|
+
signature = method_info.signature
|
140
|
+
python_param_types = method_info.python_param_types
|
27
141
|
|
28
|
-
|
29
|
-
|
30
|
-
|
31
|
-
|
142
|
+
# find the streaming vars in the signature
|
143
|
+
streaming_var_signatures = [var for var in signature.inputs if var.streaming]
|
144
|
+
stream_argname = set([var.name.split('.', 1)[0] for var in streaming_var_signatures])
|
145
|
+
assert len(
|
146
|
+
stream_argname) == 1, 'streaming methods must have exactly one streaming function arg'
|
147
|
+
stream_argname = stream_argname.pop()
|
32
148
|
|
33
|
-
|
34
|
-
|
35
|
-
|
36
|
-
|
149
|
+
# convert all inputs for the first request, including the first stream value
|
150
|
+
inputs = self._convert_input_protos_to_python(request.inputs, signature.inputs,
|
151
|
+
python_param_types)
|
152
|
+
kwargs = inputs[0]
|
37
153
|
|
38
|
-
|
39
|
-
|
40
|
-
|
41
|
-
|
154
|
+
# first streaming item
|
155
|
+
first_item = kwargs.pop(stream_argname)
|
156
|
+
|
157
|
+
# streaming generator
|
158
|
+
def InputStream():
|
159
|
+
yield first_item
|
160
|
+
# subsequent streaming items contain only the streaming input
|
161
|
+
for request in request_iterator:
|
162
|
+
item = self._convert_input_protos_to_python(request.inputs, streaming_var_signatures,
|
163
|
+
python_param_types)
|
164
|
+
item = item[0][stream_argname]
|
165
|
+
yield item
|
166
|
+
|
167
|
+
# add stream generator back to the input kwargs
|
168
|
+
kwargs[stream_argname] = InputStream()
|
169
|
+
|
170
|
+
for output in method(**kwargs):
|
171
|
+
resp = service_pb2.MultiOutputResponse()
|
172
|
+
self._convert_output_to_proto(output, signature.outputs, proto=resp.outputs.add())
|
173
|
+
resp.status.code = status_code_pb2.SUCCESS
|
174
|
+
yield resp
|
175
|
+
except Exception as e:
|
176
|
+
if _RAISE_EXCEPTIONS:
|
177
|
+
raise
|
178
|
+
logging.exception("Error in stream")
|
179
|
+
yield service_pb2.MultiOutputResponse(status=status_pb2.Status(
|
180
|
+
code=status_code_pb2.FAILURE,
|
181
|
+
details=str(e),
|
182
|
+
stack_trace=traceback.format_exc().split('\n')))
|
183
|
+
|
184
|
+
def _convert_input_protos_to_python(self, inputs: List[resources_pb2.Input], variables_signature,
|
185
|
+
python_param_types) -> List[Dict[str, Any]]:
|
186
|
+
result = []
|
187
|
+
for input in inputs:
|
188
|
+
kwargs = deserialize(input.data, variables_signature)
|
189
|
+
# dynamic cast to annotated types
|
190
|
+
for k, v in kwargs.items():
|
191
|
+
if k not in python_param_types:
|
192
|
+
continue
|
193
|
+
kwargs[k] = data_handler.cast(v, python_param_types[k])
|
194
|
+
result.append(kwargs)
|
195
|
+
return result
|
196
|
+
|
197
|
+
def _convert_output_to_proto(self, output: Any, variables_signature,
|
198
|
+
proto=None) -> resources_pb2.Output:
|
199
|
+
if proto is None:
|
200
|
+
proto = resources_pb2.Output()
|
201
|
+
if isinstance(output, tuple):
|
202
|
+
output = {f'return.{i}': item for i, item in enumerate(output)}
|
203
|
+
if not isinstance(output, dict): # TODO Output type, not just dict
|
204
|
+
output = {'return': output}
|
205
|
+
serialize(output, variables_signature, proto.data, is_output=True)
|
206
|
+
return proto
|
207
|
+
|
208
|
+
@classmethod
|
209
|
+
def _register_model_methods(cls):
|
210
|
+
# go up the class hierarchy to find all decorated methods, and add to registry of current class
|
211
|
+
methods = {}
|
212
|
+
for base in reversed(cls.__mro__):
|
213
|
+
for name, method in base.__dict__.items():
|
214
|
+
method_info = getattr(method, _METHOD_INFO_ATTR, None)
|
215
|
+
if not method_info: # regular function, not a model method
|
216
|
+
continue
|
217
|
+
methods[name] = method_info
|
218
|
+
# check for generic predict(request) -> response, etc. methods
|
219
|
+
#for name in ('predict', 'generate', 'stream'):
|
220
|
+
# if hasattr(cls, name):
|
221
|
+
# method = getattr(cls, name)
|
222
|
+
# if not hasattr(method, _METHOD_INFO_ATTR): # not already put in registry
|
223
|
+
# methods[name] = _MethodInfo(method, method_type=name)
|
224
|
+
# set method table for this class in the registry
|
225
|
+
return methods
|
226
|
+
|
227
|
+
@classmethod
|
228
|
+
def _get_method_info(cls, func_name=None):
|
229
|
+
if not hasattr(cls, _METHOD_INFO_ATTR):
|
230
|
+
setattr(cls, _METHOD_INFO_ATTR, cls._register_model_methods())
|
231
|
+
method_info = getattr(cls, _METHOD_INFO_ATTR)
|
232
|
+
if func_name:
|
233
|
+
return method_info[func_name]
|
234
|
+
return method_info
|
235
|
+
|
236
|
+
|
237
|
+
class _MethodInfo:
|
238
|
+
|
239
|
+
def __init__(self, method, method_type):
|
240
|
+
self.name = method.__name__
|
241
|
+
self.signature = build_function_signature(method, method_type)
|
242
|
+
self.python_param_types = {
|
243
|
+
p.name: p.annotation
|
244
|
+
for p in inspect.signature(method).parameters.values()
|
245
|
+
if p.annotation != inspect.Parameter.empty
|
246
|
+
}
|
247
|
+
self.python_param_types.pop('self', None)
|
248
|
+
|
249
|
+
|
250
|
+
def predict(method):
|
251
|
+
setattr(method, _METHOD_INFO_ATTR, _MethodInfo(method, 'predict'))
|
252
|
+
return method
|
253
|
+
|
254
|
+
|
255
|
+
def generate(method):
|
256
|
+
setattr(method, _METHOD_INFO_ATTR, _MethodInfo(method, 'generate'))
|
257
|
+
return method
|
258
|
+
|
259
|
+
|
260
|
+
def stream(method):
|
261
|
+
setattr(method, _METHOD_INFO_ATTR, _MethodInfo(method, 'stream'))
|
262
|
+
return method
|
263
|
+
|
264
|
+
|
265
|
+
methods = types.SimpleNamespace(predict=predict, generate=generate, stream=stream)
|
@@ -481,8 +481,11 @@ def main(model_path,
|
|
481
481
|
)
|
482
482
|
sys.exit(1)
|
483
483
|
manager = ModelRunLocally(model_path)
|
484
|
-
#
|
485
|
-
|
484
|
+
# get whatever stage is in config.yaml to force download now
|
485
|
+
# also always write to where upload/build wants to, not the /tmp folder that runtime stage uses
|
486
|
+
_, _, _, when = manager.builder._validate_config_checkpoints()
|
487
|
+
manager.builder.download_checkpoints(
|
488
|
+
stage=when, checkpoint_path_override=manager.builder.checkpoint_path)
|
486
489
|
if inside_container:
|
487
490
|
if not manager.is_docker_installed():
|
488
491
|
sys.exit(1)
|
@@ -82,6 +82,8 @@ class ModelRunner(BaseRunner, HealthProbeRequestHandler):
|
|
82
82
|
ensure_urls_downloaded(request)
|
83
83
|
|
84
84
|
resp = self.model.predict_wrapper(request)
|
85
|
+
if resp.status.code != status_code_pb2.SUCCESS:
|
86
|
+
return service_pb2.RunnerItemOutput(multi_output_response=resp)
|
85
87
|
successes = [o.status.code == status_code_pb2.SUCCESS for o in resp.outputs]
|
86
88
|
if all(successes):
|
87
89
|
status = status_pb2.Status(
|
@@ -1,3 +1,4 @@
|
|
1
|
+
import os
|
1
2
|
from itertools import tee
|
2
3
|
from typing import Iterator
|
3
4
|
|
@@ -6,6 +7,8 @@ from clarifai_grpc.grpc.api.status import status_code_pb2, status_pb2
|
|
6
7
|
|
7
8
|
from ..utils.url_fetcher import ensure_urls_downloaded
|
8
9
|
|
10
|
+
_RAISE_EXCEPTIONS = os.getenv("RAISE_EXCEPTIONS", "false").lower() in ("true", "1")
|
11
|
+
|
9
12
|
|
10
13
|
class ModelServicer(service_pb2_grpc.V2Servicer):
|
11
14
|
"""
|
@@ -33,6 +36,8 @@ class ModelServicer(service_pb2_grpc.V2Servicer):
|
|
33
36
|
try:
|
34
37
|
return self.model.predict_wrapper(request)
|
35
38
|
except Exception as e:
|
39
|
+
if _RAISE_EXCEPTIONS:
|
40
|
+
raise
|
36
41
|
return service_pb2.MultiOutputResponse(status=status_pb2.Status(
|
37
42
|
code=status_code_pb2.MODEL_PREDICTION_FAILED,
|
38
43
|
description="Failed",
|
@@ -50,8 +55,10 @@ class ModelServicer(service_pb2_grpc.V2Servicer):
|
|
50
55
|
ensure_urls_downloaded(request)
|
51
56
|
|
52
57
|
try:
|
53
|
-
|
58
|
+
yield from self.model.generate_wrapper(request)
|
54
59
|
except Exception as e:
|
60
|
+
if _RAISE_EXCEPTIONS:
|
61
|
+
raise
|
55
62
|
yield service_pb2.MultiOutputResponse(status=status_pb2.Status(
|
56
63
|
code=status_code_pb2.MODEL_PREDICTION_FAILED,
|
57
64
|
description="Failed",
|
@@ -74,8 +81,10 @@ class ModelServicer(service_pb2_grpc.V2Servicer):
|
|
74
81
|
ensure_urls_downloaded(req)
|
75
82
|
|
76
83
|
try:
|
77
|
-
|
84
|
+
yield from self.model.stream_wrapper(request_copy)
|
78
85
|
except Exception as e:
|
86
|
+
if _RAISE_EXCEPTIONS:
|
87
|
+
raise
|
79
88
|
yield service_pb2.MultiOutputResponse(status=status_pb2.Status(
|
80
89
|
code=status_code_pb2.MODEL_PREDICTION_FAILED,
|
81
90
|
description="Failed",
|
clarifai/runners/server.py
CHANGED
@@ -68,30 +68,47 @@ def main():
|
|
68
68
|
|
69
69
|
parsed_args = parser.parse_args()
|
70
70
|
|
71
|
-
|
71
|
+
serve(parsed_args.model_path, parsed_args.port, parsed_args.pool_size,
|
72
|
+
parsed_args.max_queue_size, parsed_args.max_msg_length, parsed_args.enable_tls,
|
73
|
+
parsed_args.grpc)
|
74
|
+
|
75
|
+
|
76
|
+
def serve(model_path,
|
77
|
+
port=8000,
|
78
|
+
pool_size=32,
|
79
|
+
max_queue_size=10,
|
80
|
+
max_msg_length=1024 * 1024 * 1024,
|
81
|
+
enable_tls=False,
|
82
|
+
grpc=False):
|
83
|
+
|
84
|
+
builder = ModelBuilder(model_path, download_validation_only=True)
|
72
85
|
|
73
86
|
model = builder.create_model_instance()
|
74
87
|
|
88
|
+
# `num_threads` can be set in config.yaml or via the environment variable CLARIFAI_NUM_THREADS="<integer>".
|
89
|
+
# Note: The value in config.yaml takes precedence over the environment variable.
|
90
|
+
num_threads = builder.config.get("num_threads")
|
91
|
+
|
75
92
|
# Setup the grpc server for local development.
|
76
|
-
if
|
93
|
+
if grpc:
|
77
94
|
|
78
95
|
# initialize the servicer with the runner so that it gets the predict(), generate(), stream() classes.
|
79
96
|
servicer = ModelServicer(model)
|
80
97
|
|
81
98
|
server = GRPCServer(
|
82
99
|
futures.ThreadPoolExecutor(
|
83
|
-
max_workers=
|
100
|
+
max_workers=pool_size,
|
84
101
|
thread_name_prefix="ServeCalls",
|
85
102
|
),
|
86
|
-
|
87
|
-
|
103
|
+
max_msg_length,
|
104
|
+
max_queue_size,
|
88
105
|
)
|
89
|
-
server.add_port_to_server('[::]:%s' %
|
106
|
+
server.add_port_to_server('[::]:%s' % port, enable_tls)
|
90
107
|
|
91
108
|
service_pb2_grpc.add_V2Servicer_to_server(servicer, server)
|
92
109
|
server.start()
|
93
|
-
logger.info("Started server on port %s",
|
94
|
-
logger.info(f"Access the model at http://localhost:{
|
110
|
+
logger.info("Started server on port %s", port)
|
111
|
+
logger.info(f"Access the model at http://localhost:{port}")
|
95
112
|
server.wait_for_termination()
|
96
113
|
else: # start the runner with the proper env variables and as a runner protocol.
|
97
114
|
|
@@ -102,7 +119,7 @@ def main():
|
|
102
119
|
nodepool_id=os.environ["CLARIFAI_NODEPOOL_ID"],
|
103
120
|
compute_cluster_id=os.environ["CLARIFAI_COMPUTE_CLUSTER_ID"],
|
104
121
|
base_url=os.environ.get("CLARIFAI_API_BASE", "https://api.clarifai.com"),
|
105
|
-
num_parallel_polls=
|
122
|
+
num_parallel_polls=num_threads,
|
106
123
|
)
|
107
124
|
runner.start() # start the runner to fetch work from the API.
|
108
125
|
|
Binary file
|
Binary file
|
Binary file
|
Binary file
|
clarifai/runners/utils/const.py
CHANGED
@@ -16,7 +16,7 @@ DEFAULT_PYTHON_VERSION = 3.12
|
|
16
16
|
DEFAULT_DOWNLOAD_CHECKPOINT_WHEN = "runtime"
|
17
17
|
|
18
18
|
# Folder for downloading checkpoints at runtime.
|
19
|
-
DEFAULT_RUNTIME_DOWNLOAD_PATH = "
|
19
|
+
DEFAULT_RUNTIME_DOWNLOAD_PATH = os.path.join(os.sep, "tmp", ".cache")
|
20
20
|
|
21
21
|
# List of available torch images
|
22
22
|
# Keep sorted by most recent cuda version.
|