sinapsis-anomalib 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of sinapsis-anomalib might be problematic. Click here for more details.
- sinapsis_anomalib/__init__.py +0 -0
- sinapsis_anomalib/helpers/__init__.py +0 -0
- sinapsis_anomalib/helpers/config_factory.py +122 -0
- sinapsis_anomalib/templates/__init__.py +31 -0
- sinapsis_anomalib/templates/anomalib_base.py +212 -0
- sinapsis_anomalib/templates/anomalib_base_inference.py +247 -0
- sinapsis_anomalib/templates/anomalib_export.py +204 -0
- sinapsis_anomalib/templates/anomalib_openvino_inference.py +93 -0
- sinapsis_anomalib/templates/anomalib_torch_inference.py +83 -0
- sinapsis_anomalib/templates/anomalib_train.py +205 -0
- sinapsis_anomalib-0.1.0.dist-info/METADATA +949 -0
- sinapsis_anomalib-0.1.0.dist-info/RECORD +15 -0
- sinapsis_anomalib-0.1.0.dist-info/WHEEL +5 -0
- sinapsis_anomalib-0.1.0.dist-info/licenses/LICENSE +661 -0
- sinapsis_anomalib-0.1.0.dist-info/top_level.txt +1 -0
|
File without changes
|
|
File without changes
|
|
@@ -0,0 +1,122 @@
|
|
|
1
|
+
# -*- coding: utf-8 -*-
|
|
2
|
+
from typing import Any
|
|
3
|
+
|
|
4
|
+
from anomalib.loggers import (
|
|
5
|
+
AnomalibCometLogger,
|
|
6
|
+
AnomalibMLFlowLogger,
|
|
7
|
+
AnomalibTensorBoardLogger,
|
|
8
|
+
AnomalibWandbLogger,
|
|
9
|
+
)
|
|
10
|
+
from lightning.pytorch.callbacks import (
|
|
11
|
+
Callback,
|
|
12
|
+
DeviceStatsMonitor,
|
|
13
|
+
EarlyStopping,
|
|
14
|
+
LearningRateMonitor,
|
|
15
|
+
ModelCheckpoint,
|
|
16
|
+
RichProgressBar,
|
|
17
|
+
)
|
|
18
|
+
from lightning.pytorch.loggers import (
|
|
19
|
+
CometLogger,
|
|
20
|
+
MLFlowLogger,
|
|
21
|
+
TensorBoardLogger,
|
|
22
|
+
WandbLogger,
|
|
23
|
+
)
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class CallbackFactory:
|
|
27
|
+
"""A factory class for creating and managing PyTorch Lightning callbacks.
|
|
28
|
+
|
|
29
|
+
This class provides a centralized way to create and register callbacks for use in PyTorch Lightning
|
|
30
|
+
training workflows. It supports built-in callbacks like `RichProgressBar`, `EarlyStopping`, and
|
|
31
|
+
`ModelCheckpoint`, and allows users to register custom callbacks.
|
|
32
|
+
"""
|
|
33
|
+
|
|
34
|
+
def __init__(self) -> None:
|
|
35
|
+
self.callbacks = {
|
|
36
|
+
"rich_progress": RichProgressBar,
|
|
37
|
+
"early_stopping": EarlyStopping,
|
|
38
|
+
"model_checkpoint": ModelCheckpoint,
|
|
39
|
+
"lr_monitor": LearningRateMonitor,
|
|
40
|
+
"device_stats": DeviceStatsMonitor,
|
|
41
|
+
}
|
|
42
|
+
|
|
43
|
+
def create(self, callback_name: str, config: dict[str, Any] | None = None) -> Callback:
|
|
44
|
+
"""Creates an instance of the specified callback.
|
|
45
|
+
|
|
46
|
+
Args:
|
|
47
|
+
callback_name (str): The name of the callback to create. Must be a key in the `callbacks` dictionary.
|
|
48
|
+
config (dict[str, Any] | None, optional): A dictionary of configuration parameters for the callback.
|
|
49
|
+
If provided, these parameters will be passed to the callback's constructor. Defaults to None.
|
|
50
|
+
|
|
51
|
+
Raises:
|
|
52
|
+
ValueError: If the specified `callback_name` is not found in the `callbacks` dictionary.
|
|
53
|
+
|
|
54
|
+
Returns:
|
|
55
|
+
Callback: An instance of the requested callback.
|
|
56
|
+
"""
|
|
57
|
+
callback_class = self.callbacks.get(callback_name)
|
|
58
|
+
if not callback_class:
|
|
59
|
+
raise ValueError(f"Callback {callback_name} not supported")
|
|
60
|
+
|
|
61
|
+
if config:
|
|
62
|
+
return callback_class(**config)
|
|
63
|
+
return callback_class()
|
|
64
|
+
|
|
65
|
+
def register_callback(self, name: str, callback_class: type[Callback]) -> None:
|
|
66
|
+
"""Registers a new callback class with the factory.
|
|
67
|
+
|
|
68
|
+
Args:
|
|
69
|
+
name (str): The name to associate with the callback.
|
|
70
|
+
callback_class (type[Callback]): The callback class to register.
|
|
71
|
+
"""
|
|
72
|
+
self.callbacks[name] = callback_class
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
class LoggerFactory:
|
|
76
|
+
"""A factory class for creating and managing PyTorch Lightning loggers.
|
|
77
|
+
|
|
78
|
+
This class provides a centralized way to create and register loggers for use in PyTorch Lightning
|
|
79
|
+
training workflows. It supports built-in loggers like `TensorBoardLogger`, `WandbLogger`, and
|
|
80
|
+
`MLFlowLogger`, and allows users to register custom loggers.
|
|
81
|
+
"""
|
|
82
|
+
|
|
83
|
+
def __init__(self) -> None:
|
|
84
|
+
self.loggers = {
|
|
85
|
+
"tensorboard": AnomalibTensorBoardLogger,
|
|
86
|
+
"wandb": AnomalibWandbLogger,
|
|
87
|
+
"mlflow": AnomalibMLFlowLogger,
|
|
88
|
+
"comet": AnomalibCometLogger,
|
|
89
|
+
}
|
|
90
|
+
|
|
91
|
+
def create(
|
|
92
|
+
self, logger_name: str, config: dict[str, Any] | None = None
|
|
93
|
+
) -> TensorBoardLogger | WandbLogger | MLFlowLogger | CometLogger:
|
|
94
|
+
"""Creates an instance of the specified logger.
|
|
95
|
+
|
|
96
|
+
Args:
|
|
97
|
+
logger_name (str): The name of the logger to create. Must be a key in the `loggers` dictionary.
|
|
98
|
+
config (dict[str, Any] | None, optional): A dictionary of configuration parameters for the logger.
|
|
99
|
+
If provided, these parameters will be passed to the logger's constructor. Defaults to None.
|
|
100
|
+
|
|
101
|
+
Raises:
|
|
102
|
+
ValueError: If the specified `logger_name` is not found in the `loggers` dictionary.
|
|
103
|
+
|
|
104
|
+
Returns:
|
|
105
|
+
TensorBoardLogger | WandbLogger | MLFlowLogger | CometLogger: An instance of the requested logger.
|
|
106
|
+
"""
|
|
107
|
+
logger_class = self.loggers.get(logger_name)
|
|
108
|
+
if not logger_class:
|
|
109
|
+
raise ValueError(f"Logger {logger_name} not supported")
|
|
110
|
+
|
|
111
|
+
if config:
|
|
112
|
+
return logger_class(**config)
|
|
113
|
+
return logger_class()
|
|
114
|
+
|
|
115
|
+
def register_logger(self, name: str, logger_class: type) -> None:
|
|
116
|
+
"""Registers a new logger class with the factory.
|
|
117
|
+
|
|
118
|
+
Args:
|
|
119
|
+
name (str): The name to associate with the logger.
|
|
120
|
+
logger_class (type): The logger class to register.
|
|
121
|
+
"""
|
|
122
|
+
self.loggers[name] = logger_class
|
|
@@ -0,0 +1,31 @@
|
|
|
1
|
+
# -*- coding: utf-8 -*-
|
|
2
|
+
import importlib
|
|
3
|
+
from typing import Callable, cast
|
|
4
|
+
|
|
5
|
+
from sinapsis.templates import _import_template_package
|
|
6
|
+
|
|
7
|
+
_root_lib_path = "sinapsis_anomalib.templates"
|
|
8
|
+
_ADDITIONAL_TEMPLATE_MODULES = [
|
|
9
|
+
f"{_root_lib_path}.anomalib_train",
|
|
10
|
+
f"{_root_lib_path}.anomalib_export",
|
|
11
|
+
]
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
_template_lookup: dict = {
|
|
15
|
+
"AnomalibTorchInference": f"{_root_lib_path}.anomalib_torch_inference",
|
|
16
|
+
"AnomalibOpenVINOInference": f"{_root_lib_path}.anomalib_openvino_inference",
|
|
17
|
+
}
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
for t_module in _ADDITIONAL_TEMPLATE_MODULES:
|
|
21
|
+
_template_lookup |= _import_template_package(t_module)
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def __getattr__(name: str) -> Callable:
|
|
25
|
+
if name in _template_lookup:
|
|
26
|
+
module = importlib.import_module(_template_lookup[name])
|
|
27
|
+
return cast(Callable, getattr(module, name))
|
|
28
|
+
raise AttributeError(f"template `{name}` not found in {_root_lib_path}")
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
__all__ = list(_template_lookup.keys())
|
|
@@ -0,0 +1,212 @@
|
|
|
1
|
+
# -*- coding: utf-8 -*-
|
|
2
|
+
from abc import abstractmethod
|
|
3
|
+
from collections.abc import Iterable, Sequence
|
|
4
|
+
from pathlib import Path
|
|
5
|
+
from typing import Any, Protocol, TypeAlias
|
|
6
|
+
|
|
7
|
+
from anomalib import models as anomalib_models
|
|
8
|
+
from anomalib.data import Folder
|
|
9
|
+
from anomalib.engine import Engine
|
|
10
|
+
from anomalib.utils.types import NORMALIZATION, THRESHOLD
|
|
11
|
+
from lightning.pytorch.callbacks import Callback
|
|
12
|
+
from lightning.pytorch.loggers import Logger
|
|
13
|
+
from pydantic import BaseModel, ConfigDict
|
|
14
|
+
from pydantic.dataclasses import dataclass
|
|
15
|
+
from sinapsis_core.data_containers.data_packet import DataContainer
|
|
16
|
+
from sinapsis_core.template_base import TemplateAttributes, TemplateAttributeType
|
|
17
|
+
from sinapsis_core.template_base.dynamic_template import BaseDynamicWrapperTemplate, WrapperEntryConfig
|
|
18
|
+
|
|
19
|
+
from sinapsis_anomalib.helpers.config_factory import CallbackFactory, LoggerFactory
|
|
20
|
+
|
|
21
|
+
EXCLUDED_MODELS = [
|
|
22
|
+
"EfficientAd",
|
|
23
|
+
"VlmAd",
|
|
24
|
+
"Cfa",
|
|
25
|
+
"Dfkde",
|
|
26
|
+
"Fastflow",
|
|
27
|
+
"Supersimplenet",
|
|
28
|
+
"AiVad",
|
|
29
|
+
]
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
@dataclass(frozen=True)
|
|
33
|
+
class AnomalibKeys:
|
|
34
|
+
"""Constants for accessing Anomalib Engine configuration sections.
|
|
35
|
+
|
|
36
|
+
Attributes:
|
|
37
|
+
CALLBACKS (str): Key for callback configurations in Engine setup
|
|
38
|
+
LOGGER (str): Key for logger configurations in Engine setup
|
|
39
|
+
"""
|
|
40
|
+
|
|
41
|
+
CALLBACKS: str = "callbacks"
|
|
42
|
+
LOGGER: str = "logger"
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
METRICS_TYPE: TypeAlias = list[str] | str | dict[str, dict[str, Any]]
|
|
46
|
+
PATH_TYPE: TypeAlias = str | Path | None
|
|
47
|
+
PATH_SEQUENCE_TYPE: TypeAlias = str | Path | Sequence[str | Path] | None
|
|
48
|
+
LOGGER_TYPE: TypeAlias = Logger | Iterable[Logger] | bool | None
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
class EngineConfig(BaseModel):
|
|
52
|
+
"""Pydantic model for Anomalib Engine configuration.
|
|
53
|
+
|
|
54
|
+
Attributes:
|
|
55
|
+
callbacks (list[Callback] | None): List of PyTorch Lightning callbacks
|
|
56
|
+
normalization (NORMALIZATION | None): Input normalization configuration
|
|
57
|
+
threshold (THRESHOLD | None): Anomaly threshold configuration
|
|
58
|
+
image_metrics (METRICS_TYPE | None): Image-level evaluation metrics
|
|
59
|
+
pixel_metrics (METRICS_TYPE | None): Pixel-level evaluation metrics
|
|
60
|
+
logger (LOGGER_TYPE): Logger configuration
|
|
61
|
+
default_root_dir (PATH_TYPE): Root directory for outputs
|
|
62
|
+
callback_configs (dict[str, dict[str, Any]] | None): Callback configurations
|
|
63
|
+
logger_configs (dict[str, dict[str, Any]] | None): Logger configurations
|
|
64
|
+
"""
|
|
65
|
+
|
|
66
|
+
callbacks: list[Callback] | None = None
|
|
67
|
+
normalization: NORMALIZATION | None = None
|
|
68
|
+
threshold: THRESHOLD | None = None
|
|
69
|
+
image_metrics: METRICS_TYPE | None = None
|
|
70
|
+
pixel_metrics: METRICS_TYPE | None = None
|
|
71
|
+
logger: LOGGER_TYPE = None
|
|
72
|
+
default_root_dir: PATH_TYPE = None
|
|
73
|
+
callback_configs: dict[str, dict[str, Any]] | None = None
|
|
74
|
+
logger_configs: dict[str, dict[str, Any]] | None = None
|
|
75
|
+
|
|
76
|
+
model_config = ConfigDict(arbitrary_types_allowed=True)
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
class HasSUFFIX(Protocol):
|
|
80
|
+
"""Protocol for classes that have a SUFFIX attribute.
|
|
81
|
+
|
|
82
|
+
Attributes:
|
|
83
|
+
SUFFIX (str): Class attribute suffix string
|
|
84
|
+
"""
|
|
85
|
+
|
|
86
|
+
SUFFIX: str
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
class DynamicWrapperEntry:
|
|
90
|
+
"""Descriptor that dynamically generates WrapperEntryConfig based on owner's SUFFIX."""
|
|
91
|
+
|
|
92
|
+
def __get__(self, _instance: object, owner: type[HasSUFFIX]) -> WrapperEntryConfig:
|
|
93
|
+
"""Dynamically create WrapperEntryConfig based on the SUFFIX value.
|
|
94
|
+
|
|
95
|
+
Args:
|
|
96
|
+
_instance (object): Unused instance reference
|
|
97
|
+
owner (type[HasSUFFIX]): Owning class with SUFFIX attribute
|
|
98
|
+
|
|
99
|
+
Returns:
|
|
100
|
+
WrapperEntryConfig: Configured wrapper entry
|
|
101
|
+
"""
|
|
102
|
+
return WrapperEntryConfig(
|
|
103
|
+
wrapped_object=anomalib_models,
|
|
104
|
+
template_name_suffix=owner.SUFFIX,
|
|
105
|
+
exclude_module_atts=EXCLUDED_MODELS,
|
|
106
|
+
)
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
class AnomalibBaseAttributes(TemplateAttributes):
|
|
110
|
+
"""Base attributes for Anomalib model templates.
|
|
111
|
+
|
|
112
|
+
Attributes:
|
|
113
|
+
folder_attributes (dict[str, Any]): Configuration for Folder datamodule. Required for training, optional
|
|
114
|
+
for export.
|
|
115
|
+
callbacks (list[Callback] | None): Lightning callbacks
|
|
116
|
+
normalization (NORMALIZATION | None): Input normalization
|
|
117
|
+
threshold (THRESHOLD | None): Prediction threshold
|
|
118
|
+
image_metrics (METRICS_TYPE | None): Image metrics
|
|
119
|
+
pixel_metrics (METRICS_TYPE | None): Pixel metrics
|
|
120
|
+
logger (LOGGER_TYPE): Lightning logger
|
|
121
|
+
default_root_dir (PATH_TYPE): Output directory
|
|
122
|
+
callback_configs (dict[str, dict[str, Any]] | None): Callback configs
|
|
123
|
+
logger_configs (dict[str, dict[str, Any]] | None): Logger configs
|
|
124
|
+
"""
|
|
125
|
+
|
|
126
|
+
folder_attributes: dict[str, Any] | None = None
|
|
127
|
+
callbacks: list[Callback] | None = None
|
|
128
|
+
normalization: NORMALIZATION | None = None
|
|
129
|
+
threshold: THRESHOLD | None = None
|
|
130
|
+
image_metrics: METRICS_TYPE | None = None
|
|
131
|
+
pixel_metrics: METRICS_TYPE | None = None
|
|
132
|
+
logger: LOGGER_TYPE = None
|
|
133
|
+
default_root_dir: PATH_TYPE = None
|
|
134
|
+
callback_configs: dict[str, dict[str, Any]] | None = None
|
|
135
|
+
logger_configs: dict[str, dict[str, Any]] | None = None
|
|
136
|
+
|
|
137
|
+
model_config = ConfigDict(arbitrary_types_allowed=True)
|
|
138
|
+
|
|
139
|
+
|
|
140
|
+
class AnomalibBase(BaseDynamicWrapperTemplate):
|
|
141
|
+
"""Base class for Anomalib model Train and Export templates.
|
|
142
|
+
|
|
143
|
+
Notes:
|
|
144
|
+
- Subclasses must override SUFFIX as needed for their specific purpose
|
|
145
|
+
- When using 'Train' suffix, all essential Folder attributes must be provided
|
|
146
|
+
- INT8_PTQ/INT8_ACQ compression requires complete Folder configuration
|
|
147
|
+
- Callback and logger configurations are optional but must follow Anomalib specs
|
|
148
|
+
"""
|
|
149
|
+
|
|
150
|
+
SUFFIX: str = "Wrapper"
|
|
151
|
+
AttributesBaseModel = AnomalibBaseAttributes
|
|
152
|
+
CATEGORY = "Anomalib"
|
|
153
|
+
WrapperEntry = DynamicWrapperEntry()
|
|
154
|
+
|
|
155
|
+
def __init__(self, attributes: TemplateAttributeType) -> None:
|
|
156
|
+
super().__init__(attributes=attributes)
|
|
157
|
+
self.model = self.wrapped_callable
|
|
158
|
+
self.callback_factory = CallbackFactory()
|
|
159
|
+
self.logger_factory = LoggerFactory()
|
|
160
|
+
self.engine = self.setup_engine()
|
|
161
|
+
|
|
162
|
+
def _create_callbacks(self) -> list[Callback]:
|
|
163
|
+
"""Create callbacks from configuration.
|
|
164
|
+
|
|
165
|
+
Returns:
|
|
166
|
+
list[Callback]: Instantiated callback objects
|
|
167
|
+
"""
|
|
168
|
+
return [self.callback_factory.create(name, config) for name, config in self.attributes.callback_configs.items()]
|
|
169
|
+
|
|
170
|
+
def _create_loggers(self) -> list[Logger]:
|
|
171
|
+
"""Create loggers from configuration.
|
|
172
|
+
|
|
173
|
+
Returns:
|
|
174
|
+
list[Logger]: Instantiated logger objects
|
|
175
|
+
"""
|
|
176
|
+
return [self.logger_factory.create(name, config) for name, config in self.attributes.logger_configs.items()]
|
|
177
|
+
|
|
178
|
+
def setup_engine(self) -> Engine:
|
|
179
|
+
"""Configure and initialize the Anomalib Engine.
|
|
180
|
+
|
|
181
|
+
Returns:
|
|
182
|
+
Engine: Configured engine instance
|
|
183
|
+
"""
|
|
184
|
+
engine_kwargs = EngineConfig(**self.attributes.model_dump()).model_dump(
|
|
185
|
+
exclude={"callback_configs", "logger_configs"}, exclude_none=True
|
|
186
|
+
)
|
|
187
|
+
if self.attributes.callback_configs:
|
|
188
|
+
engine_kwargs[AnomalibKeys.CALLBACKS] = self._create_callbacks()
|
|
189
|
+
|
|
190
|
+
if self.attributes.logger_configs:
|
|
191
|
+
engine_kwargs[AnomalibKeys.LOGGER] = self._create_loggers()
|
|
192
|
+
|
|
193
|
+
return Engine(**engine_kwargs)
|
|
194
|
+
|
|
195
|
+
def setup_data_loader(self) -> Folder:
|
|
196
|
+
"""Initialize the data loader from folder attributes if provided.
|
|
197
|
+
|
|
198
|
+
Returns:
|
|
199
|
+
Folder: Configured data module instance
|
|
200
|
+
"""
|
|
201
|
+
return Folder(**self.attributes.folder_attributes)
|
|
202
|
+
|
|
203
|
+
@abstractmethod
|
|
204
|
+
def execute(self, container: DataContainer) -> DataContainer:
|
|
205
|
+
"""Template method to be implemented by subclasses.
|
|
206
|
+
|
|
207
|
+
Args:
|
|
208
|
+
container (DataContainer): Input data container
|
|
209
|
+
|
|
210
|
+
Returns:
|
|
211
|
+
DataContainer: Processed data container
|
|
212
|
+
"""
|
|
@@ -0,0 +1,247 @@
|
|
|
1
|
+
# -*- coding: utf-8 -*-
|
|
2
|
+
from abc import abstractmethod
|
|
3
|
+
from copy import deepcopy
|
|
4
|
+
|
|
5
|
+
import cv2 as cv
|
|
6
|
+
import numpy as np
|
|
7
|
+
import torchvision.transforms.v2 as T
|
|
8
|
+
from anomalib.data.utils.label import LabelName
|
|
9
|
+
from anomalib.deploy import OpenVINOInferencer, TorchInferencer
|
|
10
|
+
from pydantic import Field
|
|
11
|
+
from pydantic.dataclasses import dataclass
|
|
12
|
+
from sinapsis_core.data_containers.annotations import (
|
|
13
|
+
BoundingBox,
|
|
14
|
+
ImageAnnotations,
|
|
15
|
+
Segmentation,
|
|
16
|
+
)
|
|
17
|
+
from sinapsis_core.data_containers.data_packet import DataContainer, ImagePacket
|
|
18
|
+
from sinapsis_core.template_base import (
|
|
19
|
+
Template,
|
|
20
|
+
TemplateAttributes,
|
|
21
|
+
TemplateAttributeType,
|
|
22
|
+
)
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
@dataclass(frozen=True)
|
|
26
|
+
class AnomalibInferenceKeys:
|
|
27
|
+
"""Constants representing keys used in Anomalib inference operations.
|
|
28
|
+
|
|
29
|
+
Attributes:
|
|
30
|
+
TRANSFORMS (str): Key for image transformation configuration
|
|
31
|
+
TASK (str): Key for task type specification
|
|
32
|
+
ANOMALOUS (str): Key representing anomalous/abnormal classification
|
|
33
|
+
NORMAL (str): Key representing normal classification
|
|
34
|
+
"""
|
|
35
|
+
|
|
36
|
+
TRANSFORMS: str = "transforms"
|
|
37
|
+
TASK: str = "task"
|
|
38
|
+
ANOMALOUS: str = "anomalous"
|
|
39
|
+
NORMAL: str = "normal"
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
class AnomalibInferenceAttributes(TemplateAttributes):
|
|
43
|
+
"""Configuration attributes for Anomalib inference templates.
|
|
44
|
+
|
|
45
|
+
Attributes:
|
|
46
|
+
model_path (str): Path to the exported model file
|
|
47
|
+
transforms (list[str] | None): Optional list of additional image transformation names to apply
|
|
48
|
+
anomaly_area_threshold (float): The minimum area to be considered a valid anomaly detection. Defaults to 100.
|
|
49
|
+
"""
|
|
50
|
+
|
|
51
|
+
model_path: str
|
|
52
|
+
transforms: dict = Field(default_factory=dict)
|
|
53
|
+
anomaly_area_threshold: float = 100
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
class AnomalibBaseInference(Template):
|
|
57
|
+
"""Base class for Anomalib model inference implementations.
|
|
58
|
+
|
|
59
|
+
Provides common functionality for processing images through Anomalib models,
|
|
60
|
+
handling classification, segmentation, and detection tasks.
|
|
61
|
+
"""
|
|
62
|
+
|
|
63
|
+
AttributesBaseModel = AnomalibInferenceAttributes
|
|
64
|
+
CATEGORY = "Anomalib"
|
|
65
|
+
|
|
66
|
+
def __init__(self, attributes: TemplateAttributeType) -> None:
|
|
67
|
+
super().__init__(attributes)
|
|
68
|
+
self.inferencer = self.get_inferencer()
|
|
69
|
+
self.transform = self._build_transforms()
|
|
70
|
+
|
|
71
|
+
@abstractmethod
|
|
72
|
+
def get_inferencer(self) -> TorchInferencer | OpenVINOInferencer:
|
|
73
|
+
"""Initialize and return the appropriate inferencer instance. Subclasses must implement this method.
|
|
74
|
+
|
|
75
|
+
Returns:
|
|
76
|
+
TorchInferencer | OpenVINOInferencer: Either a TorchInferencer or OpenVINOInferencer instance
|
|
77
|
+
"""
|
|
78
|
+
|
|
79
|
+
def _convert_transform_names(self) -> list[T.Transform]:
|
|
80
|
+
"""Convert transform names to actual transform callables.
|
|
81
|
+
|
|
82
|
+
Returns:
|
|
83
|
+
list[T.Transform]: List of corresponding transform callable objects
|
|
84
|
+
"""
|
|
85
|
+
transforms = []
|
|
86
|
+
for name, params in self.attributes.transforms.items():
|
|
87
|
+
if hasattr(T, name):
|
|
88
|
+
transform_class = getattr(T, name)
|
|
89
|
+
if params:
|
|
90
|
+
transforms.append(transform_class(**params))
|
|
91
|
+
else:
|
|
92
|
+
transforms.append(transform_class())
|
|
93
|
+
|
|
94
|
+
return transforms
|
|
95
|
+
|
|
96
|
+
@abstractmethod
|
|
97
|
+
def get_transformation_list(self) -> list[T.Transform]:
|
|
98
|
+
"""Construct the list of transformation according to the specified inferencer.
|
|
99
|
+
|
|
100
|
+
Returns:
|
|
101
|
+
list[T.Transform]: List of transformations.
|
|
102
|
+
"""
|
|
103
|
+
|
|
104
|
+
def _build_transforms(self) -> T.Compose:
|
|
105
|
+
"""Build the complete transform pipeline for image preprocessing.
|
|
106
|
+
|
|
107
|
+
Returns:
|
|
108
|
+
T.Compose: A composed transform pipeline
|
|
109
|
+
"""
|
|
110
|
+
transforms = self.get_transformation_list()
|
|
111
|
+
|
|
112
|
+
return T.Compose(transforms)
|
|
113
|
+
|
|
114
|
+
@staticmethod
|
|
115
|
+
@abstractmethod
|
|
116
|
+
def postprocess_segmentation_mask(binary_mask: np.ndarray, image: np.ndarray) -> np.ndarray:
|
|
117
|
+
"""Apply postprocessing operations to the generated mask according to the used inferenceer.
|
|
118
|
+
|
|
119
|
+
Args:
|
|
120
|
+
binary_mask (np.ndarray): Mask produced by inferencer.
|
|
121
|
+
image (np.ndarray): Original image.
|
|
122
|
+
|
|
123
|
+
Returns:
|
|
124
|
+
np.ndarray: Post-processed mask.
|
|
125
|
+
"""
|
|
126
|
+
|
|
127
|
+
def get_boxes_from_mask(self, np_mask: np.ndarray) -> list[list[float]]:
|
|
128
|
+
"""Produce a list of bounding boxes from the predicted segmentation mask.
|
|
129
|
+
|
|
130
|
+
Args:
|
|
131
|
+
np_mask (np.ndarray): Predicted binary mask.
|
|
132
|
+
|
|
133
|
+
Returns:
|
|
134
|
+
list[list[float]]: List of bboxes in [x, y, w, h] format.
|
|
135
|
+
"""
|
|
136
|
+
contours, _ = cv.findContours(np_mask, cv.RETR_TREE, cv.CHAIN_APPROX_SIMPLE)
|
|
137
|
+
boxes = []
|
|
138
|
+
for contour in contours:
|
|
139
|
+
contours_poly = cv.approxPolyDP(contour, 3, True)
|
|
140
|
+
x, y, w, h = cv.boundingRect(contours_poly)
|
|
141
|
+
bbox_area = w * h
|
|
142
|
+
if bbox_area > self.attributes.anomaly_area_threshold:
|
|
143
|
+
boxes.append([x, y, w, h])
|
|
144
|
+
return boxes
|
|
145
|
+
|
|
146
|
+
def process_packet(self, data_packet: ImagePacket) -> None:
|
|
147
|
+
"""Process an individual image packet through the inference pipeline.
|
|
148
|
+
|
|
149
|
+
Args:
|
|
150
|
+
data_packet (ImagePacket): The image packet containing the image to process
|
|
151
|
+
"""
|
|
152
|
+
image = deepcopy(data_packet.content)
|
|
153
|
+
processed_image = self.transform(image)
|
|
154
|
+
result = self.inferencer.predict(processed_image)
|
|
155
|
+
|
|
156
|
+
label = result.pred_label
|
|
157
|
+
label_str = AnomalibInferenceKeys.ANOMALOUS if label == LabelName.ABNORMAL else AnomalibInferenceKeys.NORMAL
|
|
158
|
+
|
|
159
|
+
pred_score = result.pred_score * 100
|
|
160
|
+
|
|
161
|
+
if result.pred_mask is None:
|
|
162
|
+
annotations = [ImageAnnotations(label=label, label_str=label_str, confidence_score=pred_score)]
|
|
163
|
+
else:
|
|
164
|
+
np_mask = self.postprocess_segmentation_mask(result.pred_mask, image)
|
|
165
|
+
|
|
166
|
+
boxes = self.get_boxes_from_mask(np_mask)
|
|
167
|
+
|
|
168
|
+
annotations = self._create_detection_annotations(
|
|
169
|
+
boxes=boxes, pred_score=pred_score, label=label, label_str=label_str, pred_mask=np_mask
|
|
170
|
+
)
|
|
171
|
+
|
|
172
|
+
data_packet.annotations = annotations
|
|
173
|
+
|
|
174
|
+
@staticmethod
|
|
175
|
+
def _create_classification_annotation(pred_score: float, label: str, label_str: str) -> ImageAnnotations:
|
|
176
|
+
"""Create classification annotation from prediction results.
|
|
177
|
+
|
|
178
|
+
Args:
|
|
179
|
+
pred_score (float): Confidence score of the prediction (0-100)
|
|
180
|
+
label (str): Numeric label value
|
|
181
|
+
label_str (str): String representation of the label
|
|
182
|
+
|
|
183
|
+
Returns:
|
|
184
|
+
ImageAnnotations: Object containing classification results
|
|
185
|
+
"""
|
|
186
|
+
return ImageAnnotations(
|
|
187
|
+
label=label,
|
|
188
|
+
label_str=label_str,
|
|
189
|
+
confidence_score=pred_score,
|
|
190
|
+
)
|
|
191
|
+
|
|
192
|
+
@staticmethod
|
|
193
|
+
def _create_detection_annotations(
|
|
194
|
+
boxes: list[list[float]],
|
|
195
|
+
pred_score: float,
|
|
196
|
+
label: int,
|
|
197
|
+
label_str: str,
|
|
198
|
+
pred_mask: np.ndarray,
|
|
199
|
+
) -> list[ImageAnnotations]:
|
|
200
|
+
"""Create detection annotations from prediction results.
|
|
201
|
+
|
|
202
|
+
Args:
|
|
203
|
+
boxes (list[list[float]]): Array of bounding box coordinates in xyxy format
|
|
204
|
+
pred_score (float): Confidence score of the prediction (0-100)
|
|
205
|
+
label (int): Numeric label value
|
|
206
|
+
label_str (str): String representation of the label
|
|
207
|
+
pred_mask (np.ndarray): Optional segmentation mask array
|
|
208
|
+
|
|
209
|
+
Returns:
|
|
210
|
+
list[ImageAnnotations]: List of objects containing detection results
|
|
211
|
+
"""
|
|
212
|
+
annotations = []
|
|
213
|
+
for box in boxes:
|
|
214
|
+
x, y, w, h = box
|
|
215
|
+
bbox = BoundingBox(x=x, y=y, w=w, h=h)
|
|
216
|
+
|
|
217
|
+
box_mask = np.zeros_like(pred_mask, dtype=np.uint8)
|
|
218
|
+
x1, y1 = int(x), int(y)
|
|
219
|
+
x2, y2 = int(x + w), int(y + h)
|
|
220
|
+
|
|
221
|
+
box_mask[y1:y2, x1:x2] = np.where(pred_mask[y1:y2, x1:x2] > 0, 1, 0)
|
|
222
|
+
|
|
223
|
+
segmentation = Segmentation(mask=box_mask)
|
|
224
|
+
|
|
225
|
+
annotations.append(
|
|
226
|
+
ImageAnnotations(
|
|
227
|
+
label=label,
|
|
228
|
+
label_str=label_str,
|
|
229
|
+
confidence_score=pred_score,
|
|
230
|
+
bbox=bbox,
|
|
231
|
+
segmentation=segmentation,
|
|
232
|
+
)
|
|
233
|
+
)
|
|
234
|
+
return annotations
|
|
235
|
+
|
|
236
|
+
def execute(self, container: DataContainer) -> DataContainer:
|
|
237
|
+
"""Process all images in the data container through the inference pipeline.
|
|
238
|
+
|
|
239
|
+
Args:
|
|
240
|
+
container (DataContainer): Input data container with images to process
|
|
241
|
+
|
|
242
|
+
Returns:
|
|
243
|
+
DataContainer: The processed data container with inference results
|
|
244
|
+
"""
|
|
245
|
+
for image_packet in container.images:
|
|
246
|
+
self.process_packet(image_packet)
|
|
247
|
+
return container
|