mct-nightly 2.2.0.20241022.507__py3-none-any.whl → 2.2.0.20241024.501__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mct_nightly-2.2.0.20241022.507.dist-info → mct_nightly-2.2.0.20241024.501.dist-info}/METADATA +1 -1
- {mct_nightly-2.2.0.20241022.507.dist-info → mct_nightly-2.2.0.20241024.501.dist-info}/RECORD +38 -31
- model_compression_toolkit/__init__.py +1 -1
- model_compression_toolkit/core/common/framework_implementation.py +43 -29
- model_compression_toolkit/core/common/hessian/__init__.py +1 -1
- model_compression_toolkit/core/common/hessian/hessian_info_service.py +222 -371
- model_compression_toolkit/core/common/hessian/hessian_scores_request.py +27 -41
- model_compression_toolkit/core/common/mixed_precision/sensitivity_evaluation.py +8 -10
- model_compression_toolkit/core/common/pruning/importance_metrics/lfh_importance_metric.py +11 -9
- model_compression_toolkit/core/common/quantization/quantization_params_generation/error_functions.py +10 -6
- model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_computation.py +17 -15
- model_compression_toolkit/core/keras/data_util.py +67 -0
- model_compression_toolkit/core/keras/keras_implementation.py +7 -1
- model_compression_toolkit/core/keras/tf_tensor_numpy.py +1 -1
- model_compression_toolkit/core/pytorch/back2framework/pytorch_model_builder.py +1 -1
- model_compression_toolkit/core/pytorch/data_util.py +163 -0
- model_compression_toolkit/core/pytorch/hessian/activation_hessian_scores_calculator_pytorch.py +6 -31
- model_compression_toolkit/core/pytorch/hessian/hessian_scores_calculator_pytorch.py +11 -21
- model_compression_toolkit/core/pytorch/hessian/weights_hessian_scores_calculator_pytorch.py +9 -7
- model_compression_toolkit/core/pytorch/pytorch_implementation.py +8 -2
- model_compression_toolkit/core/pytorch/utils.py +22 -19
- model_compression_toolkit/core/quantization_prep_runner.py +2 -1
- model_compression_toolkit/core/runner.py +1 -2
- model_compression_toolkit/gptq/common/gptq_config.py +0 -2
- model_compression_toolkit/gptq/common/gptq_training.py +58 -114
- model_compression_toolkit/gptq/keras/gptq_training.py +15 -6
- model_compression_toolkit/gptq/pytorch/gptq_loss.py +3 -2
- model_compression_toolkit/gptq/pytorch/gptq_training.py +97 -64
- model_compression_toolkit/gptq/pytorch/quantization_facade.py +0 -2
- model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/soft_quantizer_reg.py +4 -3
- tests_pytest/keras/__init__.py +14 -0
- tests_pytest/keras/core/__init__.py +14 -0
- tests_pytest/keras/core/test_data_util.py +91 -0
- tests_pytest/pytorch/core/__init__.py +14 -0
- tests_pytest/pytorch/core/test_data_util.py +125 -0
- {mct_nightly-2.2.0.20241022.507.dist-info → mct_nightly-2.2.0.20241024.501.dist-info}/LICENSE.md +0 -0
- {mct_nightly-2.2.0.20241022.507.dist-info → mct_nightly-2.2.0.20241024.501.dist-info}/WHEEL +0 -0
- {mct_nightly-2.2.0.20241022.507.dist-info → mct_nightly-2.2.0.20241024.501.dist-info}/top_level.txt +0 -0
@@ -12,32 +12,35 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
|
-
|
15
|
+
import copy
|
16
|
+
from typing import Callable, List, Tuple, Union, Generator
|
16
17
|
|
17
18
|
import numpy as np
|
19
|
+
import torch
|
20
|
+
from mct_quantizers import PytorchQuantizationWrapper, PytorchActivationQuantizationHolder
|
18
21
|
from torch.nn import Module
|
22
|
+
from torch.utils.data import DataLoader
|
19
23
|
from tqdm import tqdm
|
20
|
-
import copy
|
21
|
-
import torch
|
22
24
|
|
23
|
-
from model_compression_toolkit.core.common.hessian import HessianInfoService
|
24
|
-
from model_compression_toolkit.gptq.pytorch.quantizer.gradual_activation_quantization import \
|
25
|
-
get_gradual_activation_quantizer_wrapper_factory
|
26
|
-
from model_compression_toolkit.logger import Logger
|
27
|
-
from model_compression_toolkit.core.pytorch.back2framework.pytorch_model_builder import PyTorchModelBuilder
|
28
|
-
from model_compression_toolkit.gptq.common.gptq_graph import get_kernel_attribute_name_for_gptq
|
29
|
-
from model_compression_toolkit.gptq.common.gptq_training import GPTQTrainer
|
30
|
-
from model_compression_toolkit.gptq.common.gptq_config import GradientPTQConfig
|
31
25
|
from model_compression_toolkit.core.common import Graph, BaseNode
|
32
|
-
from model_compression_toolkit.core.common.framework_info import FrameworkInfo
|
33
26
|
from model_compression_toolkit.core.common.framework_implementation import FrameworkImplementation
|
27
|
+
from model_compression_toolkit.core.common.framework_info import FrameworkInfo
|
28
|
+
from model_compression_toolkit.core.common.hessian import HessianInfoService, HessianScoresGranularity
|
29
|
+
from model_compression_toolkit.core.pytorch.back2framework.pytorch_model_builder import PyTorchModelBuilder
|
34
30
|
from model_compression_toolkit.core.pytorch.constants import BIAS
|
31
|
+
from model_compression_toolkit.core.pytorch.data_util import FixedDatasetFromGenerator, IterableDatasetFromGenerator, \
|
32
|
+
IterableSampleWithConstInfoDataset, FixedSampleInfoDataset, get_collate_fn_with_extra_outputs
|
35
33
|
from model_compression_toolkit.core.pytorch.utils import to_torch_tensor, set_model, torch_tensor_to_numpy
|
34
|
+
from model_compression_toolkit.gptq.common.gptq_config import GradientPTQConfig
|
35
|
+
from model_compression_toolkit.gptq.common.gptq_graph import get_kernel_attribute_name_for_gptq
|
36
|
+
from model_compression_toolkit.gptq.common.gptq_training import GPTQTrainer
|
36
37
|
from model_compression_toolkit.gptq.pytorch.graph_info import get_gptq_trainable_parameters, \
|
37
38
|
get_weights_for_loss
|
39
|
+
from model_compression_toolkit.gptq.pytorch.quantizer.gradual_activation_quantization import \
|
40
|
+
get_gradual_activation_quantizer_wrapper_factory
|
38
41
|
from model_compression_toolkit.gptq.pytorch.quantizer.quantization_builder import quantization_builder
|
39
42
|
from model_compression_toolkit.gptq.pytorch.quantizer.regularization_factory import get_regularization
|
40
|
-
from
|
43
|
+
from model_compression_toolkit.logger import Logger
|
41
44
|
from model_compression_toolkit.trainable_infrastructure.pytorch.util import get_total_grad_steps
|
42
45
|
|
43
46
|
|
@@ -70,6 +73,7 @@ class PytorchGPTQTrainer(GPTQTrainer):
|
|
70
73
|
hessian_info_service: HessianInfoService to fetch info based on the hessian approximation of the float model.
|
71
74
|
"""
|
72
75
|
def _get_total_grad_steps():
|
76
|
+
# TODO get it from the dataset
|
73
77
|
return get_total_grad_steps(representative_data_gen) * gptq_config.n_epochs
|
74
78
|
|
75
79
|
# must be set prior to model building in the base class constructor
|
@@ -81,6 +85,7 @@ class PytorchGPTQTrainer(GPTQTrainer):
|
|
81
85
|
gptq_config,
|
82
86
|
fw_impl,
|
83
87
|
fw_info,
|
88
|
+
representative_data_gen_fn=representative_data_gen,
|
84
89
|
hessian_info_service=hessian_info_service)
|
85
90
|
|
86
91
|
self.loss_list = []
|
@@ -106,20 +111,87 @@ class PytorchGPTQTrainer(GPTQTrainer):
|
|
106
111
|
trainable_bias,
|
107
112
|
trainable_threshold)
|
108
113
|
hessian_cfg = self.gptq_config.hessian_weights_config
|
114
|
+
|
109
115
|
self.use_sample_layer_attention = hessian_cfg.per_sample
|
110
|
-
self.hessian_score_per_layer = None # for fixed layer weights
|
111
|
-
self.hessian_score_per_image_per_layer = None # for sample-layer attention
|
112
116
|
if self.use_sample_layer_attention:
|
113
117
|
# normalization is currently not supported, make sure the config reflects it.
|
114
118
|
if hessian_cfg.norm_scores or hessian_cfg.log_norm or hessian_cfg.scale_log_norm:
|
115
119
|
raise NotImplementedError()
|
116
|
-
|
117
|
-
self.hessian_score_per_image_per_layer = {}
|
120
|
+
self.train_dataloader = self._prepare_train_dataloader_sla(representative_data_gen)
|
118
121
|
else:
|
119
|
-
self.
|
122
|
+
self.train_dataloader = self._prepare_train_dataloader_for_non_sla(representative_data_gen)
|
120
123
|
|
121
124
|
self.reg_func = get_regularization(self.gptq_config, _get_total_grad_steps)
|
122
125
|
|
126
|
+
def _prepare_train_dataloader_sla(self, data_gen_fn: Callable[[], Generator]) -> DataLoader:
|
127
|
+
"""
|
128
|
+
Computes Sample-Layer Attention score and builds a train dataloader.
|
129
|
+
|
130
|
+
Args:
|
131
|
+
data_gen_fn: factory for representative dataset generator.
|
132
|
+
|
133
|
+
Returns:
|
134
|
+
PyTorch dataloader yielding three outputs - samples, weights for the distillation loss and
|
135
|
+
weights for regularization.
|
136
|
+
"""
|
137
|
+
fixed_dataset = FixedDatasetFromGenerator(data_gen_fn)
|
138
|
+
orig_batch_size = fixed_dataset.orig_batch_size
|
139
|
+
# compute hessians for the whole dataset
|
140
|
+
hess_data_loader = DataLoader(fixed_dataset,
|
141
|
+
batch_size=self.gptq_config.hessian_weights_config.hessian_batch_size,
|
142
|
+
shuffle=False)
|
143
|
+
request = self._build_hessian_request(granularity=HessianScoresGranularity.PER_OUTPUT_CHANNEL,
|
144
|
+
data_loader=hess_data_loader,
|
145
|
+
n_samples=None)
|
146
|
+
layers_hessians = self.hessian_service.fetch_hessian(request, force_compute=True)
|
147
|
+
|
148
|
+
# compute sla score defined as max over channels
|
149
|
+
layers_hessians = {layer: to_torch_tensor(hess.max(1)) for layer, hess in layers_hessians.items()}
|
150
|
+
|
151
|
+
# build train dataset and dataloader
|
152
|
+
hessians_tensor = torch.stack([layers_hessians[layer.name] for layer in self.compare_points], dim=1) # samples X layers
|
153
|
+
assert hessians_tensor.shape[1] == len(self.compare_points)
|
154
|
+
loss_weights = list(hessians_tensor)
|
155
|
+
sla_train_dataset = FixedSampleInfoDataset(fixed_dataset.samples, loss_weights)
|
156
|
+
|
157
|
+
reg_weights = hessians_tensor.mean(dim=0)
|
158
|
+
# use collate to add a single value to each batch
|
159
|
+
collate_fn = get_collate_fn_with_extra_outputs(reg_weights)
|
160
|
+
|
161
|
+
return DataLoader(sla_train_dataset, batch_size=orig_batch_size, shuffle=True, collate_fn=collate_fn)
|
162
|
+
|
163
|
+
def _prepare_train_dataloader_for_non_sla(self, data_gen_fn: Callable[[], Generator]) -> DataLoader:
|
164
|
+
"""
|
165
|
+
Computes loss weights and builds a train dataloader.
|
166
|
+
|
167
|
+
Args:
|
168
|
+
data_gen_fn: factory for representative dataset generator.
|
169
|
+
|
170
|
+
Returns:
|
171
|
+
PyTorch dataloader yielding three outputs - samples, weights for the distillation loss and
|
172
|
+
weights for regularization.
|
173
|
+
"""
|
174
|
+
dataset = IterableDatasetFromGenerator(data_gen_fn)
|
175
|
+
num_nodes = len(self.compare_points)
|
176
|
+
|
177
|
+
if self.gptq_config.use_hessian_based_weights:
|
178
|
+
hess_dataloader = DataLoader(dataset, batch_size=self.gptq_config.hessian_weights_config.hessian_batch_size)
|
179
|
+
loss_weights = torch.from_numpy(self.compute_hessian_based_weights(hess_dataloader))
|
180
|
+
else:
|
181
|
+
loss_weights = torch.ones(num_nodes) / num_nodes
|
182
|
+
|
183
|
+
train_dataset = IterableSampleWithConstInfoDataset(dataset, loss_weights)
|
184
|
+
|
185
|
+
reg_weights = torch.ones(num_nodes)
|
186
|
+
# use collate to add a single value to each batch
|
187
|
+
collate_fn = get_collate_fn_with_extra_outputs(reg_weights)
|
188
|
+
|
189
|
+
# NOTE: Don't just increase num_workers! With iterable dataset each worker fetches a full pass, so having
|
190
|
+
# more workers will result in multiple passes within the same epoch. Special handling is needed either
|
191
|
+
# in dataset or in worker_init_fn passed to dataloader, and it might not speed anything up anyway.
|
192
|
+
return DataLoader(train_dataset, batch_size=dataset.orig_batch_size,
|
193
|
+
collate_fn=collate_fn, num_workers=1)
|
194
|
+
|
123
195
|
def _is_gptq_weights_trainable(self,
|
124
196
|
node: BaseNode) -> bool:
|
125
197
|
"""
|
@@ -195,11 +267,10 @@ class PytorchGPTQTrainer(GPTQTrainer):
|
|
195
267
|
|
196
268
|
return gptq_model, gptq_user_info
|
197
269
|
|
198
|
-
def train(self
|
270
|
+
def train(self):
|
199
271
|
"""
|
200
272
|
GPTQ Training using pytorch framework
|
201
|
-
|
202
|
-
representative_data_gen: Dataset generator to get images.
|
273
|
+
|
203
274
|
Returns:
|
204
275
|
Graph after GPTQ training
|
205
276
|
"""
|
@@ -216,7 +287,7 @@ class PytorchGPTQTrainer(GPTQTrainer):
|
|
216
287
|
# ----------------------------------------------
|
217
288
|
# Training loop
|
218
289
|
# ----------------------------------------------
|
219
|
-
self.micro_training_loop(
|
290
|
+
self.micro_training_loop(self.gptq_config.n_epochs)
|
220
291
|
|
221
292
|
def compute_gradients(self,
|
222
293
|
y_float: List[torch.Tensor],
|
@@ -262,23 +333,21 @@ class PytorchGPTQTrainer(GPTQTrainer):
|
|
262
333
|
return loss_value, grads
|
263
334
|
|
264
335
|
def micro_training_loop(self,
|
265
|
-
data_function: Callable,
|
266
336
|
n_epochs: int):
|
267
337
|
"""
|
268
338
|
This function run a micro training loop on given set of parameters.
|
269
339
|
Args:
|
270
|
-
data_function: A callable function that give a batch of samples.
|
271
340
|
n_epochs: Number of update iterations of representative dataset.
|
272
341
|
"""
|
273
342
|
with tqdm(range(n_epochs), "Running GPTQ optimization") as epochs_pbar:
|
274
343
|
for _ in epochs_pbar:
|
275
|
-
with tqdm(
|
276
|
-
for
|
277
|
-
|
344
|
+
with tqdm(self.train_dataloader, position=1, leave=False) as data_pbar:
|
345
|
+
for sample in data_pbar:
|
346
|
+
data, loss_weight, reg_weight = to_torch_tensor(sample)
|
278
347
|
input_data = [d * self.input_scale for d in data]
|
279
348
|
input_tensor = to_torch_tensor(input_data)
|
280
349
|
y_float = self.float_model(input_tensor) # running float model
|
281
|
-
loss_value, grads = self.compute_gradients(y_float, input_tensor,
|
350
|
+
loss_value, grads = self.compute_gradients(y_float, input_tensor, loss_weight, reg_weight)
|
282
351
|
# Run one step of gradient descent by updating the value of the variables to minimize the loss.
|
283
352
|
for (optimizer, _) in self.optimizer_with_param:
|
284
353
|
optimizer.step()
|
@@ -290,42 +359,6 @@ class PytorchGPTQTrainer(GPTQTrainer):
|
|
290
359
|
self.loss_list.append(loss_value.item())
|
291
360
|
Logger.debug(f'last loss value: {self.loss_list[-1]}')
|
292
361
|
|
293
|
-
def _get_loss_weights(self, input_tensors: List[torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]:
|
294
|
-
"""
|
295
|
-
Fetches weights for distillation and round regularization parts of loss.
|
296
|
-
|
297
|
-
Args:
|
298
|
-
input_tensors: list containing a batch of inputs.
|
299
|
-
|
300
|
-
Returns:
|
301
|
-
A tuple of two tensors:
|
302
|
-
- weights for distillation loss
|
303
|
-
- weights for rounding regularization loss
|
304
|
-
|
305
|
-
"""
|
306
|
-
if self.use_sample_layer_attention is False:
|
307
|
-
return self.hessian_score_per_layer, torch.ones_like(self.hessian_score_per_layer)
|
308
|
-
|
309
|
-
if len(input_tensors) > 1:
|
310
|
-
raise NotImplementedError('Sample-Layer attention is not currently supported for networks with multiple inputs')
|
311
|
-
|
312
|
-
image_scores = []
|
313
|
-
batch = input_tensors[0]
|
314
|
-
img_hashes = [self.hessian_service.calc_image_hash(img) for img in batch]
|
315
|
-
for img_hash in img_hashes:
|
316
|
-
# If sample-layer attention score for the image is not found, compute and store it for the whole batch.
|
317
|
-
if img_hash not in self.hessian_score_per_image_per_layer:
|
318
|
-
score_per_image_per_layer = self._compute_sample_layer_attention_scores(input_tensors)
|
319
|
-
self.hessian_score_per_image_per_layer.update(score_per_image_per_layer)
|
320
|
-
img_scores_per_layer: Dict[BaseNode, np.ndarray] = self.hessian_score_per_image_per_layer[img_hash]
|
321
|
-
# fetch image scores for all layers and combine them into a single tensor
|
322
|
-
img_scores = np.stack(list(img_scores_per_layer.values()), axis=0)
|
323
|
-
image_scores.append(img_scores)
|
324
|
-
|
325
|
-
layer_sample_weights = np.stack(image_scores, axis=1) # layers X images
|
326
|
-
layer_weights = layer_sample_weights.mean(axis=1)
|
327
|
-
return layer_sample_weights, layer_weights
|
328
|
-
|
329
362
|
def update_graph(self) -> Graph:
|
330
363
|
"""
|
331
364
|
Update a graph using GPTQ after minimizing the loss between the float model's output
|
@@ -18,7 +18,6 @@ from typing import Callable, Union
|
|
18
18
|
from model_compression_toolkit.constants import ACT_HESSIAN_DEFAULT_BATCH_SIZE, PYTORCH
|
19
19
|
from model_compression_toolkit.core import CoreConfig
|
20
20
|
from model_compression_toolkit.core.analyzer import analyzer_model_quantization
|
21
|
-
from model_compression_toolkit.core.common.hessian import HessianScoresGranularity, HessianEstimationDistribution
|
22
21
|
from model_compression_toolkit.core.common.mixed_precision.mixed_precision_quantization_config import \
|
23
22
|
MixedPrecisionQuantizationConfig
|
24
23
|
from model_compression_toolkit.core.common.mixed_precision.resource_utilization_tools.resource_utilization import \
|
@@ -119,7 +118,6 @@ if FOUND_TORCH:
|
|
119
118
|
scale_log_norm=False,
|
120
119
|
hessian_batch_size=hessian_batch_size,
|
121
120
|
per_sample=True,
|
122
|
-
estimator_distribution=HessianEstimationDistribution.RADEMACHER
|
123
121
|
)
|
124
122
|
loss = loss or sample_layer_attention_loss
|
125
123
|
else:
|
@@ -47,14 +47,15 @@ class SoftQuantizerRegularization:
|
|
47
47
|
Args:
|
48
48
|
model: A model to be quantized with SoftRounding.
|
49
49
|
entropy_reg: Entropy value to scale the quantizer regularization.
|
50
|
-
layer_weights: a vector of
|
50
|
+
layer_weights: a vector of layers weights.
|
51
51
|
|
52
52
|
Returns: Regularization value.
|
53
53
|
"""
|
54
54
|
layers = [m for m in model.modules() if isinstance(m, PytorchQuantizationWrapper)]
|
55
55
|
|
56
|
-
if
|
57
|
-
raise ValueError(f'Expected weights to be
|
56
|
+
if layer_weights.shape[0] != len(layers):
|
57
|
+
raise ValueError(f'Expected weights.shape[0] to be {len(layers)}, '
|
58
|
+
f'received shape {layer_weights.shape}.') # pragma: no cover
|
58
59
|
max_w = layer_weights.max()
|
59
60
|
|
60
61
|
b = self.beta_scheduler(self.count_iter)
|
@@ -0,0 +1,14 @@
|
|
1
|
+
# Copyright 2024 Sony Semiconductor Israel, Inc. All rights reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
@@ -0,0 +1,14 @@
|
|
1
|
+
# Copyright 2024 Sony Semiconductor Israel, Inc. All rights reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
@@ -0,0 +1,91 @@
|
|
1
|
+
# Copyright 2024 Sony Semiconductor Israel, Inc. All rights reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
import numpy as np
|
16
|
+
import pytest
|
17
|
+
|
18
|
+
from model_compression_toolkit.core.keras.data_util import data_gen_to_dataloader, TFDatasetFromGenerator
|
19
|
+
|
20
|
+
|
21
|
+
@pytest.fixture(scope='session')
|
22
|
+
def fixed_dataset():
|
23
|
+
# generate 320 images with data1[i] = i and data2[i] = i+10
|
24
|
+
data1 = np.stack([np.full((3, 30, 20), v) for v in range(320)], axis=0)
|
25
|
+
data2 = np.stack([np.full((10,), v + 10) for v in range(320)], axis=0)
|
26
|
+
return data1, data2
|
27
|
+
|
28
|
+
|
29
|
+
@pytest.fixture
|
30
|
+
def fixed_gen(fixed_dataset):
|
31
|
+
def f():
|
32
|
+
for i in range(10):
|
33
|
+
yield [fixed_dataset[0][32 * i: 32 * (i + 1)], fixed_dataset[1][32 * i: 32 * (i + 1)]]
|
34
|
+
|
35
|
+
return f
|
36
|
+
|
37
|
+
|
38
|
+
def get_random_data_gen_fn(seed=42):
|
39
|
+
""" get gen factory for reproducible gen yielding different samples in each epoch """
|
40
|
+
rng = np.random.default_rng(seed)
|
41
|
+
|
42
|
+
def f():
|
43
|
+
for i in range(10):
|
44
|
+
yield [rng.random((32, 3, 20, 30)).astype(np.float32), rng.random((32, 10)).astype(np.float32)]
|
45
|
+
return f
|
46
|
+
|
47
|
+
|
48
|
+
class TestTFDataUtil:
|
49
|
+
create_dataloader_fn = data_gen_to_dataloader
|
50
|
+
|
51
|
+
def test_iterable_dataset_from_fixed_gen(self, fixed_gen):
|
52
|
+
""" tests iterable dataset from fixed gen - same samples are generated in each epoch in the same order """
|
53
|
+
ds = TFDatasetFromGenerator(fixed_gen, batch_size=1)
|
54
|
+
self._validate_ds_from_fixed_gen(ds, 320)
|
55
|
+
|
56
|
+
def test_iterable_dataset_from_random_gen(self):
|
57
|
+
""" test that dataset samples over epochs are identical to the original data generator """
|
58
|
+
ds = TFDatasetFromGenerator(get_random_data_gen_fn(), batch_size=1)
|
59
|
+
pass1 = np.concatenate([t[0] for t in ds], axis=0)
|
60
|
+
pass2 = np.concatenate([t[0] for t in ds], axis=0)
|
61
|
+
|
62
|
+
gen_fn = get_random_data_gen_fn()
|
63
|
+
# one invocation is used for validation and batch size in dataset, so promote the reference gen for comparison
|
64
|
+
next(gen_fn())
|
65
|
+
gen_pass1 = np.concatenate([t[0] for t in gen_fn()], axis=0)
|
66
|
+
gen_pass2 = np.concatenate([t[0] for t in gen_fn()], axis=0)
|
67
|
+
# check that each pass is identical to corresponding pass in the original gen
|
68
|
+
assert np.array_equal(pass1, gen_pass1)
|
69
|
+
assert np.array_equal(pass2, gen_pass2)
|
70
|
+
assert not np.allclose(pass1, pass2)
|
71
|
+
|
72
|
+
def test_dataloader(self, fixed_gen):
|
73
|
+
ds = TFDatasetFromGenerator(fixed_gen, batch_size=25)
|
74
|
+
ds_iter = iter(ds)
|
75
|
+
batch1 = next(ds_iter)
|
76
|
+
assert batch1[0].shape[0] == batch1[1].shape[0] == 25
|
77
|
+
assert np.array_equal(batch1[0][0], np.full((3, 30, 20), 0))
|
78
|
+
assert np.array_equal(batch1[1][0], np.full((10,), 10))
|
79
|
+
assert np.array_equal(batch1[0][-1], np.full((3, 30, 20), 24))
|
80
|
+
assert np.array_equal(batch1[1][-1], np.full((10,), 34))
|
81
|
+
assert len(ds) == 13
|
82
|
+
assert ds.orig_batch_size == 32
|
83
|
+
|
84
|
+
def _validate_ds_from_fixed_gen(self, ds, exp_len):
|
85
|
+
for _ in range(2):
|
86
|
+
for i, sample in enumerate(ds):
|
87
|
+
assert np.array_equal(sample[0].cpu().numpy(), np.full((1, 3, 30, 20), i))
|
88
|
+
assert np.array_equal(sample[1].cpu().numpy(), np.full((1, 10,), i + 10))
|
89
|
+
assert i == exp_len - 1
|
90
|
+
assert ds.orig_batch_size == 32
|
91
|
+
assert len(ds) == exp_len
|
@@ -0,0 +1,14 @@
|
|
1
|
+
# Copyright 2024 Sony Semiconductor Israel, Inc. All rights reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
@@ -0,0 +1,125 @@
|
|
1
|
+
# Copyright 2024 Sony Semiconductor Israel, Inc. All rights reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
import pytest
|
16
|
+
import torch
|
17
|
+
import numpy as np
|
18
|
+
from torch.utils.data import IterableDataset, Dataset
|
19
|
+
|
20
|
+
from model_compression_toolkit.core.pytorch.data_util import (data_gen_to_dataloader, IterableDatasetFromGenerator,
|
21
|
+
FixedDatasetFromGenerator, FixedSampleInfoDataset)
|
22
|
+
|
23
|
+
|
24
|
+
@pytest.fixture(scope='session')
|
25
|
+
def fixed_dataset():
|
26
|
+
# generate 320 images with data1[i] = i and data2[i] = i+10
|
27
|
+
data1 = np.stack([np.full((3, 30, 20), v) for v in range(320)], axis=0)
|
28
|
+
data2 = np.stack([np.full((10,), v + 10) for v in range(320)], axis=0)
|
29
|
+
return data1, data2
|
30
|
+
|
31
|
+
|
32
|
+
@pytest.fixture
|
33
|
+
def fixed_gen(fixed_dataset):
|
34
|
+
def f():
|
35
|
+
for i in range(10):
|
36
|
+
yield [fixed_dataset[0][32 * i: 32 * (i + 1)], fixed_dataset[1][32 * i: 32 * (i + 1)]]
|
37
|
+
|
38
|
+
return f
|
39
|
+
|
40
|
+
|
41
|
+
def get_random_data_gen_fn(seed=42):
|
42
|
+
""" get gen factory for reproducible gen yielding different samples in each epoch """
|
43
|
+
rng = np.random.default_rng(seed)
|
44
|
+
|
45
|
+
def f():
|
46
|
+
for i in range(10):
|
47
|
+
yield [rng.random((32, 3, 20, 30)), rng.random((32, 10))]
|
48
|
+
return f
|
49
|
+
|
50
|
+
|
51
|
+
class TestDataUtil:
|
52
|
+
create_dataloader_fn = data_gen_to_dataloader
|
53
|
+
|
54
|
+
def test_iterable_dataset_from_fixed_gen(self, fixed_gen):
|
55
|
+
""" tests iterable dataset from fixed gen - same samples are generated in each epoch in the same order """
|
56
|
+
ds = IterableDatasetFromGenerator(fixed_gen)
|
57
|
+
assert isinstance(ds, IterableDataset)
|
58
|
+
self._validate_ds_from_fixed_gen(ds, 320)
|
59
|
+
|
60
|
+
def test_iterable_dataset_from_random_gen(self):
|
61
|
+
""" test that dataset samples over epochs are identical to the original data generator """
|
62
|
+
ds = IterableDatasetFromGenerator(get_random_data_gen_fn())
|
63
|
+
pass1 = torch.stack([t[0] for t in ds], dim=0)
|
64
|
+
pass2 = torch.stack([t[0] for t in ds], dim=0)
|
65
|
+
|
66
|
+
gen_fn = get_random_data_gen_fn()
|
67
|
+
# one invocation is used for validation and batch size in dataset, so promote the reference gen for comparison
|
68
|
+
next(gen_fn())
|
69
|
+
gen_pass1 = np.concatenate([t[0] for t in gen_fn()], axis=0)
|
70
|
+
gen_pass2 = np.concatenate([t[0] for t in gen_fn()], axis=0)
|
71
|
+
# check that each pass is identical to corresponding pass in the original gen
|
72
|
+
assert np.allclose(pass1.cpu().numpy(), gen_pass1)
|
73
|
+
assert np.allclose(pass2.cpu().numpy(), gen_pass2)
|
74
|
+
assert not torch.equal(pass1, pass2)
|
75
|
+
|
76
|
+
def test_fixed_dataset_from_fixed_gen_full(self, fixed_gen):
|
77
|
+
ds = FixedDatasetFromGenerator(fixed_gen)
|
78
|
+
assert isinstance(ds, Dataset) and not isinstance(ds, IterableDataset)
|
79
|
+
self._validate_ds_from_fixed_gen(ds, 320)
|
80
|
+
|
81
|
+
def test_fixed_dataset_from_const_gen_subset(self, fixed_gen):
|
82
|
+
ds = FixedDatasetFromGenerator(fixed_gen, n_samples=25)
|
83
|
+
self._validate_ds_from_fixed_gen(ds, 25)
|
84
|
+
|
85
|
+
def test_fixed_dataset_from_random_gen_full(self):
|
86
|
+
ds = FixedDatasetFromGenerator(get_random_data_gen_fn())
|
87
|
+
self._validate_fixed_ds(ds, exp_len=320, exp_batch_size=32)
|
88
|
+
|
89
|
+
def test_fixed_dataset_from_random_gen_subset(self):
|
90
|
+
ds = FixedDatasetFromGenerator(get_random_data_gen_fn(), n_samples=123)
|
91
|
+
self._validate_fixed_ds(ds, exp_len=123, exp_batch_size=32)
|
92
|
+
|
93
|
+
def test_not_enough_samples_in_datagen(self):
|
94
|
+
def gen():
|
95
|
+
yield [np.ones((10, 3))]
|
96
|
+
with pytest.raises(ValueError, match='Not enough samples in the data generator'):
|
97
|
+
FixedDatasetFromGenerator(gen, n_samples=11)
|
98
|
+
|
99
|
+
def test_extra_info_mismatch(self, fixed_gen):
|
100
|
+
with pytest.raises(ValueError, match='Mismatch in the number of samples between samples and complementary data'):
|
101
|
+
FixedSampleInfoDataset([1]*10, [2]*10, [3]*11)
|
102
|
+
|
103
|
+
@pytest.mark.parametrize('ds_cls', [FixedDatasetFromGenerator, IterableDatasetFromGenerator])
|
104
|
+
def test_invalid_gen(self, ds_cls):
|
105
|
+
def gen():
|
106
|
+
yield np.ones((10, 3))
|
107
|
+
with pytest.raises(TypeError, match='Data generator is expected to yield a list of tensors'):
|
108
|
+
ds_cls(gen)
|
109
|
+
|
110
|
+
def _validate_ds_from_fixed_gen(self, ds, exp_len):
|
111
|
+
for _ in range(2):
|
112
|
+
for i, sample in enumerate(ds):
|
113
|
+
assert np.array_equal(sample[0].cpu().numpy(), np.full((3, 30, 20), i))
|
114
|
+
assert np.array_equal(sample[1].cpu().numpy(), np.full((10,), i + 10))
|
115
|
+
assert i == exp_len - 1
|
116
|
+
assert ds.orig_batch_size == 32
|
117
|
+
assert len(ds) == exp_len
|
118
|
+
|
119
|
+
def _validate_fixed_ds(self, ds, exp_len, exp_batch_size):
|
120
|
+
assert isinstance(ds, torch.utils.data.Dataset) and not isinstance(ds, torch.utils.data.IterableDataset)
|
121
|
+
full_pass1 = torch.concat([t[0] for t in ds], dim=0)
|
122
|
+
full_pass2 = torch.concat([t[0] for t in ds], dim=0)
|
123
|
+
assert torch.equal(full_pass1, full_pass2)
|
124
|
+
assert len(ds) == exp_len
|
125
|
+
assert ds.orig_batch_size == exp_batch_size
|
{mct_nightly-2.2.0.20241022.507.dist-info → mct_nightly-2.2.0.20241024.501.dist-info}/LICENSE.md
RENAMED
File without changes
|
File without changes
|
{mct_nightly-2.2.0.20241022.507.dist-info → mct_nightly-2.2.0.20241024.501.dist-info}/top_level.txt
RENAMED
File without changes
|