stouputils 1.14.0__py3-none-any.whl → 1.14.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (108) hide show
  1. stouputils/__init__.pyi +15 -0
  2. stouputils/_deprecated.pyi +12 -0
  3. stouputils/all_doctests.pyi +46 -0
  4. stouputils/applications/__init__.pyi +2 -0
  5. stouputils/applications/automatic_docs.py +3 -0
  6. stouputils/applications/automatic_docs.pyi +106 -0
  7. stouputils/applications/upscaler/__init__.pyi +3 -0
  8. stouputils/applications/upscaler/config.pyi +18 -0
  9. stouputils/applications/upscaler/image.pyi +109 -0
  10. stouputils/applications/upscaler/video.pyi +60 -0
  11. stouputils/archive.pyi +67 -0
  12. stouputils/backup.pyi +109 -0
  13. stouputils/collections.pyi +86 -0
  14. stouputils/continuous_delivery/__init__.pyi +5 -0
  15. stouputils/continuous_delivery/cd_utils.pyi +129 -0
  16. stouputils/continuous_delivery/github.pyi +162 -0
  17. stouputils/continuous_delivery/pypi.pyi +52 -0
  18. stouputils/continuous_delivery/pyproject.pyi +67 -0
  19. stouputils/continuous_delivery/stubs.pyi +39 -0
  20. stouputils/ctx.pyi +211 -0
  21. stouputils/data_science/config/get.py +51 -51
  22. stouputils/data_science/data_processing/image/__init__.py +66 -66
  23. stouputils/data_science/data_processing/image/auto_contrast.py +79 -79
  24. stouputils/data_science/data_processing/image/axis_flip.py +58 -58
  25. stouputils/data_science/data_processing/image/bias_field_correction.py +74 -74
  26. stouputils/data_science/data_processing/image/binary_threshold.py +73 -73
  27. stouputils/data_science/data_processing/image/blur.py +59 -59
  28. stouputils/data_science/data_processing/image/brightness.py +54 -54
  29. stouputils/data_science/data_processing/image/canny.py +110 -110
  30. stouputils/data_science/data_processing/image/clahe.py +92 -92
  31. stouputils/data_science/data_processing/image/common.py +30 -30
  32. stouputils/data_science/data_processing/image/contrast.py +53 -53
  33. stouputils/data_science/data_processing/image/curvature_flow_filter.py +74 -74
  34. stouputils/data_science/data_processing/image/denoise.py +378 -378
  35. stouputils/data_science/data_processing/image/histogram_equalization.py +123 -123
  36. stouputils/data_science/data_processing/image/invert.py +64 -64
  37. stouputils/data_science/data_processing/image/laplacian.py +60 -60
  38. stouputils/data_science/data_processing/image/median_blur.py +52 -52
  39. stouputils/data_science/data_processing/image/noise.py +59 -59
  40. stouputils/data_science/data_processing/image/normalize.py +65 -65
  41. stouputils/data_science/data_processing/image/random_erase.py +66 -66
  42. stouputils/data_science/data_processing/image/resize.py +69 -69
  43. stouputils/data_science/data_processing/image/rotation.py +80 -80
  44. stouputils/data_science/data_processing/image/salt_pepper.py +68 -68
  45. stouputils/data_science/data_processing/image/sharpening.py +55 -55
  46. stouputils/data_science/data_processing/image/shearing.py +64 -64
  47. stouputils/data_science/data_processing/image/threshold.py +64 -64
  48. stouputils/data_science/data_processing/image/translation.py +71 -71
  49. stouputils/data_science/data_processing/image/zoom.py +83 -83
  50. stouputils/data_science/data_processing/image_augmentation.py +118 -118
  51. stouputils/data_science/data_processing/image_preprocess.py +183 -183
  52. stouputils/data_science/data_processing/prosthesis_detection.py +359 -359
  53. stouputils/data_science/data_processing/technique.py +481 -481
  54. stouputils/data_science/dataset/__init__.py +45 -45
  55. stouputils/data_science/dataset/dataset.py +292 -292
  56. stouputils/data_science/dataset/dataset_loader.py +135 -135
  57. stouputils/data_science/dataset/grouping_strategy.py +296 -296
  58. stouputils/data_science/dataset/image_loader.py +100 -100
  59. stouputils/data_science/dataset/xy_tuple.py +696 -696
  60. stouputils/data_science/metric_dictionnary.py +106 -106
  61. stouputils/data_science/mlflow_utils.py +206 -206
  62. stouputils/data_science/models/abstract_model.py +149 -149
  63. stouputils/data_science/models/all.py +85 -85
  64. stouputils/data_science/models/keras/all.py +38 -38
  65. stouputils/data_science/models/keras/convnext.py +62 -62
  66. stouputils/data_science/models/keras/densenet.py +50 -50
  67. stouputils/data_science/models/keras/efficientnet.py +60 -60
  68. stouputils/data_science/models/keras/mobilenet.py +56 -56
  69. stouputils/data_science/models/keras/resnet.py +52 -52
  70. stouputils/data_science/models/keras/squeezenet.py +233 -233
  71. stouputils/data_science/models/keras/vgg.py +42 -42
  72. stouputils/data_science/models/keras/xception.py +38 -38
  73. stouputils/data_science/models/keras_utils/callbacks/__init__.py +20 -20
  74. stouputils/data_science/models/keras_utils/callbacks/colored_progress_bar.py +219 -219
  75. stouputils/data_science/models/keras_utils/callbacks/learning_rate_finder.py +148 -148
  76. stouputils/data_science/models/keras_utils/callbacks/model_checkpoint_v2.py +31 -31
  77. stouputils/data_science/models/keras_utils/callbacks/progressive_unfreezing.py +249 -249
  78. stouputils/data_science/models/keras_utils/callbacks/warmup_scheduler.py +66 -66
  79. stouputils/data_science/models/keras_utils/losses/__init__.py +12 -12
  80. stouputils/data_science/models/keras_utils/losses/next_generation_loss.py +56 -56
  81. stouputils/data_science/models/keras_utils/visualizations.py +416 -416
  82. stouputils/data_science/models/sandbox.py +116 -116
  83. stouputils/data_science/range_tuple.py +234 -234
  84. stouputils/data_science/utils.py +285 -285
  85. stouputils/decorators.pyi +242 -0
  86. stouputils/image.pyi +172 -0
  87. stouputils/installer/__init__.py +18 -18
  88. stouputils/installer/__init__.pyi +5 -0
  89. stouputils/installer/common.pyi +39 -0
  90. stouputils/installer/downloader.pyi +24 -0
  91. stouputils/installer/linux.py +144 -144
  92. stouputils/installer/linux.pyi +39 -0
  93. stouputils/installer/main.py +223 -223
  94. stouputils/installer/main.pyi +57 -0
  95. stouputils/installer/windows.py +136 -136
  96. stouputils/installer/windows.pyi +31 -0
  97. stouputils/io.pyi +213 -0
  98. stouputils/parallel.py +12 -10
  99. stouputils/parallel.pyi +211 -0
  100. stouputils/print.pyi +136 -0
  101. stouputils/py.typed +1 -1
  102. stouputils/stouputils/parallel.pyi +4 -4
  103. stouputils/version_pkg.pyi +15 -0
  104. {stouputils-1.14.0.dist-info → stouputils-1.14.2.dist-info}/METADATA +1 -1
  105. stouputils-1.14.2.dist-info/RECORD +171 -0
  106. stouputils-1.14.0.dist-info/RECORD +0 -140
  107. {stouputils-1.14.0.dist-info → stouputils-1.14.2.dist-info}/WHEEL +0 -0
  108. {stouputils-1.14.0.dist-info → stouputils-1.14.2.dist-info}/entry_points.txt +0 -0
@@ -1,233 +1,233 @@
1
- """ SqueezeNet model implementation.
2
-
3
- This module provides a wrapper class for the SqueezeNet model, a lightweight CNN architecture
4
- that achieves AlexNet-level accuracy with 50x fewer parameters and a model size of less than 0.5MB.
5
- SqueezeNet uses "fire modules" consisting of a squeeze layer with 1x1 filters followed by an
6
- expand layer with a mix of 1x1 and 3x3 convolution filters.
7
-
8
- Available models:
9
- - SqueezeNet: Compact model with excellent performance-to-parameter ratio
10
-
11
- The model supports transfer learning from ImageNet pre-trained weights.
12
- """
13
- # pyright: reportUnknownArgumentType=false
14
- # pyright: reportUnknownMemberType=false
15
- # pyright: reportUnknownVariableType=false
16
- # pyright: reportMissingTypeStubs=false
17
-
18
- # Imports
19
- from __future__ import annotations
20
-
21
- from typing import Any
22
-
23
- from keras import backend
24
- from keras.layers import (
25
- Activation,
26
- Convolution2D,
27
- Dropout,
28
- GlobalAveragePooling2D,
29
- GlobalMaxPooling2D,
30
- Input,
31
- MaxPooling2D,
32
- concatenate,
33
- )
34
- from keras.models import Model
35
- from keras.utils import get_file, get_source_inputs
36
-
37
- from ....decorators import simple_cache
38
- from ..base_keras import BaseKeras
39
- from ..model_interface import CLASS_ROUTINE_DOCSTRING, MODEL_DOCSTRING
40
-
41
- # Constants
42
- SQ1X1: str = "squeeze1x1"
43
-
44
- WEIGHTS_PATH = "https://github.com/rcmalli/keras-squeezenet/releases/download/v1.0/squeezenet_weights_tf_dim_ordering_tf_kernels.h5"
45
- WEIGHTS_PATH_NO_TOP = "https://github.com/rcmalli/keras-squeezenet/releases/download/v1.0/squeezenet_weights_tf_dim_ordering_tf_kernels_notop.h5"
46
-
47
-
48
- # Modular function for Fire Node
49
- def fire_module(x: Any, fire_id: int, squeeze: int = 16, expand: int = 64):
50
- """ Create a fire module with specified parameters.
51
-
52
- Args:
53
- x (Tensor): Input tensor
54
- fire_id (int): ID for the fire module
55
- squeeze (int): Number of filters for squeeze layer
56
- expand (int): Number of filters for expand layers
57
-
58
- Returns:
59
- Tensor: Output tensor from the fire module
60
- """
61
- s_id: str = f"fire{fire_id}"
62
-
63
- if backend.image_data_format() == "channels_first":
64
- channel_axis: int = 1
65
- else:
66
- channel_axis: int = 3
67
-
68
- x = Convolution2D(squeeze, (1, 1), padding="valid", name=f"{s_id}/squeeze1x1")(x)
69
- x = Activation("relu", name=f"{s_id}/relu_squeeze1x1")(x)
70
-
71
- left = Convolution2D(expand, (1, 1), padding="valid", name=f"{s_id}/expand1x1")(x)
72
- left = Activation("relu", name=f"{s_id}/relu_expand1x1")(left)
73
-
74
- right = Convolution2D(expand, (3, 3), padding="same", name=f"{s_id}/expand3x3")(x)
75
- right = Activation("relu", name=f"{s_id}/relu_expand3x3")(right)
76
-
77
- x = concatenate([left, right], axis=channel_axis, name=f"{s_id}/concat")
78
- return x
79
-
80
-
81
- # Original SqueezeNet from paper
82
- def SqueezeNet_keras( # noqa: N802
83
- include_top: bool = True,
84
- weights: str = "imagenet",
85
- input_tensor: Any = None,
86
- input_shape: tuple[Any, ...] | None = None,
87
- pooling: str | None = None,
88
- classes: int = 1000
89
- ) -> Model:
90
- """ Instantiates the SqueezeNet architecture.
91
-
92
- Args:
93
- include_top (bool): Whether to include the fully-connected layer at the top
94
- weights (str): One of `None` or 'imagenet'
95
- input_tensor (Tensor): Optional Keras tensor as input
96
- input_shape (tuple): Optional shape tuple
97
- pooling (str): Optional pooling mode for feature extraction
98
- classes (int): Number of classes to classify images into
99
-
100
- Returns:
101
- Model: A Keras model instance
102
- """
103
-
104
- if weights not in {'imagenet', None}:
105
- raise ValueError(
106
- "The `weights` argument should be either `None` (random initialization) "
107
- "or `imagenet` (pre-training on ImageNet)."
108
- )
109
-
110
- if include_top and weights == 'imagenet' and classes != 1000:
111
- raise ValueError(
112
- "If using `weights` as imagenet with `include_top` as true, `classes` should be 1000"
113
- )
114
-
115
- # Manually handle input shape logic instead of _obtain_input_shape
116
- default_size: int = 227
117
- min_size: int = 48
118
- if backend.image_data_format() == 'channels_first':
119
- default_shape: tuple[int, int, int] = (3, default_size, default_size)
120
- if weights == 'imagenet' and include_top and input_shape is not None and input_shape[0] != 3:
121
- raise ValueError(
122
- "When specifying `input_shape` and loading 'imagenet' weights, 'channels_first' input_shape "
123
- "should be (3, H, W)."
124
- )
125
- else: # channels_last
126
- default_shape = (default_size, default_size, 3)
127
- if weights == 'imagenet' and include_top and input_shape is not None and input_shape[2] != 3:
128
- raise ValueError(
129
- "When specifying `input_shape` and loading 'imagenet' weights, 'channels_last' input_shape "
130
- "should be (H, W, 3)."
131
- )
132
-
133
- if input_shape is None:
134
- input_shape = default_shape
135
- else:
136
- # Basic validation
137
- if len(input_shape) != 3:
138
- raise ValueError("`input_shape` must be a tuple of three integers.")
139
- if backend.image_data_format() == 'channels_first':
140
- if input_shape[1] is not None and input_shape[1] < min_size:
141
- raise ValueError(f"Input size must be at least {min_size}x{min_size}, got `input_shape=`{input_shape}")
142
- if input_shape[2] is not None and input_shape[2] < min_size:
143
- raise ValueError(f"Input size must be at least {min_size}x{min_size}, got `input_shape=`{input_shape}")
144
- else: # channels_last
145
- if input_shape[0] is not None and input_shape[0] < min_size:
146
- raise ValueError(f"Input size must be at least {min_size}x{min_size}, got `input_shape=`{input_shape}")
147
- if input_shape[1] is not None and input_shape[1] < min_size:
148
- raise ValueError(f"Input size must be at least {min_size}x{min_size}, got `input_shape=`{input_shape}")
149
-
150
- # Handle input tensor
151
- if input_tensor is None:
152
- img_input = Input(shape=input_shape)
153
- else:
154
- if not backend.is_keras_tensor(input_tensor):
155
- img_input = Input(tensor=input_tensor, shape=input_shape)
156
- else:
157
- img_input = input_tensor
158
-
159
- x = Convolution2D(64, (3, 3), strides=(2, 2), padding='valid', name='conv1')(img_input)
160
- x = Activation('relu', name='relu_conv1')(x)
161
- x = MaxPooling2D(pool_size=(3, 3), strides=(2, 2), name='pool1')(x)
162
-
163
- x = fire_module(x, fire_id=2, squeeze=16, expand=64)
164
- x = fire_module(x, fire_id=3, squeeze=16, expand=64)
165
- x = MaxPooling2D(pool_size=(3, 3), strides=(2, 2), name='pool3')(x)
166
-
167
- x = fire_module(x, fire_id=4, squeeze=32, expand=128)
168
- x = fire_module(x, fire_id=5, squeeze=32, expand=128)
169
- x = MaxPooling2D(pool_size=(3, 3), strides=(2, 2), name='pool5')(x)
170
-
171
- x = fire_module(x, fire_id=6, squeeze=48, expand=192)
172
- x = fire_module(x, fire_id=7, squeeze=48, expand=192)
173
- x = fire_module(x, fire_id=8, squeeze=64, expand=256)
174
- x = fire_module(x, fire_id=9, squeeze=64, expand=256)
175
-
176
- if include_top:
177
- # It's not obvious where to cut the network...
178
- # Could do the 8th or 9th layer... some work recommends cutting earlier layers.
179
-
180
- x = Dropout(0.5, name='drop9')(x)
181
-
182
- x = Convolution2D(classes, (1, 1), padding='valid', name='conv10')(x)
183
- x = Activation('relu', name='relu_conv10')(x)
184
- x = GlobalAveragePooling2D()(x)
185
- x = Activation('softmax', name='loss')(x)
186
- else:
187
- if pooling == 'avg':
188
- x = GlobalAveragePooling2D()(x)
189
- elif pooling == 'max':
190
- x = GlobalMaxPooling2D()(x)
191
- elif pooling is None:
192
- pass
193
- else:
194
- raise ValueError("Unknown argument for 'pooling'=" + pooling)
195
-
196
- # Ensure that the model takes into account
197
- # any potential predecessors of `input_tensor`.
198
- if input_tensor is not None:
199
- inputs = get_source_inputs(input_tensor)
200
- else:
201
- inputs = img_input
202
-
203
- model = Model(inputs, x, name='squeezenet')
204
-
205
- # load weights
206
- if weights == 'imagenet':
207
- if include_top:
208
- weights_path = get_file('squeezenet_weights_tf_dim_ordering_tf_kernels.h5',
209
- WEIGHTS_PATH,
210
- cache_subdir='models')
211
- else:
212
- weights_path = get_file('squeezenet_weights_tf_dim_ordering_tf_kernels_notop.h5',
213
- WEIGHTS_PATH_NO_TOP,
214
- cache_subdir='models')
215
-
216
- model.load_weights(weights_path)
217
- return model
218
-
219
-
220
- # Classes
221
- class SqueezeNet(BaseKeras):
222
- def _get_base_model(self) -> Model:
223
- return SqueezeNet_keras(
224
- include_top=False, classes=self.num_classes, input_shape=(224, 224, 3)
225
- )
226
-
227
-
228
- # Docstrings
229
- for model in [SqueezeNet]:
230
- model.__doc__ = MODEL_DOCSTRING.format(model=model.__name__)
231
- model.class_routine = simple_cache(model.class_routine)
232
- model.class_routine.__doc__ = CLASS_ROUTINE_DOCSTRING.format(model=model.__name__)
233
-
1
+ """ SqueezeNet model implementation.
2
+
3
+ This module provides a wrapper class for the SqueezeNet model, a lightweight CNN architecture
4
+ that achieves AlexNet-level accuracy with 50x fewer parameters and a model size of less than 0.5MB.
5
+ SqueezeNet uses "fire modules" consisting of a squeeze layer with 1x1 filters followed by an
6
+ expand layer with a mix of 1x1 and 3x3 convolution filters.
7
+
8
+ Available models:
9
+ - SqueezeNet: Compact model with excellent performance-to-parameter ratio
10
+
11
+ The model supports transfer learning from ImageNet pre-trained weights.
12
+ """
13
+ # pyright: reportUnknownArgumentType=false
14
+ # pyright: reportUnknownMemberType=false
15
+ # pyright: reportUnknownVariableType=false
16
+ # pyright: reportMissingTypeStubs=false
17
+
18
+ # Imports
19
+ from __future__ import annotations
20
+
21
+ from typing import Any
22
+
23
+ from keras import backend
24
+ from keras.layers import (
25
+ Activation,
26
+ Convolution2D,
27
+ Dropout,
28
+ GlobalAveragePooling2D,
29
+ GlobalMaxPooling2D,
30
+ Input,
31
+ MaxPooling2D,
32
+ concatenate,
33
+ )
34
+ from keras.models import Model
35
+ from keras.utils import get_file, get_source_inputs
36
+
37
+ from ....decorators import simple_cache
38
+ from ..base_keras import BaseKeras
39
+ from ..model_interface import CLASS_ROUTINE_DOCSTRING, MODEL_DOCSTRING
40
+
41
+ # Constants
42
+ SQ1X1: str = "squeeze1x1"
43
+
44
+ WEIGHTS_PATH = "https://github.com/rcmalli/keras-squeezenet/releases/download/v1.0/squeezenet_weights_tf_dim_ordering_tf_kernels.h5"
45
+ WEIGHTS_PATH_NO_TOP = "https://github.com/rcmalli/keras-squeezenet/releases/download/v1.0/squeezenet_weights_tf_dim_ordering_tf_kernels_notop.h5"
46
+
47
+
48
+ # Modular function for Fire Node
49
+ def fire_module(x: Any, fire_id: int, squeeze: int = 16, expand: int = 64):
50
+ """ Create a fire module with specified parameters.
51
+
52
+ Args:
53
+ x (Tensor): Input tensor
54
+ fire_id (int): ID for the fire module
55
+ squeeze (int): Number of filters for squeeze layer
56
+ expand (int): Number of filters for expand layers
57
+
58
+ Returns:
59
+ Tensor: Output tensor from the fire module
60
+ """
61
+ s_id: str = f"fire{fire_id}"
62
+
63
+ if backend.image_data_format() == "channels_first":
64
+ channel_axis: int = 1
65
+ else:
66
+ channel_axis: int = 3
67
+
68
+ x = Convolution2D(squeeze, (1, 1), padding="valid", name=f"{s_id}/squeeze1x1")(x)
69
+ x = Activation("relu", name=f"{s_id}/relu_squeeze1x1")(x)
70
+
71
+ left = Convolution2D(expand, (1, 1), padding="valid", name=f"{s_id}/expand1x1")(x)
72
+ left = Activation("relu", name=f"{s_id}/relu_expand1x1")(left)
73
+
74
+ right = Convolution2D(expand, (3, 3), padding="same", name=f"{s_id}/expand3x3")(x)
75
+ right = Activation("relu", name=f"{s_id}/relu_expand3x3")(right)
76
+
77
+ x = concatenate([left, right], axis=channel_axis, name=f"{s_id}/concat")
78
+ return x
79
+
80
+
81
+ # Original SqueezeNet from paper
82
+ def SqueezeNet_keras( # noqa: N802
83
+ include_top: bool = True,
84
+ weights: str = "imagenet",
85
+ input_tensor: Any = None,
86
+ input_shape: tuple[Any, ...] | None = None,
87
+ pooling: str | None = None,
88
+ classes: int = 1000
89
+ ) -> Model:
90
+ """ Instantiates the SqueezeNet architecture.
91
+
92
+ Args:
93
+ include_top (bool): Whether to include the fully-connected layer at the top
94
+ weights (str): One of `None` or 'imagenet'
95
+ input_tensor (Tensor): Optional Keras tensor as input
96
+ input_shape (tuple): Optional shape tuple
97
+ pooling (str): Optional pooling mode for feature extraction
98
+ classes (int): Number of classes to classify images into
99
+
100
+ Returns:
101
+ Model: A Keras model instance
102
+ """
103
+
104
+ if weights not in {'imagenet', None}:
105
+ raise ValueError(
106
+ "The `weights` argument should be either `None` (random initialization) "
107
+ "or `imagenet` (pre-training on ImageNet)."
108
+ )
109
+
110
+ if include_top and weights == 'imagenet' and classes != 1000:
111
+ raise ValueError(
112
+ "If using `weights` as imagenet with `include_top` as true, `classes` should be 1000"
113
+ )
114
+
115
+ # Manually handle input shape logic instead of _obtain_input_shape
116
+ default_size: int = 227
117
+ min_size: int = 48
118
+ if backend.image_data_format() == 'channels_first':
119
+ default_shape: tuple[int, int, int] = (3, default_size, default_size)
120
+ if weights == 'imagenet' and include_top and input_shape is not None and input_shape[0] != 3:
121
+ raise ValueError(
122
+ "When specifying `input_shape` and loading 'imagenet' weights, 'channels_first' input_shape "
123
+ "should be (3, H, W)."
124
+ )
125
+ else: # channels_last
126
+ default_shape = (default_size, default_size, 3)
127
+ if weights == 'imagenet' and include_top and input_shape is not None and input_shape[2] != 3:
128
+ raise ValueError(
129
+ "When specifying `input_shape` and loading 'imagenet' weights, 'channels_last' input_shape "
130
+ "should be (H, W, 3)."
131
+ )
132
+
133
+ if input_shape is None:
134
+ input_shape = default_shape
135
+ else:
136
+ # Basic validation
137
+ if len(input_shape) != 3:
138
+ raise ValueError("`input_shape` must be a tuple of three integers.")
139
+ if backend.image_data_format() == 'channels_first':
140
+ if input_shape[1] is not None and input_shape[1] < min_size:
141
+ raise ValueError(f"Input size must be at least {min_size}x{min_size}, got `input_shape=`{input_shape}")
142
+ if input_shape[2] is not None and input_shape[2] < min_size:
143
+ raise ValueError(f"Input size must be at least {min_size}x{min_size}, got `input_shape=`{input_shape}")
144
+ else: # channels_last
145
+ if input_shape[0] is not None and input_shape[0] < min_size:
146
+ raise ValueError(f"Input size must be at least {min_size}x{min_size}, got `input_shape=`{input_shape}")
147
+ if input_shape[1] is not None and input_shape[1] < min_size:
148
+ raise ValueError(f"Input size must be at least {min_size}x{min_size}, got `input_shape=`{input_shape}")
149
+
150
+ # Handle input tensor
151
+ if input_tensor is None:
152
+ img_input = Input(shape=input_shape)
153
+ else:
154
+ if not backend.is_keras_tensor(input_tensor):
155
+ img_input = Input(tensor=input_tensor, shape=input_shape)
156
+ else:
157
+ img_input = input_tensor
158
+
159
+ x = Convolution2D(64, (3, 3), strides=(2, 2), padding='valid', name='conv1')(img_input)
160
+ x = Activation('relu', name='relu_conv1')(x)
161
+ x = MaxPooling2D(pool_size=(3, 3), strides=(2, 2), name='pool1')(x)
162
+
163
+ x = fire_module(x, fire_id=2, squeeze=16, expand=64)
164
+ x = fire_module(x, fire_id=3, squeeze=16, expand=64)
165
+ x = MaxPooling2D(pool_size=(3, 3), strides=(2, 2), name='pool3')(x)
166
+
167
+ x = fire_module(x, fire_id=4, squeeze=32, expand=128)
168
+ x = fire_module(x, fire_id=5, squeeze=32, expand=128)
169
+ x = MaxPooling2D(pool_size=(3, 3), strides=(2, 2), name='pool5')(x)
170
+
171
+ x = fire_module(x, fire_id=6, squeeze=48, expand=192)
172
+ x = fire_module(x, fire_id=7, squeeze=48, expand=192)
173
+ x = fire_module(x, fire_id=8, squeeze=64, expand=256)
174
+ x = fire_module(x, fire_id=9, squeeze=64, expand=256)
175
+
176
+ if include_top:
177
+ # It's not obvious where to cut the network...
178
+ # Could do the 8th or 9th layer... some work recommends cutting earlier layers.
179
+
180
+ x = Dropout(0.5, name='drop9')(x)
181
+
182
+ x = Convolution2D(classes, (1, 1), padding='valid', name='conv10')(x)
183
+ x = Activation('relu', name='relu_conv10')(x)
184
+ x = GlobalAveragePooling2D()(x)
185
+ x = Activation('softmax', name='loss')(x)
186
+ else:
187
+ if pooling == 'avg':
188
+ x = GlobalAveragePooling2D()(x)
189
+ elif pooling == 'max':
190
+ x = GlobalMaxPooling2D()(x)
191
+ elif pooling is None:
192
+ pass
193
+ else:
194
+ raise ValueError("Unknown argument for 'pooling'=" + pooling)
195
+
196
+ # Ensure that the model takes into account
197
+ # any potential predecessors of `input_tensor`.
198
+ if input_tensor is not None:
199
+ inputs = get_source_inputs(input_tensor)
200
+ else:
201
+ inputs = img_input
202
+
203
+ model = Model(inputs, x, name='squeezenet')
204
+
205
+ # load weights
206
+ if weights == 'imagenet':
207
+ if include_top:
208
+ weights_path = get_file('squeezenet_weights_tf_dim_ordering_tf_kernels.h5',
209
+ WEIGHTS_PATH,
210
+ cache_subdir='models')
211
+ else:
212
+ weights_path = get_file('squeezenet_weights_tf_dim_ordering_tf_kernels_notop.h5',
213
+ WEIGHTS_PATH_NO_TOP,
214
+ cache_subdir='models')
215
+
216
+ model.load_weights(weights_path)
217
+ return model
218
+
219
+
220
+ # Classes
221
+ class SqueezeNet(BaseKeras):
222
+ def _get_base_model(self) -> Model:
223
+ return SqueezeNet_keras(
224
+ include_top=False, classes=self.num_classes, input_shape=(224, 224, 3)
225
+ )
226
+
227
+
228
+ # Docstrings
229
+ for model in [SqueezeNet]:
230
+ model.__doc__ = MODEL_DOCSTRING.format(model=model.__name__)
231
+ model.class_routine = simple_cache(model.class_routine)
232
+ model.class_routine.__doc__ = CLASS_ROUTINE_DOCSTRING.format(model=model.__name__)
233
+
@@ -1,42 +1,42 @@
1
- """ VGG models implementation.
2
-
3
- This module provides wrapper classes for the VGG family of models from the Keras applications.
4
- VGG models are characterized by their simplicity, using only 3x3 convolutional layers
5
- stacked on top of each other with increasing depth.
6
-
7
- Available models:
8
- - VGG16: 16-layer model with 13 convolutional layers and 3 fully connected layers
9
- - VGG19: 19-layer model with 16 convolutional layers and 3 fully connected layers
10
-
11
- Both models support transfer learning from ImageNet pre-trained weights.
12
- """
13
- # pyright: reportUnknownVariableType=false
14
- # pyright: reportMissingTypeStubs=false
15
-
16
- # Imports
17
- from __future__ import annotations
18
-
19
- from keras.models import Model
20
- from keras.src.applications.vgg16 import VGG16 as VGG16_keras # noqa: N811
21
- from keras.src.applications.vgg19 import VGG19 as VGG19_keras # noqa: N811
22
-
23
- from ....decorators import simple_cache
24
- from ..base_keras import BaseKeras
25
- from ..model_interface import CLASS_ROUTINE_DOCSTRING, MODEL_DOCSTRING
26
-
27
-
28
- # Base class
29
- class VGG19(BaseKeras):
30
- def _get_base_model(self) -> Model:
31
- return VGG19_keras(include_top=False, classes=self.num_classes)
32
- class VGG16(BaseKeras):
33
- def _get_base_model(self) -> Model:
34
- return VGG16_keras(include_top=False, classes=self.num_classes)
35
-
36
-
37
- # Docstrings
38
- for model in [VGG19, VGG16]:
39
- model.__doc__ = MODEL_DOCSTRING.format(model=model.__name__)
40
- model.class_routine = simple_cache(model.class_routine)
41
- model.class_routine.__doc__ = CLASS_ROUTINE_DOCSTRING.format(model=model.__name__)
42
-
1
+ """ VGG models implementation.
2
+
3
+ This module provides wrapper classes for the VGG family of models from the Keras applications.
4
+ VGG models are characterized by their simplicity, using only 3x3 convolutional layers
5
+ stacked on top of each other with increasing depth.
6
+
7
+ Available models:
8
+ - VGG16: 16-layer model with 13 convolutional layers and 3 fully connected layers
9
+ - VGG19: 19-layer model with 16 convolutional layers and 3 fully connected layers
10
+
11
+ Both models support transfer learning from ImageNet pre-trained weights.
12
+ """
13
+ # pyright: reportUnknownVariableType=false
14
+ # pyright: reportMissingTypeStubs=false
15
+
16
+ # Imports
17
+ from __future__ import annotations
18
+
19
+ from keras.models import Model
20
+ from keras.src.applications.vgg16 import VGG16 as VGG16_keras # noqa: N811
21
+ from keras.src.applications.vgg19 import VGG19 as VGG19_keras # noqa: N811
22
+
23
+ from ....decorators import simple_cache
24
+ from ..base_keras import BaseKeras
25
+ from ..model_interface import CLASS_ROUTINE_DOCSTRING, MODEL_DOCSTRING
26
+
27
+
28
+ # Base class
29
+ class VGG19(BaseKeras):
30
+ def _get_base_model(self) -> Model:
31
+ return VGG19_keras(include_top=False, classes=self.num_classes)
32
+ class VGG16(BaseKeras):
33
+ def _get_base_model(self) -> Model:
34
+ return VGG16_keras(include_top=False, classes=self.num_classes)
35
+
36
+
37
+ # Docstrings
38
+ for model in [VGG19, VGG16]:
39
+ model.__doc__ = MODEL_DOCSTRING.format(model=model.__name__)
40
+ model.class_routine = simple_cache(model.class_routine)
41
+ model.class_routine.__doc__ = CLASS_ROUTINE_DOCSTRING.format(model=model.__name__)
42
+
@@ -1,38 +1,38 @@
1
- """ Xception model implementation.
2
-
3
- This module provides a wrapper class for the Xception model, a deep convolutional neural network
4
- designed for efficient image classification. Xception uses depthwise separable convolutions,
5
- which significantly reduce the number of parameters and computational complexity compared to
6
- standard convolutional layers.
7
-
8
- Available models:
9
- - Xception: The standard Xception model
10
-
11
- The model supports transfer learning from ImageNet pre-trained weights.
12
- """
13
- # pyright: reportUnknownVariableType=false
14
- # pyright: reportMissingTypeStubs=false
15
-
16
- # Imports
17
- from __future__ import annotations
18
-
19
- from keras.models import Model
20
- from keras.src.applications.xception import Xception as Xception_keras
21
-
22
- from ....decorators import simple_cache
23
- from ..base_keras import BaseKeras
24
- from ..model_interface import CLASS_ROUTINE_DOCSTRING, MODEL_DOCSTRING
25
-
26
-
27
- # Base class
28
- class Xception(BaseKeras):
29
- def _get_base_model(self) -> Model:
30
- return Xception_keras(include_top=False, classes=self.num_classes)
31
-
32
-
33
- # Docstrings
34
- for model in [Xception]:
35
- model.__doc__ = MODEL_DOCSTRING.format(model=model.__name__)
36
- model.class_routine = simple_cache(model.class_routine)
37
- model.class_routine.__doc__ = CLASS_ROUTINE_DOCSTRING.format(model=model.__name__)
38
-
1
+ """ Xception model implementation.
2
+
3
+ This module provides a wrapper class for the Xception model, a deep convolutional neural network
4
+ designed for efficient image classification. Xception uses depthwise separable convolutions,
5
+ which significantly reduce the number of parameters and computational complexity compared to
6
+ standard convolutional layers.
7
+
8
+ Available models:
9
+ - Xception: The standard Xception model
10
+
11
+ The model supports transfer learning from ImageNet pre-trained weights.
12
+ """
13
+ # pyright: reportUnknownVariableType=false
14
+ # pyright: reportMissingTypeStubs=false
15
+
16
+ # Imports
17
+ from __future__ import annotations
18
+
19
+ from keras.models import Model
20
+ from keras.src.applications.xception import Xception as Xception_keras
21
+
22
+ from ....decorators import simple_cache
23
+ from ..base_keras import BaseKeras
24
+ from ..model_interface import CLASS_ROUTINE_DOCSTRING, MODEL_DOCSTRING
25
+
26
+
27
+ # Base class
28
+ class Xception(BaseKeras):
29
+ def _get_base_model(self) -> Model:
30
+ return Xception_keras(include_top=False, classes=self.num_classes)
31
+
32
+
33
+ # Docstrings
34
+ for model in [Xception]:
35
+ model.__doc__ = MODEL_DOCSTRING.format(model=model.__name__)
36
+ model.class_routine = simple_cache(model.class_routine)
37
+ model.class_routine.__doc__ = CLASS_ROUTINE_DOCSTRING.format(model=model.__name__)
38
+