mindstudio-probe 1.3.0__py3-none-any.whl → 8.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (213) hide show
  1. {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/METADATA +4 -2
  2. {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/RECORD +204 -152
  3. msprobe/README.md +32 -1
  4. msprobe/core/__init__.py +17 -0
  5. msprobe/core/common/const.py +120 -21
  6. msprobe/core/common/exceptions.py +2 -2
  7. msprobe/core/common/file_utils.py +279 -50
  8. msprobe/core/common/framework_adapter.py +169 -0
  9. msprobe/core/common/global_lock.py +86 -0
  10. msprobe/core/common/runtime.py +25 -0
  11. msprobe/core/common/utils.py +136 -45
  12. msprobe/core/common_config.py +7 -0
  13. msprobe/core/compare/acc_compare.py +646 -428
  14. msprobe/core/compare/check.py +36 -103
  15. msprobe/core/compare/compare_cli.py +4 -0
  16. msprobe/core/compare/config.py +72 -0
  17. msprobe/core/compare/highlight.py +215 -215
  18. msprobe/core/compare/layer_mapping/layer_mapping.py +2 -0
  19. msprobe/core/compare/merge_result/merge_result.py +4 -4
  20. msprobe/core/compare/multiprocessing_compute.py +223 -110
  21. msprobe/core/compare/npy_compare.py +2 -4
  22. msprobe/core/compare/utils.py +214 -244
  23. msprobe/core/config_check/__init__.py +17 -0
  24. msprobe/{pytorch/dump/kernel_dump/kernel_config.py → core/config_check/checkers/__init__.py} +8 -16
  25. msprobe/core/config_check/checkers/base_checker.py +60 -0
  26. msprobe/core/config_check/checkers/dataset_checker.py +138 -0
  27. msprobe/core/config_check/checkers/env_args_checker.py +96 -0
  28. msprobe/core/config_check/checkers/hyperparameter_checker.py +170 -0
  29. msprobe/core/config_check/checkers/pip_checker.py +90 -0
  30. msprobe/core/config_check/checkers/random_checker.py +367 -0
  31. msprobe/core/config_check/checkers/weights_checker.py +147 -0
  32. msprobe/core/config_check/ckpt_compare/ckpt_comparator.py +74 -0
  33. msprobe/core/config_check/ckpt_compare/megatron_loader.py +302 -0
  34. msprobe/core/config_check/ckpt_compare/metrics.py +83 -0
  35. msprobe/core/config_check/ckpt_compare/name_mapping.yaml +12 -0
  36. msprobe/core/config_check/config_check_cli.py +51 -0
  37. msprobe/core/config_check/config_checker.py +100 -0
  38. msprobe/{mindspore/runtime.py → core/config_check/resource/dependency.yaml} +7 -4
  39. msprobe/core/config_check/resource/env.yaml +57 -0
  40. msprobe/core/config_check/resource/hyperparameter.yaml +21 -0
  41. msprobe/core/config_check/utils/hyperparameter_parser.py +115 -0
  42. msprobe/core/config_check/utils/utils.py +107 -0
  43. msprobe/core/data_dump/api_registry.py +67 -4
  44. msprobe/core/data_dump/data_collector.py +170 -89
  45. msprobe/core/data_dump/data_processor/base.py +72 -51
  46. msprobe/core/data_dump/data_processor/mindspore_processor.py +109 -55
  47. msprobe/core/data_dump/data_processor/pytorch_processor.py +90 -82
  48. msprobe/core/data_dump/json_writer.py +143 -27
  49. msprobe/core/debugger/precision_debugger.py +144 -0
  50. msprobe/core/grad_probe/constant.py +1 -1
  51. msprobe/core/grad_probe/grad_compare.py +1 -1
  52. msprobe/core/grad_probe/utils.py +1 -1
  53. msprobe/core/hook_manager.py +242 -0
  54. msprobe/core/monitor/anomaly_processor.py +384 -0
  55. msprobe/core/service.py +357 -0
  56. msprobe/core/single_save/__init__.py +0 -0
  57. msprobe/core/single_save/single_comparator.py +243 -0
  58. msprobe/core/single_save/single_saver.py +146 -0
  59. msprobe/docs/01.installation.md +6 -5
  60. msprobe/docs/02.config_introduction.md +79 -22
  61. msprobe/docs/03.config_examples.md +1 -0
  62. msprobe/docs/04.kernel_dump_PyTorch.md +1 -1
  63. msprobe/docs/05.data_dump_PyTorch.md +118 -49
  64. msprobe/docs/06.data_dump_MindSpore.md +167 -20
  65. msprobe/docs/07.accuracy_checker_PyTorch.md +2 -2
  66. msprobe/docs/08.accuracy_checker_online_PyTorch.md +69 -9
  67. msprobe/docs/09.accuracy_checker_MindSpore.md +18 -6
  68. msprobe/docs/10.accuracy_compare_PyTorch.md +212 -74
  69. msprobe/docs/11.accuracy_compare_MindSpore.md +87 -37
  70. msprobe/docs/12.overflow_check_PyTorch.md +2 -2
  71. msprobe/docs/13.overflow_check_MindSpore.md +2 -2
  72. msprobe/docs/14.data_parse_PyTorch.md +3 -3
  73. msprobe/docs/17.grad_probe.md +2 -1
  74. msprobe/docs/18.online_dispatch.md +2 -2
  75. msprobe/docs/19.monitor.md +90 -44
  76. msprobe/docs/21.visualization_PyTorch.md +68 -15
  77. msprobe/docs/22.visualization_MindSpore.md +71 -18
  78. msprobe/docs/25.tool_function_introduction.md +23 -22
  79. msprobe/docs/26.data_dump_PyTorch_baseline.md +14 -3
  80. msprobe/docs/27.dump_json_instruction.md +1 -1
  81. msprobe/docs/28.debugger_save_instruction.md +111 -20
  82. msprobe/docs/29.data_dump_MSAdapter.md +2 -2
  83. msprobe/docs/30.overflow_check_MSAdapter.md +2 -2
  84. msprobe/docs/31.config_check.md +95 -0
  85. msprobe/docs/32.ckpt_compare.md +69 -0
  86. msprobe/docs/33.generate_operator_MindSpore.md +181 -0
  87. msprobe/docs/34.RL_collect.md +92 -0
  88. msprobe/docs/35.nan_analyze.md +72 -0
  89. msprobe/docs/data_dump_MindSpore/data_dump_MindSpore_baseline.md +12 -1
  90. msprobe/docs/data_dump_MindSpore/dynamic_graph_quick_start_example.md +3 -1
  91. msprobe/docs/img/compare_result.png +0 -0
  92. msprobe/docs/img/save_compare_result_sample.png +0 -0
  93. msprobe/docs/img/visualization/proxy.png +0 -0
  94. msprobe/mindspore/__init__.py +1 -2
  95. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +150 -58
  96. msprobe/mindspore/api_accuracy_checker/api_runner.py +7 -3
  97. msprobe/mindspore/api_accuracy_checker/bench_functions/flash_attention_score.py +47 -69
  98. msprobe/mindspore/api_accuracy_checker/cmd_parser.py +4 -0
  99. msprobe/mindspore/api_accuracy_checker/compute_element.py +0 -1
  100. msprobe/mindspore/api_accuracy_checker/data_manager.py +2 -2
  101. msprobe/mindspore/api_accuracy_checker/generate_op_script/op_generator.py +460 -0
  102. msprobe/mindspore/api_accuracy_checker/generate_op_script/operator_replication.template +2081 -0
  103. msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +9 -0
  104. msprobe/mindspore/api_accuracy_checker/torch_mindtorch_importer.py +2 -1
  105. msprobe/mindspore/cell_processor.py +204 -33
  106. msprobe/mindspore/code_mapping/graph_parser.py +4 -21
  107. msprobe/mindspore/common/const.py +17 -7
  108. msprobe/mindspore/common/utils.py +128 -11
  109. msprobe/mindspore/compare/common_dir_compare.py +382 -0
  110. msprobe/mindspore/compare/distributed_compare.py +2 -26
  111. msprobe/mindspore/compare/ms_compare.py +17 -405
  112. msprobe/mindspore/compare/ms_graph_compare.py +14 -5
  113. msprobe/mindspore/compare/utils.py +37 -0
  114. msprobe/mindspore/debugger/debugger_config.py +53 -3
  115. msprobe/mindspore/debugger/precision_debugger.py +72 -91
  116. msprobe/mindspore/dump/cell_dump_process.py +877 -0
  117. msprobe/mindspore/dump/cell_dump_with_insert_gradient.py +864 -0
  118. msprobe/mindspore/dump/dump_tool_factory.py +13 -5
  119. msprobe/mindspore/dump/graph_mode_cell_dump.py +139 -0
  120. msprobe/mindspore/dump/graph_tensor_dump.py +123 -0
  121. msprobe/mindspore/dump/hook_cell/api_register.py +40 -6
  122. msprobe/mindspore/dump/hook_cell/hook_cell.py +18 -7
  123. msprobe/mindspore/dump/hook_cell/ms_hook_manager.py +88 -0
  124. msprobe/mindspore/dump/hook_cell/primitive_hooks.py +8 -2
  125. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +18 -0
  126. msprobe/mindspore/dump/jit_dump.py +21 -18
  127. msprobe/mindspore/dump/kernel_kbyk_dump.py +6 -3
  128. msprobe/mindspore/dym_loader/hook_dynamic_loader.cpp +110 -0
  129. msprobe/mindspore/dym_loader/hook_dynamic_loader.h +15 -15
  130. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +12 -6
  131. msprobe/mindspore/free_benchmark/common/utils.py +1 -1
  132. msprobe/mindspore/grad_probe/global_context.py +7 -2
  133. msprobe/mindspore/grad_probe/grad_stat_csv.py +3 -2
  134. msprobe/mindspore/mindspore_service.py +114 -0
  135. msprobe/mindspore/monitor/common_func.py +52 -0
  136. msprobe/mindspore/monitor/data_writers.py +237 -0
  137. msprobe/mindspore/monitor/features.py +20 -7
  138. msprobe/mindspore/monitor/module_hook.py +281 -209
  139. msprobe/mindspore/monitor/optimizer_collect.py +334 -0
  140. msprobe/mindspore/monitor/utils.py +25 -5
  141. msprobe/mindspore/ms_config.py +16 -15
  142. msprobe/mindspore/task_handler_factory.py +5 -2
  143. msprobe/msprobe.py +19 -0
  144. msprobe/nan_analyze/__init__.py +14 -0
  145. msprobe/nan_analyze/analyzer.py +255 -0
  146. msprobe/nan_analyze/graph.py +189 -0
  147. msprobe/nan_analyze/utils.py +211 -0
  148. msprobe/pytorch/api_accuracy_checker/common/config.py +2 -2
  149. msprobe/pytorch/api_accuracy_checker/compare/compare.py +36 -34
  150. msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +20 -20
  151. msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +4 -7
  152. msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +204 -2
  153. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +12 -11
  154. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +1 -0
  155. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +8 -5
  156. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +2 -3
  157. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +29 -13
  158. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +12 -2
  159. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +45 -31
  160. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +156 -0
  161. msprobe/pytorch/attl_manager.py +65 -0
  162. msprobe/pytorch/bench_functions/npu_fusion_attention.py +27 -0
  163. msprobe/pytorch/common/utils.py +26 -14
  164. msprobe/pytorch/compare/distributed_compare.py +4 -36
  165. msprobe/pytorch/compare/pt_compare.py +13 -84
  166. msprobe/pytorch/compare/utils.py +47 -0
  167. msprobe/pytorch/debugger/debugger_config.py +34 -17
  168. msprobe/pytorch/debugger/precision_debugger.py +66 -118
  169. msprobe/pytorch/dump/module_dump/hook_wrapper.py +93 -0
  170. msprobe/pytorch/dump/module_dump/module_dump.py +11 -58
  171. msprobe/pytorch/dump/module_dump/module_processer.py +143 -113
  172. msprobe/pytorch/grad_probe/grad_stat_csv.py +3 -2
  173. msprobe/pytorch/hook_module/api_register.py +29 -5
  174. msprobe/pytorch/hook_module/hook_module.py +9 -18
  175. msprobe/pytorch/hook_module/jit_script_wrapper.py +33 -0
  176. msprobe/pytorch/hook_module/pt_hook_manager.py +68 -0
  177. msprobe/pytorch/hook_module/support_wrap_ops.yaml +22 -1
  178. msprobe/pytorch/hook_module/utils.py +28 -2
  179. msprobe/pytorch/monitor/csv2tb.py +6 -2
  180. msprobe/pytorch/monitor/data_writers.py +259 -0
  181. msprobe/pytorch/monitor/module_hook.py +227 -158
  182. msprobe/pytorch/monitor/module_metric.py +14 -0
  183. msprobe/pytorch/monitor/optimizer_collect.py +242 -270
  184. msprobe/pytorch/monitor/utils.py +16 -3
  185. msprobe/pytorch/online_dispatch/dispatch.py +4 -2
  186. msprobe/pytorch/online_dispatch/dump_compare.py +5 -2
  187. msprobe/pytorch/parse_tool/lib/utils.py +3 -3
  188. msprobe/pytorch/pt_config.py +8 -7
  189. msprobe/pytorch/pytorch_service.py +73 -0
  190. msprobe/visualization/builder/graph_builder.py +33 -13
  191. msprobe/visualization/builder/msprobe_adapter.py +24 -11
  192. msprobe/visualization/compare/graph_comparator.py +53 -45
  193. msprobe/visualization/compare/mode_adapter.py +31 -1
  194. msprobe/visualization/graph/base_node.py +3 -3
  195. msprobe/visualization/graph/graph.py +2 -2
  196. msprobe/visualization/graph_service.py +250 -103
  197. msprobe/visualization/utils.py +27 -11
  198. msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +0 -106
  199. msprobe/mindspore/monitor/anomaly_detect.py +0 -404
  200. msprobe/mindspore/monitor/module_spec_verifier.py +0 -94
  201. msprobe/mindspore/service.py +0 -549
  202. msprobe/pytorch/monitor/anomaly_analyse.py +0 -201
  203. msprobe/pytorch/monitor/anomaly_detect.py +0 -410
  204. msprobe/pytorch/monitor/module_spec_verifier.py +0 -95
  205. msprobe/pytorch/monitor/unittest/test_monitor.py +0 -160
  206. msprobe/pytorch/service.py +0 -473
  207. {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/LICENSE +0 -0
  208. {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/WHEEL +0 -0
  209. {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/entry_points.txt +0 -0
  210. {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/top_level.txt +0 -0
  211. /msprobe/{mindspore → core}/compare/ms_to_pt_api.yaml +0 -0
  212. /msprobe/{mindspore/dump → core}/kernel_dump/kernel_config.py +0 -0
  213. /msprobe/{pytorch/monitor/unittest → core/monitor}/__init__.py +0 -0
@@ -14,113 +14,46 @@
14
14
  # limitations under the License.
15
15
 
16
16
  from msprobe.core.common.log import logger
17
- from msprobe.core.compare.utils import rename_api
18
17
  from msprobe.core.common.utils import check_op_str_pattern_valid, CompareException
19
- from msprobe.core.common.const import CompareConst, Const
20
-
21
- dtype_mapping = {
22
- "Int8": "torch.int8",
23
- "UInt8": "torch.uint8",
24
- "Int16": "torch.int16",
25
- "UInt16": "torch.uint16",
26
- "Int32": "torch.int32",
27
- "UInt32": "torch.uint32",
28
- "Int64": "torch.int64",
29
- "UInt64": "torch.uint64",
30
- "Float16": "torch.float16",
31
- "Float32": "torch.float32",
32
- "Float64": "torch.float64",
33
- "Bool": "torch.bool",
34
- "BFloat16": "torch.bfloat16",
35
- "Complex64": "torch.complex64",
36
- "Complex128": "torch.complex128"
18
+ from msprobe.core.common.const import Const
19
+
20
+ cross_dtype_mapping = {
21
+ "Int8": "int",
22
+ "torch.int8": "int",
23
+ "UInt8": "int",
24
+ "torch.uint8": "int",
25
+ "Int16": "int",
26
+ "torch.int16": "int",
27
+ "UInt16": "int",
28
+ "torch.uint16": "int",
29
+ "Int32": "int",
30
+ "torch.int32": "int",
31
+ "UInt32": "int",
32
+ "torch.uint32": "int",
33
+ "Int64": "int",
34
+ "torch.int64": "int",
35
+ "UInt64": "int",
36
+ "torch.uint64": "int",
37
+
38
+ "Float16": "float",
39
+ "torch.float16": "float",
40
+ "Float32": "float",
41
+ "torch.float32": "float",
42
+ "Float64": "float",
43
+ "torch.float64": "float",
44
+ "BFloat16": "float",
45
+ "torch.bfloat16": "float",
46
+
47
+ "Bool": "bool",
48
+ "torch.bool": "bool",
49
+
50
+ "Complex64": "complex",
51
+ "torch.complex64": "complex",
52
+ "Complex128": "complex",
53
+ "torch.complex128": "complex",
37
54
  }
38
55
 
39
56
 
40
- def compare_op_dict_struct(npu_dict, bench_dict):
41
- return all(npu_dict.get(key) == bench_dict.get(key) for key in CompareConst.STRUCT_COMPARE_KEY)
42
-
43
-
44
- def check_struct_match(npu_dict, bench_dict):
45
- is_match = compare_op_dict_struct(npu_dict, bench_dict)
46
- if not is_match:
47
- struct_match_list = []
48
- try:
49
- for i, key in enumerate(CompareConst.STRUCT_COMPARE_KEY):
50
- # 首先额外检查input_struct是否空,input_struct不可能为空
51
- if i == 0 and (not npu_dict.get(key, []) or not bench_dict.get(key, [])):
52
- return False
53
- struct_match_list.append(check_type_shape_match(npu_dict.get(key, []), bench_dict.get(key, [])))
54
- except CompareException as error:
55
- err_msg = f'index out of bounds error occurs in npu or bench api, please check!\n' \
56
- f'npu_dict: {npu_dict}' \
57
- f'bench_dict: {bench_dict}'
58
- logger.error(err_msg)
59
- raise CompareException(CompareException.INDEX_OUT_OF_BOUNDS_ERROR) from error
60
- is_match = all(struct_match_list)
61
- return is_match
62
-
63
-
64
- def check_type_shape_match(npu_struct, bench_struct):
65
- """
66
- further check dtypes with a dtype mapping list when dtypes are not entirely consistent.
67
- """
68
- if len(npu_struct) != len(bench_struct):
69
- return False
70
- if not npu_struct and not bench_struct:
71
- return True
72
-
73
- struct_match = False
74
- for npu_type_shape, bench_type_shape in zip(npu_struct, bench_struct):
75
- try:
76
- npu_type = npu_type_shape[0]
77
- npu_shape = npu_type_shape[1]
78
- bench_type = bench_type_shape[0]
79
- bench_shape = bench_type_shape[1]
80
- except IndexError as error:
81
- logger.error(f'length of npu_type_shape: {npu_type_shape} and bench_type_shape: {bench_type_shape} '
82
- f'should both be 2, please check!')
83
- raise CompareException(CompareException.INDEX_OUT_OF_BOUNDS_ERROR) from error
84
- shape_match = npu_shape == bench_shape
85
- type_match = ((npu_type == bench_type) or
86
- any(npu_type in group and bench_type in group for group in CompareConst.DTYPE_MATCH_GROUPS))
87
- struct_match = shape_match and type_match
88
- if not struct_match:
89
- return False
90
- return struct_match
91
-
92
-
93
- def check_graph_mode(a_op_name, b_op_name):
94
- if Const.ATEN in a_op_name and Const.ATEN not in b_op_name:
95
- return True
96
- if Const.ATEN not in a_op_name and Const.ATEN in b_op_name:
97
- return True
98
- return False
99
-
100
-
101
- def fuzzy_check_op(npu_name_list, bench_name_list):
102
- # 先检查api里的item长度是否相等,如果不是parameters_grad, 必然有input或者output,长度不可能为0
103
- # 如果是parameters_grad, "parameters_grad"字段的字典不会是空字典,因此len>=1
104
- if len(npu_name_list) == 0 or len(bench_name_list) == 0 or len(npu_name_list) != len(bench_name_list):
105
- return False
106
- is_match = True
107
- for npu_name, bench_name in zip(npu_name_list, bench_name_list):
108
- is_match = fuzzy_check_name(npu_name, bench_name)
109
- if not is_match:
110
- break
111
- return is_match
112
-
113
-
114
- def fuzzy_check_name(npu_name, bench_name):
115
- if Const.FORWARD in npu_name and Const.FORWARD in bench_name:
116
- is_match = rename_api(npu_name, Const.FORWARD) == rename_api(bench_name, Const.FORWARD)
117
- elif Const.BACKWARD in npu_name and Const.BACKWARD in bench_name:
118
- is_match = rename_api(npu_name, Const.BACKWARD) == rename_api(bench_name, Const.BACKWARD)
119
- else:
120
- is_match = npu_name == bench_name
121
- return is_match
122
-
123
-
124
57
  def check_dump_json_str(op_data, op_name):
125
58
  input_list = op_data.get(Const.INPUT_ARGS, None) if op_data.get(Const.INPUT_ARGS, None) else op_data.get(
126
59
  Const.INPUT, None)
@@ -38,6 +38,7 @@ def compare_cli(args):
38
38
  else:
39
39
  from msprobe.mindspore.compare.ms_compare import ms_compare
40
40
  from msprobe.mindspore.compare.distributed_compare import ms_compare_distributed, ms_graph_compare
41
+ from msprobe.mindspore.compare.common_dir_compare import common_dir_compare
41
42
 
42
43
  common_kwargs = {
43
44
  "auto_analyze": auto_analyze,
@@ -78,6 +79,9 @@ def compare_cli(args):
78
79
  if input_param.get("rank_id") is not None:
79
80
  ms_graph_compare(input_param, args.output_path)
80
81
  return
82
+ if input_param.get('common', False):
83
+ common_dir_compare(input_param, args.output_path)
84
+ return
81
85
  if frame_name == Const.PT_FRAMEWORK:
82
86
  compare_distributed(npu_path, bench_path, args.output_path, **kwargs)
83
87
  else:
@@ -0,0 +1,72 @@
1
+ # Copyright (c) 2025-2025, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import os
17
+
18
+ from msprobe.core.common.const import Const, CompareConst
19
+ from msprobe.core.common.file_utils import load_yaml
20
+
21
+
22
+ class ModeConfig:
23
+ def __init__(self, stack_mode=False, auto_analyze=True, fuzzy_match=False, dump_mode=Const.SUMMARY,
24
+ compared_file_type=Const.DUMP_JSON_FILE):
25
+ self.stack_mode = stack_mode
26
+ self.auto_analyze = auto_analyze
27
+ self.fuzzy_match = fuzzy_match
28
+ self.dump_mode = dump_mode
29
+ self.compared_file_type = compared_file_type
30
+
31
+
32
+ class MappingConfig:
33
+ def __init__(self, cell_mapping=None, api_mapping=None, data_mapping=None):
34
+ self.cell_mapping = cell_mapping
35
+ self.api_mapping = api_mapping
36
+ self.data_mapping = data_mapping
37
+
38
+
39
+ class MappingDict:
40
+ def __init__(self, mapping_config: MappingConfig):
41
+ self.cell_mapping_dict = self.load_mapping_file(mapping_config.cell_mapping)
42
+ self.api_mapping_dict = self.load_mapping_file(mapping_config.api_mapping)
43
+ if mapping_config.api_mapping is not None:
44
+ self.ms_to_pt_mapping = self.load_internal_api()
45
+ self.data_mapping_dict = self.init_data_mapping(mapping_config.data_mapping)
46
+
47
+ @staticmethod
48
+ def load_internal_api():
49
+ cur_path = os.path.dirname(os.path.realpath(__file__))
50
+ yaml_path = os.path.abspath(os.path.join(cur_path, CompareConst.INTERNAL_API_MAPPING_FILE))
51
+ return load_yaml(yaml_path)
52
+
53
+ @staticmethod
54
+ def load_mapping_file(mapping_file):
55
+ if isinstance(mapping_file, str):
56
+ mapping_dict = load_yaml(mapping_file)
57
+ else:
58
+ mapping_dict = {}
59
+ return mapping_dict
60
+
61
+ def init_data_mapping(self, data_mapping):
62
+ """
63
+ 初始化data_mapping_dict
64
+ """
65
+ if isinstance(data_mapping, str) or data_mapping is None:
66
+ data_mapping_dict = self.load_mapping_file(data_mapping)
67
+ elif isinstance(data_mapping, dict):
68
+ data_mapping_dict = data_mapping
69
+ else:
70
+ raise TypeError(f"The type of parameter `data_mapping` must be dict, str or None, but got "
71
+ f"{type(data_mapping)}")
72
+ return data_mapping_dict