mindstudio-probe 1.3.0__py3-none-any.whl → 8.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (213) hide show
  1. {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/METADATA +4 -2
  2. {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/RECORD +204 -152
  3. msprobe/README.md +32 -1
  4. msprobe/core/__init__.py +17 -0
  5. msprobe/core/common/const.py +120 -21
  6. msprobe/core/common/exceptions.py +2 -2
  7. msprobe/core/common/file_utils.py +279 -50
  8. msprobe/core/common/framework_adapter.py +169 -0
  9. msprobe/core/common/global_lock.py +86 -0
  10. msprobe/core/common/runtime.py +25 -0
  11. msprobe/core/common/utils.py +136 -45
  12. msprobe/core/common_config.py +7 -0
  13. msprobe/core/compare/acc_compare.py +646 -428
  14. msprobe/core/compare/check.py +36 -103
  15. msprobe/core/compare/compare_cli.py +4 -0
  16. msprobe/core/compare/config.py +72 -0
  17. msprobe/core/compare/highlight.py +215 -215
  18. msprobe/core/compare/layer_mapping/layer_mapping.py +2 -0
  19. msprobe/core/compare/merge_result/merge_result.py +4 -4
  20. msprobe/core/compare/multiprocessing_compute.py +223 -110
  21. msprobe/core/compare/npy_compare.py +2 -4
  22. msprobe/core/compare/utils.py +214 -244
  23. msprobe/core/config_check/__init__.py +17 -0
  24. msprobe/{pytorch/dump/kernel_dump/kernel_config.py → core/config_check/checkers/__init__.py} +8 -16
  25. msprobe/core/config_check/checkers/base_checker.py +60 -0
  26. msprobe/core/config_check/checkers/dataset_checker.py +138 -0
  27. msprobe/core/config_check/checkers/env_args_checker.py +96 -0
  28. msprobe/core/config_check/checkers/hyperparameter_checker.py +170 -0
  29. msprobe/core/config_check/checkers/pip_checker.py +90 -0
  30. msprobe/core/config_check/checkers/random_checker.py +367 -0
  31. msprobe/core/config_check/checkers/weights_checker.py +147 -0
  32. msprobe/core/config_check/ckpt_compare/ckpt_comparator.py +74 -0
  33. msprobe/core/config_check/ckpt_compare/megatron_loader.py +302 -0
  34. msprobe/core/config_check/ckpt_compare/metrics.py +83 -0
  35. msprobe/core/config_check/ckpt_compare/name_mapping.yaml +12 -0
  36. msprobe/core/config_check/config_check_cli.py +51 -0
  37. msprobe/core/config_check/config_checker.py +100 -0
  38. msprobe/{mindspore/runtime.py → core/config_check/resource/dependency.yaml} +7 -4
  39. msprobe/core/config_check/resource/env.yaml +57 -0
  40. msprobe/core/config_check/resource/hyperparameter.yaml +21 -0
  41. msprobe/core/config_check/utils/hyperparameter_parser.py +115 -0
  42. msprobe/core/config_check/utils/utils.py +107 -0
  43. msprobe/core/data_dump/api_registry.py +67 -4
  44. msprobe/core/data_dump/data_collector.py +170 -89
  45. msprobe/core/data_dump/data_processor/base.py +72 -51
  46. msprobe/core/data_dump/data_processor/mindspore_processor.py +109 -55
  47. msprobe/core/data_dump/data_processor/pytorch_processor.py +90 -82
  48. msprobe/core/data_dump/json_writer.py +143 -27
  49. msprobe/core/debugger/precision_debugger.py +144 -0
  50. msprobe/core/grad_probe/constant.py +1 -1
  51. msprobe/core/grad_probe/grad_compare.py +1 -1
  52. msprobe/core/grad_probe/utils.py +1 -1
  53. msprobe/core/hook_manager.py +242 -0
  54. msprobe/core/monitor/anomaly_processor.py +384 -0
  55. msprobe/core/service.py +357 -0
  56. msprobe/core/single_save/__init__.py +0 -0
  57. msprobe/core/single_save/single_comparator.py +243 -0
  58. msprobe/core/single_save/single_saver.py +146 -0
  59. msprobe/docs/01.installation.md +6 -5
  60. msprobe/docs/02.config_introduction.md +79 -22
  61. msprobe/docs/03.config_examples.md +1 -0
  62. msprobe/docs/04.kernel_dump_PyTorch.md +1 -1
  63. msprobe/docs/05.data_dump_PyTorch.md +118 -49
  64. msprobe/docs/06.data_dump_MindSpore.md +167 -20
  65. msprobe/docs/07.accuracy_checker_PyTorch.md +2 -2
  66. msprobe/docs/08.accuracy_checker_online_PyTorch.md +69 -9
  67. msprobe/docs/09.accuracy_checker_MindSpore.md +18 -6
  68. msprobe/docs/10.accuracy_compare_PyTorch.md +212 -74
  69. msprobe/docs/11.accuracy_compare_MindSpore.md +87 -37
  70. msprobe/docs/12.overflow_check_PyTorch.md +2 -2
  71. msprobe/docs/13.overflow_check_MindSpore.md +2 -2
  72. msprobe/docs/14.data_parse_PyTorch.md +3 -3
  73. msprobe/docs/17.grad_probe.md +2 -1
  74. msprobe/docs/18.online_dispatch.md +2 -2
  75. msprobe/docs/19.monitor.md +90 -44
  76. msprobe/docs/21.visualization_PyTorch.md +68 -15
  77. msprobe/docs/22.visualization_MindSpore.md +71 -18
  78. msprobe/docs/25.tool_function_introduction.md +23 -22
  79. msprobe/docs/26.data_dump_PyTorch_baseline.md +14 -3
  80. msprobe/docs/27.dump_json_instruction.md +1 -1
  81. msprobe/docs/28.debugger_save_instruction.md +111 -20
  82. msprobe/docs/29.data_dump_MSAdapter.md +2 -2
  83. msprobe/docs/30.overflow_check_MSAdapter.md +2 -2
  84. msprobe/docs/31.config_check.md +95 -0
  85. msprobe/docs/32.ckpt_compare.md +69 -0
  86. msprobe/docs/33.generate_operator_MindSpore.md +181 -0
  87. msprobe/docs/34.RL_collect.md +92 -0
  88. msprobe/docs/35.nan_analyze.md +72 -0
  89. msprobe/docs/data_dump_MindSpore/data_dump_MindSpore_baseline.md +12 -1
  90. msprobe/docs/data_dump_MindSpore/dynamic_graph_quick_start_example.md +3 -1
  91. msprobe/docs/img/compare_result.png +0 -0
  92. msprobe/docs/img/save_compare_result_sample.png +0 -0
  93. msprobe/docs/img/visualization/proxy.png +0 -0
  94. msprobe/mindspore/__init__.py +1 -2
  95. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +150 -58
  96. msprobe/mindspore/api_accuracy_checker/api_runner.py +7 -3
  97. msprobe/mindspore/api_accuracy_checker/bench_functions/flash_attention_score.py +47 -69
  98. msprobe/mindspore/api_accuracy_checker/cmd_parser.py +4 -0
  99. msprobe/mindspore/api_accuracy_checker/compute_element.py +0 -1
  100. msprobe/mindspore/api_accuracy_checker/data_manager.py +2 -2
  101. msprobe/mindspore/api_accuracy_checker/generate_op_script/op_generator.py +460 -0
  102. msprobe/mindspore/api_accuracy_checker/generate_op_script/operator_replication.template +2081 -0
  103. msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +9 -0
  104. msprobe/mindspore/api_accuracy_checker/torch_mindtorch_importer.py +2 -1
  105. msprobe/mindspore/cell_processor.py +204 -33
  106. msprobe/mindspore/code_mapping/graph_parser.py +4 -21
  107. msprobe/mindspore/common/const.py +17 -7
  108. msprobe/mindspore/common/utils.py +128 -11
  109. msprobe/mindspore/compare/common_dir_compare.py +382 -0
  110. msprobe/mindspore/compare/distributed_compare.py +2 -26
  111. msprobe/mindspore/compare/ms_compare.py +17 -405
  112. msprobe/mindspore/compare/ms_graph_compare.py +14 -5
  113. msprobe/mindspore/compare/utils.py +37 -0
  114. msprobe/mindspore/debugger/debugger_config.py +53 -3
  115. msprobe/mindspore/debugger/precision_debugger.py +72 -91
  116. msprobe/mindspore/dump/cell_dump_process.py +877 -0
  117. msprobe/mindspore/dump/cell_dump_with_insert_gradient.py +864 -0
  118. msprobe/mindspore/dump/dump_tool_factory.py +13 -5
  119. msprobe/mindspore/dump/graph_mode_cell_dump.py +139 -0
  120. msprobe/mindspore/dump/graph_tensor_dump.py +123 -0
  121. msprobe/mindspore/dump/hook_cell/api_register.py +40 -6
  122. msprobe/mindspore/dump/hook_cell/hook_cell.py +18 -7
  123. msprobe/mindspore/dump/hook_cell/ms_hook_manager.py +88 -0
  124. msprobe/mindspore/dump/hook_cell/primitive_hooks.py +8 -2
  125. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +18 -0
  126. msprobe/mindspore/dump/jit_dump.py +21 -18
  127. msprobe/mindspore/dump/kernel_kbyk_dump.py +6 -3
  128. msprobe/mindspore/dym_loader/hook_dynamic_loader.cpp +110 -0
  129. msprobe/mindspore/dym_loader/hook_dynamic_loader.h +15 -15
  130. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +12 -6
  131. msprobe/mindspore/free_benchmark/common/utils.py +1 -1
  132. msprobe/mindspore/grad_probe/global_context.py +7 -2
  133. msprobe/mindspore/grad_probe/grad_stat_csv.py +3 -2
  134. msprobe/mindspore/mindspore_service.py +114 -0
  135. msprobe/mindspore/monitor/common_func.py +52 -0
  136. msprobe/mindspore/monitor/data_writers.py +237 -0
  137. msprobe/mindspore/monitor/features.py +20 -7
  138. msprobe/mindspore/monitor/module_hook.py +281 -209
  139. msprobe/mindspore/monitor/optimizer_collect.py +334 -0
  140. msprobe/mindspore/monitor/utils.py +25 -5
  141. msprobe/mindspore/ms_config.py +16 -15
  142. msprobe/mindspore/task_handler_factory.py +5 -2
  143. msprobe/msprobe.py +19 -0
  144. msprobe/nan_analyze/__init__.py +14 -0
  145. msprobe/nan_analyze/analyzer.py +255 -0
  146. msprobe/nan_analyze/graph.py +189 -0
  147. msprobe/nan_analyze/utils.py +211 -0
  148. msprobe/pytorch/api_accuracy_checker/common/config.py +2 -2
  149. msprobe/pytorch/api_accuracy_checker/compare/compare.py +36 -34
  150. msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +20 -20
  151. msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +4 -7
  152. msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +204 -2
  153. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +12 -11
  154. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +1 -0
  155. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +8 -5
  156. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +2 -3
  157. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +29 -13
  158. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +12 -2
  159. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +45 -31
  160. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +156 -0
  161. msprobe/pytorch/attl_manager.py +65 -0
  162. msprobe/pytorch/bench_functions/npu_fusion_attention.py +27 -0
  163. msprobe/pytorch/common/utils.py +26 -14
  164. msprobe/pytorch/compare/distributed_compare.py +4 -36
  165. msprobe/pytorch/compare/pt_compare.py +13 -84
  166. msprobe/pytorch/compare/utils.py +47 -0
  167. msprobe/pytorch/debugger/debugger_config.py +34 -17
  168. msprobe/pytorch/debugger/precision_debugger.py +66 -118
  169. msprobe/pytorch/dump/module_dump/hook_wrapper.py +93 -0
  170. msprobe/pytorch/dump/module_dump/module_dump.py +11 -58
  171. msprobe/pytorch/dump/module_dump/module_processer.py +143 -113
  172. msprobe/pytorch/grad_probe/grad_stat_csv.py +3 -2
  173. msprobe/pytorch/hook_module/api_register.py +29 -5
  174. msprobe/pytorch/hook_module/hook_module.py +9 -18
  175. msprobe/pytorch/hook_module/jit_script_wrapper.py +33 -0
  176. msprobe/pytorch/hook_module/pt_hook_manager.py +68 -0
  177. msprobe/pytorch/hook_module/support_wrap_ops.yaml +22 -1
  178. msprobe/pytorch/hook_module/utils.py +28 -2
  179. msprobe/pytorch/monitor/csv2tb.py +6 -2
  180. msprobe/pytorch/monitor/data_writers.py +259 -0
  181. msprobe/pytorch/monitor/module_hook.py +227 -158
  182. msprobe/pytorch/monitor/module_metric.py +14 -0
  183. msprobe/pytorch/monitor/optimizer_collect.py +242 -270
  184. msprobe/pytorch/monitor/utils.py +16 -3
  185. msprobe/pytorch/online_dispatch/dispatch.py +4 -2
  186. msprobe/pytorch/online_dispatch/dump_compare.py +5 -2
  187. msprobe/pytorch/parse_tool/lib/utils.py +3 -3
  188. msprobe/pytorch/pt_config.py +8 -7
  189. msprobe/pytorch/pytorch_service.py +73 -0
  190. msprobe/visualization/builder/graph_builder.py +33 -13
  191. msprobe/visualization/builder/msprobe_adapter.py +24 -11
  192. msprobe/visualization/compare/graph_comparator.py +53 -45
  193. msprobe/visualization/compare/mode_adapter.py +31 -1
  194. msprobe/visualization/graph/base_node.py +3 -3
  195. msprobe/visualization/graph/graph.py +2 -2
  196. msprobe/visualization/graph_service.py +250 -103
  197. msprobe/visualization/utils.py +27 -11
  198. msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +0 -106
  199. msprobe/mindspore/monitor/anomaly_detect.py +0 -404
  200. msprobe/mindspore/monitor/module_spec_verifier.py +0 -94
  201. msprobe/mindspore/service.py +0 -549
  202. msprobe/pytorch/monitor/anomaly_analyse.py +0 -201
  203. msprobe/pytorch/monitor/anomaly_detect.py +0 -410
  204. msprobe/pytorch/monitor/module_spec_verifier.py +0 -95
  205. msprobe/pytorch/monitor/unittest/test_monitor.py +0 -160
  206. msprobe/pytorch/service.py +0 -473
  207. {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/LICENSE +0 -0
  208. {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/WHEEL +0 -0
  209. {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/entry_points.txt +0 -0
  210. {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/top_level.txt +0 -0
  211. /msprobe/{mindspore → core}/compare/ms_to_pt_api.yaml +0 -0
  212. /msprobe/{mindspore/dump → core}/kernel_dump/kernel_config.py +0 -0
  213. /msprobe/{pytorch/monitor/unittest → core/monitor}/__init__.py +0 -0
@@ -0,0 +1,25 @@
1
+ # Copyright (c) 2025, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from msprobe.core.common.const import Const
17
+
18
+
19
+ class Runtime:
20
+ step_count: int = 0
21
+ rank_id: int = -1
22
+ is_running: bool = False
23
+ run_mode: str = Const.PYNATIVE_MODE
24
+ current_iter: int = 0
25
+ current_rank: None
@@ -18,6 +18,7 @@ import os
18
18
  import re
19
19
  import subprocess
20
20
  import time
21
+ import inspect
21
22
  from datetime import datetime, timezone
22
23
 
23
24
  import numpy as np
@@ -26,10 +27,15 @@ from msprobe.core.common.file_utils import (FileOpen, check_file_or_directory_pa
26
27
  from msprobe.core.common.const import Const, CompareConst
27
28
  from msprobe.core.common.log import logger
28
29
  from msprobe.core.common.exceptions import MsprobeException
30
+ from msprobe.core.common.decorator import recursion_depth_decorator
29
31
 
30
32
 
31
33
  device = collections.namedtuple('device', ['type', 'index'])
32
34
  prefixes = ['api_stack', 'list', 'range', 'acl']
35
+ file_suffix_to_file_type = {
36
+ "dump.json": Const.DUMP_JSON_FILE,
37
+ "debug.json": Const.DEBUG_JSON_FILE,
38
+ }
33
39
 
34
40
 
35
41
  class MsprobeBaseException(Exception):
@@ -74,6 +80,7 @@ class MsprobeBaseException(Exception):
74
80
  NAMES_STRUCTS_MATCH_ERROR = 34
75
81
  INVALID_STATE_ERROR = 35
76
82
  INVALID_API_NAME_ERROR = 36
83
+ CROSS_FRAME_ERROR = 37
77
84
 
78
85
  def __init__(self, code, error_info: str = ""):
79
86
  super(MsprobeBaseException, self).__init__()
@@ -190,27 +197,6 @@ def check_regex_prefix_format_valid(prefix):
190
197
  raise ValueError(f"prefix contains invalid characters, prefix pattern {Const.REGEX_PREFIX_PATTERN}")
191
198
 
192
199
 
193
- def execute_command(cmd):
194
- """
195
- Function Description:
196
- run the following command
197
- Parameter:
198
- cmd: command
199
- Exception Description:
200
- when invalid command throw exception
201
- """
202
- logger.info('Execute command:%s' % cmd)
203
- process = subprocess.Popen(cmd, shell=False, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
204
- while process.poll() is None:
205
- line = process.stdout.readline()
206
- line = line.strip()
207
- if line:
208
- logger.info(line)
209
- if process.returncode != 0:
210
- logger.error('Failed to execute command:%s' % " ".join(cmd))
211
- raise CompareException(CompareException.INVALID_DATA_ERROR)
212
-
213
-
214
200
  def add_time_as_suffix(name):
215
201
  return '{}_{}.csv'.format(name, time.strftime("%Y%m%d%H%M%S", time.localtime(time.time())))
216
202
 
@@ -231,17 +217,33 @@ def format_value(value):
231
217
  return float('{:.12f}'.format(value))
232
218
 
233
219
 
234
- def md5_find(data):
235
- for key_op in data:
236
- for api_info in data[key_op]:
237
- if isinstance(data[key_op][api_info], list):
238
- for data_detail in data[key_op][api_info]:
239
- if data_detail and 'md5' in data_detail:
240
- return True
241
- if isinstance(data[key_op][api_info], bool):
242
- continue
243
- elif data[key_op][api_info] and 'md5' in data[key_op][api_info]:
220
+ @recursion_depth_decorator('msprobe.core.common.utils.md5_find', max_depth=Const.DUMP_MAX_DEPTH)
221
+ def md5_find(data, json_type=Const.DUMP_JSON_FILE):
222
+ if json_type == Const.DUMP_JSON_FILE:
223
+ for key_op in data:
224
+ for api_info in data[key_op]:
225
+ if isinstance(data[key_op][api_info], list):
226
+ for data_detail in data[key_op][api_info]:
227
+ if data_detail and Const.MD5 in data_detail:
228
+ return True
229
+ if isinstance(data[key_op][api_info], bool):
230
+ continue
231
+ elif data[key_op][api_info] and Const.MD5 in data[key_op][api_info]:
232
+ return True
233
+ elif json_type == Const.DEBUG_JSON_FILE:
234
+ if isinstance(data, dict):
235
+ if Const.MD5 in data:
244
236
  return True
237
+ else:
238
+ for _, data_info in data.items():
239
+ if md5_find(data_info, Const.DEBUG_JSON_FILE):
240
+ return True
241
+ elif isinstance(data, list):
242
+ for data_info in data:
243
+ if md5_find(data_info, Const.DEBUG_JSON_FILE):
244
+ return True
245
+ else:
246
+ return False
245
247
  return False
246
248
 
247
249
 
@@ -279,13 +281,41 @@ def get_stack_construct_by_dump_json_path(dump_json_path):
279
281
  def set_dump_path(input_param):
280
282
  npu_path = input_param.get("npu_json_path", None)
281
283
  bench_path = input_param.get("bench_json_path", None)
282
- npu_path_valid = npu_path is not None and npu_path.endswith("dump.json")
283
- bench_path_valid = bench_path is not None and bench_path.endswith("dump.json")
284
- if not npu_path_valid or not bench_path_valid:
284
+ dump_json_path_valid = npu_path is not None and npu_path.endswith("dump.json") and \
285
+ bench_path is not None and bench_path.endswith("dump.json")
286
+ debug_json_path_valid = npu_path is not None and npu_path.endswith("debug.json") and \
287
+ bench_path is not None and bench_path.endswith("debug.json")
288
+ if not dump_json_path_valid and not debug_json_path_valid:
285
289
  logger.error(f"Please check the json path is valid and ensure that neither npu_path nor bench_path is None.")
286
290
  raise CompareException(CompareException.INVALID_PATH_ERROR)
287
- input_param['npu_dump_data_dir'] = os.path.join(os.path.dirname(npu_path), Const.DUMP_TENSOR_DATA)
288
- input_param['bench_dump_data_dir'] = os.path.join(os.path.dirname(bench_path), Const.DUMP_TENSOR_DATA)
291
+ input_param[CompareConst.NPU_DUMP_DATA_DIR] = os.path.join(os.path.dirname(npu_path), Const.DUMP_TENSOR_DATA)
292
+ input_param[CompareConst.BENCH_DUMP_DATA_DIR] = os.path.join(os.path.dirname(bench_path), Const.DUMP_TENSOR_DATA)
293
+
294
+
295
+ def get_file_type(file_path):
296
+ if not isinstance(file_path, str):
297
+ logger.error("get_file_type failed, check the type of file_path.")
298
+ raise CompareException(CompareException.INVALID_PATH_ERROR)
299
+ file_type = file_suffix_to_file_type.get(file_path.split(Const.SCOPE_SEPARATOR)[-1])
300
+ if file_type is None:
301
+ logger.error("get_file_type failed, file_path is neither dump.json nor debug.json.")
302
+ raise CompareException(CompareException.INVALID_PATH_ERROR)
303
+ return file_type
304
+
305
+
306
+ def check_dump_json_key(json_data, device_type):
307
+ task = json_data.get('task', None)
308
+ if not task:
309
+ logger.error(f"Task for {device_type} is empty, please check.")
310
+ raise CompareException(CompareException.INVALID_TASK_ERROR)
311
+ if 'data' not in json_data:
312
+ logger.error(f"Missing 'data' in dump.json, please check dump.json of {device_type}.")
313
+ raise CompareException(CompareException.INVALID_DATA_ERROR)
314
+ api_data = json_data.get('data')
315
+ if not isinstance(api_data, dict):
316
+ logger.error(f"Invalid type for 'data': expected a dict. Please check dump.json of {device_type}.")
317
+ raise CompareException(CompareException.INVALID_DATA_ERROR)
318
+ return task, api_data
289
319
 
290
320
 
291
321
  def get_dump_mode(input_param):
@@ -293,13 +323,10 @@ def get_dump_mode(input_param):
293
323
  bench_path = input_param.get("bench_json_path", None)
294
324
  npu_json_data = load_json(npu_path)
295
325
  bench_json_data = load_json(bench_path)
326
+ json_type = get_file_type(file_path=npu_path)
296
327
 
297
- npu_task = npu_json_data.get('task', None)
298
- bench_task = bench_json_data.get('task', None)
299
-
300
- if not npu_task or not bench_task:
301
- logger.error(f"Please check the dump task is correct, npu's task is {npu_task}, bench's task is {bench_task}.")
302
- raise CompareException(CompareException.INVALID_TASK_ERROR)
328
+ npu_task, npu_api_data = check_dump_json_key(npu_json_data, 'npu')
329
+ bench_task, bench_api_data = check_dump_json_key(bench_json_data, 'bench')
303
330
 
304
331
  if npu_task != bench_task:
305
332
  logger.error(f"Please check the dump task is consistent.")
@@ -312,8 +339,8 @@ def get_dump_mode(input_param):
312
339
  return Const.STRUCTURE
313
340
 
314
341
  if npu_task == Const.STATISTICS:
315
- npu_md5_compare = md5_find(npu_json_data['data'])
316
- bench_md5_compare = md5_find(bench_json_data['data'])
342
+ npu_md5_compare = md5_find(npu_api_data, json_type)
343
+ bench_md5_compare = md5_find(bench_api_data, json_type)
317
344
  if npu_md5_compare == bench_md5_compare:
318
345
  return Const.MD5 if npu_md5_compare else Const.SUMMARY
319
346
  else:
@@ -436,6 +463,28 @@ def check_init_step(step):
436
463
  f"{step} must be greater than or equal to 0")
437
464
 
438
465
 
466
+ def check_token_range(token_range):
467
+ if token_range is None:
468
+ return
469
+ if not isinstance(token_range, (list, tuple)):
470
+ logger.error("Token_range must be a list or tuple.")
471
+ raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR)
472
+ if len(token_range) != 2:
473
+ logger.error("Token_range must contains exactly 2 elements.")
474
+ raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR)
475
+
476
+ start, end = token_range
477
+ if not isinstance(start, int) or not isinstance(end, int):
478
+ logger.error("Start and end in token_range must be integer.")
479
+ raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR)
480
+ if start > end:
481
+ logger.error("Start in token_range must less than the end.")
482
+ raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR)
483
+ if start < 0:
484
+ logger.error("Start in token_range must >= 0.")
485
+ raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR)
486
+
487
+
439
488
  def check_seed_all(seed, mode, rm_dropout):
440
489
  if is_int(seed):
441
490
  if seed < 0 or seed > Const.MAX_SEED_VALUE:
@@ -505,4 +554,46 @@ def is_save_variable_valid(variable, valid_special_types, depth=0):
505
554
  return all(isinstance(key, str) and is_save_variable_valid(value, valid_special_types, depth + 1)
506
555
  for key, value in variable.items())
507
556
  else:
508
- return False
557
+ return False
558
+
559
+
560
+ def replace_last_occurrence(text, old, new):
561
+ if text is None:
562
+ return text
563
+ index = text.rfind(old)
564
+ if index != -1:
565
+ return text[:index] + text[index:].replace(old, new, 1)
566
+ return text
567
+
568
+
569
+ def load_stack_json(stack_path):
570
+ stack_dict = load_json(stack_path)
571
+ if not stack_dict.get(Const.NEW_STACK_FLAG):
572
+ return stack_dict
573
+
574
+ new_stack_dict = {}
575
+ for stack_info in stack_dict.values():
576
+ if len(stack_info) != 2:
577
+ continue
578
+ api_list, stack_str = stack_info
579
+ for api_name in api_list:
580
+ new_stack_dict.update({api_name: stack_str})
581
+ return new_stack_dict
582
+
583
+
584
+ def analyze_api_call_stack(name):
585
+ try:
586
+ api_stack = inspect.stack()[2:]
587
+ except Exception as e:
588
+ logger.warning(f"The call stack of {name} failed to retrieve, {e}.")
589
+ api_stack = None
590
+ stack_str = []
591
+ if api_stack:
592
+ for (_, path, line, func, code, _) in api_stack:
593
+ if not code:
594
+ continue
595
+ stack_line = f"File {path}, line {str(line)}, in {func}, \n {code[0].strip()} \n"
596
+ stack_str.append(stack_line)
597
+ else:
598
+ stack_str.append(Const.WITHOUT_CALL_STACK)
599
+ return "".join(stack_str)
@@ -111,3 +111,10 @@ class BaseConfig:
111
111
  f"The element '{mode}' of data_mode {self.data_mode} is not in {Const.DUMP_DATA_MODE_LIST}.",
112
112
  MsprobeException(MsprobeException.INVALID_PARAM_ERROR)
113
113
  )
114
+
115
+ def _check_summary_mode(self):
116
+ if self.summary_mode and self.summary_mode not in Const.SUMMARY_MODE:
117
+ logger.error_log_with_exp(
118
+ f"summary_mode is invalid, summary_mode is not in {Const.SUMMARY_MODE}.",
119
+ MsprobeException(MsprobeException.INVALID_PARAM_ERROR)
120
+ )