mindstudio-probe 1.0.1__py3-none-any.whl → 1.0.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (249) hide show
  1. {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.3.dist-info}/METADATA +5 -1
  2. mindstudio_probe-1.0.3.dist-info/RECORD +272 -0
  3. msprobe/README.md +78 -23
  4. msprobe/__init__.py +1 -0
  5. msprobe/config/README.md +182 -40
  6. msprobe/config/config.json +22 -0
  7. msprobe/core/__init__.py +0 -0
  8. msprobe/{pytorch → core}/advisor/advisor.py +3 -3
  9. msprobe/{pytorch → core}/advisor/advisor_result.py +2 -2
  10. msprobe/core/common/const.py +82 -5
  11. msprobe/core/common/exceptions.py +30 -18
  12. msprobe/core/common/file_check.py +19 -1
  13. msprobe/core/common/log.py +15 -1
  14. msprobe/core/common/utils.py +130 -30
  15. msprobe/core/common_config.py +32 -19
  16. msprobe/core/compare/acc_compare.py +299 -0
  17. msprobe/core/compare/check.py +95 -0
  18. msprobe/core/compare/compare_cli.py +49 -0
  19. msprobe/core/compare/highlight.py +222 -0
  20. msprobe/core/compare/multiprocessing_compute.py +149 -0
  21. msprobe/{pytorch → core}/compare/npy_compare.py +55 -4
  22. msprobe/core/compare/utils.py +429 -0
  23. msprobe/core/data_dump/data_collector.py +39 -35
  24. msprobe/core/data_dump/data_processor/base.py +85 -37
  25. msprobe/core/data_dump/data_processor/factory.py +5 -7
  26. msprobe/core/data_dump/data_processor/mindspore_processor.py +198 -0
  27. msprobe/core/data_dump/data_processor/pytorch_processor.py +94 -51
  28. msprobe/core/data_dump/json_writer.py +11 -11
  29. msprobe/core/grad_probe/__init__.py +0 -0
  30. msprobe/core/grad_probe/constant.py +71 -0
  31. msprobe/core/grad_probe/grad_compare.py +175 -0
  32. msprobe/core/grad_probe/utils.py +52 -0
  33. msprobe/doc/grad_probe/grad_probe.md +207 -0
  34. msprobe/doc/grad_probe/img/image-1.png +0 -0
  35. msprobe/doc/grad_probe/img/image-2.png +0 -0
  36. msprobe/doc/grad_probe/img/image-3.png +0 -0
  37. msprobe/doc/grad_probe/img/image-4.png +0 -0
  38. msprobe/doc/grad_probe/img/image.png +0 -0
  39. msprobe/mindspore/api_accuracy_checker/__init__.py +0 -0
  40. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +246 -0
  41. msprobe/mindspore/api_accuracy_checker/api_info.py +69 -0
  42. msprobe/mindspore/api_accuracy_checker/api_runner.py +152 -0
  43. msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +197 -0
  44. msprobe/mindspore/api_accuracy_checker/compute_element.py +224 -0
  45. msprobe/mindspore/api_accuracy_checker/main.py +16 -0
  46. msprobe/mindspore/api_accuracy_checker/type_mapping.py +114 -0
  47. msprobe/mindspore/api_accuracy_checker/utils.py +63 -0
  48. msprobe/mindspore/cell_processor.py +34 -0
  49. msprobe/mindspore/common/const.py +87 -0
  50. msprobe/mindspore/common/log.py +38 -0
  51. msprobe/mindspore/common/utils.py +57 -0
  52. msprobe/mindspore/compare/distributed_compare.py +75 -0
  53. msprobe/mindspore/compare/ms_compare.py +117 -0
  54. msprobe/mindspore/compare/ms_graph_compare.py +317 -0
  55. msprobe/mindspore/compare/ms_to_pt_api.yaml +399 -0
  56. msprobe/mindspore/debugger/debugger_config.py +38 -15
  57. msprobe/mindspore/debugger/precision_debugger.py +79 -4
  58. msprobe/mindspore/doc/compare.md +58 -0
  59. msprobe/mindspore/doc/dump.md +158 -6
  60. msprobe/mindspore/dump/dump_tool_factory.py +19 -22
  61. msprobe/mindspore/dump/hook_cell/api_registry.py +104 -0
  62. msprobe/mindspore/dump/hook_cell/hook_cell.py +53 -0
  63. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +925 -0
  64. msprobe/mindspore/dump/hook_cell/wrap_functional.py +91 -0
  65. msprobe/mindspore/dump/hook_cell/wrap_tensor.py +63 -0
  66. msprobe/mindspore/dump/jit_dump.py +56 -0
  67. msprobe/mindspore/dump/kernel_kbyk_dump.py +65 -0
  68. msprobe/mindspore/free_benchmark/__init__.py +0 -0
  69. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +116 -0
  70. msprobe/mindspore/free_benchmark/common/__init__.py +0 -0
  71. msprobe/mindspore/free_benchmark/common/config.py +12 -0
  72. msprobe/mindspore/free_benchmark/common/handler_params.py +17 -0
  73. msprobe/mindspore/free_benchmark/common/utils.py +71 -0
  74. msprobe/mindspore/free_benchmark/data/support_wrap_ops.yaml +842 -0
  75. msprobe/mindspore/free_benchmark/decorator/__init__.py +0 -0
  76. msprobe/mindspore/free_benchmark/decorator/dec_forward.py +42 -0
  77. msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +107 -0
  78. msprobe/mindspore/free_benchmark/handler/__init__.py +0 -0
  79. msprobe/mindspore/free_benchmark/handler/base_handler.py +90 -0
  80. msprobe/mindspore/free_benchmark/handler/check_handler.py +41 -0
  81. msprobe/mindspore/free_benchmark/handler/fix_handler.py +36 -0
  82. msprobe/mindspore/free_benchmark/handler/handler_factory.py +21 -0
  83. msprobe/mindspore/free_benchmark/perturbation/add_noise.py +67 -0
  84. msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +21 -0
  85. msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +63 -0
  86. msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +34 -0
  87. msprobe/mindspore/free_benchmark/perturbation/no_change.py +12 -0
  88. msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +27 -0
  89. msprobe/mindspore/free_benchmark/self_check_tool_factory.py +33 -0
  90. msprobe/mindspore/grad_probe/__init__.py +0 -0
  91. msprobe/mindspore/grad_probe/global_context.py +91 -0
  92. msprobe/mindspore/grad_probe/grad_analyzer.py +231 -0
  93. msprobe/mindspore/grad_probe/grad_monitor.py +27 -0
  94. msprobe/mindspore/grad_probe/grad_stat_csv.py +132 -0
  95. msprobe/mindspore/grad_probe/hook.py +92 -0
  96. msprobe/mindspore/grad_probe/utils.py +29 -0
  97. msprobe/mindspore/ms_config.py +63 -15
  98. msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +17 -15
  99. msprobe/mindspore/runtime.py +4 -0
  100. msprobe/mindspore/service.py +354 -0
  101. msprobe/mindspore/task_handler_factory.py +7 -4
  102. msprobe/msprobe.py +66 -26
  103. msprobe/pytorch/__init__.py +1 -1
  104. msprobe/pytorch/api_accuracy_checker/common/config.py +21 -16
  105. msprobe/pytorch/api_accuracy_checker/common/utils.py +1 -60
  106. msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +2 -5
  107. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +46 -10
  108. msprobe/pytorch/api_accuracy_checker/compare/compare.py +84 -48
  109. msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +8 -12
  110. msprobe/pytorch/api_accuracy_checker/config.yaml +7 -1
  111. msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +15 -11
  112. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +11 -15
  113. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +16 -9
  114. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +193 -105
  115. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +68 -1
  116. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/__init__.py +0 -0
  117. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +202 -0
  118. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +324 -0
  119. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +204 -0
  120. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +218 -0
  121. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/ssl_config.py +10 -0
  122. msprobe/pytorch/bench_functions/__init__.py +15 -0
  123. msprobe/pytorch/bench_functions/apply_adam_w.py +28 -0
  124. msprobe/pytorch/bench_functions/confusion_transpose.py +19 -0
  125. msprobe/pytorch/bench_functions/fast_gelu.py +55 -0
  126. msprobe/pytorch/bench_functions/layer_norm_eval.py +6 -0
  127. msprobe/pytorch/bench_functions/linear.py +12 -0
  128. msprobe/pytorch/bench_functions/matmul_backward.py +48 -0
  129. msprobe/pytorch/bench_functions/npu_fusion_attention.py +421 -0
  130. msprobe/pytorch/bench_functions/rms_norm.py +15 -0
  131. msprobe/pytorch/bench_functions/rotary_mul.py +52 -0
  132. msprobe/pytorch/bench_functions/scaled_mask_softmax.py +26 -0
  133. msprobe/pytorch/bench_functions/swiglu.py +55 -0
  134. msprobe/pytorch/common/parse_json.py +3 -1
  135. msprobe/pytorch/common/utils.py +83 -7
  136. msprobe/pytorch/compare/distributed_compare.py +19 -64
  137. msprobe/pytorch/compare/match.py +3 -6
  138. msprobe/pytorch/compare/pt_compare.py +40 -0
  139. msprobe/pytorch/debugger/debugger_config.py +11 -2
  140. msprobe/pytorch/debugger/precision_debugger.py +34 -4
  141. msprobe/pytorch/doc/api_accuracy_checker.md +57 -13
  142. msprobe/pytorch/doc/api_accuracy_checker_online.md +187 -0
  143. msprobe/pytorch/doc/dump.md +73 -20
  144. msprobe/pytorch/doc/ptdbg_ascend_compare.md +75 -11
  145. msprobe/pytorch/doc/ptdbg_ascend_quickstart.md +3 -3
  146. msprobe/pytorch/doc/run_overflow_check.md +1 -1
  147. msprobe/pytorch/doc//321/206/320/247/320/260/321/206/320/260/320/227/321/206/320/255/320/226/321/205/342/225/226/320/265/321/205/320/225/342/225/226/321/205/320/254/342/225/221/321/206/320/251/320/277/321/211/320/272/320/234/321/210/320/277/320/221/321/205/320/242/320/234/321/206/320/220/320/267/321/210/320/223/342/225/234/321/205/320/257/342/225/221/321/207/342/225/221/342/224/220/321/206/320/232/320/265/321/205/320/241/320/232.md +151 -0
  148. msprobe/pytorch/free_benchmark/common/constant.py +3 -0
  149. msprobe/pytorch/free_benchmark/common/utils.py +4 -0
  150. msprobe/pytorch/free_benchmark/compare/grad_saver.py +22 -26
  151. msprobe/pytorch/free_benchmark/main.py +7 -4
  152. msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +1 -1
  153. msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +1 -1
  154. msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +1 -1
  155. msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +3 -3
  156. msprobe/pytorch/free_benchmark/perturbed_layers/npu/no_change.py +1 -1
  157. msprobe/pytorch/free_benchmark/perturbed_layers/run_cpu.py +1 -1
  158. msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +43 -29
  159. msprobe/pytorch/free_benchmark/result_handlers/handler_factory.py +0 -1
  160. msprobe/pytorch/function_factory.py +75 -0
  161. msprobe/pytorch/functional/dump_module.py +4 -4
  162. msprobe/pytorch/grad_probe/__init__.py +0 -0
  163. msprobe/pytorch/grad_probe/grad_monitor.py +90 -0
  164. msprobe/pytorch/grad_probe/grad_stat_csv.py +129 -0
  165. msprobe/pytorch/hook_module/hook_module.py +14 -3
  166. msprobe/pytorch/hook_module/support_wrap_ops.yaml +2 -1
  167. msprobe/pytorch/hook_module/utils.py +9 -9
  168. msprobe/pytorch/hook_module/wrap_aten.py +20 -10
  169. msprobe/pytorch/hook_module/wrap_distributed.py +10 -7
  170. msprobe/pytorch/hook_module/wrap_functional.py +4 -7
  171. msprobe/pytorch/hook_module/wrap_npu_custom.py +21 -10
  172. msprobe/pytorch/hook_module/wrap_tensor.py +5 -6
  173. msprobe/pytorch/hook_module/wrap_torch.py +5 -7
  174. msprobe/pytorch/hook_module/wrap_vf.py +6 -8
  175. msprobe/pytorch/module_processer.py +53 -13
  176. msprobe/pytorch/online_dispatch/compare.py +4 -4
  177. msprobe/pytorch/online_dispatch/dispatch.py +39 -41
  178. msprobe/pytorch/online_dispatch/dump_compare.py +17 -47
  179. msprobe/pytorch/online_dispatch/single_compare.py +5 -5
  180. msprobe/pytorch/online_dispatch/utils.py +2 -43
  181. msprobe/pytorch/parse_tool/lib/compare.py +31 -19
  182. msprobe/pytorch/parse_tool/lib/config.py +2 -1
  183. msprobe/pytorch/parse_tool/lib/parse_tool.py +4 -4
  184. msprobe/pytorch/parse_tool/lib/utils.py +34 -80
  185. msprobe/pytorch/parse_tool/lib/visualization.py +4 -3
  186. msprobe/pytorch/pt_config.py +100 -6
  187. msprobe/pytorch/service.py +104 -19
  188. mindstudio_probe-1.0.1.dist-info/RECORD +0 -228
  189. msprobe/mindspore/dump/api_kbk_dump.py +0 -55
  190. msprobe/pytorch/compare/acc_compare.py +0 -1024
  191. msprobe/pytorch/compare/highlight.py +0 -100
  192. msprobe/test/core_ut/common/test_utils.py +0 -345
  193. msprobe/test/core_ut/data_dump/test_data_collector.py +0 -47
  194. msprobe/test/core_ut/data_dump/test_json_writer.py +0 -183
  195. msprobe/test/core_ut/data_dump/test_scope.py +0 -151
  196. msprobe/test/core_ut/test_common_config.py +0 -152
  197. msprobe/test/core_ut/test_file_check.py +0 -218
  198. msprobe/test/core_ut/test_log.py +0 -109
  199. msprobe/test/mindspore_ut/test_api_kbk_dump.py +0 -51
  200. msprobe/test/mindspore_ut/test_debugger_config.py +0 -42
  201. msprobe/test/mindspore_ut/test_dump_tool_factory.py +0 -51
  202. msprobe/test/mindspore_ut/test_kernel_graph_dump.py +0 -66
  203. msprobe/test/mindspore_ut/test_kernel_graph_overflow_check.py +0 -63
  204. msprobe/test/mindspore_ut/test_ms_config.py +0 -69
  205. msprobe/test/mindspore_ut/test_overflow_check_tool_factory.py +0 -51
  206. msprobe/test/mindspore_ut/test_precision_debugger.py +0 -56
  207. msprobe/test/mindspore_ut/test_task_handler_factory.py +0 -58
  208. msprobe/test/pytorch_ut/advisor/test_advisor.py +0 -83
  209. msprobe/test/pytorch_ut/api_accuracy_checker/common/test_common_utils.py +0 -108
  210. msprobe/test/pytorch_ut/api_accuracy_checker/common/test_config.py +0 -39
  211. msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_algorithm.py +0 -112
  212. msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_api_precision_compare.py +0 -77
  213. msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_compare.py +0 -125
  214. msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_compare_column.py +0 -10
  215. msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_compare_utils.py +0 -43
  216. msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/dump.json +0 -179
  217. msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/forward.json +0 -63
  218. msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_data_generate.py +0 -99
  219. msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_multi_run_ut.py +0 -115
  220. msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_run_ut.py +0 -72
  221. msprobe/test/pytorch_ut/compare/test_acc_compare.py +0 -17
  222. msprobe/test/pytorch_ut/free_benchmark/perturbed_layers/test_perturbed_layser.py +0 -105
  223. msprobe/test/pytorch_ut/free_benchmark/result_handlers/test_result_handler.py +0 -121
  224. msprobe/test/pytorch_ut/free_benchmark/test_main.py +0 -101
  225. msprobe/test/pytorch_ut/functional/test_dump_module.py +0 -15
  226. msprobe/test/pytorch_ut/hook_module/test_api_registry.py +0 -130
  227. msprobe/test/pytorch_ut/hook_module/test_hook_module.py +0 -42
  228. msprobe/test/pytorch_ut/hook_module/test_wrap_aten.py +0 -65
  229. msprobe/test/pytorch_ut/hook_module/test_wrap_distributed.py +0 -35
  230. msprobe/test/pytorch_ut/hook_module/test_wrap_functional.py +0 -20
  231. msprobe/test/pytorch_ut/hook_module/test_wrap_tensor.py +0 -35
  232. msprobe/test/pytorch_ut/hook_module/test_wrap_torch.py +0 -43
  233. msprobe/test/pytorch_ut/hook_module/test_wrap_vf.py +0 -11
  234. msprobe/test/pytorch_ut/test_pt_config.py +0 -69
  235. msprobe/test/pytorch_ut/test_service.py +0 -59
  236. msprobe/test/resources/advisor.txt +0 -3
  237. msprobe/test/resources/compare_result_20230703104808.csv +0 -9
  238. msprobe/test/resources/compare_result_without_accuracy.csv +0 -9
  239. msprobe/test/resources/config.yaml +0 -3
  240. msprobe/test/resources/npu_test.pkl +0 -8
  241. msprobe/test/run_test.sh +0 -30
  242. msprobe/test/run_ut.py +0 -58
  243. msprobe/test/test_module_processer.py +0 -64
  244. {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.3.dist-info}/LICENSE +0 -0
  245. {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.3.dist-info}/WHEEL +0 -0
  246. {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.3.dist-info}/entry_points.txt +0 -0
  247. {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.3.dist-info}/top_level.txt +0 -0
  248. /msprobe/{pytorch → core}/advisor/advisor_const.py +0 -0
  249. /msprobe/pytorch/doc/{atat → msprobe}/321/207/342/226/223/342/225/233/321/205/342/225/221/320/266/321/205/342/225/226/320/265/321/205/320/225/342/225/226/321/206/320/245/342/226/221/321/206/320/235/320/276dump/321/206/320/260/320/227/321/205/320/227/320/226/321/206/320/220/320/267/321/210/320/223/342/225/234/321/205/320/257/342/225/221/321/207/342/225/221/342/224/220/321/206/320/232/320/265/321/205/320/241/320/232.md" +0 -0
@@ -1,3 +1,4 @@
1
+ import copy
1
2
  import os
2
3
  import zlib
3
4
  from dataclasses import asdict
@@ -5,18 +6,20 @@ from typing import List
5
6
 
6
7
  import numpy as np
7
8
  import torch
8
- from msprobe.core.common.exceptions import MsaccException
9
9
  from msprobe.core.common.file_check import path_len_exceeds_limit, change_mode
10
10
  from msprobe.core.common.log import logger
11
11
  from msprobe.core.common.const import Const, OverflowConst, FileCheckConst
12
12
  from msprobe.core.data_dump.data_processor.base import BaseDataProcessor, ModuleBackwardInputsOutputs, \
13
13
  ModuleForwardInputsOutputs, TensorStatInfo
14
14
  from msprobe.pytorch.free_benchmark import FreeBenchmarkCheck, UnequalRow
15
+ from msprobe.pytorch.common.utils import save_pt
16
+
15
17
 
16
18
  try:
17
19
  import torch_npu
20
+ is_gpu = False
18
21
  except ImportError:
19
- pass
22
+ is_gpu = True
20
23
 
21
24
 
22
25
  class PytorchDataProcessor(BaseDataProcessor):
@@ -68,6 +71,12 @@ class PytorchDataProcessor(BaseDataProcessor):
68
71
  tensor_stat.min = False not in data_clone
69
72
  elif not data_clone.shape:
70
73
  tensor_stat.max = tensor_stat.min = tensor_stat.mean = tensor_stat.norm = data_clone.item()
74
+ elif torch.is_complex(data_clone):
75
+ data_np = data_clone.cpu().numpy()
76
+ data_abs = np.abs(data_np)
77
+ tensor_stat.max = np.max(data_abs).item()
78
+ tensor_stat.min = np.min(data_abs).item()
79
+ tensor_stat.mean = np.mean(data_abs).item()
71
80
  else:
72
81
  if not data_clone.is_floating_point() or data_clone.dtype == torch.float64:
73
82
  data_clone = data_clone.float()
@@ -76,7 +85,39 @@ class PytorchDataProcessor(BaseDataProcessor):
76
85
  tensor_stat.mean = torch._C._VariableFunctionsClass.mean(data_clone).item()
77
86
  tensor_stat.norm = torch._C._VariableFunctionsClass.norm(data_clone).item()
78
87
  return tensor_stat
79
-
88
+
89
+ @staticmethod
90
+ def handle_tensor_extremum_nan_inf(tensor, operator):
91
+ data_clone = tensor.detach()
92
+ data_nan = torch._C._VariableFunctionsClass.isnan(data_clone)
93
+ if int(torch._C._VariableFunctionsClass.sum(data_nan)) == data_clone.numel():
94
+ return float('nan')
95
+ finite_mask = torch._C._VariableFunctionsClass.isfinite(data_clone)
96
+ if int(torch._C._VariableFunctionsClass.sum(finite_mask)) > 0:
97
+ finite_values = data_clone[finite_mask]
98
+ return torch._C._VariableFunctionsClass.max(finite_values).item() if operator == 'max' else \
99
+ torch._C._VariableFunctionsClass.min(finite_values).item()
100
+ else:
101
+ data_no_nan = data_clone[~data_nan]
102
+ return torch._C._VariableFunctionsClass.max(data_no_nan).item() if operator == 'max' else \
103
+ torch._C._VariableFunctionsClass.min(data_no_nan).item()
104
+
105
+ @staticmethod
106
+ def _analyze_builtin(arg):
107
+ single_arg = {}
108
+ if isinstance(arg, slice):
109
+ single_arg.update({"type": "slice"})
110
+ # slice参数中可能存在tensor类型,json序列化,需要转换为python数值类型
111
+ values = [
112
+ value if not isinstance(value, torch.Tensor) else value.item()
113
+ for value in [arg.start, arg.stop, arg.step]
114
+ ]
115
+ single_arg.update({"value": values})
116
+ else:
117
+ single_arg.update({"type": type(arg).__name__})
118
+ single_arg.update({"value": arg})
119
+ return single_arg
120
+
80
121
  @staticmethod
81
122
  def _analyze_torch_size(arg):
82
123
  return {"type": "torch.Size", "value": list(arg)}
@@ -97,10 +138,7 @@ class PytorchDataProcessor(BaseDataProcessor):
97
138
  return self._analyze_tensor(element, Const.SEP.join(suffix_stack))
98
139
  if isinstance(element, (bool, int, float, str, slice)):
99
140
  return self._analyze_builtin(element)
100
- return None
101
-
102
- def analyze_element(self, element):
103
- return self.recursive_apply_transform(element, self.analyze_single_element)
141
+ return {}
104
142
 
105
143
  def _analyze_tensor(self, tensor, suffix):
106
144
  tensor_stat = self.get_stat_info(tensor)
@@ -113,9 +151,17 @@ class PytorchDataProcessor(BaseDataProcessor):
113
151
  tensor_json.update({"Mean": tensor_stat.mean})
114
152
  tensor_json.update({"Norm": tensor_stat.norm})
115
153
  tensor_json.update({"requires_grad": tensor.requires_grad})
116
- if self.config.summary_mode == "md5":
154
+
155
+ if tensor_stat.max is not None:
156
+ if np.isinf(tensor_stat.max) or np.isnan(tensor_stat.max):
157
+ tensor_json['Max_except_inf_nan'] = self.handle_tensor_extremum_nan_inf(tensor, "max")
158
+ if tensor_stat.min is not None:
159
+ if np.isinf(tensor_stat.min) or np.isnan(tensor_stat.min):
160
+ tensor_json['Min_except_inf_nan'] = self.handle_tensor_extremum_nan_inf(tensor, "min")
161
+
162
+ if self.config.summary_mode == Const.MD5:
117
163
  tensor_md5 = self.get_md5_for_tensor(tensor)
118
- tensor_json.update({"md5": tensor_md5})
164
+ tensor_json.update({Const.MD5: tensor_md5})
119
165
  return tensor_json
120
166
 
121
167
 
@@ -126,11 +172,8 @@ class StatisticsDataProcessor(PytorchDataProcessor):
126
172
  class TensorDataProcessor(PytorchDataProcessor):
127
173
  def _analyze_tensor(self, tensor, suffix):
128
174
  dump_data_name, file_path = self.get_save_file_path(suffix)
129
- if not path_len_exceeds_limit(file_path):
130
- torch.save(tensor, file_path)
131
- change_mode(file_path, FileCheckConst.DATA_FILE_AUTHORITY)
132
- else:
133
- logger.warning(f'The file path {file_path} length exceeds limit.')
175
+ saved_tensor = tensor.contiguous().detach()
176
+ save_pt(saved_tensor, file_path)
134
177
  single_arg = super()._analyze_tensor(tensor, suffix)
135
178
  single_arg.update({"data_name": dump_data_name})
136
179
  return single_arg
@@ -142,29 +185,36 @@ class OverflowCheckDataProcessor(PytorchDataProcessor):
142
185
  def __init__(self, config, data_writer):
143
186
  super().__init__(config, data_writer)
144
187
  self.cached_tensors_and_file_paths = {}
145
- self.real_overflow_dump_times = 0
146
- self.overflow_nums = config.overflow_num
147
188
  self.bits_for_overflow = 8
189
+ self.real_overflow_nums = 0
190
+ self.overflow_nums = config.overflow_nums
191
+ self.forward_inplace_inputs = None
192
+
193
+ @property
194
+ def is_terminated(self):
195
+ if self.overflow_nums == -1:
196
+ return False
197
+ if self.real_overflow_nums >= self.overflow_nums:
198
+ logger.info(f"[msprobe] 超过预设溢出次数 当前溢出次数: {self.real_overflow_nums}")
199
+ return True
200
+ return False
148
201
 
149
202
  @staticmethod
150
203
  def overflow_debug_mode_enable():
151
204
  overflow_mode = os.getenv(OverflowConst.OVERFLOW_DEBUG_MODE_ENABLE, Const.ENV_DISABLE)
152
205
  return overflow_mode == Const.ENV_ENABLE
153
206
 
154
- @staticmethod
155
- def handle_tensor_extremum_nan_inf(data_clone, operator):
156
- data_nan = torch._C._VariableFunctionsClass.isnan(data_clone)
157
- if int(torch._C._VariableFunctionsClass.sum(data_nan)) == data_clone.numel():
158
- return float('nan')
159
- finite_mask = torch._C._VariableFunctionsClass.isfinite(data_clone)
160
- if int(torch._C._VariableFunctionsClass.sum(finite_mask)) > 0:
161
- finite_values = data_clone[finite_mask]
162
- return torch._C._VariableFunctionsClass.max(finite_values).item() if operator == 'max' else \
163
- torch._C._VariableFunctionsClass.min(finite_values).item()
164
- else:
165
- data_no_nan = data_clone[~data_nan]
166
- return torch._C._VariableFunctionsClass.max(data_no_nan).item() if operator == 'max' else \
167
- torch._C._VariableFunctionsClass.min(data_no_nan).item()
207
+ def analyze_pre_forward_inplace(self, name, module_input_output: ModuleForwardInputsOutputs):
208
+ self.forward_inplace_inputs = copy.deepcopy(module_input_output)
209
+ return None
210
+
211
+ def analyze_forward_inplace(self, name, module_input_output: ModuleForwardInputsOutputs):
212
+ module_input_output.output = module_input_output.concat_args_and_kwargs()
213
+ module_input_output.args = self.forward_inplace_inputs.args
214
+ module_input_output.kwargs = self.forward_inplace_inputs.kwargs
215
+ # release memory used by forward inputs
216
+ self.forward_inplace_inputs = None
217
+ return self.analyze_forward(name, None, module_input_output)
168
218
 
169
219
  def analyze_forward(self, name, module, module_input_output: ModuleForwardInputsOutputs):
170
220
  self.has_overflow = False
@@ -181,20 +231,12 @@ class OverflowCheckDataProcessor(PytorchDataProcessor):
181
231
  def maybe_save_overflow_data_and_check_overflow_times(self):
182
232
  if self.has_overflow:
183
233
  for file_path, tensor in self.cached_tensors_and_file_paths.items():
184
- torch.save(tensor, file_path)
185
- change_mode(file_path, FileCheckConst.DATA_FILE_AUTHORITY)
186
- self.inc_and_check_overflow_times()
234
+ save_pt(tensor, file_path)
235
+ self.real_overflow_nums += 1
187
236
  self.cached_tensors_and_file_paths = {}
188
237
 
189
- def inc_and_check_overflow_times(self):
190
- self.real_overflow_dump_times += 1
191
- if self.overflow_nums == -1:
192
- return
193
- if self.real_overflow_dump_times >= self.overflow_nums:
194
- raise MsaccException(MsaccException.OVERFLOW_NUMS_ERROR, str(self.real_overflow_dump_times))
195
-
196
238
  def check_overflow_npu(self):
197
- if self.overflow_debug_mode_enalbe():
239
+ if self.overflow_debug_mode_enable():
198
240
  float_status = torch.zeros(self.bits_for_overflow).npu()
199
241
  result = torch_npu.npu_get_float_status(float_status, OverflowConst.OVERFLOW_DEBUG_MODE)
200
242
  if result.cpu()[0] != 0:
@@ -211,21 +253,22 @@ class OverflowCheckDataProcessor(PytorchDataProcessor):
211
253
  else:
212
254
  torch_npu._C._clear_overflow_npu()
213
255
 
214
- def _analyze_maybe_overflow_tensor(self, tensor_json, tensor):
215
- data_clone = tensor.detach()
216
- if hasattr(torch_npu._C, '_npu_is_support_inf_nan') and torch_npu._C._npu_is_support_inf_nan():
256
+ def _analyze_maybe_overflow_tensor(self, tensor_json):
257
+ if is_gpu or (hasattr(torch_npu._C, '_npu_is_support_inf_nan') and torch_npu._C._npu_is_support_inf_nan()):
217
258
  if tensor_json['Max'] is None:
218
259
  return
219
260
  if np.isinf(tensor_json['Max']) or np.isnan(tensor_json['Max']):
220
- tensor_json['Max_except_inf_nan'] = self.handle_tensor_extremum_nan_inf(data_clone, "max")
221
261
  self.has_overflow = True
222
262
  if np.isinf(tensor_json['Min']) or np.isnan(tensor_json['Min']):
223
- tensor_json['Min_except_inf_nan'] = self.handle_tensor_extremum_nan_inf(data_clone, "min")
224
263
  self.has_overflow = True
225
264
  else:
226
- self.has_overflow = self.check_overflow_npu()
227
- if self.has_overflow:
228
- self.clear_overflow_npu()
265
+ try:
266
+ self.has_overflow = self.check_overflow_npu()
267
+ if self.has_overflow:
268
+ self.clear_overflow_npu()
269
+ except Exception as e:
270
+ logger.error(f"Overflow check failed, the current environment may be abnormal.")
271
+ raise RuntimeError(f"overflow check failed") from e
229
272
 
230
273
  def _analyze_tensor(self, tensor, suffix):
231
274
  dump_data_name, file_path = self.get_save_file_path(suffix)
@@ -234,7 +277,7 @@ class OverflowCheckDataProcessor(PytorchDataProcessor):
234
277
  else:
235
278
  logger.warning(f'The file path {file_path} length exceeds limit.')
236
279
  single_arg = super()._analyze_tensor(tensor, suffix)
237
- self._analyze_maybe_overflow_tensor(single_arg, tensor)
280
+ self._analyze_maybe_overflow_tensor(single_arg)
238
281
  single_arg.update({"data_name": dump_data_name})
239
282
  return single_arg
240
283
 
@@ -280,7 +323,7 @@ class FreeBenchmarkDataProcessor(PytorchDataProcessor):
280
323
  self._forward_new_output = new_output
281
324
 
282
325
  def analyze_backward(self, name, module, module_input_output: ModuleBackwardInputsOutputs):
283
- self.checker.backward(name, module, module_input_output.grad_output)
326
+ self.checker.backward(name, module, module_input_output.grad_input)
284
327
 
285
328
 
286
329
  class KernelDumpDataProcessor(PytorchDataProcessor):
@@ -4,7 +4,7 @@ import fcntl
4
4
  import json
5
5
  from pathlib import Path
6
6
 
7
- from msprobe.core.common.file_check import change_mode
7
+ from msprobe.core.common.file_check import change_mode, FileOpen
8
8
  from msprobe.core.common.log import logger
9
9
  from msprobe.core.common.const import Const, FileCheckConst
10
10
 
@@ -30,20 +30,20 @@ class DataWriter:
30
30
  return
31
31
  is_exists = os.path.exists(file_path)
32
32
  append = "a+" if is_exists else "w+"
33
- with os.fdopen(
34
- os.open(file_path, Const.WRITE_FLAGS, FileCheckConst.DATA_FILE_AUTHORITY), append, newline=""
35
- ) as csv_file:
33
+ with FileOpen(file_path, append) as csv_file:
36
34
  spawn_writer = csv.writer(csv_file)
37
35
  if not is_exists:
38
36
  spawn_writer.writerow(result_header)
39
37
  spawn_writer.writerows([result,])
38
+ is_new_file = not is_exists
39
+ if is_new_file:
40
+ change_mode(file_path, FileCheckConst.DATA_FILE_AUTHORITY)
40
41
 
41
42
  def initialize_json_file(self, **kwargs):
42
43
  kwargs.update({"dump_data_dir": self.dump_tensor_data_dir, Const.DATA: {}})
43
- with os.fdopen(
44
- os.open(self.dump_file_path, Const.OVERWRITE_FLAGS, FileCheckConst.DATA_FILE_AUTHORITY), 'w'
45
- ) as f:
44
+ with FileOpen(self.dump_file_path, 'w') as f:
46
45
  json.dump(kwargs, f)
46
+ change_mode(self.dump_file_path, FileCheckConst.DATA_FILE_AUTHORITY)
47
47
 
48
48
  if os.path.exists(self.stack_file_path):
49
49
  os.remove(self.stack_file_path)
@@ -83,7 +83,7 @@ class DataWriter:
83
83
  def write_data_json(self, file_path):
84
84
  logger.info(f"dump.json is at {os.path.dirname(os.path.dirname(file_path))}. ")
85
85
  if Path(file_path).exists() and os.path.getsize(file_path) > 0:
86
- with open(file_path, "r+") as f:
86
+ with FileOpen(file_path, "r+") as f:
87
87
  fcntl.flock(f, fcntl.LOCK_EX)
88
88
  data_to_write = json.load(f)
89
89
  fcntl.flock(f, fcntl.LOCK_UN)
@@ -91,7 +91,7 @@ class DataWriter:
91
91
  self.init_json['data_path'] = self.dump_tensor_data_dir
92
92
  data_to_write = self.init_json
93
93
  data_to_write[Const.DATA].update(self.cache_data[Const.DATA])
94
- with open(file_path, 'w+') as f:
94
+ with FileOpen(file_path, 'w+') as f:
95
95
  fcntl.flock(f, fcntl.LOCK_EX)
96
96
  json.dump(data_to_write, f, indent=1)
97
97
  fcntl.flock(f, fcntl.LOCK_UN)
@@ -99,13 +99,13 @@ class DataWriter:
99
99
  self.cache_data[Const.DATA].clear()
100
100
 
101
101
  def write_stack_info_json(self, file_path):
102
- with open(file_path, 'w+') as f:
102
+ with FileOpen(file_path, 'w+') as f:
103
103
  fcntl.flock(f, fcntl.LOCK_EX)
104
104
  json.dump(self.cache_stack, f, indent=1)
105
105
  fcntl.flock(f, fcntl.LOCK_UN)
106
106
 
107
107
  def write_construct_info_json(self, file_path):
108
- with open(file_path, 'w+') as f:
108
+ with FileOpen(file_path, 'w+') as f:
109
109
  fcntl.flock(f, fcntl.LOCK_EX)
110
110
  json.dump(self.cache_construct, f, indent=1)
111
111
  fcntl.flock(f, fcntl.LOCK_UN)
File without changes
@@ -0,0 +1,71 @@
1
+
2
+ class GradConst:
3
+
4
+ FRAMEWORKS = {"PyTorch", "MindSpore"}
5
+ PYTORCH = "PyTorch"
6
+ MindSpore = "MindSpore"
7
+
8
+ GRAD_FILE_SUFFIX = {"npy", "pt"}
9
+ NPY_SUFFIX = "npy"
10
+ PT_SUFFIX = "pt"
11
+
12
+ # for callback
13
+ CURRENT_STEP = "current_step"
14
+
15
+ PARAM_LIST = "param_list"
16
+ RANK = "rank"
17
+ STEP = "step"
18
+ BOUNDS = "bounds"
19
+ OUTPUT_PATH = "output_path"
20
+
21
+ # level const
22
+ LEVEL = "level"
23
+ LEVEL0 = "L0"
24
+ LEVEL1 = "L1"
25
+ LEVEL2 = "L2"
26
+ SUPPORTED_LEVEL = {"L0", "L1", "L2"}
27
+
28
+ # numpy coding
29
+ STEP_IDX = 0
30
+ SHAPE_DIM_IDX = 4
31
+ MAX_SIZE = 10 * 1024 * 1024 * 1024
32
+
33
+ # direction suffix
34
+ DIR_SUFFIX = "dir.npy"
35
+
36
+ # file safty
37
+ DATA_DIR_AUTHORITY = 0o750
38
+ DATA_FILE_AUTHORITY = 0o640
39
+ DIRECTORY_LENGTH = 4096
40
+ FILE_NAME_LENGTH = 255
41
+ FILE_VALID_PATTERN = r"^[a-zA-Z0-9_.:/-]+$"
42
+ PARAM_VALID_PATTERN = r"^[a-zA-Z0-9_.]+$"
43
+ DIR = "dir"
44
+ FILE = "file"
45
+
46
+ STEP_FINISH = "step_finish"
47
+
48
+ SUMMARY = "summary"
49
+
50
+ # csv header entry
51
+ MD5 = "MD5"
52
+ DISTRIBUTION = "distribution"
53
+ SHAPE = "shape"
54
+ MAX = "max"
55
+ MIN = "min"
56
+ NORM = "norm"
57
+
58
+ level_adp = {
59
+ "L0": {
60
+ "header": [GradConst.MD5, GradConst.MAX, GradConst.MIN, GradConst.NORM, GradConst.SHAPE],
61
+ "have_grad_direction": False
62
+ },
63
+ "L1": {
64
+ "header": [GradConst.MAX, GradConst.MIN, GradConst.NORM, GradConst.SHAPE],
65
+ "have_grad_direction": True
66
+ },
67
+ "L2": {
68
+ "header": [GradConst.DISTRIBUTION, GradConst.MAX, GradConst.MIN, GradConst.NORM, GradConst.SHAPE],
69
+ "have_grad_direction": True
70
+ },
71
+ }
@@ -0,0 +1,175 @@
1
+ import os
2
+ from typing import List
3
+
4
+ from tqdm import tqdm
5
+ import pandas as pd
6
+ import matplotlib.pyplot as plt
7
+
8
+ from msprobe.core.common.utils import check_file_or_directory_path, check_path_before_create
9
+ from msprobe.core.common.file_check import create_directory
10
+ from msprobe.core.common.log import logger
11
+ from msprobe.core.common.utils import remove_path, write_csv, load_npy
12
+ from msprobe.core.grad_probe.constant import GradConst
13
+
14
+
15
+ class GradComparator:
16
+
17
+ @staticmethod
18
+ def _get_grad_weight_order(path1, path2):
19
+ for summary_file in os.listdir(path1):
20
+ if not summary_file.endswith(".csv"):
21
+ continue
22
+ if not os.path.exists(os.path.join(path2, summary_file)):
23
+ continue
24
+ summary_csv = pd.read_csv(os.path.join(path1, summary_file))
25
+ return summary_csv["param_name"]
26
+ raise RuntimeError("no matched grad_summary.csv for comparison, please dump data in same configuration")
27
+
28
+ @staticmethod
29
+ def _get_name_matched_grad_file(param_name, grad_files):
30
+ for grad_file in grad_files:
31
+ if param_name == grad_file[:grad_file.rfind('.')]:
32
+ return grad_file
33
+ raise RuntimeError("no matched grad_file for comparison, please dump data in same configuration")
34
+
35
+ @classmethod
36
+ def compare_distributed(cls, path1: str, path2: str, output_dir: str):
37
+ ranks = cls._get_matched_dirs(path1, path2, "rank")
38
+ logger.info(f"the following ranks will be compared: {ranks}")
39
+ if not ranks:
40
+ raise RuntimeError("no matched ranks for comparison, please dump data in same configuration")
41
+ if not os.path.isdir(output_dir):
42
+ create_directory(output_dir)
43
+ for rank in tqdm(ranks, desc="rank"):
44
+ logger.info(f"now comparing rank {rank}:")
45
+ cls.compare(os.path.join(path1, f"rank{rank}"),
46
+ os.path.join(path2, f"rank{rank}"),
47
+ os.path.join(output_dir, f"rank{rank}"))
48
+
49
+ @classmethod
50
+ def compare(cls, path1: str, path2: str, output_dir: str):
51
+ steps = cls._get_matched_dirs(path1, path2, "step")
52
+ if not steps:
53
+ raise RuntimeError("no matched steps for comparison, please dump data in same configuration")
54
+ similarities = cls._calculate_separated_similarities(path1, path2, steps)
55
+ if not os.path.isdir(output_dir):
56
+ create_directory(output_dir)
57
+ cls._save_similarities(similarities, steps, output_dir)
58
+
59
+ @classmethod
60
+ def _get_matched_dirs(cls, path1: str, path2: str, dir_prefix):
61
+ check_file_or_directory_path(path1, isdir=True)
62
+ check_file_or_directory_path(path2, isdir=True)
63
+ dirs = []
64
+ for dir_name in os.listdir(path1):
65
+ index = dir_name.replace(dir_prefix, "", 1)
66
+ if not dir_name.startswith(dir_prefix) or not index.isdigit():
67
+ continue
68
+
69
+ folder2 = os.path.join(path2, dir_name)
70
+ if not os.path.isdir(folder2):
71
+ continue
72
+ dirs.append(int(index))
73
+ dirs = sorted(dirs)
74
+ return dirs
75
+
76
+ @classmethod
77
+ def _save_similarities(cls, similarities: List[float], steps: List[int], output_dir: str):
78
+ if not similarities:
79
+ raise ValueError(f"length of similarities is 0")
80
+ result = [['step'] + [str(step) for step in steps]]
81
+ for key, value in tqdm(similarities.items(), desc="save similarities (by param)"):
82
+ if len(value) != len(steps):
83
+ raise RuntimeError(f"similarities length of {key}:{len(value)} not equal steps:{len(steps)}")
84
+ plt.plot(steps, value)
85
+ plt.xlabel('steps')
86
+ plt.ylabel('similarities')
87
+ plt.title(f'{key}_similarities')
88
+ picture_dir = os.path.join(output_dir, "similarities_picture")
89
+ if not os.path.isdir(picture_dir):
90
+ create_directory(picture_dir)
91
+ fig_save_path = os.path.join(picture_dir, f"{key}_similarities.png")
92
+
93
+ check_path_before_create(fig_save_path)
94
+ try:
95
+ plt.savefig(fig_save_path)
96
+ except Exception as e:
97
+ raise RuntimeError(f"save plt figure {fig_save_path} failed") from e
98
+ plt.close()
99
+
100
+ result.append([key] + value)
101
+ result_csv_path = os.path.join(output_dir, "similarities.csv")
102
+ if os.path.exists(result_csv_path):
103
+ logger.warning(f"{result_csv_path} will be recoverd")
104
+ remove_path(result_csv_path)
105
+ write_csv(result, result_csv_path)
106
+
107
+ @classmethod
108
+ def _calculate_separated_similarities(cls, path1, path2, steps):
109
+ similarities = {}
110
+ logger.info(f"{len(steps)} steps will be compared")
111
+ grad_weight_order = cls._get_grad_weight_order(path1, path2)
112
+ for step in tqdm(steps, desc="culculate similarities (by step)"):
113
+ grad_files = cls._get_matched_grad_files(path1, path2, step)
114
+ same_count_summary = 0
115
+ total_count_summary = 0
116
+ for grad_name in grad_weight_order:
117
+ grad_file = cls._get_name_matched_grad_file(grad_name, grad_files)
118
+ grad1 = os.path.join(path1, f"step{step}", grad_file)
119
+ grad2 = os.path.join(path2, f"step{step}", grad_file)
120
+ same_count, total_count = cls._calculate_similarity(grad1, grad2)
121
+ same_count_summary += same_count
122
+ total_count_summary += total_count
123
+ idx = grad_file.rfind(".")
124
+ param_name = grad_file[:idx]
125
+ if param_name not in similarities:
126
+ similarities[param_name] = []
127
+ if total_count == 0:
128
+ similarities[param_name].append(0)
129
+ else:
130
+ similarities[param_name].append(same_count / total_count)
131
+ if GradConst.SUMMARY not in similarities:
132
+ similarities[GradConst.SUMMARY] = []
133
+ if total_count_summary == 0:
134
+ similarities[GradConst.SUMMARY].append(0)
135
+ else:
136
+ similarities[GradConst.SUMMARY].append(same_count_summary / total_count_summary)
137
+ return similarities
138
+
139
+ @classmethod
140
+ def _get_matched_grad_files(cls, path1: str, path2: str, step: int):
141
+ path1 = os.path.join(path1, f"step{step}")
142
+ path2 = os.path.join(path2, f"step{step}")
143
+ check_file_or_directory_path(path1, isdir=True)
144
+ check_file_or_directory_path(path2, isdir=True)
145
+ grad_files = []
146
+ for grad_file in os.listdir(path1):
147
+ splits = grad_file.split('.')
148
+ if len(splits) < 1 or splits[-1] not in GradConst.GRAD_FILE_SUFFIX:
149
+ continue
150
+ folder2 = os.path.join(path2, grad_file)
151
+ if not os.path.exists(folder2):
152
+ continue
153
+ grad_files.append(grad_file)
154
+ return sorted(grad_files)
155
+
156
+ @classmethod
157
+ def _calculate_similarity(cls, grad_file1: str, grad_file2: str):
158
+ npy1, npy2 = cls._load_grad_files(grad_file1, grad_file2)
159
+ same_count = (npy1 == npy2).sum()
160
+ total_count = npy1.size
161
+ return same_count, total_count
162
+
163
+ @classmethod
164
+ def _load_grad_files(cls, grad_file1: str, grad_file2: str):
165
+ grad1 = load_npy(grad_file1)
166
+ grad2 = load_npy(grad_file2)
167
+ if grad1.shape != grad2.shape:
168
+ raise RuntimeError(f"tensor shape is not equal: {grad_file1}, {grad_file2}")
169
+ if grad1.dtype != bool:
170
+ raise TypeError(f"tensor type is not bool: {grad_file1}")
171
+ if grad2.dtype != bool:
172
+ raise TypeError(f"tensor type is not bool: {grad_file2}")
173
+ return grad1, grad2
174
+
175
+
@@ -0,0 +1,52 @@
1
+ import re
2
+ from msprobe.core.grad_probe.constant import GradConst
3
+ from msprobe.core.common.log import logger
4
+ from msprobe.core.common.utils import write_csv
5
+
6
+ def data_in_list_target(data, lst):
7
+ return not lst or len(lst) == 0 or data in lst
8
+
9
+
10
+ def check_numeral_list_ascend(lst):
11
+ if any(not isinstance(item, (int, float)) for item in lst):
12
+ raise Exception("The input list should only contain numbers")
13
+ if lst != sorted(lst):
14
+ raise Exception("The input list should be ascending")
15
+
16
+
17
+ def check_param(param_name):
18
+ if not re.match(GradConst.PARAM_VALID_PATTERN, param_name):
19
+ raise RuntimeError("The parameter name contains special characters.")
20
+
21
+
22
+ def check_str(string, variable_name):
23
+ if not isinstance(string, str):
24
+ raise ValueError(f'The variable: "{variable_name}" is not a string.')
25
+
26
+
27
+ class ListCache(list):
28
+ threshold = 1000
29
+
30
+ def __init__(self, *args):
31
+ super().__init__(*args)
32
+ self._output_file = None
33
+
34
+ def __del__(self):
35
+ self.flush()
36
+
37
+ def flush(self):
38
+ if len(self) == 0:
39
+ return
40
+ if not self._output_file:
41
+ logger.warning("dumpfile path is not setted")
42
+ write_csv(self, self._output_file)
43
+ logger.info(f"write {len(self)} items to {self._output_file}.")
44
+ self.clear()
45
+
46
+ def append(self, data):
47
+ list.append(self, data)
48
+ if len(self) >= ListCache.threshold:
49
+ self.flush()
50
+
51
+ def set_output_file(self, output_file):
52
+ self._output_file = output_file