mindstudio-probe 1.0.4__py3-none-any.whl → 1.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (278) hide show
  1. {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.1.dist-info}/METADATA +5 -5
  2. mindstudio_probe-1.1.1.dist-info/RECORD +341 -0
  3. {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.1.dist-info}/WHEEL +1 -1
  4. {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.1.dist-info}/entry_points.txt +0 -1
  5. msprobe/README.md +84 -18
  6. msprobe/__init__.py +16 -1
  7. msprobe/config.json +1 -5
  8. msprobe/core/advisor/advisor.py +16 -11
  9. msprobe/core/advisor/advisor_const.py +6 -7
  10. msprobe/core/advisor/advisor_result.py +12 -12
  11. msprobe/core/common/const.py +164 -3
  12. msprobe/core/common/exceptions.py +26 -4
  13. msprobe/core/common/file_utils.py +196 -27
  14. msprobe/core/common/inplace_op_checker.py +53 -0
  15. msprobe/core/common/inplace_ops.yaml +251 -0
  16. msprobe/core/common/log.py +46 -18
  17. msprobe/core/common/utils.py +308 -209
  18. msprobe/core/common_config.py +60 -38
  19. msprobe/core/compare/acc_compare.py +332 -94
  20. msprobe/core/compare/check.py +104 -22
  21. msprobe/core/compare/compare_cli.py +42 -5
  22. msprobe/core/compare/highlight.py +162 -57
  23. msprobe/core/compare/layer_mapping/__init__.py +19 -0
  24. msprobe/core/compare/layer_mapping/data_scope_parser.py +235 -0
  25. msprobe/core/compare/layer_mapping/layer_mapping.py +242 -0
  26. msprobe/core/compare/layer_mapping/postprocess_pass.py +94 -0
  27. msprobe/core/compare/multiprocessing_compute.py +33 -8
  28. msprobe/core/compare/npy_compare.py +73 -29
  29. msprobe/core/compare/utils.py +306 -247
  30. msprobe/core/data_dump/data_collector.py +44 -43
  31. msprobe/core/data_dump/data_processor/base.py +88 -35
  32. msprobe/core/data_dump/data_processor/factory.py +20 -3
  33. msprobe/core/data_dump/data_processor/mindspore_processor.py +14 -8
  34. msprobe/core/data_dump/data_processor/pytorch_processor.py +180 -66
  35. msprobe/core/data_dump/json_writer.py +63 -42
  36. msprobe/core/data_dump/scope.py +143 -48
  37. msprobe/core/grad_probe/constant.py +31 -13
  38. msprobe/core/grad_probe/grad_compare.py +20 -4
  39. msprobe/core/grad_probe/utils.py +44 -3
  40. msprobe/core/overflow_check/abnormal_scene.py +185 -0
  41. msprobe/core/overflow_check/api_info.py +55 -0
  42. msprobe/core/overflow_check/checker.py +138 -0
  43. msprobe/core/overflow_check/filter.py +157 -0
  44. msprobe/core/overflow_check/ignore_rules.yaml +55 -0
  45. msprobe/core/overflow_check/level.py +22 -0
  46. msprobe/core/overflow_check/utils.py +28 -0
  47. msprobe/docs/01.installation.md +29 -9
  48. msprobe/docs/02.config_introduction.md +83 -84
  49. msprobe/docs/03.config_examples.md +3 -20
  50. msprobe/docs/04.kernel_dump_PyTorch.md +73 -0
  51. msprobe/docs/05.data_dump_PyTorch.md +143 -13
  52. msprobe/docs/06.data_dump_MindSpore.md +197 -88
  53. msprobe/docs/07.accuracy_checker_PyTorch.md +69 -46
  54. msprobe/docs/08.accuracy_checker_online_PyTorch.md +52 -17
  55. msprobe/docs/09.accuracy_checker_MindSpore.md +51 -15
  56. msprobe/docs/10.accuracy_compare_PyTorch.md +187 -99
  57. msprobe/docs/11.accuracy_compare_MindSpore.md +253 -31
  58. msprobe/docs/12.overflow_check_PyTorch.md +1 -1
  59. msprobe/docs/13.overflow_check_MindSpore.md +6 -6
  60. msprobe/docs/15.free_benchmarking_PyTorch.md +60 -55
  61. msprobe/docs/16.free_benchmarking_MindSpore.md +159 -0
  62. msprobe/docs/17.grad_probe.md +19 -22
  63. msprobe/docs/18.online_dispatch.md +89 -0
  64. msprobe/docs/19.monitor.md +468 -0
  65. msprobe/docs/20.monitor_performance_baseline.md +52 -0
  66. msprobe/docs/21.visualization_PyTorch.md +386 -0
  67. msprobe/docs/22.visualization_MindSpore.md +384 -0
  68. msprobe/docs/23.tool_function_introduction.md +28 -0
  69. msprobe/docs/{FAQ_PyTorch.md → FAQ.md} +25 -10
  70. msprobe/docs/data_dump_Mindspore/dynamic_graph_quick_start_example.md +211 -0
  71. msprobe/docs/img/compare_result.png +0 -0
  72. msprobe/docs/img/monitor/cpu_info.png +0 -0
  73. msprobe/docs/img/ms_dump.png +0 -0
  74. msprobe/docs/img/ms_layer.png +0 -0
  75. msprobe/docs/img/pt_dump.png +0 -0
  76. msprobe/mindspore/__init__.py +16 -0
  77. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +130 -138
  78. msprobe/mindspore/api_accuracy_checker/api_info.py +27 -5
  79. msprobe/mindspore/api_accuracy_checker/api_runner.py +43 -18
  80. msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +21 -7
  81. msprobe/mindspore/api_accuracy_checker/checker_support_api.yaml +77 -0
  82. msprobe/mindspore/api_accuracy_checker/cmd_parser.py +63 -1
  83. msprobe/mindspore/api_accuracy_checker/compute_element.py +59 -24
  84. msprobe/mindspore/api_accuracy_checker/data_manager.py +264 -0
  85. msprobe/mindspore/api_accuracy_checker/main.py +27 -3
  86. msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +206 -0
  87. msprobe/mindspore/api_accuracy_checker/multi_data_manager.py +58 -0
  88. msprobe/mindspore/api_accuracy_checker/type_mapping.py +22 -5
  89. msprobe/mindspore/api_accuracy_checker/utils.py +34 -17
  90. msprobe/mindspore/cell_processor.py +58 -13
  91. msprobe/mindspore/common/const.py +35 -13
  92. msprobe/mindspore/common/log.py +5 -9
  93. msprobe/mindspore/common/utils.py +60 -5
  94. msprobe/mindspore/compare/distributed_compare.py +15 -28
  95. msprobe/mindspore/compare/ms_compare.py +319 -158
  96. msprobe/mindspore/compare/ms_graph_compare.py +99 -49
  97. msprobe/mindspore/debugger/debugger_config.py +20 -14
  98. msprobe/mindspore/debugger/precision_debugger.py +43 -13
  99. msprobe/mindspore/dump/dump_tool_factory.py +18 -1
  100. msprobe/mindspore/dump/hook_cell/api_registry.py +23 -3
  101. msprobe/mindspore/dump/hook_cell/primitive_hooks.py +203 -0
  102. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +107 -10
  103. msprobe/mindspore/dump/hook_cell/wrap_api.py +21 -13
  104. msprobe/mindspore/dump/jit_dump.py +56 -20
  105. msprobe/mindspore/dump/kernel_graph_dump.py +19 -5
  106. msprobe/mindspore/dump/kernel_kbyk_dump.py +19 -6
  107. msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +140 -0
  108. msprobe/mindspore/dym_loader/hook_dynamic_loader.h +53 -0
  109. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +162 -41
  110. msprobe/mindspore/free_benchmark/common/config.py +15 -0
  111. msprobe/mindspore/free_benchmark/common/handler_params.py +15 -1
  112. msprobe/mindspore/free_benchmark/common/utils.py +37 -8
  113. msprobe/mindspore/free_benchmark/data/support_wrap_ops.yaml +0 -204
  114. msprobe/mindspore/free_benchmark/handler/base_handler.py +20 -5
  115. msprobe/mindspore/free_benchmark/handler/check_handler.py +21 -7
  116. msprobe/mindspore/free_benchmark/handler/fix_handler.py +18 -3
  117. msprobe/mindspore/free_benchmark/handler/handler_factory.py +21 -6
  118. msprobe/mindspore/free_benchmark/perturbation/add_noise.py +23 -8
  119. msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +29 -5
  120. msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +25 -10
  121. msprobe/mindspore/free_benchmark/perturbation/exchange_value.py +45 -19
  122. msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +29 -8
  123. msprobe/mindspore/free_benchmark/perturbation/no_change.py +16 -1
  124. msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +22 -7
  125. msprobe/mindspore/free_benchmark/self_check_tool_factory.py +17 -2
  126. msprobe/mindspore/grad_probe/global_context.py +44 -14
  127. msprobe/mindspore/grad_probe/grad_analyzer.py +27 -13
  128. msprobe/mindspore/grad_probe/grad_monitor.py +16 -1
  129. msprobe/mindspore/grad_probe/grad_stat_csv.py +33 -5
  130. msprobe/mindspore/grad_probe/hook.py +24 -10
  131. msprobe/mindspore/grad_probe/utils.py +18 -5
  132. msprobe/mindspore/ms_config.py +22 -15
  133. msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +20 -6
  134. msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +15 -0
  135. msprobe/mindspore/runtime.py +15 -0
  136. msprobe/mindspore/service.py +75 -150
  137. msprobe/mindspore/task_handler_factory.py +15 -0
  138. msprobe/msprobe.py +24 -7
  139. msprobe/pytorch/__init__.py +23 -3
  140. msprobe/pytorch/api_accuracy_checker/common/config.py +81 -2
  141. msprobe/pytorch/api_accuracy_checker/common/utils.py +53 -21
  142. msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +19 -2
  143. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +50 -25
  144. msprobe/pytorch/api_accuracy_checker/compare/compare.py +51 -21
  145. msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +23 -6
  146. msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +28 -8
  147. msprobe/pytorch/api_accuracy_checker/config.yaml +1 -1
  148. msprobe/pytorch/api_accuracy_checker/generate_op_script/config_op.json +9 -0
  149. msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +454 -0
  150. msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +365 -0
  151. msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +73 -33
  152. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +44 -18
  153. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +32 -11
  154. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +122 -172
  155. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +158 -4
  156. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +30 -24
  157. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +68 -31
  158. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +27 -4
  159. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/dump_dispatch.py +115 -0
  160. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +26 -9
  161. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/torch_ops_config.yaml +63 -0
  162. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +44 -0
  163. msprobe/pytorch/bench_functions/__init__.py +18 -3
  164. msprobe/pytorch/bench_functions/apply_adam_w.py +15 -0
  165. msprobe/pytorch/bench_functions/confusion_transpose.py +20 -1
  166. msprobe/pytorch/bench_functions/fast_gelu.py +15 -0
  167. msprobe/pytorch/bench_functions/layer_norm_eval.py +15 -0
  168. msprobe/pytorch/bench_functions/linear.py +15 -0
  169. msprobe/pytorch/bench_functions/matmul_backward.py +33 -6
  170. msprobe/pytorch/bench_functions/npu_fusion_attention.py +280 -157
  171. msprobe/pytorch/bench_functions/rms_norm.py +15 -0
  172. msprobe/pytorch/bench_functions/rotary_mul.py +32 -9
  173. msprobe/pytorch/bench_functions/scaled_mask_softmax.py +15 -0
  174. msprobe/pytorch/bench_functions/swiglu.py +29 -6
  175. msprobe/pytorch/common/__init__.py +15 -0
  176. msprobe/pytorch/common/log.py +18 -6
  177. msprobe/pytorch/common/parse_json.py +31 -16
  178. msprobe/pytorch/common/utils.py +96 -40
  179. msprobe/pytorch/compare/distributed_compare.py +13 -14
  180. msprobe/pytorch/compare/match.py +15 -0
  181. msprobe/pytorch/compare/pt_compare.py +44 -10
  182. msprobe/pytorch/debugger/debugger_config.py +69 -52
  183. msprobe/pytorch/debugger/precision_debugger.py +72 -24
  184. msprobe/pytorch/dump/kernel_dump/kernel_config.py +33 -0
  185. msprobe/pytorch/free_benchmark/__init__.py +20 -5
  186. msprobe/pytorch/free_benchmark/common/constant.py +15 -0
  187. msprobe/pytorch/free_benchmark/common/counter.py +15 -0
  188. msprobe/pytorch/free_benchmark/common/enums.py +43 -0
  189. msprobe/pytorch/free_benchmark/common/params.py +23 -1
  190. msprobe/pytorch/free_benchmark/common/utils.py +43 -5
  191. msprobe/pytorch/free_benchmark/compare/grad_saver.py +47 -9
  192. msprobe/pytorch/free_benchmark/compare/single_benchmark.py +17 -0
  193. msprobe/pytorch/free_benchmark/main.py +19 -4
  194. msprobe/pytorch/free_benchmark/perturbed_layers/base_layer.py +15 -0
  195. msprobe/pytorch/free_benchmark/perturbed_layers/layer_factory.py +19 -4
  196. msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +18 -1
  197. msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +21 -4
  198. msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +28 -2
  199. msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +19 -0
  200. msprobe/pytorch/free_benchmark/perturbed_layers/npu/no_change.py +15 -0
  201. msprobe/pytorch/free_benchmark/perturbed_layers/npu/npu_base_layser.py +15 -0
  202. msprobe/pytorch/free_benchmark/perturbed_layers/run_cpu.py +15 -0
  203. msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +65 -16
  204. msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +15 -0
  205. msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +21 -5
  206. msprobe/pytorch/free_benchmark/result_handlers/handler_factory.py +15 -0
  207. msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +19 -4
  208. msprobe/pytorch/function_factory.py +17 -2
  209. msprobe/pytorch/functional/module_dump.py +84 -0
  210. msprobe/pytorch/grad_probe/grad_monitor.py +23 -6
  211. msprobe/pytorch/grad_probe/grad_stat_csv.py +40 -10
  212. msprobe/pytorch/hook_module/__init__.py +16 -1
  213. msprobe/pytorch/hook_module/api_registry.py +13 -8
  214. msprobe/pytorch/hook_module/hook_module.py +17 -19
  215. msprobe/pytorch/hook_module/support_wrap_ops.yaml +1 -0
  216. msprobe/pytorch/hook_module/utils.py +4 -6
  217. msprobe/pytorch/hook_module/wrap_aten.py +12 -11
  218. msprobe/pytorch/hook_module/wrap_distributed.py +6 -7
  219. msprobe/pytorch/hook_module/wrap_functional.py +21 -20
  220. msprobe/pytorch/hook_module/wrap_npu_custom.py +9 -17
  221. msprobe/pytorch/hook_module/wrap_tensor.py +4 -6
  222. msprobe/pytorch/hook_module/wrap_torch.py +4 -6
  223. msprobe/pytorch/hook_module/wrap_vf.py +4 -6
  224. msprobe/pytorch/module_processer.py +18 -6
  225. msprobe/pytorch/monitor/anomaly_analyse.py +201 -0
  226. msprobe/pytorch/monitor/anomaly_detect.py +340 -0
  227. msprobe/pytorch/monitor/distributed/distributed_ops.yaml +19 -0
  228. msprobe/pytorch/monitor/distributed/stack_blacklist.yaml +5 -0
  229. msprobe/pytorch/monitor/distributed/wrap_distributed.py +272 -0
  230. msprobe/pytorch/monitor/features.py +108 -0
  231. msprobe/pytorch/monitor/module_hook.py +870 -0
  232. msprobe/pytorch/monitor/module_metric.py +193 -0
  233. msprobe/pytorch/monitor/module_spec_verifier.py +93 -0
  234. msprobe/pytorch/monitor/optimizer_collect.py +295 -0
  235. msprobe/pytorch/monitor/unittest/__init__.py +0 -0
  236. msprobe/pytorch/monitor/unittest/test_monitor.py +145 -0
  237. msprobe/pytorch/monitor/utils.py +250 -0
  238. msprobe/pytorch/monitor/visualizer.py +59 -0
  239. msprobe/pytorch/online_dispatch/__init__.py +2 -3
  240. msprobe/pytorch/online_dispatch/compare.py +38 -48
  241. msprobe/pytorch/online_dispatch/dispatch.py +50 -25
  242. msprobe/pytorch/online_dispatch/dump_compare.py +21 -9
  243. msprobe/pytorch/online_dispatch/single_compare.py +60 -39
  244. msprobe/pytorch/online_dispatch/torch_ops_config.yaml +9 -1
  245. msprobe/pytorch/online_dispatch/utils.py +48 -23
  246. msprobe/pytorch/parse.py +15 -0
  247. msprobe/pytorch/parse_tool/cli.py +5 -6
  248. msprobe/pytorch/parse_tool/lib/compare.py +19 -26
  249. msprobe/pytorch/parse_tool/lib/config.py +1 -1
  250. msprobe/pytorch/parse_tool/lib/parse_tool.py +4 -2
  251. msprobe/pytorch/parse_tool/lib/utils.py +40 -55
  252. msprobe/pytorch/parse_tool/lib/visualization.py +3 -1
  253. msprobe/pytorch/pt_config.py +192 -40
  254. msprobe/pytorch/service.py +110 -35
  255. msprobe/visualization/__init__.py +14 -0
  256. msprobe/visualization/builder/__init__.py +14 -0
  257. msprobe/visualization/builder/graph_builder.py +165 -0
  258. msprobe/visualization/builder/msprobe_adapter.py +205 -0
  259. msprobe/visualization/compare/__init__.py +14 -0
  260. msprobe/visualization/compare/graph_comparator.py +130 -0
  261. msprobe/visualization/compare/mode_adapter.py +211 -0
  262. msprobe/visualization/graph/__init__.py +14 -0
  263. msprobe/visualization/graph/base_node.py +124 -0
  264. msprobe/visualization/graph/graph.py +200 -0
  265. msprobe/visualization/graph/node_colors.py +95 -0
  266. msprobe/visualization/graph/node_op.py +39 -0
  267. msprobe/visualization/graph_service.py +214 -0
  268. msprobe/visualization/utils.py +232 -0
  269. mindstudio_probe-1.0.4.dist-info/RECORD +0 -276
  270. msprobe/docs/04.acl_config_examples.md +0 -76
  271. msprobe/mindspore/free_benchmark/decorator/dec_forward.py +0 -43
  272. msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +0 -107
  273. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/ssl_config.py +0 -10
  274. msprobe/pytorch/functional/dump_module.py +0 -39
  275. {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.1.dist-info}/LICENSE +0 -0
  276. {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.1.dist-info}/top_level.txt +0 -0
  277. /msprobe/{mindspore/free_benchmark/decorator → pytorch/monitor}/__init__.py +0 -0
  278. /msprobe/pytorch/{functional/data_processor.py → monitor/distributed/__init__.py} +0 -0
@@ -1,6 +1,23 @@
1
- from msprobe.pytorch.common import seed_all
2
- from msprobe.pytorch.common.log import logger
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import torch
17
+
3
18
  from msprobe.core.common.const import Const
19
+ from msprobe.core.common.exceptions import MsprobeException
20
+ from msprobe.pytorch.common.log import logger
4
21
 
5
22
 
6
23
  class DebuggerConfig:
@@ -10,30 +27,28 @@ class DebuggerConfig:
10
27
  self.rank = common_config.rank if common_config.rank else []
11
28
  self.step = common_config.step if common_config.step else []
12
29
  self.level = level or common_config.level or "L1"
13
- self.seed = common_config.seed if common_config.seed else 1234
14
- self.is_deterministic = common_config.is_deterministic
15
30
  self.enable_dataloader = common_config.enable_dataloader
16
31
  self.scope = task_config.scope if task_config.scope else []
17
32
  self.list = task_config.list if task_config.list else []
18
33
  self.data_mode = task_config.data_mode if task_config.data_mode else ["all"]
19
- self.backward_input_list = task_config.backward_input if task_config.backward_input else []
20
- self.backward_input = {}
21
- self.acl_config = common_config.acl_config if common_config.acl_config else ""
22
- self.is_forward_acl_dump = True
23
34
  self.summary_mode = task_config.summary_mode if task_config.summary_mode else Const.STATISTICS
24
35
  self.overflow_nums = task_config.overflow_nums if task_config.overflow_nums else 1
25
36
  self.framework = Const.PT_FRAMEWORK
26
37
 
38
+ if self.level == Const.LEVEL_L2:
39
+ self.is_backward_kernel_dump = False
40
+ self._check_and_adjust_config_with_l2()
41
+
27
42
  if self.task == Const.FREE_BENCHMARK:
28
- self.fuzz_device = task_config.fuzz_device if task_config.fuzz_device else 'npu'
29
- self.handler_type = task_config.handler_type if task_config.handler_type else 'check'
30
- self.pert_mode = task_config.pert_mode if task_config.pert_mode else 'improve_precision'
31
- self.fuzz_level = task_config.fuzz_level if task_config.fuzz_level else 'L1'
32
- self.fuzz_stage = task_config.fuzz_stage if task_config.fuzz_stage else 'forward'
43
+ self.fuzz_device = task_config.fuzz_device
44
+ self.handler_type = task_config.handler_type
45
+ self.pert_mode = task_config.pert_mode
46
+ self.fuzz_level = task_config.fuzz_level
47
+ self.fuzz_stage = task_config.fuzz_stage
33
48
  self.preheat_config = {
34
- "if_preheat": task_config.if_preheat if task_config.if_preheat is not None else True,
35
- "preheat_step": task_config.preheat_step if task_config.preheat_step else 15,
36
- "max_sample": task_config.max_sample if task_config.max_sample else 20,
49
+ "if_preheat": task_config.if_preheat,
50
+ "preheat_step": task_config.preheat_step,
51
+ "max_sample": task_config.max_sample
37
52
  }
38
53
 
39
54
  self.online_run_ut = False
@@ -44,52 +59,54 @@ class DebuggerConfig:
44
59
  self.tls_path = task_config.tls_path if task_config.tls_path else ""
45
60
  self.host = task_config.host if task_config.host else ""
46
61
  self.port = task_config.port if task_config.port else -1
62
+ self.online_run_ut_recompute = task_config.online_run_ut_recompute \
63
+ if isinstance(task_config.online_run_ut_recompute, bool) else False
47
64
 
48
65
  self.check()
49
- if self.step:
50
- self.step.sort()
51
- if self.level == "L2":
52
- if not self.scope or not isinstance(self.scope, list) or len(self.scope) != 1:
53
- raise ValueError("scope must be configured as a list with one api name")
54
- if isinstance(self.scope[0], str) and Const.BACKWARD in self.scope[0] and not self.backward_input_list:
55
- raise ValueError("backward_input must be configured when scope contains 'backward'")
56
- if Const.BACKWARD in self.scope[0]:
57
- self.is_forward_acl_dump = False
58
- for index, scope_spec in enumerate(self.scope):
59
- self.scope[index] = scope_spec.replace(Const.BACKWARD, Const.FORWARD)
60
- self.backward_input[self.scope[index]] = self.backward_input_list[index]
61
- seed_all(self.seed, self.is_deterministic)
62
66
 
63
67
  def check_kwargs(self):
64
68
  if self.task and self.task not in Const.TASK_LIST:
65
- raise Exception("task is invalid")
69
+ raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR,
70
+ f"The task <{self.task}> is not in the {Const.TASK_LIST}.")
66
71
  if self.level and self.level not in Const.LEVEL_LIST:
67
- raise Exception("level is invalid")
72
+ raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR,
73
+ f"The level <{self.level}> is not in the {Const.LEVEL_LIST}.")
68
74
  if not self.dump_path:
69
- raise Exception("Invalid dump path, please check your config")
75
+ raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR,
76
+ f"The dump_path not found.")
70
77
 
71
78
  def check(self):
72
79
  self.check_kwargs()
73
- self._check_rank()
74
- self._check_step()
75
80
  return True
76
81
 
77
- def check_model(self, model):
78
- if self.level in ["L0", "mix"] and not model:
79
- raise Exception(
80
- f"For level {self.level}, PrecisionDebugger must receive a model argument."
81
- )
82
-
83
- def _check_rank(self):
84
- if self.rank:
85
- for rank_id in self.rank:
86
- if not isinstance(rank_id, int) or rank_id < 0:
87
- raise ValueError(f"rank {self.rank} must be an integer and greater than or equal to 0.")
88
- else:
89
- logger.warning_on_rank_0(f"Rank argument is provided. Only rank {self.rank} data will be dumpped.")
82
+ def check_model(self, instance, start_model):
83
+ if self.level not in ["L0", "mix"]:
84
+ if instance.model is not None or start_model is not None:
85
+ logger.warning_on_rank_0(
86
+ f"The current level is not L0 or mix level, so the model parameters will not be used.")
87
+ return
88
+ if start_model is None:
89
+ if instance.model is None:
90
+ logger.error_on_rank_0(
91
+ f"For level {self.level}, PrecisionDebugger or start interface must receive a 'model' argument.")
92
+ raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR, f"missing the parameter 'model'")
93
+ return
94
+ if isinstance(start_model, torch.nn.Module):
95
+ instance.model = start_model
96
+ else:
97
+ logger.error_on_rank_0(f"The 'model' parameter of start must be a torch.nn.Module type.")
98
+ raise MsprobeException(
99
+ MsprobeException.INVALID_PARAM_ERROR, f"model must be a torch.nn.Module")
90
100
 
91
- def _check_step(self):
92
- if self.step:
93
- for s in self.step:
94
- if not isinstance(s, int) or s < 0:
95
- raise ValueError(f"step element {s} must be an integer and greater than or equal to 0.")
101
+ def _check_and_adjust_config_with_l2(self):
102
+ if self.scope:
103
+ raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR,
104
+ f"When level is set to L2, the scope cannot be configured.")
105
+ if not self.list or len(self.list) != 1:
106
+ raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR,
107
+ f"When level is set to L2, the list must be configured as a list with one api name.")
108
+ api_name = self.list[0]
109
+ if api_name.endswith(Const.BACKWARD):
110
+ self.is_backward_kernel_dump = True
111
+ api_forward_name = api_name[:-len(Const.BACKWARD)] + Const.FORWARD
112
+ self.list.append(api_forward_name)
@@ -1,12 +1,34 @@
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from collections import namedtuple
17
+
1
18
  import torch
2
- from torch.utils.data import dataloader
3
- from msprobe.pytorch.debugger.debugger_config import DebuggerConfig
4
- from msprobe.pytorch.service import Service
5
- from msprobe.pytorch.common.log import logger
6
- from msprobe.pytorch.pt_config import parse_json_config
19
+ from msprobe.core.common.const import Const, FileCheckConst, MsgConst
7
20
  from msprobe.core.common.exceptions import MsprobeException
8
- from msprobe.core.common.const import Const
21
+ from msprobe.core.common.file_utils import FileChecker
22
+ from msprobe.core.common.utils import get_real_step_or_rank
23
+ from msprobe.pytorch.common.log import logger
24
+ from msprobe.pytorch.debugger.debugger_config import DebuggerConfig
9
25
  from msprobe.pytorch.grad_probe.grad_monitor import GradientMonitor
26
+ from msprobe.pytorch.pt_config import parse_json_config
27
+ from msprobe.pytorch.service import Service
28
+ from torch.utils.data import dataloader
29
+
30
+ ConfigParameters = namedtuple("ConfigParameters", ["config_path", "task",
31
+ "dump_path", "level", "model"])
10
32
 
11
33
 
12
34
  class PrecisionDebugger:
@@ -30,20 +52,26 @@ class PrecisionDebugger:
30
52
  step=None,
31
53
  ):
32
54
  if not hasattr(self, "initialized"):
55
+ config_params = ConfigParameters(config_path,
56
+ task,
57
+ dump_path,
58
+ level,
59
+ model)
60
+ self.check_input_params(config_params)
61
+
33
62
  self.api_origin = False
34
63
  self.initialized = True
35
- self.model = self.check_model_valid(model)
64
+ self.model = model
36
65
  common_config, task_config = parse_json_config(config_path, task)
37
- self.task = common_config.task
66
+ self.task = task if task else common_config.task
38
67
  if self.task == Const.GRAD_PROBE:
39
68
  self.gm = GradientMonitor(common_config, task_config)
40
69
  return
41
70
  if step:
42
- common_config.step = step
71
+ common_config.step = get_real_step_or_rank(step, Const.STEP)
43
72
  self.config = DebuggerConfig(
44
73
  common_config, task_config, task, dump_path, level
45
74
  )
46
- self.config.check_model(self.model)
47
75
  self.service = Service(self.config)
48
76
  self.enable_dataloader = self.config.enable_dataloader
49
77
  if self.enable_dataloader:
@@ -55,20 +83,40 @@ class PrecisionDebugger:
55
83
  return self._instance
56
84
 
57
85
  @staticmethod
58
- def check_model_valid(model):
59
- if not model or isinstance(model, torch.nn.Module):
60
- return model
61
- raise MsprobeException(
62
- MsprobeException.INVALID_PARAM_ERROR, "model 参数必须是torch.nn.Module类型。"
63
- )
86
+ def check_input_params(args):
87
+ if args.config_path is not None:
88
+ if not isinstance(args.config_path, str):
89
+ raise MsprobeException(
90
+ MsprobeException.INVALID_PARAM_ERROR, f"config_path must be a string")
91
+ file_checker = FileChecker(
92
+ file_path=args.config_path, path_type=FileCheckConst.FILE, file_type=FileCheckConst.JSON_SUFFIX)
93
+ file_checker.common_check()
94
+
95
+ if args.task is not None and args.task not in Const.TASK_LIST:
96
+ raise MsprobeException(
97
+ MsprobeException.INVALID_PARAM_ERROR, f"task must be one of {Const.TASK_LIST}")
98
+
99
+ if args.dump_path is not None:
100
+ if not isinstance(args.dump_path, str):
101
+ raise MsprobeException(
102
+ MsprobeException.INVALID_PARAM_ERROR, f"dump_path must be a string")
103
+
104
+ if args.level is not None and args.level not in Const.LEVEL_LIST:
105
+ raise MsprobeException(
106
+ MsprobeException.INVALID_PARAM_ERROR, f"level must be one of {Const.LEVEL_LIST}")
107
+
108
+ if args.model is not None and not isinstance(args.model, torch.nn.Module):
109
+ raise MsprobeException(
110
+ MsprobeException.INVALID_PARAM_ERROR, f"model must be a torch.nn.Module")
64
111
 
65
112
  @classmethod
66
- def start(cls):
113
+ def start(cls, model=None):
67
114
  instance = cls._instance
115
+ if not instance:
116
+ raise Exception(MsgConst.NOT_CREATED_INSTANCE)
68
117
  if instance.task in PrecisionDebugger.tasks_not_need_debugger:
69
118
  return
70
- if not instance:
71
- raise Exception("No instance of PrecisionDebugger found.")
119
+ instance.config.check_model(instance, model)
72
120
  if instance.enable_dataloader:
73
121
  logger.warning_on_rank_0("DataLoader is enabled, start() skipped.")
74
122
  else:
@@ -85,10 +133,10 @@ class PrecisionDebugger:
85
133
  @classmethod
86
134
  def stop(cls):
87
135
  instance = cls._instance
136
+ if not instance:
137
+ raise Exception(MsgConst.NOT_CREATED_INSTANCE)
88
138
  if instance.task in PrecisionDebugger.tasks_not_need_debugger:
89
139
  return
90
- if not instance:
91
- raise Exception("PrecisionDebugger instance is not created.")
92
140
  if instance.enable_dataloader:
93
141
  logger.warning_on_rank_0("DataLoader is enabled, stop() skipped.")
94
142
  else:
@@ -96,16 +144,16 @@ class PrecisionDebugger:
96
144
 
97
145
  @classmethod
98
146
  def step(cls):
147
+ if not cls._instance:
148
+ raise Exception(MsgConst.NOT_CREATED_INSTANCE)
99
149
  if cls._instance.task in PrecisionDebugger.tasks_not_need_debugger:
100
150
  return
101
- if not cls._instance:
102
- raise Exception("PrecisionDebugger instance is not created.")
103
151
  cls._instance.service.step()
104
152
 
105
153
  @classmethod
106
154
  def monitor(cls, model):
107
155
  if not cls._instance:
108
- raise Exception("PrecisionDebugger instance is not created.")
156
+ raise Exception(MsgConst.NOT_CREATED_INSTANCE)
109
157
  if cls._instance.task != Const.GRAD_PROBE:
110
158
  return
111
159
  cls._instance.gm.monitor(model)
@@ -0,0 +1,33 @@
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import os
17
+
18
+ from msprobe.core.common.file_utils import save_json
19
+
20
+
21
+ def create_kernel_config_json(dump_path, cur_rank):
22
+ kernel_config_name = "kernel_config.json" if cur_rank == '' else f"kernel_config_{cur_rank}.json"
23
+ kernel_config_path = os.path.join(dump_path, kernel_config_name)
24
+ config_info = {
25
+ "dump": {
26
+ "dump_list": [],
27
+ "dump_path": dump_path,
28
+ "dump_mode": "all",
29
+ "dump_op_switch": "on"
30
+ }
31
+ }
32
+ save_json(kernel_config_path, config_info, indent=4)
33
+ return kernel_config_path
@@ -1,8 +1,23 @@
1
- from msprobe.pytorch.common.log import logger
2
- from msprobe.core.common.exceptions import FreeBenchmarkException
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ __all__ = ["FreeBenchmarkCheck", "UnequalRow"]
17
+
3
18
  from msprobe.core.common.const import Const
19
+ from msprobe.core.common.exceptions import FreeBenchmarkException
20
+ from msprobe.pytorch.common.log import logger
4
21
 
5
- from .main import FreeBenchmarkCheck
6
22
  from .common.params import UnequalRow
7
-
8
- __all__ = [FreeBenchmarkCheck, UnequalRow]
23
+ from .main import FreeBenchmarkCheck
@@ -1,3 +1,18 @@
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
1
16
  from typing import Dict
2
17
 
3
18
  import numpy as np
@@ -1,3 +1,18 @@
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
1
16
  from collections import defaultdict
2
17
  from msprobe.pytorch.free_benchmark.common.constant import ThresholdConfig
3
18
 
@@ -1,3 +1,21 @@
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from msprobe.core.common.const import Const
17
+
18
+
1
19
  class PerturbationMode:
2
20
  ADD_NOISE = "add_noise"
3
21
  CHANGE_VALUE = "change_value"
@@ -35,3 +53,28 @@ class FuzzLevel:
35
53
  BASE_LEVEL = "L1"
36
54
  ADV_LEVEL = "L2"
37
55
  REAL_LEVEL = "L3"
56
+
57
+
58
+ class PytorchFreeBenchmarkConst:
59
+ PERTURBATION_MODE_LIST = [
60
+ PerturbationMode.ADD_NOISE,
61
+ PerturbationMode.CHANGE_VALUE,
62
+ PerturbationMode.IMPROVE_PRECISION,
63
+ PerturbationMode.NO_CHANGE,
64
+ PerturbationMode.BIT_NOISE,
65
+ PerturbationMode.TO_CPU,
66
+ ]
67
+ DEFAULT_MODE = PerturbationMode.IMPROVE_PRECISION
68
+ DEVICE_LIST = [DeviceType.NPU, DeviceType.CPU]
69
+ DEFAULT_DEVICE = DeviceType.NPU
70
+ HANDLER_LIST = [HandlerType.CHECK, HandlerType.FIX]
71
+ DEFAULT_HANDLER = HandlerType.CHECK
72
+ FUZZ_LEVEL_LIST = [FuzzLevel.BASE_LEVEL]
73
+ DEFAULT_FUZZ_LEVEL = FuzzLevel.BASE_LEVEL
74
+ FUZZ_STAGE_LIST = [Const.FORWARD, Const.BACKWARD]
75
+ FIX_MODE_LIST = [PerturbationMode.IMPROVE_PRECISION, PerturbationMode.TO_CPU]
76
+ DEFAULT_FUZZ_STAGE = Const.FORWARD
77
+ DEFAULT_PREHEAT_STEP = 15
78
+ DEFAULT_MAX_SAMPLE = 20
79
+ CPU_MODE_LIST = [PerturbationMode.TO_CPU]
80
+ FIX_STAGE_LIST = [Const.FORWARD]
@@ -1,7 +1,23 @@
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
1
16
  from dataclasses import dataclass
2
17
  from typing import Any, Callable, Dict, List, Optional, Tuple
3
18
 
4
19
  import torch
20
+ from msprobe.core.common.exceptions import FreeBenchmarkException
5
21
  from msprobe.pytorch.free_benchmark import logger
6
22
  from msprobe.pytorch.free_benchmark.common.enums import (
7
23
  DeviceType,
@@ -113,7 +129,13 @@ def make_unequal_row(
113
129
  row.max_rel = ratio - 1
114
130
  origin_tensor = data_params.original_result
115
131
  perturbed_tensor = data_params.perturbed_result
116
- if index:
132
+ if index is not None:
133
+ if index >= len(origin_tensor) or index >= len(perturbed_tensor):
134
+ err_msg = f"When generating unequal results, index {index} of output is out of bounds. please check!"
135
+ raise FreeBenchmarkException(
136
+ FreeBenchmarkException.OutputIndexError,
137
+ error_info=err_msg,
138
+ )
117
139
  origin_tensor = origin_tensor[index]
118
140
  perturbed_tensor = perturbed_tensor[index]
119
141
  row.output_index = index
@@ -1,4 +1,22 @@
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+
1
17
  import torch
18
+ from msprobe.core.common.exceptions import FreeBenchmarkException
19
+ from msprobe.core.common.utils import recursion_depth_decorator
2
20
  from msprobe.pytorch.free_benchmark.common.enums import DeviceType
3
21
 
4
22
 
@@ -36,6 +54,7 @@ class Tools:
36
54
  return api_name.rsplit(".", 2)[0]
37
55
 
38
56
  @staticmethod
57
+ @recursion_depth_decorator("FreeBenchmark: Tools.convert_device_and_dtype")
39
58
  def convert_device_and_dtype(
40
59
  tensor_seq, device: str = DeviceType.CPU, change_dtype: bool = False
41
60
  ):
@@ -58,24 +77,43 @@ class Tools:
58
77
  return tensor_seq
59
78
 
60
79
  @staticmethod
80
+ @recursion_depth_decorator("FreeBenchmark: Tools.convert_fuzz_output_to_origin")
61
81
  def convert_fuzz_output_to_origin(origin, perturbed):
62
- if isinstance(origin, torch.Tensor):
82
+ if isinstance(origin, torch.Tensor) and isinstance(perturbed, torch.Tensor):
63
83
  origin.data = perturbed.to(origin.dtype).to(origin.device)
64
84
  return origin
65
- if isinstance(origin, dict):
85
+ if isinstance(origin, dict) and isinstance(perturbed, dict):
66
86
  output = dict()
67
87
  for key, value in origin.items():
88
+ if key not in perturbed:
89
+ err_msg = f"'{key}' not in perturbed output."
90
+ raise FreeBenchmarkException(
91
+ FreeBenchmarkException.InvalidPerturbedOutput,
92
+ error_info=err_msg,
93
+ )
68
94
  output[key] = Tools.convert_fuzz_output_to_origin(value, perturbed[key])
69
95
  return output
70
- if isinstance(origin, (tuple, list)):
96
+ if isinstance(origin, (tuple, list)) and isinstance(perturbed, (tuple, list)):
71
97
  result = list()
98
+ if len(perturbed) != len(origin):
99
+ err_msg = (
100
+ f"length of perturbed output ({len(perturbed)}) is different "
101
+ f"from the length of original output ({len(origin)})."
102
+ )
103
+ raise FreeBenchmarkException(
104
+ FreeBenchmarkException.InvalidPerturbedOutput, error_info=err_msg
105
+ )
72
106
  for index_, value in enumerate(origin):
73
107
  result.append(
74
108
  Tools.convert_fuzz_output_to_origin(value, perturbed[index_])
75
109
  )
76
110
  return type(origin)(result)
77
- return origin
78
-
111
+ err_msg = f"conversion of two outputs with types ({type(origin)}, {type(perturbed)}) is not supported."
112
+ raise FreeBenchmarkException(
113
+ FreeBenchmarkException.UnsupportedType, error_info=err_msg
114
+ )
115
+
116
+
79
117
  class TorchC:
80
118
  sum = torch._C._VariableFunctionsClass.sum
81
119
  isinf = torch._C._VariableFunctionsClass.isinf