mindstudio-probe 1.2.2__py3-none-any.whl → 1.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-1.3.0.dist-info}/METADATA +3 -3
- {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-1.3.0.dist-info}/RECORD +143 -144
- msprobe/README.md +25 -20
- msprobe/core/common/const.py +110 -66
- msprobe/core/common/decorator.py +50 -0
- msprobe/core/common/exceptions.py +3 -1
- msprobe/core/common/file_utils.py +25 -2
- msprobe/core/common/utils.py +30 -34
- msprobe/core/compare/acc_compare.py +43 -74
- msprobe/core/compare/check.py +2 -6
- msprobe/core/compare/highlight.py +2 -0
- msprobe/core/compare/layer_mapping/layer_mapping.py +2 -1
- msprobe/core/compare/merge_result/merge_result.py +8 -2
- msprobe/core/compare/multiprocessing_compute.py +19 -12
- msprobe/core/compare/npy_compare.py +30 -12
- msprobe/core/compare/utils.py +20 -10
- msprobe/core/data_dump/api_registry.py +176 -0
- msprobe/core/data_dump/data_processor/base.py +2 -2
- msprobe/core/data_dump/data_processor/mindspore_processor.py +19 -32
- msprobe/core/data_dump/data_processor/pytorch_processor.py +45 -15
- msprobe/core/data_dump/json_writer.py +38 -35
- msprobe/core/grad_probe/constant.py +1 -0
- msprobe/core/grad_probe/grad_compare.py +1 -1
- msprobe/core/overflow_check/abnormal_scene.py +2 -0
- msprobe/docs/01.installation.md +2 -1
- msprobe/docs/02.config_introduction.md +17 -15
- msprobe/docs/05.data_dump_PyTorch.md +70 -2
- msprobe/docs/06.data_dump_MindSpore.md +33 -12
- msprobe/docs/07.accuracy_checker_PyTorch.md +11 -1
- msprobe/docs/08.accuracy_checker_online_PyTorch.md +3 -1
- msprobe/docs/09.accuracy_checker_MindSpore.md +1 -1
- msprobe/docs/10.accuracy_compare_PyTorch.md +59 -33
- msprobe/docs/11.accuracy_compare_MindSpore.md +40 -16
- msprobe/docs/12.overflow_check_PyTorch.md +3 -1
- msprobe/docs/13.overflow_check_MindSpore.md +4 -2
- msprobe/docs/14.data_parse_PyTorch.md +1 -7
- msprobe/docs/18.online_dispatch.md +1 -1
- msprobe/docs/19.monitor.md +124 -62
- msprobe/docs/21.visualization_PyTorch.md +32 -13
- msprobe/docs/22.visualization_MindSpore.md +32 -13
- msprobe/docs/23.generate_operator_PyTorch.md +9 -9
- msprobe/docs/27.dump_json_instruction.md +278 -8
- msprobe/docs/28.kernel_dump_MindSpore.md +1 -1
- msprobe/docs/29.data_dump_MSAdapter.md +229 -0
- msprobe/docs/30.overflow_check_MSAdapter.md +31 -0
- msprobe/docs/FAQ.md +3 -11
- msprobe/docs/img/compare_result.png +0 -0
- msprobe/docs/img/merge_result.png +0 -0
- msprobe/docs/img/visualization/vis_browser_1.png +0 -0
- msprobe/docs/img/visualization/vis_match_info.png +0 -0
- msprobe/docs/img/visualization/vis_precision_info.png +0 -0
- msprobe/docs/img/visualization/vis_search_info.png +0 -0
- msprobe/docs/img/visualization/vis_show_info.png +0 -0
- msprobe/docs/img/visualization/vis_showcase.png +0 -0
- msprobe/docs/img/visualization/vis_unmatch_info.png +0 -0
- msprobe/mindspore/__init__.py +4 -3
- msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +6 -1
- msprobe/mindspore/api_accuracy_checker/api_runner.py +19 -9
- msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +2 -1
- msprobe/mindspore/api_accuracy_checker/bench_functions/flash_attention_score.py +602 -0
- msprobe/mindspore/api_accuracy_checker/bench_functions/fusion_operator.py +41 -0
- msprobe/mindspore/api_accuracy_checker/data_manager.py +2 -1
- msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +2 -1
- msprobe/mindspore/api_accuracy_checker/torch_mindtorch_importer.py +2 -1
- msprobe/mindspore/common/const.py +61 -0
- msprobe/mindspore/common/utils.py +31 -19
- msprobe/mindspore/compare/ms_compare.py +27 -19
- msprobe/mindspore/compare/ms_graph_compare.py +6 -5
- msprobe/mindspore/debugger/debugger_config.py +6 -4
- msprobe/mindspore/debugger/precision_debugger.py +22 -10
- msprobe/mindspore/dump/dump_tool_factory.py +5 -3
- msprobe/mindspore/dump/hook_cell/api_register.py +142 -0
- msprobe/mindspore/dump/hook_cell/hook_cell.py +9 -10
- msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +24 -26
- msprobe/mindspore/dump/jit_dump.py +14 -9
- msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +22 -56
- msprobe/mindspore/dym_loader/hook_dynamic_loader.h +0 -1
- msprobe/mindspore/free_benchmark/api_pynative_self_check.py +10 -6
- msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +4 -2
- msprobe/mindspore/free_benchmark/self_check_tool_factory.py +6 -3
- msprobe/mindspore/grad_probe/global_context.py +2 -0
- msprobe/mindspore/grad_probe/grad_analyzer.py +2 -1
- msprobe/mindspore/grad_probe/hook.py +2 -4
- msprobe/mindspore/monitor/distributed/wrap_distributed.py +1 -1
- msprobe/mindspore/monitor/module_hook.py +354 -302
- msprobe/mindspore/monitor/utils.py +46 -4
- msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +5 -3
- msprobe/mindspore/service.py +23 -17
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +3 -6
- msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +11 -6
- msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +2 -2
- msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +4 -5
- msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +5 -5
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +25 -6
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +28 -19
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +3 -1
- msprobe/pytorch/bench_functions/moe_gating_top_k_softmax.py +6 -0
- msprobe/pytorch/common/utils.py +29 -7
- msprobe/pytorch/debugger/precision_debugger.py +10 -1
- msprobe/pytorch/dump/module_dump/module_dump.py +4 -3
- msprobe/pytorch/dump/module_dump/module_processer.py +12 -6
- msprobe/pytorch/free_benchmark/common/utils.py +1 -1
- msprobe/pytorch/free_benchmark/compare/single_benchmark.py +1 -1
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +3 -3
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +3 -3
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +1 -1
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +1 -1
- msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +1 -1
- msprobe/pytorch/function_factory.py +1 -1
- msprobe/pytorch/grad_probe/grad_monitor.py +2 -2
- msprobe/pytorch/hook_module/api_register.py +131 -0
- msprobe/pytorch/hook_module/hook_module.py +19 -14
- msprobe/pytorch/hook_module/register_optimizer_hook.py +2 -1
- msprobe/pytorch/hook_module/support_wrap_ops.yaml +172 -75
- msprobe/pytorch/monitor/csv2tb.py +8 -2
- msprobe/pytorch/monitor/distributed/wrap_distributed.py +8 -2
- msprobe/pytorch/monitor/module_hook.py +131 -105
- msprobe/pytorch/monitor/module_metric.py +3 -0
- msprobe/pytorch/monitor/optimizer_collect.py +55 -4
- msprobe/pytorch/monitor/unittest/test_monitor.py +1 -1
- msprobe/pytorch/monitor/utils.py +68 -1
- msprobe/pytorch/online_dispatch/compare.py +0 -2
- msprobe/pytorch/online_dispatch/dispatch.py +9 -0
- msprobe/pytorch/online_dispatch/dump_compare.py +3 -0
- msprobe/pytorch/online_dispatch/utils.py +3 -0
- msprobe/pytorch/parse_tool/lib/interactive_cli.py +1 -6
- msprobe/pytorch/parse_tool/lib/utils.py +2 -1
- msprobe/pytorch/pt_config.py +11 -7
- msprobe/pytorch/service.py +11 -8
- msprobe/visualization/builder/graph_builder.py +44 -5
- msprobe/visualization/builder/msprobe_adapter.py +0 -1
- msprobe/visualization/compare/graph_comparator.py +42 -38
- msprobe/visualization/compare/mode_adapter.py +0 -19
- msprobe/visualization/graph/base_node.py +8 -1
- msprobe/visualization/graph/distributed_analyzer.py +1 -10
- msprobe/visualization/graph/graph.py +0 -11
- msprobe/visualization/graph/node_op.py +1 -2
- msprobe/visualization/graph_service.py +1 -1
- msprobe/visualization/utils.py +2 -33
- msprobe/mindspore/dump/hook_cell/api_registry.py +0 -207
- msprobe/mindspore/dump/hook_cell/wrap_api.py +0 -212
- msprobe/pytorch/hook_module/api_registry.py +0 -166
- msprobe/pytorch/hook_module/wrap_distributed.py +0 -79
- msprobe/pytorch/hook_module/wrap_functional.py +0 -66
- msprobe/pytorch/hook_module/wrap_npu_custom.py +0 -85
- msprobe/pytorch/hook_module/wrap_tensor.py +0 -69
- msprobe/pytorch/hook_module/wrap_torch.py +0 -84
- msprobe/pytorch/hook_module/wrap_vf.py +0 -60
- msprobe/pytorch/parse.py +0 -19
- {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-1.3.0.dist-info}/LICENSE +0 -0
- {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-1.3.0.dist-info}/WHEEL +0 -0
- {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-1.3.0.dist-info}/entry_points.txt +0 -0
- {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-1.3.0.dist-info}/top_level.txt +0 -0
msprobe/README.md
CHANGED
|
@@ -44,6 +44,7 @@ export MSPROBE_LOG_LEVEL={x}
|
|
|
44
44
|
|
|
45
45
|
- msprobe支持AscendPyTorch 1.11.0或更高版本,支持的PyTorch和CANN以及PyTorch和python软件版本配套关系请参见《[Ascend Extension for PyTorch插件](https://gitee.com/ascend/pytorch)》。
|
|
46
46
|
- msprobe支持MindSpore 2.4.0或更高版本,支持的MindSpore和CANN以及MindSpore和python软件版本配套关系请参见《[MindSpore版本发布列表](https://www.mindspore.cn/versions)》。
|
|
47
|
+
- msprobe支持MSAdapter 2.1.0。
|
|
47
48
|
- msprobe支持的固件驱动版本与配套CANN软件支持的固件驱动版本相同,开发者可通过“[昇腾社区-固件与驱动](https://gitee.com/link?target=https%3A%2F%2Fwww.hiascend.com%2Fhardware%2Ffirmware-drivers%2Fcommunity%3Fproduct%3D2%26model%3D28%26cann%3D8.0.RC3.alpha003%26driver%3D1.0.25.alpha)”页面根据产品型号与CANN软件版本获取配套的固件与驱动。
|
|
48
49
|
|
|
49
50
|
|
|
@@ -69,35 +70,37 @@ export MSPROBE_LOG_LEVEL={x}
|
|
|
69
70
|
|
|
70
71
|
### 1 数据采集
|
|
71
72
|
|
|
72
|
-
msprobe 通过在训练脚本中添加 PrecisionDebugger 接口的方式对 API 执行精度数据 dump
|
|
73
|
+
msprobe 通过在训练脚本中添加 PrecisionDebugger 接口的方式对 API 执行精度数据 dump 操作。对应 config.json 中的 "statistics" 或 "tensor" task。
|
|
73
74
|
|
|
74
75
|
[PyTorch 场景的数据采集](./docs/05.data_dump_PyTorch.md)
|
|
75
76
|
|
|
76
77
|
[MindSpore 场景的数据采集](./docs/06.data_dump_MindSpore.md)
|
|
77
78
|
|
|
79
|
+
[MSAdapter 场景的数据采集](./docs/29.data_dump_MSAdapter.md)
|
|
80
|
+
|
|
78
81
|
### 2 精度预检
|
|
79
82
|
|
|
80
|
-
精度预检旨在昇腾 NPU 上扫描训练模型中的所有 API 进行 API 复现,给出精度情况的诊断和分析。对应 config.json 中的 task
|
|
83
|
+
精度预检旨在昇腾 NPU 上扫描训练模型中的所有 API 进行 API 复现,给出精度情况的诊断和分析。对应 config.json 中的 "run_ut" task。
|
|
81
84
|
|
|
82
85
|
PyTorch 场景的[离线预检](./docs/07.accuracy_checker_PyTorch.md)和[在线预检](./docs/08.accuracy_checker_online_PyTorch.md)
|
|
83
86
|
|
|
84
87
|
MindSpore 动态图场景的[离线预检](./docs/09.accuracy_checker_MindSpore.md)
|
|
85
88
|
|
|
86
|
-
### 3
|
|
89
|
+
### 3 分级可视化构图比对
|
|
87
90
|
|
|
88
|
-
|
|
91
|
+
该功能将msprobe工具dump的精度数据进行解析,还原模型图结构,实现模型各个层级的精度数据比对,方便用户理解模型结构、分析精度问题。
|
|
89
92
|
|
|
90
|
-
[PyTorch
|
|
93
|
+
[PyTorch 场景的分级可视化构图比对](./docs/21.visualization_PyTorch.md)
|
|
91
94
|
|
|
92
|
-
[MindSpore
|
|
95
|
+
[MindSpore 场景的分级可视化构图比对](./docs/22.visualization_MindSpore.md)
|
|
93
96
|
|
|
94
|
-
### 4
|
|
97
|
+
### 4 精度比对
|
|
95
98
|
|
|
96
|
-
|
|
99
|
+
该功能进行 PyTorch 整网 API 粒度的数据 dump、精度比对,进而定位训练场景下的精度问题。
|
|
97
100
|
|
|
98
|
-
[PyTorch
|
|
101
|
+
[PyTorch 场景的精度比对](./docs/10.accuracy_compare_PyTorch.md)
|
|
99
102
|
|
|
100
|
-
[MindSpore
|
|
103
|
+
[MindSpore 场景的精度比对](./docs/11.accuracy_compare_MindSpore.md)
|
|
101
104
|
|
|
102
105
|
### 5 数据解析
|
|
103
106
|
|
|
@@ -129,26 +132,28 @@ MindSpore 动态图场景的[离线预检](./docs/09.accuracy_checker_MindSpore.
|
|
|
129
132
|
|
|
130
133
|
[兼容 PyTorch 和 MindSpore 框架的训练状态监控](./docs/19.monitor.md)
|
|
131
134
|
|
|
132
|
-
### 10
|
|
135
|
+
### 10 单算子API自动生成脚本
|
|
133
136
|
|
|
134
|
-
该功能将msprobe工具dump
|
|
137
|
+
该功能将msprobe工具dump的精度数据进行解析,自动生成单API脚本,用于复现整网中出现的算子问题,降低用户复现问题的成本,供开发分析算子问题。
|
|
135
138
|
|
|
136
|
-
[PyTorch
|
|
139
|
+
[PyTorch 单算子API自动生成脚本](./docs/23.generate_operator_PyTorch.md)
|
|
137
140
|
|
|
138
|
-
|
|
141
|
+
### 11 数码关联
|
|
139
142
|
|
|
143
|
+
该功能只支持 MindSpore 静态图场景,用于将IR图与dump数据进行关联,获取dump数据和代码调用栈的关联关系。
|
|
140
144
|
|
|
141
|
-
|
|
145
|
+
[MindSpore 场景的数码关联](./docs/24.code_mapping_Mindspore.md)
|
|
142
146
|
|
|
143
|
-
|
|
147
|
+
### 12 溢出检测与解析
|
|
144
148
|
|
|
145
|
-
|
|
149
|
+
溢出检测用于采集溢出 API 或 模块的精度数据,而溢出解析则是通过对溢出数据的分析,进一步判断是否为正常溢出。对应 config.json 中的 "overflow_check" task。
|
|
150
|
+
推荐直接使用[数据采集](#1-数据采集)功能采集统计量信息,检测溢出问题。
|
|
146
151
|
|
|
147
|
-
|
|
152
|
+
[PyTorch 场景的溢出检测与解析](./docs/12.overflow_check_PyTorch.md)
|
|
148
153
|
|
|
149
|
-
|
|
154
|
+
[MindSpore 场景的溢出检测](./docs/13.overflow_check_MindSpore.md)
|
|
150
155
|
|
|
151
|
-
[
|
|
156
|
+
[MSAdapter 场景的溢出检测](./docs/30.overflow_check_MSAdapter.md)
|
|
152
157
|
|
|
153
158
|
## 📑 补充材料
|
|
154
159
|
|
msprobe/core/common/const.py
CHANGED
|
@@ -51,7 +51,10 @@ class Const:
|
|
|
51
51
|
FOUR_SEGMENT = 4
|
|
52
52
|
SIX_SEGMENT = 6
|
|
53
53
|
SEVEN_SEGMENT = 7
|
|
54
|
+
|
|
54
55
|
MAX_DEPTH = 10
|
|
56
|
+
CPU_QUARTER = 4
|
|
57
|
+
DUMP_MAX_DEPTH = 50
|
|
55
58
|
|
|
56
59
|
# dump mode
|
|
57
60
|
ALL = "all"
|
|
@@ -230,6 +233,92 @@ class Const:
|
|
|
230
233
|
|
|
231
234
|
TENSOR_STAT_LEN = 2
|
|
232
235
|
|
|
236
|
+
SUPPORT_API_FILE_NAME = "support_wrap_ops.yaml"
|
|
237
|
+
|
|
238
|
+
PT_API_TYPE_FUNCTIONAL = "functional"
|
|
239
|
+
PT_API_TYPE_TENSOR = "tensor"
|
|
240
|
+
PT_API_TYPE_TORCH = "torch"
|
|
241
|
+
PT_API_TYPE_VF = "_VF"
|
|
242
|
+
PT_API_TYPE_NPU = "torch_npu"
|
|
243
|
+
PT_API_TYPE_ATEN = "aten"
|
|
244
|
+
PT_API_TYPE_DIST = "distributed"
|
|
245
|
+
PT_API_TYPE_NPU_DIST = "npu_distributed"
|
|
246
|
+
|
|
247
|
+
MS_API_TYPE_OPS = "ops"
|
|
248
|
+
MS_API_TYPE_TENSOR = "tensor"
|
|
249
|
+
MS_API_TYPE_STUB_TENSOR = "stubtensor"
|
|
250
|
+
MS_API_TYPE_MINT = "mint.ops"
|
|
251
|
+
MS_API_TYPE_MINT_FUNC = "mint.nn.functional"
|
|
252
|
+
MS_API_TYPE_COM = "communication.comm_func"
|
|
253
|
+
|
|
254
|
+
FUNCTIONAL_API_TYPE_PREFIX = "Functional"
|
|
255
|
+
TENSOR_API_TYPE_PREFIX = "Tensor"
|
|
256
|
+
DIST_API_TYPE_PREFIX = "Distributed"
|
|
257
|
+
|
|
258
|
+
TORCH_API_TYPE_PREFIX = "Torch"
|
|
259
|
+
NPU_API_TYPE_PREFIX = "NPU"
|
|
260
|
+
ATEN_API_TYPE_PREFIX = "Aten"
|
|
261
|
+
VF_API_TYPE_PREFIX = "VF"
|
|
262
|
+
|
|
263
|
+
MINT_API_TYPE_PREFIX = "Mint"
|
|
264
|
+
MINT_FUNC_API_TYPE_PREFIX = "MintFunctional"
|
|
265
|
+
|
|
266
|
+
SUPPORT_API_DICT_KEY_MAP = {
|
|
267
|
+
PT_FRAMEWORK: {
|
|
268
|
+
PT_API_TYPE_FUNCTIONAL: PT_API_TYPE_FUNCTIONAL,
|
|
269
|
+
PT_API_TYPE_TENSOR: PT_API_TYPE_TENSOR,
|
|
270
|
+
PT_API_TYPE_TORCH: PT_API_TYPE_TORCH,
|
|
271
|
+
PT_API_TYPE_VF: PT_API_TYPE_VF,
|
|
272
|
+
PT_API_TYPE_NPU: PT_API_TYPE_NPU,
|
|
273
|
+
PT_API_TYPE_ATEN: PT_API_TYPE_ATEN,
|
|
274
|
+
PT_API_TYPE_DIST: PT_API_TYPE_DIST,
|
|
275
|
+
PT_API_TYPE_NPU_DIST: PT_API_TYPE_NPU_DIST
|
|
276
|
+
},
|
|
277
|
+
MS_FRAMEWORK: {
|
|
278
|
+
MS_API_TYPE_OPS: MS_API_TYPE_OPS,
|
|
279
|
+
MS_API_TYPE_TENSOR: MS_API_TYPE_TENSOR,
|
|
280
|
+
MS_API_TYPE_STUB_TENSOR: MS_API_TYPE_TENSOR,
|
|
281
|
+
MS_API_TYPE_MINT: MS_API_TYPE_MINT,
|
|
282
|
+
MS_API_TYPE_MINT_FUNC: MS_API_TYPE_MINT_FUNC,
|
|
283
|
+
MS_API_TYPE_COM: MS_API_TYPE_COM
|
|
284
|
+
},
|
|
285
|
+
MT_FRAMEWORK: {
|
|
286
|
+
PT_API_TYPE_FUNCTIONAL: PT_API_TYPE_FUNCTIONAL,
|
|
287
|
+
PT_API_TYPE_TENSOR: PT_API_TYPE_TENSOR,
|
|
288
|
+
PT_API_TYPE_TORCH: PT_API_TYPE_TORCH,
|
|
289
|
+
PT_API_TYPE_NPU: PT_API_TYPE_NPU,
|
|
290
|
+
PT_API_TYPE_DIST: PT_API_TYPE_DIST
|
|
291
|
+
}
|
|
292
|
+
}
|
|
293
|
+
|
|
294
|
+
API_DATA_PREFIX = {
|
|
295
|
+
PT_FRAMEWORK: {
|
|
296
|
+
PT_API_TYPE_FUNCTIONAL: FUNCTIONAL_API_TYPE_PREFIX,
|
|
297
|
+
PT_API_TYPE_TENSOR: TENSOR_API_TYPE_PREFIX,
|
|
298
|
+
PT_API_TYPE_TORCH: TORCH_API_TYPE_PREFIX,
|
|
299
|
+
PT_API_TYPE_VF: VF_API_TYPE_PREFIX,
|
|
300
|
+
PT_API_TYPE_NPU: NPU_API_TYPE_PREFIX,
|
|
301
|
+
PT_API_TYPE_ATEN: ATEN_API_TYPE_PREFIX,
|
|
302
|
+
PT_API_TYPE_DIST: DIST_API_TYPE_PREFIX,
|
|
303
|
+
PT_API_TYPE_NPU_DIST: DIST_API_TYPE_PREFIX
|
|
304
|
+
},
|
|
305
|
+
MS_FRAMEWORK: {
|
|
306
|
+
MS_API_TYPE_OPS: FUNCTIONAL_API_TYPE_PREFIX,
|
|
307
|
+
MS_API_TYPE_TENSOR: TENSOR_API_TYPE_PREFIX,
|
|
308
|
+
MS_API_TYPE_STUB_TENSOR: TENSOR_API_TYPE_PREFIX,
|
|
309
|
+
MS_API_TYPE_MINT: MINT_API_TYPE_PREFIX,
|
|
310
|
+
MS_API_TYPE_MINT_FUNC: MINT_FUNC_API_TYPE_PREFIX,
|
|
311
|
+
MS_API_TYPE_COM: DIST_API_TYPE_PREFIX
|
|
312
|
+
},
|
|
313
|
+
MT_FRAMEWORK: {
|
|
314
|
+
PT_API_TYPE_FUNCTIONAL: FUNCTIONAL_API_TYPE_PREFIX,
|
|
315
|
+
PT_API_TYPE_TENSOR: TENSOR_API_TYPE_PREFIX,
|
|
316
|
+
PT_API_TYPE_TORCH: TORCH_API_TYPE_PREFIX,
|
|
317
|
+
PT_API_TYPE_NPU: NPU_API_TYPE_PREFIX,
|
|
318
|
+
PT_API_TYPE_DIST: DIST_API_TYPE_PREFIX
|
|
319
|
+
}
|
|
320
|
+
}
|
|
321
|
+
|
|
233
322
|
|
|
234
323
|
class CompareConst:
|
|
235
324
|
"""
|
|
@@ -256,6 +345,7 @@ class CompareConst:
|
|
|
256
345
|
MEAN_DIFF = "Mean diff"
|
|
257
346
|
NORM_DIFF = "L2norm diff"
|
|
258
347
|
COSINE = "Cosine"
|
|
348
|
+
EUC_DIST = "EucDist"
|
|
259
349
|
MAX_ABS_ERR = "MaxAbsErr"
|
|
260
350
|
MAX_RELATIVE_ERR = "MaxRelativeErr"
|
|
261
351
|
MIN_RELATIVE_ERR = "MinRelativeErr"
|
|
@@ -330,8 +420,8 @@ class CompareConst:
|
|
|
330
420
|
ULP_ERR_STATUS = "ulp_err_status"
|
|
331
421
|
|
|
332
422
|
COMPARE_RESULT_HEADER = [
|
|
333
|
-
NPU_NAME, BENCH_NAME, NPU_DTYPE, BENCH_DTYPE, NPU_SHAPE, BENCH_SHAPE, COSINE,
|
|
334
|
-
ONE_THOUSANDTH_ERR_RATIO, FIVE_THOUSANDTHS_ERR_RATIO,
|
|
423
|
+
NPU_NAME, BENCH_NAME, NPU_DTYPE, BENCH_DTYPE, NPU_SHAPE, BENCH_SHAPE, COSINE, EUC_DIST,
|
|
424
|
+
MAX_ABS_ERR, MAX_RELATIVE_ERR, ONE_THOUSANDTH_ERR_RATIO, FIVE_THOUSANDTHS_ERR_RATIO,
|
|
335
425
|
NPU_MAX, NPU_MIN, NPU_MEAN, NPU_NORM, BENCH_MAX, BENCH_MIN, BENCH_MEAN, BENCH_NORM, ACCURACY, ERROR_MESSAGE
|
|
336
426
|
]
|
|
337
427
|
|
|
@@ -357,18 +447,16 @@ class CompareConst:
|
|
|
357
447
|
Const.MD5: MD5_COMPARE_RESULT_HEADER
|
|
358
448
|
}
|
|
359
449
|
|
|
360
|
-
ALL_COMPARE_INDEX = [COSINE, MAX_ABS_ERR, MAX_RELATIVE_ERR, ONE_THOUSANDTH_ERR_RATIO,
|
|
450
|
+
ALL_COMPARE_INDEX = [COSINE, EUC_DIST, MAX_ABS_ERR, MAX_RELATIVE_ERR, ONE_THOUSANDTH_ERR_RATIO,
|
|
451
|
+
FIVE_THOUSANDTHS_ERR_RATIO]
|
|
361
452
|
SUMMARY_COMPARE_INDEX = [MAX_DIFF, MIN_DIFF, MEAN_DIFF, NORM_DIFF,
|
|
362
453
|
MAX_RELATIVE_ERR, MIN_RELATIVE_ERR, MEAN_RELATIVE_ERR, NORM_RELATIVE_ERR]
|
|
363
454
|
|
|
364
455
|
# dtype match
|
|
365
|
-
|
|
366
|
-
|
|
367
|
-
|
|
368
|
-
|
|
369
|
-
TORCH_TYPE = [
|
|
370
|
-
[Const.TORCH_FLOAT16, Const.TORCH_FLOAT32], [Const.TORCH_FLOAT32, Const.TORCH_FLOAT16],
|
|
371
|
-
[Const.TORCH_FLOAT16, Const.TORCH_BFLOAT16], [Const.TORCH_BFLOAT16, Const.TORCH_FLOAT16]
|
|
456
|
+
|
|
457
|
+
DTYPE_MATCH_GROUPS = [
|
|
458
|
+
{Const.FLOAT16, Const.FLOAT32, Const.BFLOAT16},
|
|
459
|
+
{Const.TORCH_FLOAT16, Const.TORCH_FLOAT32, Const.TORCH_BFLOAT16}
|
|
372
460
|
]
|
|
373
461
|
|
|
374
462
|
# read_op
|
|
@@ -467,7 +555,7 @@ class CompareConst:
|
|
|
467
555
|
BENCH_MEAN: None, BENCH_NORM: None, ACCURACY: '', ERROR_MESSAGE: ''
|
|
468
556
|
}
|
|
469
557
|
MS_GRAPH_NPY = {
|
|
470
|
-
COSINE: None, MAX_ABS_ERR: None, MAX_RELATIVE_ERR: None, ONE_THOUSANDTH_ERR_RATIO: None,
|
|
558
|
+
COSINE: None, EUC_DIST: None, MAX_ABS_ERR: None, MAX_RELATIVE_ERR: None, ONE_THOUSANDTH_ERR_RATIO: None,
|
|
471
559
|
FIVE_THOUSANDTHS_ERR_RATIO: None
|
|
472
560
|
}
|
|
473
561
|
MS_GRAPH_STATISTIC = {
|
|
@@ -538,61 +626,6 @@ class OverflowConst:
|
|
|
538
626
|
OVERFLOW_DEBUG_MODE = 1
|
|
539
627
|
|
|
540
628
|
|
|
541
|
-
class MsCompareConst:
|
|
542
|
-
# api_info field
|
|
543
|
-
MINT = "Mint"
|
|
544
|
-
MINT_FUNCTIONAL = "MintFunctional"
|
|
545
|
-
TENSOR_API = "Tensor"
|
|
546
|
-
|
|
547
|
-
API_NAME_STR_LENGTH = 4
|
|
548
|
-
MAX_RECURSION_DEPTH = 20
|
|
549
|
-
|
|
550
|
-
# Mindtorch api_info field
|
|
551
|
-
MINDTORCH_TENSOR = "Tensor"
|
|
552
|
-
MINDTORCH = "Torch"
|
|
553
|
-
MINDTORCH_FUNC = "Functional"
|
|
554
|
-
MINDTORCH_NPU = "NPU"
|
|
555
|
-
MINDTORCH_DIST = "Distributed"
|
|
556
|
-
|
|
557
|
-
|
|
558
|
-
|
|
559
|
-
MT_VALID_API_TYPES = [
|
|
560
|
-
MINDTORCH, MINDTORCH_FUNC, MINDTORCH_TENSOR
|
|
561
|
-
]
|
|
562
|
-
|
|
563
|
-
TASK_FIELD = "task"
|
|
564
|
-
STATISTICS_TASK = "statistics"
|
|
565
|
-
FRAMEWORK = "framework"
|
|
566
|
-
TENSOR_TASK = "tensor"
|
|
567
|
-
DUMP_DATA_DIR_FIELD = "dump_data_dir"
|
|
568
|
-
DATA_FIELD = "data"
|
|
569
|
-
|
|
570
|
-
# supported api yaml
|
|
571
|
-
SUPPORTED_API_LIST_FILE = "checker_support_api.yaml"
|
|
572
|
-
SUPPORTED_TENSOR_LIST_KEY = "tensor"
|
|
573
|
-
|
|
574
|
-
# detail_csv
|
|
575
|
-
DETAIL_CSV_API_NAME = "API Name"
|
|
576
|
-
DETAIL_CSV_BENCH_DTYPE = "Bench Dtype"
|
|
577
|
-
DETAIL_CSV_TESTED_DTYPE = "Tested Dtype"
|
|
578
|
-
DETAIL_CSV_SHAPE = "Shape"
|
|
579
|
-
DETAIL_CSV_PASS_STATUS = "Status"
|
|
580
|
-
DETAIL_CSV_MESSAGE = "Message"
|
|
581
|
-
DETAIL_CSV_FILE_NAME = "accuracy_checking_details"
|
|
582
|
-
|
|
583
|
-
# result_csv
|
|
584
|
-
RESULT_CSV_FORWARD_TEST_SUCCESS = "Forward Test Success"
|
|
585
|
-
RESULT_CSV_BACKWARD_TEST_SUCCESS = "Backward Test Success"
|
|
586
|
-
RESULT_CSV_FILE_NAME = "accuracy_checking_result"
|
|
587
|
-
|
|
588
|
-
EPSILON = 1e-8
|
|
589
|
-
|
|
590
|
-
class ProcessStatus:
|
|
591
|
-
SUCCESS = "success"
|
|
592
|
-
API_NOT_FOUND = "api_not_found"
|
|
593
|
-
EXCEPTION_SKIP = "exception_skip"
|
|
594
|
-
|
|
595
|
-
|
|
596
629
|
class MsgConst:
|
|
597
630
|
"""
|
|
598
631
|
Class for log messages const
|
|
@@ -629,6 +662,15 @@ class MonitorConst:
|
|
|
629
662
|
"""
|
|
630
663
|
Class for monitor const
|
|
631
664
|
"""
|
|
665
|
+
|
|
666
|
+
# monitor config set default values
|
|
667
|
+
DEFAULT_GRAD_ACC_STEPS = 1
|
|
668
|
+
DEFAULT_START_ITERATION = 0
|
|
669
|
+
DEFAULT_START_STEP = 0
|
|
670
|
+
DEFAULT_MAX_COLLECT_TIMES = 1e8
|
|
671
|
+
DEFAULT_MIN_COLLECT_TIMES = 0
|
|
672
|
+
DEFAULT_STEP_INTERVAL = 1
|
|
673
|
+
|
|
632
674
|
OP_LIST = ["norm", "min", "max", "zeros", "nans", "id", "mean"]
|
|
633
675
|
MONITOR_OUTPUT_DIR = "MONITOR_OUTPUT_DIR"
|
|
634
676
|
DEFAULT_MONITOR_OUTPUT_DIR = "./monitor_output"
|
|
@@ -674,3 +716,5 @@ class MonitorConst:
|
|
|
674
716
|
CSV = "csv"
|
|
675
717
|
API = "api"
|
|
676
718
|
HEADER_NAME = 'name'
|
|
719
|
+
|
|
720
|
+
MAX_NDIGITS = 20
|
|
@@ -0,0 +1,50 @@
|
|
|
1
|
+
# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
from collections import defaultdict
|
|
17
|
+
from functools import wraps
|
|
18
|
+
|
|
19
|
+
from msprobe.core.common.const import Const
|
|
20
|
+
from msprobe.core.common.exceptions import MsprobeException
|
|
21
|
+
from msprobe.core.common.log import logger
|
|
22
|
+
|
|
23
|
+
# 记录工具函数递归的深度
|
|
24
|
+
recursion_depth = defaultdict(int)
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def recursion_depth_decorator(func_info, max_depth=Const.MAX_DEPTH):
|
|
28
|
+
"""装饰一个函数,当函数递归调用超过限制时,抛出异常并打印函数信息。"""
|
|
29
|
+
def decorator(func):
|
|
30
|
+
@wraps(func)
|
|
31
|
+
def wrapper(*args, **kwargs):
|
|
32
|
+
func_id = id(func)
|
|
33
|
+
recursion_depth[func_id] += 1
|
|
34
|
+
if recursion_depth[func_id] > max_depth:
|
|
35
|
+
msg = f"call {func_info} exceeds the recursion limit."
|
|
36
|
+
logger.error_log_with_exp(
|
|
37
|
+
msg,
|
|
38
|
+
MsprobeException(
|
|
39
|
+
MsprobeException.RECURSION_LIMIT_ERROR, msg
|
|
40
|
+
),
|
|
41
|
+
)
|
|
42
|
+
try:
|
|
43
|
+
result = func(*args, **kwargs)
|
|
44
|
+
finally:
|
|
45
|
+
recursion_depth[func_id] -= 1
|
|
46
|
+
return result
|
|
47
|
+
|
|
48
|
+
return wrapper
|
|
49
|
+
|
|
50
|
+
return decorator
|
|
@@ -28,12 +28,14 @@ class MsprobeException(CodedException):
|
|
|
28
28
|
OVERFLOW_NUMS_ERROR = 1
|
|
29
29
|
RECURSION_LIMIT_ERROR = 2
|
|
30
30
|
INTERFACE_USAGE_ERROR = 3
|
|
31
|
+
UNSUPPORTED_TYPE_ERROR = 4
|
|
31
32
|
|
|
32
33
|
err_strs = {
|
|
33
34
|
INVALID_PARAM_ERROR: "[msprobe] 无效参数:",
|
|
34
35
|
OVERFLOW_NUMS_ERROR: "[msprobe] 超过预设溢出次数 当前溢出次数:",
|
|
35
36
|
RECURSION_LIMIT_ERROR: "[msprobe] 递归调用超过限制:",
|
|
36
|
-
INTERFACE_USAGE_ERROR: "[msprobe] Invalid interface usage: "
|
|
37
|
+
INTERFACE_USAGE_ERROR: "[msprobe] Invalid interface usage: ",
|
|
38
|
+
UNSUPPORTED_TYPE_ERROR: "[msprobe] Unsupported type: "
|
|
37
39
|
}
|
|
38
40
|
|
|
39
41
|
|
|
@@ -26,6 +26,7 @@ import yaml
|
|
|
26
26
|
import numpy as np
|
|
27
27
|
import pandas as pd
|
|
28
28
|
|
|
29
|
+
from msprobe.core.common.decorator import recursion_depth_decorator
|
|
29
30
|
from msprobe.core.common.log import logger
|
|
30
31
|
from msprobe.core.common.exceptions import FileCheckException
|
|
31
32
|
from msprobe.core.common.const import FileCheckConst
|
|
@@ -266,6 +267,7 @@ def make_dir(dir_path):
|
|
|
266
267
|
file_check.common_check()
|
|
267
268
|
|
|
268
269
|
|
|
270
|
+
@recursion_depth_decorator('msprobe.core.common.file_utils.create_directory', max_depth=16)
|
|
269
271
|
def create_directory(dir_path):
|
|
270
272
|
"""
|
|
271
273
|
Function Description:
|
|
@@ -332,6 +334,23 @@ def change_mode(path, mode):
|
|
|
332
334
|
'Failed to change {} authority. {}'.format(path, str(ex))) from ex
|
|
333
335
|
|
|
334
336
|
|
|
337
|
+
@recursion_depth_decorator('msprobe.core.common.file_utils.recursive_chmod')
|
|
338
|
+
def recursive_chmod(path):
|
|
339
|
+
"""
|
|
340
|
+
递归地修改目录及其子目录和文件的权限,文件修改为640,路径修改为750
|
|
341
|
+
|
|
342
|
+
:param path: 要修改权限的目录路径
|
|
343
|
+
"""
|
|
344
|
+
for _, dirs, files in os.walk(path):
|
|
345
|
+
for file_name in files:
|
|
346
|
+
file_path = os.path.join(path, file_name)
|
|
347
|
+
change_mode(file_path, FileCheckConst.DATA_FILE_AUTHORITY)
|
|
348
|
+
for dir_name in dirs:
|
|
349
|
+
dir_path = os.path.join(path, dir_name)
|
|
350
|
+
change_mode(dir_path, FileCheckConst.DATA_DIR_AUTHORITY)
|
|
351
|
+
recursive_chmod(dir_path)
|
|
352
|
+
|
|
353
|
+
|
|
335
354
|
def path_len_exceeds_limit(file_path):
|
|
336
355
|
return len(os.path.realpath(file_path)) > FileCheckConst.DIRECTORY_LENGTH or \
|
|
337
356
|
len(os.path.basename(file_path)) > FileCheckConst.FILE_NAME_LENGTH
|
|
@@ -632,7 +651,7 @@ def os_walk_for_files(path, depth):
|
|
|
632
651
|
return res
|
|
633
652
|
|
|
634
653
|
|
|
635
|
-
def check_crt_valid(pem_path):
|
|
654
|
+
def check_crt_valid(pem_path, is_public_key=False):
|
|
636
655
|
"""
|
|
637
656
|
Check the validity of the SSL certificate.
|
|
638
657
|
|
|
@@ -641,6 +660,7 @@ def check_crt_valid(pem_path):
|
|
|
641
660
|
|
|
642
661
|
Parameters:
|
|
643
662
|
pem_path (str): The file path of the SSL certificate.
|
|
663
|
+
is_public_key (bool): The file is public key or not.
|
|
644
664
|
|
|
645
665
|
Raises:
|
|
646
666
|
RuntimeError: If the SSL certificate is invalid or expired.
|
|
@@ -649,7 +669,10 @@ def check_crt_valid(pem_path):
|
|
|
649
669
|
try:
|
|
650
670
|
with FileOpen(pem_path, "r") as f:
|
|
651
671
|
pem_data = f.read()
|
|
652
|
-
|
|
672
|
+
if is_public_key:
|
|
673
|
+
cert = OpenSSL.crypto.load_publickey(OpenSSL.crypto.FILETYPE_PEM, pem_data)
|
|
674
|
+
else:
|
|
675
|
+
cert = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM, pem_data)
|
|
653
676
|
pem_start = parser.parse(cert.get_notBefore().decode("UTF-8"))
|
|
654
677
|
pem_end = parser.parse(cert.get_notAfter().decode("UTF-8"))
|
|
655
678
|
logger.info(f"The SSL certificate passes the verification and the validity period "
|
msprobe/core/common/utils.py
CHANGED
|
@@ -18,9 +18,7 @@ import os
|
|
|
18
18
|
import re
|
|
19
19
|
import subprocess
|
|
20
20
|
import time
|
|
21
|
-
from collections import defaultdict
|
|
22
21
|
from datetime import datetime, timezone
|
|
23
|
-
from functools import wraps
|
|
24
22
|
|
|
25
23
|
import numpy as np
|
|
26
24
|
|
|
@@ -75,6 +73,7 @@ class MsprobeBaseException(Exception):
|
|
|
75
73
|
MERGE_COMPARE_RESULT_ERROR = 33
|
|
76
74
|
NAMES_STRUCTS_MATCH_ERROR = 34
|
|
77
75
|
INVALID_STATE_ERROR = 35
|
|
76
|
+
INVALID_API_NAME_ERROR = 36
|
|
78
77
|
|
|
79
78
|
def __init__(self, code, error_info: str = ""):
|
|
80
79
|
super(MsprobeBaseException, self).__init__()
|
|
@@ -247,6 +246,10 @@ def md5_find(data):
|
|
|
247
246
|
|
|
248
247
|
|
|
249
248
|
def detect_framework_by_dump_json(file_path):
|
|
249
|
+
json_data = load_json(file_path)
|
|
250
|
+
framework = json_data.get("framework", None)
|
|
251
|
+
if framework in [Const.PT_FRAMEWORK, Const.MS_FRAMEWORK]:
|
|
252
|
+
return framework
|
|
250
253
|
pattern_ms = r'"type":\s*"mindspore'
|
|
251
254
|
pattern_pt = r'"type":\s*"torch'
|
|
252
255
|
with FileOpen(file_path, 'r') as file:
|
|
@@ -279,7 +282,7 @@ def set_dump_path(input_param):
|
|
|
279
282
|
npu_path_valid = npu_path is not None and npu_path.endswith("dump.json")
|
|
280
283
|
bench_path_valid = bench_path is not None and bench_path.endswith("dump.json")
|
|
281
284
|
if not npu_path_valid or not bench_path_valid:
|
|
282
|
-
logger.error(f"Please check the json path is valid
|
|
285
|
+
logger.error(f"Please check the json path is valid and ensure that neither npu_path nor bench_path is None.")
|
|
283
286
|
raise CompareException(CompareException.INVALID_PATH_ERROR)
|
|
284
287
|
input_param['npu_dump_data_dir'] = os.path.join(os.path.dirname(npu_path), Const.DUMP_TENSOR_DATA)
|
|
285
288
|
input_param['bench_dump_data_dir'] = os.path.join(os.path.dirname(bench_path), Const.DUMP_TENSOR_DATA)
|
|
@@ -424,6 +427,15 @@ def get_real_step_or_rank(step_or_rank_input, obj):
|
|
|
424
427
|
return real_step_or_rank
|
|
425
428
|
|
|
426
429
|
|
|
430
|
+
def check_init_step(step):
|
|
431
|
+
if not is_int(step):
|
|
432
|
+
raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR,
|
|
433
|
+
f"{step} must be an integer")
|
|
434
|
+
if not step >= 0:
|
|
435
|
+
raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR,
|
|
436
|
+
f"{step} must be greater than or equal to 0")
|
|
437
|
+
|
|
438
|
+
|
|
427
439
|
def check_seed_all(seed, mode, rm_dropout):
|
|
428
440
|
if is_int(seed):
|
|
429
441
|
if seed < 0 or seed > Const.MAX_SEED_VALUE:
|
|
@@ -467,36 +479,6 @@ def safe_get_value(container, index, container_name, key=None):
|
|
|
467
479
|
raise MsprobeBaseException(MsprobeBaseException.INVALID_OBJECT_TYPE_ERROR) from e
|
|
468
480
|
|
|
469
481
|
|
|
470
|
-
# 记录工具函数递归的深度
|
|
471
|
-
recursion_depth = defaultdict(int)
|
|
472
|
-
|
|
473
|
-
|
|
474
|
-
# 装饰一个函数,当函数递归调用超过限制时,抛出异常并打印函数信息。
|
|
475
|
-
def recursion_depth_decorator(func_info):
|
|
476
|
-
def decorator(func):
|
|
477
|
-
@wraps(func)
|
|
478
|
-
def wrapper(*args, **kwargs):
|
|
479
|
-
func_id = id(func)
|
|
480
|
-
recursion_depth[func_id] += 1
|
|
481
|
-
if recursion_depth[func_id] > Const.MAX_DEPTH:
|
|
482
|
-
msg = f"call {func_info} exceeds the recursion limit."
|
|
483
|
-
logger.error_log_with_exp(
|
|
484
|
-
msg,
|
|
485
|
-
MsprobeException(
|
|
486
|
-
MsprobeException.RECURSION_LIMIT_ERROR, msg
|
|
487
|
-
),
|
|
488
|
-
)
|
|
489
|
-
try:
|
|
490
|
-
result = func(*args, **kwargs)
|
|
491
|
-
finally:
|
|
492
|
-
recursion_depth[func_id] -= 1
|
|
493
|
-
return result
|
|
494
|
-
|
|
495
|
-
return wrapper
|
|
496
|
-
|
|
497
|
-
return decorator
|
|
498
|
-
|
|
499
|
-
|
|
500
482
|
def check_str_param(param):
|
|
501
483
|
if not re.match(Const.REGEX_PREFIX_PATTERN, param):
|
|
502
484
|
logger.error('The parameter {} contains special characters.'.format(param))
|
|
@@ -509,4 +491,18 @@ class DumpPathAggregation:
|
|
|
509
491
|
construct_file_path = None
|
|
510
492
|
dump_tensor_data_dir = None
|
|
511
493
|
free_benchmark_file_path = None
|
|
512
|
-
debug_file_path = None
|
|
494
|
+
debug_file_path = None
|
|
495
|
+
|
|
496
|
+
|
|
497
|
+
def is_save_variable_valid(variable, valid_special_types, depth=0):
|
|
498
|
+
if depth > Const.DUMP_MAX_DEPTH:
|
|
499
|
+
return False
|
|
500
|
+
if isinstance(variable, valid_special_types):
|
|
501
|
+
return True
|
|
502
|
+
elif isinstance(variable, (list, tuple)):
|
|
503
|
+
return all(is_save_variable_valid(item, valid_special_types, depth + 1) for item in variable)
|
|
504
|
+
elif isinstance(variable, dict):
|
|
505
|
+
return all(isinstance(key, str) and is_save_variable_valid(value, valid_special_types, depth + 1)
|
|
506
|
+
for key, value in variable.items())
|
|
507
|
+
else:
|
|
508
|
+
return False
|