clarifai 10.8.1__py3-none-any.whl → 10.8.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,244 @@
1
+ from typing import Dict, List, Tuple, Union
2
+
3
+ import numpy as np
4
+ from clarifai_grpc.grpc.api import resources_pb2
5
+ from clarifai_grpc.grpc.api.status import status_code_pb2, status_pb2
6
+ from PIL import Image
7
+ from pydantic import BaseModel, ConfigDict, PrivateAttr, computed_field
8
+
9
+ from clarifai.client.auth.helper import ClarifaiAuthHelper
10
+
11
+ from .data_utils import bytes_to_image, image_to_bytes
12
+
13
+
14
+ class BaseDataHandler(BaseModel):
15
+ _proto: Union[resources_pb2.Input, resources_pb2.Output]
16
+ _auth: ClarifaiAuthHelper = PrivateAttr(default=None)
17
+
18
+ model_config = ConfigDict(arbitrary_types_allowed=True)
19
+
20
+ #
21
+ def to_python(self):
22
+ return dict(text=self.text, image=self.image, audio=self.audio)
23
+
24
+ # ---------------- Start get/setters ---------------- #
25
+ # Proto
26
+ @property
27
+ def proto(self):
28
+ return self._proto
29
+
30
+ def set_proto(self, proto):
31
+ self._proto = proto
32
+
33
+ # status
34
+ @computed_field
35
+ def status(self) -> status_pb2.Status:
36
+ return self._proto.status
37
+
38
+ def set_status(self, code: str, description: str = ""):
39
+ self._proto.status.code = code
40
+ self._proto.status.description = description
41
+
42
+ # Text
43
+ @computed_field
44
+ def text(self) -> Union[None, str]:
45
+ data = self._proto.data.text
46
+ text = None
47
+ if data.ByteSize():
48
+ if data.raw:
49
+ text = data.raw
50
+ else:
51
+ # url = data.url
52
+ raise NotImplementedError
53
+
54
+ return text
55
+
56
+ def set_text(self, text: str):
57
+ self._proto.data.text.raw = text
58
+
59
+ # Image
60
+ @computed_field
61
+ def image(self, format: str = "np") -> Union[None, Image.Image, np.ndarray]:
62
+ data = self._proto.data.image
63
+ image = None
64
+ if data.ByteSize():
65
+ data: resources_pb2.Image = data
66
+ if data.base64:
67
+ image = data.base64
68
+ elif data.url:
69
+ # download url
70
+ # url = data.url
71
+ image = ...
72
+ raise NotImplementedError
73
+ image = bytes_to_image(image)
74
+ image = image if not format == "np" else np.asarray(image).astype("uint8")
75
+
76
+ return image
77
+
78
+ def set_image(self, image: Union[Image.Image, np.ndarray]):
79
+ if isinstance(image, np.ndarray):
80
+ image = Image.fromarray(image)
81
+ self._proto.data.image.base64 = image_to_bytes(image)
82
+
83
+ # Audio
84
+ @computed_field
85
+ def audio(self) -> bytes:
86
+ data = self._proto.data.audio
87
+ audio = None
88
+ if data.ByteSize():
89
+ if data.base64:
90
+ audio = data.base64
91
+
92
+ return audio
93
+
94
+ def set_audio(self, audio: bytes):
95
+ self._proto.data.audio.base64 = audio
96
+
97
+ # Bboxes
98
+ @computed_field
99
+ def bboxes(self, real_coord: bool = False, image_width: int = None,
100
+ image_height: int = None) -> Tuple[List, List, List]:
101
+ if real_coord:
102
+ assert (image_height or image_width
103
+ ), "image_height and image_width are required when when return real coordinates"
104
+ xyxy = []
105
+ scores = []
106
+ concepts = []
107
+ for _, each in enumerate(self._proto.data.regions):
108
+ box = each.region_info
109
+ score = each.value
110
+ concept = each.data.concepts[0].id
111
+ x1 = box.left_col
112
+ y1 = box.top_row
113
+ x2 = box.right_col
114
+ y2 = box.bottom_row
115
+ if real_coord:
116
+ x1 = x1 * image_width
117
+ y1 = y1 * image_height
118
+ x2 = x2 * image_width
119
+ y2 = y2 * image_height
120
+ xyxy.append([x1, y1, x2, y2])
121
+ scores.append(score)
122
+ concepts.append(concept)
123
+
124
+ return xyxy, scores, concepts
125
+
126
+ def set_bboxes(
127
+ self,
128
+ boxes: list,
129
+ scores: list,
130
+ concepts: list,
131
+ real_coord: bool = False,
132
+ image_width: int = None,
133
+ image_height: int = None,
134
+ ):
135
+ if real_coord:
136
+ assert (image_height and
137
+ image_width), "image_height and image_width are required when `real_coord` is set"
138
+ bboxes = [[x[1] / image_height, x[0] / image_width, x[3] / image_height, x[2] / image_width]
139
+ for x in boxes] # normalize the bboxes to [0,1] and [y1 x1 y2 x2]
140
+ bboxes = np.clip(bboxes, 0, 1.0)
141
+
142
+ regions = []
143
+ for ith, bbox in enumerate(bboxes):
144
+ score = scores[ith]
145
+ concept = concepts[ith]
146
+ if any([each > 1.0 for each in bbox]):
147
+ assert ValueError(
148
+ "Box coordinates is not normalized between [0, 1]. Please set format_box to True and provide image_height and image_width to normalize"
149
+ )
150
+ region = resources_pb2.RegionInfo(bounding_box=resources_pb2.BoundingBox(
151
+ top_row=bbox[0], # y_min
152
+ left_col=bbox[1], # x_min
153
+ bottom_row=bbox[2], # y_max
154
+ right_col=bbox[3], # x_max
155
+ ))
156
+ data = resources_pb2.Data(concepts=resources_pb2.Concept(id=concept, value=score))
157
+ regions.append(resources_pb2.Region(region_info=region, data=data))
158
+
159
+ self._proto.data.regions = regions
160
+
161
+ # Concepts
162
+ @computed_field
163
+ def concepts(self) -> Dict[str, float]:
164
+ con_scores = {}
165
+ for each in self.proto.data.concepts:
166
+ con_scores.update({each.id: each.value})
167
+ return con_scores
168
+
169
+ def set_concepts(self, concept_score_pairs: Dict[str, float]):
170
+ concepts = []
171
+ for concept, score in concept_score_pairs.items():
172
+ con_score = resources_pb2.Concept(id=concept, name=concept, value=score)
173
+ concepts.append(con_score)
174
+ if concepts:
175
+ self._proto.data.ClearField("concepts")
176
+ for each in concepts:
177
+ self._proto.data.concepts.append(each)
178
+
179
+ # Embeddings
180
+ @computed_field
181
+ def embeddings(self) -> List[List[float]]:
182
+ return [each.vector for each in self.proto.data.embeddings]
183
+
184
+ def set_embeddings(self, list_vectors: List[List[float]]):
185
+ if list_vectors[0]:
186
+ self._proto.data.ClearField("embeddings")
187
+ for vec in list_vectors:
188
+ self._proto.data.embeddings.append(
189
+ resources_pb2.Embedding(vector=vec, num_dimensions=len(vec)))
190
+
191
+ # ---------------- End get/setters ---------------- #
192
+
193
+ # Constructors
194
+ @classmethod
195
+ def from_proto(cls, proto):
196
+ clss = cls()
197
+ clss.set_proto(proto)
198
+ return clss
199
+
200
+ @classmethod
201
+ def from_data(
202
+ cls,
203
+ status_code: int = status_code_pb2.SUCCESS,
204
+ status_description: str = "",
205
+ text: str = None,
206
+ image: Union[Image.Image, np.ndarray] = None,
207
+ audio: bytes = None,
208
+ boxes: dict = None,
209
+ concepts: Dict[str, float] = {},
210
+ embeddings: List[List[float]] = [],
211
+ ) -> 'OutputDataHandler':
212
+ clss = cls()
213
+ if isinstance(image, Image.Image) or isinstance(image, np.ndarray):
214
+ clss.set_image(image)
215
+ if text:
216
+ clss.set_text(text)
217
+ if audio:
218
+ clss.set_audio(audio)
219
+ if boxes:
220
+ clss.set_bboxes(**boxes)
221
+ if concepts:
222
+ clss.set_concepts(concepts)
223
+ if embeddings:
224
+ clss.set_embeddings(embeddings)
225
+
226
+ clss.set_status(code=status_code, description=status_description)
227
+
228
+ return clss
229
+
230
+
231
+ class InputDataHandler(BaseDataHandler):
232
+ _proto: resources_pb2.Input = resources_pb2.Input()
233
+
234
+ def set_proto(self, proto: resources_pb2.Input):
235
+ assert isinstance(proto, resources_pb2.Input)
236
+ self._proto = proto
237
+
238
+
239
+ class OutputDataHandler(BaseDataHandler):
240
+ _proto: resources_pb2.Output = resources_pb2.Output()
241
+
242
+ def set_proto(self, proto: resources_pb2.Output):
243
+ assert isinstance(proto, resources_pb2.Output)
244
+ self._proto = proto
@@ -0,0 +1,15 @@
1
+ from io import BytesIO
2
+
3
+ from PIL import Image
4
+
5
+
6
+ def image_to_bytes(img: Image.Image, format="JPEG") -> bytes:
7
+ buffered = BytesIO()
8
+ img.save(buffered, format=format)
9
+ img_str = buffered.getvalue()
10
+ return img_str
11
+
12
+
13
+ def bytes_to_image(bytes_img) -> Image.Image:
14
+ img = Image.open(BytesIO(bytes_img))
15
+ return img
@@ -0,0 +1,70 @@
1
+ import importlib.util
2
+ import json
3
+ import os
4
+ import subprocess
5
+
6
+
7
+ class HuggingFaceLoarder:
8
+
9
+ def __init__(self, repo_id=None, token=None):
10
+ self.repo_id = repo_id
11
+ self.token = token
12
+ if token:
13
+ try:
14
+ if importlib.util.find_spec("huggingface_hub") is None:
15
+ raise ImportError(
16
+ "The 'huggingface_hub' package is not installed. Please install it using 'pip install huggingface_hub'."
17
+ )
18
+ os.environ['HF_TOKEN'] = token
19
+ subprocess.run(f'huggingface-cli login --token={os.environ["HF_TOKEN"]}', shell=True)
20
+ except Exception as e:
21
+ Exception("Error setting up Hugging Face token ", e)
22
+
23
+ def download_checkpoints(self, checkpoint_path: str):
24
+ # throw error if huggingface_hub wasn't installed
25
+ try:
26
+ from huggingface_hub import snapshot_download
27
+ except ImportError:
28
+ raise ImportError(
29
+ "The 'huggingface_hub' package is not installed. Please install it using 'pip install huggingface_hub'."
30
+ )
31
+ if os.path.exists(checkpoint_path):
32
+ print("Checkpoints already exist")
33
+ else:
34
+ os.makedirs(checkpoint_path, exist_ok=True)
35
+ try:
36
+ is_hf_model_exists = self.validate_hf_model()
37
+ if not is_hf_model_exists:
38
+ print("Model not found on Hugging Face")
39
+ return False
40
+ snapshot_download(repo_id=self.repo_id, local_dir=checkpoint_path)
41
+ except Exception as e:
42
+ print("Error downloading model checkpoints ", e)
43
+ return False
44
+
45
+ if not self.validate_download(checkpoint_path):
46
+ print("Error downloading model checkpoints")
47
+ return False
48
+ return True
49
+
50
+ def validate_hf_model(self,):
51
+ # check if model exists on HF
52
+
53
+ from huggingface_hub import file_exists, repo_exists
54
+ return repo_exists(self.repo_id) and file_exists(self.repo_id, 'config.json')
55
+
56
+ def validate_download(self, checkpoint_path: str):
57
+ # check if model exists on HF
58
+ from huggingface_hub import list_repo_files
59
+
60
+ return (len(os.listdir(checkpoint_path)) >= len(list_repo_files(self.repo_id))) and len(
61
+ list_repo_files(self.repo_id)) > 0
62
+
63
+ def fetch_labels(self, checkpoint_path: str):
64
+ # Fetch labels for classification, detection and segmentation models
65
+ config_path = os.path.join(checkpoint_path, 'config.json')
66
+ with open(config_path, 'r') as f:
67
+ config = json.load(f)
68
+
69
+ labels = config['id2label']
70
+ return labels
@@ -0,0 +1,6 @@
1
+ import os
2
+
3
+ from clarifai.utils.logging import get_logger
4
+
5
+ logger_level = os.environ.get("LOG_LEVEL", "INFO")
6
+ logger = get_logger(logger_level, __name__)
@@ -0,0 +1,42 @@
1
+ import concurrent.futures
2
+
3
+ import fsspec
4
+
5
+ from .logging import logger
6
+
7
+
8
+ def download_input(input):
9
+ """
10
+ This function will download any urls that are not already bytes.
11
+ """
12
+ if input.data.image.url and not input.data.image.base64:
13
+ # Download the image
14
+ with fsspec.open(input.data.image.url, 'rb') as f:
15
+ input.data.image.base64 = f.read()
16
+ if input.data.video.url and not input.data.video.base64:
17
+ # Download the video
18
+ with fsspec.open(input.data.video.url, 'rb') as f:
19
+ input.data.video.base64 = f.read()
20
+ if input.data.audio.url and not input.data.audio.base64:
21
+ # Download the audio
22
+ with fsspec.open(input.data.audio.url, 'rb') as f:
23
+ input.data.audio.base64 = f.read()
24
+ if input.data.text.url and not input.data.text.raw:
25
+ # Download the text
26
+ with fsspec.open(input.data.text.url, 'r') as f:
27
+ input.data.text.raw = f.read()
28
+
29
+
30
+ def ensure_urls_downloaded(request, max_threads=128):
31
+ """
32
+ This function will download any urls that are not already bytes and parallelize with a thread pool.
33
+ """
34
+ with concurrent.futures.ThreadPoolExecutor(max_workers=max_threads) as executor:
35
+ futures = []
36
+ for input in request.inputs:
37
+ futures.append(executor.submit(download_input, input))
38
+ for future in concurrent.futures.as_completed(futures):
39
+ try:
40
+ future.result()
41
+ except Exception as e:
42
+ logger.exception(f"Error downloading input: {e}")
clarifai/utils/logging.py CHANGED
@@ -1,4 +1,12 @@
1
+ import datetime
2
+ import json
1
3
  import logging
4
+ import os
5
+ import socket
6
+ import sys
7
+ import threading
8
+ import time
9
+ import traceback
2
10
  from collections import defaultdict
3
11
  from typing import Any, Dict, List, Optional, Union
4
12
 
@@ -11,6 +19,41 @@ from rich.tree import Tree
11
19
 
12
20
  install()
13
21
 
22
+ # For the json logger.
23
+ JSON_LOGGER_NAME = "clarifai-json"
24
+ JSON_LOG_KEY = 'msg'
25
+ JSON_DEFAULT_CHAR_LENGTH = 400
26
+ FIELD_BLACKLIST = [
27
+ 'msg', 'message', 'account', 'levelno', 'created', 'threadName', 'name', 'processName',
28
+ 'module', 'funcName', 'msecs', 'relativeCreated', 'pathname', 'args', 'thread', 'process'
29
+ ]
30
+
31
+ # Create thread local storage that the format() call below uses.
32
+ # This is only used by the json_logger in the appropriate CLARIFAI_DEPLOY levels.
33
+ thread_log_info = threading.local()
34
+
35
+
36
+ def get_logger_context():
37
+ return thread_log_info.__dict__
38
+
39
+
40
+ def set_logger_context(**kwargs):
41
+ thread_log_info.__dict__.update(kwargs)
42
+
43
+
44
+ def clear_logger_context():
45
+ thread_log_info.__dict__.clear()
46
+
47
+
48
+ def restore_logger_context(context):
49
+ thread_log_info.__dict__.clear()
50
+ thread_log_info.__dict__.update(context)
51
+
52
+
53
+ def get_req_id_from_context():
54
+ ctx = get_logger_context()
55
+ return ctx.get('req_id', '')
56
+
14
57
 
15
58
  def display_workflow_tree(nodes_data: List[Dict]) -> None:
16
59
  """Displays a tree of the workflow nodes."""
@@ -84,12 +127,24 @@ def _configure_logger(name: str, logger_level: Union[int, str] = logging.NOTSET)
84
127
  for handler in logger.handlers[:]:
85
128
  logger.removeHandler(handler)
86
129
 
87
- # Add the new rich handler and formatter
88
- handler = RichHandler(
89
- rich_tracebacks=True, log_time_format="%Y-%m-%d %H:%M:%S", console=Console(width=255))
90
- formatter = logging.Formatter('%(name)s: %(message)s')
91
- handler.setFormatter(formatter)
92
- logger.addHandler(handler)
130
+ # If ENABLE_JSON_LOGGER is 'true' then definitely use json logger.
131
+ # If ENABLE_JSON_LOGGER is 'false' then definitely don't use json logger.
132
+ # If ENABLE_JSON_LOGGER is not set, then use json logger if in k8s.
133
+ enabled_json = os.getenv('ENABLE_JSON_LOGGER', None)
134
+ in_k8s = 'KUBERNETES_SERVICE_HOST' in os.environ
135
+ if enabled_json == 'true' or (in_k8s and enabled_json != 'false'):
136
+ # Add the json handler and formatter
137
+ handler = logging.StreamHandler()
138
+ formatter = JsonFormatter()
139
+ handler.setFormatter(formatter)
140
+ logger.addHandler(handler)
141
+ else:
142
+ # Add the new rich handler and formatter
143
+ handler = RichHandler(
144
+ rich_tracebacks=True, log_time_format="%Y-%m-%d %H:%M:%S", console=Console(width=255))
145
+ formatter = logging.Formatter('%(name)s: %(message)s')
146
+ handler.setFormatter(formatter)
147
+ logger.addHandler(handler)
93
148
 
94
149
 
95
150
  def get_logger(logger_level: Union[int, str] = logging.NOTSET,
@@ -151,3 +206,154 @@ def display_concept_relations_tree(relations_dict: Dict[str, Any]) -> None:
151
206
  for child in children:
152
207
  tree.add(child)
153
208
  rprint(tree)
209
+
210
+
211
+ def _default_json_default(obj):
212
+ """
213
+ Handle objects that could not be serialized to JSON automatically.
214
+
215
+ Coerce everything to strings.
216
+ All objects representing time get output as ISO8601.
217
+ """
218
+ if isinstance(obj, (datetime.datetime, datetime.date, datetime.time)):
219
+ return obj.isoformat()
220
+ else:
221
+ return _object_to_string_with_truncation(obj)
222
+
223
+
224
+ def _object_to_string_with_truncation(obj) -> str:
225
+ """
226
+ Truncate object string.
227
+
228
+ It's preferred to not log objects that could cause triggering this function,
229
+ It's better to extract important parts form them and log them as regular Python types,
230
+ like str or int, which won't be passed to this functon.
231
+
232
+ This message brings additional information to the logs
233
+ that could help to find and fix truncation cases.
234
+ - hardcoded part of the message could be used for the looking all entries in logs
235
+ - obj class could help with detail investigation
236
+ """
237
+
238
+ objstr = str(obj)
239
+ if len(objstr) > JSON_DEFAULT_CHAR_LENGTH:
240
+ type_name = type(obj).__name__
241
+ truncated = objstr[:JSON_DEFAULT_CHAR_LENGTH]
242
+ objstr = f"{truncated}...[{type_name} was truncated, len={len(objstr)} chars]"
243
+ return objstr
244
+
245
+
246
+ class JsonFormatter(logging.Formatter):
247
+
248
+ def __init__(self,
249
+ fmt=None,
250
+ datefmt=None,
251
+ style='%',
252
+ json_cls=None,
253
+ json_default=_default_json_default):
254
+ """
255
+ :param fmt: Config as a JSON string, allowed fields;
256
+ extra: provide extra fields always present in logs
257
+ source_host: override source host name
258
+ :param datefmt: Date format to use (required by logging.Formatter
259
+ interface but not used)
260
+ :param json_cls: JSON encoder to forward to json.dumps
261
+ :param json_default: Default JSON representation for unknown types,
262
+ by default coerce everything to a string
263
+ """
264
+
265
+ if fmt is not None:
266
+ self._fmt = json.loads(fmt)
267
+ else:
268
+ self._fmt = {}
269
+ self.json_default = json_default
270
+ self.json_cls = json_cls
271
+ if 'extra' not in self._fmt:
272
+ self.defaults = {}
273
+ else:
274
+ self.defaults = self._fmt['extra']
275
+ if 'source_host' in self._fmt:
276
+ self.source_host = self._fmt['source_host']
277
+ else:
278
+ try:
279
+ self.source_host = socket.gethostname()
280
+ except Exception:
281
+ self.source_host = ""
282
+
283
+ def _build_fields(self, defaults, fields):
284
+ """Return provided fields including any in defaults
285
+ """
286
+ return dict(list(defaults.get('@fields', {}).items()) + list(fields.items()))
287
+
288
+ # Override the format function to fit Clarifai
289
+ def format(self, record):
290
+ fields = record.__dict__.copy()
291
+
292
+ # logger.info({...}) directly.
293
+ if isinstance(record.msg, dict):
294
+ fields.update(record.msg)
295
+ fields.pop('msg')
296
+ msg = ""
297
+ else: # logger.info("message", {...})
298
+ if isinstance(record.args, dict):
299
+ fields.update(record.args)
300
+ msg = record.getMessage()
301
+ for k in FIELD_BLACKLIST:
302
+ fields.pop(k, None)
303
+ # Rename 'levelname' to 'level' and make the value lowercase to match Go logs
304
+ level = fields.pop('levelname', None)
305
+ if level:
306
+ fields['level'] = level.lower()
307
+
308
+ # Get the thread local data
309
+ req_id = getattr(thread_log_info, 'req_id', None)
310
+ if req_id:
311
+ fields['req_id'] = req_id
312
+ orig_req_id = getattr(thread_log_info, 'orig_req_id', None)
313
+ if orig_req_id:
314
+ fields['orig_req_id'] = orig_req_id
315
+ # Get the thread local data
316
+ requester = getattr(thread_log_info, 'requester', None)
317
+ if requester:
318
+ fields['requester'] = requester
319
+
320
+ user_id = getattr(thread_log_info, 'user_id', None)
321
+ if requester:
322
+ fields['user_id'] = user_id
323
+
324
+ if hasattr(thread_log_info, 'start_time'):
325
+ #pylint: disable=no-member
326
+ fields['duration_ms'] = (time.time() - thread_log_info.start_time) * 1000
327
+
328
+ if 'exc_info' in fields:
329
+ if fields['exc_info']:
330
+ formatted = traceback.format_exception(*fields['exc_info'])
331
+ fields['exception'] = formatted
332
+
333
+ fields.pop('exc_info')
334
+
335
+ if 'exc_text' in fields and not fields['exc_text']:
336
+ fields.pop('exc_text')
337
+
338
+ logr = self.defaults.copy()
339
+
340
+ logr.update({
341
+ JSON_LOG_KEY: msg,
342
+ '@timestamp': datetime.datetime.utcnow().strftime('%Y-%m-%dT%H:%M:%S.%fZ')
343
+ })
344
+
345
+ logr.update(fields)
346
+
347
+ try:
348
+ return json.dumps(logr, default=self.json_default, cls=self.json_cls)
349
+ except Exception:
350
+
351
+ type, value, tb = sys.exc_info()
352
+ return json.dumps(
353
+ {
354
+ "msg": f"Fail to format log {type.__name__}({value}), {logr}",
355
+ "formatting_traceback": "\n".join(traceback.format_tb(tb)),
356
+ },
357
+ default=self.json_default,
358
+ cls=self.json_cls,
359
+ )
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: clarifai
3
- Version: 10.8.1
3
+ Version: 10.8.3
4
4
  Summary: Clarifai Python SDK
5
5
  Home-page: https://github.com/Clarifai/clarifai-python
6
6
  Author: Clarifai
@@ -20,7 +20,8 @@ Classifier: Operating System :: OS Independent
20
20
  Requires-Python: >=3.8
21
21
  Description-Content-Type: text/markdown
22
22
  License-File: LICENSE
23
- Requires-Dist: clarifai-grpc >=10.8.3
23
+ Requires-Dist: clarifai-grpc >=10.8.7
24
+ Requires-Dist: clarifai-protocol >=0.0.4
24
25
  Requires-Dist: numpy >=1.22.0
25
26
  Requires-Dist: tqdm >=4.65.0
26
27
  Requires-Dist: tritonclient >=2.34.0