dragon-ml-toolbox 13.1.0__py3-none-any.whl → 14.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of dragon-ml-toolbox might be problematic. Click here for more details.

@@ -0,0 +1,627 @@
1
+ import torch
2
+ from torch import nn
3
+ import torchvision.models as vision_models
4
+ from torchvision.models import detection as detection_models
5
+ from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
6
+ from typing import List, Dict, Any, Literal, Optional
7
+ from abc import ABC, abstractmethod
8
+
9
+ from .ML_models import _ArchitectureHandlerMixin
10
+ from ._logger import _LOGGER
11
+ from ._script_info import _script_info
12
+
13
+
14
+ __all__ = [
15
+ "DragonResNet",
16
+ "DragonEfficientNet",
17
+ "DragonVGG",
18
+ "DragonFCN",
19
+ "DragonDeepLabv3",
20
+ "DragonFastRCNN",
21
+ ]
22
+
23
+
24
+ class _BaseVisionWrapper(nn.Module, _ArchitectureHandlerMixin, ABC):
25
+ """
26
+ Abstract base class for torchvision model wrappers.
27
+
28
+ Handles common logic for:
29
+ - Model instantiation (with/without pretrained weights)
30
+ - Input layer modification (for custom in_channels)
31
+ - Output layer modification (for custom num_classes)
32
+ - Architecture saving/loading and representation
33
+ """
34
+ def __init__(self,
35
+ num_classes: int,
36
+ in_channels: int,
37
+ model_name: str,
38
+ init_with_pretrained: bool,
39
+ weights_enum_name: Optional[str] = None):
40
+ super().__init__()
41
+
42
+ # --- 1. Validation and Configuration ---
43
+ if not hasattr(vision_models, model_name):
44
+ _LOGGER.error(f"'{model_name}' is not a valid model name in torchvision.models.")
45
+ raise ValueError()
46
+
47
+ self.num_classes = num_classes
48
+ self.in_channels = in_channels
49
+ self.model_name = model_name
50
+
51
+ # --- 2. Instantiate the base model ---
52
+ if init_with_pretrained:
53
+ weights_enum = getattr(vision_models, weights_enum_name, None) if weights_enum_name else None
54
+ weights = weights_enum.IMAGENET1K_V1 if weights_enum else None
55
+
56
+ if weights is None and init_with_pretrained:
57
+ _LOGGER.warning(f"Could not find modern weights for {model_name}. Using 'pretrained=True' legacy fallback.")
58
+ self.model = getattr(vision_models, model_name)(pretrained=True)
59
+ else:
60
+ self.model = getattr(vision_models, model_name)(weights=weights)
61
+ else:
62
+ self.model = getattr(vision_models, model_name)(weights=None)
63
+
64
+ # --- 3. Modify the input layer (using abstract method) ---
65
+ if in_channels != 3:
66
+ original_conv1 = self._get_input_layer()
67
+
68
+ new_conv1 = nn.Conv2d(
69
+ in_channels,
70
+ original_conv1.out_channels,
71
+ kernel_size=original_conv1.kernel_size, # type: ignore
72
+ stride=original_conv1.stride, # type: ignore
73
+ padding=original_conv1.padding, # type: ignore
74
+ bias=(original_conv1.bias is not None)
75
+ )
76
+
77
+ # (Optional) Average original weights if starting from pretrained
78
+ if init_with_pretrained and original_conv1.in_channels == 3:
79
+ with torch.no_grad():
80
+ avg_weights = torch.mean(original_conv1.weight, dim=1, keepdim=True)
81
+ new_conv1.weight[:] = avg_weights.repeat(1, in_channels, 1, 1)
82
+
83
+ self._set_input_layer(new_conv1)
84
+
85
+ # --- 4. Modify the output layer (using abstract method) ---
86
+ original_fc = self._get_output_layer()
87
+ if original_fc is None: # Handle case where layer isn't found
88
+ _LOGGER.error(f"Model '{model_name}' has an unexpected classifier structure. Cannot replace final layer.")
89
+ raise AttributeError("Could not find final classifier layer.")
90
+
91
+ num_filters = original_fc.in_features
92
+ self._set_output_layer(nn.Linear(num_filters, num_classes))
93
+
94
+ @abstractmethod
95
+ def _get_input_layer(self) -> nn.Conv2d:
96
+ """Returns the first convolutional layer of the model."""
97
+ raise NotImplementedError
98
+
99
+ @abstractmethod
100
+ def _set_input_layer(self, layer: nn.Conv2d):
101
+ """Sets the first convolutional layer of the model."""
102
+ raise NotImplementedError
103
+
104
+ @abstractmethod
105
+ def _get_output_layer(self) -> Optional[nn.Linear]:
106
+ """Returns the final fully-connected layer of the model."""
107
+ raise NotImplementedError
108
+
109
+ @abstractmethod
110
+ def _set_output_layer(self, layer: nn.Linear):
111
+ """Sets the final fully-connected layer of the model."""
112
+ raise NotImplementedError
113
+
114
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
115
+ """Defines the forward pass of the model."""
116
+ return self.model(x)
117
+
118
+ def get_architecture_config(self) -> Dict[str, Any]:
119
+ """
120
+ Returns the structural configuration of the model.
121
+ The 'init_with_pretrained' flag is intentionally omitted,
122
+ as .load() should restore the architecture, not the weights.
123
+ """
124
+ return {
125
+ 'num_classes': self.num_classes,
126
+ 'in_channels': self.in_channels,
127
+ 'model_name': self.model_name
128
+ }
129
+
130
+ def __repr__(self) -> str:
131
+ """Returns the developer-friendly string representation of the model."""
132
+ return (
133
+ f"{self.__class__.__name__}(model='{self.model_name}', "
134
+ f"in_channels={self.in_channels}, "
135
+ f"num_classes={self.num_classes})"
136
+ )
137
+
138
+
139
+ class DragonResNet(_BaseVisionWrapper):
140
+ """
141
+ Image Classification
142
+
143
+ A customizable wrapper for the torchvision ResNet family, compatible
144
+ with saving/loading architecture.
145
+
146
+ This wrapper allows for customizing the model backbone, input channels,
147
+ and the number of output classes for transfer learning.
148
+ """
149
+ def __init__(self,
150
+ num_classes: int,
151
+ in_channels: int = 3,
152
+ model_name: Literal["resnet18", "resnet34", "resnet50", "resnet101", "resnet152"] = 'resnet50',
153
+ init_with_pretrained: bool = False):
154
+ """
155
+ Args:
156
+ num_classes (int):
157
+ Number of output classes for the final layer.
158
+ in_channels (int):
159
+ Number of input channels (e.g., 1 for grayscale, 3 for RGB).
160
+ model_name (str):
161
+ The name of the ResNet model to use (e.g., 'resnet18', 'resnet34', 'resnet50', 'resnet101', 'resnet152'). Number is the layer count.
162
+ init_with_pretrained (bool):
163
+ If True, initializes the model with weights pretrained on ImageNet. This flag is for initialization only and is NOT saved in the architecture config.
164
+ """
165
+
166
+ weights_enum_name = getattr(vision_models, f"{model_name.upper()}_Weights", None)
167
+
168
+ super().__init__(
169
+ num_classes=num_classes,
170
+ in_channels=in_channels,
171
+ model_name=model_name,
172
+ init_with_pretrained=init_with_pretrained,
173
+ weights_enum_name=weights_enum_name
174
+ )
175
+
176
+ def _get_input_layer(self) -> nn.Conv2d:
177
+ return self.model.conv1
178
+
179
+ def _set_input_layer(self, layer: nn.Conv2d):
180
+ self.model.conv1 = layer
181
+
182
+ def _get_output_layer(self) -> Optional[nn.Linear]:
183
+ return self.model.fc
184
+
185
+ def _set_output_layer(self, layer: nn.Linear):
186
+ self.model.fc = layer
187
+
188
+
189
+ class DragonEfficientNet(_BaseVisionWrapper):
190
+ """
191
+ Image Classification
192
+
193
+ A customizable wrapper for the torchvision EfficientNet family, compatible
194
+ with saving/loading architecture.
195
+
196
+ This wrapper allows for customizing the model backbone, input channels,
197
+ and the number of output classes for transfer learning.
198
+ """
199
+ def __init__(self,
200
+ num_classes: int,
201
+ in_channels: int = 3,
202
+ model_name: str = 'efficientnet_b0',
203
+ init_with_pretrained: bool = False):
204
+ """
205
+ Args:
206
+ num_classes (int):
207
+ Number of output classes for the final layer.
208
+ in_channels (int):
209
+ Number of input channels (e.g., 1 for grayscale, 3 for RGB).
210
+ model_name (str):
211
+ The name of the EfficientNet model to use (e.g., 'efficientnet_b0'
212
+ through 'efficientnet_b7', or 'efficientnet_v2_s', 'efficientnet_v2_m', 'efficientnet_v2_l').
213
+ init_with_pretrained (bool):
214
+ If True, initializes the model with weights pretrained on
215
+ ImageNet. This flag is for initialization only and is
216
+ NOT saved in the architecture config. Defaults to False.
217
+ """
218
+
219
+ weights_enum_name = getattr(vision_models, f"{model_name.upper()}_Weights", None)
220
+
221
+ super().__init__(
222
+ num_classes=num_classes,
223
+ in_channels=in_channels,
224
+ model_name=model_name,
225
+ init_with_pretrained=init_with_pretrained,
226
+ weights_enum_name=weights_enum_name
227
+ )
228
+
229
+ def _get_input_layer(self) -> nn.Conv2d:
230
+ # The first conv layer in EfficientNet is model.features[0][0]
231
+ return self.model.features[0][0]
232
+
233
+ def _set_input_layer(self, layer: nn.Conv2d):
234
+ self.model.features[0][0] = layer
235
+
236
+ def _get_output_layer(self) -> Optional[nn.Linear]:
237
+ # The classifier in EfficientNet is model.classifier[1]
238
+ if hasattr(self.model, 'classifier') and isinstance(self.model.classifier, nn.Sequential):
239
+ output_layer = self.model.classifier[1]
240
+ if isinstance(output_layer, nn.Linear):
241
+ return output_layer
242
+ return None
243
+
244
+ def _set_output_layer(self, layer: nn.Linear):
245
+ self.model.classifier[1] = layer
246
+
247
+
248
+ class DragonVGG(_BaseVisionWrapper):
249
+ """
250
+ Image Classification
251
+
252
+ A customizable wrapper for the torchvision VGG family, compatible
253
+ with saving/loading architecture.
254
+
255
+ This wrapper allows for customizing the model backbone, input channels,
256
+ and the number of output classes for transfer learning.
257
+ """
258
+ def __init__(self,
259
+ num_classes: int,
260
+ in_channels: int = 3,
261
+ model_name: Literal["vgg11", "vgg13", "vgg16", "vgg19", "vgg11_bn", "vgg13_bn", "vgg16_bn", "vgg19_bn"] = 'vgg16',
262
+ init_with_pretrained: bool = False):
263
+ """
264
+ Args:
265
+ num_classes (int):
266
+ Number of output classes for the final layer.
267
+ in_channels (int):
268
+ Number of input channels (e.g., 1 for grayscale, 3 for RGB).
269
+ model_name (str):
270
+ The name of the VGG model to use (e.g., 'vgg16', 'vgg16_bn').
271
+ init_with_pretrained (bool):
272
+ If True, initializes the model with weights pretrained on
273
+ ImageNet. This flag is for initialization only and is
274
+ NOT saved in the architecture config. Defaults to False.
275
+ """
276
+
277
+ # Format model name to find weights enum, e.g., vgg16_bn -> VGG16_BN_Weights
278
+ weights_enum_name = f"{model_name.replace('_bn', '_BN').upper()}_Weights"
279
+
280
+ super().__init__(
281
+ num_classes=num_classes,
282
+ in_channels=in_channels,
283
+ model_name=model_name,
284
+ init_with_pretrained=init_with_pretrained,
285
+ weights_enum_name=weights_enum_name
286
+ )
287
+
288
+ def _get_input_layer(self) -> nn.Conv2d:
289
+ # The first conv layer in VGG is model.features[0]
290
+ return self.model.features[0]
291
+
292
+ def _set_input_layer(self, layer: nn.Conv2d):
293
+ self.model.features[0] = layer
294
+
295
+ def _get_output_layer(self) -> Optional[nn.Linear]:
296
+ # The final classifier in VGG is model.classifier[6]
297
+ if hasattr(self.model, 'classifier') and isinstance(self.model.classifier, nn.Sequential) and len(self.model.classifier) == 7:
298
+ output_layer = self.model.classifier[6]
299
+ if isinstance(output_layer, nn.Linear):
300
+ return output_layer
301
+ return None
302
+
303
+ def _set_output_layer(self, layer: nn.Linear):
304
+ self.model.classifier[6] = layer
305
+
306
+
307
+ # Image segmentation
308
+ class _BaseSegmentationWrapper(nn.Module, _ArchitectureHandlerMixin, ABC):
309
+ """
310
+ Abstract base class for torchvision segmentation model wrappers.
311
+
312
+ Handles common logic for:
313
+ - Model instantiation (with/without pretrained weights and custom num_classes)
314
+ - Input layer modification (for custom in_channels)
315
+ - Forward pass dictionary unpacking (returns 'out' tensor)
316
+ - Architecture saving/loading and representation
317
+ """
318
+ def __init__(self,
319
+ num_classes: int,
320
+ in_channels: int,
321
+ model_name: str,
322
+ init_with_pretrained: bool,
323
+ weights_enum_name: Optional[str] = None):
324
+ super().__init__()
325
+
326
+ # --- 1. Validation and Configuration ---
327
+ if not hasattr(vision_models.segmentation, model_name):
328
+ _LOGGER.error(f"'{model_name}' is not a valid model name in torchvision.models.segmentation.")
329
+ raise ValueError()
330
+
331
+ self.num_classes = num_classes
332
+ self.in_channels = in_channels
333
+ self.model_name = model_name
334
+
335
+ # --- 2. Instantiate the base model ---
336
+ model_kwargs = {
337
+ 'num_classes': num_classes,
338
+ 'weights': None
339
+ }
340
+ model_constructor = getattr(vision_models.segmentation, model_name)
341
+
342
+ if init_with_pretrained:
343
+ weights_enum = getattr(vision_models.segmentation, weights_enum_name, None) if weights_enum_name else None
344
+ weights = weights_enum.DEFAULT if weights_enum else None
345
+
346
+ if weights is None:
347
+ _LOGGER.warning(f"Could not find modern weights for {model_name}. Using 'pretrained=True' legacy fallback.")
348
+ # Legacy models used 'pretrained=True' and num_classes was separate
349
+ self.model = model_constructor(pretrained=True, **model_kwargs)
350
+ else:
351
+ # Modern way: weights object implies pretraining
352
+ model_kwargs['weights'] = weights
353
+ self.model = model_constructor(**model_kwargs)
354
+ else:
355
+ self.model = model_constructor(**model_kwargs)
356
+
357
+ # --- 3. Modify the input layer (using abstract method) ---
358
+ if in_channels != 3:
359
+ original_conv1 = self._get_input_layer()
360
+
361
+ new_conv1 = nn.Conv2d(
362
+ in_channels,
363
+ original_conv1.out_channels,
364
+ kernel_size=original_conv1.kernel_size, # type: ignore
365
+ stride=original_conv1.stride, # type: ignore
366
+ padding=original_conv1.padding, # type: ignore
367
+ bias=(original_conv1.bias is not None)
368
+ )
369
+
370
+ # (Optional) Average original weights if starting from pretrained
371
+ if init_with_pretrained and original_conv1.in_channels == 3:
372
+ with torch.no_grad():
373
+ avg_weights = torch.mean(original_conv1.weight, dim=1, keepdim=True)
374
+ new_conv1.weight[:] = avg_weights.repeat(1, in_channels, 1, 1)
375
+
376
+ self._set_input_layer(new_conv1)
377
+
378
+ @abstractmethod
379
+ def _get_input_layer(self) -> nn.Conv2d:
380
+ """Returns the first convolutional layer of the model (in the backbone)."""
381
+ raise NotImplementedError
382
+
383
+ @abstractmethod
384
+ def _set_input_layer(self, layer: nn.Conv2d):
385
+ """Sets the first convolutional layer of the model (in the backbone)."""
386
+ raise NotImplementedError
387
+
388
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
389
+ """
390
+ Defines the forward pass.
391
+ Returns the 'out' tensor from the segmentation model's output dict.
392
+ """
393
+ output_dict = self.model(x)
394
+ return output_dict['out'] # Key for standard torchvision seg models
395
+
396
+ def get_architecture_config(self) -> Dict[str, Any]:
397
+ """
398
+ Returns the structural configuration of the model.
399
+ The 'init_with_pretrained' flag is intentionally omitted,
400
+ as .load() should restore the architecture, not the weights.
401
+ """
402
+ return {
403
+ 'num_classes': self.num_classes,
404
+ 'in_channels': self.in_channels,
405
+ 'model_name': self.model_name
406
+ }
407
+
408
+ def __repr__(self) -> str:
409
+ """Returns the developer-friendly string representation of the model."""
410
+ return (
411
+ f"{self.__class__.__name__}(model='{self.model_name}', "
412
+ f"in_channels={self.in_channels}, "
413
+ f"num_classes={self.num_classes})"
414
+ )
415
+
416
+
417
+ class DragonFCN(_BaseSegmentationWrapper):
418
+ """
419
+ Image Segmentation
420
+
421
+ A customizable wrapper for the torchvision FCN (Fully Convolutional Network)
422
+ family, compatible with saving/loading architecture.
423
+
424
+ This wrapper allows for customizing the model backbone, input channels,
425
+ and the number of output classes for transfer learning.
426
+ """
427
+ def __init__(self,
428
+ num_classes: int,
429
+ in_channels: int = 3,
430
+ model_name: Literal["fcn_resnet50", "fcn_resnet101"] = 'fcn_resnet50',
431
+ init_with_pretrained: bool = False):
432
+ """
433
+ Args:
434
+ num_classes (int):
435
+ Number of output classes (including background).
436
+ in_channels (int):
437
+ Number of input channels (e.g., 1 for grayscale, 3 for RGB).
438
+ model_name (str):
439
+ The name of the FCN model to use ('fcn_resnet50' or 'fcn_resnet101').
440
+ init_with_pretrained (bool):
441
+ If True, initializes the model with weights pretrained on COCO.
442
+ This flag is for initialization only and is NOT saved in the
443
+ architecture config. Defaults to False.
444
+ """
445
+ # Format model name to find weights enum, e.g., fcn_resnet50 -> FCN_ResNet50_Weights
446
+ weights_model_name = model_name.replace('fcn_', 'FCN_').replace('resnet', 'ResNet')
447
+ weights_enum_name = f"{weights_model_name}_Weights"
448
+
449
+ super().__init__(
450
+ num_classes=num_classes,
451
+ in_channels=in_channels,
452
+ model_name=model_name,
453
+ init_with_pretrained=init_with_pretrained,
454
+ weights_enum_name=weights_enum_name
455
+ )
456
+
457
+ def _get_input_layer(self) -> nn.Conv2d:
458
+ # FCN models use a ResNet backbone, input layer is backbone.conv1
459
+ return self.model.backbone.conv1
460
+
461
+ def _set_input_layer(self, layer: nn.Conv2d):
462
+ self.model.backbone.conv1 = layer
463
+
464
+
465
+ class DragonDeepLabv3(_BaseSegmentationWrapper):
466
+ """
467
+ Image Segmentation
468
+
469
+ A customizable wrapper for the torchvision DeepLabv3 family, compatible
470
+ with saving/loading architecture.
471
+
472
+ This wrapper allows for customizing the model backbone, input channels,
473
+ and the number of output classes for transfer learning.
474
+ """
475
+ def __init__(self,
476
+ num_classes: int,
477
+ in_channels: int = 3,
478
+ model_name: Literal["deeplabv3_resnet50", "deeplabv3_resnet101"] = 'deeplabv3_resnet50',
479
+ init_with_pretrained: bool = False):
480
+ """
481
+ Args:
482
+ num_classes (int):
483
+ Number of output classes (including background).
484
+ in_channels (int):
485
+ Number of input channels (e.g., 1 for grayscale, 3 for RGB).
486
+ model_name (str):
487
+ The name of the DeepLabv3 model to use ('deeplabv3_resnet50' or 'deeplabv3_resnet101').
488
+ init_with_pretrained (bool):
489
+ If True, initializes the model with weights pretrained on COCO.
490
+ This flag is for initialization only and is NOT saved in the
491
+ architecture config. Defaults to False.
492
+ """
493
+
494
+ # Format model name to find weights enum, e.g., deeplabv3_resnet50 -> DeepLabV3_ResNet50_Weights
495
+ weights_model_name = model_name.replace('deeplabv3_', 'DeepLabV3_').replace('resnet', 'ResNet')
496
+ weights_enum_name = f"{weights_model_name}_Weights"
497
+
498
+ super().__init__(
499
+ num_classes=num_classes,
500
+ in_channels=in_channels,
501
+ model_name=model_name,
502
+ init_with_pretrained=init_with_pretrained,
503
+ weights_enum_name=weights_enum_name
504
+ )
505
+
506
+ def _get_input_layer(self) -> nn.Conv2d:
507
+ # DeepLabv3 models use a ResNet backbone, input layer is backbone.conv1
508
+ return self.model.backbone.conv1
509
+
510
+ def _set_input_layer(self, layer: nn.Conv2d):
511
+ self.model.backbone.conv1 = layer
512
+
513
+
514
+ class DragonFastRCNN(nn.Module, _ArchitectureHandlerMixin):
515
+ """
516
+ Object Detection
517
+
518
+ A customizable wrapper for the torchvision Faster R-CNN family.
519
+
520
+ This wrapper allows for customizing the model backbone, input channels,
521
+ and the number of output classes for transfer learning.
522
+
523
+ NOTE: This model is NOT compatible with the MLTrainer class.
524
+ """
525
+ def __init__(self,
526
+ num_classes: int,
527
+ in_channels: int = 3,
528
+ model_name: Literal["fasterrcnn_resnet50_fpn", "fasterrcnn_resnet50_fpn_v2"] = 'fasterrcnn_resnet50_fpn_v2',
529
+ init_with_pretrained: bool = False):
530
+ """
531
+ Args:
532
+ num_classes (int):
533
+ Number of output classes (including background).
534
+ in_channels (int):
535
+ Number of input channels (e.g., 1 for grayscale, 3 for RGB).
536
+ model_name (str):
537
+ The name of the Faster R-CNN model to use.
538
+ init_with_pretrained (bool):
539
+ If True, initializes the model with weights pretrained on COCO.
540
+ This flag is for initialization only and is NOT saved in the
541
+ architecture config. Defaults to False.
542
+ """
543
+ super().__init__()
544
+
545
+ # --- 1. Validation and Configuration ---
546
+ if not hasattr(detection_models, model_name):
547
+ _LOGGER.error(f"'{model_name}' is not a valid model name in torchvision.models.detection.")
548
+ raise ValueError()
549
+
550
+ self.num_classes = num_classes
551
+ self.in_channels = in_channels
552
+ self.model_name = model_name
553
+
554
+ # --- 2. Instantiate the base model ---
555
+ model_constructor = getattr(detection_models, model_name)
556
+
557
+ # Format model name to find weights enum, e.g., fasterrcnn_resnet50_fpn_v2 -> FasterRCNN_ResNet50_FPN_V2_Weights
558
+ weights_model_name = model_name.replace('fasterrcnn_', 'FasterRCNN_').replace('resnet', 'ResNet').replace('_fpn', '_FPN')
559
+ weights_enum_name = f"{weights_model_name.upper()}_Weights"
560
+
561
+ weights_enum = getattr(detection_models, weights_enum_name, None) if weights_enum_name else None
562
+ weights = weights_enum.DEFAULT if weights_enum and init_with_pretrained else None
563
+
564
+ self.model = model_constructor(weights=weights, weights_backbone=weights)
565
+
566
+ # --- 4. Modify the output layer (Box Predictor) ---
567
+ # Get the number of input features for the classifier
568
+ in_features = self.model.roi_heads.box_predictor.cls_score.in_features
569
+ # Replace the pre-trained head with a new one
570
+ self.model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)
571
+
572
+ # --- 3. Modify the input layer (Backbone conv1) ---
573
+ if in_channels != 3:
574
+ original_conv1 = self.model.backbone.body.conv1
575
+
576
+ new_conv1 = nn.Conv2d(
577
+ in_channels,
578
+ original_conv1.out_channels,
579
+ kernel_size=original_conv1.kernel_size, # type: ignore
580
+ stride=original_conv1.stride, # type: ignore
581
+ padding=original_conv1.padding, # type: ignore
582
+ bias=(original_conv1.bias is not None)
583
+ )
584
+
585
+ # (Optional) Average original weights if starting from pretrained
586
+ if init_with_pretrained and original_conv1.in_channels == 3 and weights is not None:
587
+ with torch.no_grad():
588
+ # Average the weights across the input channel dimension
589
+ avg_weights = torch.mean(original_conv1.weight, dim=1, keepdim=True)
590
+ # Repeat the averaged weights for the new number of input channels
591
+ new_conv1.weight[:] = avg_weights.repeat(1, in_channels, 1, 1)
592
+
593
+ self.model.backbone.body.conv1 = new_conv1
594
+
595
+ def forward(self, images: List[torch.Tensor], targets: Optional[List[Dict[str, torch.Tensor]]] = None):
596
+ """
597
+ Defines the forward pass.
598
+
599
+ - In train mode, expects (images, targets) and returns a dict of losses.
600
+ - In eval mode, expects (images) and returns a list of prediction dicts.
601
+ """
602
+ # The model's forward pass handles train/eval mode internally.
603
+ return self.model(images, targets)
604
+
605
+ def get_architecture_config(self) -> Dict[str, Any]:
606
+ """
607
+ Returns the structural configuration of the model.
608
+ The 'init_with_pretrained' flag is intentionally omitted,
609
+ as .load() should restore the architecture, not the weights.
610
+ """
611
+ return {
612
+ 'num_classes': self.num_classes,
613
+ 'in_channels': self.in_channels,
614
+ 'model_name': self.model_name
615
+ }
616
+
617
+ def __repr__(self) -> str:
618
+ """Returns the developer-friendly string representation of the model."""
619
+ return (
620
+ f"{self.__class__.__name__}(model='{self.model_name}', "
621
+ f"in_channels={self.in_channels}, "
622
+ f"num_classes={self.num_classes})"
623
+ )
624
+
625
+
626
+ def info():
627
+ _script_info(__all__)
@@ -0,0 +1,58 @@
1
+ from typing import Union, Dict, Type, Callable
2
+ from PIL import ImageOps, Image
3
+
4
+ from ._logger import _LOGGER
5
+ from ._script_info import _script_info
6
+ from .keys import VisionTransformRecipeKeys
7
+
8
+
9
+ __all__ = [
10
+ "TRANSFORM_REGISTRY",
11
+ "ResizeAspectFill"
12
+ ]
13
+
14
+ # --- Custom Vision Transform Class ---
15
+ class ResizeAspectFill:
16
+ """
17
+ Custom transformation to make an image square by padding it to match the
18
+ longest side, preserving the aspect ratio. The image is finally centered.
19
+
20
+ Args:
21
+ pad_color (Union[str, int]): Color to use for the padding.
22
+ Defaults to "black".
23
+ """
24
+ def __init__(self, pad_color: Union[str, int] = "black") -> None:
25
+ self.pad_color = pad_color
26
+ # Store kwargs to allow for recreation
27
+ self.__setattr__(VisionTransformRecipeKeys.KWARGS, {"pad_color": pad_color})
28
+ # self._kwargs = {"pad_color": pad_color}
29
+
30
+ def __call__(self, image: Image.Image) -> Image.Image:
31
+ if not isinstance(image, Image.Image):
32
+ _LOGGER.error(f"Expected PIL.Image.Image, got {type(image).__name__}")
33
+ raise TypeError()
34
+
35
+ w, h = image.size
36
+ if w == h:
37
+ return image
38
+
39
+ # Determine padding to center the image
40
+ if w > h:
41
+ top_padding = (w - h) // 2
42
+ bottom_padding = w - h - top_padding
43
+ padding = (0, top_padding, 0, bottom_padding)
44
+ else: # h > w
45
+ left_padding = (h - w) // 2
46
+ right_padding = h - w - left_padding
47
+ padding = (left_padding, 0, right_padding, 0)
48
+
49
+ return ImageOps.expand(image, padding, fill=self.pad_color)
50
+
51
+
52
+ #NOTE: Add custom transforms here.
53
+ TRANSFORM_REGISTRY: Dict[str, Type[Callable]] = {
54
+ "ResizeAspectFill": ResizeAspectFill,
55
+ }
56
+
57
+ def info():
58
+ _script_info(__all__)