mindstudio-probe 1.1.0__py3-none-any.whl → 1.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (299) hide show
  1. {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/METADATA +7 -6
  2. mindstudio_probe-1.2.1.dist-info/RECORD +396 -0
  3. {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/WHEEL +1 -1
  4. {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/entry_points.txt +0 -1
  5. msprobe/CMakeLists.txt +5 -0
  6. msprobe/README.md +51 -20
  7. msprobe/config.json +2 -3
  8. msprobe/core/advisor/advisor.py +8 -3
  9. msprobe/core/common/const.py +264 -15
  10. msprobe/core/common/exceptions.py +27 -3
  11. msprobe/core/common/file_utils.py +176 -26
  12. msprobe/core/common/inplace_op_checker.py +15 -0
  13. msprobe/core/common/inplace_ops.yaml +3 -0
  14. msprobe/core/common/log.py +27 -9
  15. msprobe/core/common/utils.py +204 -77
  16. msprobe/core/common_config.py +49 -14
  17. msprobe/core/compare/acc_compare.py +274 -198
  18. msprobe/core/compare/check.py +32 -33
  19. msprobe/core/compare/compare_cli.py +32 -14
  20. msprobe/core/compare/highlight.py +283 -127
  21. msprobe/core/compare/layer_mapping/__init__.py +19 -0
  22. msprobe/core/compare/layer_mapping/data_scope_parser.py +246 -0
  23. msprobe/core/compare/layer_mapping/layer_mapping.py +249 -0
  24. msprobe/core/compare/layer_mapping/postprocess_pass.py +95 -0
  25. msprobe/core/compare/merge_result/merge_result.py +380 -0
  26. msprobe/core/compare/merge_result/merge_result_cli.py +31 -0
  27. msprobe/core/compare/multiprocessing_compute.py +2 -2
  28. msprobe/core/compare/npy_compare.py +135 -144
  29. msprobe/core/compare/utils.py +419 -274
  30. msprobe/core/data_dump/data_collector.py +60 -28
  31. msprobe/core/data_dump/data_processor/base.py +84 -36
  32. msprobe/core/data_dump/data_processor/factory.py +5 -3
  33. msprobe/core/data_dump/data_processor/mindspore_processor.py +152 -18
  34. msprobe/core/data_dump/data_processor/pytorch_processor.py +267 -110
  35. msprobe/core/data_dump/json_writer.py +29 -1
  36. msprobe/core/data_dump/scope.py +119 -39
  37. msprobe/core/grad_probe/constant.py +27 -13
  38. msprobe/core/grad_probe/grad_compare.py +18 -1
  39. msprobe/core/grad_probe/utils.py +30 -2
  40. msprobe/core/overflow_check/abnormal_scene.py +189 -0
  41. msprobe/core/overflow_check/api_info.py +55 -0
  42. msprobe/core/overflow_check/checker.py +138 -0
  43. msprobe/core/overflow_check/filter.py +157 -0
  44. msprobe/core/overflow_check/ignore_rules.yaml +55 -0
  45. msprobe/core/overflow_check/level.py +22 -0
  46. msprobe/core/overflow_check/utils.py +28 -0
  47. msprobe/docs/01.installation.md +96 -7
  48. msprobe/docs/02.config_introduction.md +50 -23
  49. msprobe/docs/03.config_examples.md +2 -9
  50. msprobe/docs/04.kernel_dump_PyTorch.md +73 -0
  51. msprobe/docs/05.data_dump_PyTorch.md +93 -61
  52. msprobe/docs/06.data_dump_MindSpore.md +200 -95
  53. msprobe/docs/07.accuracy_checker_PyTorch.md +28 -28
  54. msprobe/docs/08.accuracy_checker_online_PyTorch.md +1 -6
  55. msprobe/docs/09.accuracy_checker_MindSpore.md +44 -8
  56. msprobe/docs/10.accuracy_compare_PyTorch.md +114 -50
  57. msprobe/docs/11.accuracy_compare_MindSpore.md +340 -48
  58. msprobe/docs/12.overflow_check_PyTorch.md +2 -2
  59. msprobe/docs/13.overflow_check_MindSpore.md +6 -6
  60. msprobe/docs/15.free_benchmarking_PyTorch.md +4 -5
  61. msprobe/docs/16.free_benchmarking_MindSpore.md +56 -37
  62. msprobe/docs/17.grad_probe.md +5 -6
  63. msprobe/docs/19.monitor.md +561 -0
  64. msprobe/docs/20.monitor_performance_baseline.md +52 -0
  65. msprobe/docs/21.visualization_PyTorch.md +466 -0
  66. msprobe/docs/22.visualization_MindSpore.md +481 -0
  67. msprobe/docs/23.generate_operator_PyTorch.md +107 -0
  68. msprobe/docs/24.code_mapping_Mindspore.md +28 -0
  69. msprobe/docs/25.tool_function_introduction.md +29 -0
  70. msprobe/docs/26.data_dump_PyTorch_baseline.md +37 -0
  71. msprobe/docs/27.dump_json_instruction.md +521 -0
  72. msprobe/docs/FAQ.md +29 -2
  73. msprobe/docs/accuracy_checker_MindSpore/accuracy_checker_MindSpore_baseline.md +14 -0
  74. msprobe/docs/data_dump_MindSpore/data_dump_MindSpore_baseline.md +22 -0
  75. msprobe/docs/data_dump_MindSpore/dynamic_graph_quick_start_example.md +211 -0
  76. msprobe/docs/img/compare_result.png +0 -0
  77. msprobe/docs/img/merge_result.png +0 -0
  78. msprobe/docs/img/monitor/cpu_info.png +0 -0
  79. msprobe/docs/img/visualization/fuzzy_match_ms.png +0 -0
  80. msprobe/docs/img/visualization/fuzzy_match_pt.png +0 -0
  81. msprobe/docs/img/visualization/tensorboard_1.png +0 -0
  82. msprobe/docs/img/visualization/tensorboard_2.png +0 -0
  83. msprobe/docs/img/visualization/vis_browser_1.png +0 -0
  84. msprobe/docs/img/visualization/vis_browser_2.png +0 -0
  85. msprobe/docs/img/visualization/vis_precision_info.png +0 -0
  86. msprobe/docs/img/visualization/vis_search_info.png +0 -0
  87. msprobe/docs/img/visualization/vis_show_info.png +0 -0
  88. msprobe/docs/img/visualization/vis_showcase.png +0 -0
  89. msprobe/docs/img/visualization/vis_unmatch_info.png +0 -0
  90. msprobe/docs/visualization/GPTModel.png +0 -0
  91. msprobe/docs/visualization/ParallelMLP.png +0 -0
  92. msprobe/docs/visualization/layer_mapping_example.md +132 -0
  93. msprobe/docs/visualization/mapping.png +0 -0
  94. msprobe/docs/visualization/mapping1.png +0 -0
  95. msprobe/docs/visualization/module_name.png +0 -0
  96. msprobe/docs/visualization/module_name1.png +0 -0
  97. msprobe/docs/visualization/no_mapping.png +0 -0
  98. msprobe/docs/visualization/no_mapping1.png +0 -0
  99. msprobe/docs/visualization/no_mapping_analyze.png +0 -0
  100. msprobe/docs/visualization/top_layer.png +0 -0
  101. msprobe/mindspore/__init__.py +25 -0
  102. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +151 -151
  103. msprobe/mindspore/api_accuracy_checker/api_info.py +21 -6
  104. msprobe/mindspore/api_accuracy_checker/api_runner.py +43 -18
  105. msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +21 -7
  106. msprobe/mindspore/api_accuracy_checker/checker_support_api.yaml +77 -0
  107. msprobe/mindspore/api_accuracy_checker/cmd_parser.py +64 -1
  108. msprobe/mindspore/api_accuracy_checker/compute_element.py +64 -31
  109. msprobe/mindspore/api_accuracy_checker/data_manager.py +301 -0
  110. msprobe/mindspore/api_accuracy_checker/main.py +28 -3
  111. msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +212 -0
  112. msprobe/mindspore/api_accuracy_checker/multi_data_manager.py +60 -0
  113. msprobe/mindspore/api_accuracy_checker/type_mapping.py +22 -5
  114. msprobe/mindspore/api_accuracy_checker/utils.py +34 -17
  115. msprobe/mindspore/cell_processor.py +33 -12
  116. msprobe/mindspore/code_mapping/bind.py +264 -0
  117. msprobe/mindspore/code_mapping/cmd_parser.py +40 -0
  118. msprobe/mindspore/code_mapping/graph.py +49 -0
  119. msprobe/mindspore/code_mapping/graph_parser.py +226 -0
  120. msprobe/mindspore/code_mapping/main.py +24 -0
  121. msprobe/mindspore/code_mapping/processor.py +34 -0
  122. msprobe/mindspore/common/const.py +35 -13
  123. msprobe/mindspore/common/log.py +5 -9
  124. msprobe/mindspore/common/utils.py +88 -4
  125. msprobe/mindspore/compare/distributed_compare.py +22 -24
  126. msprobe/mindspore/compare/ms_compare.py +333 -268
  127. msprobe/mindspore/compare/ms_graph_compare.py +95 -52
  128. msprobe/mindspore/debugger/debugger_config.py +7 -1
  129. msprobe/mindspore/debugger/precision_debugger.py +87 -12
  130. msprobe/mindspore/dump/dump_tool_factory.py +3 -1
  131. msprobe/mindspore/dump/hook_cell/api_registry.py +95 -18
  132. msprobe/mindspore/dump/hook_cell/hook_cell.py +60 -38
  133. msprobe/mindspore/dump/hook_cell/primitive_hooks.py +45 -30
  134. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +36 -1
  135. msprobe/mindspore/dump/hook_cell/wrap_api.py +92 -1
  136. msprobe/mindspore/dump/jit_dump.py +17 -5
  137. msprobe/mindspore/dump/kernel_dump/kernel_config.py +33 -0
  138. msprobe/mindspore/dump/kernel_graph_dump.py +9 -4
  139. msprobe/mindspore/dump/kernel_kbyk_dump.py +2 -4
  140. msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +140 -0
  141. msprobe/mindspore/dym_loader/hook_dynamic_loader.h +53 -0
  142. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +156 -41
  143. msprobe/mindspore/free_benchmark/common/handler_params.py +1 -2
  144. msprobe/mindspore/free_benchmark/common/utils.py +19 -4
  145. msprobe/mindspore/free_benchmark/data/support_wrap_ops.yaml +0 -204
  146. msprobe/mindspore/free_benchmark/handler/base_handler.py +3 -3
  147. msprobe/mindspore/free_benchmark/handler/check_handler.py +4 -5
  148. msprobe/mindspore/free_benchmark/handler/fix_handler.py +4 -4
  149. msprobe/mindspore/free_benchmark/handler/handler_factory.py +4 -4
  150. msprobe/mindspore/free_benchmark/perturbation/add_noise.py +2 -2
  151. msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +15 -6
  152. msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +2 -2
  153. msprobe/mindspore/free_benchmark/perturbation/exchange_value.py +2 -2
  154. msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +13 -6
  155. msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +2 -2
  156. msprobe/mindspore/free_benchmark/self_check_tool_factory.py +2 -2
  157. msprobe/mindspore/grad_probe/global_context.py +28 -8
  158. msprobe/mindspore/grad_probe/grad_analyzer.py +50 -24
  159. msprobe/mindspore/grad_probe/grad_monitor.py +16 -1
  160. msprobe/mindspore/grad_probe/grad_stat_csv.py +33 -5
  161. msprobe/mindspore/grad_probe/hook.py +35 -12
  162. msprobe/mindspore/grad_probe/utils.py +18 -5
  163. msprobe/mindspore/mindtorch/__init__.py +18 -0
  164. msprobe/mindspore/mindtorch/mindtorch_adaptor.py +255 -0
  165. msprobe/mindspore/ms_config.py +27 -16
  166. msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +9 -4
  167. msprobe/mindspore/runtime.py +15 -0
  168. msprobe/mindspore/service.py +285 -113
  169. msprobe/mindspore/task_handler_factory.py +15 -0
  170. msprobe/msprobe.py +48 -10
  171. msprobe/pytorch/__init__.py +8 -6
  172. msprobe/pytorch/api_accuracy_checker/common/config.py +62 -0
  173. msprobe/pytorch/api_accuracy_checker/common/utils.py +31 -16
  174. msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +41 -8
  175. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +103 -271
  176. msprobe/pytorch/api_accuracy_checker/compare/api_precision_standard.yaml +4 -1
  177. msprobe/pytorch/api_accuracy_checker/compare/compare.py +69 -68
  178. msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +54 -0
  179. msprobe/pytorch/api_accuracy_checker/compare/compare_input.py +51 -0
  180. msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +2 -4
  181. msprobe/pytorch/api_accuracy_checker/generate_op_script/config_op.json +9 -0
  182. msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +478 -0
  183. msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +365 -0
  184. msprobe/pytorch/api_accuracy_checker/precision_standard/absolute_threshold.py +106 -0
  185. msprobe/pytorch/api_accuracy_checker/precision_standard/accumulative_error_compare.py +107 -0
  186. msprobe/pytorch/api_accuracy_checker/precision_standard/base_standard.py +151 -0
  187. msprobe/pytorch/api_accuracy_checker/precision_standard/benchmark_compare.py +226 -0
  188. msprobe/pytorch/api_accuracy_checker/precision_standard/binary_consistency.py +68 -0
  189. msprobe/pytorch/api_accuracy_checker/precision_standard/standard_config.py +218 -0
  190. msprobe/pytorch/api_accuracy_checker/precision_standard/standard_register.py +104 -0
  191. msprobe/pytorch/api_accuracy_checker/precision_standard/thousandth_standard.py +63 -0
  192. msprobe/pytorch/api_accuracy_checker/precision_standard/ulp_compare.py +200 -0
  193. msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +63 -2
  194. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +21 -15
  195. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +54 -22
  196. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +140 -71
  197. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +49 -8
  198. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +9 -24
  199. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +4 -12
  200. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +5 -3
  201. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/dump_dispatch.py +9 -4
  202. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +3 -11
  203. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +2 -2
  204. msprobe/pytorch/bench_functions/confusion_transpose.py +5 -1
  205. msprobe/pytorch/bench_functions/matmul_backward.py +12 -0
  206. msprobe/pytorch/bench_functions/npu_fusion_attention.py +142 -16
  207. msprobe/pytorch/bench_functions/rotary_mul.py +4 -0
  208. msprobe/pytorch/bench_functions/swiglu.py +10 -2
  209. msprobe/pytorch/common/parse_json.py +7 -6
  210. msprobe/pytorch/common/utils.py +101 -7
  211. msprobe/pytorch/compare/distributed_compare.py +17 -30
  212. msprobe/pytorch/compare/pt_compare.py +44 -22
  213. msprobe/pytorch/debugger/debugger_config.py +46 -27
  214. msprobe/pytorch/debugger/precision_debugger.py +42 -12
  215. msprobe/pytorch/dump/kernel_dump/kernel_config.py +33 -0
  216. msprobe/pytorch/dump/module_dump/module_dump.py +86 -0
  217. msprobe/pytorch/{module_processer.py → dump/module_dump/module_processer.py} +81 -10
  218. msprobe/pytorch/free_benchmark/common/constant.py +15 -0
  219. msprobe/pytorch/free_benchmark/common/counter.py +15 -0
  220. msprobe/pytorch/free_benchmark/common/enums.py +15 -0
  221. msprobe/pytorch/free_benchmark/common/params.py +10 -2
  222. msprobe/pytorch/free_benchmark/common/utils.py +29 -4
  223. msprobe/pytorch/free_benchmark/compare/grad_saver.py +20 -5
  224. msprobe/pytorch/free_benchmark/compare/single_benchmark.py +2 -0
  225. msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +3 -1
  226. msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +6 -4
  227. msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +2 -0
  228. msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +4 -0
  229. msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +41 -47
  230. msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +6 -5
  231. msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +0 -4
  232. msprobe/pytorch/grad_probe/grad_monitor.py +23 -6
  233. msprobe/pytorch/grad_probe/grad_stat_csv.py +40 -10
  234. msprobe/pytorch/hook_module/__init__.py +1 -1
  235. msprobe/pytorch/hook_module/hook_module.py +14 -11
  236. msprobe/pytorch/hook_module/register_optimizer_hook.py +59 -0
  237. msprobe/pytorch/hook_module/support_wrap_ops.yaml +35 -0
  238. msprobe/pytorch/hook_module/wrap_distributed.py +6 -8
  239. msprobe/pytorch/hook_module/wrap_functional.py +0 -38
  240. msprobe/pytorch/monitor/__init__.py +0 -0
  241. msprobe/pytorch/monitor/anomaly_analyse.py +201 -0
  242. msprobe/pytorch/monitor/anomaly_detect.py +425 -0
  243. msprobe/pytorch/monitor/csv2tb.py +166 -0
  244. msprobe/pytorch/monitor/distributed/__init__.py +0 -0
  245. msprobe/pytorch/monitor/distributed/distributed_ops.yaml +19 -0
  246. msprobe/pytorch/monitor/distributed/stack_blacklist.yaml +5 -0
  247. msprobe/pytorch/monitor/distributed/wrap_distributed.py +283 -0
  248. msprobe/pytorch/monitor/features.py +108 -0
  249. msprobe/pytorch/monitor/module_hook.py +1076 -0
  250. msprobe/pytorch/monitor/module_metric.py +172 -0
  251. msprobe/pytorch/monitor/module_spec_verifier.py +95 -0
  252. msprobe/pytorch/monitor/optimizer_collect.py +333 -0
  253. msprobe/pytorch/monitor/unittest/__init__.py +0 -0
  254. msprobe/pytorch/monitor/unittest/test_monitor.py +160 -0
  255. msprobe/pytorch/monitor/utils.py +321 -0
  256. msprobe/pytorch/monitor/visualizer.py +59 -0
  257. msprobe/pytorch/online_dispatch/__init__.py +2 -3
  258. msprobe/pytorch/online_dispatch/compare.py +29 -38
  259. msprobe/pytorch/online_dispatch/dispatch.py +58 -27
  260. msprobe/pytorch/online_dispatch/dump_compare.py +21 -9
  261. msprobe/pytorch/online_dispatch/single_compare.py +53 -32
  262. msprobe/pytorch/online_dispatch/torch_ops_config.yaml +1 -1
  263. msprobe/pytorch/online_dispatch/utils.py +49 -21
  264. msprobe/pytorch/parse_tool/lib/compare.py +21 -27
  265. msprobe/pytorch/parse_tool/lib/config.py +6 -8
  266. msprobe/pytorch/parse_tool/lib/file_desc.py +15 -1
  267. msprobe/pytorch/parse_tool/lib/interactive_cli.py +10 -10
  268. msprobe/pytorch/parse_tool/lib/parse_exception.py +7 -7
  269. msprobe/pytorch/parse_tool/lib/parse_tool.py +12 -12
  270. msprobe/pytorch/parse_tool/lib/utils.py +33 -53
  271. msprobe/pytorch/parse_tool/lib/visualization.py +11 -10
  272. msprobe/pytorch/pt_config.py +31 -8
  273. msprobe/pytorch/service.py +188 -108
  274. msprobe/visualization/__init__.py +14 -0
  275. msprobe/visualization/builder/__init__.py +14 -0
  276. msprobe/visualization/builder/graph_builder.py +222 -0
  277. msprobe/visualization/builder/msprobe_adapter.py +227 -0
  278. msprobe/visualization/compare/__init__.py +14 -0
  279. msprobe/visualization/compare/graph_comparator.py +180 -0
  280. msprobe/visualization/compare/mode_adapter.py +197 -0
  281. msprobe/visualization/graph/__init__.py +14 -0
  282. msprobe/visualization/graph/base_node.py +119 -0
  283. msprobe/visualization/graph/distributed_analyzer.py +318 -0
  284. msprobe/visualization/graph/graph.py +209 -0
  285. msprobe/visualization/graph/node_colors.py +95 -0
  286. msprobe/visualization/graph/node_op.py +39 -0
  287. msprobe/visualization/graph_service.py +288 -0
  288. msprobe/visualization/utils.py +217 -0
  289. mindstudio_probe-1.1.0.dist-info/RECORD +0 -287
  290. msprobe/docs/04.acl_config_examples.md +0 -78
  291. msprobe/mindspore/compare/layer_mapping.py +0 -146
  292. msprobe/mindspore/compare/modify_mapping.py +0 -107
  293. msprobe/mindspore/free_benchmark/decorator/dec_forward.py +0 -57
  294. msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +0 -122
  295. msprobe/pytorch/functional/module_dump.py +0 -84
  296. {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/LICENSE +0 -0
  297. {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/top_level.txt +0 -0
  298. /msprobe/mindspore/{free_benchmark/decorator → code_mapping}/__init__.py +0 -0
  299. /msprobe/pytorch/{functional → dump/module_dump}/__init__.py +0 -0
@@ -1,6 +1,4 @@
1
- #!/usr/bin/env python3
2
- # -*- coding: utf-8 -*-
3
- # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
1
+ # Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
4
2
  # All rights reserved.
5
3
  #
6
4
  # Licensed under the Apache License, Version 2.0 (the "License");
@@ -16,8 +14,12 @@
16
14
  # limitations under the License.
17
15
 
18
16
 
19
- from .debugger.precision_debugger import PrecisionDebugger
20
- from .common.utils import seed_all
17
+ import torch
21
18
  from .compare.distributed_compare import compare_distributed
22
19
  from .compare.pt_compare import compare
23
- from .functional.module_dump import module_dump, module_dump_end
20
+ from .common.utils import seed_all
21
+ from .debugger.precision_debugger import PrecisionDebugger, module_dump, module_dump_end
22
+
23
+ torch_version_above_or_equal_2 = torch.__version__.split('+')[0] >= '2.0'
24
+ if torch_version_above_or_equal_2:
25
+ from msprobe.pytorch.monitor.module_hook import TrainerMon
@@ -16,10 +16,18 @@
16
16
  # limitations under the License.
17
17
 
18
18
  import os
19
+ from collections import namedtuple
19
20
  from msprobe.core.common.file_utils import load_yaml, check_file_or_directory_path
21
+ from msprobe.core.common.utils import is_int
20
22
  from msprobe.pytorch.pt_config import RunUTConfig
21
23
 
22
24
 
25
+ RunUtConfig = namedtuple('RunUtConfig', ['forward_content', 'backward_content', 'result_csv_path', 'details_csv_path',
26
+ 'save_error_data', 'is_continue_run_ut', 'real_data_path', 'white_list',
27
+ 'black_list', 'error_data_path', 'online_config'])
28
+ OnlineConfig = namedtuple('OnlineConfig', ['is_online', 'nfs_path', 'host', 'port', 'rank_list', 'tls_path'])
29
+
30
+
23
31
  class Config:
24
32
  def __init__(self, yaml_file):
25
33
  check_file_or_directory_path(yaml_file, False)
@@ -50,6 +58,8 @@ class Config:
50
58
  raise ValueError(f"{key} must be one of {validators.keys()}")
51
59
  if not isinstance(value, validators.get(key)):
52
60
  raise ValueError(f"{key} must be {validators[key].__name__} type")
61
+ if key == 'precision' and not is_int(value):
62
+ raise ValueError("precision must be an integer")
53
63
  if key == 'precision' and (value < 0 or value > 20):
54
64
  raise ValueError("precision must be greater than or equal to 0 and less than 21")
55
65
  if key == 'white_list':
@@ -68,3 +78,55 @@ class Config:
68
78
  cur_path = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
69
79
  yaml_path = os.path.join(cur_path, "config.yaml")
70
80
  msCheckerConfig = Config(yaml_path)
81
+
82
+
83
+ class CheckerConfig:
84
+ def __init__(self, task_config=None):
85
+ self.white_list = msCheckerConfig.white_list
86
+ self.black_list = msCheckerConfig.black_list
87
+ self.error_data_path = msCheckerConfig.error_data_path
88
+ self.is_online = msCheckerConfig.is_online
89
+ self.nfs_path = msCheckerConfig.nfs_path
90
+ self.host = msCheckerConfig.host
91
+ self.port = msCheckerConfig.port
92
+ self.rank_list = msCheckerConfig.rank_list
93
+ self.tls_path = msCheckerConfig.tls_path
94
+
95
+ if task_config:
96
+ self.load_config(task_config)
97
+
98
+ def load_config(self, task_config):
99
+ self.white_list = task_config.white_list
100
+ self.black_list = task_config.black_list
101
+ self.error_data_path = task_config.error_data_path
102
+ self.is_online = task_config.is_online
103
+ self.nfs_path = task_config.nfs_path
104
+ self.host = task_config.host
105
+ self.port = task_config.port
106
+ self.rank_list = task_config.rank_list
107
+ self.tls_path = task_config.tls_path
108
+
109
+ def get_online_config(self):
110
+ return OnlineConfig(
111
+ is_online=self.is_online,
112
+ nfs_path=self.nfs_path,
113
+ host=self.host,
114
+ port=self.port,
115
+ rank_list=self.rank_list,
116
+ tls_path=self.tls_path
117
+ )
118
+
119
+ def get_run_ut_config(self, **config_params):
120
+ return RunUtConfig(
121
+ forward_content=config_params.get('forward_content'),
122
+ backward_content=config_params.get('backward_content'),
123
+ result_csv_path=config_params.get('result_csv_path'),
124
+ details_csv_path=config_params.get('details_csv_path'),
125
+ save_error_data=config_params.get('save_error_data'),
126
+ is_continue_run_ut=config_params.get('is_continue_run_ut'),
127
+ real_data_path=config_params.get('real_data_path'),
128
+ white_list=self.white_list,
129
+ black_list=self.black_list,
130
+ error_data_path=config_params.get('error_data_path'),
131
+ online_config=self.get_online_config()
132
+ )
@@ -72,38 +72,53 @@ def check_need_convert(api_name):
72
72
  return convert_type
73
73
 
74
74
 
75
- def api_info_preprocess(api_name, api_info_dict):
75
+ def cross_entropy_process(api_info_dict):
76
76
  """
77
77
  Function Description:
78
- Preprocesses the API information.
78
+ Preprocesses the cross_entropy API information.
79
79
  Parameter:
80
- api_name: Name of the API.
81
80
  api_info_dict: argument of the API.
82
81
  Return api_info_dict:
83
- convert_type: Type of conversion.
84
82
  api_info_dict: Processed argument of the API.
85
83
  """
86
- convert_type = check_need_convert(api_name)
87
- if api_name == 'cross_entropy':
88
- api_info_dict = cross_entropy_process(api_info_dict)
89
- return convert_type, api_info_dict
84
+ if 'input_args' in api_info_dict and len(api_info_dict['input_args']) > 1 \
85
+ and 'Min' in api_info_dict['input_args'][1]:
86
+ if api_info_dict['input_args'][1]['Min'] <= 0:
87
+ # The second argument in cross_entropy should be -100 or not less than 0
88
+ api_info_dict['input_args'][1]['Min'] = 0
89
+ return api_info_dict
90
90
 
91
91
 
92
- def cross_entropy_process(api_info_dict):
92
+ def histc_process(api_info_dict):
93
+ input_args = api_info_dict['input_args']
94
+ if input_args and input_args[0].get('dtype'):
95
+ dtype = input_args[0]['dtype']
96
+ if dtype in Const.TORCH_INT_DTYPE:
97
+ api_info_dict['input_args'][0]['dtype'] = Const.TORCH_FLOAT32
98
+ return api_info_dict
99
+
100
+
101
+ API_PROCESS_MAP = {
102
+ 'cross_entropy': cross_entropy_process,
103
+ 'histc': histc_process
104
+ }
105
+
106
+
107
+ def api_info_preprocess(api_name, api_info_dict):
93
108
  """
94
109
  Function Description:
95
- Preprocesses the cross_entropy API information.
110
+ Preprocesses the API information.
96
111
  Parameter:
112
+ api_name: Name of the API.
97
113
  api_info_dict: argument of the API.
98
114
  Return api_info_dict:
115
+ convert_type: Type of conversion.
99
116
  api_info_dict: Processed argument of the API.
100
117
  """
101
- if 'input_args' in api_info_dict and len(api_info_dict['input_args']) > 1 \
102
- and 'Min' in api_info_dict['input_args'][1]:
103
- if api_info_dict['input_args'][1]['Min'] <= 0:
104
- # The second argument in cross_entropy should be -100 or not less than 0
105
- api_info_dict['input_args'][1]['Min'] = 0
106
- return api_info_dict
118
+ convert_type = check_need_convert(api_name)
119
+ if api_name in API_PROCESS_MAP:
120
+ api_info_dict = API_PROCESS_MAP[api_name](api_info_dict)
121
+ return convert_type, api_info_dict
107
122
 
108
123
 
109
124
  def initialize_save_path(save_path, dir_name):
@@ -16,10 +16,12 @@
16
16
  # limitations under the License.
17
17
 
18
18
  # 定义比对算法及比对标准
19
+ import math
19
20
  import torch
20
21
  import numpy as np
21
22
 
22
23
  from msprobe.pytorch.api_accuracy_checker.compare.compare_utils import ULP_PARAMETERS
24
+ from msprobe.pytorch.api_accuracy_checker.precision_standard.standard_config import StandardConfig
23
25
  from msprobe.core.common.const import CompareConst
24
26
 
25
27
 
@@ -179,13 +181,13 @@ def check_inf_nan_value(inf_nan_mask, bench_output, device_output, dtype, rtol):
179
181
 
180
182
  def check_small_value(abs_err, small_value_mask, small_value_atol):
181
183
  '''
182
- 新精度标准的相对阈值法中,检查npu和golden小值域输出的相对误差是否满足阈值
184
+ 新精度标准的绝对阈值法中,检查npu和golden正常值输出的绝对误差是否满足阈值
183
185
  输入:
184
- rel_err:npu输出和golden输出的相对误差
186
+ abs_err:npu输出和golden输出的绝对误差
185
187
  normal_value_mask:npu输出和golden输出的正常值mask
186
- rtol:相对误差的阈值
188
+ atol:绝对误差的阈值
187
189
  输出:
188
- rel_err_ratio:npu输出和golden输出的相对误差不满足阈值的比例
190
+ abs_err_ratio:npu输出和golden输出的绝对误差不满足阈值的比例
189
191
  '''
190
192
  greater_mask = np.greater(abs_err, small_value_atol)
191
193
  err_mask = np.logical_and(greater_mask, small_value_mask)
@@ -195,13 +197,13 @@ def check_small_value(abs_err, small_value_mask, small_value_atol):
195
197
 
196
198
  def check_norm_value(normal_value_mask, rel_err, rtol):
197
199
  '''
198
- 新精度标准的绝对阈值法中,检查npu和golden正常值输出的绝对误差是否满足阈值
200
+ 新精度标准的相对阈值法中,检查npu和golden小值域输出的相对误差是否满足阈值
199
201
  输入:
200
- abs_err:npu输出和golden输出的绝对误差
202
+ rel_err:npu输出和golden输出的相对误差
201
203
  normal_value_mask:npu输出和golden输出的正常值mask
202
- atol:绝对误差的阈值
204
+ rtol:相对误差的阈值
203
205
  输出:
204
- abs_err_ratio:npu输出和golden输出的绝对误差不满足阈值的比例
206
+ rel_err_ratio:npu输出和golden输出的相对误差不满足阈值的比例
205
207
  '''
206
208
  err_mask = np.greater(rel_err, rtol)
207
209
  err_mask = np.logical_and(err_mask, normal_value_mask)
@@ -228,3 +230,34 @@ def get_ulp_err(bench_output, device_output, dtype):
228
230
  def calc_ulp_err(bench_output, device_output, eb, exponent_num, data_type):
229
231
  return (device_output.astype(data_type) - bench_output).astype(data_type) * \
230
232
  np.exp2(-eb + exponent_num).astype(data_type)
233
+
234
+
235
+ def calc_ratio(x, y, dtype):
236
+ """
237
+ Calculate the ratio between NPU and GPU statistical values.
238
+
239
+ Args:
240
+ x (float): Statistical value from the NPU side
241
+ y (float): Statistical value from the GPU side
242
+ dtype: Data type used to determine the minimum error value
243
+
244
+ Returns:
245
+ float: The ratio of NPU to GPU statistical values
246
+
247
+ Notes:
248
+ - Takes absolute values of both x and y for calculation
249
+ - Uses StandardConfig.get_minmum_err(dtype) to get minimum error for the specified dtype
250
+ - Prevents division by zero by ensuring denominator is not less than minimum error
251
+ - Returns |x| / max(|y|, minimum_error)
252
+ """
253
+ x, y = abs(x), abs(y)
254
+ minmum_err = StandardConfig.get_minmum_err(dtype)
255
+ err_y = max(y, minmum_err)
256
+ return x / err_y
257
+
258
+
259
+ def compare_bool_tensor(bench_output, device_output):
260
+ error_nums = (bench_output != device_output).sum()
261
+ error_rate = float(error_nums / bench_output.size)
262
+ result = CompareConst.PASS if error_rate == 0 else CompareConst.ERROR
263
+ return error_rate, result, ""