mct-nightly 2.4.0.20250617.613__py3-none-any.whl → 2.4.0.20250618.606__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mct_nightly-2.4.0.20250617.613.dist-info → mct_nightly-2.4.0.20250618.606.dist-info}/METADATA +1 -1
- {mct_nightly-2.4.0.20250617.613.dist-info → mct_nightly-2.4.0.20250618.606.dist-info}/RECORD +120 -120
- model_compression_toolkit/__init__.py +1 -1
- model_compression_toolkit/core/analyzer.py +2 -5
- model_compression_toolkit/core/common/back2framework/base_model_builder.py +0 -3
- model_compression_toolkit/core/common/framework_implementation.py +10 -22
- model_compression_toolkit/core/common/framework_info.py +105 -68
- model_compression_toolkit/core/common/graph/base_graph.py +15 -42
- model_compression_toolkit/core/common/graph/base_node.py +103 -42
- model_compression_toolkit/core/common/graph/functional_node.py +18 -1
- model_compression_toolkit/core/common/graph/virtual_activation_weights_node.py +7 -13
- model_compression_toolkit/core/common/mixed_precision/bit_width_setter.py +8 -18
- model_compression_toolkit/core/common/mixed_precision/mixed_precision_candidates_filter.py +4 -7
- model_compression_toolkit/core/common/mixed_precision/mixed_precision_ru_helper.py +2 -3
- model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_facade.py +2 -5
- model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_manager.py +3 -6
- model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization_calculator.py +5 -10
- model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization_data.py +2 -5
- model_compression_toolkit/core/common/mixed_precision/sensitivity_eval/metric_calculators.py +4 -8
- model_compression_toolkit/core/common/mixed_precision/sensitivity_eval/sensitivity_evaluation.py +2 -7
- model_compression_toolkit/core/common/model_collector.py +10 -20
- model_compression_toolkit/core/common/model_validation.py +1 -4
- model_compression_toolkit/core/common/network_editors/actions.py +14 -38
- model_compression_toolkit/core/common/network_editors/edit_network.py +1 -4
- model_compression_toolkit/core/common/pruning/channels_grouping.py +1 -5
- model_compression_toolkit/core/common/pruning/greedy_mask_calculator.py +0 -6
- model_compression_toolkit/core/common/pruning/importance_metrics/lfh_importance_metric.py +5 -15
- model_compression_toolkit/core/common/pruning/mask/per_channel_mask.py +3 -7
- model_compression_toolkit/core/common/pruning/mask/per_simd_group_mask.py +2 -4
- model_compression_toolkit/core/common/pruning/memory_calculator.py +5 -13
- model_compression_toolkit/core/common/pruning/prune_graph.py +1 -4
- model_compression_toolkit/core/common/pruning/pruner.py +1 -6
- model_compression_toolkit/core/common/pruning/pruning_framework_implementation.py +5 -13
- model_compression_toolkit/core/common/pruning/pruning_section.py +9 -18
- model_compression_toolkit/core/common/quantization/candidate_node_quantization_config.py +2 -1
- model_compression_toolkit/core/common/quantization/filter_nodes_candidates.py +10 -12
- model_compression_toolkit/core/common/quantization/node_quantization_config.py +4 -3
- model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_computation.py +5 -11
- model_compression_toolkit/core/common/quantization/set_node_quantization_config.py +8 -22
- model_compression_toolkit/core/common/statistics_correction/apply_activation_bias_correction_to_graph.py +1 -2
- model_compression_toolkit/core/common/statistics_correction/apply_bias_correction_to_graph.py +2 -3
- model_compression_toolkit/core/common/statistics_correction/apply_second_moment_correction_to_graph.py +5 -13
- model_compression_toolkit/core/common/statistics_correction/compute_activation_bias_correction_of_graph.py +3 -9
- model_compression_toolkit/core/common/statistics_correction/compute_bias_correction_of_graph.py +3 -10
- model_compression_toolkit/core/common/statistics_correction/statistics_correction.py +1 -6
- model_compression_toolkit/core/common/substitutions/batchnorm_reconstruction.py +2 -3
- model_compression_toolkit/core/common/substitutions/batchnorm_refusing.py +3 -6
- model_compression_toolkit/core/common/substitutions/scale_equalization.py +5 -21
- model_compression_toolkit/core/common/substitutions/shift_negative_activation.py +5 -19
- model_compression_toolkit/core/common/substitutions/virtual_activation_weights_composition.py +1 -3
- model_compression_toolkit/core/common/substitutions/weights_activation_split.py +1 -1
- model_compression_toolkit/core/common/visualization/nn_visualizer.py +3 -8
- model_compression_toolkit/core/common/visualization/tensorboard_writer.py +6 -8
- model_compression_toolkit/core/graph_prep_runner.py +2 -16
- model_compression_toolkit/core/keras/back2framework/float_model_builder.py +0 -4
- model_compression_toolkit/core/keras/back2framework/keras_model_builder.py +0 -5
- model_compression_toolkit/core/keras/back2framework/mixed_precision_model_builder.py +8 -15
- model_compression_toolkit/core/keras/back2framework/quantized_model_builder.py +0 -4
- model_compression_toolkit/core/keras/default_framework_info.py +138 -87
- model_compression_toolkit/core/keras/graph_substitutions/substitutions/batchnorm_folding.py +2 -7
- model_compression_toolkit/core/keras/graph_substitutions/substitutions/dwconv_to_conv.py +0 -1
- model_compression_toolkit/core/keras/graph_substitutions/substitutions/input_scaling.py +3 -5
- model_compression_toolkit/core/keras/graph_substitutions/substitutions/scale_equalization.py +8 -16
- model_compression_toolkit/core/keras/graph_substitutions/substitutions/shift_negative_activation.py +1 -4
- model_compression_toolkit/core/keras/hessian/weights_hessian_scores_calculator_keras.py +3 -13
- model_compression_toolkit/core/keras/keras_implementation.py +15 -35
- model_compression_toolkit/core/keras/keras_model_validation.py +6 -7
- model_compression_toolkit/core/keras/keras_node_prior_info.py +4 -13
- model_compression_toolkit/core/keras/pruning/pruning_keras_implementation.py +11 -34
- model_compression_toolkit/core/keras/resource_utilization_data_facade.py +0 -2
- model_compression_toolkit/core/keras/statistics_correction/keras_compute_activation_bias_correction_of_graph.py +0 -3
- model_compression_toolkit/core/pytorch/back2framework/float_model_builder.py +3 -12
- model_compression_toolkit/core/pytorch/back2framework/mixed_precision_model_builder.py +9 -16
- model_compression_toolkit/core/pytorch/back2framework/pytorch_model_builder.py +1 -5
- model_compression_toolkit/core/pytorch/back2framework/quantization_wrapper/quantized_layer_wrapper.py +2 -3
- model_compression_toolkit/core/pytorch/back2framework/quantized_model_builder.py +0 -4
- model_compression_toolkit/core/pytorch/default_framework_info.py +100 -74
- model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/const_holder_conv.py +3 -4
- model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/scale_equalization.py +4 -8
- model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/shift_negative_activation.py +1 -4
- model_compression_toolkit/core/pytorch/hessian/weights_hessian_scores_calculator_pytorch.py +3 -12
- model_compression_toolkit/core/pytorch/pruning/pruning_pytorch_implementation.py +16 -41
- model_compression_toolkit/core/pytorch/pytorch_implementation.py +12 -32
- model_compression_toolkit/core/pytorch/pytorch_node_prior_info.py +1 -5
- model_compression_toolkit/core/pytorch/resource_utilization_data_facade.py +2 -2
- model_compression_toolkit/core/pytorch/statistics_correction/pytorch_compute_activation_bias_correction_of_graph.py +0 -3
- model_compression_toolkit/core/quantization_prep_runner.py +4 -9
- model_compression_toolkit/core/runner.py +5 -15
- model_compression_toolkit/data_generation/keras/optimization_functions/lr_scheduler.py +8 -8
- model_compression_toolkit/data_generation/pytorch/optimization_functions/lr_scheduler.py +11 -11
- model_compression_toolkit/gptq/common/gptq_graph.py +5 -11
- model_compression_toolkit/gptq/common/gptq_training.py +1 -8
- model_compression_toolkit/gptq/keras/gptq_training.py +3 -9
- model_compression_toolkit/gptq/keras/graph_info.py +4 -6
- model_compression_toolkit/gptq/keras/quantization_facade.py +5 -8
- model_compression_toolkit/gptq/keras/quantizer/soft_rounding/soft_quantizer_reg.py +1 -3
- model_compression_toolkit/gptq/pytorch/gptq_training.py +3 -9
- model_compression_toolkit/gptq/pytorch/graph_info.py +1 -3
- model_compression_toolkit/gptq/pytorch/quantization_facade.py +5 -7
- model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/soft_quantizer_reg.py +1 -3
- model_compression_toolkit/gptq/runner.py +1 -7
- model_compression_toolkit/pruning/keras/pruning_facade.py +2 -3
- model_compression_toolkit/pruning/pytorch/pruning_facade.py +2 -3
- model_compression_toolkit/ptq/keras/quantization_facade.py +5 -10
- model_compression_toolkit/ptq/pytorch/quantization_facade.py +4 -8
- model_compression_toolkit/ptq/runner.py +1 -4
- model_compression_toolkit/qat/common/qat_config.py +2 -6
- model_compression_toolkit/qat/keras/quantization_facade.py +7 -10
- model_compression_toolkit/qat/pytorch/quantization_facade.py +6 -10
- model_compression_toolkit/xquant/common/core_report_generator.py +1 -1
- model_compression_toolkit/xquant/common/framework_report_utils.py +0 -3
- model_compression_toolkit/xquant/common/model_folding_utils.py +1 -6
- model_compression_toolkit/xquant/common/tensorboard_utils.py +1 -4
- model_compression_toolkit/xquant/keras/keras_report_utils.py +3 -8
- model_compression_toolkit/xquant/keras/tensorboard_utils.py +0 -3
- model_compression_toolkit/xquant/pytorch/pytorch_report_utils.py +5 -8
- model_compression_toolkit/xquant/pytorch/tensorboard_utils.py +0 -3
- {mct_nightly-2.4.0.20250617.613.dist-info → mct_nightly-2.4.0.20250618.606.dist-info}/WHEEL +0 -0
- {mct_nightly-2.4.0.20250617.613.dist-info → mct_nightly-2.4.0.20250618.606.dist-info}/licenses/LICENSE.md +0 -0
- {mct_nightly-2.4.0.20250617.613.dist-info → mct_nightly-2.4.0.20250618.606.dist-info}/top_level.txt +0 -0
@@ -32,7 +32,6 @@ from model_compression_toolkit.core.common.substitutions.apply_substitutions imp
|
|
32
32
|
def _collect_and_assign_act_threshold(graph: Graph,
|
33
33
|
representative_data_gen: Callable,
|
34
34
|
core_config: CoreConfig,
|
35
|
-
fw_info: FrameworkInfo,
|
36
35
|
fw_impl: FrameworkImplementation):
|
37
36
|
"""
|
38
37
|
Collect statistics after second moment correction and assign new thresholds to activations.
|
@@ -41,14 +40,12 @@ def _collect_and_assign_act_threshold(graph: Graph,
|
|
41
40
|
representative_data_gen (Callable): Dataset used for calibration.
|
42
41
|
core_config (CoreConfig): Configuration object containing parameters of how the model should be
|
43
42
|
quantized, including mixed precision parameters.
|
44
|
-
fw_info: FrameworkInfo object with information about the specific framework's model.
|
45
43
|
fw_impl: FrameworkImplementation object with a specific framework methods implementation.
|
46
44
|
"""
|
47
45
|
|
48
46
|
mi = ModelCollector(graph,
|
49
47
|
fw_impl,
|
50
|
-
|
51
|
-
core_config.quantization_config) # Mark points for statistics collection
|
48
|
+
core_config.quantization_config) # Mark points for statistics collection
|
52
49
|
|
53
50
|
for _data in tqdm(representative_data_gen()):
|
54
51
|
mi.infer(_data)
|
@@ -63,14 +60,12 @@ def _collect_and_assign_act_threshold(graph: Graph,
|
|
63
60
|
|
64
61
|
|
65
62
|
def quantized_model_builder_for_second_moment_correction(graph: common.Graph,
|
66
|
-
fw_info: FrameworkInfo,
|
67
63
|
fw_impl: Any):
|
68
64
|
"""
|
69
65
|
Build a framework model from a graph for second moment correction.
|
70
66
|
|
71
67
|
Args:
|
72
|
-
graph: Graph to build
|
73
|
-
fw_info: FrameworkInfo object with information about the specific framework's model.
|
68
|
+
graph: Graph to build from.
|
74
69
|
fw_impl: FrameworkImplementation object with a specific framework methods implementation.
|
75
70
|
|
76
71
|
Returns:
|
@@ -79,15 +74,13 @@ def quantized_model_builder_for_second_moment_correction(graph: common.Graph,
|
|
79
74
|
quantized_tg = quantize_graph_weights(graph)
|
80
75
|
|
81
76
|
quantized_model, user_info = fw_impl.model_builder(quantized_tg,
|
82
|
-
mode=ModelBuilderMode.FLOAT
|
83
|
-
fw_info=fw_info)
|
77
|
+
mode=ModelBuilderMode.FLOAT)
|
84
78
|
return quantized_model
|
85
79
|
|
86
80
|
|
87
81
|
def apply_second_moment_correction_to_graph(graph: Graph,
|
88
82
|
representative_data_gen: Callable,
|
89
83
|
core_config: CoreConfig,
|
90
|
-
fw_info: FrameworkInfo,
|
91
84
|
fw_impl: FrameworkImplementation) -> Graph:
|
92
85
|
"""
|
93
86
|
Apply second moment correction on graph.
|
@@ -96,15 +89,14 @@ def apply_second_moment_correction_to_graph(graph: Graph,
|
|
96
89
|
representative_data_gen (Callable): Dataset used for calibration.
|
97
90
|
core_config (CoreConfig): Configuration object containing parameters of how the model should be
|
98
91
|
quantized, including mixed precision parameters.
|
99
|
-
fw_info: FrameworkInfo object with information about the specific framework's model.
|
100
92
|
fw_impl: FrameworkImplementation object with a specific framework methods implementation.
|
101
93
|
|
102
94
|
Returns:
|
103
95
|
Graph after second moment correction.
|
104
96
|
"""
|
105
|
-
semi_quantized_model = quantized_model_builder_for_second_moment_correction(graph,
|
97
|
+
semi_quantized_model = quantized_model_builder_for_second_moment_correction(graph, fw_impl)
|
106
98
|
fw_impl.apply_second_moment_correction(semi_quantized_model, core_config, representative_data_gen, graph)
|
107
99
|
graph = substitute(graph, fw_impl.get_substitutions_after_second_moment_correction(core_config.quantization_config))
|
108
|
-
_collect_and_assign_act_threshold(graph, representative_data_gen, core_config,
|
100
|
+
_collect_and_assign_act_threshold(graph, representative_data_gen, core_config, fw_impl)
|
109
101
|
|
110
102
|
return graph
|
@@ -64,7 +64,6 @@ def calculate_bin_centers(bin_edges: np.ndarray) -> np.ndarray:
|
|
64
64
|
|
65
65
|
def compute_activation_bias_correction(graph: Graph,
|
66
66
|
quant_config: QuantizationConfig,
|
67
|
-
fw_info: FrameworkInfo,
|
68
67
|
fw_impl: FrameworkImplementation,
|
69
68
|
linear_node: BaseNode,
|
70
69
|
prev_node: BaseNode,
|
@@ -76,7 +75,6 @@ def compute_activation_bias_correction(graph: Graph,
|
|
76
75
|
Args:
|
77
76
|
graph: Graph with nodes to compute the activation bias correction for each node's final activation quantization configuration.
|
78
77
|
quant_config: QuantizationConfig of how the model should be quantized.
|
79
|
-
fw_info: Framework info like lists of nodes their kernel should quantized.
|
80
78
|
fw_impl: FrameworkImplementation object with a specific framework methods implementation.
|
81
79
|
linear_node: Node to compute the activation bias correction for.
|
82
80
|
prev_node: Node to compute the activation error caused by his activation quantization.
|
@@ -127,19 +125,18 @@ def compute_activation_bias_correction(graph: Graph,
|
|
127
125
|
if normalized_bias < quant_config.activation_bias_correction_threshold:
|
128
126
|
return graph
|
129
127
|
|
130
|
-
kernel = linear_node.get_weights_by_keys(
|
128
|
+
kernel = linear_node.get_weights_by_keys(linear_node.kernel_attr)
|
131
129
|
|
132
130
|
# Compute the activation bias correction by applying the quantization error to the kernel, resulting in an output
|
133
131
|
# size matching the number of output channels.
|
134
132
|
if kernel is not None:
|
135
133
|
|
136
134
|
# Get the axes that are not the output channel.
|
137
|
-
output_channel_index, input_channel_index = fw_info.kernel_channels_mapping.get(linear_node.type)
|
138
135
|
axis_not_output_channel = list(range(len(kernel.shape)))
|
139
|
-
axis_not_output_channel.remove(
|
136
|
+
axis_not_output_channel.remove(linear_node.channel_axis.output)
|
140
137
|
|
141
138
|
# Special case of depthwise_conv2d in tensorflow, where we have a depth multiplier for the filters.
|
142
|
-
if
|
139
|
+
if linear_node.channel_axis.output == linear_node.channel_axis.input:
|
143
140
|
axis_not_output_channel.remove(3) # 3 is the depth multiplier index.
|
144
141
|
|
145
142
|
activation_bias_correction_term = mean_diff * np.sum(kernel, axis=tuple(axis_not_output_channel))
|
@@ -150,7 +147,6 @@ def compute_activation_bias_correction(graph: Graph,
|
|
150
147
|
|
151
148
|
def compute_activation_bias_correction_of_graph(graph: Graph,
|
152
149
|
quant_config: QuantizationConfig,
|
153
|
-
fw_info: FrameworkInfo,
|
154
150
|
fw_impl: FrameworkImplementation,
|
155
151
|
activation_bias_correction_node_matchers: Callable,
|
156
152
|
kernel_size: str) -> Graph:
|
@@ -160,7 +156,6 @@ def compute_activation_bias_correction_of_graph(graph: Graph,
|
|
160
156
|
Args:
|
161
157
|
graph: Graph with nodes to compute the activation bias correction.
|
162
158
|
quant_config: QuantizationConfig of how the model should be quantized.
|
163
|
-
fw_info: Framework info like lists of nodes their kernel should quantized.
|
164
159
|
fw_impl: FrameworkImplementation object with a specific framework methods implementation.
|
165
160
|
activation_bias_correction_node_matchers: Function to match the layers for activation bias correction.
|
166
161
|
kernel_size: The framework specific attribute name of the convolution layer's kernel size.
|
@@ -177,7 +172,6 @@ def compute_activation_bias_correction_of_graph(graph: Graph,
|
|
177
172
|
if prev_node is not None:
|
178
173
|
graph = compute_activation_bias_correction(graph=graph,
|
179
174
|
quant_config=quant_config,
|
180
|
-
fw_info=fw_info,
|
181
175
|
fw_impl=fw_impl,
|
182
176
|
linear_node=n,
|
183
177
|
prev_node=prev_node,
|
model_compression_toolkit/core/common/statistics_correction/compute_bias_correction_of_graph.py
CHANGED
@@ -18,7 +18,6 @@ from typing import Any
|
|
18
18
|
import numpy as np
|
19
19
|
|
20
20
|
from model_compression_toolkit.core.common.framework_implementation import FrameworkImplementation
|
21
|
-
from model_compression_toolkit.core.common.framework_info import FrameworkInfo
|
22
21
|
from model_compression_toolkit.core.common import BaseNode, Graph
|
23
22
|
from model_compression_toolkit.core.common.quantization.quantize_node import get_quantized_weights_attr_by_qc
|
24
23
|
from model_compression_toolkit.core.common.collectors.statistics_collector import BaseStatsCollector
|
@@ -26,7 +25,6 @@ from model_compression_toolkit.logger import Logger
|
|
26
25
|
|
27
26
|
|
28
27
|
def compute_bias_correction_of_graph(graph: Graph,
|
29
|
-
fw_info: FrameworkInfo,
|
30
28
|
fw_impl: FrameworkImplementation) -> Graph:
|
31
29
|
"""
|
32
30
|
For each node in a graph, and for each candidate weights quantization configuration,
|
@@ -35,7 +33,6 @@ def compute_bias_correction_of_graph(graph: Graph,
|
|
35
33
|
Args:
|
36
34
|
graph: Graph with nodes to compute the bias correction for
|
37
35
|
each node's weights quantization configuration candidates.
|
38
|
-
fw_info: Framework info like lists of nodes their kernel should quantized.
|
39
36
|
fw_impl: FrameworkImplementation object with a specific framework methods implementation.
|
40
37
|
|
41
38
|
Returns:
|
@@ -46,17 +43,15 @@ def compute_bias_correction_of_graph(graph: Graph,
|
|
46
43
|
for n in graph.nodes:
|
47
44
|
# Bias correction is computed based on the quantized kernel, so we need to get the specific kernel attribute
|
48
45
|
# name out of all the weights attributes of the node.
|
49
|
-
if
|
50
|
-
|
51
|
-
if n.is_weights_quantization_enabled(kernel_attr):
|
46
|
+
if n.is_kernel_op:
|
47
|
+
if n.is_weights_quantization_enabled(n.kernel_attr):
|
52
48
|
# Bias correction is not applied to layers with constant inputs.
|
53
49
|
if n.has_positional_weights:
|
54
50
|
for candidate_qc in n.candidates_quantization_cfg:
|
55
51
|
candidate_qc.weights_quantization_cfg.weights_bias_correction = False
|
56
52
|
else:
|
57
53
|
_compute_bias_correction_per_candidate_qc(n,
|
58
|
-
kernel_attr,
|
59
|
-
fw_info,
|
54
|
+
n.kernel_attr,
|
60
55
|
graph.get_in_stats_collector(n),
|
61
56
|
fw_impl=fw_impl)
|
62
57
|
return graph
|
@@ -64,7 +59,6 @@ def compute_bias_correction_of_graph(graph: Graph,
|
|
64
59
|
|
65
60
|
def _compute_bias_correction_per_candidate_qc(node: BaseNode,
|
66
61
|
kernel_attr: str,
|
67
|
-
fw_info: FrameworkInfo,
|
68
62
|
node_in_stats_collector: BaseStatsCollector,
|
69
63
|
fw_impl: FrameworkImplementation):
|
70
64
|
"""
|
@@ -74,7 +68,6 @@ def _compute_bias_correction_per_candidate_qc(node: BaseNode,
|
|
74
68
|
Args:
|
75
69
|
node: Node to compute the bias correction for its different candidates.
|
76
70
|
kernel_attr: The name of the kernel attribute of the node.
|
77
|
-
fw_info: Framework info like lists of nodes their kernel should quantized.
|
78
71
|
node_in_stats_collector: Statistics collector of the node for the mean per-channel.
|
79
72
|
fw_impl: FrameworkImplementation object with a specific framework methods implementation.
|
80
73
|
|
@@ -32,7 +32,6 @@ from model_compression_toolkit.core.common.visualization.tensorboard_writer impo
|
|
32
32
|
|
33
33
|
def statistics_correction_runner(transformed_graph: Graph,
|
34
34
|
core_config: CoreConfig,
|
35
|
-
fw_info: FrameworkInfo,
|
36
35
|
fw_impl: FrameworkImplementation,
|
37
36
|
tb_w: TensorboardWriter = None, ) -> Graph:
|
38
37
|
"""
|
@@ -41,7 +40,6 @@ def statistics_correction_runner(transformed_graph: Graph,
|
|
41
40
|
transformed_graph: Graph to add statistics correction.
|
42
41
|
core_config (CoreConfig): Configuration object containing parameters of how the model should be
|
43
42
|
quantized, including mixed precision parameters.
|
44
|
-
fw_info: FrameworkInfo object with information about the specific framework's model.
|
45
43
|
fw_impl: FrameworkImplementation object with a specific framework methods implementation.
|
46
44
|
tb_w (TensorboardWriter): TensorboardWriter object to use for logging events such as graphs, histograms, etc.
|
47
45
|
|
@@ -59,7 +57,6 @@ def statistics_correction_runner(transformed_graph: Graph,
|
|
59
57
|
# Compute bias correction to nodes' config candidates
|
60
58
|
########################################################
|
61
59
|
tg_with_bias = compute_bias_correction_of_graph(tg_with_bias,
|
62
|
-
fw_info,
|
63
60
|
fw_impl)
|
64
61
|
|
65
62
|
if tb_w is not None:
|
@@ -71,7 +68,6 @@ def statistics_correction_runner(transformed_graph: Graph,
|
|
71
68
|
def apply_statistics_correction(transformed_graph: Graph,
|
72
69
|
representative_data_gen: Callable,
|
73
70
|
core_config: CoreConfig,
|
74
|
-
fw_info: FrameworkInfo,
|
75
71
|
fw_impl: FrameworkImplementation,
|
76
72
|
tb_w: TensorboardWriter = None, ) -> Graph:
|
77
73
|
"""
|
@@ -81,7 +77,6 @@ def apply_statistics_correction(transformed_graph: Graph,
|
|
81
77
|
representative_data_gen (Callable): Dataset used for calibration.
|
82
78
|
core_config (CoreConfig): Configuration object containing parameters of how the model should be
|
83
79
|
quantized, including mixed precision parameters.
|
84
|
-
fw_info: FrameworkInfo object with information about the specific framework's model.
|
85
80
|
fw_impl: FrameworkImplementation object with a specific framework methods implementation.
|
86
81
|
tb_w (TensorboardWriter): TensorboardWriter object to use for logging events such as graphs, histograms, etc.
|
87
82
|
|
@@ -94,7 +89,7 @@ def apply_statistics_correction(transformed_graph: Graph,
|
|
94
89
|
#############################################
|
95
90
|
if core_config.quantization_config.weights_second_moment_correction:
|
96
91
|
transformed_graph = apply_second_moment_correction_to_graph(transformed_graph, representative_data_gen,
|
97
|
-
core_config,
|
92
|
+
core_config, fw_impl)
|
98
93
|
|
99
94
|
#############################################
|
100
95
|
# Apply Bias Correction
|
@@ -97,10 +97,9 @@ class BatchNormalizationReconstruction(common.BaseSubstitution):
|
|
97
97
|
# This feature disabled for models with weights quantization method of Power of 2
|
98
98
|
for qc in source_node.candidates_quantization_cfg:
|
99
99
|
# this feature is relevant only for layers with kernel op
|
100
|
-
kernel_attr
|
101
|
-
if kernel_attr is None:
|
100
|
+
if source_node.kernel_attr is None:
|
102
101
|
Logger.error(f"Can't preform BatchNorm reconstruction on a node {source_node.name} without a kernel op.")
|
103
|
-
if (qc.weights_quantization_cfg.get_attr_config(kernel_attr
|
102
|
+
if (qc.weights_quantization_cfg.get_attr_config(source_node.kernel_attr).weights_quantization_method
|
104
103
|
== QuantizationMethod.POWER_OF_TWO):
|
105
104
|
Logger.warning("Second moment statistics correction feature disabled for models with weights "
|
106
105
|
"quantization method of Power of 2")
|
@@ -157,7 +157,7 @@ class BatchNormalizationRefusing(common.BaseSubstitution):
|
|
157
157
|
graph.remove_node(bn_node)
|
158
158
|
graph.remove_node(source_node)
|
159
159
|
|
160
|
-
self._calc_weights_quantization_params(conv_bn, weights_scale
|
160
|
+
self._calc_weights_quantization_params(conv_bn, weights_scale)
|
161
161
|
|
162
162
|
assert num_nodes_before_substitution - len(graph.nodes) == 1
|
163
163
|
assert num_edges_before_substitution - len(graph.edges) == 1
|
@@ -165,18 +165,15 @@ class BatchNormalizationRefusing(common.BaseSubstitution):
|
|
165
165
|
|
166
166
|
def _calc_weights_quantization_params(self,
|
167
167
|
conv_bn: BaseNode,
|
168
|
-
weights_scale: np.ndarray
|
169
|
-
fw_info):
|
168
|
+
weights_scale: np.ndarray):
|
170
169
|
"""
|
171
170
|
Update node weights quantization params.
|
172
171
|
Args:
|
173
172
|
conv_bn: Convolution node to update the weights quantization params.
|
174
173
|
weights_scale: Weight scale factor in which to multiply the conv node's weight.
|
175
|
-
fw_info: FrameworkInfo object with information about the specific framework's model
|
176
174
|
"""
|
177
175
|
# Conv layer is ensured to have a kernel attribute
|
178
|
-
|
179
|
-
conv_bn_kernel_cfg = conv_bn.final_weights_quantization_cfg.get_attr_config(kernel_attr)
|
176
|
+
conv_bn_kernel_cfg = conv_bn.final_weights_quantization_cfg.get_attr_config(conv_bn.kernel_attr)
|
180
177
|
# In case of SYMMETRIC weight quantization method, we update the threshold by weights_scale
|
181
178
|
if conv_bn_kernel_cfg.weights_quantization_method == QuantizationMethod.SYMMETRIC:
|
182
179
|
original_threshold = conv_bn_kernel_cfg.weights_quantization_params[THRESHOLD]
|
@@ -20,8 +20,6 @@ import scipy
|
|
20
20
|
|
21
21
|
from model_compression_toolkit.core import common
|
22
22
|
from model_compression_toolkit.core.common import Graph, BaseNode
|
23
|
-
from model_compression_toolkit.defaultdict import DefaultDict
|
24
|
-
from model_compression_toolkit.core.common.framework_info import FrameworkInfo
|
25
23
|
from model_compression_toolkit.core.common.quantization.quantization_config import QuantizationConfig
|
26
24
|
|
27
25
|
|
@@ -77,7 +75,6 @@ def fixed_second_moment_after_relu(mu: np.ndarray,
|
|
77
75
|
|
78
76
|
def scale_reshaping(scale: np.ndarray,
|
79
77
|
op2d: common.BaseNode,
|
80
|
-
kernel_channel_mapping: DefaultDict,
|
81
78
|
kernel_str: str,
|
82
79
|
in_channels: bool = True) -> np.ndarray:
|
83
80
|
"""
|
@@ -89,7 +86,6 @@ def scale_reshaping(scale: np.ndarray,
|
|
89
86
|
Args:
|
90
87
|
scale: Scale factor to scale the kernel channels by.
|
91
88
|
op2d: Node to scale its kernel.
|
92
|
-
kernel_channel_mapping: Mapping from a layer to a tuple of indices of its output/input kernel channels.
|
93
89
|
kernel_str: The framework specific attribute name of the convolution layer's weight/kernel.
|
94
90
|
in_channels: Kernel's index of input channels.
|
95
91
|
|
@@ -99,12 +95,11 @@ def scale_reshaping(scale: np.ndarray,
|
|
99
95
|
|
100
96
|
op_ndims = op2d.get_weights_by_keys(kernel_str).ndim
|
101
97
|
reshape_target = np.ones(op_ndims, dtype=np.int32)
|
102
|
-
reshape_target[
|
98
|
+
reshape_target[op2d.channel_axis.input if in_channels else op2d.channel_axis.output] = -1
|
103
99
|
return np.reshape(scale, reshape_target)
|
104
100
|
|
105
101
|
|
106
|
-
def update_linear_nodes(
|
107
|
-
first_op2d_node: BaseNode,
|
102
|
+
def update_linear_nodes(first_op2d_node: BaseNode,
|
108
103
|
second_op2d_node: BaseNode,
|
109
104
|
scale_factor: np.ndarray,
|
110
105
|
kernel_str: str,
|
@@ -116,7 +111,6 @@ def update_linear_nodes(fw_info: FrameworkInfo,
|
|
116
111
|
The scale factor contain a scale value per-channel.
|
117
112
|
|
118
113
|
Args:
|
119
|
-
fw_info: Information needed for quantization about the specific framework (e.g., kernel channels indices,
|
120
114
|
groups of layers by how they should be quantized, etc.)
|
121
115
|
first_op2d_node: Node to multiply its kernel by the scale factor.
|
122
116
|
second_op2d_node: Node to divide its kernel by the scale factor.
|
@@ -125,15 +119,12 @@ def update_linear_nodes(fw_info: FrameworkInfo,
|
|
125
119
|
kernel_str: The framework specific attribute name of the convolution layer's weight/kernel.
|
126
120
|
|
127
121
|
"""
|
128
|
-
|
129
122
|
w2_fixed = second_op2d_node.get_weights_by_keys(kernel_str) / scale_reshaping(scale_factor,
|
130
123
|
second_op2d_node,
|
131
|
-
fw_info.kernel_channels_mapping,
|
132
124
|
kernel_str)
|
133
125
|
|
134
126
|
w1_fixed = first_op2d_node.get_weights_by_keys(kernel_str) * scale_reshaping(scale_factor,
|
135
127
|
first_op2d_node,
|
136
|
-
fw_info.kernel_channels_mapping,
|
137
128
|
kernel_str,
|
138
129
|
in_channels=False)
|
139
130
|
|
@@ -168,8 +159,7 @@ def calculate_scale_correction(first_op2d_node: BaseNode) -> tuple:
|
|
168
159
|
return scale_factor
|
169
160
|
|
170
161
|
|
171
|
-
def scale_equalization_lnl(
|
172
|
-
first_op2d_node: BaseNode,
|
162
|
+
def scale_equalization_lnl(first_op2d_node: BaseNode,
|
173
163
|
second_op2d_node: BaseNode,
|
174
164
|
kernel_str: str,
|
175
165
|
bias_str: str):
|
@@ -179,7 +169,6 @@ def scale_equalization_lnl(fw_info: FrameworkInfo,
|
|
179
169
|
follows the activation node to get the same expected output without the scaling.
|
180
170
|
|
181
171
|
Args:
|
182
|
-
fw_info: Information needed for quantization about the specific framework (e.g., kernel channels indices,
|
183
172
|
groups of layers by how they should be quantized, etc.)
|
184
173
|
first_op2d_node: Node to multiply its kernel by the scale factor.
|
185
174
|
second_op2d_node: Node to divide its kernel by the scale factor.
|
@@ -189,8 +178,7 @@ def scale_equalization_lnl(fw_info: FrameworkInfo,
|
|
189
178
|
"""
|
190
179
|
scale_factor = calculate_scale_correction(first_op2d_node)
|
191
180
|
|
192
|
-
update_linear_nodes(
|
193
|
-
first_op2d_node,
|
181
|
+
update_linear_nodes(first_op2d_node,
|
194
182
|
second_op2d_node,
|
195
183
|
scale_factor,
|
196
184
|
kernel_str,
|
@@ -206,7 +194,6 @@ class BaseScaleEqualization(common.BaseSubstitution):
|
|
206
194
|
|
207
195
|
def __init__(self,
|
208
196
|
quant_config: QuantizationConfig,
|
209
|
-
fw_info: FrameworkInfo,
|
210
197
|
matcher_instance,
|
211
198
|
kernel_str: str,
|
212
199
|
bias_str: str):
|
@@ -214,13 +201,11 @@ class BaseScaleEqualization(common.BaseSubstitution):
|
|
214
201
|
Initialize a ScaleEqualization object.
|
215
202
|
Args:
|
216
203
|
quant_config: QuantizationConfig containing parameters of how the model should be quantized.
|
217
|
-
fw_info: Information needed for quantization about the specific framework (e.g., kernel channels indices,
|
218
204
|
groups of layers by how they should be quantized, etc.)
|
219
205
|
matcher_instance: Per substitution matcher instance of type WalkMatcher
|
220
206
|
"""
|
221
207
|
|
222
208
|
self.quant_config = quant_config
|
223
|
-
self.fw_info = fw_info
|
224
209
|
self.kernel_str = kernel_str
|
225
210
|
self.bias_str = bias_str
|
226
211
|
super().__init__(matcher_instance=matcher_instance)
|
@@ -243,8 +228,7 @@ class BaseScaleEqualization(common.BaseSubstitution):
|
|
243
228
|
act_node = nodes_list[1]
|
244
229
|
second_op2d_node = nodes_list[-1]
|
245
230
|
if first_op2d_node.prior_info.std_output is not None and act_node.is_activation_quantization_enabled():
|
246
|
-
scale_equalization_lnl(
|
247
|
-
first_op2d_node,
|
231
|
+
scale_equalization_lnl(first_op2d_node,
|
248
232
|
second_op2d_node,
|
249
233
|
self.kernel_str,
|
250
234
|
self.bias_str)
|
@@ -46,7 +46,6 @@ If the linear node pads the input tensor with zeros, we modify the padded value
|
|
46
46
|
|
47
47
|
def op2d_bias_correction(op2d_node: BaseNode,
|
48
48
|
shift_to_correct: float,
|
49
|
-
fw_info: FrameworkInfo,
|
50
49
|
bias_str: str,
|
51
50
|
bias_flag_str: str):
|
52
51
|
"""
|
@@ -57,7 +56,6 @@ def op2d_bias_correction(op2d_node: BaseNode,
|
|
57
56
|
op2d_node: Node to compute its bias correction term.
|
58
57
|
shift_to_correct: Value that was used to shift the output tensor of
|
59
58
|
the non-linear node.
|
60
|
-
fw_info: Information needed for quantization about the specific framework (e.g., kernel channels indices,
|
61
59
|
bias_str:
|
62
60
|
bias_flag_str: The framework specific attribute name of the bias flag.
|
63
61
|
"""
|
@@ -76,14 +74,13 @@ def op2d_bias_correction(op2d_node: BaseNode,
|
|
76
74
|
# Each node adds a different noise due to the shifting. It depends on the
|
77
75
|
# dimensions of the kernel, thus the correction term is a function of
|
78
76
|
# the layer type.
|
79
|
-
kernel = op2d_node.get_weights_by_keys(
|
77
|
+
kernel = op2d_node.get_weights_by_keys(op2d_node.kernel_attr)
|
80
78
|
if kernel is not None:
|
81
|
-
output_channel_index, input_channel_index = fw_info.kernel_channels_mapping.get(op2d_node.type)
|
82
79
|
axis_not_output_channel = list(range(len(kernel.shape)))
|
83
|
-
axis_not_output_channel.remove(
|
80
|
+
axis_not_output_channel.remove(op2d_node.channel_axis.output)
|
84
81
|
|
85
82
|
# special case of depthwise_conv2d in tensorflow, where we have a depth multiplier for the filters
|
86
|
-
if
|
83
|
+
if op2d_node.channel_axis.output == op2d_node.channel_axis.input:
|
87
84
|
axis_not_output_channel.remove(3) # 3 is the depth multiplier index
|
88
85
|
|
89
86
|
bias_correction = shift_to_correct * np.sum(kernel, axis=tuple(axis_not_output_channel))
|
@@ -250,7 +247,6 @@ def shift_negative_function(graph: Graph,
|
|
250
247
|
core_config: CoreConfig,
|
251
248
|
non_linear_node: BaseNode,
|
252
249
|
op2d_node: BaseNode,
|
253
|
-
fw_info: FrameworkInfo,
|
254
250
|
create_add_node: Callable,
|
255
251
|
get_padding_values: Callable,
|
256
252
|
create_pad_node: Callable,
|
@@ -276,8 +272,6 @@ def shift_negative_function(graph: Graph,
|
|
276
272
|
non_linear_node: Non-linear node with negative values to shift.
|
277
273
|
op2d_node: Linear node to correct its bias to overcome the expected error due to
|
278
274
|
the shifting.
|
279
|
-
fw_info: Information needed for quantization about the specific framework (e.g., kernel channels indices,
|
280
|
-
groups of layers by how they should be quantized, etc.)
|
281
275
|
create_add_node: Function to create an add node.
|
282
276
|
get_padding_values: Function to compute the op2d node's padding values
|
283
277
|
create_pad_node: Function to create an pad node.
|
@@ -299,7 +293,6 @@ def shift_negative_function(graph: Graph,
|
|
299
293
|
# all candidates have same activation config, so taking the first candidate for calculations
|
300
294
|
non_linear_node_cfg_candidate = non_linear_node.candidates_quantization_cfg[0].activation_quantization_cfg
|
301
295
|
|
302
|
-
|
303
296
|
# get the non-linear activation threshold
|
304
297
|
activation_threshold = non_linear_node_cfg_candidate.activation_quantization_params.get(THRESHOLD)
|
305
298
|
|
@@ -390,7 +383,6 @@ def shift_negative_function(graph: Graph,
|
|
390
383
|
first_node=non_linear_node)
|
391
384
|
op2d_bias_correction(op2d_node,
|
392
385
|
shift_value,
|
393
|
-
fw_info,
|
394
386
|
bias_str,
|
395
387
|
bias_flag_str)
|
396
388
|
|
@@ -401,8 +393,7 @@ def shift_negative_function(graph: Graph,
|
|
401
393
|
graph.set_out_stats_collector_to_node(add_node, add_node_stats_collector)
|
402
394
|
graph.shift_stats_collector(add_node, np.array(shift_value))
|
403
395
|
|
404
|
-
set_quantization_configs_to_node(
|
405
|
-
node=add_node,
|
396
|
+
set_quantization_configs_to_node(node=add_node,
|
406
397
|
graph=graph,
|
407
398
|
quant_config=core_config.quantization_config,
|
408
399
|
fqc=graph.fqc,
|
@@ -428,8 +419,7 @@ def shift_negative_function(graph: Graph,
|
|
428
419
|
last_node=op2d_node)
|
429
420
|
|
430
421
|
# Set quantization configuration to node, even though we do not quantize it:
|
431
|
-
set_quantization_configs_to_node(
|
432
|
-
node=pad_node,
|
422
|
+
set_quantization_configs_to_node(node=pad_node,
|
433
423
|
graph=graph,
|
434
424
|
quant_config=core_config.quantization_config,
|
435
425
|
fqc=graph.fqc,
|
@@ -472,7 +462,6 @@ def shift_negative_function(graph: Graph,
|
|
472
462
|
candidate_qc.weights_quantization_cfg.get_attr_config(attr).enable_weights_quantization = False
|
473
463
|
|
474
464
|
candidate_qc.activation_quantization_cfg = create_node_activation_qc(core_config.quantization_config,
|
475
|
-
fw_info,
|
476
465
|
add_node_qco[op_qc_idx])
|
477
466
|
|
478
467
|
candidate_qc.activation_quantization_cfg.set_activation_quantization_param({THRESHOLD: activation_threshold,
|
@@ -573,7 +562,6 @@ def get_next_nodes_to_correct(n: BaseNode,
|
|
573
562
|
|
574
563
|
def apply_shift_negative_correction(graph: Graph,
|
575
564
|
core_config: CoreConfig,
|
576
|
-
fw_info: FrameworkInfo,
|
577
565
|
snc_node_types: NodeOperationMatcher,
|
578
566
|
linear_node_types: NodeOperationMatcher,
|
579
567
|
bypass_node_types: NodeOperationMatcher,
|
@@ -593,7 +581,6 @@ def apply_shift_negative_correction(graph: Graph,
|
|
593
581
|
Args:
|
594
582
|
graph: Graph to apply the substitution on.
|
595
583
|
core_config: Quantization configuration to build the substitutions list according to.
|
596
|
-
fw_info: Information needed for quantization about the specific framework (e.g., kernel channels indices,
|
597
584
|
groups of layers by how they should be quantized, etc.)
|
598
585
|
snc_node_types: Types of activation nodes with negative outputs to consider.
|
599
586
|
linear_node_types: Types of linear nodes to consider.
|
@@ -632,7 +619,6 @@ def apply_shift_negative_correction(graph: Graph,
|
|
632
619
|
core_config,
|
633
620
|
n,
|
634
621
|
linear_node,
|
635
|
-
fw_info,
|
636
622
|
create_add_node,
|
637
623
|
get_padding_values,
|
638
624
|
create_pad_node,
|
model_compression_toolkit/core/common/substitutions/virtual_activation_weights_composition.py
CHANGED
@@ -50,9 +50,7 @@ class BaseVirtualActivationWeightsComposition(BaseSubstitution):
|
|
50
50
|
return graph
|
51
51
|
|
52
52
|
# Virtual composed activation-weights node
|
53
|
-
v_node = VirtualActivationWeightsNode(act_node,
|
54
|
-
weights_node,
|
55
|
-
fw_info=graph.fw_info)
|
53
|
+
v_node = VirtualActivationWeightsNode(act_node, weights_node)
|
56
54
|
|
57
55
|
# Update graph
|
58
56
|
graph.add_node(v_node)
|
@@ -50,7 +50,7 @@ class BaseWeightsActivationSplit(BaseSubstitution):
|
|
50
50
|
Graph after applying the substitution.
|
51
51
|
"""
|
52
52
|
# The decomposition works on linear nodes, that is, nodes with kernel ops
|
53
|
-
kernel_attr =
|
53
|
+
kernel_attr = node.kernel_attr
|
54
54
|
if kernel_attr is None:
|
55
55
|
Logger.critical(f"Trying to split node weights and activation, but node "
|
56
56
|
f"{node.name} doesn't have a kernel attribute.")
|
@@ -59,22 +59,19 @@ class NNVisualizer:
|
|
59
59
|
def __init__(self,
|
60
60
|
graph_float: Graph,
|
61
61
|
graph_quantized: Graph,
|
62
|
-
fw_impl: FrameworkImplementation
|
63
|
-
fw_info: FrameworkInfo):
|
62
|
+
fw_impl: FrameworkImplementation):
|
64
63
|
"""
|
65
64
|
Initialize a NNVisualizer object.
|
66
65
|
Args:
|
67
66
|
graph_float: Float version of the graph.
|
68
67
|
graph_quantized: Quantized version of the graph.
|
69
68
|
fw_impl: Framework implementation with framework-specific methods implementation.
|
70
|
-
fw_info: Framework info with framework-specific information.
|
71
69
|
|
72
70
|
"""
|
73
71
|
|
74
72
|
self.graph_float = graph_float
|
75
73
|
self.graph_quantized = graph_quantized
|
76
74
|
self.fw_impl = fw_impl
|
77
|
-
self.fw_info = fw_info
|
78
75
|
|
79
76
|
# Get compare points of two graphs.
|
80
77
|
self.compare_points, self.compare_points_name = _get_compare_points(self.graph_quantized)
|
@@ -92,13 +89,11 @@ class NNVisualizer:
|
|
92
89
|
|
93
90
|
self.quantized_model, _ = self.fw_impl.model_builder(self.graph_quantized,
|
94
91
|
mode=ModelBuilderMode.QUANTIZED,
|
95
|
-
append2output=self.compare_points
|
96
|
-
fw_info=self.fw_info)
|
92
|
+
append2output=self.compare_points)
|
97
93
|
|
98
94
|
self.float_model, _ = self.fw_impl.model_builder(self.graph_float,
|
99
95
|
mode=ModelBuilderMode.FLOAT,
|
100
|
-
append2output=self.compare_points_float
|
101
|
-
fw_info=self.fw_info)
|
96
|
+
append2output=self.compare_points_float)
|
102
97
|
|
103
98
|
def has_compare_points(self) -> bool:
|
104
99
|
"""
|
@@ -89,20 +89,18 @@ class TensorboardWriter(object):
|
|
89
89
|
Class to log events to display using Tensorboard such as graphs, histograms, images, etc.
|
90
90
|
"""
|
91
91
|
|
92
|
-
def __init__(self, dir_path: str
|
92
|
+
def __init__(self, dir_path: str):
|
93
93
|
"""
|
94
94
|
Initialize a TensorboardWriter object.
|
95
95
|
|
96
96
|
Args:
|
97
97
|
dir_path: Path to save all events to display on Tensorboard.
|
98
|
-
fw_info: FrameworkInfo object (needed for computing nodes' weights memory).
|
99
98
|
|
100
99
|
"""
|
101
100
|
self.dir_path = dir_path
|
102
101
|
# we hold EventWriter per tag name, so events can be gathered by tags (like phases during the quantization
|
103
102
|
# process).
|
104
103
|
self.tag_name_to_event_writer = {}
|
105
|
-
self.fw_info = fw_info
|
106
104
|
|
107
105
|
def close(self):
|
108
106
|
"""
|
@@ -232,7 +230,7 @@ class TensorboardWriter(object):
|
|
232
230
|
if n.final_weights_quantization_cfg is not None:
|
233
231
|
attr.update(n.final_weights_quantization_cfg.__dict__)
|
234
232
|
elif n.candidates_quantization_cfg is not None:
|
235
|
-
attr.update(n.get_unified_weights_candidates_dict(
|
233
|
+
attr.update(n.get_unified_weights_candidates_dict())
|
236
234
|
return attr
|
237
235
|
|
238
236
|
def __get_node_attr(n: BaseNode) -> Dict[str, Any]:
|
@@ -296,7 +294,7 @@ class TensorboardWriter(object):
|
|
296
294
|
|
297
295
|
return NodeExecStats(node_name=n.name,
|
298
296
|
memory=[AllocatorMemoryUsed(
|
299
|
-
total_bytes=int(n.get_memory_bytes(
|
297
|
+
total_bytes=int(n.get_memory_bytes())
|
300
298
|
)])
|
301
299
|
|
302
300
|
graph_def = GraphDef() # GraphDef to add to Tensorboard
|
@@ -526,13 +524,13 @@ class TensorboardWriter(object):
|
|
526
524
|
er.add_event(event)
|
527
525
|
er.flush()
|
528
526
|
|
529
|
-
|
527
|
+
|
528
|
+
def init_tensorboard_writer() -> TensorboardWriter:
|
530
529
|
"""
|
531
530
|
Create a TensorBoardWriter object initialized with the logger dir path if it was set,
|
532
531
|
or None otherwise.
|
533
532
|
|
534
533
|
Args:
|
535
|
-
fw_info: FrameworkInfo object.
|
536
534
|
|
537
535
|
Returns:
|
538
536
|
A TensorBoardWriter object.
|
@@ -541,7 +539,7 @@ def init_tensorboard_writer(fw_info: FrameworkInfo) -> TensorboardWriter:
|
|
541
539
|
if Logger.LOG_PATH is not None:
|
542
540
|
tb_log_dir = os.path.join(os.getcwd(), Logger.LOG_PATH, 'tensorboard_logs')
|
543
541
|
Logger.info(f'To use Tensorboard, please run: tensorboard --logdir {tb_log_dir}')
|
544
|
-
tb_w = TensorboardWriter(tb_log_dir
|
542
|
+
tb_w = TensorboardWriter(tb_log_dir)
|
545
543
|
return tb_w
|
546
544
|
|
547
545
|
|