clarifai 11.4.1__py3-none-any.whl → 11.4.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (40) hide show
  1. clarifai/__init__.py +1 -1
  2. clarifai/cli/base.py +7 -0
  3. clarifai/cli/model.py +6 -8
  4. clarifai/client/app.py +2 -1
  5. clarifai/client/auth/helper.py +6 -4
  6. clarifai/client/compute_cluster.py +2 -1
  7. clarifai/client/dataset.py +8 -1
  8. clarifai/client/deployment.py +2 -1
  9. clarifai/client/input.py +2 -1
  10. clarifai/client/model.py +2 -1
  11. clarifai/client/model_client.py +1 -1
  12. clarifai/client/module.py +2 -1
  13. clarifai/client/nodepool.py +2 -1
  14. clarifai/client/runner.py +2 -1
  15. clarifai/client/search.py +2 -1
  16. clarifai/client/user.py +2 -1
  17. clarifai/client/workflow.py +2 -1
  18. clarifai/runners/models/mcp_class.py +114 -0
  19. clarifai/runners/models/model_builder.py +179 -46
  20. clarifai/runners/models/model_class.py +5 -22
  21. clarifai/runners/models/model_run_locally.py +0 -4
  22. clarifai/runners/models/visual_classifier_class.py +75 -0
  23. clarifai/runners/models/visual_detector_class.py +79 -0
  24. clarifai/runners/utils/code_script.py +75 -44
  25. clarifai/runners/utils/const.py +15 -0
  26. clarifai/runners/utils/data_types/data_types.py +48 -0
  27. clarifai/runners/utils/data_utils.py +99 -45
  28. clarifai/runners/utils/loader.py +23 -2
  29. clarifai/runners/utils/method_signatures.py +4 -4
  30. clarifai/runners/utils/openai_convertor.py +103 -0
  31. clarifai/urls/helper.py +80 -12
  32. clarifai/utils/config.py +19 -0
  33. clarifai/utils/constants.py +4 -0
  34. clarifai/utils/logging.py +22 -5
  35. {clarifai-11.4.1.dist-info → clarifai-11.4.3.dist-info}/METADATA +1 -2
  36. {clarifai-11.4.1.dist-info → clarifai-11.4.3.dist-info}/RECORD +40 -37
  37. {clarifai-11.4.1.dist-info → clarifai-11.4.3.dist-info}/WHEEL +1 -1
  38. {clarifai-11.4.1.dist-info → clarifai-11.4.3.dist-info}/entry_points.txt +0 -0
  39. {clarifai-11.4.1.dist-info → clarifai-11.4.3.dist-info}/licenses/LICENSE +0 -0
  40. {clarifai-11.4.1.dist-info → clarifai-11.4.3.dist-info}/top_level.txt +0 -0
@@ -3,7 +3,8 @@ from typing import List
3
3
 
4
4
  from clarifai_grpc.grpc.api import resources_pb2
5
5
 
6
- from clarifai.runners.utils import data_types
6
+ from clarifai.runners.utils import data_utils
7
+ from clarifai.urls.helper import ClarifaiUrlHelper
7
8
 
8
9
 
9
10
  def generate_client_script(
@@ -15,6 +16,38 @@ def generate_client_script(
15
16
  deployment_id: str = None,
16
17
  use_ctx: bool = False,
17
18
  ) -> str:
19
+ url_helper = ClarifaiUrlHelper()
20
+
21
+ # Provide an mcp client config
22
+ if len(method_signatures) == 1 and method_signatures[0].name == "mcp_transport":
23
+ api_url = url_helper.api_url(
24
+ user_id,
25
+ app_id,
26
+ "models",
27
+ model_id,
28
+ )
29
+
30
+ _CLIENT_TEMPLATE = """
31
+ import asyncio
32
+ import os
33
+ from fastmcp import Client
34
+ from fastmcp.client.transports import StreamableHttpTransport
35
+
36
+ transport = StreamableHttpTransport(url="%s/mcp",
37
+ headers={"Authorization": "Bearer " + os.environ["CLARIFAI_PAT"]})
38
+
39
+ async def main():
40
+ async with Client(transport) as client:
41
+ tools = await client.list_tools()
42
+ print(f"Available tools: {tools}")
43
+ result = await client.call_tool(tools[0].name, {"a": 5, "b": 3})
44
+ print(f"Result: {result[0].text}")
45
+
46
+ if __name__ == "__main__":
47
+ asyncio.run(main())
48
+ """
49
+ return _CLIENT_TEMPLATE % api_url
50
+
18
51
  _CLIENT_TEMPLATE = """\
19
52
  import os
20
53
 
@@ -35,19 +68,13 @@ from clarifai.runners.utils import data_types
35
68
  model_section = """
36
69
  model = Model.from_current_context()"""
37
70
  else:
38
- model_section = """
39
- model = Model("https://clarifai.com/{user_id}/{app_id}/{model_id}",
71
+ model_ui_url = url_helper.clarifai_url(user_id, app_id, "models", model_id)
72
+ model_section = f"""
73
+ model = Model({model_ui_url},
40
74
  deployment_id = {deployment_id}, # Only needed for dedicated deployed models
41
75
  {base_url_str}
42
76
  )
43
77
  """
44
- model_section = _CLIENT_TEMPLATE.format(
45
- user_id=user_id,
46
- app_id=app_id,
47
- model_id=model_id,
48
- deployment_id=deployment_id,
49
- base_url_str=base_url_str,
50
- )
51
78
 
52
79
  # Generate client template
53
80
  client_template = _CLIENT_TEMPLATE.format(
@@ -58,28 +85,24 @@ model = Model.from_current_context()"""
58
85
  method_signatures_str = []
59
86
  for method_signature in method_signatures:
60
87
  method_name = method_signature.name
61
- if method_signature.method_type in [
62
- resources_pb2.RunnerMethodType.UNARY_UNARY,
63
- resources_pb2.RunnerMethodType.UNARY_STREAMING,
64
- ]:
65
- client_script_str = f'response = model.{method_name}('
66
- annotations = _get_annotations_source(method_signature)
67
- for param_name, (param_type, default_value) in annotations.items():
68
- if param_name == "return":
69
- continue
70
- if default_value is None:
71
- default_value = _set_default_value(param_type)
72
- if param_type == "str":
73
- default_value = repr(default_value)
74
-
75
- client_script_str += f"{param_name}={default_value}, "
76
- client_script_str = client_script_str.rstrip(", ") + ")"
77
- if method_signature.method_type == resources_pb2.RunnerMethodType.UNARY_UNARY:
78
- client_script_str += "\nprint(response)"
79
- elif method_signature.method_type == resources_pb2.RunnerMethodType.UNARY_STREAMING:
80
- client_script_str += "\nfor res in response:\n print(res)"
81
- client_script_str += "\n"
82
- method_signatures_str.append(client_script_str)
88
+ client_script_str = f'response = model.{method_name}('
89
+ annotations = _get_annotations_source(method_signature)
90
+ for param_name, (param_type, default_value) in annotations.items():
91
+ print(
92
+ f"param_name: {param_name}, param_type: {param_type}, default_value: {default_value}"
93
+ )
94
+ if param_name == "return":
95
+ continue
96
+ if default_value is None:
97
+ default_value = _set_default_value(param_type)
98
+ client_script_str += f"{param_name}={default_value}, "
99
+ client_script_str = client_script_str.rstrip(", ") + ")"
100
+ if method_signature.method_type == resources_pb2.RunnerMethodType.UNARY_UNARY:
101
+ client_script_str += "\nprint(response)"
102
+ elif method_signature.method_type == resources_pb2.RunnerMethodType.UNARY_STREAMING:
103
+ client_script_str += "\nfor res in response:\n print(res)"
104
+ client_script_str += "\n"
105
+ method_signatures_str.append(client_script_str)
83
106
 
84
107
  method_signatures_str = "\n".join(method_signatures_str)
85
108
  # Combine all parts
@@ -107,9 +130,8 @@ def _get_annotations_source(method_signature: resources_pb2.MethodSignature) ->
107
130
  if input_field.iterator:
108
131
  param_type = f"Iterator[{param_type}]"
109
132
  default_value = None
110
- if input_field.default:
133
+ if data_utils.Param.get_default(input_field):
111
134
  default_value = _parse_default_value(input_field)
112
-
113
135
  annotations[param_name] = (param_type, default_value)
114
136
  if not method_signature.output_fields:
115
137
  raise ValueError("MethodSignature must have at least one output field")
@@ -177,23 +199,21 @@ def _map_default_value(field_type):
177
199
  elif field_type == "bool":
178
200
  default_value = False
179
201
  elif field_type == "data_types.Image":
180
- default_value = data_types.Image.from_url("https://samples.clarifai.com/metro-north.jpg")
202
+ default_value = 'data_types.Image.from_url("https://samples.clarifai.com/metro-north.jpg")'
181
203
  elif field_type == "data_types.Text":
182
- default_value = data_types.Text("What's the future of AI?")
204
+ default_value = 'data_types.Text("What is the future of AI?")'
183
205
  elif field_type == "data_types.Audio":
184
- default_value = data_types.Audio.from_url("https://samples.clarifai.com/audio.mp3")
206
+ default_value = 'data_types.Audio.from_url("https://samples.clarifai.com/audio.mp3")'
185
207
  elif field_type == "data_types.Video":
186
- default_value = data_types.Video.from_url("https://samples.clarifai.com/video.mp4")
208
+ default_value = 'data_types.Video.from_url("https://samples.clarifai.com/video.mp4")'
187
209
  elif field_type == "data_types.Concept":
188
- default_value = data_types.Concept(id="concept_id", name="dog", value=0.95)
210
+ default_value = 'data_types.Concept(id="concept_id", name="dog", value=0.95)'
189
211
  elif field_type == "data_types.Region":
190
- default_value = data_types.Region(
191
- box=[0.1, 0.1, 0.5, 0.5],
192
- )
212
+ default_value = 'data_types.Region(box=[0.1, 0.1, 0.5, 0.5],)'
193
213
  elif field_type == "data_types.Frame":
194
- default_value = data_types.Frame.from_url("https://samples.clarifai.com/video.mp4", 0)
214
+ default_value = 'data_types.Frame.from_url("https://samples.clarifai.com/video.mp4", 0)'
195
215
  elif field_type == "data_types.NDArray":
196
- default_value = data_types.NDArray([1, 2, 3])
216
+ default_value = 'data_types.NDArray([1, 2, 3])'
197
217
  else:
198
218
  default_value = None
199
219
  return default_value
@@ -203,6 +223,12 @@ def _set_default_value(field_type):
203
223
  """
204
224
  Set the default value of a field if it is not set.
205
225
  """
226
+ is_iterator = False
227
+ print(f"before field_type: {field_type}")
228
+ if field_type.startswith("Iterator["):
229
+ is_iterator = True
230
+ field_type = field_type[9:-1]
231
+ print(f"after field_type: {field_type}")
206
232
  default_value = None
207
233
  default_value = _map_default_value(field_type)
208
234
  if field_type.startswith("List["):
@@ -219,6 +245,11 @@ def _set_default_value(field_type):
219
245
  element_type_defaults = [_map_default_value(et) for et in element_types]
220
246
  default_value = f"{{{', '.join([str(et) for et in element_type_defaults])}}}"
221
247
 
248
+ if field_type == 'str':
249
+ default_value = repr(default_value)
250
+ if is_iterator:
251
+ default_value = f'iter([{default_value}])'
252
+ print(f"after default_value: {default_value}")
222
253
  return default_value
223
254
 
224
255
 
@@ -4,14 +4,28 @@ registry = os.environ.get('CLARIFAI_BASE_IMAGE_REGISTRY', 'public.ecr.aws/clarif
4
4
 
5
5
  GIT_SHA = "b8ae56bf3b7c95e686ca002b07ca83d259c716eb"
6
6
 
7
+ AMD_GIT_SHA = "81e942130173f54927e7c9a65aabc7e32780616d"
8
+
7
9
  PYTHON_BASE_IMAGE = registry + '/python-base:{python_version}-' + GIT_SHA
8
10
  TORCH_BASE_IMAGE = registry + '/torch:{torch_version}-py{python_version}-{gpu_version}-' + GIT_SHA
9
11
 
12
+ AMD_PYTHON_BASE_IMAGE = registry + '/amd-python-base:{python_version}-' + AMD_GIT_SHA
13
+ AMD_TORCH_BASE_IMAGE = (
14
+ registry + '/amd-torch:{torch_version}-py{python_version}-{gpu_version}-' + AMD_GIT_SHA
15
+ )
16
+ AMD_VLLM_BASE_IMAGE = (
17
+ registry + '/amd-vllm:{torch_version}-py{python_version}-{gpu_version}-' + AMD_GIT_SHA
18
+ )
19
+
10
20
  # List of available python base images
11
21
  AVAILABLE_PYTHON_IMAGES = ['3.11', '3.12']
12
22
 
13
23
  DEFAULT_PYTHON_VERSION = 3.12
14
24
 
25
+ DEFAULT_AMD_TORCH_VERSION = '2.8.0.dev20250511+rocm6.4'
26
+
27
+ DEFAULT_AMD_GPU_VERSION = 'rocm6.4'
28
+
15
29
  # By default we download at runtime.
16
30
  DEFAULT_DOWNLOAD_CHECKPOINT_WHEN = "runtime"
17
31
 
@@ -29,6 +43,7 @@ AVAILABLE_TORCH_IMAGES = [
29
43
  '2.7.0-py3.12-cu128',
30
44
  '2.7.0-py3.12-rocm6.3',
31
45
  ]
46
+
32
47
  CONCEPTS_REQUIRED_MODEL_TYPE = [
33
48
  'visual-classifier',
34
49
  'visual-detector',
@@ -395,6 +395,22 @@ class Image(MessageData):
395
395
  raise ValueError("Image has no bytes")
396
396
  return PILImage.open(io.BytesIO(self.proto.base64))
397
397
 
398
+ def to_base64_str(self) -> str:
399
+ if not self.proto.base64:
400
+ raise ValueError("Image has no bytes")
401
+ if isinstance(self.proto.base64, str):
402
+ return self.proto.base64
403
+ if isinstance(self.proto.base64, bytes):
404
+ try:
405
+ # trying direct decode (if already a base64 bytes)
406
+ return self.proto.base64.decode('utf-8')
407
+ except UnicodeDecodeError:
408
+ import base64
409
+
410
+ return base64.b64encode(self.proto.base64).decode('utf-8')
411
+ else:
412
+ raise TypeError("Expected str or bytes for Image.base64")
413
+
398
414
  def to_numpy(self) -> np.ndarray:
399
415
  return np.asarray(self.to_pil())
400
416
 
@@ -466,6 +482,22 @@ class Audio(MessageData):
466
482
  def to_proto(self) -> AudioProto:
467
483
  return self.proto
468
484
 
485
+ def to_base64_str(self) -> str:
486
+ if not self.proto.base64:
487
+ raise ValueError("Audio has no bytes")
488
+ if isinstance(self.proto.base64, str):
489
+ return self.proto.base64
490
+ if isinstance(self.proto.base64, bytes):
491
+ try:
492
+ # trying direct decode (if already a base64 bytes)
493
+ return self.proto.base64.decode('utf-8')
494
+ except UnicodeDecodeError:
495
+ import base64
496
+
497
+ return base64.b64encode(self.proto.base64).decode('utf-8')
498
+ else:
499
+ raise TypeError("Expected str or bytes for Audio.base64")
500
+
469
501
  @classmethod
470
502
  def from_proto(cls, proto: AudioProto) -> "Audio":
471
503
  return cls(proto)
@@ -578,6 +610,22 @@ class Video(MessageData):
578
610
  def to_proto(self) -> VideoProto:
579
611
  return self.proto
580
612
 
613
+ def to_base64_str(self) -> str:
614
+ if not self.proto.base64:
615
+ raise ValueError("Video has no bytes")
616
+ if isinstance(self.proto.base64, str):
617
+ return self.proto.base64
618
+ if isinstance(self.proto.base64, bytes):
619
+ try:
620
+ # trying direct decode (if already a base64 bytes)
621
+ return self.proto.base64.decode('utf-8')
622
+ except UnicodeDecodeError:
623
+ import base64
624
+
625
+ return base64.b64encode(self.proto.base64).decode('utf-8')
626
+ else:
627
+ raise TypeError("Expected str or bytes for Video.base64")
628
+
581
629
  @classmethod
582
630
  def from_proto(cls, proto: VideoProto) -> "Video":
583
631
  return cls(proto)
@@ -1,62 +1,89 @@
1
+ import base64
2
+ import json
1
3
  import math
2
4
  import operator
3
5
  from io import BytesIO
4
- from typing import List
6
+ from typing import Dict, List
5
7
 
8
+ import requests
6
9
  from clarifai_grpc.grpc.api import resources_pb2
7
10
  from clarifai_grpc.grpc.api.resources_pb2 import ModelTypeEnumOption, ModelTypeRangeInfo
8
11
  from clarifai_grpc.grpc.api.resources_pb2 import ModelTypeField as ParamProto
9
- from PIL import Image
12
+ from PIL import Image as PILImage
10
13
 
11
- from clarifai.runners.utils.data_types import MessageData
14
+ from clarifai.runners.utils.data_types import Audio, Image, MessageData, Video
12
15
 
13
16
 
14
- def image_to_bytes(img: Image.Image, format="JPEG") -> bytes:
17
+ def image_to_bytes(img: PILImage.Image, format="JPEG") -> bytes:
15
18
  buffered = BytesIO()
16
19
  img.save(buffered, format=format)
17
20
  img_str = buffered.getvalue()
18
21
  return img_str
19
22
 
20
23
 
21
- def bytes_to_image(bytes_img) -> Image.Image:
22
- img = Image.open(BytesIO(bytes_img))
24
+ def bytes_to_image(bytes_img) -> PILImage.Image:
25
+ img = PILImage.open(BytesIO(bytes_img))
23
26
  return img
24
27
 
25
28
 
26
- def is_openai_chat_format(messages):
27
- """
28
- Verify if the given argument follows the OpenAI chat messages format.
29
-
30
- Args:
31
- messages (list): A list of dictionaries representing chat messages.
32
-
33
- Returns:
34
- bool: True if valid, False otherwise.
35
- """
36
- if not isinstance(messages, list):
37
- return False
38
-
39
- valid_roles = {"system", "user", "assistant", "function"}
40
-
41
- for msg in messages:
42
- if not isinstance(msg, dict):
43
- return False
44
- if "role" not in msg or "content" not in msg:
45
- return False
46
- if msg["role"] not in valid_roles:
47
- return False
48
-
49
- content = msg["content"]
50
-
51
- # Content should be either a string (text message) or a multimodal list
52
- if isinstance(content, str):
53
- continue # Valid text message
54
-
55
- elif isinstance(content, list):
56
- for item in content:
57
- if not isinstance(item, dict):
58
- return False
59
- return True
29
+ def process_image(image: Image) -> Dict:
30
+ """Convert Clarifai Image object to OpenAI image format."""
31
+
32
+ if image.bytes:
33
+ b64_img = image.to_base64_str()
34
+ return {'type': 'image_url', 'image_url': {'url': f"data:image/jpeg;base64,{b64_img}"}}
35
+ elif image.url:
36
+ return {'type': 'image_url', 'image_url': {'url': image.url}}
37
+ else:
38
+ raise ValueError("Image must contain either bytes or URL")
39
+
40
+
41
+ def process_audio(audio: Audio) -> Dict:
42
+ """Convert Clarifai Audio object to OpenAI audio format."""
43
+
44
+ if audio.bytes:
45
+ audio = audio.to_base64_str()
46
+ audio = {
47
+ "type": "input_audio",
48
+ "input_audio": {"data": audio, "format": "wav"},
49
+ }
50
+ elif audio.url:
51
+ response = requests.get(audio.url)
52
+ if response.status_code != 200:
53
+ raise ValueError(f"Failed to fetch audio from URL: {audio.url}")
54
+ audio_base64_str = base64.b64encode(response.content).decode('utf-8')
55
+ audio = {
56
+ "type": "input_audio",
57
+ "input_audio": {"data": audio_base64_str, "format": "wav"},
58
+ }
59
+ else:
60
+ raise ValueError("Audio must contain either bytes or URL")
61
+
62
+ return audio
63
+
64
+
65
+ def process_video(video: Video) -> Dict:
66
+ """Convert Clarifai Video object to OpenAI video format."""
67
+
68
+ if video.bytes:
69
+ video = "data:video/mp4;base64," + video.to_base64_str()
70
+ video = {
71
+ "type": "video_url",
72
+ "video_url": {"url": video},
73
+ }
74
+ elif video.url:
75
+ response = requests.get(video.url)
76
+ if response.status_code != 200:
77
+ raise ValueError(f"Failed to fetch video from URL: {video.url}")
78
+ video_base64_str = base64.b64encode(response.content).decode('utf-8')
79
+ video = {
80
+ "type": "video_url",
81
+ "video_url": {"url": video_base64_str},
82
+ }
83
+ else:
84
+ raise ValueError("Video must contain either bytes or URL")
85
+
86
+ return video
60
87
 
61
88
 
62
89
  class Param(MessageData):
@@ -64,7 +91,7 @@ class Param(MessageData):
64
91
 
65
92
  def __init__(
66
93
  self,
67
- default=None,
94
+ default,
68
95
  description=None,
69
96
  min_value=None,
70
97
  max_value=None,
@@ -77,6 +104,7 @@ class Param(MessageData):
77
104
  self.max_value = max_value
78
105
  self.choices = choices
79
106
  self.is_param = is_param
107
+ self._patch_encoder()
80
108
 
81
109
  def __repr__(self) -> str:
82
110
  attrs = []
@@ -153,6 +181,16 @@ class Param(MessageData):
153
181
  def __ge__(self, other):
154
182
  return self.default >= other
155
183
 
184
+ def __getattribute__(self, name):
185
+ """Intercept attribute access to mimic default value behavior"""
186
+ try:
187
+ # First try to get Param attributes normally
188
+ return object.__getattribute__(self, name)
189
+ except AttributeError:
190
+ # Fall back to the default value's attributes
191
+ default = object.__getattribute__(self, 'default')
192
+ return getattr(default, name)
193
+
156
194
  # Arithmetic operators – # arithmetic & bitwise operators – auto-generated
157
195
  _arith_ops = {
158
196
  "__add__": operator.add,
@@ -169,7 +207,6 @@ class Param(MessageData):
169
207
  "__rshift__": operator.rshift,
170
208
  }
171
209
 
172
- # Create both left- and right-hand versions of each operator
173
210
  for _name, _op in _arith_ops.items():
174
211
 
175
212
  def _make(op):
@@ -243,6 +280,24 @@ class Param(MessageData):
243
280
  return self
244
281
  return self.default
245
282
 
283
+ def __json__(self):
284
+ return self.default if not hasattr(self.default, '__json__') else self.default.__json__()
285
+
286
+ @classmethod
287
+ def _patch_encoder(cls):
288
+ # only patch once
289
+ if getattr(json.JSONEncoder, "_user_patched", False):
290
+ return
291
+ original = json.JSONEncoder.default
292
+
293
+ def default(self, obj):
294
+ if isinstance(obj, Param):
295
+ return obj.__json__()
296
+ return original(self, obj)
297
+
298
+ json.JSONEncoder.default = default
299
+ json.JSONEncoder._user_patched = True
300
+
246
301
  def to_proto(self, proto=None) -> ParamProto:
247
302
  if proto is None:
248
303
  proto = ParamProto()
@@ -254,7 +309,7 @@ class Param(MessageData):
254
309
  option = ModelTypeEnumOption(id=str(choice))
255
310
  proto.model_type_enum_options.append(option)
256
311
 
257
- proto.required = self.default is None
312
+ proto.required = False
258
313
 
259
314
  if self.min_value is not None or self.max_value is not None:
260
315
  range_info = ModelTypeRangeInfo()
@@ -324,8 +379,7 @@ class Param(MessageData):
324
379
 
325
380
  if proto is None:
326
381
  proto = ParamProto()
327
- if default is not None:
328
- proto.default = json.dumps(default)
382
+ proto.default = json.dumps(default)
329
383
  return proto
330
384
  except Exception:
331
385
  if default is not None:
@@ -41,7 +41,7 @@ class HuggingFaceLoader:
41
41
  return True
42
42
  except Exception as e:
43
43
  logger.error(
44
- f"Error setting up Hugging Face token, please make sure you have the correct token: {e}"
44
+ f"Invalid Hugging Face token provided in the config file, this might cause issues with downloading the restricted model checkpoints. Failed reason: {e}"
45
45
  )
46
46
  return False
47
47
 
@@ -63,7 +63,6 @@ class HuggingFaceLoader:
63
63
  try:
64
64
  is_hf_model_exists = self.validate_hf_model()
65
65
  if not is_hf_model_exists:
66
- logger.error("Model %s not found on Hugging Face" % (self.repo_id))
67
66
  return False
68
67
 
69
68
  self.ignore_patterns = self._get_ignore_patterns()
@@ -205,6 +204,28 @@ class HuggingFaceLoader:
205
204
  ]
206
205
  return self.ignore_patterns
207
206
 
207
+ @classmethod
208
+ def validate_hf_repo_access(cls, repo_id: str, token: str = None) -> bool:
209
+ # check if model exists on HF
210
+ try:
211
+ from huggingface_hub import auth_check
212
+ from huggingface_hub.utils import GatedRepoError, RepositoryNotFoundError
213
+ except ImportError:
214
+ raise ImportError(cls.HF_DOWNLOAD_TEXT)
215
+
216
+ try:
217
+ auth_check(repo_id, token=token)
218
+ logger.info("Hugging Face repo access validated")
219
+ return True
220
+ except GatedRepoError:
221
+ logger.error(
222
+ "Hugging Face repo is gated. Please make sure you have access to the repo."
223
+ )
224
+ return False
225
+ except RepositoryNotFoundError:
226
+ logger.error("Hugging Face repo not found. Please make sure the repo exists.")
227
+ return False
228
+
208
229
  @staticmethod
209
230
  def validate_config(checkpoint_path: str):
210
231
  # check if downloaded config.json exists
@@ -302,6 +302,9 @@ def serialize(kwargs, signatures, proto=None, is_output=False):
302
302
  raise TypeError(f'Missing required argument: {sig.name}')
303
303
  continue # skip missing fields, they can be set to default on the server
304
304
  data = kwargs[sig.name]
305
+ default = data_utils.Param.get_default(sig)
306
+ if data is None and default is None:
307
+ continue
305
308
  serializer = serializer_from_signature(sig)
306
309
  # TODO determine if any (esp the first) var can go in the proto without parts
307
310
  # and whether to put this in the signature or dynamically determine it
@@ -312,7 +315,7 @@ def serialize(kwargs, signatures, proto=None, is_output=False):
312
315
  return proto
313
316
 
314
317
 
315
- def deserialize(proto, signatures, inference_params={}, is_output=False):
318
+ def deserialize(proto, signatures, is_output=False):
316
319
  '''
317
320
  Deserialize the given proto into kwargs using the given signatures.
318
321
  '''
@@ -323,11 +326,8 @@ def deserialize(proto, signatures, inference_params={}, is_output=False):
323
326
  for sig_i, sig in enumerate(signatures):
324
327
  serializer = serializer_from_signature(sig)
325
328
  part = parts_by_name.get(sig.name)
326
- inference_params_value = inference_params.get(sig.name)
327
329
  if part is not None:
328
330
  kwargs[sig.name] = serializer.deserialize(part.data)
329
- elif inference_params_value is not None:
330
- kwargs[sig.name] = inference_params_value
331
331
  else:
332
332
  if sig_i == 0:
333
333
  # possible inlined first value