mindstudio-probe 1.2.1__py3-none-any.whl → 1.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (177) hide show
  1. {mindstudio_probe-1.2.1.dist-info → mindstudio_probe-1.3.0.dist-info}/METADATA +3 -3
  2. {mindstudio_probe-1.2.1.dist-info → mindstudio_probe-1.3.0.dist-info}/RECORD +168 -150
  3. msprobe/README.md +27 -22
  4. msprobe/core/common/const.py +129 -60
  5. msprobe/core/common/decorator.py +50 -0
  6. msprobe/core/common/exceptions.py +3 -1
  7. msprobe/core/common/file_utils.py +25 -2
  8. msprobe/core/common/inplace_ops.yaml +1 -0
  9. msprobe/core/common/utils.py +43 -33
  10. msprobe/core/compare/acc_compare.py +43 -74
  11. msprobe/core/compare/check.py +2 -6
  12. msprobe/core/compare/highlight.py +2 -0
  13. msprobe/core/compare/layer_mapping/data_scope_parser.py +1 -1
  14. msprobe/core/compare/layer_mapping/layer_mapping.py +2 -1
  15. msprobe/core/compare/merge_result/merge_result.py +16 -9
  16. msprobe/core/compare/merge_result/utils.py +81 -0
  17. msprobe/core/compare/multiprocessing_compute.py +19 -12
  18. msprobe/core/compare/npy_compare.py +30 -12
  19. msprobe/core/compare/utils.py +30 -10
  20. msprobe/core/data_dump/api_registry.py +176 -0
  21. msprobe/core/data_dump/data_collector.py +58 -13
  22. msprobe/core/data_dump/data_processor/base.py +94 -10
  23. msprobe/core/data_dump/data_processor/factory.py +3 -0
  24. msprobe/core/data_dump/data_processor/mindspore_processor.py +33 -33
  25. msprobe/core/data_dump/data_processor/pytorch_processor.py +99 -18
  26. msprobe/core/data_dump/json_writer.py +61 -40
  27. msprobe/core/grad_probe/constant.py +1 -0
  28. msprobe/core/grad_probe/grad_compare.py +1 -1
  29. msprobe/core/overflow_check/abnormal_scene.py +2 -0
  30. msprobe/docs/01.installation.md +27 -1
  31. msprobe/docs/02.config_introduction.md +27 -23
  32. msprobe/docs/03.config_examples.md +24 -0
  33. msprobe/docs/05.data_dump_PyTorch.md +103 -16
  34. msprobe/docs/06.data_dump_MindSpore.md +76 -32
  35. msprobe/docs/07.accuracy_checker_PyTorch.md +11 -1
  36. msprobe/docs/08.accuracy_checker_online_PyTorch.md +3 -1
  37. msprobe/docs/09.accuracy_checker_MindSpore.md +5 -3
  38. msprobe/docs/10.accuracy_compare_PyTorch.md +59 -33
  39. msprobe/docs/11.accuracy_compare_MindSpore.md +40 -16
  40. msprobe/docs/12.overflow_check_PyTorch.md +3 -1
  41. msprobe/docs/13.overflow_check_MindSpore.md +4 -2
  42. msprobe/docs/14.data_parse_PyTorch.md +1 -7
  43. msprobe/docs/18.online_dispatch.md +1 -1
  44. msprobe/docs/19.monitor.md +332 -273
  45. msprobe/docs/21.visualization_PyTorch.md +42 -13
  46. msprobe/docs/22.visualization_MindSpore.md +43 -13
  47. msprobe/docs/23.generate_operator_PyTorch.md +9 -9
  48. msprobe/docs/27.dump_json_instruction.md +301 -27
  49. msprobe/docs/28.debugger_save_instruction.md +94 -0
  50. msprobe/docs/28.kernel_dump_MindSpore.md +69 -0
  51. msprobe/docs/29.data_dump_MSAdapter.md +229 -0
  52. msprobe/docs/30.overflow_check_MSAdapter.md +31 -0
  53. msprobe/docs/FAQ.md +3 -11
  54. msprobe/docs/img/compare_result.png +0 -0
  55. msprobe/docs/img/merge_result.png +0 -0
  56. msprobe/docs/img/monitor/step_count_per_record.png +0 -0
  57. msprobe/docs/img/visualization/vis_browser_1.png +0 -0
  58. msprobe/docs/img/visualization/vis_match_info.png +0 -0
  59. msprobe/docs/img/visualization/vis_precision_info.png +0 -0
  60. msprobe/docs/img/visualization/vis_search_info.png +0 -0
  61. msprobe/docs/img/visualization/vis_show_info.png +0 -0
  62. msprobe/docs/img/visualization/vis_showcase.png +0 -0
  63. msprobe/docs/img/visualization/vis_unmatch_info.png +0 -0
  64. msprobe/mindspore/__init__.py +4 -2
  65. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +32 -7
  66. msprobe/mindspore/api_accuracy_checker/api_runner.py +70 -22
  67. msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +2 -1
  68. msprobe/mindspore/api_accuracy_checker/bench_functions/flash_attention_score.py +602 -0
  69. msprobe/mindspore/api_accuracy_checker/bench_functions/fusion_operator.py +41 -0
  70. msprobe/mindspore/api_accuracy_checker/compute_element.py +47 -1
  71. msprobe/mindspore/api_accuracy_checker/data_manager.py +2 -1
  72. msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +2 -1
  73. msprobe/mindspore/api_accuracy_checker/torch_mindtorch_importer.py +130 -0
  74. msprobe/mindspore/api_accuracy_checker/type_mapping.py +24 -1
  75. msprobe/mindspore/api_accuracy_checker/utils.py +6 -1
  76. msprobe/mindspore/common/const.py +61 -0
  77. msprobe/mindspore/common/utils.py +48 -18
  78. msprobe/mindspore/compare/ms_compare.py +27 -19
  79. msprobe/mindspore/compare/ms_graph_compare.py +6 -5
  80. msprobe/mindspore/debugger/debugger_config.py +31 -6
  81. msprobe/mindspore/debugger/precision_debugger.py +45 -14
  82. msprobe/mindspore/dump/dump_tool_factory.py +5 -3
  83. msprobe/mindspore/dump/hook_cell/api_register.py +142 -0
  84. msprobe/mindspore/dump/hook_cell/hook_cell.py +9 -10
  85. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +24 -26
  86. msprobe/mindspore/dump/jit_dump.py +21 -15
  87. msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +22 -56
  88. msprobe/mindspore/dym_loader/hook_dynamic_loader.h +0 -1
  89. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +10 -6
  90. msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +4 -2
  91. msprobe/mindspore/free_benchmark/self_check_tool_factory.py +6 -3
  92. msprobe/mindspore/grad_probe/global_context.py +2 -0
  93. msprobe/mindspore/grad_probe/grad_analyzer.py +2 -1
  94. msprobe/mindspore/grad_probe/hook.py +2 -4
  95. msprobe/mindspore/monitor/anomaly_detect.py +404 -0
  96. msprobe/mindspore/monitor/distributed/__init__.py +0 -0
  97. msprobe/mindspore/monitor/distributed/distributed_ops.yaml +15 -0
  98. msprobe/mindspore/monitor/distributed/stack_blacklist.yaml +5 -0
  99. msprobe/mindspore/monitor/distributed/wrap_distributed.py +300 -0
  100. msprobe/mindspore/monitor/features.py +63 -0
  101. msprobe/mindspore/monitor/module_hook.py +873 -0
  102. msprobe/mindspore/monitor/module_spec_verifier.py +94 -0
  103. msprobe/mindspore/monitor/utils.py +309 -0
  104. msprobe/mindspore/ms_config.py +8 -2
  105. msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +5 -3
  106. msprobe/mindspore/service.py +114 -34
  107. msprobe/pytorch/__init__.py +0 -1
  108. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +3 -6
  109. msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +12 -7
  110. msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +2 -2
  111. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +4 -5
  112. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +5 -5
  113. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +25 -6
  114. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +28 -19
  115. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +3 -1
  116. msprobe/pytorch/bench_functions/apply_adam.py +215 -0
  117. msprobe/pytorch/bench_functions/group_norm_silu.py +27 -0
  118. msprobe/pytorch/{parse.py → bench_functions/mish.py} +6 -4
  119. msprobe/pytorch/bench_functions/moe_gating_top_k_softmax.py +50 -0
  120. msprobe/pytorch/bench_functions/sort_v2.py +21 -0
  121. msprobe/pytorch/common/utils.py +97 -4
  122. msprobe/pytorch/debugger/debugger_config.py +19 -9
  123. msprobe/pytorch/debugger/precision_debugger.py +24 -1
  124. msprobe/pytorch/dump/module_dump/module_dump.py +4 -3
  125. msprobe/pytorch/dump/module_dump/module_processer.py +21 -35
  126. msprobe/pytorch/free_benchmark/common/utils.py +1 -1
  127. msprobe/pytorch/free_benchmark/compare/single_benchmark.py +1 -1
  128. msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +3 -3
  129. msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +3 -3
  130. msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +1 -1
  131. msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +1 -1
  132. msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +1 -1
  133. msprobe/pytorch/function_factory.py +8 -2
  134. msprobe/pytorch/grad_probe/grad_monitor.py +2 -2
  135. msprobe/pytorch/hook_module/api_register.py +131 -0
  136. msprobe/pytorch/hook_module/hook_module.py +19 -14
  137. msprobe/pytorch/hook_module/register_optimizer_hook.py +2 -1
  138. msprobe/pytorch/hook_module/support_wrap_ops.yaml +173 -75
  139. msprobe/pytorch/monitor/anomaly_detect.py +14 -29
  140. msprobe/pytorch/monitor/csv2tb.py +18 -14
  141. msprobe/pytorch/monitor/distributed/wrap_distributed.py +8 -2
  142. msprobe/pytorch/monitor/module_hook.py +238 -193
  143. msprobe/pytorch/monitor/module_metric.py +9 -6
  144. msprobe/pytorch/monitor/optimizer_collect.py +100 -67
  145. msprobe/pytorch/monitor/unittest/test_monitor.py +1 -1
  146. msprobe/pytorch/monitor/utils.py +76 -44
  147. msprobe/pytorch/online_dispatch/compare.py +0 -2
  148. msprobe/pytorch/online_dispatch/dispatch.py +9 -0
  149. msprobe/pytorch/online_dispatch/dump_compare.py +3 -0
  150. msprobe/pytorch/online_dispatch/utils.py +3 -0
  151. msprobe/pytorch/parse_tool/lib/interactive_cli.py +1 -6
  152. msprobe/pytorch/parse_tool/lib/utils.py +2 -1
  153. msprobe/pytorch/pt_config.py +30 -29
  154. msprobe/pytorch/service.py +114 -32
  155. msprobe/visualization/builder/graph_builder.py +75 -10
  156. msprobe/visualization/builder/msprobe_adapter.py +7 -6
  157. msprobe/visualization/compare/graph_comparator.py +42 -38
  158. msprobe/visualization/compare/mode_adapter.py +0 -19
  159. msprobe/visualization/graph/base_node.py +11 -3
  160. msprobe/visualization/graph/distributed_analyzer.py +71 -3
  161. msprobe/visualization/graph/graph.py +0 -11
  162. msprobe/visualization/graph/node_op.py +4 -3
  163. msprobe/visualization/graph_service.py +4 -5
  164. msprobe/visualization/utils.py +12 -35
  165. msprobe/mindspore/dump/hook_cell/api_registry.py +0 -205
  166. msprobe/mindspore/dump/hook_cell/wrap_api.py +0 -212
  167. msprobe/pytorch/hook_module/api_registry.py +0 -166
  168. msprobe/pytorch/hook_module/wrap_distributed.py +0 -75
  169. msprobe/pytorch/hook_module/wrap_functional.py +0 -66
  170. msprobe/pytorch/hook_module/wrap_npu_custom.py +0 -85
  171. msprobe/pytorch/hook_module/wrap_tensor.py +0 -69
  172. msprobe/pytorch/hook_module/wrap_torch.py +0 -84
  173. msprobe/pytorch/hook_module/wrap_vf.py +0 -60
  174. {mindstudio_probe-1.2.1.dist-info → mindstudio_probe-1.3.0.dist-info}/LICENSE +0 -0
  175. {mindstudio_probe-1.2.1.dist-info → mindstudio_probe-1.3.0.dist-info}/WHEEL +0 -0
  176. {mindstudio_probe-1.2.1.dist-info → mindstudio_probe-1.3.0.dist-info}/entry_points.txt +0 -0
  177. {mindstudio_probe-1.2.1.dist-info → mindstudio_probe-1.3.0.dist-info}/top_level.txt +0 -0
@@ -25,6 +25,7 @@ from msprobe.core.common.file_utils import load_npy
25
25
  from msprobe.mindspore.api_accuracy_checker.type_mapping import (api_info_type_str_to_type,
26
26
  ms_dtype_to_dtype_str, torch_dtype_to_dtype_str,
27
27
  dtype_str_to_ms_dtype, dtype_str_to_np_dtype,
28
+ dtype_str_to_mindtorch_dtype,
28
29
  dtype_str_to_torch_dtype, type_to_api_info_type_str,
29
30
  DEFAULT_CONSTRUCT_NP_FLOAT_DTYPE, TUPLE_TYPE_STR,
30
31
  MINDSPORE_TENSOR_TYPE_STR, MINDSPORE_DTYPE_TYPE_STR,
@@ -33,6 +34,15 @@ from msprobe.mindspore.api_accuracy_checker.type_mapping import (api_info_type_s
33
34
  from msprobe.mindspore.api_accuracy_checker.utils import check_and_get_from_json_dict, global_context
34
35
  from msprobe.mindspore.common.log import logger
35
36
 
37
+ import msprobe.mindspore.api_accuracy_checker.torch_mindtorch_importer as env_module
38
+
39
+
40
+ if env_module.is_valid_pt_mt_env:
41
+ from msprobe.mindspore.api_accuracy_checker.torch_mindtorch_importer import mindtorch
42
+ from msprobe.mindspore.api_accuracy_checker.torch_mindtorch_importer import torch
43
+ else:
44
+ import torch
45
+
36
46
 
37
47
  class MstensorMetaData:
38
48
  def __init__(self, dtype_str, npy_path, maximum, minimum, shape) -> None:
@@ -86,6 +96,37 @@ class ComputeElement:
86
96
  torch_tensor = torch.from_numpy(np_ndarray).to(torch_dtype)
87
97
  return torch_tensor
88
98
 
99
+ @staticmethod
100
+ def transfer_to_mindtorch_tensor(ms_tensor):
101
+ """
102
+ Args:
103
+ ms_tensor: mindspore.Tensor
104
+ Return:
105
+ mindtorch_tensor: mindtorch.Tensor
106
+ """
107
+
108
+ ms_dtype = ms_tensor.dtype
109
+
110
+ dtype_str = ms_dtype_to_dtype_str.get(ms_dtype)
111
+
112
+ if dtype_str not in dtype_str_to_mindtorch_dtype:
113
+ err_msg = f"ComputeElement.transfer_to_mindtorch_tensor failed: no matching mindtorch dtype for {dtype_str}"
114
+ logger.error_log_with_exp(err_msg,
115
+ ApiAccuracyCheckerException(ApiAccuracyCheckerException.UnsupportType))
116
+ else:
117
+ mindtorch_dtype = dtype_str_to_mindtorch_dtype.get(dtype_str)
118
+
119
+ if dtype_str in int_dtype_str_list:
120
+ middle_dtype = mindspore.int64
121
+ else:
122
+ middle_dtype = mindspore.float64
123
+
124
+ np_ndarray = ms_tensor.astype(middle_dtype).numpy()
125
+
126
+ mindtorch_tensor = mindtorch.from_numpy(np_ndarray).to(ms_dtype)
127
+
128
+ return mindtorch_tensor
129
+
89
130
  @staticmethod
90
131
  def transfer_to_mindspore_tensor(torch_tensor):
91
132
  '''
@@ -141,8 +182,11 @@ class ComputeElement:
141
182
  elif isinstance(self.parameter, DtypeMetaData):
142
183
  if tensor_platform == Const.MS_FRAMEWORK:
143
184
  parameter_tmp = dtype_str_to_ms_dtype.get(self.parameter.dtype_str)
144
- else:
185
+ elif tensor_platform == Const.PT_FRAMEWORK:
145
186
  parameter_tmp = dtype_str_to_torch_dtype.get(self.parameter.dtype_str)
187
+ elif tensor_platform == Const.MT_FRAMEWORK:
188
+ parameter_tmp = dtype_str_to_mindtorch_dtype.get(self.parameter.dtype_str)
189
+
146
190
  elif isinstance(self.parameter, MstensorMetaData):
147
191
  mstensor_meta_data = self.parameter
148
192
  ms_dtype = dtype_str_to_ms_dtype.get(mstensor_meta_data.dtype_str)
@@ -161,6 +205,8 @@ class ComputeElement:
161
205
  # if necessary, do transfer
162
206
  if not get_origin and isinstance(parameter_tmp, mindspore.Tensor) and tensor_platform == Const.PT_FRAMEWORK:
163
207
  parameter = self.transfer_to_torch_tensor(parameter_tmp)
208
+ elif not get_origin and isinstance(parameter_tmp, mindspore.Tensor) and tensor_platform == Const.MT_FRAMEWORK:
209
+ parameter = self.transfer_to_mindtorch_tensor(parameter_tmp)
164
210
  elif not get_origin and isinstance(parameter_tmp, torch.Tensor) and tensor_platform == Const.MS_FRAMEWORK:
165
211
  parameter = self.transfer_to_mindspore_tensor(parameter_tmp)
166
212
  else:
@@ -16,12 +16,13 @@
16
16
  import os
17
17
  import csv
18
18
 
19
- from msprobe.core.common.const import Const, CompareConst, MsCompareConst
19
+ from msprobe.core.common.const import Const, CompareConst
20
20
  from msprobe.core.common.file_utils import FileOpen, create_directory, write_csv, read_csv
21
21
  from msprobe.core.common.utils import add_time_as_suffix, MsprobeBaseException
22
22
  from msprobe.mindspore.api_accuracy_checker.base_compare_algorithm import compare_algorithms
23
23
  from msprobe.core.common.file_utils import check_file_or_directory_path
24
24
  from msprobe.mindspore.common.log import logger
25
+ from msprobe.mindspore.common.const import MsCompareConst
25
26
 
26
27
 
27
28
  class ResultCsvEntry:
@@ -27,10 +27,11 @@ import numpy as np
27
27
  from tqdm import tqdm
28
28
 
29
29
  # 本地应用/库特定导入
30
- from msprobe.core.common.const import Const, CompareConst, MsCompareConst
30
+ from msprobe.core.common.const import Const, CompareConst
31
31
  from msprobe.mindspore.api_accuracy_checker.api_accuracy_checker import ApiAccuracyChecker, BasicInfoAndStatus
32
32
  from msprobe.mindspore.api_accuracy_checker.multi_data_manager import MultiDataManager
33
33
  from msprobe.mindspore.common.log import logger
34
+ from msprobe.mindspore.common.const import MsCompareConst
34
35
 
35
36
 
36
37
  class MultiApiAccuracyChecker(ApiAccuracyChecker):
@@ -0,0 +1,130 @@
1
+ # Copyright (c) 2025-2025, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import os
17
+ import gc
18
+ import sys
19
+ from pathlib import Path
20
+ import mindspore
21
+ from msprobe.mindspore.common.log import logger
22
+ from msprobe.core.common.const import Const, CompareConst
23
+ from msprobe.mindspore.common.const import MsCompareConst
24
+ import torch as mindtorch
25
+ from torch import Tensor as mindtorch_tensor
26
+ import torch.nn.functional as mindtorch_func
27
+ import torch.distributed as mindtorch_dist
28
+
29
+
30
+ is_valid_pt_mt_env = True
31
+
32
+
33
+ def is_mindtorch():
34
+ mindtorch_check_result = False
35
+ try:
36
+ import torch as test_torch
37
+ from mindspore import Tensor as MindsporeTensor
38
+ except ImportError:
39
+ return mindtorch_check_result
40
+ tensor = test_torch.tensor(0.0)
41
+ if isinstance(tensor, MindsporeTensor):
42
+ mindtorch_check_result = True
43
+
44
+ return mindtorch_check_result
45
+
46
+
47
+ def remove_torch_related_paths():
48
+ removed_paths = []
49
+ if not is_mindtorch():
50
+ return
51
+ try:
52
+ import torch as remove_torch
53
+ torch_file = remove_torch.__file__
54
+ except ImportError:
55
+ return
56
+
57
+ torch_dir = os.path.dirname(torch_file)
58
+
59
+ torch_dir_path = Path(torch_dir).resolve()
60
+ parent_dir = torch_dir_path.parent
61
+
62
+ paths_to_remove = [str(parent_dir)]
63
+
64
+ for path in paths_to_remove:
65
+ try:
66
+ path_resolved = str(Path(path).resolve())
67
+ except Exception as error:
68
+ logger.debug(f"Failed to resolve path {path}: {error}")
69
+ continue
70
+
71
+ if path_resolved in sys.path:
72
+ index = sys.path.index(path_resolved)
73
+ removed_paths.append((path_resolved, index))
74
+ sys.path.pop(index)
75
+
76
+ return
77
+
78
+
79
+ def clear_torch_from_sys_modules():
80
+ modules_to_remove = []
81
+ for module in sys.modules:
82
+ if module == "torch" or module.startswith("torch."):
83
+ modules_to_remove.append(module)
84
+
85
+ for module in modules_to_remove:
86
+ del sys.modules[module]
87
+
88
+
89
+ def set_pt_mt_env_invalid():
90
+ global is_valid_pt_mt_env
91
+ is_valid_pt_mt_env = False
92
+
93
+
94
+ def delete_torch_paths():
95
+
96
+ if not is_mindtorch():
97
+ set_pt_mt_env_invalid()
98
+
99
+ clear_torch_from_sys_modules()
100
+
101
+ for count_delete_env_path in range(MsCompareConst.MAX_RECURSION_DEPTH):
102
+ if not is_mindtorch():
103
+ break
104
+
105
+ remove_torch_related_paths()
106
+
107
+ clear_torch_from_sys_modules()
108
+
109
+ if count_delete_env_path >= MsCompareConst.MAX_RECURSION_DEPTH - 1:
110
+ raise Exception(f"Please check if you have a valid PyTorch and MindTorch environment, and ensure "
111
+ f"the PYTHONPATH environment variable depth does not exceed {Const.MAX_RECURSION_DEPTH}.")
112
+
113
+
114
+ if not is_mindtorch():
115
+ set_pt_mt_env_invalid()
116
+
117
+ else:
118
+ initial_sys_path = sys.path.copy()
119
+ delete_torch_paths()
120
+
121
+ gc.collect()
122
+
123
+ import torch
124
+
125
+ if is_mindtorch():
126
+ set_pt_mt_env_invalid()
127
+
128
+ sys.path = initial_sys_path
129
+
130
+
@@ -15,10 +15,18 @@
15
15
 
16
16
  import mindspore
17
17
  import numpy as np
18
- import torch
19
18
  from mindspore._c_expression import typing
20
19
  from mindspore.common import dtype as mstype
21
20
 
21
+ from msprobe.mindspore.api_accuracy_checker import torch_mindtorch_importer
22
+
23
+ if torch_mindtorch_importer.is_valid_pt_mt_env:
24
+ from msprobe.mindspore.api_accuracy_checker.torch_mindtorch_importer import mindtorch
25
+ from msprobe.mindspore.api_accuracy_checker.torch_mindtorch_importer import torch
26
+ else:
27
+ from msprobe.mindspore.api_accuracy_checker.torch_mindtorch_importer import mindtorch
28
+ import torch
29
+
22
30
  INT8 = "Int8"
23
31
  UINT8 = "UInt8"
24
32
  INT16 = "Int16"
@@ -82,6 +90,21 @@ dtype_str_to_torch_dtype = {
82
90
  }
83
91
  torch_dtype_to_dtype_str = {value: key for key, value in dtype_str_to_torch_dtype.items()}
84
92
 
93
+
94
+ dtype_str_to_mindtorch_dtype = {
95
+ INT8: mindtorch.int8,
96
+ UINT8: mindtorch.uint8,
97
+ INT16: mindtorch.int16,
98
+ INT32: mindtorch.int32,
99
+ INT64: mindtorch.int64,
100
+ FLOAT16: mindtorch.float16,
101
+ FLOAT32: mindtorch.float32,
102
+ FLOAT64: mindtorch.float64,
103
+ BOOL: mindtorch.bool,
104
+ BFLOAT16: mindtorch.bfloat16,
105
+ }
106
+ mindtorch_dtype_to_dtype_str = {value: key for key, value in dtype_str_to_mindtorch_dtype.items()}
107
+
85
108
  MINDSPORE_TENSOR_TYPE_STR = "mindspore.Tensor"
86
109
  BOOL_TYPE_STR = "bool"
87
110
  INT_TYPE_STR = "int"
@@ -82,10 +82,12 @@ class GlobalContext:
82
82
  def __init__(self):
83
83
  self.is_constructed = True
84
84
  self.dump_data_dir = ""
85
+ self.framework = Const.MS_FRAMEWORK
85
86
 
86
- def init(self, is_constructed, dump_data_dir):
87
+ def init(self, is_constructed, dump_data_dir, framework):
87
88
  self.is_constructed = is_constructed
88
89
  self.dump_data_dir = dump_data_dir
90
+ self.framework = framework
89
91
 
90
92
  def get_dump_data_dir(self):
91
93
  return self.dump_data_dir
@@ -93,5 +95,8 @@ class GlobalContext:
93
95
  def get_is_constructed(self):
94
96
  return self.is_constructed
95
97
 
98
+ def get_framework(self):
99
+ return self.framework
100
+
96
101
 
97
102
  global_context = GlobalContext()
@@ -70,6 +70,67 @@ class Const:
70
70
  }
71
71
 
72
72
 
73
+ class MsCompareConst:
74
+ # api_info field
75
+ MINT = "Mint"
76
+ MINT_FUNCTIONAL = "MintFunctional"
77
+ TENSOR_API = "Tensor"
78
+ FUNCTIONAL_API = "Functional"
79
+ FUSION_API = "FUSION"
80
+
81
+ API_NAME_STR_LENGTH = 4
82
+ MAX_RECURSION_DEPTH = 20
83
+
84
+ # Mindtorch api_info field
85
+ MINDTORCH_TENSOR = "Tensor"
86
+ MINDTORCH = "Torch"
87
+ MINDTORCH_FUNC = "Functional"
88
+ MINDTORCH_NPU = "NPU"
89
+ MINDTORCH_DIST = "Distributed"
90
+
91
+
92
+
93
+ MT_VALID_API_TYPES = [
94
+ MINDTORCH, MINDTORCH_FUNC, MINDTORCH_TENSOR
95
+ ]
96
+ SUPPORTED_FUSION_LIST = ["flash_attention_score"]
97
+
98
+
99
+ TASK_FIELD = "task"
100
+ STATISTICS_TASK = "statistics"
101
+ FRAMEWORK = "framework"
102
+ TENSOR_TASK = "tensor"
103
+ DUMP_DATA_DIR_FIELD = "dump_data_dir"
104
+ DATA_FIELD = "data"
105
+
106
+ # supported api yaml
107
+ SUPPORTED_API_LIST_FILE = "checker_support_api.yaml"
108
+ SUPPORTED_TENSOR_LIST_KEY = "tensor"
109
+
110
+ # detail_csv
111
+ DETAIL_CSV_API_NAME = "API Name"
112
+ DETAIL_CSV_BENCH_DTYPE = "Bench Dtype"
113
+ DETAIL_CSV_TESTED_DTYPE = "Tested Dtype"
114
+ DETAIL_CSV_SHAPE = "Shape"
115
+ DETAIL_CSV_PASS_STATUS = "Status"
116
+ DETAIL_CSV_MESSAGE = "Message"
117
+ DETAIL_CSV_FILE_NAME = "accuracy_checking_details"
118
+
119
+ # result_csv
120
+ RESULT_CSV_FORWARD_TEST_SUCCESS = "Forward Test Success"
121
+ RESULT_CSV_BACKWARD_TEST_SUCCESS = "Backward Test Success"
122
+ RESULT_CSV_FILE_NAME = "accuracy_checking_result"
123
+
124
+ EPSILON = 1e-8
125
+
126
+ class ProcessStatus:
127
+ SUCCESS = "success"
128
+ API_NOT_FOUND = "api_not_found"
129
+ EXCEPTION_SKIP = "exception_skip"
130
+
131
+
132
+
133
+
73
134
  class FreeBenchmarkConst:
74
135
  ADD_NOISE = "add_noise"
75
136
  BIT_NOISE = "bit_noise"
@@ -25,7 +25,31 @@ from msprobe.core.common.exceptions import DistributedNotInitializedError
25
25
  from msprobe.core.common.file_utils import path_len_exceeds_limit, check_path_exists, save_npy
26
26
  from msprobe.core.common.log import logger
27
27
  from msprobe.core.common.const import Const
28
- from msprobe.core.common.utils import CompareException, check_seed_all
28
+ from msprobe.core.common.utils import CompareException, check_seed_all, is_save_variable_valid
29
+
30
+
31
+ class MsprobeStep(ms.train.Callback):
32
+ def __init__(self, debugger):
33
+ super(MsprobeStep, self).__init__()
34
+ self.debugger = debugger
35
+
36
+ def on_train_step_begin(self, run_context):
37
+ self.debugger.start()
38
+
39
+ def on_train_step_end(self, run_context):
40
+ self.debugger.stop()
41
+ self.debugger.step()
42
+
43
+
44
+ class MsprobeInitStep(ms.train.Callback):
45
+ def on_train_begin(self, run_context):
46
+ try:
47
+ from ms._c_expression import _set_init_iter
48
+ except ImportError:
49
+ logger.warning('MsprobeInitStep does not work on this version of MindSpore.')
50
+ return
51
+ cb_params = run_context.original_args()
52
+ _set_init_iter(cb_params.cur_step_num)
29
53
 
30
54
 
31
55
  def get_rank_if_initialized():
@@ -93,20 +117,6 @@ def seed_all(seed=1234, mode=False, rm_dropout=True):
93
117
  remove_dropout()
94
118
 
95
119
 
96
- class MsprobeStep(ms.train.Callback):
97
-
98
- def __init__(self, debugger):
99
- super(MsprobeStep, self).__init__()
100
- self.debugger = debugger
101
-
102
- def on_train_step_begin(self, run_context):
103
- self.debugger.start()
104
-
105
- def on_train_step_end(self, run_context):
106
- self.debugger.stop()
107
- self.debugger.step()
108
-
109
-
110
120
  class Dropout(ops.Dropout):
111
121
  def __init__(self, keep_prob=0.5, seed0=0, seed1=1):
112
122
  super().__init__(1., seed0, seed1)
@@ -151,11 +161,10 @@ def is_mindtorch():
151
161
  mindtorch_check_result = False
152
162
  try:
153
163
  import torch
154
- from mindspore._c_expression import Tensor
155
164
  except ImportError:
156
165
  return mindtorch_check_result
157
166
  tensor = torch.tensor(0.0)
158
- if isinstance(tensor, Tensor):
167
+ if isinstance(tensor, ms.Tensor):
159
168
  mindtorch_check_result = True
160
169
  return mindtorch_check_result
161
170
 
@@ -170,7 +179,7 @@ def set_register_backward_hook_functions():
170
179
  from msprobe.mindspore.mindtorch import (_call_impl,
171
180
  register_full_backward_pre_hook,
172
181
  register_full_backward_hook)
173
- if not hasattr(torch, "register_full_backward_hook"):
182
+ if not hasattr(torch.nn.Module, "register_full_backward_hook"):
174
183
  setattr(torch.nn.Module, "_call_impl", _call_impl)
175
184
  setattr(torch.nn.Module, "register_full_backward_pre_hook", register_full_backward_pre_hook)
176
185
  setattr(torch.nn.Module, "register_full_backward_hook", register_full_backward_hook)
@@ -179,3 +188,24 @@ def set_register_backward_hook_functions():
179
188
  else:
180
189
  register_backward_hook_functions["pre"] = ms.nn.Cell.register_backward_pre_hook
181
190
  register_backward_hook_functions["full"] = ms.nn.Cell.register_backward_hook
191
+
192
+
193
+ def check_save_param(variable, name, save_backward):
194
+ # try catch this api to skip invalid call
195
+ valid_data_types = tuple([ms.Tensor, int, float, str])
196
+ if not is_save_variable_valid(variable, valid_data_types):
197
+ valid_data_types_with_nested_types = valid_data_types + (dict, tuple, list)
198
+ logger.warning("PrecisionDebugger.save variable type not valid, "
199
+ f"should be one of {valid_data_types_with_nested_types}"
200
+ "Skip current save process.")
201
+ raise ValueError
202
+ if not isinstance(name, str):
203
+ logger.warning("PrecisionDebugger.save name not valid, "
204
+ "should be string. "
205
+ "skip current save process.")
206
+ raise ValueError
207
+ if not isinstance(save_backward, bool):
208
+ logger.warning("PrecisionDebugger.save_backward name not valid, "
209
+ "should be bool. "
210
+ "Skip current save process.")
211
+ raise ValueError
@@ -22,10 +22,10 @@ import pandas as pd
22
22
 
23
23
  from msprobe.core.common.const import CompareConst, Const
24
24
  from msprobe.core.common.exceptions import FileCheckException
25
- from msprobe.core.common.file_utils import FileOpen, create_directory, load_json, load_npy, load_yaml
25
+ from msprobe.core.common.file_utils import create_directory, load_json, load_npy, load_yaml
26
26
  from msprobe.core.common.log import logger
27
27
  from msprobe.core.common.utils import CompareException, check_compare_param, check_configuration_param, \
28
- check_op_str_pattern_valid, get_dump_mode, set_dump_path
28
+ check_op_str_pattern_valid, get_dump_mode, set_dump_path, detect_framework_by_dump_json
29
29
  from msprobe.core.compare.acc_compare import Comparator, ModeConfig
30
30
  from msprobe.core.compare.check import dtype_mapping
31
31
  from msprobe.core.compare.layer_mapping import generate_data_mapping_by_layer_mapping
@@ -78,6 +78,11 @@ class MSComparator(Comparator):
78
78
  raise TypeError(f"The type of parameter `data_mapping` must be dict, str or None, but got "
79
79
  f"{type(self.data_mapping)}")
80
80
 
81
+ @staticmethod
82
+ def process_data_name(result):
83
+ result['data_name_x'] = result.apply(lambda row: [row['data_name_x'], row['data_name_y']], axis=1)
84
+ return result
85
+
81
86
  def calc_accuracy(self, result_df, header):
82
87
  condition_no_bench = result_df[CompareConst.BENCH_NAME] == CompareConst.N_A
83
88
  result_df[condition_no_bench] = result_df[condition_no_bench].fillna(CompareConst.N_A)
@@ -120,12 +125,13 @@ class MSComparator(Comparator):
120
125
  result_df.loc[~condition_md5_equal & ~condition_no_bench, CompareConst.RESULT] = CompareConst.DIFF
121
126
  elif self.dump_mode == Const.SUMMARY:
122
127
  warning_list = [calc_summary_diff(data_type) for data_type in ['max', 'min', 'mean', 'l2norm']]
123
- warning_flag = pd.DataFrame(warning_list).all()
128
+ warning_flag = pd.DataFrame(warning_list).any()
124
129
  result_df.loc[~condition_no_bench, [CompareConst.RESULT, CompareConst.ERROR_MESSAGE]] = ''
125
130
  result_df.loc[warning_flag, CompareConst.RESULT] = CompareConst.WARNING
126
131
  result_df.loc[warning_flag, CompareConst.ERROR_MESSAGE] = 'Need double check api accuracy.'
127
132
  else:
128
- fill_cols = [CompareConst.COSINE, CompareConst.MAX_ABS_ERR, CompareConst.MAX_RELATIVE_ERR,
133
+ fill_cols = [CompareConst.COSINE, CompareConst.EUC_DIST,
134
+ CompareConst.MAX_ABS_ERR, CompareConst.MAX_RELATIVE_ERR,
129
135
  CompareConst.ONE_THOUSANDTH_ERR_RATIO, CompareConst.FIVE_THOUSANDTHS_ERR_RATIO,
130
136
  CompareConst.ERROR_MESSAGE]
131
137
  result_df.loc[~condition_no_bench, fill_cols] = ''
@@ -139,6 +145,8 @@ class MSComparator(Comparator):
139
145
  header.append(CompareConst.STACK)
140
146
  if self.dump_mode == Const.ALL:
141
147
  header.append(CompareConst.DATA_NAME)
148
+ result = self.process_data_name(result)
149
+
142
150
  result.rename(columns={'op_name_x': CompareConst.NPU_NAME,
143
151
  'op_name_y': CompareConst.BENCH_NAME,
144
152
  'dtype_x': CompareConst.NPU_DTYPE,
@@ -169,6 +177,7 @@ class MSComparator(Comparator):
169
177
 
170
178
  result[npu_summary] = result['summary_x'].apply(set_summary).tolist()
171
179
  result[bench_summary] = result['summary_y'].apply(set_summary).tolist()
180
+
172
181
  result_df = pd.DataFrame(columns=header)
173
182
  for h in header:
174
183
  if h in result.columns:
@@ -269,15 +278,15 @@ class MSComparator(Comparator):
269
278
  bench_dtype = match_result['dtype_y']
270
279
  if self.cross_frame:
271
280
  npu_dtype = npu_dtype.map(dtype_mapping).fillna(npu_dtype)
272
- return ((npu_dtype == bench_dtype) |
273
- ((npu_dtype == Const.FLOAT16) & (bench_dtype == Const.FLOAT32)) |
274
- ((npu_dtype == Const.FLOAT32) & (bench_dtype == Const.FLOAT16)) |
275
- ((npu_dtype == Const.FLOAT16) & (bench_dtype == Const.BFLOAT16)) |
276
- ((npu_dtype == Const.BFLOAT16) & (bench_dtype == Const.FLOAT16)) |
277
- ((npu_dtype == Const.TORCH_FLOAT16) & (bench_dtype == Const.TORCH_FLOAT32)) |
278
- ((npu_dtype == Const.TORCH_FLOAT32) & (bench_dtype == Const.TORCH_FLOAT16)) |
279
- ((npu_dtype == Const.TORCH_FLOAT16) & (bench_dtype == Const.TORCH_BFLOAT16)) |
280
- ((npu_dtype == Const.TORCH_BFLOAT16) & (bench_dtype == Const.TORCH_FLOAT16)))
281
+
282
+ equal_condition = npu_dtype == bench_dtype
283
+ match_condition = (
284
+ (npu_dtype.isin(CompareConst.DTYPE_MATCH_GROUPS[0]) & bench_dtype.isin(
285
+ CompareConst.DTYPE_MATCH_GROUPS[0])) |
286
+ (npu_dtype.isin(CompareConst.DTYPE_MATCH_GROUPS[1]) & bench_dtype.isin(
287
+ CompareConst.DTYPE_MATCH_GROUPS[1]))
288
+ )
289
+ return equal_condition | match_condition
281
290
 
282
291
  match_result.loc[~gen_dtype_condition(), [i + '_y' for i in bench_df.columns]] = CompareConst.N_A
283
292
  return self.make_result_df(match_result)
@@ -382,12 +391,11 @@ class MSComparator(Comparator):
382
391
 
383
392
 
384
393
  def check_cross_framework(bench_json_path):
385
- pattern = r'"data_name":\s*"[^"]+\.pt"'
386
- with FileOpen(bench_json_path, 'r') as file:
387
- for line in file:
388
- if re.search(pattern, line):
389
- return True
390
- return False
394
+ framework = detect_framework_by_dump_json(bench_json_path)
395
+ if framework == Const.PT_FRAMEWORK:
396
+ return True
397
+ else:
398
+ return False
391
399
 
392
400
 
393
401
  def ms_compare(input_param, output_path, **kwargs):
@@ -195,11 +195,12 @@ class GraphMSComparator:
195
195
  if not error_flag:
196
196
  result_list, err_msg = compare_ops_apply(n_value, b_value, False, "")
197
197
  result_dict[CompareConst.COSINE] = result_list[0]
198
- result_dict[CompareConst.MAX_ABS_ERR] = result_list[1]
199
- result_dict[CompareConst.MAX_RELATIVE_ERR] = result_list[2]
200
- result_dict[CompareConst.ONE_THOUSANDTH_ERR_RATIO] = result_list[3]
201
- result_dict[CompareConst.FIVE_THOUSANDTHS_ERR_RATIO] = result_list[4]
202
- result_dict[CompareConst.ACCURACY] = check_accuracy(result_list[0], result_list[1])
198
+ result_dict[CompareConst.EUC_DIST] = result_list[1]
199
+ result_dict[CompareConst.MAX_ABS_ERR] = result_list[2]
200
+ result_dict[CompareConst.MAX_RELATIVE_ERR] = result_list[3]
201
+ result_dict[CompareConst.ONE_THOUSANDTH_ERR_RATIO] = result_list[4]
202
+ result_dict[CompareConst.FIVE_THOUSANDTHS_ERR_RATIO] = result_list[5]
203
+ result_dict[CompareConst.ACCURACY] = check_accuracy(result_list[0], result_list[2])
203
204
  result_dict[CompareConst.ERROR_MESSAGE] = err_msg
204
205
 
205
206
  return pd.Series(result_dict)
@@ -1,4 +1,4 @@
1
- # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
1
+ # Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
2
2
  # All rights reserved.
3
3
  #
4
4
  # Licensed under the Apache License, Version 2.0 (the "License");
@@ -16,9 +16,11 @@
16
16
  import os
17
17
 
18
18
  from msprobe.core.common.const import Const
19
+ from msprobe.core.common.exceptions import MsprobeException
19
20
  from msprobe.core.common.file_utils import create_directory
20
21
  from msprobe.mindspore.common.const import Const as MsConst
21
22
  from msprobe.mindspore.common.const import FreeBenchmarkConst
23
+ from msprobe.core.common.log import logger
22
24
 
23
25
 
24
26
  class DebuggerConfig:
@@ -50,12 +52,14 @@ class DebuggerConfig:
50
52
  if not task_config.handler_type else task_config.handler_type)
51
53
  self.stage = FreeBenchmarkConst.DEFAULT_STAGE if not task_config.fuzz_stage else task_config.fuzz_stage
52
54
  if self.handler_type == FreeBenchmarkConst.FIX and \
53
- self.pert_type != FreeBenchmarkConst.DEFAULT_PERT_TYPE:
54
- raise ValueError("pert_mode must be improve_precision or empty when handler_type is fix, "
55
- f"but got {self.pert_type}.")
55
+ self.pert_type != FreeBenchmarkConst.DEFAULT_PERT_TYPE:
56
+ logger.error("pert_mode must be improve_precision or empty when handler_type is fix, "
57
+ f"but got {self.pert_type}.")
58
+ raise ValueError
56
59
  if self.stage == Const.BACKWARD and self.handler_type == FreeBenchmarkConst.FIX:
57
- raise ValueError("handler_type must be check or empty when fuzz_stage is backward, "
58
- f"but got {self.handler_type}.")
60
+ logger.error("handler_type must be check or empty when fuzz_stage is backward, "
61
+ f"but got {self.handler_type}.")
62
+ raise ValueError
59
63
  self.dump_level = FreeBenchmarkConst.DEFAULT_DUMP_LEVEL
60
64
 
61
65
  def check(self):
@@ -72,4 +76,25 @@ class DebuggerConfig:
72
76
  self.check_mode = "all"
73
77
  if not isinstance(self.async_dump, bool):
74
78
  raise Exception("The parameters async_dump should be bool.")
79
+ if self.async_dump and self.task == Const.TENSOR and not self.list:
80
+ raise Exception("The parameters async_dump is true in tensor task, the parameters list cannot be empty.")
81
+ if self.task == Const.STRUCTURE and self.level_ori not in [Const.LEVEL_L0, Const.LEVEL_MIX]:
82
+ logger.warning_on_rank_0(
83
+ f"When the task is set to structure, the level should be one of {[Const.LEVEL_L0, Const.LEVEL_MIX]}. "
84
+ f"If not, the default level is {Const.LEVEL_MIX}."
85
+ )
86
+ self.level_ori = Const.LEVEL_MIX
75
87
  return True
88
+
89
+ def check_config_with_l2(self):
90
+ if self.level_ori != Const.LEVEL_L2:
91
+ return
92
+ if self.task != Const.TENSOR:
93
+ raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR,
94
+ f"When level is set to L2, the task must be set to tensor.")
95
+ if self.scope:
96
+ raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR,
97
+ f"When level is set to L2, the scope cannot be configured.")
98
+ if not self.list or len(self.list) != 1:
99
+ raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR,
100
+ f"When level is set to L2, the list must be configured as a list with one api name.")