mct-nightly 2.2.0.20241007.529__py3-none-any.whl → 2.2.0.20241008.450__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mct_nightly-2.2.0.20241007.529.dist-info → mct_nightly-2.2.0.20241008.450.dist-info}/METADATA +1 -1
- {mct_nightly-2.2.0.20241007.529.dist-info → mct_nightly-2.2.0.20241008.450.dist-info}/RECORD +17 -17
- model_compression_toolkit/__init__.py +1 -1
- model_compression_toolkit/core/common/hessian/__init__.py +3 -1
- model_compression_toolkit/core/common/hessian/hessian_info_service.py +60 -2
- model_compression_toolkit/core/common/hessian/hessian_scores_request.py +14 -3
- model_compression_toolkit/core/pytorch/hessian/activation_hessian_scores_calculator_pytorch.py +137 -88
- model_compression_toolkit/gptq/common/gptq_config.py +7 -2
- model_compression_toolkit/gptq/common/gptq_training.py +43 -12
- model_compression_toolkit/gptq/pytorch/gptq_loss.py +36 -1
- model_compression_toolkit/gptq/pytorch/gptq_training.py +58 -8
- model_compression_toolkit/gptq/pytorch/quantization_facade.py +28 -7
- model_compression_toolkit/gptq/pytorch/quantizer/regularization_factory.py +1 -1
- model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/soft_quantizer_reg.py +14 -12
- {mct_nightly-2.2.0.20241007.529.dist-info → mct_nightly-2.2.0.20241008.450.dist-info}/LICENSE.md +0 -0
- {mct_nightly-2.2.0.20241007.529.dist-info → mct_nightly-2.2.0.20241008.450.dist-info}/WHEEL +0 -0
- {mct_nightly-2.2.0.20241007.529.dist-info → mct_nightly-2.2.0.20241008.450.dist-info}/top_level.txt +0 -0
{mct_nightly-2.2.0.20241007.529.dist-info → mct_nightly-2.2.0.20241008.450.dist-info}/RECORD
RENAMED
@@ -1,4 +1,4 @@
|
|
1
|
-
model_compression_toolkit/__init__.py,sha256=
|
1
|
+
model_compression_toolkit/__init__.py,sha256=N9yCh68lSsYuGo6DuxotIhOSedwXIAg8XDYshb0Nz4g,1573
|
2
2
|
model_compression_toolkit/constants.py,sha256=i4wYheBkIdQmsQA-axIpcT3YiSO1USNc-jaNiNE8w6E,3920
|
3
3
|
model_compression_toolkit/defaultdict.py,sha256=LSc-sbZYXENMCw3U9F4GiXuv67IKpdn0Qm7Fr11jy-4,2277
|
4
4
|
model_compression_toolkit/logger.py,sha256=3DByV41XHRR3kLTJNbpaMmikL8icd9e1N-nkQAY9oDk,4567
|
@@ -46,11 +46,11 @@ model_compression_toolkit/core/common/graph/memory_graph/cut.py,sha256=aPdXJPP5a
|
|
46
46
|
model_compression_toolkit/core/common/graph/memory_graph/max_cut_astar.py,sha256=crV2NCLVO8jx9MlryZBYuJKFe_G9HfM7rUR64fDymlw,17045
|
47
47
|
model_compression_toolkit/core/common/graph/memory_graph/memory_element.py,sha256=gRmBEFRmyJsNKezQfiwDwQu1cmbGd2wgKCRTH6iw8mw,3961
|
48
48
|
model_compression_toolkit/core/common/graph/memory_graph/memory_graph.py,sha256=gw4av_rzn_3oEAPpD3B7PHZDqnxHMjIESevl6ppPnkk,7175
|
49
|
-
model_compression_toolkit/core/common/hessian/__init__.py,sha256=
|
50
|
-
model_compression_toolkit/core/common/hessian/hessian_info_service.py,sha256=
|
49
|
+
model_compression_toolkit/core/common/hessian/__init__.py,sha256=Sj3I9mLBq-yrcBFxpUkOy0Rb5pxJQBPcECvgyOqhHSY,1064
|
50
|
+
model_compression_toolkit/core/common/hessian/hessian_info_service.py,sha256=fUgW-AUhRu609_RSRd1WKaQAfPk2SmLnlkT74v6TZwY,23769
|
51
51
|
model_compression_toolkit/core/common/hessian/hessian_info_utils.py,sha256=1axmN0tjJSo_7hUr2d2KMv4y1pBi19cqWSQpi4BbdsA,1458
|
52
52
|
model_compression_toolkit/core/common/hessian/hessian_scores_calculator.py,sha256=Pe4uKerx-MeDQPJ7Slr8fvFUHfv02q33w3gbQK5kBKs,4186
|
53
|
-
model_compression_toolkit/core/common/hessian/hessian_scores_request.py,sha256=
|
53
|
+
model_compression_toolkit/core/common/hessian/hessian_scores_request.py,sha256=fYXcOMa2bpbJjQ2S4r021WOvhoDWFa_jy95hofqVBFA,3632
|
54
54
|
model_compression_toolkit/core/common/matchers/__init__.py,sha256=sw7LOPN1bM82o3SkMaklyH0jw-TLGK0-fl2Wq73rffI,697
|
55
55
|
model_compression_toolkit/core/common/matchers/base_graph_filter.py,sha256=mTk54z0mIbFmPOb4h0xfLtLDookcFyNh8H0pIN5js_M,3091
|
56
56
|
model_compression_toolkit/core/common/matchers/base_matcher.py,sha256=JCj-NLAXOJa-GcSX-94PVUTWjooQUd0NemiyNg5uKGQ,2210
|
@@ -256,7 +256,7 @@ model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/transfo
|
|
256
256
|
model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/virtual_activation_weights_composition.py,sha256=WmEa8Xjji-_tIbthDxlLAGSr69nWk-YKcHNaVqLa7sg,1375
|
257
257
|
model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/weights_activation_split.py,sha256=tp78axmUQc0Zpj3KwVmV0PGYHvCf7sAW_sRmXXw7gsY,1616
|
258
258
|
model_compression_toolkit/core/pytorch/hessian/__init__.py,sha256=lNJ29DYxaLUPDstRDA1PGI5r9Fulq_hvrZMlhst1Z5g,697
|
259
|
-
model_compression_toolkit/core/pytorch/hessian/activation_hessian_scores_calculator_pytorch.py,sha256=
|
259
|
+
model_compression_toolkit/core/pytorch/hessian/activation_hessian_scores_calculator_pytorch.py,sha256=fKeql1cXOieHTbxQDOIMpFO1sVktqXVCRBgZkv3R13Q,10929
|
260
260
|
model_compression_toolkit/core/pytorch/hessian/hessian_scores_calculator_pytorch.py,sha256=vXluX-awgavv7DGihG9HrlvLhak8qIHy837PPTOd4jg,3471
|
261
261
|
model_compression_toolkit/core/pytorch/hessian/weights_hessian_scores_calculator_pytorch.py,sha256=C4-7naBQUh8TN6fEwkyKY6rlY_nvHSAmCnWT4iMBs8E,8497
|
262
262
|
model_compression_toolkit/core/pytorch/mixed_precision/__init__.py,sha256=Rf1RcYmelmdZmBV5qOKvKWF575ofc06JFQSq83Jz99A,696
|
@@ -345,11 +345,11 @@ model_compression_toolkit/exporter/model_wrapper/pytorch/builder/node_to_quantiz
|
|
345
345
|
model_compression_toolkit/gptq/__init__.py,sha256=pEgkJvmf05KSw70iLDTz_6LI_2Oi5L8sTN0JsEUpnpk,1445
|
346
346
|
model_compression_toolkit/gptq/runner.py,sha256=La12JTYjWyJW0YW4Al4TP1_Xi4JWBCEKw6FR_JQsxe0,5982
|
347
347
|
model_compression_toolkit/gptq/common/__init__.py,sha256=cco4TmeIDIh32nj9ZZXVkws4dd9F2UDrmjKzTN8G0V0,697
|
348
|
-
model_compression_toolkit/gptq/common/gptq_config.py,sha256=
|
348
|
+
model_compression_toolkit/gptq/common/gptq_config.py,sha256=GP4lcDeyVgXA-QFArDW28UucOOKY0zeYJpq2pvyNVM8,6510
|
349
349
|
model_compression_toolkit/gptq/common/gptq_constants.py,sha256=QSm6laLkIV0LYmU0BLtmKp3Fi3SqDfbncFQWOGA1cGU,611
|
350
350
|
model_compression_toolkit/gptq/common/gptq_framework_implementation.py,sha256=n3mSf4J92kFjekzyGyrJULylI-8Jf5OVWJ5AFoVnEx0,1266
|
351
351
|
model_compression_toolkit/gptq/common/gptq_graph.py,sha256=-bL5HhPcKqV8nj4dZPXc5QmQJbFBel6etrioikP0tEo,3039
|
352
|
-
model_compression_toolkit/gptq/common/gptq_training.py,sha256=
|
352
|
+
model_compression_toolkit/gptq/common/gptq_training.py,sha256=dRNEjjKdVqlazbGWjZNE9q-MsU0PBffGKHfDpy3NX5Q,16661
|
353
353
|
model_compression_toolkit/gptq/keras/__init__.py,sha256=cco4TmeIDIh32nj9ZZXVkws4dd9F2UDrmjKzTN8G0V0,697
|
354
354
|
model_compression_toolkit/gptq/keras/gptq_keras_implementation.py,sha256=axBwnCSjq5xk-xGymOwSOqjp39It-CVtGcCTRTf0E_4,1248
|
355
355
|
model_compression_toolkit/gptq/keras/gptq_loss.py,sha256=rbRkF15MYd6nq4G49kcjb_dPTa-XNq9cTkrb93mXawo,6241
|
@@ -368,19 +368,19 @@ model_compression_toolkit/gptq/keras/quantizer/soft_rounding/uniform_soft_quanti
|
|
368
368
|
model_compression_toolkit/gptq/keras/quantizer/ste_rounding/__init__.py,sha256=cco4TmeIDIh32nj9ZZXVkws4dd9F2UDrmjKzTN8G0V0,697
|
369
369
|
model_compression_toolkit/gptq/keras/quantizer/ste_rounding/symmetric_ste.py,sha256=pgZADwaNWUwm9QTrYaW6yXE3-zfedPZSa9TKBVedNd4,8356
|
370
370
|
model_compression_toolkit/gptq/pytorch/__init__.py,sha256=cco4TmeIDIh32nj9ZZXVkws4dd9F2UDrmjKzTN8G0V0,697
|
371
|
-
model_compression_toolkit/gptq/pytorch/gptq_loss.py,sha256=
|
371
|
+
model_compression_toolkit/gptq/pytorch/gptq_loss.py,sha256=nVW3URcCWQywoXfmTOBMxliZVvosshf4-G0Sq7dNwzU,3877
|
372
372
|
model_compression_toolkit/gptq/pytorch/gptq_pytorch_implementation.py,sha256=tECPTavxn8EEwgLaP2zvxdJH6Vg9jC0YOIMJ7857Sdc,1268
|
373
|
-
model_compression_toolkit/gptq/pytorch/gptq_training.py,sha256=
|
373
|
+
model_compression_toolkit/gptq/pytorch/gptq_training.py,sha256=j_FZcs8ey_9voI83TrL4q1Mne59zO2_v0MzdhZcxWuY,20071
|
374
374
|
model_compression_toolkit/gptq/pytorch/graph_info.py,sha256=4mVM-VvnBaA64ACVdOe6wTGHdMSa2UTLIUe7nACLcdo,4008
|
375
|
-
model_compression_toolkit/gptq/pytorch/quantization_facade.py,sha256=
|
375
|
+
model_compression_toolkit/gptq/pytorch/quantization_facade.py,sha256=7UPaLBx66mJIlDTpT1uLI9LpHPzOr8EtywZ0aawveDA,16527
|
376
376
|
model_compression_toolkit/gptq/pytorch/quantizer/__init__.py,sha256=ZHNHo1yzye44m9_ht4UUZfTpK01RiVR3Tr74-vtnOGI,968
|
377
377
|
model_compression_toolkit/gptq/pytorch/quantizer/base_pytorch_gptq_quantizer.py,sha256=fKg-PNOhGBiL-4eySS9Fyw0GkA76Pq8jT_HbJuJ8iZU,4143
|
378
378
|
model_compression_toolkit/gptq/pytorch/quantizer/gradual_activation_quantization.py,sha256=nngu2TeXjngkqt_6-wciFmCvo-dbpeh_tJJxBV_cfHk,3686
|
379
379
|
model_compression_toolkit/gptq/pytorch/quantizer/quant_utils.py,sha256=OocYYRqvl7rZ37QT0hTzfJnWGiNCPskg7cziTlR7TRk,3893
|
380
380
|
model_compression_toolkit/gptq/pytorch/quantizer/quantization_builder.py,sha256=5EyAzvlU01vLyXmMwY_8dNyb7GwYktXmnrvUON8n8WI,4696
|
381
|
-
model_compression_toolkit/gptq/pytorch/quantizer/regularization_factory.py,sha256=
|
381
|
+
model_compression_toolkit/gptq/pytorch/quantizer/regularization_factory.py,sha256=H6pARLK-jq3cKoaipY0SK9wMGrqy6CSEZTk14KdrKA0,2105
|
382
382
|
model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/__init__.py,sha256=lNJ29DYxaLUPDstRDA1PGI5r9Fulq_hvrZMlhst1Z5g,697
|
383
|
-
model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/soft_quantizer_reg.py,sha256=
|
383
|
+
model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/soft_quantizer_reg.py,sha256=vlQEhif-R49UstORkXmpMA4ZE82Aqh-mJqKCnB31gag,3005
|
384
384
|
model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/symmetric_soft_quantizer.py,sha256=kLVQC1hXzDpP4Jx7AwnA764oGnY5AMEuvUUhAvhz09M,12347
|
385
385
|
model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/uniform_soft_quantizer.py,sha256=FgPSKoV8p8y-gLNz359XdOPD6w_wpDvcJFtTNLWqYb0,9099
|
386
386
|
model_compression_toolkit/gptq/pytorch/quantizer/ste_rounding/__init__.py,sha256=Rf1RcYmelmdZmBV5qOKvKWF575ofc06JFQSq83Jz99A,696
|
@@ -551,8 +551,8 @@ tests_pytest/pytorch/gptq/test_annealing_cfg.py,sha256=hGC7L6mp3N1ygcJ3OctgS_Fz2
|
|
551
551
|
tests_pytest/pytorch/gptq/test_gradual_act_quantization.py,sha256=tI01aFIUaiCILL5Qn--p1E_rLBUelxLdSY3k52lwcx0,4594
|
552
552
|
tests_pytest/pytorch/trainable_infrastructure/__init__.py,sha256=RAe8mgIr1V8dRIQtLf_dSG5zTUCKuQzxyybYx1dzEAs,697
|
553
553
|
tests_pytest/pytorch/trainable_infrastructure/test_linear_annealing.py,sha256=eNOpSp0GoLxtEdiRypBp8jaujXfdNxBwKh5Rd-P7WLs,1786
|
554
|
-
mct_nightly-2.2.0.
|
555
|
-
mct_nightly-2.2.0.
|
556
|
-
mct_nightly-2.2.0.
|
557
|
-
mct_nightly-2.2.0.
|
558
|
-
mct_nightly-2.2.0.
|
554
|
+
mct_nightly-2.2.0.20241008.450.dist-info/LICENSE.md,sha256=aYSSIb-5AFPeITTvXm1UAoe0uYBiMmSS8flvXaaFUks,10174
|
555
|
+
mct_nightly-2.2.0.20241008.450.dist-info/METADATA,sha256=J3vzhM5gpeuXgdgaqJRl6bQc17gSCONWmeNCPLYvyTs,20830
|
556
|
+
mct_nightly-2.2.0.20241008.450.dist-info/WHEEL,sha256=eOLhNAGa2EW3wWl_TU484h7q1UNgy0JXjjoqKoxAAQc,92
|
557
|
+
mct_nightly-2.2.0.20241008.450.dist-info/top_level.txt,sha256=csdfSXhtRnpWYRzjZ-dRLIhOmM2TEdVXUxG05A5fgb8,39
|
558
|
+
mct_nightly-2.2.0.20241008.450.dist-info/RECORD,,
|
@@ -27,4 +27,4 @@ from model_compression_toolkit import data_generation
|
|
27
27
|
from model_compression_toolkit import pruning
|
28
28
|
from model_compression_toolkit.trainable_infrastructure.keras.load_model import keras_load_quantized_model
|
29
29
|
|
30
|
-
__version__ = "2.2.0.
|
30
|
+
__version__ = "2.2.0.20241008.000450"
|
@@ -12,6 +12,8 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
|
-
from model_compression_toolkit.core.common.hessian.hessian_scores_request import
|
15
|
+
from model_compression_toolkit.core.common.hessian.hessian_scores_request import (
|
16
|
+
HessianScoresRequest, HessianMode, HessianScoresGranularity, HessianEstimationDistribution
|
17
|
+
)
|
16
18
|
from model_compression_toolkit.core.common.hessian.hessian_info_service import HessianInfoService
|
17
19
|
import model_compression_toolkit.core.common.hessian.hessian_info_utils as hessian_utils
|
@@ -12,16 +12,19 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
|
+
import hashlib
|
15
16
|
|
16
17
|
import numpy as np
|
17
18
|
from functools import partial
|
18
19
|
from tqdm import tqdm
|
19
|
-
from typing import Callable, List, Dict, Any, Tuple
|
20
|
+
from typing import Callable, List, Dict, Any, Tuple, TYPE_CHECKING
|
20
21
|
|
21
22
|
from model_compression_toolkit.constants import HESSIAN_NUM_ITERATIONS
|
22
23
|
from model_compression_toolkit.core.common.hessian.hessian_scores_request import HessianScoresRequest, \
|
23
24
|
HessianScoresGranularity, HessianMode
|
24
25
|
from model_compression_toolkit.logger import Logger
|
26
|
+
if TYPE_CHECKING: # pragma: no cover
|
27
|
+
from model_compression_toolkit.core.common import BaseNode
|
25
28
|
|
26
29
|
|
27
30
|
class HessianInfoService:
|
@@ -228,6 +231,61 @@ class HessianInfoService:
|
|
228
231
|
return next_iter_remain_samples if next_iter_remain_samples is not None and len(next_iter_remain_samples) > 0 \
|
229
232
|
and len(next_iter_remain_samples[0]) > 0 else None
|
230
233
|
|
234
|
+
def compute_trackable_per_sample_hessian(self,
|
235
|
+
hessian_scores_request: HessianScoresRequest,
|
236
|
+
inputs_batch: List[np.ndarray]) -> Dict[str, Dict['BaseNode', np.ndarray]]:
|
237
|
+
"""
|
238
|
+
Compute hessian score per image hash. We compute the score directly for images rather than via data generator,
|
239
|
+
as data generator might yield different images each time, depending on how it was defined,
|
240
|
+
|
241
|
+
Args:
|
242
|
+
hessian_scores_request: hessian scores request
|
243
|
+
inputs_batch: a list containing a batch of inputs.
|
244
|
+
|
245
|
+
Returns:
|
246
|
+
A dict of Hessian scores per image hash per layer {image hash: {layer: score}}
|
247
|
+
"""
|
248
|
+
topo_sorted_nodes_names = [x.name for x in self.graph.get_topo_sorted_nodes()]
|
249
|
+
hessian_scores_request.target_nodes.sort(key=lambda x: topo_sorted_nodes_names.index(x.name))
|
250
|
+
|
251
|
+
hessian_score_by_image_hash = {}
|
252
|
+
|
253
|
+
if not inputs_batch or not isinstance(inputs_batch, list):
|
254
|
+
raise TypeError('Expected a non-empty list of inputs') # pragma: no cover
|
255
|
+
if len(inputs_batch) > 1:
|
256
|
+
raise NotImplementedError('Per-sample hessian computation is not supported for networks with multiple inputs') # pragma: no cover
|
257
|
+
|
258
|
+
# Get the framework-specific calculator Hessian-approximation scores
|
259
|
+
fw_hessian_calculator = self.fw_impl.get_hessian_scores_calculator(graph=self.graph,
|
260
|
+
input_images=inputs_batch,
|
261
|
+
hessian_scores_request=hessian_scores_request,
|
262
|
+
num_iterations_for_approximation=self.num_iterations_for_approximation)
|
263
|
+
hessian_scores = fw_hessian_calculator.compute()
|
264
|
+
for i in range(inputs_batch[0].shape[0]):
|
265
|
+
img_hash = self.calc_image_hash(inputs_batch[0][i])
|
266
|
+
hessian_score_by_image_hash[img_hash] = {
|
267
|
+
node: score[i] for node, score in zip(hessian_scores_request.target_nodes, hessian_scores)
|
268
|
+
}
|
269
|
+
|
270
|
+
return hessian_score_by_image_hash
|
271
|
+
|
272
|
+
@staticmethod
|
273
|
+
def calc_image_hash(image):
|
274
|
+
"""
|
275
|
+
Calculates hash for an input image.
|
276
|
+
|
277
|
+
Args:
|
278
|
+
image: input 3d image (without batch).
|
279
|
+
|
280
|
+
Returns:
|
281
|
+
Image hash.
|
282
|
+
|
283
|
+
"""
|
284
|
+
if not len(image.shape) == 3: # pragma: no cover
|
285
|
+
raise ValueError(f'Expected 3d image (without batch) for image hash calculation, got {len(image.shape)}')
|
286
|
+
image_bytes = image.astype(np.float32).tobytes()
|
287
|
+
return hashlib.md5(image_bytes).hexdigest()
|
288
|
+
|
231
289
|
def fetch_hessian(self,
|
232
290
|
hessian_scores_request: HessianScoresRequest,
|
233
291
|
required_size: int,
|
@@ -248,7 +306,7 @@ class HessianInfoService:
|
|
248
306
|
OC for per-output-channel when the requested node has OC output-channels, etc.)
|
249
307
|
"""
|
250
308
|
|
251
|
-
if len(hessian_scores_request.target_nodes) == 0:
|
309
|
+
if len(hessian_scores_request.target_nodes) == 0: # pragma: no cover
|
252
310
|
return []
|
253
311
|
|
254
312
|
if required_size == 0:
|
@@ -40,6 +40,14 @@ class HessianScoresGranularity(Enum):
|
|
40
40
|
PER_TENSOR = 2
|
41
41
|
|
42
42
|
|
43
|
+
class HessianEstimationDistribution(str, Enum):
|
44
|
+
"""
|
45
|
+
Distribution for Hutchinson estimator random vector
|
46
|
+
"""
|
47
|
+
GAUSSIAN = 'gaussian'
|
48
|
+
RADEMACHER = 'rademacher'
|
49
|
+
|
50
|
+
|
43
51
|
class HessianScoresRequest:
|
44
52
|
"""
|
45
53
|
Request configuration for the Hessian-approximation scores.
|
@@ -53,7 +61,8 @@ class HessianScoresRequest:
|
|
53
61
|
def __init__(self,
|
54
62
|
mode: HessianMode,
|
55
63
|
granularity: HessianScoresGranularity,
|
56
|
-
target_nodes: List
|
64
|
+
target_nodes: List,
|
65
|
+
distribution: HessianEstimationDistribution = HessianEstimationDistribution.GAUSSIAN):
|
57
66
|
"""
|
58
67
|
Attributes:
|
59
68
|
mode (HessianMode): Mode of Hessian-approximation score (w.r.t weights or activations).
|
@@ -64,6 +73,7 @@ class HessianScoresRequest:
|
|
64
73
|
self.mode = mode # w.r.t activations or weights
|
65
74
|
self.granularity = granularity # per element, per layer, per channel
|
66
75
|
self.target_nodes = target_nodes
|
76
|
+
self.distribution = distribution
|
67
77
|
|
68
78
|
def __eq__(self, other):
|
69
79
|
# Checks if the other object is an instance of HessianScoresRequest
|
@@ -71,9 +81,10 @@ class HessianScoresRequest:
|
|
71
81
|
return isinstance(other, HessianScoresRequest) and \
|
72
82
|
self.mode == other.mode and \
|
73
83
|
self.granularity == other.granularity and \
|
74
|
-
self.target_nodes == other.target_nodes
|
84
|
+
self.target_nodes == other.target_nodes and \
|
85
|
+
self.distribution == other.distribution
|
75
86
|
|
76
87
|
def __hash__(self):
|
77
88
|
# Computes the hash based on the attributes.
|
78
89
|
# The use of a tuple here ensures that the hash is influenced by all the attributes.
|
79
|
-
return hash((self.mode, self.granularity, tuple(self.target_nodes)))
|
90
|
+
return hash((self.mode, self.granularity, tuple(self.target_nodes), self.distribution))
|
model_compression_toolkit/core/pytorch/hessian/activation_hessian_scores_calculator_pytorch.py
CHANGED
@@ -21,7 +21,8 @@ import numpy as np
|
|
21
21
|
|
22
22
|
from model_compression_toolkit.constants import MIN_HESSIAN_ITER, HESSIAN_COMP_TOLERANCE, HESSIAN_NUM_ITERATIONS
|
23
23
|
from model_compression_toolkit.core.common import Graph
|
24
|
-
from model_compression_toolkit.core.common.hessian import HessianScoresRequest, HessianScoresGranularity
|
24
|
+
from model_compression_toolkit.core.common.hessian import (HessianScoresRequest, HessianScoresGranularity,
|
25
|
+
HessianEstimationDistribution)
|
25
26
|
from model_compression_toolkit.core.pytorch.back2framework.float_model_builder import FloatPyTorchModelBuilder
|
26
27
|
from model_compression_toolkit.core.pytorch.hessian.hessian_scores_calculator_pytorch import \
|
27
28
|
HessianScoresCalculatorPytorch
|
@@ -55,6 +56,66 @@ class ActivationHessianScoresCalculatorPytorch(HessianScoresCalculatorPytorch):
|
|
55
56
|
hessian_scores_request=hessian_scores_request,
|
56
57
|
num_iterations_for_approximation=num_iterations_for_approximation)
|
57
58
|
|
59
|
+
def forward_pass(self):
|
60
|
+
model_output_nodes = [ot.node for ot in self.graph.get_outputs()]
|
61
|
+
|
62
|
+
if len([n for n in self.hessian_request.target_nodes if n in model_output_nodes]) > 0:
|
63
|
+
Logger.critical("Activation Hessian approximation cannot be computed for model outputs. "
|
64
|
+
"Exclude output nodes from Hessian request targets.")
|
65
|
+
|
66
|
+
grad_model_outputs = self.hessian_request.target_nodes + model_output_nodes
|
67
|
+
model, _ = FloatPyTorchModelBuilder(graph=self.graph, append2output=grad_model_outputs).build_model()
|
68
|
+
model.eval()
|
69
|
+
|
70
|
+
# Run model inference
|
71
|
+
# Set inputs to track gradients during inference
|
72
|
+
for input_tensor in self.input_images:
|
73
|
+
input_tensor.requires_grad_()
|
74
|
+
input_tensor.retain_grad()
|
75
|
+
|
76
|
+
outputs = model(*self.input_images)
|
77
|
+
|
78
|
+
if len(outputs) != len(grad_model_outputs): # pragma: no cover
|
79
|
+
Logger.critical(f"Mismatch in expected and actual model outputs for activation Hessian approximation. "
|
80
|
+
f"Expected {len(grad_model_outputs)} outputs, received {len(outputs)}.")
|
81
|
+
|
82
|
+
# Extracting the intermediate activation tensors and the model real output.
|
83
|
+
# Note that we do not allow computing Hessian for output nodes, so there shouldn't be an overlap.
|
84
|
+
num_target_nodes = len(self.hessian_request.target_nodes)
|
85
|
+
# Extract activation tensors of nodes for which we want to compute Hessian
|
86
|
+
target_activation_tensors = outputs[:num_target_nodes]
|
87
|
+
# Extract the model outputs
|
88
|
+
output_tensors = outputs[num_target_nodes:]
|
89
|
+
device = output_tensors[0].device
|
90
|
+
|
91
|
+
# Concat outputs
|
92
|
+
# First, we need to unfold all outputs that are given as list, to extract the actual output tensors
|
93
|
+
output = self.concat_tensors(output_tensors)
|
94
|
+
return output, target_activation_tensors
|
95
|
+
|
96
|
+
def _generate_random_vectors_batch(self, shape: tuple, distribution: HessianEstimationDistribution,
|
97
|
+
device: torch.device) -> torch.Tensor:
|
98
|
+
"""
|
99
|
+
Generate a batch of random vectors for Hutchinson estimation
|
100
|
+
|
101
|
+
Args:
|
102
|
+
shape: target shape
|
103
|
+
distribution: distribution to sample from
|
104
|
+
device: target device
|
105
|
+
|
106
|
+
Returns:
|
107
|
+
Random tensor
|
108
|
+
"""
|
109
|
+
if distribution == HessianEstimationDistribution.GAUSSIAN:
|
110
|
+
return torch.randn(shape, device=device)
|
111
|
+
|
112
|
+
if distribution == HessianEstimationDistribution.RADEMACHER:
|
113
|
+
v = torch.randint(high=2, size=shape, device=device)
|
114
|
+
v[v == 0] = -1
|
115
|
+
return v
|
116
|
+
|
117
|
+
raise ValueError(f'Unknown distribution {distribution}') # pragma: no cover
|
118
|
+
|
58
119
|
def compute(self) -> List[np.ndarray]:
|
59
120
|
"""
|
60
121
|
Compute the scores that are based on the approximation of the Hessian w.r.t the requested target nodes' activations.
|
@@ -62,91 +123,79 @@ class ActivationHessianScoresCalculatorPytorch(HessianScoresCalculatorPytorch):
|
|
62
123
|
Returns:
|
63
124
|
List[np.ndarray]: Scores based on the approximated Hessian for the requested nodes.
|
64
125
|
"""
|
65
|
-
|
66
|
-
|
67
|
-
model_output_nodes = [ot.node for ot in self.graph.get_outputs()]
|
68
|
-
|
69
|
-
if len([n for n in self.hessian_request.target_nodes if n in model_output_nodes]) > 0:
|
70
|
-
Logger.critical("Activation Hessian approximation cannot be computed for model outputs. "
|
71
|
-
"Exclude output nodes from Hessian request targets.")
|
72
|
-
|
73
|
-
grad_model_outputs = self.hessian_request.target_nodes + model_output_nodes
|
74
|
-
model, _ = FloatPyTorchModelBuilder(graph=self.graph, append2output=grad_model_outputs).build_model()
|
75
|
-
model.eval()
|
76
|
-
|
77
|
-
# Run model inference
|
78
|
-
# Set inputs to track gradients during inference
|
79
|
-
for input_tensor in self.input_images:
|
80
|
-
input_tensor.requires_grad_()
|
81
|
-
input_tensor.retain_grad()
|
82
|
-
|
83
|
-
outputs = model(*self.input_images)
|
84
|
-
|
85
|
-
if len(outputs) != len(grad_model_outputs): # pragma: no cover
|
86
|
-
Logger.critical(f"Mismatch in expected and actual model outputs for activation Hessian approximation. "
|
87
|
-
f"Expected {len(grad_model_outputs)} outputs, received {len(outputs)}.")
|
88
|
-
|
89
|
-
# Extracting the intermediate activation tensors and the model real output.
|
90
|
-
# Note that we do not allow computing Hessian for output nodes, so there shouldn't be an overlap.
|
91
|
-
num_target_nodes = len(self.hessian_request.target_nodes)
|
92
|
-
# Extract activation tensors of nodes for which we want to compute Hessian
|
93
|
-
target_activation_tensors = outputs[:num_target_nodes]
|
94
|
-
# Extract the model outputs
|
95
|
-
output_tensors = outputs[num_target_nodes:]
|
96
|
-
device = output_tensors[0].device
|
97
|
-
|
98
|
-
# Concat outputs
|
99
|
-
# First, we need to unfold all outputs that are given as list, to extract the actual output tensors
|
100
|
-
output = self.concat_tensors(output_tensors)
|
101
|
-
|
102
|
-
ipts_hessian_approx_scores = [torch.tensor([0.0],
|
103
|
-
requires_grad=True,
|
104
|
-
device=device)
|
105
|
-
for _ in range(len(target_activation_tensors))]
|
106
|
-
prev_mean_results = None
|
107
|
-
for j in tqdm(range(self.num_iterations_for_approximation), "Hessian random iterations"): # Approximation iterations
|
108
|
-
# Getting a random vector with normal distribution
|
109
|
-
v = torch.randn(output.shape, device=device)
|
110
|
-
f_v = torch.sum(v * output)
|
111
|
-
for i, ipt_tensor in enumerate(target_activation_tensors): # Per Interest point activation tensor
|
112
|
-
# Computing the hessian-approximation scores by getting the gradient of (output * v)
|
113
|
-
hess_v = autograd.grad(outputs=f_v,
|
114
|
-
inputs=ipt_tensor,
|
115
|
-
retain_graph=True,
|
116
|
-
allow_unused=True)[0]
|
117
|
-
|
118
|
-
if hess_v is None:
|
119
|
-
# In case we have an output node, which is an interest point, but it is not differentiable,
|
120
|
-
# we consider its Hessian to be the initial value 0.
|
121
|
-
continue # pragma: no cover
|
122
|
-
|
123
|
-
# Mean over all dims but the batch (CXHXW for conv)
|
124
|
-
hessian_approx_scores = torch.sum(hess_v ** 2.0, dim=tuple(d for d in range(1, len(hess_v.shape))))
|
125
|
-
|
126
|
-
# Update node Hessian approximation mean over random iterations
|
127
|
-
ipts_hessian_approx_scores[i] = (j * ipts_hessian_approx_scores[i] + hessian_approx_scores) / (j + 1)
|
128
|
-
|
129
|
-
# If the change to the maximal mean Hessian approximation is insignificant we stop the calculation
|
130
|
-
if j > MIN_HESSIAN_ITER:
|
131
|
-
if prev_mean_results is not None:
|
132
|
-
new_mean_res = torch.mean(torch.stack(ipts_hessian_approx_scores), dim=1)
|
133
|
-
relative_delta_per_node = (torch.abs(new_mean_res - prev_mean_results) /
|
134
|
-
(torch.abs(new_mean_res) + 1e-6))
|
135
|
-
max_delta = torch.max(relative_delta_per_node)
|
136
|
-
if max_delta < HESSIAN_COMP_TOLERANCE:
|
137
|
-
break
|
138
|
-
prev_mean_results = torch.mean(torch.stack(ipts_hessian_approx_scores), dim=1)
|
139
|
-
|
140
|
-
# Convert results to list of numpy arrays
|
141
|
-
hessian_results = [torch_tensor_to_numpy(h) for h in ipts_hessian_approx_scores]
|
142
|
-
# Extend the Hessian tensors shape to align with expected return type
|
143
|
-
# TODO: currently, only per-tensor Hessian is available for activation.
|
144
|
-
# Once implementing per-channel or per-element, this alignment needs to be verified and handled separately.
|
145
|
-
hessian_results = [h[..., np.newaxis] for h in hessian_results]
|
146
|
-
|
147
|
-
return hessian_results
|
148
|
-
|
149
|
-
else: # pragma: no cover
|
150
|
-
Logger.critical(f"PyTorch activation Hessian's approximation scores does not support "
|
151
|
-
f"{self.hessian_request.granularity} granularity.")
|
126
|
+
output, target_activation_tensors = self.forward_pass()
|
152
127
|
|
128
|
+
if self.hessian_request.granularity == HessianScoresGranularity.PER_TENSOR:
|
129
|
+
hessian_scores = self._compute_per_tensor(output, target_activation_tensors)
|
130
|
+
elif self.hessian_request.granularity == HessianScoresGranularity.PER_OUTPUT_CHANNEL:
|
131
|
+
hessian_scores = self._compute_per_channel(output, target_activation_tensors)
|
132
|
+
else:
|
133
|
+
raise NotImplementedError(f'{self.hessian_request.granularity} is not supported') # pragma: no cover
|
134
|
+
|
135
|
+
# Convert results to list of numpy arrays
|
136
|
+
hessian_results = [torch_tensor_to_numpy(h) for h in hessian_scores]
|
137
|
+
return hessian_results
|
138
|
+
|
139
|
+
def _compute_per_tensor(self, output, target_activation_tensors):
|
140
|
+
assert self.hessian_request.granularity == HessianScoresGranularity.PER_TENSOR
|
141
|
+
ipts_hessian_approx_scores = [torch.tensor([0.0], requires_grad=True, device=output.device)
|
142
|
+
for _ in range(len(target_activation_tensors))]
|
143
|
+
prev_mean_results = None
|
144
|
+
for j in tqdm(range(self.num_iterations_for_approximation), "Hessian random iterations"): # Approximation iterations
|
145
|
+
# Getting a random vector with normal distribution
|
146
|
+
v = self._generate_random_vectors_batch(output.shape, self.hessian_request.distribution, output.device)
|
147
|
+
f_v = torch.sum(v * output)
|
148
|
+
for i, ipt_tensor in enumerate(target_activation_tensors): # Per Interest point activation tensor
|
149
|
+
# Computing the hessian-approximation scores by getting the gradient of (output * v)
|
150
|
+
hess_v = autograd.grad(outputs=f_v,
|
151
|
+
inputs=ipt_tensor,
|
152
|
+
retain_graph=True,
|
153
|
+
allow_unused=True)[0]
|
154
|
+
|
155
|
+
if hess_v is None:
|
156
|
+
# In case we have an output node, which is an interest point, but it is not differentiable,
|
157
|
+
# we consider its Hessian to be the initial value 0.
|
158
|
+
continue # pragma: no cover
|
159
|
+
|
160
|
+
# Mean over all dims but the batch (CXHXW for conv)
|
161
|
+
hessian_approx_scores = torch.sum(hess_v ** 2.0, dim=tuple(d for d in range(1, len(hess_v.shape))))
|
162
|
+
|
163
|
+
# Update node Hessian approximation mean over random iterations
|
164
|
+
ipts_hessian_approx_scores[i] = (j * ipts_hessian_approx_scores[i] + hessian_approx_scores) / (j + 1)
|
165
|
+
|
166
|
+
# If the change to the maximal mean Hessian approximation is insignificant we stop the calculation
|
167
|
+
if j > MIN_HESSIAN_ITER:
|
168
|
+
if prev_mean_results is not None:
|
169
|
+
new_mean_res = torch.mean(torch.stack(ipts_hessian_approx_scores), dim=1)
|
170
|
+
relative_delta_per_node = (torch.abs(new_mean_res - prev_mean_results) /
|
171
|
+
(torch.abs(new_mean_res) + 1e-6))
|
172
|
+
max_delta = torch.max(relative_delta_per_node)
|
173
|
+
if max_delta < HESSIAN_COMP_TOLERANCE:
|
174
|
+
break
|
175
|
+
prev_mean_results = torch.mean(torch.stack(ipts_hessian_approx_scores), dim=1)
|
176
|
+
|
177
|
+
# add extra dimension to preserve previous behaviour
|
178
|
+
ipts_hessian_approx_scores = [torch.unsqueeze(t, -1) for t in ipts_hessian_approx_scores]
|
179
|
+
return ipts_hessian_approx_scores
|
180
|
+
|
181
|
+
def _compute_per_channel(self, output, target_activation_tensors):
|
182
|
+
assert self.hessian_request.granularity == HessianScoresGranularity.PER_OUTPUT_CHANNEL
|
183
|
+
ipts_hessian_approx_scores = [torch.tensor(0.0, requires_grad=True, device=output.device)
|
184
|
+
for _ in range(len(target_activation_tensors))]
|
185
|
+
|
186
|
+
for j in tqdm(range(self.num_iterations_for_approximation), "Hessian random iterations"): # Approximation iterations
|
187
|
+
v = self._generate_random_vectors_batch(output.shape, self.hessian_request.distribution, output.device)
|
188
|
+
f_v = torch.sum(v * output)
|
189
|
+
for i, ipt_tensor in enumerate(target_activation_tensors): # Per Interest point activation tensor
|
190
|
+
hess_v = autograd.grad(outputs=f_v,
|
191
|
+
inputs=ipt_tensor,
|
192
|
+
retain_graph=True)[0]
|
193
|
+
hessian_approx_scores = hess_v ** 2
|
194
|
+
rank = len(hess_v.shape)
|
195
|
+
if rank > 2:
|
196
|
+
hessian_approx_scores = torch.mean(hessian_approx_scores, dim=tuple(range(2, rank)))
|
197
|
+
|
198
|
+
# Update node Hessian approximation mean over random iterations
|
199
|
+
ipts_hessian_approx_scores[i] = (j * ipts_hessian_approx_scores[i] + hessian_approx_scores) / (j + 1)
|
200
|
+
|
201
|
+
return ipts_hessian_approx_scores
|
@@ -17,6 +17,7 @@ from enum import Enum
|
|
17
17
|
from typing import Callable, Any, Dict, Optional
|
18
18
|
|
19
19
|
from model_compression_toolkit.constants import GPTQ_HESSIAN_NUM_SAMPLES, ACT_HESSIAN_DEFAULT_BATCH_SIZE
|
20
|
+
from model_compression_toolkit.core.common.hessian import HessianScoresGranularity, HessianEstimationDistribution
|
20
21
|
from model_compression_toolkit.gptq.common.gptq_constants import REG_DEFAULT
|
21
22
|
|
22
23
|
|
@@ -39,17 +40,21 @@ class GPTQHessianScoresConfig:
|
|
39
40
|
Configuration to use for computing the Hessian-based scores for GPTQ loss metric.
|
40
41
|
|
41
42
|
Args:
|
42
|
-
hessians_num_samples (int): Number of samples to use for computing the Hessian-based scores.
|
43
|
+
hessians_num_samples (int|None): Number of samples to use for computing the Hessian-based scores.
|
44
|
+
If None, compute Hessian for all images.
|
43
45
|
norm_scores (bool): Whether to normalize the returned scores of the weighted loss function (to get values between 0 and 1).
|
44
46
|
log_norm (bool): Whether to use log normalization for the GPTQ Hessian-based scores.
|
45
47
|
scale_log_norm (bool): Whether to scale the final vector of the Hessian-based scores.
|
46
48
|
hessian_batch_size (int): The Hessian computation batch size. used only if using GPTQ with Hessian-based objective.
|
49
|
+
per_sample (bool): Whether to use per sample attention score.
|
47
50
|
"""
|
48
|
-
hessians_num_samples: int = GPTQ_HESSIAN_NUM_SAMPLES
|
51
|
+
hessians_num_samples: Optional[int] = GPTQ_HESSIAN_NUM_SAMPLES
|
49
52
|
norm_scores: bool = True
|
50
53
|
log_norm: bool = True
|
51
54
|
scale_log_norm: bool = False
|
52
55
|
hessian_batch_size: int = ACT_HESSIAN_DEFAULT_BATCH_SIZE
|
56
|
+
per_sample: bool = False
|
57
|
+
estimator_distribution: HessianEstimationDistribution = HessianEstimationDistribution.GAUSSIAN
|
53
58
|
|
54
59
|
|
55
60
|
@dataclass
|
@@ -13,6 +13,7 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
15
|
import copy
|
16
|
+
import hashlib
|
16
17
|
from abc import ABC, abstractmethod
|
17
18
|
import numpy as np
|
18
19
|
from typing import Callable, List, Any, Dict
|
@@ -143,7 +144,11 @@ class GPTQTrainer(ABC):
|
|
143
144
|
return np.asarray([1 / num_nodes for _ in range(num_nodes)])
|
144
145
|
|
145
146
|
# Fetch hessian approximations for each target node
|
146
|
-
|
147
|
+
# TODO this smells like a potential bug. In hessian calculation target nodes are topo sorted and results are returned
|
148
|
+
# TODO also target nodes are replaced for reuse. Does this work correctly?
|
149
|
+
approximations = self._fetch_hessian_approximations(HessianScoresGranularity.PER_TENSOR)
|
150
|
+
compare_point_to_hessian_approx_scores = {node: score for node, score in zip(self.compare_points, approximations)}
|
151
|
+
|
147
152
|
# Process the fetched hessian approximations to gather them per images
|
148
153
|
hessian_approx_score_by_image = (
|
149
154
|
self._process_hessian_approximations(compare_point_to_hessian_approx_scores))
|
@@ -172,29 +177,55 @@ class GPTQTrainer(ABC):
|
|
172
177
|
# If log normalization is not enabled, return the mean of the approximations across images
|
173
178
|
return np.mean(hessian_approx_score_by_image, axis=0)
|
174
179
|
|
175
|
-
def
|
180
|
+
def _compute_sample_layer_attention_scores(self, inputs_batch) -> Dict[str, Dict[BaseNode, np.ndarray]]:
|
181
|
+
"""
|
182
|
+
Compute sample layer attention scores per image hash per layer.
|
183
|
+
|
184
|
+
Args:
|
185
|
+
inputs_batch: a list containing a batch of inputs.
|
186
|
+
|
187
|
+
Returns:
|
188
|
+
A dictionary with a structure {img_hash: {layer: score}}.
|
189
|
+
|
190
|
+
"""
|
191
|
+
request = self._build_hessian_request(HessianScoresGranularity.PER_OUTPUT_CHANNEL)
|
192
|
+
hessian_batch_size = self.gptq_config.hessian_weights_config.hessian_batch_size
|
193
|
+
|
194
|
+
hessian_score_per_image_per_layer = {}
|
195
|
+
# If hessian batch is smaller than inputs batch, split it to hessian batches. If hessian batch is larger,
|
196
|
+
# it's currently ignored (TODO)
|
197
|
+
for i in range(0, inputs_batch[0].shape[0], hessian_batch_size):
|
198
|
+
inputs = [t[i: i+hessian_batch_size] for t in inputs_batch]
|
199
|
+
hessian_score_per_image_per_layer.update(
|
200
|
+
self.hessian_service.compute_trackable_per_sample_hessian(request, inputs)
|
201
|
+
)
|
202
|
+
for img_hash, v in hessian_score_per_image_per_layer.items():
|
203
|
+
hessian_score_per_image_per_layer[img_hash] = {k: t.max(axis=0) for k, t in v.items()}
|
204
|
+
return hessian_score_per_image_per_layer
|
205
|
+
|
206
|
+
def _fetch_hessian_approximations(self, granularity: HessianScoresGranularity) -> Dict[BaseNode, List[List[float]]]:
|
176
207
|
"""
|
177
208
|
Fetches hessian approximations for each target node.
|
178
209
|
|
179
210
|
Returns:
|
180
211
|
Mapping of target nodes to their hessian approximations.
|
181
212
|
"""
|
182
|
-
|
183
|
-
|
184
|
-
mode=HessianMode.ACTIVATION,
|
185
|
-
granularity=HessianScoresGranularity.PER_TENSOR,
|
186
|
-
target_nodes=self.compare_points
|
187
|
-
)
|
213
|
+
hessian_scores_request = self._build_hessian_request(granularity)
|
214
|
+
|
188
215
|
node_approximations = self.hessian_service.fetch_hessian(
|
189
216
|
hessian_scores_request=hessian_scores_request,
|
190
217
|
required_size=self.gptq_config.hessian_weights_config.hessians_num_samples,
|
191
218
|
batch_size=self.gptq_config.hessian_weights_config.hessian_batch_size
|
192
219
|
)
|
220
|
+
return node_approximations
|
193
221
|
|
194
|
-
|
195
|
-
|
196
|
-
|
197
|
-
|
222
|
+
def _build_hessian_request(self, granularity):
|
223
|
+
return HessianScoresRequest(
|
224
|
+
mode=HessianMode.ACTIVATION,
|
225
|
+
granularity=granularity,
|
226
|
+
target_nodes=self.compare_points,
|
227
|
+
distribution=self.gptq_config.hessian_weights_config.estimator_distribution
|
228
|
+
)
|
198
229
|
|
199
230
|
def _process_hessian_approximations(self, approximations: Dict[BaseNode, List[List[float]]]) -> List:
|
200
231
|
"""
|
@@ -13,8 +13,10 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
15
|
from typing import List
|
16
|
+
|
16
17
|
import torch
|
17
18
|
|
19
|
+
|
18
20
|
def mse_loss(y: torch.Tensor, x: torch.Tensor, normalized: bool = True) -> torch.Tensor:
|
19
21
|
"""
|
20
22
|
Compute the MSE of two tensors.
|
@@ -25,7 +27,7 @@ def mse_loss(y: torch.Tensor, x: torch.Tensor, normalized: bool = True) -> torch
|
|
25
27
|
Returns:
|
26
28
|
The MSE of two tensors.
|
27
29
|
"""
|
28
|
-
loss = torch.nn.MSELoss()(x,y)
|
30
|
+
loss = torch.nn.MSELoss()(x, y)
|
29
31
|
return loss / torch.mean(torch.square(x)) if normalized else loss
|
30
32
|
|
31
33
|
|
@@ -62,3 +64,36 @@ def multiple_tensors_mse_loss(y_list: List[torch.Tensor],
|
|
62
64
|
else:
|
63
65
|
return torch.mean(torch.stack(loss_values_list))
|
64
66
|
|
67
|
+
|
68
|
+
def sample_layer_attention_loss(y_list: List[torch.Tensor],
|
69
|
+
x_list: List[torch.Tensor],
|
70
|
+
fxp_w_list,
|
71
|
+
flp_w_list,
|
72
|
+
act_bn_mean,
|
73
|
+
act_bn_std,
|
74
|
+
loss_weights: torch.Tensor) -> torch.Tensor:
|
75
|
+
"""
|
76
|
+
Compute Sample Layer Attention loss between two lists of tensors.
|
77
|
+
|
78
|
+
Args:
|
79
|
+
y_list: First list of tensors.
|
80
|
+
x_list: Second list of tensors.
|
81
|
+
fxp_w_list, flp_w_list, act_bn_mean, act_bn_std: unused (needed to comply with the interface).
|
82
|
+
loss_weights: layer-sample weights tensor of shape (layers, batch)
|
83
|
+
|
84
|
+
Returns:
|
85
|
+
Sample Layer Attention loss (a scalar).
|
86
|
+
"""
|
87
|
+
loss = 0
|
88
|
+
layers_mean_w = []
|
89
|
+
|
90
|
+
for i, (y, x, w) in enumerate(zip(y_list, x_list, loss_weights)):
|
91
|
+
norm = (y - x).pow(2).sum(1)
|
92
|
+
if len(norm.shape) > 1:
|
93
|
+
norm = norm.flatten(1).mean(1)
|
94
|
+
loss += torch.mean(w * norm)
|
95
|
+
layers_mean_w.append(w.mean())
|
96
|
+
|
97
|
+
loss = loss / torch.stack(layers_mean_w).max()
|
98
|
+
return loss
|
99
|
+
|
@@ -12,7 +12,7 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
|
-
from typing import Callable, List, Tuple, Union
|
15
|
+
from typing import Callable, List, Tuple, Union, Dict
|
16
16
|
|
17
17
|
import numpy as np
|
18
18
|
from torch.nn import Module
|
@@ -105,8 +105,18 @@ class PytorchGPTQTrainer(GPTQTrainer):
|
|
105
105
|
self.optimizer_with_param = self.get_optimizer_with_param(trainable_weights,
|
106
106
|
trainable_bias,
|
107
107
|
trainable_threshold)
|
108
|
-
|
109
|
-
self.
|
108
|
+
hessian_cfg = self.gptq_config.hessian_weights_config
|
109
|
+
self.use_sample_layer_attention = hessian_cfg.per_sample
|
110
|
+
self.hessian_score_per_layer = None # for fixed layer weights
|
111
|
+
self.hessian_score_per_image_per_layer = None # for sample-layer attention
|
112
|
+
if self.use_sample_layer_attention:
|
113
|
+
# normalization is currently not supported, make sure the config reflects it.
|
114
|
+
if hessian_cfg.norm_scores or hessian_cfg.log_norm or hessian_cfg.scale_log_norm:
|
115
|
+
raise NotImplementedError()
|
116
|
+
# Per sample hessian scores are calculated on-demand during the training loop
|
117
|
+
self.hessian_score_per_image_per_layer = {}
|
118
|
+
else:
|
119
|
+
self.hessian_score_per_layer = to_torch_tensor(self.compute_hessian_based_weights())
|
110
120
|
|
111
121
|
self.reg_func = get_regularization(self.gptq_config, _get_total_grad_steps)
|
112
122
|
|
@@ -210,13 +220,17 @@ class PytorchGPTQTrainer(GPTQTrainer):
|
|
210
220
|
|
211
221
|
def compute_gradients(self,
|
212
222
|
y_float: List[torch.Tensor],
|
213
|
-
input_tensors: List[torch.Tensor]
|
223
|
+
input_tensors: List[torch.Tensor],
|
224
|
+
distill_loss_weights: torch.Tensor,
|
225
|
+
round_reg_weights: torch.Tensor) -> Tuple[torch.Tensor, List[np.ndarray]]:
|
214
226
|
"""
|
215
227
|
Get outputs from both teacher and student networks. Compute the observed error,
|
216
228
|
and use it to compute the gradients and applying them to the student weights.
|
217
229
|
Args:
|
218
230
|
y_float: A list of reference tensor from the floating point network.
|
219
231
|
input_tensors: A list of Input tensors to pass through the networks.
|
232
|
+
distill_loss_weights: Weights for the distillation loss.
|
233
|
+
round_reg_weights: Weight for the rounding regularization loss.
|
220
234
|
Returns:
|
221
235
|
Loss and gradients.
|
222
236
|
"""
|
@@ -231,9 +245,8 @@ class PytorchGPTQTrainer(GPTQTrainer):
|
|
231
245
|
self.flp_weights_list,
|
232
246
|
self.compare_points_mean,
|
233
247
|
self.compare_points_std,
|
234
|
-
|
235
|
-
|
236
|
-
reg_value = self.reg_func(self.fxp_model, self.gptq_config.regularization_factor)
|
248
|
+
distill_loss_weights)
|
249
|
+
reg_value = self.reg_func(self.fxp_model, self.gptq_config.regularization_factor, round_reg_weights)
|
237
250
|
|
238
251
|
loss_value += reg_value
|
239
252
|
|
@@ -261,10 +274,11 @@ class PytorchGPTQTrainer(GPTQTrainer):
|
|
261
274
|
for _ in epochs_pbar:
|
262
275
|
with tqdm(data_function(), position=1, leave=False) as data_pbar:
|
263
276
|
for data in data_pbar:
|
277
|
+
distill_weights, reg_weights = to_torch_tensor(self._get_loss_weights(data))
|
264
278
|
input_data = [d * self.input_scale for d in data]
|
265
279
|
input_tensor = to_torch_tensor(input_data)
|
266
280
|
y_float = self.float_model(input_tensor) # running float model
|
267
|
-
loss_value, grads = self.compute_gradients(y_float, input_tensor)
|
281
|
+
loss_value, grads = self.compute_gradients(y_float, input_tensor, distill_weights, reg_weights)
|
268
282
|
# Run one step of gradient descent by updating the value of the variables to minimize the loss.
|
269
283
|
for (optimizer, _) in self.optimizer_with_param:
|
270
284
|
optimizer.step()
|
@@ -276,6 +290,42 @@ class PytorchGPTQTrainer(GPTQTrainer):
|
|
276
290
|
self.loss_list.append(loss_value.item())
|
277
291
|
Logger.debug(f'last loss value: {self.loss_list[-1]}')
|
278
292
|
|
293
|
+
def _get_loss_weights(self, input_tensors: List[torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]:
|
294
|
+
"""
|
295
|
+
Fetches weights for distillation and round regularization parts of loss.
|
296
|
+
|
297
|
+
Args:
|
298
|
+
input_tensors: list containing a batch of inputs.
|
299
|
+
|
300
|
+
Returns:
|
301
|
+
A tuple of two tensors:
|
302
|
+
- weights for distillation loss
|
303
|
+
- weights for rounding regularization loss
|
304
|
+
|
305
|
+
"""
|
306
|
+
if self.use_sample_layer_attention is False:
|
307
|
+
return self.hessian_score_per_layer, torch.ones_like(self.hessian_score_per_layer)
|
308
|
+
|
309
|
+
if len(input_tensors) > 1:
|
310
|
+
raise NotImplementedError('Sample-Layer attention is not currently supported for networks with multiple inputs')
|
311
|
+
|
312
|
+
image_scores = []
|
313
|
+
batch = input_tensors[0]
|
314
|
+
img_hashes = [self.hessian_service.calc_image_hash(img) for img in batch]
|
315
|
+
for img_hash in img_hashes:
|
316
|
+
# If sample-layer attention score for the image is not found, compute and store it for the whole batch.
|
317
|
+
if img_hash not in self.hessian_score_per_image_per_layer:
|
318
|
+
score_per_image_per_layer = self._compute_sample_layer_attention_scores(input_tensors)
|
319
|
+
self.hessian_score_per_image_per_layer.update(score_per_image_per_layer)
|
320
|
+
img_scores_per_layer: Dict[BaseNode, np.ndarray] = self.hessian_score_per_image_per_layer[img_hash]
|
321
|
+
# fetch image scores for all layers and combine them into a single tensor
|
322
|
+
img_scores = np.stack(list(img_scores_per_layer.values()), axis=0)
|
323
|
+
image_scores.append(img_scores)
|
324
|
+
|
325
|
+
layer_sample_weights = np.stack(image_scores, axis=1) # layers X images
|
326
|
+
layer_weights = layer_sample_weights.mean(axis=1)
|
327
|
+
return layer_sample_weights, layer_weights
|
328
|
+
|
279
329
|
def update_graph(self) -> Graph:
|
280
330
|
"""
|
281
331
|
Update a graph using GPTQ after minimizing the loss between the float model's output
|
@@ -18,6 +18,7 @@ from typing import Callable, Union
|
|
18
18
|
from model_compression_toolkit.constants import ACT_HESSIAN_DEFAULT_BATCH_SIZE, PYTORCH
|
19
19
|
from model_compression_toolkit.core import CoreConfig
|
20
20
|
from model_compression_toolkit.core.analyzer import analyzer_model_quantization
|
21
|
+
from model_compression_toolkit.core.common.hessian import HessianScoresGranularity, HessianEstimationDistribution
|
21
22
|
from model_compression_toolkit.core.common.mixed_precision.mixed_precision_quantization_config import \
|
22
23
|
MixedPrecisionQuantizationConfig
|
23
24
|
from model_compression_toolkit.core.common.mixed_precision.resource_utilization_tools.resource_utilization import \
|
@@ -43,7 +44,7 @@ if FOUND_TORCH:
|
|
43
44
|
from model_compression_toolkit.core.pytorch.default_framework_info import DEFAULT_PYTORCH_INFO
|
44
45
|
from model_compression_toolkit.gptq.pytorch.gptq_pytorch_implementation import GPTQPytorchImplemantation
|
45
46
|
from model_compression_toolkit.target_platform_capabilities.constants import DEFAULT_TP_MODEL
|
46
|
-
from model_compression_toolkit.gptq.pytorch.gptq_loss import multiple_tensors_mse_loss
|
47
|
+
from model_compression_toolkit.gptq.pytorch.gptq_loss import multiple_tensors_mse_loss, sample_layer_attention_loss
|
47
48
|
from model_compression_toolkit.exporter.model_wrapper.pytorch.builder.fully_quantized_model_builder import get_exportable_pytorch_model
|
48
49
|
import torch
|
49
50
|
from torch.nn import Module
|
@@ -55,11 +56,12 @@ if FOUND_TORCH:
|
|
55
56
|
def get_pytorch_gptq_config(n_epochs: int,
|
56
57
|
optimizer: Optimizer = None,
|
57
58
|
optimizer_rest: Optimizer = None,
|
58
|
-
loss: Callable =
|
59
|
+
loss: Callable = None,
|
59
60
|
log_function: Callable = None,
|
60
61
|
use_hessian_based_weights: bool = True,
|
61
62
|
regularization_factor: float = REG_DEFAULT,
|
62
63
|
hessian_batch_size: int = ACT_HESSIAN_DEFAULT_BATCH_SIZE,
|
64
|
+
use_hessian_sample_attention: bool = False,
|
63
65
|
gradual_activation_quantization: Union[bool, GradualActivationQuantizationConfig] = False,
|
64
66
|
) -> GradientPTQConfig:
|
65
67
|
"""
|
@@ -74,6 +76,7 @@ if FOUND_TORCH:
|
|
74
76
|
use_hessian_based_weights (bool): Whether to use Hessian-based weights for weighted average loss.
|
75
77
|
regularization_factor (float): A floating point number that defines the regularization factor.
|
76
78
|
hessian_batch_size (int): Batch size for Hessian computation in Hessian-based weights GPTQ.
|
79
|
+
use_hessian_sample_attention (bool): whether to use Sample-Layer Attention score for weighted loss.
|
77
80
|
gradual_activation_quantization (bool, GradualActivationQuantizationConfig):
|
78
81
|
If False, GradualActivationQuantization is disabled.
|
79
82
|
If True, GradualActivationQuantization is enabled with the default settings.
|
@@ -105,19 +108,37 @@ if FOUND_TORCH:
|
|
105
108
|
|
106
109
|
bias_optimizer = torch.optim.SGD([torch.Tensor([])], lr=LR_BIAS_DEFAULT, momentum=GPTQ_MOMENTUM)
|
107
110
|
|
111
|
+
if use_hessian_sample_attention:
|
112
|
+
if not use_hessian_based_weights: # pragma: no cover
|
113
|
+
raise ValueError('use_hessian_based_weights must be set to True in order to use Sample Layer Attention.')
|
114
|
+
|
115
|
+
hessian_weights_config = GPTQHessianScoresConfig(
|
116
|
+
hessians_num_samples=None,
|
117
|
+
norm_scores=False,
|
118
|
+
log_norm=False,
|
119
|
+
scale_log_norm=False,
|
120
|
+
hessian_batch_size=hessian_batch_size,
|
121
|
+
per_sample=True,
|
122
|
+
estimator_distribution=HessianEstimationDistribution.RADEMACHER
|
123
|
+
)
|
124
|
+
loss = loss or sample_layer_attention_loss
|
125
|
+
else:
|
126
|
+
hessian_weights_config = GPTQHessianScoresConfig(hessian_batch_size=hessian_batch_size)
|
127
|
+
loss = loss or multiple_tensors_mse_loss
|
128
|
+
|
108
129
|
if isinstance(gradual_activation_quantization, bool):
|
109
130
|
gradual_quant_config = GradualActivationQuantizationConfig() if gradual_activation_quantization else None
|
110
131
|
elif isinstance(gradual_activation_quantization, GradualActivationQuantizationConfig):
|
111
132
|
gradual_quant_config = gradual_activation_quantization
|
112
|
-
else:
|
133
|
+
else: # pragma: no cover
|
113
134
|
raise TypeError(f'gradual_activation_quantization argument should be bool or '
|
114
|
-
f'GradualActivationQuantizationConfig, received {type(gradual_activation_quantization)}')
|
135
|
+
f'GradualActivationQuantizationConfig, received {type(gradual_activation_quantization)}')
|
115
136
|
|
116
137
|
return GradientPTQConfig(n_epochs, optimizer, optimizer_rest=optimizer_rest, loss=loss,
|
117
138
|
log_function=log_function, train_bias=True, optimizer_bias=bias_optimizer,
|
118
139
|
use_hessian_based_weights=use_hessian_based_weights,
|
119
140
|
regularization_factor=regularization_factor,
|
120
|
-
hessian_weights_config=
|
141
|
+
hessian_weights_config=hessian_weights_config,
|
121
142
|
gradual_activation_quantization_config=gradual_quant_config)
|
122
143
|
|
123
144
|
def pytorch_gradient_post_training_quantization(model: Module,
|
@@ -185,11 +206,11 @@ if FOUND_TORCH:
|
|
185
206
|
|
186
207
|
"""
|
187
208
|
|
188
|
-
if core_config.is_mixed_precision_enabled:
|
209
|
+
if core_config.is_mixed_precision_enabled: # pragma: no cover
|
189
210
|
if not isinstance(core_config.mixed_precision_config, MixedPrecisionQuantizationConfig):
|
190
211
|
Logger.critical("Given quantization config for mixed-precision is not of type 'MixedPrecisionQuantizationConfig'. "
|
191
212
|
"Ensure usage of the correct API for 'pytorch_gradient_post_training_quantization' "
|
192
|
-
"or provide a valid mixed-precision configuration.")
|
213
|
+
"or provide a valid mixed-precision configuration.")
|
193
214
|
|
194
215
|
tb_w = init_tensorboard_writer(DEFAULT_PYTORCH_INFO)
|
195
216
|
|
@@ -41,4 +41,4 @@ def get_regularization(gptq_config: GradientPTQConfig, get_total_grad_steps_fn:
|
|
41
41
|
scheduler = LinearAnnealingScheduler(t_start=t_start, t_end=total_gradient_steps, initial_val=20, target_val=2)
|
42
42
|
return SoftQuantizerRegularization(scheduler)
|
43
43
|
else:
|
44
|
-
return lambda
|
44
|
+
return lambda *args, **kwargs: 0
|
@@ -40,32 +40,34 @@ class SoftQuantizerRegularization:
|
|
40
40
|
|
41
41
|
self.count_iter = 0
|
42
42
|
|
43
|
-
def __call__(self, model: nn.Module, entropy_reg: float):
|
43
|
+
def __call__(self, model: nn.Module, entropy_reg: float, layer_weights: torch.Tensor):
|
44
44
|
"""
|
45
45
|
Returns the soft quantizer regularization value for SoftRounding.
|
46
46
|
|
47
47
|
Args:
|
48
48
|
model: A model to be quantized with SoftRounding.
|
49
49
|
entropy_reg: Entropy value to scale the quantizer regularization.
|
50
|
+
layer_weights: a vector of layer weights.
|
50
51
|
|
51
52
|
Returns: Regularization value.
|
52
53
|
"""
|
54
|
+
layers = [m for m in model.modules() if isinstance(m, PytorchQuantizationWrapper)]
|
53
55
|
|
54
|
-
|
55
|
-
|
56
|
-
|
57
|
-
if isinstance(layer, PytorchQuantizationWrapper):
|
58
|
-
kernel_attribute = get_kernel_attribute_name_for_gptq(layer_type=type(layer.layer),
|
59
|
-
fw_info=DEFAULT_PYTORCH_INFO)
|
60
|
-
|
61
|
-
st = layer.weights_quantizers[kernel_attribute].get_soft_targets()
|
62
|
-
soft_reg_aux.append((1 - torch.pow(torch.abs(st - .5) * 2, b)).sum())
|
56
|
+
if len(layer_weights.shape) != 1 or layer_weights.shape[0] != len(layers):
|
57
|
+
raise ValueError(f'Expected weights to be a vector of length {len(layers)}, received {layer_weights.shape}.') # pragma: no cover
|
58
|
+
max_w = layer_weights.max()
|
63
59
|
|
60
|
+
b = self.beta_scheduler(self.count_iter)
|
64
61
|
reg = 0
|
62
|
+
for layer, w in zip(layers, layer_weights):
|
63
|
+
kernel_attribute = get_kernel_attribute_name_for_gptq(layer_type=type(layer.layer),
|
64
|
+
fw_info=DEFAULT_PYTORCH_INFO)
|
65
65
|
|
66
|
-
|
67
|
-
|
66
|
+
st = layer.weights_quantizers[kernel_attribute].get_soft_targets()
|
67
|
+
soft_loss = (1 - torch.pow(torch.abs(st - .5) * 2, b)).sum()
|
68
|
+
reg += w * soft_loss
|
68
69
|
|
70
|
+
reg = reg / max_w
|
69
71
|
self.count_iter += 1
|
70
72
|
|
71
73
|
return entropy_reg * reg
|
{mct_nightly-2.2.0.20241007.529.dist-info → mct_nightly-2.2.0.20241008.450.dist-info}/LICENSE.md
RENAMED
File without changes
|
File without changes
|
{mct_nightly-2.2.0.20241007.529.dist-info → mct_nightly-2.2.0.20241008.450.dist-info}/top_level.txt
RENAMED
File without changes
|