mindstudio-probe 1.0.4__py3-none-any.whl → 1.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (278) hide show
  1. {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.1.dist-info}/METADATA +5 -5
  2. mindstudio_probe-1.1.1.dist-info/RECORD +341 -0
  3. {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.1.dist-info}/WHEEL +1 -1
  4. {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.1.dist-info}/entry_points.txt +0 -1
  5. msprobe/README.md +84 -18
  6. msprobe/__init__.py +16 -1
  7. msprobe/config.json +1 -5
  8. msprobe/core/advisor/advisor.py +16 -11
  9. msprobe/core/advisor/advisor_const.py +6 -7
  10. msprobe/core/advisor/advisor_result.py +12 -12
  11. msprobe/core/common/const.py +164 -3
  12. msprobe/core/common/exceptions.py +26 -4
  13. msprobe/core/common/file_utils.py +196 -27
  14. msprobe/core/common/inplace_op_checker.py +53 -0
  15. msprobe/core/common/inplace_ops.yaml +251 -0
  16. msprobe/core/common/log.py +46 -18
  17. msprobe/core/common/utils.py +308 -209
  18. msprobe/core/common_config.py +60 -38
  19. msprobe/core/compare/acc_compare.py +332 -94
  20. msprobe/core/compare/check.py +104 -22
  21. msprobe/core/compare/compare_cli.py +42 -5
  22. msprobe/core/compare/highlight.py +162 -57
  23. msprobe/core/compare/layer_mapping/__init__.py +19 -0
  24. msprobe/core/compare/layer_mapping/data_scope_parser.py +235 -0
  25. msprobe/core/compare/layer_mapping/layer_mapping.py +242 -0
  26. msprobe/core/compare/layer_mapping/postprocess_pass.py +94 -0
  27. msprobe/core/compare/multiprocessing_compute.py +33 -8
  28. msprobe/core/compare/npy_compare.py +73 -29
  29. msprobe/core/compare/utils.py +306 -247
  30. msprobe/core/data_dump/data_collector.py +44 -43
  31. msprobe/core/data_dump/data_processor/base.py +88 -35
  32. msprobe/core/data_dump/data_processor/factory.py +20 -3
  33. msprobe/core/data_dump/data_processor/mindspore_processor.py +14 -8
  34. msprobe/core/data_dump/data_processor/pytorch_processor.py +180 -66
  35. msprobe/core/data_dump/json_writer.py +63 -42
  36. msprobe/core/data_dump/scope.py +143 -48
  37. msprobe/core/grad_probe/constant.py +31 -13
  38. msprobe/core/grad_probe/grad_compare.py +20 -4
  39. msprobe/core/grad_probe/utils.py +44 -3
  40. msprobe/core/overflow_check/abnormal_scene.py +185 -0
  41. msprobe/core/overflow_check/api_info.py +55 -0
  42. msprobe/core/overflow_check/checker.py +138 -0
  43. msprobe/core/overflow_check/filter.py +157 -0
  44. msprobe/core/overflow_check/ignore_rules.yaml +55 -0
  45. msprobe/core/overflow_check/level.py +22 -0
  46. msprobe/core/overflow_check/utils.py +28 -0
  47. msprobe/docs/01.installation.md +29 -9
  48. msprobe/docs/02.config_introduction.md +83 -84
  49. msprobe/docs/03.config_examples.md +3 -20
  50. msprobe/docs/04.kernel_dump_PyTorch.md +73 -0
  51. msprobe/docs/05.data_dump_PyTorch.md +143 -13
  52. msprobe/docs/06.data_dump_MindSpore.md +197 -88
  53. msprobe/docs/07.accuracy_checker_PyTorch.md +69 -46
  54. msprobe/docs/08.accuracy_checker_online_PyTorch.md +52 -17
  55. msprobe/docs/09.accuracy_checker_MindSpore.md +51 -15
  56. msprobe/docs/10.accuracy_compare_PyTorch.md +187 -99
  57. msprobe/docs/11.accuracy_compare_MindSpore.md +253 -31
  58. msprobe/docs/12.overflow_check_PyTorch.md +1 -1
  59. msprobe/docs/13.overflow_check_MindSpore.md +6 -6
  60. msprobe/docs/15.free_benchmarking_PyTorch.md +60 -55
  61. msprobe/docs/16.free_benchmarking_MindSpore.md +159 -0
  62. msprobe/docs/17.grad_probe.md +19 -22
  63. msprobe/docs/18.online_dispatch.md +89 -0
  64. msprobe/docs/19.monitor.md +468 -0
  65. msprobe/docs/20.monitor_performance_baseline.md +52 -0
  66. msprobe/docs/21.visualization_PyTorch.md +386 -0
  67. msprobe/docs/22.visualization_MindSpore.md +384 -0
  68. msprobe/docs/23.tool_function_introduction.md +28 -0
  69. msprobe/docs/{FAQ_PyTorch.md → FAQ.md} +25 -10
  70. msprobe/docs/data_dump_Mindspore/dynamic_graph_quick_start_example.md +211 -0
  71. msprobe/docs/img/compare_result.png +0 -0
  72. msprobe/docs/img/monitor/cpu_info.png +0 -0
  73. msprobe/docs/img/ms_dump.png +0 -0
  74. msprobe/docs/img/ms_layer.png +0 -0
  75. msprobe/docs/img/pt_dump.png +0 -0
  76. msprobe/mindspore/__init__.py +16 -0
  77. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +130 -138
  78. msprobe/mindspore/api_accuracy_checker/api_info.py +27 -5
  79. msprobe/mindspore/api_accuracy_checker/api_runner.py +43 -18
  80. msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +21 -7
  81. msprobe/mindspore/api_accuracy_checker/checker_support_api.yaml +77 -0
  82. msprobe/mindspore/api_accuracy_checker/cmd_parser.py +63 -1
  83. msprobe/mindspore/api_accuracy_checker/compute_element.py +59 -24
  84. msprobe/mindspore/api_accuracy_checker/data_manager.py +264 -0
  85. msprobe/mindspore/api_accuracy_checker/main.py +27 -3
  86. msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +206 -0
  87. msprobe/mindspore/api_accuracy_checker/multi_data_manager.py +58 -0
  88. msprobe/mindspore/api_accuracy_checker/type_mapping.py +22 -5
  89. msprobe/mindspore/api_accuracy_checker/utils.py +34 -17
  90. msprobe/mindspore/cell_processor.py +58 -13
  91. msprobe/mindspore/common/const.py +35 -13
  92. msprobe/mindspore/common/log.py +5 -9
  93. msprobe/mindspore/common/utils.py +60 -5
  94. msprobe/mindspore/compare/distributed_compare.py +15 -28
  95. msprobe/mindspore/compare/ms_compare.py +319 -158
  96. msprobe/mindspore/compare/ms_graph_compare.py +99 -49
  97. msprobe/mindspore/debugger/debugger_config.py +20 -14
  98. msprobe/mindspore/debugger/precision_debugger.py +43 -13
  99. msprobe/mindspore/dump/dump_tool_factory.py +18 -1
  100. msprobe/mindspore/dump/hook_cell/api_registry.py +23 -3
  101. msprobe/mindspore/dump/hook_cell/primitive_hooks.py +203 -0
  102. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +107 -10
  103. msprobe/mindspore/dump/hook_cell/wrap_api.py +21 -13
  104. msprobe/mindspore/dump/jit_dump.py +56 -20
  105. msprobe/mindspore/dump/kernel_graph_dump.py +19 -5
  106. msprobe/mindspore/dump/kernel_kbyk_dump.py +19 -6
  107. msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +140 -0
  108. msprobe/mindspore/dym_loader/hook_dynamic_loader.h +53 -0
  109. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +162 -41
  110. msprobe/mindspore/free_benchmark/common/config.py +15 -0
  111. msprobe/mindspore/free_benchmark/common/handler_params.py +15 -1
  112. msprobe/mindspore/free_benchmark/common/utils.py +37 -8
  113. msprobe/mindspore/free_benchmark/data/support_wrap_ops.yaml +0 -204
  114. msprobe/mindspore/free_benchmark/handler/base_handler.py +20 -5
  115. msprobe/mindspore/free_benchmark/handler/check_handler.py +21 -7
  116. msprobe/mindspore/free_benchmark/handler/fix_handler.py +18 -3
  117. msprobe/mindspore/free_benchmark/handler/handler_factory.py +21 -6
  118. msprobe/mindspore/free_benchmark/perturbation/add_noise.py +23 -8
  119. msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +29 -5
  120. msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +25 -10
  121. msprobe/mindspore/free_benchmark/perturbation/exchange_value.py +45 -19
  122. msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +29 -8
  123. msprobe/mindspore/free_benchmark/perturbation/no_change.py +16 -1
  124. msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +22 -7
  125. msprobe/mindspore/free_benchmark/self_check_tool_factory.py +17 -2
  126. msprobe/mindspore/grad_probe/global_context.py +44 -14
  127. msprobe/mindspore/grad_probe/grad_analyzer.py +27 -13
  128. msprobe/mindspore/grad_probe/grad_monitor.py +16 -1
  129. msprobe/mindspore/grad_probe/grad_stat_csv.py +33 -5
  130. msprobe/mindspore/grad_probe/hook.py +24 -10
  131. msprobe/mindspore/grad_probe/utils.py +18 -5
  132. msprobe/mindspore/ms_config.py +22 -15
  133. msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +20 -6
  134. msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +15 -0
  135. msprobe/mindspore/runtime.py +15 -0
  136. msprobe/mindspore/service.py +75 -150
  137. msprobe/mindspore/task_handler_factory.py +15 -0
  138. msprobe/msprobe.py +24 -7
  139. msprobe/pytorch/__init__.py +23 -3
  140. msprobe/pytorch/api_accuracy_checker/common/config.py +81 -2
  141. msprobe/pytorch/api_accuracy_checker/common/utils.py +53 -21
  142. msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +19 -2
  143. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +50 -25
  144. msprobe/pytorch/api_accuracy_checker/compare/compare.py +51 -21
  145. msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +23 -6
  146. msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +28 -8
  147. msprobe/pytorch/api_accuracy_checker/config.yaml +1 -1
  148. msprobe/pytorch/api_accuracy_checker/generate_op_script/config_op.json +9 -0
  149. msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +454 -0
  150. msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +365 -0
  151. msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +73 -33
  152. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +44 -18
  153. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +32 -11
  154. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +122 -172
  155. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +158 -4
  156. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +30 -24
  157. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +68 -31
  158. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +27 -4
  159. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/dump_dispatch.py +115 -0
  160. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +26 -9
  161. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/torch_ops_config.yaml +63 -0
  162. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +44 -0
  163. msprobe/pytorch/bench_functions/__init__.py +18 -3
  164. msprobe/pytorch/bench_functions/apply_adam_w.py +15 -0
  165. msprobe/pytorch/bench_functions/confusion_transpose.py +20 -1
  166. msprobe/pytorch/bench_functions/fast_gelu.py +15 -0
  167. msprobe/pytorch/bench_functions/layer_norm_eval.py +15 -0
  168. msprobe/pytorch/bench_functions/linear.py +15 -0
  169. msprobe/pytorch/bench_functions/matmul_backward.py +33 -6
  170. msprobe/pytorch/bench_functions/npu_fusion_attention.py +280 -157
  171. msprobe/pytorch/bench_functions/rms_norm.py +15 -0
  172. msprobe/pytorch/bench_functions/rotary_mul.py +32 -9
  173. msprobe/pytorch/bench_functions/scaled_mask_softmax.py +15 -0
  174. msprobe/pytorch/bench_functions/swiglu.py +29 -6
  175. msprobe/pytorch/common/__init__.py +15 -0
  176. msprobe/pytorch/common/log.py +18 -6
  177. msprobe/pytorch/common/parse_json.py +31 -16
  178. msprobe/pytorch/common/utils.py +96 -40
  179. msprobe/pytorch/compare/distributed_compare.py +13 -14
  180. msprobe/pytorch/compare/match.py +15 -0
  181. msprobe/pytorch/compare/pt_compare.py +44 -10
  182. msprobe/pytorch/debugger/debugger_config.py +69 -52
  183. msprobe/pytorch/debugger/precision_debugger.py +72 -24
  184. msprobe/pytorch/dump/kernel_dump/kernel_config.py +33 -0
  185. msprobe/pytorch/free_benchmark/__init__.py +20 -5
  186. msprobe/pytorch/free_benchmark/common/constant.py +15 -0
  187. msprobe/pytorch/free_benchmark/common/counter.py +15 -0
  188. msprobe/pytorch/free_benchmark/common/enums.py +43 -0
  189. msprobe/pytorch/free_benchmark/common/params.py +23 -1
  190. msprobe/pytorch/free_benchmark/common/utils.py +43 -5
  191. msprobe/pytorch/free_benchmark/compare/grad_saver.py +47 -9
  192. msprobe/pytorch/free_benchmark/compare/single_benchmark.py +17 -0
  193. msprobe/pytorch/free_benchmark/main.py +19 -4
  194. msprobe/pytorch/free_benchmark/perturbed_layers/base_layer.py +15 -0
  195. msprobe/pytorch/free_benchmark/perturbed_layers/layer_factory.py +19 -4
  196. msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +18 -1
  197. msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +21 -4
  198. msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +28 -2
  199. msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +19 -0
  200. msprobe/pytorch/free_benchmark/perturbed_layers/npu/no_change.py +15 -0
  201. msprobe/pytorch/free_benchmark/perturbed_layers/npu/npu_base_layser.py +15 -0
  202. msprobe/pytorch/free_benchmark/perturbed_layers/run_cpu.py +15 -0
  203. msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +65 -16
  204. msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +15 -0
  205. msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +21 -5
  206. msprobe/pytorch/free_benchmark/result_handlers/handler_factory.py +15 -0
  207. msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +19 -4
  208. msprobe/pytorch/function_factory.py +17 -2
  209. msprobe/pytorch/functional/module_dump.py +84 -0
  210. msprobe/pytorch/grad_probe/grad_monitor.py +23 -6
  211. msprobe/pytorch/grad_probe/grad_stat_csv.py +40 -10
  212. msprobe/pytorch/hook_module/__init__.py +16 -1
  213. msprobe/pytorch/hook_module/api_registry.py +13 -8
  214. msprobe/pytorch/hook_module/hook_module.py +17 -19
  215. msprobe/pytorch/hook_module/support_wrap_ops.yaml +1 -0
  216. msprobe/pytorch/hook_module/utils.py +4 -6
  217. msprobe/pytorch/hook_module/wrap_aten.py +12 -11
  218. msprobe/pytorch/hook_module/wrap_distributed.py +6 -7
  219. msprobe/pytorch/hook_module/wrap_functional.py +21 -20
  220. msprobe/pytorch/hook_module/wrap_npu_custom.py +9 -17
  221. msprobe/pytorch/hook_module/wrap_tensor.py +4 -6
  222. msprobe/pytorch/hook_module/wrap_torch.py +4 -6
  223. msprobe/pytorch/hook_module/wrap_vf.py +4 -6
  224. msprobe/pytorch/module_processer.py +18 -6
  225. msprobe/pytorch/monitor/anomaly_analyse.py +201 -0
  226. msprobe/pytorch/monitor/anomaly_detect.py +340 -0
  227. msprobe/pytorch/monitor/distributed/distributed_ops.yaml +19 -0
  228. msprobe/pytorch/monitor/distributed/stack_blacklist.yaml +5 -0
  229. msprobe/pytorch/monitor/distributed/wrap_distributed.py +272 -0
  230. msprobe/pytorch/monitor/features.py +108 -0
  231. msprobe/pytorch/monitor/module_hook.py +870 -0
  232. msprobe/pytorch/monitor/module_metric.py +193 -0
  233. msprobe/pytorch/monitor/module_spec_verifier.py +93 -0
  234. msprobe/pytorch/monitor/optimizer_collect.py +295 -0
  235. msprobe/pytorch/monitor/unittest/__init__.py +0 -0
  236. msprobe/pytorch/monitor/unittest/test_monitor.py +145 -0
  237. msprobe/pytorch/monitor/utils.py +250 -0
  238. msprobe/pytorch/monitor/visualizer.py +59 -0
  239. msprobe/pytorch/online_dispatch/__init__.py +2 -3
  240. msprobe/pytorch/online_dispatch/compare.py +38 -48
  241. msprobe/pytorch/online_dispatch/dispatch.py +50 -25
  242. msprobe/pytorch/online_dispatch/dump_compare.py +21 -9
  243. msprobe/pytorch/online_dispatch/single_compare.py +60 -39
  244. msprobe/pytorch/online_dispatch/torch_ops_config.yaml +9 -1
  245. msprobe/pytorch/online_dispatch/utils.py +48 -23
  246. msprobe/pytorch/parse.py +15 -0
  247. msprobe/pytorch/parse_tool/cli.py +5 -6
  248. msprobe/pytorch/parse_tool/lib/compare.py +19 -26
  249. msprobe/pytorch/parse_tool/lib/config.py +1 -1
  250. msprobe/pytorch/parse_tool/lib/parse_tool.py +4 -2
  251. msprobe/pytorch/parse_tool/lib/utils.py +40 -55
  252. msprobe/pytorch/parse_tool/lib/visualization.py +3 -1
  253. msprobe/pytorch/pt_config.py +192 -40
  254. msprobe/pytorch/service.py +110 -35
  255. msprobe/visualization/__init__.py +14 -0
  256. msprobe/visualization/builder/__init__.py +14 -0
  257. msprobe/visualization/builder/graph_builder.py +165 -0
  258. msprobe/visualization/builder/msprobe_adapter.py +205 -0
  259. msprobe/visualization/compare/__init__.py +14 -0
  260. msprobe/visualization/compare/graph_comparator.py +130 -0
  261. msprobe/visualization/compare/mode_adapter.py +211 -0
  262. msprobe/visualization/graph/__init__.py +14 -0
  263. msprobe/visualization/graph/base_node.py +124 -0
  264. msprobe/visualization/graph/graph.py +200 -0
  265. msprobe/visualization/graph/node_colors.py +95 -0
  266. msprobe/visualization/graph/node_op.py +39 -0
  267. msprobe/visualization/graph_service.py +214 -0
  268. msprobe/visualization/utils.py +232 -0
  269. mindstudio_probe-1.0.4.dist-info/RECORD +0 -276
  270. msprobe/docs/04.acl_config_examples.md +0 -76
  271. msprobe/mindspore/free_benchmark/decorator/dec_forward.py +0 -43
  272. msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +0 -107
  273. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/ssl_config.py +0 -10
  274. msprobe/pytorch/functional/dump_module.py +0 -39
  275. {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.1.dist-info}/LICENSE +0 -0
  276. {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.1.dist-info}/top_level.txt +0 -0
  277. /msprobe/{mindspore/free_benchmark/decorator → pytorch/monitor}/__init__.py +0 -0
  278. /msprobe/pytorch/{functional/data_processor.py → monitor/distributed/__init__.py} +0 -0
@@ -0,0 +1,235 @@
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import os
17
+ import re
18
+ from copy import deepcopy
19
+ from dataclasses import dataclass
20
+ from typing import ClassVar, Dict, List, Optional, Tuple
21
+
22
+ import yaml
23
+ from msprobe.core.common.const import Const
24
+ from msprobe.core.common.file_utils import save_yaml
25
+ from msprobe.core.common.log import logger
26
+ from msprobe.core.common.utils import CompareException, add_time_with_yaml
27
+ from msprobe.core.compare.layer_mapping.postprocess_pass import postprocess_pass
28
+
29
+
30
+ @dataclass
31
+ class DumpDataItem:
32
+ framework: str
33
+ data_name: Optional[str] = None
34
+ api_type: Optional[str] = None
35
+ api_name: Optional[str] = None
36
+ type_name: Optional[str] = None
37
+ full_scope: str = ""
38
+ layer_scope: str = ""
39
+ stack_scope: str = ""
40
+ frame_stack_scope: str = ""
41
+ user_stack_scope: str = ""
42
+ construct_scope: str = ""
43
+ scope_direction: Optional[str] = None
44
+ scope_id: Optional[int] = None
45
+
46
+ # 类变量使用 ClassVar
47
+ framework2layername: ClassVar[Dict[str, str]] = {
48
+ Const.MS_FRAMEWORK: Const.CELL, Const.PT_FRAMEWORK: Const.MODULE}
49
+ framework2stack_sign: ClassVar[Dict[str, Tuple[str, str]]] = {
50
+ Const.MS_FRAMEWORK: ("Template", "construct"),
51
+ Const.PT_FRAMEWORK: ("Template", r"in (for|back)ward,")
52
+ }
53
+
54
+ @staticmethod
55
+ def check_stack_valid(stack_info):
56
+ if stack_info is not None:
57
+ if not isinstance(stack_info, list):
58
+ logger.error(f"stack is invalid, it should be a list[str], but got {stack_info}")
59
+ raise CompareException(CompareException.INVALID_DATA_ERROR)
60
+ for stack in stack_info:
61
+ if not isinstance(stack, str):
62
+ logger.error(f"stack is invalid, it should be a list[str], but got {stack_info}")
63
+ raise CompareException(CompareException.INVALID_DATA_ERROR)
64
+
65
+ def set(self, data_name: str, construct_info: str, stack_info: str) -> None:
66
+ self.set_name(data_name)
67
+ self.set_layer_scope(construct_info)
68
+ self.set_stack_scope(stack_info)
69
+ self.set_full_scope()
70
+
71
+ def set_name(self, data_name: str) -> None:
72
+ self.data_name = data_name
73
+ data_name_list = data_name.split(Const.SEP)
74
+ if not data_name_list or len(data_name_list) < abs(Const.LAYER_NAME_INDEX):
75
+ logger.error(
76
+ f"The dump data does not comply with the format specification and "
77
+ f"must contain no less than four fields. "
78
+ f"The current data is {data_name}"
79
+ )
80
+ raise CompareException(CompareException.INVALID_DATA_ERROR)
81
+
82
+ self.api_type = data_name_list[Const.API_TYPE_INDEX]
83
+ self.type_name = data_name_list[Const.TYPE_NAME_INDEX]
84
+ if self.api_type == self.framework2layername.get(self.framework):
85
+ self.api_name = data_name_list[Const.LAYER_NAME_INDEX]
86
+ else:
87
+ self.api_name = self.type_name
88
+
89
+ def set_layer_scope(self, construct_info: str) -> None:
90
+ self.construct_scope = construct_info
91
+ if self.api_type == self.framework2layername.get(self.framework):
92
+ # remove api name
93
+ data_list = self.data_name.split(Const.SEP)
94
+ data_list = data_list[:Const.LAYER_NAME_INDEX] + data_list[Const.TYPE_NAME_INDEX:]
95
+ elif construct_info:
96
+ data_list = construct_info.split(Const.SEP)
97
+ else:
98
+ data_list = []
99
+
100
+ if data_list:
101
+ self.layer_scope = Const.SEP.join(data_list[:Const.TYPE_NAME_INDEX])
102
+ else:
103
+ self.layer_scope = self.framework2layername.get(self.framework)
104
+ if construct_info:
105
+ construct_list = construct_info.split(Const.SEP)
106
+ if len(construct_list) < abs(Const.LAYER_NAME_INDEX):
107
+ logger.error(
108
+ f"The construct data does not comply with the format specification and "
109
+ f"must contain no less than four fields. "
110
+ f"The current data is {construct_info}"
111
+ )
112
+ raise CompareException(CompareException.INVALID_DATA_ERROR)
113
+ self.scope_id = construct_list[Const.SCOPE_ID_INDEX]
114
+ self.scope_direction = construct_list[Const.SCOPE_DIRECTION_INDEX]
115
+
116
+ def set_stack_scope(self, stack_info: str) -> None:
117
+ # Cell/Module has no stack info
118
+ if self.api_type == self.framework2layername.get(self.framework):
119
+ return
120
+
121
+ if self.api_type in Const.DATA_TYPE_SKIP_LIST or not stack_info:
122
+ return
123
+
124
+ start_sign, end_sign = self.framework2stack_sign.get(self.framework)
125
+ self.check_stack_valid(stack_info)
126
+ start_pos, end_pos = find_regard_scope(stack_info, start_sign, end_sign)
127
+ # 获取指定范围的代码
128
+ regard_scope = stack_info[start_pos + 1:end_pos]
129
+ frame_func_stack_list, user_func_stack_list = find_stack_func_list(regard_scope)
130
+ self.frame_stack_scope = Const.SEP.join(frame_func_stack_list)
131
+ self.user_stack_scope = Const.SEP.join(user_func_stack_list)
132
+
133
+ def set_full_scope(self, use_user_func_scope=False, use_frame_func_scope=True) -> None:
134
+ scope_list = [self.layer_scope]
135
+ if use_user_func_scope and self.user_stack_scope:
136
+ scope_list.append(self.user_stack_scope)
137
+ if use_frame_func_scope and self.frame_stack_scope:
138
+ scope_list.append(self.frame_stack_scope)
139
+ scope_list.append(self.api_name)
140
+ self.full_scope = Const.SEP.join(scope_list)
141
+
142
+
143
+ def find_regard_scope(lines, start_sign, end_sign):
144
+ # 找出 start_pos 和 end_pos
145
+ start_pos = -1
146
+ end_pos = len(lines)
147
+ for idx, ii in enumerate(lines):
148
+ if re.search(start_sign, ii):
149
+ start_pos = idx
150
+ elif start_pos >= 0 and re.search(end_sign, ii):
151
+ end_pos = idx
152
+ break
153
+ return start_pos, end_pos
154
+
155
+
156
+ def find_stack_func_list(lines, record_user=True):
157
+ res_list = []
158
+ user_stack = []
159
+ frame_stack = None
160
+ no_entrance = True
161
+ for line in lines:
162
+ ele_list = line.split(Const.COMMA)
163
+ file_ele = ele_list[Const.STACK_FILE_INDEX]
164
+ # if framework func line and no framework entrance found yet
165
+ if any(ii in file_ele for ii in Const.FRAME_FILE_LIST) and no_entrance:
166
+ frame_stack = line # Update the last target index
167
+ else:
168
+ if record_user:
169
+ user_stack.append(line)
170
+ no_entrance = False
171
+
172
+ # Check if the last string in the list contains target str
173
+ if frame_stack and no_entrance:
174
+ no_entrance = False
175
+
176
+ # 过滤和处理 regard_scope
177
+ frame_func = get_stack_in_lines([frame_stack])
178
+ user_func = get_stack_in_lines(user_stack)
179
+ return (frame_func, user_func)
180
+
181
+
182
+ def get_stack_in_lines(simplified: List[str]):
183
+ res_list = []
184
+ if not simplified:
185
+ return res_list
186
+ for line in simplified:
187
+ if not line:
188
+ continue
189
+
190
+ ele_list = line.split(Const.COMMA)
191
+ file_ele = ele_list[Const.STACK_FILE_INDEX]
192
+ if any(ii in file_ele for ii in Const.FILE_SKIP_LIST):
193
+ continue
194
+
195
+ func_ele = ele_list[Const.STACK_FUNC_INDEX]
196
+ if any(ii in func_ele for ii in Const.FUNC_SKIP_LIST):
197
+ continue
198
+
199
+ in_func_name = func_ele.split()[Const.STACK_FUNC_ELE_INDEX]
200
+
201
+ res_list.append(in_func_name)
202
+
203
+ reversed_list = res_list[::-1]
204
+ return reversed_list
205
+
206
+
207
+ def dumpdata_representer(dumper, data):
208
+ d = deepcopy(data.__dict__)
209
+ d.pop("data_name")
210
+ return dumper.represent_dict(d)
211
+
212
+
213
+ def get_dump_data_items(dump, stack, construct, framework, output_path=None):
214
+ if not stack or not construct:
215
+ return []
216
+ name2item = {}
217
+ data_items = []
218
+
219
+ dump_data = dump.get("data", {})
220
+ for data_name in dump_data:
221
+ code_info = stack.get(data_name, None)
222
+ parent_info = construct.get(data_name, None)
223
+ data_item = DumpDataItem(framework)
224
+ data_item.set(data_name, parent_info, code_info)
225
+ name2item[data_name] = data_item
226
+ data_items.append(data_item)
227
+
228
+ postprocess_pass(data_items, name2item)
229
+
230
+ if output_path:
231
+ yaml.add_representer(DumpDataItem, dumpdata_representer)
232
+ file_name = add_time_with_yaml(f"{framework}_data")
233
+ file_path = os.path.join(os.path.realpath(output_path), file_name)
234
+ save_yaml(file_path, name2item)
235
+ return data_items
@@ -0,0 +1,242 @@
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import os
17
+
18
+ from msprobe.core.common.const import CompareConst, Const
19
+ from msprobe.core.common.file_utils import load_json, load_yaml, save_yaml
20
+ from msprobe.core.common.utils import (add_time_with_yaml,
21
+ detect_framework_by_dump_json,
22
+ get_stack_construct_by_dump_json_path)
23
+ from msprobe.core.compare.layer_mapping.data_scope_parser import get_dump_data_items
24
+ from msprobe.core.compare.utils import read_op
25
+
26
+
27
+ class LayerTrie:
28
+ def __init__(self, type_name, framework=None):
29
+ self.type_name = type_name
30
+ self.data_items = []
31
+ self.children = {}
32
+ self.framework = framework
33
+
34
+ def __repr__(self):
35
+ return f"Layer(type_name={self.type_name}, data_number={len(self.data_items)})"
36
+
37
+ def get(self, name):
38
+ return self.children.get(name)
39
+
40
+ def insert(self, data_item):
41
+ parts = data_item.full_scope.split(Const.SEP)
42
+ node = self
43
+ scope_name_list = parts[Const.RIGHT_MOVE_INDEX:]
44
+
45
+ for name in scope_name_list:
46
+ if name not in node.children:
47
+ node.children[name] = LayerTrie(name, data_item.framework)
48
+ node = node.children[name]
49
+ node.data_items.append(data_item)
50
+ node.type_name = data_item.type_name
51
+
52
+ def query_data(self, scope, index, default_value=None):
53
+ parts = scope.split(Const.SEP)
54
+ node = self
55
+ scope_name_list = parts[1:]
56
+
57
+ for name in scope_name_list:
58
+ if name not in node.children:
59
+ return default_value
60
+ node = node.children[name]
61
+ if index >= len(node.data_items):
62
+ return default_value
63
+ return node.data_items[index]
64
+
65
+ def save_to_yaml(self, output_path):
66
+ result = {f"{self.type_name} @ {self}": self.convert_to_dict(self)}
67
+ file_name = add_time_with_yaml(f"{self.framework}_tree")
68
+ file_path = os.path.join(os.path.realpath(output_path), file_name)
69
+ save_yaml(file_path, result)
70
+
71
+ def convert_to_dict(self, node):
72
+ result = {}
73
+ result["data_item"] = [node.data_name for node in node.data_items]
74
+ for child_key, child_node in node.children.items():
75
+ key = f"{child_key} @ {child_node}"
76
+ result[key] = self.convert_to_dict(child_node)
77
+ return result
78
+
79
+
80
+ def convert_scope(layer_trie, data_item, mapping=None):
81
+ if not mapping:
82
+ mapping = {}
83
+ new_scope = Const.TOP_LAYER
84
+ scope_list = data_item.full_scope.split(Const.SEP)
85
+ cur_node = layer_trie
86
+
87
+ idx = 0
88
+ while idx < len(scope_list) - 1:
89
+ child_name = scope_list[idx + 1]
90
+ type_name = cur_node.type_name
91
+ prefix_mapping = mapping.get(type_name, {})
92
+ mapping_list = prefix_mapping.get(child_name, [])
93
+ mapping_list.append((child_name, child_name, 1))
94
+ step = 1
95
+ for origin, target, level in mapping_list:
96
+ if Const.SEP.join(scope_list[idx + 1: idx + level + 1]) == origin:
97
+ new_scope = new_scope + Const.SEP + target
98
+ step = level
99
+ break
100
+ for _ in range(step):
101
+ child_node = cur_node.get(scope_list[idx + 1])
102
+ cur_node = child_node
103
+ idx += 1
104
+ index = -1
105
+ for idx, child in enumerate(cur_node.data_items):
106
+ if data_item.data_name == child.data_name:
107
+ index = idx
108
+ return new_scope, index
109
+
110
+
111
+ def get_data_items_and_tree(dump_json_path, output_path):
112
+ framework = detect_framework_by_dump_json(dump_json_path)
113
+ stack, construct = get_stack_construct_by_dump_json_path(dump_json_path)
114
+ dump = load_json(dump_json_path)
115
+ dump_data_items = get_dump_data_items(dump, stack, construct, framework, output_path)
116
+ root = LayerTrie(Const.TOP_LAYER, framework)
117
+ for data_item in dump_data_items:
118
+ root.insert(data_item)
119
+ if output_path:
120
+ root.save_to_yaml(output_path)
121
+ return dump_data_items, root
122
+
123
+
124
+ def convert_data_item(npu_tree, bench_tree, npu_data_item, mapping):
125
+ new_scope, index = convert_scope(npu_tree, npu_data_item, mapping)
126
+ bench_data_item = bench_tree.query_data(new_scope, index)
127
+ return bench_data_item
128
+
129
+
130
+ def update_keys_in_place(d):
131
+ """
132
+ This function is used to compare and maintain compatibility between the old and new versions.
133
+ In the old version, 'Cell' was used as the top layer name, while the new version uses 'TopLayer'.
134
+ """
135
+ cell_value = d.pop(Const.CELL, None)
136
+
137
+ if cell_value is not None:
138
+ d[Const.TOP_LAYER] = cell_value
139
+
140
+
141
+ def preprocess_layer_mapping(mapping):
142
+ """
143
+ before:
144
+ {'A': {'a.b.c': 'new_c',
145
+ 'a.demo': 'new_demo',
146
+ 'z': 'new_z',
147
+ 'd.e': 'e'}}
148
+ after:
149
+ {'A': {'a': [('a.b.c', 'new_c', 3), ('a.demo', 'new_demo', 2)],
150
+ 'z': [('z', 'new_z', 1)],
151
+ 'd': [('d.e', 'e', 2)]}}
152
+ """
153
+ update_keys_in_place(mapping)
154
+ final_mapping = {}
155
+
156
+ for type_name, name_map in mapping.items():
157
+ final_mapping[type_name] = {}
158
+
159
+ for key, value in name_map.items():
160
+ key_list = key.split('.')
161
+ prefix = key_list[0] # 取前缀
162
+ key_len = len(key_list)
163
+ if prefix not in final_mapping[type_name]:
164
+ final_mapping[type_name][prefix] = []
165
+ final_mapping[type_name][prefix].append((key, value, key_len))
166
+
167
+ # 前缀映射列表按规则长度排序
168
+ for prefix in final_mapping[type_name]:
169
+ final_mapping[type_name][prefix].sort(key=lambda x: -x[-1])
170
+
171
+ return final_mapping
172
+
173
+
174
+ def convert_data_items(npu_tree, bench_tree, npu_data_items, mapping):
175
+ mapping = preprocess_layer_mapping(mapping)
176
+ api_mapping = {}
177
+ for npu_data_item in npu_data_items:
178
+ bench_data_item = convert_data_item(npu_tree, bench_tree, npu_data_item, mapping)
179
+ bench_name = bench_data_item.data_name if bench_data_item else CompareConst.N_A
180
+ npu_name = npu_data_item.data_name
181
+ api_mapping[npu_name] = bench_name
182
+ return api_mapping
183
+
184
+
185
+ def generate_api_mapping_by_layer_mapping(npu_json_path, bench_json_path, layer_mapping_path=None, output_path=None):
186
+ npu_data_items, npu_root = get_data_items_and_tree(npu_json_path, output_path)
187
+ _, bench_root = get_data_items_and_tree(bench_json_path, output_path)
188
+ if isinstance(layer_mapping_path, str):
189
+ mapping = load_yaml(layer_mapping_path)
190
+ else:
191
+ mapping = {}
192
+ api_mapping = convert_data_items(npu_root, bench_root, npu_data_items, mapping)
193
+ if output_path:
194
+ file_name = add_time_with_yaml("api_mapping")
195
+ file_path = os.path.join(os.path.realpath(output_path), file_name)
196
+ save_yaml(file_path, api_mapping)
197
+ return api_mapping
198
+
199
+
200
+ def generate_data_mapping(npu_json_path, bench_json_path, api_mapping, output_path=None):
201
+ def read_full_op_names(data, op_name):
202
+ op_parsed_list = read_op(data.get(op_name, {}), op_name)
203
+ full_op_names = [op_parsed.get('full_op_name') for op_parsed in op_parsed_list]
204
+ return full_op_names
205
+
206
+ def generate_op_data_mapping(npu_op_name, npu_full_op_names, bench_op_name, bench_full_op_names):
207
+ suffix_to_full_op_name = {}
208
+ op_data_mapping = {}
209
+ for bench_full_op_name in bench_full_op_names:
210
+ suffix = bench_full_op_name[len(bench_op_name):]
211
+ suffix_to_full_op_name[suffix] = bench_full_op_name
212
+
213
+ for npu_full_op_name in npu_full_op_names:
214
+ suffix = npu_full_op_name[len(npu_op_name):]
215
+ op_data_mapping[npu_full_op_name] = suffix_to_full_op_name.get(suffix, CompareConst.N_A)
216
+ return op_data_mapping
217
+
218
+ npu_data = load_json(npu_json_path).get("data", {})
219
+ bench_data = load_json(bench_json_path).get("data", {})
220
+ data_mapping = {}
221
+ for npu_op_name, bench_op_name in api_mapping.items():
222
+ if not npu_op_name:
223
+ continue
224
+ npu_full_op_names = read_full_op_names(npu_data, npu_op_name)
225
+ bench_full_op_names = read_full_op_names(bench_data, bench_op_name)
226
+ mapping = generate_op_data_mapping(npu_op_name, npu_full_op_names, bench_op_name, bench_full_op_names)
227
+ data_mapping.update(mapping)
228
+ if output_path:
229
+ file_name = add_time_with_yaml("data_mapping")
230
+ file_path = os.path.join(os.path.realpath(output_path), file_name)
231
+ save_yaml(file_path, data_mapping)
232
+ return data_mapping
233
+
234
+
235
+ def generate_data_mapping_by_layer_mapping(input_param, layer_mapping_path=None, output_path=None):
236
+ npu_json_path = input_param.get("npu_json_path")
237
+ bench_json_path = input_param.get("bench_json_path")
238
+ api_mapping = generate_api_mapping_by_layer_mapping(
239
+ npu_json_path, bench_json_path, layer_mapping_path)
240
+ data_mapping = generate_data_mapping(
241
+ npu_json_path, bench_json_path, api_mapping, output_path)
242
+ return data_mapping
@@ -0,0 +1,94 @@
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ import re
16
+ import math
17
+
18
+ from msprobe.core.common.const import Const
19
+
20
+
21
+ def postprocess_pass(data_items, name2item):
22
+ backward_pass(data_items, name2item)
23
+ renumber_index_pass(data_items, "ParallelTransformer", "layers")
24
+
25
+
26
+ def backward_pass(data_items, name2item):
27
+ # 处理反向数据,反向无栈信息,沿用正向数据栈信息
28
+ for data_item in data_items:
29
+ data_name_list = data_item.data_name.split(Const.SEP)
30
+ if not data_name_list:
31
+ continue
32
+ if Const.BACKWARD in data_name_list[Const.SCOPE_DIRECTION_INDEX :]:
33
+ data_name_list[Const.SCOPE_DIRECTION_INDEX :] = [
34
+ s.replace(Const.BACKWARD, Const.FORWARD) for s in data_name_list[Const.SCOPE_DIRECTION_INDEX :]
35
+ ]
36
+ forward_name = Const.SEP.join(data_name_list)
37
+ forward_item = name2item.get(forward_name, None)
38
+ if not forward_item:
39
+ continue
40
+ data_item.stack_scope = forward_item.stack_scope
41
+ data_item.full_scope = forward_item.full_scope
42
+ data_item.layer_scope = forward_item.layer_scope
43
+
44
+
45
+ def extract_next_item_last_number(data, prefix, default_result=None):
46
+ result = default_result
47
+ match = re.search(rf"^{re.escape(prefix)}\.(\S+?)(?:\.|$)", data)
48
+ if match:
49
+ next_item = match.group(1)
50
+ numbers = re.findall(r"\d+", next_item)
51
+ if numbers:
52
+ result = int(numbers[-1])
53
+ return result
54
+
55
+
56
+ def replace_next_item_index(full_scope, prefix, index):
57
+ if math.isinf(index):
58
+ return full_scope
59
+ prefix_pattern = rf"^{re.escape(prefix)}\."
60
+ result = full_scope
61
+ match = re.search(rf"{prefix_pattern}(\S+?)(?:\.|$)", full_scope)
62
+ if match:
63
+ next_item = match.group(1)
64
+ pattern = rf"{prefix_pattern}{re.escape(next_item)}"
65
+ result = re.sub(pattern, f"{prefix}.{index}", full_scope, count=1)
66
+ return result
67
+
68
+
69
+ def renumber_index_pass(data_items, type_name, suffix=None):
70
+ """
71
+ 该函数为解决并行切分场景中编号不一致的比对问题。例如在MindSpore中ParallelTransformer层的PP切分场景,
72
+ MindSpore中的layers的成员编号是全局的,而在PyTorch中编号为局部的。
73
+ 为适配此种场景,对指定层的索引进行重新编号,以确保在后续处理阶段序号对齐。
74
+ """
75
+ prefix_dict = {} # 保存类型为type_name的前缀和最小编号的映射
76
+ for data_item in data_items:
77
+ if data_item.type_name == type_name:
78
+ prefix = f"{data_item.full_scope}.{suffix}" if suffix else data_item.layer_scope
79
+ prefix_dict[prefix] = math.inf
80
+
81
+ # 计算前缀对应的最小编号
82
+ for prefix in prefix_dict:
83
+ for data_item in data_items:
84
+ res = extract_next_item_last_number(data_item.full_scope, prefix, math.inf)
85
+ prefix_dict[prefix] = min(prefix_dict[prefix], res)
86
+
87
+ # 重新编号
88
+ for prefix, min_index in prefix_dict.items():
89
+ for data_item in data_items:
90
+ full_scope = data_item.full_scope
91
+ abs_index = extract_next_item_last_number(data_item.full_scope, prefix, math.inf)
92
+ rel_index = abs_index - min_index
93
+ full_scope = replace_next_item_index(full_scope, prefix, rel_index)
94
+ data_item.full_scope = full_scope
@@ -1,9 +1,22 @@
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
1
15
 
2
16
  import multiprocessing
3
17
  from dataclasses import dataclass
4
- from functools import partial
5
- import numpy as np
6
18
  import pandas as pd
19
+ from tqdm import tqdm
7
20
  from msprobe.core.common.log import logger
8
21
  from msprobe.core.common.utils import CompareException
9
22
  from msprobe.core.common.const import CompareConst
@@ -29,11 +42,19 @@ def _handle_multi_process(func, input_parma, result_df, lock):
29
42
  except OSError as e:
30
43
  logger.error("pool terminate failed")
31
44
 
45
+ progress_bar = tqdm(total=len(result_df), desc="API/Module Item Compare Process", unit="row", ncols=100)
46
+
47
+ def update_progress(size, progress_lock):
48
+ with progress_lock:
49
+ progress_bar.update(size)
50
+
32
51
  for process_idx, df_chunk in enumerate(df_chunks):
33
52
  idx = df_chunk_size * process_idx
53
+ chunk_size = len(df_chunk)
34
54
  result = pool.apply_async(func,
35
55
  args=(idx, op_name_mapping_dict, df_chunk, lock, input_parma),
36
- error_callback=err_call)
56
+ error_callback=err_call,
57
+ callback=update_progress(chunk_size, lock))
37
58
  results.append(result)
38
59
  final_results = [r.get() for r in results]
39
60
  pool.close()
@@ -42,7 +63,7 @@ def _handle_multi_process(func, input_parma, result_df, lock):
42
63
 
43
64
 
44
65
  def _ms_graph_handle_multi_process(func, result_df, mode):
45
- process_num = int((multiprocessing.cpu_count() + 1) // 2)
66
+ process_num = int((multiprocessing.cpu_count() + 1) // 4)
46
67
  df_chunk_size = len(result_df) // process_num
47
68
  if df_chunk_size > 0:
48
69
  df_chunks = [result_df.iloc[i:i + df_chunk_size] for i in range(0, len(result_df), df_chunk_size)]
@@ -84,7 +105,8 @@ def read_dump_data(result_df):
84
105
  except IndexError as e:
85
106
  logger.error('result dataframe elements can not be access.')
86
107
  raise CompareException(CompareException.INDEX_OUT_OF_BOUNDS_ERROR) from e
87
-
108
+
109
+
88
110
  @dataclass
89
111
  class ComparisonResult:
90
112
  cos_result: list
@@ -116,9 +138,12 @@ def _save_cmp_result(offset, result: ComparisonResult, result_df, lock):
116
138
  result_df.loc[process_index, CompareConst.MAX_ABS_ERR] = result.max_err_result[i]
117
139
  result_df.loc[process_index, CompareConst.MAX_RELATIVE_ERR] = result.max_relative_err_result[i]
118
140
  result_df.loc[process_index, CompareConst.ERROR_MESSAGE] = result.err_msgs[i]
119
- result_df.loc[process_index, CompareConst.ACCURACY] = check_accuracy(result.cos_result[i], result.max_err_result[i])
120
- result_df.loc[process_index, CompareConst.ONE_THOUSANDTH_ERR_RATIO] = result.one_thousand_err_ratio_result[i]
121
- result_df.loc[process_index, CompareConst.FIVE_THOUSANDTHS_ERR_RATIO] = result.five_thousand_err_ratio_result[i]
141
+ result_df.loc[process_index, CompareConst.ACCURACY] = (
142
+ check_accuracy(result.cos_result[i], result.max_err_result[i]))
143
+ result_df.loc[process_index, CompareConst.ONE_THOUSANDTH_ERR_RATIO] = (
144
+ result.one_thousand_err_ratio_result)[i]
145
+ result_df.loc[process_index, CompareConst.FIVE_THOUSANDTHS_ERR_RATIO] = (
146
+ result.five_thousand_err_ratio_result)[i]
122
147
  return result_df
123
148
  except ValueError as e:
124
149
  logger.error('result dataframe is not found.')