mindstudio-probe 1.0.1__py3-none-any.whl → 1.0.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (249) hide show
  1. {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.3.dist-info}/METADATA +5 -1
  2. mindstudio_probe-1.0.3.dist-info/RECORD +272 -0
  3. msprobe/README.md +78 -23
  4. msprobe/__init__.py +1 -0
  5. msprobe/config/README.md +182 -40
  6. msprobe/config/config.json +22 -0
  7. msprobe/core/__init__.py +0 -0
  8. msprobe/{pytorch → core}/advisor/advisor.py +3 -3
  9. msprobe/{pytorch → core}/advisor/advisor_result.py +2 -2
  10. msprobe/core/common/const.py +82 -5
  11. msprobe/core/common/exceptions.py +30 -18
  12. msprobe/core/common/file_check.py +19 -1
  13. msprobe/core/common/log.py +15 -1
  14. msprobe/core/common/utils.py +130 -30
  15. msprobe/core/common_config.py +32 -19
  16. msprobe/core/compare/acc_compare.py +299 -0
  17. msprobe/core/compare/check.py +95 -0
  18. msprobe/core/compare/compare_cli.py +49 -0
  19. msprobe/core/compare/highlight.py +222 -0
  20. msprobe/core/compare/multiprocessing_compute.py +149 -0
  21. msprobe/{pytorch → core}/compare/npy_compare.py +55 -4
  22. msprobe/core/compare/utils.py +429 -0
  23. msprobe/core/data_dump/data_collector.py +39 -35
  24. msprobe/core/data_dump/data_processor/base.py +85 -37
  25. msprobe/core/data_dump/data_processor/factory.py +5 -7
  26. msprobe/core/data_dump/data_processor/mindspore_processor.py +198 -0
  27. msprobe/core/data_dump/data_processor/pytorch_processor.py +94 -51
  28. msprobe/core/data_dump/json_writer.py +11 -11
  29. msprobe/core/grad_probe/__init__.py +0 -0
  30. msprobe/core/grad_probe/constant.py +71 -0
  31. msprobe/core/grad_probe/grad_compare.py +175 -0
  32. msprobe/core/grad_probe/utils.py +52 -0
  33. msprobe/doc/grad_probe/grad_probe.md +207 -0
  34. msprobe/doc/grad_probe/img/image-1.png +0 -0
  35. msprobe/doc/grad_probe/img/image-2.png +0 -0
  36. msprobe/doc/grad_probe/img/image-3.png +0 -0
  37. msprobe/doc/grad_probe/img/image-4.png +0 -0
  38. msprobe/doc/grad_probe/img/image.png +0 -0
  39. msprobe/mindspore/api_accuracy_checker/__init__.py +0 -0
  40. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +246 -0
  41. msprobe/mindspore/api_accuracy_checker/api_info.py +69 -0
  42. msprobe/mindspore/api_accuracy_checker/api_runner.py +152 -0
  43. msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +197 -0
  44. msprobe/mindspore/api_accuracy_checker/compute_element.py +224 -0
  45. msprobe/mindspore/api_accuracy_checker/main.py +16 -0
  46. msprobe/mindspore/api_accuracy_checker/type_mapping.py +114 -0
  47. msprobe/mindspore/api_accuracy_checker/utils.py +63 -0
  48. msprobe/mindspore/cell_processor.py +34 -0
  49. msprobe/mindspore/common/const.py +87 -0
  50. msprobe/mindspore/common/log.py +38 -0
  51. msprobe/mindspore/common/utils.py +57 -0
  52. msprobe/mindspore/compare/distributed_compare.py +75 -0
  53. msprobe/mindspore/compare/ms_compare.py +117 -0
  54. msprobe/mindspore/compare/ms_graph_compare.py +317 -0
  55. msprobe/mindspore/compare/ms_to_pt_api.yaml +399 -0
  56. msprobe/mindspore/debugger/debugger_config.py +38 -15
  57. msprobe/mindspore/debugger/precision_debugger.py +79 -4
  58. msprobe/mindspore/doc/compare.md +58 -0
  59. msprobe/mindspore/doc/dump.md +158 -6
  60. msprobe/mindspore/dump/dump_tool_factory.py +19 -22
  61. msprobe/mindspore/dump/hook_cell/api_registry.py +104 -0
  62. msprobe/mindspore/dump/hook_cell/hook_cell.py +53 -0
  63. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +925 -0
  64. msprobe/mindspore/dump/hook_cell/wrap_functional.py +91 -0
  65. msprobe/mindspore/dump/hook_cell/wrap_tensor.py +63 -0
  66. msprobe/mindspore/dump/jit_dump.py +56 -0
  67. msprobe/mindspore/dump/kernel_kbyk_dump.py +65 -0
  68. msprobe/mindspore/free_benchmark/__init__.py +0 -0
  69. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +116 -0
  70. msprobe/mindspore/free_benchmark/common/__init__.py +0 -0
  71. msprobe/mindspore/free_benchmark/common/config.py +12 -0
  72. msprobe/mindspore/free_benchmark/common/handler_params.py +17 -0
  73. msprobe/mindspore/free_benchmark/common/utils.py +71 -0
  74. msprobe/mindspore/free_benchmark/data/support_wrap_ops.yaml +842 -0
  75. msprobe/mindspore/free_benchmark/decorator/__init__.py +0 -0
  76. msprobe/mindspore/free_benchmark/decorator/dec_forward.py +42 -0
  77. msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +107 -0
  78. msprobe/mindspore/free_benchmark/handler/__init__.py +0 -0
  79. msprobe/mindspore/free_benchmark/handler/base_handler.py +90 -0
  80. msprobe/mindspore/free_benchmark/handler/check_handler.py +41 -0
  81. msprobe/mindspore/free_benchmark/handler/fix_handler.py +36 -0
  82. msprobe/mindspore/free_benchmark/handler/handler_factory.py +21 -0
  83. msprobe/mindspore/free_benchmark/perturbation/add_noise.py +67 -0
  84. msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +21 -0
  85. msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +63 -0
  86. msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +34 -0
  87. msprobe/mindspore/free_benchmark/perturbation/no_change.py +12 -0
  88. msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +27 -0
  89. msprobe/mindspore/free_benchmark/self_check_tool_factory.py +33 -0
  90. msprobe/mindspore/grad_probe/__init__.py +0 -0
  91. msprobe/mindspore/grad_probe/global_context.py +91 -0
  92. msprobe/mindspore/grad_probe/grad_analyzer.py +231 -0
  93. msprobe/mindspore/grad_probe/grad_monitor.py +27 -0
  94. msprobe/mindspore/grad_probe/grad_stat_csv.py +132 -0
  95. msprobe/mindspore/grad_probe/hook.py +92 -0
  96. msprobe/mindspore/grad_probe/utils.py +29 -0
  97. msprobe/mindspore/ms_config.py +63 -15
  98. msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +17 -15
  99. msprobe/mindspore/runtime.py +4 -0
  100. msprobe/mindspore/service.py +354 -0
  101. msprobe/mindspore/task_handler_factory.py +7 -4
  102. msprobe/msprobe.py +66 -26
  103. msprobe/pytorch/__init__.py +1 -1
  104. msprobe/pytorch/api_accuracy_checker/common/config.py +21 -16
  105. msprobe/pytorch/api_accuracy_checker/common/utils.py +1 -60
  106. msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +2 -5
  107. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +46 -10
  108. msprobe/pytorch/api_accuracy_checker/compare/compare.py +84 -48
  109. msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +8 -12
  110. msprobe/pytorch/api_accuracy_checker/config.yaml +7 -1
  111. msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +15 -11
  112. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +11 -15
  113. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +16 -9
  114. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +193 -105
  115. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +68 -1
  116. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/__init__.py +0 -0
  117. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +202 -0
  118. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +324 -0
  119. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +204 -0
  120. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +218 -0
  121. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/ssl_config.py +10 -0
  122. msprobe/pytorch/bench_functions/__init__.py +15 -0
  123. msprobe/pytorch/bench_functions/apply_adam_w.py +28 -0
  124. msprobe/pytorch/bench_functions/confusion_transpose.py +19 -0
  125. msprobe/pytorch/bench_functions/fast_gelu.py +55 -0
  126. msprobe/pytorch/bench_functions/layer_norm_eval.py +6 -0
  127. msprobe/pytorch/bench_functions/linear.py +12 -0
  128. msprobe/pytorch/bench_functions/matmul_backward.py +48 -0
  129. msprobe/pytorch/bench_functions/npu_fusion_attention.py +421 -0
  130. msprobe/pytorch/bench_functions/rms_norm.py +15 -0
  131. msprobe/pytorch/bench_functions/rotary_mul.py +52 -0
  132. msprobe/pytorch/bench_functions/scaled_mask_softmax.py +26 -0
  133. msprobe/pytorch/bench_functions/swiglu.py +55 -0
  134. msprobe/pytorch/common/parse_json.py +3 -1
  135. msprobe/pytorch/common/utils.py +83 -7
  136. msprobe/pytorch/compare/distributed_compare.py +19 -64
  137. msprobe/pytorch/compare/match.py +3 -6
  138. msprobe/pytorch/compare/pt_compare.py +40 -0
  139. msprobe/pytorch/debugger/debugger_config.py +11 -2
  140. msprobe/pytorch/debugger/precision_debugger.py +34 -4
  141. msprobe/pytorch/doc/api_accuracy_checker.md +57 -13
  142. msprobe/pytorch/doc/api_accuracy_checker_online.md +187 -0
  143. msprobe/pytorch/doc/dump.md +73 -20
  144. msprobe/pytorch/doc/ptdbg_ascend_compare.md +75 -11
  145. msprobe/pytorch/doc/ptdbg_ascend_quickstart.md +3 -3
  146. msprobe/pytorch/doc/run_overflow_check.md +1 -1
  147. msprobe/pytorch/doc//321/206/320/247/320/260/321/206/320/260/320/227/321/206/320/255/320/226/321/205/342/225/226/320/265/321/205/320/225/342/225/226/321/205/320/254/342/225/221/321/206/320/251/320/277/321/211/320/272/320/234/321/210/320/277/320/221/321/205/320/242/320/234/321/206/320/220/320/267/321/210/320/223/342/225/234/321/205/320/257/342/225/221/321/207/342/225/221/342/224/220/321/206/320/232/320/265/321/205/320/241/320/232.md +151 -0
  148. msprobe/pytorch/free_benchmark/common/constant.py +3 -0
  149. msprobe/pytorch/free_benchmark/common/utils.py +4 -0
  150. msprobe/pytorch/free_benchmark/compare/grad_saver.py +22 -26
  151. msprobe/pytorch/free_benchmark/main.py +7 -4
  152. msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +1 -1
  153. msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +1 -1
  154. msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +1 -1
  155. msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +3 -3
  156. msprobe/pytorch/free_benchmark/perturbed_layers/npu/no_change.py +1 -1
  157. msprobe/pytorch/free_benchmark/perturbed_layers/run_cpu.py +1 -1
  158. msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +43 -29
  159. msprobe/pytorch/free_benchmark/result_handlers/handler_factory.py +0 -1
  160. msprobe/pytorch/function_factory.py +75 -0
  161. msprobe/pytorch/functional/dump_module.py +4 -4
  162. msprobe/pytorch/grad_probe/__init__.py +0 -0
  163. msprobe/pytorch/grad_probe/grad_monitor.py +90 -0
  164. msprobe/pytorch/grad_probe/grad_stat_csv.py +129 -0
  165. msprobe/pytorch/hook_module/hook_module.py +14 -3
  166. msprobe/pytorch/hook_module/support_wrap_ops.yaml +2 -1
  167. msprobe/pytorch/hook_module/utils.py +9 -9
  168. msprobe/pytorch/hook_module/wrap_aten.py +20 -10
  169. msprobe/pytorch/hook_module/wrap_distributed.py +10 -7
  170. msprobe/pytorch/hook_module/wrap_functional.py +4 -7
  171. msprobe/pytorch/hook_module/wrap_npu_custom.py +21 -10
  172. msprobe/pytorch/hook_module/wrap_tensor.py +5 -6
  173. msprobe/pytorch/hook_module/wrap_torch.py +5 -7
  174. msprobe/pytorch/hook_module/wrap_vf.py +6 -8
  175. msprobe/pytorch/module_processer.py +53 -13
  176. msprobe/pytorch/online_dispatch/compare.py +4 -4
  177. msprobe/pytorch/online_dispatch/dispatch.py +39 -41
  178. msprobe/pytorch/online_dispatch/dump_compare.py +17 -47
  179. msprobe/pytorch/online_dispatch/single_compare.py +5 -5
  180. msprobe/pytorch/online_dispatch/utils.py +2 -43
  181. msprobe/pytorch/parse_tool/lib/compare.py +31 -19
  182. msprobe/pytorch/parse_tool/lib/config.py +2 -1
  183. msprobe/pytorch/parse_tool/lib/parse_tool.py +4 -4
  184. msprobe/pytorch/parse_tool/lib/utils.py +34 -80
  185. msprobe/pytorch/parse_tool/lib/visualization.py +4 -3
  186. msprobe/pytorch/pt_config.py +100 -6
  187. msprobe/pytorch/service.py +104 -19
  188. mindstudio_probe-1.0.1.dist-info/RECORD +0 -228
  189. msprobe/mindspore/dump/api_kbk_dump.py +0 -55
  190. msprobe/pytorch/compare/acc_compare.py +0 -1024
  191. msprobe/pytorch/compare/highlight.py +0 -100
  192. msprobe/test/core_ut/common/test_utils.py +0 -345
  193. msprobe/test/core_ut/data_dump/test_data_collector.py +0 -47
  194. msprobe/test/core_ut/data_dump/test_json_writer.py +0 -183
  195. msprobe/test/core_ut/data_dump/test_scope.py +0 -151
  196. msprobe/test/core_ut/test_common_config.py +0 -152
  197. msprobe/test/core_ut/test_file_check.py +0 -218
  198. msprobe/test/core_ut/test_log.py +0 -109
  199. msprobe/test/mindspore_ut/test_api_kbk_dump.py +0 -51
  200. msprobe/test/mindspore_ut/test_debugger_config.py +0 -42
  201. msprobe/test/mindspore_ut/test_dump_tool_factory.py +0 -51
  202. msprobe/test/mindspore_ut/test_kernel_graph_dump.py +0 -66
  203. msprobe/test/mindspore_ut/test_kernel_graph_overflow_check.py +0 -63
  204. msprobe/test/mindspore_ut/test_ms_config.py +0 -69
  205. msprobe/test/mindspore_ut/test_overflow_check_tool_factory.py +0 -51
  206. msprobe/test/mindspore_ut/test_precision_debugger.py +0 -56
  207. msprobe/test/mindspore_ut/test_task_handler_factory.py +0 -58
  208. msprobe/test/pytorch_ut/advisor/test_advisor.py +0 -83
  209. msprobe/test/pytorch_ut/api_accuracy_checker/common/test_common_utils.py +0 -108
  210. msprobe/test/pytorch_ut/api_accuracy_checker/common/test_config.py +0 -39
  211. msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_algorithm.py +0 -112
  212. msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_api_precision_compare.py +0 -77
  213. msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_compare.py +0 -125
  214. msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_compare_column.py +0 -10
  215. msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_compare_utils.py +0 -43
  216. msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/dump.json +0 -179
  217. msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/forward.json +0 -63
  218. msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_data_generate.py +0 -99
  219. msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_multi_run_ut.py +0 -115
  220. msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_run_ut.py +0 -72
  221. msprobe/test/pytorch_ut/compare/test_acc_compare.py +0 -17
  222. msprobe/test/pytorch_ut/free_benchmark/perturbed_layers/test_perturbed_layser.py +0 -105
  223. msprobe/test/pytorch_ut/free_benchmark/result_handlers/test_result_handler.py +0 -121
  224. msprobe/test/pytorch_ut/free_benchmark/test_main.py +0 -101
  225. msprobe/test/pytorch_ut/functional/test_dump_module.py +0 -15
  226. msprobe/test/pytorch_ut/hook_module/test_api_registry.py +0 -130
  227. msprobe/test/pytorch_ut/hook_module/test_hook_module.py +0 -42
  228. msprobe/test/pytorch_ut/hook_module/test_wrap_aten.py +0 -65
  229. msprobe/test/pytorch_ut/hook_module/test_wrap_distributed.py +0 -35
  230. msprobe/test/pytorch_ut/hook_module/test_wrap_functional.py +0 -20
  231. msprobe/test/pytorch_ut/hook_module/test_wrap_tensor.py +0 -35
  232. msprobe/test/pytorch_ut/hook_module/test_wrap_torch.py +0 -43
  233. msprobe/test/pytorch_ut/hook_module/test_wrap_vf.py +0 -11
  234. msprobe/test/pytorch_ut/test_pt_config.py +0 -69
  235. msprobe/test/pytorch_ut/test_service.py +0 -59
  236. msprobe/test/resources/advisor.txt +0 -3
  237. msprobe/test/resources/compare_result_20230703104808.csv +0 -9
  238. msprobe/test/resources/compare_result_without_accuracy.csv +0 -9
  239. msprobe/test/resources/config.yaml +0 -3
  240. msprobe/test/resources/npu_test.pkl +0 -8
  241. msprobe/test/run_test.sh +0 -30
  242. msprobe/test/run_ut.py +0 -58
  243. msprobe/test/test_module_processer.py +0 -64
  244. {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.3.dist-info}/LICENSE +0 -0
  245. {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.3.dist-info}/WHEEL +0 -0
  246. {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.3.dist-info}/entry_points.txt +0 -0
  247. {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.3.dist-info}/top_level.txt +0 -0
  248. /msprobe/{pytorch → core}/advisor/advisor_const.py +0 -0
  249. /msprobe/pytorch/doc/{atat → msprobe}/321/207/342/226/223/342/225/233/321/205/342/225/221/320/266/321/205/342/225/226/320/265/321/205/320/225/342/225/226/321/206/320/245/342/226/221/321/206/320/235/320/276dump/321/206/320/260/320/227/321/205/320/227/320/226/321/206/320/220/320/267/321/210/320/223/342/225/234/321/205/320/257/342/225/221/321/207/342/225/221/342/224/220/321/206/320/232/320/265/321/205/320/241/320/232.md" +0 -0
@@ -14,13 +14,21 @@
14
14
  # See the License for the specific language governing permissions and
15
15
  # limitations under the License.
16
16
  """
17
+ import logging
17
18
  import os
18
19
  import random
19
20
  import stat
21
+ import csv
22
+ import json
20
23
  import torch
24
+ import torch.distributed as dist
21
25
  import numpy as np
22
26
  from functools import wraps
23
27
  from msprobe.core.common.exceptions import DistributedNotInitializedError
28
+ from msprobe.core.common.log import logger as common_logger
29
+ from msprobe.core.common.utils import check_file_or_directory_path, check_path_before_create, CompareException
30
+ from msprobe.core.common.file_check import FileCheckConst, change_mode, FileOpen
31
+
24
32
 
25
33
  try:
26
34
  import torch_npu
@@ -30,13 +38,8 @@ else:
30
38
  is_gpu = False
31
39
 
32
40
 
33
- torch_without_guard_version_list = ['2.1', '2.2']
34
- for version in torch_without_guard_version_list:
35
- if torch.__version__.startswith(version):
36
- torch_without_guard_version = True
37
- break
38
- else:
39
- torch_without_guard_version = False
41
+ torch_without_guard_version = torch.__version__ >= '2.1'
42
+
40
43
 
41
44
  if not is_gpu and not torch_without_guard_version:
42
45
  from torch_npu.utils.device_guard import torch_device_guard as torch_npu_device_guard
@@ -222,3 +225,76 @@ class Const:
222
225
  CONVERT_API = {
223
226
  "int32_to_int64": ["cross_entropy"]
224
227
  }
228
+
229
+
230
+ def get_tensor_rank(in_feat, out_feat):
231
+ if dist.is_initialized():
232
+ return dist.get_rank()
233
+
234
+ def get_tensor_rank_single(x):
235
+ if isinstance(x, (list, tuple)):
236
+ if len(x) > 0:
237
+ return get_tensor_rank_single(x[0])
238
+ elif isinstance(x, torch.Tensor):
239
+ device = x.device
240
+ if device.type != 'cpu':
241
+ return device.index
242
+ return None
243
+
244
+ in_rank = get_tensor_rank_single(in_feat)
245
+ out_rank = get_tensor_rank_single(out_feat)
246
+ tensor_rank = in_rank if in_rank else out_rank
247
+ return tensor_rank
248
+
249
+
250
+ def get_rank_id():
251
+ if torch.distributed.is_initialized():
252
+ return torch.distributed.get_rank()
253
+ return 0
254
+
255
+
256
+ def print_rank_0(message):
257
+ if dist.is_initialized():
258
+ if dist.get_rank() == 0:
259
+ logger.info(message)
260
+ else:
261
+ logger.info(message)
262
+
263
+
264
+ def load_pt(pt_path, to_cpu=False):
265
+ pt_path = os.path.realpath(pt_path)
266
+ check_file_or_directory_path(pt_path)
267
+ try:
268
+ if to_cpu:
269
+ pt = torch.load(pt_path, map_location=torch.device("cpu"))
270
+ else:
271
+ pt = torch.load(pt_path)
272
+ except Exception as e:
273
+ raise RuntimeError(f"load pt file {pt_path} failed") from e
274
+ return pt
275
+
276
+
277
+ def save_pt(tensor, filepath):
278
+ filepath = os.path.realpath(filepath)
279
+ check_path_before_create(filepath)
280
+ try:
281
+ torch.save(tensor, filepath)
282
+ except Exception as e:
283
+ common_logger.error("Save pt file failed, please check according possible error causes: "
284
+ "1. out of disk space or disk error, "
285
+ "2. no permission to write files, etc.")
286
+ raise RuntimeError(f"save pt file {filepath} failed") from e
287
+ change_mode(filepath, FileCheckConst.DATA_FILE_AUTHORITY)
288
+
289
+
290
+ def _create_logger(level=logging.INFO):
291
+ logger_ = logging.getLogger()
292
+ logger_.setLevel(level)
293
+ ch = logging.StreamHandler()
294
+ ch.setLevel(level)
295
+ logger_.addHandler(ch)
296
+ return logger_
297
+
298
+
299
+ log_level = logging.DEBUG if os.environ.get("API_ACCURACY_CHECK_LOG_LEVEL") == "1" else logging.INFO
300
+ logger = _create_logger(log_level)
@@ -15,62 +15,17 @@
15
15
  # limitations under the License.
16
16
  """
17
17
  import os
18
- import sys
19
- import re
20
18
  from msprobe.core.common.utils import CompareException, check_compare_param, \
21
- check_configuration_param, task_dumppath_get, check_file_or_directory_path, check_regex_prefix_format_valid
22
- from msprobe.pytorch.compare.acc_compare import compare_core
19
+ check_configuration_param, task_dumppath_get
23
20
  from msprobe.core.common.file_check import create_directory
24
- from msprobe.pytorch.common.log import logger
21
+ from msprobe.core.common.exceptions import FileCheckException
22
+ from msprobe.core.common.log import logger
23
+ from msprobe.core.common.const import Const
24
+ from msprobe.pytorch.compare.pt_compare import PTComparator
25
+ from msprobe.core.compare.utils import check_and_return_dir_contents, extract_json
25
26
 
26
27
 
27
28
  def compare_distributed(npu_dump_dir, bench_dump_dir, output_path, **kwargs):
28
- def check_and_return_dir_contents(dump_dir, prefix):
29
- """
30
- check the given dump dir and validate files in dump dir by using the given prefix patterns to build a
31
- pattern: ^{prefix}(?:0|[0-9][1-9]*)?$
32
-
33
- Args:
34
- dump_dir (str): dump dir
35
- prefix (str): prefix for the patterns, prefix should be less than 20 characters and alphanumeric/-/_ only
36
-
37
- Returns:
38
- content [list]: dir contents
39
- Raises:
40
- CompareException: invalid path
41
- ValueError: prefix not match the patterns
42
-
43
- """
44
- check_regex_prefix_format_valid(prefix)
45
- check_file_or_directory_path(dump_dir, True)
46
- contents = os.listdir(dump_dir)
47
- pattern = re.compile(rf'^{prefix}(?:0|[0-9][1-9]*)?$')
48
- for name in contents:
49
- if not pattern.match(name):
50
- logger.error(
51
- f"dump_dir contains '{name}'. Expected '{prefix}'. This name is not in the format of dump "
52
- f"output. Please check and delete irrelevant files in {dump_dir} and try again."
53
- )
54
- raise CompareException(CompareException.INVALID_PATH_ERROR)
55
- return contents
56
-
57
- def extract_json(dirname, stack_json=False):
58
- json_path = ''
59
- for fname in os.listdir(dirname):
60
- full_path = os.path.join(dirname, fname)
61
- if full_path.endswith('.json'):
62
- json_path = full_path
63
- if not stack_json and 'stack' not in json_path:
64
- break
65
- if stack_json and 'stack' in json_path:
66
- break
67
-
68
- # Provide robustness on invalid directory inputs
69
- if not json_path:
70
- logger.error(f'No file is found in dump dir {dirname}. ')
71
- raise CompareException(CompareException.NO_DUMP_FILE_ERROR)
72
- return json_path
73
-
74
29
  if kwargs.get('suffix'):
75
30
  logger.error("Argument 'suffix' is not supported for compare_distributed.")
76
31
  raise CompareException(CompareException.INVALID_PARAM_ERROR)
@@ -86,26 +41,26 @@ def compare_distributed(npu_dump_dir, bench_dump_dir, output_path, **kwargs):
86
41
  'or use compare() api and manually match the ranks.')
87
42
  raise CompareException(CompareException.INVALID_PATH_ERROR)
88
43
  for nr, br in zip(npu_ranks, bench_ranks):
89
- n_dir = os.path.join(npu_dump_dir, nr)
90
- b_dir = os.path.join(bench_dump_dir, br)
91
- s_dir = b_dir
92
- npu_json_path = extract_json(n_dir, stack_json=False)
93
- bench_json_path = extract_json(b_dir, stack_json=False)
94
- stack_json_path = extract_json(s_dir, stack_json=True)
44
+ npu_data_dir = os.path.join(npu_dump_dir, nr)
45
+ bench_data_dir = os.path.join(bench_dump_dir, br)
46
+ npu_path = extract_json(npu_data_dir, stack_json=False)
47
+ bench_path = extract_json(bench_data_dir, stack_json=False)
48
+ stack_path = extract_json(npu_data_dir, stack_json=True)
95
49
 
96
50
  dump_result_param = {
97
- 'npu_json_path': npu_json_path,
98
- 'bench_json_path': bench_json_path,
99
- 'stack_json_path': stack_json_path,
51
+ 'npu_json_path': npu_path,
52
+ 'bench_json_path': bench_path,
53
+ 'stack_json_path': stack_path,
100
54
  'is_print_compare_log': True
101
55
  }
102
56
  try:
103
57
  summary_compare, md5_compare = task_dumppath_get(dump_result_param)
104
58
  check_configuration_param(stack_mode, auto_analyze, fuzzy_match)
105
59
  create_directory(output_path)
106
- check_compare_param(dump_result_param, output_path, stack_mode=stack_mode, summary_compare=summary_compare)
107
- except CompareException as error:
60
+ check_compare_param(dump_result_param, output_path, summary_compare=summary_compare, md5_compare=md5_compare)
61
+ except (CompareException, FileCheckException) as error:
108
62
  logger.error('Compare failed. Please check the arguments and do it again!')
109
- sys.exit(error.code)
110
- compare_core(dump_result_param, output_path, suffix=f'_{nr}-{br}', summary_compare=summary_compare,
63
+ raise CompareException(error.code) from error
64
+ pt_comparator = PTComparator()
65
+ pt_comparator.compare_core(dump_result_param, output_path, suffix=f'_{nr}-{br}', summary_compare=summary_compare,
111
66
  md5_compare=md5_compare, **kwargs)
@@ -1,16 +1,13 @@
1
1
  import os
2
- import yaml
3
- from msprobe.core.common.file_check import FileOpen
4
- from msprobe.core.common.utils import CompareException
2
+ from msprobe.core.common.utils import CompareException, load_yaml
5
3
 
6
4
 
7
5
  class AtenIrMapping():
8
6
  def __init__(self):
9
7
  cur_path = os.path.dirname(os.path.realpath(__file__))
10
8
  yaml_path = os.path.join(cur_path, "mapping.yaml")
11
- with FileOpen(yaml_path, 'r') as f:
12
- self.aten_mapping = yaml.safe_load(f)
13
-
9
+ self.aten_mapping = load_yaml(yaml_path)
10
+
14
11
  def match(self, op1, op2):
15
12
  if "Aten" in op1 and "Aten" not in op2:
16
13
  return self.match_op(op1, op2)
@@ -0,0 +1,40 @@
1
+ import os.path
2
+ import torch
3
+ from msprobe.core.common.const import FileCheckConst, Const
4
+ from msprobe.core.common.log import logger
5
+ from msprobe.core.common.exceptions import FileCheckException
6
+ from msprobe.core.compare.acc_compare import Comparator
7
+ from msprobe.core.common.utils import create_directory, check_configuration_param, task_dumppath_get, \
8
+ check_compare_param, FileChecker
9
+ from msprobe.core.common.utils import CompareException
10
+
11
+
12
+ class PTComparator (Comparator):
13
+ def __init__(self):
14
+ self.frame_name = PTComparator.__name__
15
+
16
+ def read_npy_data(self, dir_path, file_name):
17
+ data_path = os.path.join(dir_path, file_name)
18
+ path_checker = FileChecker(data_path, FileCheckConst.FILE, FileCheckConst.READ_ABLE,
19
+ FileCheckConst.PT_SUFFIX, False)
20
+ data_path = path_checker.common_check()
21
+ data_value = torch.load(data_path, map_location=torch.device('cpu')).detach() # detach for less memory
22
+ if data_value.dtype == torch.bfloat16:
23
+ data_value = data_value.to(torch.float32)
24
+ data_value = data_value.numpy()
25
+ return data_value
26
+
27
+
28
+ def compare(input_param, output_path, stack_mode=False, auto_analyze=True, fuzzy_match=False):
29
+ try:
30
+ summary_compare, md5_compare = task_dumppath_get(input_param)
31
+ check_configuration_param(stack_mode, auto_analyze, fuzzy_match)
32
+ create_directory(output_path)
33
+ check_compare_param(input_param, output_path, summary_compare, md5_compare)
34
+ except (CompareException, FileCheckException) as error:
35
+ logger.error('Compare failed. Please check the arguments and do it again!')
36
+ raise CompareException(error.code) from error
37
+ pt_comparator = PTComparator()
38
+ pt_comparator.compare_core(input_param, output_path, stack_mode=stack_mode,
39
+ auto_analyze=auto_analyze, fuzzy_match=fuzzy_match, summary_compare=summary_compare,
40
+ md5_compare=md5_compare)
@@ -21,7 +21,7 @@ class DebuggerConfig:
21
21
  self.acl_config = common_config.acl_config if common_config.acl_config else ""
22
22
  self.is_forward_acl_dump = True
23
23
  self.summary_mode = task_config.summary_mode if task_config.summary_mode else Const.STATISTICS
24
- self.overflow_num = task_config.overflow_num if task_config.overflow_num else 1
24
+ self.overflow_nums = task_config.overflow_nums if task_config.overflow_nums else 1
25
25
  self.framework = Const.PT_FRAMEWORK
26
26
 
27
27
  if self.task == Const.FREE_BENCHMARK:
@@ -35,7 +35,16 @@ class DebuggerConfig:
35
35
  "preheat_step": task_config.preheat_step if task_config.preheat_step else 15,
36
36
  "max_sample": task_config.max_sample if task_config.max_sample else 20,
37
37
  }
38
-
38
+
39
+ self.online_run_ut = False
40
+ if self.task == Const.TENSOR:
41
+ # dump api tensor and collaborate with online run_ut
42
+ self.online_run_ut = task_config.online_run_ut if task_config.online_run_ut else False
43
+ self.nfs_path = task_config.nfs_path if task_config.nfs_path else ""
44
+ self.tls_path = task_config.tls_path if task_config.tls_path else ""
45
+ self.host = task_config.host if task_config.host else ""
46
+ self.port = task_config.port if task_config.port else -1
47
+
39
48
  self.check()
40
49
  if self.step:
41
50
  self.step.sort()
@@ -4,11 +4,14 @@ from msprobe.pytorch.debugger.debugger_config import DebuggerConfig
4
4
  from msprobe.pytorch.service import Service
5
5
  from msprobe.pytorch.common.log import logger
6
6
  from msprobe.pytorch.pt_config import parse_json_config
7
- from msprobe.core.common.exceptions import MsaccException
7
+ from msprobe.core.common.exceptions import MsprobeException
8
+ from msprobe.core.common.const import Const
9
+ from msprobe.pytorch.grad_probe.grad_monitor import GradientMonitor
8
10
 
9
11
 
10
12
  class PrecisionDebugger:
11
13
  _instance = None
14
+ tasks_not_need_debugger = [Const.GRAD_PROBE]
12
15
 
13
16
  def __new__(cls, *args, **kwargs):
14
17
  if cls._instance is None:
@@ -27,9 +30,14 @@ class PrecisionDebugger:
27
30
  step=None,
28
31
  ):
29
32
  if not hasattr(self, "initialized"):
33
+ self.api_origin = False
30
34
  self.initialized = True
31
35
  self.model = self.check_model_valid(model)
32
36
  common_config, task_config = parse_json_config(config_path, task)
37
+ self.task = common_config.task
38
+ if self.task == Const.GRAD_PROBE:
39
+ self.gm = GradientMonitor(common_config, task_config)
40
+ return
33
41
  if step:
34
42
  common_config.step = step
35
43
  self.config = DebuggerConfig(
@@ -50,23 +58,35 @@ class PrecisionDebugger:
50
58
  def check_model_valid(model):
51
59
  if not model or isinstance(model, torch.nn.Module):
52
60
  return model
53
- raise MsaccException(
54
- MsaccException.INVALID_PARAM_ERROR, "model 参数必须是torch.nn.Module类型。"
61
+ raise MsprobeException(
62
+ MsprobeException.INVALID_PARAM_ERROR, "model 参数必须是torch.nn.Module类型。"
55
63
  )
56
64
 
57
65
  @classmethod
58
66
  def start(cls):
59
67
  instance = cls._instance
68
+ if instance.task in PrecisionDebugger.tasks_not_need_debugger:
69
+ return
60
70
  if not instance:
61
71
  raise Exception("No instance of PrecisionDebugger found.")
62
72
  if instance.enable_dataloader:
63
73
  logger.warning_on_rank_0("DataLoader is enabled, start() skipped.")
64
74
  else:
65
- instance.service.start(instance.model)
75
+ instance.service.start(instance.model, instance.api_origin)
76
+ instance.api_origin = False
77
+
78
+ # 指定代码段dump前反向结束符,之后的计算过程数据将被忽略,无法被dump
79
+ @classmethod
80
+ def forward_backward_dump_end(cls):
81
+ instance = cls._instance
82
+ instance.service.forward_backward_dump_end()
83
+ instance.api_origin = True
66
84
 
67
85
  @classmethod
68
86
  def stop(cls):
69
87
  instance = cls._instance
88
+ if instance.task in PrecisionDebugger.tasks_not_need_debugger:
89
+ return
70
90
  if not instance:
71
91
  raise Exception("PrecisionDebugger instance is not created.")
72
92
  if instance.enable_dataloader:
@@ -76,10 +96,20 @@ class PrecisionDebugger:
76
96
 
77
97
  @classmethod
78
98
  def step(cls):
99
+ if cls._instance.task in PrecisionDebugger.tasks_not_need_debugger:
100
+ return
79
101
  if not cls._instance:
80
102
  raise Exception("PrecisionDebugger instance is not created.")
81
103
  cls._instance.service.step()
82
104
 
105
+ @classmethod
106
+ def monitor(cls, model):
107
+ if not cls._instance:
108
+ raise Exception("PrecisionDebugger instance is not created.")
109
+ if cls._instance.task != Const.GRAD_PROBE:
110
+ return
111
+ cls._instance.gm.monitor(model)
112
+
83
113
 
84
114
  def iter_tracer(func):
85
115
  def func_wrapper(*args, **kwargs):
@@ -8,7 +8,7 @@
8
8
 
9
9
  **真实数据模式**:精度预检工具支持随机生成模式和真实数据模式,即在预检dump时可以选择由工具构造随机数进行输入获得dump数据或选择获取真实输入数据进行预检dump操作;随机生成模式执行效率高,可以快速获得结果,但数据精度低,只能大致判断精度问题;真实数据模式执行效率略低于随机生成模式,但是数据精度高,可以准确判断精度问题。
10
10
 
11
- **工具支持PyTorch版本**:2.0/2.1/2.2。
11
+ **工具支持PyTorch版本**:1.11/2.0/2.1/2.2。
12
12
 
13
13
  **工具特性**
14
14
 
@@ -21,7 +21,7 @@
21
21
  精度预检操作流程如下:
22
22
 
23
23
  1. 在NPU和GPU环境下分别安装msprobe工具。详见《[MindStudio精度调试工具](../../README.md)》的“工具安装”章节。
24
- 2. 在NPU训练脚本内添加msprobe工具dump接口PrecisionDebugger采集待预检数据。详见《[精度数据采集](./dump.md)》。
24
+ 2. 在NPU训练脚本内添加msprobe工具dump接口PrecisionDebugger,采集待预检数据。详见《[精度数据采集](./dump.md)》,注意需要配置level="L1"。
25
25
  3. 将NPU环境下dump的预检数据拷贝至GPU环境。
26
26
  4. 在NPU和GPU环境下分别执行run_ut,生成结果用于最终api_precision_compare操作的输入。详见“**run_ut预检操作**”。
27
27
  5. 将NPU和GPU执行run_ut生成的`accuracy_checking_details_{timestamp}.csv`结果文件拷贝至同一环境下。
@@ -51,10 +51,12 @@ run_ut预检操作包括如下场景:
51
51
  | -api_info或--api_info_file | 指定API信息文件dump.json。 | 是 |
52
52
  | -save_error_data | 保存精度未达标的API输入输出数据。 | 否 |
53
53
  | -o或--out_path | 指定run_ut执行结果存盘路径,默认“./”(相对于run_ut的路径)。 | 否 |
54
+ | | | |
54
55
  | -j或--jit_compile | 开启jit编译。 | 否 |
55
56
  | -d或--device | 指定Device ID,选择UT代码运行所在的卡,默认值为0。 | 否 |
56
57
  | -csv_path或--result_csv_path | 指定本次运行中断时生成的`accuracy_checking_result_{timestamp}.csv`文件路径,执行run_ut中断时,若想从中断处继续执行,配置此参数即可。需要指定为上次中断的`accuracy_checking_result_{timestamp}.csv`文件。详见“**断点续检**”。 | run_ut操作中断后继续执行场景下必选 |
57
58
  | -f或--filter_api | 过滤模型中除最大值和最小值以外其他参数和结构相同的API。适用于模型较大且重复API较多的场景。 | 否 |
59
+ | -config或--config_path | 指定预检操作过程中的额外配置(包括黑名单、白名单等)的[config.json](https://gitee.com/ascend/mstt/tree/master/debug/accuracy_tools/msprobe/config)文件,默认未配置。config.json文件的配置可参考《[配置文件说明](https://gitee.com/ascend/mstt/blob/master/debug/accuracy_tools/msprobe/config/README.md#pytorch场景task配置为run_ut)》。 | 否 |
58
60
 
59
61
  run_ut执行结果包括`accuracy_checking_result_{timestamp}.csv`和`accuracy_checking_details_{timestamp}.csv`两个文件。`accuracy_checking_result_{timestamp}.csv`是API粒度的,标明每个API是否通过测试。建议用户先查看`accuracy_checking_result_{timestamp}.csv`文件,对于其中没有通过测试的或者特定感兴趣的API,根据其API name字段在`accuracy_checking_details_{timestamp}.csv`中查询其各个输出的达标情况以及比较指标。详细介绍请参见“**预检结果**”。
60
62
 
@@ -64,7 +66,7 @@ run_ut预检操作包括如下场景:
64
66
  msprobe -f pytorch run_ut -api_info ./dump.json -save_error_data
65
67
  ```
66
68
 
67
- 数据默认会存盘到'./ut_error_data{timestamp}'路径下(相对于启动run_ut的路径),有需要的话,用户可以通过修改mstt/debug/accuracy_tools/api_accuracy_checker目录下,config.yaml文件的error_data_path参数来配置保存路径,详见“config.yaml文件说明”。
69
+ 数据默认会存盘到'./ut_error_data{timestamp}'路径下(相对于启动run_ut的路径),有需要的话,用户可以通过error_data_path参数来配置保存路径,error_data_path参数在[config.json](https://gitee.com/ascend/mstt/tree/master/debug/accuracy_tools/msprobe/config)文件或config.yaml文件配置,config.json文件需要在run_ut操作时通过-config参数指定,config.yaml文件详见“**config.yaml文件说明**”。
68
70
 
69
71
  #### 使用multi_run_ut.py执行多线程预检
70
72
 
@@ -99,23 +101,65 @@ msprobe -f pytorch multi_run_ut -api_info ./dump.json -n 32 -d 0 1 2 3
99
101
  msprobe -f pytorch run_ut -api_info ./dump.json -csv_path /home/xxx/ut/accuracy_checking_result_{timestamp}.csv
100
102
  ```
101
103
 
102
- #### API预检白名单
104
+ #### API预检黑名单和白名单
103
105
 
104
- run_ut过程支持API预检白名单,操作方式如下:
106
+ run_ut过程支持API预检黑名单和白名单,通过如下文件配置black_list(黑名单)或white_list(白名单)参数来指定不需要或需要预检的API名称:
105
107
 
106
- 修改mstt/debug/accuracy_tools/api_accuracy_checker目录下config.yaml文件的white_list参数,配置需要预检的API名称,详见“config.yaml文件说明”。
108
+ - 配置[config.json](https://gitee.com/ascend/mstt/tree/master/debug/accuracy_tools/msprobe/config)文件,config.json文件需要在run_ut操作时通过-config参数指定。
109
+ - 配置config.yaml文件,详见“**config.yaml文件说明**”。
110
+
111
+ config.json文件的优先级高于config.yaml文件,即执行config.json文件时,config.yaml文件的配置不生效。
107
112
 
108
113
  ### config.yaml文件说明
109
114
 
110
- config.yaml文件可以通过配置参数来控制dump和run_ut操作的白名单等功能。
115
+ config.yaml文件可以通过配置参数来控制dump和run_ut操作的白名单、黑名单等功能。操作步骤如下:
116
+
117
+ 1. 查找msprobe工具安装路径。
118
+
119
+ ```bash
120
+ pip show mindstudio-probe
121
+ ```
122
+
123
+ 输出结果如下示例:
124
+
125
+ ```bash
126
+ Name: mindstudio-probe
127
+ Version: 1.0
128
+ Summary: This is a pytorch precision comparison tools
129
+ Home-page:
130
+ Author:
131
+ Author-email:
132
+ License:
133
+ Location: /home/xx/anaconda3/envs/pt21py38/lib/python3.8/site-packages
134
+ Requires: numpy, openpyxl, pandas, pyyaml, rich, tqdm, wheel
135
+ Required-by:
136
+ ```
137
+
138
+ Location字段为msprobe工具的安装路径,那么config.yaml文件位置为/home/xx/anaconda3/envs/pt21py38/lib/python3.8/site-packages/msprobe/pytorch/api_accuracy_checker/config.yaml
139
+
140
+ 2. 进入config.yaml文件
141
+
142
+ ```bash
143
+ vi /home/xx/anaconda3/envs/pt21py38/lib/python3.8/site-packages/msprobe/pytorch/api_accuracy_checker/config.yaml
144
+ ```
145
+
146
+ 3. 修改config.yaml文件参数。
147
+
148
+ ```yaml
149
+ white_list: []
150
+ black_list: []
151
+ error_data_path: './'
152
+ precision: 14
153
+ ```
111
154
 
112
- 文件路径为:mstt/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/config.yaml
155
+ | 参数名称 | 说明 | 是否必选 |
156
+ | --------------- | ------------------------------------------------------------ | -------- |
157
+ | white_list | API dump白名单,仅对指定的API进行dump。参数示例:white_list=["conv1d", "conv2d"]。默认未配置白名单,即dump全量API数据。 | 否 |
158
+ | black_list | API dump黑名单,被指定的API不进行dump。参数示例:black_list=["conv1d", "conv2d"]。默认未配置黑名单,即dump全量API数据。 | 否 |
159
+ | error_data_path | 配置保存精度未达标的API输入输出数据路径。参数示例"error_data_path": "./"。默认为当前路径。 | 否 |
160
+ | precision | 浮点数表示位数,默认取小数点后14位。 | 否 |
113
161
 
114
- | 参数名称 | 说明 | 是否必选 |
115
- | --------------- | ------------------------------------------------------------ | -------- |
116
- | white_list | API dump白名单,指定dump具体API数据,也可以直接配置预检的API白名单,详细请参见“**API预检白名单**”。参数示例:white_list=["conv1d", "conv2d"]。默认未配置白名单,即dump全量API数据。 | 否 |
117
- | error_data_path | 配置保存精度未达标的API输入输出数据路径。 | 否 |
118
- | precision | 浮点数表示位数,默认取小数点后14位。 | 否 |
162
+ 说明:white_list和black_list同时配置时,二者配置的API名单若无交集,则白名单生效,若API名单存在交集,则白名单排除的部分以及交集的API不进行dump。
119
163
 
120
164
  ## 预检结果
121
165