mindstudio-probe 1.1.1__py3-none-any.whl → 1.2.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (226) hide show
  1. {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.2.dist-info}/METADATA +3 -2
  2. mindstudio_probe-1.2.2.dist-info/RECORD +415 -0
  3. msprobe/CMakeLists.txt +5 -0
  4. msprobe/README.md +16 -21
  5. msprobe/config.json +1 -0
  6. msprobe/core/common/const.py +185 -11
  7. msprobe/core/common/exceptions.py +3 -1
  8. msprobe/core/common/file_utils.py +33 -7
  9. msprobe/core/common/inplace_ops.yaml +4 -0
  10. msprobe/core/common/utils.py +42 -14
  11. msprobe/core/common_config.py +6 -0
  12. msprobe/core/compare/acc_compare.py +139 -128
  13. msprobe/core/compare/check.py +31 -29
  14. msprobe/core/compare/compare_cli.py +17 -16
  15. msprobe/core/compare/highlight.py +186 -99
  16. msprobe/core/compare/layer_mapping/data_scope_parser.py +19 -8
  17. msprobe/core/compare/layer_mapping/layer_mapping.py +21 -14
  18. msprobe/core/compare/layer_mapping/postprocess_pass.py +4 -3
  19. msprobe/core/compare/merge_result/merge_result.py +381 -0
  20. msprobe/core/compare/merge_result/merge_result_cli.py +31 -0
  21. msprobe/core/compare/merge_result/utils.py +81 -0
  22. msprobe/core/compare/multiprocessing_compute.py +2 -2
  23. msprobe/core/compare/npy_compare.py +109 -147
  24. msprobe/core/compare/utils.py +199 -69
  25. msprobe/core/data_dump/data_collector.py +100 -25
  26. msprobe/core/data_dump/data_processor/base.py +130 -28
  27. msprobe/core/data_dump/data_processor/factory.py +8 -3
  28. msprobe/core/data_dump/data_processor/mindspore_processor.py +170 -23
  29. msprobe/core/data_dump/data_processor/pytorch_processor.py +175 -64
  30. msprobe/core/data_dump/json_writer.py +54 -8
  31. msprobe/core/data_dump/scope.py +19 -18
  32. msprobe/core/overflow_check/abnormal_scene.py +9 -5
  33. msprobe/core/overflow_check/checker.py +1 -1
  34. msprobe/core/overflow_check/utils.py +1 -1
  35. msprobe/docs/01.installation.md +121 -17
  36. msprobe/docs/02.config_introduction.md +18 -16
  37. msprobe/docs/03.config_examples.md +24 -0
  38. msprobe/docs/05.data_dump_PyTorch.md +107 -58
  39. msprobe/docs/06.data_dump_MindSpore.md +95 -34
  40. msprobe/docs/07.accuracy_checker_PyTorch.md +18 -18
  41. msprobe/docs/09.accuracy_checker_MindSpore.md +8 -6
  42. msprobe/docs/10.accuracy_compare_PyTorch.md +99 -41
  43. msprobe/docs/11.accuracy_compare_MindSpore.md +249 -48
  44. msprobe/docs/12.overflow_check_PyTorch.md +1 -1
  45. msprobe/docs/19.monitor.md +310 -220
  46. msprobe/docs/21.visualization_PyTorch.md +125 -35
  47. msprobe/docs/22.visualization_MindSpore.md +149 -41
  48. msprobe/docs/23.generate_operator_PyTorch.md +107 -0
  49. msprobe/docs/24.code_mapping_Mindspore.md +28 -0
  50. msprobe/docs/{23.tool_function_introduction.md → 25.tool_function_introduction.md} +1 -0
  51. msprobe/docs/26.data_dump_PyTorch_baseline.md +37 -0
  52. msprobe/docs/27.dump_json_instruction.md +525 -0
  53. msprobe/docs/28.debugger_save_instruction.md +94 -0
  54. msprobe/docs/28.kernel_dump_MindSpore.md +69 -0
  55. msprobe/docs/FAQ.md +26 -2
  56. msprobe/docs/accuracy_checker_MindSpore/accuracy_checker_MindSpore_baseline.md +14 -0
  57. msprobe/docs/data_dump_MindSpore/data_dump_MindSpore_baseline.md +22 -0
  58. msprobe/docs/img/merge_result.png +0 -0
  59. msprobe/docs/img/monitor/step_count_per_record.png +0 -0
  60. msprobe/docs/img/visualization/fuzzy_match_ms.png +0 -0
  61. msprobe/docs/img/visualization/fuzzy_match_pt.png +0 -0
  62. msprobe/docs/img/visualization/tensorboard_1.png +0 -0
  63. msprobe/docs/img/visualization/tensorboard_2.png +0 -0
  64. msprobe/docs/img/visualization/vis_browser_1.png +0 -0
  65. msprobe/docs/img/visualization/vis_browser_2.png +0 -0
  66. msprobe/docs/img/visualization/vis_precision_info.png +0 -0
  67. msprobe/docs/img/visualization/vis_search_info.png +0 -0
  68. msprobe/docs/img/visualization/vis_show_info.png +0 -0
  69. msprobe/docs/img/visualization/vis_showcase.png +0 -0
  70. msprobe/docs/img/visualization/vis_unmatch_info.png +0 -0
  71. msprobe/docs/visualization/GPTModel.png +0 -0
  72. msprobe/docs/visualization/ParallelMLP.png +0 -0
  73. msprobe/docs/visualization/layer_mapping_example.md +132 -0
  74. msprobe/docs/visualization/mapping.png +0 -0
  75. msprobe/docs/visualization/mapping1.png +0 -0
  76. msprobe/docs/visualization/module_name.png +0 -0
  77. msprobe/docs/visualization/module_name1.png +0 -0
  78. msprobe/docs/visualization/no_mapping.png +0 -0
  79. msprobe/docs/visualization/no_mapping1.png +0 -0
  80. msprobe/docs/visualization/no_mapping_analyze.png +0 -0
  81. msprobe/docs/visualization/top_layer.png +0 -0
  82. msprobe/mindspore/__init__.py +11 -0
  83. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +80 -28
  84. msprobe/mindspore/api_accuracy_checker/api_runner.py +54 -16
  85. msprobe/mindspore/api_accuracy_checker/cmd_parser.py +2 -1
  86. msprobe/mindspore/api_accuracy_checker/compute_element.py +52 -8
  87. msprobe/mindspore/api_accuracy_checker/data_manager.py +37 -0
  88. msprobe/mindspore/api_accuracy_checker/main.py +1 -0
  89. msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +12 -6
  90. msprobe/mindspore/api_accuracy_checker/multi_data_manager.py +3 -1
  91. msprobe/mindspore/api_accuracy_checker/torch_mindtorch_importer.py +129 -0
  92. msprobe/mindspore/api_accuracy_checker/type_mapping.py +24 -1
  93. msprobe/mindspore/api_accuracy_checker/utils.py +6 -1
  94. msprobe/mindspore/code_mapping/bind.py +264 -0
  95. msprobe/mindspore/code_mapping/cmd_parser.py +40 -0
  96. msprobe/mindspore/code_mapping/graph.py +49 -0
  97. msprobe/mindspore/code_mapping/graph_parser.py +226 -0
  98. msprobe/mindspore/code_mapping/main.py +24 -0
  99. msprobe/mindspore/code_mapping/processor.py +34 -0
  100. msprobe/mindspore/common/const.py +3 -1
  101. msprobe/mindspore/common/utils.py +68 -5
  102. msprobe/mindspore/compare/distributed_compare.py +0 -2
  103. msprobe/mindspore/compare/ms_compare.py +105 -63
  104. msprobe/mindspore/compare/ms_graph_compare.py +14 -5
  105. msprobe/mindspore/debugger/debugger_config.py +28 -2
  106. msprobe/mindspore/debugger/precision_debugger.py +100 -12
  107. msprobe/mindspore/dump/hook_cell/api_registry.py +85 -16
  108. msprobe/mindspore/dump/hook_cell/hook_cell.py +60 -38
  109. msprobe/mindspore/dump/hook_cell/primitive_hooks.py +33 -15
  110. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +11 -1
  111. msprobe/mindspore/dump/hook_cell/wrap_api.py +92 -1
  112. msprobe/mindspore/dump/jit_dump.py +7 -6
  113. msprobe/mindspore/dump/kernel_dump/kernel_config.py +33 -0
  114. msprobe/mindspore/dump/kernel_graph_dump.py +7 -0
  115. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +13 -4
  116. msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +2 -2
  117. msprobe/mindspore/grad_probe/grad_analyzer.py +24 -12
  118. msprobe/mindspore/grad_probe/hook.py +13 -4
  119. msprobe/mindspore/mindtorch/__init__.py +18 -0
  120. msprobe/mindspore/mindtorch/mindtorch_adaptor.py +255 -0
  121. msprobe/mindspore/monitor/anomaly_detect.py +404 -0
  122. msprobe/mindspore/monitor/distributed/__init__.py +0 -0
  123. msprobe/mindspore/monitor/distributed/distributed_ops.yaml +15 -0
  124. msprobe/mindspore/monitor/distributed/stack_blacklist.yaml +5 -0
  125. msprobe/mindspore/monitor/distributed/wrap_distributed.py +300 -0
  126. msprobe/mindspore/monitor/features.py +63 -0
  127. msprobe/mindspore/monitor/module_hook.py +821 -0
  128. msprobe/mindspore/monitor/module_spec_verifier.py +94 -0
  129. msprobe/mindspore/monitor/utils.py +267 -0
  130. msprobe/mindspore/ms_config.py +13 -3
  131. msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +7 -0
  132. msprobe/mindspore/service.py +347 -107
  133. msprobe/msprobe.py +24 -3
  134. msprobe/pytorch/__init__.py +7 -7
  135. msprobe/pytorch/api_accuracy_checker/common/utils.py +31 -16
  136. msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +41 -8
  137. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +100 -267
  138. msprobe/pytorch/api_accuracy_checker/compare/api_precision_standard.yaml +4 -1
  139. msprobe/pytorch/api_accuracy_checker/compare/compare.py +69 -68
  140. msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +54 -0
  141. msprobe/pytorch/api_accuracy_checker/compare/compare_input.py +51 -0
  142. msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +2 -4
  143. msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +55 -31
  144. msprobe/pytorch/api_accuracy_checker/precision_standard/absolute_threshold.py +106 -0
  145. msprobe/pytorch/api_accuracy_checker/precision_standard/accumulative_error_compare.py +107 -0
  146. msprobe/pytorch/api_accuracy_checker/precision_standard/base_standard.py +151 -0
  147. msprobe/pytorch/api_accuracy_checker/precision_standard/benchmark_compare.py +226 -0
  148. msprobe/pytorch/api_accuracy_checker/precision_standard/binary_consistency.py +68 -0
  149. msprobe/pytorch/api_accuracy_checker/precision_standard/standard_config.py +218 -0
  150. msprobe/pytorch/api_accuracy_checker/precision_standard/standard_register.py +104 -0
  151. msprobe/pytorch/api_accuracy_checker/precision_standard/thousandth_standard.py +63 -0
  152. msprobe/pytorch/api_accuracy_checker/precision_standard/ulp_compare.py +200 -0
  153. msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +57 -1
  154. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +2 -1
  155. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +42 -14
  156. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +64 -19
  157. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +34 -4
  158. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +5 -3
  159. msprobe/pytorch/bench_functions/apply_adam.py +215 -0
  160. msprobe/pytorch/bench_functions/group_norm_silu.py +27 -0
  161. msprobe/pytorch/bench_functions/mish.py +21 -0
  162. msprobe/pytorch/bench_functions/moe_gating_top_k_softmax.py +44 -0
  163. msprobe/pytorch/bench_functions/npu_fusion_attention.py +42 -10
  164. msprobe/pytorch/bench_functions/sort_v2.py +21 -0
  165. msprobe/pytorch/common/parse_json.py +2 -1
  166. msprobe/pytorch/common/utils.py +116 -2
  167. msprobe/pytorch/compare/distributed_compare.py +17 -29
  168. msprobe/pytorch/compare/pt_compare.py +40 -20
  169. msprobe/pytorch/debugger/debugger_config.py +42 -17
  170. msprobe/pytorch/debugger/precision_debugger.py +56 -12
  171. msprobe/pytorch/dump/module_dump/__init__.py +0 -0
  172. msprobe/pytorch/dump/module_dump/module_dump.py +86 -0
  173. msprobe/pytorch/dump/module_dump/module_processer.py +204 -0
  174. msprobe/pytorch/free_benchmark/common/params.py +2 -1
  175. msprobe/pytorch/free_benchmark/common/utils.py +3 -0
  176. msprobe/pytorch/free_benchmark/compare/grad_saver.py +0 -2
  177. msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +31 -47
  178. msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +0 -4
  179. msprobe/pytorch/function_factory.py +7 -1
  180. msprobe/pytorch/hook_module/__init__.py +1 -1
  181. msprobe/pytorch/hook_module/hook_module.py +14 -11
  182. msprobe/pytorch/hook_module/register_optimizer_hook.py +59 -0
  183. msprobe/pytorch/hook_module/support_wrap_ops.yaml +36 -1
  184. msprobe/pytorch/hook_module/wrap_distributed.py +10 -8
  185. msprobe/pytorch/hook_module/wrap_functional.py +0 -40
  186. msprobe/pytorch/monitor/anomaly_analyse.py +1 -1
  187. msprobe/pytorch/monitor/anomaly_detect.py +98 -28
  188. msprobe/pytorch/monitor/csv2tb.py +164 -0
  189. msprobe/pytorch/monitor/distributed/wrap_distributed.py +25 -14
  190. msprobe/pytorch/monitor/features.py +3 -3
  191. msprobe/pytorch/monitor/module_hook.py +543 -318
  192. msprobe/pytorch/monitor/module_metric.py +27 -48
  193. msprobe/pytorch/monitor/module_spec_verifier.py +3 -1
  194. msprobe/pytorch/monitor/optimizer_collect.py +76 -56
  195. msprobe/pytorch/monitor/unittest/test_monitor.py +24 -9
  196. msprobe/pytorch/monitor/utils.py +84 -48
  197. msprobe/pytorch/online_dispatch/dispatch.py +8 -2
  198. msprobe/pytorch/parse_tool/lib/compare.py +10 -10
  199. msprobe/pytorch/parse_tool/lib/config.py +5 -7
  200. msprobe/pytorch/parse_tool/lib/file_desc.py +15 -1
  201. msprobe/pytorch/parse_tool/lib/interactive_cli.py +10 -10
  202. msprobe/pytorch/parse_tool/lib/parse_exception.py +7 -7
  203. msprobe/pytorch/parse_tool/lib/parse_tool.py +11 -10
  204. msprobe/pytorch/parse_tool/lib/utils.py +18 -19
  205. msprobe/pytorch/parse_tool/lib/visualization.py +9 -10
  206. msprobe/pytorch/pt_config.py +19 -22
  207. msprobe/pytorch/service.py +264 -115
  208. msprobe/visualization/builder/graph_builder.py +93 -10
  209. msprobe/visualization/builder/msprobe_adapter.py +30 -6
  210. msprobe/visualization/compare/graph_comparator.py +64 -14
  211. msprobe/visualization/compare/mode_adapter.py +1 -15
  212. msprobe/visualization/graph/base_node.py +15 -19
  213. msprobe/visualization/graph/distributed_analyzer.py +395 -0
  214. msprobe/visualization/graph/graph.py +9 -0
  215. msprobe/visualization/graph/node_op.py +4 -2
  216. msprobe/visualization/graph_service.py +100 -27
  217. msprobe/visualization/utils.py +24 -31
  218. mindstudio_probe-1.1.1.dist-info/RECORD +0 -341
  219. msprobe/pytorch/functional/module_dump.py +0 -84
  220. msprobe/pytorch/module_processer.py +0 -150
  221. {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.2.dist-info}/LICENSE +0 -0
  222. {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.2.dist-info}/WHEEL +0 -0
  223. {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.2.dist-info}/entry_points.txt +0 -0
  224. {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.2.dist-info}/top_level.txt +0 -0
  225. /msprobe/docs/{data_dump_Mindspore → data_dump_MindSpore}/dynamic_graph_quick_start_example.md +0 -0
  226. /msprobe/{pytorch/functional → mindspore/code_mapping}/__init__.py +0 -0
@@ -1,4 +1,4 @@
1
- # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
1
+ # Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
2
2
  # All rights reserved.
3
3
  #
4
4
  # Licensed under the Apache License, Version 2.0 (the "License");
@@ -21,6 +21,7 @@ from typing import List
21
21
  import numpy as np
22
22
  import torch
23
23
  from torch import distributed as dist
24
+ from torch.distributed.distributed_c10d import _get_default_group
24
25
 
25
26
  from msprobe.core.common.const import Const
26
27
  from msprobe.core.common.file_utils import path_len_exceeds_limit
@@ -40,7 +41,16 @@ except ImportError:
40
41
 
41
42
 
42
43
  class PytorchDataProcessor(BaseDataProcessor):
43
- pytorch_special_type = (torch.device, torch.dtype, torch.Size, torch.Tensor, torch.memory_format, dist.ProcessGroup)
44
+ pytorch_special_type = (
45
+ torch.device,
46
+ torch.dtype,
47
+ torch.Size,
48
+ torch.Tensor,
49
+ torch.memory_format,
50
+ dist.ProcessGroup,
51
+ dist.P2POp,
52
+ dist.ReduceOp
53
+ )
44
54
  memory_format = {
45
55
  torch.contiguous_format: "contiguous_format",
46
56
  torch.channels_last: "channels_last",
@@ -54,6 +64,7 @@ class PytorchDataProcessor(BaseDataProcessor):
54
64
  "device": self.analyze_device_in_kwargs,
55
65
  "dtype": self.analyze_dtype_in_kwargs
56
66
  }
67
+ self._async_dump_cache = {}
57
68
 
58
69
  @staticmethod
59
70
  def get_md5_for_tensor(x):
@@ -82,49 +93,80 @@ class PytorchDataProcessor(BaseDataProcessor):
82
93
  return {"type": "torch.dtype", "value": str(element)}
83
94
 
84
95
  @staticmethod
85
- def get_stat_info(data):
96
+ def get_stat_info_async(data):
86
97
  tensor_stat = TensorStatInfo()
87
- if data.is_meta:
88
- return tensor_stat
89
- data_clone = data.detach()
90
- if data_clone.numel() == 0:
98
+ if torch.is_complex(data):
99
+ logger.warning("Async dump do not support complex data!")
91
100
  return tensor_stat
92
- elif data_clone.dtype == torch.bool:
93
- tensor_stat.max = torch._C._VariableFunctionsClass.any(data_clone).item()
94
- tensor_stat.min = torch._C._VariableFunctionsClass.all(data_clone).item()
95
- elif not data_clone.shape:
96
- tensor_stat.max = tensor_stat.min = tensor_stat.mean = tensor_stat.norm = data_clone.item()
97
- elif torch.is_complex(data_clone):
98
- data_np = data_clone.cpu().numpy()
101
+ elif data.dtype == torch.bool:
102
+ tensor_stat.stack_tensor_stat = (["Max", "Min"], torch.stack(
103
+ [torch.any(data), torch.all(data)]))
104
+ elif not data.shape:
105
+ tensor_stat.stack_tensor_stat = (["Max", "Min", "Mean", "Norm"], torch.stack([data, data, data, data]))
106
+ else:
107
+ if not data.is_floating_point() or data.dtype == torch.float64:
108
+ data = data.float()
109
+ tensor_stat.stack_tensor_stat = (["Max", "Min", "Mean", "Norm"], torch.stack([
110
+ torch.max(data),
111
+ torch.min(data),
112
+ torch.mean(data),
113
+ torch.norm(data)
114
+ ]))
115
+ return tensor_stat
116
+
117
+ @staticmethod
118
+ def get_stat_info_sync(data):
119
+ tensor_stat = TensorStatInfo()
120
+ if torch.is_complex(data):
121
+ data_np = data.cpu().numpy()
99
122
  data_abs = np.abs(data_np)
100
123
  tensor_stat.max = np.max(data_abs).item()
101
124
  tensor_stat.min = np.min(data_abs).item()
102
125
  tensor_stat.mean = np.mean(data_abs).item()
126
+ elif data.dtype == torch.bool:
127
+ tensor_stat.max = torch.any(data).item()
128
+ tensor_stat.min = torch.all(data).item()
129
+ elif not data.shape:
130
+ tensor_stat.max = tensor_stat.min = tensor_stat.mean = tensor_stat.norm = data.item()
103
131
  else:
104
- if not data_clone.is_floating_point() or data_clone.dtype == torch.float64:
105
- data_clone = data_clone.float()
106
- tensor_stat.max = torch._C._VariableFunctionsClass.max(data_clone).item()
107
- tensor_stat.min = torch._C._VariableFunctionsClass.min(data_clone).item()
108
- tensor_stat.mean = torch._C._VariableFunctionsClass.mean(data_clone).item()
109
- tensor_stat.norm = torch._C._VariableFunctionsClass.norm(data_clone).item()
132
+ if not data.is_floating_point() or data.dtype == torch.float64:
133
+ data = data.float()
134
+ tensor_stat.max = torch.max(data).item()
135
+ tensor_stat.min = torch.min(data).item()
136
+ tensor_stat.mean = torch.mean(data).item()
137
+ tensor_stat.norm = torch.norm(data).item()
110
138
  return tensor_stat
111
139
 
140
+ @staticmethod
141
+ def get_stat_info(data, async_dump=False):
142
+ tensor_stat = TensorStatInfo()
143
+ if data.is_meta:
144
+ return tensor_stat
145
+ data_clone = data.detach()
146
+ if data_clone.numel() == 0:
147
+ return tensor_stat
148
+ else:
149
+ if data_clone.device.type == Const.CPU_LOWERCASE or not async_dump:
150
+ return PytorchDataProcessor.get_stat_info_sync(data_clone)
151
+ else:
152
+ return PytorchDataProcessor.get_stat_info_async(data_clone)
153
+
112
154
  @staticmethod
113
155
  def handle_tensor_extremum_nan_inf(tensor, operator):
114
156
  data_clone = tensor.detach()
115
- data_nan = torch._C._VariableFunctionsClass.isnan(data_clone)
116
- if int(torch._C._VariableFunctionsClass.sum(data_nan)) == data_clone.numel():
157
+ data_nan = torch.isnan(data_clone)
158
+ if int(torch.sum(data_nan)) == data_clone.numel():
117
159
  return float('nan')
118
160
 
119
- finite_mask = torch._C._VariableFunctionsClass.isfinite(data_clone)
120
- if int(torch._C._VariableFunctionsClass.sum(finite_mask)) > 0:
121
- finite_values = getattr(torch._C._TensorBase, "__getitem__")(data_clone, finite_mask)
122
- return torch._C._VariableFunctionsClass.max(finite_values).item() if operator == 'max' else \
123
- torch._C._VariableFunctionsClass.min(finite_values).item()
161
+ finite_mask = torch.isfinite(data_clone)
162
+ if int(torch.sum(finite_mask)) > 0:
163
+ finite_values = data_clone[finite_mask]
164
+ return torch.max(finite_values).item() if operator == 'max' else \
165
+ torch.min(finite_values).item()
124
166
  else:
125
- data_no_nan = getattr(torch._C._TensorBase, "__getitem__")(data_clone, ~data_nan)
126
- return torch._C._VariableFunctionsClass.max(data_no_nan).item() if operator == 'max' else \
127
- torch._C._VariableFunctionsClass.min(data_no_nan).item()
167
+ data_no_nan = data_clone[~data_nan]
168
+ return torch.max(data_no_nan).item() if operator == 'max' else \
169
+ torch.min(data_no_nan).item()
128
170
 
129
171
  @staticmethod
130
172
  def process_group_hash(arg):
@@ -132,6 +174,15 @@ class PytorchDataProcessor(BaseDataProcessor):
132
174
  group_ranks_hash = hashlib.md5(str(group_ranks).encode('utf-8')).hexdigest()
133
175
  return group_ranks_hash
134
176
 
177
+ @staticmethod
178
+ def is_distributed_op(module):
179
+ return getattr(module, "op_is_distributed", False)
180
+
181
+ @staticmethod
182
+ def is_hookable_element(element):
183
+ return (hasattr(element, "register_hook") and callable(element.register_hook)) and \
184
+ (hasattr(element, "requires_grad") and element.requires_grad)
185
+
135
186
  @staticmethod
136
187
  def _analyze_torch_size(arg):
137
188
  return {"type": "torch.Size", "value": list(arg)}
@@ -140,7 +191,6 @@ class PytorchDataProcessor(BaseDataProcessor):
140
191
  def _analyze_memory_format(arg):
141
192
  # 获取内存格式
142
193
  format_type = PytorchDataProcessor.memory_format.get(arg)
143
-
144
194
  return {"type": "torch.memory_format", "format": format_type}
145
195
 
146
196
  @staticmethod
@@ -152,9 +202,18 @@ class PytorchDataProcessor(BaseDataProcessor):
152
202
  group_id = PytorchDataProcessor.process_group_hash(arg)
153
203
  group_info.update({"group_id": group_id})
154
204
  except Exception as e:
155
- logger.warning(f"Failed to get process group(id: {group_id}) ranks info with error info: {e}.")
205
+ logger.warning(f"Failed to get process group ranks info with error info: {e}.")
156
206
  return group_info
157
207
 
208
+ @staticmethod
209
+ def _analyze_reduce_op(arg):
210
+ op_type = None
211
+ try:
212
+ op_type = str(arg)
213
+ except Exception as e:
214
+ logger.warning(f"Failed to get value of torch.distributed.ReduceOp with error info: {e}.")
215
+ return {"type": "torch.distributed.ReduceOp", "value": op_type}
216
+
158
217
  @classmethod
159
218
  def get_special_types(cls):
160
219
  return super().get_special_types() + cls.pytorch_special_type
@@ -168,35 +227,65 @@ class PytorchDataProcessor(BaseDataProcessor):
168
227
  return self._analyze_memory_format(element)
169
228
  if isinstance(element, dist.ProcessGroup):
170
229
  return self._analyze_process_group(element)
230
+ if isinstance(element, dist.P2POp):
231
+ return self._analyze_p2pop(element)
232
+ if isinstance(element, dist.ReduceOp):
233
+ return self._analyze_reduce_op(element)
171
234
  converted_numpy, numpy_type = self._convert_numpy_to_builtin(element)
172
235
  if converted_numpy is not element:
173
- return self._analyze_numpy(converted_numpy, numpy_type)
236
+ return {"type": numpy_type, "value": converted_numpy}
174
237
  if isinstance(element, torch.Tensor):
175
- return self._analyze_tensor(element, Const.SEP.join(suffix_stack))
238
+ return self._analyze_tensor(element, Const.SEP.join([str(suffix) for suffix in suffix_stack]))
239
+ if isinstance(element, np.ndarray):
240
+ return self._analyze_numpy(element, Const.SEP.join([str(suffix) for suffix in suffix_stack]))
176
241
  if isinstance(element, (bool, int, float, str, slice, type(Ellipsis))):
177
242
  return self._analyze_builtin(element)
178
243
  return {}
179
244
 
245
+ def analyze_forward_output(self, name, module, module_input_output: ModuleForwardInputsOutputs):
246
+ if self.is_distributed_op(module):
247
+ module_input_output.update_output_with_args_and_kwargs()
248
+ return super().analyze_forward_output(name, module, module_input_output)
249
+
250
+ def _analyze_p2pop(self, arg):
251
+ p2pop_info = {"class_type": "torch.distributed.P2POp"}
252
+ try:
253
+ tensor_info = self._analyze_tensor(arg.tensor, [])
254
+ p2pop_info.update({"tensor": tensor_info})
255
+ p2pop_info.update({"op": arg.op.__name__})
256
+ p2pop_info.update({"peer": arg.peer})
257
+ p2pop_info.update({"tag": arg.tag})
258
+ group_id = PytorchDataProcessor.process_group_hash(
259
+ arg.group) if arg.group else PytorchDataProcessor.process_group_hash(_get_default_group())
260
+ p2pop_info.update({"group_id": group_id})
261
+ except Exception as e:
262
+ logger.warning(f"Failed to parse the P2POp content with error info: {e}.")
263
+ return p2pop_info
264
+
180
265
  def _analyze_tensor(self, tensor, suffix):
181
- tensor_stat = self.get_stat_info(tensor)
266
+ tensor_stat = self.get_stat_info(tensor, self.config.async_dump)
182
267
  tensor_json = {}
183
268
  tensor_json.update({'type': 'torch.Tensor'})
184
269
  tensor_json.update({'dtype': str(tensor.dtype)})
185
270
  tensor_json.update({"shape": tensor.shape})
186
- tensor_json.update({"Max": tensor_stat.max})
187
- tensor_json.update({"Min": tensor_stat.min})
188
- tensor_json.update({"Mean": tensor_stat.mean})
189
- tensor_json.update({"Norm": tensor_stat.norm})
190
- tensor_json.update({"requires_grad": tensor.requires_grad})
191
-
192
- if tensor_stat.max is not None:
193
- if np.isinf(tensor_stat.max) or np.isnan(tensor_stat.max):
194
- tensor_json['Max_except_inf_nan'] = self.handle_tensor_extremum_nan_inf(tensor, "max")
195
- if tensor_stat.min is not None:
196
- if np.isinf(tensor_stat.min) or np.isnan(tensor_stat.min):
197
- tensor_json['Min_except_inf_nan'] = self.handle_tensor_extremum_nan_inf(tensor, "min")
198
-
199
- if self.config.summary_mode == Const.MD5:
271
+ if tensor_stat.stack_tensor_stat is None:
272
+ tensor_json.update({"Max": tensor_stat.max})
273
+ tensor_json.update({"Min": tensor_stat.min})
274
+ tensor_json.update({"Mean": tensor_stat.mean})
275
+ tensor_json.update({"Norm": tensor_stat.norm})
276
+ tensor_json.update({"requires_grad": tensor.requires_grad})
277
+ if tensor_stat.max is not None:
278
+ if np.isinf(tensor_stat.max) or np.isnan(tensor_stat.max):
279
+ tensor_json['Max_except_inf_nan'] = self.handle_tensor_extremum_nan_inf(tensor, "max")
280
+ if tensor_stat.min is not None:
281
+ if np.isinf(tensor_stat.min) or np.isnan(tensor_stat.min):
282
+ tensor_json['Min_except_inf_nan'] = self.handle_tensor_extremum_nan_inf(tensor, "min")
283
+
284
+ else:
285
+ tensor_json.update({"requires_grad": tensor.requires_grad})
286
+ tensor_json.update({"tensor_stat": tensor_stat.stack_tensor_stat})
287
+
288
+ if self.config.summary_mode == Const.MD5 and not self.config.async_dump:
200
289
  tensor_md5 = self.get_md5_for_tensor(tensor)
201
290
  tensor_json.update({Const.MD5: tensor_md5})
202
291
  return tensor_json
@@ -207,13 +296,28 @@ class StatisticsDataProcessor(PytorchDataProcessor):
207
296
 
208
297
 
209
298
  class TensorDataProcessor(PytorchDataProcessor):
299
+ def dump_async_data(self):
300
+ for file_path, tensor in self._async_dump_cache.items():
301
+ save_pt(tensor.contiguous(), file_path)
302
+ self._async_dump_cache.clear()
303
+
210
304
  def _analyze_tensor(self, tensor, suffix):
211
305
  dump_data_name, file_path = self.get_save_file_path(suffix)
212
- saved_tensor = tensor.clone().contiguous().detach()
213
- save_pt(saved_tensor, file_path)
214
306
  single_arg = super()._analyze_tensor(tensor, suffix)
215
307
  single_arg.update({"data_name": dump_data_name})
308
+ if self.config.async_dump:
309
+ self._async_dump_cache[file_path] = tensor.clone().detach()
310
+ else:
311
+ saved_tensor = tensor.clone().contiguous().detach()
312
+ save_pt(saved_tensor, file_path)
216
313
  return single_arg
314
+
315
+ def _analyze_numpy(self, ndarray, suffix):
316
+ dump_data_name, file_path = self.get_save_file_path(suffix)
317
+ save_pt(torch.tensor(ndarray), file_path)
318
+ ndarray_json = super()._analyze_numpy(ndarray, suffix)
319
+ ndarray_json.update({"data_name": dump_data_name})
320
+ return ndarray_json
217
321
 
218
322
 
219
323
  class OverflowCheckDataProcessor(PytorchDataProcessor):
@@ -223,7 +327,7 @@ class OverflowCheckDataProcessor(PytorchDataProcessor):
223
327
  super().__init__(config, data_writer)
224
328
  self.has_overflow = False
225
329
  self.support_inf_nan = None
226
- self.cached_inplace_api_info = {}
330
+ self.cached_api_info = {}
227
331
  self.cached_tensors_and_file_paths = {}
228
332
  self.bits_for_overflow = 8
229
333
  self.real_overflow_nums = 0
@@ -237,21 +341,21 @@ class OverflowCheckDataProcessor(PytorchDataProcessor):
237
341
  return True
238
342
  return False
239
343
 
240
- def analyze_pre_forward_inplace(self, name, module_input_output: ModuleForwardInputsOutputs):
344
+ def analyze_forward_input(self, name, module, module_input_output: ModuleForwardInputsOutputs):
241
345
  self.has_overflow = False
242
346
  self._is_support_inf_nan()
243
- self.cached_inplace_api_info = super().analyze_pre_forward_inplace(name, module_input_output)
347
+ self.cached_api_info = super().analyze_forward_input(name, module, module_input_output)
244
348
  return None
245
349
 
246
- def analyze_forward_inplace(self, name, module_input_output: ModuleForwardInputsOutputs):
350
+ def analyze_forward_output(self, name, module, module_input_output: ModuleForwardInputsOutputs):
247
351
  self._is_support_inf_nan()
248
- api_info_struct = super().analyze_forward_inplace(name, module_input_output)
249
- if name in self.cached_inplace_api_info and name in api_info_struct:
250
- self.cached_inplace_api_info[name].update(api_info_struct[name])
352
+ api_info_struct = super().analyze_forward_output(name, module, module_input_output)
353
+ if name in self.cached_api_info and name in api_info_struct:
354
+ self.cached_api_info[name].update(api_info_struct[name])
251
355
  elif name in api_info_struct:
252
- self.cached_inplace_api_info = api_info_struct
356
+ self.cached_api_info = api_info_struct
253
357
  self.handle_overflow()
254
- return self.cached_inplace_api_info if self.has_overflow else None
358
+ return self.cached_api_info if self.has_overflow else None
255
359
 
256
360
  def analyze_forward(self, name, module, module_input_output: ModuleForwardInputsOutputs):
257
361
  self.has_overflow = False
@@ -267,6 +371,13 @@ class OverflowCheckDataProcessor(PytorchDataProcessor):
267
371
  self.handle_overflow()
268
372
  return api_info_struct if self.has_overflow else None
269
373
 
374
+ def analyze_params(self, name, param_name, grad):
375
+ self.has_overflow = False
376
+ self._is_support_inf_nan()
377
+ api_info_struct = super().analyze_params(name, param_name, grad)
378
+ self.handle_overflow()
379
+ return api_info_struct if self.has_overflow else None
380
+
270
381
  def handle_overflow(self):
271
382
  if not self.support_inf_nan:
272
383
  self._analyze_maybe_overflow_flag()
@@ -340,10 +451,10 @@ class FreeBenchmarkDataProcessor(PytorchDataProcessor):
340
451
  )
341
452
  return
342
453
 
343
- def analyze_pre_forward(self, name, module, module_input_output: ModuleForwardInputsOutputs):
454
+ def analyze_forward_input(self, name, module, module_input_output: ModuleForwardInputsOutputs):
344
455
  self.checker.pre_forward(name, module, self, module_input_output.args, module_input_output.kwargs)
345
456
 
346
- def analyze_forward(self, name, module, module_input_output: ModuleForwardInputsOutputs):
457
+ def analyze_forward_output(self, name, module, module_input_output: ModuleForwardInputsOutputs):
347
458
  new_output, unequal_rows = self.checker.forward(
348
459
  name,
349
460
  module,
@@ -388,7 +499,7 @@ class KernelDumpDataProcessor(PytorchDataProcessor):
388
499
  def _print_unsupported_log(api_name):
389
500
  logger.warning(f"The kernel dump does not support the {api_name} API.")
390
501
 
391
- def analyze_pre_forward(self, name, module, module_input_output):
502
+ def analyze_forward_input(self, name, module, module_input_output):
392
503
  if not self.enable_kernel_dump:
393
504
  return
394
505
  if is_gpu:
@@ -413,7 +524,7 @@ class KernelDumpDataProcessor(PytorchDataProcessor):
413
524
  return
414
525
  self.start_kernel_dump(self.config.kernel_config_path)
415
526
 
416
- def analyze_forward(self, name, module, module_input_output):
527
+ def analyze_forward_output(self, name, module, module_input_output):
417
528
  if not self.enable_kernel_dump:
418
529
  return
419
530
  if self.config.is_backward_kernel_dump:
@@ -15,10 +15,13 @@
15
15
 
16
16
  import csv
17
17
  import os
18
+ import copy
19
+ import numpy as np
18
20
 
19
21
  from msprobe.core.common.const import Const, FileCheckConst
20
- from msprobe.core.common.file_utils import change_mode, FileOpen, save_json
22
+ from msprobe.core.common.file_utils import change_mode, FileOpen, save_json, load_json
21
23
  from msprobe.core.common.log import logger
24
+ from msprobe.core.common.exceptions import MsprobeException
22
25
 
23
26
 
24
27
  class DataWriter:
@@ -29,10 +32,12 @@ class DataWriter:
29
32
  self.construct_file_path = None
30
33
  self.free_benchmark_file_path = None
31
34
  self.dump_tensor_data_dir = None
35
+ self.debug_file_path = None
32
36
  self.flush_size = 1000
33
37
  self.cache_data = {}
34
38
  self.cache_stack = {}
35
39
  self.cache_construct = {}
40
+ self.cache_debug = {}
36
41
 
37
42
  @staticmethod
38
43
  def write_data_to_csv(result: list, result_header: tuple, file_path: str):
@@ -55,6 +60,13 @@ class DataWriter:
55
60
  self.cache_construct = {}
56
61
 
57
62
  def initialize_json_file(self, **kwargs):
63
+ if self.debug_file_path and not self.cache_debug:
64
+ # debug level case only create debug.json
65
+ debug_dict = copy.deepcopy(kwargs)
66
+ debug_dict.update({"dump_data_dir": self.dump_tensor_data_dir, Const.DATA: {}})
67
+ self.cache_debug = debug_dict
68
+ save_json(self.debug_file_path, self.cache_debug, indent=1)
69
+ return
58
70
  if not self.cache_data:
59
71
  kwargs.update({"dump_data_dir": self.dump_tensor_data_dir, Const.DATA: {}})
60
72
  self.cache_data = kwargs
@@ -64,13 +76,13 @@ class DataWriter:
64
76
  if not self.cache_construct:
65
77
  save_json(self.construct_file_path, self.cache_construct, indent=1)
66
78
 
67
- def update_dump_paths(self, dump_file_path, stack_file_path, construct_file_path, dump_data_dir,
68
- free_benchmark_file_path):
69
- self.dump_file_path = dump_file_path
70
- self.stack_file_path = stack_file_path
71
- self.construct_file_path = construct_file_path
72
- self.dump_tensor_data_dir = dump_data_dir
73
- self.free_benchmark_file_path = free_benchmark_file_path
79
+ def update_dump_paths(self, dump_path_aggregation):
80
+ self.dump_file_path = dump_path_aggregation.dump_file_path
81
+ self.stack_file_path = dump_path_aggregation.stack_file_path
82
+ self.construct_file_path = dump_path_aggregation.construct_file_path
83
+ self.dump_tensor_data_dir = dump_path_aggregation.dump_tensor_data_dir
84
+ self.free_benchmark_file_path = dump_path_aggregation.free_benchmark_file_path
85
+ self.debug_file_path = dump_path_aggregation.debug_file_path
74
86
 
75
87
  def flush_data_periodically(self):
76
88
  dump_data = self.cache_data.get(Const.DATA)
@@ -98,6 +110,9 @@ class DataWriter:
98
110
  def update_construct(self, new_data):
99
111
  self.cache_construct.update(new_data)
100
112
 
113
+ def update_debug(self, new_data):
114
+ self.cache_debug['data'].update(new_data)
115
+
101
116
  def write_data_json(self, file_path):
102
117
  logger.info(f"dump.json is at {os.path.dirname(os.path.dirname(file_path))}. ")
103
118
  save_json(file_path, self.cache_data, indent=1)
@@ -108,6 +123,9 @@ class DataWriter:
108
123
  def write_construct_info_json(self, file_path):
109
124
  save_json(file_path, self.cache_construct, indent=1)
110
125
 
126
+ def write_debug_info_json(self, file_path):
127
+ save_json(file_path, self.cache_debug, indent=1)
128
+
111
129
  def write_json(self):
112
130
  if self.cache_data:
113
131
  self.write_data_json(self.dump_file_path)
@@ -115,3 +133,31 @@ class DataWriter:
115
133
  self.write_stack_info_json(self.stack_file_path)
116
134
  if self.cache_construct:
117
135
  self.write_construct_info_json(self.construct_file_path)
136
+ if self.cache_debug:
137
+ self.write_debug_info_json(self.debug_file_path)
138
+
139
+ def fill_stack_tensor_data(self):
140
+ self.process_stat_data_recursive(self.cache_data)
141
+
142
+ def process_stat_data_recursive(self, data, depth=0):
143
+ if depth > Const.MAX_DEPTH:
144
+ logger.error(f"The maximum depth of recursive process stat data, {Const.MAX_DEPTH} is reached.")
145
+ raise MsprobeException(MsprobeException.RECURSION_LIMIT_ERROR)
146
+ if isinstance(data, dict):
147
+ if "tensor_stat" in data.keys():
148
+ tensor_stat = data["tensor_stat"]
149
+ if len(tensor_stat) != Const.TENSOR_STAT_LEN or len(tensor_stat[0]) != len(tensor_stat[1]):
150
+ logger.warning("Some bad data in async dump")
151
+ else:
152
+ tensor_stat_index, tensor_stat_data = tensor_stat[0], tensor_stat[1]
153
+ if hasattr(tensor_stat_data, "device") and tensor_stat_data.device != Const.CPU_LOWERCASE:
154
+ tensor_stat_data = tensor_stat_data.cpu()
155
+ for index, stat in zip(tensor_stat_index, tensor_stat_data):
156
+ data.update({index: stat.item()})
157
+ del data["tensor_stat"]
158
+ else:
159
+ for key in data.keys():
160
+ self.process_stat_data_recursive(data[key], depth + 1)
161
+ elif isinstance(data, (list, tuple)):
162
+ for i in data:
163
+ self.process_stat_data_recursive(i, depth + 1)
@@ -45,7 +45,7 @@ class ScopeFactory:
45
45
 
46
46
  if self.level == Const.LEVEL_MIX:
47
47
  return mix_range_scope
48
-
48
+
49
49
  if not self.scope:
50
50
  return api_range_scope
51
51
  if api_range_scope.is_valid and module_range_scope.is_valid:
@@ -73,21 +73,21 @@ class BaseScope(ABC):
73
73
  def rectify_args(scope, api_list):
74
74
  if not isinstance(api_list, list):
75
75
  raise ScopeException(ScopeException.InvalidApiStr,
76
- f"api_list参数须配置为列表,实际类型为{type(api_list)}.")
76
+ f"api_list参数须配置为列表,实际类型为{type(api_list)}.")
77
77
  for api in api_list:
78
78
  if not isinstance(api, str):
79
79
  raise ScopeException(ScopeException.InvalidApiStr,
80
- f"api_list中的元素须配置为字符串,实际类型为{type(api)}.")
80
+ f"api_list中的元素须配置为字符串,实际类型为{type(api)}.")
81
81
  if isinstance(scope, str):
82
82
  scope = [scope]
83
83
  return scope, api_list
84
84
  if not isinstance(scope, list):
85
85
  raise ScopeException(ScopeException.InvalidScope,
86
- f"scope参数须配置为字符串或列表,实际类型为{type(scope)}.")
86
+ f"scope参数须配置为字符串或列表,实际类型为{type(scope)}.")
87
87
  for s in scope:
88
88
  if not isinstance(s, str):
89
89
  raise ScopeException(ScopeException.InvalidScope,
90
- f"scope列表元素要求类型为字符串,实际类型为{type(s)}.")
90
+ f"scope列表元素要求类型为字符串,实际类型为{type(s)}.")
91
91
  return scope, api_list
92
92
 
93
93
  @abstractmethod
@@ -108,7 +108,7 @@ class ListScope(BaseScope):
108
108
  def rectify_args(scope, api_list):
109
109
  if scope and api_list:
110
110
  raise ScopeException(ScopeException.ArgConflict,
111
- f"scope和api_list不可以同时配置,实际配置为scope={scope}, api_list={api_list}.")
111
+ f"scope和api_list不可以同时配置,实际配置为scope={scope}, api_list={api_list}.")
112
112
  return super(ListScope, ListScope).rectify_args(scope, api_list)
113
113
 
114
114
  def check(self, name):
@@ -123,6 +123,7 @@ class RangeScope(BaseScope, ABC):
123
123
  super().__init__(*args)
124
124
  self.in_scope = False
125
125
  self.in_list = False
126
+ self.start_name_set = set()
126
127
  self.is_valid = self.check_scope_is_valid()
127
128
 
128
129
  def check_name_pattern(self, name):
@@ -133,23 +134,23 @@ class RangeScope(BaseScope, ABC):
133
134
  if self.level == Const.LEVEL_L1:
134
135
  if not re.match(api_pattern, name):
135
136
  raise ScopeException(ScopeException.InvalidScope,
136
- f"scope参数格式错误,要求格式为api完整命名,实际为{name}.")
137
-
137
+ f"scope参数格式错误,要求格式为api完整命名,实际为{name}.")
138
+
138
139
  if self.level == Const.LEVEL_L0:
139
140
  if not re.match(module_pattern, name):
140
141
  raise ScopeException(ScopeException.InvalidScope,
141
- f"scope参数格式错误,要求格式为模块完整命名,实际为{name}.")
142
+ f"scope参数格式错误,要求格式为模块完整命名,实际为{name}.")
142
143
 
143
144
  if self.level == Const.LEVEL_MIX:
144
145
  if not re.match(api_pattern, name) and not re.match(module_pattern, name):
145
146
  raise ScopeException(ScopeException.InvalidScope,
146
- f"scope参数格式错误,要求格式为api或模块完整命名,实际为{name}.")
147
+ f"scope参数格式错误,要求格式为api或模块完整命名,实际为{name}.")
147
148
 
148
149
  def rectify_args(self, scope, api_list):
149
150
  scope, api_list = super(RangeScope, RangeScope).rectify_args(scope, api_list)
150
151
  if scope and len(scope) != 2:
151
152
  raise ScopeException(ScopeException.InvalidScope,
152
- f"scope参数指定区间断点,须传入长度为2的列表,实际长度为{len(scope)}.")
153
+ f"scope参数指定区间断点,须传入长度为2的列表,实际长度为{len(scope)}.")
153
154
  for name in scope:
154
155
  self.check_name_pattern(name)
155
156
  return scope, api_list
@@ -229,30 +230,31 @@ class ModuleRangeScope(RangeScope):
229
230
  class MixRangeScope(RangeScope):
230
231
  def check_scope_is_valid(self):
231
232
  return True if self.scope else False
232
-
233
+
233
234
  def begin_module(self, module_name):
234
235
  if self.scope and module_name == self.scope[0]:
235
236
  self.in_scope = True
236
237
  for name in self.api_list:
237
238
  if name in module_name:
238
239
  self.in_list = True
240
+ self.start_name_set.add(module_name) # 记录每一个开启in_list的module_name
239
241
 
240
242
  def end_module(self, module_name):
241
243
  if self.scope and module_name == self.scope[1]:
242
244
  self.in_scope = False
243
- for name in self.api_list:
244
- if name in module_name:
245
- self.in_list = False
245
+ self.start_name_set.discard(module_name) # 从集合中删除每一个module_name
246
+ if not self.start_name_set: # 如果集合为空,说明当前module_name是最后一个开启in_list的module_name
247
+ self.in_list = False # 关闭in_list
246
248
 
247
249
  def check_api_list(self, api_name):
248
250
  if not self.api_list:
249
251
  return True
250
-
252
+
251
253
  for name in self.api_list:
252
254
  if name in api_name:
253
255
  return True
254
256
  return False
255
-
257
+
256
258
  def check(self, name):
257
259
  """
258
260
  dump时调用的接口,根据scope和api_list判断是否需要dump
@@ -270,4 +272,3 @@ class MixRangeScope(RangeScope):
270
272
  if self.scope and name == self.scope[1]:
271
273
  self.in_scope = False
272
274
  return result
273
-
@@ -37,7 +37,11 @@ class AnomalyScene:
37
37
  @staticmethod
38
38
  def _has_anomaly(data: Union[Dict, Any]) -> bool:
39
39
  """检查张量是否包含异常值"""
40
- return has_nan_inf(data)
40
+ if isinstance(data, dict):
41
+ return has_nan_inf(data)
42
+ elif isinstance(data, list):
43
+ return any(AnomalyScene._has_anomaly(x) for x in data)
44
+ return False
41
45
 
42
46
  def get_details(self) -> Dict:
43
47
  """获取异常详情"""
@@ -70,14 +74,14 @@ class InputOutputAnomalyScene(AnomalyScene):
70
74
  def has_input_anomaly(self) -> bool:
71
75
  """检查输入是否有异常(包括args和kwargs)"""
72
76
  # args
73
- args_anomaly = any(self._has_anomaly(x) for x in self.api_data.input_args if isinstance(x, dict))
77
+ args_anomaly = any(self._has_anomaly(x) for x in self.api_data.input_args)
74
78
  # kwargs
75
- kwargs_anomaly = any(self._has_anomaly(x) for x in self.api_data.input_kwargs.values() if isinstance(x, dict))
79
+ kwargs_anomaly = any(self._has_anomaly(x) for x in self.api_data.input_kwargs.values())
76
80
  return args_anomaly or kwargs_anomaly
77
81
 
78
82
  def has_output_anomaly(self) -> bool:
79
83
  """检查输出是否有异常"""
80
- return any(self._has_anomaly(x) for x in self.api_data.output_data if isinstance(x, dict))
84
+ return any(self._has_anomaly(x) for x in self.api_data.output_data)
81
85
 
82
86
  def matches(self) -> bool:
83
87
  """判断是否匹配该场景"""
@@ -121,7 +125,7 @@ class NumericalMutationScene(AnomalyScene):
121
125
  """
122
126
  检查数值突变,统计输入args、kwargs中norm值,同时统计输出的norm最大值,计算差异,大于 threshold 则认为是异常情况
123
127
  """
124
- def __init__(self, api_info: APIInfo, threshold: float = 100000.0):
128
+ def __init__(self, api_info: APIInfo, threshold: float = 100.0):
125
129
  super().__init__(api_info)
126
130
  self.threshold = threshold
127
131