mct-nightly 2.1.0.20240724.437__py3-none-any.whl → 2.1.0.20240726.430__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mct_nightly-2.1.0.20240724.437.dist-info → mct_nightly-2.1.0.20240726.430.dist-info}/METADATA +1 -1
- {mct_nightly-2.1.0.20240724.437.dist-info → mct_nightly-2.1.0.20240726.430.dist-info}/RECORD +35 -31
- model_compression_toolkit/__init__.py +1 -1
- model_compression_toolkit/core/pytorch/constants.py +6 -1
- model_compression_toolkit/core/pytorch/utils.py +27 -0
- model_compression_toolkit/data_generation/common/data_generation.py +20 -18
- model_compression_toolkit/data_generation/common/data_generation_config.py +8 -11
- model_compression_toolkit/data_generation/common/enums.py +24 -12
- model_compression_toolkit/data_generation/common/image_pipeline.py +50 -12
- model_compression_toolkit/data_generation/common/model_info_exctractors.py +0 -8
- model_compression_toolkit/data_generation/common/optimization_utils.py +7 -11
- model_compression_toolkit/data_generation/keras/constants.py +5 -2
- model_compression_toolkit/data_generation/keras/image_operations.py +189 -0
- model_compression_toolkit/data_generation/keras/image_pipeline.py +50 -104
- model_compression_toolkit/data_generation/keras/keras_data_generation.py +28 -36
- model_compression_toolkit/data_generation/keras/model_info_exctractors.py +0 -13
- model_compression_toolkit/data_generation/keras/optimization_functions/bn_layer_weighting_functions.py +16 -6
- model_compression_toolkit/data_generation/keras/optimization_functions/lr_scheduler.py +219 -0
- model_compression_toolkit/data_generation/keras/optimization_functions/output_loss_functions.py +39 -13
- model_compression_toolkit/data_generation/keras/optimization_functions/scheduler_step_functions.py +6 -98
- model_compression_toolkit/data_generation/keras/optimization_utils.py +15 -28
- model_compression_toolkit/data_generation/pytorch/constants.py +4 -1
- model_compression_toolkit/data_generation/pytorch/image_operations.py +105 -0
- model_compression_toolkit/data_generation/pytorch/image_pipeline.py +70 -78
- model_compression_toolkit/data_generation/pytorch/model_info_exctractors.py +0 -10
- model_compression_toolkit/data_generation/pytorch/optimization_functions/bn_layer_weighting_functions.py +17 -6
- model_compression_toolkit/data_generation/pytorch/optimization_functions/image_initilization.py +2 -2
- model_compression_toolkit/data_generation/pytorch/optimization_functions/lr_scheduler.py +219 -0
- model_compression_toolkit/data_generation/pytorch/optimization_functions/output_loss_functions.py +55 -21
- model_compression_toolkit/data_generation/pytorch/optimization_functions/scheduler_step_functions.py +15 -0
- model_compression_toolkit/data_generation/pytorch/optimization_utils.py +32 -54
- model_compression_toolkit/data_generation/pytorch/pytorch_data_generation.py +57 -52
- {mct_nightly-2.1.0.20240724.437.dist-info → mct_nightly-2.1.0.20240726.430.dist-info}/LICENSE.md +0 -0
- {mct_nightly-2.1.0.20240724.437.dist-info → mct_nightly-2.1.0.20240726.430.dist-info}/WHEEL +0 -0
- {mct_nightly-2.1.0.20240724.437.dist-info → mct_nightly-2.1.0.20240726.430.dist-info}/top_level.txt +0 -0
model_compression_toolkit/data_generation/pytorch/optimization_functions/output_loss_functions.py
CHANGED
@@ -17,37 +17,66 @@ from typing import Dict, Callable
|
|
17
17
|
import torch
|
18
18
|
from torch import Tensor
|
19
19
|
|
20
|
-
from model_compression_toolkit.core.pytorch.pytorch_device_config import get_working_device
|
21
20
|
from model_compression_toolkit.data_generation.common.enums import OutputLossType
|
22
21
|
from model_compression_toolkit.data_generation.pytorch.model_info_exctractors import PytorchActivationExtractor
|
23
22
|
|
24
|
-
def
|
25
|
-
|
23
|
+
def inverse_min_max_diff(
|
24
|
+
model_outputs: Tensor,
|
26
25
|
activation_extractor: PytorchActivationExtractor,
|
26
|
+
device: torch.device,
|
27
27
|
eps: float = 1e-6) -> Tensor:
|
28
28
|
"""
|
29
|
-
Calculate the
|
29
|
+
Calculate the inverse of the maximum - minimum difference of the model output on the input images.
|
30
30
|
|
31
31
|
Args:
|
32
|
-
|
32
|
+
model_outputs (Tensor or List[Tensor]): The output of the model on images.
|
33
33
|
activation_extractor (PytorchActivationExtractor): The activation extractor for the model.
|
34
|
+
device (torch.device): The current device set for PyTorch operations.
|
34
35
|
eps (float): Small value for numerical stability.
|
35
36
|
|
36
37
|
Returns:
|
37
38
|
Tensor: The computed minimum-maximum difference loss.
|
38
39
|
"""
|
39
|
-
if not isinstance(
|
40
|
-
|
41
|
-
output_loss =
|
42
|
-
for output in
|
40
|
+
if not isinstance(model_outputs, (list, tuple)):
|
41
|
+
model_outputs = [model_outputs]
|
42
|
+
output_loss = torch.zeros(1).to(device)
|
43
|
+
for output in model_outputs:
|
43
44
|
output = torch.reshape(output, [output.shape[0], -1])
|
44
45
|
output_loss += 1 / torch.mean(torch.max(output, 1)[0] - torch.min(output, 1)[0] + eps)
|
45
46
|
return output_loss
|
46
47
|
|
48
|
+
def negative_min_max_diff(
|
49
|
+
model_outputs: Tensor,
|
50
|
+
activation_extractor: PytorchActivationExtractor,
|
51
|
+
device: torch.device,
|
52
|
+
eps: float = 1e-6) -> Tensor:
|
53
|
+
"""
|
54
|
+
Calculate the mean of the negative maximum - minimum difference of the model output on the input images.
|
55
|
+
|
56
|
+
Args:
|
57
|
+
model_outputs (Tensor or List[Tensor]): The output of the model on images.
|
58
|
+
activation_extractor (PytorchActivationExtractor): The activation extractor for the model.
|
59
|
+
device (torch.device): The current device set for PyTorch operations.
|
60
|
+
eps (float): Small value for numerical stability.
|
61
|
+
|
62
|
+
Returns:
|
63
|
+
Tensor: The computed minimum-maximum difference loss.
|
64
|
+
"""
|
65
|
+
if not isinstance(model_outputs, (list, tuple)):
|
66
|
+
model_outputs = [model_outputs]
|
67
|
+
output_loss = torch.zeros(1).to(device)
|
68
|
+
for output in model_outputs:
|
69
|
+
output = torch.reshape(output, [output.shape[0], -1])
|
70
|
+
out_max, out_argmax = torch.max(output, dim=1)
|
71
|
+
out_min, out_argmin = torch.min(output, dim=1)
|
72
|
+
output_loss += torch.mean(-(out_max - out_min))
|
73
|
+
return output_loss
|
74
|
+
|
47
75
|
|
48
76
|
def regularized_min_max_diff(
|
49
|
-
|
77
|
+
model_outputs: Tensor,
|
50
78
|
activation_extractor: PytorchActivationExtractor,
|
79
|
+
device: torch.device,
|
51
80
|
eps: float = 1e-6) -> Tensor:
|
52
81
|
"""
|
53
82
|
Calculate the regularized minimum-maximum difference of output images. We want to maximize
|
@@ -56,8 +85,9 @@ def regularized_min_max_diff(
|
|
56
85
|
the last layer's weights.
|
57
86
|
|
58
87
|
Args:
|
59
|
-
|
88
|
+
model_outputs (Tensor or List[Tensor]): The output of the model on images.
|
60
89
|
activation_extractor (PytorchActivationExtractor): The activation extractor for the model.
|
90
|
+
device (torch.device): The current device set for PyTorch operations.
|
61
91
|
eps (float): Small value for numerical stability.
|
62
92
|
|
63
93
|
Returns:
|
@@ -69,13 +99,13 @@ def regularized_min_max_diff(
|
|
69
99
|
# get the weights of the last linear layers of the model
|
70
100
|
weights_output_layers = activation_extractor.get_last_linear_layers_weights()
|
71
101
|
|
72
|
-
if not isinstance(
|
73
|
-
|
74
|
-
|
75
|
-
|
76
|
-
output_loss = torch.zeros(1).to(
|
102
|
+
if not isinstance(model_outputs, (list, tuple)):
|
103
|
+
model_outputs = torch.reshape(model_outputs, [model_outputs.shape[0], model_outputs.shape[1], -1])
|
104
|
+
model_outputs = torch.mean(model_outputs, dim=-1)
|
105
|
+
model_outputs = [model_outputs]
|
106
|
+
output_loss = torch.zeros(1).to(device)
|
77
107
|
|
78
|
-
for output_weight, output, last_layer_input in zip(weights_output_layers,
|
108
|
+
for output_weight, output, last_layer_input in zip(weights_output_layers, model_outputs, output_layers_inputs):
|
79
109
|
weights_norm = torch.linalg.norm(output_weight.squeeze(), dim=1)
|
80
110
|
out_max, out_argmax = torch.max(output, dim=1)
|
81
111
|
out_min, out_argmin = torch.min(output, dim=1)
|
@@ -88,27 +118,31 @@ def regularized_min_max_diff(
|
|
88
118
|
output_loss += torch.mean(reg_min + reg_max + dynamic_loss)
|
89
119
|
return output_loss
|
90
120
|
|
121
|
+
|
91
122
|
def no_output_loss(
|
92
|
-
|
123
|
+
model_outputs: Tensor,
|
93
124
|
activation_extractor: PytorchActivationExtractor,
|
125
|
+
device: torch.device,
|
94
126
|
eps: float = 1e-6) -> Tensor:
|
95
127
|
"""
|
96
128
|
Calculate no output loss.
|
97
129
|
|
98
130
|
Args:
|
99
|
-
|
131
|
+
model_outputs (Tensor): The output of the model on images.
|
100
132
|
activation_extractor (PytorchActivationExtractor): The activation extractor for the model.
|
133
|
+
device (torch.device): The current device set for PyTorch operations.
|
101
134
|
eps (float): Small value for numerical stability.
|
102
135
|
|
103
136
|
Returns:
|
104
137
|
Tensor: A tensor with zero value for the loss.
|
105
138
|
"""
|
106
|
-
return torch.zeros(1).to(
|
139
|
+
return torch.zeros(1).to(device)
|
107
140
|
|
108
141
|
|
109
142
|
# Dictionary of output loss functions
|
110
143
|
output_loss_function_dict: Dict[OutputLossType, Callable] = {
|
111
144
|
OutputLossType.NONE: no_output_loss,
|
112
|
-
OutputLossType.
|
145
|
+
OutputLossType.NEGATIVE_MIN_MAX_DIFF: negative_min_max_diff,
|
146
|
+
OutputLossType.INVERSE_MIN_MAX_DIFF: inverse_min_max_diff,
|
113
147
|
OutputLossType.REGULARIZED_MIN_MAX_DIFF: regularized_min_max_diff,
|
114
148
|
}
|
model_compression_toolkit/data_generation/pytorch/optimization_functions/scheduler_step_functions.py
CHANGED
@@ -17,6 +17,8 @@ from typing import Callable, Any, Dict, Tuple
|
|
17
17
|
|
18
18
|
from torch.optim.lr_scheduler import ReduceLROnPlateau, StepLR
|
19
19
|
from model_compression_toolkit.data_generation.common.enums import SchedulerType
|
20
|
+
from model_compression_toolkit.data_generation.pytorch.optimization_functions.lr_scheduler import \
|
21
|
+
ReduceLROnPlateauWithReset
|
20
22
|
|
21
23
|
|
22
24
|
def get_reduce_lr_on_plateau_scheduler(n_iter: int) -> Callable:
|
@@ -31,6 +33,18 @@ def get_reduce_lr_on_plateau_scheduler(n_iter: int) -> Callable:
|
|
31
33
|
"""
|
32
34
|
return partial(ReduceLROnPlateau, min_lr=1e-4, factor=0.5, patience=int(n_iter / 50))
|
33
35
|
|
36
|
+
def get_reduce_lr_on_plateau_with_reset_scheduler(n_iter: int) -> Callable:
|
37
|
+
"""
|
38
|
+
Get a ReduceLROnPlateauWithReset scheduler.
|
39
|
+
|
40
|
+
Args:
|
41
|
+
n_iter (int): The number of iterations.
|
42
|
+
|
43
|
+
Returns:
|
44
|
+
Callable: A partial function to create ReduceLROnPlateauWithReset scheduler with specified parameters.
|
45
|
+
"""
|
46
|
+
return partial(ReduceLROnPlateauWithReset, min_lr=1e-4, factor=0.5, patience=int(n_iter / 50))
|
47
|
+
|
34
48
|
def get_step_lr_scheduler(n_iter: int) -> Callable:
|
35
49
|
"""
|
36
50
|
Get a StepLR scheduler.
|
@@ -69,5 +83,6 @@ def scheduler_step_fn(scheduler: Any, i_iter: int, loss_value: float):
|
|
69
83
|
# Dictionary of scheduler functions and their corresponding step functions
|
70
84
|
scheduler_step_function_dict: Dict[SchedulerType, Tuple[Callable, Callable]] = {
|
71
85
|
SchedulerType.REDUCE_ON_PLATEAU: (get_reduce_lr_on_plateau_scheduler, reduce_lr_on_platu_step_fn),
|
86
|
+
SchedulerType.REDUCE_ON_PLATEAU_WITH_RESET: (get_reduce_lr_on_plateau_with_reset_scheduler, reduce_lr_on_platu_step_fn),
|
72
87
|
SchedulerType.STEP: (get_step_lr_scheduler, scheduler_step_fn),
|
73
88
|
}
|
@@ -20,15 +20,16 @@ from torch import Tensor
|
|
20
20
|
from torch.nn import Module
|
21
21
|
from torch.optim import Optimizer
|
22
22
|
from torch.utils.data import DataLoader, Dataset
|
23
|
-
from
|
23
|
+
from torch.cuda.amp import GradScaler
|
24
24
|
|
25
|
-
from model_compression_toolkit.core.pytorch.
|
25
|
+
from model_compression_toolkit.core.pytorch.utils import to_torch_tensor, clip_inf_values_float16
|
26
26
|
from model_compression_toolkit.data_generation.common.enums import ImageGranularity
|
27
27
|
from model_compression_toolkit.data_generation.common.image_pipeline import BaseImagePipeline
|
28
28
|
from model_compression_toolkit.data_generation.common.optimization_utils import BatchStatsHolder, AllImagesStatsHolder, \
|
29
29
|
BatchOptimizationHolder, ImagesOptimizationHandler
|
30
30
|
from model_compression_toolkit.data_generation.common.constants import IMAGE_INPUT
|
31
31
|
from model_compression_toolkit.data_generation.pytorch.constants import BATCH_AXIS, H_AXIS, W_AXIS
|
32
|
+
from model_compression_toolkit.data_generation.pytorch.image_operations import create_valid_grid
|
32
33
|
from model_compression_toolkit.data_generation.pytorch.model_info_exctractors import ActivationExtractor
|
33
34
|
|
34
35
|
|
@@ -58,8 +59,7 @@ class PytorchImagesOptimizationHandler(ImagesOptimizationHandler):
|
|
58
59
|
initial_lr: float,
|
59
60
|
normalization_mean: List[float],
|
60
61
|
normalization_std: List[float],
|
61
|
-
|
62
|
-
reflection: bool,
|
62
|
+
device: str,
|
63
63
|
eps: float = 1e-6):
|
64
64
|
"""
|
65
65
|
Constructor for the PytorchImagesOptimizationHandler class.
|
@@ -77,8 +77,7 @@ class PytorchImagesOptimizationHandler(ImagesOptimizationHandler):
|
|
77
77
|
initial_lr (float): The initial learning rate used by the optimizer.
|
78
78
|
normalization_mean (List[float]): The mean values for image normalization.
|
79
79
|
normalization_std (List[float]): The standard deviation values for image normalization.
|
80
|
-
|
81
|
-
reflection (bool): Whether to use reflection during image clipping.
|
80
|
+
device (torch.device): The current device set for PyTorch operations.
|
82
81
|
eps (float): A small value added for numerical stability.
|
83
82
|
"""
|
84
83
|
super(PytorchImagesOptimizationHandler, self).__init__(model=model,
|
@@ -93,16 +92,12 @@ class PytorchImagesOptimizationHandler(ImagesOptimizationHandler):
|
|
93
92
|
initial_lr=initial_lr,
|
94
93
|
normalization_mean=normalization_mean,
|
95
94
|
normalization_std=normalization_std,
|
96
|
-
clip_images=clip_images,
|
97
|
-
reflection=reflection,
|
98
95
|
eps=eps)
|
99
96
|
|
100
|
-
|
101
|
-
|
102
|
-
|
103
|
-
self.valid_grid =
|
104
|
-
std=normalization_std)(t.transpose(1, 0)[None, :, :, None]).squeeze().to(self.device)
|
105
|
-
|
97
|
+
# Initialize mixed-precision scaler
|
98
|
+
self.scaler = GradScaler()
|
99
|
+
self.device = device
|
100
|
+
self.valid_grid = create_valid_grid(normalization_mean, normalization_std)
|
106
101
|
|
107
102
|
# Set the mean axis based on the image granularity
|
108
103
|
if self.image_granularity == ImageGranularity.ImageWise:
|
@@ -155,42 +150,40 @@ class PytorchImagesOptimizationHandler(ImagesOptimizationHandler):
|
|
155
150
|
total_mean, total_second_moment = 0, 0
|
156
151
|
for i_batch in range(self.n_batches):
|
157
152
|
mean, second_moment, std = self.all_imgs_stats_holder.get_stats(i_batch, layer_name)
|
158
|
-
|
159
|
-
|
153
|
+
if mean is not None:
|
154
|
+
total_mean += mean
|
155
|
+
if second_moment is not None:
|
156
|
+
total_second_moment += second_moment
|
160
157
|
|
161
158
|
total_mean /= self.n_batches
|
162
159
|
total_second_moment /= self.n_batches
|
163
|
-
total_var = total_second_moment - torch.pow(total_mean, 2)
|
160
|
+
total_var = to_torch_tensor(total_second_moment) - torch.pow(to_torch_tensor(total_mean), 2)
|
164
161
|
total_std = torch.sqrt(total_var + self.eps)
|
165
162
|
return total_mean, total_std
|
166
163
|
|
167
164
|
def optimization_step(self,
|
168
165
|
batch_index: int,
|
169
166
|
loss: Tensor,
|
170
|
-
|
167
|
+
i_iter: int):
|
171
168
|
"""
|
172
169
|
Perform an optimization step.
|
173
170
|
|
174
171
|
Args:
|
175
172
|
batch_index (int): Index of the batch.
|
176
173
|
loss (Tensor): Loss value.
|
177
|
-
|
174
|
+
i_iter (int): Current optimization iteration.
|
178
175
|
"""
|
179
176
|
# Get optimizer and scheduler for the specific batch index
|
180
177
|
optimizer = self.get_optimizer_by_batch_index(batch_index)
|
181
178
|
scheduler = self.get_scheduler_by_batch_index(batch_index)
|
182
179
|
|
183
180
|
# Backward pass
|
184
|
-
loss.backward()
|
185
|
-
|
186
|
-
|
187
|
-
optimizer.step()
|
181
|
+
self.scaler.scale(loss).backward()
|
182
|
+
self.scaler.step(optimizer)
|
183
|
+
self.scaler.update()
|
188
184
|
|
189
185
|
# Perform scheduler step
|
190
|
-
self.scheduler_step_fn(scheduler,
|
191
|
-
|
192
|
-
if self.clip_images:
|
193
|
-
self.batch_opt_holders_list[batch_index].clip_images(self.valid_grid, reflection=self.reflection)
|
186
|
+
self.scheduler_step_fn(scheduler, i_iter, loss.item())
|
194
187
|
|
195
188
|
|
196
189
|
def zero_grad(self, batch_index: int):
|
@@ -259,25 +252,6 @@ class PytorchBatchOptimizationHolder(BatchOptimizationHolder):
|
|
259
252
|
self.optimizer = optimizer([self.images], lr=initial_lr)
|
260
253
|
self.scheduler = scheduler(self.optimizer)
|
261
254
|
|
262
|
-
def clip_images(self,
|
263
|
-
valid_grid: Tensor,
|
264
|
-
reflection: bool = True):
|
265
|
-
"""
|
266
|
-
Clip the images.
|
267
|
-
|
268
|
-
Args:
|
269
|
-
valid_grid (Tensor): A tensor containing valid values for image clipping.
|
270
|
-
reflection (bool): Whether to use reflection during image clipping. Defaults to True.
|
271
|
-
"""
|
272
|
-
with torch.no_grad():
|
273
|
-
for i_ch in range(valid_grid.shape[0]):
|
274
|
-
clamp = torch.clamp(self.images[:, i_ch, :, :], valid_grid[i_ch, :].min(), valid_grid[i_ch, :].max())
|
275
|
-
if reflection:
|
276
|
-
self.images[:, i_ch, :, :] = 2 * clamp - self.images[:, i_ch, :, :]
|
277
|
-
else:
|
278
|
-
self.images[:, i_ch, :, :] = clamp
|
279
|
-
self.images.requires_grad = True
|
280
|
-
|
281
255
|
|
282
256
|
class PytorchAllImagesStatsHolder(AllImagesStatsHolder):
|
283
257
|
"""
|
@@ -332,8 +306,9 @@ class PytorchBatchStatsHolder(BatchStatsHolder):
|
|
332
306
|
"""
|
333
307
|
mean = self.get_mean(bn_layer_name)
|
334
308
|
second_moment = self.get_second_moment(bn_layer_name)
|
335
|
-
|
336
|
-
|
309
|
+
if mean is not None and second_moment is not None:
|
310
|
+
return second_moment - torch.pow(mean, 2.0)
|
311
|
+
return None
|
337
312
|
|
338
313
|
|
339
314
|
def get_std(self, bn_layer_name: str) -> Tensor:
|
@@ -347,7 +322,9 @@ class PytorchBatchStatsHolder(BatchStatsHolder):
|
|
347
322
|
Tensor: The standard deviation for the specified layer.
|
348
323
|
"""
|
349
324
|
var = self.get_var(bn_layer_name)
|
350
|
-
|
325
|
+
if var is not None:
|
326
|
+
return torch.sqrt(var + self.eps)
|
327
|
+
return None
|
351
328
|
|
352
329
|
def calc_bn_stats_from_activations(self,
|
353
330
|
input_imgs: Tensor,
|
@@ -374,12 +351,13 @@ class PytorchBatchStatsHolder(BatchStatsHolder):
|
|
374
351
|
# Extract statistics of intermediate convolution outputs before the BatchNorm layers
|
375
352
|
for bn_layer_name in activation_extractor.get_extractor_layer_names():
|
376
353
|
bn_input_activations = activation_extractor.get_layer_input_activation(bn_layer_name)
|
377
|
-
if not
|
378
|
-
|
354
|
+
if bn_input_activations is not None:
|
355
|
+
if not to_differentiate:
|
356
|
+
bn_input_activations = bn_input_activations.detach()
|
379
357
|
|
380
|
-
|
381
|
-
|
382
|
-
|
358
|
+
collected_mean = torch.mean(bn_input_activations, dim=self.mean_axis)
|
359
|
+
collected_second_moment = clip_inf_values_float16(torch.mean(torch.pow(bn_input_activations, 2.0), dim=self.mean_axis))
|
360
|
+
self.update_layer_stats(bn_layer_name, collected_mean, collected_second_moment)
|
383
361
|
|
384
362
|
def clear(self):
|
385
363
|
"""Clear the statistics."""
|
@@ -13,7 +13,7 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
15
|
import time
|
16
|
-
from typing import Callable, Any, Tuple, List
|
16
|
+
from typing import Callable, Any, Tuple, List, Union
|
17
17
|
|
18
18
|
from tqdm import tqdm
|
19
19
|
|
@@ -25,10 +25,12 @@ from model_compression_toolkit.data_generation.common.data_generation_config imp
|
|
25
25
|
from model_compression_toolkit.data_generation.common.enums import ImageGranularity, SchedulerType, \
|
26
26
|
BatchNormAlignemntLossType, DataInitType, BNLayerWeightingType, ImagePipelineType, ImageNormalizationType, \
|
27
27
|
OutputLossType
|
28
|
+
from model_compression_toolkit.data_generation.common.image_pipeline import image_normalization_dict
|
28
29
|
from model_compression_toolkit.data_generation.pytorch.constants import DEFAULT_PYTORCH_INITIAL_LR, \
|
29
|
-
|
30
|
+
DEFAULT_PYTORCH_BN_LAYER_TYPES, DEFAULT_PYTORCH_LAST_LAYER_TYPES, DEFAULT_PYTORCH_EXTRA_PIXELS, \
|
31
|
+
DEFAULT_PYTORCH_OUTPUT_LOSS_MULTIPLIER
|
30
32
|
from model_compression_toolkit.data_generation.pytorch.image_pipeline import image_pipeline_dict, \
|
31
|
-
|
33
|
+
BaseImagePipeline
|
32
34
|
from model_compression_toolkit.data_generation.pytorch.model_info_exctractors import PytorchActivationExtractor, \
|
33
35
|
PytorchOriginalBNStatsHolder
|
34
36
|
from model_compression_toolkit.data_generation.pytorch.optimization_functions.batchnorm_alignment_functions import \
|
@@ -51,6 +53,7 @@ if FOUND_TORCH and FOUND_TORCHVISION:
|
|
51
53
|
from torch.nn import Module
|
52
54
|
from torch.optim import RAdam, Optimizer
|
53
55
|
from torch.fx import symbolic_trace
|
56
|
+
from torch.cuda.amp import autocast
|
54
57
|
|
55
58
|
from model_compression_toolkit.core.pytorch.pytorch_device_config import get_working_device
|
56
59
|
|
@@ -61,19 +64,18 @@ if FOUND_TORCH and FOUND_TORCHVISION:
|
|
61
64
|
data_gen_batch_size=DEFAULT_DATA_GEN_BS,
|
62
65
|
initial_lr=DEFAULT_PYTORCH_INITIAL_LR,
|
63
66
|
output_loss_multiplier=DEFAULT_PYTORCH_OUTPUT_LOSS_MULTIPLIER,
|
64
|
-
scheduler_type: SchedulerType = SchedulerType.
|
67
|
+
scheduler_type: SchedulerType = SchedulerType.REDUCE_ON_PLATEAU_WITH_RESET,
|
65
68
|
bn_alignment_loss_type: BatchNormAlignemntLossType = BatchNormAlignemntLossType.L2_SQUARE,
|
66
|
-
output_loss_type: OutputLossType = OutputLossType.
|
67
|
-
data_init_type: DataInitType = DataInitType.
|
69
|
+
output_loss_type: OutputLossType = OutputLossType.NEGATIVE_MIN_MAX_DIFF,
|
70
|
+
data_init_type: DataInitType = DataInitType.Gaussian,
|
68
71
|
layer_weighting_type: BNLayerWeightingType = BNLayerWeightingType.AVERAGE,
|
69
72
|
image_granularity=ImageGranularity.AllImages,
|
70
|
-
image_pipeline_type: ImagePipelineType = ImagePipelineType.
|
73
|
+
image_pipeline_type: ImagePipelineType = ImagePipelineType.SMOOTHING_AND_AUGMENTATION,
|
71
74
|
image_normalization_type: ImageNormalizationType = ImageNormalizationType.TORCHVISION,
|
72
|
-
extra_pixels: int =
|
75
|
+
extra_pixels: Union[int, Tuple[int, int]] = DEFAULT_PYTORCH_EXTRA_PIXELS,
|
73
76
|
bn_layer_types: List = DEFAULT_PYTORCH_BN_LAYER_TYPES,
|
74
77
|
last_layer_types: List = DEFAULT_PYTORCH_LAST_LAYER_TYPES,
|
75
|
-
|
76
|
-
reflection: bool = True,
|
78
|
+
image_clipping: bool = True,
|
77
79
|
) -> DataGenerationConfig:
|
78
80
|
"""
|
79
81
|
Function to create a DataGenerationConfig object with the specified configuration parameters.
|
@@ -92,11 +94,10 @@ if FOUND_TORCH and FOUND_TORCHVISION:
|
|
92
94
|
image_granularity (ImageGranularity): The granularity of the images for optimization.
|
93
95
|
image_pipeline_type (ImagePipelineType): The type of image pipeline to use.
|
94
96
|
image_normalization_type (ImageNormalizationType): The type of image normalization to use.
|
95
|
-
extra_pixels (int): Extra pixels to add to the input image size. Defaults to 0.
|
97
|
+
extra_pixels (Union[int, Tuple[int, int]]): Extra pixels to add to the input image size. Defaults to 0.
|
96
98
|
bn_layer_types (List): List of BatchNorm layer types to be considered for data generation.
|
97
99
|
last_layer_types (List): List of layer types to be considered for the output loss.
|
98
|
-
|
99
|
-
reflection (bool): Whether to use reflection during optimization.
|
100
|
+
image_clipping (bool): Whether to clip images during optimization.
|
100
101
|
|
101
102
|
|
102
103
|
Returns:
|
@@ -121,15 +122,14 @@ if FOUND_TORCH and FOUND_TORCHVISION:
|
|
121
122
|
extra_pixels=extra_pixels,
|
122
123
|
bn_layer_types=bn_layer_types,
|
123
124
|
last_layer_types=last_layer_types,
|
124
|
-
|
125
|
-
reflection=reflection
|
125
|
+
image_clipping=image_clipping,
|
126
126
|
)
|
127
127
|
|
128
128
|
|
129
129
|
def pytorch_data_generation_experimental(
|
130
130
|
model: Module,
|
131
131
|
n_images: int,
|
132
|
-
output_image_size: int,
|
132
|
+
output_image_size: Union[int, Tuple[int, int]],
|
133
133
|
data_generation_config: DataGenerationConfig) -> List[Tensor]:
|
134
134
|
"""
|
135
135
|
Function to perform data generation using the provided model and data generation configuration.
|
@@ -137,7 +137,7 @@ if FOUND_TORCH and FOUND_TORCHVISION:
|
|
137
137
|
Args:
|
138
138
|
model (Module): PyTorch model to generate data for.
|
139
139
|
n_images (int): Number of images to generate.
|
140
|
-
output_image_size (int): The hight and width size of the output images.
|
140
|
+
output_image_size (Union[int, Tuple[int, int]]): The hight and width size of the output images.
|
141
141
|
data_generation_config (DataGenerationConfig): Configuration for data generation.
|
142
142
|
|
143
143
|
Returns:
|
@@ -176,6 +176,9 @@ if FOUND_TORCH and FOUND_TORCHVISION:
|
|
176
176
|
f"If you encounter an issue, please open an issue in our GitHub "
|
177
177
|
f"project https://github.com/sony/model_optimization")
|
178
178
|
|
179
|
+
# get the model device
|
180
|
+
device = get_working_device()
|
181
|
+
|
179
182
|
# get a static graph representation of the model using torch.fx
|
180
183
|
fx_model = symbolic_trace(model)
|
181
184
|
|
@@ -198,8 +201,8 @@ if FOUND_TORCH and FOUND_TORCHVISION:
|
|
198
201
|
|
199
202
|
# Check if the scheduler type is valid
|
200
203
|
if scheduler_get_fn is None or scheduler_step_fn is None:
|
201
|
-
Logger.critical(f'Invalid
|
202
|
-
f'Please select one from {SchedulerType.get_values()}.')
|
204
|
+
Logger.critical(f'Invalid scheduler_type {data_generation_config.scheduler_type}. '
|
205
|
+
f'Please select one from {SchedulerType.get_values()}.') # pragma: no cover
|
203
206
|
|
204
207
|
# Create a scheduler object with the specified number of iterations
|
205
208
|
scheduler = scheduler_get_fn(data_generation_config.n_iter)
|
@@ -218,23 +221,22 @@ if FOUND_TORCH and FOUND_TORCHVISION:
|
|
218
221
|
orig_bn_stats_holder = PytorchOriginalBNStatsHolder(model, data_generation_config.bn_layer_types)
|
219
222
|
if orig_bn_stats_holder.get_num_bn_layers() == 0:
|
220
223
|
Logger.critical(
|
221
|
-
f'Data generation requires a model with at least one BatchNorm layer.')
|
224
|
+
f'Data generation requires a model with at least one BatchNorm layer.') # pragma: no cover
|
222
225
|
|
223
226
|
# Create an ImagesOptimizationHandler object for handling optimization
|
224
227
|
all_imgs_opt_handler = PytorchImagesOptimizationHandler(model=model,
|
225
|
-
|
226
|
-
|
227
|
-
|
228
|
-
|
229
|
-
|
230
|
-
|
231
|
-
|
232
|
-
|
233
|
-
|
234
|
-
|
235
|
-
|
236
|
-
|
237
|
-
reflection=data_generation_config.reflection)
|
228
|
+
data_gen_batch_size=data_generation_config.data_gen_batch_size,
|
229
|
+
init_dataset=init_dataset,
|
230
|
+
optimizer=data_generation_config.optimizer,
|
231
|
+
image_pipeline=image_pipeline,
|
232
|
+
activation_extractor=activation_extractor,
|
233
|
+
image_granularity=data_generation_config.image_granularity,
|
234
|
+
scheduler_step_fn=scheduler_step_fn,
|
235
|
+
scheduler=scheduler,
|
236
|
+
initial_lr=data_generation_config.initial_lr,
|
237
|
+
normalization_mean=normalization[0],
|
238
|
+
normalization_std=normalization[1],
|
239
|
+
device=device)
|
238
240
|
|
239
241
|
# Perform data generation and obtain a list of generated images
|
240
242
|
generated_images_list = data_generation(
|
@@ -247,6 +249,7 @@ if FOUND_TORCH and FOUND_TORCHVISION:
|
|
247
249
|
bn_alignment_loss_fn=bn_alignment_loss_fn,
|
248
250
|
output_loss_fn=output_loss_fn,
|
249
251
|
output_loss_multiplier=data_generation_config.output_loss_multiplier,
|
252
|
+
device=device,
|
250
253
|
)
|
251
254
|
# Return the list of finalized generated images
|
252
255
|
return generated_images_list
|
@@ -261,7 +264,8 @@ if FOUND_TORCH and FOUND_TORCHVISION:
|
|
261
264
|
bn_layer_weighting_fn: Callable,
|
262
265
|
bn_alignment_loss_fn: Callable,
|
263
266
|
output_loss_fn: Callable,
|
264
|
-
output_loss_multiplier: float
|
267
|
+
output_loss_multiplier: float,
|
268
|
+
device: torch.device
|
265
269
|
) -> List[Any]:
|
266
270
|
"""
|
267
271
|
Function to perform data generation using the provided model and data generation configuration.
|
@@ -276,14 +280,11 @@ if FOUND_TORCH and FOUND_TORCHVISION:
|
|
276
280
|
bn_alignment_loss_fn (Callable): Function to compute BatchNorm alignment loss.
|
277
281
|
output_loss_fn (Callable): Function to compute output loss.
|
278
282
|
output_loss_multiplier (float): Multiplier for the output loss.
|
283
|
+
device (torch.device): The current device set for PyTorch operations.
|
279
284
|
|
280
285
|
Returns:
|
281
286
|
List: Finalized list containing generated images.
|
282
287
|
"""
|
283
|
-
|
284
|
-
# Compute the layer weights based on orig_bn_stats_holder
|
285
|
-
bn_layer_weights = bn_layer_weighting_fn(orig_bn_stats_holder)
|
286
|
-
|
287
288
|
# Get the current time to measure the total time taken
|
288
289
|
total_time = time.time()
|
289
290
|
|
@@ -291,7 +292,7 @@ if FOUND_TORCH and FOUND_TORCHVISION:
|
|
291
292
|
ibar = tqdm(range(data_generation_config.n_iter))
|
292
293
|
|
293
294
|
# Perform data generation iterations
|
294
|
-
for
|
295
|
+
for i_iter in ibar:
|
295
296
|
|
296
297
|
# Randomly reorder the batches
|
297
298
|
all_imgs_opt_handler.random_batch_reorder()
|
@@ -311,7 +312,11 @@ if FOUND_TORCH and FOUND_TORCHVISION:
|
|
311
312
|
input_imgs = image_pipeline.image_input_manipulation(imgs_to_optimize)
|
312
313
|
|
313
314
|
# Forward pass to extract activations
|
314
|
-
|
315
|
+
with autocast():
|
316
|
+
output = activation_extractor.run_model(input_imgs)
|
317
|
+
|
318
|
+
# Compute the layer weights based on orig_bn_stats_holder
|
319
|
+
bn_layer_weights = bn_layer_weighting_fn(orig_bn_stats_holder, activation_extractor, i_iter, data_generation_config.n_iter)
|
315
320
|
|
316
321
|
# Compute BatchNorm alignment loss
|
317
322
|
bn_loss = all_imgs_opt_handler.compute_bn_loss(input_imgs=input_imgs,
|
@@ -322,33 +327,33 @@ if FOUND_TORCH and FOUND_TORCHVISION:
|
|
322
327
|
bn_layer_weights=bn_layer_weights)
|
323
328
|
|
324
329
|
# Compute output loss
|
325
|
-
|
326
|
-
|
327
|
-
|
328
|
-
|
329
|
-
else:
|
330
|
-
output_loss = torch.zeros(1).to(get_working_device())
|
330
|
+
output_loss = output_loss_fn(
|
331
|
+
model_outputs=output,
|
332
|
+
activation_extractor=activation_extractor,
|
333
|
+
device=device)
|
331
334
|
|
332
335
|
# Compute total loss
|
333
336
|
total_loss = bn_loss + output_loss_multiplier * output_loss
|
334
337
|
|
335
338
|
# Perform optimiztion step
|
336
|
-
all_imgs_opt_handler.optimization_step(random_batch_index, total_loss,
|
339
|
+
all_imgs_opt_handler.optimization_step(random_batch_index, total_loss, i_iter)
|
337
340
|
|
338
341
|
# Update the statistics based on the updated images
|
339
342
|
if all_imgs_opt_handler.use_all_data_stats:
|
340
|
-
|
341
|
-
|
342
|
-
|
343
|
-
|
343
|
+
with autocast():
|
344
|
+
final_imgs = image_pipeline.image_output_finalize(imgs_to_optimize)
|
345
|
+
all_imgs_opt_handler.update_statistics(input_imgs=final_imgs,
|
346
|
+
batch_index=random_batch_index,
|
347
|
+
activation_extractor=activation_extractor)
|
344
348
|
|
345
349
|
ibar.set_description(f"Total Loss: {total_loss.item():.5f}, "
|
346
350
|
f"BN Loss: {bn_loss.item():.5f}, "
|
347
|
-
f"Output Loss: {
|
351
|
+
f"Output Loss: {output_loss.item():.5f}")
|
348
352
|
|
349
353
|
# Return a list containing the finalized generated images
|
350
354
|
finalized_imgs = all_imgs_opt_handler.get_finalized_images()
|
351
355
|
Logger.info(f'Total time to generate {len(finalized_imgs)} images (seconds): {int(time.time() - total_time)}')
|
356
|
+
Logger.info(f'Final Loss: Total {total_loss.item()}, BN loss {bn_loss.item()}, Output loss {output_loss.item()}')
|
352
357
|
return finalized_imgs
|
353
358
|
else:
|
354
359
|
# If torch is not installed,
|
{mct_nightly-2.1.0.20240724.437.dist-info → mct_nightly-2.1.0.20240726.430.dist-info}/LICENSE.md
RENAMED
File without changes
|
File without changes
|
{mct_nightly-2.1.0.20240724.437.dist-info → mct_nightly-2.1.0.20240726.430.dist-info}/top_level.txt
RENAMED
File without changes
|