mindstudio-probe 1.1.0__py3-none-any.whl → 1.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (220) hide show
  1. {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.1.1.dist-info}/METADATA +5 -5
  2. mindstudio_probe-1.1.1.dist-info/RECORD +341 -0
  3. {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.1.1.dist-info}/WHEEL +1 -1
  4. {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.1.1.dist-info}/entry_points.txt +0 -1
  5. msprobe/README.md +39 -3
  6. msprobe/config.json +1 -3
  7. msprobe/core/advisor/advisor.py +8 -3
  8. msprobe/core/common/const.py +113 -13
  9. msprobe/core/common/exceptions.py +25 -3
  10. msprobe/core/common/file_utils.py +150 -26
  11. msprobe/core/common/inplace_op_checker.py +15 -0
  12. msprobe/core/common/log.py +27 -9
  13. msprobe/core/common/utils.py +182 -69
  14. msprobe/core/common_config.py +44 -15
  15. msprobe/core/compare/acc_compare.py +207 -142
  16. msprobe/core/compare/check.py +2 -5
  17. msprobe/core/compare/compare_cli.py +21 -4
  18. msprobe/core/compare/highlight.py +124 -55
  19. msprobe/core/compare/layer_mapping/__init__.py +19 -0
  20. msprobe/core/compare/layer_mapping/data_scope_parser.py +235 -0
  21. msprobe/core/compare/layer_mapping/layer_mapping.py +242 -0
  22. msprobe/core/compare/layer_mapping/postprocess_pass.py +94 -0
  23. msprobe/core/compare/npy_compare.py +52 -23
  24. msprobe/core/compare/utils.py +272 -247
  25. msprobe/core/data_dump/data_collector.py +13 -11
  26. msprobe/core/data_dump/data_processor/base.py +46 -16
  27. msprobe/core/data_dump/data_processor/mindspore_processor.py +4 -4
  28. msprobe/core/data_dump/data_processor/pytorch_processor.py +156 -59
  29. msprobe/core/data_dump/scope.py +113 -34
  30. msprobe/core/grad_probe/constant.py +27 -13
  31. msprobe/core/grad_probe/grad_compare.py +18 -1
  32. msprobe/core/grad_probe/utils.py +30 -2
  33. msprobe/core/overflow_check/abnormal_scene.py +185 -0
  34. msprobe/core/overflow_check/api_info.py +55 -0
  35. msprobe/core/overflow_check/checker.py +138 -0
  36. msprobe/core/overflow_check/filter.py +157 -0
  37. msprobe/core/overflow_check/ignore_rules.yaml +55 -0
  38. msprobe/core/overflow_check/level.py +22 -0
  39. msprobe/core/overflow_check/utils.py +28 -0
  40. msprobe/docs/01.installation.md +10 -0
  41. msprobe/docs/02.config_introduction.md +49 -22
  42. msprobe/docs/03.config_examples.md +2 -9
  43. msprobe/docs/04.kernel_dump_PyTorch.md +73 -0
  44. msprobe/docs/05.data_dump_PyTorch.md +3 -1
  45. msprobe/docs/06.data_dump_MindSpore.md +157 -90
  46. msprobe/docs/07.accuracy_checker_PyTorch.md +12 -12
  47. msprobe/docs/08.accuracy_checker_online_PyTorch.md +1 -6
  48. msprobe/docs/09.accuracy_checker_MindSpore.md +44 -8
  49. msprobe/docs/10.accuracy_compare_PyTorch.md +19 -13
  50. msprobe/docs/11.accuracy_compare_MindSpore.md +104 -13
  51. msprobe/docs/12.overflow_check_PyTorch.md +1 -1
  52. msprobe/docs/13.overflow_check_MindSpore.md +6 -6
  53. msprobe/docs/15.free_benchmarking_PyTorch.md +4 -5
  54. msprobe/docs/16.free_benchmarking_MindSpore.md +56 -37
  55. msprobe/docs/17.grad_probe.md +5 -6
  56. msprobe/docs/19.monitor.md +468 -0
  57. msprobe/docs/20.monitor_performance_baseline.md +52 -0
  58. msprobe/docs/21.visualization_PyTorch.md +386 -0
  59. msprobe/docs/22.visualization_MindSpore.md +384 -0
  60. msprobe/docs/23.tool_function_introduction.md +28 -0
  61. msprobe/docs/FAQ.md +3 -0
  62. msprobe/docs/data_dump_Mindspore/dynamic_graph_quick_start_example.md +211 -0
  63. msprobe/docs/img/compare_result.png +0 -0
  64. msprobe/docs/img/monitor/cpu_info.png +0 -0
  65. msprobe/mindspore/__init__.py +15 -0
  66. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +113 -145
  67. msprobe/mindspore/api_accuracy_checker/api_info.py +21 -6
  68. msprobe/mindspore/api_accuracy_checker/api_runner.py +43 -18
  69. msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +21 -7
  70. msprobe/mindspore/api_accuracy_checker/checker_support_api.yaml +77 -0
  71. msprobe/mindspore/api_accuracy_checker/cmd_parser.py +63 -1
  72. msprobe/mindspore/api_accuracy_checker/compute_element.py +59 -24
  73. msprobe/mindspore/api_accuracy_checker/data_manager.py +264 -0
  74. msprobe/mindspore/api_accuracy_checker/main.py +27 -3
  75. msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +206 -0
  76. msprobe/mindspore/api_accuracy_checker/multi_data_manager.py +58 -0
  77. msprobe/mindspore/api_accuracy_checker/type_mapping.py +22 -5
  78. msprobe/mindspore/api_accuracy_checker/utils.py +34 -17
  79. msprobe/mindspore/cell_processor.py +33 -12
  80. msprobe/mindspore/common/const.py +33 -13
  81. msprobe/mindspore/common/log.py +5 -9
  82. msprobe/mindspore/common/utils.py +43 -4
  83. msprobe/mindspore/compare/distributed_compare.py +22 -22
  84. msprobe/mindspore/compare/ms_compare.py +271 -248
  85. msprobe/mindspore/compare/ms_graph_compare.py +81 -47
  86. msprobe/mindspore/debugger/debugger_config.py +4 -1
  87. msprobe/mindspore/debugger/precision_debugger.py +7 -1
  88. msprobe/mindspore/dump/dump_tool_factory.py +3 -1
  89. msprobe/mindspore/dump/hook_cell/api_registry.py +12 -2
  90. msprobe/mindspore/dump/hook_cell/primitive_hooks.py +13 -16
  91. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +25 -0
  92. msprobe/mindspore/dump/jit_dump.py +17 -5
  93. msprobe/mindspore/dump/kernel_graph_dump.py +2 -4
  94. msprobe/mindspore/dump/kernel_kbyk_dump.py +2 -4
  95. msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +140 -0
  96. msprobe/mindspore/dym_loader/hook_dynamic_loader.h +53 -0
  97. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +145 -39
  98. msprobe/mindspore/free_benchmark/common/handler_params.py +1 -2
  99. msprobe/mindspore/free_benchmark/common/utils.py +19 -4
  100. msprobe/mindspore/free_benchmark/data/support_wrap_ops.yaml +0 -204
  101. msprobe/mindspore/free_benchmark/handler/base_handler.py +3 -3
  102. msprobe/mindspore/free_benchmark/handler/check_handler.py +4 -5
  103. msprobe/mindspore/free_benchmark/handler/fix_handler.py +4 -4
  104. msprobe/mindspore/free_benchmark/handler/handler_factory.py +4 -4
  105. msprobe/mindspore/free_benchmark/perturbation/add_noise.py +2 -2
  106. msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +15 -6
  107. msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +4 -4
  108. msprobe/mindspore/free_benchmark/perturbation/exchange_value.py +2 -2
  109. msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +13 -6
  110. msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +2 -2
  111. msprobe/mindspore/free_benchmark/self_check_tool_factory.py +2 -2
  112. msprobe/mindspore/grad_probe/global_context.py +28 -8
  113. msprobe/mindspore/grad_probe/grad_analyzer.py +27 -13
  114. msprobe/mindspore/grad_probe/grad_monitor.py +16 -1
  115. msprobe/mindspore/grad_probe/grad_stat_csv.py +33 -5
  116. msprobe/mindspore/grad_probe/hook.py +24 -10
  117. msprobe/mindspore/grad_probe/utils.py +18 -5
  118. msprobe/mindspore/ms_config.py +22 -15
  119. msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +2 -4
  120. msprobe/mindspore/runtime.py +15 -0
  121. msprobe/mindspore/service.py +36 -30
  122. msprobe/mindspore/task_handler_factory.py +15 -0
  123. msprobe/msprobe.py +24 -7
  124. msprobe/pytorch/__init__.py +3 -2
  125. msprobe/pytorch/api_accuracy_checker/common/config.py +62 -0
  126. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +3 -4
  127. msprobe/pytorch/api_accuracy_checker/generate_op_script/config_op.json +9 -0
  128. msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +454 -0
  129. msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +365 -0
  130. msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +6 -1
  131. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +19 -14
  132. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +13 -9
  133. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +77 -53
  134. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +15 -4
  135. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +9 -24
  136. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +4 -12
  137. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/dump_dispatch.py +9 -4
  138. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +3 -11
  139. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +2 -2
  140. msprobe/pytorch/bench_functions/confusion_transpose.py +5 -1
  141. msprobe/pytorch/bench_functions/matmul_backward.py +12 -0
  142. msprobe/pytorch/bench_functions/npu_fusion_attention.py +100 -6
  143. msprobe/pytorch/bench_functions/rotary_mul.py +4 -0
  144. msprobe/pytorch/bench_functions/swiglu.py +10 -2
  145. msprobe/pytorch/common/parse_json.py +6 -6
  146. msprobe/pytorch/common/utils.py +56 -5
  147. msprobe/pytorch/compare/distributed_compare.py +8 -9
  148. msprobe/pytorch/compare/pt_compare.py +8 -6
  149. msprobe/pytorch/debugger/debugger_config.py +19 -15
  150. msprobe/pytorch/dump/kernel_dump/kernel_config.py +33 -0
  151. msprobe/pytorch/free_benchmark/common/constant.py +15 -0
  152. msprobe/pytorch/free_benchmark/common/counter.py +15 -0
  153. msprobe/pytorch/free_benchmark/common/enums.py +15 -0
  154. msprobe/pytorch/free_benchmark/common/params.py +8 -1
  155. msprobe/pytorch/free_benchmark/common/utils.py +26 -4
  156. msprobe/pytorch/free_benchmark/compare/grad_saver.py +20 -3
  157. msprobe/pytorch/free_benchmark/compare/single_benchmark.py +2 -0
  158. msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +3 -1
  159. msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +6 -4
  160. msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +2 -0
  161. msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +4 -0
  162. msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +10 -0
  163. msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +6 -5
  164. msprobe/pytorch/grad_probe/grad_monitor.py +23 -6
  165. msprobe/pytorch/grad_probe/grad_stat_csv.py +40 -10
  166. msprobe/pytorch/hook_module/support_wrap_ops.yaml +1 -0
  167. msprobe/pytorch/hook_module/wrap_functional.py +14 -12
  168. msprobe/pytorch/module_processer.py +2 -5
  169. msprobe/pytorch/monitor/anomaly_analyse.py +201 -0
  170. msprobe/pytorch/monitor/anomaly_detect.py +340 -0
  171. msprobe/pytorch/monitor/distributed/__init__.py +0 -0
  172. msprobe/pytorch/monitor/distributed/distributed_ops.yaml +19 -0
  173. msprobe/pytorch/monitor/distributed/stack_blacklist.yaml +5 -0
  174. msprobe/pytorch/monitor/distributed/wrap_distributed.py +272 -0
  175. msprobe/pytorch/monitor/features.py +108 -0
  176. msprobe/pytorch/monitor/module_hook.py +870 -0
  177. msprobe/pytorch/monitor/module_metric.py +193 -0
  178. msprobe/pytorch/monitor/module_spec_verifier.py +93 -0
  179. msprobe/pytorch/monitor/optimizer_collect.py +295 -0
  180. msprobe/pytorch/monitor/unittest/__init__.py +0 -0
  181. msprobe/pytorch/monitor/unittest/test_monitor.py +145 -0
  182. msprobe/pytorch/monitor/utils.py +250 -0
  183. msprobe/pytorch/monitor/visualizer.py +59 -0
  184. msprobe/pytorch/online_dispatch/__init__.py +2 -3
  185. msprobe/pytorch/online_dispatch/compare.py +29 -38
  186. msprobe/pytorch/online_dispatch/dispatch.py +50 -25
  187. msprobe/pytorch/online_dispatch/dump_compare.py +21 -9
  188. msprobe/pytorch/online_dispatch/single_compare.py +53 -32
  189. msprobe/pytorch/online_dispatch/torch_ops_config.yaml +1 -1
  190. msprobe/pytorch/online_dispatch/utils.py +49 -21
  191. msprobe/pytorch/parse_tool/lib/compare.py +12 -18
  192. msprobe/pytorch/parse_tool/lib/config.py +1 -1
  193. msprobe/pytorch/parse_tool/lib/parse_tool.py +1 -2
  194. msprobe/pytorch/parse_tool/lib/utils.py +16 -35
  195. msprobe/pytorch/parse_tool/lib/visualization.py +2 -0
  196. msprobe/pytorch/pt_config.py +31 -8
  197. msprobe/pytorch/service.py +15 -5
  198. msprobe/visualization/__init__.py +14 -0
  199. msprobe/visualization/builder/__init__.py +14 -0
  200. msprobe/visualization/builder/graph_builder.py +165 -0
  201. msprobe/visualization/builder/msprobe_adapter.py +205 -0
  202. msprobe/visualization/compare/__init__.py +14 -0
  203. msprobe/visualization/compare/graph_comparator.py +130 -0
  204. msprobe/visualization/compare/mode_adapter.py +211 -0
  205. msprobe/visualization/graph/__init__.py +14 -0
  206. msprobe/visualization/graph/base_node.py +124 -0
  207. msprobe/visualization/graph/graph.py +200 -0
  208. msprobe/visualization/graph/node_colors.py +95 -0
  209. msprobe/visualization/graph/node_op.py +39 -0
  210. msprobe/visualization/graph_service.py +214 -0
  211. msprobe/visualization/utils.py +232 -0
  212. mindstudio_probe-1.1.0.dist-info/RECORD +0 -287
  213. msprobe/docs/04.acl_config_examples.md +0 -78
  214. msprobe/mindspore/compare/layer_mapping.py +0 -146
  215. msprobe/mindspore/compare/modify_mapping.py +0 -107
  216. msprobe/mindspore/free_benchmark/decorator/dec_forward.py +0 -57
  217. msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +0 -122
  218. {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.1.1.dist-info}/LICENSE +0 -0
  219. {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.1.1.dist-info}/top_level.txt +0 -0
  220. /msprobe/{mindspore/free_benchmark/decorator → pytorch/monitor}/__init__.py +0 -0
@@ -50,8 +50,8 @@ else:
50
50
  from msprobe.pytorch.common.utils import logger
51
51
  from msprobe.core.common.const import Const, CompareConst
52
52
 
53
- gtype = torch.float64 # arm host必须选择float64,x86环境选择float32即可,64也行。arm计算很慢,s=8k的场景建议使用x86
54
- softmax_build_mode = "QKV" # "MAX_SUM"
53
+ GTYPE = torch.float64 # arm host必须选择float64,x86环境选择float32即可,64也行。arm计算很慢,s=8k的场景建议使用x86
54
+ SOFTMAX_BUILD_MODE = "QKV" # "MAX_SUM"
55
55
 
56
56
 
57
57
  def softmax_forward(x):
@@ -166,6 +166,18 @@ def parse_bsnd_args(query, key, head_num, input_layout):
166
166
 
167
167
 
168
168
  def convert_from_bnsd(_input, input_layout):
169
+ """
170
+ transform qkv from bnsd to input_layout.
171
+ B: batch_size
172
+ S: sequence_length
173
+ N: num_heads
174
+ D: head_dim
175
+ Args:
176
+ _input (torch.Tensor): tensor of shape (B,N,S,D)
177
+ input_layout (str): "BSH" or "SBH" or "BSND" or "BNSD" or "TND"
178
+ Returns:
179
+ tensor of shape (B,N,S,D) or (B,S,N,D) or (S,B,H) or (B,S,H)
180
+ """
169
181
  if input_layout == "BSH":
170
182
  # (B,N,S,D)=>(B,S,N*D)
171
183
  out = rearrange(_input, 'b n s d -> b s (n d)').contiguous()
@@ -183,7 +195,19 @@ def convert_from_bnsd(_input, input_layout):
183
195
 
184
196
 
185
197
  def convert_to_bnsd(_input, n, input_layout):
186
- # 默认"BNSD"无需处理
198
+ """
199
+ transform qkv from input_layout to bnsd.
200
+ B: batch_size
201
+ S: sequence_length
202
+ N: num_heads
203
+ D: head_dim
204
+ Args:
205
+ _input (torch.Tensor): tensor of shape (B,N,S,D) or (B,S,N,D) or (S,B,H) or (B,S,H)
206
+ n (int): num_heads
207
+ input_layout (str):"BSH" or "SBH" or "BSND" or "BNSD" or "TND"
208
+ Returns:
209
+ tensor of shape (B,N,S,D)
210
+ """
187
211
  if input_layout == "BSH":
188
212
  # (B,S,N*D)=>(B,N,S,D)
189
213
  out = rearrange(_input, 'b s (n d) -> b n s d', n=n)
@@ -199,7 +223,68 @@ def convert_to_bnsd(_input, n, input_layout):
199
223
  out = _input
200
224
  if out.dim() != 4:
201
225
  raise ValueError(f"convert qkv format failed with input_layout {input_layout}.")
202
- return out.to(gtype)
226
+ return out.to(GTYPE)
227
+
228
+
229
+ def convert_from_bsnd(_input, input_layout):
230
+ """
231
+ transform qkv from bsnd to input_layout.
232
+ B: batch_size
233
+ S: sequence_length
234
+ N: num_heads
235
+ D: head_dim
236
+ Args:
237
+ _input (torch.Tensor): tensor of shape (B,S,N,D)
238
+ input_layout (str): "BSH" or "SBH" or "BSND" or "BNSD" or "TND"
239
+ Returns:
240
+ tensor of shape (B,N,S,D) or (B,S,N,D) or (S,B,H) or (B,S,H)
241
+ """
242
+ if input_layout == "BSH":
243
+ # (B,S,N,D)=>(B,S,N*D)
244
+ out = rearrange(_input, 'b s n d -> b s (n d)').contiguous()
245
+ elif input_layout == "SBH":
246
+ # (B,S,N,D)=>(S,B,N*D)
247
+ out = rearrange(_input, 'b s n d -> s b (n d)').contiguous()
248
+ elif input_layout == "BNSD":
249
+ # (B,S,N,D)=>(B,N,S,D)
250
+ out = rearrange(_input, 'b s n d -> b n s d').contiguous()
251
+ elif input_layout == "TND":
252
+ raise ValueError(f"input_layout {input_layout} does not supported for now.")
253
+ else:
254
+ out = _input
255
+ return out
256
+
257
+
258
+ def convert_to_bsnd(_input, n, input_layout):
259
+ """
260
+ transform qkv from input_layout to bsnd.
261
+ B: batch_size
262
+ S: sequence_length
263
+ N: num_heads
264
+ D: head_dim
265
+ Args:
266
+ _input (torch.Tensor): tensor of shape (B,N,S,D) or (B,S,N,D) or (S,B,H) or (B,S,H)
267
+ n (int): num_heads
268
+ input_layout (str):"BSH" or "SBH" or "BSND" or "BNSD" or "TND"
269
+ Returns:
270
+ tensor of shape (B,S,N,D)
271
+ """
272
+ if input_layout == "BSH":
273
+ # (B,S,N*D)=>(B,S,N,D)
274
+ out = rearrange(_input, 'b s (n d) -> b s n d', n=n)
275
+ elif input_layout == "SBH":
276
+ # (S,B,N*D)=>(B,S,N,D)
277
+ out = rearrange(_input, 's b (n d) -> b s n d', n=n)
278
+ elif input_layout == "BNSD":
279
+ # (B,N,S,D)=>(B,S,N,D)
280
+ out = rearrange(_input, 'b n s d -> b s n d', n=n)
281
+ elif input_layout == "TND":
282
+ raise ValueError(f"input_layout {input_layout} does not supported for now.")
283
+ else:
284
+ out = _input
285
+ if out.dim() != 4:
286
+ raise ValueError(f"convert qkv format failed with input_layout {input_layout}.")
287
+ return out
203
288
 
204
289
 
205
290
  def generate_atten_mask(*args):
@@ -279,7 +364,7 @@ def rebuid_softmax_by_qkv(q, k, atten_mask, pse, scale):
279
364
  """
280
365
  logger.info("Using QKV to rebuild original softmax")
281
366
  qk = calculate_qk(q, k, atten_mask, pse, scale)
282
- softmax_res, x_max, x_sum = softmax_forward(qk)
367
+ softmax_res, _, _ = softmax_forward(qk)
283
368
  return softmax_res
284
369
 
285
370
 
@@ -319,6 +404,10 @@ def get_input_layout(*args, **kwargs):
319
404
 
320
405
 
321
406
  def npu_fusion_attention_forward_patch(*args, **kwargs):
407
+
408
+ if len(args) < 2:
409
+ raise RuntimeError("npu_fusion_attention_forward_patch: length of args should greater than or equal to 2.")
410
+
322
411
  # query, key, value, head_num, input_layout
323
412
  head_num = get_head_num(*args, **kwargs)
324
413
  input_layout = get_input_layout(*args, **kwargs)
@@ -454,7 +543,7 @@ def npu_fusion_attention_grad(*args, **kwargs):
454
543
  value = convert_to_bnsd(value, n2, input_layout)
455
544
  k_new, v_new = generate_kv(key, value, n1, n2)
456
545
 
457
- if softmax_build_mode == "QKV":
546
+ if SOFTMAX_BUILD_MODE == "QKV":
458
547
  softmax_res = rebuid_softmax_by_qkv(query, k_new, atten_mask, pse, scale_value)
459
548
  else:
460
549
  softmax_res = rebuild_softmax_by_max_sum(query, k_new, atten_mask, pse, scale_value, softmax_max, softmax_sum)
@@ -531,8 +620,13 @@ def gpu_fusion_attention(*args, **kwargs):
531
620
  else:
532
621
  alibi_slopes = None
533
622
 
623
+ input_layout = get_input_layout(*args, **kwargs)
624
+ query = convert_to_bsnd(query, n1, input_layout)
625
+ key = convert_to_bsnd(key, n2, input_layout)
626
+ value = convert_to_bsnd(value, n2, input_layout)
534
627
  out = flash_attn_func(
535
628
  query, key, value, dropout_p=(1 - keep_prob), softmax_scale=scale, causal=causal_switch,
536
629
  window_size=(window_left, window_right), alibi_slopes=alibi_slopes, deterministic=deterministic
537
630
  )
631
+ out = convert_from_bsnd(out, input_layout)
538
632
  return out, Const.NONE, Const.NONE
@@ -40,6 +40,9 @@ def npu_rotary_mul_backward(dy_tensor, x, r1, r2):
40
40
  x_shape = x.shape
41
41
  h = x.float()
42
42
  grad = dy_tensor.float()
43
+ if len(r1_shape) < 4 or len(x_shape) < 4:
44
+ raise RuntimeError(f"Shape of r1 and x should at least be 4-dimension, "
45
+ f"but got r1 shape:{r1_shape}, x shape:{x_shape}")
43
46
  condition_1 = (r1_shape[0] == 1
44
47
  and r1_shape[1] == x_shape[1]
45
48
  and r1_shape[2] == 1
@@ -68,4 +71,5 @@ def npu_rotary_mul_backward(dy_tensor, x, r1, r2):
68
71
  for j in range(x_shape[2]):
69
72
  r2_grad[:, 0, 0, :] += (x_new2[:, i, j, :] * grad[:, i, j, :])
70
73
  r1_grad[:, 0, 0, :] += (h[:, i, j, :] * grad[:, i, j, :])
74
+
71
75
  return x.grad.cpu(), r1_grad.cpu(), r2_grad.cpu()
@@ -19,7 +19,11 @@ import torch
19
19
  def npu_swiglu(x, dim=-1):
20
20
  tensor_dtype = x.dtype
21
21
 
22
- in_tensors = torch.chunk(x, 2, dim=dim)
22
+ try:
23
+ in_tensors = torch.chunk(x, 2, dim=dim)
24
+ except Exception as e:
25
+ raise RuntimeError(f"Invalid chunk x into 2 tensors with shape {x.shape} and dimension {dim}") from e
26
+
23
27
  if tensor_dtype == torch.float32:
24
28
  tensor_scalar = torch.sigmoid(torch.mul(in_tensors[0], 1.0))
25
29
  output_data = torch.mul(torch.mul(tensor_scalar, in_tensors[0]), in_tensors[1])
@@ -34,7 +38,11 @@ def npu_swiglu(x, dim=-1):
34
38
 
35
39
  def npu_swiglu_backward(grad, x, dim=-1):
36
40
  tensor_dtype = grad.dtype
37
- in_tensors = torch.chunk(x, 2, dim=dim)
41
+ try:
42
+ in_tensors = torch.chunk(x, 2, dim=dim)
43
+ except Exception as e:
44
+ raise RuntimeError(f"Invalid chunk x into 2 tensors with shape {x.shape} and dimension {dim}") from e
45
+
38
46
  tensor_grad_out = grad
39
47
 
40
48
  if tensor_dtype == torch.float16:
@@ -13,20 +13,20 @@
13
13
  # See the License for the specific language governing permissions and
14
14
  # limitations under the License.
15
15
 
16
- import json
17
-
18
16
  from msprobe.core.common.exceptions import ParseJsonException
19
- from msprobe.core.common.file_utils import FileOpen
17
+ from msprobe.core.common.file_utils import load_json
18
+ from msprobe.core.common.log import logger
20
19
 
21
20
 
22
21
  def parse_json_info_forward_backward(json_path):
23
- with FileOpen(json_path, 'r') as f:
24
- dump_json = json.load(f)
22
+ dump_json = load_json(json_path)
25
23
 
26
24
  real_data_path = dump_json.get("dump_data_dir")
27
25
  dump_data = dump_json.get("data")
26
+ if dump_data is None:
27
+ raise ParseJsonException(ParseJsonException.InvalidDumpJson, "something wrong with dump, no data found in dump.json")
28
28
  if not dump_data:
29
- raise ParseJsonException(ParseJsonException.InvalidDumpJson, "dump数据中没有data字段")
29
+ logger.warning("data field is empty, no overflow data found.")
30
30
 
31
31
  forward_data = {}
32
32
  backward_data = {}
@@ -15,6 +15,7 @@
15
15
 
16
16
  import io
17
17
  import os
18
+ import pickle
18
19
  import random
19
20
  import stat
20
21
  from functools import wraps
@@ -24,7 +25,7 @@ import torch
24
25
  import torch.distributed as dist
25
26
  from msprobe.core.common.exceptions import DistributedNotInitializedError
26
27
  from msprobe.core.common.file_utils import (FileCheckConst, change_mode,
27
- check_file_or_directory_path, check_path_before_create)
28
+ check_file_or_directory_path, check_path_before_create, FileOpen)
28
29
  from msprobe.core.common.log import logger
29
30
  from msprobe.core.common.utils import check_seed_all
30
31
  from packaging import version
@@ -75,7 +76,7 @@ def parameter_adapter(func):
75
76
  else:
76
77
  res = [input_tensor[tensor_index] for tensor_index in indices]
77
78
  return getattr(torch._C._VariableFunctionsClass, "stack")(res, 0)
78
- if self.op_name_ == "__eq__" and args[1] is None:
79
+ if self.op_name_ == "__eq__" and len(args) > 1 and args[1] is None:
79
80
  return False
80
81
  return func(self, *args, **kwargs)
81
82
 
@@ -269,17 +270,17 @@ def load_pt(pt_path, to_cpu=False):
269
270
  check_file_or_directory_path(pt_path)
270
271
  try:
271
272
  if to_cpu:
272
- pt = torch.load(pt_path, map_location=torch.device("cpu"))
273
+ pt = torch.load(pt_path, map_location=torch.device("cpu"), weights_only=True)
273
274
  else:
274
- pt = torch.load(pt_path)
275
+ pt = torch.load(pt_path, weights_only=True)
275
276
  except Exception as e:
276
277
  raise RuntimeError(f"load pt file {pt_path} failed") from e
277
278
  return pt
278
279
 
279
280
 
280
281
  def save_pt(tensor, filepath):
281
- filepath = os.path.realpath(filepath)
282
282
  check_path_before_create(filepath)
283
+ filepath = os.path.realpath(filepath)
283
284
  try:
284
285
  torch.save(tensor, filepath)
285
286
  except Exception as e:
@@ -290,6 +291,56 @@ def save_pt(tensor, filepath):
290
291
  change_mode(filepath, FileCheckConst.DATA_FILE_AUTHORITY)
291
292
 
292
293
 
294
+ class TypeCheckingUnpickler(pickle.Unpickler):
295
+ """
296
+ This class is a subclass of pickle.Unpickler, which is used to unpickle pickled objects.
297
+ It overrides the find_class method to add type checking functionality.
298
+ """
299
+ allowed_types = [
300
+ "str",
301
+ "ApiData",
302
+ "OrderedDict",
303
+ "_rebuild_tensor_v2", # from torch.utils
304
+ "_load_from_bytes" # from torch.storage
305
+ ]
306
+
307
+ def find_class(self, module, name):
308
+ """
309
+ Method to find the class of the object to be unpickled.
310
+ Throws pickle.UnpicklingError If the object type is not in the allowed types list.
311
+ """
312
+ if name in self.allowed_types:
313
+ return super().find_class(module, name)
314
+ raise pickle.UnpicklingError("Unsupported object type: {}.{}".format(module, name))
315
+
316
+
317
+ def save_pkl(tensor, filepath):
318
+ """Save ApiData or str objection by pickle"""
319
+ check_path_before_create(filepath)
320
+ filepath = os.path.realpath(filepath)
321
+ try:
322
+ with FileOpen(filepath, 'wb') as f:
323
+ pickle.dump(tensor, f)
324
+ except Exception as e:
325
+ logger.error("Save pt file failed, please check according possible error causes: "
326
+ "1. out of disk space or disk error, "
327
+ "2. no permission to write files, etc.")
328
+ raise RuntimeError(f"save pt file {filepath} failed") from e
329
+ change_mode(filepath, FileCheckConst.DATA_FILE_AUTHORITY)
330
+
331
+
332
+ def load_pkl(pt_path):
333
+ """Load ApiData or str objection by pickle for accuracy_checker_online"""
334
+ check_file_or_directory_path(pt_path)
335
+ pt_path = os.path.realpath(pt_path)
336
+ try:
337
+ with FileOpen(pt_path, 'rb') as f:
338
+ pt = TypeCheckingUnpickler(f).load()
339
+ except Exception as e:
340
+ raise RuntimeError(f"load pt file {pt_path} failed: {e}") from e
341
+ return pt
342
+
343
+
293
344
  def save_api_data(api_data):
294
345
  """Save data to io stream"""
295
346
  try:
@@ -15,7 +15,7 @@
15
15
 
16
16
  import os
17
17
  from msprobe.core.common.utils import CompareException, check_compare_param, \
18
- check_configuration_param, task_dumppath_get
18
+ check_configuration_param, set_dump_path, get_dump_mode
19
19
  from msprobe.core.common.file_utils import create_directory
20
20
  from msprobe.core.common.exceptions import FileCheckException
21
21
  from msprobe.pytorch.common.log import logger
@@ -30,6 +30,7 @@ def compare_distributed(npu_dump_dir, bench_dump_dir, output_path, **kwargs):
30
30
  stack_mode = kwargs.get('stack_mode', False)
31
31
  auto_analyze = kwargs.get('auto_analyze', True)
32
32
  fuzzy_match = kwargs.get('fuzzy_match', False)
33
+ is_print_compare_log = kwargs.get('is_print_compare_log', True)
33
34
  # get the ranks and match by order
34
35
  npu_ranks = sorted(check_and_return_dir_contents(npu_dump_dir, 'rank'))
35
36
  bench_ranks = sorted(check_and_return_dir_contents(bench_dump_dir, 'rank'))
@@ -49,18 +50,16 @@ def compare_distributed(npu_dump_dir, bench_dump_dir, output_path, **kwargs):
49
50
  'npu_json_path': npu_path,
50
51
  'bench_json_path': bench_path,
51
52
  'stack_json_path': stack_path,
52
- 'is_print_compare_log': True
53
+ 'is_print_compare_log': is_print_compare_log
53
54
  }
54
55
  try:
55
- summary_compare, md5_compare = task_dumppath_get(dump_result_param)
56
- check_configuration_param(stack_mode, auto_analyze, fuzzy_match,
57
- dump_result_param.get('is_print_compare_log', True))
56
+ set_dump_path(dump_result_param)
57
+ dump_mode = get_dump_mode(dump_result_param)
58
+ check_configuration_param(stack_mode, auto_analyze, fuzzy_match, is_print_compare_log)
58
59
  create_directory(output_path)
59
- check_compare_param(dump_result_param, output_path,
60
- summary_compare=summary_compare, md5_compare=md5_compare)
60
+ check_compare_param(dump_result_param, output_path, dump_mode)
61
61
  except (CompareException, FileCheckException) as error:
62
62
  logger.error('Compare failed. Please check the arguments and do it again!')
63
63
  raise CompareException(error.code) from error
64
64
  pt_comparator = PTComparator()
65
- pt_comparator.compare_core(dump_result_param, output_path, suffix=f'_{nr}-{br}',
66
- summary_compare=summary_compare, md5_compare=md5_compare, **kwargs)
65
+ pt_comparator.compare_core(dump_result_param, output_path, suffix=f'_{nr}-{br}', dump_mode=dump_mode, **kwargs)
@@ -19,8 +19,8 @@ from msprobe.core.common.const import FileCheckConst
19
19
  from msprobe.pytorch.common.log import logger
20
20
  from msprobe.core.common.exceptions import FileCheckException
21
21
  from msprobe.core.compare.acc_compare import Comparator
22
- from msprobe.core.common.utils import check_configuration_param, task_dumppath_get, check_compare_param, \
23
- CompareException
22
+ from msprobe.core.common.utils import check_configuration_param, check_compare_param, \
23
+ CompareException, set_dump_path, get_dump_mode
24
24
  from msprobe.core.common.file_utils import FileChecker, create_directory, load_yaml
25
25
  from msprobe.pytorch.common.utils import load_pt
26
26
 
@@ -45,6 +45,8 @@ class PTComparator (Comparator):
45
45
  return mapping_dict
46
46
 
47
47
  def read_npy_data(self, dir_path, file_name):
48
+ if not file_name:
49
+ return None
48
50
  data_path = os.path.join(dir_path, file_name)
49
51
  path_checker = FileChecker(data_path, FileCheckConst.FILE, FileCheckConst.READ_ABLE,
50
52
  FileCheckConst.PT_SUFFIX, False)
@@ -68,15 +70,15 @@ class PTComparator (Comparator):
68
70
 
69
71
  def compare(input_param, output_path, stack_mode=False, auto_analyze=True, fuzzy_match=False, **kwargs):
70
72
  try:
71
- summary_compare, md5_compare = task_dumppath_get(input_param)
73
+ set_dump_path(input_param)
74
+ dump_mode = get_dump_mode(input_param)
72
75
  check_configuration_param(stack_mode, auto_analyze, fuzzy_match, input_param.get('is_print_compare_log', True))
73
76
  create_directory(output_path)
74
- check_compare_param(input_param, output_path, summary_compare, md5_compare)
77
+ check_compare_param(input_param, output_path, dump_mode)
75
78
  data_mapping = kwargs.get('data_mapping', None)
76
79
  except (CompareException, FileCheckException) as error:
77
80
  logger.error('Compare failed. Please check the arguments and do it again!')
78
81
  raise CompareException(error.code) from error
79
82
  pt_comparator = PTComparator(data_mapping)
80
83
  pt_comparator.compare_core(input_param, output_path, stack_mode=stack_mode,
81
- auto_analyze=auto_analyze, fuzzy_match=fuzzy_match, summary_compare=summary_compare,
82
- md5_compare=md5_compare)
84
+ auto_analyze=auto_analyze, fuzzy_match=fuzzy_match, dump_mode=dump_mode)
@@ -31,14 +31,14 @@ class DebuggerConfig:
31
31
  self.scope = task_config.scope if task_config.scope else []
32
32
  self.list = task_config.list if task_config.list else []
33
33
  self.data_mode = task_config.data_mode if task_config.data_mode else ["all"]
34
- self.backward_input_list = task_config.backward_input if task_config.backward_input else []
35
- self.backward_input = {}
36
- self.acl_config = common_config.acl_config if common_config.acl_config else ""
37
- self.is_forward_acl_dump = True
38
34
  self.summary_mode = task_config.summary_mode if task_config.summary_mode else Const.STATISTICS
39
35
  self.overflow_nums = task_config.overflow_nums if task_config.overflow_nums else 1
40
36
  self.framework = Const.PT_FRAMEWORK
41
37
 
38
+ if self.level == Const.LEVEL_L2:
39
+ self.is_backward_kernel_dump = False
40
+ self._check_and_adjust_config_with_l2()
41
+
42
42
  if self.task == Const.FREE_BENCHMARK:
43
43
  self.fuzz_device = task_config.fuzz_device
44
44
  self.handler_type = task_config.handler_type
@@ -59,20 +59,11 @@ class DebuggerConfig:
59
59
  self.tls_path = task_config.tls_path if task_config.tls_path else ""
60
60
  self.host = task_config.host if task_config.host else ""
61
61
  self.port = task_config.port if task_config.port else -1
62
+ self.online_run_ut_recompute = task_config.online_run_ut_recompute \
63
+ if isinstance(task_config.online_run_ut_recompute, bool) else False
62
64
 
63
65
  self.check()
64
66
 
65
- if self.level == "L2":
66
- if not self.scope or not isinstance(self.scope, list) or len(self.scope) != 1:
67
- raise ValueError("scope must be configured as a list with one api name")
68
- if isinstance(self.scope[0], str) and Const.BACKWARD in self.scope[0] and not self.backward_input_list:
69
- raise ValueError("backward_input must be configured when scope contains 'backward'")
70
- if Const.BACKWARD in self.scope[0]:
71
- self.is_forward_acl_dump = False
72
- for index, scope_spec in enumerate(self.scope):
73
- self.scope[index] = scope_spec.replace(Const.BACKWARD, Const.FORWARD)
74
- self.backward_input[self.scope[index]] = self.backward_input_list[index]
75
-
76
67
  def check_kwargs(self):
77
68
  if self.task and self.task not in Const.TASK_LIST:
78
69
  raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR,
@@ -106,3 +97,16 @@ class DebuggerConfig:
106
97
  logger.error_on_rank_0(f"The 'model' parameter of start must be a torch.nn.Module type.")
107
98
  raise MsprobeException(
108
99
  MsprobeException.INVALID_PARAM_ERROR, f"model must be a torch.nn.Module")
100
+
101
+ def _check_and_adjust_config_with_l2(self):
102
+ if self.scope:
103
+ raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR,
104
+ f"When level is set to L2, the scope cannot be configured.")
105
+ if not self.list or len(self.list) != 1:
106
+ raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR,
107
+ f"When level is set to L2, the list must be configured as a list with one api name.")
108
+ api_name = self.list[0]
109
+ if api_name.endswith(Const.BACKWARD):
110
+ self.is_backward_kernel_dump = True
111
+ api_forward_name = api_name[:-len(Const.BACKWARD)] + Const.FORWARD
112
+ self.list.append(api_forward_name)
@@ -0,0 +1,33 @@
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import os
17
+
18
+ from msprobe.core.common.file_utils import save_json
19
+
20
+
21
+ def create_kernel_config_json(dump_path, cur_rank):
22
+ kernel_config_name = "kernel_config.json" if cur_rank == '' else f"kernel_config_{cur_rank}.json"
23
+ kernel_config_path = os.path.join(dump_path, kernel_config_name)
24
+ config_info = {
25
+ "dump": {
26
+ "dump_list": [],
27
+ "dump_path": dump_path,
28
+ "dump_mode": "all",
29
+ "dump_op_switch": "on"
30
+ }
31
+ }
32
+ save_json(kernel_config_path, config_info, indent=4)
33
+ return kernel_config_path
@@ -1,3 +1,18 @@
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
1
16
  from typing import Dict
2
17
 
3
18
  import numpy as np
@@ -1,3 +1,18 @@
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
1
16
  from collections import defaultdict
2
17
  from msprobe.pytorch.free_benchmark.common.constant import ThresholdConfig
3
18
 
@@ -1,3 +1,18 @@
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
1
16
  from msprobe.core.common.const import Const
2
17
 
3
18
 
@@ -17,6 +17,7 @@ from dataclasses import dataclass
17
17
  from typing import Any, Callable, Dict, List, Optional, Tuple
18
18
 
19
19
  import torch
20
+ from msprobe.core.common.exceptions import FreeBenchmarkException
20
21
  from msprobe.pytorch.free_benchmark import logger
21
22
  from msprobe.pytorch.free_benchmark.common.enums import (
22
23
  DeviceType,
@@ -128,7 +129,13 @@ def make_unequal_row(
128
129
  row.max_rel = ratio - 1
129
130
  origin_tensor = data_params.original_result
130
131
  perturbed_tensor = data_params.perturbed_result
131
- if index:
132
+ if index is not None:
133
+ if index >= len(origin_tensor) or index >= len(perturbed_tensor):
134
+ err_msg = f"When generating unequal results, index {index} of output is out of bounds. please check!"
135
+ raise FreeBenchmarkException(
136
+ FreeBenchmarkException.OutputIndexError,
137
+ error_info=err_msg,
138
+ )
132
139
  origin_tensor = origin_tensor[index]
133
140
  perturbed_tensor = perturbed_tensor[index]
134
141
  row.output_index = index
@@ -13,7 +13,10 @@
13
13
  # See the License for the specific language governing permissions and
14
14
  # limitations under the License.
15
15
 
16
+
16
17
  import torch
18
+ from msprobe.core.common.exceptions import FreeBenchmarkException
19
+ from msprobe.core.common.utils import recursion_depth_decorator
17
20
  from msprobe.pytorch.free_benchmark.common.enums import DeviceType
18
21
 
19
22
 
@@ -51,6 +54,7 @@ class Tools:
51
54
  return api_name.rsplit(".", 2)[0]
52
55
 
53
56
  @staticmethod
57
+ @recursion_depth_decorator("FreeBenchmark: Tools.convert_device_and_dtype")
54
58
  def convert_device_and_dtype(
55
59
  tensor_seq, device: str = DeviceType.CPU, change_dtype: bool = False
56
60
  ):
@@ -73,23 +77,41 @@ class Tools:
73
77
  return tensor_seq
74
78
 
75
79
  @staticmethod
80
+ @recursion_depth_decorator("FreeBenchmark: Tools.convert_fuzz_output_to_origin")
76
81
  def convert_fuzz_output_to_origin(origin, perturbed):
77
- if isinstance(origin, torch.Tensor):
82
+ if isinstance(origin, torch.Tensor) and isinstance(perturbed, torch.Tensor):
78
83
  origin.data = perturbed.to(origin.dtype).to(origin.device)
79
84
  return origin
80
- if isinstance(origin, dict):
85
+ if isinstance(origin, dict) and isinstance(perturbed, dict):
81
86
  output = dict()
82
87
  for key, value in origin.items():
88
+ if key not in perturbed:
89
+ err_msg = f"'{key}' not in perturbed output."
90
+ raise FreeBenchmarkException(
91
+ FreeBenchmarkException.InvalidPerturbedOutput,
92
+ error_info=err_msg,
93
+ )
83
94
  output[key] = Tools.convert_fuzz_output_to_origin(value, perturbed[key])
84
95
  return output
85
- if isinstance(origin, (tuple, list)):
96
+ if isinstance(origin, (tuple, list)) and isinstance(perturbed, (tuple, list)):
86
97
  result = list()
98
+ if len(perturbed) != len(origin):
99
+ err_msg = (
100
+ f"length of perturbed output ({len(perturbed)}) is different "
101
+ f"from the length of original output ({len(origin)})."
102
+ )
103
+ raise FreeBenchmarkException(
104
+ FreeBenchmarkException.InvalidPerturbedOutput, error_info=err_msg
105
+ )
87
106
  for index_, value in enumerate(origin):
88
107
  result.append(
89
108
  Tools.convert_fuzz_output_to_origin(value, perturbed[index_])
90
109
  )
91
110
  return type(origin)(result)
92
- return origin
111
+ err_msg = f"conversion of two outputs with types ({type(origin)}, {type(perturbed)}) is not supported."
112
+ raise FreeBenchmarkException(
113
+ FreeBenchmarkException.UnsupportedType, error_info=err_msg
114
+ )
93
115
 
94
116
 
95
117
  class TorchC: