dragon-ml-toolbox 13.3.0__py3-none-any.whl → 16.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (48) hide show
  1. {dragon_ml_toolbox-13.3.0.dist-info → dragon_ml_toolbox-16.2.0.dist-info}/METADATA +20 -6
  2. dragon_ml_toolbox-16.2.0.dist-info/RECORD +51 -0
  3. {dragon_ml_toolbox-13.3.0.dist-info → dragon_ml_toolbox-16.2.0.dist-info}/licenses/LICENSE-THIRD-PARTY.md +10 -0
  4. ml_tools/ETL_cleaning.py +20 -20
  5. ml_tools/ETL_engineering.py +23 -25
  6. ml_tools/GUI_tools.py +20 -20
  7. ml_tools/MICE_imputation.py +207 -5
  8. ml_tools/ML_callbacks.py +43 -26
  9. ml_tools/ML_configuration.py +788 -0
  10. ml_tools/ML_datasetmaster.py +303 -448
  11. ml_tools/ML_evaluation.py +351 -93
  12. ml_tools/ML_evaluation_multi.py +139 -42
  13. ml_tools/ML_inference.py +290 -209
  14. ml_tools/ML_models.py +33 -106
  15. ml_tools/ML_models_advanced.py +323 -0
  16. ml_tools/ML_optimization.py +12 -12
  17. ml_tools/ML_scaler.py +11 -11
  18. ml_tools/ML_sequence_datasetmaster.py +341 -0
  19. ml_tools/ML_sequence_evaluation.py +219 -0
  20. ml_tools/ML_sequence_inference.py +391 -0
  21. ml_tools/ML_sequence_models.py +139 -0
  22. ml_tools/ML_trainer.py +1604 -179
  23. ml_tools/ML_utilities.py +351 -4
  24. ml_tools/ML_vision_datasetmaster.py +1540 -0
  25. ml_tools/ML_vision_evaluation.py +284 -0
  26. ml_tools/ML_vision_inference.py +405 -0
  27. ml_tools/ML_vision_models.py +641 -0
  28. ml_tools/ML_vision_transformers.py +284 -0
  29. ml_tools/PSO_optimization.py +6 -6
  30. ml_tools/SQL.py +4 -4
  31. ml_tools/_keys.py +171 -0
  32. ml_tools/_schema.py +1 -1
  33. ml_tools/custom_logger.py +37 -14
  34. ml_tools/data_exploration.py +502 -93
  35. ml_tools/ensemble_evaluation.py +54 -11
  36. ml_tools/ensemble_inference.py +7 -33
  37. ml_tools/ensemble_learning.py +1 -1
  38. ml_tools/math_utilities.py +1 -1
  39. ml_tools/optimization_tools.py +2 -2
  40. ml_tools/path_manager.py +5 -5
  41. ml_tools/serde.py +2 -2
  42. ml_tools/utilities.py +192 -4
  43. dragon_ml_toolbox-13.3.0.dist-info/RECORD +0 -41
  44. ml_tools/RNN_forecast.py +0 -56
  45. ml_tools/keys.py +0 -87
  46. {dragon_ml_toolbox-13.3.0.dist-info → dragon_ml_toolbox-16.2.0.dist-info}/WHEEL +0 -0
  47. {dragon_ml_toolbox-13.3.0.dist-info → dragon_ml_toolbox-16.2.0.dist-info}/licenses/LICENSE +0 -0
  48. {dragon_ml_toolbox-13.3.0.dist-info → dragon_ml_toolbox-16.2.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,641 @@
1
+ import torch
2
+ from torch import nn
3
+ import torchvision.models as vision_models
4
+ from torchvision.models import detection as detection_models
5
+ from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
6
+ from typing import List, Dict, Any, Literal, Optional
7
+ from abc import ABC, abstractmethod
8
+
9
+ from .ML_models import _ArchitectureHandlerMixin
10
+ from ._logger import _LOGGER
11
+ from ._script_info import _script_info
12
+
13
+
14
+ __all__ = [
15
+ "DragonResNet",
16
+ "DragonEfficientNet",
17
+ "DragonVGG",
18
+ "DragonFCN",
19
+ "DragonDeepLabv3",
20
+ "DragonFastRCNN",
21
+ ]
22
+
23
+
24
+ class _BaseVisionWrapper(nn.Module, _ArchitectureHandlerMixin, ABC):
25
+ """
26
+ Abstract base class for torchvision model wrappers.
27
+
28
+ Handles common logic for:
29
+ - Model instantiation (with/without pretrained weights)
30
+ - Input layer modification (for custom in_channels)
31
+ - Output layer modification (for custom num_classes)
32
+ - Architecture saving/loading and representation
33
+ """
34
+ def __init__(self,
35
+ num_classes: int,
36
+ in_channels: int,
37
+ model_name: str,
38
+ init_with_pretrained: bool,
39
+ weights_enum_name: Optional[str] = None):
40
+ super().__init__()
41
+
42
+ # --- 1. Validation and Configuration ---
43
+ if not hasattr(vision_models, model_name):
44
+ _LOGGER.error(f"'{model_name}' is not a valid model name in torchvision.models.")
45
+ raise ValueError()
46
+
47
+ self.num_classes = num_classes
48
+ self.in_channels = in_channels
49
+ self.model_name = model_name
50
+ self._pretrained_default_transforms = None
51
+
52
+ # --- 2. Instantiate the base model ---
53
+ if init_with_pretrained:
54
+ weights_enum = getattr(vision_models, weights_enum_name, None) if weights_enum_name else None
55
+ weights = weights_enum.IMAGENET1K_V1 if weights_enum else None
56
+
57
+ # Save transformations for pretrained models
58
+ if weights:
59
+ self._pretrained_default_transforms = weights.transforms()
60
+
61
+ if weights is None and init_with_pretrained:
62
+ _LOGGER.warning(f"Could not find modern weights for {model_name}. Using 'pretrained=True' legacy fallback.")
63
+ self.model = getattr(vision_models, model_name)(pretrained=True)
64
+ else:
65
+ self.model = getattr(vision_models, model_name)(weights=weights)
66
+ else:
67
+ self.model = getattr(vision_models, model_name)(weights=None)
68
+
69
+ # --- 3. Modify the input layer (using abstract method) ---
70
+ if in_channels != 3:
71
+ original_conv1 = self._get_input_layer()
72
+
73
+ new_conv1 = nn.Conv2d(
74
+ in_channels,
75
+ original_conv1.out_channels,
76
+ kernel_size=original_conv1.kernel_size, # type: ignore
77
+ stride=original_conv1.stride, # type: ignore
78
+ padding=original_conv1.padding, # type: ignore
79
+ bias=(original_conv1.bias is not None)
80
+ )
81
+
82
+ # (Optional) Average original weights if starting from pretrained
83
+ if init_with_pretrained and original_conv1.in_channels == 3:
84
+ with torch.no_grad():
85
+ avg_weights = torch.mean(original_conv1.weight, dim=1, keepdim=True)
86
+ new_conv1.weight[:] = avg_weights.repeat(1, in_channels, 1, 1)
87
+
88
+ self._set_input_layer(new_conv1)
89
+
90
+ # --- 4. Modify the output layer (using abstract method) ---
91
+ original_fc = self._get_output_layer()
92
+ if original_fc is None: # Handle case where layer isn't found
93
+ _LOGGER.error(f"Model '{model_name}' has an unexpected classifier structure. Cannot replace final layer.")
94
+ raise AttributeError("Could not find final classifier layer.")
95
+
96
+ num_filters = original_fc.in_features
97
+ self._set_output_layer(nn.Linear(num_filters, num_classes))
98
+
99
+ @abstractmethod
100
+ def _get_input_layer(self) -> nn.Conv2d:
101
+ """Returns the first convolutional layer of the model."""
102
+ raise NotImplementedError
103
+
104
+ @abstractmethod
105
+ def _set_input_layer(self, layer: nn.Conv2d):
106
+ """Sets the first convolutional layer of the model."""
107
+ raise NotImplementedError
108
+
109
+ @abstractmethod
110
+ def _get_output_layer(self) -> Optional[nn.Linear]:
111
+ """Returns the final fully-connected layer of the model."""
112
+ raise NotImplementedError
113
+
114
+ @abstractmethod
115
+ def _set_output_layer(self, layer: nn.Linear):
116
+ """Sets the final fully-connected layer of the model."""
117
+ raise NotImplementedError
118
+
119
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
120
+ """Defines the forward pass of the model."""
121
+ return self.model(x)
122
+
123
+ def get_architecture_config(self) -> Dict[str, Any]:
124
+ """
125
+ Returns the structural configuration of the model.
126
+ The 'init_with_pretrained' flag is intentionally omitted,
127
+ as .load() should restore the architecture, not the weights.
128
+ """
129
+ return {
130
+ 'num_classes': self.num_classes,
131
+ 'in_channels': self.in_channels,
132
+ 'model_name': self.model_name
133
+ }
134
+
135
+ def __repr__(self) -> str:
136
+ """Returns the developer-friendly string representation of the model."""
137
+ return (
138
+ f"{self.__class__.__name__}(model='{self.model_name}', "
139
+ f"in_channels={self.in_channels}, "
140
+ f"num_classes={self.num_classes})"
141
+ )
142
+
143
+
144
+ class DragonResNet(_BaseVisionWrapper):
145
+ """
146
+ Image Classification
147
+
148
+ A customizable wrapper for the torchvision ResNet family, compatible
149
+ with saving/loading architecture.
150
+
151
+ This wrapper allows for customizing the model backbone, input channels,
152
+ and the number of output classes for transfer learning.
153
+ """
154
+ def __init__(self,
155
+ num_classes: int,
156
+ in_channels: int = 3,
157
+ model_name: Literal["resnet18", "resnet34", "resnet50", "resnet101", "resnet152"] = 'resnet50',
158
+ init_with_pretrained: bool = False):
159
+ """
160
+ Args:
161
+ num_classes (int):
162
+ Number of output classes for the final layer.
163
+ in_channels (int):
164
+ Number of input channels (e.g., 1 for grayscale, 3 for RGB).
165
+ model_name (str):
166
+ The name of the ResNet model to use (e.g., 'resnet18', 'resnet34', 'resnet50', 'resnet101', 'resnet152'). Number is the layer count.
167
+ init_with_pretrained (bool):
168
+ If True, initializes the model with weights pretrained on ImageNet. This flag is for initialization only and is NOT saved in the architecture config.
169
+ """
170
+
171
+ weights_enum_name = getattr(vision_models, f"{model_name.upper()}_Weights", None)
172
+
173
+ super().__init__(
174
+ num_classes=num_classes,
175
+ in_channels=in_channels,
176
+ model_name=model_name,
177
+ init_with_pretrained=init_with_pretrained,
178
+ weights_enum_name=weights_enum_name
179
+ )
180
+
181
+ def _get_input_layer(self) -> nn.Conv2d:
182
+ return self.model.conv1
183
+
184
+ def _set_input_layer(self, layer: nn.Conv2d):
185
+ self.model.conv1 = layer
186
+
187
+ def _get_output_layer(self) -> Optional[nn.Linear]:
188
+ return self.model.fc
189
+
190
+ def _set_output_layer(self, layer: nn.Linear):
191
+ self.model.fc = layer
192
+
193
+
194
+ class DragonEfficientNet(_BaseVisionWrapper):
195
+ """
196
+ Image Classification
197
+
198
+ A customizable wrapper for the torchvision EfficientNet family, compatible
199
+ with saving/loading architecture.
200
+
201
+ This wrapper allows for customizing the model backbone, input channels,
202
+ and the number of output classes for transfer learning.
203
+ """
204
+ def __init__(self,
205
+ num_classes: int,
206
+ in_channels: int = 3,
207
+ model_name: str = 'efficientnet_b0',
208
+ init_with_pretrained: bool = False):
209
+ """
210
+ Args:
211
+ num_classes (int):
212
+ Number of output classes for the final layer.
213
+ in_channels (int):
214
+ Number of input channels (e.g., 1 for grayscale, 3 for RGB).
215
+ model_name (str):
216
+ The name of the EfficientNet model to use (e.g., 'efficientnet_b0'
217
+ through 'efficientnet_b7', or 'efficientnet_v2_s', 'efficientnet_v2_m', 'efficientnet_v2_l').
218
+ init_with_pretrained (bool):
219
+ If True, initializes the model with weights pretrained on
220
+ ImageNet. This flag is for initialization only and is
221
+ NOT saved in the architecture config. Defaults to False.
222
+ """
223
+
224
+ weights_enum_name = getattr(vision_models, f"{model_name.upper()}_Weights", None)
225
+
226
+ super().__init__(
227
+ num_classes=num_classes,
228
+ in_channels=in_channels,
229
+ model_name=model_name,
230
+ init_with_pretrained=init_with_pretrained,
231
+ weights_enum_name=weights_enum_name
232
+ )
233
+
234
+ def _get_input_layer(self) -> nn.Conv2d:
235
+ # The first conv layer in EfficientNet is model.features[0][0]
236
+ return self.model.features[0][0]
237
+
238
+ def _set_input_layer(self, layer: nn.Conv2d):
239
+ self.model.features[0][0] = layer
240
+
241
+ def _get_output_layer(self) -> Optional[nn.Linear]:
242
+ # The classifier in EfficientNet is model.classifier[1]
243
+ if hasattr(self.model, 'classifier') and isinstance(self.model.classifier, nn.Sequential):
244
+ output_layer = self.model.classifier[1]
245
+ if isinstance(output_layer, nn.Linear):
246
+ return output_layer
247
+ return None
248
+
249
+ def _set_output_layer(self, layer: nn.Linear):
250
+ self.model.classifier[1] = layer
251
+
252
+
253
+ class DragonVGG(_BaseVisionWrapper):
254
+ """
255
+ Image Classification
256
+
257
+ A customizable wrapper for the torchvision VGG family, compatible
258
+ with saving/loading architecture.
259
+
260
+ This wrapper allows for customizing the model backbone, input channels,
261
+ and the number of output classes for transfer learning.
262
+ """
263
+ def __init__(self,
264
+ num_classes: int,
265
+ in_channels: int = 3,
266
+ model_name: Literal["vgg11", "vgg13", "vgg16", "vgg19", "vgg11_bn", "vgg13_bn", "vgg16_bn", "vgg19_bn"] = 'vgg16',
267
+ init_with_pretrained: bool = False):
268
+ """
269
+ Args:
270
+ num_classes (int):
271
+ Number of output classes for the final layer.
272
+ in_channels (int):
273
+ Number of input channels (e.g., 1 for grayscale, 3 for RGB).
274
+ model_name (str):
275
+ The name of the VGG model to use (e.g., 'vgg16', 'vgg16_bn').
276
+ init_with_pretrained (bool):
277
+ If True, initializes the model with weights pretrained on
278
+ ImageNet. This flag is for initialization only and is
279
+ NOT saved in the architecture config. Defaults to False.
280
+ """
281
+
282
+ # Format model name to find weights enum, e.g., vgg16_bn -> VGG16_BN_Weights
283
+ weights_enum_name = f"{model_name.replace('_bn', '_BN').upper()}_Weights"
284
+
285
+ super().__init__(
286
+ num_classes=num_classes,
287
+ in_channels=in_channels,
288
+ model_name=model_name,
289
+ init_with_pretrained=init_with_pretrained,
290
+ weights_enum_name=weights_enum_name
291
+ )
292
+
293
+ def _get_input_layer(self) -> nn.Conv2d:
294
+ # The first conv layer in VGG is model.features[0]
295
+ return self.model.features[0]
296
+
297
+ def _set_input_layer(self, layer: nn.Conv2d):
298
+ self.model.features[0] = layer
299
+
300
+ def _get_output_layer(self) -> Optional[nn.Linear]:
301
+ # The final classifier in VGG is model.classifier[6]
302
+ if hasattr(self.model, 'classifier') and isinstance(self.model.classifier, nn.Sequential) and len(self.model.classifier) == 7:
303
+ output_layer = self.model.classifier[6]
304
+ if isinstance(output_layer, nn.Linear):
305
+ return output_layer
306
+ return None
307
+
308
+ def _set_output_layer(self, layer: nn.Linear):
309
+ self.model.classifier[6] = layer
310
+
311
+
312
+ # Image segmentation
313
+ class _BaseSegmentationWrapper(nn.Module, _ArchitectureHandlerMixin, ABC):
314
+ """
315
+ Abstract base class for torchvision segmentation model wrappers.
316
+
317
+ Handles common logic for:
318
+ - Model instantiation (with/without pretrained weights and custom num_classes)
319
+ - Input layer modification (for custom in_channels)
320
+ - Forward pass dictionary unpacking (returns 'out' tensor)
321
+ - Architecture saving/loading and representation
322
+ """
323
+ def __init__(self,
324
+ num_classes: int,
325
+ in_channels: int,
326
+ model_name: str,
327
+ init_with_pretrained: bool,
328
+ weights_enum_name: Optional[str] = None):
329
+ super().__init__()
330
+
331
+ # --- 1. Validation and Configuration ---
332
+ if not hasattr(vision_models.segmentation, model_name):
333
+ _LOGGER.error(f"'{model_name}' is not a valid model name in torchvision.models.segmentation.")
334
+ raise ValueError()
335
+
336
+ self.num_classes = num_classes
337
+ self.in_channels = in_channels
338
+ self.model_name = model_name
339
+ self._pretrained_default_transforms = None
340
+
341
+ # --- 2. Instantiate the base model ---
342
+ model_kwargs = {
343
+ 'num_classes': num_classes,
344
+ 'weights': None
345
+ }
346
+ model_constructor = getattr(vision_models.segmentation, model_name)
347
+
348
+ if init_with_pretrained:
349
+ weights_enum = getattr(vision_models.segmentation, weights_enum_name, None) if weights_enum_name else None
350
+ weights = weights_enum.DEFAULT if weights_enum else None
351
+
352
+ # save pretrained model transformations
353
+ if weights:
354
+ self._pretrained_default_transforms = weights.transforms()
355
+
356
+ if weights is None:
357
+ _LOGGER.warning(f"Could not find modern weights for {model_name}. Using 'pretrained=True' legacy fallback.")
358
+ # Legacy models used 'pretrained=True' and num_classes was separate
359
+ self.model = model_constructor(pretrained=True, **model_kwargs)
360
+ else:
361
+ # Modern way: weights object implies pretraining
362
+ model_kwargs['weights'] = weights
363
+ self.model = model_constructor(**model_kwargs)
364
+ else:
365
+ self.model = model_constructor(**model_kwargs)
366
+
367
+ # --- 3. Modify the input layer (using abstract method) ---
368
+ if in_channels != 3:
369
+ original_conv1 = self._get_input_layer()
370
+
371
+ new_conv1 = nn.Conv2d(
372
+ in_channels,
373
+ original_conv1.out_channels,
374
+ kernel_size=original_conv1.kernel_size, # type: ignore
375
+ stride=original_conv1.stride, # type: ignore
376
+ padding=original_conv1.padding, # type: ignore
377
+ bias=(original_conv1.bias is not None)
378
+ )
379
+
380
+ # (Optional) Average original weights if starting from pretrained
381
+ if init_with_pretrained and original_conv1.in_channels == 3:
382
+ with torch.no_grad():
383
+ avg_weights = torch.mean(original_conv1.weight, dim=1, keepdim=True)
384
+ new_conv1.weight[:] = avg_weights.repeat(1, in_channels, 1, 1)
385
+
386
+ self._set_input_layer(new_conv1)
387
+
388
+ @abstractmethod
389
+ def _get_input_layer(self) -> nn.Conv2d:
390
+ """Returns the first convolutional layer of the model (in the backbone)."""
391
+ raise NotImplementedError
392
+
393
+ @abstractmethod
394
+ def _set_input_layer(self, layer: nn.Conv2d):
395
+ """Sets the first convolutional layer of the model (in the backbone)."""
396
+ raise NotImplementedError
397
+
398
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
399
+ """
400
+ Defines the forward pass.
401
+ Returns the 'out' tensor from the segmentation model's output dict.
402
+ """
403
+ output_dict = self.model(x)
404
+ return output_dict['out'] # Key for standard torchvision seg models
405
+
406
+ def get_architecture_config(self) -> Dict[str, Any]:
407
+ """
408
+ Returns the structural configuration of the model.
409
+ The 'init_with_pretrained' flag is intentionally omitted,
410
+ as .load() should restore the architecture, not the weights.
411
+ """
412
+ return {
413
+ 'num_classes': self.num_classes,
414
+ 'in_channels': self.in_channels,
415
+ 'model_name': self.model_name
416
+ }
417
+
418
+ def __repr__(self) -> str:
419
+ """Returns the developer-friendly string representation of the model."""
420
+ return (
421
+ f"{self.__class__.__name__}(model='{self.model_name}', "
422
+ f"in_channels={self.in_channels}, "
423
+ f"num_classes={self.num_classes})"
424
+ )
425
+
426
+
427
+ class DragonFCN(_BaseSegmentationWrapper):
428
+ """
429
+ Image Segmentation
430
+
431
+ A customizable wrapper for the torchvision FCN (Fully Convolutional Network)
432
+ family, compatible with saving/loading architecture.
433
+
434
+ This wrapper allows for customizing the model backbone, input channels,
435
+ and the number of output classes for transfer learning.
436
+ """
437
+ def __init__(self,
438
+ num_classes: int,
439
+ in_channels: int = 3,
440
+ model_name: Literal["fcn_resnet50", "fcn_resnet101"] = 'fcn_resnet50',
441
+ init_with_pretrained: bool = False):
442
+ """
443
+ Args:
444
+ num_classes (int):
445
+ Number of output classes (including background).
446
+ in_channels (int):
447
+ Number of input channels (e.g., 1 for grayscale, 3 for RGB).
448
+ model_name (str):
449
+ The name of the FCN model to use ('fcn_resnet50' or 'fcn_resnet101').
450
+ init_with_pretrained (bool):
451
+ If True, initializes the model with weights pretrained on COCO.
452
+ This flag is for initialization only and is NOT saved in the
453
+ architecture config. Defaults to False.
454
+ """
455
+ # Format model name to find weights enum, e.g., fcn_resnet50 -> FCN_ResNet50_Weights
456
+ weights_model_name = model_name.replace('fcn_', 'FCN_').replace('resnet', 'ResNet')
457
+ weights_enum_name = f"{weights_model_name}_Weights"
458
+
459
+ super().__init__(
460
+ num_classes=num_classes,
461
+ in_channels=in_channels,
462
+ model_name=model_name,
463
+ init_with_pretrained=init_with_pretrained,
464
+ weights_enum_name=weights_enum_name
465
+ )
466
+
467
+ def _get_input_layer(self) -> nn.Conv2d:
468
+ # FCN models use a ResNet backbone, input layer is backbone.conv1
469
+ return self.model.backbone.conv1
470
+
471
+ def _set_input_layer(self, layer: nn.Conv2d):
472
+ self.model.backbone.conv1 = layer
473
+
474
+
475
+ class DragonDeepLabv3(_BaseSegmentationWrapper):
476
+ """
477
+ Image Segmentation
478
+
479
+ A customizable wrapper for the torchvision DeepLabv3 family, compatible
480
+ with saving/loading architecture.
481
+
482
+ This wrapper allows for customizing the model backbone, input channels,
483
+ and the number of output classes for transfer learning.
484
+ """
485
+ def __init__(self,
486
+ num_classes: int,
487
+ in_channels: int = 3,
488
+ model_name: Literal["deeplabv3_resnet50", "deeplabv3_resnet101"] = 'deeplabv3_resnet50',
489
+ init_with_pretrained: bool = False):
490
+ """
491
+ Args:
492
+ num_classes (int):
493
+ Number of output classes (including background).
494
+ in_channels (int):
495
+ Number of input channels (e.g., 1 for grayscale, 3 for RGB).
496
+ model_name (str):
497
+ The name of the DeepLabv3 model to use ('deeplabv3_resnet50' or 'deeplabv3_resnet101').
498
+ init_with_pretrained (bool):
499
+ If True, initializes the model with weights pretrained on COCO.
500
+ This flag is for initialization only and is NOT saved in the
501
+ architecture config. Defaults to False.
502
+ """
503
+
504
+ # Format model name to find weights enum, e.g., deeplabv3_resnet50 -> DeepLabV3_ResNet50_Weights
505
+ weights_model_name = model_name.replace('deeplabv3_', 'DeepLabV3_').replace('resnet', 'ResNet')
506
+ weights_enum_name = f"{weights_model_name}_Weights"
507
+
508
+ super().__init__(
509
+ num_classes=num_classes,
510
+ in_channels=in_channels,
511
+ model_name=model_name,
512
+ init_with_pretrained=init_with_pretrained,
513
+ weights_enum_name=weights_enum_name
514
+ )
515
+
516
+ def _get_input_layer(self) -> nn.Conv2d:
517
+ # DeepLabv3 models use a ResNet backbone, input layer is backbone.conv1
518
+ return self.model.backbone.conv1
519
+
520
+ def _set_input_layer(self, layer: nn.Conv2d):
521
+ self.model.backbone.conv1 = layer
522
+
523
+
524
+ class DragonFastRCNN(nn.Module, _ArchitectureHandlerMixin):
525
+ """
526
+ Object Detection
527
+
528
+ A customizable wrapper for the torchvision Faster R-CNN family.
529
+
530
+ This wrapper allows for customizing the model backbone, input channels,
531
+ and the number of output classes for transfer learning.
532
+
533
+ NOTE: Use an Object Detection compatible trainer.
534
+ """
535
+ def __init__(self,
536
+ num_classes: int,
537
+ in_channels: int = 3,
538
+ model_name: Literal["fasterrcnn_resnet50_fpn", "fasterrcnn_resnet50_fpn_v2"] = 'fasterrcnn_resnet50_fpn_v2',
539
+ init_with_pretrained: bool = False):
540
+ """
541
+ Args:
542
+ num_classes (int):
543
+ Number of output classes (including background).
544
+ in_channels (int):
545
+ Number of input channels (e.g., 1 for grayscale, 3 for RGB).
546
+ model_name (str):
547
+ The name of the Faster R-CNN model to use.
548
+ init_with_pretrained (bool):
549
+ If True, initializes the model with weights pretrained on COCO.
550
+ This flag is for initialization only and is NOT saved in the
551
+ architecture config. Defaults to False.
552
+ """
553
+ super().__init__()
554
+
555
+ # --- 1. Validation and Configuration ---
556
+ if not hasattr(detection_models, model_name):
557
+ _LOGGER.error(f"'{model_name}' is not a valid model name in torchvision.models.detection.")
558
+ raise ValueError()
559
+
560
+ self.num_classes = num_classes
561
+ self.in_channels = in_channels
562
+ self.model_name = model_name
563
+ self._pretrained_default_transforms = None
564
+
565
+ # --- 2. Instantiate the base model ---
566
+ model_constructor = getattr(detection_models, model_name)
567
+
568
+ # Format model name to find weights enum, e.g., fasterrcnn_resnet50_fpn_v2 -> FasterRCNN_ResNet50_FPN_V2_Weights
569
+ weights_model_name = model_name.replace('fasterrcnn_', 'FasterRCNN_').replace('resnet', 'ResNet').replace('_fpn', '_FPN')
570
+ weights_enum_name = f"{weights_model_name.upper()}_Weights"
571
+
572
+ weights_enum = getattr(detection_models, weights_enum_name, None) if weights_enum_name else None
573
+ weights = weights_enum.DEFAULT if weights_enum and init_with_pretrained else None
574
+
575
+ if weights:
576
+ self._pretrained_default_transforms = weights.transforms()
577
+
578
+ self.model = model_constructor(weights=weights, weights_backbone=weights)
579
+
580
+ # --- 4. Modify the output layer (Box Predictor) ---
581
+ # Get the number of input features for the classifier
582
+ in_features = self.model.roi_heads.box_predictor.cls_score.in_features
583
+ # Replace the pre-trained head with a new one
584
+ self.model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)
585
+
586
+ # --- 3. Modify the input layer (Backbone conv1) ---
587
+ if in_channels != 3:
588
+ original_conv1 = self.model.backbone.body.conv1
589
+
590
+ new_conv1 = nn.Conv2d(
591
+ in_channels,
592
+ original_conv1.out_channels,
593
+ kernel_size=original_conv1.kernel_size, # type: ignore
594
+ stride=original_conv1.stride, # type: ignore
595
+ padding=original_conv1.padding, # type: ignore
596
+ bias=(original_conv1.bias is not None)
597
+ )
598
+
599
+ # (Optional) Average original weights if starting from pretrained
600
+ if init_with_pretrained and original_conv1.in_channels == 3 and weights is not None:
601
+ with torch.no_grad():
602
+ # Average the weights across the input channel dimension
603
+ avg_weights = torch.mean(original_conv1.weight, dim=1, keepdim=True)
604
+ # Repeat the averaged weights for the new number of input channels
605
+ new_conv1.weight[:] = avg_weights.repeat(1, in_channels, 1, 1)
606
+
607
+ self.model.backbone.body.conv1 = new_conv1
608
+
609
+ def forward(self, images: List[torch.Tensor], targets: Optional[List[Dict[str, torch.Tensor]]] = None):
610
+ """
611
+ Defines the forward pass.
612
+
613
+ - In train mode, expects (images, targets) and returns a dict of losses.
614
+ - In eval mode, expects (images) and returns a list of prediction dicts.
615
+ """
616
+ # The model's forward pass handles train/eval mode internally.
617
+ return self.model(images, targets)
618
+
619
+ def get_architecture_config(self) -> Dict[str, Any]:
620
+ """
621
+ Returns the structural configuration of the model.
622
+ The 'init_with_pretrained' flag is intentionally omitted,
623
+ as .load() should restore the architecture, not the weights.
624
+ """
625
+ return {
626
+ 'num_classes': self.num_classes,
627
+ 'in_channels': self.in_channels,
628
+ 'model_name': self.model_name
629
+ }
630
+
631
+ def __repr__(self) -> str:
632
+ """Returns the developer-friendly string representation of the model."""
633
+ return (
634
+ f"{self.__class__.__name__}(model='{self.model_name}', "
635
+ f"in_channels={self.in_channels}, "
636
+ f"num_classes={self.num_classes})"
637
+ )
638
+
639
+
640
+ def info():
641
+ _script_info(__all__)