clarifai 10.8.2__py3-none-any.whl → 10.8.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,231 @@
1
+ from typing import Dict, List, Tuple, Union
2
+
3
+ import numpy as np
4
+ from clarifai_grpc.grpc.api import resources_pb2
5
+ from clarifai_grpc.grpc.api.status import status_code_pb2, status_pb2
6
+ from PIL import Image
7
+
8
+ from clarifai.client.auth.helper import ClarifaiAuthHelper
9
+
10
+ from .data_utils import bytes_to_image, image_to_bytes
11
+
12
+
13
+ class BaseDataHandler:
14
+
15
+ def __init__(self,
16
+ proto: Union[resources_pb2.Input, resources_pb2.Output],
17
+ auth: ClarifaiAuthHelper = None):
18
+ self._proto = proto
19
+ self._auth = auth
20
+
21
+ #
22
+ def to_python(self):
23
+ return dict(text=self.text, image=self.image, audio=self.audio)
24
+
25
+ # ---------------- Start get/setters ---------------- #
26
+ # Proto
27
+ @property
28
+ def proto(self):
29
+ return self._proto
30
+
31
+ # Status
32
+ @property
33
+ def status(self) -> status_pb2.Status:
34
+ return self._proto.status
35
+
36
+ def set_status(self, code: str, description: str = ""):
37
+ self._proto.status.code = code
38
+ self._proto.status.description = description
39
+
40
+ # Text
41
+ @property
42
+ def text(self) -> Union[None, str]:
43
+ data = self._proto.data.text
44
+ text = None
45
+ if data.ByteSize():
46
+ if data.raw:
47
+ text = data.raw
48
+ else:
49
+ raise NotImplementedError
50
+ return text
51
+
52
+ def set_text(self, text: str):
53
+ self._proto.data.text.raw = text
54
+
55
+ # Image
56
+ @property
57
+ def image(self, format: str = "np") -> Union[None, Image.Image, np.ndarray]:
58
+ data = self._proto.data.image
59
+ image = None
60
+ if data.ByteSize():
61
+ data: resources_pb2.Image = data
62
+ if data.base64:
63
+ image = data.base64
64
+ elif data.url:
65
+ raise NotImplementedError
66
+ image = bytes_to_image(image)
67
+ image = image if not format == "np" else np.asarray(image).astype("uint8")
68
+ return image
69
+
70
+ def set_image(self, image: Union[Image.Image, np.ndarray]):
71
+ if isinstance(image, np.ndarray):
72
+ image = Image.fromarray(image)
73
+ self._proto.data.image.base64 = image_to_bytes(image)
74
+
75
+ # Audio
76
+ @property
77
+ def audio(self) -> bytes:
78
+ data = self._proto.data.audio
79
+ audio = None
80
+ if data.ByteSize():
81
+ if data.base64:
82
+ audio = data.base64
83
+ return audio
84
+
85
+ def set_audio(self, audio: bytes):
86
+ self._proto.data.audio.base64 = audio
87
+
88
+ # Bboxes
89
+ @property
90
+ def bboxes(self, real_coord: bool = False, image_width: int = None,
91
+ image_height: int = None) -> Tuple[List, List, List]:
92
+ if real_coord:
93
+ assert (image_height or image_width
94
+ ), "image_height and image_width are required when when return real coordinates"
95
+ xyxy = []
96
+ scores = []
97
+ concepts = []
98
+ for _, each in enumerate(self._proto.data.regions):
99
+ box = each.region_info
100
+ score = each.value
101
+ concept = each.data.concepts[0].id
102
+ x1 = box.left_col
103
+ y1 = box.top_row
104
+ x2 = box.right_col
105
+ y2 = box.bottom_row
106
+ if real_coord:
107
+ x1 = x1 * image_width
108
+ y1 = y1 * image_height
109
+ x2 = x2 * image_width
110
+ y2 = y2 * image_height
111
+ xyxy.append([x1, y1, x2, y2])
112
+ scores.append(score)
113
+ concepts.append(concept)
114
+
115
+ return xyxy, scores, concepts
116
+
117
+ def set_bboxes(self,
118
+ boxes: list,
119
+ scores: list,
120
+ concepts: list,
121
+ real_coord: bool = False,
122
+ image_width: int = None,
123
+ image_height: int = None):
124
+ if real_coord:
125
+ assert (image_height and
126
+ image_width), "image_height and image_width are required when `real_coord` is set"
127
+ bboxes = [[x[1] / image_height, x[0] / image_width, x[3] / image_height, x[2] / image_width]
128
+ for x in boxes] # normalize the bboxes to [0,1] and [y1 x1 y2 x2]
129
+ bboxes = np.clip(bboxes, 0, 1.0)
130
+
131
+ regions = []
132
+ for ith, bbox in enumerate(bboxes):
133
+ score = scores[ith]
134
+ concept = concepts[ith]
135
+ if any([each > 1.0 for each in bbox]):
136
+ assert ValueError(
137
+ "Box coordinates is not normalized between [0, 1]. Please set format_box to True and provide image_height and image_width to normalize"
138
+ )
139
+ region = resources_pb2.RegionInfo(bounding_box=resources_pb2.BoundingBox(
140
+ top_row=bbox[0], # y_min
141
+ left_col=bbox[1], # x_min
142
+ bottom_row=bbox[2], # y_max
143
+ right_col=bbox[3], # x_max
144
+ ))
145
+ data = resources_pb2.Data(concepts=resources_pb2.Concept(id=concept, value=score))
146
+ regions.append(resources_pb2.Region(region_info=region, data=data))
147
+
148
+ self._proto.data.regions = regions
149
+
150
+ # Concepts
151
+ @property
152
+ def concepts(self) -> Dict[str, float]:
153
+ con_scores = {}
154
+ for each in self.proto.data.concepts:
155
+ con_scores.update({each.id: each.value})
156
+ return con_scores
157
+
158
+ def set_concepts(self, concept_score_pairs: Dict[str, float]):
159
+ concepts = []
160
+ for concept, score in concept_score_pairs.items():
161
+ con_score = resources_pb2.Concept(id=concept, name=concept, value=score)
162
+ concepts.append(con_score)
163
+ if concepts:
164
+ self._proto.data.ClearField("concepts")
165
+ for each in concepts:
166
+ self._proto.data.concepts.append(each)
167
+
168
+ # Embeddings
169
+ @property
170
+ def embeddings(self) -> List[List[float]]:
171
+ return [each.vector for each in self.proto.data.embeddings]
172
+
173
+ def set_embeddings(self, list_vectors: List[List[float]]):
174
+ if list_vectors[0]:
175
+ self._proto.data.ClearField("embeddings")
176
+ for vec in list_vectors:
177
+ self._proto.data.embeddings.append(
178
+ resources_pb2.Embedding(vector=vec, num_dimensions=len(vec)))
179
+
180
+ # ---------------- End get/setters ---------------- #
181
+
182
+ # Constructors
183
+ @classmethod
184
+ def from_proto(cls, proto):
185
+ clss = cls(proto=proto)
186
+ return clss
187
+
188
+ @classmethod
189
+ def from_data(
190
+ cls,
191
+ status_code: int = status_code_pb2.SUCCESS,
192
+ status_description: str = "",
193
+ text: str = None,
194
+ image: Union[Image.Image, np.ndarray] = None,
195
+ audio: bytes = None,
196
+ boxes: dict = None,
197
+ concepts: Dict[str, float] = {},
198
+ embeddings: List[List[float]] = [],
199
+ ) -> 'OutputDataHandler':
200
+ clss = cls(proto=resources_pb2.Output())
201
+ if isinstance(image, Image.Image) or isinstance(image, np.ndarray):
202
+ clss.set_image(image)
203
+ if text:
204
+ clss.set_text(text)
205
+ if audio:
206
+ clss.set_audio(audio)
207
+ if boxes:
208
+ clss.set_bboxes(**boxes)
209
+ if concepts:
210
+ clss.set_concepts(concepts)
211
+ if embeddings:
212
+ clss.set_embeddings(embeddings)
213
+
214
+ clss.set_status(code=status_code, description=status_description)
215
+ return clss
216
+
217
+
218
+ class InputDataHandler(BaseDataHandler):
219
+
220
+ def __init__(self,
221
+ proto: resources_pb2.Input = resources_pb2.Input(),
222
+ auth: ClarifaiAuthHelper = None):
223
+ super().__init__(proto=proto, auth=auth)
224
+
225
+
226
+ class OutputDataHandler(BaseDataHandler):
227
+
228
+ def __init__(self,
229
+ proto: resources_pb2.Output = resources_pb2.Output(),
230
+ auth: ClarifaiAuthHelper = None):
231
+ super().__init__(proto=proto, auth=auth)
@@ -0,0 +1,15 @@
1
+ from io import BytesIO
2
+
3
+ from PIL import Image
4
+
5
+
6
+ def image_to_bytes(img: Image.Image, format="JPEG") -> bytes:
7
+ buffered = BytesIO()
8
+ img.save(buffered, format=format)
9
+ img_str = buffered.getvalue()
10
+ return img_str
11
+
12
+
13
+ def bytes_to_image(bytes_img) -> Image.Image:
14
+ img = Image.open(BytesIO(bytes_img))
15
+ return img
@@ -0,0 +1,71 @@
1
+ import importlib.util
2
+ import json
3
+ import os
4
+ import subprocess
5
+
6
+
7
+ class HuggingFaceLoarder:
8
+
9
+ def __init__(self, repo_id=None, token=None):
10
+ self.repo_id = repo_id
11
+ self.token = token
12
+ if token:
13
+ try:
14
+ if importlib.util.find_spec("huggingface_hub") is None:
15
+ raise ImportError(
16
+ "The 'huggingface_hub' package is not installed. Please install it using 'pip install huggingface_hub'."
17
+ )
18
+ os.environ['HF_TOKEN'] = token
19
+ subprocess.run(f'huggingface-cli login --token={os.environ["HF_TOKEN"]}', shell=True)
20
+ except Exception as e:
21
+ Exception("Error setting up Hugging Face token ", e)
22
+
23
+ def download_checkpoints(self, checkpoint_path: str):
24
+ # throw error if huggingface_hub wasn't installed
25
+ try:
26
+ from huggingface_hub import snapshot_download
27
+ except ImportError:
28
+ raise ImportError(
29
+ "The 'huggingface_hub' package is not installed. Please install it using 'pip install huggingface_hub'."
30
+ )
31
+ if os.path.exists(checkpoint_path) and self.validate_download(checkpoint_path):
32
+ print("Checkpoints already exist")
33
+ else:
34
+ os.makedirs(checkpoint_path, exist_ok=True)
35
+ try:
36
+ is_hf_model_exists = self.validate_hf_model()
37
+ if not is_hf_model_exists:
38
+ print("Model not found on Hugging Face")
39
+ return False
40
+ snapshot_download(repo_id=self.repo_id, local_dir=checkpoint_path)
41
+ except Exception as e:
42
+ print("Error downloading model checkpoints ", e)
43
+ return False
44
+ finally:
45
+ is_downloaded = self.validate_download(checkpoint_path)
46
+ if not is_downloaded:
47
+ print("Error downloading model checkpoints")
48
+ return False
49
+ return True
50
+
51
+ def validate_hf_model(self,):
52
+ # check if model exists on HF
53
+
54
+ from huggingface_hub import file_exists, repo_exists
55
+ return repo_exists(self.repo_id) and file_exists(self.repo_id, 'config.json')
56
+
57
+ def validate_download(self, checkpoint_path: str):
58
+ # check if model exists on HF
59
+ from huggingface_hub import list_repo_files
60
+
61
+ return (len(os.listdir(checkpoint_path)) >= len(list_repo_files(self.repo_id))) and len(
62
+ list_repo_files(self.repo_id)) > 0
63
+
64
+ def fetch_labels(self, checkpoint_path: str):
65
+ # Fetch labels for classification, detection and segmentation models
66
+ config_path = os.path.join(checkpoint_path, 'config.json')
67
+ with open(config_path, 'r') as f:
68
+ config = json.load(f)
69
+
70
+ labels = config['id2label']
71
+ return labels
@@ -0,0 +1,6 @@
1
+ import os
2
+
3
+ from clarifai.utils.logging import get_logger
4
+
5
+ logger_level = os.environ.get("LOG_LEVEL", "INFO")
6
+ logger = get_logger(logger_level, __name__)
@@ -0,0 +1,42 @@
1
+ import concurrent.futures
2
+
3
+ import fsspec
4
+
5
+ from .logging import logger
6
+
7
+
8
+ def download_input(input):
9
+ """
10
+ This function will download any urls that are not already bytes.
11
+ """
12
+ if input.data.image.url and not input.data.image.base64:
13
+ # Download the image
14
+ with fsspec.open(input.data.image.url, 'rb') as f:
15
+ input.data.image.base64 = f.read()
16
+ if input.data.video.url and not input.data.video.base64:
17
+ # Download the video
18
+ with fsspec.open(input.data.video.url, 'rb') as f:
19
+ input.data.video.base64 = f.read()
20
+ if input.data.audio.url and not input.data.audio.base64:
21
+ # Download the audio
22
+ with fsspec.open(input.data.audio.url, 'rb') as f:
23
+ input.data.audio.base64 = f.read()
24
+ if input.data.text.url and not input.data.text.raw:
25
+ # Download the text
26
+ with fsspec.open(input.data.text.url, 'r') as f:
27
+ input.data.text.raw = f.read()
28
+
29
+
30
+ def ensure_urls_downloaded(request, max_threads=128):
31
+ """
32
+ This function will download any urls that are not already bytes and parallelize with a thread pool.
33
+ """
34
+ with concurrent.futures.ThreadPoolExecutor(max_workers=max_threads) as executor:
35
+ futures = []
36
+ for input in request.inputs:
37
+ futures.append(executor.submit(download_input, input))
38
+ for future in concurrent.futures.as_completed(futures):
39
+ try:
40
+ future.result()
41
+ except Exception as e:
42
+ logger.exception(f"Error downloading input: {e}")
clarifai/utils/logging.py CHANGED
@@ -1,4 +1,12 @@
1
+ import datetime
2
+ import json
1
3
  import logging
4
+ import os
5
+ import socket
6
+ import sys
7
+ import threading
8
+ import time
9
+ import traceback
2
10
  from collections import defaultdict
3
11
  from typing import Any, Dict, List, Optional, Union
4
12
 
@@ -11,6 +19,41 @@ from rich.tree import Tree
11
19
 
12
20
  install()
13
21
 
22
+ # For the json logger.
23
+ JSON_LOGGER_NAME = "clarifai-json"
24
+ JSON_LOG_KEY = 'msg'
25
+ JSON_DEFAULT_CHAR_LENGTH = 400
26
+ FIELD_BLACKLIST = [
27
+ 'msg', 'message', 'account', 'levelno', 'created', 'threadName', 'name', 'processName',
28
+ 'module', 'funcName', 'msecs', 'relativeCreated', 'pathname', 'args', 'thread', 'process'
29
+ ]
30
+
31
+ # Create thread local storage that the format() call below uses.
32
+ # This is only used by the json_logger in the appropriate CLARIFAI_DEPLOY levels.
33
+ thread_log_info = threading.local()
34
+
35
+
36
+ def get_logger_context():
37
+ return thread_log_info.__dict__
38
+
39
+
40
+ def set_logger_context(**kwargs):
41
+ thread_log_info.__dict__.update(kwargs)
42
+
43
+
44
+ def clear_logger_context():
45
+ thread_log_info.__dict__.clear()
46
+
47
+
48
+ def restore_logger_context(context):
49
+ thread_log_info.__dict__.clear()
50
+ thread_log_info.__dict__.update(context)
51
+
52
+
53
+ def get_req_id_from_context():
54
+ ctx = get_logger_context()
55
+ return ctx.get('req_id', '')
56
+
14
57
 
15
58
  def display_workflow_tree(nodes_data: List[Dict]) -> None:
16
59
  """Displays a tree of the workflow nodes."""
@@ -84,12 +127,24 @@ def _configure_logger(name: str, logger_level: Union[int, str] = logging.NOTSET)
84
127
  for handler in logger.handlers[:]:
85
128
  logger.removeHandler(handler)
86
129
 
87
- # Add the new rich handler and formatter
88
- handler = RichHandler(
89
- rich_tracebacks=True, log_time_format="%Y-%m-%d %H:%M:%S", console=Console(width=255))
90
- formatter = logging.Formatter('%(name)s: %(message)s')
91
- handler.setFormatter(formatter)
92
- logger.addHandler(handler)
130
+ # If ENABLE_JSON_LOGGER is 'true' then definitely use json logger.
131
+ # If ENABLE_JSON_LOGGER is 'false' then definitely don't use json logger.
132
+ # If ENABLE_JSON_LOGGER is not set, then use json logger if in k8s.
133
+ enabled_json = os.getenv('ENABLE_JSON_LOGGER', None)
134
+ in_k8s = 'KUBERNETES_SERVICE_HOST' in os.environ
135
+ if enabled_json == 'true' or (in_k8s and enabled_json != 'false'):
136
+ # Add the json handler and formatter
137
+ handler = logging.StreamHandler()
138
+ formatter = JsonFormatter()
139
+ handler.setFormatter(formatter)
140
+ logger.addHandler(handler)
141
+ else:
142
+ # Add the new rich handler and formatter
143
+ handler = RichHandler(
144
+ rich_tracebacks=True, log_time_format="%Y-%m-%d %H:%M:%S", console=Console(width=255))
145
+ formatter = logging.Formatter('%(name)s: %(message)s')
146
+ handler.setFormatter(formatter)
147
+ logger.addHandler(handler)
93
148
 
94
149
 
95
150
  def get_logger(logger_level: Union[int, str] = logging.NOTSET,
@@ -151,3 +206,154 @@ def display_concept_relations_tree(relations_dict: Dict[str, Any]) -> None:
151
206
  for child in children:
152
207
  tree.add(child)
153
208
  rprint(tree)
209
+
210
+
211
+ def _default_json_default(obj):
212
+ """
213
+ Handle objects that could not be serialized to JSON automatically.
214
+
215
+ Coerce everything to strings.
216
+ All objects representing time get output as ISO8601.
217
+ """
218
+ if isinstance(obj, (datetime.datetime, datetime.date, datetime.time)):
219
+ return obj.isoformat()
220
+ else:
221
+ return _object_to_string_with_truncation(obj)
222
+
223
+
224
+ def _object_to_string_with_truncation(obj) -> str:
225
+ """
226
+ Truncate object string.
227
+
228
+ It's preferred to not log objects that could cause triggering this function,
229
+ It's better to extract important parts form them and log them as regular Python types,
230
+ like str or int, which won't be passed to this functon.
231
+
232
+ This message brings additional information to the logs
233
+ that could help to find and fix truncation cases.
234
+ - hardcoded part of the message could be used for the looking all entries in logs
235
+ - obj class could help with detail investigation
236
+ """
237
+
238
+ objstr = str(obj)
239
+ if len(objstr) > JSON_DEFAULT_CHAR_LENGTH:
240
+ type_name = type(obj).__name__
241
+ truncated = objstr[:JSON_DEFAULT_CHAR_LENGTH]
242
+ objstr = f"{truncated}...[{type_name} was truncated, len={len(objstr)} chars]"
243
+ return objstr
244
+
245
+
246
+ class JsonFormatter(logging.Formatter):
247
+
248
+ def __init__(self,
249
+ fmt=None,
250
+ datefmt=None,
251
+ style='%',
252
+ json_cls=None,
253
+ json_default=_default_json_default):
254
+ """
255
+ :param fmt: Config as a JSON string, allowed fields;
256
+ extra: provide extra fields always present in logs
257
+ source_host: override source host name
258
+ :param datefmt: Date format to use (required by logging.Formatter
259
+ interface but not used)
260
+ :param json_cls: JSON encoder to forward to json.dumps
261
+ :param json_default: Default JSON representation for unknown types,
262
+ by default coerce everything to a string
263
+ """
264
+
265
+ if fmt is not None:
266
+ self._fmt = json.loads(fmt)
267
+ else:
268
+ self._fmt = {}
269
+ self.json_default = json_default
270
+ self.json_cls = json_cls
271
+ if 'extra' not in self._fmt:
272
+ self.defaults = {}
273
+ else:
274
+ self.defaults = self._fmt['extra']
275
+ if 'source_host' in self._fmt:
276
+ self.source_host = self._fmt['source_host']
277
+ else:
278
+ try:
279
+ self.source_host = socket.gethostname()
280
+ except Exception:
281
+ self.source_host = ""
282
+
283
+ def _build_fields(self, defaults, fields):
284
+ """Return provided fields including any in defaults
285
+ """
286
+ return dict(list(defaults.get('@fields', {}).items()) + list(fields.items()))
287
+
288
+ # Override the format function to fit Clarifai
289
+ def format(self, record):
290
+ fields = record.__dict__.copy()
291
+
292
+ # logger.info({...}) directly.
293
+ if isinstance(record.msg, dict):
294
+ fields.update(record.msg)
295
+ fields.pop('msg')
296
+ msg = ""
297
+ else: # logger.info("message", {...})
298
+ if isinstance(record.args, dict):
299
+ fields.update(record.args)
300
+ msg = record.getMessage()
301
+ for k in FIELD_BLACKLIST:
302
+ fields.pop(k, None)
303
+ # Rename 'levelname' to 'level' and make the value lowercase to match Go logs
304
+ level = fields.pop('levelname', None)
305
+ if level:
306
+ fields['level'] = level.lower()
307
+
308
+ # Get the thread local data
309
+ req_id = getattr(thread_log_info, 'req_id', None)
310
+ if req_id:
311
+ fields['req_id'] = req_id
312
+ orig_req_id = getattr(thread_log_info, 'orig_req_id', None)
313
+ if orig_req_id:
314
+ fields['orig_req_id'] = orig_req_id
315
+ # Get the thread local data
316
+ requester = getattr(thread_log_info, 'requester', None)
317
+ if requester:
318
+ fields['requester'] = requester
319
+
320
+ user_id = getattr(thread_log_info, 'user_id', None)
321
+ if requester:
322
+ fields['user_id'] = user_id
323
+
324
+ if hasattr(thread_log_info, 'start_time'):
325
+ #pylint: disable=no-member
326
+ fields['duration_ms'] = (time.time() - thread_log_info.start_time) * 1000
327
+
328
+ if 'exc_info' in fields:
329
+ if fields['exc_info']:
330
+ formatted = traceback.format_exception(*fields['exc_info'])
331
+ fields['exception'] = formatted
332
+
333
+ fields.pop('exc_info')
334
+
335
+ if 'exc_text' in fields and not fields['exc_text']:
336
+ fields.pop('exc_text')
337
+
338
+ logr = self.defaults.copy()
339
+
340
+ logr.update({
341
+ JSON_LOG_KEY: msg,
342
+ '@timestamp': datetime.datetime.utcnow().strftime('%Y-%m-%dT%H:%M:%S.%fZ')
343
+ })
344
+
345
+ logr.update(fields)
346
+
347
+ try:
348
+ return json.dumps(logr, default=self.json_default, cls=self.json_cls)
349
+ except Exception:
350
+
351
+ type, value, tb = sys.exc_info()
352
+ return json.dumps(
353
+ {
354
+ "msg": f"Fail to format log {type.__name__}({value}), {logr}",
355
+ "formatting_traceback": "\n".join(traceback.format_tb(tb)),
356
+ },
357
+ default=self.json_default,
358
+ cls=self.json_cls,
359
+ )
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: clarifai
3
- Version: 10.8.2
3
+ Version: 10.8.4
4
4
  Summary: Clarifai Python SDK
5
5
  Home-page: https://github.com/Clarifai/clarifai-python
6
6
  Author: Clarifai
@@ -20,7 +20,8 @@ Classifier: Operating System :: OS Independent
20
20
  Requires-Python: >=3.8
21
21
  Description-Content-Type: text/markdown
22
22
  License-File: LICENSE
23
- Requires-Dist: clarifai-grpc >=10.8.6
23
+ Requires-Dist: clarifai-grpc >=10.8.7
24
+ Requires-Dist: clarifai-protocol >=0.0.4
24
25
  Requires-Dist: numpy >=1.22.0
25
26
  Requires-Dist: tqdm >=4.65.0
26
27
  Requires-Dist: tritonclient >=2.34.0