mindstudio-probe 1.0.4__py3-none-any.whl → 1.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (278) hide show
  1. {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.1.dist-info}/METADATA +5 -5
  2. mindstudio_probe-1.1.1.dist-info/RECORD +341 -0
  3. {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.1.dist-info}/WHEEL +1 -1
  4. {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.1.dist-info}/entry_points.txt +0 -1
  5. msprobe/README.md +84 -18
  6. msprobe/__init__.py +16 -1
  7. msprobe/config.json +1 -5
  8. msprobe/core/advisor/advisor.py +16 -11
  9. msprobe/core/advisor/advisor_const.py +6 -7
  10. msprobe/core/advisor/advisor_result.py +12 -12
  11. msprobe/core/common/const.py +164 -3
  12. msprobe/core/common/exceptions.py +26 -4
  13. msprobe/core/common/file_utils.py +196 -27
  14. msprobe/core/common/inplace_op_checker.py +53 -0
  15. msprobe/core/common/inplace_ops.yaml +251 -0
  16. msprobe/core/common/log.py +46 -18
  17. msprobe/core/common/utils.py +308 -209
  18. msprobe/core/common_config.py +60 -38
  19. msprobe/core/compare/acc_compare.py +332 -94
  20. msprobe/core/compare/check.py +104 -22
  21. msprobe/core/compare/compare_cli.py +42 -5
  22. msprobe/core/compare/highlight.py +162 -57
  23. msprobe/core/compare/layer_mapping/__init__.py +19 -0
  24. msprobe/core/compare/layer_mapping/data_scope_parser.py +235 -0
  25. msprobe/core/compare/layer_mapping/layer_mapping.py +242 -0
  26. msprobe/core/compare/layer_mapping/postprocess_pass.py +94 -0
  27. msprobe/core/compare/multiprocessing_compute.py +33 -8
  28. msprobe/core/compare/npy_compare.py +73 -29
  29. msprobe/core/compare/utils.py +306 -247
  30. msprobe/core/data_dump/data_collector.py +44 -43
  31. msprobe/core/data_dump/data_processor/base.py +88 -35
  32. msprobe/core/data_dump/data_processor/factory.py +20 -3
  33. msprobe/core/data_dump/data_processor/mindspore_processor.py +14 -8
  34. msprobe/core/data_dump/data_processor/pytorch_processor.py +180 -66
  35. msprobe/core/data_dump/json_writer.py +63 -42
  36. msprobe/core/data_dump/scope.py +143 -48
  37. msprobe/core/grad_probe/constant.py +31 -13
  38. msprobe/core/grad_probe/grad_compare.py +20 -4
  39. msprobe/core/grad_probe/utils.py +44 -3
  40. msprobe/core/overflow_check/abnormal_scene.py +185 -0
  41. msprobe/core/overflow_check/api_info.py +55 -0
  42. msprobe/core/overflow_check/checker.py +138 -0
  43. msprobe/core/overflow_check/filter.py +157 -0
  44. msprobe/core/overflow_check/ignore_rules.yaml +55 -0
  45. msprobe/core/overflow_check/level.py +22 -0
  46. msprobe/core/overflow_check/utils.py +28 -0
  47. msprobe/docs/01.installation.md +29 -9
  48. msprobe/docs/02.config_introduction.md +83 -84
  49. msprobe/docs/03.config_examples.md +3 -20
  50. msprobe/docs/04.kernel_dump_PyTorch.md +73 -0
  51. msprobe/docs/05.data_dump_PyTorch.md +143 -13
  52. msprobe/docs/06.data_dump_MindSpore.md +197 -88
  53. msprobe/docs/07.accuracy_checker_PyTorch.md +69 -46
  54. msprobe/docs/08.accuracy_checker_online_PyTorch.md +52 -17
  55. msprobe/docs/09.accuracy_checker_MindSpore.md +51 -15
  56. msprobe/docs/10.accuracy_compare_PyTorch.md +187 -99
  57. msprobe/docs/11.accuracy_compare_MindSpore.md +253 -31
  58. msprobe/docs/12.overflow_check_PyTorch.md +1 -1
  59. msprobe/docs/13.overflow_check_MindSpore.md +6 -6
  60. msprobe/docs/15.free_benchmarking_PyTorch.md +60 -55
  61. msprobe/docs/16.free_benchmarking_MindSpore.md +159 -0
  62. msprobe/docs/17.grad_probe.md +19 -22
  63. msprobe/docs/18.online_dispatch.md +89 -0
  64. msprobe/docs/19.monitor.md +468 -0
  65. msprobe/docs/20.monitor_performance_baseline.md +52 -0
  66. msprobe/docs/21.visualization_PyTorch.md +386 -0
  67. msprobe/docs/22.visualization_MindSpore.md +384 -0
  68. msprobe/docs/23.tool_function_introduction.md +28 -0
  69. msprobe/docs/{FAQ_PyTorch.md → FAQ.md} +25 -10
  70. msprobe/docs/data_dump_Mindspore/dynamic_graph_quick_start_example.md +211 -0
  71. msprobe/docs/img/compare_result.png +0 -0
  72. msprobe/docs/img/monitor/cpu_info.png +0 -0
  73. msprobe/docs/img/ms_dump.png +0 -0
  74. msprobe/docs/img/ms_layer.png +0 -0
  75. msprobe/docs/img/pt_dump.png +0 -0
  76. msprobe/mindspore/__init__.py +16 -0
  77. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +130 -138
  78. msprobe/mindspore/api_accuracy_checker/api_info.py +27 -5
  79. msprobe/mindspore/api_accuracy_checker/api_runner.py +43 -18
  80. msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +21 -7
  81. msprobe/mindspore/api_accuracy_checker/checker_support_api.yaml +77 -0
  82. msprobe/mindspore/api_accuracy_checker/cmd_parser.py +63 -1
  83. msprobe/mindspore/api_accuracy_checker/compute_element.py +59 -24
  84. msprobe/mindspore/api_accuracy_checker/data_manager.py +264 -0
  85. msprobe/mindspore/api_accuracy_checker/main.py +27 -3
  86. msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +206 -0
  87. msprobe/mindspore/api_accuracy_checker/multi_data_manager.py +58 -0
  88. msprobe/mindspore/api_accuracy_checker/type_mapping.py +22 -5
  89. msprobe/mindspore/api_accuracy_checker/utils.py +34 -17
  90. msprobe/mindspore/cell_processor.py +58 -13
  91. msprobe/mindspore/common/const.py +35 -13
  92. msprobe/mindspore/common/log.py +5 -9
  93. msprobe/mindspore/common/utils.py +60 -5
  94. msprobe/mindspore/compare/distributed_compare.py +15 -28
  95. msprobe/mindspore/compare/ms_compare.py +319 -158
  96. msprobe/mindspore/compare/ms_graph_compare.py +99 -49
  97. msprobe/mindspore/debugger/debugger_config.py +20 -14
  98. msprobe/mindspore/debugger/precision_debugger.py +43 -13
  99. msprobe/mindspore/dump/dump_tool_factory.py +18 -1
  100. msprobe/mindspore/dump/hook_cell/api_registry.py +23 -3
  101. msprobe/mindspore/dump/hook_cell/primitive_hooks.py +203 -0
  102. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +107 -10
  103. msprobe/mindspore/dump/hook_cell/wrap_api.py +21 -13
  104. msprobe/mindspore/dump/jit_dump.py +56 -20
  105. msprobe/mindspore/dump/kernel_graph_dump.py +19 -5
  106. msprobe/mindspore/dump/kernel_kbyk_dump.py +19 -6
  107. msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +140 -0
  108. msprobe/mindspore/dym_loader/hook_dynamic_loader.h +53 -0
  109. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +162 -41
  110. msprobe/mindspore/free_benchmark/common/config.py +15 -0
  111. msprobe/mindspore/free_benchmark/common/handler_params.py +15 -1
  112. msprobe/mindspore/free_benchmark/common/utils.py +37 -8
  113. msprobe/mindspore/free_benchmark/data/support_wrap_ops.yaml +0 -204
  114. msprobe/mindspore/free_benchmark/handler/base_handler.py +20 -5
  115. msprobe/mindspore/free_benchmark/handler/check_handler.py +21 -7
  116. msprobe/mindspore/free_benchmark/handler/fix_handler.py +18 -3
  117. msprobe/mindspore/free_benchmark/handler/handler_factory.py +21 -6
  118. msprobe/mindspore/free_benchmark/perturbation/add_noise.py +23 -8
  119. msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +29 -5
  120. msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +25 -10
  121. msprobe/mindspore/free_benchmark/perturbation/exchange_value.py +45 -19
  122. msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +29 -8
  123. msprobe/mindspore/free_benchmark/perturbation/no_change.py +16 -1
  124. msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +22 -7
  125. msprobe/mindspore/free_benchmark/self_check_tool_factory.py +17 -2
  126. msprobe/mindspore/grad_probe/global_context.py +44 -14
  127. msprobe/mindspore/grad_probe/grad_analyzer.py +27 -13
  128. msprobe/mindspore/grad_probe/grad_monitor.py +16 -1
  129. msprobe/mindspore/grad_probe/grad_stat_csv.py +33 -5
  130. msprobe/mindspore/grad_probe/hook.py +24 -10
  131. msprobe/mindspore/grad_probe/utils.py +18 -5
  132. msprobe/mindspore/ms_config.py +22 -15
  133. msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +20 -6
  134. msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +15 -0
  135. msprobe/mindspore/runtime.py +15 -0
  136. msprobe/mindspore/service.py +75 -150
  137. msprobe/mindspore/task_handler_factory.py +15 -0
  138. msprobe/msprobe.py +24 -7
  139. msprobe/pytorch/__init__.py +23 -3
  140. msprobe/pytorch/api_accuracy_checker/common/config.py +81 -2
  141. msprobe/pytorch/api_accuracy_checker/common/utils.py +53 -21
  142. msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +19 -2
  143. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +50 -25
  144. msprobe/pytorch/api_accuracy_checker/compare/compare.py +51 -21
  145. msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +23 -6
  146. msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +28 -8
  147. msprobe/pytorch/api_accuracy_checker/config.yaml +1 -1
  148. msprobe/pytorch/api_accuracy_checker/generate_op_script/config_op.json +9 -0
  149. msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +454 -0
  150. msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +365 -0
  151. msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +73 -33
  152. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +44 -18
  153. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +32 -11
  154. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +122 -172
  155. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +158 -4
  156. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +30 -24
  157. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +68 -31
  158. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +27 -4
  159. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/dump_dispatch.py +115 -0
  160. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +26 -9
  161. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/torch_ops_config.yaml +63 -0
  162. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +44 -0
  163. msprobe/pytorch/bench_functions/__init__.py +18 -3
  164. msprobe/pytorch/bench_functions/apply_adam_w.py +15 -0
  165. msprobe/pytorch/bench_functions/confusion_transpose.py +20 -1
  166. msprobe/pytorch/bench_functions/fast_gelu.py +15 -0
  167. msprobe/pytorch/bench_functions/layer_norm_eval.py +15 -0
  168. msprobe/pytorch/bench_functions/linear.py +15 -0
  169. msprobe/pytorch/bench_functions/matmul_backward.py +33 -6
  170. msprobe/pytorch/bench_functions/npu_fusion_attention.py +280 -157
  171. msprobe/pytorch/bench_functions/rms_norm.py +15 -0
  172. msprobe/pytorch/bench_functions/rotary_mul.py +32 -9
  173. msprobe/pytorch/bench_functions/scaled_mask_softmax.py +15 -0
  174. msprobe/pytorch/bench_functions/swiglu.py +29 -6
  175. msprobe/pytorch/common/__init__.py +15 -0
  176. msprobe/pytorch/common/log.py +18 -6
  177. msprobe/pytorch/common/parse_json.py +31 -16
  178. msprobe/pytorch/common/utils.py +96 -40
  179. msprobe/pytorch/compare/distributed_compare.py +13 -14
  180. msprobe/pytorch/compare/match.py +15 -0
  181. msprobe/pytorch/compare/pt_compare.py +44 -10
  182. msprobe/pytorch/debugger/debugger_config.py +69 -52
  183. msprobe/pytorch/debugger/precision_debugger.py +72 -24
  184. msprobe/pytorch/dump/kernel_dump/kernel_config.py +33 -0
  185. msprobe/pytorch/free_benchmark/__init__.py +20 -5
  186. msprobe/pytorch/free_benchmark/common/constant.py +15 -0
  187. msprobe/pytorch/free_benchmark/common/counter.py +15 -0
  188. msprobe/pytorch/free_benchmark/common/enums.py +43 -0
  189. msprobe/pytorch/free_benchmark/common/params.py +23 -1
  190. msprobe/pytorch/free_benchmark/common/utils.py +43 -5
  191. msprobe/pytorch/free_benchmark/compare/grad_saver.py +47 -9
  192. msprobe/pytorch/free_benchmark/compare/single_benchmark.py +17 -0
  193. msprobe/pytorch/free_benchmark/main.py +19 -4
  194. msprobe/pytorch/free_benchmark/perturbed_layers/base_layer.py +15 -0
  195. msprobe/pytorch/free_benchmark/perturbed_layers/layer_factory.py +19 -4
  196. msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +18 -1
  197. msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +21 -4
  198. msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +28 -2
  199. msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +19 -0
  200. msprobe/pytorch/free_benchmark/perturbed_layers/npu/no_change.py +15 -0
  201. msprobe/pytorch/free_benchmark/perturbed_layers/npu/npu_base_layser.py +15 -0
  202. msprobe/pytorch/free_benchmark/perturbed_layers/run_cpu.py +15 -0
  203. msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +65 -16
  204. msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +15 -0
  205. msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +21 -5
  206. msprobe/pytorch/free_benchmark/result_handlers/handler_factory.py +15 -0
  207. msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +19 -4
  208. msprobe/pytorch/function_factory.py +17 -2
  209. msprobe/pytorch/functional/module_dump.py +84 -0
  210. msprobe/pytorch/grad_probe/grad_monitor.py +23 -6
  211. msprobe/pytorch/grad_probe/grad_stat_csv.py +40 -10
  212. msprobe/pytorch/hook_module/__init__.py +16 -1
  213. msprobe/pytorch/hook_module/api_registry.py +13 -8
  214. msprobe/pytorch/hook_module/hook_module.py +17 -19
  215. msprobe/pytorch/hook_module/support_wrap_ops.yaml +1 -0
  216. msprobe/pytorch/hook_module/utils.py +4 -6
  217. msprobe/pytorch/hook_module/wrap_aten.py +12 -11
  218. msprobe/pytorch/hook_module/wrap_distributed.py +6 -7
  219. msprobe/pytorch/hook_module/wrap_functional.py +21 -20
  220. msprobe/pytorch/hook_module/wrap_npu_custom.py +9 -17
  221. msprobe/pytorch/hook_module/wrap_tensor.py +4 -6
  222. msprobe/pytorch/hook_module/wrap_torch.py +4 -6
  223. msprobe/pytorch/hook_module/wrap_vf.py +4 -6
  224. msprobe/pytorch/module_processer.py +18 -6
  225. msprobe/pytorch/monitor/anomaly_analyse.py +201 -0
  226. msprobe/pytorch/monitor/anomaly_detect.py +340 -0
  227. msprobe/pytorch/monitor/distributed/distributed_ops.yaml +19 -0
  228. msprobe/pytorch/monitor/distributed/stack_blacklist.yaml +5 -0
  229. msprobe/pytorch/monitor/distributed/wrap_distributed.py +272 -0
  230. msprobe/pytorch/monitor/features.py +108 -0
  231. msprobe/pytorch/monitor/module_hook.py +870 -0
  232. msprobe/pytorch/monitor/module_metric.py +193 -0
  233. msprobe/pytorch/monitor/module_spec_verifier.py +93 -0
  234. msprobe/pytorch/monitor/optimizer_collect.py +295 -0
  235. msprobe/pytorch/monitor/unittest/__init__.py +0 -0
  236. msprobe/pytorch/monitor/unittest/test_monitor.py +145 -0
  237. msprobe/pytorch/monitor/utils.py +250 -0
  238. msprobe/pytorch/monitor/visualizer.py +59 -0
  239. msprobe/pytorch/online_dispatch/__init__.py +2 -3
  240. msprobe/pytorch/online_dispatch/compare.py +38 -48
  241. msprobe/pytorch/online_dispatch/dispatch.py +50 -25
  242. msprobe/pytorch/online_dispatch/dump_compare.py +21 -9
  243. msprobe/pytorch/online_dispatch/single_compare.py +60 -39
  244. msprobe/pytorch/online_dispatch/torch_ops_config.yaml +9 -1
  245. msprobe/pytorch/online_dispatch/utils.py +48 -23
  246. msprobe/pytorch/parse.py +15 -0
  247. msprobe/pytorch/parse_tool/cli.py +5 -6
  248. msprobe/pytorch/parse_tool/lib/compare.py +19 -26
  249. msprobe/pytorch/parse_tool/lib/config.py +1 -1
  250. msprobe/pytorch/parse_tool/lib/parse_tool.py +4 -2
  251. msprobe/pytorch/parse_tool/lib/utils.py +40 -55
  252. msprobe/pytorch/parse_tool/lib/visualization.py +3 -1
  253. msprobe/pytorch/pt_config.py +192 -40
  254. msprobe/pytorch/service.py +110 -35
  255. msprobe/visualization/__init__.py +14 -0
  256. msprobe/visualization/builder/__init__.py +14 -0
  257. msprobe/visualization/builder/graph_builder.py +165 -0
  258. msprobe/visualization/builder/msprobe_adapter.py +205 -0
  259. msprobe/visualization/compare/__init__.py +14 -0
  260. msprobe/visualization/compare/graph_comparator.py +130 -0
  261. msprobe/visualization/compare/mode_adapter.py +211 -0
  262. msprobe/visualization/graph/__init__.py +14 -0
  263. msprobe/visualization/graph/base_node.py +124 -0
  264. msprobe/visualization/graph/graph.py +200 -0
  265. msprobe/visualization/graph/node_colors.py +95 -0
  266. msprobe/visualization/graph/node_op.py +39 -0
  267. msprobe/visualization/graph_service.py +214 -0
  268. msprobe/visualization/utils.py +232 -0
  269. mindstudio_probe-1.0.4.dist-info/RECORD +0 -276
  270. msprobe/docs/04.acl_config_examples.md +0 -76
  271. msprobe/mindspore/free_benchmark/decorator/dec_forward.py +0 -43
  272. msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +0 -107
  273. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/ssl_config.py +0 -10
  274. msprobe/pytorch/functional/dump_module.py +0 -39
  275. {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.1.dist-info}/LICENSE +0 -0
  276. {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.1.dist-info}/top_level.txt +0 -0
  277. /msprobe/{mindspore/free_benchmark/decorator → pytorch/monitor}/__init__.py +0 -0
  278. /msprobe/pytorch/{functional/data_processor.py → monitor/distributed/__init__.py} +0 -0
@@ -1,40 +1,70 @@
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
1
16
  from abc import ABC, abstractmethod
2
- from msprobe.core.common.exceptions import ScopeException
17
+ import re
18
+
3
19
  from msprobe.core.common.const import Const
20
+ from msprobe.core.common.exceptions import ScopeException
4
21
 
5
22
 
6
- def build_scope(scope_class, scope=None, api_list=None):
7
- if not scope and not api_list:
8
- return None
9
- if scope is None:
10
- scope = []
11
- if api_list is None:
12
- api_list = []
13
- if scope_class:
14
- return scope_class(scope, api_list)
15
- return build_range_scope_according_to_scope_name(scope, api_list)
16
-
17
-
18
- def build_range_scope_according_to_scope_name(scope, api_list):
19
- api_range_scope = APIRangeScope(scope, api_list)
20
- module_range_scope = ModuleRangeScope(scope, api_list)
21
- if not scope: # 如果没有scope参数则用哪类scope都一样
22
- return api_range_scope
23
- if api_range_scope.is_valid and module_range_scope.is_valid:
24
- raise ScopeException(ScopeException.InvalidScope, f"scope={scope}.")
25
- elif api_range_scope.is_valid:
26
- return api_range_scope
27
- elif module_range_scope.is_valid:
28
- return module_range_scope
29
- else:
30
- raise ScopeException(ScopeException.InvalidScope, f"scope={scope}")
23
+ class ScopeFactory:
24
+ def __init__(self, config):
25
+ self.task = config.task
26
+ self.level = config.level
27
+ self.scope = config.scope
28
+ self.api_list = config.list
29
+
30
+ def build_scope(self):
31
+ if not self.scope and not self.api_list:
32
+ return None
33
+ if self.scope is None:
34
+ self.scope = []
35
+ if self.api_list is None:
36
+ self.api_list = []
37
+ if self.task == Const.FREE_BENCHMARK:
38
+ return ListScope(self.scope, self.api_list)
39
+ return self._build_range_scope()
40
+
41
+ def _build_range_scope(self):
42
+ api_range_scope = APIRangeScope(self.scope, self.api_list, self.level)
43
+ module_range_scope = ModuleRangeScope(self.scope, self.api_list, self.level)
44
+ mix_range_scope = MixRangeScope(self.scope, self.api_list, self.level)
45
+
46
+ if self.level == Const.LEVEL_MIX:
47
+ return mix_range_scope
48
+
49
+ if not self.scope:
50
+ return api_range_scope
51
+ if api_range_scope.is_valid and module_range_scope.is_valid:
52
+ raise ScopeException(ScopeException.InvalidScope, f"scope={self.scope}.")
53
+ elif api_range_scope.is_valid:
54
+ return api_range_scope
55
+ elif module_range_scope.is_valid:
56
+ return module_range_scope
57
+ else:
58
+ raise ScopeException(ScopeException.InvalidScope, f"scope={self.scope}")
31
59
 
32
60
 
33
61
  class BaseScope(ABC):
34
62
  Module_Type_Module = "Module"
35
63
  Module_Type_API = "api"
64
+ module_type = ["Module", "Cell"]
36
65
 
37
- def __init__(self, scope, api_list):
66
+ def __init__(self, scope, api_list, level=None):
67
+ self.level = level
38
68
  scope, api_list = self.rectify_args(scope, api_list)
39
69
  self.scope = scope
40
70
  self.api_list = api_list
@@ -81,9 +111,9 @@ class ListScope(BaseScope):
81
111
  f"scope和api_list不可以同时配置,实际配置为scope={scope}, api_list={api_list}.")
82
112
  return super(ListScope, ListScope).rectify_args(scope, api_list)
83
113
 
84
- def check(self, module_name):
85
- if not self.scope or module_name in self.scope:
86
- return self.check_api_list(module_name)
114
+ def check(self, name):
115
+ if not self.scope or name in self.scope:
116
+ return self.check_api_list(name)
87
117
  return False
88
118
 
89
119
 
@@ -92,19 +122,36 @@ class RangeScope(BaseScope, ABC):
92
122
  def __init__(self, *args):
93
123
  super().__init__(*args)
94
124
  self.in_scope = False
125
+ self.in_list = False
95
126
  self.is_valid = self.check_scope_is_valid()
96
127
 
128
+ def check_name_pattern(self, name):
129
+ options_pattern = "|".join(re.escape(option) for option in Const.DUMP_PREFIX)
130
+ api_pattern = rf"^({options_pattern})\..*\.\d+\.(forward|backward)$"
131
+ module_pattern = r"^(Cell|Module)\..*\.(forward|backward)\.\d+$"
97
132
 
98
- @staticmethod
99
- def rectify_args(scope, api_list):
100
- scope, api_list = super(RangeScope, RangeScope).rectify_args(scope, api_list)
101
- if isinstance(scope, list):
102
- if len(scope) == 1:
103
- scope.append(scope[0])
104
- elif len(scope) > 2:
133
+ if self.level == Const.LEVEL_L1:
134
+ if not re.match(api_pattern, name):
105
135
  raise ScopeException(ScopeException.InvalidScope,
106
- f"scope参数指定区间断点,须传入长度为1或2的列表,实际长度为{len(scope)}.")
136
+ f"scope参数格式错误,要求格式为api完整命名,实际为{name}.")
137
+
138
+ if self.level == Const.LEVEL_L0:
139
+ if not re.match(module_pattern, name):
140
+ raise ScopeException(ScopeException.InvalidScope,
141
+ f"scope参数格式错误,要求格式为模块完整命名,实际为{name}.")
142
+
143
+ if self.level == Const.LEVEL_MIX:
144
+ if not re.match(api_pattern, name) and not re.match(module_pattern, name):
145
+ raise ScopeException(ScopeException.InvalidScope,
146
+ f"scope参数格式错误,要求格式为api或模块完整命名,实际为{name}.")
107
147
 
148
+ def rectify_args(self, scope, api_list):
149
+ scope, api_list = super(RangeScope, RangeScope).rectify_args(scope, api_list)
150
+ if scope and len(scope) != 2:
151
+ raise ScopeException(ScopeException.InvalidScope,
152
+ f"scope参数指定区间断点,须传入长度为2的列表,实际长度为{len(scope)}.")
153
+ for name in scope:
154
+ self.check_name_pattern(name)
108
155
  return scope, api_list
109
156
 
110
157
  @abstractmethod
@@ -123,23 +170,23 @@ class APIRangeScope(RangeScope):
123
170
  if not self.scope:
124
171
  return True
125
172
  scope_start_type = self.scope[0].split(Const.SEP)[0]
126
- if scope_start_type == BaseScope.Module_Type_Module:
173
+ if scope_start_type in BaseScope.module_type:
127
174
  return False
128
175
  scope_stop_type = self.scope[1].split(Const.SEP)[0]
129
- if scope_stop_type == BaseScope.Module_Type_Module:
176
+ if scope_stop_type in BaseScope.module_type:
130
177
  return False
131
178
  return True
132
179
 
133
- def check(self, api_name):
134
- if self.scope and api_name == self.scope[0]:
180
+ def check(self, name):
181
+ if self.scope and name == self.scope[0]:
135
182
  self.in_scope = True
136
183
 
137
184
  if not self.scope or self.in_scope:
138
- result = self.check_api_list(api_name)
185
+ result = self.check_api_list(name)
139
186
  else:
140
187
  result = False
141
188
 
142
- if self.scope and api_name == self.scope[1]:
189
+ if self.scope and name == self.scope[1]:
143
190
  self.in_scope = False
144
191
  return result
145
192
 
@@ -150,13 +197,14 @@ class ModuleRangeScope(RangeScope):
150
197
  需要用pre_hook和full_backward_hook来精确控制module的开始和结束,
151
198
  在这些hook触发时调用begin_module和end_module做区间控制
152
199
  """
200
+
153
201
  def check_scope_is_valid(self):
154
202
  if not self.scope:
155
203
  return True
156
204
  scope_start_type = self.scope[0].split(Const.SEP)[0]
157
205
  scope_stop_type = self.scope[1].split(Const.SEP)[0]
158
- if scope_start_type == BaseScope.Module_Type_Module and \
159
- scope_stop_type == BaseScope.Module_Type_Module:
206
+ if scope_start_type in BaseScope.module_type and \
207
+ scope_stop_type in BaseScope.module_type:
160
208
  return True
161
209
  return False
162
210
 
@@ -172,7 +220,54 @@ class ModuleRangeScope(RangeScope):
172
220
  if module_name == self.scope[1]:
173
221
  self.in_scope = False
174
222
 
175
- def check(self, module_name):
223
+ def check(self, name):
176
224
  if not self.scope or self.in_scope:
177
- return self.check_api_list(module_name)
225
+ return self.check_api_list(name)
178
226
  return False
227
+
228
+
229
+ class MixRangeScope(RangeScope):
230
+ def check_scope_is_valid(self):
231
+ return True if self.scope else False
232
+
233
+ def begin_module(self, module_name):
234
+ if self.scope and module_name == self.scope[0]:
235
+ self.in_scope = True
236
+ for name in self.api_list:
237
+ if name in module_name:
238
+ self.in_list = True
239
+
240
+ def end_module(self, module_name):
241
+ if self.scope and module_name == self.scope[1]:
242
+ self.in_scope = False
243
+ for name in self.api_list:
244
+ if name in module_name:
245
+ self.in_list = False
246
+
247
+ def check_api_list(self, api_name):
248
+ if not self.api_list:
249
+ return True
250
+
251
+ for name in self.api_list:
252
+ if name in api_name:
253
+ return True
254
+ return False
255
+
256
+ def check(self, name):
257
+ """
258
+ dump时调用的接口,根据scope和api_list判断是否需要dump
259
+ """
260
+ result = False
261
+ if self.scope and name == self.scope[0]:
262
+ self.in_scope = True
263
+
264
+ if not self.scope or self.in_scope:
265
+ if self.in_list:
266
+ result = True
267
+ else:
268
+ result = self.check_api_list(name)
269
+
270
+ if self.scope and name == self.scope[1]:
271
+ self.in_scope = False
272
+ return result
273
+
@@ -1,3 +1,17 @@
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
1
15
 
2
16
  class GradConst:
3
17
 
@@ -33,6 +47,10 @@ class GradConst:
33
47
  # direction suffix
34
48
  DIR_SUFFIX = "dir.npy"
35
49
 
50
+ # bounds safety
51
+ BOUNDS_MINIMUM = -2**63
52
+ BOUNDS_MAXIMUM = 2**63 - 1
53
+
36
54
  # file safty
37
55
  DATA_DIR_AUTHORITY = 0o750
38
56
  DATA_FILE_AUTHORITY = 0o640
@@ -56,16 +74,16 @@ class GradConst:
56
74
  NORM = "norm"
57
75
 
58
76
  level_adp = {
59
- "L0": {
60
- "header": [GradConst.MD5, GradConst.MAX, GradConst.MIN, GradConst.NORM, GradConst.SHAPE],
61
- "have_grad_direction": False
62
- },
63
- "L1": {
64
- "header": [GradConst.MAX, GradConst.MIN, GradConst.NORM, GradConst.SHAPE],
65
- "have_grad_direction": True
66
- },
67
- "L2": {
68
- "header": [GradConst.DISTRIBUTION, GradConst.MAX, GradConst.MIN, GradConst.NORM, GradConst.SHAPE],
69
- "have_grad_direction": True
70
- },
71
- }
77
+ "L0": {
78
+ "header": [GradConst.MD5, GradConst.MAX, GradConst.MIN, GradConst.NORM, GradConst.SHAPE],
79
+ "have_grad_direction": False
80
+ },
81
+ "L1": {
82
+ "header": [GradConst.MAX, GradConst.MIN, GradConst.NORM, GradConst.SHAPE],
83
+ "have_grad_direction": True
84
+ },
85
+ "L2": {
86
+ "header": [GradConst.DISTRIBUTION, GradConst.MAX, GradConst.MIN, GradConst.NORM, GradConst.SHAPE],
87
+ "have_grad_direction": True
88
+ },
89
+ }
@@ -1,13 +1,27 @@
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
1
16
  import os
2
17
  from typing import List
3
18
 
4
19
  from tqdm import tqdm
5
- import pandas as pd
6
20
  import matplotlib.pyplot as plt
7
21
 
8
- from msprobe.core.common.file_utils import create_directory, check_path_before_create, check_file_or_directory_path
22
+ from msprobe.core.common.file_utils import create_directory, check_file_or_directory_path
9
23
  from msprobe.core.common.log import logger
10
- from msprobe.core.common.file_utils import remove_path, load_npy, write_csv
24
+ from msprobe.core.common.file_utils import remove_path, load_npy, write_csv, read_csv
11
25
  from msprobe.core.grad_probe.constant import GradConst
12
26
  from msprobe.core.grad_probe.utils import plt_savefig
13
27
 
@@ -21,7 +35,7 @@ class GradComparator:
21
35
  continue
22
36
  if not os.path.exists(os.path.join(path2, summary_file)):
23
37
  continue
24
- summary_csv = pd.read_csv(os.path.join(path1, summary_file))
38
+ summary_csv = read_csv(os.path.join(path1, summary_file))
25
39
  return summary_csv["param_name"]
26
40
  raise RuntimeError("no matched grad_summary.csv for comparison, please dump data in same configuration")
27
41
 
@@ -34,6 +48,8 @@ class GradComparator:
34
48
 
35
49
  @classmethod
36
50
  def compare_distributed(cls, path1: str, path2: str, output_dir: str):
51
+ check_file_or_directory_path(path1, isdir=True)
52
+ check_file_or_directory_path(path2, isdir=True)
37
53
  ranks = cls._get_matched_dirs(path1, path2, "rank")
38
54
  logger.info(f"the following ranks will be compared: {ranks}")
39
55
  if not ranks:
@@ -1,8 +1,24 @@
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
1
16
  import re
2
17
  from msprobe.core.grad_probe.constant import GradConst
3
18
  from msprobe.core.common.log import logger
4
19
  from msprobe.core.common.file_utils import write_csv, check_path_before_create, change_mode
5
20
  from msprobe.core.common.const import FileCheckConst
21
+ from msprobe.core.common.utils import is_int
6
22
  import matplotlib.pyplot as plt
7
23
 
8
24
 
@@ -20,12 +36,37 @@ def check_numeral_list_ascend(lst):
20
36
  def check_param(param_name):
21
37
  if not re.match(GradConst.PARAM_VALID_PATTERN, param_name):
22
38
  raise RuntimeError("The parameter name contains special characters.")
23
-
39
+
24
40
 
25
41
  def check_str(string, variable_name):
26
42
  if not isinstance(string, str):
27
43
  raise ValueError(f'The variable: "{variable_name}" is not a string.')
28
-
44
+
45
+
46
+ def check_bounds_element(bound):
47
+ return GradConst.BOUNDS_MINIMUM <= bound <= GradConst.BOUNDS_MAXIMUM
48
+
49
+
50
+ def check_param_element(param):
51
+ if not re.match(GradConst.PARAM_VALID_PATTERN, param):
52
+ return False
53
+ else:
54
+ return True
55
+
56
+
57
+ def check_bounds(bounds):
58
+ if not isinstance(bounds, list):
59
+ raise Exception(f"bounds must be a list")
60
+ prev = GradConst.BOUNDS_MINIMUM - 1
61
+ for element in bounds:
62
+ if not is_int(element) and not isinstance(element, float):
63
+ raise Exception("bounds element is not int or float")
64
+ if not check_bounds_element(element):
65
+ raise Exception("bounds element is out of int64 range")
66
+ if prev >= element:
67
+ raise Exception("bounds list is not ascending")
68
+ prev = element
69
+
29
70
 
30
71
  class ListCache(list):
31
72
  threshold = 1000
@@ -50,7 +91,7 @@ class ListCache(list):
50
91
  list.append(self, data)
51
92
  if len(self) >= ListCache.threshold:
52
93
  self.flush()
53
-
94
+
54
95
  def set_output_file(self, output_file):
55
96
  self._output_file = output_file
56
97
 
@@ -0,0 +1,185 @@
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from typing import List, Dict, Union, Any
17
+
18
+ import numpy as np
19
+
20
+ from msprobe.core.overflow_check.api_info import APIInfo
21
+ from msprobe.core.overflow_check.level import OverflowLevel
22
+ from msprobe.core.overflow_check.utils import has_nan_inf
23
+
24
+
25
+ class AnomalyScene:
26
+ """异常场景的基类"""
27
+
28
+ def __init__(self, api_info: APIInfo):
29
+ self.api_name = api_info.api_name
30
+ self.api_data = api_info
31
+
32
+ @property
33
+ def rank(self) -> OverflowLevel:
34
+ """获取异常等级"""
35
+ raise NotImplementedError
36
+
37
+ @staticmethod
38
+ def _has_anomaly(data: Union[Dict, Any]) -> bool:
39
+ """检查张量是否包含异常值"""
40
+ return has_nan_inf(data)
41
+
42
+ def get_details(self) -> Dict:
43
+ """获取异常详情"""
44
+ return {
45
+ 'api_name': self.api_name,
46
+ 'rank': self.rank.value,
47
+ 'scene_type': self.__class__.__name__,
48
+ 'input_args_anomaly_indices': self._get_anomaly_indices_from_list(self.api_data.input_args),
49
+ 'input_kwargs_anomaly_keys': self._get_anomaly_keys_from_dict(self.api_data.input_kwargs),
50
+ 'output_anomaly_indices': self._get_anomaly_indices_from_list(self.api_data.output_data)
51
+ }
52
+
53
+ def matches(self) -> bool:
54
+ """
55
+ 待子类实现对应匹配逻辑
56
+ Returns:
57
+
58
+ """
59
+ raise NotImplementedError
60
+
61
+ def _get_anomaly_indices_from_list(self, data_list: List[Dict]) -> List[int]:
62
+ return [i for i, data in enumerate(data_list) if self._has_anomaly(data)]
63
+
64
+ def _get_anomaly_keys_from_dict(self, data_dict: Dict) -> List[str]:
65
+ return [key for key, data in data_dict.items() if self._has_anomaly(data)]
66
+
67
+
68
+ class InputOutputAnomalyScene(AnomalyScene):
69
+ """输入输出异常检测的基类"""
70
+ def has_input_anomaly(self) -> bool:
71
+ """检查输入是否有异常(包括args和kwargs)"""
72
+ # args
73
+ args_anomaly = any(self._has_anomaly(x) for x in self.api_data.input_args if isinstance(x, dict))
74
+ # kwargs
75
+ kwargs_anomaly = any(self._has_anomaly(x) for x in self.api_data.input_kwargs.values() if isinstance(x, dict))
76
+ return args_anomaly or kwargs_anomaly
77
+
78
+ def has_output_anomaly(self) -> bool:
79
+ """检查输出是否有异常"""
80
+ return any(self._has_anomaly(x) for x in self.api_data.output_data if isinstance(x, dict))
81
+
82
+ def matches(self) -> bool:
83
+ """判断是否匹配该场景"""
84
+ raise NotImplementedError
85
+
86
+
87
+ class InputAnomalyOutputNormalScene(InputOutputAnomalyScene):
88
+ """输入异常,输出正常场景"""
89
+
90
+ @property
91
+ def rank(self) -> OverflowLevel:
92
+ return OverflowLevel.MEDIUM
93
+
94
+ def matches(self) -> bool:
95
+ return self.has_input_anomaly() and not self.has_output_anomaly()
96
+
97
+
98
+ class InputAnomalyOutputAnomalyScene(InputOutputAnomalyScene):
99
+ """输入异常,输出异常场景"""
100
+
101
+ @property
102
+ def rank(self) -> OverflowLevel:
103
+ return OverflowLevel.HIGH
104
+
105
+ def matches(self) -> bool:
106
+ return self.has_input_anomaly() and self.has_output_anomaly()
107
+
108
+
109
+ class InputNormalOutputAnomalyScene(InputOutputAnomalyScene):
110
+ """输入正常,输出异常场景"""
111
+
112
+ @property
113
+ def rank(self) -> OverflowLevel:
114
+ return OverflowLevel.CRITICAL
115
+
116
+ def matches(self) -> bool:
117
+ return not self.has_input_anomaly() and self.has_output_anomaly()
118
+
119
+
120
+ class NumericalMutationScene(AnomalyScene):
121
+ """
122
+ 检查数值突变,统计输入args、kwargs中norm值,同时统计输出的norm最大值,计算差异,大于 threshold 则认为是异常情况
123
+ """
124
+ def __init__(self, api_info: APIInfo, threshold: float = 100000.0):
125
+ super().__init__(api_info)
126
+ self.threshold = threshold
127
+
128
+ @property
129
+ def rank(self) -> OverflowLevel:
130
+ return OverflowLevel.HIGH
131
+
132
+ @staticmethod
133
+ def _get_tensor_norms(data_list: List[Dict]) -> List[float]:
134
+ norms = []
135
+ for data in data_list:
136
+ if isinstance(data, dict) and data.get('type') == 'torch.Tensor':
137
+ norm = data.get('Norm')
138
+ if norm is not None and not np.isnan(norm):
139
+ norms.append(norm)
140
+ return norms
141
+
142
+ @staticmethod
143
+ def _get_kwargs_norms(data_dict: Dict) -> List[float]:
144
+ """
145
+ 获取kwargs中张量的范数列表
146
+ Args:
147
+ data_dict:
148
+ Returns:
149
+ """
150
+ norms = []
151
+ for data in data_dict.values():
152
+ if isinstance(data, dict) and data.get('type') == 'torch.Tensor':
153
+ norm = data.get('Norm')
154
+ if norm is not None and not np.isnan(norm):
155
+ norms.append(norm)
156
+ return norms
157
+
158
+ def matches(self) -> bool:
159
+ """
160
+ 继承父类函数,实现数值突变检查
161
+ Returns:
162
+ """
163
+ # 收集所有输入的范数
164
+ input_norms = (self._get_tensor_norms(self.api_data.input_args) +
165
+ self._get_kwargs_norms(self.api_data.input_kwargs))
166
+ # 收集所有输出的范数
167
+ output_norms = self._get_tensor_norms(self.api_data.output_data)
168
+
169
+ if not input_norms or not output_norms:
170
+ return False
171
+
172
+ max_input = max(input_norms)
173
+ max_output = max(output_norms)
174
+
175
+ if max_input == 0:
176
+ return max_output > self.threshold
177
+ return max_output / max_input > self.threshold
178
+
179
+ def get_details(self) -> Dict:
180
+ details = super().get_details()
181
+ details.update({
182
+ 'threshold': self.threshold,
183
+ 'scale_change_detected': self.matches()
184
+ })
185
+ return details
@@ -0,0 +1,55 @@
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from dataclasses import dataclass
17
+
18
+ from typing import Dict, List
19
+
20
+ from msprobe.core.common.const import Const
21
+
22
+
23
+ @dataclass
24
+ class APIInfo:
25
+ api_name: str
26
+ torch_api_name: str
27
+ input_args: List[Dict]
28
+ input_kwargs: Dict
29
+ output_data: List[Dict]
30
+
31
+ def __init__(self, api_name, input_args=None, input_kwargs=None, output_data=None):
32
+ self.api_name = api_name
33
+ self.input_args = input_args
34
+ self.input_kwargs = input_kwargs
35
+ self.output_data = output_data
36
+ self.torch_api_name = self.extract_torch_api(self.api_name)
37
+
38
+ @staticmethod
39
+ def extract_torch_api(api_name) -> str:
40
+ """
41
+ Process tensor api name to extract first two fields in lowercase.
42
+ """
43
+ # Empty string checking
44
+ if not api_name.strip():
45
+ return ""
46
+
47
+ parts = api_name.split(Const.SEP)
48
+
49
+ # Handle different cases based on number of parts
50
+ if len(parts) == 0:
51
+ return ""
52
+ elif len(parts) == 1:
53
+ return parts[0].lower()
54
+ else:
55
+ return Const.SEP.join(parts[:2]).lower()