mindstudio-probe 1.2.2__py3-none-any.whl → 1.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (153) hide show
  1. {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-1.3.0.dist-info}/METADATA +3 -3
  2. {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-1.3.0.dist-info}/RECORD +143 -144
  3. msprobe/README.md +25 -20
  4. msprobe/core/common/const.py +110 -66
  5. msprobe/core/common/decorator.py +50 -0
  6. msprobe/core/common/exceptions.py +3 -1
  7. msprobe/core/common/file_utils.py +25 -2
  8. msprobe/core/common/utils.py +30 -34
  9. msprobe/core/compare/acc_compare.py +43 -74
  10. msprobe/core/compare/check.py +2 -6
  11. msprobe/core/compare/highlight.py +2 -0
  12. msprobe/core/compare/layer_mapping/layer_mapping.py +2 -1
  13. msprobe/core/compare/merge_result/merge_result.py +8 -2
  14. msprobe/core/compare/multiprocessing_compute.py +19 -12
  15. msprobe/core/compare/npy_compare.py +30 -12
  16. msprobe/core/compare/utils.py +20 -10
  17. msprobe/core/data_dump/api_registry.py +176 -0
  18. msprobe/core/data_dump/data_processor/base.py +2 -2
  19. msprobe/core/data_dump/data_processor/mindspore_processor.py +19 -32
  20. msprobe/core/data_dump/data_processor/pytorch_processor.py +45 -15
  21. msprobe/core/data_dump/json_writer.py +38 -35
  22. msprobe/core/grad_probe/constant.py +1 -0
  23. msprobe/core/grad_probe/grad_compare.py +1 -1
  24. msprobe/core/overflow_check/abnormal_scene.py +2 -0
  25. msprobe/docs/01.installation.md +2 -1
  26. msprobe/docs/02.config_introduction.md +17 -15
  27. msprobe/docs/05.data_dump_PyTorch.md +70 -2
  28. msprobe/docs/06.data_dump_MindSpore.md +33 -12
  29. msprobe/docs/07.accuracy_checker_PyTorch.md +11 -1
  30. msprobe/docs/08.accuracy_checker_online_PyTorch.md +3 -1
  31. msprobe/docs/09.accuracy_checker_MindSpore.md +1 -1
  32. msprobe/docs/10.accuracy_compare_PyTorch.md +59 -33
  33. msprobe/docs/11.accuracy_compare_MindSpore.md +40 -16
  34. msprobe/docs/12.overflow_check_PyTorch.md +3 -1
  35. msprobe/docs/13.overflow_check_MindSpore.md +4 -2
  36. msprobe/docs/14.data_parse_PyTorch.md +1 -7
  37. msprobe/docs/18.online_dispatch.md +1 -1
  38. msprobe/docs/19.monitor.md +124 -62
  39. msprobe/docs/21.visualization_PyTorch.md +32 -13
  40. msprobe/docs/22.visualization_MindSpore.md +32 -13
  41. msprobe/docs/23.generate_operator_PyTorch.md +9 -9
  42. msprobe/docs/27.dump_json_instruction.md +278 -8
  43. msprobe/docs/28.kernel_dump_MindSpore.md +1 -1
  44. msprobe/docs/29.data_dump_MSAdapter.md +229 -0
  45. msprobe/docs/30.overflow_check_MSAdapter.md +31 -0
  46. msprobe/docs/FAQ.md +3 -11
  47. msprobe/docs/img/compare_result.png +0 -0
  48. msprobe/docs/img/merge_result.png +0 -0
  49. msprobe/docs/img/visualization/vis_browser_1.png +0 -0
  50. msprobe/docs/img/visualization/vis_match_info.png +0 -0
  51. msprobe/docs/img/visualization/vis_precision_info.png +0 -0
  52. msprobe/docs/img/visualization/vis_search_info.png +0 -0
  53. msprobe/docs/img/visualization/vis_show_info.png +0 -0
  54. msprobe/docs/img/visualization/vis_showcase.png +0 -0
  55. msprobe/docs/img/visualization/vis_unmatch_info.png +0 -0
  56. msprobe/mindspore/__init__.py +4 -3
  57. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +6 -1
  58. msprobe/mindspore/api_accuracy_checker/api_runner.py +19 -9
  59. msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +2 -1
  60. msprobe/mindspore/api_accuracy_checker/bench_functions/flash_attention_score.py +602 -0
  61. msprobe/mindspore/api_accuracy_checker/bench_functions/fusion_operator.py +41 -0
  62. msprobe/mindspore/api_accuracy_checker/data_manager.py +2 -1
  63. msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +2 -1
  64. msprobe/mindspore/api_accuracy_checker/torch_mindtorch_importer.py +2 -1
  65. msprobe/mindspore/common/const.py +61 -0
  66. msprobe/mindspore/common/utils.py +31 -19
  67. msprobe/mindspore/compare/ms_compare.py +27 -19
  68. msprobe/mindspore/compare/ms_graph_compare.py +6 -5
  69. msprobe/mindspore/debugger/debugger_config.py +6 -4
  70. msprobe/mindspore/debugger/precision_debugger.py +22 -10
  71. msprobe/mindspore/dump/dump_tool_factory.py +5 -3
  72. msprobe/mindspore/dump/hook_cell/api_register.py +142 -0
  73. msprobe/mindspore/dump/hook_cell/hook_cell.py +9 -10
  74. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +24 -26
  75. msprobe/mindspore/dump/jit_dump.py +14 -9
  76. msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +22 -56
  77. msprobe/mindspore/dym_loader/hook_dynamic_loader.h +0 -1
  78. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +10 -6
  79. msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +4 -2
  80. msprobe/mindspore/free_benchmark/self_check_tool_factory.py +6 -3
  81. msprobe/mindspore/grad_probe/global_context.py +2 -0
  82. msprobe/mindspore/grad_probe/grad_analyzer.py +2 -1
  83. msprobe/mindspore/grad_probe/hook.py +2 -4
  84. msprobe/mindspore/monitor/distributed/wrap_distributed.py +1 -1
  85. msprobe/mindspore/monitor/module_hook.py +354 -302
  86. msprobe/mindspore/monitor/utils.py +46 -4
  87. msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +5 -3
  88. msprobe/mindspore/service.py +23 -17
  89. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +3 -6
  90. msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +11 -6
  91. msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +2 -2
  92. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +4 -5
  93. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +5 -5
  94. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +25 -6
  95. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +28 -19
  96. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +3 -1
  97. msprobe/pytorch/bench_functions/moe_gating_top_k_softmax.py +6 -0
  98. msprobe/pytorch/common/utils.py +29 -7
  99. msprobe/pytorch/debugger/precision_debugger.py +10 -1
  100. msprobe/pytorch/dump/module_dump/module_dump.py +4 -3
  101. msprobe/pytorch/dump/module_dump/module_processer.py +12 -6
  102. msprobe/pytorch/free_benchmark/common/utils.py +1 -1
  103. msprobe/pytorch/free_benchmark/compare/single_benchmark.py +1 -1
  104. msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +3 -3
  105. msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +3 -3
  106. msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +1 -1
  107. msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +1 -1
  108. msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +1 -1
  109. msprobe/pytorch/function_factory.py +1 -1
  110. msprobe/pytorch/grad_probe/grad_monitor.py +2 -2
  111. msprobe/pytorch/hook_module/api_register.py +131 -0
  112. msprobe/pytorch/hook_module/hook_module.py +19 -14
  113. msprobe/pytorch/hook_module/register_optimizer_hook.py +2 -1
  114. msprobe/pytorch/hook_module/support_wrap_ops.yaml +172 -75
  115. msprobe/pytorch/monitor/csv2tb.py +8 -2
  116. msprobe/pytorch/monitor/distributed/wrap_distributed.py +8 -2
  117. msprobe/pytorch/monitor/module_hook.py +131 -105
  118. msprobe/pytorch/monitor/module_metric.py +3 -0
  119. msprobe/pytorch/monitor/optimizer_collect.py +55 -4
  120. msprobe/pytorch/monitor/unittest/test_monitor.py +1 -1
  121. msprobe/pytorch/monitor/utils.py +68 -1
  122. msprobe/pytorch/online_dispatch/compare.py +0 -2
  123. msprobe/pytorch/online_dispatch/dispatch.py +9 -0
  124. msprobe/pytorch/online_dispatch/dump_compare.py +3 -0
  125. msprobe/pytorch/online_dispatch/utils.py +3 -0
  126. msprobe/pytorch/parse_tool/lib/interactive_cli.py +1 -6
  127. msprobe/pytorch/parse_tool/lib/utils.py +2 -1
  128. msprobe/pytorch/pt_config.py +11 -7
  129. msprobe/pytorch/service.py +11 -8
  130. msprobe/visualization/builder/graph_builder.py +44 -5
  131. msprobe/visualization/builder/msprobe_adapter.py +0 -1
  132. msprobe/visualization/compare/graph_comparator.py +42 -38
  133. msprobe/visualization/compare/mode_adapter.py +0 -19
  134. msprobe/visualization/graph/base_node.py +8 -1
  135. msprobe/visualization/graph/distributed_analyzer.py +1 -10
  136. msprobe/visualization/graph/graph.py +0 -11
  137. msprobe/visualization/graph/node_op.py +1 -2
  138. msprobe/visualization/graph_service.py +1 -1
  139. msprobe/visualization/utils.py +2 -33
  140. msprobe/mindspore/dump/hook_cell/api_registry.py +0 -207
  141. msprobe/mindspore/dump/hook_cell/wrap_api.py +0 -212
  142. msprobe/pytorch/hook_module/api_registry.py +0 -166
  143. msprobe/pytorch/hook_module/wrap_distributed.py +0 -79
  144. msprobe/pytorch/hook_module/wrap_functional.py +0 -66
  145. msprobe/pytorch/hook_module/wrap_npu_custom.py +0 -85
  146. msprobe/pytorch/hook_module/wrap_tensor.py +0 -69
  147. msprobe/pytorch/hook_module/wrap_torch.py +0 -84
  148. msprobe/pytorch/hook_module/wrap_vf.py +0 -60
  149. msprobe/pytorch/parse.py +0 -19
  150. {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-1.3.0.dist-info}/LICENSE +0 -0
  151. {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-1.3.0.dist-info}/WHEEL +0 -0
  152. {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-1.3.0.dist-info}/entry_points.txt +0 -0
  153. {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-1.3.0.dist-info}/top_level.txt +0 -0
@@ -25,7 +25,7 @@ import torch
25
25
  from msprobe.core.common.const import MonitorConst, Const
26
26
  from msprobe.pytorch.common.log import logger
27
27
  from msprobe.core.common.utils import is_int
28
- from msprobe.core.common.file_utils import check_file_or_directory_path
28
+ from msprobe.core.common.file_utils import check_file_or_directory_path, recursive_chmod
29
29
 
30
30
 
31
31
  device = "cpu"
@@ -105,6 +105,15 @@ def validate_ops(ops):
105
105
  return valid_ops
106
106
 
107
107
 
108
+ def validate_ndigits(ndigits):
109
+ if not ndigits:
110
+ return
111
+ if not is_int(ndigits) or ndigits <= 0:
112
+ raise ValueError(f"ndigits({ndigits}) is not a positive integer, current is: {ndigits}.")
113
+ if ndigits > MonitorConst.MAX_NDIGITS:
114
+ raise ValueError(f"The maximum supported ndigits is {MonitorConst.MAX_NDIGITS}, current value: {ndigits}.")
115
+
116
+
108
117
  def validate_ranks(ranks):
109
118
  if not isinstance(ranks, list):
110
119
  raise TypeError("module_ranks should be a list")
@@ -206,9 +215,17 @@ def validate_step_count_per_record(step_count_per_record):
206
215
  raise ValueError("step_count_per_record must smaller than 1e6")
207
216
 
208
217
 
218
+ def validate_dynamic_on(dynamic_on):
219
+ if not isinstance(dynamic_on, bool):
220
+ raise TypeError('dynamic_on should be a bool')
221
+
222
+
209
223
  def validate_config(config):
210
224
  config['ops'] = validate_ops(config.get('ops', []))
211
225
 
226
+ ndigits = config.get('ndigits')
227
+ validate_ndigits(ndigits)
228
+
212
229
  eps = config.get('eps', 1e-8)
213
230
  if not isinstance(eps, float):
214
231
  raise TypeError("eps should be a float")
@@ -246,9 +263,20 @@ def validate_config(config):
246
263
  step_count_per_record = config.get('step_count_per_record', 1)
247
264
  validate_step_count_per_record(step_count_per_record)
248
265
 
266
+ config["start_step"] = validate_int_arg(config.get("start_step"), "start_step",
267
+ MonitorConst.DEFAULT_START_STEP, MonitorConst.DEFAULT_START_STEP)
268
+ config["collect_times"] = validate_int_arg(config.get("collect_times"), "collect_times",
269
+ MonitorConst.DEFAULT_MIN_COLLECT_TIMES,
270
+ MonitorConst.DEFAULT_MAX_COLLECT_TIMES)
271
+ config["step_interval"] = validate_int_arg(config.get("step_interval"), "step_interval",
272
+ MonitorConst.DEFAULT_STEP_INTERVAL, MonitorConst.DEFAULT_STEP_INTERVAL)
273
+
249
274
  squash_name = config.get('squash_name', True)
250
275
  validate_squash_name(squash_name)
251
276
 
277
+ dynamic_on = config.get('dynamic_on', False)
278
+ validate_dynamic_on(dynamic_on)
279
+
252
280
  if not targets:
253
281
  if xy_distribution:
254
282
  config["all_xy"] = True
@@ -257,6 +285,8 @@ def validate_config(config):
257
285
 
258
286
  def time_str2time_digit(time_str):
259
287
  time_format = '%b%d_%H-%M-%S'
288
+ if not isinstance(time_str, str):
289
+ raise TypeError(f"time_str:{time_str} should be a str")
260
290
  try:
261
291
  time_digit = datetime.strptime(time_str, time_format)
262
292
  except Exception as e:
@@ -284,3 +314,40 @@ def get_target_output_dir(monitor_path, time_start, time_end):
284
314
  if start_ok and end_ok:
285
315
  result[rank] = os.path.join(monitor_path, dirname)
286
316
  return result
317
+
318
+
319
+ def chmod_tensorboard_dir(path):
320
+ """
321
+ format配置为tensorboard时,需要补充文件权限设置
322
+ """
323
+ try:
324
+ recursive_chmod(path)
325
+ except Exception as e:
326
+ logger.warning(f"chmod tensorboard dir wrong because {e}, not updated, please check!!!")
327
+
328
+
329
+ def validate_set_monitor(grad_acc_steps, start_iteration):
330
+ """
331
+ validate parameters of set_monitor.
332
+ """
333
+ grad_acc_steps = validate_int_arg(grad_acc_steps, "grad_acc_steps",
334
+ MonitorConst.DEFAULT_GRAD_ACC_STEPS, MonitorConst.DEFAULT_GRAD_ACC_STEPS)
335
+
336
+ start_iteration = validate_int_arg(start_iteration, "start_iteration",
337
+ MonitorConst.DEFAULT_START_ITERATION, MonitorConst.DEFAULT_START_ITERATION)
338
+ return grad_acc_steps, start_iteration
339
+
340
+
341
+ def validate_int_arg(value, name, minimum, default_value):
342
+ """Validate int args, if any exception occurs, use the default value."""
343
+ if value is None:
344
+ return default_value
345
+ try:
346
+ if not is_int(value):
347
+ raise TypeError(f"{name} must be int")
348
+ if value < minimum:
349
+ raise ValueError(f"{name} must greater than {minimum}")
350
+ except Exception as e:
351
+ value = default_value
352
+ logger.warning(f"Validate {name} failed, {e}, replaced with default value {value}.")
353
+ return value
@@ -125,8 +125,6 @@ class Saver:
125
125
 
126
126
  def write_summary_csv(self, test_result):
127
127
  test_rows = []
128
- if self.stack_info:
129
- test_rows[0].append(self.COLUMN_STACK_INFO)
130
128
 
131
129
  check_op_str_pattern_valid(test_result.api_name)
132
130
  df_row = [test_result.api_name, test_result.is_fwd_success, test_result.is_bwd_success]
@@ -16,6 +16,7 @@
16
16
  import json
17
17
  import os
18
18
  import time
19
+ import multiprocessing
19
20
  from multiprocessing import Pool
20
21
 
21
22
  import torch
@@ -52,6 +53,7 @@ class PtdbgDispatch(TorchDispatchMode):
52
53
  return
53
54
  if dump_path is None:
54
55
  logger.error("Please set dump_path when dump_mode is config!")
56
+ raise DispatchException("Please set dump_path when dump_mode is config!")
55
57
  check_file_or_directory_path(dump_path, True)
56
58
 
57
59
  self.device_id = torch_npu._C._npu_getDevice()
@@ -85,6 +87,11 @@ class PtdbgDispatch(TorchDispatchMode):
85
87
  self.get_ops(yaml_path)
86
88
 
87
89
  self.lock = None
90
+ max_process_num = max(int((multiprocessing.cpu_count() + 1) // Const.CPU_QUARTER), 1)
91
+ if process_num > max_process_num:
92
+ logger.error(f"process_num should be less than or equal to {max_process_num}, but got {process_num}!")
93
+ raise DispatchException(f'process_num should be less than or equal to {max_process_num}, '
94
+ f'but got {process_num}!')
88
95
  if process_num > 0:
89
96
  self.pool = Pool(process_num)
90
97
  if debug:
@@ -115,6 +122,8 @@ class PtdbgDispatch(TorchDispatchMode):
115
122
  if len(json_line_data) == 0:
116
123
  break
117
124
  msg = json.loads(json_line_data)
125
+ if len(msg) < 2:
126
+ raise ValueError("JSON data does not contain enough elements. Expected at least 2 elements.")
118
127
  self.all_summary[msg[0]] = msg[1]
119
128
  fp_handle.close()
120
129
 
@@ -19,6 +19,8 @@ import os
19
19
  from datetime import datetime, timezone
20
20
 
21
21
  import torch
22
+ from msprobe.core.common.const import Const
23
+ from msprobe.core.common.decorator import recursion_depth_decorator
22
24
  from msprobe.core.common.file_utils import FileOpen, save_npy, save_json
23
25
  from msprobe.pytorch.common.log import logger
24
26
 
@@ -91,6 +93,7 @@ def support_basic_type(data):
91
93
  return False
92
94
 
93
95
 
96
+ @recursion_depth_decorator("dump_data")
94
97
  def dump_data(data, prefix, dump_path):
95
98
  if isinstance(data, (tuple, list)) and data:
96
99
  for i, item in enumerate(data):
@@ -27,8 +27,10 @@ else:
27
27
  pta_cpu_device = torch.device("cpu")
28
28
 
29
29
  from msprobe.core.common.const import CompareConst
30
+ from msprobe.core.common.decorator import recursion_depth_decorator
30
31
  from msprobe.pytorch.common.log import logger
31
32
 
33
+
32
34
  cpu_device = torch._C.device("cpu")
33
35
  COLOR_RED = '\033[31m'
34
36
  COLOR_GREEN = '\033[32m'
@@ -85,6 +87,7 @@ def get_callstack():
85
87
  return callstack
86
88
 
87
89
 
90
+ @recursion_depth_decorator("data_to_cpu")
88
91
  def data_to_cpu(data, deep, data_cpu):
89
92
  global cpu_device
90
93
  list_cpu = []
@@ -45,12 +45,7 @@ class InteractiveCli(cmd.Cmd):
45
45
 
46
46
  @catch_exception
47
47
  def default(self, line=""):
48
- self.util.execute_command(line)
49
- return False
50
-
51
- @catch_exception
52
- def do_run(self, line=""):
53
- self.util.execute_command(line)
48
+ self.stdout.write("Command invalid, Only support command start with cad/vc/dc/pk/cn/pt\n")
54
49
 
55
50
  @catch_exception
56
51
  def do_vc(self, line=""):
@@ -119,6 +119,7 @@ class Util:
119
119
 
120
120
  @staticmethod
121
121
  def deal_with_dir_or_file_inconsistency(output_path):
122
+ logger.warning(f"Trying to delete {output_path}")
122
123
  remove_path(output_path)
123
124
  raise ParseException("Inconsistent directory structure or file.")
124
125
 
@@ -264,7 +265,7 @@ class Util:
264
265
  match = re_pattern.match(name)
265
266
  if not match:
266
267
  continue
267
- if extern_pattern != '' and re_pattern.match(extern_pattern) and not re.match(extern_pattern, name):
268
+ if extern_pattern != '' and re_pattern.match(extern_pattern) and not name.startswith(extern_pattern):
268
269
  continue
269
270
  file_list[name] = gen_info_func(name, match, file["root"])
270
271
  return file_list
@@ -16,9 +16,10 @@
16
16
  import os
17
17
  import re
18
18
 
19
- from msprobe.core.common.const import Const
19
+ from msprobe.core.common.const import Const, FileCheckConst
20
20
  from msprobe.core.common.exceptions import MsprobeException
21
- from msprobe.core.common.file_utils import FileOpen, load_json, check_file_or_directory_path, check_crt_valid
21
+ from msprobe.core.common.file_utils import FileOpen, load_json, check_file_or_directory_path, check_crt_valid, \
22
+ FileChecker
22
23
  from msprobe.core.common.log import logger
23
24
  from msprobe.core.common.utils import is_int
24
25
  from msprobe.core.common_config import BaseConfig, CommonConfig
@@ -66,6 +67,7 @@ class TensorConfig(BaseConfig):
66
67
  check_file_or_directory_path(os.path.join(self.tls_path, "client.key"))
67
68
  check_file_or_directory_path(os.path.join(self.tls_path, "client.crt"))
68
69
  check_crt_valid(os.path.join(self.tls_path, "client.crt"))
70
+ check_crt_valid(os.path.join(self.tls_path, "client.key"), True)
69
71
 
70
72
  if not isinstance(self.host, str) or not re.match(Const.ipv4_pattern, self.host):
71
73
  raise Exception(f"host: {self.host} is invalid.")
@@ -95,6 +97,8 @@ class OverflowCheckConfig(BaseConfig):
95
97
  def check_overflow_config(self):
96
98
  if self.overflow_nums is not None and not is_int(self.overflow_nums):
97
99
  raise Exception("overflow_num is invalid")
100
+ if self.overflow_nums is not None and self.overflow_nums != -1 and self.overflow_nums <= 0:
101
+ raise Exception("overflow_nums should be -1 or positive integer")
98
102
  if self.check_mode is not None and self.check_mode not in ["all", "aicore", "atomic"]:
99
103
  raise Exception("check_mode is invalid")
100
104
 
@@ -148,7 +152,7 @@ class FreeBenchmarkCheckConfig(BaseConfig):
148
152
  self.pert_mode in PytorchFreeBenchmarkConst.CPU_MODE_LIST
149
153
  ):
150
154
  msg = (
151
- f"You neet to and can only set fuzz_device as {DeviceType.CPU} "
155
+ f"You need to and can only set fuzz_device as {DeviceType.CPU} "
152
156
  f"when pert_mode in {PytorchFreeBenchmarkConst.CPU_MODE_LIST}"
153
157
  )
154
158
  logger.error_log_with_exp(
@@ -271,13 +275,13 @@ class RunUTConfig(BaseConfig):
271
275
 
272
276
  @classmethod
273
277
  def check_nfs_path_config(cls, nfs_path):
274
- if nfs_path and not os.path.exists(nfs_path):
275
- raise Exception("nfs_path: %s does not exist" % nfs_path)
278
+ if nfs_path:
279
+ FileChecker(nfs_path, FileCheckConst.DIR, FileCheckConst.READ_ABLE).common_check()
276
280
 
277
281
  @classmethod
278
282
  def check_tls_path_config(cls, tls_path):
279
- if tls_path and not os.path.exists(tls_path):
280
- raise Exception("tls_path: %s does not exist" % tls_path)
283
+ if tls_path:
284
+ FileChecker(tls_path, FileCheckConst.DIR, FileCheckConst.READ_ABLE).common_check()
281
285
 
282
286
  def check_run_ut_config(self):
283
287
  RunUTConfig.check_filter_list_config(Const.WHITE_LIST, self.white_list)
@@ -30,7 +30,7 @@ from msprobe.pytorch.common.log import logger
30
30
  from msprobe.pytorch.common.utils import get_rank_if_initialized, is_recomputation
31
31
  from msprobe.pytorch.dump.kernel_dump.kernel_config import create_kernel_config_json
32
32
  from msprobe.pytorch.dump.module_dump.module_processer import ModuleProcesser
33
- from msprobe.pytorch.hook_module.api_registry import api_register
33
+ from msprobe.pytorch.hook_module.api_register import get_api_register
34
34
  from msprobe.pytorch.hook_module.hook_module import HOOKModule
35
35
  from msprobe.pytorch.hook_module.register_optimizer_hook import register_optimizer_hook
36
36
 
@@ -50,6 +50,8 @@ class Service:
50
50
  self.switch = False
51
51
  self.inner_switch = False
52
52
  self.current_iter = 0
53
+ self.loop = 0
54
+ self.init_step = 0
53
55
  self.first_start = True
54
56
  self.current_rank = None
55
57
  self.dump_iter_dir = None
@@ -58,6 +60,7 @@ class Service:
58
60
  self.params_grad_info = {}
59
61
  self.hook_handle_dict = {}
60
62
  # 提前注册,确保注册尽可能多的API hook
63
+ self.api_register = get_api_register()
61
64
  self.register_api_hook()
62
65
  self.init_for_debug_level()
63
66
 
@@ -246,6 +249,8 @@ class Service:
246
249
  return HookFn(pre_forward_hook_fn, forward_hook_fn, backward_hook_fn, forward_hook_torch_version_below_2_fn)
247
250
 
248
251
  def start(self, model):
252
+ self.current_iter = self.loop + self.init_step
253
+ self.data_collector.update_iter(self.current_iter)
249
254
  if self.config.level == Const.LEVEL_DEBUG:
250
255
  return
251
256
  if self.need_stop_service():
@@ -304,8 +309,7 @@ class Service:
304
309
  if self.config.task == Const.TENSOR:
305
310
  self.data_collector.data_processor.dump_async_data()
306
311
  self.data_collector.write_json()
307
- self.current_iter += 1
308
- self.data_collector.update_iter(self.current_iter)
312
+ self.loop += 1
309
313
  self.reset_status()
310
314
 
311
315
  def need_stop_service(self):
@@ -370,11 +374,10 @@ class Service:
370
374
  def register_api_hook(self):
371
375
  if self.config.level in [Const.LEVEL_MIX, Const.LEVEL_L1, Const.LEVEL_L2]:
372
376
  logger.info_on_rank_0(f"The api {self.config.task} hook function is successfully mounted to the model.")
373
- api_register.initialize_hook(
374
- functools.partial(self.build_hook, BaseScope.Module_Type_API),
375
- self.config.online_run_ut
377
+ self.api_register.initialize_hook(
378
+ functools.partial(self.build_hook, BaseScope.Module_Type_API)
376
379
  )
377
- api_register.api_modularity()
380
+ self.api_register.register_all_api()
378
381
 
379
382
  def register_module_hook(self):
380
383
  if self.config.level in [Const.LEVEL_L0, Const.LEVEL_MIX]:
@@ -409,7 +412,7 @@ class Service:
409
412
  if self.config.nfs_path:
410
413
  self.attl.upload("end")
411
414
  elif self.attl.socket_manager is not None:
412
- logger.info(f"pid: {os.getpid()} finished, start send STOP signal.")
415
+ logger.info(f"pid: {os.getpid()} finished, start sends STOP signal.")
413
416
  self.attl.socket_manager.send_stop_signal()
414
417
 
415
418
  def reset_status(self):
@@ -16,19 +16,19 @@
16
16
  import re
17
17
 
18
18
  from msprobe.core.common.const import Const
19
- from msprobe.core.common.file_utils import load_json
19
+ from msprobe.core.common.file_utils import load_json, save_json
20
20
  from msprobe.visualization.builder.msprobe_adapter import get_input_output
21
21
  from msprobe.visualization.builder.msprobe_adapter import op_patterns
22
22
  from msprobe.visualization.graph.graph import Graph
23
23
  from msprobe.visualization.graph.node_op import NodeOp
24
- from msprobe.visualization.utils import save_json_file, GraphConst
24
+ from msprobe.visualization.utils import GraphConst
25
25
 
26
26
 
27
27
  class GraphBuilder:
28
28
  backward_pattern = re.compile(r"(\.backward\.)(\d+)$")
29
29
  forward_pattern = re.compile(r"(\.forward\.)(\d+)$")
30
- # 匹配以大写字母开头,后接任意字母,并以Template(结尾
31
- template_pattern = re.compile(r'\b[A-Z][a-zA-Z]*Template\(')
30
+ # 匹配以大写字母开头,后接任意字母,并以Template(结尾,或包含api_template(的字符串
31
+ template_pattern = re.compile(r'\b([A-Z][a-zA-Z]*Template|api_template)\(')
32
32
 
33
33
  @staticmethod
34
34
  def build(construct_path, data_path, stack_path, model_name='DefaultModel', complete_stack=False):
@@ -51,6 +51,7 @@ class GraphBuilder:
51
51
  graph = Graph(model_name, data_path=dump_dict.get('dump_data_dir', ''), dump_data=data_dict)
52
52
  GraphBuilder._init_nodes(graph, construct_dict, data_dict, stack_dict)
53
53
  GraphBuilder._collect_apis_between_modules(graph)
54
+ GraphBuilder._add_parameters_grad(graph, data_dict)
54
55
  return graph
55
56
 
56
57
  @staticmethod
@@ -73,7 +74,7 @@ class GraphBuilder:
73
74
  if config.task:
74
75
  result[GraphConst.JSON_TASK_KEY] = config.task
75
76
  result[GraphConst.OVERFLOW_CHECK] = config.overflow_check
76
- save_json_file(filename, result)
77
+ save_json(filename, result, indent=4)
77
78
 
78
79
  @staticmethod
79
80
  def _simplify_stack(stack_dict):
@@ -235,6 +236,44 @@ class GraphBuilder:
235
236
 
236
237
  graph.root.subnodes = output
237
238
 
239
+ @staticmethod
240
+ def _add_parameters_grad(graph, data_dict):
241
+ """
242
+ 将parameters_grad信息添加到graph中,
243
+ 对应模块的parameters_grad节点添加到对应模块的最后一次backward节点(backward计数最大)内作为子节点
244
+
245
+ 例如,graph有节点Module.a.backward.0, Module.a.backward.1, Module.a.backward.2
246
+ 则Module.a.parameters_grad添加在Module.a.backward.2内作为子节点
247
+ """
248
+ prefixes = []
249
+ suffix = Const.SEP + Const.PARAMS_GRAD
250
+ for node_id in data_dict.keys():
251
+ if node_id not in graph.node_map and node_id.endswith(suffix):
252
+ prefixes.append(node_id.replace(suffix, ''))
253
+
254
+ max_info = {prefix: 0 for prefix in prefixes}
255
+
256
+ for key in graph.node_map.keys():
257
+ for prefix in prefixes:
258
+ # 构建正则表达式,匹配以 "backward.数字" 结尾的键
259
+ pattern = re.compile(r'^' + re.escape(prefix) + r'\.backward\.(\d+)$')
260
+ match = pattern.match(key)
261
+ if match:
262
+ num = int(match.group(1))
263
+ if num > max_info[prefix]:
264
+ max_info[prefix] = num
265
+
266
+ for prefix, num in max_info.items():
267
+ node_id = prefix + Const.SEP + Const.BACKWARD + Const.SEP + str(num)
268
+ node = graph.get_node(node_id)
269
+ if node:
270
+ parameters_grad_node_id = graph.add_node(NodeOp.module, prefix + suffix, up_node=node)
271
+ # 添加输入输出数据
272
+ node_data = data_dict.get(parameters_grad_node_id, {})
273
+ input_data, output_data = get_input_output(node_data, parameters_grad_node_id)
274
+ # 更新数据
275
+ graph.get_node(parameters_grad_node_id).set_input_output(input_data, output_data)
276
+
238
277
 
239
278
  class GraphExportConfig:
240
279
  def __init__(self, graph_n, graph_b=None, tool_tip=None, node_colors=None, micro_steps=None, task='',
@@ -13,7 +13,6 @@
13
13
  # See the License for the specific language governing permissions and
14
14
  # limitations under the License.
15
15
  import re
16
- import math
17
16
  from msprobe.core.compare.acc_compare import read_op, merge_tensor, get_accuracy
18
17
  from msprobe.core.common.utils import set_dump_path, get_dump_mode
19
18
  from msprobe.visualization.utils import GraphConst
@@ -17,12 +17,14 @@ import re
17
17
  from msprobe.visualization.builder.msprobe_adapter import compare_node, get_compare_mode, run_real_data
18
18
  from msprobe.visualization.utils import GraphConst, load_json_file, load_data_json_file, get_csv_df
19
19
  from msprobe.visualization.graph.graph import Graph, NodeOp
20
- from msprobe.visualization.graph.node_colors import NodeColors
21
20
  from msprobe.visualization.compare.mode_adapter import ModeAdapter
22
21
  from msprobe.core.common.const import Const
22
+ from msprobe.core.common.decorator import recursion_depth_decorator
23
23
 
24
24
 
25
25
  class GraphComparator:
26
+ MAX_DEPTH = 1000
27
+
26
28
  def __init__(self, graphs, dump_path_param, args, mapping_dict=None):
27
29
  self.graph_n = graphs[0]
28
30
  self.graph_b = graphs[1]
@@ -41,7 +43,7 @@ class GraphComparator:
41
43
  else:
42
44
  self._compare_nodes(self.graph_n.root)
43
45
  self._postcompare()
44
-
46
+
45
47
  def add_compare_result_to_node(self, node, compare_result_list):
46
48
  """
47
49
  将比对结果添加到节点的输入输出数据中
@@ -66,43 +68,8 @@ class GraphComparator:
66
68
  self.ma.parse_result(node, [compare_in_dict, compare_out_dict]))
67
69
  node.data[GraphConst.JSON_INDEX_KEY] = precision_index
68
70
  node.data.update(other_dict)
69
-
70
- def _parse_param(self, dump_path_param, output_path):
71
- self.dump_path_param = dump_path_param
72
- self.output_path = output_path
73
- compare_mode = get_compare_mode(self.dump_path_param)
74
- self.ma = ModeAdapter(compare_mode)
75
- self.data_n_dict = load_data_json_file(dump_path_param.get('npu_json_path'))
76
- self.data_b_dict = load_data_json_file(dump_path_param.get('bench_json_path'))
77
- self.stack_json_data = load_json_file(dump_path_param.get('stack_json_path'))
78
-
79
- def _postcompare(self):
80
- self._handle_api_collection_index()
81
- if not self.ma.compare_mode == GraphConst.REAL_DATA_COMPARE:
82
- return
83
- df = get_csv_df(True, self.ma.csv_data, self.ma.compare_mode)
84
- df = run_real_data(self.dump_path_param, df, self.framework, True if self.mapping_dict else False)
85
- compare_data_dict = {row[0]: row.tolist() for _, row in df.iterrows()}
86
- for node in self.ma.compare_nodes:
87
- precision_index, _ = self.ma.parse_result(node, [compare_data_dict])
88
- node.data[GraphConst.JSON_INDEX_KEY] = precision_index
89
-
90
- def _handle_api_collection_index(self):
91
- """
92
- api集合的指标, md5模式使用集合中所有api最小的指标,statistics和tensor模式使用集合中所有api最大的指标
93
- md5模式下指标为0代表最差,statistics和tensor模式下指标为1代表最差
94
- """
95
- for node in self.graph_n.root.subnodes:
96
- if node.op == NodeOp.api_collection:
97
- precision_index = GraphConst.MAX_INDEX_KEY if self.ma.compare_mode == GraphConst.MD5_COMPARE \
98
- else GraphConst.MIN_INDEX_KEY
99
- for api in node.subnodes:
100
- precision_index = min(precision_index,
101
- api.data.get(GraphConst.JSON_INDEX_KEY, GraphConst.MAX_INDEX_KEY)) \
102
- if self.ma.compare_mode == GraphConst.MD5_COMPARE \
103
- else max(precision_index, api.data.get(GraphConst.JSON_INDEX_KEY, GraphConst.MIN_INDEX_KEY))
104
- node.data[GraphConst.JSON_INDEX_KEY] = precision_index
105
71
 
72
+ @recursion_depth_decorator('GraphComparator._compare_nodes', max_depth=MAX_DEPTH)
106
73
  def _compare_nodes(self, node_n):
107
74
  """
108
75
  递归遍历NPU树中的节点,如果在Bench中找到具有相同名称的节点,检查他们的祖先和参数信息,检查一致则及逆行精度数据对比
@@ -126,6 +93,7 @@ class GraphComparator:
126
93
  for subnode in node_n.subnodes:
127
94
  self._compare_nodes(subnode)
128
95
 
96
+ @recursion_depth_decorator('GraphComparator._compare_nodes_fuzzy', max_depth=MAX_DEPTH)
129
97
  def _compare_nodes_fuzzy(self, node_n):
130
98
  if node_n.op != NodeOp.function_api:
131
99
  # 模块经过模糊匹配
@@ -146,6 +114,42 @@ class GraphComparator:
146
114
  for sub_node in node_n.subnodes:
147
115
  self._compare_nodes_fuzzy(sub_node)
148
116
 
117
+ def _parse_param(self, dump_path_param, output_path):
118
+ self.dump_path_param = dump_path_param
119
+ self.output_path = output_path
120
+ compare_mode = get_compare_mode(self.dump_path_param)
121
+ self.ma = ModeAdapter(compare_mode)
122
+ self.data_n_dict = load_data_json_file(dump_path_param.get('npu_json_path'))
123
+ self.data_b_dict = load_data_json_file(dump_path_param.get('bench_json_path'))
124
+ self.stack_json_data = load_json_file(dump_path_param.get('stack_json_path'))
125
+
126
+ def _postcompare(self):
127
+ self._handle_api_collection_index()
128
+ if not self.ma.compare_mode == GraphConst.REAL_DATA_COMPARE:
129
+ return
130
+ df = get_csv_df(True, self.ma.csv_data, self.ma.compare_mode)
131
+ df = run_real_data(self.dump_path_param, df, self.framework, True if self.mapping_dict else False)
132
+ compare_data_dict = {row[0]: row.tolist() for _, row in df.iterrows()}
133
+ for node in self.ma.compare_nodes:
134
+ precision_index, _ = self.ma.parse_result(node, [compare_data_dict])
135
+ node.data[GraphConst.JSON_INDEX_KEY] = precision_index
136
+
137
+ def _handle_api_collection_index(self):
138
+ """
139
+ api集合的指标, md5模式使用集合中所有api最小的指标,statistics和tensor模式使用集合中所有api最大的指标
140
+ md5模式下指标为0代表最差,statistics和tensor模式下指标为1代表最差
141
+ """
142
+ for node in self.graph_n.root.subnodes:
143
+ if node.op == NodeOp.api_collection:
144
+ precision_index = GraphConst.MAX_INDEX_KEY if self.ma.compare_mode == GraphConst.MD5_COMPARE \
145
+ else GraphConst.MIN_INDEX_KEY
146
+ for api in node.subnodes:
147
+ precision_index = min(precision_index,
148
+ api.data.get(GraphConst.JSON_INDEX_KEY, GraphConst.MAX_INDEX_KEY)) \
149
+ if self.ma.compare_mode == GraphConst.MD5_COMPARE \
150
+ else max(precision_index, api.data.get(GraphConst.JSON_INDEX_KEY, GraphConst.MIN_INDEX_KEY))
151
+ node.data[GraphConst.JSON_INDEX_KEY] = precision_index
152
+
149
153
  def _get_and_add_result(self, node_n, node_b):
150
154
  compare_result_list = compare_node([node_n.id, node_b.id],
151
155
  [self.data_n_dict, self.data_b_dict],
@@ -14,7 +14,6 @@
14
14
  # limitations under the License.
15
15
 
16
16
  import json
17
- import math
18
17
  from msprobe.core.common.const import CompareConst, Const
19
18
  from msprobe.visualization.utils import ToolTip, GraphConst, str2float
20
19
 
@@ -157,24 +156,6 @@ class ModeAdapter:
157
156
  return
158
157
  self.csv_data.extend(compare_result_list)
159
158
 
160
- def add_error_key(self, node_data):
161
- """
162
- 根据不同的模式进行提供不同错误信息
163
- """
164
- for key, value in node_data.items():
165
- if not isinstance(value, dict):
166
- continue
167
- if self.compare_mode == GraphConst.SUMMARY_COMPARE:
168
- message = [CompareConst.MAX_RELATIVE_ERR, CompareConst.MIN_RELATIVE_ERR,
169
- CompareConst.MEAN_RELATIVE_ERR, CompareConst.NORM_RELATIVE_ERR]
170
- elif self.compare_mode == GraphConst.REAL_DATA_COMPARE:
171
- message = [CompareConst.ONE_THOUSANDTH_ERR_RATIO, CompareConst.FIVE_THOUSANDTHS_ERR_RATIO]
172
- else:
173
- # 输出件优化
174
- message = []
175
- value[GraphConst.ERROR_KEY] = message
176
- node_data[key] = value
177
-
178
159
  def get_tool_tip(self):
179
160
  """
180
161
  用于前端展示字段的具体含义
@@ -12,10 +12,11 @@
12
12
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
13
  # See the License for the specific language governing permissions and
14
14
  # limitations under the License.
15
+
15
16
  from msprobe.core.overflow_check.level import OverflowLevel
16
- from msprobe.visualization.graph.node_op import NodeOp
17
17
  from msprobe.visualization.utils import GraphConst
18
18
  from msprobe.visualization.builder.msprobe_adapter import format_node_data, compare_data, compare_data_fuzzy
19
+ from msprobe.core.common.log import logger
19
20
 
20
21
 
21
22
  class BaseNode:
@@ -114,7 +115,13 @@ class BaseNode:
114
115
  """
115
116
  ancestors = []
116
117
  current_node = self.upnode
118
+ seen_nodes = set()
117
119
  while current_node:
120
+ if current_node.id in seen_nodes:
121
+ logger.warning(f'Detected a cycle in the node structure and cannot get node ancestors, '
122
+ f'current node is {current_node.id}.')
123
+ return []
124
+ seen_nodes.add(current_node.id)
118
125
  ancestors.append(current_node.id)
119
126
  current_node = current_node.upnode
120
127
  return list(reversed(ancestors))
@@ -107,15 +107,6 @@ class DistributedAnalyzer:
107
107
  return None, None
108
108
  return group_ranks, group_id
109
109
 
110
- @staticmethod
111
- def _get_batch_group_info(node, rank):
112
- for data in node.input_data.values():
113
- group_id = data.get('group_id')
114
- if group_id is not None:
115
- return group_id
116
- logger.warning(f'The group_id of node {node.id} does not exist, {CANNOT_MATCH}{rank}')
117
- return None
118
-
119
110
  def distributed_match(self):
120
111
  for rank, graph in self.graphs.items():
121
112
  nodes = graph.node_map
@@ -377,7 +368,7 @@ class DistributedAnalyzer:
377
368
  target_api_name = self.config.get(api_name)[0]
378
369
  target_rank = int(id_info[1].replace(Const.RANK, ''))
379
370
  except Exception as e:
380
- logger.warning(f'Failed to parsing batch p2p parameter with error info: {e}.')
371
+ logger.warning(f'Failed to parse batch p2p parameter with error info: {e}.')
381
372
  continue
382
373
  target_node = self._get_target_node(rank, unique_group_id, api_name, target_rank, target_api_name)
383
374
  if not target_node: