mindstudio-probe 1.3.0__py3-none-any.whl → 8.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/METADATA +4 -2
- {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/RECORD +204 -152
- msprobe/README.md +32 -1
- msprobe/core/__init__.py +17 -0
- msprobe/core/common/const.py +120 -21
- msprobe/core/common/exceptions.py +2 -2
- msprobe/core/common/file_utils.py +279 -50
- msprobe/core/common/framework_adapter.py +169 -0
- msprobe/core/common/global_lock.py +86 -0
- msprobe/core/common/runtime.py +25 -0
- msprobe/core/common/utils.py +136 -45
- msprobe/core/common_config.py +7 -0
- msprobe/core/compare/acc_compare.py +646 -428
- msprobe/core/compare/check.py +36 -103
- msprobe/core/compare/compare_cli.py +4 -0
- msprobe/core/compare/config.py +72 -0
- msprobe/core/compare/highlight.py +215 -215
- msprobe/core/compare/layer_mapping/layer_mapping.py +2 -0
- msprobe/core/compare/merge_result/merge_result.py +4 -4
- msprobe/core/compare/multiprocessing_compute.py +223 -110
- msprobe/core/compare/npy_compare.py +2 -4
- msprobe/core/compare/utils.py +214 -244
- msprobe/core/config_check/__init__.py +17 -0
- msprobe/{pytorch/dump/kernel_dump/kernel_config.py → core/config_check/checkers/__init__.py} +8 -16
- msprobe/core/config_check/checkers/base_checker.py +60 -0
- msprobe/core/config_check/checkers/dataset_checker.py +138 -0
- msprobe/core/config_check/checkers/env_args_checker.py +96 -0
- msprobe/core/config_check/checkers/hyperparameter_checker.py +170 -0
- msprobe/core/config_check/checkers/pip_checker.py +90 -0
- msprobe/core/config_check/checkers/random_checker.py +367 -0
- msprobe/core/config_check/checkers/weights_checker.py +147 -0
- msprobe/core/config_check/ckpt_compare/ckpt_comparator.py +74 -0
- msprobe/core/config_check/ckpt_compare/megatron_loader.py +302 -0
- msprobe/core/config_check/ckpt_compare/metrics.py +83 -0
- msprobe/core/config_check/ckpt_compare/name_mapping.yaml +12 -0
- msprobe/core/config_check/config_check_cli.py +51 -0
- msprobe/core/config_check/config_checker.py +100 -0
- msprobe/{mindspore/runtime.py → core/config_check/resource/dependency.yaml} +7 -4
- msprobe/core/config_check/resource/env.yaml +57 -0
- msprobe/core/config_check/resource/hyperparameter.yaml +21 -0
- msprobe/core/config_check/utils/hyperparameter_parser.py +115 -0
- msprobe/core/config_check/utils/utils.py +107 -0
- msprobe/core/data_dump/api_registry.py +67 -4
- msprobe/core/data_dump/data_collector.py +170 -89
- msprobe/core/data_dump/data_processor/base.py +72 -51
- msprobe/core/data_dump/data_processor/mindspore_processor.py +109 -55
- msprobe/core/data_dump/data_processor/pytorch_processor.py +90 -82
- msprobe/core/data_dump/json_writer.py +143 -27
- msprobe/core/debugger/precision_debugger.py +144 -0
- msprobe/core/grad_probe/constant.py +1 -1
- msprobe/core/grad_probe/grad_compare.py +1 -1
- msprobe/core/grad_probe/utils.py +1 -1
- msprobe/core/hook_manager.py +242 -0
- msprobe/core/monitor/anomaly_processor.py +384 -0
- msprobe/core/service.py +357 -0
- msprobe/core/single_save/__init__.py +0 -0
- msprobe/core/single_save/single_comparator.py +243 -0
- msprobe/core/single_save/single_saver.py +146 -0
- msprobe/docs/01.installation.md +6 -5
- msprobe/docs/02.config_introduction.md +79 -22
- msprobe/docs/03.config_examples.md +1 -0
- msprobe/docs/04.kernel_dump_PyTorch.md +1 -1
- msprobe/docs/05.data_dump_PyTorch.md +118 -49
- msprobe/docs/06.data_dump_MindSpore.md +167 -20
- msprobe/docs/07.accuracy_checker_PyTorch.md +2 -2
- msprobe/docs/08.accuracy_checker_online_PyTorch.md +69 -9
- msprobe/docs/09.accuracy_checker_MindSpore.md +18 -6
- msprobe/docs/10.accuracy_compare_PyTorch.md +212 -74
- msprobe/docs/11.accuracy_compare_MindSpore.md +87 -37
- msprobe/docs/12.overflow_check_PyTorch.md +2 -2
- msprobe/docs/13.overflow_check_MindSpore.md +2 -2
- msprobe/docs/14.data_parse_PyTorch.md +3 -3
- msprobe/docs/17.grad_probe.md +2 -1
- msprobe/docs/18.online_dispatch.md +2 -2
- msprobe/docs/19.monitor.md +90 -44
- msprobe/docs/21.visualization_PyTorch.md +68 -15
- msprobe/docs/22.visualization_MindSpore.md +71 -18
- msprobe/docs/25.tool_function_introduction.md +23 -22
- msprobe/docs/26.data_dump_PyTorch_baseline.md +14 -3
- msprobe/docs/27.dump_json_instruction.md +1 -1
- msprobe/docs/28.debugger_save_instruction.md +111 -20
- msprobe/docs/29.data_dump_MSAdapter.md +2 -2
- msprobe/docs/30.overflow_check_MSAdapter.md +2 -2
- msprobe/docs/31.config_check.md +95 -0
- msprobe/docs/32.ckpt_compare.md +69 -0
- msprobe/docs/33.generate_operator_MindSpore.md +181 -0
- msprobe/docs/34.RL_collect.md +92 -0
- msprobe/docs/35.nan_analyze.md +72 -0
- msprobe/docs/data_dump_MindSpore/data_dump_MindSpore_baseline.md +12 -1
- msprobe/docs/data_dump_MindSpore/dynamic_graph_quick_start_example.md +3 -1
- msprobe/docs/img/compare_result.png +0 -0
- msprobe/docs/img/save_compare_result_sample.png +0 -0
- msprobe/docs/img/visualization/proxy.png +0 -0
- msprobe/mindspore/__init__.py +1 -2
- msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +150 -58
- msprobe/mindspore/api_accuracy_checker/api_runner.py +7 -3
- msprobe/mindspore/api_accuracy_checker/bench_functions/flash_attention_score.py +47 -69
- msprobe/mindspore/api_accuracy_checker/cmd_parser.py +4 -0
- msprobe/mindspore/api_accuracy_checker/compute_element.py +0 -1
- msprobe/mindspore/api_accuracy_checker/data_manager.py +2 -2
- msprobe/mindspore/api_accuracy_checker/generate_op_script/op_generator.py +460 -0
- msprobe/mindspore/api_accuracy_checker/generate_op_script/operator_replication.template +2081 -0
- msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +9 -0
- msprobe/mindspore/api_accuracy_checker/torch_mindtorch_importer.py +2 -1
- msprobe/mindspore/cell_processor.py +204 -33
- msprobe/mindspore/code_mapping/graph_parser.py +4 -21
- msprobe/mindspore/common/const.py +17 -7
- msprobe/mindspore/common/utils.py +128 -11
- msprobe/mindspore/compare/common_dir_compare.py +382 -0
- msprobe/mindspore/compare/distributed_compare.py +2 -26
- msprobe/mindspore/compare/ms_compare.py +17 -405
- msprobe/mindspore/compare/ms_graph_compare.py +14 -5
- msprobe/mindspore/compare/utils.py +37 -0
- msprobe/mindspore/debugger/debugger_config.py +53 -3
- msprobe/mindspore/debugger/precision_debugger.py +72 -91
- msprobe/mindspore/dump/cell_dump_process.py +877 -0
- msprobe/mindspore/dump/cell_dump_with_insert_gradient.py +864 -0
- msprobe/mindspore/dump/dump_tool_factory.py +13 -5
- msprobe/mindspore/dump/graph_mode_cell_dump.py +139 -0
- msprobe/mindspore/dump/graph_tensor_dump.py +123 -0
- msprobe/mindspore/dump/hook_cell/api_register.py +40 -6
- msprobe/mindspore/dump/hook_cell/hook_cell.py +18 -7
- msprobe/mindspore/dump/hook_cell/ms_hook_manager.py +88 -0
- msprobe/mindspore/dump/hook_cell/primitive_hooks.py +8 -2
- msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +18 -0
- msprobe/mindspore/dump/jit_dump.py +21 -18
- msprobe/mindspore/dump/kernel_kbyk_dump.py +6 -3
- msprobe/mindspore/dym_loader/hook_dynamic_loader.cpp +110 -0
- msprobe/mindspore/dym_loader/hook_dynamic_loader.h +15 -15
- msprobe/mindspore/free_benchmark/api_pynative_self_check.py +12 -6
- msprobe/mindspore/free_benchmark/common/utils.py +1 -1
- msprobe/mindspore/grad_probe/global_context.py +7 -2
- msprobe/mindspore/grad_probe/grad_stat_csv.py +3 -2
- msprobe/mindspore/mindspore_service.py +114 -0
- msprobe/mindspore/monitor/common_func.py +52 -0
- msprobe/mindspore/monitor/data_writers.py +237 -0
- msprobe/mindspore/monitor/features.py +20 -7
- msprobe/mindspore/monitor/module_hook.py +281 -209
- msprobe/mindspore/monitor/optimizer_collect.py +334 -0
- msprobe/mindspore/monitor/utils.py +25 -5
- msprobe/mindspore/ms_config.py +16 -15
- msprobe/mindspore/task_handler_factory.py +5 -2
- msprobe/msprobe.py +19 -0
- msprobe/nan_analyze/__init__.py +14 -0
- msprobe/nan_analyze/analyzer.py +255 -0
- msprobe/nan_analyze/graph.py +189 -0
- msprobe/nan_analyze/utils.py +211 -0
- msprobe/pytorch/api_accuracy_checker/common/config.py +2 -2
- msprobe/pytorch/api_accuracy_checker/compare/compare.py +36 -34
- msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +20 -20
- msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +4 -7
- msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +204 -2
- msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +12 -11
- msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +1 -0
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +8 -5
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +2 -3
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +29 -13
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +12 -2
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +45 -31
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +156 -0
- msprobe/pytorch/attl_manager.py +65 -0
- msprobe/pytorch/bench_functions/npu_fusion_attention.py +27 -0
- msprobe/pytorch/common/utils.py +26 -14
- msprobe/pytorch/compare/distributed_compare.py +4 -36
- msprobe/pytorch/compare/pt_compare.py +13 -84
- msprobe/pytorch/compare/utils.py +47 -0
- msprobe/pytorch/debugger/debugger_config.py +34 -17
- msprobe/pytorch/debugger/precision_debugger.py +66 -118
- msprobe/pytorch/dump/module_dump/hook_wrapper.py +93 -0
- msprobe/pytorch/dump/module_dump/module_dump.py +11 -58
- msprobe/pytorch/dump/module_dump/module_processer.py +143 -113
- msprobe/pytorch/grad_probe/grad_stat_csv.py +3 -2
- msprobe/pytorch/hook_module/api_register.py +29 -5
- msprobe/pytorch/hook_module/hook_module.py +9 -18
- msprobe/pytorch/hook_module/jit_script_wrapper.py +33 -0
- msprobe/pytorch/hook_module/pt_hook_manager.py +68 -0
- msprobe/pytorch/hook_module/support_wrap_ops.yaml +22 -1
- msprobe/pytorch/hook_module/utils.py +28 -2
- msprobe/pytorch/monitor/csv2tb.py +6 -2
- msprobe/pytorch/monitor/data_writers.py +259 -0
- msprobe/pytorch/monitor/module_hook.py +227 -158
- msprobe/pytorch/monitor/module_metric.py +14 -0
- msprobe/pytorch/monitor/optimizer_collect.py +242 -270
- msprobe/pytorch/monitor/utils.py +16 -3
- msprobe/pytorch/online_dispatch/dispatch.py +4 -2
- msprobe/pytorch/online_dispatch/dump_compare.py +5 -2
- msprobe/pytorch/parse_tool/lib/utils.py +3 -3
- msprobe/pytorch/pt_config.py +8 -7
- msprobe/pytorch/pytorch_service.py +73 -0
- msprobe/visualization/builder/graph_builder.py +33 -13
- msprobe/visualization/builder/msprobe_adapter.py +24 -11
- msprobe/visualization/compare/graph_comparator.py +53 -45
- msprobe/visualization/compare/mode_adapter.py +31 -1
- msprobe/visualization/graph/base_node.py +3 -3
- msprobe/visualization/graph/graph.py +2 -2
- msprobe/visualization/graph_service.py +250 -103
- msprobe/visualization/utils.py +27 -11
- msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +0 -106
- msprobe/mindspore/monitor/anomaly_detect.py +0 -404
- msprobe/mindspore/monitor/module_spec_verifier.py +0 -94
- msprobe/mindspore/service.py +0 -549
- msprobe/pytorch/monitor/anomaly_analyse.py +0 -201
- msprobe/pytorch/monitor/anomaly_detect.py +0 -410
- msprobe/pytorch/monitor/module_spec_verifier.py +0 -95
- msprobe/pytorch/monitor/unittest/test_monitor.py +0 -160
- msprobe/pytorch/service.py +0 -473
- {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/LICENSE +0 -0
- {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/WHEEL +0 -0
- {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/entry_points.txt +0 -0
- {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/top_level.txt +0 -0
- /msprobe/{mindspore → core}/compare/ms_to_pt_api.yaml +0 -0
- /msprobe/{mindspore/dump → core}/kernel_dump/kernel_config.py +0 -0
- /msprobe/{pytorch/monitor/unittest → core/monitor}/__init__.py +0 -0
|
@@ -13,10 +13,12 @@
|
|
|
13
13
|
# See the License for the specific language governing permissions and
|
|
14
14
|
# limitations under the License.
|
|
15
15
|
|
|
16
|
-
import os
|
|
17
16
|
from collections import defaultdict
|
|
17
|
+
import os
|
|
18
|
+
import types
|
|
18
19
|
|
|
19
20
|
import mindspore
|
|
21
|
+
from mindspore import nn
|
|
20
22
|
from mindspore._c_expression import PyNativeExecutor_
|
|
21
23
|
try:
|
|
22
24
|
from mindspore.common.api import _MindsporeFunctionExecutor
|
|
@@ -25,7 +27,9 @@ except ImportError:
|
|
|
25
27
|
|
|
26
28
|
from msprobe.core.common.log import logger
|
|
27
29
|
from msprobe.core.common.const import Const
|
|
30
|
+
from msprobe.core.common.runtime import Runtime
|
|
28
31
|
from msprobe.core.data_dump.data_processor.base import ModuleForwardInputsOutputs, ModuleBackwardInputsOutputs
|
|
32
|
+
from msprobe.mindspore.common.const import Const as MsConst
|
|
29
33
|
from msprobe.mindspore.dump.hook_cell.api_register import get_api_register
|
|
30
34
|
|
|
31
35
|
|
|
@@ -34,24 +38,20 @@ _api_register = get_api_register()
|
|
|
34
38
|
|
|
35
39
|
def dump_jit(name, in_feat, out_feat, is_forward):
|
|
36
40
|
pid = os.getpid()
|
|
37
|
-
|
|
38
|
-
index = ori_args.find("<")
|
|
39
|
-
if index != 0 and index != -1:
|
|
40
|
-
result = ori_args[0:index]
|
|
41
|
-
elif name is not None and "<" not in str(name):
|
|
42
|
-
result = str(name)
|
|
43
|
-
else:
|
|
44
|
-
result = "JitFunction"
|
|
41
|
+
name = name if name else "JitFunction"
|
|
45
42
|
if JitDump.need_dump():
|
|
46
43
|
if is_forward:
|
|
47
|
-
JitDump.jit_count
|
|
48
|
-
|
|
49
|
-
|
|
44
|
+
if name in JitDump.jit_count:
|
|
45
|
+
JitDump.jit_count[name] += 1
|
|
46
|
+
else:
|
|
47
|
+
JitDump.jit_count[name] = 0
|
|
48
|
+
name_template = (Const.JIT + Const.SEP + name + Const.SEP +
|
|
49
|
+
str(JitDump.jit_count[name]) + Const.SEP + Const.FORWARD)
|
|
50
50
|
JitDump.data_collector.update_api_or_module_name(name_template)
|
|
51
51
|
module_input_output = ModuleForwardInputsOutputs(args=in_feat, kwargs={}, output=out_feat)
|
|
52
52
|
JitDump.data_collector.forward_data_collect(name_template, None, pid, module_input_output)
|
|
53
53
|
else:
|
|
54
|
-
name_template = Const.JIT + Const.SEP +
|
|
54
|
+
name_template = Const.JIT + Const.SEP + name + Const.SEP + str(JitDump.jit_count[name]) + Const.SEP + \
|
|
55
55
|
Const.BACKWARD
|
|
56
56
|
JitDump.data_collector.update_api_or_module_name(name_template)
|
|
57
57
|
module_input_output = ModuleBackwardInputsOutputs(grad_input=in_feat, grad_output=out_feat)
|
|
@@ -74,11 +74,11 @@ class JitDump(_MindsporeFunctionExecutor):
|
|
|
74
74
|
def __call__(self, *args, **kwargs):
|
|
75
75
|
_api_register.restore_all_api()
|
|
76
76
|
out = super().__call__(*args, **kwargs)
|
|
77
|
-
if JitDump.jit_dump_switch and len(args) > 0:
|
|
78
|
-
if self.name
|
|
77
|
+
if JitDump.jit_dump_switch and len(args) > 0 and self.name:
|
|
78
|
+
if self.name != "construct":
|
|
79
79
|
dump_jit(self.name, args, out, True)
|
|
80
|
-
|
|
81
|
-
dump_jit(args[0], args, out, True)
|
|
80
|
+
elif Runtime.run_mode != MsConst.PYNATIVE_GRAPH_MODE and isinstance(args[0], nn.Cell):
|
|
81
|
+
dump_jit(args[0].__class__.__name__, args, out, True)
|
|
82
82
|
JitDump.jit_enable = True
|
|
83
83
|
elif len(args) == 0:
|
|
84
84
|
logger.warning(f"The jit function {self.name} has no input arguments, nothing will be dumped.")
|
|
@@ -109,6 +109,9 @@ class JitDump(_MindsporeFunctionExecutor):
|
|
|
109
109
|
else:
|
|
110
110
|
output = self._executor.grad(grad, obj, weights, grad_position, *args, *(kwargs.values()))
|
|
111
111
|
if JitDump.jit_dump_switch and JitDump.jit_enable:
|
|
112
|
-
|
|
112
|
+
if isinstance(obj, types.FunctionType):
|
|
113
|
+
dump_jit(obj.__name__, args, None, False)
|
|
114
|
+
elif Runtime.run_mode != MsConst.PYNATIVE_GRAPH_MODE and isinstance(obj, nn.Cell):
|
|
115
|
+
dump_jit(obj.__class__.__name__, args, None, False)
|
|
113
116
|
_api_register.register_all_api()
|
|
114
117
|
return output
|
|
@@ -39,9 +39,12 @@ class KernelKbykDump:
|
|
|
39
39
|
common_set["input_output"] = 0
|
|
40
40
|
common_set["kernels"] = []
|
|
41
41
|
common_set["support_device"] = [0, 1, 2, 3, 4, 5, 6, 7]
|
|
42
|
-
e2e_set =
|
|
43
|
-
|
|
44
|
-
|
|
42
|
+
e2e_set = {
|
|
43
|
+
"enable": not config.async_dump,
|
|
44
|
+
"trans_flag": True,
|
|
45
|
+
"stat_calc_mode": config.stat_cal_mode,
|
|
46
|
+
"device_stat_precision_mode": config.device_stat_precision_mode,
|
|
47
|
+
}
|
|
45
48
|
|
|
46
49
|
if config.list:
|
|
47
50
|
common_set["dump_mode"] = 1
|
|
@@ -0,0 +1,110 @@
|
|
|
1
|
+
/*
|
|
2
|
+
* Copyright (C) 2024-2025. Huawei Technologies Co., Ltd. All rights reserved.
|
|
3
|
+
*
|
|
4
|
+
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
* you may not use this file except in compliance with the License.
|
|
6
|
+
* You may obtain a copy of the License at
|
|
7
|
+
*
|
|
8
|
+
* http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
*
|
|
10
|
+
* Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
* See the License for the specific language governing permissions and
|
|
14
|
+
* limitations under the License.
|
|
15
|
+
*/
|
|
16
|
+
|
|
17
|
+
#include "hook_dynamic_loader.h"
|
|
18
|
+
#include <sys/stat.h>
|
|
19
|
+
#include <cstdlib>
|
|
20
|
+
#include <cstring>
|
|
21
|
+
#include <pybind11/embed.h>
|
|
22
|
+
#include "utils/log_adapter.h"
|
|
23
|
+
|
|
24
|
+
namespace py = pybind11;
|
|
25
|
+
|
|
26
|
+
HookDynamicLoader &HookDynamicLoader::GetInstance()
|
|
27
|
+
{
|
|
28
|
+
static HookDynamicLoader instance;
|
|
29
|
+
return instance;
|
|
30
|
+
}
|
|
31
|
+
|
|
32
|
+
bool HookDynamicLoader::LoadFunction(void *handle, const std::string &functionName) {
|
|
33
|
+
void *func = dlsym(handle, functionName.c_str());
|
|
34
|
+
if (!func) {
|
|
35
|
+
MS_LOG(WARNING) << "Could not load function: " << functionName << ", error: " << dlerror();
|
|
36
|
+
return false;
|
|
37
|
+
}
|
|
38
|
+
funcMap_[functionName] = func;
|
|
39
|
+
return true;
|
|
40
|
+
}
|
|
41
|
+
|
|
42
|
+
bool HookDynamicLoader::LoadLibrary()
|
|
43
|
+
{
|
|
44
|
+
std::string msprobePath = "";
|
|
45
|
+
// 获取gil锁
|
|
46
|
+
py::gil_scoped_acquire acquire;
|
|
47
|
+
try {
|
|
48
|
+
py::module msprobeMod = py::module::import("msprobe.lib._msprobe_c");
|
|
49
|
+
if (!py::hasattr(msprobeMod, "__file__")) {
|
|
50
|
+
MS_LOG(WARNING) << "Adump mod not found";
|
|
51
|
+
return false;
|
|
52
|
+
}
|
|
53
|
+
msprobePath = msprobeMod.attr("__file__").cast<std::string>();
|
|
54
|
+
} catch (const std::exception& e) {
|
|
55
|
+
MS_LOG(WARNING) << "Adump mod path unable to get: " << e.what();
|
|
56
|
+
return false;
|
|
57
|
+
}
|
|
58
|
+
std::lock_guard<std::mutex> lock(mutex_);
|
|
59
|
+
if (handle_) {
|
|
60
|
+
MS_LOG(WARNING) << "Hook library already loaded!";
|
|
61
|
+
return false;
|
|
62
|
+
}
|
|
63
|
+
if (msprobePath == "") {
|
|
64
|
+
MS_LOG(WARNING) << "Adump path not loaded";
|
|
65
|
+
return false;
|
|
66
|
+
}
|
|
67
|
+
handle_ = dlopen(msprobePath.c_str(), RTLD_LAZY | RTLD_LOCAL);
|
|
68
|
+
if (!handle_) {
|
|
69
|
+
MS_LOG(WARNING) << "Failed to load Hook library: " << dlerror();
|
|
70
|
+
return false;
|
|
71
|
+
}
|
|
72
|
+
|
|
73
|
+
for (const auto &functionName : functionList_) {
|
|
74
|
+
if (!LoadFunction(handle_, functionName)) {
|
|
75
|
+
MS_LOG(WARNING) << "Failed to load adump function";
|
|
76
|
+
dlclose(handle_);
|
|
77
|
+
handle_ = nullptr;
|
|
78
|
+
return false;
|
|
79
|
+
}
|
|
80
|
+
}
|
|
81
|
+
|
|
82
|
+
MS_LOG(INFO) << "Hook library loaded successfully.";
|
|
83
|
+
return true;
|
|
84
|
+
}
|
|
85
|
+
|
|
86
|
+
bool HookDynamicLoader::UnloadLibrary()
|
|
87
|
+
{
|
|
88
|
+
std::lock_guard<std::mutex> lock(mutex_);
|
|
89
|
+
if (!handle_) {
|
|
90
|
+
MS_LOG(WARNING) << "Hook library hasn't been loaded.";
|
|
91
|
+
return false;
|
|
92
|
+
}
|
|
93
|
+
|
|
94
|
+
dlclose(handle_);
|
|
95
|
+
handle_ = nullptr;
|
|
96
|
+
funcMap_.clear();
|
|
97
|
+
MS_LOG(INFO) << "Library unloaded successfully.";
|
|
98
|
+
return true;
|
|
99
|
+
}
|
|
100
|
+
|
|
101
|
+
void *HookDynamicLoader::GetHooker(const std::string &funcName)
|
|
102
|
+
{
|
|
103
|
+
std::lock_guard<std::mutex> lock(mutex_);
|
|
104
|
+
auto iter = funcMap_.find(funcName);
|
|
105
|
+
if (iter == funcMap_.end()) {
|
|
106
|
+
MS_LOG(WARNING) << "Function not found: " << funcName;
|
|
107
|
+
return nullptr;
|
|
108
|
+
}
|
|
109
|
+
return iter->second;
|
|
110
|
+
}
|
|
@@ -27,26 +27,26 @@ constexpr auto kHookBegin = "MS_DbgOnStepBegin";
|
|
|
27
27
|
constexpr auto kHookEnd = "MS_DbgOnStepEnd";
|
|
28
28
|
|
|
29
29
|
class HookDynamicLoader {
|
|
30
|
-
|
|
31
|
-
|
|
30
|
+
public:
|
|
31
|
+
static HookDynamicLoader &GetInstance();
|
|
32
32
|
|
|
33
|
-
|
|
34
|
-
|
|
33
|
+
HookDynamicLoader(const HookDynamicLoader &) = delete;
|
|
34
|
+
HookDynamicLoader &operator=(const HookDynamicLoader &) = delete;
|
|
35
35
|
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
36
|
+
bool LoadLibrary();
|
|
37
|
+
bool UnloadLibrary();
|
|
38
|
+
void *GetHooker(const std::string &funcName);
|
|
39
39
|
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
40
|
+
private:
|
|
41
|
+
// Helper functions
|
|
42
|
+
bool LoadFunction(void *handle, const std::string &functionName);
|
|
43
43
|
|
|
44
|
-
|
|
44
|
+
HookDynamicLoader() = default;
|
|
45
45
|
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
46
|
+
void *handle_ = nullptr;
|
|
47
|
+
std::vector<std::string> functionList_ = {kHookBegin, kHookEnd};
|
|
48
|
+
std::map<std::string, void *> funcMap_;
|
|
49
|
+
std::mutex mutex_;
|
|
50
50
|
};
|
|
51
51
|
|
|
52
52
|
#endif // HOOK_DYNAMIC_LOADER_H
|
|
@@ -23,6 +23,8 @@ import mindspore as ms
|
|
|
23
23
|
from msprobe.core.common.const import Const
|
|
24
24
|
from msprobe.core.common.exceptions import DistributedNotInitializedError
|
|
25
25
|
from msprobe.core.common.file_utils import check_path_length, load_yaml
|
|
26
|
+
from msprobe.core.common.runtime import Runtime
|
|
27
|
+
from msprobe.core.hook_manager import HookSet
|
|
26
28
|
from msprobe.mindspore.common.const import Const as MsConst
|
|
27
29
|
from msprobe.mindspore.common.const import FreeBenchmarkConst
|
|
28
30
|
from msprobe.mindspore.common.log import logger
|
|
@@ -35,7 +37,6 @@ from msprobe.mindspore.free_benchmark.common.handler_params import HandlerParams
|
|
|
35
37
|
from msprobe.mindspore.free_benchmark.common.utils import Tools
|
|
36
38
|
from msprobe.mindspore.free_benchmark.handler.handler_factory import HandlerFactory
|
|
37
39
|
from msprobe.mindspore.free_benchmark.perturbation.perturbation_factory import PerturbationFactory
|
|
38
|
-
from msprobe.mindspore.runtime import Runtime
|
|
39
40
|
|
|
40
41
|
|
|
41
42
|
_api_register = get_api_register()
|
|
@@ -75,7 +76,7 @@ class ApiPyNativeSelfCheck:
|
|
|
75
76
|
ret = None
|
|
76
77
|
|
|
77
78
|
if not need_wrapper_func():
|
|
78
|
-
del cell.
|
|
79
|
+
del cell.msprobe_input_kwargs
|
|
79
80
|
return ret
|
|
80
81
|
|
|
81
82
|
api_name_with_id = api_name_with_id[:-1]
|
|
@@ -84,9 +85,9 @@ class ApiPyNativeSelfCheck:
|
|
|
84
85
|
api_name_with_id[api_name_with_id.find(Const.SEP) + 1:api_name_with_id.rfind(Const.SEP)])
|
|
85
86
|
if api_name in self.api_list:
|
|
86
87
|
ret = check_self(api_name_with_id, output_data, self.ori_func.get(api_name),
|
|
87
|
-
*input_data, **cell.
|
|
88
|
+
*input_data, **cell.msprobe_input_kwargs)
|
|
88
89
|
|
|
89
|
-
del cell.
|
|
90
|
+
del cell.msprobe_input_kwargs
|
|
90
91
|
return ret
|
|
91
92
|
|
|
92
93
|
def backward_hook(cell, grad_input, grad_output):
|
|
@@ -105,8 +106,13 @@ class ApiPyNativeSelfCheck:
|
|
|
105
106
|
|
|
106
107
|
def pre_backward_hook(cell, grad_input):
|
|
107
108
|
return None
|
|
108
|
-
|
|
109
|
-
return
|
|
109
|
+
|
|
110
|
+
return HookSet(
|
|
111
|
+
forward_hook=wrap_forward_hook,
|
|
112
|
+
forward_pre_hook=pre_hook,
|
|
113
|
+
backward_hook=wrap_backward_hook,
|
|
114
|
+
backward_pre_hook=pre_backward_hook
|
|
115
|
+
)
|
|
110
116
|
|
|
111
117
|
def store_original_func(self):
|
|
112
118
|
for api_name in self.api_list:
|
|
@@ -19,10 +19,10 @@ from typing import Any, Optional
|
|
|
19
19
|
import mindspore as ms
|
|
20
20
|
from mindspore import Tensor, ops
|
|
21
21
|
|
|
22
|
+
from msprobe.core.common.runtime import Runtime
|
|
22
23
|
from msprobe.mindspore.common.const import FreeBenchmarkConst
|
|
23
24
|
from msprobe.mindspore.free_benchmark.common.config import Config
|
|
24
25
|
from msprobe.mindspore.free_benchmark.common.handler_params import HandlerParams
|
|
25
|
-
from msprobe.mindspore.runtime import Runtime
|
|
26
26
|
|
|
27
27
|
|
|
28
28
|
class Tools:
|
|
@@ -41,8 +41,12 @@ class GlobalContext:
|
|
|
41
41
|
def __new__(cls, *args, **kwargs):
|
|
42
42
|
if cls._instance is None:
|
|
43
43
|
cls._instance_lock.acquire()
|
|
44
|
-
|
|
45
|
-
|
|
44
|
+
try:
|
|
45
|
+
cls._instance = object.__new__(cls)
|
|
46
|
+
except Exception as e:
|
|
47
|
+
raise RuntimeError("grad_probe global context init failed") from e
|
|
48
|
+
finally:
|
|
49
|
+
cls._instance_lock.release()
|
|
46
50
|
return cls._instance
|
|
47
51
|
|
|
48
52
|
def init_context(self, config_dict: Dict):
|
|
@@ -69,6 +73,7 @@ class GlobalContext:
|
|
|
69
73
|
create_directory(self._setting.get(GradConst.OUTPUT_PATH))
|
|
70
74
|
else:
|
|
71
75
|
logger.warning("The output_path exists, the data will be covered.")
|
|
76
|
+
|
|
72
77
|
self._setting[GradConst.TIME_STAMP] = str(int(time.time()))
|
|
73
78
|
|
|
74
79
|
def get_context(self, key: str):
|
|
@@ -15,6 +15,7 @@
|
|
|
15
15
|
|
|
16
16
|
import hashlib
|
|
17
17
|
from abc import ABC, abstractmethod
|
|
18
|
+
import zlib
|
|
18
19
|
|
|
19
20
|
import mindspore
|
|
20
21
|
from mindspore import ops
|
|
@@ -76,8 +77,8 @@ class CsvMd5(CsvItem):
|
|
|
76
77
|
def generate_csv_content(csv_input):
|
|
77
78
|
grad = csv_input.grad
|
|
78
79
|
tensor_bytes = grad.float().numpy().tobytes()
|
|
79
|
-
md5_hash =
|
|
80
|
-
return [md5_hash
|
|
80
|
+
md5_hash = f"{zlib.crc32(tensor_bytes):08x}"
|
|
81
|
+
return [md5_hash]
|
|
81
82
|
|
|
82
83
|
|
|
83
84
|
@register_csv_item(GradConst.DISTRIBUTION)
|
|
@@ -0,0 +1,114 @@
|
|
|
1
|
+
# Copyright (c) 2025, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
from collections import defaultdict
|
|
17
|
+
import mindspore as ms
|
|
18
|
+
from mindspore.ops.primitive import Primitive
|
|
19
|
+
|
|
20
|
+
from msprobe.core.common.utils import Const
|
|
21
|
+
from msprobe.core.service import BaseService
|
|
22
|
+
from msprobe.mindspore.cell_processor import CellProcessor
|
|
23
|
+
from msprobe.mindspore.common.log import logger
|
|
24
|
+
from msprobe.mindspore.common.utils import (
|
|
25
|
+
get_rank_if_initialized,
|
|
26
|
+
is_mindtorch,
|
|
27
|
+
get_cells_and_names_with_index
|
|
28
|
+
)
|
|
29
|
+
from msprobe.mindspore.dump.hook_cell.api_register import get_api_register, ApiTemplate
|
|
30
|
+
from msprobe.mindspore.dump.hook_cell.ms_hook_manager import MindsproeHookManager
|
|
31
|
+
from msprobe.mindspore.dump.hook_cell.primitive_hooks import PrimitiveHookService
|
|
32
|
+
from msprobe.mindspore.dump.jit_dump import JitDump
|
|
33
|
+
|
|
34
|
+
try:
|
|
35
|
+
from mindspore.common._pijit_context import PIJitCaptureContext
|
|
36
|
+
except ImportError:
|
|
37
|
+
pijit_label = False
|
|
38
|
+
else:
|
|
39
|
+
pijit_label = True
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
class MindsporeService(BaseService):
|
|
43
|
+
@property
|
|
44
|
+
def _get_framework_type(self):
|
|
45
|
+
return Const.MT_FRAMEWORK if is_mindtorch() else Const.MS_FRAMEWORK
|
|
46
|
+
|
|
47
|
+
@staticmethod
|
|
48
|
+
def _get_current_rank():
|
|
49
|
+
return get_rank_if_initialized()
|
|
50
|
+
|
|
51
|
+
def empty(self, *args, **kwargs):
|
|
52
|
+
pass
|
|
53
|
+
|
|
54
|
+
def reset_status(self):
|
|
55
|
+
self._reset_status()
|
|
56
|
+
|
|
57
|
+
def _init_specific_components(self):
|
|
58
|
+
self.logger = logger
|
|
59
|
+
self.api_register = get_api_register()
|
|
60
|
+
self.primitive_hook_service = PrimitiveHookService(self)
|
|
61
|
+
self.cell_processor = CellProcessor(self.data_collector.scope)
|
|
62
|
+
self.hook_manager = MindsproeHookManager(self.data_collector, self.config)
|
|
63
|
+
self._setup_jit_context()
|
|
64
|
+
self.api_template = ApiTemplate
|
|
65
|
+
|
|
66
|
+
def _setup_jit_context(self):
|
|
67
|
+
if self.config.level in [Const.LEVEL_MIX, Const.LEVEL_L1]:
|
|
68
|
+
JitDump.set_config(self.config)
|
|
69
|
+
JitDump.set_data_collector(self.data_collector)
|
|
70
|
+
if hasattr(ms.common.api, "_MindsporeFunctionExecutor"):
|
|
71
|
+
ms.common.api._MindsporeFunctionExecutor = JitDump
|
|
72
|
+
else:
|
|
73
|
+
ms.common.api._JitExecutor = JitDump
|
|
74
|
+
ms.common.api._PyNativeExecutor.grad = JitDump.grad
|
|
75
|
+
if pijit_label:
|
|
76
|
+
PIJitCaptureContext.__enter__ = self.empty
|
|
77
|
+
PIJitCaptureContext.__exit__ = self.empty
|
|
78
|
+
|
|
79
|
+
def _register_module_hook(self):
|
|
80
|
+
self.cell_processor.register_cell_hook(self.model, self.build_hook, self.config)
|
|
81
|
+
self.logger.info_on_rank_0(f"The module {self.config.task} hook function is successfully mounted to the model.")
|
|
82
|
+
|
|
83
|
+
def _register_hook(self):
|
|
84
|
+
self._register_primitive_hook()
|
|
85
|
+
|
|
86
|
+
def _register_primitive_hook(self):
|
|
87
|
+
if self.config.level not in [Const.LEVEL_MIX, Const.LEVEL_L1]:
|
|
88
|
+
return
|
|
89
|
+
if not self.model or self.config.task not in Const.DUMP_DATA_COLLECTION_LIST:
|
|
90
|
+
return
|
|
91
|
+
|
|
92
|
+
primitive_set = set()
|
|
93
|
+
cells_and_names_with_index, _ = get_cells_and_names_with_index(self.model)
|
|
94
|
+
for cells_and_names in cells_and_names_with_index.values():
|
|
95
|
+
for _, cell in cells_and_names:
|
|
96
|
+
for attribute, value in vars(cell).items():
|
|
97
|
+
if isinstance(value, Primitive):
|
|
98
|
+
primitive_set.add((attribute, value))
|
|
99
|
+
|
|
100
|
+
for pname, primitive in primitive_set:
|
|
101
|
+
primitive_class_name = primitive.__class__.__name__
|
|
102
|
+
primitive_combined_name = pname + Const.SEP + primitive_class_name
|
|
103
|
+
new_primitive = type('NewPrimitive', (primitive.__class__,),
|
|
104
|
+
{'__call__': self.primitive_hook_service.wrap_primitive(primitive.__call__,
|
|
105
|
+
primitive_combined_name)})
|
|
106
|
+
primitive.__class__ = new_primitive
|
|
107
|
+
|
|
108
|
+
def _reset_status(self):
|
|
109
|
+
super()._reset_status()
|
|
110
|
+
self.primitive_hook_service.primitive_counters.clear()
|
|
111
|
+
JitDump.jit_count = defaultdict(int)
|
|
112
|
+
|
|
113
|
+
def _change_jit_switch(self, status):
|
|
114
|
+
JitDump.jit_dump_switch = status
|
|
@@ -0,0 +1,52 @@
|
|
|
1
|
+
# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
from mindspore import nn
|
|
18
|
+
from mindspore import communication
|
|
19
|
+
from msprobe.mindspore.monitor.utils import logger
|
|
20
|
+
from msprobe.mindspore.common.utils import is_mindtorch
|
|
21
|
+
if is_mindtorch():
|
|
22
|
+
import torch
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def is_valid_instance(model):
|
|
26
|
+
return isinstance(model, torch.nn.Module) if is_mindtorch() else isinstance(model, nn.Cell)
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
def get_submodules(model):
|
|
30
|
+
if not is_valid_instance(model):
|
|
31
|
+
logger.info("Counter invalid model, nothing to hook")
|
|
32
|
+
return {}
|
|
33
|
+
return model.named_modules() if is_mindtorch() else model.cells_and_names()
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
def get_parameters(model):
|
|
37
|
+
if not is_valid_instance(model):
|
|
38
|
+
return {}
|
|
39
|
+
if is_mindtorch():
|
|
40
|
+
return model.named_parameters()
|
|
41
|
+
else:
|
|
42
|
+
return model.parameters_and_names()
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
def get_rank():
|
|
46
|
+
if comm_is_initialized():
|
|
47
|
+
return communication.get_rank()
|
|
48
|
+
return 0
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
def comm_is_initialized():
|
|
52
|
+
return communication.GlobalComm.INITED
|