mct-nightly 2.2.0.20241022.507__py3-none-any.whl → 2.2.0.20241024.501__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mct_nightly-2.2.0.20241022.507.dist-info → mct_nightly-2.2.0.20241024.501.dist-info}/METADATA +1 -1
- {mct_nightly-2.2.0.20241022.507.dist-info → mct_nightly-2.2.0.20241024.501.dist-info}/RECORD +38 -31
- model_compression_toolkit/__init__.py +1 -1
- model_compression_toolkit/core/common/framework_implementation.py +43 -29
- model_compression_toolkit/core/common/hessian/__init__.py +1 -1
- model_compression_toolkit/core/common/hessian/hessian_info_service.py +222 -371
- model_compression_toolkit/core/common/hessian/hessian_scores_request.py +27 -41
- model_compression_toolkit/core/common/mixed_precision/sensitivity_evaluation.py +8 -10
- model_compression_toolkit/core/common/pruning/importance_metrics/lfh_importance_metric.py +11 -9
- model_compression_toolkit/core/common/quantization/quantization_params_generation/error_functions.py +10 -6
- model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_computation.py +17 -15
- model_compression_toolkit/core/keras/data_util.py +67 -0
- model_compression_toolkit/core/keras/keras_implementation.py +7 -1
- model_compression_toolkit/core/keras/tf_tensor_numpy.py +1 -1
- model_compression_toolkit/core/pytorch/back2framework/pytorch_model_builder.py +1 -1
- model_compression_toolkit/core/pytorch/data_util.py +163 -0
- model_compression_toolkit/core/pytorch/hessian/activation_hessian_scores_calculator_pytorch.py +6 -31
- model_compression_toolkit/core/pytorch/hessian/hessian_scores_calculator_pytorch.py +11 -21
- model_compression_toolkit/core/pytorch/hessian/weights_hessian_scores_calculator_pytorch.py +9 -7
- model_compression_toolkit/core/pytorch/pytorch_implementation.py +8 -2
- model_compression_toolkit/core/pytorch/utils.py +22 -19
- model_compression_toolkit/core/quantization_prep_runner.py +2 -1
- model_compression_toolkit/core/runner.py +1 -2
- model_compression_toolkit/gptq/common/gptq_config.py +0 -2
- model_compression_toolkit/gptq/common/gptq_training.py +58 -114
- model_compression_toolkit/gptq/keras/gptq_training.py +15 -6
- model_compression_toolkit/gptq/pytorch/gptq_loss.py +3 -2
- model_compression_toolkit/gptq/pytorch/gptq_training.py +97 -64
- model_compression_toolkit/gptq/pytorch/quantization_facade.py +0 -2
- model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/soft_quantizer_reg.py +4 -3
- tests_pytest/keras/__init__.py +14 -0
- tests_pytest/keras/core/__init__.py +14 -0
- tests_pytest/keras/core/test_data_util.py +91 -0
- tests_pytest/pytorch/core/__init__.py +14 -0
- tests_pytest/pytorch/core/test_data_util.py +125 -0
- {mct_nightly-2.2.0.20241022.507.dist-info → mct_nightly-2.2.0.20241024.501.dist-info}/LICENSE.md +0 -0
- {mct_nightly-2.2.0.20241022.507.dist-info → mct_nightly-2.2.0.20241024.501.dist-info}/WHEEL +0 -0
- {mct_nightly-2.2.0.20241022.507.dist-info → mct_nightly-2.2.0.20241024.501.dist-info}/top_level.txt +0 -0
@@ -12,21 +12,102 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
|
-
import
|
15
|
+
from dataclasses import dataclass
|
16
|
+
from typing import List, Dict, Tuple, TYPE_CHECKING
|
16
17
|
|
17
18
|
import numpy as np
|
18
|
-
from functools import partial
|
19
|
-
from tqdm import tqdm
|
20
|
-
from typing import Callable, List, Dict, Any, Tuple, TYPE_CHECKING
|
21
19
|
|
22
20
|
from model_compression_toolkit.constants import HESSIAN_NUM_ITERATIONS
|
23
|
-
from model_compression_toolkit.core.common.hessian.hessian_scores_request import HessianScoresRequest, \
|
24
|
-
HessianScoresGranularity
|
25
|
-
|
21
|
+
from model_compression_toolkit.core.common.hessian.hessian_scores_request import HessianScoresRequest, HessianMode, \
|
22
|
+
HessianScoresGranularity
|
23
|
+
|
26
24
|
if TYPE_CHECKING: # pragma: no cover
|
27
25
|
from model_compression_toolkit.core.common import BaseNode
|
28
26
|
|
29
27
|
|
28
|
+
# type hints aliases
|
29
|
+
LayerName = str
|
30
|
+
Tensor = np.ndarray
|
31
|
+
|
32
|
+
|
33
|
+
@dataclass(eq=True, frozen=True)
|
34
|
+
class Query:
|
35
|
+
""" Query key for hessians cache. """
|
36
|
+
mode: HessianMode
|
37
|
+
granularity: HessianScoresGranularity
|
38
|
+
node: LayerName
|
39
|
+
|
40
|
+
|
41
|
+
class HessianCache:
|
42
|
+
""" Hessian cache """
|
43
|
+
def __init__(self):
|
44
|
+
self._data: Dict[Query, Tensor] = {}
|
45
|
+
|
46
|
+
def update(self, layers_hessians: Dict[str, np.ndarray], request: HessianScoresRequest) -> int:
|
47
|
+
"""
|
48
|
+
Updates the cache with new hessians estimations.
|
49
|
+
|
50
|
+
Note: we assume that the new hessians were computed on different samples than previously stored hessians.
|
51
|
+
If same samples were used more than once, duplicates will be stored. This can only be a problem if hessians
|
52
|
+
for the same query were computed via multiple requests and dataloader in each request yields same samples.
|
53
|
+
We cannot just filter out duplicates since in some cases we can get valid identical hessians on different
|
54
|
+
samples.
|
55
|
+
|
56
|
+
Args:
|
57
|
+
layers_hessians: a dictionary from layer names to their hessian score tensors.
|
58
|
+
request: request per which hessians were computed.
|
59
|
+
|
60
|
+
Returns:
|
61
|
+
Minimal samples count after update (among updated layers).
|
62
|
+
|
63
|
+
"""
|
64
|
+
assert set(layers_hessians.keys()) == set(n.name for n in request.target_nodes)
|
65
|
+
n_nodes_samples = [] # samples count per node after update
|
66
|
+
for node_name, hess in layers_hessians.items():
|
67
|
+
query = Query(request.mode, request.granularity, node_name)
|
68
|
+
saved_hess = self._data.get(query)
|
69
|
+
new_hess = hess if saved_hess is None else np.concatenate([saved_hess, hess], axis=0)
|
70
|
+
self._data[query] = new_hess
|
71
|
+
n_nodes_samples.append(new_hess.shape[0])
|
72
|
+
|
73
|
+
return min(n_nodes_samples)
|
74
|
+
|
75
|
+
def fetch_hessian(self, request: HessianScoresRequest) -> Tuple[Dict[LayerName, Tensor], Dict[LayerName, int]]:
|
76
|
+
"""
|
77
|
+
Fetch available hessians per request and identify missing samples.
|
78
|
+
|
79
|
+
Note: if fewer samples are available than requested, hessians tensor will contain the available samples.
|
80
|
+
|
81
|
+
Args:
|
82
|
+
request: hessians fetch request.
|
83
|
+
|
84
|
+
Returns:
|
85
|
+
A tuple of two dictionaries:
|
86
|
+
- A dictionary from layer name to a tensor of its hessian.
|
87
|
+
- A dictionary from layer name to a number of missing samples.
|
88
|
+
"""
|
89
|
+
assert request.n_samples is not None
|
90
|
+
|
91
|
+
result = {}
|
92
|
+
missing = {}
|
93
|
+
for node in request.target_nodes:
|
94
|
+
query = Query(request.mode, request.granularity, node.name)
|
95
|
+
hess = self._data.get(query)
|
96
|
+
if hess is None:
|
97
|
+
missing[node.name] = request.n_samples
|
98
|
+
continue
|
99
|
+
n_missing = request.n_samples - hess.shape[0]
|
100
|
+
if n_missing > 0:
|
101
|
+
missing[node.name] = n_missing
|
102
|
+
result[node.name] = hess[:request.n_samples, ...]
|
103
|
+
|
104
|
+
return result, missing
|
105
|
+
|
106
|
+
def clear(self):
|
107
|
+
""" Clear the cache. """
|
108
|
+
self._data.clear()
|
109
|
+
|
110
|
+
|
30
111
|
class HessianInfoService:
|
31
112
|
"""
|
32
113
|
A service to manage, store, and compute information based on the Hessian matrix approximation.
|
@@ -44,411 +125,181 @@ class HessianInfoService:
|
|
44
125
|
|
45
126
|
def __init__(self,
|
46
127
|
graph,
|
47
|
-
representative_dataset_gen: Callable,
|
48
128
|
fw_impl,
|
49
129
|
num_iterations_for_approximation: int = HESSIAN_NUM_ITERATIONS):
|
50
130
|
"""
|
51
|
-
|
52
131
|
Args:
|
53
132
|
graph: Float graph.
|
54
|
-
representative_dataset_gen: A callable that provides a dataset for sampling.
|
55
133
|
fw_impl: Framework-specific implementation for Hessian approximation scores computation.
|
134
|
+
num_iterations_for_approximation: the number of iterations for hessian estimation.
|
56
135
|
"""
|
57
136
|
self.graph = graph
|
58
|
-
|
59
|
-
self.representative_dataset_gen = representative_dataset_gen
|
60
|
-
|
61
137
|
self.fw_impl = fw_impl
|
62
138
|
self.num_iterations_for_approximation = num_iterations_for_approximation
|
139
|
+
self.cache = HessianCache()
|
63
140
|
|
64
|
-
|
65
|
-
|
66
|
-
def _sample_batch_representative_dataset(self,
|
67
|
-
representative_dataset: Any,
|
68
|
-
num_hessian_samples: int,
|
69
|
-
num_inputs: int,
|
70
|
-
last_iter_remain_samples: List[List[np.ndarray]] = None
|
71
|
-
) -> Tuple[List[np.ndarray], List[List[np.ndarray]]]:
|
141
|
+
def fetch_hessian(self, request: HessianScoresRequest,
|
142
|
+
force_compute: bool = False) -> Dict[LayerName, Tensor]:
|
72
143
|
"""
|
73
|
-
|
144
|
+
Fetch hessians per request.
|
145
|
+
If 'force_compute' is False, will first try to retrieve previously cached hessians. If no or not enough
|
146
|
+
hessians are found in the cache, will compute the remaining number of hessians to fulfill the request.
|
147
|
+
If 'force_compute' is True, will compute the hessians (use when you need hessians for specific inputs).
|
74
148
|
|
75
149
|
Args:
|
76
|
-
|
77
|
-
|
78
|
-
|
79
|
-
last_iter_remain_samples: A list of input samples (for each input layer) with remaining samples from
|
80
|
-
previous iterations.
|
81
|
-
|
82
|
-
Returns: A tuple with two lists:
|
83
|
-
(1) A list of inputs - a tensor of the requested batch size for each input layer.
|
84
|
-
(2) A list of remaining samples - for each input layer.
|
85
|
-
"""
|
86
|
-
|
87
|
-
if num_inputs < 0: # pragma: no cover
|
88
|
-
Logger.critical(f"Number of images to compute Hessian approximation must be positive, "
|
89
|
-
f"but given {num_inputs}.")
|
90
|
-
|
91
|
-
all_inp_hessian_samples = [[] for _ in range(num_inputs)]
|
92
|
-
all_inp_remaining_samples = [[] for _ in range(num_inputs)]
|
93
|
-
|
94
|
-
# Collect the requested number of samples from the representative dataset
|
95
|
-
# In case there are samples left from previous iterations, we use them first
|
96
|
-
# otherwise, we take a batch from the representative dataset generator
|
97
|
-
while len(all_inp_hessian_samples[0]) < num_hessian_samples:
|
98
|
-
batch = None
|
99
|
-
sampling_from_repr = True
|
100
|
-
if last_iter_remain_samples is not None and len(last_iter_remain_samples[0]) >= num_hessian_samples:
|
101
|
-
batch = last_iter_remain_samples
|
102
|
-
sampling_from_repr = False
|
103
|
-
else:
|
104
|
-
try:
|
105
|
-
batch = next(representative_dataset)
|
106
|
-
except StopIteration:
|
107
|
-
Logger.critical(
|
108
|
-
f"Not enough samples in the provided representative dataset to compute Hessian approximation on "
|
109
|
-
f"{num_hessian_samples} samples.")
|
110
|
-
|
111
|
-
if batch is not None and not isinstance(batch, list):
|
112
|
-
Logger.critical(f'Expected batch to be a list; found type: {type(batch)}.') # pragma: no cover
|
113
|
-
|
114
|
-
for inp_idx in range(len(batch)):
|
115
|
-
inp_batch = batch[inp_idx] if sampling_from_repr else np.stack(batch[inp_idx], axis=0)
|
116
|
-
if not sampling_from_repr:
|
117
|
-
last_iter_remain_samples[inp_idx] = []
|
118
|
-
|
119
|
-
# Compute number of missing samples to get to the requested amount from the current batch
|
120
|
-
num_missing = min(num_hessian_samples - len(all_inp_hessian_samples[inp_idx]), inp_batch.shape[0])
|
121
|
-
|
122
|
-
# Append each sample separately
|
123
|
-
samples = [s for s in inp_batch[0:num_missing, ...]]
|
124
|
-
remaining_samples = [s for s in inp_batch[num_missing:, ...]]
|
125
|
-
|
126
|
-
all_inp_hessian_samples[inp_idx] += [sample.reshape(1, *sample.shape) for sample in samples]
|
127
|
-
|
128
|
-
# This list can only get filled on the last batch iteration
|
129
|
-
all_inp_remaining_samples[inp_idx] += remaining_samples
|
130
|
-
|
131
|
-
if len(all_inp_hessian_samples[0]) > num_hessian_samples:
|
132
|
-
Logger.critical(f"Requested {num_hessian_samples} samples for computing Hessian approximation but "
|
133
|
-
f"{len(all_inp_hessian_samples[0])} were collected.") # pragma: no cover
|
134
|
-
|
135
|
-
# Collected enough samples, constructing a dataset with the requested batch size
|
136
|
-
hessian_samples_for_input = []
|
137
|
-
for inp_samples in all_inp_hessian_samples:
|
138
|
-
inp_samples = np.concatenate(inp_samples, axis=0)
|
139
|
-
num_collected_samples = inp_samples.shape[0]
|
140
|
-
inp_samples = np.split(inp_samples,
|
141
|
-
num_collected_samples // min(num_collected_samples, num_hessian_samples))
|
142
|
-
hessian_samples_for_input.append(inp_samples[0])
|
143
|
-
|
144
|
-
return hessian_samples_for_input, all_inp_remaining_samples
|
145
|
-
|
146
|
-
def _clear_saved_hessian_info(self):
|
147
|
-
"""Clears the saved info approximations."""
|
148
|
-
self.hessian_scores_request_to_scores_list={}
|
149
|
-
|
150
|
-
def count_saved_scores_of_request(self, hessian_request: HessianScoresRequest) -> Dict:
|
151
|
-
"""
|
152
|
-
Counts the saved approximations of Hessian scores for a specific request.
|
153
|
-
If some approximations were computed for this request before, the amount of approximations (per image)
|
154
|
-
will be returned. If not, zero is returned.
|
155
|
-
|
156
|
-
Args:
|
157
|
-
hessian_request: The request configuration for which to count the saved data.
|
150
|
+
request: request per which to fetch the hessians.
|
151
|
+
force_compute: if True, will compute the hessians.
|
152
|
+
If False, will look for cached hessians first.
|
158
153
|
|
159
154
|
Returns:
|
160
|
-
|
161
|
-
|
162
|
-
|
163
|
-
per_node_counter = {}
|
164
|
-
|
165
|
-
for n in hessian_request.target_nodes:
|
166
|
-
if n.reuse:
|
167
|
-
# Reused nodes supposed to have been replaced with a reuse_group
|
168
|
-
# representing node before calling this method.
|
169
|
-
Logger.critical(f"Expecting the Hessian request to include only non-reused nodes at this point, "
|
170
|
-
f"but found node {n.name} with 'reuse' status.")
|
171
|
-
# Check if the request for this node is in the saved info and store its count, otherwise store 0
|
172
|
-
per_node_counter[n] = len(self.hessian_scores_request_to_scores_list.get(hessian_request, []))
|
173
|
-
|
174
|
-
return per_node_counter
|
175
|
-
|
176
|
-
def compute(self,
|
177
|
-
hessian_scores_request: HessianScoresRequest,
|
178
|
-
representative_dataset_gen,
|
179
|
-
num_hessian_samples: int,
|
180
|
-
last_iter_remain_samples: List[List[np.ndarray]] = None):
|
181
|
-
"""
|
182
|
-
Computes scores based on the Hessian matrix approximation according to the
|
183
|
-
provided request configuration and stores it in the cache.
|
184
|
-
|
185
|
-
Args:
|
186
|
-
hessian_scores_request: Configuration for which to compute the approximation.
|
187
|
-
representative_dataset_gen: A callable that provides a dataset for sampling.
|
188
|
-
num_hessian_samples: Number of requested samples to compute batch Hessian approximation scores.
|
189
|
-
last_iter_remain_samples: A list of input samples (for each input layer) with remaining samples from
|
190
|
-
previous iterations.
|
155
|
+
A dictionary of layers' hessian tensors of shape (samples, ...). The exact shape depends on the
|
156
|
+
requested granularity.
|
191
157
|
"""
|
192
|
-
|
193
|
-
|
194
|
-
|
195
|
-
|
196
|
-
|
197
|
-
#
|
198
|
-
|
199
|
-
|
200
|
-
|
201
|
-
|
202
|
-
|
203
|
-
|
204
|
-
|
205
|
-
|
206
|
-
|
207
|
-
|
208
|
-
|
209
|
-
|
210
|
-
|
211
|
-
|
212
|
-
|
213
|
-
|
214
|
-
|
215
|
-
|
216
|
-
# We collect the results as a list of a result for images, which is combined across batches.
|
217
|
-
# After conversion, hessian_scores_request_to_scores_list for a request of a single node should be a list of
|
218
|
-
# results of all images, where each result is a tensor of the shape depending on the granularity.
|
219
|
-
if single_node_request in self.hessian_scores_request_to_scores_list:
|
220
|
-
self.hessian_scores_request_to_scores_list[single_node_request] += (
|
221
|
-
self._convert_tensor_to_list_of_appx_results(hessian))
|
222
|
-
else:
|
223
|
-
self.hessian_scores_request_to_scores_list[single_node_request] = (
|
224
|
-
self._convert_tensor_to_list_of_appx_results(hessian))
|
225
|
-
|
226
|
-
# In case that we are required to return a number of scores that is larger that the computation batch size
|
227
|
-
# and if in this case the computation batch size is smaller than the representative dataset batch size
|
228
|
-
# we need to carry over remaining samples from the last fetched batch to the next computation, otherwise,
|
229
|
-
# we might skip samples or remain without enough samples to complete the computations for the
|
230
|
-
# requested number of scores.
|
231
|
-
return next_iter_remain_samples if next_iter_remain_samples is not None and len(next_iter_remain_samples) > 0 \
|
232
|
-
and len(next_iter_remain_samples[0]) > 0 else None
|
233
|
-
|
234
|
-
def compute_trackable_per_sample_hessian(self,
|
235
|
-
hessian_scores_request: HessianScoresRequest,
|
236
|
-
inputs_batch: List[np.ndarray]) -> Dict[str, Dict['BaseNode', np.ndarray]]:
|
158
|
+
if request.n_samples is None and not force_compute:
|
159
|
+
raise ValueError('Number of samples can be None only when force_compute is True.')
|
160
|
+
|
161
|
+
orig_request = request
|
162
|
+
# replace reused nodes with primary nodes
|
163
|
+
# TODO need to check if there is a bug in reuse. While this is the same layer, the compare tensors and their
|
164
|
+
# gradients are not. It seems that currently the same compare tensor of the primary node is used multiple times
|
165
|
+
target_nodes = [self._get_primary_node(n) for n in request.target_nodes]
|
166
|
+
request = request.clone(target_nodes=target_nodes)
|
167
|
+
|
168
|
+
if force_compute:
|
169
|
+
res = self._compute_hessians(request, self.num_iterations_for_approximation, count_by_cache=False)
|
170
|
+
else:
|
171
|
+
res = self._fetch_hessians_with_compute(request, self.num_iterations_for_approximation)
|
172
|
+
|
173
|
+
# restore nodes from the original request
|
174
|
+
res = {n_orig.name: res[n.name] for n_orig, n in zip(orig_request.target_nodes, request.target_nodes)}
|
175
|
+
return res
|
176
|
+
|
177
|
+
def clear_cache(self):
|
178
|
+
""" Purge the cached hessians. """
|
179
|
+
self.cache.clear()
|
180
|
+
|
181
|
+
def _fetch_hessians_with_compute(self, request: HessianScoresRequest, n_iterations: int) -> Dict[LayerName, Tensor]:
|
237
182
|
"""
|
238
|
-
|
239
|
-
as data generator might yield different images each time, depending on how it was defined,
|
183
|
+
Fetch pre-computed hessians for the request if available. Otherwise, compute the missing hessians.
|
240
184
|
|
241
185
|
Args:
|
242
|
-
|
243
|
-
|
186
|
+
request: hessian estimation request.
|
187
|
+
n_iterations: the number of iterations for hessian estimation.
|
244
188
|
|
245
189
|
Returns:
|
246
|
-
A
|
190
|
+
A dictionary from layers (by name) to their hessians.
|
247
191
|
"""
|
248
|
-
|
249
|
-
|
250
|
-
|
251
|
-
|
252
|
-
|
253
|
-
|
254
|
-
|
255
|
-
|
256
|
-
|
257
|
-
|
258
|
-
#
|
259
|
-
|
260
|
-
|
261
|
-
|
262
|
-
|
263
|
-
|
264
|
-
|
265
|
-
|
266
|
-
|
267
|
-
|
268
|
-
|
269
|
-
|
270
|
-
return hessian_score_by_image_hash
|
271
|
-
|
272
|
-
@staticmethod
|
273
|
-
def calc_image_hash(image):
|
192
|
+
res, missing = self.cache.fetch_hessian(request)
|
193
|
+
if not missing:
|
194
|
+
return res
|
195
|
+
|
196
|
+
if request.data_loader is None:
|
197
|
+
raise ValueError(f'Not enough hessians are cached to fulfill the request, but data loader was not passed '
|
198
|
+
f'for additional computation. Requested {request.n_samples}, '
|
199
|
+
f'available {min(missing.values())}.')
|
200
|
+
|
201
|
+
orig_request = request
|
202
|
+
# if some hessians were found generate a new request only for missing nodes.
|
203
|
+
if res:
|
204
|
+
target_nodes = [n for n in orig_request.target_nodes if n.name in missing]
|
205
|
+
request = request.clone(target_nodes=target_nodes)
|
206
|
+
self._compute_hessians(request, n_iterations, count_by_cache=True)
|
207
|
+
res, missing = self.cache.fetch_hessian(request)
|
208
|
+
assert not missing
|
209
|
+
return res
|
210
|
+
|
211
|
+
def _compute_hessians(self, request: HessianScoresRequest,
|
212
|
+
n_iterations: int, count_by_cache: bool) -> Dict[LayerName, Tensor]:
|
274
213
|
"""
|
275
|
-
|
276
|
-
|
277
|
-
Args:
|
278
|
-
image: input 3d image (without batch).
|
214
|
+
Computes hessian estimation per request.
|
279
215
|
|
280
|
-
|
281
|
-
|
216
|
+
Data loader from request is used as is, i.e. it should reflect the required batch size (e.g. if
|
217
|
+
hessians should be estimated sample by sample, the data loader should yield a single sample at a time).
|
282
218
|
|
283
|
-
|
284
|
-
|
285
|
-
|
286
|
-
image_bytes = image.astype(np.float32).tobytes()
|
287
|
-
return hashlib.md5(image_bytes).hexdigest()
|
288
|
-
|
289
|
-
def fetch_hessian(self,
|
290
|
-
hessian_scores_request: HessianScoresRequest,
|
291
|
-
required_size: int,
|
292
|
-
batch_size: int = 1) -> List[List[np.ndarray]]:
|
293
|
-
"""
|
294
|
-
Fetches the computed approximations of the Hessian-based scores for the given
|
295
|
-
request and required size.
|
219
|
+
NOTE: the returned value only contains hessians that were computed here, which may differ from the requested
|
220
|
+
number of samples. It's only intended for use when you specifically need sample-wise hessians for the
|
221
|
+
samples in the request.
|
296
222
|
|
297
223
|
Args:
|
298
|
-
|
299
|
-
|
300
|
-
|
301
|
-
|
224
|
+
request: hessian estimation request.
|
225
|
+
n_iterations: the number of iterations for hessian estimation.
|
226
|
+
count_by_cache: if True, computes hessians until the cache contains the requested number of samples.
|
227
|
+
if False, computes hessian for the first requested number of sample in the dataloader.
|
302
228
|
Returns:
|
303
|
-
|
304
|
-
|
305
|
-
The inner list length dependent on the granularity (1 for per-tensor,
|
306
|
-
OC for per-output-channel when the requested node has OC output-channels, etc.)
|
307
|
-
"""
|
308
|
-
|
309
|
-
if len(hessian_scores_request.target_nodes) == 0: # pragma: no cover
|
310
|
-
return []
|
311
|
-
|
312
|
-
if required_size == 0:
|
313
|
-
return [[] for _ in hessian_scores_request.target_nodes]
|
314
|
-
|
315
|
-
Logger.info(f"\nEnsuring {required_size} Hessian-approximation scores for nodes "
|
316
|
-
f"{hessian_scores_request.target_nodes}.")
|
317
|
-
|
318
|
-
# Replace node in reused target nodes with a representing node from the 'reuse group'.
|
319
|
-
hessian_scores_request.target_nodes = [
|
320
|
-
self._get_representing_of_reuse_group(node) if node.reuse else node
|
321
|
-
for node in hessian_scores_request.target_nodes
|
322
|
-
]
|
323
|
-
|
324
|
-
# Ensure the saved info has the required number of approximations
|
325
|
-
self._populate_saved_info_to_size(hessian_scores_request, required_size, batch_size)
|
326
|
-
|
327
|
-
# Return the saved approximations for the given request
|
328
|
-
return self._collect_saved_hessians_for_request(hessian_scores_request, required_size)
|
329
|
-
|
330
|
-
def _get_representing_of_reuse_group(self, node) -> Any:
|
331
|
-
"""
|
332
|
-
For each reused group we compute and fetch its members using a single request.
|
333
|
-
This method creates and returns a request for the reused group the node is in.
|
334
|
-
|
335
|
-
Args:
|
336
|
-
node: The node to get its reuse group representative node.
|
337
|
-
|
338
|
-
Returns: A reuse group representative node (BaseNode).
|
339
|
-
"""
|
340
|
-
father_nodes = [n for n in self.graph.nodes if not n.reuse and n.reuse_group == node.reuse_group]
|
341
|
-
if len(father_nodes) != 1: # pragma: no cover
|
342
|
-
Logger.critical(f"Expected a single non-reused node in the reused group, "
|
343
|
-
f"but found {len(father_nodes)}.")
|
344
|
-
|
345
|
-
return father_nodes[0]
|
346
|
-
|
347
|
-
def _populate_saved_info_to_size(self,
|
348
|
-
hessian_scores_request: HessianScoresRequest,
|
349
|
-
required_size: int,
|
350
|
-
batch_size: int = 1):
|
229
|
+
A dictionary from layers (by name) to their hessian tensors that *were computed in this invocation*.
|
230
|
+
First axis corresponds to samples in the order determined by the data loader.
|
351
231
|
"""
|
352
|
-
|
353
|
-
|
354
|
-
|
355
|
-
|
356
|
-
|
357
|
-
|
232
|
+
if count_by_cache:
|
233
|
+
assert request.n_samples is not None
|
234
|
+
|
235
|
+
n_samples = 0
|
236
|
+
hess_per_layer = []
|
237
|
+
for batch in request.data_loader:
|
238
|
+
batch_hess_per_layer = self._compute_hessian_for_batch(request, batch, n_iterations)
|
239
|
+
hess_per_layer.append(batch_hess_per_layer)
|
240
|
+
min_count = self.cache.update(batch_hess_per_layer, request)
|
241
|
+
n_samples = min_count if count_by_cache else (n_samples + batch[0].shape[0])
|
242
|
+
if request.n_samples and n_samples >= request.n_samples:
|
243
|
+
break
|
244
|
+
|
245
|
+
hess_per_layer = {
|
246
|
+
layer.name: np.concatenate([hess[layer.name] for hess in hess_per_layer], axis=0)
|
247
|
+
for layer in request.target_nodes
|
248
|
+
}
|
249
|
+
|
250
|
+
if request.n_samples:
|
251
|
+
if n_samples < request.n_samples:
|
252
|
+
raise ValueError(f'Could not compute the requested number of Hessians ({request.n_samples}), '
|
253
|
+
f'not enough samples in the provided representative dataset.')
|
254
|
+
|
255
|
+
if n_samples > request.n_samples:
|
256
|
+
hess_per_layer = {
|
257
|
+
layer: hess[:request.n_samples, ...] for layer, hess in hess_per_layer.items()
|
258
|
+
}
|
259
|
+
return hess_per_layer
|
260
|
+
|
261
|
+
def _compute_hessian_for_batch(self,
|
262
|
+
request: HessianScoresRequest,
|
263
|
+
inputs_batch: List[Tensor],
|
264
|
+
n_iterations: int) -> Dict[LayerName, Tensor]:
|
358
265
|
"""
|
359
|
-
|
360
|
-
# Get the current number of saved approximations for each node in the request
|
361
|
-
current_existing_hessians = self.count_saved_scores_of_request(hessian_scores_request)
|
362
|
-
|
363
|
-
# Compute the required number of approximations to meet the required size.
|
364
|
-
# Since we allow batch and multi-nodes computation, we take the node with the maximal number of missing
|
365
|
-
# approximations to compute, and run batch computations until meeting the requirement.
|
366
|
-
min_exist_hessians = min(current_existing_hessians.values())
|
367
|
-
max_remaining_hessians = required_size - min_exist_hessians
|
368
|
-
|
369
|
-
Logger.info(
|
370
|
-
f"Running Hessian approximation computation for {len(hessian_scores_request.target_nodes)} nodes.\n "
|
371
|
-
f"The node with minimal existing Hessian-approximation scores has {min_exist_hessians} "
|
372
|
-
f"approximated scores computed.\n"
|
373
|
-
f"{max_remaining_hessians} approximations left to compute...")
|
374
|
-
|
375
|
-
hessian_representative_dataset = partial(self._sample_batch_representative_dataset,
|
376
|
-
num_inputs=len(self.graph.input_nodes),
|
377
|
-
representative_dataset=self.representative_dataset_gen())
|
378
|
-
|
379
|
-
next_iter_remaining_samples = None
|
380
|
-
pbar = tqdm(desc="Computing Hessian approximations...", total=None)
|
381
|
-
while max_remaining_hessians > 0:
|
382
|
-
# If batch_size < max_remaining_hessians then we run each computation on a batch_size of images.
|
383
|
-
# This way, we always run a computation for a single batch.
|
384
|
-
pbar.update(1)
|
385
|
-
size_to_compute = min(max_remaining_hessians, batch_size)
|
386
|
-
next_iter_remaining_samples = (
|
387
|
-
self.compute(hessian_scores_request, hessian_representative_dataset, size_to_compute,
|
388
|
-
last_iter_remain_samples=next_iter_remaining_samples))
|
389
|
-
max_remaining_hessians -= size_to_compute
|
390
|
-
|
391
|
-
def _collect_saved_hessians_for_request(self,
|
392
|
-
hessian_scores_request: HessianScoresRequest,
|
393
|
-
required_size: int) -> List[List[np.ndarray]]:
|
394
|
-
"""
|
395
|
-
Collects Hessian approximation for the nodes in the given request.
|
266
|
+
Use hessian score calculator to compute hessian approximations for a batch of inputs.
|
396
267
|
|
397
268
|
Args:
|
398
|
-
|
399
|
-
|
400
|
-
|
401
|
-
Returns: A list with List of computed Hessian approximation (a tensor for each score) for each node
|
402
|
-
in the request.
|
269
|
+
request: hessian estimation request.
|
270
|
+
inputs_batch: a batch of inputs to estimate hessians on.
|
271
|
+
n_iterations: the number of iterations for hessian estimation.
|
403
272
|
|
273
|
+
Returns:
|
274
|
+
A dictionary from layers (by name) to their hessians.
|
404
275
|
"""
|
405
|
-
|
406
|
-
|
407
|
-
|
408
|
-
|
409
|
-
|
410
|
-
|
411
|
-
res_for_node = self.hessian_scores_request_to_scores_list.get(single_node_request)
|
412
|
-
if res_for_node is None: # pragma: no cover
|
413
|
-
Logger.critical(f"Couldn't find saved Hessian approximations for node {node.name}.")
|
414
|
-
if len(res_for_node) < required_size: # pragma: no cover
|
415
|
-
Logger.critical(f"Missing Hessian approximations for node {node.name}, requested {required_size} "
|
416
|
-
f"but found only {len(res_for_node)}.")
|
417
|
-
|
418
|
-
res_for_node = res_for_node[:required_size]
|
276
|
+
fw_hessian_calculator = self.fw_impl.get_hessian_scores_calculator(
|
277
|
+
graph=self.graph,
|
278
|
+
input_images=inputs_batch,
|
279
|
+
hessian_scores_request=request,
|
280
|
+
num_iterations_for_approximation=n_iterations
|
281
|
+
)
|
419
282
|
|
420
|
-
|
283
|
+
hessian_scores: list = fw_hessian_calculator.compute()
|
421
284
|
|
422
|
-
|
285
|
+
layers_hessian_scores = {
|
286
|
+
layer.name: score for layer, score in zip(request.target_nodes, hessian_scores)
|
287
|
+
}
|
288
|
+
return layers_hessian_scores
|
423
289
|
|
424
|
-
|
425
|
-
def _construct_single_node_request(mode: HessianMode,
|
426
|
-
granularity: HessianScoresGranularity,
|
427
|
-
target_nodes: List) -> HessianScoresRequest:
|
290
|
+
def _get_primary_node(self, node: 'BaseNode') -> 'BaseNode':
|
428
291
|
"""
|
429
|
-
|
292
|
+
Get node's primary node that it reuses, or itself if not reused.
|
430
293
|
|
431
294
|
Args:
|
432
|
-
|
433
|
-
granularity (HessianScoresGranularity): Granularity level for the approximation.
|
434
|
-
target_nodes (List[BaseNode]): The node in the float graph for which the Hessian's approximation scores is targeted.
|
295
|
+
node: node's object to get its primary node.
|
435
296
|
|
436
|
-
Returns:
|
437
|
-
|
438
|
-
"""
|
439
|
-
return HessianScoresRequest(mode,
|
440
|
-
granularity,
|
441
|
-
target_nodes=[target_nodes])
|
442
|
-
|
443
|
-
@staticmethod
|
444
|
-
def _convert_tensor_to_list_of_appx_results(t: Any) -> List:
|
297
|
+
Returns:
|
298
|
+
Node's primary node.
|
445
299
|
"""
|
446
|
-
|
447
|
-
|
448
|
-
Args:
|
449
|
-
t: A tensor with Hessian approximation results.
|
450
|
-
|
451
|
-
Returns: A list with split batch into individual results.
|
300
|
+
if node.reuse is False:
|
301
|
+
return node
|
452
302
|
|
453
|
-
|
454
|
-
|
303
|
+
father_nodes = [n for n in self.graph.nodes if not n.reuse and n.reuse_group == node.reuse_group]
|
304
|
+
assert len(father_nodes) == 1
|
305
|
+
return father_nodes[0]
|