mindstudio-probe 1.1.0__py3-none-any.whl → 1.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (220) hide show
  1. {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.1.1.dist-info}/METADATA +5 -5
  2. mindstudio_probe-1.1.1.dist-info/RECORD +341 -0
  3. {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.1.1.dist-info}/WHEEL +1 -1
  4. {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.1.1.dist-info}/entry_points.txt +0 -1
  5. msprobe/README.md +39 -3
  6. msprobe/config.json +1 -3
  7. msprobe/core/advisor/advisor.py +8 -3
  8. msprobe/core/common/const.py +113 -13
  9. msprobe/core/common/exceptions.py +25 -3
  10. msprobe/core/common/file_utils.py +150 -26
  11. msprobe/core/common/inplace_op_checker.py +15 -0
  12. msprobe/core/common/log.py +27 -9
  13. msprobe/core/common/utils.py +182 -69
  14. msprobe/core/common_config.py +44 -15
  15. msprobe/core/compare/acc_compare.py +207 -142
  16. msprobe/core/compare/check.py +2 -5
  17. msprobe/core/compare/compare_cli.py +21 -4
  18. msprobe/core/compare/highlight.py +124 -55
  19. msprobe/core/compare/layer_mapping/__init__.py +19 -0
  20. msprobe/core/compare/layer_mapping/data_scope_parser.py +235 -0
  21. msprobe/core/compare/layer_mapping/layer_mapping.py +242 -0
  22. msprobe/core/compare/layer_mapping/postprocess_pass.py +94 -0
  23. msprobe/core/compare/npy_compare.py +52 -23
  24. msprobe/core/compare/utils.py +272 -247
  25. msprobe/core/data_dump/data_collector.py +13 -11
  26. msprobe/core/data_dump/data_processor/base.py +46 -16
  27. msprobe/core/data_dump/data_processor/mindspore_processor.py +4 -4
  28. msprobe/core/data_dump/data_processor/pytorch_processor.py +156 -59
  29. msprobe/core/data_dump/scope.py +113 -34
  30. msprobe/core/grad_probe/constant.py +27 -13
  31. msprobe/core/grad_probe/grad_compare.py +18 -1
  32. msprobe/core/grad_probe/utils.py +30 -2
  33. msprobe/core/overflow_check/abnormal_scene.py +185 -0
  34. msprobe/core/overflow_check/api_info.py +55 -0
  35. msprobe/core/overflow_check/checker.py +138 -0
  36. msprobe/core/overflow_check/filter.py +157 -0
  37. msprobe/core/overflow_check/ignore_rules.yaml +55 -0
  38. msprobe/core/overflow_check/level.py +22 -0
  39. msprobe/core/overflow_check/utils.py +28 -0
  40. msprobe/docs/01.installation.md +10 -0
  41. msprobe/docs/02.config_introduction.md +49 -22
  42. msprobe/docs/03.config_examples.md +2 -9
  43. msprobe/docs/04.kernel_dump_PyTorch.md +73 -0
  44. msprobe/docs/05.data_dump_PyTorch.md +3 -1
  45. msprobe/docs/06.data_dump_MindSpore.md +157 -90
  46. msprobe/docs/07.accuracy_checker_PyTorch.md +12 -12
  47. msprobe/docs/08.accuracy_checker_online_PyTorch.md +1 -6
  48. msprobe/docs/09.accuracy_checker_MindSpore.md +44 -8
  49. msprobe/docs/10.accuracy_compare_PyTorch.md +19 -13
  50. msprobe/docs/11.accuracy_compare_MindSpore.md +104 -13
  51. msprobe/docs/12.overflow_check_PyTorch.md +1 -1
  52. msprobe/docs/13.overflow_check_MindSpore.md +6 -6
  53. msprobe/docs/15.free_benchmarking_PyTorch.md +4 -5
  54. msprobe/docs/16.free_benchmarking_MindSpore.md +56 -37
  55. msprobe/docs/17.grad_probe.md +5 -6
  56. msprobe/docs/19.monitor.md +468 -0
  57. msprobe/docs/20.monitor_performance_baseline.md +52 -0
  58. msprobe/docs/21.visualization_PyTorch.md +386 -0
  59. msprobe/docs/22.visualization_MindSpore.md +384 -0
  60. msprobe/docs/23.tool_function_introduction.md +28 -0
  61. msprobe/docs/FAQ.md +3 -0
  62. msprobe/docs/data_dump_Mindspore/dynamic_graph_quick_start_example.md +211 -0
  63. msprobe/docs/img/compare_result.png +0 -0
  64. msprobe/docs/img/monitor/cpu_info.png +0 -0
  65. msprobe/mindspore/__init__.py +15 -0
  66. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +113 -145
  67. msprobe/mindspore/api_accuracy_checker/api_info.py +21 -6
  68. msprobe/mindspore/api_accuracy_checker/api_runner.py +43 -18
  69. msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +21 -7
  70. msprobe/mindspore/api_accuracy_checker/checker_support_api.yaml +77 -0
  71. msprobe/mindspore/api_accuracy_checker/cmd_parser.py +63 -1
  72. msprobe/mindspore/api_accuracy_checker/compute_element.py +59 -24
  73. msprobe/mindspore/api_accuracy_checker/data_manager.py +264 -0
  74. msprobe/mindspore/api_accuracy_checker/main.py +27 -3
  75. msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +206 -0
  76. msprobe/mindspore/api_accuracy_checker/multi_data_manager.py +58 -0
  77. msprobe/mindspore/api_accuracy_checker/type_mapping.py +22 -5
  78. msprobe/mindspore/api_accuracy_checker/utils.py +34 -17
  79. msprobe/mindspore/cell_processor.py +33 -12
  80. msprobe/mindspore/common/const.py +33 -13
  81. msprobe/mindspore/common/log.py +5 -9
  82. msprobe/mindspore/common/utils.py +43 -4
  83. msprobe/mindspore/compare/distributed_compare.py +22 -22
  84. msprobe/mindspore/compare/ms_compare.py +271 -248
  85. msprobe/mindspore/compare/ms_graph_compare.py +81 -47
  86. msprobe/mindspore/debugger/debugger_config.py +4 -1
  87. msprobe/mindspore/debugger/precision_debugger.py +7 -1
  88. msprobe/mindspore/dump/dump_tool_factory.py +3 -1
  89. msprobe/mindspore/dump/hook_cell/api_registry.py +12 -2
  90. msprobe/mindspore/dump/hook_cell/primitive_hooks.py +13 -16
  91. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +25 -0
  92. msprobe/mindspore/dump/jit_dump.py +17 -5
  93. msprobe/mindspore/dump/kernel_graph_dump.py +2 -4
  94. msprobe/mindspore/dump/kernel_kbyk_dump.py +2 -4
  95. msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +140 -0
  96. msprobe/mindspore/dym_loader/hook_dynamic_loader.h +53 -0
  97. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +145 -39
  98. msprobe/mindspore/free_benchmark/common/handler_params.py +1 -2
  99. msprobe/mindspore/free_benchmark/common/utils.py +19 -4
  100. msprobe/mindspore/free_benchmark/data/support_wrap_ops.yaml +0 -204
  101. msprobe/mindspore/free_benchmark/handler/base_handler.py +3 -3
  102. msprobe/mindspore/free_benchmark/handler/check_handler.py +4 -5
  103. msprobe/mindspore/free_benchmark/handler/fix_handler.py +4 -4
  104. msprobe/mindspore/free_benchmark/handler/handler_factory.py +4 -4
  105. msprobe/mindspore/free_benchmark/perturbation/add_noise.py +2 -2
  106. msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +15 -6
  107. msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +4 -4
  108. msprobe/mindspore/free_benchmark/perturbation/exchange_value.py +2 -2
  109. msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +13 -6
  110. msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +2 -2
  111. msprobe/mindspore/free_benchmark/self_check_tool_factory.py +2 -2
  112. msprobe/mindspore/grad_probe/global_context.py +28 -8
  113. msprobe/mindspore/grad_probe/grad_analyzer.py +27 -13
  114. msprobe/mindspore/grad_probe/grad_monitor.py +16 -1
  115. msprobe/mindspore/grad_probe/grad_stat_csv.py +33 -5
  116. msprobe/mindspore/grad_probe/hook.py +24 -10
  117. msprobe/mindspore/grad_probe/utils.py +18 -5
  118. msprobe/mindspore/ms_config.py +22 -15
  119. msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +2 -4
  120. msprobe/mindspore/runtime.py +15 -0
  121. msprobe/mindspore/service.py +36 -30
  122. msprobe/mindspore/task_handler_factory.py +15 -0
  123. msprobe/msprobe.py +24 -7
  124. msprobe/pytorch/__init__.py +3 -2
  125. msprobe/pytorch/api_accuracy_checker/common/config.py +62 -0
  126. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +3 -4
  127. msprobe/pytorch/api_accuracy_checker/generate_op_script/config_op.json +9 -0
  128. msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +454 -0
  129. msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +365 -0
  130. msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +6 -1
  131. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +19 -14
  132. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +13 -9
  133. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +77 -53
  134. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +15 -4
  135. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +9 -24
  136. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +4 -12
  137. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/dump_dispatch.py +9 -4
  138. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +3 -11
  139. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +2 -2
  140. msprobe/pytorch/bench_functions/confusion_transpose.py +5 -1
  141. msprobe/pytorch/bench_functions/matmul_backward.py +12 -0
  142. msprobe/pytorch/bench_functions/npu_fusion_attention.py +100 -6
  143. msprobe/pytorch/bench_functions/rotary_mul.py +4 -0
  144. msprobe/pytorch/bench_functions/swiglu.py +10 -2
  145. msprobe/pytorch/common/parse_json.py +6 -6
  146. msprobe/pytorch/common/utils.py +56 -5
  147. msprobe/pytorch/compare/distributed_compare.py +8 -9
  148. msprobe/pytorch/compare/pt_compare.py +8 -6
  149. msprobe/pytorch/debugger/debugger_config.py +19 -15
  150. msprobe/pytorch/dump/kernel_dump/kernel_config.py +33 -0
  151. msprobe/pytorch/free_benchmark/common/constant.py +15 -0
  152. msprobe/pytorch/free_benchmark/common/counter.py +15 -0
  153. msprobe/pytorch/free_benchmark/common/enums.py +15 -0
  154. msprobe/pytorch/free_benchmark/common/params.py +8 -1
  155. msprobe/pytorch/free_benchmark/common/utils.py +26 -4
  156. msprobe/pytorch/free_benchmark/compare/grad_saver.py +20 -3
  157. msprobe/pytorch/free_benchmark/compare/single_benchmark.py +2 -0
  158. msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +3 -1
  159. msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +6 -4
  160. msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +2 -0
  161. msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +4 -0
  162. msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +10 -0
  163. msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +6 -5
  164. msprobe/pytorch/grad_probe/grad_monitor.py +23 -6
  165. msprobe/pytorch/grad_probe/grad_stat_csv.py +40 -10
  166. msprobe/pytorch/hook_module/support_wrap_ops.yaml +1 -0
  167. msprobe/pytorch/hook_module/wrap_functional.py +14 -12
  168. msprobe/pytorch/module_processer.py +2 -5
  169. msprobe/pytorch/monitor/anomaly_analyse.py +201 -0
  170. msprobe/pytorch/monitor/anomaly_detect.py +340 -0
  171. msprobe/pytorch/monitor/distributed/__init__.py +0 -0
  172. msprobe/pytorch/monitor/distributed/distributed_ops.yaml +19 -0
  173. msprobe/pytorch/monitor/distributed/stack_blacklist.yaml +5 -0
  174. msprobe/pytorch/monitor/distributed/wrap_distributed.py +272 -0
  175. msprobe/pytorch/monitor/features.py +108 -0
  176. msprobe/pytorch/monitor/module_hook.py +870 -0
  177. msprobe/pytorch/monitor/module_metric.py +193 -0
  178. msprobe/pytorch/monitor/module_spec_verifier.py +93 -0
  179. msprobe/pytorch/monitor/optimizer_collect.py +295 -0
  180. msprobe/pytorch/monitor/unittest/__init__.py +0 -0
  181. msprobe/pytorch/monitor/unittest/test_monitor.py +145 -0
  182. msprobe/pytorch/monitor/utils.py +250 -0
  183. msprobe/pytorch/monitor/visualizer.py +59 -0
  184. msprobe/pytorch/online_dispatch/__init__.py +2 -3
  185. msprobe/pytorch/online_dispatch/compare.py +29 -38
  186. msprobe/pytorch/online_dispatch/dispatch.py +50 -25
  187. msprobe/pytorch/online_dispatch/dump_compare.py +21 -9
  188. msprobe/pytorch/online_dispatch/single_compare.py +53 -32
  189. msprobe/pytorch/online_dispatch/torch_ops_config.yaml +1 -1
  190. msprobe/pytorch/online_dispatch/utils.py +49 -21
  191. msprobe/pytorch/parse_tool/lib/compare.py +12 -18
  192. msprobe/pytorch/parse_tool/lib/config.py +1 -1
  193. msprobe/pytorch/parse_tool/lib/parse_tool.py +1 -2
  194. msprobe/pytorch/parse_tool/lib/utils.py +16 -35
  195. msprobe/pytorch/parse_tool/lib/visualization.py +2 -0
  196. msprobe/pytorch/pt_config.py +31 -8
  197. msprobe/pytorch/service.py +15 -5
  198. msprobe/visualization/__init__.py +14 -0
  199. msprobe/visualization/builder/__init__.py +14 -0
  200. msprobe/visualization/builder/graph_builder.py +165 -0
  201. msprobe/visualization/builder/msprobe_adapter.py +205 -0
  202. msprobe/visualization/compare/__init__.py +14 -0
  203. msprobe/visualization/compare/graph_comparator.py +130 -0
  204. msprobe/visualization/compare/mode_adapter.py +211 -0
  205. msprobe/visualization/graph/__init__.py +14 -0
  206. msprobe/visualization/graph/base_node.py +124 -0
  207. msprobe/visualization/graph/graph.py +200 -0
  208. msprobe/visualization/graph/node_colors.py +95 -0
  209. msprobe/visualization/graph/node_op.py +39 -0
  210. msprobe/visualization/graph_service.py +214 -0
  211. msprobe/visualization/utils.py +232 -0
  212. mindstudio_probe-1.1.0.dist-info/RECORD +0 -287
  213. msprobe/docs/04.acl_config_examples.md +0 -78
  214. msprobe/mindspore/compare/layer_mapping.py +0 -146
  215. msprobe/mindspore/compare/modify_mapping.py +0 -107
  216. msprobe/mindspore/free_benchmark/decorator/dec_forward.py +0 -57
  217. msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +0 -122
  218. {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.1.1.dist-info}/LICENSE +0 -0
  219. {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.1.1.dist-info}/top_level.txt +0 -0
  220. /msprobe/{mindspore/free_benchmark/decorator → pytorch/monitor}/__init__.py +0 -0
@@ -17,7 +17,7 @@
17
17
 
18
18
  import argparse
19
19
  import os
20
- import csv
20
+ import re
21
21
  import sys
22
22
  import time
23
23
  import gc
@@ -35,19 +35,20 @@ import torch
35
35
  from tqdm import tqdm
36
36
 
37
37
  from msprobe.pytorch.api_accuracy_checker.run_ut.run_ut_utils import BackwardMessage, UtDataInfo, \
38
- get_validated_result_csv_path, get_validated_details_csv_path, exec_api, record_skip_info
38
+ get_validated_result_csv_path, get_validated_details_csv_path, exec_api, record_skip_info, is_unsupported_api
39
39
  from msprobe.pytorch.api_accuracy_checker.run_ut.data_generate import gen_api_params, gen_args
40
40
  from msprobe.pytorch.api_accuracy_checker.common.utils import api_info_preprocess, \
41
41
  initialize_save_path, UtDataProcessor, extract_basic_api_segments, ApiData
42
42
  from msprobe.pytorch.api_accuracy_checker.compare.compare import Comparator
43
43
  from msprobe.pytorch.api_accuracy_checker.compare.compare_column import CompareColumn
44
- from msprobe.pytorch.api_accuracy_checker.common.config import msCheckerConfig
44
+ from msprobe.pytorch.api_accuracy_checker.common.config import CheckerConfig
45
45
  from msprobe.pytorch.common.parse_json import parse_json_info_forward_backward
46
- from msprobe.core.common.file_utils import FileChecker, change_mode, check_path_before_create, \
47
- create_directory, get_json_contents, read_csv
46
+ from msprobe.core.common.file_utils import FileChecker, change_mode, \
47
+ create_directory, get_json_contents, read_csv, check_file_or_directory_path, check_crt_valid
48
48
  from msprobe.pytorch.common.log import logger
49
49
  from msprobe.pytorch.pt_config import parse_json_config
50
50
  from msprobe.core.common.const import Const, FileCheckConst, CompareConst
51
+ from msprobe.core.common.utils import safe_get_value
51
52
  from msprobe.pytorch.api_accuracy_checker.tensor_transport_layer.attl import ATTL, ATTLConfig, move2device_exec
52
53
  from msprobe.pytorch.api_accuracy_checker.tensor_transport_layer.device_dispatch import ConsumerDispatcher
53
54
  from msprobe.pytorch.api_accuracy_checker.run_ut.run_ut_utils import generate_cpu_params, generate_device_params
@@ -57,11 +58,7 @@ current_time = time.strftime("%Y%m%d%H%M%S")
57
58
  UT_ERROR_DATA_DIR = 'ut_error_data' + current_time
58
59
  RESULT_FILE_NAME = "accuracy_checking_result_" + current_time + ".csv"
59
60
  DETAILS_FILE_NAME = "accuracy_checking_details_" + current_time + ".csv"
60
- RunUTConfig = namedtuple('RunUTConfig', ['forward_content', 'backward_content', 'result_csv_path', 'details_csv_path',
61
- 'save_error_data', 'is_continue_run_ut', 'real_data_path', 'white_list',
62
- 'black_list', 'error_data_path', 'online_config'])
63
61
 
64
- OnlineConfig = namedtuple('OnlineConfig', ['is_online', 'nfs_path', 'host', 'port', 'rank_list', 'tls_path'])
65
62
 
66
63
  not_backward_list = ['repeat_interleave']
67
64
 
@@ -99,7 +96,11 @@ def run_ut(config):
99
96
  run_api_online(config, compare)
100
97
  else:
101
98
  csv_df = read_csv(config.result_csv_path)
102
- api_name_set = {row[0] for row in csv_df.itertuples(index=False, name=None)}
99
+ try:
100
+ api_name_set = {row[0] for row in csv_df.itertuples(index=False, name=None)}
101
+ except IndexError:
102
+ logger.error(f"Read {config.result_csv_path} error, api_name_set is empty.")
103
+ api_name_set = set()
103
104
  run_api_offline(config, compare, api_name_set)
104
105
  for result_csv_path, details_csv_path in zip(compare.save_path_list, compare.detail_save_path_list):
105
106
  change_mode(result_csv_path, FileCheckConst.DATA_FILE_AUTHORITY)
@@ -140,7 +141,7 @@ def run_api_offline(config, compare, api_name_set):
140
141
  except Exception as err:
141
142
  if "expected scalar type Long" in str(err):
142
143
  logger.warning(f"API {api_name} not support int32 tensor in CPU, please add {api_name} to CONVERT_API "
143
- f"'int32_to_int64' list in accuracy_tools/api_accuracy_check/common/utils.py file.")
144
+ "'int32_to_int64' list in accuracy_tools/msprobe/core/common/const.py file.")
144
145
  else:
145
146
  logger.error(f"Run {api_full_name} UT Error: %s" % str(err))
146
147
  compare_alg_results = err_column.to_column_value(CompareConst.SKIP, str(err))
@@ -220,14 +221,6 @@ def blacklist_and_whitelist_filter(api_name, black_list, white_list):
220
221
  return False
221
222
 
222
223
 
223
- def is_unsupported_api(api_name):
224
- split_name = api_name.split(Const.SEP)[0]
225
- flag = split_name == Const.DISTRIBUTED
226
- if flag:
227
- logger.info(f"{split_name} api is not supported for run ut. SKIP.")
228
- return flag
229
-
230
-
231
224
  def do_save_error_data(api_full_name, data_info, error_data_path, is_fwd_success, is_bwd_success):
232
225
  if not is_fwd_success or not is_bwd_success:
233
226
  processor = UtDataProcessor(error_data_path)
@@ -253,7 +246,7 @@ def run_torch_api(api_full_name, real_data_path, backward_content, api_info_dict
253
246
  backward_message += BackwardMessage.UNSUPPORT_BACKWARD_MESSAGE
254
247
  if api_name in not_backward_list:
255
248
  need_grad = False
256
- logger.warning("%s %s" % (api_full_name, BackwardMessage.NO_BACKWARD_RESULT_MESSAGE))
249
+ logger.info("%s %s" % (api_full_name, BackwardMessage.NO_BACKWARD_RESULT_MESSAGE))
257
250
  backward_message += BackwardMessage.NO_BACKWARD_RESULT_MESSAGE
258
251
  need_backward = need_backward and need_grad
259
252
  if kwargs.get("device"):
@@ -278,7 +271,8 @@ def run_torch_api(api_full_name, real_data_path, backward_content, api_info_dict
278
271
  func_options = {
279
272
  'real_data_path': real_data_path
280
273
  }
281
- grad = gen_args(backward_args, api_name, func_options)[0]
274
+ grad = gen_args(backward_args, api_name, func_options)
275
+ grad = safe_get_value(grad, 0, "grad")
282
276
  bench_grad, _ = generate_cpu_params(grad, {}, False, api_name)
283
277
  bench_grad_out = run_backward(cpu_args, bench_grad, grad_index, out)
284
278
  device_grad = grad.clone().detach().to(current_device)
@@ -286,8 +280,8 @@ def run_torch_api(api_full_name, real_data_path, backward_content, api_info_dict
286
280
  else:
287
281
  backward_message += BackwardMessage.MULTIPLE_BACKWARD_MESSAGE
288
282
  if api_name == "npu_fusion_attention":
289
- out = out[0]
290
- device_out = device_out[0]
283
+ out = safe_get_value(out, 0, "out")
284
+ device_out = safe_get_value(device_out, 0, "device_out")
291
285
 
292
286
  return UtDataInfo(bench_grad_out, device_grad_out, device_out, out, bench_grad, in_fwd_data_list, backward_message)
293
287
 
@@ -323,6 +317,9 @@ def need_to_backward(grad_index, out):
323
317
 
324
318
  def run_backward(args, grad, grad_index, out):
325
319
  if grad_index is not None:
320
+ if grad_index >= len(out):
321
+ logger.error(f"Run backward error when grad_index is {grad_index}")
322
+ raise IndexError(f"Run backward error when grad_index is {grad_index}")
326
323
  out[grad_index].backward(grad)
327
324
  else:
328
325
  out.backward(grad)
@@ -336,7 +333,6 @@ def run_backward(args, grad, grad_index, out):
336
333
 
337
334
 
338
335
  def initialize_save_error_data(error_data_path):
339
- check_path_before_create(error_data_path)
340
336
  create_directory(error_data_path)
341
337
  error_data_path_checker = FileChecker(error_data_path, FileCheckConst.DIR,
342
338
  ability=FileCheckConst.WRITE_ABLE)
@@ -438,7 +434,49 @@ def _run_ut(parser=None):
438
434
  run_ut_command(args)
439
435
 
440
436
 
437
+ def checked_online_config(online_config):
438
+ if not online_config.is_online:
439
+ return
440
+ if not isinstance(online_config.is_online, bool):
441
+ raise ValueError("is_online must be bool type")
442
+ # rank_list
443
+ if not isinstance(online_config.rank_list, list):
444
+ raise ValueError("rank_list must be a list")
445
+ if online_config.rank_list and not all(isinstance(rank, int) for rank in online_config.rank_list):
446
+ raise ValueError("All elements in rank_list must be integers")
447
+
448
+ # nfs_path
449
+ if online_config.nfs_path:
450
+ check_file_or_directory_path(online_config.nfs_path, isdir=True)
451
+ return
452
+ # tls_path
453
+ if online_config.tls_path:
454
+ check_file_or_directory_path(online_config.tls_path, isdir=True)
455
+ check_file_or_directory_path(os.path.join(online_config.tls_path, "server.key"))
456
+ check_file_or_directory_path(os.path.join(online_config.tls_path, "server.crt"))
457
+ check_crt_valid(os.path.join(online_config.tls_path, "server.crt"))
458
+
459
+ # host and port
460
+ if not isinstance(online_config.host, str) or not re.match(Const.ipv4_pattern, online_config.host):
461
+ raise Exception(f"host: {online_config.host} is invalid.")
462
+ if not isinstance(online_config.port, int) or not (0 < online_config.port <= 65535):
463
+ raise Exception(f"port: {online_config.port} is invalid, port range 0-65535.")
464
+
465
+
441
466
  def run_ut_command(args):
467
+ if args.config_path:
468
+ config_path_checker = FileChecker(args.config_path, FileCheckConst.FILE,
469
+ FileCheckConst.READ_ABLE, FileCheckConst.JSON_SUFFIX)
470
+ checked_config_path = config_path_checker.common_check()
471
+ _, task_config = parse_json_config(checked_config_path, Const.RUN_UT)
472
+ checker_config = CheckerConfig(task_config)
473
+ else:
474
+ checker_config = CheckerConfig()
475
+
476
+ if not checker_config.is_online and not args.api_info_file:
477
+ logger.error("Please provide api_info_file for offline run ut.")
478
+ raise Exception("Please provide api_info_file for offline run ut.")
479
+
442
480
  if not is_gpu:
443
481
  torch.npu.set_compile_mode(jit_compile=args.jit_compile)
444
482
  used_device = current_device + ":" + str(args.device_id[0])
@@ -464,8 +502,7 @@ def run_ut_command(args):
464
502
  forward_content = preprocess_forward_content(forward_content)
465
503
  logger.info("Finish filtering the api in the api_info_file.")
466
504
 
467
- out_path = os.path.realpath(args.out_path) if args.out_path else "./"
468
- check_path_before_create(out_path)
505
+ out_path = args.out_path if args.out_path else Const.DEFAULT_PATH
469
506
  create_directory(out_path)
470
507
  out_path_checker = FileChecker(out_path, FileCheckConst.DIR, ability=FileCheckConst.WRITE_ABLE)
471
508
  out_path = out_path_checker.common_check()
@@ -476,40 +513,27 @@ def run_ut_command(args):
476
513
  if args.result_csv_path:
477
514
  result_csv_path = get_validated_result_csv_path(args.result_csv_path, 'result')
478
515
  details_csv_path = get_validated_details_csv_path(result_csv_path)
479
- white_list = msCheckerConfig.white_list
480
- black_list = msCheckerConfig.black_list
481
- error_data_path = msCheckerConfig.error_data_path
482
- is_online = msCheckerConfig.is_online
483
- nfs_path = msCheckerConfig.nfs_path
484
- host = msCheckerConfig.host
485
- port = msCheckerConfig.port
486
- rank_list = msCheckerConfig.rank_list
487
- tls_path = msCheckerConfig.tls_path
488
- if args.config_path:
489
- config_path_checker = FileChecker(args.config_path, FileCheckConst.FILE,
490
- FileCheckConst.READ_ABLE, FileCheckConst.JSON_SUFFIX)
491
- checked_config_path = config_path_checker.common_check()
492
- _, task_config = parse_json_config(checked_config_path, Const.RUN_UT)
493
- white_list = task_config.white_list
494
- black_list = task_config.black_list
495
- error_data_path = task_config.error_data_path
496
- is_online = task_config.is_online
497
- nfs_path = task_config.nfs_path
498
- host = task_config.host
499
- port = task_config.port
500
- rank_list = task_config.rank_list
501
- tls_path = task_config.tls_path
502
516
 
517
+ error_data_path = checker_config.error_data_path
503
518
  if save_error_data:
504
519
  if args.result_csv_path:
505
520
  time_info = result_csv_path.split('.')[0].split('_')[-1]
506
521
  global UT_ERROR_DATA_DIR
507
522
  UT_ERROR_DATA_DIR = 'ut_error_data' + time_info
508
523
  error_data_path = initialize_save_error_data(error_data_path)
509
- online_config = OnlineConfig(is_online, nfs_path, host, port, rank_list, tls_path)
510
- run_ut_config = RunUTConfig(forward_content, backward_content, result_csv_path, details_csv_path, save_error_data,
511
- args.result_csv_path, real_data_path, set(white_list), set(black_list), error_data_path,
512
- online_config)
524
+ online_config = checker_config.get_online_config()
525
+ checked_online_config(online_config)
526
+ config_params = {
527
+ 'forward_content': forward_content,
528
+ 'backward_content': backward_content,
529
+ 'result_csv_path': result_csv_path,
530
+ 'details_csv_path': details_csv_path,
531
+ 'save_error_data': save_error_data,
532
+ 'is_continue_run_ut': args.result_csv_path,
533
+ 'real_data_path': real_data_path,
534
+ 'error_data_path': error_data_path
535
+ }
536
+ run_ut_config = checker_config.get_run_ut_config(**config_params)
513
537
  run_ut(run_ut_config)
514
538
 
515
539
 
@@ -51,7 +51,7 @@ class BackwardMessage:
51
51
  MULTIPLE_BACKWARD_MESSAGE = "Multiple backward is not supported."
52
52
  UNSUPPORT_BACKWARD_MESSAGE = "function with out=... arguments don't support automatic differentiation, " \
53
53
  "skip backward."
54
- NO_BACKWARD_RESULT_MESSAGE = "function backward result is None, skip backward."
54
+ NO_BACKWARD_RESULT_MESSAGE = "This API does not have backward input data, skip backward."
55
55
 
56
56
 
57
57
  class UtDataInfo:
@@ -186,11 +186,13 @@ def generate_cpu_params(input_args, input_kwargs, need_backward, api_name):
186
186
  logger.error("The depth of arg_in is too large, please check the arg_in.")
187
187
  raise CompareException(CompareException.RECURSION_LIMIT_ERROR)
188
188
  if isinstance(arg_in, (list, tuple)):
189
- return set().union(*tuple(recursive_find_dtypes(arg, kwargs, check_kwargs=check_kwargs, depth=depth+1) for arg in arg_in))
189
+ return set().union(*tuple(recursive_find_dtypes(arg, kwargs, check_kwargs=check_kwargs, depth=depth+1) for
190
+ arg in arg_in))
190
191
  elif isinstance(arg_in, torch.Tensor) and is_tensor_with_raise_precision(arg_in, check_kwargs):
191
192
  return set([arg_in.dtype])
192
193
  elif isinstance(arg_in, dict) and check_kwargs:
193
- return set().union(*tuple(recursive_find_dtypes(v, kwargs, check_kwargs=True, depth=depth+1) for v in arg_in.values()))
194
+ return set().union(*tuple(recursive_find_dtypes(v, kwargs, check_kwargs=True, depth=depth+1) for
195
+ v in arg_in.values()))
194
196
  return set()
195
197
 
196
198
  raise_dtype = None
@@ -204,10 +206,19 @@ def generate_cpu_params(input_args, input_kwargs, need_backward, api_name):
204
206
  raise_dtype = None if api_name in not_raise_dtype_set else raise_dtype
205
207
  is_detach = api_name not in not_detach_set
206
208
  cpu_args = recursive_arg_to_cpu(input_args, is_detach, raise_dtype=raise_dtype)
207
- cpu_kwargs = {key: recursive_arg_to_cpu(value, key != "out" and is_detach, raise_dtype=raise_dtype) for key, value in input_kwargs.items()}
209
+ cpu_kwargs = {key: recursive_arg_to_cpu(value, key != "out" and is_detach, raise_dtype=raise_dtype) for
210
+ key, value in input_kwargs.items()}
208
211
  return cpu_args, cpu_kwargs
209
212
 
210
213
 
211
214
  def record_skip_info(api_full_name, compare, compare_alg_results):
212
215
  result_info = (api_full_name, CompareConst.SKIP, CompareConst.SKIP, [compare_alg_results], None, 0)
213
216
  compare.record_results(result_info)
217
+
218
+
219
+ def is_unsupported_api(api_name, is_overflow_check=False):
220
+ split_name = api_name.split(Const.SEP)[0]
221
+ flag = (split_name == Const.DISTRIBUTED) or (is_overflow_check and split_name == Const.NPU)
222
+ if flag:
223
+ logger.info(f"{split_name} api is not supported for run ut. SKIP.")
224
+ return flag
@@ -16,7 +16,6 @@
16
16
  import glob
17
17
  import os.path
18
18
  import time
19
- import re
20
19
  from multiprocessing import Queue
21
20
  from typing import Optional, Union, Dict, Any
22
21
  from dataclasses import dataclass
@@ -26,9 +25,8 @@ import torch
26
25
  from msprobe.pytorch.api_accuracy_checker.common.utils import ApiData
27
26
  from msprobe.pytorch.api_accuracy_checker.tensor_transport_layer.client import TCPClient
28
27
  from msprobe.pytorch.api_accuracy_checker.tensor_transport_layer.server import TCPServer
29
- from msprobe.pytorch.common.utils import logger
30
28
  from msprobe.core.common.file_utils import remove_path
31
- from msprobe.pytorch.common.utils import save_api_data, load_api_data, save_pt, load_pt
29
+ from msprobe.pytorch.common.utils import logger, save_api_data, load_api_data, save_pkl, load_pkl
32
30
 
33
31
  BufferType = Union[ApiData, Dict[str, Any], str] # Union[Tensor, Tuple[Optional[Tensor]]]
34
32
 
@@ -55,7 +53,6 @@ class ATTL:
55
53
  self.dequeue_list = []
56
54
  self.message_end = False
57
55
  self.kill_progress = False
58
- self.check_attl_config()
59
56
  self.nfs_path = None
60
57
  if self.session_config.nfs_path:
61
58
  self.nfs_path = self.session_config.nfs_path
@@ -73,18 +70,6 @@ class ATTL:
73
70
  self.session_config.tls_path)
74
71
  self.socket_manager.start()
75
72
 
76
- def check_attl_config(self):
77
- if self.session_config.nfs_path:
78
- if os.path.exists(self.session_config.nfs_path):
79
- return
80
- else:
81
- raise Exception(f"nfs path {self.session_config.nfs_path} doesn't exists.")
82
- ipv4_pattern = "([1-9]?\d|1\d{2}|2[0-4]\d|25[0-5])(\.([1-9]?\d|1\d{2}|2[0-4]\d|25[0-5])){3}$"
83
- if not re.match(ipv4_pattern, self.session_config.connect_ip):
84
- raise Exception(f"host {self.session_config.connect_ip} is invalid.")
85
- if not (0 < self.session_config.connect_port <= 65535):
86
- raise Exception(f"port {self.session_config.connect_port} is invalid.")
87
-
88
73
  def stop_serve(self):
89
74
  if isinstance(self.socket_manager, TCPServer):
90
75
  self.socket_manager.stop()
@@ -115,21 +100,21 @@ class ATTL:
115
100
  self.socket_manager.add_to_sending_queue(data, rank=rank, step=step)
116
101
 
117
102
  def recv(self, timeout_ms=0) -> Optional[BufferType]:
118
- buffer = None
119
- while buffer is None:
103
+ buffer = ''
104
+ while not buffer:
120
105
  if timeout_ms > 0:
121
106
  time.sleep(timeout_ms / 1000.0)
122
- if buffer is None and not self.data_queue.empty():
107
+ if not buffer and not self.data_queue.empty():
123
108
  buffer = self.data_queue.get()
124
109
  break
125
- if buffer is None and timeout_ms > 0: # timeout is the only case we give up and return None
110
+ if not buffer and timeout_ms > 0: # timeout is the only case we give up and return None
126
111
  break
127
112
  if self.message_end and self.data_queue.empty():
128
113
  buffer = b"KILL_CONFIRM"
129
114
  self.kill_progress = True
130
115
  break
131
116
  time.sleep(0.1) # waiting outside the lock before next attempt
132
- if buffer is None:
117
+ if not buffer:
133
118
  # this is a result of a timeout
134
119
  self.logger.info(f"RECEIVE API DATA TIMED OUT")
135
120
  else:
@@ -146,7 +131,7 @@ class ATTL:
146
131
  except Exception as e:
147
132
  self.logger.warning("there is something error. please check it. %s", e)
148
133
  if isinstance(buffer, bytes):
149
- return None
134
+ return ''
150
135
  if isinstance(buffer, str):
151
136
  return buffer
152
137
 
@@ -160,7 +145,7 @@ class ATTL:
160
145
  file_path = os.path.join(self.session_config.nfs_path, buffer + f"_{int(time.time())}")
161
146
 
162
147
  try:
163
- save_pt(buffer, file_path)
148
+ save_pkl(buffer, file_path)
164
149
  except Exception as e:
165
150
  self.logger.warning("there is something error in save_pt. please check it. %s", e)
166
151
 
@@ -176,7 +161,7 @@ class ATTL:
176
161
 
177
162
  if cur_file is not None:
178
163
  try:
179
- buffer = load_pt(cur_file)
164
+ buffer = load_pkl(cur_file)
180
165
  except Exception as e:
181
166
  self.logger.warning("there is something error. please check it. %s", e)
182
167
  remove_path(cur_file)
@@ -27,8 +27,8 @@ from twisted.internet import reactor, protocol, endpoints
27
27
  from twisted.protocols.basic import FileSender
28
28
 
29
29
  from msprobe.pytorch.common.utils import logger
30
- from msprobe.pytorch.api_accuracy_checker.tensor_transport_layer.utils import struct_unpack_mode as unpack_mode, \
31
- str_to_bytes_order as bytes_order
30
+ from msprobe.pytorch.api_accuracy_checker.tensor_transport_layer.utils import STRUCT_UNPACK_MODE as unpack_mode, \
31
+ STR_TO_BYTES_ORDER as bytes_order
32
32
 
33
33
  MAX_SENDING_QUEUE_SIZE = 20
34
34
 
@@ -84,15 +84,6 @@ class TCPClient:
84
84
  def run_reactor():
85
85
  reactor.run(installSignalHandlers=False)
86
86
 
87
- def check_tls_path(self):
88
- client_key = os.path.join(self.tls_path, "client.key")
89
- client_crt = os.path.join(self.tls_path, "client.crt")
90
- if not os.path.exists(client_key):
91
- raise Exception(f"client_key: {client_key} is not exists.")
92
- if not os.path.exists(client_crt):
93
- raise Exception(f"client_crt: {client_crt} is not exists.")
94
- return client_key, client_crt
95
-
96
87
  def start(self):
97
88
  def conn_callback(cur_protocol):
98
89
  if cur_protocol.transport and cur_protocol.transport.getPeer().host == self.host:
@@ -114,7 +105,8 @@ class TCPClient:
114
105
  self.factory.protocol = cur_protocol
115
106
  if self.tls_path:
116
107
  from twisted.internet import ssl
117
- client_key, client_crt = self.check_tls_path()
108
+ client_key = os.path.join(self.tls_path, "client.key")
109
+ client_crt = os.path.join(self.tls_path, "client.crt")
118
110
  client_context_factory = ssl.DefaultOpenSSLContextFactory(client_key, client_crt)
119
111
  endpoint = endpoints.SSL4ClientEndpoint(reactor, self.host, self.port, client_context_factory)
120
112
  else:
@@ -1,3 +1,4 @@
1
+
1
2
  # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
3
  # All rights reserved.
3
4
  #
@@ -14,6 +15,7 @@
14
15
  # limitations under the License.
15
16
 
16
17
  import os
18
+ from collections import defaultdict
17
19
  from functools import wraps
18
20
 
19
21
  import torch
@@ -39,7 +41,7 @@ def singleton(cls):
39
41
  @singleton
40
42
  class Counter:
41
43
  def __init__(self) -> None:
42
- self.index_dict = {}
44
+ self.index_dict = defaultdict(int)
43
45
 
44
46
 
45
47
  counter = Counter()
@@ -67,9 +69,9 @@ class AccuracyCheckerDispatch(TorchDispatchMode):
67
69
 
68
70
  res = func(*args, **kwargs)
69
71
  cur_rank = get_tensor_rank(args, res)
70
- cur_api_number = self.counter.index_dict.setdefault(aten_api, 0)
72
+ cur_api_number = self.counter.index_dict[aten_api]
71
73
  api_name = f'{Const.ATEN}{Const.SEP}{aten_api}{Const.SEP}{cur_api_number}'
72
- logger.info(f"tools is dumping api: {api_name}")
74
+ logger.info(f"tools is dumping api: {api_name}, rank: {cur_rank}")
73
75
  api_data = ApiData(api_name, args, kwargs, res, 0, cur_rank)
74
76
  if "device" in api_data.kwargs:
75
77
  api_data.kwargs.pop("device")
@@ -98,7 +100,7 @@ def dispatch4data(func, attl, status):
98
100
  return wrapper
99
101
 
100
102
 
101
- def run_ut_dispatch(attl, status):
103
+ def run_ut_dispatch(attl, status, is_recompute=False):
102
104
  """
103
105
  This function called by online_run_ut.
104
106
  It is used to enable or disable dispatch for torch.autograd.backward function.
@@ -106,5 +108,8 @@ def run_ut_dispatch(attl, status):
106
108
  Args:
107
109
  attl (ATTL): online_run_ut class ATTL, which is used to upload or send api data to server.
108
110
  status (bool): True means enable dispatch, False means disable dispatch.
111
+ is_recompute (bool): Flag of recompute, which is conflicted with aten api, then skip dispatch4data.
109
112
  """
113
+ if is_recompute:
114
+ return
110
115
  torch.autograd.backward = dispatch4data(torch.autograd.backward, attl, status)
@@ -24,7 +24,7 @@ from twisted.internet import reactor, protocol, endpoints
24
24
 
25
25
  from msprobe.pytorch.common.utils import logger
26
26
  from msprobe.pytorch.api_accuracy_checker.tensor_transport_layer.utils import cipher_list, \
27
- struct_unpack_mode as unpack_mode, str_to_bytes_order as bytes_order
27
+ STRUCT_UNPACK_MODE as unpack_mode, STR_TO_BYTES_ORDER as bytes_order
28
28
 
29
29
 
30
30
  class TCPServer:
@@ -40,22 +40,14 @@ class TCPServer:
40
40
  def run_reactor():
41
41
  reactor.run(installSignalHandlers=False)
42
42
 
43
- def check_tls_path(self):
44
- server_key = os.path.join(self.tls_path, "server.key")
45
- server_crt = os.path.join(self.tls_path, "server.crt")
46
- if not os.path.exists(server_key):
47
- raise Exception(f"server_key: {server_key} is not exists.")
48
- if not os.path.exists(server_crt):
49
- raise Exception(f"server_crt: {server_crt} is not exists.")
50
- return server_key, server_crt
51
-
52
43
  def start(self):
53
44
  self.factory.protocol = self.build_protocol
54
45
 
55
46
  if self.tls_path:
56
47
  from OpenSSL import SSL
57
48
  from twisted.internet import ssl
58
- server_key, server_crt = self.check_tls_path()
49
+ server_key = os.path.join(self.tls_path, "server.key")
50
+ server_crt = os.path.join(self.tls_path, "server.crt")
59
51
  server_context_factory = ssl.DefaultOpenSSLContextFactory(server_key, server_crt, SSL.TLSv1_2_METHOD)
60
52
  server_context_ = server_context_factory.getContext()
61
53
  server_context_.set_cipher_list(cipher_list)
@@ -40,5 +40,5 @@ cipher_list = ":".join(
40
40
  "TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256"]
41
41
  ).encode()
42
42
 
43
- struct_unpack_mode = "!Q"
44
- str_to_bytes_order = "big"
43
+ STRUCT_UNPACK_MODE = "!Q"
44
+ STR_TO_BYTES_ORDER = "big"
@@ -22,7 +22,11 @@ def npu_confusion_transpose(data, perm, shape, transpose_first):
22
22
 
23
23
 
24
24
  def npu_confusion_transpose_backward(grad, perm, shape, transpose_first):
25
- shape_cal = shape if transpose_first else [shape[perm_dim] for perm_dim in perm]
25
+ try:
26
+ shape_cal = shape if transpose_first else [shape[perm_dim] for perm_dim in perm]
27
+ except IndexError as e:
28
+ raise IndexError("npu_confusion_transpose_backward: Invalid perm index for shape") from e
29
+
26
30
  perm_cal = [0] * len(perm)
27
31
  for i, perm_dim in enumerate(perm):
28
32
  perm_cal[perm_dim] = i
@@ -17,6 +17,9 @@ import torch
17
17
 
18
18
 
19
19
  def matmul_backward(grad, self, other, mask):
20
+ if len(mask) < 2:
21
+ raise RuntimeError("Mask size at least 2")
22
+
20
23
  grad_self, grad_other = None, None
21
24
  dim_self = self.dim()
22
25
  dim_other = other.dim()
@@ -24,6 +27,7 @@ def matmul_backward(grad, self, other, mask):
24
27
  size_grad = list(grad.size())
25
28
  size_self = list(self.size())
26
29
  size_other = list(other.size())
30
+
27
31
  if dim_self == 1 and dim_other == 1:
28
32
  grad_self = other.mul(grad) if mask[0] else grad_self
29
33
  grad_other = self.mul(grad) if mask[1] else grad_other
@@ -34,19 +38,27 @@ def matmul_backward(grad, self, other, mask):
34
38
  grad_self = grad.unsqueeze(0).mm(other.transpose(-1, -2)).squeeze_(0) if mask[0] else grad_self
35
39
  grad_other = self.unsqueeze(1).mm(grad.unsqueeze(0)) if mask[1] else grad_other
36
40
  elif dim_self >= 3 and (dim_other == 1 or dim_other == 2):
41
+ if len(size_grad) < 1:
42
+ raise RuntimeError("size_grad's length at least 1")
37
43
  view_size = 1 if dim_other == 1 else size_grad[-1]
38
44
  unfolded_grad = (grad.unsqueeze(-1) if dim_other == 1 else grad).contiguous().view(-1, view_size)
39
45
  if mask[0]:
40
46
  grad_self = unfolded_grad.mm(other.unsqueeze(0) if dim_other == 1 else other.transpose(-1, -2)) \
41
47
  .view(size_self)
42
48
  if mask[1]:
49
+ if len(size_self) < 1:
50
+ raise RuntimeError("size_self's length at least 1")
43
51
  unfolded_self = self.contiguous().view([-1, size_self[-1]])
44
52
  grad_other = unfolded_self.transpose(-1, -2).mm(unfolded_grad).view(size_other)
45
53
  elif (dim_self == 1 or dim_self == 2) and dim_other >= 3:
54
+ if len(size_grad) < 2:
55
+ raise RuntimeError("size_grad's length at least 2")
46
56
  view_size = 1 if dim_self == 1 else size_grad[-2]
47
57
  unfolded_grad_t = grad.view([-1, view_size]) \
48
58
  if dim_self == 1 else grad.transpose(-1, -2).contiguous().view([-1, view_size])
49
59
  if mask[0]:
60
+ if len(size_other) < 2:
61
+ raise RuntimeError("size_other's length at least 2")
50
62
  # create a 2D-matrix from other
51
63
  unfolded_other_t = \
52
64
  other.transpose(-1, -2).contiguous().view([-1, size_other[-2]]).transpose(-1, -2)