dragon-ml-toolbox 19.13.0__py3-none-any.whl → 20.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (219) hide show
  1. {dragon_ml_toolbox-19.13.0.dist-info → dragon_ml_toolbox-20.0.0.dist-info}/METADATA +29 -46
  2. dragon_ml_toolbox-20.0.0.dist-info/RECORD +178 -0
  3. ml_tools/{ETL_cleaning.py → ETL_cleaning/__init__.py} +13 -5
  4. ml_tools/ETL_cleaning/_basic_clean.py +351 -0
  5. ml_tools/ETL_cleaning/_clean_tools.py +128 -0
  6. ml_tools/ETL_cleaning/_dragon_cleaner.py +245 -0
  7. ml_tools/ETL_cleaning/_imprimir.py +13 -0
  8. ml_tools/{ETL_engineering.py → ETL_engineering/__init__.py} +8 -4
  9. ml_tools/ETL_engineering/_dragon_engineering.py +261 -0
  10. ml_tools/ETL_engineering/_imprimir.py +24 -0
  11. ml_tools/{_core/_ETL_engineering.py → ETL_engineering/_transforms.py} +14 -267
  12. ml_tools/{_core → GUI_tools}/_GUI_tools.py +37 -40
  13. ml_tools/{GUI_tools.py → GUI_tools/__init__.py} +7 -5
  14. ml_tools/GUI_tools/_imprimir.py +12 -0
  15. ml_tools/IO_tools/_IO_loggers.py +235 -0
  16. ml_tools/IO_tools/_IO_save_load.py +151 -0
  17. ml_tools/IO_tools/_IO_utils.py +140 -0
  18. ml_tools/{IO_tools.py → IO_tools/__init__.py} +13 -5
  19. ml_tools/IO_tools/_imprimir.py +14 -0
  20. ml_tools/MICE/_MICE_imputation.py +132 -0
  21. ml_tools/{MICE_imputation.py → MICE/__init__.py} +6 -7
  22. ml_tools/{_core/_MICE_imputation.py → MICE/_dragon_mice.py} +243 -322
  23. ml_tools/MICE/_imprimir.py +11 -0
  24. ml_tools/{ML_callbacks.py → ML_callbacks/__init__.py} +12 -4
  25. ml_tools/ML_callbacks/_base.py +101 -0
  26. ml_tools/ML_callbacks/_checkpoint.py +232 -0
  27. ml_tools/ML_callbacks/_early_stop.py +208 -0
  28. ml_tools/ML_callbacks/_imprimir.py +12 -0
  29. ml_tools/ML_callbacks/_scheduler.py +197 -0
  30. ml_tools/{ML_chaining_utilities.py → ML_chain/__init__.py} +8 -3
  31. ml_tools/{_core/_ML_chaining_utilities.py → ML_chain/_chaining_tools.py} +5 -129
  32. ml_tools/ML_chain/_dragon_chain.py +140 -0
  33. ml_tools/ML_chain/_imprimir.py +11 -0
  34. ml_tools/ML_configuration/__init__.py +90 -0
  35. ml_tools/ML_configuration/_base_model_config.py +69 -0
  36. ml_tools/ML_configuration/_finalize.py +366 -0
  37. ml_tools/ML_configuration/_imprimir.py +47 -0
  38. ml_tools/ML_configuration/_metrics.py +593 -0
  39. ml_tools/ML_configuration/_models.py +206 -0
  40. ml_tools/ML_configuration/_training.py +124 -0
  41. ml_tools/ML_datasetmaster/__init__.py +28 -0
  42. ml_tools/ML_datasetmaster/_base_datasetmaster.py +337 -0
  43. ml_tools/{_core/_ML_datasetmaster.py → ML_datasetmaster/_datasetmaster.py} +9 -329
  44. ml_tools/ML_datasetmaster/_imprimir.py +15 -0
  45. ml_tools/{_core/_ML_sequence_datasetmaster.py → ML_datasetmaster/_sequence_datasetmaster.py} +13 -15
  46. ml_tools/{_core/_ML_vision_datasetmaster.py → ML_datasetmaster/_vision_datasetmaster.py} +63 -65
  47. ml_tools/ML_evaluation/__init__.py +53 -0
  48. ml_tools/ML_evaluation/_classification.py +629 -0
  49. ml_tools/ML_evaluation/_feature_importance.py +409 -0
  50. ml_tools/ML_evaluation/_imprimir.py +25 -0
  51. ml_tools/ML_evaluation/_loss.py +92 -0
  52. ml_tools/ML_evaluation/_regression.py +273 -0
  53. ml_tools/{_core/_ML_sequence_evaluation.py → ML_evaluation/_sequence.py} +8 -11
  54. ml_tools/{_core/_ML_vision_evaluation.py → ML_evaluation/_vision.py} +12 -17
  55. ml_tools/{_core → ML_evaluation_captum}/_ML_evaluation_captum.py +11 -38
  56. ml_tools/{ML_evaluation_captum.py → ML_evaluation_captum/__init__.py} +6 -4
  57. ml_tools/ML_evaluation_captum/_imprimir.py +10 -0
  58. ml_tools/{_core → ML_finalize_handler}/_ML_finalize_handler.py +3 -7
  59. ml_tools/ML_finalize_handler/__init__.py +10 -0
  60. ml_tools/ML_finalize_handler/_imprimir.py +8 -0
  61. ml_tools/ML_inference/__init__.py +22 -0
  62. ml_tools/ML_inference/_base_inference.py +166 -0
  63. ml_tools/{_core/_ML_chaining_inference.py → ML_inference/_chain_inference.py} +14 -17
  64. ml_tools/ML_inference/_dragon_inference.py +332 -0
  65. ml_tools/ML_inference/_imprimir.py +11 -0
  66. ml_tools/ML_inference/_multi_inference.py +180 -0
  67. ml_tools/ML_inference_sequence/__init__.py +10 -0
  68. ml_tools/ML_inference_sequence/_imprimir.py +8 -0
  69. ml_tools/{_core/_ML_sequence_inference.py → ML_inference_sequence/_sequence_inference.py} +11 -15
  70. ml_tools/ML_inference_vision/__init__.py +10 -0
  71. ml_tools/ML_inference_vision/_imprimir.py +8 -0
  72. ml_tools/{_core/_ML_vision_inference.py → ML_inference_vision/_vision_inference.py} +15 -19
  73. ml_tools/ML_models/__init__.py +32 -0
  74. ml_tools/{_core/_ML_models_advanced.py → ML_models/_advanced_models.py} +22 -18
  75. ml_tools/ML_models/_base_mlp_attention.py +198 -0
  76. ml_tools/{_core/_models_advanced_base.py → ML_models/_base_save_load.py} +73 -49
  77. ml_tools/ML_models/_dragon_tabular.py +248 -0
  78. ml_tools/ML_models/_imprimir.py +18 -0
  79. ml_tools/ML_models/_mlp_attention.py +134 -0
  80. ml_tools/{_core → ML_models}/_models_advanced_helpers.py +13 -13
  81. ml_tools/ML_models_sequence/__init__.py +10 -0
  82. ml_tools/ML_models_sequence/_imprimir.py +8 -0
  83. ml_tools/{_core/_ML_sequence_models.py → ML_models_sequence/_sequence_models.py} +5 -8
  84. ml_tools/ML_models_vision/__init__.py +29 -0
  85. ml_tools/ML_models_vision/_base_wrapper.py +254 -0
  86. ml_tools/ML_models_vision/_image_classification.py +182 -0
  87. ml_tools/ML_models_vision/_image_segmentation.py +108 -0
  88. ml_tools/ML_models_vision/_imprimir.py +16 -0
  89. ml_tools/ML_models_vision/_object_detection.py +135 -0
  90. ml_tools/ML_optimization/__init__.py +21 -0
  91. ml_tools/ML_optimization/_imprimir.py +13 -0
  92. ml_tools/{_core/_ML_optimization_pareto.py → ML_optimization/_multi_dragon.py} +18 -24
  93. ml_tools/ML_optimization/_single_dragon.py +203 -0
  94. ml_tools/{_core/_ML_optimization.py → ML_optimization/_single_manual.py} +75 -213
  95. ml_tools/{_core → ML_scaler}/_ML_scaler.py +8 -11
  96. ml_tools/ML_scaler/__init__.py +10 -0
  97. ml_tools/ML_scaler/_imprimir.py +8 -0
  98. ml_tools/ML_trainer/__init__.py +20 -0
  99. ml_tools/ML_trainer/_base_trainer.py +297 -0
  100. ml_tools/ML_trainer/_dragon_detection_trainer.py +402 -0
  101. ml_tools/ML_trainer/_dragon_sequence_trainer.py +540 -0
  102. ml_tools/ML_trainer/_dragon_trainer.py +1160 -0
  103. ml_tools/ML_trainer/_imprimir.py +10 -0
  104. ml_tools/{ML_utilities.py → ML_utilities/__init__.py} +14 -6
  105. ml_tools/ML_utilities/_artifact_finder.py +382 -0
  106. ml_tools/ML_utilities/_imprimir.py +16 -0
  107. ml_tools/ML_utilities/_inspection.py +325 -0
  108. ml_tools/ML_utilities/_train_tools.py +205 -0
  109. ml_tools/{ML_vision_transformers.py → ML_vision_transformers/__init__.py} +9 -6
  110. ml_tools/{_core/_ML_vision_transformers.py → ML_vision_transformers/_core_transforms.py} +11 -155
  111. ml_tools/ML_vision_transformers/_imprimir.py +14 -0
  112. ml_tools/ML_vision_transformers/_offline_augmentation.py +159 -0
  113. ml_tools/{_core/_PSO_optimization.py → PSO_optimization/_PSO.py} +58 -15
  114. ml_tools/{PSO_optimization.py → PSO_optimization/__init__.py} +5 -3
  115. ml_tools/PSO_optimization/_imprimir.py +10 -0
  116. ml_tools/SQL/__init__.py +7 -0
  117. ml_tools/{_core/_SQL.py → SQL/_dragon_SQL.py} +7 -11
  118. ml_tools/SQL/_imprimir.py +8 -0
  119. ml_tools/{_core → VIF}/_VIF_factor.py +5 -8
  120. ml_tools/{VIF_factor.py → VIF/__init__.py} +4 -2
  121. ml_tools/VIF/_imprimir.py +10 -0
  122. ml_tools/_core/__init__.py +7 -1
  123. ml_tools/_core/_logger.py +8 -18
  124. ml_tools/_core/_schema_load_ops.py +43 -0
  125. ml_tools/_core/_script_info.py +2 -2
  126. ml_tools/{data_exploration.py → data_exploration/__init__.py} +32 -16
  127. ml_tools/data_exploration/_analysis.py +214 -0
  128. ml_tools/data_exploration/_cleaning.py +566 -0
  129. ml_tools/data_exploration/_features.py +583 -0
  130. ml_tools/data_exploration/_imprimir.py +32 -0
  131. ml_tools/data_exploration/_plotting.py +487 -0
  132. ml_tools/data_exploration/_schema_ops.py +176 -0
  133. ml_tools/{ensemble_evaluation.py → ensemble_evaluation/__init__.py} +6 -4
  134. ml_tools/{_core → ensemble_evaluation}/_ensemble_evaluation.py +3 -7
  135. ml_tools/ensemble_evaluation/_imprimir.py +14 -0
  136. ml_tools/{ensemble_inference.py → ensemble_inference/__init__.py} +5 -3
  137. ml_tools/{_core → ensemble_inference}/_ensemble_inference.py +15 -18
  138. ml_tools/ensemble_inference/_imprimir.py +9 -0
  139. ml_tools/{ensemble_learning.py → ensemble_learning/__init__.py} +4 -6
  140. ml_tools/{_core → ensemble_learning}/_ensemble_learning.py +7 -10
  141. ml_tools/ensemble_learning/_imprimir.py +10 -0
  142. ml_tools/{excel_handler.py → excel_handler/__init__.py} +5 -3
  143. ml_tools/{_core → excel_handler}/_excel_handler.py +6 -10
  144. ml_tools/excel_handler/_imprimir.py +13 -0
  145. ml_tools/{keys.py → keys/__init__.py} +4 -1
  146. ml_tools/keys/_imprimir.py +11 -0
  147. ml_tools/{_core → keys}/_keys.py +2 -0
  148. ml_tools/{math_utilities.py → math_utilities/__init__.py} +5 -2
  149. ml_tools/math_utilities/_imprimir.py +11 -0
  150. ml_tools/{_core → math_utilities}/_math_utilities.py +1 -5
  151. ml_tools/{optimization_tools.py → optimization_tools/__init__.py} +9 -4
  152. ml_tools/optimization_tools/_imprimir.py +13 -0
  153. ml_tools/optimization_tools/_optimization_bounds.py +236 -0
  154. ml_tools/optimization_tools/_optimization_plots.py +218 -0
  155. ml_tools/{path_manager.py → path_manager/__init__.py} +6 -3
  156. ml_tools/{_core/_path_manager.py → path_manager/_dragonmanager.py} +11 -347
  157. ml_tools/path_manager/_imprimir.py +15 -0
  158. ml_tools/path_manager/_path_tools.py +346 -0
  159. ml_tools/plot_fonts/__init__.py +8 -0
  160. ml_tools/plot_fonts/_imprimir.py +8 -0
  161. ml_tools/{_core → plot_fonts}/_plot_fonts.py +2 -5
  162. ml_tools/schema/__init__.py +15 -0
  163. ml_tools/schema/_feature_schema.py +223 -0
  164. ml_tools/schema/_gui_schema.py +191 -0
  165. ml_tools/schema/_imprimir.py +10 -0
  166. ml_tools/{serde.py → serde/__init__.py} +4 -2
  167. ml_tools/serde/_imprimir.py +10 -0
  168. ml_tools/{_core → serde}/_serde.py +3 -8
  169. ml_tools/{utilities.py → utilities/__init__.py} +11 -6
  170. ml_tools/utilities/_imprimir.py +18 -0
  171. ml_tools/{_core/_utilities.py → utilities/_utility_save_load.py} +13 -190
  172. ml_tools/utilities/_utility_tools.py +192 -0
  173. dragon_ml_toolbox-19.13.0.dist-info/RECORD +0 -111
  174. ml_tools/ML_chaining_inference.py +0 -8
  175. ml_tools/ML_configuration.py +0 -86
  176. ml_tools/ML_configuration_pytab.py +0 -14
  177. ml_tools/ML_datasetmaster.py +0 -10
  178. ml_tools/ML_evaluation.py +0 -16
  179. ml_tools/ML_evaluation_multi.py +0 -12
  180. ml_tools/ML_finalize_handler.py +0 -8
  181. ml_tools/ML_inference.py +0 -12
  182. ml_tools/ML_models.py +0 -14
  183. ml_tools/ML_models_advanced.py +0 -14
  184. ml_tools/ML_models_pytab.py +0 -14
  185. ml_tools/ML_optimization.py +0 -14
  186. ml_tools/ML_optimization_pareto.py +0 -8
  187. ml_tools/ML_scaler.py +0 -8
  188. ml_tools/ML_sequence_datasetmaster.py +0 -8
  189. ml_tools/ML_sequence_evaluation.py +0 -10
  190. ml_tools/ML_sequence_inference.py +0 -8
  191. ml_tools/ML_sequence_models.py +0 -8
  192. ml_tools/ML_trainer.py +0 -12
  193. ml_tools/ML_vision_datasetmaster.py +0 -12
  194. ml_tools/ML_vision_evaluation.py +0 -10
  195. ml_tools/ML_vision_inference.py +0 -8
  196. ml_tools/ML_vision_models.py +0 -18
  197. ml_tools/SQL.py +0 -8
  198. ml_tools/_core/_ETL_cleaning.py +0 -694
  199. ml_tools/_core/_IO_tools.py +0 -498
  200. ml_tools/_core/_ML_callbacks.py +0 -702
  201. ml_tools/_core/_ML_configuration.py +0 -1332
  202. ml_tools/_core/_ML_configuration_pytab.py +0 -102
  203. ml_tools/_core/_ML_evaluation.py +0 -867
  204. ml_tools/_core/_ML_evaluation_multi.py +0 -544
  205. ml_tools/_core/_ML_inference.py +0 -646
  206. ml_tools/_core/_ML_models.py +0 -668
  207. ml_tools/_core/_ML_models_pytab.py +0 -693
  208. ml_tools/_core/_ML_trainer.py +0 -2323
  209. ml_tools/_core/_ML_utilities.py +0 -886
  210. ml_tools/_core/_ML_vision_models.py +0 -644
  211. ml_tools/_core/_data_exploration.py +0 -1901
  212. ml_tools/_core/_optimization_tools.py +0 -493
  213. ml_tools/_core/_schema.py +0 -359
  214. ml_tools/plot_fonts.py +0 -8
  215. ml_tools/schema.py +0 -12
  216. {dragon_ml_toolbox-19.13.0.dist-info → dragon_ml_toolbox-20.0.0.dist-info}/WHEEL +0 -0
  217. {dragon_ml_toolbox-19.13.0.dist-info → dragon_ml_toolbox-20.0.0.dist-info}/licenses/LICENSE +0 -0
  218. {dragon_ml_toolbox-19.13.0.dist-info → dragon_ml_toolbox-20.0.0.dist-info}/licenses/LICENSE-THIRD-PARTY.md +0 -0
  219. {dragon_ml_toolbox-19.13.0.dist-info → dragon_ml_toolbox-20.0.0.dist-info}/top_level.txt +0 -0
@@ -1,644 +0,0 @@
1
- import torch
2
- from torch import nn
3
- import torchvision.models as vision_models
4
- from torchvision.models import detection as detection_models
5
- from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
6
- from typing import List, Dict, Any, Literal, Optional
7
- from abc import ABC, abstractmethod
8
-
9
- from ._ML_models import _ArchitectureHandlerMixin
10
- from ._logger import get_logger
11
- from ._script_info import _script_info
12
-
13
-
14
- _LOGGER = get_logger("DragonModel")
15
-
16
-
17
- __all__ = [
18
- "DragonResNet",
19
- "DragonEfficientNet",
20
- "DragonVGG",
21
- "DragonFCN",
22
- "DragonDeepLabv3",
23
- "DragonFastRCNN",
24
- ]
25
-
26
-
27
- class _BaseVisionWrapper(nn.Module, _ArchitectureHandlerMixin, ABC):
28
- """
29
- Abstract base class for torchvision model wrappers.
30
-
31
- Handles common logic for:
32
- - Model instantiation (with/without pretrained weights)
33
- - Input layer modification (for custom in_channels)
34
- - Output layer modification (for custom num_classes)
35
- - Architecture saving/loading and representation
36
- """
37
- def __init__(self,
38
- num_classes: int,
39
- in_channels: int,
40
- model_name: str,
41
- init_with_pretrained: bool,
42
- weights_enum_name: Optional[str] = None):
43
- super().__init__()
44
-
45
- # --- 1. Validation and Configuration ---
46
- if not hasattr(vision_models, model_name):
47
- _LOGGER.error(f"'{model_name}' is not a valid model name in torchvision.models.")
48
- raise ValueError()
49
-
50
- self.num_classes = num_classes
51
- self.in_channels = in_channels
52
- self.model_name = model_name
53
- self._pretrained_default_transforms = None
54
-
55
- # --- 2. Instantiate the base model ---
56
- if init_with_pretrained:
57
- weights_enum = getattr(vision_models, weights_enum_name, None) if weights_enum_name else None
58
- weights = weights_enum.IMAGENET1K_V1 if weights_enum else None
59
-
60
- # Save transformations for pretrained models
61
- if weights:
62
- self._pretrained_default_transforms = weights.transforms()
63
-
64
- if weights is None and init_with_pretrained:
65
- _LOGGER.warning(f"Could not find modern weights for {model_name}. Using 'pretrained=True' legacy fallback.")
66
- self.model = getattr(vision_models, model_name)(pretrained=True)
67
- else:
68
- self.model = getattr(vision_models, model_name)(weights=weights)
69
- else:
70
- self.model = getattr(vision_models, model_name)(weights=None)
71
-
72
- # --- 3. Modify the input layer (using abstract method) ---
73
- if in_channels != 3:
74
- original_conv1 = self._get_input_layer()
75
-
76
- new_conv1 = nn.Conv2d(
77
- in_channels,
78
- original_conv1.out_channels,
79
- kernel_size=original_conv1.kernel_size, # type: ignore
80
- stride=original_conv1.stride, # type: ignore
81
- padding=original_conv1.padding, # type: ignore
82
- bias=(original_conv1.bias is not None)
83
- )
84
-
85
- # (Optional) Average original weights if starting from pretrained
86
- if init_with_pretrained and original_conv1.in_channels == 3:
87
- with torch.no_grad():
88
- avg_weights = torch.mean(original_conv1.weight, dim=1, keepdim=True)
89
- new_conv1.weight[:] = avg_weights.repeat(1, in_channels, 1, 1)
90
-
91
- self._set_input_layer(new_conv1)
92
-
93
- # --- 4. Modify the output layer (using abstract method) ---
94
- original_fc = self._get_output_layer()
95
- if original_fc is None: # Handle case where layer isn't found
96
- _LOGGER.error(f"Model '{model_name}' has an unexpected classifier structure. Cannot replace final layer.")
97
- raise AttributeError("Could not find final classifier layer.")
98
-
99
- num_filters = original_fc.in_features
100
- self._set_output_layer(nn.Linear(num_filters, num_classes))
101
-
102
- @abstractmethod
103
- def _get_input_layer(self) -> nn.Conv2d:
104
- """Returns the first convolutional layer of the model."""
105
- raise NotImplementedError
106
-
107
- @abstractmethod
108
- def _set_input_layer(self, layer: nn.Conv2d):
109
- """Sets the first convolutional layer of the model."""
110
- raise NotImplementedError
111
-
112
- @abstractmethod
113
- def _get_output_layer(self) -> Optional[nn.Linear]:
114
- """Returns the final fully-connected layer of the model."""
115
- raise NotImplementedError
116
-
117
- @abstractmethod
118
- def _set_output_layer(self, layer: nn.Linear):
119
- """Sets the final fully-connected layer of the model."""
120
- raise NotImplementedError
121
-
122
- def forward(self, x: torch.Tensor) -> torch.Tensor:
123
- """Defines the forward pass of the model."""
124
- return self.model(x)
125
-
126
- def get_architecture_config(self) -> Dict[str, Any]:
127
- """
128
- Returns the structural configuration of the model.
129
- The 'init_with_pretrained' flag is intentionally omitted,
130
- as .load() should restore the architecture, not the weights.
131
- """
132
- return {
133
- 'num_classes': self.num_classes,
134
- 'in_channels': self.in_channels,
135
- 'model_name': self.model_name
136
- }
137
-
138
- def __repr__(self) -> str:
139
- """Returns the developer-friendly string representation of the model."""
140
- return (
141
- f"{self.__class__.__name__}(model='{self.model_name}', "
142
- f"in_channels={self.in_channels}, "
143
- f"num_classes={self.num_classes})"
144
- )
145
-
146
-
147
- class DragonResNet(_BaseVisionWrapper):
148
- """
149
- Image Classification
150
-
151
- A customizable wrapper for the torchvision ResNet family, compatible
152
- with saving/loading architecture.
153
-
154
- This wrapper allows for customizing the model backbone, input channels,
155
- and the number of output classes for transfer learning.
156
- """
157
- def __init__(self,
158
- num_classes: int,
159
- in_channels: int = 3,
160
- model_name: Literal["resnet18", "resnet34", "resnet50", "resnet101", "resnet152"] = 'resnet50',
161
- init_with_pretrained: bool = False):
162
- """
163
- Args:
164
- num_classes (int):
165
- Number of output classes for the final layer.
166
- in_channels (int):
167
- Number of input channels (e.g., 1 for grayscale, 3 for RGB).
168
- model_name (str):
169
- The name of the ResNet model to use (e.g., 'resnet18', 'resnet34', 'resnet50', 'resnet101', 'resnet152'). Number is the layer count.
170
- init_with_pretrained (bool):
171
- If True, initializes the model with weights pretrained on ImageNet. This flag is for initialization only and is NOT saved in the architecture config.
172
- """
173
-
174
- weights_enum_name = getattr(vision_models, f"{model_name.upper()}_Weights", None)
175
-
176
- super().__init__(
177
- num_classes=num_classes,
178
- in_channels=in_channels,
179
- model_name=model_name,
180
- init_with_pretrained=init_with_pretrained,
181
- weights_enum_name=weights_enum_name
182
- )
183
-
184
- def _get_input_layer(self) -> nn.Conv2d:
185
- return self.model.conv1
186
-
187
- def _set_input_layer(self, layer: nn.Conv2d):
188
- self.model.conv1 = layer
189
-
190
- def _get_output_layer(self) -> Optional[nn.Linear]:
191
- return self.model.fc
192
-
193
- def _set_output_layer(self, layer: nn.Linear):
194
- self.model.fc = layer
195
-
196
-
197
- class DragonEfficientNet(_BaseVisionWrapper):
198
- """
199
- Image Classification
200
-
201
- A customizable wrapper for the torchvision EfficientNet family, compatible
202
- with saving/loading architecture.
203
-
204
- This wrapper allows for customizing the model backbone, input channels,
205
- and the number of output classes for transfer learning.
206
- """
207
- def __init__(self,
208
- num_classes: int,
209
- in_channels: int = 3,
210
- model_name: str = 'efficientnet_b0',
211
- init_with_pretrained: bool = False):
212
- """
213
- Args:
214
- num_classes (int):
215
- Number of output classes for the final layer.
216
- in_channels (int):
217
- Number of input channels (e.g., 1 for grayscale, 3 for RGB).
218
- model_name (str):
219
- The name of the EfficientNet model to use (e.g., 'efficientnet_b0'
220
- through 'efficientnet_b7', or 'efficientnet_v2_s', 'efficientnet_v2_m', 'efficientnet_v2_l').
221
- init_with_pretrained (bool):
222
- If True, initializes the model with weights pretrained on
223
- ImageNet. This flag is for initialization only and is
224
- NOT saved in the architecture config. Defaults to False.
225
- """
226
-
227
- weights_enum_name = getattr(vision_models, f"{model_name.upper()}_Weights", None)
228
-
229
- super().__init__(
230
- num_classes=num_classes,
231
- in_channels=in_channels,
232
- model_name=model_name,
233
- init_with_pretrained=init_with_pretrained,
234
- weights_enum_name=weights_enum_name
235
- )
236
-
237
- def _get_input_layer(self) -> nn.Conv2d:
238
- # The first conv layer in EfficientNet is model.features[0][0]
239
- return self.model.features[0][0]
240
-
241
- def _set_input_layer(self, layer: nn.Conv2d):
242
- self.model.features[0][0] = layer
243
-
244
- def _get_output_layer(self) -> Optional[nn.Linear]:
245
- # The classifier in EfficientNet is model.classifier[1]
246
- if hasattr(self.model, 'classifier') and isinstance(self.model.classifier, nn.Sequential):
247
- output_layer = self.model.classifier[1]
248
- if isinstance(output_layer, nn.Linear):
249
- return output_layer
250
- return None
251
-
252
- def _set_output_layer(self, layer: nn.Linear):
253
- self.model.classifier[1] = layer
254
-
255
-
256
- class DragonVGG(_BaseVisionWrapper):
257
- """
258
- Image Classification
259
-
260
- A customizable wrapper for the torchvision VGG family, compatible
261
- with saving/loading architecture.
262
-
263
- This wrapper allows for customizing the model backbone, input channels,
264
- and the number of output classes for transfer learning.
265
- """
266
- def __init__(self,
267
- num_classes: int,
268
- in_channels: int = 3,
269
- model_name: Literal["vgg11", "vgg13", "vgg16", "vgg19", "vgg11_bn", "vgg13_bn", "vgg16_bn", "vgg19_bn"] = 'vgg16',
270
- init_with_pretrained: bool = False):
271
- """
272
- Args:
273
- num_classes (int):
274
- Number of output classes for the final layer.
275
- in_channels (int):
276
- Number of input channels (e.g., 1 for grayscale, 3 for RGB).
277
- model_name (str):
278
- The name of the VGG model to use (e.g., 'vgg16', 'vgg16_bn').
279
- init_with_pretrained (bool):
280
- If True, initializes the model with weights pretrained on
281
- ImageNet. This flag is for initialization only and is
282
- NOT saved in the architecture config. Defaults to False.
283
- """
284
-
285
- # Format model name to find weights enum, e.g., vgg16_bn -> VGG16_BN_Weights
286
- weights_enum_name = f"{model_name.replace('_bn', '_BN').upper()}_Weights"
287
-
288
- super().__init__(
289
- num_classes=num_classes,
290
- in_channels=in_channels,
291
- model_name=model_name,
292
- init_with_pretrained=init_with_pretrained,
293
- weights_enum_name=weights_enum_name
294
- )
295
-
296
- def _get_input_layer(self) -> nn.Conv2d:
297
- # The first conv layer in VGG is model.features[0]
298
- return self.model.features[0]
299
-
300
- def _set_input_layer(self, layer: nn.Conv2d):
301
- self.model.features[0] = layer
302
-
303
- def _get_output_layer(self) -> Optional[nn.Linear]:
304
- # The final classifier in VGG is model.classifier[6]
305
- if hasattr(self.model, 'classifier') and isinstance(self.model.classifier, nn.Sequential) and len(self.model.classifier) == 7:
306
- output_layer = self.model.classifier[6]
307
- if isinstance(output_layer, nn.Linear):
308
- return output_layer
309
- return None
310
-
311
- def _set_output_layer(self, layer: nn.Linear):
312
- self.model.classifier[6] = layer
313
-
314
-
315
- # Image segmentation
316
- class _BaseSegmentationWrapper(nn.Module, _ArchitectureHandlerMixin, ABC):
317
- """
318
- Abstract base class for torchvision segmentation model wrappers.
319
-
320
- Handles common logic for:
321
- - Model instantiation (with/without pretrained weights and custom num_classes)
322
- - Input layer modification (for custom in_channels)
323
- - Forward pass dictionary unpacking (returns 'out' tensor)
324
- - Architecture saving/loading and representation
325
- """
326
- def __init__(self,
327
- num_classes: int,
328
- in_channels: int,
329
- model_name: str,
330
- init_with_pretrained: bool,
331
- weights_enum_name: Optional[str] = None):
332
- super().__init__()
333
-
334
- # --- 1. Validation and Configuration ---
335
- if not hasattr(vision_models.segmentation, model_name):
336
- _LOGGER.error(f"'{model_name}' is not a valid model name in torchvision.models.segmentation.")
337
- raise ValueError()
338
-
339
- self.num_classes = num_classes
340
- self.in_channels = in_channels
341
- self.model_name = model_name
342
- self._pretrained_default_transforms = None
343
-
344
- # --- 2. Instantiate the base model ---
345
- model_kwargs = {
346
- 'num_classes': num_classes,
347
- 'weights': None
348
- }
349
- model_constructor = getattr(vision_models.segmentation, model_name)
350
-
351
- if init_with_pretrained:
352
- weights_enum = getattr(vision_models.segmentation, weights_enum_name, None) if weights_enum_name else None
353
- weights = weights_enum.DEFAULT if weights_enum else None
354
-
355
- # save pretrained model transformations
356
- if weights:
357
- self._pretrained_default_transforms = weights.transforms()
358
-
359
- if weights is None:
360
- _LOGGER.warning(f"Could not find modern weights for {model_name}. Using 'pretrained=True' legacy fallback.")
361
- # Legacy models used 'pretrained=True' and num_classes was separate
362
- self.model = model_constructor(pretrained=True, **model_kwargs)
363
- else:
364
- # Modern way: weights object implies pretraining
365
- model_kwargs['weights'] = weights
366
- self.model = model_constructor(**model_kwargs)
367
- else:
368
- self.model = model_constructor(**model_kwargs)
369
-
370
- # --- 3. Modify the input layer (using abstract method) ---
371
- if in_channels != 3:
372
- original_conv1 = self._get_input_layer()
373
-
374
- new_conv1 = nn.Conv2d(
375
- in_channels,
376
- original_conv1.out_channels,
377
- kernel_size=original_conv1.kernel_size, # type: ignore
378
- stride=original_conv1.stride, # type: ignore
379
- padding=original_conv1.padding, # type: ignore
380
- bias=(original_conv1.bias is not None)
381
- )
382
-
383
- # (Optional) Average original weights if starting from pretrained
384
- if init_with_pretrained and original_conv1.in_channels == 3:
385
- with torch.no_grad():
386
- avg_weights = torch.mean(original_conv1.weight, dim=1, keepdim=True)
387
- new_conv1.weight[:] = avg_weights.repeat(1, in_channels, 1, 1)
388
-
389
- self._set_input_layer(new_conv1)
390
-
391
- @abstractmethod
392
- def _get_input_layer(self) -> nn.Conv2d:
393
- """Returns the first convolutional layer of the model (in the backbone)."""
394
- raise NotImplementedError
395
-
396
- @abstractmethod
397
- def _set_input_layer(self, layer: nn.Conv2d):
398
- """Sets the first convolutional layer of the model (in the backbone)."""
399
- raise NotImplementedError
400
-
401
- def forward(self, x: torch.Tensor) -> torch.Tensor:
402
- """
403
- Defines the forward pass.
404
- Returns the 'out' tensor from the segmentation model's output dict.
405
- """
406
- output_dict = self.model(x)
407
- return output_dict['out'] # Key for standard torchvision seg models
408
-
409
- def get_architecture_config(self) -> Dict[str, Any]:
410
- """
411
- Returns the structural configuration of the model.
412
- The 'init_with_pretrained' flag is intentionally omitted,
413
- as .load() should restore the architecture, not the weights.
414
- """
415
- return {
416
- 'num_classes': self.num_classes,
417
- 'in_channels': self.in_channels,
418
- 'model_name': self.model_name
419
- }
420
-
421
- def __repr__(self) -> str:
422
- """Returns the developer-friendly string representation of the model."""
423
- return (
424
- f"{self.__class__.__name__}(model='{self.model_name}', "
425
- f"in_channels={self.in_channels}, "
426
- f"num_classes={self.num_classes})"
427
- )
428
-
429
-
430
- class DragonFCN(_BaseSegmentationWrapper):
431
- """
432
- Image Segmentation
433
-
434
- A customizable wrapper for the torchvision FCN (Fully Convolutional Network)
435
- family, compatible with saving/loading architecture.
436
-
437
- This wrapper allows for customizing the model backbone, input channels,
438
- and the number of output classes for transfer learning.
439
- """
440
- def __init__(self,
441
- num_classes: int,
442
- in_channels: int = 3,
443
- model_name: Literal["fcn_resnet50", "fcn_resnet101"] = 'fcn_resnet50',
444
- init_with_pretrained: bool = False):
445
- """
446
- Args:
447
- num_classes (int):
448
- Number of output classes (including background).
449
- in_channels (int):
450
- Number of input channels (e.g., 1 for grayscale, 3 for RGB).
451
- model_name (str):
452
- The name of the FCN model to use ('fcn_resnet50' or 'fcn_resnet101').
453
- init_with_pretrained (bool):
454
- If True, initializes the model with weights pretrained on COCO.
455
- This flag is for initialization only and is NOT saved in the
456
- architecture config. Defaults to False.
457
- """
458
- # Format model name to find weights enum, e.g., fcn_resnet50 -> FCN_ResNet50_Weights
459
- weights_model_name = model_name.replace('fcn_', 'FCN_').replace('resnet', 'ResNet')
460
- weights_enum_name = f"{weights_model_name}_Weights"
461
-
462
- super().__init__(
463
- num_classes=num_classes,
464
- in_channels=in_channels,
465
- model_name=model_name,
466
- init_with_pretrained=init_with_pretrained,
467
- weights_enum_name=weights_enum_name
468
- )
469
-
470
- def _get_input_layer(self) -> nn.Conv2d:
471
- # FCN models use a ResNet backbone, input layer is backbone.conv1
472
- return self.model.backbone.conv1
473
-
474
- def _set_input_layer(self, layer: nn.Conv2d):
475
- self.model.backbone.conv1 = layer
476
-
477
-
478
- class DragonDeepLabv3(_BaseSegmentationWrapper):
479
- """
480
- Image Segmentation
481
-
482
- A customizable wrapper for the torchvision DeepLabv3 family, compatible
483
- with saving/loading architecture.
484
-
485
- This wrapper allows for customizing the model backbone, input channels,
486
- and the number of output classes for transfer learning.
487
- """
488
- def __init__(self,
489
- num_classes: int,
490
- in_channels: int = 3,
491
- model_name: Literal["deeplabv3_resnet50", "deeplabv3_resnet101"] = 'deeplabv3_resnet50',
492
- init_with_pretrained: bool = False):
493
- """
494
- Args:
495
- num_classes (int):
496
- Number of output classes (including background).
497
- in_channels (int):
498
- Number of input channels (e.g., 1 for grayscale, 3 for RGB).
499
- model_name (str):
500
- The name of the DeepLabv3 model to use ('deeplabv3_resnet50' or 'deeplabv3_resnet101').
501
- init_with_pretrained (bool):
502
- If True, initializes the model with weights pretrained on COCO.
503
- This flag is for initialization only and is NOT saved in the
504
- architecture config. Defaults to False.
505
- """
506
-
507
- # Format model name to find weights enum, e.g., deeplabv3_resnet50 -> DeepLabV3_ResNet50_Weights
508
- weights_model_name = model_name.replace('deeplabv3_', 'DeepLabV3_').replace('resnet', 'ResNet')
509
- weights_enum_name = f"{weights_model_name}_Weights"
510
-
511
- super().__init__(
512
- num_classes=num_classes,
513
- in_channels=in_channels,
514
- model_name=model_name,
515
- init_with_pretrained=init_with_pretrained,
516
- weights_enum_name=weights_enum_name
517
- )
518
-
519
- def _get_input_layer(self) -> nn.Conv2d:
520
- # DeepLabv3 models use a ResNet backbone, input layer is backbone.conv1
521
- return self.model.backbone.conv1
522
-
523
- def _set_input_layer(self, layer: nn.Conv2d):
524
- self.model.backbone.conv1 = layer
525
-
526
-
527
- class DragonFastRCNN(nn.Module, _ArchitectureHandlerMixin):
528
- """
529
- Object Detection
530
-
531
- A customizable wrapper for the torchvision Faster R-CNN family.
532
-
533
- This wrapper allows for customizing the model backbone, input channels,
534
- and the number of output classes for transfer learning.
535
-
536
- NOTE: Use an Object Detection compatible trainer.
537
- """
538
- def __init__(self,
539
- num_classes: int,
540
- in_channels: int = 3,
541
- model_name: Literal["fasterrcnn_resnet50_fpn", "fasterrcnn_resnet50_fpn_v2"] = 'fasterrcnn_resnet50_fpn_v2',
542
- init_with_pretrained: bool = False):
543
- """
544
- Args:
545
- num_classes (int):
546
- Number of output classes (including background).
547
- in_channels (int):
548
- Number of input channels (e.g., 1 for grayscale, 3 for RGB).
549
- model_name (str):
550
- The name of the Faster R-CNN model to use.
551
- init_with_pretrained (bool):
552
- If True, initializes the model with weights pretrained on COCO.
553
- This flag is for initialization only and is NOT saved in the
554
- architecture config. Defaults to False.
555
- """
556
- super().__init__()
557
-
558
- # --- 1. Validation and Configuration ---
559
- if not hasattr(detection_models, model_name):
560
- _LOGGER.error(f"'{model_name}' is not a valid model name in torchvision.models.detection.")
561
- raise ValueError()
562
-
563
- self.num_classes = num_classes
564
- self.in_channels = in_channels
565
- self.model_name = model_name
566
- self._pretrained_default_transforms = None
567
-
568
- # --- 2. Instantiate the base model ---
569
- model_constructor = getattr(detection_models, model_name)
570
-
571
- # Format model name to find weights enum, e.g., fasterrcnn_resnet50_fpn_v2 -> FasterRCNN_ResNet50_FPN_V2_Weights
572
- weights_model_name = model_name.replace('fasterrcnn_', 'FasterRCNN_').replace('resnet', 'ResNet').replace('_fpn', '_FPN')
573
- weights_enum_name = f"{weights_model_name.upper()}_Weights"
574
-
575
- weights_enum = getattr(detection_models, weights_enum_name, None) if weights_enum_name else None
576
- weights = weights_enum.DEFAULT if weights_enum and init_with_pretrained else None
577
-
578
- if weights:
579
- self._pretrained_default_transforms = weights.transforms()
580
-
581
- self.model = model_constructor(weights=weights, weights_backbone=weights)
582
-
583
- # --- 4. Modify the output layer (Box Predictor) ---
584
- # Get the number of input features for the classifier
585
- in_features = self.model.roi_heads.box_predictor.cls_score.in_features
586
- # Replace the pre-trained head with a new one
587
- self.model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)
588
-
589
- # --- 3. Modify the input layer (Backbone conv1) ---
590
- if in_channels != 3:
591
- original_conv1 = self.model.backbone.body.conv1
592
-
593
- new_conv1 = nn.Conv2d(
594
- in_channels,
595
- original_conv1.out_channels,
596
- kernel_size=original_conv1.kernel_size, # type: ignore
597
- stride=original_conv1.stride, # type: ignore
598
- padding=original_conv1.padding, # type: ignore
599
- bias=(original_conv1.bias is not None)
600
- )
601
-
602
- # (Optional) Average original weights if starting from pretrained
603
- if init_with_pretrained and original_conv1.in_channels == 3 and weights is not None:
604
- with torch.no_grad():
605
- # Average the weights across the input channel dimension
606
- avg_weights = torch.mean(original_conv1.weight, dim=1, keepdim=True)
607
- # Repeat the averaged weights for the new number of input channels
608
- new_conv1.weight[:] = avg_weights.repeat(1, in_channels, 1, 1)
609
-
610
- self.model.backbone.body.conv1 = new_conv1
611
-
612
- def forward(self, images: List[torch.Tensor], targets: Optional[List[Dict[str, torch.Tensor]]] = None):
613
- """
614
- Defines the forward pass.
615
-
616
- - In train mode, expects (images, targets) and returns a dict of losses.
617
- - In eval mode, expects (images) and returns a list of prediction dicts.
618
- """
619
- # The model's forward pass handles train/eval mode internally.
620
- return self.model(images, targets)
621
-
622
- def get_architecture_config(self) -> Dict[str, Any]:
623
- """
624
- Returns the structural configuration of the model.
625
- The 'init_with_pretrained' flag is intentionally omitted,
626
- as .load() should restore the architecture, not the weights.
627
- """
628
- return {
629
- 'num_classes': self.num_classes,
630
- 'in_channels': self.in_channels,
631
- 'model_name': self.model_name
632
- }
633
-
634
- def __repr__(self) -> str:
635
- """Returns the developer-friendly string representation of the model."""
636
- return (
637
- f"{self.__class__.__name__}(model='{self.model_name}', "
638
- f"in_channels={self.in_channels}, "
639
- f"num_classes={self.num_classes})"
640
- )
641
-
642
-
643
- def info():
644
- _script_info(__all__)