mindstudio-probe 1.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (228) hide show
  1. mindstudio_probe-1.0.1.dist-info/LICENSE +201 -0
  2. mindstudio_probe-1.0.1.dist-info/METADATA +30 -0
  3. mindstudio_probe-1.0.1.dist-info/RECORD +228 -0
  4. mindstudio_probe-1.0.1.dist-info/WHEEL +5 -0
  5. mindstudio_probe-1.0.1.dist-info/entry_points.txt +2 -0
  6. mindstudio_probe-1.0.1.dist-info/top_level.txt +1 -0
  7. msprobe/README.md +182 -0
  8. msprobe/__init__.py +0 -0
  9. msprobe/config/README.md +397 -0
  10. msprobe/config/config.json +28 -0
  11. msprobe/config/img/free_benchmark.png +0 -0
  12. msprobe/core/common/const.py +241 -0
  13. msprobe/core/common/exceptions.py +88 -0
  14. msprobe/core/common/file_check.py +265 -0
  15. msprobe/core/common/log.py +55 -0
  16. msprobe/core/common/utils.py +516 -0
  17. msprobe/core/common_config.py +58 -0
  18. msprobe/core/data_dump/data_collector.py +140 -0
  19. msprobe/core/data_dump/data_processor/base.py +245 -0
  20. msprobe/core/data_dump/data_processor/factory.py +61 -0
  21. msprobe/core/data_dump/data_processor/pytorch_processor.py +346 -0
  22. msprobe/core/data_dump/json_writer.py +116 -0
  23. msprobe/core/data_dump/scope.py +178 -0
  24. msprobe/mindspore/__init__.py +1 -0
  25. msprobe/mindspore/debugger/__init__.py +0 -0
  26. msprobe/mindspore/debugger/debugger_config.py +51 -0
  27. msprobe/mindspore/debugger/precision_debugger.py +32 -0
  28. msprobe/mindspore/doc/dump.md +65 -0
  29. msprobe/mindspore/dump/__init__.py +0 -0
  30. msprobe/mindspore/dump/api_kbk_dump.py +55 -0
  31. msprobe/mindspore/dump/dump_tool_factory.py +38 -0
  32. msprobe/mindspore/dump/kernel_graph_dump.py +60 -0
  33. msprobe/mindspore/ms_config.py +78 -0
  34. msprobe/mindspore/overflow_check/__init__.py +0 -0
  35. msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +45 -0
  36. msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +32 -0
  37. msprobe/mindspore/task_handler_factory.py +21 -0
  38. msprobe/msprobe.py +67 -0
  39. msprobe/pytorch/__init__.py +4 -0
  40. msprobe/pytorch/advisor/advisor.py +124 -0
  41. msprobe/pytorch/advisor/advisor_const.py +59 -0
  42. msprobe/pytorch/advisor/advisor_result.py +58 -0
  43. msprobe/pytorch/api_accuracy_checker/.keep +0 -0
  44. msprobe/pytorch/api_accuracy_checker/__init__.py +0 -0
  45. msprobe/pytorch/api_accuracy_checker/common/.keep +0 -0
  46. msprobe/pytorch/api_accuracy_checker/common/__init__.py +0 -0
  47. msprobe/pytorch/api_accuracy_checker/common/config.py +50 -0
  48. msprobe/pytorch/api_accuracy_checker/common/utils.py +224 -0
  49. msprobe/pytorch/api_accuracy_checker/compare/__init__.py +0 -0
  50. msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +216 -0
  51. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +545 -0
  52. msprobe/pytorch/api_accuracy_checker/compare/api_precision_standard.yaml +133 -0
  53. msprobe/pytorch/api_accuracy_checker/compare/api_precision_threshold.yaml +390 -0
  54. msprobe/pytorch/api_accuracy_checker/compare/compare.py +345 -0
  55. msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +74 -0
  56. msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +249 -0
  57. msprobe/pytorch/api_accuracy_checker/config.yaml +4 -0
  58. msprobe/pytorch/api_accuracy_checker/run_ut/.keep +0 -0
  59. msprobe/pytorch/api_accuracy_checker/run_ut/__init__.py +0 -0
  60. msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +328 -0
  61. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +203 -0
  62. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +127 -0
  63. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +493 -0
  64. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +7 -0
  65. msprobe/pytorch/api_accuracy_checker/run_ut/torch_ut_setting.json +5 -0
  66. msprobe/pytorch/common/__init__.py +2 -0
  67. msprobe/pytorch/common/compare_script.template +14 -0
  68. msprobe/pytorch/common/log.py +32 -0
  69. msprobe/pytorch/common/parse_json.py +37 -0
  70. msprobe/pytorch/common/utils.py +224 -0
  71. msprobe/pytorch/compare/acc_compare.py +1024 -0
  72. msprobe/pytorch/compare/distributed_compare.py +111 -0
  73. msprobe/pytorch/compare/highlight.py +100 -0
  74. msprobe/pytorch/compare/mapping.yaml +607 -0
  75. msprobe/pytorch/compare/match.py +36 -0
  76. msprobe/pytorch/compare/npy_compare.py +244 -0
  77. msprobe/pytorch/debugger/__init__.py +0 -0
  78. msprobe/pytorch/debugger/debugger_config.py +86 -0
  79. msprobe/pytorch/debugger/precision_debugger.py +95 -0
  80. msprobe/pytorch/doc/FAQ.md +193 -0
  81. msprobe/pytorch/doc/api_accuracy_checker.md +269 -0
  82. msprobe/pytorch/doc/atat/321/207/342/226/223/342/225/233/321/205/342/225/221/320/266/321/205/342/225/226/320/265/321/205/320/225/342/225/226/321/206/320/245/342/226/221/321/206/320/235/320/276dump/321/206/320/260/320/227/321/205/320/227/320/226/321/206/320/220/320/267/321/210/320/223/342/225/234/321/205/320/257/342/225/221/321/207/342/225/221/342/224/220/321/206/320/232/320/265/321/205/320/241/320/232.md +182 -0
  83. msprobe/pytorch/doc/dump.md +207 -0
  84. msprobe/pytorch/doc/img/BLOOM-7B_1.png +0 -0
  85. msprobe/pytorch/doc/img/BLOOM-7B_2.png +0 -0
  86. msprobe/pytorch/doc/img/BLOOM-7B_3.png +0 -0
  87. msprobe/pytorch/doc/img/BLOOM-7B_4.png +0 -0
  88. msprobe/pytorch/doc/img/GPT-3_1.png +0 -0
  89. msprobe/pytorch/doc/img/GPT-3_2.png +0 -0
  90. msprobe/pytorch/doc/img/GPT-3_3.png +0 -0
  91. msprobe/pytorch/doc/img/GPT-3_4.png +0 -0
  92. msprobe/pytorch/doc/img/GPT-3_5.png +0 -0
  93. msprobe/pytorch/doc/img/GPT-3_6.png +0 -0
  94. msprobe/pytorch/doc/img/GPT-3_7.png +0 -0
  95. msprobe/pytorch/doc/img/GPT-3_8.png +0 -0
  96. msprobe/pytorch/doc/img/YOLOV5S_1.png +0 -0
  97. msprobe/pytorch/doc/img/YOLOV5S_2.png +0 -0
  98. msprobe/pytorch/doc/img/accuracy_checking_details.png +0 -0
  99. msprobe/pytorch/doc/img/accuracy_checking_result.png +0 -0
  100. msprobe/pytorch/doc/img/api_precision_compare_details.png +0 -0
  101. msprobe/pytorch/doc/img/api_precision_compare_result.png +0 -0
  102. msprobe/pytorch/doc/img/auto_analyze_log.png +0 -0
  103. msprobe/pytorch/doc/img/compare_result_pkl.png +0 -0
  104. msprobe/pytorch/doc/img/compare_result_pkl_md5.png.png +0 -0
  105. msprobe/pytorch/doc/img/cpu_info.png +0 -0
  106. msprobe/pytorch/doc/img/module_compare.png +0 -0
  107. msprobe/pytorch/doc/parse_tool.md +286 -0
  108. msprobe/pytorch/doc/ptdbg_ascend_compare.md +176 -0
  109. msprobe/pytorch/doc/ptdbg_ascend_overview.md +68 -0
  110. msprobe/pytorch/doc/ptdbg_ascend_quickstart.md +381 -0
  111. msprobe/pytorch/doc/run_overflow_check.md +25 -0
  112. msprobe/pytorch/doc//321/205/320/254/320/270/321/207/342/225/221/342/224/220/321/207/342/226/223/342/225/233/321/205/342/225/221/320/266/321/206/320/277/320/244/321/205/320/277/342/225/243.md +90 -0
  113. msprobe/pytorch/free_benchmark/__init__.py +8 -0
  114. msprobe/pytorch/free_benchmark/common/__init__.py +0 -0
  115. msprobe/pytorch/free_benchmark/common/constant.py +67 -0
  116. msprobe/pytorch/free_benchmark/common/counter.py +72 -0
  117. msprobe/pytorch/free_benchmark/common/enums.py +37 -0
  118. msprobe/pytorch/free_benchmark/common/params.py +129 -0
  119. msprobe/pytorch/free_benchmark/common/utils.py +98 -0
  120. msprobe/pytorch/free_benchmark/compare/grad_saver.py +183 -0
  121. msprobe/pytorch/free_benchmark/compare/single_benchmark.py +104 -0
  122. msprobe/pytorch/free_benchmark/main.py +102 -0
  123. msprobe/pytorch/free_benchmark/perturbed_layers/__init__.py +0 -0
  124. msprobe/pytorch/free_benchmark/perturbed_layers/base_layer.py +13 -0
  125. msprobe/pytorch/free_benchmark/perturbed_layers/layer_factory.py +41 -0
  126. msprobe/pytorch/free_benchmark/perturbed_layers/npu/__init__.py +0 -0
  127. msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +90 -0
  128. msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +104 -0
  129. msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +63 -0
  130. msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +68 -0
  131. msprobe/pytorch/free_benchmark/perturbed_layers/npu/no_change.py +28 -0
  132. msprobe/pytorch/free_benchmark/perturbed_layers/npu/npu_base_layser.py +45 -0
  133. msprobe/pytorch/free_benchmark/perturbed_layers/run_cpu.py +19 -0
  134. msprobe/pytorch/free_benchmark/result_handlers/__init__.py +0 -0
  135. msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +203 -0
  136. msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +39 -0
  137. msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +24 -0
  138. msprobe/pytorch/free_benchmark/result_handlers/handler_factory.py +31 -0
  139. msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +170 -0
  140. msprobe/pytorch/functional/__init__.py +0 -0
  141. msprobe/pytorch/functional/data_processor.py +0 -0
  142. msprobe/pytorch/functional/dump_module.py +39 -0
  143. msprobe/pytorch/hook_module/__init__.py +1 -0
  144. msprobe/pytorch/hook_module/api_registry.py +161 -0
  145. msprobe/pytorch/hook_module/hook_module.py +109 -0
  146. msprobe/pytorch/hook_module/support_wrap_ops.yaml +1876 -0
  147. msprobe/pytorch/hook_module/utils.py +29 -0
  148. msprobe/pytorch/hook_module/wrap_aten.py +100 -0
  149. msprobe/pytorch/hook_module/wrap_distributed.py +75 -0
  150. msprobe/pytorch/hook_module/wrap_functional.py +108 -0
  151. msprobe/pytorch/hook_module/wrap_npu_custom.py +73 -0
  152. msprobe/pytorch/hook_module/wrap_tensor.py +72 -0
  153. msprobe/pytorch/hook_module/wrap_torch.py +88 -0
  154. msprobe/pytorch/hook_module/wrap_vf.py +64 -0
  155. msprobe/pytorch/module_processer.py +98 -0
  156. msprobe/pytorch/online_dispatch/__init__.py +20 -0
  157. msprobe/pytorch/online_dispatch/compare.py +236 -0
  158. msprobe/pytorch/online_dispatch/dispatch.py +274 -0
  159. msprobe/pytorch/online_dispatch/dump_compare.py +186 -0
  160. msprobe/pytorch/online_dispatch/single_compare.py +391 -0
  161. msprobe/pytorch/online_dispatch/torch_ops_config.yaml +50 -0
  162. msprobe/pytorch/online_dispatch/utils.py +187 -0
  163. msprobe/pytorch/parse.py +4 -0
  164. msprobe/pytorch/parse_tool/__init__.py +0 -0
  165. msprobe/pytorch/parse_tool/cli.py +32 -0
  166. msprobe/pytorch/parse_tool/lib/__init__.py +0 -0
  167. msprobe/pytorch/parse_tool/lib/compare.py +259 -0
  168. msprobe/pytorch/parse_tool/lib/config.py +51 -0
  169. msprobe/pytorch/parse_tool/lib/file_desc.py +31 -0
  170. msprobe/pytorch/parse_tool/lib/interactive_cli.py +102 -0
  171. msprobe/pytorch/parse_tool/lib/parse_exception.py +54 -0
  172. msprobe/pytorch/parse_tool/lib/parse_tool.py +158 -0
  173. msprobe/pytorch/parse_tool/lib/utils.py +367 -0
  174. msprobe/pytorch/parse_tool/lib/visualization.py +90 -0
  175. msprobe/pytorch/pt_config.py +93 -0
  176. msprobe/pytorch/service.py +167 -0
  177. msprobe/test/core_ut/common/test_utils.py +345 -0
  178. msprobe/test/core_ut/data_dump/test_data_collector.py +47 -0
  179. msprobe/test/core_ut/data_dump/test_json_writer.py +183 -0
  180. msprobe/test/core_ut/data_dump/test_scope.py +151 -0
  181. msprobe/test/core_ut/test_common_config.py +152 -0
  182. msprobe/test/core_ut/test_file_check.py +218 -0
  183. msprobe/test/core_ut/test_log.py +109 -0
  184. msprobe/test/mindspore_ut/test_api_kbk_dump.py +51 -0
  185. msprobe/test/mindspore_ut/test_debugger_config.py +42 -0
  186. msprobe/test/mindspore_ut/test_dump_tool_factory.py +51 -0
  187. msprobe/test/mindspore_ut/test_kernel_graph_dump.py +66 -0
  188. msprobe/test/mindspore_ut/test_kernel_graph_overflow_check.py +63 -0
  189. msprobe/test/mindspore_ut/test_ms_config.py +69 -0
  190. msprobe/test/mindspore_ut/test_overflow_check_tool_factory.py +51 -0
  191. msprobe/test/mindspore_ut/test_precision_debugger.py +56 -0
  192. msprobe/test/mindspore_ut/test_task_handler_factory.py +58 -0
  193. msprobe/test/pytorch_ut/advisor/test_advisor.py +83 -0
  194. msprobe/test/pytorch_ut/api_accuracy_checker/common/test_common_utils.py +108 -0
  195. msprobe/test/pytorch_ut/api_accuracy_checker/common/test_config.py +39 -0
  196. msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_algorithm.py +112 -0
  197. msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_api_precision_compare.py +77 -0
  198. msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_compare.py +125 -0
  199. msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_compare_column.py +10 -0
  200. msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_compare_utils.py +43 -0
  201. msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/dump.json +179 -0
  202. msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/forward.json +63 -0
  203. msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_data_generate.py +99 -0
  204. msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_multi_run_ut.py +115 -0
  205. msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_run_ut.py +72 -0
  206. msprobe/test/pytorch_ut/compare/test_acc_compare.py +17 -0
  207. msprobe/test/pytorch_ut/free_benchmark/perturbed_layers/test_perturbed_layser.py +105 -0
  208. msprobe/test/pytorch_ut/free_benchmark/result_handlers/test_result_handler.py +121 -0
  209. msprobe/test/pytorch_ut/free_benchmark/test_main.py +101 -0
  210. msprobe/test/pytorch_ut/functional/test_dump_module.py +15 -0
  211. msprobe/test/pytorch_ut/hook_module/test_api_registry.py +130 -0
  212. msprobe/test/pytorch_ut/hook_module/test_hook_module.py +42 -0
  213. msprobe/test/pytorch_ut/hook_module/test_wrap_aten.py +65 -0
  214. msprobe/test/pytorch_ut/hook_module/test_wrap_distributed.py +35 -0
  215. msprobe/test/pytorch_ut/hook_module/test_wrap_functional.py +20 -0
  216. msprobe/test/pytorch_ut/hook_module/test_wrap_tensor.py +35 -0
  217. msprobe/test/pytorch_ut/hook_module/test_wrap_torch.py +43 -0
  218. msprobe/test/pytorch_ut/hook_module/test_wrap_vf.py +11 -0
  219. msprobe/test/pytorch_ut/test_pt_config.py +69 -0
  220. msprobe/test/pytorch_ut/test_service.py +59 -0
  221. msprobe/test/resources/advisor.txt +3 -0
  222. msprobe/test/resources/compare_result_20230703104808.csv +9 -0
  223. msprobe/test/resources/compare_result_without_accuracy.csv +9 -0
  224. msprobe/test/resources/config.yaml +3 -0
  225. msprobe/test/resources/npu_test.pkl +8 -0
  226. msprobe/test/run_test.sh +30 -0
  227. msprobe/test/run_ut.py +58 -0
  228. msprobe/test/test_module_processer.py +64 -0
@@ -0,0 +1,90 @@
1
+ import torch
2
+ from msprobe.pytorch.free_benchmark import logger
3
+ from msprobe.pytorch.free_benchmark.common.constant import ThresholdConfig
4
+ from msprobe.pytorch.free_benchmark.common.enums import PerturbationMode
5
+ from msprobe.pytorch.free_benchmark.common.params import DataParams
6
+ from msprobe.pytorch.free_benchmark.common.utils import TorchC
7
+ from msprobe.pytorch.free_benchmark.perturbed_layers.npu.npu_base_layser import (
8
+ NpuBaseLayer,
9
+ )
10
+
11
+
12
+ class AddNoiseLayer(NpuBaseLayer):
13
+
14
+ def add_noise(self, tensor_obj):
15
+ if isinstance(tensor_obj, torch.Tensor):
16
+ self.perturbed_value = ThresholdConfig.PERTURBATION_VALUE_DICT.get(
17
+ tensor_obj.dtype
18
+ )
19
+ if not self.pre_check(tensor_obj):
20
+ return tensor_obj
21
+ noise = self._get_noise(tensor_obj)
22
+ result = TorchC.where(
23
+ TorchC.gt(TorchC.abs(tensor_obj), self.perturbed_value ** 0.5),
24
+ TorchC.add(noise, tensor_obj),
25
+ tensor_obj,
26
+ ).to(tensor_obj.dtype)
27
+ self.is_added = True
28
+ return result
29
+ if isinstance(tensor_obj, dict):
30
+ return {key: self.add_noise(value) for key, value in tensor_obj.items()}
31
+ if isinstance(tensor_obj, (tuple, list)):
32
+ return type(tensor_obj)([self.add_noise(value) for value in tensor_obj])
33
+ return tensor_obj
34
+
35
+ def handle(self, params: DataParams) -> torch.Any:
36
+ """
37
+ 对输入添加扰动并返回
38
+ """
39
+ logger.info_on_rank_0(
40
+ f"[msprobe] Free benchmark: Perturbation is "
41
+ f"{PerturbationMode.ADD_NOISE} of {self.api_name}."
42
+ )
43
+ params.perturbed_value = self.add_noise(params.args[params.valid_input_index])
44
+ return self.perturbed_result(params)
45
+
46
+ def _get_noise(self, tensor_obj):
47
+ dtype = tensor_obj.dtype
48
+ device = str(tensor_obj.device)
49
+ noise = TorchC.full(
50
+ tensor_obj.shape,
51
+ self.perturbed_value,
52
+ device=device,
53
+ dtype=dtype,
54
+ )
55
+ return noise
56
+
57
+ def _check_details(self, tensor_obj):
58
+ """
59
+ 判断是否需要添加扰动
60
+ """
61
+ if not self.perturbed_value:
62
+ logger.warning_on_rank_0(
63
+ f"[msprobe] Free Benchmark: For {self.api_name}, "
64
+ f"dtype unsupported. Cancel perturbation."
65
+ )
66
+ return False
67
+ if tensor_obj.numel() == 0:
68
+ logger.warning_on_rank_0(
69
+ f"[msprobe] Free benchmark: For {self.api_name}, tensor shape must > 0."
70
+ f" Cancel adding noise."
71
+ )
72
+ return False
73
+ abs_tol = ThresholdConfig.ABS_TOL_VALUE_DICT.get(
74
+ tensor_obj.dtype, ThresholdConfig.NOISE_INPUT_LOWER_BOUND
75
+ )
76
+ try:
77
+ max_val = TorchC.max(TorchC.abs(tensor_obj)).item()
78
+ except Exception:
79
+ logger.warning_on_rank_0(
80
+ f"[msprobe] Free Benchmark: For {self.api_name}, "
81
+ f"when calculate maximun value, tensor is changed to float32."
82
+ )
83
+ max_val = TorchC.max(TorchC.abs(tensor_obj.to(torch.float32))).item()
84
+ if max_val < abs_tol:
85
+ logger.warning_on_rank_0(
86
+ f"[msprobe] Free Benchmark: For {self.api_name}, "
87
+ f"Maximun value is less than the minimun threshold. Cancel add noise."
88
+ )
89
+ return False
90
+ return True
@@ -0,0 +1,104 @@
1
+ import torch
2
+ from msprobe.pytorch.free_benchmark import logger
3
+ from msprobe.pytorch.free_benchmark.common.constant import ThresholdConfig
4
+ from msprobe.pytorch.free_benchmark.common.enums import PerturbationMode
5
+ from msprobe.pytorch.free_benchmark.common.params import DataParams
6
+ from msprobe.pytorch.free_benchmark.common.utils import TorchC
7
+ from msprobe.pytorch.free_benchmark.perturbed_layers.npu.npu_base_layser import (
8
+ NpuBaseLayer,
9
+ )
10
+
11
+
12
+ class BitNoiseLayer(NpuBaseLayer):
13
+ def __init__(self, api_name):
14
+ super().__init__(api_name)
15
+ self.bit_mode = TorchC.bitwise_xor
16
+ self.bit_tail: int = 1
17
+ self.bit_type = None
18
+
19
+ def add_bit_noise(self, tensor_obj):
20
+ """
21
+ 对输入添加噪声
22
+ """
23
+ # finfo应该列入黑名单
24
+
25
+ if isinstance(tensor_obj, torch.Tensor):
26
+ self._set_perturbation_bit(tensor_obj)
27
+ if not self.pre_check(tensor_obj):
28
+ return tensor_obj
29
+ sub_normal = torch.finfo(tensor_obj.dtype).smallest_normal
30
+ noise = TorchC.full(
31
+ tensor_obj.shape,
32
+ self.bit_tail,
33
+ device=tensor_obj.device,
34
+ dtype=self.bit_type,
35
+ )
36
+ result = tensor_obj.view(self.bit_type)
37
+ result = TorchC.where(
38
+ TorchC.gt(TorchC.abs(tensor_obj), sub_normal),
39
+ self.bit_mode(result, noise),
40
+ result,
41
+ ).view(tensor_obj.dtype)
42
+
43
+ self.is_added = True
44
+ return result
45
+ if isinstance(tensor_obj, dict):
46
+ return {key: self.add_bit_noise(value) for key, value in tensor_obj.items()}
47
+ if isinstance(tensor_obj, (tuple, list)):
48
+ return type(tensor_obj)([self.add_bit_noise(value) for value in tensor_obj])
49
+ return tensor_obj
50
+
51
+ def handle(self, params: DataParams) -> torch.Any:
52
+ """
53
+ 对输入添加扰动并返回
54
+ """
55
+ logger.info_on_rank_0(
56
+ f"[msprobe] Free benchmark: Perturbation is "
57
+ f"{PerturbationMode.BIT_NOISE} of {self.api_name}."
58
+ )
59
+ params.perturbed_value = self.add_bit_noise(params.args[params.valid_input_index])
60
+ return self.perturbed_result(params)
61
+
62
+ def _check_details(self, tensor_obj):
63
+ """
64
+ 判断是否需要添加扰动, bit翻转
65
+ """
66
+ if not self.bit_type:
67
+ logger.info_on_rank_0(
68
+ f"[msprobe] Free Benchmark: For {self.api_name}, "
69
+ f"dtype unsupported. Cancel perturbation."
70
+ )
71
+ return False
72
+ if tensor_obj.numel() == 0:
73
+ logger.warning_on_rank_0(
74
+ f"[msprobe] Free benchmark: For {self.api_name}, tensor shape must > 0"
75
+ f" Cancel adding noise."
76
+ )
77
+ return False
78
+ abs_tol = ThresholdConfig.ABS_TOL_VALUE_DICT.get(
79
+ tensor_obj.dtype, ThresholdConfig.NOISE_INPUT_LOWER_BOUND
80
+ )
81
+ try:
82
+ max_val = TorchC.max(TorchC.abs(tensor_obj)).item()
83
+ except Exception:
84
+ logger.warning_on_rank_0(
85
+ f"[msprobe] Free Benchmark: For {self.api_name}, "
86
+ f"when calculate maximun value, tensor is changed to float32."
87
+ )
88
+ max_val = TorchC.max(TorchC.abs(tensor_obj.to(torch.float32))).item()
89
+ if max_val < abs_tol:
90
+ logger.info_on_rank_0(
91
+ f"[msprobe] Free Benchmark: For {self.api_name}, "
92
+ f"Maximun value is less than the minimun threshold. Cancel add noise."
93
+ )
94
+ return False
95
+ return True
96
+
97
+ def _set_perturbation_bit(self, tensor_obj):
98
+ """
99
+ 根据不同浮点数确定不同位数扰动值
100
+ """
101
+ bit_len_type = ThresholdConfig.PERTURBATION_BIT_DICT.get(tensor_obj.dtype)
102
+ if bit_len_type:
103
+ self.bit_tail = 1
104
+ self.bit_type = bit_len_type
@@ -0,0 +1,63 @@
1
+ import torch
2
+ from msprobe.pytorch.free_benchmark import logger
3
+ from msprobe.pytorch.free_benchmark.common.enums import PerturbationMode
4
+ from msprobe.pytorch.free_benchmark.common.params import DataParams
5
+ from msprobe.pytorch.free_benchmark.common.utils import TorchC
6
+ from msprobe.pytorch.free_benchmark.perturbed_layers.npu.npu_base_layser import (
7
+ NpuBaseLayer,
8
+ )
9
+
10
+
11
+ class ChangeValueLayer(NpuBaseLayer):
12
+ def __init__(self, api_name):
13
+ super().__init__(api_name)
14
+ self.head: int = 0
15
+ self.tail: int = -1
16
+
17
+ def change_value(self, tensor_obj):
18
+ """
19
+ 交换张量首尾
20
+ """
21
+ if isinstance(tensor_obj, torch.Tensor) and self.pre_check(tensor_obj):
22
+ new_tensor = TorchC.clone(tensor_obj)
23
+ if new_tensor.ndim == 1:
24
+ temp_first = TorchC.clone(new_tensor[self.head])
25
+ temp_last = TorchC.clone(new_tensor[self.tail])
26
+ new_tensor[self.head] = temp_last
27
+ new_tensor[self.tail] = temp_first
28
+ else:
29
+ temp_first = TorchC.clone(new_tensor[self.head][self.head])
30
+ temp_last = TorchC.clone(new_tensor[self.tail][self.tail])
31
+ new_tensor[self.head][self.head] = temp_last
32
+ new_tensor[self.tail][self.tail] = temp_first
33
+
34
+ self.is_added = True
35
+ return new_tensor
36
+ if isinstance(tensor_obj, dict):
37
+ return {key: self.change_value(value) for key, value in tensor_obj.items()}
38
+ if isinstance(tensor_obj, (tuple, list)):
39
+ return type(tensor_obj)([self.change_value(value) for value in tensor_obj])
40
+ return tensor_obj
41
+
42
+ def handle(self, params: DataParams) -> torch.Any:
43
+ """
44
+ 对输入添加扰动并返回
45
+ """
46
+ logger.info_on_rank_0(
47
+ f"[msprobe] Free benchmark: Perturbation is "
48
+ f"{PerturbationMode.CHANGE_VALUE} of {self.api_name}."
49
+ )
50
+ params.perturbed_value = self.change_value(params.args[params.valid_input_index])
51
+ return self.perturbed_result(params)
52
+
53
+ def _check_details(self, tensor_obj):
54
+ """
55
+ 判断是否需要添加扰动, 首尾值交换
56
+ """
57
+ if tensor_obj.size(0) < 2:
58
+ logger.info_on_rank_0(
59
+ f"[msprobe] Free Benchmark: For {self.api_name}, "
60
+ f"size 0 must greater than 1. Cancel change value."
61
+ )
62
+ return False
63
+ return True
@@ -0,0 +1,68 @@
1
+ import torch
2
+ from msprobe.core.common.const import Const
3
+ from msprobe.pytorch.free_benchmark import logger
4
+ from msprobe.pytorch.free_benchmark.common.constant import CommonField
5
+ from msprobe.pytorch.free_benchmark.common.enums import PerturbationMode
6
+ from msprobe.pytorch.free_benchmark.common.params import DataParams
7
+ from msprobe.pytorch.free_benchmark.perturbed_layers.npu.npu_base_layser import (
8
+ NpuBaseLayer,
9
+ )
10
+
11
+
12
+ class ImprovePrecisionLayer(NpuBaseLayer):
13
+
14
+ def improve_tensor_precision(self, tensor_obj):
15
+ if (
16
+ isinstance(tensor_obj, torch.Tensor)
17
+ and torch.is_floating_point(tensor_obj)
18
+ and tensor_obj.dtype not in [torch.float32, torch.float64]
19
+ ):
20
+ self._set_improve_valus(tensor_obj)
21
+ tensor_obj = self._change_dtype(tensor_obj)
22
+ self.is_added = True
23
+ return tensor_obj
24
+ if isinstance(tensor_obj, dict):
25
+ return {
26
+ key: self.improve_tensor_precision(value)
27
+ for key, value in tensor_obj.items()
28
+ }
29
+ if isinstance(tensor_obj, (tuple, list)):
30
+ return type(tensor_obj)(
31
+ [self.improve_tensor_precision(value) for value in tensor_obj]
32
+ )
33
+ return tensor_obj
34
+
35
+ def handle(self, params: DataParams) -> torch.Any:
36
+ logger.info_on_rank_0(
37
+ f"[msprobe] Free benchmark: Perturbation is "
38
+ f"{PerturbationMode.IMPROVE_PRECISION} of {self.api_name}."
39
+ )
40
+ new_args = self.improve_tensor_precision(params.args)
41
+ if params.fuzz_stage == Const.BACKWARD:
42
+ new_kwargs = {}
43
+ else:
44
+ new_kwargs = self.improve_tensor_precision(params.kwargs)
45
+ # 如果输入中全为高精度、应跳过二次执行、减少多余显存引用
46
+ if not self.is_added:
47
+ return params.perturbed_result
48
+ if "inplace" in new_kwargs:
49
+ new_kwargs["inplace"] = False
50
+ params.perturbed_result = params.origin_func(*new_args, **new_kwargs)
51
+ return params.perturbed_result
52
+
53
+ def _set_improve_valus(self, inputs):
54
+ if inputs.dtype in [torch.float16, torch.bfloat16]:
55
+ self.perturbed_value = torch.float32
56
+
57
+ def _change_dtype(self, inputs):
58
+ if hasattr(inputs, CommonField.DEVICE):
59
+ device = inputs.device
60
+ if device is CommonField.META:
61
+ new_inputs = inputs.to(
62
+ device=CommonField.META, dtype=self.perturbed_value
63
+ )
64
+ else:
65
+ new_inputs = inputs.to(dtype=self.perturbed_value).to(device)
66
+ else:
67
+ new_inputs = inputs.to(dtype=self.perturbed_value)
68
+ return new_inputs
@@ -0,0 +1,28 @@
1
+ import torch
2
+ from msprobe.pytorch.free_benchmark import logger
3
+ from msprobe.pytorch.free_benchmark.common.enums import PerturbationMode
4
+ from msprobe.pytorch.free_benchmark.common.params import DataParams
5
+ from msprobe.pytorch.free_benchmark.perturbed_layers.npu.npu_base_layser import (
6
+ NpuBaseLayer,
7
+ )
8
+
9
+
10
+ class NoChangeLayer(NpuBaseLayer):
11
+
12
+ def no_change(self, tensor_obj):
13
+ """
14
+ 不对输入做任何改变、直接二次执行
15
+ """
16
+ self.is_added = True
17
+ return tensor_obj
18
+
19
+ def handle(self, params: DataParams) -> torch.Any:
20
+ """
21
+ 对输入添加扰动并返回
22
+ """
23
+ logger.info_on_rank_0(
24
+ f"[msprobe] Free benchmark: Perturbation is "
25
+ f"{PerturbationMode.NO_CHANGE} of {self.api_name}."
26
+ )
27
+ params.perturbed_value = self.no_change(params.args[params.valid_input_index])
28
+ return self.perturbed_result(params)
@@ -0,0 +1,45 @@
1
+ from abc import abstractmethod
2
+ from typing import Any
3
+
4
+ import torch
5
+ from msprobe.pytorch.free_benchmark.common.params import DataParams
6
+ from msprobe.pytorch.free_benchmark.perturbed_layers.base_layer import BaseLayer
7
+
8
+
9
+ class NpuBaseLayer(BaseLayer):
10
+ def __init__(self, api_name: str) -> None:
11
+ super().__init__(api_name)
12
+ self.perturbed_value = None # 扰动的元素
13
+ self.is_added = False # 标记当前算子输入是否调整
14
+
15
+ @staticmethod
16
+ def perturbed_result(params: DataParams) -> Any:
17
+ args_front = params.args[: params.valid_input_index]
18
+ args_rear = params.args[params.valid_input_index + 1:]
19
+ # 此处会将有inplace属性的算子换为非inplace
20
+ if "inplace" in params.kwargs:
21
+ params.kwargs["inplace"] = False
22
+ params.perturbed_result = params.origin_func(
23
+ *args_front, params.perturbed_value, *args_rear, **params.kwargs
24
+ )
25
+ return params.perturbed_result
26
+
27
+ @abstractmethod
28
+ def handle(self, params: DataParams) -> Any:
29
+ pass
30
+
31
+ def pre_check(self, tensor_obj):
32
+ """
33
+ 检查张量是否符合标准(float类型且最大值大于对应精度最小值)
34
+ """
35
+ # 只针对第一个满足要求的添加扰动
36
+ if self.is_added:
37
+ return False
38
+ if not torch.is_floating_point(tensor_obj):
39
+ return False
40
+ if not self._check_details(tensor_obj):
41
+ return False
42
+ return True
43
+
44
+ def _check_details(self, tensor_obj):
45
+ return True
@@ -0,0 +1,19 @@
1
+ import torch
2
+ from msprobe.pytorch.free_benchmark import logger
3
+ from msprobe.pytorch.free_benchmark.common.params import DataParams
4
+ from msprobe.pytorch.free_benchmark.common.utils import Tools
5
+ from msprobe.pytorch.free_benchmark.common.enums import DeviceType
6
+ from msprobe.pytorch.free_benchmark.perturbed_layers.base_layer import BaseLayer
7
+
8
+
9
+ class CpuLayer(BaseLayer):
10
+
11
+ def handle(self, params: DataParams) -> torch.Any:
12
+
13
+ logger.info_on_rank_0(
14
+ f"[msprobe] Free benchmark: Perturbation is to_cpu of {self.api_name}."
15
+ )
16
+ new_args = Tools.convert_device_and_dtype(params.args, DeviceType.CPU, change_dtype=True)
17
+ new_kwargs = Tools.convert_device_and_dtype(params.kwargs, DeviceType.CPU, change_dtype=True)
18
+ params.perturbed_result = params.origin_func(*new_args, **new_kwargs)
19
+ return params.perturbed_result
@@ -0,0 +1,203 @@
1
+ import math
2
+ from abc import ABC, abstractmethod
3
+ from typing import Any, Optional, Tuple
4
+
5
+ import torch
6
+ from msprobe.core.common.const import Const
7
+ from msprobe.pytorch.free_benchmark import logger
8
+ from msprobe.pytorch.free_benchmark.common.constant import ThresholdConfig
9
+ from msprobe.pytorch.free_benchmark.common.enums import (
10
+ FuzzThreshold,
11
+ NormType,
12
+ PerturbationMode,
13
+ )
14
+ from msprobe.pytorch.free_benchmark.common.params import (
15
+ DataParams,
16
+ HandlerParams,
17
+ make_unequal_row,
18
+ )
19
+ from msprobe.pytorch.free_benchmark.common.utils import Tools, TorchC
20
+
21
+
22
+ class FuzzHandler(ABC):
23
+ def __init__(self, params: HandlerParams) -> None:
24
+ self.params = params
25
+ self.unequal_rows = []
26
+
27
+ @staticmethod
28
+ def pre_process(origin_ouput, perturbed_output):
29
+ if (
30
+ isinstance(origin_ouput, tuple)
31
+ and hasattr(origin_ouput, "values")
32
+ and hasattr(origin_ouput, "indices")
33
+ ):
34
+ origin_ouput = origin_ouput.values
35
+ perturbed_output = perturbed_output.values
36
+ if hasattr(perturbed_output, "dtype"):
37
+ abs_tol = ThresholdConfig.ABS_TOL_VALUE_DICT.get(perturbed_output.dtype)
38
+ else:
39
+ abs_tol = FuzzThreshold.F32_THD.value
40
+ return (
41
+ origin_ouput.to(perturbed_output.dtype).to(perturbed_output.device),
42
+ perturbed_output,
43
+ abs_tol,
44
+ )
45
+
46
+ @staticmethod
47
+ def convert_overflow_ratio_to_consistent(ratio):
48
+ if math.isnan(ratio) or math.isinf(ratio):
49
+ return ThresholdConfig.COMP_CONSISTENT
50
+ return ratio
51
+
52
+ @abstractmethod
53
+ def get_threshold(self, dtype):
54
+ pass
55
+
56
+ @abstractmethod
57
+ def handle(self, data_params: DataParams) -> Any:
58
+ pass
59
+
60
+ def get_ratio_from_specific_norm(
61
+ self, origin_output, perturbed_output, norm_type, abs_tol
62
+ ):
63
+ if norm_type == NormType.ENDLESS_NORM:
64
+ return self.get_endless_norm(origin_output, perturbed_output, abs_tol)
65
+ return ThresholdConfig.COMP_CONSISTENT
66
+
67
+ def get_endless_norm(self, origin_output, perturbed_output, abs_tol):
68
+ ratio_tensor1 = TorchC.where(
69
+ TorchC.gt(TorchC.abs(perturbed_output), abs_tol),
70
+ TorchC.div(
71
+ TorchC.abs(origin_output),
72
+ TorchC.add(TorchC.abs(perturbed_output), abs_tol),
73
+ ),
74
+ 1,
75
+ )
76
+ ratio_tensor2 = TorchC.where(
77
+ TorchC.gt(TorchC.abs(origin_output), abs_tol),
78
+ TorchC.div(
79
+ TorchC.abs(perturbed_output),
80
+ TorchC.add(TorchC.abs(origin_output), abs_tol),
81
+ ),
82
+ 1,
83
+ )
84
+
85
+ norm1 = self.convert_overflow_ratio_to_consistent(
86
+ TorchC.max(ratio_tensor1).item()
87
+ )
88
+ norm2 = self.convert_overflow_ratio_to_consistent(
89
+ TorchC.max(ratio_tensor2).item()
90
+ )
91
+ norm3 = self.convert_overflow_ratio_to_consistent(
92
+ TorchC.min(ratio_tensor1).item()
93
+ )
94
+ if norm3 < 0:
95
+ ratio = ThresholdConfig.SYMBOL_FLIPPING
96
+ else:
97
+ ratio = max(norm1, norm2)
98
+ return ratio
99
+
100
+ def ratio_calculate(self, origin_output, perturbed_output, norm_type) -> float:
101
+ try:
102
+ origin_output, perturbed_output, abs_tol = self.pre_process(
103
+ origin_output, perturbed_output
104
+ )
105
+ except Exception as e:
106
+ logger.warning_on_rank_0(
107
+ f"[msprobe] Free Benchmark: For {self.params.api_name}, "
108
+ f"when computing ratio,"
109
+ f" y1 or y2 dtype is not supported {e}"
110
+ )
111
+ return ThresholdConfig.COMP_NAN
112
+ if self.params.fuzz_stage == Const.BACKWARD:
113
+ abs_tol = ThresholdConfig.BACKWARD_OUTPUT_LOWER_BOUND
114
+ else:
115
+ abs_tol = abs_tol ** 0.5
116
+ return self.get_ratio_from_specific_norm(
117
+ origin_output, perturbed_output, norm_type, abs_tol
118
+ )
119
+
120
+ def npu_compare(
121
+ self, origin_output, perturbed_output
122
+ ) -> Tuple[bool, Optional[float]]:
123
+
124
+ if isinstance(perturbed_output, int):
125
+ return origin_output == perturbed_output, None
126
+ elif isinstance(perturbed_output, float):
127
+ if perturbed_output == 0:
128
+ origin_output += FuzzThreshold.F32_THD
129
+ perturbed_output += FuzzThreshold.F32_THD
130
+ return (
131
+ math.isclose(origin_output, perturbed_output),
132
+ origin_output / perturbed_output,
133
+ )
134
+ elif not isinstance(perturbed_output, torch.Tensor):
135
+ logger.warning_on_rank_0(
136
+ f"[msprobe] Free Benchmark: For {self.params.api_name} "
137
+ f"The compare for output type {type(perturbed_output)} is not supported"
138
+ )
139
+
140
+ threshold = self.get_threshold(Tools.get_first_tensor_dtype(origin_output))
141
+ ratio = self.ratio_calculate(
142
+ origin_output, perturbed_output, norm_type=NormType.ENDLESS_NORM
143
+ )
144
+ if ratio == ThresholdConfig.SYMBOL_FLIPPING:
145
+ is_consistent = False
146
+ else:
147
+ is_consistent = threshold >= ratio >= 1 / threshold
148
+ return is_consistent, ratio
149
+
150
+ def cmp_output_npu(self, data_params: DataParams):
151
+ npu_consistent = True
152
+ max_fuzz_ratio = 0
153
+ try:
154
+ if isinstance(data_params.original_result, torch.Tensor):
155
+ is_consistent, ratio = self.npu_compare(
156
+ data_params.original_result, data_params.perturbed_result
157
+ )
158
+ npu_consistent = is_consistent
159
+ max_fuzz_ratio = (
160
+ max_fuzz_ratio if ratio is None else max(max_fuzz_ratio, ratio)
161
+ )
162
+ data_params.is_consistent = is_consistent and data_params.is_consistent
163
+ if not is_consistent and data_params.grad_unequal_flag:
164
+ self.unequal_rows.append(
165
+ make_unequal_row(data_params, self.params, ratio=ratio)
166
+ )
167
+
168
+ elif isinstance(data_params.original_result, (list, tuple)):
169
+ for index_, origin_item in enumerate(data_params.original_result):
170
+ is_consistent, ratio = self.npu_compare(
171
+ origin_item, data_params.perturbed_result[index_]
172
+ )
173
+ npu_consistent = npu_consistent and is_consistent
174
+ max_fuzz_ratio = (
175
+ max_fuzz_ratio if ratio is None else max(max_fuzz_ratio, ratio)
176
+ )
177
+ data_params.is_consistent = (
178
+ is_consistent and data_params.is_consistent
179
+ )
180
+ if not is_consistent and data_params.grad_unequal_flag:
181
+ self.unequal_rows.append(
182
+ make_unequal_row(
183
+ data_params, self.params, ratio=ratio, index=index_
184
+ )
185
+ )
186
+ except Exception as e:
187
+ logger.warning_on_rank_0(
188
+ f"[msprobe] Free Benchmark: For {self.params.api_name}, "
189
+ f"when campare the result exception raise {e}"
190
+ )
191
+ return npu_consistent, max_fuzz_ratio
192
+
193
+ def get_unequal_rows(self):
194
+ return self.unequal_rows
195
+
196
+ def _get_default_threshold(self, dtype):
197
+ if self.params.pert_mode == PerturbationMode.NO_CHANGE:
198
+ threshold = ThresholdConfig.COMP_CONSISTENT
199
+ else:
200
+ threshold = ThresholdConfig.DTYPE_PER_THD.get(
201
+ dtype, ThresholdConfig.DTYPE_PER_THD.get(torch.float32)
202
+ )
203
+ return threshold
@@ -0,0 +1,39 @@
1
+ from typing import Any
2
+
3
+ from msprobe.pytorch.free_benchmark import logger
4
+ from msprobe.pytorch.free_benchmark.common.enums import DeviceType
5
+ from msprobe.pytorch.free_benchmark.common.params import DataParams, make_unequal_row
6
+ from msprobe.pytorch.free_benchmark.common.utils import Tools
7
+ from msprobe.pytorch.free_benchmark.compare.single_benchmark import SingleCompare
8
+ from msprobe.pytorch.free_benchmark.result_handlers.base_handler import FuzzHandler
9
+
10
+
11
+ class CheckerHandler(FuzzHandler):
12
+ def other_compare(self, data_params: DataParams) -> bool:
13
+ is_consistent = SingleCompare().compare_seq(
14
+ data_params.original_result, data_params.perturbed_result
15
+ )
16
+ if not is_consistent:
17
+ self.unequal_rows.append(
18
+ make_unequal_row(data_params, self.params)
19
+ )
20
+
21
+ def get_threshold(self, dtype):
22
+ return self._get_default_threshold(dtype)
23
+
24
+ def handle(self, data_params: DataParams) -> Any:
25
+ if isinstance(data_params.perturbed_result, bool) or not Tools.is_float_tensor(
26
+ data_params.perturbed_result
27
+ ):
28
+ return data_params.original_result
29
+ try:
30
+ if self.params.fuzz_device == DeviceType.NPU:
31
+ self.cmp_output_npu(data_params)
32
+ else:
33
+ self.other_compare(data_params)
34
+ except Exception as e:
35
+ logger.warning_on_rank_0(
36
+ f"[msprobe] Free Benchmark: For {self.params.api_name}, "
37
+ f"when campare the result exception raise {e}"
38
+ )
39
+ return data_params.original_result