mct-nightly 2.2.0.20241022.507__py3-none-any.whl → 2.2.0.20241024.501__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (38) hide show
  1. {mct_nightly-2.2.0.20241022.507.dist-info → mct_nightly-2.2.0.20241024.501.dist-info}/METADATA +1 -1
  2. {mct_nightly-2.2.0.20241022.507.dist-info → mct_nightly-2.2.0.20241024.501.dist-info}/RECORD +38 -31
  3. model_compression_toolkit/__init__.py +1 -1
  4. model_compression_toolkit/core/common/framework_implementation.py +43 -29
  5. model_compression_toolkit/core/common/hessian/__init__.py +1 -1
  6. model_compression_toolkit/core/common/hessian/hessian_info_service.py +222 -371
  7. model_compression_toolkit/core/common/hessian/hessian_scores_request.py +27 -41
  8. model_compression_toolkit/core/common/mixed_precision/sensitivity_evaluation.py +8 -10
  9. model_compression_toolkit/core/common/pruning/importance_metrics/lfh_importance_metric.py +11 -9
  10. model_compression_toolkit/core/common/quantization/quantization_params_generation/error_functions.py +10 -6
  11. model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_computation.py +17 -15
  12. model_compression_toolkit/core/keras/data_util.py +67 -0
  13. model_compression_toolkit/core/keras/keras_implementation.py +7 -1
  14. model_compression_toolkit/core/keras/tf_tensor_numpy.py +1 -1
  15. model_compression_toolkit/core/pytorch/back2framework/pytorch_model_builder.py +1 -1
  16. model_compression_toolkit/core/pytorch/data_util.py +163 -0
  17. model_compression_toolkit/core/pytorch/hessian/activation_hessian_scores_calculator_pytorch.py +6 -31
  18. model_compression_toolkit/core/pytorch/hessian/hessian_scores_calculator_pytorch.py +11 -21
  19. model_compression_toolkit/core/pytorch/hessian/weights_hessian_scores_calculator_pytorch.py +9 -7
  20. model_compression_toolkit/core/pytorch/pytorch_implementation.py +8 -2
  21. model_compression_toolkit/core/pytorch/utils.py +22 -19
  22. model_compression_toolkit/core/quantization_prep_runner.py +2 -1
  23. model_compression_toolkit/core/runner.py +1 -2
  24. model_compression_toolkit/gptq/common/gptq_config.py +0 -2
  25. model_compression_toolkit/gptq/common/gptq_training.py +58 -114
  26. model_compression_toolkit/gptq/keras/gptq_training.py +15 -6
  27. model_compression_toolkit/gptq/pytorch/gptq_loss.py +3 -2
  28. model_compression_toolkit/gptq/pytorch/gptq_training.py +97 -64
  29. model_compression_toolkit/gptq/pytorch/quantization_facade.py +0 -2
  30. model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/soft_quantizer_reg.py +4 -3
  31. tests_pytest/keras/__init__.py +14 -0
  32. tests_pytest/keras/core/__init__.py +14 -0
  33. tests_pytest/keras/core/test_data_util.py +91 -0
  34. tests_pytest/pytorch/core/__init__.py +14 -0
  35. tests_pytest/pytorch/core/test_data_util.py +125 -0
  36. {mct_nightly-2.2.0.20241022.507.dist-info → mct_nightly-2.2.0.20241024.501.dist-info}/LICENSE.md +0 -0
  37. {mct_nightly-2.2.0.20241022.507.dist-info → mct_nightly-2.2.0.20241024.501.dist-info}/WHEEL +0 -0
  38. {mct_nightly-2.2.0.20241022.507.dist-info → mct_nightly-2.2.0.20241024.501.dist-info}/top_level.txt +0 -0
@@ -12,10 +12,14 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
- from typing import List
15
+ from typing import Iterable, Sequence, Optional, TYPE_CHECKING
16
+ import dataclasses
16
17
 
17
18
  from enum import Enum
18
19
 
20
+ if TYPE_CHECKING: # pragma: no cover
21
+ from model_compression_toolkit.core.common import BaseNode
22
+
19
23
 
20
24
  class HessianMode(Enum):
21
25
  """
@@ -40,14 +44,7 @@ class HessianScoresGranularity(Enum):
40
44
  PER_TENSOR = 2
41
45
 
42
46
 
43
- class HessianEstimationDistribution(str, Enum):
44
- """
45
- Distribution for Hutchinson estimator random vector
46
- """
47
- GAUSSIAN = 'gaussian'
48
- RADEMACHER = 'rademacher'
49
-
50
-
47
+ @dataclasses.dataclass
51
48
  class HessianScoresRequest:
52
49
  """
53
50
  Request configuration for the Hessian-approximation scores.
@@ -55,36 +52,25 @@ class HessianScoresRequest:
55
52
  This class defines the parameters for the scores based on the Hessian matrix approximation.
56
53
  It specifies the mode (weights/activations), granularity (element/channel/tensor), and the target node.
57
54
 
58
- Note: This does not compute scores using the actual Hessian matrix but an approximation.
55
+ Attributes:
56
+ mode: Mode of Hessian-approximation score (w.r.t weights or activations).
57
+ granularity: Granularity level for the approximation.
58
+ target_nodes: The node objects in the float graph for which the Hessian's approximation scores is targeted.
59
+ data_loader: Data loader to compute hessian approximations on. Should reflect the desired batch size for
60
+ the computation. Can be None if all hessians for the request are expected to be pre-computed previously.
61
+ n_samples: The number of samples to fetch hessian estimations for. If None, fetch hessians for a full pass
62
+ of the data loader.
59
63
  """
60
-
61
- def __init__(self,
62
- mode: HessianMode,
63
- granularity: HessianScoresGranularity,
64
- target_nodes: List,
65
- distribution: HessianEstimationDistribution = HessianEstimationDistribution.GAUSSIAN):
66
- """
67
- Attributes:
68
- mode (HessianMode): Mode of Hessian-approximation score (w.r.t weights or activations).
69
- granularity (HessianScoresGranularity): Granularity level for the approximation.
70
- target_nodes (List[BaseNode]): The node in the float graph for which the Hessian's approximation scores is targeted.
71
- """
72
-
73
- self.mode = mode # w.r.t activations or weights
74
- self.granularity = granularity # per element, per layer, per channel
75
- self.target_nodes = target_nodes
76
- self.distribution = distribution
77
-
78
- def __eq__(self, other):
79
- # Checks if the other object is an instance of HessianScoresRequest
80
- # and then checks if all attributes are equal.
81
- return isinstance(other, HessianScoresRequest) and \
82
- self.mode == other.mode and \
83
- self.granularity == other.granularity and \
84
- self.target_nodes == other.target_nodes and \
85
- self.distribution == other.distribution
86
-
87
- def __hash__(self):
88
- # Computes the hash based on the attributes.
89
- # The use of a tuple here ensures that the hash is influenced by all the attributes.
90
- return hash((self.mode, self.granularity, tuple(self.target_nodes), self.distribution))
64
+ mode: HessianMode
65
+ granularity: HessianScoresGranularity
66
+ target_nodes: Sequence['BaseNode']
67
+ data_loader: Optional[Iterable]
68
+ n_samples: Optional[int]
69
+
70
+ def __post_init__(self):
71
+ if self.data_loader is None and self.n_samples is None:
72
+ raise ValueError('Data loader and the number of samples cannot both be None.')
73
+
74
+ def clone(self, **kwargs):
75
+ """ Create a clone with optional overrides """
76
+ return dataclasses.replace(self, **kwargs)
@@ -238,22 +238,20 @@ class SensitivityEvaluation:
238
238
  """
239
239
  # Create a request for Hessian approximation scores with specific configurations
240
240
  # (here we use per-tensor approximation of the Hessian's trace w.r.t the node's activations)
241
+ fw_dataloader = self.fw_impl.convert_data_gen_to_dataloader(self.representative_data_gen,
242
+ batch_size=self.quant_config.hessian_batch_size)
241
243
  hessian_info_request = HessianScoresRequest(mode=HessianMode.ACTIVATION,
242
244
  granularity=HessianScoresGranularity.PER_TENSOR,
243
- target_nodes=self.interest_points)
245
+ target_nodes=self.interest_points,
246
+ data_loader=fw_dataloader,
247
+ n_samples=self.quant_config.num_of_images)
244
248
 
245
249
  # Fetch the Hessian approximation scores for the current interest point
246
- nodes_approximations = self.hessian_info_service.fetch_hessian(hessian_scores_request=hessian_info_request,
247
- required_size=self.quant_config.num_of_images,
248
- batch_size=self.quant_config.hessian_batch_size)
249
-
250
- # Store the approximations for each node for each image
251
- approx_by_image = [[nodes_approximations[j][image_idx]
252
- for j, _ in enumerate(self.interest_points)]
253
- for image_idx in range(self.quant_config.num_of_images)]
250
+ nodes_approximations = self.hessian_info_service.fetch_hessian(request=hessian_info_request)
251
+ approx_by_image = np.stack([nodes_approximations[n.name] for n in self.interest_points], axis=1) # samples X nodes
254
252
 
255
253
  # Return the mean approximation value across all images for each interest point
256
- return np.mean(np.stack(approx_by_image), axis=0)
254
+ return np.mean(approx_by_image, axis=0)
257
255
 
258
256
  def _configure_bitwidths_model(self,
259
257
  mp_model_configuration: List[int],
@@ -120,22 +120,24 @@ class LFHImportanceMetric(BaseImportanceMetric):
120
120
  """
121
121
 
122
122
  # Initialize HessianInfoService for score computation.
123
+
123
124
  hessian_info_service = HessianInfoService(graph=self.float_graph,
124
- representative_dataset_gen=self.representative_data_gen,
125
125
  fw_impl=self.fw_impl)
126
126
 
127
127
  # Fetch and process Hessian scores for output channels of entry nodes.
128
- nodes_scores = []
128
+ data_loader = self.fw_impl.convert_data_gen_to_dataloader(self.representative_data_gen, batch_size=1)
129
+ nodes_scores = {}
129
130
  for node in entry_nodes:
130
- _request = HessianScoresRequest(mode=HessianMode.WEIGHTS,
131
- granularity=HessianScoresGranularity.PER_OUTPUT_CHANNEL,
132
- target_nodes=[node])
133
- _scores_for_node = hessian_info_service.fetch_hessian(_request,
134
- required_size=self.pruning_config.num_score_approximations)
135
- nodes_scores.append(_scores_for_node)
131
+ request = HessianScoresRequest(mode=HessianMode.WEIGHTS,
132
+ granularity=HessianScoresGranularity.PER_OUTPUT_CHANNEL,
133
+ target_nodes=[node],
134
+ data_loader=data_loader,
135
+ n_samples=self.pruning_config.num_score_approximations)
136
+ node_scores = hessian_info_service.fetch_hessian(request)
137
+ nodes_scores.update(node_scores)
136
138
 
137
139
  # Average and map scores to nodes.
138
- self._entry_node_to_hessian_score = {node: np.mean(scores[0], axis=0).flatten() for node, scores in zip(entry_nodes, nodes_scores)}
140
+ self._entry_node_to_hessian_score = {node: np.mean(nodes_scores[node.name], axis=0).flatten() for node in entry_nodes}
139
141
 
140
142
  self._entry_node_count_oc_nparams = self._count_oc_nparams(entry_nodes=entry_nodes)
141
143
  _entry_node_l2_oc_norm = self._get_squaredl2norm(entry_nodes=entry_nodes)
@@ -13,7 +13,7 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
  from copy import deepcopy
16
- from typing import Tuple, Callable, List
16
+ from typing import Tuple, Callable, List, Iterable, Optional
17
17
  import numpy as np
18
18
  import model_compression_toolkit.core.common.quantization.quantization_config as qc
19
19
  from model_compression_toolkit.core.common.hessian import HessianScoresRequest, HessianMode, HessianScoresGranularity, \
@@ -377,7 +377,8 @@ def _get_sliced_histogram(bins: np.ndarray,
377
377
 
378
378
  def _compute_hessian_for_hmse(node,
379
379
  hessian_info_service: HessianInfoService,
380
- num_hessian_samples: int = NUM_QPARAM_HESSIAN_SAMPLES) -> List[List[np.ndarray]]:
380
+ num_hessian_samples: int,
381
+ dataloader: Optional[Iterable]) -> List[List[np.ndarray]]:
381
382
  """
382
383
  Compute and retrieve Hessian-based scores for using during HMSE error computation.
383
384
 
@@ -385,15 +386,18 @@ def _compute_hessian_for_hmse(node,
385
386
  node: The node to compute Hessian-based scores for.
386
387
  hessian_info_service: HessianInfoService object for retrieving Hessian-based scores.
387
388
  num_hessian_samples: Number of samples to approximate Hessian-based scores on.
389
+ dataloader: Data loader for computing Hessian-based scores. Can be None if hessians are expected to be
390
+ available, i.e. have been already computed previously.
388
391
 
389
392
  Returns: A list with computed Hessian-based scores tensors for the given node.
390
393
 
391
394
  """
392
395
  _request = HessianScoresRequest(mode=HessianMode.WEIGHTS,
393
396
  granularity=HessianScoresGranularity.PER_ELEMENT,
397
+ data_loader=dataloader,
398
+ n_samples=num_hessian_samples,
394
399
  target_nodes=[node])
395
- _scores_for_node = hessian_info_service.fetch_hessian(_request,
396
- required_size=num_hessian_samples)
400
+ _scores_for_node = hessian_info_service.fetch_hessian(_request)
397
401
 
398
402
  return _scores_for_node
399
403
 
@@ -476,11 +480,11 @@ def get_threshold_selection_tensor_error_function(quantization_method: Quantizat
476
480
  per_channel=True)
477
481
 
478
482
  if quant_error_method == qc.QuantizationErrorMethod.HMSE:
479
- node_hessian_scores = _compute_hessian_for_hmse(node, hessian_info_service, num_hessian_samples)
483
+ node_hessian_scores = _compute_hessian_for_hmse(node, hessian_info_service, num_hessian_samples, None)
480
484
  if len(node_hessian_scores) != 1:
481
485
  Logger.critical(f"Expecting single node Hessian score request to return a list of length 1, but got a list "
482
486
  f"of length {len(node_hessian_scores)}.")
483
- node_hessian_scores = np.sqrt(np.mean(node_hessian_scores[0], axis=0))
487
+ node_hessian_scores = np.sqrt(np.mean(node_hessian_scores[node.name], axis=0))
484
488
 
485
489
  return lambda x, y, threshold: _hmse_error_function_wrapper(x, y, norm=norm, axis=axis,
486
490
  hessian_scores=node_hessian_scores)
@@ -15,11 +15,12 @@
15
15
  import copy
16
16
 
17
17
  from tqdm import tqdm
18
- from typing import List
18
+ from typing import List, Callable, Generator
19
19
 
20
20
  from model_compression_toolkit.constants import NUM_QPARAM_HESSIAN_SAMPLES
21
21
  from model_compression_toolkit.core import QuantizationErrorMethod
22
22
  from model_compression_toolkit.core.common import Graph, BaseNode
23
+ from model_compression_toolkit.core.common.framework_implementation import FrameworkImplementation
23
24
  from model_compression_toolkit.core.common.hessian import HessianInfoService, HessianScoresRequest, HessianMode, \
24
25
  HessianScoresGranularity
25
26
  from model_compression_toolkit.core.common.quantization.quantization_params_generation.qparams_activations_computation \
@@ -55,26 +56,25 @@ def _collect_nodes_for_hmse(nodes_list: List[BaseNode], graph: Graph) -> List[Ba
55
56
 
56
57
 
57
58
  def calculate_quantization_params(graph: Graph,
58
- nodes: List[BaseNode] = [],
59
- specific_nodes: bool = False,
59
+ fw_impl: FrameworkImplementation,
60
+ repr_data_gen_fn: Callable[[], Generator],
61
+ nodes: List[BaseNode] = None,
60
62
  hessian_info_service: HessianInfoService = None,
61
63
  num_hessian_samples: int = NUM_QPARAM_HESSIAN_SAMPLES):
62
64
  """
63
65
  For a graph, go over its nodes, compute quantization params (for both weights and activations according
64
66
  to the given framework info), and create and attach a NodeQuantizationConfig to each node (containing the
65
67
  computed params).
66
- By default, the function goes over all nodes in the graph. However, the specific_nodes flag enables
67
- to compute quantization params for specific nodes if the default behavior is unnecessary. For that,
68
- a list of nodes should be passed as well.
68
+ By default, the function goes over all nodes in the graph. However, specific nodes can be passed
69
+ to compute quantization params only for them.
69
70
 
70
71
  Args:
71
- groups of layers by how they should be quantized, etc.)
72
72
  graph: Graph to compute its nodes' thresholds.
73
+ fw_impl: FrameworkImplementation object.
74
+ repr_data_gen_fn: callable returning representative dataset generator.
73
75
  nodes: List of nodes to compute their thresholds instead of computing it for all nodes in the graph.
74
- specific_nodes: Flag to compute thresholds for only specific nodes.
75
76
  hessian_info_service: HessianInfoService object for retrieving Hessian-based scores (used only with HMSE error method).
76
77
  num_hessian_samples: Number of samples to approximate Hessian-based scores on (used only with HMSE error method).
77
-
78
78
  """
79
79
 
80
80
  Logger.info(f"\nRunning quantization parameters search. "
@@ -82,18 +82,20 @@ def calculate_quantization_params(graph: Graph,
82
82
  f"depending on the model size and the selected quantization methods.\n")
83
83
 
84
84
  # Create a list of nodes to compute their thresholds
85
- nodes_list: List[BaseNode] = nodes if specific_nodes else graph.nodes()
85
+ nodes_list: List[BaseNode] = nodes or graph.nodes()
86
86
 
87
87
  # Collecting nodes that are configured to search weights quantization parameters using HMSE optimization
88
88
  # and computing required Hessian information to be used for HMSE parameters selection.
89
89
  # The Hessian scores are computed and stored in the hessian_info_service object.
90
90
  nodes_for_hmse = _collect_nodes_for_hmse(nodes_list, graph)
91
91
  if len(nodes_for_hmse) > 0:
92
- hessian_info_service.fetch_hessian(HessianScoresRequest(mode=HessianMode.WEIGHTS,
93
- granularity=HessianScoresGranularity.PER_ELEMENT,
94
- target_nodes=nodes_for_hmse),
95
- required_size=num_hessian_samples,
96
- batch_size=1)
92
+ dataloader = fw_impl.convert_data_gen_to_dataloader(repr_data_gen_fn, batch_size=1)
93
+ request = HessianScoresRequest(mode=HessianMode.WEIGHTS,
94
+ granularity=HessianScoresGranularity.PER_ELEMENT,
95
+ data_loader=dataloader,
96
+ n_samples=num_hessian_samples,
97
+ target_nodes=nodes_for_hmse)
98
+ hessian_info_service.fetch_hessian(request)
97
99
 
98
100
  for n in tqdm(nodes_list, "Calculating quantization parameters"): # iterate only nodes that we should compute their thresholds
99
101
  for candidate_qc in n.candidates_quantization_cfg:
@@ -0,0 +1,67 @@
1
+ # Copyright 2024 Sony Semiconductor Israel, Inc. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ from typing import Generator, Callable
16
+
17
+ import tensorflow as tf
18
+
19
+ from model_compression_toolkit.core.keras.tf_tensor_numpy import to_tf_tensor
20
+
21
+
22
+ def flat_gen_fn(data_gen_fn: Callable[[], Generator]):
23
+ """
24
+ Convert data generator with arbitrary batch size to a flat (sample by sample) data generator.
25
+
26
+ Args:
27
+ data_gen_fn: input data generator factory. Generator is expected to yield lists of tensors.
28
+
29
+ Returns:
30
+ A factory for a flattened data generator.
31
+ """
32
+ def gen():
33
+ for inputs_batch in data_gen_fn():
34
+ for sample in zip(*inputs_batch):
35
+ yield to_tf_tensor(sample)
36
+ return gen
37
+
38
+
39
+ # TODO in tf dataset and dataloader are combined within tf.data.Dataset. For advanced use cases such as gptq sla we
40
+ # need to separate dataset from dataloader similarly to torch data_util.
41
+ class TFDatasetFromGenerator:
42
+ def __init__(self, data_gen, batch_size):
43
+ inputs = next(data_gen())
44
+ if not isinstance(inputs, list):
45
+ raise TypeError(f'Representative data generator is expected to generate a list of tensors, '
46
+ f'got {type(inputs)}') # pragma: no cover
47
+
48
+ self.orig_batch_size = inputs[0].shape[0]
49
+
50
+ output_signature = tuple([tf.TensorSpec(shape=t.shape[1:], dtype=t.dtype) for t in inputs])
51
+ dataset = tf.data.Dataset.from_generator(flat_gen_fn(data_gen), output_signature=output_signature)
52
+ self.dataset = dataset.batch(batch_size)
53
+ self._size = None
54
+
55
+ def __iter__(self):
56
+ return iter(self.dataset)
57
+
58
+ def __len__(self):
59
+ """ Returns the number of batches. """
60
+ if self._size is None:
61
+ self._num_batches = sum(1 for _ in self)
62
+ return self._num_batches
63
+
64
+
65
+ def data_gen_to_dataloader(data_gen_fn: Callable[[], Generator], batch_size) -> TFDatasetFromGenerator:
66
+ """ Create DataLoader based on samples yielded by data_gen. """
67
+ return TFDatasetFromGenerator(data_gen_fn, batch_size)
@@ -13,7 +13,7 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
  from functools import partial
16
- from typing import List, Any, Tuple, Callable, Dict, Union
16
+ from typing import List, Any, Tuple, Callable, Dict, Union, Generator
17
17
 
18
18
  import numpy as np
19
19
  import tensorflow as tf
@@ -23,6 +23,7 @@ from tensorflow.keras.models import Model
23
23
  from model_compression_toolkit.constants import HESSIAN_NUM_ITERATIONS
24
24
  from model_compression_toolkit.core.common.graph.functional_node import FunctionalNode
25
25
  from model_compression_toolkit.core.common.hessian import HessianScoresRequest, HessianMode, HessianInfoService
26
+ from model_compression_toolkit.core.keras.data_util import data_gen_to_dataloader
26
27
  from model_compression_toolkit.core.keras.graph_substitutions.substitutions.remove_identity import RemoveIdentity
27
28
  from model_compression_toolkit.core.keras.hessian.activation_hessian_scores_calculator_keras import \
28
29
  ActivationHessianScoresCalculatorKeras
@@ -628,3 +629,8 @@ class KerasImplementation(FrameworkImplementation):
628
629
  get_weights_quantizer_for_node,
629
630
  get_activations_quantizer_for_node,
630
631
  attribute_names)
632
+
633
+ @staticmethod
634
+ def convert_data_gen_to_dataloader(data_gen_fn: Callable[[], Generator], batch_size: int):
635
+ """ Create DataLoader based on samples yielded by data_gen. """
636
+ return data_gen_to_dataloader(data_gen_fn, batch_size=batch_size)
@@ -33,7 +33,7 @@ def to_tf_tensor(tensor):
33
33
  elif isinstance(tensor, list):
34
34
  return [to_tf_tensor(t) for t in tensor]
35
35
  elif isinstance(tensor, tuple):
36
- return (to_tf_tensor(t) for t in tensor)
36
+ return tuple(to_tf_tensor(t) for t in tensor)
37
37
  elif isinstance(tensor, np.ndarray):
38
38
  return tf.convert_to_tensor(tensor.astype(np.float32))
39
39
  else: # pragma: no cover
@@ -134,7 +134,7 @@ def _run_operation(n: BaseNode,
134
134
  input_tensors = n.insert_positional_weights_to_input_list(input_tensors)
135
135
  # convert inputs from positional weights (numpy arrays) to tensors. Must handle each element in the
136
136
  # list separately, because in FX the tensors are FX objects and fail to_torch_tensor
137
- input_tensors = [to_torch_tensor(t, numpy_type=t.dtype) if isinstance(t, np.ndarray) else t
137
+ input_tensors = [to_torch_tensor(t, None) if isinstance(t, np.ndarray) else t
138
138
  for t in input_tensors]
139
139
  _tensor_input_allocs = None
140
140
 
@@ -0,0 +1,163 @@
1
+ # Copyright 2024 Sony Semiconductor Israel, Inc. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ from typing import Generator, Callable, Sequence, Any
16
+
17
+ import torch
18
+ from torch.utils.data import IterableDataset, Dataset, DataLoader, default_collate
19
+
20
+
21
+ def flat_gen_fn(data_gen_fn: Callable[[], Generator]):
22
+ """
23
+ Convert data generator with arbitrary batch size to a flat (sample by sample) data generator.
24
+
25
+ Args:
26
+ data_gen_fn: input data generator factory. Generator is expected to yield lists of tensors.
27
+
28
+ Returns:
29
+ A factory for a flattened data generator.
30
+ """
31
+ def gen():
32
+ for inputs_batch in data_gen_fn():
33
+ for sample in zip(*inputs_batch):
34
+ # convert to torch tensor but do not move to device yet (it will cause issues with num_workers > 0)
35
+ yield [torch.as_tensor(s) for s in sample]
36
+ return gen
37
+
38
+
39
+ class IterableDatasetFromGenerator(IterableDataset):
40
+ """
41
+ PyTorch iterable dataset built from a data generator factory.
42
+ Each iteration over the dataset corresponds to one pass over a fresh instance of a data generator.
43
+ Therefore, if the data generator factory creates data generator instances that yield different samples,
44
+ this behavior is preserved.
45
+ """
46
+
47
+ def __init__(self, data_gen_fn: Callable[[], Generator]):
48
+ """
49
+ Args:
50
+ data_gen_fn: a factory for data generator that yields lists of tensors.
51
+ """
52
+ # validate one batch
53
+ test_batch = next(data_gen_fn())
54
+ if not isinstance(test_batch, list):
55
+ raise TypeError(f'Data generator is expected to yield a list of tensors, got {type(test_batch)}')
56
+ self.orig_batch_size = test_batch[0].shape[0]
57
+
58
+ self._size = None
59
+ self._gen_fn = flat_gen_fn(data_gen_fn)
60
+
61
+ def __iter__(self):
62
+ """ Return an iterator for the dataset. """
63
+ return self._gen_fn()
64
+
65
+ def __len__(self):
66
+ """ Get the length of the dataset. """
67
+ if self._size is None:
68
+ self._size = sum(1 for _ in self)
69
+ return self._size
70
+
71
+
72
+ class FixedDatasetFromGenerator(Dataset):
73
+ """
74
+ Dataset containing a fixed number of samples (i.e. same samples are yielded in each epoch), retrieved from a
75
+ data generator.
76
+ Note that the samples are stored in memory.
77
+
78
+ Attributes:
79
+ orig_batch_size: the batch size of the input data generator (retrieved from the first batch).
80
+ """
81
+ def __init__(self, data_gen_fn: Callable[[], Generator], n_samples: int = None):
82
+ """
83
+ Args:
84
+ data_gen_fn: data generator factory.
85
+ n_samples: target size of the dataset. If None, use all samples yielded by the data generator in one pass.
86
+ """
87
+ test_batch = next(data_gen_fn())
88
+ if not isinstance(test_batch, list):
89
+ raise TypeError(f'Data generator is expected to yield a list of tensors, got {type(test_batch)}')
90
+ self.orig_batch_size = test_batch[0].shape[0]
91
+
92
+ samples = []
93
+ for batch in data_gen_fn():
94
+ # convert to torch tensor but do not move to device yet (it will cause issues with num_workers > 0)
95
+ batch = [torch.as_tensor(t) for t in batch]
96
+ samples.extend(zip(*batch))
97
+ if n_samples is not None and len(samples) >= n_samples:
98
+ samples = samples[:n_samples]
99
+ break
100
+
101
+ if n_samples is not None and len(samples) < n_samples:
102
+ raise ValueError(f'Not enough samples in the data generator to create a dataset with {n_samples}')
103
+ self.samples = samples
104
+
105
+ def __len__(self):
106
+ return len(self.samples)
107
+
108
+ def __getitem__(self, index):
109
+ return list(self.samples[index])
110
+
111
+
112
+ class FixedSampleInfoDataset(Dataset):
113
+ """
114
+ Dataset for samples augmented with additional info per sample.
115
+ Each element in the dataset is a tuple containing the sample and sample's additional info.
116
+ """
117
+ def __init__(self, samples: Sequence, *sample_info: Sequence):
118
+ """
119
+ Args:
120
+ samples: a sequence of input samples.
121
+ hessians: one or more sequences of samples complementary data of matching sizes.
122
+ """
123
+ if not all(len(info) == len(samples) for info in sample_info):
124
+ raise ValueError('Mismatch in the number of samples between samples and complementary data.')
125
+ self.samples = samples
126
+ self.sample_info = sample_info
127
+
128
+ def __getitem__(self, index):
129
+ return self.samples[index], *[info[index] for info in self.sample_info]
130
+
131
+ def __len__(self):
132
+ return len(self.samples)
133
+
134
+
135
+ class IterableSampleWithConstInfoDataset(IterableDataset):
136
+ """
137
+ A convenience dataset that augments each sample with additional info shared by all samples.
138
+ """
139
+ def __init__(self, samples_dataset: Dataset, *info: Any):
140
+ """
141
+ Args:
142
+ samples_dataset: any dataset containing samples.
143
+ *sample_info: one or more static entities to augment each sample.
144
+ """
145
+ self.samples_dataset = samples_dataset
146
+ self.info = info
147
+
148
+ def __iter__(self):
149
+ for sample in self.samples_dataset:
150
+ yield sample, *self.info
151
+
152
+
153
+ def get_collate_fn_with_extra_outputs(*extra_outputs: Any) -> Callable:
154
+ """ Collation function that adds const extra outputs to each batch. """
155
+ def f(batch):
156
+ return default_collate(batch) + list(extra_outputs)
157
+ return f
158
+
159
+
160
+ def data_gen_to_dataloader(data_gen_fn: Callable[[], Generator], batch_size, **kwargs):
161
+ """ Create DataLoader based on samples yielded by data_gen. """
162
+ dataset = IterableDatasetFromGenerator(data_gen_fn)
163
+ return DataLoader(dataset, batch_size=batch_size, **kwargs)
@@ -15,20 +15,19 @@
15
15
 
16
16
  from typing import List
17
17
 
18
+ import numpy as np
19
+ import torch
18
20
  from torch import autograd
19
21
  from tqdm import tqdm
20
- import numpy as np
21
22
 
22
23
  from model_compression_toolkit.constants import MIN_HESSIAN_ITER, HESSIAN_COMP_TOLERANCE, HESSIAN_NUM_ITERATIONS
23
24
  from model_compression_toolkit.core.common import Graph
24
- from model_compression_toolkit.core.common.hessian import (HessianScoresRequest, HessianScoresGranularity,
25
- HessianEstimationDistribution)
25
+ from model_compression_toolkit.core.common.hessian import HessianScoresRequest, HessianScoresGranularity
26
26
  from model_compression_toolkit.core.pytorch.back2framework.float_model_builder import FloatPyTorchModelBuilder
27
27
  from model_compression_toolkit.core.pytorch.hessian.hessian_scores_calculator_pytorch import \
28
28
  HessianScoresCalculatorPytorch
29
29
  from model_compression_toolkit.core.pytorch.utils import torch_tensor_to_numpy
30
30
  from model_compression_toolkit.logger import Logger
31
- import torch
32
31
 
33
32
 
34
33
  class ActivationHessianScoresCalculatorPytorch(HessianScoresCalculatorPytorch):
@@ -86,36 +85,12 @@ class ActivationHessianScoresCalculatorPytorch(HessianScoresCalculatorPytorch):
86
85
  target_activation_tensors = outputs[:num_target_nodes]
87
86
  # Extract the model outputs
88
87
  output_tensors = outputs[num_target_nodes:]
89
- device = output_tensors[0].device
90
88
 
91
89
  # Concat outputs
92
90
  # First, we need to unfold all outputs that are given as list, to extract the actual output tensors
93
91
  output = self.concat_tensors(output_tensors)
94
92
  return output, target_activation_tensors
95
93
 
96
- def _generate_random_vectors_batch(self, shape: tuple, distribution: HessianEstimationDistribution,
97
- device: torch.device) -> torch.Tensor:
98
- """
99
- Generate a batch of random vectors for Hutchinson estimation
100
-
101
- Args:
102
- shape: target shape
103
- distribution: distribution to sample from
104
- device: target device
105
-
106
- Returns:
107
- Random tensor
108
- """
109
- if distribution == HessianEstimationDistribution.GAUSSIAN:
110
- return torch.randn(shape, device=device)
111
-
112
- if distribution == HessianEstimationDistribution.RADEMACHER:
113
- v = torch.randint(high=2, size=shape, device=device)
114
- v[v == 0] = -1
115
- return v
116
-
117
- raise ValueError(f'Unknown distribution {distribution}') # pragma: no cover
118
-
119
94
  def compute(self) -> List[np.ndarray]:
120
95
  """
121
96
  Compute the scores that are based on the approximation of the Hessian w.r.t the requested target nodes' activations.
@@ -142,8 +117,8 @@ class ActivationHessianScoresCalculatorPytorch(HessianScoresCalculatorPytorch):
142
117
  for _ in range(len(target_activation_tensors))]
143
118
  prev_mean_results = None
144
119
  for j in tqdm(range(self.num_iterations_for_approximation), "Hessian random iterations"): # Approximation iterations
145
- # Getting a random vector with normal distribution
146
- v = self._generate_random_vectors_batch(output.shape, self.hessian_request.distribution, output.device)
120
+ # Getting a random vector
121
+ v = self._generate_random_vectors_batch(output.shape, output.device)
147
122
  f_v = torch.sum(v * output)
148
123
  for i, ipt_tensor in enumerate(target_activation_tensors): # Per Interest point activation tensor
149
124
  # Computing the hessian-approximation scores by getting the gradient of (output * v)
@@ -184,7 +159,7 @@ class ActivationHessianScoresCalculatorPytorch(HessianScoresCalculatorPytorch):
184
159
  for _ in range(len(target_activation_tensors))]
185
160
 
186
161
  for j in tqdm(range(self.num_iterations_for_approximation), "Hessian random iterations"): # Approximation iterations
187
- v = self._generate_random_vectors_batch(output.shape, self.hessian_request.distribution, output.device)
162
+ v = self._generate_random_vectors_batch(output.shape, output.device)
188
163
  f_v = torch.sum(v * output)
189
164
  for i, ipt_tensor in enumerate(target_activation_tensors): # Per Interest point activation tensor
190
165
  hess_v = autograd.grad(outputs=f_v,