mindstudio-probe 8.1.2__py3-none-any.whl → 8.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (181) hide show
  1. {mindstudio_probe-8.1.2.dist-info → mindstudio_probe-8.2.1.dist-info}/METADATA +2 -2
  2. {mindstudio_probe-8.1.2.dist-info → mindstudio_probe-8.2.1.dist-info}/RECORD +172 -147
  3. msprobe/README.md +6 -6
  4. msprobe/core/common/const.py +98 -41
  5. msprobe/core/common/db_manager.py +256 -0
  6. msprobe/core/common/file_utils.py +28 -5
  7. msprobe/core/common/log.py +7 -0
  8. msprobe/core/common/megatron_utils.py +59 -0
  9. msprobe/core/common/parallel_state.py +193 -0
  10. msprobe/core/common/utils.py +20 -13
  11. msprobe/core/common_config.py +5 -0
  12. msprobe/core/compare/acc_compare.py +140 -93
  13. msprobe/core/compare/check.py +13 -0
  14. msprobe/core/compare/compare_cli.py +64 -6
  15. msprobe/core/compare/config.py +10 -8
  16. msprobe/core/compare/diff_analyze/diff_analyze_threshold.yaml +14 -0
  17. msprobe/core/compare/diff_analyze/first_diff_analyze.py +135 -0
  18. msprobe/core/compare/diff_analyze/ignore_op_list.yaml +3 -0
  19. msprobe/core/compare/find_first/__init__.py +0 -0
  20. msprobe/core/compare/find_first/analyzer.py +282 -0
  21. msprobe/core/compare/find_first/data_processor.py +35 -0
  22. msprobe/core/compare/find_first/graph.py +188 -0
  23. msprobe/core/compare/find_first/utils.py +189 -0
  24. msprobe/core/compare/highlight.py +74 -101
  25. msprobe/core/compare/layer_mapping/layer_mapping.py +14 -9
  26. msprobe/core/compare/merge_result/merge_result.py +2 -2
  27. msprobe/core/compare/multiprocessing_compute.py +45 -28
  28. msprobe/core/compare/npy_compare.py +7 -10
  29. msprobe/core/compare/utils.py +338 -130
  30. msprobe/core/config_check/checkers/dataset_checker.py +2 -1
  31. msprobe/core/config_check/checkers/env_args_checker.py +5 -5
  32. msprobe/core/config_check/checkers/hyperparameter_checker.py +30 -10
  33. msprobe/core/config_check/checkers/pip_checker.py +4 -3
  34. msprobe/core/config_check/checkers/random_checker.py +3 -3
  35. msprobe/core/config_check/checkers/weights_checker.py +2 -1
  36. msprobe/core/config_check/ckpt_compare/megatron_loader.py +2 -0
  37. msprobe/core/config_check/resource/hyperparameter.yaml +11 -1
  38. msprobe/core/config_check/utils/hyperparameter_parser.py +7 -3
  39. msprobe/core/config_check/utils/utils.py +10 -0
  40. msprobe/core/data_dump/api_registry.py +49 -30
  41. msprobe/core/data_dump/data_collector.py +71 -29
  42. msprobe/core/data_dump/data_processor/base.py +2 -0
  43. msprobe/core/data_dump/data_processor/mindspore_processor.py +47 -53
  44. msprobe/core/data_dump/data_processor/pytorch_processor.py +227 -93
  45. msprobe/core/data_dump/json_writer.py +81 -7
  46. msprobe/core/data_dump/scope.py +4 -6
  47. msprobe/core/hook_manager.py +129 -70
  48. msprobe/core/monitor/csv2db.py +361 -0
  49. msprobe/core/monitor/db_utils.py +278 -0
  50. msprobe/core/monitor/utils.py +35 -1
  51. msprobe/core/service.py +31 -39
  52. msprobe/core/single_save/single_comparator.py +16 -3
  53. msprobe/docs/01.installation.md +51 -19
  54. msprobe/docs/02.config_introduction.md +16 -20
  55. msprobe/docs/03.config_examples.md +26 -0
  56. msprobe/docs/04.kernel_dump_PyTorch.md +1 -1
  57. msprobe/docs/05.data_dump_PyTorch.md +6 -2
  58. msprobe/docs/06.data_dump_MindSpore.md +44 -7
  59. msprobe/docs/07.accuracy_checker_PyTorch.md +1 -1
  60. msprobe/docs/10.accuracy_compare_PyTorch.md +124 -44
  61. msprobe/docs/11.accuracy_compare_MindSpore.md +75 -7
  62. msprobe/docs/14.data_parse_PyTorch.md +1 -1
  63. msprobe/docs/19.monitor.md +94 -7
  64. msprobe/docs/21.visualization_PyTorch.md +71 -101
  65. msprobe/docs/22.visualization_MindSpore.md +69 -119
  66. msprobe/docs/23.generate_operator_PyTorch.md +1 -1
  67. msprobe/docs/25.tool_function_introduction.md +0 -1
  68. msprobe/docs/26.data_dump_PyTorch_baseline.md +7 -7
  69. msprobe/docs/28.debugger_save_instruction.md +184 -81
  70. msprobe/docs/29.data_dump_MSAdapter.md +6 -0
  71. msprobe/docs/31.config_check.md +4 -2
  72. msprobe/docs/36.calculation_result_change.md +75 -0
  73. msprobe/docs/FAQ.md +22 -1
  74. msprobe/docs/data_dump_MindSpore/dynamic_graph_quick_start_example.md +6 -2
  75. msprobe/docs/img/compare_result.png +0 -0
  76. msprobe/docs/img/visualization/vis_browser_1.png +0 -0
  77. msprobe/docs/img/visualization/vis_match_info.png +0 -0
  78. msprobe/docs/img/visualization/vis_precision_info.png +0 -0
  79. msprobe/docs/img/visualization/vis_search_info.png +0 -0
  80. msprobe/docs/img/visualization/vis_show_info.png +0 -0
  81. msprobe/docs/img/visualization/vis_showcase.png +0 -0
  82. msprobe/docs/img/visualization/vis_unmatch_info.png +0 -0
  83. msprobe/docs/visualization/mindspeed_llamafactoary_img/1.png +0 -0
  84. msprobe/docs/visualization/mindspeed_llamafactoary_img/2.png +0 -0
  85. msprobe/docs/visualization/mindspeed_llamafactoary_img/3.png +0 -0
  86. msprobe/docs/visualization/mindspeed_llamafactoary_img/4.png +0 -0
  87. msprobe/docs/visualization/mindspeed_llamafactoary_img/5.png +0 -0
  88. msprobe/docs/visualization/mindspeed_llamafactoary_img/6.png +0 -0
  89. msprobe/docs/visualization/mindspeed_llamafactoary_img/7.png +0 -0
  90. msprobe/docs/visualization/mindspeed_llamafactoary_img/llamafactory-qwen25vl.txt +59 -0
  91. msprobe/docs/visualization/mindspeed_llamafactoary_img/llamafactory1.png +0 -0
  92. msprobe/docs/visualization/mindspeed_llamafactoary_img/llamafactory2.png +0 -0
  93. msprobe/docs/visualization/mindspeed_llamafactoary_img/mindspeed-mm-qwen25vl.txt +80 -0
  94. msprobe/docs/visualization/mindspeed_llamafactoary_img/mindspeed1.png +0 -0
  95. msprobe/docs/visualization/mindspeed_llamafactoary_img/mindspeed2.png +0 -0
  96. msprobe/docs/visualization/mindspeed_llamafactory_mapping.md +330 -0
  97. msprobe/mindspore/__init__.py +1 -1
  98. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +1 -1
  99. msprobe/mindspore/api_accuracy_checker/api_runner.py +9 -6
  100. msprobe/mindspore/api_accuracy_checker/compute_element.py +18 -12
  101. msprobe/mindspore/cell_processor.py +64 -25
  102. msprobe/mindspore/common/utils.py +51 -7
  103. msprobe/mindspore/compare/common_dir_compare.py +45 -37
  104. msprobe/mindspore/compare/ms_compare.py +10 -2
  105. msprobe/mindspore/compare/ms_graph_compare.py +47 -52
  106. msprobe/mindspore/debugger/debugger_config.py +18 -7
  107. msprobe/mindspore/debugger/precision_debugger.py +16 -12
  108. msprobe/mindspore/dump/cell_dump_process.py +130 -68
  109. msprobe/mindspore/dump/cell_dump_with_insert_gradient.py +10 -2
  110. msprobe/mindspore/dump/graph_mode_cell_dump.py +35 -9
  111. msprobe/mindspore/dump/graph_tensor_dump.py +11 -0
  112. msprobe/mindspore/dump/hook_cell/api_register.py +19 -20
  113. msprobe/mindspore/dump/hook_cell/hook_cell.py +12 -34
  114. msprobe/mindspore/dump/hook_cell/ms_hook_manager.py +142 -21
  115. msprobe/mindspore/dump/kernel_kbyk_dump.py +24 -0
  116. msprobe/mindspore/exception_dump/__init__.py +0 -0
  117. msprobe/mindspore/exception_dump/exception_dump_tool_factory.py +51 -0
  118. msprobe/mindspore/exception_dump/kernel_graph_exception_dump.py +57 -0
  119. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +5 -4
  120. msprobe/mindspore/mindspore_service.py +2 -2
  121. msprobe/mindspore/mindtorch/mindtorch_adaptor.py +12 -7
  122. msprobe/mindspore/monitor/features.py +82 -0
  123. msprobe/mindspore/monitor/module_hook.py +168 -10
  124. msprobe/mindspore/monitor/utils.py +27 -1
  125. msprobe/mindspore/ms_config.py +12 -4
  126. msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +1 -1
  127. msprobe/mindspore/task_handler_factory.py +3 -1
  128. msprobe/nan_analyze/graph.py +1 -1
  129. msprobe/pytorch/api_accuracy_checker/common/config.py +3 -36
  130. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +0 -24
  131. msprobe/pytorch/api_accuracy_checker/compare/compare.py +2 -12
  132. msprobe/pytorch/api_accuracy_checker/config.yaml +1 -6
  133. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +2 -2
  134. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +12 -132
  135. msprobe/pytorch/common/utils.py +1 -21
  136. msprobe/pytorch/compare/pt_compare.py +10 -2
  137. msprobe/pytorch/{hook_module/jit_script_wrapper.py → compare/pt_diff_analyze.py} +3 -15
  138. msprobe/pytorch/compare/utils.py +2 -1
  139. msprobe/pytorch/debugger/debugger_config.py +18 -23
  140. msprobe/pytorch/dump/module_dump/hook_wrapper.py +10 -7
  141. msprobe/pytorch/dump/module_dump/module_processer.py +41 -19
  142. msprobe/pytorch/free_benchmark/main.py +7 -4
  143. msprobe/pytorch/hook_module/api_register.py +62 -24
  144. msprobe/pytorch/hook_module/hook_module.py +9 -29
  145. msprobe/pytorch/hook_module/pt_hook_manager.py +84 -15
  146. msprobe/pytorch/hook_module/script_wrapper.py +140 -0
  147. msprobe/pytorch/hook_module/support_wrap_ops.yaml +6 -0
  148. msprobe/pytorch/monitor/csv2tb.py +1 -1
  149. msprobe/pytorch/monitor/features.py +94 -0
  150. msprobe/pytorch/monitor/module_hook.py +221 -81
  151. msprobe/pytorch/monitor/module_metric.py +27 -1
  152. msprobe/pytorch/monitor/optimizer_collect.py +109 -4
  153. msprobe/pytorch/online_dispatch/dispatch.py +42 -24
  154. msprobe/pytorch/online_dispatch/dump_compare.py +1 -1
  155. msprobe/pytorch/parse_tool/lib/visualization.py +0 -1
  156. msprobe/pytorch/pt_config.py +2 -51
  157. msprobe/pytorch/pytorch_service.py +7 -14
  158. msprobe/visualization/builder/graph_builder.py +192 -63
  159. msprobe/visualization/builder/graph_merger.py +986 -0
  160. msprobe/visualization/builder/msprobe_adapter.py +17 -15
  161. msprobe/visualization/compare/graph_comparator.py +26 -16
  162. msprobe/visualization/db_utils.py +252 -0
  163. msprobe/visualization/graph/base_node.py +2 -22
  164. msprobe/visualization/graph/distributed_analyzer.py +12 -12
  165. msprobe/visualization/graph/graph.py +44 -16
  166. msprobe/visualization/graph_service.py +143 -59
  167. msprobe/visualization/utils.py +103 -4
  168. msprobe/docs/08.accuracy_checker_online_PyTorch.md +0 -295
  169. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +0 -205
  170. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +0 -378
  171. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +0 -239
  172. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/dump_dispatch.py +0 -115
  173. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +0 -250
  174. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/torch_ops_config.yaml +0 -63
  175. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +0 -198
  176. msprobe/pytorch/attl_manager.py +0 -65
  177. {mindstudio_probe-8.1.2.dist-info → mindstudio_probe-8.2.1.dist-info}/LICENSE +0 -0
  178. {mindstudio_probe-8.1.2.dist-info → mindstudio_probe-8.2.1.dist-info}/WHEEL +0 -0
  179. {mindstudio_probe-8.1.2.dist-info → mindstudio_probe-8.2.1.dist-info}/entry_points.txt +0 -0
  180. {mindstudio_probe-8.1.2.dist-info → mindstudio_probe-8.2.1.dist-info}/top_level.txt +0 -0
  181. /msprobe/{pytorch/api_accuracy_checker/tensor_transport_layer → core/compare/diff_analyze}/__init__.py +0 -0
@@ -25,16 +25,18 @@ from tqdm import tqdm
25
25
  from msprobe.core.advisor.advisor import Advisor
26
26
  from msprobe.core.common.const import CompareConst, Const
27
27
  from msprobe.core.common.exceptions import FileCheckException
28
- from msprobe.core.common.file_utils import load_json, remove_path, create_directory
28
+ from msprobe.core.common.file_utils import load_json, remove_path, create_directory, save_excel, save_json
29
29
  from msprobe.core.common.log import logger
30
30
  from msprobe.core.common.utils import CompareException, add_time_with_xlsx, check_op_str_pattern_valid, \
31
- set_dump_path, get_dump_mode, check_compare_param, check_configuration_param, load_stack_json, get_file_type
32
- from msprobe.core.compare.check import check_dump_json_str, check_stack_json_str, cross_dtype_mapping
33
- from msprobe.core.compare.utils import merge_tensor, print_compare_ends_info, read_op, \
34
- reorder_op_x_list, set_stack_json_path, check_api_info_len
31
+ set_dump_path, get_dump_mode, check_compare_param, load_stack_json, get_file_type, add_time_with_json
32
+ from msprobe.core.compare.check import check_dump_json_str, check_stack_json_str, cross_dtype_mapping, \
33
+ check_configuration_param
34
+ from msprobe.core.compare.utils import merge_tensor, print_compare_ends_info, read_op, set_stack_json_path, \
35
+ reorder_index
35
36
  from msprobe.core.compare.config import ModeConfig, MappingConfig, MappingDict
36
37
  from msprobe.core.compare.multiprocessing_compute import CompareRealData
37
38
  from msprobe.core.compare.highlight import HighLight
39
+ from msprobe.core.compare.diff_analyze.first_diff_analyze import FirstDiffAnalyze
38
40
 
39
41
 
40
42
  @dataclass
@@ -43,12 +45,15 @@ class ComparisonConfig:
43
45
  stack_mode: bool
44
46
  auto_analyze: bool
45
47
  fuzzy_match: bool
48
+ highlight: bool
46
49
  data_mapping: dict
47
50
  suffix: str
48
51
  cell_mapping: dict
49
52
  api_mapping: dict
50
53
  layer_mapping: dict
51
54
  compared_file_type: str
55
+ first_diff_analyze: bool
56
+ is_print_compare_log: bool
52
57
 
53
58
 
54
59
  class Comparator:
@@ -57,17 +62,18 @@ class Comparator:
57
62
  self.mode_config = mode_config
58
63
  self.mapping_config = mapping_config
59
64
  self.cross_frame = is_cross_framework
60
-
61
65
  self.mapping_dict = MappingDict(mapping_config)
62
66
 
63
- @staticmethod
64
- def process_output_file(output_path, suffix, compared_file_type):
67
+ def process_output_file(self, output_path, suffix, compared_file_type):
65
68
  file_name_prefix_mapping = {
66
69
  Const.DUMP_JSON_FILE: "compare_result",
67
70
  Const.DEBUG_JSON_FILE: "debug_compare_result"
68
71
  }
69
72
  file_name_prefix = file_name_prefix_mapping.get(compared_file_type, "compare_result")
70
- file_name = add_time_with_xlsx(file_name_prefix + suffix)
73
+ if self.mode_config.first_diff_analyze:
74
+ file_name = add_time_with_json("compare_result" + suffix)
75
+ else:
76
+ file_name = add_time_with_xlsx(file_name_prefix + suffix)
71
77
  file_path = os.path.join(os.path.realpath(output_path), file_name)
72
78
  if os.path.exists(file_path):
73
79
  logger.warning(f"{file_path} will be deleted.")
@@ -95,6 +101,7 @@ class Comparator:
95
101
 
96
102
  # get kwargs or set default value
97
103
  suffix = kwargs.get('suffix', '')
104
+ rank = suffix[1:]
98
105
 
99
106
  # process output file
100
107
  file_path = self.process_output_file(output_path, suffix, self.mode_config.compared_file_type)
@@ -103,22 +110,45 @@ class Comparator:
103
110
  npu_json = input_param.get("npu_json_path")
104
111
  bench_json = input_param.get("bench_json_path")
105
112
  stack_json = input_param.get("stack_json_path")
106
- result_df = self.compare_statistics([npu_json, bench_json, stack_json])
113
+ parse_data = ParseData(self.mode_config, rank) # load and parse json data
114
+ npu_df, bench_df = parse_data.parse([npu_json, bench_json, stack_json])
115
+ result_df = self.compare_statistics(npu_df, bench_df)
107
116
  if not result_df.values.tolist():
108
117
  logger.warning("Can`t match any op. No compare result file generated.")
109
118
  return
110
119
 
120
+ if self.mode_config.first_diff_analyze:
121
+ # add P2POp additional info from npu_df and bench_df to result_df
122
+ result_df['NPU P2POp op'] = npu_df['op']
123
+ result_df['Bench P2POp op'] = bench_df['op']
124
+ result_df['NPU P2POp peer'] = npu_df['peer']
125
+ result_df['Bench P2POp peer'] = bench_df['peer']
126
+
127
+ first_diff_analyze = FirstDiffAnalyze(self.mode_config, rank)
128
+ check_result = first_diff_analyze.check(result_df)
129
+ save_json(file_path, check_result, indent=4)
130
+ logger.info(f"Saving json file to disk: {file_path}")
131
+ return
132
+
111
133
  # compare real data
112
134
  if self.mode_config.dump_mode == Const.ALL:
113
135
  compare_real_data = CompareRealData(self.file_reader, self.mode_config, self.cross_frame)
114
136
  result_df = compare_real_data.do_multi_process(input_param, result_df)
115
137
 
116
- # highlight suspicious API
117
- highlight_dict = {"red_rows": set(), "yellow_rows": set(), "red_lines": [], "yellow_lines": []}
118
- highlight = HighLight(self.mode_config)
119
- if self.mode_config.compared_file_type == Const.DUMP_JSON_FILE:
120
- highlight.find_compare_result_error_rows(result_df, highlight_dict)
121
- highlight.highlight_rows_xlsx(result_df, highlight_dict, file_path)
138
+ # save result excel file
139
+ logger.info(f'Saving result excel file in progress. The file path is: {file_path}.')
140
+ if self.mode_config.highlight and len(result_df) <= CompareConst.MAX_EXCEL_LENGTH:
141
+ # highlight if not too long
142
+ highlight_dict = {"red_rows": set(), "yellow_rows": set(), "red_lines": [], "yellow_lines": []}
143
+ highlight = HighLight(self.mode_config, rank)
144
+ if self.mode_config.compared_file_type == Const.DUMP_JSON_FILE:
145
+ highlight.find_compare_result_error_rows(result_df, highlight_dict)
146
+ result_df.drop(columns=['state', 'api_origin_name'], inplace=True) # 删除中间数据,两列不落盘
147
+ highlight.highlight_rows_xlsx(result_df, highlight_dict, file_path)
148
+ else:
149
+ # fallback to simple save without highlight
150
+ result_df.drop(columns=['state', 'api_origin_name'], inplace=True) # 删除中间数据,两列不落盘
151
+ save_excel(file_path, result_df)
122
152
 
123
153
  # output compare analysis suggestions
124
154
  if self.mode_config.auto_analyze:
@@ -127,11 +157,7 @@ class Comparator:
127
157
 
128
158
  print_compare_ends_info()
129
159
 
130
- def compare_statistics(self, file_list):
131
- # load and parse json data
132
- parse_data = ParseData(self.mode_config)
133
- npu_df, bench_df = parse_data.parse(file_list)
134
-
160
+ def compare_statistics(self, npu_df, bench_df):
135
161
  npu_df[[Const.DTYPE, Const.SHAPE]] = npu_df[[Const.DTYPE, Const.SHAPE]].astype(str)
136
162
  bench_df[[Const.DTYPE, Const.SHAPE]] = bench_df[[Const.DTYPE, Const.SHAPE]].astype(str)
137
163
 
@@ -149,6 +175,8 @@ class Comparator:
149
175
  match_result.loc[~match.gen_dtype_condition(match_result), bench_columns] = CompareConst.N_A
150
176
 
151
177
  # organize compare result table by renaming columns
178
+ if self.mode_config.dump_mode == Const.ALL and self.mode_config.first_diff_analyze:
179
+ self.mode_config.dump_mode = Const.SUMMARY
152
180
  create_table = CreateTable(self.mode_config)
153
181
  result_df, header = create_table.make_result_df(match_result)
154
182
 
@@ -158,8 +186,9 @@ class Comparator:
158
186
 
159
187
 
160
188
  class ParseData:
161
- def __init__(self, mode_config: ModeConfig):
189
+ def __init__(self, mode_config: ModeConfig, rank):
162
190
  self.mode_config = mode_config
191
+ self.rank = rank
163
192
 
164
193
  def parse(self, file_list):
165
194
  npu_json_path, bench_json_path, stack_json_path = file_list
@@ -168,21 +197,24 @@ class ParseData:
168
197
  stack_json_data = load_stack_json(stack_json_path) if self.mode_config.stack_mode else None
169
198
 
170
199
  # parse json data and generate df
171
- npu_df = self.gen_data_df(npu_json_data, stack_json_data)
172
- bench_df = self.gen_data_df(bench_json_data, stack_json_data)
200
+ npu_df = self.gen_data_df(npu_json_data, stack_json_data, 'NPU')
201
+ bench_df = self.gen_data_df(bench_json_data, stack_json_data, 'Bench')
173
202
 
174
203
  return npu_df, bench_df
175
204
 
176
- def gen_data_df(self, data_json, stack_json_data):
205
+ def gen_data_df(self, data_json, stack_json_data, device: str):
177
206
  result = {
178
207
  CompareConst.OP_NAME: [],
179
208
  Const.DTYPE: [],
180
209
  Const.SHAPE: [],
181
210
  Const.SUMMARY: [],
182
- Const.STACK_INFO: []
211
+ Const.STACK_INFO: [],
212
+ Const.STATE: [],
213
+ Const.API_ORIGIN_NAME: [],
214
+ Const.REQ_GRAD: []
183
215
  }
184
216
  if self.mode_config.dump_mode == Const.ALL:
185
- result['data_name'] = []
217
+ result[Const.DATA_NAME] = []
186
218
  elif self.mode_config.dump_mode == Const.MD5:
187
219
  result[Const.MD5] = []
188
220
 
@@ -192,56 +224,50 @@ class ParseData:
192
224
  return pd.DataFrame(result)
193
225
 
194
226
  api_nums = len(apis_data)
195
- progress_bar = tqdm(total=api_nums, desc="API/Module Read Progress", unit="api/module", ncols=100)
227
+ default_bar_desc = f'{device} API/Module Read Progress'
228
+ bar_desc_add_rank = f'[{self.rank}]' + default_bar_desc if self.rank else default_bar_desc
229
+ progress_bar = tqdm(total=api_nums, desc=bar_desc_add_rank, unit="api/module", ncols=100)
196
230
 
197
231
  # 从json中循环解析API数据,遍历所有API
198
232
  for data_name in apis_data:
199
233
  check_op_str_pattern_valid(data_name)
200
- merge_list = self.gen_merge_list(data_json, data_name, stack_json_data)
201
- if not merge_list:
234
+ op_parsed_list = self.gen_merge_list(data_json, data_name, stack_json_data)
235
+ if not op_parsed_list:
202
236
  continue
203
-
204
- op_name_list = merge_list.get(CompareConst.OP_NAME)
205
- summary_list = merge_list.get(Const.SUMMARY)
206
- data_name_list = merge_list.get('data_name')
207
- op_name_reorder, summary_reorder, data_name_reorder = reorder_op_x_list(op_name_list,
208
- summary_list,
209
- data_name_list)
210
- # 遍历单个API的所有item
211
- for index, op_name in enumerate(op_name_reorder):
212
- result[CompareConst.OP_NAME].append(op_name)
213
- if (CompareConst.INPUT_PATTERN in op_name) or (CompareConst.KWARGS_PATTERN in op_name):
214
- info_list = merge_list[CompareConst.INPUT_STRUCT]
215
- elif CompareConst.OUTPUT_PATTERN in op_name:
216
- info_list = merge_list[CompareConst.OUTPUT_STRUCT]
217
- elif CompareConst.PARAMS_PATTERN in op_name:
218
- info_list = merge_list[CompareConst.PARAMS_STRUCT]
219
- elif CompareConst.PARAMS_GRAD_PATTERN in op_name:
220
- info_list = merge_list[CompareConst.PARAMS_GRAD_STRUCT]
221
- else:
222
- info_list = merge_list[CompareConst.DEBUG_STRUCT]
223
- check_api_info_len(op_name, info_list, 1)
224
- struct = info_list.pop(0)
225
-
226
- check_api_info_len(op_name, struct, 2)
227
- result[Const.DTYPE].append(struct[0])
228
- result[Const.SHAPE].append(struct[1])
237
+ reordered_index_list = reorder_index(op_parsed_list)
238
+ for i, index in enumerate(reordered_index_list):
239
+ op_item = op_parsed_list[index]
240
+
241
+ # common key
242
+ result[CompareConst.OP_NAME].append(op_item.get('full_op_name'))
243
+ result[Const.DTYPE].append(op_item.get(Const.DTYPE))
244
+ result[Const.SHAPE].append(op_item.get(Const.SHAPE))
245
+ result[Const.STATE].append(op_item.get(Const.STATE))
246
+ result[Const.REQ_GRAD].append(op_item.get(Const.REQ_GRAD))
247
+ result[Const.API_ORIGIN_NAME].append(data_name)
248
+ summary_data = [
249
+ str(op_item.get(key)) if op_item.get(key) is None else op_item.get(key)
250
+ for key in Const.SUMMARY_METRICS_LIST
251
+ ]
252
+ result[Const.SUMMARY].append(summary_data)
253
+
254
+ # dump_mode differ key
229
255
  if self.mode_config.dump_mode == Const.MD5:
230
- check_api_info_len(op_name, struct, 3)
231
- result[Const.MD5].append(struct[2])
232
-
233
- check_api_info_len(op_name, summary_reorder, 1)
234
- result[Const.SUMMARY].append(summary_reorder.pop(0))
256
+ result[Const.MD5].append(op_parsed_list[index].get(Const.MD5))
257
+ if self.mode_config.dump_mode == Const.ALL:
258
+ result[Const.DATA_NAME].append(op_item.get(Const.DATA_NAME))
235
259
 
236
- if index == 0 and self.mode_config.stack_mode:
237
- check_api_info_len(op_name, merge_list[Const.STACK_INFO], 1)
238
- result[Const.STACK_INFO].append(merge_list[Const.STACK_INFO][0])
260
+ # mode_config stack_mode addition key
261
+ if i == 0 and self.mode_config.stack_mode:
262
+ result[Const.STACK_INFO].append(op_parsed_list[-1].get('full_info'))
239
263
  else:
240
264
  result[Const.STACK_INFO].append(None)
241
265
 
242
- if self.mode_config.dump_mode == Const.ALL:
243
- check_api_info_len(op_name, data_name_reorder, 1)
244
- result['data_name'].append(data_name_reorder.pop(0))
266
+ # mode_config first_diff_analyze addition key
267
+ if self.mode_config.first_diff_analyze:
268
+ result.setdefault('op', []).append(op_item.get('op', str(None)))
269
+ result.setdefault('peer', []).append(op_item.get('peer', str(None)))
270
+
245
271
  progress_bar.update(1)
246
272
  progress_bar.close()
247
273
  return pd.DataFrame(result)
@@ -256,14 +282,14 @@ class ParseData:
256
282
  stack_info = stack_json_data.get(op_name)
257
283
  if stack_info is not None:
258
284
  check_stack_json_str(stack_info, op_name)
259
- # append only when stack_mode is True,
260
- op_parsed_list.append({
261
- 'full_op_name': op_name,
262
- 'full_info': stack_info
263
- })
264
-
265
- merge_list = merge_tensor(op_parsed_list, self.mode_config.dump_mode)
266
- return merge_list
285
+ else:
286
+ stack_info = None
287
+ # always add stack_info whether stack_mode is True
288
+ op_parsed_list.append({
289
+ 'full_op_name': op_name,
290
+ 'full_info': stack_info
291
+ })
292
+ return op_parsed_list
267
293
 
268
294
 
269
295
  class ProcessDf:
@@ -327,13 +353,17 @@ class ProcessDf:
327
353
  return npu_op_name
328
354
 
329
355
  def modify_compare_data_with_user_mapping(self, npu_df, bench_df):
356
+ def remove_prefix(string, prefix):
357
+ if string.startswith(prefix):
358
+ return string[len(prefix):]
359
+ return string
360
+
330
361
  def gen_input_compare_key(pattern, term):
331
362
  is_unmatched = True
332
363
  for i, prefix in enumerate(mapping_dict.get(f'ms_{term}')):
333
- if op_name.split(pattern)[1].startswith(str(prefix)):
364
+ if remove_prefix(op_name, api_origin_name + pattern) == str(prefix):
334
365
  npu_df.loc[index, CompareConst.CMP_KEY] = (
335
- op_name.replace(pattern + str(prefix),
336
- pattern + str(mapping_dict.get(f'pt_{term}')[i])))
366
+ op_name.replace(pattern + str(prefix), pattern + str(mapping_dict.get(f'pt_{term}')[i])))
337
367
  is_unmatched = False
338
368
  return is_unmatched
339
369
 
@@ -355,15 +385,17 @@ class ProcessDf:
355
385
  continue
356
386
  for index in ms_api_indices_dict.get(ms_api):
357
387
  op_name = npu_df.loc[index, CompareConst.OP_NAME].replace(ms_api, pt_api, 1)
358
- if CompareConst.INPUT_PATTERN in op_name:
388
+ state = npu_df.loc[index, Const.STATE]
389
+ api_origin_name = npu_df.loc[index, Const.API_ORIGIN_NAME].replace(ms_api, pt_api, 1)
390
+ if state == Const.INPUT:
359
391
  is_abandoned = gen_input_compare_key(CompareConst.INPUT_PATTERN, 'args')
360
- elif CompareConst.KWARGS_PATTERN in op_name:
392
+ elif state == Const.KWARGS:
361
393
  is_abandoned = gen_input_compare_key(CompareConst.KWARGS_PATTERN, 'args')
362
- elif CompareConst.OUTPUT_PATTERN in op_name:
394
+ elif state == Const.OUTPUT:
363
395
  is_abandoned = gen_input_compare_key(CompareConst.OUTPUT_PATTERN, 'output')
364
- elif CompareConst.PARAMS_PATTERN in op_name:
396
+ elif state == Const.PARAMS:
365
397
  is_abandoned = gen_input_compare_key(CompareConst.PARAMS_PATTERN, 'parameters')
366
- elif CompareConst.PARAMS_GRAD_PATTERN in op_name:
398
+ elif state == Const.PARAMS_GRAD:
367
399
  is_abandoned = gen_input_compare_key(CompareConst.PARAMS_GRAD_PATTERN, 'parameters_grad')
368
400
  else:
369
401
  logger.error(f'Excepted op_name: {op_name}')
@@ -413,8 +445,8 @@ class Match:
413
445
  @staticmethod
414
446
  def put_unmatched_in_table(match_result, npu_op_item):
415
447
  npu_columns = npu_op_item.index.tolist()[:-2]
416
- new_columns = [name[:-1] + 'y' for name in npu_columns]
417
- na_series = pd.Series([CompareConst.N_A] * len(new_columns), index=new_columns)
448
+ bench_columns = [name + '_y' for name in npu_columns]
449
+ na_series = pd.Series([CompareConst.N_A] * len(bench_columns), index=bench_columns)
418
450
  new_result_item = pd.concat([npu_op_item, na_series]).to_frame().T
419
451
  new_result_item.columns = CompareConst.MATCH_RESULT_COLUMNS
420
452
  match_result = pd.concat([match_result, new_result_item])
@@ -610,12 +642,21 @@ class CreateTable:
610
642
  'md5_x': CompareConst.NPU_MD5,
611
643
  'md5_y': CompareConst.BENCH_MD5,
612
644
  'data_name_x': CompareConst.DATA_NAME,
613
- 'stack_info_x': CompareConst.STACK}, inplace=True)
645
+ 'stack_info_x': CompareConst.STACK,
646
+ 'state_x': Const.STATE,
647
+ 'api_origin_name_x': Const.API_ORIGIN_NAME,
648
+ 'requires_grad_x': CompareConst.NPU_REQ_GRAD,
649
+ 'requires_grad_y': CompareConst.BENCH_REQ_GRAD
650
+ },
651
+ inplace=True)
614
652
 
615
653
  # process summary data
616
654
  npu_summary = [CompareConst.NPU_MAX, CompareConst.NPU_MIN, CompareConst.NPU_MEAN, CompareConst.NPU_NORM]
617
655
  bench_summary = [CompareConst.BENCH_MAX, CompareConst.BENCH_MIN, CompareConst.BENCH_MEAN,
618
656
  CompareConst.BENCH_NORM]
657
+ # process requires_grad
658
+ result[CompareConst.REQ_GRAD_CONSIST] = result[CompareConst.NPU_REQ_GRAD] == result[CompareConst.BENCH_REQ_GRAD]
659
+
619
660
  if result.empty:
620
661
  result[npu_summary] = pd.DataFrame(columns=npu_summary)
621
662
  result[bench_summary] = pd.DataFrame(columns=bench_summary)
@@ -623,6 +664,7 @@ class CreateTable:
623
664
  result[npu_summary] = result['summary_x'].apply(self.set_summary).tolist()
624
665
  result[bench_summary] = result['summary_y'].apply(self.set_summary).tolist()
625
666
 
667
+ header.extend([Const.STATE, Const.API_ORIGIN_NAME])
626
668
  result_df = pd.DataFrame(columns=header)
627
669
  for h in header:
628
670
  if h in result.columns:
@@ -667,13 +709,13 @@ class CalcStatsDiff:
667
709
  result_df.loc[cond_nan_diff, [diff_name, rel_err_name]] = CompareConst.NAN
668
710
 
669
711
  cond_not_nan_diff = cond_valid_stat & ~cond_diff_nan
670
- condition_pt_zero = bench_val == 0
712
+ condition_pt_zero = self.get_number(bench_val) == 0
671
713
  result_df.loc[cond_not_nan_diff & condition_pt_zero, rel_err_name] = CompareConst.N_A
672
714
 
673
715
  # 相对误差转成百分比字符串
674
716
  cond_ref_err = cond_not_nan_diff & ~condition_pt_zero
675
717
  result_df.loc[cond_ref_err, rel_err_name] = (
676
- result_df.loc[cond_ref_err, diff_name] / bench_val[cond_ref_err] * 100)
718
+ result_df.loc[cond_ref_err, diff_name] / bench_val[cond_ref_err].astype(float) * 100)
677
719
  result_df.loc[cond_ref_err, rel_err_name] = (result_df.loc[cond_ref_err, rel_err_name].abs().astype(str) + '%')
678
720
 
679
721
  magnitude = self.get_number(result_df[diff_name]).abs() / (pd.Series(
@@ -685,12 +727,13 @@ class CalcStatsDiff:
685
727
  condition_no_bench = result_df[CompareConst.BENCH_NAME] == CompareConst.N_A
686
728
  result_df[condition_no_bench] = result_df[condition_no_bench].fillna(CompareConst.N_A)
687
729
  result_df.loc[condition_no_bench, CompareConst.ERROR_MESSAGE] = CompareConst.NO_BENCH
730
+ condition_req_grad_consist = result_df[CompareConst.NPU_REQ_GRAD] == result_df[CompareConst.BENCH_REQ_GRAD]
688
731
 
689
732
  if self.mode_config.dump_mode == Const.MD5:
690
733
  condition_md5_equal = result_df[CompareConst.NPU_MD5] == result_df[CompareConst.BENCH_MD5]
691
734
  result_df.loc[condition_md5_equal, CompareConst.RESULT] = CompareConst.PASS
692
735
  result_df.loc[~condition_md5_equal & ~condition_no_bench, CompareConst.RESULT] = CompareConst.DIFF
693
- elif self.mode_config.dump_mode == Const.SUMMARY:
736
+ elif self.mode_config.first_diff_analyze or self.mode_config.dump_mode == Const.SUMMARY:
694
737
  warning_list = [
695
738
  self.calc_summary_diff(result_df, condition_no_bench, stats_index)
696
739
  for stats_index in ['max', 'min', 'mean', 'l2norm']
@@ -698,14 +741,16 @@ class CalcStatsDiff:
698
741
  warning_flag = pd.DataFrame(warning_list).any()
699
742
  result_df.loc[~condition_no_bench, [CompareConst.RESULT, CompareConst.ERROR_MESSAGE]] = ''
700
743
  result_df.loc[warning_flag, CompareConst.RESULT] = CompareConst.WARNING
701
- result_df.loc[warning_flag, CompareConst.ERROR_MESSAGE] = 'Need double check api accuracy.'
744
+ result_df.loc[warning_flag, CompareConst.ERROR_MESSAGE] = 'Need double check api accuracy. '
745
+ result_df.loc[~condition_req_grad_consist, CompareConst.ERROR_MESSAGE] += 'Requires_grad inconsistent. '
702
746
  else:
703
747
  fill_cols = [CompareConst.COSINE, CompareConst.EUC_DIST,
704
748
  CompareConst.MAX_ABS_ERR, CompareConst.MAX_RELATIVE_ERR,
705
749
  CompareConst.ONE_THOUSANDTH_ERR_RATIO, CompareConst.FIVE_THOUSANDTHS_ERR_RATIO,
706
750
  CompareConst.ERROR_MESSAGE]
707
- result_df.loc[~condition_no_bench, fill_cols] = ''
751
+ result_df.loc[~condition_no_bench, fill_cols] = '' # 默认填充'', df默认省缺值为nan,不便后续处理,容易出现意外情况
708
752
  result_df.loc[~condition_no_bench, CompareConst.ACCURACY] = CompareConst.ACCURACY_CHECK_YES
753
+ result_df.loc[~condition_req_grad_consist, CompareConst.ERROR_MESSAGE] = 'Requires_grad inconsistent. '
709
754
 
710
755
  return result_df[header]
711
756
 
@@ -718,12 +763,15 @@ def setup_comparison(input_param, output_path, **kwargs) -> ComparisonConfig:
718
763
  stack_mode=False,
719
764
  auto_analyze=kwargs.get('auto_analyze', True),
720
765
  fuzzy_match=kwargs.get('fuzzy_match', False),
766
+ highlight=kwargs.get('highlight', False),
721
767
  data_mapping=kwargs.get('data_mapping', {}),
722
768
  suffix=kwargs.get('suffix', ''),
723
769
  cell_mapping=kwargs.get('cell_mapping', {}),
724
770
  api_mapping=kwargs.get('api_mapping', {}),
725
771
  layer_mapping=kwargs.get('layer_mapping', {}),
772
+ first_diff_analyze=kwargs.get('first_diff_analyze', False),
726
773
  compared_file_type='',
774
+ is_print_compare_log=input_param.get('is_print_compare_log', True)
727
775
  )
728
776
 
729
777
  set_dump_path(input_param)
@@ -736,8 +784,7 @@ def setup_comparison(input_param, output_path, **kwargs) -> ComparisonConfig:
736
784
  else:
737
785
  config.stack_mode = set_stack_json_path(input_param)
738
786
 
739
- check_configuration_param(config.stack_mode, config.auto_analyze, config.fuzzy_match,
740
- input_param.get('is_print_compare_log', True))
787
+ check_configuration_param(config)
741
788
  create_directory(output_path)
742
789
  check_compare_param(input_param, output_path, config.dump_mode, config.stack_mode)
743
790
 
@@ -13,6 +13,8 @@
13
13
  # See the License for the specific language governing permissions and
14
14
  # limitations under the License.
15
15
 
16
+ import os
17
+
16
18
  from msprobe.core.common.log import logger
17
19
  from msprobe.core.common.utils import check_op_str_pattern_valid, CompareException
18
20
  from msprobe.core.common.const import Const
@@ -106,3 +108,14 @@ def check_stack_json_str(stack_info, op_name):
106
108
  else:
107
109
  logger.error(f"Expected stack_info to be a list, but got {type(stack_info).__name__} for '{op_name}'")
108
110
  raise CompareException(CompareException.INVALID_OBJECT_TYPE_ERROR)
111
+
112
+
113
+ def check_configuration_param(config):
114
+ arg_list = [config.stack_mode, config.auto_analyze, config.fuzzy_match,
115
+ config.highlight, config.first_diff_analyze, config.is_print_compare_log]
116
+ arg_names = ['stack_mode', 'auto_analyze', 'fuzzy_match',
117
+ 'highlight', 'first_diff_analyze', 'is_print_compare_log']
118
+ for arg, name in zip(arg_list, arg_names):
119
+ if not isinstance(arg, bool):
120
+ logger.error(f"Invalid input parameter, {name} which should be only bool type.")
121
+ raise CompareException(CompareException.INVALID_PARAM_ERROR)
@@ -1,4 +1,4 @@
1
- # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
1
+ # Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
2
2
  # All rights reserved.
3
3
  #
4
4
  # Licensed under the Apache License, Version 2.0 (the "License");
@@ -13,28 +13,40 @@
13
13
  # See the License for the specific language governing permissions and
14
14
  # limitations under the License.
15
15
 
16
- import json
16
+ import os
17
+
17
18
  from msprobe.core.common.file_utils import check_file_type, load_json, check_file_or_directory_path
18
19
  from msprobe.core.common.const import FileCheckConst, Const
19
20
  from msprobe.core.common.utils import CompareException
20
21
  from msprobe.core.common.log import logger
22
+ from msprobe.core.compare.utils import get_paired_dirs
23
+
21
24
 
25
+ def compare_cli(args, depth=1):
26
+ if depth > 2:
27
+ logger.error("Recursive compare error, depth exceeds 2.")
28
+ raise CompareException(CompareException.RECURSION_LIMIT_ERROR)
22
29
 
23
- def compare_cli(args):
24
- input_param = load_json(args.input_path)
30
+ if isinstance(args.input_path, dict): # special for dyn-graph mix compare
31
+ input_param = args.input_path
32
+ else:
33
+ input_param = load_json(args.input_path)
25
34
  if not isinstance(input_param, dict):
26
35
  logger.error("input_param should be dict, please check!")
27
36
  raise CompareException(CompareException.INVALID_OBJECT_TYPE_ERROR)
37
+
28
38
  npu_path = input_param.get("npu_path", None)
29
39
  bench_path = input_param.get("bench_path", None)
30
40
  if not npu_path:
31
- logger.error(f"Missing npu_path in configuration file {args.input_path}, please check!")
41
+ logger.error(f"Missing npu_path in input configuration file, please check!")
32
42
  raise CompareException(CompareException.INVALID_PATH_ERROR)
33
43
  if not bench_path:
34
- logger.error(f"Missing bench_path in configuration file {args.input_path}, please check!")
44
+ logger.error(f"Missing bench_path in input configuration file, please check!")
35
45
  raise CompareException(CompareException.INVALID_PATH_ERROR)
46
+
36
47
  frame_name = args.framework
37
48
  auto_analyze = not args.compare_only
49
+
38
50
  if frame_name == Const.PT_FRAMEWORK:
39
51
  from msprobe.pytorch.compare.pt_compare import compare
40
52
  from msprobe.pytorch.compare.distributed_compare import compare_distributed
@@ -46,7 +58,9 @@ def compare_cli(args):
46
58
  common_kwargs = {
47
59
  "auto_analyze": auto_analyze,
48
60
  "fuzzy_match": args.fuzzy_match,
61
+ "highlight": args.highlight,
49
62
  "data_mapping": args.data_mapping,
63
+ "diff_analyze": args.diff_analyze
50
64
  }
51
65
 
52
66
  if check_file_type(npu_path) == FileCheckConst.FILE and check_file_type(bench_path) == FileCheckConst.FILE:
@@ -75,6 +89,12 @@ def compare_cli(args):
75
89
  elif check_file_type(npu_path) == FileCheckConst.DIR and check_file_type(bench_path) == FileCheckConst.DIR:
76
90
  check_file_or_directory_path(npu_path, isdir=True)
77
91
  check_file_or_directory_path(bench_path, isdir=True)
92
+
93
+ if depth == 1:
94
+ mix_compare_success = mix_compare(args, input_param, depth)
95
+ if mix_compare_success:
96
+ return
97
+
78
98
  kwargs = {
79
99
  **common_kwargs,
80
100
  "stack_mode": args.stack_mode,
@@ -90,6 +110,13 @@ def compare_cli(args):
90
110
  if isinstance(common, bool) and common:
91
111
  common_dir_compare(input_param, args.output_path)
92
112
  return
113
+
114
+ if common_kwargs.get('diff_analyze', False):
115
+ logger.info("Start finding first diff node......")
116
+ from msprobe.core.compare.find_first.analyzer import DiffAnalyzer
117
+ DiffAnalyzer(npu_path, bench_path, args.output_path, frame_name).analyze()
118
+ return
119
+
93
120
  if frame_name == Const.PT_FRAMEWORK:
94
121
  compare_distributed(npu_path, bench_path, args.output_path, **kwargs)
95
122
  else:
@@ -97,3 +124,34 @@ def compare_cli(args):
97
124
  else:
98
125
  logger.error("The npu_path and bench_path need to be of the same type.")
99
126
  raise CompareException(CompareException.INVALID_COMPARE_MODE)
127
+
128
+
129
+ def mix_compare(args, input_param, depth):
130
+ npu_path = input_param.get("npu_path", None)
131
+ bench_path = input_param.get("bench_path", None)
132
+
133
+ npu_bench_same_dirs_set = set(get_paired_dirs(npu_path, bench_path))
134
+ compare_cross_set = npu_bench_same_dirs_set & Const.MIX_DUMP_NAMES
135
+
136
+ if compare_cross_set:
137
+ logger.info("Start mix compare.")
138
+ origin_output = args.output_path
139
+
140
+ for folder_name in list(compare_cross_set):
141
+ new_npu_path = os.path.join(npu_path, folder_name)
142
+ new_bench_path = os.path.join(bench_path, folder_name)
143
+ paired_steps = get_paired_dirs(new_npu_path, new_bench_path)
144
+
145
+ for step_name in paired_steps:
146
+ logger.info(f"[mix compare] Start comparing {folder_name}/{step_name}")
147
+ npu_dir = os.path.join(new_npu_path, step_name)
148
+ bench_dir = os.path.join(new_bench_path, step_name)
149
+ args.input_path = {
150
+ "npu_path": npu_dir,
151
+ "bench_path": bench_dir,
152
+ "is_print_compare_log": input_param.get("is_print_compare_log", True)
153
+ }
154
+ args.output_path = os.path.join(origin_output, folder_name, step_name)
155
+ compare_cli(args, depth + 1)
156
+ return True
157
+ return False
@@ -20,13 +20,15 @@ from msprobe.core.common.file_utils import load_yaml
20
20
 
21
21
 
22
22
  class ModeConfig:
23
- def __init__(self, stack_mode=False, auto_analyze=True, fuzzy_match=False, dump_mode=Const.SUMMARY,
24
- compared_file_type=Const.DUMP_JSON_FILE):
25
- self.stack_mode = stack_mode
26
- self.auto_analyze = auto_analyze
27
- self.fuzzy_match = fuzzy_match
28
- self.dump_mode = dump_mode
29
- self.compared_file_type = compared_file_type
23
+ def __init__(self, **kwargs):
24
+ self.stack_mode = kwargs.get('stack_mode', False)
25
+ self.auto_analyze = kwargs.get('auto_analyze', True)
26
+ self.fuzzy_match = kwargs.get('fuzzy_match', False)
27
+ self.highlight = kwargs.get('highlight', False)
28
+ self.dump_mode = kwargs.get('dump_mode', Const.SUMMARY)
29
+ self.first_diff_analyze = kwargs.get('first_diff_analyze', False)
30
+ self.diff_analyze = kwargs.get('diff_analyze', False)
31
+ self.compared_file_type = kwargs.get('compared_file_type', Const.DUMP_JSON_FILE)
30
32
 
31
33
 
32
34
  class MappingConfig:
@@ -69,4 +71,4 @@ class MappingDict:
69
71
  else:
70
72
  raise TypeError(f"The type of parameter `data_mapping` must be dict, str or None, but got "
71
73
  f"{type(data_mapping)}")
72
- return data_mapping_dict
74
+ return data_mapping_dict
@@ -0,0 +1,14 @@
1
+ compare_metrics:
2
+ - MaxRelativeErr
3
+ - MinRelativeErr
4
+ - MeanRelativeErr
5
+ - NormRelativeErr
6
+
7
+ MaxRelativeErr:
8
+ - 0.5
9
+ MinRelativeErr:
10
+ - 0.5
11
+ MeanRelativeErr:
12
+ - 0.5
13
+ NormRelativeErr:
14
+ - 0.5