mindstudio-probe 1.0.1__py3-none-any.whl → 1.0.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (249) hide show
  1. {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.3.dist-info}/METADATA +5 -1
  2. mindstudio_probe-1.0.3.dist-info/RECORD +272 -0
  3. msprobe/README.md +78 -23
  4. msprobe/__init__.py +1 -0
  5. msprobe/config/README.md +182 -40
  6. msprobe/config/config.json +22 -0
  7. msprobe/core/__init__.py +0 -0
  8. msprobe/{pytorch → core}/advisor/advisor.py +3 -3
  9. msprobe/{pytorch → core}/advisor/advisor_result.py +2 -2
  10. msprobe/core/common/const.py +82 -5
  11. msprobe/core/common/exceptions.py +30 -18
  12. msprobe/core/common/file_check.py +19 -1
  13. msprobe/core/common/log.py +15 -1
  14. msprobe/core/common/utils.py +130 -30
  15. msprobe/core/common_config.py +32 -19
  16. msprobe/core/compare/acc_compare.py +299 -0
  17. msprobe/core/compare/check.py +95 -0
  18. msprobe/core/compare/compare_cli.py +49 -0
  19. msprobe/core/compare/highlight.py +222 -0
  20. msprobe/core/compare/multiprocessing_compute.py +149 -0
  21. msprobe/{pytorch → core}/compare/npy_compare.py +55 -4
  22. msprobe/core/compare/utils.py +429 -0
  23. msprobe/core/data_dump/data_collector.py +39 -35
  24. msprobe/core/data_dump/data_processor/base.py +85 -37
  25. msprobe/core/data_dump/data_processor/factory.py +5 -7
  26. msprobe/core/data_dump/data_processor/mindspore_processor.py +198 -0
  27. msprobe/core/data_dump/data_processor/pytorch_processor.py +94 -51
  28. msprobe/core/data_dump/json_writer.py +11 -11
  29. msprobe/core/grad_probe/__init__.py +0 -0
  30. msprobe/core/grad_probe/constant.py +71 -0
  31. msprobe/core/grad_probe/grad_compare.py +175 -0
  32. msprobe/core/grad_probe/utils.py +52 -0
  33. msprobe/doc/grad_probe/grad_probe.md +207 -0
  34. msprobe/doc/grad_probe/img/image-1.png +0 -0
  35. msprobe/doc/grad_probe/img/image-2.png +0 -0
  36. msprobe/doc/grad_probe/img/image-3.png +0 -0
  37. msprobe/doc/grad_probe/img/image-4.png +0 -0
  38. msprobe/doc/grad_probe/img/image.png +0 -0
  39. msprobe/mindspore/api_accuracy_checker/__init__.py +0 -0
  40. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +246 -0
  41. msprobe/mindspore/api_accuracy_checker/api_info.py +69 -0
  42. msprobe/mindspore/api_accuracy_checker/api_runner.py +152 -0
  43. msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +197 -0
  44. msprobe/mindspore/api_accuracy_checker/compute_element.py +224 -0
  45. msprobe/mindspore/api_accuracy_checker/main.py +16 -0
  46. msprobe/mindspore/api_accuracy_checker/type_mapping.py +114 -0
  47. msprobe/mindspore/api_accuracy_checker/utils.py +63 -0
  48. msprobe/mindspore/cell_processor.py +34 -0
  49. msprobe/mindspore/common/const.py +87 -0
  50. msprobe/mindspore/common/log.py +38 -0
  51. msprobe/mindspore/common/utils.py +57 -0
  52. msprobe/mindspore/compare/distributed_compare.py +75 -0
  53. msprobe/mindspore/compare/ms_compare.py +117 -0
  54. msprobe/mindspore/compare/ms_graph_compare.py +317 -0
  55. msprobe/mindspore/compare/ms_to_pt_api.yaml +399 -0
  56. msprobe/mindspore/debugger/debugger_config.py +38 -15
  57. msprobe/mindspore/debugger/precision_debugger.py +79 -4
  58. msprobe/mindspore/doc/compare.md +58 -0
  59. msprobe/mindspore/doc/dump.md +158 -6
  60. msprobe/mindspore/dump/dump_tool_factory.py +19 -22
  61. msprobe/mindspore/dump/hook_cell/api_registry.py +104 -0
  62. msprobe/mindspore/dump/hook_cell/hook_cell.py +53 -0
  63. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +925 -0
  64. msprobe/mindspore/dump/hook_cell/wrap_functional.py +91 -0
  65. msprobe/mindspore/dump/hook_cell/wrap_tensor.py +63 -0
  66. msprobe/mindspore/dump/jit_dump.py +56 -0
  67. msprobe/mindspore/dump/kernel_kbyk_dump.py +65 -0
  68. msprobe/mindspore/free_benchmark/__init__.py +0 -0
  69. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +116 -0
  70. msprobe/mindspore/free_benchmark/common/__init__.py +0 -0
  71. msprobe/mindspore/free_benchmark/common/config.py +12 -0
  72. msprobe/mindspore/free_benchmark/common/handler_params.py +17 -0
  73. msprobe/mindspore/free_benchmark/common/utils.py +71 -0
  74. msprobe/mindspore/free_benchmark/data/support_wrap_ops.yaml +842 -0
  75. msprobe/mindspore/free_benchmark/decorator/__init__.py +0 -0
  76. msprobe/mindspore/free_benchmark/decorator/dec_forward.py +42 -0
  77. msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +107 -0
  78. msprobe/mindspore/free_benchmark/handler/__init__.py +0 -0
  79. msprobe/mindspore/free_benchmark/handler/base_handler.py +90 -0
  80. msprobe/mindspore/free_benchmark/handler/check_handler.py +41 -0
  81. msprobe/mindspore/free_benchmark/handler/fix_handler.py +36 -0
  82. msprobe/mindspore/free_benchmark/handler/handler_factory.py +21 -0
  83. msprobe/mindspore/free_benchmark/perturbation/add_noise.py +67 -0
  84. msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +21 -0
  85. msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +63 -0
  86. msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +34 -0
  87. msprobe/mindspore/free_benchmark/perturbation/no_change.py +12 -0
  88. msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +27 -0
  89. msprobe/mindspore/free_benchmark/self_check_tool_factory.py +33 -0
  90. msprobe/mindspore/grad_probe/__init__.py +0 -0
  91. msprobe/mindspore/grad_probe/global_context.py +91 -0
  92. msprobe/mindspore/grad_probe/grad_analyzer.py +231 -0
  93. msprobe/mindspore/grad_probe/grad_monitor.py +27 -0
  94. msprobe/mindspore/grad_probe/grad_stat_csv.py +132 -0
  95. msprobe/mindspore/grad_probe/hook.py +92 -0
  96. msprobe/mindspore/grad_probe/utils.py +29 -0
  97. msprobe/mindspore/ms_config.py +63 -15
  98. msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +17 -15
  99. msprobe/mindspore/runtime.py +4 -0
  100. msprobe/mindspore/service.py +354 -0
  101. msprobe/mindspore/task_handler_factory.py +7 -4
  102. msprobe/msprobe.py +66 -26
  103. msprobe/pytorch/__init__.py +1 -1
  104. msprobe/pytorch/api_accuracy_checker/common/config.py +21 -16
  105. msprobe/pytorch/api_accuracy_checker/common/utils.py +1 -60
  106. msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +2 -5
  107. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +46 -10
  108. msprobe/pytorch/api_accuracy_checker/compare/compare.py +84 -48
  109. msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +8 -12
  110. msprobe/pytorch/api_accuracy_checker/config.yaml +7 -1
  111. msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +15 -11
  112. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +11 -15
  113. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +16 -9
  114. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +193 -105
  115. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +68 -1
  116. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/__init__.py +0 -0
  117. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +202 -0
  118. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +324 -0
  119. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +204 -0
  120. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +218 -0
  121. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/ssl_config.py +10 -0
  122. msprobe/pytorch/bench_functions/__init__.py +15 -0
  123. msprobe/pytorch/bench_functions/apply_adam_w.py +28 -0
  124. msprobe/pytorch/bench_functions/confusion_transpose.py +19 -0
  125. msprobe/pytorch/bench_functions/fast_gelu.py +55 -0
  126. msprobe/pytorch/bench_functions/layer_norm_eval.py +6 -0
  127. msprobe/pytorch/bench_functions/linear.py +12 -0
  128. msprobe/pytorch/bench_functions/matmul_backward.py +48 -0
  129. msprobe/pytorch/bench_functions/npu_fusion_attention.py +421 -0
  130. msprobe/pytorch/bench_functions/rms_norm.py +15 -0
  131. msprobe/pytorch/bench_functions/rotary_mul.py +52 -0
  132. msprobe/pytorch/bench_functions/scaled_mask_softmax.py +26 -0
  133. msprobe/pytorch/bench_functions/swiglu.py +55 -0
  134. msprobe/pytorch/common/parse_json.py +3 -1
  135. msprobe/pytorch/common/utils.py +83 -7
  136. msprobe/pytorch/compare/distributed_compare.py +19 -64
  137. msprobe/pytorch/compare/match.py +3 -6
  138. msprobe/pytorch/compare/pt_compare.py +40 -0
  139. msprobe/pytorch/debugger/debugger_config.py +11 -2
  140. msprobe/pytorch/debugger/precision_debugger.py +34 -4
  141. msprobe/pytorch/doc/api_accuracy_checker.md +57 -13
  142. msprobe/pytorch/doc/api_accuracy_checker_online.md +187 -0
  143. msprobe/pytorch/doc/dump.md +73 -20
  144. msprobe/pytorch/doc/ptdbg_ascend_compare.md +75 -11
  145. msprobe/pytorch/doc/ptdbg_ascend_quickstart.md +3 -3
  146. msprobe/pytorch/doc/run_overflow_check.md +1 -1
  147. msprobe/pytorch/doc//321/206/320/247/320/260/321/206/320/260/320/227/321/206/320/255/320/226/321/205/342/225/226/320/265/321/205/320/225/342/225/226/321/205/320/254/342/225/221/321/206/320/251/320/277/321/211/320/272/320/234/321/210/320/277/320/221/321/205/320/242/320/234/321/206/320/220/320/267/321/210/320/223/342/225/234/321/205/320/257/342/225/221/321/207/342/225/221/342/224/220/321/206/320/232/320/265/321/205/320/241/320/232.md +151 -0
  148. msprobe/pytorch/free_benchmark/common/constant.py +3 -0
  149. msprobe/pytorch/free_benchmark/common/utils.py +4 -0
  150. msprobe/pytorch/free_benchmark/compare/grad_saver.py +22 -26
  151. msprobe/pytorch/free_benchmark/main.py +7 -4
  152. msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +1 -1
  153. msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +1 -1
  154. msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +1 -1
  155. msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +3 -3
  156. msprobe/pytorch/free_benchmark/perturbed_layers/npu/no_change.py +1 -1
  157. msprobe/pytorch/free_benchmark/perturbed_layers/run_cpu.py +1 -1
  158. msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +43 -29
  159. msprobe/pytorch/free_benchmark/result_handlers/handler_factory.py +0 -1
  160. msprobe/pytorch/function_factory.py +75 -0
  161. msprobe/pytorch/functional/dump_module.py +4 -4
  162. msprobe/pytorch/grad_probe/__init__.py +0 -0
  163. msprobe/pytorch/grad_probe/grad_monitor.py +90 -0
  164. msprobe/pytorch/grad_probe/grad_stat_csv.py +129 -0
  165. msprobe/pytorch/hook_module/hook_module.py +14 -3
  166. msprobe/pytorch/hook_module/support_wrap_ops.yaml +2 -1
  167. msprobe/pytorch/hook_module/utils.py +9 -9
  168. msprobe/pytorch/hook_module/wrap_aten.py +20 -10
  169. msprobe/pytorch/hook_module/wrap_distributed.py +10 -7
  170. msprobe/pytorch/hook_module/wrap_functional.py +4 -7
  171. msprobe/pytorch/hook_module/wrap_npu_custom.py +21 -10
  172. msprobe/pytorch/hook_module/wrap_tensor.py +5 -6
  173. msprobe/pytorch/hook_module/wrap_torch.py +5 -7
  174. msprobe/pytorch/hook_module/wrap_vf.py +6 -8
  175. msprobe/pytorch/module_processer.py +53 -13
  176. msprobe/pytorch/online_dispatch/compare.py +4 -4
  177. msprobe/pytorch/online_dispatch/dispatch.py +39 -41
  178. msprobe/pytorch/online_dispatch/dump_compare.py +17 -47
  179. msprobe/pytorch/online_dispatch/single_compare.py +5 -5
  180. msprobe/pytorch/online_dispatch/utils.py +2 -43
  181. msprobe/pytorch/parse_tool/lib/compare.py +31 -19
  182. msprobe/pytorch/parse_tool/lib/config.py +2 -1
  183. msprobe/pytorch/parse_tool/lib/parse_tool.py +4 -4
  184. msprobe/pytorch/parse_tool/lib/utils.py +34 -80
  185. msprobe/pytorch/parse_tool/lib/visualization.py +4 -3
  186. msprobe/pytorch/pt_config.py +100 -6
  187. msprobe/pytorch/service.py +104 -19
  188. mindstudio_probe-1.0.1.dist-info/RECORD +0 -228
  189. msprobe/mindspore/dump/api_kbk_dump.py +0 -55
  190. msprobe/pytorch/compare/acc_compare.py +0 -1024
  191. msprobe/pytorch/compare/highlight.py +0 -100
  192. msprobe/test/core_ut/common/test_utils.py +0 -345
  193. msprobe/test/core_ut/data_dump/test_data_collector.py +0 -47
  194. msprobe/test/core_ut/data_dump/test_json_writer.py +0 -183
  195. msprobe/test/core_ut/data_dump/test_scope.py +0 -151
  196. msprobe/test/core_ut/test_common_config.py +0 -152
  197. msprobe/test/core_ut/test_file_check.py +0 -218
  198. msprobe/test/core_ut/test_log.py +0 -109
  199. msprobe/test/mindspore_ut/test_api_kbk_dump.py +0 -51
  200. msprobe/test/mindspore_ut/test_debugger_config.py +0 -42
  201. msprobe/test/mindspore_ut/test_dump_tool_factory.py +0 -51
  202. msprobe/test/mindspore_ut/test_kernel_graph_dump.py +0 -66
  203. msprobe/test/mindspore_ut/test_kernel_graph_overflow_check.py +0 -63
  204. msprobe/test/mindspore_ut/test_ms_config.py +0 -69
  205. msprobe/test/mindspore_ut/test_overflow_check_tool_factory.py +0 -51
  206. msprobe/test/mindspore_ut/test_precision_debugger.py +0 -56
  207. msprobe/test/mindspore_ut/test_task_handler_factory.py +0 -58
  208. msprobe/test/pytorch_ut/advisor/test_advisor.py +0 -83
  209. msprobe/test/pytorch_ut/api_accuracy_checker/common/test_common_utils.py +0 -108
  210. msprobe/test/pytorch_ut/api_accuracy_checker/common/test_config.py +0 -39
  211. msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_algorithm.py +0 -112
  212. msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_api_precision_compare.py +0 -77
  213. msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_compare.py +0 -125
  214. msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_compare_column.py +0 -10
  215. msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_compare_utils.py +0 -43
  216. msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/dump.json +0 -179
  217. msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/forward.json +0 -63
  218. msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_data_generate.py +0 -99
  219. msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_multi_run_ut.py +0 -115
  220. msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_run_ut.py +0 -72
  221. msprobe/test/pytorch_ut/compare/test_acc_compare.py +0 -17
  222. msprobe/test/pytorch_ut/free_benchmark/perturbed_layers/test_perturbed_layser.py +0 -105
  223. msprobe/test/pytorch_ut/free_benchmark/result_handlers/test_result_handler.py +0 -121
  224. msprobe/test/pytorch_ut/free_benchmark/test_main.py +0 -101
  225. msprobe/test/pytorch_ut/functional/test_dump_module.py +0 -15
  226. msprobe/test/pytorch_ut/hook_module/test_api_registry.py +0 -130
  227. msprobe/test/pytorch_ut/hook_module/test_hook_module.py +0 -42
  228. msprobe/test/pytorch_ut/hook_module/test_wrap_aten.py +0 -65
  229. msprobe/test/pytorch_ut/hook_module/test_wrap_distributed.py +0 -35
  230. msprobe/test/pytorch_ut/hook_module/test_wrap_functional.py +0 -20
  231. msprobe/test/pytorch_ut/hook_module/test_wrap_tensor.py +0 -35
  232. msprobe/test/pytorch_ut/hook_module/test_wrap_torch.py +0 -43
  233. msprobe/test/pytorch_ut/hook_module/test_wrap_vf.py +0 -11
  234. msprobe/test/pytorch_ut/test_pt_config.py +0 -69
  235. msprobe/test/pytorch_ut/test_service.py +0 -59
  236. msprobe/test/resources/advisor.txt +0 -3
  237. msprobe/test/resources/compare_result_20230703104808.csv +0 -9
  238. msprobe/test/resources/compare_result_without_accuracy.csv +0 -9
  239. msprobe/test/resources/config.yaml +0 -3
  240. msprobe/test/resources/npu_test.pkl +0 -8
  241. msprobe/test/run_test.sh +0 -30
  242. msprobe/test/run_ut.py +0 -58
  243. msprobe/test/test_module_processer.py +0 -64
  244. {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.3.dist-info}/LICENSE +0 -0
  245. {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.3.dist-info}/WHEEL +0 -0
  246. {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.3.dist-info}/entry_points.txt +0 -0
  247. {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.3.dist-info}/top_level.txt +0 -0
  248. /msprobe/{pytorch → core}/advisor/advisor_const.py +0 -0
  249. /msprobe/pytorch/doc/{atat → msprobe}/321/207/342/226/223/342/225/233/321/205/342/225/221/320/266/321/205/342/225/226/320/265/321/205/320/225/342/225/226/321/206/320/245/342/226/221/321/206/320/235/320/276dump/321/206/320/260/320/227/321/205/320/227/320/226/321/206/320/220/320/267/321/210/320/223/342/225/234/321/205/320/257/342/225/221/321/207/342/225/221/342/224/220/321/206/320/232/320/265/321/205/320/241/320/232.md" +0 -0
@@ -0,0 +1,114 @@
1
+ from mindspore.common import dtype as mstype
2
+ import numpy as np
3
+ import mindspore
4
+ import torch
5
+
6
+ INT8 = "Int8"
7
+ UINT8 = "UInt8"
8
+ INT16 = "Int16"
9
+ UINT16 = "UInt16"
10
+ INT32 = "Int32"
11
+ UINT32 = "UInt32"
12
+ INT64 = "Int64"
13
+ UINT64 = "UInt64"
14
+ FLOAT16 = "Float16"
15
+ FLOAT32 = "Float32"
16
+ FLOAT64 = "Float64"
17
+ BOOL = "Bool"
18
+ BFLOAT16 = "BFloat16"
19
+ INT4 = "Int4"
20
+
21
+
22
+ dtype_str_to_ms_dtype = {
23
+ INT8: mstype.int8,
24
+ UINT8: mstype.uint8,
25
+ INT16: mstype.int16,
26
+ UINT16: mstype.uint16,
27
+ INT32: mstype.int32,
28
+ UINT32: mstype.uint32,
29
+ INT64: mstype.int64,
30
+ UINT64: mstype.uint64,
31
+ FLOAT16: mstype.float16,
32
+ FLOAT32: mstype.float32,
33
+ FLOAT64: mstype.float64,
34
+ BOOL: mstype.bool_,
35
+ BFLOAT16: mstype.bfloat16,
36
+ INT4: mstype.qint4x2
37
+ }
38
+ ms_dtype_to_dtype_str = {value: key for key, value in dtype_str_to_ms_dtype.items()}
39
+
40
+
41
+ dtype_str_to_np_dtype = {
42
+ INT8: np.int8,
43
+ UINT8: np.uint8,
44
+ INT16: np.int16,
45
+ UINT16: np.uint16,
46
+ INT32: np.int32,
47
+ UINT32: np.uint32,
48
+ INT64: np.int64,
49
+ UINT64: np.uint64,
50
+ FLOAT16: np.float16,
51
+ FLOAT32: np.float32,
52
+ FLOAT64: np.float64,
53
+ BOOL: np.bool_
54
+ }
55
+ np_dtype_to_dtype_str = {value: key for key, value in dtype_str_to_np_dtype.items()}
56
+
57
+ dtype_str_to_torch_dtype = {
58
+ INT8: torch.int8,
59
+ UINT8: torch.uint8,
60
+ INT16: torch.int16,
61
+ INT32: torch.int32,
62
+ INT64: torch.int64,
63
+ FLOAT16: torch.float16,
64
+ FLOAT32: torch.float32,
65
+ FLOAT64: torch.float64,
66
+ BOOL: torch.bool,
67
+ BFLOAT16: torch.bfloat16,
68
+ }
69
+ torch_dtype_to_dtype_str = {value: key for key, value in dtype_str_to_torch_dtype.items()}
70
+
71
+ MINDSPORE_TENSOR_TYPE_STR = "mindspore.Tensor"
72
+ BOOL_TYPE_STR = "bool"
73
+ INT_TYPE_STR = "int"
74
+ FLOAT_TYPE_STR = "float"
75
+ SLICE_TYPE_STR = "slice"
76
+ TUPLE_TYPE_STR = "tuple"
77
+ STR_TYPE_STR = "str"
78
+
79
+ api_info_type_str_to_type = {
80
+ MINDSPORE_TENSOR_TYPE_STR: mindspore.Tensor,
81
+ BOOL_TYPE_STR: bool,
82
+ INT_TYPE_STR: int,
83
+ FLOAT_TYPE_STR: float,
84
+ SLICE_TYPE_STR: slice,
85
+ STR_TYPE_STR: str,
86
+ }
87
+ type_to_api_info_type_str = {value: key for key, value in api_info_type_str_to_type.items()}
88
+
89
+ DEFAULT_CONSTRUCT_NP_FLOAT_DTYPE = np.float64
90
+ DEFAULT_CONSTRUCT_NP_INT_DTYPE = np.float64
91
+ DEFAULT_CONSTRUCT_NP_UINT_DTYPE = np.float64
92
+
93
+ float_dtype_str_list = [
94
+ FLOAT16,
95
+ FLOAT32,
96
+ FLOAT64,
97
+ BFLOAT16,
98
+ ]
99
+
100
+ int_dtype_str_list = [
101
+ INT8,
102
+ INT16,
103
+ INT32,
104
+ INT64,
105
+ BOOL,
106
+ INT4,
107
+ ]
108
+
109
+ uint_dtype_str_list = [
110
+ UINT8,
111
+ UINT16,
112
+ UINT32,
113
+ UINT64,
114
+ ]
@@ -0,0 +1,63 @@
1
+ from msprobe.core.common.exceptions import ApiAccuracyCheckerException
2
+ from msprobe.core.common.log import logger
3
+
4
+ def check_and_get_from_json_dict(dict_instance, key, key_description, accepted_type=None, accepted_value=None):
5
+ '''
6
+ Args:
7
+ dict_instance: dict, dict parsed from input json
8
+ key: str
9
+ key_description: str
10
+ accepted_type: tuple
11
+ accepted_value: Union[tuple, list]
12
+
13
+ Return:
14
+ value, the corresponding value of "key" in "dict_instance"
15
+
16
+ Exception:
17
+ raise ApiAccuracyCheckerException.ParseJsonFailed error when
18
+ 1. dict_instance is not a dict
19
+ 2. value is None
20
+ 3. value is not accepted type
21
+ 4. value is not accepted value
22
+ '''
23
+ parse_failed_exception = ApiAccuracyCheckerException(ApiAccuracyCheckerException.ParseJsonFailed)
24
+ if not isinstance(dict_instance, dict):
25
+ logger.error_log_with_exp("check_and_get_from_json_dict failed: input is not a dict", parse_failed_exception)
26
+ value = dict_instance.get(key)
27
+ if value is None:
28
+ logger.error_log_with_exp(f"check_and_get_from_json_dict failed: {key_description} is missing",
29
+ parse_failed_exception)
30
+ elif accepted_type is not None and not isinstance(value, accepted_type):
31
+ logger.error_log_with_exp(
32
+ f"check_and_get_from_json_dict failed: {key_description} is not accepted type: {accepted_type}",
33
+ parse_failed_exception)
34
+ elif accepted_value is not None and value not in accepted_value:
35
+ logger.error_log_with_exp(
36
+ f"check_and_get_from_json_dict failed: {key_description} is not accepted value: {accepted_value}",
37
+ parse_failed_exception)
38
+ return value
39
+
40
+ def convert_to_tuple(input):
41
+ if isinstance(input, (tuple, list)):
42
+ return tuple(input)
43
+ else:
44
+ input_list = [input]
45
+ return tuple(input_list)
46
+
47
+ class GlobalContext:
48
+ def __init__(self):
49
+ self.is_constructed = True
50
+ self.dump_data_dir = ""
51
+
52
+ def init(self, is_constructed, dump_data_dir):
53
+ self.is_constructed = is_constructed
54
+ self.dump_data_dir = dump_data_dir
55
+
56
+ def get_dump_data_dir(self):
57
+ return self.dump_data_dir
58
+
59
+ def get_is_constructed(self):
60
+ return self.is_constructed
61
+
62
+
63
+ global_context = GlobalContext()
@@ -0,0 +1,34 @@
1
+ from msprobe.core.data_dump.scope import ModuleRangeScope
2
+ from msprobe.core.common.const import Const
3
+ from msprobe.mindspore.common.log import logger
4
+
5
+
6
+ class CellProcessor:
7
+ cell_count = {}
8
+
9
+ def __init__(self, scope):
10
+ if isinstance(scope, ModuleRangeScope):
11
+ self.scope = scope
12
+ else:
13
+ self.scope = None
14
+
15
+ @staticmethod
16
+ def set_cell_count(cell_name):
17
+ if cell_name not in CellProcessor.cell_count:
18
+ CellProcessor.cell_count[cell_name] = 0
19
+ else:
20
+ CellProcessor.cell_count[cell_name] += 1
21
+ return CellProcessor.cell_count[cell_name]
22
+
23
+ def node_hook(self, name_prefix, start_or_stop, **kwargs):
24
+ def begin_hook(cell, input):
25
+ index = self.set_cell_count(name_prefix)
26
+ cell.mindstudio_reserved_name = full_name = name_prefix + Const.SEP + str(index)
27
+ if self.scope:
28
+ self.scope.begin_module(full_name)
29
+
30
+ def end_hook(cell, input, output):
31
+ if self.scope:
32
+ self.scope.end_module(cell.mindstudio_reserved_name)
33
+
34
+ return begin_hook if Const.START == start_or_stop else end_hook
@@ -0,0 +1,87 @@
1
+ import numpy as np
2
+ import mindspore as ms
3
+
4
+ from msprobe.core.common.const import Const as CoreCost
5
+
6
+
7
+ class Const:
8
+ CELL = "cell"
9
+ API = "api"
10
+ KERNEL = "kernel"
11
+ TOOL_LEVEL_DICT = {
12
+ CoreCost.LEVEL_L0: CELL,
13
+ CoreCost.LEVEL_L1: API,
14
+ CoreCost.LEVEL_L2: KERNEL
15
+ }
16
+ PYNATIVE_MODE = "pynative"
17
+ GRAPH_GE_MODE = "graph_ge"
18
+ GRAPH_KBYK_MODE = "graph_kbyk"
19
+
20
+
21
+ class FreeBenchmarkConst:
22
+ DEFAULT_DEVICE = "npu"
23
+ DEFAULT_STAGE = "forward"
24
+ DEFAULT_DUMP_LEVEL = CoreCost.LEVEL_L1
25
+ DEFAULT_PERT_TYPE = "improve_precision"
26
+ DEFAULT_HANDLER_TYPE = "check"
27
+ FIX_HANDLER_MODE = "fix"
28
+ ADD_NOISE = "add_noise"
29
+ BIT_NOISE = "bit_noise"
30
+ NO_CHANGE = "no_change"
31
+ IMPROVE_PRECISION = "improve_precision"
32
+ CHECK = "check"
33
+ FIX = "fix"
34
+ DEVICE_LIST = ["npu"]
35
+ STAGE_LIST = ["forward"]
36
+ DUMP_LEVEL_LIST = [CoreCost.LEVEL_L1]
37
+ PERT_TYPE_LIST = [IMPROVE_PRECISION, ADD_NOISE, BIT_NOISE, NO_CHANGE]
38
+ HANDLER_TYPE_LIST = [CHECK, FIX]
39
+ COMMUNICATION_API_LIST = [
40
+ "mindspore.communication.comm_func.all_gather_into_tensor",
41
+ "mindspore.communication.comm_func.gather_into_tensor",
42
+ "mindspore.communication.comm_func.all_reduce",
43
+ "mindspore.communication.comm_func.reduce",
44
+ "mindspore.communication.comm_func.reduce_scatter_tensor"
45
+ ]
46
+ NO_CHANGE_ERROR_THRESHOLD = 1.0
47
+ SYMBOL_FLIPPING_RATIO = 8.0
48
+ OPS_PREFIX = "mindspore.ops."
49
+ Tensor_PREFIX = "mindspore.Tensor."
50
+ MINT_PREFIX = "mindspore.mint."
51
+ MINT_NN_FUNC_PREFIX = "mindspore.mint.nn.functional."
52
+ COMM_PREFIX = "mindspore.communication.comm_func."
53
+
54
+ API_PREFIX_DICT = {
55
+ "ops": OPS_PREFIX,
56
+ "Tensor": Tensor_PREFIX,
57
+ "mint": MINT_PREFIX,
58
+ "mint.nn.functional": MINT_NN_FUNC_PREFIX,
59
+ "communication": COMM_PREFIX
60
+ }
61
+
62
+ PERT_VALUE_DICT = {
63
+ ms.bfloat16: 1e-4,
64
+ ms.float16: 1e-6,
65
+ ms.float32: 1e-8,
66
+ ms.float64: 1e-16
67
+ }
68
+
69
+ ERROR_THRESHOLD = {
70
+ ms.float16: 1.002,
71
+ ms.float32: 1.0002
72
+ }
73
+
74
+ PERT_BIT_DICT = {
75
+ ms.float16: np.int16,
76
+ ms.float32: np.int32,
77
+ ms.float64: np.int64
78
+ }
79
+
80
+ MS_NUMPY_DTYPE_DICT = {
81
+ ms.int16: np.int16,
82
+ ms.int32: np.int32,
83
+ ms.int64: np.int64,
84
+ ms.float16: np.float16,
85
+ ms.float32: np.float32,
86
+ ms.float64: np.float64
87
+ }
@@ -0,0 +1,38 @@
1
+ # Copyright 2024 Huawei Technologies Co., Ltd
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ============================================================================
15
+
16
+ import os
17
+ import time
18
+ import sys
19
+
20
+ from msprobe.mindspore.common.utils import get_rank_if_initialized
21
+ from msprobe.core.common.log import BaseLogger
22
+ from msprobe.core.common.exceptions import DistributedNotInitializedError
23
+
24
+
25
+ class MindsporeLogger(BaseLogger):
26
+ def __init__(self):
27
+ super().__init__()
28
+
29
+ def get_rank(self):
30
+ try:
31
+ current_rank = get_rank_if_initialized()
32
+ except DistributedNotInitializedError:
33
+ current_rank = None
34
+
35
+ return current_rank
36
+
37
+
38
+ logger = MindsporeLogger()
@@ -0,0 +1,57 @@
1
+ # Copyright 2024 Huawei Technologies Co., Ltd
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ============================================================================
15
+
16
+ import mindspore as ms
17
+
18
+ from msprobe.core.common.exceptions import DistributedNotInitializedError
19
+ from msprobe.core.common.file_check import path_len_exceeds_limit
20
+ from msprobe.core.common.utils import save_npy
21
+ from msprobe.core.common.log import logger
22
+
23
+
24
+ def get_rank_if_initialized():
25
+ if ms.communication.GlobalComm.INITED:
26
+ return ms.communication.get_rank()
27
+ else:
28
+ raise DistributedNotInitializedError("mindspore distributed environment is not initialized")
29
+
30
+
31
+ def convert_bf16_to_fp32(tensor):
32
+ if tensor.dtype == ms.bfloat16:
33
+ tensor = tensor.to(ms.float32)
34
+ return tensor
35
+
36
+
37
+ def save_tensor_as_npy(tensor, file_path):
38
+ if not path_len_exceeds_limit(file_path):
39
+ tensor = convert_bf16_to_fp32(tensor)
40
+ saved_tensor = tensor.asnumpy()
41
+ save_npy(saved_tensor, file_path)
42
+ else:
43
+ logger.warning(f'The file path {file_path} length exceeds limit.')
44
+
45
+
46
+ class MsprobeStep(ms.train.Callback):
47
+
48
+ def __init__(self, debugger):
49
+ super(MsprobeStep, self).__init__()
50
+ self.debugger = debugger
51
+
52
+ def on_train_step_begin(self, run_context):
53
+ self.debugger.start()
54
+
55
+ def on_train_step_end(self, run_context):
56
+ self.debugger.stop()
57
+ self.debugger.step()
@@ -0,0 +1,75 @@
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+ """
4
+ # Copyright (C) 2019-2024. Huawei Technologies Co., Ltd. All rights reserved.
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """
17
+ import os
18
+ from msprobe.core.common.utils import CompareException, check_compare_param, \
19
+ check_configuration_param, task_dumppath_get
20
+ from msprobe.core.common.file_check import create_directory
21
+ from msprobe.core.common.exceptions import FileCheckException
22
+ from msprobe.core.common.log import logger
23
+ from msprobe.mindspore.compare.ms_compare import MSComparator
24
+ from msprobe.core.compare.utils import check_and_return_dir_contents, extract_json
25
+ from msprobe.mindspore.compare.ms_graph_compare import GraphMSComparator
26
+
27
+ def ms_compare_distributed(npu_dump_dir, bench_dump_dir, output_path, **kwargs):
28
+ if kwargs.get('suffix'):
29
+ logger.error("Argument 'suffix' is not supported for compare_distributed.")
30
+ raise CompareException(CompareException.INVALID_PARAM_ERROR)
31
+ stack_mode = kwargs.get('stack_mode', False)
32
+ auto_analyze = kwargs.get('auto_analyze', True)
33
+ fuzzy_match = kwargs.get('fuzzy_match', False)
34
+ # get the ranks and match by order
35
+ npu_ranks = sorted(check_and_return_dir_contents(npu_dump_dir, 'rank'))
36
+ bench_ranks = sorted(check_and_return_dir_contents(bench_dump_dir, 'rank'))
37
+ if len(npu_ranks) != len(bench_ranks):
38
+ logger.error('The number of ranks in the two runs are different. '
39
+ 'Unable to match the ranks. Please use another folder to compare '
40
+ 'or use compare() api and manually match the ranks.')
41
+ raise CompareException(CompareException.INVALID_PATH_ERROR)
42
+ for nr, br in zip(npu_ranks, bench_ranks):
43
+ npu_data_dir = os.path.join(npu_dump_dir, nr)
44
+ bench_data_dir = os.path.join(bench_dump_dir, br)
45
+ npu_path = extract_json(npu_data_dir, stack_json=False)
46
+ bench_path = extract_json(bench_data_dir, stack_json=False)
47
+ stack_path = extract_json(npu_data_dir, stack_json=True)
48
+
49
+ dump_result_param = {
50
+ 'npu_json_path': npu_path,
51
+ 'bench_json_path': bench_path,
52
+ 'stack_json_path': stack_path,
53
+ 'is_print_compare_log': True
54
+ }
55
+ try:
56
+ summary_compare, md5_compare = task_dumppath_get(dump_result_param)
57
+ check_configuration_param(stack_mode, auto_analyze, fuzzy_match)
58
+ create_directory(output_path)
59
+ check_compare_param(dump_result_param, output_path, summary_compare=summary_compare, md5_compare=md5_compare)
60
+ except (CompareException, FileCheckException) as error:
61
+ logger.error('Compare failed. Please check the arguments and do it again!')
62
+ raise CompareException(error.code) from error
63
+ ms_comparator = MSComparator()
64
+ ms_comparator.compare_core(dump_result_param, output_path, suffix=f'_{nr}-{br}', summary_compare=summary_compare,
65
+ md5_compare=md5_compare, **kwargs)
66
+
67
+
68
+ def ms_graph_compare(inputs, outputs):
69
+ try:
70
+ create_directory(outputs)
71
+ except (CompareException, FileCheckException) as error:
72
+ logger.error('Compare failed. Please check the arguments and do it again!')
73
+ return
74
+ msComparator = GraphMSComparator(inputs, outputs)
75
+ msComparator.compare_core()
@@ -0,0 +1,117 @@
1
+ import os.path
2
+ from msprobe.core.common.utils import check_compare_param, CompareException, check_configuration_param, \
3
+ task_dumppath_get, load_yaml, load_npy
4
+ from msprobe.core.common.file_check import create_directory
5
+ from msprobe.core.common.const import Const
6
+ from msprobe.core.common.log import logger
7
+ from msprobe.core.common.exceptions import FileCheckException
8
+ from msprobe.core.compare.acc_compare import Comparator
9
+ from msprobe.core.compare.check import check_struct_match, fuzzy_check_op
10
+
11
+
12
+ class MSComparator(Comparator):
13
+ def __init__(self, cell_mapping=None, api_mapping=None):
14
+ self.frame_name = MSComparator.__name__
15
+ self.cell_mapping = cell_mapping
16
+ self.api_mapping = api_mapping
17
+ self.cross_frame = cell_mapping is not None or api_mapping is not None
18
+ self.cell_mapping_dict = self.load_mapping_file(self.cell_mapping)
19
+ self.api_mapping_dict = {}
20
+ if api_mapping is not None:
21
+ self.ms_to_pt_mapping = self.load_internal_api()
22
+
23
+ def load_internal_api(self):
24
+ cur_path = os.path.dirname(os.path.realpath(__file__))
25
+ yaml_path = os.path.join(cur_path,"ms_to_pt_api.yaml")
26
+ return load_yaml(yaml_path)
27
+
28
+ def load_mapping_file(self, mapping_file):
29
+ if isinstance(mapping_file, str):
30
+ mapping_dict = load_yaml(mapping_file)
31
+ else:
32
+ mapping_dict = {}
33
+ return mapping_dict
34
+
35
+ def process_cell_mapping(self, npu_op_name):
36
+ npu_op_name = [op_name.replace("Cell", "Module", 1) for op_name in npu_op_name]
37
+ if self.cell_mapping_dict:
38
+ for index, op_name in enumerate(npu_op_name):
39
+ # get cell name & class name from op_name
40
+ # Cell.fc1.Dense.forward.0.input.0
41
+ cell_name = op_name.split(Const.SEP, 1)[-1].rsplit(Const.SEP, 4)[0]
42
+ if cell_name in self.cell_mapping_dict:
43
+ npu_op_name[index] = op_name.replace(cell_name, self.cell_mapping_dict[cell_name], 1)
44
+ return npu_op_name
45
+
46
+ def check_op(self, npu_dict, bench_dict, fuzzy_match):
47
+ npu_op_name = npu_dict["op_name"].copy()
48
+ bench_op_name = bench_dict["op_name"].copy()
49
+
50
+ if self.api_mapping is not None:
51
+ npu_op_name = self.process_api_mapping(npu_op_name, bench_op_name)
52
+ if self.cell_mapping is not None:
53
+ npu_op_name = self.process_cell_mapping(npu_op_name)
54
+
55
+ struct_match = check_struct_match(npu_dict, bench_dict, cross_frame=self.cross_frame)
56
+ if not fuzzy_match:
57
+ return npu_op_name == bench_op_name and struct_match
58
+ is_match = True
59
+ try:
60
+ is_match = fuzzy_check_op(npu_op_name, bench_op_name)
61
+ except Exception as err:
62
+ logger.warning("%s and %s can not fuzzy match." % (npu_op_name, bench_op_name))
63
+ is_match = False
64
+ return is_match and struct_match
65
+
66
+ def read_npy_data(self, dir_path, file_name, load_pt_file=False):
67
+ data_path = os.path.join(dir_path, file_name)
68
+ if load_pt_file:
69
+ import torch
70
+ from msprobe.pytorch.common.utils import load_pt
71
+ data_value = load_pt(data_path).detach()
72
+ if data_value.dtype == torch.bfloat16:
73
+ data_value = data_value.to(torch.float32)
74
+ data_value = data_value.numpy()
75
+ else:
76
+ data_value = load_npy(data_path)
77
+ return data_value
78
+
79
+ def api_replace(self, npu_op_name, target, para):
80
+ for idx, _ in enumerate(npu_op_name):
81
+ npu_op_name[idx] = npu_op_name[idx].replace(target, para)
82
+ return npu_op_name
83
+
84
+ def process_api_mapping(self, npu_op_name, bench_op_name):
85
+ # get api name & class name from op_name
86
+ # Functional.addcmul.0.forward.input.0
87
+ ms_api_name = npu_op_name[0].rsplit(Const.SEP, 4)[0]
88
+ pt_api_name = bench_op_name[0].rsplit(Const.SEP, 4)[0]
89
+ class_name = ms_api_name.split(Const.SEP)[0]
90
+ if class_name == "Mint":
91
+ return self.api_replace(npu_op_name, "Mint", "Torch")
92
+ elif class_name == "MintFunctional":
93
+ return self.api_replace(npu_op_name, "MintFunctional", "Functional")
94
+ elif self.ms_to_pt_mapping.get(ms_api_name) == pt_api_name:
95
+ return self.api_replace(npu_op_name, ms_api_name, pt_api_name)
96
+ else:
97
+ return npu_op_name
98
+
99
+
100
+ def ms_compare(input_param, output_path, **kwargs):
101
+ try:
102
+ stack_mode = kwargs.get('stack_mode', False)
103
+ auto_analyze = kwargs.get('auto_analyze', True)
104
+ fuzzy_match = kwargs.get('fuzzy_match', False)
105
+ cell_mapping = kwargs.get('cell_mapping', None)
106
+ api_mapping = kwargs.get('api_mapping', None)
107
+ summary_compare, md5_compare = task_dumppath_get(input_param)
108
+ check_configuration_param(stack_mode, auto_analyze, fuzzy_match)
109
+ create_directory(output_path)
110
+ check_compare_param(input_param, output_path, summary_compare, md5_compare)
111
+ except (CompareException, FileCheckException) as error:
112
+ logger.error('Compare failed. Please check the arguments and do it again!')
113
+ raise CompareException(error.code) from error
114
+ ms_comparator = MSComparator(cell_mapping, api_mapping)
115
+ ms_comparator.compare_core(input_param, output_path, stack_mode=stack_mode,
116
+ auto_analyze=auto_analyze, fuzzy_match=fuzzy_match, summary_compare=summary_compare,
117
+ md5_compare=md5_compare)