clarifai 10.8.1__py3-none-any.whl → 10.8.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- clarifai/__init__.py +1 -1
- clarifai/client/app.py +3 -4
- clarifai/client/model.py +47 -15
- clarifai/constants/model.py +6 -0
- clarifai/models/model_serving/repo_build/static_files/base_test.py +4 -4
- clarifai/runners/__init__.py +14 -0
- clarifai/runners/dockerfile_template/Dockerfile.cpu.template +31 -0
- clarifai/runners/dockerfile_template/Dockerfile.cuda.template +129 -0
- clarifai/runners/models/__init__.py +0 -0
- clarifai/runners/models/base_typed_model.py +235 -0
- clarifai/runners/models/model_class.py +41 -0
- clarifai/runners/models/model_runner.py +175 -0
- clarifai/runners/models/model_servicer.py +79 -0
- clarifai/runners/models/model_upload.py +315 -0
- clarifai/runners/server.py +130 -0
- clarifai/runners/utils/__init__.py +0 -0
- clarifai/runners/utils/data_handler.py +244 -0
- clarifai/runners/utils/data_utils.py +15 -0
- clarifai/runners/utils/loader.py +70 -0
- clarifai/runners/utils/logging.py +6 -0
- clarifai/runners/utils/url_fetcher.py +42 -0
- clarifai/utils/logging.py +212 -6
- {clarifai-10.8.1.dist-info → clarifai-10.8.3.dist-info}/METADATA +3 -2
- {clarifai-10.8.1.dist-info → clarifai-10.8.3.dist-info}/RECORD +28 -12
- {clarifai-10.8.1.dist-info → clarifai-10.8.3.dist-info}/WHEEL +1 -1
- {clarifai-10.8.1.dist-info → clarifai-10.8.3.dist-info}/LICENSE +0 -0
- {clarifai-10.8.1.dist-info → clarifai-10.8.3.dist-info}/entry_points.txt +0 -0
- {clarifai-10.8.1.dist-info → clarifai-10.8.3.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,175 @@
|
|
1
|
+
from typing import Iterator
|
2
|
+
|
3
|
+
from clarifai_grpc.grpc.api import service_pb2
|
4
|
+
from clarifai_grpc.grpc.api.status import status_code_pb2, status_pb2
|
5
|
+
|
6
|
+
from clarifai_protocol import BaseRunner
|
7
|
+
from clarifai_protocol.utils.health import HealthProbeRequestHandler
|
8
|
+
from ..utils.url_fetcher import ensure_urls_downloaded
|
9
|
+
|
10
|
+
from .model_class import ModelClass
|
11
|
+
|
12
|
+
|
13
|
+
class ModelRunner(BaseRunner, ModelClass, HealthProbeRequestHandler):
|
14
|
+
"""
|
15
|
+
This is a subclass of the runner class which will handle only the work items relevant to models.
|
16
|
+
|
17
|
+
It is also a subclass of ModelClass so that any subclass of ModelRunner will need to just
|
18
|
+
implement predict(), generate() and stream() methods and load_model() if needed.
|
19
|
+
"""
|
20
|
+
|
21
|
+
def __init__(
|
22
|
+
self,
|
23
|
+
runner_id: str,
|
24
|
+
nodepool_id: str,
|
25
|
+
compute_cluster_id: str,
|
26
|
+
user_id: str = None,
|
27
|
+
check_runner_exists: bool = True,
|
28
|
+
base_url: str = "https://api.clarifai.com",
|
29
|
+
pat: str = None,
|
30
|
+
token: str = None,
|
31
|
+
num_parallel_polls: int = 4,
|
32
|
+
**kwargs,
|
33
|
+
) -> None:
|
34
|
+
super().__init__(
|
35
|
+
runner_id,
|
36
|
+
nodepool_id,
|
37
|
+
compute_cluster_id,
|
38
|
+
user_id,
|
39
|
+
check_runner_exists,
|
40
|
+
base_url,
|
41
|
+
pat,
|
42
|
+
token,
|
43
|
+
num_parallel_polls,
|
44
|
+
**kwargs,
|
45
|
+
)
|
46
|
+
self.load_model()
|
47
|
+
|
48
|
+
# After model load successfully set the health probe to ready and startup
|
49
|
+
HealthProbeRequestHandler.is_ready = True
|
50
|
+
HealthProbeRequestHandler.is_startup = True
|
51
|
+
|
52
|
+
def get_runner_item_output_for_status(self,
|
53
|
+
status: status_pb2.Status) -> service_pb2.RunnerItemOutput:
|
54
|
+
"""
|
55
|
+
Set the error message in the RunnerItemOutput message subfield, used during exception handling
|
56
|
+
where we may only have a status to return.
|
57
|
+
|
58
|
+
Args:
|
59
|
+
status: status_pb2.Status - the status to return
|
60
|
+
|
61
|
+
Returns:
|
62
|
+
service_pb2.RunnerItemOutput - the RunnerItemOutput message with the status set
|
63
|
+
"""
|
64
|
+
rio = service_pb2.RunnerItemOutput(
|
65
|
+
multi_output_response=service_pb2.MultiOutputResponse(status=status))
|
66
|
+
return rio
|
67
|
+
|
68
|
+
def runner_item_predict(self,
|
69
|
+
runner_item: service_pb2.RunnerItem) -> service_pb2.RunnerItemOutput:
|
70
|
+
"""
|
71
|
+
Run the model on the given request. You shouldn't need to override this method, see run_input
|
72
|
+
for the implementation to process each input in the request.
|
73
|
+
|
74
|
+
Args:
|
75
|
+
request: service_pb2.PostModelOutputsRequest - the request to run the model on
|
76
|
+
|
77
|
+
Returns:
|
78
|
+
service_pb2.MultiOutputResponse - the response from the model's run_input implementation.
|
79
|
+
"""
|
80
|
+
|
81
|
+
if not runner_item.HasField('post_model_outputs_request'):
|
82
|
+
raise Exception("Unexpected work item type: {}".format(runner_item))
|
83
|
+
request = runner_item.post_model_outputs_request
|
84
|
+
ensure_urls_downloaded(request)
|
85
|
+
|
86
|
+
resp = self.predict_wrapper(request)
|
87
|
+
successes = [o.status.code == status_code_pb2.SUCCESS for o in resp.outputs]
|
88
|
+
if all(successes):
|
89
|
+
status = status_pb2.Status(
|
90
|
+
code=status_code_pb2.SUCCESS,
|
91
|
+
description="Success",
|
92
|
+
)
|
93
|
+
elif any(successes):
|
94
|
+
status = status_pb2.Status(
|
95
|
+
code=status_code_pb2.MIXED_STATUS,
|
96
|
+
description="Mixed Status",
|
97
|
+
)
|
98
|
+
else:
|
99
|
+
status = status_pb2.Status(
|
100
|
+
code=status_code_pb2.FAILURE,
|
101
|
+
description="Failed",
|
102
|
+
)
|
103
|
+
|
104
|
+
resp.status.CopyFrom(status)
|
105
|
+
return service_pb2.RunnerItemOutput(multi_output_response=resp)
|
106
|
+
|
107
|
+
def runner_item_generate(
|
108
|
+
self, runner_item: service_pb2.RunnerItem) -> Iterator[service_pb2.RunnerItemOutput]:
|
109
|
+
# Call the generate() method the underlying model implements.
|
110
|
+
|
111
|
+
if not runner_item.HasField('post_model_outputs_request'):
|
112
|
+
raise Exception("Unexpected work item type: {}".format(runner_item))
|
113
|
+
request = runner_item.post_model_outputs_request
|
114
|
+
ensure_urls_downloaded(request)
|
115
|
+
|
116
|
+
for resp in self.generate_wrapper(request):
|
117
|
+
successes = []
|
118
|
+
for output in resp.outputs:
|
119
|
+
if not output.HasField('status') or not output.status.code:
|
120
|
+
raise Exception("Output must have a status code, please check the model implementation.")
|
121
|
+
successes.append(output.status.code == status_code_pb2.SUCCESS)
|
122
|
+
if all(successes):
|
123
|
+
status = status_pb2.Status(
|
124
|
+
code=status_code_pb2.SUCCESS,
|
125
|
+
description="Success",
|
126
|
+
)
|
127
|
+
elif any(successes):
|
128
|
+
status = status_pb2.Status(
|
129
|
+
code=status_code_pb2.MIXED_STATUS,
|
130
|
+
description="Mixed Status",
|
131
|
+
)
|
132
|
+
else:
|
133
|
+
status = status_pb2.Status(
|
134
|
+
code=status_code_pb2.FAILURE,
|
135
|
+
description="Failed",
|
136
|
+
)
|
137
|
+
resp.status.CopyFrom(status)
|
138
|
+
|
139
|
+
yield service_pb2.RunnerItemOutput(multi_output_response=resp)
|
140
|
+
|
141
|
+
def runner_item_stream(self, runner_item_iterator: Iterator[service_pb2.RunnerItem]
|
142
|
+
) -> Iterator[service_pb2.RunnerItemOutput]:
|
143
|
+
# Call the generate() method the underlying model implements.
|
144
|
+
for resp in self.stream_wrapper(pmo_iterator(runner_item_iterator)):
|
145
|
+
successes = []
|
146
|
+
for output in resp.outputs:
|
147
|
+
if not output.HasField('status') or not output.status.code:
|
148
|
+
raise Exception("Output must have a status code, please check the model implementation.")
|
149
|
+
successes.append(output.status.code == status_code_pb2.SUCCESS)
|
150
|
+
if all(successes):
|
151
|
+
status = status_pb2.Status(
|
152
|
+
code=status_code_pb2.SUCCESS,
|
153
|
+
description="Success",
|
154
|
+
)
|
155
|
+
elif any(successes):
|
156
|
+
status = status_pb2.Status(
|
157
|
+
code=status_code_pb2.MIXED_STATUS,
|
158
|
+
description="Mixed Status",
|
159
|
+
)
|
160
|
+
else:
|
161
|
+
status = status_pb2.Status(
|
162
|
+
code=status_code_pb2.FAILURE,
|
163
|
+
description="Failed",
|
164
|
+
)
|
165
|
+
resp.status.CopyFrom(status)
|
166
|
+
|
167
|
+
yield service_pb2.RunnerItemOutput(multi_output_response=resp)
|
168
|
+
|
169
|
+
|
170
|
+
def pmo_iterator(runner_item_iterator):
|
171
|
+
for runner_item in runner_item_iterator:
|
172
|
+
if not runner_item.HasField('post_model_outputs_request'):
|
173
|
+
raise Exception("Unexpected work item type: {}".format(runner_item))
|
174
|
+
ensure_urls_downloaded(runner_item.post_model_outputs_request)
|
175
|
+
yield runner_item.post_model_outputs_request
|
@@ -0,0 +1,79 @@
|
|
1
|
+
from itertools import tee
|
2
|
+
from typing import Iterator
|
3
|
+
|
4
|
+
from clarifai_grpc.grpc.api import service_pb2, service_pb2_grpc
|
5
|
+
from clarifai_grpc.grpc.api.status import status_code_pb2, status_pb2
|
6
|
+
|
7
|
+
from ..utils.url_fetcher import ensure_urls_downloaded
|
8
|
+
|
9
|
+
|
10
|
+
class ModelServicer(service_pb2_grpc.V2Servicer):
|
11
|
+
"""
|
12
|
+
This is the servicer that will handle the gRPC requests from either the dev server or runner loop.
|
13
|
+
"""
|
14
|
+
|
15
|
+
def __init__(self, model_class):
|
16
|
+
self.model_class = model_class
|
17
|
+
|
18
|
+
def PostModelOutputs(self, request: service_pb2.PostModelOutputsRequest,
|
19
|
+
context=None) -> service_pb2.MultiOutputResponse:
|
20
|
+
"""
|
21
|
+
This is the method that will be called when the servicer is run. It takes in an input and
|
22
|
+
returns an output.
|
23
|
+
"""
|
24
|
+
|
25
|
+
# Download any urls that are not already bytes.
|
26
|
+
ensure_urls_downloaded(self.url_fetcher, request)
|
27
|
+
|
28
|
+
try:
|
29
|
+
return self.model_class.predict(request)
|
30
|
+
except Exception as e:
|
31
|
+
return service_pb2.MultiOutputResponse(status=status_pb2.Status(
|
32
|
+
code=status_code_pb2.MODEL_PREDICTION_FAILED,
|
33
|
+
description="Failed",
|
34
|
+
details="",
|
35
|
+
internal_details=str(e),
|
36
|
+
))
|
37
|
+
|
38
|
+
def GenerateModelOutputs(self, request: service_pb2.PostModelOutputsRequest,
|
39
|
+
context=None) -> Iterator[service_pb2.MultiOutputResponse]:
|
40
|
+
"""
|
41
|
+
This is the method that will be called when the servicer is run. It takes in an input and
|
42
|
+
returns an output.
|
43
|
+
"""
|
44
|
+
# Download any urls that are not already bytes.
|
45
|
+
ensure_urls_downloaded(self.url_fetcher, request)
|
46
|
+
|
47
|
+
try:
|
48
|
+
return self.model_class.generate(request)
|
49
|
+
except Exception as e:
|
50
|
+
yield service_pb2.MultiOutputResponse(status=status_pb2.Status(
|
51
|
+
code=status_code_pb2.MODEL_PREDICTION_FAILED,
|
52
|
+
description="Failed",
|
53
|
+
details="",
|
54
|
+
internal_details=str(e),
|
55
|
+
))
|
56
|
+
|
57
|
+
def StreamModelOutputs(self,
|
58
|
+
request: Iterator[service_pb2.PostModelOutputsRequest],
|
59
|
+
context=None) -> Iterator[service_pb2.MultiOutputResponse]:
|
60
|
+
"""
|
61
|
+
This is the method that will be called when the servicer is run. It takes in an input and
|
62
|
+
returns an output.
|
63
|
+
"""
|
64
|
+
# Duplicate the iterator
|
65
|
+
request, request_copy = tee(request)
|
66
|
+
|
67
|
+
# Download any urls that are not already bytes.
|
68
|
+
for req in request:
|
69
|
+
ensure_urls_downloaded(self.url_fetcher, req)
|
70
|
+
|
71
|
+
try:
|
72
|
+
return self.model_class.stream(request_copy)
|
73
|
+
except Exception as e:
|
74
|
+
yield service_pb2.MultiOutputResponse(status=status_pb2.Status(
|
75
|
+
code=status_code_pb2.MODEL_PREDICTION_FAILED,
|
76
|
+
description="Failed",
|
77
|
+
details="",
|
78
|
+
internal_details=str(e),
|
79
|
+
))
|
@@ -0,0 +1,315 @@
|
|
1
|
+
import argparse
|
2
|
+
import os
|
3
|
+
import time
|
4
|
+
from string import Template
|
5
|
+
|
6
|
+
import yaml
|
7
|
+
from clarifai_grpc.grpc.api import resources_pb2, service_pb2
|
8
|
+
from clarifai_grpc.grpc.api.status import status_code_pb2
|
9
|
+
from google.protobuf import json_format
|
10
|
+
from rich import print
|
11
|
+
|
12
|
+
from clarifai.client import BaseClient
|
13
|
+
|
14
|
+
from clarifai.runners.utils.loader import HuggingFaceLoarder
|
15
|
+
|
16
|
+
|
17
|
+
def _clear_line(n: int = 1) -> None:
|
18
|
+
LINE_UP = '\033[1A' # Move cursor up one line
|
19
|
+
LINE_CLEAR = '\x1b[2K' # Clear the entire line
|
20
|
+
for _ in range(n):
|
21
|
+
print(LINE_UP, end=LINE_CLEAR, flush=True)
|
22
|
+
|
23
|
+
|
24
|
+
class ModelUploader:
|
25
|
+
DEFAULT_PYTHON_VERSION = 3.11
|
26
|
+
CONCEPTS_REQUIRED_MODEL_TYPE = [
|
27
|
+
'visual-classifier', 'visual-detector', 'visual-segmenter', 'text-classifier'
|
28
|
+
]
|
29
|
+
|
30
|
+
def __init__(self, folder: str):
|
31
|
+
self.folder = self._validate_folder(folder)
|
32
|
+
self.config = self._load_config(os.path.join(self.folder, 'config.yaml'))
|
33
|
+
self.initialize_client()
|
34
|
+
self.model_proto = self._get_model_proto()
|
35
|
+
self.model_id = self.model_proto.id
|
36
|
+
self.user_app_id = self.client.user_app_id
|
37
|
+
self.inference_compute_info = self._get_inference_compute_info()
|
38
|
+
self.is_v3 = True # Do model build for v3
|
39
|
+
|
40
|
+
@staticmethod
|
41
|
+
def _validate_folder(folder):
|
42
|
+
if not folder.startswith("/"):
|
43
|
+
folder = os.path.join(os.getcwd(), folder)
|
44
|
+
print(f"Validating folder: {folder}")
|
45
|
+
files = os.listdir(folder)
|
46
|
+
assert "requirements.txt" in files, "requirements.txt not found in the folder"
|
47
|
+
assert "config.yaml" in files, "config.yaml not found in the folder"
|
48
|
+
assert "1" in files, "Subfolder '1' not found in the folder"
|
49
|
+
subfolder_files = os.listdir(os.path.join(folder, '1'))
|
50
|
+
assert 'model.py' in subfolder_files, "model.py not found in the folder"
|
51
|
+
return folder
|
52
|
+
|
53
|
+
@staticmethod
|
54
|
+
def _load_config(config_file: str):
|
55
|
+
with open(config_file, 'r') as file:
|
56
|
+
config = yaml.safe_load(file)
|
57
|
+
return config
|
58
|
+
|
59
|
+
def initialize_client(self):
|
60
|
+
assert "model" in self.config, "model info not found in the config file"
|
61
|
+
model = self.config.get('model')
|
62
|
+
assert "user_id" in model, "user_id not found in the config file"
|
63
|
+
assert "app_id" in model, "app_id not found in the config file"
|
64
|
+
user_id = model.get('user_id')
|
65
|
+
app_id = model.get('app_id')
|
66
|
+
|
67
|
+
base = os.environ.get('CLARIFAI_API_BASE', 'https://api-dev.clarifai.com')
|
68
|
+
|
69
|
+
self.client = BaseClient(user_id=user_id, app_id=app_id, base=base)
|
70
|
+
print(f"Client initialized for user {user_id} and app {app_id}")
|
71
|
+
|
72
|
+
def _get_model_proto(self):
|
73
|
+
assert "model" in self.config, "model info not found in the config file"
|
74
|
+
model = self.config.get('model')
|
75
|
+
|
76
|
+
assert "model_type_id" in model, "model_type_id not found in the config file"
|
77
|
+
assert "id" in model, "model_id not found in the config file"
|
78
|
+
assert "user_id" in model, "user_id not found in the config file"
|
79
|
+
assert "app_id" in model, "app_id not found in the config file"
|
80
|
+
|
81
|
+
model_proto = json_format.ParseDict(model, resources_pb2.Model())
|
82
|
+
assert model_proto.id == model_proto.id.lower(), "Model ID must be lowercase"
|
83
|
+
assert model_proto.user_id == model_proto.user_id.lower(), "User ID must be lowercase"
|
84
|
+
assert model_proto.app_id == model_proto.app_id.lower(), "App ID must be lowercase"
|
85
|
+
|
86
|
+
return model_proto
|
87
|
+
|
88
|
+
def _get_inference_compute_info(self):
|
89
|
+
assert ("inference_compute_info" in self.config
|
90
|
+
), "inference_compute_info not found in the config file"
|
91
|
+
inference_compute_info = self.config.get('inference_compute_info')
|
92
|
+
return json_format.ParseDict(inference_compute_info, resources_pb2.ComputeInfo())
|
93
|
+
|
94
|
+
def maybe_create_model(self):
|
95
|
+
resp = self.client.STUB.GetModel(
|
96
|
+
service_pb2.GetModelRequest(
|
97
|
+
user_app_id=self.client.user_app_id, model_id=self.model_proto.id))
|
98
|
+
if resp.status.code == status_code_pb2.SUCCESS:
|
99
|
+
print(
|
100
|
+
f"Model '{self.client.user_app_id.user_id}/{self.client.user_app_id.app_id}/models/{self.model_proto.id}' already exists, "
|
101
|
+
f"will create a new version for it.")
|
102
|
+
return resp
|
103
|
+
|
104
|
+
request = service_pb2.PostModelsRequest(
|
105
|
+
user_app_id=self.client.user_app_id,
|
106
|
+
models=[self.model_proto],
|
107
|
+
)
|
108
|
+
return self.client.STUB.PostModels(request)
|
109
|
+
|
110
|
+
def create_dockerfile(self):
|
111
|
+
num_accelerators = self.inference_compute_info.num_accelerators
|
112
|
+
if num_accelerators:
|
113
|
+
dockerfile_template = os.path.join(
|
114
|
+
os.path.dirname(os.path.dirname(__file__)),
|
115
|
+
'dockerfile_template',
|
116
|
+
'Dockerfile.cuda.template',
|
117
|
+
)
|
118
|
+
else:
|
119
|
+
dockerfile_template = os.path.join(
|
120
|
+
os.path.dirname(os.path.dirname(__file__)), 'dockerfile_template',
|
121
|
+
'Dockerfile.cpu.template')
|
122
|
+
|
123
|
+
with open(dockerfile_template, 'r') as template_file:
|
124
|
+
dockerfile_template = template_file.read()
|
125
|
+
|
126
|
+
dockerfile_template = Template(dockerfile_template)
|
127
|
+
|
128
|
+
# Get the Python version from the config file
|
129
|
+
build_info = self.config.get('build_info', {})
|
130
|
+
python_version = build_info.get('python_version', self.DEFAULT_PYTHON_VERSION)
|
131
|
+
|
132
|
+
# Replace placeholders with actual values
|
133
|
+
dockerfile_content = dockerfile_template.safe_substitute(
|
134
|
+
PYTHON_VERSION=python_version,
|
135
|
+
name='main',
|
136
|
+
)
|
137
|
+
|
138
|
+
# Write Dockerfile
|
139
|
+
with open(os.path.join(self.folder, 'Dockerfile'), 'w') as dockerfile:
|
140
|
+
dockerfile.write(dockerfile_content)
|
141
|
+
|
142
|
+
def download_checkpoints(self):
|
143
|
+
if not self.config.get("checkpoints"):
|
144
|
+
print("No checkpoints specified in the config file")
|
145
|
+
return
|
146
|
+
|
147
|
+
assert "type" in self.config.get("checkpoints"), "No loader type specified in the config file"
|
148
|
+
loader_type = self.config.get("checkpoints").get("type")
|
149
|
+
if not loader_type:
|
150
|
+
print("No loader type specified in the config file for checkpoints")
|
151
|
+
assert loader_type == "huggingface", "Only huggingface loader supported for now"
|
152
|
+
if loader_type == "huggingface":
|
153
|
+
assert "repo_id" in self.config.get("checkpoints"), "No repo_id specified in the config file"
|
154
|
+
repo_id = self.config.get("checkpoints").get("repo_id")
|
155
|
+
|
156
|
+
hf_token = self.config.get("checkpoints").get("hf_token", None)
|
157
|
+
loader = HuggingFaceLoarder(repo_id=repo_id, token=hf_token)
|
158
|
+
|
159
|
+
checkpoint_path = os.path.join(self.folder, '1', 'checkpoints')
|
160
|
+
loader.download_checkpoints(checkpoint_path)
|
161
|
+
|
162
|
+
print(f"Downloaded checkpoints for model {repo_id}")
|
163
|
+
|
164
|
+
def _concepts_protos_from_concepts(self, concepts):
|
165
|
+
concept_protos = []
|
166
|
+
for concept in concepts:
|
167
|
+
concept_protos.append(resources_pb2.Concept(
|
168
|
+
id=str(concept[0]),
|
169
|
+
name=concept[1],
|
170
|
+
))
|
171
|
+
return concept_protos
|
172
|
+
|
173
|
+
def hf_labels_to_config(self, labels, config_file):
|
174
|
+
with open(config_file, 'r') as file:
|
175
|
+
config = yaml.safe_load(file)
|
176
|
+
model = config.get('model')
|
177
|
+
model_type_id = model.get('model_type_id')
|
178
|
+
assert model_type_id in self.CONCEPTS_REQUIRED_MODEL_TYPE, f"Model type {model_type_id} not supported for concepts"
|
179
|
+
concept_protos = self._concepts_protos_from_concepts(labels)
|
180
|
+
|
181
|
+
config['concepts'] = [{'id': concept.id, 'name': concept.name} for concept in concept_protos]
|
182
|
+
|
183
|
+
with open(config_file, 'w') as file:
|
184
|
+
yaml.dump(config, file, sort_keys=False)
|
185
|
+
concepts = config.get('concepts')
|
186
|
+
print(f"Updated config.yaml with {len(concepts)} concepts.")
|
187
|
+
|
188
|
+
def _get_model_version_proto(self):
|
189
|
+
|
190
|
+
model_version = resources_pb2.ModelVersion(
|
191
|
+
pretrained_model_config=resources_pb2.PretrainedModelConfig(),
|
192
|
+
inference_compute_info=self.inference_compute_info,
|
193
|
+
)
|
194
|
+
|
195
|
+
model_type_id = self.config.get('model').get('model_type_id')
|
196
|
+
if model_type_id in self.CONCEPTS_REQUIRED_MODEL_TYPE:
|
197
|
+
|
198
|
+
loader = HuggingFaceLoarder()
|
199
|
+
checkpoint_path = os.path.join(self.folder, '1', 'checkpoints')
|
200
|
+
labels = loader.fetch_labels(checkpoint_path)
|
201
|
+
# sort the concepts by id and then update the config file
|
202
|
+
labels = sorted(labels.items(), key=lambda x: int(x[0]))
|
203
|
+
|
204
|
+
config_file = os.path.join(self.folder, 'config.yaml')
|
205
|
+
self.hf_labels_to_config(labels, config_file)
|
206
|
+
|
207
|
+
model_version.output_info.data.concepts.extend(self._concepts_protos_from_concepts(labels))
|
208
|
+
return model_version
|
209
|
+
|
210
|
+
def upload_model_version(self):
|
211
|
+
file_path = f"{self.folder}.tar.gz"
|
212
|
+
print(f"Will tar it into file: {file_path}")
|
213
|
+
|
214
|
+
# Tar the folder
|
215
|
+
os.system(f"tar --exclude=*~ -czvf {self.folder}.tar.gz -C {self.folder} .")
|
216
|
+
print("Tarring complete, about to start upload.")
|
217
|
+
|
218
|
+
model_version = self._get_model_version_proto()
|
219
|
+
|
220
|
+
response = self.maybe_create_model()
|
221
|
+
|
222
|
+
for response in self.client.STUB.PostModelVersionsUpload(
|
223
|
+
self.model_version_stream_upload_iterator(model_version, file_path),):
|
224
|
+
percent_completed = 0
|
225
|
+
if response.status.code == status_code_pb2.UPLOAD_IN_PROGRESS:
|
226
|
+
percent_completed = response.status.percent_completed
|
227
|
+
details = response.status.details
|
228
|
+
|
229
|
+
_clear_line()
|
230
|
+
print(
|
231
|
+
f"Status: {response.status.description}, "
|
232
|
+
f"Progress: {percent_completed}% - {details} ",
|
233
|
+
end='\r',
|
234
|
+
flush=True)
|
235
|
+
print()
|
236
|
+
if response.status.code != status_code_pb2.MODEL_BUILDING:
|
237
|
+
print(f"Failed to upload model version: {response.status.description}")
|
238
|
+
return
|
239
|
+
model_version_id = response.model_version_id
|
240
|
+
print(f"Created Model Version ID: {model_version_id}")
|
241
|
+
|
242
|
+
self.monitor_model_build(model_version_id)
|
243
|
+
|
244
|
+
def model_version_stream_upload_iterator(self, model_version, file_path):
|
245
|
+
yield self.init_upload_model_version(model_version, file_path)
|
246
|
+
with open(file_path, "rb") as f:
|
247
|
+
file_size = os.path.getsize(file_path)
|
248
|
+
chunk_size = int(127 * 1024 * 1024) # 127MB chunk size
|
249
|
+
num_chunks = (file_size // chunk_size) + 1
|
250
|
+
|
251
|
+
read_so_far = 0
|
252
|
+
for part_id in range(num_chunks):
|
253
|
+
chunk = f.read(chunk_size)
|
254
|
+
read_so_far += len(chunk)
|
255
|
+
yield service_pb2.PostModelVersionsUploadRequest(
|
256
|
+
content_part=resources_pb2.UploadContentPart(
|
257
|
+
data=chunk,
|
258
|
+
part_number=part_id + 1,
|
259
|
+
range_start=read_so_far,
|
260
|
+
))
|
261
|
+
print("\nUpload complete!, waiting for model build...")
|
262
|
+
|
263
|
+
def init_upload_model_version(self, model_version, file_path):
|
264
|
+
file_size = os.path.getsize(file_path)
|
265
|
+
print(
|
266
|
+
f"Uploading model version '{model_version.id}' with file '{os.path.basename(file_path)}' of size {file_size} bytes..."
|
267
|
+
)
|
268
|
+
return service_pb2.PostModelVersionsUploadRequest(
|
269
|
+
upload_config=service_pb2.PostModelVersionsUploadConfig(
|
270
|
+
user_app_id=self.client.user_app_id,
|
271
|
+
model_id=self.model_proto.id,
|
272
|
+
model_version=model_version,
|
273
|
+
total_size=file_size,
|
274
|
+
is_v3=self.is_v3,
|
275
|
+
))
|
276
|
+
|
277
|
+
def monitor_model_build(self, model_version_id):
|
278
|
+
st = time.time()
|
279
|
+
while True:
|
280
|
+
resp = self.client.STUB.GetModelVersion(
|
281
|
+
service_pb2.GetModelVersionRequest(
|
282
|
+
user_app_id=self.client.user_app_id,
|
283
|
+
model_id=self.model_proto.id,
|
284
|
+
version_id=model_version_id,
|
285
|
+
))
|
286
|
+
status_code = resp.model_version.status.code
|
287
|
+
if status_code == status_code_pb2.MODEL_BUILDING:
|
288
|
+
print(f"Model is building... (elapsed {time.time() - st:.1f}s)", end='\r', flush=True)
|
289
|
+
time.sleep(1)
|
290
|
+
elif status_code == status_code_pb2.MODEL_TRAINED:
|
291
|
+
print("\nModel build complete!")
|
292
|
+
print(
|
293
|
+
f"Check out the model at https://clarifai.com/{self.user_app_id.user_id}/apps/{self.user_app_id.app_id}/models/{self.model_id}/versions/{model_version_id}"
|
294
|
+
)
|
295
|
+
break
|
296
|
+
else:
|
297
|
+
print(f"\nModel build failed with status: {resp.model_version.status}")
|
298
|
+
break
|
299
|
+
|
300
|
+
|
301
|
+
def main(folder):
|
302
|
+
uploader = ModelUploader(folder)
|
303
|
+
uploader.download_checkpoints()
|
304
|
+
uploader.create_dockerfile()
|
305
|
+
input("Press Enter to continue...")
|
306
|
+
uploader.upload_model_version()
|
307
|
+
|
308
|
+
|
309
|
+
if __name__ == "__main__":
|
310
|
+
parser = argparse.ArgumentParser()
|
311
|
+
parser.add_argument(
|
312
|
+
'--model_path', type=str, help='Path of the model folder to upload', required=True)
|
313
|
+
args = parser.parse_args()
|
314
|
+
|
315
|
+
main(args.model_path)
|
@@ -0,0 +1,130 @@
|
|
1
|
+
"""
|
2
|
+
This is simply the main file for the server that imports ModelRunner implementation
|
3
|
+
and starts the server.
|
4
|
+
"""
|
5
|
+
|
6
|
+
import argparse
|
7
|
+
import importlib.util
|
8
|
+
import inspect
|
9
|
+
import os
|
10
|
+
import sys
|
11
|
+
from concurrent import futures
|
12
|
+
|
13
|
+
from clarifai_grpc.grpc.api import service_pb2_grpc
|
14
|
+
from clarifai_protocol import BaseRunner
|
15
|
+
from clarifai_protocol.utils.grpc_server import GRPCServer
|
16
|
+
|
17
|
+
from clarifai.runners.models.model_servicer import ModelServicer
|
18
|
+
from clarifai.runners.utils.logging import logger
|
19
|
+
|
20
|
+
|
21
|
+
def main():
|
22
|
+
parser = argparse.ArgumentParser()
|
23
|
+
parser.add_argument(
|
24
|
+
'--port',
|
25
|
+
type=int,
|
26
|
+
default=8000,
|
27
|
+
help="The port to host the gRPC server at.",
|
28
|
+
choices=range(1024, 65535),
|
29
|
+
)
|
30
|
+
parser.add_argument(
|
31
|
+
'--pool_size',
|
32
|
+
type=int,
|
33
|
+
default=32,
|
34
|
+
help="The number of threads to use for the gRPC server.",
|
35
|
+
choices=range(1, 129),
|
36
|
+
) # pylint: disable=range-builtin-not-iterating
|
37
|
+
parser.add_argument(
|
38
|
+
'--max_queue_size',
|
39
|
+
type=int,
|
40
|
+
default=10,
|
41
|
+
help='Max queue size of requests before we begin to reject requests (default: 10).',
|
42
|
+
choices=range(1, 21),
|
43
|
+
) # pylint: disable=range-builtin-not-iterating
|
44
|
+
parser.add_argument(
|
45
|
+
'--max_msg_length',
|
46
|
+
type=int,
|
47
|
+
default=1024 * 1024 * 1024,
|
48
|
+
help='Max message length of grpc requests (default: 1 GB).',
|
49
|
+
)
|
50
|
+
parser.add_argument(
|
51
|
+
'--enable_tls',
|
52
|
+
action='store_true',
|
53
|
+
default=False,
|
54
|
+
help=
|
55
|
+
'Set to true to enable TLS (default: False) since this server is meant for local development only.',
|
56
|
+
)
|
57
|
+
parser.add_argument(
|
58
|
+
'--start_dev_server',
|
59
|
+
action='store_true',
|
60
|
+
default=False,
|
61
|
+
help=
|
62
|
+
'Set to true to start the gRPC server (default: False). If set to false, the server will not start and only the runner loop will start to fetch work from the API.',
|
63
|
+
)
|
64
|
+
parser.add_argument(
|
65
|
+
'--model_path',
|
66
|
+
type=str,
|
67
|
+
required=True,
|
68
|
+
help='The path to the model directory that contains implemention of the model.',
|
69
|
+
)
|
70
|
+
|
71
|
+
parsed_args = parser.parse_args()
|
72
|
+
|
73
|
+
# import the runner class that to be implement by the user
|
74
|
+
runner_path = os.path.join(parsed_args.model_path, "1", "model.py")
|
75
|
+
|
76
|
+
# arbitrary name given to the module to be imported
|
77
|
+
module = "runner_module"
|
78
|
+
|
79
|
+
spec = importlib.util.spec_from_file_location(module, runner_path)
|
80
|
+
runner_module = importlib.util.module_from_spec(spec)
|
81
|
+
sys.modules[module] = runner_module
|
82
|
+
spec.loader.exec_module(runner_module)
|
83
|
+
|
84
|
+
# Find all classes in the model.py file that are subclasses of BaseRunner
|
85
|
+
classes = [
|
86
|
+
cls for _, cls in inspect.getmembers(runner_module, inspect.isclass)
|
87
|
+
if issubclass(cls, BaseRunner) and cls.__module__ == runner_module.__name__
|
88
|
+
]
|
89
|
+
|
90
|
+
# Ensure there is exactly one subclass of BaseRunner in the model.py file
|
91
|
+
if len(classes) != 1:
|
92
|
+
raise Exception("Expected exactly one subclass of BaseRunner, found: {}".format(len(classes)))
|
93
|
+
|
94
|
+
MyRunner = classes[0]
|
95
|
+
|
96
|
+
# initialize the Runner class. This is what the user implements.
|
97
|
+
# (Note) do we want to set runner_id, nodepool_id, compute_cluster_id, base_url, num_parallel_polls as env vars? or as args?
|
98
|
+
runner = MyRunner(
|
99
|
+
runner_id=os.environ["CLARIFAI_RUNNER_ID"],
|
100
|
+
nodepool_id=os.environ["CLARIFAI_NODEPOOL_ID"],
|
101
|
+
compute_cluster_id=os.environ["CLARIFAI_COMPUTE_CLUSTER_ID"],
|
102
|
+
base_url=os.environ["CLARIFAI_API_BASE"],
|
103
|
+
num_parallel_polls=int(os.environ.get("CLARIFAI_NUM_THREADS", 1)),
|
104
|
+
)
|
105
|
+
|
106
|
+
# initialize the servicer
|
107
|
+
servicer = ModelServicer(runner)
|
108
|
+
|
109
|
+
# Setup the grpc server for local development.
|
110
|
+
if parsed_args.start_dev_server:
|
111
|
+
server = GRPCServer(
|
112
|
+
futures.ThreadPoolExecutor(
|
113
|
+
max_workers=parsed_args.pool_size,
|
114
|
+
thread_name_prefix="ServeCalls",
|
115
|
+
),
|
116
|
+
parsed_args.max_msg_length,
|
117
|
+
parsed_args.max_queue_size,
|
118
|
+
)
|
119
|
+
server.add_port_to_server('[::]:%s' % parsed_args.port, parsed_args.enable_tls)
|
120
|
+
|
121
|
+
service_pb2_grpc.add_V2Servicer_to_server(servicer, server)
|
122
|
+
server.start()
|
123
|
+
logger.info("Started server on port %s", parsed_args.port)
|
124
|
+
# server.wait_for_termination() # won't get here currently.
|
125
|
+
|
126
|
+
runner.start() # start the runner loop to fetch work from the API.
|
127
|
+
|
128
|
+
|
129
|
+
if __name__ == '__main__':
|
130
|
+
main()
|
File without changes
|