mindstudio-probe 1.1.0__py3-none-any.whl → 1.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/METADATA +7 -6
- mindstudio_probe-1.2.1.dist-info/RECORD +396 -0
- {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/WHEEL +1 -1
- {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/entry_points.txt +0 -1
- msprobe/CMakeLists.txt +5 -0
- msprobe/README.md +51 -20
- msprobe/config.json +2 -3
- msprobe/core/advisor/advisor.py +8 -3
- msprobe/core/common/const.py +264 -15
- msprobe/core/common/exceptions.py +27 -3
- msprobe/core/common/file_utils.py +176 -26
- msprobe/core/common/inplace_op_checker.py +15 -0
- msprobe/core/common/inplace_ops.yaml +3 -0
- msprobe/core/common/log.py +27 -9
- msprobe/core/common/utils.py +204 -77
- msprobe/core/common_config.py +49 -14
- msprobe/core/compare/acc_compare.py +274 -198
- msprobe/core/compare/check.py +32 -33
- msprobe/core/compare/compare_cli.py +32 -14
- msprobe/core/compare/highlight.py +283 -127
- msprobe/core/compare/layer_mapping/__init__.py +19 -0
- msprobe/core/compare/layer_mapping/data_scope_parser.py +246 -0
- msprobe/core/compare/layer_mapping/layer_mapping.py +249 -0
- msprobe/core/compare/layer_mapping/postprocess_pass.py +95 -0
- msprobe/core/compare/merge_result/merge_result.py +380 -0
- msprobe/core/compare/merge_result/merge_result_cli.py +31 -0
- msprobe/core/compare/multiprocessing_compute.py +2 -2
- msprobe/core/compare/npy_compare.py +135 -144
- msprobe/core/compare/utils.py +419 -274
- msprobe/core/data_dump/data_collector.py +60 -28
- msprobe/core/data_dump/data_processor/base.py +84 -36
- msprobe/core/data_dump/data_processor/factory.py +5 -3
- msprobe/core/data_dump/data_processor/mindspore_processor.py +152 -18
- msprobe/core/data_dump/data_processor/pytorch_processor.py +267 -110
- msprobe/core/data_dump/json_writer.py +29 -1
- msprobe/core/data_dump/scope.py +119 -39
- msprobe/core/grad_probe/constant.py +27 -13
- msprobe/core/grad_probe/grad_compare.py +18 -1
- msprobe/core/grad_probe/utils.py +30 -2
- msprobe/core/overflow_check/abnormal_scene.py +189 -0
- msprobe/core/overflow_check/api_info.py +55 -0
- msprobe/core/overflow_check/checker.py +138 -0
- msprobe/core/overflow_check/filter.py +157 -0
- msprobe/core/overflow_check/ignore_rules.yaml +55 -0
- msprobe/core/overflow_check/level.py +22 -0
- msprobe/core/overflow_check/utils.py +28 -0
- msprobe/docs/01.installation.md +96 -7
- msprobe/docs/02.config_introduction.md +50 -23
- msprobe/docs/03.config_examples.md +2 -9
- msprobe/docs/04.kernel_dump_PyTorch.md +73 -0
- msprobe/docs/05.data_dump_PyTorch.md +93 -61
- msprobe/docs/06.data_dump_MindSpore.md +200 -95
- msprobe/docs/07.accuracy_checker_PyTorch.md +28 -28
- msprobe/docs/08.accuracy_checker_online_PyTorch.md +1 -6
- msprobe/docs/09.accuracy_checker_MindSpore.md +44 -8
- msprobe/docs/10.accuracy_compare_PyTorch.md +114 -50
- msprobe/docs/11.accuracy_compare_MindSpore.md +340 -48
- msprobe/docs/12.overflow_check_PyTorch.md +2 -2
- msprobe/docs/13.overflow_check_MindSpore.md +6 -6
- msprobe/docs/15.free_benchmarking_PyTorch.md +4 -5
- msprobe/docs/16.free_benchmarking_MindSpore.md +56 -37
- msprobe/docs/17.grad_probe.md +5 -6
- msprobe/docs/19.monitor.md +561 -0
- msprobe/docs/20.monitor_performance_baseline.md +52 -0
- msprobe/docs/21.visualization_PyTorch.md +466 -0
- msprobe/docs/22.visualization_MindSpore.md +481 -0
- msprobe/docs/23.generate_operator_PyTorch.md +107 -0
- msprobe/docs/24.code_mapping_Mindspore.md +28 -0
- msprobe/docs/25.tool_function_introduction.md +29 -0
- msprobe/docs/26.data_dump_PyTorch_baseline.md +37 -0
- msprobe/docs/27.dump_json_instruction.md +521 -0
- msprobe/docs/FAQ.md +29 -2
- msprobe/docs/accuracy_checker_MindSpore/accuracy_checker_MindSpore_baseline.md +14 -0
- msprobe/docs/data_dump_MindSpore/data_dump_MindSpore_baseline.md +22 -0
- msprobe/docs/data_dump_MindSpore/dynamic_graph_quick_start_example.md +211 -0
- msprobe/docs/img/compare_result.png +0 -0
- msprobe/docs/img/merge_result.png +0 -0
- msprobe/docs/img/monitor/cpu_info.png +0 -0
- msprobe/docs/img/visualization/fuzzy_match_ms.png +0 -0
- msprobe/docs/img/visualization/fuzzy_match_pt.png +0 -0
- msprobe/docs/img/visualization/tensorboard_1.png +0 -0
- msprobe/docs/img/visualization/tensorboard_2.png +0 -0
- msprobe/docs/img/visualization/vis_browser_1.png +0 -0
- msprobe/docs/img/visualization/vis_browser_2.png +0 -0
- msprobe/docs/img/visualization/vis_precision_info.png +0 -0
- msprobe/docs/img/visualization/vis_search_info.png +0 -0
- msprobe/docs/img/visualization/vis_show_info.png +0 -0
- msprobe/docs/img/visualization/vis_showcase.png +0 -0
- msprobe/docs/img/visualization/vis_unmatch_info.png +0 -0
- msprobe/docs/visualization/GPTModel.png +0 -0
- msprobe/docs/visualization/ParallelMLP.png +0 -0
- msprobe/docs/visualization/layer_mapping_example.md +132 -0
- msprobe/docs/visualization/mapping.png +0 -0
- msprobe/docs/visualization/mapping1.png +0 -0
- msprobe/docs/visualization/module_name.png +0 -0
- msprobe/docs/visualization/module_name1.png +0 -0
- msprobe/docs/visualization/no_mapping.png +0 -0
- msprobe/docs/visualization/no_mapping1.png +0 -0
- msprobe/docs/visualization/no_mapping_analyze.png +0 -0
- msprobe/docs/visualization/top_layer.png +0 -0
- msprobe/mindspore/__init__.py +25 -0
- msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +151 -151
- msprobe/mindspore/api_accuracy_checker/api_info.py +21 -6
- msprobe/mindspore/api_accuracy_checker/api_runner.py +43 -18
- msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +21 -7
- msprobe/mindspore/api_accuracy_checker/checker_support_api.yaml +77 -0
- msprobe/mindspore/api_accuracy_checker/cmd_parser.py +64 -1
- msprobe/mindspore/api_accuracy_checker/compute_element.py +64 -31
- msprobe/mindspore/api_accuracy_checker/data_manager.py +301 -0
- msprobe/mindspore/api_accuracy_checker/main.py +28 -3
- msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +212 -0
- msprobe/mindspore/api_accuracy_checker/multi_data_manager.py +60 -0
- msprobe/mindspore/api_accuracy_checker/type_mapping.py +22 -5
- msprobe/mindspore/api_accuracy_checker/utils.py +34 -17
- msprobe/mindspore/cell_processor.py +33 -12
- msprobe/mindspore/code_mapping/bind.py +264 -0
- msprobe/mindspore/code_mapping/cmd_parser.py +40 -0
- msprobe/mindspore/code_mapping/graph.py +49 -0
- msprobe/mindspore/code_mapping/graph_parser.py +226 -0
- msprobe/mindspore/code_mapping/main.py +24 -0
- msprobe/mindspore/code_mapping/processor.py +34 -0
- msprobe/mindspore/common/const.py +35 -13
- msprobe/mindspore/common/log.py +5 -9
- msprobe/mindspore/common/utils.py +88 -4
- msprobe/mindspore/compare/distributed_compare.py +22 -24
- msprobe/mindspore/compare/ms_compare.py +333 -268
- msprobe/mindspore/compare/ms_graph_compare.py +95 -52
- msprobe/mindspore/debugger/debugger_config.py +7 -1
- msprobe/mindspore/debugger/precision_debugger.py +87 -12
- msprobe/mindspore/dump/dump_tool_factory.py +3 -1
- msprobe/mindspore/dump/hook_cell/api_registry.py +95 -18
- msprobe/mindspore/dump/hook_cell/hook_cell.py +60 -38
- msprobe/mindspore/dump/hook_cell/primitive_hooks.py +45 -30
- msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +36 -1
- msprobe/mindspore/dump/hook_cell/wrap_api.py +92 -1
- msprobe/mindspore/dump/jit_dump.py +17 -5
- msprobe/mindspore/dump/kernel_dump/kernel_config.py +33 -0
- msprobe/mindspore/dump/kernel_graph_dump.py +9 -4
- msprobe/mindspore/dump/kernel_kbyk_dump.py +2 -4
- msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +140 -0
- msprobe/mindspore/dym_loader/hook_dynamic_loader.h +53 -0
- msprobe/mindspore/free_benchmark/api_pynative_self_check.py +156 -41
- msprobe/mindspore/free_benchmark/common/handler_params.py +1 -2
- msprobe/mindspore/free_benchmark/common/utils.py +19 -4
- msprobe/mindspore/free_benchmark/data/support_wrap_ops.yaml +0 -204
- msprobe/mindspore/free_benchmark/handler/base_handler.py +3 -3
- msprobe/mindspore/free_benchmark/handler/check_handler.py +4 -5
- msprobe/mindspore/free_benchmark/handler/fix_handler.py +4 -4
- msprobe/mindspore/free_benchmark/handler/handler_factory.py +4 -4
- msprobe/mindspore/free_benchmark/perturbation/add_noise.py +2 -2
- msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +15 -6
- msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +2 -2
- msprobe/mindspore/free_benchmark/perturbation/exchange_value.py +2 -2
- msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +13 -6
- msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +2 -2
- msprobe/mindspore/free_benchmark/self_check_tool_factory.py +2 -2
- msprobe/mindspore/grad_probe/global_context.py +28 -8
- msprobe/mindspore/grad_probe/grad_analyzer.py +50 -24
- msprobe/mindspore/grad_probe/grad_monitor.py +16 -1
- msprobe/mindspore/grad_probe/grad_stat_csv.py +33 -5
- msprobe/mindspore/grad_probe/hook.py +35 -12
- msprobe/mindspore/grad_probe/utils.py +18 -5
- msprobe/mindspore/mindtorch/__init__.py +18 -0
- msprobe/mindspore/mindtorch/mindtorch_adaptor.py +255 -0
- msprobe/mindspore/ms_config.py +27 -16
- msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +9 -4
- msprobe/mindspore/runtime.py +15 -0
- msprobe/mindspore/service.py +285 -113
- msprobe/mindspore/task_handler_factory.py +15 -0
- msprobe/msprobe.py +48 -10
- msprobe/pytorch/__init__.py +8 -6
- msprobe/pytorch/api_accuracy_checker/common/config.py +62 -0
- msprobe/pytorch/api_accuracy_checker/common/utils.py +31 -16
- msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +41 -8
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +103 -271
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_standard.yaml +4 -1
- msprobe/pytorch/api_accuracy_checker/compare/compare.py +69 -68
- msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +54 -0
- msprobe/pytorch/api_accuracy_checker/compare/compare_input.py +51 -0
- msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +2 -4
- msprobe/pytorch/api_accuracy_checker/generate_op_script/config_op.json +9 -0
- msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +478 -0
- msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +365 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/absolute_threshold.py +106 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/accumulative_error_compare.py +107 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/base_standard.py +151 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/benchmark_compare.py +226 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/binary_consistency.py +68 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/standard_config.py +218 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/standard_register.py +104 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/thousandth_standard.py +63 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/ulp_compare.py +200 -0
- msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +63 -2
- msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +21 -15
- msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +54 -22
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +140 -71
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +49 -8
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +9 -24
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +4 -12
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +5 -3
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/dump_dispatch.py +9 -4
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +3 -11
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +2 -2
- msprobe/pytorch/bench_functions/confusion_transpose.py +5 -1
- msprobe/pytorch/bench_functions/matmul_backward.py +12 -0
- msprobe/pytorch/bench_functions/npu_fusion_attention.py +142 -16
- msprobe/pytorch/bench_functions/rotary_mul.py +4 -0
- msprobe/pytorch/bench_functions/swiglu.py +10 -2
- msprobe/pytorch/common/parse_json.py +7 -6
- msprobe/pytorch/common/utils.py +101 -7
- msprobe/pytorch/compare/distributed_compare.py +17 -30
- msprobe/pytorch/compare/pt_compare.py +44 -22
- msprobe/pytorch/debugger/debugger_config.py +46 -27
- msprobe/pytorch/debugger/precision_debugger.py +42 -12
- msprobe/pytorch/dump/kernel_dump/kernel_config.py +33 -0
- msprobe/pytorch/dump/module_dump/module_dump.py +86 -0
- msprobe/pytorch/{module_processer.py → dump/module_dump/module_processer.py} +81 -10
- msprobe/pytorch/free_benchmark/common/constant.py +15 -0
- msprobe/pytorch/free_benchmark/common/counter.py +15 -0
- msprobe/pytorch/free_benchmark/common/enums.py +15 -0
- msprobe/pytorch/free_benchmark/common/params.py +10 -2
- msprobe/pytorch/free_benchmark/common/utils.py +29 -4
- msprobe/pytorch/free_benchmark/compare/grad_saver.py +20 -5
- msprobe/pytorch/free_benchmark/compare/single_benchmark.py +2 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +3 -1
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +6 -4
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +2 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +4 -0
- msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +41 -47
- msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +6 -5
- msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +0 -4
- msprobe/pytorch/grad_probe/grad_monitor.py +23 -6
- msprobe/pytorch/grad_probe/grad_stat_csv.py +40 -10
- msprobe/pytorch/hook_module/__init__.py +1 -1
- msprobe/pytorch/hook_module/hook_module.py +14 -11
- msprobe/pytorch/hook_module/register_optimizer_hook.py +59 -0
- msprobe/pytorch/hook_module/support_wrap_ops.yaml +35 -0
- msprobe/pytorch/hook_module/wrap_distributed.py +6 -8
- msprobe/pytorch/hook_module/wrap_functional.py +0 -38
- msprobe/pytorch/monitor/__init__.py +0 -0
- msprobe/pytorch/monitor/anomaly_analyse.py +201 -0
- msprobe/pytorch/monitor/anomaly_detect.py +425 -0
- msprobe/pytorch/monitor/csv2tb.py +166 -0
- msprobe/pytorch/monitor/distributed/__init__.py +0 -0
- msprobe/pytorch/monitor/distributed/distributed_ops.yaml +19 -0
- msprobe/pytorch/monitor/distributed/stack_blacklist.yaml +5 -0
- msprobe/pytorch/monitor/distributed/wrap_distributed.py +283 -0
- msprobe/pytorch/monitor/features.py +108 -0
- msprobe/pytorch/monitor/module_hook.py +1076 -0
- msprobe/pytorch/monitor/module_metric.py +172 -0
- msprobe/pytorch/monitor/module_spec_verifier.py +95 -0
- msprobe/pytorch/monitor/optimizer_collect.py +333 -0
- msprobe/pytorch/monitor/unittest/__init__.py +0 -0
- msprobe/pytorch/monitor/unittest/test_monitor.py +160 -0
- msprobe/pytorch/monitor/utils.py +321 -0
- msprobe/pytorch/monitor/visualizer.py +59 -0
- msprobe/pytorch/online_dispatch/__init__.py +2 -3
- msprobe/pytorch/online_dispatch/compare.py +29 -38
- msprobe/pytorch/online_dispatch/dispatch.py +58 -27
- msprobe/pytorch/online_dispatch/dump_compare.py +21 -9
- msprobe/pytorch/online_dispatch/single_compare.py +53 -32
- msprobe/pytorch/online_dispatch/torch_ops_config.yaml +1 -1
- msprobe/pytorch/online_dispatch/utils.py +49 -21
- msprobe/pytorch/parse_tool/lib/compare.py +21 -27
- msprobe/pytorch/parse_tool/lib/config.py +6 -8
- msprobe/pytorch/parse_tool/lib/file_desc.py +15 -1
- msprobe/pytorch/parse_tool/lib/interactive_cli.py +10 -10
- msprobe/pytorch/parse_tool/lib/parse_exception.py +7 -7
- msprobe/pytorch/parse_tool/lib/parse_tool.py +12 -12
- msprobe/pytorch/parse_tool/lib/utils.py +33 -53
- msprobe/pytorch/parse_tool/lib/visualization.py +11 -10
- msprobe/pytorch/pt_config.py +31 -8
- msprobe/pytorch/service.py +188 -108
- msprobe/visualization/__init__.py +14 -0
- msprobe/visualization/builder/__init__.py +14 -0
- msprobe/visualization/builder/graph_builder.py +222 -0
- msprobe/visualization/builder/msprobe_adapter.py +227 -0
- msprobe/visualization/compare/__init__.py +14 -0
- msprobe/visualization/compare/graph_comparator.py +180 -0
- msprobe/visualization/compare/mode_adapter.py +197 -0
- msprobe/visualization/graph/__init__.py +14 -0
- msprobe/visualization/graph/base_node.py +119 -0
- msprobe/visualization/graph/distributed_analyzer.py +318 -0
- msprobe/visualization/graph/graph.py +209 -0
- msprobe/visualization/graph/node_colors.py +95 -0
- msprobe/visualization/graph/node_op.py +39 -0
- msprobe/visualization/graph_service.py +288 -0
- msprobe/visualization/utils.py +217 -0
- mindstudio_probe-1.1.0.dist-info/RECORD +0 -287
- msprobe/docs/04.acl_config_examples.md +0 -78
- msprobe/mindspore/compare/layer_mapping.py +0 -146
- msprobe/mindspore/compare/modify_mapping.py +0 -107
- msprobe/mindspore/free_benchmark/decorator/dec_forward.py +0 -57
- msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +0 -122
- msprobe/pytorch/functional/module_dump.py +0 -84
- {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/LICENSE +0 -0
- {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/top_level.txt +0 -0
- /msprobe/mindspore/{free_benchmark/decorator → code_mapping}/__init__.py +0 -0
- /msprobe/pytorch/{functional → dump/module_dump}/__init__.py +0 -0
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
2
2
|
# All rights reserved.
|
|
3
3
|
#
|
|
4
|
-
# Licensed under the Apache License, Version 2.0
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
5
|
# you may not use this file except in compliance with the License.
|
|
6
6
|
# You may obtain a copy of the License at
|
|
7
7
|
#
|
|
@@ -18,9 +18,11 @@ from typing import Any
|
|
|
18
18
|
import mindspore as ms
|
|
19
19
|
from mindspore import Tensor, ops
|
|
20
20
|
|
|
21
|
-
from msprobe.
|
|
21
|
+
from msprobe.core.common.const import Const
|
|
22
22
|
from msprobe.mindspore.common.log import logger
|
|
23
|
+
from msprobe.mindspore.free_benchmark.common.config import Config
|
|
23
24
|
from msprobe.mindspore.free_benchmark.common.handler_params import HandlerParams
|
|
25
|
+
from msprobe.mindspore.free_benchmark.common.utils import Tools
|
|
24
26
|
from msprobe.mindspore.free_benchmark.perturbation.base_perturbation import BasePerturbation
|
|
25
27
|
|
|
26
28
|
|
|
@@ -40,10 +42,15 @@ class ImprovePrecisionPerturbation(BasePerturbation):
|
|
|
40
42
|
def handle(self, params: HandlerParams) -> Any:
|
|
41
43
|
args = self.improve_tensor_precision(params.args)
|
|
42
44
|
kwargs = self.improve_tensor_precision(params.kwargs)
|
|
43
|
-
fuzzed_value = args
|
|
44
|
-
if self.api_name in Const.COMMUNICATION_API_LIST:
|
|
45
|
-
params.fuzzed_value = fuzzed_value
|
|
46
45
|
if not self.is_fuzzed:
|
|
47
|
-
logger.warning(f"{self.
|
|
46
|
+
logger.warning(f"{self.api_name_with_id} can not improve precision.")
|
|
48
47
|
return False
|
|
48
|
+
|
|
49
|
+
if Config.stage == Const.BACKWARD:
|
|
50
|
+
fuzzed_result = Tools.get_grad(params.original_func, *args, **kwargs)
|
|
51
|
+
if fuzzed_result is not None:
|
|
52
|
+
return fuzzed_result
|
|
53
|
+
else:
|
|
54
|
+
return False
|
|
55
|
+
|
|
49
56
|
return params.original_func(*args, **kwargs)
|
|
@@ -36,9 +36,9 @@ class PerturbationFactory:
|
|
|
36
36
|
}
|
|
37
37
|
|
|
38
38
|
@staticmethod
|
|
39
|
-
def create(
|
|
39
|
+
def create(api_name_with_id: str):
|
|
40
40
|
perturbation = PerturbationFactory.perturbations.get(Config.pert_type)
|
|
41
41
|
if perturbation:
|
|
42
|
-
return perturbation(
|
|
42
|
+
return perturbation(api_name_with_id)
|
|
43
43
|
else:
|
|
44
44
|
raise Exception(f'{Config.pert_type} is a invalid perturbation type')
|
|
@@ -15,7 +15,7 @@
|
|
|
15
15
|
|
|
16
16
|
from msprobe.mindspore.common.const import Const
|
|
17
17
|
from msprobe.mindspore.debugger.debugger_config import DebuggerConfig
|
|
18
|
-
from msprobe.mindspore.free_benchmark.api_pynative_self_check import
|
|
18
|
+
from msprobe.mindspore.free_benchmark.api_pynative_self_check import ApiPyNativeSelfCheck
|
|
19
19
|
|
|
20
20
|
|
|
21
21
|
class SelfCheckToolFactory:
|
|
@@ -28,7 +28,7 @@ class SelfCheckToolFactory:
|
|
|
28
28
|
Const.API: {
|
|
29
29
|
Const.GRAPH_KBYK_MODE: None,
|
|
30
30
|
Const.GRAPH_GE_MODE: None,
|
|
31
|
-
Const.PYNATIVE_MODE:
|
|
31
|
+
Const.PYNATIVE_MODE: ApiPyNativeSelfCheck
|
|
32
32
|
},
|
|
33
33
|
Const.KERNEL: {
|
|
34
34
|
Const.GRAPH_KBYK_MODE: None,
|
|
@@ -1,15 +1,30 @@
|
|
|
1
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
1
16
|
import os
|
|
2
17
|
import threading
|
|
3
18
|
from typing import Dict, Union, Tuple
|
|
4
19
|
|
|
5
|
-
from msprobe.core.
|
|
20
|
+
from msprobe.core.common.utils import is_int
|
|
21
|
+
from msprobe.core.common.file_utils import create_directory, check_path_before_create
|
|
6
22
|
from msprobe.core.grad_probe.constant import GradConst
|
|
23
|
+
from msprobe.core.grad_probe.utils import check_str, check_bounds_element, check_param_element
|
|
7
24
|
from msprobe.mindspore.common.log import logger
|
|
8
|
-
from msprobe.core.common.file_utils import create_directory, check_path_before_create
|
|
9
25
|
|
|
10
26
|
|
|
11
27
|
class GlobalContext:
|
|
12
|
-
|
|
13
28
|
_instance = None
|
|
14
29
|
_instance_lock = threading.Lock()
|
|
15
30
|
_setting = {
|
|
@@ -37,10 +52,10 @@ class GlobalContext:
|
|
|
37
52
|
else:
|
|
38
53
|
raise ValueError("Invalid level set in config yaml file, level option: L0, L1, L2")
|
|
39
54
|
|
|
40
|
-
self._set_input_list(config_dict, GradConst.PARAM_LIST, str)
|
|
55
|
+
self._set_input_list(config_dict, GradConst.PARAM_LIST, (str,), element_check=check_param_element)
|
|
41
56
|
self._set_input_list(config_dict, GradConst.BOUNDS, (float, int), element_check=check_bounds_element)
|
|
42
|
-
self._set_input_list(config_dict, GradConst.STEP, int)
|
|
43
|
-
self._set_input_list(config_dict, GradConst.RANK, int)
|
|
57
|
+
self._set_input_list(config_dict, GradConst.STEP, (int,))
|
|
58
|
+
self._set_input_list(config_dict, GradConst.RANK, (int,))
|
|
44
59
|
|
|
45
60
|
output_path = config_dict.get(GradConst.OUTPUT_PATH)
|
|
46
61
|
check_str(output_path, variable_name="output_path in yaml")
|
|
@@ -88,13 +103,18 @@ class GlobalContext:
|
|
|
88
103
|
if value and isinstance(value, list):
|
|
89
104
|
for val in value:
|
|
90
105
|
if not isinstance(val, dtype):
|
|
91
|
-
logger.warning(f"Invalid {name} which must be None or list of {type_str}")
|
|
106
|
+
logger.warning(f"Invalid {name} which must be None or list of {type_str}, use default value.")
|
|
107
|
+
return
|
|
108
|
+
elif isinstance(val, int) and not is_int(val):
|
|
109
|
+
logger.warning(f"Invalid {name} which must be None or list of int, use default value.")
|
|
92
110
|
return
|
|
93
111
|
if element_check and not element_check(val):
|
|
94
|
-
logger.warning(f"Given {name} violates some rules.")
|
|
112
|
+
logger.warning(f"Given {name} violates some rules, use default value.")
|
|
95
113
|
return
|
|
114
|
+
|
|
96
115
|
self._setting[name] = value
|
|
97
116
|
else:
|
|
98
117
|
logger.warning(f"{name} is None or not a list with valid items, use default value.")
|
|
99
118
|
|
|
119
|
+
|
|
100
120
|
grad_context = GlobalContext()
|
|
@@ -1,23 +1,48 @@
|
|
|
1
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
import multiprocessing
|
|
1
17
|
import os
|
|
2
18
|
import time
|
|
3
|
-
from
|
|
4
|
-
import multiprocessing
|
|
19
|
+
from dataclasses import dataclass
|
|
5
20
|
from multiprocessing import Process
|
|
21
|
+
from typing import List
|
|
6
22
|
|
|
7
|
-
import numpy as np
|
|
8
23
|
import mindspore as ms
|
|
9
|
-
|
|
10
|
-
from mindspore.ops import operations as P
|
|
24
|
+
import numpy as np
|
|
11
25
|
from mindspore.common.parameter import Parameter
|
|
26
|
+
from mindspore.communication import get_rank
|
|
12
27
|
|
|
13
|
-
from msprobe.core.grad_probe.utils import ListCache
|
|
14
|
-
from msprobe.core.grad_probe.constant import GradConst
|
|
15
|
-
from msprobe.mindspore.common.log import logger
|
|
16
28
|
from msprobe.core.common.file_utils import (create_directory, check_file_or_directory_path,
|
|
17
29
|
write_csv, remove_path, move_file, load_npy)
|
|
30
|
+
from msprobe.core.grad_probe.constant import GradConst
|
|
31
|
+
from msprobe.core.grad_probe.utils import ListCache
|
|
32
|
+
from msprobe.mindspore.common.log import logger
|
|
18
33
|
from msprobe.mindspore.grad_probe.global_context import grad_context, GlobalContext
|
|
19
34
|
|
|
20
35
|
|
|
36
|
+
@dataclass
|
|
37
|
+
class GradDumpConfig:
|
|
38
|
+
dump_dir: str
|
|
39
|
+
g_name: str
|
|
40
|
+
dump_step: Parameter
|
|
41
|
+
grad: ms.Tensor
|
|
42
|
+
level: str
|
|
43
|
+
bounds: List
|
|
44
|
+
|
|
45
|
+
|
|
21
46
|
def get_rank_id():
|
|
22
47
|
try:
|
|
23
48
|
rank_id = get_rank()
|
|
@@ -27,35 +52,35 @@ def get_rank_id():
|
|
|
27
52
|
|
|
28
53
|
|
|
29
54
|
@ms.jit
|
|
30
|
-
def grad_dump(
|
|
31
|
-
|
|
55
|
+
def grad_dump(config: GradDumpConfig):
|
|
56
|
+
"""
|
|
32
57
|
Dump gradient statistic data.
|
|
33
58
|
level0: [step, max, min, norm, shape_dim, shape]
|
|
34
59
|
level1: [step, max, min, norm, shape_dim, shape] + grad_bool_data
|
|
35
60
|
level2: [step, max, min, norm, shape_dim, shape, dist_dim, dist] + grad_bool_data
|
|
36
|
-
|
|
37
|
-
dump_path = os.path.join(dump_dir, g_name)
|
|
61
|
+
"""
|
|
62
|
+
dump_path = os.path.join(config.dump_dir, config.g_name)
|
|
38
63
|
dump_dir_path = dump_path + "_dir"
|
|
39
64
|
save_op = ms.ops.TensorDump()
|
|
40
65
|
|
|
41
|
-
grad_flat = grad.reshape(-1)
|
|
66
|
+
grad_flat = config.grad.reshape(-1)
|
|
42
67
|
max_val = grad_flat.max(axis=0).float()
|
|
43
68
|
min_val = grad_flat.min(axis=0).float()
|
|
44
69
|
norm_val = grad_flat.norm(ord=2).float()
|
|
45
|
-
shape = grad.shape
|
|
46
|
-
extrem_list = [dump_step[0].float(), max_val, min_val, norm_val]
|
|
70
|
+
shape = config.grad.shape
|
|
71
|
+
extrem_list = [config.dump_step[0].float(), max_val, min_val, norm_val]
|
|
47
72
|
extrem_stat = ms.ops.stack(extrem_list)
|
|
48
73
|
shape_list = [len(shape)] + list(shape)
|
|
49
74
|
shape_stat = ms.Tensor(shape_list).float()
|
|
50
75
|
level0_stat = ms.ops.concat((extrem_stat, shape_stat), axis=0)
|
|
51
76
|
level_stat = level0_stat
|
|
52
77
|
|
|
53
|
-
if level == GradConst.LEVEL2:
|
|
54
|
-
zero_grad = (grad == 0).sum()
|
|
55
|
-
dist_dim = ms.Tensor([len(bounds) + 2]).float()
|
|
56
|
-
bucket_result = ms.ops.bucketize(grad.float(), bounds)
|
|
78
|
+
if config.level == GradConst.LEVEL2:
|
|
79
|
+
zero_grad = (config.grad == 0).sum()
|
|
80
|
+
dist_dim = ms.Tensor([len(config.bounds) + 2]).float()
|
|
81
|
+
bucket_result = ms.ops.bucketize(config.grad.float(), config.bounds)
|
|
57
82
|
bucket_result = bucket_result.astype(ms.int8)
|
|
58
|
-
dist_stat = [(bucket_result == i).sum() for i in range(len(bounds) + 1)]
|
|
83
|
+
dist_stat = [(bucket_result == i).sum() for i in range(len(config.bounds) + 1)]
|
|
59
84
|
dist_stat.append(zero_grad)
|
|
60
85
|
dist_stat.append(ms.Tensor(1, dtype=ms.int64)) # make sure dist_stat is not empty
|
|
61
86
|
dist_stat = ms.ops.stack(dist_stat, axis=0).float()
|
|
@@ -63,8 +88,8 @@ def grad_dump(dump_dir: str, g_name: str, dump_step: Parameter, grad: ms.Tensor,
|
|
|
63
88
|
level_stat = level2_stat
|
|
64
89
|
|
|
65
90
|
save_op(dump_path, level_stat)
|
|
66
|
-
if level == GradConst.LEVEL1 or level == GradConst.LEVEL2:
|
|
67
|
-
grad_direction = grad > 0
|
|
91
|
+
if config.level == GradConst.LEVEL1 or config.level == GradConst.LEVEL2:
|
|
92
|
+
grad_direction = config.grad > 0
|
|
68
93
|
save_op(dump_dir_path, grad_direction)
|
|
69
94
|
|
|
70
95
|
|
|
@@ -182,7 +207,7 @@ class CSVGenerator(Process):
|
|
|
182
207
|
shape_dim = int(stat_data[GradConst.SHAPE_DIM_IDX])
|
|
183
208
|
file_name = os.path.basename(file_path)
|
|
184
209
|
prefix_idx = len(file_name.split("_")[0])
|
|
185
|
-
param_name = file_name[(prefix_idx + 1)
|
|
210
|
+
param_name = file_name[(prefix_idx + 1): -(len(GradConst.NPY_SUFFIX) + 1)]
|
|
186
211
|
if not param_name:
|
|
187
212
|
raise RuntimeError("Invalid gradient statistic file name.")
|
|
188
213
|
csv_line = [param_name]
|
|
@@ -224,8 +249,9 @@ class CSVGenerator(Process):
|
|
|
224
249
|
if i == 0:
|
|
225
250
|
intervals.append(f"(-inf, {self.bounds[i]}]")
|
|
226
251
|
else:
|
|
227
|
-
intervals.append(f"({self.bounds[i-1]}, {self.bounds[i]}]")
|
|
252
|
+
intervals.append(f"({self.bounds[i - 1]}, {self.bounds[i]}]")
|
|
228
253
|
intervals.extend([f"({self.bounds[-1]}, inf)", "=0"])
|
|
229
254
|
return intervals
|
|
230
255
|
|
|
256
|
+
|
|
231
257
|
csv_generator = CSVGenerator()
|
|
@@ -1,7 +1,22 @@
|
|
|
1
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
from msprobe.core.grad_probe.constant import GradConst
|
|
1
17
|
from msprobe.mindspore.grad_probe.global_context import grad_context
|
|
2
18
|
from msprobe.mindspore.grad_probe.grad_analyzer import csv_generator
|
|
3
19
|
from msprobe.mindspore.grad_probe.hook import hook_optimizer
|
|
4
|
-
from msprobe.core.grad_probe.constant import GradConst
|
|
5
20
|
|
|
6
21
|
|
|
7
22
|
class GradientMonitor:
|
|
@@ -1,8 +1,23 @@
|
|
|
1
|
-
|
|
1
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
2
16
|
import hashlib
|
|
17
|
+
from abc import ABC, abstractmethod
|
|
3
18
|
|
|
4
19
|
import mindspore
|
|
5
|
-
from mindspore import ops
|
|
20
|
+
from mindspore import ops
|
|
6
21
|
from msprobe.core.grad_probe.constant import GradConst
|
|
7
22
|
|
|
8
23
|
|
|
@@ -12,6 +27,7 @@ class CsvInput:
|
|
|
12
27
|
self.grad = grad
|
|
13
28
|
self.bounds = bounds
|
|
14
29
|
|
|
30
|
+
|
|
15
31
|
class GradStatCsv:
|
|
16
32
|
csv = {}
|
|
17
33
|
|
|
@@ -52,9 +68,11 @@ class CsvItem(ABC):
|
|
|
52
68
|
|
|
53
69
|
@register_csv_item(GradConst.MD5)
|
|
54
70
|
class CsvMd5(CsvItem):
|
|
71
|
+
@staticmethod
|
|
55
72
|
def generate_csv_header(csv_input):
|
|
56
73
|
return ["MD5"]
|
|
57
74
|
|
|
75
|
+
@staticmethod
|
|
58
76
|
def generate_csv_content(csv_input):
|
|
59
77
|
grad = csv_input.grad
|
|
60
78
|
tensor_bytes = grad.float().numpy().tobytes()
|
|
@@ -64,19 +82,21 @@ class CsvMd5(CsvItem):
|
|
|
64
82
|
|
|
65
83
|
@register_csv_item(GradConst.DISTRIBUTION)
|
|
66
84
|
class CsvDistribution(CsvItem):
|
|
85
|
+
@staticmethod
|
|
67
86
|
def generate_csv_header(csv_input):
|
|
68
87
|
bounds = csv_input.bounds
|
|
69
88
|
intervals = []
|
|
70
89
|
if bounds:
|
|
71
90
|
intervals.append(f"(-inf, {bounds[0]}]")
|
|
72
91
|
for i in range(1, len(bounds)):
|
|
73
|
-
intervals.append(f"({bounds[i-1]}, {bounds[i]}]")
|
|
92
|
+
intervals.append(f"({bounds[i - 1]}, {bounds[i]}]")
|
|
74
93
|
if intervals:
|
|
75
94
|
intervals.append(f"({bounds[-1]}, inf)")
|
|
76
95
|
intervals.append("=0")
|
|
77
|
-
|
|
96
|
+
|
|
78
97
|
return intervals
|
|
79
98
|
|
|
99
|
+
@staticmethod
|
|
80
100
|
def generate_csv_content(csv_input):
|
|
81
101
|
grad = csv_input.grad
|
|
82
102
|
bounds = csv_input.bounds
|
|
@@ -94,9 +114,11 @@ class CsvDistribution(CsvItem):
|
|
|
94
114
|
|
|
95
115
|
@register_csv_item(GradConst.MAX)
|
|
96
116
|
class CsvMax(CsvItem):
|
|
117
|
+
@staticmethod
|
|
97
118
|
def generate_csv_header(csv_input):
|
|
98
119
|
return ["max"]
|
|
99
120
|
|
|
121
|
+
@staticmethod
|
|
100
122
|
def generate_csv_content(csv_input):
|
|
101
123
|
grad = csv_input.grad
|
|
102
124
|
return [ops.amax(grad).float().numpy().tolist()]
|
|
@@ -104,9 +126,11 @@ class CsvMax(CsvItem):
|
|
|
104
126
|
|
|
105
127
|
@register_csv_item(GradConst.MIN)
|
|
106
128
|
class CsvMin(CsvItem):
|
|
129
|
+
@staticmethod
|
|
107
130
|
def generate_csv_header(csv_input):
|
|
108
131
|
return ["min"]
|
|
109
132
|
|
|
133
|
+
@staticmethod
|
|
110
134
|
def generate_csv_content(csv_input):
|
|
111
135
|
grad = csv_input.grad
|
|
112
136
|
return [ops.amin(grad).float().numpy().tolist()]
|
|
@@ -114,9 +138,11 @@ class CsvMin(CsvItem):
|
|
|
114
138
|
|
|
115
139
|
@register_csv_item(GradConst.NORM)
|
|
116
140
|
class CsvNorm(CsvItem):
|
|
141
|
+
@staticmethod
|
|
117
142
|
def generate_csv_header(csv_input):
|
|
118
143
|
return ["norm"]
|
|
119
144
|
|
|
145
|
+
@staticmethod
|
|
120
146
|
def generate_csv_content(csv_input):
|
|
121
147
|
grad = csv_input.grad
|
|
122
148
|
return [ops.norm(grad).float().numpy().tolist()]
|
|
@@ -124,9 +150,11 @@ class CsvNorm(CsvItem):
|
|
|
124
150
|
|
|
125
151
|
@register_csv_item(GradConst.SHAPE)
|
|
126
152
|
class CsvShape(CsvItem):
|
|
153
|
+
@staticmethod
|
|
127
154
|
def generate_csv_header(csv_input):
|
|
128
155
|
return ["shape"]
|
|
129
156
|
|
|
157
|
+
@staticmethod
|
|
130
158
|
def generate_csv_content(csv_input):
|
|
131
159
|
grad = csv_input.grad
|
|
132
|
-
return [list(grad.shape)]
|
|
160
|
+
return [list(grad.shape)]
|
|
@@ -1,32 +1,51 @@
|
|
|
1
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
1
15
|
|
|
2
16
|
import os
|
|
3
17
|
|
|
4
18
|
import mindspore
|
|
5
19
|
import mindspore as ms
|
|
6
20
|
from mindspore.common.api import jit
|
|
7
|
-
from mindspore.nn.optim.optimizer import Optimizer
|
|
8
|
-
from mindspore.common.parameter import Parameter
|
|
9
21
|
from mindspore.common.initializer import initializer
|
|
10
|
-
|
|
22
|
+
from mindspore.common.parameter import Parameter
|
|
23
|
+
from mindspore.nn.optim.optimizer import Optimizer
|
|
24
|
+
from msprobe.core.common.file_utils import remove_path, write_csv, create_directory
|
|
11
25
|
from msprobe.core.grad_probe.constant import GradConst
|
|
12
26
|
from msprobe.mindspore.common.log import logger
|
|
13
|
-
|
|
14
|
-
from msprobe.core.common.file_utils import remove_path, write_csv, create_directory
|
|
15
27
|
from msprobe.mindspore.grad_probe.global_context import grad_context
|
|
16
|
-
from msprobe.mindspore.grad_probe.grad_analyzer import grad_dump, get_rank_id
|
|
17
28
|
from msprobe.mindspore.grad_probe.grad_analyzer import csv_generator
|
|
29
|
+
from msprobe.mindspore.grad_probe.grad_analyzer import grad_dump, get_rank_id, GradDumpConfig
|
|
18
30
|
from msprobe.mindspore.grad_probe.grad_stat_csv import GradStatCsv, CsvInput
|
|
19
31
|
from msprobe.mindspore.grad_probe.utils import save_grad_direction, get_adapted_level
|
|
20
32
|
|
|
21
|
-
class HookInput:
|
|
22
33
|
|
|
34
|
+
class HookInput:
|
|
23
35
|
'''
|
|
24
36
|
HookInput is a class wrapping all the variables used for hooking optimizer
|
|
25
37
|
'''
|
|
26
38
|
|
|
27
39
|
def __init__(self, opt) -> None:
|
|
28
40
|
self.func = opt.construct
|
|
29
|
-
|
|
41
|
+
if hasattr(opt, "_parameters"):
|
|
42
|
+
parameter_list = opt._parameters
|
|
43
|
+
elif hasattr(opt, "parameters"):
|
|
44
|
+
parameter_list = opt.parameters
|
|
45
|
+
else:
|
|
46
|
+
logger.error_log_with_exp("Given optimizer has no attributes: '_parameters' or 'parameters'. \
|
|
47
|
+
Please check the type of the given optimizer.", ValueError)
|
|
48
|
+
self.g_names = [param.name for param in parameter_list]
|
|
30
49
|
self.param_list = grad_context.get_context(GradConst.PARAM_LIST)
|
|
31
50
|
self.rank_id = get_rank_id()
|
|
32
51
|
output_path = grad_context.get_context(GradConst.OUTPUT_PATH)
|
|
@@ -40,14 +59,17 @@ class HookInput:
|
|
|
40
59
|
self.bounds = grad_context.get_context(GradConst.BOUNDS)
|
|
41
60
|
self.mode = mindspore.get_context("mode")
|
|
42
61
|
|
|
62
|
+
|
|
43
63
|
def hook_graph_mode_optimizer(opt, hook_input):
|
|
44
64
|
@jit
|
|
45
65
|
def new_construct(self, gradients):
|
|
46
66
|
for index, grad_value in enumerate(gradients):
|
|
47
67
|
if hook_input.param_list and hook_input.g_names[index] not in hook_input.param_list:
|
|
48
68
|
continue
|
|
49
|
-
|
|
50
|
-
|
|
69
|
+
conf = GradDumpConfig(dump_dir=hook_input.dump_dir, g_name=hook_input.g_names[index],
|
|
70
|
+
dump_step=self.dump_step, grad=grad_value, level=hook_input.level,
|
|
71
|
+
bounds=hook_input.bounds)
|
|
72
|
+
grad_dump(conf)
|
|
51
73
|
ms.ops.TensorDump()(hook_input.step_finish_flag, self.dump_step)
|
|
52
74
|
self.assignadd(self.dump_step, self.global_step_increase_tensor)
|
|
53
75
|
out = hook_input.func(gradients)
|
|
@@ -57,11 +79,12 @@ def hook_graph_mode_optimizer(opt, hook_input):
|
|
|
57
79
|
opt.construct = new_construct.__get__(opt, type(opt))
|
|
58
80
|
csv_generator.start()
|
|
59
81
|
|
|
82
|
+
|
|
60
83
|
def hook_pynative_optimizer(opt, hook_input):
|
|
61
84
|
level_adapted = get_adapted_level(hook_input.level)
|
|
62
85
|
|
|
63
|
-
def hook_fn(cell,
|
|
64
|
-
gradients, =
|
|
86
|
+
def hook_fn(cell, input_data):
|
|
87
|
+
gradients, = input_data
|
|
65
88
|
cur_step = grad_context.get_context(GradConst.CURRENT_STEP)
|
|
66
89
|
if grad_context.step_need_dump(cur_step) and grad_context.rank_need_dump(hook_input.rank_id):
|
|
67
90
|
create_directory(hook_input.save_dir)
|
|
@@ -1,12 +1,26 @@
|
|
|
1
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
1
16
|
import os
|
|
2
17
|
|
|
3
18
|
import mindspore
|
|
4
|
-
from msprobe.core.grad_probe.constant import level_adp
|
|
5
|
-
from msprobe.core.grad_probe.utils import check_param
|
|
6
19
|
from msprobe.core.common.file_utils import (create_directory,
|
|
7
|
-
check_path_before_create,
|
|
8
20
|
check_file_or_directory_path,
|
|
9
21
|
save_npy)
|
|
22
|
+
from msprobe.core.grad_probe.constant import level_adp
|
|
23
|
+
from msprobe.core.grad_probe.utils import check_param
|
|
10
24
|
|
|
11
25
|
|
|
12
26
|
def save_grad_direction(param_name, grad, save_path):
|
|
@@ -15,7 +29,6 @@ def save_grad_direction(param_name, grad, save_path):
|
|
|
15
29
|
check_file_or_directory_path(save_path, isdir=True)
|
|
16
30
|
check_param(param_name)
|
|
17
31
|
save_filepath = os.path.join(save_path, f"{param_name}.npy")
|
|
18
|
-
check_path_before_create(save_filepath)
|
|
19
32
|
|
|
20
33
|
if grad.dtype == mindspore.bfloat16:
|
|
21
34
|
grad = grad.to(mindspore.float32)
|
|
@@ -27,4 +40,4 @@ def save_grad_direction(param_name, grad, save_path):
|
|
|
27
40
|
|
|
28
41
|
def get_adapted_level(level: str):
|
|
29
42
|
level_adapted = level_adp.get(level)
|
|
30
|
-
return level_adapted
|
|
43
|
+
return level_adapted
|
|
@@ -0,0 +1,18 @@
|
|
|
1
|
+
# Copyright (c) 2025-2025, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
from .mindtorch_adaptor import (_call_impl,
|
|
17
|
+
register_full_backward_pre_hook,
|
|
18
|
+
register_full_backward_hook)
|