mindstudio-probe 1.2.2__py3-none-any.whl → 1.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (153) hide show
  1. {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-1.3.0.dist-info}/METADATA +3 -3
  2. {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-1.3.0.dist-info}/RECORD +143 -144
  3. msprobe/README.md +25 -20
  4. msprobe/core/common/const.py +110 -66
  5. msprobe/core/common/decorator.py +50 -0
  6. msprobe/core/common/exceptions.py +3 -1
  7. msprobe/core/common/file_utils.py +25 -2
  8. msprobe/core/common/utils.py +30 -34
  9. msprobe/core/compare/acc_compare.py +43 -74
  10. msprobe/core/compare/check.py +2 -6
  11. msprobe/core/compare/highlight.py +2 -0
  12. msprobe/core/compare/layer_mapping/layer_mapping.py +2 -1
  13. msprobe/core/compare/merge_result/merge_result.py +8 -2
  14. msprobe/core/compare/multiprocessing_compute.py +19 -12
  15. msprobe/core/compare/npy_compare.py +30 -12
  16. msprobe/core/compare/utils.py +20 -10
  17. msprobe/core/data_dump/api_registry.py +176 -0
  18. msprobe/core/data_dump/data_processor/base.py +2 -2
  19. msprobe/core/data_dump/data_processor/mindspore_processor.py +19 -32
  20. msprobe/core/data_dump/data_processor/pytorch_processor.py +45 -15
  21. msprobe/core/data_dump/json_writer.py +38 -35
  22. msprobe/core/grad_probe/constant.py +1 -0
  23. msprobe/core/grad_probe/grad_compare.py +1 -1
  24. msprobe/core/overflow_check/abnormal_scene.py +2 -0
  25. msprobe/docs/01.installation.md +2 -1
  26. msprobe/docs/02.config_introduction.md +17 -15
  27. msprobe/docs/05.data_dump_PyTorch.md +70 -2
  28. msprobe/docs/06.data_dump_MindSpore.md +33 -12
  29. msprobe/docs/07.accuracy_checker_PyTorch.md +11 -1
  30. msprobe/docs/08.accuracy_checker_online_PyTorch.md +3 -1
  31. msprobe/docs/09.accuracy_checker_MindSpore.md +1 -1
  32. msprobe/docs/10.accuracy_compare_PyTorch.md +59 -33
  33. msprobe/docs/11.accuracy_compare_MindSpore.md +40 -16
  34. msprobe/docs/12.overflow_check_PyTorch.md +3 -1
  35. msprobe/docs/13.overflow_check_MindSpore.md +4 -2
  36. msprobe/docs/14.data_parse_PyTorch.md +1 -7
  37. msprobe/docs/18.online_dispatch.md +1 -1
  38. msprobe/docs/19.monitor.md +124 -62
  39. msprobe/docs/21.visualization_PyTorch.md +32 -13
  40. msprobe/docs/22.visualization_MindSpore.md +32 -13
  41. msprobe/docs/23.generate_operator_PyTorch.md +9 -9
  42. msprobe/docs/27.dump_json_instruction.md +278 -8
  43. msprobe/docs/28.kernel_dump_MindSpore.md +1 -1
  44. msprobe/docs/29.data_dump_MSAdapter.md +229 -0
  45. msprobe/docs/30.overflow_check_MSAdapter.md +31 -0
  46. msprobe/docs/FAQ.md +3 -11
  47. msprobe/docs/img/compare_result.png +0 -0
  48. msprobe/docs/img/merge_result.png +0 -0
  49. msprobe/docs/img/visualization/vis_browser_1.png +0 -0
  50. msprobe/docs/img/visualization/vis_match_info.png +0 -0
  51. msprobe/docs/img/visualization/vis_precision_info.png +0 -0
  52. msprobe/docs/img/visualization/vis_search_info.png +0 -0
  53. msprobe/docs/img/visualization/vis_show_info.png +0 -0
  54. msprobe/docs/img/visualization/vis_showcase.png +0 -0
  55. msprobe/docs/img/visualization/vis_unmatch_info.png +0 -0
  56. msprobe/mindspore/__init__.py +4 -3
  57. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +6 -1
  58. msprobe/mindspore/api_accuracy_checker/api_runner.py +19 -9
  59. msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +2 -1
  60. msprobe/mindspore/api_accuracy_checker/bench_functions/flash_attention_score.py +602 -0
  61. msprobe/mindspore/api_accuracy_checker/bench_functions/fusion_operator.py +41 -0
  62. msprobe/mindspore/api_accuracy_checker/data_manager.py +2 -1
  63. msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +2 -1
  64. msprobe/mindspore/api_accuracy_checker/torch_mindtorch_importer.py +2 -1
  65. msprobe/mindspore/common/const.py +61 -0
  66. msprobe/mindspore/common/utils.py +31 -19
  67. msprobe/mindspore/compare/ms_compare.py +27 -19
  68. msprobe/mindspore/compare/ms_graph_compare.py +6 -5
  69. msprobe/mindspore/debugger/debugger_config.py +6 -4
  70. msprobe/mindspore/debugger/precision_debugger.py +22 -10
  71. msprobe/mindspore/dump/dump_tool_factory.py +5 -3
  72. msprobe/mindspore/dump/hook_cell/api_register.py +142 -0
  73. msprobe/mindspore/dump/hook_cell/hook_cell.py +9 -10
  74. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +24 -26
  75. msprobe/mindspore/dump/jit_dump.py +14 -9
  76. msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +22 -56
  77. msprobe/mindspore/dym_loader/hook_dynamic_loader.h +0 -1
  78. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +10 -6
  79. msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +4 -2
  80. msprobe/mindspore/free_benchmark/self_check_tool_factory.py +6 -3
  81. msprobe/mindspore/grad_probe/global_context.py +2 -0
  82. msprobe/mindspore/grad_probe/grad_analyzer.py +2 -1
  83. msprobe/mindspore/grad_probe/hook.py +2 -4
  84. msprobe/mindspore/monitor/distributed/wrap_distributed.py +1 -1
  85. msprobe/mindspore/monitor/module_hook.py +354 -302
  86. msprobe/mindspore/monitor/utils.py +46 -4
  87. msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +5 -3
  88. msprobe/mindspore/service.py +23 -17
  89. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +3 -6
  90. msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +11 -6
  91. msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +2 -2
  92. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +4 -5
  93. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +5 -5
  94. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +25 -6
  95. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +28 -19
  96. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +3 -1
  97. msprobe/pytorch/bench_functions/moe_gating_top_k_softmax.py +6 -0
  98. msprobe/pytorch/common/utils.py +29 -7
  99. msprobe/pytorch/debugger/precision_debugger.py +10 -1
  100. msprobe/pytorch/dump/module_dump/module_dump.py +4 -3
  101. msprobe/pytorch/dump/module_dump/module_processer.py +12 -6
  102. msprobe/pytorch/free_benchmark/common/utils.py +1 -1
  103. msprobe/pytorch/free_benchmark/compare/single_benchmark.py +1 -1
  104. msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +3 -3
  105. msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +3 -3
  106. msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +1 -1
  107. msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +1 -1
  108. msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +1 -1
  109. msprobe/pytorch/function_factory.py +1 -1
  110. msprobe/pytorch/grad_probe/grad_monitor.py +2 -2
  111. msprobe/pytorch/hook_module/api_register.py +131 -0
  112. msprobe/pytorch/hook_module/hook_module.py +19 -14
  113. msprobe/pytorch/hook_module/register_optimizer_hook.py +2 -1
  114. msprobe/pytorch/hook_module/support_wrap_ops.yaml +172 -75
  115. msprobe/pytorch/monitor/csv2tb.py +8 -2
  116. msprobe/pytorch/monitor/distributed/wrap_distributed.py +8 -2
  117. msprobe/pytorch/monitor/module_hook.py +131 -105
  118. msprobe/pytorch/monitor/module_metric.py +3 -0
  119. msprobe/pytorch/monitor/optimizer_collect.py +55 -4
  120. msprobe/pytorch/monitor/unittest/test_monitor.py +1 -1
  121. msprobe/pytorch/monitor/utils.py +68 -1
  122. msprobe/pytorch/online_dispatch/compare.py +0 -2
  123. msprobe/pytorch/online_dispatch/dispatch.py +9 -0
  124. msprobe/pytorch/online_dispatch/dump_compare.py +3 -0
  125. msprobe/pytorch/online_dispatch/utils.py +3 -0
  126. msprobe/pytorch/parse_tool/lib/interactive_cli.py +1 -6
  127. msprobe/pytorch/parse_tool/lib/utils.py +2 -1
  128. msprobe/pytorch/pt_config.py +11 -7
  129. msprobe/pytorch/service.py +11 -8
  130. msprobe/visualization/builder/graph_builder.py +44 -5
  131. msprobe/visualization/builder/msprobe_adapter.py +0 -1
  132. msprobe/visualization/compare/graph_comparator.py +42 -38
  133. msprobe/visualization/compare/mode_adapter.py +0 -19
  134. msprobe/visualization/graph/base_node.py +8 -1
  135. msprobe/visualization/graph/distributed_analyzer.py +1 -10
  136. msprobe/visualization/graph/graph.py +0 -11
  137. msprobe/visualization/graph/node_op.py +1 -2
  138. msprobe/visualization/graph_service.py +1 -1
  139. msprobe/visualization/utils.py +2 -33
  140. msprobe/mindspore/dump/hook_cell/api_registry.py +0 -207
  141. msprobe/mindspore/dump/hook_cell/wrap_api.py +0 -212
  142. msprobe/pytorch/hook_module/api_registry.py +0 -166
  143. msprobe/pytorch/hook_module/wrap_distributed.py +0 -79
  144. msprobe/pytorch/hook_module/wrap_functional.py +0 -66
  145. msprobe/pytorch/hook_module/wrap_npu_custom.py +0 -85
  146. msprobe/pytorch/hook_module/wrap_tensor.py +0 -69
  147. msprobe/pytorch/hook_module/wrap_torch.py +0 -84
  148. msprobe/pytorch/hook_module/wrap_vf.py +0 -60
  149. msprobe/pytorch/parse.py +0 -19
  150. {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-1.3.0.dist-info}/LICENSE +0 -0
  151. {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-1.3.0.dist-info}/WHEEL +0 -0
  152. {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-1.3.0.dist-info}/entry_points.txt +0 -0
  153. {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-1.3.0.dist-info}/top_level.txt +0 -0
@@ -70,6 +70,67 @@ class Const:
70
70
  }
71
71
 
72
72
 
73
+ class MsCompareConst:
74
+ # api_info field
75
+ MINT = "Mint"
76
+ MINT_FUNCTIONAL = "MintFunctional"
77
+ TENSOR_API = "Tensor"
78
+ FUNCTIONAL_API = "Functional"
79
+ FUSION_API = "FUSION"
80
+
81
+ API_NAME_STR_LENGTH = 4
82
+ MAX_RECURSION_DEPTH = 20
83
+
84
+ # Mindtorch api_info field
85
+ MINDTORCH_TENSOR = "Tensor"
86
+ MINDTORCH = "Torch"
87
+ MINDTORCH_FUNC = "Functional"
88
+ MINDTORCH_NPU = "NPU"
89
+ MINDTORCH_DIST = "Distributed"
90
+
91
+
92
+
93
+ MT_VALID_API_TYPES = [
94
+ MINDTORCH, MINDTORCH_FUNC, MINDTORCH_TENSOR
95
+ ]
96
+ SUPPORTED_FUSION_LIST = ["flash_attention_score"]
97
+
98
+
99
+ TASK_FIELD = "task"
100
+ STATISTICS_TASK = "statistics"
101
+ FRAMEWORK = "framework"
102
+ TENSOR_TASK = "tensor"
103
+ DUMP_DATA_DIR_FIELD = "dump_data_dir"
104
+ DATA_FIELD = "data"
105
+
106
+ # supported api yaml
107
+ SUPPORTED_API_LIST_FILE = "checker_support_api.yaml"
108
+ SUPPORTED_TENSOR_LIST_KEY = "tensor"
109
+
110
+ # detail_csv
111
+ DETAIL_CSV_API_NAME = "API Name"
112
+ DETAIL_CSV_BENCH_DTYPE = "Bench Dtype"
113
+ DETAIL_CSV_TESTED_DTYPE = "Tested Dtype"
114
+ DETAIL_CSV_SHAPE = "Shape"
115
+ DETAIL_CSV_PASS_STATUS = "Status"
116
+ DETAIL_CSV_MESSAGE = "Message"
117
+ DETAIL_CSV_FILE_NAME = "accuracy_checking_details"
118
+
119
+ # result_csv
120
+ RESULT_CSV_FORWARD_TEST_SUCCESS = "Forward Test Success"
121
+ RESULT_CSV_BACKWARD_TEST_SUCCESS = "Backward Test Success"
122
+ RESULT_CSV_FILE_NAME = "accuracy_checking_result"
123
+
124
+ EPSILON = 1e-8
125
+
126
+ class ProcessStatus:
127
+ SUCCESS = "success"
128
+ API_NOT_FOUND = "api_not_found"
129
+ EXCEPTION_SKIP = "exception_skip"
130
+
131
+
132
+
133
+
73
134
  class FreeBenchmarkConst:
74
135
  ADD_NOISE = "add_noise"
75
136
  BIT_NOISE = "bit_noise"
@@ -25,7 +25,31 @@ from msprobe.core.common.exceptions import DistributedNotInitializedError
25
25
  from msprobe.core.common.file_utils import path_len_exceeds_limit, check_path_exists, save_npy
26
26
  from msprobe.core.common.log import logger
27
27
  from msprobe.core.common.const import Const
28
- from msprobe.core.common.utils import CompareException, check_seed_all
28
+ from msprobe.core.common.utils import CompareException, check_seed_all, is_save_variable_valid
29
+
30
+
31
+ class MsprobeStep(ms.train.Callback):
32
+ def __init__(self, debugger):
33
+ super(MsprobeStep, self).__init__()
34
+ self.debugger = debugger
35
+
36
+ def on_train_step_begin(self, run_context):
37
+ self.debugger.start()
38
+
39
+ def on_train_step_end(self, run_context):
40
+ self.debugger.stop()
41
+ self.debugger.step()
42
+
43
+
44
+ class MsprobeInitStep(ms.train.Callback):
45
+ def on_train_begin(self, run_context):
46
+ try:
47
+ from ms._c_expression import _set_init_iter
48
+ except ImportError:
49
+ logger.warning('MsprobeInitStep does not work on this version of MindSpore.')
50
+ return
51
+ cb_params = run_context.original_args()
52
+ _set_init_iter(cb_params.cur_step_num)
29
53
 
30
54
 
31
55
  def get_rank_if_initialized():
@@ -93,20 +117,6 @@ def seed_all(seed=1234, mode=False, rm_dropout=True):
93
117
  remove_dropout()
94
118
 
95
119
 
96
- class MsprobeStep(ms.train.Callback):
97
-
98
- def __init__(self, debugger):
99
- super(MsprobeStep, self).__init__()
100
- self.debugger = debugger
101
-
102
- def on_train_step_begin(self, run_context):
103
- self.debugger.start()
104
-
105
- def on_train_step_end(self, run_context):
106
- self.debugger.stop()
107
- self.debugger.step()
108
-
109
-
110
120
  class Dropout(ops.Dropout):
111
121
  def __init__(self, keep_prob=0.5, seed0=0, seed1=1):
112
122
  super().__init__(1., seed0, seed1)
@@ -169,7 +179,7 @@ def set_register_backward_hook_functions():
169
179
  from msprobe.mindspore.mindtorch import (_call_impl,
170
180
  register_full_backward_pre_hook,
171
181
  register_full_backward_hook)
172
- if not hasattr(torch, "register_full_backward_hook"):
182
+ if not hasattr(torch.nn.Module, "register_full_backward_hook"):
173
183
  setattr(torch.nn.Module, "_call_impl", _call_impl)
174
184
  setattr(torch.nn.Module, "register_full_backward_pre_hook", register_full_backward_pre_hook)
175
185
  setattr(torch.nn.Module, "register_full_backward_hook", register_full_backward_hook)
@@ -182,9 +192,11 @@ def set_register_backward_hook_functions():
182
192
 
183
193
  def check_save_param(variable, name, save_backward):
184
194
  # try catch this api to skip invalid call
185
- if not isinstance(variable, (list, dict, ms.Tensor, int, float, str)):
195
+ valid_data_types = tuple([ms.Tensor, int, float, str])
196
+ if not is_save_variable_valid(variable, valid_data_types):
197
+ valid_data_types_with_nested_types = valid_data_types + (dict, tuple, list)
186
198
  logger.warning("PrecisionDebugger.save variable type not valid, "
187
- "should be one of list, dict, ms.Tensor, int, float or string. "
199
+ f"should be one of {valid_data_types_with_nested_types}"
188
200
  "Skip current save process.")
189
201
  raise ValueError
190
202
  if not isinstance(name, str):
@@ -196,4 +208,4 @@ def check_save_param(variable, name, save_backward):
196
208
  logger.warning("PrecisionDebugger.save_backward name not valid, "
197
209
  "should be bool. "
198
210
  "Skip current save process.")
199
- raise ValueError
211
+ raise ValueError
@@ -22,10 +22,10 @@ import pandas as pd
22
22
 
23
23
  from msprobe.core.common.const import CompareConst, Const
24
24
  from msprobe.core.common.exceptions import FileCheckException
25
- from msprobe.core.common.file_utils import FileOpen, create_directory, load_json, load_npy, load_yaml
25
+ from msprobe.core.common.file_utils import create_directory, load_json, load_npy, load_yaml
26
26
  from msprobe.core.common.log import logger
27
27
  from msprobe.core.common.utils import CompareException, check_compare_param, check_configuration_param, \
28
- check_op_str_pattern_valid, get_dump_mode, set_dump_path
28
+ check_op_str_pattern_valid, get_dump_mode, set_dump_path, detect_framework_by_dump_json
29
29
  from msprobe.core.compare.acc_compare import Comparator, ModeConfig
30
30
  from msprobe.core.compare.check import dtype_mapping
31
31
  from msprobe.core.compare.layer_mapping import generate_data_mapping_by_layer_mapping
@@ -78,6 +78,11 @@ class MSComparator(Comparator):
78
78
  raise TypeError(f"The type of parameter `data_mapping` must be dict, str or None, but got "
79
79
  f"{type(self.data_mapping)}")
80
80
 
81
+ @staticmethod
82
+ def process_data_name(result):
83
+ result['data_name_x'] = result.apply(lambda row: [row['data_name_x'], row['data_name_y']], axis=1)
84
+ return result
85
+
81
86
  def calc_accuracy(self, result_df, header):
82
87
  condition_no_bench = result_df[CompareConst.BENCH_NAME] == CompareConst.N_A
83
88
  result_df[condition_no_bench] = result_df[condition_no_bench].fillna(CompareConst.N_A)
@@ -120,12 +125,13 @@ class MSComparator(Comparator):
120
125
  result_df.loc[~condition_md5_equal & ~condition_no_bench, CompareConst.RESULT] = CompareConst.DIFF
121
126
  elif self.dump_mode == Const.SUMMARY:
122
127
  warning_list = [calc_summary_diff(data_type) for data_type in ['max', 'min', 'mean', 'l2norm']]
123
- warning_flag = pd.DataFrame(warning_list).all()
128
+ warning_flag = pd.DataFrame(warning_list).any()
124
129
  result_df.loc[~condition_no_bench, [CompareConst.RESULT, CompareConst.ERROR_MESSAGE]] = ''
125
130
  result_df.loc[warning_flag, CompareConst.RESULT] = CompareConst.WARNING
126
131
  result_df.loc[warning_flag, CompareConst.ERROR_MESSAGE] = 'Need double check api accuracy.'
127
132
  else:
128
- fill_cols = [CompareConst.COSINE, CompareConst.MAX_ABS_ERR, CompareConst.MAX_RELATIVE_ERR,
133
+ fill_cols = [CompareConst.COSINE, CompareConst.EUC_DIST,
134
+ CompareConst.MAX_ABS_ERR, CompareConst.MAX_RELATIVE_ERR,
129
135
  CompareConst.ONE_THOUSANDTH_ERR_RATIO, CompareConst.FIVE_THOUSANDTHS_ERR_RATIO,
130
136
  CompareConst.ERROR_MESSAGE]
131
137
  result_df.loc[~condition_no_bench, fill_cols] = ''
@@ -139,6 +145,8 @@ class MSComparator(Comparator):
139
145
  header.append(CompareConst.STACK)
140
146
  if self.dump_mode == Const.ALL:
141
147
  header.append(CompareConst.DATA_NAME)
148
+ result = self.process_data_name(result)
149
+
142
150
  result.rename(columns={'op_name_x': CompareConst.NPU_NAME,
143
151
  'op_name_y': CompareConst.BENCH_NAME,
144
152
  'dtype_x': CompareConst.NPU_DTYPE,
@@ -169,6 +177,7 @@ class MSComparator(Comparator):
169
177
 
170
178
  result[npu_summary] = result['summary_x'].apply(set_summary).tolist()
171
179
  result[bench_summary] = result['summary_y'].apply(set_summary).tolist()
180
+
172
181
  result_df = pd.DataFrame(columns=header)
173
182
  for h in header:
174
183
  if h in result.columns:
@@ -269,15 +278,15 @@ class MSComparator(Comparator):
269
278
  bench_dtype = match_result['dtype_y']
270
279
  if self.cross_frame:
271
280
  npu_dtype = npu_dtype.map(dtype_mapping).fillna(npu_dtype)
272
- return ((npu_dtype == bench_dtype) |
273
- ((npu_dtype == Const.FLOAT16) & (bench_dtype == Const.FLOAT32)) |
274
- ((npu_dtype == Const.FLOAT32) & (bench_dtype == Const.FLOAT16)) |
275
- ((npu_dtype == Const.FLOAT16) & (bench_dtype == Const.BFLOAT16)) |
276
- ((npu_dtype == Const.BFLOAT16) & (bench_dtype == Const.FLOAT16)) |
277
- ((npu_dtype == Const.TORCH_FLOAT16) & (bench_dtype == Const.TORCH_FLOAT32)) |
278
- ((npu_dtype == Const.TORCH_FLOAT32) & (bench_dtype == Const.TORCH_FLOAT16)) |
279
- ((npu_dtype == Const.TORCH_FLOAT16) & (bench_dtype == Const.TORCH_BFLOAT16)) |
280
- ((npu_dtype == Const.TORCH_BFLOAT16) & (bench_dtype == Const.TORCH_FLOAT16)))
281
+
282
+ equal_condition = npu_dtype == bench_dtype
283
+ match_condition = (
284
+ (npu_dtype.isin(CompareConst.DTYPE_MATCH_GROUPS[0]) & bench_dtype.isin(
285
+ CompareConst.DTYPE_MATCH_GROUPS[0])) |
286
+ (npu_dtype.isin(CompareConst.DTYPE_MATCH_GROUPS[1]) & bench_dtype.isin(
287
+ CompareConst.DTYPE_MATCH_GROUPS[1]))
288
+ )
289
+ return equal_condition | match_condition
281
290
 
282
291
  match_result.loc[~gen_dtype_condition(), [i + '_y' for i in bench_df.columns]] = CompareConst.N_A
283
292
  return self.make_result_df(match_result)
@@ -382,12 +391,11 @@ class MSComparator(Comparator):
382
391
 
383
392
 
384
393
  def check_cross_framework(bench_json_path):
385
- pattern = r'"data_name":\s*"[^"]+\.pt"'
386
- with FileOpen(bench_json_path, 'r') as file:
387
- for line in file:
388
- if re.search(pattern, line):
389
- return True
390
- return False
394
+ framework = detect_framework_by_dump_json(bench_json_path)
395
+ if framework == Const.PT_FRAMEWORK:
396
+ return True
397
+ else:
398
+ return False
391
399
 
392
400
 
393
401
  def ms_compare(input_param, output_path, **kwargs):
@@ -195,11 +195,12 @@ class GraphMSComparator:
195
195
  if not error_flag:
196
196
  result_list, err_msg = compare_ops_apply(n_value, b_value, False, "")
197
197
  result_dict[CompareConst.COSINE] = result_list[0]
198
- result_dict[CompareConst.MAX_ABS_ERR] = result_list[1]
199
- result_dict[CompareConst.MAX_RELATIVE_ERR] = result_list[2]
200
- result_dict[CompareConst.ONE_THOUSANDTH_ERR_RATIO] = result_list[3]
201
- result_dict[CompareConst.FIVE_THOUSANDTHS_ERR_RATIO] = result_list[4]
202
- result_dict[CompareConst.ACCURACY] = check_accuracy(result_list[0], result_list[1])
198
+ result_dict[CompareConst.EUC_DIST] = result_list[1]
199
+ result_dict[CompareConst.MAX_ABS_ERR] = result_list[2]
200
+ result_dict[CompareConst.MAX_RELATIVE_ERR] = result_list[3]
201
+ result_dict[CompareConst.ONE_THOUSANDTH_ERR_RATIO] = result_list[4]
202
+ result_dict[CompareConst.FIVE_THOUSANDTHS_ERR_RATIO] = result_list[5]
203
+ result_dict[CompareConst.ACCURACY] = check_accuracy(result_list[0], result_list[2])
203
204
  result_dict[CompareConst.ERROR_MESSAGE] = err_msg
204
205
 
205
206
  return pd.Series(result_dict)
@@ -53,11 +53,13 @@ class DebuggerConfig:
53
53
  self.stage = FreeBenchmarkConst.DEFAULT_STAGE if not task_config.fuzz_stage else task_config.fuzz_stage
54
54
  if self.handler_type == FreeBenchmarkConst.FIX and \
55
55
  self.pert_type != FreeBenchmarkConst.DEFAULT_PERT_TYPE:
56
- raise ValueError("pert_mode must be improve_precision or empty when handler_type is fix, "
57
- f"but got {self.pert_type}.")
56
+ logger.error("pert_mode must be improve_precision or empty when handler_type is fix, "
57
+ f"but got {self.pert_type}.")
58
+ raise ValueError
58
59
  if self.stage == Const.BACKWARD and self.handler_type == FreeBenchmarkConst.FIX:
59
- raise ValueError("handler_type must be check or empty when fuzz_stage is backward, "
60
- f"but got {self.handler_type}.")
60
+ logger.error("handler_type must be check or empty when fuzz_stage is backward, "
61
+ f"but got {self.handler_type}.")
62
+ raise ValueError
61
63
  self.dump_level = FreeBenchmarkConst.DEFAULT_DUMP_LEVEL
62
64
 
63
65
  def check(self):
@@ -22,12 +22,12 @@ from mindspore._c_expression import MSContext
22
22
  from msprobe.core.common.const import Const, FileCheckConst, MsgConst
23
23
  from msprobe.core.common.exceptions import MsprobeException
24
24
  from msprobe.core.common.file_utils import FileChecker
25
- from msprobe.core.common.utils import get_real_step_or_rank
25
+ from msprobe.core.common.utils import get_real_step_or_rank, check_init_step
26
26
  from msprobe.mindspore.cell_processor import CellProcessor
27
27
  from msprobe.mindspore.common.const import Const as MsConst
28
28
  from msprobe.mindspore.common.utils import set_register_backward_hook_functions, check_save_param
29
29
  from msprobe.mindspore.debugger.debugger_config import DebuggerConfig
30
- from msprobe.mindspore.dump.hook_cell.api_registry import api_register
30
+ from msprobe.mindspore.dump.hook_cell.api_register import get_api_register
31
31
  from msprobe.mindspore.dump.hook_cell.hook_cell import HOOKCell
32
32
  from msprobe.mindspore.grad_probe.grad_monitor import GradientMonitor
33
33
  from msprobe.mindspore.ms_config import parse_json_config
@@ -84,7 +84,7 @@ class PrecisionDebugger:
84
84
  common_config.dump_path = dump_path if dump_path else common_config.dump_path
85
85
  self.config = DebuggerConfig(common_config, task_config)
86
86
 
87
- if _msprobe_c:
87
+ if self._need_msprobe_c() and _msprobe_c:
88
88
  _msprobe_c._PrecisionDebugger(framework="MindSpore", config_path=config_path)
89
89
 
90
90
  self.config.execution_mode = self._get_execution_mode()
@@ -151,7 +151,7 @@ class PrecisionDebugger:
151
151
  instance = cls._instance
152
152
  if not instance:
153
153
  raise Exception(MsgConst.NOT_CREATED_INSTANCE)
154
- if _msprobe_c:
154
+ if cls._need_msprobe_c() and _msprobe_c:
155
155
  _msprobe_c._PrecisionDebugger().start()
156
156
  if instance.task in PrecisionDebugger.task_not_need_service:
157
157
  return
@@ -163,7 +163,7 @@ class PrecisionDebugger:
163
163
  instance.service.start(model)
164
164
  else:
165
165
  if not instance.first_start:
166
- api_register.api_set_ori_func()
166
+ get_api_register().restore_all_api()
167
167
  handler = TaskHandlerFactory.create(instance.config)
168
168
  handler.handle()
169
169
 
@@ -180,8 +180,6 @@ class PrecisionDebugger:
180
180
  instance = cls._instance
181
181
  if not instance:
182
182
  raise Exception(MsgConst.NOT_CREATED_INSTANCE)
183
- if _msprobe_c:
184
- _msprobe_c._PrecisionDebugger().stop()
185
183
  if instance.task == Const.GRAD_PROBE:
186
184
  instance.gm.stop()
187
185
  if instance.task in PrecisionDebugger.task_not_need_service:
@@ -195,8 +193,6 @@ class PrecisionDebugger:
195
193
  instance = cls._instance
196
194
  if not instance:
197
195
  raise Exception(MsgConst.NOT_CREATED_INSTANCE)
198
- if _msprobe_c:
199
- _msprobe_c._PrecisionDebugger().step()
200
196
  if instance.task in PrecisionDebugger.task_not_need_service:
201
197
  return
202
198
  if instance.service:
@@ -233,6 +229,15 @@ class PrecisionDebugger:
233
229
  instance.service = Service(instance.config)
234
230
  instance.service.save(variable, name, save_backward)
235
231
 
232
+ @classmethod
233
+ def set_init_step(cls, step):
234
+ instance = cls._instance
235
+ if not instance:
236
+ raise Exception(MsgConst.NOT_CREATED_INSTANCE)
237
+ check_init_step(step)
238
+ instance.service.init_step = step
239
+ instance.service.loop = 0
240
+
236
241
  @classmethod
237
242
  def _need_service(cls):
238
243
  instance = cls._instance
@@ -241,4 +246,11 @@ class PrecisionDebugger:
241
246
  if instance.config.execution_mode != MsConst.PYNATIVE_MODE:
242
247
  return False
243
248
  else:
244
- return instance.config.task != Const.FREE_BENCHMARK and not instance._is_graph_dump(instance.config)
249
+ return instance.config.task != Const.FREE_BENCHMARK and not instance._is_graph_dump(instance.config)
250
+
251
+ @classmethod
252
+ def _need_msprobe_c(cls):
253
+ instance = cls._instance
254
+ if not instance:
255
+ raise Exception(MsgConst.NOT_CREATED_INSTANCE)
256
+ return instance.config.level_ori == Const.LEVEL_L2
@@ -1,4 +1,4 @@
1
- # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
1
+ # Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
2
2
  # All rights reserved.
3
3
  #
4
4
  # Licensed under the Apache License, Version 2.0 (the "License");
@@ -14,6 +14,7 @@
14
14
  # limitations under the License.
15
15
 
16
16
  from msprobe.mindspore.common.const import Const
17
+ from msprobe.core.common.log import logger
17
18
  from msprobe.mindspore.debugger.debugger_config import DebuggerConfig
18
19
  from msprobe.mindspore.dump.kernel_graph_dump import KernelGraphDump
19
20
  from msprobe.mindspore.dump.kernel_kbyk_dump import KernelKbykDump
@@ -47,6 +48,7 @@ class DumpToolFactory:
47
48
  raise Exception("Valid level is needed.")
48
49
  tool = tool.get(config.execution_mode)
49
50
  if not tool:
50
- raise Exception(f"Data dump is not supported in {config.execution_mode} mode "
51
- f"when dump level is {config.level}.")
51
+ logger.error(f"Data dump is not supported in {config.execution_mode} mode "
52
+ f"when dump level is {config.level}.")
53
+ raise ValueError
52
54
  return tool(config)
@@ -0,0 +1,142 @@
1
+ # Copyright (c) 2025-2025, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import os
17
+
18
+ from mindspore import Tensor, ops, mint
19
+ from mindspore.mint.nn import functional
20
+ from mindspore.communication import comm_func
21
+
22
+ from msprobe.core.common.file_utils import load_yaml
23
+ from msprobe.core.common.utils import Const
24
+ from msprobe.core.data_dump.api_registry import ApiRegistry
25
+ from msprobe.mindspore.common.const import Const as MsConst
26
+ from msprobe.mindspore.common.utils import is_mindtorch
27
+ from msprobe.mindspore.dump.hook_cell.hook_cell import HOOKCell
28
+
29
+
30
+ stub_tensor_existed = True
31
+ try:
32
+ from mindspore.common._stub_tensor import StubTensor
33
+ except ImportError:
34
+ stub_tensor_existed = False
35
+
36
+ cur_path = os.path.dirname(os.path.realpath(__file__))
37
+ if not is_mindtorch():
38
+ _api_types = {
39
+ Const.MS_FRAMEWORK: {
40
+ Const.MS_API_TYPE_OPS: (ops, (ops,)),
41
+ Const.MS_API_TYPE_TENSOR: (Tensor, (Tensor,)),
42
+ Const.MS_API_TYPE_MINT: (mint, (mint,)),
43
+ Const.MS_API_TYPE_MINT_FUNC: (functional, (functional,)),
44
+ Const.MS_API_TYPE_COM: (comm_func, (comm_func,))
45
+ }
46
+ }
47
+ if stub_tensor_existed:
48
+ _api_types.get(Const.MS_FRAMEWORK).update(
49
+ {Const.MS_API_TYPE_STUB_TENSOR: (StubTensor, (StubTensor,))}
50
+ )
51
+
52
+ _supported_api_list_path = (os.path.join(cur_path, MsConst.SUPPORTED_API_LIST_FILE),)
53
+ else:
54
+ import torch
55
+ import torch_npu
56
+ _api_types = {
57
+ Const.MT_FRAMEWORK: {
58
+ Const.PT_API_TYPE_FUNCTIONAL: (torch.nn.functional, (torch.nn.functional,)),
59
+ Const.PT_API_TYPE_TENSOR: (torch.Tensor, (torch.Tensor,)),
60
+ Const.PT_API_TYPE_TORCH: (torch, (torch,)),
61
+ Const.PT_API_TYPE_NPU: (torch_npu, (torch_npu,)),
62
+ Const.PT_API_TYPE_DIST: (torch.distributed, (torch.distributed, torch.distributed.distributed_c10d))
63
+ }
64
+ }
65
+ _supported_api_list_path = (os.path.join(cur_path, '../../../pytorch/hook_module',
66
+ MsConst.SUPPORTED_API_LIST_FILE),)
67
+
68
+ _inner_used_api = {
69
+ Const.MS_FRAMEWORK + Const.SEP + Const.MS_API_TYPE_OPS: (
70
+ ops, "norm", "square", "sqrt", "is_complex", "stack", "is_floating_point"
71
+ ),
72
+ Const.MS_FRAMEWORK + Const.SEP + Const.MS_API_TYPE_TENSOR: (
73
+ Tensor, "to", "numel"
74
+ ),
75
+ Const.MS_FRAMEWORK + Const.SEP + Const.MS_API_TYPE_MINT: (
76
+ mint, "max", "min", "mean", "norm"
77
+ )
78
+ }
79
+
80
+
81
+ class ApiTemplate(HOOKCell):
82
+ def __init__(self, api_name, api_func, prefix, hook_build_func):
83
+ self.api_name = api_name
84
+ self.api_func = api_func
85
+ self.prefix_api_name = prefix + Const.SEP + str(api_name.split(Const.SEP)[-1]) + Const.SEP
86
+ super().__init__(hook_build_func)
87
+
88
+ @staticmethod
89
+ def async_to_sync(output):
90
+ # Fake handle, used to return after the CommHandle executes the wait method
91
+ fake_handle = type("FakeHandle", (), {"wait": lambda self: None})()
92
+ if isinstance(output, tuple) and len(output) == 2 and hasattr(output[1], "wait"):
93
+ output[1].wait()
94
+ output = (output[0], fake_handle)
95
+ elif hasattr(output, "wait"):
96
+ output.wait()
97
+ output = fake_handle
98
+ return output
99
+
100
+ def construct(self, *args, **kwargs):
101
+ if self.api_name.startswith(MsConst.DROPOUT_API_NAME_PREFIX):
102
+ return args[0] if args else kwargs.get(Const.INPUT)
103
+
104
+ output = self.api_func(*args, **kwargs)
105
+
106
+ if self.prefix_api_name.startswith(MsConst.DISTRIBUTED_DATA_PREFIX):
107
+ if kwargs.get("async_op") or self.api_name in ["isend", "irecv"]:
108
+ output = self.async_to_sync(output)
109
+ return output
110
+
111
+ def forward(self, *args, **kwargs):
112
+ if self.api_name.startswith(MsConst.DROPOUT_API_NAME_PREFIX):
113
+ return args[0] if args else kwargs.get(Const.INPUT)
114
+ return self.api_func(*args, **kwargs)
115
+
116
+
117
+ api_register = None
118
+ stub_tensor_set = False
119
+
120
+
121
+ def get_api_register(return_new=False):
122
+ global stub_tensor_set
123
+
124
+ def stub_method(method):
125
+ def wrapped_method(*args, **kwargs):
126
+ return method(*args, **kwargs)
127
+ return wrapped_method
128
+ if not is_mindtorch() and stub_tensor_existed and not stub_tensor_set:
129
+ api_names = load_yaml(_supported_api_list_path[0]).get(Const.MS_API_TYPE_TENSOR, [])
130
+ for attr_name in dir(StubTensor):
131
+ attr = getattr(StubTensor, attr_name)
132
+ if attr_name in api_names and callable(attr):
133
+ setattr(StubTensor, attr_name, stub_method(attr))
134
+ stub_tensor_set = True
135
+
136
+ if return_new:
137
+ return ApiRegistry(_api_types, _inner_used_api, _supported_api_list_path, ApiTemplate)
138
+
139
+ global api_register
140
+ if api_register is None:
141
+ api_register = ApiRegistry(_api_types, _inner_used_api, _supported_api_list_path, ApiTemplate)
142
+ return api_register
@@ -28,23 +28,22 @@ def get_cell_count(name):
28
28
  return HOOKCell.cell_count[name]
29
29
 
30
30
 
31
- def __init__(self, build_hook) -> None:
31
+ def __init__(self, hook_build_func) -> None:
32
32
  super(HOOKCell, self).__init__()
33
33
  self.changed_status = False
34
34
  self.input_kwargs = {}
35
- self.prefix = ""
36
35
  if not HOOKCell.g_stop_hook:
37
36
  HOOKCell.g_stop_hook = True
38
37
  self.changed_status = True
39
- if hasattr(self, "prefix_api_name"):
40
- self.prefix = self.prefix_api_name
41
-
42
38
  self.forward_data_collected = False
43
- forward_pre_hook, forward_hook, backward_hook, backward_pre_hook = build_hook(self.prefix)
44
- self.register_forward_pre_hook(forward_pre_hook)
45
- self.register_forward_hook(forward_hook)
46
- register_backward_hook_functions["full"](self, backward_hook)
47
- register_backward_hook_functions["pre"](self, backward_pre_hook)
39
+
40
+ prefix = self.prefix_api_name if hasattr(self, "prefix_api_name") else ""
41
+ if callable(hook_build_func):
42
+ forward_pre_hook, forward_hook, backward_hook, backward_pre_hook = hook_build_func(prefix)
43
+ self.register_forward_pre_hook(forward_pre_hook)
44
+ self.register_forward_hook(forward_hook)
45
+ register_backward_hook_functions["full"](self, backward_hook)
46
+ register_backward_hook_functions["pre"](self, backward_pre_hook)
48
47
 
49
48
 
50
49
  # 重载call,加全局标志。