mindstudio-probe 1.1.0__py3-none-any.whl → 1.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (299) hide show
  1. {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/METADATA +7 -6
  2. mindstudio_probe-1.2.1.dist-info/RECORD +396 -0
  3. {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/WHEEL +1 -1
  4. {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/entry_points.txt +0 -1
  5. msprobe/CMakeLists.txt +5 -0
  6. msprobe/README.md +51 -20
  7. msprobe/config.json +2 -3
  8. msprobe/core/advisor/advisor.py +8 -3
  9. msprobe/core/common/const.py +264 -15
  10. msprobe/core/common/exceptions.py +27 -3
  11. msprobe/core/common/file_utils.py +176 -26
  12. msprobe/core/common/inplace_op_checker.py +15 -0
  13. msprobe/core/common/inplace_ops.yaml +3 -0
  14. msprobe/core/common/log.py +27 -9
  15. msprobe/core/common/utils.py +204 -77
  16. msprobe/core/common_config.py +49 -14
  17. msprobe/core/compare/acc_compare.py +274 -198
  18. msprobe/core/compare/check.py +32 -33
  19. msprobe/core/compare/compare_cli.py +32 -14
  20. msprobe/core/compare/highlight.py +283 -127
  21. msprobe/core/compare/layer_mapping/__init__.py +19 -0
  22. msprobe/core/compare/layer_mapping/data_scope_parser.py +246 -0
  23. msprobe/core/compare/layer_mapping/layer_mapping.py +249 -0
  24. msprobe/core/compare/layer_mapping/postprocess_pass.py +95 -0
  25. msprobe/core/compare/merge_result/merge_result.py +380 -0
  26. msprobe/core/compare/merge_result/merge_result_cli.py +31 -0
  27. msprobe/core/compare/multiprocessing_compute.py +2 -2
  28. msprobe/core/compare/npy_compare.py +135 -144
  29. msprobe/core/compare/utils.py +419 -274
  30. msprobe/core/data_dump/data_collector.py +60 -28
  31. msprobe/core/data_dump/data_processor/base.py +84 -36
  32. msprobe/core/data_dump/data_processor/factory.py +5 -3
  33. msprobe/core/data_dump/data_processor/mindspore_processor.py +152 -18
  34. msprobe/core/data_dump/data_processor/pytorch_processor.py +267 -110
  35. msprobe/core/data_dump/json_writer.py +29 -1
  36. msprobe/core/data_dump/scope.py +119 -39
  37. msprobe/core/grad_probe/constant.py +27 -13
  38. msprobe/core/grad_probe/grad_compare.py +18 -1
  39. msprobe/core/grad_probe/utils.py +30 -2
  40. msprobe/core/overflow_check/abnormal_scene.py +189 -0
  41. msprobe/core/overflow_check/api_info.py +55 -0
  42. msprobe/core/overflow_check/checker.py +138 -0
  43. msprobe/core/overflow_check/filter.py +157 -0
  44. msprobe/core/overflow_check/ignore_rules.yaml +55 -0
  45. msprobe/core/overflow_check/level.py +22 -0
  46. msprobe/core/overflow_check/utils.py +28 -0
  47. msprobe/docs/01.installation.md +96 -7
  48. msprobe/docs/02.config_introduction.md +50 -23
  49. msprobe/docs/03.config_examples.md +2 -9
  50. msprobe/docs/04.kernel_dump_PyTorch.md +73 -0
  51. msprobe/docs/05.data_dump_PyTorch.md +93 -61
  52. msprobe/docs/06.data_dump_MindSpore.md +200 -95
  53. msprobe/docs/07.accuracy_checker_PyTorch.md +28 -28
  54. msprobe/docs/08.accuracy_checker_online_PyTorch.md +1 -6
  55. msprobe/docs/09.accuracy_checker_MindSpore.md +44 -8
  56. msprobe/docs/10.accuracy_compare_PyTorch.md +114 -50
  57. msprobe/docs/11.accuracy_compare_MindSpore.md +340 -48
  58. msprobe/docs/12.overflow_check_PyTorch.md +2 -2
  59. msprobe/docs/13.overflow_check_MindSpore.md +6 -6
  60. msprobe/docs/15.free_benchmarking_PyTorch.md +4 -5
  61. msprobe/docs/16.free_benchmarking_MindSpore.md +56 -37
  62. msprobe/docs/17.grad_probe.md +5 -6
  63. msprobe/docs/19.monitor.md +561 -0
  64. msprobe/docs/20.monitor_performance_baseline.md +52 -0
  65. msprobe/docs/21.visualization_PyTorch.md +466 -0
  66. msprobe/docs/22.visualization_MindSpore.md +481 -0
  67. msprobe/docs/23.generate_operator_PyTorch.md +107 -0
  68. msprobe/docs/24.code_mapping_Mindspore.md +28 -0
  69. msprobe/docs/25.tool_function_introduction.md +29 -0
  70. msprobe/docs/26.data_dump_PyTorch_baseline.md +37 -0
  71. msprobe/docs/27.dump_json_instruction.md +521 -0
  72. msprobe/docs/FAQ.md +29 -2
  73. msprobe/docs/accuracy_checker_MindSpore/accuracy_checker_MindSpore_baseline.md +14 -0
  74. msprobe/docs/data_dump_MindSpore/data_dump_MindSpore_baseline.md +22 -0
  75. msprobe/docs/data_dump_MindSpore/dynamic_graph_quick_start_example.md +211 -0
  76. msprobe/docs/img/compare_result.png +0 -0
  77. msprobe/docs/img/merge_result.png +0 -0
  78. msprobe/docs/img/monitor/cpu_info.png +0 -0
  79. msprobe/docs/img/visualization/fuzzy_match_ms.png +0 -0
  80. msprobe/docs/img/visualization/fuzzy_match_pt.png +0 -0
  81. msprobe/docs/img/visualization/tensorboard_1.png +0 -0
  82. msprobe/docs/img/visualization/tensorboard_2.png +0 -0
  83. msprobe/docs/img/visualization/vis_browser_1.png +0 -0
  84. msprobe/docs/img/visualization/vis_browser_2.png +0 -0
  85. msprobe/docs/img/visualization/vis_precision_info.png +0 -0
  86. msprobe/docs/img/visualization/vis_search_info.png +0 -0
  87. msprobe/docs/img/visualization/vis_show_info.png +0 -0
  88. msprobe/docs/img/visualization/vis_showcase.png +0 -0
  89. msprobe/docs/img/visualization/vis_unmatch_info.png +0 -0
  90. msprobe/docs/visualization/GPTModel.png +0 -0
  91. msprobe/docs/visualization/ParallelMLP.png +0 -0
  92. msprobe/docs/visualization/layer_mapping_example.md +132 -0
  93. msprobe/docs/visualization/mapping.png +0 -0
  94. msprobe/docs/visualization/mapping1.png +0 -0
  95. msprobe/docs/visualization/module_name.png +0 -0
  96. msprobe/docs/visualization/module_name1.png +0 -0
  97. msprobe/docs/visualization/no_mapping.png +0 -0
  98. msprobe/docs/visualization/no_mapping1.png +0 -0
  99. msprobe/docs/visualization/no_mapping_analyze.png +0 -0
  100. msprobe/docs/visualization/top_layer.png +0 -0
  101. msprobe/mindspore/__init__.py +25 -0
  102. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +151 -151
  103. msprobe/mindspore/api_accuracy_checker/api_info.py +21 -6
  104. msprobe/mindspore/api_accuracy_checker/api_runner.py +43 -18
  105. msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +21 -7
  106. msprobe/mindspore/api_accuracy_checker/checker_support_api.yaml +77 -0
  107. msprobe/mindspore/api_accuracy_checker/cmd_parser.py +64 -1
  108. msprobe/mindspore/api_accuracy_checker/compute_element.py +64 -31
  109. msprobe/mindspore/api_accuracy_checker/data_manager.py +301 -0
  110. msprobe/mindspore/api_accuracy_checker/main.py +28 -3
  111. msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +212 -0
  112. msprobe/mindspore/api_accuracy_checker/multi_data_manager.py +60 -0
  113. msprobe/mindspore/api_accuracy_checker/type_mapping.py +22 -5
  114. msprobe/mindspore/api_accuracy_checker/utils.py +34 -17
  115. msprobe/mindspore/cell_processor.py +33 -12
  116. msprobe/mindspore/code_mapping/bind.py +264 -0
  117. msprobe/mindspore/code_mapping/cmd_parser.py +40 -0
  118. msprobe/mindspore/code_mapping/graph.py +49 -0
  119. msprobe/mindspore/code_mapping/graph_parser.py +226 -0
  120. msprobe/mindspore/code_mapping/main.py +24 -0
  121. msprobe/mindspore/code_mapping/processor.py +34 -0
  122. msprobe/mindspore/common/const.py +35 -13
  123. msprobe/mindspore/common/log.py +5 -9
  124. msprobe/mindspore/common/utils.py +88 -4
  125. msprobe/mindspore/compare/distributed_compare.py +22 -24
  126. msprobe/mindspore/compare/ms_compare.py +333 -268
  127. msprobe/mindspore/compare/ms_graph_compare.py +95 -52
  128. msprobe/mindspore/debugger/debugger_config.py +7 -1
  129. msprobe/mindspore/debugger/precision_debugger.py +87 -12
  130. msprobe/mindspore/dump/dump_tool_factory.py +3 -1
  131. msprobe/mindspore/dump/hook_cell/api_registry.py +95 -18
  132. msprobe/mindspore/dump/hook_cell/hook_cell.py +60 -38
  133. msprobe/mindspore/dump/hook_cell/primitive_hooks.py +45 -30
  134. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +36 -1
  135. msprobe/mindspore/dump/hook_cell/wrap_api.py +92 -1
  136. msprobe/mindspore/dump/jit_dump.py +17 -5
  137. msprobe/mindspore/dump/kernel_dump/kernel_config.py +33 -0
  138. msprobe/mindspore/dump/kernel_graph_dump.py +9 -4
  139. msprobe/mindspore/dump/kernel_kbyk_dump.py +2 -4
  140. msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +140 -0
  141. msprobe/mindspore/dym_loader/hook_dynamic_loader.h +53 -0
  142. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +156 -41
  143. msprobe/mindspore/free_benchmark/common/handler_params.py +1 -2
  144. msprobe/mindspore/free_benchmark/common/utils.py +19 -4
  145. msprobe/mindspore/free_benchmark/data/support_wrap_ops.yaml +0 -204
  146. msprobe/mindspore/free_benchmark/handler/base_handler.py +3 -3
  147. msprobe/mindspore/free_benchmark/handler/check_handler.py +4 -5
  148. msprobe/mindspore/free_benchmark/handler/fix_handler.py +4 -4
  149. msprobe/mindspore/free_benchmark/handler/handler_factory.py +4 -4
  150. msprobe/mindspore/free_benchmark/perturbation/add_noise.py +2 -2
  151. msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +15 -6
  152. msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +2 -2
  153. msprobe/mindspore/free_benchmark/perturbation/exchange_value.py +2 -2
  154. msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +13 -6
  155. msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +2 -2
  156. msprobe/mindspore/free_benchmark/self_check_tool_factory.py +2 -2
  157. msprobe/mindspore/grad_probe/global_context.py +28 -8
  158. msprobe/mindspore/grad_probe/grad_analyzer.py +50 -24
  159. msprobe/mindspore/grad_probe/grad_monitor.py +16 -1
  160. msprobe/mindspore/grad_probe/grad_stat_csv.py +33 -5
  161. msprobe/mindspore/grad_probe/hook.py +35 -12
  162. msprobe/mindspore/grad_probe/utils.py +18 -5
  163. msprobe/mindspore/mindtorch/__init__.py +18 -0
  164. msprobe/mindspore/mindtorch/mindtorch_adaptor.py +255 -0
  165. msprobe/mindspore/ms_config.py +27 -16
  166. msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +9 -4
  167. msprobe/mindspore/runtime.py +15 -0
  168. msprobe/mindspore/service.py +285 -113
  169. msprobe/mindspore/task_handler_factory.py +15 -0
  170. msprobe/msprobe.py +48 -10
  171. msprobe/pytorch/__init__.py +8 -6
  172. msprobe/pytorch/api_accuracy_checker/common/config.py +62 -0
  173. msprobe/pytorch/api_accuracy_checker/common/utils.py +31 -16
  174. msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +41 -8
  175. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +103 -271
  176. msprobe/pytorch/api_accuracy_checker/compare/api_precision_standard.yaml +4 -1
  177. msprobe/pytorch/api_accuracy_checker/compare/compare.py +69 -68
  178. msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +54 -0
  179. msprobe/pytorch/api_accuracy_checker/compare/compare_input.py +51 -0
  180. msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +2 -4
  181. msprobe/pytorch/api_accuracy_checker/generate_op_script/config_op.json +9 -0
  182. msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +478 -0
  183. msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +365 -0
  184. msprobe/pytorch/api_accuracy_checker/precision_standard/absolute_threshold.py +106 -0
  185. msprobe/pytorch/api_accuracy_checker/precision_standard/accumulative_error_compare.py +107 -0
  186. msprobe/pytorch/api_accuracy_checker/precision_standard/base_standard.py +151 -0
  187. msprobe/pytorch/api_accuracy_checker/precision_standard/benchmark_compare.py +226 -0
  188. msprobe/pytorch/api_accuracy_checker/precision_standard/binary_consistency.py +68 -0
  189. msprobe/pytorch/api_accuracy_checker/precision_standard/standard_config.py +218 -0
  190. msprobe/pytorch/api_accuracy_checker/precision_standard/standard_register.py +104 -0
  191. msprobe/pytorch/api_accuracy_checker/precision_standard/thousandth_standard.py +63 -0
  192. msprobe/pytorch/api_accuracy_checker/precision_standard/ulp_compare.py +200 -0
  193. msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +63 -2
  194. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +21 -15
  195. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +54 -22
  196. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +140 -71
  197. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +49 -8
  198. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +9 -24
  199. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +4 -12
  200. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +5 -3
  201. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/dump_dispatch.py +9 -4
  202. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +3 -11
  203. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +2 -2
  204. msprobe/pytorch/bench_functions/confusion_transpose.py +5 -1
  205. msprobe/pytorch/bench_functions/matmul_backward.py +12 -0
  206. msprobe/pytorch/bench_functions/npu_fusion_attention.py +142 -16
  207. msprobe/pytorch/bench_functions/rotary_mul.py +4 -0
  208. msprobe/pytorch/bench_functions/swiglu.py +10 -2
  209. msprobe/pytorch/common/parse_json.py +7 -6
  210. msprobe/pytorch/common/utils.py +101 -7
  211. msprobe/pytorch/compare/distributed_compare.py +17 -30
  212. msprobe/pytorch/compare/pt_compare.py +44 -22
  213. msprobe/pytorch/debugger/debugger_config.py +46 -27
  214. msprobe/pytorch/debugger/precision_debugger.py +42 -12
  215. msprobe/pytorch/dump/kernel_dump/kernel_config.py +33 -0
  216. msprobe/pytorch/dump/module_dump/module_dump.py +86 -0
  217. msprobe/pytorch/{module_processer.py → dump/module_dump/module_processer.py} +81 -10
  218. msprobe/pytorch/free_benchmark/common/constant.py +15 -0
  219. msprobe/pytorch/free_benchmark/common/counter.py +15 -0
  220. msprobe/pytorch/free_benchmark/common/enums.py +15 -0
  221. msprobe/pytorch/free_benchmark/common/params.py +10 -2
  222. msprobe/pytorch/free_benchmark/common/utils.py +29 -4
  223. msprobe/pytorch/free_benchmark/compare/grad_saver.py +20 -5
  224. msprobe/pytorch/free_benchmark/compare/single_benchmark.py +2 -0
  225. msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +3 -1
  226. msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +6 -4
  227. msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +2 -0
  228. msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +4 -0
  229. msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +41 -47
  230. msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +6 -5
  231. msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +0 -4
  232. msprobe/pytorch/grad_probe/grad_monitor.py +23 -6
  233. msprobe/pytorch/grad_probe/grad_stat_csv.py +40 -10
  234. msprobe/pytorch/hook_module/__init__.py +1 -1
  235. msprobe/pytorch/hook_module/hook_module.py +14 -11
  236. msprobe/pytorch/hook_module/register_optimizer_hook.py +59 -0
  237. msprobe/pytorch/hook_module/support_wrap_ops.yaml +35 -0
  238. msprobe/pytorch/hook_module/wrap_distributed.py +6 -8
  239. msprobe/pytorch/hook_module/wrap_functional.py +0 -38
  240. msprobe/pytorch/monitor/__init__.py +0 -0
  241. msprobe/pytorch/monitor/anomaly_analyse.py +201 -0
  242. msprobe/pytorch/monitor/anomaly_detect.py +425 -0
  243. msprobe/pytorch/monitor/csv2tb.py +166 -0
  244. msprobe/pytorch/monitor/distributed/__init__.py +0 -0
  245. msprobe/pytorch/monitor/distributed/distributed_ops.yaml +19 -0
  246. msprobe/pytorch/monitor/distributed/stack_blacklist.yaml +5 -0
  247. msprobe/pytorch/monitor/distributed/wrap_distributed.py +283 -0
  248. msprobe/pytorch/monitor/features.py +108 -0
  249. msprobe/pytorch/monitor/module_hook.py +1076 -0
  250. msprobe/pytorch/monitor/module_metric.py +172 -0
  251. msprobe/pytorch/monitor/module_spec_verifier.py +95 -0
  252. msprobe/pytorch/monitor/optimizer_collect.py +333 -0
  253. msprobe/pytorch/monitor/unittest/__init__.py +0 -0
  254. msprobe/pytorch/monitor/unittest/test_monitor.py +160 -0
  255. msprobe/pytorch/monitor/utils.py +321 -0
  256. msprobe/pytorch/monitor/visualizer.py +59 -0
  257. msprobe/pytorch/online_dispatch/__init__.py +2 -3
  258. msprobe/pytorch/online_dispatch/compare.py +29 -38
  259. msprobe/pytorch/online_dispatch/dispatch.py +58 -27
  260. msprobe/pytorch/online_dispatch/dump_compare.py +21 -9
  261. msprobe/pytorch/online_dispatch/single_compare.py +53 -32
  262. msprobe/pytorch/online_dispatch/torch_ops_config.yaml +1 -1
  263. msprobe/pytorch/online_dispatch/utils.py +49 -21
  264. msprobe/pytorch/parse_tool/lib/compare.py +21 -27
  265. msprobe/pytorch/parse_tool/lib/config.py +6 -8
  266. msprobe/pytorch/parse_tool/lib/file_desc.py +15 -1
  267. msprobe/pytorch/parse_tool/lib/interactive_cli.py +10 -10
  268. msprobe/pytorch/parse_tool/lib/parse_exception.py +7 -7
  269. msprobe/pytorch/parse_tool/lib/parse_tool.py +12 -12
  270. msprobe/pytorch/parse_tool/lib/utils.py +33 -53
  271. msprobe/pytorch/parse_tool/lib/visualization.py +11 -10
  272. msprobe/pytorch/pt_config.py +31 -8
  273. msprobe/pytorch/service.py +188 -108
  274. msprobe/visualization/__init__.py +14 -0
  275. msprobe/visualization/builder/__init__.py +14 -0
  276. msprobe/visualization/builder/graph_builder.py +222 -0
  277. msprobe/visualization/builder/msprobe_adapter.py +227 -0
  278. msprobe/visualization/compare/__init__.py +14 -0
  279. msprobe/visualization/compare/graph_comparator.py +180 -0
  280. msprobe/visualization/compare/mode_adapter.py +197 -0
  281. msprobe/visualization/graph/__init__.py +14 -0
  282. msprobe/visualization/graph/base_node.py +119 -0
  283. msprobe/visualization/graph/distributed_analyzer.py +318 -0
  284. msprobe/visualization/graph/graph.py +209 -0
  285. msprobe/visualization/graph/node_colors.py +95 -0
  286. msprobe/visualization/graph/node_op.py +39 -0
  287. msprobe/visualization/graph_service.py +288 -0
  288. msprobe/visualization/utils.py +217 -0
  289. mindstudio_probe-1.1.0.dist-info/RECORD +0 -287
  290. msprobe/docs/04.acl_config_examples.md +0 -78
  291. msprobe/mindspore/compare/layer_mapping.py +0 -146
  292. msprobe/mindspore/compare/modify_mapping.py +0 -107
  293. msprobe/mindspore/free_benchmark/decorator/dec_forward.py +0 -57
  294. msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +0 -122
  295. msprobe/pytorch/functional/module_dump.py +0 -84
  296. {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/LICENSE +0 -0
  297. {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/top_level.txt +0 -0
  298. /msprobe/mindspore/{free_benchmark/decorator → code_mapping}/__init__.py +0 -0
  299. /msprobe/pytorch/{functional → dump/module_dump}/__init__.py +0 -0
@@ -24,15 +24,20 @@ from msprobe.core.common.utils import CompareException
24
24
  from msprobe.core.common.file_utils import get_json_contents, write_csv
25
25
  import torch
26
26
  from msprobe.core.common.const import CompareConst
27
- from msprobe.pytorch.api_accuracy_checker.compare.algorithm import get_rmse, get_error_balance, get_max_rel_err, \
28
- get_mean_rel_err, get_rel_err, get_abs_err, get_max_abs_err, get_rel_err_ratio, cosine_sim, get_rel_err_origin, \
29
- get_small_value_err_ratio, get_finite_and_infinite_mask, get_small_value_mask, check_inf_nan_value, \
30
- check_small_value, check_norm_value, get_abs_bench_with_eps, get_ulp_err
27
+ from msprobe.pytorch.api_accuracy_checker.precision_standard.standard_register import StandardRegistry
28
+ from msprobe.pytorch.api_accuracy_checker.precision_standard.absolute_threshold import AbsolutethdCompare
29
+ from msprobe.pytorch.api_accuracy_checker.precision_standard.benchmark_compare import BenchmarkCompare
30
+ from msprobe.pytorch.api_accuracy_checker.precision_standard.ulp_compare import UlpCompare
31
+ from msprobe.pytorch.api_accuracy_checker.precision_standard.binary_consistency import BinaryCompare
32
+ from msprobe.pytorch.api_accuracy_checker.precision_standard.thousandth_standard import ThousandthStdCompare
33
+ from msprobe.pytorch.api_accuracy_checker.precision_standard.accumulative_error_compare import AccumulativeErrorCompare
34
+ from msprobe.pytorch.api_accuracy_checker.compare.compare_input import CompareInput
35
+ from msprobe.pytorch.api_accuracy_checker.compare.algorithm import get_abs_err, get_max_abs_err, get_rel_err_ratio, \
36
+ cosine_sim, get_rel_err_origin, get_abs_bench_with_eps, compare_bool_tensor
31
37
  from msprobe.pytorch.api_accuracy_checker.common.config import msCheckerConfig
32
38
  from msprobe.pytorch.api_accuracy_checker.compare.compare_column import CompareColumn
33
39
  from msprobe.pytorch.api_accuracy_checker.compare.compare_utils import check_dtype_comparable, \
34
- DETAIL_TEST_ROWS, precision_configs, BENCHMARK_COMPARE_SUPPORT_LIST, absolute_standard_api, binary_standard_api, \
35
- ulp_standard_api, thousandth_standard_api, apis_threshold
40
+ DETAIL_TEST_ROWS, BENCHMARK_COMPARE_SUPPORT_LIST
36
41
  from msprobe.pytorch.api_accuracy_checker.common.utils import extract_basic_api_segments
37
42
  from msprobe.pytorch.common.log import logger
38
43
 
@@ -42,6 +47,7 @@ ResultInfo = namedtuple('ResultInfo', ['full_api_name', 'fwd_success_status', 'b
42
47
 
43
48
 
44
49
  INDEX_TEST_RESULT_GROUP = 3
50
+ BACKWARD_RESULT_GROUP = 4
45
51
  INDEX_FIRST_GROUP = 0
46
52
  INDEX_MESSAGE = -1
47
53
 
@@ -66,6 +72,8 @@ class Comparator:
66
72
  self.detail_save_path_list = \
67
73
  [self.detail_save_path_str.format(rank) for rank in config.online_config.rank_list]
68
74
 
75
+ self.registry = self._register_compare_func()
76
+
69
77
  if not is_continue_run_ut:
70
78
  self.write_csv_title()
71
79
  if stack_info_json_path:
@@ -101,22 +109,6 @@ class Comparator:
101
109
  compare_column.error_rate = 0
102
110
  return CompareConst.PASS, compare_column, ""
103
111
 
104
- @staticmethod
105
- def _compare_bool_tensor(bench_output, device_output):
106
- error_nums = (bench_output != device_output).sum()
107
- if bench_output.size == 0:
108
- return CompareConst.NAN, CompareConst.ERROR, "There is not bench calculation result."
109
- error_rate = float(error_nums / bench_output.size)
110
- result = CompareConst.PASS if error_rate == 0 else CompareConst.ERROR
111
- return error_rate, result, ""
112
-
113
- @staticmethod
114
- def _get_absolute_threshold_attribute(api_name, dtype):
115
- small_value_threshold = apis_threshold.get(api_name).get(dtype).get('small_value')
116
- small_value_atol = apis_threshold.get(api_name).get(dtype).get('small_value_atol')
117
- rtol = apis_threshold.get(api_name).get(dtype).get('rtol')
118
- return small_value_threshold, small_value_atol, rtol
119
-
120
112
  @staticmethod
121
113
  def _get_run_ut_detail(test_result):
122
114
  """get run_ut detail before write to csv, called by online run_ut"""
@@ -143,6 +135,36 @@ class Comparator:
143
135
  test_rows.append([subject] + list(test_subject))
144
136
  return test_rows
145
137
 
138
+ @staticmethod
139
+ def _binary_standard_compare(input_data):
140
+ binary_compare = BinaryCompare(input_data)
141
+ binary_compare.compare()
142
+
143
+ @staticmethod
144
+ def _thousandth_standard_compare(input_data):
145
+ thousandth_compare = ThousandthStdCompare(input_data)
146
+ thousandth_compare.compare()
147
+
148
+ @staticmethod
149
+ def _absolute_standard_compare(input_data):
150
+ absolute_compare = AbsolutethdCompare(input_data)
151
+ absolute_compare.compare()
152
+
153
+ @staticmethod
154
+ def _ulp_compare(input_data):
155
+ ulp_compare = UlpCompare(input_data)
156
+ ulp_compare.compare()
157
+
158
+ @staticmethod
159
+ def _benchmark_compare(input_data):
160
+ benchmark_compare = BenchmarkCompare(input_data)
161
+ benchmark_compare.compare()
162
+
163
+ @staticmethod
164
+ def _accumulative_error_compare(input_data):
165
+ accumulative_error_compare = AccumulativeErrorCompare(input_data)
166
+ accumulative_error_compare.compare()
167
+
146
168
  def write_csv_title(self):
147
169
  summary_test_rows = [
148
170
  [self.COLUMN_API_NAME,
@@ -163,6 +185,8 @@ class Comparator:
163
185
  df_row = list(test_result[:INDEX_TEST_RESULT_GROUP])
164
186
  if test_result[1] == CompareConst.SKIP:
165
187
  df_row.append(test_result[INDEX_TEST_RESULT_GROUP][INDEX_FIRST_GROUP][INDEX_MESSAGE])
188
+ elif test_result[2] == CompareConst.SKIP:
189
+ df_row.append(test_result[BACKWARD_RESULT_GROUP][INDEX_FIRST_GROUP][INDEX_MESSAGE])
166
190
  if self.stack_info:
167
191
  stack_info = "\n".join(self.stack_info[name])
168
192
  df_row.append(stack_info)
@@ -211,6 +235,7 @@ class Comparator:
211
235
  if backward_message:
212
236
  backward_column = CompareColumn()
213
237
  bwd_compare_alg_results = [backward_column.to_column_value(CompareConst.SKIP, backward_message)]
238
+ bwd_success_status = CompareConst.SKIP
214
239
  else:
215
240
  bwd_success_status = bwd_success_status if bwd_compare_alg_results is not None else CompareConst.SPACE
216
241
  result_info = ResultInfo(full_api_name,
@@ -226,6 +251,16 @@ class Comparator:
226
251
  return fwd_success_status == CompareConst.PASS, bwd_success_status == CompareConst.PASS \
227
252
  or bwd_success_status == CompareConst.SPACE
228
253
 
254
+ def _register_compare_func(self):
255
+ registry = StandardRegistry()
256
+ registry.register(CompareConst.ABSOLUTE_THRESHOLD, self._absolute_standard_compare)
257
+ registry.register(CompareConst.BINARY_CONSISTENCY, self._binary_standard_compare)
258
+ registry.register(CompareConst.ULP_COMPARE, self._ulp_compare)
259
+ registry.register(CompareConst.THOUSANDTH_STANDARD, self._thousandth_standard_compare)
260
+ registry.register(CompareConst.BENCHMARK, self._benchmark_compare)
261
+ registry.register(CompareConst.ACCUMULATIVE_ERROR_COMPARE, self._accumulative_error_compare)
262
+ return registry
263
+
229
264
  def _compare_core_wrapper(self, api_name, bench_output, device_output):
230
265
  detailed_result_total = []
231
266
  test_final_success = CompareConst.PASS
@@ -308,11 +343,13 @@ class Comparator:
308
343
  return CompareConst.ERROR, compare_column, f"Bench out dtype is {bench_output.dtype} but " \
309
344
  f"npu output dtype is {device_output.dtype}, cannot compare."
310
345
  message = ""
346
+ if bench_output.size == 0:
347
+ return CompareConst.ERROR, compare_column, "There is not bench calculation result."
311
348
  if bench_output.dtype in [bool, np.uint8, np.int8, np.int16, np.uint16, np.uint32, np.int32,
312
349
  np.int64, np.uint64]:
313
350
  message += f"Compare algorithm is not supported for {bench_output.dtype} data. " \
314
351
  f"Only judged by Error Rate."
315
- err_rate, status, msg = self._compare_bool_tensor(bench_output, device_output)
352
+ err_rate, status, msg = compare_bool_tensor(bench_output, device_output)
316
353
  message += msg + "\n"
317
354
  compare_column.error_rate = err_rate
318
355
  return status, compare_column, message
@@ -321,56 +358,20 @@ class Comparator:
321
358
  compare_column, npu_dtype)
322
359
  return status, compare_column, message
323
360
 
361
+ def _perform_comparison(self, api_name, input_data):
362
+ comparison_func = self.registry.get_comparison_function(api_name, None)
363
+ comparison_func(input_data)
364
+
324
365
  def _compare_float_tensor(self, api_name, bench_output, device_output, compare_column, dtype):
325
366
  message = ""
326
- abs_bench, abs_bench_with_eps = get_abs_bench_with_eps(bench_output, dtype)
367
+ _, abs_bench_with_eps = get_abs_bench_with_eps(bench_output, dtype)
327
368
  abs_err = get_abs_err(bench_output, device_output)
328
369
  rel_err_orign = get_rel_err_origin(abs_err, abs_bench_with_eps)
329
- if api_name in thousandth_standard_api:
330
- thousand_res, thousand_status = get_rel_err_ratio(rel_err_orign, CompareConst.THOUSAND_RATIO_THRESHOLD)
331
- compare_column.rel_err_thousandth = thousand_res
370
+ input_data = CompareInput(bench_output, device_output, compare_column, dtype, rel_err_orign)
332
371
  if str(dtype) in BENCHMARK_COMPARE_SUPPORT_LIST:
333
- both_finite_mask, inf_nan_mask = get_finite_and_infinite_mask(bench_output, device_output)
334
- if api_name in binary_standard_api:
335
- err_rate, _, _ = self._compare_bool_tensor(bench_output, device_output)
336
- compare_column.error_rate = err_rate
337
- elif api_name in absolute_standard_api:
338
- small_value_threshold, small_value_atol, rtol = self._get_absolute_threshold_attribute(
339
- api_name, str(dtype))
340
- rel_err = abs_err / abs_bench_with_eps
341
- small_value_mask = get_small_value_mask(abs_bench, both_finite_mask, small_value_threshold)
342
- normal_value_mask = np.logical_and(both_finite_mask, np.logical_not(small_value_mask))
343
- compare_column.inf_nan_error_ratio = check_inf_nan_value(inf_nan_mask, bench_output, device_output,
344
- dtype, rtol)
345
- compare_column.rel_err_ratio = check_norm_value(normal_value_mask, rel_err, rtol)
346
- compare_column.abs_err_ratio = check_small_value(abs_err, small_value_mask, small_value_atol)
347
- elif api_name in ulp_standard_api:
348
- if bench_output.size == 0:
349
- compare_column.max_ulp_error = 0
350
- compare_column.mean_ulp_error = 0
351
- compare_column.ulp_error_proportion = 0
352
- else:
353
- ulp_err = get_ulp_err(bench_output, device_output, dtype)
354
- compare_column.max_ulp_error = np.max(ulp_err)
355
- compare_column.mean_ulp_error = np.mean(ulp_err)
356
- if dtype == torch.float32:
357
- compare_column.ulp_error_proportion = \
358
- np.sum(ulp_err > CompareConst.ULP_FLOAT32_THRESHOLD) / bench_output.size
359
- else:
360
- compare_column.ulp_error_proportion = \
361
- np.sum(ulp_err > CompareConst.ULP_FLOAT16_THRESHOLD) / bench_output.size
362
- else:
363
- dtype_config = precision_configs.get(dtype)
364
- small_value_mask = get_small_value_mask(abs_bench, both_finite_mask, dtype_config['small_value'][0])
365
- abs_err_greater_mask = np.greater(abs_err, dtype_config['small_value_atol'][0])
366
- compare_column.small_value_err_ratio = get_small_value_err_ratio(small_value_mask, abs_err_greater_mask)
367
- rel_err = get_rel_err(abs_err, abs_bench_with_eps, small_value_mask, inf_nan_mask)
368
- compare_column.rmse = get_rmse(abs_err, np.logical_or(inf_nan_mask, small_value_mask))
369
- compare_column.eb = get_error_balance(bench_output, device_output)
370
- if rel_err.size == 0:
371
- return CompareConst.ERROR, compare_column, "Relative error result list is empty."
372
- compare_column.max_rel_error = get_max_rel_err(rel_err)
373
- compare_column.mean_rel_error = get_mean_rel_err(rel_err)
372
+ self._perform_comparison(api_name, input_data)
373
+ else:
374
+ message += f"The data type {dtype} is not supported for new precision standard."
374
375
 
375
376
  cos_res, cos_status, msg = cosine_sim(bench_output, device_output)
376
377
  compare_column.cosine_sim = cos_res
@@ -16,9 +16,17 @@
16
16
  # limitations under the License.
17
17
 
18
18
  from msprobe.core.common.const import CompareConst
19
+ from msprobe.pytorch.common.log import logger
19
20
 
20
21
 
21
22
  class CompareColumn:
23
+ __slots__ = [
24
+ 'bench_type', 'npu_type', 'shape', 'cosine_sim', 'max_abs_err', 'rel_err_hundredth',
25
+ 'rel_err_ten_thousandth', 'inf_nan_error_ratio', 'rel_err_ratio', 'abs_err_ratio',
26
+ 'small_value_err_ratio', 'max_rel_error', 'mean_rel_error', 'rmse', 'eb', 'max_ulp_error',
27
+ 'mean_ulp_error', 'ulp_error_proportion', 'error_rate', 'rel_err_thousandth'
28
+ ]
29
+
22
30
  def __init__(self):
23
31
  self.bench_type = CompareConst.SPACE
24
32
  self.npu_type = CompareConst.SPACE
@@ -41,6 +49,24 @@ class CompareColumn:
41
49
  self.mean_ulp_error = CompareConst.SPACE
42
50
  self.ulp_error_proportion = CompareConst.SPACE
43
51
 
52
+ def update(self, metrics):
53
+ """
54
+ Updates the object's attributes with the provided metrics.
55
+
56
+ Args:
57
+ metrics (dict): A dictionary containing attribute names and their corresponding values.
58
+
59
+ Raises:
60
+ AttributeError: If the metric key is not a valid attribute of CompareColumn.
61
+ """
62
+ for key, value in metrics.items():
63
+ if value is None:
64
+ continue
65
+ if key not in self.__slots__:
66
+ logger.error(f"The key '{key}' is not a valid attribute of CompareColumn.")
67
+ continue
68
+ setattr(self, key, value)
69
+
44
70
  def to_column_value(self, is_pass, message):
45
71
  return [self.bench_type, self.npu_type, self.shape, self.cosine_sim, self.max_abs_err, self.rel_err_hundredth,
46
72
  self.rel_err_thousandth, self.rel_err_ten_thousandth, self.error_rate, self.eb, self.rmse,
@@ -50,6 +76,16 @@ class CompareColumn:
50
76
 
51
77
 
52
78
  class ApiPrecisionOutputColumn:
79
+ __slots__ = [
80
+ 'api_name', 'small_value_err_ratio', 'small_value_err_status', 'rmse_ratio', 'rmse_status',
81
+ 'max_rel_err_ratio', 'max_rel_err_status', 'mean_rel_err_ratio', 'mean_rel_err_status', 'eb_ratio',
82
+ 'eb_status', 'inf_nan_error_ratio', 'inf_nan_error_ratio_status', 'rel_err_ratio',
83
+ 'rel_err_ratio_status', 'abs_err_ratio', 'abs_err_ratio_status', 'error_rate', 'error_rate_status',
84
+ 'mean_ulp_err', 'ulp_err_proportion', 'ulp_err_proportion_ratio', 'ulp_err_status',
85
+ 'rel_err_thousandth', 'rel_err_thousandth_status', 'compare_result', 'compare_algorithm',
86
+ 'compare_message'
87
+ ]
88
+
53
89
  def __init__(self):
54
90
  self.api_name = CompareConst.SPACE
55
91
  self.small_value_err_ratio = CompareConst.SPACE
@@ -80,6 +116,24 @@ class ApiPrecisionOutputColumn:
80
116
  self.compare_algorithm = CompareConst.SPACE
81
117
  self.compare_message = CompareConst.SPACE
82
118
 
119
+ def update(self, metrics):
120
+ """
121
+ Updates the object's attributes with the provided metrics.
122
+
123
+ Args:
124
+ metrics (dict): A dictionary containing attribute names and their corresponding values.
125
+
126
+ Raises:
127
+ AttributeError: If the metric key is not a valid attribute of CompareColumn.
128
+ """
129
+ for key, value in metrics.items():
130
+ if value is None:
131
+ continue
132
+ if key not in self.__slots__:
133
+ logger.error("The key '%s' is not a valid attribute of CompareColumn.", key)
134
+ continue
135
+ setattr(self, key, value)
136
+
83
137
  def to_column_value(self):
84
138
  return [self.api_name, self.small_value_err_ratio, self.small_value_err_status, self.rmse_ratio,
85
139
  self.rmse_status, self.max_rel_err_ratio, self.max_rel_err_status, self.mean_rel_err_ratio,
@@ -0,0 +1,51 @@
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
4
+ # All rights reserved.
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+
18
+ import numpy as np
19
+
20
+
21
+ class CompareInput:
22
+ """
23
+ A class to encapsulate the input data required for comparison operations.
24
+
25
+ Attributes:
26
+ bench_output (np.ndarray): The benchmark output values.
27
+ device_output (np.ndarray): The device output values.
28
+ compare_column (class): A clasee to store and update comparison metrics.
29
+ dtype (type, optional): The data type of the outputs. Defaults to None.
30
+ rel_err_orign (float or array-like, optional): The original relative error values. Defaults to None.
31
+
32
+ Methods:
33
+ __init__(bench_output, device_output, compare_column, dtype, rel_err_orign):
34
+ Initializes an instance of CompareInput.
35
+ """
36
+ def __init__(self, bench_output, device_output, compare_column, dtype=None, rel_err_orign=None):
37
+ self.bench_output = bench_output
38
+ self.device_output = device_output
39
+ if not isinstance(bench_output, np.ndarray) or not isinstance(device_output, np.ndarray):
40
+ raise TypeError("The input should be numpy array")
41
+ self.compare_column = compare_column
42
+ self.dtype = dtype
43
+ self.rel_err_orign = rel_err_orign
44
+
45
+
46
+ class PrecisionCompareInput:
47
+ def __init__(self, row_npu, row_gpu, dtype, compare_column):
48
+ self.row_npu = row_npu
49
+ self.row_gpu = row_gpu
50
+ self.dtype = dtype
51
+ self.compare_column = compare_column
@@ -43,10 +43,7 @@ absolute_standard_api = apis.get('AbsoluteThreshStandard')
43
43
  binary_standard_api = apis.get('BinaryCompareStandard')
44
44
  ulp_standard_api = apis.get('ULPStandard')
45
45
  thousandth_standard_api = apis.get('ThousandthStandard')
46
-
47
-
48
- threshold_yaml_path = os.path.join(cur_path, "api_precision_threshold.yaml")
49
- apis_threshold = load_yaml(threshold_yaml_path)
46
+ accumulative_error_standard_api = apis.get('AccumulativeErrorStandard')
50
47
 
51
48
 
52
49
  DETAIL_TEST_ROWS = [
@@ -134,6 +131,7 @@ ULP_PARAMETERS = {
134
131
  class ApiPrecisionCompareColumn:
135
132
  API_NAME = 'API Name'
136
133
  DEVICE_DTYPE = 'DEVICE Dtype'
134
+ SHAPE = 'Shape'
137
135
  SMALL_VALUE_ERROR_RATE = '小值域错误占比'
138
136
  RMSE = '均方根误差'
139
137
  MAX_REL_ERR = '相对误差最大值'
@@ -0,0 +1,9 @@
1
+ {
2
+ "dump_json_path": "./dump.json",
3
+ "api_name": "",
4
+ "extract_api_path": "",
5
+ "propagation": "forward",
6
+ "data_mode": "random_data",
7
+ "random_seed": 1234,
8
+ "iter_times": 1
9
+ }