mindstudio-probe 1.1.0__py3-none-any.whl → 1.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (299) hide show
  1. {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/METADATA +7 -6
  2. mindstudio_probe-1.2.1.dist-info/RECORD +396 -0
  3. {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/WHEEL +1 -1
  4. {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/entry_points.txt +0 -1
  5. msprobe/CMakeLists.txt +5 -0
  6. msprobe/README.md +51 -20
  7. msprobe/config.json +2 -3
  8. msprobe/core/advisor/advisor.py +8 -3
  9. msprobe/core/common/const.py +264 -15
  10. msprobe/core/common/exceptions.py +27 -3
  11. msprobe/core/common/file_utils.py +176 -26
  12. msprobe/core/common/inplace_op_checker.py +15 -0
  13. msprobe/core/common/inplace_ops.yaml +3 -0
  14. msprobe/core/common/log.py +27 -9
  15. msprobe/core/common/utils.py +204 -77
  16. msprobe/core/common_config.py +49 -14
  17. msprobe/core/compare/acc_compare.py +274 -198
  18. msprobe/core/compare/check.py +32 -33
  19. msprobe/core/compare/compare_cli.py +32 -14
  20. msprobe/core/compare/highlight.py +283 -127
  21. msprobe/core/compare/layer_mapping/__init__.py +19 -0
  22. msprobe/core/compare/layer_mapping/data_scope_parser.py +246 -0
  23. msprobe/core/compare/layer_mapping/layer_mapping.py +249 -0
  24. msprobe/core/compare/layer_mapping/postprocess_pass.py +95 -0
  25. msprobe/core/compare/merge_result/merge_result.py +380 -0
  26. msprobe/core/compare/merge_result/merge_result_cli.py +31 -0
  27. msprobe/core/compare/multiprocessing_compute.py +2 -2
  28. msprobe/core/compare/npy_compare.py +135 -144
  29. msprobe/core/compare/utils.py +419 -274
  30. msprobe/core/data_dump/data_collector.py +60 -28
  31. msprobe/core/data_dump/data_processor/base.py +84 -36
  32. msprobe/core/data_dump/data_processor/factory.py +5 -3
  33. msprobe/core/data_dump/data_processor/mindspore_processor.py +152 -18
  34. msprobe/core/data_dump/data_processor/pytorch_processor.py +267 -110
  35. msprobe/core/data_dump/json_writer.py +29 -1
  36. msprobe/core/data_dump/scope.py +119 -39
  37. msprobe/core/grad_probe/constant.py +27 -13
  38. msprobe/core/grad_probe/grad_compare.py +18 -1
  39. msprobe/core/grad_probe/utils.py +30 -2
  40. msprobe/core/overflow_check/abnormal_scene.py +189 -0
  41. msprobe/core/overflow_check/api_info.py +55 -0
  42. msprobe/core/overflow_check/checker.py +138 -0
  43. msprobe/core/overflow_check/filter.py +157 -0
  44. msprobe/core/overflow_check/ignore_rules.yaml +55 -0
  45. msprobe/core/overflow_check/level.py +22 -0
  46. msprobe/core/overflow_check/utils.py +28 -0
  47. msprobe/docs/01.installation.md +96 -7
  48. msprobe/docs/02.config_introduction.md +50 -23
  49. msprobe/docs/03.config_examples.md +2 -9
  50. msprobe/docs/04.kernel_dump_PyTorch.md +73 -0
  51. msprobe/docs/05.data_dump_PyTorch.md +93 -61
  52. msprobe/docs/06.data_dump_MindSpore.md +200 -95
  53. msprobe/docs/07.accuracy_checker_PyTorch.md +28 -28
  54. msprobe/docs/08.accuracy_checker_online_PyTorch.md +1 -6
  55. msprobe/docs/09.accuracy_checker_MindSpore.md +44 -8
  56. msprobe/docs/10.accuracy_compare_PyTorch.md +114 -50
  57. msprobe/docs/11.accuracy_compare_MindSpore.md +340 -48
  58. msprobe/docs/12.overflow_check_PyTorch.md +2 -2
  59. msprobe/docs/13.overflow_check_MindSpore.md +6 -6
  60. msprobe/docs/15.free_benchmarking_PyTorch.md +4 -5
  61. msprobe/docs/16.free_benchmarking_MindSpore.md +56 -37
  62. msprobe/docs/17.grad_probe.md +5 -6
  63. msprobe/docs/19.monitor.md +561 -0
  64. msprobe/docs/20.monitor_performance_baseline.md +52 -0
  65. msprobe/docs/21.visualization_PyTorch.md +466 -0
  66. msprobe/docs/22.visualization_MindSpore.md +481 -0
  67. msprobe/docs/23.generate_operator_PyTorch.md +107 -0
  68. msprobe/docs/24.code_mapping_Mindspore.md +28 -0
  69. msprobe/docs/25.tool_function_introduction.md +29 -0
  70. msprobe/docs/26.data_dump_PyTorch_baseline.md +37 -0
  71. msprobe/docs/27.dump_json_instruction.md +521 -0
  72. msprobe/docs/FAQ.md +29 -2
  73. msprobe/docs/accuracy_checker_MindSpore/accuracy_checker_MindSpore_baseline.md +14 -0
  74. msprobe/docs/data_dump_MindSpore/data_dump_MindSpore_baseline.md +22 -0
  75. msprobe/docs/data_dump_MindSpore/dynamic_graph_quick_start_example.md +211 -0
  76. msprobe/docs/img/compare_result.png +0 -0
  77. msprobe/docs/img/merge_result.png +0 -0
  78. msprobe/docs/img/monitor/cpu_info.png +0 -0
  79. msprobe/docs/img/visualization/fuzzy_match_ms.png +0 -0
  80. msprobe/docs/img/visualization/fuzzy_match_pt.png +0 -0
  81. msprobe/docs/img/visualization/tensorboard_1.png +0 -0
  82. msprobe/docs/img/visualization/tensorboard_2.png +0 -0
  83. msprobe/docs/img/visualization/vis_browser_1.png +0 -0
  84. msprobe/docs/img/visualization/vis_browser_2.png +0 -0
  85. msprobe/docs/img/visualization/vis_precision_info.png +0 -0
  86. msprobe/docs/img/visualization/vis_search_info.png +0 -0
  87. msprobe/docs/img/visualization/vis_show_info.png +0 -0
  88. msprobe/docs/img/visualization/vis_showcase.png +0 -0
  89. msprobe/docs/img/visualization/vis_unmatch_info.png +0 -0
  90. msprobe/docs/visualization/GPTModel.png +0 -0
  91. msprobe/docs/visualization/ParallelMLP.png +0 -0
  92. msprobe/docs/visualization/layer_mapping_example.md +132 -0
  93. msprobe/docs/visualization/mapping.png +0 -0
  94. msprobe/docs/visualization/mapping1.png +0 -0
  95. msprobe/docs/visualization/module_name.png +0 -0
  96. msprobe/docs/visualization/module_name1.png +0 -0
  97. msprobe/docs/visualization/no_mapping.png +0 -0
  98. msprobe/docs/visualization/no_mapping1.png +0 -0
  99. msprobe/docs/visualization/no_mapping_analyze.png +0 -0
  100. msprobe/docs/visualization/top_layer.png +0 -0
  101. msprobe/mindspore/__init__.py +25 -0
  102. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +151 -151
  103. msprobe/mindspore/api_accuracy_checker/api_info.py +21 -6
  104. msprobe/mindspore/api_accuracy_checker/api_runner.py +43 -18
  105. msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +21 -7
  106. msprobe/mindspore/api_accuracy_checker/checker_support_api.yaml +77 -0
  107. msprobe/mindspore/api_accuracy_checker/cmd_parser.py +64 -1
  108. msprobe/mindspore/api_accuracy_checker/compute_element.py +64 -31
  109. msprobe/mindspore/api_accuracy_checker/data_manager.py +301 -0
  110. msprobe/mindspore/api_accuracy_checker/main.py +28 -3
  111. msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +212 -0
  112. msprobe/mindspore/api_accuracy_checker/multi_data_manager.py +60 -0
  113. msprobe/mindspore/api_accuracy_checker/type_mapping.py +22 -5
  114. msprobe/mindspore/api_accuracy_checker/utils.py +34 -17
  115. msprobe/mindspore/cell_processor.py +33 -12
  116. msprobe/mindspore/code_mapping/bind.py +264 -0
  117. msprobe/mindspore/code_mapping/cmd_parser.py +40 -0
  118. msprobe/mindspore/code_mapping/graph.py +49 -0
  119. msprobe/mindspore/code_mapping/graph_parser.py +226 -0
  120. msprobe/mindspore/code_mapping/main.py +24 -0
  121. msprobe/mindspore/code_mapping/processor.py +34 -0
  122. msprobe/mindspore/common/const.py +35 -13
  123. msprobe/mindspore/common/log.py +5 -9
  124. msprobe/mindspore/common/utils.py +88 -4
  125. msprobe/mindspore/compare/distributed_compare.py +22 -24
  126. msprobe/mindspore/compare/ms_compare.py +333 -268
  127. msprobe/mindspore/compare/ms_graph_compare.py +95 -52
  128. msprobe/mindspore/debugger/debugger_config.py +7 -1
  129. msprobe/mindspore/debugger/precision_debugger.py +87 -12
  130. msprobe/mindspore/dump/dump_tool_factory.py +3 -1
  131. msprobe/mindspore/dump/hook_cell/api_registry.py +95 -18
  132. msprobe/mindspore/dump/hook_cell/hook_cell.py +60 -38
  133. msprobe/mindspore/dump/hook_cell/primitive_hooks.py +45 -30
  134. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +36 -1
  135. msprobe/mindspore/dump/hook_cell/wrap_api.py +92 -1
  136. msprobe/mindspore/dump/jit_dump.py +17 -5
  137. msprobe/mindspore/dump/kernel_dump/kernel_config.py +33 -0
  138. msprobe/mindspore/dump/kernel_graph_dump.py +9 -4
  139. msprobe/mindspore/dump/kernel_kbyk_dump.py +2 -4
  140. msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +140 -0
  141. msprobe/mindspore/dym_loader/hook_dynamic_loader.h +53 -0
  142. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +156 -41
  143. msprobe/mindspore/free_benchmark/common/handler_params.py +1 -2
  144. msprobe/mindspore/free_benchmark/common/utils.py +19 -4
  145. msprobe/mindspore/free_benchmark/data/support_wrap_ops.yaml +0 -204
  146. msprobe/mindspore/free_benchmark/handler/base_handler.py +3 -3
  147. msprobe/mindspore/free_benchmark/handler/check_handler.py +4 -5
  148. msprobe/mindspore/free_benchmark/handler/fix_handler.py +4 -4
  149. msprobe/mindspore/free_benchmark/handler/handler_factory.py +4 -4
  150. msprobe/mindspore/free_benchmark/perturbation/add_noise.py +2 -2
  151. msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +15 -6
  152. msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +2 -2
  153. msprobe/mindspore/free_benchmark/perturbation/exchange_value.py +2 -2
  154. msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +13 -6
  155. msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +2 -2
  156. msprobe/mindspore/free_benchmark/self_check_tool_factory.py +2 -2
  157. msprobe/mindspore/grad_probe/global_context.py +28 -8
  158. msprobe/mindspore/grad_probe/grad_analyzer.py +50 -24
  159. msprobe/mindspore/grad_probe/grad_monitor.py +16 -1
  160. msprobe/mindspore/grad_probe/grad_stat_csv.py +33 -5
  161. msprobe/mindspore/grad_probe/hook.py +35 -12
  162. msprobe/mindspore/grad_probe/utils.py +18 -5
  163. msprobe/mindspore/mindtorch/__init__.py +18 -0
  164. msprobe/mindspore/mindtorch/mindtorch_adaptor.py +255 -0
  165. msprobe/mindspore/ms_config.py +27 -16
  166. msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +9 -4
  167. msprobe/mindspore/runtime.py +15 -0
  168. msprobe/mindspore/service.py +285 -113
  169. msprobe/mindspore/task_handler_factory.py +15 -0
  170. msprobe/msprobe.py +48 -10
  171. msprobe/pytorch/__init__.py +8 -6
  172. msprobe/pytorch/api_accuracy_checker/common/config.py +62 -0
  173. msprobe/pytorch/api_accuracy_checker/common/utils.py +31 -16
  174. msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +41 -8
  175. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +103 -271
  176. msprobe/pytorch/api_accuracy_checker/compare/api_precision_standard.yaml +4 -1
  177. msprobe/pytorch/api_accuracy_checker/compare/compare.py +69 -68
  178. msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +54 -0
  179. msprobe/pytorch/api_accuracy_checker/compare/compare_input.py +51 -0
  180. msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +2 -4
  181. msprobe/pytorch/api_accuracy_checker/generate_op_script/config_op.json +9 -0
  182. msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +478 -0
  183. msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +365 -0
  184. msprobe/pytorch/api_accuracy_checker/precision_standard/absolute_threshold.py +106 -0
  185. msprobe/pytorch/api_accuracy_checker/precision_standard/accumulative_error_compare.py +107 -0
  186. msprobe/pytorch/api_accuracy_checker/precision_standard/base_standard.py +151 -0
  187. msprobe/pytorch/api_accuracy_checker/precision_standard/benchmark_compare.py +226 -0
  188. msprobe/pytorch/api_accuracy_checker/precision_standard/binary_consistency.py +68 -0
  189. msprobe/pytorch/api_accuracy_checker/precision_standard/standard_config.py +218 -0
  190. msprobe/pytorch/api_accuracy_checker/precision_standard/standard_register.py +104 -0
  191. msprobe/pytorch/api_accuracy_checker/precision_standard/thousandth_standard.py +63 -0
  192. msprobe/pytorch/api_accuracy_checker/precision_standard/ulp_compare.py +200 -0
  193. msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +63 -2
  194. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +21 -15
  195. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +54 -22
  196. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +140 -71
  197. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +49 -8
  198. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +9 -24
  199. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +4 -12
  200. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +5 -3
  201. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/dump_dispatch.py +9 -4
  202. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +3 -11
  203. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +2 -2
  204. msprobe/pytorch/bench_functions/confusion_transpose.py +5 -1
  205. msprobe/pytorch/bench_functions/matmul_backward.py +12 -0
  206. msprobe/pytorch/bench_functions/npu_fusion_attention.py +142 -16
  207. msprobe/pytorch/bench_functions/rotary_mul.py +4 -0
  208. msprobe/pytorch/bench_functions/swiglu.py +10 -2
  209. msprobe/pytorch/common/parse_json.py +7 -6
  210. msprobe/pytorch/common/utils.py +101 -7
  211. msprobe/pytorch/compare/distributed_compare.py +17 -30
  212. msprobe/pytorch/compare/pt_compare.py +44 -22
  213. msprobe/pytorch/debugger/debugger_config.py +46 -27
  214. msprobe/pytorch/debugger/precision_debugger.py +42 -12
  215. msprobe/pytorch/dump/kernel_dump/kernel_config.py +33 -0
  216. msprobe/pytorch/dump/module_dump/module_dump.py +86 -0
  217. msprobe/pytorch/{module_processer.py → dump/module_dump/module_processer.py} +81 -10
  218. msprobe/pytorch/free_benchmark/common/constant.py +15 -0
  219. msprobe/pytorch/free_benchmark/common/counter.py +15 -0
  220. msprobe/pytorch/free_benchmark/common/enums.py +15 -0
  221. msprobe/pytorch/free_benchmark/common/params.py +10 -2
  222. msprobe/pytorch/free_benchmark/common/utils.py +29 -4
  223. msprobe/pytorch/free_benchmark/compare/grad_saver.py +20 -5
  224. msprobe/pytorch/free_benchmark/compare/single_benchmark.py +2 -0
  225. msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +3 -1
  226. msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +6 -4
  227. msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +2 -0
  228. msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +4 -0
  229. msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +41 -47
  230. msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +6 -5
  231. msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +0 -4
  232. msprobe/pytorch/grad_probe/grad_monitor.py +23 -6
  233. msprobe/pytorch/grad_probe/grad_stat_csv.py +40 -10
  234. msprobe/pytorch/hook_module/__init__.py +1 -1
  235. msprobe/pytorch/hook_module/hook_module.py +14 -11
  236. msprobe/pytorch/hook_module/register_optimizer_hook.py +59 -0
  237. msprobe/pytorch/hook_module/support_wrap_ops.yaml +35 -0
  238. msprobe/pytorch/hook_module/wrap_distributed.py +6 -8
  239. msprobe/pytorch/hook_module/wrap_functional.py +0 -38
  240. msprobe/pytorch/monitor/__init__.py +0 -0
  241. msprobe/pytorch/monitor/anomaly_analyse.py +201 -0
  242. msprobe/pytorch/monitor/anomaly_detect.py +425 -0
  243. msprobe/pytorch/monitor/csv2tb.py +166 -0
  244. msprobe/pytorch/monitor/distributed/__init__.py +0 -0
  245. msprobe/pytorch/monitor/distributed/distributed_ops.yaml +19 -0
  246. msprobe/pytorch/monitor/distributed/stack_blacklist.yaml +5 -0
  247. msprobe/pytorch/monitor/distributed/wrap_distributed.py +283 -0
  248. msprobe/pytorch/monitor/features.py +108 -0
  249. msprobe/pytorch/monitor/module_hook.py +1076 -0
  250. msprobe/pytorch/monitor/module_metric.py +172 -0
  251. msprobe/pytorch/monitor/module_spec_verifier.py +95 -0
  252. msprobe/pytorch/monitor/optimizer_collect.py +333 -0
  253. msprobe/pytorch/monitor/unittest/__init__.py +0 -0
  254. msprobe/pytorch/monitor/unittest/test_monitor.py +160 -0
  255. msprobe/pytorch/monitor/utils.py +321 -0
  256. msprobe/pytorch/monitor/visualizer.py +59 -0
  257. msprobe/pytorch/online_dispatch/__init__.py +2 -3
  258. msprobe/pytorch/online_dispatch/compare.py +29 -38
  259. msprobe/pytorch/online_dispatch/dispatch.py +58 -27
  260. msprobe/pytorch/online_dispatch/dump_compare.py +21 -9
  261. msprobe/pytorch/online_dispatch/single_compare.py +53 -32
  262. msprobe/pytorch/online_dispatch/torch_ops_config.yaml +1 -1
  263. msprobe/pytorch/online_dispatch/utils.py +49 -21
  264. msprobe/pytorch/parse_tool/lib/compare.py +21 -27
  265. msprobe/pytorch/parse_tool/lib/config.py +6 -8
  266. msprobe/pytorch/parse_tool/lib/file_desc.py +15 -1
  267. msprobe/pytorch/parse_tool/lib/interactive_cli.py +10 -10
  268. msprobe/pytorch/parse_tool/lib/parse_exception.py +7 -7
  269. msprobe/pytorch/parse_tool/lib/parse_tool.py +12 -12
  270. msprobe/pytorch/parse_tool/lib/utils.py +33 -53
  271. msprobe/pytorch/parse_tool/lib/visualization.py +11 -10
  272. msprobe/pytorch/pt_config.py +31 -8
  273. msprobe/pytorch/service.py +188 -108
  274. msprobe/visualization/__init__.py +14 -0
  275. msprobe/visualization/builder/__init__.py +14 -0
  276. msprobe/visualization/builder/graph_builder.py +222 -0
  277. msprobe/visualization/builder/msprobe_adapter.py +227 -0
  278. msprobe/visualization/compare/__init__.py +14 -0
  279. msprobe/visualization/compare/graph_comparator.py +180 -0
  280. msprobe/visualization/compare/mode_adapter.py +197 -0
  281. msprobe/visualization/graph/__init__.py +14 -0
  282. msprobe/visualization/graph/base_node.py +119 -0
  283. msprobe/visualization/graph/distributed_analyzer.py +318 -0
  284. msprobe/visualization/graph/graph.py +209 -0
  285. msprobe/visualization/graph/node_colors.py +95 -0
  286. msprobe/visualization/graph/node_op.py +39 -0
  287. msprobe/visualization/graph_service.py +288 -0
  288. msprobe/visualization/utils.py +217 -0
  289. mindstudio_probe-1.1.0.dist-info/RECORD +0 -287
  290. msprobe/docs/04.acl_config_examples.md +0 -78
  291. msprobe/mindspore/compare/layer_mapping.py +0 -146
  292. msprobe/mindspore/compare/modify_mapping.py +0 -107
  293. msprobe/mindspore/free_benchmark/decorator/dec_forward.py +0 -57
  294. msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +0 -122
  295. msprobe/pytorch/functional/module_dump.py +0 -84
  296. {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/LICENSE +0 -0
  297. {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/top_level.txt +0 -0
  298. /msprobe/mindspore/{free_benchmark/decorator → code_mapping}/__init__.py +0 -0
  299. /msprobe/pytorch/{functional → dump/module_dump}/__init__.py +0 -0
@@ -20,6 +20,7 @@ from typing import Any, Optional, Tuple
20
20
  import numpy as np
21
21
  import torch
22
22
  from msprobe.core.common.const import Const
23
+ from msprobe.core.common.exceptions import FreeBenchmarkException
23
24
  from msprobe.pytorch.free_benchmark import logger
24
25
  from msprobe.pytorch.free_benchmark.common.constant import ThresholdConfig
25
26
  from msprobe.pytorch.free_benchmark.common.enums import (
@@ -88,12 +89,6 @@ class FuzzHandler(ABC):
88
89
  )
89
90
  return origin_output_chunks, perturbed_output_chunks
90
91
 
91
- @staticmethod
92
- def convert_overflow_ratio_to_consistent(ratio):
93
- if math.isnan(ratio) or math.isinf(ratio):
94
- return ThresholdConfig.COMP_CONSISTENT
95
- return ratio
96
-
97
92
  @abstractmethod
98
93
  def get_threshold(self, dtype):
99
94
  pass
@@ -106,49 +101,45 @@ class FuzzHandler(ABC):
106
101
  self, origin_output, perturbed_output, norm_type, abs_tol
107
102
  ):
108
103
  if norm_type == NormType.ENDLESS_NORM:
109
- return self.calculate_error(origin_output, perturbed_output, abs_tol)
104
+ return self.calculate_max_ratio(origin_output, perturbed_output, abs_tol)
110
105
  return ThresholdConfig.COMP_CONSISTENT
111
106
 
112
- def calculate_error(self, origin_output, perturbed_output, abs_tol):
107
+ def calculate_max_ratio(self, origin_output, perturbed_output, abs_tol):
113
108
  origin_output_chunks, perturbed_output_chunks = (
114
109
  self.tensor_split_for_error_calculate(origin_output, perturbed_output)
115
110
  )
116
- norm1 = -np.inf
117
- norm2 = -np.inf
118
- norm3 = np.inf
111
+ if len(origin_output_chunks) != len(perturbed_output_chunks):
112
+ err_msg = (
113
+ f"For {self.params.api_name}, the number of compare tensor chunks is different: "
114
+ f"{len(origin_output_chunks)} != {len(perturbed_output_chunks)}. please check!"
115
+ )
116
+ raise FreeBenchmarkException(
117
+ FreeBenchmarkException.OutputIndexError, err_msg
118
+ )
119
+
120
+ max_ratio = ThresholdConfig.COMP_CONSISTENT
119
121
  for i, chunk_origin in enumerate(origin_output_chunks):
120
122
  if chunk_origin.nelement() == 0:
121
123
  break
122
124
  chunk_perturbed = perturbed_output_chunks[i]
123
- ratio_tensor1 = TorchC.where(
124
- TorchC.abs(chunk_perturbed) > abs_tol,
125
- TorchC.div(
126
- TorchC.clamp(chunk_origin, min=abs_tol),
127
- TorchC.clamp(chunk_perturbed, min=abs_tol),
128
- ),
129
- 1,
125
+ # 如果乘积最小值 < 极小值乘积的负值,认为存在非极小值符号相反的情况
126
+ if TorchC.lt(
127
+ TorchC.min(TorchC.mul(chunk_origin, chunk_perturbed)), -(abs_tol**2)
128
+ ):
129
+ return ThresholdConfig.SYMBOL_FLIPPING
130
+ # 求A/B B/A的比值前,将值限制在大于极小值范围内
131
+ clamp_origin = TorchC.clamp(TorchC.abs(chunk_origin), min=abs_tol)
132
+ clamp_perturbed = TorchC.clamp(TorchC.abs(chunk_perturbed), min=abs_tol)
133
+ # 对于计算结果为nan的情况,认为两者没有差异
134
+ ratio_tensor = TorchC.nan_to_num(
135
+ TorchC.div(clamp_origin, clamp_perturbed),
136
+ nan=ThresholdConfig.COMP_CONSISTENT,
130
137
  )
131
- ratio_tensor2 = TorchC.where(
132
- TorchC.abs(chunk_origin) > abs_tol,
133
- TorchC.div(
134
- TorchC.clamp(chunk_perturbed, min=abs_tol),
135
- TorchC.clamp(chunk_origin, min=abs_tol),
136
- ),
137
- 1,
138
- )
139
- norm_values = TorchC.stack(
140
- [TorchC.max(ratio_tensor1), TorchC.max(ratio_tensor2)]
141
- )
142
- max_ratio1, max_ratio2 = norm_values.tolist()
143
- norm1 = max(norm1, self.convert_overflow_ratio_to_consistent(max_ratio1))
144
- norm2 = max(norm2, self.convert_overflow_ratio_to_consistent(max_ratio2))
145
- norm3 = min(norm3, self.convert_overflow_ratio_to_consistent(max_ratio1))
146
-
147
- if norm3 < 0:
148
- ratio = ThresholdConfig.SYMBOL_FLIPPING
149
- else:
150
- ratio = max(norm1, norm2)
151
- return ratio
138
+ # 求A/B 和 B/A比值最大值,其中 B/A的最大值为 A/B的最小值的倒数
139
+ min_ratio, max_ratio = TorchC.stack([*TorchC.aminmax(ratio_tensor)]).tolist()
140
+ min_ratio_reciprocal = np.inf if min_ratio == 0 else 1 / min_ratio
141
+ max_ratio = max(max_ratio, min_ratio_reciprocal)
142
+ return max_ratio
152
143
 
153
144
  def ratio_calculate(self, origin_output, perturbed_output, norm_type) -> float:
154
145
  try:
@@ -189,6 +180,7 @@ class FuzzHandler(ABC):
189
180
  f"[msprobe] Free Benchmark: For {self.params.api_name} "
190
181
  f"The compare for output type {type(perturbed_output)} is not supported"
191
182
  )
183
+ return True, 1
192
184
 
193
185
  threshold = self.get_threshold(Tools.get_first_tensor_dtype(origin_output))
194
186
  ratio = self.ratio_calculate(
@@ -210,10 +202,12 @@ class FuzzHandler(ABC):
210
202
  )
211
203
  npu_consistent = is_consistent
212
204
  max_fuzz_ratio = (
213
- max_fuzz_ratio if ratio is None else max(max_fuzz_ratio, ratio)
205
+ max_fuzz_ratio
206
+ if not isinstance(ratio, (int, float))
207
+ else max(max_fuzz_ratio, ratio)
214
208
  )
215
- data_params.is_consistent = is_consistent and data_params.is_consistent
216
- if not is_consistent and data_params.grad_unequal_flag:
209
+ data_params.is_consistent = is_consistent
210
+ if not is_consistent:
217
211
  self.unequal_rows.append(
218
212
  make_unequal_row(data_params, self.params, ratio=ratio)
219
213
  )
@@ -225,12 +219,12 @@ class FuzzHandler(ABC):
225
219
  )
226
220
  npu_consistent = npu_consistent and is_consistent
227
221
  max_fuzz_ratio = (
228
- max_fuzz_ratio if ratio is None else max(max_fuzz_ratio, ratio)
229
- )
230
- data_params.is_consistent = (
231
- is_consistent and data_params.is_consistent
222
+ max_fuzz_ratio
223
+ if not isinstance(ratio, (int, float))
224
+ else max(max_fuzz_ratio, ratio)
232
225
  )
233
- if not is_consistent and data_params.grad_unequal_flag:
226
+ data_params.is_consistent = is_consistent
227
+ if not is_consistent:
234
228
  self.unequal_rows.append(
235
229
  make_unequal_row(
236
230
  data_params, self.params, ratio=ratio, index=index_
@@ -15,10 +15,11 @@
15
15
 
16
16
  from typing import Any
17
17
 
18
+ from msprobe.core.common.exceptions import FreeBenchmarkException
19
+ from msprobe.pytorch.free_benchmark import logger
18
20
  from msprobe.pytorch.free_benchmark.common.params import DataParams
19
21
  from msprobe.pytorch.free_benchmark.common.utils import Tools
20
22
  from msprobe.pytorch.free_benchmark.result_handlers.base_handler import FuzzHandler
21
- from msprobe.pytorch.free_benchmark import logger
22
23
 
23
24
 
24
25
  class FixHandler(FuzzHandler):
@@ -31,9 +32,9 @@ class FixHandler(FuzzHandler):
31
32
  return Tools.convert_fuzz_output_to_origin(
32
33
  data_params.original_result, data_params.perturbed_result
33
34
  )
34
- except Exception as e:
35
- logger.warning_on_rank_0(
35
+ except FreeBenchmarkException as e:
36
+ logger.warning(
36
37
  f"[msprobe] Free Benchmark: For {self.params.api_name} "
37
- f"Fix output failed. "
38
+ f"Fix output failed because of: \n{e}"
38
39
  )
39
- return data_params.original_result
40
+ return data_params.original_result
@@ -75,10 +75,6 @@ class PreheatHandler(FuzzHandler):
75
75
  if self.params.preheat_config.get("preheat_step") <= self.params.step:
76
76
  return data_params.original_result
77
77
 
78
- if not data_params.grad_unequal_flag:
79
- data_params.grad_unequal_flag = True
80
- data_params.is_consistent = False
81
- return data_params.original_result
82
78
  preheat_counter.add_api_called_time(self.pure_name)
83
79
 
84
80
  if not self._is_take_a_sample():
@@ -1,15 +1,31 @@
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
1
16
  import os
2
17
  from collections import defaultdict
3
18
 
4
19
  import torch
5
- if int(torch.__version__.split('.')[0]) >= 2:
6
- from torch.optim.optimizer import register_optimizer_step_pre_hook
7
- from msprobe.pytorch.grad_probe.grad_stat_csv import GradStatCsv
8
- from msprobe.core.grad_probe.utils import check_numeral_list_ascend, data_in_list_target
20
+ from msprobe.core.common.file_utils import remove_path, save_npy, write_csv, create_directory
9
21
  from msprobe.core.grad_probe.constant import level_adp
22
+ from msprobe.core.grad_probe.utils import check_numeral_list_ascend, data_in_list_target
10
23
  from msprobe.pytorch.common.log import logger
11
- from msprobe.core.common.file_utils import remove_path, save_npy, write_csv, create_directory
12
24
  from msprobe.pytorch.common.utils import get_rank_id, print_rank_0
25
+ from msprobe.pytorch.grad_probe.grad_stat_csv import GradStatCsv
26
+
27
+ if int(torch.__version__.split('.')[0]) >= 2:
28
+ from torch.optim.optimizer import register_optimizer_step_pre_hook
13
29
 
14
30
 
15
31
  class GradientMonitor:
@@ -75,7 +91,7 @@ class GradientMonitor:
75
91
  output_lines.append(grad_info)
76
92
  if self._level_adp["have_grad_direction"]:
77
93
  GradientMonitor.save_grad_direction(param_name, grad,
78
- f'{self._output_path}/rank{self._rank}/step{self._step}')
94
+ f'{self._output_path}/rank{self._rank}/step{self._step}')
79
95
  output_dirpath = os.path.join(self._output_path, f"rank{getattr(self, '_rank')}")
80
96
  if not os.path.isdir(output_dirpath):
81
97
  create_directory(output_dirpath)
@@ -87,5 +103,6 @@ class GradientMonitor:
87
103
  output_lines.insert(0, header_result)
88
104
  write_csv(output_lines, output_path)
89
105
  logger.info(f"write grad data to {output_path}")
106
+
90
107
  if int(torch.__version__.split('.')[0]) >= 2:
91
108
  register_optimizer_step_pre_hook(optimizer_pre_step_hook)
@@ -1,11 +1,27 @@
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
1
16
  from abc import ABC, abstractmethod
2
17
  from collections import namedtuple
3
18
  import hashlib
19
+ from functools import wraps
4
20
  import torch
5
21
  from msprobe.core.grad_probe.constant import GradConst
6
22
 
7
- CSV_header_input = namedtuple("CSV_header_input", ["bounds"])
8
- CSV_content_input = namedtuple("CSV_content_input", ["grad", "bounds"])
23
+ CsvHeaderInput = namedtuple("CsvHeaderInput", ["bounds"])
24
+ CsvContentInput = namedtuple("CsvContentInput", ["grad", "bounds"])
9
25
 
10
26
 
11
27
  class GradStatCsv:
@@ -15,7 +31,7 @@ class GradStatCsv:
15
31
  def generate_csv_header(level, bounds):
16
32
  header = ["param_name"]
17
33
  for key in level["header"]:
18
- csv_header_input = CSV_header_input(bounds=bounds)
34
+ csv_header_input = CsvHeaderInput(bounds=bounds)
19
35
  header.extend(GradStatCsv.csv[key].generate_csv_header(csv_header_input))
20
36
  return header
21
37
 
@@ -23,7 +39,7 @@ class GradStatCsv:
23
39
  def generate_csv_line(param_name, level, grad, bounds):
24
40
  line = [param_name]
25
41
  for key in level["header"]:
26
- csv_content_input = CSV_content_input(grad=grad, bounds=bounds)
42
+ csv_content_input = CsvContentInput(grad=grad, bounds=bounds)
27
43
  line.extend(GradStatCsv.csv[key].generate_csv_content(csv_content_input))
28
44
  return line
29
45
 
@@ -37,20 +53,24 @@ def register_csv_item(key, cls=None):
37
53
 
38
54
 
39
55
  class CsvItem(ABC):
56
+ @staticmethod
40
57
  @abstractmethod
41
58
  def generate_csv_header(csv_header_input):
42
59
  pass
43
60
 
61
+ @staticmethod
44
62
  @abstractmethod
45
63
  def generate_csv_content(csv_content_input):
46
64
  pass
47
65
 
48
66
 
49
67
  @register_csv_item(GradConst.MD5)
50
- class CSV_md5(CsvItem):
68
+ class CsvMd5(CsvItem):
69
+ @staticmethod
51
70
  def generate_csv_header(csv_header_input):
52
71
  return ["MD5"]
53
72
 
73
+ @staticmethod
54
74
  def generate_csv_content(csv_content_input):
55
75
  grad = csv_content_input.grad
56
76
  tensor_bytes = grad.cpu().detach().float().numpy().tobytes()
@@ -59,7 +79,8 @@ class CSV_md5(CsvItem):
59
79
 
60
80
 
61
81
  @register_csv_item(GradConst.DISTRIBUTION)
62
- class CSV_distribution(CsvItem):
82
+ class CsvDistribution(CsvItem):
83
+ @staticmethod
63
84
  def generate_csv_header(csv_header_input):
64
85
  bounds = csv_header_input.bounds
65
86
  intervals = []
@@ -73,6 +94,7 @@ class CSV_distribution(CsvItem):
73
94
 
74
95
  return intervals
75
96
 
97
+ @staticmethod
76
98
  def generate_csv_content(csv_content_input):
77
99
  grad = csv_content_input.grad
78
100
  bounds = csv_content_input.bounds
@@ -90,40 +112,48 @@ class CSV_distribution(CsvItem):
90
112
 
91
113
 
92
114
  @register_csv_item(GradConst.MAX)
93
- class CSV_max(CsvItem):
115
+ class CsvMax(CsvItem):
116
+ @staticmethod
94
117
  def generate_csv_header(csv_header_input):
95
118
  return ["max"]
96
119
 
120
+ @staticmethod
97
121
  def generate_csv_content(csv_content_input):
98
122
  grad = csv_content_input.grad
99
123
  return [torch.max(grad).cpu().detach().float().numpy().tolist()]
100
124
 
101
125
 
102
126
  @register_csv_item(GradConst.MIN)
103
- class CSV_min(CsvItem):
127
+ class CsvMin(CsvItem):
128
+ @staticmethod
104
129
  def generate_csv_header(csv_header_input):
105
130
  return ["min"]
106
131
 
132
+ @staticmethod
107
133
  def generate_csv_content(csv_content_input):
108
134
  grad = csv_content_input.grad
109
135
  return [torch.min(grad).cpu().detach().float().numpy().tolist()]
110
136
 
111
137
 
112
138
  @register_csv_item(GradConst.NORM)
113
- class CSV_norm(CsvItem):
139
+ class CsvNorm(CsvItem):
140
+ @staticmethod
114
141
  def generate_csv_header(csv_header_input):
115
142
  return ["norm"]
116
143
 
144
+ @staticmethod
117
145
  def generate_csv_content(csv_content_input):
118
146
  grad = csv_content_input.grad
119
147
  return [torch.norm(grad).cpu().detach().float().numpy().tolist()]
120
148
 
121
149
 
122
150
  @register_csv_item(GradConst.SHAPE)
123
- class CSV_shape(CsvItem):
151
+ class CsvShape(CsvItem):
152
+ @staticmethod
124
153
  def generate_csv_header(csv_header_input):
125
154
  return ["shape"]
126
155
 
156
+ @staticmethod
127
157
  def generate_csv_content(csv_content_input):
128
158
  grad = csv_content_input.grad
129
159
  return [list(grad.shape)]
@@ -13,4 +13,4 @@
13
13
  # See the License for the specific language governing permissions and
14
14
  # limitations under the License.
15
15
 
16
- from .wrap_functional import remove_dropout
16
+ from msprobe.pytorch.common.utils import remove_dropout
@@ -15,17 +15,17 @@
15
15
 
16
16
  import functools
17
17
  import threading
18
+ from collections import defaultdict
18
19
 
19
20
  import torch
20
21
  import torch.nn as nn
21
22
  import torch.utils.hooks as full_hooks
22
23
 
23
- from msprobe.core.common.const import Const
24
24
  torch_version_above_or_equal_2 = torch.__version__.split('+')[0] >= '2.0'
25
25
 
26
26
 
27
27
  class HOOKModule(nn.Module):
28
- module_count = {}
28
+ module_count = defaultdict(int)
29
29
  inner_stop_hook = {}
30
30
 
31
31
  def __init__(self, build_hook) -> None:
@@ -41,12 +41,7 @@ class HOOKModule(nn.Module):
41
41
  if hasattr(self, "prefix_op_name_"):
42
42
  self.prefix = self.prefix_op_name_
43
43
 
44
- if self.prefix not in HOOKModule.module_count:
45
- HOOKModule.module_count[self.prefix] = 1
46
- self.prefix += '0' + Const.SEP
47
- else:
48
- HOOKModule.module_count[self.prefix] += 1
49
- self.prefix = self.prefix + str(HOOKModule.module_count[self.prefix] - 1) + Const.SEP
44
+ self.forward_data_collected = False
50
45
  forward_pre_hook, forward_hook, backward_hook, _ = build_hook(self.prefix)
51
46
  if torch_version_above_or_equal_2:
52
47
  self.register_forward_pre_hook(forward_pre_hook, with_kwargs=True)
@@ -66,9 +61,17 @@ class HOOKModule(nn.Module):
66
61
  HOOKModule.inner_stop_hook[self.current_thread] = False
67
62
  return result
68
63
 
69
- @classmethod
70
- def reset_module_stats(cls):
71
- cls.module_count = {}
64
+ @staticmethod
65
+ def reset_module_stats():
66
+ HOOKModule.module_count = defaultdict(int)
67
+
68
+ @staticmethod
69
+ def add_module_count(name):
70
+ HOOKModule.module_count[name] += 1
71
+
72
+ @staticmethod
73
+ def get_module_count(name):
74
+ return HOOKModule.module_count[name]
72
75
 
73
76
  def _call_func(self, *args, **kwargs):
74
77
  full_backward_hooks, non_full_backward_hooks = [], []
@@ -0,0 +1,59 @@
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import torch
17
+ from msprobe.core.common.const import Const
18
+ from msprobe.pytorch.common.log import logger
19
+
20
+ torch_version_above_or_equal_2 = torch.__version__.split('+')[0] >= '2.0'
21
+ if torch_version_above_or_equal_2:
22
+ from torch.optim.optimizer import register_optimizer_step_pre_hook, register_optimizer_step_post_hook
23
+
24
+
25
+ def register_optimizer_hook(data_collector):
26
+ def optimizer_pre_step_hook(optimizer, args, kwargs):
27
+ data_collector.optimizer_status = Const.OPTIMIZER
28
+
29
+ def optimizer_post_step_hook(optimizer, args, kwargs):
30
+ data_collector.optimizer_status = Const.END_PREFIX + Const.OPTIMIZER
31
+
32
+ def patch_clip_grad(func):
33
+ def wrapper(*args, **kwargs):
34
+ data_collector.optimizer_status = Const.CLIP_GRAD
35
+ func(*args, **kwargs)
36
+ data_collector.optimizer_status = Const.END_PREFIX + Const.CLIP_GRAD
37
+
38
+ return wrapper
39
+
40
+ if torch_version_above_or_equal_2:
41
+ register_optimizer_step_pre_hook(optimizer_pre_step_hook)
42
+ register_optimizer_step_post_hook(optimizer_post_step_hook)
43
+ else:
44
+ logger.info_on_rank_0("Pytorch version is below 2.0, cannot register optimizer hook.")
45
+
46
+ try:
47
+ torch.nn.utils.clip_grad_norm_ = patch_clip_grad(torch.nn.utils.clip_grad_norm_)
48
+ torch.nn.utils.clip_grad_norm = patch_clip_grad(torch.nn.utils.clip_grad_norm)
49
+ torch.nn.utils.clip_grad_value_ = patch_clip_grad(torch.nn.utils.clip_grad_value_)
50
+ except Exception as e:
51
+ logger.info_on_rank_0("Cannot patch clip grad function. detail:%s" % str(e))
52
+
53
+ try:
54
+ from megatron.core.optimizer import MegatronOptimizer
55
+ MegatronOptimizer.clip_grad_norm = patch_clip_grad(MegatronOptimizer.clip_grad_norm)
56
+ except ImportError:
57
+ pass
58
+ except Exception as e:
59
+ logger.info_on_rank_0("Cannot patch megatron clip grad function. detail:%s" % str(e))
@@ -138,6 +138,10 @@ functional:
138
138
  - fold
139
139
  - multi_head_attention_forward
140
140
  - scaled_dot_product_attention
141
+ - lp_pool3d
142
+ - dropout1d
143
+ - mish
144
+ - huber_loss
141
145
 
142
146
  tensor:
143
147
  - __add__
@@ -172,6 +176,7 @@ tensor:
172
176
  - __sub__
173
177
  - __truediv__
174
178
  - __xor__
179
+ - __pow__
175
180
  - abs
176
181
  - abs_
177
182
  - absolute
@@ -557,6 +562,27 @@ tensor:
557
562
  - view_as
558
563
  - xlogy
559
564
  - xlogy_
565
+ - split
566
+ - stft
567
+ - nan_to_num
568
+ - dsplit
569
+ - orgqr
570
+ - bitwise_left_shift_
571
+ - arctan2
572
+ - histogram
573
+ - q_zero_point
574
+ - adjoint
575
+ - ormqr
576
+ - bitwise_right_shift_
577
+ - nanquantile
578
+ - lu
579
+ - quantile
580
+ - arctan2_
581
+ - qr
582
+ - diagonal_scatter
583
+ - corrcoef
584
+ - vsplit
585
+ - aminmax
560
586
 
561
587
  torch:
562
588
  - linalg.norm
@@ -1130,6 +1156,15 @@ torch_npu:
1130
1156
  - npu_prompt_flash_attention
1131
1157
  - npu_lstm
1132
1158
  - npu_apply_adam
1159
+ - npu_apply_adam_w
1160
+ - npu_anti_quant
1161
+ - npu_grouped_matmu
1162
+ - npu_quant_scatter
1163
+ - npu_group_norm_silu
1164
+ - npu_format_cast
1165
+ - npu_moe_finalize_routing
1166
+ - npu_moe_gating_top_k_softmax
1167
+ - npu_trans_quant_param
1133
1168
 
1134
1169
  aten:
1135
1170
  - signbit
@@ -21,7 +21,6 @@ from msprobe.pytorch.hook_module.hook_module import HOOKModule
21
21
  from msprobe.pytorch.common.utils import torch_device_guard
22
22
  from msprobe.core.common.const import Const
23
23
  from msprobe.core.common.file_utils import load_yaml
24
- from msprobe.core.common.inplace_op_checker import InplaceOpChecker
25
24
 
26
25
 
27
26
  cur_path = os.path.dirname(os.path.realpath(__file__))
@@ -49,17 +48,16 @@ class DistributedOPTemplate(HOOKModule):
49
48
  self.op_name_ = op_name
50
49
  self.prefix_op_name_ = "Distributed" + Const.SEP + str(op_name) + Const.SEP
51
50
  super().__init__(build_hook)
52
- if not self.stop_hook and InplaceOpChecker.check(self.op_name_, InplaceOpChecker.OP_DISTRIBUTED):
53
- self.op_is_inplace = True
51
+ if not self.stop_hook:
52
+ self.op_is_distributed = True
54
53
 
55
54
  @torch_device_guard
56
55
  def forward(self, *args, **kwargs):
56
+ handle = distributed_func.get(self.op_name_)(*args, **kwargs)
57
57
  if kwargs.get("async_op") or self.op_name_ in ["isend", "irecv"]:
58
- handle = distributed_func.get(self.op_name_)(*args, **kwargs)
59
- handle.wait()
60
- return handle
61
- else:
62
- return distributed_func.get(self.op_name_)(*args, **kwargs)
58
+ if handle and hasattr(handle, 'wait'):
59
+ handle.wait()
60
+ return handle
63
61
 
64
62
 
65
63
  def wrap_distributed_op(op_name, hook):
@@ -23,44 +23,6 @@ from msprobe.pytorch.common.log import logger
23
23
  from msprobe.core.common.file_utils import load_yaml
24
24
 
25
25
 
26
- def remove_dropout():
27
- if torch.__version__ > "1.8":
28
- logger.info_on_rank_0("For precision comparison, the probability p in the dropout method is set to 0.")
29
- import torch.nn.functional as F
30
- from torch import _VF
31
- from torch.overrides import has_torch_function_unary, handle_torch_function
32
-
33
- def function_dropout(input: torch.Tensor, p: float = 0.5, training: bool = True,
34
- inplace: bool = False) -> torch.Tensor:
35
- if has_torch_function_unary(input):
36
- return handle_torch_function(
37
- function_dropout, (input,), input, p=0., training=training, inplace=inplace)
38
- if p < 0.0 or p > 1.0:
39
- raise ValueError("dropout probability has to be between 0 and 1, " "but got {}".format(p))
40
- return _VF.dropout_(input, 0., training) if inplace else _VF.dropout(input, 0., training)
41
-
42
- def function_dropout2d(input: torch.Tensor, p: float = 0.5, training: bool = True,
43
- inplace: bool = False) -> torch.Tensor:
44
- if has_torch_function_unary(input):
45
- return handle_torch_function(
46
- function_dropout2d, (input,), input, p=0., training=training, inplace=inplace)
47
- if p < 0.0 or p > 1.0:
48
- raise ValueError("dropout probability has to be between 0 and 1, " "but got {}".format(p))
49
- return _VF.feature_dropout_(input, 0., training) if inplace else _VF.feature_dropout(input, 0., training)
50
-
51
- def function_dropout3d(input: torch.Tensor, p: float = 0.5, training: bool = True,
52
- inplace: bool = False) -> torch.Tensor:
53
- if has_torch_function_unary(input):
54
- return handle_torch_function(
55
- function_dropout3d, (input,), input, p=0., training=training, inplace=inplace)
56
- if p < 0.0 or p > 1.0:
57
- raise ValueError("dropout probability has to be between 0 and 1, " "but got {}".format(p))
58
- return _VF.feature_dropout_(input, 0., training) if inplace else _VF.feature_dropout(input, 0., training)
59
-
60
- F.dropout = function_dropout
61
- F.dropout2d = function_dropout2d
62
- F.dropout3d = function_dropout3d
63
-
64
26
  cur_path = os.path.dirname(os.path.realpath(__file__))
65
27
  yaml_path = os.path.join(cur_path, "support_wrap_ops.yaml")
66
28
 
File without changes