sinapsis-anomalib 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of sinapsis-anomalib might be problematic. Click here for more details.
- sinapsis_anomalib/__init__.py +0 -0
- sinapsis_anomalib/helpers/__init__.py +0 -0
- sinapsis_anomalib/helpers/config_factory.py +122 -0
- sinapsis_anomalib/templates/__init__.py +31 -0
- sinapsis_anomalib/templates/anomalib_base.py +212 -0
- sinapsis_anomalib/templates/anomalib_base_inference.py +247 -0
- sinapsis_anomalib/templates/anomalib_export.py +204 -0
- sinapsis_anomalib/templates/anomalib_openvino_inference.py +93 -0
- sinapsis_anomalib/templates/anomalib_torch_inference.py +83 -0
- sinapsis_anomalib/templates/anomalib_train.py +205 -0
- sinapsis_anomalib-0.1.0.dist-info/METADATA +949 -0
- sinapsis_anomalib-0.1.0.dist-info/RECORD +15 -0
- sinapsis_anomalib-0.1.0.dist-info/WHEEL +5 -0
- sinapsis_anomalib-0.1.0.dist-info/licenses/LICENSE +661 -0
- sinapsis_anomalib-0.1.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,204 @@
|
|
|
1
|
+
# -*- coding: utf-8 -*-
|
|
2
|
+
from pathlib import Path
|
|
3
|
+
from typing import Any
|
|
4
|
+
|
|
5
|
+
from anomalib.deploy import CompressionType, ExportType
|
|
6
|
+
from pydantic.dataclasses import dataclass
|
|
7
|
+
from sinapsis_core.data_containers.data_packet import DataContainer
|
|
8
|
+
from sinapsis_core.template_base import Template, TemplateAttributeType
|
|
9
|
+
from sinapsis_core.template_base.dynamic_template_factory import make_dynamic_template
|
|
10
|
+
from sinapsis_core.utils.env_var_keys import SINAPSIS_BUILD_DOCS
|
|
11
|
+
from torchmetrics import Metric
|
|
12
|
+
|
|
13
|
+
from sinapsis_anomalib.templates.anomalib_base import (
|
|
14
|
+
AnomalibBase,
|
|
15
|
+
AnomalibBaseAttributes,
|
|
16
|
+
)
|
|
17
|
+
from sinapsis_anomalib.templates.anomalib_train import AnomalibTrainDataClass
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
@dataclass(frozen=True, slots=True)
|
|
21
|
+
class AnomalibExportDataClass:
|
|
22
|
+
"""Container for export results.
|
|
23
|
+
|
|
24
|
+
Attributes:
|
|
25
|
+
exported_model_path (Path | str): Path to the exported model file(s).
|
|
26
|
+
Can be either a Path object or string path.
|
|
27
|
+
"""
|
|
28
|
+
|
|
29
|
+
exported_model_path: Path | str
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class AnomalibExportAttributes(AnomalibBaseAttributes):
|
|
33
|
+
"""Export-specific attribute configuration for Anomalib models.
|
|
34
|
+
|
|
35
|
+
Attributes:
|
|
36
|
+
export_type (ExportType | str): Target export format. Defaults to TORCH.
|
|
37
|
+
export_root (str | Path | None): Root directory for exported files.
|
|
38
|
+
input_size (tuple[int, int] | None): Expected input dimensions (height, width).
|
|
39
|
+
compression_type (CompressionType | None): Quantization/compression method.
|
|
40
|
+
metric (Metric | str | None): Metric used for compression calibration.
|
|
41
|
+
ov_args (dict[str, Any] | None): OpenVINO-specific export arguments.
|
|
42
|
+
ckpt_path (str | None): Explicit path to model checkpoint.
|
|
43
|
+
generic_key_chkpt (str | None): Key to retrieve training results.
|
|
44
|
+
"""
|
|
45
|
+
|
|
46
|
+
export_type: ExportType = ExportType.TORCH
|
|
47
|
+
export_root: str | Path | None = None
|
|
48
|
+
input_size: tuple[int, int] | None = None
|
|
49
|
+
compression_type: CompressionType | None = None
|
|
50
|
+
metric: Metric | str | None = None
|
|
51
|
+
ov_args: dict[str, Any] | None = None
|
|
52
|
+
ckpt_path: str | None = None
|
|
53
|
+
generic_key_chkpt: str | None = None
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
class AnomalibExport(AnomalibBase):
|
|
57
|
+
"""Export functionality for trained Anomalib models.
|
|
58
|
+
|
|
59
|
+
Usage example:
|
|
60
|
+
|
|
61
|
+
agent:
|
|
62
|
+
name: my_test_agent
|
|
63
|
+
templates:
|
|
64
|
+
- template_name: InputTemplate
|
|
65
|
+
class_name: InputTemplate
|
|
66
|
+
attributes: {}
|
|
67
|
+
- template_name: CfaExport
|
|
68
|
+
class_name: CfaExport
|
|
69
|
+
template_input: InputTemplate
|
|
70
|
+
attributes:
|
|
71
|
+
folder_attributes_config_path: null
|
|
72
|
+
generic_key: 'my_generic_key'
|
|
73
|
+
callbacks: null
|
|
74
|
+
normalization: null
|
|
75
|
+
threshold: null
|
|
76
|
+
task: null
|
|
77
|
+
image_metrics: null
|
|
78
|
+
pixel_metrics: null
|
|
79
|
+
logger: null
|
|
80
|
+
default_root_dir: null
|
|
81
|
+
callback_configs: null
|
|
82
|
+
logger_configs: null
|
|
83
|
+
export_type: 'openvino'
|
|
84
|
+
export_root: null
|
|
85
|
+
input_size: null
|
|
86
|
+
transform: null
|
|
87
|
+
compression_type: null
|
|
88
|
+
metric: null
|
|
89
|
+
ov_args: null
|
|
90
|
+
ckpt_path: null
|
|
91
|
+
generic_key_chkpt: null
|
|
92
|
+
cfa_init:
|
|
93
|
+
backbone: wide_resnet50_2
|
|
94
|
+
gamma_c: 1
|
|
95
|
+
gamma_d: 1
|
|
96
|
+
num_nearest_neighbors: 3
|
|
97
|
+
num_hard_negative_features: 3
|
|
98
|
+
radius: 1.0e-05
|
|
99
|
+
|
|
100
|
+
For a full list of options use the sinapsis cli: sinapsis info --all-template-names
|
|
101
|
+
If you want to see all available models, please visit:
|
|
102
|
+
https://anomalib.readthedocs.io/en/v1.2.0/markdown/guides/reference/models/image/index.html
|
|
103
|
+
|
|
104
|
+
Notes:
|
|
105
|
+
- Supports multiple export formats via ExportType
|
|
106
|
+
- Enables model compression through CompressionType.
|
|
107
|
+
- Can load models from checkpoints or previous training results.
|
|
108
|
+
"""
|
|
109
|
+
|
|
110
|
+
AttributesBaseModel = AnomalibExportAttributes
|
|
111
|
+
SUFFIX = "Export"
|
|
112
|
+
|
|
113
|
+
def __init__(self, attributes: TemplateAttributeType) -> None:
|
|
114
|
+
super().__init__(attributes)
|
|
115
|
+
self.data_module = None
|
|
116
|
+
if self.attributes.compression_type in (
|
|
117
|
+
CompressionType.INT8_ACQ,
|
|
118
|
+
CompressionType.INT8_PTQ,
|
|
119
|
+
):
|
|
120
|
+
self.data_module = self.setup_data_loader()
|
|
121
|
+
|
|
122
|
+
def _get_checkpoint_path(self, container: DataContainer) -> str | Path:
|
|
123
|
+
"""Resolves the checkpoint path for model export.
|
|
124
|
+
|
|
125
|
+
Args:
|
|
126
|
+
container (DataContainer): Container with potential training results
|
|
127
|
+
|
|
128
|
+
Raises:
|
|
129
|
+
ValueError: If no valid checkpoint path is found
|
|
130
|
+
|
|
131
|
+
Returns:
|
|
132
|
+
str | Path : Path to model checkpoint
|
|
133
|
+
|
|
134
|
+
Note:
|
|
135
|
+
Priority order:
|
|
136
|
+
1. Explicit ckpt_path from attributes
|
|
137
|
+
2. Training results from container (if generic_key_chkpt provided)
|
|
138
|
+
"""
|
|
139
|
+
if self.attributes.ckpt_path:
|
|
140
|
+
return self.attributes.ckpt_path
|
|
141
|
+
|
|
142
|
+
generic_data = self._get_generic_data(container, self.attributes.generic_key_chkpt)
|
|
143
|
+
if generic_data and isinstance(generic_data, AnomalibTrainDataClass):
|
|
144
|
+
return generic_data.checkpoint_path
|
|
145
|
+
|
|
146
|
+
raise ValueError("No checkpoint path found")
|
|
147
|
+
|
|
148
|
+
def export_model(self, container: DataContainer) -> AnomalibExportDataClass:
|
|
149
|
+
"""Exports the model to specified format.
|
|
150
|
+
|
|
151
|
+
Args:
|
|
152
|
+
container (DataContainer): Input data container
|
|
153
|
+
|
|
154
|
+
Returns:
|
|
155
|
+
AnomalibExportDataClass: Contains exported model path
|
|
156
|
+
"""
|
|
157
|
+
ckpt_path = self._get_checkpoint_path(container)
|
|
158
|
+
|
|
159
|
+
exported_path = self.engine.export(
|
|
160
|
+
model=self.model,
|
|
161
|
+
export_type=self.attributes.export_type,
|
|
162
|
+
export_root=self.attributes.export_root,
|
|
163
|
+
input_size=self.attributes.input_size,
|
|
164
|
+
compression_type=self.attributes.compression_type,
|
|
165
|
+
datamodule=self.data_module,
|
|
166
|
+
metric=self.attributes.metric,
|
|
167
|
+
ov_args=self.attributes.ov_args,
|
|
168
|
+
ckpt_path=ckpt_path,
|
|
169
|
+
)
|
|
170
|
+
|
|
171
|
+
return AnomalibExportDataClass(exported_model_path=exported_path)
|
|
172
|
+
|
|
173
|
+
def execute(self, container: DataContainer) -> DataContainer:
|
|
174
|
+
"""Performs model export and stores results.
|
|
175
|
+
|
|
176
|
+
Args:
|
|
177
|
+
container (DataContainer): Input data container
|
|
178
|
+
|
|
179
|
+
Returns:
|
|
180
|
+
DataContainer: Container with export results stored as generic data
|
|
181
|
+
"""
|
|
182
|
+
result = self.export_model(container)
|
|
183
|
+
self._set_generic_data(container, result)
|
|
184
|
+
return container
|
|
185
|
+
|
|
186
|
+
|
|
187
|
+
def __getattr__(name: str) -> Template:
|
|
188
|
+
"""Creat dynamic templates.
|
|
189
|
+
|
|
190
|
+
Only create a template if it's imported, this avoids creating all the base models for all templates
|
|
191
|
+
and potential import errors due to not available packages.
|
|
192
|
+
"""
|
|
193
|
+
if name in AnomalibExport.WrapperEntry.module_att_names:
|
|
194
|
+
return make_dynamic_template(name, AnomalibExport)
|
|
195
|
+
raise AttributeError(f"template `{name}` not found in {__name__}")
|
|
196
|
+
|
|
197
|
+
|
|
198
|
+
__all__ = AnomalibExport.WrapperEntry.module_att_names
|
|
199
|
+
|
|
200
|
+
if SINAPSIS_BUILD_DOCS:
|
|
201
|
+
dynamic_templates = [__getattr__(template_name) for template_name in __all__]
|
|
202
|
+
for template in dynamic_templates:
|
|
203
|
+
globals()[template.__name__] = template
|
|
204
|
+
del template
|
|
@@ -0,0 +1,93 @@
|
|
|
1
|
+
# -*- coding: utf-8 -*-
|
|
2
|
+
import os
|
|
3
|
+
from typing import Literal
|
|
4
|
+
|
|
5
|
+
import cv2 as cv
|
|
6
|
+
import numpy as np
|
|
7
|
+
import torchvision.transforms.v2 as T
|
|
8
|
+
from anomalib.deploy import OpenVINOInferencer
|
|
9
|
+
|
|
10
|
+
from sinapsis_anomalib.templates.anomalib_base_inference import (
|
|
11
|
+
AnomalibBaseInference,
|
|
12
|
+
AnomalibInferenceAttributes,
|
|
13
|
+
)
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class AnomalibOpenVINOInferenceAttributes(AnomalibInferenceAttributes):
|
|
17
|
+
"""OpenVINO-specific inference attribute configuration.
|
|
18
|
+
|
|
19
|
+
Attributes:
|
|
20
|
+
device (Literal["CPU", "GPU"]): Target hardware accelerator for inference.
|
|
21
|
+
Must be either 'CPU' or 'GPU'.
|
|
22
|
+
model_height (int): The image height expected by OV model.
|
|
23
|
+
model_width (int): The image width expected by OV model.
|
|
24
|
+
"""
|
|
25
|
+
|
|
26
|
+
device: Literal["CPU", "GPU"]
|
|
27
|
+
model_height: int
|
|
28
|
+
model_width: int
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class AnomalibOpenVINOInference(AnomalibBaseInference):
|
|
32
|
+
"""OpenVINO-specific inference implementation for Anomalib models.
|
|
33
|
+
|
|
34
|
+
Extends base inference to provide optimized model execution using OpenVINO toolkit.
|
|
35
|
+
|
|
36
|
+
Usage example:
|
|
37
|
+
|
|
38
|
+
agent:
|
|
39
|
+
name: my_test_agent
|
|
40
|
+
templates:
|
|
41
|
+
- template_name: InputTemplate
|
|
42
|
+
class_name: InputTemplate
|
|
43
|
+
attributes: {}
|
|
44
|
+
- template_name: AnomalibOpenVINOInference
|
|
45
|
+
class_name: AnomalibOpenVINOInference
|
|
46
|
+
template_input: InputTemplate
|
|
47
|
+
attributes:
|
|
48
|
+
model_path: 'path/to/model.xml'
|
|
49
|
+
transforms: null
|
|
50
|
+
device: CPU
|
|
51
|
+
model_height: 256
|
|
52
|
+
model_width: 256
|
|
53
|
+
"""
|
|
54
|
+
|
|
55
|
+
AttributesBaseModel = AnomalibOpenVINOInferenceAttributes
|
|
56
|
+
|
|
57
|
+
def get_inferencer(self) -> OpenVINOInferencer:
|
|
58
|
+
"""Initialize OpenVINO inferencer with model and metadata.
|
|
59
|
+
|
|
60
|
+
Returns:
|
|
61
|
+
OpenVINOInferencer: Inferencer instance with model and metadata loaded.
|
|
62
|
+
"""
|
|
63
|
+
model_path: str = os.path.expanduser(self.attributes.model_path)
|
|
64
|
+
|
|
65
|
+
return OpenVINOInferencer(
|
|
66
|
+
path=model_path,
|
|
67
|
+
device=self.attributes.device,
|
|
68
|
+
)
|
|
69
|
+
|
|
70
|
+
@staticmethod
|
|
71
|
+
def postprocess_segmentation_mask(binary_mask: np.ndarray, image: np.ndarray) -> np.ndarray:
|
|
72
|
+
"""Apply resizing and squeezing to mask.
|
|
73
|
+
|
|
74
|
+
Args:
|
|
75
|
+
binary_mask (np.ndarray): Mask produced by OpenVinoInfencer.
|
|
76
|
+
image (np.ndarray): Original image.
|
|
77
|
+
|
|
78
|
+
Returns:
|
|
79
|
+
np.ndarray: Postprocessed mask.
|
|
80
|
+
"""
|
|
81
|
+
height, width = image.shape[:2]
|
|
82
|
+
binary_mask = np.squeeze(binary_mask).astype(np.uint8)
|
|
83
|
+
return cv.resize(binary_mask, (width, height))
|
|
84
|
+
|
|
85
|
+
def get_transformation_list(self) -> list[T.Transform]:
|
|
86
|
+
"""Construct the list of transformations for the OpenVinoInferencer.
|
|
87
|
+
|
|
88
|
+
Returns:
|
|
89
|
+
list[T.Transform]: List of transformations.
|
|
90
|
+
"""
|
|
91
|
+
transforms = [T.ToImage(), T.Resize([self.attributes.model_height, self.attributes.model_width])]
|
|
92
|
+
transforms.extend(self._convert_transform_names())
|
|
93
|
+
return transforms
|
|
@@ -0,0 +1,83 @@
|
|
|
1
|
+
# -*- coding: utf-8 -*-
|
|
2
|
+
import os
|
|
3
|
+
from typing import Literal
|
|
4
|
+
|
|
5
|
+
import numpy as np
|
|
6
|
+
import torch
|
|
7
|
+
import torchvision.transforms.v2 as T
|
|
8
|
+
from anomalib.deploy import TorchInferencer
|
|
9
|
+
|
|
10
|
+
from sinapsis_anomalib.templates.anomalib_base_inference import (
|
|
11
|
+
AnomalibBaseInference,
|
|
12
|
+
AnomalibInferenceAttributes,
|
|
13
|
+
)
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class AnomalibTorchInferenceAttributes(AnomalibInferenceAttributes):
|
|
17
|
+
"""PyTorch-specific inference attribute configuration.
|
|
18
|
+
|
|
19
|
+
Attributes:
|
|
20
|
+
device (Literal["cuda", "cpu"]): Target device for inference (either 'cuda' or 'cpu').
|
|
21
|
+
"""
|
|
22
|
+
|
|
23
|
+
device: Literal["cuda", "cpu"]
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class AnomalibTorchInference(AnomalibBaseInference):
|
|
27
|
+
"""PyTorch-specific inference implementation for Anomalib models.
|
|
28
|
+
|
|
29
|
+
Extends base inference to provide native PyTorch model execution.
|
|
30
|
+
|
|
31
|
+
Usage example:
|
|
32
|
+
|
|
33
|
+
agent:
|
|
34
|
+
name: my_test_agent
|
|
35
|
+
templates:
|
|
36
|
+
- template_name: InputTemplate
|
|
37
|
+
class_name: InputTemplate
|
|
38
|
+
attributes: {}
|
|
39
|
+
- template_name: AnomalibTorchInference
|
|
40
|
+
class_name: AnomalibTorchInference
|
|
41
|
+
template_input: InputTemplate
|
|
42
|
+
attributes:
|
|
43
|
+
model_path: 'path/to/model.pt'
|
|
44
|
+
transforms: null
|
|
45
|
+
device: 'cuda'
|
|
46
|
+
"""
|
|
47
|
+
|
|
48
|
+
AttributesBaseModel = AnomalibTorchInferenceAttributes
|
|
49
|
+
|
|
50
|
+
def get_inferencer(self) -> TorchInferencer:
|
|
51
|
+
"""Get PyTorch Inferencer instance.
|
|
52
|
+
|
|
53
|
+
Returns:
|
|
54
|
+
TorchInferencer: Inferencer instance with model loaded on specified device.
|
|
55
|
+
"""
|
|
56
|
+
model_path: str = os.path.expanduser(self.attributes.model_path)
|
|
57
|
+
return TorchInferencer(path=model_path, device=self.attributes.device)
|
|
58
|
+
|
|
59
|
+
@staticmethod
|
|
60
|
+
def postprocess_segmentation_mask(binary_mask: torch.TensorType, image: np.ndarray) -> np.ndarray:
|
|
61
|
+
"""Apply rescaling, squeezing and conversion from torch to numpy array format.
|
|
62
|
+
|
|
63
|
+
Args:
|
|
64
|
+
binary_mask (np.ndarray): Mask produced by TorchInferencer.
|
|
65
|
+
image (np.ndarray): Input image.
|
|
66
|
+
|
|
67
|
+
Returns:
|
|
68
|
+
np.ndarray: Postprocessed mask.
|
|
69
|
+
"""
|
|
70
|
+
height, width = image.shape[:2]
|
|
71
|
+
rescaled_mask = torch.squeeze(T.Resize(size=[height, width])(binary_mask))
|
|
72
|
+
return rescaled_mask.cpu().numpy().astype(np.uint8)
|
|
73
|
+
|
|
74
|
+
def get_transformation_list(self) -> list[T.Transform]:
|
|
75
|
+
"""Construct the list of transformations for the TorchInferencer.
|
|
76
|
+
|
|
77
|
+
Returns:
|
|
78
|
+
list[T.Transform]: List of transformations.
|
|
79
|
+
"""
|
|
80
|
+
transforms = [T.ToImage(), T.ToDtype(torch.float32, scale=True)]
|
|
81
|
+
|
|
82
|
+
transforms.extend(self._convert_transform_names())
|
|
83
|
+
return transforms
|
|
@@ -0,0 +1,205 @@
|
|
|
1
|
+
# -*- coding: utf-8 -*-
|
|
2
|
+
from pathlib import Path
|
|
3
|
+
from typing import Any, Literal
|
|
4
|
+
|
|
5
|
+
from anomalib.engine.engine import _TrainerArgumentsCache
|
|
6
|
+
from pydantic import Field
|
|
7
|
+
from pydantic.dataclasses import dataclass
|
|
8
|
+
from sinapsis_core.data_containers.data_packet import DataContainer
|
|
9
|
+
from sinapsis_core.template_base import Template, TemplateAttributeType
|
|
10
|
+
from sinapsis_core.template_base.dynamic_template_factory import make_dynamic_template
|
|
11
|
+
from sinapsis_core.utils.env_var_keys import SINAPSIS_BUILD_DOCS
|
|
12
|
+
|
|
13
|
+
from sinapsis_anomalib.templates.anomalib_base import (
|
|
14
|
+
AnomalibBase,
|
|
15
|
+
AnomalibBaseAttributes,
|
|
16
|
+
)
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
@dataclass(frozen=True)
|
|
20
|
+
class AnomalibTrainKeys:
|
|
21
|
+
"""Constants for accessing training-related configuration keys.
|
|
22
|
+
|
|
23
|
+
Attributes:
|
|
24
|
+
ACCELERATOR (str): Key for accessing accelerator field in trainer args dict.
|
|
25
|
+
DEVICES: (str): Key for accessing devices field in trainer args dict.
|
|
26
|
+
CALLBACK_METRICS (str): Key for accessing callback metrics ('callback_metrics')
|
|
27
|
+
MAX_EPOCHS (str): Key for maximum epochs setting ('max_epochs')
|
|
28
|
+
CKPT_PATH (str): Key for checkpoint path ('ckpt_path')
|
|
29
|
+
BEST_MODEL_PATH (str): Key for best model path ('best_model_path')
|
|
30
|
+
"""
|
|
31
|
+
|
|
32
|
+
ACCELERATOR: str = "accelerator"
|
|
33
|
+
CALLBACK_METRICS: str = "callback_metrics"
|
|
34
|
+
MAX_EPOCHS: str = "max_epochs"
|
|
35
|
+
CKPT_PATH: str = "ckpt_path"
|
|
36
|
+
BEST_MODEL_PATH: str = "best_model_path"
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
@dataclass(frozen=True, slots=True)
|
|
40
|
+
class AnomalibTrainDataClass:
|
|
41
|
+
"""Container for training results and artifacts.
|
|
42
|
+
|
|
43
|
+
Attributes:
|
|
44
|
+
metrics (dict[str, float]): Dictionary of training metrics (metric_name: value)
|
|
45
|
+
checkpoint_path (Path | str | None): Path to the best model checkpoint,
|
|
46
|
+
or None if no checkpoint was saved
|
|
47
|
+
"""
|
|
48
|
+
|
|
49
|
+
metrics: dict[str, float]
|
|
50
|
+
checkpoint_path: Path | str | None
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
class AnomalibTrainAttributes(AnomalibBaseAttributes):
|
|
54
|
+
"""Training-specific configuration attributes.
|
|
55
|
+
|
|
56
|
+
Attributes:
|
|
57
|
+
max_epochs (int | None): Maximum number of training epochs.
|
|
58
|
+
If None, uses default from model configuration.
|
|
59
|
+
accelerator: (Literal["cpu", "gpu", "tpu", "hpu", "auto"]): Define the device to be used during training.
|
|
60
|
+
Defaults to "cpu".
|
|
61
|
+
ckpt_path (str | Path | None): Path to checkpoint for resuming training.
|
|
62
|
+
If None, starts training from scratch.
|
|
63
|
+
trainer_args: (dict[str, Any]): General trainer arguments. For more details see:
|
|
64
|
+
https://lightning.ai/docs/pytorch/stable/common/trainer.html#trainer-flags
|
|
65
|
+
"""
|
|
66
|
+
|
|
67
|
+
max_epochs: int | None = None
|
|
68
|
+
accelerator: Literal["cpu", "gpu", "tpu", "hpu", "auto"] = "cpu"
|
|
69
|
+
ckpt_path: str | Path | None = None
|
|
70
|
+
trainer_args: dict[str, Any] = Field(default_factory=dict)
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
class AnomalibTrain(AnomalibBase):
|
|
74
|
+
"""Training implementation for Anomalib models.
|
|
75
|
+
|
|
76
|
+
Usage example:
|
|
77
|
+
|
|
78
|
+
agent:
|
|
79
|
+
name: my_test_agent
|
|
80
|
+
templates:
|
|
81
|
+
- template_name: InputTemplate
|
|
82
|
+
class_name: InputTemplate
|
|
83
|
+
attributes: {}
|
|
84
|
+
- template_name: CfaTrain
|
|
85
|
+
class_name: CfaTrain
|
|
86
|
+
template_input: InputTemplate
|
|
87
|
+
attributes:
|
|
88
|
+
folder_attributes_config_path: 'path/to/config.yaml'
|
|
89
|
+
generic_key: 'my_generic_key'
|
|
90
|
+
callbacks: null
|
|
91
|
+
normalization: null
|
|
92
|
+
threshold: null
|
|
93
|
+
image_metrics: null
|
|
94
|
+
pixel_metrics: null
|
|
95
|
+
logger: null
|
|
96
|
+
default_root_dir: null
|
|
97
|
+
callback_configs: null
|
|
98
|
+
logger_configs: null
|
|
99
|
+
max_epochs: null
|
|
100
|
+
ckpt_path: null
|
|
101
|
+
accelerator: gpu
|
|
102
|
+
trainer_args:
|
|
103
|
+
devices: "0"
|
|
104
|
+
cfa_init:
|
|
105
|
+
backbone: wide_resnet50_2
|
|
106
|
+
gamma_c: 1
|
|
107
|
+
gamma_d: 1
|
|
108
|
+
num_nearest_neighbors: 3
|
|
109
|
+
num_hard_negative_features: 3
|
|
110
|
+
radius: 1.0e-05
|
|
111
|
+
For a full list of options use the sinapsis cli: sinapsis info --all-template-names.
|
|
112
|
+
If you want to see all available models, please visit:
|
|
113
|
+
https://anomalib.readthedocs.io/en/v1.2.0/markdown/guides/reference/models/image/index.html.
|
|
114
|
+
"""
|
|
115
|
+
|
|
116
|
+
SUFFIX = "Train"
|
|
117
|
+
AttributesBaseModel = AnomalibTrainAttributes
|
|
118
|
+
|
|
119
|
+
def __init__(self, attributes: TemplateAttributeType) -> None:
|
|
120
|
+
super().__init__(attributes)
|
|
121
|
+
self.data_module = self.setup_data_loader()
|
|
122
|
+
|
|
123
|
+
def _update_trainer_args(self) -> None:
|
|
124
|
+
"""Updates the trainer configuration with current settings.
|
|
125
|
+
|
|
126
|
+
Specifically sets the maximum number of training epochs
|
|
127
|
+
based on the attributes configuration.
|
|
128
|
+
"""
|
|
129
|
+
existing_args = self.engine._cache.args
|
|
130
|
+
existing_args[AnomalibTrainKeys.MAX_EPOCHS] = self.attributes.max_epochs
|
|
131
|
+
|
|
132
|
+
self.attributes.trainer_args[AnomalibTrainKeys.ACCELERATOR] = self.attributes.accelerator
|
|
133
|
+
existing_args.update(self.attributes.trainer_args)
|
|
134
|
+
self.engine._cache = _TrainerArgumentsCache(**existing_args)
|
|
135
|
+
|
|
136
|
+
def _get_training_metrics(self) -> dict[str, Any]:
|
|
137
|
+
"""Extracts training metrics from the model.
|
|
138
|
+
|
|
139
|
+
Returns:
|
|
140
|
+
dict[str, float]: Dictionary of metric names and their values.
|
|
141
|
+
Returns empty dict if no metrics available.
|
|
142
|
+
"""
|
|
143
|
+
if not hasattr(self.engine.trainer, AnomalibTrainKeys.CALLBACK_METRICS):
|
|
144
|
+
return {}
|
|
145
|
+
return {k: v.item() if hasattr(v, "item") else v for k, v in self.engine.trainer.callback_metrics.items()}
|
|
146
|
+
|
|
147
|
+
def train_model(self) -> AnomalibTrainDataClass:
|
|
148
|
+
"""Executes model training process.
|
|
149
|
+
|
|
150
|
+
Returns:
|
|
151
|
+
AnomalibTrainDataClass: Contains:
|
|
152
|
+
- metrics: Collected training metrics
|
|
153
|
+
- checkpoint_path: Path to best model checkpoint
|
|
154
|
+
|
|
155
|
+
Note:
|
|
156
|
+
- Updates trainer configuration before starting
|
|
157
|
+
- Uses provided checkpoint if available for resuming
|
|
158
|
+
"""
|
|
159
|
+
self._update_trainer_args()
|
|
160
|
+
|
|
161
|
+
self.engine.train(
|
|
162
|
+
model=self.model,
|
|
163
|
+
datamodule=self.data_module,
|
|
164
|
+
ckpt_path=self.attributes.ckpt_path,
|
|
165
|
+
)
|
|
166
|
+
|
|
167
|
+
metrics = self._get_training_metrics()
|
|
168
|
+
checkpoint_path = (
|
|
169
|
+
self.engine.best_model_path if hasattr(self.engine, AnomalibTrainKeys.BEST_MODEL_PATH) else None
|
|
170
|
+
)
|
|
171
|
+
|
|
172
|
+
return AnomalibTrainDataClass(metrics=metrics, checkpoint_path=checkpoint_path)
|
|
173
|
+
|
|
174
|
+
def execute(self, container: DataContainer) -> DataContainer:
|
|
175
|
+
"""Executes training and stores results.
|
|
176
|
+
|
|
177
|
+
Args:
|
|
178
|
+
container: Input data container
|
|
179
|
+
|
|
180
|
+
Returns:
|
|
181
|
+
Modified container with training results stored under generic_key
|
|
182
|
+
"""
|
|
183
|
+
result = self.train_model()
|
|
184
|
+
self._set_generic_data(container, result)
|
|
185
|
+
return container
|
|
186
|
+
|
|
187
|
+
|
|
188
|
+
def __getattr__(name: str) -> Template:
|
|
189
|
+
"""Creat dynamic templates.
|
|
190
|
+
|
|
191
|
+
Only create a template if it's imported, this avoids creating all the base models for all templates
|
|
192
|
+
and potential import errors due to not available packages.
|
|
193
|
+
"""
|
|
194
|
+
if name in AnomalibTrain.WrapperEntry.module_att_names:
|
|
195
|
+
return make_dynamic_template(name, AnomalibTrain)
|
|
196
|
+
raise AttributeError(f"template `{name}` not found in {__name__}")
|
|
197
|
+
|
|
198
|
+
|
|
199
|
+
__all__ = AnomalibTrain.WrapperEntry.module_att_names
|
|
200
|
+
|
|
201
|
+
if SINAPSIS_BUILD_DOCS:
|
|
202
|
+
dynamic_templates = [__getattr__(template_name) for template_name in __all__]
|
|
203
|
+
for template in dynamic_templates:
|
|
204
|
+
globals()[template.__name__] = template
|
|
205
|
+
del template
|