mindstudio-probe 1.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (228) hide show
  1. mindstudio_probe-1.0.1.dist-info/LICENSE +201 -0
  2. mindstudio_probe-1.0.1.dist-info/METADATA +30 -0
  3. mindstudio_probe-1.0.1.dist-info/RECORD +228 -0
  4. mindstudio_probe-1.0.1.dist-info/WHEEL +5 -0
  5. mindstudio_probe-1.0.1.dist-info/entry_points.txt +2 -0
  6. mindstudio_probe-1.0.1.dist-info/top_level.txt +1 -0
  7. msprobe/README.md +182 -0
  8. msprobe/__init__.py +0 -0
  9. msprobe/config/README.md +397 -0
  10. msprobe/config/config.json +28 -0
  11. msprobe/config/img/free_benchmark.png +0 -0
  12. msprobe/core/common/const.py +241 -0
  13. msprobe/core/common/exceptions.py +88 -0
  14. msprobe/core/common/file_check.py +265 -0
  15. msprobe/core/common/log.py +55 -0
  16. msprobe/core/common/utils.py +516 -0
  17. msprobe/core/common_config.py +58 -0
  18. msprobe/core/data_dump/data_collector.py +140 -0
  19. msprobe/core/data_dump/data_processor/base.py +245 -0
  20. msprobe/core/data_dump/data_processor/factory.py +61 -0
  21. msprobe/core/data_dump/data_processor/pytorch_processor.py +346 -0
  22. msprobe/core/data_dump/json_writer.py +116 -0
  23. msprobe/core/data_dump/scope.py +178 -0
  24. msprobe/mindspore/__init__.py +1 -0
  25. msprobe/mindspore/debugger/__init__.py +0 -0
  26. msprobe/mindspore/debugger/debugger_config.py +51 -0
  27. msprobe/mindspore/debugger/precision_debugger.py +32 -0
  28. msprobe/mindspore/doc/dump.md +65 -0
  29. msprobe/mindspore/dump/__init__.py +0 -0
  30. msprobe/mindspore/dump/api_kbk_dump.py +55 -0
  31. msprobe/mindspore/dump/dump_tool_factory.py +38 -0
  32. msprobe/mindspore/dump/kernel_graph_dump.py +60 -0
  33. msprobe/mindspore/ms_config.py +78 -0
  34. msprobe/mindspore/overflow_check/__init__.py +0 -0
  35. msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +45 -0
  36. msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +32 -0
  37. msprobe/mindspore/task_handler_factory.py +21 -0
  38. msprobe/msprobe.py +67 -0
  39. msprobe/pytorch/__init__.py +4 -0
  40. msprobe/pytorch/advisor/advisor.py +124 -0
  41. msprobe/pytorch/advisor/advisor_const.py +59 -0
  42. msprobe/pytorch/advisor/advisor_result.py +58 -0
  43. msprobe/pytorch/api_accuracy_checker/.keep +0 -0
  44. msprobe/pytorch/api_accuracy_checker/__init__.py +0 -0
  45. msprobe/pytorch/api_accuracy_checker/common/.keep +0 -0
  46. msprobe/pytorch/api_accuracy_checker/common/__init__.py +0 -0
  47. msprobe/pytorch/api_accuracy_checker/common/config.py +50 -0
  48. msprobe/pytorch/api_accuracy_checker/common/utils.py +224 -0
  49. msprobe/pytorch/api_accuracy_checker/compare/__init__.py +0 -0
  50. msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +216 -0
  51. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +545 -0
  52. msprobe/pytorch/api_accuracy_checker/compare/api_precision_standard.yaml +133 -0
  53. msprobe/pytorch/api_accuracy_checker/compare/api_precision_threshold.yaml +390 -0
  54. msprobe/pytorch/api_accuracy_checker/compare/compare.py +345 -0
  55. msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +74 -0
  56. msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +249 -0
  57. msprobe/pytorch/api_accuracy_checker/config.yaml +4 -0
  58. msprobe/pytorch/api_accuracy_checker/run_ut/.keep +0 -0
  59. msprobe/pytorch/api_accuracy_checker/run_ut/__init__.py +0 -0
  60. msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +328 -0
  61. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +203 -0
  62. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +127 -0
  63. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +493 -0
  64. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +7 -0
  65. msprobe/pytorch/api_accuracy_checker/run_ut/torch_ut_setting.json +5 -0
  66. msprobe/pytorch/common/__init__.py +2 -0
  67. msprobe/pytorch/common/compare_script.template +14 -0
  68. msprobe/pytorch/common/log.py +32 -0
  69. msprobe/pytorch/common/parse_json.py +37 -0
  70. msprobe/pytorch/common/utils.py +224 -0
  71. msprobe/pytorch/compare/acc_compare.py +1024 -0
  72. msprobe/pytorch/compare/distributed_compare.py +111 -0
  73. msprobe/pytorch/compare/highlight.py +100 -0
  74. msprobe/pytorch/compare/mapping.yaml +607 -0
  75. msprobe/pytorch/compare/match.py +36 -0
  76. msprobe/pytorch/compare/npy_compare.py +244 -0
  77. msprobe/pytorch/debugger/__init__.py +0 -0
  78. msprobe/pytorch/debugger/debugger_config.py +86 -0
  79. msprobe/pytorch/debugger/precision_debugger.py +95 -0
  80. msprobe/pytorch/doc/FAQ.md +193 -0
  81. msprobe/pytorch/doc/api_accuracy_checker.md +269 -0
  82. msprobe/pytorch/doc/atat/321/207/342/226/223/342/225/233/321/205/342/225/221/320/266/321/205/342/225/226/320/265/321/205/320/225/342/225/226/321/206/320/245/342/226/221/321/206/320/235/320/276dump/321/206/320/260/320/227/321/205/320/227/320/226/321/206/320/220/320/267/321/210/320/223/342/225/234/321/205/320/257/342/225/221/321/207/342/225/221/342/224/220/321/206/320/232/320/265/321/205/320/241/320/232.md +182 -0
  83. msprobe/pytorch/doc/dump.md +207 -0
  84. msprobe/pytorch/doc/img/BLOOM-7B_1.png +0 -0
  85. msprobe/pytorch/doc/img/BLOOM-7B_2.png +0 -0
  86. msprobe/pytorch/doc/img/BLOOM-7B_3.png +0 -0
  87. msprobe/pytorch/doc/img/BLOOM-7B_4.png +0 -0
  88. msprobe/pytorch/doc/img/GPT-3_1.png +0 -0
  89. msprobe/pytorch/doc/img/GPT-3_2.png +0 -0
  90. msprobe/pytorch/doc/img/GPT-3_3.png +0 -0
  91. msprobe/pytorch/doc/img/GPT-3_4.png +0 -0
  92. msprobe/pytorch/doc/img/GPT-3_5.png +0 -0
  93. msprobe/pytorch/doc/img/GPT-3_6.png +0 -0
  94. msprobe/pytorch/doc/img/GPT-3_7.png +0 -0
  95. msprobe/pytorch/doc/img/GPT-3_8.png +0 -0
  96. msprobe/pytorch/doc/img/YOLOV5S_1.png +0 -0
  97. msprobe/pytorch/doc/img/YOLOV5S_2.png +0 -0
  98. msprobe/pytorch/doc/img/accuracy_checking_details.png +0 -0
  99. msprobe/pytorch/doc/img/accuracy_checking_result.png +0 -0
  100. msprobe/pytorch/doc/img/api_precision_compare_details.png +0 -0
  101. msprobe/pytorch/doc/img/api_precision_compare_result.png +0 -0
  102. msprobe/pytorch/doc/img/auto_analyze_log.png +0 -0
  103. msprobe/pytorch/doc/img/compare_result_pkl.png +0 -0
  104. msprobe/pytorch/doc/img/compare_result_pkl_md5.png.png +0 -0
  105. msprobe/pytorch/doc/img/cpu_info.png +0 -0
  106. msprobe/pytorch/doc/img/module_compare.png +0 -0
  107. msprobe/pytorch/doc/parse_tool.md +286 -0
  108. msprobe/pytorch/doc/ptdbg_ascend_compare.md +176 -0
  109. msprobe/pytorch/doc/ptdbg_ascend_overview.md +68 -0
  110. msprobe/pytorch/doc/ptdbg_ascend_quickstart.md +381 -0
  111. msprobe/pytorch/doc/run_overflow_check.md +25 -0
  112. msprobe/pytorch/doc//321/205/320/254/320/270/321/207/342/225/221/342/224/220/321/207/342/226/223/342/225/233/321/205/342/225/221/320/266/321/206/320/277/320/244/321/205/320/277/342/225/243.md +90 -0
  113. msprobe/pytorch/free_benchmark/__init__.py +8 -0
  114. msprobe/pytorch/free_benchmark/common/__init__.py +0 -0
  115. msprobe/pytorch/free_benchmark/common/constant.py +67 -0
  116. msprobe/pytorch/free_benchmark/common/counter.py +72 -0
  117. msprobe/pytorch/free_benchmark/common/enums.py +37 -0
  118. msprobe/pytorch/free_benchmark/common/params.py +129 -0
  119. msprobe/pytorch/free_benchmark/common/utils.py +98 -0
  120. msprobe/pytorch/free_benchmark/compare/grad_saver.py +183 -0
  121. msprobe/pytorch/free_benchmark/compare/single_benchmark.py +104 -0
  122. msprobe/pytorch/free_benchmark/main.py +102 -0
  123. msprobe/pytorch/free_benchmark/perturbed_layers/__init__.py +0 -0
  124. msprobe/pytorch/free_benchmark/perturbed_layers/base_layer.py +13 -0
  125. msprobe/pytorch/free_benchmark/perturbed_layers/layer_factory.py +41 -0
  126. msprobe/pytorch/free_benchmark/perturbed_layers/npu/__init__.py +0 -0
  127. msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +90 -0
  128. msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +104 -0
  129. msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +63 -0
  130. msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +68 -0
  131. msprobe/pytorch/free_benchmark/perturbed_layers/npu/no_change.py +28 -0
  132. msprobe/pytorch/free_benchmark/perturbed_layers/npu/npu_base_layser.py +45 -0
  133. msprobe/pytorch/free_benchmark/perturbed_layers/run_cpu.py +19 -0
  134. msprobe/pytorch/free_benchmark/result_handlers/__init__.py +0 -0
  135. msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +203 -0
  136. msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +39 -0
  137. msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +24 -0
  138. msprobe/pytorch/free_benchmark/result_handlers/handler_factory.py +31 -0
  139. msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +170 -0
  140. msprobe/pytorch/functional/__init__.py +0 -0
  141. msprobe/pytorch/functional/data_processor.py +0 -0
  142. msprobe/pytorch/functional/dump_module.py +39 -0
  143. msprobe/pytorch/hook_module/__init__.py +1 -0
  144. msprobe/pytorch/hook_module/api_registry.py +161 -0
  145. msprobe/pytorch/hook_module/hook_module.py +109 -0
  146. msprobe/pytorch/hook_module/support_wrap_ops.yaml +1876 -0
  147. msprobe/pytorch/hook_module/utils.py +29 -0
  148. msprobe/pytorch/hook_module/wrap_aten.py +100 -0
  149. msprobe/pytorch/hook_module/wrap_distributed.py +75 -0
  150. msprobe/pytorch/hook_module/wrap_functional.py +108 -0
  151. msprobe/pytorch/hook_module/wrap_npu_custom.py +73 -0
  152. msprobe/pytorch/hook_module/wrap_tensor.py +72 -0
  153. msprobe/pytorch/hook_module/wrap_torch.py +88 -0
  154. msprobe/pytorch/hook_module/wrap_vf.py +64 -0
  155. msprobe/pytorch/module_processer.py +98 -0
  156. msprobe/pytorch/online_dispatch/__init__.py +20 -0
  157. msprobe/pytorch/online_dispatch/compare.py +236 -0
  158. msprobe/pytorch/online_dispatch/dispatch.py +274 -0
  159. msprobe/pytorch/online_dispatch/dump_compare.py +186 -0
  160. msprobe/pytorch/online_dispatch/single_compare.py +391 -0
  161. msprobe/pytorch/online_dispatch/torch_ops_config.yaml +50 -0
  162. msprobe/pytorch/online_dispatch/utils.py +187 -0
  163. msprobe/pytorch/parse.py +4 -0
  164. msprobe/pytorch/parse_tool/__init__.py +0 -0
  165. msprobe/pytorch/parse_tool/cli.py +32 -0
  166. msprobe/pytorch/parse_tool/lib/__init__.py +0 -0
  167. msprobe/pytorch/parse_tool/lib/compare.py +259 -0
  168. msprobe/pytorch/parse_tool/lib/config.py +51 -0
  169. msprobe/pytorch/parse_tool/lib/file_desc.py +31 -0
  170. msprobe/pytorch/parse_tool/lib/interactive_cli.py +102 -0
  171. msprobe/pytorch/parse_tool/lib/parse_exception.py +54 -0
  172. msprobe/pytorch/parse_tool/lib/parse_tool.py +158 -0
  173. msprobe/pytorch/parse_tool/lib/utils.py +367 -0
  174. msprobe/pytorch/parse_tool/lib/visualization.py +90 -0
  175. msprobe/pytorch/pt_config.py +93 -0
  176. msprobe/pytorch/service.py +167 -0
  177. msprobe/test/core_ut/common/test_utils.py +345 -0
  178. msprobe/test/core_ut/data_dump/test_data_collector.py +47 -0
  179. msprobe/test/core_ut/data_dump/test_json_writer.py +183 -0
  180. msprobe/test/core_ut/data_dump/test_scope.py +151 -0
  181. msprobe/test/core_ut/test_common_config.py +152 -0
  182. msprobe/test/core_ut/test_file_check.py +218 -0
  183. msprobe/test/core_ut/test_log.py +109 -0
  184. msprobe/test/mindspore_ut/test_api_kbk_dump.py +51 -0
  185. msprobe/test/mindspore_ut/test_debugger_config.py +42 -0
  186. msprobe/test/mindspore_ut/test_dump_tool_factory.py +51 -0
  187. msprobe/test/mindspore_ut/test_kernel_graph_dump.py +66 -0
  188. msprobe/test/mindspore_ut/test_kernel_graph_overflow_check.py +63 -0
  189. msprobe/test/mindspore_ut/test_ms_config.py +69 -0
  190. msprobe/test/mindspore_ut/test_overflow_check_tool_factory.py +51 -0
  191. msprobe/test/mindspore_ut/test_precision_debugger.py +56 -0
  192. msprobe/test/mindspore_ut/test_task_handler_factory.py +58 -0
  193. msprobe/test/pytorch_ut/advisor/test_advisor.py +83 -0
  194. msprobe/test/pytorch_ut/api_accuracy_checker/common/test_common_utils.py +108 -0
  195. msprobe/test/pytorch_ut/api_accuracy_checker/common/test_config.py +39 -0
  196. msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_algorithm.py +112 -0
  197. msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_api_precision_compare.py +77 -0
  198. msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_compare.py +125 -0
  199. msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_compare_column.py +10 -0
  200. msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_compare_utils.py +43 -0
  201. msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/dump.json +179 -0
  202. msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/forward.json +63 -0
  203. msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_data_generate.py +99 -0
  204. msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_multi_run_ut.py +115 -0
  205. msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_run_ut.py +72 -0
  206. msprobe/test/pytorch_ut/compare/test_acc_compare.py +17 -0
  207. msprobe/test/pytorch_ut/free_benchmark/perturbed_layers/test_perturbed_layser.py +105 -0
  208. msprobe/test/pytorch_ut/free_benchmark/result_handlers/test_result_handler.py +121 -0
  209. msprobe/test/pytorch_ut/free_benchmark/test_main.py +101 -0
  210. msprobe/test/pytorch_ut/functional/test_dump_module.py +15 -0
  211. msprobe/test/pytorch_ut/hook_module/test_api_registry.py +130 -0
  212. msprobe/test/pytorch_ut/hook_module/test_hook_module.py +42 -0
  213. msprobe/test/pytorch_ut/hook_module/test_wrap_aten.py +65 -0
  214. msprobe/test/pytorch_ut/hook_module/test_wrap_distributed.py +35 -0
  215. msprobe/test/pytorch_ut/hook_module/test_wrap_functional.py +20 -0
  216. msprobe/test/pytorch_ut/hook_module/test_wrap_tensor.py +35 -0
  217. msprobe/test/pytorch_ut/hook_module/test_wrap_torch.py +43 -0
  218. msprobe/test/pytorch_ut/hook_module/test_wrap_vf.py +11 -0
  219. msprobe/test/pytorch_ut/test_pt_config.py +69 -0
  220. msprobe/test/pytorch_ut/test_service.py +59 -0
  221. msprobe/test/resources/advisor.txt +3 -0
  222. msprobe/test/resources/compare_result_20230703104808.csv +9 -0
  223. msprobe/test/resources/compare_result_without_accuracy.csv +9 -0
  224. msprobe/test/resources/config.yaml +3 -0
  225. msprobe/test/resources/npu_test.pkl +8 -0
  226. msprobe/test/run_test.sh +30 -0
  227. msprobe/test/run_ut.py +58 -0
  228. msprobe/test/test_module_processer.py +64 -0
@@ -0,0 +1,381 @@
1
+ # **精度比对工具**
2
+
3
+ 本文主要介绍msprobe的精度比对工具的快速入门和场景化示例。
4
+
5
+ 本文介绍的操作需要安装msprobe工具,详见《[MindStudio精度调试工具](../../README.md)》的“工具安装”章节。
6
+
7
+ 本文介绍的操作主要是精度数据dump和精度比对,详细操作指导可参考《[精度数据采集](./dump.md)》和《[CPU或GPU与NPU精度数据比对](./ptdbg_ascend.md)》。
8
+
9
+ ## 快速入门
10
+
11
+ ### 单卡场景精度比对
12
+
13
+ **精度分析建议**
14
+
15
+ PyTorch训练场景的精度问题分析建议参考以下思路进行精度比对和比对结果分析:
16
+
17
+ 1. 整网比对:dump整网数据并进行精度比对,初步定位异常范围。
18
+
19
+ 对于模型数据庞大(比如达到T级别)的场景,不推荐直接dump整网比对,整网dump可能导致磁盘不足,需要预留足够的存储空间或者分多次dump。
20
+
21
+ 2. 缩小范围:根据Accuracy Reached or Not找出不符合精度标准的API。
22
+
23
+ 3. 范围比对:对不符合精度标准的API重新dump详细信息。
24
+
25
+ 4. 分析原因并优化:分析API精度不符合标准的原因并进行优化调整。
26
+
27
+ 5. 整网比对:重新进行整网比对,判断优化后的API是否已符合精度标准以及是否出现新的精度问题。
28
+
29
+ 6. 重复1~5步,直到不存在精度问题为止。
30
+
31
+ **精度分析示例**
32
+
33
+ 1. 修改dump配置文件config.json。
34
+
35
+ ```json
36
+ {
37
+ "task": "tensor",
38
+ "dump_path": "./npu_dump",
39
+ "rank": [],
40
+ "step": [],
41
+ "level": "L1",
42
+ "seed": 1234,
43
+ "is_deterministic": false,
44
+
45
+ "tensor": {
46
+ "scope": [],
47
+ "list": [],
48
+ "data_mode": ["all"],
49
+ "summary_mode": "statistics"
50
+ }
51
+ }
52
+ ```
53
+
54
+ 2. 在训练脚本内添加msprobe工具,dump整网数据。
55
+
56
+ 分别dump CPU或GPU以及NPU数据,在PyTorch训练脚本插入dump接口,示例代码如下(下面以NPU为例,CPU或GPU dump基本相同):
57
+
58
+ ```python
59
+ from msprobe.pytorch import PrecisionDebugger
60
+ debugger = PrecisionDebugger(config_path="./config.json", dump_path="./npu_dump")
61
+ # 请勿将以上初始化流程插入到循环代码中
62
+
63
+ # 模型初始化
64
+ # 下面代码也可以用PrecisionDebugger.start()和PrecisionDebugger.stop()
65
+ debugger.start()
66
+
67
+ # 需要dump的代码片段1
68
+
69
+ debugger.stop()
70
+ debugger.start()
71
+
72
+ # 需要dump的代码片段2
73
+
74
+ debugger.stop()
75
+ debugger.step()
76
+ ```
77
+
78
+ 3. 比对整网数据。
79
+
80
+ 第1步中的NPU dump数据目录为npu_dump,假设GPU dump数据目录为gpu_dump;dump将生成dump.json、stack.json、construct.json文件以及dump数据目录。
81
+
82
+ 创建并配置精度比对脚本,以创建compare.py为例,示例代码如下:
83
+
84
+ ```python
85
+ from msprobe.pytorch import compare
86
+ dump_result_param={
87
+ "npu_json_path": "./npu_dump/dump.json",
88
+ "bench_json_path": "./gpu_dump/dump.json",
89
+ "stack_json_path": "./npu_dump/stack.json",
90
+ "is_print_compare_log": True
91
+ }
92
+ compare(dump_result_param, output_path="./output", stack_mode=True)
93
+ ```
94
+
95
+ 执行比对:
96
+
97
+ ```bash
98
+ python3 compare.py
99
+ ```
100
+
101
+ 在output目录下生成结果文件,包括:`compare_result_{timestamp}.xlsx`和`advisor_{timestamp}.txt`
102
+
103
+ 4. 找出存在问题的API。
104
+
105
+ 1. 根据`advisor_{timestamp}.txt`或打屏信息的提示,可找到存在精度问题的算子(Suspect Nodes)和专家建议(Expert Advice)。
106
+
107
+ ![auto_analyze_log](img/auto_analyze_log.png)
108
+
109
+ 2. 根据第2步结果文件`compare_result_{timestamp}.xlsx`中的Accuracy Reached or No字段显示为NO的API,针对该API执行后续比对操作,分析该API存在的精度问题。
110
+
111
+ 5. (可选)重新比对。
112
+
113
+ 根据第3步的dump数据重新配置compare.py并执行比对,可以对单API模型进行问题复现。
114
+
115
+ **注意**:部分API存在调用嵌套关系,比如functional.batch_norm实际调用torch.batch_norm,该场景会影响kernel init初始化多次,导致功能异常。
116
+
117
+ ### 溢出检测场景
118
+
119
+ 溢出检测是针对NPU的PyTorch API,检测是否存在溢出的情况。当前仅支持识别aicore浮点溢出。
120
+
121
+ 溢出检测原理:针对溢出阶段,开启acl dump模式,重新对溢出阶段执行,落盘数据。
122
+
123
+ 建议按照如下步骤操作:
124
+
125
+ 1. 修改dump配置文件config.json。
126
+
127
+ ```json
128
+ {
129
+ "task": "overflow_check",
130
+ "dump_path": "./npu_dump",
131
+ "rank": [],
132
+ "step": [],
133
+ "level": "L1",
134
+ "seed": 1234,
135
+ "is_deterministic": false,
136
+
137
+ "overflow_check": {
138
+ "overflow_nums": 3
139
+ }
140
+ }
141
+ ```
142
+
143
+ 2. 在NPU训练脚本内添加msprobe工具,执行溢出检测dump。
144
+
145
+ ```python
146
+ from msprobe.pytorch import PrecisionDebugger
147
+ debugger = PrecisionDebugger(config_path="./config.json", dump_path="./npu_dump")
148
+ # 请勿将以上初始化流程插入到循环代码中
149
+
150
+ # 模型初始化
151
+ # 下面代码也可以用PrecisionDebugger.start()和PrecisionDebugger.stop()
152
+ debugger.start()
153
+
154
+ # 需要dump的代码片段1
155
+
156
+ debugger.stop()
157
+ debugger.start()
158
+
159
+ # 需要dump的代码片段2
160
+
161
+ debugger.stop()
162
+ debugger.step()
163
+ ```
164
+
165
+ 多卡使用时各卡单独计算溢出次数。
166
+
167
+ 3. NPU环境下执行训练dump溢出数据。
168
+
169
+ 针对输入正常但输出存在溢出的API,会在训练执行目录下将溢出的API信息dump并保存为`dump.json`通过《[溢出解析工具](./run_overflow_check.md)》对json文件进行解析,输出溢出API为正常溢出还是非正常溢出,从而帮助用户快速判断。
170
+
171
+ 溢出解析工具执行命令如下:
172
+
173
+ ```bash
174
+ msprobe -f pytorch run_overflow_check -api_info ./dump.json
175
+ ```
176
+
177
+ 反向过程溢出的API暂不支持精度预检功能。
178
+
179
+
180
+ 当重复执行溢出检测dump操作时,需要删除上一次dump目录下的溢出检测dump数据,否则将因重名而报错。
181
+
182
+ **注意事项**
183
+
184
+ * (暂不支持)level为L2场景下,会增加npu的内存消耗,请谨慎开启。
185
+ * (暂不支持)l部分API存在调用嵌套关系,比如functional.batch_norm实际调用torch.batch_norm,该场景会影响acl init初始化多次,导致level为L2功能异常。
186
+ * 混合精度动态loss scale场景下,正常训练会有"Gradient overflow. SKipping step"日志,添加溢出检测后日志消失,可以通过设置环境变量export OVERFLOW_DEBUG_MODE_ENABLE=1,并将register_hook位置调整amp.initialize之前解决。此功能需要cann包配套支持,不支持版本执行报错EZ3003。
187
+
188
+ ## 场景化示例
189
+
190
+ ### 多卡场景精度比对
191
+
192
+ 精度工具支持多卡场景的精度比对,多卡场景的dump步骤与单卡场景完全一致,请参见“**单卡场景精度比对**”章节,不同的是多卡数据精度比对时需要使用“compare_distributed”函数进行比对。
193
+
194
+ 如下示例:
195
+
196
+ 说明:多机多卡场景需要每个节点单独执行比对操作。
197
+
198
+ 假设NPU dump 数据目录为npu_dump,GPU dump数据目录为gpu_dump。
199
+
200
+ 1. 创建比对脚本,例如compare_distributed.py,拷贝如下代码。
201
+
202
+ ```python
203
+ from msprobe.pytorch import *
204
+ compare_distributed('./npu_dump/step0', './gpu_dump/step0', './output')
205
+ ```
206
+
207
+ dump数据目录须指定到step级。
208
+
209
+ 2. 执行比对:
210
+
211
+ ```bash
212
+ python3 compare_distributed.py
213
+ ```
214
+
215
+ 两次运行须用相同数量的卡,传入`compare_distributed`的两个文件夹下须有相同个数的rank文件夹,且不包含其他无关文件,否则将无法比对。
216
+
217
+ **多卡set_dump_path注意事项**
218
+
219
+ 多卡一般为多进程,须保证每个进程都正确调用PrecisionDebugger,或把PrecisionDebugger插入到import语句后,如:
220
+
221
+ ```python
222
+ from msprobe.pytorch import PrecisionDebugger
223
+ debugger = PrecisionDebugger(config_path="./config.json", dump_path="./npu_dump")
224
+ ```
225
+
226
+ 如此可保证set_dump_path在每个进程都被调用。
227
+
228
+ ### NPU vs NPU精度比对
229
+
230
+ 对于NPU vs NPU场景,是针对同一模型,进行迭代(模型、API版本升级或设备硬件升级)时存在的精度下降问题,对比相同模型在迭代前后版本的API计算数值,进行问题定位。
231
+
232
+ 一般情况下迭代涉及NPU自定义算子,因此,可以仅dump NPU自定义算子进行比对。比对精度问题分析请参见“**单卡场景精度比对**”章节。
233
+
234
+ 工具当前支持dump NPU自定义算子如下:
235
+
236
+ | 序号 | NPU自定义算子 |
237
+ | :--- | ----------------------------------------------- |
238
+ | 1 | torch_npu.one_ |
239
+ | 2 | torch_npu.npu_sort_v2 |
240
+ | 3 | torch_npu.npu_transpose |
241
+ | 4 | torch_npu.npu_broadcast |
242
+ | 5 | torch_npu.npu_dtype_cast |
243
+ | 6 | torch_npu.empty_with_format |
244
+ | 7 | torch_npu.npu_one_hot |
245
+ | 8 | torch_npu.npu_stride_add |
246
+ | 9 | torch_npu.npu_ps_roi_pooling |
247
+ | 10 | torch_npu.npu_roi_align |
248
+ | 11 | torch_npu.npu_nms_v4 |
249
+ | 12 | torch_npu.npu_iou |
250
+ | 13 | torch_npu.npu_nms_with_mask |
251
+ | 14 | torch_npu.npu_pad |
252
+ | 15 | torch_npu.npu_bounding_box_encode |
253
+ | 16 | torch_npu.npu_bounding_box_decode |
254
+ | 17 | torch_npu.npu_batch_nms |
255
+ | 18 | torch_npu.npu_slice |
256
+ | 19 | torch_npu._npu_dropout |
257
+ | 20 | torch_npu.npu_indexing |
258
+ | 21 | torch_npu.npu_ifmr |
259
+ | 22 | torch_npu.npu_max |
260
+ | 23 | torch_npu.npu_scatter |
261
+ | 24 | torch_npu.npu_layer_norm_eval |
262
+ | 25 | torch_npu.npu_alloc_float_status |
263
+ | 26 | torch_npu.npu_confusion_transpose |
264
+ | 27 | torch_npu.npu_bmmV2 |
265
+ | 28 | torch_npu.fast_gelu |
266
+ | 29 | torch_npu.npu_sub_sample |
267
+ | 30 | torch_npu.npu_deformable_conv2d |
268
+ | 31 | torch_npu.npu_mish |
269
+ | 32 | torch_npu.npu_anchor_response_flags |
270
+ | 33 | torch_npu.npu_yolo_boxes_encode |
271
+ | 34 | torch_npu.npu_grid_assign_positive |
272
+ | 35 | torch_npu.npu_normalize_batch |
273
+ | 36 | torch_npu.npu_masked_fill_range |
274
+ | 37 | torch_npu.npu_linear |
275
+ | 38 | torch_npu.npu_bert_apply_adam |
276
+ | 39 | torch_npu.npu_giou |
277
+ | 40 | torch_npu.npu_ciou |
278
+ | 41 | torch_npu.npu_diou |
279
+ | 42 | torch_npu.npu_sign_bits_pack |
280
+ | 43 | torch_npu.npu_sign_bits_unpack |
281
+ | 44 | torch_npu.npu_flash_attention |
282
+ | 45 | torch_npu.npu_scaled_masked_softmax |
283
+ | 46 | torch_npu.npu_rotary_mul |
284
+ | 47 | torch_npu.npu_roi_align |
285
+ | 48 | torch_npu.npu_roi_alignbk |
286
+ | 49 | torch_npu.npu_ptiou |
287
+ | 50 | torch_npu.npu_fusion_attention |
288
+ | 51 | torch_npu.npu_dropout_with_add_softmax |
289
+ | 52 | torch_npu.npu_random_choice_with_mask |
290
+ | 53 | torch_npu.npu_rotated_iou |
291
+ | 54 | torch_npu.npu_conv2d |
292
+ | 55 | torch_npu.npu_conv3d |
293
+ | 56 | torch_npu.npu_softmax_cross_entropy_with_logits |
294
+ | 57 | torch_npu.npu_all_gather_base_mm |
295
+ | 58 | torch_npu.npu_swiglu |
296
+ | 59 | torch_npu.npu_rms_norm |
297
+ | 60 | torch_npu.npu_mm_reduce_scatter_base |
298
+ | 61 | torch_npu.npu_mm_all_reduce_base |
299
+ | 62 | torch_npu.npu_conv_transpose2d |
300
+ | 63 | torch_npu.npu_convolution |
301
+ | 64 | torch_npu.npu_convolution_transpose |
302
+ | 65 | torch_npu.npu_min |
303
+ | 66 | torch_npu.npu_nms_rotated |
304
+ | 67 | torch_npu.npu_reshape |
305
+ | 68 | torch_npu.npu_rotated_box_decode |
306
+ | 69 | torch_npu.npu_rotated_box_encode |
307
+ | 70 | torch_npu.npu_rotated_overlaps |
308
+ | 71 | torch_npu.npu_silu |
309
+ | 72 | torch_npu.npu_fused_attention_score |
310
+ | 73 | torch_npu.npu_multi_head_attention |
311
+ | 74 | torch_npu.npu_gru |
312
+ | 75 | torch_npu.npu_incre_flash_attention |
313
+ | 76 | torch_npu.npu_prompt_flash_attention |
314
+ | 77 | torch_npu.npu_lstm |
315
+ | 78 | torch_npu.npu_apply_adam |
316
+
317
+ ### 通信API的数据dump
318
+
319
+ 通信类API数据可以使用全量dump方式获取,若只dump通信类API数据,可以使用如下示例:
320
+
321
+ 1. 修改dump配置文件config.json。
322
+
323
+ ```json
324
+ {
325
+ "task": "tensor",
326
+ "dump_path": "./npu_dump",
327
+ "rank": [],
328
+ "step": [],
329
+ "level": "L1",
330
+ "seed": 1234,
331
+ "is_deterministic": false,
332
+
333
+ "tensor": {
334
+ "scope": [],
335
+ "list": ["distributed"],
336
+ "data_mode": ["all"],
337
+ "summary_mode": "statistics"
338
+ }
339
+ }
340
+ ```
341
+
342
+ 2. 在训练脚本内添加msprobe工具,dump整网数据。
343
+
344
+ ```python
345
+ from msprobe.pytorch import PrecisionDebugger
346
+ debugger = PrecisionDebugger(config_path="./config.json", dump_path="./npu_dump")
347
+ # 请勿将以上初始化流程插入到循环代码中
348
+
349
+ # 模型初始化
350
+ # 下面代码也可以用PrecisionDebugger.start()和PrecisionDebugger.stop()
351
+ debugger.start()
352
+
353
+ # 需要dump的代码片段1
354
+
355
+ debugger.stop()
356
+ debugger.start()
357
+
358
+ # 需要dump的代码片段2
359
+
360
+ debugger.stop()
361
+ debugger.step()
362
+ ```
363
+
364
+ 通信类API支持列表:
365
+
366
+ | 序号 | Distributed |
367
+ | :--- | -------------------- |
368
+ | 1 | send |
369
+ | 2 | recv |
370
+ | 3 | broadcast |
371
+ | 4 | all_reduce |
372
+ | 5 | reduce |
373
+ | 6 | all_gather |
374
+ | 7 | gather |
375
+ | 8 | isend |
376
+ | 9 | irecv |
377
+ | 10 | scatter |
378
+ | 11 | reduce_scatter |
379
+ | 12 | _reduce_scatter_base |
380
+ | 13 | _all_gather_base |
381
+
@@ -0,0 +1,25 @@
1
+ # **溢出解析工具**
2
+
3
+ 针对训练过程中的溢出检测场景(当《[精度数据采集](./dump.md)》开启溢出检测dump时),对于输入正常但输出存在溢出的API,会在训练执行目录下将溢出的API信息按照前向和反向分类,dump并保存为`dump.json`,前向过程溢出的API可通过该工具对`dump.json`进行解析,输出溢出API为正常溢出还是非正常溢出,从而帮助用户快速判断。
4
+
5
+ 工具支持PyTorch版本:1.11.0/2.0/2.1/2.2。
6
+
7
+ 操作步骤如下:
8
+
9
+ 1. 安装工具。
10
+
11
+ 详见《[MindStudio精度调试工具](../../README.md)》的“工具安装”章节。
12
+
13
+ 2. 执行溢出API解析操作。
14
+
15
+ ```bash
16
+ msprobe -f pytorch run_overflow_check -api_info ./dump.json
17
+ ```
18
+
19
+ | 参数名称 | 说明 | 是否必选 |
20
+ | -------------------------- | -------------------------------------------------- | -------- |
21
+ | -api_info或--api_info_file | 指定API信息文件dump.json。 | 是 |
22
+ | -j或--jit_compile | 开启jit编译。 | 否 |
23
+ | -d或--device | 指定Device ID,选择UT代码运行所在的卡,默认值为0。 | 否 |
24
+
25
+ 反向过程溢出的API暂不支持该功能。
@@ -0,0 +1,90 @@
1
+ # **PyTorch NPU在线精度比对工具使用指南**
2
+
3
+ PyTorch NPU在线精度比对是ptdbg_ascend工具实现在PyTorch训练过程中直接完成精度比对并输出比对结果的功能。
4
+
5
+ 在线精度比对实现的是NPU与CPU之间的精度比对。
6
+
7
+ ## PyTorch NPU在线精度比对总体流程
8
+
9
+ 1. 准备NPU训练工程。
10
+
11
+ 2. 在NPU环境下安装ptdbg_ascend工具,参见《[PyTorch精度工具](https://gitee.com/ascend/mstt/blob/master/debug/accuracy_tools/ptdbg_ascend/README.md)》。
12
+
13
+ 3. 在训练脚本内插入ptdbg_ascend工具在线精度比对接口。
14
+
15
+ 4. 执行训练并获取在线精度比对NPU和CPU分别执行后的精度比对结果。
16
+
17
+ 5. 比对结果分析。
18
+
19
+ ## PyTorch NPU在线精度比对
20
+ ### 总体说明
21
+ - 本节主要介绍NPU精度比对所需要的函数以及示例。
22
+ - 在线精度比对工具通过截获PyTorch框架中部分Aten Ir及其输入输出,并将输入数据转到CPU执行,最后将NPU和CPU的执行结果进行精度比对得到比对结果。
23
+
24
+ ### 约束
25
+
26
+ - Pytorch 只支持2.0及其以上版本。
27
+ - 只支持Aten Ir级在线精度比对,所有Aten Ir可以通过dir(torch.ops.aten)查看,其中部分IR不支持在线比对:Aten Ir无对应CPU实现、NPU和CPU同AtenIR实现逻辑不一致,导致同输入不同输出。
28
+ - 正反向不支持同时在线精度比对,不支持跨step在线精度比对。
29
+
30
+
31
+ ### 场景示例
32
+ 1. 在NPU训练脚本中添加在线精度比对接口,示例如下:
33
+
34
+ ```python
35
+ from msprobe.pytorch.common.utils import seed_all
36
+ from msprobe.pytorch.online_dispatch import PtdbgDispatch
37
+
38
+ # 在main函数开始前固定随机数
39
+ seed_all()
40
+
41
+
42
+ ...
43
+
44
+ # 在需要调试精度的正向或反向代码前设置
45
+ # 正向示例
46
+ with PtdbgDispatch(dump_mode="auto", dump_path="/home/dump"):
47
+ output = model_cpu(inputs)
48
+ # 反向示例
49
+ with PtdbgDispatch(dump_mode="auto", dump_path="/home/dump"):
50
+ loss.backward()
51
+ ```
52
+
53
+ 2. 执行训练。
54
+
55
+ 3. 找出精度不达标的Aten IR。
56
+
57
+ 执行过程中会打屏Failed,Failed在比对结果csv中的Accuracy Reached or Not列标记为No,并在Dump目录下存盘精度不达标Aten IR的输入输出。
58
+ ![图片说明](http://image.huawei.com/tiny-lts/v1/images/d83d564e337e80c7cfb557ca3600d0d4_1689x178.png@900-0-90-f.png)
59
+
60
+ ### 计算精度评价指标
61
+
62
+ 1. Cosine < 0.99 且 MaxAbsError > 0.001时,精度不达标;
63
+ 2. Cosine < 0.9,精度不达标;
64
+ 3. MaxAbsError > 1,精度不达标。
65
+
66
+ ### 在线精度比对参数设置说明
67
+
68
+ | 参数名称 | 说明 | 是否必选 |
69
+ | -------- |-------------------------------------------------------------------------------------------------| -------- |
70
+ | dump_mode| dump模式,可取值"all"、"list"、"auto"、"OFF",默认值为OFF(表示不Dump数据)。 | 否 |
71
+ | api_list | dump范围,dump_mode="list"时设置,需要Dump Aten Ir API名称,默认为None,Aten Ir API名称可以通过dir(torch.ops.aten)查看。 | 否 |
72
+ | dump_path| dump文件生成的路径。 | 是 |
73
+ | tag | 传入tag字符串,成为dump文件夹名一部分,默认为None。 | 否 |
74
+ | process_num | 多进程并发数,默认为0。 | 否 |
75
+ | debug | debug信息打印,默认为False。 | 否 |
76
+ ### dump数据存盘说明
77
+ dump数据存盘目录名格式:`msprobe_tag_rankid_{timestamp}`。
78
+
79
+ 子目录下包含1个比对结果csv文件、cpu和npudump数据目录,npu目录下包含Aten IR在NPU上的输入输出的dump数据,由于CPU的输入是直接使用NPU的输入执行,因此cpu目录下只包含执行输出的dump数据。
80
+
81
+ ```bash
82
+ msprobe_rank4_20230911170521
83
+ ├── compare_result_rank4_20230911170521.csv
84
+ ├── cpu
85
+ │   ├── native_batch_norm_backward_10_output.0.npy
86
+ │ ............
87
+ └── npu
88
+ ├── native_batch_norm_backward_10_input.0.npy
89
+ ............
90
+ ```
@@ -0,0 +1,8 @@
1
+ from msprobe.core.common.log import logger
2
+ from msprobe.core.common.exceptions import FreeBenchmarkException
3
+ from msprobe.core.common.const import Const
4
+
5
+ from .main import FreeBenchmarkCheck
6
+ from .common.params import UnequalRow
7
+
8
+ __all__ = [FreeBenchmarkCheck, UnequalRow]
File without changes
@@ -0,0 +1,67 @@
1
+ from typing import Dict
2
+
3
+ import numpy as np
4
+ import torch
5
+ from msprobe.pytorch.free_benchmark.common.enums import FuzzThreshold
6
+ from msprobe.pytorch.free_benchmark.common.params import BenchmarkThd
7
+
8
+
9
+ class CommonField:
10
+ DEVICE = "device"
11
+ META = "meta"
12
+ FUZZ_TENSOR = "fuzz_tensor"
13
+ REQUIRES_GRAD = "requires_grad"
14
+ HOLD_PLACE = "hold_place"
15
+ DISTRIBUTED_OP = "torch.distributed"
16
+ GRADSAVER = "grad_saver"
17
+
18
+
19
+ class ThresholdConfig:
20
+ PERTURBATION_VALUE_DICT: Dict = {
21
+ torch.bfloat16: FuzzThreshold.BF16_THD,
22
+ torch.float16: FuzzThreshold.F16_THD,
23
+ torch.float32: FuzzThreshold.F32_THD,
24
+ torch.float64: FuzzThreshold.F64_THD,
25
+ }
26
+
27
+ ABS_TOL_VALUE_DICT: Dict = {
28
+ torch.bfloat16: FuzzThreshold.BF16_THD,
29
+ torch.float16: FuzzThreshold.F16_THD,
30
+ torch.float32: FuzzThreshold.F32_THD,
31
+ torch.float64: FuzzThreshold.F64_THD,
32
+ }
33
+
34
+ # bit翻转需要匹配到等长或更长的整型
35
+ PERTURBATION_BIT_DICT = {
36
+ torch.bfloat16: torch.int16,
37
+ torch.float16: torch.int16,
38
+ torch.float32: torch.int32,
39
+ torch.float64: torch.int64,
40
+ }
41
+
42
+ # 输入噪声下界
43
+ NOISE_INPUT_LOWER_BOUND = 1e-8
44
+ COMP_CONSISTENT = 1.0
45
+ COMP_NAN = np.nan
46
+ SYMBOL_FLIPPING = "symbol_flipping"
47
+ BACKWARD_OUTPUT_LOWER_BOUND = 1e-3
48
+ SMALL_VALUE = 1.0
49
+ # 预热初始阈值
50
+ PREHEAT_INITIAL_THD = 2.05
51
+ API_THD_STEP = 2.0
52
+
53
+ DTYPE_PER_THD = {
54
+ torch.float16: 1.002,
55
+ torch.float32: 1.0002,
56
+ }
57
+ BENCHMARK_THD_DICT = {
58
+ torch.float32: BenchmarkThd(2**-14, 1.0, 2**-14, 1e-4),
59
+ torch.float16: BenchmarkThd(2**-11, 1.0, 2**-11, 1e-4),
60
+ torch.bfloat16: BenchmarkThd(2**-8, 1.0, 2**-8, 1e-4),
61
+ }
62
+
63
+
64
+ class PreheatConfig:
65
+ IF_PREHEAT = "if_preheat"
66
+ PREHEAT_STEP = "preheat_step"
67
+ MAX_SAMPLE = "max_sample"
@@ -0,0 +1,72 @@
1
+ from collections import defaultdict
2
+ from msprobe.pytorch.free_benchmark.common.constant import ThresholdConfig
3
+
4
+
5
+ class PreheatCounter:
6
+ def __init__(self) -> None:
7
+ self.api_called_time: dict = defaultdict(int)
8
+ self.api_sample_time: dict = defaultdict(int)
9
+ self.one_step_used_api: dict = defaultdict(int)
10
+ self.api_thd: dict = defaultdict(dict)
11
+ self.preheat_record: dict = defaultdict(dict)
12
+ self.dtype_map: dict = {}
13
+ self.if_preheat: dict = defaultdict(dict)
14
+ self.step = 0
15
+
16
+ def clear_step(self):
17
+ self.preheat_record.clear()
18
+ self.api_called_time.clear()
19
+ self.api_sample_time.clear()
20
+
21
+ def check_step(self, current_step):
22
+ if current_step != self.step:
23
+ self.clear_step()
24
+ self.step = current_step
25
+
26
+ def add_api_called_time(self, api_name: str):
27
+ self.api_called_time[api_name] += 1
28
+
29
+ def get_api_called_time(self, api_name: str) -> int:
30
+ return self.api_called_time[api_name]
31
+
32
+ def add_api_sample_time(self, api_name: str):
33
+ self.api_sample_time[api_name] += 1
34
+
35
+ def get_api_sample_time(self, api_name: str) -> int:
36
+ return self.api_sample_time[api_name]
37
+
38
+ def add_one_step_used_api(self, api_name: str):
39
+ self.one_step_used_api[api_name] += 1
40
+
41
+ def get_one_step_used_api(self, api_name: str):
42
+ return self.one_step_used_api[api_name]
43
+
44
+ def update_preheat_record(self, api_name, dtype, cmp_result):
45
+ # 记录预热阶段CPU标杆比对的结果
46
+ if str(dtype) not in self.preheat_record[api_name].keys():
47
+ self.preheat_record[api_name][str(dtype)] = list()
48
+ self.preheat_record[api_name][str(dtype)].append(cmp_result)
49
+ self.dtype_map[str(dtype)] = dtype
50
+
51
+ def update_api_thd(self, api_name, dtype, threshold, dthreshold):
52
+ self.api_thd[api_name][str(dtype)] = (
53
+ threshold if threshold > dthreshold else dthreshold
54
+ )
55
+
56
+ def get_api_thd(self, api_name, dtype):
57
+ if not str(dtype) in self.api_thd[api_name]:
58
+ self.api_thd[api_name][str(dtype)] = ThresholdConfig.PREHEAT_INITIAL_THD
59
+ self.dtype_map[str(dtype)] = dtype
60
+ return self.api_thd[api_name][str(dtype)]
61
+
62
+ def set_api_preheat(self, api_name, dtype_str, is_preheat=True):
63
+ # 标记cpu不一致的dtype 不再进行预热
64
+ self.if_preheat[api_name][dtype_str] = is_preheat
65
+
66
+ def get_api_preheat(self, api_name, dtype):
67
+ # 标记cpu不一致的dtype 不再进行预热
68
+ if str(dtype) not in self.if_preheat[api_name]:
69
+ return True
70
+ return self.if_preheat[api_name][str(dtype)]
71
+
72
+ preheat_counter = PreheatCounter()
@@ -0,0 +1,37 @@
1
+ class PerturbationMode:
2
+ ADD_NOISE = "add_noise"
3
+ CHANGE_VALUE = "change_value"
4
+ IMPROVE_PRECISION = "improve_precision"
5
+ NO_CHANGE = "no_change"
6
+ BIT_NOISE = "bit_noise"
7
+ TO_CPU = "to_cpu"
8
+
9
+
10
+ class DeviceType:
11
+ NPU = "npu"
12
+ CPU = "cpu"
13
+
14
+
15
+ class FuzzThreshold:
16
+ BF16_THD = 1e-4
17
+ F16_THD = 1e-6
18
+ F32_THD = 1e-8
19
+ F64_THD = 1e-16
20
+
21
+
22
+ class NormType:
23
+ ONE_NORM = (1, "one_norm")
24
+ TWO_NORM = (2, "two_norm")
25
+ ENDLESS_NORM = (3, "endless_norm")
26
+
27
+
28
+ class HandlerType:
29
+ CHECK = "check"
30
+ PREHEAT = "preheat"
31
+ FIX = "fix"
32
+
33
+
34
+ class FuzzLevel:
35
+ BASE_LEVEL = "L1"
36
+ ADV_LEVEL = "L2"
37
+ REAL_LEVEL = "L3"