sinapsis-huggingface 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sinapsis_huggingface-0.1.0.dist-info/METADATA +921 -0
- sinapsis_huggingface-0.1.0.dist-info/RECORD +33 -0
- sinapsis_huggingface-0.1.0.dist-info/WHEEL +5 -0
- sinapsis_huggingface-0.1.0.dist-info/licenses/LICENSE +661 -0
- sinapsis_huggingface-0.1.0.dist-info/top_level.txt +4 -0
- sinapsis_huggingface_diffusers/src/sinapsis_huggingface_diffusers/__init__.py +0 -0
- sinapsis_huggingface_diffusers/src/sinapsis_huggingface_diffusers/templates/__init__.py +26 -0
- sinapsis_huggingface_diffusers/src/sinapsis_huggingface_diffusers/templates/base_diffusers.py +191 -0
- sinapsis_huggingface_diffusers/src/sinapsis_huggingface_diffusers/templates/image_to_image_diffusers.py +156 -0
- sinapsis_huggingface_diffusers/src/sinapsis_huggingface_diffusers/templates/image_to_video_gen_xl_diffusers.py +59 -0
- sinapsis_huggingface_diffusers/src/sinapsis_huggingface_diffusers/templates/inpainting_diffusers.py +327 -0
- sinapsis_huggingface_diffusers/src/sinapsis_huggingface_diffusers/templates/text_to_image_diffusers.py +79 -0
- sinapsis_huggingface_embeddings/src/sinapsis_huggingface_embeddings/__init__.py +0 -0
- sinapsis_huggingface_embeddings/src/sinapsis_huggingface_embeddings/templates/__init__.py +22 -0
- sinapsis_huggingface_embeddings/src/sinapsis_huggingface_embeddings/templates/hugging_face_embedding_extractor.py +104 -0
- sinapsis_huggingface_embeddings/src/sinapsis_huggingface_embeddings/templates/speaker_embedding_from_audio.py +159 -0
- sinapsis_huggingface_embeddings/src/sinapsis_huggingface_embeddings/templates/speaker_embedding_from_dataset.py +95 -0
- sinapsis_huggingface_grounding_dino/src/sinapsis_huggingface_grounding_dino/__init__.py +0 -0
- sinapsis_huggingface_grounding_dino/src/sinapsis_huggingface_grounding_dino/helpers/__init__.py +0 -0
- sinapsis_huggingface_grounding_dino/src/sinapsis_huggingface_grounding_dino/helpers/grounding_dino_keys.py +21 -0
- sinapsis_huggingface_grounding_dino/src/sinapsis_huggingface_grounding_dino/templates/__init__.py +24 -0
- sinapsis_huggingface_grounding_dino/src/sinapsis_huggingface_grounding_dino/templates/grounding_dino.py +333 -0
- sinapsis_huggingface_grounding_dino/src/sinapsis_huggingface_grounding_dino/templates/grounding_dino_classification.py +226 -0
- sinapsis_huggingface_transformers/src/sinapsis_huggingface_transformers/__init__.py +0 -0
- sinapsis_huggingface_transformers/src/sinapsis_huggingface_transformers/helpers/__init__.py +4 -0
- sinapsis_huggingface_transformers/src/sinapsis_huggingface_transformers/helpers/text_to_sentences.py +48 -0
- sinapsis_huggingface_transformers/src/sinapsis_huggingface_transformers/templates/__init__.py +23 -0
- sinapsis_huggingface_transformers/src/sinapsis_huggingface_transformers/templates/base_transformers.py +134 -0
- sinapsis_huggingface_transformers/src/sinapsis_huggingface_transformers/templates/image_to_text_transformers.py +73 -0
- sinapsis_huggingface_transformers/src/sinapsis_huggingface_transformers/templates/speech_to_text_transformers.py +63 -0
- sinapsis_huggingface_transformers/src/sinapsis_huggingface_transformers/templates/summarization_transformers.py +61 -0
- sinapsis_huggingface_transformers/src/sinapsis_huggingface_transformers/templates/text_to_speech_transformers.py +153 -0
- sinapsis_huggingface_transformers/src/sinapsis_huggingface_transformers/templates/translation_transformers.py +75 -0
|
@@ -0,0 +1,95 @@
|
|
|
1
|
+
# -*- coding: utf-8 -*-
|
|
2
|
+
|
|
3
|
+
from typing import Literal
|
|
4
|
+
|
|
5
|
+
from datasets import load_dataset
|
|
6
|
+
from sinapsis_core.data_containers.data_packet import DataContainer
|
|
7
|
+
from sinapsis_core.template_base import (
|
|
8
|
+
Template,
|
|
9
|
+
TemplateAttributes,
|
|
10
|
+
)
|
|
11
|
+
from sinapsis_core.utils.env_var_keys import SINAPSIS_CACHE_DIR
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class SpeakerEmbeddingFromDatasetAttributes(TemplateAttributes):
|
|
15
|
+
"""Attributes for the SpeakerEmbeddingFromDataset template.
|
|
16
|
+
|
|
17
|
+
Attributes:
|
|
18
|
+
dataset_path (str): Path or name of the Hugging Face dataset containing speaker embeddings.
|
|
19
|
+
For example, `"Matthijs/cmu-arctic-xvectors"`.
|
|
20
|
+
data_cache_dir (str): Directory to cache the downloaded dataset. Defaults to the value of
|
|
21
|
+
the `SINAPSIS_CACHE_DIR` environment variable.
|
|
22
|
+
split (str): Dataset split to use (e.g., "train", "validation", or "test").
|
|
23
|
+
Defaults to `"validation"`.
|
|
24
|
+
sample_idx (int): Index of the dataset sample to extract the embedding from.
|
|
25
|
+
xvector_key (str): Key in the dataset sample that stores the xvector. Defaults to `"xvector"`.
|
|
26
|
+
target_packet (Literal["texts", "audios"]): Type of packet in the `DataContainer` to which
|
|
27
|
+
the embedding will be attached. Must be either `"texts"` or `"audios"`.
|
|
28
|
+
"""
|
|
29
|
+
|
|
30
|
+
dataset_path: str
|
|
31
|
+
data_cache_dir: str = str(SINAPSIS_CACHE_DIR)
|
|
32
|
+
split: str = "validation"
|
|
33
|
+
sample_idx: int
|
|
34
|
+
xvector_key: str = "xvector"
|
|
35
|
+
target_packet: Literal["texts", "audios"]
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
class SpeakerEmbeddingFromDataset(Template):
|
|
39
|
+
"""
|
|
40
|
+
Template to retrieve and attach speaker embeddings from a Hugging Face dataset.
|
|
41
|
+
This template extracts a specified embedding (e.g., xvector) from a dataset and attaches
|
|
42
|
+
it to the `embedding` attribute of each `TextPacket` in a `DataContainer`.
|
|
43
|
+
|
|
44
|
+
Usage example:
|
|
45
|
+
|
|
46
|
+
agent:
|
|
47
|
+
name: my_test_agent
|
|
48
|
+
templates:
|
|
49
|
+
- template_name: InputTemplate
|
|
50
|
+
class_name: InputTemplate
|
|
51
|
+
attributes: {}
|
|
52
|
+
- template_name: SpeakerEmbeddingFromDataset
|
|
53
|
+
class_name: SpeakerEmbeddingFromDataset
|
|
54
|
+
template_input: InputTemplate
|
|
55
|
+
attributes:
|
|
56
|
+
dataset_path: '/path/to/hugging/face/dataset'
|
|
57
|
+
data_cache_dir: /path/to/cache/dir
|
|
58
|
+
split: validation
|
|
59
|
+
sample_idx: '1'
|
|
60
|
+
xvector_key: xvector
|
|
61
|
+
target_packet: 'audios'
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
"""
|
|
66
|
+
|
|
67
|
+
AttributesBaseModel = SpeakerEmbeddingFromDatasetAttributes
|
|
68
|
+
CATEGORY = "Embeddings"
|
|
69
|
+
|
|
70
|
+
def execute(self, container: DataContainer) -> DataContainer:
|
|
71
|
+
"""Retrieve and attach speaker embeddings to specified packets in a DataContainer.
|
|
72
|
+
|
|
73
|
+
Args:
|
|
74
|
+
container (DataContainer): The container holding the packets to which the embedding will be
|
|
75
|
+
attached.
|
|
76
|
+
|
|
77
|
+
Returns:
|
|
78
|
+
DataContainer: The updated container with embeddings attached to the `embedding`
|
|
79
|
+
attribute of the specified packet type.
|
|
80
|
+
"""
|
|
81
|
+
packets = getattr(container, self.attributes.target_packet)
|
|
82
|
+
embeddings_dataset = load_dataset(
|
|
83
|
+
self.attributes.dataset_path,
|
|
84
|
+
split=self.attributes.split,
|
|
85
|
+
cache_dir=self.attributes.data_cache_dir,
|
|
86
|
+
)
|
|
87
|
+
speaker_embedding = embeddings_dataset[self.attributes.sample_idx][self.attributes.xvector_key]
|
|
88
|
+
self.logger.info(
|
|
89
|
+
f"Attaching embedding from index {self.attributes.sample_idx} to "
|
|
90
|
+
f"{len(packets)} {self.attributes.target_packet} packets."
|
|
91
|
+
)
|
|
92
|
+
for packet in packets:
|
|
93
|
+
packet.embedding = speaker_embedding
|
|
94
|
+
|
|
95
|
+
return container
|
|
File without changes
|
sinapsis_huggingface_grounding_dino/src/sinapsis_huggingface_grounding_dino/helpers/__init__.py
ADDED
|
File without changes
|
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
# -*- coding: utf-8 -*-
|
|
2
|
+
from pydantic.dataclasses import dataclass
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
@dataclass(frozen=True)
|
|
6
|
+
class GroundingDINOKeys:
|
|
7
|
+
"""Defines key constants used in the GroundingDINO workflow for consistent referencing.
|
|
8
|
+
|
|
9
|
+
Attributes:
|
|
10
|
+
CLASS_DELIMITER (str): Delimiter used to separate class names in the input text.
|
|
11
|
+
CONFIDENCE_SCORE (str): Key for accessing confidence scores from the model output.
|
|
12
|
+
INPUT_IDS (str): Key for tokenized input IDs used by the model.
|
|
13
|
+
LABELS (str): Key for general label data in the processed output.
|
|
14
|
+
BBOXES (str): Key for bounding box coordinates in the output.
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
CLASS_DELIMITER: str = "."
|
|
18
|
+
CONFIDENCE_SCORE: str = "scores"
|
|
19
|
+
INPUT_IDS: str = "input_ids"
|
|
20
|
+
LABELS: str = "text_labels"
|
|
21
|
+
BBOXES: str = "boxes"
|
sinapsis_huggingface_grounding_dino/src/sinapsis_huggingface_grounding_dino/templates/__init__.py
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
1
|
+
# -*- coding: utf-8 -*-
|
|
2
|
+
import importlib
|
|
3
|
+
from typing import Any, Callable, cast
|
|
4
|
+
|
|
5
|
+
_root_lib_path = "sinapsis_huggingface_grounding_dino.templates"
|
|
6
|
+
|
|
7
|
+
_template_lookup = {
|
|
8
|
+
"GroundingDINO": f"{_root_lib_path}.grounding_dino",
|
|
9
|
+
"GroundingDINOClassification": f"{_root_lib_path}.grounding_dino_classification",
|
|
10
|
+
}
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def __getattr__(name: str) -> Callable[..., Any]:
|
|
14
|
+
if name in _template_lookup:
|
|
15
|
+
module = importlib.import_module(_template_lookup[name])
|
|
16
|
+
attr = getattr(module, name)
|
|
17
|
+
if callable(attr):
|
|
18
|
+
return cast(Callable[..., Any], attr)
|
|
19
|
+
raise TypeError(f"Attribute `{name}` in `{_template_lookup[name]}` is not callable.")
|
|
20
|
+
|
|
21
|
+
raise AttributeError(f"template `{name}` not found in {_root_lib_path}")
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
__all__ = list(_template_lookup.keys())
|
|
@@ -0,0 +1,333 @@
|
|
|
1
|
+
# -*- coding: utf-8 -*-
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
from typing import Any, Literal
|
|
5
|
+
|
|
6
|
+
import torch
|
|
7
|
+
from sinapsis_core.data_containers.annotations import BoundingBox, ImageAnnotations
|
|
8
|
+
from sinapsis_core.data_containers.data_packet import DataContainer, ImagePacket
|
|
9
|
+
from sinapsis_core.template_base import Template, TemplateAttributes, TemplateAttributeType
|
|
10
|
+
from sinapsis_core.utils.env_var_keys import SINAPSIS_CACHE_DIR
|
|
11
|
+
from transformers import (
|
|
12
|
+
AutoModelForZeroShotObjectDetection,
|
|
13
|
+
AutoProcessor,
|
|
14
|
+
GroundingDinoForObjectDetection,
|
|
15
|
+
PreTrainedModel,
|
|
16
|
+
)
|
|
17
|
+
|
|
18
|
+
from sinapsis_huggingface_grounding_dino.helpers.grounding_dino_keys import GroundingDINOKeys
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class GroundingBaseAttributes(TemplateAttributes):
|
|
22
|
+
"""GroundingDINOAttributes defines the base configuration attributes for the GroundingDINO
|
|
23
|
+
classes.
|
|
24
|
+
|
|
25
|
+
Attributes:
|
|
26
|
+
model_path (str): Specifies the model identifier or file path for the base GroundingDINO
|
|
27
|
+
model. Example:
|
|
28
|
+
"IDEA-Research/grounding-dino-tiny".
|
|
29
|
+
model_cache_dir (str): Directory where model files are or will be stored. Defaults to
|
|
30
|
+
"SINAPSIS_CACHE_DIR".
|
|
31
|
+
inference_mode (Literal["object_detection", "zero_shot"]): Specifies the mode for model
|
|
32
|
+
inference. "object_detection" enables direct object detection, while "zero_shot"
|
|
33
|
+
enables zero-shot inference for unseen classes.
|
|
34
|
+
threshold (float): Threshold for box detection. Defaults to 0.25.
|
|
35
|
+
text_threshold (float): Threshold for text detection. Defaults to 0.25.
|
|
36
|
+
device (Literal["cuda", "cpu"]): Device to be used for inference.
|
|
37
|
+
"""
|
|
38
|
+
|
|
39
|
+
model_path: str
|
|
40
|
+
model_cache_dir: str = str(SINAPSIS_CACHE_DIR)
|
|
41
|
+
inference_mode: Literal["object_detection", "zero_shot"]
|
|
42
|
+
threshold: float = 0.25
|
|
43
|
+
text_threshold: float = 0.25
|
|
44
|
+
device: Literal["cuda", "cpu"]
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
class GroundingDINO(Template):
|
|
48
|
+
"""Base class for Grounding DINO models.
|
|
49
|
+
|
|
50
|
+
This module contains the foundational class for implementing object detection using
|
|
51
|
+
the Grounding DINO model, leveraging transformers and PyTorch. It provides essential
|
|
52
|
+
methods for running inference, formatting results, and creating annotations, which
|
|
53
|
+
can be extended by other specialized classes related to Grounding DINO.
|
|
54
|
+
|
|
55
|
+
Usage example:
|
|
56
|
+
|
|
57
|
+
agent:
|
|
58
|
+
name: my_test_agent
|
|
59
|
+
templates:
|
|
60
|
+
- template_name: InputTemplate
|
|
61
|
+
class_name: InputTemplate
|
|
62
|
+
attributes: {}
|
|
63
|
+
- template_name: GroundingDINO
|
|
64
|
+
class_name: GroundingDINO
|
|
65
|
+
template_input: InputTemplate
|
|
66
|
+
attributes:
|
|
67
|
+
model_path: '/path/to/model'
|
|
68
|
+
model_cache_dir: /path/to/cache/dir
|
|
69
|
+
inference_mode: 'object_detection'
|
|
70
|
+
threshold: 0.25
|
|
71
|
+
text_threshold: 0.25
|
|
72
|
+
device: 'cuda'
|
|
73
|
+
text_input: 'object to detect'
|
|
74
|
+
|
|
75
|
+
"""
|
|
76
|
+
|
|
77
|
+
CATEGORY = "Grounding DINO"
|
|
78
|
+
KEYS = GroundingDINOKeys()
|
|
79
|
+
|
|
80
|
+
class AttributesBaseModel(GroundingBaseAttributes):
|
|
81
|
+
"""GroundingDINOAttributes defines the configuration attributes for the GroundingDINO class.
|
|
82
|
+
|
|
83
|
+
Attributes:
|
|
84
|
+
text_input (str): Input text for GroundingDINO processing.
|
|
85
|
+
"""
|
|
86
|
+
|
|
87
|
+
text_input: str
|
|
88
|
+
|
|
89
|
+
def __init__(self, attributes: TemplateAttributeType) -> None:
|
|
90
|
+
"""Initializes the GroundingDINO class with the provided attributes.
|
|
91
|
+
|
|
92
|
+
Args:
|
|
93
|
+
attributes (dict[str, Any]): Dictionary containing configuration parameters.
|
|
94
|
+
"""
|
|
95
|
+
super().__init__(attributes)
|
|
96
|
+
self.device = self.attributes.device
|
|
97
|
+
self.processor = AutoProcessor.from_pretrained(
|
|
98
|
+
self.attributes.model_path, cache_dir=self.attributes.model_cache_dir
|
|
99
|
+
)
|
|
100
|
+
self.model = self._set_model().to(self.device)
|
|
101
|
+
self.max_tokens = self.processor.tokenizer.model_max_length
|
|
102
|
+
self.text_input = self.validate_and_format_text_input(self.attributes.text_input)
|
|
103
|
+
|
|
104
|
+
def _set_model(self) -> PreTrainedModel:
|
|
105
|
+
"""Loads the specific model variant based on the configured inference mode.
|
|
106
|
+
|
|
107
|
+
Returns:
|
|
108
|
+
PreTrainedModel: The pretrained model instance, loaded according to the inference mode.
|
|
109
|
+
"""
|
|
110
|
+
if self.attributes.inference_mode == "object_detection":
|
|
111
|
+
return GroundingDinoForObjectDetection.from_pretrained(
|
|
112
|
+
self.attributes.model_path, cache_dir=self.attributes.model_cache_dir
|
|
113
|
+
)
|
|
114
|
+
return AutoModelForZeroShotObjectDetection.from_pretrained(
|
|
115
|
+
self.attributes.model_path, cache_dir=self.attributes.model_cache_dir
|
|
116
|
+
)
|
|
117
|
+
|
|
118
|
+
def validate_and_format_text_input(self, text_input: str) -> str:
|
|
119
|
+
"""Validates and formats the text input for consistency.
|
|
120
|
+
|
|
121
|
+
Args:
|
|
122
|
+
text_input (str): The input text specifying object classes.
|
|
123
|
+
|
|
124
|
+
Returns:
|
|
125
|
+
str: A validated and formatted version of the input text.
|
|
126
|
+
"""
|
|
127
|
+
delimiter = self.KEYS.CLASS_DELIMITER
|
|
128
|
+
|
|
129
|
+
if delimiter not in text_input:
|
|
130
|
+
raise ValueError(
|
|
131
|
+
f"Invalid text_input format '{text_input}': Expected at least one '{delimiter}' "
|
|
132
|
+
"to separate object names."
|
|
133
|
+
)
|
|
134
|
+
|
|
135
|
+
formatted_input = " ".join(text_input.split())
|
|
136
|
+
formatted_input = f"{delimiter} ".join(part.strip() for part in formatted_input.split(delimiter))
|
|
137
|
+
|
|
138
|
+
if not formatted_input.endswith(f"{delimiter} "):
|
|
139
|
+
formatted_input += delimiter
|
|
140
|
+
|
|
141
|
+
return formatted_input.strip()
|
|
142
|
+
|
|
143
|
+
def _run_inference(self, image_packet: ImagePacket) -> tuple[dict[str, torch.Tensor], dict[str, torch.Tensor]]:
|
|
144
|
+
"""Runs inference on a given image packet.
|
|
145
|
+
|
|
146
|
+
Args:
|
|
147
|
+
image_packet (ImagePacket): Image data to be processed.
|
|
148
|
+
|
|
149
|
+
Returns:
|
|
150
|
+
tuple[dict[str, torch.Tensor], dict[str, torch.Tensor]]: Outputs and inputs of the processor, respectively.
|
|
151
|
+
"""
|
|
152
|
+
inputs = self.processor(
|
|
153
|
+
images=image_packet.content,
|
|
154
|
+
text=self.text_input,
|
|
155
|
+
return_tensors="pt",
|
|
156
|
+
).to(self.device)
|
|
157
|
+
|
|
158
|
+
with torch.no_grad():
|
|
159
|
+
outputs = self.model(**inputs)
|
|
160
|
+
|
|
161
|
+
return outputs, inputs
|
|
162
|
+
|
|
163
|
+
def _post_process(
|
|
164
|
+
self,
|
|
165
|
+
outputs: dict[str, torch.Tensor],
|
|
166
|
+
inputs: dict[str, torch.Tensor],
|
|
167
|
+
image_size: tuple[int, int],
|
|
168
|
+
) -> list[dict[str, Any]]:
|
|
169
|
+
"""Processes model outputs to extract and format detection results.
|
|
170
|
+
|
|
171
|
+
This method uses the processor's post-processing function to generate
|
|
172
|
+
detection results such as bounding boxes, confidence scores, and class labels
|
|
173
|
+
for the objects detected in the input image.
|
|
174
|
+
|
|
175
|
+
Args:
|
|
176
|
+
outputs (dict[str, torch.Tensor]): The raw output tensors from the model, containing
|
|
177
|
+
information required for generating detections (e.g., bounding box coordinates,
|
|
178
|
+
scores).
|
|
179
|
+
inputs (dict[str, torch.Tensor]): Input tensors provided to the model, including input
|
|
180
|
+
IDs for identifying objects.
|
|
181
|
+
image_size (tuple[int, int]): The dimensions (height, width) of the input image, used
|
|
182
|
+
to scale the bounding boxes to the image size.
|
|
183
|
+
|
|
184
|
+
Returns:
|
|
185
|
+
list[dict[str, Any]]: A list of dictionaries, each containing detection information
|
|
186
|
+
for an identified object, including bounding boxes, labels, and confidence scores.
|
|
187
|
+
"""
|
|
188
|
+
detections: list[dict[str, Any]] = self.processor.post_process_grounded_object_detection(
|
|
189
|
+
outputs,
|
|
190
|
+
inputs[self.KEYS.INPUT_IDS],
|
|
191
|
+
threshold=self.attributes.threshold,
|
|
192
|
+
text_threshold=self.attributes.text_threshold,
|
|
193
|
+
target_sizes=[image_size],
|
|
194
|
+
)
|
|
195
|
+
return detections
|
|
196
|
+
|
|
197
|
+
def _format_results(self, results: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
|
198
|
+
"""Formats the results by converting scores and boxes to CPU tensors and converting labels
|
|
199
|
+
to strings.
|
|
200
|
+
|
|
201
|
+
Args:
|
|
202
|
+
results (list[dict[str, Any]]): List of detection results containing scores, labels,
|
|
203
|
+
and boxes.
|
|
204
|
+
|
|
205
|
+
Returns:
|
|
206
|
+
list[dict[str, Any]]: A list of formatted results with tensors moved to the CPU, labels converted to
|
|
207
|
+
strings.
|
|
208
|
+
"""
|
|
209
|
+
formatted_results = []
|
|
210
|
+
for result in results:
|
|
211
|
+
formatted_result = {
|
|
212
|
+
self.KEYS.CONFIDENCE_SCORE: result[self.KEYS.CONFIDENCE_SCORE].cpu(),
|
|
213
|
+
self.KEYS.LABELS: [
|
|
214
|
+
str(label) if isinstance(label, str) else str(label.item()) for label in result[self.KEYS.LABELS]
|
|
215
|
+
],
|
|
216
|
+
self.KEYS.BBOXES: result[self.KEYS.BBOXES].cpu(),
|
|
217
|
+
}
|
|
218
|
+
|
|
219
|
+
formatted_results.append(formatted_result)
|
|
220
|
+
|
|
221
|
+
return formatted_results
|
|
222
|
+
|
|
223
|
+
def get_class_names(self) -> list[str]:
|
|
224
|
+
"""
|
|
225
|
+
Produce a list of class names from the given text input.
|
|
226
|
+
|
|
227
|
+
Returns:
|
|
228
|
+
list[str]: List of class names.
|
|
229
|
+
"""
|
|
230
|
+
return [cls.strip() for cls in self.text_input.split(self.KEYS.CLASS_DELIMITER) if cls.strip()]
|
|
231
|
+
|
|
232
|
+
def _filter_results(self, results: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
|
233
|
+
"""Filters results to exclude annotations with empty or whitespace-only labels,
|
|
234
|
+
and discards labels not found in the provided class names.
|
|
235
|
+
|
|
236
|
+
Args:
|
|
237
|
+
results (list[dict[str, Any]]): List of detection results containing scores, labels,
|
|
238
|
+
and boxes.
|
|
239
|
+
|
|
240
|
+
Returns:
|
|
241
|
+
list[dict[str, Any]]: A list of filtered results with empty or invalid labels removed.
|
|
242
|
+
"""
|
|
243
|
+
class_names = self.get_class_names()
|
|
244
|
+
filtered_results = []
|
|
245
|
+
|
|
246
|
+
for result in results:
|
|
247
|
+
valid_labels = [idx for idx, label in enumerate(result[self.KEYS.LABELS]) if label.strip() in class_names]
|
|
248
|
+
|
|
249
|
+
if valid_labels:
|
|
250
|
+
filtered_result = {
|
|
251
|
+
self.KEYS.CONFIDENCE_SCORE: result[self.KEYS.CONFIDENCE_SCORE][valid_labels],
|
|
252
|
+
self.KEYS.LABELS: [result[self.KEYS.LABELS][i] for i in valid_labels],
|
|
253
|
+
self.KEYS.BBOXES: result[self.KEYS.BBOXES][valid_labels],
|
|
254
|
+
}
|
|
255
|
+
|
|
256
|
+
filtered_results.append(filtered_result)
|
|
257
|
+
|
|
258
|
+
return filtered_results
|
|
259
|
+
|
|
260
|
+
def _prepare_results(
|
|
261
|
+
self,
|
|
262
|
+
outputs: dict[str, torch.Tensor],
|
|
263
|
+
inputs: dict[str, torch.Tensor],
|
|
264
|
+
image_size: tuple[int, int],
|
|
265
|
+
) -> list[dict[str, Any]]:
|
|
266
|
+
"""Prepare and filter results from the model outputs.
|
|
267
|
+
|
|
268
|
+
Args:
|
|
269
|
+
outputs (dict[str, torch.Tensor]): Model outputs from inference.
|
|
270
|
+
inputs (dict[str, torch.Tensor]): Input data fed to the model processor.
|
|
271
|
+
image_size (tuple[int, int]): Size of the processed image.
|
|
272
|
+
|
|
273
|
+
Returns:
|
|
274
|
+
list[dict[str, Any]]: Final post-processed, filtered and formatted results from the model.
|
|
275
|
+
"""
|
|
276
|
+
results = self._post_process(outputs, inputs, image_size)
|
|
277
|
+
formatted_results = self._format_results(results)
|
|
278
|
+
return self._filter_results(formatted_results)
|
|
279
|
+
|
|
280
|
+
def _create_annotations(self, image_packet: ImagePacket, results: list[dict[str, Any]]) -> None:
|
|
281
|
+
"""Creates annotation objects for the image packet based on filtered detection results.
|
|
282
|
+
|
|
283
|
+
Args:
|
|
284
|
+
image_packet (ImagePacket): Container holding the processed image and metadata.
|
|
285
|
+
list[dict[str, Any]]: Detection results including bounding boxes, labels, and scores.
|
|
286
|
+
"""
|
|
287
|
+
new_annotations: list[ImageAnnotations] = []
|
|
288
|
+
for result in results:
|
|
289
|
+
for idx, bbox in enumerate(result[self.KEYS.BBOXES].numpy()):
|
|
290
|
+
xmin, ymin, xmax, ymax = bbox
|
|
291
|
+
bounding_box = BoundingBox(
|
|
292
|
+
x=xmin,
|
|
293
|
+
y=ymin,
|
|
294
|
+
w=xmax - xmin,
|
|
295
|
+
h=ymax - ymin,
|
|
296
|
+
)
|
|
297
|
+
|
|
298
|
+
new_annotations.append(
|
|
299
|
+
ImageAnnotations(
|
|
300
|
+
label_str=result[self.KEYS.LABELS][idx],
|
|
301
|
+
confidence_score=float(result[self.KEYS.CONFIDENCE_SCORE][idx]),
|
|
302
|
+
bbox=bounding_box,
|
|
303
|
+
)
|
|
304
|
+
)
|
|
305
|
+
image_packet.annotations = new_annotations
|
|
306
|
+
|
|
307
|
+
def _process_image_packet(self, image_packet: ImagePacket) -> None:
|
|
308
|
+
"""Processes a single image packet by preparing detection results and creating annotations.
|
|
309
|
+
|
|
310
|
+
Args:
|
|
311
|
+
image_packet (ImagePacket): An image packet to be processed, producing detection
|
|
312
|
+
results and corresponding annotations.
|
|
313
|
+
"""
|
|
314
|
+
outputs, inputs = self._run_inference(image_packet)
|
|
315
|
+
image_size = image_packet.content.shape[:2]
|
|
316
|
+
detections = self._prepare_results(outputs, inputs, image_size)
|
|
317
|
+
self._create_annotations(image_packet, detections)
|
|
318
|
+
|
|
319
|
+
def execute(self, container: DataContainer) -> DataContainer:
|
|
320
|
+
"""Executes the detection pipeline for a container of images by processing each image
|
|
321
|
+
packet.
|
|
322
|
+
|
|
323
|
+
Args:
|
|
324
|
+
container (DataContainer): The container holding multiple image packets for processing.
|
|
325
|
+
|
|
326
|
+
Returns:
|
|
327
|
+
DataContainer: The container with updated annotations for each image packet.
|
|
328
|
+
"""
|
|
329
|
+
|
|
330
|
+
for image_packet in container.images:
|
|
331
|
+
self._process_image_packet(image_packet)
|
|
332
|
+
|
|
333
|
+
return container
|