mindstudio-probe 1.1.1__py3-none-any.whl → 1.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (197) hide show
  1. {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.1.dist-info}/METADATA +3 -2
  2. {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.1.dist-info}/RECORD +196 -141
  3. msprobe/CMakeLists.txt +5 -0
  4. msprobe/README.md +14 -19
  5. msprobe/config.json +1 -0
  6. msprobe/core/common/const.py +155 -6
  7. msprobe/core/common/exceptions.py +3 -1
  8. msprobe/core/common/file_utils.py +33 -7
  9. msprobe/core/common/inplace_ops.yaml +3 -0
  10. msprobe/core/common/utils.py +28 -14
  11. msprobe/core/common_config.py +6 -0
  12. msprobe/core/compare/acc_compare.py +139 -128
  13. msprobe/core/compare/check.py +31 -29
  14. msprobe/core/compare/compare_cli.py +17 -16
  15. msprobe/core/compare/highlight.py +186 -99
  16. msprobe/core/compare/layer_mapping/data_scope_parser.py +18 -7
  17. msprobe/core/compare/layer_mapping/layer_mapping.py +21 -14
  18. msprobe/core/compare/layer_mapping/postprocess_pass.py +4 -3
  19. msprobe/core/compare/merge_result/merge_result.py +380 -0
  20. msprobe/core/compare/merge_result/merge_result_cli.py +31 -0
  21. msprobe/core/compare/multiprocessing_compute.py +2 -2
  22. msprobe/core/compare/npy_compare.py +109 -147
  23. msprobe/core/compare/utils.py +189 -69
  24. msprobe/core/data_dump/data_collector.py +51 -21
  25. msprobe/core/data_dump/data_processor/base.py +38 -20
  26. msprobe/core/data_dump/data_processor/factory.py +5 -3
  27. msprobe/core/data_dump/data_processor/mindspore_processor.py +154 -20
  28. msprobe/core/data_dump/data_processor/pytorch_processor.py +118 -58
  29. msprobe/core/data_dump/json_writer.py +29 -1
  30. msprobe/core/data_dump/scope.py +19 -18
  31. msprobe/core/overflow_check/abnormal_scene.py +9 -5
  32. msprobe/core/overflow_check/checker.py +1 -1
  33. msprobe/core/overflow_check/utils.py +1 -1
  34. msprobe/docs/01.installation.md +96 -17
  35. msprobe/docs/02.config_introduction.md +5 -5
  36. msprobe/docs/05.data_dump_PyTorch.md +91 -61
  37. msprobe/docs/06.data_dump_MindSpore.md +57 -19
  38. msprobe/docs/07.accuracy_checker_PyTorch.md +18 -18
  39. msprobe/docs/09.accuracy_checker_MindSpore.md +4 -4
  40. msprobe/docs/10.accuracy_compare_PyTorch.md +99 -41
  41. msprobe/docs/11.accuracy_compare_MindSpore.md +249 -48
  42. msprobe/docs/12.overflow_check_PyTorch.md +1 -1
  43. msprobe/docs/19.monitor.md +120 -27
  44. msprobe/docs/21.visualization_PyTorch.md +115 -35
  45. msprobe/docs/22.visualization_MindSpore.md +138 -41
  46. msprobe/docs/23.generate_operator_PyTorch.md +107 -0
  47. msprobe/docs/24.code_mapping_Mindspore.md +28 -0
  48. msprobe/docs/{23.tool_function_introduction.md → 25.tool_function_introduction.md} +1 -0
  49. msprobe/docs/26.data_dump_PyTorch_baseline.md +37 -0
  50. msprobe/docs/27.dump_json_instruction.md +521 -0
  51. msprobe/docs/FAQ.md +26 -2
  52. msprobe/docs/accuracy_checker_MindSpore/accuracy_checker_MindSpore_baseline.md +14 -0
  53. msprobe/docs/data_dump_MindSpore/data_dump_MindSpore_baseline.md +22 -0
  54. msprobe/docs/img/merge_result.png +0 -0
  55. msprobe/docs/img/visualization/fuzzy_match_ms.png +0 -0
  56. msprobe/docs/img/visualization/fuzzy_match_pt.png +0 -0
  57. msprobe/docs/img/visualization/tensorboard_1.png +0 -0
  58. msprobe/docs/img/visualization/tensorboard_2.png +0 -0
  59. msprobe/docs/img/visualization/vis_browser_1.png +0 -0
  60. msprobe/docs/img/visualization/vis_browser_2.png +0 -0
  61. msprobe/docs/img/visualization/vis_precision_info.png +0 -0
  62. msprobe/docs/img/visualization/vis_search_info.png +0 -0
  63. msprobe/docs/img/visualization/vis_show_info.png +0 -0
  64. msprobe/docs/img/visualization/vis_showcase.png +0 -0
  65. msprobe/docs/img/visualization/vis_unmatch_info.png +0 -0
  66. msprobe/docs/visualization/GPTModel.png +0 -0
  67. msprobe/docs/visualization/ParallelMLP.png +0 -0
  68. msprobe/docs/visualization/layer_mapping_example.md +132 -0
  69. msprobe/docs/visualization/mapping.png +0 -0
  70. msprobe/docs/visualization/mapping1.png +0 -0
  71. msprobe/docs/visualization/module_name.png +0 -0
  72. msprobe/docs/visualization/module_name1.png +0 -0
  73. msprobe/docs/visualization/no_mapping.png +0 -0
  74. msprobe/docs/visualization/no_mapping1.png +0 -0
  75. msprobe/docs/visualization/no_mapping_analyze.png +0 -0
  76. msprobe/docs/visualization/top_layer.png +0 -0
  77. msprobe/mindspore/__init__.py +10 -0
  78. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +57 -25
  79. msprobe/mindspore/api_accuracy_checker/cmd_parser.py +2 -1
  80. msprobe/mindspore/api_accuracy_checker/compute_element.py +5 -7
  81. msprobe/mindspore/api_accuracy_checker/data_manager.py +37 -0
  82. msprobe/mindspore/api_accuracy_checker/main.py +1 -0
  83. msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +12 -6
  84. msprobe/mindspore/api_accuracy_checker/multi_data_manager.py +3 -1
  85. msprobe/mindspore/code_mapping/bind.py +264 -0
  86. msprobe/mindspore/code_mapping/cmd_parser.py +40 -0
  87. msprobe/mindspore/code_mapping/graph.py +49 -0
  88. msprobe/mindspore/code_mapping/graph_parser.py +226 -0
  89. msprobe/mindspore/code_mapping/main.py +24 -0
  90. msprobe/mindspore/code_mapping/processor.py +34 -0
  91. msprobe/mindspore/common/const.py +3 -1
  92. msprobe/mindspore/common/utils.py +50 -5
  93. msprobe/mindspore/compare/distributed_compare.py +0 -2
  94. msprobe/mindspore/compare/ms_compare.py +105 -63
  95. msprobe/mindspore/compare/ms_graph_compare.py +14 -5
  96. msprobe/mindspore/debugger/debugger_config.py +3 -0
  97. msprobe/mindspore/debugger/precision_debugger.py +81 -12
  98. msprobe/mindspore/dump/hook_cell/api_registry.py +83 -16
  99. msprobe/mindspore/dump/hook_cell/hook_cell.py +60 -38
  100. msprobe/mindspore/dump/hook_cell/primitive_hooks.py +33 -15
  101. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +11 -1
  102. msprobe/mindspore/dump/hook_cell/wrap_api.py +92 -1
  103. msprobe/mindspore/dump/kernel_dump/kernel_config.py +33 -0
  104. msprobe/mindspore/dump/kernel_graph_dump.py +7 -0
  105. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +13 -4
  106. msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +2 -2
  107. msprobe/mindspore/grad_probe/grad_analyzer.py +24 -12
  108. msprobe/mindspore/grad_probe/hook.py +13 -4
  109. msprobe/mindspore/mindtorch/__init__.py +18 -0
  110. msprobe/mindspore/mindtorch/mindtorch_adaptor.py +255 -0
  111. msprobe/mindspore/ms_config.py +5 -1
  112. msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +7 -0
  113. msprobe/mindspore/service.py +267 -101
  114. msprobe/msprobe.py +24 -3
  115. msprobe/pytorch/__init__.py +7 -6
  116. msprobe/pytorch/api_accuracy_checker/common/utils.py +31 -16
  117. msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +41 -8
  118. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +100 -267
  119. msprobe/pytorch/api_accuracy_checker/compare/api_precision_standard.yaml +4 -1
  120. msprobe/pytorch/api_accuracy_checker/compare/compare.py +69 -68
  121. msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +54 -0
  122. msprobe/pytorch/api_accuracy_checker/compare/compare_input.py +51 -0
  123. msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +2 -4
  124. msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +54 -30
  125. msprobe/pytorch/api_accuracy_checker/precision_standard/absolute_threshold.py +106 -0
  126. msprobe/pytorch/api_accuracy_checker/precision_standard/accumulative_error_compare.py +107 -0
  127. msprobe/pytorch/api_accuracy_checker/precision_standard/base_standard.py +151 -0
  128. msprobe/pytorch/api_accuracy_checker/precision_standard/benchmark_compare.py +226 -0
  129. msprobe/pytorch/api_accuracy_checker/precision_standard/binary_consistency.py +68 -0
  130. msprobe/pytorch/api_accuracy_checker/precision_standard/standard_config.py +218 -0
  131. msprobe/pytorch/api_accuracy_checker/precision_standard/standard_register.py +104 -0
  132. msprobe/pytorch/api_accuracy_checker/precision_standard/thousandth_standard.py +63 -0
  133. msprobe/pytorch/api_accuracy_checker/precision_standard/ulp_compare.py +200 -0
  134. msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +57 -1
  135. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +2 -1
  136. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +42 -14
  137. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +64 -19
  138. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +34 -4
  139. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +5 -3
  140. msprobe/pytorch/bench_functions/npu_fusion_attention.py +42 -10
  141. msprobe/pytorch/common/parse_json.py +2 -1
  142. msprobe/pytorch/common/utils.py +45 -2
  143. msprobe/pytorch/compare/distributed_compare.py +17 -29
  144. msprobe/pytorch/compare/pt_compare.py +40 -20
  145. msprobe/pytorch/debugger/debugger_config.py +27 -12
  146. msprobe/pytorch/debugger/precision_debugger.py +42 -12
  147. msprobe/pytorch/dump/module_dump/__init__.py +0 -0
  148. msprobe/pytorch/dump/module_dump/module_dump.py +86 -0
  149. msprobe/pytorch/{module_processer.py → dump/module_dump/module_processer.py} +80 -6
  150. msprobe/pytorch/free_benchmark/common/params.py +2 -1
  151. msprobe/pytorch/free_benchmark/common/utils.py +3 -0
  152. msprobe/pytorch/free_benchmark/compare/grad_saver.py +0 -2
  153. msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +31 -47
  154. msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +0 -4
  155. msprobe/pytorch/hook_module/__init__.py +1 -1
  156. msprobe/pytorch/hook_module/hook_module.py +14 -11
  157. msprobe/pytorch/hook_module/register_optimizer_hook.py +59 -0
  158. msprobe/pytorch/hook_module/support_wrap_ops.yaml +34 -0
  159. msprobe/pytorch/hook_module/wrap_distributed.py +6 -8
  160. msprobe/pytorch/hook_module/wrap_functional.py +0 -40
  161. msprobe/pytorch/monitor/anomaly_analyse.py +1 -1
  162. msprobe/pytorch/monitor/anomaly_detect.py +107 -22
  163. msprobe/pytorch/monitor/csv2tb.py +166 -0
  164. msprobe/pytorch/monitor/distributed/wrap_distributed.py +25 -14
  165. msprobe/pytorch/monitor/features.py +3 -3
  166. msprobe/pytorch/monitor/module_hook.py +483 -277
  167. msprobe/pytorch/monitor/module_metric.py +27 -48
  168. msprobe/pytorch/monitor/module_spec_verifier.py +3 -1
  169. msprobe/pytorch/monitor/optimizer_collect.py +52 -14
  170. msprobe/pytorch/monitor/unittest/test_monitor.py +24 -9
  171. msprobe/pytorch/monitor/utils.py +77 -6
  172. msprobe/pytorch/online_dispatch/dispatch.py +8 -2
  173. msprobe/pytorch/parse_tool/lib/compare.py +10 -10
  174. msprobe/pytorch/parse_tool/lib/config.py +5 -7
  175. msprobe/pytorch/parse_tool/lib/file_desc.py +15 -1
  176. msprobe/pytorch/parse_tool/lib/interactive_cli.py +10 -10
  177. msprobe/pytorch/parse_tool/lib/parse_exception.py +7 -7
  178. msprobe/pytorch/parse_tool/lib/parse_tool.py +11 -10
  179. msprobe/pytorch/parse_tool/lib/utils.py +18 -19
  180. msprobe/pytorch/parse_tool/lib/visualization.py +9 -10
  181. msprobe/pytorch/service.py +176 -106
  182. msprobe/visualization/builder/graph_builder.py +62 -5
  183. msprobe/visualization/builder/msprobe_adapter.py +24 -2
  184. msprobe/visualization/compare/graph_comparator.py +64 -14
  185. msprobe/visualization/compare/mode_adapter.py +1 -15
  186. msprobe/visualization/graph/base_node.py +12 -17
  187. msprobe/visualization/graph/distributed_analyzer.py +318 -0
  188. msprobe/visualization/graph/graph.py +9 -0
  189. msprobe/visualization/graph_service.py +97 -23
  190. msprobe/visualization/utils.py +14 -29
  191. msprobe/pytorch/functional/module_dump.py +0 -84
  192. {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.1.dist-info}/LICENSE +0 -0
  193. {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.1.dist-info}/WHEEL +0 -0
  194. {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.1.dist-info}/entry_points.txt +0 -0
  195. {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.1.dist-info}/top_level.txt +0 -0
  196. /msprobe/docs/{data_dump_Mindspore → data_dump_MindSpore}/dynamic_graph_quick_start_example.md +0 -0
  197. /msprobe/{pytorch/functional → mindspore/code_mapping}/__init__.py +0 -0
@@ -1,4 +1,4 @@
1
- # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
1
+ # Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
2
2
  # All rights reserved.
3
3
  #
4
4
  # Licensed under the Apache License, Version 2.0 (the "License");
@@ -20,6 +20,7 @@ import mindspore as ms
20
20
 
21
21
  from mindspore import ops
22
22
  from mindspore.mint import nn
23
+
23
24
  from msprobe.core.common.exceptions import DistributedNotInitializedError
24
25
  from msprobe.core.common.file_utils import path_len_exceeds_limit, check_path_exists, save_npy
25
26
  from msprobe.core.common.log import logger
@@ -43,7 +44,7 @@ def convert_bf16_to_fp32(tensor):
43
44
  def save_tensor_as_npy(tensor, file_path):
44
45
  if not path_len_exceeds_limit(file_path):
45
46
  tensor = convert_bf16_to_fp32(tensor)
46
- saved_tensor = tensor.contiguous().asnumpy()
47
+ saved_tensor = tensor.asnumpy()
47
48
  save_npy(saved_tensor, file_path)
48
49
  else:
49
50
  logger.warning(f'The file path {file_path} length exceeds limit.')
@@ -56,6 +57,11 @@ def convert_to_int(value):
56
57
  return -1
57
58
 
58
59
 
60
+ def clean_input_kwargs(cell):
61
+ if hasattr(cell, 'input_kwargs'):
62
+ del cell.input_kwargs
63
+
64
+
59
65
  def list_lowest_level_directories(root_dir):
60
66
  check_path_exists(root_dir)
61
67
  lowest_level_dirs = []
@@ -77,7 +83,7 @@ def list_lowest_level_directories(root_dir):
77
83
 
78
84
 
79
85
  def seed_all(seed=1234, mode=False, rm_dropout=True):
80
- check_seed_all(seed, mode)
86
+ check_seed_all(seed, mode, rm_dropout)
81
87
  os.environ['PYTHONHASHSEED'] = str(seed)
82
88
  ms.set_seed(seed)
83
89
  random.seed(seed)
@@ -102,8 +108,8 @@ class MsprobeStep(ms.train.Callback):
102
108
 
103
109
 
104
110
  class Dropout(ops.Dropout):
105
- def __init__(self, keep_prob=0.5, Seed0=0, Seed1=1):
106
- super().__init__(1., Seed0, Seed1)
111
+ def __init__(self, keep_prob=0.5, seed0=0, seed1=1):
112
+ super().__init__(1., seed0, seed1)
107
113
 
108
114
 
109
115
  class Dropout2D(ops.Dropout2D):
@@ -134,3 +140,42 @@ def remove_dropout():
134
140
  ops.operations.Dropout3D = Dropout3D
135
141
  nn.Dropout = DropoutExt
136
142
  nn.functional.dropout = dropout_ext
143
+
144
+
145
+ mindtorch_check_result = None
146
+
147
+
148
+ def is_mindtorch():
149
+ global mindtorch_check_result
150
+ if mindtorch_check_result is None:
151
+ mindtorch_check_result = False
152
+ try:
153
+ import torch
154
+ from mindspore._c_expression import Tensor
155
+ except ImportError:
156
+ return mindtorch_check_result
157
+ tensor = torch.tensor(0.0)
158
+ if isinstance(tensor, Tensor):
159
+ mindtorch_check_result = True
160
+ return mindtorch_check_result
161
+
162
+
163
+ register_backward_hook_functions = {}
164
+
165
+
166
+ def set_register_backward_hook_functions():
167
+ global register_backward_hook_functions
168
+ if is_mindtorch():
169
+ import torch
170
+ from msprobe.mindspore.mindtorch import (_call_impl,
171
+ register_full_backward_pre_hook,
172
+ register_full_backward_hook)
173
+ if not hasattr(torch, "register_full_backward_hook"):
174
+ setattr(torch.nn.Module, "_call_impl", _call_impl)
175
+ setattr(torch.nn.Module, "register_full_backward_pre_hook", register_full_backward_pre_hook)
176
+ setattr(torch.nn.Module, "register_full_backward_hook", register_full_backward_hook)
177
+ register_backward_hook_functions["pre"] = torch.nn.Module.register_full_backward_pre_hook
178
+ register_backward_hook_functions["full"] = torch.nn.Module.register_full_backward_hook
179
+ else:
180
+ register_backward_hook_functions["pre"] = ms.nn.Cell.register_backward_pre_hook
181
+ register_backward_hook_functions["full"] = ms.nn.Cell.register_backward_hook
@@ -41,12 +41,10 @@ def ms_compare_distributed(npu_dump_dir, bench_dump_dir, output_path, **kwargs):
41
41
  bench_data_dir = os.path.join(bench_dump_dir, br)
42
42
  npu_path = extract_json(npu_data_dir, stack_json=False)
43
43
  bench_path = extract_json(bench_data_dir, stack_json=False)
44
- stack_path = extract_json(npu_data_dir, stack_json=True)
45
44
 
46
45
  dump_result_param = {
47
46
  'npu_json_path': npu_path,
48
47
  'bench_json_path': bench_path,
49
- 'stack_json_path': stack_path,
50
48
  'is_print_compare_log': is_print_compare_log
51
49
  }
52
50
  ms_compare(input_param=dump_result_param, output_path=output_path, suffix=f'_{nr}-{br}', **kwargs)
@@ -1,4 +1,4 @@
1
- # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
1
+ # Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
2
2
  # All rights reserved.
3
3
  #
4
4
  # Licensed under the Apache License, Version 2.0 (the "License");
@@ -15,7 +15,6 @@
15
15
 
16
16
  import os
17
17
  import re
18
-
19
18
  from collections import defaultdict
20
19
 
21
20
  import numpy as np
@@ -23,15 +22,21 @@ import pandas as pd
23
22
 
24
23
  from msprobe.core.common.const import CompareConst, Const
25
24
  from msprobe.core.common.exceptions import FileCheckException
26
- from msprobe.core.common.file_utils import (FileOpen, create_directory, load_json,
27
- load_npy, load_yaml)
25
+ from msprobe.core.common.file_utils import FileOpen, create_directory, load_json, load_npy, load_yaml
28
26
  from msprobe.core.common.log import logger
29
- from msprobe.core.common.utils import (CompareException, check_compare_param,
30
- check_configuration_param,
31
- get_dump_mode, set_dump_path, check_op_str_pattern_valid)
27
+ from msprobe.core.common.utils import CompareException, check_compare_param, check_configuration_param, \
28
+ check_op_str_pattern_valid, get_dump_mode, set_dump_path
29
+ from msprobe.core.compare.acc_compare import Comparator, ModeConfig
32
30
  from msprobe.core.compare.check import dtype_mapping
33
- from msprobe.core.compare.acc_compare import Comparator
34
31
  from msprobe.core.compare.layer_mapping import generate_data_mapping_by_layer_mapping
32
+ from msprobe.core.compare.utils import set_stack_json_path, reorder_op_x_list
33
+
34
+
35
+ class MappingConfig:
36
+ def __init__(self, cell_mapping=None, api_mapping=None, data_mapping=None):
37
+ self.cell_mapping = cell_mapping
38
+ self.api_mapping = api_mapping
39
+ self.data_mapping = data_mapping
35
40
 
36
41
 
37
42
  class MSComparator(Comparator):
@@ -42,18 +47,27 @@ class MSComparator(Comparator):
42
47
  data_mapping: mindspore的cell或api的入参/出参和pytorch之间的映射关系;
43
48
  is_cross_framework: 是否跨框架。
44
49
  """
45
- def __init__(self, cell_mapping=None, api_mapping=None, data_mapping=None, is_cross_framework=False):
50
+ def __init__(self, mode_config, mapping_config=None, is_cross_framework=False):
51
+ super().__init__(mode_config)
46
52
  self.frame_name = MSComparator.__name__
47
- self.cell_mapping = cell_mapping
48
- self.api_mapping = api_mapping
49
- self.data_mapping = data_mapping
50
- if data_mapping:
53
+
54
+ self.stack_mode = mode_config.stack_mode
55
+ self.auto_analyze = mode_config.auto_analyze
56
+ self.fuzzy_match = mode_config.fuzzy_match
57
+ self.dump_mode = mode_config.dump_mode
58
+
59
+ if mapping_config:
60
+ self.cell_mapping = mapping_config.cell_mapping
61
+ self.api_mapping = mapping_config.api_mapping
62
+ self.data_mapping = mapping_config.data_mapping
63
+
64
+ if self.data_mapping:
51
65
  self.cross_frame = is_cross_framework
52
66
  else:
53
- self.cross_frame = cell_mapping is not None or api_mapping is not None
67
+ self.cross_frame = self.cell_mapping is not None or self.api_mapping is not None
54
68
  self.cell_mapping_dict = self.load_mapping_file(self.cell_mapping)
55
69
  self.api_mapping_dict = self.load_mapping_file(self.api_mapping)
56
- if api_mapping is not None:
70
+ if self.api_mapping is not None:
57
71
  self.ms_to_pt_mapping = self.load_internal_api()
58
72
 
59
73
  if isinstance(self.data_mapping, str) or self.data_mapping is None:
@@ -63,9 +77,8 @@ class MSComparator(Comparator):
63
77
  else:
64
78
  raise TypeError(f"The type of parameter `data_mapping` must be dict, str or None, but got "
65
79
  f"{type(self.data_mapping)}")
66
-
67
- @classmethod
68
- def calc_accuracy(cls, result_df, dump_mode, header):
80
+
81
+ def calc_accuracy(self, result_df, header):
69
82
  condition_no_bench = result_df[CompareConst.BENCH_NAME] == CompareConst.N_A
70
83
  result_df[condition_no_bench] = result_df[condition_no_bench].fillna(CompareConst.N_A)
71
84
  result_df.loc[condition_no_bench, CompareConst.ERROR_MESSAGE] = CompareConst.NO_BENCH
@@ -76,10 +89,10 @@ class MSComparator(Comparator):
76
89
  val_str = val.astype(str)
77
90
  check_series[pd.to_numeric(val_str, errors='coerce').notna() | val_str.str.lower().eq('nan')] = True
78
91
  return check_series
79
-
92
+
80
93
  def get_number(val):
81
94
  return pd.to_numeric(val.astype(str), errors='coerce')
82
-
95
+
83
96
  ms_val = result_df['NPU ' + data_type]
84
97
  pt_val = result_df['Bench ' + data_type]
85
98
  diff_name = data_type.capitalize() + ' diff'
@@ -93,7 +106,7 @@ class MSComparator(Comparator):
93
106
  condition_pt_zero = pt_val == 0
94
107
  result_df.loc[condition_not_nan_diff & condition_pt_zero, rel_err_name] = CompareConst.NAN
95
108
  condition_ref_err = condition_not_nan_diff & ~condition_pt_zero
96
- result_df.loc[condition_ref_err, rel_err_name] = (result_df.loc[condition_ref_err, diff_name] /
109
+ result_df.loc[condition_ref_err, rel_err_name] = (result_df.loc[condition_ref_err, diff_name] /
97
110
  pt_val[condition_ref_err] * 100)
98
111
  result_df.loc[condition_ref_err, rel_err_name] = (result_df.loc[condition_ref_err, rel_err_name]
99
112
  .abs().astype(str) + '%')
@@ -101,31 +114,30 @@ class MSComparator(Comparator):
101
114
  pd.Series(np.maximum(get_number(ms_val), get_number(pt_val))).abs() + CompareConst.EPSILON)
102
115
  return magnitude > CompareConst.MAGNITUDE
103
116
 
104
- if dump_mode == Const.MD5:
117
+ if self.dump_mode == Const.MD5:
105
118
  condition_md5_equal = result_df[CompareConst.NPU_MD5] == result_df[CompareConst.BENCH_MD5]
106
119
  result_df.loc[condition_md5_equal, CompareConst.RESULT] = CompareConst.PASS
107
120
  result_df.loc[~condition_md5_equal & ~condition_no_bench, CompareConst.RESULT] = CompareConst.DIFF
108
- elif dump_mode == Const.SUMMARY:
121
+ elif self.dump_mode == Const.SUMMARY:
109
122
  warning_list = [calc_summary_diff(data_type) for data_type in ['max', 'min', 'mean', 'l2norm']]
110
123
  warning_flag = pd.DataFrame(warning_list).all()
111
124
  result_df.loc[~condition_no_bench, [CompareConst.RESULT, CompareConst.ERROR_MESSAGE]] = ''
112
125
  result_df.loc[warning_flag, CompareConst.RESULT] = CompareConst.WARNING
113
126
  result_df.loc[warning_flag, CompareConst.ERROR_MESSAGE] = 'Need double check api accuracy.'
114
127
  else:
115
- fill_cols = [CompareConst.COSINE, CompareConst.MAX_ABS_ERR, CompareConst.MAX_RELATIVE_ERR,
128
+ fill_cols = [CompareConst.COSINE, CompareConst.MAX_ABS_ERR, CompareConst.MAX_RELATIVE_ERR,
116
129
  CompareConst.ONE_THOUSANDTH_ERR_RATIO, CompareConst.FIVE_THOUSANDTHS_ERR_RATIO,
117
130
  CompareConst.ERROR_MESSAGE]
118
131
  result_df.loc[~condition_no_bench, fill_cols] = ''
119
132
  result_df.loc[~condition_no_bench, CompareConst.ACCURACY] = CompareConst.ACCURACY_CHECK_YES
120
133
  return result_df[header]
121
134
 
122
- @classmethod
123
- def make_result_df(cls, result, stack_mode, dump_mode):
124
- header = CompareConst.HEAD_OF_COMPARE_MODE[dump_mode]
135
+ def make_result_df(self, result):
136
+ header = CompareConst.HEAD_OF_COMPARE_MODE[self.dump_mode][:]
125
137
 
126
- if stack_mode:
138
+ if self.stack_mode:
127
139
  header.append(CompareConst.STACK)
128
- if dump_mode == Const.ALL:
140
+ if self.dump_mode == Const.ALL:
129
141
  header.append(CompareConst.DATA_NAME)
130
142
  result.rename(columns={'op_name_x': CompareConst.NPU_NAME,
131
143
  'op_name_y': CompareConst.BENCH_NAME,
@@ -137,10 +149,11 @@ class MSComparator(Comparator):
137
149
  'md5_y': CompareConst.BENCH_MD5,
138
150
  'data_name_x': CompareConst.DATA_NAME,
139
151
  'stack_info_x': CompareConst.STACK}, inplace=True)
140
-
152
+
141
153
  npu_summary = [CompareConst.NPU_MAX, CompareConst.NPU_MIN, CompareConst.NPU_MEAN, CompareConst.NPU_NORM]
142
- bench_summary = [CompareConst.BENCH_MAX, CompareConst.BENCH_MIN, CompareConst.BENCH_MEAN,
154
+ bench_summary = [CompareConst.BENCH_MAX, CompareConst.BENCH_MIN, CompareConst.BENCH_MEAN,
143
155
  CompareConst.BENCH_NORM]
156
+
144
157
  def set_summary(summary):
145
158
  if summary == CompareConst.N_A:
146
159
  return [CompareConst.N_A] * 4
@@ -153,14 +166,14 @@ class MSComparator(Comparator):
153
166
  else:
154
167
  summary_list.append(i)
155
168
  return summary_list
156
-
169
+
157
170
  result[npu_summary] = result['summary_x'].apply(set_summary).tolist()
158
171
  result[bench_summary] = result['summary_y'].apply(set_summary).tolist()
159
172
  result_df = pd.DataFrame(columns=header)
160
173
  for h in header:
161
174
  if h in result.columns:
162
175
  result_df[h] = result[h]
163
- return cls.calc_accuracy(result_df, dump_mode, header)
176
+ return self.calc_accuracy(result_df, header)
164
177
 
165
178
  def load_internal_api(self):
166
179
  cur_path = os.path.dirname(os.path.realpath(__file__))
@@ -175,13 +188,16 @@ class MSComparator(Comparator):
175
188
  return mapping_dict
176
189
 
177
190
  def process_cell_mapping(self, npu_op_name):
178
- if not npu_op_name or not re.match(r'.+(?:for|back)ward\..+', npu_op_name):
191
+ if not npu_op_name:
192
+ return CompareConst.N_A
193
+ param_grad_flag = Const.PARAMS_GRAD in npu_op_name.split(Const.SEP)
194
+ if not param_grad_flag and not re.search(Const.REGEX_FORWARD_BACKWARD, npu_op_name):
179
195
  return CompareConst.N_A
180
196
  npu_op_name = npu_op_name.replace("Cell", "Module", 1)
181
197
  if self.cell_mapping_dict:
182
198
  # get cell name & class name from op_name
183
199
  # Cell.fc1.Dense.forward.0.input.0
184
- cell_name = re.split(r'\.(?:for|back)ward\.', npu_op_name.split(Const.SEP, 1)[-1])[0]
200
+ cell_name = re.split(r'\.(?:forward|backward|parameters_grad)\.', npu_op_name.split(Const.SEP, 1)[-1])[0]
185
201
  if cell_name in self.cell_mapping_dict:
186
202
  npu_op_name = npu_op_name.replace(cell_name, self.cell_mapping_dict[cell_name], 1)
187
203
  return npu_op_name
@@ -198,7 +214,7 @@ class MSComparator(Comparator):
198
214
  data_value = data_value.to(torch.float32)
199
215
  data_value = data_value.numpy()
200
216
  else:
201
- data_value = load_npy(data_path)
217
+ data_value = load_npy(data_path)
202
218
  return data_value
203
219
 
204
220
  def process_internal_api_mapping(self, npu_op_name):
@@ -214,7 +230,7 @@ class MSComparator(Comparator):
214
230
  return npu_op_name.replace(ms_api_name, self.ms_to_pt_mapping.get(ms_api_name))
215
231
  else:
216
232
  return npu_op_name
217
-
233
+
218
234
  def get_api_name(self, api_list):
219
235
  try:
220
236
  api_name = api_list[0] + Const.SEP + api_list[1]
@@ -223,14 +239,14 @@ class MSComparator(Comparator):
223
239
  raise CompareException(CompareException.INDEX_OUT_OF_BOUNDS_ERROR) from error
224
240
  return api_name
225
241
 
226
- def compare_process(self, file_lists, stack_mode, fuzzy_match, dump_mode):
242
+ def compare_process(self, file_lists):
227
243
  npu_json_path, bench_json_path, stack_json_path = file_lists
228
244
  npu_json_data = load_json(npu_json_path)
229
245
  bench_json_data = load_json(bench_json_path)
230
- stack_json_data = load_json(stack_json_path)
246
+ stack_json_data = load_json(stack_json_path) if self.stack_mode else None
231
247
 
232
- npu_df = self.gen_data_df(npu_json_data, stack_json_data, dump_mode)
233
- bench_df = self.gen_data_df(bench_json_data, stack_json_data, dump_mode)
248
+ npu_df = self.gen_data_df(npu_json_data, stack_json_data)
249
+ bench_df = self.gen_data_df(bench_json_data, stack_json_data)
234
250
  if self.cell_mapping:
235
251
  npu_df[CompareConst.COMPARE_KEY] = npu_df[CompareConst.OP_NAME].apply(self.process_cell_mapping)
236
252
  elif self.api_mapping:
@@ -242,8 +258,8 @@ class MSComparator(Comparator):
242
258
  npu_df[[Const.DTYPE, Const.SHAPE]] = npu_df[[Const.DTYPE, Const.SHAPE]].astype(str)
243
259
  bench_df[[Const.DTYPE, Const.SHAPE]] = bench_df[[Const.DTYPE, Const.SHAPE]].astype(str)
244
260
  npu_df[CompareConst.COMPARE_SHAPE] = npu_df[Const.SHAPE]
245
- bench_df[CompareConst.COMPARE_SHAPE] = bench_df[Const.SHAPE]
246
261
  bench_df[CompareConst.COMPARE_KEY] = bench_df[CompareConst.OP_NAME]
262
+ bench_df[CompareConst.COMPARE_SHAPE] = bench_df[Const.SHAPE]
247
263
  match_result = pd.merge(npu_df, bench_df, on=[CompareConst.COMPARE_KEY, CompareConst.COMPARE_SHAPE],
248
264
  how='outer')
249
265
  match_result = match_result[match_result['op_name_x'].notna()].fillna(CompareConst.N_A)
@@ -262,9 +278,9 @@ class MSComparator(Comparator):
262
278
  ((npu_dtype == Const.TORCH_FLOAT32) & (bench_dtype == Const.TORCH_FLOAT16)) |
263
279
  ((npu_dtype == Const.TORCH_FLOAT16) & (bench_dtype == Const.TORCH_BFLOAT16)) |
264
280
  ((npu_dtype == Const.TORCH_BFLOAT16) & (bench_dtype == Const.TORCH_FLOAT16)))
265
-
281
+
266
282
  match_result.loc[~gen_dtype_condition(), [i + '_y' for i in bench_df.columns]] = CompareConst.N_A
267
- return MSComparator.make_result_df(match_result, stack_mode, dump_mode)
283
+ return self.make_result_df(match_result)
268
284
 
269
285
  def modify_compare_data_with_user_mapping(self, npu_df, bench_df):
270
286
  def get_api_indices_dict(op_name_df):
@@ -288,11 +304,17 @@ class MSComparator(Comparator):
288
304
  return flag
289
305
 
290
306
  for mapping_dict in self.api_mapping_dict:
291
- if (len(mapping_dict.get('ms_args')) != len(mapping_dict.get('pt_args')) or
292
- len(mapping_dict.get('ms_output')) != len(mapping_dict.get('pt_output'))):
307
+ keys_to_compare = [
308
+ ('ms_args', 'pt_args'),
309
+ ('ms_output', 'pt_output'),
310
+ ('ms_parameters', 'pt_parameters'),
311
+ ('ms_parameters_grad', 'pt_parameters_grad'),
312
+ ]
313
+ if not all(len(mapping_dict.get(k1, [])) == len(mapping_dict.get(k2, [])) for k1, k2 in keys_to_compare):
293
314
  logger.warning('The user-defined mapping table is incorrect,\
294
315
  make sure that the number of parameters is equal')
295
316
  continue
317
+
296
318
  ms_api, pt_api = mapping_dict.get('ms_api'), mapping_dict.get('pt_api')
297
319
  if ms_api not in ms_api_indices_dict or pt_api not in pt_api_indices_dict:
298
320
  continue
@@ -304,13 +326,17 @@ class MSComparator(Comparator):
304
326
  is_abandoned = gen_input_compare_key(CompareConst.KWARGS_PATTERN, 'args')
305
327
  elif CompareConst.OUTPUT_PATTERN in op_name:
306
328
  is_abandoned = gen_input_compare_key(CompareConst.OUTPUT_PATTERN, 'output')
329
+ elif CompareConst.PARAMS_PATTERN in op_name:
330
+ is_abandoned = gen_input_compare_key(CompareConst.PARAMS_PATTERN, 'parameters')
331
+ elif CompareConst.PARAMS_GRAD_PATTERN in op_name:
332
+ is_abandoned = gen_input_compare_key(CompareConst.PARAMS_GRAD_PATTERN, 'parameters_grad')
307
333
  else:
308
334
  logger.error(f'Excepted op_name: {op_name}')
309
335
  raise CompareException(CompareException.INVALID_DATA_ERROR)
310
336
  if is_abandoned:
311
337
  npu_df.loc[index, CompareConst.COMPARE_KEY] = op_name + 'abandoned'
312
338
 
313
- def gen_data_df(self, data_json, stack_json, dump_mode):
339
+ def gen_data_df(self, data_json, stack_json_data):
314
340
  result = {
315
341
  CompareConst.OP_NAME: [],
316
342
  Const.DTYPE: [],
@@ -318,29 +344,40 @@ class MSComparator(Comparator):
318
344
  Const.SUMMARY: [],
319
345
  'stack_info': []
320
346
  }
321
- if dump_mode == Const.ALL:
347
+ if self.dump_mode == Const.ALL:
322
348
  result['data_name'] = []
323
- elif dump_mode == Const.MD5:
349
+ elif self.dump_mode == Const.MD5:
324
350
  result[Const.MD5] = []
325
351
  for data_name in data_json['data']:
326
352
  check_op_str_pattern_valid(data_name)
327
- merge_list = self.gen_merge_list(data_json, data_name, stack_json, dump_mode)
353
+ merge_list = self.gen_merge_list(data_json, data_name, stack_json_data)
328
354
  if not merge_list:
329
355
  continue
330
- for op_name in merge_list[CompareConst.OP_NAME]:
356
+
357
+ op_name_list = merge_list.get(CompareConst.OP_NAME)
358
+ summary_list = merge_list.get(Const.SUMMARY)
359
+ data_name_list = merge_list.get('data_name')
360
+ op_name_reorder, summary_reorder, data_name_reorder = reorder_op_x_list(op_name_list,
361
+ summary_list,
362
+ data_name_list)
363
+ for op_name in op_name_reorder:
331
364
  result[CompareConst.OP_NAME].append(op_name)
332
365
  if (CompareConst.INPUT_PATTERN in op_name) or (CompareConst.KWARGS_PATTERN in op_name):
333
366
  struct = merge_list[CompareConst.INPUT_STRUCT].pop(0)
334
- else:
367
+ elif CompareConst.OUTPUT_PATTERN in op_name:
335
368
  struct = merge_list[CompareConst.OUTPUT_STRUCT].pop(0)
369
+ elif CompareConst.PARAMS_PATTERN in op_name:
370
+ struct = merge_list[CompareConst.PARAMS_STRUCT].pop(0)
371
+ else:
372
+ struct = merge_list[CompareConst.PARAMS_GRAD_STRUCT].pop(0)
336
373
  result[Const.DTYPE].append(struct[0])
337
374
  result[Const.SHAPE].append(struct[1])
338
- if dump_mode == Const.MD5:
375
+ if self.dump_mode == Const.MD5:
339
376
  result[Const.MD5].append(struct[2])
340
- result[Const.SUMMARY].append(merge_list[Const.SUMMARY].pop(0))
341
- result['stack_info'].append(merge_list['stack_info'][0])
342
- if dump_mode == Const.ALL:
343
- result['data_name'].append(merge_list['data_name'].pop(0))
377
+ result[Const.SUMMARY].append(summary_reorder.pop(0))
378
+ result['stack_info'].append(merge_list['stack_info'][0] if self.stack_mode else None)
379
+ if self.dump_mode == Const.ALL:
380
+ result['data_name'].append(data_name_reorder.pop(0))
344
381
  return pd.DataFrame(result)
345
382
 
346
383
 
@@ -355,7 +392,6 @@ def check_cross_framework(bench_json_path):
355
392
 
356
393
  def ms_compare(input_param, output_path, **kwargs):
357
394
  try:
358
- stack_mode = kwargs.get('stack_mode', False)
359
395
  auto_analyze = kwargs.get('auto_analyze', True)
360
396
  fuzzy_match = kwargs.get('fuzzy_match', False)
361
397
  cell_mapping = kwargs.get('cell_mapping', None)
@@ -366,15 +402,21 @@ def ms_compare(input_param, output_path, **kwargs):
366
402
 
367
403
  set_dump_path(input_param)
368
404
  dump_mode = get_dump_mode(input_param)
405
+ if 'stack_json_path' in input_param:
406
+ stack_mode = kwargs.get('stack_mode', False)
407
+ else:
408
+ stack_mode = set_stack_json_path(input_param) # set stack_mode and set "stack_json_path" in input_param
369
409
  check_configuration_param(stack_mode, auto_analyze, fuzzy_match, input_param.get('is_print_compare_log', True))
370
410
  create_directory(output_path)
371
- check_compare_param(input_param, output_path, dump_mode)
411
+ check_compare_param(input_param, output_path, dump_mode, stack_mode)
372
412
  except (CompareException, FileCheckException) as error:
373
413
  logger.error('Compare failed. Please check the arguments and do it again!')
374
414
  raise CompareException(error.code) from error
375
415
  if layer_mapping:
376
416
  data_mapping = generate_data_mapping_by_layer_mapping(input_param, layer_mapping, output_path)
377
- is_cross_framework = check_cross_framework(input_param.get("bench_json_path"))
378
- ms_comparator = MSComparator(cell_mapping, api_mapping, data_mapping, is_cross_framework)
379
- ms_comparator.compare_core(input_param, output_path, stack_mode=stack_mode, suffix=suffix,
380
- auto_analyze=auto_analyze, fuzzy_match=fuzzy_match, dump_mode=dump_mode)
417
+
418
+ mode_config = ModeConfig(stack_mode, auto_analyze, fuzzy_match, dump_mode)
419
+ mapping_config = MappingConfig(cell_mapping, api_mapping, data_mapping)
420
+ is_cross_framework = check_cross_framework(input_param.get('bench_json_path'))
421
+ ms_comparator = MSComparator(mode_config, mapping_config, is_cross_framework)
422
+ ms_comparator.compare_core(input_param, output_path, suffix=suffix)
@@ -25,7 +25,7 @@ from msprobe.core.common.file_utils import load_npy, read_csv, save_excel
25
25
  from msprobe.core.common.log import logger
26
26
  from msprobe.core.common.utils import add_time_with_xlsx, CompareException
27
27
  from msprobe.core.compare.multiprocessing_compute import _ms_graph_handle_multi_process, check_accuracy
28
- from msprobe.core.compare.npy_compare import npy_data_check, statistics_data_check, reshape_value, compare_ops_apply
28
+ from msprobe.core.compare.npy_compare import npy_data_check, statistics_data_check, compare_ops_apply
29
29
  from msprobe.mindspore.common.utils import convert_to_int, list_lowest_level_directories
30
30
 
31
31
 
@@ -144,10 +144,16 @@ def generate_data_name(data_path):
144
144
  mode = GraphMode.STATISTIC_MODE
145
145
  else:
146
146
  mode = GraphMode.ERROR_MODE
147
- logger.error(f"Error mode.")
147
+ logger.error("Error mode.")
148
148
  return mode, data_list
149
149
 
150
150
 
151
+ def transform_special_string_into_float(data_frame):
152
+ data_frame[data_frame == "null"] = '0'
153
+ data_frame[data_frame == "False"] = '0'
154
+ data_frame[data_frame == "True"] = '1'
155
+
156
+
151
157
  class GraphMSComparator:
152
158
  def __init__(self, input_param, output_path):
153
159
  self.output_path = output_path
@@ -187,7 +193,6 @@ class GraphMSComparator:
187
193
  result_dict[CompareConst.ERROR_MESSAGE] = error_message
188
194
 
189
195
  if not error_flag:
190
- n_value, b_value = reshape_value(n_value, b_value)
191
196
  result_list, err_msg = compare_ops_apply(n_value, b_value, False, "")
192
197
  result_dict[CompareConst.COSINE] = result_list[0]
193
198
  result_dict[CompareConst.MAX_ABS_ERR] = result_list[1]
@@ -334,13 +339,17 @@ class GraphMSComparator:
334
339
  CompareConst.BENCH_NORM])
335
340
 
336
341
  npu_float_type = [CompareConst.NPU_MAX, CompareConst.NPU_MIN, CompareConst.NPU_MEAN, CompareConst.NPU_NORM]
337
- npu_data_df[npu_float_type] = npu_data_df[npu_float_type].astype(float)
342
+ npu_float_data_df = npu_data_df[npu_float_type].astype(str)
343
+ transform_special_string_into_float(npu_float_data_df)
344
+ npu_data_df[npu_float_type] = npu_float_data_df.astype(float)
338
345
 
339
346
  bench_float_type = [
340
347
  CompareConst.BENCH_MAX, CompareConst.BENCH_MIN,
341
348
  CompareConst.BENCH_MEAN, CompareConst.BENCH_NORM
342
349
  ]
343
- bench_data_df[bench_float_type] = bench_data_df[bench_float_type].astype(float)
350
+ bench_float_data_df = bench_data_df[bench_float_type].astype(str)
351
+ transform_special_string_into_float(bench_float_data_df)
352
+ bench_data_df[bench_float_type] = bench_float_data_df.astype(float)
344
353
 
345
354
  npu_data_df['Local Index'] = npu_data_df.sort_values('TimeStamp').groupby('Compare Key').cumcount()
346
355
  bench_data_df['Local Index'] = bench_data_df.sort_values('TimeStamp').groupby('Compare Key').cumcount()
@@ -39,6 +39,7 @@ class DebuggerConfig:
39
39
  self.check_mode = task_config.check_mode
40
40
  self.framework = Const.MS_FRAMEWORK
41
41
  self.summary_mode = task_config.summary_mode
42
+ self.async_dump = common_config.async_dump if common_config.async_dump else False
42
43
  self.check()
43
44
  create_directory(self.dump_path)
44
45
 
@@ -69,4 +70,6 @@ class DebuggerConfig:
69
70
  self.file_format = "npy"
70
71
  if not self.check_mode:
71
72
  self.check_mode = "all"
73
+ if not isinstance(self.async_dump, bool):
74
+ raise Exception("The parameters async_dump should be bool.")
72
75
  return True