mindstudio-probe 1.1.1__py3-none-any.whl → 1.2.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (226) hide show
  1. {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.2.dist-info}/METADATA +3 -2
  2. mindstudio_probe-1.2.2.dist-info/RECORD +415 -0
  3. msprobe/CMakeLists.txt +5 -0
  4. msprobe/README.md +16 -21
  5. msprobe/config.json +1 -0
  6. msprobe/core/common/const.py +185 -11
  7. msprobe/core/common/exceptions.py +3 -1
  8. msprobe/core/common/file_utils.py +33 -7
  9. msprobe/core/common/inplace_ops.yaml +4 -0
  10. msprobe/core/common/utils.py +42 -14
  11. msprobe/core/common_config.py +6 -0
  12. msprobe/core/compare/acc_compare.py +139 -128
  13. msprobe/core/compare/check.py +31 -29
  14. msprobe/core/compare/compare_cli.py +17 -16
  15. msprobe/core/compare/highlight.py +186 -99
  16. msprobe/core/compare/layer_mapping/data_scope_parser.py +19 -8
  17. msprobe/core/compare/layer_mapping/layer_mapping.py +21 -14
  18. msprobe/core/compare/layer_mapping/postprocess_pass.py +4 -3
  19. msprobe/core/compare/merge_result/merge_result.py +381 -0
  20. msprobe/core/compare/merge_result/merge_result_cli.py +31 -0
  21. msprobe/core/compare/merge_result/utils.py +81 -0
  22. msprobe/core/compare/multiprocessing_compute.py +2 -2
  23. msprobe/core/compare/npy_compare.py +109 -147
  24. msprobe/core/compare/utils.py +199 -69
  25. msprobe/core/data_dump/data_collector.py +100 -25
  26. msprobe/core/data_dump/data_processor/base.py +130 -28
  27. msprobe/core/data_dump/data_processor/factory.py +8 -3
  28. msprobe/core/data_dump/data_processor/mindspore_processor.py +170 -23
  29. msprobe/core/data_dump/data_processor/pytorch_processor.py +175 -64
  30. msprobe/core/data_dump/json_writer.py +54 -8
  31. msprobe/core/data_dump/scope.py +19 -18
  32. msprobe/core/overflow_check/abnormal_scene.py +9 -5
  33. msprobe/core/overflow_check/checker.py +1 -1
  34. msprobe/core/overflow_check/utils.py +1 -1
  35. msprobe/docs/01.installation.md +121 -17
  36. msprobe/docs/02.config_introduction.md +18 -16
  37. msprobe/docs/03.config_examples.md +24 -0
  38. msprobe/docs/05.data_dump_PyTorch.md +107 -58
  39. msprobe/docs/06.data_dump_MindSpore.md +95 -34
  40. msprobe/docs/07.accuracy_checker_PyTorch.md +18 -18
  41. msprobe/docs/09.accuracy_checker_MindSpore.md +8 -6
  42. msprobe/docs/10.accuracy_compare_PyTorch.md +99 -41
  43. msprobe/docs/11.accuracy_compare_MindSpore.md +249 -48
  44. msprobe/docs/12.overflow_check_PyTorch.md +1 -1
  45. msprobe/docs/19.monitor.md +310 -220
  46. msprobe/docs/21.visualization_PyTorch.md +125 -35
  47. msprobe/docs/22.visualization_MindSpore.md +149 -41
  48. msprobe/docs/23.generate_operator_PyTorch.md +107 -0
  49. msprobe/docs/24.code_mapping_Mindspore.md +28 -0
  50. msprobe/docs/{23.tool_function_introduction.md → 25.tool_function_introduction.md} +1 -0
  51. msprobe/docs/26.data_dump_PyTorch_baseline.md +37 -0
  52. msprobe/docs/27.dump_json_instruction.md +525 -0
  53. msprobe/docs/28.debugger_save_instruction.md +94 -0
  54. msprobe/docs/28.kernel_dump_MindSpore.md +69 -0
  55. msprobe/docs/FAQ.md +26 -2
  56. msprobe/docs/accuracy_checker_MindSpore/accuracy_checker_MindSpore_baseline.md +14 -0
  57. msprobe/docs/data_dump_MindSpore/data_dump_MindSpore_baseline.md +22 -0
  58. msprobe/docs/img/merge_result.png +0 -0
  59. msprobe/docs/img/monitor/step_count_per_record.png +0 -0
  60. msprobe/docs/img/visualization/fuzzy_match_ms.png +0 -0
  61. msprobe/docs/img/visualization/fuzzy_match_pt.png +0 -0
  62. msprobe/docs/img/visualization/tensorboard_1.png +0 -0
  63. msprobe/docs/img/visualization/tensorboard_2.png +0 -0
  64. msprobe/docs/img/visualization/vis_browser_1.png +0 -0
  65. msprobe/docs/img/visualization/vis_browser_2.png +0 -0
  66. msprobe/docs/img/visualization/vis_precision_info.png +0 -0
  67. msprobe/docs/img/visualization/vis_search_info.png +0 -0
  68. msprobe/docs/img/visualization/vis_show_info.png +0 -0
  69. msprobe/docs/img/visualization/vis_showcase.png +0 -0
  70. msprobe/docs/img/visualization/vis_unmatch_info.png +0 -0
  71. msprobe/docs/visualization/GPTModel.png +0 -0
  72. msprobe/docs/visualization/ParallelMLP.png +0 -0
  73. msprobe/docs/visualization/layer_mapping_example.md +132 -0
  74. msprobe/docs/visualization/mapping.png +0 -0
  75. msprobe/docs/visualization/mapping1.png +0 -0
  76. msprobe/docs/visualization/module_name.png +0 -0
  77. msprobe/docs/visualization/module_name1.png +0 -0
  78. msprobe/docs/visualization/no_mapping.png +0 -0
  79. msprobe/docs/visualization/no_mapping1.png +0 -0
  80. msprobe/docs/visualization/no_mapping_analyze.png +0 -0
  81. msprobe/docs/visualization/top_layer.png +0 -0
  82. msprobe/mindspore/__init__.py +11 -0
  83. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +80 -28
  84. msprobe/mindspore/api_accuracy_checker/api_runner.py +54 -16
  85. msprobe/mindspore/api_accuracy_checker/cmd_parser.py +2 -1
  86. msprobe/mindspore/api_accuracy_checker/compute_element.py +52 -8
  87. msprobe/mindspore/api_accuracy_checker/data_manager.py +37 -0
  88. msprobe/mindspore/api_accuracy_checker/main.py +1 -0
  89. msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +12 -6
  90. msprobe/mindspore/api_accuracy_checker/multi_data_manager.py +3 -1
  91. msprobe/mindspore/api_accuracy_checker/torch_mindtorch_importer.py +129 -0
  92. msprobe/mindspore/api_accuracy_checker/type_mapping.py +24 -1
  93. msprobe/mindspore/api_accuracy_checker/utils.py +6 -1
  94. msprobe/mindspore/code_mapping/bind.py +264 -0
  95. msprobe/mindspore/code_mapping/cmd_parser.py +40 -0
  96. msprobe/mindspore/code_mapping/graph.py +49 -0
  97. msprobe/mindspore/code_mapping/graph_parser.py +226 -0
  98. msprobe/mindspore/code_mapping/main.py +24 -0
  99. msprobe/mindspore/code_mapping/processor.py +34 -0
  100. msprobe/mindspore/common/const.py +3 -1
  101. msprobe/mindspore/common/utils.py +68 -5
  102. msprobe/mindspore/compare/distributed_compare.py +0 -2
  103. msprobe/mindspore/compare/ms_compare.py +105 -63
  104. msprobe/mindspore/compare/ms_graph_compare.py +14 -5
  105. msprobe/mindspore/debugger/debugger_config.py +28 -2
  106. msprobe/mindspore/debugger/precision_debugger.py +100 -12
  107. msprobe/mindspore/dump/hook_cell/api_registry.py +85 -16
  108. msprobe/mindspore/dump/hook_cell/hook_cell.py +60 -38
  109. msprobe/mindspore/dump/hook_cell/primitive_hooks.py +33 -15
  110. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +11 -1
  111. msprobe/mindspore/dump/hook_cell/wrap_api.py +92 -1
  112. msprobe/mindspore/dump/jit_dump.py +7 -6
  113. msprobe/mindspore/dump/kernel_dump/kernel_config.py +33 -0
  114. msprobe/mindspore/dump/kernel_graph_dump.py +7 -0
  115. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +13 -4
  116. msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +2 -2
  117. msprobe/mindspore/grad_probe/grad_analyzer.py +24 -12
  118. msprobe/mindspore/grad_probe/hook.py +13 -4
  119. msprobe/mindspore/mindtorch/__init__.py +18 -0
  120. msprobe/mindspore/mindtorch/mindtorch_adaptor.py +255 -0
  121. msprobe/mindspore/monitor/anomaly_detect.py +404 -0
  122. msprobe/mindspore/monitor/distributed/__init__.py +0 -0
  123. msprobe/mindspore/monitor/distributed/distributed_ops.yaml +15 -0
  124. msprobe/mindspore/monitor/distributed/stack_blacklist.yaml +5 -0
  125. msprobe/mindspore/monitor/distributed/wrap_distributed.py +300 -0
  126. msprobe/mindspore/monitor/features.py +63 -0
  127. msprobe/mindspore/monitor/module_hook.py +821 -0
  128. msprobe/mindspore/monitor/module_spec_verifier.py +94 -0
  129. msprobe/mindspore/monitor/utils.py +267 -0
  130. msprobe/mindspore/ms_config.py +13 -3
  131. msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +7 -0
  132. msprobe/mindspore/service.py +347 -107
  133. msprobe/msprobe.py +24 -3
  134. msprobe/pytorch/__init__.py +7 -7
  135. msprobe/pytorch/api_accuracy_checker/common/utils.py +31 -16
  136. msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +41 -8
  137. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +100 -267
  138. msprobe/pytorch/api_accuracy_checker/compare/api_precision_standard.yaml +4 -1
  139. msprobe/pytorch/api_accuracy_checker/compare/compare.py +69 -68
  140. msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +54 -0
  141. msprobe/pytorch/api_accuracy_checker/compare/compare_input.py +51 -0
  142. msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +2 -4
  143. msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +55 -31
  144. msprobe/pytorch/api_accuracy_checker/precision_standard/absolute_threshold.py +106 -0
  145. msprobe/pytorch/api_accuracy_checker/precision_standard/accumulative_error_compare.py +107 -0
  146. msprobe/pytorch/api_accuracy_checker/precision_standard/base_standard.py +151 -0
  147. msprobe/pytorch/api_accuracy_checker/precision_standard/benchmark_compare.py +226 -0
  148. msprobe/pytorch/api_accuracy_checker/precision_standard/binary_consistency.py +68 -0
  149. msprobe/pytorch/api_accuracy_checker/precision_standard/standard_config.py +218 -0
  150. msprobe/pytorch/api_accuracy_checker/precision_standard/standard_register.py +104 -0
  151. msprobe/pytorch/api_accuracy_checker/precision_standard/thousandth_standard.py +63 -0
  152. msprobe/pytorch/api_accuracy_checker/precision_standard/ulp_compare.py +200 -0
  153. msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +57 -1
  154. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +2 -1
  155. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +42 -14
  156. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +64 -19
  157. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +34 -4
  158. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +5 -3
  159. msprobe/pytorch/bench_functions/apply_adam.py +215 -0
  160. msprobe/pytorch/bench_functions/group_norm_silu.py +27 -0
  161. msprobe/pytorch/bench_functions/mish.py +21 -0
  162. msprobe/pytorch/bench_functions/moe_gating_top_k_softmax.py +44 -0
  163. msprobe/pytorch/bench_functions/npu_fusion_attention.py +42 -10
  164. msprobe/pytorch/bench_functions/sort_v2.py +21 -0
  165. msprobe/pytorch/common/parse_json.py +2 -1
  166. msprobe/pytorch/common/utils.py +116 -2
  167. msprobe/pytorch/compare/distributed_compare.py +17 -29
  168. msprobe/pytorch/compare/pt_compare.py +40 -20
  169. msprobe/pytorch/debugger/debugger_config.py +42 -17
  170. msprobe/pytorch/debugger/precision_debugger.py +56 -12
  171. msprobe/pytorch/dump/module_dump/__init__.py +0 -0
  172. msprobe/pytorch/dump/module_dump/module_dump.py +86 -0
  173. msprobe/pytorch/dump/module_dump/module_processer.py +204 -0
  174. msprobe/pytorch/free_benchmark/common/params.py +2 -1
  175. msprobe/pytorch/free_benchmark/common/utils.py +3 -0
  176. msprobe/pytorch/free_benchmark/compare/grad_saver.py +0 -2
  177. msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +31 -47
  178. msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +0 -4
  179. msprobe/pytorch/function_factory.py +7 -1
  180. msprobe/pytorch/hook_module/__init__.py +1 -1
  181. msprobe/pytorch/hook_module/hook_module.py +14 -11
  182. msprobe/pytorch/hook_module/register_optimizer_hook.py +59 -0
  183. msprobe/pytorch/hook_module/support_wrap_ops.yaml +36 -1
  184. msprobe/pytorch/hook_module/wrap_distributed.py +10 -8
  185. msprobe/pytorch/hook_module/wrap_functional.py +0 -40
  186. msprobe/pytorch/monitor/anomaly_analyse.py +1 -1
  187. msprobe/pytorch/monitor/anomaly_detect.py +98 -28
  188. msprobe/pytorch/monitor/csv2tb.py +164 -0
  189. msprobe/pytorch/monitor/distributed/wrap_distributed.py +25 -14
  190. msprobe/pytorch/monitor/features.py +3 -3
  191. msprobe/pytorch/monitor/module_hook.py +543 -318
  192. msprobe/pytorch/monitor/module_metric.py +27 -48
  193. msprobe/pytorch/monitor/module_spec_verifier.py +3 -1
  194. msprobe/pytorch/monitor/optimizer_collect.py +76 -56
  195. msprobe/pytorch/monitor/unittest/test_monitor.py +24 -9
  196. msprobe/pytorch/monitor/utils.py +84 -48
  197. msprobe/pytorch/online_dispatch/dispatch.py +8 -2
  198. msprobe/pytorch/parse_tool/lib/compare.py +10 -10
  199. msprobe/pytorch/parse_tool/lib/config.py +5 -7
  200. msprobe/pytorch/parse_tool/lib/file_desc.py +15 -1
  201. msprobe/pytorch/parse_tool/lib/interactive_cli.py +10 -10
  202. msprobe/pytorch/parse_tool/lib/parse_exception.py +7 -7
  203. msprobe/pytorch/parse_tool/lib/parse_tool.py +11 -10
  204. msprobe/pytorch/parse_tool/lib/utils.py +18 -19
  205. msprobe/pytorch/parse_tool/lib/visualization.py +9 -10
  206. msprobe/pytorch/pt_config.py +19 -22
  207. msprobe/pytorch/service.py +264 -115
  208. msprobe/visualization/builder/graph_builder.py +93 -10
  209. msprobe/visualization/builder/msprobe_adapter.py +30 -6
  210. msprobe/visualization/compare/graph_comparator.py +64 -14
  211. msprobe/visualization/compare/mode_adapter.py +1 -15
  212. msprobe/visualization/graph/base_node.py +15 -19
  213. msprobe/visualization/graph/distributed_analyzer.py +395 -0
  214. msprobe/visualization/graph/graph.py +9 -0
  215. msprobe/visualization/graph/node_op.py +4 -2
  216. msprobe/visualization/graph_service.py +100 -27
  217. msprobe/visualization/utils.py +24 -31
  218. mindstudio_probe-1.1.1.dist-info/RECORD +0 -341
  219. msprobe/pytorch/functional/module_dump.py +0 -84
  220. msprobe/pytorch/module_processer.py +0 -150
  221. {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.2.dist-info}/LICENSE +0 -0
  222. {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.2.dist-info}/WHEEL +0 -0
  223. {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.2.dist-info}/entry_points.txt +0 -0
  224. {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.2.dist-info}/top_level.txt +0 -0
  225. /msprobe/docs/{data_dump_Mindspore → data_dump_MindSpore}/dynamic_graph_quick_start_example.md +0 -0
  226. /msprobe/{pytorch/functional → mindspore/code_mapping}/__init__.py +0 -0
@@ -41,6 +41,7 @@ from msprobe.core.common.utils import CompareException
41
41
 
42
42
  def split_json_file(input_file, num_splits, filter_api):
43
43
  forward_data, backward_data, real_data_path = parse_json_info_forward_backward(input_file)
44
+ input_dir = os.path.dirname(os.path.abspath(input_file))
44
45
  if filter_api:
45
46
  forward_data = preprocess_forward_content(forward_data)
46
47
  for data_name in list(forward_data.keys()):
@@ -71,7 +72,7 @@ def split_json_file(input_file, num_splits, filter_api):
71
72
  **backward_data
72
73
  }
73
74
  }
74
- split_filename = f"temp_part{i}.json"
75
+ split_filename = os.path.join(input_dir, f"temp_part{i}.json")
75
76
  save_json(split_filename, temp_data)
76
77
  split_files.append(split_filename)
77
78
 
@@ -23,12 +23,14 @@ try:
23
23
  import torch_npu
24
24
  except ImportError:
25
25
  is_gpu = True
26
+ current_device = "cuda"
26
27
  else:
27
28
  is_gpu = False
29
+ current_device = "npu"
28
30
  import torch
29
31
  from tqdm import tqdm
30
32
  from msprobe.pytorch.api_accuracy_checker.run_ut.run_ut import generate_device_params, get_api_info
31
- from msprobe.pytorch.api_accuracy_checker.run_ut.run_ut_utils import exec_api, is_unsupported_api
33
+ from msprobe.pytorch.api_accuracy_checker.run_ut.run_ut_utils import exec_api, is_unsupported_api, ExecParams
32
34
  from msprobe.core.common.file_utils import check_link, FileChecker
33
35
  from msprobe.pytorch.api_accuracy_checker.common.utils import extract_basic_api_segments
34
36
  from msprobe.core.common.const import FileCheckConst, Const
@@ -61,19 +63,33 @@ def check_tensor_overflow(x):
61
63
  return False
62
64
 
63
65
 
64
- def check_data_overflow(x):
65
- if isinstance(x, (tuple, list)) and x:
66
- for _, item in enumerate(x):
67
- if check_data_overflow(item):
68
- return True
69
- return False
66
+ def check_data_overflow(x, device):
67
+ if isinstance(x, (tuple, list)):
68
+ if not x:
69
+ return False
70
+ return any(check_data_overflow(item, device) for item in x)
70
71
  else:
71
- return check_tensor_overflow(x)
72
+ if device == Const.CPU_LOWERCASE:
73
+ return check_tensor_overflow(x)
74
+ else:
75
+ return torch_npu.npu.utils.npu_check_overflow(x)
76
+
77
+
78
+ def is_bool_output(x):
79
+ if isinstance(x, (tuple, list)):
80
+ if not x:
81
+ return False
82
+ return any(is_bool_output(item) for item in x)
83
+ else:
84
+ return isinstance(x, bool)
72
85
 
73
86
 
74
87
  def run_overflow_check(forward_file):
75
88
  logger.info("start UT test")
76
89
  forward_content, _, real_data_path = parse_json_info_forward_backward(forward_file)
90
+ if real_data_path:
91
+ dump_path = os.path.dirname(forward_file)
92
+ real_data_path = os.path.join(dump_path, Const.DUMP_TENSOR_DATA)
77
93
  for api_full_name, api_info_dict in tqdm(forward_content.items()):
78
94
  if is_unsupported_api(api_full_name, is_overflow_check=True):
79
95
  continue
@@ -87,6 +103,9 @@ def run_overflow_check(forward_file):
87
103
  elif "expected scalar type Long" in str(err):
88
104
  logger.warning(f"API {api_name} not support int32 tensor in CPU, please add {api_name} to CONVERT_API "
89
105
  "'int32_to_int64' list in accuracy_tools/msprobe/core/common/const.py file.")
106
+ elif "could not create a primitive descriptor for a matmul primitive" in str(err):
107
+ logger.warning(f"API {api_name} not support matmul primitive in CPU due to pytorch bug, "
108
+ "so it will be skipped.")
90
109
  else:
91
110
  logger.error(f"Run {api_full_name} UT Error: %s" % str(err))
92
111
 
@@ -98,17 +117,26 @@ def run_torch_api(api_full_name, api_info_dict, real_data_path):
98
117
  if not need_grad:
99
118
  logger.warning("%s function with out=... arguments don't support automatic differentiation, skip backward."
100
119
  % api_full_name)
120
+ device_info_kwargs = kwargs.get(Const.DEVICE)
121
+ if device_info_kwargs and device_info_kwargs.get(Const.VALUE):
122
+ kwargs[Const.DEVICE] = current_device
101
123
  npu_args, npu_kwargs = generate_device_params(args, kwargs, False, api_name)
102
- if kwargs.get("device"):
103
- del kwargs["device"]
104
- out = exec_api(api_type, api_name, Const.CPU_LOWERCASE, args, kwargs)
105
- npu_out = exec_api(api_type, api_name, Const.NPU_LOWERCASE, npu_args, npu_kwargs)
124
+ if kwargs.get(Const.DEVICE):
125
+ del kwargs[Const.DEVICE]
126
+ cpu_exec_params = ExecParams(api_type, api_name, Const.CPU_LOWERCASE, args, kwargs, False, None)
127
+ device_exec_params = ExecParams(api_type, api_name, Const.NPU_LOWERCASE, npu_args, npu_kwargs, False, None)
128
+ out = exec_api(cpu_exec_params)
129
+ npu_out = exec_api(device_exec_params)
106
130
  if out is None and npu_out is None:
107
131
  logger.warning("The %s overflow is a normal overflow, out and npu_out is None." % api_full_name)
108
132
  return
133
+ if is_bool_output(out) or is_bool_output(npu_out):
134
+ logger.warning("The output of %s is bool type.This dtype not support overflow, so it will be skipped."
135
+ % api_full_name)
136
+ return
109
137
 
110
- cpu_overflow = check_data_overflow(out)
111
- npu_overflow = torch_npu.npu.utils.npu_check_overflow(npu_out)
138
+ cpu_overflow = check_data_overflow(out, Const.CPU_LOWERCASE)
139
+ npu_overflow = check_data_overflow(npu_out, Const.NPU_LOWERCASE)
112
140
  if cpu_overflow == npu_overflow:
113
141
  logger.warning("The %s overflow is a normal overflow." % api_full_name)
114
142
  else:
@@ -31,6 +31,7 @@ except ImportError:
31
31
  else:
32
32
  is_gpu = False
33
33
  current_device = "npu"
34
+
34
35
  import torch
35
36
  from tqdm import tqdm
36
37
 
@@ -48,10 +49,12 @@ from msprobe.core.common.file_utils import FileChecker, change_mode, \
48
49
  from msprobe.pytorch.common.log import logger
49
50
  from msprobe.pytorch.pt_config import parse_json_config
50
51
  from msprobe.core.common.const import Const, FileCheckConst, CompareConst
51
- from msprobe.core.common.utils import safe_get_value
52
+ from msprobe.core.common.utils import safe_get_value, CompareException
53
+ from msprobe.pytorch.common.utils import seed_all
52
54
  from msprobe.pytorch.api_accuracy_checker.tensor_transport_layer.attl import ATTL, ATTLConfig, move2device_exec
53
55
  from msprobe.pytorch.api_accuracy_checker.tensor_transport_layer.device_dispatch import ConsumerDispatcher
54
- from msprobe.pytorch.api_accuracy_checker.run_ut.run_ut_utils import generate_cpu_params, generate_device_params
56
+ from msprobe.pytorch.api_accuracy_checker.run_ut.run_ut_utils import generate_cpu_params, generate_device_params, \
57
+ ExecParams
55
58
 
56
59
 
57
60
  current_time = time.strftime("%Y%m%d%H%M%S")
@@ -61,6 +64,7 @@ DETAILS_FILE_NAME = "accuracy_checking_details_" + current_time + ".csv"
61
64
 
62
65
 
63
66
  not_backward_list = ['repeat_interleave']
67
+ unsupported_backward_list = ['masked_select']
64
68
 
65
69
 
66
70
  tqdm_params = {
@@ -237,7 +241,8 @@ def run_torch_api(api_full_name, real_data_path, backward_content, api_info_dict
237
241
  in_fwd_data_list = []
238
242
  backward_message = ''
239
243
  api_type, api_name = extract_basic_api_segments(api_full_name)
240
- args, kwargs, need_grad = get_api_info(api_info_dict, api_name, real_data_path)
244
+ args, kwargs, output_dtype = get_api_info(api_info_dict, api_name, real_data_path)
245
+ need_grad = check_need_grad(api_info_dict)
241
246
  in_fwd_data_list.append(args)
242
247
  in_fwd_data_list.append(kwargs)
243
248
  need_backward = api_full_name in backward_content
@@ -248,14 +253,30 @@ def run_torch_api(api_full_name, real_data_path, backward_content, api_info_dict
248
253
  need_grad = False
249
254
  logger.info("%s %s" % (api_full_name, BackwardMessage.NO_BACKWARD_RESULT_MESSAGE))
250
255
  backward_message += BackwardMessage.NO_BACKWARD_RESULT_MESSAGE
256
+ if api_name in unsupported_backward_list:
257
+ need_grad = False
258
+ logger.info("%s %s" % (api_full_name, BackwardMessage.UNSUPPORT_API_MESSAGE))
259
+ backward_message += BackwardMessage.UNSUPPORT_API_MESSAGE
251
260
  need_backward = need_backward and need_grad
252
- if kwargs.get("device"):
253
- del kwargs["device"]
254
- cpu_args, cpu_kwargs = generate_cpu_params(args, kwargs, need_backward, api_name)
261
+
262
+ device_info_kwargs = kwargs.get(Const.DEVICE)
263
+ if device_info_kwargs and device_info_kwargs.get(Const.VALUE):
264
+ kwargs[Const.DEVICE] = current_device
255
265
  device_args, device_kwargs = generate_device_params(args, kwargs, need_backward, api_name)
266
+ if kwargs.get(Const.DEVICE):
267
+ del kwargs[Const.DEVICE]
268
+ cpu_params = generate_cpu_params(args, kwargs, need_backward, api_name)
269
+ cpu_args, cpu_kwargs = cpu_params.cpu_args, cpu_params.cpu_kwargs
270
+ autocast_dtype, is_autocast = cpu_params.autocast_dtype, cpu_params.is_autocast
271
+ if not is_autocast and output_dtype:
272
+ is_autocast = autocast_dtype != output_dtype
273
+ autocast_dtype = output_dtype
256
274
  bench_grad_out, device_grad_out = None, None
257
- out = exec_api(api_type, api_name, Const.CPU_LOWERCASE, cpu_args, cpu_kwargs)
258
- device_out = exec_api(api_type, api_name, current_device, device_args, device_kwargs)
275
+ cpu_exec_params = ExecParams(api_type, api_name, Const.CPU_LOWERCASE, cpu_args, cpu_kwargs, False, autocast_dtype)
276
+ out = exec_api(cpu_exec_params)
277
+ device_exec_params = ExecParams(api_type, api_name, current_device, device_args, device_kwargs, is_autocast,
278
+ autocast_dtype)
279
+ device_out = exec_api(device_exec_params)
259
280
  current_path = os.path.dirname(os.path.realpath(__file__))
260
281
  ut_setting_path = os.path.join(current_path, "torch_ut_setting.json")
261
282
  api_setting_dict = get_json_contents(ut_setting_path)
@@ -273,7 +294,8 @@ def run_torch_api(api_full_name, real_data_path, backward_content, api_info_dict
273
294
  }
274
295
  grad = gen_args(backward_args, api_name, func_options)
275
296
  grad = safe_get_value(grad, 0, "grad")
276
- bench_grad, _ = generate_cpu_params(grad, {}, False, api_name)
297
+ grad_params = generate_cpu_params(grad, {}, False, api_name)
298
+ bench_grad = grad_params.cpu_args
277
299
  bench_grad_out = run_backward(cpu_args, bench_grad, grad_index, out)
278
300
  device_grad = grad.clone().detach().to(current_device)
279
301
  device_grad_out = run_backward(device_args, device_grad, grad_index, device_out)
@@ -300,13 +322,18 @@ def run_torch_api_online(api_full_name, api_data, backward_content):
300
322
  return UtDataInfo(None, None, out, device_out, None, in_fwd_data_list, None, rank=api_data.rank)
301
323
 
302
324
 
303
- def get_api_info(api_info_dict, api_name, real_data_path):
304
- convert_type, api_info_dict = api_info_preprocess(api_name, api_info_dict)
325
+ def check_need_grad(api_info_dict):
305
326
  need_grad = True
306
- if api_info_dict.get("input_kwargs") and "out" in api_info_dict.get("input_kwargs"):
327
+ if api_info_dict.get(Const.INPUT_KWARGS) and "out" in api_info_dict.get(Const.INPUT_KWARGS):
307
328
  need_grad = False
308
- args, kwargs = gen_api_params(api_info_dict, api_name, need_grad, convert_type, real_data_path)
309
- return args, kwargs, need_grad
329
+ return need_grad
330
+
331
+
332
+ def get_api_info(api_info_dict, api_name, real_data_path):
333
+ convert_type, api_info_dict = api_info_preprocess(api_name, api_info_dict)
334
+ need_grad = check_need_grad(api_info_dict)
335
+ args, kwargs, output_dtype = gen_api_params(api_info_dict, api_name, need_grad, convert_type, real_data_path)
336
+ return args, kwargs, output_dtype
310
337
 
311
338
 
312
339
  def need_to_backward(grad_index, out):
@@ -323,15 +350,25 @@ def run_backward(args, grad, grad_index, out):
323
350
  out[grad_index].backward(grad)
324
351
  else:
325
352
  out.backward(grad)
326
- args_grad = []
327
- for arg in args:
328
- if isinstance(arg, torch.Tensor):
329
- args_grad.append(arg.grad)
330
- grad_out = args_grad
353
+
354
+ grad_out = extract_tensors_grad(args)
331
355
 
332
356
  return grad_out
333
357
 
334
358
 
359
+ def extract_tensors_grad(args, depth=0):
360
+ if depth > Const.MAX_DEPTH:
361
+ logger.error("The depth of arg_in is too large, please check the arg_in.")
362
+ raise CompareException(CompareException.RECURSION_LIMIT_ERROR)
363
+ grads = []
364
+ for arg in args:
365
+ if isinstance(arg, torch.Tensor):
366
+ grads.append(arg.grad)
367
+ elif isinstance(arg, (list, tuple)):
368
+ grads.extend(extract_tensors_grad(arg, depth+1))
369
+ return grads
370
+
371
+
335
372
  def initialize_save_error_data(error_data_path):
336
373
  create_directory(error_data_path)
337
374
  error_data_path_checker = FileChecker(error_data_path, FileCheckConst.DIR,
@@ -479,6 +516,10 @@ def run_ut_command(args):
479
516
 
480
517
  if not is_gpu:
481
518
  torch.npu.set_compile_mode(jit_compile=args.jit_compile)
519
+ if args.jit_compile:
520
+ torch.npu.config.allow_internal_format = True
521
+ else:
522
+ torch.npu.config.allow_internal_format = False
482
523
  used_device = current_device + ":" + str(args.device_id[0])
483
524
  try:
484
525
  if is_gpu:
@@ -497,6 +538,9 @@ def run_ut_command(args):
497
538
  ability=FileCheckConst.READ_ABLE, file_type=FileCheckConst.JSON_SUFFIX)
498
539
  checked_api_info = api_info_file_checker.common_check()
499
540
  forward_content, backward_content, real_data_path = parse_json_info_forward_backward(checked_api_info)
541
+ if real_data_path:
542
+ dump_path = os.path.dirname(checked_api_info)
543
+ real_data_path = os.path.join(dump_path, Const.DUMP_TENSOR_DATA)
500
544
  if args.filter_api:
501
545
  logger.info("Start filtering the api in the api_info_file.")
502
546
  forward_content = preprocess_forward_content(forward_content)
@@ -538,5 +582,6 @@ def run_ut_command(args):
538
582
 
539
583
 
540
584
  if __name__ == '__main__':
585
+ seed_all()
541
586
  _run_ut()
542
587
  logger.info("UT task completed.")
@@ -16,6 +16,7 @@
16
16
  # limitations under the License.
17
17
 
18
18
  import os
19
+ from collections import namedtuple
19
20
  import re
20
21
  import torch
21
22
 
@@ -23,8 +24,10 @@ try:
23
24
  import torch_npu
24
25
  except ImportError:
25
26
  current_device = "cuda"
27
+ from torch.cuda.amp import autocast
26
28
  else:
27
29
  current_device = "npu"
30
+ from torch_npu.npu.amp import autocast
28
31
 
29
32
  from msprobe.core.common.const import FileCheckConst, Const, CompareConst
30
33
  from msprobe.core.common.file_utils import FileChecker
@@ -47,11 +50,17 @@ PRECISION_MAPPING = {
47
50
  }
48
51
 
49
52
 
53
+ CpuParams = namedtuple("CpuArgs", ["cpu_args", "cpu_kwargs", "autocast_dtype", "is_autocast"])
54
+ ExecParams = namedtuple("ExecParams", ["api_type", "api_name", "device", "args", "kwargs",
55
+ "is_autocast", "autocast_dtype"])
56
+
57
+
50
58
  class BackwardMessage:
51
59
  MULTIPLE_BACKWARD_MESSAGE = "Multiple backward is not supported."
52
60
  UNSUPPORT_BACKWARD_MESSAGE = "function with out=... arguments don't support automatic differentiation, " \
53
61
  "skip backward."
54
62
  NO_BACKWARD_RESULT_MESSAGE = "This API does not have backward input data, skip backward."
63
+ UNSUPPORT_API_MESSAGE = "This API does not support backward ut, skip backward."
55
64
 
56
65
 
57
66
  class UtDataInfo:
@@ -91,7 +100,15 @@ def get_validated_details_csv_path(validated_result_csv_path):
91
100
  return validated_details_csv_path
92
101
 
93
102
 
94
- def exec_api(api_type, api_name, device, args, kwargs):
103
+ def exec_api(exec_params):
104
+ api_type = exec_params.api_type
105
+ api_name = exec_params.api_name
106
+ device = exec_params.device
107
+ args = exec_params.args
108
+ kwargs = exec_params.kwargs
109
+ is_autocast = exec_params.is_autocast
110
+ autocast_dtype = exec_params.autocast_dtype
111
+
95
112
  if api_type == "Functional":
96
113
  torch_api = FunctionalOPTemplate(api_name, str, False)
97
114
  if api_type == "Tensor":
@@ -102,7 +119,11 @@ def exec_api(api_type, api_name, device, args, kwargs):
102
119
  torch_api = AtenOPTemplate(api_name, None, False)
103
120
  if api_type == "NPU":
104
121
  torch_api = NpuOPTemplate(api_name, None, False, device)
105
- out = torch_api.forward(*args, **kwargs)
122
+ if is_autocast:
123
+ with autocast(dtype=autocast_dtype):
124
+ out = torch_api.forward(*args, **kwargs)
125
+ else:
126
+ out = torch_api.forward(*args, **kwargs)
106
127
  return out
107
128
 
108
129
 
@@ -196,19 +217,28 @@ def generate_cpu_params(input_args, input_kwargs, need_backward, api_name):
196
217
  return set()
197
218
 
198
219
  raise_dtype = None
220
+ autocast_dtype = None
221
+ is_autocast = False
199
222
  need_raise_dtypes = recursive_find_dtypes(input_args)
200
223
  need_raise_dtypes.update(recursive_find_dtypes(input_kwargs, check_kwargs=True))
201
224
  if len(need_raise_dtypes) == 1:
202
- raise_dtype = PRECISION_MAPPING.get(need_raise_dtypes.pop(), torch.float32)
225
+ origin_dtype = need_raise_dtypes.pop()
226
+ raise_dtype = PRECISION_MAPPING.get(origin_dtype, torch.float32)
227
+ autocast_dtype = origin_dtype
228
+
203
229
  elif len(need_raise_dtypes) >= 2:
204
230
  raise_dtype = torch.float32
231
+ need_raise_dtypes.discard(torch.float32)
232
+ autocast_dtype = need_raise_dtypes.pop()
233
+ is_autocast = True
205
234
 
206
235
  raise_dtype = None if api_name in not_raise_dtype_set else raise_dtype
207
236
  is_detach = api_name not in not_detach_set
208
237
  cpu_args = recursive_arg_to_cpu(input_args, is_detach, raise_dtype=raise_dtype)
209
238
  cpu_kwargs = {key: recursive_arg_to_cpu(value, key != "out" and is_detach, raise_dtype=raise_dtype) for
210
239
  key, value in input_kwargs.items()}
211
- return cpu_args, cpu_kwargs
240
+ cpu_params = CpuParams(cpu_args, cpu_kwargs, autocast_dtype, is_autocast)
241
+ return cpu_params
212
242
 
213
243
 
214
244
  def record_skip_info(api_full_name, compare, compare_alg_results):
@@ -24,7 +24,7 @@ from msprobe.core.common.const import Const, CompareConst
24
24
  from msprobe.pytorch.api_accuracy_checker.compare.api_precision_compare import online_api_precision_compare
25
25
  from msprobe.pytorch.api_accuracy_checker.compare.compare_utils import DETAIL_TEST_ROWS, thousandth_standard_api, \
26
26
  binary_standard_api, absolute_standard_api
27
- from msprobe.pytorch.api_accuracy_checker.run_ut.run_ut_utils import UtDataInfo, exec_api
27
+ from msprobe.pytorch.api_accuracy_checker.run_ut.run_ut_utils import UtDataInfo, exec_api, ExecParams
28
28
  from msprobe.pytorch.common.log import logger
29
29
  from msprobe.pytorch.api_accuracy_checker.tensor_transport_layer.attl import move2target_device
30
30
  from msprobe.pytorch.api_accuracy_checker.run_ut.run_ut_utils import generate_cpu_params
@@ -92,8 +92,10 @@ def online_precision_compare(api_data, device, common_config, api_precision_csv_
92
92
 
93
93
  try:
94
94
  # NPU vs CPU
95
- cpu_args, cpu_kwargs = generate_cpu_params(npu_args, npu_kwargs, False, api_name)
96
- cpu_out = exec_api(api_type, api_name, Const.CPU_LOWERCASE, cpu_args, cpu_kwargs)
95
+ cpu_params = generate_cpu_params(npu_args, npu_kwargs, False, api_name)
96
+ cpu_args, cpu_kwargs = cpu_params.cpu_args, cpu_params.cpu_kwargs
97
+ cpu_exec_params = ExecParams(api_type, api_name, Const.CPU_LOWERCASE, cpu_args, cpu_kwargs, False, None)
98
+ cpu_out = exec_api(cpu_exec_params)
97
99
  npu_data_info = UtDataInfo(None, None, npu_out, cpu_out, None, [], None, rank=api_data.rank)
98
100
  npu_detail = compare.compare_output(api_full_name, npu_data_info, True)
99
101
  npu_data = pd.DataFrame(npu_detail, columns=DETAIL_TEST_ROWS[-1])
@@ -0,0 +1,215 @@
1
+ # Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from collections import namedtuple
17
+ import torch
18
+
19
+
20
+ VarParams = namedtuple('VarParams', ['var', 'lr_t', 'm_t', 'beta1_broad', 'grad', 'epsilon', 'v_t'])
21
+
22
+
23
+ def _output_m_compute(m, beta1_broad, grad):
24
+ """
25
+ _output_m_compute
26
+ do compute m_t = m + (beta1 - 1) * (m - grad)
27
+ """
28
+ input_dtype = m.dtype
29
+
30
+ sneg_one = torch.ones((1), dtype=input_dtype) * -1
31
+ sneg_one = sneg_one.to(beta1_broad.device)
32
+
33
+ # `formula; beta1 -1`
34
+ vsub_beta1_1 = torch.add(beta1_broad, sneg_one)
35
+
36
+ # `formula; m - grad`
37
+ vsub_m_grad = torch.sub(m, grad)
38
+
39
+ # `formula; (beta1 - 1) * (m - grad)`
40
+ vmul_m = torch.mul(vsub_beta1_1, vsub_m_grad)
41
+
42
+ # `formula; m_t = m + (beta1 - 1) * (m - grad)`
43
+ m_t = torch.add(m, vmul_m)
44
+
45
+ return m_t
46
+
47
+
48
+ def _output_v_compute(v, beta2, grad):
49
+ """
50
+ _output_v_compute
51
+ do compute v_t = v + (1 - beta2)*(grad*grad -v)
52
+ """
53
+ input_dtype = v.dtype
54
+
55
+ sneg_one = torch.ones((1), dtype=input_dtype) * -1
56
+
57
+ # `formula; broadcast beta2 to vector`
58
+ beta2_tensor = torch.tensor(beta2, dtype=input_dtype)
59
+ beta2_broad = beta2_tensor.expand_as(v)
60
+
61
+ # `formula; beta2 - 1`
62
+ vsub_beta2_1 = torch.add(beta2_broad, sneg_one)
63
+ vsub_beta2_1 = vsub_beta2_1.to(v.device)
64
+
65
+ # `formula; grad * grad`
66
+ vmul_grad_grad = torch.mul(grad, grad)
67
+
68
+ # `formula; (v - grad*grad)`
69
+ vsub_v_grad = torch.sub(v, vmul_grad_grad)
70
+
71
+ # `formula; (beta2 -1) * (v - grad * grad)`
72
+ vmul_grad = torch.mul(vsub_beta2_1, vsub_v_grad)
73
+
74
+ # `formula; v_t = v + (beta2 - 1) * (v - grad * grad)`
75
+ v_t = torch.add(v, vmul_grad)
76
+
77
+ return v_t
78
+
79
+
80
+ def _inner_lr_compute(lr, beta2_power, beta1_power, compute_shape_tensor):
81
+ """
82
+ _inner_lr_compute
83
+ `formula; lr_t = learning_rate * (sqrt(1-beta2_power)) / (1 - beta1_power)`
84
+ """
85
+
86
+ input_dtype = compute_shape_tensor.dtype
87
+
88
+ s_one = torch.ones((1), dtype=input_dtype)
89
+
90
+ s_neg_one = torch.ones((1), dtype=input_dtype) * -1
91
+
92
+ # `formula; (1 - beta2_power)`
93
+ v_neg_beta2_power = torch.mul(beta2_power, s_neg_one)
94
+ v_add_beta2_power = torch.add(v_neg_beta2_power, s_one)
95
+
96
+ # `formula; sqrt(1 - beta2_power)`
97
+ v_sqrt_beta2_power = torch.sqrt(v_add_beta2_power)
98
+
99
+ # `formula; (1 - beta1_power)`
100
+ v_neg_beta1_power = torch.mul(beta1_power, s_neg_one)
101
+ v_add_beta1_power = torch.add(v_neg_beta1_power, s_one)
102
+
103
+ # `formula; learning_rate * (sqrt(1-beta2_power)`
104
+ res = torch.mul(lr, v_sqrt_beta2_power)
105
+
106
+ # `formula; learning_rate*(sqrt(1-beta2_power))/(1-beta1_power)`
107
+ res = torch.div(res, v_add_beta1_power)
108
+ return res.expand_as(compute_shape_tensor)
109
+
110
+
111
+ def _inner_eps_add_sqrt_vt_compute(epsilon, v_t):
112
+ """
113
+ (epsilon + sqrt(v_t) )
114
+ """
115
+ # `formula; sqrt(v_t)`
116
+ sqrt_vt = torch.sqrt(v_t)
117
+
118
+ # `formula; broadcast epsilon to vector`
119
+ input_dtype = v_t.dtype
120
+ epsilon_tensor = torch.tensor(epsilon, dtype=input_dtype)
121
+ epsilon_broad = epsilon_tensor.expand_as(v_t)
122
+ epsilon_broad = epsilon_broad.to(sqrt_vt.device)
123
+
124
+ # `formula; epsilon + sqrt(v_t)`
125
+ v_add_sqrt_v = torch.add(sqrt_vt, epsilon_broad)
126
+
127
+ return v_add_sqrt_v
128
+
129
+
130
+ def _output_var_t_compute_use_nesterov(varparams):
131
+ """
132
+ _output_var_t_compute_use_nesterov
133
+ `formula; var_t = var - lr_t * (m_t * beta1 + (1 - beta1) * grad) / (epsilon + sqrt(v_t))`
134
+ `formula; var_t = var - lr_t * (m_t * beta1 + (1 - beta1) * grad) / (epsilon + sqrt(v_t))`
135
+ """
136
+ var = varparams.var
137
+ lr_t = varparams.lr_t
138
+ m_t = varparams.m_t
139
+ beta1_broad = varparams.beta1_broad
140
+ grad = varparams.grad
141
+ epsilon = varparams.epsilon
142
+ v_t = varparams.v_t
143
+
144
+ input_dtype = var.dtype
145
+
146
+ s_one = torch.ones((1), dtype=input_dtype)
147
+
148
+ s_neg_one = torch.ones((1), dtype=input_dtype) * -1
149
+
150
+ # `formula; m_t * beta1`
151
+ v_muls_mt_beta1 = torch.mul(m_t, beta1_broad)
152
+
153
+ # `formula; 1 -beta1`
154
+ v_neg_beta1 = torch.mul(beta1_broad, s_neg_one)
155
+ vsub_1_beta1 = torch.add(v_neg_beta1, s_one)
156
+
157
+ # `formula; (1-beta1)* grad`
158
+ v_mul_grad = torch.mul(vsub_1_beta1, grad)
159
+
160
+ # `formula; (m_t*beta1 + (1 - beta1)*grad)`
161
+ v_div_left = torch.add(v_muls_mt_beta1, v_mul_grad)
162
+
163
+ # `formula; lr_t * (m_t*beta1 + (1 - beta1) * grad)`
164
+ # broadcast lr_t to vector
165
+
166
+ lrt_broad = lr_t.expand_as(var)
167
+ v_mul_left = torch.mul(lrt_broad, v_div_left)
168
+
169
+ # `formula; (epsilon + sqrt(v_t))`
170
+ v_add_sqrt_v = _inner_eps_add_sqrt_vt_compute(epsilon, v_t)
171
+
172
+ # `formula; lr_t * (m_t*beta1 + (1-beta1)*grad / (epsilon + sqrt(v_t))`
173
+ v_div_res = torch.div(v_mul_left, v_add_sqrt_v)
174
+
175
+ # `formula; var - lr_t * (m_t*beta1 + (1-beta1)*grad) / (epsilon + sqrt(v_t))`
176
+ v_t = torch.sub(var, v_div_res)
177
+
178
+ return v_t
179
+
180
+
181
+ def _output_var_t_compute(var, lr_t, m_t, epsilon, v_t):
182
+ """
183
+ _output_var_t_compute
184
+ `var_t = var - lr_t * m_t / (epsilon + sqrt(v_t))`
185
+ """
186
+ # `formula; lr_t * m_t`
187
+ lr_t = lr_t.to(m_t.device)
188
+ v_mul_left = torch.mul(lr_t, m_t)
189
+
190
+ # `formula; (epsilon + sqrt(v_t))`
191
+ v_add_sqrt_v = _inner_eps_add_sqrt_vt_compute(epsilon, v_t)
192
+
193
+ # `formula; lr_t * m_t /(epsilon + sqrt(v_t))`
194
+ v_div_res = torch.div(v_mul_left, v_add_sqrt_v)
195
+
196
+ # `formula; var - lr_t * m_t / (epsilon + sqrt(v_t))`
197
+ v_t = torch.sub(var, v_div_res)
198
+
199
+ return v_t
200
+
201
+
202
+ def npu_apply_adam(beta1_power, beta2_power, lr, beta1, beta2, epsilon, grad, use_locking, use_nesterov, out):
203
+ var, m, v = out
204
+ input_dtype = m.dtype
205
+ beta1_tensor = torch.tensor(beta1, dtype=input_dtype).to(m.device)
206
+ beta1_broad = beta1_tensor.expand_as(m)
207
+ m_t = _output_m_compute(m, beta1_broad, grad)
208
+ v_t = _output_v_compute(v, beta2, grad)
209
+ lr_t = _inner_lr_compute(lr, beta2_power, beta1_power, grad)
210
+ if use_nesterov:
211
+ var_params = VarParams(var, lr_t, m_t, beta1_broad, grad, epsilon, v_t)
212
+ var_t = _output_var_t_compute_use_nesterov(var_params)
213
+ else:
214
+ var_t = _output_var_t_compute(var, lr_t, m_t, epsilon, v_t)
215
+ return var_t, m_t, v_t
@@ -0,0 +1,27 @@
1
+ # Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import torch
17
+
18
+
19
+ def npu_group_norm_silu(x, gama, beta, group, eps):
20
+ if len(x.shape) != 4:
21
+ raise ValueError("x shape should be (N, C, H, W)")
22
+ res = torch.ops.aten.native_group_norm(x, gama, beta, x.shape[0], x.shape[1], x.shape[2] * x.shape[3], group, eps)
23
+ res = list(res)
24
+ if not res:
25
+ raise ValueError("run native_group_norm failed")
26
+ res[0] = torch.nn.functional.silu(res[0])
27
+ return res
@@ -0,0 +1,21 @@
1
+ # Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import torch
17
+
18
+
19
+ def npu_mish(x):
20
+ mish = torch.nn.Mish()
21
+ return mish(x)