mindstudio-probe 1.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (228) hide show
  1. mindstudio_probe-1.0.1.dist-info/LICENSE +201 -0
  2. mindstudio_probe-1.0.1.dist-info/METADATA +30 -0
  3. mindstudio_probe-1.0.1.dist-info/RECORD +228 -0
  4. mindstudio_probe-1.0.1.dist-info/WHEEL +5 -0
  5. mindstudio_probe-1.0.1.dist-info/entry_points.txt +2 -0
  6. mindstudio_probe-1.0.1.dist-info/top_level.txt +1 -0
  7. msprobe/README.md +182 -0
  8. msprobe/__init__.py +0 -0
  9. msprobe/config/README.md +397 -0
  10. msprobe/config/config.json +28 -0
  11. msprobe/config/img/free_benchmark.png +0 -0
  12. msprobe/core/common/const.py +241 -0
  13. msprobe/core/common/exceptions.py +88 -0
  14. msprobe/core/common/file_check.py +265 -0
  15. msprobe/core/common/log.py +55 -0
  16. msprobe/core/common/utils.py +516 -0
  17. msprobe/core/common_config.py +58 -0
  18. msprobe/core/data_dump/data_collector.py +140 -0
  19. msprobe/core/data_dump/data_processor/base.py +245 -0
  20. msprobe/core/data_dump/data_processor/factory.py +61 -0
  21. msprobe/core/data_dump/data_processor/pytorch_processor.py +346 -0
  22. msprobe/core/data_dump/json_writer.py +116 -0
  23. msprobe/core/data_dump/scope.py +178 -0
  24. msprobe/mindspore/__init__.py +1 -0
  25. msprobe/mindspore/debugger/__init__.py +0 -0
  26. msprobe/mindspore/debugger/debugger_config.py +51 -0
  27. msprobe/mindspore/debugger/precision_debugger.py +32 -0
  28. msprobe/mindspore/doc/dump.md +65 -0
  29. msprobe/mindspore/dump/__init__.py +0 -0
  30. msprobe/mindspore/dump/api_kbk_dump.py +55 -0
  31. msprobe/mindspore/dump/dump_tool_factory.py +38 -0
  32. msprobe/mindspore/dump/kernel_graph_dump.py +60 -0
  33. msprobe/mindspore/ms_config.py +78 -0
  34. msprobe/mindspore/overflow_check/__init__.py +0 -0
  35. msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +45 -0
  36. msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +32 -0
  37. msprobe/mindspore/task_handler_factory.py +21 -0
  38. msprobe/msprobe.py +67 -0
  39. msprobe/pytorch/__init__.py +4 -0
  40. msprobe/pytorch/advisor/advisor.py +124 -0
  41. msprobe/pytorch/advisor/advisor_const.py +59 -0
  42. msprobe/pytorch/advisor/advisor_result.py +58 -0
  43. msprobe/pytorch/api_accuracy_checker/.keep +0 -0
  44. msprobe/pytorch/api_accuracy_checker/__init__.py +0 -0
  45. msprobe/pytorch/api_accuracy_checker/common/.keep +0 -0
  46. msprobe/pytorch/api_accuracy_checker/common/__init__.py +0 -0
  47. msprobe/pytorch/api_accuracy_checker/common/config.py +50 -0
  48. msprobe/pytorch/api_accuracy_checker/common/utils.py +224 -0
  49. msprobe/pytorch/api_accuracy_checker/compare/__init__.py +0 -0
  50. msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +216 -0
  51. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +545 -0
  52. msprobe/pytorch/api_accuracy_checker/compare/api_precision_standard.yaml +133 -0
  53. msprobe/pytorch/api_accuracy_checker/compare/api_precision_threshold.yaml +390 -0
  54. msprobe/pytorch/api_accuracy_checker/compare/compare.py +345 -0
  55. msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +74 -0
  56. msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +249 -0
  57. msprobe/pytorch/api_accuracy_checker/config.yaml +4 -0
  58. msprobe/pytorch/api_accuracy_checker/run_ut/.keep +0 -0
  59. msprobe/pytorch/api_accuracy_checker/run_ut/__init__.py +0 -0
  60. msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +328 -0
  61. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +203 -0
  62. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +127 -0
  63. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +493 -0
  64. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +7 -0
  65. msprobe/pytorch/api_accuracy_checker/run_ut/torch_ut_setting.json +5 -0
  66. msprobe/pytorch/common/__init__.py +2 -0
  67. msprobe/pytorch/common/compare_script.template +14 -0
  68. msprobe/pytorch/common/log.py +32 -0
  69. msprobe/pytorch/common/parse_json.py +37 -0
  70. msprobe/pytorch/common/utils.py +224 -0
  71. msprobe/pytorch/compare/acc_compare.py +1024 -0
  72. msprobe/pytorch/compare/distributed_compare.py +111 -0
  73. msprobe/pytorch/compare/highlight.py +100 -0
  74. msprobe/pytorch/compare/mapping.yaml +607 -0
  75. msprobe/pytorch/compare/match.py +36 -0
  76. msprobe/pytorch/compare/npy_compare.py +244 -0
  77. msprobe/pytorch/debugger/__init__.py +0 -0
  78. msprobe/pytorch/debugger/debugger_config.py +86 -0
  79. msprobe/pytorch/debugger/precision_debugger.py +95 -0
  80. msprobe/pytorch/doc/FAQ.md +193 -0
  81. msprobe/pytorch/doc/api_accuracy_checker.md +269 -0
  82. msprobe/pytorch/doc/atat/321/207/342/226/223/342/225/233/321/205/342/225/221/320/266/321/205/342/225/226/320/265/321/205/320/225/342/225/226/321/206/320/245/342/226/221/321/206/320/235/320/276dump/321/206/320/260/320/227/321/205/320/227/320/226/321/206/320/220/320/267/321/210/320/223/342/225/234/321/205/320/257/342/225/221/321/207/342/225/221/342/224/220/321/206/320/232/320/265/321/205/320/241/320/232.md +182 -0
  83. msprobe/pytorch/doc/dump.md +207 -0
  84. msprobe/pytorch/doc/img/BLOOM-7B_1.png +0 -0
  85. msprobe/pytorch/doc/img/BLOOM-7B_2.png +0 -0
  86. msprobe/pytorch/doc/img/BLOOM-7B_3.png +0 -0
  87. msprobe/pytorch/doc/img/BLOOM-7B_4.png +0 -0
  88. msprobe/pytorch/doc/img/GPT-3_1.png +0 -0
  89. msprobe/pytorch/doc/img/GPT-3_2.png +0 -0
  90. msprobe/pytorch/doc/img/GPT-3_3.png +0 -0
  91. msprobe/pytorch/doc/img/GPT-3_4.png +0 -0
  92. msprobe/pytorch/doc/img/GPT-3_5.png +0 -0
  93. msprobe/pytorch/doc/img/GPT-3_6.png +0 -0
  94. msprobe/pytorch/doc/img/GPT-3_7.png +0 -0
  95. msprobe/pytorch/doc/img/GPT-3_8.png +0 -0
  96. msprobe/pytorch/doc/img/YOLOV5S_1.png +0 -0
  97. msprobe/pytorch/doc/img/YOLOV5S_2.png +0 -0
  98. msprobe/pytorch/doc/img/accuracy_checking_details.png +0 -0
  99. msprobe/pytorch/doc/img/accuracy_checking_result.png +0 -0
  100. msprobe/pytorch/doc/img/api_precision_compare_details.png +0 -0
  101. msprobe/pytorch/doc/img/api_precision_compare_result.png +0 -0
  102. msprobe/pytorch/doc/img/auto_analyze_log.png +0 -0
  103. msprobe/pytorch/doc/img/compare_result_pkl.png +0 -0
  104. msprobe/pytorch/doc/img/compare_result_pkl_md5.png.png +0 -0
  105. msprobe/pytorch/doc/img/cpu_info.png +0 -0
  106. msprobe/pytorch/doc/img/module_compare.png +0 -0
  107. msprobe/pytorch/doc/parse_tool.md +286 -0
  108. msprobe/pytorch/doc/ptdbg_ascend_compare.md +176 -0
  109. msprobe/pytorch/doc/ptdbg_ascend_overview.md +68 -0
  110. msprobe/pytorch/doc/ptdbg_ascend_quickstart.md +381 -0
  111. msprobe/pytorch/doc/run_overflow_check.md +25 -0
  112. msprobe/pytorch/doc//321/205/320/254/320/270/321/207/342/225/221/342/224/220/321/207/342/226/223/342/225/233/321/205/342/225/221/320/266/321/206/320/277/320/244/321/205/320/277/342/225/243.md +90 -0
  113. msprobe/pytorch/free_benchmark/__init__.py +8 -0
  114. msprobe/pytorch/free_benchmark/common/__init__.py +0 -0
  115. msprobe/pytorch/free_benchmark/common/constant.py +67 -0
  116. msprobe/pytorch/free_benchmark/common/counter.py +72 -0
  117. msprobe/pytorch/free_benchmark/common/enums.py +37 -0
  118. msprobe/pytorch/free_benchmark/common/params.py +129 -0
  119. msprobe/pytorch/free_benchmark/common/utils.py +98 -0
  120. msprobe/pytorch/free_benchmark/compare/grad_saver.py +183 -0
  121. msprobe/pytorch/free_benchmark/compare/single_benchmark.py +104 -0
  122. msprobe/pytorch/free_benchmark/main.py +102 -0
  123. msprobe/pytorch/free_benchmark/perturbed_layers/__init__.py +0 -0
  124. msprobe/pytorch/free_benchmark/perturbed_layers/base_layer.py +13 -0
  125. msprobe/pytorch/free_benchmark/perturbed_layers/layer_factory.py +41 -0
  126. msprobe/pytorch/free_benchmark/perturbed_layers/npu/__init__.py +0 -0
  127. msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +90 -0
  128. msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +104 -0
  129. msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +63 -0
  130. msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +68 -0
  131. msprobe/pytorch/free_benchmark/perturbed_layers/npu/no_change.py +28 -0
  132. msprobe/pytorch/free_benchmark/perturbed_layers/npu/npu_base_layser.py +45 -0
  133. msprobe/pytorch/free_benchmark/perturbed_layers/run_cpu.py +19 -0
  134. msprobe/pytorch/free_benchmark/result_handlers/__init__.py +0 -0
  135. msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +203 -0
  136. msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +39 -0
  137. msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +24 -0
  138. msprobe/pytorch/free_benchmark/result_handlers/handler_factory.py +31 -0
  139. msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +170 -0
  140. msprobe/pytorch/functional/__init__.py +0 -0
  141. msprobe/pytorch/functional/data_processor.py +0 -0
  142. msprobe/pytorch/functional/dump_module.py +39 -0
  143. msprobe/pytorch/hook_module/__init__.py +1 -0
  144. msprobe/pytorch/hook_module/api_registry.py +161 -0
  145. msprobe/pytorch/hook_module/hook_module.py +109 -0
  146. msprobe/pytorch/hook_module/support_wrap_ops.yaml +1876 -0
  147. msprobe/pytorch/hook_module/utils.py +29 -0
  148. msprobe/pytorch/hook_module/wrap_aten.py +100 -0
  149. msprobe/pytorch/hook_module/wrap_distributed.py +75 -0
  150. msprobe/pytorch/hook_module/wrap_functional.py +108 -0
  151. msprobe/pytorch/hook_module/wrap_npu_custom.py +73 -0
  152. msprobe/pytorch/hook_module/wrap_tensor.py +72 -0
  153. msprobe/pytorch/hook_module/wrap_torch.py +88 -0
  154. msprobe/pytorch/hook_module/wrap_vf.py +64 -0
  155. msprobe/pytorch/module_processer.py +98 -0
  156. msprobe/pytorch/online_dispatch/__init__.py +20 -0
  157. msprobe/pytorch/online_dispatch/compare.py +236 -0
  158. msprobe/pytorch/online_dispatch/dispatch.py +274 -0
  159. msprobe/pytorch/online_dispatch/dump_compare.py +186 -0
  160. msprobe/pytorch/online_dispatch/single_compare.py +391 -0
  161. msprobe/pytorch/online_dispatch/torch_ops_config.yaml +50 -0
  162. msprobe/pytorch/online_dispatch/utils.py +187 -0
  163. msprobe/pytorch/parse.py +4 -0
  164. msprobe/pytorch/parse_tool/__init__.py +0 -0
  165. msprobe/pytorch/parse_tool/cli.py +32 -0
  166. msprobe/pytorch/parse_tool/lib/__init__.py +0 -0
  167. msprobe/pytorch/parse_tool/lib/compare.py +259 -0
  168. msprobe/pytorch/parse_tool/lib/config.py +51 -0
  169. msprobe/pytorch/parse_tool/lib/file_desc.py +31 -0
  170. msprobe/pytorch/parse_tool/lib/interactive_cli.py +102 -0
  171. msprobe/pytorch/parse_tool/lib/parse_exception.py +54 -0
  172. msprobe/pytorch/parse_tool/lib/parse_tool.py +158 -0
  173. msprobe/pytorch/parse_tool/lib/utils.py +367 -0
  174. msprobe/pytorch/parse_tool/lib/visualization.py +90 -0
  175. msprobe/pytorch/pt_config.py +93 -0
  176. msprobe/pytorch/service.py +167 -0
  177. msprobe/test/core_ut/common/test_utils.py +345 -0
  178. msprobe/test/core_ut/data_dump/test_data_collector.py +47 -0
  179. msprobe/test/core_ut/data_dump/test_json_writer.py +183 -0
  180. msprobe/test/core_ut/data_dump/test_scope.py +151 -0
  181. msprobe/test/core_ut/test_common_config.py +152 -0
  182. msprobe/test/core_ut/test_file_check.py +218 -0
  183. msprobe/test/core_ut/test_log.py +109 -0
  184. msprobe/test/mindspore_ut/test_api_kbk_dump.py +51 -0
  185. msprobe/test/mindspore_ut/test_debugger_config.py +42 -0
  186. msprobe/test/mindspore_ut/test_dump_tool_factory.py +51 -0
  187. msprobe/test/mindspore_ut/test_kernel_graph_dump.py +66 -0
  188. msprobe/test/mindspore_ut/test_kernel_graph_overflow_check.py +63 -0
  189. msprobe/test/mindspore_ut/test_ms_config.py +69 -0
  190. msprobe/test/mindspore_ut/test_overflow_check_tool_factory.py +51 -0
  191. msprobe/test/mindspore_ut/test_precision_debugger.py +56 -0
  192. msprobe/test/mindspore_ut/test_task_handler_factory.py +58 -0
  193. msprobe/test/pytorch_ut/advisor/test_advisor.py +83 -0
  194. msprobe/test/pytorch_ut/api_accuracy_checker/common/test_common_utils.py +108 -0
  195. msprobe/test/pytorch_ut/api_accuracy_checker/common/test_config.py +39 -0
  196. msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_algorithm.py +112 -0
  197. msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_api_precision_compare.py +77 -0
  198. msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_compare.py +125 -0
  199. msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_compare_column.py +10 -0
  200. msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_compare_utils.py +43 -0
  201. msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/dump.json +179 -0
  202. msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/forward.json +63 -0
  203. msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_data_generate.py +99 -0
  204. msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_multi_run_ut.py +115 -0
  205. msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_run_ut.py +72 -0
  206. msprobe/test/pytorch_ut/compare/test_acc_compare.py +17 -0
  207. msprobe/test/pytorch_ut/free_benchmark/perturbed_layers/test_perturbed_layser.py +105 -0
  208. msprobe/test/pytorch_ut/free_benchmark/result_handlers/test_result_handler.py +121 -0
  209. msprobe/test/pytorch_ut/free_benchmark/test_main.py +101 -0
  210. msprobe/test/pytorch_ut/functional/test_dump_module.py +15 -0
  211. msprobe/test/pytorch_ut/hook_module/test_api_registry.py +130 -0
  212. msprobe/test/pytorch_ut/hook_module/test_hook_module.py +42 -0
  213. msprobe/test/pytorch_ut/hook_module/test_wrap_aten.py +65 -0
  214. msprobe/test/pytorch_ut/hook_module/test_wrap_distributed.py +35 -0
  215. msprobe/test/pytorch_ut/hook_module/test_wrap_functional.py +20 -0
  216. msprobe/test/pytorch_ut/hook_module/test_wrap_tensor.py +35 -0
  217. msprobe/test/pytorch_ut/hook_module/test_wrap_torch.py +43 -0
  218. msprobe/test/pytorch_ut/hook_module/test_wrap_vf.py +11 -0
  219. msprobe/test/pytorch_ut/test_pt_config.py +69 -0
  220. msprobe/test/pytorch_ut/test_service.py +59 -0
  221. msprobe/test/resources/advisor.txt +3 -0
  222. msprobe/test/resources/compare_result_20230703104808.csv +9 -0
  223. msprobe/test/resources/compare_result_without_accuracy.csv +9 -0
  224. msprobe/test/resources/config.yaml +3 -0
  225. msprobe/test/resources/npu_test.pkl +8 -0
  226. msprobe/test/run_test.sh +30 -0
  227. msprobe/test/run_ut.py +58 -0
  228. msprobe/test/test_module_processer.py +64 -0
@@ -0,0 +1,24 @@
1
+ from typing import Any
2
+
3
+ from msprobe.pytorch.free_benchmark.common.params import DataParams
4
+ from msprobe.pytorch.free_benchmark.common.utils import Tools
5
+ from msprobe.pytorch.free_benchmark.result_handlers.base_handler import FuzzHandler
6
+ from msprobe.pytorch.free_benchmark import logger
7
+
8
+
9
+ class FixHandler(FuzzHandler):
10
+
11
+ def get_threshold(self, dtype):
12
+ return self._get_default_threshold(dtype)
13
+
14
+ def handle(self, data_params: DataParams) -> Any:
15
+ try:
16
+ return Tools.convert_fuzz_output_to_origin(
17
+ data_params.original_result, data_params.perturbed_result
18
+ )
19
+ except Exception as e:
20
+ logger.warning_on_rank_0(
21
+ f"[msprobe] Free Benchmark: For {self.params.api_name} "
22
+ f"Fix output failed. "
23
+ )
24
+ return data_params.original_result
@@ -0,0 +1,31 @@
1
+ from msprobe.pytorch.free_benchmark import FreeBenchmarkException
2
+ from msprobe.pytorch.free_benchmark.common.constant import PreheatConfig
3
+ from msprobe.pytorch.free_benchmark.common.enums import HandlerType
4
+ from msprobe.pytorch.free_benchmark.common.params import HandlerParams
5
+ from msprobe.pytorch.free_benchmark.result_handlers.check_handler import CheckerHandler
6
+ from msprobe.pytorch.free_benchmark.result_handlers.preheat_handler import PreheatHandler
7
+ from msprobe.pytorch.free_benchmark.result_handlers.fix_handler import FixHandler
8
+
9
+
10
+ class FuzzHandlerFactory:
11
+
12
+ result_handlers = {
13
+ HandlerType.CHECK: CheckerHandler,
14
+ HandlerType.FIX: FixHandler,
15
+ HandlerType.PREHEAT: PreheatHandler,
16
+ }
17
+
18
+ @staticmethod
19
+ def create(params: HandlerParams):
20
+ if_preheat = params.preheat_config.get(PreheatConfig.IF_PREHEAT)
21
+ if not if_preheat:
22
+ handler = FuzzHandlerFactory.result_handlers.get(params.handler_type)
23
+ else:
24
+ handler = FuzzHandlerFactory.result_handlers.get(HandlerType.PREHEAT)
25
+ # TODO
26
+ if not handler:
27
+ raise FreeBenchmarkException(
28
+ FreeBenchmarkException.UnsupportedType,
29
+ f"无标杆工具支持 [ {HandlerType.CHECK}、{HandlerType.FIX}] 形式",
30
+ )
31
+ return handler(params)
@@ -0,0 +1,170 @@
1
+ import math
2
+ from typing import Any
3
+
4
+ from msprobe.pytorch.free_benchmark import logger
5
+ from msprobe.pytorch.free_benchmark.common.constant import ThresholdConfig
6
+ from msprobe.pytorch.free_benchmark.common.counter import preheat_counter
7
+ from msprobe.pytorch.free_benchmark.common.enums import DeviceType
8
+ from msprobe.pytorch.free_benchmark.common.params import DataParams, HandlerParams
9
+ from msprobe.pytorch.free_benchmark.common.utils import Tools
10
+ from msprobe.pytorch.free_benchmark.compare.single_benchmark import SingleCompare
11
+ from msprobe.pytorch.free_benchmark.result_handlers.base_handler import FuzzHandler
12
+
13
+
14
+ class PreheatHandler(FuzzHandler):
15
+
16
+ def __init__(self, params: HandlerParams) -> None:
17
+ super().__init__(params)
18
+ self.pure_name = Tools.get_pure_api_name(self.params.api_name)
19
+
20
+ def get_threshold(self, dtype):
21
+ return preheat_counter.get_api_thd(self.pure_name, dtype)
22
+
23
+ def compare_npu_and_cpu(self, data_params: DataParams):
24
+ args = Tools.convert_device_and_dtype(
25
+ data_params.args, DeviceType.CPU, change_dtype=True
26
+ )
27
+ kwargs = Tools.convert_device_and_dtype(
28
+ data_params.kwargs, DeviceType.CPU, change_dtype=True
29
+ )
30
+ cpu_result = data_params.origin_func(*args, **kwargs)
31
+ return SingleCompare().compare_seq(data_params.original_result, cpu_result)
32
+
33
+ def preheat(self, max_fuzz_ratio, cpu_consistent, first_dtype):
34
+ # 存储当前step所有输出比值和对应npu\cpu比对结果
35
+ preheat_counter.update_preheat_record(
36
+ self.pure_name,
37
+ first_dtype,
38
+ (max_fuzz_ratio, cpu_consistent),
39
+ )
40
+ if self._need_adjust_threshold():
41
+ self._adjust_threshold()
42
+
43
+ def handle(self, data_params: DataParams) -> Any:
44
+
45
+ if isinstance(data_params.perturbed_result, bool) or not Tools.is_float_tensor(
46
+ data_params.perturbed_result
47
+ ):
48
+ return data_params.original_result
49
+
50
+ if self.params.step == 0:
51
+ preheat_counter.add_one_step_used_api(self.pure_name)
52
+ return data_params.original_result
53
+
54
+ # 如果当前api,step需要预热
55
+ npu_consistent, max_fuzz_ratio = self.cmp_output_npu(data_params)
56
+ data_params.is_consistent = npu_consistent
57
+
58
+ preheat_counter.check_step(self.params.step)
59
+
60
+ if self.params.preheat_config.get("preheat_step") <= self.params.step:
61
+ return data_params.original_result
62
+
63
+ if not data_params.grad_unequal_flag:
64
+ data_params.grad_unequal_flag = True
65
+ data_params.is_consistent = False
66
+ return data_params.original_result
67
+ preheat_counter.add_api_called_time(self.pure_name)
68
+
69
+ if not self._is_take_a_sample():
70
+ return data_params.original_result
71
+
72
+ cpu_consistent = True
73
+ try:
74
+ cpu_consistent = self.compare_npu_and_cpu(data_params)
75
+ except Exception as e:
76
+ logger.warning_on_rank_0(
77
+ f"[msprobe] Free Benchmark: For {self.params.api_name}, "
78
+ f"when campare to cpu exception raise {e}"
79
+ )
80
+ try:
81
+ first_dtype = Tools.get_first_tensor_dtype(data_params.original_result)
82
+ except RuntimeError:
83
+ logger.warning_on_rank_0(
84
+ f"[msprobe] Free Benchmark: For {self.params.api_name}, "
85
+ f"the output sequence does not contain tensors."
86
+ )
87
+ if preheat_counter.get_api_preheat(self.pure_name, str(first_dtype)):
88
+ self.preheat(max_fuzz_ratio, cpu_consistent, first_dtype)
89
+
90
+ return data_params.original_result
91
+
92
+ def _is_take_a_sample(self) -> bool:
93
+ need_sample_set = self._get_need_sample_set()
94
+ curr_called_seq = preheat_counter.get_api_called_time(self.pure_name)
95
+ res = curr_called_seq in need_sample_set
96
+ if res:
97
+ total_count = preheat_counter.get_one_step_used_api(self.pure_name)
98
+ logger.info_on_rank_0(
99
+ f"[msprobe] Free benchmark: preheat sample in step{self.params.step}"
100
+ f"api_name {self.params.api_name}, "
101
+ f"curr_called_seq: {curr_called_seq}/{total_count}"
102
+ )
103
+ preheat_counter.add_api_sample_time(self.pure_name)
104
+ return res
105
+
106
+ def _get_sample_count_per_step(self) -> set:
107
+ """
108
+ 每一个step中应该采集的样本数
109
+ """
110
+ total_count = preheat_counter.get_one_step_used_api(self.pure_name)
111
+ preheat_step = self.params.preheat_config.get("preheat_step")
112
+ max_sample = self.params.preheat_config.get("max_sample")
113
+ return min(math.ceil(total_count / preheat_step), max_sample)
114
+
115
+ def _get_need_sample_set(self):
116
+ """
117
+ 需要采集的api集合
118
+ """
119
+ # 每一步样本数
120
+ total_count = preheat_counter.get_one_step_used_api(self.pure_name)
121
+ sample_count_per_step = self._get_sample_count_per_step()
122
+ need_sample_set = set()
123
+ prehead_step = self.params.preheat_config.get("preheat_step")
124
+ for i in range(1, sample_count_per_step + 1):
125
+ count = (prehead_step * (i - 1) + self.params.step) % total_count
126
+ if count == 0:
127
+ count = total_count
128
+ need_sample_set.add(count)
129
+ return need_sample_set
130
+
131
+ def _need_adjust_threshold(self) -> bool:
132
+ sample_count_per_step = self._get_sample_count_per_step()
133
+ sampled_time = preheat_counter.get_api_sample_time(self.pure_name)
134
+ res = sampled_time >= sample_count_per_step
135
+ return res
136
+
137
+ def _adjust_threshold_for_dtype(self, dtype_str, compare_result):
138
+ con_ratio = [ratio for ratio, is_consistent in compare_result if is_consistent]
139
+ incon_ratio = [
140
+ ratio for ratio, is_consistent in compare_result if not is_consistent
141
+ ]
142
+ old_thd = preheat_counter.get_api_thd(self.pure_name, dtype_str)
143
+ new_thd = old_thd
144
+ # 正例负例都存在
145
+ if con_ratio and incon_ratio:
146
+ if min(incon_ratio) > max(con_ratio):
147
+ new_thd = min(min(incon_ratio), old_thd)
148
+ preheat_counter.set_api_preheat(self.pure_name, dtype_str, is_preheat=False)
149
+ elif con_ratio:
150
+ # 存在漏报
151
+ if max(con_ratio) > old_thd:
152
+ new_thd = 1 + ((old_thd - 1) * ThresholdConfig.API_THD_STEP)
153
+ else:
154
+ new_thd = 1 + ((old_thd - 1) / ThresholdConfig.API_THD_STEP)
155
+ else:
156
+ new_thd = min(min(incon_ratio), old_thd)
157
+ preheat_counter.set_api_preheat(self.pure_name, dtype_str, is_preheat=False)
158
+ return new_thd
159
+
160
+ def _adjust_threshold(self):
161
+ for dtype_str, compare_result in preheat_counter.preheat_record[
162
+ self.pure_name
163
+ ].items():
164
+ new_thd = self._adjust_threshold_for_dtype(dtype_str, compare_result)
165
+ threshold = self._get_default_threshold(
166
+ preheat_counter.dtype_map.get(dtype_str)
167
+ )
168
+ preheat_counter.update_api_thd(
169
+ self.pure_name, dtype_str, new_thd, threshold
170
+ )
File without changes
File without changes
@@ -0,0 +1,39 @@
1
+ import torch.nn as nn
2
+ from msprobe.pytorch.common.log import logger
3
+ from msprobe.core.common.const import Const
4
+ from msprobe.pytorch.hook_module.api_registry import api_register
5
+ from msprobe.pytorch.debugger.precision_debugger import PrecisionDebugger
6
+ from msprobe.core.common.exceptions import MsaccException
7
+ from msprobe.core.data_dump.scope import BaseScope
8
+
9
+ module_count = {}
10
+
11
+
12
+ def module_dump(module, dump_name):
13
+ if not isinstance(module, nn.Module):
14
+ logger.error("The parameter:module in module_dump is not a Module subclass.")
15
+ raise MsaccException(MsaccException.INVALID_PARAM_ERROR)
16
+ if not isinstance(dump_name, str):
17
+ logger.error("The parameter:dump_name in module_dump is not a str type.")
18
+ raise MsaccException(MsaccException.INVALID_PARAM_ERROR)
19
+ api_register.api_originality()
20
+ if dump_name not in module_count:
21
+ module_count[dump_name] = 0
22
+ else:
23
+ module_count[dump_name] += 1
24
+ dump_name = dump_name + Const.SEP + str(module_count.get(dump_name)) + Const.SEP
25
+
26
+ pdg = PrecisionDebugger()
27
+ _, forward_hook, backward_hook = pdg.service.build_hook(BaseScope.Module_Type_Module, dump_name)
28
+ module.register_forward_hook(forward_hook, with_kwargs=True)
29
+ module.register_full_backward_hook(backward_hook)
30
+
31
+ module.register_forward_pre_hook(pdg.service.module_processor.node_hook(dump_name + Const.FORWARD, Const.START))
32
+ module.register_forward_hook(pdg.service.module_processor.node_hook(dump_name + Const.FORWARD, Const.STOP))
33
+ module.register_full_backward_pre_hook(
34
+ pdg.service.module_processor.node_hook(dump_name + Const.BACKWARD, Const.START))
35
+ module.register_full_backward_hook(pdg.service.module_processor.node_hook(dump_name + Const.BACKWARD, Const.STOP))
36
+
37
+
38
+ def module_dump_end():
39
+ api_register.api_modularity()
@@ -0,0 +1 @@
1
+ from .wrap_functional import remove_dropout
@@ -0,0 +1,161 @@
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+ """
4
+ # Copyright (C) 2022-2023. Huawei Technologies Co., Ltd. All rights reserved.
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """
17
+
18
+ import torch
19
+ import torch.distributed as dist
20
+
21
+ from msprobe.pytorch.hook_module import wrap_torch, wrap_functional, wrap_tensor, wrap_vf, wrap_distributed, wrap_aten
22
+ from msprobe.pytorch.hook_module.wrap_aten import get_aten_ops
23
+ from msprobe.pytorch.hook_module.wrap_distributed import get_distributed_ops
24
+ from msprobe.pytorch.hook_module.wrap_functional import get_functional_ops
25
+ from msprobe.pytorch.hook_module.wrap_tensor import get_tensor_ops
26
+ from msprobe.pytorch.hook_module.wrap_torch import get_torch_ops
27
+ from msprobe.pytorch.hook_module.wrap_vf import get_vf_ops
28
+ from msprobe.pytorch.common.utils import torch_without_guard_version, npu_distributed_api, is_gpu
29
+ from msprobe.core.common.const import Const
30
+
31
+ torch_version_above_2 = torch.__version__.split('+')[0] > '2.0'
32
+
33
+ if not is_gpu:
34
+ import torch_npu
35
+ from . import wrap_npu_custom
36
+ from .wrap_npu_custom import get_npu_ops
37
+
38
+
39
+ class ApiRegistry:
40
+ def __init__(self):
41
+ self.tensor_ori_attr = {}
42
+ self.torch_ori_attr = {}
43
+ self.functional_ori_attr = {}
44
+ self.distributed_ori_attr = {}
45
+ self.npu_distributed_ori_attr = {}
46
+ self.vf_ori_attr = {}
47
+ self.aten_ori_attr = {}
48
+ self.torch_npu_ori_attr = {}
49
+
50
+ self.tensor_hook_attr = {}
51
+ self.torch_hook_attr = {}
52
+ self.functional_hook_attr = {}
53
+ self.distributed_hook_attr = {}
54
+ self.npu_distributed_hook_attr = {}
55
+ self.vf_hook_attr = {}
56
+ self.aten_hook_attr = {}
57
+ self.torch_npu_hook_attr = {}
58
+
59
+ @staticmethod
60
+ def store_ori_attr(ori_api_group, api_list, api_ori_attr):
61
+ for api in api_list:
62
+ if '.' in api:
63
+ sub_module_name, sub_op = api.rsplit('.', 1)
64
+ sub_module = getattr(ori_api_group, sub_module_name)
65
+ api_ori_attr[api] = getattr(sub_module, sub_op)
66
+ else:
67
+ api_ori_attr[api] = getattr(ori_api_group, api)
68
+
69
+ @staticmethod
70
+ def set_api_attr(api_group, attr_dict):
71
+ for api, api_attr in attr_dict.items():
72
+ if '.' in api:
73
+ sub_module_name, sub_op = api.rsplit('.', 1)
74
+ sub_module = getattr(api_group, sub_module_name, None)
75
+ if sub_module is not None:
76
+ setattr(sub_module, sub_op, api_attr)
77
+ else:
78
+ setattr(api_group, api, api_attr)
79
+
80
+ def api_modularity(self):
81
+ self.set_api_attr(torch.Tensor, self.tensor_hook_attr)
82
+ self.set_api_attr(torch, self.torch_hook_attr)
83
+ self.set_api_attr(torch.nn.functional, self.functional_hook_attr)
84
+ self.set_api_attr(dist, self.distributed_hook_attr)
85
+ self.set_api_attr(dist.distributed_c10d, self.distributed_hook_attr)
86
+ if not is_gpu and not torch_without_guard_version:
87
+ self.set_api_attr(torch_npu.distributed, self.npu_distributed_hook_attr)
88
+ self.set_api_attr(torch_npu.distributed.distributed_c10d, self.npu_distributed_hook_attr)
89
+ if torch_version_above_2:
90
+ self.set_api_attr(torch.ops.aten, self.aten_hook_attr)
91
+ self.set_api_attr(torch._VF, self.vf_hook_attr)
92
+ if not is_gpu:
93
+ self.set_api_attr(torch_npu, self.torch_npu_hook_attr)
94
+
95
+ def api_originality(self):
96
+ self.set_api_attr(torch.Tensor, self.tensor_ori_attr)
97
+ self.set_api_attr(torch, self.torch_ori_attr)
98
+ self.set_api_attr(torch.nn.functional, self.functional_ori_attr)
99
+ self.set_api_attr(dist, self.distributed_ori_attr)
100
+ self.set_api_attr(dist.distributed_c10d, self.distributed_ori_attr)
101
+ if not is_gpu and not torch_without_guard_version:
102
+ self.set_api_attr(torch_npu.distributed, self.npu_distributed_ori_attr)
103
+ self.set_api_attr(torch_npu.distributed.distributed_c10d, self.npu_distributed_ori_attr)
104
+ if torch_version_above_2:
105
+ self.set_api_attr(torch.ops.aten, self.aten_ori_attr)
106
+ self.set_api_attr(torch._VF, self.vf_ori_attr)
107
+ if not is_gpu:
108
+ self.set_api_attr(torch_npu, self.torch_npu_ori_attr)
109
+
110
+ def initialize_hook(self, hook):
111
+ self.store_ori_attr(torch.Tensor, get_tensor_ops(), self.tensor_ori_attr)
112
+ wrap_tensor.wrap_tensor_ops_and_bind(hook)
113
+ for attr_name in dir(wrap_tensor.HOOKTensor):
114
+ if attr_name.startswith(Const.ATTR_NAME_PREFIX):
115
+ self.tensor_hook_attr[attr_name[5:]] = getattr(wrap_tensor.HOOKTensor, attr_name)
116
+
117
+ self.store_ori_attr(torch, get_torch_ops(), self.torch_ori_attr)
118
+ wrap_torch.wrap_torch_ops_and_bind(hook)
119
+ for attr_name in dir(wrap_torch.HOOKTorchOP):
120
+ if attr_name.startswith(Const.ATTR_NAME_PREFIX):
121
+ self.torch_hook_attr[attr_name[5:]] = getattr(wrap_torch.HOOKTorchOP, attr_name)
122
+
123
+ self.store_ori_attr(torch.nn.functional, get_functional_ops(), self.functional_ori_attr)
124
+ wrap_functional.wrap_functional_ops_and_bind(hook)
125
+ for attr_name in dir(wrap_functional.HOOKFunctionalOP):
126
+ if attr_name.startswith(Const.ATTR_NAME_PREFIX):
127
+ self.functional_hook_attr[attr_name[5:]] = getattr(wrap_functional.HOOKFunctionalOP, attr_name)
128
+
129
+ self.store_ori_attr(dist, get_distributed_ops(), self.distributed_ori_attr)
130
+ wrap_distributed.wrap_distributed_ops_and_bind(hook)
131
+ if not is_gpu and not torch_without_guard_version:
132
+ self.store_ori_attr(torch_npu.distributed, npu_distributed_api, self.npu_distributed_ori_attr)
133
+ for attr_name in dir(wrap_distributed.HOOKDistributedOP):
134
+ if attr_name.startswith(Const.ATTR_NAME_PREFIX):
135
+ self.distributed_hook_attr[attr_name[5:]] = getattr(wrap_distributed.HOOKDistributedOP, attr_name)
136
+ if not is_gpu and not torch_without_guard_version and attr_name[5:] in npu_distributed_api:
137
+ self.npu_distributed_hook_attr[attr_name[5:]] = getattr(wrap_distributed.HOOKDistributedOP,
138
+ attr_name)
139
+
140
+ if torch_version_above_2:
141
+ self.store_ori_attr(torch.ops.aten, get_aten_ops(), self.aten_ori_attr)
142
+ wrap_aten.wrap_aten_ops_and_bind(hook)
143
+ for attr_name in dir(wrap_aten.HOOKAtenOP):
144
+ if attr_name.startswith(Const.ATTR_NAME_PREFIX):
145
+ self.aten_hook_attr[attr_name[5:]] = getattr(wrap_aten.HOOKAtenOP, attr_name)
146
+
147
+ self.store_ori_attr(torch._VF, get_vf_ops(), self.vf_ori_attr)
148
+ wrap_vf.wrap_vf_ops_and_bind(hook)
149
+ for attr_name in dir(wrap_vf.HOOKVfOP):
150
+ if attr_name.startswith(Const.ATTR_NAME_PREFIX):
151
+ self.vf_hook_attr[attr_name[5:]] = getattr(wrap_vf.HOOKVfOP, attr_name)
152
+
153
+ if not is_gpu:
154
+ self.store_ori_attr(torch_npu, get_npu_ops(), self.torch_npu_ori_attr)
155
+ wrap_npu_custom.wrap_npu_ops_and_bind(hook)
156
+ for attr_name in dir(wrap_npu_custom.HOOKNpuOP):
157
+ if attr_name.startswith(Const.ATTR_NAME_PREFIX):
158
+ self.torch_npu_hook_attr[attr_name[5:]] = getattr(wrap_npu_custom.HOOKNpuOP, attr_name)
159
+
160
+
161
+ api_register = ApiRegistry()
@@ -0,0 +1,109 @@
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+ """
4
+ # Copyright (C) 2019-2020. Huawei Technologies Co., Ltd. All rights reserved.
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """
17
+
18
+ import functools
19
+ import threading
20
+ import torch
21
+ import torch.nn as nn
22
+ import torch.utils.hooks as full_hooks
23
+ from msprobe.core.common.const import Const
24
+
25
+
26
+ class HOOKModule(nn.Module):
27
+ module_count = {}
28
+ inner_stop_hook = {}
29
+
30
+ def __init__(self, build_hook) -> None:
31
+ super(HOOKModule, self).__init__()
32
+ self.has_overflow = False
33
+ self.prefix = ""
34
+ self.current_thread = threading.current_thread().ident
35
+ if self.current_thread not in HOOKModule.inner_stop_hook:
36
+ HOOKModule.inner_stop_hook[self.current_thread] = False
37
+ self.stop_hook = HOOKModule.inner_stop_hook.get(self.current_thread, False)
38
+
39
+ if not self.stop_hook:
40
+ if hasattr(self, "prefix_op_name_"):
41
+ self.prefix = self.prefix_op_name_
42
+
43
+ if self.prefix not in HOOKModule.module_count:
44
+ HOOKModule.module_count[self.prefix] = 1
45
+ self.prefix += '0' + Const.SEP
46
+ else:
47
+ HOOKModule.module_count[self.prefix] += 1
48
+ self.prefix = self.prefix + str(HOOKModule.module_count[self.prefix] - 1) + Const.SEP
49
+ forward_pre_hook, forward_hook, backward_hook = build_hook(self.prefix)
50
+ self.register_forward_pre_hook(forward_pre_hook, with_kwargs=True)
51
+ self.register_forward_hook(forward_hook, with_kwargs=True)
52
+ self.register_backward_hook(backward_hook)
53
+
54
+ def __call__(self, *input, **kwargs):
55
+ changed = False
56
+ if not self.stop_hook:
57
+ HOOKModule.inner_stop_hook[self.current_thread] = True
58
+ changed = True
59
+ result = self._call_func(*input, **kwargs)
60
+ if changed:
61
+ HOOKModule.inner_stop_hook[self.current_thread] = False
62
+ return result
63
+
64
+ def _call_func(self, *input, **kwargs):
65
+ full_backward_hooks, non_full_backward_hooks = [], []
66
+ if len(self._backward_hooks) > 0:
67
+ full_backward_hooks, non_full_backward_hooks = self._get_backward_hooks()
68
+ for hook in self._forward_pre_hooks.values():
69
+ result_input, result_kwargs = hook(self, input, kwargs)
70
+ if result_input is not None:
71
+ if not isinstance(result_input, tuple):
72
+ result_input = (result_input,)
73
+ input = result_input
74
+ if result_kwargs is not None:
75
+ kwargs = result_kwargs
76
+ bw_hook = None
77
+ if len(full_backward_hooks) > 0:
78
+ bw_hook = full_hooks.BackwardHook(self, full_backward_hooks)
79
+ input = bw_hook.setup_input_hook(input)
80
+ if torch._C._get_tracing_state():
81
+ result = self._slow_forward(*input, **kwargs)
82
+ else:
83
+ result = self.forward(*input, **kwargs)
84
+ for hook in self._forward_hooks.values():
85
+ hook_result = hook(self, input, kwargs, result)
86
+ if hook_result is not None:
87
+ result = hook_result
88
+ if bw_hook:
89
+ result = bw_hook.setup_output_hook(result)
90
+ if len(non_full_backward_hooks) > 0:
91
+ var = result
92
+ while not isinstance(var, torch.Tensor):
93
+ if isinstance(var, dict):
94
+ var = next((v for v in var.values() if isinstance(v, torch.Tensor)))
95
+ elif isinstance(var, (list, tuple)):
96
+ if var:
97
+ var = var[0]
98
+ else:
99
+ return result
100
+ else:
101
+ return result
102
+ grad_fn = var.grad_fn
103
+ if grad_fn is not None:
104
+ for hook in non_full_backward_hooks:
105
+ wrapper = functools.partial(hook, self)
106
+ functools.update_wrapper(wrapper, hook)
107
+ grad_fn.register_hook(wrapper)
108
+ self._maybe_warn_non_full_backward_hook(input, result, grad_fn)
109
+ return result