mindstudio-probe 1.1.0__py3-none-any.whl → 1.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (299) hide show
  1. {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/METADATA +7 -6
  2. mindstudio_probe-1.2.1.dist-info/RECORD +396 -0
  3. {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/WHEEL +1 -1
  4. {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/entry_points.txt +0 -1
  5. msprobe/CMakeLists.txt +5 -0
  6. msprobe/README.md +51 -20
  7. msprobe/config.json +2 -3
  8. msprobe/core/advisor/advisor.py +8 -3
  9. msprobe/core/common/const.py +264 -15
  10. msprobe/core/common/exceptions.py +27 -3
  11. msprobe/core/common/file_utils.py +176 -26
  12. msprobe/core/common/inplace_op_checker.py +15 -0
  13. msprobe/core/common/inplace_ops.yaml +3 -0
  14. msprobe/core/common/log.py +27 -9
  15. msprobe/core/common/utils.py +204 -77
  16. msprobe/core/common_config.py +49 -14
  17. msprobe/core/compare/acc_compare.py +274 -198
  18. msprobe/core/compare/check.py +32 -33
  19. msprobe/core/compare/compare_cli.py +32 -14
  20. msprobe/core/compare/highlight.py +283 -127
  21. msprobe/core/compare/layer_mapping/__init__.py +19 -0
  22. msprobe/core/compare/layer_mapping/data_scope_parser.py +246 -0
  23. msprobe/core/compare/layer_mapping/layer_mapping.py +249 -0
  24. msprobe/core/compare/layer_mapping/postprocess_pass.py +95 -0
  25. msprobe/core/compare/merge_result/merge_result.py +380 -0
  26. msprobe/core/compare/merge_result/merge_result_cli.py +31 -0
  27. msprobe/core/compare/multiprocessing_compute.py +2 -2
  28. msprobe/core/compare/npy_compare.py +135 -144
  29. msprobe/core/compare/utils.py +419 -274
  30. msprobe/core/data_dump/data_collector.py +60 -28
  31. msprobe/core/data_dump/data_processor/base.py +84 -36
  32. msprobe/core/data_dump/data_processor/factory.py +5 -3
  33. msprobe/core/data_dump/data_processor/mindspore_processor.py +152 -18
  34. msprobe/core/data_dump/data_processor/pytorch_processor.py +267 -110
  35. msprobe/core/data_dump/json_writer.py +29 -1
  36. msprobe/core/data_dump/scope.py +119 -39
  37. msprobe/core/grad_probe/constant.py +27 -13
  38. msprobe/core/grad_probe/grad_compare.py +18 -1
  39. msprobe/core/grad_probe/utils.py +30 -2
  40. msprobe/core/overflow_check/abnormal_scene.py +189 -0
  41. msprobe/core/overflow_check/api_info.py +55 -0
  42. msprobe/core/overflow_check/checker.py +138 -0
  43. msprobe/core/overflow_check/filter.py +157 -0
  44. msprobe/core/overflow_check/ignore_rules.yaml +55 -0
  45. msprobe/core/overflow_check/level.py +22 -0
  46. msprobe/core/overflow_check/utils.py +28 -0
  47. msprobe/docs/01.installation.md +96 -7
  48. msprobe/docs/02.config_introduction.md +50 -23
  49. msprobe/docs/03.config_examples.md +2 -9
  50. msprobe/docs/04.kernel_dump_PyTorch.md +73 -0
  51. msprobe/docs/05.data_dump_PyTorch.md +93 -61
  52. msprobe/docs/06.data_dump_MindSpore.md +200 -95
  53. msprobe/docs/07.accuracy_checker_PyTorch.md +28 -28
  54. msprobe/docs/08.accuracy_checker_online_PyTorch.md +1 -6
  55. msprobe/docs/09.accuracy_checker_MindSpore.md +44 -8
  56. msprobe/docs/10.accuracy_compare_PyTorch.md +114 -50
  57. msprobe/docs/11.accuracy_compare_MindSpore.md +340 -48
  58. msprobe/docs/12.overflow_check_PyTorch.md +2 -2
  59. msprobe/docs/13.overflow_check_MindSpore.md +6 -6
  60. msprobe/docs/15.free_benchmarking_PyTorch.md +4 -5
  61. msprobe/docs/16.free_benchmarking_MindSpore.md +56 -37
  62. msprobe/docs/17.grad_probe.md +5 -6
  63. msprobe/docs/19.monitor.md +561 -0
  64. msprobe/docs/20.monitor_performance_baseline.md +52 -0
  65. msprobe/docs/21.visualization_PyTorch.md +466 -0
  66. msprobe/docs/22.visualization_MindSpore.md +481 -0
  67. msprobe/docs/23.generate_operator_PyTorch.md +107 -0
  68. msprobe/docs/24.code_mapping_Mindspore.md +28 -0
  69. msprobe/docs/25.tool_function_introduction.md +29 -0
  70. msprobe/docs/26.data_dump_PyTorch_baseline.md +37 -0
  71. msprobe/docs/27.dump_json_instruction.md +521 -0
  72. msprobe/docs/FAQ.md +29 -2
  73. msprobe/docs/accuracy_checker_MindSpore/accuracy_checker_MindSpore_baseline.md +14 -0
  74. msprobe/docs/data_dump_MindSpore/data_dump_MindSpore_baseline.md +22 -0
  75. msprobe/docs/data_dump_MindSpore/dynamic_graph_quick_start_example.md +211 -0
  76. msprobe/docs/img/compare_result.png +0 -0
  77. msprobe/docs/img/merge_result.png +0 -0
  78. msprobe/docs/img/monitor/cpu_info.png +0 -0
  79. msprobe/docs/img/visualization/fuzzy_match_ms.png +0 -0
  80. msprobe/docs/img/visualization/fuzzy_match_pt.png +0 -0
  81. msprobe/docs/img/visualization/tensorboard_1.png +0 -0
  82. msprobe/docs/img/visualization/tensorboard_2.png +0 -0
  83. msprobe/docs/img/visualization/vis_browser_1.png +0 -0
  84. msprobe/docs/img/visualization/vis_browser_2.png +0 -0
  85. msprobe/docs/img/visualization/vis_precision_info.png +0 -0
  86. msprobe/docs/img/visualization/vis_search_info.png +0 -0
  87. msprobe/docs/img/visualization/vis_show_info.png +0 -0
  88. msprobe/docs/img/visualization/vis_showcase.png +0 -0
  89. msprobe/docs/img/visualization/vis_unmatch_info.png +0 -0
  90. msprobe/docs/visualization/GPTModel.png +0 -0
  91. msprobe/docs/visualization/ParallelMLP.png +0 -0
  92. msprobe/docs/visualization/layer_mapping_example.md +132 -0
  93. msprobe/docs/visualization/mapping.png +0 -0
  94. msprobe/docs/visualization/mapping1.png +0 -0
  95. msprobe/docs/visualization/module_name.png +0 -0
  96. msprobe/docs/visualization/module_name1.png +0 -0
  97. msprobe/docs/visualization/no_mapping.png +0 -0
  98. msprobe/docs/visualization/no_mapping1.png +0 -0
  99. msprobe/docs/visualization/no_mapping_analyze.png +0 -0
  100. msprobe/docs/visualization/top_layer.png +0 -0
  101. msprobe/mindspore/__init__.py +25 -0
  102. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +151 -151
  103. msprobe/mindspore/api_accuracy_checker/api_info.py +21 -6
  104. msprobe/mindspore/api_accuracy_checker/api_runner.py +43 -18
  105. msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +21 -7
  106. msprobe/mindspore/api_accuracy_checker/checker_support_api.yaml +77 -0
  107. msprobe/mindspore/api_accuracy_checker/cmd_parser.py +64 -1
  108. msprobe/mindspore/api_accuracy_checker/compute_element.py +64 -31
  109. msprobe/mindspore/api_accuracy_checker/data_manager.py +301 -0
  110. msprobe/mindspore/api_accuracy_checker/main.py +28 -3
  111. msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +212 -0
  112. msprobe/mindspore/api_accuracy_checker/multi_data_manager.py +60 -0
  113. msprobe/mindspore/api_accuracy_checker/type_mapping.py +22 -5
  114. msprobe/mindspore/api_accuracy_checker/utils.py +34 -17
  115. msprobe/mindspore/cell_processor.py +33 -12
  116. msprobe/mindspore/code_mapping/bind.py +264 -0
  117. msprobe/mindspore/code_mapping/cmd_parser.py +40 -0
  118. msprobe/mindspore/code_mapping/graph.py +49 -0
  119. msprobe/mindspore/code_mapping/graph_parser.py +226 -0
  120. msprobe/mindspore/code_mapping/main.py +24 -0
  121. msprobe/mindspore/code_mapping/processor.py +34 -0
  122. msprobe/mindspore/common/const.py +35 -13
  123. msprobe/mindspore/common/log.py +5 -9
  124. msprobe/mindspore/common/utils.py +88 -4
  125. msprobe/mindspore/compare/distributed_compare.py +22 -24
  126. msprobe/mindspore/compare/ms_compare.py +333 -268
  127. msprobe/mindspore/compare/ms_graph_compare.py +95 -52
  128. msprobe/mindspore/debugger/debugger_config.py +7 -1
  129. msprobe/mindspore/debugger/precision_debugger.py +87 -12
  130. msprobe/mindspore/dump/dump_tool_factory.py +3 -1
  131. msprobe/mindspore/dump/hook_cell/api_registry.py +95 -18
  132. msprobe/mindspore/dump/hook_cell/hook_cell.py +60 -38
  133. msprobe/mindspore/dump/hook_cell/primitive_hooks.py +45 -30
  134. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +36 -1
  135. msprobe/mindspore/dump/hook_cell/wrap_api.py +92 -1
  136. msprobe/mindspore/dump/jit_dump.py +17 -5
  137. msprobe/mindspore/dump/kernel_dump/kernel_config.py +33 -0
  138. msprobe/mindspore/dump/kernel_graph_dump.py +9 -4
  139. msprobe/mindspore/dump/kernel_kbyk_dump.py +2 -4
  140. msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +140 -0
  141. msprobe/mindspore/dym_loader/hook_dynamic_loader.h +53 -0
  142. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +156 -41
  143. msprobe/mindspore/free_benchmark/common/handler_params.py +1 -2
  144. msprobe/mindspore/free_benchmark/common/utils.py +19 -4
  145. msprobe/mindspore/free_benchmark/data/support_wrap_ops.yaml +0 -204
  146. msprobe/mindspore/free_benchmark/handler/base_handler.py +3 -3
  147. msprobe/mindspore/free_benchmark/handler/check_handler.py +4 -5
  148. msprobe/mindspore/free_benchmark/handler/fix_handler.py +4 -4
  149. msprobe/mindspore/free_benchmark/handler/handler_factory.py +4 -4
  150. msprobe/mindspore/free_benchmark/perturbation/add_noise.py +2 -2
  151. msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +15 -6
  152. msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +2 -2
  153. msprobe/mindspore/free_benchmark/perturbation/exchange_value.py +2 -2
  154. msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +13 -6
  155. msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +2 -2
  156. msprobe/mindspore/free_benchmark/self_check_tool_factory.py +2 -2
  157. msprobe/mindspore/grad_probe/global_context.py +28 -8
  158. msprobe/mindspore/grad_probe/grad_analyzer.py +50 -24
  159. msprobe/mindspore/grad_probe/grad_monitor.py +16 -1
  160. msprobe/mindspore/grad_probe/grad_stat_csv.py +33 -5
  161. msprobe/mindspore/grad_probe/hook.py +35 -12
  162. msprobe/mindspore/grad_probe/utils.py +18 -5
  163. msprobe/mindspore/mindtorch/__init__.py +18 -0
  164. msprobe/mindspore/mindtorch/mindtorch_adaptor.py +255 -0
  165. msprobe/mindspore/ms_config.py +27 -16
  166. msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +9 -4
  167. msprobe/mindspore/runtime.py +15 -0
  168. msprobe/mindspore/service.py +285 -113
  169. msprobe/mindspore/task_handler_factory.py +15 -0
  170. msprobe/msprobe.py +48 -10
  171. msprobe/pytorch/__init__.py +8 -6
  172. msprobe/pytorch/api_accuracy_checker/common/config.py +62 -0
  173. msprobe/pytorch/api_accuracy_checker/common/utils.py +31 -16
  174. msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +41 -8
  175. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +103 -271
  176. msprobe/pytorch/api_accuracy_checker/compare/api_precision_standard.yaml +4 -1
  177. msprobe/pytorch/api_accuracy_checker/compare/compare.py +69 -68
  178. msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +54 -0
  179. msprobe/pytorch/api_accuracy_checker/compare/compare_input.py +51 -0
  180. msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +2 -4
  181. msprobe/pytorch/api_accuracy_checker/generate_op_script/config_op.json +9 -0
  182. msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +478 -0
  183. msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +365 -0
  184. msprobe/pytorch/api_accuracy_checker/precision_standard/absolute_threshold.py +106 -0
  185. msprobe/pytorch/api_accuracy_checker/precision_standard/accumulative_error_compare.py +107 -0
  186. msprobe/pytorch/api_accuracy_checker/precision_standard/base_standard.py +151 -0
  187. msprobe/pytorch/api_accuracy_checker/precision_standard/benchmark_compare.py +226 -0
  188. msprobe/pytorch/api_accuracy_checker/precision_standard/binary_consistency.py +68 -0
  189. msprobe/pytorch/api_accuracy_checker/precision_standard/standard_config.py +218 -0
  190. msprobe/pytorch/api_accuracy_checker/precision_standard/standard_register.py +104 -0
  191. msprobe/pytorch/api_accuracy_checker/precision_standard/thousandth_standard.py +63 -0
  192. msprobe/pytorch/api_accuracy_checker/precision_standard/ulp_compare.py +200 -0
  193. msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +63 -2
  194. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +21 -15
  195. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +54 -22
  196. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +140 -71
  197. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +49 -8
  198. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +9 -24
  199. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +4 -12
  200. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +5 -3
  201. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/dump_dispatch.py +9 -4
  202. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +3 -11
  203. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +2 -2
  204. msprobe/pytorch/bench_functions/confusion_transpose.py +5 -1
  205. msprobe/pytorch/bench_functions/matmul_backward.py +12 -0
  206. msprobe/pytorch/bench_functions/npu_fusion_attention.py +142 -16
  207. msprobe/pytorch/bench_functions/rotary_mul.py +4 -0
  208. msprobe/pytorch/bench_functions/swiglu.py +10 -2
  209. msprobe/pytorch/common/parse_json.py +7 -6
  210. msprobe/pytorch/common/utils.py +101 -7
  211. msprobe/pytorch/compare/distributed_compare.py +17 -30
  212. msprobe/pytorch/compare/pt_compare.py +44 -22
  213. msprobe/pytorch/debugger/debugger_config.py +46 -27
  214. msprobe/pytorch/debugger/precision_debugger.py +42 -12
  215. msprobe/pytorch/dump/kernel_dump/kernel_config.py +33 -0
  216. msprobe/pytorch/dump/module_dump/module_dump.py +86 -0
  217. msprobe/pytorch/{module_processer.py → dump/module_dump/module_processer.py} +81 -10
  218. msprobe/pytorch/free_benchmark/common/constant.py +15 -0
  219. msprobe/pytorch/free_benchmark/common/counter.py +15 -0
  220. msprobe/pytorch/free_benchmark/common/enums.py +15 -0
  221. msprobe/pytorch/free_benchmark/common/params.py +10 -2
  222. msprobe/pytorch/free_benchmark/common/utils.py +29 -4
  223. msprobe/pytorch/free_benchmark/compare/grad_saver.py +20 -5
  224. msprobe/pytorch/free_benchmark/compare/single_benchmark.py +2 -0
  225. msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +3 -1
  226. msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +6 -4
  227. msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +2 -0
  228. msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +4 -0
  229. msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +41 -47
  230. msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +6 -5
  231. msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +0 -4
  232. msprobe/pytorch/grad_probe/grad_monitor.py +23 -6
  233. msprobe/pytorch/grad_probe/grad_stat_csv.py +40 -10
  234. msprobe/pytorch/hook_module/__init__.py +1 -1
  235. msprobe/pytorch/hook_module/hook_module.py +14 -11
  236. msprobe/pytorch/hook_module/register_optimizer_hook.py +59 -0
  237. msprobe/pytorch/hook_module/support_wrap_ops.yaml +35 -0
  238. msprobe/pytorch/hook_module/wrap_distributed.py +6 -8
  239. msprobe/pytorch/hook_module/wrap_functional.py +0 -38
  240. msprobe/pytorch/monitor/__init__.py +0 -0
  241. msprobe/pytorch/monitor/anomaly_analyse.py +201 -0
  242. msprobe/pytorch/monitor/anomaly_detect.py +425 -0
  243. msprobe/pytorch/monitor/csv2tb.py +166 -0
  244. msprobe/pytorch/monitor/distributed/__init__.py +0 -0
  245. msprobe/pytorch/monitor/distributed/distributed_ops.yaml +19 -0
  246. msprobe/pytorch/monitor/distributed/stack_blacklist.yaml +5 -0
  247. msprobe/pytorch/monitor/distributed/wrap_distributed.py +283 -0
  248. msprobe/pytorch/monitor/features.py +108 -0
  249. msprobe/pytorch/monitor/module_hook.py +1076 -0
  250. msprobe/pytorch/monitor/module_metric.py +172 -0
  251. msprobe/pytorch/monitor/module_spec_verifier.py +95 -0
  252. msprobe/pytorch/monitor/optimizer_collect.py +333 -0
  253. msprobe/pytorch/monitor/unittest/__init__.py +0 -0
  254. msprobe/pytorch/monitor/unittest/test_monitor.py +160 -0
  255. msprobe/pytorch/monitor/utils.py +321 -0
  256. msprobe/pytorch/monitor/visualizer.py +59 -0
  257. msprobe/pytorch/online_dispatch/__init__.py +2 -3
  258. msprobe/pytorch/online_dispatch/compare.py +29 -38
  259. msprobe/pytorch/online_dispatch/dispatch.py +58 -27
  260. msprobe/pytorch/online_dispatch/dump_compare.py +21 -9
  261. msprobe/pytorch/online_dispatch/single_compare.py +53 -32
  262. msprobe/pytorch/online_dispatch/torch_ops_config.yaml +1 -1
  263. msprobe/pytorch/online_dispatch/utils.py +49 -21
  264. msprobe/pytorch/parse_tool/lib/compare.py +21 -27
  265. msprobe/pytorch/parse_tool/lib/config.py +6 -8
  266. msprobe/pytorch/parse_tool/lib/file_desc.py +15 -1
  267. msprobe/pytorch/parse_tool/lib/interactive_cli.py +10 -10
  268. msprobe/pytorch/parse_tool/lib/parse_exception.py +7 -7
  269. msprobe/pytorch/parse_tool/lib/parse_tool.py +12 -12
  270. msprobe/pytorch/parse_tool/lib/utils.py +33 -53
  271. msprobe/pytorch/parse_tool/lib/visualization.py +11 -10
  272. msprobe/pytorch/pt_config.py +31 -8
  273. msprobe/pytorch/service.py +188 -108
  274. msprobe/visualization/__init__.py +14 -0
  275. msprobe/visualization/builder/__init__.py +14 -0
  276. msprobe/visualization/builder/graph_builder.py +222 -0
  277. msprobe/visualization/builder/msprobe_adapter.py +227 -0
  278. msprobe/visualization/compare/__init__.py +14 -0
  279. msprobe/visualization/compare/graph_comparator.py +180 -0
  280. msprobe/visualization/compare/mode_adapter.py +197 -0
  281. msprobe/visualization/graph/__init__.py +14 -0
  282. msprobe/visualization/graph/base_node.py +119 -0
  283. msprobe/visualization/graph/distributed_analyzer.py +318 -0
  284. msprobe/visualization/graph/graph.py +209 -0
  285. msprobe/visualization/graph/node_colors.py +95 -0
  286. msprobe/visualization/graph/node_op.py +39 -0
  287. msprobe/visualization/graph_service.py +288 -0
  288. msprobe/visualization/utils.py +217 -0
  289. mindstudio_probe-1.1.0.dist-info/RECORD +0 -287
  290. msprobe/docs/04.acl_config_examples.md +0 -78
  291. msprobe/mindspore/compare/layer_mapping.py +0 -146
  292. msprobe/mindspore/compare/modify_mapping.py +0 -107
  293. msprobe/mindspore/free_benchmark/decorator/dec_forward.py +0 -57
  294. msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +0 -122
  295. msprobe/pytorch/functional/module_dump.py +0 -84
  296. {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/LICENSE +0 -0
  297. {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/top_level.txt +0 -0
  298. /msprobe/mindspore/{free_benchmark/decorator → code_mapping}/__init__.py +0 -0
  299. /msprobe/pytorch/{functional → dump/module_dump}/__init__.py +0 -0
@@ -0,0 +1,478 @@
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+ # Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
4
+ # All rights reserved.
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+
18
+ import argparse
19
+ import json
20
+ import os
21
+ import re
22
+
23
+ import math
24
+ import numpy as np
25
+ import torch
26
+
27
+ from msprobe.pytorch.api_accuracy_checker.compare.compare_utils import binary_standard_api, absolute_standard_api, \
28
+ ulp_standard_api, thousandth_standard_api
29
+ from msprobe.core.common.file_utils import FileOpen, load_json, save_json
30
+ from msprobe.core.common.utils import check_file_or_directory_path, check_op_str_pattern_valid, is_int
31
+ from msprobe.core.common.const import Const, MonitorConst, MsgConst
32
+ from msprobe.core.common.log import logger
33
+ from msprobe.core.common.file_utils import make_dir
34
+ from msprobe.core.common.utils import recursion_depth_decorator
35
+
36
+ TENSOR_DATA_LIST = ["torch.Tensor", "torch.nn.parameter.Parameter"]
37
+ TORCH_BOOL_TYPE = ["torch.bool"]
38
+ TORCH_INT_TYPE = ["torch.uint8", "torch.int8", "torch.int16", "torch.short", "torch.int32", "torch.int",
39
+ "torch.int64", "torch.long"]
40
+ TORCH_FLOAT_TYPE = ["torch.float16", "torch.half", "torch.bfloat16", "torch.float32", "torch.float",
41
+ "torch.float64", "torch.double"]
42
+ TORCH_COMPLEX_TYPE = ["torch.complex32", "torch.chalf", "torch.complex64", "torch.cfloat", "torch.complex128",
43
+ "torch.cdouble"]
44
+ OPERATOR_TYPE = ("Functional", "Tensor", "Torch")
45
+
46
+ API_INFO = 2
47
+ FOUR_SEGMENT = 4
48
+ FIVE_SEGMENT = 5
49
+ DATA_NAME = "data_name"
50
+ API_MAX_LENGTH = 30
51
+ PROPAGATION_LIST = [Const.FORWARD, Const.BACKWARD]
52
+ DATAMODE_LIST = ["random_data", "real_data"]
53
+
54
+
55
+ class APIInfo:
56
+ def __init__(self, api_full_name, api_info_dict, backward_info=None):
57
+ self.api_full_name = api_full_name
58
+ self.api_info_dict = api_info_dict
59
+ self.backward_info = backward_info
60
+
61
+ @property
62
+ def api_type(self):
63
+ return self.api_full_name.split(Const.SEP, -1)[0]
64
+
65
+ @classmethod
66
+ def from_json(cls, json_content, propagation):
67
+ forward_name, forward_dict = list(json_content.items())[0]
68
+ forward_info = cls(api_full_name=forward_name, api_info_dict=forward_dict)
69
+
70
+ if propagation == Const.BACKWARD:
71
+ backward_name, backward_dict = list(json_content.items())[1]
72
+ backward_info = cls(api_full_name=backward_name, api_info_dict=backward_dict)
73
+ forward_info.backward_info = backward_info
74
+
75
+ if not forward_info.is_supported_type():
76
+ raise ValueError(f"type {forward_info.api_type} of API is not supported!")
77
+
78
+ return forward_info
79
+
80
+ def is_supported_type(self):
81
+ return self.api_type in OPERATOR_TYPE
82
+
83
+
84
+ class CommonConfig:
85
+ def __init__(self, json_config):
86
+ self.dump_json_path = json_config.get('dump_json_path')
87
+ self.api_name = json_config.get('api_name')
88
+ self.extract_api_path = json_config.get('extract_api_path')
89
+ self.propagation = json_config.get('propagation')
90
+ self.data_mode = json_config.get('data_mode')
91
+ self.random_seed = json_config.get('random_seed')
92
+ self.iter_times = json_config.get('iter_times')
93
+ self._check_config()
94
+
95
+
96
+ def check_user_settings(self):
97
+ iter_t = self.iter_times
98
+ if iter_t <= 0:
99
+ raise ValueError("iter_times should be an integer bigger than zero!")
100
+
101
+ json_file = self.extract_api_path
102
+ propagation = self.propagation
103
+
104
+ json_content = load_json(json_file)
105
+
106
+ # ensure the dict is not empty
107
+ if not json_content:
108
+ raise ValueError(f'json file is empty!')
109
+
110
+ # ensure json_content is of type dict
111
+ if not isinstance(json_content, dict):
112
+ raise ValueError(f'content of json file is not a dict!')
113
+
114
+ # ensure the length of json_content is within allowed limits
115
+ if len(json_content) > API_INFO:
116
+ raise ValueError(f'json file has more than one API, the API only contains forward and backward info')
117
+
118
+ # Retrieve the first API name and dictionary
119
+ forward_item = next(iter(json_content.items()), None)
120
+ if not forward_item or not isinstance(forward_item[1], dict):
121
+ raise ValueError(f'Invalid forward API data in json_content!')
122
+
123
+ # if propagation is backward, ensure json file contains forward and backward info
124
+ if propagation == Const.BACKWARD and len(json_content) < API_INFO:
125
+ raise ValueError(f'Backward propagation requires contains forward and backward info!')
126
+
127
+ # if propagation is backward, ensure it has valid data
128
+ if propagation == Const.BACKWARD:
129
+ backward_item = list(json_content.items())[1]
130
+ if not isinstance(backward_item[1], dict):
131
+ raise ValueError(f'Invalid backward API data in json_content!')
132
+
133
+ return json_content
134
+
135
+
136
+ def _check_config(self):
137
+ if self.dump_json_path:
138
+ check_file_or_directory_path(self.dump_json_path)
139
+ if self.api_name:
140
+ check_op_str_pattern_valid(self.api_name)
141
+ if len(self.api_name) > API_MAX_LENGTH:
142
+ raise ValueError(f'API name {self.api_name} is too long!')
143
+ make_dir(os.path.dirname(self.extract_api_path))
144
+ if self.propagation and self.propagation not in PROPAGATION_LIST:
145
+ raise ValueError(f'propagation is invalid, it should be one of {PROPAGATION_LIST}')
146
+ if self.data_mode and self.data_mode not in DATAMODE_LIST:
147
+ raise ValueError(f'data_mode is invalid, it should be one of {DATAMODE_LIST}')
148
+ if not is_int(self.random_seed):
149
+ raise ValueError(f'random_seed is invalid, it should be an int')
150
+ if not is_int(self.iter_times):
151
+ raise ValueError(f'iter_times is invalid, it should be an int')
152
+
153
+
154
+ class APIExtractor:
155
+ def __init__(self, api_name, dump_json_path, output_file):
156
+ self.api_name = api_name
157
+ self.dump_json_path = dump_json_path
158
+ self.output_file = output_file
159
+ self.data = None
160
+
161
+ def extract_op(self):
162
+ self.data = load_json(self.dump_json_path)
163
+ new_data = {}
164
+ extract_key_pattern = re.compile(f"^{re.escape(self.api_name)}\..+")
165
+ real_data_path = self.data.get('dump_data_dir', '')
166
+ for key, value in self.data.get('data', {}).items():
167
+ if extract_key_pattern.match(key):
168
+ if real_data_path:
169
+ value = self.load_real_data_path(value, real_data_path)
170
+ new_data[key] = value
171
+ if not new_data:
172
+ logger.error(f"Error: The api '{self.api_name}' does not exist in the file.")
173
+ else:
174
+ save_json(self.output_file, new_data, indent=4)
175
+ logger.info(
176
+ f"The api '{self.api_name}' has been successfully extracted and saved in: {self.output_file}")
177
+
178
+ def load_real_data_path(self, value, dump_data_dir):
179
+ parameters = [Const.INPUT_ARGS, Const.GRAD_INPUT, Const.INPUT, Const.OUTPUT, Const.GRAD_OUTPUT]
180
+ for parameter in parameters:
181
+ for v in value.get(parameter, []):
182
+ if v is not None:
183
+ self.update_data_name(v, dump_data_dir)
184
+ return value
185
+
186
+ def update_data_name(self, data, dump_data_dir):
187
+ if isinstance(data, list):
188
+ for item in data:
189
+ self.update_data_name(item, dump_data_dir)
190
+ elif DATA_NAME in data:
191
+ data[DATA_NAME] = os.path.join(dump_data_dir, data[DATA_NAME])
192
+
193
+
194
+ class OperatorScriptGenerator:
195
+ def __init__(self, common_config, args_info_forward, kwargs_info_forward, args_info_backward):
196
+ self.common_config = common_config
197
+ self.args_info_forward = args_info_forward
198
+ self.kwargs_info_forward = kwargs_info_forward
199
+ self.args_info_backward = args_info_backward
200
+
201
+ @staticmethod
202
+ def get_compare_standard(api_name):
203
+ api_standard_map = {
204
+ "binary_standard_api": "CompareStandard.BINARY_EQUALITY_STANDARD",
205
+ "absolute_standard_api": "CompareStandard.ABSOLUTE_THRESHOLD_STANDARD",
206
+ "ulp_standard_api": "CompareStandard.ULP_ERROR_STANDARD",
207
+ "thousandth_standard_api": "CompareStandard.THOUSANDTH_STANDARD"
208
+ }
209
+ for standard_api, standard_value in api_standard_map.items():
210
+ if api_name in globals()[standard_api]:
211
+ return standard_value
212
+ return "CompareStandard.BENCHMARK_STANDARD"
213
+
214
+ @staticmethod
215
+ def extract_detailed_api_segments(full_api_name):
216
+ """
217
+ Function Description:
218
+ Extract the name of the API.
219
+ Parameter:
220
+ full_api_name_with_direction_status: Full name of the API. Example: torch.matmul.0.forward.output.0
221
+ Return:
222
+ api_name: Name of api. Example: matmul, mul, etc.
223
+ full_api_name: Full name of api. Example: torch.matmul.0
224
+ direction_status: Direction status of api. Example: forward, backward, etc.
225
+ """
226
+ api_parts = full_api_name.split(Const.SEP)
227
+ api_parts_length = len(api_parts)
228
+ api_type, api_name, api_order = None, None, None
229
+ if api_parts_length == FOUR_SEGMENT:
230
+ api_type, api_name, api_order, _ = api_parts
231
+ elif api_parts_length == FIVE_SEGMENT:
232
+ api_type, prefix, api_name, api_order, _ = api_parts
233
+ api_name = Const.SEP.join([prefix, api_name])
234
+ return api_type, api_name, api_order
235
+
236
+ def get_settings(self, api_full_name):
237
+ '''
238
+ internal_settings contain all information needed for the operator program.
239
+ keys:
240
+ api_full_name: api_type.api_name.ordinal_number
241
+ api_type: type of API, one of torch.nn.functional, torch.Tensor or Torch
242
+ api_name: name of API
243
+ ordinal_number: how many times the same api has been called
244
+ direction_status: forward
245
+ random_seed: if mode is random_data, random seed is random_seed
246
+ iter_times: if mode is random_data, generate iter_times group of data; if mode is real_data,
247
+ iter_times does not matter
248
+ args_element_assignment: code for args assignment
249
+ args_list_generator_device: code for generate args list on device
250
+ args_list_generator_bench: code for generate args list on bench
251
+ kwargs_value_assignment: code for kwargs assignment
252
+ kwargs_dict_generator_device: code for generate kwargs dict on device
253
+ kwargs_dict_generator_bench: code for generate kwargs dict on bench
254
+ '''
255
+ # Generate an internal setting dictionary based on user settings
256
+ # including API name, type, comparison standard, random seed, number of iterations and other information
257
+ internal_settings = {}
258
+ internal_settings["propagation"] = self.common_config.propagation
259
+ internal_settings["api_full_name"] = api_full_name
260
+ api_type, api_name, ordinal_number = self.extract_detailed_api_segments(api_full_name)
261
+ if api_type == "Functional":
262
+ internal_settings["api_type"] = "torch.nn.functional"
263
+ elif api_type == "Tensor":
264
+ internal_settings["api_type"] = "torch.Tensor"
265
+ else:
266
+ internal_settings["api_type"] = "torch"
267
+ internal_settings["api_name"] = api_name
268
+ internal_settings["compare_standard"] = self.get_compare_standard(api_name)
269
+ internal_settings["ordinal_number"] = ordinal_number
270
+ internal_settings["direction_status"] = self.common_config.propagation
271
+ internal_settings["random_seed"] = self.common_config.random_seed
272
+ if self.common_config.data_mode == "real_data":
273
+ internal_settings["iter_times"] = 1
274
+ else:
275
+ internal_settings["iter_times"] = self.common_config.iter_times
276
+ internal_settings["args_element_assignment"] = \
277
+ self.generate_args_element_assignment_code(self.args_info_forward)
278
+ internal_settings["args_list_generator_device"] = \
279
+ self.generate_args_list(self.args_info_forward, flag_device=True)
280
+ internal_settings["args_list_generator_bench"] = \
281
+ self.generate_args_list(self.args_info_forward, flag_device=False)
282
+ internal_settings["kwargs_value_assignment"] = \
283
+ self.generate_kwargs_value_assignment_code(self.kwargs_info_forward)
284
+ internal_settings["kwargs_dict_generator_device"] = \
285
+ self.generate_kwargs_dict(self.kwargs_info_forward, flag_device=True)
286
+ internal_settings["kwargs_dict_generator_bench"] = \
287
+ self.generate_kwargs_dict(self.kwargs_info_forward, flag_device=False)
288
+ if self.common_config.propagation == Const.BACKWARD:
289
+ internal_settings["args_element_assignment_backward"] = self.generate_args_element_assignment_code(
290
+ self.args_info_backward)
291
+ internal_settings["args_list_generator_device_backward"] = \
292
+ self.generate_args_list(self.args_info_backward, flag_device=True)
293
+ internal_settings["args_list_generator_bench_backward"] = \
294
+ self.generate_args_list(self.args_info_backward, flag_device=False)
295
+ else:
296
+ internal_settings["args_element_assignment_backward"] = ''
297
+ internal_settings["args_list_generator_device_backward"] = ''
298
+ internal_settings["args_list_generator_bench_backward"] = ''
299
+
300
+ return internal_settings
301
+
302
+ @recursion_depth_decorator("OpGenerator: OperatorScriptGenerator.recursive_args_element_assignment")
303
+ def recursive_args_element_assignment(self, args_info, name_number):
304
+ args_element_assignment = ""
305
+ for index, arg in enumerate(args_info):
306
+ if isinstance(arg, (list, tuple)):
307
+ new_args_element_assignment = \
308
+ self.recursive_args_element_assignment(arg, name_number + "_" + str(index))
309
+ args_element_assignment += new_args_element_assignment
310
+ else:
311
+ arg["parameter_name"] = "arg" + name_number + "_" + str(index)
312
+ args_element_assignment += " " + "arg_info" + name_number + "_" + str(index) + " = " + \
313
+ "{}".format(str(arg)) + MsgConst.SPECIAL_CHAR[0]
314
+ args_element_assignment += " " + "arg" + name_number + "_" + str(index) + " = " + \
315
+ "generate_data(arg_info" + name_number + "_" + str(index) + ")" + MsgConst.SPECIAL_CHAR[0]
316
+ return args_element_assignment
317
+
318
+
319
+ def generate_args_element_assignment_code(self, args_info):
320
+ args_element_assignment = self.recursive_args_element_assignment(args_info, "")
321
+ return args_element_assignment
322
+
323
+ @recursion_depth_decorator("OpGenerator: OperatorScriptGenerator.recursive_args_list")
324
+ def recursive_args_list(self, args_info, flag_device=False, flag_bench=False):
325
+ args_list_generator = ""
326
+ for _, arg in enumerate(args_info):
327
+ if isinstance(arg, (list, tuple)):
328
+ (left_bracket, right_bracket) = ("[", "]") if isinstance(arg, list) else ("(", ")")
329
+ args_list_generator += left_bracket
330
+ new_args_list_generator = self.recursive_args_list(arg, flag_device=flag_device, flag_bench=flag_bench)
331
+ args_list_generator += new_args_list_generator
332
+ args_list_generator += right_bracket
333
+ else:
334
+ args_list_generator += arg.get("parameter_name")
335
+ if arg.get("type") in TENSOR_DATA_LIST:
336
+ if flag_device:
337
+ args_list_generator += ".to(device)"
338
+ if flag_bench:
339
+ args_list_generator += '.to(torch.device("cpu"))'
340
+ args_list_generator += ".to(RAISE_PRECISION.get(str(" + arg.get("parameter_name") + \
341
+ ".dtype), " + arg.get("parameter_name") + ".dtype))"
342
+ args_list_generator += Const.COMMA
343
+ return args_list_generator
344
+
345
+ def generate_args_list(self, args_info, flag_device):
346
+ if flag_device:
347
+ args_list_generator = self.recursive_args_list(args_info, flag_device=True)
348
+ else:
349
+ args_list_generator = self.recursive_args_list(args_info, flag_bench=True)
350
+ return args_list_generator
351
+
352
+ @recursion_depth_decorator("OpGenerator: OperatorScriptGenerator.recursive_kwargs_value_assignment")
353
+ def recursive_kwargs_value_assignment(self, info, key_name, name_number):
354
+ kwargs_value_assignment = ""
355
+ if isinstance(info, dict):
356
+ if info.get("type") == "torch.device" or info.get("type") == "torch.dtype":
357
+ kwargs_value_assignment += " " + "kwarg_" + key_name + name_number + " = " + info.get("value")
358
+ else:
359
+ kwargs_value_assignment += " " + "kwarg_info_" + key_name + name_number + " = " + \
360
+ "{}".format(str(info)) + MsgConst.SPECIAL_CHAR[0]
361
+ kwargs_value_assignment += " " + "kwarg_" + key_name + name_number + " = " + \
362
+ "generate_data(kwarg_info_" + key_name + name_number + ")" + MsgConst.SPECIAL_CHAR[0]
363
+ info["parameter_name"] = "kwarg_" + key_name + name_number
364
+ else:
365
+ for index, arg in enumerate(info):
366
+ new_kwargs_value_assignment = self.recursive_kwargs_value_assignment(arg, key_name, name_number + \
367
+ "_" + str(index))
368
+ kwargs_value_assignment += new_kwargs_value_assignment
369
+ return kwargs_value_assignment
370
+
371
+ def generate_kwargs_value_assignment_code(self, kwargs_info):
372
+ kwargs_value_assignment = ""
373
+ for key, value in kwargs_info.items():
374
+ kwargs_value_assignment += self.recursive_kwargs_value_assignment(value, key, "")
375
+ return kwargs_value_assignment
376
+
377
+ @recursion_depth_decorator("OpGenerator: OperatorScriptGenerator.recursive_kwargs_dict")
378
+ def recursive_kwargs_dict(self, info, flag_device=False, flag_bench=False):
379
+ kwargs_dict_generator = ""
380
+ if isinstance(info, dict):
381
+ kwargs_dict_generator += info.get("parameter_name")
382
+ if info.get("type") in TENSOR_DATA_LIST:
383
+ if flag_device:
384
+ kwargs_dict_generator += ".to(device)"
385
+ if flag_bench:
386
+ kwargs_dict_generator += '.to(torch.device("cpu"))'
387
+ kwargs_dict_generator += ".to(RAISE_PRECISION.get(str(" + info.get("parameter_name") + \
388
+ ".dtype), " + info.get("parameter_name") + ".dtype))"
389
+ else:
390
+ (left_bracket, right_bracket) = ("[", "]") if isinstance(info, list) else ("(", ")")
391
+ kwargs_dict_generator += left_bracket
392
+ for arg in info:
393
+ kwargs_dict_generator += self.recursive_kwargs_dict(arg, flag_device=flag_device, flag_bench=flag_bench)
394
+ kwargs_dict_generator += Const.COMMA
395
+ kwargs_dict_generator += right_bracket
396
+ return kwargs_dict_generator
397
+
398
+
399
+ def generate_kwargs_dict(self, kwargs_info, flag_device):
400
+ kwargs_dict_generator = ""
401
+ for key, value in kwargs_info.items():
402
+ kwargs_dict_generator += '"' + key + '"' + MonitorConst.VPP_SEP
403
+ if flag_device:
404
+ kwargs_dict_generator += self.recursive_kwargs_dict(value, flag_device=True) + Const.COMMA
405
+ else:
406
+ kwargs_dict_generator += self.recursive_kwargs_dict(value, flag_bench=True) + Const.COMMA
407
+ return kwargs_dict_generator
408
+
409
+
410
+
411
+ def _op_generator_parser(parser):
412
+ parser.add_argument("-i", "--config_input", dest="config_input", default='', type=str,
413
+ help="<Optional> Path of config json file", required=True)
414
+ parser.add_argument("-o", "--api_output_path", dest="api_output_path", type=str,
415
+ help="<Required> Path of extract api_name.json.",
416
+ required=True)
417
+
418
+
419
+ def parse_json_config(json_file_path):
420
+ if not json_file_path:
421
+ config_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
422
+ json_file_path = os.path.join(config_dir, "config.json")
423
+ json_config = load_json(json_file_path)
424
+ common_config = CommonConfig(json_config)
425
+ return common_config
426
+
427
+
428
+ def _run_operator_generate_commond(cmd_args):
429
+ common_config = parse_json_config(cmd_args.config_input)
430
+
431
+ if common_config.dump_json_path:
432
+ api_extract = APIExtractor(common_config.api_name, common_config.dump_json_path, common_config.extract_api_path)
433
+ api_extract.extract_op()
434
+ check_file_or_directory_path(common_config.extract_api_path)
435
+ check_file_or_directory_path(cmd_args.api_output_path, isdir=True)
436
+ json_content = common_config.check_user_settings()
437
+ api_info = APIInfo.from_json(json_content, common_config.propagation)
438
+
439
+ if common_config.propagation == Const.BACKWARD:
440
+ # read and check json
441
+ api_full_name_forward, api_info_dict_forward = api_info.api_full_name, api_info.api_info_dict
442
+ api_full_name_backward, api_info_dict_backward = (api_info.backward_info.api_full_name,
443
+ api_info.backward_info.api_info_dict)
444
+ args_info_forward = api_info_dict_forward.get(Const.INPUT_ARGS)
445
+ kwargs_info_forward = api_info_dict_forward.get(Const.INPUT_KWARGS)
446
+ if Const.GRAD_INPUT in api_info_dict_backward:
447
+ args_info_backward = api_info_dict_backward.get(Const.GRAD_INPUT)
448
+ elif Const.INPUT in api_info_dict_backward:
449
+ args_info_backward = api_info_dict_backward.get(Const.INPUT)
450
+ op_generate = OperatorScriptGenerator(common_config, args_info_forward, kwargs_info_forward, args_info_backward)
451
+ internal_settings = op_generate.get_settings(api_full_name_backward)
452
+ else:
453
+ # read and check json
454
+ api_full_name_forward, api_info_dict_forward = api_info.api_full_name, api_info.api_info_dict
455
+ args_info_forward = api_info_dict_forward.get(Const.INPUT_ARGS)
456
+ kwargs_info_forward = api_info_dict_forward.get(Const.INPUT_KWARGS)
457
+ op_generate = OperatorScriptGenerator(common_config, args_info_forward, kwargs_info_forward, None)
458
+ internal_settings = op_generate.get_settings(api_full_name_forward)
459
+
460
+ template_path = os.path.join(os.path.dirname(__file__), "operator_replication.template")
461
+ operator_script_path = os.path.join(cmd_args.api_output_path,
462
+ "{0}.py".format(internal_settings.get("api_full_name")))
463
+
464
+ try:
465
+ with FileOpen(template_path, 'r') as ftemp, FileOpen(operator_script_path, 'w') as fout:
466
+ code_template = ftemp.read()
467
+ fout.write(code_template.format(**internal_settings))
468
+ except OSError:
469
+ logger.error(f"Failed to open file. Please check file {template_path} or {operator_script_path}.")
470
+
471
+ logger.info(f"Generate operator script successfully and the name is {operator_script_path}.")
472
+
473
+
474
+ if __name__ == "__main__":
475
+ parser = argparse.ArgumentParser()
476
+ _op_generator_parser(parser)
477
+ cmd_args = parser.parse_args()
478
+ _run_operator_generate_commond(cmd_args)