mindstudio-probe 1.2.2__py3-none-any.whl → 8.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/METADATA +4 -3
- {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/RECORD +243 -191
- msprobe/README.md +57 -21
- msprobe/core/__init__.py +17 -0
- msprobe/core/common/const.py +224 -82
- msprobe/core/common/decorator.py +50 -0
- msprobe/core/common/exceptions.py +5 -3
- msprobe/core/common/file_utils.py +274 -40
- msprobe/core/common/framework_adapter.py +169 -0
- msprobe/core/common/global_lock.py +86 -0
- msprobe/core/common/runtime.py +25 -0
- msprobe/core/common/utils.py +148 -72
- msprobe/core/common_config.py +7 -0
- msprobe/core/compare/acc_compare.py +640 -462
- msprobe/core/compare/check.py +36 -107
- msprobe/core/compare/compare_cli.py +4 -0
- msprobe/core/compare/config.py +72 -0
- msprobe/core/compare/highlight.py +217 -215
- msprobe/core/compare/layer_mapping/layer_mapping.py +4 -1
- msprobe/core/compare/merge_result/merge_result.py +12 -6
- msprobe/core/compare/multiprocessing_compute.py +227 -107
- msprobe/core/compare/npy_compare.py +32 -16
- msprobe/core/compare/utils.py +218 -244
- msprobe/{mindspore/runtime.py → core/config_check/__init__.py} +2 -4
- msprobe/{pytorch/dump/kernel_dump/kernel_config.py → core/config_check/checkers/__init__.py} +8 -16
- msprobe/core/config_check/checkers/base_checker.py +60 -0
- msprobe/core/config_check/checkers/dataset_checker.py +138 -0
- msprobe/core/config_check/checkers/env_args_checker.py +96 -0
- msprobe/core/config_check/checkers/hyperparameter_checker.py +170 -0
- msprobe/core/config_check/checkers/pip_checker.py +90 -0
- msprobe/core/config_check/checkers/random_checker.py +367 -0
- msprobe/core/config_check/checkers/weights_checker.py +147 -0
- msprobe/core/config_check/ckpt_compare/ckpt_comparator.py +74 -0
- msprobe/core/config_check/ckpt_compare/megatron_loader.py +302 -0
- msprobe/core/config_check/ckpt_compare/metrics.py +83 -0
- msprobe/core/config_check/ckpt_compare/name_mapping.yaml +12 -0
- msprobe/core/config_check/config_check_cli.py +51 -0
- msprobe/core/config_check/config_checker.py +100 -0
- msprobe/{pytorch/parse.py → core/config_check/resource/dependency.yaml} +7 -4
- msprobe/core/config_check/resource/env.yaml +57 -0
- msprobe/core/config_check/resource/hyperparameter.yaml +21 -0
- msprobe/core/config_check/utils/hyperparameter_parser.py +115 -0
- msprobe/core/config_check/utils/utils.py +107 -0
- msprobe/core/data_dump/api_registry.py +239 -0
- msprobe/core/data_dump/data_collector.py +36 -9
- msprobe/core/data_dump/data_processor/base.py +74 -53
- msprobe/core/data_dump/data_processor/mindspore_processor.py +119 -78
- msprobe/core/data_dump/data_processor/pytorch_processor.py +134 -96
- msprobe/core/data_dump/json_writer.py +146 -57
- msprobe/core/debugger/precision_debugger.py +143 -0
- msprobe/core/grad_probe/constant.py +2 -1
- msprobe/core/grad_probe/grad_compare.py +2 -2
- msprobe/core/grad_probe/utils.py +1 -1
- msprobe/core/hook_manager.py +242 -0
- msprobe/core/monitor/anomaly_processor.py +384 -0
- msprobe/core/overflow_check/abnormal_scene.py +2 -0
- msprobe/core/service.py +356 -0
- msprobe/core/single_save/__init__.py +0 -0
- msprobe/core/single_save/single_comparator.py +243 -0
- msprobe/core/single_save/single_saver.py +157 -0
- msprobe/docs/01.installation.md +6 -5
- msprobe/docs/02.config_introduction.md +89 -30
- msprobe/docs/03.config_examples.md +1 -0
- msprobe/docs/04.kernel_dump_PyTorch.md +1 -1
- msprobe/docs/05.data_dump_PyTorch.md +184 -50
- msprobe/docs/06.data_dump_MindSpore.md +193 -28
- msprobe/docs/07.accuracy_checker_PyTorch.md +13 -3
- msprobe/docs/08.accuracy_checker_online_PyTorch.md +72 -10
- msprobe/docs/09.accuracy_checker_MindSpore.md +19 -7
- msprobe/docs/10.accuracy_compare_PyTorch.md +266 -102
- msprobe/docs/11.accuracy_compare_MindSpore.md +117 -43
- msprobe/docs/12.overflow_check_PyTorch.md +5 -3
- msprobe/docs/13.overflow_check_MindSpore.md +6 -4
- msprobe/docs/14.data_parse_PyTorch.md +4 -10
- msprobe/docs/17.grad_probe.md +2 -1
- msprobe/docs/18.online_dispatch.md +3 -3
- msprobe/docs/19.monitor.md +211 -103
- msprobe/docs/21.visualization_PyTorch.md +100 -28
- msprobe/docs/22.visualization_MindSpore.md +103 -31
- msprobe/docs/23.generate_operator_PyTorch.md +9 -9
- msprobe/docs/25.tool_function_introduction.md +23 -22
- msprobe/docs/26.data_dump_PyTorch_baseline.md +14 -3
- msprobe/docs/27.dump_json_instruction.md +278 -8
- msprobe/docs/28.debugger_save_instruction.md +111 -20
- msprobe/docs/28.kernel_dump_MindSpore.md +1 -1
- msprobe/docs/29.data_dump_MSAdapter.md +229 -0
- msprobe/docs/30.overflow_check_MSAdapter.md +31 -0
- msprobe/docs/31.config_check.md +95 -0
- msprobe/docs/32.ckpt_compare.md +69 -0
- msprobe/docs/33.generate_operator_MindSpore.md +190 -0
- msprobe/docs/34.RL_collect.md +92 -0
- msprobe/docs/35.nan_analyze.md +72 -0
- msprobe/docs/FAQ.md +3 -11
- msprobe/docs/data_dump_MindSpore/data_dump_MindSpore_baseline.md +12 -1
- msprobe/docs/data_dump_MindSpore/dynamic_graph_quick_start_example.md +3 -1
- msprobe/docs/img/compare_result.png +0 -0
- msprobe/docs/img/merge_result.png +0 -0
- msprobe/docs/img/save_compare_result_sample.png +0 -0
- msprobe/docs/img/visualization/proxy.png +0 -0
- msprobe/docs/img/visualization/vis_browser_1.png +0 -0
- msprobe/docs/img/visualization/vis_match_info.png +0 -0
- msprobe/docs/img/visualization/vis_precision_info.png +0 -0
- msprobe/docs/img/visualization/vis_search_info.png +0 -0
- msprobe/docs/img/visualization/vis_show_info.png +0 -0
- msprobe/docs/img/visualization/vis_showcase.png +0 -0
- msprobe/docs/img/visualization/vis_unmatch_info.png +0 -0
- msprobe/mindspore/__init__.py +3 -3
- msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +151 -55
- msprobe/mindspore/api_accuracy_checker/api_runner.py +25 -11
- msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +2 -1
- msprobe/mindspore/api_accuracy_checker/bench_functions/flash_attention_score.py +580 -0
- msprobe/mindspore/api_accuracy_checker/bench_functions/fusion_operator.py +41 -0
- msprobe/mindspore/api_accuracy_checker/cmd_parser.py +4 -0
- msprobe/mindspore/api_accuracy_checker/data_manager.py +4 -3
- msprobe/mindspore/api_accuracy_checker/generate_op_script/config_op.json +9 -0
- msprobe/mindspore/api_accuracy_checker/generate_op_script/op_generator.py +451 -0
- msprobe/mindspore/api_accuracy_checker/generate_op_script/operator_replication.template +2081 -0
- msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +11 -1
- msprobe/mindspore/api_accuracy_checker/torch_mindtorch_importer.py +2 -1
- msprobe/mindspore/cell_processor.py +204 -33
- msprobe/mindspore/code_mapping/graph_parser.py +4 -21
- msprobe/mindspore/common/const.py +73 -2
- msprobe/mindspore/common/utils.py +157 -29
- msprobe/mindspore/compare/common_dir_compare.py +382 -0
- msprobe/mindspore/compare/distributed_compare.py +2 -26
- msprobe/mindspore/compare/ms_compare.py +18 -398
- msprobe/mindspore/compare/ms_graph_compare.py +20 -10
- msprobe/mindspore/compare/utils.py +37 -0
- msprobe/mindspore/debugger/debugger_config.py +59 -7
- msprobe/mindspore/debugger/precision_debugger.py +83 -90
- msprobe/mindspore/dump/cell_dump_process.py +902 -0
- msprobe/mindspore/dump/cell_dump_with_insert_gradient.py +889 -0
- msprobe/mindspore/dump/dump_tool_factory.py +18 -8
- msprobe/mindspore/dump/graph_mode_cell_dump.py +139 -0
- msprobe/mindspore/dump/graph_tensor_dump.py +123 -0
- msprobe/mindspore/dump/hook_cell/api_register.py +176 -0
- msprobe/mindspore/dump/hook_cell/hook_cell.py +22 -12
- msprobe/mindspore/dump/hook_cell/ms_hook_manager.py +88 -0
- msprobe/mindspore/dump/hook_cell/primitive_hooks.py +8 -2
- msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +42 -26
- msprobe/mindspore/dump/jit_dump.py +35 -27
- msprobe/mindspore/dump/kernel_kbyk_dump.py +6 -3
- msprobe/mindspore/dym_loader/hook_dynamic_loader.cpp +110 -0
- msprobe/mindspore/dym_loader/hook_dynamic_loader.h +15 -16
- msprobe/mindspore/free_benchmark/api_pynative_self_check.py +22 -12
- msprobe/mindspore/free_benchmark/common/utils.py +1 -1
- msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +4 -2
- msprobe/mindspore/free_benchmark/self_check_tool_factory.py +6 -3
- msprobe/mindspore/grad_probe/global_context.py +9 -2
- msprobe/mindspore/grad_probe/grad_analyzer.py +2 -1
- msprobe/mindspore/grad_probe/grad_stat_csv.py +3 -2
- msprobe/mindspore/grad_probe/hook.py +2 -4
- msprobe/mindspore/mindspore_service.py +111 -0
- msprobe/mindspore/monitor/common_func.py +52 -0
- msprobe/mindspore/monitor/data_writers.py +237 -0
- msprobe/mindspore/monitor/distributed/wrap_distributed.py +1 -1
- msprobe/mindspore/monitor/features.py +13 -1
- msprobe/mindspore/monitor/module_hook.py +568 -444
- msprobe/mindspore/monitor/optimizer_collect.py +331 -0
- msprobe/mindspore/monitor/utils.py +71 -9
- msprobe/mindspore/ms_config.py +16 -15
- msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +5 -3
- msprobe/mindspore/task_handler_factory.py +5 -2
- msprobe/msprobe.py +19 -0
- msprobe/nan_analyze/__init__.py +14 -0
- msprobe/nan_analyze/analyzer.py +255 -0
- msprobe/nan_analyze/graph.py +189 -0
- msprobe/nan_analyze/utils.py +211 -0
- msprobe/pytorch/api_accuracy_checker/common/config.py +2 -2
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +3 -6
- msprobe/pytorch/api_accuracy_checker/compare/compare.py +36 -34
- msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +15 -13
- msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +206 -4
- msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +9 -9
- msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +6 -5
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +31 -9
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +28 -20
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +3 -1
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +29 -13
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +12 -2
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +45 -31
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +154 -0
- msprobe/pytorch/attl_manager.py +65 -0
- msprobe/pytorch/bench_functions/moe_gating_top_k_softmax.py +6 -0
- msprobe/pytorch/bench_functions/npu_fusion_attention.py +27 -0
- msprobe/pytorch/common/utils.py +53 -19
- msprobe/pytorch/compare/distributed_compare.py +4 -36
- msprobe/pytorch/compare/pt_compare.py +13 -84
- msprobe/pytorch/compare/utils.py +47 -0
- msprobe/pytorch/debugger/debugger_config.py +34 -17
- msprobe/pytorch/debugger/precision_debugger.py +50 -96
- msprobe/pytorch/dump/module_dump/hook_wrapper.py +93 -0
- msprobe/pytorch/dump/module_dump/module_dump.py +15 -61
- msprobe/pytorch/dump/module_dump/module_processer.py +150 -114
- msprobe/pytorch/free_benchmark/common/utils.py +1 -1
- msprobe/pytorch/free_benchmark/compare/single_benchmark.py +1 -1
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +3 -3
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +3 -3
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +1 -1
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +1 -1
- msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +1 -1
- msprobe/pytorch/function_factory.py +1 -1
- msprobe/pytorch/grad_probe/grad_monitor.py +2 -2
- msprobe/pytorch/grad_probe/grad_stat_csv.py +3 -2
- msprobe/pytorch/hook_module/api_register.py +155 -0
- msprobe/pytorch/hook_module/hook_module.py +18 -22
- msprobe/pytorch/hook_module/jit_script_wrapper.py +33 -0
- msprobe/pytorch/hook_module/pt_hook_manager.py +68 -0
- msprobe/pytorch/hook_module/register_optimizer_hook.py +2 -1
- msprobe/pytorch/hook_module/support_wrap_ops.yaml +193 -75
- msprobe/pytorch/hook_module/utils.py +28 -2
- msprobe/pytorch/monitor/csv2tb.py +14 -4
- msprobe/pytorch/monitor/data_writers.py +259 -0
- msprobe/pytorch/monitor/distributed/wrap_distributed.py +8 -2
- msprobe/pytorch/monitor/module_hook.py +336 -241
- msprobe/pytorch/monitor/module_metric.py +17 -0
- msprobe/pytorch/monitor/optimizer_collect.py +244 -224
- msprobe/pytorch/monitor/utils.py +84 -4
- msprobe/pytorch/online_dispatch/compare.py +0 -2
- msprobe/pytorch/online_dispatch/dispatch.py +13 -2
- msprobe/pytorch/online_dispatch/dump_compare.py +8 -2
- msprobe/pytorch/online_dispatch/utils.py +3 -0
- msprobe/pytorch/parse_tool/lib/interactive_cli.py +1 -6
- msprobe/pytorch/parse_tool/lib/utils.py +5 -4
- msprobe/pytorch/pt_config.py +16 -11
- msprobe/pytorch/pytorch_service.py +70 -0
- msprobe/visualization/builder/graph_builder.py +69 -10
- msprobe/visualization/builder/msprobe_adapter.py +24 -12
- msprobe/visualization/compare/graph_comparator.py +63 -51
- msprobe/visualization/compare/mode_adapter.py +22 -20
- msprobe/visualization/graph/base_node.py +11 -4
- msprobe/visualization/graph/distributed_analyzer.py +1 -10
- msprobe/visualization/graph/graph.py +2 -13
- msprobe/visualization/graph/node_op.py +1 -2
- msprobe/visualization/graph_service.py +251 -104
- msprobe/visualization/utils.py +26 -44
- msprobe/mindspore/dump/hook_cell/api_registry.py +0 -207
- msprobe/mindspore/dump/hook_cell/wrap_api.py +0 -212
- msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +0 -140
- msprobe/mindspore/monitor/anomaly_detect.py +0 -404
- msprobe/mindspore/monitor/module_spec_verifier.py +0 -94
- msprobe/mindspore/service.py +0 -543
- msprobe/pytorch/hook_module/api_registry.py +0 -166
- msprobe/pytorch/hook_module/wrap_distributed.py +0 -79
- msprobe/pytorch/hook_module/wrap_functional.py +0 -66
- msprobe/pytorch/hook_module/wrap_npu_custom.py +0 -85
- msprobe/pytorch/hook_module/wrap_tensor.py +0 -69
- msprobe/pytorch/hook_module/wrap_torch.py +0 -84
- msprobe/pytorch/hook_module/wrap_vf.py +0 -60
- msprobe/pytorch/monitor/anomaly_analyse.py +0 -201
- msprobe/pytorch/monitor/anomaly_detect.py +0 -410
- msprobe/pytorch/monitor/module_spec_verifier.py +0 -95
- msprobe/pytorch/monitor/unittest/test_monitor.py +0 -160
- msprobe/pytorch/service.py +0 -470
- {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/LICENSE +0 -0
- {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/WHEEL +0 -0
- {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/entry_points.txt +0 -0
- {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/top_level.txt +0 -0
- /msprobe/{mindspore → core}/compare/ms_to_pt_api.yaml +0 -0
- /msprobe/{mindspore/dump → core}/kernel_dump/kernel_config.py +0 -0
- /msprobe/{pytorch/monitor/unittest → core/monitor}/__init__.py +0 -0
|
@@ -13,19 +13,63 @@
|
|
|
13
13
|
# See the License for the specific language governing permissions and
|
|
14
14
|
# limitations under the License.
|
|
15
15
|
|
|
16
|
+
import inspect
|
|
16
17
|
import os
|
|
17
18
|
import random
|
|
19
|
+
import types
|
|
18
20
|
|
|
19
21
|
import mindspore as ms
|
|
20
|
-
|
|
21
22
|
from mindspore import ops
|
|
23
|
+
from mindspore.common.jit_config import JitConfig
|
|
22
24
|
from mindspore.mint import nn
|
|
23
25
|
|
|
26
|
+
from msprobe.core.common.const import Const
|
|
27
|
+
from msprobe.core.common.decorator import recursion_depth_decorator
|
|
24
28
|
from msprobe.core.common.exceptions import DistributedNotInitializedError
|
|
25
29
|
from msprobe.core.common.file_utils import path_len_exceeds_limit, check_path_exists, save_npy
|
|
26
30
|
from msprobe.core.common.log import logger
|
|
27
|
-
from msprobe.core.common.
|
|
28
|
-
from msprobe.
|
|
31
|
+
from msprobe.core.common.utils import CompareException, check_seed_all, is_save_variable_valid
|
|
32
|
+
from msprobe.mindspore.common.const import Const as MsConst
|
|
33
|
+
|
|
34
|
+
try:
|
|
35
|
+
from mindspore._c_expression import _set_init_iter
|
|
36
|
+
except ImportError:
|
|
37
|
+
enable_dynamic_kbyk_dump = False
|
|
38
|
+
else:
|
|
39
|
+
enable_dynamic_kbyk_dump = True
|
|
40
|
+
|
|
41
|
+
mindtorch_check_result = None
|
|
42
|
+
register_backward_hook_functions = {}
|
|
43
|
+
kwargs_exist_in_forward_hook = None
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
class MsprobeStep(ms.train.Callback):
|
|
47
|
+
def __init__(self, debugger):
|
|
48
|
+
super(MsprobeStep, self).__init__()
|
|
49
|
+
self.debugger = debugger
|
|
50
|
+
|
|
51
|
+
def on_train_begin(self, run_context):
|
|
52
|
+
self.debugger.start()
|
|
53
|
+
if enable_dynamic_kbyk_dump:
|
|
54
|
+
_set_init_iter(0)
|
|
55
|
+
|
|
56
|
+
def on_train_step_begin(self, run_context):
|
|
57
|
+
self.debugger.start()
|
|
58
|
+
|
|
59
|
+
def on_train_step_end(self, run_context):
|
|
60
|
+
self.debugger.stop()
|
|
61
|
+
self.debugger.step()
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
class MsprobeInitStep(ms.train.Callback):
|
|
65
|
+
def on_train_begin(self, run_context):
|
|
66
|
+
try:
|
|
67
|
+
from ms._c_expression import _set_init_iter
|
|
68
|
+
except ImportError:
|
|
69
|
+
logger.warning('MsprobeInitStep does not work on this version of MindSpore.')
|
|
70
|
+
return
|
|
71
|
+
cb_params = run_context.original_args()
|
|
72
|
+
_set_init_iter(cb_params.cur_step_num)
|
|
29
73
|
|
|
30
74
|
|
|
31
75
|
def get_rank_if_initialized():
|
|
@@ -58,8 +102,8 @@ def convert_to_int(value):
|
|
|
58
102
|
|
|
59
103
|
|
|
60
104
|
def clean_input_kwargs(cell):
|
|
61
|
-
if hasattr(cell, '
|
|
62
|
-
del cell.
|
|
105
|
+
if hasattr(cell, 'msprobe_input_kwargs'):
|
|
106
|
+
del cell.msprobe_input_kwargs
|
|
63
107
|
|
|
64
108
|
|
|
65
109
|
def list_lowest_level_directories(root_dir):
|
|
@@ -93,20 +137,6 @@ def seed_all(seed=1234, mode=False, rm_dropout=True):
|
|
|
93
137
|
remove_dropout()
|
|
94
138
|
|
|
95
139
|
|
|
96
|
-
class MsprobeStep(ms.train.Callback):
|
|
97
|
-
|
|
98
|
-
def __init__(self, debugger):
|
|
99
|
-
super(MsprobeStep, self).__init__()
|
|
100
|
-
self.debugger = debugger
|
|
101
|
-
|
|
102
|
-
def on_train_step_begin(self, run_context):
|
|
103
|
-
self.debugger.start()
|
|
104
|
-
|
|
105
|
-
def on_train_step_end(self, run_context):
|
|
106
|
-
self.debugger.stop()
|
|
107
|
-
self.debugger.step()
|
|
108
|
-
|
|
109
|
-
|
|
110
140
|
class Dropout(ops.Dropout):
|
|
111
141
|
def __init__(self, keep_prob=0.5, seed0=0, seed1=1):
|
|
112
142
|
super().__init__(1., seed0, seed1)
|
|
@@ -142,9 +172,6 @@ def remove_dropout():
|
|
|
142
172
|
nn.functional.dropout = dropout_ext
|
|
143
173
|
|
|
144
174
|
|
|
145
|
-
mindtorch_check_result = None
|
|
146
|
-
|
|
147
|
-
|
|
148
175
|
def is_mindtorch():
|
|
149
176
|
global mindtorch_check_result
|
|
150
177
|
if mindtorch_check_result is None:
|
|
@@ -159,17 +186,17 @@ def is_mindtorch():
|
|
|
159
186
|
return mindtorch_check_result
|
|
160
187
|
|
|
161
188
|
|
|
162
|
-
register_backward_hook_functions = {}
|
|
163
|
-
|
|
164
|
-
|
|
165
189
|
def set_register_backward_hook_functions():
|
|
166
190
|
global register_backward_hook_functions
|
|
191
|
+
if register_backward_hook_functions:
|
|
192
|
+
return
|
|
193
|
+
|
|
167
194
|
if is_mindtorch():
|
|
168
195
|
import torch
|
|
169
196
|
from msprobe.mindspore.mindtorch import (_call_impl,
|
|
170
197
|
register_full_backward_pre_hook,
|
|
171
198
|
register_full_backward_hook)
|
|
172
|
-
if not hasattr(torch, "register_full_backward_hook"):
|
|
199
|
+
if not hasattr(torch.nn.Module, "register_full_backward_hook"):
|
|
173
200
|
setattr(torch.nn.Module, "_call_impl", _call_impl)
|
|
174
201
|
setattr(torch.nn.Module, "register_full_backward_pre_hook", register_full_backward_pre_hook)
|
|
175
202
|
setattr(torch.nn.Module, "register_full_backward_hook", register_full_backward_hook)
|
|
@@ -182,9 +209,11 @@ def set_register_backward_hook_functions():
|
|
|
182
209
|
|
|
183
210
|
def check_save_param(variable, name, save_backward):
|
|
184
211
|
# try catch this api to skip invalid call
|
|
185
|
-
|
|
212
|
+
valid_data_types = (ms.Tensor, int, float, str)
|
|
213
|
+
if not is_save_variable_valid(variable, valid_data_types):
|
|
214
|
+
valid_data_types_with_nested_types = valid_data_types + (dict, tuple, list)
|
|
186
215
|
logger.warning("PrecisionDebugger.save variable type not valid, "
|
|
187
|
-
"should be one of
|
|
216
|
+
f"should be one of {valid_data_types_with_nested_types}"
|
|
188
217
|
"Skip current save process.")
|
|
189
218
|
raise ValueError
|
|
190
219
|
if not isinstance(name, str):
|
|
@@ -196,4 +225,103 @@ def check_save_param(variable, name, save_backward):
|
|
|
196
225
|
logger.warning("PrecisionDebugger.save_backward name not valid, "
|
|
197
226
|
"should be bool. "
|
|
198
227
|
"Skip current save process.")
|
|
199
|
-
raise ValueError
|
|
228
|
+
raise ValueError
|
|
229
|
+
|
|
230
|
+
|
|
231
|
+
def is_graph_mode_cell_dump_allowed(config):
|
|
232
|
+
if config.task not in [Const.TENSOR, Const.STATISTICS] or is_mindtorch() or not hasattr(ops, 'DumpGradient'):
|
|
233
|
+
return False
|
|
234
|
+
valid_mix_level = [MsConst.CELL_AND_API, Const.LEVEL_MIX]
|
|
235
|
+
if config.level in valid_mix_level and config.execution_mode == MsConst.PYNATIVE_MODE:
|
|
236
|
+
return True
|
|
237
|
+
return config.level == MsConst.CELL or config.level == Const.LEVEL_L0
|
|
238
|
+
|
|
239
|
+
|
|
240
|
+
@recursion_depth_decorator('msprobe.mindspore.common.utils.is_decorated_by_jit')
|
|
241
|
+
def is_decorated_by_jit(func):
|
|
242
|
+
closure = getattr(func, '__closure__', [])
|
|
243
|
+
if closure:
|
|
244
|
+
for obj in closure:
|
|
245
|
+
if isinstance(obj.cell_contents, JitConfig):
|
|
246
|
+
return True
|
|
247
|
+
elif isinstance(obj.cell_contents, types.FunctionType) and hasattr(obj.cell_contents, '__closure__'):
|
|
248
|
+
if is_decorated_by_jit(obj.cell_contents):
|
|
249
|
+
return True
|
|
250
|
+
return False
|
|
251
|
+
|
|
252
|
+
|
|
253
|
+
@recursion_depth_decorator('msprobe.mindspore.common.utils.get_cells_and_names')
|
|
254
|
+
def get_cells_and_names(model, cells_set=None, name_prefix=''):
|
|
255
|
+
cells_set = cells_set if cells_set else set()
|
|
256
|
+
if model in cells_set:
|
|
257
|
+
return
|
|
258
|
+
|
|
259
|
+
cells_set.add(model)
|
|
260
|
+
jit_decorated = is_decorated_by_jit(model.construct)
|
|
261
|
+
yield name_prefix, model, jit_decorated
|
|
262
|
+
if jit_decorated:
|
|
263
|
+
return
|
|
264
|
+
|
|
265
|
+
children_cells = getattr(model, '_cells')
|
|
266
|
+
for name, cell in children_cells.items():
|
|
267
|
+
if cell:
|
|
268
|
+
cells_name_prefix = f'{name_prefix}{Const.SEP}{name}' if name_prefix else name
|
|
269
|
+
jit_decorated = is_decorated_by_jit(model.construct)
|
|
270
|
+
if jit_decorated:
|
|
271
|
+
yield cells_name_prefix, cell, jit_decorated
|
|
272
|
+
else:
|
|
273
|
+
for ele in get_cells_and_names(cell, cells_set, cells_name_prefix):
|
|
274
|
+
yield ele
|
|
275
|
+
|
|
276
|
+
|
|
277
|
+
def get_cells_and_names_with_index(models):
|
|
278
|
+
cells_with_index_in_pynative_mode = {}
|
|
279
|
+
cells_with_index_in_graph_mode = {}
|
|
280
|
+
|
|
281
|
+
def distinguish_cells(cells):
|
|
282
|
+
cells_in_pynative_mode = []
|
|
283
|
+
cells_in_graph_mode = []
|
|
284
|
+
for name, cell, jit_decorated in cells:
|
|
285
|
+
if jit_decorated:
|
|
286
|
+
cells_in_graph_mode.append((name, cell))
|
|
287
|
+
else:
|
|
288
|
+
cells_in_pynative_mode.append((name, cell))
|
|
289
|
+
return cells_in_pynative_mode, cells_in_graph_mode
|
|
290
|
+
|
|
291
|
+
if is_mindtorch():
|
|
292
|
+
if isinstance(models, (list, tuple)):
|
|
293
|
+
for index, model in enumerate(models):
|
|
294
|
+
cells_with_index_in_pynative_mode[str(index)] = model.named_modules()
|
|
295
|
+
else:
|
|
296
|
+
cells_with_index_in_pynative_mode["-1"] = models.named_modules()
|
|
297
|
+
else:
|
|
298
|
+
if isinstance(models, (list, tuple)):
|
|
299
|
+
for index, model in enumerate(models):
|
|
300
|
+
cells = get_cells_and_names(model)
|
|
301
|
+
cells_in_pynative_mode, cells_in_graph_mode = distinguish_cells(cells)
|
|
302
|
+
cells_with_index_in_pynative_mode[str(index)] = cells_in_pynative_mode
|
|
303
|
+
cells_with_index_in_graph_mode[str(index)] = cells_in_graph_mode
|
|
304
|
+
else:
|
|
305
|
+
cells = get_cells_and_names(models)
|
|
306
|
+
cells_in_pynative_mode, cells_in_graph_mode = distinguish_cells(cells)
|
|
307
|
+
cells_with_index_in_pynative_mode["-1"] = cells_in_pynative_mode
|
|
308
|
+
cells_with_index_in_graph_mode["-1"] = cells_in_graph_mode
|
|
309
|
+
|
|
310
|
+
return cells_with_index_in_pynative_mode, cells_with_index_in_graph_mode
|
|
311
|
+
|
|
312
|
+
|
|
313
|
+
def has_kwargs_in_forward_hook():
|
|
314
|
+
global kwargs_exist_in_forward_hook
|
|
315
|
+
|
|
316
|
+
if kwargs_exist_in_forward_hook is None:
|
|
317
|
+
if is_mindtorch():
|
|
318
|
+
kwargs_exist_in_forward_hook = True
|
|
319
|
+
return kwargs_exist_in_forward_hook
|
|
320
|
+
|
|
321
|
+
try:
|
|
322
|
+
func_params = inspect.signature(nn.Cell.register_forward_hook).parameters
|
|
323
|
+
kwargs_exist_in_forward_hook = 'with_kwargs' in func_params
|
|
324
|
+
except Exception:
|
|
325
|
+
kwargs_exist_in_forward_hook = False
|
|
326
|
+
|
|
327
|
+
return kwargs_exist_in_forward_hook
|
|
@@ -0,0 +1,382 @@
|
|
|
1
|
+
# Copyright (c) 2025-2025, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.import functools
|
|
15
|
+
|
|
16
|
+
import os
|
|
17
|
+
import multiprocessing
|
|
18
|
+
from dataclasses import dataclass
|
|
19
|
+
from typing import Dict, List, Tuple, Optional, Any
|
|
20
|
+
from concurrent.futures import ProcessPoolExecutor
|
|
21
|
+
from functools import partial
|
|
22
|
+
from pathlib import Path
|
|
23
|
+
|
|
24
|
+
import pandas as pd
|
|
25
|
+
import numpy as np
|
|
26
|
+
from tqdm import tqdm
|
|
27
|
+
|
|
28
|
+
from msprobe.core.common.log import logger
|
|
29
|
+
from msprobe.core.common.utils import CompareException
|
|
30
|
+
from msprobe.core.common.exceptions import FileCheckException
|
|
31
|
+
from msprobe.core.common.file_utils import check_file_or_directory_path, write_df_to_csv, create_directory, \
|
|
32
|
+
check_path_before_create, load_npy
|
|
33
|
+
from msprobe.core.common.const import CompareConst, FileCheckConst
|
|
34
|
+
from msprobe.core.compare.npy_compare import compare_ops_apply
|
|
35
|
+
from msprobe.core.compare.multiprocessing_compute import check_accuracy
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
def common_dir_compare(input_params: Dict, output_dir: str) -> Optional[pd.DataFrame]:
|
|
39
|
+
"""
|
|
40
|
+
高级目录比对函数,完全镜像输入目录结构
|
|
41
|
+
|
|
42
|
+
Args:
|
|
43
|
+
input_params: 包含npu_path和bench_path的字典
|
|
44
|
+
output_dir: 输出根目录
|
|
45
|
+
|
|
46
|
+
Returns:
|
|
47
|
+
当输入目录是平铺npy文件时返回DataFrame,否则返回None
|
|
48
|
+
"""
|
|
49
|
+
npu_root = Path(input_params.get('npu_path'))
|
|
50
|
+
bench_root = Path(input_params.get('bench_path'))
|
|
51
|
+
name_map_dict = input_params.get('map_dict', {})
|
|
52
|
+
file_tree = build_mirror_file_tree(npu_root, bench_root)
|
|
53
|
+
|
|
54
|
+
# 处理文件比对
|
|
55
|
+
with ProcessPoolExecutor() as executor:
|
|
56
|
+
results = list(tqdm(
|
|
57
|
+
executor.map(
|
|
58
|
+
partial(process_directory_pair, name_map_dict=name_map_dict, output_dir=output_dir),
|
|
59
|
+
file_tree.items()
|
|
60
|
+
),
|
|
61
|
+
total=len(file_tree),
|
|
62
|
+
desc="Processing directories"
|
|
63
|
+
))
|
|
64
|
+
return
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
def process_directory_pair(item: Tuple[Path, Tuple[Path, Path]], name_map_dict: Dict, output_dir: str):
|
|
68
|
+
"""
|
|
69
|
+
处理一个目录对
|
|
70
|
+
|
|
71
|
+
Args:
|
|
72
|
+
item: (相对路径, (npu目录, bench目录))元组
|
|
73
|
+
output_dir: 输出根目录
|
|
74
|
+
|
|
75
|
+
Returns:
|
|
76
|
+
比对结果的DataFrame(仅平铺结构时返回)
|
|
77
|
+
"""
|
|
78
|
+
rel_path, (npu_dir, bench_dir) = item
|
|
79
|
+
|
|
80
|
+
# 创建镜像输出目录
|
|
81
|
+
output_path = Path(output_dir) / rel_path
|
|
82
|
+
create_directory(output_path)
|
|
83
|
+
|
|
84
|
+
# 生成文件映射
|
|
85
|
+
npu_files = find_npy_files(npu_dir)
|
|
86
|
+
bench_files = find_npy_files(bench_dir)
|
|
87
|
+
map_dict = generate_map_dict(npu_files, bench_files, name_map_dict)
|
|
88
|
+
|
|
89
|
+
if not map_dict:
|
|
90
|
+
logger.warning(f"No file pairs found in {rel_path}")
|
|
91
|
+
return None
|
|
92
|
+
|
|
93
|
+
# 执行比对
|
|
94
|
+
result_df = do_multi_process(process_chunk, map_dict)
|
|
95
|
+
check_path_before_create(output_path)
|
|
96
|
+
# 保存结果
|
|
97
|
+
result_path = os.path.join(output_path, 'result.csv')
|
|
98
|
+
write_df_to_csv(result_df, result_path)
|
|
99
|
+
logger.info(f"Results saved to {result_path}")
|
|
100
|
+
return None
|
|
101
|
+
|
|
102
|
+
|
|
103
|
+
def build_mirror_file_tree(npu_root: Path, bench_root: Path) -> Dict[Path, Tuple[Path, Path]]:
|
|
104
|
+
"""
|
|
105
|
+
构建镜像文件树,键为相对路径,值为(npu_path, bench_path)元组
|
|
106
|
+
|
|
107
|
+
Args:
|
|
108
|
+
npu_root: NPU数据根目录
|
|
109
|
+
bench_root: 基准数据根目录
|
|
110
|
+
|
|
111
|
+
Returns:
|
|
112
|
+
文件树字典
|
|
113
|
+
"""
|
|
114
|
+
file_tree = {}
|
|
115
|
+
|
|
116
|
+
# 遍历NPU目录构建树结构
|
|
117
|
+
for npu_path in npu_root.rglob('*.npy'):
|
|
118
|
+
dir_path = npu_path.relative_to(npu_root).parent
|
|
119
|
+
npu_dir_pair = os.path.join(npu_root, dir_path)
|
|
120
|
+
bench_dir_pair = os.path.join(bench_root, dir_path)
|
|
121
|
+
try:
|
|
122
|
+
check_file_or_directory_path(bench_dir_pair, isdir=True)
|
|
123
|
+
except FileCheckException:
|
|
124
|
+
continue
|
|
125
|
+
# 添加到文件树
|
|
126
|
+
if dir_path not in file_tree:
|
|
127
|
+
file_tree[dir_path] = (npu_dir_pair, bench_dir_pair)
|
|
128
|
+
|
|
129
|
+
return file_tree
|
|
130
|
+
|
|
131
|
+
|
|
132
|
+
def find_npy_files(directory):
|
|
133
|
+
npy_files_dict = {}
|
|
134
|
+
for root, _, files in os.walk(directory):
|
|
135
|
+
for file in files:
|
|
136
|
+
if file.endswith(".npy"):
|
|
137
|
+
# 分割文件名并去掉最后两个元素
|
|
138
|
+
file_name = file.split('_')
|
|
139
|
+
if len(file_name) < 2:
|
|
140
|
+
continue
|
|
141
|
+
key = '_'.join(file_name[:-2])
|
|
142
|
+
# 文件的完整路径
|
|
143
|
+
value = os.path.join(root, file)
|
|
144
|
+
# 添加到字典中
|
|
145
|
+
if not npy_files_dict.get(key):
|
|
146
|
+
npy_files_dict[key] = []
|
|
147
|
+
npy_files_dict[key].append(value)
|
|
148
|
+
return npy_files_dict
|
|
149
|
+
|
|
150
|
+
|
|
151
|
+
def generate_map_dict(npu_file_dict, bench_file_dict, name_map_dict=None):
|
|
152
|
+
for k, npu_file_list in npu_file_dict.items():
|
|
153
|
+
bench_file_list = bench_file_dict.get(k)
|
|
154
|
+
if not bench_file_list and k in name_map_dict:
|
|
155
|
+
bench_file_list = bench_file_dict.get(name_map_dict.get(k))
|
|
156
|
+
bench_length = len(bench_file_list)
|
|
157
|
+
if not (bench_file_list and bench_length):
|
|
158
|
+
continue
|
|
159
|
+
result_dict = {}
|
|
160
|
+
for i, npu_file in enumerate(npu_file_list):
|
|
161
|
+
if i >= bench_length:
|
|
162
|
+
break
|
|
163
|
+
bench_file = bench_file_list[i]
|
|
164
|
+
result_dict[f"{k}_{i}"] = (npu_file, bench_file)
|
|
165
|
+
return result_dict
|
|
166
|
+
|
|
167
|
+
|
|
168
|
+
def do_multi_process(func, map_dict):
|
|
169
|
+
lock = multiprocessing.Manager().RLock()
|
|
170
|
+
result_len = len(map_dict)
|
|
171
|
+
process_num = max(int((multiprocessing.cpu_count() + 1) // 4), 1)
|
|
172
|
+
# every block size
|
|
173
|
+
df_chunk_size = result_len // process_num
|
|
174
|
+
|
|
175
|
+
# generate the same len of map_dict df
|
|
176
|
+
result_df = initialize_result_df(result_len)
|
|
177
|
+
if df_chunk_size > 0:
|
|
178
|
+
df_chunks = [result_df.iloc[i:i + df_chunk_size] for i in range(0, len(result_df), df_chunk_size)]
|
|
179
|
+
else:
|
|
180
|
+
df_chunks = [result_df]
|
|
181
|
+
process_num = 1
|
|
182
|
+
logger.info(f"Using {process_num} processes with chunk size {df_chunk_size}")
|
|
183
|
+
|
|
184
|
+
# 分割字典
|
|
185
|
+
map_chunks = split_dict(map_dict, df_chunk_size)
|
|
186
|
+
|
|
187
|
+
# 创建结果列表和进程池
|
|
188
|
+
results = []
|
|
189
|
+
pool = multiprocessing.Pool(process_num)
|
|
190
|
+
|
|
191
|
+
progress_bar = tqdm(total=len(result_df), desc="API/Module Item Compare Process", unit="row", ncols=100)
|
|
192
|
+
|
|
193
|
+
def update_progress(size, progress_lock, extra_param=None):
|
|
194
|
+
with progress_lock:
|
|
195
|
+
progress_bar.update(size)
|
|
196
|
+
|
|
197
|
+
def err_call(args):
|
|
198
|
+
logger.error('multiprocess compare failed! Reason: {}'.format(args))
|
|
199
|
+
try:
|
|
200
|
+
pool.close()
|
|
201
|
+
except OSError as e:
|
|
202
|
+
logger.error(f'pool terminate failed: {str(e)}')
|
|
203
|
+
results = []
|
|
204
|
+
try:
|
|
205
|
+
# 提交任务到进程池
|
|
206
|
+
for process_idx, (df_chunk, map_chunk) in enumerate(zip(df_chunks, map_chunks)):
|
|
207
|
+
start_idx = df_chunk_size * process_idx
|
|
208
|
+
result = pool.apply_async(
|
|
209
|
+
func,
|
|
210
|
+
args=(df_chunk, start_idx, map_chunk, lock),
|
|
211
|
+
error_callback=err_call,
|
|
212
|
+
callback=partial(update_progress, len(map_chunk), lock)
|
|
213
|
+
)
|
|
214
|
+
results.append(result)
|
|
215
|
+
|
|
216
|
+
final_results = [r.get() for r in results]
|
|
217
|
+
# 等待所有任务完成
|
|
218
|
+
pool.close()
|
|
219
|
+
pool.join()
|
|
220
|
+
return pd.concat(final_results, ignore_index=True)
|
|
221
|
+
except Exception as e:
|
|
222
|
+
logger.error(f"\nMain process error: {str(e)}")
|
|
223
|
+
pool.terminate()
|
|
224
|
+
return pd.DataFrame({})
|
|
225
|
+
finally:
|
|
226
|
+
pool.close()
|
|
227
|
+
|
|
228
|
+
|
|
229
|
+
def initialize_result_df(total_size):
|
|
230
|
+
"""预分配结果DataFrame"""
|
|
231
|
+
columns = [
|
|
232
|
+
CompareConst.NAME,
|
|
233
|
+
CompareConst.NPU_DTYPE,
|
|
234
|
+
CompareConst.BENCH_DTYPE,
|
|
235
|
+
CompareConst.NPU_SHAPE,
|
|
236
|
+
CompareConst.BENCH_SHAPE,
|
|
237
|
+
CompareConst.COSINE,
|
|
238
|
+
CompareConst.EUC_DIST,
|
|
239
|
+
CompareConst.MAX_ABS_ERR,
|
|
240
|
+
CompareConst.MAX_RELATIVE_ERR,
|
|
241
|
+
CompareConst.ONE_THOUSANDTH_ERR_RATIO,
|
|
242
|
+
CompareConst.FIVE_THOUSANDTHS_ERR_RATIO,
|
|
243
|
+
CompareConst.NPU_MAX,
|
|
244
|
+
CompareConst.NPU_MIN,
|
|
245
|
+
CompareConst.NPU_MEAN,
|
|
246
|
+
CompareConst.NPU_NORM,
|
|
247
|
+
CompareConst.BENCH_MAX,
|
|
248
|
+
CompareConst.BENCH_MIN,
|
|
249
|
+
CompareConst.BENCH_MEAN,
|
|
250
|
+
CompareConst.BENCH_NORM,
|
|
251
|
+
CompareConst.ACCURACY,
|
|
252
|
+
CompareConst.ERROR_MESSAGE,
|
|
253
|
+
CompareConst.DATA_NAME
|
|
254
|
+
]
|
|
255
|
+
return pd.DataFrame(index=range(total_size), columns=columns)
|
|
256
|
+
|
|
257
|
+
|
|
258
|
+
def split_dict(input_dict, chunk_size):
|
|
259
|
+
"""将字典按指定chunk_size分割"""
|
|
260
|
+
items = list(input_dict.items())
|
|
261
|
+
if chunk_size > 0:
|
|
262
|
+
return [dict(items[i:i + chunk_size]) for i in range(0, len(items), chunk_size)]
|
|
263
|
+
return [input_dict]
|
|
264
|
+
|
|
265
|
+
|
|
266
|
+
def get_tensor_stats(tensor: np.ndarray) -> Tuple[float, float, float, float]:
|
|
267
|
+
"""获取张量的统计信息"""
|
|
268
|
+
t_max = np.max(tensor)
|
|
269
|
+
t_min = np.min(tensor)
|
|
270
|
+
t_mean = np.mean(tensor)
|
|
271
|
+
t_l2norm = np.linalg.norm(tensor)
|
|
272
|
+
return t_max, t_min, t_mean, t_l2norm
|
|
273
|
+
|
|
274
|
+
|
|
275
|
+
def process_chunk(df, start_idx, map_chunk, lock):
|
|
276
|
+
"""处理一个数据块"""
|
|
277
|
+
err_mess = []
|
|
278
|
+
results = []
|
|
279
|
+
for name, file_pair in map_chunk.items():
|
|
280
|
+
err_msg = ""
|
|
281
|
+
npu_file, bench_file = file_pair
|
|
282
|
+
n_value = load_npy(npu_file)
|
|
283
|
+
# if need to support cross frame b_value need to add load_pt
|
|
284
|
+
b_value = load_npy(bench_file)
|
|
285
|
+
error_flag = False
|
|
286
|
+
|
|
287
|
+
err_list, err_msg = compare_ops_apply(n_value, b_value, error_flag, err_msg)
|
|
288
|
+
cos_sim, euc_dist, max_abs_err, max_relative_err, one_thousand_err_ratio, five_thousand_err_ratio = err_list
|
|
289
|
+
a_max, a_min, a_mean, a_l2norm = get_tensor_stats(n_value)
|
|
290
|
+
b_max, b_min, b_mean, b_l2norm = get_tensor_stats(b_value)
|
|
291
|
+
err_mess.append(err_msg)
|
|
292
|
+
# 使用示例
|
|
293
|
+
result = ComparisonResult(
|
|
294
|
+
name=name, # CompareConst.NAME
|
|
295
|
+
npu_dtype=n_value.dtype, # CompareConst.NPU_DTYPE
|
|
296
|
+
bench_dtype=b_value.dtype, # CompareConst.BENCH_DTYPE
|
|
297
|
+
npu_shape=n_value.shape, # CompareConst.NPU_SHAPE
|
|
298
|
+
bench_shape=b_value.shape, # CompareConst.BENCH_SHAPE
|
|
299
|
+
cosine=cos_sim, # CompareConst.COSINE
|
|
300
|
+
euc_dist=euc_dist, # CompareConst.EUC_DIST
|
|
301
|
+
max_abs_err=max_abs_err, # CompareConst.MAX_ABS_ERR
|
|
302
|
+
max_relative_err=max_relative_err, # CompareConst.MAX_RELATIVE_ERR
|
|
303
|
+
one_thousandth_err_ratio=one_thousand_err_ratio, # CompareConst.ONE_THOUSANDTH_ERR_RATIO
|
|
304
|
+
five_thousandth_err_ratio=five_thousand_err_ratio, # CompareConst.FIVE_THOUSANDTHS_ERR_RATIO
|
|
305
|
+
npu_max=a_max, # CompareConst.NPU_MAX
|
|
306
|
+
npu_min=a_min, # CompareConst.NPU_MIN
|
|
307
|
+
npu_mean=a_mean, # CompareConst.NPU_MEAN
|
|
308
|
+
npu_norm=a_l2norm, # CompareConst.NPU_NORM
|
|
309
|
+
bench_max=b_max, # CompareConst.BENCH_MAX
|
|
310
|
+
bench_min=b_min, # CompareConst.BENCH_MIN
|
|
311
|
+
bench_mean=b_mean, # CompareConst.BENCH_MEAN
|
|
312
|
+
bench_norm=b_l2norm, # CompareConst.BENCH_NORM
|
|
313
|
+
accuracy=check_accuracy(cos_sim, max_abs_err), # CompareConst.ACCURACY
|
|
314
|
+
error_message=err_msg, # CompareConst.ERROR_MESSAGE
|
|
315
|
+
data_name=[npu_file, bench_file] # CompareConst.DATA_NAME
|
|
316
|
+
)
|
|
317
|
+
results.append(result)
|
|
318
|
+
return _save_part_df(df, start_idx, results, lock)
|
|
319
|
+
|
|
320
|
+
|
|
321
|
+
@dataclass
|
|
322
|
+
class ComparisonResult:
|
|
323
|
+
name: str # CompareConst.NAME
|
|
324
|
+
npu_dtype: Any # CompareConst.NPU_DTYPE
|
|
325
|
+
bench_dtype: Any # CompareConst.BENCH_DTYPE
|
|
326
|
+
npu_shape: Tuple[int, ...] # CompareConst.NPU_SHAPE
|
|
327
|
+
bench_shape: Tuple[int, ...] # CompareConst.BENCH_SHAPE
|
|
328
|
+
cosine: float # Cons t.COSINE
|
|
329
|
+
euc_dist: float # CompareConst.EUC_DIST
|
|
330
|
+
max_abs_err: float # CompareConst.MAX_ABS_ERR
|
|
331
|
+
max_relative_err: float # CompareConst.MAX_RELATIVE_ERR
|
|
332
|
+
one_thousandth_err_ratio: float # CompareConst.ONE_THOUSANDTH_ERR_RATIO
|
|
333
|
+
five_thousandth_err_ratio: float # CompareConst.FIVE_THOUSANDTHS_ERR_RATIO
|
|
334
|
+
npu_max: float # CompareConst.NPU_MAX
|
|
335
|
+
npu_min: float # CompareConst.NPU_MIN
|
|
336
|
+
npu_mean: float # CompareConst.NPU_MEAN
|
|
337
|
+
npu_norm: float # CompareConst.NPU_NORM
|
|
338
|
+
bench_max: float # CompareConst.BENCH_MAX
|
|
339
|
+
bench_min: float # CompareConst.BENCH_MIN
|
|
340
|
+
bench_mean: float # CompareConst.BENCH_MEAN
|
|
341
|
+
bench_norm: float # CompareConst.BENCH_NORM
|
|
342
|
+
accuracy: bool # CompareConst.ACCURACY
|
|
343
|
+
error_message: str # CompareConst.ERROR_MESSAGE
|
|
344
|
+
data_name: List[str] # CompareConst.DATA_NAME
|
|
345
|
+
|
|
346
|
+
|
|
347
|
+
def _save_part_df(df, start_idx, results, lock):
|
|
348
|
+
lock.acquire()
|
|
349
|
+
try:
|
|
350
|
+
for i, result in enumerate(results):
|
|
351
|
+
process_index = i + start_idx
|
|
352
|
+
df.loc[process_index, CompareConst.NAME] = result.name
|
|
353
|
+
df.loc[process_index, CompareConst.NPU_DTYPE] = result.npu_dtype
|
|
354
|
+
df.loc[process_index, CompareConst.BENCH_DTYPE] = result.bench_dtype
|
|
355
|
+
df.loc[process_index, CompareConst.NPU_SHAPE] = str(result.npu_shape) # 通常将tuple转为字符串存储
|
|
356
|
+
df.loc[process_index, CompareConst.BENCH_SHAPE] = str(result.bench_shape)
|
|
357
|
+
df.loc[process_index, CompareConst.COSINE] = result.cosine
|
|
358
|
+
df.loc[process_index, CompareConst.EUC_DIST] = result.euc_dist
|
|
359
|
+
df.loc[process_index, CompareConst.MAX_ABS_ERR] = result.max_abs_err
|
|
360
|
+
df.loc[process_index, CompareConst.MAX_RELATIVE_ERR] = result.max_relative_err
|
|
361
|
+
df.loc[process_index, CompareConst.ONE_THOUSANDTH_ERR_RATIO] = result.one_thousandth_err_ratio
|
|
362
|
+
df.loc[process_index, CompareConst.FIVE_THOUSANDTHS_ERR_RATIO] = result.five_thousandth_err_ratio
|
|
363
|
+
df.loc[process_index, CompareConst.NPU_MAX] = result.npu_max
|
|
364
|
+
df.loc[process_index, CompareConst.NPU_MIN] = result.npu_min
|
|
365
|
+
df.loc[process_index, CompareConst.NPU_MEAN] = result.npu_mean
|
|
366
|
+
df.loc[process_index, CompareConst.NPU_NORM] = result.npu_norm
|
|
367
|
+
df.loc[process_index, CompareConst.BENCH_MAX] = result.bench_max
|
|
368
|
+
df.loc[process_index, CompareConst.BENCH_MIN] = result.bench_min
|
|
369
|
+
df.loc[process_index, CompareConst.BENCH_MEAN] = result.bench_mean
|
|
370
|
+
df.loc[process_index, CompareConst.BENCH_NORM] = result.bench_norm
|
|
371
|
+
df.loc[process_index, CompareConst.ACCURACY] = result.accuracy
|
|
372
|
+
df.loc[process_index, CompareConst.ERROR_MESSAGE] = result.error_message
|
|
373
|
+
df.loc[process_index, CompareConst.DATA_NAME] = str(result.data_name) # 列表转为字符串存储
|
|
374
|
+
return df
|
|
375
|
+
except ValueError as e:
|
|
376
|
+
logger.error('result dataframe is not found.')
|
|
377
|
+
raise CompareException(CompareException.INVALID_DATA_ERROR) from e
|
|
378
|
+
except IndexError as e:
|
|
379
|
+
logger.error('result dataframe elements can not be access.')
|
|
380
|
+
raise CompareException(CompareException.INDEX_OUT_OF_BOUNDS_ERROR) from e
|
|
381
|
+
finally:
|
|
382
|
+
lock.release()
|
|
@@ -13,41 +13,17 @@
|
|
|
13
13
|
# See the License for the specific language governing permissions and
|
|
14
14
|
# limitations under the License.
|
|
15
15
|
|
|
16
|
-
import os
|
|
17
16
|
from msprobe.core.common.utils import CompareException
|
|
18
17
|
from msprobe.core.common.file_utils import create_directory
|
|
19
18
|
from msprobe.core.common.exceptions import FileCheckException
|
|
20
19
|
from msprobe.mindspore.common.log import logger
|
|
21
20
|
from msprobe.mindspore.compare.ms_compare import ms_compare
|
|
22
|
-
from msprobe.core.compare.utils import
|
|
21
|
+
from msprobe.core.compare.utils import compare_distributed_inner
|
|
23
22
|
from msprobe.mindspore.compare.ms_graph_compare import GraphMSComparator
|
|
24
23
|
|
|
25
24
|
|
|
26
25
|
def ms_compare_distributed(npu_dump_dir, bench_dump_dir, output_path, **kwargs):
|
|
27
|
-
|
|
28
|
-
logger.error("Argument 'suffix' is not supported for compare_distributed.")
|
|
29
|
-
raise CompareException(CompareException.INVALID_PARAM_ERROR)
|
|
30
|
-
is_print_compare_log = kwargs.get('is_print_compare_log', True)
|
|
31
|
-
# get the ranks and match by order
|
|
32
|
-
npu_ranks = sorted(check_and_return_dir_contents(npu_dump_dir, 'rank'))
|
|
33
|
-
bench_ranks = sorted(check_and_return_dir_contents(bench_dump_dir, 'rank'))
|
|
34
|
-
if len(npu_ranks) != len(bench_ranks):
|
|
35
|
-
logger.error('The number of ranks in the two runs are different. '
|
|
36
|
-
'Unable to match the ranks. Please use another folder to compare '
|
|
37
|
-
'or use compare() api and manually match the ranks.')
|
|
38
|
-
raise CompareException(CompareException.INVALID_PATH_ERROR)
|
|
39
|
-
for nr, br in zip(npu_ranks, bench_ranks):
|
|
40
|
-
npu_data_dir = os.path.join(npu_dump_dir, nr)
|
|
41
|
-
bench_data_dir = os.path.join(bench_dump_dir, br)
|
|
42
|
-
npu_path = extract_json(npu_data_dir, stack_json=False)
|
|
43
|
-
bench_path = extract_json(bench_data_dir, stack_json=False)
|
|
44
|
-
|
|
45
|
-
dump_result_param = {
|
|
46
|
-
'npu_json_path': npu_path,
|
|
47
|
-
'bench_json_path': bench_path,
|
|
48
|
-
'is_print_compare_log': is_print_compare_log
|
|
49
|
-
}
|
|
50
|
-
ms_compare(input_param=dump_result_param, output_path=output_path, suffix=f'_{nr}-{br}', **kwargs)
|
|
26
|
+
compare_distributed_inner(npu_dump_dir, bench_dump_dir, output_path, ms_compare, **kwargs)
|
|
51
27
|
|
|
52
28
|
|
|
53
29
|
def ms_graph_compare(inputs, outputs):
|