mindstudio-probe 1.1.0__py3-none-any.whl → 1.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (220) hide show
  1. {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.1.1.dist-info}/METADATA +5 -5
  2. mindstudio_probe-1.1.1.dist-info/RECORD +341 -0
  3. {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.1.1.dist-info}/WHEEL +1 -1
  4. {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.1.1.dist-info}/entry_points.txt +0 -1
  5. msprobe/README.md +39 -3
  6. msprobe/config.json +1 -3
  7. msprobe/core/advisor/advisor.py +8 -3
  8. msprobe/core/common/const.py +113 -13
  9. msprobe/core/common/exceptions.py +25 -3
  10. msprobe/core/common/file_utils.py +150 -26
  11. msprobe/core/common/inplace_op_checker.py +15 -0
  12. msprobe/core/common/log.py +27 -9
  13. msprobe/core/common/utils.py +182 -69
  14. msprobe/core/common_config.py +44 -15
  15. msprobe/core/compare/acc_compare.py +207 -142
  16. msprobe/core/compare/check.py +2 -5
  17. msprobe/core/compare/compare_cli.py +21 -4
  18. msprobe/core/compare/highlight.py +124 -55
  19. msprobe/core/compare/layer_mapping/__init__.py +19 -0
  20. msprobe/core/compare/layer_mapping/data_scope_parser.py +235 -0
  21. msprobe/core/compare/layer_mapping/layer_mapping.py +242 -0
  22. msprobe/core/compare/layer_mapping/postprocess_pass.py +94 -0
  23. msprobe/core/compare/npy_compare.py +52 -23
  24. msprobe/core/compare/utils.py +272 -247
  25. msprobe/core/data_dump/data_collector.py +13 -11
  26. msprobe/core/data_dump/data_processor/base.py +46 -16
  27. msprobe/core/data_dump/data_processor/mindspore_processor.py +4 -4
  28. msprobe/core/data_dump/data_processor/pytorch_processor.py +156 -59
  29. msprobe/core/data_dump/scope.py +113 -34
  30. msprobe/core/grad_probe/constant.py +27 -13
  31. msprobe/core/grad_probe/grad_compare.py +18 -1
  32. msprobe/core/grad_probe/utils.py +30 -2
  33. msprobe/core/overflow_check/abnormal_scene.py +185 -0
  34. msprobe/core/overflow_check/api_info.py +55 -0
  35. msprobe/core/overflow_check/checker.py +138 -0
  36. msprobe/core/overflow_check/filter.py +157 -0
  37. msprobe/core/overflow_check/ignore_rules.yaml +55 -0
  38. msprobe/core/overflow_check/level.py +22 -0
  39. msprobe/core/overflow_check/utils.py +28 -0
  40. msprobe/docs/01.installation.md +10 -0
  41. msprobe/docs/02.config_introduction.md +49 -22
  42. msprobe/docs/03.config_examples.md +2 -9
  43. msprobe/docs/04.kernel_dump_PyTorch.md +73 -0
  44. msprobe/docs/05.data_dump_PyTorch.md +3 -1
  45. msprobe/docs/06.data_dump_MindSpore.md +157 -90
  46. msprobe/docs/07.accuracy_checker_PyTorch.md +12 -12
  47. msprobe/docs/08.accuracy_checker_online_PyTorch.md +1 -6
  48. msprobe/docs/09.accuracy_checker_MindSpore.md +44 -8
  49. msprobe/docs/10.accuracy_compare_PyTorch.md +19 -13
  50. msprobe/docs/11.accuracy_compare_MindSpore.md +104 -13
  51. msprobe/docs/12.overflow_check_PyTorch.md +1 -1
  52. msprobe/docs/13.overflow_check_MindSpore.md +6 -6
  53. msprobe/docs/15.free_benchmarking_PyTorch.md +4 -5
  54. msprobe/docs/16.free_benchmarking_MindSpore.md +56 -37
  55. msprobe/docs/17.grad_probe.md +5 -6
  56. msprobe/docs/19.monitor.md +468 -0
  57. msprobe/docs/20.monitor_performance_baseline.md +52 -0
  58. msprobe/docs/21.visualization_PyTorch.md +386 -0
  59. msprobe/docs/22.visualization_MindSpore.md +384 -0
  60. msprobe/docs/23.tool_function_introduction.md +28 -0
  61. msprobe/docs/FAQ.md +3 -0
  62. msprobe/docs/data_dump_Mindspore/dynamic_graph_quick_start_example.md +211 -0
  63. msprobe/docs/img/compare_result.png +0 -0
  64. msprobe/docs/img/monitor/cpu_info.png +0 -0
  65. msprobe/mindspore/__init__.py +15 -0
  66. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +113 -145
  67. msprobe/mindspore/api_accuracy_checker/api_info.py +21 -6
  68. msprobe/mindspore/api_accuracy_checker/api_runner.py +43 -18
  69. msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +21 -7
  70. msprobe/mindspore/api_accuracy_checker/checker_support_api.yaml +77 -0
  71. msprobe/mindspore/api_accuracy_checker/cmd_parser.py +63 -1
  72. msprobe/mindspore/api_accuracy_checker/compute_element.py +59 -24
  73. msprobe/mindspore/api_accuracy_checker/data_manager.py +264 -0
  74. msprobe/mindspore/api_accuracy_checker/main.py +27 -3
  75. msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +206 -0
  76. msprobe/mindspore/api_accuracy_checker/multi_data_manager.py +58 -0
  77. msprobe/mindspore/api_accuracy_checker/type_mapping.py +22 -5
  78. msprobe/mindspore/api_accuracy_checker/utils.py +34 -17
  79. msprobe/mindspore/cell_processor.py +33 -12
  80. msprobe/mindspore/common/const.py +33 -13
  81. msprobe/mindspore/common/log.py +5 -9
  82. msprobe/mindspore/common/utils.py +43 -4
  83. msprobe/mindspore/compare/distributed_compare.py +22 -22
  84. msprobe/mindspore/compare/ms_compare.py +271 -248
  85. msprobe/mindspore/compare/ms_graph_compare.py +81 -47
  86. msprobe/mindspore/debugger/debugger_config.py +4 -1
  87. msprobe/mindspore/debugger/precision_debugger.py +7 -1
  88. msprobe/mindspore/dump/dump_tool_factory.py +3 -1
  89. msprobe/mindspore/dump/hook_cell/api_registry.py +12 -2
  90. msprobe/mindspore/dump/hook_cell/primitive_hooks.py +13 -16
  91. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +25 -0
  92. msprobe/mindspore/dump/jit_dump.py +17 -5
  93. msprobe/mindspore/dump/kernel_graph_dump.py +2 -4
  94. msprobe/mindspore/dump/kernel_kbyk_dump.py +2 -4
  95. msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +140 -0
  96. msprobe/mindspore/dym_loader/hook_dynamic_loader.h +53 -0
  97. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +145 -39
  98. msprobe/mindspore/free_benchmark/common/handler_params.py +1 -2
  99. msprobe/mindspore/free_benchmark/common/utils.py +19 -4
  100. msprobe/mindspore/free_benchmark/data/support_wrap_ops.yaml +0 -204
  101. msprobe/mindspore/free_benchmark/handler/base_handler.py +3 -3
  102. msprobe/mindspore/free_benchmark/handler/check_handler.py +4 -5
  103. msprobe/mindspore/free_benchmark/handler/fix_handler.py +4 -4
  104. msprobe/mindspore/free_benchmark/handler/handler_factory.py +4 -4
  105. msprobe/mindspore/free_benchmark/perturbation/add_noise.py +2 -2
  106. msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +15 -6
  107. msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +4 -4
  108. msprobe/mindspore/free_benchmark/perturbation/exchange_value.py +2 -2
  109. msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +13 -6
  110. msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +2 -2
  111. msprobe/mindspore/free_benchmark/self_check_tool_factory.py +2 -2
  112. msprobe/mindspore/grad_probe/global_context.py +28 -8
  113. msprobe/mindspore/grad_probe/grad_analyzer.py +27 -13
  114. msprobe/mindspore/grad_probe/grad_monitor.py +16 -1
  115. msprobe/mindspore/grad_probe/grad_stat_csv.py +33 -5
  116. msprobe/mindspore/grad_probe/hook.py +24 -10
  117. msprobe/mindspore/grad_probe/utils.py +18 -5
  118. msprobe/mindspore/ms_config.py +22 -15
  119. msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +2 -4
  120. msprobe/mindspore/runtime.py +15 -0
  121. msprobe/mindspore/service.py +36 -30
  122. msprobe/mindspore/task_handler_factory.py +15 -0
  123. msprobe/msprobe.py +24 -7
  124. msprobe/pytorch/__init__.py +3 -2
  125. msprobe/pytorch/api_accuracy_checker/common/config.py +62 -0
  126. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +3 -4
  127. msprobe/pytorch/api_accuracy_checker/generate_op_script/config_op.json +9 -0
  128. msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +454 -0
  129. msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +365 -0
  130. msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +6 -1
  131. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +19 -14
  132. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +13 -9
  133. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +77 -53
  134. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +15 -4
  135. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +9 -24
  136. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +4 -12
  137. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/dump_dispatch.py +9 -4
  138. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +3 -11
  139. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +2 -2
  140. msprobe/pytorch/bench_functions/confusion_transpose.py +5 -1
  141. msprobe/pytorch/bench_functions/matmul_backward.py +12 -0
  142. msprobe/pytorch/bench_functions/npu_fusion_attention.py +100 -6
  143. msprobe/pytorch/bench_functions/rotary_mul.py +4 -0
  144. msprobe/pytorch/bench_functions/swiglu.py +10 -2
  145. msprobe/pytorch/common/parse_json.py +6 -6
  146. msprobe/pytorch/common/utils.py +56 -5
  147. msprobe/pytorch/compare/distributed_compare.py +8 -9
  148. msprobe/pytorch/compare/pt_compare.py +8 -6
  149. msprobe/pytorch/debugger/debugger_config.py +19 -15
  150. msprobe/pytorch/dump/kernel_dump/kernel_config.py +33 -0
  151. msprobe/pytorch/free_benchmark/common/constant.py +15 -0
  152. msprobe/pytorch/free_benchmark/common/counter.py +15 -0
  153. msprobe/pytorch/free_benchmark/common/enums.py +15 -0
  154. msprobe/pytorch/free_benchmark/common/params.py +8 -1
  155. msprobe/pytorch/free_benchmark/common/utils.py +26 -4
  156. msprobe/pytorch/free_benchmark/compare/grad_saver.py +20 -3
  157. msprobe/pytorch/free_benchmark/compare/single_benchmark.py +2 -0
  158. msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +3 -1
  159. msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +6 -4
  160. msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +2 -0
  161. msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +4 -0
  162. msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +10 -0
  163. msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +6 -5
  164. msprobe/pytorch/grad_probe/grad_monitor.py +23 -6
  165. msprobe/pytorch/grad_probe/grad_stat_csv.py +40 -10
  166. msprobe/pytorch/hook_module/support_wrap_ops.yaml +1 -0
  167. msprobe/pytorch/hook_module/wrap_functional.py +14 -12
  168. msprobe/pytorch/module_processer.py +2 -5
  169. msprobe/pytorch/monitor/anomaly_analyse.py +201 -0
  170. msprobe/pytorch/monitor/anomaly_detect.py +340 -0
  171. msprobe/pytorch/monitor/distributed/__init__.py +0 -0
  172. msprobe/pytorch/monitor/distributed/distributed_ops.yaml +19 -0
  173. msprobe/pytorch/monitor/distributed/stack_blacklist.yaml +5 -0
  174. msprobe/pytorch/monitor/distributed/wrap_distributed.py +272 -0
  175. msprobe/pytorch/monitor/features.py +108 -0
  176. msprobe/pytorch/monitor/module_hook.py +870 -0
  177. msprobe/pytorch/monitor/module_metric.py +193 -0
  178. msprobe/pytorch/monitor/module_spec_verifier.py +93 -0
  179. msprobe/pytorch/monitor/optimizer_collect.py +295 -0
  180. msprobe/pytorch/monitor/unittest/__init__.py +0 -0
  181. msprobe/pytorch/monitor/unittest/test_monitor.py +145 -0
  182. msprobe/pytorch/monitor/utils.py +250 -0
  183. msprobe/pytorch/monitor/visualizer.py +59 -0
  184. msprobe/pytorch/online_dispatch/__init__.py +2 -3
  185. msprobe/pytorch/online_dispatch/compare.py +29 -38
  186. msprobe/pytorch/online_dispatch/dispatch.py +50 -25
  187. msprobe/pytorch/online_dispatch/dump_compare.py +21 -9
  188. msprobe/pytorch/online_dispatch/single_compare.py +53 -32
  189. msprobe/pytorch/online_dispatch/torch_ops_config.yaml +1 -1
  190. msprobe/pytorch/online_dispatch/utils.py +49 -21
  191. msprobe/pytorch/parse_tool/lib/compare.py +12 -18
  192. msprobe/pytorch/parse_tool/lib/config.py +1 -1
  193. msprobe/pytorch/parse_tool/lib/parse_tool.py +1 -2
  194. msprobe/pytorch/parse_tool/lib/utils.py +16 -35
  195. msprobe/pytorch/parse_tool/lib/visualization.py +2 -0
  196. msprobe/pytorch/pt_config.py +31 -8
  197. msprobe/pytorch/service.py +15 -5
  198. msprobe/visualization/__init__.py +14 -0
  199. msprobe/visualization/builder/__init__.py +14 -0
  200. msprobe/visualization/builder/graph_builder.py +165 -0
  201. msprobe/visualization/builder/msprobe_adapter.py +205 -0
  202. msprobe/visualization/compare/__init__.py +14 -0
  203. msprobe/visualization/compare/graph_comparator.py +130 -0
  204. msprobe/visualization/compare/mode_adapter.py +211 -0
  205. msprobe/visualization/graph/__init__.py +14 -0
  206. msprobe/visualization/graph/base_node.py +124 -0
  207. msprobe/visualization/graph/graph.py +200 -0
  208. msprobe/visualization/graph/node_colors.py +95 -0
  209. msprobe/visualization/graph/node_op.py +39 -0
  210. msprobe/visualization/graph_service.py +214 -0
  211. msprobe/visualization/utils.py +232 -0
  212. mindstudio_probe-1.1.0.dist-info/RECORD +0 -287
  213. msprobe/docs/04.acl_config_examples.md +0 -78
  214. msprobe/mindspore/compare/layer_mapping.py +0 -146
  215. msprobe/mindspore/compare/modify_mapping.py +0 -107
  216. msprobe/mindspore/free_benchmark/decorator/dec_forward.py +0 -57
  217. msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +0 -122
  218. {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.1.1.dist-info}/LICENSE +0 -0
  219. {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.1.1.dist-info}/top_level.txt +0 -0
  220. /msprobe/{mindspore/free_benchmark/decorator → pytorch/monitor}/__init__.py +0 -0
@@ -0,0 +1,365 @@
1
+ import json
2
+ import os
3
+ import math
4
+ from enum import Enum, auto
5
+ import torch
6
+ try:
7
+ import torch_npu
8
+ except ImportError:
9
+ pass
10
+ from tabulate import tabulate
11
+
12
+ TENSOR_DATA_LIST = ["torch.Tensor", "torch.nn.parameter.Parameter"]
13
+ TORCH_BOOL_TYPE = ["torch.bool"]
14
+ TORCH_INT_TYPE = ["torch.uint8", "torch.int8", "torch.int16", "torch.short", "torch.int32", "torch.int",
15
+ "torch.int64", "torch.long"]
16
+ TORCH_FLOAT_TYPE = ["torch.float16", "torch.half", "torch.bfloat16", "torch.float32", "torch.float",
17
+ "torch.float64", "torch.double"]
18
+ TORCH_COMPLEX_TYPE = ["torch.complex32", "torch.chalf", "torch.complex64", "torch.cfloat", "torch.complex128", "torch.cdouble"]
19
+ RAISE_PRECISION = {{
20
+ "torch.float16": torch.float32,
21
+ "torch.half": torch.float32,
22
+ "torch.bfloat16": torch.float32,
23
+ "torch.float32": torch.float64,
24
+ "torch.float": torch.float64
25
+ }}
26
+ THOUSANDTH_THRESHOLDING = 0.001
27
+ BACKWARD = 'backward'
28
+
29
+ class CompareStandard(Enum):
30
+ BINARY_EQUALITY_STANDARD = auto()
31
+ ABSOLUTE_THRESHOLD_STANDARD = auto()
32
+ ULP_ERROR_STANDARD = auto()
33
+ BENCHMARK_STANDARD = auto()
34
+ THOUSANDTH_STANDARD = auto()
35
+
36
+ def load_pt(pt_path, to_cpu=False):
37
+ pt_path = os.path.realpath(pt_path)
38
+ try:
39
+ if to_cpu:
40
+ pt = torch.load(pt_path, map_location=torch.device("cpu"))
41
+ else:
42
+ pt = torch.load(pt_path)
43
+ except Exception as e:
44
+ raise RuntimeError(f"load pt file {{pt_path}} failed") from e
45
+ return pt
46
+
47
+ def get_device():
48
+ if torch.cuda.is_available():
49
+ device = torch.device("cuda")
50
+ elif torch_npu.npu.is_available():
51
+ device = torch.device("npu")
52
+ else:
53
+ raise Exception("Error: This device is not NPU or GPU!")
54
+ return device
55
+
56
+
57
+ def generate_bool_tensor(low, high, shape):
58
+ low, high = int(low), int(high)
59
+ tensor = torch.randint(low, high + 1, shape)
60
+ bool_tensor = torch.gt(tensor, 0)
61
+ return bool_tensor
62
+
63
+
64
+ def generate_numerical_tensor(low, high, shape, data_dtype):
65
+ if data_dtype in TORCH_FLOAT_TYPE:
66
+ scale = high - low
67
+ rand01 = torch.rand(shape, dtype=eval(data_dtype))
68
+ tensor = rand01 * scale + low
69
+ elif data_dtype in TORCH_INT_TYPE:
70
+ low, high = int(low), int(high)
71
+ tensor = torch.randint(low, high + 1, shape, dtype=eval(data_dtype))
72
+ else:
73
+ raise NotImplementedError(f"{{data_dtype}} is not supported!")
74
+ if torch.numel(tensor) == 0:
75
+ return tensor
76
+ tmp_tensor = tensor.reshape(-1)
77
+ tmp_tensor[0] = low
78
+ tmp_tensor[-1] = high
79
+ data = tmp_tensor.reshape(shape)
80
+ return data
81
+
82
+
83
+ def generate_random_tensor(info):
84
+ low, high = info.get('Min'), info.get('Max')
85
+ data_dtype = info.get('dtype')
86
+ shape = tuple(info.get('shape'))
87
+ if data_dtype == "torch.bool":
88
+ data = generate_bool_tensor(low, high, shape)
89
+ else:
90
+ data = generate_numerical_tensor(low, high, shape, data_dtype)
91
+ return data
92
+
93
+
94
+ def generate_real_tensor(data_path):
95
+ data_path = os.path.realpath(data_path)
96
+ data = load_pt(data_path, to_cpu = True)
97
+ return data
98
+
99
+
100
+ def generate_data(info):
101
+ data_type = info.get("type")
102
+ data_path = info.get("data_name")
103
+ data_grad = info.get("requires_grad")
104
+ if data_type in TENSOR_DATA_LIST:
105
+ if data_path:
106
+ data = generate_real_tensor(data_path)
107
+ else:
108
+ data = generate_random_tensor(info)
109
+ else:
110
+ data = info.get("value")
111
+ if data_grad == True:
112
+ data.requires_grad_(True)
113
+ return data
114
+
115
+
116
+ def get_input(propagation):
117
+ {args_element_assignment}
118
+ args_device = [{args_list_generator_device}]
119
+ args_bench = [{args_list_generator_bench}]
120
+ {kwargs_value_assignment}
121
+ kwargs_device = {{{kwargs_dict_generator_device}}}
122
+ kwargs_bench = {{{kwargs_dict_generator_bench}}}
123
+ {args_element_assignment_backward}
124
+ args_device_backward = [{args_list_generator_device_backward}]
125
+ args_bench_backward = [{args_list_generator_bench_backward}]
126
+ if propagation == BACKWARD:
127
+ return args_device, kwargs_device, args_bench, kwargs_bench, args_device_backward, args_bench_backward
128
+ return args_device, kwargs_device, args_bench, kwargs_bench
129
+
130
+ def exec_api(args, kwargs, args_grad_input, propagation):
131
+ output = {api_type}.{api_name}(*args, **kwargs)
132
+ if propagation == BACKWARD:
133
+ args_input_tensor = [tensor for tensor in args if isinstance(tensor, torch.Tensor) and tensor.requires_grad]
134
+ args_input_tensor.extend(
135
+ [value for value in kwargs.values() if isinstance(value, torch.Tensor) and value.requires_grad])
136
+ output_backward = torch.autograd.grad(outputs=output, inputs=args_input_tensor, grad_outputs=args_grad_input)
137
+ return output_backward
138
+ return output
139
+
140
+ def compute_inf_nan_proportion(inf_nan_mask, out_device, out_bench, abs_bench_with_eps, rtol):
141
+ out_bench = out_bench.to(out_device.dtype)
142
+ min = torch.finfo(out_device.dtype).min
143
+ max = torch.finfo(out_device.dtype).max
144
+ bench_clip = torch.clamp(out_bench, min=min, max=max)
145
+ device_clip = torch.clamp(out_device, min=min, max=max)
146
+ clipped_abs_ae = torch.abs(device_clip - bench_clip)
147
+ clipped_re = clipped_abs_ae / abs_bench_with_eps
148
+ pass_mask = torch.less_equal(clipped_re, rtol)
149
+ both_nan_mask = torch.logical_and(torch.isnan(out_device), torch.isnan(bench_clip))
150
+ pass_mask = torch.logical_or(pass_mask, both_nan_mask)
151
+ not_pass_mask = torch.logical_not(pass_mask)
152
+ not_pass_mask = torch.logical_and(not_pass_mask, inf_nan_mask)
153
+ inf_nan_err_cnt = torch.sum(not_pass_mask)
154
+ return 0 if torch.sum(inf_nan_mask) == 0 else inf_nan_err_cnt / torch.sum(inf_nan_mask)
155
+
156
+
157
+ def compute_rmse(abs_err, normal_value_mask):
158
+ if torch.sum(normal_value_mask) == 0:
159
+ return 0
160
+ else:
161
+ masked_ae = torch.where(normal_value_mask, abs_err, 0)
162
+ mse = torch.sum(torch.square(masked_ae)) / torch.sum(normal_value_mask)
163
+ rmse = torch.sqrt(mse)
164
+ return rmse
165
+
166
+
167
+ def compute_error_balance(out_device, out_bench):
168
+ larger_count = torch.sum(torch.greater(out_device - out_bench.to(out_device.dtype), 0))
169
+ smaller_count = torch.sum(torch.less(out_device - out_bench.to(out_device.dtype), 0))
170
+ if torch.numel(out_bench) == 0:
171
+ raise ZeroDivisionError(f"ERROR: please check torch.numel out_bench, its value is {{torch.numel(out_bench)}}")
172
+ error_balance = abs(larger_count - smaller_count) / torch.numel(out_bench)
173
+ return error_balance
174
+
175
+
176
+ def compare_tensor(out_device, out_bench, api_name):
177
+ if out_device.shape != out_bench.shape:
178
+ print("ERROR: shape of out_device and out_bench is not equal!")
179
+ return None
180
+ if torch.numel(out_bench) == 0:
181
+ print("Both out_device and out_bench have zero elements.")
182
+ return None
183
+ dtype_device = out_device.dtype
184
+ dtype_bench = out_bench.dtype
185
+ headers = ["Metric", "Value"]
186
+ table = [
187
+ ["Shape", out_bench.shape],
188
+ ["Dtype of out_device", out_device.dtype],
189
+ ["Dtype of out_bench", out_bench.dtype]
190
+ ]
191
+ if str(dtype_device) in TORCH_FLOAT_TYPE and str(dtype_bench) in TORCH_FLOAT_TYPE \
192
+ or str(dtype_device) in TORCH_INT_TYPE and str(dtype_bench) in TORCH_INT_TYPE \
193
+ or str(dtype_device) in TORCH_BOOL_TYPE and str(dtype_bench) in TORCH_BOOL_TYPE:
194
+ out_device = out_device.to(torch.device("cpu"))
195
+ if str(dtype_device) in TORCH_BOOL_TYPE or str(dtype_device) in TORCH_INT_TYPE or compare_standard == CompareStandard.BINARY_EQUALITY_STANDARD:
196
+ error_number = torch.sum(out_device != out_bench).item()
197
+ if torch.numel(out_bench) == 0:
198
+ raise ZeroDivisionError(f"ERROR: please check torch.numel out_bench, its value is {{torch.numel(out_bench)}}")
199
+ error_rate = error_number / torch.numel(out_bench)
200
+ table.append(["Compare Standard", "Binary Equality Standard"])
201
+ table.append(["Error Rate", error_rate])
202
+ else:
203
+ abs_err = torch.abs(out_device - out_bench)
204
+ abs_bench = torch.abs(out_bench)
205
+ if dtype_bench == torch.float32:
206
+ eps = 2 ** -23
207
+ if dtype_bench == torch.float64:
208
+ eps = 2 ** -52
209
+ abs_bench_with_eps = abs_bench + eps
210
+ rel_err = torch.abs(abs_err / abs_bench_with_eps)
211
+ device_finite_mask = torch.isfinite(out_device)
212
+ bench_finite_mask = torch.isfinite(out_bench.to(dtype_device))
213
+ both_finite_mask = torch.logical_and(device_finite_mask, bench_finite_mask)
214
+ inf_nan_mask = torch.logical_not(both_finite_mask)
215
+ if compare_standard == CompareStandard.ABSOLUTE_THRESHOLD_STANDARD:
216
+ if dtype_device == torch.float16:
217
+ rtol, small_value, small_value_atol = 1.0e-3, 1.0e-3, 1.0e-5
218
+ elif dtype_device == torch.bfloat16:
219
+ rtol, small_value, small_value_atol = 4.0e-3, 1.0e-3, 1.0e-5
220
+ else:
221
+ rtol, small_value, small_value_atol = 1.0e-6, 1.0e-6, 1.0e-9
222
+ small_value_mask = torch.less_equal(abs_bench, small_value)
223
+ small_value_mask = torch.logical_and(small_value_mask, both_finite_mask)
224
+ normal_value_mask = torch.logical_and(both_finite_mask, torch.logical_not(small_value_mask))
225
+ inf_nan_proportion = compute_inf_nan_proportion(inf_nan_mask, out_device, out_bench, abs_bench_with_eps, rtol)
226
+ rel_err_mask = torch.greater(rel_err, rtol)
227
+ rel_err_mask = torch.logical_and(rel_err_mask, normal_value_mask)
228
+ if torch.sum(normal_value_mask) == 0:
229
+ rel_err_proportion = 0
230
+ else:
231
+ rel_err_proportion = torch.sum(rel_err_mask) / torch.sum(normal_value_mask)
232
+ abs_err_mask = torch.greater(abs_err, small_value_atol)
233
+ abs_err_mask = torch.logical_and(abs_err_mask, small_value_mask)
234
+ if torch.sum(small_value_mask) == 0:
235
+ abs_err_proportion = 0
236
+ else:
237
+ abs_err_proportion = torch.sum(abs_err_mask) / torch.sum(small_value_mask)
238
+ table.append(["Compare Standard", "Absolute Threshold Standard"])
239
+ table.append(["Relative Error Ratio", rel_err_proportion])
240
+ table.append(["Absolute Error Ratio", abs_err_proportion])
241
+ elif compare_standard == CompareStandard.ULP_ERROR_STANDARD:
242
+ if dtype_device == torch.float16:
243
+ min_eb, exponent_num = -14, 10
244
+ elif dtype_device == torch.bfloat16:
245
+ min_eb, exponent_num = -126, 7
246
+ else:
247
+ min_eb, exponent_num = -126, 23
248
+ eb = torch.where(abs_bench == 0, torch.zeros(out_bench.shape), torch.floor(torch.log2(abs_bench)))
249
+ eb = torch.maximum(eb, min_eb * torch.ones(out_bench.shape))
250
+ if dtype_device == torch.float32:
251
+ ulp_err = (out_device.to(torch.float64) - out_bench).to(torch.float64) * torch.exp2(-eb + exponent_num).to(torch.float64)
252
+ else:
253
+ ulp_err = (out_device.to(torch.float32) - out_bench).to(torch.float32) * torch.exp2(-eb + exponent_num).to(torch.float32)
254
+ ulp_err = torch.abs(ulp_err)
255
+ max_ulp_err = torch.max(ulp_err)
256
+ mean_ulp_err = torch.mean(ulp_err)
257
+ if torch.numel(out_bench) == 0:
258
+ raise ZeroDivisionError(f"ERROR: please check torch.numel out_bench, its value is {{torch.numel(out_bench)}}")
259
+ if dtype_device == torch.float32:
260
+ ulp_err_proportion = torch.sum(ulp_err > 32) / torch.numel(out_bench)
261
+ else:
262
+ ulp_err_proportion = torch.sum(ulp_err > 1) / torch.numel(out_bench)
263
+ table.append(["Compare Standard", "ULP error Standard"])
264
+ table.append(["Maximum ULP Error", max_ulp_err])
265
+ table.append(["Mean ULP Error", mean_ulp_err])
266
+ table.append(["ULP Error Proportion", ulp_err_proportion])
267
+ elif compare_standard == CompareStandard.THOUSANDTH_STANDARD:
268
+ rel_err_origin = torch.abs(abs_err / abs_bench_with_eps)
269
+ if torch.numel(rel_err_origin) == 0:
270
+ thousand_res = 1
271
+ else:
272
+ thousand_res = torch.divide(torch.sum(rel_err < THOUSANDTH_THRESHOLDING), torch.numel(rel_err_origin))
273
+ thousand_status = thousand_res > (1 - THOUSANDTH_THRESHOLDING)
274
+ table.append(["Compare Standard", "Thousandth Standard"])
275
+ table.append(["Thousandth ratio", thousand_res])
276
+ else:
277
+ if dtype_device == torch.float16:
278
+ small_value, small_value_atol = 1.0e-3, 1.0e-5
279
+ elif dtype_device == torch.bfloat16:
280
+ small_value, small_value_atol = 1.0e-3, 1.0e-5
281
+ else:
282
+ small_value, small_value_atol = 1.0e-6, 1.0e-9
283
+ small_value_mask = torch.less_equal(abs_bench, small_value)
284
+ small_value_mask = torch.logical_and(small_value_mask, both_finite_mask)
285
+ normal_value_mask = torch.logical_and(both_finite_mask, torch.logical_not(small_value_mask))
286
+ abs_err_mask = torch.greater(abs_err, small_value_atol)
287
+ abs_err_mask = torch.logical_and(abs_err_mask, small_value_mask)
288
+ if torch.sum(small_value_mask) == 0:
289
+ small_value_err_proportion = 0
290
+ else:
291
+ small_value_err_proportion = torch.sum(abs_err_mask) / torch.sum(small_value_mask)
292
+ rel_err = torch.where(normal_value_mask, rel_err, -1 * torch.ones(out_device.shape))
293
+ if torch.max(rel_err) >= 0:
294
+ max_rel_err = torch.max(rel_err)
295
+ else:
296
+ max_rel_err = 0
297
+ if torch.sum(normal_value_mask) == 0:
298
+ mean_rel_err = 0
299
+ else:
300
+ mean_rel_err = torch.sum(torch.clamp(rel_err, min=0)) / torch.sum(normal_value_mask)
301
+ rmse = compute_rmse(abs_err, normal_value_mask)
302
+ error_balance = compute_error_balance(out_device, out_bench)
303
+ table.append(["Compare Standard", "Benchmark Standard"])
304
+ table.append(["Small Value Error Proportion", small_value_err_proportion])
305
+ table.append(["Maximum Relative Error", max_rel_err])
306
+ table.append(["Mean Relative Error", mean_rel_err])
307
+ table.append(["Root Mean Squared Error", rmse])
308
+ table.append(["Error Balance", error_balance])
309
+ else:
310
+ print(f"ERROR: out_device dtype is {{dtype_device}}, out_bench dtype is {{dtype_bench}}, not comparable.")
311
+ return None
312
+ print(tabulate(table, headers, tablefmt='grid'))
313
+ return None
314
+
315
+
316
+ def compare_element(out_device, out_bench, api_name):
317
+ if type(out_device) != type(out_bench):
318
+ print("ERROR: out_device and out_bench is not the same type!")
319
+ return None
320
+ if isinstance(out_bench, torch.Tensor):
321
+ compare_tensor(out_device, out_bench, api_name)
322
+ elif isinstance(out_bench, (bool, int, float, str)):
323
+ if out_device == out_bench:
324
+ print("PASS: out_device and out_bench equals.")
325
+ else:
326
+ print("ERROR: out_device and out_bench is not equal!")
327
+ else:
328
+ print(f"ERROR: comparison of type {{type(out_bench)}} is not supported.")
329
+ return None
330
+
331
+
332
+ def compare(out_device, out_bench, api_name):
333
+ print("Compare result:")
334
+ if type(out_device) != type(out_bench):
335
+ print("ERROR: out_device and out_bench is not the same type!")
336
+ return None
337
+ if isinstance(out_bench, (list, tuple)):
338
+ if len(out_device) != len(out_bench):
339
+ print("ERROR: len of out_device and out_bench is different!")
340
+ return None
341
+ for index, _ in enumerate(out_bench):
342
+ print(f"index {{index}}:")
343
+ compare_element(out_device[index], out_bench[index], api_name)
344
+ else:
345
+ compare_element(out_device, out_bench, api_name)
346
+
347
+ if __name__ == "__main__":
348
+ device = get_device()
349
+ api_name = "{api_name}"
350
+ propagation = "{propagation}"
351
+ compare_standard = {compare_standard}
352
+ torch.manual_seed({random_seed})
353
+ for i in range({iter_times}):
354
+ print(f"iter: {{i}}:")
355
+ if propagation == BACKWARD:
356
+ args_device, kwargs_device, args_bench, kwargs_bench, args_device_backward, args_bench_backward = get_input(propagation)
357
+ output_device = exec_api(args_device, kwargs_device, args_device_backward, propagation)
358
+ output_bench = exec_api(args_bench, kwargs_bench, args_bench_backward, propagation)
359
+ compare(output_device, output_bench, api_name)
360
+ else:
361
+ args_device, kwargs_device, args_bench, kwargs_bench = get_input(propagation)
362
+ output_device = exec_api(args_device, kwargs_device, None, propagation)
363
+ output_bench = exec_api(args_bench, kwargs_bench, None, propagation)
364
+ compare(output_device, output_bench, api_name)
365
+ print("Compare finished.")
@@ -139,7 +139,12 @@ def gen_random_tensor(info, convert_type):
139
139
  high_info = [high, high_origin]
140
140
  data_dtype = info.get('dtype')
141
141
  shape = tuple(info.get('shape'))
142
- if not isinstance(low, (int, float)) or not isinstance(high, (int, float)):
142
+ if 0 in shape:
143
+ low, low_origin = 0, 0
144
+ high, high_origin = 0, 0
145
+ low_info = [low, low_origin]
146
+ high_info = [high, high_origin]
147
+ elif not isinstance(low, (int, float)) or not isinstance(high, (int, float)):
143
148
  error_info = f'Data info Min: {low} , Max: {high}, info type must be int or float.'
144
149
  raise CompareException(CompareException.INVALID_PARAM_ERROR, error_info)
145
150
  if data_dtype == "torch.bool":
@@ -33,9 +33,10 @@ from msprobe.pytorch.api_accuracy_checker.compare.compare import Comparator
33
33
  from msprobe.pytorch.common import parse_json_info_forward_backward
34
34
  from msprobe.pytorch.common.log import logger
35
35
  from msprobe.core.common.file_utils import FileChecker, check_file_suffix, check_link, FileOpen, \
36
- check_path_before_create, create_directory
36
+ create_directory, load_json, save_json
37
37
  from msprobe.core.common.file_utils import remove_path
38
- from msprobe.core.common.const import FileCheckConst
38
+ from msprobe.core.common.const import FileCheckConst, Const
39
+ from msprobe.core.common.utils import CompareException
39
40
 
40
41
 
41
42
  def split_json_file(input_file, num_splits, filter_api):
@@ -47,9 +48,11 @@ def split_json_file(input_file, num_splits, filter_api):
47
48
  for data_name in list(backward_data.keys()):
48
49
  backward_data[f"{data_name}.backward"] = backward_data.pop(data_name)
49
50
 
50
- with FileOpen(input_file, 'r') as file:
51
- input_data = json.load(file)
52
- input_data.pop("data")
51
+ input_data = load_json(input_file)
52
+ if input_data.get("data") is None:
53
+ logger.error("Invalid input file, 'data' field is missing")
54
+ raise CompareException("Invalid input file, 'data' field is missing")
55
+ input_data.pop("data")
53
56
 
54
57
  items = list(forward_data.items())
55
58
  total_items = len(items)
@@ -69,8 +72,7 @@ def split_json_file(input_file, num_splits, filter_api):
69
72
  }
70
73
  }
71
74
  split_filename = f"temp_part{i}.json"
72
- with FileOpen(split_filename, 'w') as split_file:
73
- json.dump(temp_data, split_file)
75
+ save_json(split_filename, temp_data)
74
76
  split_files.append(split_filename)
75
77
 
76
78
  return split_files, total_items
@@ -122,7 +124,7 @@ def run_parallel_ut(config):
122
124
  if output == '':
123
125
  break
124
126
  if '[ERROR]' in output:
125
- logger.warning(output, end='')
127
+ logger.warning(output)
126
128
  sys.stdout.flush()
127
129
  except ValueError as e:
128
130
  logger.warning(f"An error occurred while reading subprocess output: {e}")
@@ -182,16 +184,19 @@ def run_parallel_ut(config):
182
184
 
183
185
 
184
186
  def prepare_config(args):
185
- check_link(args.api_info_file)
186
- api_info = os.path.realpath(args.api_info_file)
187
- check_file_suffix(api_info, FileCheckConst.JSON_SUFFIX)
188
- out_path = os.path.realpath(args.out_path) if args.out_path else "./"
189
- check_path_before_create(out_path)
187
+ api_info_file_checker = FileChecker(file_path=args.api_info_file, path_type=FileCheckConst.FILE,
188
+ ability=FileCheckConst.READ_ABLE, file_type=FileCheckConst.JSON_SUFFIX)
189
+ api_info = api_info_file_checker.common_check()
190
+ out_path = args.out_path if args.out_path else Const.DEFAULT_PATH
190
191
  create_directory(out_path)
191
192
  out_path_checker = FileChecker(out_path, FileCheckConst.DIR, ability=FileCheckConst.WRITE_ABLE)
192
193
  out_path = out_path_checker.common_check()
193
194
  split_files, total_items = split_json_file(api_info, args.num_splits, args.filter_api)
194
- config_path = os.path.realpath(args.config_path) if args.config_path else None
195
+ config_path = args.config_path if args.config_path else None
196
+ if config_path:
197
+ config_path_checker = FileChecker(config_path, FileCheckConst.FILE,
198
+ FileCheckConst.READ_ABLE, FileCheckConst.JSON_SUFFIX)
199
+ config_path = config_path_checker.common_check()
195
200
  result_csv_path = args.result_csv_path or os.path.join(
196
201
  out_path, f"accuracy_checking_result_{time.strftime('%Y%m%d%H%M%S')}.csv")
197
202
  if not args.result_csv_path:
@@ -28,11 +28,12 @@ else:
28
28
  import torch
29
29
  from tqdm import tqdm
30
30
  from msprobe.pytorch.api_accuracy_checker.run_ut.run_ut import generate_device_params, get_api_info
31
- from msprobe.pytorch.api_accuracy_checker.run_ut.run_ut_utils import exec_api
32
- from msprobe.core.common.file_utils import check_link
31
+ from msprobe.pytorch.api_accuracy_checker.run_ut.run_ut_utils import exec_api, is_unsupported_api
32
+ from msprobe.core.common.file_utils import check_link, FileChecker
33
+ from msprobe.pytorch.api_accuracy_checker.common.utils import extract_basic_api_segments
34
+ from msprobe.core.common.const import FileCheckConst, Const
33
35
  from msprobe.pytorch.common.log import logger
34
36
  from msprobe.pytorch.common.parse_json import parse_json_info_forward_backward
35
- from msprobe.core.common.const import Const
36
37
 
37
38
 
38
39
  def check_tensor_overflow(x):
@@ -74,23 +75,25 @@ def run_overflow_check(forward_file):
74
75
  logger.info("start UT test")
75
76
  forward_content, _, real_data_path = parse_json_info_forward_backward(forward_file)
76
77
  for api_full_name, api_info_dict in tqdm(forward_content.items()):
78
+ if is_unsupported_api(api_full_name, is_overflow_check=True):
79
+ continue
77
80
  try:
78
81
  run_torch_api(api_full_name, api_info_dict, real_data_path)
79
82
  except Exception as err:
80
83
  _, api_name, _ = api_full_name.split(Const.SEP)
81
84
  if "not implemented for 'Half'" in str(err):
82
- logger.warning(f"API {api_name} not support half tensor in CPU, please add {api_name} to CONVERT_API "
83
- f"'fp16_to_fp32' list in accuracy_tools/api_accuracy_check/common/utils.py file.")
85
+ logger.warning(f"API {api_name} not support half tensor in CPU. This API does not support overflow "
86
+ "check, so it will be skipped.")
84
87
  elif "expected scalar type Long" in str(err):
85
88
  logger.warning(f"API {api_name} not support int32 tensor in CPU, please add {api_name} to CONVERT_API "
86
- f"'int32_to_int64' list in accuracy_tools/api_accuracy_check/common/utils.py file.")
89
+ "'int32_to_int64' list in accuracy_tools/msprobe/core/common/const.py file.")
87
90
  else:
88
91
  logger.error(f"Run {api_full_name} UT Error: %s" % str(err))
89
92
 
90
93
 
91
94
  def run_torch_api(api_full_name, api_info_dict, real_data_path):
92
95
  torch.npu.clear_npu_overflow_flag()
93
- api_type, api_name, _ = api_full_name.split(Const.SEP)
96
+ api_type, api_name = extract_basic_api_segments(api_full_name)
94
97
  args, kwargs, need_grad = get_api_info(api_info_dict, api_name, real_data_path)
95
98
  if not need_grad:
96
99
  logger.warning("%s function with out=... arguments don't support automatic differentiation, skip backward."
@@ -135,8 +138,9 @@ def _run_overflow_check(parser=None):
135
138
  def _run_overflow_check_command(args):
136
139
  torch.npu.set_compile_mode(jit_compile=args.jit_compile)
137
140
  npu_device = "npu:" + str(args.device_id)
138
- check_link(args.api_info_file)
139
- api_info = os.path.realpath(args.api_info_file)
141
+ api_info_file_checker = FileChecker(file_path=args.api_info_file, path_type=FileCheckConst.FILE,
142
+ ability=FileCheckConst.READ_ABLE, file_type=FileCheckConst.JSON_SUFFIX)
143
+ api_info = api_info_file_checker.common_check()
140
144
  try:
141
145
  torch.npu.set_device(npu_device)
142
146
  except Exception as error: