mct-nightly 1.10.0.20231203.post417__py3-none-any.whl → 1.10.0.20231205.post412__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mct_nightly-1.10.0.20231203.post417.dist-info → mct_nightly-1.10.0.20231205.post412.dist-info}/METADATA +1 -1
- {mct_nightly-1.10.0.20231203.post417.dist-info → mct_nightly-1.10.0.20231205.post412.dist-info}/RECORD +16 -16
- model_compression_toolkit/core/common/similarity_analyzer.py +4 -1
- model_compression_toolkit/core/common/visualization/tensorboard_writer.py +46 -0
- model_compression_toolkit/core/runner.py +3 -37
- model_compression_toolkit/gptq/keras/quantization_facade.py +3 -2
- model_compression_toolkit/gptq/pytorch/quantization_facade.py +3 -2
- model_compression_toolkit/legacy/keras_quantization_facade.py +4 -3
- model_compression_toolkit/legacy/pytorch_quantization_facade.py +4 -3
- model_compression_toolkit/ptq/keras/quantization_facade.py +3 -2
- model_compression_toolkit/ptq/pytorch/quantization_facade.py +3 -2
- model_compression_toolkit/qat/keras/quantization_facade.py +3 -2
- model_compression_toolkit/qat/pytorch/quantization_facade.py +3 -2
- {mct_nightly-1.10.0.20231203.post417.dist-info → mct_nightly-1.10.0.20231205.post412.dist-info}/LICENSE.md +0 -0
- {mct_nightly-1.10.0.20231203.post417.dist-info → mct_nightly-1.10.0.20231205.post412.dist-info}/WHEEL +0 -0
- {mct_nightly-1.10.0.20231203.post417.dist-info → mct_nightly-1.10.0.20231205.post412.dist-info}/top_level.txt +0 -0
|
@@ -6,7 +6,7 @@ model_compression_toolkit/core/analyzer.py,sha256=dbsD61pakp_9JXNyAScLdtJvcXny9j
|
|
|
6
6
|
model_compression_toolkit/core/exporter.py,sha256=U_-ea-zYHsnIt2ydameMLZ_gzDaCMI1dRa5IjA8RUuc,4233
|
|
7
7
|
model_compression_toolkit/core/graph_prep_runner.py,sha256=SHhFl0vpC9YpRu40xkApFzmw_dT-nfIz1MDjmKcon8Q,9913
|
|
8
8
|
model_compression_toolkit/core/quantization_prep_runner.py,sha256=npv55-QsJFR7bnbHj4tBMf13Y18Ns7QGa-UDSI6WJRE,6554
|
|
9
|
-
model_compression_toolkit/core/runner.py,sha256=
|
|
9
|
+
model_compression_toolkit/core/runner.py,sha256=Cb8_TWAOBz4SO1O48ehxqC9PpaR4KifbCs0nV724zMM,10454
|
|
10
10
|
model_compression_toolkit/core/common/__init__.py,sha256=Wh127PbXcETZX_d1PQqZ71ETK3J9XO5A-HpadGUbj6o,1447
|
|
11
11
|
model_compression_toolkit/core/common/base_substitutions.py,sha256=xDFSmVVs_iFSZfajytI0cuQaNRNcwHX3uqOoHgVUvxQ,1666
|
|
12
12
|
model_compression_toolkit/core/common/data_loader.py,sha256=7YF5Mqz64Xb4rVwY3knrdIZ4JEHybXxiQqx0deR_c5k,4017
|
|
@@ -18,7 +18,7 @@ model_compression_toolkit/core/common/model_builder_mode.py,sha256=jll9-59OPaE3u
|
|
|
18
18
|
model_compression_toolkit/core/common/model_collector.py,sha256=pNmJsU7QPCQ8-YUrzz__85YwF7Mk4Q27gozDSYCpzrg,5005
|
|
19
19
|
model_compression_toolkit/core/common/model_validation.py,sha256=LaG8wd6aZl0OJgieE3SeiVDEPxtk8IHq9-3wSnmWhY4,1214
|
|
20
20
|
model_compression_toolkit/core/common/node_prior_info.py,sha256=WXX_PrGVG9M9I_REG5ZzFBohwmV4yf356sZnrja_FLo,2832
|
|
21
|
-
model_compression_toolkit/core/common/similarity_analyzer.py,sha256=
|
|
21
|
+
model_compression_toolkit/core/common/similarity_analyzer.py,sha256=v7TBF_6iufFPTRQMT74Z2BuwSN4HfX7eFuqH4AF6JOM,8579
|
|
22
22
|
model_compression_toolkit/core/common/user_info.py,sha256=BM98W0jRF_M6zYdCs6z7eKgEOFGu1DZjJja7KSSUGJQ,1631
|
|
23
23
|
model_compression_toolkit/core/common/back2framework/__init__.py,sha256=cco4TmeIDIh32nj9ZZXVkws4dd9F2UDrmjKzTN8G0V0,697
|
|
24
24
|
model_compression_toolkit/core/common/back2framework/base_model_builder.py,sha256=V1oShKzbSkdcTvREn8VnQQBzvm-tTHkWMXqMkYozF2s,2023
|
|
@@ -132,7 +132,7 @@ model_compression_toolkit/core/common/substitutions/weights_activation_split.py,
|
|
|
132
132
|
model_compression_toolkit/core/common/visualization/__init__.py,sha256=mjbqLD-KcG3eNeCYpu1GBS7VclGVOQ63x2p6mAAuba4,698
|
|
133
133
|
model_compression_toolkit/core/common/visualization/final_config_visualizer.py,sha256=6I10jKLesB-RQKaXA75Xgz2wPvylQUrnPtCcQZIynGo,6371
|
|
134
134
|
model_compression_toolkit/core/common/visualization/nn_visualizer.py,sha256=6EjZj_KE1tICTQ0XSKIx5ivsRFpRktFywda7pW7YnNQ,5955
|
|
135
|
-
model_compression_toolkit/core/common/visualization/tensorboard_writer.py,sha256=
|
|
135
|
+
model_compression_toolkit/core/common/visualization/tensorboard_writer.py,sha256=954742gUTrrKmcVjcuBJaKR-EfMMsrWZ7PXd07unA6E,21939
|
|
136
136
|
model_compression_toolkit/core/keras/__init__.py,sha256=mjbqLD-KcG3eNeCYpu1GBS7VclGVOQ63x2p6mAAuba4,698
|
|
137
137
|
model_compression_toolkit/core/keras/constants.py,sha256=oFYFagoFTOQTrs2RHVc93583EhOvcvbCYHleqsZdQ6s,3046
|
|
138
138
|
model_compression_toolkit/core/keras/custom_layer_validation.py,sha256=f-b14wuiIgitBe7d0MmofYhDCTO3IhwJgwrh-Hq_t_U,1192
|
|
@@ -318,7 +318,7 @@ model_compression_toolkit/gptq/keras/gptq_keras_implementation.py,sha256=axBwnCS
|
|
|
318
318
|
model_compression_toolkit/gptq/keras/gptq_loss.py,sha256=rbRkF15MYd6nq4G49kcjb_dPTa-XNq9cTkrb93mXawo,6241
|
|
319
319
|
model_compression_toolkit/gptq/keras/gptq_training.py,sha256=75j56X2AcNI_5hInsLvXnWZOGMZIKQY2hStVKBaA_Bc,17705
|
|
320
320
|
model_compression_toolkit/gptq/keras/graph_info.py,sha256=FIGqzJbG6GkdHenvdMu-tGTjp4j9BewdF_spmWCb4Mo,4627
|
|
321
|
-
model_compression_toolkit/gptq/keras/quantization_facade.py,sha256=
|
|
321
|
+
model_compression_toolkit/gptq/keras/quantization_facade.py,sha256=DKh16lW24btKCrS6U-itaTHuTRX39e_nV5gn3hKUfxQ,15198
|
|
322
322
|
model_compression_toolkit/gptq/keras/quantizer/__init__.py,sha256=-DK1CDXvlsnEbki4lukZLpl6Xrbo91_jcqxXlG5Eg6Q,963
|
|
323
323
|
model_compression_toolkit/gptq/keras/quantizer/base_keras_gptq_quantizer.py,sha256=8NrJBftKFbMAF_jYaAbLP6GBwpCv3Ln1NKURaV75zko,4770
|
|
324
324
|
model_compression_toolkit/gptq/keras/quantizer/quant_utils.py,sha256=Vt7Qb8i4JsE4sFtcjpfM4FTXTtfV1t6SwfoNH8a_Iaw,5055
|
|
@@ -335,7 +335,7 @@ model_compression_toolkit/gptq/pytorch/gptq_loss.py,sha256=kDuWw-6zh17wZpYWh4Xa9
|
|
|
335
335
|
model_compression_toolkit/gptq/pytorch/gptq_pytorch_implementation.py,sha256=tECPTavxn8EEwgLaP2zvxdJH6Vg9jC0YOIMJ7857Sdc,1268
|
|
336
336
|
model_compression_toolkit/gptq/pytorch/gptq_training.py,sha256=RfkJNQbeqGZQvlmV0dZO7YJ894Gx2asLnnIHFdWNEZ0,15078
|
|
337
337
|
model_compression_toolkit/gptq/pytorch/graph_info.py,sha256=-0GDC2cr-XXS7cTFTnDflJivGN7VaPnzVPsxCE-vZNU,3955
|
|
338
|
-
model_compression_toolkit/gptq/pytorch/quantization_facade.py,sha256=
|
|
338
|
+
model_compression_toolkit/gptq/pytorch/quantization_facade.py,sha256=q1CzOc0LY1i8Ux4uH2PFGlGhTAJ0t-PPkIJURzbMu9w,13233
|
|
339
339
|
model_compression_toolkit/gptq/pytorch/quantizer/__init__.py,sha256=ZHNHo1yzye44m9_ht4UUZfTpK01RiVR3Tr74-vtnOGI,968
|
|
340
340
|
model_compression_toolkit/gptq/pytorch/quantizer/base_pytorch_gptq_quantizer.py,sha256=Zb-P0yRyZHHBlDvUBdRwxDpdduEJyJp6OT9pfKFF5ks,4171
|
|
341
341
|
model_compression_toolkit/gptq/pytorch/quantizer/quant_utils.py,sha256=OocYYRqvl7rZ37QT0hTzfJnWGiNCPskg7cziTlR7TRk,3893
|
|
@@ -348,19 +348,19 @@ model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/uniform_soft_quan
|
|
|
348
348
|
model_compression_toolkit/gptq/pytorch/quantizer/ste_rounding/__init__.py,sha256=Rf1RcYmelmdZmBV5qOKvKWF575ofc06JFQSq83Jz99A,696
|
|
349
349
|
model_compression_toolkit/gptq/pytorch/quantizer/ste_rounding/symmetric_ste.py,sha256=whVx94NCmYzobvZtPNr4qZTCN-CV8jfs4des_mkK3F8,8770
|
|
350
350
|
model_compression_toolkit/legacy/__init__.py,sha256=lNJ29DYxaLUPDstRDA1PGI5r9Fulq_hvrZMlhst1Z5g,697
|
|
351
|
-
model_compression_toolkit/legacy/keras_quantization_facade.py,sha256=
|
|
352
|
-
model_compression_toolkit/legacy/pytorch_quantization_facade.py,sha256=
|
|
351
|
+
model_compression_toolkit/legacy/keras_quantization_facade.py,sha256=2pNJoc1mKMbikBS_uebLgFAbTqfA0y9ofDUNCVogSKI,18444
|
|
352
|
+
model_compression_toolkit/legacy/pytorch_quantization_facade.py,sha256=p-ZGKdGeRIJsR5XmFYgjs3VN49NrwHumNtTY2OSDW-4,17874
|
|
353
353
|
model_compression_toolkit/ptq/__init__.py,sha256=50QBTXOKdj9XLjXtrvf0mhC9FlW6TOi9-pjl96RLR14,930
|
|
354
354
|
model_compression_toolkit/ptq/runner.py,sha256=_c1dSjlPPpsx59Vbg1buhG9bZq__OORz1VlPkwjJzoc,2552
|
|
355
355
|
model_compression_toolkit/ptq/keras/__init__.py,sha256=cco4TmeIDIh32nj9ZZXVkws4dd9F2UDrmjKzTN8G0V0,697
|
|
356
|
-
model_compression_toolkit/ptq/keras/quantization_facade.py,sha256=
|
|
356
|
+
model_compression_toolkit/ptq/keras/quantization_facade.py,sha256=39-bHTOZ7OENxvbC_HKelwjSO3e8BccGexgwWPClIDk,9969
|
|
357
357
|
model_compression_toolkit/ptq/pytorch/__init__.py,sha256=cco4TmeIDIh32nj9ZZXVkws4dd9F2UDrmjKzTN8G0V0,697
|
|
358
|
-
model_compression_toolkit/ptq/pytorch/quantization_facade.py,sha256=
|
|
358
|
+
model_compression_toolkit/ptq/pytorch/quantization_facade.py,sha256=zsewr833uH91wpluQ4Z22B3HX5ohepdJZPTqNpAosZk,8750
|
|
359
359
|
model_compression_toolkit/qat/__init__.py,sha256=BYKgH1NwB9fqF1TszULQ5tDfLI-GqgZV5sao-lDN9EM,1091
|
|
360
360
|
model_compression_toolkit/qat/common/__init__.py,sha256=6tLZ4R4pYP6QVztLVQC_jik2nES3l4uhML0qUxZrezk,829
|
|
361
361
|
model_compression_toolkit/qat/common/qat_config.py,sha256=kbSxFL6_u28furq5mW_75STWDmyX4clPt-seJAnX3IQ,3445
|
|
362
362
|
model_compression_toolkit/qat/keras/__init__.py,sha256=cco4TmeIDIh32nj9ZZXVkws4dd9F2UDrmjKzTN8G0V0,697
|
|
363
|
-
model_compression_toolkit/qat/keras/quantization_facade.py,sha256=
|
|
363
|
+
model_compression_toolkit/qat/keras/quantization_facade.py,sha256=hB_VrhSqwOjGnOT8BYpXkh52EMzo7I-62-IEGohg_74,16253
|
|
364
364
|
model_compression_toolkit/qat/keras/quantizer/__init__.py,sha256=zmYyCa25_KLCSUCGUDRslh3RCIjcRMxc_oXa54Aui-4,996
|
|
365
365
|
model_compression_toolkit/qat/keras/quantizer/base_keras_qat_quantizer.py,sha256=gPuIgQb8OafvC3SuA8jNsGoy8S8eTsDCEKuh36WDNss,2104
|
|
366
366
|
model_compression_toolkit/qat/keras/quantizer/quant_utils.py,sha256=cBULOgWUodcBO1lHevZggdTevuDYI6tQceV86U2x6DA,2543
|
|
@@ -372,7 +372,7 @@ model_compression_toolkit/qat/keras/quantizer/ste_rounding/__init__.py,sha256=cc
|
|
|
372
372
|
model_compression_toolkit/qat/keras/quantizer/ste_rounding/symmetric_ste.py,sha256=I4KlaGv17k71IyjuSG9M0OlXlD5P0pfvKa6oCyRQ5FE,13517
|
|
373
373
|
model_compression_toolkit/qat/keras/quantizer/ste_rounding/uniform_ste.py,sha256=EED6LfqhX_OhDRJ9e4GwbpgNC9vq7hoXyJS2VPvG2qc,10789
|
|
374
374
|
model_compression_toolkit/qat/pytorch/__init__.py,sha256=cco4TmeIDIh32nj9ZZXVkws4dd9F2UDrmjKzTN8G0V0,697
|
|
375
|
-
model_compression_toolkit/qat/pytorch/quantization_facade.py,sha256=
|
|
375
|
+
model_compression_toolkit/qat/pytorch/quantization_facade.py,sha256=Kv2vUKY2cSt6kYB7J2B1SC_PGzSFqecpxDcKWGeRzuQ,12629
|
|
376
376
|
model_compression_toolkit/qat/pytorch/quantizer/__init__.py,sha256=xYa4C8pr9cG1f3mQQcBXO_u3IdJN-zl7leZxuXDs86w,1003
|
|
377
377
|
model_compression_toolkit/qat/pytorch/quantizer/base_pytorch_qat_quantizer.py,sha256=FnhuFCuQoSf78FM1z1UZgXXd3k-mKSM7i9dYOuJUmeA,2213
|
|
378
378
|
model_compression_toolkit/qat/pytorch/quantizer/quantization_builder.py,sha256=GOYRDXvQSGe_iUFVmvDy5BqC952hu_-rQO06n8QCyw0,5491
|
|
@@ -448,8 +448,8 @@ model_compression_toolkit/trainable_infrastructure/keras/quantize_wrapper.py,sha
|
|
|
448
448
|
model_compression_toolkit/trainable_infrastructure/keras/quantizer_utils.py,sha256=MVwXNymmFRB2NXIBx4e2mdJ1RfoHxRPYRgjb1MQP5kY,1797
|
|
449
449
|
model_compression_toolkit/trainable_infrastructure/pytorch/__init__.py,sha256=huHoBUcKNB6BnY6YaUCcFvdyBtBI172ZoUD8ZYeNc6o,696
|
|
450
450
|
model_compression_toolkit/trainable_infrastructure/pytorch/base_pytorch_quantizer.py,sha256=SbvRlIdE32PEBsINt1bhSqvrKL_zbM9V-aeSkOn-sw4,3083
|
|
451
|
-
mct_nightly-1.10.0.
|
|
452
|
-
mct_nightly-1.10.0.
|
|
453
|
-
mct_nightly-1.10.0.
|
|
454
|
-
mct_nightly-1.10.0.
|
|
455
|
-
mct_nightly-1.10.0.
|
|
451
|
+
mct_nightly-1.10.0.20231205.post412.dist-info/LICENSE.md,sha256=aYSSIb-5AFPeITTvXm1UAoe0uYBiMmSS8flvXaaFUks,10174
|
|
452
|
+
mct_nightly-1.10.0.20231205.post412.dist-info/METADATA,sha256=TJLsSdD5_ptD9k1nxSMBU83HMkoDdwh1OGP9_dMrhbM,16232
|
|
453
|
+
mct_nightly-1.10.0.20231205.post412.dist-info/WHEEL,sha256=oiQVh_5PnQM0E3gPdiz09WCNmwiHDMaGer_elqB3coM,92
|
|
454
|
+
mct_nightly-1.10.0.20231205.post412.dist-info/top_level.txt,sha256=gsYA8juk0Z-ZmQRKULkb3JLGdOdz8jW_cMRjisn9ga4,26
|
|
455
|
+
mct_nightly-1.10.0.20231205.post412.dist-info/RECORD,,
|
|
@@ -241,4 +241,7 @@ def compute_kl_divergence(float_tensor: np.ndarray, fxp_tensor: np.ndarray, batc
|
|
|
241
241
|
non_zero_fxp_tensor = fxp_flat.copy()
|
|
242
242
|
non_zero_fxp_tensor[non_zero_fxp_tensor == 0] = EPS
|
|
243
243
|
|
|
244
|
-
|
|
244
|
+
prob_distance = np.where(float_flat != 0, float_flat * np.log(float_flat / non_zero_fxp_tensor), 0)
|
|
245
|
+
# The sum is part of the KL-Divergance function.
|
|
246
|
+
# The mean is to aggregate the distance between each output probability vectors.
|
|
247
|
+
return np.mean(np.sum(prob_distance, axis=-1), axis=-1)
|
|
@@ -16,6 +16,7 @@
|
|
|
16
16
|
from copy import deepcopy
|
|
17
17
|
|
|
18
18
|
import io
|
|
19
|
+
import os
|
|
19
20
|
import numpy as np
|
|
20
21
|
from PIL import Image
|
|
21
22
|
from matplotlib.figure import Figure
|
|
@@ -34,6 +35,9 @@ from networkx import topological_sort
|
|
|
34
35
|
from model_compression_toolkit.core import FrameworkInfo
|
|
35
36
|
from model_compression_toolkit.core.common import Graph, BaseNode
|
|
36
37
|
from model_compression_toolkit.core.common.collectors.statistics_collector import BaseStatsCollector
|
|
38
|
+
from model_compression_toolkit.logger import Logger
|
|
39
|
+
from model_compression_toolkit.core.common.visualization.final_config_visualizer import \
|
|
40
|
+
WeightsFinalBitwidthConfigVisualizer, ActivationFinalBitwidthConfigVisualizer
|
|
37
41
|
|
|
38
42
|
DEVICE_STEP_STATS = "/device:CPU:0"
|
|
39
43
|
|
|
@@ -486,3 +490,45 @@ class TensorboardWriter(object):
|
|
|
486
490
|
er = self.__get_event_writer_by_tag_name(main_tag_name)
|
|
487
491
|
er.add_event(event)
|
|
488
492
|
er.flush()
|
|
493
|
+
|
|
494
|
+
|
|
495
|
+
def init_tensorboard_writer(fw_info: FrameworkInfo) -> TensorboardWriter:
|
|
496
|
+
"""
|
|
497
|
+
Create a TensorBoardWriter object initialized with the logger dir path if it was set,
|
|
498
|
+
or None otherwise.
|
|
499
|
+
|
|
500
|
+
Args:
|
|
501
|
+
fw_info: FrameworkInfo object.
|
|
502
|
+
|
|
503
|
+
Returns:
|
|
504
|
+
A TensorBoardWriter object.
|
|
505
|
+
"""
|
|
506
|
+
tb_w = None
|
|
507
|
+
if Logger.LOG_PATH is not None:
|
|
508
|
+
tb_log_dir = os.path.join(os.getcwd(), Logger.LOG_PATH, 'tensorboard_logs')
|
|
509
|
+
Logger.info(f'To use Tensorboard, please run: tensorboard --logdir {tb_log_dir}')
|
|
510
|
+
tb_w = TensorboardWriter(tb_log_dir, fw_info)
|
|
511
|
+
return tb_w
|
|
512
|
+
|
|
513
|
+
|
|
514
|
+
def finalize_bitwidth_in_tb(tb_w: TensorboardWriter,
|
|
515
|
+
weights_conf_nodes_bitwidth: List,
|
|
516
|
+
activation_conf_nodes_bitwidth: List):
|
|
517
|
+
"""
|
|
518
|
+
Set the final bit-width configuration of the quantized model in the provided TensorBoard object.
|
|
519
|
+
|
|
520
|
+
Args:
|
|
521
|
+
tb_w: A TensorBoard object.
|
|
522
|
+
weights_conf_nodes_bitwidth: Final weights bit-width configuration.
|
|
523
|
+
activation_conf_nodes_bitwidth: Final activation bit-width configuration.
|
|
524
|
+
|
|
525
|
+
"""
|
|
526
|
+
|
|
527
|
+
if len(weights_conf_nodes_bitwidth) > 0:
|
|
528
|
+
visual = WeightsFinalBitwidthConfigVisualizer(weights_conf_nodes_bitwidth)
|
|
529
|
+
figure = visual.plot_config_bitwidth()
|
|
530
|
+
tb_w.add_figure(figure, f'Weights final bit-width config')
|
|
531
|
+
if len(activation_conf_nodes_bitwidth) > 0:
|
|
532
|
+
visual = ActivationFinalBitwidthConfigVisualizer(activation_conf_nodes_bitwidth)
|
|
533
|
+
figure = visual.plot_config_bitwidth()
|
|
534
|
+
tb_w.add_figure(figure, f'Activation final bit-width config')
|
|
@@ -14,11 +14,9 @@
|
|
|
14
14
|
# ==============================================================================
|
|
15
15
|
|
|
16
16
|
|
|
17
|
-
import os
|
|
18
17
|
from typing import Callable, Tuple, Any, List, Dict
|
|
19
18
|
|
|
20
19
|
import numpy as np
|
|
21
|
-
from tqdm import tqdm
|
|
22
20
|
|
|
23
21
|
from model_compression_toolkit.core.common import FrameworkInfo
|
|
24
22
|
from model_compression_toolkit.core.common.hessian.hessian_info_service import HessianInfoService
|
|
@@ -33,20 +31,14 @@ from model_compression_toolkit.core.common.mixed_precision.kpi_tools.kpi_aggrega
|
|
|
33
31
|
from model_compression_toolkit.core.common.mixed_precision.kpi_tools.kpi_functions_mapping import kpi_functions_mapping
|
|
34
32
|
from model_compression_toolkit.core.common.mixed_precision.kpi_tools.kpi_methods import MpKpiMetric
|
|
35
33
|
from model_compression_toolkit.core.common.mixed_precision.mixed_precision_search_facade import search_bit_width
|
|
36
|
-
from model_compression_toolkit.core.common.model_collector import ModelCollector
|
|
37
34
|
from model_compression_toolkit.core.common.network_editors.edit_network import edit_network_graph
|
|
38
35
|
from model_compression_toolkit.core.common.quantization.core_config import CoreConfig
|
|
39
|
-
from model_compression_toolkit.core.common.quantization.quantization_analyzer import analyzer_graph
|
|
40
|
-
from model_compression_toolkit.core.common.quantization.quantization_params_generation.qparams_computation import \
|
|
41
|
-
calculate_quantization_params
|
|
42
|
-
from model_compression_toolkit.core.common.statistics_correction.statistics_correction import \
|
|
43
|
-
statistics_correction_runner
|
|
44
|
-
from model_compression_toolkit.core.common.substitutions.apply_substitutions import substitute
|
|
45
36
|
from model_compression_toolkit.target_platform_capabilities.target_platform.targetplatform2framework import TargetPlatformCapabilities
|
|
46
37
|
from model_compression_toolkit.core.common.visualization.final_config_visualizer import \
|
|
47
38
|
WeightsFinalBitwidthConfigVisualizer, \
|
|
48
39
|
ActivationFinalBitwidthConfigVisualizer
|
|
49
|
-
from model_compression_toolkit.core.common.visualization.tensorboard_writer import TensorboardWriter
|
|
40
|
+
from model_compression_toolkit.core.common.visualization.tensorboard_writer import TensorboardWriter, \
|
|
41
|
+
finalize_bitwidth_in_tb
|
|
50
42
|
|
|
51
43
|
|
|
52
44
|
def core_runner(in_model: Any,
|
|
@@ -150,37 +142,11 @@ def core_runner(in_model: Any,
|
|
|
150
142
|
f'Final activation bit-width configuration: {[node_b[1] for node_b in activation_conf_nodes_bitwidth]}')
|
|
151
143
|
|
|
152
144
|
if tb_w is not None:
|
|
153
|
-
|
|
154
|
-
visual = WeightsFinalBitwidthConfigVisualizer(weights_conf_nodes_bitwidth)
|
|
155
|
-
figure = visual.plot_config_bitwidth()
|
|
156
|
-
tb_w.add_figure(figure, f'Weights final bit-width config')
|
|
157
|
-
if len(activation_conf_nodes_bitwidth) > 0:
|
|
158
|
-
visual = ActivationFinalBitwidthConfigVisualizer(activation_conf_nodes_bitwidth)
|
|
159
|
-
figure = visual.plot_config_bitwidth()
|
|
160
|
-
tb_w.add_figure(figure, f'Activation final bit-width config')
|
|
145
|
+
finalize_bitwidth_in_tb(tb_w, weights_conf_nodes_bitwidth, activation_conf_nodes_bitwidth)
|
|
161
146
|
|
|
162
147
|
return tg, bit_widths_config, hessian_info_service
|
|
163
148
|
|
|
164
149
|
|
|
165
|
-
def _init_tensorboard_writer(fw_info: FrameworkInfo) -> TensorboardWriter:
|
|
166
|
-
"""
|
|
167
|
-
Create a TensorBoardWriter object initialized with the logger dir path if it was set,
|
|
168
|
-
or None otherwise.
|
|
169
|
-
|
|
170
|
-
Args:
|
|
171
|
-
fw_info: FrameworkInfo object.
|
|
172
|
-
|
|
173
|
-
Returns:
|
|
174
|
-
A TensorBoardWriter object.
|
|
175
|
-
"""
|
|
176
|
-
tb_w = None
|
|
177
|
-
if Logger.LOG_PATH is not None:
|
|
178
|
-
tb_log_dir = os.path.join(os.getcwd(), Logger.LOG_PATH, 'tensorboard_logs')
|
|
179
|
-
Logger.info(f'To use Tensorboard, please run: tensorboard --logdir {tb_log_dir}')
|
|
180
|
-
tb_w = TensorboardWriter(tb_log_dir, fw_info)
|
|
181
|
-
return tb_w
|
|
182
|
-
|
|
183
|
-
|
|
184
150
|
def _set_final_kpi(graph: Graph,
|
|
185
151
|
final_bit_widths_config: List[int],
|
|
186
152
|
kpi_functions_dict: Dict[KPITarget, Tuple[MpKpiMetric, MpKpiAggregation]],
|
|
@@ -16,6 +16,7 @@
|
|
|
16
16
|
from typing import Callable, Tuple
|
|
17
17
|
from packaging import version
|
|
18
18
|
|
|
19
|
+
from model_compression_toolkit.core.common.visualization.tensorboard_writer import init_tensorboard_writer
|
|
19
20
|
from model_compression_toolkit.logger import Logger
|
|
20
21
|
from model_compression_toolkit.constants import TENSORFLOW, FOUND_TF
|
|
21
22
|
from model_compression_toolkit.core.common.user_info import UserInformation
|
|
@@ -24,7 +25,7 @@ from model_compression_toolkit.core.common.mixed_precision.kpi_tools.kpi import
|
|
|
24
25
|
from model_compression_toolkit.core.common.framework_info import FrameworkInfo
|
|
25
26
|
from model_compression_toolkit.core.common.mixed_precision.mixed_precision_quantization_config import MixedPrecisionQuantizationConfigV2
|
|
26
27
|
from model_compression_toolkit.core import CoreConfig
|
|
27
|
-
from model_compression_toolkit.core.runner import core_runner
|
|
28
|
+
from model_compression_toolkit.core.runner import core_runner
|
|
28
29
|
from model_compression_toolkit.gptq.runner import gptq_runner
|
|
29
30
|
from model_compression_toolkit.core.exporter import export_model
|
|
30
31
|
from model_compression_toolkit.core.analyzer import analyzer_model_quantization
|
|
@@ -202,7 +203,7 @@ if FOUND_TF:
|
|
|
202
203
|
Logger.info("Using experimental mixed-precision quantization. "
|
|
203
204
|
"If you encounter an issue please file a bug.")
|
|
204
205
|
|
|
205
|
-
tb_w =
|
|
206
|
+
tb_w = init_tensorboard_writer(fw_info)
|
|
206
207
|
|
|
207
208
|
fw_impl = GPTQKerasImplemantation()
|
|
208
209
|
|
|
@@ -15,12 +15,13 @@
|
|
|
15
15
|
from typing import Callable
|
|
16
16
|
from model_compression_toolkit.core import common
|
|
17
17
|
from model_compression_toolkit.constants import FOUND_TORCH
|
|
18
|
+
from model_compression_toolkit.core.common.visualization.tensorboard_writer import init_tensorboard_writer
|
|
18
19
|
from model_compression_toolkit.logger import Logger
|
|
19
20
|
from model_compression_toolkit.constants import PYTORCH
|
|
20
21
|
from model_compression_toolkit.gptq.common.gptq_config import GradientPTQConfigV2
|
|
21
22
|
from model_compression_toolkit.target_platform_capabilities.target_platform import TargetPlatformCapabilities
|
|
22
23
|
from model_compression_toolkit.core.common.mixed_precision.kpi_tools.kpi import KPI
|
|
23
|
-
from model_compression_toolkit.core.runner import core_runner
|
|
24
|
+
from model_compression_toolkit.core.runner import core_runner
|
|
24
25
|
from model_compression_toolkit.gptq.keras.quantization_facade import GPTQ_MOMENTUM
|
|
25
26
|
from model_compression_toolkit.gptq.runner import gptq_runner
|
|
26
27
|
from model_compression_toolkit.core.exporter import export_model
|
|
@@ -161,7 +162,7 @@ if FOUND_TORCH:
|
|
|
161
162
|
Logger.info("Using experimental mixed-precision quantization. "
|
|
162
163
|
"If you encounter an issue please file a bug.")
|
|
163
164
|
|
|
164
|
-
tb_w =
|
|
165
|
+
tb_w = init_tensorboard_writer(DEFAULT_PYTORCH_INFO)
|
|
165
166
|
|
|
166
167
|
fw_impl = GPTQPytorchImplemantation()
|
|
167
168
|
|
|
@@ -15,6 +15,7 @@
|
|
|
15
15
|
|
|
16
16
|
from typing import Callable, List, Tuple
|
|
17
17
|
|
|
18
|
+
from model_compression_toolkit.core.common.visualization.tensorboard_writer import init_tensorboard_writer
|
|
18
19
|
from model_compression_toolkit.logger import Logger
|
|
19
20
|
from model_compression_toolkit.constants import TENSORFLOW
|
|
20
21
|
from model_compression_toolkit.core.common.user_info import UserInformation
|
|
@@ -28,7 +29,7 @@ from model_compression_toolkit.core.common.quantization.quantization_config impo
|
|
|
28
29
|
from model_compression_toolkit.core.common.quantization.core_config import CoreConfig
|
|
29
30
|
from model_compression_toolkit.core.common.quantization.debug_config import DebugConfig
|
|
30
31
|
from model_compression_toolkit.core.common.quantization.quantization_config import DEFAULTCONFIG
|
|
31
|
-
from model_compression_toolkit.core.runner import core_runner
|
|
32
|
+
from model_compression_toolkit.core.runner import core_runner
|
|
32
33
|
from model_compression_toolkit.gptq.runner import gptq_runner
|
|
33
34
|
from model_compression_toolkit.ptq.runner import ptq_runner
|
|
34
35
|
from model_compression_toolkit.core.exporter import export_model
|
|
@@ -114,7 +115,7 @@ if FOUND_TF:
|
|
|
114
115
|
network_editor=network_editor)
|
|
115
116
|
)
|
|
116
117
|
|
|
117
|
-
tb_w =
|
|
118
|
+
tb_w = init_tensorboard_writer(fw_info)
|
|
118
119
|
|
|
119
120
|
fw_impl = KerasImplementation()
|
|
120
121
|
|
|
@@ -249,7 +250,7 @@ if FOUND_TF:
|
|
|
249
250
|
network_editor=network_editor)
|
|
250
251
|
)
|
|
251
252
|
|
|
252
|
-
tb_w =
|
|
253
|
+
tb_w = init_tensorboard_writer(fw_info)
|
|
253
254
|
|
|
254
255
|
fw_impl = KerasImplementation()
|
|
255
256
|
|
|
@@ -14,6 +14,7 @@
|
|
|
14
14
|
# ==============================================================================
|
|
15
15
|
from typing import Callable, List, Tuple
|
|
16
16
|
|
|
17
|
+
from model_compression_toolkit.core.common.visualization.tensorboard_writer import init_tensorboard_writer
|
|
17
18
|
from model_compression_toolkit.logger import Logger
|
|
18
19
|
from model_compression_toolkit.constants import PYTORCH
|
|
19
20
|
from model_compression_toolkit.core.common.user_info import UserInformation
|
|
@@ -28,7 +29,7 @@ from model_compression_toolkit.core.common.mixed_precision.mixed_precision_quant
|
|
|
28
29
|
MixedPrecisionQuantizationConfig, DEFAULT_MIXEDPRECISION_CONFIG
|
|
29
30
|
from model_compression_toolkit.core.common.quantization.quantization_config import QuantizationConfig
|
|
30
31
|
from model_compression_toolkit.core.common.quantization.quantization_config import DEFAULTCONFIG
|
|
31
|
-
from model_compression_toolkit.core.runner import core_runner
|
|
32
|
+
from model_compression_toolkit.core.runner import core_runner
|
|
32
33
|
from model_compression_toolkit.gptq.runner import gptq_runner
|
|
33
34
|
from model_compression_toolkit.ptq.runner import ptq_runner
|
|
34
35
|
from model_compression_toolkit.core.exporter import export_model
|
|
@@ -106,7 +107,7 @@ if FOUND_TORCH:
|
|
|
106
107
|
debug_config=DebugConfig(analyze_similarity=analyze_similarity,
|
|
107
108
|
network_editor=network_editor))
|
|
108
109
|
|
|
109
|
-
tb_w =
|
|
110
|
+
tb_w = init_tensorboard_writer(fw_info)
|
|
110
111
|
|
|
111
112
|
fw_impl = PytorchImplementation()
|
|
112
113
|
|
|
@@ -235,7 +236,7 @@ if FOUND_TORCH:
|
|
|
235
236
|
debug_config=DebugConfig(analyze_similarity=analyze_similarity,
|
|
236
237
|
network_editor=network_editor))
|
|
237
238
|
|
|
238
|
-
tb_w =
|
|
239
|
+
tb_w = init_tensorboard_writer(fw_info)
|
|
239
240
|
|
|
240
241
|
fw_impl = PytorchImplementation()
|
|
241
242
|
|
|
@@ -17,6 +17,7 @@ from typing import Callable
|
|
|
17
17
|
|
|
18
18
|
from model_compression_toolkit.core import CoreConfig
|
|
19
19
|
from model_compression_toolkit.core.analyzer import analyzer_model_quantization
|
|
20
|
+
from model_compression_toolkit.core.common.visualization.tensorboard_writer import init_tensorboard_writer
|
|
20
21
|
from model_compression_toolkit.logger import Logger
|
|
21
22
|
from model_compression_toolkit.constants import TENSORFLOW, FOUND_TF
|
|
22
23
|
from model_compression_toolkit.core.common.mixed_precision.kpi_tools.kpi import KPI
|
|
@@ -24,7 +25,7 @@ from model_compression_toolkit.core.common.mixed_precision.mixed_precision_quant
|
|
|
24
25
|
MixedPrecisionQuantizationConfigV2
|
|
25
26
|
from model_compression_toolkit.target_platform_capabilities.target_platform.targetplatform2framework import TargetPlatformCapabilities
|
|
26
27
|
from model_compression_toolkit.core.exporter import export_model
|
|
27
|
-
from model_compression_toolkit.core.runner import core_runner
|
|
28
|
+
from model_compression_toolkit.core.runner import core_runner
|
|
28
29
|
from model_compression_toolkit.ptq.runner import ptq_runner
|
|
29
30
|
|
|
30
31
|
if FOUND_TF:
|
|
@@ -130,7 +131,7 @@ if FOUND_TF:
|
|
|
130
131
|
Logger.info("Using experimental mixed-precision quantization. "
|
|
131
132
|
"If you encounter an issue please file a bug.")
|
|
132
133
|
|
|
133
|
-
tb_w =
|
|
134
|
+
tb_w = init_tensorboard_writer(fw_info)
|
|
134
135
|
|
|
135
136
|
fw_impl = KerasImplementation()
|
|
136
137
|
|
|
@@ -15,6 +15,7 @@
|
|
|
15
15
|
from typing import Callable
|
|
16
16
|
|
|
17
17
|
from model_compression_toolkit.core import common
|
|
18
|
+
from model_compression_toolkit.core.common.visualization.tensorboard_writer import init_tensorboard_writer
|
|
18
19
|
from model_compression_toolkit.logger import Logger
|
|
19
20
|
from model_compression_toolkit.constants import PYTORCH, FOUND_TORCH
|
|
20
21
|
from model_compression_toolkit.target_platform_capabilities.target_platform import TargetPlatformCapabilities
|
|
@@ -22,7 +23,7 @@ from model_compression_toolkit.core.common.mixed_precision.kpi_tools.kpi import
|
|
|
22
23
|
from model_compression_toolkit.core import CoreConfig
|
|
23
24
|
from model_compression_toolkit.core.common.mixed_precision.mixed_precision_quantization_config import \
|
|
24
25
|
MixedPrecisionQuantizationConfigV2
|
|
25
|
-
from model_compression_toolkit.core.runner import core_runner
|
|
26
|
+
from model_compression_toolkit.core.runner import core_runner
|
|
26
27
|
from model_compression_toolkit.ptq.runner import ptq_runner
|
|
27
28
|
from model_compression_toolkit.core.exporter import export_model
|
|
28
29
|
from model_compression_toolkit.core.analyzer import analyzer_model_quantization
|
|
@@ -102,7 +103,7 @@ if FOUND_TORCH:
|
|
|
102
103
|
Logger.info("Using experimental mixed-precision quantization. "
|
|
103
104
|
"If you encounter an issue please file a bug.")
|
|
104
105
|
|
|
105
|
-
tb_w =
|
|
106
|
+
tb_w = init_tensorboard_writer(DEFAULT_PYTORCH_INFO)
|
|
106
107
|
|
|
107
108
|
fw_impl = PytorchImplementation()
|
|
108
109
|
|
|
@@ -17,6 +17,7 @@ from typing import Callable
|
|
|
17
17
|
from functools import partial
|
|
18
18
|
|
|
19
19
|
from model_compression_toolkit.core import CoreConfig
|
|
20
|
+
from model_compression_toolkit.core.common.visualization.tensorboard_writer import init_tensorboard_writer
|
|
20
21
|
from model_compression_toolkit.logger import Logger
|
|
21
22
|
from model_compression_toolkit.constants import FOUND_TF
|
|
22
23
|
from model_compression_toolkit.core.common.mixed_precision.kpi_tools.kpi import KPI
|
|
@@ -25,7 +26,7 @@ from model_compression_toolkit.core.common.mixed_precision.mixed_precision_quant
|
|
|
25
26
|
from mct_quantizers import KerasActivationQuantizationHolder
|
|
26
27
|
from model_compression_toolkit.trainable_infrastructure import KerasTrainableQuantizationWrapper
|
|
27
28
|
from model_compression_toolkit.target_platform_capabilities.target_platform.targetplatform2framework import TargetPlatformCapabilities
|
|
28
|
-
from model_compression_toolkit.core.runner import core_runner
|
|
29
|
+
from model_compression_toolkit.core.runner import core_runner
|
|
29
30
|
from model_compression_toolkit.ptq.runner import ptq_runner
|
|
30
31
|
|
|
31
32
|
if FOUND_TF:
|
|
@@ -177,7 +178,7 @@ if FOUND_TF:
|
|
|
177
178
|
Logger.info("Using experimental mixed-precision quantization. "
|
|
178
179
|
"If you encounter an issue please file a bug.")
|
|
179
180
|
|
|
180
|
-
tb_w =
|
|
181
|
+
tb_w = init_tensorboard_writer(fw_info)
|
|
181
182
|
|
|
182
183
|
fw_impl = KerasImplementation()
|
|
183
184
|
|
|
@@ -20,6 +20,7 @@ from model_compression_toolkit.constants import FOUND_TORCH, PYTORCH
|
|
|
20
20
|
|
|
21
21
|
from model_compression_toolkit.core import CoreConfig
|
|
22
22
|
from model_compression_toolkit.core import common
|
|
23
|
+
from model_compression_toolkit.core.common.visualization.tensorboard_writer import init_tensorboard_writer
|
|
23
24
|
from model_compression_toolkit.logger import Logger
|
|
24
25
|
from model_compression_toolkit.core.common.framework_info import FrameworkInfo
|
|
25
26
|
from model_compression_toolkit.core.common.mixed_precision.kpi_tools.kpi import KPI
|
|
@@ -27,7 +28,7 @@ from model_compression_toolkit.core.common.mixed_precision.mixed_precision_quant
|
|
|
27
28
|
MixedPrecisionQuantizationConfigV2
|
|
28
29
|
from model_compression_toolkit.target_platform_capabilities.target_platform.targetplatform2framework import \
|
|
29
30
|
TargetPlatformCapabilities
|
|
30
|
-
from model_compression_toolkit.core.runner import core_runner
|
|
31
|
+
from model_compression_toolkit.core.runner import core_runner
|
|
31
32
|
from model_compression_toolkit.ptq.runner import ptq_runner
|
|
32
33
|
|
|
33
34
|
if FOUND_TORCH:
|
|
@@ -145,7 +146,7 @@ if FOUND_TORCH:
|
|
|
145
146
|
Logger.info("Using experimental mixed-precision quantization. "
|
|
146
147
|
"If you encounter an issue please file a bug.")
|
|
147
148
|
|
|
148
|
-
tb_w =
|
|
149
|
+
tb_w = init_tensorboard_writer(fw_info)
|
|
149
150
|
|
|
150
151
|
fw_impl = PytorchImplementation()
|
|
151
152
|
|
|
File without changes
|
|
File without changes
|
|
File without changes
|