mct-nightly 2.1.0.20240724.437__py3-none-any.whl → 2.1.0.20240726.430__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (35) hide show
  1. {mct_nightly-2.1.0.20240724.437.dist-info → mct_nightly-2.1.0.20240726.430.dist-info}/METADATA +1 -1
  2. {mct_nightly-2.1.0.20240724.437.dist-info → mct_nightly-2.1.0.20240726.430.dist-info}/RECORD +35 -31
  3. model_compression_toolkit/__init__.py +1 -1
  4. model_compression_toolkit/core/pytorch/constants.py +6 -1
  5. model_compression_toolkit/core/pytorch/utils.py +27 -0
  6. model_compression_toolkit/data_generation/common/data_generation.py +20 -18
  7. model_compression_toolkit/data_generation/common/data_generation_config.py +8 -11
  8. model_compression_toolkit/data_generation/common/enums.py +24 -12
  9. model_compression_toolkit/data_generation/common/image_pipeline.py +50 -12
  10. model_compression_toolkit/data_generation/common/model_info_exctractors.py +0 -8
  11. model_compression_toolkit/data_generation/common/optimization_utils.py +7 -11
  12. model_compression_toolkit/data_generation/keras/constants.py +5 -2
  13. model_compression_toolkit/data_generation/keras/image_operations.py +189 -0
  14. model_compression_toolkit/data_generation/keras/image_pipeline.py +50 -104
  15. model_compression_toolkit/data_generation/keras/keras_data_generation.py +28 -36
  16. model_compression_toolkit/data_generation/keras/model_info_exctractors.py +0 -13
  17. model_compression_toolkit/data_generation/keras/optimization_functions/bn_layer_weighting_functions.py +16 -6
  18. model_compression_toolkit/data_generation/keras/optimization_functions/lr_scheduler.py +219 -0
  19. model_compression_toolkit/data_generation/keras/optimization_functions/output_loss_functions.py +39 -13
  20. model_compression_toolkit/data_generation/keras/optimization_functions/scheduler_step_functions.py +6 -98
  21. model_compression_toolkit/data_generation/keras/optimization_utils.py +15 -28
  22. model_compression_toolkit/data_generation/pytorch/constants.py +4 -1
  23. model_compression_toolkit/data_generation/pytorch/image_operations.py +105 -0
  24. model_compression_toolkit/data_generation/pytorch/image_pipeline.py +70 -78
  25. model_compression_toolkit/data_generation/pytorch/model_info_exctractors.py +0 -10
  26. model_compression_toolkit/data_generation/pytorch/optimization_functions/bn_layer_weighting_functions.py +17 -6
  27. model_compression_toolkit/data_generation/pytorch/optimization_functions/image_initilization.py +2 -2
  28. model_compression_toolkit/data_generation/pytorch/optimization_functions/lr_scheduler.py +219 -0
  29. model_compression_toolkit/data_generation/pytorch/optimization_functions/output_loss_functions.py +55 -21
  30. model_compression_toolkit/data_generation/pytorch/optimization_functions/scheduler_step_functions.py +15 -0
  31. model_compression_toolkit/data_generation/pytorch/optimization_utils.py +32 -54
  32. model_compression_toolkit/data_generation/pytorch/pytorch_data_generation.py +57 -52
  33. {mct_nightly-2.1.0.20240724.437.dist-info → mct_nightly-2.1.0.20240726.430.dist-info}/LICENSE.md +0 -0
  34. {mct_nightly-2.1.0.20240724.437.dist-info → mct_nightly-2.1.0.20240726.430.dist-info}/WHEEL +0 -0
  35. {mct_nightly-2.1.0.20240724.437.dist-info → mct_nightly-2.1.0.20240726.430.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,189 @@
1
+ # Copyright 2024 Sony Semiconductor Israel, Inc. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ import numpy as np
16
+ from typing import Tuple
17
+
18
+ import tensorflow as tf
19
+
20
+ def random_crop(image: tf.Tensor,
21
+ height_crop: int,
22
+ width_crop: int) -> tf.Tensor:
23
+ """
24
+ Randomly crop an image to the specified size.
25
+
26
+ Args:
27
+ image (tf.Tensor): Input image tensor.
28
+ height_crop (int): Size of the crop in the height axis.
29
+ width_crop (int): Size of the crop in the width axis.
30
+
31
+ Returns:
32
+ tf.Tensor: Cropped image tensor.
33
+ """
34
+ cropped_image = tf.image.random_crop(image,
35
+ size=(tf.shape(image)[0],
36
+ height_crop,
37
+ width_crop,
38
+ tf.shape(image)[-1]))
39
+ return cropped_image
40
+
41
+
42
+ def center_crop(image: tf.Tensor,
43
+ height_crop: int,
44
+ width_crop: int) -> tf.Tensor:
45
+ """
46
+ Center crop an image to the specified size.
47
+
48
+ Args:
49
+ image (tf.Tensor): Input image tensor.
50
+ output_size (Tuple): Size of image after the crop (height and width).
51
+
52
+ Returns:
53
+ tf.Tensor: Cropped image tensor.
54
+ """
55
+
56
+ # Calculate the cropping dimensions
57
+ input_shape = tf.shape(image)
58
+ height, width = input_shape[1], input_shape[2]
59
+
60
+ # Calculate the cropping offsets
61
+ offset_height = tf.maximum((height - height_crop) // 2, 0)
62
+ offset_width = tf.maximum((width - width_crop) // 2, 0)
63
+
64
+ # Crop the image
65
+ cropped_image = tf.image.crop_to_bounding_box(image, offset_height, offset_width, height_crop, width_crop)
66
+
67
+ return cropped_image
68
+
69
+
70
+ def random_flip(image: tf.Tensor) -> tf.Tensor:
71
+ """
72
+ Randomly flip an image horizontally with a specified probability.
73
+
74
+ Args:
75
+ image (tf.Tensor): Input image tensor.
76
+
77
+ Returns:
78
+ tf.Tensor: Flipped image tensor.
79
+ """
80
+ flip_image = tf.image.random_flip_left_right(image)
81
+ return flip_image
82
+
83
+
84
+ def clip_images(images: tf.Tensor, valid_grid: tf.Tensor, reflection: bool = False) -> tf.Tensor:
85
+ """
86
+ Clip the images based on a valid grid.
87
+
88
+ Args:
89
+ images (tf.Tensor): The images to be clipped.
90
+ valid_grid (tf.Tensor): The valid grid for clipping.
91
+ reflection (bool): Whether to apply reflection during clipping. Defaults to False.
92
+
93
+ Returns:
94
+ tf.Tensor: The clipped images.
95
+ """
96
+ clipped_images = tf.TensorArray(tf.float32, size=images.shape[1])
97
+
98
+ for i in range(valid_grid.shape[0]):
99
+ channel = images[:, i, :, :]
100
+ min_val = tf.reduce_min(valid_grid[i, :])
101
+ max_val = tf.reduce_max(valid_grid[i, :])
102
+ clamp = tf.clip_by_value(channel, min_val, max_val)
103
+ if reflection:
104
+ channel = 2 * clamp - channel
105
+ else:
106
+ channel = clamp
107
+ clipped_images = clipped_images.write(i, channel)
108
+
109
+ clipped_images = clipped_images.stack()
110
+ return tf.transpose(clipped_images, perm=[1, 0, 2, 3])
111
+
112
+
113
+ def create_valid_grid(means, stds) -> tf.Tensor:
114
+ """
115
+ Create a valid grid for image normalization.
116
+
117
+ Returns:
118
+ tf.Tensor: The valid grid for image normalization.
119
+ """
120
+ # Create a pixel grid in the range 0-255, repeat for 3 color channels, and reshape
121
+ pixel_grid = np.arange(256).repeat(3).reshape(-1, 3)
122
+
123
+ # Transpose and add batch and channel dimensions
124
+ pixel_grid = tf.constant(pixel_grid, dtype=tf.float32)
125
+ pixel_grid = tf.transpose(pixel_grid, perm=[1, 0])
126
+
127
+ # Normalize the pixel grid using the specified mean and std
128
+ mean = tf.constant(np.array(means), dtype=tf.float32)
129
+ std = tf.constant(np.array(stds), dtype=tf.float32)
130
+ valid_grid = (pixel_grid - mean[: , tf.newaxis]) / std[: , tf.newaxis]
131
+
132
+ return valid_grid
133
+
134
+ class Smoothing(tf.keras.layers.Layer):
135
+ """
136
+ A TensorFlow layer for applying Gaussian smoothing to an image.
137
+ """
138
+
139
+ def __init__(self, size: int = 3, sigma: float = 1.25):
140
+ """
141
+ Initialize the Smoothing layer.
142
+
143
+ Args:
144
+ size (int): The size of the Gaussian kernel.
145
+ sigma (float): The standard deviation of the Gaussian kernel.
146
+ """
147
+ super(Smoothing, self).__init__()
148
+ self.size = size
149
+ self.sigma = sigma
150
+ self.kernel = self.gaussian_kernel(size, sigma)
151
+
152
+ def build(self, input_shape):
153
+ """
154
+ Build the smoothing layer.
155
+
156
+ Args:
157
+ input_shape (TensorShape): Shape of the input tensor.
158
+ """
159
+ kernel = tf.reshape(self.kernel, [self.size, self.size, 1, 1])
160
+ self.kernel = tf.tile(kernel, [1, 1, input_shape[-1], 1])
161
+
162
+ def call(self, inputs):
163
+ """
164
+ Apply Gaussian smoothing to the input image.
165
+
166
+ Args:
167
+ inputs (tf.Tensor): The input image tensor.
168
+
169
+ Returns:
170
+ tf.Tensor: The smoothed image tensor.
171
+ """
172
+ return tf.nn.depthwise_conv2d(inputs, self.kernel, strides=[1, 1, 1, 1], padding='SAME')
173
+
174
+ def gaussian_kernel(self, size: int, sigma: float) -> tf.Tensor:
175
+ """
176
+ Create a Gaussian kernel.
177
+
178
+ Args:
179
+ size (int): The size of the Gaussian kernel.
180
+ sigma (float): The standard deviation of the Gaussian kernel.
181
+
182
+ Returns:
183
+ tf.Tensor: The Gaussian kernel tensor.
184
+ """
185
+ axis = tf.range(-size // 2 + 1, size // 2 + 1, dtype=tf.float32)
186
+ x, y = tf.meshgrid(axis, axis)
187
+ kernel = tf.exp(-(x ** 2 + y ** 2) / (2 * sigma ** 2))
188
+ kernel = kernel / tf.reduce_sum(kernel)
189
+ return kernel
@@ -12,111 +12,62 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
- from typing import Tuple, Dict, Type, List
16
-
17
- import numpy as np
15
+ from typing import Tuple, Dict, Type, Union, List
18
16
  import tensorflow as tf
19
17
 
20
- from model_compression_toolkit.data_generation.common.enums import ImagePipelineType, ImageNormalizationType
18
+ from model_compression_toolkit.data_generation.common.enums import ImagePipelineType
21
19
  from model_compression_toolkit.data_generation.common.image_pipeline import BaseImagePipeline
20
+ from model_compression_toolkit.data_generation.keras.image_operations import Smoothing, random_flip, random_crop, \
21
+ clip_images, create_valid_grid, center_crop
22
22
 
23
23
 
24
- # Define tf function for image manipulation
25
-
26
- def random_crop(image: tf.Tensor,
27
- height_crop: int,
28
- width_crop: int) -> tf.Tensor:
29
- """
30
- Randomly crop an image to the specified size.
31
-
32
- Args:
33
- image (tf.Tensor): Input image tensor.
34
- height_crop (int): Size of the crop in the height axis.
35
- width_crop (int): Size of the crop in the width axis.
36
-
37
- Returns:
38
- tf.Tensor: Cropped image tensor.
39
- """
40
- cropped_image = tf.image.random_crop(image,
41
- size=(tf.shape(image)[0],
42
- height_crop,
43
- width_crop,
44
- tf.shape(image)[-1]))
45
- return cropped_image
46
-
47
-
48
- def center_crop(image: tf.Tensor,
49
- output_size: Tuple) -> tf.Tensor:
50
- """
51
- Center crop an image to the specified size.
52
-
53
- Args:
54
- image (tf.Tensor): Input image tensor.
55
- output_size (Tuple): Size of image after the crop (height and width).
56
-
57
- Returns:
58
- tf.Tensor: Cropped image tensor.
59
- """
60
-
61
- # Calculate the cropping dimensions
62
- input_shape = tf.shape(image)
63
- height, width = input_shape[1], input_shape[2]
64
- target_height, target_width = output_size[0], output_size[1]
65
-
66
- # Calculate the cropping offsets
67
- offset_height = tf.maximum((height - target_height) // 2, 0)
68
- offset_width = tf.maximum((width - target_width) // 2, 0)
69
-
70
- # Crop the image
71
- cropped_image = tf.image.crop_to_bounding_box(image, offset_height, offset_width, target_height, target_width)
72
-
73
- return cropped_image
74
-
75
-
76
- def random_flip(image: tf.Tensor) -> tf.Tensor:
77
- """
78
- Randomly flip an image horizontally with a specified probability.
79
-
80
- Args:
81
- image (tf.Tensor): Input image tensor.
82
-
83
- Returns:
84
- tf.Tensor: Flipped image tensor.
85
- """
86
- flip_image = tf.image.random_flip_left_right(image)
87
- return flip_image
88
-
89
-
90
- class TensorflowCropFlipImagePipeline(BaseImagePipeline):
24
+ class TensorflowSmoothAugmentationImagePipeline(BaseImagePipeline):
91
25
  def __init__(self,
92
- output_image_size: Tuple,
93
- extra_pixels: int):
26
+ output_image_size: Union[int, Tuple[int, int]],
27
+ extra_pixels: Union[int, Tuple[int, int]],
28
+ normalization: List[List[int]],
29
+ image_clipping: bool = False,
30
+ smoothing_filter_size: int = 3,
31
+ smoothing_filter_sigma: float = 1.25):
94
32
  """
95
33
  Initialize the TensorflowCropFlipImagePipeline.
96
34
 
97
35
  Args:
98
- output_image_size (Tuple): The output image size.
99
- extra_pixels (int): Extra pixels to add to the input image size. Defaults to 0.
100
- """
101
- super(TensorflowCropFlipImagePipeline, self, ).__init__(output_image_size, extra_pixels)
102
-
36
+ output_image_size (Union[int, Tuple[int, int]]): The output image size.
37
+ extra_pixels (Union[int, Tuple[int, int]]): Extra pixels to add to the input image size. Defaults to 0.
38
+ normalization (List[List[float]]): The image normalization values for processing images during optimization.
39
+ image_clipping (bool): Whether to clip images during optimization.
40
+ smoothing_filter_size (int): The size of the smoothing filter. Defaults to 3.
41
+ smoothing_filter_sigma (float): The standard deviation of the smoothing filter. Defaults to 1.25.
42
+ """
43
+ super(TensorflowSmoothAugmentationImagePipeline, self, ).__init__(output_image_size, extra_pixels, image_clipping, normalization)
44
+
45
+ smoothing = Smoothing(smoothing_filter_size, smoothing_filter_sigma)
103
46
  # List of image manipulation functions and their arguments.
104
47
  self.img_manipulation_list = [(random_flip, {}),
105
- (random_crop, {'height_crop': output_image_size[0],
106
- 'width_crop': output_image_size[1]})]
48
+ (smoothing, {}),
49
+ (random_crop, {'height_crop': self.output_image_size[0],
50
+ 'width_crop': self.output_image_size[1]}),
51
+ ]
107
52
 
108
53
  # List of output image manipulation functions and their arguments.
109
- self.img_output_finalize_list = [(center_crop, {'output_size': output_image_size})]
110
- self.extra_pixels = extra_pixels
54
+ self.img_output_finalize_list = [(smoothing, {}),
55
+ (center_crop, {'height_crop': self.output_image_size[0],
56
+ 'width_crop': self.output_image_size[1]}),
57
+ ]
58
+ if image_clipping:
59
+ clip_fn = (clip_images, {'valid_grid': create_valid_grid(self.normalization[0], self.normalization[1])})
60
+ self.img_manipulation_list.append(clip_fn)
61
+ self.img_output_finalize_list.append(clip_fn)
111
62
 
112
- def get_image_input_size(self) -> Tuple:
63
+ def get_image_input_size(self) -> Tuple[int, int]:
113
64
  """
114
65
  Get the size of the input image considering extra pixels.
115
66
 
116
67
  Returns:
117
- Tuple: Size of the input image.
68
+ Tuple[int, int]: Size of the input image.
118
69
  """
119
- return tuple(np.array(self.output_image_size) + self.extra_pixels)
70
+ return tuple([o + e for (o, e) in zip(self.output_image_size, self.extra_pixels)])
120
71
 
121
72
  def image_input_manipulation(self,
122
73
  images: tf.Tensor) -> tf.Tensor:
@@ -161,28 +112,30 @@ class TensorflowCropFlipImagePipeline(BaseImagePipeline):
161
112
 
162
113
  class TensorflowIdentityImagePipeline(BaseImagePipeline):
163
114
 
164
- def __init__(self, output_image_size: int,
165
- extra_pixels: int
115
+ def __init__(self, output_image_size: Union[int, Tuple[int, int]],
116
+ extra_pixels: Union[int, Tuple[int, int]],
117
+ normalization: List[List[int]],
118
+ image_clipping: bool = False
166
119
  ):
167
120
  """
168
121
  Initialize the TensorflowIdentityImagePipeline.
169
122
 
170
123
  Args:
171
- output_image_size (Tuple): The output image size.
172
- extra_pixels (int): Extra pixels to add to the input image size. Defaults to 0.
124
+ output_image_size (Union[int, Tuple[int, int]]): The output image size.
125
+ extra_pixels (Union[int, Tuple[int, int]]): Extra pixels to add to the input image size. Defaults to 0.
126
+ normalization (List[List[float]]): The image normalization values for processing images during optimization.
127
+ image_clipping (bool): Whether to clip images during optimization.
173
128
  """
174
- super(TensorflowIdentityImagePipeline, self, ).__init__(output_image_size, extra_pixels)
175
- self.extra_pixels = extra_pixels
176
- self.output_image_size = output_image_size
129
+ super(TensorflowIdentityImagePipeline, self, ).__init__(output_image_size, extra_pixels, image_clipping, normalization)
177
130
 
178
- def get_image_input_size(self) -> Tuple:
131
+ def get_image_input_size(self) -> Tuple[int, int]:
179
132
  """
180
- Get the size of the input image considering extra pixels.
133
+ Get the size of the input image.
181
134
 
182
135
  Returns:
183
- Tuple: Size of the input image.
136
+ Tuple[int, int]: Size of the input image.
184
137
  """
185
- return tuple(np.array(self.output_image_size) + self.extra_pixels)
138
+ return self.output_image_size
186
139
 
187
140
  def image_input_manipulation(self,
188
141
  images: tf.Tensor) -> tf.Tensor:
@@ -214,12 +167,5 @@ class TensorflowIdentityImagePipeline(BaseImagePipeline):
214
167
  # Dictionary mapping ImagePipelineType to corresponding image pipeline classes
215
168
  image_pipeline_dict: Dict[ImagePipelineType, Type[BaseImagePipeline]] = {
216
169
  ImagePipelineType.IDENTITY: TensorflowIdentityImagePipeline,
217
- ImagePipelineType.RANDOM_CROP_FLIP: TensorflowCropFlipImagePipeline
218
- }
219
-
220
- # Dictionary mapping ImageNormalizationType to corresponding normalization values
221
- image_normalization_dict: Dict[ImageNormalizationType, List[List[float]]] = {
222
- ImageNormalizationType.TORCHVISION: [[0.485, 0.456, 0.406], [0.229, 0.224, 0.225]],
223
- ImageNormalizationType.KERAS_APPLICATIONS: [(127.5, 127.5, 127.5), (127.5, 127.5, 127.5)],
224
- ImageNormalizationType.NO_NORMALIZATION: [[0, 0, 0], [1, 1, 1]]
170
+ ImagePipelineType.SMOOTHING_AND_AUGMENTATION: TensorflowSmoothAugmentationImagePipeline
225
171
  }
@@ -13,12 +13,13 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
  import time
16
- from typing import Callable, Tuple, List, Dict
16
+ from typing import Callable, Tuple, List, Dict, Union
17
17
  from tqdm import tqdm
18
18
 
19
19
  from model_compression_toolkit.constants import FOUND_TF
20
20
  from model_compression_toolkit.data_generation.common.constants import DEFAULT_N_ITER, DEFAULT_DATA_GEN_BS
21
21
  from model_compression_toolkit.data_generation.common.data_generation import get_data_generation_classes
22
+ from model_compression_toolkit.data_generation.common.image_pipeline import image_normalization_dict
22
23
  from model_compression_toolkit.logger import Logger
23
24
  from model_compression_toolkit.data_generation.common.data_generation_config import DataGenerationConfig, \
24
25
  ImageGranularity
@@ -30,9 +31,8 @@ if FOUND_TF:
30
31
  from tensorflow.keras.layers import BatchNormalization
31
32
  from tensorflow.keras.optimizers.legacy import Optimizer, Adam
32
33
  from model_compression_toolkit.data_generation.keras.constants import DEFAULT_KERAS_INITIAL_LR, \
33
- DEFAULT_KERAS_OUTPUT_LOSS_MULTIPLIER
34
- from model_compression_toolkit.data_generation.keras.image_pipeline import (image_pipeline_dict,
35
- image_normalization_dict)
34
+ DEFAULT_KERAS_EXTRA_PIXELS, DEFAULT_KERAS_OUTPUT_LOSS_MULTIPLIER
35
+ from model_compression_toolkit.data_generation.keras.image_pipeline import image_pipeline_dict
36
36
  from model_compression_toolkit.data_generation.keras.model_info_exctractors import (KerasActivationExtractor,
37
37
  KerasOriginalBNStatsHolder)
38
38
  from model_compression_toolkit.data_generation.keras.optimization_functions.batchnorm_alignment_functions import \
@@ -55,18 +55,17 @@ if FOUND_TF:
55
55
  data_gen_batch_size: int = DEFAULT_DATA_GEN_BS,
56
56
  initial_lr: float = DEFAULT_KERAS_INITIAL_LR,
57
57
  output_loss_multiplier: float = DEFAULT_KERAS_OUTPUT_LOSS_MULTIPLIER,
58
- scheduler_type: SchedulerType = SchedulerType.REDUCE_ON_PLATEAU,
58
+ scheduler_type: SchedulerType = SchedulerType.REDUCE_ON_PLATEAU ,
59
59
  bn_alignment_loss_type: BatchNormAlignemntLossType = BatchNormAlignemntLossType.L2_SQUARE,
60
60
  output_loss_type: OutputLossType = OutputLossType.REGULARIZED_MIN_MAX_DIFF,
61
61
  data_init_type: DataInitType = DataInitType.Gaussian,
62
62
  layer_weighting_type: BNLayerWeightingType = BNLayerWeightingType.AVERAGE,
63
63
  image_granularity: ImageGranularity = ImageGranularity.BatchWise,
64
- image_pipeline_type: ImagePipelineType = ImagePipelineType.RANDOM_CROP_FLIP,
64
+ image_pipeline_type: ImagePipelineType = ImagePipelineType.SMOOTHING_AND_AUGMENTATION,
65
65
  image_normalization_type: ImageNormalizationType = ImageNormalizationType.KERAS_APPLICATIONS,
66
- extra_pixels: int = 0,
66
+ extra_pixels: Union[int, Tuple[int, int]] = DEFAULT_KERAS_EXTRA_PIXELS,
67
67
  bn_layer_types: List = [BatchNormalization],
68
- clip_images: bool = True,
69
- reflection: bool = True,
68
+ image_clipping: bool = False,
70
69
  ) -> DataGenerationConfig:
71
70
  """
72
71
  Function to create a DataGenerationConfig object with the specified configuration parameters.
@@ -85,10 +84,9 @@ if FOUND_TF:
85
84
  image_granularity (ImageGranularity): The granularity of the images for optimization.
86
85
  image_pipeline_type (ImagePipelineType): The type of image pipeline to use.
87
86
  image_normalization_type (ImageNormalizationType): The type of image normalization to use.
88
- extra_pixels (int): Extra pixels to add to the input image size. Defaults to 0.
87
+ extra_pixels (Union[int, Tuple[int, int]]): Extra pixels to add to the input image size. Defaults to 0.
89
88
  bn_layer_types (List): List of BatchNorm layer types to be considered for data generation.
90
- clip_images (bool): Whether to clip images during optimization.
91
- reflection (bool): Whether to use reflection during optimization.
89
+ image_clipping (bool): Whether to clip images during optimization.
92
90
 
93
91
  Returns:
94
92
  DataGenerationConfig: Data generation configuration object.
@@ -100,6 +98,7 @@ if FOUND_TF:
100
98
  optimizer=optimizer,
101
99
  data_gen_batch_size=data_gen_batch_size,
102
100
  initial_lr=initial_lr,
101
+ output_loss_multiplier=output_loss_multiplier,
103
102
  scheduler_type=scheduler_type,
104
103
  bn_alignment_loss_type=bn_alignment_loss_type,
105
104
  output_loss_type=output_loss_type,
@@ -110,15 +109,13 @@ if FOUND_TF:
110
109
  image_normalization_type=image_normalization_type,
111
110
  extra_pixels=extra_pixels,
112
111
  bn_layer_types=bn_layer_types,
113
- clip_images=clip_images,
114
- reflection=reflection,
115
- output_loss_multiplier=output_loss_multiplier)
112
+ image_clipping=image_clipping)
116
113
 
117
114
 
118
115
  def keras_data_generation_experimental(
119
116
  model: tf.keras.Model,
120
117
  n_images: int,
121
- output_image_size: Tuple,
118
+ output_image_size: Union[int, Tuple[int, int]],
122
119
  data_generation_config: DataGenerationConfig) -> tf.Tensor:
123
120
  """
124
121
  Function to perform data generation using the provided Keras model and data generation configuration.
@@ -126,7 +123,7 @@ if FOUND_TF:
126
123
  Args:
127
124
  model (Model): Keras model to generate data for.
128
125
  n_images (int): Number of images to generate.
129
- output_image_size (Tuple): Size of the output images.
126
+ output_image_size (Union[int, Tuple[int, int]]): Size of the output images.
130
127
  data_generation_config (DataGenerationConfig): Configuration for data generation.
131
128
 
132
129
  Returns:
@@ -180,17 +177,13 @@ if FOUND_TF:
180
177
  bn_alignment_loss_function_dict=bn_alignment_loss_function_dict,
181
178
  output_loss_function_dict=output_loss_function_dict)
182
179
 
183
- if not all(normalization[1]):
184
- Logger.critical(
185
- f'Invalid normalization standard deviation {normalization[1]} set to zero, which will lead to division by zero. Please select a non-zero normalization standard deviation.')
186
-
187
180
  # Get the scheduler functions corresponding to the specified scheduler type
188
181
  scheduler_get_fn = scheduler_step_function_dict.get(data_generation_config.scheduler_type)
189
182
 
190
183
  # Check if the scheduler type is valid
191
184
  if scheduler_get_fn is None:
192
185
  Logger.critical(
193
- f'Invalid scheduler_type {data_generation_config.scheduler_type}. Please select one from {SchedulerType.get_values()}.')
186
+ f'Invalid scheduler_type {data_generation_config.scheduler_type}. Please select one from {SchedulerType.get_values()}.') # pragma: no cover
194
187
 
195
188
  # Create a scheduler object with the specified number of iterations
196
189
  scheduler = scheduler_get_fn(n_iter=data_generation_config.n_iter,
@@ -202,10 +195,7 @@ if FOUND_TF:
202
195
  # Create an activation extractor object to extract activations from the model
203
196
  activation_extractor = KerasActivationExtractor(model=model,
204
197
  layer_types_to_extract_inputs=
205
- data_generation_config.bn_layer_types,
206
- image_granularity=data_generation_config.image_granularity,
207
- image_input_manipulation=
208
- image_pipeline.image_input_manipulation)
198
+ data_generation_config.bn_layer_types)
209
199
 
210
200
  # Create an orig_bn_stats_holder object to hold original BatchNorm statistics
211
201
  orig_bn_stats_holder = KerasOriginalBNStatsHolder(model=model,
@@ -223,9 +213,6 @@ if FOUND_TF:
223
213
  model=model,
224
214
  orig_bn_stats_holder=orig_bn_stats_holder)
225
215
 
226
- # Compute the layer weights based on orig_bn_stats_holder
227
- bn_layer_weights = bn_layer_weighting_fn(orig_bn_stats_holder=orig_bn_stats_holder)
228
-
229
216
  # Get the current time to measure the total time taken
230
217
  total_time = time.time()
231
218
 
@@ -233,7 +220,7 @@ if FOUND_TF:
233
220
  ibar = tqdm(range(data_generation_config.n_iter))
234
221
 
235
222
  # Perform data generation iterations
236
- for i_ter in ibar:
223
+ for i_iter in ibar:
237
224
 
238
225
  # Randomly reorder the batches
239
226
  all_imgs_opt_handler.random_batch_reorder()
@@ -246,6 +233,12 @@ if FOUND_TF:
246
233
  # Get the images to optimize and the optimizer for the batch
247
234
  imgs_to_optimize = all_imgs_opt_handler.get_images_by_batch_index(batch_index=random_batch_index)
248
235
 
236
+ # Compute the layer weights based on orig_bn_stats_holder
237
+ bn_layer_weights = bn_layer_weighting_fn(orig_bn_stats_holder=orig_bn_stats_holder,
238
+ activation_extractor=activation_extractor,
239
+ i_iter=i_iter,
240
+ n_iter=data_generation_config.n_iter)
241
+
249
242
  # Compute the gradients and the loss for the batch
250
243
  gradients, total_loss, bn_loss, output_loss = keras_compute_grads(imgs_to_optimize=imgs_to_optimize,
251
244
  batch_index=random_batch_index,
@@ -266,7 +259,7 @@ if FOUND_TF:
266
259
  images=imgs_to_optimize,
267
260
  gradients=gradients,
268
261
  loss=total_loss,
269
- i_ter=i_ter)
262
+ i_iter=i_iter)
270
263
 
271
264
  # Update the statistics based on the updated images
272
265
  if all_imgs_opt_handler.use_all_data_stats:
@@ -335,14 +328,13 @@ if FOUND_TF:
335
328
  bn_layer_weights=bn_layer_weights)
336
329
 
337
330
  # Compute output loss
338
- # If output_loss_multiplier is zero return 0
339
- output_loss = output_loss_multiplier * output_loss_fn(
340
- output_imgs=output,
331
+ output_loss = output_loss_fn(
332
+ model_outputs=output,
341
333
  activation_extractor=activation_extractor,
342
- tape=tape) if output_loss_multiplier > 0 else tf.zeros(1)
334
+ tape=tape)
343
335
 
344
336
  # Compute total loss
345
- total_loss = bn_loss + output_loss
337
+ total_loss = bn_loss + output_loss_multiplier * output_loss
346
338
 
347
339
  # Get the trainable variables
348
340
  variables = [imgs_to_optimize]
@@ -72,8 +72,6 @@ class KerasActivationExtractor(ActivationExtractor):
72
72
  def __init__(self,
73
73
  model: tf.keras.Model,
74
74
  layer_types_to_extract_inputs: List,
75
- image_granularity: ImageGranularity,
76
- image_input_manipulation: Callable,
77
75
  linear_layers: Tuple = (Dense, Conv2D)):
78
76
  """
79
77
  Initializes the KerasActivationExtractor.
@@ -81,14 +79,10 @@ class KerasActivationExtractor(ActivationExtractor):
81
79
  Args:
82
80
  model (Model): Keras model to generate data for.
83
81
  layer_types_to_extract_inputs (List): Tuple or list of layer types.
84
- image_granularity (ImageGranularity): The granularity of the images for optimization.
85
- image_input_manipulation (Callable): Function for image input manipulation.
86
82
  linear_layers (Tuple): Tuple of linear layers types to retrieve the output of the last linear layer
87
83
 
88
84
  """
89
85
  self.model = model
90
- self.image_input_manipulation = image_input_manipulation
91
- self.image_granularity = image_granularity
92
86
  self.layer_types_to_extract_inputs = tuple(layer_types_to_extract_inputs)
93
87
  self.linear_layers = linear_layers
94
88
 
@@ -96,7 +90,6 @@ class KerasActivationExtractor(ActivationExtractor):
96
90
  self.bn_layer_names = [layer.name for layer in model.layers if isinstance(layer,
97
91
  self.layer_types_to_extract_inputs)]
98
92
  self.num_layers = len(self.bn_layer_names)
99
- Logger.info(f'Number of layers = {self.num_layers}')
100
93
 
101
94
  # Initialize stats containers
102
95
  self.activations = {}
@@ -206,9 +199,3 @@ class KerasActivationExtractor(ActivationExtractor):
206
199
  last_layer = layer
207
200
  break
208
201
  return last_layer
209
-
210
- def remove(self):
211
- """
212
- Remove the stats containers.
213
- """
214
- self.activations = {}
@@ -15,16 +15,22 @@
15
15
  from typing import Dict, Callable
16
16
 
17
17
  from model_compression_toolkit.data_generation.common.enums import BNLayerWeightingType
18
- from model_compression_toolkit.data_generation.keras.model_info_exctractors import KerasOriginalBNStatsHolder
18
+ from model_compression_toolkit.data_generation.keras.model_info_exctractors import KerasOriginalBNStatsHolder, \
19
+ KerasActivationExtractor
19
20
 
20
21
 
21
- def average_layer_weighting_fn(orig_bn_stats_holder: KerasOriginalBNStatsHolder, **kwargs) -> Dict[str, float]:
22
+ def average_layer_weighting_fn(orig_bn_stats_holder: KerasOriginalBNStatsHolder,
23
+ activation_extractor: KerasActivationExtractor,
24
+ i_iter: int,
25
+ n_iter: int) -> Dict[str, float]:
22
26
  """
23
27
  Calculate average weighting for each batch normalization layer.
24
28
 
25
29
  Args:
26
30
  orig_bn_stats_holder (KerasOriginalBNStatsHolder): Holder for original batch normalization statistics.
27
- **kwargs: Additional arguments if needed.
31
+ activation_extractor (KerasActivationExtractor): The activation extractor for the model.
32
+ i_iter (int): Current optimization iteration.
33
+ n_iter (int): Total number of optimization iterations.
28
34
 
29
35
  Returns:
30
36
  Dict[str, float]: A dictionary containing layer names as keys and average weightings as values.
@@ -33,14 +39,18 @@ def average_layer_weighting_fn(orig_bn_stats_holder: KerasOriginalBNStatsHolder,
33
39
  return {bn_layer_name: 1 / num_bn_layers for bn_layer_name in orig_bn_stats_holder.get_bn_layer_names()}
34
40
 
35
41
 
36
- def first_bn_multiplier_weighting_fn(orig_bn_stats_holder: KerasOriginalBNStatsHolder, **kwargs) -> Dict[str, float]:
42
+ def first_bn_multiplier_weighting_fn(orig_bn_stats_holder: KerasOriginalBNStatsHolder,
43
+ activation_extractor: KerasActivationExtractor,
44
+ i_iter: int,
45
+ n_iter: int) -> Dict[str, float]:
37
46
  """
38
47
  Calculate layer weightings with a higher multiplier for the first batch normalization layer.
39
48
 
40
49
  Args:
41
50
  orig_bn_stats_holder (KerasOriginalBNStatsHolder): Holder for original batch normalization statistics.
42
- **kwargs: Additional arguments if needed.
43
-
51
+ activation_extractor (KerasActivationExtractor): The activation extractor for the model.
52
+ i_iter (int): Current optimization iteration.
53
+ n_iter (int): Total number of optimization iterations.
44
54
  Returns:
45
55
  Dict[str, float]: A dictionary containing layer names as keys and weightings as values.
46
56
  """