mindstudio-probe 1.1.1__py3-none-any.whl → 1.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.1.dist-info}/METADATA +3 -2
- {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.1.dist-info}/RECORD +196 -141
- msprobe/CMakeLists.txt +5 -0
- msprobe/README.md +14 -19
- msprobe/config.json +1 -0
- msprobe/core/common/const.py +155 -6
- msprobe/core/common/exceptions.py +3 -1
- msprobe/core/common/file_utils.py +33 -7
- msprobe/core/common/inplace_ops.yaml +3 -0
- msprobe/core/common/utils.py +28 -14
- msprobe/core/common_config.py +6 -0
- msprobe/core/compare/acc_compare.py +139 -128
- msprobe/core/compare/check.py +31 -29
- msprobe/core/compare/compare_cli.py +17 -16
- msprobe/core/compare/highlight.py +186 -99
- msprobe/core/compare/layer_mapping/data_scope_parser.py +18 -7
- msprobe/core/compare/layer_mapping/layer_mapping.py +21 -14
- msprobe/core/compare/layer_mapping/postprocess_pass.py +4 -3
- msprobe/core/compare/merge_result/merge_result.py +380 -0
- msprobe/core/compare/merge_result/merge_result_cli.py +31 -0
- msprobe/core/compare/multiprocessing_compute.py +2 -2
- msprobe/core/compare/npy_compare.py +109 -147
- msprobe/core/compare/utils.py +189 -69
- msprobe/core/data_dump/data_collector.py +51 -21
- msprobe/core/data_dump/data_processor/base.py +38 -20
- msprobe/core/data_dump/data_processor/factory.py +5 -3
- msprobe/core/data_dump/data_processor/mindspore_processor.py +154 -20
- msprobe/core/data_dump/data_processor/pytorch_processor.py +118 -58
- msprobe/core/data_dump/json_writer.py +29 -1
- msprobe/core/data_dump/scope.py +19 -18
- msprobe/core/overflow_check/abnormal_scene.py +9 -5
- msprobe/core/overflow_check/checker.py +1 -1
- msprobe/core/overflow_check/utils.py +1 -1
- msprobe/docs/01.installation.md +96 -17
- msprobe/docs/02.config_introduction.md +5 -5
- msprobe/docs/05.data_dump_PyTorch.md +91 -61
- msprobe/docs/06.data_dump_MindSpore.md +57 -19
- msprobe/docs/07.accuracy_checker_PyTorch.md +18 -18
- msprobe/docs/09.accuracy_checker_MindSpore.md +4 -4
- msprobe/docs/10.accuracy_compare_PyTorch.md +99 -41
- msprobe/docs/11.accuracy_compare_MindSpore.md +249 -48
- msprobe/docs/12.overflow_check_PyTorch.md +1 -1
- msprobe/docs/19.monitor.md +120 -27
- msprobe/docs/21.visualization_PyTorch.md +115 -35
- msprobe/docs/22.visualization_MindSpore.md +138 -41
- msprobe/docs/23.generate_operator_PyTorch.md +107 -0
- msprobe/docs/24.code_mapping_Mindspore.md +28 -0
- msprobe/docs/{23.tool_function_introduction.md → 25.tool_function_introduction.md} +1 -0
- msprobe/docs/26.data_dump_PyTorch_baseline.md +37 -0
- msprobe/docs/27.dump_json_instruction.md +521 -0
- msprobe/docs/FAQ.md +26 -2
- msprobe/docs/accuracy_checker_MindSpore/accuracy_checker_MindSpore_baseline.md +14 -0
- msprobe/docs/data_dump_MindSpore/data_dump_MindSpore_baseline.md +22 -0
- msprobe/docs/img/merge_result.png +0 -0
- msprobe/docs/img/visualization/fuzzy_match_ms.png +0 -0
- msprobe/docs/img/visualization/fuzzy_match_pt.png +0 -0
- msprobe/docs/img/visualization/tensorboard_1.png +0 -0
- msprobe/docs/img/visualization/tensorboard_2.png +0 -0
- msprobe/docs/img/visualization/vis_browser_1.png +0 -0
- msprobe/docs/img/visualization/vis_browser_2.png +0 -0
- msprobe/docs/img/visualization/vis_precision_info.png +0 -0
- msprobe/docs/img/visualization/vis_search_info.png +0 -0
- msprobe/docs/img/visualization/vis_show_info.png +0 -0
- msprobe/docs/img/visualization/vis_showcase.png +0 -0
- msprobe/docs/img/visualization/vis_unmatch_info.png +0 -0
- msprobe/docs/visualization/GPTModel.png +0 -0
- msprobe/docs/visualization/ParallelMLP.png +0 -0
- msprobe/docs/visualization/layer_mapping_example.md +132 -0
- msprobe/docs/visualization/mapping.png +0 -0
- msprobe/docs/visualization/mapping1.png +0 -0
- msprobe/docs/visualization/module_name.png +0 -0
- msprobe/docs/visualization/module_name1.png +0 -0
- msprobe/docs/visualization/no_mapping.png +0 -0
- msprobe/docs/visualization/no_mapping1.png +0 -0
- msprobe/docs/visualization/no_mapping_analyze.png +0 -0
- msprobe/docs/visualization/top_layer.png +0 -0
- msprobe/mindspore/__init__.py +10 -0
- msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +57 -25
- msprobe/mindspore/api_accuracy_checker/cmd_parser.py +2 -1
- msprobe/mindspore/api_accuracy_checker/compute_element.py +5 -7
- msprobe/mindspore/api_accuracy_checker/data_manager.py +37 -0
- msprobe/mindspore/api_accuracy_checker/main.py +1 -0
- msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +12 -6
- msprobe/mindspore/api_accuracy_checker/multi_data_manager.py +3 -1
- msprobe/mindspore/code_mapping/bind.py +264 -0
- msprobe/mindspore/code_mapping/cmd_parser.py +40 -0
- msprobe/mindspore/code_mapping/graph.py +49 -0
- msprobe/mindspore/code_mapping/graph_parser.py +226 -0
- msprobe/mindspore/code_mapping/main.py +24 -0
- msprobe/mindspore/code_mapping/processor.py +34 -0
- msprobe/mindspore/common/const.py +3 -1
- msprobe/mindspore/common/utils.py +50 -5
- msprobe/mindspore/compare/distributed_compare.py +0 -2
- msprobe/mindspore/compare/ms_compare.py +105 -63
- msprobe/mindspore/compare/ms_graph_compare.py +14 -5
- msprobe/mindspore/debugger/debugger_config.py +3 -0
- msprobe/mindspore/debugger/precision_debugger.py +81 -12
- msprobe/mindspore/dump/hook_cell/api_registry.py +83 -16
- msprobe/mindspore/dump/hook_cell/hook_cell.py +60 -38
- msprobe/mindspore/dump/hook_cell/primitive_hooks.py +33 -15
- msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +11 -1
- msprobe/mindspore/dump/hook_cell/wrap_api.py +92 -1
- msprobe/mindspore/dump/kernel_dump/kernel_config.py +33 -0
- msprobe/mindspore/dump/kernel_graph_dump.py +7 -0
- msprobe/mindspore/free_benchmark/api_pynative_self_check.py +13 -4
- msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +2 -2
- msprobe/mindspore/grad_probe/grad_analyzer.py +24 -12
- msprobe/mindspore/grad_probe/hook.py +13 -4
- msprobe/mindspore/mindtorch/__init__.py +18 -0
- msprobe/mindspore/mindtorch/mindtorch_adaptor.py +255 -0
- msprobe/mindspore/ms_config.py +5 -1
- msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +7 -0
- msprobe/mindspore/service.py +267 -101
- msprobe/msprobe.py +24 -3
- msprobe/pytorch/__init__.py +7 -6
- msprobe/pytorch/api_accuracy_checker/common/utils.py +31 -16
- msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +41 -8
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +100 -267
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_standard.yaml +4 -1
- msprobe/pytorch/api_accuracy_checker/compare/compare.py +69 -68
- msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +54 -0
- msprobe/pytorch/api_accuracy_checker/compare/compare_input.py +51 -0
- msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +2 -4
- msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +54 -30
- msprobe/pytorch/api_accuracy_checker/precision_standard/absolute_threshold.py +106 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/accumulative_error_compare.py +107 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/base_standard.py +151 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/benchmark_compare.py +226 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/binary_consistency.py +68 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/standard_config.py +218 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/standard_register.py +104 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/thousandth_standard.py +63 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/ulp_compare.py +200 -0
- msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +57 -1
- msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +2 -1
- msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +42 -14
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +64 -19
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +34 -4
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +5 -3
- msprobe/pytorch/bench_functions/npu_fusion_attention.py +42 -10
- msprobe/pytorch/common/parse_json.py +2 -1
- msprobe/pytorch/common/utils.py +45 -2
- msprobe/pytorch/compare/distributed_compare.py +17 -29
- msprobe/pytorch/compare/pt_compare.py +40 -20
- msprobe/pytorch/debugger/debugger_config.py +27 -12
- msprobe/pytorch/debugger/precision_debugger.py +42 -12
- msprobe/pytorch/dump/module_dump/__init__.py +0 -0
- msprobe/pytorch/dump/module_dump/module_dump.py +86 -0
- msprobe/pytorch/{module_processer.py → dump/module_dump/module_processer.py} +80 -6
- msprobe/pytorch/free_benchmark/common/params.py +2 -1
- msprobe/pytorch/free_benchmark/common/utils.py +3 -0
- msprobe/pytorch/free_benchmark/compare/grad_saver.py +0 -2
- msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +31 -47
- msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +0 -4
- msprobe/pytorch/hook_module/__init__.py +1 -1
- msprobe/pytorch/hook_module/hook_module.py +14 -11
- msprobe/pytorch/hook_module/register_optimizer_hook.py +59 -0
- msprobe/pytorch/hook_module/support_wrap_ops.yaml +34 -0
- msprobe/pytorch/hook_module/wrap_distributed.py +6 -8
- msprobe/pytorch/hook_module/wrap_functional.py +0 -40
- msprobe/pytorch/monitor/anomaly_analyse.py +1 -1
- msprobe/pytorch/monitor/anomaly_detect.py +107 -22
- msprobe/pytorch/monitor/csv2tb.py +166 -0
- msprobe/pytorch/monitor/distributed/wrap_distributed.py +25 -14
- msprobe/pytorch/monitor/features.py +3 -3
- msprobe/pytorch/monitor/module_hook.py +483 -277
- msprobe/pytorch/monitor/module_metric.py +27 -48
- msprobe/pytorch/monitor/module_spec_verifier.py +3 -1
- msprobe/pytorch/monitor/optimizer_collect.py +52 -14
- msprobe/pytorch/monitor/unittest/test_monitor.py +24 -9
- msprobe/pytorch/monitor/utils.py +77 -6
- msprobe/pytorch/online_dispatch/dispatch.py +8 -2
- msprobe/pytorch/parse_tool/lib/compare.py +10 -10
- msprobe/pytorch/parse_tool/lib/config.py +5 -7
- msprobe/pytorch/parse_tool/lib/file_desc.py +15 -1
- msprobe/pytorch/parse_tool/lib/interactive_cli.py +10 -10
- msprobe/pytorch/parse_tool/lib/parse_exception.py +7 -7
- msprobe/pytorch/parse_tool/lib/parse_tool.py +11 -10
- msprobe/pytorch/parse_tool/lib/utils.py +18 -19
- msprobe/pytorch/parse_tool/lib/visualization.py +9 -10
- msprobe/pytorch/service.py +176 -106
- msprobe/visualization/builder/graph_builder.py +62 -5
- msprobe/visualization/builder/msprobe_adapter.py +24 -2
- msprobe/visualization/compare/graph_comparator.py +64 -14
- msprobe/visualization/compare/mode_adapter.py +1 -15
- msprobe/visualization/graph/base_node.py +12 -17
- msprobe/visualization/graph/distributed_analyzer.py +318 -0
- msprobe/visualization/graph/graph.py +9 -0
- msprobe/visualization/graph_service.py +97 -23
- msprobe/visualization/utils.py +14 -29
- msprobe/pytorch/functional/module_dump.py +0 -84
- {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.1.dist-info}/LICENSE +0 -0
- {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.1.dist-info}/WHEEL +0 -0
- {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.1.dist-info}/entry_points.txt +0 -0
- {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.1.dist-info}/top_level.txt +0 -0
- /msprobe/docs/{data_dump_Mindspore → data_dump_MindSpore}/dynamic_graph_quick_start_example.md +0 -0
- /msprobe/{pytorch/functional → mindspore/code_mapping}/__init__.py +0 -0
msprobe/README.md
CHANGED
|
@@ -51,15 +51,21 @@ export MSPROBE_LOG_LEVEL={x}
|
|
|
51
51
|
|
|
52
52
|
**1. Pytorch 框架下,工具暂不支持 Fully Sharded Data Parallel(FSDP)。**
|
|
53
53
|
|
|
54
|
+
**2. 工具读写的所有路径,如config_path、dump_path等,只允许包含大小写字母、数字、下划线、斜杠、点和短横线。**
|
|
55
|
+
|
|
54
56
|
## ⚙️ [安装](./docs/01.installation.md)
|
|
55
57
|
|
|
58
|
+
## 🌟 新版本特性
|
|
59
|
+
|
|
60
|
+
请参见[特性变更说明](./docs/01.installation.md#特性变更说明)。
|
|
61
|
+
|
|
56
62
|
## 🛠️ config.json [介绍](./docs/02.config_introduction.md) 和 [示例](./docs/03.config_examples.md)
|
|
57
63
|
|
|
58
64
|
## 🧰 主要功能
|
|
59
65
|
|
|
60
66
|
### 0 用前必看
|
|
61
67
|
|
|
62
|
-
使用工具前,建议先浏览[**工具功能模块简介、适用场景和当前版本局限性**](./docs/
|
|
68
|
+
使用工具前,建议先浏览[**工具功能模块简介、适用场景和当前版本局限性**](./docs/25.tool_function_introduction.md),了解功能特性。
|
|
63
69
|
|
|
64
70
|
### 1 数据采集
|
|
65
71
|
|
|
@@ -131,29 +137,18 @@ MindSpore 动态图场景的[离线预检](./docs/09.accuracy_checker_MindSpore.
|
|
|
131
137
|
|
|
132
138
|
[MindSpore 场景的分级可视化构图比对](./docs/22.visualization_MindSpore.md)
|
|
133
139
|
|
|
134
|
-
## 🌟 新版本特性
|
|
135
140
|
|
|
136
|
-
|
|
141
|
+
### 11 单算子API自动生成脚本
|
|
142
|
+
|
|
143
|
+
该功能将msprobe工具dump的精度数据进行解析,自动生成单API脚本,用于复现整网中出现的算子问题,降低用户复现问题的成本,供开发分析算子问题。
|
|
137
144
|
|
|
138
|
-
|
|
139
|
-
- 支持 config.json 中的 step 传入范围;
|
|
140
|
-
- 优化了指定 step 的机制,指定 step 结束后工具不再采集数据,但训练会继续运行。工具结束运行后,日志提示信息如下:
|
|
141
|
-
```bash
|
|
142
|
-
****************************************
|
|
143
|
-
* msprobe ends successfully. *
|
|
144
|
-
****************************************
|
|
145
|
-
```
|
|
146
|
-
注:在多卡场景,每张卡进程训练到指定 step 之后都会打印一次上述信息。
|
|
145
|
+
[PyTorch 单算子API自动生成脚本](./docs/23.generate_operator_PyTorch.md)
|
|
147
146
|
|
|
148
|
-
|
|
149
|
-
- 在 PyTorch 场景,支持部分 NPU 融合算子预检。
|
|
147
|
+
### 12 数码关联
|
|
150
148
|
|
|
151
|
-
|
|
152
|
-
- 解决了使用 MindSpore 需要安装 PyTorch 的问题。
|
|
149
|
+
该功能只支持 MindSpore 静态图场景,用于将IR图与dump数据进行关联,获取dump数据和代码调用栈的关联关系。
|
|
153
150
|
|
|
154
|
-
|
|
155
|
-
- 补充在 PyTorch 场景的性能基线报告;
|
|
156
|
-
- 支持 MindSpore 场景的 change_value 扰动模式。
|
|
151
|
+
[MindSpore 场景的数码关联](./docs/24.code_mapping_Mindspore.md)
|
|
157
152
|
|
|
158
153
|
## 📑 补充材料
|
|
159
154
|
|
msprobe/config.json
CHANGED
msprobe/core/common/const.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright (c) 2024-
|
|
1
|
+
# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
|
|
2
2
|
# All rights reserved.
|
|
3
3
|
#
|
|
4
4
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
@@ -29,6 +29,7 @@ class Const:
|
|
|
29
29
|
SEP = "."
|
|
30
30
|
REGEX_PREFIX_MAX_LENGTH = 20
|
|
31
31
|
REGEX_PREFIX_PATTERN = r"^[a-zA-Z0-9_-]+$"
|
|
32
|
+
REGEX_FORWARD_BACKWARD = r'\.(forward|backward)\.'
|
|
32
33
|
FILE_PATTERN = r'^[a-zA-Z0-9_./-]+$'
|
|
33
34
|
STRING_BLACKLIST = r"^[+-=%@\+\-=%@]|;[+-=%@\+\-=%@]"
|
|
34
35
|
COMMA = ","
|
|
@@ -65,6 +66,7 @@ class Const:
|
|
|
65
66
|
ONLINE_DUMP_MODE = [ALL, LIST, AUTO, OFF]
|
|
66
67
|
SUMMARY = "summary"
|
|
67
68
|
MD5 = "md5"
|
|
69
|
+
VALUE = "value"
|
|
68
70
|
SUMMARY_MODE = [ALL, SUMMARY, MD5]
|
|
69
71
|
|
|
70
72
|
WRITE_FLAGS = os.O_WRONLY | os.O_CREAT
|
|
@@ -73,6 +75,7 @@ class Const:
|
|
|
73
75
|
|
|
74
76
|
PKL_SUFFIX = ".pkl"
|
|
75
77
|
NUMPY_SUFFIX = ".npy"
|
|
78
|
+
NUMPY_PATTERN = "*.npy"
|
|
76
79
|
PT_SUFFIX = ".pt"
|
|
77
80
|
ONE_GB = 1073741824 # 1 * 1024 * 1024 * 1024
|
|
78
81
|
TEN_GB = 10737418240 # 10 * 1024 * 1024 * 1024
|
|
@@ -87,6 +90,8 @@ class Const:
|
|
|
87
90
|
INPUT_KWARGS = 'input_kwargs'
|
|
88
91
|
GRAD_INPUT = 'grad_input'
|
|
89
92
|
GRAD_OUTPUT = 'grad_output'
|
|
93
|
+
PARAMS = 'parameters'
|
|
94
|
+
PARAMS_GRAD = 'parameters_grad'
|
|
90
95
|
START = "start"
|
|
91
96
|
STOP = "stop"
|
|
92
97
|
ENV_ENABLE = "1"
|
|
@@ -112,6 +117,7 @@ class Const:
|
|
|
112
117
|
DATA = "data"
|
|
113
118
|
PT_FRAMEWORK = "pytorch"
|
|
114
119
|
MS_FRAMEWORK = "mindspore"
|
|
120
|
+
MT_FRAMEWORK = "mindtorch"
|
|
115
121
|
UNKNOWN_FRAMEWORK = "unknown"
|
|
116
122
|
DIRECTORY_LENGTH = 4096
|
|
117
123
|
FILE_NAME_LENGTH = 255
|
|
@@ -122,9 +128,12 @@ class Const:
|
|
|
122
128
|
NPU_LOWERCASE = 'npu'
|
|
123
129
|
CPU_LOWERCASE = 'cpu'
|
|
124
130
|
CUDA_LOWERCASE = 'cuda'
|
|
131
|
+
DEVICE = 'device'
|
|
125
132
|
DISTRIBUTED = 'Distributed'
|
|
126
|
-
DUMP_PREFIX = ["Distributed", "Functional", "Torch", "Tensor", "Mint", "MintFunctional", "Primitive",
|
|
133
|
+
DUMP_PREFIX = ["Distributed", "Functional", "Torch", "Tensor", "Mint", "MintFunctional", "Primitive",
|
|
127
134
|
"Aten", "VF", "NPU", "Jit"]
|
|
135
|
+
MODULE_PREFIX = ["Module", "Cell"]
|
|
136
|
+
FORWARD_NAME_SUFFIX = ".forward"
|
|
128
137
|
|
|
129
138
|
# struct json param
|
|
130
139
|
ORIGIN_DATA = "origin_data"
|
|
@@ -145,10 +154,13 @@ class Const:
|
|
|
145
154
|
SCOPE_ID_INDEX = -1
|
|
146
155
|
SCOPE_DIRECTION_INDEX = -2
|
|
147
156
|
TYPE_NAME_INDEX = -3
|
|
157
|
+
PARAMS_GRAD_TYPE_NAME_INDEX = -2
|
|
148
158
|
LAYER_NAME_INDEX = -4
|
|
159
|
+
PARAMS_GRAD_NAME_INDEX = -3
|
|
149
160
|
API_TYPE_INDEX = 0
|
|
150
161
|
LEFT_MOVE_INDEX = -1
|
|
151
162
|
RIGHT_MOVE_INDEX = 1
|
|
163
|
+
LAST_INDEX = -1
|
|
152
164
|
|
|
153
165
|
TOP_LAYER = "TopLayer"
|
|
154
166
|
CELL = "Cell"
|
|
@@ -162,12 +174,16 @@ class Const:
|
|
|
162
174
|
|
|
163
175
|
CONVERT = {
|
|
164
176
|
"int32_to_int64": ["torch.int32", "torch.int64"],
|
|
177
|
+
"int64_to_fp32": ["torch.int64", "torch.float32"]
|
|
165
178
|
}
|
|
166
179
|
|
|
167
180
|
CONVERT_API = {
|
|
168
|
-
"int32_to_int64": ["cross_entropy"]
|
|
181
|
+
"int32_to_int64": ["cross_entropy"],
|
|
182
|
+
"int64_to_fp32": ["histc"]
|
|
169
183
|
}
|
|
170
184
|
|
|
185
|
+
FA_SPECIAL_SPARSE_MODE = [2, 3, 4]
|
|
186
|
+
|
|
171
187
|
FILL_CHAR_NUMS = 50
|
|
172
188
|
TOOL_ENDS_SUCCESSFULLY = f"{TOOL_NAME} ends successfully."
|
|
173
189
|
WITHOUT_CALL_STACK = "The call stack retrieval failed."
|
|
@@ -179,6 +195,8 @@ class Const:
|
|
|
179
195
|
STEP_RANK_MAXIMUM_VALUE = int(1e6)
|
|
180
196
|
|
|
181
197
|
# data type const
|
|
198
|
+
TORCH_INT_DTYPE = ["torch.int8", "torch.int32", "torch.int64"]
|
|
199
|
+
TORCH_FLOAT_DTYPE = ["torch.bfloat16", "torch.float16", "torch.float32", "torch.float64"]
|
|
182
200
|
FLOAT16 = "Float16"
|
|
183
201
|
FLOAT32 = "Float32"
|
|
184
202
|
BFLOAT16 = "BFloat16"
|
|
@@ -193,6 +211,23 @@ class Const:
|
|
|
193
211
|
MEAN = 'Mean'
|
|
194
212
|
NORM = 'Norm'
|
|
195
213
|
|
|
214
|
+
CODE_STACK = 'Code Stack'
|
|
215
|
+
OP_NAME = 'Op Name'
|
|
216
|
+
SCOPE_NAME = 'Scope Name'
|
|
217
|
+
CODE_STACKS = 'Code Stacks'
|
|
218
|
+
FILE_PATH = 'File Path'
|
|
219
|
+
NEW_LINE = '\n'
|
|
220
|
+
CSV_NEWLINE_SEPARATOR = ',\n'
|
|
221
|
+
# 分隔符常量
|
|
222
|
+
SCOPE_SEPARATOR = "/"
|
|
223
|
+
REPLACEMENT_CHARACTER = "_"
|
|
224
|
+
|
|
225
|
+
OPTIMIZER = "optimizer"
|
|
226
|
+
CLIP_GRAD = "clip_grad"
|
|
227
|
+
END_PREFIX = "end_"
|
|
228
|
+
|
|
229
|
+
TENSOR_STAT_LEN = 2
|
|
230
|
+
|
|
196
231
|
|
|
197
232
|
class CompareConst:
|
|
198
233
|
"""
|
|
@@ -239,13 +274,58 @@ class CompareConst:
|
|
|
239
274
|
INPUT_STRUCT = "input_struct"
|
|
240
275
|
KWARGS_STRUCT = "kwargs_struct"
|
|
241
276
|
OUTPUT_STRUCT = "output_struct"
|
|
277
|
+
PARAMS_STRUCT = "params_struct"
|
|
278
|
+
PARAMS_GRAD_STRUCT = "params_grad_struct"
|
|
242
279
|
SUMMARY = "summary"
|
|
280
|
+
COMPARE_RESULT = "compare_result"
|
|
281
|
+
COMPARE_MESSAGE = "compare_message"
|
|
243
282
|
MAX_EXCEL_LENGTH = 1048576
|
|
244
283
|
YES = "Yes"
|
|
245
284
|
NO = "No"
|
|
246
285
|
STATISTICS_INDICATOR_NUM = 4
|
|
247
286
|
EPSILON = 1e-10
|
|
248
287
|
COMPARE_ENDS_SUCCESSFULLY = "msprobe compare ends successfully."
|
|
288
|
+
DEFAULT_RATIO_VALUE = 10000
|
|
289
|
+
THOUSANDTH_PASS_VALUE = 0.999
|
|
290
|
+
ZERO_SHAPE = '(0,)'
|
|
291
|
+
|
|
292
|
+
BENCHMARK_COMPARE_ALGORITHM_NAME = "标杆比对法"
|
|
293
|
+
ULP_COMPARE_ALGORITHM_NAME = "ULP误差比对法"
|
|
294
|
+
BINARY_CONSISTENCY_ALGORITHM_NAME = "二进制一致法"
|
|
295
|
+
ABSOLUTE_THRESHOLD_ALGORITHM_NAME = "绝对阈值法"
|
|
296
|
+
THOUSANDTH_STANDARD_ALGORITHM_NAME = "双千指标法"
|
|
297
|
+
ACCUMULATIVE_ERROR_COMPARE_ALGORITHM_NAME = "累积误差比对法"
|
|
298
|
+
|
|
299
|
+
ABSOLUTE_THRESHOLD = 'absolute_threshold'
|
|
300
|
+
BINARY_CONSISTENCY = 'binary_consistency'
|
|
301
|
+
ULP_COMPARE = 'ulp_compare'
|
|
302
|
+
THOUSANDTH_STANDARD = 'thousandth_threshold'
|
|
303
|
+
BENCHMARK = 'benchmark'
|
|
304
|
+
ACCUMULATIVE_ERROR_COMPARE = 'accumulative_error_compare'
|
|
305
|
+
|
|
306
|
+
SMALL_VALUE_ERR_RATIO = "small_value_err_ratio"
|
|
307
|
+
RMSE_RATIO = "rmse_ratio"
|
|
308
|
+
MAX_REL_ERR_RATIO = "max_rel_err_ratio"
|
|
309
|
+
MEAN_REL_ERR_RATIO = "mean_rel_err_ratio"
|
|
310
|
+
EB_RATIO = "eb_ratio"
|
|
311
|
+
|
|
312
|
+
SMALL_VALUE = "small_value"
|
|
313
|
+
RMSE = "rmse"
|
|
314
|
+
MAX_REL_ERR = "max_rel_err"
|
|
315
|
+
MEAN_REL_ERR = "mean_rel_err"
|
|
316
|
+
EB = "eb"
|
|
317
|
+
|
|
318
|
+
SMALL_VALUE_ERR_STATUS = "small_value_err_status"
|
|
319
|
+
RMSE_STATUS = "rmse_status"
|
|
320
|
+
MAX_REL_ERR_STATUS = "max_rel_err_status"
|
|
321
|
+
MEAN_REL_ERR_STATUS = "mean_rel_err_status"
|
|
322
|
+
EB_STATUS = "eb_status"
|
|
323
|
+
|
|
324
|
+
MEAN_ULP_ERR = "mean_ulp_err"
|
|
325
|
+
ULP_ERR_PROPORTION = "ulp_err_proportion"
|
|
326
|
+
ULP_ERR_PROPORTION_RATIO = "ulp_err_proportion_ratio"
|
|
327
|
+
|
|
328
|
+
ULP_ERR_STATUS = "ulp_err_status"
|
|
249
329
|
|
|
250
330
|
COMPARE_RESULT_HEADER = [
|
|
251
331
|
NPU_NAME, BENCH_NAME, NPU_DTYPE, BENCH_DTYPE, NPU_SHAPE, BENCH_SHAPE, COSINE, MAX_ABS_ERR, MAX_RELATIVE_ERR,
|
|
@@ -263,12 +343,57 @@ class CompareConst:
|
|
|
263
343
|
NPU_NAME, BENCH_NAME, NPU_DTYPE, BENCH_DTYPE, NPU_SHAPE, BENCH_SHAPE, NPU_MD5, BENCH_MD5, RESULT
|
|
264
344
|
]
|
|
265
345
|
|
|
346
|
+
COMPARE_RESULT_HEADER_STACK = COMPARE_RESULT_HEADER + [STACK]
|
|
347
|
+
|
|
348
|
+
SUMMARY_COMPARE_RESULT_HEADER_STACK = SUMMARY_COMPARE_RESULT_HEADER + [STACK]
|
|
349
|
+
|
|
350
|
+
MD5_COMPARE_RESULT_HEADER_STACK = MD5_COMPARE_RESULT_HEADER + [STACK]
|
|
351
|
+
|
|
266
352
|
HEAD_OF_COMPARE_MODE = {
|
|
267
353
|
Const.ALL: COMPARE_RESULT_HEADER,
|
|
268
354
|
Const.SUMMARY: SUMMARY_COMPARE_RESULT_HEADER,
|
|
269
355
|
Const.MD5: MD5_COMPARE_RESULT_HEADER
|
|
270
356
|
}
|
|
271
357
|
|
|
358
|
+
ALL_COMPARE_INDEX = [COSINE, MAX_ABS_ERR, MAX_RELATIVE_ERR, ONE_THOUSANDTH_ERR_RATIO, FIVE_THOUSANDTHS_ERR_RATIO]
|
|
359
|
+
SUMMARY_COMPARE_INDEX = [MAX_DIFF, MIN_DIFF, MEAN_DIFF, NORM_DIFF,
|
|
360
|
+
MAX_RELATIVE_ERR, MIN_RELATIVE_ERR, MEAN_RELATIVE_ERR, NORM_RELATIVE_ERR]
|
|
361
|
+
|
|
362
|
+
# dtype match
|
|
363
|
+
MS_TYPE = [
|
|
364
|
+
[Const.FLOAT16, Const.FLOAT32], [Const.FLOAT32, Const.FLOAT16],
|
|
365
|
+
[Const.FLOAT16, Const.BFLOAT16], [Const.BFLOAT16, Const.FLOAT16]
|
|
366
|
+
]
|
|
367
|
+
TORCH_TYPE = [
|
|
368
|
+
[Const.TORCH_FLOAT16, Const.TORCH_FLOAT32], [Const.TORCH_FLOAT32, Const.TORCH_FLOAT16],
|
|
369
|
+
[Const.TORCH_FLOAT16, Const.TORCH_BFLOAT16], [Const.TORCH_BFLOAT16, Const.TORCH_FLOAT16]
|
|
370
|
+
]
|
|
371
|
+
|
|
372
|
+
# read_op
|
|
373
|
+
IO_NAME_MAPPING = {
|
|
374
|
+
Const.INPUT_ARGS: '.input',
|
|
375
|
+
Const.INPUT_KWARGS: '.input',
|
|
376
|
+
Const.INPUT: '.input',
|
|
377
|
+
Const.OUTPUT: '.output',
|
|
378
|
+
Const.PARAMS: '.parameters'
|
|
379
|
+
}
|
|
380
|
+
|
|
381
|
+
# state to struct mapping
|
|
382
|
+
STATE_TO_STRUCT_MAPPING = {
|
|
383
|
+
Const.INPUT: INPUT_STRUCT,
|
|
384
|
+
Const.KWARGS: INPUT_STRUCT,
|
|
385
|
+
Const.OUTPUT: OUTPUT_STRUCT,
|
|
386
|
+
Const.PARAMS: PARAMS_STRUCT,
|
|
387
|
+
Const.PARAMS_GRAD: PARAMS_GRAD_STRUCT
|
|
388
|
+
}
|
|
389
|
+
|
|
390
|
+
STRUCT_COMPARE_KEY = [
|
|
391
|
+
INPUT_STRUCT,
|
|
392
|
+
OUTPUT_STRUCT,
|
|
393
|
+
PARAMS_STRUCT,
|
|
394
|
+
PARAMS_GRAD_STRUCT
|
|
395
|
+
]
|
|
396
|
+
|
|
272
397
|
# compare standard
|
|
273
398
|
HUNDRED_RATIO_THRESHOLD = 0.01
|
|
274
399
|
THOUSAND_RATIO_THRESHOLD = 0.001
|
|
@@ -350,6 +475,8 @@ class CompareConst:
|
|
|
350
475
|
INPUT_PATTERN = Const.SEP + Const.INPUT + Const.SEP
|
|
351
476
|
KWARGS_PATTERN = Const.SEP + Const.KWARGS + Const.SEP
|
|
352
477
|
OUTPUT_PATTERN = Const.SEP + Const.OUTPUT + Const.SEP
|
|
478
|
+
PARAMS_PATTERN = Const.SEP + Const.PARAMS + Const.SEP
|
|
479
|
+
PARAMS_GRAD_PATTERN = Const.SEP + Const.PARAMS_GRAD + Const.SEP
|
|
353
480
|
COMPARE_KEY = 'compare_key'
|
|
354
481
|
COMPARE_SHAPE = 'compare_shape'
|
|
355
482
|
INTERNAL_API_MAPPING_FILE = 'ms_to_pt_api.yaml'
|
|
@@ -372,13 +499,17 @@ class FileCheckConst:
|
|
|
372
499
|
JSON_SUFFIX = ".json"
|
|
373
500
|
PT_SUFFIX = ".pt"
|
|
374
501
|
CSV_SUFFIX = ".csv"
|
|
502
|
+
XLSX_SUFFIX = ".xlsx"
|
|
375
503
|
YAML_SUFFIX = ".yaml"
|
|
504
|
+
IR_SUFFIX = ".ir"
|
|
376
505
|
MAX_PKL_SIZE = 1073741824 # 1 * 1024 * 1024 * 1024
|
|
377
506
|
MAX_NUMPY_SIZE = 10737418240 # 10 * 1024 * 1024 * 1024
|
|
378
507
|
MAX_JSON_SIZE = 1073741824 # 1 * 1024 * 1024 * 1024
|
|
379
508
|
MAX_PT_SIZE = 10737418240 # 10 * 1024 * 1024 * 1024
|
|
380
509
|
MAX_CSV_SIZE = 1073741824 # 1 * 1024 * 1024 * 1024
|
|
510
|
+
MAX_XLSX_SIZE = 1073741824 # 1 * 1024 * 1024 * 1024
|
|
381
511
|
MAX_YAML_SIZE = 1073741824 # 1 * 1024 * 1024 * 1024
|
|
512
|
+
MAX_IR_SIZE = 1073741824 # 1 * 1024 * 1024 * 1024
|
|
382
513
|
COMMOM_FILE_SIZE = 1048576 # 1 * 1024 * 1024
|
|
383
514
|
DIR = "dir"
|
|
384
515
|
FILE = "file"
|
|
@@ -390,7 +521,9 @@ class FileCheckConst:
|
|
|
390
521
|
JSON_SUFFIX: MAX_JSON_SIZE,
|
|
391
522
|
PT_SUFFIX: MAX_PT_SIZE,
|
|
392
523
|
CSV_SUFFIX: MAX_CSV_SIZE,
|
|
393
|
-
|
|
524
|
+
XLSX_SUFFIX: MAX_XLSX_SIZE,
|
|
525
|
+
YAML_SUFFIX: MAX_YAML_SIZE,
|
|
526
|
+
IR_SUFFIX: MAX_IR_SIZE
|
|
394
527
|
}
|
|
395
528
|
CSV_BLACK_LIST = r'^[+-=%@\+\-=%@]|;[+-=%@\+\-=%@]'
|
|
396
529
|
|
|
@@ -437,6 +570,11 @@ class MsCompareConst:
|
|
|
437
570
|
|
|
438
571
|
EPSILON = 1e-8
|
|
439
572
|
|
|
573
|
+
class ProcessStatus:
|
|
574
|
+
SUCCESS = "success"
|
|
575
|
+
API_NOT_FOUND = "api_not_found"
|
|
576
|
+
EXCEPTION_SKIP = "exception_skip"
|
|
577
|
+
|
|
440
578
|
|
|
441
579
|
class MsgConst:
|
|
442
580
|
"""
|
|
@@ -474,15 +612,20 @@ class MonitorConst:
|
|
|
474
612
|
"""
|
|
475
613
|
Class for monitor const
|
|
476
614
|
"""
|
|
477
|
-
OP_LIST = ["
|
|
615
|
+
OP_LIST = ["norm", "min", "max", "zeros", "nans", "id", "mean"]
|
|
478
616
|
MONITOR_OUTPUT_DIR = "MONITOR_OUTPUT_DIR"
|
|
479
617
|
DEFAULT_MONITOR_OUTPUT_DIR = "./monitor_output"
|
|
480
618
|
DATABASE = "database"
|
|
481
619
|
EMAIL = "email"
|
|
482
620
|
OPT_TY = ['Megatron_DistributedOptimizer', 'Megatron_Float16OptimizerWithFloat16Params']
|
|
483
|
-
DEEPSPEED_OPT_TY = (
|
|
621
|
+
DEEPSPEED_OPT_TY = (
|
|
622
|
+
"DeepSpeedZeroOptimizer_Stage0",
|
|
623
|
+
"DeepSpeedZeroOptimizer_Stage1_or_2",
|
|
624
|
+
"DeepSpeedZeroOptimizer_Stage3"
|
|
625
|
+
)
|
|
484
626
|
RULE_NAME = ['AnomalyTurbulence']
|
|
485
627
|
|
|
628
|
+
SLICE_SIZE = 20480
|
|
486
629
|
DOT = "."
|
|
487
630
|
VPP_SEP = ":"
|
|
488
631
|
ACTV_IN = "input"
|
|
@@ -491,12 +634,18 @@ class MonitorConst:
|
|
|
491
634
|
ACTVGRAD_OUT = "output_grad"
|
|
492
635
|
POST_GRAD = "post_grad"
|
|
493
636
|
PRE_GRAD = "pre_grad"
|
|
637
|
+
ACC_GRAD = "acc_grad"
|
|
494
638
|
PREFIX_POST = "post"
|
|
495
639
|
PREFIX_PRE = "pre"
|
|
640
|
+
OUTPUT_DIR_PATTERN = r"([\w-]{0,20})-rank(\d{1,5})-"
|
|
496
641
|
|
|
642
|
+
EXP_AVG = "exp_avg"
|
|
643
|
+
EFXP_AVG_SQ = "efxp_avg_sq"
|
|
497
644
|
|
|
498
645
|
ANOMALY_JSON = "anomaly.json"
|
|
499
646
|
ANALYSE_JSON = "anomaly_analyse.json"
|
|
500
647
|
TENSORBOARD = "tensorboard"
|
|
501
648
|
CSV = "csv"
|
|
502
649
|
API = "api"
|
|
650
|
+
OPS_START_INDEX = 3
|
|
651
|
+
HEADER_NAME_INDEX = 1
|
|
@@ -27,11 +27,13 @@ class MsprobeException(CodedException):
|
|
|
27
27
|
INVALID_PARAM_ERROR = 0
|
|
28
28
|
OVERFLOW_NUMS_ERROR = 1
|
|
29
29
|
RECURSION_LIMIT_ERROR = 2
|
|
30
|
+
INTERFACE_USAGE_ERROR = 3
|
|
30
31
|
|
|
31
32
|
err_strs = {
|
|
32
33
|
INVALID_PARAM_ERROR: "[msprobe] 无效参数:",
|
|
33
34
|
OVERFLOW_NUMS_ERROR: "[msprobe] 超过预设溢出次数 当前溢出次数:",
|
|
34
|
-
RECURSION_LIMIT_ERROR: "[msprobe] 递归调用超过限制:"
|
|
35
|
+
RECURSION_LIMIT_ERROR: "[msprobe] 递归调用超过限制:",
|
|
36
|
+
INTERFACE_USAGE_ERROR: "[msprobe] Invalid interface usage: "
|
|
35
37
|
}
|
|
36
38
|
|
|
37
39
|
|
|
@@ -22,7 +22,6 @@ import re
|
|
|
22
22
|
import shutil
|
|
23
23
|
from datetime import datetime, timezone
|
|
24
24
|
from dateutil import parser
|
|
25
|
-
import OpenSSL
|
|
26
25
|
import yaml
|
|
27
26
|
import numpy as np
|
|
28
27
|
import pandas as pd
|
|
@@ -419,20 +418,36 @@ def save_yaml(yaml_path, data):
|
|
|
419
418
|
|
|
420
419
|
|
|
421
420
|
def save_excel(path, data):
|
|
421
|
+
def validate_data(data):
|
|
422
|
+
"""Validate that the data is a DataFrame or a list of (DataFrame, sheet_name) pairs."""
|
|
423
|
+
if isinstance(data, pd.DataFrame):
|
|
424
|
+
return "single"
|
|
425
|
+
elif isinstance(data, list):
|
|
426
|
+
if all(isinstance(item, tuple) and len(item) == 2 and isinstance(item[0], pd.DataFrame) for item in data):
|
|
427
|
+
return "list"
|
|
428
|
+
raise ValueError("Data must be a DataFrame or a list of (DataFrame, sheet_name) pairs.")
|
|
429
|
+
|
|
422
430
|
check_path_before_create(path)
|
|
423
431
|
path = os.path.realpath(path)
|
|
432
|
+
|
|
433
|
+
# 验证数据类型
|
|
434
|
+
data_type = validate_data(data)
|
|
435
|
+
|
|
424
436
|
try:
|
|
425
|
-
if
|
|
437
|
+
if data_type == "single":
|
|
426
438
|
data.to_excel(path, index=False)
|
|
427
|
-
|
|
428
|
-
|
|
429
|
-
|
|
439
|
+
elif data_type == "list":
|
|
440
|
+
with pd.ExcelWriter(path) as writer:
|
|
441
|
+
for data_df, sheet_name in data:
|
|
442
|
+
data_df.to_excel(writer, sheet_name=sheet_name, index=False)
|
|
430
443
|
except Exception as e:
|
|
431
444
|
logger.error(f'Save excel file "{os.path.basename(path)}" failed.')
|
|
432
445
|
raise RuntimeError(f"Save excel file {path} failed.") from e
|
|
433
446
|
change_mode(path, FileCheckConst.DATA_FILE_AUTHORITY)
|
|
434
447
|
|
|
435
448
|
|
|
449
|
+
|
|
450
|
+
|
|
436
451
|
def move_file(src_path, dst_path):
|
|
437
452
|
check_file_or_directory_path(src_path)
|
|
438
453
|
check_path_before_create(dst_path)
|
|
@@ -522,11 +537,11 @@ def write_csv(data, filepath, mode="a+", malicious_check=False):
|
|
|
522
537
|
change_mode(filepath, FileCheckConst.DATA_FILE_AUTHORITY)
|
|
523
538
|
|
|
524
539
|
|
|
525
|
-
def read_csv(filepath, as_pd=True):
|
|
540
|
+
def read_csv(filepath, as_pd=True, header='infer'):
|
|
526
541
|
check_file_or_directory_path(filepath)
|
|
527
542
|
try:
|
|
528
543
|
if as_pd:
|
|
529
|
-
csv_data = pd.read_csv(filepath)
|
|
544
|
+
csv_data = pd.read_csv(filepath, header=header)
|
|
530
545
|
else:
|
|
531
546
|
with FileOpen(filepath, 'r', encoding='utf-8-sig') as f:
|
|
532
547
|
csv_reader = csv.reader(f, delimiter=',')
|
|
@@ -630,6 +645,7 @@ def check_crt_valid(pem_path):
|
|
|
630
645
|
Raises:
|
|
631
646
|
RuntimeError: If the SSL certificate is invalid or expired.
|
|
632
647
|
"""
|
|
648
|
+
import OpenSSL
|
|
633
649
|
try:
|
|
634
650
|
with FileOpen(pem_path, "r") as f:
|
|
635
651
|
pem_data = f.read()
|
|
@@ -645,3 +661,13 @@ def check_crt_valid(pem_path):
|
|
|
645
661
|
now_utc = datetime.now(tz=timezone.utc)
|
|
646
662
|
if cert.has_expired() or not (pem_start <= now_utc <= pem_end):
|
|
647
663
|
raise RuntimeError(f"The SSL certificate has expired and needs to be replaced, {pem_path}")
|
|
664
|
+
|
|
665
|
+
|
|
666
|
+
def read_xlsx(file_path):
|
|
667
|
+
check_file_or_directory_path(file_path)
|
|
668
|
+
try:
|
|
669
|
+
result_df = pd.read_excel(file_path, keep_default_na=False)
|
|
670
|
+
except Exception as e:
|
|
671
|
+
logger.error(f"The xlsx file failed to load. Please check the path: {file_path}.")
|
|
672
|
+
raise RuntimeError(f"Read xlsx file {file_path} failed.") from e
|
|
673
|
+
return result_df
|
msprobe/core/common/utils.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright (c) 2024-
|
|
1
|
+
# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
|
|
2
2
|
# All rights reserved.
|
|
3
3
|
#
|
|
4
4
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
@@ -29,6 +29,7 @@ from msprobe.core.common.const import Const, CompareConst
|
|
|
29
29
|
from msprobe.core.common.log import logger
|
|
30
30
|
from msprobe.core.common.exceptions import MsprobeException
|
|
31
31
|
|
|
32
|
+
|
|
32
33
|
device = collections.namedtuple('device', ['type', 'index'])
|
|
33
34
|
prefixes = ['api_stack', 'list', 'range', 'acl']
|
|
34
35
|
|
|
@@ -71,6 +72,9 @@ class MsprobeBaseException(Exception):
|
|
|
71
72
|
BACKWARD_DATA_COLLECTION_ERROR = 30
|
|
72
73
|
INVALID_KEY_ERROR = 31
|
|
73
74
|
MISSING_HEADER_ERROR = 32
|
|
75
|
+
MERGE_COMPARE_RESULT_ERROR = 33
|
|
76
|
+
NAMES_STRUCTS_MATCH_ERROR = 34
|
|
77
|
+
INVALID_STATE_ERROR = 35
|
|
74
78
|
|
|
75
79
|
def __init__(self, code, error_info: str = ""):
|
|
76
80
|
super(MsprobeBaseException, self).__init__()
|
|
@@ -109,7 +113,7 @@ def is_json_file(file_path):
|
|
|
109
113
|
return False
|
|
110
114
|
|
|
111
115
|
|
|
112
|
-
def check_compare_param(input_param, output_path, dump_mode):
|
|
116
|
+
def check_compare_param(input_param, output_path, dump_mode, stack_mode):
|
|
113
117
|
if not isinstance(input_param, dict):
|
|
114
118
|
logger.error(f"Invalid input parameter 'input_param', the expected type dict but got {type(input_param)}.")
|
|
115
119
|
raise CompareException(CompareException.INVALID_PARAM_ERROR)
|
|
@@ -127,7 +131,8 @@ def check_compare_param(input_param, output_path, dump_mode):
|
|
|
127
131
|
|
|
128
132
|
check_json_path("npu_json_path")
|
|
129
133
|
check_json_path("bench_json_path")
|
|
130
|
-
|
|
134
|
+
if stack_mode:
|
|
135
|
+
check_json_path("stack_json_path")
|
|
131
136
|
|
|
132
137
|
if dump_mode == Const.ALL:
|
|
133
138
|
check_file_or_directory_path(input_param.get("npu_dump_data_dir"), True)
|
|
@@ -135,9 +140,12 @@ def check_compare_param(input_param, output_path, dump_mode):
|
|
|
135
140
|
check_file_or_directory_path(output_path, True)
|
|
136
141
|
|
|
137
142
|
with FileOpen(input_param.get("npu_json_path"), "r") as npu_json, \
|
|
138
|
-
FileOpen(input_param.get("bench_json_path"), "r") as bench_json
|
|
139
|
-
|
|
140
|
-
|
|
143
|
+
FileOpen(input_param.get("bench_json_path"), "r") as bench_json:
|
|
144
|
+
_check_json(npu_json, input_param.get("npu_json_path"))
|
|
145
|
+
_check_json(bench_json, input_param.get("bench_json_path"))
|
|
146
|
+
if stack_mode:
|
|
147
|
+
with FileOpen(input_param.get("stack_json_path"), "r") as stack_json:
|
|
148
|
+
_check_json(stack_json, input_param.get("stack_json_path"))
|
|
141
149
|
|
|
142
150
|
|
|
143
151
|
def check_configuration_param(stack_mode=False, auto_analyze=True, fuzzy_match=False, is_print_compare_log=True):
|
|
@@ -395,20 +403,23 @@ def get_real_step_or_rank(step_or_rank_input, obj):
|
|
|
395
403
|
if not is_int(element) and not isinstance(element, str):
|
|
396
404
|
raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR,
|
|
397
405
|
f"{obj} element {element} must be an integer or string.")
|
|
398
|
-
if
|
|
399
|
-
|
|
400
|
-
|
|
401
|
-
|
|
406
|
+
if is_int(element):
|
|
407
|
+
if not Const.STEP_RANK_MINIMUM_VALUE <= element <= Const.STEP_RANK_MAXIMUM_VALUE:
|
|
408
|
+
raise MsprobeException(
|
|
409
|
+
MsprobeException.INVALID_PARAM_ERROR,
|
|
410
|
+
f"Each element of {obj} must be between {Const.STEP_RANK_MINIMUM_VALUE} and "
|
|
411
|
+
f"{Const.STEP_RANK_MAXIMUM_VALUE}, currently it is {element}."
|
|
412
|
+
)
|
|
402
413
|
real_step_or_rank.append(element)
|
|
403
|
-
|
|
404
|
-
|
|
405
|
-
|
|
414
|
+
continue
|
|
415
|
+
continual_step_or_rank = get_step_or_rank_from_string(element, obj)
|
|
416
|
+
real_step_or_rank.extend(continual_step_or_rank)
|
|
406
417
|
real_step_or_rank = list(set(real_step_or_rank))
|
|
407
418
|
real_step_or_rank.sort()
|
|
408
419
|
return real_step_or_rank
|
|
409
420
|
|
|
410
421
|
|
|
411
|
-
def check_seed_all(seed, mode):
|
|
422
|
+
def check_seed_all(seed, mode, rm_dropout):
|
|
412
423
|
if is_int(seed):
|
|
413
424
|
if seed < 0 or seed > Const.MAX_SEED_VALUE:
|
|
414
425
|
logger.error(f"Seed must be between 0 and {Const.MAX_SEED_VALUE}.")
|
|
@@ -419,6 +430,9 @@ def check_seed_all(seed, mode):
|
|
|
419
430
|
if not isinstance(mode, bool):
|
|
420
431
|
logger.error("seed_all mode must be bool.")
|
|
421
432
|
raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR)
|
|
433
|
+
if not isinstance(rm_dropout, bool):
|
|
434
|
+
logger.error("The rm_dropout parameter must be bool.")
|
|
435
|
+
raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR)
|
|
422
436
|
|
|
423
437
|
|
|
424
438
|
def safe_get_value(container, index, container_name, key=None):
|
msprobe/core/common_config.py
CHANGED
|
@@ -27,6 +27,7 @@ class CommonConfig:
|
|
|
27
27
|
self.step = get_real_step_or_rank(json_config.get('step'), Const.STEP)
|
|
28
28
|
self.level = json_config.get('level')
|
|
29
29
|
self.enable_dataloader = json_config.get('enable_dataloader', False)
|
|
30
|
+
self.async_dump = json_config.get("async_dump", False)
|
|
30
31
|
self._check_config()
|
|
31
32
|
|
|
32
33
|
def _check_config(self):
|
|
@@ -42,6 +43,11 @@ class CommonConfig:
|
|
|
42
43
|
if not isinstance(self.enable_dataloader, bool):
|
|
43
44
|
logger.error_log_with_exp("enable_dataloader is invalid, it should be a boolean",
|
|
44
45
|
MsprobeException(MsprobeException.INVALID_PARAM_ERROR))
|
|
46
|
+
if not isinstance(self.async_dump, bool):
|
|
47
|
+
logger.error_log_with_exp("async_dump is invalid, it should be a boolean",
|
|
48
|
+
MsprobeException(MsprobeException.INVALID_PARAM_ERROR))
|
|
49
|
+
elif self.async_dump:
|
|
50
|
+
logger.warning("async_dump is True, it may cause OOM when dumping large tensor.")
|
|
45
51
|
|
|
46
52
|
|
|
47
53
|
class BaseConfig:
|