dragon-ml-toolbox 13.8.0__py3-none-any.whl → 14.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of dragon-ml-toolbox might be problematic. Click here for more details.
- {dragon_ml_toolbox-13.8.0.dist-info → dragon_ml_toolbox-14.0.0.dist-info}/METADATA +2 -1
- {dragon_ml_toolbox-13.8.0.dist-info → dragon_ml_toolbox-14.0.0.dist-info}/RECORD +21 -14
- ml_tools/ML_datasetmaster.py +2 -185
- ml_tools/ML_evaluation.py +3 -3
- ml_tools/ML_inference.py +0 -1
- ml_tools/ML_models.py +3 -1
- ml_tools/ML_trainer.py +446 -11
- ml_tools/ML_utilities.py +50 -1
- ml_tools/ML_vision_datasetmaster.py +1315 -0
- ml_tools/ML_vision_evaluation.py +260 -0
- ml_tools/ML_vision_inference.py +428 -0
- ml_tools/ML_vision_models.py +627 -0
- ml_tools/ML_vision_transformers.py +58 -0
- ml_tools/_ML_pytorch_tabular.py +543 -0
- ml_tools/_ML_vision_recipe.py +88 -0
- ml_tools/custom_logger.py +11 -6
- ml_tools/keys.py +30 -0
- {dragon_ml_toolbox-13.8.0.dist-info → dragon_ml_toolbox-14.0.0.dist-info}/WHEEL +0 -0
- {dragon_ml_toolbox-13.8.0.dist-info → dragon_ml_toolbox-14.0.0.dist-info}/licenses/LICENSE +0 -0
- {dragon_ml_toolbox-13.8.0.dist-info → dragon_ml_toolbox-14.0.0.dist-info}/licenses/LICENSE-THIRD-PARTY.md +0 -0
- {dragon_ml_toolbox-13.8.0.dist-info → dragon_ml_toolbox-14.0.0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,627 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
from torch import nn
|
|
3
|
+
import torchvision.models as vision_models
|
|
4
|
+
from torchvision.models import detection as detection_models
|
|
5
|
+
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
|
|
6
|
+
from typing import List, Dict, Any, Literal, Optional
|
|
7
|
+
from abc import ABC, abstractmethod
|
|
8
|
+
|
|
9
|
+
from .ML_models import _ArchitectureHandlerMixin
|
|
10
|
+
from ._logger import _LOGGER
|
|
11
|
+
from ._script_info import _script_info
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
__all__ = [
|
|
15
|
+
"DragonResNet",
|
|
16
|
+
"DragonEfficientNet",
|
|
17
|
+
"DragonVGG",
|
|
18
|
+
"DragonFCN",
|
|
19
|
+
"DragonDeepLabv3",
|
|
20
|
+
"DragonFastRCNN",
|
|
21
|
+
]
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class _BaseVisionWrapper(nn.Module, _ArchitectureHandlerMixin, ABC):
|
|
25
|
+
"""
|
|
26
|
+
Abstract base class for torchvision model wrappers.
|
|
27
|
+
|
|
28
|
+
Handles common logic for:
|
|
29
|
+
- Model instantiation (with/without pretrained weights)
|
|
30
|
+
- Input layer modification (for custom in_channels)
|
|
31
|
+
- Output layer modification (for custom num_classes)
|
|
32
|
+
- Architecture saving/loading and representation
|
|
33
|
+
"""
|
|
34
|
+
def __init__(self,
|
|
35
|
+
num_classes: int,
|
|
36
|
+
in_channels: int,
|
|
37
|
+
model_name: str,
|
|
38
|
+
init_with_pretrained: bool,
|
|
39
|
+
weights_enum_name: Optional[str] = None):
|
|
40
|
+
super().__init__()
|
|
41
|
+
|
|
42
|
+
# --- 1. Validation and Configuration ---
|
|
43
|
+
if not hasattr(vision_models, model_name):
|
|
44
|
+
_LOGGER.error(f"'{model_name}' is not a valid model name in torchvision.models.")
|
|
45
|
+
raise ValueError()
|
|
46
|
+
|
|
47
|
+
self.num_classes = num_classes
|
|
48
|
+
self.in_channels = in_channels
|
|
49
|
+
self.model_name = model_name
|
|
50
|
+
|
|
51
|
+
# --- 2. Instantiate the base model ---
|
|
52
|
+
if init_with_pretrained:
|
|
53
|
+
weights_enum = getattr(vision_models, weights_enum_name, None) if weights_enum_name else None
|
|
54
|
+
weights = weights_enum.IMAGENET1K_V1 if weights_enum else None
|
|
55
|
+
|
|
56
|
+
if weights is None and init_with_pretrained:
|
|
57
|
+
_LOGGER.warning(f"Could not find modern weights for {model_name}. Using 'pretrained=True' legacy fallback.")
|
|
58
|
+
self.model = getattr(vision_models, model_name)(pretrained=True)
|
|
59
|
+
else:
|
|
60
|
+
self.model = getattr(vision_models, model_name)(weights=weights)
|
|
61
|
+
else:
|
|
62
|
+
self.model = getattr(vision_models, model_name)(weights=None)
|
|
63
|
+
|
|
64
|
+
# --- 3. Modify the input layer (using abstract method) ---
|
|
65
|
+
if in_channels != 3:
|
|
66
|
+
original_conv1 = self._get_input_layer()
|
|
67
|
+
|
|
68
|
+
new_conv1 = nn.Conv2d(
|
|
69
|
+
in_channels,
|
|
70
|
+
original_conv1.out_channels,
|
|
71
|
+
kernel_size=original_conv1.kernel_size, # type: ignore
|
|
72
|
+
stride=original_conv1.stride, # type: ignore
|
|
73
|
+
padding=original_conv1.padding, # type: ignore
|
|
74
|
+
bias=(original_conv1.bias is not None)
|
|
75
|
+
)
|
|
76
|
+
|
|
77
|
+
# (Optional) Average original weights if starting from pretrained
|
|
78
|
+
if init_with_pretrained and original_conv1.in_channels == 3:
|
|
79
|
+
with torch.no_grad():
|
|
80
|
+
avg_weights = torch.mean(original_conv1.weight, dim=1, keepdim=True)
|
|
81
|
+
new_conv1.weight[:] = avg_weights.repeat(1, in_channels, 1, 1)
|
|
82
|
+
|
|
83
|
+
self._set_input_layer(new_conv1)
|
|
84
|
+
|
|
85
|
+
# --- 4. Modify the output layer (using abstract method) ---
|
|
86
|
+
original_fc = self._get_output_layer()
|
|
87
|
+
if original_fc is None: # Handle case where layer isn't found
|
|
88
|
+
_LOGGER.error(f"Model '{model_name}' has an unexpected classifier structure. Cannot replace final layer.")
|
|
89
|
+
raise AttributeError("Could not find final classifier layer.")
|
|
90
|
+
|
|
91
|
+
num_filters = original_fc.in_features
|
|
92
|
+
self._set_output_layer(nn.Linear(num_filters, num_classes))
|
|
93
|
+
|
|
94
|
+
@abstractmethod
|
|
95
|
+
def _get_input_layer(self) -> nn.Conv2d:
|
|
96
|
+
"""Returns the first convolutional layer of the model."""
|
|
97
|
+
raise NotImplementedError
|
|
98
|
+
|
|
99
|
+
@abstractmethod
|
|
100
|
+
def _set_input_layer(self, layer: nn.Conv2d):
|
|
101
|
+
"""Sets the first convolutional layer of the model."""
|
|
102
|
+
raise NotImplementedError
|
|
103
|
+
|
|
104
|
+
@abstractmethod
|
|
105
|
+
def _get_output_layer(self) -> Optional[nn.Linear]:
|
|
106
|
+
"""Returns the final fully-connected layer of the model."""
|
|
107
|
+
raise NotImplementedError
|
|
108
|
+
|
|
109
|
+
@abstractmethod
|
|
110
|
+
def _set_output_layer(self, layer: nn.Linear):
|
|
111
|
+
"""Sets the final fully-connected layer of the model."""
|
|
112
|
+
raise NotImplementedError
|
|
113
|
+
|
|
114
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
115
|
+
"""Defines the forward pass of the model."""
|
|
116
|
+
return self.model(x)
|
|
117
|
+
|
|
118
|
+
def get_architecture_config(self) -> Dict[str, Any]:
|
|
119
|
+
"""
|
|
120
|
+
Returns the structural configuration of the model.
|
|
121
|
+
The 'init_with_pretrained' flag is intentionally omitted,
|
|
122
|
+
as .load() should restore the architecture, not the weights.
|
|
123
|
+
"""
|
|
124
|
+
return {
|
|
125
|
+
'num_classes': self.num_classes,
|
|
126
|
+
'in_channels': self.in_channels,
|
|
127
|
+
'model_name': self.model_name
|
|
128
|
+
}
|
|
129
|
+
|
|
130
|
+
def __repr__(self) -> str:
|
|
131
|
+
"""Returns the developer-friendly string representation of the model."""
|
|
132
|
+
return (
|
|
133
|
+
f"{self.__class__.__name__}(model='{self.model_name}', "
|
|
134
|
+
f"in_channels={self.in_channels}, "
|
|
135
|
+
f"num_classes={self.num_classes})"
|
|
136
|
+
)
|
|
137
|
+
|
|
138
|
+
|
|
139
|
+
class DragonResNet(_BaseVisionWrapper):
|
|
140
|
+
"""
|
|
141
|
+
Image Classification
|
|
142
|
+
|
|
143
|
+
A customizable wrapper for the torchvision ResNet family, compatible
|
|
144
|
+
with saving/loading architecture.
|
|
145
|
+
|
|
146
|
+
This wrapper allows for customizing the model backbone, input channels,
|
|
147
|
+
and the number of output classes for transfer learning.
|
|
148
|
+
"""
|
|
149
|
+
def __init__(self,
|
|
150
|
+
num_classes: int,
|
|
151
|
+
in_channels: int = 3,
|
|
152
|
+
model_name: Literal["resnet18", "resnet34", "resnet50", "resnet101", "resnet152"] = 'resnet50',
|
|
153
|
+
init_with_pretrained: bool = False):
|
|
154
|
+
"""
|
|
155
|
+
Args:
|
|
156
|
+
num_classes (int):
|
|
157
|
+
Number of output classes for the final layer.
|
|
158
|
+
in_channels (int):
|
|
159
|
+
Number of input channels (e.g., 1 for grayscale, 3 for RGB).
|
|
160
|
+
model_name (str):
|
|
161
|
+
The name of the ResNet model to use (e.g., 'resnet18', 'resnet34', 'resnet50', 'resnet101', 'resnet152'). Number is the layer count.
|
|
162
|
+
init_with_pretrained (bool):
|
|
163
|
+
If True, initializes the model with weights pretrained on ImageNet. This flag is for initialization only and is NOT saved in the architecture config.
|
|
164
|
+
"""
|
|
165
|
+
|
|
166
|
+
weights_enum_name = getattr(vision_models, f"{model_name.upper()}_Weights", None)
|
|
167
|
+
|
|
168
|
+
super().__init__(
|
|
169
|
+
num_classes=num_classes,
|
|
170
|
+
in_channels=in_channels,
|
|
171
|
+
model_name=model_name,
|
|
172
|
+
init_with_pretrained=init_with_pretrained,
|
|
173
|
+
weights_enum_name=weights_enum_name
|
|
174
|
+
)
|
|
175
|
+
|
|
176
|
+
def _get_input_layer(self) -> nn.Conv2d:
|
|
177
|
+
return self.model.conv1
|
|
178
|
+
|
|
179
|
+
def _set_input_layer(self, layer: nn.Conv2d):
|
|
180
|
+
self.model.conv1 = layer
|
|
181
|
+
|
|
182
|
+
def _get_output_layer(self) -> Optional[nn.Linear]:
|
|
183
|
+
return self.model.fc
|
|
184
|
+
|
|
185
|
+
def _set_output_layer(self, layer: nn.Linear):
|
|
186
|
+
self.model.fc = layer
|
|
187
|
+
|
|
188
|
+
|
|
189
|
+
class DragonEfficientNet(_BaseVisionWrapper):
|
|
190
|
+
"""
|
|
191
|
+
Image Classification
|
|
192
|
+
|
|
193
|
+
A customizable wrapper for the torchvision EfficientNet family, compatible
|
|
194
|
+
with saving/loading architecture.
|
|
195
|
+
|
|
196
|
+
This wrapper allows for customizing the model backbone, input channels,
|
|
197
|
+
and the number of output classes for transfer learning.
|
|
198
|
+
"""
|
|
199
|
+
def __init__(self,
|
|
200
|
+
num_classes: int,
|
|
201
|
+
in_channels: int = 3,
|
|
202
|
+
model_name: str = 'efficientnet_b0',
|
|
203
|
+
init_with_pretrained: bool = False):
|
|
204
|
+
"""
|
|
205
|
+
Args:
|
|
206
|
+
num_classes (int):
|
|
207
|
+
Number of output classes for the final layer.
|
|
208
|
+
in_channels (int):
|
|
209
|
+
Number of input channels (e.g., 1 for grayscale, 3 for RGB).
|
|
210
|
+
model_name (str):
|
|
211
|
+
The name of the EfficientNet model to use (e.g., 'efficientnet_b0'
|
|
212
|
+
through 'efficientnet_b7', or 'efficientnet_v2_s', 'efficientnet_v2_m', 'efficientnet_v2_l').
|
|
213
|
+
init_with_pretrained (bool):
|
|
214
|
+
If True, initializes the model with weights pretrained on
|
|
215
|
+
ImageNet. This flag is for initialization only and is
|
|
216
|
+
NOT saved in the architecture config. Defaults to False.
|
|
217
|
+
"""
|
|
218
|
+
|
|
219
|
+
weights_enum_name = getattr(vision_models, f"{model_name.upper()}_Weights", None)
|
|
220
|
+
|
|
221
|
+
super().__init__(
|
|
222
|
+
num_classes=num_classes,
|
|
223
|
+
in_channels=in_channels,
|
|
224
|
+
model_name=model_name,
|
|
225
|
+
init_with_pretrained=init_with_pretrained,
|
|
226
|
+
weights_enum_name=weights_enum_name
|
|
227
|
+
)
|
|
228
|
+
|
|
229
|
+
def _get_input_layer(self) -> nn.Conv2d:
|
|
230
|
+
# The first conv layer in EfficientNet is model.features[0][0]
|
|
231
|
+
return self.model.features[0][0]
|
|
232
|
+
|
|
233
|
+
def _set_input_layer(self, layer: nn.Conv2d):
|
|
234
|
+
self.model.features[0][0] = layer
|
|
235
|
+
|
|
236
|
+
def _get_output_layer(self) -> Optional[nn.Linear]:
|
|
237
|
+
# The classifier in EfficientNet is model.classifier[1]
|
|
238
|
+
if hasattr(self.model, 'classifier') and isinstance(self.model.classifier, nn.Sequential):
|
|
239
|
+
output_layer = self.model.classifier[1]
|
|
240
|
+
if isinstance(output_layer, nn.Linear):
|
|
241
|
+
return output_layer
|
|
242
|
+
return None
|
|
243
|
+
|
|
244
|
+
def _set_output_layer(self, layer: nn.Linear):
|
|
245
|
+
self.model.classifier[1] = layer
|
|
246
|
+
|
|
247
|
+
|
|
248
|
+
class DragonVGG(_BaseVisionWrapper):
|
|
249
|
+
"""
|
|
250
|
+
Image Classification
|
|
251
|
+
|
|
252
|
+
A customizable wrapper for the torchvision VGG family, compatible
|
|
253
|
+
with saving/loading architecture.
|
|
254
|
+
|
|
255
|
+
This wrapper allows for customizing the model backbone, input channels,
|
|
256
|
+
and the number of output classes for transfer learning.
|
|
257
|
+
"""
|
|
258
|
+
def __init__(self,
|
|
259
|
+
num_classes: int,
|
|
260
|
+
in_channels: int = 3,
|
|
261
|
+
model_name: Literal["vgg11", "vgg13", "vgg16", "vgg19", "vgg11_bn", "vgg13_bn", "vgg16_bn", "vgg19_bn"] = 'vgg16',
|
|
262
|
+
init_with_pretrained: bool = False):
|
|
263
|
+
"""
|
|
264
|
+
Args:
|
|
265
|
+
num_classes (int):
|
|
266
|
+
Number of output classes for the final layer.
|
|
267
|
+
in_channels (int):
|
|
268
|
+
Number of input channels (e.g., 1 for grayscale, 3 for RGB).
|
|
269
|
+
model_name (str):
|
|
270
|
+
The name of the VGG model to use (e.g., 'vgg16', 'vgg16_bn').
|
|
271
|
+
init_with_pretrained (bool):
|
|
272
|
+
If True, initializes the model with weights pretrained on
|
|
273
|
+
ImageNet. This flag is for initialization only and is
|
|
274
|
+
NOT saved in the architecture config. Defaults to False.
|
|
275
|
+
"""
|
|
276
|
+
|
|
277
|
+
# Format model name to find weights enum, e.g., vgg16_bn -> VGG16_BN_Weights
|
|
278
|
+
weights_enum_name = f"{model_name.replace('_bn', '_BN').upper()}_Weights"
|
|
279
|
+
|
|
280
|
+
super().__init__(
|
|
281
|
+
num_classes=num_classes,
|
|
282
|
+
in_channels=in_channels,
|
|
283
|
+
model_name=model_name,
|
|
284
|
+
init_with_pretrained=init_with_pretrained,
|
|
285
|
+
weights_enum_name=weights_enum_name
|
|
286
|
+
)
|
|
287
|
+
|
|
288
|
+
def _get_input_layer(self) -> nn.Conv2d:
|
|
289
|
+
# The first conv layer in VGG is model.features[0]
|
|
290
|
+
return self.model.features[0]
|
|
291
|
+
|
|
292
|
+
def _set_input_layer(self, layer: nn.Conv2d):
|
|
293
|
+
self.model.features[0] = layer
|
|
294
|
+
|
|
295
|
+
def _get_output_layer(self) -> Optional[nn.Linear]:
|
|
296
|
+
# The final classifier in VGG is model.classifier[6]
|
|
297
|
+
if hasattr(self.model, 'classifier') and isinstance(self.model.classifier, nn.Sequential) and len(self.model.classifier) == 7:
|
|
298
|
+
output_layer = self.model.classifier[6]
|
|
299
|
+
if isinstance(output_layer, nn.Linear):
|
|
300
|
+
return output_layer
|
|
301
|
+
return None
|
|
302
|
+
|
|
303
|
+
def _set_output_layer(self, layer: nn.Linear):
|
|
304
|
+
self.model.classifier[6] = layer
|
|
305
|
+
|
|
306
|
+
|
|
307
|
+
# Image segmentation
|
|
308
|
+
class _BaseSegmentationWrapper(nn.Module, _ArchitectureHandlerMixin, ABC):
|
|
309
|
+
"""
|
|
310
|
+
Abstract base class for torchvision segmentation model wrappers.
|
|
311
|
+
|
|
312
|
+
Handles common logic for:
|
|
313
|
+
- Model instantiation (with/without pretrained weights and custom num_classes)
|
|
314
|
+
- Input layer modification (for custom in_channels)
|
|
315
|
+
- Forward pass dictionary unpacking (returns 'out' tensor)
|
|
316
|
+
- Architecture saving/loading and representation
|
|
317
|
+
"""
|
|
318
|
+
def __init__(self,
|
|
319
|
+
num_classes: int,
|
|
320
|
+
in_channels: int,
|
|
321
|
+
model_name: str,
|
|
322
|
+
init_with_pretrained: bool,
|
|
323
|
+
weights_enum_name: Optional[str] = None):
|
|
324
|
+
super().__init__()
|
|
325
|
+
|
|
326
|
+
# --- 1. Validation and Configuration ---
|
|
327
|
+
if not hasattr(vision_models.segmentation, model_name):
|
|
328
|
+
_LOGGER.error(f"'{model_name}' is not a valid model name in torchvision.models.segmentation.")
|
|
329
|
+
raise ValueError()
|
|
330
|
+
|
|
331
|
+
self.num_classes = num_classes
|
|
332
|
+
self.in_channels = in_channels
|
|
333
|
+
self.model_name = model_name
|
|
334
|
+
|
|
335
|
+
# --- 2. Instantiate the base model ---
|
|
336
|
+
model_kwargs = {
|
|
337
|
+
'num_classes': num_classes,
|
|
338
|
+
'weights': None
|
|
339
|
+
}
|
|
340
|
+
model_constructor = getattr(vision_models.segmentation, model_name)
|
|
341
|
+
|
|
342
|
+
if init_with_pretrained:
|
|
343
|
+
weights_enum = getattr(vision_models.segmentation, weights_enum_name, None) if weights_enum_name else None
|
|
344
|
+
weights = weights_enum.DEFAULT if weights_enum else None
|
|
345
|
+
|
|
346
|
+
if weights is None:
|
|
347
|
+
_LOGGER.warning(f"Could not find modern weights for {model_name}. Using 'pretrained=True' legacy fallback.")
|
|
348
|
+
# Legacy models used 'pretrained=True' and num_classes was separate
|
|
349
|
+
self.model = model_constructor(pretrained=True, **model_kwargs)
|
|
350
|
+
else:
|
|
351
|
+
# Modern way: weights object implies pretraining
|
|
352
|
+
model_kwargs['weights'] = weights
|
|
353
|
+
self.model = model_constructor(**model_kwargs)
|
|
354
|
+
else:
|
|
355
|
+
self.model = model_constructor(**model_kwargs)
|
|
356
|
+
|
|
357
|
+
# --- 3. Modify the input layer (using abstract method) ---
|
|
358
|
+
if in_channels != 3:
|
|
359
|
+
original_conv1 = self._get_input_layer()
|
|
360
|
+
|
|
361
|
+
new_conv1 = nn.Conv2d(
|
|
362
|
+
in_channels,
|
|
363
|
+
original_conv1.out_channels,
|
|
364
|
+
kernel_size=original_conv1.kernel_size, # type: ignore
|
|
365
|
+
stride=original_conv1.stride, # type: ignore
|
|
366
|
+
padding=original_conv1.padding, # type: ignore
|
|
367
|
+
bias=(original_conv1.bias is not None)
|
|
368
|
+
)
|
|
369
|
+
|
|
370
|
+
# (Optional) Average original weights if starting from pretrained
|
|
371
|
+
if init_with_pretrained and original_conv1.in_channels == 3:
|
|
372
|
+
with torch.no_grad():
|
|
373
|
+
avg_weights = torch.mean(original_conv1.weight, dim=1, keepdim=True)
|
|
374
|
+
new_conv1.weight[:] = avg_weights.repeat(1, in_channels, 1, 1)
|
|
375
|
+
|
|
376
|
+
self._set_input_layer(new_conv1)
|
|
377
|
+
|
|
378
|
+
@abstractmethod
|
|
379
|
+
def _get_input_layer(self) -> nn.Conv2d:
|
|
380
|
+
"""Returns the first convolutional layer of the model (in the backbone)."""
|
|
381
|
+
raise NotImplementedError
|
|
382
|
+
|
|
383
|
+
@abstractmethod
|
|
384
|
+
def _set_input_layer(self, layer: nn.Conv2d):
|
|
385
|
+
"""Sets the first convolutional layer of the model (in the backbone)."""
|
|
386
|
+
raise NotImplementedError
|
|
387
|
+
|
|
388
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
389
|
+
"""
|
|
390
|
+
Defines the forward pass.
|
|
391
|
+
Returns the 'out' tensor from the segmentation model's output dict.
|
|
392
|
+
"""
|
|
393
|
+
output_dict = self.model(x)
|
|
394
|
+
return output_dict['out'] # Key for standard torchvision seg models
|
|
395
|
+
|
|
396
|
+
def get_architecture_config(self) -> Dict[str, Any]:
|
|
397
|
+
"""
|
|
398
|
+
Returns the structural configuration of the model.
|
|
399
|
+
The 'init_with_pretrained' flag is intentionally omitted,
|
|
400
|
+
as .load() should restore the architecture, not the weights.
|
|
401
|
+
"""
|
|
402
|
+
return {
|
|
403
|
+
'num_classes': self.num_classes,
|
|
404
|
+
'in_channels': self.in_channels,
|
|
405
|
+
'model_name': self.model_name
|
|
406
|
+
}
|
|
407
|
+
|
|
408
|
+
def __repr__(self) -> str:
|
|
409
|
+
"""Returns the developer-friendly string representation of the model."""
|
|
410
|
+
return (
|
|
411
|
+
f"{self.__class__.__name__}(model='{self.model_name}', "
|
|
412
|
+
f"in_channels={self.in_channels}, "
|
|
413
|
+
f"num_classes={self.num_classes})"
|
|
414
|
+
)
|
|
415
|
+
|
|
416
|
+
|
|
417
|
+
class DragonFCN(_BaseSegmentationWrapper):
|
|
418
|
+
"""
|
|
419
|
+
Image Segmentation
|
|
420
|
+
|
|
421
|
+
A customizable wrapper for the torchvision FCN (Fully Convolutional Network)
|
|
422
|
+
family, compatible with saving/loading architecture.
|
|
423
|
+
|
|
424
|
+
This wrapper allows for customizing the model backbone, input channels,
|
|
425
|
+
and the number of output classes for transfer learning.
|
|
426
|
+
"""
|
|
427
|
+
def __init__(self,
|
|
428
|
+
num_classes: int,
|
|
429
|
+
in_channels: int = 3,
|
|
430
|
+
model_name: Literal["fcn_resnet50", "fcn_resnet101"] = 'fcn_resnet50',
|
|
431
|
+
init_with_pretrained: bool = False):
|
|
432
|
+
"""
|
|
433
|
+
Args:
|
|
434
|
+
num_classes (int):
|
|
435
|
+
Number of output classes (including background).
|
|
436
|
+
in_channels (int):
|
|
437
|
+
Number of input channels (e.g., 1 for grayscale, 3 for RGB).
|
|
438
|
+
model_name (str):
|
|
439
|
+
The name of the FCN model to use ('fcn_resnet50' or 'fcn_resnet101').
|
|
440
|
+
init_with_pretrained (bool):
|
|
441
|
+
If True, initializes the model with weights pretrained on COCO.
|
|
442
|
+
This flag is for initialization only and is NOT saved in the
|
|
443
|
+
architecture config. Defaults to False.
|
|
444
|
+
"""
|
|
445
|
+
# Format model name to find weights enum, e.g., fcn_resnet50 -> FCN_ResNet50_Weights
|
|
446
|
+
weights_model_name = model_name.replace('fcn_', 'FCN_').replace('resnet', 'ResNet')
|
|
447
|
+
weights_enum_name = f"{weights_model_name}_Weights"
|
|
448
|
+
|
|
449
|
+
super().__init__(
|
|
450
|
+
num_classes=num_classes,
|
|
451
|
+
in_channels=in_channels,
|
|
452
|
+
model_name=model_name,
|
|
453
|
+
init_with_pretrained=init_with_pretrained,
|
|
454
|
+
weights_enum_name=weights_enum_name
|
|
455
|
+
)
|
|
456
|
+
|
|
457
|
+
def _get_input_layer(self) -> nn.Conv2d:
|
|
458
|
+
# FCN models use a ResNet backbone, input layer is backbone.conv1
|
|
459
|
+
return self.model.backbone.conv1
|
|
460
|
+
|
|
461
|
+
def _set_input_layer(self, layer: nn.Conv2d):
|
|
462
|
+
self.model.backbone.conv1 = layer
|
|
463
|
+
|
|
464
|
+
|
|
465
|
+
class DragonDeepLabv3(_BaseSegmentationWrapper):
|
|
466
|
+
"""
|
|
467
|
+
Image Segmentation
|
|
468
|
+
|
|
469
|
+
A customizable wrapper for the torchvision DeepLabv3 family, compatible
|
|
470
|
+
with saving/loading architecture.
|
|
471
|
+
|
|
472
|
+
This wrapper allows for customizing the model backbone, input channels,
|
|
473
|
+
and the number of output classes for transfer learning.
|
|
474
|
+
"""
|
|
475
|
+
def __init__(self,
|
|
476
|
+
num_classes: int,
|
|
477
|
+
in_channels: int = 3,
|
|
478
|
+
model_name: Literal["deeplabv3_resnet50", "deeplabv3_resnet101"] = 'deeplabv3_resnet50',
|
|
479
|
+
init_with_pretrained: bool = False):
|
|
480
|
+
"""
|
|
481
|
+
Args:
|
|
482
|
+
num_classes (int):
|
|
483
|
+
Number of output classes (including background).
|
|
484
|
+
in_channels (int):
|
|
485
|
+
Number of input channels (e.g., 1 for grayscale, 3 for RGB).
|
|
486
|
+
model_name (str):
|
|
487
|
+
The name of the DeepLabv3 model to use ('deeplabv3_resnet50' or 'deeplabv3_resnet101').
|
|
488
|
+
init_with_pretrained (bool):
|
|
489
|
+
If True, initializes the model with weights pretrained on COCO.
|
|
490
|
+
This flag is for initialization only and is NOT saved in the
|
|
491
|
+
architecture config. Defaults to False.
|
|
492
|
+
"""
|
|
493
|
+
|
|
494
|
+
# Format model name to find weights enum, e.g., deeplabv3_resnet50 -> DeepLabV3_ResNet50_Weights
|
|
495
|
+
weights_model_name = model_name.replace('deeplabv3_', 'DeepLabV3_').replace('resnet', 'ResNet')
|
|
496
|
+
weights_enum_name = f"{weights_model_name}_Weights"
|
|
497
|
+
|
|
498
|
+
super().__init__(
|
|
499
|
+
num_classes=num_classes,
|
|
500
|
+
in_channels=in_channels,
|
|
501
|
+
model_name=model_name,
|
|
502
|
+
init_with_pretrained=init_with_pretrained,
|
|
503
|
+
weights_enum_name=weights_enum_name
|
|
504
|
+
)
|
|
505
|
+
|
|
506
|
+
def _get_input_layer(self) -> nn.Conv2d:
|
|
507
|
+
# DeepLabv3 models use a ResNet backbone, input layer is backbone.conv1
|
|
508
|
+
return self.model.backbone.conv1
|
|
509
|
+
|
|
510
|
+
def _set_input_layer(self, layer: nn.Conv2d):
|
|
511
|
+
self.model.backbone.conv1 = layer
|
|
512
|
+
|
|
513
|
+
|
|
514
|
+
class DragonFastRCNN(nn.Module, _ArchitectureHandlerMixin):
|
|
515
|
+
"""
|
|
516
|
+
Object Detection
|
|
517
|
+
|
|
518
|
+
A customizable wrapper for the torchvision Faster R-CNN family.
|
|
519
|
+
|
|
520
|
+
This wrapper allows for customizing the model backbone, input channels,
|
|
521
|
+
and the number of output classes for transfer learning.
|
|
522
|
+
|
|
523
|
+
NOTE: This model is NOT compatible with the MLTrainer class.
|
|
524
|
+
"""
|
|
525
|
+
def __init__(self,
|
|
526
|
+
num_classes: int,
|
|
527
|
+
in_channels: int = 3,
|
|
528
|
+
model_name: Literal["fasterrcnn_resnet50_fpn", "fasterrcnn_resnet50_fpn_v2"] = 'fasterrcnn_resnet50_fpn_v2',
|
|
529
|
+
init_with_pretrained: bool = False):
|
|
530
|
+
"""
|
|
531
|
+
Args:
|
|
532
|
+
num_classes (int):
|
|
533
|
+
Number of output classes (including background).
|
|
534
|
+
in_channels (int):
|
|
535
|
+
Number of input channels (e.g., 1 for grayscale, 3 for RGB).
|
|
536
|
+
model_name (str):
|
|
537
|
+
The name of the Faster R-CNN model to use.
|
|
538
|
+
init_with_pretrained (bool):
|
|
539
|
+
If True, initializes the model with weights pretrained on COCO.
|
|
540
|
+
This flag is for initialization only and is NOT saved in the
|
|
541
|
+
architecture config. Defaults to False.
|
|
542
|
+
"""
|
|
543
|
+
super().__init__()
|
|
544
|
+
|
|
545
|
+
# --- 1. Validation and Configuration ---
|
|
546
|
+
if not hasattr(detection_models, model_name):
|
|
547
|
+
_LOGGER.error(f"'{model_name}' is not a valid model name in torchvision.models.detection.")
|
|
548
|
+
raise ValueError()
|
|
549
|
+
|
|
550
|
+
self.num_classes = num_classes
|
|
551
|
+
self.in_channels = in_channels
|
|
552
|
+
self.model_name = model_name
|
|
553
|
+
|
|
554
|
+
# --- 2. Instantiate the base model ---
|
|
555
|
+
model_constructor = getattr(detection_models, model_name)
|
|
556
|
+
|
|
557
|
+
# Format model name to find weights enum, e.g., fasterrcnn_resnet50_fpn_v2 -> FasterRCNN_ResNet50_FPN_V2_Weights
|
|
558
|
+
weights_model_name = model_name.replace('fasterrcnn_', 'FasterRCNN_').replace('resnet', 'ResNet').replace('_fpn', '_FPN')
|
|
559
|
+
weights_enum_name = f"{weights_model_name.upper()}_Weights"
|
|
560
|
+
|
|
561
|
+
weights_enum = getattr(detection_models, weights_enum_name, None) if weights_enum_name else None
|
|
562
|
+
weights = weights_enum.DEFAULT if weights_enum and init_with_pretrained else None
|
|
563
|
+
|
|
564
|
+
self.model = model_constructor(weights=weights, weights_backbone=weights)
|
|
565
|
+
|
|
566
|
+
# --- 4. Modify the output layer (Box Predictor) ---
|
|
567
|
+
# Get the number of input features for the classifier
|
|
568
|
+
in_features = self.model.roi_heads.box_predictor.cls_score.in_features
|
|
569
|
+
# Replace the pre-trained head with a new one
|
|
570
|
+
self.model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)
|
|
571
|
+
|
|
572
|
+
# --- 3. Modify the input layer (Backbone conv1) ---
|
|
573
|
+
if in_channels != 3:
|
|
574
|
+
original_conv1 = self.model.backbone.body.conv1
|
|
575
|
+
|
|
576
|
+
new_conv1 = nn.Conv2d(
|
|
577
|
+
in_channels,
|
|
578
|
+
original_conv1.out_channels,
|
|
579
|
+
kernel_size=original_conv1.kernel_size, # type: ignore
|
|
580
|
+
stride=original_conv1.stride, # type: ignore
|
|
581
|
+
padding=original_conv1.padding, # type: ignore
|
|
582
|
+
bias=(original_conv1.bias is not None)
|
|
583
|
+
)
|
|
584
|
+
|
|
585
|
+
# (Optional) Average original weights if starting from pretrained
|
|
586
|
+
if init_with_pretrained and original_conv1.in_channels == 3 and weights is not None:
|
|
587
|
+
with torch.no_grad():
|
|
588
|
+
# Average the weights across the input channel dimension
|
|
589
|
+
avg_weights = torch.mean(original_conv1.weight, dim=1, keepdim=True)
|
|
590
|
+
# Repeat the averaged weights for the new number of input channels
|
|
591
|
+
new_conv1.weight[:] = avg_weights.repeat(1, in_channels, 1, 1)
|
|
592
|
+
|
|
593
|
+
self.model.backbone.body.conv1 = new_conv1
|
|
594
|
+
|
|
595
|
+
def forward(self, images: List[torch.Tensor], targets: Optional[List[Dict[str, torch.Tensor]]] = None):
|
|
596
|
+
"""
|
|
597
|
+
Defines the forward pass.
|
|
598
|
+
|
|
599
|
+
- In train mode, expects (images, targets) and returns a dict of losses.
|
|
600
|
+
- In eval mode, expects (images) and returns a list of prediction dicts.
|
|
601
|
+
"""
|
|
602
|
+
# The model's forward pass handles train/eval mode internally.
|
|
603
|
+
return self.model(images, targets)
|
|
604
|
+
|
|
605
|
+
def get_architecture_config(self) -> Dict[str, Any]:
|
|
606
|
+
"""
|
|
607
|
+
Returns the structural configuration of the model.
|
|
608
|
+
The 'init_with_pretrained' flag is intentionally omitted,
|
|
609
|
+
as .load() should restore the architecture, not the weights.
|
|
610
|
+
"""
|
|
611
|
+
return {
|
|
612
|
+
'num_classes': self.num_classes,
|
|
613
|
+
'in_channels': self.in_channels,
|
|
614
|
+
'model_name': self.model_name
|
|
615
|
+
}
|
|
616
|
+
|
|
617
|
+
def __repr__(self) -> str:
|
|
618
|
+
"""Returns the developer-friendly string representation of the model."""
|
|
619
|
+
return (
|
|
620
|
+
f"{self.__class__.__name__}(model='{self.model_name}', "
|
|
621
|
+
f"in_channels={self.in_channels}, "
|
|
622
|
+
f"num_classes={self.num_classes})"
|
|
623
|
+
)
|
|
624
|
+
|
|
625
|
+
|
|
626
|
+
def info():
|
|
627
|
+
_script_info(__all__)
|
|
@@ -0,0 +1,58 @@
|
|
|
1
|
+
from typing import Union, Dict, Type, Callable
|
|
2
|
+
from PIL import ImageOps, Image
|
|
3
|
+
|
|
4
|
+
from ._logger import _LOGGER
|
|
5
|
+
from ._script_info import _script_info
|
|
6
|
+
from .keys import VisionTransformRecipeKeys
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
__all__ = [
|
|
10
|
+
"TRANSFORM_REGISTRY",
|
|
11
|
+
"ResizeAspectFill"
|
|
12
|
+
]
|
|
13
|
+
|
|
14
|
+
# --- Custom Vision Transform Class ---
|
|
15
|
+
class ResizeAspectFill:
|
|
16
|
+
"""
|
|
17
|
+
Custom transformation to make an image square by padding it to match the
|
|
18
|
+
longest side, preserving the aspect ratio. The image is finally centered.
|
|
19
|
+
|
|
20
|
+
Args:
|
|
21
|
+
pad_color (Union[str, int]): Color to use for the padding.
|
|
22
|
+
Defaults to "black".
|
|
23
|
+
"""
|
|
24
|
+
def __init__(self, pad_color: Union[str, int] = "black") -> None:
|
|
25
|
+
self.pad_color = pad_color
|
|
26
|
+
# Store kwargs to allow for recreation
|
|
27
|
+
self.__setattr__(VisionTransformRecipeKeys.KWARGS, {"pad_color": pad_color})
|
|
28
|
+
# self._kwargs = {"pad_color": pad_color}
|
|
29
|
+
|
|
30
|
+
def __call__(self, image: Image.Image) -> Image.Image:
|
|
31
|
+
if not isinstance(image, Image.Image):
|
|
32
|
+
_LOGGER.error(f"Expected PIL.Image.Image, got {type(image).__name__}")
|
|
33
|
+
raise TypeError()
|
|
34
|
+
|
|
35
|
+
w, h = image.size
|
|
36
|
+
if w == h:
|
|
37
|
+
return image
|
|
38
|
+
|
|
39
|
+
# Determine padding to center the image
|
|
40
|
+
if w > h:
|
|
41
|
+
top_padding = (w - h) // 2
|
|
42
|
+
bottom_padding = w - h - top_padding
|
|
43
|
+
padding = (0, top_padding, 0, bottom_padding)
|
|
44
|
+
else: # h > w
|
|
45
|
+
left_padding = (h - w) // 2
|
|
46
|
+
right_padding = h - w - left_padding
|
|
47
|
+
padding = (left_padding, 0, right_padding, 0)
|
|
48
|
+
|
|
49
|
+
return ImageOps.expand(image, padding, fill=self.pad_color)
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
#NOTE: Add custom transforms here.
|
|
53
|
+
TRANSFORM_REGISTRY: Dict[str, Type[Callable]] = {
|
|
54
|
+
"ResizeAspectFill": ResizeAspectFill,
|
|
55
|
+
}
|
|
56
|
+
|
|
57
|
+
def info():
|
|
58
|
+
_script_info(__all__)
|