mindstudio-probe 1.1.0__py3-none-any.whl → 1.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/METADATA +7 -6
- mindstudio_probe-1.2.1.dist-info/RECORD +396 -0
- {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/WHEEL +1 -1
- {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/entry_points.txt +0 -1
- msprobe/CMakeLists.txt +5 -0
- msprobe/README.md +51 -20
- msprobe/config.json +2 -3
- msprobe/core/advisor/advisor.py +8 -3
- msprobe/core/common/const.py +264 -15
- msprobe/core/common/exceptions.py +27 -3
- msprobe/core/common/file_utils.py +176 -26
- msprobe/core/common/inplace_op_checker.py +15 -0
- msprobe/core/common/inplace_ops.yaml +3 -0
- msprobe/core/common/log.py +27 -9
- msprobe/core/common/utils.py +204 -77
- msprobe/core/common_config.py +49 -14
- msprobe/core/compare/acc_compare.py +274 -198
- msprobe/core/compare/check.py +32 -33
- msprobe/core/compare/compare_cli.py +32 -14
- msprobe/core/compare/highlight.py +283 -127
- msprobe/core/compare/layer_mapping/__init__.py +19 -0
- msprobe/core/compare/layer_mapping/data_scope_parser.py +246 -0
- msprobe/core/compare/layer_mapping/layer_mapping.py +249 -0
- msprobe/core/compare/layer_mapping/postprocess_pass.py +95 -0
- msprobe/core/compare/merge_result/merge_result.py +380 -0
- msprobe/core/compare/merge_result/merge_result_cli.py +31 -0
- msprobe/core/compare/multiprocessing_compute.py +2 -2
- msprobe/core/compare/npy_compare.py +135 -144
- msprobe/core/compare/utils.py +419 -274
- msprobe/core/data_dump/data_collector.py +60 -28
- msprobe/core/data_dump/data_processor/base.py +84 -36
- msprobe/core/data_dump/data_processor/factory.py +5 -3
- msprobe/core/data_dump/data_processor/mindspore_processor.py +152 -18
- msprobe/core/data_dump/data_processor/pytorch_processor.py +267 -110
- msprobe/core/data_dump/json_writer.py +29 -1
- msprobe/core/data_dump/scope.py +119 -39
- msprobe/core/grad_probe/constant.py +27 -13
- msprobe/core/grad_probe/grad_compare.py +18 -1
- msprobe/core/grad_probe/utils.py +30 -2
- msprobe/core/overflow_check/abnormal_scene.py +189 -0
- msprobe/core/overflow_check/api_info.py +55 -0
- msprobe/core/overflow_check/checker.py +138 -0
- msprobe/core/overflow_check/filter.py +157 -0
- msprobe/core/overflow_check/ignore_rules.yaml +55 -0
- msprobe/core/overflow_check/level.py +22 -0
- msprobe/core/overflow_check/utils.py +28 -0
- msprobe/docs/01.installation.md +96 -7
- msprobe/docs/02.config_introduction.md +50 -23
- msprobe/docs/03.config_examples.md +2 -9
- msprobe/docs/04.kernel_dump_PyTorch.md +73 -0
- msprobe/docs/05.data_dump_PyTorch.md +93 -61
- msprobe/docs/06.data_dump_MindSpore.md +200 -95
- msprobe/docs/07.accuracy_checker_PyTorch.md +28 -28
- msprobe/docs/08.accuracy_checker_online_PyTorch.md +1 -6
- msprobe/docs/09.accuracy_checker_MindSpore.md +44 -8
- msprobe/docs/10.accuracy_compare_PyTorch.md +114 -50
- msprobe/docs/11.accuracy_compare_MindSpore.md +340 -48
- msprobe/docs/12.overflow_check_PyTorch.md +2 -2
- msprobe/docs/13.overflow_check_MindSpore.md +6 -6
- msprobe/docs/15.free_benchmarking_PyTorch.md +4 -5
- msprobe/docs/16.free_benchmarking_MindSpore.md +56 -37
- msprobe/docs/17.grad_probe.md +5 -6
- msprobe/docs/19.monitor.md +561 -0
- msprobe/docs/20.monitor_performance_baseline.md +52 -0
- msprobe/docs/21.visualization_PyTorch.md +466 -0
- msprobe/docs/22.visualization_MindSpore.md +481 -0
- msprobe/docs/23.generate_operator_PyTorch.md +107 -0
- msprobe/docs/24.code_mapping_Mindspore.md +28 -0
- msprobe/docs/25.tool_function_introduction.md +29 -0
- msprobe/docs/26.data_dump_PyTorch_baseline.md +37 -0
- msprobe/docs/27.dump_json_instruction.md +521 -0
- msprobe/docs/FAQ.md +29 -2
- msprobe/docs/accuracy_checker_MindSpore/accuracy_checker_MindSpore_baseline.md +14 -0
- msprobe/docs/data_dump_MindSpore/data_dump_MindSpore_baseline.md +22 -0
- msprobe/docs/data_dump_MindSpore/dynamic_graph_quick_start_example.md +211 -0
- msprobe/docs/img/compare_result.png +0 -0
- msprobe/docs/img/merge_result.png +0 -0
- msprobe/docs/img/monitor/cpu_info.png +0 -0
- msprobe/docs/img/visualization/fuzzy_match_ms.png +0 -0
- msprobe/docs/img/visualization/fuzzy_match_pt.png +0 -0
- msprobe/docs/img/visualization/tensorboard_1.png +0 -0
- msprobe/docs/img/visualization/tensorboard_2.png +0 -0
- msprobe/docs/img/visualization/vis_browser_1.png +0 -0
- msprobe/docs/img/visualization/vis_browser_2.png +0 -0
- msprobe/docs/img/visualization/vis_precision_info.png +0 -0
- msprobe/docs/img/visualization/vis_search_info.png +0 -0
- msprobe/docs/img/visualization/vis_show_info.png +0 -0
- msprobe/docs/img/visualization/vis_showcase.png +0 -0
- msprobe/docs/img/visualization/vis_unmatch_info.png +0 -0
- msprobe/docs/visualization/GPTModel.png +0 -0
- msprobe/docs/visualization/ParallelMLP.png +0 -0
- msprobe/docs/visualization/layer_mapping_example.md +132 -0
- msprobe/docs/visualization/mapping.png +0 -0
- msprobe/docs/visualization/mapping1.png +0 -0
- msprobe/docs/visualization/module_name.png +0 -0
- msprobe/docs/visualization/module_name1.png +0 -0
- msprobe/docs/visualization/no_mapping.png +0 -0
- msprobe/docs/visualization/no_mapping1.png +0 -0
- msprobe/docs/visualization/no_mapping_analyze.png +0 -0
- msprobe/docs/visualization/top_layer.png +0 -0
- msprobe/mindspore/__init__.py +25 -0
- msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +151 -151
- msprobe/mindspore/api_accuracy_checker/api_info.py +21 -6
- msprobe/mindspore/api_accuracy_checker/api_runner.py +43 -18
- msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +21 -7
- msprobe/mindspore/api_accuracy_checker/checker_support_api.yaml +77 -0
- msprobe/mindspore/api_accuracy_checker/cmd_parser.py +64 -1
- msprobe/mindspore/api_accuracy_checker/compute_element.py +64 -31
- msprobe/mindspore/api_accuracy_checker/data_manager.py +301 -0
- msprobe/mindspore/api_accuracy_checker/main.py +28 -3
- msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +212 -0
- msprobe/mindspore/api_accuracy_checker/multi_data_manager.py +60 -0
- msprobe/mindspore/api_accuracy_checker/type_mapping.py +22 -5
- msprobe/mindspore/api_accuracy_checker/utils.py +34 -17
- msprobe/mindspore/cell_processor.py +33 -12
- msprobe/mindspore/code_mapping/bind.py +264 -0
- msprobe/mindspore/code_mapping/cmd_parser.py +40 -0
- msprobe/mindspore/code_mapping/graph.py +49 -0
- msprobe/mindspore/code_mapping/graph_parser.py +226 -0
- msprobe/mindspore/code_mapping/main.py +24 -0
- msprobe/mindspore/code_mapping/processor.py +34 -0
- msprobe/mindspore/common/const.py +35 -13
- msprobe/mindspore/common/log.py +5 -9
- msprobe/mindspore/common/utils.py +88 -4
- msprobe/mindspore/compare/distributed_compare.py +22 -24
- msprobe/mindspore/compare/ms_compare.py +333 -268
- msprobe/mindspore/compare/ms_graph_compare.py +95 -52
- msprobe/mindspore/debugger/debugger_config.py +7 -1
- msprobe/mindspore/debugger/precision_debugger.py +87 -12
- msprobe/mindspore/dump/dump_tool_factory.py +3 -1
- msprobe/mindspore/dump/hook_cell/api_registry.py +95 -18
- msprobe/mindspore/dump/hook_cell/hook_cell.py +60 -38
- msprobe/mindspore/dump/hook_cell/primitive_hooks.py +45 -30
- msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +36 -1
- msprobe/mindspore/dump/hook_cell/wrap_api.py +92 -1
- msprobe/mindspore/dump/jit_dump.py +17 -5
- msprobe/mindspore/dump/kernel_dump/kernel_config.py +33 -0
- msprobe/mindspore/dump/kernel_graph_dump.py +9 -4
- msprobe/mindspore/dump/kernel_kbyk_dump.py +2 -4
- msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +140 -0
- msprobe/mindspore/dym_loader/hook_dynamic_loader.h +53 -0
- msprobe/mindspore/free_benchmark/api_pynative_self_check.py +156 -41
- msprobe/mindspore/free_benchmark/common/handler_params.py +1 -2
- msprobe/mindspore/free_benchmark/common/utils.py +19 -4
- msprobe/mindspore/free_benchmark/data/support_wrap_ops.yaml +0 -204
- msprobe/mindspore/free_benchmark/handler/base_handler.py +3 -3
- msprobe/mindspore/free_benchmark/handler/check_handler.py +4 -5
- msprobe/mindspore/free_benchmark/handler/fix_handler.py +4 -4
- msprobe/mindspore/free_benchmark/handler/handler_factory.py +4 -4
- msprobe/mindspore/free_benchmark/perturbation/add_noise.py +2 -2
- msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +15 -6
- msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +2 -2
- msprobe/mindspore/free_benchmark/perturbation/exchange_value.py +2 -2
- msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +13 -6
- msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +2 -2
- msprobe/mindspore/free_benchmark/self_check_tool_factory.py +2 -2
- msprobe/mindspore/grad_probe/global_context.py +28 -8
- msprobe/mindspore/grad_probe/grad_analyzer.py +50 -24
- msprobe/mindspore/grad_probe/grad_monitor.py +16 -1
- msprobe/mindspore/grad_probe/grad_stat_csv.py +33 -5
- msprobe/mindspore/grad_probe/hook.py +35 -12
- msprobe/mindspore/grad_probe/utils.py +18 -5
- msprobe/mindspore/mindtorch/__init__.py +18 -0
- msprobe/mindspore/mindtorch/mindtorch_adaptor.py +255 -0
- msprobe/mindspore/ms_config.py +27 -16
- msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +9 -4
- msprobe/mindspore/runtime.py +15 -0
- msprobe/mindspore/service.py +285 -113
- msprobe/mindspore/task_handler_factory.py +15 -0
- msprobe/msprobe.py +48 -10
- msprobe/pytorch/__init__.py +8 -6
- msprobe/pytorch/api_accuracy_checker/common/config.py +62 -0
- msprobe/pytorch/api_accuracy_checker/common/utils.py +31 -16
- msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +41 -8
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +103 -271
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_standard.yaml +4 -1
- msprobe/pytorch/api_accuracy_checker/compare/compare.py +69 -68
- msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +54 -0
- msprobe/pytorch/api_accuracy_checker/compare/compare_input.py +51 -0
- msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +2 -4
- msprobe/pytorch/api_accuracy_checker/generate_op_script/config_op.json +9 -0
- msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +478 -0
- msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +365 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/absolute_threshold.py +106 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/accumulative_error_compare.py +107 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/base_standard.py +151 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/benchmark_compare.py +226 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/binary_consistency.py +68 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/standard_config.py +218 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/standard_register.py +104 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/thousandth_standard.py +63 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/ulp_compare.py +200 -0
- msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +63 -2
- msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +21 -15
- msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +54 -22
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +140 -71
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +49 -8
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +9 -24
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +4 -12
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +5 -3
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/dump_dispatch.py +9 -4
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +3 -11
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +2 -2
- msprobe/pytorch/bench_functions/confusion_transpose.py +5 -1
- msprobe/pytorch/bench_functions/matmul_backward.py +12 -0
- msprobe/pytorch/bench_functions/npu_fusion_attention.py +142 -16
- msprobe/pytorch/bench_functions/rotary_mul.py +4 -0
- msprobe/pytorch/bench_functions/swiglu.py +10 -2
- msprobe/pytorch/common/parse_json.py +7 -6
- msprobe/pytorch/common/utils.py +101 -7
- msprobe/pytorch/compare/distributed_compare.py +17 -30
- msprobe/pytorch/compare/pt_compare.py +44 -22
- msprobe/pytorch/debugger/debugger_config.py +46 -27
- msprobe/pytorch/debugger/precision_debugger.py +42 -12
- msprobe/pytorch/dump/kernel_dump/kernel_config.py +33 -0
- msprobe/pytorch/dump/module_dump/module_dump.py +86 -0
- msprobe/pytorch/{module_processer.py → dump/module_dump/module_processer.py} +81 -10
- msprobe/pytorch/free_benchmark/common/constant.py +15 -0
- msprobe/pytorch/free_benchmark/common/counter.py +15 -0
- msprobe/pytorch/free_benchmark/common/enums.py +15 -0
- msprobe/pytorch/free_benchmark/common/params.py +10 -2
- msprobe/pytorch/free_benchmark/common/utils.py +29 -4
- msprobe/pytorch/free_benchmark/compare/grad_saver.py +20 -5
- msprobe/pytorch/free_benchmark/compare/single_benchmark.py +2 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +3 -1
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +6 -4
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +2 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +4 -0
- msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +41 -47
- msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +6 -5
- msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +0 -4
- msprobe/pytorch/grad_probe/grad_monitor.py +23 -6
- msprobe/pytorch/grad_probe/grad_stat_csv.py +40 -10
- msprobe/pytorch/hook_module/__init__.py +1 -1
- msprobe/pytorch/hook_module/hook_module.py +14 -11
- msprobe/pytorch/hook_module/register_optimizer_hook.py +59 -0
- msprobe/pytorch/hook_module/support_wrap_ops.yaml +35 -0
- msprobe/pytorch/hook_module/wrap_distributed.py +6 -8
- msprobe/pytorch/hook_module/wrap_functional.py +0 -38
- msprobe/pytorch/monitor/__init__.py +0 -0
- msprobe/pytorch/monitor/anomaly_analyse.py +201 -0
- msprobe/pytorch/monitor/anomaly_detect.py +425 -0
- msprobe/pytorch/monitor/csv2tb.py +166 -0
- msprobe/pytorch/monitor/distributed/__init__.py +0 -0
- msprobe/pytorch/monitor/distributed/distributed_ops.yaml +19 -0
- msprobe/pytorch/monitor/distributed/stack_blacklist.yaml +5 -0
- msprobe/pytorch/monitor/distributed/wrap_distributed.py +283 -0
- msprobe/pytorch/monitor/features.py +108 -0
- msprobe/pytorch/monitor/module_hook.py +1076 -0
- msprobe/pytorch/monitor/module_metric.py +172 -0
- msprobe/pytorch/monitor/module_spec_verifier.py +95 -0
- msprobe/pytorch/monitor/optimizer_collect.py +333 -0
- msprobe/pytorch/monitor/unittest/__init__.py +0 -0
- msprobe/pytorch/monitor/unittest/test_monitor.py +160 -0
- msprobe/pytorch/monitor/utils.py +321 -0
- msprobe/pytorch/monitor/visualizer.py +59 -0
- msprobe/pytorch/online_dispatch/__init__.py +2 -3
- msprobe/pytorch/online_dispatch/compare.py +29 -38
- msprobe/pytorch/online_dispatch/dispatch.py +58 -27
- msprobe/pytorch/online_dispatch/dump_compare.py +21 -9
- msprobe/pytorch/online_dispatch/single_compare.py +53 -32
- msprobe/pytorch/online_dispatch/torch_ops_config.yaml +1 -1
- msprobe/pytorch/online_dispatch/utils.py +49 -21
- msprobe/pytorch/parse_tool/lib/compare.py +21 -27
- msprobe/pytorch/parse_tool/lib/config.py +6 -8
- msprobe/pytorch/parse_tool/lib/file_desc.py +15 -1
- msprobe/pytorch/parse_tool/lib/interactive_cli.py +10 -10
- msprobe/pytorch/parse_tool/lib/parse_exception.py +7 -7
- msprobe/pytorch/parse_tool/lib/parse_tool.py +12 -12
- msprobe/pytorch/parse_tool/lib/utils.py +33 -53
- msprobe/pytorch/parse_tool/lib/visualization.py +11 -10
- msprobe/pytorch/pt_config.py +31 -8
- msprobe/pytorch/service.py +188 -108
- msprobe/visualization/__init__.py +14 -0
- msprobe/visualization/builder/__init__.py +14 -0
- msprobe/visualization/builder/graph_builder.py +222 -0
- msprobe/visualization/builder/msprobe_adapter.py +227 -0
- msprobe/visualization/compare/__init__.py +14 -0
- msprobe/visualization/compare/graph_comparator.py +180 -0
- msprobe/visualization/compare/mode_adapter.py +197 -0
- msprobe/visualization/graph/__init__.py +14 -0
- msprobe/visualization/graph/base_node.py +119 -0
- msprobe/visualization/graph/distributed_analyzer.py +318 -0
- msprobe/visualization/graph/graph.py +209 -0
- msprobe/visualization/graph/node_colors.py +95 -0
- msprobe/visualization/graph/node_op.py +39 -0
- msprobe/visualization/graph_service.py +288 -0
- msprobe/visualization/utils.py +217 -0
- mindstudio_probe-1.1.0.dist-info/RECORD +0 -287
- msprobe/docs/04.acl_config_examples.md +0 -78
- msprobe/mindspore/compare/layer_mapping.py +0 -146
- msprobe/mindspore/compare/modify_mapping.py +0 -107
- msprobe/mindspore/free_benchmark/decorator/dec_forward.py +0 -57
- msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +0 -122
- msprobe/pytorch/functional/module_dump.py +0 -84
- {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/LICENSE +0 -0
- {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/top_level.txt +0 -0
- /msprobe/mindspore/{free_benchmark/decorator → code_mapping}/__init__.py +0 -0
- /msprobe/pytorch/{functional → dump/module_dump}/__init__.py +0 -0
msprobe/core/data_dump/scope.py
CHANGED
|
@@ -14,36 +14,48 @@
|
|
|
14
14
|
# limitations under the License.
|
|
15
15
|
|
|
16
16
|
from abc import ABC, abstractmethod
|
|
17
|
+
import re
|
|
17
18
|
|
|
18
19
|
from msprobe.core.common.const import Const
|
|
19
20
|
from msprobe.core.common.exceptions import ScopeException
|
|
20
21
|
|
|
21
22
|
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
scope =
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
return
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
23
|
+
class ScopeFactory:
|
|
24
|
+
def __init__(self, config):
|
|
25
|
+
self.task = config.task
|
|
26
|
+
self.level = config.level
|
|
27
|
+
self.scope = config.scope
|
|
28
|
+
self.api_list = config.list
|
|
29
|
+
|
|
30
|
+
def build_scope(self):
|
|
31
|
+
if not self.scope and not self.api_list:
|
|
32
|
+
return None
|
|
33
|
+
if self.scope is None:
|
|
34
|
+
self.scope = []
|
|
35
|
+
if self.api_list is None:
|
|
36
|
+
self.api_list = []
|
|
37
|
+
if self.task == Const.FREE_BENCHMARK:
|
|
38
|
+
return ListScope(self.scope, self.api_list)
|
|
39
|
+
return self._build_range_scope()
|
|
40
|
+
|
|
41
|
+
def _build_range_scope(self):
|
|
42
|
+
api_range_scope = APIRangeScope(self.scope, self.api_list, self.level)
|
|
43
|
+
module_range_scope = ModuleRangeScope(self.scope, self.api_list, self.level)
|
|
44
|
+
mix_range_scope = MixRangeScope(self.scope, self.api_list, self.level)
|
|
45
|
+
|
|
46
|
+
if self.level == Const.LEVEL_MIX:
|
|
47
|
+
return mix_range_scope
|
|
48
|
+
|
|
49
|
+
if not self.scope:
|
|
50
|
+
return api_range_scope
|
|
51
|
+
if api_range_scope.is_valid and module_range_scope.is_valid:
|
|
52
|
+
raise ScopeException(ScopeException.InvalidScope, f"scope={self.scope}.")
|
|
53
|
+
elif api_range_scope.is_valid:
|
|
54
|
+
return api_range_scope
|
|
55
|
+
elif module_range_scope.is_valid:
|
|
56
|
+
return module_range_scope
|
|
57
|
+
else:
|
|
58
|
+
raise ScopeException(ScopeException.InvalidScope, f"scope={self.scope}")
|
|
47
59
|
|
|
48
60
|
|
|
49
61
|
class BaseScope(ABC):
|
|
@@ -51,7 +63,8 @@ class BaseScope(ABC):
|
|
|
51
63
|
Module_Type_API = "api"
|
|
52
64
|
module_type = ["Module", "Cell"]
|
|
53
65
|
|
|
54
|
-
def __init__(self, scope, api_list):
|
|
66
|
+
def __init__(self, scope, api_list, level=None):
|
|
67
|
+
self.level = level
|
|
55
68
|
scope, api_list = self.rectify_args(scope, api_list)
|
|
56
69
|
self.scope = scope
|
|
57
70
|
self.api_list = api_list
|
|
@@ -60,21 +73,21 @@ class BaseScope(ABC):
|
|
|
60
73
|
def rectify_args(scope, api_list):
|
|
61
74
|
if not isinstance(api_list, list):
|
|
62
75
|
raise ScopeException(ScopeException.InvalidApiStr,
|
|
63
|
-
|
|
76
|
+
f"api_list参数须配置为列表,实际类型为{type(api_list)}.")
|
|
64
77
|
for api in api_list:
|
|
65
78
|
if not isinstance(api, str):
|
|
66
79
|
raise ScopeException(ScopeException.InvalidApiStr,
|
|
67
|
-
|
|
80
|
+
f"api_list中的元素须配置为字符串,实际类型为{type(api)}.")
|
|
68
81
|
if isinstance(scope, str):
|
|
69
82
|
scope = [scope]
|
|
70
83
|
return scope, api_list
|
|
71
84
|
if not isinstance(scope, list):
|
|
72
85
|
raise ScopeException(ScopeException.InvalidScope,
|
|
73
|
-
|
|
86
|
+
f"scope参数须配置为字符串或列表,实际类型为{type(scope)}.")
|
|
74
87
|
for s in scope:
|
|
75
88
|
if not isinstance(s, str):
|
|
76
89
|
raise ScopeException(ScopeException.InvalidScope,
|
|
77
|
-
|
|
90
|
+
f"scope列表元素要求类型为字符串,实际类型为{type(s)}.")
|
|
78
91
|
return scope, api_list
|
|
79
92
|
|
|
80
93
|
@abstractmethod
|
|
@@ -95,7 +108,7 @@ class ListScope(BaseScope):
|
|
|
95
108
|
def rectify_args(scope, api_list):
|
|
96
109
|
if scope and api_list:
|
|
97
110
|
raise ScopeException(ScopeException.ArgConflict,
|
|
98
|
-
|
|
111
|
+
f"scope和api_list不可以同时配置,实际配置为scope={scope}, api_list={api_list}.")
|
|
99
112
|
return super(ListScope, ListScope).rectify_args(scope, api_list)
|
|
100
113
|
|
|
101
114
|
def check(self, name):
|
|
@@ -109,17 +122,37 @@ class RangeScope(BaseScope, ABC):
|
|
|
109
122
|
def __init__(self, *args):
|
|
110
123
|
super().__init__(*args)
|
|
111
124
|
self.in_scope = False
|
|
125
|
+
self.in_list = False
|
|
126
|
+
self.start_name_set = set()
|
|
112
127
|
self.is_valid = self.check_scope_is_valid()
|
|
113
128
|
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
129
|
+
def check_name_pattern(self, name):
|
|
130
|
+
options_pattern = "|".join(re.escape(option) for option in Const.DUMP_PREFIX)
|
|
131
|
+
api_pattern = rf"^({options_pattern})\..*\.\d+\.(forward|backward)$"
|
|
132
|
+
module_pattern = r"^(Cell|Module)\..*\.(forward|backward)\.\d+$"
|
|
133
|
+
|
|
134
|
+
if self.level == Const.LEVEL_L1:
|
|
135
|
+
if not re.match(api_pattern, name):
|
|
136
|
+
raise ScopeException(ScopeException.InvalidScope,
|
|
137
|
+
f"scope参数格式错误,要求格式为api完整命名,实际为{name}.")
|
|
138
|
+
|
|
139
|
+
if self.level == Const.LEVEL_L0:
|
|
140
|
+
if not re.match(module_pattern, name):
|
|
141
|
+
raise ScopeException(ScopeException.InvalidScope,
|
|
142
|
+
f"scope参数格式错误,要求格式为模块完整命名,实际为{name}.")
|
|
143
|
+
|
|
144
|
+
if self.level == Const.LEVEL_MIX:
|
|
145
|
+
if not re.match(api_pattern, name) and not re.match(module_pattern, name):
|
|
121
146
|
raise ScopeException(ScopeException.InvalidScope,
|
|
122
|
-
|
|
147
|
+
f"scope参数格式错误,要求格式为api或模块完整命名,实际为{name}.")
|
|
148
|
+
|
|
149
|
+
def rectify_args(self, scope, api_list):
|
|
150
|
+
scope, api_list = super(RangeScope, RangeScope).rectify_args(scope, api_list)
|
|
151
|
+
if scope and len(scope) != 2:
|
|
152
|
+
raise ScopeException(ScopeException.InvalidScope,
|
|
153
|
+
f"scope参数指定区间断点,须传入长度为2的列表,实际长度为{len(scope)}.")
|
|
154
|
+
for name in scope:
|
|
155
|
+
self.check_name_pattern(name)
|
|
123
156
|
return scope, api_list
|
|
124
157
|
|
|
125
158
|
@abstractmethod
|
|
@@ -192,3 +225,50 @@ class ModuleRangeScope(RangeScope):
|
|
|
192
225
|
if not self.scope or self.in_scope:
|
|
193
226
|
return self.check_api_list(name)
|
|
194
227
|
return False
|
|
228
|
+
|
|
229
|
+
|
|
230
|
+
class MixRangeScope(RangeScope):
|
|
231
|
+
def check_scope_is_valid(self):
|
|
232
|
+
return True if self.scope else False
|
|
233
|
+
|
|
234
|
+
def begin_module(self, module_name):
|
|
235
|
+
if self.scope and module_name == self.scope[0]:
|
|
236
|
+
self.in_scope = True
|
|
237
|
+
for name in self.api_list:
|
|
238
|
+
if name in module_name:
|
|
239
|
+
self.in_list = True
|
|
240
|
+
self.start_name_set.add(module_name) # 记录每一个开启in_list的module_name
|
|
241
|
+
|
|
242
|
+
def end_module(self, module_name):
|
|
243
|
+
if self.scope and module_name == self.scope[1]:
|
|
244
|
+
self.in_scope = False
|
|
245
|
+
self.start_name_set.discard(module_name) # 从集合中删除每一个module_name
|
|
246
|
+
if not self.start_name_set: # 如果集合为空,说明当前module_name是最后一个开启in_list的module_name
|
|
247
|
+
self.in_list = False # 关闭in_list
|
|
248
|
+
|
|
249
|
+
def check_api_list(self, api_name):
|
|
250
|
+
if not self.api_list:
|
|
251
|
+
return True
|
|
252
|
+
|
|
253
|
+
for name in self.api_list:
|
|
254
|
+
if name in api_name:
|
|
255
|
+
return True
|
|
256
|
+
return False
|
|
257
|
+
|
|
258
|
+
def check(self, name):
|
|
259
|
+
"""
|
|
260
|
+
dump时调用的接口,根据scope和api_list判断是否需要dump
|
|
261
|
+
"""
|
|
262
|
+
result = False
|
|
263
|
+
if self.scope and name == self.scope[0]:
|
|
264
|
+
self.in_scope = True
|
|
265
|
+
|
|
266
|
+
if not self.scope or self.in_scope:
|
|
267
|
+
if self.in_list:
|
|
268
|
+
result = True
|
|
269
|
+
else:
|
|
270
|
+
result = self.check_api_list(name)
|
|
271
|
+
|
|
272
|
+
if self.scope and name == self.scope[1]:
|
|
273
|
+
self.in_scope = False
|
|
274
|
+
return result
|
|
@@ -1,3 +1,17 @@
|
|
|
1
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
1
15
|
|
|
2
16
|
class GradConst:
|
|
3
17
|
|
|
@@ -60,16 +74,16 @@ class GradConst:
|
|
|
60
74
|
NORM = "norm"
|
|
61
75
|
|
|
62
76
|
level_adp = {
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
77
|
+
"L0": {
|
|
78
|
+
"header": [GradConst.MD5, GradConst.MAX, GradConst.MIN, GradConst.NORM, GradConst.SHAPE],
|
|
79
|
+
"have_grad_direction": False
|
|
80
|
+
},
|
|
81
|
+
"L1": {
|
|
82
|
+
"header": [GradConst.MAX, GradConst.MIN, GradConst.NORM, GradConst.SHAPE],
|
|
83
|
+
"have_grad_direction": True
|
|
84
|
+
},
|
|
85
|
+
"L2": {
|
|
86
|
+
"header": [GradConst.DISTRIBUTION, GradConst.MAX, GradConst.MIN, GradConst.NORM, GradConst.SHAPE],
|
|
87
|
+
"have_grad_direction": True
|
|
88
|
+
},
|
|
89
|
+
}
|
|
@@ -1,10 +1,25 @@
|
|
|
1
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
1
16
|
import os
|
|
2
17
|
from typing import List
|
|
3
18
|
|
|
4
19
|
from tqdm import tqdm
|
|
5
20
|
import matplotlib.pyplot as plt
|
|
6
21
|
|
|
7
|
-
from msprobe.core.common.file_utils import create_directory,
|
|
22
|
+
from msprobe.core.common.file_utils import create_directory, check_file_or_directory_path
|
|
8
23
|
from msprobe.core.common.log import logger
|
|
9
24
|
from msprobe.core.common.file_utils import remove_path, load_npy, write_csv, read_csv
|
|
10
25
|
from msprobe.core.grad_probe.constant import GradConst
|
|
@@ -33,6 +48,8 @@ class GradComparator:
|
|
|
33
48
|
|
|
34
49
|
@classmethod
|
|
35
50
|
def compare_distributed(cls, path1: str, path2: str, output_dir: str):
|
|
51
|
+
check_file_or_directory_path(path1, isdir=True)
|
|
52
|
+
check_file_or_directory_path(path2, isdir=True)
|
|
36
53
|
ranks = cls._get_matched_dirs(path1, path2, "rank")
|
|
37
54
|
logger.info(f"the following ranks will be compared: {ranks}")
|
|
38
55
|
if not ranks:
|
msprobe/core/grad_probe/utils.py
CHANGED
|
@@ -1,8 +1,24 @@
|
|
|
1
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
1
16
|
import re
|
|
2
17
|
from msprobe.core.grad_probe.constant import GradConst
|
|
3
18
|
from msprobe.core.common.log import logger
|
|
4
19
|
from msprobe.core.common.file_utils import write_csv, check_path_before_create, change_mode
|
|
5
20
|
from msprobe.core.common.const import FileCheckConst
|
|
21
|
+
from msprobe.core.common.utils import is_int
|
|
6
22
|
import matplotlib.pyplot as plt
|
|
7
23
|
|
|
8
24
|
|
|
@@ -26,13 +42,24 @@ def check_str(string, variable_name):
|
|
|
26
42
|
if not isinstance(string, str):
|
|
27
43
|
raise ValueError(f'The variable: "{variable_name}" is not a string.')
|
|
28
44
|
|
|
45
|
+
|
|
29
46
|
def check_bounds_element(bound):
|
|
30
|
-
return GradConst.BOUNDS_MINIMUM <= bound
|
|
47
|
+
return GradConst.BOUNDS_MINIMUM <= bound <= GradConst.BOUNDS_MAXIMUM
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
def check_param_element(param):
|
|
51
|
+
if not re.match(GradConst.PARAM_VALID_PATTERN, param):
|
|
52
|
+
return False
|
|
53
|
+
else:
|
|
54
|
+
return True
|
|
55
|
+
|
|
31
56
|
|
|
32
57
|
def check_bounds(bounds):
|
|
58
|
+
if not isinstance(bounds, list):
|
|
59
|
+
raise Exception(f"bounds must be a list")
|
|
33
60
|
prev = GradConst.BOUNDS_MINIMUM - 1
|
|
34
61
|
for element in bounds:
|
|
35
|
-
if not
|
|
62
|
+
if not is_int(element) and not isinstance(element, float):
|
|
36
63
|
raise Exception("bounds element is not int or float")
|
|
37
64
|
if not check_bounds_element(element):
|
|
38
65
|
raise Exception("bounds element is out of int64 range")
|
|
@@ -40,6 +67,7 @@ def check_bounds(bounds):
|
|
|
40
67
|
raise Exception("bounds list is not ascending")
|
|
41
68
|
prev = element
|
|
42
69
|
|
|
70
|
+
|
|
43
71
|
class ListCache(list):
|
|
44
72
|
threshold = 1000
|
|
45
73
|
|
|
@@ -0,0 +1,189 @@
|
|
|
1
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
from typing import List, Dict, Union, Any
|
|
17
|
+
|
|
18
|
+
import numpy as np
|
|
19
|
+
|
|
20
|
+
from msprobe.core.overflow_check.api_info import APIInfo
|
|
21
|
+
from msprobe.core.overflow_check.level import OverflowLevel
|
|
22
|
+
from msprobe.core.overflow_check.utils import has_nan_inf
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class AnomalyScene:
|
|
26
|
+
"""异常场景的基类"""
|
|
27
|
+
|
|
28
|
+
def __init__(self, api_info: APIInfo):
|
|
29
|
+
self.api_name = api_info.api_name
|
|
30
|
+
self.api_data = api_info
|
|
31
|
+
|
|
32
|
+
@property
|
|
33
|
+
def rank(self) -> OverflowLevel:
|
|
34
|
+
"""获取异常等级"""
|
|
35
|
+
raise NotImplementedError
|
|
36
|
+
|
|
37
|
+
@staticmethod
|
|
38
|
+
def _has_anomaly(data: Union[Dict, Any]) -> bool:
|
|
39
|
+
"""检查张量是否包含异常值"""
|
|
40
|
+
if isinstance(data, dict):
|
|
41
|
+
return has_nan_inf(data)
|
|
42
|
+
elif isinstance(data, list):
|
|
43
|
+
return any(AnomalyScene._has_anomaly(x) for x in data)
|
|
44
|
+
return False
|
|
45
|
+
|
|
46
|
+
def get_details(self) -> Dict:
|
|
47
|
+
"""获取异常详情"""
|
|
48
|
+
return {
|
|
49
|
+
'api_name': self.api_name,
|
|
50
|
+
'rank': self.rank.value,
|
|
51
|
+
'scene_type': self.__class__.__name__,
|
|
52
|
+
'input_args_anomaly_indices': self._get_anomaly_indices_from_list(self.api_data.input_args),
|
|
53
|
+
'input_kwargs_anomaly_keys': self._get_anomaly_keys_from_dict(self.api_data.input_kwargs),
|
|
54
|
+
'output_anomaly_indices': self._get_anomaly_indices_from_list(self.api_data.output_data)
|
|
55
|
+
}
|
|
56
|
+
|
|
57
|
+
def matches(self) -> bool:
|
|
58
|
+
"""
|
|
59
|
+
待子类实现对应匹配逻辑
|
|
60
|
+
Returns:
|
|
61
|
+
|
|
62
|
+
"""
|
|
63
|
+
raise NotImplementedError
|
|
64
|
+
|
|
65
|
+
def _get_anomaly_indices_from_list(self, data_list: List[Dict]) -> List[int]:
|
|
66
|
+
return [i for i, data in enumerate(data_list) if self._has_anomaly(data)]
|
|
67
|
+
|
|
68
|
+
def _get_anomaly_keys_from_dict(self, data_dict: Dict) -> List[str]:
|
|
69
|
+
return [key for key, data in data_dict.items() if self._has_anomaly(data)]
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
class InputOutputAnomalyScene(AnomalyScene):
|
|
73
|
+
"""输入输出异常检测的基类"""
|
|
74
|
+
def has_input_anomaly(self) -> bool:
|
|
75
|
+
"""检查输入是否有异常(包括args和kwargs)"""
|
|
76
|
+
# args
|
|
77
|
+
args_anomaly = any(self._has_anomaly(x) for x in self.api_data.input_args)
|
|
78
|
+
# kwargs
|
|
79
|
+
kwargs_anomaly = any(self._has_anomaly(x) for x in self.api_data.input_kwargs.values())
|
|
80
|
+
return args_anomaly or kwargs_anomaly
|
|
81
|
+
|
|
82
|
+
def has_output_anomaly(self) -> bool:
|
|
83
|
+
"""检查输出是否有异常"""
|
|
84
|
+
return any(self._has_anomaly(x) for x in self.api_data.output_data)
|
|
85
|
+
|
|
86
|
+
def matches(self) -> bool:
|
|
87
|
+
"""判断是否匹配该场景"""
|
|
88
|
+
raise NotImplementedError
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
class InputAnomalyOutputNormalScene(InputOutputAnomalyScene):
|
|
92
|
+
"""输入异常,输出正常场景"""
|
|
93
|
+
|
|
94
|
+
@property
|
|
95
|
+
def rank(self) -> OverflowLevel:
|
|
96
|
+
return OverflowLevel.MEDIUM
|
|
97
|
+
|
|
98
|
+
def matches(self) -> bool:
|
|
99
|
+
return self.has_input_anomaly() and not self.has_output_anomaly()
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
class InputAnomalyOutputAnomalyScene(InputOutputAnomalyScene):
|
|
103
|
+
"""输入异常,输出异常场景"""
|
|
104
|
+
|
|
105
|
+
@property
|
|
106
|
+
def rank(self) -> OverflowLevel:
|
|
107
|
+
return OverflowLevel.HIGH
|
|
108
|
+
|
|
109
|
+
def matches(self) -> bool:
|
|
110
|
+
return self.has_input_anomaly() and self.has_output_anomaly()
|
|
111
|
+
|
|
112
|
+
|
|
113
|
+
class InputNormalOutputAnomalyScene(InputOutputAnomalyScene):
|
|
114
|
+
"""输入正常,输出异常场景"""
|
|
115
|
+
|
|
116
|
+
@property
|
|
117
|
+
def rank(self) -> OverflowLevel:
|
|
118
|
+
return OverflowLevel.CRITICAL
|
|
119
|
+
|
|
120
|
+
def matches(self) -> bool:
|
|
121
|
+
return not self.has_input_anomaly() and self.has_output_anomaly()
|
|
122
|
+
|
|
123
|
+
|
|
124
|
+
class NumericalMutationScene(AnomalyScene):
|
|
125
|
+
"""
|
|
126
|
+
检查数值突变,统计输入args、kwargs中norm值,同时统计输出的norm最大值,计算差异,大于 threshold 则认为是异常情况
|
|
127
|
+
"""
|
|
128
|
+
def __init__(self, api_info: APIInfo, threshold: float = 100.0):
|
|
129
|
+
super().__init__(api_info)
|
|
130
|
+
self.threshold = threshold
|
|
131
|
+
|
|
132
|
+
@property
|
|
133
|
+
def rank(self) -> OverflowLevel:
|
|
134
|
+
return OverflowLevel.HIGH
|
|
135
|
+
|
|
136
|
+
@staticmethod
|
|
137
|
+
def _get_tensor_norms(data_list: List[Dict]) -> List[float]:
|
|
138
|
+
norms = []
|
|
139
|
+
for data in data_list:
|
|
140
|
+
if isinstance(data, dict) and data.get('type') == 'torch.Tensor':
|
|
141
|
+
norm = data.get('Norm')
|
|
142
|
+
if norm is not None and not np.isnan(norm):
|
|
143
|
+
norms.append(norm)
|
|
144
|
+
return norms
|
|
145
|
+
|
|
146
|
+
@staticmethod
|
|
147
|
+
def _get_kwargs_norms(data_dict: Dict) -> List[float]:
|
|
148
|
+
"""
|
|
149
|
+
获取kwargs中张量的范数列表
|
|
150
|
+
Args:
|
|
151
|
+
data_dict:
|
|
152
|
+
Returns:
|
|
153
|
+
"""
|
|
154
|
+
norms = []
|
|
155
|
+
for data in data_dict.values():
|
|
156
|
+
if isinstance(data, dict) and data.get('type') == 'torch.Tensor':
|
|
157
|
+
norm = data.get('Norm')
|
|
158
|
+
if norm is not None and not np.isnan(norm):
|
|
159
|
+
norms.append(norm)
|
|
160
|
+
return norms
|
|
161
|
+
|
|
162
|
+
def matches(self) -> bool:
|
|
163
|
+
"""
|
|
164
|
+
继承父类函数,实现数值突变检查
|
|
165
|
+
Returns:
|
|
166
|
+
"""
|
|
167
|
+
# 收集所有输入的范数
|
|
168
|
+
input_norms = (self._get_tensor_norms(self.api_data.input_args) +
|
|
169
|
+
self._get_kwargs_norms(self.api_data.input_kwargs))
|
|
170
|
+
# 收集所有输出的范数
|
|
171
|
+
output_norms = self._get_tensor_norms(self.api_data.output_data)
|
|
172
|
+
|
|
173
|
+
if not input_norms or not output_norms:
|
|
174
|
+
return False
|
|
175
|
+
|
|
176
|
+
max_input = max(input_norms)
|
|
177
|
+
max_output = max(output_norms)
|
|
178
|
+
|
|
179
|
+
if max_input == 0:
|
|
180
|
+
return max_output > self.threshold
|
|
181
|
+
return max_output / max_input > self.threshold
|
|
182
|
+
|
|
183
|
+
def get_details(self) -> Dict:
|
|
184
|
+
details = super().get_details()
|
|
185
|
+
details.update({
|
|
186
|
+
'threshold': self.threshold,
|
|
187
|
+
'scale_change_detected': self.matches()
|
|
188
|
+
})
|
|
189
|
+
return details
|
|
@@ -0,0 +1,55 @@
|
|
|
1
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
from dataclasses import dataclass
|
|
17
|
+
|
|
18
|
+
from typing import Dict, List
|
|
19
|
+
|
|
20
|
+
from msprobe.core.common.const import Const
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
@dataclass
|
|
24
|
+
class APIInfo:
|
|
25
|
+
api_name: str
|
|
26
|
+
torch_api_name: str
|
|
27
|
+
input_args: List[Dict]
|
|
28
|
+
input_kwargs: Dict
|
|
29
|
+
output_data: List[Dict]
|
|
30
|
+
|
|
31
|
+
def __init__(self, api_name, input_args=None, input_kwargs=None, output_data=None):
|
|
32
|
+
self.api_name = api_name
|
|
33
|
+
self.input_args = input_args
|
|
34
|
+
self.input_kwargs = input_kwargs
|
|
35
|
+
self.output_data = output_data
|
|
36
|
+
self.torch_api_name = self.extract_torch_api(self.api_name)
|
|
37
|
+
|
|
38
|
+
@staticmethod
|
|
39
|
+
def extract_torch_api(api_name) -> str:
|
|
40
|
+
"""
|
|
41
|
+
Process tensor api name to extract first two fields in lowercase.
|
|
42
|
+
"""
|
|
43
|
+
# Empty string checking
|
|
44
|
+
if not api_name.strip():
|
|
45
|
+
return ""
|
|
46
|
+
|
|
47
|
+
parts = api_name.split(Const.SEP)
|
|
48
|
+
|
|
49
|
+
# Handle different cases based on number of parts
|
|
50
|
+
if len(parts) == 0:
|
|
51
|
+
return ""
|
|
52
|
+
elif len(parts) == 1:
|
|
53
|
+
return parts[0].lower()
|
|
54
|
+
else:
|
|
55
|
+
return Const.SEP.join(parts[:2]).lower()
|