mindstudio-probe 1.0.1__py3-none-any.whl → 1.0.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (249) hide show
  1. {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.3.dist-info}/METADATA +5 -1
  2. mindstudio_probe-1.0.3.dist-info/RECORD +272 -0
  3. msprobe/README.md +78 -23
  4. msprobe/__init__.py +1 -0
  5. msprobe/config/README.md +182 -40
  6. msprobe/config/config.json +22 -0
  7. msprobe/core/__init__.py +0 -0
  8. msprobe/{pytorch → core}/advisor/advisor.py +3 -3
  9. msprobe/{pytorch → core}/advisor/advisor_result.py +2 -2
  10. msprobe/core/common/const.py +82 -5
  11. msprobe/core/common/exceptions.py +30 -18
  12. msprobe/core/common/file_check.py +19 -1
  13. msprobe/core/common/log.py +15 -1
  14. msprobe/core/common/utils.py +130 -30
  15. msprobe/core/common_config.py +32 -19
  16. msprobe/core/compare/acc_compare.py +299 -0
  17. msprobe/core/compare/check.py +95 -0
  18. msprobe/core/compare/compare_cli.py +49 -0
  19. msprobe/core/compare/highlight.py +222 -0
  20. msprobe/core/compare/multiprocessing_compute.py +149 -0
  21. msprobe/{pytorch → core}/compare/npy_compare.py +55 -4
  22. msprobe/core/compare/utils.py +429 -0
  23. msprobe/core/data_dump/data_collector.py +39 -35
  24. msprobe/core/data_dump/data_processor/base.py +85 -37
  25. msprobe/core/data_dump/data_processor/factory.py +5 -7
  26. msprobe/core/data_dump/data_processor/mindspore_processor.py +198 -0
  27. msprobe/core/data_dump/data_processor/pytorch_processor.py +94 -51
  28. msprobe/core/data_dump/json_writer.py +11 -11
  29. msprobe/core/grad_probe/__init__.py +0 -0
  30. msprobe/core/grad_probe/constant.py +71 -0
  31. msprobe/core/grad_probe/grad_compare.py +175 -0
  32. msprobe/core/grad_probe/utils.py +52 -0
  33. msprobe/doc/grad_probe/grad_probe.md +207 -0
  34. msprobe/doc/grad_probe/img/image-1.png +0 -0
  35. msprobe/doc/grad_probe/img/image-2.png +0 -0
  36. msprobe/doc/grad_probe/img/image-3.png +0 -0
  37. msprobe/doc/grad_probe/img/image-4.png +0 -0
  38. msprobe/doc/grad_probe/img/image.png +0 -0
  39. msprobe/mindspore/api_accuracy_checker/__init__.py +0 -0
  40. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +246 -0
  41. msprobe/mindspore/api_accuracy_checker/api_info.py +69 -0
  42. msprobe/mindspore/api_accuracy_checker/api_runner.py +152 -0
  43. msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +197 -0
  44. msprobe/mindspore/api_accuracy_checker/compute_element.py +224 -0
  45. msprobe/mindspore/api_accuracy_checker/main.py +16 -0
  46. msprobe/mindspore/api_accuracy_checker/type_mapping.py +114 -0
  47. msprobe/mindspore/api_accuracy_checker/utils.py +63 -0
  48. msprobe/mindspore/cell_processor.py +34 -0
  49. msprobe/mindspore/common/const.py +87 -0
  50. msprobe/mindspore/common/log.py +38 -0
  51. msprobe/mindspore/common/utils.py +57 -0
  52. msprobe/mindspore/compare/distributed_compare.py +75 -0
  53. msprobe/mindspore/compare/ms_compare.py +117 -0
  54. msprobe/mindspore/compare/ms_graph_compare.py +317 -0
  55. msprobe/mindspore/compare/ms_to_pt_api.yaml +399 -0
  56. msprobe/mindspore/debugger/debugger_config.py +38 -15
  57. msprobe/mindspore/debugger/precision_debugger.py +79 -4
  58. msprobe/mindspore/doc/compare.md +58 -0
  59. msprobe/mindspore/doc/dump.md +158 -6
  60. msprobe/mindspore/dump/dump_tool_factory.py +19 -22
  61. msprobe/mindspore/dump/hook_cell/api_registry.py +104 -0
  62. msprobe/mindspore/dump/hook_cell/hook_cell.py +53 -0
  63. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +925 -0
  64. msprobe/mindspore/dump/hook_cell/wrap_functional.py +91 -0
  65. msprobe/mindspore/dump/hook_cell/wrap_tensor.py +63 -0
  66. msprobe/mindspore/dump/jit_dump.py +56 -0
  67. msprobe/mindspore/dump/kernel_kbyk_dump.py +65 -0
  68. msprobe/mindspore/free_benchmark/__init__.py +0 -0
  69. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +116 -0
  70. msprobe/mindspore/free_benchmark/common/__init__.py +0 -0
  71. msprobe/mindspore/free_benchmark/common/config.py +12 -0
  72. msprobe/mindspore/free_benchmark/common/handler_params.py +17 -0
  73. msprobe/mindspore/free_benchmark/common/utils.py +71 -0
  74. msprobe/mindspore/free_benchmark/data/support_wrap_ops.yaml +842 -0
  75. msprobe/mindspore/free_benchmark/decorator/__init__.py +0 -0
  76. msprobe/mindspore/free_benchmark/decorator/dec_forward.py +42 -0
  77. msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +107 -0
  78. msprobe/mindspore/free_benchmark/handler/__init__.py +0 -0
  79. msprobe/mindspore/free_benchmark/handler/base_handler.py +90 -0
  80. msprobe/mindspore/free_benchmark/handler/check_handler.py +41 -0
  81. msprobe/mindspore/free_benchmark/handler/fix_handler.py +36 -0
  82. msprobe/mindspore/free_benchmark/handler/handler_factory.py +21 -0
  83. msprobe/mindspore/free_benchmark/perturbation/add_noise.py +67 -0
  84. msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +21 -0
  85. msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +63 -0
  86. msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +34 -0
  87. msprobe/mindspore/free_benchmark/perturbation/no_change.py +12 -0
  88. msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +27 -0
  89. msprobe/mindspore/free_benchmark/self_check_tool_factory.py +33 -0
  90. msprobe/mindspore/grad_probe/__init__.py +0 -0
  91. msprobe/mindspore/grad_probe/global_context.py +91 -0
  92. msprobe/mindspore/grad_probe/grad_analyzer.py +231 -0
  93. msprobe/mindspore/grad_probe/grad_monitor.py +27 -0
  94. msprobe/mindspore/grad_probe/grad_stat_csv.py +132 -0
  95. msprobe/mindspore/grad_probe/hook.py +92 -0
  96. msprobe/mindspore/grad_probe/utils.py +29 -0
  97. msprobe/mindspore/ms_config.py +63 -15
  98. msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +17 -15
  99. msprobe/mindspore/runtime.py +4 -0
  100. msprobe/mindspore/service.py +354 -0
  101. msprobe/mindspore/task_handler_factory.py +7 -4
  102. msprobe/msprobe.py +66 -26
  103. msprobe/pytorch/__init__.py +1 -1
  104. msprobe/pytorch/api_accuracy_checker/common/config.py +21 -16
  105. msprobe/pytorch/api_accuracy_checker/common/utils.py +1 -60
  106. msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +2 -5
  107. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +46 -10
  108. msprobe/pytorch/api_accuracy_checker/compare/compare.py +84 -48
  109. msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +8 -12
  110. msprobe/pytorch/api_accuracy_checker/config.yaml +7 -1
  111. msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +15 -11
  112. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +11 -15
  113. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +16 -9
  114. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +193 -105
  115. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +68 -1
  116. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/__init__.py +0 -0
  117. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +202 -0
  118. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +324 -0
  119. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +204 -0
  120. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +218 -0
  121. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/ssl_config.py +10 -0
  122. msprobe/pytorch/bench_functions/__init__.py +15 -0
  123. msprobe/pytorch/bench_functions/apply_adam_w.py +28 -0
  124. msprobe/pytorch/bench_functions/confusion_transpose.py +19 -0
  125. msprobe/pytorch/bench_functions/fast_gelu.py +55 -0
  126. msprobe/pytorch/bench_functions/layer_norm_eval.py +6 -0
  127. msprobe/pytorch/bench_functions/linear.py +12 -0
  128. msprobe/pytorch/bench_functions/matmul_backward.py +48 -0
  129. msprobe/pytorch/bench_functions/npu_fusion_attention.py +421 -0
  130. msprobe/pytorch/bench_functions/rms_norm.py +15 -0
  131. msprobe/pytorch/bench_functions/rotary_mul.py +52 -0
  132. msprobe/pytorch/bench_functions/scaled_mask_softmax.py +26 -0
  133. msprobe/pytorch/bench_functions/swiglu.py +55 -0
  134. msprobe/pytorch/common/parse_json.py +3 -1
  135. msprobe/pytorch/common/utils.py +83 -7
  136. msprobe/pytorch/compare/distributed_compare.py +19 -64
  137. msprobe/pytorch/compare/match.py +3 -6
  138. msprobe/pytorch/compare/pt_compare.py +40 -0
  139. msprobe/pytorch/debugger/debugger_config.py +11 -2
  140. msprobe/pytorch/debugger/precision_debugger.py +34 -4
  141. msprobe/pytorch/doc/api_accuracy_checker.md +57 -13
  142. msprobe/pytorch/doc/api_accuracy_checker_online.md +187 -0
  143. msprobe/pytorch/doc/dump.md +73 -20
  144. msprobe/pytorch/doc/ptdbg_ascend_compare.md +75 -11
  145. msprobe/pytorch/doc/ptdbg_ascend_quickstart.md +3 -3
  146. msprobe/pytorch/doc/run_overflow_check.md +1 -1
  147. msprobe/pytorch/doc//321/206/320/247/320/260/321/206/320/260/320/227/321/206/320/255/320/226/321/205/342/225/226/320/265/321/205/320/225/342/225/226/321/205/320/254/342/225/221/321/206/320/251/320/277/321/211/320/272/320/234/321/210/320/277/320/221/321/205/320/242/320/234/321/206/320/220/320/267/321/210/320/223/342/225/234/321/205/320/257/342/225/221/321/207/342/225/221/342/224/220/321/206/320/232/320/265/321/205/320/241/320/232.md +151 -0
  148. msprobe/pytorch/free_benchmark/common/constant.py +3 -0
  149. msprobe/pytorch/free_benchmark/common/utils.py +4 -0
  150. msprobe/pytorch/free_benchmark/compare/grad_saver.py +22 -26
  151. msprobe/pytorch/free_benchmark/main.py +7 -4
  152. msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +1 -1
  153. msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +1 -1
  154. msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +1 -1
  155. msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +3 -3
  156. msprobe/pytorch/free_benchmark/perturbed_layers/npu/no_change.py +1 -1
  157. msprobe/pytorch/free_benchmark/perturbed_layers/run_cpu.py +1 -1
  158. msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +43 -29
  159. msprobe/pytorch/free_benchmark/result_handlers/handler_factory.py +0 -1
  160. msprobe/pytorch/function_factory.py +75 -0
  161. msprobe/pytorch/functional/dump_module.py +4 -4
  162. msprobe/pytorch/grad_probe/__init__.py +0 -0
  163. msprobe/pytorch/grad_probe/grad_monitor.py +90 -0
  164. msprobe/pytorch/grad_probe/grad_stat_csv.py +129 -0
  165. msprobe/pytorch/hook_module/hook_module.py +14 -3
  166. msprobe/pytorch/hook_module/support_wrap_ops.yaml +2 -1
  167. msprobe/pytorch/hook_module/utils.py +9 -9
  168. msprobe/pytorch/hook_module/wrap_aten.py +20 -10
  169. msprobe/pytorch/hook_module/wrap_distributed.py +10 -7
  170. msprobe/pytorch/hook_module/wrap_functional.py +4 -7
  171. msprobe/pytorch/hook_module/wrap_npu_custom.py +21 -10
  172. msprobe/pytorch/hook_module/wrap_tensor.py +5 -6
  173. msprobe/pytorch/hook_module/wrap_torch.py +5 -7
  174. msprobe/pytorch/hook_module/wrap_vf.py +6 -8
  175. msprobe/pytorch/module_processer.py +53 -13
  176. msprobe/pytorch/online_dispatch/compare.py +4 -4
  177. msprobe/pytorch/online_dispatch/dispatch.py +39 -41
  178. msprobe/pytorch/online_dispatch/dump_compare.py +17 -47
  179. msprobe/pytorch/online_dispatch/single_compare.py +5 -5
  180. msprobe/pytorch/online_dispatch/utils.py +2 -43
  181. msprobe/pytorch/parse_tool/lib/compare.py +31 -19
  182. msprobe/pytorch/parse_tool/lib/config.py +2 -1
  183. msprobe/pytorch/parse_tool/lib/parse_tool.py +4 -4
  184. msprobe/pytorch/parse_tool/lib/utils.py +34 -80
  185. msprobe/pytorch/parse_tool/lib/visualization.py +4 -3
  186. msprobe/pytorch/pt_config.py +100 -6
  187. msprobe/pytorch/service.py +104 -19
  188. mindstudio_probe-1.0.1.dist-info/RECORD +0 -228
  189. msprobe/mindspore/dump/api_kbk_dump.py +0 -55
  190. msprobe/pytorch/compare/acc_compare.py +0 -1024
  191. msprobe/pytorch/compare/highlight.py +0 -100
  192. msprobe/test/core_ut/common/test_utils.py +0 -345
  193. msprobe/test/core_ut/data_dump/test_data_collector.py +0 -47
  194. msprobe/test/core_ut/data_dump/test_json_writer.py +0 -183
  195. msprobe/test/core_ut/data_dump/test_scope.py +0 -151
  196. msprobe/test/core_ut/test_common_config.py +0 -152
  197. msprobe/test/core_ut/test_file_check.py +0 -218
  198. msprobe/test/core_ut/test_log.py +0 -109
  199. msprobe/test/mindspore_ut/test_api_kbk_dump.py +0 -51
  200. msprobe/test/mindspore_ut/test_debugger_config.py +0 -42
  201. msprobe/test/mindspore_ut/test_dump_tool_factory.py +0 -51
  202. msprobe/test/mindspore_ut/test_kernel_graph_dump.py +0 -66
  203. msprobe/test/mindspore_ut/test_kernel_graph_overflow_check.py +0 -63
  204. msprobe/test/mindspore_ut/test_ms_config.py +0 -69
  205. msprobe/test/mindspore_ut/test_overflow_check_tool_factory.py +0 -51
  206. msprobe/test/mindspore_ut/test_precision_debugger.py +0 -56
  207. msprobe/test/mindspore_ut/test_task_handler_factory.py +0 -58
  208. msprobe/test/pytorch_ut/advisor/test_advisor.py +0 -83
  209. msprobe/test/pytorch_ut/api_accuracy_checker/common/test_common_utils.py +0 -108
  210. msprobe/test/pytorch_ut/api_accuracy_checker/common/test_config.py +0 -39
  211. msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_algorithm.py +0 -112
  212. msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_api_precision_compare.py +0 -77
  213. msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_compare.py +0 -125
  214. msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_compare_column.py +0 -10
  215. msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_compare_utils.py +0 -43
  216. msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/dump.json +0 -179
  217. msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/forward.json +0 -63
  218. msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_data_generate.py +0 -99
  219. msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_multi_run_ut.py +0 -115
  220. msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_run_ut.py +0 -72
  221. msprobe/test/pytorch_ut/compare/test_acc_compare.py +0 -17
  222. msprobe/test/pytorch_ut/free_benchmark/perturbed_layers/test_perturbed_layser.py +0 -105
  223. msprobe/test/pytorch_ut/free_benchmark/result_handlers/test_result_handler.py +0 -121
  224. msprobe/test/pytorch_ut/free_benchmark/test_main.py +0 -101
  225. msprobe/test/pytorch_ut/functional/test_dump_module.py +0 -15
  226. msprobe/test/pytorch_ut/hook_module/test_api_registry.py +0 -130
  227. msprobe/test/pytorch_ut/hook_module/test_hook_module.py +0 -42
  228. msprobe/test/pytorch_ut/hook_module/test_wrap_aten.py +0 -65
  229. msprobe/test/pytorch_ut/hook_module/test_wrap_distributed.py +0 -35
  230. msprobe/test/pytorch_ut/hook_module/test_wrap_functional.py +0 -20
  231. msprobe/test/pytorch_ut/hook_module/test_wrap_tensor.py +0 -35
  232. msprobe/test/pytorch_ut/hook_module/test_wrap_torch.py +0 -43
  233. msprobe/test/pytorch_ut/hook_module/test_wrap_vf.py +0 -11
  234. msprobe/test/pytorch_ut/test_pt_config.py +0 -69
  235. msprobe/test/pytorch_ut/test_service.py +0 -59
  236. msprobe/test/resources/advisor.txt +0 -3
  237. msprobe/test/resources/compare_result_20230703104808.csv +0 -9
  238. msprobe/test/resources/compare_result_without_accuracy.csv +0 -9
  239. msprobe/test/resources/config.yaml +0 -3
  240. msprobe/test/resources/npu_test.pkl +0 -8
  241. msprobe/test/run_test.sh +0 -30
  242. msprobe/test/run_ut.py +0 -58
  243. msprobe/test/test_module_processer.py +0 -64
  244. {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.3.dist-info}/LICENSE +0 -0
  245. {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.3.dist-info}/WHEEL +0 -0
  246. {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.3.dist-info}/entry_points.txt +0 -0
  247. {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.3.dist-info}/top_level.txt +0 -0
  248. /msprobe/{pytorch → core}/advisor/advisor_const.py +0 -0
  249. /msprobe/pytorch/doc/{atat → msprobe}/321/207/342/226/223/342/225/233/321/205/342/225/221/320/266/321/205/342/225/226/320/265/321/205/320/225/342/225/226/321/206/320/245/342/226/221/321/206/320/235/320/276dump/321/206/320/260/320/227/321/205/320/227/320/226/321/206/320/220/320/267/321/210/320/223/342/225/234/321/205/320/257/342/225/221/321/207/342/225/221/342/224/220/321/206/320/232/320/265/321/205/320/241/320/232.md" +0 -0
@@ -1,7 +1,6 @@
1
-
2
1
  import os
3
2
 
4
- from msprobe.core.data_dump.scope import build_scope, ListScope
3
+ from msprobe.core.data_dump.scope import build_scope, ListScope
5
4
  from msprobe.core.data_dump.json_writer import DataWriter
6
5
  from msprobe.core.common.log import logger
7
6
  from msprobe.core.common.const import Const
@@ -21,7 +20,8 @@ class DataCollector:
21
20
  self.config = config
22
21
  self.data_writer = DataWriter()
23
22
  self.data_processor = DataProcessorFactory.create_processor(self.config, self.data_writer)
24
- self.module_processor = DataProcessorFactory.get_module_processor(self.config.framework) if self.config.framework == Const.PT_FRAMEWORK else None
23
+ self.module_processor = DataProcessorFactory.get_module_processor(self.config.framework) \
24
+ if self.config.framework == Const.PT_FRAMEWORK else None
25
25
  self.module_count = {}
26
26
  if self.config.task == Const.FREE_BENCHMARK:
27
27
  self.scope = build_scope(ListScope, self.config.scope, self.config.list)
@@ -35,7 +35,7 @@ class DataCollector:
35
35
  @property
36
36
  def dump_file_path(self):
37
37
  return self.data_writer.dump_file_path
38
-
38
+
39
39
  @staticmethod
40
40
  def check_scope_and_pid(scope, name, pid):
41
41
  return (not scope or scope.check(name)) and pid == os.getpid()
@@ -43,10 +43,10 @@ class DataCollector:
43
43
  @staticmethod
44
44
  def is_inplace(module):
45
45
  return getattr(module, "op_is_inplace", False)
46
-
46
+
47
47
  def if_return_forward_new_output(self):
48
48
  return self.data_processor.if_return_forward_new_output()
49
-
49
+
50
50
  def get_forward_new_output(self):
51
51
  return self.data_processor.get_forward_new_output()
52
52
 
@@ -71,12 +71,11 @@ class DataCollector:
71
71
  backward_name = name.replace(Const.FORWARD, Const.BACKWARD)
72
72
  if self.check_scope_and_pid(self.scope, backward_name, pid):
73
73
  self.data_processor.analyze_pre_forward(backward_name, module, module_input_output)
74
- if not self.is_inplace(module):
74
+ if not self.is_inplace(module) or not self.check_scope_and_pid(self.scope, name, pid):
75
75
  return
76
76
  logger.info(f"API {name} is inplace.")
77
- if self.check_scope_and_pid(self.scope, name, pid):
78
- data_info = self.data_processor.analyze_pre_forward_inplace(name, module_input_output)
79
- self.update_data(data_info)
77
+ data_info = self.data_processor.analyze_pre_forward_inplace(name, module_input_output)
78
+ self.handle_data(name, data_info)
80
79
 
81
80
  def forward_data_collect(self, name, module, pid, module_input_output):
82
81
  self.update_construct(name)
@@ -88,8 +87,11 @@ class DataCollector:
88
87
  else:
89
88
  data_info = self.data_processor.analyze_forward_inplace(name, module_input_output)
90
89
  if self.config.level == "L2":
91
- return
90
+ return
92
91
  self.data_writer.update_stack(self.data_processor.analyze_api_call_stack(name))
92
+ if self.data_processor.is_terminated:
93
+ self.handle_data(name, data_info, use_buffer=False)
94
+ raise Exception("[msprobe] exit")
93
95
  self.handle_data(name, data_info)
94
96
 
95
97
  def backward_data_collect(self, name, module, pid, module_input_output):
@@ -98,43 +100,45 @@ class DataCollector:
98
100
  return
99
101
 
100
102
  data_info = self.data_processor.analyze_backward(name, module, module_input_output)
103
+ if self.data_processor.is_terminated:
104
+ self.handle_data(name, data_info, use_buffer=False)
105
+ raise Exception("[msprobe] exit")
106
+ self.handle_data(name, data_info)
107
+
108
+ def backward_input_data_collect(self, name, module, pid, module_input_output):
109
+ self.update_construct(name)
110
+ if not self.check_scope_and_pid(self.scope, name, pid):
111
+ return
112
+
113
+ data_info = self.data_processor.analyze_backward_input(name, module, module_input_output)
114
+ self.handle_data(name, data_info)
115
+
116
+ def backward_output_data_collect(self, name, module, pid, module_input_output):
117
+ self.update_construct(name)
118
+ if not self.check_scope_and_pid(self.scope, name, pid):
119
+ return
120
+
121
+ data_info = self.data_processor.analyze_backward_output(name, module, module_input_output)
101
122
  self.handle_data(name, data_info)
102
123
 
103
124
  def update_construct(self, name):
104
- if self.config.level not in DataCollector.level_without_construct:
125
+ if self.config.framework == Const.PT_FRAMEWORK and self.config.level not in DataCollector.level_without_construct:
105
126
  self.data_writer.update_construct({name: self.module_processor.api_parent_node})
106
127
  self.data_writer.update_construct(self.module_processor.module_node)
107
128
 
108
- def handle_data(self, name, data_info):
109
- msg = f"msProbe is collecting data on {name}. "
129
+ def handle_data(self, name, data_info, use_buffer=True):
110
130
  if data_info:
131
+ msg = f"msprobe is collecting data on {name}. "
111
132
  msg = self.update_data(data_info, msg)
112
133
  logger.info(msg)
113
- self.data_writer.flush_data_when_buffer_is_full()
114
-
115
- def module_count_func(self, name, name_template):
116
- module_name = name.split(Const.SEP)[-3]
117
- if "forward" in name_template:
118
- if module_name not in self.module_count:
119
- self.module_count[module_name] = [0, [0]]
120
- else:
121
- if self.module_count[module_name][-1] and \
122
- self.module_count[module_name][0] != self.module_count[module_name][-1][-1]:
123
- self.module_count[module_name][-1].pop()
124
- self.module_count[module_name][0] += 1
125
- self.module_count[module_name][-1].append(self.module_count[module_name][0])
126
- index = self.module_count[module_name][0]
134
+ if use_buffer:
135
+ self.data_writer.flush_data_when_buffer_is_full()
127
136
  else:
128
- backward_stack = self.module_count[module_name][-1] if module_name in self.module_count else []
129
- if not backward_stack:
130
- index = "abnormal"
131
- else:
132
- index = backward_stack.pop()
133
- return index
137
+ self.write_json()
134
138
 
135
139
  def update_dump_paths(self, *args):
136
140
  self.data_writer.update_dump_paths(*args)
137
141
  self.data_writer.initialize_json_file(task=self.config.task, level=self.config.level)
138
-
142
+
139
143
  def update_iter(self, current_iter):
140
144
  self.data_processor.update_iter(current_iter)
@@ -35,11 +35,29 @@ class ModuleBackwardInputsOutputs:
35
35
  @property
36
36
  def grad_input_tuple(self):
37
37
  return convert_tuple(self.grad_input)
38
-
38
+
39
39
  @property
40
40
  def grad_output_tuple(self):
41
- return convert_tuple(self.grad_output)
42
-
41
+ return convert_tuple(self.grad_output)
42
+
43
+
44
+ @dataclass
45
+ class ModuleBackwardInputs:
46
+ grad_input: Optional[Tuple]
47
+
48
+ @property
49
+ def grad_input_tuple(self):
50
+ return convert_tuple(self.grad_input)
51
+
52
+
53
+ @dataclass
54
+ class ModuleBackwardOutputs:
55
+ grad_output: Optional[Tuple]
56
+
57
+ @property
58
+ def grad_output_tuple(self):
59
+ return convert_tuple(self.grad_output)
60
+
43
61
 
44
62
  class TensorStatInfo:
45
63
  def __init__(self, max_val=None, min_val=None, mean_val=None, norm_val=None):
@@ -53,7 +71,7 @@ class BaseDataProcessor:
53
71
  _recursive_key_stack = []
54
72
  special_type = (np.integer, np.floating, np.bool_, np.complexfloating, np.str_, np.byte, np.unicode_,
55
73
  bool, int, float, str, slice)
56
-
74
+
57
75
  def __init__(self, config, data_writer):
58
76
  self.data_writer = data_writer
59
77
  self.config = config
@@ -65,11 +83,15 @@ class BaseDataProcessor:
65
83
  self.current_iter = 0
66
84
  self._return_forward_new_output = False
67
85
  self._forward_new_output = None
68
-
86
+
69
87
  @property
70
88
  def data_path(self):
71
89
  return self.data_writer.dump_tensor_data_dir
72
-
90
+
91
+ @property
92
+ def is_terminated(self):
93
+ return False
94
+
73
95
  @staticmethod
74
96
  def analyze_api_call_stack(name):
75
97
  stack_str = []
@@ -87,7 +109,17 @@ class BaseDataProcessor:
87
109
  stack_str.append(stack_line)
88
110
  stack_info_struct = {name: stack_str}
89
111
  return stack_info_struct
90
-
112
+
113
+ @staticmethod
114
+ def transfer_type(data):
115
+ dtype = str(type(data))
116
+ if 'int' in dtype:
117
+ return int(data)
118
+ elif 'float' in dtype:
119
+ return float(data)
120
+ else:
121
+ return data
122
+
91
123
  @staticmethod
92
124
  def _convert_numpy_to_builtin(arg):
93
125
  type_mapping = {
@@ -103,26 +135,15 @@ class BaseDataProcessor:
103
135
  if isinstance(arg, numpy_type):
104
136
  return builtin_type(arg), type(arg).__name__
105
137
  return arg, ''
106
-
138
+
107
139
  @staticmethod
108
140
  def _analyze_numpy(value, numpy_type):
109
141
  return {"type": numpy_type, "value": value}
110
-
111
- @staticmethod
112
- def _analyze_builtin(arg):
113
- single_arg = {}
114
- if isinstance(arg, slice):
115
- single_arg.update({"type": "slice"})
116
- single_arg.update({"value": [arg.start, arg.stop, arg.step]})
117
- else:
118
- single_arg.update({"type": type(arg).__name__})
119
- single_arg.update({"value": arg})
120
- return single_arg
121
-
142
+
122
143
  @classmethod
123
144
  def get_special_types(cls):
124
145
  return cls.special_type
125
-
146
+
126
147
  @classmethod
127
148
  def recursive_apply_transform(cls, args, transform):
128
149
  if isinstance(args, cls.get_special_types()):
@@ -177,13 +198,17 @@ class BaseDataProcessor:
177
198
  return (Const.ALL in self.config.data_mode or
178
199
  forward_backward in self.config.data_mode or
179
200
  input_output in self.config.data_mode)
180
-
181
- def analyze_pre_forward(self, name, module,module_input_output: ModuleForwardInputsOutputs):
201
+
202
+ def analyze_pre_forward(self, name, module, module_input_output: ModuleForwardInputsOutputs):
182
203
  pass
183
204
 
205
+ def analyze_element(self, element):
206
+ return self.recursive_apply_transform(element, self.analyze_single_element)
207
+
184
208
  def analyze_forward(self, name, module, module_input_output: ModuleForwardInputsOutputs):
185
209
  api_info_struct = {}
186
- if self.is_dump_for_data_mode(Const.FORWARD, Const.INPUT): # check whether data_mode contains forward or input
210
+ # check whether data_mode contains forward or input
211
+ if self.is_dump_for_data_mode(Const.FORWARD, Const.INPUT):
187
212
  api_info_struct[name] = {}
188
213
  self.api_data_category = Const.INPUT
189
214
  args_info_list = self.analyze_element(module_input_output.args_tuple)
@@ -192,13 +217,14 @@ class BaseDataProcessor:
192
217
  kwargs_info_list = self.analyze_element(module_input_output.kwargs)
193
218
  api_info_struct[name][Const.INPUT_KWARGS] = kwargs_info_list
194
219
 
195
- if self.is_dump_for_data_mode(Const.FORWARD, Const.OUTPUT): # check whether data_mode contains forward or output
220
+ # check whether data_mode contains forward or output
221
+ if self.is_dump_for_data_mode(Const.FORWARD, Const.OUTPUT):
196
222
  api_info_struct[name] = api_info_struct.get(name, {})
197
223
  self.api_data_category = Const.OUTPUT
198
224
  output_info_list = self.analyze_element(module_input_output.output_tuple)
199
225
  api_info_struct[name][Const.OUTPUT] = output_info_list
200
226
  return api_info_struct
201
-
227
+
202
228
  def analyze_pre_forward_inplace(self, name, module_input_output: ModuleForwardInputsOutputs):
203
229
  api_info_struct = {}
204
230
  if self.is_dump_for_data_mode(Const.FORWARD, Const.INPUT):
@@ -210,7 +236,7 @@ class BaseDataProcessor:
210
236
  kwargs_info_list = self.analyze_element(module_input_output.kwargs)
211
237
  api_info_struct[name][Const.INPUT_KWARGS] = kwargs_info_list
212
238
  return api_info_struct
213
-
239
+
214
240
  def analyze_forward_inplace(self, name, module_input_output: ModuleForwardInputsOutputs):
215
241
  concat_args = module_input_output.concat_args_and_kwargs()
216
242
  api_info_struct = {}
@@ -220,26 +246,48 @@ class BaseDataProcessor:
220
246
  output_info_list = self.analyze_element(concat_args)
221
247
  api_info_struct[name][Const.OUTPUT] = output_info_list
222
248
  return api_info_struct
223
-
249
+
224
250
  def analyze_backward(self, name, module, module_input_output: ModuleBackwardInputsOutputs):
225
251
  api_info_struct = {}
226
- if self.is_dump_for_data_mode(Const.BACKWARD, Const.OUTPUT):
252
+ if self.is_dump_for_data_mode(Const.BACKWARD, Const.INPUT):
227
253
  api_info_struct[name] = {}
228
- self.api_data_category = Const.OUTPUT
254
+ self.api_data_category = Const.INPUT
229
255
  input_info_list = self.analyze_element(module_input_output.grad_input_tuple)
230
- api_info_struct[name][Const.GRAD_INPUT] = input_info_list
256
+ api_info_struct[name][Const.INPUT] = input_info_list
231
257
 
232
- if self.is_dump_for_data_mode(Const.BACKWARD, Const.INPUT):
258
+ if self.is_dump_for_data_mode(Const.BACKWARD, Const.OUTPUT):
233
259
  api_info_struct[name] = api_info_struct.get(name, {})
234
- self.api_data_category = Const.INPUT
260
+ self.api_data_category = Const.OUTPUT
235
261
  output_info_list = self.analyze_element(module_input_output.grad_output_tuple)
236
- api_info_struct[name][Const.GRAD_OUTPUT] = output_info_list
262
+ api_info_struct[name][Const.OUTPUT] = output_info_list
263
+
264
+ return api_info_struct
265
+
266
+ def analyze_backward_input(self, name, module,
267
+ module_input_output: ModuleBackwardInputs):
268
+ api_info_struct = {}
269
+ if self.is_dump_for_data_mode(Const.BACKWARD, Const.INPUT):
270
+ api_info_struct[name] = {}
271
+ self.api_data_category = Const.INPUT
272
+
273
+ input_info_list = self.analyze_element(module_input_output.grad_input_tuple)
274
+ api_info_struct[name][Const.INPUT] = input_info_list
275
+ return api_info_struct
237
276
 
277
+ def analyze_backward_output(self, name, module,
278
+ module_input_output: ModuleBackwardOutputs):
279
+ api_info_struct = {}
280
+ if self.is_dump_for_data_mode(Const.BACKWARD, Const.OUTPUT):
281
+ api_info_struct[name] = {}
282
+ self.api_data_category = Const.OUTPUT
283
+
284
+ output_info_list = self.analyze_element(module_input_output.grad_output_tuple)
285
+ api_info_struct[name][Const.OUTPUT] = output_info_list
238
286
  return api_info_struct
239
287
 
240
288
  def get_save_file_path(self, suffix):
241
- file_format = "pt" if self.config.framework == Const.PT_FRAMEWORK else "npy"
289
+ file_format = Const.PT_SUFFIX if self.config.framework == Const.PT_FRAMEWORK else Const.NUMPY_SUFFIX
242
290
  dump_data_name = (self.current_api_or_module_name + Const.SEP + self.api_data_category + Const.SEP +
243
- suffix + Const.SEP + file_format)
291
+ suffix + file_format)
244
292
  file_path = os.path.join(self.data_writer.dump_tensor_data_dir, dump_data_name)
245
- return dump_data_name, file_path
293
+ return dump_data_name, file_path
@@ -4,7 +4,7 @@ from msprobe.core.common.const import Const
4
4
  class DataProcessorFactory:
5
5
  _data_processor = {}
6
6
  _module_processor = {}
7
-
7
+
8
8
  @classmethod
9
9
  def register_processor(cls, framework, task, processor_class):
10
10
  key = (framework, task)
@@ -13,7 +13,7 @@ class DataProcessorFactory:
13
13
  @classmethod
14
14
  def register_module_processor(cls, framework, processor_class):
15
15
  cls._module_processor[framework] = processor_class
16
-
16
+
17
17
  @classmethod
18
18
  def get_module_processor(cls, framework):
19
19
  processor_class = cls._module_processor.get(framework)
@@ -39,7 +39,7 @@ class DataProcessorFactory:
39
39
  TensorDataProcessor as PytorchTensorDataProcessor,
40
40
  OverflowCheckDataProcessor as PytorchOverflowCheckDataProcessor,
41
41
  FreeBenchmarkDataProcessor as PytorchFreeBenchmarkDataProcessor,
42
- KernelDumpDataProcessor as PytorchKernelDumpDataProcessor
42
+ KernelDumpDataProcessor as PytorchKernelDumpDataProcessor
43
43
  )
44
44
  from ....pytorch.module_processer import ModuleProcesser
45
45
  cls.register_processor(Const.PT_FRAMEWORK, Const.STATISTICS, PytorchStatisticsDataProcessor)
@@ -47,15 +47,13 @@ class DataProcessorFactory:
47
47
  cls.register_processor(Const.PT_FRAMEWORK, Const.OVERFLOW_CHECK, PytorchOverflowCheckDataProcessor)
48
48
  cls.register_processor(Const.PT_FRAMEWORK, Const.FREE_BENCHMARK, PytorchFreeBenchmarkDataProcessor)
49
49
  cls.register_processor(Const.PT_FRAMEWORK, Const.KERNEL_DUMP, PytorchKernelDumpDataProcessor)
50
- cls.register_module_processor(Const.PT_FRAMEWORK, ModuleProcesser)
50
+ cls.register_module_processor(Const.PT_FRAMEWORK, ModuleProcesser)
51
51
  elif framework == Const.MS_FRAMEWORK:
52
52
  from .mindspore_processor import (
53
53
  StatisticsDataProcessor as MindsporeStatisticsDataProcessor,
54
54
  TensorDataProcessor as MindsporeTensorDataProcessor,
55
- OverflowCheckDataProcessor as MindsporeOverflowCheckDataProcessor,
56
- FreeBenchmarkDataProcessor as MindsporeFreeBenchmarkDataProcessor
55
+ OverflowCheckDataProcessor as MindsporeOverflowCheckDataProcessor
57
56
  )
58
57
  cls.register_processor(Const.MS_FRAMEWORK, Const.STATISTICS, MindsporeStatisticsDataProcessor)
59
58
  cls.register_processor(Const.MS_FRAMEWORK, Const.TENSOR, MindsporeTensorDataProcessor)
60
59
  cls.register_processor(Const.MS_FRAMEWORK, Const.OVERFLOW_CHECK, MindsporeOverflowCheckDataProcessor)
61
- cls.register_processor(Const.MS_FRAMEWORK, Const.FREE_BENCHMARK, MindsporeFreeBenchmarkDataProcessor)
@@ -0,0 +1,198 @@
1
+ # Copyright 2024 Huawei Technologies Co., Ltd
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ============================================================================
15
+
16
+ import zlib
17
+
18
+ import mindspore as ms
19
+ from mindspore import ops
20
+ import numpy as np
21
+
22
+ from msprobe.core.common.const import Const
23
+ from msprobe.core.data_dump.data_processor.base import (BaseDataProcessor, TensorStatInfo,
24
+ ModuleForwardInputsOutputs, ModuleBackwardInputsOutputs)
25
+ from msprobe.core.common.file_check import path_len_exceeds_limit
26
+ from msprobe.mindspore.dump.hook_cell.wrap_functional import load_ops_functions
27
+ from msprobe.mindspore.common.utils import convert_bf16_to_fp32, save_tensor_as_npy
28
+ from msprobe.mindspore.common.log import logger
29
+ from msprobe.mindspore.dump.hook_cell.api_registry import api_register
30
+
31
+
32
+ class MindsporeDataProcessor(BaseDataProcessor):
33
+ mindspore_special_type = tuple([ms.Tensor])
34
+ ops_func, mint_ops_func, _ = load_ops_functions()
35
+
36
+ def __init__(self, config, data_writer):
37
+ super().__init__(config, data_writer)
38
+ self.mindspore_object_key = {
39
+ "dtype": self.analyze_dtype_in_kwargs
40
+ }
41
+
42
+ @staticmethod
43
+ def get_md5_for_tensor(x):
44
+ x = convert_bf16_to_fp32(x)
45
+ tensor_bytes = x.asnumpy().tobytes()
46
+ crc32_hash = zlib.crc32(tensor_bytes)
47
+ return f"{crc32_hash:08x}"
48
+
49
+ @staticmethod
50
+ def analyze_dtype_in_kwargs(element):
51
+ return {"type": "mindspore.dtype", "value": str(element)}
52
+
53
+ @staticmethod
54
+ def _analyze_builtin(arg):
55
+ single_arg = {}
56
+ if isinstance(arg, slice):
57
+ single_arg.update({"type": "slice"})
58
+ # slice参数中可能存在tensor类型,json序列化,需要转换为python数值类型
59
+ values = [
60
+ value if not isinstance(value, ms.Tensor) else value.item()
61
+ for value in [arg.start, arg.stop, arg.step]
62
+ ]
63
+ single_arg.update({"value": values})
64
+ else:
65
+ single_arg.update({"type": type(arg).__name__})
66
+ single_arg.update({"value": arg})
67
+ return single_arg
68
+
69
+ @classmethod
70
+ def get_special_types(cls):
71
+ return super().get_special_types() + cls.mindspore_special_type
72
+
73
+ def get_stat_info(self, data):
74
+ tensor_stat = TensorStatInfo()
75
+ if data.numel() == 0:
76
+ return tensor_stat
77
+ elif data.dtype == ms.bool_:
78
+ data_np = data.asnumpy()
79
+ tensor_stat.max = np.max(data_np).item()
80
+ tensor_stat.min = np.min(data_np).item()
81
+ elif not data.shape:
82
+ tensor_stat.max = tensor_stat.min = tensor_stat.mean = tensor_stat.norm = data.item()
83
+ elif data.dtype == ms.complex64 or data.dtype == ms.complex128:
84
+ data_abs = np.abs(data.asnumpy())
85
+ tensor_stat.max = np.max(data_abs).item()
86
+ tensor_stat.min = np.min(data_abs).item()
87
+ tensor_stat.mean = np.mean(data_abs).item()
88
+ tensor_stat.norm = np.linalg.norm(data_abs).item()
89
+ else:
90
+ if data.dtype == ms.bfloat16 or not ops.is_floating_point(data):
91
+ data = data.to(ms.float32)
92
+ api_register.norm_inner_op_set_ori_func()
93
+ tensor_stat.max = self.mint_ops_func["max"](data).item()
94
+ tensor_stat.min = self.mint_ops_func["min"](data).item()
95
+ tensor_stat.mean = self.mint_ops_func["mean"](data).item()
96
+ tensor_stat.norm = self.ops_func["norm"](data).item()
97
+ api_register.norm_inner_op_set_hook_func()
98
+ return tensor_stat
99
+
100
+ def analyze_single_element(self, element, suffix_stack):
101
+ if suffix_stack and suffix_stack[-1] in self.mindspore_object_key:
102
+ return self.mindspore_object_key[suffix_stack[-1]](element)
103
+
104
+ converted_numpy, numpy_type = self._convert_numpy_to_builtin(element)
105
+ if converted_numpy is not element:
106
+ return self._analyze_numpy(converted_numpy, numpy_type)
107
+ if isinstance(element, ms.Tensor):
108
+ return self._analyze_tensor(element, Const.SEP.join(suffix_stack))
109
+
110
+ if isinstance(element, (bool, int, float, str, slice)):
111
+ return self._analyze_builtin(element)
112
+ return {}
113
+
114
+ def _analyze_tensor(self, tensor, suffix):
115
+ tensor_stat = self.get_stat_info(tensor)
116
+ tensor_json = {
117
+ 'type': 'mindspore.Tensor',
118
+ 'dtype': str(tensor.dtype),
119
+ 'shape': tensor.shape,
120
+ 'Max': self.transfer_type(tensor_stat.max),
121
+ 'Min': self.transfer_type(tensor_stat.min),
122
+ 'Mean': self.transfer_type(tensor_stat.mean),
123
+ 'Norm': self.transfer_type(tensor_stat.norm),
124
+ }
125
+ if self.config.summary_mode == Const.MD5:
126
+ tensor_md5 = self.get_md5_for_tensor(tensor)
127
+ tensor_json.update({Const.MD5: tensor_md5})
128
+ return tensor_json
129
+
130
+
131
+ class StatisticsDataProcessor(MindsporeDataProcessor):
132
+ pass
133
+
134
+
135
+ class TensorDataProcessor(MindsporeDataProcessor):
136
+ def _analyze_tensor(self, tensor, suffix):
137
+ dump_data_name, file_path = self.get_save_file_path(suffix)
138
+ single_arg = super()._analyze_tensor(tensor, suffix)
139
+ single_arg.update({"data_name": dump_data_name})
140
+ save_tensor_as_npy(tensor, file_path)
141
+ return single_arg
142
+
143
+
144
+ class OverflowCheckDataProcessor(MindsporeDataProcessor):
145
+ __slots__ = ["cached_tensors_and_file_paths"]
146
+
147
+ def __init__(self, config, data_writer):
148
+ super().__init__(config, data_writer)
149
+ self.cached_tensors_and_file_paths = {}
150
+ self.real_overflow_nums = 0
151
+ self.overflow_nums = config.overflow_nums
152
+
153
+ @property
154
+ def is_terminated(self):
155
+ if self.overflow_nums == -1:
156
+ return False
157
+ if self.real_overflow_nums >= self.overflow_nums:
158
+ logger.info(f"[msprobe] 超过预设溢出次数 当前溢出次数: {self.real_overflow_nums}")
159
+ return True
160
+ return False
161
+
162
+ def analyze_forward(self, name, module, module_input_output: ModuleForwardInputsOutputs):
163
+ self.has_overflow = False
164
+ api_info_struct = super().analyze_forward(name, module, module_input_output)
165
+ self.maybe_save_overflow_data()
166
+ return api_info_struct if self.has_overflow else None
167
+
168
+ def analyze_backward(self, name, module, module_input_output: ModuleBackwardInputsOutputs):
169
+ self.has_overflow = False
170
+ api_info_struct = super().analyze_backward(name, module, module_input_output)
171
+ self.maybe_save_overflow_data()
172
+ return api_info_struct if self.has_overflow else None
173
+
174
+ def maybe_save_overflow_data(self):
175
+ if self.has_overflow:
176
+ for file_path, tensor in self.cached_tensors_and_file_paths.items():
177
+ save_tensor_as_npy(tensor, file_path)
178
+ self.real_overflow_nums += 1
179
+ self.cached_tensors_and_file_paths = {}
180
+
181
+ def _analyze_maybe_overflow_tensor(self, tensor_json):
182
+ if tensor_json['Max'] is None:
183
+ return
184
+ if np.isinf(tensor_json['Max']) or np.isnan(tensor_json['Max']):
185
+ self.has_overflow = True
186
+ if np.isinf(tensor_json['Min']) or np.isnan(tensor_json['Min']):
187
+ self.has_overflow = True
188
+
189
+ def _analyze_tensor(self, tensor, suffix):
190
+ dump_data_name, file_path = self.get_save_file_path(suffix)
191
+ if not path_len_exceeds_limit(file_path):
192
+ self.cached_tensors_and_file_paths.update({file_path: tensor})
193
+ else:
194
+ logger.warning(f'The file path {file_path} length exceeds limit.')
195
+ single_arg = super()._analyze_tensor(tensor, suffix)
196
+ self._analyze_maybe_overflow_tensor(single_arg)
197
+ single_arg.update({"data_name": dump_data_name})
198
+ return single_arg