clarifai 11.4.1__py3-none-any.whl → 11.4.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- clarifai/__init__.py +1 -1
- clarifai/cli/base.py +7 -0
- clarifai/cli/model.py +6 -8
- clarifai/client/app.py +2 -1
- clarifai/client/auth/helper.py +6 -4
- clarifai/client/compute_cluster.py +2 -1
- clarifai/client/dataset.py +8 -1
- clarifai/client/deployment.py +2 -1
- clarifai/client/input.py +2 -1
- clarifai/client/model.py +2 -1
- clarifai/client/model_client.py +1 -1
- clarifai/client/module.py +2 -1
- clarifai/client/nodepool.py +2 -1
- clarifai/client/runner.py +2 -1
- clarifai/client/search.py +2 -1
- clarifai/client/user.py +2 -1
- clarifai/client/workflow.py +2 -1
- clarifai/runners/models/mcp_class.py +114 -0
- clarifai/runners/models/model_builder.py +179 -46
- clarifai/runners/models/model_class.py +5 -22
- clarifai/runners/models/model_run_locally.py +0 -4
- clarifai/runners/models/visual_classifier_class.py +75 -0
- clarifai/runners/models/visual_detector_class.py +79 -0
- clarifai/runners/utils/code_script.py +75 -44
- clarifai/runners/utils/const.py +15 -0
- clarifai/runners/utils/data_types/data_types.py +48 -0
- clarifai/runners/utils/data_utils.py +99 -45
- clarifai/runners/utils/loader.py +23 -2
- clarifai/runners/utils/method_signatures.py +4 -4
- clarifai/runners/utils/openai_convertor.py +103 -0
- clarifai/urls/helper.py +80 -12
- clarifai/utils/config.py +19 -0
- clarifai/utils/constants.py +4 -0
- clarifai/utils/logging.py +22 -5
- {clarifai-11.4.1.dist-info → clarifai-11.4.3.dist-info}/METADATA +1 -2
- {clarifai-11.4.1.dist-info → clarifai-11.4.3.dist-info}/RECORD +40 -37
- {clarifai-11.4.1.dist-info → clarifai-11.4.3.dist-info}/WHEEL +1 -1
- {clarifai-11.4.1.dist-info → clarifai-11.4.3.dist-info}/entry_points.txt +0 -0
- {clarifai-11.4.1.dist-info → clarifai-11.4.3.dist-info}/licenses/LICENSE +0 -0
- {clarifai-11.4.1.dist-info → clarifai-11.4.3.dist-info}/top_level.txt +0 -0
@@ -3,7 +3,8 @@ from typing import List
|
|
3
3
|
|
4
4
|
from clarifai_grpc.grpc.api import resources_pb2
|
5
5
|
|
6
|
-
from clarifai.runners.utils import
|
6
|
+
from clarifai.runners.utils import data_utils
|
7
|
+
from clarifai.urls.helper import ClarifaiUrlHelper
|
7
8
|
|
8
9
|
|
9
10
|
def generate_client_script(
|
@@ -15,6 +16,38 @@ def generate_client_script(
|
|
15
16
|
deployment_id: str = None,
|
16
17
|
use_ctx: bool = False,
|
17
18
|
) -> str:
|
19
|
+
url_helper = ClarifaiUrlHelper()
|
20
|
+
|
21
|
+
# Provide an mcp client config
|
22
|
+
if len(method_signatures) == 1 and method_signatures[0].name == "mcp_transport":
|
23
|
+
api_url = url_helper.api_url(
|
24
|
+
user_id,
|
25
|
+
app_id,
|
26
|
+
"models",
|
27
|
+
model_id,
|
28
|
+
)
|
29
|
+
|
30
|
+
_CLIENT_TEMPLATE = """
|
31
|
+
import asyncio
|
32
|
+
import os
|
33
|
+
from fastmcp import Client
|
34
|
+
from fastmcp.client.transports import StreamableHttpTransport
|
35
|
+
|
36
|
+
transport = StreamableHttpTransport(url="%s/mcp",
|
37
|
+
headers={"Authorization": "Bearer " + os.environ["CLARIFAI_PAT"]})
|
38
|
+
|
39
|
+
async def main():
|
40
|
+
async with Client(transport) as client:
|
41
|
+
tools = await client.list_tools()
|
42
|
+
print(f"Available tools: {tools}")
|
43
|
+
result = await client.call_tool(tools[0].name, {"a": 5, "b": 3})
|
44
|
+
print(f"Result: {result[0].text}")
|
45
|
+
|
46
|
+
if __name__ == "__main__":
|
47
|
+
asyncio.run(main())
|
48
|
+
"""
|
49
|
+
return _CLIENT_TEMPLATE % api_url
|
50
|
+
|
18
51
|
_CLIENT_TEMPLATE = """\
|
19
52
|
import os
|
20
53
|
|
@@ -35,19 +68,13 @@ from clarifai.runners.utils import data_types
|
|
35
68
|
model_section = """
|
36
69
|
model = Model.from_current_context()"""
|
37
70
|
else:
|
38
|
-
|
39
|
-
|
71
|
+
model_ui_url = url_helper.clarifai_url(user_id, app_id, "models", model_id)
|
72
|
+
model_section = f"""
|
73
|
+
model = Model({model_ui_url},
|
40
74
|
deployment_id = {deployment_id}, # Only needed for dedicated deployed models
|
41
75
|
{base_url_str}
|
42
76
|
)
|
43
77
|
"""
|
44
|
-
model_section = _CLIENT_TEMPLATE.format(
|
45
|
-
user_id=user_id,
|
46
|
-
app_id=app_id,
|
47
|
-
model_id=model_id,
|
48
|
-
deployment_id=deployment_id,
|
49
|
-
base_url_str=base_url_str,
|
50
|
-
)
|
51
78
|
|
52
79
|
# Generate client template
|
53
80
|
client_template = _CLIENT_TEMPLATE.format(
|
@@ -58,28 +85,24 @@ model = Model.from_current_context()"""
|
|
58
85
|
method_signatures_str = []
|
59
86
|
for method_signature in method_signatures:
|
60
87
|
method_name = method_signature.name
|
61
|
-
|
62
|
-
|
63
|
-
|
64
|
-
|
65
|
-
|
66
|
-
|
67
|
-
|
68
|
-
|
69
|
-
|
70
|
-
|
71
|
-
|
72
|
-
|
73
|
-
|
74
|
-
|
75
|
-
|
76
|
-
client_script_str
|
77
|
-
|
78
|
-
|
79
|
-
elif method_signature.method_type == resources_pb2.RunnerMethodType.UNARY_STREAMING:
|
80
|
-
client_script_str += "\nfor res in response:\n print(res)"
|
81
|
-
client_script_str += "\n"
|
82
|
-
method_signatures_str.append(client_script_str)
|
88
|
+
client_script_str = f'response = model.{method_name}('
|
89
|
+
annotations = _get_annotations_source(method_signature)
|
90
|
+
for param_name, (param_type, default_value) in annotations.items():
|
91
|
+
print(
|
92
|
+
f"param_name: {param_name}, param_type: {param_type}, default_value: {default_value}"
|
93
|
+
)
|
94
|
+
if param_name == "return":
|
95
|
+
continue
|
96
|
+
if default_value is None:
|
97
|
+
default_value = _set_default_value(param_type)
|
98
|
+
client_script_str += f"{param_name}={default_value}, "
|
99
|
+
client_script_str = client_script_str.rstrip(", ") + ")"
|
100
|
+
if method_signature.method_type == resources_pb2.RunnerMethodType.UNARY_UNARY:
|
101
|
+
client_script_str += "\nprint(response)"
|
102
|
+
elif method_signature.method_type == resources_pb2.RunnerMethodType.UNARY_STREAMING:
|
103
|
+
client_script_str += "\nfor res in response:\n print(res)"
|
104
|
+
client_script_str += "\n"
|
105
|
+
method_signatures_str.append(client_script_str)
|
83
106
|
|
84
107
|
method_signatures_str = "\n".join(method_signatures_str)
|
85
108
|
# Combine all parts
|
@@ -107,9 +130,8 @@ def _get_annotations_source(method_signature: resources_pb2.MethodSignature) ->
|
|
107
130
|
if input_field.iterator:
|
108
131
|
param_type = f"Iterator[{param_type}]"
|
109
132
|
default_value = None
|
110
|
-
if input_field
|
133
|
+
if data_utils.Param.get_default(input_field):
|
111
134
|
default_value = _parse_default_value(input_field)
|
112
|
-
|
113
135
|
annotations[param_name] = (param_type, default_value)
|
114
136
|
if not method_signature.output_fields:
|
115
137
|
raise ValueError("MethodSignature must have at least one output field")
|
@@ -177,23 +199,21 @@ def _map_default_value(field_type):
|
|
177
199
|
elif field_type == "bool":
|
178
200
|
default_value = False
|
179
201
|
elif field_type == "data_types.Image":
|
180
|
-
default_value = data_types.Image.from_url("https://samples.clarifai.com/metro-north.jpg")
|
202
|
+
default_value = 'data_types.Image.from_url("https://samples.clarifai.com/metro-north.jpg")'
|
181
203
|
elif field_type == "data_types.Text":
|
182
|
-
default_value = data_types.Text("What
|
204
|
+
default_value = 'data_types.Text("What is the future of AI?")'
|
183
205
|
elif field_type == "data_types.Audio":
|
184
|
-
default_value = data_types.Audio.from_url("https://samples.clarifai.com/audio.mp3")
|
206
|
+
default_value = 'data_types.Audio.from_url("https://samples.clarifai.com/audio.mp3")'
|
185
207
|
elif field_type == "data_types.Video":
|
186
|
-
default_value = data_types.Video.from_url("https://samples.clarifai.com/video.mp4")
|
208
|
+
default_value = 'data_types.Video.from_url("https://samples.clarifai.com/video.mp4")'
|
187
209
|
elif field_type == "data_types.Concept":
|
188
|
-
default_value = data_types.Concept(id="concept_id", name="dog", value=0.95)
|
210
|
+
default_value = 'data_types.Concept(id="concept_id", name="dog", value=0.95)'
|
189
211
|
elif field_type == "data_types.Region":
|
190
|
-
default_value = data_types.Region(
|
191
|
-
box=[0.1, 0.1, 0.5, 0.5],
|
192
|
-
)
|
212
|
+
default_value = 'data_types.Region(box=[0.1, 0.1, 0.5, 0.5],)'
|
193
213
|
elif field_type == "data_types.Frame":
|
194
|
-
default_value = data_types.Frame.from_url("https://samples.clarifai.com/video.mp4", 0)
|
214
|
+
default_value = 'data_types.Frame.from_url("https://samples.clarifai.com/video.mp4", 0)'
|
195
215
|
elif field_type == "data_types.NDArray":
|
196
|
-
default_value = data_types.NDArray([1, 2, 3])
|
216
|
+
default_value = 'data_types.NDArray([1, 2, 3])'
|
197
217
|
else:
|
198
218
|
default_value = None
|
199
219
|
return default_value
|
@@ -203,6 +223,12 @@ def _set_default_value(field_type):
|
|
203
223
|
"""
|
204
224
|
Set the default value of a field if it is not set.
|
205
225
|
"""
|
226
|
+
is_iterator = False
|
227
|
+
print(f"before field_type: {field_type}")
|
228
|
+
if field_type.startswith("Iterator["):
|
229
|
+
is_iterator = True
|
230
|
+
field_type = field_type[9:-1]
|
231
|
+
print(f"after field_type: {field_type}")
|
206
232
|
default_value = None
|
207
233
|
default_value = _map_default_value(field_type)
|
208
234
|
if field_type.startswith("List["):
|
@@ -219,6 +245,11 @@ def _set_default_value(field_type):
|
|
219
245
|
element_type_defaults = [_map_default_value(et) for et in element_types]
|
220
246
|
default_value = f"{{{', '.join([str(et) for et in element_type_defaults])}}}"
|
221
247
|
|
248
|
+
if field_type == 'str':
|
249
|
+
default_value = repr(default_value)
|
250
|
+
if is_iterator:
|
251
|
+
default_value = f'iter([{default_value}])'
|
252
|
+
print(f"after default_value: {default_value}")
|
222
253
|
return default_value
|
223
254
|
|
224
255
|
|
clarifai/runners/utils/const.py
CHANGED
@@ -4,14 +4,28 @@ registry = os.environ.get('CLARIFAI_BASE_IMAGE_REGISTRY', 'public.ecr.aws/clarif
|
|
4
4
|
|
5
5
|
GIT_SHA = "b8ae56bf3b7c95e686ca002b07ca83d259c716eb"
|
6
6
|
|
7
|
+
AMD_GIT_SHA = "81e942130173f54927e7c9a65aabc7e32780616d"
|
8
|
+
|
7
9
|
PYTHON_BASE_IMAGE = registry + '/python-base:{python_version}-' + GIT_SHA
|
8
10
|
TORCH_BASE_IMAGE = registry + '/torch:{torch_version}-py{python_version}-{gpu_version}-' + GIT_SHA
|
9
11
|
|
12
|
+
AMD_PYTHON_BASE_IMAGE = registry + '/amd-python-base:{python_version}-' + AMD_GIT_SHA
|
13
|
+
AMD_TORCH_BASE_IMAGE = (
|
14
|
+
registry + '/amd-torch:{torch_version}-py{python_version}-{gpu_version}-' + AMD_GIT_SHA
|
15
|
+
)
|
16
|
+
AMD_VLLM_BASE_IMAGE = (
|
17
|
+
registry + '/amd-vllm:{torch_version}-py{python_version}-{gpu_version}-' + AMD_GIT_SHA
|
18
|
+
)
|
19
|
+
|
10
20
|
# List of available python base images
|
11
21
|
AVAILABLE_PYTHON_IMAGES = ['3.11', '3.12']
|
12
22
|
|
13
23
|
DEFAULT_PYTHON_VERSION = 3.12
|
14
24
|
|
25
|
+
DEFAULT_AMD_TORCH_VERSION = '2.8.0.dev20250511+rocm6.4'
|
26
|
+
|
27
|
+
DEFAULT_AMD_GPU_VERSION = 'rocm6.4'
|
28
|
+
|
15
29
|
# By default we download at runtime.
|
16
30
|
DEFAULT_DOWNLOAD_CHECKPOINT_WHEN = "runtime"
|
17
31
|
|
@@ -29,6 +43,7 @@ AVAILABLE_TORCH_IMAGES = [
|
|
29
43
|
'2.7.0-py3.12-cu128',
|
30
44
|
'2.7.0-py3.12-rocm6.3',
|
31
45
|
]
|
46
|
+
|
32
47
|
CONCEPTS_REQUIRED_MODEL_TYPE = [
|
33
48
|
'visual-classifier',
|
34
49
|
'visual-detector',
|
@@ -395,6 +395,22 @@ class Image(MessageData):
|
|
395
395
|
raise ValueError("Image has no bytes")
|
396
396
|
return PILImage.open(io.BytesIO(self.proto.base64))
|
397
397
|
|
398
|
+
def to_base64_str(self) -> str:
|
399
|
+
if not self.proto.base64:
|
400
|
+
raise ValueError("Image has no bytes")
|
401
|
+
if isinstance(self.proto.base64, str):
|
402
|
+
return self.proto.base64
|
403
|
+
if isinstance(self.proto.base64, bytes):
|
404
|
+
try:
|
405
|
+
# trying direct decode (if already a base64 bytes)
|
406
|
+
return self.proto.base64.decode('utf-8')
|
407
|
+
except UnicodeDecodeError:
|
408
|
+
import base64
|
409
|
+
|
410
|
+
return base64.b64encode(self.proto.base64).decode('utf-8')
|
411
|
+
else:
|
412
|
+
raise TypeError("Expected str or bytes for Image.base64")
|
413
|
+
|
398
414
|
def to_numpy(self) -> np.ndarray:
|
399
415
|
return np.asarray(self.to_pil())
|
400
416
|
|
@@ -466,6 +482,22 @@ class Audio(MessageData):
|
|
466
482
|
def to_proto(self) -> AudioProto:
|
467
483
|
return self.proto
|
468
484
|
|
485
|
+
def to_base64_str(self) -> str:
|
486
|
+
if not self.proto.base64:
|
487
|
+
raise ValueError("Audio has no bytes")
|
488
|
+
if isinstance(self.proto.base64, str):
|
489
|
+
return self.proto.base64
|
490
|
+
if isinstance(self.proto.base64, bytes):
|
491
|
+
try:
|
492
|
+
# trying direct decode (if already a base64 bytes)
|
493
|
+
return self.proto.base64.decode('utf-8')
|
494
|
+
except UnicodeDecodeError:
|
495
|
+
import base64
|
496
|
+
|
497
|
+
return base64.b64encode(self.proto.base64).decode('utf-8')
|
498
|
+
else:
|
499
|
+
raise TypeError("Expected str or bytes for Audio.base64")
|
500
|
+
|
469
501
|
@classmethod
|
470
502
|
def from_proto(cls, proto: AudioProto) -> "Audio":
|
471
503
|
return cls(proto)
|
@@ -578,6 +610,22 @@ class Video(MessageData):
|
|
578
610
|
def to_proto(self) -> VideoProto:
|
579
611
|
return self.proto
|
580
612
|
|
613
|
+
def to_base64_str(self) -> str:
|
614
|
+
if not self.proto.base64:
|
615
|
+
raise ValueError("Video has no bytes")
|
616
|
+
if isinstance(self.proto.base64, str):
|
617
|
+
return self.proto.base64
|
618
|
+
if isinstance(self.proto.base64, bytes):
|
619
|
+
try:
|
620
|
+
# trying direct decode (if already a base64 bytes)
|
621
|
+
return self.proto.base64.decode('utf-8')
|
622
|
+
except UnicodeDecodeError:
|
623
|
+
import base64
|
624
|
+
|
625
|
+
return base64.b64encode(self.proto.base64).decode('utf-8')
|
626
|
+
else:
|
627
|
+
raise TypeError("Expected str or bytes for Video.base64")
|
628
|
+
|
581
629
|
@classmethod
|
582
630
|
def from_proto(cls, proto: VideoProto) -> "Video":
|
583
631
|
return cls(proto)
|
@@ -1,62 +1,89 @@
|
|
1
|
+
import base64
|
2
|
+
import json
|
1
3
|
import math
|
2
4
|
import operator
|
3
5
|
from io import BytesIO
|
4
|
-
from typing import List
|
6
|
+
from typing import Dict, List
|
5
7
|
|
8
|
+
import requests
|
6
9
|
from clarifai_grpc.grpc.api import resources_pb2
|
7
10
|
from clarifai_grpc.grpc.api.resources_pb2 import ModelTypeEnumOption, ModelTypeRangeInfo
|
8
11
|
from clarifai_grpc.grpc.api.resources_pb2 import ModelTypeField as ParamProto
|
9
|
-
from PIL import Image
|
12
|
+
from PIL import Image as PILImage
|
10
13
|
|
11
|
-
from clarifai.runners.utils.data_types import MessageData
|
14
|
+
from clarifai.runners.utils.data_types import Audio, Image, MessageData, Video
|
12
15
|
|
13
16
|
|
14
|
-
def image_to_bytes(img:
|
17
|
+
def image_to_bytes(img: PILImage.Image, format="JPEG") -> bytes:
|
15
18
|
buffered = BytesIO()
|
16
19
|
img.save(buffered, format=format)
|
17
20
|
img_str = buffered.getvalue()
|
18
21
|
return img_str
|
19
22
|
|
20
23
|
|
21
|
-
def bytes_to_image(bytes_img) ->
|
22
|
-
img =
|
24
|
+
def bytes_to_image(bytes_img) -> PILImage.Image:
|
25
|
+
img = PILImage.open(BytesIO(bytes_img))
|
23
26
|
return img
|
24
27
|
|
25
28
|
|
26
|
-
def
|
27
|
-
"""
|
28
|
-
|
29
|
-
|
30
|
-
|
31
|
-
|
32
|
-
|
33
|
-
|
34
|
-
|
35
|
-
|
36
|
-
|
37
|
-
|
38
|
-
|
39
|
-
|
40
|
-
|
41
|
-
|
42
|
-
|
43
|
-
|
44
|
-
|
45
|
-
|
46
|
-
|
47
|
-
|
48
|
-
|
49
|
-
|
50
|
-
|
51
|
-
|
52
|
-
|
53
|
-
|
54
|
-
|
55
|
-
|
56
|
-
|
57
|
-
|
58
|
-
|
59
|
-
return
|
29
|
+
def process_image(image: Image) -> Dict:
|
30
|
+
"""Convert Clarifai Image object to OpenAI image format."""
|
31
|
+
|
32
|
+
if image.bytes:
|
33
|
+
b64_img = image.to_base64_str()
|
34
|
+
return {'type': 'image_url', 'image_url': {'url': f"data:image/jpeg;base64,{b64_img}"}}
|
35
|
+
elif image.url:
|
36
|
+
return {'type': 'image_url', 'image_url': {'url': image.url}}
|
37
|
+
else:
|
38
|
+
raise ValueError("Image must contain either bytes or URL")
|
39
|
+
|
40
|
+
|
41
|
+
def process_audio(audio: Audio) -> Dict:
|
42
|
+
"""Convert Clarifai Audio object to OpenAI audio format."""
|
43
|
+
|
44
|
+
if audio.bytes:
|
45
|
+
audio = audio.to_base64_str()
|
46
|
+
audio = {
|
47
|
+
"type": "input_audio",
|
48
|
+
"input_audio": {"data": audio, "format": "wav"},
|
49
|
+
}
|
50
|
+
elif audio.url:
|
51
|
+
response = requests.get(audio.url)
|
52
|
+
if response.status_code != 200:
|
53
|
+
raise ValueError(f"Failed to fetch audio from URL: {audio.url}")
|
54
|
+
audio_base64_str = base64.b64encode(response.content).decode('utf-8')
|
55
|
+
audio = {
|
56
|
+
"type": "input_audio",
|
57
|
+
"input_audio": {"data": audio_base64_str, "format": "wav"},
|
58
|
+
}
|
59
|
+
else:
|
60
|
+
raise ValueError("Audio must contain either bytes or URL")
|
61
|
+
|
62
|
+
return audio
|
63
|
+
|
64
|
+
|
65
|
+
def process_video(video: Video) -> Dict:
|
66
|
+
"""Convert Clarifai Video object to OpenAI video format."""
|
67
|
+
|
68
|
+
if video.bytes:
|
69
|
+
video = "data:video/mp4;base64," + video.to_base64_str()
|
70
|
+
video = {
|
71
|
+
"type": "video_url",
|
72
|
+
"video_url": {"url": video},
|
73
|
+
}
|
74
|
+
elif video.url:
|
75
|
+
response = requests.get(video.url)
|
76
|
+
if response.status_code != 200:
|
77
|
+
raise ValueError(f"Failed to fetch video from URL: {video.url}")
|
78
|
+
video_base64_str = base64.b64encode(response.content).decode('utf-8')
|
79
|
+
video = {
|
80
|
+
"type": "video_url",
|
81
|
+
"video_url": {"url": video_base64_str},
|
82
|
+
}
|
83
|
+
else:
|
84
|
+
raise ValueError("Video must contain either bytes or URL")
|
85
|
+
|
86
|
+
return video
|
60
87
|
|
61
88
|
|
62
89
|
class Param(MessageData):
|
@@ -64,7 +91,7 @@ class Param(MessageData):
|
|
64
91
|
|
65
92
|
def __init__(
|
66
93
|
self,
|
67
|
-
default
|
94
|
+
default,
|
68
95
|
description=None,
|
69
96
|
min_value=None,
|
70
97
|
max_value=None,
|
@@ -77,6 +104,7 @@ class Param(MessageData):
|
|
77
104
|
self.max_value = max_value
|
78
105
|
self.choices = choices
|
79
106
|
self.is_param = is_param
|
107
|
+
self._patch_encoder()
|
80
108
|
|
81
109
|
def __repr__(self) -> str:
|
82
110
|
attrs = []
|
@@ -153,6 +181,16 @@ class Param(MessageData):
|
|
153
181
|
def __ge__(self, other):
|
154
182
|
return self.default >= other
|
155
183
|
|
184
|
+
def __getattribute__(self, name):
|
185
|
+
"""Intercept attribute access to mimic default value behavior"""
|
186
|
+
try:
|
187
|
+
# First try to get Param attributes normally
|
188
|
+
return object.__getattribute__(self, name)
|
189
|
+
except AttributeError:
|
190
|
+
# Fall back to the default value's attributes
|
191
|
+
default = object.__getattribute__(self, 'default')
|
192
|
+
return getattr(default, name)
|
193
|
+
|
156
194
|
# Arithmetic operators – # arithmetic & bitwise operators – auto-generated
|
157
195
|
_arith_ops = {
|
158
196
|
"__add__": operator.add,
|
@@ -169,7 +207,6 @@ class Param(MessageData):
|
|
169
207
|
"__rshift__": operator.rshift,
|
170
208
|
}
|
171
209
|
|
172
|
-
# Create both left- and right-hand versions of each operator
|
173
210
|
for _name, _op in _arith_ops.items():
|
174
211
|
|
175
212
|
def _make(op):
|
@@ -243,6 +280,24 @@ class Param(MessageData):
|
|
243
280
|
return self
|
244
281
|
return self.default
|
245
282
|
|
283
|
+
def __json__(self):
|
284
|
+
return self.default if not hasattr(self.default, '__json__') else self.default.__json__()
|
285
|
+
|
286
|
+
@classmethod
|
287
|
+
def _patch_encoder(cls):
|
288
|
+
# only patch once
|
289
|
+
if getattr(json.JSONEncoder, "_user_patched", False):
|
290
|
+
return
|
291
|
+
original = json.JSONEncoder.default
|
292
|
+
|
293
|
+
def default(self, obj):
|
294
|
+
if isinstance(obj, Param):
|
295
|
+
return obj.__json__()
|
296
|
+
return original(self, obj)
|
297
|
+
|
298
|
+
json.JSONEncoder.default = default
|
299
|
+
json.JSONEncoder._user_patched = True
|
300
|
+
|
246
301
|
def to_proto(self, proto=None) -> ParamProto:
|
247
302
|
if proto is None:
|
248
303
|
proto = ParamProto()
|
@@ -254,7 +309,7 @@ class Param(MessageData):
|
|
254
309
|
option = ModelTypeEnumOption(id=str(choice))
|
255
310
|
proto.model_type_enum_options.append(option)
|
256
311
|
|
257
|
-
proto.required =
|
312
|
+
proto.required = False
|
258
313
|
|
259
314
|
if self.min_value is not None or self.max_value is not None:
|
260
315
|
range_info = ModelTypeRangeInfo()
|
@@ -324,8 +379,7 @@ class Param(MessageData):
|
|
324
379
|
|
325
380
|
if proto is None:
|
326
381
|
proto = ParamProto()
|
327
|
-
|
328
|
-
proto.default = json.dumps(default)
|
382
|
+
proto.default = json.dumps(default)
|
329
383
|
return proto
|
330
384
|
except Exception:
|
331
385
|
if default is not None:
|
clarifai/runners/utils/loader.py
CHANGED
@@ -41,7 +41,7 @@ class HuggingFaceLoader:
|
|
41
41
|
return True
|
42
42
|
except Exception as e:
|
43
43
|
logger.error(
|
44
|
-
f"
|
44
|
+
f"Invalid Hugging Face token provided in the config file, this might cause issues with downloading the restricted model checkpoints. Failed reason: {e}"
|
45
45
|
)
|
46
46
|
return False
|
47
47
|
|
@@ -63,7 +63,6 @@ class HuggingFaceLoader:
|
|
63
63
|
try:
|
64
64
|
is_hf_model_exists = self.validate_hf_model()
|
65
65
|
if not is_hf_model_exists:
|
66
|
-
logger.error("Model %s not found on Hugging Face" % (self.repo_id))
|
67
66
|
return False
|
68
67
|
|
69
68
|
self.ignore_patterns = self._get_ignore_patterns()
|
@@ -205,6 +204,28 @@ class HuggingFaceLoader:
|
|
205
204
|
]
|
206
205
|
return self.ignore_patterns
|
207
206
|
|
207
|
+
@classmethod
|
208
|
+
def validate_hf_repo_access(cls, repo_id: str, token: str = None) -> bool:
|
209
|
+
# check if model exists on HF
|
210
|
+
try:
|
211
|
+
from huggingface_hub import auth_check
|
212
|
+
from huggingface_hub.utils import GatedRepoError, RepositoryNotFoundError
|
213
|
+
except ImportError:
|
214
|
+
raise ImportError(cls.HF_DOWNLOAD_TEXT)
|
215
|
+
|
216
|
+
try:
|
217
|
+
auth_check(repo_id, token=token)
|
218
|
+
logger.info("Hugging Face repo access validated")
|
219
|
+
return True
|
220
|
+
except GatedRepoError:
|
221
|
+
logger.error(
|
222
|
+
"Hugging Face repo is gated. Please make sure you have access to the repo."
|
223
|
+
)
|
224
|
+
return False
|
225
|
+
except RepositoryNotFoundError:
|
226
|
+
logger.error("Hugging Face repo not found. Please make sure the repo exists.")
|
227
|
+
return False
|
228
|
+
|
208
229
|
@staticmethod
|
209
230
|
def validate_config(checkpoint_path: str):
|
210
231
|
# check if downloaded config.json exists
|
@@ -302,6 +302,9 @@ def serialize(kwargs, signatures, proto=None, is_output=False):
|
|
302
302
|
raise TypeError(f'Missing required argument: {sig.name}')
|
303
303
|
continue # skip missing fields, they can be set to default on the server
|
304
304
|
data = kwargs[sig.name]
|
305
|
+
default = data_utils.Param.get_default(sig)
|
306
|
+
if data is None and default is None:
|
307
|
+
continue
|
305
308
|
serializer = serializer_from_signature(sig)
|
306
309
|
# TODO determine if any (esp the first) var can go in the proto without parts
|
307
310
|
# and whether to put this in the signature or dynamically determine it
|
@@ -312,7 +315,7 @@ def serialize(kwargs, signatures, proto=None, is_output=False):
|
|
312
315
|
return proto
|
313
316
|
|
314
317
|
|
315
|
-
def deserialize(proto, signatures,
|
318
|
+
def deserialize(proto, signatures, is_output=False):
|
316
319
|
'''
|
317
320
|
Deserialize the given proto into kwargs using the given signatures.
|
318
321
|
'''
|
@@ -323,11 +326,8 @@ def deserialize(proto, signatures, inference_params={}, is_output=False):
|
|
323
326
|
for sig_i, sig in enumerate(signatures):
|
324
327
|
serializer = serializer_from_signature(sig)
|
325
328
|
part = parts_by_name.get(sig.name)
|
326
|
-
inference_params_value = inference_params.get(sig.name)
|
327
329
|
if part is not None:
|
328
330
|
kwargs[sig.name] = serializer.deserialize(part.data)
|
329
|
-
elif inference_params_value is not None:
|
330
|
-
kwargs[sig.name] = inference_params_value
|
331
331
|
else:
|
332
332
|
if sig_i == 0:
|
333
333
|
# possible inlined first value
|