mct-nightly 2.2.0.20241006.532__py3-none-any.whl → 2.2.0.20241008.450__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: mct-nightly
3
- Version: 2.2.0.20241006.532
3
+ Version: 2.2.0.20241008.450
4
4
  Summary: A Model Compression Toolkit for neural networks
5
5
  Home-page: UNKNOWN
6
6
  License: UNKNOWN
@@ -1,4 +1,4 @@
1
- model_compression_toolkit/__init__.py,sha256=kSEKSXjzPQru90gByS3M6uomZGoS5vT50wU-WzVOQEU,1573
1
+ model_compression_toolkit/__init__.py,sha256=N9yCh68lSsYuGo6DuxotIhOSedwXIAg8XDYshb0Nz4g,1573
2
2
  model_compression_toolkit/constants.py,sha256=i4wYheBkIdQmsQA-axIpcT3YiSO1USNc-jaNiNE8w6E,3920
3
3
  model_compression_toolkit/defaultdict.py,sha256=LSc-sbZYXENMCw3U9F4GiXuv67IKpdn0Qm7Fr11jy-4,2277
4
4
  model_compression_toolkit/logger.py,sha256=3DByV41XHRR3kLTJNbpaMmikL8icd9e1N-nkQAY9oDk,4567
@@ -46,11 +46,11 @@ model_compression_toolkit/core/common/graph/memory_graph/cut.py,sha256=aPdXJPP5a
46
46
  model_compression_toolkit/core/common/graph/memory_graph/max_cut_astar.py,sha256=crV2NCLVO8jx9MlryZBYuJKFe_G9HfM7rUR64fDymlw,17045
47
47
  model_compression_toolkit/core/common/graph/memory_graph/memory_element.py,sha256=gRmBEFRmyJsNKezQfiwDwQu1cmbGd2wgKCRTH6iw8mw,3961
48
48
  model_compression_toolkit/core/common/graph/memory_graph/memory_graph.py,sha256=gw4av_rzn_3oEAPpD3B7PHZDqnxHMjIESevl6ppPnkk,7175
49
- model_compression_toolkit/core/common/hessian/__init__.py,sha256=6216QgHl7h4DXGn5ForP9Tija-wrBSONNtQ769ikP2s,1025
50
- model_compression_toolkit/core/common/hessian/hessian_info_service.py,sha256=TfgSIh5pmZcJM9335aAxZriCzMljnk3mYhmKBsK2x5Y,20848
49
+ model_compression_toolkit/core/common/hessian/__init__.py,sha256=Sj3I9mLBq-yrcBFxpUkOy0Rb5pxJQBPcECvgyOqhHSY,1064
50
+ model_compression_toolkit/core/common/hessian/hessian_info_service.py,sha256=fUgW-AUhRu609_RSRd1WKaQAfPk2SmLnlkT74v6TZwY,23769
51
51
  model_compression_toolkit/core/common/hessian/hessian_info_utils.py,sha256=1axmN0tjJSo_7hUr2d2KMv4y1pBi19cqWSQpi4BbdsA,1458
52
52
  model_compression_toolkit/core/common/hessian/hessian_scores_calculator.py,sha256=Pe4uKerx-MeDQPJ7Slr8fvFUHfv02q33w3gbQK5kBKs,4186
53
- model_compression_toolkit/core/common/hessian/hessian_scores_request.py,sha256=atGJgJBL9uwYRC3t9NnzGgHYxV4XJj4Ai_xPpQH0rhY,3229
53
+ model_compression_toolkit/core/common/hessian/hessian_scores_request.py,sha256=fYXcOMa2bpbJjQ2S4r021WOvhoDWFa_jy95hofqVBFA,3632
54
54
  model_compression_toolkit/core/common/matchers/__init__.py,sha256=sw7LOPN1bM82o3SkMaklyH0jw-TLGK0-fl2Wq73rffI,697
55
55
  model_compression_toolkit/core/common/matchers/base_graph_filter.py,sha256=mTk54z0mIbFmPOb4h0xfLtLDookcFyNh8H0pIN5js_M,3091
56
56
  model_compression_toolkit/core/common/matchers/base_matcher.py,sha256=JCj-NLAXOJa-GcSX-94PVUTWjooQUd0NemiyNg5uKGQ,2210
@@ -256,7 +256,7 @@ model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/transfo
256
256
  model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/virtual_activation_weights_composition.py,sha256=WmEa8Xjji-_tIbthDxlLAGSr69nWk-YKcHNaVqLa7sg,1375
257
257
  model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/weights_activation_split.py,sha256=tp78axmUQc0Zpj3KwVmV0PGYHvCf7sAW_sRmXXw7gsY,1616
258
258
  model_compression_toolkit/core/pytorch/hessian/__init__.py,sha256=lNJ29DYxaLUPDstRDA1PGI5r9Fulq_hvrZMlhst1Z5g,697
259
- model_compression_toolkit/core/pytorch/hessian/activation_hessian_scores_calculator_pytorch.py,sha256=xc_-utc9_Hq915X02VbT8zXxGqxE4fFz6dhiiZwU3ok,8578
259
+ model_compression_toolkit/core/pytorch/hessian/activation_hessian_scores_calculator_pytorch.py,sha256=fKeql1cXOieHTbxQDOIMpFO1sVktqXVCRBgZkv3R13Q,10929
260
260
  model_compression_toolkit/core/pytorch/hessian/hessian_scores_calculator_pytorch.py,sha256=vXluX-awgavv7DGihG9HrlvLhak8qIHy837PPTOd4jg,3471
261
261
  model_compression_toolkit/core/pytorch/hessian/weights_hessian_scores_calculator_pytorch.py,sha256=C4-7naBQUh8TN6fEwkyKY6rlY_nvHSAmCnWT4iMBs8E,8497
262
262
  model_compression_toolkit/core/pytorch/mixed_precision/__init__.py,sha256=Rf1RcYmelmdZmBV5qOKvKWF575ofc06JFQSq83Jz99A,696
@@ -345,11 +345,11 @@ model_compression_toolkit/exporter/model_wrapper/pytorch/builder/node_to_quantiz
345
345
  model_compression_toolkit/gptq/__init__.py,sha256=pEgkJvmf05KSw70iLDTz_6LI_2Oi5L8sTN0JsEUpnpk,1445
346
346
  model_compression_toolkit/gptq/runner.py,sha256=La12JTYjWyJW0YW4Al4TP1_Xi4JWBCEKw6FR_JQsxe0,5982
347
347
  model_compression_toolkit/gptq/common/__init__.py,sha256=cco4TmeIDIh32nj9ZZXVkws4dd9F2UDrmjKzTN8G0V0,697
348
- model_compression_toolkit/gptq/common/gptq_config.py,sha256=xogD4mM2825NXyX7rKWBaKBhBFo31bMUmxECREGgtWc,6132
348
+ model_compression_toolkit/gptq/common/gptq_config.py,sha256=GP4lcDeyVgXA-QFArDW28UucOOKY0zeYJpq2pvyNVM8,6510
349
349
  model_compression_toolkit/gptq/common/gptq_constants.py,sha256=QSm6laLkIV0LYmU0BLtmKp3Fi3SqDfbncFQWOGA1cGU,611
350
350
  model_compression_toolkit/gptq/common/gptq_framework_implementation.py,sha256=n3mSf4J92kFjekzyGyrJULylI-8Jf5OVWJ5AFoVnEx0,1266
351
351
  model_compression_toolkit/gptq/common/gptq_graph.py,sha256=-bL5HhPcKqV8nj4dZPXc5QmQJbFBel6etrioikP0tEo,3039
352
- model_compression_toolkit/gptq/common/gptq_training.py,sha256=CtSpjG27BQ3rLPGWeBnZYYiGnMREpdBd6dx7SQf_wDk,14965
352
+ model_compression_toolkit/gptq/common/gptq_training.py,sha256=dRNEjjKdVqlazbGWjZNE9q-MsU0PBffGKHfDpy3NX5Q,16661
353
353
  model_compression_toolkit/gptq/keras/__init__.py,sha256=cco4TmeIDIh32nj9ZZXVkws4dd9F2UDrmjKzTN8G0V0,697
354
354
  model_compression_toolkit/gptq/keras/gptq_keras_implementation.py,sha256=axBwnCSjq5xk-xGymOwSOqjp39It-CVtGcCTRTf0E_4,1248
355
355
  model_compression_toolkit/gptq/keras/gptq_loss.py,sha256=rbRkF15MYd6nq4G49kcjb_dPTa-XNq9cTkrb93mXawo,6241
@@ -368,19 +368,19 @@ model_compression_toolkit/gptq/keras/quantizer/soft_rounding/uniform_soft_quanti
368
368
  model_compression_toolkit/gptq/keras/quantizer/ste_rounding/__init__.py,sha256=cco4TmeIDIh32nj9ZZXVkws4dd9F2UDrmjKzTN8G0V0,697
369
369
  model_compression_toolkit/gptq/keras/quantizer/ste_rounding/symmetric_ste.py,sha256=pgZADwaNWUwm9QTrYaW6yXE3-zfedPZSa9TKBVedNd4,8356
370
370
  model_compression_toolkit/gptq/pytorch/__init__.py,sha256=cco4TmeIDIh32nj9ZZXVkws4dd9F2UDrmjKzTN8G0V0,697
371
- model_compression_toolkit/gptq/pytorch/gptq_loss.py,sha256=kDuWw-6zh17wZpYWh4Xa94rpoodf82DksgjQCnL7nBc,2719
371
+ model_compression_toolkit/gptq/pytorch/gptq_loss.py,sha256=nVW3URcCWQywoXfmTOBMxliZVvosshf4-G0Sq7dNwzU,3877
372
372
  model_compression_toolkit/gptq/pytorch/gptq_pytorch_implementation.py,sha256=tECPTavxn8EEwgLaP2zvxdJH6Vg9jC0YOIMJ7857Sdc,1268
373
- model_compression_toolkit/gptq/pytorch/gptq_training.py,sha256=bnL4DyPLBz2-pip3RV_jBmExvQKZ4N1vXzQudc1VgMY,17117
373
+ model_compression_toolkit/gptq/pytorch/gptq_training.py,sha256=j_FZcs8ey_9voI83TrL4q1Mne59zO2_v0MzdhZcxWuY,20071
374
374
  model_compression_toolkit/gptq/pytorch/graph_info.py,sha256=4mVM-VvnBaA64ACVdOe6wTGHdMSa2UTLIUe7nACLcdo,4008
375
- model_compression_toolkit/gptq/pytorch/quantization_facade.py,sha256=Z1xCEDiRWE6xtjVjgVGpgGazuY9l9IhUOPNiRZegLMQ,15408
375
+ model_compression_toolkit/gptq/pytorch/quantization_facade.py,sha256=7UPaLBx66mJIlDTpT1uLI9LpHPzOr8EtywZ0aawveDA,16527
376
376
  model_compression_toolkit/gptq/pytorch/quantizer/__init__.py,sha256=ZHNHo1yzye44m9_ht4UUZfTpK01RiVR3Tr74-vtnOGI,968
377
377
  model_compression_toolkit/gptq/pytorch/quantizer/base_pytorch_gptq_quantizer.py,sha256=fKg-PNOhGBiL-4eySS9Fyw0GkA76Pq8jT_HbJuJ8iZU,4143
378
378
  model_compression_toolkit/gptq/pytorch/quantizer/gradual_activation_quantization.py,sha256=nngu2TeXjngkqt_6-wciFmCvo-dbpeh_tJJxBV_cfHk,3686
379
379
  model_compression_toolkit/gptq/pytorch/quantizer/quant_utils.py,sha256=OocYYRqvl7rZ37QT0hTzfJnWGiNCPskg7cziTlR7TRk,3893
380
380
  model_compression_toolkit/gptq/pytorch/quantizer/quantization_builder.py,sha256=5EyAzvlU01vLyXmMwY_8dNyb7GwYktXmnrvUON8n8WI,4696
381
- model_compression_toolkit/gptq/pytorch/quantizer/regularization_factory.py,sha256=lkeEBgAAhC1VHu4DHoqDz8GC7BIU4cU0HIAXFYfgUFU,2098
381
+ model_compression_toolkit/gptq/pytorch/quantizer/regularization_factory.py,sha256=H6pARLK-jq3cKoaipY0SK9wMGrqy6CSEZTk14KdrKA0,2105
382
382
  model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/__init__.py,sha256=lNJ29DYxaLUPDstRDA1PGI5r9Fulq_hvrZMlhst1Z5g,697
383
- model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/soft_quantizer_reg.py,sha256=UZwVCpG8WOw7r0-cmPYXNkJYpTZciW66KWtKG004J6Q,2683
383
+ model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/soft_quantizer_reg.py,sha256=vlQEhif-R49UstORkXmpMA4ZE82Aqh-mJqKCnB31gag,3005
384
384
  model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/symmetric_soft_quantizer.py,sha256=kLVQC1hXzDpP4Jx7AwnA764oGnY5AMEuvUUhAvhz09M,12347
385
385
  model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/uniform_soft_quantizer.py,sha256=FgPSKoV8p8y-gLNz359XdOPD6w_wpDvcJFtTNLWqYb0,9099
386
386
  model_compression_toolkit/gptq/pytorch/quantizer/ste_rounding/__init__.py,sha256=Rf1RcYmelmdZmBV5qOKvKWF575ofc06JFQSq83Jz99A,696
@@ -551,8 +551,8 @@ tests_pytest/pytorch/gptq/test_annealing_cfg.py,sha256=hGC7L6mp3N1ygcJ3OctgS_Fz2
551
551
  tests_pytest/pytorch/gptq/test_gradual_act_quantization.py,sha256=tI01aFIUaiCILL5Qn--p1E_rLBUelxLdSY3k52lwcx0,4594
552
552
  tests_pytest/pytorch/trainable_infrastructure/__init__.py,sha256=RAe8mgIr1V8dRIQtLf_dSG5zTUCKuQzxyybYx1dzEAs,697
553
553
  tests_pytest/pytorch/trainable_infrastructure/test_linear_annealing.py,sha256=eNOpSp0GoLxtEdiRypBp8jaujXfdNxBwKh5Rd-P7WLs,1786
554
- mct_nightly-2.2.0.20241006.532.dist-info/LICENSE.md,sha256=aYSSIb-5AFPeITTvXm1UAoe0uYBiMmSS8flvXaaFUks,10174
555
- mct_nightly-2.2.0.20241006.532.dist-info/METADATA,sha256=skhhX9UH3JERO3bWA-6PalAx6JLwSUkVJyWsT07eFrs,20830
556
- mct_nightly-2.2.0.20241006.532.dist-info/WHEEL,sha256=eOLhNAGa2EW3wWl_TU484h7q1UNgy0JXjjoqKoxAAQc,92
557
- mct_nightly-2.2.0.20241006.532.dist-info/top_level.txt,sha256=csdfSXhtRnpWYRzjZ-dRLIhOmM2TEdVXUxG05A5fgb8,39
558
- mct_nightly-2.2.0.20241006.532.dist-info/RECORD,,
554
+ mct_nightly-2.2.0.20241008.450.dist-info/LICENSE.md,sha256=aYSSIb-5AFPeITTvXm1UAoe0uYBiMmSS8flvXaaFUks,10174
555
+ mct_nightly-2.2.0.20241008.450.dist-info/METADATA,sha256=J3vzhM5gpeuXgdgaqJRl6bQc17gSCONWmeNCPLYvyTs,20830
556
+ mct_nightly-2.2.0.20241008.450.dist-info/WHEEL,sha256=eOLhNAGa2EW3wWl_TU484h7q1UNgy0JXjjoqKoxAAQc,92
557
+ mct_nightly-2.2.0.20241008.450.dist-info/top_level.txt,sha256=csdfSXhtRnpWYRzjZ-dRLIhOmM2TEdVXUxG05A5fgb8,39
558
+ mct_nightly-2.2.0.20241008.450.dist-info/RECORD,,
@@ -27,4 +27,4 @@ from model_compression_toolkit import data_generation
27
27
  from model_compression_toolkit import pruning
28
28
  from model_compression_toolkit.trainable_infrastructure.keras.load_model import keras_load_quantized_model
29
29
 
30
- __version__ = "2.2.0.20241006.000532"
30
+ __version__ = "2.2.0.20241008.000450"
@@ -12,6 +12,8 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
- from model_compression_toolkit.core.common.hessian.hessian_scores_request import HessianScoresRequest, HessianMode, HessianScoresGranularity
15
+ from model_compression_toolkit.core.common.hessian.hessian_scores_request import (
16
+ HessianScoresRequest, HessianMode, HessianScoresGranularity, HessianEstimationDistribution
17
+ )
16
18
  from model_compression_toolkit.core.common.hessian.hessian_info_service import HessianInfoService
17
19
  import model_compression_toolkit.core.common.hessian.hessian_info_utils as hessian_utils
@@ -12,16 +12,19 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
+ import hashlib
15
16
 
16
17
  import numpy as np
17
18
  from functools import partial
18
19
  from tqdm import tqdm
19
- from typing import Callable, List, Dict, Any, Tuple
20
+ from typing import Callable, List, Dict, Any, Tuple, TYPE_CHECKING
20
21
 
21
22
  from model_compression_toolkit.constants import HESSIAN_NUM_ITERATIONS
22
23
  from model_compression_toolkit.core.common.hessian.hessian_scores_request import HessianScoresRequest, \
23
24
  HessianScoresGranularity, HessianMode
24
25
  from model_compression_toolkit.logger import Logger
26
+ if TYPE_CHECKING: # pragma: no cover
27
+ from model_compression_toolkit.core.common import BaseNode
25
28
 
26
29
 
27
30
  class HessianInfoService:
@@ -228,6 +231,61 @@ class HessianInfoService:
228
231
  return next_iter_remain_samples if next_iter_remain_samples is not None and len(next_iter_remain_samples) > 0 \
229
232
  and len(next_iter_remain_samples[0]) > 0 else None
230
233
 
234
+ def compute_trackable_per_sample_hessian(self,
235
+ hessian_scores_request: HessianScoresRequest,
236
+ inputs_batch: List[np.ndarray]) -> Dict[str, Dict['BaseNode', np.ndarray]]:
237
+ """
238
+ Compute hessian score per image hash. We compute the score directly for images rather than via data generator,
239
+ as data generator might yield different images each time, depending on how it was defined,
240
+
241
+ Args:
242
+ hessian_scores_request: hessian scores request
243
+ inputs_batch: a list containing a batch of inputs.
244
+
245
+ Returns:
246
+ A dict of Hessian scores per image hash per layer {image hash: {layer: score}}
247
+ """
248
+ topo_sorted_nodes_names = [x.name for x in self.graph.get_topo_sorted_nodes()]
249
+ hessian_scores_request.target_nodes.sort(key=lambda x: topo_sorted_nodes_names.index(x.name))
250
+
251
+ hessian_score_by_image_hash = {}
252
+
253
+ if not inputs_batch or not isinstance(inputs_batch, list):
254
+ raise TypeError('Expected a non-empty list of inputs') # pragma: no cover
255
+ if len(inputs_batch) > 1:
256
+ raise NotImplementedError('Per-sample hessian computation is not supported for networks with multiple inputs') # pragma: no cover
257
+
258
+ # Get the framework-specific calculator Hessian-approximation scores
259
+ fw_hessian_calculator = self.fw_impl.get_hessian_scores_calculator(graph=self.graph,
260
+ input_images=inputs_batch,
261
+ hessian_scores_request=hessian_scores_request,
262
+ num_iterations_for_approximation=self.num_iterations_for_approximation)
263
+ hessian_scores = fw_hessian_calculator.compute()
264
+ for i in range(inputs_batch[0].shape[0]):
265
+ img_hash = self.calc_image_hash(inputs_batch[0][i])
266
+ hessian_score_by_image_hash[img_hash] = {
267
+ node: score[i] for node, score in zip(hessian_scores_request.target_nodes, hessian_scores)
268
+ }
269
+
270
+ return hessian_score_by_image_hash
271
+
272
+ @staticmethod
273
+ def calc_image_hash(image):
274
+ """
275
+ Calculates hash for an input image.
276
+
277
+ Args:
278
+ image: input 3d image (without batch).
279
+
280
+ Returns:
281
+ Image hash.
282
+
283
+ """
284
+ if not len(image.shape) == 3: # pragma: no cover
285
+ raise ValueError(f'Expected 3d image (without batch) for image hash calculation, got {len(image.shape)}')
286
+ image_bytes = image.astype(np.float32).tobytes()
287
+ return hashlib.md5(image_bytes).hexdigest()
288
+
231
289
  def fetch_hessian(self,
232
290
  hessian_scores_request: HessianScoresRequest,
233
291
  required_size: int,
@@ -248,7 +306,7 @@ class HessianInfoService:
248
306
  OC for per-output-channel when the requested node has OC output-channels, etc.)
249
307
  """
250
308
 
251
- if len(hessian_scores_request.target_nodes) == 0:
309
+ if len(hessian_scores_request.target_nodes) == 0: # pragma: no cover
252
310
  return []
253
311
 
254
312
  if required_size == 0:
@@ -40,6 +40,14 @@ class HessianScoresGranularity(Enum):
40
40
  PER_TENSOR = 2
41
41
 
42
42
 
43
+ class HessianEstimationDistribution(str, Enum):
44
+ """
45
+ Distribution for Hutchinson estimator random vector
46
+ """
47
+ GAUSSIAN = 'gaussian'
48
+ RADEMACHER = 'rademacher'
49
+
50
+
43
51
  class HessianScoresRequest:
44
52
  """
45
53
  Request configuration for the Hessian-approximation scores.
@@ -53,7 +61,8 @@ class HessianScoresRequest:
53
61
  def __init__(self,
54
62
  mode: HessianMode,
55
63
  granularity: HessianScoresGranularity,
56
- target_nodes: List):
64
+ target_nodes: List,
65
+ distribution: HessianEstimationDistribution = HessianEstimationDistribution.GAUSSIAN):
57
66
  """
58
67
  Attributes:
59
68
  mode (HessianMode): Mode of Hessian-approximation score (w.r.t weights or activations).
@@ -64,6 +73,7 @@ class HessianScoresRequest:
64
73
  self.mode = mode # w.r.t activations or weights
65
74
  self.granularity = granularity # per element, per layer, per channel
66
75
  self.target_nodes = target_nodes
76
+ self.distribution = distribution
67
77
 
68
78
  def __eq__(self, other):
69
79
  # Checks if the other object is an instance of HessianScoresRequest
@@ -71,9 +81,10 @@ class HessianScoresRequest:
71
81
  return isinstance(other, HessianScoresRequest) and \
72
82
  self.mode == other.mode and \
73
83
  self.granularity == other.granularity and \
74
- self.target_nodes == other.target_nodes
84
+ self.target_nodes == other.target_nodes and \
85
+ self.distribution == other.distribution
75
86
 
76
87
  def __hash__(self):
77
88
  # Computes the hash based on the attributes.
78
89
  # The use of a tuple here ensures that the hash is influenced by all the attributes.
79
- return hash((self.mode, self.granularity, tuple(self.target_nodes)))
90
+ return hash((self.mode, self.granularity, tuple(self.target_nodes), self.distribution))
@@ -21,7 +21,8 @@ import numpy as np
21
21
 
22
22
  from model_compression_toolkit.constants import MIN_HESSIAN_ITER, HESSIAN_COMP_TOLERANCE, HESSIAN_NUM_ITERATIONS
23
23
  from model_compression_toolkit.core.common import Graph
24
- from model_compression_toolkit.core.common.hessian import HessianScoresRequest, HessianScoresGranularity
24
+ from model_compression_toolkit.core.common.hessian import (HessianScoresRequest, HessianScoresGranularity,
25
+ HessianEstimationDistribution)
25
26
  from model_compression_toolkit.core.pytorch.back2framework.float_model_builder import FloatPyTorchModelBuilder
26
27
  from model_compression_toolkit.core.pytorch.hessian.hessian_scores_calculator_pytorch import \
27
28
  HessianScoresCalculatorPytorch
@@ -55,6 +56,66 @@ class ActivationHessianScoresCalculatorPytorch(HessianScoresCalculatorPytorch):
55
56
  hessian_scores_request=hessian_scores_request,
56
57
  num_iterations_for_approximation=num_iterations_for_approximation)
57
58
 
59
+ def forward_pass(self):
60
+ model_output_nodes = [ot.node for ot in self.graph.get_outputs()]
61
+
62
+ if len([n for n in self.hessian_request.target_nodes if n in model_output_nodes]) > 0:
63
+ Logger.critical("Activation Hessian approximation cannot be computed for model outputs. "
64
+ "Exclude output nodes from Hessian request targets.")
65
+
66
+ grad_model_outputs = self.hessian_request.target_nodes + model_output_nodes
67
+ model, _ = FloatPyTorchModelBuilder(graph=self.graph, append2output=grad_model_outputs).build_model()
68
+ model.eval()
69
+
70
+ # Run model inference
71
+ # Set inputs to track gradients during inference
72
+ for input_tensor in self.input_images:
73
+ input_tensor.requires_grad_()
74
+ input_tensor.retain_grad()
75
+
76
+ outputs = model(*self.input_images)
77
+
78
+ if len(outputs) != len(grad_model_outputs): # pragma: no cover
79
+ Logger.critical(f"Mismatch in expected and actual model outputs for activation Hessian approximation. "
80
+ f"Expected {len(grad_model_outputs)} outputs, received {len(outputs)}.")
81
+
82
+ # Extracting the intermediate activation tensors and the model real output.
83
+ # Note that we do not allow computing Hessian for output nodes, so there shouldn't be an overlap.
84
+ num_target_nodes = len(self.hessian_request.target_nodes)
85
+ # Extract activation tensors of nodes for which we want to compute Hessian
86
+ target_activation_tensors = outputs[:num_target_nodes]
87
+ # Extract the model outputs
88
+ output_tensors = outputs[num_target_nodes:]
89
+ device = output_tensors[0].device
90
+
91
+ # Concat outputs
92
+ # First, we need to unfold all outputs that are given as list, to extract the actual output tensors
93
+ output = self.concat_tensors(output_tensors)
94
+ return output, target_activation_tensors
95
+
96
+ def _generate_random_vectors_batch(self, shape: tuple, distribution: HessianEstimationDistribution,
97
+ device: torch.device) -> torch.Tensor:
98
+ """
99
+ Generate a batch of random vectors for Hutchinson estimation
100
+
101
+ Args:
102
+ shape: target shape
103
+ distribution: distribution to sample from
104
+ device: target device
105
+
106
+ Returns:
107
+ Random tensor
108
+ """
109
+ if distribution == HessianEstimationDistribution.GAUSSIAN:
110
+ return torch.randn(shape, device=device)
111
+
112
+ if distribution == HessianEstimationDistribution.RADEMACHER:
113
+ v = torch.randint(high=2, size=shape, device=device)
114
+ v[v == 0] = -1
115
+ return v
116
+
117
+ raise ValueError(f'Unknown distribution {distribution}') # pragma: no cover
118
+
58
119
  def compute(self) -> List[np.ndarray]:
59
120
  """
60
121
  Compute the scores that are based on the approximation of the Hessian w.r.t the requested target nodes' activations.
@@ -62,91 +123,79 @@ class ActivationHessianScoresCalculatorPytorch(HessianScoresCalculatorPytorch):
62
123
  Returns:
63
124
  List[np.ndarray]: Scores based on the approximated Hessian for the requested nodes.
64
125
  """
65
- if self.hessian_request.granularity == HessianScoresGranularity.PER_TENSOR:
66
-
67
- model_output_nodes = [ot.node for ot in self.graph.get_outputs()]
68
-
69
- if len([n for n in self.hessian_request.target_nodes if n in model_output_nodes]) > 0:
70
- Logger.critical("Activation Hessian approximation cannot be computed for model outputs. "
71
- "Exclude output nodes from Hessian request targets.")
72
-
73
- grad_model_outputs = self.hessian_request.target_nodes + model_output_nodes
74
- model, _ = FloatPyTorchModelBuilder(graph=self.graph, append2output=grad_model_outputs).build_model()
75
- model.eval()
76
-
77
- # Run model inference
78
- # Set inputs to track gradients during inference
79
- for input_tensor in self.input_images:
80
- input_tensor.requires_grad_()
81
- input_tensor.retain_grad()
82
-
83
- outputs = model(*self.input_images)
84
-
85
- if len(outputs) != len(grad_model_outputs): # pragma: no cover
86
- Logger.critical(f"Mismatch in expected and actual model outputs for activation Hessian approximation. "
87
- f"Expected {len(grad_model_outputs)} outputs, received {len(outputs)}.")
88
-
89
- # Extracting the intermediate activation tensors and the model real output.
90
- # Note that we do not allow computing Hessian for output nodes, so there shouldn't be an overlap.
91
- num_target_nodes = len(self.hessian_request.target_nodes)
92
- # Extract activation tensors of nodes for which we want to compute Hessian
93
- target_activation_tensors = outputs[:num_target_nodes]
94
- # Extract the model outputs
95
- output_tensors = outputs[num_target_nodes:]
96
- device = output_tensors[0].device
97
-
98
- # Concat outputs
99
- # First, we need to unfold all outputs that are given as list, to extract the actual output tensors
100
- output = self.concat_tensors(output_tensors)
101
-
102
- ipts_hessian_approx_scores = [torch.tensor([0.0],
103
- requires_grad=True,
104
- device=device)
105
- for _ in range(len(target_activation_tensors))]
106
- prev_mean_results = None
107
- for j in tqdm(range(self.num_iterations_for_approximation), "Hessian random iterations"): # Approximation iterations
108
- # Getting a random vector with normal distribution
109
- v = torch.randn(output.shape, device=device)
110
- f_v = torch.sum(v * output)
111
- for i, ipt_tensor in enumerate(target_activation_tensors): # Per Interest point activation tensor
112
- # Computing the hessian-approximation scores by getting the gradient of (output * v)
113
- hess_v = autograd.grad(outputs=f_v,
114
- inputs=ipt_tensor,
115
- retain_graph=True,
116
- allow_unused=True)[0]
117
-
118
- if hess_v is None:
119
- # In case we have an output node, which is an interest point, but it is not differentiable,
120
- # we consider its Hessian to be the initial value 0.
121
- continue # pragma: no cover
122
-
123
- # Mean over all dims but the batch (CXHXW for conv)
124
- hessian_approx_scores = torch.sum(hess_v ** 2.0, dim=tuple(d for d in range(1, len(hess_v.shape))))
125
-
126
- # Update node Hessian approximation mean over random iterations
127
- ipts_hessian_approx_scores[i] = (j * ipts_hessian_approx_scores[i] + hessian_approx_scores) / (j + 1)
128
-
129
- # If the change to the maximal mean Hessian approximation is insignificant we stop the calculation
130
- if j > MIN_HESSIAN_ITER:
131
- if prev_mean_results is not None:
132
- new_mean_res = torch.mean(torch.stack(ipts_hessian_approx_scores), dim=1)
133
- relative_delta_per_node = (torch.abs(new_mean_res - prev_mean_results) /
134
- (torch.abs(new_mean_res) + 1e-6))
135
- max_delta = torch.max(relative_delta_per_node)
136
- if max_delta < HESSIAN_COMP_TOLERANCE:
137
- break
138
- prev_mean_results = torch.mean(torch.stack(ipts_hessian_approx_scores), dim=1)
139
-
140
- # Convert results to list of numpy arrays
141
- hessian_results = [torch_tensor_to_numpy(h) for h in ipts_hessian_approx_scores]
142
- # Extend the Hessian tensors shape to align with expected return type
143
- # TODO: currently, only per-tensor Hessian is available for activation.
144
- # Once implementing per-channel or per-element, this alignment needs to be verified and handled separately.
145
- hessian_results = [h[..., np.newaxis] for h in hessian_results]
146
-
147
- return hessian_results
148
-
149
- else: # pragma: no cover
150
- Logger.critical(f"PyTorch activation Hessian's approximation scores does not support "
151
- f"{self.hessian_request.granularity} granularity.")
126
+ output, target_activation_tensors = self.forward_pass()
152
127
 
128
+ if self.hessian_request.granularity == HessianScoresGranularity.PER_TENSOR:
129
+ hessian_scores = self._compute_per_tensor(output, target_activation_tensors)
130
+ elif self.hessian_request.granularity == HessianScoresGranularity.PER_OUTPUT_CHANNEL:
131
+ hessian_scores = self._compute_per_channel(output, target_activation_tensors)
132
+ else:
133
+ raise NotImplementedError(f'{self.hessian_request.granularity} is not supported') # pragma: no cover
134
+
135
+ # Convert results to list of numpy arrays
136
+ hessian_results = [torch_tensor_to_numpy(h) for h in hessian_scores]
137
+ return hessian_results
138
+
139
+ def _compute_per_tensor(self, output, target_activation_tensors):
140
+ assert self.hessian_request.granularity == HessianScoresGranularity.PER_TENSOR
141
+ ipts_hessian_approx_scores = [torch.tensor([0.0], requires_grad=True, device=output.device)
142
+ for _ in range(len(target_activation_tensors))]
143
+ prev_mean_results = None
144
+ for j in tqdm(range(self.num_iterations_for_approximation), "Hessian random iterations"): # Approximation iterations
145
+ # Getting a random vector with normal distribution
146
+ v = self._generate_random_vectors_batch(output.shape, self.hessian_request.distribution, output.device)
147
+ f_v = torch.sum(v * output)
148
+ for i, ipt_tensor in enumerate(target_activation_tensors): # Per Interest point activation tensor
149
+ # Computing the hessian-approximation scores by getting the gradient of (output * v)
150
+ hess_v = autograd.grad(outputs=f_v,
151
+ inputs=ipt_tensor,
152
+ retain_graph=True,
153
+ allow_unused=True)[0]
154
+
155
+ if hess_v is None:
156
+ # In case we have an output node, which is an interest point, but it is not differentiable,
157
+ # we consider its Hessian to be the initial value 0.
158
+ continue # pragma: no cover
159
+
160
+ # Mean over all dims but the batch (CXHXW for conv)
161
+ hessian_approx_scores = torch.sum(hess_v ** 2.0, dim=tuple(d for d in range(1, len(hess_v.shape))))
162
+
163
+ # Update node Hessian approximation mean over random iterations
164
+ ipts_hessian_approx_scores[i] = (j * ipts_hessian_approx_scores[i] + hessian_approx_scores) / (j + 1)
165
+
166
+ # If the change to the maximal mean Hessian approximation is insignificant we stop the calculation
167
+ if j > MIN_HESSIAN_ITER:
168
+ if prev_mean_results is not None:
169
+ new_mean_res = torch.mean(torch.stack(ipts_hessian_approx_scores), dim=1)
170
+ relative_delta_per_node = (torch.abs(new_mean_res - prev_mean_results) /
171
+ (torch.abs(new_mean_res) + 1e-6))
172
+ max_delta = torch.max(relative_delta_per_node)
173
+ if max_delta < HESSIAN_COMP_TOLERANCE:
174
+ break
175
+ prev_mean_results = torch.mean(torch.stack(ipts_hessian_approx_scores), dim=1)
176
+
177
+ # add extra dimension to preserve previous behaviour
178
+ ipts_hessian_approx_scores = [torch.unsqueeze(t, -1) for t in ipts_hessian_approx_scores]
179
+ return ipts_hessian_approx_scores
180
+
181
+ def _compute_per_channel(self, output, target_activation_tensors):
182
+ assert self.hessian_request.granularity == HessianScoresGranularity.PER_OUTPUT_CHANNEL
183
+ ipts_hessian_approx_scores = [torch.tensor(0.0, requires_grad=True, device=output.device)
184
+ for _ in range(len(target_activation_tensors))]
185
+
186
+ for j in tqdm(range(self.num_iterations_for_approximation), "Hessian random iterations"): # Approximation iterations
187
+ v = self._generate_random_vectors_batch(output.shape, self.hessian_request.distribution, output.device)
188
+ f_v = torch.sum(v * output)
189
+ for i, ipt_tensor in enumerate(target_activation_tensors): # Per Interest point activation tensor
190
+ hess_v = autograd.grad(outputs=f_v,
191
+ inputs=ipt_tensor,
192
+ retain_graph=True)[0]
193
+ hessian_approx_scores = hess_v ** 2
194
+ rank = len(hess_v.shape)
195
+ if rank > 2:
196
+ hessian_approx_scores = torch.mean(hessian_approx_scores, dim=tuple(range(2, rank)))
197
+
198
+ # Update node Hessian approximation mean over random iterations
199
+ ipts_hessian_approx_scores[i] = (j * ipts_hessian_approx_scores[i] + hessian_approx_scores) / (j + 1)
200
+
201
+ return ipts_hessian_approx_scores
@@ -17,6 +17,7 @@ from enum import Enum
17
17
  from typing import Callable, Any, Dict, Optional
18
18
 
19
19
  from model_compression_toolkit.constants import GPTQ_HESSIAN_NUM_SAMPLES, ACT_HESSIAN_DEFAULT_BATCH_SIZE
20
+ from model_compression_toolkit.core.common.hessian import HessianScoresGranularity, HessianEstimationDistribution
20
21
  from model_compression_toolkit.gptq.common.gptq_constants import REG_DEFAULT
21
22
 
22
23
 
@@ -39,17 +40,21 @@ class GPTQHessianScoresConfig:
39
40
  Configuration to use for computing the Hessian-based scores for GPTQ loss metric.
40
41
 
41
42
  Args:
42
- hessians_num_samples (int): Number of samples to use for computing the Hessian-based scores.
43
+ hessians_num_samples (int|None): Number of samples to use for computing the Hessian-based scores.
44
+ If None, compute Hessian for all images.
43
45
  norm_scores (bool): Whether to normalize the returned scores of the weighted loss function (to get values between 0 and 1).
44
46
  log_norm (bool): Whether to use log normalization for the GPTQ Hessian-based scores.
45
47
  scale_log_norm (bool): Whether to scale the final vector of the Hessian-based scores.
46
48
  hessian_batch_size (int): The Hessian computation batch size. used only if using GPTQ with Hessian-based objective.
49
+ per_sample (bool): Whether to use per sample attention score.
47
50
  """
48
- hessians_num_samples: int = GPTQ_HESSIAN_NUM_SAMPLES
51
+ hessians_num_samples: Optional[int] = GPTQ_HESSIAN_NUM_SAMPLES
49
52
  norm_scores: bool = True
50
53
  log_norm: bool = True
51
54
  scale_log_norm: bool = False
52
55
  hessian_batch_size: int = ACT_HESSIAN_DEFAULT_BATCH_SIZE
56
+ per_sample: bool = False
57
+ estimator_distribution: HessianEstimationDistribution = HessianEstimationDistribution.GAUSSIAN
53
58
 
54
59
 
55
60
  @dataclass
@@ -13,6 +13,7 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
  import copy
16
+ import hashlib
16
17
  from abc import ABC, abstractmethod
17
18
  import numpy as np
18
19
  from typing import Callable, List, Any, Dict
@@ -143,7 +144,11 @@ class GPTQTrainer(ABC):
143
144
  return np.asarray([1 / num_nodes for _ in range(num_nodes)])
144
145
 
145
146
  # Fetch hessian approximations for each target node
146
- compare_point_to_hessian_approx_scores = self._fetch_hessian_approximations()
147
+ # TODO this smells like a potential bug. In hessian calculation target nodes are topo sorted and results are returned
148
+ # TODO also target nodes are replaced for reuse. Does this work correctly?
149
+ approximations = self._fetch_hessian_approximations(HessianScoresGranularity.PER_TENSOR)
150
+ compare_point_to_hessian_approx_scores = {node: score for node, score in zip(self.compare_points, approximations)}
151
+
147
152
  # Process the fetched hessian approximations to gather them per images
148
153
  hessian_approx_score_by_image = (
149
154
  self._process_hessian_approximations(compare_point_to_hessian_approx_scores))
@@ -172,29 +177,55 @@ class GPTQTrainer(ABC):
172
177
  # If log normalization is not enabled, return the mean of the approximations across images
173
178
  return np.mean(hessian_approx_score_by_image, axis=0)
174
179
 
175
- def _fetch_hessian_approximations(self) -> Dict[BaseNode, List[List[float]]]:
180
+ def _compute_sample_layer_attention_scores(self, inputs_batch) -> Dict[str, Dict[BaseNode, np.ndarray]]:
181
+ """
182
+ Compute sample layer attention scores per image hash per layer.
183
+
184
+ Args:
185
+ inputs_batch: a list containing a batch of inputs.
186
+
187
+ Returns:
188
+ A dictionary with a structure {img_hash: {layer: score}}.
189
+
190
+ """
191
+ request = self._build_hessian_request(HessianScoresGranularity.PER_OUTPUT_CHANNEL)
192
+ hessian_batch_size = self.gptq_config.hessian_weights_config.hessian_batch_size
193
+
194
+ hessian_score_per_image_per_layer = {}
195
+ # If hessian batch is smaller than inputs batch, split it to hessian batches. If hessian batch is larger,
196
+ # it's currently ignored (TODO)
197
+ for i in range(0, inputs_batch[0].shape[0], hessian_batch_size):
198
+ inputs = [t[i: i+hessian_batch_size] for t in inputs_batch]
199
+ hessian_score_per_image_per_layer.update(
200
+ self.hessian_service.compute_trackable_per_sample_hessian(request, inputs)
201
+ )
202
+ for img_hash, v in hessian_score_per_image_per_layer.items():
203
+ hessian_score_per_image_per_layer[img_hash] = {k: t.max(axis=0) for k, t in v.items()}
204
+ return hessian_score_per_image_per_layer
205
+
206
+ def _fetch_hessian_approximations(self, granularity: HessianScoresGranularity) -> Dict[BaseNode, List[List[float]]]:
176
207
  """
177
208
  Fetches hessian approximations for each target node.
178
209
 
179
210
  Returns:
180
211
  Mapping of target nodes to their hessian approximations.
181
212
  """
182
- approximations = {}
183
- hessian_scores_request = HessianScoresRequest(
184
- mode=HessianMode.ACTIVATION,
185
- granularity=HessianScoresGranularity.PER_TENSOR,
186
- target_nodes=self.compare_points
187
- )
213
+ hessian_scores_request = self._build_hessian_request(granularity)
214
+
188
215
  node_approximations = self.hessian_service.fetch_hessian(
189
216
  hessian_scores_request=hessian_scores_request,
190
217
  required_size=self.gptq_config.hessian_weights_config.hessians_num_samples,
191
218
  batch_size=self.gptq_config.hessian_weights_config.hessian_batch_size
192
219
  )
220
+ return node_approximations
193
221
 
194
- for i, target_node in enumerate(self.compare_points):
195
- approximations[target_node] = node_approximations[i]
196
-
197
- return approximations
222
+ def _build_hessian_request(self, granularity):
223
+ return HessianScoresRequest(
224
+ mode=HessianMode.ACTIVATION,
225
+ granularity=granularity,
226
+ target_nodes=self.compare_points,
227
+ distribution=self.gptq_config.hessian_weights_config.estimator_distribution
228
+ )
198
229
 
199
230
  def _process_hessian_approximations(self, approximations: Dict[BaseNode, List[List[float]]]) -> List:
200
231
  """
@@ -13,8 +13,10 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
  from typing import List
16
+
16
17
  import torch
17
18
 
19
+
18
20
  def mse_loss(y: torch.Tensor, x: torch.Tensor, normalized: bool = True) -> torch.Tensor:
19
21
  """
20
22
  Compute the MSE of two tensors.
@@ -25,7 +27,7 @@ def mse_loss(y: torch.Tensor, x: torch.Tensor, normalized: bool = True) -> torch
25
27
  Returns:
26
28
  The MSE of two tensors.
27
29
  """
28
- loss = torch.nn.MSELoss()(x,y)
30
+ loss = torch.nn.MSELoss()(x, y)
29
31
  return loss / torch.mean(torch.square(x)) if normalized else loss
30
32
 
31
33
 
@@ -62,3 +64,36 @@ def multiple_tensors_mse_loss(y_list: List[torch.Tensor],
62
64
  else:
63
65
  return torch.mean(torch.stack(loss_values_list))
64
66
 
67
+
68
+ def sample_layer_attention_loss(y_list: List[torch.Tensor],
69
+ x_list: List[torch.Tensor],
70
+ fxp_w_list,
71
+ flp_w_list,
72
+ act_bn_mean,
73
+ act_bn_std,
74
+ loss_weights: torch.Tensor) -> torch.Tensor:
75
+ """
76
+ Compute Sample Layer Attention loss between two lists of tensors.
77
+
78
+ Args:
79
+ y_list: First list of tensors.
80
+ x_list: Second list of tensors.
81
+ fxp_w_list, flp_w_list, act_bn_mean, act_bn_std: unused (needed to comply with the interface).
82
+ loss_weights: layer-sample weights tensor of shape (layers, batch)
83
+
84
+ Returns:
85
+ Sample Layer Attention loss (a scalar).
86
+ """
87
+ loss = 0
88
+ layers_mean_w = []
89
+
90
+ for i, (y, x, w) in enumerate(zip(y_list, x_list, loss_weights)):
91
+ norm = (y - x).pow(2).sum(1)
92
+ if len(norm.shape) > 1:
93
+ norm = norm.flatten(1).mean(1)
94
+ loss += torch.mean(w * norm)
95
+ layers_mean_w.append(w.mean())
96
+
97
+ loss = loss / torch.stack(layers_mean_w).max()
98
+ return loss
99
+
@@ -12,7 +12,7 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
- from typing import Callable, List, Tuple, Union
15
+ from typing import Callable, List, Tuple, Union, Dict
16
16
 
17
17
  import numpy as np
18
18
  from torch.nn import Module
@@ -105,8 +105,18 @@ class PytorchGPTQTrainer(GPTQTrainer):
105
105
  self.optimizer_with_param = self.get_optimizer_with_param(trainable_weights,
106
106
  trainable_bias,
107
107
  trainable_threshold)
108
-
109
- self.weights_for_average_loss = to_torch_tensor(self.compute_hessian_based_weights())
108
+ hessian_cfg = self.gptq_config.hessian_weights_config
109
+ self.use_sample_layer_attention = hessian_cfg.per_sample
110
+ self.hessian_score_per_layer = None # for fixed layer weights
111
+ self.hessian_score_per_image_per_layer = None # for sample-layer attention
112
+ if self.use_sample_layer_attention:
113
+ # normalization is currently not supported, make sure the config reflects it.
114
+ if hessian_cfg.norm_scores or hessian_cfg.log_norm or hessian_cfg.scale_log_norm:
115
+ raise NotImplementedError()
116
+ # Per sample hessian scores are calculated on-demand during the training loop
117
+ self.hessian_score_per_image_per_layer = {}
118
+ else:
119
+ self.hessian_score_per_layer = to_torch_tensor(self.compute_hessian_based_weights())
110
120
 
111
121
  self.reg_func = get_regularization(self.gptq_config, _get_total_grad_steps)
112
122
 
@@ -210,13 +220,17 @@ class PytorchGPTQTrainer(GPTQTrainer):
210
220
 
211
221
  def compute_gradients(self,
212
222
  y_float: List[torch.Tensor],
213
- input_tensors: List[torch.Tensor]) -> Tuple[torch.Tensor, List[np.ndarray]]:
223
+ input_tensors: List[torch.Tensor],
224
+ distill_loss_weights: torch.Tensor,
225
+ round_reg_weights: torch.Tensor) -> Tuple[torch.Tensor, List[np.ndarray]]:
214
226
  """
215
227
  Get outputs from both teacher and student networks. Compute the observed error,
216
228
  and use it to compute the gradients and applying them to the student weights.
217
229
  Args:
218
230
  y_float: A list of reference tensor from the floating point network.
219
231
  input_tensors: A list of Input tensors to pass through the networks.
232
+ distill_loss_weights: Weights for the distillation loss.
233
+ round_reg_weights: Weight for the rounding regularization loss.
220
234
  Returns:
221
235
  Loss and gradients.
222
236
  """
@@ -231,9 +245,8 @@ class PytorchGPTQTrainer(GPTQTrainer):
231
245
  self.flp_weights_list,
232
246
  self.compare_points_mean,
233
247
  self.compare_points_std,
234
- self.weights_for_average_loss)
235
-
236
- reg_value = self.reg_func(self.fxp_model, self.gptq_config.regularization_factor)
248
+ distill_loss_weights)
249
+ reg_value = self.reg_func(self.fxp_model, self.gptq_config.regularization_factor, round_reg_weights)
237
250
 
238
251
  loss_value += reg_value
239
252
 
@@ -261,10 +274,11 @@ class PytorchGPTQTrainer(GPTQTrainer):
261
274
  for _ in epochs_pbar:
262
275
  with tqdm(data_function(), position=1, leave=False) as data_pbar:
263
276
  for data in data_pbar:
277
+ distill_weights, reg_weights = to_torch_tensor(self._get_loss_weights(data))
264
278
  input_data = [d * self.input_scale for d in data]
265
279
  input_tensor = to_torch_tensor(input_data)
266
280
  y_float = self.float_model(input_tensor) # running float model
267
- loss_value, grads = self.compute_gradients(y_float, input_tensor)
281
+ loss_value, grads = self.compute_gradients(y_float, input_tensor, distill_weights, reg_weights)
268
282
  # Run one step of gradient descent by updating the value of the variables to minimize the loss.
269
283
  for (optimizer, _) in self.optimizer_with_param:
270
284
  optimizer.step()
@@ -276,6 +290,42 @@ class PytorchGPTQTrainer(GPTQTrainer):
276
290
  self.loss_list.append(loss_value.item())
277
291
  Logger.debug(f'last loss value: {self.loss_list[-1]}')
278
292
 
293
+ def _get_loss_weights(self, input_tensors: List[torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]:
294
+ """
295
+ Fetches weights for distillation and round regularization parts of loss.
296
+
297
+ Args:
298
+ input_tensors: list containing a batch of inputs.
299
+
300
+ Returns:
301
+ A tuple of two tensors:
302
+ - weights for distillation loss
303
+ - weights for rounding regularization loss
304
+
305
+ """
306
+ if self.use_sample_layer_attention is False:
307
+ return self.hessian_score_per_layer, torch.ones_like(self.hessian_score_per_layer)
308
+
309
+ if len(input_tensors) > 1:
310
+ raise NotImplementedError('Sample-Layer attention is not currently supported for networks with multiple inputs')
311
+
312
+ image_scores = []
313
+ batch = input_tensors[0]
314
+ img_hashes = [self.hessian_service.calc_image_hash(img) for img in batch]
315
+ for img_hash in img_hashes:
316
+ # If sample-layer attention score for the image is not found, compute and store it for the whole batch.
317
+ if img_hash not in self.hessian_score_per_image_per_layer:
318
+ score_per_image_per_layer = self._compute_sample_layer_attention_scores(input_tensors)
319
+ self.hessian_score_per_image_per_layer.update(score_per_image_per_layer)
320
+ img_scores_per_layer: Dict[BaseNode, np.ndarray] = self.hessian_score_per_image_per_layer[img_hash]
321
+ # fetch image scores for all layers and combine them into a single tensor
322
+ img_scores = np.stack(list(img_scores_per_layer.values()), axis=0)
323
+ image_scores.append(img_scores)
324
+
325
+ layer_sample_weights = np.stack(image_scores, axis=1) # layers X images
326
+ layer_weights = layer_sample_weights.mean(axis=1)
327
+ return layer_sample_weights, layer_weights
328
+
279
329
  def update_graph(self) -> Graph:
280
330
  """
281
331
  Update a graph using GPTQ after minimizing the loss between the float model's output
@@ -18,6 +18,7 @@ from typing import Callable, Union
18
18
  from model_compression_toolkit.constants import ACT_HESSIAN_DEFAULT_BATCH_SIZE, PYTORCH
19
19
  from model_compression_toolkit.core import CoreConfig
20
20
  from model_compression_toolkit.core.analyzer import analyzer_model_quantization
21
+ from model_compression_toolkit.core.common.hessian import HessianScoresGranularity, HessianEstimationDistribution
21
22
  from model_compression_toolkit.core.common.mixed_precision.mixed_precision_quantization_config import \
22
23
  MixedPrecisionQuantizationConfig
23
24
  from model_compression_toolkit.core.common.mixed_precision.resource_utilization_tools.resource_utilization import \
@@ -43,7 +44,7 @@ if FOUND_TORCH:
43
44
  from model_compression_toolkit.core.pytorch.default_framework_info import DEFAULT_PYTORCH_INFO
44
45
  from model_compression_toolkit.gptq.pytorch.gptq_pytorch_implementation import GPTQPytorchImplemantation
45
46
  from model_compression_toolkit.target_platform_capabilities.constants import DEFAULT_TP_MODEL
46
- from model_compression_toolkit.gptq.pytorch.gptq_loss import multiple_tensors_mse_loss
47
+ from model_compression_toolkit.gptq.pytorch.gptq_loss import multiple_tensors_mse_loss, sample_layer_attention_loss
47
48
  from model_compression_toolkit.exporter.model_wrapper.pytorch.builder.fully_quantized_model_builder import get_exportable_pytorch_model
48
49
  import torch
49
50
  from torch.nn import Module
@@ -55,11 +56,12 @@ if FOUND_TORCH:
55
56
  def get_pytorch_gptq_config(n_epochs: int,
56
57
  optimizer: Optimizer = None,
57
58
  optimizer_rest: Optimizer = None,
58
- loss: Callable = multiple_tensors_mse_loss,
59
+ loss: Callable = None,
59
60
  log_function: Callable = None,
60
61
  use_hessian_based_weights: bool = True,
61
62
  regularization_factor: float = REG_DEFAULT,
62
63
  hessian_batch_size: int = ACT_HESSIAN_DEFAULT_BATCH_SIZE,
64
+ use_hessian_sample_attention: bool = False,
63
65
  gradual_activation_quantization: Union[bool, GradualActivationQuantizationConfig] = False,
64
66
  ) -> GradientPTQConfig:
65
67
  """
@@ -74,6 +76,7 @@ if FOUND_TORCH:
74
76
  use_hessian_based_weights (bool): Whether to use Hessian-based weights for weighted average loss.
75
77
  regularization_factor (float): A floating point number that defines the regularization factor.
76
78
  hessian_batch_size (int): Batch size for Hessian computation in Hessian-based weights GPTQ.
79
+ use_hessian_sample_attention (bool): whether to use Sample-Layer Attention score for weighted loss.
77
80
  gradual_activation_quantization (bool, GradualActivationQuantizationConfig):
78
81
  If False, GradualActivationQuantization is disabled.
79
82
  If True, GradualActivationQuantization is enabled with the default settings.
@@ -105,19 +108,37 @@ if FOUND_TORCH:
105
108
 
106
109
  bias_optimizer = torch.optim.SGD([torch.Tensor([])], lr=LR_BIAS_DEFAULT, momentum=GPTQ_MOMENTUM)
107
110
 
111
+ if use_hessian_sample_attention:
112
+ if not use_hessian_based_weights: # pragma: no cover
113
+ raise ValueError('use_hessian_based_weights must be set to True in order to use Sample Layer Attention.')
114
+
115
+ hessian_weights_config = GPTQHessianScoresConfig(
116
+ hessians_num_samples=None,
117
+ norm_scores=False,
118
+ log_norm=False,
119
+ scale_log_norm=False,
120
+ hessian_batch_size=hessian_batch_size,
121
+ per_sample=True,
122
+ estimator_distribution=HessianEstimationDistribution.RADEMACHER
123
+ )
124
+ loss = loss or sample_layer_attention_loss
125
+ else:
126
+ hessian_weights_config = GPTQHessianScoresConfig(hessian_batch_size=hessian_batch_size)
127
+ loss = loss or multiple_tensors_mse_loss
128
+
108
129
  if isinstance(gradual_activation_quantization, bool):
109
130
  gradual_quant_config = GradualActivationQuantizationConfig() if gradual_activation_quantization else None
110
131
  elif isinstance(gradual_activation_quantization, GradualActivationQuantizationConfig):
111
132
  gradual_quant_config = gradual_activation_quantization
112
- else:
133
+ else: # pragma: no cover
113
134
  raise TypeError(f'gradual_activation_quantization argument should be bool or '
114
- f'GradualActivationQuantizationConfig, received {type(gradual_activation_quantization)}') # pragma: no cover
135
+ f'GradualActivationQuantizationConfig, received {type(gradual_activation_quantization)}')
115
136
 
116
137
  return GradientPTQConfig(n_epochs, optimizer, optimizer_rest=optimizer_rest, loss=loss,
117
138
  log_function=log_function, train_bias=True, optimizer_bias=bias_optimizer,
118
139
  use_hessian_based_weights=use_hessian_based_weights,
119
140
  regularization_factor=regularization_factor,
120
- hessian_weights_config=GPTQHessianScoresConfig(hessian_batch_size=hessian_batch_size),
141
+ hessian_weights_config=hessian_weights_config,
121
142
  gradual_activation_quantization_config=gradual_quant_config)
122
143
 
123
144
  def pytorch_gradient_post_training_quantization(model: Module,
@@ -185,11 +206,11 @@ if FOUND_TORCH:
185
206
 
186
207
  """
187
208
 
188
- if core_config.is_mixed_precision_enabled:
209
+ if core_config.is_mixed_precision_enabled: # pragma: no cover
189
210
  if not isinstance(core_config.mixed_precision_config, MixedPrecisionQuantizationConfig):
190
211
  Logger.critical("Given quantization config for mixed-precision is not of type 'MixedPrecisionQuantizationConfig'. "
191
212
  "Ensure usage of the correct API for 'pytorch_gradient_post_training_quantization' "
192
- "or provide a valid mixed-precision configuration.") # pragma: no cover
213
+ "or provide a valid mixed-precision configuration.")
193
214
 
194
215
  tb_w = init_tensorboard_writer(DEFAULT_PYTORCH_INFO)
195
216
 
@@ -41,4 +41,4 @@ def get_regularization(gptq_config: GradientPTQConfig, get_total_grad_steps_fn:
41
41
  scheduler = LinearAnnealingScheduler(t_start=t_start, t_end=total_gradient_steps, initial_val=20, target_val=2)
42
42
  return SoftQuantizerRegularization(scheduler)
43
43
  else:
44
- return lambda m, e_reg: 0
44
+ return lambda *args, **kwargs: 0
@@ -40,32 +40,34 @@ class SoftQuantizerRegularization:
40
40
 
41
41
  self.count_iter = 0
42
42
 
43
- def __call__(self, model: nn.Module, entropy_reg: float):
43
+ def __call__(self, model: nn.Module, entropy_reg: float, layer_weights: torch.Tensor):
44
44
  """
45
45
  Returns the soft quantizer regularization value for SoftRounding.
46
46
 
47
47
  Args:
48
48
  model: A model to be quantized with SoftRounding.
49
49
  entropy_reg: Entropy value to scale the quantizer regularization.
50
+ layer_weights: a vector of layer weights.
50
51
 
51
52
  Returns: Regularization value.
52
53
  """
54
+ layers = [m for m in model.modules() if isinstance(m, PytorchQuantizationWrapper)]
53
55
 
54
- soft_reg_aux: List[torch.Tensor] = []
55
- b = self.beta_scheduler(self.count_iter)
56
- for layer in model.modules():
57
- if isinstance(layer, PytorchQuantizationWrapper):
58
- kernel_attribute = get_kernel_attribute_name_for_gptq(layer_type=type(layer.layer),
59
- fw_info=DEFAULT_PYTORCH_INFO)
60
-
61
- st = layer.weights_quantizers[kernel_attribute].get_soft_targets()
62
- soft_reg_aux.append((1 - torch.pow(torch.abs(st - .5) * 2, b)).sum())
56
+ if len(layer_weights.shape) != 1 or layer_weights.shape[0] != len(layers):
57
+ raise ValueError(f'Expected weights to be a vector of length {len(layers)}, received {layer_weights.shape}.') # pragma: no cover
58
+ max_w = layer_weights.max()
63
59
 
60
+ b = self.beta_scheduler(self.count_iter)
64
61
  reg = 0
62
+ for layer, w in zip(layers, layer_weights):
63
+ kernel_attribute = get_kernel_attribute_name_for_gptq(layer_type=type(layer.layer),
64
+ fw_info=DEFAULT_PYTORCH_INFO)
65
65
 
66
- for sq in soft_reg_aux:
67
- reg += sq
66
+ st = layer.weights_quantizers[kernel_attribute].get_soft_targets()
67
+ soft_loss = (1 - torch.pow(torch.abs(st - .5) * 2, b)).sum()
68
+ reg += w * soft_loss
68
69
 
70
+ reg = reg / max_w
69
71
  self.count_iter += 1
70
72
 
71
73
  return entropy_reg * reg