clarifai 10.8.1__py3-none-any.whl → 10.8.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,175 @@
1
+ from typing import Iterator
2
+
3
+ from clarifai_grpc.grpc.api import service_pb2
4
+ from clarifai_grpc.grpc.api.status import status_code_pb2, status_pb2
5
+
6
+ from clarifai_protocol import BaseRunner
7
+ from clarifai_protocol.utils.health import HealthProbeRequestHandler
8
+ from ..utils.url_fetcher import ensure_urls_downloaded
9
+
10
+ from .model_class import ModelClass
11
+
12
+
13
+ class ModelRunner(BaseRunner, ModelClass, HealthProbeRequestHandler):
14
+ """
15
+ This is a subclass of the runner class which will handle only the work items relevant to models.
16
+
17
+ It is also a subclass of ModelClass so that any subclass of ModelRunner will need to just
18
+ implement predict(), generate() and stream() methods and load_model() if needed.
19
+ """
20
+
21
+ def __init__(
22
+ self,
23
+ runner_id: str,
24
+ nodepool_id: str,
25
+ compute_cluster_id: str,
26
+ user_id: str = None,
27
+ check_runner_exists: bool = True,
28
+ base_url: str = "https://api.clarifai.com",
29
+ pat: str = None,
30
+ token: str = None,
31
+ num_parallel_polls: int = 4,
32
+ **kwargs,
33
+ ) -> None:
34
+ super().__init__(
35
+ runner_id,
36
+ nodepool_id,
37
+ compute_cluster_id,
38
+ user_id,
39
+ check_runner_exists,
40
+ base_url,
41
+ pat,
42
+ token,
43
+ num_parallel_polls,
44
+ **kwargs,
45
+ )
46
+ self.load_model()
47
+
48
+ # After model load successfully set the health probe to ready and startup
49
+ HealthProbeRequestHandler.is_ready = True
50
+ HealthProbeRequestHandler.is_startup = True
51
+
52
+ def get_runner_item_output_for_status(self,
53
+ status: status_pb2.Status) -> service_pb2.RunnerItemOutput:
54
+ """
55
+ Set the error message in the RunnerItemOutput message subfield, used during exception handling
56
+ where we may only have a status to return.
57
+
58
+ Args:
59
+ status: status_pb2.Status - the status to return
60
+
61
+ Returns:
62
+ service_pb2.RunnerItemOutput - the RunnerItemOutput message with the status set
63
+ """
64
+ rio = service_pb2.RunnerItemOutput(
65
+ multi_output_response=service_pb2.MultiOutputResponse(status=status))
66
+ return rio
67
+
68
+ def runner_item_predict(self,
69
+ runner_item: service_pb2.RunnerItem) -> service_pb2.RunnerItemOutput:
70
+ """
71
+ Run the model on the given request. You shouldn't need to override this method, see run_input
72
+ for the implementation to process each input in the request.
73
+
74
+ Args:
75
+ request: service_pb2.PostModelOutputsRequest - the request to run the model on
76
+
77
+ Returns:
78
+ service_pb2.MultiOutputResponse - the response from the model's run_input implementation.
79
+ """
80
+
81
+ if not runner_item.HasField('post_model_outputs_request'):
82
+ raise Exception("Unexpected work item type: {}".format(runner_item))
83
+ request = runner_item.post_model_outputs_request
84
+ ensure_urls_downloaded(request)
85
+
86
+ resp = self.predict_wrapper(request)
87
+ successes = [o.status.code == status_code_pb2.SUCCESS for o in resp.outputs]
88
+ if all(successes):
89
+ status = status_pb2.Status(
90
+ code=status_code_pb2.SUCCESS,
91
+ description="Success",
92
+ )
93
+ elif any(successes):
94
+ status = status_pb2.Status(
95
+ code=status_code_pb2.MIXED_STATUS,
96
+ description="Mixed Status",
97
+ )
98
+ else:
99
+ status = status_pb2.Status(
100
+ code=status_code_pb2.FAILURE,
101
+ description="Failed",
102
+ )
103
+
104
+ resp.status.CopyFrom(status)
105
+ return service_pb2.RunnerItemOutput(multi_output_response=resp)
106
+
107
+ def runner_item_generate(
108
+ self, runner_item: service_pb2.RunnerItem) -> Iterator[service_pb2.RunnerItemOutput]:
109
+ # Call the generate() method the underlying model implements.
110
+
111
+ if not runner_item.HasField('post_model_outputs_request'):
112
+ raise Exception("Unexpected work item type: {}".format(runner_item))
113
+ request = runner_item.post_model_outputs_request
114
+ ensure_urls_downloaded(request)
115
+
116
+ for resp in self.generate_wrapper(request):
117
+ successes = []
118
+ for output in resp.outputs:
119
+ if not output.HasField('status') or not output.status.code:
120
+ raise Exception("Output must have a status code, please check the model implementation.")
121
+ successes.append(output.status.code == status_code_pb2.SUCCESS)
122
+ if all(successes):
123
+ status = status_pb2.Status(
124
+ code=status_code_pb2.SUCCESS,
125
+ description="Success",
126
+ )
127
+ elif any(successes):
128
+ status = status_pb2.Status(
129
+ code=status_code_pb2.MIXED_STATUS,
130
+ description="Mixed Status",
131
+ )
132
+ else:
133
+ status = status_pb2.Status(
134
+ code=status_code_pb2.FAILURE,
135
+ description="Failed",
136
+ )
137
+ resp.status.CopyFrom(status)
138
+
139
+ yield service_pb2.RunnerItemOutput(multi_output_response=resp)
140
+
141
+ def runner_item_stream(self, runner_item_iterator: Iterator[service_pb2.RunnerItem]
142
+ ) -> Iterator[service_pb2.RunnerItemOutput]:
143
+ # Call the generate() method the underlying model implements.
144
+ for resp in self.stream_wrapper(pmo_iterator(runner_item_iterator)):
145
+ successes = []
146
+ for output in resp.outputs:
147
+ if not output.HasField('status') or not output.status.code:
148
+ raise Exception("Output must have a status code, please check the model implementation.")
149
+ successes.append(output.status.code == status_code_pb2.SUCCESS)
150
+ if all(successes):
151
+ status = status_pb2.Status(
152
+ code=status_code_pb2.SUCCESS,
153
+ description="Success",
154
+ )
155
+ elif any(successes):
156
+ status = status_pb2.Status(
157
+ code=status_code_pb2.MIXED_STATUS,
158
+ description="Mixed Status",
159
+ )
160
+ else:
161
+ status = status_pb2.Status(
162
+ code=status_code_pb2.FAILURE,
163
+ description="Failed",
164
+ )
165
+ resp.status.CopyFrom(status)
166
+
167
+ yield service_pb2.RunnerItemOutput(multi_output_response=resp)
168
+
169
+
170
+ def pmo_iterator(runner_item_iterator):
171
+ for runner_item in runner_item_iterator:
172
+ if not runner_item.HasField('post_model_outputs_request'):
173
+ raise Exception("Unexpected work item type: {}".format(runner_item))
174
+ ensure_urls_downloaded(runner_item.post_model_outputs_request)
175
+ yield runner_item.post_model_outputs_request
@@ -0,0 +1,79 @@
1
+ from itertools import tee
2
+ from typing import Iterator
3
+
4
+ from clarifai_grpc.grpc.api import service_pb2, service_pb2_grpc
5
+ from clarifai_grpc.grpc.api.status import status_code_pb2, status_pb2
6
+
7
+ from ..utils.url_fetcher import ensure_urls_downloaded
8
+
9
+
10
+ class ModelServicer(service_pb2_grpc.V2Servicer):
11
+ """
12
+ This is the servicer that will handle the gRPC requests from either the dev server or runner loop.
13
+ """
14
+
15
+ def __init__(self, model_class):
16
+ self.model_class = model_class
17
+
18
+ def PostModelOutputs(self, request: service_pb2.PostModelOutputsRequest,
19
+ context=None) -> service_pb2.MultiOutputResponse:
20
+ """
21
+ This is the method that will be called when the servicer is run. It takes in an input and
22
+ returns an output.
23
+ """
24
+
25
+ # Download any urls that are not already bytes.
26
+ ensure_urls_downloaded(self.url_fetcher, request)
27
+
28
+ try:
29
+ return self.model_class.predict(request)
30
+ except Exception as e:
31
+ return service_pb2.MultiOutputResponse(status=status_pb2.Status(
32
+ code=status_code_pb2.MODEL_PREDICTION_FAILED,
33
+ description="Failed",
34
+ details="",
35
+ internal_details=str(e),
36
+ ))
37
+
38
+ def GenerateModelOutputs(self, request: service_pb2.PostModelOutputsRequest,
39
+ context=None) -> Iterator[service_pb2.MultiOutputResponse]:
40
+ """
41
+ This is the method that will be called when the servicer is run. It takes in an input and
42
+ returns an output.
43
+ """
44
+ # Download any urls that are not already bytes.
45
+ ensure_urls_downloaded(self.url_fetcher, request)
46
+
47
+ try:
48
+ return self.model_class.generate(request)
49
+ except Exception as e:
50
+ yield service_pb2.MultiOutputResponse(status=status_pb2.Status(
51
+ code=status_code_pb2.MODEL_PREDICTION_FAILED,
52
+ description="Failed",
53
+ details="",
54
+ internal_details=str(e),
55
+ ))
56
+
57
+ def StreamModelOutputs(self,
58
+ request: Iterator[service_pb2.PostModelOutputsRequest],
59
+ context=None) -> Iterator[service_pb2.MultiOutputResponse]:
60
+ """
61
+ This is the method that will be called when the servicer is run. It takes in an input and
62
+ returns an output.
63
+ """
64
+ # Duplicate the iterator
65
+ request, request_copy = tee(request)
66
+
67
+ # Download any urls that are not already bytes.
68
+ for req in request:
69
+ ensure_urls_downloaded(self.url_fetcher, req)
70
+
71
+ try:
72
+ return self.model_class.stream(request_copy)
73
+ except Exception as e:
74
+ yield service_pb2.MultiOutputResponse(status=status_pb2.Status(
75
+ code=status_code_pb2.MODEL_PREDICTION_FAILED,
76
+ description="Failed",
77
+ details="",
78
+ internal_details=str(e),
79
+ ))
@@ -0,0 +1,315 @@
1
+ import argparse
2
+ import os
3
+ import time
4
+ from string import Template
5
+
6
+ import yaml
7
+ from clarifai_grpc.grpc.api import resources_pb2, service_pb2
8
+ from clarifai_grpc.grpc.api.status import status_code_pb2
9
+ from google.protobuf import json_format
10
+ from rich import print
11
+
12
+ from clarifai.client import BaseClient
13
+
14
+ from clarifai.runners.utils.loader import HuggingFaceLoarder
15
+
16
+
17
+ def _clear_line(n: int = 1) -> None:
18
+ LINE_UP = '\033[1A' # Move cursor up one line
19
+ LINE_CLEAR = '\x1b[2K' # Clear the entire line
20
+ for _ in range(n):
21
+ print(LINE_UP, end=LINE_CLEAR, flush=True)
22
+
23
+
24
+ class ModelUploader:
25
+ DEFAULT_PYTHON_VERSION = 3.11
26
+ CONCEPTS_REQUIRED_MODEL_TYPE = [
27
+ 'visual-classifier', 'visual-detector', 'visual-segmenter', 'text-classifier'
28
+ ]
29
+
30
+ def __init__(self, folder: str):
31
+ self.folder = self._validate_folder(folder)
32
+ self.config = self._load_config(os.path.join(self.folder, 'config.yaml'))
33
+ self.initialize_client()
34
+ self.model_proto = self._get_model_proto()
35
+ self.model_id = self.model_proto.id
36
+ self.user_app_id = self.client.user_app_id
37
+ self.inference_compute_info = self._get_inference_compute_info()
38
+ self.is_v3 = True # Do model build for v3
39
+
40
+ @staticmethod
41
+ def _validate_folder(folder):
42
+ if not folder.startswith("/"):
43
+ folder = os.path.join(os.getcwd(), folder)
44
+ print(f"Validating folder: {folder}")
45
+ files = os.listdir(folder)
46
+ assert "requirements.txt" in files, "requirements.txt not found in the folder"
47
+ assert "config.yaml" in files, "config.yaml not found in the folder"
48
+ assert "1" in files, "Subfolder '1' not found in the folder"
49
+ subfolder_files = os.listdir(os.path.join(folder, '1'))
50
+ assert 'model.py' in subfolder_files, "model.py not found in the folder"
51
+ return folder
52
+
53
+ @staticmethod
54
+ def _load_config(config_file: str):
55
+ with open(config_file, 'r') as file:
56
+ config = yaml.safe_load(file)
57
+ return config
58
+
59
+ def initialize_client(self):
60
+ assert "model" in self.config, "model info not found in the config file"
61
+ model = self.config.get('model')
62
+ assert "user_id" in model, "user_id not found in the config file"
63
+ assert "app_id" in model, "app_id not found in the config file"
64
+ user_id = model.get('user_id')
65
+ app_id = model.get('app_id')
66
+
67
+ base = os.environ.get('CLARIFAI_API_BASE', 'https://api-dev.clarifai.com')
68
+
69
+ self.client = BaseClient(user_id=user_id, app_id=app_id, base=base)
70
+ print(f"Client initialized for user {user_id} and app {app_id}")
71
+
72
+ def _get_model_proto(self):
73
+ assert "model" in self.config, "model info not found in the config file"
74
+ model = self.config.get('model')
75
+
76
+ assert "model_type_id" in model, "model_type_id not found in the config file"
77
+ assert "id" in model, "model_id not found in the config file"
78
+ assert "user_id" in model, "user_id not found in the config file"
79
+ assert "app_id" in model, "app_id not found in the config file"
80
+
81
+ model_proto = json_format.ParseDict(model, resources_pb2.Model())
82
+ assert model_proto.id == model_proto.id.lower(), "Model ID must be lowercase"
83
+ assert model_proto.user_id == model_proto.user_id.lower(), "User ID must be lowercase"
84
+ assert model_proto.app_id == model_proto.app_id.lower(), "App ID must be lowercase"
85
+
86
+ return model_proto
87
+
88
+ def _get_inference_compute_info(self):
89
+ assert ("inference_compute_info" in self.config
90
+ ), "inference_compute_info not found in the config file"
91
+ inference_compute_info = self.config.get('inference_compute_info')
92
+ return json_format.ParseDict(inference_compute_info, resources_pb2.ComputeInfo())
93
+
94
+ def maybe_create_model(self):
95
+ resp = self.client.STUB.GetModel(
96
+ service_pb2.GetModelRequest(
97
+ user_app_id=self.client.user_app_id, model_id=self.model_proto.id))
98
+ if resp.status.code == status_code_pb2.SUCCESS:
99
+ print(
100
+ f"Model '{self.client.user_app_id.user_id}/{self.client.user_app_id.app_id}/models/{self.model_proto.id}' already exists, "
101
+ f"will create a new version for it.")
102
+ return resp
103
+
104
+ request = service_pb2.PostModelsRequest(
105
+ user_app_id=self.client.user_app_id,
106
+ models=[self.model_proto],
107
+ )
108
+ return self.client.STUB.PostModels(request)
109
+
110
+ def create_dockerfile(self):
111
+ num_accelerators = self.inference_compute_info.num_accelerators
112
+ if num_accelerators:
113
+ dockerfile_template = os.path.join(
114
+ os.path.dirname(os.path.dirname(__file__)),
115
+ 'dockerfile_template',
116
+ 'Dockerfile.cuda.template',
117
+ )
118
+ else:
119
+ dockerfile_template = os.path.join(
120
+ os.path.dirname(os.path.dirname(__file__)), 'dockerfile_template',
121
+ 'Dockerfile.cpu.template')
122
+
123
+ with open(dockerfile_template, 'r') as template_file:
124
+ dockerfile_template = template_file.read()
125
+
126
+ dockerfile_template = Template(dockerfile_template)
127
+
128
+ # Get the Python version from the config file
129
+ build_info = self.config.get('build_info', {})
130
+ python_version = build_info.get('python_version', self.DEFAULT_PYTHON_VERSION)
131
+
132
+ # Replace placeholders with actual values
133
+ dockerfile_content = dockerfile_template.safe_substitute(
134
+ PYTHON_VERSION=python_version,
135
+ name='main',
136
+ )
137
+
138
+ # Write Dockerfile
139
+ with open(os.path.join(self.folder, 'Dockerfile'), 'w') as dockerfile:
140
+ dockerfile.write(dockerfile_content)
141
+
142
+ def download_checkpoints(self):
143
+ if not self.config.get("checkpoints"):
144
+ print("No checkpoints specified in the config file")
145
+ return
146
+
147
+ assert "type" in self.config.get("checkpoints"), "No loader type specified in the config file"
148
+ loader_type = self.config.get("checkpoints").get("type")
149
+ if not loader_type:
150
+ print("No loader type specified in the config file for checkpoints")
151
+ assert loader_type == "huggingface", "Only huggingface loader supported for now"
152
+ if loader_type == "huggingface":
153
+ assert "repo_id" in self.config.get("checkpoints"), "No repo_id specified in the config file"
154
+ repo_id = self.config.get("checkpoints").get("repo_id")
155
+
156
+ hf_token = self.config.get("checkpoints").get("hf_token", None)
157
+ loader = HuggingFaceLoarder(repo_id=repo_id, token=hf_token)
158
+
159
+ checkpoint_path = os.path.join(self.folder, '1', 'checkpoints')
160
+ loader.download_checkpoints(checkpoint_path)
161
+
162
+ print(f"Downloaded checkpoints for model {repo_id}")
163
+
164
+ def _concepts_protos_from_concepts(self, concepts):
165
+ concept_protos = []
166
+ for concept in concepts:
167
+ concept_protos.append(resources_pb2.Concept(
168
+ id=str(concept[0]),
169
+ name=concept[1],
170
+ ))
171
+ return concept_protos
172
+
173
+ def hf_labels_to_config(self, labels, config_file):
174
+ with open(config_file, 'r') as file:
175
+ config = yaml.safe_load(file)
176
+ model = config.get('model')
177
+ model_type_id = model.get('model_type_id')
178
+ assert model_type_id in self.CONCEPTS_REQUIRED_MODEL_TYPE, f"Model type {model_type_id} not supported for concepts"
179
+ concept_protos = self._concepts_protos_from_concepts(labels)
180
+
181
+ config['concepts'] = [{'id': concept.id, 'name': concept.name} for concept in concept_protos]
182
+
183
+ with open(config_file, 'w') as file:
184
+ yaml.dump(config, file, sort_keys=False)
185
+ concepts = config.get('concepts')
186
+ print(f"Updated config.yaml with {len(concepts)} concepts.")
187
+
188
+ def _get_model_version_proto(self):
189
+
190
+ model_version = resources_pb2.ModelVersion(
191
+ pretrained_model_config=resources_pb2.PretrainedModelConfig(),
192
+ inference_compute_info=self.inference_compute_info,
193
+ )
194
+
195
+ model_type_id = self.config.get('model').get('model_type_id')
196
+ if model_type_id in self.CONCEPTS_REQUIRED_MODEL_TYPE:
197
+
198
+ loader = HuggingFaceLoarder()
199
+ checkpoint_path = os.path.join(self.folder, '1', 'checkpoints')
200
+ labels = loader.fetch_labels(checkpoint_path)
201
+ # sort the concepts by id and then update the config file
202
+ labels = sorted(labels.items(), key=lambda x: int(x[0]))
203
+
204
+ config_file = os.path.join(self.folder, 'config.yaml')
205
+ self.hf_labels_to_config(labels, config_file)
206
+
207
+ model_version.output_info.data.concepts.extend(self._concepts_protos_from_concepts(labels))
208
+ return model_version
209
+
210
+ def upload_model_version(self):
211
+ file_path = f"{self.folder}.tar.gz"
212
+ print(f"Will tar it into file: {file_path}")
213
+
214
+ # Tar the folder
215
+ os.system(f"tar --exclude=*~ -czvf {self.folder}.tar.gz -C {self.folder} .")
216
+ print("Tarring complete, about to start upload.")
217
+
218
+ model_version = self._get_model_version_proto()
219
+
220
+ response = self.maybe_create_model()
221
+
222
+ for response in self.client.STUB.PostModelVersionsUpload(
223
+ self.model_version_stream_upload_iterator(model_version, file_path),):
224
+ percent_completed = 0
225
+ if response.status.code == status_code_pb2.UPLOAD_IN_PROGRESS:
226
+ percent_completed = response.status.percent_completed
227
+ details = response.status.details
228
+
229
+ _clear_line()
230
+ print(
231
+ f"Status: {response.status.description}, "
232
+ f"Progress: {percent_completed}% - {details} ",
233
+ end='\r',
234
+ flush=True)
235
+ print()
236
+ if response.status.code != status_code_pb2.MODEL_BUILDING:
237
+ print(f"Failed to upload model version: {response.status.description}")
238
+ return
239
+ model_version_id = response.model_version_id
240
+ print(f"Created Model Version ID: {model_version_id}")
241
+
242
+ self.monitor_model_build(model_version_id)
243
+
244
+ def model_version_stream_upload_iterator(self, model_version, file_path):
245
+ yield self.init_upload_model_version(model_version, file_path)
246
+ with open(file_path, "rb") as f:
247
+ file_size = os.path.getsize(file_path)
248
+ chunk_size = int(127 * 1024 * 1024) # 127MB chunk size
249
+ num_chunks = (file_size // chunk_size) + 1
250
+
251
+ read_so_far = 0
252
+ for part_id in range(num_chunks):
253
+ chunk = f.read(chunk_size)
254
+ read_so_far += len(chunk)
255
+ yield service_pb2.PostModelVersionsUploadRequest(
256
+ content_part=resources_pb2.UploadContentPart(
257
+ data=chunk,
258
+ part_number=part_id + 1,
259
+ range_start=read_so_far,
260
+ ))
261
+ print("\nUpload complete!, waiting for model build...")
262
+
263
+ def init_upload_model_version(self, model_version, file_path):
264
+ file_size = os.path.getsize(file_path)
265
+ print(
266
+ f"Uploading model version '{model_version.id}' with file '{os.path.basename(file_path)}' of size {file_size} bytes..."
267
+ )
268
+ return service_pb2.PostModelVersionsUploadRequest(
269
+ upload_config=service_pb2.PostModelVersionsUploadConfig(
270
+ user_app_id=self.client.user_app_id,
271
+ model_id=self.model_proto.id,
272
+ model_version=model_version,
273
+ total_size=file_size,
274
+ is_v3=self.is_v3,
275
+ ))
276
+
277
+ def monitor_model_build(self, model_version_id):
278
+ st = time.time()
279
+ while True:
280
+ resp = self.client.STUB.GetModelVersion(
281
+ service_pb2.GetModelVersionRequest(
282
+ user_app_id=self.client.user_app_id,
283
+ model_id=self.model_proto.id,
284
+ version_id=model_version_id,
285
+ ))
286
+ status_code = resp.model_version.status.code
287
+ if status_code == status_code_pb2.MODEL_BUILDING:
288
+ print(f"Model is building... (elapsed {time.time() - st:.1f}s)", end='\r', flush=True)
289
+ time.sleep(1)
290
+ elif status_code == status_code_pb2.MODEL_TRAINED:
291
+ print("\nModel build complete!")
292
+ print(
293
+ f"Check out the model at https://clarifai.com/{self.user_app_id.user_id}/apps/{self.user_app_id.app_id}/models/{self.model_id}/versions/{model_version_id}"
294
+ )
295
+ break
296
+ else:
297
+ print(f"\nModel build failed with status: {resp.model_version.status}")
298
+ break
299
+
300
+
301
+ def main(folder):
302
+ uploader = ModelUploader(folder)
303
+ uploader.download_checkpoints()
304
+ uploader.create_dockerfile()
305
+ input("Press Enter to continue...")
306
+ uploader.upload_model_version()
307
+
308
+
309
+ if __name__ == "__main__":
310
+ parser = argparse.ArgumentParser()
311
+ parser.add_argument(
312
+ '--model_path', type=str, help='Path of the model folder to upload', required=True)
313
+ args = parser.parse_args()
314
+
315
+ main(args.model_path)
@@ -0,0 +1,130 @@
1
+ """
2
+ This is simply the main file for the server that imports ModelRunner implementation
3
+ and starts the server.
4
+ """
5
+
6
+ import argparse
7
+ import importlib.util
8
+ import inspect
9
+ import os
10
+ import sys
11
+ from concurrent import futures
12
+
13
+ from clarifai_grpc.grpc.api import service_pb2_grpc
14
+ from clarifai_protocol import BaseRunner
15
+ from clarifai_protocol.utils.grpc_server import GRPCServer
16
+
17
+ from clarifai.runners.models.model_servicer import ModelServicer
18
+ from clarifai.runners.utils.logging import logger
19
+
20
+
21
+ def main():
22
+ parser = argparse.ArgumentParser()
23
+ parser.add_argument(
24
+ '--port',
25
+ type=int,
26
+ default=8000,
27
+ help="The port to host the gRPC server at.",
28
+ choices=range(1024, 65535),
29
+ )
30
+ parser.add_argument(
31
+ '--pool_size',
32
+ type=int,
33
+ default=32,
34
+ help="The number of threads to use for the gRPC server.",
35
+ choices=range(1, 129),
36
+ ) # pylint: disable=range-builtin-not-iterating
37
+ parser.add_argument(
38
+ '--max_queue_size',
39
+ type=int,
40
+ default=10,
41
+ help='Max queue size of requests before we begin to reject requests (default: 10).',
42
+ choices=range(1, 21),
43
+ ) # pylint: disable=range-builtin-not-iterating
44
+ parser.add_argument(
45
+ '--max_msg_length',
46
+ type=int,
47
+ default=1024 * 1024 * 1024,
48
+ help='Max message length of grpc requests (default: 1 GB).',
49
+ )
50
+ parser.add_argument(
51
+ '--enable_tls',
52
+ action='store_true',
53
+ default=False,
54
+ help=
55
+ 'Set to true to enable TLS (default: False) since this server is meant for local development only.',
56
+ )
57
+ parser.add_argument(
58
+ '--start_dev_server',
59
+ action='store_true',
60
+ default=False,
61
+ help=
62
+ 'Set to true to start the gRPC server (default: False). If set to false, the server will not start and only the runner loop will start to fetch work from the API.',
63
+ )
64
+ parser.add_argument(
65
+ '--model_path',
66
+ type=str,
67
+ required=True,
68
+ help='The path to the model directory that contains implemention of the model.',
69
+ )
70
+
71
+ parsed_args = parser.parse_args()
72
+
73
+ # import the runner class that to be implement by the user
74
+ runner_path = os.path.join(parsed_args.model_path, "1", "model.py")
75
+
76
+ # arbitrary name given to the module to be imported
77
+ module = "runner_module"
78
+
79
+ spec = importlib.util.spec_from_file_location(module, runner_path)
80
+ runner_module = importlib.util.module_from_spec(spec)
81
+ sys.modules[module] = runner_module
82
+ spec.loader.exec_module(runner_module)
83
+
84
+ # Find all classes in the model.py file that are subclasses of BaseRunner
85
+ classes = [
86
+ cls for _, cls in inspect.getmembers(runner_module, inspect.isclass)
87
+ if issubclass(cls, BaseRunner) and cls.__module__ == runner_module.__name__
88
+ ]
89
+
90
+ # Ensure there is exactly one subclass of BaseRunner in the model.py file
91
+ if len(classes) != 1:
92
+ raise Exception("Expected exactly one subclass of BaseRunner, found: {}".format(len(classes)))
93
+
94
+ MyRunner = classes[0]
95
+
96
+ # initialize the Runner class. This is what the user implements.
97
+ # (Note) do we want to set runner_id, nodepool_id, compute_cluster_id, base_url, num_parallel_polls as env vars? or as args?
98
+ runner = MyRunner(
99
+ runner_id=os.environ["CLARIFAI_RUNNER_ID"],
100
+ nodepool_id=os.environ["CLARIFAI_NODEPOOL_ID"],
101
+ compute_cluster_id=os.environ["CLARIFAI_COMPUTE_CLUSTER_ID"],
102
+ base_url=os.environ["CLARIFAI_API_BASE"],
103
+ num_parallel_polls=int(os.environ.get("CLARIFAI_NUM_THREADS", 1)),
104
+ )
105
+
106
+ # initialize the servicer
107
+ servicer = ModelServicer(runner)
108
+
109
+ # Setup the grpc server for local development.
110
+ if parsed_args.start_dev_server:
111
+ server = GRPCServer(
112
+ futures.ThreadPoolExecutor(
113
+ max_workers=parsed_args.pool_size,
114
+ thread_name_prefix="ServeCalls",
115
+ ),
116
+ parsed_args.max_msg_length,
117
+ parsed_args.max_queue_size,
118
+ )
119
+ server.add_port_to_server('[::]:%s' % parsed_args.port, parsed_args.enable_tls)
120
+
121
+ service_pb2_grpc.add_V2Servicer_to_server(servicer, server)
122
+ server.start()
123
+ logger.info("Started server on port %s", parsed_args.port)
124
+ # server.wait_for_termination() # won't get here currently.
125
+
126
+ runner.start() # start the runner loop to fetch work from the API.
127
+
128
+
129
+ if __name__ == '__main__':
130
+ main()
File without changes