mindstudio-probe 1.1.1__py3-none-any.whl → 1.2.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (226) hide show
  1. {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.2.dist-info}/METADATA +3 -2
  2. mindstudio_probe-1.2.2.dist-info/RECORD +415 -0
  3. msprobe/CMakeLists.txt +5 -0
  4. msprobe/README.md +16 -21
  5. msprobe/config.json +1 -0
  6. msprobe/core/common/const.py +185 -11
  7. msprobe/core/common/exceptions.py +3 -1
  8. msprobe/core/common/file_utils.py +33 -7
  9. msprobe/core/common/inplace_ops.yaml +4 -0
  10. msprobe/core/common/utils.py +42 -14
  11. msprobe/core/common_config.py +6 -0
  12. msprobe/core/compare/acc_compare.py +139 -128
  13. msprobe/core/compare/check.py +31 -29
  14. msprobe/core/compare/compare_cli.py +17 -16
  15. msprobe/core/compare/highlight.py +186 -99
  16. msprobe/core/compare/layer_mapping/data_scope_parser.py +19 -8
  17. msprobe/core/compare/layer_mapping/layer_mapping.py +21 -14
  18. msprobe/core/compare/layer_mapping/postprocess_pass.py +4 -3
  19. msprobe/core/compare/merge_result/merge_result.py +381 -0
  20. msprobe/core/compare/merge_result/merge_result_cli.py +31 -0
  21. msprobe/core/compare/merge_result/utils.py +81 -0
  22. msprobe/core/compare/multiprocessing_compute.py +2 -2
  23. msprobe/core/compare/npy_compare.py +109 -147
  24. msprobe/core/compare/utils.py +199 -69
  25. msprobe/core/data_dump/data_collector.py +100 -25
  26. msprobe/core/data_dump/data_processor/base.py +130 -28
  27. msprobe/core/data_dump/data_processor/factory.py +8 -3
  28. msprobe/core/data_dump/data_processor/mindspore_processor.py +170 -23
  29. msprobe/core/data_dump/data_processor/pytorch_processor.py +175 -64
  30. msprobe/core/data_dump/json_writer.py +54 -8
  31. msprobe/core/data_dump/scope.py +19 -18
  32. msprobe/core/overflow_check/abnormal_scene.py +9 -5
  33. msprobe/core/overflow_check/checker.py +1 -1
  34. msprobe/core/overflow_check/utils.py +1 -1
  35. msprobe/docs/01.installation.md +121 -17
  36. msprobe/docs/02.config_introduction.md +18 -16
  37. msprobe/docs/03.config_examples.md +24 -0
  38. msprobe/docs/05.data_dump_PyTorch.md +107 -58
  39. msprobe/docs/06.data_dump_MindSpore.md +95 -34
  40. msprobe/docs/07.accuracy_checker_PyTorch.md +18 -18
  41. msprobe/docs/09.accuracy_checker_MindSpore.md +8 -6
  42. msprobe/docs/10.accuracy_compare_PyTorch.md +99 -41
  43. msprobe/docs/11.accuracy_compare_MindSpore.md +249 -48
  44. msprobe/docs/12.overflow_check_PyTorch.md +1 -1
  45. msprobe/docs/19.monitor.md +310 -220
  46. msprobe/docs/21.visualization_PyTorch.md +125 -35
  47. msprobe/docs/22.visualization_MindSpore.md +149 -41
  48. msprobe/docs/23.generate_operator_PyTorch.md +107 -0
  49. msprobe/docs/24.code_mapping_Mindspore.md +28 -0
  50. msprobe/docs/{23.tool_function_introduction.md → 25.tool_function_introduction.md} +1 -0
  51. msprobe/docs/26.data_dump_PyTorch_baseline.md +37 -0
  52. msprobe/docs/27.dump_json_instruction.md +525 -0
  53. msprobe/docs/28.debugger_save_instruction.md +94 -0
  54. msprobe/docs/28.kernel_dump_MindSpore.md +69 -0
  55. msprobe/docs/FAQ.md +26 -2
  56. msprobe/docs/accuracy_checker_MindSpore/accuracy_checker_MindSpore_baseline.md +14 -0
  57. msprobe/docs/data_dump_MindSpore/data_dump_MindSpore_baseline.md +22 -0
  58. msprobe/docs/img/merge_result.png +0 -0
  59. msprobe/docs/img/monitor/step_count_per_record.png +0 -0
  60. msprobe/docs/img/visualization/fuzzy_match_ms.png +0 -0
  61. msprobe/docs/img/visualization/fuzzy_match_pt.png +0 -0
  62. msprobe/docs/img/visualization/tensorboard_1.png +0 -0
  63. msprobe/docs/img/visualization/tensorboard_2.png +0 -0
  64. msprobe/docs/img/visualization/vis_browser_1.png +0 -0
  65. msprobe/docs/img/visualization/vis_browser_2.png +0 -0
  66. msprobe/docs/img/visualization/vis_precision_info.png +0 -0
  67. msprobe/docs/img/visualization/vis_search_info.png +0 -0
  68. msprobe/docs/img/visualization/vis_show_info.png +0 -0
  69. msprobe/docs/img/visualization/vis_showcase.png +0 -0
  70. msprobe/docs/img/visualization/vis_unmatch_info.png +0 -0
  71. msprobe/docs/visualization/GPTModel.png +0 -0
  72. msprobe/docs/visualization/ParallelMLP.png +0 -0
  73. msprobe/docs/visualization/layer_mapping_example.md +132 -0
  74. msprobe/docs/visualization/mapping.png +0 -0
  75. msprobe/docs/visualization/mapping1.png +0 -0
  76. msprobe/docs/visualization/module_name.png +0 -0
  77. msprobe/docs/visualization/module_name1.png +0 -0
  78. msprobe/docs/visualization/no_mapping.png +0 -0
  79. msprobe/docs/visualization/no_mapping1.png +0 -0
  80. msprobe/docs/visualization/no_mapping_analyze.png +0 -0
  81. msprobe/docs/visualization/top_layer.png +0 -0
  82. msprobe/mindspore/__init__.py +11 -0
  83. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +80 -28
  84. msprobe/mindspore/api_accuracy_checker/api_runner.py +54 -16
  85. msprobe/mindspore/api_accuracy_checker/cmd_parser.py +2 -1
  86. msprobe/mindspore/api_accuracy_checker/compute_element.py +52 -8
  87. msprobe/mindspore/api_accuracy_checker/data_manager.py +37 -0
  88. msprobe/mindspore/api_accuracy_checker/main.py +1 -0
  89. msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +12 -6
  90. msprobe/mindspore/api_accuracy_checker/multi_data_manager.py +3 -1
  91. msprobe/mindspore/api_accuracy_checker/torch_mindtorch_importer.py +129 -0
  92. msprobe/mindspore/api_accuracy_checker/type_mapping.py +24 -1
  93. msprobe/mindspore/api_accuracy_checker/utils.py +6 -1
  94. msprobe/mindspore/code_mapping/bind.py +264 -0
  95. msprobe/mindspore/code_mapping/cmd_parser.py +40 -0
  96. msprobe/mindspore/code_mapping/graph.py +49 -0
  97. msprobe/mindspore/code_mapping/graph_parser.py +226 -0
  98. msprobe/mindspore/code_mapping/main.py +24 -0
  99. msprobe/mindspore/code_mapping/processor.py +34 -0
  100. msprobe/mindspore/common/const.py +3 -1
  101. msprobe/mindspore/common/utils.py +68 -5
  102. msprobe/mindspore/compare/distributed_compare.py +0 -2
  103. msprobe/mindspore/compare/ms_compare.py +105 -63
  104. msprobe/mindspore/compare/ms_graph_compare.py +14 -5
  105. msprobe/mindspore/debugger/debugger_config.py +28 -2
  106. msprobe/mindspore/debugger/precision_debugger.py +100 -12
  107. msprobe/mindspore/dump/hook_cell/api_registry.py +85 -16
  108. msprobe/mindspore/dump/hook_cell/hook_cell.py +60 -38
  109. msprobe/mindspore/dump/hook_cell/primitive_hooks.py +33 -15
  110. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +11 -1
  111. msprobe/mindspore/dump/hook_cell/wrap_api.py +92 -1
  112. msprobe/mindspore/dump/jit_dump.py +7 -6
  113. msprobe/mindspore/dump/kernel_dump/kernel_config.py +33 -0
  114. msprobe/mindspore/dump/kernel_graph_dump.py +7 -0
  115. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +13 -4
  116. msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +2 -2
  117. msprobe/mindspore/grad_probe/grad_analyzer.py +24 -12
  118. msprobe/mindspore/grad_probe/hook.py +13 -4
  119. msprobe/mindspore/mindtorch/__init__.py +18 -0
  120. msprobe/mindspore/mindtorch/mindtorch_adaptor.py +255 -0
  121. msprobe/mindspore/monitor/anomaly_detect.py +404 -0
  122. msprobe/mindspore/monitor/distributed/__init__.py +0 -0
  123. msprobe/mindspore/monitor/distributed/distributed_ops.yaml +15 -0
  124. msprobe/mindspore/monitor/distributed/stack_blacklist.yaml +5 -0
  125. msprobe/mindspore/monitor/distributed/wrap_distributed.py +300 -0
  126. msprobe/mindspore/monitor/features.py +63 -0
  127. msprobe/mindspore/monitor/module_hook.py +821 -0
  128. msprobe/mindspore/monitor/module_spec_verifier.py +94 -0
  129. msprobe/mindspore/monitor/utils.py +267 -0
  130. msprobe/mindspore/ms_config.py +13 -3
  131. msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +7 -0
  132. msprobe/mindspore/service.py +347 -107
  133. msprobe/msprobe.py +24 -3
  134. msprobe/pytorch/__init__.py +7 -7
  135. msprobe/pytorch/api_accuracy_checker/common/utils.py +31 -16
  136. msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +41 -8
  137. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +100 -267
  138. msprobe/pytorch/api_accuracy_checker/compare/api_precision_standard.yaml +4 -1
  139. msprobe/pytorch/api_accuracy_checker/compare/compare.py +69 -68
  140. msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +54 -0
  141. msprobe/pytorch/api_accuracy_checker/compare/compare_input.py +51 -0
  142. msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +2 -4
  143. msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +55 -31
  144. msprobe/pytorch/api_accuracy_checker/precision_standard/absolute_threshold.py +106 -0
  145. msprobe/pytorch/api_accuracy_checker/precision_standard/accumulative_error_compare.py +107 -0
  146. msprobe/pytorch/api_accuracy_checker/precision_standard/base_standard.py +151 -0
  147. msprobe/pytorch/api_accuracy_checker/precision_standard/benchmark_compare.py +226 -0
  148. msprobe/pytorch/api_accuracy_checker/precision_standard/binary_consistency.py +68 -0
  149. msprobe/pytorch/api_accuracy_checker/precision_standard/standard_config.py +218 -0
  150. msprobe/pytorch/api_accuracy_checker/precision_standard/standard_register.py +104 -0
  151. msprobe/pytorch/api_accuracy_checker/precision_standard/thousandth_standard.py +63 -0
  152. msprobe/pytorch/api_accuracy_checker/precision_standard/ulp_compare.py +200 -0
  153. msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +57 -1
  154. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +2 -1
  155. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +42 -14
  156. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +64 -19
  157. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +34 -4
  158. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +5 -3
  159. msprobe/pytorch/bench_functions/apply_adam.py +215 -0
  160. msprobe/pytorch/bench_functions/group_norm_silu.py +27 -0
  161. msprobe/pytorch/bench_functions/mish.py +21 -0
  162. msprobe/pytorch/bench_functions/moe_gating_top_k_softmax.py +44 -0
  163. msprobe/pytorch/bench_functions/npu_fusion_attention.py +42 -10
  164. msprobe/pytorch/bench_functions/sort_v2.py +21 -0
  165. msprobe/pytorch/common/parse_json.py +2 -1
  166. msprobe/pytorch/common/utils.py +116 -2
  167. msprobe/pytorch/compare/distributed_compare.py +17 -29
  168. msprobe/pytorch/compare/pt_compare.py +40 -20
  169. msprobe/pytorch/debugger/debugger_config.py +42 -17
  170. msprobe/pytorch/debugger/precision_debugger.py +56 -12
  171. msprobe/pytorch/dump/module_dump/__init__.py +0 -0
  172. msprobe/pytorch/dump/module_dump/module_dump.py +86 -0
  173. msprobe/pytorch/dump/module_dump/module_processer.py +204 -0
  174. msprobe/pytorch/free_benchmark/common/params.py +2 -1
  175. msprobe/pytorch/free_benchmark/common/utils.py +3 -0
  176. msprobe/pytorch/free_benchmark/compare/grad_saver.py +0 -2
  177. msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +31 -47
  178. msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +0 -4
  179. msprobe/pytorch/function_factory.py +7 -1
  180. msprobe/pytorch/hook_module/__init__.py +1 -1
  181. msprobe/pytorch/hook_module/hook_module.py +14 -11
  182. msprobe/pytorch/hook_module/register_optimizer_hook.py +59 -0
  183. msprobe/pytorch/hook_module/support_wrap_ops.yaml +36 -1
  184. msprobe/pytorch/hook_module/wrap_distributed.py +10 -8
  185. msprobe/pytorch/hook_module/wrap_functional.py +0 -40
  186. msprobe/pytorch/monitor/anomaly_analyse.py +1 -1
  187. msprobe/pytorch/monitor/anomaly_detect.py +98 -28
  188. msprobe/pytorch/monitor/csv2tb.py +164 -0
  189. msprobe/pytorch/monitor/distributed/wrap_distributed.py +25 -14
  190. msprobe/pytorch/monitor/features.py +3 -3
  191. msprobe/pytorch/monitor/module_hook.py +543 -318
  192. msprobe/pytorch/monitor/module_metric.py +27 -48
  193. msprobe/pytorch/monitor/module_spec_verifier.py +3 -1
  194. msprobe/pytorch/monitor/optimizer_collect.py +76 -56
  195. msprobe/pytorch/monitor/unittest/test_monitor.py +24 -9
  196. msprobe/pytorch/monitor/utils.py +84 -48
  197. msprobe/pytorch/online_dispatch/dispatch.py +8 -2
  198. msprobe/pytorch/parse_tool/lib/compare.py +10 -10
  199. msprobe/pytorch/parse_tool/lib/config.py +5 -7
  200. msprobe/pytorch/parse_tool/lib/file_desc.py +15 -1
  201. msprobe/pytorch/parse_tool/lib/interactive_cli.py +10 -10
  202. msprobe/pytorch/parse_tool/lib/parse_exception.py +7 -7
  203. msprobe/pytorch/parse_tool/lib/parse_tool.py +11 -10
  204. msprobe/pytorch/parse_tool/lib/utils.py +18 -19
  205. msprobe/pytorch/parse_tool/lib/visualization.py +9 -10
  206. msprobe/pytorch/pt_config.py +19 -22
  207. msprobe/pytorch/service.py +264 -115
  208. msprobe/visualization/builder/graph_builder.py +93 -10
  209. msprobe/visualization/builder/msprobe_adapter.py +30 -6
  210. msprobe/visualization/compare/graph_comparator.py +64 -14
  211. msprobe/visualization/compare/mode_adapter.py +1 -15
  212. msprobe/visualization/graph/base_node.py +15 -19
  213. msprobe/visualization/graph/distributed_analyzer.py +395 -0
  214. msprobe/visualization/graph/graph.py +9 -0
  215. msprobe/visualization/graph/node_op.py +4 -2
  216. msprobe/visualization/graph_service.py +100 -27
  217. msprobe/visualization/utils.py +24 -31
  218. mindstudio_probe-1.1.1.dist-info/RECORD +0 -341
  219. msprobe/pytorch/functional/module_dump.py +0 -84
  220. msprobe/pytorch/module_processer.py +0 -150
  221. {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.2.dist-info}/LICENSE +0 -0
  222. {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.2.dist-info}/WHEEL +0 -0
  223. {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.2.dist-info}/entry_points.txt +0 -0
  224. {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.2.dist-info}/top_level.txt +0 -0
  225. /msprobe/docs/{data_dump_Mindspore → data_dump_MindSpore}/dynamic_graph_quick_start_example.md +0 -0
  226. /msprobe/{pytorch/functional → mindspore/code_mapping}/__init__.py +0 -0
msprobe/msprobe.py CHANGED
@@ -16,10 +16,12 @@
16
16
  import argparse
17
17
  import sys
18
18
  import importlib.util
19
- from msprobe.core.compare.utils import _compare_parser
19
+
20
+ from msprobe.core.common.const import Const
20
21
  from msprobe.core.common.log import logger
22
+ from msprobe.core.compare.utils import _compare_parser
21
23
  from msprobe.core.compare.compare_cli import compare_cli
22
- from msprobe.core.common.const import Const
24
+ from msprobe.core.compare.merge_result.merge_result_cli import _merge_result_parser, merge_result_cli
23
25
 
24
26
 
25
27
  def is_module_available(module_name):
@@ -45,10 +47,15 @@ def main():
45
47
  multi_run_ut_cmd_parser = subparsers.add_parser('multi_run_ut')
46
48
  api_precision_compare_cmd_parser = subparsers.add_parser('api_precision_compare')
47
49
  run_overflow_check_cmd_parser = subparsers.add_parser('run_overflow_check')
50
+ code_mapping_cmd_parser = subparsers.add_parser('code_mapping')
48
51
  graph_service_cmd_parser = subparsers.add_parser('graph')
52
+ op_generate_cmd_parser = subparsers.add_parser('op_generate')
53
+ merge_result_parser = subparsers.add_parser('merge_result')
49
54
  _compare_parser(compare_cmd_parser)
55
+ _merge_result_parser(merge_result_parser)
56
+
50
57
  is_torch_available = is_module_available("torch")
51
- is_mindspore_available = is_module_available("mindspore")
58
+
52
59
  if len(sys.argv) < 4:
53
60
  parser.print_help()
54
61
  sys.exit(0)
@@ -62,6 +69,8 @@ def main():
62
69
  from msprobe.pytorch.api_accuracy_checker.run_ut.run_overflow_check import _run_overflow_check_parser, \
63
70
  _run_overflow_check_command
64
71
  from msprobe.visualization.graph_service import _pt_graph_service_parser, _pt_graph_service_command
72
+ from msprobe.pytorch.api_accuracy_checker.generate_op_script.op_generator import _op_generator_parser, \
73
+ _run_operator_generate_commond
65
74
 
66
75
  _run_ut_parser(run_ut_cmd_parser)
67
76
  _run_ut_parser(multi_run_ut_cmd_parser)
@@ -70,12 +79,15 @@ def main():
70
79
  _api_precision_compare_parser(api_precision_compare_cmd_parser)
71
80
  _run_overflow_check_parser(run_overflow_check_cmd_parser)
72
81
  _pt_graph_service_parser(graph_service_cmd_parser)
82
+ _op_generator_parser(op_generate_cmd_parser)
73
83
  elif framework_args.framework == Const.MS_FRAMEWORK:
74
84
  from msprobe.mindspore.api_accuracy_checker.cmd_parser import add_api_accuracy_checker_argument
75
85
  from msprobe.visualization.graph_service import _ms_graph_service_parser, _ms_graph_service_command
76
86
  add_api_accuracy_checker_argument(run_ut_cmd_parser)
77
87
  from msprobe.mindspore.api_accuracy_checker.cmd_parser import multi_add_api_accuracy_checker_argument
78
88
  multi_add_api_accuracy_checker_argument(multi_run_ut_cmd_parser)
89
+ from msprobe.mindspore.code_mapping.cmd_parser import add_ir_parser_arguments
90
+ add_ir_parser_arguments(code_mapping_cmd_parser)
79
91
 
80
92
  _ms_graph_service_parser(graph_service_cmd_parser)
81
93
 
@@ -97,17 +109,23 @@ def main():
97
109
  _run_overflow_check_command(args)
98
110
  elif sys.argv[3] == "graph":
99
111
  _pt_graph_service_command(args)
112
+ elif sys.argv[3] == 'op_generate':
113
+ _run_operator_generate_commond(args)
100
114
  elif sys.argv[3] == "compare":
101
115
  if args.cell_mapping is not None or args.api_mapping is not None:
102
116
  logger.error("Argument -cm or -am is not supported in PyTorch framework")
103
117
  raise Exception("Argument -cm or -am is not supported in PyTorch framework")
104
118
  compare_cli(args)
119
+ elif sys.argv[3] == "merge_result":
120
+ merge_result_cli(args)
105
121
  else:
106
122
  if not is_module_available(Const.MS_FRAMEWORK):
107
123
  logger.error("MindSpore does not exist, please install MindSpore library")
108
124
  raise Exception("MindSpore does not exist, please install MindSpore library")
109
125
  if sys.argv[3] == "compare":
110
126
  compare_cli(args)
127
+ elif sys.argv[3] == "merge_result":
128
+ merge_result_cli(args)
111
129
  elif sys.argv[3] == "run_ut":
112
130
  from msprobe.mindspore.api_accuracy_checker.main import api_checker_main
113
131
  api_checker_main(args)
@@ -116,6 +134,9 @@ def main():
116
134
  mul_api_checker_main(args)
117
135
  elif sys.argv[3] == "graph":
118
136
  _ms_graph_service_command(args)
137
+ elif sys.argv[3] == "code_mapping":
138
+ from msprobe.mindspore.code_mapping.main import code_mapping_main
139
+ code_mapping_main(args)
119
140
 
120
141
 
121
142
  if __name__ == "__main__":
@@ -1,6 +1,4 @@
1
- #!/usr/bin/env python3
2
- # -*- coding: utf-8 -*-
3
- # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
1
+ # Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
4
2
  # All rights reserved.
5
3
  #
6
4
  # Licensed under the Apache License, Version 2.0 (the "License");
@@ -15,10 +13,12 @@
15
13
  # See the License for the specific language governing permissions and
16
14
  # limitations under the License.
17
15
 
18
-
19
- from msprobe.pytorch.monitor.module_hook import TrainerMon
16
+ import torch
20
17
  from .compare.distributed_compare import compare_distributed
21
18
  from .compare.pt_compare import compare
22
19
  from .common.utils import seed_all
23
- from .debugger.precision_debugger import PrecisionDebugger
24
- from .functional.module_dump import module_dump, module_dump_end
20
+ from .debugger.precision_debugger import PrecisionDebugger, module_dump, module_dump_end
21
+
22
+ torch_version_above_or_equal_2 = torch.__version__.split('+')[0] >= '2.0'
23
+ if torch_version_above_or_equal_2:
24
+ from msprobe.pytorch.monitor.module_hook import TrainerMon
@@ -72,38 +72,53 @@ def check_need_convert(api_name):
72
72
  return convert_type
73
73
 
74
74
 
75
- def api_info_preprocess(api_name, api_info_dict):
75
+ def cross_entropy_process(api_info_dict):
76
76
  """
77
77
  Function Description:
78
- Preprocesses the API information.
78
+ Preprocesses the cross_entropy API information.
79
79
  Parameter:
80
- api_name: Name of the API.
81
80
  api_info_dict: argument of the API.
82
81
  Return api_info_dict:
83
- convert_type: Type of conversion.
84
82
  api_info_dict: Processed argument of the API.
85
83
  """
86
- convert_type = check_need_convert(api_name)
87
- if api_name == 'cross_entropy':
88
- api_info_dict = cross_entropy_process(api_info_dict)
89
- return convert_type, api_info_dict
84
+ if 'input_args' in api_info_dict and len(api_info_dict['input_args']) > 1 \
85
+ and 'Min' in api_info_dict['input_args'][1]:
86
+ if api_info_dict['input_args'][1]['Min'] <= 0:
87
+ # The second argument in cross_entropy should be -100 or not less than 0
88
+ api_info_dict['input_args'][1]['Min'] = 0
89
+ return api_info_dict
90
90
 
91
91
 
92
- def cross_entropy_process(api_info_dict):
92
+ def histc_process(api_info_dict):
93
+ input_args = api_info_dict['input_args']
94
+ if input_args and input_args[0].get('dtype'):
95
+ dtype = input_args[0]['dtype']
96
+ if dtype in Const.TORCH_INT_DTYPE:
97
+ api_info_dict['input_args'][0]['dtype'] = Const.TORCH_FLOAT32
98
+ return api_info_dict
99
+
100
+
101
+ API_PROCESS_MAP = {
102
+ 'cross_entropy': cross_entropy_process,
103
+ 'histc': histc_process
104
+ }
105
+
106
+
107
+ def api_info_preprocess(api_name, api_info_dict):
93
108
  """
94
109
  Function Description:
95
- Preprocesses the cross_entropy API information.
110
+ Preprocesses the API information.
96
111
  Parameter:
112
+ api_name: Name of the API.
97
113
  api_info_dict: argument of the API.
98
114
  Return api_info_dict:
115
+ convert_type: Type of conversion.
99
116
  api_info_dict: Processed argument of the API.
100
117
  """
101
- if 'input_args' in api_info_dict and len(api_info_dict['input_args']) > 1 \
102
- and 'Min' in api_info_dict['input_args'][1]:
103
- if api_info_dict['input_args'][1]['Min'] <= 0:
104
- # The second argument in cross_entropy should be -100 or not less than 0
105
- api_info_dict['input_args'][1]['Min'] = 0
106
- return api_info_dict
118
+ convert_type = check_need_convert(api_name)
119
+ if api_name in API_PROCESS_MAP:
120
+ api_info_dict = API_PROCESS_MAP[api_name](api_info_dict)
121
+ return convert_type, api_info_dict
107
122
 
108
123
 
109
124
  def initialize_save_path(save_path, dir_name):
@@ -16,10 +16,12 @@
16
16
  # limitations under the License.
17
17
 
18
18
  # 定义比对算法及比对标准
19
+ import math
19
20
  import torch
20
21
  import numpy as np
21
22
 
22
23
  from msprobe.pytorch.api_accuracy_checker.compare.compare_utils import ULP_PARAMETERS
24
+ from msprobe.pytorch.api_accuracy_checker.precision_standard.standard_config import StandardConfig
23
25
  from msprobe.core.common.const import CompareConst
24
26
 
25
27
 
@@ -179,13 +181,13 @@ def check_inf_nan_value(inf_nan_mask, bench_output, device_output, dtype, rtol):
179
181
 
180
182
  def check_small_value(abs_err, small_value_mask, small_value_atol):
181
183
  '''
182
- 新精度标准的相对阈值法中,检查npu和golden小值域输出的相对误差是否满足阈值
184
+ 新精度标准的绝对阈值法中,检查npu和golden正常值输出的绝对误差是否满足阈值
183
185
  输入:
184
- rel_err:npu输出和golden输出的相对误差
186
+ abs_err:npu输出和golden输出的绝对误差
185
187
  normal_value_mask:npu输出和golden输出的正常值mask
186
- rtol:相对误差的阈值
188
+ atol:绝对误差的阈值
187
189
  输出:
188
- rel_err_ratio:npu输出和golden输出的相对误差不满足阈值的比例
190
+ abs_err_ratio:npu输出和golden输出的绝对误差不满足阈值的比例
189
191
  '''
190
192
  greater_mask = np.greater(abs_err, small_value_atol)
191
193
  err_mask = np.logical_and(greater_mask, small_value_mask)
@@ -195,13 +197,13 @@ def check_small_value(abs_err, small_value_mask, small_value_atol):
195
197
 
196
198
  def check_norm_value(normal_value_mask, rel_err, rtol):
197
199
  '''
198
- 新精度标准的绝对阈值法中,检查npu和golden正常值输出的绝对误差是否满足阈值
200
+ 新精度标准的相对阈值法中,检查npu和golden小值域输出的相对误差是否满足阈值
199
201
  输入:
200
- abs_err:npu输出和golden输出的绝对误差
202
+ rel_err:npu输出和golden输出的相对误差
201
203
  normal_value_mask:npu输出和golden输出的正常值mask
202
- atol:绝对误差的阈值
204
+ rtol:相对误差的阈值
203
205
  输出:
204
- abs_err_ratio:npu输出和golden输出的绝对误差不满足阈值的比例
206
+ rel_err_ratio:npu输出和golden输出的相对误差不满足阈值的比例
205
207
  '''
206
208
  err_mask = np.greater(rel_err, rtol)
207
209
  err_mask = np.logical_and(err_mask, normal_value_mask)
@@ -228,3 +230,34 @@ def get_ulp_err(bench_output, device_output, dtype):
228
230
  def calc_ulp_err(bench_output, device_output, eb, exponent_num, data_type):
229
231
  return (device_output.astype(data_type) - bench_output).astype(data_type) * \
230
232
  np.exp2(-eb + exponent_num).astype(data_type)
233
+
234
+
235
+ def calc_ratio(x, y, dtype):
236
+ """
237
+ Calculate the ratio between NPU and GPU statistical values.
238
+
239
+ Args:
240
+ x (float): Statistical value from the NPU side
241
+ y (float): Statistical value from the GPU side
242
+ dtype: Data type used to determine the minimum error value
243
+
244
+ Returns:
245
+ float: The ratio of NPU to GPU statistical values
246
+
247
+ Notes:
248
+ - Takes absolute values of both x and y for calculation
249
+ - Uses StandardConfig.get_minmum_err(dtype) to get minimum error for the specified dtype
250
+ - Prevents division by zero by ensuring denominator is not less than minimum error
251
+ - Returns |x| / max(|y|, minimum_error)
252
+ """
253
+ x, y = abs(x), abs(y)
254
+ minmum_err = StandardConfig.get_minmum_err(dtype)
255
+ err_y = max(y, minmum_err)
256
+ return x / err_y
257
+
258
+
259
+ def compare_bool_tensor(bench_output, device_output):
260
+ error_nums = (bench_output != device_output).sum()
261
+ error_rate = float(error_nums / bench_output.size)
262
+ result = CompareConst.PASS if error_rate == 0 else CompareConst.ERROR
263
+ return error_rate, result, ""