mct-nightly 2.2.0.20241022.507__py3-none-any.whl → 2.2.0.20241024.501__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (38) hide show
  1. {mct_nightly-2.2.0.20241022.507.dist-info → mct_nightly-2.2.0.20241024.501.dist-info}/METADATA +1 -1
  2. {mct_nightly-2.2.0.20241022.507.dist-info → mct_nightly-2.2.0.20241024.501.dist-info}/RECORD +38 -31
  3. model_compression_toolkit/__init__.py +1 -1
  4. model_compression_toolkit/core/common/framework_implementation.py +43 -29
  5. model_compression_toolkit/core/common/hessian/__init__.py +1 -1
  6. model_compression_toolkit/core/common/hessian/hessian_info_service.py +222 -371
  7. model_compression_toolkit/core/common/hessian/hessian_scores_request.py +27 -41
  8. model_compression_toolkit/core/common/mixed_precision/sensitivity_evaluation.py +8 -10
  9. model_compression_toolkit/core/common/pruning/importance_metrics/lfh_importance_metric.py +11 -9
  10. model_compression_toolkit/core/common/quantization/quantization_params_generation/error_functions.py +10 -6
  11. model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_computation.py +17 -15
  12. model_compression_toolkit/core/keras/data_util.py +67 -0
  13. model_compression_toolkit/core/keras/keras_implementation.py +7 -1
  14. model_compression_toolkit/core/keras/tf_tensor_numpy.py +1 -1
  15. model_compression_toolkit/core/pytorch/back2framework/pytorch_model_builder.py +1 -1
  16. model_compression_toolkit/core/pytorch/data_util.py +163 -0
  17. model_compression_toolkit/core/pytorch/hessian/activation_hessian_scores_calculator_pytorch.py +6 -31
  18. model_compression_toolkit/core/pytorch/hessian/hessian_scores_calculator_pytorch.py +11 -21
  19. model_compression_toolkit/core/pytorch/hessian/weights_hessian_scores_calculator_pytorch.py +9 -7
  20. model_compression_toolkit/core/pytorch/pytorch_implementation.py +8 -2
  21. model_compression_toolkit/core/pytorch/utils.py +22 -19
  22. model_compression_toolkit/core/quantization_prep_runner.py +2 -1
  23. model_compression_toolkit/core/runner.py +1 -2
  24. model_compression_toolkit/gptq/common/gptq_config.py +0 -2
  25. model_compression_toolkit/gptq/common/gptq_training.py +58 -114
  26. model_compression_toolkit/gptq/keras/gptq_training.py +15 -6
  27. model_compression_toolkit/gptq/pytorch/gptq_loss.py +3 -2
  28. model_compression_toolkit/gptq/pytorch/gptq_training.py +97 -64
  29. model_compression_toolkit/gptq/pytorch/quantization_facade.py +0 -2
  30. model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/soft_quantizer_reg.py +4 -3
  31. tests_pytest/keras/__init__.py +14 -0
  32. tests_pytest/keras/core/__init__.py +14 -0
  33. tests_pytest/keras/core/test_data_util.py +91 -0
  34. tests_pytest/pytorch/core/__init__.py +14 -0
  35. tests_pytest/pytorch/core/test_data_util.py +125 -0
  36. {mct_nightly-2.2.0.20241022.507.dist-info → mct_nightly-2.2.0.20241024.501.dist-info}/LICENSE.md +0 -0
  37. {mct_nightly-2.2.0.20241022.507.dist-info → mct_nightly-2.2.0.20241024.501.dist-info}/WHEEL +0 -0
  38. {mct_nightly-2.2.0.20241022.507.dist-info → mct_nightly-2.2.0.20241024.501.dist-info}/top_level.txt +0 -0
@@ -12,32 +12,35 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
- from typing import Callable, List, Tuple, Union, Dict
15
+ import copy
16
+ from typing import Callable, List, Tuple, Union, Generator
16
17
 
17
18
  import numpy as np
19
+ import torch
20
+ from mct_quantizers import PytorchQuantizationWrapper, PytorchActivationQuantizationHolder
18
21
  from torch.nn import Module
22
+ from torch.utils.data import DataLoader
19
23
  from tqdm import tqdm
20
- import copy
21
- import torch
22
24
 
23
- from model_compression_toolkit.core.common.hessian import HessianInfoService
24
- from model_compression_toolkit.gptq.pytorch.quantizer.gradual_activation_quantization import \
25
- get_gradual_activation_quantizer_wrapper_factory
26
- from model_compression_toolkit.logger import Logger
27
- from model_compression_toolkit.core.pytorch.back2framework.pytorch_model_builder import PyTorchModelBuilder
28
- from model_compression_toolkit.gptq.common.gptq_graph import get_kernel_attribute_name_for_gptq
29
- from model_compression_toolkit.gptq.common.gptq_training import GPTQTrainer
30
- from model_compression_toolkit.gptq.common.gptq_config import GradientPTQConfig
31
25
  from model_compression_toolkit.core.common import Graph, BaseNode
32
- from model_compression_toolkit.core.common.framework_info import FrameworkInfo
33
26
  from model_compression_toolkit.core.common.framework_implementation import FrameworkImplementation
27
+ from model_compression_toolkit.core.common.framework_info import FrameworkInfo
28
+ from model_compression_toolkit.core.common.hessian import HessianInfoService, HessianScoresGranularity
29
+ from model_compression_toolkit.core.pytorch.back2framework.pytorch_model_builder import PyTorchModelBuilder
34
30
  from model_compression_toolkit.core.pytorch.constants import BIAS
31
+ from model_compression_toolkit.core.pytorch.data_util import FixedDatasetFromGenerator, IterableDatasetFromGenerator, \
32
+ IterableSampleWithConstInfoDataset, FixedSampleInfoDataset, get_collate_fn_with_extra_outputs
35
33
  from model_compression_toolkit.core.pytorch.utils import to_torch_tensor, set_model, torch_tensor_to_numpy
34
+ from model_compression_toolkit.gptq.common.gptq_config import GradientPTQConfig
35
+ from model_compression_toolkit.gptq.common.gptq_graph import get_kernel_attribute_name_for_gptq
36
+ from model_compression_toolkit.gptq.common.gptq_training import GPTQTrainer
36
37
  from model_compression_toolkit.gptq.pytorch.graph_info import get_gptq_trainable_parameters, \
37
38
  get_weights_for_loss
39
+ from model_compression_toolkit.gptq.pytorch.quantizer.gradual_activation_quantization import \
40
+ get_gradual_activation_quantizer_wrapper_factory
38
41
  from model_compression_toolkit.gptq.pytorch.quantizer.quantization_builder import quantization_builder
39
42
  from model_compression_toolkit.gptq.pytorch.quantizer.regularization_factory import get_regularization
40
- from mct_quantizers import PytorchQuantizationWrapper, PytorchActivationQuantizationHolder
43
+ from model_compression_toolkit.logger import Logger
41
44
  from model_compression_toolkit.trainable_infrastructure.pytorch.util import get_total_grad_steps
42
45
 
43
46
 
@@ -70,6 +73,7 @@ class PytorchGPTQTrainer(GPTQTrainer):
70
73
  hessian_info_service: HessianInfoService to fetch info based on the hessian approximation of the float model.
71
74
  """
72
75
  def _get_total_grad_steps():
76
+ # TODO get it from the dataset
73
77
  return get_total_grad_steps(representative_data_gen) * gptq_config.n_epochs
74
78
 
75
79
  # must be set prior to model building in the base class constructor
@@ -81,6 +85,7 @@ class PytorchGPTQTrainer(GPTQTrainer):
81
85
  gptq_config,
82
86
  fw_impl,
83
87
  fw_info,
88
+ representative_data_gen_fn=representative_data_gen,
84
89
  hessian_info_service=hessian_info_service)
85
90
 
86
91
  self.loss_list = []
@@ -106,20 +111,87 @@ class PytorchGPTQTrainer(GPTQTrainer):
106
111
  trainable_bias,
107
112
  trainable_threshold)
108
113
  hessian_cfg = self.gptq_config.hessian_weights_config
114
+
109
115
  self.use_sample_layer_attention = hessian_cfg.per_sample
110
- self.hessian_score_per_layer = None # for fixed layer weights
111
- self.hessian_score_per_image_per_layer = None # for sample-layer attention
112
116
  if self.use_sample_layer_attention:
113
117
  # normalization is currently not supported, make sure the config reflects it.
114
118
  if hessian_cfg.norm_scores or hessian_cfg.log_norm or hessian_cfg.scale_log_norm:
115
119
  raise NotImplementedError()
116
- # Per sample hessian scores are calculated on-demand during the training loop
117
- self.hessian_score_per_image_per_layer = {}
120
+ self.train_dataloader = self._prepare_train_dataloader_sla(representative_data_gen)
118
121
  else:
119
- self.hessian_score_per_layer = to_torch_tensor(self.compute_hessian_based_weights())
122
+ self.train_dataloader = self._prepare_train_dataloader_for_non_sla(representative_data_gen)
120
123
 
121
124
  self.reg_func = get_regularization(self.gptq_config, _get_total_grad_steps)
122
125
 
126
+ def _prepare_train_dataloader_sla(self, data_gen_fn: Callable[[], Generator]) -> DataLoader:
127
+ """
128
+ Computes Sample-Layer Attention score and builds a train dataloader.
129
+
130
+ Args:
131
+ data_gen_fn: factory for representative dataset generator.
132
+
133
+ Returns:
134
+ PyTorch dataloader yielding three outputs - samples, weights for the distillation loss and
135
+ weights for regularization.
136
+ """
137
+ fixed_dataset = FixedDatasetFromGenerator(data_gen_fn)
138
+ orig_batch_size = fixed_dataset.orig_batch_size
139
+ # compute hessians for the whole dataset
140
+ hess_data_loader = DataLoader(fixed_dataset,
141
+ batch_size=self.gptq_config.hessian_weights_config.hessian_batch_size,
142
+ shuffle=False)
143
+ request = self._build_hessian_request(granularity=HessianScoresGranularity.PER_OUTPUT_CHANNEL,
144
+ data_loader=hess_data_loader,
145
+ n_samples=None)
146
+ layers_hessians = self.hessian_service.fetch_hessian(request, force_compute=True)
147
+
148
+ # compute sla score defined as max over channels
149
+ layers_hessians = {layer: to_torch_tensor(hess.max(1)) for layer, hess in layers_hessians.items()}
150
+
151
+ # build train dataset and dataloader
152
+ hessians_tensor = torch.stack([layers_hessians[layer.name] for layer in self.compare_points], dim=1) # samples X layers
153
+ assert hessians_tensor.shape[1] == len(self.compare_points)
154
+ loss_weights = list(hessians_tensor)
155
+ sla_train_dataset = FixedSampleInfoDataset(fixed_dataset.samples, loss_weights)
156
+
157
+ reg_weights = hessians_tensor.mean(dim=0)
158
+ # use collate to add a single value to each batch
159
+ collate_fn = get_collate_fn_with_extra_outputs(reg_weights)
160
+
161
+ return DataLoader(sla_train_dataset, batch_size=orig_batch_size, shuffle=True, collate_fn=collate_fn)
162
+
163
+ def _prepare_train_dataloader_for_non_sla(self, data_gen_fn: Callable[[], Generator]) -> DataLoader:
164
+ """
165
+ Computes loss weights and builds a train dataloader.
166
+
167
+ Args:
168
+ data_gen_fn: factory for representative dataset generator.
169
+
170
+ Returns:
171
+ PyTorch dataloader yielding three outputs - samples, weights for the distillation loss and
172
+ weights for regularization.
173
+ """
174
+ dataset = IterableDatasetFromGenerator(data_gen_fn)
175
+ num_nodes = len(self.compare_points)
176
+
177
+ if self.gptq_config.use_hessian_based_weights:
178
+ hess_dataloader = DataLoader(dataset, batch_size=self.gptq_config.hessian_weights_config.hessian_batch_size)
179
+ loss_weights = torch.from_numpy(self.compute_hessian_based_weights(hess_dataloader))
180
+ else:
181
+ loss_weights = torch.ones(num_nodes) / num_nodes
182
+
183
+ train_dataset = IterableSampleWithConstInfoDataset(dataset, loss_weights)
184
+
185
+ reg_weights = torch.ones(num_nodes)
186
+ # use collate to add a single value to each batch
187
+ collate_fn = get_collate_fn_with_extra_outputs(reg_weights)
188
+
189
+ # NOTE: Don't just increase num_workers! With iterable dataset each worker fetches a full pass, so having
190
+ # more workers will result in multiple passes within the same epoch. Special handling is needed either
191
+ # in dataset or in worker_init_fn passed to dataloader, and it might not speed anything up anyway.
192
+ return DataLoader(train_dataset, batch_size=dataset.orig_batch_size,
193
+ collate_fn=collate_fn, num_workers=1)
194
+
123
195
  def _is_gptq_weights_trainable(self,
124
196
  node: BaseNode) -> bool:
125
197
  """
@@ -195,11 +267,10 @@ class PytorchGPTQTrainer(GPTQTrainer):
195
267
 
196
268
  return gptq_model, gptq_user_info
197
269
 
198
- def train(self, representative_data_gen: Callable):
270
+ def train(self):
199
271
  """
200
272
  GPTQ Training using pytorch framework
201
- Args:
202
- representative_data_gen: Dataset generator to get images.
273
+
203
274
  Returns:
204
275
  Graph after GPTQ training
205
276
  """
@@ -216,7 +287,7 @@ class PytorchGPTQTrainer(GPTQTrainer):
216
287
  # ----------------------------------------------
217
288
  # Training loop
218
289
  # ----------------------------------------------
219
- self.micro_training_loop(representative_data_gen, self.gptq_config.n_epochs)
290
+ self.micro_training_loop(self.gptq_config.n_epochs)
220
291
 
221
292
  def compute_gradients(self,
222
293
  y_float: List[torch.Tensor],
@@ -262,23 +333,21 @@ class PytorchGPTQTrainer(GPTQTrainer):
262
333
  return loss_value, grads
263
334
 
264
335
  def micro_training_loop(self,
265
- data_function: Callable,
266
336
  n_epochs: int):
267
337
  """
268
338
  This function run a micro training loop on given set of parameters.
269
339
  Args:
270
- data_function: A callable function that give a batch of samples.
271
340
  n_epochs: Number of update iterations of representative dataset.
272
341
  """
273
342
  with tqdm(range(n_epochs), "Running GPTQ optimization") as epochs_pbar:
274
343
  for _ in epochs_pbar:
275
- with tqdm(data_function(), position=1, leave=False) as data_pbar:
276
- for data in data_pbar:
277
- distill_weights, reg_weights = to_torch_tensor(self._get_loss_weights(data))
344
+ with tqdm(self.train_dataloader, position=1, leave=False) as data_pbar:
345
+ for sample in data_pbar:
346
+ data, loss_weight, reg_weight = to_torch_tensor(sample)
278
347
  input_data = [d * self.input_scale for d in data]
279
348
  input_tensor = to_torch_tensor(input_data)
280
349
  y_float = self.float_model(input_tensor) # running float model
281
- loss_value, grads = self.compute_gradients(y_float, input_tensor, distill_weights, reg_weights)
350
+ loss_value, grads = self.compute_gradients(y_float, input_tensor, loss_weight, reg_weight)
282
351
  # Run one step of gradient descent by updating the value of the variables to minimize the loss.
283
352
  for (optimizer, _) in self.optimizer_with_param:
284
353
  optimizer.step()
@@ -290,42 +359,6 @@ class PytorchGPTQTrainer(GPTQTrainer):
290
359
  self.loss_list.append(loss_value.item())
291
360
  Logger.debug(f'last loss value: {self.loss_list[-1]}')
292
361
 
293
- def _get_loss_weights(self, input_tensors: List[torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]:
294
- """
295
- Fetches weights for distillation and round regularization parts of loss.
296
-
297
- Args:
298
- input_tensors: list containing a batch of inputs.
299
-
300
- Returns:
301
- A tuple of two tensors:
302
- - weights for distillation loss
303
- - weights for rounding regularization loss
304
-
305
- """
306
- if self.use_sample_layer_attention is False:
307
- return self.hessian_score_per_layer, torch.ones_like(self.hessian_score_per_layer)
308
-
309
- if len(input_tensors) > 1:
310
- raise NotImplementedError('Sample-Layer attention is not currently supported for networks with multiple inputs')
311
-
312
- image_scores = []
313
- batch = input_tensors[0]
314
- img_hashes = [self.hessian_service.calc_image_hash(img) for img in batch]
315
- for img_hash in img_hashes:
316
- # If sample-layer attention score for the image is not found, compute and store it for the whole batch.
317
- if img_hash not in self.hessian_score_per_image_per_layer:
318
- score_per_image_per_layer = self._compute_sample_layer_attention_scores(input_tensors)
319
- self.hessian_score_per_image_per_layer.update(score_per_image_per_layer)
320
- img_scores_per_layer: Dict[BaseNode, np.ndarray] = self.hessian_score_per_image_per_layer[img_hash]
321
- # fetch image scores for all layers and combine them into a single tensor
322
- img_scores = np.stack(list(img_scores_per_layer.values()), axis=0)
323
- image_scores.append(img_scores)
324
-
325
- layer_sample_weights = np.stack(image_scores, axis=1) # layers X images
326
- layer_weights = layer_sample_weights.mean(axis=1)
327
- return layer_sample_weights, layer_weights
328
-
329
362
  def update_graph(self) -> Graph:
330
363
  """
331
364
  Update a graph using GPTQ after minimizing the loss between the float model's output
@@ -18,7 +18,6 @@ from typing import Callable, Union
18
18
  from model_compression_toolkit.constants import ACT_HESSIAN_DEFAULT_BATCH_SIZE, PYTORCH
19
19
  from model_compression_toolkit.core import CoreConfig
20
20
  from model_compression_toolkit.core.analyzer import analyzer_model_quantization
21
- from model_compression_toolkit.core.common.hessian import HessianScoresGranularity, HessianEstimationDistribution
22
21
  from model_compression_toolkit.core.common.mixed_precision.mixed_precision_quantization_config import \
23
22
  MixedPrecisionQuantizationConfig
24
23
  from model_compression_toolkit.core.common.mixed_precision.resource_utilization_tools.resource_utilization import \
@@ -119,7 +118,6 @@ if FOUND_TORCH:
119
118
  scale_log_norm=False,
120
119
  hessian_batch_size=hessian_batch_size,
121
120
  per_sample=True,
122
- estimator_distribution=HessianEstimationDistribution.RADEMACHER
123
121
  )
124
122
  loss = loss or sample_layer_attention_loss
125
123
  else:
@@ -47,14 +47,15 @@ class SoftQuantizerRegularization:
47
47
  Args:
48
48
  model: A model to be quantized with SoftRounding.
49
49
  entropy_reg: Entropy value to scale the quantizer regularization.
50
- layer_weights: a vector of layer weights.
50
+ layer_weights: a vector of layers weights.
51
51
 
52
52
  Returns: Regularization value.
53
53
  """
54
54
  layers = [m for m in model.modules() if isinstance(m, PytorchQuantizationWrapper)]
55
55
 
56
- if len(layer_weights.shape) != 1 or layer_weights.shape[0] != len(layers):
57
- raise ValueError(f'Expected weights to be a vector of length {len(layers)}, received {layer_weights.shape}.') # pragma: no cover
56
+ if layer_weights.shape[0] != len(layers):
57
+ raise ValueError(f'Expected weights.shape[0] to be {len(layers)}, '
58
+ f'received shape {layer_weights.shape}.') # pragma: no cover
58
59
  max_w = layer_weights.max()
59
60
 
60
61
  b = self.beta_scheduler(self.count_iter)
@@ -0,0 +1,14 @@
1
+ # Copyright 2024 Sony Semiconductor Israel, Inc. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
@@ -0,0 +1,14 @@
1
+ # Copyright 2024 Sony Semiconductor Israel, Inc. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
@@ -0,0 +1,91 @@
1
+ # Copyright 2024 Sony Semiconductor Israel, Inc. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ import numpy as np
16
+ import pytest
17
+
18
+ from model_compression_toolkit.core.keras.data_util import data_gen_to_dataloader, TFDatasetFromGenerator
19
+
20
+
21
+ @pytest.fixture(scope='session')
22
+ def fixed_dataset():
23
+ # generate 320 images with data1[i] = i and data2[i] = i+10
24
+ data1 = np.stack([np.full((3, 30, 20), v) for v in range(320)], axis=0)
25
+ data2 = np.stack([np.full((10,), v + 10) for v in range(320)], axis=0)
26
+ return data1, data2
27
+
28
+
29
+ @pytest.fixture
30
+ def fixed_gen(fixed_dataset):
31
+ def f():
32
+ for i in range(10):
33
+ yield [fixed_dataset[0][32 * i: 32 * (i + 1)], fixed_dataset[1][32 * i: 32 * (i + 1)]]
34
+
35
+ return f
36
+
37
+
38
+ def get_random_data_gen_fn(seed=42):
39
+ """ get gen factory for reproducible gen yielding different samples in each epoch """
40
+ rng = np.random.default_rng(seed)
41
+
42
+ def f():
43
+ for i in range(10):
44
+ yield [rng.random((32, 3, 20, 30)).astype(np.float32), rng.random((32, 10)).astype(np.float32)]
45
+ return f
46
+
47
+
48
+ class TestTFDataUtil:
49
+ create_dataloader_fn = data_gen_to_dataloader
50
+
51
+ def test_iterable_dataset_from_fixed_gen(self, fixed_gen):
52
+ """ tests iterable dataset from fixed gen - same samples are generated in each epoch in the same order """
53
+ ds = TFDatasetFromGenerator(fixed_gen, batch_size=1)
54
+ self._validate_ds_from_fixed_gen(ds, 320)
55
+
56
+ def test_iterable_dataset_from_random_gen(self):
57
+ """ test that dataset samples over epochs are identical to the original data generator """
58
+ ds = TFDatasetFromGenerator(get_random_data_gen_fn(), batch_size=1)
59
+ pass1 = np.concatenate([t[0] for t in ds], axis=0)
60
+ pass2 = np.concatenate([t[0] for t in ds], axis=0)
61
+
62
+ gen_fn = get_random_data_gen_fn()
63
+ # one invocation is used for validation and batch size in dataset, so promote the reference gen for comparison
64
+ next(gen_fn())
65
+ gen_pass1 = np.concatenate([t[0] for t in gen_fn()], axis=0)
66
+ gen_pass2 = np.concatenate([t[0] for t in gen_fn()], axis=0)
67
+ # check that each pass is identical to corresponding pass in the original gen
68
+ assert np.array_equal(pass1, gen_pass1)
69
+ assert np.array_equal(pass2, gen_pass2)
70
+ assert not np.allclose(pass1, pass2)
71
+
72
+ def test_dataloader(self, fixed_gen):
73
+ ds = TFDatasetFromGenerator(fixed_gen, batch_size=25)
74
+ ds_iter = iter(ds)
75
+ batch1 = next(ds_iter)
76
+ assert batch1[0].shape[0] == batch1[1].shape[0] == 25
77
+ assert np.array_equal(batch1[0][0], np.full((3, 30, 20), 0))
78
+ assert np.array_equal(batch1[1][0], np.full((10,), 10))
79
+ assert np.array_equal(batch1[0][-1], np.full((3, 30, 20), 24))
80
+ assert np.array_equal(batch1[1][-1], np.full((10,), 34))
81
+ assert len(ds) == 13
82
+ assert ds.orig_batch_size == 32
83
+
84
+ def _validate_ds_from_fixed_gen(self, ds, exp_len):
85
+ for _ in range(2):
86
+ for i, sample in enumerate(ds):
87
+ assert np.array_equal(sample[0].cpu().numpy(), np.full((1, 3, 30, 20), i))
88
+ assert np.array_equal(sample[1].cpu().numpy(), np.full((1, 10,), i + 10))
89
+ assert i == exp_len - 1
90
+ assert ds.orig_batch_size == 32
91
+ assert len(ds) == exp_len
@@ -0,0 +1,14 @@
1
+ # Copyright 2024 Sony Semiconductor Israel, Inc. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
@@ -0,0 +1,125 @@
1
+ # Copyright 2024 Sony Semiconductor Israel, Inc. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ import pytest
16
+ import torch
17
+ import numpy as np
18
+ from torch.utils.data import IterableDataset, Dataset
19
+
20
+ from model_compression_toolkit.core.pytorch.data_util import (data_gen_to_dataloader, IterableDatasetFromGenerator,
21
+ FixedDatasetFromGenerator, FixedSampleInfoDataset)
22
+
23
+
24
+ @pytest.fixture(scope='session')
25
+ def fixed_dataset():
26
+ # generate 320 images with data1[i] = i and data2[i] = i+10
27
+ data1 = np.stack([np.full((3, 30, 20), v) for v in range(320)], axis=0)
28
+ data2 = np.stack([np.full((10,), v + 10) for v in range(320)], axis=0)
29
+ return data1, data2
30
+
31
+
32
+ @pytest.fixture
33
+ def fixed_gen(fixed_dataset):
34
+ def f():
35
+ for i in range(10):
36
+ yield [fixed_dataset[0][32 * i: 32 * (i + 1)], fixed_dataset[1][32 * i: 32 * (i + 1)]]
37
+
38
+ return f
39
+
40
+
41
+ def get_random_data_gen_fn(seed=42):
42
+ """ get gen factory for reproducible gen yielding different samples in each epoch """
43
+ rng = np.random.default_rng(seed)
44
+
45
+ def f():
46
+ for i in range(10):
47
+ yield [rng.random((32, 3, 20, 30)), rng.random((32, 10))]
48
+ return f
49
+
50
+
51
+ class TestDataUtil:
52
+ create_dataloader_fn = data_gen_to_dataloader
53
+
54
+ def test_iterable_dataset_from_fixed_gen(self, fixed_gen):
55
+ """ tests iterable dataset from fixed gen - same samples are generated in each epoch in the same order """
56
+ ds = IterableDatasetFromGenerator(fixed_gen)
57
+ assert isinstance(ds, IterableDataset)
58
+ self._validate_ds_from_fixed_gen(ds, 320)
59
+
60
+ def test_iterable_dataset_from_random_gen(self):
61
+ """ test that dataset samples over epochs are identical to the original data generator """
62
+ ds = IterableDatasetFromGenerator(get_random_data_gen_fn())
63
+ pass1 = torch.stack([t[0] for t in ds], dim=0)
64
+ pass2 = torch.stack([t[0] for t in ds], dim=0)
65
+
66
+ gen_fn = get_random_data_gen_fn()
67
+ # one invocation is used for validation and batch size in dataset, so promote the reference gen for comparison
68
+ next(gen_fn())
69
+ gen_pass1 = np.concatenate([t[0] for t in gen_fn()], axis=0)
70
+ gen_pass2 = np.concatenate([t[0] for t in gen_fn()], axis=0)
71
+ # check that each pass is identical to corresponding pass in the original gen
72
+ assert np.allclose(pass1.cpu().numpy(), gen_pass1)
73
+ assert np.allclose(pass2.cpu().numpy(), gen_pass2)
74
+ assert not torch.equal(pass1, pass2)
75
+
76
+ def test_fixed_dataset_from_fixed_gen_full(self, fixed_gen):
77
+ ds = FixedDatasetFromGenerator(fixed_gen)
78
+ assert isinstance(ds, Dataset) and not isinstance(ds, IterableDataset)
79
+ self._validate_ds_from_fixed_gen(ds, 320)
80
+
81
+ def test_fixed_dataset_from_const_gen_subset(self, fixed_gen):
82
+ ds = FixedDatasetFromGenerator(fixed_gen, n_samples=25)
83
+ self._validate_ds_from_fixed_gen(ds, 25)
84
+
85
+ def test_fixed_dataset_from_random_gen_full(self):
86
+ ds = FixedDatasetFromGenerator(get_random_data_gen_fn())
87
+ self._validate_fixed_ds(ds, exp_len=320, exp_batch_size=32)
88
+
89
+ def test_fixed_dataset_from_random_gen_subset(self):
90
+ ds = FixedDatasetFromGenerator(get_random_data_gen_fn(), n_samples=123)
91
+ self._validate_fixed_ds(ds, exp_len=123, exp_batch_size=32)
92
+
93
+ def test_not_enough_samples_in_datagen(self):
94
+ def gen():
95
+ yield [np.ones((10, 3))]
96
+ with pytest.raises(ValueError, match='Not enough samples in the data generator'):
97
+ FixedDatasetFromGenerator(gen, n_samples=11)
98
+
99
+ def test_extra_info_mismatch(self, fixed_gen):
100
+ with pytest.raises(ValueError, match='Mismatch in the number of samples between samples and complementary data'):
101
+ FixedSampleInfoDataset([1]*10, [2]*10, [3]*11)
102
+
103
+ @pytest.mark.parametrize('ds_cls', [FixedDatasetFromGenerator, IterableDatasetFromGenerator])
104
+ def test_invalid_gen(self, ds_cls):
105
+ def gen():
106
+ yield np.ones((10, 3))
107
+ with pytest.raises(TypeError, match='Data generator is expected to yield a list of tensors'):
108
+ ds_cls(gen)
109
+
110
+ def _validate_ds_from_fixed_gen(self, ds, exp_len):
111
+ for _ in range(2):
112
+ for i, sample in enumerate(ds):
113
+ assert np.array_equal(sample[0].cpu().numpy(), np.full((3, 30, 20), i))
114
+ assert np.array_equal(sample[1].cpu().numpy(), np.full((10,), i + 10))
115
+ assert i == exp_len - 1
116
+ assert ds.orig_batch_size == 32
117
+ assert len(ds) == exp_len
118
+
119
+ def _validate_fixed_ds(self, ds, exp_len, exp_batch_size):
120
+ assert isinstance(ds, torch.utils.data.Dataset) and not isinstance(ds, torch.utils.data.IterableDataset)
121
+ full_pass1 = torch.concat([t[0] for t in ds], dim=0)
122
+ full_pass2 = torch.concat([t[0] for t in ds], dim=0)
123
+ assert torch.equal(full_pass1, full_pass2)
124
+ assert len(ds) == exp_len
125
+ assert ds.orig_batch_size == exp_batch_size