mct-nightly 1.1.0.7012022.post2611__py3-none-any.whl → 1.1.0.07122021-002414__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (78) hide show
  1. {mct_nightly-1.1.0.7012022.post2611.dist-info → mct_nightly-1.1.0.7122021.post2414.dist-info}/METADATA +3 -3
  2. {mct_nightly-1.1.0.7012022.post2611.dist-info → mct_nightly-1.1.0.7122021.post2414.dist-info}/RECORD +72 -76
  3. {mct_nightly-1.1.0.7012022.post2611.dist-info → mct_nightly-1.1.0.7122021.post2414.dist-info}/WHEEL +1 -1
  4. model_compression_toolkit/common/__init__.py +2 -2
  5. model_compression_toolkit/common/bias_correction/apply_bias_correction_to_graph.py +2 -2
  6. model_compression_toolkit/common/bias_correction/compute_bias_correction_of_graph.py +9 -9
  7. model_compression_toolkit/common/collectors/mean_collector.py +2 -3
  8. model_compression_toolkit/common/collectors/min_max_per_channel_collector.py +3 -6
  9. model_compression_toolkit/common/constants.py +0 -1
  10. model_compression_toolkit/common/framework_implementation.py +6 -22
  11. model_compression_toolkit/common/framework_info.py +7 -39
  12. model_compression_toolkit/common/graph/__init__.py +1 -1
  13. model_compression_toolkit/common/graph/base_graph.py +34 -34
  14. model_compression_toolkit/common/graph/edge.py +3 -3
  15. model_compression_toolkit/common/graph/graph_matchers.py +3 -3
  16. model_compression_toolkit/common/graph/graph_searches.py +4 -4
  17. model_compression_toolkit/common/graph/graph_vis.py +116 -0
  18. model_compression_toolkit/common/graph/{base_node.py → node.py} +27 -11
  19. model_compression_toolkit/common/mixed_precision/bit_width_setter.py +9 -10
  20. model_compression_toolkit/common/mixed_precision/mixed_precision_search_manager.py +1 -1
  21. model_compression_toolkit/common/model_collector.py +12 -14
  22. model_compression_toolkit/common/network_editors/actions.py +23 -19
  23. model_compression_toolkit/common/post_training_quantization.py +7 -20
  24. model_compression_toolkit/common/quantization/node_quantization_config.py +5 -13
  25. model_compression_toolkit/common/quantization/quantization_analyzer.py +7 -11
  26. model_compression_toolkit/common/quantization/quantization_config.py +6 -6
  27. model_compression_toolkit/common/quantization/quantization_params_fn_selection.py +3 -2
  28. model_compression_toolkit/common/quantization/quantization_params_generation/qparams_activations_computation.py +7 -13
  29. model_compression_toolkit/common/quantization/quantization_params_generation/qparams_computation.py +20 -17
  30. model_compression_toolkit/common/quantization/quantize_node.py +2 -2
  31. model_compression_toolkit/common/quantization/set_node_quantization_config.py +36 -39
  32. model_compression_toolkit/common/{collectors/statistics_collector.py → statistics_collector.py} +30 -26
  33. model_compression_toolkit/common/visualization/tensorboard_writer.py +8 -11
  34. model_compression_toolkit/keras/back2framework/instance_builder.py +4 -4
  35. model_compression_toolkit/keras/back2framework/model_builder.py +34 -47
  36. model_compression_toolkit/keras/constants.py +0 -3
  37. model_compression_toolkit/keras/default_framework_info.py +7 -33
  38. model_compression_toolkit/keras/gradient_ptq/graph_info.py +2 -2
  39. model_compression_toolkit/keras/gradient_ptq/graph_update.py +1 -7
  40. model_compression_toolkit/keras/gradient_ptq/training_wrapper.py +1 -0
  41. model_compression_toolkit/keras/graph_substitutions/substitutions/activation_decomposition.py +8 -10
  42. model_compression_toolkit/keras/graph_substitutions/substitutions/batchnorm_folding.py +2 -2
  43. model_compression_toolkit/keras/graph_substitutions/substitutions/input_scaling.py +2 -2
  44. model_compression_toolkit/keras/graph_substitutions/substitutions/mark_activation.py +3 -3
  45. model_compression_toolkit/keras/graph_substitutions/substitutions/relu_bound_correction.py +4 -3
  46. model_compression_toolkit/keras/graph_substitutions/substitutions/remove_relu_upper_bound.py +2 -2
  47. model_compression_toolkit/keras/graph_substitutions/substitutions/scale_equalization.py +9 -9
  48. model_compression_toolkit/keras/graph_substitutions/substitutions/separableconv_decomposition.py +19 -19
  49. model_compression_toolkit/keras/graph_substitutions/substitutions/shift_negative_activation.py +45 -64
  50. model_compression_toolkit/keras/keras_implementation.py +8 -28
  51. model_compression_toolkit/keras/mixed_precision/sensitivity_evaluation.py +2 -2
  52. model_compression_toolkit/keras/quantization_facade.py +1 -5
  53. model_compression_toolkit/keras/quantizer/fake_quant_builder.py +4 -4
  54. model_compression_toolkit/keras/quantizer/gradient_ptq/activation_quantizer.py +2 -3
  55. model_compression_toolkit/keras/quantizer/gradient_ptq/activation_quantizer_gptq_config.py +4 -8
  56. model_compression_toolkit/keras/quantizer/gradient_ptq/activation_weight_quantizer_gptq_config.py +4 -9
  57. model_compression_toolkit/keras/quantizer/gradient_ptq/config_factory.py +10 -9
  58. model_compression_toolkit/keras/quantizer/gradient_ptq/weight_quantizer_gptq_config.py +1 -9
  59. model_compression_toolkit/keras/quantizer/mixed_precision/quantization_config_factory.py +1 -1
  60. model_compression_toolkit/keras/quantizer/mixed_precision/selective_weights_quantize_config.py +1 -6
  61. model_compression_toolkit/keras/reader/common.py +11 -9
  62. model_compression_toolkit/keras/reader/connectivity_handler.py +9 -15
  63. model_compression_toolkit/keras/reader/nested_model/edges_merger.py +6 -6
  64. model_compression_toolkit/keras/reader/nested_model/nested_model_handler.py +2 -2
  65. model_compression_toolkit/keras/reader/nested_model/nodes_merger.py +3 -3
  66. model_compression_toolkit/keras/reader/nested_model/outputs_merger.py +2 -2
  67. model_compression_toolkit/keras/reader/node_builder.py +15 -65
  68. model_compression_toolkit/keras/reader/reader.py +5 -5
  69. model_compression_toolkit/keras/tensor_marking.py +113 -0
  70. model_compression_toolkit/keras/visualization/nn_visualizer.py +2 -2
  71. model_compression_toolkit/common/collectors/statistics_collector_generator.py +0 -43
  72. model_compression_toolkit/common/graph/functional_node.py +0 -59
  73. model_compression_toolkit/common/model_validation.py +0 -43
  74. model_compression_toolkit/common/node_prior_info.py +0 -29
  75. model_compression_toolkit/keras/keras_model_validation.py +0 -38
  76. model_compression_toolkit/keras/keras_node_prior_info.py +0 -60
  77. {mct_nightly-1.1.0.7012022.post2611.dist-info → mct_nightly-1.1.0.7122021.post2414.dist-info}/LICENSE +0 -0
  78. {mct_nightly-1.1.0.7012022.post2611.dist-info → mct_nightly-1.1.0.7122021.post2414.dist-info}/top_level.txt +0 -0
@@ -1,59 +0,0 @@
1
- from typing import Dict, Any, Tuple, List
2
-
3
- from model_compression_toolkit.common.graph.base_node import BaseNode
4
- import numpy as np
5
-
6
- class FunctionalNode(BaseNode):
7
- """
8
- Node that represents function ops with arguments to pass when building back the model.
9
- """
10
-
11
- def __init__(self,
12
- name: str,
13
- framework_attr: Dict[str, Any],
14
- input_shape: Tuple[Any],
15
- output_shape: Tuple[Any],
16
- weights: Dict[str, np.ndarray],
17
- layer_class: type,
18
- op_call_args: List[Any] = None,
19
- op_call_kwargs: Dict[str, Any] = None,
20
- reuse: bool = False,
21
- reuse_group: str = None,
22
- quantization_attr: Dict[str, Any] = None,
23
- functional_op: Any = None,
24
- inputs_as_list: bool = False):
25
- """
26
- Init a FunctionalNode object.
27
-
28
- Args:
29
- name: Node's name
30
- framework_attr: Framework attributes the layer had which the node holds.
31
- input_shape: Input tensor shape of the node.
32
- output_shape: Input tensor shape of the node.
33
- weights: Dictionary from a variable name to the weights with that name in the layer the node represents.
34
- layer_class: Class path of the layer this node represents.
35
- op_call_args: Arguments list to pass when calling the layer.
36
- op_call_kwargs: Key-Word Arguments dictionary with values to pass when calling the layer.
37
- reuse: Whether this node was duplicated and represents a reused layer.
38
- reuse_group: Name of group of nodes from the same reused layer.
39
- quantization_attr: Attributes the node holds regarding how it should be quantized.
40
- functional_op: The op the node implements.
41
- inputs_as_list: Whether to pass the node its input tensors as a list or not when calling the layer.
42
-
43
- """
44
-
45
- super().__init__(name,
46
- framework_attr,
47
- input_shape,
48
- output_shape,
49
- weights,
50
- layer_class,
51
- reuse,
52
- reuse_group,
53
- quantization_attr)
54
-
55
- self.op_call_kwargs = op_call_kwargs
56
- self.op_call_args = op_call_args
57
- self.functional_op = functional_op
58
- self.inputs_as_list = inputs_as_list
59
-
@@ -1,43 +0,0 @@
1
- from abc import abstractmethod
2
- from typing import Any
3
-
4
- from model_compression_toolkit import FrameworkInfo
5
-
6
-
7
- class ModelValidation:
8
- """
9
- Class to define validation methods in order to validate the received model to quantize.
10
- """
11
-
12
- def __init__(self,
13
- model: Any,
14
- fw_info:FrameworkInfo):
15
- """
16
- Initialize a ModelValidation object.
17
-
18
- Args:
19
- model: Model to check its validity.
20
- fw_info: Information about the specific framework of the model.
21
- """
22
- self.model = model
23
- self.fw_info = fw_info
24
-
25
- @abstractmethod
26
- def validate_output_channel_consistency(self):
27
- """
28
-
29
- Validate that output channels index in all layers of the model are the same.
30
- If the model has layers with different output channels index, it should throw an exception.
31
-
32
- """
33
- raise NotImplemented(f'Framework validation class did not implement validate_output_channel_consistency')
34
-
35
- def validate(self):
36
- """
37
-
38
- Run all validation methods before the quantization process starts.
39
-
40
- """
41
- self.validate_output_channel_consistency()
42
-
43
-
@@ -1,29 +0,0 @@
1
- from model_compression_toolkit.common.collectors.statistics_collector import is_number
2
-
3
-
4
- class NodePriorInfo:
5
- """
6
- Class to wrap all prior information we have on a node.
7
- """
8
-
9
- def __init__(self,
10
- min_output: float = None,
11
- max_output: float = None):
12
- """
13
- Initialize a NodePriorInfo object.
14
-
15
- Args:
16
- min_output: Minimal output value of the node.
17
- max_output: Maximal output value of the node.
18
- """
19
-
20
- self.min_output = min_output
21
- self.max_output = max_output
22
-
23
- def is_output_bounded(self) -> bool:
24
- """
25
-
26
- Returns: Whether the node's output is bounded within a known range or not.
27
-
28
- """
29
- return is_number(self.min_output) and is_number(self.max_output)
@@ -1,38 +0,0 @@
1
- from tensorflow.keras.models import Model
2
-
3
- from model_compression_toolkit import FrameworkInfo
4
- from model_compression_toolkit.common.framework_info import ChannelAxis
5
- from model_compression_toolkit.common.model_validation import ModelValidation
6
- from model_compression_toolkit.keras.constants import CHANNELS_FORMAT, CHANNELS_FORMAT_LAST, CHANNELS_FORMAT_FIRST
7
-
8
-
9
- class KerasModelValidation(ModelValidation):
10
- """
11
- Class to define validation methods in order to validate the received Keras model to quantize.
12
- """
13
-
14
- def __init__(self, model: Model, fw_info: FrameworkInfo):
15
- """
16
- Initialize a KerasModelValidation object.
17
-
18
- Args:
19
- model: Keras model to check its validity.
20
- fw_info: Information about the framework of the model (Keras).
21
- """
22
-
23
- super(KerasModelValidation, self).__init__(model=model,
24
- fw_info=fw_info)
25
-
26
- def validate_output_channel_consistency(self):
27
- """
28
-
29
- Validate that output channels index in all layers of the model are the same.
30
- If the model has layers with different output channels index, an exception is thrown.
31
-
32
- """
33
- for layer in self.model.layers:
34
- data_format = layer.get_config().get(CHANNELS_FORMAT)
35
- if data_format is not None:
36
- assert (data_format == CHANNELS_FORMAT_LAST and self.fw_info.output_channel_index == ChannelAxis.NHWC
37
- or data_format == CHANNELS_FORMAT_FIRST and self.fw_info.output_channel_index == ChannelAxis.NCHW), \
38
- f'Model can not have layers with different data formats.'
@@ -1,60 +0,0 @@
1
- from typing import Any, Tuple
2
-
3
- import tensorflow as tf
4
- if tf.__version__ < "2.6":
5
- from tensorflow.keras.layers import Activation, ReLU
6
- else:
7
- from keras.layers import Activation, ReLU
8
-
9
- from model_compression_toolkit import FrameworkInfo
10
- from model_compression_toolkit.common import BaseNode
11
- from model_compression_toolkit.common.node_prior_info import NodePriorInfo
12
- from model_compression_toolkit.keras.constants import ACTIVATION, RELU_MAX_VALUE, NEGATIVE_SLOPE, THRESHOLD
13
-
14
-
15
- def create_node_prior_info(node: BaseNode,
16
- fw_info: FrameworkInfo):
17
- """
18
- Create a NodePriorInfo object for a given node.
19
-
20
- Args:
21
- node: Node to create its prior info.
22
- fw_info: Information about a specific framework the node was generated from.
23
-
24
- Returns:
25
- NodePriorInfo object with info about the node.
26
- """
27
-
28
- min_output, max_output = _get_min_max_outputs(node=node,
29
- fw_info=fw_info)
30
- return NodePriorInfo(min_output=min_output,
31
- max_output=max_output)
32
-
33
-
34
- def _get_min_max_outputs(node: BaseNode,
35
- fw_info: FrameworkInfo) -> Tuple[Any,Any]:
36
- """
37
- Return the min/max output values of a node if known.
38
- If one of them (or both of them) is unknown - return None instead of a value.
39
- Args:
40
- node: Node to create its prior info.
41
- fw_info: Information about a specific framework the node was generated from.
42
-
43
- Returns:
44
- Min/max output values if known.
45
- """
46
- min_output, max_output = None, None
47
-
48
- if node.layer_class == ReLU:
49
- min_output = node.framework_attr[THRESHOLD] if node.framework_attr[NEGATIVE_SLOPE] == 0 else None
50
- max_output = node.framework_attr[RELU_MAX_VALUE]
51
-
52
- elif fw_info.layers_has_min_max(node.layer_class):
53
- min_output, max_output = fw_info.layer_min_max_mapping[node.layer_class]
54
-
55
- elif node.layer_class == Activation and fw_info.activation_has_min_max(node.framework_attr[ACTIVATION]):
56
- min_output, max_output = fw_info.activation_min_max_mapping[node.framework_attr[ACTIVATION]]
57
-
58
- return min_output, max_output
59
-
60
-