mindstudio-probe 1.1.0__py3-none-any.whl → 1.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/METADATA +7 -6
- mindstudio_probe-1.2.1.dist-info/RECORD +396 -0
- {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/WHEEL +1 -1
- {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/entry_points.txt +0 -1
- msprobe/CMakeLists.txt +5 -0
- msprobe/README.md +51 -20
- msprobe/config.json +2 -3
- msprobe/core/advisor/advisor.py +8 -3
- msprobe/core/common/const.py +264 -15
- msprobe/core/common/exceptions.py +27 -3
- msprobe/core/common/file_utils.py +176 -26
- msprobe/core/common/inplace_op_checker.py +15 -0
- msprobe/core/common/inplace_ops.yaml +3 -0
- msprobe/core/common/log.py +27 -9
- msprobe/core/common/utils.py +204 -77
- msprobe/core/common_config.py +49 -14
- msprobe/core/compare/acc_compare.py +274 -198
- msprobe/core/compare/check.py +32 -33
- msprobe/core/compare/compare_cli.py +32 -14
- msprobe/core/compare/highlight.py +283 -127
- msprobe/core/compare/layer_mapping/__init__.py +19 -0
- msprobe/core/compare/layer_mapping/data_scope_parser.py +246 -0
- msprobe/core/compare/layer_mapping/layer_mapping.py +249 -0
- msprobe/core/compare/layer_mapping/postprocess_pass.py +95 -0
- msprobe/core/compare/merge_result/merge_result.py +380 -0
- msprobe/core/compare/merge_result/merge_result_cli.py +31 -0
- msprobe/core/compare/multiprocessing_compute.py +2 -2
- msprobe/core/compare/npy_compare.py +135 -144
- msprobe/core/compare/utils.py +419 -274
- msprobe/core/data_dump/data_collector.py +60 -28
- msprobe/core/data_dump/data_processor/base.py +84 -36
- msprobe/core/data_dump/data_processor/factory.py +5 -3
- msprobe/core/data_dump/data_processor/mindspore_processor.py +152 -18
- msprobe/core/data_dump/data_processor/pytorch_processor.py +267 -110
- msprobe/core/data_dump/json_writer.py +29 -1
- msprobe/core/data_dump/scope.py +119 -39
- msprobe/core/grad_probe/constant.py +27 -13
- msprobe/core/grad_probe/grad_compare.py +18 -1
- msprobe/core/grad_probe/utils.py +30 -2
- msprobe/core/overflow_check/abnormal_scene.py +189 -0
- msprobe/core/overflow_check/api_info.py +55 -0
- msprobe/core/overflow_check/checker.py +138 -0
- msprobe/core/overflow_check/filter.py +157 -0
- msprobe/core/overflow_check/ignore_rules.yaml +55 -0
- msprobe/core/overflow_check/level.py +22 -0
- msprobe/core/overflow_check/utils.py +28 -0
- msprobe/docs/01.installation.md +96 -7
- msprobe/docs/02.config_introduction.md +50 -23
- msprobe/docs/03.config_examples.md +2 -9
- msprobe/docs/04.kernel_dump_PyTorch.md +73 -0
- msprobe/docs/05.data_dump_PyTorch.md +93 -61
- msprobe/docs/06.data_dump_MindSpore.md +200 -95
- msprobe/docs/07.accuracy_checker_PyTorch.md +28 -28
- msprobe/docs/08.accuracy_checker_online_PyTorch.md +1 -6
- msprobe/docs/09.accuracy_checker_MindSpore.md +44 -8
- msprobe/docs/10.accuracy_compare_PyTorch.md +114 -50
- msprobe/docs/11.accuracy_compare_MindSpore.md +340 -48
- msprobe/docs/12.overflow_check_PyTorch.md +2 -2
- msprobe/docs/13.overflow_check_MindSpore.md +6 -6
- msprobe/docs/15.free_benchmarking_PyTorch.md +4 -5
- msprobe/docs/16.free_benchmarking_MindSpore.md +56 -37
- msprobe/docs/17.grad_probe.md +5 -6
- msprobe/docs/19.monitor.md +561 -0
- msprobe/docs/20.monitor_performance_baseline.md +52 -0
- msprobe/docs/21.visualization_PyTorch.md +466 -0
- msprobe/docs/22.visualization_MindSpore.md +481 -0
- msprobe/docs/23.generate_operator_PyTorch.md +107 -0
- msprobe/docs/24.code_mapping_Mindspore.md +28 -0
- msprobe/docs/25.tool_function_introduction.md +29 -0
- msprobe/docs/26.data_dump_PyTorch_baseline.md +37 -0
- msprobe/docs/27.dump_json_instruction.md +521 -0
- msprobe/docs/FAQ.md +29 -2
- msprobe/docs/accuracy_checker_MindSpore/accuracy_checker_MindSpore_baseline.md +14 -0
- msprobe/docs/data_dump_MindSpore/data_dump_MindSpore_baseline.md +22 -0
- msprobe/docs/data_dump_MindSpore/dynamic_graph_quick_start_example.md +211 -0
- msprobe/docs/img/compare_result.png +0 -0
- msprobe/docs/img/merge_result.png +0 -0
- msprobe/docs/img/monitor/cpu_info.png +0 -0
- msprobe/docs/img/visualization/fuzzy_match_ms.png +0 -0
- msprobe/docs/img/visualization/fuzzy_match_pt.png +0 -0
- msprobe/docs/img/visualization/tensorboard_1.png +0 -0
- msprobe/docs/img/visualization/tensorboard_2.png +0 -0
- msprobe/docs/img/visualization/vis_browser_1.png +0 -0
- msprobe/docs/img/visualization/vis_browser_2.png +0 -0
- msprobe/docs/img/visualization/vis_precision_info.png +0 -0
- msprobe/docs/img/visualization/vis_search_info.png +0 -0
- msprobe/docs/img/visualization/vis_show_info.png +0 -0
- msprobe/docs/img/visualization/vis_showcase.png +0 -0
- msprobe/docs/img/visualization/vis_unmatch_info.png +0 -0
- msprobe/docs/visualization/GPTModel.png +0 -0
- msprobe/docs/visualization/ParallelMLP.png +0 -0
- msprobe/docs/visualization/layer_mapping_example.md +132 -0
- msprobe/docs/visualization/mapping.png +0 -0
- msprobe/docs/visualization/mapping1.png +0 -0
- msprobe/docs/visualization/module_name.png +0 -0
- msprobe/docs/visualization/module_name1.png +0 -0
- msprobe/docs/visualization/no_mapping.png +0 -0
- msprobe/docs/visualization/no_mapping1.png +0 -0
- msprobe/docs/visualization/no_mapping_analyze.png +0 -0
- msprobe/docs/visualization/top_layer.png +0 -0
- msprobe/mindspore/__init__.py +25 -0
- msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +151 -151
- msprobe/mindspore/api_accuracy_checker/api_info.py +21 -6
- msprobe/mindspore/api_accuracy_checker/api_runner.py +43 -18
- msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +21 -7
- msprobe/mindspore/api_accuracy_checker/checker_support_api.yaml +77 -0
- msprobe/mindspore/api_accuracy_checker/cmd_parser.py +64 -1
- msprobe/mindspore/api_accuracy_checker/compute_element.py +64 -31
- msprobe/mindspore/api_accuracy_checker/data_manager.py +301 -0
- msprobe/mindspore/api_accuracy_checker/main.py +28 -3
- msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +212 -0
- msprobe/mindspore/api_accuracy_checker/multi_data_manager.py +60 -0
- msprobe/mindspore/api_accuracy_checker/type_mapping.py +22 -5
- msprobe/mindspore/api_accuracy_checker/utils.py +34 -17
- msprobe/mindspore/cell_processor.py +33 -12
- msprobe/mindspore/code_mapping/bind.py +264 -0
- msprobe/mindspore/code_mapping/cmd_parser.py +40 -0
- msprobe/mindspore/code_mapping/graph.py +49 -0
- msprobe/mindspore/code_mapping/graph_parser.py +226 -0
- msprobe/mindspore/code_mapping/main.py +24 -0
- msprobe/mindspore/code_mapping/processor.py +34 -0
- msprobe/mindspore/common/const.py +35 -13
- msprobe/mindspore/common/log.py +5 -9
- msprobe/mindspore/common/utils.py +88 -4
- msprobe/mindspore/compare/distributed_compare.py +22 -24
- msprobe/mindspore/compare/ms_compare.py +333 -268
- msprobe/mindspore/compare/ms_graph_compare.py +95 -52
- msprobe/mindspore/debugger/debugger_config.py +7 -1
- msprobe/mindspore/debugger/precision_debugger.py +87 -12
- msprobe/mindspore/dump/dump_tool_factory.py +3 -1
- msprobe/mindspore/dump/hook_cell/api_registry.py +95 -18
- msprobe/mindspore/dump/hook_cell/hook_cell.py +60 -38
- msprobe/mindspore/dump/hook_cell/primitive_hooks.py +45 -30
- msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +36 -1
- msprobe/mindspore/dump/hook_cell/wrap_api.py +92 -1
- msprobe/mindspore/dump/jit_dump.py +17 -5
- msprobe/mindspore/dump/kernel_dump/kernel_config.py +33 -0
- msprobe/mindspore/dump/kernel_graph_dump.py +9 -4
- msprobe/mindspore/dump/kernel_kbyk_dump.py +2 -4
- msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +140 -0
- msprobe/mindspore/dym_loader/hook_dynamic_loader.h +53 -0
- msprobe/mindspore/free_benchmark/api_pynative_self_check.py +156 -41
- msprobe/mindspore/free_benchmark/common/handler_params.py +1 -2
- msprobe/mindspore/free_benchmark/common/utils.py +19 -4
- msprobe/mindspore/free_benchmark/data/support_wrap_ops.yaml +0 -204
- msprobe/mindspore/free_benchmark/handler/base_handler.py +3 -3
- msprobe/mindspore/free_benchmark/handler/check_handler.py +4 -5
- msprobe/mindspore/free_benchmark/handler/fix_handler.py +4 -4
- msprobe/mindspore/free_benchmark/handler/handler_factory.py +4 -4
- msprobe/mindspore/free_benchmark/perturbation/add_noise.py +2 -2
- msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +15 -6
- msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +2 -2
- msprobe/mindspore/free_benchmark/perturbation/exchange_value.py +2 -2
- msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +13 -6
- msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +2 -2
- msprobe/mindspore/free_benchmark/self_check_tool_factory.py +2 -2
- msprobe/mindspore/grad_probe/global_context.py +28 -8
- msprobe/mindspore/grad_probe/grad_analyzer.py +50 -24
- msprobe/mindspore/grad_probe/grad_monitor.py +16 -1
- msprobe/mindspore/grad_probe/grad_stat_csv.py +33 -5
- msprobe/mindspore/grad_probe/hook.py +35 -12
- msprobe/mindspore/grad_probe/utils.py +18 -5
- msprobe/mindspore/mindtorch/__init__.py +18 -0
- msprobe/mindspore/mindtorch/mindtorch_adaptor.py +255 -0
- msprobe/mindspore/ms_config.py +27 -16
- msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +9 -4
- msprobe/mindspore/runtime.py +15 -0
- msprobe/mindspore/service.py +285 -113
- msprobe/mindspore/task_handler_factory.py +15 -0
- msprobe/msprobe.py +48 -10
- msprobe/pytorch/__init__.py +8 -6
- msprobe/pytorch/api_accuracy_checker/common/config.py +62 -0
- msprobe/pytorch/api_accuracy_checker/common/utils.py +31 -16
- msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +41 -8
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +103 -271
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_standard.yaml +4 -1
- msprobe/pytorch/api_accuracy_checker/compare/compare.py +69 -68
- msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +54 -0
- msprobe/pytorch/api_accuracy_checker/compare/compare_input.py +51 -0
- msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +2 -4
- msprobe/pytorch/api_accuracy_checker/generate_op_script/config_op.json +9 -0
- msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +478 -0
- msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +365 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/absolute_threshold.py +106 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/accumulative_error_compare.py +107 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/base_standard.py +151 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/benchmark_compare.py +226 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/binary_consistency.py +68 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/standard_config.py +218 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/standard_register.py +104 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/thousandth_standard.py +63 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/ulp_compare.py +200 -0
- msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +63 -2
- msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +21 -15
- msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +54 -22
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +140 -71
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +49 -8
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +9 -24
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +4 -12
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +5 -3
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/dump_dispatch.py +9 -4
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +3 -11
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +2 -2
- msprobe/pytorch/bench_functions/confusion_transpose.py +5 -1
- msprobe/pytorch/bench_functions/matmul_backward.py +12 -0
- msprobe/pytorch/bench_functions/npu_fusion_attention.py +142 -16
- msprobe/pytorch/bench_functions/rotary_mul.py +4 -0
- msprobe/pytorch/bench_functions/swiglu.py +10 -2
- msprobe/pytorch/common/parse_json.py +7 -6
- msprobe/pytorch/common/utils.py +101 -7
- msprobe/pytorch/compare/distributed_compare.py +17 -30
- msprobe/pytorch/compare/pt_compare.py +44 -22
- msprobe/pytorch/debugger/debugger_config.py +46 -27
- msprobe/pytorch/debugger/precision_debugger.py +42 -12
- msprobe/pytorch/dump/kernel_dump/kernel_config.py +33 -0
- msprobe/pytorch/dump/module_dump/module_dump.py +86 -0
- msprobe/pytorch/{module_processer.py → dump/module_dump/module_processer.py} +81 -10
- msprobe/pytorch/free_benchmark/common/constant.py +15 -0
- msprobe/pytorch/free_benchmark/common/counter.py +15 -0
- msprobe/pytorch/free_benchmark/common/enums.py +15 -0
- msprobe/pytorch/free_benchmark/common/params.py +10 -2
- msprobe/pytorch/free_benchmark/common/utils.py +29 -4
- msprobe/pytorch/free_benchmark/compare/grad_saver.py +20 -5
- msprobe/pytorch/free_benchmark/compare/single_benchmark.py +2 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +3 -1
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +6 -4
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +2 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +4 -0
- msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +41 -47
- msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +6 -5
- msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +0 -4
- msprobe/pytorch/grad_probe/grad_monitor.py +23 -6
- msprobe/pytorch/grad_probe/grad_stat_csv.py +40 -10
- msprobe/pytorch/hook_module/__init__.py +1 -1
- msprobe/pytorch/hook_module/hook_module.py +14 -11
- msprobe/pytorch/hook_module/register_optimizer_hook.py +59 -0
- msprobe/pytorch/hook_module/support_wrap_ops.yaml +35 -0
- msprobe/pytorch/hook_module/wrap_distributed.py +6 -8
- msprobe/pytorch/hook_module/wrap_functional.py +0 -38
- msprobe/pytorch/monitor/__init__.py +0 -0
- msprobe/pytorch/monitor/anomaly_analyse.py +201 -0
- msprobe/pytorch/monitor/anomaly_detect.py +425 -0
- msprobe/pytorch/monitor/csv2tb.py +166 -0
- msprobe/pytorch/monitor/distributed/__init__.py +0 -0
- msprobe/pytorch/monitor/distributed/distributed_ops.yaml +19 -0
- msprobe/pytorch/monitor/distributed/stack_blacklist.yaml +5 -0
- msprobe/pytorch/monitor/distributed/wrap_distributed.py +283 -0
- msprobe/pytorch/monitor/features.py +108 -0
- msprobe/pytorch/monitor/module_hook.py +1076 -0
- msprobe/pytorch/monitor/module_metric.py +172 -0
- msprobe/pytorch/monitor/module_spec_verifier.py +95 -0
- msprobe/pytorch/monitor/optimizer_collect.py +333 -0
- msprobe/pytorch/monitor/unittest/__init__.py +0 -0
- msprobe/pytorch/monitor/unittest/test_monitor.py +160 -0
- msprobe/pytorch/monitor/utils.py +321 -0
- msprobe/pytorch/monitor/visualizer.py +59 -0
- msprobe/pytorch/online_dispatch/__init__.py +2 -3
- msprobe/pytorch/online_dispatch/compare.py +29 -38
- msprobe/pytorch/online_dispatch/dispatch.py +58 -27
- msprobe/pytorch/online_dispatch/dump_compare.py +21 -9
- msprobe/pytorch/online_dispatch/single_compare.py +53 -32
- msprobe/pytorch/online_dispatch/torch_ops_config.yaml +1 -1
- msprobe/pytorch/online_dispatch/utils.py +49 -21
- msprobe/pytorch/parse_tool/lib/compare.py +21 -27
- msprobe/pytorch/parse_tool/lib/config.py +6 -8
- msprobe/pytorch/parse_tool/lib/file_desc.py +15 -1
- msprobe/pytorch/parse_tool/lib/interactive_cli.py +10 -10
- msprobe/pytorch/parse_tool/lib/parse_exception.py +7 -7
- msprobe/pytorch/parse_tool/lib/parse_tool.py +12 -12
- msprobe/pytorch/parse_tool/lib/utils.py +33 -53
- msprobe/pytorch/parse_tool/lib/visualization.py +11 -10
- msprobe/pytorch/pt_config.py +31 -8
- msprobe/pytorch/service.py +188 -108
- msprobe/visualization/__init__.py +14 -0
- msprobe/visualization/builder/__init__.py +14 -0
- msprobe/visualization/builder/graph_builder.py +222 -0
- msprobe/visualization/builder/msprobe_adapter.py +227 -0
- msprobe/visualization/compare/__init__.py +14 -0
- msprobe/visualization/compare/graph_comparator.py +180 -0
- msprobe/visualization/compare/mode_adapter.py +197 -0
- msprobe/visualization/graph/__init__.py +14 -0
- msprobe/visualization/graph/base_node.py +119 -0
- msprobe/visualization/graph/distributed_analyzer.py +318 -0
- msprobe/visualization/graph/graph.py +209 -0
- msprobe/visualization/graph/node_colors.py +95 -0
- msprobe/visualization/graph/node_op.py +39 -0
- msprobe/visualization/graph_service.py +288 -0
- msprobe/visualization/utils.py +217 -0
- mindstudio_probe-1.1.0.dist-info/RECORD +0 -287
- msprobe/docs/04.acl_config_examples.md +0 -78
- msprobe/mindspore/compare/layer_mapping.py +0 -146
- msprobe/mindspore/compare/modify_mapping.py +0 -107
- msprobe/mindspore/free_benchmark/decorator/dec_forward.py +0 -57
- msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +0 -122
- msprobe/pytorch/functional/module_dump.py +0 -84
- {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/LICENSE +0 -0
- {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/top_level.txt +0 -0
- /msprobe/mindspore/{free_benchmark/decorator → code_mapping}/__init__.py +0 -0
- /msprobe/pytorch/{functional → dump/module_dump}/__init__.py +0 -0
|
@@ -13,9 +13,10 @@
|
|
|
13
13
|
# See the License for the specific language governing permissions and
|
|
14
14
|
# limitations under the License.
|
|
15
15
|
|
|
16
|
+
import atexit
|
|
16
17
|
import os
|
|
17
18
|
|
|
18
|
-
from msprobe.core.data_dump.scope import
|
|
19
|
+
from msprobe.core.data_dump.scope import ScopeFactory
|
|
19
20
|
from msprobe.core.data_dump.json_writer import DataWriter
|
|
20
21
|
from msprobe.core.common.log import logger
|
|
21
22
|
from msprobe.core.common.const import Const
|
|
@@ -27,7 +28,6 @@ def build_data_collector(config):
|
|
|
27
28
|
|
|
28
29
|
|
|
29
30
|
class DataCollector:
|
|
30
|
-
multi_output_apis = ["_sort_", "npu_flash_attention"]
|
|
31
31
|
tasks_need_tensor_data = [Const.OVERFLOW_CHECK, Const.TENSOR, Const.FREE_BENCHMARK]
|
|
32
32
|
level_without_construct = [Const.LEVEL_L1, Const.LEVEL_L2]
|
|
33
33
|
|
|
@@ -37,13 +37,10 @@ class DataCollector:
|
|
|
37
37
|
self.data_processor = DataProcessorFactory.create_processor(self.config, self.data_writer)
|
|
38
38
|
self.module_processor = DataProcessorFactory.get_module_processor(self.config.framework)
|
|
39
39
|
self.module_count = {}
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
def __del__(self):
|
|
46
|
-
self.write_json()
|
|
40
|
+
self.scope = ScopeFactory(self.config).build_scope()
|
|
41
|
+
self.backward_module_names = {}
|
|
42
|
+
self.optimizer_status = ""
|
|
43
|
+
atexit.register(self.write_json)
|
|
47
44
|
|
|
48
45
|
@property
|
|
49
46
|
def dump_data_dir(self):
|
|
@@ -57,10 +54,6 @@ class DataCollector:
|
|
|
57
54
|
def check_scope_and_pid(scope, name, pid):
|
|
58
55
|
return (not scope or scope.check(name)) and pid == os.getpid()
|
|
59
56
|
|
|
60
|
-
@staticmethod
|
|
61
|
-
def is_inplace(module):
|
|
62
|
-
return getattr(module, "op_is_inplace", False)
|
|
63
|
-
|
|
64
57
|
def if_return_forward_new_output(self):
|
|
65
58
|
return self.data_processor.if_return_forward_new_output()
|
|
66
59
|
|
|
@@ -84,36 +77,54 @@ class DataCollector:
|
|
|
84
77
|
logger.debug(msg)
|
|
85
78
|
self.data_writer.update_data(data_info)
|
|
86
79
|
|
|
87
|
-
def
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
self.
|
|
91
|
-
|
|
80
|
+
def forward_input_data_collect(self, name, module, pid, module_input_output):
|
|
81
|
+
if self.config.task == Const.FREE_BENCHMARK:
|
|
82
|
+
backward_name = name.replace(Const.FORWARD, Const.BACKWARD)
|
|
83
|
+
if self.check_scope_and_pid(self.scope, backward_name, pid):
|
|
84
|
+
self.data_processor.analyze_forward_input(backward_name, module, module_input_output)
|
|
85
|
+
return
|
|
86
|
+
|
|
87
|
+
if not self.check_scope_and_pid(self.scope, name, pid):
|
|
88
|
+
return
|
|
89
|
+
|
|
90
|
+
data_info = self.data_processor.analyze_forward_input(name, module, module_input_output)
|
|
91
|
+
if self.config.level == Const.LEVEL_L2:
|
|
92
92
|
return
|
|
93
|
-
logger.info(f"API {name} is inplace.")
|
|
94
|
-
data_info = self.data_processor.analyze_pre_forward_inplace(name, module_input_output)
|
|
95
93
|
self.handle_data(name, data_info, flush=self.data_processor.is_terminated)
|
|
96
94
|
|
|
97
|
-
def
|
|
95
|
+
def forward_output_data_collect(self, name, module, pid, module_input_output):
|
|
98
96
|
self.update_construct(name)
|
|
99
97
|
if not self.check_scope_and_pid(self.scope, name, pid):
|
|
100
98
|
return
|
|
101
99
|
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
else:
|
|
105
|
-
data_info = self.data_processor.analyze_forward_inplace(name, module_input_output)
|
|
106
|
-
if self.config.level == "L2":
|
|
100
|
+
data_info = self.data_processor.analyze_forward_output(name, module, module_input_output)
|
|
101
|
+
if self.config.level == Const.LEVEL_L2:
|
|
107
102
|
return
|
|
108
103
|
self.data_writer.update_stack(self.data_processor.analyze_api_call_stack(name))
|
|
109
104
|
self.handle_data(name, data_info, flush=self.data_processor.is_terminated)
|
|
110
105
|
|
|
106
|
+
def forward_data_collect(self, name, module, pid, module_input_output):
|
|
107
|
+
self.update_construct(name)
|
|
108
|
+
if not self.check_scope_and_pid(self.scope, name, pid):
|
|
109
|
+
return
|
|
110
|
+
|
|
111
|
+
data_info = self.data_processor.analyze_forward(name, module, module_input_output)
|
|
112
|
+
self.data_writer.update_stack(self.data_processor.analyze_api_call_stack(name))
|
|
113
|
+
self.handle_data(name, data_info, flush=self.data_processor.is_terminated)
|
|
114
|
+
|
|
111
115
|
def backward_data_collect(self, name, module, pid, module_input_output):
|
|
112
116
|
self.update_construct(name)
|
|
113
117
|
if not self.check_scope_and_pid(self.scope, name, pid):
|
|
114
118
|
return
|
|
115
119
|
|
|
116
120
|
data_info = self.data_processor.analyze_backward(name, module, module_input_output)
|
|
121
|
+
if self.config.level == Const.LEVEL_L2:
|
|
122
|
+
return
|
|
123
|
+
# 获取执行反向的模块名称
|
|
124
|
+
if data_info and name.split(Const.SEP)[0] in Const.MODULE_PREFIX:
|
|
125
|
+
module_name = name.rsplit(Const.SEP, 2)[0]
|
|
126
|
+
# 将模块名称加入到反向模块名称集合中,用于梯度收集时判断是否需要收集梯度
|
|
127
|
+
self.backward_module_names[module_name] = True
|
|
117
128
|
self.handle_data(name, data_info, flush=self.data_processor.is_terminated)
|
|
118
129
|
|
|
119
130
|
def backward_input_data_collect(self, name, module, pid, module_input_output):
|
|
@@ -134,12 +145,17 @@ class DataCollector:
|
|
|
134
145
|
|
|
135
146
|
def update_construct(self, name):
|
|
136
147
|
if self.config.level not in DataCollector.level_without_construct:
|
|
137
|
-
self.
|
|
148
|
+
if self.optimizer_status in [Const.OPTIMIZER, Const.CLIP_GRAD]:
|
|
149
|
+
self.data_writer.update_construct({name: self.optimizer_status})
|
|
150
|
+
else:
|
|
151
|
+
self.data_writer.update_construct({name: self.module_processor.api_parent_node})
|
|
138
152
|
self.data_writer.update_construct(self.module_processor.module_node)
|
|
139
153
|
|
|
140
154
|
def handle_data(self, name, data_info, flush=False):
|
|
141
155
|
if data_info:
|
|
142
156
|
self.update_data(name, data_info)
|
|
157
|
+
if self.config.async_dump:
|
|
158
|
+
return
|
|
143
159
|
if not flush:
|
|
144
160
|
self.data_writer.flush_data_periodically()
|
|
145
161
|
else:
|
|
@@ -147,7 +163,23 @@ class DataCollector:
|
|
|
147
163
|
|
|
148
164
|
def update_dump_paths(self, *args):
|
|
149
165
|
self.data_writer.update_dump_paths(*args)
|
|
150
|
-
|
|
166
|
+
|
|
167
|
+
def initialize_json_file(self, framework=Const.UNKNOWN_FRAMEWORK):
|
|
168
|
+
self.data_writer.initialize_json_file(task=self.config.task, level=self.config.level, framework=framework)
|
|
151
169
|
|
|
152
170
|
def update_iter(self, current_iter):
|
|
153
171
|
self.data_processor.update_iter(current_iter)
|
|
172
|
+
|
|
173
|
+
def params_data_collect(self, name, param_name, pid, data):
|
|
174
|
+
grad_name = name + Const.SEP + Const.PARAMS_GRAD
|
|
175
|
+
# 校验scope和pid,以及当前name是否有过反向计算
|
|
176
|
+
if not self.check_scope_and_pid(self.scope, name, pid) and not self.backward_module_names.get(name):
|
|
177
|
+
# 如果没有反向计算,则需要清除之前占位写入的grad数据
|
|
178
|
+
if self.data_writer.cache_data.get("data"):
|
|
179
|
+
self.data_writer.cache_data.get("data").pop(grad_name, None)
|
|
180
|
+
return
|
|
181
|
+
data_info = self.data_processor.analyze_params(grad_name, param_name, data)
|
|
182
|
+
self.handle_data(grad_name, data_info, flush=self.data_processor.is_terminated)
|
|
183
|
+
|
|
184
|
+
def fill_stack_tensor_data(self):
|
|
185
|
+
self.data_writer.fill_stack_tensor_data()
|
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
2
2
|
# All rights reserved.
|
|
3
3
|
#
|
|
4
|
-
# Licensed under the Apache License, Version 2.0
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
5
|
# you may not use this file except in compliance with the License.
|
|
6
6
|
# You may obtain a copy of the License at
|
|
7
7
|
#
|
|
@@ -15,10 +15,11 @@
|
|
|
15
15
|
|
|
16
16
|
import inspect
|
|
17
17
|
import os
|
|
18
|
-
from dataclasses import dataclass
|
|
18
|
+
from dataclasses import dataclass, is_dataclass
|
|
19
19
|
from typing import Tuple, Dict, Optional, Any
|
|
20
20
|
|
|
21
21
|
import numpy as np
|
|
22
|
+
|
|
22
23
|
from msprobe.core.common.const import Const
|
|
23
24
|
from msprobe.core.common.log import logger
|
|
24
25
|
from msprobe.core.common.utils import convert_tuple, CompareException
|
|
@@ -38,9 +39,8 @@ class ModuleForwardInputsOutputs:
|
|
|
38
39
|
def output_tuple(self):
|
|
39
40
|
return convert_tuple(self.output)
|
|
40
41
|
|
|
41
|
-
def
|
|
42
|
-
|
|
43
|
-
return args
|
|
42
|
+
def update_output_with_args_and_kwargs(self):
|
|
43
|
+
self.output = self.args + tuple(self.kwargs.values())
|
|
44
44
|
|
|
45
45
|
|
|
46
46
|
@dataclass
|
|
@@ -76,11 +76,12 @@ class ModuleBackwardOutputs:
|
|
|
76
76
|
|
|
77
77
|
|
|
78
78
|
class TensorStatInfo:
|
|
79
|
-
def __init__(self, max_val=None, min_val=None, mean_val=None, norm_val=None):
|
|
79
|
+
def __init__(self, max_val=None, min_val=None, mean_val=None, norm_val=None, stack_tensor_stat=None):
|
|
80
80
|
self.max = max_val
|
|
81
81
|
self.min = min_val
|
|
82
82
|
self.mean = mean_val
|
|
83
83
|
self.norm = norm_val
|
|
84
|
+
self.stack_tensor_stat = stack_tensor_stat
|
|
84
85
|
|
|
85
86
|
|
|
86
87
|
class BaseDataProcessor:
|
|
@@ -101,6 +102,9 @@ class BaseDataProcessor:
|
|
|
101
102
|
self.current_iter = 0
|
|
102
103
|
self._return_forward_new_output = False
|
|
103
104
|
self._forward_new_output = None
|
|
105
|
+
self.save_name = None
|
|
106
|
+
if hasattr(config, "data_mode"):
|
|
107
|
+
self.allowed_data_mode = self._get_allowed_data_mode(config.data_mode)
|
|
104
108
|
|
|
105
109
|
@property
|
|
106
110
|
def data_path(self):
|
|
@@ -182,6 +186,18 @@ class BaseDataProcessor:
|
|
|
182
186
|
def _analyze_numpy(value, numpy_type):
|
|
183
187
|
return {"type": numpy_type, "value": value}
|
|
184
188
|
|
|
189
|
+
@staticmethod
|
|
190
|
+
def _get_allowed_data_mode(data_mode):
|
|
191
|
+
if Const.ALL in data_mode:
|
|
192
|
+
allowed_data_mode = [Const.FORWARD, Const.BACKWARD, Const.INPUT, Const.OUTPUT]
|
|
193
|
+
else:
|
|
194
|
+
allowed_data_mode = list(set(data_mode))
|
|
195
|
+
if Const.FORWARD not in allowed_data_mode and Const.BACKWARD not in allowed_data_mode:
|
|
196
|
+
allowed_data_mode += [Const.FORWARD, Const.BACKWARD]
|
|
197
|
+
if Const.INPUT not in allowed_data_mode and Const.OUTPUT not in allowed_data_mode:
|
|
198
|
+
allowed_data_mode += [Const.INPUT, Const.OUTPUT]
|
|
199
|
+
return allowed_data_mode
|
|
200
|
+
|
|
185
201
|
@classmethod
|
|
186
202
|
def get_special_types(cls):
|
|
187
203
|
return cls.special_type
|
|
@@ -194,25 +210,42 @@ class BaseDataProcessor:
|
|
|
194
210
|
if isinstance(args, cls.get_special_types()):
|
|
195
211
|
arg_transform = transform(args, cls._recursive_key_stack)
|
|
196
212
|
return arg_transform
|
|
213
|
+
elif isinstance(args, tuple) and hasattr(args, '_fields'):
|
|
214
|
+
# namedtuple to dict
|
|
215
|
+
args_dict = {field: getattr(args, field) for field in args._fields}
|
|
216
|
+
return cls.apply_transform_dict(args_dict, transform, depth)
|
|
217
|
+
elif is_dataclass(args):
|
|
218
|
+
# dataclass to dict
|
|
219
|
+
args_dict = {field: getattr(args, field) for field in args.__dataclass_fields__}
|
|
220
|
+
return cls.apply_transform_dict(args_dict, transform, depth)
|
|
197
221
|
elif isinstance(args, (list, tuple)):
|
|
198
|
-
result_list =
|
|
199
|
-
for i, arg in enumerate(args):
|
|
200
|
-
cls._recursive_key_stack.append(str(i))
|
|
201
|
-
result_list.append(cls.recursive_apply_transform(arg, transform, depth=depth + 1))
|
|
202
|
-
cls._recursive_key_stack.pop()
|
|
222
|
+
result_list = cls.apply_transform_list(args, transform, depth)
|
|
203
223
|
return type(args)(result_list)
|
|
204
224
|
elif isinstance(args, dict):
|
|
205
|
-
|
|
206
|
-
for k, arg in args.items():
|
|
207
|
-
cls._recursive_key_stack.append(str(k))
|
|
208
|
-
result_dict[k] = cls.recursive_apply_transform(arg, transform, depth=depth + 1)
|
|
209
|
-
cls._recursive_key_stack.pop()
|
|
210
|
-
return result_dict
|
|
225
|
+
return cls.apply_transform_dict(args, transform, depth)
|
|
211
226
|
elif args is not None:
|
|
212
|
-
logger.
|
|
227
|
+
logger.debug(f"Data type {type(args)} is not supported.")
|
|
213
228
|
return None
|
|
214
229
|
else:
|
|
215
230
|
return None
|
|
231
|
+
|
|
232
|
+
@classmethod
|
|
233
|
+
def apply_transform_dict(cls, args, transform, depth):
|
|
234
|
+
result_dict = {}
|
|
235
|
+
for k, arg in args.items():
|
|
236
|
+
cls._recursive_key_stack.append(str(k))
|
|
237
|
+
result_dict[k] = cls.recursive_apply_transform(arg, transform, depth=depth + 1)
|
|
238
|
+
cls._recursive_key_stack.pop()
|
|
239
|
+
return result_dict
|
|
240
|
+
|
|
241
|
+
@classmethod
|
|
242
|
+
def apply_transform_list(cls, args, transform, depth):
|
|
243
|
+
result_list = []
|
|
244
|
+
for i, arg in enumerate(args):
|
|
245
|
+
cls._recursive_key_stack.append(str(i))
|
|
246
|
+
result_list.append(cls.recursive_apply_transform(arg, transform, depth=depth + 1))
|
|
247
|
+
cls._recursive_key_stack.pop()
|
|
248
|
+
return result_list
|
|
216
249
|
|
|
217
250
|
def if_return_forward_new_output(self):
|
|
218
251
|
return self._return_forward_new_output
|
|
@@ -239,17 +272,12 @@ class BaseDataProcessor:
|
|
|
239
272
|
Return:
|
|
240
273
|
bool: True if the parameters are in data_mode or data_mode is all, False otherwise.
|
|
241
274
|
"""
|
|
242
|
-
return
|
|
243
|
-
forward_backward in self.config.data_mode or
|
|
244
|
-
input_output in self.config.data_mode)
|
|
245
|
-
|
|
246
|
-
def analyze_pre_forward(self, name, module, module_input_output: ModuleForwardInputsOutputs):
|
|
247
|
-
pass
|
|
275
|
+
return forward_backward in self.allowed_data_mode and input_output in self.allowed_data_mode
|
|
248
276
|
|
|
249
277
|
def analyze_element(self, element):
|
|
250
278
|
return self.recursive_apply_transform(element, self.analyze_single_element)
|
|
251
279
|
|
|
252
|
-
def
|
|
280
|
+
def analyze_forward_input(self, name, module, module_input_output: ModuleForwardInputsOutputs):
|
|
253
281
|
api_info_struct = {}
|
|
254
282
|
# check whether data_mode contains forward or input
|
|
255
283
|
if self.is_dump_for_data_mode(Const.FORWARD, Const.INPUT):
|
|
@@ -261,16 +289,22 @@ class BaseDataProcessor:
|
|
|
261
289
|
kwargs_info_list = self.analyze_element(module_input_output.kwargs)
|
|
262
290
|
api_info_struct[name][Const.INPUT_KWARGS] = kwargs_info_list
|
|
263
291
|
|
|
264
|
-
|
|
292
|
+
return api_info_struct
|
|
293
|
+
|
|
294
|
+
def analyze_forward_output(self, name, module, module_input_output: ModuleForwardInputsOutputs):
|
|
295
|
+
api_info_struct = {}
|
|
296
|
+
# check whether data_mode contains forward or input
|
|
265
297
|
if self.is_dump_for_data_mode(Const.FORWARD, Const.OUTPUT):
|
|
266
|
-
api_info_struct[name] =
|
|
298
|
+
api_info_struct[name] = {}
|
|
267
299
|
self.api_data_category = Const.OUTPUT
|
|
268
300
|
output_info_list = self.analyze_element(module_input_output.output_tuple)
|
|
269
301
|
api_info_struct[name][Const.OUTPUT] = output_info_list
|
|
302
|
+
|
|
270
303
|
return api_info_struct
|
|
271
304
|
|
|
272
|
-
def
|
|
305
|
+
def analyze_forward(self, name, module, module_input_output: ModuleForwardInputsOutputs):
|
|
273
306
|
api_info_struct = {}
|
|
307
|
+
# check whether data_mode contains forward or input
|
|
274
308
|
if self.is_dump_for_data_mode(Const.FORWARD, Const.INPUT):
|
|
275
309
|
api_info_struct[name] = {}
|
|
276
310
|
self.api_data_category = Const.INPUT
|
|
@@ -279,16 +313,18 @@ class BaseDataProcessor:
|
|
|
279
313
|
self.api_data_category = Const.KWARGS
|
|
280
314
|
kwargs_info_list = self.analyze_element(module_input_output.kwargs)
|
|
281
315
|
api_info_struct[name][Const.INPUT_KWARGS] = kwargs_info_list
|
|
282
|
-
return api_info_struct
|
|
283
316
|
|
|
284
|
-
|
|
285
|
-
concat_args = module_input_output.concat_args_and_kwargs()
|
|
286
|
-
api_info_struct = {}
|
|
317
|
+
# check whether data_mode contains forward or output
|
|
287
318
|
if self.is_dump_for_data_mode(Const.FORWARD, Const.OUTPUT):
|
|
288
|
-
api_info_struct[name] = {}
|
|
319
|
+
api_info_struct[name] = api_info_struct.get(name, {})
|
|
289
320
|
self.api_data_category = Const.OUTPUT
|
|
290
|
-
output_info_list = self.analyze_element(
|
|
321
|
+
output_info_list = self.analyze_element(module_input_output.output_tuple)
|
|
291
322
|
api_info_struct[name][Const.OUTPUT] = output_info_list
|
|
323
|
+
|
|
324
|
+
if name in api_info_struct and hasattr(module_input_output, Const.PARAMS):
|
|
325
|
+
self.api_data_category = Const.PARAMS
|
|
326
|
+
api_info_struct[name][Const.PARAMS] = self.analyze_element(getattr(module_input_output, Const.PARAMS))
|
|
327
|
+
|
|
292
328
|
return api_info_struct
|
|
293
329
|
|
|
294
330
|
def analyze_backward(self, name, module, module_input_output: ModuleBackwardInputsOutputs):
|
|
@@ -329,9 +365,21 @@ class BaseDataProcessor:
|
|
|
329
365
|
api_info_struct[name][Const.OUTPUT] = output_info_list
|
|
330
366
|
return api_info_struct
|
|
331
367
|
|
|
368
|
+
def analyze_params(self, name, param_name, grad):
|
|
369
|
+
api_info_struct = {}
|
|
370
|
+
self.save_name = name + Const.SEP + param_name
|
|
371
|
+
data_info = self.analyze_element(grad)
|
|
372
|
+
grad_info_dict = {param_name: [data_info]}
|
|
373
|
+
api_info_struct[name] = grad_info_dict
|
|
374
|
+
return api_info_struct
|
|
375
|
+
|
|
332
376
|
def get_save_file_path(self, suffix):
|
|
333
377
|
file_format = Const.PT_SUFFIX if self.config.framework == Const.PT_FRAMEWORK else Const.NUMPY_SUFFIX
|
|
334
|
-
|
|
335
|
-
|
|
378
|
+
if self.save_name is not None:
|
|
379
|
+
dump_data_name = (self.save_name + file_format)
|
|
380
|
+
self.save_name = None
|
|
381
|
+
else:
|
|
382
|
+
dump_data_name = (self.current_api_or_module_name + Const.SEP + self.api_data_category + Const.SEP +
|
|
383
|
+
suffix + file_format)
|
|
336
384
|
file_path = os.path.join(self.data_writer.dump_tensor_data_dir, dump_data_name)
|
|
337
385
|
return dump_data_name, file_path
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright (c) 2024-
|
|
1
|
+
# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
|
|
2
2
|
# All rights reserved.
|
|
3
3
|
#
|
|
4
4
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
@@ -56,7 +56,7 @@ class DataProcessorFactory:
|
|
|
56
56
|
FreeBenchmarkDataProcessor as PytorchFreeBenchmarkDataProcessor,
|
|
57
57
|
KernelDumpDataProcessor as PytorchKernelDumpDataProcessor
|
|
58
58
|
)
|
|
59
|
-
from msprobe.pytorch.module_processer import ModuleProcesser
|
|
59
|
+
from msprobe.pytorch.dump.module_dump.module_processer import ModuleProcesser
|
|
60
60
|
cls.register_processor(Const.PT_FRAMEWORK, Const.STATISTICS, PytorchStatisticsDataProcessor)
|
|
61
61
|
cls.register_processor(Const.PT_FRAMEWORK, Const.TENSOR, PytorchTensorDataProcessor)
|
|
62
62
|
cls.register_processor(Const.PT_FRAMEWORK, Const.OVERFLOW_CHECK, PytorchOverflowCheckDataProcessor)
|
|
@@ -67,10 +67,12 @@ class DataProcessorFactory:
|
|
|
67
67
|
from msprobe.core.data_dump.data_processor.mindspore_processor import (
|
|
68
68
|
StatisticsDataProcessor as MindsporeStatisticsDataProcessor,
|
|
69
69
|
TensorDataProcessor as MindsporeTensorDataProcessor,
|
|
70
|
-
OverflowCheckDataProcessor as MindsporeOverflowCheckDataProcessor
|
|
70
|
+
OverflowCheckDataProcessor as MindsporeOverflowCheckDataProcessor,
|
|
71
|
+
KernelDumpDataProcessor as MindsporeKernelDumpDataProcessor
|
|
71
72
|
)
|
|
72
73
|
from msprobe.mindspore.cell_processor import CellProcessor
|
|
73
74
|
cls.register_processor(Const.MS_FRAMEWORK, Const.STATISTICS, MindsporeStatisticsDataProcessor)
|
|
74
75
|
cls.register_processor(Const.MS_FRAMEWORK, Const.TENSOR, MindsporeTensorDataProcessor)
|
|
75
76
|
cls.register_processor(Const.MS_FRAMEWORK, Const.OVERFLOW_CHECK, MindsporeOverflowCheckDataProcessor)
|
|
77
|
+
cls.register_processor(Const.MS_FRAMEWORK, Const.KERNEL_DUMP, MindsporeKernelDumpDataProcessor)
|
|
76
78
|
cls.register_module_processor(Const.MS_FRAMEWORK, CellProcessor)
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright 2024 Huawei Technologies Co., Ltd
|
|
1
|
+
# Copyright 2024-2025 Huawei Technologies Co., Ltd
|
|
2
2
|
#
|
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
4
|
# you may not use this file except in compliance with the License.
|
|
@@ -16,7 +16,7 @@
|
|
|
16
16
|
import zlib
|
|
17
17
|
|
|
18
18
|
import mindspore as ms
|
|
19
|
-
from mindspore import mint, ops
|
|
19
|
+
from mindspore import mint, ops, hal
|
|
20
20
|
from mindspore._c_expression.typing import Number
|
|
21
21
|
import numpy as np
|
|
22
22
|
|
|
@@ -28,6 +28,12 @@ from msprobe.mindspore.common.utils import convert_bf16_to_fp32, save_tensor_as_
|
|
|
28
28
|
from msprobe.mindspore.common.log import logger
|
|
29
29
|
from msprobe.mindspore.dump.hook_cell.api_registry import api_register
|
|
30
30
|
|
|
31
|
+
has_adump = True
|
|
32
|
+
try:
|
|
33
|
+
from msprobe.lib import _msprobe_c
|
|
34
|
+
except ImportError:
|
|
35
|
+
has_adump = False
|
|
36
|
+
|
|
31
37
|
|
|
32
38
|
class MindsporeDataProcessor(BaseDataProcessor):
|
|
33
39
|
mindspore_special_type = tuple([ms.Tensor, Number])
|
|
@@ -37,6 +43,7 @@ class MindsporeDataProcessor(BaseDataProcessor):
|
|
|
37
43
|
self.mindspore_object_key = {
|
|
38
44
|
"dtype": self.analyze_dtype_in_kwargs
|
|
39
45
|
}
|
|
46
|
+
self._async_dump_cache = {}
|
|
40
47
|
|
|
41
48
|
@staticmethod
|
|
42
49
|
def get_md5_for_tensor(x):
|
|
@@ -49,15 +56,10 @@ class MindsporeDataProcessor(BaseDataProcessor):
|
|
|
49
56
|
def analyze_dtype_in_kwargs(element):
|
|
50
57
|
return {"type": "mindspore.dtype", "value": str(element)}
|
|
51
58
|
|
|
52
|
-
@
|
|
53
|
-
def
|
|
54
|
-
return super().get_special_types() + cls.mindspore_special_type
|
|
55
|
-
|
|
56
|
-
def get_stat_info(self, data):
|
|
59
|
+
@staticmethod
|
|
60
|
+
def get_stat_info_sync(data):
|
|
57
61
|
tensor_stat = TensorStatInfo()
|
|
58
|
-
if data.
|
|
59
|
-
return tensor_stat
|
|
60
|
-
elif data.dtype == ms.bool_:
|
|
62
|
+
if data.dtype == ms.bool_:
|
|
61
63
|
data_np = data.asnumpy()
|
|
62
64
|
tensor_stat.max = np.max(data_np).item()
|
|
63
65
|
tensor_stat.min = np.min(data_np).item()
|
|
@@ -70,7 +72,7 @@ class MindsporeDataProcessor(BaseDataProcessor):
|
|
|
70
72
|
tensor_stat.mean = np.mean(data_abs).item()
|
|
71
73
|
tensor_stat.norm = np.linalg.norm(data_abs).item()
|
|
72
74
|
else:
|
|
73
|
-
if not ops.is_floating_point(data):
|
|
75
|
+
if not ops.is_floating_point(data) or data.dtype == ms.float64:
|
|
74
76
|
data = data.to(ms.float32)
|
|
75
77
|
api_register.norm_inner_op_set_ori_func()
|
|
76
78
|
get_max_value = api_register.mint_ops_ori_attr.get("max", mint.max)
|
|
@@ -87,6 +89,47 @@ class MindsporeDataProcessor(BaseDataProcessor):
|
|
|
87
89
|
api_register.norm_inner_op_set_hook_func()
|
|
88
90
|
return tensor_stat
|
|
89
91
|
|
|
92
|
+
@staticmethod
|
|
93
|
+
def get_stat_info_async(data):
|
|
94
|
+
tensor_stat = TensorStatInfo()
|
|
95
|
+
stack_method = api_register.functional_ori_attr.get("stack", ms.ops.stack)
|
|
96
|
+
if data.dtype == ms.complex64 or data.dtype == ms.complex128:
|
|
97
|
+
logger.warning("Async dump do not support complex data!")
|
|
98
|
+
return tensor_stat
|
|
99
|
+
elif data.dtype == ms.bool_:
|
|
100
|
+
tensor_stat.stack_tensor_stat = (["Max", "Min"], stack_method([data.any(), data.all()]))
|
|
101
|
+
elif not data.shape:
|
|
102
|
+
tensor_stat.stack_tensor_stat = (["Max", "Min", "Mean", "Norm"], stack_method([data, data, data, data]))
|
|
103
|
+
else:
|
|
104
|
+
if not ops.is_floating_point(data) or data.dtype == ms.float64:
|
|
105
|
+
data = data.to(ms.float32)
|
|
106
|
+
api_register.norm_inner_op_set_ori_func()
|
|
107
|
+
get_max_value = api_register.mint_ops_ori_attr.get("max", mint.max)
|
|
108
|
+
get_min_value = api_register.mint_ops_ori_attr.get("min", mint.min)
|
|
109
|
+
get_mean_value = api_register.mint_ops_ori_attr.get("mean", mint.mean)
|
|
110
|
+
if hasattr(mint, "norm"):
|
|
111
|
+
get_norm_value = api_register.mint_ops_ori_attr.get("norm", mint.norm)
|
|
112
|
+
else:
|
|
113
|
+
get_norm_value = api_register.functional_ori_attr.get("norm", ops.norm)
|
|
114
|
+
tensor_stat.stack_tensor_stat = (["Max", "Min", "Mean", "Norm"], stack_method(
|
|
115
|
+
[get_max_value(data), get_min_value(data), get_mean_value(data), get_norm_value(data)]))
|
|
116
|
+
api_register.norm_inner_op_set_hook_func()
|
|
117
|
+
return tensor_stat
|
|
118
|
+
|
|
119
|
+
@classmethod
|
|
120
|
+
def get_special_types(cls):
|
|
121
|
+
return super().get_special_types() + cls.mindspore_special_type
|
|
122
|
+
|
|
123
|
+
def get_stat_info(self, data):
|
|
124
|
+
tensor_stat = TensorStatInfo()
|
|
125
|
+
if data.numel() == 0:
|
|
126
|
+
return tensor_stat
|
|
127
|
+
else:
|
|
128
|
+
if self.config.async_dump:
|
|
129
|
+
return MindsporeDataProcessor.get_stat_info_async(data)
|
|
130
|
+
else:
|
|
131
|
+
return MindsporeDataProcessor.get_stat_info_sync(data)
|
|
132
|
+
|
|
90
133
|
def analyze_single_element(self, element, suffix_stack):
|
|
91
134
|
if suffix_stack and suffix_stack[-1] in self.mindspore_object_key:
|
|
92
135
|
return self.mindspore_object_key[suffix_stack[-1]](element)
|
|
@@ -107,13 +150,17 @@ class MindsporeDataProcessor(BaseDataProcessor):
|
|
|
107
150
|
tensor_json = {
|
|
108
151
|
'type': 'mindspore.Tensor',
|
|
109
152
|
'dtype': str(tensor.dtype),
|
|
110
|
-
'shape': tensor.shape
|
|
111
|
-
'Max': self.transfer_type(tensor_stat.max),
|
|
112
|
-
'Min': self.transfer_type(tensor_stat.min),
|
|
113
|
-
'Mean': self.transfer_type(tensor_stat.mean),
|
|
114
|
-
'Norm': self.transfer_type(tensor_stat.norm),
|
|
153
|
+
'shape': tensor.shape
|
|
115
154
|
}
|
|
116
|
-
|
|
155
|
+
|
|
156
|
+
if tensor_stat.stack_tensor_stat is None:
|
|
157
|
+
tensor_json.update({'Max': self.transfer_type(tensor_stat.max)})
|
|
158
|
+
tensor_json.update({'Min': self.transfer_type(tensor_stat.min)})
|
|
159
|
+
tensor_json.update({'Mean': self.transfer_type(tensor_stat.mean)})
|
|
160
|
+
tensor_json.update({'Norm': self.transfer_type(tensor_stat.norm)})
|
|
161
|
+
else:
|
|
162
|
+
tensor_json.update({'tensor_stat': tensor_stat.stack_tensor_stat})
|
|
163
|
+
if self.config.summary_mode == Const.MD5 and not self.config.async_dump:
|
|
117
164
|
tensor_md5 = self.get_md5_for_tensor(tensor)
|
|
118
165
|
tensor_json.update({Const.MD5: tensor_md5})
|
|
119
166
|
return tensor_json
|
|
@@ -124,11 +171,19 @@ class StatisticsDataProcessor(MindsporeDataProcessor):
|
|
|
124
171
|
|
|
125
172
|
|
|
126
173
|
class TensorDataProcessor(MindsporeDataProcessor):
|
|
174
|
+
def dump_async_data(self):
|
|
175
|
+
for file_path, tensor in self._async_dump_cache.items():
|
|
176
|
+
save_tensor_as_npy(tensor, file_path)
|
|
177
|
+
self._async_dump_cache.clear()
|
|
178
|
+
|
|
127
179
|
def _analyze_tensor(self, tensor, suffix):
|
|
128
180
|
dump_data_name, file_path = self.get_save_file_path(suffix)
|
|
129
181
|
single_arg = super()._analyze_tensor(tensor, suffix)
|
|
130
182
|
single_arg.update({"data_name": dump_data_name})
|
|
131
|
-
|
|
183
|
+
if self.config.async_dump:
|
|
184
|
+
self._async_dump_cache[file_path] = tensor.copy()
|
|
185
|
+
else:
|
|
186
|
+
save_tensor_as_npy(tensor, file_path)
|
|
132
187
|
return single_arg
|
|
133
188
|
|
|
134
189
|
|
|
@@ -138,6 +193,7 @@ class OverflowCheckDataProcessor(MindsporeDataProcessor):
|
|
|
138
193
|
def __init__(self, config, data_writer):
|
|
139
194
|
super().__init__(config, data_writer)
|
|
140
195
|
self.has_overflow = False
|
|
196
|
+
self.cached_api_info = {}
|
|
141
197
|
self.cached_tensors_and_file_paths = {}
|
|
142
198
|
self.real_overflow_nums = 0
|
|
143
199
|
self.overflow_nums = config.overflow_nums
|
|
@@ -150,6 +206,20 @@ class OverflowCheckDataProcessor(MindsporeDataProcessor):
|
|
|
150
206
|
return True
|
|
151
207
|
return False
|
|
152
208
|
|
|
209
|
+
def analyze_forward_input(self, name, module, module_input_output: ModuleForwardInputsOutputs):
|
|
210
|
+
self.has_overflow = False
|
|
211
|
+
self.cached_api_info = super().analyze_forward_input(name, module, module_input_output)
|
|
212
|
+
return None
|
|
213
|
+
|
|
214
|
+
def analyze_forward_output(self, name, module, module_input_output: ModuleForwardInputsOutputs):
|
|
215
|
+
api_info_struct = super().analyze_forward_output(name, module, module_input_output)
|
|
216
|
+
if name in self.cached_api_info and name in api_info_struct:
|
|
217
|
+
self.cached_api_info[name].update(api_info_struct[name])
|
|
218
|
+
elif name in api_info_struct:
|
|
219
|
+
self.cached_api_info = api_info_struct
|
|
220
|
+
self.maybe_save_overflow_data()
|
|
221
|
+
return self.cached_api_info if self.has_overflow else None
|
|
222
|
+
|
|
153
223
|
def analyze_forward(self, name, module, module_input_output: ModuleForwardInputsOutputs):
|
|
154
224
|
self.has_overflow = False
|
|
155
225
|
api_info_struct = super().analyze_forward(name, module, module_input_output)
|
|
@@ -161,6 +231,12 @@ class OverflowCheckDataProcessor(MindsporeDataProcessor):
|
|
|
161
231
|
api_info_struct = super().analyze_backward(name, module, module_input_output)
|
|
162
232
|
self.maybe_save_overflow_data()
|
|
163
233
|
return api_info_struct if self.has_overflow else None
|
|
234
|
+
|
|
235
|
+
def analyze_params(self, name, param_name, grad):
|
|
236
|
+
self.has_overflow = False
|
|
237
|
+
api_info_struct = super().analyze_params(name, param_name, grad)
|
|
238
|
+
self.maybe_save_overflow_data()
|
|
239
|
+
return api_info_struct if self.has_overflow else None
|
|
164
240
|
|
|
165
241
|
def maybe_save_overflow_data(self):
|
|
166
242
|
if self.has_overflow:
|
|
@@ -190,3 +266,61 @@ class OverflowCheckDataProcessor(MindsporeDataProcessor):
|
|
|
190
266
|
self._analyze_maybe_overflow_tensor(single_arg)
|
|
191
267
|
single_arg.update({"data_name": dump_data_name})
|
|
192
268
|
return single_arg
|
|
269
|
+
|
|
270
|
+
|
|
271
|
+
class KernelDumpDataProcessor(MindsporeDataProcessor):
|
|
272
|
+
def __init__(self, config, data_writer):
|
|
273
|
+
super().__init__(config, data_writer)
|
|
274
|
+
self.enable_kernel_dump = True
|
|
275
|
+
|
|
276
|
+
@staticmethod
|
|
277
|
+
def start_kernel_dump(config_path):
|
|
278
|
+
hal.synchronize()
|
|
279
|
+
_msprobe_c.init_dump()
|
|
280
|
+
_msprobe_c.set_dump(config_path)
|
|
281
|
+
hal.synchronize()
|
|
282
|
+
|
|
283
|
+
@staticmethod
|
|
284
|
+
def stop_kernel_dump():
|
|
285
|
+
hal.synchronize()
|
|
286
|
+
_msprobe_c.finalize_dump()
|
|
287
|
+
hal.synchronize()
|
|
288
|
+
|
|
289
|
+
@staticmethod
|
|
290
|
+
def _print_unsupported_log(api_name):
|
|
291
|
+
logger.warning(f"The kernel dump does not support the {api_name} API.")
|
|
292
|
+
|
|
293
|
+
def analyze_forward_input(self, name, module, module_input_output):
|
|
294
|
+
if not self.enable_kernel_dump:
|
|
295
|
+
return
|
|
296
|
+
if not has_adump:
|
|
297
|
+
logger.warning("The current msprobe package does not compile adump, and kernel dump cannot be used.")
|
|
298
|
+
self.enable_kernel_dump = False
|
|
299
|
+
return
|
|
300
|
+
self.start_kernel_dump(self.config.kernel_config_path)
|
|
301
|
+
|
|
302
|
+
def analyze_forward_output(self, name, module, module_input_output):
|
|
303
|
+
if not self.enable_kernel_dump:
|
|
304
|
+
return
|
|
305
|
+
self.enable_kernel_dump = False
|
|
306
|
+
self.stop_kernel_dump()
|
|
307
|
+
logger.info(f"The kernel data of {name} is dumped successfully.")
|
|
308
|
+
|
|
309
|
+
def analyze_backward_input(self, name, module, module_input_output):
|
|
310
|
+
if not self.enable_kernel_dump:
|
|
311
|
+
return
|
|
312
|
+
if not has_adump:
|
|
313
|
+
logger.warning("The current msprobe package does not compile adump, and kernel dump cannot be used.")
|
|
314
|
+
self.enable_kernel_dump = False
|
|
315
|
+
return
|
|
316
|
+
self.start_kernel_dump(self.config.kernel_config_path)
|
|
317
|
+
|
|
318
|
+
def analyze_backward(self, name, module, module_input_output):
|
|
319
|
+
if not self.enable_kernel_dump:
|
|
320
|
+
return
|
|
321
|
+
self.enable_kernel_dump = False
|
|
322
|
+
self.stop_kernel_dump()
|
|
323
|
+
logger.info(f"The kernel data of {name} is dumped successfully.")
|
|
324
|
+
|
|
325
|
+
def reset_status(self):
|
|
326
|
+
self.enable_kernel_dump = True
|