mindstudio-probe 1.0.3__py3-none-any.whl → 1.0.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (262) hide show
  1. {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.0.4.dist-info}/LICENSE +201 -201
  2. {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.0.4.dist-info}/METADATA +36 -34
  3. mindstudio_probe-1.0.4.dist-info/RECORD +276 -0
  4. {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.0.4.dist-info}/WHEEL +1 -1
  5. {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.0.4.dist-info}/entry_points.txt +1 -0
  6. msprobe/README.md +101 -237
  7. msprobe/{config/config.json → config.json} +49 -49
  8. msprobe/core/advisor/advisor.py +124 -124
  9. msprobe/core/advisor/advisor_const.py +59 -59
  10. msprobe/core/advisor/advisor_result.py +58 -58
  11. msprobe/core/common/const.py +341 -318
  12. msprobe/core/common/exceptions.py +99 -99
  13. msprobe/core/common/{file_check.py → file_utils.py} +478 -283
  14. msprobe/core/common/log.py +76 -69
  15. msprobe/core/common/utils.py +385 -616
  16. msprobe/core/common_config.py +85 -71
  17. msprobe/core/compare/acc_compare.py +299 -298
  18. msprobe/core/compare/check.py +95 -95
  19. msprobe/core/compare/compare_cli.py +49 -49
  20. msprobe/core/compare/highlight.py +223 -222
  21. msprobe/core/compare/multiprocessing_compute.py +149 -149
  22. msprobe/core/compare/npy_compare.py +295 -295
  23. msprobe/core/compare/utils.py +430 -429
  24. msprobe/core/data_dump/data_collector.py +154 -144
  25. msprobe/core/data_dump/data_processor/base.py +314 -293
  26. msprobe/core/data_dump/data_processor/factory.py +59 -59
  27. msprobe/core/data_dump/data_processor/mindspore_processor.py +186 -198
  28. msprobe/core/data_dump/data_processor/pytorch_processor.py +366 -389
  29. msprobe/core/data_dump/json_writer.py +96 -116
  30. msprobe/core/data_dump/scope.py +178 -178
  31. msprobe/core/grad_probe/constant.py +70 -70
  32. msprobe/core/grad_probe/grad_compare.py +171 -175
  33. msprobe/core/grad_probe/utils.py +64 -52
  34. msprobe/docs/01.installation.md +89 -0
  35. msprobe/docs/02.config_introduction.md +165 -0
  36. msprobe/docs/03.config_examples.md +247 -0
  37. msprobe/docs/04.acl_config_examples.md +76 -0
  38. msprobe/docs/05.data_dump_PyTorch.md +198 -0
  39. msprobe/docs/06.data_dump_MindSpore.md +243 -0
  40. msprobe/docs/07.accuracy_checker_PyTorch.md +274 -0
  41. msprobe/docs/08.accuracy_checker_online_PyTorch.md +198 -0
  42. msprobe/docs/09.accuracy_checker_MindSpore.md +68 -0
  43. msprobe/docs/10.accuracy_compare_PyTorch.md +245 -0
  44. msprobe/docs/11.accuracy_compare_MindSpore.md +202 -0
  45. msprobe/docs/12.overflow_check_PyTorch.md +79 -0
  46. msprobe/docs/13.overflow_check_MindSpore.md +31 -0
  47. msprobe/{pytorch/doc/parse_tool.md → docs/14.data_parse_PyTorch.md} +283 -286
  48. msprobe/docs/15.free_benchmarking_PyTorch.md +164 -0
  49. msprobe/{doc/grad_probe/grad_probe.md → docs/17.grad_probe.md} +207 -207
  50. msprobe/docs/FAQ_PyTorch.md +177 -0
  51. msprobe/docs/S02.report_free_benchmarking_validation_performance_baseline.md +146 -0
  52. msprobe/docs/img/free_benchmark_framework.png +0 -0
  53. msprobe/mindspore/__init__.py +1 -1
  54. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +254 -245
  55. msprobe/mindspore/api_accuracy_checker/api_info.py +69 -69
  56. msprobe/mindspore/api_accuracy_checker/api_runner.py +155 -151
  57. msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +196 -196
  58. msprobe/mindspore/api_accuracy_checker/cmd_parser.py +6 -0
  59. msprobe/mindspore/api_accuracy_checker/compute_element.py +238 -223
  60. msprobe/mindspore/api_accuracy_checker/main.py +8 -15
  61. msprobe/mindspore/api_accuracy_checker/type_mapping.py +113 -113
  62. msprobe/mindspore/api_accuracy_checker/utils.py +79 -62
  63. msprobe/mindspore/cell_processor.py +34 -34
  64. msprobe/mindspore/common/const.py +106 -87
  65. msprobe/mindspore/common/log.py +37 -37
  66. msprobe/mindspore/common/utils.py +81 -57
  67. msprobe/mindspore/compare/distributed_compare.py +75 -75
  68. msprobe/mindspore/compare/ms_compare.py +219 -117
  69. msprobe/mindspore/compare/ms_graph_compare.py +348 -317
  70. msprobe/mindspore/compare/ms_to_pt_api.yaml +399 -399
  71. msprobe/mindspore/debugger/debugger_config.py +66 -74
  72. msprobe/mindspore/debugger/precision_debugger.py +126 -107
  73. msprobe/mindspore/dump/dump_tool_factory.py +35 -35
  74. msprobe/mindspore/dump/hook_cell/api_registry.py +118 -104
  75. msprobe/mindspore/dump/hook_cell/hook_cell.py +55 -53
  76. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +922 -925
  77. msprobe/mindspore/dump/hook_cell/wrap_api.py +113 -0
  78. msprobe/mindspore/dump/jit_dump.py +72 -56
  79. msprobe/mindspore/dump/kernel_graph_dump.py +59 -60
  80. msprobe/mindspore/dump/kernel_kbyk_dump.py +64 -65
  81. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +116 -116
  82. msprobe/mindspore/free_benchmark/common/config.py +12 -12
  83. msprobe/mindspore/free_benchmark/common/handler_params.py +17 -17
  84. msprobe/mindspore/free_benchmark/common/utils.py +71 -71
  85. msprobe/mindspore/free_benchmark/data/support_wrap_ops.yaml +842 -842
  86. msprobe/mindspore/free_benchmark/decorator/dec_forward.py +43 -42
  87. msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +107 -107
  88. msprobe/mindspore/free_benchmark/handler/base_handler.py +90 -90
  89. msprobe/mindspore/free_benchmark/handler/check_handler.py +41 -41
  90. msprobe/mindspore/free_benchmark/handler/fix_handler.py +36 -36
  91. msprobe/mindspore/free_benchmark/handler/handler_factory.py +21 -21
  92. msprobe/mindspore/free_benchmark/perturbation/add_noise.py +67 -67
  93. msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +21 -21
  94. msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +63 -63
  95. msprobe/mindspore/free_benchmark/perturbation/exchange_value.py +51 -0
  96. msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +35 -34
  97. msprobe/mindspore/free_benchmark/perturbation/no_change.py +12 -12
  98. msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +29 -27
  99. msprobe/mindspore/free_benchmark/self_check_tool_factory.py +33 -33
  100. msprobe/mindspore/grad_probe/global_context.py +90 -91
  101. msprobe/mindspore/grad_probe/grad_analyzer.py +231 -231
  102. msprobe/mindspore/grad_probe/grad_monitor.py +27 -27
  103. msprobe/mindspore/grad_probe/grad_stat_csv.py +131 -131
  104. msprobe/mindspore/grad_probe/hook.py +94 -92
  105. msprobe/mindspore/grad_probe/utils.py +29 -28
  106. msprobe/mindspore/ms_config.py +128 -126
  107. msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +44 -45
  108. msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +34 -34
  109. msprobe/mindspore/runtime.py +4 -4
  110. msprobe/mindspore/service.py +378 -354
  111. msprobe/mindspore/task_handler_factory.py +24 -24
  112. msprobe/msprobe.py +105 -107
  113. msprobe/pytorch/__init__.py +3 -3
  114. msprobe/pytorch/api_accuracy_checker/common/config.py +53 -55
  115. msprobe/pytorch/api_accuracy_checker/common/utils.py +214 -165
  116. msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +213 -213
  117. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +606 -581
  118. msprobe/pytorch/api_accuracy_checker/compare/api_precision_standard.yaml +132 -132
  119. msprobe/pytorch/api_accuracy_checker/compare/api_precision_threshold.yaml +390 -390
  120. msprobe/pytorch/api_accuracy_checker/compare/compare.py +386 -381
  121. msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +73 -73
  122. msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +245 -244
  123. msprobe/pytorch/api_accuracy_checker/config.yaml +10 -10
  124. msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +335 -332
  125. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +200 -199
  126. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +133 -134
  127. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +592 -581
  128. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +70 -74
  129. msprobe/pytorch/api_accuracy_checker/run_ut/torch_ut_setting.json +7 -4
  130. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +197 -202
  131. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +325 -324
  132. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +204 -204
  133. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +219 -218
  134. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/ssl_config.py +10 -10
  135. msprobe/pytorch/bench_functions/__init__.py +15 -15
  136. msprobe/pytorch/bench_functions/apply_adam_w.py +28 -28
  137. msprobe/pytorch/bench_functions/confusion_transpose.py +19 -19
  138. msprobe/pytorch/bench_functions/fast_gelu.py +55 -55
  139. msprobe/pytorch/bench_functions/layer_norm_eval.py +6 -6
  140. msprobe/pytorch/bench_functions/linear.py +12 -12
  141. msprobe/pytorch/bench_functions/matmul_backward.py +48 -48
  142. msprobe/pytorch/bench_functions/npu_fusion_attention.py +509 -421
  143. msprobe/pytorch/bench_functions/rms_norm.py +15 -15
  144. msprobe/pytorch/bench_functions/rotary_mul.py +52 -52
  145. msprobe/pytorch/bench_functions/scaled_mask_softmax.py +26 -26
  146. msprobe/pytorch/bench_functions/swiglu.py +55 -55
  147. msprobe/pytorch/common/__init__.py +2 -2
  148. msprobe/pytorch/common/compare_script.template +14 -14
  149. msprobe/pytorch/common/log.py +20 -31
  150. msprobe/pytorch/common/parse_json.py +39 -39
  151. msprobe/pytorch/common/utils.py +305 -300
  152. msprobe/pytorch/compare/distributed_compare.py +66 -66
  153. msprobe/pytorch/compare/mapping.yaml +607 -607
  154. msprobe/pytorch/compare/match.py +34 -33
  155. msprobe/pytorch/compare/pt_compare.py +50 -40
  156. msprobe/pytorch/debugger/debugger_config.py +95 -95
  157. msprobe/pytorch/debugger/precision_debugger.py +125 -125
  158. msprobe/pytorch/free_benchmark/__init__.py +8 -8
  159. msprobe/pytorch/free_benchmark/common/constant.py +70 -70
  160. msprobe/pytorch/free_benchmark/common/counter.py +71 -71
  161. msprobe/pytorch/free_benchmark/common/enums.py +37 -37
  162. msprobe/pytorch/free_benchmark/common/params.py +129 -129
  163. msprobe/pytorch/free_benchmark/common/utils.py +102 -102
  164. msprobe/pytorch/free_benchmark/compare/grad_saver.py +179 -179
  165. msprobe/pytorch/free_benchmark/compare/single_benchmark.py +104 -104
  166. msprobe/pytorch/free_benchmark/main.py +105 -105
  167. msprobe/pytorch/free_benchmark/perturbed_layers/base_layer.py +13 -13
  168. msprobe/pytorch/free_benchmark/perturbed_layers/layer_factory.py +41 -41
  169. msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +90 -90
  170. msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +104 -104
  171. msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +63 -63
  172. msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +68 -68
  173. msprobe/pytorch/free_benchmark/perturbed_layers/npu/no_change.py +28 -28
  174. msprobe/pytorch/free_benchmark/perturbed_layers/npu/npu_base_layser.py +45 -45
  175. msprobe/pytorch/free_benchmark/perturbed_layers/run_cpu.py +19 -19
  176. msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +217 -217
  177. msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +39 -39
  178. msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +23 -23
  179. msprobe/pytorch/free_benchmark/result_handlers/handler_factory.py +30 -30
  180. msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +170 -170
  181. msprobe/pytorch/function_factory.py +76 -75
  182. msprobe/pytorch/functional/dump_module.py +39 -39
  183. msprobe/pytorch/grad_probe/grad_monitor.py +91 -90
  184. msprobe/pytorch/grad_probe/grad_stat_csv.py +128 -128
  185. msprobe/pytorch/hook_module/api_registry.py +161 -161
  186. msprobe/pytorch/hook_module/hook_module.py +120 -120
  187. msprobe/pytorch/hook_module/support_wrap_ops.yaml +1879 -1877
  188. msprobe/pytorch/hook_module/utils.py +30 -29
  189. msprobe/pytorch/hook_module/wrap_aten.py +110 -110
  190. msprobe/pytorch/hook_module/wrap_distributed.py +78 -78
  191. msprobe/pytorch/hook_module/wrap_functional.py +105 -105
  192. msprobe/pytorch/hook_module/wrap_npu_custom.py +93 -84
  193. msprobe/pytorch/hook_module/wrap_tensor.py +71 -71
  194. msprobe/pytorch/hook_module/wrap_torch.py +86 -86
  195. msprobe/pytorch/hook_module/wrap_vf.py +62 -62
  196. msprobe/pytorch/module_processer.py +138 -138
  197. msprobe/pytorch/online_dispatch/__init__.py +20 -20
  198. msprobe/pytorch/online_dispatch/compare.py +236 -236
  199. msprobe/pytorch/online_dispatch/dispatch.py +271 -271
  200. msprobe/pytorch/online_dispatch/dump_compare.py +155 -156
  201. msprobe/pytorch/online_dispatch/single_compare.py +391 -391
  202. msprobe/pytorch/online_dispatch/torch_ops_config.yaml +49 -49
  203. msprobe/pytorch/online_dispatch/utils.py +130 -146
  204. msprobe/pytorch/parse.py +4 -4
  205. msprobe/pytorch/parse_tool/cli.py +32 -32
  206. msprobe/pytorch/parse_tool/lib/compare.py +260 -271
  207. msprobe/pytorch/parse_tool/lib/config.py +52 -52
  208. msprobe/pytorch/parse_tool/lib/file_desc.py +31 -31
  209. msprobe/pytorch/parse_tool/lib/interactive_cli.py +102 -102
  210. msprobe/pytorch/parse_tool/lib/parse_exception.py +54 -54
  211. msprobe/pytorch/parse_tool/lib/parse_tool.py +158 -158
  212. msprobe/pytorch/parse_tool/lib/utils.py +316 -321
  213. msprobe/pytorch/parse_tool/lib/visualization.py +85 -91
  214. msprobe/pytorch/pt_config.py +188 -187
  215. msprobe/pytorch/service.py +246 -252
  216. mindstudio_probe-1.0.3.dist-info/RECORD +0 -272
  217. msprobe/config/README.md +0 -539
  218. msprobe/mindspore/doc/compare.md +0 -58
  219. msprobe/mindspore/doc/dump.md +0 -217
  220. msprobe/mindspore/dump/hook_cell/wrap_functional.py +0 -91
  221. msprobe/mindspore/dump/hook_cell/wrap_tensor.py +0 -63
  222. msprobe/pytorch/doc/FAQ.md +0 -193
  223. msprobe/pytorch/doc/api_accuracy_checker.md +0 -313
  224. msprobe/pytorch/doc/api_accuracy_checker_online.md +0 -187
  225. msprobe/pytorch/doc/dump.md +0 -260
  226. msprobe/pytorch/doc/msprobe/321/207/342/226/223/342/225/233/321/205/342/225/221/320/266/321/205/342/225/226/320/265/321/205/320/225/342/225/226/321/206/320/245/342/226/221/321/206/320/235/320/276dump/321/206/320/260/320/227/321/205/320/227/320/226/321/206/320/220/320/267/321/210/320/223/342/225/234/321/205/320/257/342/225/221/321/207/342/225/221/342/224/220/321/206/320/232/320/265/321/205/320/241/320/232.md +0 -182
  227. msprobe/pytorch/doc/ptdbg_ascend_compare.md +0 -240
  228. msprobe/pytorch/doc/ptdbg_ascend_overview.md +0 -68
  229. msprobe/pytorch/doc/ptdbg_ascend_quickstart.md +0 -381
  230. msprobe/pytorch/doc/run_overflow_check.md +0 -25
  231. msprobe/pytorch/doc//321/205/320/254/320/270/321/207/342/225/221/342/224/220/321/207/342/226/223/342/225/233/321/205/342/225/221/320/266/321/206/320/277/320/244/321/205/320/277/342/225/243.md +0 -90
  232. msprobe/pytorch/doc//321/206/320/247/320/260/321/206/320/260/320/227/321/206/320/255/320/226/321/205/342/225/226/320/265/321/205/320/225/342/225/226/321/205/320/254/342/225/221/321/206/320/251/320/277/321/211/320/272/320/234/321/210/320/277/320/221/321/205/320/242/320/234/321/206/320/220/320/267/321/210/320/223/342/225/234/321/205/320/257/342/225/221/321/207/342/225/221/342/224/220/321/206/320/232/320/265/321/205/320/241/320/232.md +0 -151
  233. {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.0.4.dist-info}/top_level.txt +0 -0
  234. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_1.png +0 -0
  235. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_2.png +0 -0
  236. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_3.png +0 -0
  237. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_4.png +0 -0
  238. /msprobe/{pytorch/doc → docs}/img/GPT-3_1.png +0 -0
  239. /msprobe/{pytorch/doc → docs}/img/GPT-3_2.png +0 -0
  240. /msprobe/{pytorch/doc → docs}/img/GPT-3_3.png +0 -0
  241. /msprobe/{pytorch/doc → docs}/img/GPT-3_4.png +0 -0
  242. /msprobe/{pytorch/doc → docs}/img/GPT-3_5.png +0 -0
  243. /msprobe/{pytorch/doc → docs}/img/GPT-3_6.png +0 -0
  244. /msprobe/{pytorch/doc → docs}/img/GPT-3_7.png +0 -0
  245. /msprobe/{pytorch/doc → docs}/img/GPT-3_8.png +0 -0
  246. /msprobe/{pytorch/doc → docs}/img/YOLOV5S_1.png +0 -0
  247. /msprobe/{pytorch/doc → docs}/img/YOLOV5S_2.png +0 -0
  248. /msprobe/{pytorch/doc → docs}/img/accuracy_checking_details.png +0 -0
  249. /msprobe/{pytorch/doc → docs}/img/accuracy_checking_result.png +0 -0
  250. /msprobe/{pytorch/doc → docs}/img/api_precision_compare_details.png +0 -0
  251. /msprobe/{pytorch/doc → docs}/img/api_precision_compare_result.png +0 -0
  252. /msprobe/{pytorch/doc → docs}/img/auto_analyze_log.png +0 -0
  253. /msprobe/{pytorch/doc → docs}/img/compare_result_pkl.png +0 -0
  254. /msprobe/{pytorch/doc → docs}/img/compare_result_pkl_md5.png.png +0 -0
  255. /msprobe/{pytorch/doc → docs}/img/cpu_info.png +0 -0
  256. /msprobe/{config → docs}/img/free_benchmark.png +0 -0
  257. /msprobe/{doc/grad_probe/img/image-1.png → docs/img/grad_probe_image-1.png} +0 -0
  258. /msprobe/{doc/grad_probe/img/image-2.png → docs/img/grad_probe_image-2.png} +0 -0
  259. /msprobe/{doc/grad_probe/img/image-3.png → docs/img/grad_probe_image-3.png} +0 -0
  260. /msprobe/{doc/grad_probe/img/image-4.png → docs/img/grad_probe_image-4.png} +0 -0
  261. /msprobe/{doc/grad_probe/img/image.png → docs/img/grad_probe_image.png} +0 -0
  262. /msprobe/{pytorch/doc → docs}/img/module_compare.png +0 -0
@@ -1,149 +1,149 @@
1
-
2
- import multiprocessing
3
- from dataclasses import dataclass
4
- from functools import partial
5
- import numpy as np
6
- import pandas as pd
7
- from msprobe.core.common.log import logger
8
- from msprobe.core.common.utils import CompareException
9
- from msprobe.core.common.const import CompareConst
10
-
11
-
12
- def _handle_multi_process(func, input_parma, result_df, lock):
13
- process_num = int((multiprocessing.cpu_count() + 1) / 2)
14
- op_name_mapping_dict = read_dump_data(result_df)
15
-
16
- df_chunk_size = len(result_df) // process_num
17
- if df_chunk_size > 0:
18
- df_chunks = [result_df.iloc[i:i + df_chunk_size] for i in range(0, len(result_df), df_chunk_size)]
19
- else:
20
- df_chunks = [result_df]
21
-
22
- results = []
23
- pool = multiprocessing.Pool(process_num)
24
-
25
- def err_call(args):
26
- logger.error('multiprocess compare failed! Reason: {}'.format(args))
27
- try:
28
- pool.terminate()
29
- except OSError as e:
30
- logger.error("pool terminate failed")
31
-
32
- for process_idx, df_chunk in enumerate(df_chunks):
33
- idx = df_chunk_size * process_idx
34
- result = pool.apply_async(func,
35
- args=(idx, op_name_mapping_dict, df_chunk, lock, input_parma),
36
- error_callback=err_call)
37
- results.append(result)
38
- final_results = [r.get() for r in results]
39
- pool.close()
40
- pool.join()
41
- return pd.concat(final_results, ignore_index=True)
42
-
43
-
44
- def _ms_graph_handle_multi_process(func, result_df, mode):
45
- process_num = int((multiprocessing.cpu_count() + 1) // 2)
46
- df_chunk_size = len(result_df) // process_num
47
- if df_chunk_size > 0:
48
- df_chunks = [result_df.iloc[i:i + df_chunk_size] for i in range(0, len(result_df), df_chunk_size)]
49
- else:
50
- df_chunks = [result_df]
51
-
52
- results = []
53
- pool = multiprocessing.Pool(process_num)
54
-
55
- def err_call(args):
56
- logger.error('multiprocess compare failed! Reason: {}'.format(args))
57
- try:
58
- pool.terminate()
59
- except OSError as e:
60
- logger.error("pool terminate failed")
61
-
62
- for df_chunk in df_chunks:
63
- result = pool.apply_async(func, args=(df_chunk, mode), error_callback=err_call)
64
- results.append(result)
65
- final_results = [r.get() for r in results]
66
- pool.close()
67
- pool.join()
68
- return pd.concat(final_results, ignore_index=True)
69
-
70
-
71
- def read_dump_data(result_df):
72
- try:
73
- npu_dump_name_list = result_df.iloc[0:, 0].tolist()
74
- npu_dump_tensor_list = result_df.iloc[0:, -1].tolist()
75
- op_name_mapping_dict = {}
76
- for index, _ in enumerate(npu_dump_name_list):
77
- npu_dump_name = npu_dump_name_list[index]
78
- npu_dump_tensor = npu_dump_tensor_list[index]
79
- op_name_mapping_dict[npu_dump_name] = [npu_dump_tensor, npu_dump_tensor]
80
- return op_name_mapping_dict
81
- except ValueError as e:
82
- logger.error('result dataframe is not found.')
83
- raise CompareException(CompareException.INVALID_DATA_ERROR) from e
84
- except IndexError as e:
85
- logger.error('result dataframe elements can not be access.')
86
- raise CompareException(CompareException.INDEX_OUT_OF_BOUNDS_ERROR) from e
87
-
88
- @dataclass
89
- class ComparisonResult:
90
- cos_result: list
91
- max_err_result: list
92
- max_relative_err_result: list
93
- err_msgs: list
94
- one_thousand_err_ratio_result: list
95
- five_thousand_err_ratio_result: list
96
-
97
-
98
- def _save_cmp_result(offset, result: ComparisonResult, result_df, lock):
99
- """
100
- Save comparison results into the result DataFrame with thread safety.
101
- Args:
102
- offset: offset for index
103
- result: data struct of ComparisonResult
104
- result_df: result of DataFrame
105
- lock: thread lock
106
-
107
- Returns:
108
- comparison results in DataFrame
109
- """
110
-
111
- lock.acquire()
112
- try:
113
- for i, _ in enumerate(result.cos_result):
114
- process_index = i + offset
115
- result_df.loc[process_index, CompareConst.COSINE] = result.cos_result[i]
116
- result_df.loc[process_index, CompareConst.MAX_ABS_ERR] = result.max_err_result[i]
117
- result_df.loc[process_index, CompareConst.MAX_RELATIVE_ERR] = result.max_relative_err_result[i]
118
- result_df.loc[process_index, CompareConst.ERROR_MESSAGE] = result.err_msgs[i]
119
- result_df.loc[process_index, CompareConst.ACCURACY] = check_accuracy(result.cos_result[i], result.max_err_result[i])
120
- result_df.loc[process_index, CompareConst.ONE_THOUSANDTH_ERR_RATIO] = result.one_thousand_err_ratio_result[i]
121
- result_df.loc[process_index, CompareConst.FIVE_THOUSANDTHS_ERR_RATIO] = result.five_thousand_err_ratio_result[i]
122
- return result_df
123
- except ValueError as e:
124
- logger.error('result dataframe is not found.')
125
- raise CompareException(CompareException.INVALID_DATA_ERROR) from e
126
- except IndexError as e:
127
- logger.error('result dataframe elements can not be access.')
128
- raise CompareException(CompareException.INDEX_OUT_OF_BOUNDS_ERROR) from e
129
- finally:
130
- lock.release()
131
-
132
-
133
- def check_accuracy(cos, max_abs_err):
134
- if cos == CompareConst.SHAPE_UNMATCH:
135
- return CompareConst.ACCURACY_CHECK_UNMATCH
136
- if cos == CompareConst.NONE or max_abs_err == CompareConst.NONE:
137
- return CompareConst.NONE
138
- if cos == "N/A" or max_abs_err == "N/A":
139
- return CompareConst.ACCURACY_CHECK_NO
140
- try:
141
- cos, max_abs_err = float(cos), float(max_abs_err)
142
- except ValueError:
143
- logger.warning("Cosine or MaxAbsErr can not get float value.")
144
- return CompareConst.NONE
145
- if cos < CompareConst.COS_THRESHOLD and max_abs_err > CompareConst.MAX_ABS_ERR_THRESHOLD:
146
- return CompareConst.ACCURACY_CHECK_NO
147
- if cos < CompareConst.COS_MAX_THRESHOLD or max_abs_err > CompareConst.MAX_ABS_ERR_MAX_THRESHOLD:
148
- return CompareConst.ACCURACY_CHECK_NO
149
- return CompareConst.ACCURACY_CHECK_YES
1
+
2
+ import multiprocessing
3
+ from dataclasses import dataclass
4
+ from functools import partial
5
+ import numpy as np
6
+ import pandas as pd
7
+ from msprobe.core.common.log import logger
8
+ from msprobe.core.common.utils import CompareException
9
+ from msprobe.core.common.const import CompareConst
10
+
11
+
12
+ def _handle_multi_process(func, input_parma, result_df, lock):
13
+ process_num = int((multiprocessing.cpu_count() + 1) / 2)
14
+ op_name_mapping_dict = read_dump_data(result_df)
15
+
16
+ df_chunk_size = len(result_df) // process_num
17
+ if df_chunk_size > 0:
18
+ df_chunks = [result_df.iloc[i:i + df_chunk_size] for i in range(0, len(result_df), df_chunk_size)]
19
+ else:
20
+ df_chunks = [result_df]
21
+
22
+ results = []
23
+ pool = multiprocessing.Pool(process_num)
24
+
25
+ def err_call(args):
26
+ logger.error('multiprocess compare failed! Reason: {}'.format(args))
27
+ try:
28
+ pool.terminate()
29
+ except OSError as e:
30
+ logger.error("pool terminate failed")
31
+
32
+ for process_idx, df_chunk in enumerate(df_chunks):
33
+ idx = df_chunk_size * process_idx
34
+ result = pool.apply_async(func,
35
+ args=(idx, op_name_mapping_dict, df_chunk, lock, input_parma),
36
+ error_callback=err_call)
37
+ results.append(result)
38
+ final_results = [r.get() for r in results]
39
+ pool.close()
40
+ pool.join()
41
+ return pd.concat(final_results, ignore_index=True)
42
+
43
+
44
+ def _ms_graph_handle_multi_process(func, result_df, mode):
45
+ process_num = int((multiprocessing.cpu_count() + 1) // 2)
46
+ df_chunk_size = len(result_df) // process_num
47
+ if df_chunk_size > 0:
48
+ df_chunks = [result_df.iloc[i:i + df_chunk_size] for i in range(0, len(result_df), df_chunk_size)]
49
+ else:
50
+ df_chunks = [result_df]
51
+
52
+ results = []
53
+ pool = multiprocessing.Pool(process_num)
54
+
55
+ def err_call(args):
56
+ logger.error('multiprocess compare failed! Reason: {}'.format(args))
57
+ try:
58
+ pool.terminate()
59
+ except OSError as e:
60
+ logger.error("pool terminate failed")
61
+
62
+ for df_chunk in df_chunks:
63
+ result = pool.apply_async(func, args=(df_chunk, mode), error_callback=err_call)
64
+ results.append(result)
65
+ final_results = [r.get() for r in results]
66
+ pool.close()
67
+ pool.join()
68
+ return pd.concat(final_results, ignore_index=True)
69
+
70
+
71
+ def read_dump_data(result_df):
72
+ try:
73
+ npu_dump_name_list = result_df.iloc[0:, 0].tolist()
74
+ npu_dump_tensor_list = result_df.iloc[0:, -1].tolist()
75
+ op_name_mapping_dict = {}
76
+ for index, _ in enumerate(npu_dump_name_list):
77
+ npu_dump_name = npu_dump_name_list[index]
78
+ npu_dump_tensor = npu_dump_tensor_list[index]
79
+ op_name_mapping_dict[npu_dump_name] = [npu_dump_tensor, npu_dump_tensor]
80
+ return op_name_mapping_dict
81
+ except ValueError as e:
82
+ logger.error('result dataframe is not found.')
83
+ raise CompareException(CompareException.INVALID_DATA_ERROR) from e
84
+ except IndexError as e:
85
+ logger.error('result dataframe elements can not be access.')
86
+ raise CompareException(CompareException.INDEX_OUT_OF_BOUNDS_ERROR) from e
87
+
88
+ @dataclass
89
+ class ComparisonResult:
90
+ cos_result: list
91
+ max_err_result: list
92
+ max_relative_err_result: list
93
+ err_msgs: list
94
+ one_thousand_err_ratio_result: list
95
+ five_thousand_err_ratio_result: list
96
+
97
+
98
+ def _save_cmp_result(offset, result: ComparisonResult, result_df, lock):
99
+ """
100
+ Save comparison results into the result DataFrame with thread safety.
101
+ Args:
102
+ offset: offset for index
103
+ result: data struct of ComparisonResult
104
+ result_df: result of DataFrame
105
+ lock: thread lock
106
+
107
+ Returns:
108
+ comparison results in DataFrame
109
+ """
110
+
111
+ lock.acquire()
112
+ try:
113
+ for i, _ in enumerate(result.cos_result):
114
+ process_index = i + offset
115
+ result_df.loc[process_index, CompareConst.COSINE] = result.cos_result[i]
116
+ result_df.loc[process_index, CompareConst.MAX_ABS_ERR] = result.max_err_result[i]
117
+ result_df.loc[process_index, CompareConst.MAX_RELATIVE_ERR] = result.max_relative_err_result[i]
118
+ result_df.loc[process_index, CompareConst.ERROR_MESSAGE] = result.err_msgs[i]
119
+ result_df.loc[process_index, CompareConst.ACCURACY] = check_accuracy(result.cos_result[i], result.max_err_result[i])
120
+ result_df.loc[process_index, CompareConst.ONE_THOUSANDTH_ERR_RATIO] = result.one_thousand_err_ratio_result[i]
121
+ result_df.loc[process_index, CompareConst.FIVE_THOUSANDTHS_ERR_RATIO] = result.five_thousand_err_ratio_result[i]
122
+ return result_df
123
+ except ValueError as e:
124
+ logger.error('result dataframe is not found.')
125
+ raise CompareException(CompareException.INVALID_DATA_ERROR) from e
126
+ except IndexError as e:
127
+ logger.error('result dataframe elements can not be access.')
128
+ raise CompareException(CompareException.INDEX_OUT_OF_BOUNDS_ERROR) from e
129
+ finally:
130
+ lock.release()
131
+
132
+
133
+ def check_accuracy(cos, max_abs_err):
134
+ if cos == CompareConst.SHAPE_UNMATCH:
135
+ return CompareConst.ACCURACY_CHECK_UNMATCH
136
+ if cos == CompareConst.NONE or max_abs_err == CompareConst.NONE:
137
+ return CompareConst.NONE
138
+ if cos == "N/A" or max_abs_err == "N/A":
139
+ return CompareConst.ACCURACY_CHECK_NO
140
+ try:
141
+ cos, max_abs_err = float(cos), float(max_abs_err)
142
+ except ValueError:
143
+ logger.warning("Cosine or MaxAbsErr can not get float value.")
144
+ return CompareConst.NONE
145
+ if cos < CompareConst.COS_THRESHOLD and max_abs_err > CompareConst.MAX_ABS_ERR_THRESHOLD:
146
+ return CompareConst.ACCURACY_CHECK_NO
147
+ if cos < CompareConst.COS_MAX_THRESHOLD or max_abs_err > CompareConst.MAX_ABS_ERR_MAX_THRESHOLD:
148
+ return CompareConst.ACCURACY_CHECK_NO
149
+ return CompareConst.ACCURACY_CHECK_YES