mindstudio-probe 1.1.0__py3-none-any.whl → 1.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/METADATA +7 -6
- mindstudio_probe-1.2.1.dist-info/RECORD +396 -0
- {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/WHEEL +1 -1
- {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/entry_points.txt +0 -1
- msprobe/CMakeLists.txt +5 -0
- msprobe/README.md +51 -20
- msprobe/config.json +2 -3
- msprobe/core/advisor/advisor.py +8 -3
- msprobe/core/common/const.py +264 -15
- msprobe/core/common/exceptions.py +27 -3
- msprobe/core/common/file_utils.py +176 -26
- msprobe/core/common/inplace_op_checker.py +15 -0
- msprobe/core/common/inplace_ops.yaml +3 -0
- msprobe/core/common/log.py +27 -9
- msprobe/core/common/utils.py +204 -77
- msprobe/core/common_config.py +49 -14
- msprobe/core/compare/acc_compare.py +274 -198
- msprobe/core/compare/check.py +32 -33
- msprobe/core/compare/compare_cli.py +32 -14
- msprobe/core/compare/highlight.py +283 -127
- msprobe/core/compare/layer_mapping/__init__.py +19 -0
- msprobe/core/compare/layer_mapping/data_scope_parser.py +246 -0
- msprobe/core/compare/layer_mapping/layer_mapping.py +249 -0
- msprobe/core/compare/layer_mapping/postprocess_pass.py +95 -0
- msprobe/core/compare/merge_result/merge_result.py +380 -0
- msprobe/core/compare/merge_result/merge_result_cli.py +31 -0
- msprobe/core/compare/multiprocessing_compute.py +2 -2
- msprobe/core/compare/npy_compare.py +135 -144
- msprobe/core/compare/utils.py +419 -274
- msprobe/core/data_dump/data_collector.py +60 -28
- msprobe/core/data_dump/data_processor/base.py +84 -36
- msprobe/core/data_dump/data_processor/factory.py +5 -3
- msprobe/core/data_dump/data_processor/mindspore_processor.py +152 -18
- msprobe/core/data_dump/data_processor/pytorch_processor.py +267 -110
- msprobe/core/data_dump/json_writer.py +29 -1
- msprobe/core/data_dump/scope.py +119 -39
- msprobe/core/grad_probe/constant.py +27 -13
- msprobe/core/grad_probe/grad_compare.py +18 -1
- msprobe/core/grad_probe/utils.py +30 -2
- msprobe/core/overflow_check/abnormal_scene.py +189 -0
- msprobe/core/overflow_check/api_info.py +55 -0
- msprobe/core/overflow_check/checker.py +138 -0
- msprobe/core/overflow_check/filter.py +157 -0
- msprobe/core/overflow_check/ignore_rules.yaml +55 -0
- msprobe/core/overflow_check/level.py +22 -0
- msprobe/core/overflow_check/utils.py +28 -0
- msprobe/docs/01.installation.md +96 -7
- msprobe/docs/02.config_introduction.md +50 -23
- msprobe/docs/03.config_examples.md +2 -9
- msprobe/docs/04.kernel_dump_PyTorch.md +73 -0
- msprobe/docs/05.data_dump_PyTorch.md +93 -61
- msprobe/docs/06.data_dump_MindSpore.md +200 -95
- msprobe/docs/07.accuracy_checker_PyTorch.md +28 -28
- msprobe/docs/08.accuracy_checker_online_PyTorch.md +1 -6
- msprobe/docs/09.accuracy_checker_MindSpore.md +44 -8
- msprobe/docs/10.accuracy_compare_PyTorch.md +114 -50
- msprobe/docs/11.accuracy_compare_MindSpore.md +340 -48
- msprobe/docs/12.overflow_check_PyTorch.md +2 -2
- msprobe/docs/13.overflow_check_MindSpore.md +6 -6
- msprobe/docs/15.free_benchmarking_PyTorch.md +4 -5
- msprobe/docs/16.free_benchmarking_MindSpore.md +56 -37
- msprobe/docs/17.grad_probe.md +5 -6
- msprobe/docs/19.monitor.md +561 -0
- msprobe/docs/20.monitor_performance_baseline.md +52 -0
- msprobe/docs/21.visualization_PyTorch.md +466 -0
- msprobe/docs/22.visualization_MindSpore.md +481 -0
- msprobe/docs/23.generate_operator_PyTorch.md +107 -0
- msprobe/docs/24.code_mapping_Mindspore.md +28 -0
- msprobe/docs/25.tool_function_introduction.md +29 -0
- msprobe/docs/26.data_dump_PyTorch_baseline.md +37 -0
- msprobe/docs/27.dump_json_instruction.md +521 -0
- msprobe/docs/FAQ.md +29 -2
- msprobe/docs/accuracy_checker_MindSpore/accuracy_checker_MindSpore_baseline.md +14 -0
- msprobe/docs/data_dump_MindSpore/data_dump_MindSpore_baseline.md +22 -0
- msprobe/docs/data_dump_MindSpore/dynamic_graph_quick_start_example.md +211 -0
- msprobe/docs/img/compare_result.png +0 -0
- msprobe/docs/img/merge_result.png +0 -0
- msprobe/docs/img/monitor/cpu_info.png +0 -0
- msprobe/docs/img/visualization/fuzzy_match_ms.png +0 -0
- msprobe/docs/img/visualization/fuzzy_match_pt.png +0 -0
- msprobe/docs/img/visualization/tensorboard_1.png +0 -0
- msprobe/docs/img/visualization/tensorboard_2.png +0 -0
- msprobe/docs/img/visualization/vis_browser_1.png +0 -0
- msprobe/docs/img/visualization/vis_browser_2.png +0 -0
- msprobe/docs/img/visualization/vis_precision_info.png +0 -0
- msprobe/docs/img/visualization/vis_search_info.png +0 -0
- msprobe/docs/img/visualization/vis_show_info.png +0 -0
- msprobe/docs/img/visualization/vis_showcase.png +0 -0
- msprobe/docs/img/visualization/vis_unmatch_info.png +0 -0
- msprobe/docs/visualization/GPTModel.png +0 -0
- msprobe/docs/visualization/ParallelMLP.png +0 -0
- msprobe/docs/visualization/layer_mapping_example.md +132 -0
- msprobe/docs/visualization/mapping.png +0 -0
- msprobe/docs/visualization/mapping1.png +0 -0
- msprobe/docs/visualization/module_name.png +0 -0
- msprobe/docs/visualization/module_name1.png +0 -0
- msprobe/docs/visualization/no_mapping.png +0 -0
- msprobe/docs/visualization/no_mapping1.png +0 -0
- msprobe/docs/visualization/no_mapping_analyze.png +0 -0
- msprobe/docs/visualization/top_layer.png +0 -0
- msprobe/mindspore/__init__.py +25 -0
- msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +151 -151
- msprobe/mindspore/api_accuracy_checker/api_info.py +21 -6
- msprobe/mindspore/api_accuracy_checker/api_runner.py +43 -18
- msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +21 -7
- msprobe/mindspore/api_accuracy_checker/checker_support_api.yaml +77 -0
- msprobe/mindspore/api_accuracy_checker/cmd_parser.py +64 -1
- msprobe/mindspore/api_accuracy_checker/compute_element.py +64 -31
- msprobe/mindspore/api_accuracy_checker/data_manager.py +301 -0
- msprobe/mindspore/api_accuracy_checker/main.py +28 -3
- msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +212 -0
- msprobe/mindspore/api_accuracy_checker/multi_data_manager.py +60 -0
- msprobe/mindspore/api_accuracy_checker/type_mapping.py +22 -5
- msprobe/mindspore/api_accuracy_checker/utils.py +34 -17
- msprobe/mindspore/cell_processor.py +33 -12
- msprobe/mindspore/code_mapping/bind.py +264 -0
- msprobe/mindspore/code_mapping/cmd_parser.py +40 -0
- msprobe/mindspore/code_mapping/graph.py +49 -0
- msprobe/mindspore/code_mapping/graph_parser.py +226 -0
- msprobe/mindspore/code_mapping/main.py +24 -0
- msprobe/mindspore/code_mapping/processor.py +34 -0
- msprobe/mindspore/common/const.py +35 -13
- msprobe/mindspore/common/log.py +5 -9
- msprobe/mindspore/common/utils.py +88 -4
- msprobe/mindspore/compare/distributed_compare.py +22 -24
- msprobe/mindspore/compare/ms_compare.py +333 -268
- msprobe/mindspore/compare/ms_graph_compare.py +95 -52
- msprobe/mindspore/debugger/debugger_config.py +7 -1
- msprobe/mindspore/debugger/precision_debugger.py +87 -12
- msprobe/mindspore/dump/dump_tool_factory.py +3 -1
- msprobe/mindspore/dump/hook_cell/api_registry.py +95 -18
- msprobe/mindspore/dump/hook_cell/hook_cell.py +60 -38
- msprobe/mindspore/dump/hook_cell/primitive_hooks.py +45 -30
- msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +36 -1
- msprobe/mindspore/dump/hook_cell/wrap_api.py +92 -1
- msprobe/mindspore/dump/jit_dump.py +17 -5
- msprobe/mindspore/dump/kernel_dump/kernel_config.py +33 -0
- msprobe/mindspore/dump/kernel_graph_dump.py +9 -4
- msprobe/mindspore/dump/kernel_kbyk_dump.py +2 -4
- msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +140 -0
- msprobe/mindspore/dym_loader/hook_dynamic_loader.h +53 -0
- msprobe/mindspore/free_benchmark/api_pynative_self_check.py +156 -41
- msprobe/mindspore/free_benchmark/common/handler_params.py +1 -2
- msprobe/mindspore/free_benchmark/common/utils.py +19 -4
- msprobe/mindspore/free_benchmark/data/support_wrap_ops.yaml +0 -204
- msprobe/mindspore/free_benchmark/handler/base_handler.py +3 -3
- msprobe/mindspore/free_benchmark/handler/check_handler.py +4 -5
- msprobe/mindspore/free_benchmark/handler/fix_handler.py +4 -4
- msprobe/mindspore/free_benchmark/handler/handler_factory.py +4 -4
- msprobe/mindspore/free_benchmark/perturbation/add_noise.py +2 -2
- msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +15 -6
- msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +2 -2
- msprobe/mindspore/free_benchmark/perturbation/exchange_value.py +2 -2
- msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +13 -6
- msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +2 -2
- msprobe/mindspore/free_benchmark/self_check_tool_factory.py +2 -2
- msprobe/mindspore/grad_probe/global_context.py +28 -8
- msprobe/mindspore/grad_probe/grad_analyzer.py +50 -24
- msprobe/mindspore/grad_probe/grad_monitor.py +16 -1
- msprobe/mindspore/grad_probe/grad_stat_csv.py +33 -5
- msprobe/mindspore/grad_probe/hook.py +35 -12
- msprobe/mindspore/grad_probe/utils.py +18 -5
- msprobe/mindspore/mindtorch/__init__.py +18 -0
- msprobe/mindspore/mindtorch/mindtorch_adaptor.py +255 -0
- msprobe/mindspore/ms_config.py +27 -16
- msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +9 -4
- msprobe/mindspore/runtime.py +15 -0
- msprobe/mindspore/service.py +285 -113
- msprobe/mindspore/task_handler_factory.py +15 -0
- msprobe/msprobe.py +48 -10
- msprobe/pytorch/__init__.py +8 -6
- msprobe/pytorch/api_accuracy_checker/common/config.py +62 -0
- msprobe/pytorch/api_accuracy_checker/common/utils.py +31 -16
- msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +41 -8
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +103 -271
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_standard.yaml +4 -1
- msprobe/pytorch/api_accuracy_checker/compare/compare.py +69 -68
- msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +54 -0
- msprobe/pytorch/api_accuracy_checker/compare/compare_input.py +51 -0
- msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +2 -4
- msprobe/pytorch/api_accuracy_checker/generate_op_script/config_op.json +9 -0
- msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +478 -0
- msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +365 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/absolute_threshold.py +106 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/accumulative_error_compare.py +107 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/base_standard.py +151 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/benchmark_compare.py +226 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/binary_consistency.py +68 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/standard_config.py +218 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/standard_register.py +104 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/thousandth_standard.py +63 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/ulp_compare.py +200 -0
- msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +63 -2
- msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +21 -15
- msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +54 -22
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +140 -71
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +49 -8
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +9 -24
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +4 -12
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +5 -3
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/dump_dispatch.py +9 -4
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +3 -11
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +2 -2
- msprobe/pytorch/bench_functions/confusion_transpose.py +5 -1
- msprobe/pytorch/bench_functions/matmul_backward.py +12 -0
- msprobe/pytorch/bench_functions/npu_fusion_attention.py +142 -16
- msprobe/pytorch/bench_functions/rotary_mul.py +4 -0
- msprobe/pytorch/bench_functions/swiglu.py +10 -2
- msprobe/pytorch/common/parse_json.py +7 -6
- msprobe/pytorch/common/utils.py +101 -7
- msprobe/pytorch/compare/distributed_compare.py +17 -30
- msprobe/pytorch/compare/pt_compare.py +44 -22
- msprobe/pytorch/debugger/debugger_config.py +46 -27
- msprobe/pytorch/debugger/precision_debugger.py +42 -12
- msprobe/pytorch/dump/kernel_dump/kernel_config.py +33 -0
- msprobe/pytorch/dump/module_dump/module_dump.py +86 -0
- msprobe/pytorch/{module_processer.py → dump/module_dump/module_processer.py} +81 -10
- msprobe/pytorch/free_benchmark/common/constant.py +15 -0
- msprobe/pytorch/free_benchmark/common/counter.py +15 -0
- msprobe/pytorch/free_benchmark/common/enums.py +15 -0
- msprobe/pytorch/free_benchmark/common/params.py +10 -2
- msprobe/pytorch/free_benchmark/common/utils.py +29 -4
- msprobe/pytorch/free_benchmark/compare/grad_saver.py +20 -5
- msprobe/pytorch/free_benchmark/compare/single_benchmark.py +2 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +3 -1
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +6 -4
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +2 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +4 -0
- msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +41 -47
- msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +6 -5
- msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +0 -4
- msprobe/pytorch/grad_probe/grad_monitor.py +23 -6
- msprobe/pytorch/grad_probe/grad_stat_csv.py +40 -10
- msprobe/pytorch/hook_module/__init__.py +1 -1
- msprobe/pytorch/hook_module/hook_module.py +14 -11
- msprobe/pytorch/hook_module/register_optimizer_hook.py +59 -0
- msprobe/pytorch/hook_module/support_wrap_ops.yaml +35 -0
- msprobe/pytorch/hook_module/wrap_distributed.py +6 -8
- msprobe/pytorch/hook_module/wrap_functional.py +0 -38
- msprobe/pytorch/monitor/__init__.py +0 -0
- msprobe/pytorch/monitor/anomaly_analyse.py +201 -0
- msprobe/pytorch/monitor/anomaly_detect.py +425 -0
- msprobe/pytorch/monitor/csv2tb.py +166 -0
- msprobe/pytorch/monitor/distributed/__init__.py +0 -0
- msprobe/pytorch/monitor/distributed/distributed_ops.yaml +19 -0
- msprobe/pytorch/monitor/distributed/stack_blacklist.yaml +5 -0
- msprobe/pytorch/monitor/distributed/wrap_distributed.py +283 -0
- msprobe/pytorch/monitor/features.py +108 -0
- msprobe/pytorch/monitor/module_hook.py +1076 -0
- msprobe/pytorch/monitor/module_metric.py +172 -0
- msprobe/pytorch/monitor/module_spec_verifier.py +95 -0
- msprobe/pytorch/monitor/optimizer_collect.py +333 -0
- msprobe/pytorch/monitor/unittest/__init__.py +0 -0
- msprobe/pytorch/monitor/unittest/test_monitor.py +160 -0
- msprobe/pytorch/monitor/utils.py +321 -0
- msprobe/pytorch/monitor/visualizer.py +59 -0
- msprobe/pytorch/online_dispatch/__init__.py +2 -3
- msprobe/pytorch/online_dispatch/compare.py +29 -38
- msprobe/pytorch/online_dispatch/dispatch.py +58 -27
- msprobe/pytorch/online_dispatch/dump_compare.py +21 -9
- msprobe/pytorch/online_dispatch/single_compare.py +53 -32
- msprobe/pytorch/online_dispatch/torch_ops_config.yaml +1 -1
- msprobe/pytorch/online_dispatch/utils.py +49 -21
- msprobe/pytorch/parse_tool/lib/compare.py +21 -27
- msprobe/pytorch/parse_tool/lib/config.py +6 -8
- msprobe/pytorch/parse_tool/lib/file_desc.py +15 -1
- msprobe/pytorch/parse_tool/lib/interactive_cli.py +10 -10
- msprobe/pytorch/parse_tool/lib/parse_exception.py +7 -7
- msprobe/pytorch/parse_tool/lib/parse_tool.py +12 -12
- msprobe/pytorch/parse_tool/lib/utils.py +33 -53
- msprobe/pytorch/parse_tool/lib/visualization.py +11 -10
- msprobe/pytorch/pt_config.py +31 -8
- msprobe/pytorch/service.py +188 -108
- msprobe/visualization/__init__.py +14 -0
- msprobe/visualization/builder/__init__.py +14 -0
- msprobe/visualization/builder/graph_builder.py +222 -0
- msprobe/visualization/builder/msprobe_adapter.py +227 -0
- msprobe/visualization/compare/__init__.py +14 -0
- msprobe/visualization/compare/graph_comparator.py +180 -0
- msprobe/visualization/compare/mode_adapter.py +197 -0
- msprobe/visualization/graph/__init__.py +14 -0
- msprobe/visualization/graph/base_node.py +119 -0
- msprobe/visualization/graph/distributed_analyzer.py +318 -0
- msprobe/visualization/graph/graph.py +209 -0
- msprobe/visualization/graph/node_colors.py +95 -0
- msprobe/visualization/graph/node_op.py +39 -0
- msprobe/visualization/graph_service.py +288 -0
- msprobe/visualization/utils.py +217 -0
- mindstudio_probe-1.1.0.dist-info/RECORD +0 -287
- msprobe/docs/04.acl_config_examples.md +0 -78
- msprobe/mindspore/compare/layer_mapping.py +0 -146
- msprobe/mindspore/compare/modify_mapping.py +0 -107
- msprobe/mindspore/free_benchmark/decorator/dec_forward.py +0 -57
- msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +0 -122
- msprobe/pytorch/functional/module_dump.py +0 -84
- {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/LICENSE +0 -0
- {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/top_level.txt +0 -0
- /msprobe/mindspore/{free_benchmark/decorator → code_mapping}/__init__.py +0 -0
- /msprobe/pytorch/{functional → dump/module_dump}/__init__.py +0 -0
|
@@ -20,6 +20,7 @@ from typing import Any, Optional, Tuple
|
|
|
20
20
|
import numpy as np
|
|
21
21
|
import torch
|
|
22
22
|
from msprobe.core.common.const import Const
|
|
23
|
+
from msprobe.core.common.exceptions import FreeBenchmarkException
|
|
23
24
|
from msprobe.pytorch.free_benchmark import logger
|
|
24
25
|
from msprobe.pytorch.free_benchmark.common.constant import ThresholdConfig
|
|
25
26
|
from msprobe.pytorch.free_benchmark.common.enums import (
|
|
@@ -88,12 +89,6 @@ class FuzzHandler(ABC):
|
|
|
88
89
|
)
|
|
89
90
|
return origin_output_chunks, perturbed_output_chunks
|
|
90
91
|
|
|
91
|
-
@staticmethod
|
|
92
|
-
def convert_overflow_ratio_to_consistent(ratio):
|
|
93
|
-
if math.isnan(ratio) or math.isinf(ratio):
|
|
94
|
-
return ThresholdConfig.COMP_CONSISTENT
|
|
95
|
-
return ratio
|
|
96
|
-
|
|
97
92
|
@abstractmethod
|
|
98
93
|
def get_threshold(self, dtype):
|
|
99
94
|
pass
|
|
@@ -106,49 +101,45 @@ class FuzzHandler(ABC):
|
|
|
106
101
|
self, origin_output, perturbed_output, norm_type, abs_tol
|
|
107
102
|
):
|
|
108
103
|
if norm_type == NormType.ENDLESS_NORM:
|
|
109
|
-
return self.
|
|
104
|
+
return self.calculate_max_ratio(origin_output, perturbed_output, abs_tol)
|
|
110
105
|
return ThresholdConfig.COMP_CONSISTENT
|
|
111
106
|
|
|
112
|
-
def
|
|
107
|
+
def calculate_max_ratio(self, origin_output, perturbed_output, abs_tol):
|
|
113
108
|
origin_output_chunks, perturbed_output_chunks = (
|
|
114
109
|
self.tensor_split_for_error_calculate(origin_output, perturbed_output)
|
|
115
110
|
)
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
111
|
+
if len(origin_output_chunks) != len(perturbed_output_chunks):
|
|
112
|
+
err_msg = (
|
|
113
|
+
f"For {self.params.api_name}, the number of compare tensor chunks is different: "
|
|
114
|
+
f"{len(origin_output_chunks)} != {len(perturbed_output_chunks)}. please check!"
|
|
115
|
+
)
|
|
116
|
+
raise FreeBenchmarkException(
|
|
117
|
+
FreeBenchmarkException.OutputIndexError, err_msg
|
|
118
|
+
)
|
|
119
|
+
|
|
120
|
+
max_ratio = ThresholdConfig.COMP_CONSISTENT
|
|
119
121
|
for i, chunk_origin in enumerate(origin_output_chunks):
|
|
120
122
|
if chunk_origin.nelement() == 0:
|
|
121
123
|
break
|
|
122
124
|
chunk_perturbed = perturbed_output_chunks[i]
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
TorchC.
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
125
|
+
# 如果乘积最小值 < 极小值乘积的负值,认为存在非极小值符号相反的情况
|
|
126
|
+
if TorchC.lt(
|
|
127
|
+
TorchC.min(TorchC.mul(chunk_origin, chunk_perturbed)), -(abs_tol**2)
|
|
128
|
+
):
|
|
129
|
+
return ThresholdConfig.SYMBOL_FLIPPING
|
|
130
|
+
# 求A/B B/A的比值前,将值限制在大于极小值范围内
|
|
131
|
+
clamp_origin = TorchC.clamp(TorchC.abs(chunk_origin), min=abs_tol)
|
|
132
|
+
clamp_perturbed = TorchC.clamp(TorchC.abs(chunk_perturbed), min=abs_tol)
|
|
133
|
+
# 对于计算结果为nan的情况,认为两者没有差异
|
|
134
|
+
ratio_tensor = TorchC.nan_to_num(
|
|
135
|
+
TorchC.div(clamp_origin, clamp_perturbed),
|
|
136
|
+
nan=ThresholdConfig.COMP_CONSISTENT,
|
|
130
137
|
)
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
),
|
|
137
|
-
1,
|
|
138
|
-
)
|
|
139
|
-
norm_values = TorchC.stack(
|
|
140
|
-
[TorchC.max(ratio_tensor1), TorchC.max(ratio_tensor2)]
|
|
141
|
-
)
|
|
142
|
-
max_ratio1, max_ratio2 = norm_values.tolist()
|
|
143
|
-
norm1 = max(norm1, self.convert_overflow_ratio_to_consistent(max_ratio1))
|
|
144
|
-
norm2 = max(norm2, self.convert_overflow_ratio_to_consistent(max_ratio2))
|
|
145
|
-
norm3 = min(norm3, self.convert_overflow_ratio_to_consistent(max_ratio1))
|
|
146
|
-
|
|
147
|
-
if norm3 < 0:
|
|
148
|
-
ratio = ThresholdConfig.SYMBOL_FLIPPING
|
|
149
|
-
else:
|
|
150
|
-
ratio = max(norm1, norm2)
|
|
151
|
-
return ratio
|
|
138
|
+
# 求A/B 和 B/A比值最大值,其中 B/A的最大值为 A/B的最小值的倒数
|
|
139
|
+
min_ratio, max_ratio = TorchC.stack([*TorchC.aminmax(ratio_tensor)]).tolist()
|
|
140
|
+
min_ratio_reciprocal = np.inf if min_ratio == 0 else 1 / min_ratio
|
|
141
|
+
max_ratio = max(max_ratio, min_ratio_reciprocal)
|
|
142
|
+
return max_ratio
|
|
152
143
|
|
|
153
144
|
def ratio_calculate(self, origin_output, perturbed_output, norm_type) -> float:
|
|
154
145
|
try:
|
|
@@ -189,6 +180,7 @@ class FuzzHandler(ABC):
|
|
|
189
180
|
f"[msprobe] Free Benchmark: For {self.params.api_name} "
|
|
190
181
|
f"The compare for output type {type(perturbed_output)} is not supported"
|
|
191
182
|
)
|
|
183
|
+
return True, 1
|
|
192
184
|
|
|
193
185
|
threshold = self.get_threshold(Tools.get_first_tensor_dtype(origin_output))
|
|
194
186
|
ratio = self.ratio_calculate(
|
|
@@ -210,10 +202,12 @@ class FuzzHandler(ABC):
|
|
|
210
202
|
)
|
|
211
203
|
npu_consistent = is_consistent
|
|
212
204
|
max_fuzz_ratio = (
|
|
213
|
-
max_fuzz_ratio
|
|
205
|
+
max_fuzz_ratio
|
|
206
|
+
if not isinstance(ratio, (int, float))
|
|
207
|
+
else max(max_fuzz_ratio, ratio)
|
|
214
208
|
)
|
|
215
|
-
data_params.is_consistent = is_consistent
|
|
216
|
-
if not is_consistent
|
|
209
|
+
data_params.is_consistent = is_consistent
|
|
210
|
+
if not is_consistent:
|
|
217
211
|
self.unequal_rows.append(
|
|
218
212
|
make_unequal_row(data_params, self.params, ratio=ratio)
|
|
219
213
|
)
|
|
@@ -225,12 +219,12 @@ class FuzzHandler(ABC):
|
|
|
225
219
|
)
|
|
226
220
|
npu_consistent = npu_consistent and is_consistent
|
|
227
221
|
max_fuzz_ratio = (
|
|
228
|
-
max_fuzz_ratio
|
|
229
|
-
|
|
230
|
-
|
|
231
|
-
is_consistent and data_params.is_consistent
|
|
222
|
+
max_fuzz_ratio
|
|
223
|
+
if not isinstance(ratio, (int, float))
|
|
224
|
+
else max(max_fuzz_ratio, ratio)
|
|
232
225
|
)
|
|
233
|
-
|
|
226
|
+
data_params.is_consistent = is_consistent
|
|
227
|
+
if not is_consistent:
|
|
234
228
|
self.unequal_rows.append(
|
|
235
229
|
make_unequal_row(
|
|
236
230
|
data_params, self.params, ratio=ratio, index=index_
|
|
@@ -15,10 +15,11 @@
|
|
|
15
15
|
|
|
16
16
|
from typing import Any
|
|
17
17
|
|
|
18
|
+
from msprobe.core.common.exceptions import FreeBenchmarkException
|
|
19
|
+
from msprobe.pytorch.free_benchmark import logger
|
|
18
20
|
from msprobe.pytorch.free_benchmark.common.params import DataParams
|
|
19
21
|
from msprobe.pytorch.free_benchmark.common.utils import Tools
|
|
20
22
|
from msprobe.pytorch.free_benchmark.result_handlers.base_handler import FuzzHandler
|
|
21
|
-
from msprobe.pytorch.free_benchmark import logger
|
|
22
23
|
|
|
23
24
|
|
|
24
25
|
class FixHandler(FuzzHandler):
|
|
@@ -31,9 +32,9 @@ class FixHandler(FuzzHandler):
|
|
|
31
32
|
return Tools.convert_fuzz_output_to_origin(
|
|
32
33
|
data_params.original_result, data_params.perturbed_result
|
|
33
34
|
)
|
|
34
|
-
except
|
|
35
|
-
logger.
|
|
35
|
+
except FreeBenchmarkException as e:
|
|
36
|
+
logger.warning(
|
|
36
37
|
f"[msprobe] Free Benchmark: For {self.params.api_name} "
|
|
37
|
-
f"Fix output failed
|
|
38
|
+
f"Fix output failed because of: \n{e}"
|
|
38
39
|
)
|
|
39
|
-
|
|
40
|
+
return data_params.original_result
|
|
@@ -75,10 +75,6 @@ class PreheatHandler(FuzzHandler):
|
|
|
75
75
|
if self.params.preheat_config.get("preheat_step") <= self.params.step:
|
|
76
76
|
return data_params.original_result
|
|
77
77
|
|
|
78
|
-
if not data_params.grad_unequal_flag:
|
|
79
|
-
data_params.grad_unequal_flag = True
|
|
80
|
-
data_params.is_consistent = False
|
|
81
|
-
return data_params.original_result
|
|
82
78
|
preheat_counter.add_api_called_time(self.pure_name)
|
|
83
79
|
|
|
84
80
|
if not self._is_take_a_sample():
|
|
@@ -1,15 +1,31 @@
|
|
|
1
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
1
16
|
import os
|
|
2
17
|
from collections import defaultdict
|
|
3
18
|
|
|
4
19
|
import torch
|
|
5
|
-
|
|
6
|
-
from torch.optim.optimizer import register_optimizer_step_pre_hook
|
|
7
|
-
from msprobe.pytorch.grad_probe.grad_stat_csv import GradStatCsv
|
|
8
|
-
from msprobe.core.grad_probe.utils import check_numeral_list_ascend, data_in_list_target
|
|
20
|
+
from msprobe.core.common.file_utils import remove_path, save_npy, write_csv, create_directory
|
|
9
21
|
from msprobe.core.grad_probe.constant import level_adp
|
|
22
|
+
from msprobe.core.grad_probe.utils import check_numeral_list_ascend, data_in_list_target
|
|
10
23
|
from msprobe.pytorch.common.log import logger
|
|
11
|
-
from msprobe.core.common.file_utils import remove_path, save_npy, write_csv, create_directory
|
|
12
24
|
from msprobe.pytorch.common.utils import get_rank_id, print_rank_0
|
|
25
|
+
from msprobe.pytorch.grad_probe.grad_stat_csv import GradStatCsv
|
|
26
|
+
|
|
27
|
+
if int(torch.__version__.split('.')[0]) >= 2:
|
|
28
|
+
from torch.optim.optimizer import register_optimizer_step_pre_hook
|
|
13
29
|
|
|
14
30
|
|
|
15
31
|
class GradientMonitor:
|
|
@@ -75,7 +91,7 @@ class GradientMonitor:
|
|
|
75
91
|
output_lines.append(grad_info)
|
|
76
92
|
if self._level_adp["have_grad_direction"]:
|
|
77
93
|
GradientMonitor.save_grad_direction(param_name, grad,
|
|
78
|
-
|
|
94
|
+
f'{self._output_path}/rank{self._rank}/step{self._step}')
|
|
79
95
|
output_dirpath = os.path.join(self._output_path, f"rank{getattr(self, '_rank')}")
|
|
80
96
|
if not os.path.isdir(output_dirpath):
|
|
81
97
|
create_directory(output_dirpath)
|
|
@@ -87,5 +103,6 @@ class GradientMonitor:
|
|
|
87
103
|
output_lines.insert(0, header_result)
|
|
88
104
|
write_csv(output_lines, output_path)
|
|
89
105
|
logger.info(f"write grad data to {output_path}")
|
|
106
|
+
|
|
90
107
|
if int(torch.__version__.split('.')[0]) >= 2:
|
|
91
108
|
register_optimizer_step_pre_hook(optimizer_pre_step_hook)
|
|
@@ -1,11 +1,27 @@
|
|
|
1
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
1
16
|
from abc import ABC, abstractmethod
|
|
2
17
|
from collections import namedtuple
|
|
3
18
|
import hashlib
|
|
19
|
+
from functools import wraps
|
|
4
20
|
import torch
|
|
5
21
|
from msprobe.core.grad_probe.constant import GradConst
|
|
6
22
|
|
|
7
|
-
|
|
8
|
-
|
|
23
|
+
CsvHeaderInput = namedtuple("CsvHeaderInput", ["bounds"])
|
|
24
|
+
CsvContentInput = namedtuple("CsvContentInput", ["grad", "bounds"])
|
|
9
25
|
|
|
10
26
|
|
|
11
27
|
class GradStatCsv:
|
|
@@ -15,7 +31,7 @@ class GradStatCsv:
|
|
|
15
31
|
def generate_csv_header(level, bounds):
|
|
16
32
|
header = ["param_name"]
|
|
17
33
|
for key in level["header"]:
|
|
18
|
-
csv_header_input =
|
|
34
|
+
csv_header_input = CsvHeaderInput(bounds=bounds)
|
|
19
35
|
header.extend(GradStatCsv.csv[key].generate_csv_header(csv_header_input))
|
|
20
36
|
return header
|
|
21
37
|
|
|
@@ -23,7 +39,7 @@ class GradStatCsv:
|
|
|
23
39
|
def generate_csv_line(param_name, level, grad, bounds):
|
|
24
40
|
line = [param_name]
|
|
25
41
|
for key in level["header"]:
|
|
26
|
-
csv_content_input =
|
|
42
|
+
csv_content_input = CsvContentInput(grad=grad, bounds=bounds)
|
|
27
43
|
line.extend(GradStatCsv.csv[key].generate_csv_content(csv_content_input))
|
|
28
44
|
return line
|
|
29
45
|
|
|
@@ -37,20 +53,24 @@ def register_csv_item(key, cls=None):
|
|
|
37
53
|
|
|
38
54
|
|
|
39
55
|
class CsvItem(ABC):
|
|
56
|
+
@staticmethod
|
|
40
57
|
@abstractmethod
|
|
41
58
|
def generate_csv_header(csv_header_input):
|
|
42
59
|
pass
|
|
43
60
|
|
|
61
|
+
@staticmethod
|
|
44
62
|
@abstractmethod
|
|
45
63
|
def generate_csv_content(csv_content_input):
|
|
46
64
|
pass
|
|
47
65
|
|
|
48
66
|
|
|
49
67
|
@register_csv_item(GradConst.MD5)
|
|
50
|
-
class
|
|
68
|
+
class CsvMd5(CsvItem):
|
|
69
|
+
@staticmethod
|
|
51
70
|
def generate_csv_header(csv_header_input):
|
|
52
71
|
return ["MD5"]
|
|
53
72
|
|
|
73
|
+
@staticmethod
|
|
54
74
|
def generate_csv_content(csv_content_input):
|
|
55
75
|
grad = csv_content_input.grad
|
|
56
76
|
tensor_bytes = grad.cpu().detach().float().numpy().tobytes()
|
|
@@ -59,7 +79,8 @@ class CSV_md5(CsvItem):
|
|
|
59
79
|
|
|
60
80
|
|
|
61
81
|
@register_csv_item(GradConst.DISTRIBUTION)
|
|
62
|
-
class
|
|
82
|
+
class CsvDistribution(CsvItem):
|
|
83
|
+
@staticmethod
|
|
63
84
|
def generate_csv_header(csv_header_input):
|
|
64
85
|
bounds = csv_header_input.bounds
|
|
65
86
|
intervals = []
|
|
@@ -73,6 +94,7 @@ class CSV_distribution(CsvItem):
|
|
|
73
94
|
|
|
74
95
|
return intervals
|
|
75
96
|
|
|
97
|
+
@staticmethod
|
|
76
98
|
def generate_csv_content(csv_content_input):
|
|
77
99
|
grad = csv_content_input.grad
|
|
78
100
|
bounds = csv_content_input.bounds
|
|
@@ -90,40 +112,48 @@ class CSV_distribution(CsvItem):
|
|
|
90
112
|
|
|
91
113
|
|
|
92
114
|
@register_csv_item(GradConst.MAX)
|
|
93
|
-
class
|
|
115
|
+
class CsvMax(CsvItem):
|
|
116
|
+
@staticmethod
|
|
94
117
|
def generate_csv_header(csv_header_input):
|
|
95
118
|
return ["max"]
|
|
96
119
|
|
|
120
|
+
@staticmethod
|
|
97
121
|
def generate_csv_content(csv_content_input):
|
|
98
122
|
grad = csv_content_input.grad
|
|
99
123
|
return [torch.max(grad).cpu().detach().float().numpy().tolist()]
|
|
100
124
|
|
|
101
125
|
|
|
102
126
|
@register_csv_item(GradConst.MIN)
|
|
103
|
-
class
|
|
127
|
+
class CsvMin(CsvItem):
|
|
128
|
+
@staticmethod
|
|
104
129
|
def generate_csv_header(csv_header_input):
|
|
105
130
|
return ["min"]
|
|
106
131
|
|
|
132
|
+
@staticmethod
|
|
107
133
|
def generate_csv_content(csv_content_input):
|
|
108
134
|
grad = csv_content_input.grad
|
|
109
135
|
return [torch.min(grad).cpu().detach().float().numpy().tolist()]
|
|
110
136
|
|
|
111
137
|
|
|
112
138
|
@register_csv_item(GradConst.NORM)
|
|
113
|
-
class
|
|
139
|
+
class CsvNorm(CsvItem):
|
|
140
|
+
@staticmethod
|
|
114
141
|
def generate_csv_header(csv_header_input):
|
|
115
142
|
return ["norm"]
|
|
116
143
|
|
|
144
|
+
@staticmethod
|
|
117
145
|
def generate_csv_content(csv_content_input):
|
|
118
146
|
grad = csv_content_input.grad
|
|
119
147
|
return [torch.norm(grad).cpu().detach().float().numpy().tolist()]
|
|
120
148
|
|
|
121
149
|
|
|
122
150
|
@register_csv_item(GradConst.SHAPE)
|
|
123
|
-
class
|
|
151
|
+
class CsvShape(CsvItem):
|
|
152
|
+
@staticmethod
|
|
124
153
|
def generate_csv_header(csv_header_input):
|
|
125
154
|
return ["shape"]
|
|
126
155
|
|
|
156
|
+
@staticmethod
|
|
127
157
|
def generate_csv_content(csv_content_input):
|
|
128
158
|
grad = csv_content_input.grad
|
|
129
159
|
return [list(grad.shape)]
|
|
@@ -15,17 +15,17 @@
|
|
|
15
15
|
|
|
16
16
|
import functools
|
|
17
17
|
import threading
|
|
18
|
+
from collections import defaultdict
|
|
18
19
|
|
|
19
20
|
import torch
|
|
20
21
|
import torch.nn as nn
|
|
21
22
|
import torch.utils.hooks as full_hooks
|
|
22
23
|
|
|
23
|
-
from msprobe.core.common.const import Const
|
|
24
24
|
torch_version_above_or_equal_2 = torch.__version__.split('+')[0] >= '2.0'
|
|
25
25
|
|
|
26
26
|
|
|
27
27
|
class HOOKModule(nn.Module):
|
|
28
|
-
module_count =
|
|
28
|
+
module_count = defaultdict(int)
|
|
29
29
|
inner_stop_hook = {}
|
|
30
30
|
|
|
31
31
|
def __init__(self, build_hook) -> None:
|
|
@@ -41,12 +41,7 @@ class HOOKModule(nn.Module):
|
|
|
41
41
|
if hasattr(self, "prefix_op_name_"):
|
|
42
42
|
self.prefix = self.prefix_op_name_
|
|
43
43
|
|
|
44
|
-
|
|
45
|
-
HOOKModule.module_count[self.prefix] = 1
|
|
46
|
-
self.prefix += '0' + Const.SEP
|
|
47
|
-
else:
|
|
48
|
-
HOOKModule.module_count[self.prefix] += 1
|
|
49
|
-
self.prefix = self.prefix + str(HOOKModule.module_count[self.prefix] - 1) + Const.SEP
|
|
44
|
+
self.forward_data_collected = False
|
|
50
45
|
forward_pre_hook, forward_hook, backward_hook, _ = build_hook(self.prefix)
|
|
51
46
|
if torch_version_above_or_equal_2:
|
|
52
47
|
self.register_forward_pre_hook(forward_pre_hook, with_kwargs=True)
|
|
@@ -66,9 +61,17 @@ class HOOKModule(nn.Module):
|
|
|
66
61
|
HOOKModule.inner_stop_hook[self.current_thread] = False
|
|
67
62
|
return result
|
|
68
63
|
|
|
69
|
-
@
|
|
70
|
-
def reset_module_stats(
|
|
71
|
-
|
|
64
|
+
@staticmethod
|
|
65
|
+
def reset_module_stats():
|
|
66
|
+
HOOKModule.module_count = defaultdict(int)
|
|
67
|
+
|
|
68
|
+
@staticmethod
|
|
69
|
+
def add_module_count(name):
|
|
70
|
+
HOOKModule.module_count[name] += 1
|
|
71
|
+
|
|
72
|
+
@staticmethod
|
|
73
|
+
def get_module_count(name):
|
|
74
|
+
return HOOKModule.module_count[name]
|
|
72
75
|
|
|
73
76
|
def _call_func(self, *args, **kwargs):
|
|
74
77
|
full_backward_hooks, non_full_backward_hooks = [], []
|
|
@@ -0,0 +1,59 @@
|
|
|
1
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
import torch
|
|
17
|
+
from msprobe.core.common.const import Const
|
|
18
|
+
from msprobe.pytorch.common.log import logger
|
|
19
|
+
|
|
20
|
+
torch_version_above_or_equal_2 = torch.__version__.split('+')[0] >= '2.0'
|
|
21
|
+
if torch_version_above_or_equal_2:
|
|
22
|
+
from torch.optim.optimizer import register_optimizer_step_pre_hook, register_optimizer_step_post_hook
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def register_optimizer_hook(data_collector):
|
|
26
|
+
def optimizer_pre_step_hook(optimizer, args, kwargs):
|
|
27
|
+
data_collector.optimizer_status = Const.OPTIMIZER
|
|
28
|
+
|
|
29
|
+
def optimizer_post_step_hook(optimizer, args, kwargs):
|
|
30
|
+
data_collector.optimizer_status = Const.END_PREFIX + Const.OPTIMIZER
|
|
31
|
+
|
|
32
|
+
def patch_clip_grad(func):
|
|
33
|
+
def wrapper(*args, **kwargs):
|
|
34
|
+
data_collector.optimizer_status = Const.CLIP_GRAD
|
|
35
|
+
func(*args, **kwargs)
|
|
36
|
+
data_collector.optimizer_status = Const.END_PREFIX + Const.CLIP_GRAD
|
|
37
|
+
|
|
38
|
+
return wrapper
|
|
39
|
+
|
|
40
|
+
if torch_version_above_or_equal_2:
|
|
41
|
+
register_optimizer_step_pre_hook(optimizer_pre_step_hook)
|
|
42
|
+
register_optimizer_step_post_hook(optimizer_post_step_hook)
|
|
43
|
+
else:
|
|
44
|
+
logger.info_on_rank_0("Pytorch version is below 2.0, cannot register optimizer hook.")
|
|
45
|
+
|
|
46
|
+
try:
|
|
47
|
+
torch.nn.utils.clip_grad_norm_ = patch_clip_grad(torch.nn.utils.clip_grad_norm_)
|
|
48
|
+
torch.nn.utils.clip_grad_norm = patch_clip_grad(torch.nn.utils.clip_grad_norm)
|
|
49
|
+
torch.nn.utils.clip_grad_value_ = patch_clip_grad(torch.nn.utils.clip_grad_value_)
|
|
50
|
+
except Exception as e:
|
|
51
|
+
logger.info_on_rank_0("Cannot patch clip grad function. detail:%s" % str(e))
|
|
52
|
+
|
|
53
|
+
try:
|
|
54
|
+
from megatron.core.optimizer import MegatronOptimizer
|
|
55
|
+
MegatronOptimizer.clip_grad_norm = patch_clip_grad(MegatronOptimizer.clip_grad_norm)
|
|
56
|
+
except ImportError:
|
|
57
|
+
pass
|
|
58
|
+
except Exception as e:
|
|
59
|
+
logger.info_on_rank_0("Cannot patch megatron clip grad function. detail:%s" % str(e))
|
|
@@ -138,6 +138,10 @@ functional:
|
|
|
138
138
|
- fold
|
|
139
139
|
- multi_head_attention_forward
|
|
140
140
|
- scaled_dot_product_attention
|
|
141
|
+
- lp_pool3d
|
|
142
|
+
- dropout1d
|
|
143
|
+
- mish
|
|
144
|
+
- huber_loss
|
|
141
145
|
|
|
142
146
|
tensor:
|
|
143
147
|
- __add__
|
|
@@ -172,6 +176,7 @@ tensor:
|
|
|
172
176
|
- __sub__
|
|
173
177
|
- __truediv__
|
|
174
178
|
- __xor__
|
|
179
|
+
- __pow__
|
|
175
180
|
- abs
|
|
176
181
|
- abs_
|
|
177
182
|
- absolute
|
|
@@ -557,6 +562,27 @@ tensor:
|
|
|
557
562
|
- view_as
|
|
558
563
|
- xlogy
|
|
559
564
|
- xlogy_
|
|
565
|
+
- split
|
|
566
|
+
- stft
|
|
567
|
+
- nan_to_num
|
|
568
|
+
- dsplit
|
|
569
|
+
- orgqr
|
|
570
|
+
- bitwise_left_shift_
|
|
571
|
+
- arctan2
|
|
572
|
+
- histogram
|
|
573
|
+
- q_zero_point
|
|
574
|
+
- adjoint
|
|
575
|
+
- ormqr
|
|
576
|
+
- bitwise_right_shift_
|
|
577
|
+
- nanquantile
|
|
578
|
+
- lu
|
|
579
|
+
- quantile
|
|
580
|
+
- arctan2_
|
|
581
|
+
- qr
|
|
582
|
+
- diagonal_scatter
|
|
583
|
+
- corrcoef
|
|
584
|
+
- vsplit
|
|
585
|
+
- aminmax
|
|
560
586
|
|
|
561
587
|
torch:
|
|
562
588
|
- linalg.norm
|
|
@@ -1130,6 +1156,15 @@ torch_npu:
|
|
|
1130
1156
|
- npu_prompt_flash_attention
|
|
1131
1157
|
- npu_lstm
|
|
1132
1158
|
- npu_apply_adam
|
|
1159
|
+
- npu_apply_adam_w
|
|
1160
|
+
- npu_anti_quant
|
|
1161
|
+
- npu_grouped_matmu
|
|
1162
|
+
- npu_quant_scatter
|
|
1163
|
+
- npu_group_norm_silu
|
|
1164
|
+
- npu_format_cast
|
|
1165
|
+
- npu_moe_finalize_routing
|
|
1166
|
+
- npu_moe_gating_top_k_softmax
|
|
1167
|
+
- npu_trans_quant_param
|
|
1133
1168
|
|
|
1134
1169
|
aten:
|
|
1135
1170
|
- signbit
|
|
@@ -21,7 +21,6 @@ from msprobe.pytorch.hook_module.hook_module import HOOKModule
|
|
|
21
21
|
from msprobe.pytorch.common.utils import torch_device_guard
|
|
22
22
|
from msprobe.core.common.const import Const
|
|
23
23
|
from msprobe.core.common.file_utils import load_yaml
|
|
24
|
-
from msprobe.core.common.inplace_op_checker import InplaceOpChecker
|
|
25
24
|
|
|
26
25
|
|
|
27
26
|
cur_path = os.path.dirname(os.path.realpath(__file__))
|
|
@@ -49,17 +48,16 @@ class DistributedOPTemplate(HOOKModule):
|
|
|
49
48
|
self.op_name_ = op_name
|
|
50
49
|
self.prefix_op_name_ = "Distributed" + Const.SEP + str(op_name) + Const.SEP
|
|
51
50
|
super().__init__(build_hook)
|
|
52
|
-
if not self.stop_hook
|
|
53
|
-
self.
|
|
51
|
+
if not self.stop_hook:
|
|
52
|
+
self.op_is_distributed = True
|
|
54
53
|
|
|
55
54
|
@torch_device_guard
|
|
56
55
|
def forward(self, *args, **kwargs):
|
|
56
|
+
handle = distributed_func.get(self.op_name_)(*args, **kwargs)
|
|
57
57
|
if kwargs.get("async_op") or self.op_name_ in ["isend", "irecv"]:
|
|
58
|
-
handle
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
else:
|
|
62
|
-
return distributed_func.get(self.op_name_)(*args, **kwargs)
|
|
58
|
+
if handle and hasattr(handle, 'wait'):
|
|
59
|
+
handle.wait()
|
|
60
|
+
return handle
|
|
63
61
|
|
|
64
62
|
|
|
65
63
|
def wrap_distributed_op(op_name, hook):
|
|
@@ -23,44 +23,6 @@ from msprobe.pytorch.common.log import logger
|
|
|
23
23
|
from msprobe.core.common.file_utils import load_yaml
|
|
24
24
|
|
|
25
25
|
|
|
26
|
-
def remove_dropout():
|
|
27
|
-
if torch.__version__ > "1.8":
|
|
28
|
-
logger.info_on_rank_0("For precision comparison, the probability p in the dropout method is set to 0.")
|
|
29
|
-
import torch.nn.functional as F
|
|
30
|
-
from torch import _VF
|
|
31
|
-
from torch.overrides import has_torch_function_unary, handle_torch_function
|
|
32
|
-
|
|
33
|
-
def function_dropout(input: torch.Tensor, p: float = 0.5, training: bool = True,
|
|
34
|
-
inplace: bool = False) -> torch.Tensor:
|
|
35
|
-
if has_torch_function_unary(input):
|
|
36
|
-
return handle_torch_function(
|
|
37
|
-
function_dropout, (input,), input, p=0., training=training, inplace=inplace)
|
|
38
|
-
if p < 0.0 or p > 1.0:
|
|
39
|
-
raise ValueError("dropout probability has to be between 0 and 1, " "but got {}".format(p))
|
|
40
|
-
return _VF.dropout_(input, 0., training) if inplace else _VF.dropout(input, 0., training)
|
|
41
|
-
|
|
42
|
-
def function_dropout2d(input: torch.Tensor, p: float = 0.5, training: bool = True,
|
|
43
|
-
inplace: bool = False) -> torch.Tensor:
|
|
44
|
-
if has_torch_function_unary(input):
|
|
45
|
-
return handle_torch_function(
|
|
46
|
-
function_dropout2d, (input,), input, p=0., training=training, inplace=inplace)
|
|
47
|
-
if p < 0.0 or p > 1.0:
|
|
48
|
-
raise ValueError("dropout probability has to be between 0 and 1, " "but got {}".format(p))
|
|
49
|
-
return _VF.feature_dropout_(input, 0., training) if inplace else _VF.feature_dropout(input, 0., training)
|
|
50
|
-
|
|
51
|
-
def function_dropout3d(input: torch.Tensor, p: float = 0.5, training: bool = True,
|
|
52
|
-
inplace: bool = False) -> torch.Tensor:
|
|
53
|
-
if has_torch_function_unary(input):
|
|
54
|
-
return handle_torch_function(
|
|
55
|
-
function_dropout3d, (input,), input, p=0., training=training, inplace=inplace)
|
|
56
|
-
if p < 0.0 or p > 1.0:
|
|
57
|
-
raise ValueError("dropout probability has to be between 0 and 1, " "but got {}".format(p))
|
|
58
|
-
return _VF.feature_dropout_(input, 0., training) if inplace else _VF.feature_dropout(input, 0., training)
|
|
59
|
-
|
|
60
|
-
F.dropout = function_dropout
|
|
61
|
-
F.dropout2d = function_dropout2d
|
|
62
|
-
F.dropout3d = function_dropout3d
|
|
63
|
-
|
|
64
26
|
cur_path = os.path.dirname(os.path.realpath(__file__))
|
|
65
27
|
yaml_path = os.path.join(cur_path, "support_wrap_ops.yaml")
|
|
66
28
|
|
|
File without changes
|