mct-nightly 2.1.0.20240725.446__py3-none-any.whl → 2.1.0.20240726.430__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (35) hide show
  1. {mct_nightly-2.1.0.20240725.446.dist-info → mct_nightly-2.1.0.20240726.430.dist-info}/METADATA +1 -1
  2. {mct_nightly-2.1.0.20240725.446.dist-info → mct_nightly-2.1.0.20240726.430.dist-info}/RECORD +35 -31
  3. model_compression_toolkit/__init__.py +1 -1
  4. model_compression_toolkit/core/pytorch/constants.py +6 -1
  5. model_compression_toolkit/core/pytorch/utils.py +27 -0
  6. model_compression_toolkit/data_generation/common/data_generation.py +20 -18
  7. model_compression_toolkit/data_generation/common/data_generation_config.py +8 -11
  8. model_compression_toolkit/data_generation/common/enums.py +24 -12
  9. model_compression_toolkit/data_generation/common/image_pipeline.py +50 -12
  10. model_compression_toolkit/data_generation/common/model_info_exctractors.py +0 -8
  11. model_compression_toolkit/data_generation/common/optimization_utils.py +7 -11
  12. model_compression_toolkit/data_generation/keras/constants.py +5 -2
  13. model_compression_toolkit/data_generation/keras/image_operations.py +189 -0
  14. model_compression_toolkit/data_generation/keras/image_pipeline.py +50 -104
  15. model_compression_toolkit/data_generation/keras/keras_data_generation.py +28 -36
  16. model_compression_toolkit/data_generation/keras/model_info_exctractors.py +0 -13
  17. model_compression_toolkit/data_generation/keras/optimization_functions/bn_layer_weighting_functions.py +16 -6
  18. model_compression_toolkit/data_generation/keras/optimization_functions/lr_scheduler.py +219 -0
  19. model_compression_toolkit/data_generation/keras/optimization_functions/output_loss_functions.py +39 -13
  20. model_compression_toolkit/data_generation/keras/optimization_functions/scheduler_step_functions.py +6 -98
  21. model_compression_toolkit/data_generation/keras/optimization_utils.py +15 -28
  22. model_compression_toolkit/data_generation/pytorch/constants.py +4 -1
  23. model_compression_toolkit/data_generation/pytorch/image_operations.py +105 -0
  24. model_compression_toolkit/data_generation/pytorch/image_pipeline.py +70 -78
  25. model_compression_toolkit/data_generation/pytorch/model_info_exctractors.py +0 -10
  26. model_compression_toolkit/data_generation/pytorch/optimization_functions/bn_layer_weighting_functions.py +17 -6
  27. model_compression_toolkit/data_generation/pytorch/optimization_functions/image_initilization.py +2 -2
  28. model_compression_toolkit/data_generation/pytorch/optimization_functions/lr_scheduler.py +219 -0
  29. model_compression_toolkit/data_generation/pytorch/optimization_functions/output_loss_functions.py +55 -21
  30. model_compression_toolkit/data_generation/pytorch/optimization_functions/scheduler_step_functions.py +15 -0
  31. model_compression_toolkit/data_generation/pytorch/optimization_utils.py +32 -54
  32. model_compression_toolkit/data_generation/pytorch/pytorch_data_generation.py +57 -52
  33. {mct_nightly-2.1.0.20240725.446.dist-info → mct_nightly-2.1.0.20240726.430.dist-info}/LICENSE.md +0 -0
  34. {mct_nightly-2.1.0.20240725.446.dist-info → mct_nightly-2.1.0.20240726.430.dist-info}/WHEEL +0 -0
  35. {mct_nightly-2.1.0.20240725.446.dist-info → mct_nightly-2.1.0.20240726.430.dist-info}/top_level.txt +0 -0
@@ -17,37 +17,66 @@ from typing import Dict, Callable
17
17
  import torch
18
18
  from torch import Tensor
19
19
 
20
- from model_compression_toolkit.core.pytorch.pytorch_device_config import get_working_device
21
20
  from model_compression_toolkit.data_generation.common.enums import OutputLossType
22
21
  from model_compression_toolkit.data_generation.pytorch.model_info_exctractors import PytorchActivationExtractor
23
22
 
24
- def min_max_diff(
25
- output_imgs: Tensor,
23
+ def inverse_min_max_diff(
24
+ model_outputs: Tensor,
26
25
  activation_extractor: PytorchActivationExtractor,
26
+ device: torch.device,
27
27
  eps: float = 1e-6) -> Tensor:
28
28
  """
29
- Calculate the minimum-maximum difference of output images.
29
+ Calculate the inverse of the maximum - minimum difference of the model output on the input images.
30
30
 
31
31
  Args:
32
- output_imgs (Tensor or List[Tensor]): The output of the model on images.
32
+ model_outputs (Tensor or List[Tensor]): The output of the model on images.
33
33
  activation_extractor (PytorchActivationExtractor): The activation extractor for the model.
34
+ device (torch.device): The current device set for PyTorch operations.
34
35
  eps (float): Small value for numerical stability.
35
36
 
36
37
  Returns:
37
38
  Tensor: The computed minimum-maximum difference loss.
38
39
  """
39
- if not isinstance(output_imgs, (list, tuple)):
40
- output_imgs = [output_imgs]
41
- output_loss = 0
42
- for output in output_imgs:
40
+ if not isinstance(model_outputs, (list, tuple)):
41
+ model_outputs = [model_outputs]
42
+ output_loss = torch.zeros(1).to(device)
43
+ for output in model_outputs:
43
44
  output = torch.reshape(output, [output.shape[0], -1])
44
45
  output_loss += 1 / torch.mean(torch.max(output, 1)[0] - torch.min(output, 1)[0] + eps)
45
46
  return output_loss
46
47
 
48
+ def negative_min_max_diff(
49
+ model_outputs: Tensor,
50
+ activation_extractor: PytorchActivationExtractor,
51
+ device: torch.device,
52
+ eps: float = 1e-6) -> Tensor:
53
+ """
54
+ Calculate the mean of the negative maximum - minimum difference of the model output on the input images.
55
+
56
+ Args:
57
+ model_outputs (Tensor or List[Tensor]): The output of the model on images.
58
+ activation_extractor (PytorchActivationExtractor): The activation extractor for the model.
59
+ device (torch.device): The current device set for PyTorch operations.
60
+ eps (float): Small value for numerical stability.
61
+
62
+ Returns:
63
+ Tensor: The computed minimum-maximum difference loss.
64
+ """
65
+ if not isinstance(model_outputs, (list, tuple)):
66
+ model_outputs = [model_outputs]
67
+ output_loss = torch.zeros(1).to(device)
68
+ for output in model_outputs:
69
+ output = torch.reshape(output, [output.shape[0], -1])
70
+ out_max, out_argmax = torch.max(output, dim=1)
71
+ out_min, out_argmin = torch.min(output, dim=1)
72
+ output_loss += torch.mean(-(out_max - out_min))
73
+ return output_loss
74
+
47
75
 
48
76
  def regularized_min_max_diff(
49
- output_imgs: Tensor,
77
+ model_outputs: Tensor,
50
78
  activation_extractor: PytorchActivationExtractor,
79
+ device: torch.device,
51
80
  eps: float = 1e-6) -> Tensor:
52
81
  """
53
82
  Calculate the regularized minimum-maximum difference of output images. We want to maximize
@@ -56,8 +85,9 @@ def regularized_min_max_diff(
56
85
  the last layer's weights.
57
86
 
58
87
  Args:
59
- output_imgs (Tensor or List[Tensor]): The output of the model on images.
88
+ model_outputs (Tensor or List[Tensor]): The output of the model on images.
60
89
  activation_extractor (PytorchActivationExtractor): The activation extractor for the model.
90
+ device (torch.device): The current device set for PyTorch operations.
61
91
  eps (float): Small value for numerical stability.
62
92
 
63
93
  Returns:
@@ -69,13 +99,13 @@ def regularized_min_max_diff(
69
99
  # get the weights of the last linear layers of the model
70
100
  weights_output_layers = activation_extractor.get_last_linear_layers_weights()
71
101
 
72
- if not isinstance(output_imgs, (list, tuple)):
73
- output_imgs = torch.reshape(output_imgs, [output_imgs.shape[0], output_imgs.shape[1], -1])
74
- output_imgs = torch.mean(output_imgs, dim=-1)
75
- output_imgs = [output_imgs]
76
- output_loss = torch.zeros(1).to(get_working_device())
102
+ if not isinstance(model_outputs, (list, tuple)):
103
+ model_outputs = torch.reshape(model_outputs, [model_outputs.shape[0], model_outputs.shape[1], -1])
104
+ model_outputs = torch.mean(model_outputs, dim=-1)
105
+ model_outputs = [model_outputs]
106
+ output_loss = torch.zeros(1).to(device)
77
107
 
78
- for output_weight, output, last_layer_input in zip(weights_output_layers, output_imgs, output_layers_inputs):
108
+ for output_weight, output, last_layer_input in zip(weights_output_layers, model_outputs, output_layers_inputs):
79
109
  weights_norm = torch.linalg.norm(output_weight.squeeze(), dim=1)
80
110
  out_max, out_argmax = torch.max(output, dim=1)
81
111
  out_min, out_argmin = torch.min(output, dim=1)
@@ -88,27 +118,31 @@ def regularized_min_max_diff(
88
118
  output_loss += torch.mean(reg_min + reg_max + dynamic_loss)
89
119
  return output_loss
90
120
 
121
+
91
122
  def no_output_loss(
92
- output_imgs: Tensor,
123
+ model_outputs: Tensor,
93
124
  activation_extractor: PytorchActivationExtractor,
125
+ device: torch.device,
94
126
  eps: float = 1e-6) -> Tensor:
95
127
  """
96
128
  Calculate no output loss.
97
129
 
98
130
  Args:
99
- output_imgs (Tensor): The output of the model on images.
131
+ model_outputs (Tensor): The output of the model on images.
100
132
  activation_extractor (PytorchActivationExtractor): The activation extractor for the model.
133
+ device (torch.device): The current device set for PyTorch operations.
101
134
  eps (float): Small value for numerical stability.
102
135
 
103
136
  Returns:
104
137
  Tensor: A tensor with zero value for the loss.
105
138
  """
106
- return torch.zeros(1).to(get_working_device())
139
+ return torch.zeros(1).to(device)
107
140
 
108
141
 
109
142
  # Dictionary of output loss functions
110
143
  output_loss_function_dict: Dict[OutputLossType, Callable] = {
111
144
  OutputLossType.NONE: no_output_loss,
112
- OutputLossType.MIN_MAX_DIFF: min_max_diff,
145
+ OutputLossType.NEGATIVE_MIN_MAX_DIFF: negative_min_max_diff,
146
+ OutputLossType.INVERSE_MIN_MAX_DIFF: inverse_min_max_diff,
113
147
  OutputLossType.REGULARIZED_MIN_MAX_DIFF: regularized_min_max_diff,
114
148
  }
@@ -17,6 +17,8 @@ from typing import Callable, Any, Dict, Tuple
17
17
 
18
18
  from torch.optim.lr_scheduler import ReduceLROnPlateau, StepLR
19
19
  from model_compression_toolkit.data_generation.common.enums import SchedulerType
20
+ from model_compression_toolkit.data_generation.pytorch.optimization_functions.lr_scheduler import \
21
+ ReduceLROnPlateauWithReset
20
22
 
21
23
 
22
24
  def get_reduce_lr_on_plateau_scheduler(n_iter: int) -> Callable:
@@ -31,6 +33,18 @@ def get_reduce_lr_on_plateau_scheduler(n_iter: int) -> Callable:
31
33
  """
32
34
  return partial(ReduceLROnPlateau, min_lr=1e-4, factor=0.5, patience=int(n_iter / 50))
33
35
 
36
+ def get_reduce_lr_on_plateau_with_reset_scheduler(n_iter: int) -> Callable:
37
+ """
38
+ Get a ReduceLROnPlateauWithReset scheduler.
39
+
40
+ Args:
41
+ n_iter (int): The number of iterations.
42
+
43
+ Returns:
44
+ Callable: A partial function to create ReduceLROnPlateauWithReset scheduler with specified parameters.
45
+ """
46
+ return partial(ReduceLROnPlateauWithReset, min_lr=1e-4, factor=0.5, patience=int(n_iter / 50))
47
+
34
48
  def get_step_lr_scheduler(n_iter: int) -> Callable:
35
49
  """
36
50
  Get a StepLR scheduler.
@@ -69,5 +83,6 @@ def scheduler_step_fn(scheduler: Any, i_iter: int, loss_value: float):
69
83
  # Dictionary of scheduler functions and their corresponding step functions
70
84
  scheduler_step_function_dict: Dict[SchedulerType, Tuple[Callable, Callable]] = {
71
85
  SchedulerType.REDUCE_ON_PLATEAU: (get_reduce_lr_on_plateau_scheduler, reduce_lr_on_platu_step_fn),
86
+ SchedulerType.REDUCE_ON_PLATEAU_WITH_RESET: (get_reduce_lr_on_plateau_with_reset_scheduler, reduce_lr_on_platu_step_fn),
72
87
  SchedulerType.STEP: (get_step_lr_scheduler, scheduler_step_fn),
73
88
  }
@@ -20,15 +20,16 @@ from torch import Tensor
20
20
  from torch.nn import Module
21
21
  from torch.optim import Optimizer
22
22
  from torch.utils.data import DataLoader, Dataset
23
- from torchvision.transforms import Normalize
23
+ from torch.cuda.amp import GradScaler
24
24
 
25
- from model_compression_toolkit.core.pytorch.pytorch_device_config import get_working_device
25
+ from model_compression_toolkit.core.pytorch.utils import to_torch_tensor, clip_inf_values_float16
26
26
  from model_compression_toolkit.data_generation.common.enums import ImageGranularity
27
27
  from model_compression_toolkit.data_generation.common.image_pipeline import BaseImagePipeline
28
28
  from model_compression_toolkit.data_generation.common.optimization_utils import BatchStatsHolder, AllImagesStatsHolder, \
29
29
  BatchOptimizationHolder, ImagesOptimizationHandler
30
30
  from model_compression_toolkit.data_generation.common.constants import IMAGE_INPUT
31
31
  from model_compression_toolkit.data_generation.pytorch.constants import BATCH_AXIS, H_AXIS, W_AXIS
32
+ from model_compression_toolkit.data_generation.pytorch.image_operations import create_valid_grid
32
33
  from model_compression_toolkit.data_generation.pytorch.model_info_exctractors import ActivationExtractor
33
34
 
34
35
 
@@ -58,8 +59,7 @@ class PytorchImagesOptimizationHandler(ImagesOptimizationHandler):
58
59
  initial_lr: float,
59
60
  normalization_mean: List[float],
60
61
  normalization_std: List[float],
61
- clip_images: bool,
62
- reflection: bool,
62
+ device: str,
63
63
  eps: float = 1e-6):
64
64
  """
65
65
  Constructor for the PytorchImagesOptimizationHandler class.
@@ -77,8 +77,7 @@ class PytorchImagesOptimizationHandler(ImagesOptimizationHandler):
77
77
  initial_lr (float): The initial learning rate used by the optimizer.
78
78
  normalization_mean (List[float]): The mean values for image normalization.
79
79
  normalization_std (List[float]): The standard deviation values for image normalization.
80
- clip_images (bool): Whether to clip the images during optimization.
81
- reflection (bool): Whether to use reflection during image clipping.
80
+ device (torch.device): The current device set for PyTorch operations.
82
81
  eps (float): A small value added for numerical stability.
83
82
  """
84
83
  super(PytorchImagesOptimizationHandler, self).__init__(model=model,
@@ -93,16 +92,12 @@ class PytorchImagesOptimizationHandler(ImagesOptimizationHandler):
93
92
  initial_lr=initial_lr,
94
93
  normalization_mean=normalization_mean,
95
94
  normalization_std=normalization_std,
96
- clip_images=clip_images,
97
- reflection=reflection,
98
95
  eps=eps)
99
96
 
100
- self.device = get_working_device()
101
- # Image valid grid, each image value can only be 0 - 255 before normalization
102
- t = torch.from_numpy(np.array(list(range(256))).repeat(3).reshape(-1, 3) / 255)
103
- self.valid_grid = Normalize(mean=normalization_mean,
104
- std=normalization_std)(t.transpose(1, 0)[None, :, :, None]).squeeze().to(self.device)
105
-
97
+ # Initialize mixed-precision scaler
98
+ self.scaler = GradScaler()
99
+ self.device = device
100
+ self.valid_grid = create_valid_grid(normalization_mean, normalization_std)
106
101
 
107
102
  # Set the mean axis based on the image granularity
108
103
  if self.image_granularity == ImageGranularity.ImageWise:
@@ -155,42 +150,40 @@ class PytorchImagesOptimizationHandler(ImagesOptimizationHandler):
155
150
  total_mean, total_second_moment = 0, 0
156
151
  for i_batch in range(self.n_batches):
157
152
  mean, second_moment, std = self.all_imgs_stats_holder.get_stats(i_batch, layer_name)
158
- total_mean += mean
159
- total_second_moment += second_moment
153
+ if mean is not None:
154
+ total_mean += mean
155
+ if second_moment is not None:
156
+ total_second_moment += second_moment
160
157
 
161
158
  total_mean /= self.n_batches
162
159
  total_second_moment /= self.n_batches
163
- total_var = total_second_moment - torch.pow(total_mean, 2)
160
+ total_var = to_torch_tensor(total_second_moment) - torch.pow(to_torch_tensor(total_mean), 2)
164
161
  total_std = torch.sqrt(total_var + self.eps)
165
162
  return total_mean, total_std
166
163
 
167
164
  def optimization_step(self,
168
165
  batch_index: int,
169
166
  loss: Tensor,
170
- i_ter: int):
167
+ i_iter: int):
171
168
  """
172
169
  Perform an optimization step.
173
170
 
174
171
  Args:
175
172
  batch_index (int): Index of the batch.
176
173
  loss (Tensor): Loss value.
177
- i_ter (int): Current optimization iteration.
174
+ i_iter (int): Current optimization iteration.
178
175
  """
179
176
  # Get optimizer and scheduler for the specific batch index
180
177
  optimizer = self.get_optimizer_by_batch_index(batch_index)
181
178
  scheduler = self.get_scheduler_by_batch_index(batch_index)
182
179
 
183
180
  # Backward pass
184
- loss.backward()
185
-
186
- # Update weights
187
- optimizer.step()
181
+ self.scaler.scale(loss).backward()
182
+ self.scaler.step(optimizer)
183
+ self.scaler.update()
188
184
 
189
185
  # Perform scheduler step
190
- self.scheduler_step_fn(scheduler, i_ter, loss.item())
191
-
192
- if self.clip_images:
193
- self.batch_opt_holders_list[batch_index].clip_images(self.valid_grid, reflection=self.reflection)
186
+ self.scheduler_step_fn(scheduler, i_iter, loss.item())
194
187
 
195
188
 
196
189
  def zero_grad(self, batch_index: int):
@@ -259,25 +252,6 @@ class PytorchBatchOptimizationHolder(BatchOptimizationHolder):
259
252
  self.optimizer = optimizer([self.images], lr=initial_lr)
260
253
  self.scheduler = scheduler(self.optimizer)
261
254
 
262
- def clip_images(self,
263
- valid_grid: Tensor,
264
- reflection: bool = True):
265
- """
266
- Clip the images.
267
-
268
- Args:
269
- valid_grid (Tensor): A tensor containing valid values for image clipping.
270
- reflection (bool): Whether to use reflection during image clipping. Defaults to True.
271
- """
272
- with torch.no_grad():
273
- for i_ch in range(valid_grid.shape[0]):
274
- clamp = torch.clamp(self.images[:, i_ch, :, :], valid_grid[i_ch, :].min(), valid_grid[i_ch, :].max())
275
- if reflection:
276
- self.images[:, i_ch, :, :] = 2 * clamp - self.images[:, i_ch, :, :]
277
- else:
278
- self.images[:, i_ch, :, :] = clamp
279
- self.images.requires_grad = True
280
-
281
255
 
282
256
  class PytorchAllImagesStatsHolder(AllImagesStatsHolder):
283
257
  """
@@ -332,8 +306,9 @@ class PytorchBatchStatsHolder(BatchStatsHolder):
332
306
  """
333
307
  mean = self.get_mean(bn_layer_name)
334
308
  second_moment = self.get_second_moment(bn_layer_name)
335
- var = second_moment - torch.pow(mean, 2.0)
336
- return var
309
+ if mean is not None and second_moment is not None:
310
+ return second_moment - torch.pow(mean, 2.0)
311
+ return None
337
312
 
338
313
 
339
314
  def get_std(self, bn_layer_name: str) -> Tensor:
@@ -347,7 +322,9 @@ class PytorchBatchStatsHolder(BatchStatsHolder):
347
322
  Tensor: The standard deviation for the specified layer.
348
323
  """
349
324
  var = self.get_var(bn_layer_name)
350
- return torch.sqrt(var + self.eps)
325
+ if var is not None:
326
+ return torch.sqrt(var + self.eps)
327
+ return None
351
328
 
352
329
  def calc_bn_stats_from_activations(self,
353
330
  input_imgs: Tensor,
@@ -374,12 +351,13 @@ class PytorchBatchStatsHolder(BatchStatsHolder):
374
351
  # Extract statistics of intermediate convolution outputs before the BatchNorm layers
375
352
  for bn_layer_name in activation_extractor.get_extractor_layer_names():
376
353
  bn_input_activations = activation_extractor.get_layer_input_activation(bn_layer_name)
377
- if not to_differentiate:
378
- bn_input_activations = bn_input_activations.detach()
354
+ if bn_input_activations is not None:
355
+ if not to_differentiate:
356
+ bn_input_activations = bn_input_activations.detach()
379
357
 
380
- collected_mean = torch.mean(bn_input_activations, dim=self.mean_axis)
381
- collected_second_moment = torch.mean(torch.pow(bn_input_activations, 2.0), dim=self.mean_axis)
382
- self.update_layer_stats(bn_layer_name, collected_mean, collected_second_moment)
358
+ collected_mean = torch.mean(bn_input_activations, dim=self.mean_axis)
359
+ collected_second_moment = clip_inf_values_float16(torch.mean(torch.pow(bn_input_activations, 2.0), dim=self.mean_axis))
360
+ self.update_layer_stats(bn_layer_name, collected_mean, collected_second_moment)
383
361
 
384
362
  def clear(self):
385
363
  """Clear the statistics."""
@@ -13,7 +13,7 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
  import time
16
- from typing import Callable, Any, Tuple, List
16
+ from typing import Callable, Any, Tuple, List, Union
17
17
 
18
18
  from tqdm import tqdm
19
19
 
@@ -25,10 +25,12 @@ from model_compression_toolkit.data_generation.common.data_generation_config imp
25
25
  from model_compression_toolkit.data_generation.common.enums import ImageGranularity, SchedulerType, \
26
26
  BatchNormAlignemntLossType, DataInitType, BNLayerWeightingType, ImagePipelineType, ImageNormalizationType, \
27
27
  OutputLossType
28
+ from model_compression_toolkit.data_generation.common.image_pipeline import image_normalization_dict
28
29
  from model_compression_toolkit.data_generation.pytorch.constants import DEFAULT_PYTORCH_INITIAL_LR, \
29
- DEFAULT_PYTORCH_OUTPUT_LOSS_MULTIPLIER, DEFAULT_PYTORCH_BN_LAYER_TYPES, DEFAULT_PYTORCH_LAST_LAYER_TYPES
30
+ DEFAULT_PYTORCH_BN_LAYER_TYPES, DEFAULT_PYTORCH_LAST_LAYER_TYPES, DEFAULT_PYTORCH_EXTRA_PIXELS, \
31
+ DEFAULT_PYTORCH_OUTPUT_LOSS_MULTIPLIER
30
32
  from model_compression_toolkit.data_generation.pytorch.image_pipeline import image_pipeline_dict, \
31
- image_normalization_dict, BaseImagePipeline
33
+ BaseImagePipeline
32
34
  from model_compression_toolkit.data_generation.pytorch.model_info_exctractors import PytorchActivationExtractor, \
33
35
  PytorchOriginalBNStatsHolder
34
36
  from model_compression_toolkit.data_generation.pytorch.optimization_functions.batchnorm_alignment_functions import \
@@ -51,6 +53,7 @@ if FOUND_TORCH and FOUND_TORCHVISION:
51
53
  from torch.nn import Module
52
54
  from torch.optim import RAdam, Optimizer
53
55
  from torch.fx import symbolic_trace
56
+ from torch.cuda.amp import autocast
54
57
 
55
58
  from model_compression_toolkit.core.pytorch.pytorch_device_config import get_working_device
56
59
 
@@ -61,19 +64,18 @@ if FOUND_TORCH and FOUND_TORCHVISION:
61
64
  data_gen_batch_size=DEFAULT_DATA_GEN_BS,
62
65
  initial_lr=DEFAULT_PYTORCH_INITIAL_LR,
63
66
  output_loss_multiplier=DEFAULT_PYTORCH_OUTPUT_LOSS_MULTIPLIER,
64
- scheduler_type: SchedulerType = SchedulerType.REDUCE_ON_PLATEAU,
67
+ scheduler_type: SchedulerType = SchedulerType.REDUCE_ON_PLATEAU_WITH_RESET,
65
68
  bn_alignment_loss_type: BatchNormAlignemntLossType = BatchNormAlignemntLossType.L2_SQUARE,
66
- output_loss_type: OutputLossType = OutputLossType.REGULARIZED_MIN_MAX_DIFF,
67
- data_init_type: DataInitType = DataInitType.Diverse,
69
+ output_loss_type: OutputLossType = OutputLossType.NEGATIVE_MIN_MAX_DIFF,
70
+ data_init_type: DataInitType = DataInitType.Gaussian,
68
71
  layer_weighting_type: BNLayerWeightingType = BNLayerWeightingType.AVERAGE,
69
72
  image_granularity=ImageGranularity.AllImages,
70
- image_pipeline_type: ImagePipelineType = ImagePipelineType.RANDOM_CROP,
73
+ image_pipeline_type: ImagePipelineType = ImagePipelineType.SMOOTHING_AND_AUGMENTATION,
71
74
  image_normalization_type: ImageNormalizationType = ImageNormalizationType.TORCHVISION,
72
- extra_pixels: int = 0,
75
+ extra_pixels: Union[int, Tuple[int, int]] = DEFAULT_PYTORCH_EXTRA_PIXELS,
73
76
  bn_layer_types: List = DEFAULT_PYTORCH_BN_LAYER_TYPES,
74
77
  last_layer_types: List = DEFAULT_PYTORCH_LAST_LAYER_TYPES,
75
- clip_images: bool = True,
76
- reflection: bool = True,
78
+ image_clipping: bool = True,
77
79
  ) -> DataGenerationConfig:
78
80
  """
79
81
  Function to create a DataGenerationConfig object with the specified configuration parameters.
@@ -92,11 +94,10 @@ if FOUND_TORCH and FOUND_TORCHVISION:
92
94
  image_granularity (ImageGranularity): The granularity of the images for optimization.
93
95
  image_pipeline_type (ImagePipelineType): The type of image pipeline to use.
94
96
  image_normalization_type (ImageNormalizationType): The type of image normalization to use.
95
- extra_pixels (int): Extra pixels to add to the input image size. Defaults to 0.
97
+ extra_pixels (Union[int, Tuple[int, int]]): Extra pixels to add to the input image size. Defaults to 0.
96
98
  bn_layer_types (List): List of BatchNorm layer types to be considered for data generation.
97
99
  last_layer_types (List): List of layer types to be considered for the output loss.
98
- clip_images (bool): Whether to clip images during optimization.
99
- reflection (bool): Whether to use reflection during optimization.
100
+ image_clipping (bool): Whether to clip images during optimization.
100
101
 
101
102
 
102
103
  Returns:
@@ -121,15 +122,14 @@ if FOUND_TORCH and FOUND_TORCHVISION:
121
122
  extra_pixels=extra_pixels,
122
123
  bn_layer_types=bn_layer_types,
123
124
  last_layer_types=last_layer_types,
124
- clip_images=clip_images,
125
- reflection=reflection
125
+ image_clipping=image_clipping,
126
126
  )
127
127
 
128
128
 
129
129
  def pytorch_data_generation_experimental(
130
130
  model: Module,
131
131
  n_images: int,
132
- output_image_size: int,
132
+ output_image_size: Union[int, Tuple[int, int]],
133
133
  data_generation_config: DataGenerationConfig) -> List[Tensor]:
134
134
  """
135
135
  Function to perform data generation using the provided model and data generation configuration.
@@ -137,7 +137,7 @@ if FOUND_TORCH and FOUND_TORCHVISION:
137
137
  Args:
138
138
  model (Module): PyTorch model to generate data for.
139
139
  n_images (int): Number of images to generate.
140
- output_image_size (int): The hight and width size of the output images.
140
+ output_image_size (Union[int, Tuple[int, int]]): The hight and width size of the output images.
141
141
  data_generation_config (DataGenerationConfig): Configuration for data generation.
142
142
 
143
143
  Returns:
@@ -176,6 +176,9 @@ if FOUND_TORCH and FOUND_TORCHVISION:
176
176
  f"If you encounter an issue, please open an issue in our GitHub "
177
177
  f"project https://github.com/sony/model_optimization")
178
178
 
179
+ # get the model device
180
+ device = get_working_device()
181
+
179
182
  # get a static graph representation of the model using torch.fx
180
183
  fx_model = symbolic_trace(model)
181
184
 
@@ -198,8 +201,8 @@ if FOUND_TORCH and FOUND_TORCHVISION:
198
201
 
199
202
  # Check if the scheduler type is valid
200
203
  if scheduler_get_fn is None or scheduler_step_fn is None:
201
- Logger.critical(f'Invalid output_loss_type {data_generation_config.scheduler_type}. '
202
- f'Please select one from {SchedulerType.get_values()}.')
204
+ Logger.critical(f'Invalid scheduler_type {data_generation_config.scheduler_type}. '
205
+ f'Please select one from {SchedulerType.get_values()}.') # pragma: no cover
203
206
 
204
207
  # Create a scheduler object with the specified number of iterations
205
208
  scheduler = scheduler_get_fn(data_generation_config.n_iter)
@@ -218,23 +221,22 @@ if FOUND_TORCH and FOUND_TORCHVISION:
218
221
  orig_bn_stats_holder = PytorchOriginalBNStatsHolder(model, data_generation_config.bn_layer_types)
219
222
  if orig_bn_stats_holder.get_num_bn_layers() == 0:
220
223
  Logger.critical(
221
- f'Data generation requires a model with at least one BatchNorm layer.')
224
+ f'Data generation requires a model with at least one BatchNorm layer.') # pragma: no cover
222
225
 
223
226
  # Create an ImagesOptimizationHandler object for handling optimization
224
227
  all_imgs_opt_handler = PytorchImagesOptimizationHandler(model=model,
225
- data_gen_batch_size=data_generation_config.data_gen_batch_size,
226
- init_dataset=init_dataset,
227
- optimizer=data_generation_config.optimizer,
228
- image_pipeline=image_pipeline,
229
- activation_extractor=activation_extractor,
230
- image_granularity=data_generation_config.image_granularity,
231
- scheduler_step_fn=scheduler_step_fn,
232
- scheduler=scheduler,
233
- initial_lr=data_generation_config.initial_lr,
234
- normalization_mean=normalization[0],
235
- normalization_std=normalization[1],
236
- clip_images=data_generation_config.clip_images,
237
- reflection=data_generation_config.reflection)
228
+ data_gen_batch_size=data_generation_config.data_gen_batch_size,
229
+ init_dataset=init_dataset,
230
+ optimizer=data_generation_config.optimizer,
231
+ image_pipeline=image_pipeline,
232
+ activation_extractor=activation_extractor,
233
+ image_granularity=data_generation_config.image_granularity,
234
+ scheduler_step_fn=scheduler_step_fn,
235
+ scheduler=scheduler,
236
+ initial_lr=data_generation_config.initial_lr,
237
+ normalization_mean=normalization[0],
238
+ normalization_std=normalization[1],
239
+ device=device)
238
240
 
239
241
  # Perform data generation and obtain a list of generated images
240
242
  generated_images_list = data_generation(
@@ -247,6 +249,7 @@ if FOUND_TORCH and FOUND_TORCHVISION:
247
249
  bn_alignment_loss_fn=bn_alignment_loss_fn,
248
250
  output_loss_fn=output_loss_fn,
249
251
  output_loss_multiplier=data_generation_config.output_loss_multiplier,
252
+ device=device,
250
253
  )
251
254
  # Return the list of finalized generated images
252
255
  return generated_images_list
@@ -261,7 +264,8 @@ if FOUND_TORCH and FOUND_TORCHVISION:
261
264
  bn_layer_weighting_fn: Callable,
262
265
  bn_alignment_loss_fn: Callable,
263
266
  output_loss_fn: Callable,
264
- output_loss_multiplier: float
267
+ output_loss_multiplier: float,
268
+ device: torch.device
265
269
  ) -> List[Any]:
266
270
  """
267
271
  Function to perform data generation using the provided model and data generation configuration.
@@ -276,14 +280,11 @@ if FOUND_TORCH and FOUND_TORCHVISION:
276
280
  bn_alignment_loss_fn (Callable): Function to compute BatchNorm alignment loss.
277
281
  output_loss_fn (Callable): Function to compute output loss.
278
282
  output_loss_multiplier (float): Multiplier for the output loss.
283
+ device (torch.device): The current device set for PyTorch operations.
279
284
 
280
285
  Returns:
281
286
  List: Finalized list containing generated images.
282
287
  """
283
-
284
- # Compute the layer weights based on orig_bn_stats_holder
285
- bn_layer_weights = bn_layer_weighting_fn(orig_bn_stats_holder)
286
-
287
288
  # Get the current time to measure the total time taken
288
289
  total_time = time.time()
289
290
 
@@ -291,7 +292,7 @@ if FOUND_TORCH and FOUND_TORCHVISION:
291
292
  ibar = tqdm(range(data_generation_config.n_iter))
292
293
 
293
294
  # Perform data generation iterations
294
- for i_ter in ibar:
295
+ for i_iter in ibar:
295
296
 
296
297
  # Randomly reorder the batches
297
298
  all_imgs_opt_handler.random_batch_reorder()
@@ -311,7 +312,11 @@ if FOUND_TORCH and FOUND_TORCHVISION:
311
312
  input_imgs = image_pipeline.image_input_manipulation(imgs_to_optimize)
312
313
 
313
314
  # Forward pass to extract activations
314
- output = activation_extractor.run_model(input_imgs)
315
+ with autocast():
316
+ output = activation_extractor.run_model(input_imgs)
317
+
318
+ # Compute the layer weights based on orig_bn_stats_holder
319
+ bn_layer_weights = bn_layer_weighting_fn(orig_bn_stats_holder, activation_extractor, i_iter, data_generation_config.n_iter)
315
320
 
316
321
  # Compute BatchNorm alignment loss
317
322
  bn_loss = all_imgs_opt_handler.compute_bn_loss(input_imgs=input_imgs,
@@ -322,33 +327,33 @@ if FOUND_TORCH and FOUND_TORCHVISION:
322
327
  bn_layer_weights=bn_layer_weights)
323
328
 
324
329
  # Compute output loss
325
- if output_loss_multiplier > 0:
326
- output_loss = output_loss_fn(
327
- output_imgs=output,
328
- activation_extractor=activation_extractor)
329
- else:
330
- output_loss = torch.zeros(1).to(get_working_device())
330
+ output_loss = output_loss_fn(
331
+ model_outputs=output,
332
+ activation_extractor=activation_extractor,
333
+ device=device)
331
334
 
332
335
  # Compute total loss
333
336
  total_loss = bn_loss + output_loss_multiplier * output_loss
334
337
 
335
338
  # Perform optimiztion step
336
- all_imgs_opt_handler.optimization_step(random_batch_index, total_loss, i_ter)
339
+ all_imgs_opt_handler.optimization_step(random_batch_index, total_loss, i_iter)
337
340
 
338
341
  # Update the statistics based on the updated images
339
342
  if all_imgs_opt_handler.use_all_data_stats:
340
- final_imgs = image_pipeline.image_output_finalize(imgs_to_optimize)
341
- all_imgs_opt_handler.update_statistics(input_imgs=final_imgs,
342
- batch_index=random_batch_index,
343
- activation_extractor=activation_extractor)
343
+ with autocast():
344
+ final_imgs = image_pipeline.image_output_finalize(imgs_to_optimize)
345
+ all_imgs_opt_handler.update_statistics(input_imgs=final_imgs,
346
+ batch_index=random_batch_index,
347
+ activation_extractor=activation_extractor)
344
348
 
345
349
  ibar.set_description(f"Total Loss: {total_loss.item():.5f}, "
346
350
  f"BN Loss: {bn_loss.item():.5f}, "
347
- f"Output Loss: {output_loss_multiplier * output_loss.item():.5f}")
351
+ f"Output Loss: {output_loss.item():.5f}")
348
352
 
349
353
  # Return a list containing the finalized generated images
350
354
  finalized_imgs = all_imgs_opt_handler.get_finalized_images()
351
355
  Logger.info(f'Total time to generate {len(finalized_imgs)} images (seconds): {int(time.time() - total_time)}')
356
+ Logger.info(f'Final Loss: Total {total_loss.item()}, BN loss {bn_loss.item()}, Output loss {output_loss.item()}')
352
357
  return finalized_imgs
353
358
  else:
354
359
  # If torch is not installed,