mindstudio-probe 1.1.0__py3-none-any.whl → 1.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (220) hide show
  1. {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.1.1.dist-info}/METADATA +5 -5
  2. mindstudio_probe-1.1.1.dist-info/RECORD +341 -0
  3. {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.1.1.dist-info}/WHEEL +1 -1
  4. {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.1.1.dist-info}/entry_points.txt +0 -1
  5. msprobe/README.md +39 -3
  6. msprobe/config.json +1 -3
  7. msprobe/core/advisor/advisor.py +8 -3
  8. msprobe/core/common/const.py +113 -13
  9. msprobe/core/common/exceptions.py +25 -3
  10. msprobe/core/common/file_utils.py +150 -26
  11. msprobe/core/common/inplace_op_checker.py +15 -0
  12. msprobe/core/common/log.py +27 -9
  13. msprobe/core/common/utils.py +182 -69
  14. msprobe/core/common_config.py +44 -15
  15. msprobe/core/compare/acc_compare.py +207 -142
  16. msprobe/core/compare/check.py +2 -5
  17. msprobe/core/compare/compare_cli.py +21 -4
  18. msprobe/core/compare/highlight.py +124 -55
  19. msprobe/core/compare/layer_mapping/__init__.py +19 -0
  20. msprobe/core/compare/layer_mapping/data_scope_parser.py +235 -0
  21. msprobe/core/compare/layer_mapping/layer_mapping.py +242 -0
  22. msprobe/core/compare/layer_mapping/postprocess_pass.py +94 -0
  23. msprobe/core/compare/npy_compare.py +52 -23
  24. msprobe/core/compare/utils.py +272 -247
  25. msprobe/core/data_dump/data_collector.py +13 -11
  26. msprobe/core/data_dump/data_processor/base.py +46 -16
  27. msprobe/core/data_dump/data_processor/mindspore_processor.py +4 -4
  28. msprobe/core/data_dump/data_processor/pytorch_processor.py +156 -59
  29. msprobe/core/data_dump/scope.py +113 -34
  30. msprobe/core/grad_probe/constant.py +27 -13
  31. msprobe/core/grad_probe/grad_compare.py +18 -1
  32. msprobe/core/grad_probe/utils.py +30 -2
  33. msprobe/core/overflow_check/abnormal_scene.py +185 -0
  34. msprobe/core/overflow_check/api_info.py +55 -0
  35. msprobe/core/overflow_check/checker.py +138 -0
  36. msprobe/core/overflow_check/filter.py +157 -0
  37. msprobe/core/overflow_check/ignore_rules.yaml +55 -0
  38. msprobe/core/overflow_check/level.py +22 -0
  39. msprobe/core/overflow_check/utils.py +28 -0
  40. msprobe/docs/01.installation.md +10 -0
  41. msprobe/docs/02.config_introduction.md +49 -22
  42. msprobe/docs/03.config_examples.md +2 -9
  43. msprobe/docs/04.kernel_dump_PyTorch.md +73 -0
  44. msprobe/docs/05.data_dump_PyTorch.md +3 -1
  45. msprobe/docs/06.data_dump_MindSpore.md +157 -90
  46. msprobe/docs/07.accuracy_checker_PyTorch.md +12 -12
  47. msprobe/docs/08.accuracy_checker_online_PyTorch.md +1 -6
  48. msprobe/docs/09.accuracy_checker_MindSpore.md +44 -8
  49. msprobe/docs/10.accuracy_compare_PyTorch.md +19 -13
  50. msprobe/docs/11.accuracy_compare_MindSpore.md +104 -13
  51. msprobe/docs/12.overflow_check_PyTorch.md +1 -1
  52. msprobe/docs/13.overflow_check_MindSpore.md +6 -6
  53. msprobe/docs/15.free_benchmarking_PyTorch.md +4 -5
  54. msprobe/docs/16.free_benchmarking_MindSpore.md +56 -37
  55. msprobe/docs/17.grad_probe.md +5 -6
  56. msprobe/docs/19.monitor.md +468 -0
  57. msprobe/docs/20.monitor_performance_baseline.md +52 -0
  58. msprobe/docs/21.visualization_PyTorch.md +386 -0
  59. msprobe/docs/22.visualization_MindSpore.md +384 -0
  60. msprobe/docs/23.tool_function_introduction.md +28 -0
  61. msprobe/docs/FAQ.md +3 -0
  62. msprobe/docs/data_dump_Mindspore/dynamic_graph_quick_start_example.md +211 -0
  63. msprobe/docs/img/compare_result.png +0 -0
  64. msprobe/docs/img/monitor/cpu_info.png +0 -0
  65. msprobe/mindspore/__init__.py +15 -0
  66. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +113 -145
  67. msprobe/mindspore/api_accuracy_checker/api_info.py +21 -6
  68. msprobe/mindspore/api_accuracy_checker/api_runner.py +43 -18
  69. msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +21 -7
  70. msprobe/mindspore/api_accuracy_checker/checker_support_api.yaml +77 -0
  71. msprobe/mindspore/api_accuracy_checker/cmd_parser.py +63 -1
  72. msprobe/mindspore/api_accuracy_checker/compute_element.py +59 -24
  73. msprobe/mindspore/api_accuracy_checker/data_manager.py +264 -0
  74. msprobe/mindspore/api_accuracy_checker/main.py +27 -3
  75. msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +206 -0
  76. msprobe/mindspore/api_accuracy_checker/multi_data_manager.py +58 -0
  77. msprobe/mindspore/api_accuracy_checker/type_mapping.py +22 -5
  78. msprobe/mindspore/api_accuracy_checker/utils.py +34 -17
  79. msprobe/mindspore/cell_processor.py +33 -12
  80. msprobe/mindspore/common/const.py +33 -13
  81. msprobe/mindspore/common/log.py +5 -9
  82. msprobe/mindspore/common/utils.py +43 -4
  83. msprobe/mindspore/compare/distributed_compare.py +22 -22
  84. msprobe/mindspore/compare/ms_compare.py +271 -248
  85. msprobe/mindspore/compare/ms_graph_compare.py +81 -47
  86. msprobe/mindspore/debugger/debugger_config.py +4 -1
  87. msprobe/mindspore/debugger/precision_debugger.py +7 -1
  88. msprobe/mindspore/dump/dump_tool_factory.py +3 -1
  89. msprobe/mindspore/dump/hook_cell/api_registry.py +12 -2
  90. msprobe/mindspore/dump/hook_cell/primitive_hooks.py +13 -16
  91. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +25 -0
  92. msprobe/mindspore/dump/jit_dump.py +17 -5
  93. msprobe/mindspore/dump/kernel_graph_dump.py +2 -4
  94. msprobe/mindspore/dump/kernel_kbyk_dump.py +2 -4
  95. msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +140 -0
  96. msprobe/mindspore/dym_loader/hook_dynamic_loader.h +53 -0
  97. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +145 -39
  98. msprobe/mindspore/free_benchmark/common/handler_params.py +1 -2
  99. msprobe/mindspore/free_benchmark/common/utils.py +19 -4
  100. msprobe/mindspore/free_benchmark/data/support_wrap_ops.yaml +0 -204
  101. msprobe/mindspore/free_benchmark/handler/base_handler.py +3 -3
  102. msprobe/mindspore/free_benchmark/handler/check_handler.py +4 -5
  103. msprobe/mindspore/free_benchmark/handler/fix_handler.py +4 -4
  104. msprobe/mindspore/free_benchmark/handler/handler_factory.py +4 -4
  105. msprobe/mindspore/free_benchmark/perturbation/add_noise.py +2 -2
  106. msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +15 -6
  107. msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +4 -4
  108. msprobe/mindspore/free_benchmark/perturbation/exchange_value.py +2 -2
  109. msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +13 -6
  110. msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +2 -2
  111. msprobe/mindspore/free_benchmark/self_check_tool_factory.py +2 -2
  112. msprobe/mindspore/grad_probe/global_context.py +28 -8
  113. msprobe/mindspore/grad_probe/grad_analyzer.py +27 -13
  114. msprobe/mindspore/grad_probe/grad_monitor.py +16 -1
  115. msprobe/mindspore/grad_probe/grad_stat_csv.py +33 -5
  116. msprobe/mindspore/grad_probe/hook.py +24 -10
  117. msprobe/mindspore/grad_probe/utils.py +18 -5
  118. msprobe/mindspore/ms_config.py +22 -15
  119. msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +2 -4
  120. msprobe/mindspore/runtime.py +15 -0
  121. msprobe/mindspore/service.py +36 -30
  122. msprobe/mindspore/task_handler_factory.py +15 -0
  123. msprobe/msprobe.py +24 -7
  124. msprobe/pytorch/__init__.py +3 -2
  125. msprobe/pytorch/api_accuracy_checker/common/config.py +62 -0
  126. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +3 -4
  127. msprobe/pytorch/api_accuracy_checker/generate_op_script/config_op.json +9 -0
  128. msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +454 -0
  129. msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +365 -0
  130. msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +6 -1
  131. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +19 -14
  132. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +13 -9
  133. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +77 -53
  134. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +15 -4
  135. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +9 -24
  136. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +4 -12
  137. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/dump_dispatch.py +9 -4
  138. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +3 -11
  139. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +2 -2
  140. msprobe/pytorch/bench_functions/confusion_transpose.py +5 -1
  141. msprobe/pytorch/bench_functions/matmul_backward.py +12 -0
  142. msprobe/pytorch/bench_functions/npu_fusion_attention.py +100 -6
  143. msprobe/pytorch/bench_functions/rotary_mul.py +4 -0
  144. msprobe/pytorch/bench_functions/swiglu.py +10 -2
  145. msprobe/pytorch/common/parse_json.py +6 -6
  146. msprobe/pytorch/common/utils.py +56 -5
  147. msprobe/pytorch/compare/distributed_compare.py +8 -9
  148. msprobe/pytorch/compare/pt_compare.py +8 -6
  149. msprobe/pytorch/debugger/debugger_config.py +19 -15
  150. msprobe/pytorch/dump/kernel_dump/kernel_config.py +33 -0
  151. msprobe/pytorch/free_benchmark/common/constant.py +15 -0
  152. msprobe/pytorch/free_benchmark/common/counter.py +15 -0
  153. msprobe/pytorch/free_benchmark/common/enums.py +15 -0
  154. msprobe/pytorch/free_benchmark/common/params.py +8 -1
  155. msprobe/pytorch/free_benchmark/common/utils.py +26 -4
  156. msprobe/pytorch/free_benchmark/compare/grad_saver.py +20 -3
  157. msprobe/pytorch/free_benchmark/compare/single_benchmark.py +2 -0
  158. msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +3 -1
  159. msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +6 -4
  160. msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +2 -0
  161. msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +4 -0
  162. msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +10 -0
  163. msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +6 -5
  164. msprobe/pytorch/grad_probe/grad_monitor.py +23 -6
  165. msprobe/pytorch/grad_probe/grad_stat_csv.py +40 -10
  166. msprobe/pytorch/hook_module/support_wrap_ops.yaml +1 -0
  167. msprobe/pytorch/hook_module/wrap_functional.py +14 -12
  168. msprobe/pytorch/module_processer.py +2 -5
  169. msprobe/pytorch/monitor/anomaly_analyse.py +201 -0
  170. msprobe/pytorch/monitor/anomaly_detect.py +340 -0
  171. msprobe/pytorch/monitor/distributed/__init__.py +0 -0
  172. msprobe/pytorch/monitor/distributed/distributed_ops.yaml +19 -0
  173. msprobe/pytorch/monitor/distributed/stack_blacklist.yaml +5 -0
  174. msprobe/pytorch/monitor/distributed/wrap_distributed.py +272 -0
  175. msprobe/pytorch/monitor/features.py +108 -0
  176. msprobe/pytorch/monitor/module_hook.py +870 -0
  177. msprobe/pytorch/monitor/module_metric.py +193 -0
  178. msprobe/pytorch/monitor/module_spec_verifier.py +93 -0
  179. msprobe/pytorch/monitor/optimizer_collect.py +295 -0
  180. msprobe/pytorch/monitor/unittest/__init__.py +0 -0
  181. msprobe/pytorch/monitor/unittest/test_monitor.py +145 -0
  182. msprobe/pytorch/monitor/utils.py +250 -0
  183. msprobe/pytorch/monitor/visualizer.py +59 -0
  184. msprobe/pytorch/online_dispatch/__init__.py +2 -3
  185. msprobe/pytorch/online_dispatch/compare.py +29 -38
  186. msprobe/pytorch/online_dispatch/dispatch.py +50 -25
  187. msprobe/pytorch/online_dispatch/dump_compare.py +21 -9
  188. msprobe/pytorch/online_dispatch/single_compare.py +53 -32
  189. msprobe/pytorch/online_dispatch/torch_ops_config.yaml +1 -1
  190. msprobe/pytorch/online_dispatch/utils.py +49 -21
  191. msprobe/pytorch/parse_tool/lib/compare.py +12 -18
  192. msprobe/pytorch/parse_tool/lib/config.py +1 -1
  193. msprobe/pytorch/parse_tool/lib/parse_tool.py +1 -2
  194. msprobe/pytorch/parse_tool/lib/utils.py +16 -35
  195. msprobe/pytorch/parse_tool/lib/visualization.py +2 -0
  196. msprobe/pytorch/pt_config.py +31 -8
  197. msprobe/pytorch/service.py +15 -5
  198. msprobe/visualization/__init__.py +14 -0
  199. msprobe/visualization/builder/__init__.py +14 -0
  200. msprobe/visualization/builder/graph_builder.py +165 -0
  201. msprobe/visualization/builder/msprobe_adapter.py +205 -0
  202. msprobe/visualization/compare/__init__.py +14 -0
  203. msprobe/visualization/compare/graph_comparator.py +130 -0
  204. msprobe/visualization/compare/mode_adapter.py +211 -0
  205. msprobe/visualization/graph/__init__.py +14 -0
  206. msprobe/visualization/graph/base_node.py +124 -0
  207. msprobe/visualization/graph/graph.py +200 -0
  208. msprobe/visualization/graph/node_colors.py +95 -0
  209. msprobe/visualization/graph/node_op.py +39 -0
  210. msprobe/visualization/graph_service.py +214 -0
  211. msprobe/visualization/utils.py +232 -0
  212. mindstudio_probe-1.1.0.dist-info/RECORD +0 -287
  213. msprobe/docs/04.acl_config_examples.md +0 -78
  214. msprobe/mindspore/compare/layer_mapping.py +0 -146
  215. msprobe/mindspore/compare/modify_mapping.py +0 -107
  216. msprobe/mindspore/free_benchmark/decorator/dec_forward.py +0 -57
  217. msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +0 -122
  218. {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.1.1.dist-info}/LICENSE +0 -0
  219. {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.1.1.dist-info}/top_level.txt +0 -0
  220. /msprobe/{mindspore/free_benchmark/decorator → pytorch/monitor}/__init__.py +0 -0
@@ -14,36 +14,48 @@
14
14
  # limitations under the License.
15
15
 
16
16
  from abc import ABC, abstractmethod
17
+ import re
17
18
 
18
19
  from msprobe.core.common.const import Const
19
20
  from msprobe.core.common.exceptions import ScopeException
20
21
 
21
22
 
22
- def build_scope(scope_class, scope=None, api_list=None):
23
- if not scope and not api_list:
24
- return None
25
- if scope is None:
26
- scope = []
27
- if api_list is None:
28
- api_list = []
29
- if scope_class:
30
- return scope_class(scope, api_list)
31
- return build_range_scope_according_to_scope_name(scope, api_list)
32
-
33
-
34
- def build_range_scope_according_to_scope_name(scope, api_list):
35
- api_range_scope = APIRangeScope(scope, api_list)
36
- module_range_scope = ModuleRangeScope(scope, api_list)
37
- if not scope: # 如果没有scope参数则用哪类scope都一样
38
- return api_range_scope
39
- if api_range_scope.is_valid and module_range_scope.is_valid:
40
- raise ScopeException(ScopeException.InvalidScope, f"scope={scope}.")
41
- elif api_range_scope.is_valid:
42
- return api_range_scope
43
- elif module_range_scope.is_valid:
44
- return module_range_scope
45
- else:
46
- raise ScopeException(ScopeException.InvalidScope, f"scope={scope}")
23
+ class ScopeFactory:
24
+ def __init__(self, config):
25
+ self.task = config.task
26
+ self.level = config.level
27
+ self.scope = config.scope
28
+ self.api_list = config.list
29
+
30
+ def build_scope(self):
31
+ if not self.scope and not self.api_list:
32
+ return None
33
+ if self.scope is None:
34
+ self.scope = []
35
+ if self.api_list is None:
36
+ self.api_list = []
37
+ if self.task == Const.FREE_BENCHMARK:
38
+ return ListScope(self.scope, self.api_list)
39
+ return self._build_range_scope()
40
+
41
+ def _build_range_scope(self):
42
+ api_range_scope = APIRangeScope(self.scope, self.api_list, self.level)
43
+ module_range_scope = ModuleRangeScope(self.scope, self.api_list, self.level)
44
+ mix_range_scope = MixRangeScope(self.scope, self.api_list, self.level)
45
+
46
+ if self.level == Const.LEVEL_MIX:
47
+ return mix_range_scope
48
+
49
+ if not self.scope:
50
+ return api_range_scope
51
+ if api_range_scope.is_valid and module_range_scope.is_valid:
52
+ raise ScopeException(ScopeException.InvalidScope, f"scope={self.scope}.")
53
+ elif api_range_scope.is_valid:
54
+ return api_range_scope
55
+ elif module_range_scope.is_valid:
56
+ return module_range_scope
57
+ else:
58
+ raise ScopeException(ScopeException.InvalidScope, f"scope={self.scope}")
47
59
 
48
60
 
49
61
  class BaseScope(ABC):
@@ -51,7 +63,8 @@ class BaseScope(ABC):
51
63
  Module_Type_API = "api"
52
64
  module_type = ["Module", "Cell"]
53
65
 
54
- def __init__(self, scope, api_list):
66
+ def __init__(self, scope, api_list, level=None):
67
+ self.level = level
55
68
  scope, api_list = self.rectify_args(scope, api_list)
56
69
  self.scope = scope
57
70
  self.api_list = api_list
@@ -109,17 +122,36 @@ class RangeScope(BaseScope, ABC):
109
122
  def __init__(self, *args):
110
123
  super().__init__(*args)
111
124
  self.in_scope = False
125
+ self.in_list = False
112
126
  self.is_valid = self.check_scope_is_valid()
113
127
 
114
- @staticmethod
115
- def rectify_args(scope, api_list):
116
- scope, api_list = super(RangeScope, RangeScope).rectify_args(scope, api_list)
117
- if isinstance(scope, list):
118
- if len(scope) == 1:
119
- scope.append(scope[0])
120
- elif len(scope) > 2:
128
+ def check_name_pattern(self, name):
129
+ options_pattern = "|".join(re.escape(option) for option in Const.DUMP_PREFIX)
130
+ api_pattern = rf"^({options_pattern})\..*\.\d+\.(forward|backward)$"
131
+ module_pattern = r"^(Cell|Module)\..*\.(forward|backward)\.\d+$"
132
+
133
+ if self.level == Const.LEVEL_L1:
134
+ if not re.match(api_pattern, name):
135
+ raise ScopeException(ScopeException.InvalidScope,
136
+ f"scope参数格式错误,要求格式为api完整命名,实际为{name}.")
137
+
138
+ if self.level == Const.LEVEL_L0:
139
+ if not re.match(module_pattern, name):
140
+ raise ScopeException(ScopeException.InvalidScope,
141
+ f"scope参数格式错误,要求格式为模块完整命名,实际为{name}.")
142
+
143
+ if self.level == Const.LEVEL_MIX:
144
+ if not re.match(api_pattern, name) and not re.match(module_pattern, name):
121
145
  raise ScopeException(ScopeException.InvalidScope,
122
- f"scope参数指定区间断点,须传入长度为1或2的列表,实际长度为{len(scope)}.")
146
+ f"scope参数格式错误,要求格式为api或模块完整命名,实际为{name}.")
147
+
148
+ def rectify_args(self, scope, api_list):
149
+ scope, api_list = super(RangeScope, RangeScope).rectify_args(scope, api_list)
150
+ if scope and len(scope) != 2:
151
+ raise ScopeException(ScopeException.InvalidScope,
152
+ f"scope参数指定区间断点,须传入长度为2的列表,实际长度为{len(scope)}.")
153
+ for name in scope:
154
+ self.check_name_pattern(name)
123
155
  return scope, api_list
124
156
 
125
157
  @abstractmethod
@@ -192,3 +224,50 @@ class ModuleRangeScope(RangeScope):
192
224
  if not self.scope or self.in_scope:
193
225
  return self.check_api_list(name)
194
226
  return False
227
+
228
+
229
+ class MixRangeScope(RangeScope):
230
+ def check_scope_is_valid(self):
231
+ return True if self.scope else False
232
+
233
+ def begin_module(self, module_name):
234
+ if self.scope and module_name == self.scope[0]:
235
+ self.in_scope = True
236
+ for name in self.api_list:
237
+ if name in module_name:
238
+ self.in_list = True
239
+
240
+ def end_module(self, module_name):
241
+ if self.scope and module_name == self.scope[1]:
242
+ self.in_scope = False
243
+ for name in self.api_list:
244
+ if name in module_name:
245
+ self.in_list = False
246
+
247
+ def check_api_list(self, api_name):
248
+ if not self.api_list:
249
+ return True
250
+
251
+ for name in self.api_list:
252
+ if name in api_name:
253
+ return True
254
+ return False
255
+
256
+ def check(self, name):
257
+ """
258
+ dump时调用的接口,根据scope和api_list判断是否需要dump
259
+ """
260
+ result = False
261
+ if self.scope and name == self.scope[0]:
262
+ self.in_scope = True
263
+
264
+ if not self.scope or self.in_scope:
265
+ if self.in_list:
266
+ result = True
267
+ else:
268
+ result = self.check_api_list(name)
269
+
270
+ if self.scope and name == self.scope[1]:
271
+ self.in_scope = False
272
+ return result
273
+
@@ -1,3 +1,17 @@
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
1
15
 
2
16
  class GradConst:
3
17
 
@@ -60,16 +74,16 @@ class GradConst:
60
74
  NORM = "norm"
61
75
 
62
76
  level_adp = {
63
- "L0": {
64
- "header": [GradConst.MD5, GradConst.MAX, GradConst.MIN, GradConst.NORM, GradConst.SHAPE],
65
- "have_grad_direction": False
66
- },
67
- "L1": {
68
- "header": [GradConst.MAX, GradConst.MIN, GradConst.NORM, GradConst.SHAPE],
69
- "have_grad_direction": True
70
- },
71
- "L2": {
72
- "header": [GradConst.DISTRIBUTION, GradConst.MAX, GradConst.MIN, GradConst.NORM, GradConst.SHAPE],
73
- "have_grad_direction": True
74
- },
75
- }
77
+ "L0": {
78
+ "header": [GradConst.MD5, GradConst.MAX, GradConst.MIN, GradConst.NORM, GradConst.SHAPE],
79
+ "have_grad_direction": False
80
+ },
81
+ "L1": {
82
+ "header": [GradConst.MAX, GradConst.MIN, GradConst.NORM, GradConst.SHAPE],
83
+ "have_grad_direction": True
84
+ },
85
+ "L2": {
86
+ "header": [GradConst.DISTRIBUTION, GradConst.MAX, GradConst.MIN, GradConst.NORM, GradConst.SHAPE],
87
+ "have_grad_direction": True
88
+ },
89
+ }
@@ -1,10 +1,25 @@
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
1
16
  import os
2
17
  from typing import List
3
18
 
4
19
  from tqdm import tqdm
5
20
  import matplotlib.pyplot as plt
6
21
 
7
- from msprobe.core.common.file_utils import create_directory, check_path_before_create, check_file_or_directory_path
22
+ from msprobe.core.common.file_utils import create_directory, check_file_or_directory_path
8
23
  from msprobe.core.common.log import logger
9
24
  from msprobe.core.common.file_utils import remove_path, load_npy, write_csv, read_csv
10
25
  from msprobe.core.grad_probe.constant import GradConst
@@ -33,6 +48,8 @@ class GradComparator:
33
48
 
34
49
  @classmethod
35
50
  def compare_distributed(cls, path1: str, path2: str, output_dir: str):
51
+ check_file_or_directory_path(path1, isdir=True)
52
+ check_file_or_directory_path(path2, isdir=True)
36
53
  ranks = cls._get_matched_dirs(path1, path2, "rank")
37
54
  logger.info(f"the following ranks will be compared: {ranks}")
38
55
  if not ranks:
@@ -1,8 +1,24 @@
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
1
16
  import re
2
17
  from msprobe.core.grad_probe.constant import GradConst
3
18
  from msprobe.core.common.log import logger
4
19
  from msprobe.core.common.file_utils import write_csv, check_path_before_create, change_mode
5
20
  from msprobe.core.common.const import FileCheckConst
21
+ from msprobe.core.common.utils import is_int
6
22
  import matplotlib.pyplot as plt
7
23
 
8
24
 
@@ -26,13 +42,24 @@ def check_str(string, variable_name):
26
42
  if not isinstance(string, str):
27
43
  raise ValueError(f'The variable: "{variable_name}" is not a string.')
28
44
 
45
+
29
46
  def check_bounds_element(bound):
30
- return GradConst.BOUNDS_MINIMUM <= bound and bound <= GradConst.BOUNDS_MAXIMUM
47
+ return GradConst.BOUNDS_MINIMUM <= bound <= GradConst.BOUNDS_MAXIMUM
48
+
49
+
50
+ def check_param_element(param):
51
+ if not re.match(GradConst.PARAM_VALID_PATTERN, param):
52
+ return False
53
+ else:
54
+ return True
55
+
31
56
 
32
57
  def check_bounds(bounds):
58
+ if not isinstance(bounds, list):
59
+ raise Exception(f"bounds must be a list")
33
60
  prev = GradConst.BOUNDS_MINIMUM - 1
34
61
  for element in bounds:
35
- if not isinstance(element, (int, float)):
62
+ if not is_int(element) and not isinstance(element, float):
36
63
  raise Exception("bounds element is not int or float")
37
64
  if not check_bounds_element(element):
38
65
  raise Exception("bounds element is out of int64 range")
@@ -40,6 +67,7 @@ def check_bounds(bounds):
40
67
  raise Exception("bounds list is not ascending")
41
68
  prev = element
42
69
 
70
+
43
71
  class ListCache(list):
44
72
  threshold = 1000
45
73
 
@@ -0,0 +1,185 @@
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from typing import List, Dict, Union, Any
17
+
18
+ import numpy as np
19
+
20
+ from msprobe.core.overflow_check.api_info import APIInfo
21
+ from msprobe.core.overflow_check.level import OverflowLevel
22
+ from msprobe.core.overflow_check.utils import has_nan_inf
23
+
24
+
25
+ class AnomalyScene:
26
+ """异常场景的基类"""
27
+
28
+ def __init__(self, api_info: APIInfo):
29
+ self.api_name = api_info.api_name
30
+ self.api_data = api_info
31
+
32
+ @property
33
+ def rank(self) -> OverflowLevel:
34
+ """获取异常等级"""
35
+ raise NotImplementedError
36
+
37
+ @staticmethod
38
+ def _has_anomaly(data: Union[Dict, Any]) -> bool:
39
+ """检查张量是否包含异常值"""
40
+ return has_nan_inf(data)
41
+
42
+ def get_details(self) -> Dict:
43
+ """获取异常详情"""
44
+ return {
45
+ 'api_name': self.api_name,
46
+ 'rank': self.rank.value,
47
+ 'scene_type': self.__class__.__name__,
48
+ 'input_args_anomaly_indices': self._get_anomaly_indices_from_list(self.api_data.input_args),
49
+ 'input_kwargs_anomaly_keys': self._get_anomaly_keys_from_dict(self.api_data.input_kwargs),
50
+ 'output_anomaly_indices': self._get_anomaly_indices_from_list(self.api_data.output_data)
51
+ }
52
+
53
+ def matches(self) -> bool:
54
+ """
55
+ 待子类实现对应匹配逻辑
56
+ Returns:
57
+
58
+ """
59
+ raise NotImplementedError
60
+
61
+ def _get_anomaly_indices_from_list(self, data_list: List[Dict]) -> List[int]:
62
+ return [i for i, data in enumerate(data_list) if self._has_anomaly(data)]
63
+
64
+ def _get_anomaly_keys_from_dict(self, data_dict: Dict) -> List[str]:
65
+ return [key for key, data in data_dict.items() if self._has_anomaly(data)]
66
+
67
+
68
+ class InputOutputAnomalyScene(AnomalyScene):
69
+ """输入输出异常检测的基类"""
70
+ def has_input_anomaly(self) -> bool:
71
+ """检查输入是否有异常(包括args和kwargs)"""
72
+ # args
73
+ args_anomaly = any(self._has_anomaly(x) for x in self.api_data.input_args if isinstance(x, dict))
74
+ # kwargs
75
+ kwargs_anomaly = any(self._has_anomaly(x) for x in self.api_data.input_kwargs.values() if isinstance(x, dict))
76
+ return args_anomaly or kwargs_anomaly
77
+
78
+ def has_output_anomaly(self) -> bool:
79
+ """检查输出是否有异常"""
80
+ return any(self._has_anomaly(x) for x in self.api_data.output_data if isinstance(x, dict))
81
+
82
+ def matches(self) -> bool:
83
+ """判断是否匹配该场景"""
84
+ raise NotImplementedError
85
+
86
+
87
+ class InputAnomalyOutputNormalScene(InputOutputAnomalyScene):
88
+ """输入异常,输出正常场景"""
89
+
90
+ @property
91
+ def rank(self) -> OverflowLevel:
92
+ return OverflowLevel.MEDIUM
93
+
94
+ def matches(self) -> bool:
95
+ return self.has_input_anomaly() and not self.has_output_anomaly()
96
+
97
+
98
+ class InputAnomalyOutputAnomalyScene(InputOutputAnomalyScene):
99
+ """输入异常,输出异常场景"""
100
+
101
+ @property
102
+ def rank(self) -> OverflowLevel:
103
+ return OverflowLevel.HIGH
104
+
105
+ def matches(self) -> bool:
106
+ return self.has_input_anomaly() and self.has_output_anomaly()
107
+
108
+
109
+ class InputNormalOutputAnomalyScene(InputOutputAnomalyScene):
110
+ """输入正常,输出异常场景"""
111
+
112
+ @property
113
+ def rank(self) -> OverflowLevel:
114
+ return OverflowLevel.CRITICAL
115
+
116
+ def matches(self) -> bool:
117
+ return not self.has_input_anomaly() and self.has_output_anomaly()
118
+
119
+
120
+ class NumericalMutationScene(AnomalyScene):
121
+ """
122
+ 检查数值突变,统计输入args、kwargs中norm值,同时统计输出的norm最大值,计算差异,大于 threshold 则认为是异常情况
123
+ """
124
+ def __init__(self, api_info: APIInfo, threshold: float = 100000.0):
125
+ super().__init__(api_info)
126
+ self.threshold = threshold
127
+
128
+ @property
129
+ def rank(self) -> OverflowLevel:
130
+ return OverflowLevel.HIGH
131
+
132
+ @staticmethod
133
+ def _get_tensor_norms(data_list: List[Dict]) -> List[float]:
134
+ norms = []
135
+ for data in data_list:
136
+ if isinstance(data, dict) and data.get('type') == 'torch.Tensor':
137
+ norm = data.get('Norm')
138
+ if norm is not None and not np.isnan(norm):
139
+ norms.append(norm)
140
+ return norms
141
+
142
+ @staticmethod
143
+ def _get_kwargs_norms(data_dict: Dict) -> List[float]:
144
+ """
145
+ 获取kwargs中张量的范数列表
146
+ Args:
147
+ data_dict:
148
+ Returns:
149
+ """
150
+ norms = []
151
+ for data in data_dict.values():
152
+ if isinstance(data, dict) and data.get('type') == 'torch.Tensor':
153
+ norm = data.get('Norm')
154
+ if norm is not None and not np.isnan(norm):
155
+ norms.append(norm)
156
+ return norms
157
+
158
+ def matches(self) -> bool:
159
+ """
160
+ 继承父类函数,实现数值突变检查
161
+ Returns:
162
+ """
163
+ # 收集所有输入的范数
164
+ input_norms = (self._get_tensor_norms(self.api_data.input_args) +
165
+ self._get_kwargs_norms(self.api_data.input_kwargs))
166
+ # 收集所有输出的范数
167
+ output_norms = self._get_tensor_norms(self.api_data.output_data)
168
+
169
+ if not input_norms or not output_norms:
170
+ return False
171
+
172
+ max_input = max(input_norms)
173
+ max_output = max(output_norms)
174
+
175
+ if max_input == 0:
176
+ return max_output > self.threshold
177
+ return max_output / max_input > self.threshold
178
+
179
+ def get_details(self) -> Dict:
180
+ details = super().get_details()
181
+ details.update({
182
+ 'threshold': self.threshold,
183
+ 'scale_change_detected': self.matches()
184
+ })
185
+ return details
@@ -0,0 +1,55 @@
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from dataclasses import dataclass
17
+
18
+ from typing import Dict, List
19
+
20
+ from msprobe.core.common.const import Const
21
+
22
+
23
+ @dataclass
24
+ class APIInfo:
25
+ api_name: str
26
+ torch_api_name: str
27
+ input_args: List[Dict]
28
+ input_kwargs: Dict
29
+ output_data: List[Dict]
30
+
31
+ def __init__(self, api_name, input_args=None, input_kwargs=None, output_data=None):
32
+ self.api_name = api_name
33
+ self.input_args = input_args
34
+ self.input_kwargs = input_kwargs
35
+ self.output_data = output_data
36
+ self.torch_api_name = self.extract_torch_api(self.api_name)
37
+
38
+ @staticmethod
39
+ def extract_torch_api(api_name) -> str:
40
+ """
41
+ Process tensor api name to extract first two fields in lowercase.
42
+ """
43
+ # Empty string checking
44
+ if not api_name.strip():
45
+ return ""
46
+
47
+ parts = api_name.split(Const.SEP)
48
+
49
+ # Handle different cases based on number of parts
50
+ if len(parts) == 0:
51
+ return ""
52
+ elif len(parts) == 1:
53
+ return parts[0].lower()
54
+ else:
55
+ return Const.SEP.join(parts[:2]).lower()