mct-nightly 2.2.0.20241022.507__py3-none-any.whl → 2.2.0.20241024.501__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mct_nightly-2.2.0.20241022.507.dist-info → mct_nightly-2.2.0.20241024.501.dist-info}/METADATA +1 -1
- {mct_nightly-2.2.0.20241022.507.dist-info → mct_nightly-2.2.0.20241024.501.dist-info}/RECORD +38 -31
- model_compression_toolkit/__init__.py +1 -1
- model_compression_toolkit/core/common/framework_implementation.py +43 -29
- model_compression_toolkit/core/common/hessian/__init__.py +1 -1
- model_compression_toolkit/core/common/hessian/hessian_info_service.py +222 -371
- model_compression_toolkit/core/common/hessian/hessian_scores_request.py +27 -41
- model_compression_toolkit/core/common/mixed_precision/sensitivity_evaluation.py +8 -10
- model_compression_toolkit/core/common/pruning/importance_metrics/lfh_importance_metric.py +11 -9
- model_compression_toolkit/core/common/quantization/quantization_params_generation/error_functions.py +10 -6
- model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_computation.py +17 -15
- model_compression_toolkit/core/keras/data_util.py +67 -0
- model_compression_toolkit/core/keras/keras_implementation.py +7 -1
- model_compression_toolkit/core/keras/tf_tensor_numpy.py +1 -1
- model_compression_toolkit/core/pytorch/back2framework/pytorch_model_builder.py +1 -1
- model_compression_toolkit/core/pytorch/data_util.py +163 -0
- model_compression_toolkit/core/pytorch/hessian/activation_hessian_scores_calculator_pytorch.py +6 -31
- model_compression_toolkit/core/pytorch/hessian/hessian_scores_calculator_pytorch.py +11 -21
- model_compression_toolkit/core/pytorch/hessian/weights_hessian_scores_calculator_pytorch.py +9 -7
- model_compression_toolkit/core/pytorch/pytorch_implementation.py +8 -2
- model_compression_toolkit/core/pytorch/utils.py +22 -19
- model_compression_toolkit/core/quantization_prep_runner.py +2 -1
- model_compression_toolkit/core/runner.py +1 -2
- model_compression_toolkit/gptq/common/gptq_config.py +0 -2
- model_compression_toolkit/gptq/common/gptq_training.py +58 -114
- model_compression_toolkit/gptq/keras/gptq_training.py +15 -6
- model_compression_toolkit/gptq/pytorch/gptq_loss.py +3 -2
- model_compression_toolkit/gptq/pytorch/gptq_training.py +97 -64
- model_compression_toolkit/gptq/pytorch/quantization_facade.py +0 -2
- model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/soft_quantizer_reg.py +4 -3
- tests_pytest/keras/__init__.py +14 -0
- tests_pytest/keras/core/__init__.py +14 -0
- tests_pytest/keras/core/test_data_util.py +91 -0
- tests_pytest/pytorch/core/__init__.py +14 -0
- tests_pytest/pytorch/core/test_data_util.py +125 -0
- {mct_nightly-2.2.0.20241022.507.dist-info → mct_nightly-2.2.0.20241024.501.dist-info}/LICENSE.md +0 -0
- {mct_nightly-2.2.0.20241022.507.dist-info → mct_nightly-2.2.0.20241024.501.dist-info}/WHEEL +0 -0
- {mct_nightly-2.2.0.20241022.507.dist-info → mct_nightly-2.2.0.20241024.501.dist-info}/top_level.txt +0 -0
@@ -12,10 +12,14 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
|
-
from typing import
|
15
|
+
from typing import Iterable, Sequence, Optional, TYPE_CHECKING
|
16
|
+
import dataclasses
|
16
17
|
|
17
18
|
from enum import Enum
|
18
19
|
|
20
|
+
if TYPE_CHECKING: # pragma: no cover
|
21
|
+
from model_compression_toolkit.core.common import BaseNode
|
22
|
+
|
19
23
|
|
20
24
|
class HessianMode(Enum):
|
21
25
|
"""
|
@@ -40,14 +44,7 @@ class HessianScoresGranularity(Enum):
|
|
40
44
|
PER_TENSOR = 2
|
41
45
|
|
42
46
|
|
43
|
-
|
44
|
-
"""
|
45
|
-
Distribution for Hutchinson estimator random vector
|
46
|
-
"""
|
47
|
-
GAUSSIAN = 'gaussian'
|
48
|
-
RADEMACHER = 'rademacher'
|
49
|
-
|
50
|
-
|
47
|
+
@dataclasses.dataclass
|
51
48
|
class HessianScoresRequest:
|
52
49
|
"""
|
53
50
|
Request configuration for the Hessian-approximation scores.
|
@@ -55,36 +52,25 @@ class HessianScoresRequest:
|
|
55
52
|
This class defines the parameters for the scores based on the Hessian matrix approximation.
|
56
53
|
It specifies the mode (weights/activations), granularity (element/channel/tensor), and the target node.
|
57
54
|
|
58
|
-
|
55
|
+
Attributes:
|
56
|
+
mode: Mode of Hessian-approximation score (w.r.t weights or activations).
|
57
|
+
granularity: Granularity level for the approximation.
|
58
|
+
target_nodes: The node objects in the float graph for which the Hessian's approximation scores is targeted.
|
59
|
+
data_loader: Data loader to compute hessian approximations on. Should reflect the desired batch size for
|
60
|
+
the computation. Can be None if all hessians for the request are expected to be pre-computed previously.
|
61
|
+
n_samples: The number of samples to fetch hessian estimations for. If None, fetch hessians for a full pass
|
62
|
+
of the data loader.
|
59
63
|
"""
|
60
|
-
|
61
|
-
|
62
|
-
|
63
|
-
|
64
|
-
|
65
|
-
|
66
|
-
|
67
|
-
|
68
|
-
|
69
|
-
|
70
|
-
|
71
|
-
"""
|
72
|
-
|
73
|
-
self.mode = mode # w.r.t activations or weights
|
74
|
-
self.granularity = granularity # per element, per layer, per channel
|
75
|
-
self.target_nodes = target_nodes
|
76
|
-
self.distribution = distribution
|
77
|
-
|
78
|
-
def __eq__(self, other):
|
79
|
-
# Checks if the other object is an instance of HessianScoresRequest
|
80
|
-
# and then checks if all attributes are equal.
|
81
|
-
return isinstance(other, HessianScoresRequest) and \
|
82
|
-
self.mode == other.mode and \
|
83
|
-
self.granularity == other.granularity and \
|
84
|
-
self.target_nodes == other.target_nodes and \
|
85
|
-
self.distribution == other.distribution
|
86
|
-
|
87
|
-
def __hash__(self):
|
88
|
-
# Computes the hash based on the attributes.
|
89
|
-
# The use of a tuple here ensures that the hash is influenced by all the attributes.
|
90
|
-
return hash((self.mode, self.granularity, tuple(self.target_nodes), self.distribution))
|
64
|
+
mode: HessianMode
|
65
|
+
granularity: HessianScoresGranularity
|
66
|
+
target_nodes: Sequence['BaseNode']
|
67
|
+
data_loader: Optional[Iterable]
|
68
|
+
n_samples: Optional[int]
|
69
|
+
|
70
|
+
def __post_init__(self):
|
71
|
+
if self.data_loader is None and self.n_samples is None:
|
72
|
+
raise ValueError('Data loader and the number of samples cannot both be None.')
|
73
|
+
|
74
|
+
def clone(self, **kwargs):
|
75
|
+
""" Create a clone with optional overrides """
|
76
|
+
return dataclasses.replace(self, **kwargs)
|
@@ -238,22 +238,20 @@ class SensitivityEvaluation:
|
|
238
238
|
"""
|
239
239
|
# Create a request for Hessian approximation scores with specific configurations
|
240
240
|
# (here we use per-tensor approximation of the Hessian's trace w.r.t the node's activations)
|
241
|
+
fw_dataloader = self.fw_impl.convert_data_gen_to_dataloader(self.representative_data_gen,
|
242
|
+
batch_size=self.quant_config.hessian_batch_size)
|
241
243
|
hessian_info_request = HessianScoresRequest(mode=HessianMode.ACTIVATION,
|
242
244
|
granularity=HessianScoresGranularity.PER_TENSOR,
|
243
|
-
target_nodes=self.interest_points
|
245
|
+
target_nodes=self.interest_points,
|
246
|
+
data_loader=fw_dataloader,
|
247
|
+
n_samples=self.quant_config.num_of_images)
|
244
248
|
|
245
249
|
# Fetch the Hessian approximation scores for the current interest point
|
246
|
-
nodes_approximations = self.hessian_info_service.fetch_hessian(
|
247
|
-
|
248
|
-
batch_size=self.quant_config.hessian_batch_size)
|
249
|
-
|
250
|
-
# Store the approximations for each node for each image
|
251
|
-
approx_by_image = [[nodes_approximations[j][image_idx]
|
252
|
-
for j, _ in enumerate(self.interest_points)]
|
253
|
-
for image_idx in range(self.quant_config.num_of_images)]
|
250
|
+
nodes_approximations = self.hessian_info_service.fetch_hessian(request=hessian_info_request)
|
251
|
+
approx_by_image = np.stack([nodes_approximations[n.name] for n in self.interest_points], axis=1) # samples X nodes
|
254
252
|
|
255
253
|
# Return the mean approximation value across all images for each interest point
|
256
|
-
return np.mean(
|
254
|
+
return np.mean(approx_by_image, axis=0)
|
257
255
|
|
258
256
|
def _configure_bitwidths_model(self,
|
259
257
|
mp_model_configuration: List[int],
|
@@ -120,22 +120,24 @@ class LFHImportanceMetric(BaseImportanceMetric):
|
|
120
120
|
"""
|
121
121
|
|
122
122
|
# Initialize HessianInfoService for score computation.
|
123
|
+
|
123
124
|
hessian_info_service = HessianInfoService(graph=self.float_graph,
|
124
|
-
representative_dataset_gen=self.representative_data_gen,
|
125
125
|
fw_impl=self.fw_impl)
|
126
126
|
|
127
127
|
# Fetch and process Hessian scores for output channels of entry nodes.
|
128
|
-
|
128
|
+
data_loader = self.fw_impl.convert_data_gen_to_dataloader(self.representative_data_gen, batch_size=1)
|
129
|
+
nodes_scores = {}
|
129
130
|
for node in entry_nodes:
|
130
|
-
|
131
|
-
|
132
|
-
|
133
|
-
|
134
|
-
|
135
|
-
|
131
|
+
request = HessianScoresRequest(mode=HessianMode.WEIGHTS,
|
132
|
+
granularity=HessianScoresGranularity.PER_OUTPUT_CHANNEL,
|
133
|
+
target_nodes=[node],
|
134
|
+
data_loader=data_loader,
|
135
|
+
n_samples=self.pruning_config.num_score_approximations)
|
136
|
+
node_scores = hessian_info_service.fetch_hessian(request)
|
137
|
+
nodes_scores.update(node_scores)
|
136
138
|
|
137
139
|
# Average and map scores to nodes.
|
138
|
-
self._entry_node_to_hessian_score = {node: np.mean(
|
140
|
+
self._entry_node_to_hessian_score = {node: np.mean(nodes_scores[node.name], axis=0).flatten() for node in entry_nodes}
|
139
141
|
|
140
142
|
self._entry_node_count_oc_nparams = self._count_oc_nparams(entry_nodes=entry_nodes)
|
141
143
|
_entry_node_l2_oc_norm = self._get_squaredl2norm(entry_nodes=entry_nodes)
|
model_compression_toolkit/core/common/quantization/quantization_params_generation/error_functions.py
CHANGED
@@ -13,7 +13,7 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
15
|
from copy import deepcopy
|
16
|
-
from typing import Tuple, Callable, List
|
16
|
+
from typing import Tuple, Callable, List, Iterable, Optional
|
17
17
|
import numpy as np
|
18
18
|
import model_compression_toolkit.core.common.quantization.quantization_config as qc
|
19
19
|
from model_compression_toolkit.core.common.hessian import HessianScoresRequest, HessianMode, HessianScoresGranularity, \
|
@@ -377,7 +377,8 @@ def _get_sliced_histogram(bins: np.ndarray,
|
|
377
377
|
|
378
378
|
def _compute_hessian_for_hmse(node,
|
379
379
|
hessian_info_service: HessianInfoService,
|
380
|
-
num_hessian_samples: int
|
380
|
+
num_hessian_samples: int,
|
381
|
+
dataloader: Optional[Iterable]) -> List[List[np.ndarray]]:
|
381
382
|
"""
|
382
383
|
Compute and retrieve Hessian-based scores for using during HMSE error computation.
|
383
384
|
|
@@ -385,15 +386,18 @@ def _compute_hessian_for_hmse(node,
|
|
385
386
|
node: The node to compute Hessian-based scores for.
|
386
387
|
hessian_info_service: HessianInfoService object for retrieving Hessian-based scores.
|
387
388
|
num_hessian_samples: Number of samples to approximate Hessian-based scores on.
|
389
|
+
dataloader: Data loader for computing Hessian-based scores. Can be None if hessians are expected to be
|
390
|
+
available, i.e. have been already computed previously.
|
388
391
|
|
389
392
|
Returns: A list with computed Hessian-based scores tensors for the given node.
|
390
393
|
|
391
394
|
"""
|
392
395
|
_request = HessianScoresRequest(mode=HessianMode.WEIGHTS,
|
393
396
|
granularity=HessianScoresGranularity.PER_ELEMENT,
|
397
|
+
data_loader=dataloader,
|
398
|
+
n_samples=num_hessian_samples,
|
394
399
|
target_nodes=[node])
|
395
|
-
_scores_for_node = hessian_info_service.fetch_hessian(_request
|
396
|
-
required_size=num_hessian_samples)
|
400
|
+
_scores_for_node = hessian_info_service.fetch_hessian(_request)
|
397
401
|
|
398
402
|
return _scores_for_node
|
399
403
|
|
@@ -476,11 +480,11 @@ def get_threshold_selection_tensor_error_function(quantization_method: Quantizat
|
|
476
480
|
per_channel=True)
|
477
481
|
|
478
482
|
if quant_error_method == qc.QuantizationErrorMethod.HMSE:
|
479
|
-
node_hessian_scores = _compute_hessian_for_hmse(node, hessian_info_service, num_hessian_samples)
|
483
|
+
node_hessian_scores = _compute_hessian_for_hmse(node, hessian_info_service, num_hessian_samples, None)
|
480
484
|
if len(node_hessian_scores) != 1:
|
481
485
|
Logger.critical(f"Expecting single node Hessian score request to return a list of length 1, but got a list "
|
482
486
|
f"of length {len(node_hessian_scores)}.")
|
483
|
-
node_hessian_scores = np.sqrt(np.mean(node_hessian_scores[
|
487
|
+
node_hessian_scores = np.sqrt(np.mean(node_hessian_scores[node.name], axis=0))
|
484
488
|
|
485
489
|
return lambda x, y, threshold: _hmse_error_function_wrapper(x, y, norm=norm, axis=axis,
|
486
490
|
hessian_scores=node_hessian_scores)
|
@@ -15,11 +15,12 @@
|
|
15
15
|
import copy
|
16
16
|
|
17
17
|
from tqdm import tqdm
|
18
|
-
from typing import List
|
18
|
+
from typing import List, Callable, Generator
|
19
19
|
|
20
20
|
from model_compression_toolkit.constants import NUM_QPARAM_HESSIAN_SAMPLES
|
21
21
|
from model_compression_toolkit.core import QuantizationErrorMethod
|
22
22
|
from model_compression_toolkit.core.common import Graph, BaseNode
|
23
|
+
from model_compression_toolkit.core.common.framework_implementation import FrameworkImplementation
|
23
24
|
from model_compression_toolkit.core.common.hessian import HessianInfoService, HessianScoresRequest, HessianMode, \
|
24
25
|
HessianScoresGranularity
|
25
26
|
from model_compression_toolkit.core.common.quantization.quantization_params_generation.qparams_activations_computation \
|
@@ -55,26 +56,25 @@ def _collect_nodes_for_hmse(nodes_list: List[BaseNode], graph: Graph) -> List[Ba
|
|
55
56
|
|
56
57
|
|
57
58
|
def calculate_quantization_params(graph: Graph,
|
58
|
-
|
59
|
-
|
59
|
+
fw_impl: FrameworkImplementation,
|
60
|
+
repr_data_gen_fn: Callable[[], Generator],
|
61
|
+
nodes: List[BaseNode] = None,
|
60
62
|
hessian_info_service: HessianInfoService = None,
|
61
63
|
num_hessian_samples: int = NUM_QPARAM_HESSIAN_SAMPLES):
|
62
64
|
"""
|
63
65
|
For a graph, go over its nodes, compute quantization params (for both weights and activations according
|
64
66
|
to the given framework info), and create and attach a NodeQuantizationConfig to each node (containing the
|
65
67
|
computed params).
|
66
|
-
By default, the function goes over all nodes in the graph. However,
|
67
|
-
to compute quantization params for
|
68
|
-
a list of nodes should be passed as well.
|
68
|
+
By default, the function goes over all nodes in the graph. However, specific nodes can be passed
|
69
|
+
to compute quantization params only for them.
|
69
70
|
|
70
71
|
Args:
|
71
|
-
groups of layers by how they should be quantized, etc.)
|
72
72
|
graph: Graph to compute its nodes' thresholds.
|
73
|
+
fw_impl: FrameworkImplementation object.
|
74
|
+
repr_data_gen_fn: callable returning representative dataset generator.
|
73
75
|
nodes: List of nodes to compute their thresholds instead of computing it for all nodes in the graph.
|
74
|
-
specific_nodes: Flag to compute thresholds for only specific nodes.
|
75
76
|
hessian_info_service: HessianInfoService object for retrieving Hessian-based scores (used only with HMSE error method).
|
76
77
|
num_hessian_samples: Number of samples to approximate Hessian-based scores on (used only with HMSE error method).
|
77
|
-
|
78
78
|
"""
|
79
79
|
|
80
80
|
Logger.info(f"\nRunning quantization parameters search. "
|
@@ -82,18 +82,20 @@ def calculate_quantization_params(graph: Graph,
|
|
82
82
|
f"depending on the model size and the selected quantization methods.\n")
|
83
83
|
|
84
84
|
# Create a list of nodes to compute their thresholds
|
85
|
-
nodes_list: List[BaseNode] = nodes
|
85
|
+
nodes_list: List[BaseNode] = nodes or graph.nodes()
|
86
86
|
|
87
87
|
# Collecting nodes that are configured to search weights quantization parameters using HMSE optimization
|
88
88
|
# and computing required Hessian information to be used for HMSE parameters selection.
|
89
89
|
# The Hessian scores are computed and stored in the hessian_info_service object.
|
90
90
|
nodes_for_hmse = _collect_nodes_for_hmse(nodes_list, graph)
|
91
91
|
if len(nodes_for_hmse) > 0:
|
92
|
-
|
93
|
-
|
94
|
-
|
95
|
-
|
96
|
-
|
92
|
+
dataloader = fw_impl.convert_data_gen_to_dataloader(repr_data_gen_fn, batch_size=1)
|
93
|
+
request = HessianScoresRequest(mode=HessianMode.WEIGHTS,
|
94
|
+
granularity=HessianScoresGranularity.PER_ELEMENT,
|
95
|
+
data_loader=dataloader,
|
96
|
+
n_samples=num_hessian_samples,
|
97
|
+
target_nodes=nodes_for_hmse)
|
98
|
+
hessian_info_service.fetch_hessian(request)
|
97
99
|
|
98
100
|
for n in tqdm(nodes_list, "Calculating quantization parameters"): # iterate only nodes that we should compute their thresholds
|
99
101
|
for candidate_qc in n.candidates_quantization_cfg:
|
@@ -0,0 +1,67 @@
|
|
1
|
+
# Copyright 2024 Sony Semiconductor Israel, Inc. All rights reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
from typing import Generator, Callable
|
16
|
+
|
17
|
+
import tensorflow as tf
|
18
|
+
|
19
|
+
from model_compression_toolkit.core.keras.tf_tensor_numpy import to_tf_tensor
|
20
|
+
|
21
|
+
|
22
|
+
def flat_gen_fn(data_gen_fn: Callable[[], Generator]):
|
23
|
+
"""
|
24
|
+
Convert data generator with arbitrary batch size to a flat (sample by sample) data generator.
|
25
|
+
|
26
|
+
Args:
|
27
|
+
data_gen_fn: input data generator factory. Generator is expected to yield lists of tensors.
|
28
|
+
|
29
|
+
Returns:
|
30
|
+
A factory for a flattened data generator.
|
31
|
+
"""
|
32
|
+
def gen():
|
33
|
+
for inputs_batch in data_gen_fn():
|
34
|
+
for sample in zip(*inputs_batch):
|
35
|
+
yield to_tf_tensor(sample)
|
36
|
+
return gen
|
37
|
+
|
38
|
+
|
39
|
+
# TODO in tf dataset and dataloader are combined within tf.data.Dataset. For advanced use cases such as gptq sla we
|
40
|
+
# need to separate dataset from dataloader similarly to torch data_util.
|
41
|
+
class TFDatasetFromGenerator:
|
42
|
+
def __init__(self, data_gen, batch_size):
|
43
|
+
inputs = next(data_gen())
|
44
|
+
if not isinstance(inputs, list):
|
45
|
+
raise TypeError(f'Representative data generator is expected to generate a list of tensors, '
|
46
|
+
f'got {type(inputs)}') # pragma: no cover
|
47
|
+
|
48
|
+
self.orig_batch_size = inputs[0].shape[0]
|
49
|
+
|
50
|
+
output_signature = tuple([tf.TensorSpec(shape=t.shape[1:], dtype=t.dtype) for t in inputs])
|
51
|
+
dataset = tf.data.Dataset.from_generator(flat_gen_fn(data_gen), output_signature=output_signature)
|
52
|
+
self.dataset = dataset.batch(batch_size)
|
53
|
+
self._size = None
|
54
|
+
|
55
|
+
def __iter__(self):
|
56
|
+
return iter(self.dataset)
|
57
|
+
|
58
|
+
def __len__(self):
|
59
|
+
""" Returns the number of batches. """
|
60
|
+
if self._size is None:
|
61
|
+
self._num_batches = sum(1 for _ in self)
|
62
|
+
return self._num_batches
|
63
|
+
|
64
|
+
|
65
|
+
def data_gen_to_dataloader(data_gen_fn: Callable[[], Generator], batch_size) -> TFDatasetFromGenerator:
|
66
|
+
""" Create DataLoader based on samples yielded by data_gen. """
|
67
|
+
return TFDatasetFromGenerator(data_gen_fn, batch_size)
|
@@ -13,7 +13,7 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
15
|
from functools import partial
|
16
|
-
from typing import List, Any, Tuple, Callable, Dict, Union
|
16
|
+
from typing import List, Any, Tuple, Callable, Dict, Union, Generator
|
17
17
|
|
18
18
|
import numpy as np
|
19
19
|
import tensorflow as tf
|
@@ -23,6 +23,7 @@ from tensorflow.keras.models import Model
|
|
23
23
|
from model_compression_toolkit.constants import HESSIAN_NUM_ITERATIONS
|
24
24
|
from model_compression_toolkit.core.common.graph.functional_node import FunctionalNode
|
25
25
|
from model_compression_toolkit.core.common.hessian import HessianScoresRequest, HessianMode, HessianInfoService
|
26
|
+
from model_compression_toolkit.core.keras.data_util import data_gen_to_dataloader
|
26
27
|
from model_compression_toolkit.core.keras.graph_substitutions.substitutions.remove_identity import RemoveIdentity
|
27
28
|
from model_compression_toolkit.core.keras.hessian.activation_hessian_scores_calculator_keras import \
|
28
29
|
ActivationHessianScoresCalculatorKeras
|
@@ -628,3 +629,8 @@ class KerasImplementation(FrameworkImplementation):
|
|
628
629
|
get_weights_quantizer_for_node,
|
629
630
|
get_activations_quantizer_for_node,
|
630
631
|
attribute_names)
|
632
|
+
|
633
|
+
@staticmethod
|
634
|
+
def convert_data_gen_to_dataloader(data_gen_fn: Callable[[], Generator], batch_size: int):
|
635
|
+
""" Create DataLoader based on samples yielded by data_gen. """
|
636
|
+
return data_gen_to_dataloader(data_gen_fn, batch_size=batch_size)
|
@@ -33,7 +33,7 @@ def to_tf_tensor(tensor):
|
|
33
33
|
elif isinstance(tensor, list):
|
34
34
|
return [to_tf_tensor(t) for t in tensor]
|
35
35
|
elif isinstance(tensor, tuple):
|
36
|
-
return (to_tf_tensor(t) for t in tensor)
|
36
|
+
return tuple(to_tf_tensor(t) for t in tensor)
|
37
37
|
elif isinstance(tensor, np.ndarray):
|
38
38
|
return tf.convert_to_tensor(tensor.astype(np.float32))
|
39
39
|
else: # pragma: no cover
|
@@ -134,7 +134,7 @@ def _run_operation(n: BaseNode,
|
|
134
134
|
input_tensors = n.insert_positional_weights_to_input_list(input_tensors)
|
135
135
|
# convert inputs from positional weights (numpy arrays) to tensors. Must handle each element in the
|
136
136
|
# list separately, because in FX the tensors are FX objects and fail to_torch_tensor
|
137
|
-
input_tensors = [to_torch_tensor(t,
|
137
|
+
input_tensors = [to_torch_tensor(t, None) if isinstance(t, np.ndarray) else t
|
138
138
|
for t in input_tensors]
|
139
139
|
_tensor_input_allocs = None
|
140
140
|
|
@@ -0,0 +1,163 @@
|
|
1
|
+
# Copyright 2024 Sony Semiconductor Israel, Inc. All rights reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
from typing import Generator, Callable, Sequence, Any
|
16
|
+
|
17
|
+
import torch
|
18
|
+
from torch.utils.data import IterableDataset, Dataset, DataLoader, default_collate
|
19
|
+
|
20
|
+
|
21
|
+
def flat_gen_fn(data_gen_fn: Callable[[], Generator]):
|
22
|
+
"""
|
23
|
+
Convert data generator with arbitrary batch size to a flat (sample by sample) data generator.
|
24
|
+
|
25
|
+
Args:
|
26
|
+
data_gen_fn: input data generator factory. Generator is expected to yield lists of tensors.
|
27
|
+
|
28
|
+
Returns:
|
29
|
+
A factory for a flattened data generator.
|
30
|
+
"""
|
31
|
+
def gen():
|
32
|
+
for inputs_batch in data_gen_fn():
|
33
|
+
for sample in zip(*inputs_batch):
|
34
|
+
# convert to torch tensor but do not move to device yet (it will cause issues with num_workers > 0)
|
35
|
+
yield [torch.as_tensor(s) for s in sample]
|
36
|
+
return gen
|
37
|
+
|
38
|
+
|
39
|
+
class IterableDatasetFromGenerator(IterableDataset):
|
40
|
+
"""
|
41
|
+
PyTorch iterable dataset built from a data generator factory.
|
42
|
+
Each iteration over the dataset corresponds to one pass over a fresh instance of a data generator.
|
43
|
+
Therefore, if the data generator factory creates data generator instances that yield different samples,
|
44
|
+
this behavior is preserved.
|
45
|
+
"""
|
46
|
+
|
47
|
+
def __init__(self, data_gen_fn: Callable[[], Generator]):
|
48
|
+
"""
|
49
|
+
Args:
|
50
|
+
data_gen_fn: a factory for data generator that yields lists of tensors.
|
51
|
+
"""
|
52
|
+
# validate one batch
|
53
|
+
test_batch = next(data_gen_fn())
|
54
|
+
if not isinstance(test_batch, list):
|
55
|
+
raise TypeError(f'Data generator is expected to yield a list of tensors, got {type(test_batch)}')
|
56
|
+
self.orig_batch_size = test_batch[0].shape[0]
|
57
|
+
|
58
|
+
self._size = None
|
59
|
+
self._gen_fn = flat_gen_fn(data_gen_fn)
|
60
|
+
|
61
|
+
def __iter__(self):
|
62
|
+
""" Return an iterator for the dataset. """
|
63
|
+
return self._gen_fn()
|
64
|
+
|
65
|
+
def __len__(self):
|
66
|
+
""" Get the length of the dataset. """
|
67
|
+
if self._size is None:
|
68
|
+
self._size = sum(1 for _ in self)
|
69
|
+
return self._size
|
70
|
+
|
71
|
+
|
72
|
+
class FixedDatasetFromGenerator(Dataset):
|
73
|
+
"""
|
74
|
+
Dataset containing a fixed number of samples (i.e. same samples are yielded in each epoch), retrieved from a
|
75
|
+
data generator.
|
76
|
+
Note that the samples are stored in memory.
|
77
|
+
|
78
|
+
Attributes:
|
79
|
+
orig_batch_size: the batch size of the input data generator (retrieved from the first batch).
|
80
|
+
"""
|
81
|
+
def __init__(self, data_gen_fn: Callable[[], Generator], n_samples: int = None):
|
82
|
+
"""
|
83
|
+
Args:
|
84
|
+
data_gen_fn: data generator factory.
|
85
|
+
n_samples: target size of the dataset. If None, use all samples yielded by the data generator in one pass.
|
86
|
+
"""
|
87
|
+
test_batch = next(data_gen_fn())
|
88
|
+
if not isinstance(test_batch, list):
|
89
|
+
raise TypeError(f'Data generator is expected to yield a list of tensors, got {type(test_batch)}')
|
90
|
+
self.orig_batch_size = test_batch[0].shape[0]
|
91
|
+
|
92
|
+
samples = []
|
93
|
+
for batch in data_gen_fn():
|
94
|
+
# convert to torch tensor but do not move to device yet (it will cause issues with num_workers > 0)
|
95
|
+
batch = [torch.as_tensor(t) for t in batch]
|
96
|
+
samples.extend(zip(*batch))
|
97
|
+
if n_samples is not None and len(samples) >= n_samples:
|
98
|
+
samples = samples[:n_samples]
|
99
|
+
break
|
100
|
+
|
101
|
+
if n_samples is not None and len(samples) < n_samples:
|
102
|
+
raise ValueError(f'Not enough samples in the data generator to create a dataset with {n_samples}')
|
103
|
+
self.samples = samples
|
104
|
+
|
105
|
+
def __len__(self):
|
106
|
+
return len(self.samples)
|
107
|
+
|
108
|
+
def __getitem__(self, index):
|
109
|
+
return list(self.samples[index])
|
110
|
+
|
111
|
+
|
112
|
+
class FixedSampleInfoDataset(Dataset):
|
113
|
+
"""
|
114
|
+
Dataset for samples augmented with additional info per sample.
|
115
|
+
Each element in the dataset is a tuple containing the sample and sample's additional info.
|
116
|
+
"""
|
117
|
+
def __init__(self, samples: Sequence, *sample_info: Sequence):
|
118
|
+
"""
|
119
|
+
Args:
|
120
|
+
samples: a sequence of input samples.
|
121
|
+
hessians: one or more sequences of samples complementary data of matching sizes.
|
122
|
+
"""
|
123
|
+
if not all(len(info) == len(samples) for info in sample_info):
|
124
|
+
raise ValueError('Mismatch in the number of samples between samples and complementary data.')
|
125
|
+
self.samples = samples
|
126
|
+
self.sample_info = sample_info
|
127
|
+
|
128
|
+
def __getitem__(self, index):
|
129
|
+
return self.samples[index], *[info[index] for info in self.sample_info]
|
130
|
+
|
131
|
+
def __len__(self):
|
132
|
+
return len(self.samples)
|
133
|
+
|
134
|
+
|
135
|
+
class IterableSampleWithConstInfoDataset(IterableDataset):
|
136
|
+
"""
|
137
|
+
A convenience dataset that augments each sample with additional info shared by all samples.
|
138
|
+
"""
|
139
|
+
def __init__(self, samples_dataset: Dataset, *info: Any):
|
140
|
+
"""
|
141
|
+
Args:
|
142
|
+
samples_dataset: any dataset containing samples.
|
143
|
+
*sample_info: one or more static entities to augment each sample.
|
144
|
+
"""
|
145
|
+
self.samples_dataset = samples_dataset
|
146
|
+
self.info = info
|
147
|
+
|
148
|
+
def __iter__(self):
|
149
|
+
for sample in self.samples_dataset:
|
150
|
+
yield sample, *self.info
|
151
|
+
|
152
|
+
|
153
|
+
def get_collate_fn_with_extra_outputs(*extra_outputs: Any) -> Callable:
|
154
|
+
""" Collation function that adds const extra outputs to each batch. """
|
155
|
+
def f(batch):
|
156
|
+
return default_collate(batch) + list(extra_outputs)
|
157
|
+
return f
|
158
|
+
|
159
|
+
|
160
|
+
def data_gen_to_dataloader(data_gen_fn: Callable[[], Generator], batch_size, **kwargs):
|
161
|
+
""" Create DataLoader based on samples yielded by data_gen. """
|
162
|
+
dataset = IterableDatasetFromGenerator(data_gen_fn)
|
163
|
+
return DataLoader(dataset, batch_size=batch_size, **kwargs)
|
model_compression_toolkit/core/pytorch/hessian/activation_hessian_scores_calculator_pytorch.py
CHANGED
@@ -15,20 +15,19 @@
|
|
15
15
|
|
16
16
|
from typing import List
|
17
17
|
|
18
|
+
import numpy as np
|
19
|
+
import torch
|
18
20
|
from torch import autograd
|
19
21
|
from tqdm import tqdm
|
20
|
-
import numpy as np
|
21
22
|
|
22
23
|
from model_compression_toolkit.constants import MIN_HESSIAN_ITER, HESSIAN_COMP_TOLERANCE, HESSIAN_NUM_ITERATIONS
|
23
24
|
from model_compression_toolkit.core.common import Graph
|
24
|
-
from model_compression_toolkit.core.common.hessian import
|
25
|
-
HessianEstimationDistribution)
|
25
|
+
from model_compression_toolkit.core.common.hessian import HessianScoresRequest, HessianScoresGranularity
|
26
26
|
from model_compression_toolkit.core.pytorch.back2framework.float_model_builder import FloatPyTorchModelBuilder
|
27
27
|
from model_compression_toolkit.core.pytorch.hessian.hessian_scores_calculator_pytorch import \
|
28
28
|
HessianScoresCalculatorPytorch
|
29
29
|
from model_compression_toolkit.core.pytorch.utils import torch_tensor_to_numpy
|
30
30
|
from model_compression_toolkit.logger import Logger
|
31
|
-
import torch
|
32
31
|
|
33
32
|
|
34
33
|
class ActivationHessianScoresCalculatorPytorch(HessianScoresCalculatorPytorch):
|
@@ -86,36 +85,12 @@ class ActivationHessianScoresCalculatorPytorch(HessianScoresCalculatorPytorch):
|
|
86
85
|
target_activation_tensors = outputs[:num_target_nodes]
|
87
86
|
# Extract the model outputs
|
88
87
|
output_tensors = outputs[num_target_nodes:]
|
89
|
-
device = output_tensors[0].device
|
90
88
|
|
91
89
|
# Concat outputs
|
92
90
|
# First, we need to unfold all outputs that are given as list, to extract the actual output tensors
|
93
91
|
output = self.concat_tensors(output_tensors)
|
94
92
|
return output, target_activation_tensors
|
95
93
|
|
96
|
-
def _generate_random_vectors_batch(self, shape: tuple, distribution: HessianEstimationDistribution,
|
97
|
-
device: torch.device) -> torch.Tensor:
|
98
|
-
"""
|
99
|
-
Generate a batch of random vectors for Hutchinson estimation
|
100
|
-
|
101
|
-
Args:
|
102
|
-
shape: target shape
|
103
|
-
distribution: distribution to sample from
|
104
|
-
device: target device
|
105
|
-
|
106
|
-
Returns:
|
107
|
-
Random tensor
|
108
|
-
"""
|
109
|
-
if distribution == HessianEstimationDistribution.GAUSSIAN:
|
110
|
-
return torch.randn(shape, device=device)
|
111
|
-
|
112
|
-
if distribution == HessianEstimationDistribution.RADEMACHER:
|
113
|
-
v = torch.randint(high=2, size=shape, device=device)
|
114
|
-
v[v == 0] = -1
|
115
|
-
return v
|
116
|
-
|
117
|
-
raise ValueError(f'Unknown distribution {distribution}') # pragma: no cover
|
118
|
-
|
119
94
|
def compute(self) -> List[np.ndarray]:
|
120
95
|
"""
|
121
96
|
Compute the scores that are based on the approximation of the Hessian w.r.t the requested target nodes' activations.
|
@@ -142,8 +117,8 @@ class ActivationHessianScoresCalculatorPytorch(HessianScoresCalculatorPytorch):
|
|
142
117
|
for _ in range(len(target_activation_tensors))]
|
143
118
|
prev_mean_results = None
|
144
119
|
for j in tqdm(range(self.num_iterations_for_approximation), "Hessian random iterations"): # Approximation iterations
|
145
|
-
# Getting a random vector
|
146
|
-
v = self._generate_random_vectors_batch(output.shape,
|
120
|
+
# Getting a random vector
|
121
|
+
v = self._generate_random_vectors_batch(output.shape, output.device)
|
147
122
|
f_v = torch.sum(v * output)
|
148
123
|
for i, ipt_tensor in enumerate(target_activation_tensors): # Per Interest point activation tensor
|
149
124
|
# Computing the hessian-approximation scores by getting the gradient of (output * v)
|
@@ -184,7 +159,7 @@ class ActivationHessianScoresCalculatorPytorch(HessianScoresCalculatorPytorch):
|
|
184
159
|
for _ in range(len(target_activation_tensors))]
|
185
160
|
|
186
161
|
for j in tqdm(range(self.num_iterations_for_approximation), "Hessian random iterations"): # Approximation iterations
|
187
|
-
v = self._generate_random_vectors_batch(output.shape,
|
162
|
+
v = self._generate_random_vectors_batch(output.shape, output.device)
|
188
163
|
f_v = torch.sum(v * output)
|
189
164
|
for i, ipt_tensor in enumerate(target_activation_tensors): # Per Interest point activation tensor
|
190
165
|
hess_v = autograd.grad(outputs=f_v,
|