mct-nightly 2.3.0.20250310.500__py3-none-any.whl → 2.3.0.20250313.526__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (22) hide show
  1. {mct_nightly-2.3.0.20250310.500.dist-info → mct_nightly-2.3.0.20250313.526.dist-info}/METADATA +2 -2
  2. {mct_nightly-2.3.0.20250310.500.dist-info → mct_nightly-2.3.0.20250313.526.dist-info}/RECORD +22 -21
  3. model_compression_toolkit/__init__.py +1 -1
  4. model_compression_toolkit/core/common/graph/base_node.py +8 -0
  5. model_compression_toolkit/core/common/graph/memory_graph/cut.py +20 -12
  6. model_compression_toolkit/core/common/graph/memory_graph/max_cut_astar.py +2 -1
  7. model_compression_toolkit/core/common/graph/virtual_activation_weights_node.py +50 -69
  8. model_compression_toolkit/core/common/mixed_precision/mixed_precision_ru_helper.py +16 -46
  9. model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_facade.py +1 -0
  10. model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_manager.py +1 -3
  11. model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization_calculator.py +162 -70
  12. model_compression_toolkit/core/common/substitutions/virtual_activation_weights_composition.py +35 -15
  13. model_compression_toolkit/core/common/substitutions/weights_activation_split.py +3 -3
  14. model_compression_toolkit/core/pytorch/constants.py +1 -0
  15. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/convtranspose_dynamic_padding.py +77 -0
  16. model_compression_toolkit/core/pytorch/pytorch_implementation.py +4 -1
  17. model_compression_toolkit/core/runner.py +2 -2
  18. model_compression_toolkit/exporter/model_wrapper/pytorch/builder/fully_quantized_model_builder.py +4 -0
  19. model_compression_toolkit/target_platform_capabilities/__init__.py +3 -2
  20. {mct_nightly-2.3.0.20250310.500.dist-info → mct_nightly-2.3.0.20250313.526.dist-info}/LICENSE.md +0 -0
  21. {mct_nightly-2.3.0.20250310.500.dist-info → mct_nightly-2.3.0.20250313.526.dist-info}/WHEEL +0 -0
  22. {mct_nightly-2.3.0.20250310.500.dist-info → mct_nightly-2.3.0.20250313.526.dist-info}/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.2
2
2
  Name: mct-nightly
3
- Version: 2.3.0.20250310.500
3
+ Version: 2.3.0.20250313.526
4
4
  Summary: A Model Compression Toolkit for neural networks
5
5
  Classifier: Programming Language :: Python :: 3
6
6
  Classifier: License :: OSI Approved :: Apache Software License
@@ -20,7 +20,7 @@ Requires-Dist: PuLP
20
20
  Requires-Dist: matplotlib<3.10.0
21
21
  Requires-Dist: scipy
22
22
  Requires-Dist: protobuf
23
- Requires-Dist: mct-quantizers==1.5.2
23
+ Requires-Dist: mct-quantizers-nightly
24
24
  Requires-Dist: pydantic<2.0
25
25
  Dynamic: classifier
26
26
  Dynamic: description
@@ -1,4 +1,4 @@
1
- model_compression_toolkit/__init__.py,sha256=IvTOp-U0QerJ9UCdyeAzLvpX3qRiOQGbOeGL7ps8zGg,1557
1
+ model_compression_toolkit/__init__.py,sha256=9GncQIw01bNCuq697TP39EeUYSVCocW8hYFq_GlA4NY,1557
2
2
  model_compression_toolkit/constants.py,sha256=i_R6uXBfO1ph_X6DNJych2x59SUojfJbn7dNjs_mZnc,3846
3
3
  model_compression_toolkit/defaultdict.py,sha256=LSc-sbZYXENMCw3U9F4GiXuv67IKpdn0Qm7Fr11jy-4,2277
4
4
  model_compression_toolkit/logger.py,sha256=L3q7tn3Uht0i_7phnlOWMR2Te2zvzrt2HOz9vYEInts,4529
@@ -8,7 +8,7 @@ model_compression_toolkit/core/__init__.py,sha256=8a0wUNBKwTdJGDk_Ho6WQAXjGuCqQZ
8
8
  model_compression_toolkit/core/analyzer.py,sha256=X-2ZpkH1xdXnISnw1yJvXnvV-ssoUh-9LkLISSWNqiY,3691
9
9
  model_compression_toolkit/core/graph_prep_runner.py,sha256=CVTjBaci8F6EP3IKDnRMfxkP-Sv8qY8GpkGt6FyII2U,11376
10
10
  model_compression_toolkit/core/quantization_prep_runner.py,sha256=DPevqQ8brkdut8K5f5v9g5lbT3r1GSmhLAk3NkL40Fg,6593
11
- model_compression_toolkit/core/runner.py,sha256=iJpDasfs7wtdAelIRaBPxDbN64phPern1O86QDM2HeY,13706
11
+ model_compression_toolkit/core/runner.py,sha256=qblr8WM6R5v4jip94kBeWHKsjc-FUOteVgMtunGf8lU,13716
12
12
  model_compression_toolkit/core/common/__init__.py,sha256=Wh127PbXcETZX_d1PQqZ71ETK3J9XO5A-HpadGUbj6o,1447
13
13
  model_compression_toolkit/core/common/base_substitutions.py,sha256=xDFSmVVs_iFSZfajytI0cuQaNRNcwHX3uqOoHgVUvxQ,1666
14
14
  model_compression_toolkit/core/common/framework_implementation.py,sha256=s3yiqnbWkwfnAB1sSal_KAuqVg27rLhAJ2O8LHUbSHE,22494
@@ -34,17 +34,17 @@ model_compression_toolkit/core/common/fusion/graph_fuser.py,sha256=b41_4rL_Adiza
34
34
  model_compression_toolkit/core/common/fusion/layer_fusing.py,sha256=-2fnjyC9q2RPw9st6RxROW-gdtT2mSRz0QZ_Gz1KDz4,5579
35
35
  model_compression_toolkit/core/common/graph/__init__.py,sha256=Xr-Lt_qXMdrCnnOaUS_OJP_3iTTGfPCLf8_vSrQgCs0,773
36
36
  model_compression_toolkit/core/common/graph/base_graph.py,sha256=0zsiEldkV_wjDoTjaGtL8DOMGEv2yQqhajwEAnFgqR8,37819
37
- model_compression_toolkit/core/common/graph/base_node.py,sha256=LYiF4Pv0doX9dJhXGBM78Ay40qYDp0gXHd19JwS11Uo,33463
37
+ model_compression_toolkit/core/common/graph/base_node.py,sha256=kZbmAMh5cPAwYzlY8KYa8w0ipL58yApB09-WXQ8plrE,33763
38
38
  model_compression_toolkit/core/common/graph/edge.py,sha256=buoSEUZwilWBK3WeBKpJ-GeDaUA1SDdOHxDpxU_bGpk,3784
39
39
  model_compression_toolkit/core/common/graph/functional_node.py,sha256=GH5wStmw8SoAj5IdT_-ItN1Meo_P5NUTt_5bgJC4fak,3935
40
40
  model_compression_toolkit/core/common/graph/graph_matchers.py,sha256=CrDoHYq4iPaflgJWmoJ1K4ziLrRogJvFTVWg8P0UcDU,4744
41
41
  model_compression_toolkit/core/common/graph/graph_searches.py,sha256=2oKuW6L8hP-oL0lFO9PhQFt9fEFgVJwpc1u4fHExAtE,5128
42
- model_compression_toolkit/core/common/graph/virtual_activation_weights_node.py,sha256=3el-A7j1oyoo1_9zq3faQp7IeRsFXFCvnrb3zZFXpU0,9803
42
+ model_compression_toolkit/core/common/graph/virtual_activation_weights_node.py,sha256=gPlGMyC5jdUTQy8jYU_Rz7cPXSH6JhV4Dnwt3-1FAKM,9849
43
43
  model_compression_toolkit/core/common/graph/memory_graph/__init__.py,sha256=cco4TmeIDIh32nj9ZZXVkws4dd9F2UDrmjKzTN8G0V0,697
44
44
  model_compression_toolkit/core/common/graph/memory_graph/bipartite_graph.py,sha256=X6FK3C3y8ixFRPjC_wm3ClloCX8_06SOdA1TRi7o_LA,3800
45
45
  model_compression_toolkit/core/common/graph/memory_graph/compute_graph_max_cut.py,sha256=oyz260JXDbvL8aI-DVtUvLHtLRWC2Yu4SBYlGL68c2Y,3498
46
- model_compression_toolkit/core/common/graph/memory_graph/cut.py,sha256=7Dfq4TVJIrnencHLJqjhxYKhY7ooUo_ml33WH2IIAgc,2576
47
- model_compression_toolkit/core/common/graph/memory_graph/max_cut_astar.py,sha256=E8xKMUxtEF0GjztUk-3CmMtivPPBcADnZTUaSN24o6A,17816
46
+ model_compression_toolkit/core/common/graph/memory_graph/cut.py,sha256=ZUGgn-vDA7unzc9UWhK2v_2i5nfdkSG1xOpgpDmziEo,2870
47
+ model_compression_toolkit/core/common/graph/memory_graph/max_cut_astar.py,sha256=1TWLVAOlT8g8q_YyOdjm5cQfiSDZ5EHGQcb509Gnzjg,17895
48
48
  model_compression_toolkit/core/common/graph/memory_graph/memory_element.py,sha256=ISD2BvJWj5mB91jrFjG8VQb0oOoLBoita_thCZWzCPI,4238
49
49
  model_compression_toolkit/core/common/graph/memory_graph/memory_graph.py,sha256=FCzK4HmX4lWI4qGoGv94wpGv7o6_f5wPBfeBPMerZ18,7752
50
50
  model_compression_toolkit/core/common/hessian/__init__.py,sha256=E7LK3K_1AwMCQokanNc1JODMwUKNOKmwXQiGQ7GO10I,1033
@@ -66,15 +66,15 @@ model_compression_toolkit/core/common/mixed_precision/configurable_quantizer_uti
66
66
  model_compression_toolkit/core/common/mixed_precision/distance_weighting.py,sha256=-x8edUyudu1EAEM66AuXPtgayLpzbxoLNubfEbFM5kU,2867
67
67
  model_compression_toolkit/core/common/mixed_precision/mixed_precision_candidates_filter.py,sha256=6pLUEEIqRTVIlCYQC4JIvY55KAvuBHEX8uTOQ-1Ac4Q,3859
68
68
  model_compression_toolkit/core/common/mixed_precision/mixed_precision_quantization_config.py,sha256=r1t025_QHshyoop-PZvL7x6UuXaeplCCU3h4VNBhJHo,4309
69
- model_compression_toolkit/core/common/mixed_precision/mixed_precision_ru_helper.py,sha256=7iJ2YprFvm2Dk9EkXYrwO7-Sf89f537D-KrQP7XhvPs,8889
70
- model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_facade.py,sha256=aAuGIzSDKIDiq07nheeWRXLEatzr6Fvoa5ZHv-2BtCI,7130
71
- model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_manager.py,sha256=uDX1mEaq7qqWk2FQCfVXUYVlpGWS0OBP0C1CsGCkZYY,32791
69
+ model_compression_toolkit/core/common/mixed_precision/mixed_precision_ru_helper.py,sha256=k7LjEmcvlkiV995DU7S1CrNOllu6qPZrhUUKXcZDIUQ,7538
70
+ model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_facade.py,sha256=4YH9tsFPOn6rCcedfyocZhZwDLNX5kB1tebu0-nvhyA,7226
71
+ model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_manager.py,sha256=ItBWNZYOf-Zzi8FaRv1y170wYRXYcR3pJysClOtH8qc,32525
72
72
  model_compression_toolkit/core/common/mixed_precision/sensitivity_evaluation.py,sha256=gsigifJ-ykWNafF4t7UMEC_-nd6YPERAk1_z0kT-Y88,27172
73
73
  model_compression_toolkit/core/common/mixed_precision/set_layer_to_bitwidth.py,sha256=P8QtKgFXtt5b2RoubzI5OGlCfbEfZsAirjyrkFzK26A,2846
74
74
  model_compression_toolkit/core/common/mixed_precision/solution_refinement_procedure.py,sha256=MQZnBcpBDMd5y6rOunUtH3t41GQH0aBmxVB4muoxNfk,9477
75
75
  model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/__init__.py,sha256=Rf1RcYmelmdZmBV5qOKvKWF575ofc06JFQSq83Jz99A,696
76
76
  model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization.py,sha256=T5yVr7lay-6QLuTDBZNI1Ufj02EMBWuY_yHjC8eHx5I,3998
77
- model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization_calculator.py,sha256=DyiE84ECgwtaCATWcisv-7ndmBUbj_TaddZ7GeIjlrU,35307
77
+ model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization_calculator.py,sha256=Oj-tVGUyBXtTpxNFQVPja8fFcUOpi6B2PdpNKHkAlbc,39314
78
78
  model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization_data.py,sha256=J7gqUGs4ITo4ufl84A5vACxm670LG6RhQyXkejfpbn8,8834
79
79
  model_compression_toolkit/core/common/mixed_precision/search_methods/__init__.py,sha256=sw7LOPN1bM82o3SkMaklyH0jw-TLGK0-fl2Wq73rffI,697
80
80
  model_compression_toolkit/core/common/mixed_precision/search_methods/linear_programming.py,sha256=9Hh85pr0VL65umhf9mPnrrssJXwJPAsIkBwCZnfzjHY,17575
@@ -146,8 +146,8 @@ model_compression_toolkit/core/common/substitutions/residual_collapsing.py,sha25
146
146
  model_compression_toolkit/core/common/substitutions/scale_equalization.py,sha256=p57u25qdW2pimxzGwgMXEBV4S-LzXuTVAlIM7830WfU,10966
147
147
  model_compression_toolkit/core/common/substitutions/shift_negative_activation.py,sha256=oiiN16OqDrax4FPP5VeyTz0rhb0-eZJACKznTBlKkio,30013
148
148
  model_compression_toolkit/core/common/substitutions/softmax_shift.py,sha256=R-0ZqhYAuZLEFWHvB2UTPm52L6gWHGdRdEnwGxKSeGI,2625
149
- model_compression_toolkit/core/common/substitutions/virtual_activation_weights_composition.py,sha256=aXzUOJfgKPfQpEGfiIun26fgfCqazBG1mBpzoc4Ezxs,3477
150
- model_compression_toolkit/core/common/substitutions/weights_activation_split.py,sha256=h85L2VlDOqbLd-N98wA3SdYWiblBgSsPceNuLanJd70,4737
149
+ model_compression_toolkit/core/common/substitutions/virtual_activation_weights_composition.py,sha256=w43dRmaG96a8SNECgghxoFCTSoZ-vUb33dXGm2PbomE,4251
150
+ model_compression_toolkit/core/common/substitutions/weights_activation_split.py,sha256=gt07lXRUvYunJKiwv_w20zfXhcplSW4oT2C1dqiNNXc,4719
151
151
  model_compression_toolkit/core/common/visualization/__init__.py,sha256=mjbqLD-KcG3eNeCYpu1GBS7VclGVOQ63x2p6mAAuba4,698
152
152
  model_compression_toolkit/core/common/visualization/final_config_visualizer.py,sha256=6I10jKLesB-RQKaXA75Xgz2wPvylQUrnPtCcQZIynGo,6371
153
153
  model_compression_toolkit/core/common/visualization/nn_visualizer.py,sha256=HOq7AObkmEZiDSZXUMJDAEJzUY-fSXUT0AMgwiyH7dg,7388
@@ -219,11 +219,11 @@ model_compression_toolkit/core/keras/statistics_correction/apply_second_moment_c
219
219
  model_compression_toolkit/core/keras/statistics_correction/keras_compute_activation_bias_correction_of_graph.py,sha256=lq6yw9r1u0ZGA95JFvzsV-HQax66qAkJBmGeKnG9OrM,3409
220
220
  model_compression_toolkit/core/keras/visualization/__init__.py,sha256=mjbqLD-KcG3eNeCYpu1GBS7VclGVOQ63x2p6mAAuba4,698
221
221
  model_compression_toolkit/core/pytorch/__init__.py,sha256=Rf1RcYmelmdZmBV5qOKvKWF575ofc06JFQSq83Jz99A,696
222
- model_compression_toolkit/core/pytorch/constants.py,sha256=YwD_joIF0vK8UG2vW1NVvg36pCNWA0vHOXjAgy_XWn0,2794
222
+ model_compression_toolkit/core/pytorch/constants.py,sha256=Sg0hkUaMe88mI2_pd3KqhVz5ORnA46S1uq9Tj5qhtHc,2828
223
223
  model_compression_toolkit/core/pytorch/data_util.py,sha256=YYbT135HhlTt0q6XdD2JX7AS_L92f_uV2rWq2hsJOCA,6325
224
224
  model_compression_toolkit/core/pytorch/default_framework_info.py,sha256=NLdmiig5a2EBxutJeDHjp8px4g_2EKt3zmntmK-NrT4,4309
225
225
  model_compression_toolkit/core/pytorch/pytorch_device_config.py,sha256=S25cuw10AW3SEN_fRAGRcG_I3wdvvQx1ehSJzPnn-UI,4404
226
- model_compression_toolkit/core/pytorch/pytorch_implementation.py,sha256=ToGadW24_9ajSWc_J8jlARw7OOO5BRt0_HvN1FfijgI,30575
226
+ model_compression_toolkit/core/pytorch/pytorch_implementation.py,sha256=QBCKYimTbHGFmXGz84Ioni5C9qKntp9FMEBLMUrIKkY,30771
227
227
  model_compression_toolkit/core/pytorch/pytorch_node_prior_info.py,sha256=2LDQ7qupglHQ7o1Am7LWdfYVacfQnl-aW2N6l9det1w,3264
228
228
  model_compression_toolkit/core/pytorch/resource_utilization_data_facade.py,sha256=aIHl-dTAC4ISnWSKLD99c-1W3827vfRGyLjMBib-l3s,5618
229
229
  model_compression_toolkit/core/pytorch/utils.py,sha256=7VbgcLwtQvdEEc_AJgSOQ3U3KRKCICFPaBirN1fIQxg,3940
@@ -244,6 +244,7 @@ model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/batchno
244
244
  model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/batchnorm_refusing.py,sha256=JDWOaNwYrZG0zTwd3HwoZUM3tKu7zPbzLOrqNQsu8xA,2162
245
245
  model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/concat_threshold_update.py,sha256=SBrR24ZAnWPftLinv4FuIqdBGjfYtfXbYQJN5mgy5V4,2861
246
246
  model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/const_holder_conv.py,sha256=sw3jIOUSvfWUeD8l3rGcUOtC6QuzpMIQm8V3RQAM53Q,4741
247
+ model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/convtranspose_dynamic_padding.py,sha256=N0VQr7hYkj1BN6O91nqiLkV3ZtclLkqlNNJwOEKv62g,3205
247
248
  model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/functional_batch_norm.py,sha256=7GZY7lU3LUUaO5iiccHkUP62PB0QeGAGOZdUSGMkFBY,4450
248
249
  model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/functional_layer_norm.py,sha256=XhiLVcnCc_gF-6mjxbf9C4bYg5YL_GCvDJmcdLkBNAg,4151
249
250
  model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/functional_linear.py,sha256=3-OHYPun5Rt7GITqV3ZekJk59tsuY9ZYSpRpxKsNEVA,3450
@@ -347,7 +348,7 @@ model_compression_toolkit/exporter/model_wrapper/keras/builder/node_to_quantizer
347
348
  model_compression_toolkit/exporter/model_wrapper/pytorch/__init__.py,sha256=Rf1RcYmelmdZmBV5qOKvKWF575ofc06JFQSq83Jz99A,696
348
349
  model_compression_toolkit/exporter/model_wrapper/pytorch/validate_layer.py,sha256=vQUGbCi8_pGoN8DwQ0IblSeN6L9t6Cr0reZNuCbBpkM,3469
349
350
  model_compression_toolkit/exporter/model_wrapper/pytorch/builder/__init__.py,sha256=cco4TmeIDIh32nj9ZZXVkws4dd9F2UDrmjKzTN8G0V0,697
350
- model_compression_toolkit/exporter/model_wrapper/pytorch/builder/fully_quantized_model_builder.py,sha256=0sx6PLcnJ42LHKn79Qx1FOH615YBqM9OJMF6S1W6plE,6255
351
+ model_compression_toolkit/exporter/model_wrapper/pytorch/builder/fully_quantized_model_builder.py,sha256=dpN2Hyb56Wt4INEtBJAOxZeFdhIwdx__WFTmOVkxMLc,6470
351
352
  model_compression_toolkit/exporter/model_wrapper/pytorch/builder/node_to_quantizer.py,sha256=Pl8a8MSZMzNbm5vngujFjCt_iSMbSmKjlcL1DvN9nTM,9292
352
353
  model_compression_toolkit/gptq/__init__.py,sha256=pEgkJvmf05KSw70iLDTz_6LI_2Oi5L8sTN0JsEUpnpk,1445
353
354
  model_compression_toolkit/gptq/runner.py,sha256=La12JTYjWyJW0YW4Al4TP1_Xi4JWBCEKw6FR_JQsxe0,5982
@@ -428,7 +429,7 @@ model_compression_toolkit/qat/pytorch/quantizer/lsq/uniform_lsq.py,sha256=KefO2Z
428
429
  model_compression_toolkit/qat/pytorch/quantizer/ste_rounding/__init__.py,sha256=Rf1RcYmelmdZmBV5qOKvKWF575ofc06JFQSq83Jz99A,696
429
430
  model_compression_toolkit/qat/pytorch/quantizer/ste_rounding/symmetric_ste.py,sha256=p1JqtBZZVHTV5caR1U0d1t2UcTz0ACNyLcJTBFUEq98,6173
430
431
  model_compression_toolkit/qat/pytorch/quantizer/ste_rounding/uniform_ste.py,sha256=wWehe5R0xVHSm3ruMrUc8RzW5UVAVCMgUTUMPDsvy9g,5487
431
- model_compression_toolkit/target_platform_capabilities/__init__.py,sha256=9ZcT9JVlYzy8k7MlAXhj086gn6SxlGFsjMvy7ubcnfc,1392
432
+ model_compression_toolkit/target_platform_capabilities/__init__.py,sha256=8RVOriZg-XNjSt53h_4Yum0oRgOe2gp5H45dfG_lZxE,1415
432
433
  model_compression_toolkit/target_platform_capabilities/constants.py,sha256=BFSgDwYWU1sZShjoW2S7eH3AI0D4SqDOeOu_sQ971LE,1518
433
434
  model_compression_toolkit/target_platform_capabilities/immutable.py,sha256=YhROBiXEIB3TU-bAFrnL3qbAsb1yuWPBAQ_CLOJbYUU,1827
434
435
  model_compression_toolkit/target_platform_capabilities/tpc_io_handler.py,sha256=4ydTWWKv_PEOAFok2JtxFNj8rav-0IlqcXKF6lnhHNE,4157
@@ -524,8 +525,8 @@ model_compression_toolkit/xquant/pytorch/model_analyzer.py,sha256=b93o800yVB3Z-i
524
525
  model_compression_toolkit/xquant/pytorch/pytorch_report_utils.py,sha256=UVN_S9ULHBEldBpShCOt8-soT8YTQ5oE362y96qF_FA,3950
525
526
  model_compression_toolkit/xquant/pytorch/similarity_functions.py,sha256=CERxq5K8rqaiE-DlwhZBTUd9x69dtYJlkHOPLB54vm8,2354
526
527
  model_compression_toolkit/xquant/pytorch/tensorboard_utils.py,sha256=mkoEktLFFHtEKzzFRn_jCnxjhJolK12TZ5AQeDHzUO8,9767
527
- mct_nightly-2.3.0.20250310.500.dist-info/LICENSE.md,sha256=aYSSIb-5AFPeITTvXm1UAoe0uYBiMmSS8flvXaaFUks,10174
528
- mct_nightly-2.3.0.20250310.500.dist-info/METADATA,sha256=sONRBJhRO4oeP6vCk6tpSIEekAc8Y8EH-6HfO-a9ZG0,27079
529
- mct_nightly-2.3.0.20250310.500.dist-info/WHEEL,sha256=52BFRY2Up02UkjOa29eZOS2VxUrpPORXg1pkohGGUS8,91
530
- mct_nightly-2.3.0.20250310.500.dist-info/top_level.txt,sha256=gsYA8juk0Z-ZmQRKULkb3JLGdOdz8jW_cMRjisn9ga4,26
531
- mct_nightly-2.3.0.20250310.500.dist-info/RECORD,,
528
+ mct_nightly-2.3.0.20250313.526.dist-info/LICENSE.md,sha256=aYSSIb-5AFPeITTvXm1UAoe0uYBiMmSS8flvXaaFUks,10174
529
+ mct_nightly-2.3.0.20250313.526.dist-info/METADATA,sha256=aCFIGI9kNuUGhU8Koa0PDvXhqkHxK05E6a10mBpQAgU,27080
530
+ mct_nightly-2.3.0.20250313.526.dist-info/WHEEL,sha256=52BFRY2Up02UkjOa29eZOS2VxUrpPORXg1pkohGGUS8,91
531
+ mct_nightly-2.3.0.20250313.526.dist-info/top_level.txt,sha256=gsYA8juk0Z-ZmQRKULkb3JLGdOdz8jW_cMRjisn9ga4,26
532
+ mct_nightly-2.3.0.20250313.526.dist-info/RECORD,,
@@ -27,4 +27,4 @@ from model_compression_toolkit import data_generation
27
27
  from model_compression_toolkit import pruning
28
28
  from model_compression_toolkit.trainable_infrastructure.keras.load_model import keras_load_quantized_model
29
29
 
30
- __version__ = "2.3.0.20250310.000500"
30
+ __version__ = "2.3.0.20250313.000526"
@@ -167,6 +167,14 @@ class BaseNode:
167
167
  """
168
168
  return self.is_weights_quantization_enabled(attr_name) and not self.is_all_weights_candidates_equal(attr_name)
169
169
 
170
+ def has_any_configurable_weight(self) -> bool:
171
+ """
172
+ Check whether any of the node's weights is configurable.
173
+ Returns:
174
+ Whether any of the node's weights is configurable.
175
+ """
176
+ return any(self.is_configurable_weight(attr) for attr in self.weights)
177
+
170
178
  def has_configurable_activation(self) -> bool:
171
179
  """
172
180
  Checks whether the activation has a configurable quantization.
@@ -12,28 +12,36 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
+ from dataclasses import dataclass, field
16
+
15
17
  from typing import List, Set
16
18
 
17
19
  from model_compression_toolkit.core.common import BaseNode
18
20
  from model_compression_toolkit.core.common.graph.memory_graph.memory_element import MemoryElements
19
21
 
20
22
 
23
+ @dataclass(frozen=True)
21
24
  class Cut:
22
25
  """
23
26
  A Cut object that contains a set of ordered nodes and their memory elements.
24
- """
25
27
 
26
- def __init__(self, op_order: List[BaseNode], op_record: Set[BaseNode], mem_elements: MemoryElements):
27
- """
28
- Args:
29
- op_order: A list of the cut's nodes (model layers), ordered by their addition to the cut (first-to-last).
30
- op_record: A (unordered) set of the nodes in the cut.
31
- mem_elements: MemoryElements object which represents the activation tensors of the cut's nodes.
32
- """
33
-
34
- self.op_order = op_order
35
- self.op_record = op_record
36
- self.mem_elements = mem_elements
28
+ Args:
29
+ op_order: A list of the cut's nodes (model layers), ordered by their addition to the cut (first-to-last).
30
+ op_record: A (unordered) set of the nodes in the cut.
31
+ mem_elements: MemoryElements object which represents the activation tensors of the cut's nodes.
32
+ """
33
+ op_order: List[BaseNode]
34
+ op_record: Set[BaseNode]
35
+ mem_elements: MemoryElements
36
+
37
+ _sorted_elements_signature: str = field(init=False, default=None)
38
+
39
+ @property
40
+ def sorted_elements_signature(self):
41
+ if self._sorted_elements_signature is None:
42
+ object.__setattr__(self, '_sorted_elements_signature',
43
+ '_'.join(sorted([e.node_name for e in self.mem_elements.elements])))
44
+ return self._sorted_elements_signature
37
45
 
38
46
  def memory_size(self) -> float:
39
47
  """
@@ -232,7 +232,8 @@ class MaxCutAstar:
232
232
  max_cut_len = max([len(routes[c]) for c in open_list])
233
233
  ordered_cuts_list = sorted(open_list,
234
234
  key=lambda c: (self.accumulate(costs[c], self.estimate(c, estimate)),
235
- max_cut_len - len(routes[c])))
235
+ max_cut_len - len(routes[c]),
236
+ c.sorted_elements_signature))
236
237
 
237
238
  assert len(ordered_cuts_list) > 0
238
239
  return ordered_cuts_list[0]
@@ -24,7 +24,6 @@ import numpy as np
24
24
 
25
25
  from model_compression_toolkit.core.common.quantization.candidate_node_quantization_config import \
26
26
  CandidateNodeQuantizationConfig
27
- from model_compression_toolkit.logger import Logger
28
27
 
29
28
 
30
29
  class VirtualSplitNode(BaseNode):
@@ -73,11 +72,14 @@ class VirtualSplitWeightsNode(VirtualSplitNode):
73
72
  super().__init__(origin_node)
74
73
 
75
74
  self.name = origin_node.name + VIRTUAL_WEIGHTS_SUFFIX
76
-
77
- self.candidates_quantization_cfg = origin_node.get_unique_weights_candidates(kernel_attr)
78
- for c in self.candidates_quantization_cfg:
79
- c.activation_quantization_cfg.enable_activation_quantization = False
80
- c.activation_quantization_cfg.activation_n_bits = FLOAT_BITWIDTH
75
+ # Virtual weights node is created only to be absorbed into virtual composed node right away.
76
+ # However, in some cases composition is impossible and virtual weights node can remain in the graph.
77
+ # In such case it messes up resource utilization computation, specifically activation cuts. In order to minimize
78
+ # the impact, we preserve the behavior of the original node wrt activation (shape and quantization),
79
+ # so that prev - virtualW cut is identical to prev-origin_node. Only the cut virtualW-virtualA will be different
80
+ # from the original graph, so in the worst case the utilization will be higher in virtual graph.
81
+ # This should guarantee that the utilization of the original graph does not exceed the requested target.
82
+ self.candidates_quantization_cfg = origin_node.candidates_quantization_cfg
81
83
 
82
84
 
83
85
  class VirtualSplitActivationNode(VirtualSplitNode):
@@ -126,89 +128,68 @@ class VirtualActivationWeightsNode(BaseNode):
126
128
  def __init__(self,
127
129
  act_node: BaseNode,
128
130
  weights_node: BaseNode,
129
- name: str,
130
- framework_attr: Dict[str, Any],
131
- input_shape: Tuple[Any],
132
- output_shape: Tuple[Any],
133
- weights: Dict[str, np.ndarray],
134
- layer_class: type,
135
- fw_info: FrameworkInfo,
136
- reuse: bool = False,
137
- reuse_group: str = None,
138
- quantization_attr: Dict[str, Any] = None,
139
- has_activation: bool = True,
140
- **kwargs):
131
+ fw_info: FrameworkInfo):
141
132
  """
142
133
  Init a VirtualActivationWeightsNode object.
143
134
 
144
135
  Args:
145
136
  act_node: The original activation node.
146
137
  weights_node: The original weights node.
147
- name: Node's name
148
- framework_attr: Framework attributes the layer had which the node holds.
149
- input_shape: Input tensor shape of the node.
150
- output_shape: Input tensor shape of the node.
151
- weights: Dictionary from a variable name to the weights with that name in the layer the node represents.
152
- layer_class: Class path of the layer this node represents.
153
- fw_info: A FrameworkInfo object with framework specific information,
154
- reuse: Whether this node was duplicated and represents a reused layer.
155
- reuse_group: Name of group of nodes from the same reused layer.
156
- quantization_attr: Attributes the node holds regarding how it should be quantized.
157
- has_activation: Whether the node has activations that we might want to quantize.
158
- **kwargs: Additional arguments that can be passed but are not used (allows to init the object with an
159
- existing node's __dict__).
160
-
138
+ fw_info: A FrameworkInfo object with framework specific information.
161
139
  """
162
-
140
+ # Validate weights node
141
+ kernel_attrs = fw_info.get_kernel_op_attributes(weights_node.type)
142
+ assert len(kernel_attrs) == 1 and kernel_attrs[0] is not None, 'Expected exactly one kernel attr.'
143
+ kernel_attr = kernel_attrs[0]
144
+ conf_weights = [attr for attr in weights_node.weights if weights_node.is_configurable_weight(attr)]
145
+ if len(conf_weights) > 1 or len(conf_weights) == 1 and not weights_node.is_configurable_weight(kernel_attr):
146
+ raise NotImplementedError('Only kernel weight can be configurable.') # pragma: no cover
147
+
148
+ weights = weights_node.weights
149
+ if act_node.weights:
150
+ assert fw_info.get_kernel_op_attributes(act_node)[0] is None, \
151
+ f'Node {act_node} with kernel cannot be used as activation for VirtualActivationWeightsNode.'
152
+ if set(weights_node.weights.keys()).intersection(set(act_node.weights.keys())):
153
+ raise ValueError('Activation and weight nodes are not expected to have the same weight attribute') # pragma: no cover
154
+ if act_node.has_any_configurable_weight():
155
+ raise NotImplementedError('Node with a configurable weight cannot be used as activation for '
156
+ 'VirtualActivationWeightsNode.') # pragma: no cover
157
+ # combine weights from activation and weights
158
+ weights.update(act_node.weights)
159
+
160
+ name = f"{VIRTUAL_ACTIVATION_WEIGHTS_NODE_PREFIX}_{act_node.name}_{weights_node.name}"
163
161
  super().__init__(name,
164
- framework_attr,
165
- input_shape,
166
- output_shape,
167
- weights,
168
- layer_class,
169
- reuse,
170
- reuse_group,
171
- quantization_attr,
172
- has_activation)
173
-
174
- self.name = f"{VIRTUAL_ACTIVATION_WEIGHTS_NODE_PREFIX}_{act_node.name}_{weights_node.name}"
162
+ framework_attr=weights_node.framework_attr,
163
+ input_shape=act_node.input_shape,
164
+ output_shape=act_node.output_shape,
165
+ weights=weights,
166
+ layer_class=weights_node.layer_class,
167
+ reuse=weights_node.reuse,
168
+ reuse_group=weights_node.reuse_group,
169
+ quantization_attr=weights_node.quantization_attr,
170
+ has_activation=False)
175
171
 
176
172
  self.original_activation_node = act_node
177
173
  self.original_weights_node = weights_node
178
174
 
179
175
  v_candidates = []
176
+ weights_candidates_quantization_cfg = weights_node.get_unique_weights_candidates(kernel_attr)
180
177
  for c_a in act_node.candidates_quantization_cfg:
181
- for c_w in weights_node.candidates_quantization_cfg:
178
+ for c_w in weights_candidates_quantization_cfg:
182
179
  composed_candidate = CandidateNodeQuantizationConfig(activation_quantization_cfg=c_a.activation_quantization_cfg,
183
180
  weights_quantization_cfg=c_w.weights_quantization_cfg)
181
+ if act_node.weights:
182
+ # add non-kernel weights cfg from activation node to the composed node's weights cfg
183
+ composed_candidate.weights_quantization_cfg.attributes_config_mapping.update(
184
+ c_a.weights_quantization_cfg.attributes_config_mapping
185
+ )
186
+ composed_candidate.weights_quantization_cfg.pos_attributes_config_mapping.update(
187
+ c_a.weights_quantization_cfg.pos_attributes_config_mapping
188
+ )
184
189
  v_candidates.append(composed_candidate)
185
190
 
186
191
  # sorting the candidates by weights number of bits first and then by activation number of bits (reversed order)
187
- kernel_attr = fw_info.get_kernel_op_attributes(self.type)[0]
188
192
  v_candidates.sort(key=lambda c: (c.weights_quantization_cfg.get_attr_config(kernel_attr).weights_n_bits,
189
193
  c.activation_quantization_cfg.activation_n_bits), reverse=True)
190
194
 
191
195
  self.candidates_quantization_cfg = v_candidates
192
-
193
- def get_bops_count(self, fw_impl: Any, fw_info: FrameworkInfo, candidate_idx: int) -> float:
194
- """
195
- Computes the composed node's (edge) bit-operation count.
196
-
197
- Args:
198
- fw_impl: A FrameworkImplementation object with framework specific methods.
199
- fw_info: A FrameworkInfo object with framework specific information,
200
- candidate_idx: The index of the node's quantization candidate configuration.
201
-
202
- Returns: The BOPS count of the composed node.
203
-
204
- """
205
- kernel_attr = fw_info.get_kernel_op_attributes(self.original_weights_node.type)[0]
206
- node_mac = fw_impl.get_node_mac_operations(self.original_weights_node, fw_info)
207
- candidate = self.candidates_quantization_cfg[candidate_idx]
208
- kernel_attr_cfg = candidate.weights_quantization_cfg.get_attr_config(kernel_attr)
209
- weights_bit = kernel_attr_cfg.weights_n_bits if \
210
- kernel_attr_cfg.enable_weights_quantization else FLOAT_BITWIDTH
211
- activation_bit = candidate.activation_quantization_cfg.activation_n_bits if \
212
- candidate.activation_quantization_cfg.enable_activation_quantization else FLOAT_BITWIDTH
213
- node_bops = weights_bit * activation_bit * node_mac
214
- return node_bops
@@ -19,7 +19,6 @@ import numpy as np
19
19
  from model_compression_toolkit.core import FrameworkInfo
20
20
  from model_compression_toolkit.core.common import Graph, BaseNode
21
21
  from model_compression_toolkit.core.common.framework_implementation import FrameworkImplementation
22
- from model_compression_toolkit.core.common.graph.virtual_activation_weights_node import VirtualActivationWeightsNode
23
22
  from model_compression_toolkit.core.common.mixed_precision.resource_utilization_tools.resource_utilization import \
24
23
  RUTarget
25
24
  from model_compression_toolkit.core.common.mixed_precision.resource_utilization_tools.resource_utilization_calculator import \
@@ -28,9 +27,6 @@ from model_compression_toolkit.core.common.quantization.node_quantization_config
28
27
  NodeActivationQuantizationConfig
29
28
 
30
29
 
31
- # TODO take into account Virtual nodes. Are candidates defined with respect to virtual or original nodes?
32
- # Can we use the virtual graph only for bops and the original graph for everything else?
33
-
34
30
  class MixedPrecisionRUHelper:
35
31
  """ Helper class for resource utilization computations for mixed precision optimization. """
36
32
 
@@ -65,7 +61,7 @@ class MixedPrecisionRUHelper:
65
61
  ru[RUTarget.ACTIVATION] = np.array(list(au.values()))
66
62
 
67
63
  if RUTarget.BOPS in ru_targets:
68
- ru[RUTarget.BOPS] = self._bops_utilization(mp_cfg)
64
+ ru[RUTarget.BOPS] = self._bops_utilization(act_qcs=act_qcs, w_qcs=w_qcs)
69
65
 
70
66
  if RUTarget.TOTAL in ru_targets:
71
67
  raise ValueError('Total target should be computed based on weights and activations targets.')
@@ -88,8 +84,8 @@ class MixedPrecisionRUHelper:
88
84
  """
89
85
  mp_nodes = self.graph.get_configurable_sorted_nodes(self.fw_info)
90
86
  node_qcs = {n: n.candidates_quantization_cfg[mp_cfg[i]] for i, n in enumerate(mp_nodes)}
91
- act_qcs = {n: cfg.activation_quantization_cfg for n, cfg in node_qcs.items()}
92
- w_qcs = {n: cfg.weights_quantization_cfg for n, cfg in node_qcs.items()}
87
+ act_qcs = {n.name: cfg.activation_quantization_cfg for n, cfg in node_qcs.items()}
88
+ w_qcs = {n.name: cfg.weights_quantization_cfg for n, cfg in node_qcs.items()}
93
89
  return act_qcs, w_qcs
94
90
 
95
91
  def _weights_utilization(self, w_qcs: Optional[Dict[BaseNode, NodeWeightsQuantizationConfig]]) -> Dict[BaseNode, float]:
@@ -137,51 +133,25 @@ class MixedPrecisionRUHelper:
137
133
  cuts_util = {c: u.bytes for c, u in cuts_util.items()}
138
134
  return cuts_util
139
135
 
140
- def _bops_utilization(self, mp_cfg: List[int]) -> np.ndarray:
136
+ def _bops_utilization(self,
137
+ act_qcs: Optional[Dict[BaseNode, NodeActivationQuantizationConfig]],
138
+ w_qcs: Optional[Dict[BaseNode, NodeWeightsQuantizationConfig]]) -> np.ndarray:
141
139
  """
142
- Computes a resource utilization vector with the respective bit-operations (BOPS) count for each configurable node,
143
- according to the given mixed-precision configuration of a virtual graph with composed nodes.
140
+ Computes a resource utilization vector with the respective bit-operations (BOPS) count
141
+ according to the given mixed-precision configuration.
144
142
 
145
143
  Args:
146
- mp_cfg: A mixed-precision configuration (list of candidates index for each configurable node)
144
+ act_qcs: nodes activation configuration or None.
145
+ w_qcs: nodes quantization configuration to compute, or None.
146
+ Either both are provided, or both are None.
147
147
 
148
148
  Returns:
149
149
  A vector of node's BOPS count.
150
150
  """
151
- # bops is computed for all nodes, so non-configurable memory is already covered by the computation of
152
- # configurable nodes
153
- if not mp_cfg:
151
+ assert [act_qcs, w_qcs].count(None) in [0, 2], 'act_qcs and w_qcs should both be provided or both be None.'
152
+ if act_qcs is None:
154
153
  return np.array([])
155
154
 
156
- # TODO keeping old implementation for now
157
- virtual_bops_nodes = [n for n in self.graph.get_topo_sorted_nodes() if isinstance(n, VirtualActivationWeightsNode)]
158
-
159
- mp_nodes = self.graph.get_configurable_sorted_nodes_names(self.fw_info)
160
-
161
- bops = [n.get_bops_count(self.fw_impl, self.fw_info, candidate_idx=_get_node_cfg_idx(n, mp_cfg, mp_nodes))
162
- for n in virtual_bops_nodes]
163
-
164
- return np.array(bops)
165
-
166
-
167
- def _get_node_cfg_idx(node: BaseNode, mp_cfg: List[int], sorted_configurable_nodes_names: List[str]) -> int:
168
- """
169
- Returns the index of a node's quantization configuration candidate according to the given
170
- mixed-precision configuration. If the node is not configurable, then it must have a single configuration,
171
- therefore, the index 0 is returned.
172
-
173
- Args:
174
- node: A node to get its candidate configuration index.
175
- mp_cfg: A mixed-precision configuration (list of candidates index for each configurable node)
176
- sorted_configurable_nodes_names: A list of configurable nodes names.
177
-
178
- Returns: An index (integer) of a node's quantization configuration candidate.
179
- """
180
-
181
- if node.name in sorted_configurable_nodes_names:
182
- node_idx = sorted_configurable_nodes_names.index(node.name)
183
- return mp_cfg[node_idx]
184
- else: # pragma: no cover
185
- assert len(node.candidates_quantization_cfg) > 0, \
186
- "Any node should have at least one candidate configuration."
187
- return 0
155
+ _, detailed_bops = self.ru_calculator.compute_bops(TargetInclusionCriterion.Any, BitwidthMode.QCustom,
156
+ act_qcs=act_qcs, w_qcs=w_qcs)
157
+ return np.array(list(detailed_bops.values()))
@@ -83,6 +83,7 @@ def search_bit_width(graph_to_search_cfg: Graph,
83
83
  # Set graph for MP search
84
84
  graph = copy.deepcopy(graph_to_search_cfg) # Copy graph before searching
85
85
  if target_resource_utilization.bops_restricted():
86
+ # TODO: we only need the virtual graph is both activations and weights are configurable
86
87
  # Since Bit-operations count target resource utilization is set, we need to reconstruct the graph for the MP search
87
88
  graph = substitute(graph, fw_impl.get_substitutions_virtual_weights_activation_coupling())
88
89
 
@@ -189,11 +189,9 @@ class MixedPrecisionSearchManager:
189
189
 
190
190
  """
191
191
  act_qcs, w_qcs = self.ru_helper.get_quantization_candidates(config)
192
- act_qcs = None if (RUTarget.ACTIVATION not in self.ru_targets_to_compute and RUTarget.TOTAL not in self.ru_targets_to_compute) else act_qcs
193
- w_qcs = None if (RUTarget.WEIGHTS not in self.ru_targets_to_compute and RUTarget.TOTAL not in self.ru_targets_to_compute) else w_qcs
194
192
  ru = self.ru_helper.ru_calculator.compute_resource_utilization(
195
193
  target_criterion=TargetInclusionCriterion.AnyQuantized, bitwidth_mode=BitwidthMode.QCustom, act_qcs=act_qcs,
196
- w_qcs=w_qcs, ru_targets=self.ru_targets_to_compute)
194
+ w_qcs=w_qcs, ru_targets=self.ru_targets_to_compute, allow_unused_qcs=True)
197
195
  return ru
198
196
 
199
197
  def finalize_distance_metric(self, layer_to_metrics_mapping: Dict[int, Dict[int, float]]):
@@ -13,24 +13,27 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
  from collections import defaultdict
16
+
16
17
  from copy import deepcopy
17
18
  from enum import Enum, auto
18
19
  from typing import Dict, NamedTuple, Optional, Tuple, List, Iterable, Union, Literal, Sequence
19
20
 
20
- from model_compression_toolkit.logger import Logger
21
21
  from model_compression_toolkit.constants import FLOAT_BITWIDTH
22
22
  from model_compression_toolkit.core import FrameworkInfo
23
23
  from model_compression_toolkit.core.common import Graph, BaseNode
24
24
  from model_compression_toolkit.core.common.framework_implementation import FrameworkImplementation
25
25
  from model_compression_toolkit.core.common.graph.base_node import WeightAttrT
26
- from model_compression_toolkit.core.common.graph.edge import EDGE_SINK_INDEX
27
26
  from model_compression_toolkit.core.common.graph.memory_graph.compute_graph_max_cut import compute_graph_max_cut
28
27
  from model_compression_toolkit.core.common.graph.memory_graph.cut import Cut
29
28
  from model_compression_toolkit.core.common.graph.memory_graph.memory_graph import MemoryGraph
29
+ from model_compression_toolkit.core.common.graph.virtual_activation_weights_node import VirtualActivationWeightsNode, \
30
+ VirtualSplitWeightsNode
30
31
  from model_compression_toolkit.core.common.mixed_precision.resource_utilization_tools.resource_utilization import \
31
32
  RUTarget, ResourceUtilization
32
33
  from model_compression_toolkit.core.common.quantization.node_quantization_config import NodeWeightsQuantizationConfig, \
33
- NodeActivationQuantizationConfig
34
+ NodeActivationQuantizationConfig, BaseNodeQuantizationConfig
35
+ from model_compression_toolkit.core.common.substitutions.virtual_activation_weights_composition import \
36
+ get_input_activation_if_composable
34
37
 
35
38
 
36
39
  class BitwidthMode(Enum):
@@ -101,6 +104,12 @@ class Utilization(NamedTuple):
101
104
  return self.bytes < other.bytes
102
105
 
103
106
 
107
+ NodeName = str
108
+ ActivationQCfgPerNode = Dict[NodeName, NodeActivationQuantizationConfig]
109
+ WeightsQCfgPerNode = Dict[NodeName, NodeWeightsQuantizationConfig]
110
+ DetailedMem = Dict[Union[BaseNode, Cut], float]
111
+
112
+
104
113
  class ResourceUtilizationCalculator:
105
114
  """ Resource utilization calculator. """
106
115
 
@@ -110,6 +119,7 @@ class ResourceUtilizationCalculator:
110
119
  }
111
120
 
112
121
  unexpected_qc_error = 'Custom quantization configuration is not expected for non-custom bit mode.'
122
+ unexpected_qc_nodes_error = 'Custom quantization configuration contains unexpected node names.'
113
123
 
114
124
  def __init__(self, graph: Graph, fw_impl: FrameworkImplementation, fw_info: FrameworkInfo):
115
125
  self.graph = graph
@@ -121,10 +131,11 @@ class ResourceUtilizationCalculator:
121
131
  self._act_tensors_size = {}
122
132
  self._params_cnt = {}
123
133
  for n in graph.nodes:
124
- self._act_tensors_size[n] = n.get_total_output_params()
134
+ self._act_tensors_size[n.name] = n.get_total_output_params()
125
135
  if n.weights:
126
- self._params_cnt[n] = {k: v.size for k, v in n.weights.items()}
136
+ self._params_cnt[n.name] = {k: v.size for k, v in n.weights.items()}
127
137
  self._cuts: Optional[Dict[Cut, List[BaseNode]]] = None
138
+ self._nodes_names = set(n.name for n in graph.nodes)
128
139
 
129
140
  @property
130
141
  def cuts(self) -> Dict[Cut, List[BaseNode]]:
@@ -142,10 +153,12 @@ class ResourceUtilizationCalculator:
142
153
  def compute_resource_utilization(self,
143
154
  target_criterion: TargetInclusionCriterion,
144
155
  bitwidth_mode: BitwidthMode,
145
- act_qcs: Optional[Dict[BaseNode, NodeActivationQuantizationConfig]] = None,
146
- w_qcs: Optional[Dict[BaseNode, NodeWeightsQuantizationConfig]] = None,
156
+ act_qcs: Optional[ActivationQCfgPerNode] = None,
157
+ w_qcs: Optional[WeightsQCfgPerNode] = None,
147
158
  ru_targets: Iterable[RUTarget] = None,
148
- allow_unused_qcs: bool = False) -> ResourceUtilization:
159
+ allow_unused_qcs: bool = False,
160
+ return_detailed=False) \
161
+ -> Union[ResourceUtilization, Tuple[ResourceUtilization, Dict[RUTarget, DetailedMem]]]:
149
162
  """
150
163
  Compute network's resource utilization.
151
164
 
@@ -161,14 +174,17 @@ class ResourceUtilizationCalculator:
161
174
  ru_targets: metrics to include for computation. If None, all metrics are calculated.
162
175
  allow_unused_qcs: by default, if custom quantization configs are passed, but are not going to be used for
163
176
  any of the requested targets, an error is raised. To disable the validation, pass True.
177
+ return_detailed: whether to return an additional dictionary with detailed utilization per element.
164
178
 
165
179
  Returns:
166
- Resource utilization object.
180
+ Resource utilization object, or a tuple of resource utilization object and a dict containing detailed
181
+ memory utilization per ru target: for weights and bops targets - bytes per node,
182
+ for activations and total targets - bytes per cut.
167
183
  """
168
184
  ru_targets = set(ru_targets) if ru_targets else set(RUTarget)
169
185
 
170
- if (w_qcs or act_qcs) and bitwidth_mode != BitwidthMode.QCustom:
171
- raise ValueError(self.unexpected_qc_error)
186
+ self._validate_custom_qcs(act_qcs, bitwidth_mode)
187
+ self._validate_custom_qcs(w_qcs, bitwidth_mode)
172
188
 
173
189
  if w_qcs and not {RUTarget.WEIGHTS, RUTarget.TOTAL, RUTarget.BOPS}.intersection(ru_targets):
174
190
  if not allow_unused_qcs:
@@ -180,31 +196,46 @@ class ResourceUtilizationCalculator:
180
196
  raise ValueError('Activation configuration passed but no relevant ru_targets requested.')
181
197
  act_qcs = None
182
198
 
183
- w_total, a_total = None, None
199
+ w_total, w_per_node = None, None
184
200
  if {RUTarget.WEIGHTS, RUTarget.TOTAL}.intersection(ru_targets):
185
- w_total, *_ = self.compute_weights_utilization(target_criterion, bitwidth_mode, w_qcs)
201
+ w_total, w_per_node, _ = self.compute_weights_utilization(target_criterion, bitwidth_mode, w_qcs)
186
202
 
203
+ a_total, a_per_cut = None, None
187
204
  if {RUTarget.ACTIVATION, RUTarget.TOTAL}.intersection(ru_targets):
188
- a_total = self.compute_activations_utilization(target_criterion, bitwidth_mode, act_qcs)
205
+ a_total, a_per_cut, _ = self.compute_activations_utilization(target_criterion, bitwidth_mode, act_qcs)
189
206
 
190
207
  ru = ResourceUtilization()
208
+ detailed = {}
191
209
  if RUTarget.WEIGHTS in ru_targets:
192
210
  ru.weights_memory = w_total
211
+ if return_detailed:
212
+ detailed[RUTarget.WEIGHTS] = {n: u.bytes for n, u in w_per_node.items()}
193
213
  if RUTarget.ACTIVATION in ru_targets:
194
214
  ru.activation_memory = a_total
215
+ if return_detailed:
216
+ detailed[RUTarget.ACTIVATION] = {cut: u.bytes for cut, u in a_per_cut.items()}
195
217
  if RUTarget.TOTAL in ru_targets:
196
218
  ru.total_memory = w_total + a_total
219
+ if return_detailed:
220
+ detailed[RUTarget.TOTAL] = {cut: u.bytes + w_total for cut, u in a_per_cut.items()}
197
221
  if RUTarget.BOPS in ru_targets:
198
- ru.bops, _ = self.compute_bops(target_criterion, bitwidth_mode, act_qcs=act_qcs, w_qcs=w_qcs)
222
+ ru.bops, bops_per_node = self.compute_bops(target_criterion, bitwidth_mode, act_qcs=act_qcs, w_qcs=w_qcs)
223
+ if return_detailed:
224
+ detailed[RUTarget.BOPS] = bops_per_node
225
+
226
+ assert ru.get_restricted_targets() == set(ru_targets), \
227
+ 'Mismatch between the number of requested and computed metrics'
228
+
229
+ if return_detailed:
230
+ return ru, detailed
199
231
 
200
- assert ru.get_restricted_targets() == set(ru_targets), 'Mismatch between the number of requested and computed metrics'
201
232
  return ru
202
233
 
203
234
  def compute_weights_utilization(self,
204
235
  target_criterion: TargetInclusionCriterion,
205
236
  bitwidth_mode: BitwidthMode,
206
- w_qcs: Optional[Dict[BaseNode, NodeWeightsQuantizationConfig]] = None) \
207
- -> Tuple[float, Dict[BaseNode, Utilization], Dict[BaseNode, Dict[str, Utilization]]]:
237
+ w_qcs: Optional[WeightsQCfgPerNode] = None) \
238
+ -> Tuple[float, Dict[NodeName, Utilization], Dict[NodeName, Dict[str, Utilization]]]:
208
239
  """
209
240
  Compute graph's weights resource utilization.
210
241
 
@@ -220,19 +251,18 @@ class ResourceUtilizationCalculator:
220
251
  - Per node total weights utilization. Dict keys are nodes in a topological order.
221
252
  - Detailed per node per weight attribute utilization. Dict keys are nodes in a topological order.
222
253
  """
223
- if w_qcs and bitwidth_mode != BitwidthMode.QCustom:
224
- raise ValueError(self.unexpected_qc_error)
254
+ self._validate_custom_qcs(w_qcs, bitwidth_mode)
225
255
 
226
256
  node_attrs = self._collect_target_nodes_w_attrs(target_criterion, include_reused=False)
227
257
 
228
- util_per_node: Dict[BaseNode, Utilization] = {}
258
+ util_per_node: Dict[NodeName, Utilization] = {}
229
259
  util_per_node_per_weight = {}
230
260
  for n in self._topo_sort(list(node_attrs.keys())):
231
- w_qc = w_qcs.get(n) if w_qcs else None
261
+ w_qc = w_qcs.get(n.name) if w_qcs else None
232
262
  node_weights_util, per_weight_util = self.compute_node_weights_utilization(n, node_attrs[n],
233
263
  bitwidth_mode, w_qc)
234
- util_per_node[n] = node_weights_util
235
- util_per_node_per_weight[n] = per_weight_util
264
+ util_per_node[n.name] = node_weights_util
265
+ util_per_node_per_weight[n.name] = per_weight_util
236
266
 
237
267
  total_util = sum(util_per_node.values()) if util_per_node else Utilization(0, 0)
238
268
  return total_util.bytes, util_per_node, util_per_node_per_weight
@@ -276,7 +306,7 @@ class ResourceUtilizationCalculator:
276
306
 
277
307
  attr_util = {}
278
308
  for attr in weight_attrs:
279
- size = self._params_cnt[n][attr]
309
+ size = self._params_cnt[n.name][attr]
280
310
  nbits = self._get_weight_nbits(n, attr, bitwidth_mode, qc)
281
311
  bytes_ = size * nbits / 8
282
312
  attr_util[attr] = Utilization(size, bytes_)
@@ -287,7 +317,7 @@ class ResourceUtilizationCalculator:
287
317
  def compute_activations_utilization(self,
288
318
  target_criterion: TargetInclusionCriterion,
289
319
  bitwidth_mode: BitwidthMode,
290
- act_qcs: Optional[Dict[BaseNode, NodeActivationQuantizationConfig]] = None):
320
+ act_qcs: Optional[ActivationQCfgPerNode] = None):
291
321
  """
292
322
  Compute total activations utilization in the graph.
293
323
 
@@ -299,14 +329,16 @@ class ResourceUtilizationCalculator:
299
329
  activations, if not provided, the default configuration will be extracted from the node.
300
330
 
301
331
  Returns:
302
- Total activation utilization of the network.
332
+ - Total activation utilization of the network.
333
+ - Total activation utilization per cut.
334
+ - Detailed activation utilization per cut per node.
303
335
  """
304
- return self.compute_activation_utilization_by_cut(target_criterion, bitwidth_mode, act_qcs)[0]
336
+ return self.compute_activation_utilization_by_cut(target_criterion, bitwidth_mode, act_qcs)
305
337
 
306
338
  def compute_activation_utilization_by_cut(self,
307
339
  target_criterion: TargetInclusionCriterion,
308
340
  bitwidth_mode: BitwidthMode,
309
- act_qcs: Optional[Dict[BaseNode, NodeActivationQuantizationConfig]] = None) \
341
+ act_qcs: Optional[ActivationQCfgPerNode] = None) \
310
342
  -> Tuple[float, Dict[Cut, Utilization], Dict[Cut, Dict[BaseNode, Utilization]]]:
311
343
  """
312
344
  Compute graph activation cuts utilization.
@@ -323,8 +355,7 @@ class ResourceUtilizationCalculator:
323
355
  - Total activation utilization per cut.
324
356
  - Detailed activation utilization per cut per node.
325
357
  """
326
- if act_qcs and not bitwidth_mode == BitwidthMode.QCustom:
327
- raise ValueError(self.unexpected_qc_error)
358
+ self._validate_custom_qcs(act_qcs, bitwidth_mode)
328
359
 
329
360
  graph_target_nodes = self._get_target_activation_nodes(target_criterion, include_reused=True)
330
361
  # if there are no target activations in the graph, don't waste time looking for cuts
@@ -338,9 +369,9 @@ class ResourceUtilizationCalculator:
338
369
  if not cut_target_nodes:
339
370
  continue
340
371
  for n in cut_target_nodes:
341
- qc = act_qcs.get(n) if act_qcs else None
342
- util_per_cut_per_node[cut][n] = self.compute_node_activation_tensor_utilization(n, target_criterion,
343
- bitwidth_mode, qc)
372
+ qc = act_qcs.get(n.name) if act_qcs else None
373
+ util_per_cut_per_node[cut][n.name] = self.compute_node_activation_tensor_utilization(n, target_criterion,
374
+ bitwidth_mode, qc)
344
375
  util_per_cut[cut] = sum(util_per_cut_per_node[cut].values()) # type: ignore
345
376
 
346
377
  total_util = max(util_per_cut.values())
@@ -349,9 +380,9 @@ class ResourceUtilizationCalculator:
349
380
  def compute_activation_tensors_utilization(self,
350
381
  target_criterion: TargetInclusionCriterion,
351
382
  bitwidth_mode: BitwidthMode,
352
- act_qcs: Optional[Dict[BaseNode, NodeActivationQuantizationConfig]] = None,
383
+ act_qcs: Optional[ActivationQCfgPerNode] = None,
353
384
  include_reused=False) \
354
- -> Tuple[float, Dict[BaseNode, Utilization]]:
385
+ -> Tuple[float, Dict[NodeName, Utilization]]:
355
386
  """
356
387
  Compute resource utilization for graph's activations tensors.
357
388
 
@@ -362,21 +393,21 @@ class ResourceUtilizationCalculator:
362
393
  In custom mode, must provide configuration for all configurable activations. For non-configurable
363
394
  activations, if not provided, the default configuration will be extracted from the node.
364
395
  include_reused: whether to include reused nodes.
396
+
365
397
  Returns:
366
398
  - Total activation utilization of the network.
367
399
  - Detailed utilization per node. Dict keys are nodes in a topological order.
368
400
 
369
401
  """
370
- if act_qcs and bitwidth_mode != BitwidthMode.QCustom:
371
- raise ValueError(self.unexpected_qc_error)
402
+ self._validate_custom_qcs(act_qcs, bitwidth_mode)
372
403
 
373
404
  nodes = self._get_target_activation_nodes(target_criterion, include_reused=include_reused)
374
405
 
375
- util_per_node: Dict[BaseNode, Utilization] = {}
406
+ util_per_node: Dict[NodeName, Utilization] = {}
376
407
  for n in self._topo_sort(nodes):
377
- qc = act_qcs.get(n) if act_qcs else None
408
+ qc = act_qcs.get(n.name) if act_qcs else None
378
409
  util = self.compute_node_activation_tensor_utilization(n, None, bitwidth_mode, qc)
379
- util_per_node[n] = util
410
+ util_per_node[n.name] = util
380
411
 
381
412
  total_util = max(util_per_node.values()).bytes if util_per_node else 0
382
413
  return total_util, util_per_node
@@ -396,6 +427,7 @@ class ResourceUtilizationCalculator:
396
427
  qc: activation quantization config for the node. Should be provided only in custom bit mode.
397
428
  In custom mode, must be provided if the activation is configurable. For non-configurable activation, if
398
429
  not passed, the default configuration will be extracted from the node.
430
+
399
431
  Returns:
400
432
  Node's activation utilization.
401
433
  """
@@ -408,7 +440,7 @@ class ResourceUtilizationCalculator:
408
440
  if not nodes:
409
441
  return Utilization(0, 0)
410
442
 
411
- size = self._act_tensors_size[n]
443
+ size = self._act_tensors_size[n.name]
412
444
  nbits = self._get_activation_nbits(n, bitwidth_mode, qc)
413
445
  bytes_ = size * nbits / 8
414
446
  return Utilization(size, bytes_)
@@ -416,9 +448,9 @@ class ResourceUtilizationCalculator:
416
448
  def compute_bops(self,
417
449
  target_criterion: TargetInclusionCriterion,
418
450
  bitwidth_mode: BitwidthMode,
419
- act_qcs: Optional[Dict[BaseNode, NodeActivationQuantizationConfig]] = None,
420
- w_qcs: Optional[Dict[BaseNode, NodeWeightsQuantizationConfig]] = None) \
421
- -> Tuple[int, Dict[BaseNode, int]]:
451
+ act_qcs: Optional[ActivationQCfgPerNode] = None,
452
+ w_qcs: Optional[WeightsQCfgPerNode] = None) \
453
+ -> Tuple[int, Dict[NodeName, int]]:
422
454
  """
423
455
  Compute bit operations based on nodes with kernel.
424
456
  Note that 'target_criterion' applies to weights, and BOPS are computed for the selected nodes regardless
@@ -438,30 +470,30 @@ class ResourceUtilizationCalculator:
438
470
  - Total BOPS count of the network.
439
471
  - Detailed BOPS count per node.
440
472
  """
441
- if target_criterion != TargetInclusionCriterion.AnyQuantized: # pragma: no cover
442
- raise NotImplementedError('BOPS computation is currently only supported for quantized targets.')
443
-
444
- nodes = self._collect_target_nodes_w_attrs(target_criterion, include_reused=True)
445
- # filter out nodes with only positional weights # TODO add as arg to get target nodes
446
- nodes = [n for n in nodes if n.has_kernel_weight_to_quantize(self.fw_info)]
473
+ self._validate_custom_qcs(act_qcs, bitwidth_mode)
474
+ self._validate_custom_qcs(w_qcs, bitwidth_mode)
447
475
 
448
476
  nodes_bops = {}
449
- for n in nodes:
450
- w_qc = w_qcs.get(n) if w_qcs else None
451
- nodes_bops[n] = self.compute_node_bops(n, bitwidth_mode, act_qcs=act_qcs, w_qc=w_qc)
477
+ for n in self.graph.get_topo_sorted_nodes():
478
+ w_qc = w_qcs.get(n.name) if w_qcs else None
479
+ bops = self.compute_node_bops(n, target_criterion, bitwidth_mode, act_qcs=act_qcs, w_qc=w_qc)
480
+ if bops:
481
+ nodes_bops[n.name] = bops
452
482
 
453
483
  return sum(nodes_bops.values()), nodes_bops
454
484
 
455
485
  def compute_node_bops(self,
456
486
  n: BaseNode,
487
+ target_criterion: Optional[TargetInclusionCriterion],
457
488
  bitwidth_mode: BitwidthMode,
458
- act_qcs: Optional[Dict[BaseNode, NodeActivationQuantizationConfig]] = None,
489
+ act_qcs: Optional[ActivationQCfgPerNode] = None,
459
490
  w_qc: Optional[NodeWeightsQuantizationConfig] = None) -> Union[float, int]:
460
491
  """
461
492
  Compute Bit Operations of a node.
462
493
 
463
494
  Args:
464
495
  n: node.
496
+ target_criterion: criterion to include nodes for computation.
465
497
  bitwidth_mode: bit-width mode for the computation.
466
498
  act_qcs: custom activations quantization configuration. Should be provided for custom bit mode only.
467
499
  In custom mode, must provide configuration for all configurable activations. For non-configurable
@@ -473,26 +505,58 @@ class ResourceUtilizationCalculator:
473
505
  Returns:
474
506
  Node's BOPS count.
475
507
  """
476
- node_mac = self.fw_impl.get_node_mac_operations(n, self.fw_info)
477
- if node_mac == 0: # pragma: no cover
478
- return node_mac
508
+ if target_criterion is None:
509
+ target_criterion = TargetInclusionCriterion.Any
510
+ if target_criterion not in [TargetInclusionCriterion.AnyQuantized, TargetInclusionCriterion.Any]:
511
+ raise ValueError('BOPS computation is supported only for Any and AnyQuantized targets.')
479
512
 
480
- incoming_edges = self.graph.incoming_edges(n, sort_by_attr=EDGE_SINK_INDEX)
481
- # TODO temporary adding this for const_representation test in torch which has Linear with const input
482
- if not incoming_edges: # pragma: no cover
513
+ self._validate_custom_qcs(act_qcs, bitwidth_mode)
514
+ self._validate_custom_qcs(w_qc, bitwidth_mode)
515
+
516
+ if isinstance(n, VirtualSplitWeightsNode):
517
+ # Virtual weights node can only be present if it couldn't be merged into VirtualActivationWeightsNode.
518
+ # This means that during MP search we cannot compute bops for all A/W nbits combinations. To prevent
519
+ # inconsistencies we ignore such nodes for bops computation.
483
520
  return 0
484
- assert len(incoming_edges) == 1, \
485
- f'Unexpected number of inputs {len(incoming_edges)} for BOPS calculation. Expected 1.'
486
- input_act_node = incoming_edges[0].source_node
487
- act_qc = act_qcs.get(input_act_node) if act_qcs else None
488
- a_nbits = self._get_activation_nbits(input_act_node, bitwidth_mode, act_qc)
489
521
 
522
+ # Fetch the original weights node for mac computation (VirtualActivationWeightsNode input/output shapes are
523
+ # based on the activation original node, not weights original node)
524
+ orig_w_node = n
525
+ if isinstance(n, VirtualActivationWeightsNode):
526
+ orig_w_node = n.original_weights_node
527
+ if isinstance(orig_w_node, VirtualSplitWeightsNode):
528
+ orig_w_node = orig_w_node.origin_node
529
+
530
+ # check if the node has kernel
490
531
  kernel_attrs = self.fw_info.get_kernel_op_attributes(n.type)
491
- if len(kernel_attrs) > 1: # pragma: no cover
532
+ if len(kernel_attrs) > 1: # pragma: no cover
492
533
  raise NotImplementedError('Multiple kernel attributes are not supported for BOPS computation.')
534
+ if not kernel_attrs or not kernel_attrs[0]:
535
+ return 0
536
+
493
537
  kernel_attr = kernel_attrs[0]
494
- w_nbits = self._get_weight_nbits(n, kernel_attr, bitwidth_mode, w_qc)
538
+ node_mac = self.fw_impl.get_node_mac_operations(orig_w_node, self.fw_info)
539
+ if node_mac == 0:
540
+ return node_mac
541
+
542
+ # find the activation node from which to get quantization info and for which to look in custom configuration
543
+ if isinstance(n, VirtualActivationWeightsNode):
544
+ # we don't need the original node (and cannot use it for custom configuration anyway)
545
+ a_node = n
546
+ else:
547
+ # if we are running on the original (non-virtual) graph, we only compute bops if it would be computed in an
548
+ # equivalent virtual graph for consistency.
549
+ a_node = get_input_activation_if_composable(self.graph, n, warn=False)
550
+ if a_node is None:
551
+ return 0
552
+
553
+ if (target_criterion == TargetInclusionCriterion.AnyQuantized and
554
+ not (a_node.is_activation_quantization_enabled() or n.is_weights_quantization_enabled(kernel_attr))):
555
+ return 0
495
556
 
557
+ act_qc = act_qcs.get(a_node.name) if act_qcs else None
558
+ a_nbits = self._get_activation_nbits(a_node, bitwidth_mode, act_qc)
559
+ w_nbits = self._get_weight_nbits(n, kernel_attr, bitwidth_mode, w_qc)
496
560
  node_bops = a_nbits * w_nbits * node_mac
497
561
  return node_bops
498
562
 
@@ -531,10 +595,11 @@ class ResourceUtilizationCalculator:
531
595
  """
532
596
  nodes_attrs = {n: attrs for n in self.graph.nodes
533
597
  if (attrs := self._get_target_weight_attrs(n, target_criterion))
534
- and (include_reused or not n.reuse)}
598
+ and (include_reused or not n.reuse)}
535
599
  return nodes_attrs
536
600
 
537
- def _get_target_weight_attrs(self, n: BaseNode, target_criterion: TargetInclusionCriterion) -> List[str]:
601
+ @staticmethod
602
+ def _get_target_weight_attrs(n: BaseNode, target_criterion: TargetInclusionCriterion) -> List[str]:
538
603
  """
539
604
  Collect weight attributes of a node per criterion.
540
605
 
@@ -692,3 +757,30 @@ class ResourceUtilizationCalculator:
692
757
  return w_qcs[0].weights_n_bits
693
758
 
694
759
  raise ValueError(f'Unknown mode {bitwidth_mode.name}') # pragma: no cover
760
+
761
+ def _validate_custom_qcs(self,
762
+ qcs: Union[BaseNodeQuantizationConfig, Dict[NodeName, BaseNodeQuantizationConfig]],
763
+ bitwidth_mode: BitwidthMode):
764
+ """
765
+ Validate custom quantization configuration.
766
+
767
+ Args:
768
+ qcs: either a mapping from nodes names to quantization configuration, or just a quantization configuration.
769
+ bitwidth_mode: bit mode.
770
+
771
+ Raises:
772
+ ValueError: if validation fails.
773
+
774
+ """
775
+ if qcs is None:
776
+ return
777
+
778
+ if bitwidth_mode != BitwidthMode.QCustom:
779
+ raise ValueError(self.unexpected_qc_error)
780
+
781
+ if isinstance(qcs, (NodeActivationQuantizationConfig, NodeWeightsQuantizationConfig)):
782
+ return
783
+
784
+ unknown_nodes = set(qcs.keys()) - self._nodes_names
785
+ if unknown_nodes:
786
+ raise ValueError(self.unexpected_qc_nodes_error, unknown_nodes)
@@ -12,10 +12,12 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
+ from typing import Optional
15
16
 
16
17
  from model_compression_toolkit.core.common import BaseNode, Graph, BaseSubstitution
17
18
  from model_compression_toolkit.logger import Logger
18
- from model_compression_toolkit.core.common.graph.virtual_activation_weights_node import VirtualActivationWeightsNode
19
+ from model_compression_toolkit.core.common.graph.virtual_activation_weights_node import VirtualActivationWeightsNode, \
20
+ VirtualSplitWeightsNode
19
21
 
20
22
 
21
23
  class BaseVirtualActivationWeightsComposition(BaseSubstitution):
@@ -39,26 +41,18 @@ class BaseVirtualActivationWeightsComposition(BaseSubstitution):
39
41
  Returns:
40
42
  Graph after applying the substitution.
41
43
  """
44
+ if not isinstance(weights_node, VirtualSplitWeightsNode):
45
+ raise TypeError(f'Matched node {weights_node} was expected to be of type VirtualSplitWeightsNode. '
46
+ f'This substitution is expected to be called after activation-weights split.')
42
47
 
43
- predecessors = graph.get_prev_nodes(weights_node)
44
- if len(predecessors) != 1:
45
- return graph
46
-
47
- act_node = predecessors[0]
48
-
49
- if len(graph.out_edges(act_node)) > 1:
50
- Logger.warning(f"Node {act_node.name} has multiple outgoing edges, which is not supported with "
51
- f"mixed-precision bit-operations utilization, thus, edge {act_node.name} --> {weights_node.name} "
52
- f"would not be counted in the bit-operations calculations.")
48
+ act_node = get_input_activation_if_composable(graph, weights_node, warn=True)
49
+ if act_node is None:
53
50
  return graph
54
51
 
55
52
  # Virtual composed activation-weights node
56
- # we pass a dummy initialization dict to initialize the super BaseNode class,
57
- # the actual arguments values are irrelevant because they are being overridden or not used
58
53
  v_node = VirtualActivationWeightsNode(act_node,
59
54
  weights_node,
60
- fw_info=graph.fw_info,
61
- **weights_node.__dict__)
55
+ fw_info=graph.fw_info)
62
56
 
63
57
  # Update graph
64
58
  graph.add_node(v_node)
@@ -71,3 +65,29 @@ class BaseVirtualActivationWeightsComposition(BaseSubstitution):
71
65
  graph.remove_node(act_node)
72
66
 
73
67
  return graph
68
+
69
+
70
+ def get_input_activation_if_composable(graph: Graph, weights_node: BaseNode, warn: bool) -> Optional[BaseNode]:
71
+ """
72
+ Get input activation node for composition, or None if not composable.
73
+
74
+ Args:
75
+ graph: graph.
76
+ weights_node: weights node for composition.
77
+ warn: whether to log a warning if not composable.
78
+
79
+ Returns:
80
+ Input activation node or None.
81
+ """
82
+ predecessors = graph.get_prev_nodes(weights_node)
83
+ assert len(predecessors) == 1, (f'Weights node is expected to have exactly one input, '
84
+ f'node {weights_node} has {len(predecessors)}')
85
+ act_node = predecessors[0]
86
+ if len(graph.out_edges(act_node)) > 1:
87
+ if warn:
88
+ Logger.warning(f"Node {act_node.name} has multiple outgoing edges, which is not supported with "
89
+ f"mixed-precision search under bit-operations constraint. In such case, it might result in "
90
+ f"incorrect resource utilization computation and suboptimal bits selection.")
91
+ return None
92
+
93
+ return act_node
@@ -52,9 +52,9 @@ class BaseWeightsActivationSplit(BaseSubstitution):
52
52
  # The decomposition works on linear nodes, that is, nodes with kernel ops
53
53
  kernel_attr = graph.fw_info.get_kernel_op_attributes(node.type)[0]
54
54
  if kernel_attr is None:
55
- Logger.error(f"Trying to split node weights and activation, but node "
56
- f"{node.name} doesn't have a kernel attribute.")
57
- if not node.is_all_weights_candidates_equal(kernel_attr) and not node.is_all_activation_candidates_equal():
55
+ Logger.critical(f"Trying to split node weights and activation, but node "
56
+ f"{node.name} doesn't have a kernel attribute.")
57
+ if node.is_configurable_weight(kernel_attr) and node.has_configurable_activation():
58
58
  # Node has both different weights and different activation configuration candidates
59
59
  weights_bits = [c.weights_quantization_cfg.get_attr_config(kernel_attr).weights_n_bits
60
60
  for c in node.get_unique_weights_candidates(kernel_attr)]
@@ -33,6 +33,7 @@ STRIDES = 'stride'
33
33
  DILATIONS = 'dilation'
34
34
  TENSOR_META = 'tensor_meta'
35
35
  FILTERS = 'out_channels'
36
+ OUTPUT_PADDING = 'output_padding'
36
37
  TYPE = 'type'
37
38
  PAD = 'pad'
38
39
  VALUE = 'value'
@@ -0,0 +1,77 @@
1
+ # Copyright 2025 Sony Semiconductor Israel, Inc. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ from typing import Tuple
16
+ import torch.nn as nn
17
+ import torch
18
+ from model_compression_toolkit.core.pytorch.constants import OUTPUT_PADDING
19
+ from model_compression_toolkit.core.common.graph.graph_matchers import NodeOperationMatcher
20
+ from model_compression_toolkit.core import common
21
+ from model_compression_toolkit.core.common import BaseNode, Graph
22
+ from model_compression_toolkit.logger import Logger
23
+
24
+
25
+ class ConvtransposeDynamicPadding(common.BaseSubstitution):
26
+ """
27
+ Replace output_padding of nn.ConvTranspose2d to align dynamic output_size input.
28
+ In case there is a dynamic output_size in ConvTranspose2d forward function, we recalculate the
29
+ output_padding here according to node.output_shape (which is equal to the dynamic output_size if existed).
30
+ """
31
+
32
+ def __init__(self):
33
+ """
34
+ Matches: nn.ConvTranspose2d
35
+ """
36
+ convtr_node = NodeOperationMatcher(nn.ConvTranspose2d)
37
+ super().__init__(matcher_instance=convtr_node)
38
+
39
+
40
+ def calc_dynamic_output_size(self, node: BaseNode) -> Tuple[int]:
41
+ """
42
+ Calc the output padding to support dunamic output_size of nn.ConvTranspose2d
43
+ Args:
44
+ node: node to calculate output padding
45
+
46
+ Returns:
47
+ corrected output padding
48
+ """
49
+ convtr = nn.ConvTranspose2d(**node.framework_attr)
50
+ num_spatial_dims = 2
51
+ output_padding = convtr._output_padding(torch.randn(size=node.input_shape[0]),
52
+ node.output_shape[0],
53
+ convtr.stride,
54
+ convtr.padding,
55
+ convtr.kernel_size,
56
+ num_spatial_dims,
57
+ convtr.dilation)
58
+ return tuple(output_padding)
59
+
60
+
61
+ def substitute(self,
62
+ graph: Graph,
63
+ node: BaseNode) -> Graph:
64
+ """
65
+ Substitute nn.ConvTranspose2d with corrected output_padding for cases of dynamic output_size
66
+ Args:
67
+ graph: Graph we apply the substitution on.
68
+ node: node that match the pattern in the substitution init.
69
+
70
+ Returns:
71
+ Graph after applying the substitution.
72
+ """
73
+
74
+ if not node.reuse:
75
+ output_padding = self.calc_dynamic_output_size(node)
76
+ node.framework_attr.update({OUTPUT_PADDING: output_padding})
77
+ return graph
@@ -62,6 +62,8 @@ from model_compression_toolkit.core.pytorch.graph_substitutions.substitutions.sc
62
62
  ScaledDotProductDecomposition
63
63
  from model_compression_toolkit.core.pytorch.graph_substitutions.substitutions.transform_function_call_method import \
64
64
  TransformFunctionCallMethod
65
+ from model_compression_toolkit.core.pytorch.graph_substitutions.substitutions.convtranspose_dynamic_padding import \
66
+ ConvtransposeDynamicPadding
65
67
  from model_compression_toolkit.core.pytorch.graph_substitutions.substitutions.const_holder_conv import \
66
68
  FunctionalConvSubstitution
67
69
  from model_compression_toolkit.core.pytorch.graph_substitutions.substitutions.relu_bound_to_power_of_2 import \
@@ -286,7 +288,8 @@ class PytorchImplementation(FrameworkImplementation):
286
288
  FunctionalBatchNorm(),
287
289
  FunctionalLayerNorm(),
288
290
  FunctionalLinear(),
289
- RemoveIdentity()]
291
+ RemoveIdentity(),
292
+ ConvtransposeDynamicPadding()]
290
293
 
291
294
  def get_substitutions_pre_statistics_collection(self,
292
295
  quant_config: QuantizationConfig
@@ -229,8 +229,8 @@ def _set_final_resource_utilization(graph: Graph,
229
229
  final_ru = None
230
230
  if ru_targets:
231
231
  ru_calculator = ResourceUtilizationCalculator(graph, fw_impl, fw_info)
232
- w_qcs = {n: n.final_weights_quantization_cfg for n in graph.nodes}
233
- a_qcs = {n: n.final_activation_quantization_cfg for n in graph.nodes}
232
+ w_qcs = {n.name: n.final_weights_quantization_cfg for n in graph.nodes}
233
+ a_qcs = {n.name: n.final_activation_quantization_cfg for n in graph.nodes}
234
234
  final_ru = ru_calculator.compute_resource_utilization(TargetInclusionCriterion.AnyQuantized,
235
235
  BitwidthMode.QCustom, act_qcs=a_qcs, w_qcs=w_qcs,
236
236
  ru_targets=ru_targets, allow_unused_qcs=True)
@@ -45,6 +45,10 @@ if FOUND_TORCH:
45
45
  """
46
46
  weight_quantizers, _ = fw_impl.get_inferable_quantizers(node)
47
47
  if len(weight_quantizers) > 0:
48
+ # Set reuse for weight quantizers if node is reused
49
+ for _, quantizer in weight_quantizers.items():
50
+ if node.reuse_group:
51
+ quantizer.enable_reuse_quantizer()
48
52
  # for positional weights we need to extract the weight's value.
49
53
  weights_values = {attr: fw_impl.to_tensor(node.get_weights_by_keys(attr))
50
54
  for attr in weight_quantizers if isinstance(attr, int)}
@@ -17,7 +17,8 @@ from model_compression_toolkit.target_platform_capabilities.targetplatform2frame
17
17
  from model_compression_toolkit.target_platform_capabilities.targetplatform2framework import (
18
18
  FrameworkQuantizationCapabilities, OperationsSetToLayers, Smaller, SmallerEq, NotEq, Eq, GreaterEq, Greater,
19
19
  LayerFilterParams, OperationsToLayers, get_current_tpc)
20
- from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import TargetPlatformCapabilities, OperatorsSet, \
21
- OperatorSetGroup, Signedness, AttributeQuantizationConfig, OpQuantizationConfig, QuantizationConfigOptions, Fusing
20
+ from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import (
21
+ TargetPlatformCapabilities, OperatorsSet, OperatorSetGroup, Signedness, AttributeQuantizationConfig,
22
+ OpQuantizationConfig, QuantizationConfigOptions, Fusing, OperatorSetNames)
22
23
 
23
24
  from mct_quantizers import QuantizationMethod