mindstudio-probe 1.3.0__py3-none-any.whl → 8.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (213) hide show
  1. {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/METADATA +4 -2
  2. {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/RECORD +204 -152
  3. msprobe/README.md +32 -1
  4. msprobe/core/__init__.py +17 -0
  5. msprobe/core/common/const.py +120 -21
  6. msprobe/core/common/exceptions.py +2 -2
  7. msprobe/core/common/file_utils.py +279 -50
  8. msprobe/core/common/framework_adapter.py +169 -0
  9. msprobe/core/common/global_lock.py +86 -0
  10. msprobe/core/common/runtime.py +25 -0
  11. msprobe/core/common/utils.py +136 -45
  12. msprobe/core/common_config.py +7 -0
  13. msprobe/core/compare/acc_compare.py +646 -428
  14. msprobe/core/compare/check.py +36 -103
  15. msprobe/core/compare/compare_cli.py +4 -0
  16. msprobe/core/compare/config.py +72 -0
  17. msprobe/core/compare/highlight.py +215 -215
  18. msprobe/core/compare/layer_mapping/layer_mapping.py +2 -0
  19. msprobe/core/compare/merge_result/merge_result.py +4 -4
  20. msprobe/core/compare/multiprocessing_compute.py +223 -110
  21. msprobe/core/compare/npy_compare.py +2 -4
  22. msprobe/core/compare/utils.py +214 -244
  23. msprobe/core/config_check/__init__.py +17 -0
  24. msprobe/{pytorch/dump/kernel_dump/kernel_config.py → core/config_check/checkers/__init__.py} +8 -16
  25. msprobe/core/config_check/checkers/base_checker.py +60 -0
  26. msprobe/core/config_check/checkers/dataset_checker.py +138 -0
  27. msprobe/core/config_check/checkers/env_args_checker.py +96 -0
  28. msprobe/core/config_check/checkers/hyperparameter_checker.py +170 -0
  29. msprobe/core/config_check/checkers/pip_checker.py +90 -0
  30. msprobe/core/config_check/checkers/random_checker.py +367 -0
  31. msprobe/core/config_check/checkers/weights_checker.py +147 -0
  32. msprobe/core/config_check/ckpt_compare/ckpt_comparator.py +74 -0
  33. msprobe/core/config_check/ckpt_compare/megatron_loader.py +302 -0
  34. msprobe/core/config_check/ckpt_compare/metrics.py +83 -0
  35. msprobe/core/config_check/ckpt_compare/name_mapping.yaml +12 -0
  36. msprobe/core/config_check/config_check_cli.py +51 -0
  37. msprobe/core/config_check/config_checker.py +100 -0
  38. msprobe/{mindspore/runtime.py → core/config_check/resource/dependency.yaml} +7 -4
  39. msprobe/core/config_check/resource/env.yaml +57 -0
  40. msprobe/core/config_check/resource/hyperparameter.yaml +21 -0
  41. msprobe/core/config_check/utils/hyperparameter_parser.py +115 -0
  42. msprobe/core/config_check/utils/utils.py +107 -0
  43. msprobe/core/data_dump/api_registry.py +67 -4
  44. msprobe/core/data_dump/data_collector.py +170 -89
  45. msprobe/core/data_dump/data_processor/base.py +72 -51
  46. msprobe/core/data_dump/data_processor/mindspore_processor.py +109 -55
  47. msprobe/core/data_dump/data_processor/pytorch_processor.py +90 -82
  48. msprobe/core/data_dump/json_writer.py +143 -27
  49. msprobe/core/debugger/precision_debugger.py +144 -0
  50. msprobe/core/grad_probe/constant.py +1 -1
  51. msprobe/core/grad_probe/grad_compare.py +1 -1
  52. msprobe/core/grad_probe/utils.py +1 -1
  53. msprobe/core/hook_manager.py +242 -0
  54. msprobe/core/monitor/anomaly_processor.py +384 -0
  55. msprobe/core/service.py +357 -0
  56. msprobe/core/single_save/__init__.py +0 -0
  57. msprobe/core/single_save/single_comparator.py +243 -0
  58. msprobe/core/single_save/single_saver.py +146 -0
  59. msprobe/docs/01.installation.md +6 -5
  60. msprobe/docs/02.config_introduction.md +79 -22
  61. msprobe/docs/03.config_examples.md +1 -0
  62. msprobe/docs/04.kernel_dump_PyTorch.md +1 -1
  63. msprobe/docs/05.data_dump_PyTorch.md +118 -49
  64. msprobe/docs/06.data_dump_MindSpore.md +167 -20
  65. msprobe/docs/07.accuracy_checker_PyTorch.md +2 -2
  66. msprobe/docs/08.accuracy_checker_online_PyTorch.md +69 -9
  67. msprobe/docs/09.accuracy_checker_MindSpore.md +18 -6
  68. msprobe/docs/10.accuracy_compare_PyTorch.md +212 -74
  69. msprobe/docs/11.accuracy_compare_MindSpore.md +87 -37
  70. msprobe/docs/12.overflow_check_PyTorch.md +2 -2
  71. msprobe/docs/13.overflow_check_MindSpore.md +2 -2
  72. msprobe/docs/14.data_parse_PyTorch.md +3 -3
  73. msprobe/docs/17.grad_probe.md +2 -1
  74. msprobe/docs/18.online_dispatch.md +2 -2
  75. msprobe/docs/19.monitor.md +90 -44
  76. msprobe/docs/21.visualization_PyTorch.md +68 -15
  77. msprobe/docs/22.visualization_MindSpore.md +71 -18
  78. msprobe/docs/25.tool_function_introduction.md +23 -22
  79. msprobe/docs/26.data_dump_PyTorch_baseline.md +14 -3
  80. msprobe/docs/27.dump_json_instruction.md +1 -1
  81. msprobe/docs/28.debugger_save_instruction.md +111 -20
  82. msprobe/docs/29.data_dump_MSAdapter.md +2 -2
  83. msprobe/docs/30.overflow_check_MSAdapter.md +2 -2
  84. msprobe/docs/31.config_check.md +95 -0
  85. msprobe/docs/32.ckpt_compare.md +69 -0
  86. msprobe/docs/33.generate_operator_MindSpore.md +181 -0
  87. msprobe/docs/34.RL_collect.md +92 -0
  88. msprobe/docs/35.nan_analyze.md +72 -0
  89. msprobe/docs/data_dump_MindSpore/data_dump_MindSpore_baseline.md +12 -1
  90. msprobe/docs/data_dump_MindSpore/dynamic_graph_quick_start_example.md +3 -1
  91. msprobe/docs/img/compare_result.png +0 -0
  92. msprobe/docs/img/save_compare_result_sample.png +0 -0
  93. msprobe/docs/img/visualization/proxy.png +0 -0
  94. msprobe/mindspore/__init__.py +1 -2
  95. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +150 -58
  96. msprobe/mindspore/api_accuracy_checker/api_runner.py +7 -3
  97. msprobe/mindspore/api_accuracy_checker/bench_functions/flash_attention_score.py +47 -69
  98. msprobe/mindspore/api_accuracy_checker/cmd_parser.py +4 -0
  99. msprobe/mindspore/api_accuracy_checker/compute_element.py +0 -1
  100. msprobe/mindspore/api_accuracy_checker/data_manager.py +2 -2
  101. msprobe/mindspore/api_accuracy_checker/generate_op_script/op_generator.py +460 -0
  102. msprobe/mindspore/api_accuracy_checker/generate_op_script/operator_replication.template +2081 -0
  103. msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +9 -0
  104. msprobe/mindspore/api_accuracy_checker/torch_mindtorch_importer.py +2 -1
  105. msprobe/mindspore/cell_processor.py +204 -33
  106. msprobe/mindspore/code_mapping/graph_parser.py +4 -21
  107. msprobe/mindspore/common/const.py +17 -7
  108. msprobe/mindspore/common/utils.py +128 -11
  109. msprobe/mindspore/compare/common_dir_compare.py +382 -0
  110. msprobe/mindspore/compare/distributed_compare.py +2 -26
  111. msprobe/mindspore/compare/ms_compare.py +17 -405
  112. msprobe/mindspore/compare/ms_graph_compare.py +14 -5
  113. msprobe/mindspore/compare/utils.py +37 -0
  114. msprobe/mindspore/debugger/debugger_config.py +53 -3
  115. msprobe/mindspore/debugger/precision_debugger.py +72 -91
  116. msprobe/mindspore/dump/cell_dump_process.py +877 -0
  117. msprobe/mindspore/dump/cell_dump_with_insert_gradient.py +864 -0
  118. msprobe/mindspore/dump/dump_tool_factory.py +13 -5
  119. msprobe/mindspore/dump/graph_mode_cell_dump.py +139 -0
  120. msprobe/mindspore/dump/graph_tensor_dump.py +123 -0
  121. msprobe/mindspore/dump/hook_cell/api_register.py +40 -6
  122. msprobe/mindspore/dump/hook_cell/hook_cell.py +18 -7
  123. msprobe/mindspore/dump/hook_cell/ms_hook_manager.py +88 -0
  124. msprobe/mindspore/dump/hook_cell/primitive_hooks.py +8 -2
  125. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +18 -0
  126. msprobe/mindspore/dump/jit_dump.py +21 -18
  127. msprobe/mindspore/dump/kernel_kbyk_dump.py +6 -3
  128. msprobe/mindspore/dym_loader/hook_dynamic_loader.cpp +110 -0
  129. msprobe/mindspore/dym_loader/hook_dynamic_loader.h +15 -15
  130. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +12 -6
  131. msprobe/mindspore/free_benchmark/common/utils.py +1 -1
  132. msprobe/mindspore/grad_probe/global_context.py +7 -2
  133. msprobe/mindspore/grad_probe/grad_stat_csv.py +3 -2
  134. msprobe/mindspore/mindspore_service.py +114 -0
  135. msprobe/mindspore/monitor/common_func.py +52 -0
  136. msprobe/mindspore/monitor/data_writers.py +237 -0
  137. msprobe/mindspore/monitor/features.py +20 -7
  138. msprobe/mindspore/monitor/module_hook.py +281 -209
  139. msprobe/mindspore/monitor/optimizer_collect.py +334 -0
  140. msprobe/mindspore/monitor/utils.py +25 -5
  141. msprobe/mindspore/ms_config.py +16 -15
  142. msprobe/mindspore/task_handler_factory.py +5 -2
  143. msprobe/msprobe.py +19 -0
  144. msprobe/nan_analyze/__init__.py +14 -0
  145. msprobe/nan_analyze/analyzer.py +255 -0
  146. msprobe/nan_analyze/graph.py +189 -0
  147. msprobe/nan_analyze/utils.py +211 -0
  148. msprobe/pytorch/api_accuracy_checker/common/config.py +2 -2
  149. msprobe/pytorch/api_accuracy_checker/compare/compare.py +36 -34
  150. msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +20 -20
  151. msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +4 -7
  152. msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +204 -2
  153. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +12 -11
  154. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +1 -0
  155. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +8 -5
  156. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +2 -3
  157. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +29 -13
  158. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +12 -2
  159. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +45 -31
  160. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +156 -0
  161. msprobe/pytorch/attl_manager.py +65 -0
  162. msprobe/pytorch/bench_functions/npu_fusion_attention.py +27 -0
  163. msprobe/pytorch/common/utils.py +26 -14
  164. msprobe/pytorch/compare/distributed_compare.py +4 -36
  165. msprobe/pytorch/compare/pt_compare.py +13 -84
  166. msprobe/pytorch/compare/utils.py +47 -0
  167. msprobe/pytorch/debugger/debugger_config.py +34 -17
  168. msprobe/pytorch/debugger/precision_debugger.py +66 -118
  169. msprobe/pytorch/dump/module_dump/hook_wrapper.py +93 -0
  170. msprobe/pytorch/dump/module_dump/module_dump.py +11 -58
  171. msprobe/pytorch/dump/module_dump/module_processer.py +143 -113
  172. msprobe/pytorch/grad_probe/grad_stat_csv.py +3 -2
  173. msprobe/pytorch/hook_module/api_register.py +29 -5
  174. msprobe/pytorch/hook_module/hook_module.py +9 -18
  175. msprobe/pytorch/hook_module/jit_script_wrapper.py +33 -0
  176. msprobe/pytorch/hook_module/pt_hook_manager.py +68 -0
  177. msprobe/pytorch/hook_module/support_wrap_ops.yaml +22 -1
  178. msprobe/pytorch/hook_module/utils.py +28 -2
  179. msprobe/pytorch/monitor/csv2tb.py +6 -2
  180. msprobe/pytorch/monitor/data_writers.py +259 -0
  181. msprobe/pytorch/monitor/module_hook.py +227 -158
  182. msprobe/pytorch/monitor/module_metric.py +14 -0
  183. msprobe/pytorch/monitor/optimizer_collect.py +242 -270
  184. msprobe/pytorch/monitor/utils.py +16 -3
  185. msprobe/pytorch/online_dispatch/dispatch.py +4 -2
  186. msprobe/pytorch/online_dispatch/dump_compare.py +5 -2
  187. msprobe/pytorch/parse_tool/lib/utils.py +3 -3
  188. msprobe/pytorch/pt_config.py +8 -7
  189. msprobe/pytorch/pytorch_service.py +73 -0
  190. msprobe/visualization/builder/graph_builder.py +33 -13
  191. msprobe/visualization/builder/msprobe_adapter.py +24 -11
  192. msprobe/visualization/compare/graph_comparator.py +53 -45
  193. msprobe/visualization/compare/mode_adapter.py +31 -1
  194. msprobe/visualization/graph/base_node.py +3 -3
  195. msprobe/visualization/graph/graph.py +2 -2
  196. msprobe/visualization/graph_service.py +250 -103
  197. msprobe/visualization/utils.py +27 -11
  198. msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +0 -106
  199. msprobe/mindspore/monitor/anomaly_detect.py +0 -404
  200. msprobe/mindspore/monitor/module_spec_verifier.py +0 -94
  201. msprobe/mindspore/service.py +0 -549
  202. msprobe/pytorch/monitor/anomaly_analyse.py +0 -201
  203. msprobe/pytorch/monitor/anomaly_detect.py +0 -410
  204. msprobe/pytorch/monitor/module_spec_verifier.py +0 -95
  205. msprobe/pytorch/monitor/unittest/test_monitor.py +0 -160
  206. msprobe/pytorch/service.py +0 -473
  207. {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/LICENSE +0 -0
  208. {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/WHEEL +0 -0
  209. {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/entry_points.txt +0 -0
  210. {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/top_level.txt +0 -0
  211. /msprobe/{mindspore → core}/compare/ms_to_pt_api.yaml +0 -0
  212. /msprobe/{mindspore/dump → core}/kernel_dump/kernel_config.py +0 -0
  213. /msprobe/{pytorch/monitor/unittest → core/monitor}/__init__.py +0 -0
@@ -0,0 +1,115 @@
1
+ # Copyright (c) 2025-2025, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import re
17
+ from abc import ABC, abstractmethod
18
+
19
+ from msprobe.core.config_check.utils.utils import config_checking_print
20
+ from msprobe.core.common.file_utils import FileOpen, load_yaml
21
+ from msprobe.core.common.const import Const, FileCheckConst
22
+
23
+
24
+ class Parser(ABC):
25
+ @abstractmethod
26
+ def parse(self, file_path: str) -> dict:
27
+ pass
28
+
29
+ def run(self, file_path: str) -> dict:
30
+ """
31
+ 统一对外调用接口
32
+ :param file_path: 需解析的文件路径
33
+ :return:
34
+ """
35
+ try:
36
+ result = self.parse(file_path)
37
+ except Exception as exc:
38
+ config_checking_print(f"{self.__class__} parsing error, skip file path: {file_path}, error: {exc}")
39
+ result = {}
40
+ return result
41
+
42
+
43
+ class ShellParser(Parser):
44
+ def parse(self, file_path: str) -> dict:
45
+ """
46
+ Extracts arguments from bash script used to run a model training.
47
+ """
48
+ hyperparameters = {}
49
+ script_content_list = []
50
+ with FileOpen(file_path, 'r') as file:
51
+ for line in file:
52
+ stripped_line = line.lstrip()
53
+ if not stripped_line.startswith('#'):
54
+ line = line.split('#')[0].rstrip() + '\n'
55
+ if line.strip():
56
+ script_content_list.append(line)
57
+ script_content = ''.join(script_content_list)
58
+
59
+ command_line = re.search(r'msrun\s[^|]*|torchrun\s[^|]*|python\d? -m torch.distributed.launch\s[^|]*',
60
+ script_content,
61
+ re.DOTALL)
62
+ if command_line:
63
+ command_line = command_line.group()
64
+
65
+ blocks = re.findall(r'([a-zA-Z0-9_]{1,20}_ARGS)="(.*?)"', script_content, re.DOTALL)
66
+ block_contents = {}
67
+ for block_name, block_content in blocks:
68
+ block_content = block_content.replace('\n', ' ')
69
+ block_contents[block_name] = block_content
70
+ command_line = command_line.replace(f"${block_name}", block_content)
71
+
72
+ matches = re.findall(r'--([\w-]+)(?:\s+([^\s\\]+))?', command_line)
73
+ for match in matches:
74
+ key, value = match
75
+ args_key = re.match(r'\$\{?(\w+)}?', value)
76
+ if args_key:
77
+ env_vars = re.findall(rf'{args_key.group(1)}=\s*(.+)', script_content)
78
+ if env_vars:
79
+ value = env_vars[-1]
80
+ hyperparameters[key] = value if value else True
81
+
82
+ return hyperparameters
83
+
84
+
85
+ class YamlParser(Parser):
86
+ hyperparameters = {}
87
+
88
+ def parse(self, file_path: str) -> dict:
89
+ ori_hyper = load_yaml(file_path)
90
+ self.recursive_parse_parameters(ori_hyper, "")
91
+ return self.hyperparameters
92
+
93
+ def recursive_parse_parameters(self, parameters, prefix):
94
+ if isinstance(parameters, dict):
95
+ for key, value in parameters.items():
96
+ new_prefix = prefix + Const.SEP + key if prefix else key
97
+ self.recursive_parse_parameters(value, new_prefix)
98
+ elif isinstance(parameters, list):
99
+ for value in parameters:
100
+ self.recursive_parse_parameters(value, prefix)
101
+ elif isinstance(parameters, (int, str, bool)):
102
+ self.hyperparameters.update({prefix: parameters})
103
+
104
+
105
+ class ParserFactory:
106
+ __ParserDict = {
107
+ FileCheckConst.SHELL_SUFFIX: ShellParser(),
108
+ FileCheckConst.YAML_SUFFIX: YamlParser()
109
+ }
110
+
111
+ def get_parser(self, file_type: str) -> Parser:
112
+ parser = self.__ParserDict.get(file_type, None)
113
+ if not parser:
114
+ raise ValueError(f'Invalid parser type: {file_type}')
115
+ return parser
@@ -0,0 +1,107 @@
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import os
17
+ import re
18
+ import hashlib
19
+
20
+ from msprobe.core.common.framework_adapter import FmkAdp
21
+ from msprobe.core.common.log import logger
22
+
23
+
24
+ def merge_keys(dir_0, dir_1):
25
+ output_list = list(dir_0.keys())
26
+ output_list.extend(list(dir_1.keys()))
27
+ return set(output_list)
28
+
29
+
30
+ def compare_dict(bench_dict, cmp_dict):
31
+ result = []
32
+ for key in set(bench_dict.keys()) | set(cmp_dict.keys()):
33
+ if key in bench_dict and key in cmp_dict:
34
+ if bench_dict[key] != cmp_dict[key]:
35
+ result.append(f"{key}: {bench_dict[key]} -> {cmp_dict[key]}")
36
+ elif key in bench_dict:
37
+ result.append(f"{key}: [deleted] -> {bench_dict[key]}")
38
+ else:
39
+ result.append(f"{key}: [added] -> {cmp_dict[key]}")
40
+ return result
41
+
42
+
43
+ def config_checking_print(msg):
44
+ logger.info(f"[config checking log] {msg}")
45
+
46
+
47
+ def tensor_to_hash(tensor):
48
+ """Compute the hash value of a tensor"""
49
+ tensor_bytes = tensor.clone().detach().cpu().numpy().tobytes()
50
+ return bytes_hash(tensor_bytes)
51
+
52
+
53
+ def get_tensor_features(tensor):
54
+ features = {
55
+ "max": FmkAdp.tensor_max(tensor),
56
+ "min": FmkAdp.tensor_min(tensor),
57
+ "mean": FmkAdp.tensor_mean(tensor),
58
+ "norm": FmkAdp.tensor_norm(tensor),
59
+ }
60
+
61
+ return features
62
+
63
+
64
+ def compare_dicts(dict1, dict2, path=''):
65
+ deleted = []
66
+ added = []
67
+ changed = []
68
+ result = {}
69
+
70
+ for key in dict1:
71
+ if key not in dict2:
72
+ deleted.append(f"[Deleted]: {path + key}")
73
+ result[key] = "[deleted]"
74
+ else:
75
+ if isinstance(dict1[key], dict) and isinstance(dict2[key], dict):
76
+ sub_deleted, sub_added, sub_changed, sub_result = compare_dicts(
77
+ dict1[key], dict2[key], path + key + '/')
78
+ deleted.extend(sub_deleted)
79
+ added.extend(sub_added)
80
+ changed.extend(sub_changed)
81
+ if sub_result:
82
+ result[key] = sub_result
83
+ elif dict1[key] != dict2[key]:
84
+ changed.append(f"[Changed]: {path + key} : {dict1[key]} -> {dict2[key]}")
85
+ result[key] = f"[changed]: {dict1[key]} -> {dict2[key]}"
86
+ for key in dict2:
87
+ if key not in dict1:
88
+ added.append(f"[Added]: {path + key}")
89
+ result[key] = "[added]"
90
+ return deleted, added, changed, result
91
+
92
+
93
+ def bytes_hash(obj: bytes):
94
+ hex_dig = hashlib.sha256(obj).hexdigest()
95
+ short_hash = int(hex_dig, 16) % (2 ** 16)
96
+ return short_hash
97
+
98
+
99
+ def update_dict(ori_dict, new_dict):
100
+ for key, value in new_dict.items():
101
+ if key in ori_dict and ori_dict[key] != value:
102
+ if "values" in ori_dict.keys():
103
+ ori_dict[key]["values"].append(new_dict[key])
104
+ else:
105
+ ori_dict[key] = {"description": "duplicate_value", "values": [ori_dict[key], new_dict[key]]}
106
+ else:
107
+ ori_dict[key] = value
@@ -13,10 +13,12 @@
13
13
  # See the License for the specific language governing permissions and
14
14
  # limitations under the License.
15
15
 
16
+ import inspect
16
17
  from typing import Dict, Any, Optional, Callable, Union, List, Tuple
17
18
 
18
19
  from msprobe.core.common.const import Const
19
20
  from msprobe.core.common.file_utils import load_yaml
21
+ from msprobe.core.common.log import logger
20
22
 
21
23
 
22
24
  def _get_attr(module, attr_name):
@@ -32,7 +34,8 @@ def _get_attr(module, attr_name):
32
34
  class ApiWrapper:
33
35
  def __init__(
34
36
  self, api_types: Dict[str, Dict[str, Any]],
35
- api_list_paths: Union[str, List[str], Tuple[str]]
37
+ api_list_paths: Union[str, List[str], Tuple[str]],
38
+ backlist: Union[List[str], Tuple[str]] = None
36
39
  ):
37
40
  self.api_types = api_types
38
41
  if not isinstance(api_list_paths, (list, tuple)):
@@ -41,9 +44,42 @@ class ApiWrapper:
41
44
  raise RuntimeError("The number of api_list_paths must be equal to the number of frameworks in 'api_types', "
42
45
  "when api_list_paths is a list or tuple.")
43
46
  self.api_list_paths = api_list_paths
47
+ self.backlist = backlist if backlist else []
44
48
  self.api_names = self._get_api_names()
45
49
  self.wrapped_api_functions = dict()
46
50
 
51
+ @staticmethod
52
+ def deal_with_self_kwargs(api_name, api_func, args, kwargs):
53
+ if kwargs and 'self' in kwargs:
54
+ func_params = None
55
+ try:
56
+ func_params = inspect.signature(api_func).parameters
57
+ except Exception:
58
+ if api_name in Const.API_WITH_SELF_ARG:
59
+ func_params = inspect.signature(Const.API_WITH_SELF_ARG.get(api_name)).parameters
60
+ if func_params is None:
61
+ return False, args, kwargs
62
+
63
+ for name, param in func_params.items():
64
+ if name == 'self' and param.kind == inspect.Parameter.KEYWORD_ONLY:
65
+ return False, args, kwargs
66
+ args_ = list(args)
67
+ names_and_values = []
68
+ self_index = 0
69
+ for i, item in enumerate(func_params.items()):
70
+ names_and_values.append((item[0], item[1].default))
71
+ if item[0] == 'self':
72
+ self_index = i
73
+ break
74
+ for i in range(len(args), self_index + 1):
75
+ if names_and_values[i][0] in kwargs:
76
+ args_.append(kwargs.pop(names_and_values[i][0]))
77
+ else:
78
+ args_.append(names_and_values[i][1])
79
+ args = tuple(args_)
80
+
81
+ return True, args, kwargs
82
+
47
83
  def wrap_api(
48
84
  self, api_templates, hook_build_func: Optional[Callable]
49
85
  ):
@@ -68,6 +104,14 @@ class ApiWrapper:
68
104
  if callable(ori_api):
69
105
  def wrap_api_func(api_name, api_func, prefix, hook_build_func, api_template):
70
106
  def api_function(*args, **kwargs):
107
+ api_name_with_prefix = prefix + Const.SEP + str(api_name.split(Const.SEP)[-1])
108
+ enable_wrap, args, kwargs = self.deal_with_self_kwargs(api_name_with_prefix,
109
+ api_func, args, kwargs)
110
+ if not enable_wrap:
111
+ logger.warning(f'Cannot collect precision data of {api_name_with_prefix}. '
112
+ 'It may be fixed by passing the value of "self" '
113
+ 'as a positional argument instead of a keyword argument. ')
114
+ return api_func(*args, **kwargs)
71
115
  return api_template(api_name, api_func, prefix, hook_build_func)(*args, **kwargs)
72
116
  api_function.__name__ = api_name
73
117
  return api_function
@@ -84,9 +128,12 @@ class ApiWrapper:
84
128
  api_list = load_yaml(self.api_list_paths[index])
85
129
  valid_names = dict()
86
130
  for api_type, api_modules in self.api_types.get(framework, {}).items():
87
- api_from_file = api_list.get(Const.SUPPORT_API_DICT_KEY_MAP.get(framework, {}).get(api_type), [])
131
+ key_in_file = Const.SUPPORT_API_DICT_KEY_MAP.get(framework, {}).get(api_type)
132
+ api_from_file = api_list.get(key_in_file, [])
88
133
  names = set()
89
134
  for api_name in api_from_file:
135
+ if f'{key_in_file}.{api_name}' in self.backlist:
136
+ continue
90
137
  target_attr = api_name
91
138
  target_module = api_modules[0]
92
139
  if Const.SEP in api_name:
@@ -105,7 +152,7 @@ class ApiRegistry:
105
152
  Base class for api registry.
106
153
  """
107
154
 
108
- def __init__(self, api_types, inner_used_api, supported_api_list_path, api_templates):
155
+ def __init__(self, api_types, inner_used_api, supported_api_list_path, api_templates, backlist=None):
109
156
  self.ori_api_attr = dict()
110
157
  self.wrapped_api_attr = dict()
111
158
  self.inner_used_ori_attr = dict()
@@ -114,6 +161,8 @@ class ApiRegistry:
114
161
  self.inner_used_api = inner_used_api
115
162
  self.supported_api_list_path = supported_api_list_path
116
163
  self.api_templates = api_templates
164
+ self.backlist = backlist if backlist else []
165
+ self.all_api_registered = False
117
166
 
118
167
  @staticmethod
119
168
  def store_ori_attr(ori_api_group, api_list, api_ori_attr):
@@ -131,7 +180,20 @@ class ApiRegistry:
131
180
  else:
132
181
  setattr(api_group, api, api_attr)
133
182
 
183
+ @staticmethod
184
+ def register_custom_api(module, api_name, api_prefix, hook_build_func, api_template):
185
+ def wrap_api_func(api_name, api_func, prefix, hook_build_func, api_template):
186
+ def api_function(*args, **kwargs):
187
+ return api_template(api_name, api_func, prefix, hook_build_func)(*args, **kwargs)
188
+
189
+ api_function.__name__ = api_name
190
+ return api_function
191
+
192
+ setattr(module, api_name,
193
+ wrap_api_func(api_name, getattr(module, api_name), api_prefix, hook_build_func, api_template))
194
+
134
195
  def register_all_api(self):
196
+ self.all_api_registered = True
135
197
  for framework, api_types in self.api_types.items():
136
198
  for api_type, api_modules in api_types.items():
137
199
  api_type_with_framework = framework + Const.SEP + api_type
@@ -143,6 +205,7 @@ class ApiRegistry:
143
205
  self.set_api_attr(self.inner_used_api.get(api_type)[0], self.inner_used_wrapped_attr.get(api_type, {}))
144
206
 
145
207
  def restore_all_api(self):
208
+ self.all_api_registered = False
146
209
  for framework, api_types in self.api_types.items():
147
210
  for api_type, api_modules in api_types.items():
148
211
  api_type_with_framework = framework + Const.SEP + api_type
@@ -154,7 +217,7 @@ class ApiRegistry:
154
217
  self.set_api_attr(self.inner_used_api.get(api_type)[0], self.inner_used_ori_attr.get(api_type, {}))
155
218
 
156
219
  def initialize_hook(self, hook_build_func):
157
- api_wrapper = ApiWrapper(self.api_types, self.supported_api_list_path)
220
+ api_wrapper = ApiWrapper(self.api_types, self.supported_api_list_path, self.backlist)
158
221
  wrapped_api_functions = api_wrapper.wrap_api(self.api_templates, hook_build_func)
159
222
 
160
223
  for framework, api_types in self.api_types.items():