mindstudio-probe 1.1.0__py3-none-any.whl → 1.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (220) hide show
  1. {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.1.1.dist-info}/METADATA +5 -5
  2. mindstudio_probe-1.1.1.dist-info/RECORD +341 -0
  3. {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.1.1.dist-info}/WHEEL +1 -1
  4. {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.1.1.dist-info}/entry_points.txt +0 -1
  5. msprobe/README.md +39 -3
  6. msprobe/config.json +1 -3
  7. msprobe/core/advisor/advisor.py +8 -3
  8. msprobe/core/common/const.py +113 -13
  9. msprobe/core/common/exceptions.py +25 -3
  10. msprobe/core/common/file_utils.py +150 -26
  11. msprobe/core/common/inplace_op_checker.py +15 -0
  12. msprobe/core/common/log.py +27 -9
  13. msprobe/core/common/utils.py +182 -69
  14. msprobe/core/common_config.py +44 -15
  15. msprobe/core/compare/acc_compare.py +207 -142
  16. msprobe/core/compare/check.py +2 -5
  17. msprobe/core/compare/compare_cli.py +21 -4
  18. msprobe/core/compare/highlight.py +124 -55
  19. msprobe/core/compare/layer_mapping/__init__.py +19 -0
  20. msprobe/core/compare/layer_mapping/data_scope_parser.py +235 -0
  21. msprobe/core/compare/layer_mapping/layer_mapping.py +242 -0
  22. msprobe/core/compare/layer_mapping/postprocess_pass.py +94 -0
  23. msprobe/core/compare/npy_compare.py +52 -23
  24. msprobe/core/compare/utils.py +272 -247
  25. msprobe/core/data_dump/data_collector.py +13 -11
  26. msprobe/core/data_dump/data_processor/base.py +46 -16
  27. msprobe/core/data_dump/data_processor/mindspore_processor.py +4 -4
  28. msprobe/core/data_dump/data_processor/pytorch_processor.py +156 -59
  29. msprobe/core/data_dump/scope.py +113 -34
  30. msprobe/core/grad_probe/constant.py +27 -13
  31. msprobe/core/grad_probe/grad_compare.py +18 -1
  32. msprobe/core/grad_probe/utils.py +30 -2
  33. msprobe/core/overflow_check/abnormal_scene.py +185 -0
  34. msprobe/core/overflow_check/api_info.py +55 -0
  35. msprobe/core/overflow_check/checker.py +138 -0
  36. msprobe/core/overflow_check/filter.py +157 -0
  37. msprobe/core/overflow_check/ignore_rules.yaml +55 -0
  38. msprobe/core/overflow_check/level.py +22 -0
  39. msprobe/core/overflow_check/utils.py +28 -0
  40. msprobe/docs/01.installation.md +10 -0
  41. msprobe/docs/02.config_introduction.md +49 -22
  42. msprobe/docs/03.config_examples.md +2 -9
  43. msprobe/docs/04.kernel_dump_PyTorch.md +73 -0
  44. msprobe/docs/05.data_dump_PyTorch.md +3 -1
  45. msprobe/docs/06.data_dump_MindSpore.md +157 -90
  46. msprobe/docs/07.accuracy_checker_PyTorch.md +12 -12
  47. msprobe/docs/08.accuracy_checker_online_PyTorch.md +1 -6
  48. msprobe/docs/09.accuracy_checker_MindSpore.md +44 -8
  49. msprobe/docs/10.accuracy_compare_PyTorch.md +19 -13
  50. msprobe/docs/11.accuracy_compare_MindSpore.md +104 -13
  51. msprobe/docs/12.overflow_check_PyTorch.md +1 -1
  52. msprobe/docs/13.overflow_check_MindSpore.md +6 -6
  53. msprobe/docs/15.free_benchmarking_PyTorch.md +4 -5
  54. msprobe/docs/16.free_benchmarking_MindSpore.md +56 -37
  55. msprobe/docs/17.grad_probe.md +5 -6
  56. msprobe/docs/19.monitor.md +468 -0
  57. msprobe/docs/20.monitor_performance_baseline.md +52 -0
  58. msprobe/docs/21.visualization_PyTorch.md +386 -0
  59. msprobe/docs/22.visualization_MindSpore.md +384 -0
  60. msprobe/docs/23.tool_function_introduction.md +28 -0
  61. msprobe/docs/FAQ.md +3 -0
  62. msprobe/docs/data_dump_Mindspore/dynamic_graph_quick_start_example.md +211 -0
  63. msprobe/docs/img/compare_result.png +0 -0
  64. msprobe/docs/img/monitor/cpu_info.png +0 -0
  65. msprobe/mindspore/__init__.py +15 -0
  66. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +113 -145
  67. msprobe/mindspore/api_accuracy_checker/api_info.py +21 -6
  68. msprobe/mindspore/api_accuracy_checker/api_runner.py +43 -18
  69. msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +21 -7
  70. msprobe/mindspore/api_accuracy_checker/checker_support_api.yaml +77 -0
  71. msprobe/mindspore/api_accuracy_checker/cmd_parser.py +63 -1
  72. msprobe/mindspore/api_accuracy_checker/compute_element.py +59 -24
  73. msprobe/mindspore/api_accuracy_checker/data_manager.py +264 -0
  74. msprobe/mindspore/api_accuracy_checker/main.py +27 -3
  75. msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +206 -0
  76. msprobe/mindspore/api_accuracy_checker/multi_data_manager.py +58 -0
  77. msprobe/mindspore/api_accuracy_checker/type_mapping.py +22 -5
  78. msprobe/mindspore/api_accuracy_checker/utils.py +34 -17
  79. msprobe/mindspore/cell_processor.py +33 -12
  80. msprobe/mindspore/common/const.py +33 -13
  81. msprobe/mindspore/common/log.py +5 -9
  82. msprobe/mindspore/common/utils.py +43 -4
  83. msprobe/mindspore/compare/distributed_compare.py +22 -22
  84. msprobe/mindspore/compare/ms_compare.py +271 -248
  85. msprobe/mindspore/compare/ms_graph_compare.py +81 -47
  86. msprobe/mindspore/debugger/debugger_config.py +4 -1
  87. msprobe/mindspore/debugger/precision_debugger.py +7 -1
  88. msprobe/mindspore/dump/dump_tool_factory.py +3 -1
  89. msprobe/mindspore/dump/hook_cell/api_registry.py +12 -2
  90. msprobe/mindspore/dump/hook_cell/primitive_hooks.py +13 -16
  91. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +25 -0
  92. msprobe/mindspore/dump/jit_dump.py +17 -5
  93. msprobe/mindspore/dump/kernel_graph_dump.py +2 -4
  94. msprobe/mindspore/dump/kernel_kbyk_dump.py +2 -4
  95. msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +140 -0
  96. msprobe/mindspore/dym_loader/hook_dynamic_loader.h +53 -0
  97. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +145 -39
  98. msprobe/mindspore/free_benchmark/common/handler_params.py +1 -2
  99. msprobe/mindspore/free_benchmark/common/utils.py +19 -4
  100. msprobe/mindspore/free_benchmark/data/support_wrap_ops.yaml +0 -204
  101. msprobe/mindspore/free_benchmark/handler/base_handler.py +3 -3
  102. msprobe/mindspore/free_benchmark/handler/check_handler.py +4 -5
  103. msprobe/mindspore/free_benchmark/handler/fix_handler.py +4 -4
  104. msprobe/mindspore/free_benchmark/handler/handler_factory.py +4 -4
  105. msprobe/mindspore/free_benchmark/perturbation/add_noise.py +2 -2
  106. msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +15 -6
  107. msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +4 -4
  108. msprobe/mindspore/free_benchmark/perturbation/exchange_value.py +2 -2
  109. msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +13 -6
  110. msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +2 -2
  111. msprobe/mindspore/free_benchmark/self_check_tool_factory.py +2 -2
  112. msprobe/mindspore/grad_probe/global_context.py +28 -8
  113. msprobe/mindspore/grad_probe/grad_analyzer.py +27 -13
  114. msprobe/mindspore/grad_probe/grad_monitor.py +16 -1
  115. msprobe/mindspore/grad_probe/grad_stat_csv.py +33 -5
  116. msprobe/mindspore/grad_probe/hook.py +24 -10
  117. msprobe/mindspore/grad_probe/utils.py +18 -5
  118. msprobe/mindspore/ms_config.py +22 -15
  119. msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +2 -4
  120. msprobe/mindspore/runtime.py +15 -0
  121. msprobe/mindspore/service.py +36 -30
  122. msprobe/mindspore/task_handler_factory.py +15 -0
  123. msprobe/msprobe.py +24 -7
  124. msprobe/pytorch/__init__.py +3 -2
  125. msprobe/pytorch/api_accuracy_checker/common/config.py +62 -0
  126. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +3 -4
  127. msprobe/pytorch/api_accuracy_checker/generate_op_script/config_op.json +9 -0
  128. msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +454 -0
  129. msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +365 -0
  130. msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +6 -1
  131. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +19 -14
  132. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +13 -9
  133. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +77 -53
  134. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +15 -4
  135. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +9 -24
  136. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +4 -12
  137. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/dump_dispatch.py +9 -4
  138. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +3 -11
  139. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +2 -2
  140. msprobe/pytorch/bench_functions/confusion_transpose.py +5 -1
  141. msprobe/pytorch/bench_functions/matmul_backward.py +12 -0
  142. msprobe/pytorch/bench_functions/npu_fusion_attention.py +100 -6
  143. msprobe/pytorch/bench_functions/rotary_mul.py +4 -0
  144. msprobe/pytorch/bench_functions/swiglu.py +10 -2
  145. msprobe/pytorch/common/parse_json.py +6 -6
  146. msprobe/pytorch/common/utils.py +56 -5
  147. msprobe/pytorch/compare/distributed_compare.py +8 -9
  148. msprobe/pytorch/compare/pt_compare.py +8 -6
  149. msprobe/pytorch/debugger/debugger_config.py +19 -15
  150. msprobe/pytorch/dump/kernel_dump/kernel_config.py +33 -0
  151. msprobe/pytorch/free_benchmark/common/constant.py +15 -0
  152. msprobe/pytorch/free_benchmark/common/counter.py +15 -0
  153. msprobe/pytorch/free_benchmark/common/enums.py +15 -0
  154. msprobe/pytorch/free_benchmark/common/params.py +8 -1
  155. msprobe/pytorch/free_benchmark/common/utils.py +26 -4
  156. msprobe/pytorch/free_benchmark/compare/grad_saver.py +20 -3
  157. msprobe/pytorch/free_benchmark/compare/single_benchmark.py +2 -0
  158. msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +3 -1
  159. msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +6 -4
  160. msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +2 -0
  161. msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +4 -0
  162. msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +10 -0
  163. msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +6 -5
  164. msprobe/pytorch/grad_probe/grad_monitor.py +23 -6
  165. msprobe/pytorch/grad_probe/grad_stat_csv.py +40 -10
  166. msprobe/pytorch/hook_module/support_wrap_ops.yaml +1 -0
  167. msprobe/pytorch/hook_module/wrap_functional.py +14 -12
  168. msprobe/pytorch/module_processer.py +2 -5
  169. msprobe/pytorch/monitor/anomaly_analyse.py +201 -0
  170. msprobe/pytorch/monitor/anomaly_detect.py +340 -0
  171. msprobe/pytorch/monitor/distributed/__init__.py +0 -0
  172. msprobe/pytorch/monitor/distributed/distributed_ops.yaml +19 -0
  173. msprobe/pytorch/monitor/distributed/stack_blacklist.yaml +5 -0
  174. msprobe/pytorch/monitor/distributed/wrap_distributed.py +272 -0
  175. msprobe/pytorch/monitor/features.py +108 -0
  176. msprobe/pytorch/monitor/module_hook.py +870 -0
  177. msprobe/pytorch/monitor/module_metric.py +193 -0
  178. msprobe/pytorch/monitor/module_spec_verifier.py +93 -0
  179. msprobe/pytorch/monitor/optimizer_collect.py +295 -0
  180. msprobe/pytorch/monitor/unittest/__init__.py +0 -0
  181. msprobe/pytorch/monitor/unittest/test_monitor.py +145 -0
  182. msprobe/pytorch/monitor/utils.py +250 -0
  183. msprobe/pytorch/monitor/visualizer.py +59 -0
  184. msprobe/pytorch/online_dispatch/__init__.py +2 -3
  185. msprobe/pytorch/online_dispatch/compare.py +29 -38
  186. msprobe/pytorch/online_dispatch/dispatch.py +50 -25
  187. msprobe/pytorch/online_dispatch/dump_compare.py +21 -9
  188. msprobe/pytorch/online_dispatch/single_compare.py +53 -32
  189. msprobe/pytorch/online_dispatch/torch_ops_config.yaml +1 -1
  190. msprobe/pytorch/online_dispatch/utils.py +49 -21
  191. msprobe/pytorch/parse_tool/lib/compare.py +12 -18
  192. msprobe/pytorch/parse_tool/lib/config.py +1 -1
  193. msprobe/pytorch/parse_tool/lib/parse_tool.py +1 -2
  194. msprobe/pytorch/parse_tool/lib/utils.py +16 -35
  195. msprobe/pytorch/parse_tool/lib/visualization.py +2 -0
  196. msprobe/pytorch/pt_config.py +31 -8
  197. msprobe/pytorch/service.py +15 -5
  198. msprobe/visualization/__init__.py +14 -0
  199. msprobe/visualization/builder/__init__.py +14 -0
  200. msprobe/visualization/builder/graph_builder.py +165 -0
  201. msprobe/visualization/builder/msprobe_adapter.py +205 -0
  202. msprobe/visualization/compare/__init__.py +14 -0
  203. msprobe/visualization/compare/graph_comparator.py +130 -0
  204. msprobe/visualization/compare/mode_adapter.py +211 -0
  205. msprobe/visualization/graph/__init__.py +14 -0
  206. msprobe/visualization/graph/base_node.py +124 -0
  207. msprobe/visualization/graph/graph.py +200 -0
  208. msprobe/visualization/graph/node_colors.py +95 -0
  209. msprobe/visualization/graph/node_op.py +39 -0
  210. msprobe/visualization/graph_service.py +214 -0
  211. msprobe/visualization/utils.py +232 -0
  212. mindstudio_probe-1.1.0.dist-info/RECORD +0 -287
  213. msprobe/docs/04.acl_config_examples.md +0 -78
  214. msprobe/mindspore/compare/layer_mapping.py +0 -146
  215. msprobe/mindspore/compare/modify_mapping.py +0 -107
  216. msprobe/mindspore/free_benchmark/decorator/dec_forward.py +0 -57
  217. msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +0 -122
  218. {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.1.1.dist-info}/LICENSE +0 -0
  219. {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.1.1.dist-info}/top_level.txt +0 -0
  220. /msprobe/{mindspore/free_benchmark/decorator → pytorch/monitor}/__init__.py +0 -0
@@ -28,7 +28,7 @@ from msprobe.pytorch.parse_tool.lib.parse_exception import ParseException
28
28
  from msprobe.core.common.file_utils import change_mode, check_other_user_writable,\
29
29
  check_path_executable, check_path_owner_consistent
30
30
  from msprobe.core.common.const import FileCheckConst
31
- from msprobe.core.common.file_utils import check_file_or_directory_path, remove_path, check_file_type
31
+ from msprobe.core.common.file_utils import check_file_or_directory_path, remove_path, check_file_type, os_walk_for_files
32
32
  from msprobe.pytorch.common.log import logger
33
33
 
34
34
 
@@ -81,16 +81,8 @@ class Util:
81
81
 
82
82
  @staticmethod
83
83
  def get_subfiles_count(directory):
84
- file_count = 0
85
- for root, _, files in os.walk(directory, topdown=True):
86
- check_file_or_directory_path(root, isdir=True)
87
- file_count += len(files)
88
- path_depth = root.count(os.sep)
89
- if path_depth <= Const.MAX_TRAVERSAL_DEPTH:
90
- yield root, _, files
91
- else:
92
- _[:] = []
93
- return file_count
84
+ files = os_walk_for_files(directory, Const.MAX_TRAVERSAL_DEPTH)
85
+ return len(files)
94
86
 
95
87
  @staticmethod
96
88
  def get_sorted_subdirectories_names(directory):
@@ -146,16 +138,10 @@ class Util:
146
138
 
147
139
  @staticmethod
148
140
  def dir_contains_only(path, endfix):
149
- for root, _, files in os.walk(path, topdown=True):
150
- check_file_or_directory_path(root, isdir=True)
151
- for file in files:
152
- if not file.endswith(endfix):
153
- return False
154
- path_depth = root.count(os.sep)
155
- if path_depth <= Const.MAX_TRAVERSAL_DEPTH:
156
- yield root, _, files
157
- else:
158
- _[:] = []
141
+ files = os_walk_for_files(path, Const.MAX_TRAVERSAL_DEPTH)
142
+ for file in files:
143
+ if not file['file'].endswith(endfix):
144
+ return False
159
145
  return True
160
146
 
161
147
  @staticmethod
@@ -273,20 +259,15 @@ class Util:
273
259
  self.check_path_valid(path)
274
260
  file_list = {}
275
261
  re_pattern = re.compile(pattern)
276
- for dir_path, _, file_names in os.walk(path, topdown=True):
277
- check_file_or_directory_path(dir_path, isdir=True)
278
- for name in file_names:
279
- match = re_pattern.match(name)
280
- if not match:
281
- continue
282
- if extern_pattern != '' and re_pattern.match(extern_pattern) and not re.match(extern_pattern, name):
283
- continue
284
- file_list[name] = gen_info_func(name, match, dir_path)
285
- path_depth = dir_path.count(os.sep)
286
- if path_depth <= Const.MAX_TRAVERSAL_DEPTH:
287
- yield dir_path, _, file_names
288
- else:
289
- _[:] = []
262
+ files = os_walk_for_files(path, Const.MAX_TRAVERSAL_DEPTH)
263
+ for file in files:
264
+ name = file["file"]
265
+ match = re_pattern.match(name)
266
+ if not match:
267
+ continue
268
+ if extern_pattern != '' and re_pattern.match(extern_pattern) and not re.match(extern_pattern, name):
269
+ continue
270
+ file_list[name] = gen_info_func(name, match, file["root"])
290
271
  return file_list
291
272
 
292
273
  def check_file_path_format(self, path, suffix):
@@ -65,6 +65,8 @@ class Visualization:
65
65
  self.util.log.error("%s %s in line %s" % ("JSONDecodeError", str(e), pkl_line))
66
66
  self.util.log.warning("Please check the pkl file")
67
67
  raise ParseException(ParseException.PARSE_JSONDECODE_ERROR) from e
68
+ if not isinstance(msg, list) or len(msg) == 0:
69
+ break
68
70
  info_prefix = msg[0]
69
71
  if not info_prefix.startswith(api_name):
70
72
  continue
@@ -14,11 +14,13 @@
14
14
  # limitations under the License.
15
15
 
16
16
  import os
17
+ import re
17
18
 
18
19
  from msprobe.core.common.const import Const
19
20
  from msprobe.core.common.exceptions import MsprobeException
20
- from msprobe.core.common.file_utils import FileOpen, load_json
21
+ from msprobe.core.common.file_utils import FileOpen, load_json, check_file_or_directory_path, check_crt_valid
21
22
  from msprobe.core.common.log import logger
23
+ from msprobe.core.common.utils import is_int
22
24
  from msprobe.core.common_config import BaseConfig, CommonConfig
23
25
  from msprobe.core.grad_probe.constant import level_adp
24
26
  from msprobe.core.grad_probe.utils import check_bounds
@@ -38,17 +40,38 @@ class TensorConfig(BaseConfig):
38
40
  self.host = json_config.get("host", "")
39
41
  self.port = json_config.get("port", -1)
40
42
  self.tls_path = json_config.get("tls_path", "./")
43
+ self.online_run_ut_recompute = json_config.get("online_run_ut_recompute", False)
41
44
  self.check_config()
42
45
  self._check_file_format()
43
- self._check_tls_path_config()
46
+ if self.online_run_ut:
47
+ self._check_online_run_ut()
44
48
 
45
49
  def _check_file_format(self):
46
50
  if self.file_format is not None and self.file_format not in ["npy", "bin"]:
47
51
  raise Exception("file_format is invalid")
48
52
 
49
- def _check_tls_path_config(self):
50
- if self.tls_path and not os.path.exists(self.tls_path):
51
- raise Exception("tls_path: %s does not exist" % self.tls_path)
53
+ def _check_online_run_ut(self):
54
+ if not isinstance(self.online_run_ut, bool):
55
+ raise Exception(f"online_run_ut: {self.online_run_ut} is invalid.")
56
+
57
+ if not isinstance(self.online_run_ut_recompute, bool):
58
+ raise Exception(f"online_run_ut_recompute: {self.online_run_ut_recompute} is invalid.")
59
+
60
+ if self.nfs_path:
61
+ check_file_or_directory_path(self.nfs_path, isdir=True)
62
+ return
63
+
64
+ if self.tls_path:
65
+ check_file_or_directory_path(self.tls_path, isdir=True)
66
+ check_file_or_directory_path(os.path.join(self.tls_path, "client.key"))
67
+ check_file_or_directory_path(os.path.join(self.tls_path, "client.crt"))
68
+ check_crt_valid(os.path.join(self.tls_path, "client.crt"))
69
+
70
+ if not isinstance(self.host, str) or not re.match(Const.ipv4_pattern, self.host):
71
+ raise Exception(f"host: {self.host} is invalid.")
72
+
73
+ if not isinstance(self.port, int) or not (0 < self.port <= 65535):
74
+ raise Exception(f"port: {self.port} is invalid, port range 0-65535.")
52
75
 
53
76
 
54
77
  class StatisticsConfig(BaseConfig):
@@ -70,7 +93,7 @@ class OverflowCheckConfig(BaseConfig):
70
93
  self.check_overflow_config()
71
94
 
72
95
  def check_overflow_config(self):
73
- if self.overflow_nums is not None and not isinstance(self.overflow_nums, int):
96
+ if self.overflow_nums is not None and not is_int(self.overflow_nums):
74
97
  raise Exception("overflow_num is invalid")
75
98
  if self.check_mode is not None and self.check_mode not in ["all", "aicore", "atomic"]:
76
99
  raise Exception("check_mode is invalid")
@@ -170,7 +193,7 @@ class FreeBenchmarkCheckConfig(BaseConfig):
170
193
  )
171
194
 
172
195
  def _check_preheat_config(self):
173
- if not isinstance(self.preheat_step, int):
196
+ if not is_int(self.preheat_step):
174
197
  msg = "preheat_step is invalid, it should be an integer"
175
198
  logger.error_log_with_exp(
176
199
  msg, MsprobeException(MsprobeException.INVALID_PARAM_ERROR, msg)
@@ -180,7 +203,7 @@ class FreeBenchmarkCheckConfig(BaseConfig):
180
203
  logger.error_log_with_exp(
181
204
  msg, MsprobeException(MsprobeException.INVALID_PARAM_ERROR, msg)
182
205
  )
183
- if not isinstance(self.max_sample, int):
206
+ if not is_int(self.max_sample):
184
207
  msg = "max_sample is invalid, it should be an integer"
185
208
  logger.error_log_with_exp(
186
209
  msg, MsprobeException(MsprobeException.INVALID_PARAM_ERROR, msg)
@@ -15,8 +15,8 @@
15
15
 
16
16
  import functools
17
17
  import os
18
-
19
18
  from collections import namedtuple
19
+
20
20
  import torch
21
21
  from msprobe.core.common.const import Const
22
22
  from msprobe.core.common.exceptions import DistributedNotInitializedError, MsprobeException
@@ -25,13 +25,14 @@ from msprobe.core.common.utils import print_tools_ends_info
25
25
  from msprobe.core.data_dump.data_collector import build_data_collector
26
26
  from msprobe.core.data_dump.data_processor.base import ModuleForwardInputsOutputs, ModuleBackwardInputsOutputs
27
27
  from msprobe.core.data_dump.scope import BaseScope
28
+ from msprobe.pytorch.api_accuracy_checker.common.utils import ApiData
28
29
  from msprobe.pytorch.common.log import logger
29
30
  from msprobe.pytorch.common.utils import get_rank_if_initialized
31
+ from msprobe.pytorch.dump.kernel_dump.kernel_config import create_kernel_config_json
30
32
  from msprobe.pytorch.hook_module import remove_dropout
31
33
  from msprobe.pytorch.hook_module.api_registry import api_register
32
34
  from msprobe.pytorch.hook_module.hook_module import HOOKModule
33
35
  from msprobe.pytorch.module_processer import ModuleProcesser
34
- from msprobe.pytorch.api_accuracy_checker.common.utils import ApiData
35
36
 
36
37
  torch_version_above_or_equal_2 = torch.__version__.split('+')[0] >= '2.0'
37
38
  if torch_version_above_or_equal_2:
@@ -159,10 +160,10 @@ class Service:
159
160
  if api_origin:
160
161
  api_register.api_modularity()
161
162
  if self.config.online_run_ut and torch_version_above_or_equal_2:
162
- run_ut_dispatch(self.attl, True)
163
+ run_ut_dispatch(self.attl, True, self.config.online_run_ut_recompute)
163
164
  self.switch = True
164
165
  logger.info_on_rank_0(f"Dump switch is turned on at step {self.current_iter}. ")
165
- if self.config.level != "L2" and not self.config.online_run_ut:
166
+ if not self.config.online_run_ut:
166
167
  self.create_dirs()
167
168
  logger.info_on_rank_0(f"Dump data will be saved in {self.dump_iter_dir}.")
168
169
 
@@ -177,7 +178,7 @@ class Service:
177
178
  return
178
179
  self.switch = False
179
180
  if self.config.online_run_ut and torch_version_above_or_equal_2:
180
- run_ut_dispatch(self.attl, False)
181
+ run_ut_dispatch(self.attl, False, self.config.online_run_ut_recompute)
181
182
  return
182
183
  self.data_collector.write_json()
183
184
 
@@ -191,6 +192,9 @@ class Service:
191
192
  HOOKModule.reset_module_stats()
192
193
  self.data_collector.data_writer.reset_cache()
193
194
 
195
+ if self.config.level == Const.LEVEL_L2:
196
+ self.data_collector.data_processor.reset_status()
197
+
194
198
  def need_stop_service(self):
195
199
  if self.should_stop_service:
196
200
  return True
@@ -221,6 +225,12 @@ class Service:
221
225
  create_directory(self.config.dump_path)
222
226
  self.dump_iter_dir = os.path.join(self.config.dump_path, f"step{self.current_iter}")
223
227
  cur_rank = self.current_rank if self.current_rank is not None else ''
228
+ if self.config.level == Const.LEVEL_L2:
229
+ create_directory(self.dump_iter_dir)
230
+ kernel_config_path = create_kernel_config_json(self.dump_iter_dir, cur_rank)
231
+ self.config.kernel_config_path = kernel_config_path
232
+ return
233
+
224
234
  dump_dir = os.path.join(self.dump_iter_dir, f"rank{cur_rank}")
225
235
  create_directory(dump_dir)
226
236
  if self.config.task in self.data_collector.tasks_need_tensor_data:
@@ -0,0 +1,14 @@
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
@@ -0,0 +1,14 @@
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
@@ -0,0 +1,165 @@
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import re
17
+ from msprobe.visualization.graph.graph import Graph
18
+ from msprobe.visualization.graph.node_op import NodeOp
19
+ from msprobe.visualization.utils import save_json_file, GraphConst
20
+ from msprobe.visualization.builder.msprobe_adapter import get_input_output
21
+ from msprobe.core.common.file_utils import load_json
22
+
23
+
24
+ class GraphBuilder:
25
+ @staticmethod
26
+ def build(construct_path, data_path, stack_path, model_name='DefaultModel'):
27
+ """
28
+ GraphBuilder的对外提供的构图方法
29
+ Args:
30
+ construct_path: construct.json路径
31
+ data_path: dump.json路径
32
+ stack_path: stack.json路径
33
+ model_name: 模型名字,依赖外部输入
34
+ Returns: Graph,代表图的数据结构
35
+ """
36
+ construct_dict = load_json(construct_path)
37
+ dump_dict = load_json(data_path)
38
+ stack_dict = load_json(stack_path)
39
+ data_dict = dump_dict.get(GraphConst.DATA_KEY, {})
40
+ graph = Graph(model_name, data_path=dump_dict.get('dump_data_dir', ''), dump_data=data_dict)
41
+ GraphBuilder._init_nodes(graph, construct_dict, data_dict, stack_dict)
42
+ GraphBuilder._collect_apis_between_modules(graph)
43
+ return graph
44
+
45
+ @staticmethod
46
+ def to_json(filename, config):
47
+ """
48
+ 将graph导出成.vis文件的接口
49
+ """
50
+ result = {}
51
+ if config.graph_b:
52
+ result[GraphConst.JSON_NPU_KEY] = config.graph_n.to_dict()
53
+ result[GraphConst.JSON_BENCH_KEY] = config.graph_b.to_dict()
54
+ else:
55
+ result = config.graph_n.to_dict()
56
+ if config.tool_tip:
57
+ result[GraphConst.JSON_TIP_KEY] = config.tool_tip
58
+ if config.node_colors:
59
+ result[GraphConst.COLORS] = config.node_colors
60
+ if config.micro_steps:
61
+ result[GraphConst.MICRO_STEPS] = config.micro_steps
62
+ if config.task:
63
+ result[GraphConst.JSON_TASK_KEY] = config.task
64
+ save_json_file(filename, result)
65
+
66
+ @staticmethod
67
+ def _handle_backward_upnode_missing(construct_dict, subnode_id, upnode_id):
68
+ """
69
+ 如果backward节点的父级节点是null,则尝试从同名的forward节点寻找父级节点
70
+ """
71
+ # 匹配以.backward.后跟一个或多个数字结尾的模式
72
+ backward_pattern = r"(\.backward\.)(\d+)$"
73
+ forward_pattern = r"(\.forward\.)(\d+)$"
74
+ if re.search(backward_pattern, subnode_id) and not upnode_id:
75
+ forward_upnode_id = construct_dict.get(re.sub(backward_pattern, r".forward.\2", subnode_id))
76
+ if forward_upnode_id:
77
+ new_upnode_id = re.sub(forward_pattern, r".backward.\2", forward_upnode_id)
78
+ if new_upnode_id in construct_dict:
79
+ return new_upnode_id
80
+ return upnode_id
81
+
82
+ @staticmethod
83
+ def _init_nodes(graph, construct_dict, data_dict, stack_dict):
84
+ for subnode_id, upnode_id in construct_dict.items():
85
+ upnode_id = GraphBuilder._handle_backward_upnode_missing(construct_dict, subnode_id, upnode_id)
86
+ if upnode_id:
87
+ upnode_op = NodeOp.get_node_op(upnode_id)
88
+ upnode = GraphBuilder._create_or_get_node(graph, [data_dict, stack_dict], upnode_op, upnode_id)
89
+ else:
90
+ upnode = graph.root
91
+ node_op = NodeOp.get_node_op(subnode_id)
92
+ GraphBuilder._create_or_get_node(graph, [data_dict, stack_dict], node_op, subnode_id, upnode)
93
+
94
+ @staticmethod
95
+ def _create_or_get_node(graph, data_stack_list, op, name, upnode=None):
96
+ if name in graph.node_map:
97
+ node = graph.get_node(name)
98
+ else:
99
+ graph.add_node(op, name, upnode)
100
+ node = graph.get_node(name)
101
+ node_data = data_stack_list[0].get(name, {})
102
+ node_stack_info = data_stack_list[1].get(name, [])
103
+ # 添加输入输出数据
104
+ input_data, output_data = get_input_output(node_data, node.id)
105
+ # 更新数据
106
+ node.set_input_output(input_data, output_data)
107
+ node.stack_info = node_stack_info
108
+ # 添加节点
109
+ node.add_upnode(upnode)
110
+ return node
111
+
112
+ @staticmethod
113
+ def _collect_apis_between_modules(graph):
114
+ """
115
+ 图首次展开,这些首层节点包含许多module和api,api数量很多导致图被拉得很长严重影响查阅,因此将module之间的apis收集起来成为节点
116
+ Args:
117
+ graph: 模型结构
118
+
119
+ Returns: None
120
+ """
121
+ i = 0
122
+ output = []
123
+ node_list = graph.root.subnodes
124
+ while i < len(node_list):
125
+ current_node = node_list[i]
126
+
127
+ # 当前节点为api,检查后续是否还有api
128
+ if current_node.op == NodeOp.function_api:
129
+ temp_nodes = [current_node]
130
+ i += 1
131
+ while i < len(node_list) and node_list[i].op == NodeOp.function_api:
132
+ temp_nodes.append(node_list[i])
133
+ i += 1
134
+
135
+ # 检查api节点是否大于等于2个
136
+ if len(temp_nodes) >= 2:
137
+ # 创建新节点,将这些api节点放入新节点的subnodes属性
138
+ node_id = graph.add_node(NodeOp.api_collection, GraphConst.APIS_BETWEEN_MODULES,
139
+ id_accumulation=True)
140
+ api_collection_node = graph.get_node(node_id)
141
+ api_collection_node.subnodes = temp_nodes
142
+ # 重新确立父子关系
143
+ for node in temp_nodes:
144
+ node.upnode = api_collection_node
145
+ api_collection_node.upnode = graph.root
146
+ output.append(api_collection_node)
147
+ else:
148
+ # 如果连续的api节点不足2个,将它们原样添加到输出列表
149
+ output.extend(temp_nodes)
150
+ else:
151
+ # 如果当前节点为module,直接添加到输出列表
152
+ output.append(current_node)
153
+ i += 1
154
+
155
+ graph.root.subnodes = output
156
+
157
+
158
+ class GraphExportConfig:
159
+ def __init__(self, graph_n, graph_b=None, tool_tip=None, node_colors=None, micro_steps=None, task=''):
160
+ self.graph_n = graph_n
161
+ self.graph_b = graph_b
162
+ self.tool_tip = tool_tip
163
+ self.node_colors = node_colors
164
+ self.micro_steps = micro_steps
165
+ self.task = task
@@ -0,0 +1,205 @@
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ import re
16
+ import math
17
+ from msprobe.core.compare.acc_compare import read_op, merge_tensor, get_accuracy
18
+ from msprobe.core.common.utils import set_dump_path, get_dump_mode
19
+ from msprobe.visualization.utils import GraphConst
20
+ from msprobe.core.common.const import Const
21
+
22
+ # 用于将节点名字解析成对应的NodeOp的规则
23
+ op_patterns = [
24
+ # NodeOp.module
25
+ r'^(Module.|Cell.)',
26
+ # NodeOp.function_api
27
+ r'^(Tensor.|Torch.|Functional.|NPU.|VF.|Distributed.|Aten.|Mint.|Primitive.|Jit.|MintFunctional.)'
28
+ ]
29
+
30
+
31
+ def get_compare_mode(dump_path_param):
32
+ """
33
+ 获得比较模式,包括summary、MD5和真实数据三种模式
34
+ Args:
35
+ dump_path_param: 调用acc_compare接口所依赖的参数
36
+ Returns: 0 summary mode, 1 md5 mode, 2 true data mode
37
+ """
38
+ set_dump_path(dump_path_param)
39
+ dump_mode = get_dump_mode(dump_path_param)
40
+ compare_mode = GraphConst.DUMP_MODE_TO_GRAPHCOMPARE_MODE_MAPPING.get(dump_mode)
41
+ return compare_mode
42
+
43
+
44
+ def run_real_data(dump_path_param, csv_path, framework, is_cross_frame=False):
45
+ """
46
+ 多进程运行生成真实数据
47
+ Args:
48
+ dump_path_param: 调用acc_compare接口所依赖的参数
49
+ csv_path: 生成文件路径
50
+ framework: 框架类型, pytorch或mindspore
51
+ is_cross_frame: 是否进行跨框架比对,仅支持mindspore比pytorch, 其中pytorch为标杆
52
+ """
53
+ if framework == Const.PT_FRAMEWORK:
54
+ from msprobe.pytorch.compare.pt_compare import PTComparator
55
+ return PTComparator().do_multi_process(dump_path_param, csv_path)
56
+ else:
57
+ from msprobe.mindspore.compare.ms_compare import MSComparator
58
+ ms_comparator = MSComparator()
59
+ ms_comparator.cross_frame = is_cross_frame
60
+ return ms_comparator.do_multi_process(dump_path_param, csv_path)
61
+
62
+
63
+ def get_input_output(node_data, node_id):
64
+ """
65
+ 将dump的原始数据进行拆解,分解为output和input两个数据
66
+ Args:
67
+ node_data: 属于单个节点的dump数据
68
+ node_id: 节点名字
69
+ """
70
+ input_data = {}
71
+ output_data = {}
72
+ op_parsed_list = read_op(node_data, node_id)
73
+ for item in op_parsed_list:
74
+ full_op_name = item.get('full_op_name', '')
75
+ if not full_op_name:
76
+ continue
77
+ if GraphConst.OUTPUT in full_op_name and GraphConst.INPUT not in full_op_name:
78
+ output_data[full_op_name] = item
79
+ else:
80
+ name = item.get('data_name')
81
+ # 节点参数名称尽量使用落盘数据的名称
82
+ if isinstance(name, str) and name != '-1':
83
+ input_data[name.rsplit(Const.SEP, 1)[0]] = item
84
+ else:
85
+ input_data[full_op_name] = item
86
+ return input_data, output_data
87
+
88
+
89
+ def compare_data(data_dict_list1, data_dict_list2):
90
+ """
91
+ 比较get_input_output中输出的结果是否结构一致,比较一致返回True
92
+ """
93
+ if len(data_dict_list1) != len(data_dict_list2):
94
+ return False
95
+ # 用于比较两个节点是否相等的关键字段
96
+ tag_keys = ['type', 'shape']
97
+ for key1, key2 in zip(data_dict_list1, data_dict_list2):
98
+ dict1 = data_dict_list1[key1]
99
+ dict2 = data_dict_list2[key2]
100
+ for tag_key in tag_keys:
101
+ tag_value1 = dict1.get(tag_key, None)
102
+ tag_value2 = dict2.get(tag_key, None)
103
+ if tag_value1 != tag_value2:
104
+ return False
105
+ return True
106
+
107
+
108
+ def format_node_data(data_dict):
109
+ """
110
+ 批量进行节点数据的输出
111
+ """
112
+ del_list = ['requires_grad', 'full_op_name']
113
+ for _, value in data_dict.items():
114
+ if not isinstance(value, dict):
115
+ continue
116
+ for item in del_list:
117
+ if item in value:
118
+ del value[item]
119
+ _format_data(value)
120
+ return data_dict
121
+
122
+
123
+ def compare_node(node_ids, data_dicts, stack_json_data, compare_mode):
124
+ """
125
+ 调用acc_compare.py中的get_accuracy获得精度对比指标
126
+ 真实数据对比模式无法获得精度对比指标,需要调用多进程比对接口
127
+ Returns: 包含参数信息和对比指标(真实数据对比模式除外)的list
128
+ """
129
+ merge_n = _parse_node(node_ids[0], data_dicts[0], stack_json_data, compare_mode)
130
+ merge_b = _parse_node(node_ids[1], data_dicts[1], stack_json_data, compare_mode)
131
+ result = []
132
+ dump_mode = GraphConst.GRAPHCOMPARE_MODE_TO_DUMP_MODE_TO_MAPPING.get(compare_mode)
133
+ get_accuracy(result, merge_n, merge_b, dump_mode)
134
+ return result
135
+
136
+
137
+ def _parse_node(node_id, data_dict, stack_json_data, compare_mode):
138
+ """
139
+ 转换节点,使其能够作为acc_compare.py中的get_accuracy的入参
140
+ """
141
+ dump_mode = GraphConst.GRAPHCOMPARE_MODE_TO_DUMP_MODE_TO_MAPPING.get(compare_mode)
142
+ op_parsed_list = read_op(data_dict.get(node_id, {}), node_id)
143
+ if node_id in stack_json_data:
144
+ op_parsed_list.append(
145
+ {'full_op_name': node_id, 'full_info': stack_json_data[node_id]})
146
+ else:
147
+ op_parsed_list.append({'full_op_name': node_id, 'full_info': None})
148
+ result = merge_tensor(op_parsed_list, dump_mode)
149
+ if not result:
150
+ result['op_name'] = []
151
+ return result
152
+
153
+
154
+ def _format_decimal_string(s):
155
+ """
156
+ 使用正则表达式匹配包含数字、小数点和可选的百分号的字符串
157
+ """
158
+ pattern = re.compile(r'\d{1,20}\.\d{1,20}%?')
159
+ matches = pattern.findall(s)
160
+ for match in matches:
161
+ is_percent = match.endswith('%')
162
+ number_str = match.rstrip('%')
163
+ decimal_part = number_str.split('.')[1]
164
+ # 如果小数位数大于6,进行处理
165
+ if len(decimal_part) > GraphConst.ROUND_TH:
166
+ number_float = float(number_str)
167
+ formatted_number = f"{number_float:.{GraphConst.ROUND_TH}f}"
168
+ # 如果原来是百分数,加回百分号
169
+ if is_percent:
170
+ formatted_number += '%'
171
+ # 替换原字符串中的数值部分
172
+ s = s.replace(match, formatted_number)
173
+ return s
174
+
175
+
176
+ def _format_data(data_dict):
177
+ """
178
+ 格式化数据,小数保留6位,处理一些异常值
179
+ """
180
+ pattern = r'^[+-]?(\d+(\.\d*)?|\.\d+)([eE][+-]?\d+)$'
181
+ all_null = False
182
+ for key, value in data_dict.items():
183
+ if isinstance(value, str):
184
+ # 将单引号删掉,None换成null避免前端解析错误
185
+ value = value.replace("'", "").replace(GraphConst.NONE, GraphConst.NULL)
186
+ value = _format_decimal_string(value)
187
+ elif value is None or value == ' ':
188
+ value = GraphConst.NULL
189
+ # 科学计数法1.123123123123e-11,格式化为1.123123e-11
190
+ elif isinstance(value, float) and len(str(value)) < GraphConst.STR_MAX_LEN and re.match(pattern, str(value)):
191
+ value = "{:.6e}".format(value)
192
+ elif isinstance(value, float):
193
+ value = round(value, GraphConst.ROUND_TH)
194
+ # Inf会走入这里,确保转成Inf。另外给其他不符合预期的类型做兜底方案
195
+ if key != GraphConst.ERROR_KEY:
196
+ # 除了error_key不转str,其他都转str, 避免前端解析错误
197
+ value = str(value)
198
+ # max为null, 意味着这个参数值为null
199
+ if key == Const.MAX and value == GraphConst.NULL:
200
+ all_null = True
201
+ data_dict[key] = value
202
+ # 字典里的value全null,只保留一个null
203
+ if all_null:
204
+ data_dict.clear()
205
+ data_dict[GraphConst.VALUE] = GraphConst.NULL
@@ -0,0 +1,14 @@
1
+ # Copyright (c) 2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.