clarifai 11.4.1__py3-none-any.whl → 11.4.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- clarifai/__init__.py +1 -1
- clarifai/cli/base.py +7 -0
- clarifai/cli/model.py +6 -8
- clarifai/client/app.py +2 -1
- clarifai/client/auth/helper.py +6 -4
- clarifai/client/compute_cluster.py +2 -1
- clarifai/client/dataset.py +8 -1
- clarifai/client/deployment.py +2 -1
- clarifai/client/input.py +2 -1
- clarifai/client/model.py +2 -1
- clarifai/client/model_client.py +1 -1
- clarifai/client/module.py +2 -1
- clarifai/client/nodepool.py +2 -1
- clarifai/client/runner.py +2 -1
- clarifai/client/search.py +2 -1
- clarifai/client/user.py +2 -1
- clarifai/client/workflow.py +2 -1
- clarifai/runners/models/mcp_class.py +114 -0
- clarifai/runners/models/model_builder.py +179 -46
- clarifai/runners/models/model_class.py +5 -22
- clarifai/runners/models/model_run_locally.py +0 -4
- clarifai/runners/models/visual_classifier_class.py +75 -0
- clarifai/runners/models/visual_detector_class.py +79 -0
- clarifai/runners/utils/code_script.py +75 -44
- clarifai/runners/utils/const.py +15 -0
- clarifai/runners/utils/data_types/data_types.py +48 -0
- clarifai/runners/utils/data_utils.py +99 -45
- clarifai/runners/utils/loader.py +23 -2
- clarifai/runners/utils/method_signatures.py +4 -4
- clarifai/runners/utils/openai_convertor.py +103 -0
- clarifai/urls/helper.py +80 -12
- clarifai/utils/config.py +19 -0
- clarifai/utils/constants.py +4 -0
- clarifai/utils/logging.py +22 -5
- {clarifai-11.4.1.dist-info → clarifai-11.4.3.dist-info}/METADATA +1 -2
- {clarifai-11.4.1.dist-info → clarifai-11.4.3.dist-info}/RECORD +40 -37
- {clarifai-11.4.1.dist-info → clarifai-11.4.3.dist-info}/WHEEL +1 -1
- {clarifai-11.4.1.dist-info → clarifai-11.4.3.dist-info}/entry_points.txt +0 -0
- {clarifai-11.4.1.dist-info → clarifai-11.4.3.dist-info}/licenses/LICENSE +0 -0
- {clarifai-11.4.1.dist-info → clarifai-11.4.3.dist-info}/top_level.txt +0 -0
@@ -14,15 +14,17 @@ import yaml
|
|
14
14
|
from clarifai_grpc.grpc.api import resources_pb2, service_pb2
|
15
15
|
from clarifai_grpc.grpc.api.status import status_code_pb2
|
16
16
|
from google.protobuf import json_format
|
17
|
-
from rich import print
|
18
|
-
from rich.markup import escape
|
19
17
|
|
20
18
|
from clarifai.client.base import BaseClient
|
21
19
|
from clarifai.runners.models.model_class import ModelClass
|
22
20
|
from clarifai.runners.utils.const import (
|
21
|
+
AMD_PYTHON_BASE_IMAGE,
|
22
|
+
AMD_VLLM_BASE_IMAGE,
|
23
23
|
AVAILABLE_PYTHON_IMAGES,
|
24
24
|
AVAILABLE_TORCH_IMAGES,
|
25
25
|
CONCEPTS_REQUIRED_MODEL_TYPE,
|
26
|
+
DEFAULT_AMD_GPU_VERSION,
|
27
|
+
DEFAULT_AMD_TORCH_VERSION,
|
26
28
|
DEFAULT_DOWNLOAD_CHECKPOINT_WHEN,
|
27
29
|
DEFAULT_PYTHON_VERSION,
|
28
30
|
DEFAULT_RUNTIME_DOWNLOAD_PATH,
|
@@ -43,13 +45,6 @@ dependencies = [
|
|
43
45
|
]
|
44
46
|
|
45
47
|
|
46
|
-
def _clear_line(n: int = 1) -> None:
|
47
|
-
LINE_UP = '\033[1A' # Move cursor up one line
|
48
|
-
LINE_CLEAR = '\x1b[2K' # Clear the entire line
|
49
|
-
for _ in range(n):
|
50
|
-
print(LINE_UP, end=LINE_CLEAR, flush=True)
|
51
|
-
|
52
|
-
|
53
48
|
def is_related(object_class, main_class):
|
54
49
|
# Check if the object_class is a subclass of main_class
|
55
50
|
if issubclass(object_class, main_class):
|
@@ -361,13 +356,23 @@ class ModelBuilder:
|
|
361
356
|
if self.config.get("checkpoints"):
|
362
357
|
loader_type, _, hf_token, _, _, _ = self._validate_config_checkpoints()
|
363
358
|
|
364
|
-
if loader_type == "huggingface"
|
365
|
-
is_valid_token = HuggingFaceLoader.validate_hftoken(hf_token)
|
366
|
-
if not is_valid_token:
|
359
|
+
if loader_type == "huggingface":
|
360
|
+
is_valid_token = hf_token and HuggingFaceLoader.validate_hftoken(hf_token)
|
361
|
+
if not is_valid_token and hf_token:
|
362
|
+
logger.info(
|
363
|
+
"Continuing without Hugging Face token for validating config in model builder."
|
364
|
+
)
|
365
|
+
|
366
|
+
has_repo_access = HuggingFaceLoader.validate_hf_repo_access(
|
367
|
+
repo_id=self.config.get("checkpoints", {}).get("repo_id"),
|
368
|
+
token=hf_token if is_valid_token else None,
|
369
|
+
)
|
370
|
+
|
371
|
+
if not has_repo_access:
|
367
372
|
logger.error(
|
368
|
-
"Invalid Hugging Face
|
373
|
+
f"Invalid Hugging Face repo access for repo {self.config.get('checkpoints').get('repo_id')}. Please check your repo and try again."
|
369
374
|
)
|
370
|
-
|
375
|
+
sys.exit("Token does not have access to HuggingFace repo , exiting.")
|
371
376
|
|
372
377
|
num_threads = self.config.get("num_threads")
|
373
378
|
if num_threads or num_threads == 0:
|
@@ -405,11 +410,17 @@ class ModelBuilder:
|
|
405
410
|
signatures = {method.name: method.signature for method in method_info.values()}
|
406
411
|
return signatures_to_yaml(signatures)
|
407
412
|
|
408
|
-
def get_method_signatures(self):
|
413
|
+
def get_method_signatures(self, mocking=True):
|
409
414
|
"""
|
410
415
|
Returns the method signatures for the model class.
|
416
|
+
|
417
|
+
Args:
|
418
|
+
mocking (bool): Whether to mock the model class or not. Defaults to False.
|
419
|
+
|
420
|
+
Returns:
|
421
|
+
list: A list of method signatures for the model class.
|
411
422
|
"""
|
412
|
-
model_class = self.load_model_class(mocking=
|
423
|
+
model_class = self.load_model_class(mocking=mocking)
|
413
424
|
method_info = model_class._get_method_info()
|
414
425
|
signatures = [method.signature for method in method_info.values()]
|
415
426
|
return signatures
|
@@ -431,22 +442,42 @@ class ModelBuilder:
|
|
431
442
|
return self._client
|
432
443
|
|
433
444
|
@property
|
434
|
-
def
|
445
|
+
def model_ui_url(self):
|
446
|
+
url_helper = ClarifaiUrlHelper(self._client.auth_helper)
|
447
|
+
# Note(zeiler): the UI experience isn't the best when including version id right now.
|
448
|
+
# if self.model_version_id is not None:
|
449
|
+
# return url_helper.clarifai_url(
|
450
|
+
# self.client.user_app_id.user_id,
|
451
|
+
# self.client.user_app_id.app_id,
|
452
|
+
# "models",
|
453
|
+
# self.model_id,
|
454
|
+
# self.model_version_id,
|
455
|
+
# )
|
456
|
+
# else:
|
457
|
+
return url_helper.clarifai_url(
|
458
|
+
self.client.user_app_id.user_id,
|
459
|
+
self.client.user_app_id.app_id,
|
460
|
+
"models",
|
461
|
+
self.model_id,
|
462
|
+
)
|
463
|
+
|
464
|
+
@property
|
465
|
+
def model_api_url(self):
|
435
466
|
url_helper = ClarifaiUrlHelper(self._client.auth_helper)
|
436
467
|
if self.model_version_id is not None:
|
437
|
-
return url_helper.
|
468
|
+
return url_helper.api_url(
|
438
469
|
self.client.user_app_id.user_id,
|
439
470
|
self.client.user_app_id.app_id,
|
440
471
|
"models",
|
441
472
|
self.model_id,
|
473
|
+
self.model_version_id,
|
442
474
|
)
|
443
475
|
else:
|
444
|
-
return url_helper.
|
476
|
+
return url_helper.api_url(
|
445
477
|
self.client.user_app_id.user_id,
|
446
478
|
self.client.user_app_id.app_id,
|
447
479
|
"models",
|
448
480
|
self.model_id,
|
449
|
-
self.model_version_id,
|
450
481
|
)
|
451
482
|
|
452
483
|
def _get_model_proto(self):
|
@@ -532,6 +563,30 @@ class ModelBuilder:
|
|
532
563
|
dependencies_version[dependency] = version if version else None
|
533
564
|
return dependencies_version
|
534
565
|
|
566
|
+
def _is_amd(self):
|
567
|
+
"""
|
568
|
+
Check if the model is AMD or not.
|
569
|
+
"""
|
570
|
+
is_amd_gpu = False
|
571
|
+
is_nvidia_gpu = False
|
572
|
+
if "inference_compute_info" in self.config:
|
573
|
+
inference_compute_info = self.config.get('inference_compute_info')
|
574
|
+
if 'accelerator_type' in inference_compute_info:
|
575
|
+
for accelerator in inference_compute_info['accelerator_type']:
|
576
|
+
if 'amd' in accelerator.lower():
|
577
|
+
is_amd_gpu = True
|
578
|
+
elif 'nvidia' in accelerator.lower():
|
579
|
+
is_nvidia_gpu = True
|
580
|
+
if is_amd_gpu and is_nvidia_gpu:
|
581
|
+
raise Exception(
|
582
|
+
"Both AMD and NVIDIA GPUs are specified in the config file, please use only one type of GPU."
|
583
|
+
)
|
584
|
+
if is_amd_gpu:
|
585
|
+
logger.info("Using AMD base image to build the Docker image and upload the model")
|
586
|
+
elif is_nvidia_gpu:
|
587
|
+
logger.info("Using NVIDIA base image to build the Docker image and upload the model")
|
588
|
+
return is_amd_gpu
|
589
|
+
|
535
590
|
def create_dockerfile(self):
|
536
591
|
dockerfile_template = os.path.join(
|
537
592
|
os.path.dirname(os.path.dirname(__file__)),
|
@@ -562,30 +617,85 @@ class ModelBuilder:
|
|
562
617
|
)
|
563
618
|
python_version = DEFAULT_PYTHON_VERSION
|
564
619
|
|
565
|
-
# This is always the final image used for runtime.
|
566
|
-
final_image = PYTHON_BASE_IMAGE.format(python_version=python_version)
|
567
|
-
downloader_image = PYTHON_BASE_IMAGE.format(python_version=python_version)
|
568
|
-
|
569
620
|
# Parse the requirements.txt file to determine the base image
|
570
621
|
dependencies = self._parse_requirements()
|
571
|
-
|
572
|
-
|
573
|
-
|
574
|
-
|
575
|
-
|
576
|
-
|
577
|
-
|
578
|
-
|
579
|
-
|
580
|
-
|
581
|
-
|
582
|
-
|
622
|
+
|
623
|
+
is_amd_gpu = self._is_amd()
|
624
|
+
if is_amd_gpu:
|
625
|
+
final_image = AMD_PYTHON_BASE_IMAGE.format(python_version=python_version)
|
626
|
+
downloader_image = AMD_PYTHON_BASE_IMAGE.format(python_version=python_version)
|
627
|
+
if 'vllm' in dependencies:
|
628
|
+
if python_version != DEFAULT_PYTHON_VERSION:
|
629
|
+
raise Exception(
|
630
|
+
f"vLLM is not supported with Python version {python_version}, please use Python version {DEFAULT_PYTHON_VERSION} in your config.yaml"
|
631
|
+
)
|
632
|
+
torch_version = dependencies.get('torch', None)
|
633
|
+
if 'torch' in dependencies:
|
634
|
+
if python_version != DEFAULT_PYTHON_VERSION:
|
635
|
+
raise Exception(
|
636
|
+
f"torch is not supported with Python version {python_version}, please use Python version {DEFAULT_PYTHON_VERSION} in your config.yaml"
|
637
|
+
)
|
638
|
+
if not torch_version:
|
639
|
+
logger.info(
|
640
|
+
f"torch version not found in requirements.txt, using the default version {DEFAULT_AMD_TORCH_VERSION}"
|
641
|
+
)
|
642
|
+
torch_version = DEFAULT_AMD_TORCH_VERSION
|
643
|
+
if torch_version not in [DEFAULT_AMD_TORCH_VERSION]:
|
644
|
+
raise Exception(
|
645
|
+
f"torch version {torch_version} not supported, please use one of the following versions: {DEFAULT_AMD_TORCH_VERSION} in your requirements.txt"
|
646
|
+
)
|
647
|
+
python_version = DEFAULT_PYTHON_VERSION
|
648
|
+
gpu_version = DEFAULT_AMD_GPU_VERSION
|
649
|
+
final_image = AMD_VLLM_BASE_IMAGE.format(
|
650
|
+
torch_version=torch_version,
|
651
|
+
python_version=python_version,
|
652
|
+
gpu_version=gpu_version,
|
653
|
+
)
|
654
|
+
logger.info("Using vLLM base image to build the Docker image")
|
655
|
+
elif 'torch' in dependencies:
|
656
|
+
torch_version = dependencies['torch']
|
657
|
+
if python_version != DEFAULT_PYTHON_VERSION:
|
658
|
+
raise Exception(
|
659
|
+
f"torch is not supported with Python version {python_version}, please use Python version {DEFAULT_PYTHON_VERSION} in your config.yaml"
|
583
660
|
)
|
661
|
+
if not torch_version:
|
584
662
|
logger.info(
|
585
|
-
f"
|
663
|
+
f"torch version not found in requirements.txt, using the default version {DEFAULT_AMD_TORCH_VERSION}"
|
586
664
|
)
|
587
|
-
|
588
|
-
|
665
|
+
torch_version = DEFAULT_AMD_TORCH_VERSION
|
666
|
+
if torch_version not in [DEFAULT_AMD_TORCH_VERSION]:
|
667
|
+
raise Exception(
|
668
|
+
f"torch version {torch_version} not supported, please use one of the following versions: {DEFAULT_AMD_TORCH_VERSION} in your requirements.txt"
|
669
|
+
)
|
670
|
+
python_version = DEFAULT_PYTHON_VERSION
|
671
|
+
gpu_version = DEFAULT_AMD_GPU_VERSION
|
672
|
+
final_image = TORCH_BASE_IMAGE.format(
|
673
|
+
torch_version=torch_version,
|
674
|
+
python_version=python_version,
|
675
|
+
gpu_version=gpu_version,
|
676
|
+
)
|
677
|
+
logger.info(
|
678
|
+
f"Using Torch version {torch_version} base image to build the Docker image"
|
679
|
+
)
|
680
|
+
else:
|
681
|
+
final_image = PYTHON_BASE_IMAGE.format(python_version=python_version)
|
682
|
+
downloader_image = PYTHON_BASE_IMAGE.format(python_version=python_version)
|
683
|
+
if 'torch' in dependencies and dependencies['torch']:
|
684
|
+
torch_version = dependencies['torch']
|
685
|
+
# Sort in reverse so that newer cuda versions come first and are preferred.
|
686
|
+
for image in sorted(AVAILABLE_TORCH_IMAGES, reverse=True):
|
687
|
+
if torch_version in image and f'py{python_version}' in image:
|
688
|
+
# like cu124, rocm6.3, etc.
|
689
|
+
gpu_version = image.split('-')[-1]
|
690
|
+
final_image = TORCH_BASE_IMAGE.format(
|
691
|
+
torch_version=torch_version,
|
692
|
+
python_version=python_version,
|
693
|
+
gpu_version=gpu_version,
|
694
|
+
)
|
695
|
+
logger.info(
|
696
|
+
f"Using Torch version {torch_version} base image to build the Docker image"
|
697
|
+
)
|
698
|
+
break
|
589
699
|
if 'clarifai' not in dependencies:
|
590
700
|
raise Exception(
|
591
701
|
f"clarifai not found in requirements.txt, please add clarifai to the requirements.txt file with a fixed version. Current version is clarifai=={CLIENT_VERSION}"
|
@@ -835,7 +945,6 @@ class ModelBuilder:
|
|
835
945
|
percent_completed = response.status.percent_completed
|
836
946
|
details = response.status.details
|
837
947
|
|
838
|
-
_clear_line()
|
839
948
|
print(
|
840
949
|
f"Status: {response.status.description}, Progress: {percent_completed}% - {details} ",
|
841
950
|
f"request_id: {response.status.req_id}",
|
@@ -847,9 +956,26 @@ class ModelBuilder:
|
|
847
956
|
return
|
848
957
|
self.model_version_id = response.model_version_id
|
849
958
|
logger.info(f"Created Model Version ID: {self.model_version_id}")
|
850
|
-
logger.info(f"Full url to that version is: {self.
|
959
|
+
logger.info(f"Full url to that version is: {self.model_ui_url}")
|
851
960
|
try:
|
852
|
-
self.monitor_model_build()
|
961
|
+
is_uploaded = self.monitor_model_build()
|
962
|
+
if is_uploaded:
|
963
|
+
# python code to run the model.
|
964
|
+
from clarifai.runners.utils import code_script
|
965
|
+
|
966
|
+
method_signatures = self.get_method_signatures()
|
967
|
+
snippet = code_script.generate_client_script(
|
968
|
+
method_signatures,
|
969
|
+
user_id=self.client.user_app_id.user_id,
|
970
|
+
app_id=self.client.user_app_id.app_id,
|
971
|
+
model_id=self.model_proto.id,
|
972
|
+
)
|
973
|
+
logger.info("""\n
|
974
|
+
XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX
|
975
|
+
# Here is a code snippet to use this model:
|
976
|
+
XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX
|
977
|
+
""")
|
978
|
+
logger.info(snippet)
|
853
979
|
finally:
|
854
980
|
if os.path.exists(self.tar_file):
|
855
981
|
logger.debug(f"Cleaning up upload file: {self.tar_file}")
|
@@ -933,7 +1059,12 @@ class ModelBuilder:
|
|
933
1059
|
for log_entry in logs.log_entries:
|
934
1060
|
if log_entry.url not in seen_logs:
|
935
1061
|
seen_logs.add(log_entry.url)
|
936
|
-
|
1062
|
+
log_entry_msg = re.sub(
|
1063
|
+
r"(\\*)(\[[a-z#/@][^[]*?])",
|
1064
|
+
lambda m: f"{m.group(1)}{m.group(1)}\\{m.group(2)}",
|
1065
|
+
log_entry.message.strip(),
|
1066
|
+
)
|
1067
|
+
logger.info(log_entry_msg)
|
937
1068
|
if status_code == status_code_pb2.MODEL_BUILDING:
|
938
1069
|
print(
|
939
1070
|
f"Model is building... (elapsed {time.time() - st:.1f}s)", end='\r', flush=True
|
@@ -945,7 +1076,7 @@ class ModelBuilder:
|
|
945
1076
|
logger.info("Model build complete!")
|
946
1077
|
logger.info(f"Build time elapsed {time.time() - st:.1f}s)")
|
947
1078
|
logger.info(
|
948
|
-
f"Check out the model at {self.
|
1079
|
+
f"Check out the model at {self.model_ui_url} version: {self.model_version_id}"
|
949
1080
|
)
|
950
1081
|
return True
|
951
1082
|
else:
|
@@ -970,10 +1101,12 @@ def upload_model(folder, stage, skip_dockerfile):
|
|
970
1101
|
exists = builder.check_model_exists()
|
971
1102
|
if exists:
|
972
1103
|
logger.info(
|
973
|
-
f"Model already exists at {builder.
|
1104
|
+
f"Model already exists at {builder.model_ui_url}, this upload will create a new version for it."
|
974
1105
|
)
|
975
1106
|
else:
|
976
|
-
logger.info(
|
1107
|
+
logger.info(
|
1108
|
+
f"New model will be created at {builder.model_ui_url} with it's first version."
|
1109
|
+
)
|
977
1110
|
|
978
1111
|
input("Press Enter to continue...")
|
979
1112
|
builder.upload_model_version()
|
@@ -9,7 +9,6 @@ from typing import Any, Dict, Iterator, List
|
|
9
9
|
|
10
10
|
from clarifai_grpc.grpc.api import resources_pb2, service_pb2
|
11
11
|
from clarifai_grpc.grpc.api.status import status_code_pb2, status_pb2
|
12
|
-
from google.protobuf import json_format
|
13
12
|
|
14
13
|
from clarifai.runners.utils import data_types
|
15
14
|
from clarifai.runners.utils.data_utils import DataConverter
|
@@ -100,7 +99,6 @@ class ModelClass(ABC):
|
|
100
99
|
try:
|
101
100
|
# TODO add method name field to proto
|
102
101
|
method_name = 'predict'
|
103
|
-
inference_params = get_inference_params(request)
|
104
102
|
if len(request.inputs) > 0 and '_method_name' in request.inputs[0].data.metadata:
|
105
103
|
method_name = request.inputs[0].data.metadata['_method_name']
|
106
104
|
if (
|
@@ -124,7 +122,7 @@ class ModelClass(ABC):
|
|
124
122
|
input.data.CopyFrom(new_data)
|
125
123
|
# convert inputs to python types
|
126
124
|
inputs = self._convert_input_protos_to_python(
|
127
|
-
request.inputs,
|
125
|
+
request.inputs, signature.input_fields, python_param_types
|
128
126
|
)
|
129
127
|
if len(inputs) == 1:
|
130
128
|
inputs = inputs[0]
|
@@ -163,7 +161,6 @@ class ModelClass(ABC):
|
|
163
161
|
) -> Iterator[service_pb2.MultiOutputResponse]:
|
164
162
|
try:
|
165
163
|
method_name = 'generate'
|
166
|
-
inference_params = get_inference_params(request)
|
167
164
|
if len(request.inputs) > 0 and '_method_name' in request.inputs[0].data.metadata:
|
168
165
|
method_name = request.inputs[0].data.metadata['_method_name']
|
169
166
|
method = getattr(self, method_name)
|
@@ -180,7 +177,7 @@ class ModelClass(ABC):
|
|
180
177
|
)
|
181
178
|
input.data.CopyFrom(new_data)
|
182
179
|
inputs = self._convert_input_protos_to_python(
|
183
|
-
request.inputs,
|
180
|
+
request.inputs, signature.input_fields, python_param_types
|
184
181
|
)
|
185
182
|
if len(inputs) == 1:
|
186
183
|
inputs = inputs[0]
|
@@ -226,7 +223,6 @@ class ModelClass(ABC):
|
|
226
223
|
assert len(request.inputs) == 1, "Streaming requires exactly one input"
|
227
224
|
|
228
225
|
method_name = 'stream'
|
229
|
-
inference_params = get_inference_params(request)
|
230
226
|
if len(request.inputs) > 0 and '_method_name' in request.inputs[0].data.metadata:
|
231
227
|
method_name = request.inputs[0].data.metadata['_method_name']
|
232
228
|
method = getattr(self, method_name)
|
@@ -251,7 +247,7 @@ class ModelClass(ABC):
|
|
251
247
|
input.data.CopyFrom(new_data)
|
252
248
|
# convert all inputs for the first request, including the first stream value
|
253
249
|
inputs = self._convert_input_protos_to_python(
|
254
|
-
request.inputs,
|
250
|
+
request.inputs, signature.input_fields, python_param_types
|
255
251
|
)
|
256
252
|
kwargs = inputs[0]
|
257
253
|
|
@@ -264,7 +260,7 @@ class ModelClass(ABC):
|
|
264
260
|
# subsequent streaming items contain only the streaming input
|
265
261
|
for request in request_iterator:
|
266
262
|
item = self._convert_input_protos_to_python(
|
267
|
-
request.inputs,
|
263
|
+
request.inputs, [stream_sig], python_param_types
|
268
264
|
)
|
269
265
|
item = item[0][stream_argname]
|
270
266
|
yield item
|
@@ -297,13 +293,12 @@ class ModelClass(ABC):
|
|
297
293
|
def _convert_input_protos_to_python(
|
298
294
|
self,
|
299
295
|
inputs: List[resources_pb2.Input],
|
300
|
-
inference_params: dict,
|
301
296
|
variables_signature: List[resources_pb2.ModelTypeField],
|
302
297
|
python_param_types,
|
303
298
|
) -> List[Dict[str, Any]]:
|
304
299
|
result = []
|
305
300
|
for input in inputs:
|
306
|
-
kwargs = deserialize(input.data, variables_signature
|
301
|
+
kwargs = deserialize(input.data, variables_signature)
|
307
302
|
# dynamic cast to annotated types
|
308
303
|
for k, v in kwargs.items():
|
309
304
|
if k not in python_param_types:
|
@@ -374,18 +369,6 @@ class ModelClass(ABC):
|
|
374
369
|
return method_info
|
375
370
|
|
376
371
|
|
377
|
-
# Helper function to get the inference params
|
378
|
-
def get_inference_params(request) -> dict:
|
379
|
-
"""Get the inference params from the request."""
|
380
|
-
inference_params = {}
|
381
|
-
if request.model.model_version.id != "":
|
382
|
-
output_info = request.model.model_version.output_info
|
383
|
-
output_info = json_format.MessageToDict(output_info, preserving_proto_field_name=True)
|
384
|
-
if "params" in output_info:
|
385
|
-
inference_params = output_info["params"]
|
386
|
-
return inference_params
|
387
|
-
|
388
|
-
|
389
372
|
class _MethodInfo:
|
390
373
|
def __init__(self, method):
|
391
374
|
self.name = method.__name__
|
@@ -442,10 +442,6 @@ def main(
|
|
442
442
|
manager = ModelRunLocally(model_path)
|
443
443
|
# get whatever stage is in config.yaml to force download now
|
444
444
|
# also always write to where upload/build wants to, not the /tmp folder that runtime stage uses
|
445
|
-
_, _, _, when, _, _ = manager.builder._validate_config_checkpoints()
|
446
|
-
manager.builder.download_checkpoints(
|
447
|
-
stage=when, checkpoint_path_override=manager.builder.checkpoint_path
|
448
|
-
)
|
449
445
|
if inside_container:
|
450
446
|
if not manager.is_docker_installed():
|
451
447
|
sys.exit(1)
|
@@ -0,0 +1,75 @@
|
|
1
|
+
import os
|
2
|
+
import tempfile
|
3
|
+
from io import BytesIO
|
4
|
+
from typing import Dict, Iterator, List
|
5
|
+
|
6
|
+
import cv2
|
7
|
+
import torch
|
8
|
+
from PIL import Image as PILImage
|
9
|
+
|
10
|
+
from clarifai.runners.models.model_class import ModelClass
|
11
|
+
from clarifai.runners.utils.data_types import Concept, Frame, Image
|
12
|
+
from clarifai.utils.logging import logger
|
13
|
+
|
14
|
+
|
15
|
+
class VisualClassifierClass(ModelClass):
|
16
|
+
"""Base class for visual classification models supporting image and video processing."""
|
17
|
+
|
18
|
+
@staticmethod
|
19
|
+
def preprocess_image(image_bytes: bytes) -> PILImage:
|
20
|
+
"""Convert image bytes to PIL Image."""
|
21
|
+
return PILImage.open(BytesIO(image_bytes)).convert("RGB")
|
22
|
+
|
23
|
+
@staticmethod
|
24
|
+
def video_to_frames(video_bytes: bytes) -> Iterator[Frame]:
|
25
|
+
"""Convert video bytes to frames.
|
26
|
+
|
27
|
+
Args:
|
28
|
+
video_bytes: Raw video data in bytes
|
29
|
+
|
30
|
+
Yields:
|
31
|
+
Frame with JPEG encoded frame data as bytes and timestamp in milliseconds
|
32
|
+
"""
|
33
|
+
with tempfile.NamedTemporaryFile(delete=False, suffix=".mp4") as temp_video_file:
|
34
|
+
temp_video_file.write(video_bytes)
|
35
|
+
temp_video_path = temp_video_file.name
|
36
|
+
logger.debug(f"temp_video_path: {temp_video_path}")
|
37
|
+
|
38
|
+
video = cv2.VideoCapture(temp_video_path)
|
39
|
+
logger.debug(f"video opened: {video.isOpened()}")
|
40
|
+
|
41
|
+
while video.isOpened():
|
42
|
+
ret, frame = video.read()
|
43
|
+
if not ret:
|
44
|
+
break
|
45
|
+
# Get frame timestamp in milliseconds
|
46
|
+
timestamp_ms = video.get(cv2.CAP_PROP_POS_MSEC)
|
47
|
+
frame_bytes = cv2.imencode('.jpg', frame)[1].tobytes()
|
48
|
+
yield Frame(image=Image(bytes=frame_bytes), time=timestamp_ms)
|
49
|
+
|
50
|
+
video.release()
|
51
|
+
os.unlink(temp_video_path)
|
52
|
+
|
53
|
+
@staticmethod
|
54
|
+
def process_concepts(
|
55
|
+
logits: torch.Tensor, threshold: float, model_labels: Dict[int, str]
|
56
|
+
) -> List[List[Concept]]:
|
57
|
+
"""Convert model logits into a structured format of concepts.
|
58
|
+
|
59
|
+
Args:
|
60
|
+
logits: Model output logits as a tensor (batch_size x num_classes)
|
61
|
+
model_labels: Dictionary mapping label indices to label names
|
62
|
+
|
63
|
+
Returns:
|
64
|
+
List of lists containing Concept objects for each input in the batch
|
65
|
+
"""
|
66
|
+
outputs = []
|
67
|
+
for logit in logits:
|
68
|
+
probs = torch.softmax(logit, dim=-1)
|
69
|
+
sorted_indices = torch.argsort(probs, dim=-1, descending=True)
|
70
|
+
output_concepts = []
|
71
|
+
for idx in sorted_indices:
|
72
|
+
concept = Concept(name=model_labels[idx.item()], value=probs[idx].item())
|
73
|
+
output_concepts.append(concept)
|
74
|
+
outputs.append(output_concepts)
|
75
|
+
return outputs
|
@@ -0,0 +1,79 @@
|
|
1
|
+
import os
|
2
|
+
import tempfile
|
3
|
+
from io import BytesIO
|
4
|
+
from typing import Dict, Iterator, List
|
5
|
+
|
6
|
+
import cv2
|
7
|
+
import torch
|
8
|
+
from PIL import Image as PILImage
|
9
|
+
|
10
|
+
from clarifai.runners.models.model_class import ModelClass
|
11
|
+
from clarifai.runners.utils.data_types import Concept, Frame, Image, Region
|
12
|
+
from clarifai.utils.logging import logger
|
13
|
+
|
14
|
+
|
15
|
+
class VisualDetectorClass(ModelClass):
|
16
|
+
"""Base class for visual detection models supporting image and video processing."""
|
17
|
+
|
18
|
+
@staticmethod
|
19
|
+
def preprocess_image(image_bytes: bytes) -> PILImage:
|
20
|
+
"""Convert image bytes to PIL Image."""
|
21
|
+
return PILImage.open(BytesIO(image_bytes)).convert("RGB")
|
22
|
+
|
23
|
+
@staticmethod
|
24
|
+
def video_to_frames(video_bytes: bytes) -> Iterator[Frame]:
|
25
|
+
"""Convert video bytes to frames.
|
26
|
+
|
27
|
+
Args:
|
28
|
+
video_bytes: Raw video data in bytes
|
29
|
+
|
30
|
+
Yields:
|
31
|
+
Frame with JPEG encoded frame data as bytes and timestamp in milliseconds
|
32
|
+
"""
|
33
|
+
with tempfile.NamedTemporaryFile(delete=False, suffix=".mp4") as temp_video_file:
|
34
|
+
temp_video_file.write(video_bytes)
|
35
|
+
temp_video_path = temp_video_file.name
|
36
|
+
logger.debug(f"temp_video_path: {temp_video_path}")
|
37
|
+
|
38
|
+
video = cv2.VideoCapture(temp_video_path)
|
39
|
+
logger.debug(f"video opened: {video.isOpened()}")
|
40
|
+
|
41
|
+
while video.isOpened():
|
42
|
+
ret, frame = video.read()
|
43
|
+
if not ret:
|
44
|
+
break
|
45
|
+
# Get frame timestamp in milliseconds
|
46
|
+
timestamp_ms = video.get(cv2.CAP_PROP_POS_MSEC)
|
47
|
+
frame_bytes = cv2.imencode('.jpg', frame)[1].tobytes()
|
48
|
+
yield Frame(image=Image(bytes=frame_bytes), time=timestamp_ms)
|
49
|
+
|
50
|
+
video.release()
|
51
|
+
os.unlink(temp_video_path)
|
52
|
+
|
53
|
+
@staticmethod
|
54
|
+
def process_detections(
|
55
|
+
results: List[Dict[str, torch.Tensor]], threshold: float, model_labels: Dict[int, str]
|
56
|
+
) -> List[List[Region]]:
|
57
|
+
"""Convert model outputs into a structured format of detections.
|
58
|
+
|
59
|
+
Args:
|
60
|
+
results: Raw detection results from model
|
61
|
+
threshold: Confidence threshold for detections
|
62
|
+
model_labels: Dictionary mapping label indices to names
|
63
|
+
|
64
|
+
Returns:
|
65
|
+
List of lists containing Region objects for each detection
|
66
|
+
"""
|
67
|
+
outputs = []
|
68
|
+
for result in results:
|
69
|
+
detections = []
|
70
|
+
for score, label_idx, box in zip(result["scores"], result["labels"], result["boxes"]):
|
71
|
+
if score > threshold:
|
72
|
+
label = model_labels[label_idx.item()]
|
73
|
+
detections.append(
|
74
|
+
Region(
|
75
|
+
box=box.tolist(), concepts=[Concept(name=label, value=score.item())]
|
76
|
+
)
|
77
|
+
)
|
78
|
+
outputs.append(detections)
|
79
|
+
return outputs
|