clarifai 11.4.1__py3-none-any.whl → 11.4.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (40) hide show
  1. clarifai/__init__.py +1 -1
  2. clarifai/cli/base.py +7 -0
  3. clarifai/cli/model.py +6 -8
  4. clarifai/client/app.py +2 -1
  5. clarifai/client/auth/helper.py +6 -4
  6. clarifai/client/compute_cluster.py +2 -1
  7. clarifai/client/dataset.py +8 -1
  8. clarifai/client/deployment.py +2 -1
  9. clarifai/client/input.py +2 -1
  10. clarifai/client/model.py +2 -1
  11. clarifai/client/model_client.py +1 -1
  12. clarifai/client/module.py +2 -1
  13. clarifai/client/nodepool.py +2 -1
  14. clarifai/client/runner.py +2 -1
  15. clarifai/client/search.py +2 -1
  16. clarifai/client/user.py +2 -1
  17. clarifai/client/workflow.py +2 -1
  18. clarifai/runners/models/mcp_class.py +114 -0
  19. clarifai/runners/models/model_builder.py +179 -46
  20. clarifai/runners/models/model_class.py +5 -22
  21. clarifai/runners/models/model_run_locally.py +0 -4
  22. clarifai/runners/models/visual_classifier_class.py +75 -0
  23. clarifai/runners/models/visual_detector_class.py +79 -0
  24. clarifai/runners/utils/code_script.py +75 -44
  25. clarifai/runners/utils/const.py +15 -0
  26. clarifai/runners/utils/data_types/data_types.py +48 -0
  27. clarifai/runners/utils/data_utils.py +99 -45
  28. clarifai/runners/utils/loader.py +23 -2
  29. clarifai/runners/utils/method_signatures.py +4 -4
  30. clarifai/runners/utils/openai_convertor.py +103 -0
  31. clarifai/urls/helper.py +80 -12
  32. clarifai/utils/config.py +19 -0
  33. clarifai/utils/constants.py +4 -0
  34. clarifai/utils/logging.py +22 -5
  35. {clarifai-11.4.1.dist-info → clarifai-11.4.3.dist-info}/METADATA +1 -2
  36. {clarifai-11.4.1.dist-info → clarifai-11.4.3.dist-info}/RECORD +40 -37
  37. {clarifai-11.4.1.dist-info → clarifai-11.4.3.dist-info}/WHEEL +1 -1
  38. {clarifai-11.4.1.dist-info → clarifai-11.4.3.dist-info}/entry_points.txt +0 -0
  39. {clarifai-11.4.1.dist-info → clarifai-11.4.3.dist-info}/licenses/LICENSE +0 -0
  40. {clarifai-11.4.1.dist-info → clarifai-11.4.3.dist-info}/top_level.txt +0 -0
@@ -14,15 +14,17 @@ import yaml
14
14
  from clarifai_grpc.grpc.api import resources_pb2, service_pb2
15
15
  from clarifai_grpc.grpc.api.status import status_code_pb2
16
16
  from google.protobuf import json_format
17
- from rich import print
18
- from rich.markup import escape
19
17
 
20
18
  from clarifai.client.base import BaseClient
21
19
  from clarifai.runners.models.model_class import ModelClass
22
20
  from clarifai.runners.utils.const import (
21
+ AMD_PYTHON_BASE_IMAGE,
22
+ AMD_VLLM_BASE_IMAGE,
23
23
  AVAILABLE_PYTHON_IMAGES,
24
24
  AVAILABLE_TORCH_IMAGES,
25
25
  CONCEPTS_REQUIRED_MODEL_TYPE,
26
+ DEFAULT_AMD_GPU_VERSION,
27
+ DEFAULT_AMD_TORCH_VERSION,
26
28
  DEFAULT_DOWNLOAD_CHECKPOINT_WHEN,
27
29
  DEFAULT_PYTHON_VERSION,
28
30
  DEFAULT_RUNTIME_DOWNLOAD_PATH,
@@ -43,13 +45,6 @@ dependencies = [
43
45
  ]
44
46
 
45
47
 
46
- def _clear_line(n: int = 1) -> None:
47
- LINE_UP = '\033[1A' # Move cursor up one line
48
- LINE_CLEAR = '\x1b[2K' # Clear the entire line
49
- for _ in range(n):
50
- print(LINE_UP, end=LINE_CLEAR, flush=True)
51
-
52
-
53
48
  def is_related(object_class, main_class):
54
49
  # Check if the object_class is a subclass of main_class
55
50
  if issubclass(object_class, main_class):
@@ -361,13 +356,23 @@ class ModelBuilder:
361
356
  if self.config.get("checkpoints"):
362
357
  loader_type, _, hf_token, _, _, _ = self._validate_config_checkpoints()
363
358
 
364
- if loader_type == "huggingface" and hf_token:
365
- is_valid_token = HuggingFaceLoader.validate_hftoken(hf_token)
366
- if not is_valid_token:
359
+ if loader_type == "huggingface":
360
+ is_valid_token = hf_token and HuggingFaceLoader.validate_hftoken(hf_token)
361
+ if not is_valid_token and hf_token:
362
+ logger.info(
363
+ "Continuing without Hugging Face token for validating config in model builder."
364
+ )
365
+
366
+ has_repo_access = HuggingFaceLoader.validate_hf_repo_access(
367
+ repo_id=self.config.get("checkpoints", {}).get("repo_id"),
368
+ token=hf_token if is_valid_token else None,
369
+ )
370
+
371
+ if not has_repo_access:
367
372
  logger.error(
368
- "Invalid Hugging Face token provided in the config file, this might cause issues with downloading the restricted model checkpoints."
373
+ f"Invalid Hugging Face repo access for repo {self.config.get('checkpoints').get('repo_id')}. Please check your repo and try again."
369
374
  )
370
- logger.info("Continuing without Hugging Face token")
375
+ sys.exit("Token does not have access to HuggingFace repo , exiting.")
371
376
 
372
377
  num_threads = self.config.get("num_threads")
373
378
  if num_threads or num_threads == 0:
@@ -405,11 +410,17 @@ class ModelBuilder:
405
410
  signatures = {method.name: method.signature for method in method_info.values()}
406
411
  return signatures_to_yaml(signatures)
407
412
 
408
- def get_method_signatures(self):
413
+ def get_method_signatures(self, mocking=True):
409
414
  """
410
415
  Returns the method signatures for the model class.
416
+
417
+ Args:
418
+ mocking (bool): Whether to mock the model class or not. Defaults to False.
419
+
420
+ Returns:
421
+ list: A list of method signatures for the model class.
411
422
  """
412
- model_class = self.load_model_class(mocking=True)
423
+ model_class = self.load_model_class(mocking=mocking)
413
424
  method_info = model_class._get_method_info()
414
425
  signatures = [method.signature for method in method_info.values()]
415
426
  return signatures
@@ -431,22 +442,42 @@ class ModelBuilder:
431
442
  return self._client
432
443
 
433
444
  @property
434
- def model_url(self):
445
+ def model_ui_url(self):
446
+ url_helper = ClarifaiUrlHelper(self._client.auth_helper)
447
+ # Note(zeiler): the UI experience isn't the best when including version id right now.
448
+ # if self.model_version_id is not None:
449
+ # return url_helper.clarifai_url(
450
+ # self.client.user_app_id.user_id,
451
+ # self.client.user_app_id.app_id,
452
+ # "models",
453
+ # self.model_id,
454
+ # self.model_version_id,
455
+ # )
456
+ # else:
457
+ return url_helper.clarifai_url(
458
+ self.client.user_app_id.user_id,
459
+ self.client.user_app_id.app_id,
460
+ "models",
461
+ self.model_id,
462
+ )
463
+
464
+ @property
465
+ def model_api_url(self):
435
466
  url_helper = ClarifaiUrlHelper(self._client.auth_helper)
436
467
  if self.model_version_id is not None:
437
- return url_helper.clarifai_url(
468
+ return url_helper.api_url(
438
469
  self.client.user_app_id.user_id,
439
470
  self.client.user_app_id.app_id,
440
471
  "models",
441
472
  self.model_id,
473
+ self.model_version_id,
442
474
  )
443
475
  else:
444
- return url_helper.clarifai_url(
476
+ return url_helper.api_url(
445
477
  self.client.user_app_id.user_id,
446
478
  self.client.user_app_id.app_id,
447
479
  "models",
448
480
  self.model_id,
449
- self.model_version_id,
450
481
  )
451
482
 
452
483
  def _get_model_proto(self):
@@ -532,6 +563,30 @@ class ModelBuilder:
532
563
  dependencies_version[dependency] = version if version else None
533
564
  return dependencies_version
534
565
 
566
+ def _is_amd(self):
567
+ """
568
+ Check if the model is AMD or not.
569
+ """
570
+ is_amd_gpu = False
571
+ is_nvidia_gpu = False
572
+ if "inference_compute_info" in self.config:
573
+ inference_compute_info = self.config.get('inference_compute_info')
574
+ if 'accelerator_type' in inference_compute_info:
575
+ for accelerator in inference_compute_info['accelerator_type']:
576
+ if 'amd' in accelerator.lower():
577
+ is_amd_gpu = True
578
+ elif 'nvidia' in accelerator.lower():
579
+ is_nvidia_gpu = True
580
+ if is_amd_gpu and is_nvidia_gpu:
581
+ raise Exception(
582
+ "Both AMD and NVIDIA GPUs are specified in the config file, please use only one type of GPU."
583
+ )
584
+ if is_amd_gpu:
585
+ logger.info("Using AMD base image to build the Docker image and upload the model")
586
+ elif is_nvidia_gpu:
587
+ logger.info("Using NVIDIA base image to build the Docker image and upload the model")
588
+ return is_amd_gpu
589
+
535
590
  def create_dockerfile(self):
536
591
  dockerfile_template = os.path.join(
537
592
  os.path.dirname(os.path.dirname(__file__)),
@@ -562,30 +617,85 @@ class ModelBuilder:
562
617
  )
563
618
  python_version = DEFAULT_PYTHON_VERSION
564
619
 
565
- # This is always the final image used for runtime.
566
- final_image = PYTHON_BASE_IMAGE.format(python_version=python_version)
567
- downloader_image = PYTHON_BASE_IMAGE.format(python_version=python_version)
568
-
569
620
  # Parse the requirements.txt file to determine the base image
570
621
  dependencies = self._parse_requirements()
571
- if 'torch' in dependencies and dependencies['torch']:
572
- torch_version = dependencies['torch']
573
-
574
- # Sort in reverse so that newer cuda versions come first and are preferred.
575
- for image in sorted(AVAILABLE_TORCH_IMAGES, reverse=True):
576
- if torch_version in image and f'py{python_version}' in image:
577
- # like cu124, rocm6.3, etc.
578
- gpu_version = image.split('-')[-1]
579
- final_image = TORCH_BASE_IMAGE.format(
580
- torch_version=torch_version,
581
- python_version=python_version,
582
- gpu_version=gpu_version,
622
+
623
+ is_amd_gpu = self._is_amd()
624
+ if is_amd_gpu:
625
+ final_image = AMD_PYTHON_BASE_IMAGE.format(python_version=python_version)
626
+ downloader_image = AMD_PYTHON_BASE_IMAGE.format(python_version=python_version)
627
+ if 'vllm' in dependencies:
628
+ if python_version != DEFAULT_PYTHON_VERSION:
629
+ raise Exception(
630
+ f"vLLM is not supported with Python version {python_version}, please use Python version {DEFAULT_PYTHON_VERSION} in your config.yaml"
631
+ )
632
+ torch_version = dependencies.get('torch', None)
633
+ if 'torch' in dependencies:
634
+ if python_version != DEFAULT_PYTHON_VERSION:
635
+ raise Exception(
636
+ f"torch is not supported with Python version {python_version}, please use Python version {DEFAULT_PYTHON_VERSION} in your config.yaml"
637
+ )
638
+ if not torch_version:
639
+ logger.info(
640
+ f"torch version not found in requirements.txt, using the default version {DEFAULT_AMD_TORCH_VERSION}"
641
+ )
642
+ torch_version = DEFAULT_AMD_TORCH_VERSION
643
+ if torch_version not in [DEFAULT_AMD_TORCH_VERSION]:
644
+ raise Exception(
645
+ f"torch version {torch_version} not supported, please use one of the following versions: {DEFAULT_AMD_TORCH_VERSION} in your requirements.txt"
646
+ )
647
+ python_version = DEFAULT_PYTHON_VERSION
648
+ gpu_version = DEFAULT_AMD_GPU_VERSION
649
+ final_image = AMD_VLLM_BASE_IMAGE.format(
650
+ torch_version=torch_version,
651
+ python_version=python_version,
652
+ gpu_version=gpu_version,
653
+ )
654
+ logger.info("Using vLLM base image to build the Docker image")
655
+ elif 'torch' in dependencies:
656
+ torch_version = dependencies['torch']
657
+ if python_version != DEFAULT_PYTHON_VERSION:
658
+ raise Exception(
659
+ f"torch is not supported with Python version {python_version}, please use Python version {DEFAULT_PYTHON_VERSION} in your config.yaml"
583
660
  )
661
+ if not torch_version:
584
662
  logger.info(
585
- f"Using Torch version {torch_version} base image to build the Docker image"
663
+ f"torch version not found in requirements.txt, using the default version {DEFAULT_AMD_TORCH_VERSION}"
586
664
  )
587
- break
588
-
665
+ torch_version = DEFAULT_AMD_TORCH_VERSION
666
+ if torch_version not in [DEFAULT_AMD_TORCH_VERSION]:
667
+ raise Exception(
668
+ f"torch version {torch_version} not supported, please use one of the following versions: {DEFAULT_AMD_TORCH_VERSION} in your requirements.txt"
669
+ )
670
+ python_version = DEFAULT_PYTHON_VERSION
671
+ gpu_version = DEFAULT_AMD_GPU_VERSION
672
+ final_image = TORCH_BASE_IMAGE.format(
673
+ torch_version=torch_version,
674
+ python_version=python_version,
675
+ gpu_version=gpu_version,
676
+ )
677
+ logger.info(
678
+ f"Using Torch version {torch_version} base image to build the Docker image"
679
+ )
680
+ else:
681
+ final_image = PYTHON_BASE_IMAGE.format(python_version=python_version)
682
+ downloader_image = PYTHON_BASE_IMAGE.format(python_version=python_version)
683
+ if 'torch' in dependencies and dependencies['torch']:
684
+ torch_version = dependencies['torch']
685
+ # Sort in reverse so that newer cuda versions come first and are preferred.
686
+ for image in sorted(AVAILABLE_TORCH_IMAGES, reverse=True):
687
+ if torch_version in image and f'py{python_version}' in image:
688
+ # like cu124, rocm6.3, etc.
689
+ gpu_version = image.split('-')[-1]
690
+ final_image = TORCH_BASE_IMAGE.format(
691
+ torch_version=torch_version,
692
+ python_version=python_version,
693
+ gpu_version=gpu_version,
694
+ )
695
+ logger.info(
696
+ f"Using Torch version {torch_version} base image to build the Docker image"
697
+ )
698
+ break
589
699
  if 'clarifai' not in dependencies:
590
700
  raise Exception(
591
701
  f"clarifai not found in requirements.txt, please add clarifai to the requirements.txt file with a fixed version. Current version is clarifai=={CLIENT_VERSION}"
@@ -835,7 +945,6 @@ class ModelBuilder:
835
945
  percent_completed = response.status.percent_completed
836
946
  details = response.status.details
837
947
 
838
- _clear_line()
839
948
  print(
840
949
  f"Status: {response.status.description}, Progress: {percent_completed}% - {details} ",
841
950
  f"request_id: {response.status.req_id}",
@@ -847,9 +956,26 @@ class ModelBuilder:
847
956
  return
848
957
  self.model_version_id = response.model_version_id
849
958
  logger.info(f"Created Model Version ID: {self.model_version_id}")
850
- logger.info(f"Full url to that version is: {self.model_url}")
959
+ logger.info(f"Full url to that version is: {self.model_ui_url}")
851
960
  try:
852
- self.monitor_model_build()
961
+ is_uploaded = self.monitor_model_build()
962
+ if is_uploaded:
963
+ # python code to run the model.
964
+ from clarifai.runners.utils import code_script
965
+
966
+ method_signatures = self.get_method_signatures()
967
+ snippet = code_script.generate_client_script(
968
+ method_signatures,
969
+ user_id=self.client.user_app_id.user_id,
970
+ app_id=self.client.user_app_id.app_id,
971
+ model_id=self.model_proto.id,
972
+ )
973
+ logger.info("""\n
974
+ XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX
975
+ # Here is a code snippet to use this model:
976
+ XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX
977
+ """)
978
+ logger.info(snippet)
853
979
  finally:
854
980
  if os.path.exists(self.tar_file):
855
981
  logger.debug(f"Cleaning up upload file: {self.tar_file}")
@@ -933,7 +1059,12 @@ class ModelBuilder:
933
1059
  for log_entry in logs.log_entries:
934
1060
  if log_entry.url not in seen_logs:
935
1061
  seen_logs.add(log_entry.url)
936
- logger.info(f"{escape(log_entry.message.strip())}")
1062
+ log_entry_msg = re.sub(
1063
+ r"(\\*)(\[[a-z#/@][^[]*?])",
1064
+ lambda m: f"{m.group(1)}{m.group(1)}\\{m.group(2)}",
1065
+ log_entry.message.strip(),
1066
+ )
1067
+ logger.info(log_entry_msg)
937
1068
  if status_code == status_code_pb2.MODEL_BUILDING:
938
1069
  print(
939
1070
  f"Model is building... (elapsed {time.time() - st:.1f}s)", end='\r', flush=True
@@ -945,7 +1076,7 @@ class ModelBuilder:
945
1076
  logger.info("Model build complete!")
946
1077
  logger.info(f"Build time elapsed {time.time() - st:.1f}s)")
947
1078
  logger.info(
948
- f"Check out the model at {self.model_url} version: {self.model_version_id}"
1079
+ f"Check out the model at {self.model_ui_url} version: {self.model_version_id}"
949
1080
  )
950
1081
  return True
951
1082
  else:
@@ -970,10 +1101,12 @@ def upload_model(folder, stage, skip_dockerfile):
970
1101
  exists = builder.check_model_exists()
971
1102
  if exists:
972
1103
  logger.info(
973
- f"Model already exists at {builder.model_url}, this upload will create a new version for it."
1104
+ f"Model already exists at {builder.model_ui_url}, this upload will create a new version for it."
974
1105
  )
975
1106
  else:
976
- logger.info(f"New model will be created at {builder.model_url} with it's first version.")
1107
+ logger.info(
1108
+ f"New model will be created at {builder.model_ui_url} with it's first version."
1109
+ )
977
1110
 
978
1111
  input("Press Enter to continue...")
979
1112
  builder.upload_model_version()
@@ -9,7 +9,6 @@ from typing import Any, Dict, Iterator, List
9
9
 
10
10
  from clarifai_grpc.grpc.api import resources_pb2, service_pb2
11
11
  from clarifai_grpc.grpc.api.status import status_code_pb2, status_pb2
12
- from google.protobuf import json_format
13
12
 
14
13
  from clarifai.runners.utils import data_types
15
14
  from clarifai.runners.utils.data_utils import DataConverter
@@ -100,7 +99,6 @@ class ModelClass(ABC):
100
99
  try:
101
100
  # TODO add method name field to proto
102
101
  method_name = 'predict'
103
- inference_params = get_inference_params(request)
104
102
  if len(request.inputs) > 0 and '_method_name' in request.inputs[0].data.metadata:
105
103
  method_name = request.inputs[0].data.metadata['_method_name']
106
104
  if (
@@ -124,7 +122,7 @@ class ModelClass(ABC):
124
122
  input.data.CopyFrom(new_data)
125
123
  # convert inputs to python types
126
124
  inputs = self._convert_input_protos_to_python(
127
- request.inputs, inference_params, signature.input_fields, python_param_types
125
+ request.inputs, signature.input_fields, python_param_types
128
126
  )
129
127
  if len(inputs) == 1:
130
128
  inputs = inputs[0]
@@ -163,7 +161,6 @@ class ModelClass(ABC):
163
161
  ) -> Iterator[service_pb2.MultiOutputResponse]:
164
162
  try:
165
163
  method_name = 'generate'
166
- inference_params = get_inference_params(request)
167
164
  if len(request.inputs) > 0 and '_method_name' in request.inputs[0].data.metadata:
168
165
  method_name = request.inputs[0].data.metadata['_method_name']
169
166
  method = getattr(self, method_name)
@@ -180,7 +177,7 @@ class ModelClass(ABC):
180
177
  )
181
178
  input.data.CopyFrom(new_data)
182
179
  inputs = self._convert_input_protos_to_python(
183
- request.inputs, inference_params, signature.input_fields, python_param_types
180
+ request.inputs, signature.input_fields, python_param_types
184
181
  )
185
182
  if len(inputs) == 1:
186
183
  inputs = inputs[0]
@@ -226,7 +223,6 @@ class ModelClass(ABC):
226
223
  assert len(request.inputs) == 1, "Streaming requires exactly one input"
227
224
 
228
225
  method_name = 'stream'
229
- inference_params = get_inference_params(request)
230
226
  if len(request.inputs) > 0 and '_method_name' in request.inputs[0].data.metadata:
231
227
  method_name = request.inputs[0].data.metadata['_method_name']
232
228
  method = getattr(self, method_name)
@@ -251,7 +247,7 @@ class ModelClass(ABC):
251
247
  input.data.CopyFrom(new_data)
252
248
  # convert all inputs for the first request, including the first stream value
253
249
  inputs = self._convert_input_protos_to_python(
254
- request.inputs, inference_params, signature.input_fields, python_param_types
250
+ request.inputs, signature.input_fields, python_param_types
255
251
  )
256
252
  kwargs = inputs[0]
257
253
 
@@ -264,7 +260,7 @@ class ModelClass(ABC):
264
260
  # subsequent streaming items contain only the streaming input
265
261
  for request in request_iterator:
266
262
  item = self._convert_input_protos_to_python(
267
- request.inputs, inference_params, [stream_sig], python_param_types
263
+ request.inputs, [stream_sig], python_param_types
268
264
  )
269
265
  item = item[0][stream_argname]
270
266
  yield item
@@ -297,13 +293,12 @@ class ModelClass(ABC):
297
293
  def _convert_input_protos_to_python(
298
294
  self,
299
295
  inputs: List[resources_pb2.Input],
300
- inference_params: dict,
301
296
  variables_signature: List[resources_pb2.ModelTypeField],
302
297
  python_param_types,
303
298
  ) -> List[Dict[str, Any]]:
304
299
  result = []
305
300
  for input in inputs:
306
- kwargs = deserialize(input.data, variables_signature, inference_params)
301
+ kwargs = deserialize(input.data, variables_signature)
307
302
  # dynamic cast to annotated types
308
303
  for k, v in kwargs.items():
309
304
  if k not in python_param_types:
@@ -374,18 +369,6 @@ class ModelClass(ABC):
374
369
  return method_info
375
370
 
376
371
 
377
- # Helper function to get the inference params
378
- def get_inference_params(request) -> dict:
379
- """Get the inference params from the request."""
380
- inference_params = {}
381
- if request.model.model_version.id != "":
382
- output_info = request.model.model_version.output_info
383
- output_info = json_format.MessageToDict(output_info, preserving_proto_field_name=True)
384
- if "params" in output_info:
385
- inference_params = output_info["params"]
386
- return inference_params
387
-
388
-
389
372
  class _MethodInfo:
390
373
  def __init__(self, method):
391
374
  self.name = method.__name__
@@ -442,10 +442,6 @@ def main(
442
442
  manager = ModelRunLocally(model_path)
443
443
  # get whatever stage is in config.yaml to force download now
444
444
  # also always write to where upload/build wants to, not the /tmp folder that runtime stage uses
445
- _, _, _, when, _, _ = manager.builder._validate_config_checkpoints()
446
- manager.builder.download_checkpoints(
447
- stage=when, checkpoint_path_override=manager.builder.checkpoint_path
448
- )
449
445
  if inside_container:
450
446
  if not manager.is_docker_installed():
451
447
  sys.exit(1)
@@ -0,0 +1,75 @@
1
+ import os
2
+ import tempfile
3
+ from io import BytesIO
4
+ from typing import Dict, Iterator, List
5
+
6
+ import cv2
7
+ import torch
8
+ from PIL import Image as PILImage
9
+
10
+ from clarifai.runners.models.model_class import ModelClass
11
+ from clarifai.runners.utils.data_types import Concept, Frame, Image
12
+ from clarifai.utils.logging import logger
13
+
14
+
15
+ class VisualClassifierClass(ModelClass):
16
+ """Base class for visual classification models supporting image and video processing."""
17
+
18
+ @staticmethod
19
+ def preprocess_image(image_bytes: bytes) -> PILImage:
20
+ """Convert image bytes to PIL Image."""
21
+ return PILImage.open(BytesIO(image_bytes)).convert("RGB")
22
+
23
+ @staticmethod
24
+ def video_to_frames(video_bytes: bytes) -> Iterator[Frame]:
25
+ """Convert video bytes to frames.
26
+
27
+ Args:
28
+ video_bytes: Raw video data in bytes
29
+
30
+ Yields:
31
+ Frame with JPEG encoded frame data as bytes and timestamp in milliseconds
32
+ """
33
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".mp4") as temp_video_file:
34
+ temp_video_file.write(video_bytes)
35
+ temp_video_path = temp_video_file.name
36
+ logger.debug(f"temp_video_path: {temp_video_path}")
37
+
38
+ video = cv2.VideoCapture(temp_video_path)
39
+ logger.debug(f"video opened: {video.isOpened()}")
40
+
41
+ while video.isOpened():
42
+ ret, frame = video.read()
43
+ if not ret:
44
+ break
45
+ # Get frame timestamp in milliseconds
46
+ timestamp_ms = video.get(cv2.CAP_PROP_POS_MSEC)
47
+ frame_bytes = cv2.imencode('.jpg', frame)[1].tobytes()
48
+ yield Frame(image=Image(bytes=frame_bytes), time=timestamp_ms)
49
+
50
+ video.release()
51
+ os.unlink(temp_video_path)
52
+
53
+ @staticmethod
54
+ def process_concepts(
55
+ logits: torch.Tensor, threshold: float, model_labels: Dict[int, str]
56
+ ) -> List[List[Concept]]:
57
+ """Convert model logits into a structured format of concepts.
58
+
59
+ Args:
60
+ logits: Model output logits as a tensor (batch_size x num_classes)
61
+ model_labels: Dictionary mapping label indices to label names
62
+
63
+ Returns:
64
+ List of lists containing Concept objects for each input in the batch
65
+ """
66
+ outputs = []
67
+ for logit in logits:
68
+ probs = torch.softmax(logit, dim=-1)
69
+ sorted_indices = torch.argsort(probs, dim=-1, descending=True)
70
+ output_concepts = []
71
+ for idx in sorted_indices:
72
+ concept = Concept(name=model_labels[idx.item()], value=probs[idx].item())
73
+ output_concepts.append(concept)
74
+ outputs.append(output_concepts)
75
+ return outputs
@@ -0,0 +1,79 @@
1
+ import os
2
+ import tempfile
3
+ from io import BytesIO
4
+ from typing import Dict, Iterator, List
5
+
6
+ import cv2
7
+ import torch
8
+ from PIL import Image as PILImage
9
+
10
+ from clarifai.runners.models.model_class import ModelClass
11
+ from clarifai.runners.utils.data_types import Concept, Frame, Image, Region
12
+ from clarifai.utils.logging import logger
13
+
14
+
15
+ class VisualDetectorClass(ModelClass):
16
+ """Base class for visual detection models supporting image and video processing."""
17
+
18
+ @staticmethod
19
+ def preprocess_image(image_bytes: bytes) -> PILImage:
20
+ """Convert image bytes to PIL Image."""
21
+ return PILImage.open(BytesIO(image_bytes)).convert("RGB")
22
+
23
+ @staticmethod
24
+ def video_to_frames(video_bytes: bytes) -> Iterator[Frame]:
25
+ """Convert video bytes to frames.
26
+
27
+ Args:
28
+ video_bytes: Raw video data in bytes
29
+
30
+ Yields:
31
+ Frame with JPEG encoded frame data as bytes and timestamp in milliseconds
32
+ """
33
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".mp4") as temp_video_file:
34
+ temp_video_file.write(video_bytes)
35
+ temp_video_path = temp_video_file.name
36
+ logger.debug(f"temp_video_path: {temp_video_path}")
37
+
38
+ video = cv2.VideoCapture(temp_video_path)
39
+ logger.debug(f"video opened: {video.isOpened()}")
40
+
41
+ while video.isOpened():
42
+ ret, frame = video.read()
43
+ if not ret:
44
+ break
45
+ # Get frame timestamp in milliseconds
46
+ timestamp_ms = video.get(cv2.CAP_PROP_POS_MSEC)
47
+ frame_bytes = cv2.imencode('.jpg', frame)[1].tobytes()
48
+ yield Frame(image=Image(bytes=frame_bytes), time=timestamp_ms)
49
+
50
+ video.release()
51
+ os.unlink(temp_video_path)
52
+
53
+ @staticmethod
54
+ def process_detections(
55
+ results: List[Dict[str, torch.Tensor]], threshold: float, model_labels: Dict[int, str]
56
+ ) -> List[List[Region]]:
57
+ """Convert model outputs into a structured format of detections.
58
+
59
+ Args:
60
+ results: Raw detection results from model
61
+ threshold: Confidence threshold for detections
62
+ model_labels: Dictionary mapping label indices to names
63
+
64
+ Returns:
65
+ List of lists containing Region objects for each detection
66
+ """
67
+ outputs = []
68
+ for result in results:
69
+ detections = []
70
+ for score, label_idx, box in zip(result["scores"], result["labels"], result["boxes"]):
71
+ if score > threshold:
72
+ label = model_labels[label_idx.item()]
73
+ detections.append(
74
+ Region(
75
+ box=box.tolist(), concepts=[Concept(name=label, value=score.item())]
76
+ )
77
+ )
78
+ outputs.append(detections)
79
+ return outputs