mct-nightly 2.2.0.20241022.507__py3-none-any.whl → 2.2.0.20241024.501__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (38) hide show
  1. {mct_nightly-2.2.0.20241022.507.dist-info → mct_nightly-2.2.0.20241024.501.dist-info}/METADATA +1 -1
  2. {mct_nightly-2.2.0.20241022.507.dist-info → mct_nightly-2.2.0.20241024.501.dist-info}/RECORD +38 -31
  3. model_compression_toolkit/__init__.py +1 -1
  4. model_compression_toolkit/core/common/framework_implementation.py +43 -29
  5. model_compression_toolkit/core/common/hessian/__init__.py +1 -1
  6. model_compression_toolkit/core/common/hessian/hessian_info_service.py +222 -371
  7. model_compression_toolkit/core/common/hessian/hessian_scores_request.py +27 -41
  8. model_compression_toolkit/core/common/mixed_precision/sensitivity_evaluation.py +8 -10
  9. model_compression_toolkit/core/common/pruning/importance_metrics/lfh_importance_metric.py +11 -9
  10. model_compression_toolkit/core/common/quantization/quantization_params_generation/error_functions.py +10 -6
  11. model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_computation.py +17 -15
  12. model_compression_toolkit/core/keras/data_util.py +67 -0
  13. model_compression_toolkit/core/keras/keras_implementation.py +7 -1
  14. model_compression_toolkit/core/keras/tf_tensor_numpy.py +1 -1
  15. model_compression_toolkit/core/pytorch/back2framework/pytorch_model_builder.py +1 -1
  16. model_compression_toolkit/core/pytorch/data_util.py +163 -0
  17. model_compression_toolkit/core/pytorch/hessian/activation_hessian_scores_calculator_pytorch.py +6 -31
  18. model_compression_toolkit/core/pytorch/hessian/hessian_scores_calculator_pytorch.py +11 -21
  19. model_compression_toolkit/core/pytorch/hessian/weights_hessian_scores_calculator_pytorch.py +9 -7
  20. model_compression_toolkit/core/pytorch/pytorch_implementation.py +8 -2
  21. model_compression_toolkit/core/pytorch/utils.py +22 -19
  22. model_compression_toolkit/core/quantization_prep_runner.py +2 -1
  23. model_compression_toolkit/core/runner.py +1 -2
  24. model_compression_toolkit/gptq/common/gptq_config.py +0 -2
  25. model_compression_toolkit/gptq/common/gptq_training.py +58 -114
  26. model_compression_toolkit/gptq/keras/gptq_training.py +15 -6
  27. model_compression_toolkit/gptq/pytorch/gptq_loss.py +3 -2
  28. model_compression_toolkit/gptq/pytorch/gptq_training.py +97 -64
  29. model_compression_toolkit/gptq/pytorch/quantization_facade.py +0 -2
  30. model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/soft_quantizer_reg.py +4 -3
  31. tests_pytest/keras/__init__.py +14 -0
  32. tests_pytest/keras/core/__init__.py +14 -0
  33. tests_pytest/keras/core/test_data_util.py +91 -0
  34. tests_pytest/pytorch/core/__init__.py +14 -0
  35. tests_pytest/pytorch/core/test_data_util.py +125 -0
  36. {mct_nightly-2.2.0.20241022.507.dist-info → mct_nightly-2.2.0.20241024.501.dist-info}/LICENSE.md +0 -0
  37. {mct_nightly-2.2.0.20241022.507.dist-info → mct_nightly-2.2.0.20241024.501.dist-info}/WHEEL +0 -0
  38. {mct_nightly-2.2.0.20241022.507.dist-info → mct_nightly-2.2.0.20241024.501.dist-info}/top_level.txt +0 -0
@@ -12,21 +12,102 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
- import hashlib
15
+ from dataclasses import dataclass
16
+ from typing import List, Dict, Tuple, TYPE_CHECKING
16
17
 
17
18
  import numpy as np
18
- from functools import partial
19
- from tqdm import tqdm
20
- from typing import Callable, List, Dict, Any, Tuple, TYPE_CHECKING
21
19
 
22
20
  from model_compression_toolkit.constants import HESSIAN_NUM_ITERATIONS
23
- from model_compression_toolkit.core.common.hessian.hessian_scores_request import HessianScoresRequest, \
24
- HessianScoresGranularity, HessianMode
25
- from model_compression_toolkit.logger import Logger
21
+ from model_compression_toolkit.core.common.hessian.hessian_scores_request import HessianScoresRequest, HessianMode, \
22
+ HessianScoresGranularity
23
+
26
24
  if TYPE_CHECKING: # pragma: no cover
27
25
  from model_compression_toolkit.core.common import BaseNode
28
26
 
29
27
 
28
+ # type hints aliases
29
+ LayerName = str
30
+ Tensor = np.ndarray
31
+
32
+
33
+ @dataclass(eq=True, frozen=True)
34
+ class Query:
35
+ """ Query key for hessians cache. """
36
+ mode: HessianMode
37
+ granularity: HessianScoresGranularity
38
+ node: LayerName
39
+
40
+
41
+ class HessianCache:
42
+ """ Hessian cache """
43
+ def __init__(self):
44
+ self._data: Dict[Query, Tensor] = {}
45
+
46
+ def update(self, layers_hessians: Dict[str, np.ndarray], request: HessianScoresRequest) -> int:
47
+ """
48
+ Updates the cache with new hessians estimations.
49
+
50
+ Note: we assume that the new hessians were computed on different samples than previously stored hessians.
51
+ If same samples were used more than once, duplicates will be stored. This can only be a problem if hessians
52
+ for the same query were computed via multiple requests and dataloader in each request yields same samples.
53
+ We cannot just filter out duplicates since in some cases we can get valid identical hessians on different
54
+ samples.
55
+
56
+ Args:
57
+ layers_hessians: a dictionary from layer names to their hessian score tensors.
58
+ request: request per which hessians were computed.
59
+
60
+ Returns:
61
+ Minimal samples count after update (among updated layers).
62
+
63
+ """
64
+ assert set(layers_hessians.keys()) == set(n.name for n in request.target_nodes)
65
+ n_nodes_samples = [] # samples count per node after update
66
+ for node_name, hess in layers_hessians.items():
67
+ query = Query(request.mode, request.granularity, node_name)
68
+ saved_hess = self._data.get(query)
69
+ new_hess = hess if saved_hess is None else np.concatenate([saved_hess, hess], axis=0)
70
+ self._data[query] = new_hess
71
+ n_nodes_samples.append(new_hess.shape[0])
72
+
73
+ return min(n_nodes_samples)
74
+
75
+ def fetch_hessian(self, request: HessianScoresRequest) -> Tuple[Dict[LayerName, Tensor], Dict[LayerName, int]]:
76
+ """
77
+ Fetch available hessians per request and identify missing samples.
78
+
79
+ Note: if fewer samples are available than requested, hessians tensor will contain the available samples.
80
+
81
+ Args:
82
+ request: hessians fetch request.
83
+
84
+ Returns:
85
+ A tuple of two dictionaries:
86
+ - A dictionary from layer name to a tensor of its hessian.
87
+ - A dictionary from layer name to a number of missing samples.
88
+ """
89
+ assert request.n_samples is not None
90
+
91
+ result = {}
92
+ missing = {}
93
+ for node in request.target_nodes:
94
+ query = Query(request.mode, request.granularity, node.name)
95
+ hess = self._data.get(query)
96
+ if hess is None:
97
+ missing[node.name] = request.n_samples
98
+ continue
99
+ n_missing = request.n_samples - hess.shape[0]
100
+ if n_missing > 0:
101
+ missing[node.name] = n_missing
102
+ result[node.name] = hess[:request.n_samples, ...]
103
+
104
+ return result, missing
105
+
106
+ def clear(self):
107
+ """ Clear the cache. """
108
+ self._data.clear()
109
+
110
+
30
111
  class HessianInfoService:
31
112
  """
32
113
  A service to manage, store, and compute information based on the Hessian matrix approximation.
@@ -44,411 +125,181 @@ class HessianInfoService:
44
125
 
45
126
  def __init__(self,
46
127
  graph,
47
- representative_dataset_gen: Callable,
48
128
  fw_impl,
49
129
  num_iterations_for_approximation: int = HESSIAN_NUM_ITERATIONS):
50
130
  """
51
-
52
131
  Args:
53
132
  graph: Float graph.
54
- representative_dataset_gen: A callable that provides a dataset for sampling.
55
133
  fw_impl: Framework-specific implementation for Hessian approximation scores computation.
134
+ num_iterations_for_approximation: the number of iterations for hessian estimation.
56
135
  """
57
136
  self.graph = graph
58
-
59
- self.representative_dataset_gen = representative_dataset_gen
60
-
61
137
  self.fw_impl = fw_impl
62
138
  self.num_iterations_for_approximation = num_iterations_for_approximation
139
+ self.cache = HessianCache()
63
140
 
64
- self.hessian_scores_request_to_scores_list = {}
65
-
66
- def _sample_batch_representative_dataset(self,
67
- representative_dataset: Any,
68
- num_hessian_samples: int,
69
- num_inputs: int,
70
- last_iter_remain_samples: List[List[np.ndarray]] = None
71
- ) -> Tuple[List[np.ndarray], List[List[np.ndarray]]]:
141
+ def fetch_hessian(self, request: HessianScoresRequest,
142
+ force_compute: bool = False) -> Dict[LayerName, Tensor]:
72
143
  """
73
- Get a batch of samples from a representative dataset with the requested num_hessian_samples.
144
+ Fetch hessians per request.
145
+ If 'force_compute' is False, will first try to retrieve previously cached hessians. If no or not enough
146
+ hessians are found in the cache, will compute the remaining number of hessians to fulfill the request.
147
+ If 'force_compute' is True, will compute the hessians (use when you need hessians for specific inputs).
74
148
 
75
149
  Args:
76
- representative_dataset: A generator which yields batches of input samples.
77
- num_hessian_samples: Number of requested samples to compute batch Hessian approximation scores.
78
- num_inputs: Number of input layers of the model on which the scores are computed.
79
- last_iter_remain_samples: A list of input samples (for each input layer) with remaining samples from
80
- previous iterations.
81
-
82
- Returns: A tuple with two lists:
83
- (1) A list of inputs - a tensor of the requested batch size for each input layer.
84
- (2) A list of remaining samples - for each input layer.
85
- """
86
-
87
- if num_inputs < 0: # pragma: no cover
88
- Logger.critical(f"Number of images to compute Hessian approximation must be positive, "
89
- f"but given {num_inputs}.")
90
-
91
- all_inp_hessian_samples = [[] for _ in range(num_inputs)]
92
- all_inp_remaining_samples = [[] for _ in range(num_inputs)]
93
-
94
- # Collect the requested number of samples from the representative dataset
95
- # In case there are samples left from previous iterations, we use them first
96
- # otherwise, we take a batch from the representative dataset generator
97
- while len(all_inp_hessian_samples[0]) < num_hessian_samples:
98
- batch = None
99
- sampling_from_repr = True
100
- if last_iter_remain_samples is not None and len(last_iter_remain_samples[0]) >= num_hessian_samples:
101
- batch = last_iter_remain_samples
102
- sampling_from_repr = False
103
- else:
104
- try:
105
- batch = next(representative_dataset)
106
- except StopIteration:
107
- Logger.critical(
108
- f"Not enough samples in the provided representative dataset to compute Hessian approximation on "
109
- f"{num_hessian_samples} samples.")
110
-
111
- if batch is not None and not isinstance(batch, list):
112
- Logger.critical(f'Expected batch to be a list; found type: {type(batch)}.') # pragma: no cover
113
-
114
- for inp_idx in range(len(batch)):
115
- inp_batch = batch[inp_idx] if sampling_from_repr else np.stack(batch[inp_idx], axis=0)
116
- if not sampling_from_repr:
117
- last_iter_remain_samples[inp_idx] = []
118
-
119
- # Compute number of missing samples to get to the requested amount from the current batch
120
- num_missing = min(num_hessian_samples - len(all_inp_hessian_samples[inp_idx]), inp_batch.shape[0])
121
-
122
- # Append each sample separately
123
- samples = [s for s in inp_batch[0:num_missing, ...]]
124
- remaining_samples = [s for s in inp_batch[num_missing:, ...]]
125
-
126
- all_inp_hessian_samples[inp_idx] += [sample.reshape(1, *sample.shape) for sample in samples]
127
-
128
- # This list can only get filled on the last batch iteration
129
- all_inp_remaining_samples[inp_idx] += remaining_samples
130
-
131
- if len(all_inp_hessian_samples[0]) > num_hessian_samples:
132
- Logger.critical(f"Requested {num_hessian_samples} samples for computing Hessian approximation but "
133
- f"{len(all_inp_hessian_samples[0])} were collected.") # pragma: no cover
134
-
135
- # Collected enough samples, constructing a dataset with the requested batch size
136
- hessian_samples_for_input = []
137
- for inp_samples in all_inp_hessian_samples:
138
- inp_samples = np.concatenate(inp_samples, axis=0)
139
- num_collected_samples = inp_samples.shape[0]
140
- inp_samples = np.split(inp_samples,
141
- num_collected_samples // min(num_collected_samples, num_hessian_samples))
142
- hessian_samples_for_input.append(inp_samples[0])
143
-
144
- return hessian_samples_for_input, all_inp_remaining_samples
145
-
146
- def _clear_saved_hessian_info(self):
147
- """Clears the saved info approximations."""
148
- self.hessian_scores_request_to_scores_list={}
149
-
150
- def count_saved_scores_of_request(self, hessian_request: HessianScoresRequest) -> Dict:
151
- """
152
- Counts the saved approximations of Hessian scores for a specific request.
153
- If some approximations were computed for this request before, the amount of approximations (per image)
154
- will be returned. If not, zero is returned.
155
-
156
- Args:
157
- hessian_request: The request configuration for which to count the saved data.
150
+ request: request per which to fetch the hessians.
151
+ force_compute: if True, will compute the hessians.
152
+ If False, will look for cached hessians first.
158
153
 
159
154
  Returns:
160
- Number of saved approximations for the given request.
161
- """
162
-
163
- per_node_counter = {}
164
-
165
- for n in hessian_request.target_nodes:
166
- if n.reuse:
167
- # Reused nodes supposed to have been replaced with a reuse_group
168
- # representing node before calling this method.
169
- Logger.critical(f"Expecting the Hessian request to include only non-reused nodes at this point, "
170
- f"but found node {n.name} with 'reuse' status.")
171
- # Check if the request for this node is in the saved info and store its count, otherwise store 0
172
- per_node_counter[n] = len(self.hessian_scores_request_to_scores_list.get(hessian_request, []))
173
-
174
- return per_node_counter
175
-
176
- def compute(self,
177
- hessian_scores_request: HessianScoresRequest,
178
- representative_dataset_gen,
179
- num_hessian_samples: int,
180
- last_iter_remain_samples: List[List[np.ndarray]] = None):
181
- """
182
- Computes scores based on the Hessian matrix approximation according to the
183
- provided request configuration and stores it in the cache.
184
-
185
- Args:
186
- hessian_scores_request: Configuration for which to compute the approximation.
187
- representative_dataset_gen: A callable that provides a dataset for sampling.
188
- num_hessian_samples: Number of requested samples to compute batch Hessian approximation scores.
189
- last_iter_remain_samples: A list of input samples (for each input layer) with remaining samples from
190
- previous iterations.
155
+ A dictionary of layers' hessian tensors of shape (samples, ...). The exact shape depends on the
156
+ requested granularity.
191
157
  """
192
- Logger.debug(f"Computing Hessian-scores approximations for nodes {hessian_scores_request.target_nodes}.")
193
-
194
- images, next_iter_remain_samples = representative_dataset_gen(num_hessian_samples=num_hessian_samples,
195
- last_iter_remain_samples=last_iter_remain_samples)
196
-
197
- # Compute and store the computed approximation in the saved info
198
- topo_sorted_nodes_names = [x.name for x in self.graph.get_topo_sorted_nodes()]
199
- hessian_scores_request.target_nodes.sort(key=lambda x: topo_sorted_nodes_names.index(x.name))
200
-
201
- # Get the framework-specific calculator Hessian-approximation scores
202
- fw_hessian_calculator = self.fw_impl.get_hessian_scores_calculator(graph=self.graph,
203
- input_images=images,
204
- hessian_scores_request=hessian_scores_request,
205
- num_iterations_for_approximation=self.num_iterations_for_approximation)
206
-
207
- hessian_scores = fw_hessian_calculator.compute()
208
-
209
- for node, hessian in zip(hessian_scores_request.target_nodes, hessian_scores):
210
- single_node_request = self._construct_single_node_request(hessian_scores_request.mode,
211
- hessian_scores_request.granularity,
212
- node)
213
-
214
- # The hessian for each node is expected to be a tensor where the first axis represents the number of
215
- # images in the batch on which the approximation was computed.
216
- # We collect the results as a list of a result for images, which is combined across batches.
217
- # After conversion, hessian_scores_request_to_scores_list for a request of a single node should be a list of
218
- # results of all images, where each result is a tensor of the shape depending on the granularity.
219
- if single_node_request in self.hessian_scores_request_to_scores_list:
220
- self.hessian_scores_request_to_scores_list[single_node_request] += (
221
- self._convert_tensor_to_list_of_appx_results(hessian))
222
- else:
223
- self.hessian_scores_request_to_scores_list[single_node_request] = (
224
- self._convert_tensor_to_list_of_appx_results(hessian))
225
-
226
- # In case that we are required to return a number of scores that is larger that the computation batch size
227
- # and if in this case the computation batch size is smaller than the representative dataset batch size
228
- # we need to carry over remaining samples from the last fetched batch to the next computation, otherwise,
229
- # we might skip samples or remain without enough samples to complete the computations for the
230
- # requested number of scores.
231
- return next_iter_remain_samples if next_iter_remain_samples is not None and len(next_iter_remain_samples) > 0 \
232
- and len(next_iter_remain_samples[0]) > 0 else None
233
-
234
- def compute_trackable_per_sample_hessian(self,
235
- hessian_scores_request: HessianScoresRequest,
236
- inputs_batch: List[np.ndarray]) -> Dict[str, Dict['BaseNode', np.ndarray]]:
158
+ if request.n_samples is None and not force_compute:
159
+ raise ValueError('Number of samples can be None only when force_compute is True.')
160
+
161
+ orig_request = request
162
+ # replace reused nodes with primary nodes
163
+ # TODO need to check if there is a bug in reuse. While this is the same layer, the compare tensors and their
164
+ # gradients are not. It seems that currently the same compare tensor of the primary node is used multiple times
165
+ target_nodes = [self._get_primary_node(n) for n in request.target_nodes]
166
+ request = request.clone(target_nodes=target_nodes)
167
+
168
+ if force_compute:
169
+ res = self._compute_hessians(request, self.num_iterations_for_approximation, count_by_cache=False)
170
+ else:
171
+ res = self._fetch_hessians_with_compute(request, self.num_iterations_for_approximation)
172
+
173
+ # restore nodes from the original request
174
+ res = {n_orig.name: res[n.name] for n_orig, n in zip(orig_request.target_nodes, request.target_nodes)}
175
+ return res
176
+
177
+ def clear_cache(self):
178
+ """ Purge the cached hessians. """
179
+ self.cache.clear()
180
+
181
+ def _fetch_hessians_with_compute(self, request: HessianScoresRequest, n_iterations: int) -> Dict[LayerName, Tensor]:
237
182
  """
238
- Compute hessian score per image hash. We compute the score directly for images rather than via data generator,
239
- as data generator might yield different images each time, depending on how it was defined,
183
+ Fetch pre-computed hessians for the request if available. Otherwise, compute the missing hessians.
240
184
 
241
185
  Args:
242
- hessian_scores_request: hessian scores request
243
- inputs_batch: a list containing a batch of inputs.
186
+ request: hessian estimation request.
187
+ n_iterations: the number of iterations for hessian estimation.
244
188
 
245
189
  Returns:
246
- A dict of Hessian scores per image hash per layer {image hash: {layer: score}}
190
+ A dictionary from layers (by name) to their hessians.
247
191
  """
248
- topo_sorted_nodes_names = [x.name for x in self.graph.get_topo_sorted_nodes()]
249
- hessian_scores_request.target_nodes.sort(key=lambda x: topo_sorted_nodes_names.index(x.name))
250
-
251
- hessian_score_by_image_hash = {}
252
-
253
- if not inputs_batch or not isinstance(inputs_batch, list):
254
- raise TypeError('Expected a non-empty list of inputs') # pragma: no cover
255
- if len(inputs_batch) > 1:
256
- raise NotImplementedError('Per-sample hessian computation is not supported for networks with multiple inputs') # pragma: no cover
257
-
258
- # Get the framework-specific calculator Hessian-approximation scores
259
- fw_hessian_calculator = self.fw_impl.get_hessian_scores_calculator(graph=self.graph,
260
- input_images=inputs_batch,
261
- hessian_scores_request=hessian_scores_request,
262
- num_iterations_for_approximation=self.num_iterations_for_approximation)
263
- hessian_scores = fw_hessian_calculator.compute()
264
- for i in range(inputs_batch[0].shape[0]):
265
- img_hash = self.calc_image_hash(inputs_batch[0][i])
266
- hessian_score_by_image_hash[img_hash] = {
267
- node: score[i] for node, score in zip(hessian_scores_request.target_nodes, hessian_scores)
268
- }
269
-
270
- return hessian_score_by_image_hash
271
-
272
- @staticmethod
273
- def calc_image_hash(image):
192
+ res, missing = self.cache.fetch_hessian(request)
193
+ if not missing:
194
+ return res
195
+
196
+ if request.data_loader is None:
197
+ raise ValueError(f'Not enough hessians are cached to fulfill the request, but data loader was not passed '
198
+ f'for additional computation. Requested {request.n_samples}, '
199
+ f'available {min(missing.values())}.')
200
+
201
+ orig_request = request
202
+ # if some hessians were found generate a new request only for missing nodes.
203
+ if res:
204
+ target_nodes = [n for n in orig_request.target_nodes if n.name in missing]
205
+ request = request.clone(target_nodes=target_nodes)
206
+ self._compute_hessians(request, n_iterations, count_by_cache=True)
207
+ res, missing = self.cache.fetch_hessian(request)
208
+ assert not missing
209
+ return res
210
+
211
+ def _compute_hessians(self, request: HessianScoresRequest,
212
+ n_iterations: int, count_by_cache: bool) -> Dict[LayerName, Tensor]:
274
213
  """
275
- Calculates hash for an input image.
276
-
277
- Args:
278
- image: input 3d image (without batch).
214
+ Computes hessian estimation per request.
279
215
 
280
- Returns:
281
- Image hash.
216
+ Data loader from request is used as is, i.e. it should reflect the required batch size (e.g. if
217
+ hessians should be estimated sample by sample, the data loader should yield a single sample at a time).
282
218
 
283
- """
284
- if not len(image.shape) == 3: # pragma: no cover
285
- raise ValueError(f'Expected 3d image (without batch) for image hash calculation, got {len(image.shape)}')
286
- image_bytes = image.astype(np.float32).tobytes()
287
- return hashlib.md5(image_bytes).hexdigest()
288
-
289
- def fetch_hessian(self,
290
- hessian_scores_request: HessianScoresRequest,
291
- required_size: int,
292
- batch_size: int = 1) -> List[List[np.ndarray]]:
293
- """
294
- Fetches the computed approximations of the Hessian-based scores for the given
295
- request and required size.
219
+ NOTE: the returned value only contains hessians that were computed here, which may differ from the requested
220
+ number of samples. It's only intended for use when you specifically need sample-wise hessians for the
221
+ samples in the request.
296
222
 
297
223
  Args:
298
- hessian_scores_request: Configuration for which to fetch the approximation.
299
- required_size: Number of approximations required.
300
- batch_size: The Hessian computation batch size.
301
-
224
+ request: hessian estimation request.
225
+ n_iterations: the number of iterations for hessian estimation.
226
+ count_by_cache: if True, computes hessians until the cache contains the requested number of samples.
227
+ if False, computes hessian for the first requested number of sample in the dataloader.
302
228
  Returns:
303
- List[List[np.ndarray]]: For each target node, returns a list of computed approximations.
304
- The outer list is per image (thus, has the length as required_size).
305
- The inner list length dependent on the granularity (1 for per-tensor,
306
- OC for per-output-channel when the requested node has OC output-channels, etc.)
307
- """
308
-
309
- if len(hessian_scores_request.target_nodes) == 0: # pragma: no cover
310
- return []
311
-
312
- if required_size == 0:
313
- return [[] for _ in hessian_scores_request.target_nodes]
314
-
315
- Logger.info(f"\nEnsuring {required_size} Hessian-approximation scores for nodes "
316
- f"{hessian_scores_request.target_nodes}.")
317
-
318
- # Replace node in reused target nodes with a representing node from the 'reuse group'.
319
- hessian_scores_request.target_nodes = [
320
- self._get_representing_of_reuse_group(node) if node.reuse else node
321
- for node in hessian_scores_request.target_nodes
322
- ]
323
-
324
- # Ensure the saved info has the required number of approximations
325
- self._populate_saved_info_to_size(hessian_scores_request, required_size, batch_size)
326
-
327
- # Return the saved approximations for the given request
328
- return self._collect_saved_hessians_for_request(hessian_scores_request, required_size)
329
-
330
- def _get_representing_of_reuse_group(self, node) -> Any:
331
- """
332
- For each reused group we compute and fetch its members using a single request.
333
- This method creates and returns a request for the reused group the node is in.
334
-
335
- Args:
336
- node: The node to get its reuse group representative node.
337
-
338
- Returns: A reuse group representative node (BaseNode).
339
- """
340
- father_nodes = [n for n in self.graph.nodes if not n.reuse and n.reuse_group == node.reuse_group]
341
- if len(father_nodes) != 1: # pragma: no cover
342
- Logger.critical(f"Expected a single non-reused node in the reused group, "
343
- f"but found {len(father_nodes)}.")
344
-
345
- return father_nodes[0]
346
-
347
- def _populate_saved_info_to_size(self,
348
- hessian_scores_request: HessianScoresRequest,
349
- required_size: int,
350
- batch_size: int = 1):
229
+ A dictionary from layers (by name) to their hessian tensors that *were computed in this invocation*.
230
+ First axis corresponds to samples in the order determined by the data loader.
351
231
  """
352
- Ensures that the saved info has the required size of Hessian approximation scores for the given request.
353
-
354
- Args:
355
- hessian_scores_request: Configuration of the request to ensure the saved info size.
356
- required_size: Required number of Hessian-approximation scores.
357
- batch_size: The Hessian computation batch size.
232
+ if count_by_cache:
233
+ assert request.n_samples is not None
234
+
235
+ n_samples = 0
236
+ hess_per_layer = []
237
+ for batch in request.data_loader:
238
+ batch_hess_per_layer = self._compute_hessian_for_batch(request, batch, n_iterations)
239
+ hess_per_layer.append(batch_hess_per_layer)
240
+ min_count = self.cache.update(batch_hess_per_layer, request)
241
+ n_samples = min_count if count_by_cache else (n_samples + batch[0].shape[0])
242
+ if request.n_samples and n_samples >= request.n_samples:
243
+ break
244
+
245
+ hess_per_layer = {
246
+ layer.name: np.concatenate([hess[layer.name] for hess in hess_per_layer], axis=0)
247
+ for layer in request.target_nodes
248
+ }
249
+
250
+ if request.n_samples:
251
+ if n_samples < request.n_samples:
252
+ raise ValueError(f'Could not compute the requested number of Hessians ({request.n_samples}), '
253
+ f'not enough samples in the provided representative dataset.')
254
+
255
+ if n_samples > request.n_samples:
256
+ hess_per_layer = {
257
+ layer: hess[:request.n_samples, ...] for layer, hess in hess_per_layer.items()
258
+ }
259
+ return hess_per_layer
260
+
261
+ def _compute_hessian_for_batch(self,
262
+ request: HessianScoresRequest,
263
+ inputs_batch: List[Tensor],
264
+ n_iterations: int) -> Dict[LayerName, Tensor]:
358
265
  """
359
-
360
- # Get the current number of saved approximations for each node in the request
361
- current_existing_hessians = self.count_saved_scores_of_request(hessian_scores_request)
362
-
363
- # Compute the required number of approximations to meet the required size.
364
- # Since we allow batch and multi-nodes computation, we take the node with the maximal number of missing
365
- # approximations to compute, and run batch computations until meeting the requirement.
366
- min_exist_hessians = min(current_existing_hessians.values())
367
- max_remaining_hessians = required_size - min_exist_hessians
368
-
369
- Logger.info(
370
- f"Running Hessian approximation computation for {len(hessian_scores_request.target_nodes)} nodes.\n "
371
- f"The node with minimal existing Hessian-approximation scores has {min_exist_hessians} "
372
- f"approximated scores computed.\n"
373
- f"{max_remaining_hessians} approximations left to compute...")
374
-
375
- hessian_representative_dataset = partial(self._sample_batch_representative_dataset,
376
- num_inputs=len(self.graph.input_nodes),
377
- representative_dataset=self.representative_dataset_gen())
378
-
379
- next_iter_remaining_samples = None
380
- pbar = tqdm(desc="Computing Hessian approximations...", total=None)
381
- while max_remaining_hessians > 0:
382
- # If batch_size < max_remaining_hessians then we run each computation on a batch_size of images.
383
- # This way, we always run a computation for a single batch.
384
- pbar.update(1)
385
- size_to_compute = min(max_remaining_hessians, batch_size)
386
- next_iter_remaining_samples = (
387
- self.compute(hessian_scores_request, hessian_representative_dataset, size_to_compute,
388
- last_iter_remain_samples=next_iter_remaining_samples))
389
- max_remaining_hessians -= size_to_compute
390
-
391
- def _collect_saved_hessians_for_request(self,
392
- hessian_scores_request: HessianScoresRequest,
393
- required_size: int) -> List[List[np.ndarray]]:
394
- """
395
- Collects Hessian approximation for the nodes in the given request.
266
+ Use hessian score calculator to compute hessian approximations for a batch of inputs.
396
267
 
397
268
  Args:
398
- hessian_scores_request: Configuration for which to fetch the approximation.
399
- required_size: Required number of Hessian-approximation scores.
400
-
401
- Returns: A list with List of computed Hessian approximation (a tensor for each score) for each node
402
- in the request.
269
+ request: hessian estimation request.
270
+ inputs_batch: a batch of inputs to estimate hessians on.
271
+ n_iterations: the number of iterations for hessian estimation.
403
272
 
273
+ Returns:
274
+ A dictionary from layers (by name) to their hessians.
404
275
  """
405
- collected_results = []
406
- for node in hessian_scores_request.target_nodes:
407
- single_node_request = self._construct_single_node_request(hessian_scores_request.mode,
408
- hessian_scores_request.granularity,
409
- node)
410
-
411
- res_for_node = self.hessian_scores_request_to_scores_list.get(single_node_request)
412
- if res_for_node is None: # pragma: no cover
413
- Logger.critical(f"Couldn't find saved Hessian approximations for node {node.name}.")
414
- if len(res_for_node) < required_size: # pragma: no cover
415
- Logger.critical(f"Missing Hessian approximations for node {node.name}, requested {required_size} "
416
- f"but found only {len(res_for_node)}.")
417
-
418
- res_for_node = res_for_node[:required_size]
276
+ fw_hessian_calculator = self.fw_impl.get_hessian_scores_calculator(
277
+ graph=self.graph,
278
+ input_images=inputs_batch,
279
+ hessian_scores_request=request,
280
+ num_iterations_for_approximation=n_iterations
281
+ )
419
282
 
420
- collected_results.append(res_for_node)
283
+ hessian_scores: list = fw_hessian_calculator.compute()
421
284
 
422
- return collected_results
285
+ layers_hessian_scores = {
286
+ layer.name: score for layer, score in zip(request.target_nodes, hessian_scores)
287
+ }
288
+ return layers_hessian_scores
423
289
 
424
- @staticmethod
425
- def _construct_single_node_request(mode: HessianMode,
426
- granularity: HessianScoresGranularity,
427
- target_nodes: List) -> HessianScoresRequest:
290
+ def _get_primary_node(self, node: 'BaseNode') -> 'BaseNode':
428
291
  """
429
- Constructs a Hessian request with for a single node. Used for retrieving and maintaining cached results.
292
+ Get node's primary node that it reuses, or itself if not reused.
430
293
 
431
294
  Args:
432
- mode (HessianMode): Mode of Hessian's approximation (w.r.t weights or activations).
433
- granularity (HessianScoresGranularity): Granularity level for the approximation.
434
- target_nodes (List[BaseNode]): The node in the float graph for which the Hessian's approximation scores is targeted.
295
+ node: node's object to get its primary node.
435
296
 
436
- Returns: A HessianScoresRequest with the given details for the requested node.
437
-
438
- """
439
- return HessianScoresRequest(mode,
440
- granularity,
441
- target_nodes=[target_nodes])
442
-
443
- @staticmethod
444
- def _convert_tensor_to_list_of_appx_results(t: Any) -> List:
297
+ Returns:
298
+ Node's primary node.
445
299
  """
446
- Converts a tensor with batch computation results to a list of individual result for each sample in batch.
447
-
448
- Args:
449
- t: A tensor with Hessian approximation results.
450
-
451
- Returns: A list with split batch into individual results.
300
+ if node.reuse is False:
301
+ return node
452
302
 
453
- """
454
- return [t[i:i+1, :] for i in range(t.shape[0])]
303
+ father_nodes = [n for n in self.graph.nodes if not n.reuse and n.reuse_group == node.reuse_group]
304
+ assert len(father_nodes) == 1
305
+ return father_nodes[0]