mindstudio-probe 1.0.4__py3-none-any.whl → 1.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (278) hide show
  1. {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.1.dist-info}/METADATA +5 -5
  2. mindstudio_probe-1.1.1.dist-info/RECORD +341 -0
  3. {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.1.dist-info}/WHEEL +1 -1
  4. {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.1.dist-info}/entry_points.txt +0 -1
  5. msprobe/README.md +84 -18
  6. msprobe/__init__.py +16 -1
  7. msprobe/config.json +1 -5
  8. msprobe/core/advisor/advisor.py +16 -11
  9. msprobe/core/advisor/advisor_const.py +6 -7
  10. msprobe/core/advisor/advisor_result.py +12 -12
  11. msprobe/core/common/const.py +164 -3
  12. msprobe/core/common/exceptions.py +26 -4
  13. msprobe/core/common/file_utils.py +196 -27
  14. msprobe/core/common/inplace_op_checker.py +53 -0
  15. msprobe/core/common/inplace_ops.yaml +251 -0
  16. msprobe/core/common/log.py +46 -18
  17. msprobe/core/common/utils.py +308 -209
  18. msprobe/core/common_config.py +60 -38
  19. msprobe/core/compare/acc_compare.py +332 -94
  20. msprobe/core/compare/check.py +104 -22
  21. msprobe/core/compare/compare_cli.py +42 -5
  22. msprobe/core/compare/highlight.py +162 -57
  23. msprobe/core/compare/layer_mapping/__init__.py +19 -0
  24. msprobe/core/compare/layer_mapping/data_scope_parser.py +235 -0
  25. msprobe/core/compare/layer_mapping/layer_mapping.py +242 -0
  26. msprobe/core/compare/layer_mapping/postprocess_pass.py +94 -0
  27. msprobe/core/compare/multiprocessing_compute.py +33 -8
  28. msprobe/core/compare/npy_compare.py +73 -29
  29. msprobe/core/compare/utils.py +306 -247
  30. msprobe/core/data_dump/data_collector.py +44 -43
  31. msprobe/core/data_dump/data_processor/base.py +88 -35
  32. msprobe/core/data_dump/data_processor/factory.py +20 -3
  33. msprobe/core/data_dump/data_processor/mindspore_processor.py +14 -8
  34. msprobe/core/data_dump/data_processor/pytorch_processor.py +180 -66
  35. msprobe/core/data_dump/json_writer.py +63 -42
  36. msprobe/core/data_dump/scope.py +143 -48
  37. msprobe/core/grad_probe/constant.py +31 -13
  38. msprobe/core/grad_probe/grad_compare.py +20 -4
  39. msprobe/core/grad_probe/utils.py +44 -3
  40. msprobe/core/overflow_check/abnormal_scene.py +185 -0
  41. msprobe/core/overflow_check/api_info.py +55 -0
  42. msprobe/core/overflow_check/checker.py +138 -0
  43. msprobe/core/overflow_check/filter.py +157 -0
  44. msprobe/core/overflow_check/ignore_rules.yaml +55 -0
  45. msprobe/core/overflow_check/level.py +22 -0
  46. msprobe/core/overflow_check/utils.py +28 -0
  47. msprobe/docs/01.installation.md +29 -9
  48. msprobe/docs/02.config_introduction.md +83 -84
  49. msprobe/docs/03.config_examples.md +3 -20
  50. msprobe/docs/04.kernel_dump_PyTorch.md +73 -0
  51. msprobe/docs/05.data_dump_PyTorch.md +143 -13
  52. msprobe/docs/06.data_dump_MindSpore.md +197 -88
  53. msprobe/docs/07.accuracy_checker_PyTorch.md +69 -46
  54. msprobe/docs/08.accuracy_checker_online_PyTorch.md +52 -17
  55. msprobe/docs/09.accuracy_checker_MindSpore.md +51 -15
  56. msprobe/docs/10.accuracy_compare_PyTorch.md +187 -99
  57. msprobe/docs/11.accuracy_compare_MindSpore.md +253 -31
  58. msprobe/docs/12.overflow_check_PyTorch.md +1 -1
  59. msprobe/docs/13.overflow_check_MindSpore.md +6 -6
  60. msprobe/docs/15.free_benchmarking_PyTorch.md +60 -55
  61. msprobe/docs/16.free_benchmarking_MindSpore.md +159 -0
  62. msprobe/docs/17.grad_probe.md +19 -22
  63. msprobe/docs/18.online_dispatch.md +89 -0
  64. msprobe/docs/19.monitor.md +468 -0
  65. msprobe/docs/20.monitor_performance_baseline.md +52 -0
  66. msprobe/docs/21.visualization_PyTorch.md +386 -0
  67. msprobe/docs/22.visualization_MindSpore.md +384 -0
  68. msprobe/docs/23.tool_function_introduction.md +28 -0
  69. msprobe/docs/{FAQ_PyTorch.md → FAQ.md} +25 -10
  70. msprobe/docs/data_dump_Mindspore/dynamic_graph_quick_start_example.md +211 -0
  71. msprobe/docs/img/compare_result.png +0 -0
  72. msprobe/docs/img/monitor/cpu_info.png +0 -0
  73. msprobe/docs/img/ms_dump.png +0 -0
  74. msprobe/docs/img/ms_layer.png +0 -0
  75. msprobe/docs/img/pt_dump.png +0 -0
  76. msprobe/mindspore/__init__.py +16 -0
  77. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +130 -138
  78. msprobe/mindspore/api_accuracy_checker/api_info.py +27 -5
  79. msprobe/mindspore/api_accuracy_checker/api_runner.py +43 -18
  80. msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +21 -7
  81. msprobe/mindspore/api_accuracy_checker/checker_support_api.yaml +77 -0
  82. msprobe/mindspore/api_accuracy_checker/cmd_parser.py +63 -1
  83. msprobe/mindspore/api_accuracy_checker/compute_element.py +59 -24
  84. msprobe/mindspore/api_accuracy_checker/data_manager.py +264 -0
  85. msprobe/mindspore/api_accuracy_checker/main.py +27 -3
  86. msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +206 -0
  87. msprobe/mindspore/api_accuracy_checker/multi_data_manager.py +58 -0
  88. msprobe/mindspore/api_accuracy_checker/type_mapping.py +22 -5
  89. msprobe/mindspore/api_accuracy_checker/utils.py +34 -17
  90. msprobe/mindspore/cell_processor.py +58 -13
  91. msprobe/mindspore/common/const.py +35 -13
  92. msprobe/mindspore/common/log.py +5 -9
  93. msprobe/mindspore/common/utils.py +60 -5
  94. msprobe/mindspore/compare/distributed_compare.py +15 -28
  95. msprobe/mindspore/compare/ms_compare.py +319 -158
  96. msprobe/mindspore/compare/ms_graph_compare.py +99 -49
  97. msprobe/mindspore/debugger/debugger_config.py +20 -14
  98. msprobe/mindspore/debugger/precision_debugger.py +43 -13
  99. msprobe/mindspore/dump/dump_tool_factory.py +18 -1
  100. msprobe/mindspore/dump/hook_cell/api_registry.py +23 -3
  101. msprobe/mindspore/dump/hook_cell/primitive_hooks.py +203 -0
  102. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +107 -10
  103. msprobe/mindspore/dump/hook_cell/wrap_api.py +21 -13
  104. msprobe/mindspore/dump/jit_dump.py +56 -20
  105. msprobe/mindspore/dump/kernel_graph_dump.py +19 -5
  106. msprobe/mindspore/dump/kernel_kbyk_dump.py +19 -6
  107. msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +140 -0
  108. msprobe/mindspore/dym_loader/hook_dynamic_loader.h +53 -0
  109. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +162 -41
  110. msprobe/mindspore/free_benchmark/common/config.py +15 -0
  111. msprobe/mindspore/free_benchmark/common/handler_params.py +15 -1
  112. msprobe/mindspore/free_benchmark/common/utils.py +37 -8
  113. msprobe/mindspore/free_benchmark/data/support_wrap_ops.yaml +0 -204
  114. msprobe/mindspore/free_benchmark/handler/base_handler.py +20 -5
  115. msprobe/mindspore/free_benchmark/handler/check_handler.py +21 -7
  116. msprobe/mindspore/free_benchmark/handler/fix_handler.py +18 -3
  117. msprobe/mindspore/free_benchmark/handler/handler_factory.py +21 -6
  118. msprobe/mindspore/free_benchmark/perturbation/add_noise.py +23 -8
  119. msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +29 -5
  120. msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +25 -10
  121. msprobe/mindspore/free_benchmark/perturbation/exchange_value.py +45 -19
  122. msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +29 -8
  123. msprobe/mindspore/free_benchmark/perturbation/no_change.py +16 -1
  124. msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +22 -7
  125. msprobe/mindspore/free_benchmark/self_check_tool_factory.py +17 -2
  126. msprobe/mindspore/grad_probe/global_context.py +44 -14
  127. msprobe/mindspore/grad_probe/grad_analyzer.py +27 -13
  128. msprobe/mindspore/grad_probe/grad_monitor.py +16 -1
  129. msprobe/mindspore/grad_probe/grad_stat_csv.py +33 -5
  130. msprobe/mindspore/grad_probe/hook.py +24 -10
  131. msprobe/mindspore/grad_probe/utils.py +18 -5
  132. msprobe/mindspore/ms_config.py +22 -15
  133. msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +20 -6
  134. msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +15 -0
  135. msprobe/mindspore/runtime.py +15 -0
  136. msprobe/mindspore/service.py +75 -150
  137. msprobe/mindspore/task_handler_factory.py +15 -0
  138. msprobe/msprobe.py +24 -7
  139. msprobe/pytorch/__init__.py +23 -3
  140. msprobe/pytorch/api_accuracy_checker/common/config.py +81 -2
  141. msprobe/pytorch/api_accuracy_checker/common/utils.py +53 -21
  142. msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +19 -2
  143. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +50 -25
  144. msprobe/pytorch/api_accuracy_checker/compare/compare.py +51 -21
  145. msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +23 -6
  146. msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +28 -8
  147. msprobe/pytorch/api_accuracy_checker/config.yaml +1 -1
  148. msprobe/pytorch/api_accuracy_checker/generate_op_script/config_op.json +9 -0
  149. msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +454 -0
  150. msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +365 -0
  151. msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +73 -33
  152. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +44 -18
  153. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +32 -11
  154. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +122 -172
  155. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +158 -4
  156. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +30 -24
  157. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +68 -31
  158. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +27 -4
  159. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/dump_dispatch.py +115 -0
  160. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +26 -9
  161. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/torch_ops_config.yaml +63 -0
  162. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +44 -0
  163. msprobe/pytorch/bench_functions/__init__.py +18 -3
  164. msprobe/pytorch/bench_functions/apply_adam_w.py +15 -0
  165. msprobe/pytorch/bench_functions/confusion_transpose.py +20 -1
  166. msprobe/pytorch/bench_functions/fast_gelu.py +15 -0
  167. msprobe/pytorch/bench_functions/layer_norm_eval.py +15 -0
  168. msprobe/pytorch/bench_functions/linear.py +15 -0
  169. msprobe/pytorch/bench_functions/matmul_backward.py +33 -6
  170. msprobe/pytorch/bench_functions/npu_fusion_attention.py +280 -157
  171. msprobe/pytorch/bench_functions/rms_norm.py +15 -0
  172. msprobe/pytorch/bench_functions/rotary_mul.py +32 -9
  173. msprobe/pytorch/bench_functions/scaled_mask_softmax.py +15 -0
  174. msprobe/pytorch/bench_functions/swiglu.py +29 -6
  175. msprobe/pytorch/common/__init__.py +15 -0
  176. msprobe/pytorch/common/log.py +18 -6
  177. msprobe/pytorch/common/parse_json.py +31 -16
  178. msprobe/pytorch/common/utils.py +96 -40
  179. msprobe/pytorch/compare/distributed_compare.py +13 -14
  180. msprobe/pytorch/compare/match.py +15 -0
  181. msprobe/pytorch/compare/pt_compare.py +44 -10
  182. msprobe/pytorch/debugger/debugger_config.py +69 -52
  183. msprobe/pytorch/debugger/precision_debugger.py +72 -24
  184. msprobe/pytorch/dump/kernel_dump/kernel_config.py +33 -0
  185. msprobe/pytorch/free_benchmark/__init__.py +20 -5
  186. msprobe/pytorch/free_benchmark/common/constant.py +15 -0
  187. msprobe/pytorch/free_benchmark/common/counter.py +15 -0
  188. msprobe/pytorch/free_benchmark/common/enums.py +43 -0
  189. msprobe/pytorch/free_benchmark/common/params.py +23 -1
  190. msprobe/pytorch/free_benchmark/common/utils.py +43 -5
  191. msprobe/pytorch/free_benchmark/compare/grad_saver.py +47 -9
  192. msprobe/pytorch/free_benchmark/compare/single_benchmark.py +17 -0
  193. msprobe/pytorch/free_benchmark/main.py +19 -4
  194. msprobe/pytorch/free_benchmark/perturbed_layers/base_layer.py +15 -0
  195. msprobe/pytorch/free_benchmark/perturbed_layers/layer_factory.py +19 -4
  196. msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +18 -1
  197. msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +21 -4
  198. msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +28 -2
  199. msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +19 -0
  200. msprobe/pytorch/free_benchmark/perturbed_layers/npu/no_change.py +15 -0
  201. msprobe/pytorch/free_benchmark/perturbed_layers/npu/npu_base_layser.py +15 -0
  202. msprobe/pytorch/free_benchmark/perturbed_layers/run_cpu.py +15 -0
  203. msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +65 -16
  204. msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +15 -0
  205. msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +21 -5
  206. msprobe/pytorch/free_benchmark/result_handlers/handler_factory.py +15 -0
  207. msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +19 -4
  208. msprobe/pytorch/function_factory.py +17 -2
  209. msprobe/pytorch/functional/module_dump.py +84 -0
  210. msprobe/pytorch/grad_probe/grad_monitor.py +23 -6
  211. msprobe/pytorch/grad_probe/grad_stat_csv.py +40 -10
  212. msprobe/pytorch/hook_module/__init__.py +16 -1
  213. msprobe/pytorch/hook_module/api_registry.py +13 -8
  214. msprobe/pytorch/hook_module/hook_module.py +17 -19
  215. msprobe/pytorch/hook_module/support_wrap_ops.yaml +1 -0
  216. msprobe/pytorch/hook_module/utils.py +4 -6
  217. msprobe/pytorch/hook_module/wrap_aten.py +12 -11
  218. msprobe/pytorch/hook_module/wrap_distributed.py +6 -7
  219. msprobe/pytorch/hook_module/wrap_functional.py +21 -20
  220. msprobe/pytorch/hook_module/wrap_npu_custom.py +9 -17
  221. msprobe/pytorch/hook_module/wrap_tensor.py +4 -6
  222. msprobe/pytorch/hook_module/wrap_torch.py +4 -6
  223. msprobe/pytorch/hook_module/wrap_vf.py +4 -6
  224. msprobe/pytorch/module_processer.py +18 -6
  225. msprobe/pytorch/monitor/anomaly_analyse.py +201 -0
  226. msprobe/pytorch/monitor/anomaly_detect.py +340 -0
  227. msprobe/pytorch/monitor/distributed/distributed_ops.yaml +19 -0
  228. msprobe/pytorch/monitor/distributed/stack_blacklist.yaml +5 -0
  229. msprobe/pytorch/monitor/distributed/wrap_distributed.py +272 -0
  230. msprobe/pytorch/monitor/features.py +108 -0
  231. msprobe/pytorch/monitor/module_hook.py +870 -0
  232. msprobe/pytorch/monitor/module_metric.py +193 -0
  233. msprobe/pytorch/monitor/module_spec_verifier.py +93 -0
  234. msprobe/pytorch/monitor/optimizer_collect.py +295 -0
  235. msprobe/pytorch/monitor/unittest/__init__.py +0 -0
  236. msprobe/pytorch/monitor/unittest/test_monitor.py +145 -0
  237. msprobe/pytorch/monitor/utils.py +250 -0
  238. msprobe/pytorch/monitor/visualizer.py +59 -0
  239. msprobe/pytorch/online_dispatch/__init__.py +2 -3
  240. msprobe/pytorch/online_dispatch/compare.py +38 -48
  241. msprobe/pytorch/online_dispatch/dispatch.py +50 -25
  242. msprobe/pytorch/online_dispatch/dump_compare.py +21 -9
  243. msprobe/pytorch/online_dispatch/single_compare.py +60 -39
  244. msprobe/pytorch/online_dispatch/torch_ops_config.yaml +9 -1
  245. msprobe/pytorch/online_dispatch/utils.py +48 -23
  246. msprobe/pytorch/parse.py +15 -0
  247. msprobe/pytorch/parse_tool/cli.py +5 -6
  248. msprobe/pytorch/parse_tool/lib/compare.py +19 -26
  249. msprobe/pytorch/parse_tool/lib/config.py +1 -1
  250. msprobe/pytorch/parse_tool/lib/parse_tool.py +4 -2
  251. msprobe/pytorch/parse_tool/lib/utils.py +40 -55
  252. msprobe/pytorch/parse_tool/lib/visualization.py +3 -1
  253. msprobe/pytorch/pt_config.py +192 -40
  254. msprobe/pytorch/service.py +110 -35
  255. msprobe/visualization/__init__.py +14 -0
  256. msprobe/visualization/builder/__init__.py +14 -0
  257. msprobe/visualization/builder/graph_builder.py +165 -0
  258. msprobe/visualization/builder/msprobe_adapter.py +205 -0
  259. msprobe/visualization/compare/__init__.py +14 -0
  260. msprobe/visualization/compare/graph_comparator.py +130 -0
  261. msprobe/visualization/compare/mode_adapter.py +211 -0
  262. msprobe/visualization/graph/__init__.py +14 -0
  263. msprobe/visualization/graph/base_node.py +124 -0
  264. msprobe/visualization/graph/graph.py +200 -0
  265. msprobe/visualization/graph/node_colors.py +95 -0
  266. msprobe/visualization/graph/node_op.py +39 -0
  267. msprobe/visualization/graph_service.py +214 -0
  268. msprobe/visualization/utils.py +232 -0
  269. mindstudio_probe-1.0.4.dist-info/RECORD +0 -276
  270. msprobe/docs/04.acl_config_examples.md +0 -76
  271. msprobe/mindspore/free_benchmark/decorator/dec_forward.py +0 -43
  272. msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +0 -107
  273. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/ssl_config.py +0 -10
  274. msprobe/pytorch/functional/dump_module.py +0 -39
  275. {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.1.dist-info}/LICENSE +0 -0
  276. {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.1.dist-info}/top_level.txt +0 -0
  277. /msprobe/{mindspore/free_benchmark/decorator → pytorch/monitor}/__init__.py +0 -0
  278. /msprobe/pytorch/{functional/data_processor.py → monitor/distributed/__init__.py} +0 -0
@@ -1,8 +1,9 @@
1
1
  #!/usr/bin/env python3
2
2
  # -*- coding: utf-8 -*-
3
- """
4
- # Copyright (C) 2023-2023. Huawei Technologies Co., Ltd. All rights reserved.
5
- # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
4
+ # All rights reserved.
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
7
  # you may not use this file except in compliance with the License.
7
8
  # You may obtain a copy of the License at
8
9
  #
@@ -13,10 +14,11 @@
13
14
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
15
  # See the License for the specific language governing permissions and
15
16
  # limitations under the License.
16
- """
17
+
17
18
  import os
18
19
  import re
19
20
  from collections import namedtuple
21
+ import importlib
20
22
 
21
23
  import torch
22
24
 
@@ -96,7 +98,8 @@ def cross_entropy_process(api_info_dict):
96
98
  Return api_info_dict:
97
99
  api_info_dict: Processed argument of the API.
98
100
  """
99
- if 'input_args' in api_info_dict and len(api_info_dict['input_args']) > 1 and 'Min' in api_info_dict['input_args'][1]:
101
+ if 'input_args' in api_info_dict and len(api_info_dict['input_args']) > 1 \
102
+ and 'Min' in api_info_dict['input_args'][1]:
100
103
  if api_info_dict['input_args'][1]['Min'] <= 0:
101
104
  # The second argument in cross_entropy should be -100 or not less than 0
102
105
  api_info_dict['input_args'][1]['Min'] = 0
@@ -109,18 +112,6 @@ def initialize_save_path(save_path, dir_name):
109
112
  return data_path
110
113
 
111
114
 
112
- def get_real_data_path(file_path):
113
- targets = ['forward_real_data', 'backward_real_data', 'ut_error_data\d+']
114
- pattern = re.compile(r'({})'.format('|'.join(targets)))
115
- match = pattern.search(file_path)
116
- if match:
117
- target_index = match.start()
118
- target_path = file_path[target_index:]
119
- return target_path
120
- else:
121
- raise DumpException(DumpException.INVALID_PATH_ERROR)
122
-
123
-
124
115
  def get_full_data_path(data_path, real_data_path):
125
116
  if not data_path:
126
117
  return data_path
@@ -137,7 +128,10 @@ class UtDataProcessor:
137
128
  self.index = 0
138
129
  self._save_recursive(api_name, element)
139
130
 
140
- def _save_recursive(self, api_name, element):
131
+ def _save_recursive(self, api_name, element, depth=0):
132
+ if depth > Const.MAX_DEPTH:
133
+ logger.error(f"Maximum depth of {Const.MAX_DEPTH} exceeded for {api_name}")
134
+ raise DumpException(DumpException.RECURSION_LIMIT_ERROR)
141
135
  if isinstance(element, torch.Tensor):
142
136
  api_args = api_name + Const.SEP + str(self.index)
143
137
  create_directory(self.save_path)
@@ -153,10 +147,10 @@ class UtDataProcessor:
153
147
  self.index += 1
154
148
  elif isinstance(element, (list, tuple)):
155
149
  for item in element:
156
- self._save_recursive(api_name, item)
150
+ self._save_recursive(api_name, item, depth=depth+1)
157
151
  elif isinstance(element, dict):
158
152
  for value in element.values():
159
- self._save_recursive(api_name, value)
153
+ self._save_recursive(api_name, value, depth=depth+1)
160
154
  else:
161
155
  self.index += 1
162
156
 
@@ -211,4 +205,42 @@ def extract_detailed_api_segments(full_api_name_with_direction_status):
211
205
  else:
212
206
  full_api_name = None
213
207
  return api_name, full_api_name, direction_status
214
-
208
+
209
+
210
+ def get_module_and_atttribute_name(attribute):
211
+ '''
212
+ Function Description:
213
+ Get the module and attribute name.
214
+ Parameter:
215
+ name: Attribute of a module. Example: torch.float16
216
+ Return:
217
+ module_name: Name of the module. Example: torch.
218
+ attribute_name: Name of the attribute. Example: float16.
219
+ '''
220
+ try:
221
+ module_name, attribute_name = attribute.split(Const.SEP)
222
+ except ValueError as e:
223
+ logger.error(f"Failed to get module and attribute name from {attribute}")
224
+ raise CompareException(CompareException.INVALID_DATA_ERROR) from e
225
+ return module_name, attribute_name
226
+
227
+
228
+ def get_attribute(module_name, attribute_name):
229
+ '''
230
+ Function Description:
231
+ Get the attribute of the module.
232
+ Parameter:
233
+ module_name: Name of the module.
234
+ attribute_name: Name of the attribute.
235
+ '''
236
+ attribute = None
237
+ if module_name not in Const.MODULE_WHITE_LIST:
238
+ logger.error(f"Module {module_name} is not in white list")
239
+ raise CompareException(CompareException.INVALID_DATA_ERROR)
240
+ try:
241
+ module = importlib.import_module(module_name)
242
+ attribute = getattr(module, attribute_name)
243
+ except (ImportError, AttributeError) as e:
244
+ logger.error(f"Failed to get attribute {attribute_name} from module {module_name}: {e}")
245
+ raise CompareException(CompareException.INVALID_ATTRIBUTE_ERROR) from e
246
+ return attribute
@@ -1,3 +1,20 @@
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
4
+ # All rights reserved.
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+
1
18
  # 定义比对算法及比对标准
2
19
  import torch
3
20
  import numpy as np
@@ -142,7 +159,7 @@ def check_inf_nan_value(inf_nan_mask, bench_output, device_output, dtype, rtol):
142
159
  输出:
143
160
  inf_nan_err_ratio:npu输出和golden输出的inf、nan不一致的比例
144
161
  '''
145
- abs_gpu, abs_gpu_with_eps = get_abs_bench_with_eps(bench_output, dtype)
162
+ _, abs_gpu_with_eps = get_abs_bench_with_eps(bench_output, dtype)
146
163
  golden_same_dtype = bench_output.astype(device_output.dtype)
147
164
  a_min = np.finfo(device_output.dtype).min if dtype != torch.bfloat16 else CompareConst.BFLOAT16_MIN
148
165
  a_max = np.finfo(device_output.dtype).max if dtype != torch.bfloat16 else CompareConst.BFLOAT16_MAX
@@ -209,5 +226,5 @@ def get_ulp_err(bench_output, device_output, dtype):
209
226
 
210
227
 
211
228
  def calc_ulp_err(bench_output, device_output, eb, exponent_num, data_type):
212
- return (device_output.astype(data_type) - bench_output).astype(data_type) * \
229
+ return (device_output.astype(data_type) - bench_output).astype(data_type) * \
213
230
  np.exp2(-eb + exponent_num).astype(data_type)
@@ -1,3 +1,20 @@
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
4
+ # All rights reserved.
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+
1
18
  import argparse
2
19
  import math
3
20
  import os
@@ -7,7 +24,7 @@ from collections import namedtuple
7
24
  import torch
8
25
  import pandas as pd
9
26
 
10
- from msprobe.core.common.file_utils import write_csv
27
+ from msprobe.core.common.file_utils import write_csv, read_csv
11
28
  from msprobe.pytorch.api_accuracy_checker.common.config import msCheckerConfig
12
29
  from msprobe.pytorch.api_accuracy_checker.compare.compare_utils import API_PRECISION_COMPARE_RESULT_FILE_NAME, \
13
30
  API_PRECISION_COMPARE_DETAILS_FILE_NAME, BENCHMARK_COMPARE_SUPPORT_LIST, API_PRECISION_COMPARE_UNSUPPORT_LIST, \
@@ -17,18 +34,18 @@ from msprobe.pytorch.api_accuracy_checker.compare.compare_utils import API_PRECI
17
34
  from msprobe.pytorch.api_accuracy_checker.compare.compare_column import ApiPrecisionOutputColumn
18
35
  from msprobe.pytorch.api_accuracy_checker.run_ut.run_ut_utils import get_validated_result_csv_path
19
36
  from msprobe.pytorch.api_accuracy_checker.common.utils import extract_detailed_api_segments
20
- from msprobe.core.common.file_utils import FileChecker, change_mode, check_path_before_create, create_directory
37
+ from msprobe.core.common.file_utils import FileChecker, change_mode, create_directory
21
38
  from msprobe.pytorch.common.log import logger
22
39
  from msprobe.core.common.utils import CompareException
23
40
  from msprobe.core.common.const import Const, CompareConst, FileCheckConst
24
41
 
25
42
  CompareConfig = namedtuple('CompareConfig', ['npu_csv_path', 'gpu_csv_path', 'result_csv_path', 'details_csv_path'])
26
- BenchmarkInf_Nan_Consistency = namedtuple('BenchmarkInf_Nan_Consistency', ['small_value_inf_nan_consistency',
43
+ BenchmarkInfNanConsistency = namedtuple('BenchmarkInfNanConsistency', ['small_value_inf_nan_consistency',
27
44
  'rmse_inf_nan_consistency',
28
45
  'max_rel_inf_nan_consistency',
29
46
  'mean_rel_inf_nan_consistency',
30
47
  'eb_inf_nan_consistency'])
31
- unsupported_message = 'This data type does not support benchmark compare.'
48
+ UNSUPPORTED_MESSAGE = 'This data type does not support benchmark compare.'
32
49
 
33
50
  DEFAULT_THRESHOLD = 1
34
51
 
@@ -154,11 +171,11 @@ class BenchmarkStandard(Standard):
154
171
  self.rmse_status = self._get_status(self.rmse_ratio, 'rmse') if rmse_inf_nan_consistency \
155
172
  else CompareConst.ERROR
156
173
  self.check_result_list.append(self.rmse_status)
157
- self.max_rel_err_status = self._get_status(self.max_rel_err_ratio, 'max_rel_err') if max_rel_inf_nan_consistency \
158
- else CompareConst.ERROR
174
+ self.max_rel_err_status = self._get_status(
175
+ self.max_rel_err_ratio, 'max_rel_err') if max_rel_inf_nan_consistency else CompareConst.ERROR
159
176
  self.check_result_list.append(self.max_rel_err_status)
160
- self.mean_rel_err_status = self._get_status(self.mean_rel_err_ratio, 'mean_rel_err') if mean_rel_inf_nan_consistency \
161
- else CompareConst.ERROR
177
+ self.mean_rel_err_status = self._get_status(
178
+ self.mean_rel_err_ratio, 'mean_rel_err') if mean_rel_inf_nan_consistency else CompareConst.ERROR
162
179
  self.check_result_list.append(self.mean_rel_err_status)
163
180
  self.eb_status = self._get_status(self.eb_ratio, 'eb')
164
181
  if CompareConst.ERROR in self.check_result_list:
@@ -187,7 +204,8 @@ class BenchmarkStandard(Standard):
187
204
  self.npu_precision.get(ApiPrecisionCompareColumn.MAX_REL_ERR),
188
205
  self.gpu_precision.get(ApiPrecisionCompareColumn.MAX_REL_ERR), 10000.0)
189
206
  self.compare_message += max_rel_message
190
- self.mean_rel_err_ratio, mean_rel_inf_nan_consistency, mean_rel_message = self._calc_ratio(ApiPrecisionCompareColumn.MEAN_REL_ERR,
207
+ self.mean_rel_err_ratio, mean_rel_inf_nan_consistency, mean_rel_message = self._calc_ratio(
208
+ ApiPrecisionCompareColumn.MEAN_REL_ERR,
191
209
  self.npu_precision.get(ApiPrecisionCompareColumn.MEAN_REL_ERR),
192
210
  self.gpu_precision.get(ApiPrecisionCompareColumn.MEAN_REL_ERR), 10000.0)
193
211
  self.compare_message += mean_rel_message
@@ -196,8 +214,9 @@ class BenchmarkStandard(Standard):
196
214
  self.gpu_precision.get(ApiPrecisionCompareColumn.EB), 10000.0)
197
215
  self.compare_message += eb_message
198
216
 
199
- return BenchmarkInf_Nan_Consistency(small_value_inf_nan_consistency, rmse_inf_nan_consistency,
200
- max_rel_inf_nan_consistency, mean_rel_inf_nan_consistency, eb_inf_nan_consistency)
217
+ return BenchmarkInfNanConsistency(small_value_inf_nan_consistency, rmse_inf_nan_consistency,
218
+ max_rel_inf_nan_consistency, mean_rel_inf_nan_consistency,
219
+ eb_inf_nan_consistency)
201
220
 
202
221
 
203
222
  class ULPStandard(Standard):
@@ -269,12 +288,12 @@ def api_precision_compare(config):
269
288
  logger.info(f"Compare task result will be saved in {config.result_csv_path}")
270
289
  logger.info(f"Compare task detail will be saved in {config.details_csv_path}")
271
290
  try:
272
- npu_data = pd.read_csv(config.npu_csv_path)
291
+ npu_data = read_csv(config.npu_csv_path)
273
292
  except Exception as err:
274
293
  logger.error(f"Open npu csv Error: %s" % str(err))
275
294
  check_csv_columns(npu_data.columns, "npu_csv")
276
295
  try:
277
- gpu_data = pd.read_csv(config.gpu_csv_path)
296
+ gpu_data = read_csv(config.gpu_csv_path)
278
297
  except Exception as err:
279
298
  logger.error(f"Open gpu csv Error: %s" % str(err))
280
299
  check_csv_columns(gpu_data.columns, "gpu_csv")
@@ -292,8 +311,10 @@ def api_precision_compare(config):
292
311
 
293
312
  def online_api_precision_compare(online_config):
294
313
  rank = online_config.rank
295
- result_csv_path = os.path.join(Const.DEFAULT_PATH, online_config.result_csv_path).replace("_rank*.csv", f"_rank{rank}.csv")
296
- details_csv_path = os.path.join(Const.DEFAULT_PATH, online_config.details_csv_path).replace("_rank*.csv", f"_rank{rank}.csv")
314
+ result_csv_path = os.path.join(Const.DEFAULT_PATH, online_config.result_csv_path).replace(
315
+ "_rank*.csv", f"_rank{rank}.csv")
316
+ details_csv_path = os.path.join(Const.DEFAULT_PATH, online_config.details_csv_path).replace(
317
+ "_rank*.csv", f"_rank{rank}.csv")
297
318
  detail_csv_title = [ApiPrecisionCompareColumn.get_detail_csv_title()]
298
319
  result_csv_title = [ApiPrecisionCompareColumn.get_result_csv_title()]
299
320
  if not os.path.exists(result_csv_path):
@@ -315,6 +336,7 @@ def online_api_precision_compare(online_config):
315
336
  def analyse_csv(npu_data, gpu_data, config):
316
337
  forward_status, backward_status = [], []
317
338
  last_api_name, last_api_dtype, last_api_full_name = None, None, None
339
+ last_api_skip_message = ''
318
340
  for _, row_npu in npu_data.iterrows():
319
341
  message = ''
320
342
  compare_column = ApiPrecisionOutputColumn()
@@ -328,7 +350,7 @@ def analyse_csv(npu_data, gpu_data, config):
328
350
  compare_column.compare_result = CompareConst.SKIP
329
351
  compare_column.compare_message = err_message
330
352
  write_detail_csv(compare_column.to_column_value(), config.details_csv_path)
331
- write_csv([[full_api_name_with_direction_status, CompareConst.SKIP, CompareConst.SKIP, err_message]],
353
+ write_csv([[full_api_name_with_direction_status, CompareConst.SKIP, CompareConst.SKIP, err_message]],
332
354
  config.result_csv_path)
333
355
  continue
334
356
  if row_gpu.empty:
@@ -355,19 +377,19 @@ def analyse_csv(npu_data, gpu_data, config):
355
377
 
356
378
  if last_api_name is not None and api_full_name != last_api_name:
357
379
  if last_api_dtype in API_PRECISION_COMPARE_UNSUPPORT_LIST:
358
- message = unsupported_message
380
+ message = UNSUPPORTED_MESSAGE
359
381
  write_csv([[last_api_name, CompareConst.SKIP, CompareConst.SKIP, message]], config.result_csv_path)
360
382
  print_test_success(last_api_name, CompareConst.SKIP, CompareConst.SKIP)
361
- forward_status, backward_status = [], []
362
- message = ''
363
383
  else:
364
384
  forward_result = get_api_checker_result(forward_status)
365
385
  backward_result = get_api_checker_result(backward_status)
366
386
  message += CompareMessage.get(last_api_name, "") if forward_result == CompareConst.ERROR else ""
387
+ message += last_api_skip_message if forward_result == CompareConst.SKIP else ""
367
388
  write_csv([[last_api_name, forward_result, backward_result, message]], config.result_csv_path)
368
389
  print_test_success(last_api_name, forward_result, backward_result)
369
- forward_status, backward_status = [], []
370
- message = ''
390
+ last_api_skip_message = ''
391
+ forward_status, backward_status = [], []
392
+ message = ''
371
393
 
372
394
  is_supported = row_npu[ApiPrecisionCompareColumn.DEVICE_DTYPE] not in API_PRECISION_COMPARE_UNSUPPORT_LIST
373
395
  last_api_name = api_full_name
@@ -378,6 +400,8 @@ def analyse_csv(npu_data, gpu_data, config):
378
400
 
379
401
  if direction_status == 'forward':
380
402
  forward_status.append(new_status)
403
+ last_api_skip_message = str(row_npu[ApiPrecisionCompareColumn.MESSAGE]) if new_status == CompareConst.SKIP \
404
+ else ''
381
405
  elif direction_status == 'backward':
382
406
  backward_status.append(new_status)
383
407
  else:
@@ -385,15 +409,17 @@ def analyse_csv(npu_data, gpu_data, config):
385
409
 
386
410
  if last_api_name is not None:
387
411
  if last_api_dtype in API_PRECISION_COMPARE_UNSUPPORT_LIST:
388
- message = unsupported_message
412
+ message = UNSUPPORTED_MESSAGE
389
413
  write_csv([[last_api_name, CompareConst.SKIP, CompareConst.SKIP, message]], config.result_csv_path)
390
414
  print_test_success(last_api_name, CompareConst.SKIP, CompareConst.SKIP)
391
415
  else:
392
416
  forward_result = get_api_checker_result(forward_status)
393
417
  backward_result = get_api_checker_result(backward_status)
394
418
  message += CompareMessage.get(last_api_name, "") if forward_result == CompareConst.ERROR else ""
419
+ message += last_api_skip_message if forward_result == CompareConst.SKIP else ""
395
420
  write_csv([[last_api_name, forward_result, backward_result, message]], config.result_csv_path)
396
421
  print_test_success(last_api_name, forward_result, backward_result)
422
+ last_api_skip_message = ''
397
423
 
398
424
 
399
425
  def get_api_status(row_npu, row_gpu, api_name, compare_column):
@@ -576,8 +602,7 @@ def _api_precision_compare(parser=None):
576
602
  def _api_precision_compare_command(args):
577
603
  npu_csv_path = get_validated_result_csv_path(args.npu_csv_path, 'detail')
578
604
  gpu_csv_path = get_validated_result_csv_path(args.gpu_csv_path, 'detail')
579
- out_path = os.path.realpath(args.out_path) if args.out_path else "./"
580
- check_path_before_create(out_path)
605
+ out_path = args.out_path if args.out_path else Const.DEFAULT_PATH
581
606
  create_directory(out_path)
582
607
  out_path_checker = FileChecker(out_path, FileCheckConst.DIR, ability=FileCheckConst.WRITE_ABLE)
583
608
  out_path = out_path_checker.common_check()
@@ -595,7 +620,7 @@ def _api_precision_compare_parser(parser):
595
620
  parser.add_argument("-gpu", "--gpu_csv_path", dest="gpu_csv_path", default="", type=str,
596
621
  help="<Required> Accuracy_checking_details.csv generated on the GPU by using the "
597
622
  "api_accuracy_checker tool.",
598
- required=False)
623
+ required=True)
599
624
  parser.add_argument("-o", "--out_path", dest="out_path", default="", type=str,
600
625
  help="<optional> The api precision compare task result out path.",
601
626
  required=False)
@@ -1,3 +1,20 @@
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
4
+ # All rights reserved.
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+
1
18
  # 进行比对及结果展示
2
19
  import os
3
20
  from collections import namedtuple
@@ -127,8 +144,12 @@ class Comparator:
127
144
  return test_rows
128
145
 
129
146
  def write_csv_title(self):
130
- summary_test_rows = [[self.COLUMN_API_NAME, self.COLUMN_FORWARD_SUCCESS,
131
- self.COLUMN_BACKWARD_SUCCESS, "Message"]]
147
+ summary_test_rows = [
148
+ [self.COLUMN_API_NAME,
149
+ self.COLUMN_FORWARD_SUCCESS,
150
+ self.COLUMN_BACKWARD_SUCCESS,
151
+ "Message"]
152
+ ]
132
153
  for save_path, detail_save_path in zip(self.save_path_list, self.detail_save_path_list):
133
154
  if not os.path.exists(save_path):
134
155
  write_csv(summary_test_rows, save_path)
@@ -240,13 +261,15 @@ class Comparator:
240
261
  def _compare_core(self, api_name, bench_output, device_output):
241
262
  compare_column = CompareColumn()
242
263
  if not isinstance(bench_output, type(device_output)):
243
- return CompareConst.ERROR, compare_column, "bench and npu output type is different."
264
+ status = CompareConst.ERROR
265
+ message = "bench and npu output type is different."
244
266
  elif isinstance(bench_output, dict):
245
267
  b_keys, n_keys = set(bench_output.keys()), set(device_output.keys())
246
268
  if b_keys != n_keys:
247
- return CompareConst.ERROR, compare_column, "bench and npu output dict keys are different."
269
+ status = CompareConst.ERROR
270
+ message = "bench and npu output dict keys are different."
248
271
  else:
249
- status, compare_result, message = self._compare_core(api_name, list(bench_output.values()),
272
+ status, compare_column, message = self._compare_core(api_name, list(bench_output.values()),
250
273
  list(device_output.values()))
251
274
  elif isinstance(bench_output, torch.Tensor):
252
275
  copy_bench_out = bench_output.detach().clone()
@@ -254,19 +277,20 @@ class Comparator:
254
277
  compare_column.bench_type = str(copy_bench_out.dtype)
255
278
  compare_column.npu_type = str(copy_device_output.dtype)
256
279
  compare_column.shape = tuple(device_output.shape)
257
- status, compare_result, message = self._compare_torch_tensor(api_name, copy_bench_out, copy_device_output,
280
+ status, compare_column, message = self._compare_torch_tensor(api_name, copy_bench_out, copy_device_output,
258
281
  compare_column)
259
282
  elif isinstance(bench_output, (bool, int, float, str)):
260
283
  compare_column.bench_type = str(type(bench_output))
261
284
  compare_column.npu_type = str(type(device_output))
262
- status, compare_result, message = self._compare_builtin_type(bench_output, device_output, compare_column)
285
+ status, compare_column, message = self._compare_builtin_type(bench_output, device_output, compare_column)
263
286
  elif bench_output is None:
264
- return CompareConst.SKIP, compare_column, "Bench output is None, skip this test."
287
+ status = CompareConst.SKIP
288
+ message = "Bench output is None, skip this test."
265
289
  else:
266
- return CompareConst.PASS, compare_column,
267
- "Unexpected output type in compare_core: {}".format(type(bench_output))
290
+ status = CompareConst.ERROR
291
+ message = "Unexpected output type in compare_core: {}".format(type(bench_output))
268
292
 
269
- return status, compare_result, message
293
+ return status, compare_column, message
270
294
 
271
295
  def _compare_torch_tensor(self, api_name, bench_output, device_output, compare_column):
272
296
  cpu_shape = bench_output.shape
@@ -330,21 +354,23 @@ class Comparator:
330
354
  compare_column.max_ulp_error = np.max(ulp_err)
331
355
  compare_column.mean_ulp_error = np.mean(ulp_err)
332
356
  if dtype == torch.float32:
333
- compare_column.ulp_error_proportion = np.sum(ulp_err > CompareConst.ULP_FLOAT32_THRESHOLD) / bench_output.size
357
+ compare_column.ulp_error_proportion = \
358
+ np.sum(ulp_err > CompareConst.ULP_FLOAT32_THRESHOLD) / bench_output.size
334
359
  else:
335
- compare_column.ulp_error_proportion = np.sum(ulp_err > CompareConst.ULP_FLOAT16_THRESHOLD) / bench_output.size
360
+ compare_column.ulp_error_proportion = \
361
+ np.sum(ulp_err > CompareConst.ULP_FLOAT16_THRESHOLD) / bench_output.size
336
362
  else:
337
363
  dtype_config = precision_configs.get(dtype)
338
364
  small_value_mask = get_small_value_mask(abs_bench, both_finite_mask, dtype_config['small_value'][0])
339
365
  abs_err_greater_mask = np.greater(abs_err, dtype_config['small_value_atol'][0])
340
366
  compare_column.small_value_err_ratio = get_small_value_err_ratio(small_value_mask, abs_err_greater_mask)
341
367
  rel_err = get_rel_err(abs_err, abs_bench_with_eps, small_value_mask, inf_nan_mask)
342
- compare_column.RMSE = get_rmse(abs_err, np.logical_or(inf_nan_mask, small_value_mask))
343
- compare_column.EB = get_error_balance(bench_output, device_output)
368
+ compare_column.rmse = get_rmse(abs_err, np.logical_or(inf_nan_mask, small_value_mask))
369
+ compare_column.eb = get_error_balance(bench_output, device_output)
344
370
  if rel_err.size == 0:
345
371
  return CompareConst.ERROR, compare_column, "Relative error result list is empty."
346
- compare_column.Max_rel_error = get_max_rel_err(rel_err)
347
- compare_column.Mean_rel_error = get_mean_rel_err(rel_err)
372
+ compare_column.max_rel_error = get_max_rel_err(rel_err)
373
+ compare_column.mean_rel_error = get_mean_rel_err(rel_err)
348
374
 
349
375
  cos_res, cos_status, msg = cosine_sim(bench_output, device_output)
350
376
  compare_column.cosine_sim = cos_res
@@ -363,7 +389,8 @@ class Comparator:
363
389
  hundred_res, hundred_status = get_rel_err_ratio(rel_err_orign, CompareConst.HUNDRED_RATIO_THRESHOLD)
364
390
  compare_column.rel_err_hundredth = hundred_res
365
391
  if not hundred_status:
366
- message += "Relative error is greater than 0.01, consider as error, skip other check and set to SPACE.\n"
392
+ message += "Relative error is greater than 0.01, consider as error, " \
393
+ "skip other check and set to SPACE.\n"
367
394
  return CompareConst.ERROR, compare_column, message
368
395
  thousand_res, thousand_status = get_rel_err_ratio(rel_err_orign, CompareConst.THOUSAND_RATIO_THRESHOLD)
369
396
  compare_column.rel_err_thousandth = thousand_res
@@ -373,14 +400,17 @@ class Comparator:
373
400
  return CompareConst.PASS, compare_column, message
374
401
  message += "Relative error is greater than 0.001, consider as warning, skip other check and set to SPACE.\n"
375
402
  return CompareConst.WARNING, compare_column, message
376
- ten_thousand_res, ten_thousand_status = get_rel_err_ratio(rel_err_orign, CompareConst.TEN_THOUSAND_RATIO_THRESHOLD)
403
+ ten_thousand_res, ten_thousand_status = get_rel_err_ratio(
404
+ rel_err_orign, CompareConst.TEN_THOUSAND_RATIO_THRESHOLD)
377
405
  compare_column.rel_err_ten_thousandth = ten_thousand_res
378
406
  if dtype in [torch.float32, torch.float64]:
379
407
  if not thousand_status:
380
- message += "Relative error is greater than 0.001, consider as error, skip other check and set to SPACE.\n"
408
+ message += "Relative error is greater than 0.001, consider as error, " \
409
+ "skip other check and set to SPACE.\n"
381
410
  return CompareConst.ERROR, compare_column, message
382
411
  if not ten_thousand_status:
383
- message += "Relative error is greater than 0.0001, consider as warning, skip other check and set to SPACE.\n"
412
+ message += "Relative error is greater than 0.0001, consider as warning, " \
413
+ "skip other check and set to SPACE.\n"
384
414
  return CompareConst.WARNING, compare_column, message
385
415
  message += "Relative error is less than 0.0001, consider as pass.\n"
386
416
  return CompareConst.PASS, compare_column, message
@@ -1,3 +1,20 @@
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
4
+ # All rights reserved.
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+
1
18
  from msprobe.core.common.const import CompareConst
2
19
 
3
20
 
@@ -12,11 +29,11 @@ class CompareColumn:
12
29
  self.rel_err_thousandth = CompareConst.SPACE
13
30
  self.rel_err_ten_thousandth = CompareConst.SPACE
14
31
  self.error_rate = CompareConst.SPACE
15
- self.EB = CompareConst.SPACE
16
- self.RMSE = CompareConst.SPACE
32
+ self.eb = CompareConst.SPACE
33
+ self.rmse = CompareConst.SPACE
17
34
  self.small_value_err_ratio = CompareConst.SPACE
18
- self.Max_rel_error = CompareConst.SPACE
19
- self.Mean_rel_error = CompareConst.SPACE
35
+ self.max_rel_error = CompareConst.SPACE
36
+ self.mean_rel_error = CompareConst.SPACE
20
37
  self.inf_nan_error_ratio = CompareConst.SPACE
21
38
  self.rel_err_ratio = CompareConst.SPACE
22
39
  self.abs_err_ratio = CompareConst.SPACE
@@ -26,8 +43,8 @@ class CompareColumn:
26
43
 
27
44
  def to_column_value(self, is_pass, message):
28
45
  return [self.bench_type, self.npu_type, self.shape, self.cosine_sim, self.max_abs_err, self.rel_err_hundredth,
29
- self.rel_err_thousandth, self.rel_err_ten_thousandth, self.error_rate, self.EB, self.RMSE,
30
- self.small_value_err_ratio, self.Max_rel_error, self.Mean_rel_error, self.inf_nan_error_ratio,
46
+ self.rel_err_thousandth, self.rel_err_ten_thousandth, self.error_rate, self.eb, self.rmse,
47
+ self.small_value_err_ratio, self.max_rel_error, self.mean_rel_error, self.inf_nan_error_ratio,
31
48
  self.rel_err_ratio, self.abs_err_ratio, self.max_ulp_error, self.mean_ulp_error,
32
49
  self.ulp_error_proportion, is_pass, message]
33
50
 
@@ -1,3 +1,20 @@
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
4
+ # All rights reserved.
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+
1
18
  import time
2
19
  import os
3
20
  import math
@@ -32,7 +49,8 @@ threshold_yaml_path = os.path.join(cur_path, "api_precision_threshold.yaml")
32
49
  apis_threshold = load_yaml(threshold_yaml_path)
33
50
 
34
51
 
35
- DETAIL_TEST_ROWS = [[
52
+ DETAIL_TEST_ROWS = [
53
+ [
36
54
  "API Name", "Bench Dtype", "DEVICE Dtype", "Shape",
37
55
  "余弦相似度",
38
56
  "最大绝对误差",
@@ -53,7 +71,8 @@ DETAIL_TEST_ROWS = [[
53
71
  "ULP误差大于阈值占比",
54
72
  "Status",
55
73
  "Message"
56
- ]]
74
+ ]
75
+ ]
57
76
 
58
77
 
59
78
  precision_configs = {
@@ -154,11 +173,11 @@ class ApiPrecisionCompareColumn:
154
173
  def to_required_columns():
155
174
  return [ApiPrecisionCompareColumn.API_NAME, ApiPrecisionCompareColumn.DEVICE_DTYPE,
156
175
  ApiPrecisionCompareColumn.SMALL_VALUE_ERROR_RATE, ApiPrecisionCompareColumn.RMSE,
157
- ApiPrecisionCompareColumn.MAX_REL_ERR, ApiPrecisionCompareColumn.MEAN_REL_ERR, ApiPrecisionCompareColumn.EB,
158
- ApiPrecisionCompareColumn.ERROR_RATE, ApiPrecisionCompareColumn.INF_NAN_ERROR_RATIO,
159
- ApiPrecisionCompareColumn.REL_ERR_RATIO, ApiPrecisionCompareColumn.ABS_ERR_RATIO,
160
- ApiPrecisionCompareColumn.MEAN_ULP_ERR, ApiPrecisionCompareColumn.ULP_ERR_PROPORTION,
161
- ApiPrecisionCompareColumn.REL_ERR_THOUSANDTH]
176
+ ApiPrecisionCompareColumn.MAX_REL_ERR, ApiPrecisionCompareColumn.MEAN_REL_ERR,
177
+ ApiPrecisionCompareColumn.EB, ApiPrecisionCompareColumn.ERROR_RATE,
178
+ ApiPrecisionCompareColumn.INF_NAN_ERROR_RATIO, ApiPrecisionCompareColumn.REL_ERR_RATIO,
179
+ ApiPrecisionCompareColumn.ABS_ERR_RATIO, ApiPrecisionCompareColumn.MEAN_ULP_ERR,
180
+ ApiPrecisionCompareColumn.ULP_ERR_PROPORTION, ApiPrecisionCompareColumn.REL_ERR_THOUSANDTH]
162
181
 
163
182
  @staticmethod
164
183
  def get_detail_csv_title():
@@ -175,7 +194,8 @@ class ApiPrecisionCompareColumn:
175
194
  ApiPrecisionCompareColumn.MEAN_ULP_ERR, ApiPrecisionCompareColumn.ULP_ERR_PROPORTION,
176
195
  ApiPrecisionCompareColumn.ULP_ERR_PROPORTION_RATIO, ApiPrecisionCompareColumn.ULP_ERR_STATUS,
177
196
  ApiPrecisionCompareColumn.REL_ERR_THOUSANDTH, ApiPrecisionCompareColumn.REL_ERR_THOUSANDTH_STATUS,
178
- ApiPrecisionCompareColumn.FINAL_RESULT, ApiPrecisionCompareColumn.ALGORITHM, ApiPrecisionCompareColumn.MESSAGE]
197
+ ApiPrecisionCompareColumn.FINAL_RESULT, ApiPrecisionCompareColumn.ALGORITHM,
198
+ ApiPrecisionCompareColumn.MESSAGE]
179
199
 
180
200
  @staticmethod
181
201
  def get_result_csv_title():
@@ -7,4 +7,4 @@ nfs_path: ""
7
7
  host: ""
8
8
  port: -1
9
9
  rank_list: [0]
10
- tls_path: ""
10
+ tls_path: "./"
@@ -0,0 +1,9 @@
1
+ {
2
+ "dump_json_path": "./dump.json",
3
+ "api_name": "",
4
+ "extract_api_path": "",
5
+ "propagation": "forward",
6
+ "data_mode": "random_data",
7
+ "random_seed": 1234,
8
+ "iter_times": 1
9
+ }