mindstudio-probe 1.2.2__py3-none-any.whl → 8.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (261) hide show
  1. {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/METADATA +4 -3
  2. {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/RECORD +243 -191
  3. msprobe/README.md +57 -21
  4. msprobe/core/__init__.py +17 -0
  5. msprobe/core/common/const.py +224 -82
  6. msprobe/core/common/decorator.py +50 -0
  7. msprobe/core/common/exceptions.py +5 -3
  8. msprobe/core/common/file_utils.py +274 -40
  9. msprobe/core/common/framework_adapter.py +169 -0
  10. msprobe/core/common/global_lock.py +86 -0
  11. msprobe/core/common/runtime.py +25 -0
  12. msprobe/core/common/utils.py +148 -72
  13. msprobe/core/common_config.py +7 -0
  14. msprobe/core/compare/acc_compare.py +640 -462
  15. msprobe/core/compare/check.py +36 -107
  16. msprobe/core/compare/compare_cli.py +4 -0
  17. msprobe/core/compare/config.py +72 -0
  18. msprobe/core/compare/highlight.py +217 -215
  19. msprobe/core/compare/layer_mapping/layer_mapping.py +4 -1
  20. msprobe/core/compare/merge_result/merge_result.py +12 -6
  21. msprobe/core/compare/multiprocessing_compute.py +227 -107
  22. msprobe/core/compare/npy_compare.py +32 -16
  23. msprobe/core/compare/utils.py +218 -244
  24. msprobe/{mindspore/runtime.py → core/config_check/__init__.py} +2 -4
  25. msprobe/{pytorch/dump/kernel_dump/kernel_config.py → core/config_check/checkers/__init__.py} +8 -16
  26. msprobe/core/config_check/checkers/base_checker.py +60 -0
  27. msprobe/core/config_check/checkers/dataset_checker.py +138 -0
  28. msprobe/core/config_check/checkers/env_args_checker.py +96 -0
  29. msprobe/core/config_check/checkers/hyperparameter_checker.py +170 -0
  30. msprobe/core/config_check/checkers/pip_checker.py +90 -0
  31. msprobe/core/config_check/checkers/random_checker.py +367 -0
  32. msprobe/core/config_check/checkers/weights_checker.py +147 -0
  33. msprobe/core/config_check/ckpt_compare/ckpt_comparator.py +74 -0
  34. msprobe/core/config_check/ckpt_compare/megatron_loader.py +302 -0
  35. msprobe/core/config_check/ckpt_compare/metrics.py +83 -0
  36. msprobe/core/config_check/ckpt_compare/name_mapping.yaml +12 -0
  37. msprobe/core/config_check/config_check_cli.py +51 -0
  38. msprobe/core/config_check/config_checker.py +100 -0
  39. msprobe/{pytorch/parse.py → core/config_check/resource/dependency.yaml} +7 -4
  40. msprobe/core/config_check/resource/env.yaml +57 -0
  41. msprobe/core/config_check/resource/hyperparameter.yaml +21 -0
  42. msprobe/core/config_check/utils/hyperparameter_parser.py +115 -0
  43. msprobe/core/config_check/utils/utils.py +107 -0
  44. msprobe/core/data_dump/api_registry.py +239 -0
  45. msprobe/core/data_dump/data_collector.py +36 -9
  46. msprobe/core/data_dump/data_processor/base.py +74 -53
  47. msprobe/core/data_dump/data_processor/mindspore_processor.py +119 -78
  48. msprobe/core/data_dump/data_processor/pytorch_processor.py +134 -96
  49. msprobe/core/data_dump/json_writer.py +146 -57
  50. msprobe/core/debugger/precision_debugger.py +143 -0
  51. msprobe/core/grad_probe/constant.py +2 -1
  52. msprobe/core/grad_probe/grad_compare.py +2 -2
  53. msprobe/core/grad_probe/utils.py +1 -1
  54. msprobe/core/hook_manager.py +242 -0
  55. msprobe/core/monitor/anomaly_processor.py +384 -0
  56. msprobe/core/overflow_check/abnormal_scene.py +2 -0
  57. msprobe/core/service.py +356 -0
  58. msprobe/core/single_save/__init__.py +0 -0
  59. msprobe/core/single_save/single_comparator.py +243 -0
  60. msprobe/core/single_save/single_saver.py +157 -0
  61. msprobe/docs/01.installation.md +6 -5
  62. msprobe/docs/02.config_introduction.md +89 -30
  63. msprobe/docs/03.config_examples.md +1 -0
  64. msprobe/docs/04.kernel_dump_PyTorch.md +1 -1
  65. msprobe/docs/05.data_dump_PyTorch.md +184 -50
  66. msprobe/docs/06.data_dump_MindSpore.md +193 -28
  67. msprobe/docs/07.accuracy_checker_PyTorch.md +13 -3
  68. msprobe/docs/08.accuracy_checker_online_PyTorch.md +72 -10
  69. msprobe/docs/09.accuracy_checker_MindSpore.md +19 -7
  70. msprobe/docs/10.accuracy_compare_PyTorch.md +266 -102
  71. msprobe/docs/11.accuracy_compare_MindSpore.md +117 -43
  72. msprobe/docs/12.overflow_check_PyTorch.md +5 -3
  73. msprobe/docs/13.overflow_check_MindSpore.md +6 -4
  74. msprobe/docs/14.data_parse_PyTorch.md +4 -10
  75. msprobe/docs/17.grad_probe.md +2 -1
  76. msprobe/docs/18.online_dispatch.md +3 -3
  77. msprobe/docs/19.monitor.md +211 -103
  78. msprobe/docs/21.visualization_PyTorch.md +100 -28
  79. msprobe/docs/22.visualization_MindSpore.md +103 -31
  80. msprobe/docs/23.generate_operator_PyTorch.md +9 -9
  81. msprobe/docs/25.tool_function_introduction.md +23 -22
  82. msprobe/docs/26.data_dump_PyTorch_baseline.md +14 -3
  83. msprobe/docs/27.dump_json_instruction.md +278 -8
  84. msprobe/docs/28.debugger_save_instruction.md +111 -20
  85. msprobe/docs/28.kernel_dump_MindSpore.md +1 -1
  86. msprobe/docs/29.data_dump_MSAdapter.md +229 -0
  87. msprobe/docs/30.overflow_check_MSAdapter.md +31 -0
  88. msprobe/docs/31.config_check.md +95 -0
  89. msprobe/docs/32.ckpt_compare.md +69 -0
  90. msprobe/docs/33.generate_operator_MindSpore.md +190 -0
  91. msprobe/docs/34.RL_collect.md +92 -0
  92. msprobe/docs/35.nan_analyze.md +72 -0
  93. msprobe/docs/FAQ.md +3 -11
  94. msprobe/docs/data_dump_MindSpore/data_dump_MindSpore_baseline.md +12 -1
  95. msprobe/docs/data_dump_MindSpore/dynamic_graph_quick_start_example.md +3 -1
  96. msprobe/docs/img/compare_result.png +0 -0
  97. msprobe/docs/img/merge_result.png +0 -0
  98. msprobe/docs/img/save_compare_result_sample.png +0 -0
  99. msprobe/docs/img/visualization/proxy.png +0 -0
  100. msprobe/docs/img/visualization/vis_browser_1.png +0 -0
  101. msprobe/docs/img/visualization/vis_match_info.png +0 -0
  102. msprobe/docs/img/visualization/vis_precision_info.png +0 -0
  103. msprobe/docs/img/visualization/vis_search_info.png +0 -0
  104. msprobe/docs/img/visualization/vis_show_info.png +0 -0
  105. msprobe/docs/img/visualization/vis_showcase.png +0 -0
  106. msprobe/docs/img/visualization/vis_unmatch_info.png +0 -0
  107. msprobe/mindspore/__init__.py +3 -3
  108. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +151 -55
  109. msprobe/mindspore/api_accuracy_checker/api_runner.py +25 -11
  110. msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +2 -1
  111. msprobe/mindspore/api_accuracy_checker/bench_functions/flash_attention_score.py +580 -0
  112. msprobe/mindspore/api_accuracy_checker/bench_functions/fusion_operator.py +41 -0
  113. msprobe/mindspore/api_accuracy_checker/cmd_parser.py +4 -0
  114. msprobe/mindspore/api_accuracy_checker/data_manager.py +4 -3
  115. msprobe/mindspore/api_accuracy_checker/generate_op_script/config_op.json +9 -0
  116. msprobe/mindspore/api_accuracy_checker/generate_op_script/op_generator.py +451 -0
  117. msprobe/mindspore/api_accuracy_checker/generate_op_script/operator_replication.template +2081 -0
  118. msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +11 -1
  119. msprobe/mindspore/api_accuracy_checker/torch_mindtorch_importer.py +2 -1
  120. msprobe/mindspore/cell_processor.py +204 -33
  121. msprobe/mindspore/code_mapping/graph_parser.py +4 -21
  122. msprobe/mindspore/common/const.py +73 -2
  123. msprobe/mindspore/common/utils.py +157 -29
  124. msprobe/mindspore/compare/common_dir_compare.py +382 -0
  125. msprobe/mindspore/compare/distributed_compare.py +2 -26
  126. msprobe/mindspore/compare/ms_compare.py +18 -398
  127. msprobe/mindspore/compare/ms_graph_compare.py +20 -10
  128. msprobe/mindspore/compare/utils.py +37 -0
  129. msprobe/mindspore/debugger/debugger_config.py +59 -7
  130. msprobe/mindspore/debugger/precision_debugger.py +83 -90
  131. msprobe/mindspore/dump/cell_dump_process.py +902 -0
  132. msprobe/mindspore/dump/cell_dump_with_insert_gradient.py +889 -0
  133. msprobe/mindspore/dump/dump_tool_factory.py +18 -8
  134. msprobe/mindspore/dump/graph_mode_cell_dump.py +139 -0
  135. msprobe/mindspore/dump/graph_tensor_dump.py +123 -0
  136. msprobe/mindspore/dump/hook_cell/api_register.py +176 -0
  137. msprobe/mindspore/dump/hook_cell/hook_cell.py +22 -12
  138. msprobe/mindspore/dump/hook_cell/ms_hook_manager.py +88 -0
  139. msprobe/mindspore/dump/hook_cell/primitive_hooks.py +8 -2
  140. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +42 -26
  141. msprobe/mindspore/dump/jit_dump.py +35 -27
  142. msprobe/mindspore/dump/kernel_kbyk_dump.py +6 -3
  143. msprobe/mindspore/dym_loader/hook_dynamic_loader.cpp +110 -0
  144. msprobe/mindspore/dym_loader/hook_dynamic_loader.h +15 -16
  145. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +22 -12
  146. msprobe/mindspore/free_benchmark/common/utils.py +1 -1
  147. msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +4 -2
  148. msprobe/mindspore/free_benchmark/self_check_tool_factory.py +6 -3
  149. msprobe/mindspore/grad_probe/global_context.py +9 -2
  150. msprobe/mindspore/grad_probe/grad_analyzer.py +2 -1
  151. msprobe/mindspore/grad_probe/grad_stat_csv.py +3 -2
  152. msprobe/mindspore/grad_probe/hook.py +2 -4
  153. msprobe/mindspore/mindspore_service.py +111 -0
  154. msprobe/mindspore/monitor/common_func.py +52 -0
  155. msprobe/mindspore/monitor/data_writers.py +237 -0
  156. msprobe/mindspore/monitor/distributed/wrap_distributed.py +1 -1
  157. msprobe/mindspore/monitor/features.py +13 -1
  158. msprobe/mindspore/monitor/module_hook.py +568 -444
  159. msprobe/mindspore/monitor/optimizer_collect.py +331 -0
  160. msprobe/mindspore/monitor/utils.py +71 -9
  161. msprobe/mindspore/ms_config.py +16 -15
  162. msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +5 -3
  163. msprobe/mindspore/task_handler_factory.py +5 -2
  164. msprobe/msprobe.py +19 -0
  165. msprobe/nan_analyze/__init__.py +14 -0
  166. msprobe/nan_analyze/analyzer.py +255 -0
  167. msprobe/nan_analyze/graph.py +189 -0
  168. msprobe/nan_analyze/utils.py +211 -0
  169. msprobe/pytorch/api_accuracy_checker/common/config.py +2 -2
  170. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +3 -6
  171. msprobe/pytorch/api_accuracy_checker/compare/compare.py +36 -34
  172. msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +15 -13
  173. msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +206 -4
  174. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +9 -9
  175. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +6 -5
  176. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +31 -9
  177. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +28 -20
  178. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +3 -1
  179. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +29 -13
  180. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +12 -2
  181. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +45 -31
  182. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +154 -0
  183. msprobe/pytorch/attl_manager.py +65 -0
  184. msprobe/pytorch/bench_functions/moe_gating_top_k_softmax.py +6 -0
  185. msprobe/pytorch/bench_functions/npu_fusion_attention.py +27 -0
  186. msprobe/pytorch/common/utils.py +53 -19
  187. msprobe/pytorch/compare/distributed_compare.py +4 -36
  188. msprobe/pytorch/compare/pt_compare.py +13 -84
  189. msprobe/pytorch/compare/utils.py +47 -0
  190. msprobe/pytorch/debugger/debugger_config.py +34 -17
  191. msprobe/pytorch/debugger/precision_debugger.py +50 -96
  192. msprobe/pytorch/dump/module_dump/hook_wrapper.py +93 -0
  193. msprobe/pytorch/dump/module_dump/module_dump.py +15 -61
  194. msprobe/pytorch/dump/module_dump/module_processer.py +150 -114
  195. msprobe/pytorch/free_benchmark/common/utils.py +1 -1
  196. msprobe/pytorch/free_benchmark/compare/single_benchmark.py +1 -1
  197. msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +3 -3
  198. msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +3 -3
  199. msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +1 -1
  200. msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +1 -1
  201. msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +1 -1
  202. msprobe/pytorch/function_factory.py +1 -1
  203. msprobe/pytorch/grad_probe/grad_monitor.py +2 -2
  204. msprobe/pytorch/grad_probe/grad_stat_csv.py +3 -2
  205. msprobe/pytorch/hook_module/api_register.py +155 -0
  206. msprobe/pytorch/hook_module/hook_module.py +18 -22
  207. msprobe/pytorch/hook_module/jit_script_wrapper.py +33 -0
  208. msprobe/pytorch/hook_module/pt_hook_manager.py +68 -0
  209. msprobe/pytorch/hook_module/register_optimizer_hook.py +2 -1
  210. msprobe/pytorch/hook_module/support_wrap_ops.yaml +193 -75
  211. msprobe/pytorch/hook_module/utils.py +28 -2
  212. msprobe/pytorch/monitor/csv2tb.py +14 -4
  213. msprobe/pytorch/monitor/data_writers.py +259 -0
  214. msprobe/pytorch/monitor/distributed/wrap_distributed.py +8 -2
  215. msprobe/pytorch/monitor/module_hook.py +336 -241
  216. msprobe/pytorch/monitor/module_metric.py +17 -0
  217. msprobe/pytorch/monitor/optimizer_collect.py +244 -224
  218. msprobe/pytorch/monitor/utils.py +84 -4
  219. msprobe/pytorch/online_dispatch/compare.py +0 -2
  220. msprobe/pytorch/online_dispatch/dispatch.py +13 -2
  221. msprobe/pytorch/online_dispatch/dump_compare.py +8 -2
  222. msprobe/pytorch/online_dispatch/utils.py +3 -0
  223. msprobe/pytorch/parse_tool/lib/interactive_cli.py +1 -6
  224. msprobe/pytorch/parse_tool/lib/utils.py +5 -4
  225. msprobe/pytorch/pt_config.py +16 -11
  226. msprobe/pytorch/pytorch_service.py +70 -0
  227. msprobe/visualization/builder/graph_builder.py +69 -10
  228. msprobe/visualization/builder/msprobe_adapter.py +24 -12
  229. msprobe/visualization/compare/graph_comparator.py +63 -51
  230. msprobe/visualization/compare/mode_adapter.py +22 -20
  231. msprobe/visualization/graph/base_node.py +11 -4
  232. msprobe/visualization/graph/distributed_analyzer.py +1 -10
  233. msprobe/visualization/graph/graph.py +2 -13
  234. msprobe/visualization/graph/node_op.py +1 -2
  235. msprobe/visualization/graph_service.py +251 -104
  236. msprobe/visualization/utils.py +26 -44
  237. msprobe/mindspore/dump/hook_cell/api_registry.py +0 -207
  238. msprobe/mindspore/dump/hook_cell/wrap_api.py +0 -212
  239. msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +0 -140
  240. msprobe/mindspore/monitor/anomaly_detect.py +0 -404
  241. msprobe/mindspore/monitor/module_spec_verifier.py +0 -94
  242. msprobe/mindspore/service.py +0 -543
  243. msprobe/pytorch/hook_module/api_registry.py +0 -166
  244. msprobe/pytorch/hook_module/wrap_distributed.py +0 -79
  245. msprobe/pytorch/hook_module/wrap_functional.py +0 -66
  246. msprobe/pytorch/hook_module/wrap_npu_custom.py +0 -85
  247. msprobe/pytorch/hook_module/wrap_tensor.py +0 -69
  248. msprobe/pytorch/hook_module/wrap_torch.py +0 -84
  249. msprobe/pytorch/hook_module/wrap_vf.py +0 -60
  250. msprobe/pytorch/monitor/anomaly_analyse.py +0 -201
  251. msprobe/pytorch/monitor/anomaly_detect.py +0 -410
  252. msprobe/pytorch/monitor/module_spec_verifier.py +0 -95
  253. msprobe/pytorch/monitor/unittest/test_monitor.py +0 -160
  254. msprobe/pytorch/service.py +0 -470
  255. {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/LICENSE +0 -0
  256. {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/WHEEL +0 -0
  257. {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/entry_points.txt +0 -0
  258. {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/top_level.txt +0 -0
  259. /msprobe/{mindspore → core}/compare/ms_to_pt_api.yaml +0 -0
  260. /msprobe/{mindspore/dump → core}/kernel_dump/kernel_config.py +0 -0
  261. /msprobe/{pytorch/monitor/unittest → core/monitor}/__init__.py +0 -0
@@ -20,33 +20,45 @@ import zlib
20
20
  from dataclasses import dataclass
21
21
 
22
22
  import numpy as np
23
+ import pandas as pd
23
24
 
24
25
  from msprobe.core.common.const import Const, CompareConst, FileCheckConst
25
26
  from msprobe.core.common.utils import CompareException, check_regex_prefix_format_valid, logger, safe_get_value
26
27
  from msprobe.core.common.file_utils import check_file_or_directory_path
27
28
 
29
+ json_file_mapping = {
30
+ Const.DUMP_JSON_FILE: "dump.json",
31
+ Const.DEBUG_JSON_FILE: "debug.json",
32
+ Const.STACK_JSON_FILE: "stack.json"
33
+ }
28
34
 
29
- def extract_json(dirname, stack_json=False):
35
+
36
+ def extract_json(dirname, json_file_type):
30
37
  json_path = ''
31
38
  for filename in os.listdir(dirname):
32
- target_file_name = 'stack.json' if stack_json else 'dump.json'
39
+ target_file_name = json_file_mapping.get(json_file_type)
40
+ if target_file_name is None:
41
+ logger.error(f'extract_json failed, invalid json_file_type: {json_file_type}.')
42
+ raise CompareException(CompareException.INVALID_KEY_ERROR)
33
43
  if filename == target_file_name:
34
44
  json_path = os.path.join(dirname, filename)
35
45
  break
36
46
 
37
47
  # Provide robustness on invalid directory inputs
38
48
  if not json_path:
39
- if stack_json:
49
+ if json_file_type == Const.STACK_JSON_FILE:
40
50
  logger.warning(f'stack.json is not found in dump dir {dirname}.')
41
- else:
51
+ elif json_file_type == Const.DUMP_JSON_FILE:
42
52
  logger.error(f'dump.json is not found in dump dir {dirname}.')
43
- raise CompareException(CompareException.NO_DUMP_FILE_ERROR)
53
+ elif json_file_type == Const.DEBUG_JSON_FILE:
54
+ logger.warning(f'debug.json is not found in dump dir {dirname}.')
55
+
44
56
  return json_path
45
57
 
46
58
 
47
59
  def set_stack_json_path(input_param):
48
60
  npu_data_dir = os.path.dirname(input_param.get("npu_json_path"))
49
- stack_path = extract_json(npu_data_dir, stack_json=True)
61
+ stack_path = extract_json(npu_data_dir, json_file_type=Const.STACK_JSON_FILE)
50
62
  input_param["stack_json_path"] = stack_path if stack_path else None
51
63
  return bool(stack_path)
52
64
 
@@ -81,24 +93,9 @@ def check_and_return_dir_contents(dump_dir, prefix):
81
93
  return contents
82
94
 
83
95
 
84
- def rename_api(npu_name, process):
85
- """
86
- 原api: {api_type}.{api_name}.{API调用次数}.{前向反向}.{input/output}.{参数序号}
87
- rename后: {api_type}.{api_name}.{input/output}.{参数序号}
88
- """
89
- npu_split = npu_name.split(process)
90
- try:
91
- torch_func_index, in_out = npu_split[0], npu_split[1]
92
- except IndexError as error:
93
- logger.error(f'{npu_name} can not be split with {process}, please check!')
94
- raise CompareException(CompareException.INDEX_OUT_OF_BOUNDS_ERROR) from error
95
- torch_func_split = torch_func_index.rsplit(Const.SEP, 2)
96
- torch_func = str(torch_func_split[0]) + str(in_out)
97
- return torch_func
98
-
99
-
100
96
  def read_op(op_data, op_name):
101
- if Const.PARAMS_GRAD in op_name.split(Const.SEP):
97
+ split_name = op_name.split(Const.SEP)
98
+ if Const.DEBUG in split_name or Const.PARAMS_GRAD in split_name:
102
99
  op_parsed_list = op_item_parse(op_data, op_name)
103
100
  else:
104
101
  op_parsed_list = []
@@ -191,35 +188,152 @@ def gen_op_item(op_data, op_name):
191
188
  return op_item
192
189
 
193
190
 
194
- def resolve_api_special_parameters(data_dict, full_op_name, item_list):
191
+ @dataclass
192
+ class ApiItemInfo:
193
+ name: str
194
+ struct: tuple
195
+ stack_info: list
196
+
197
+
198
+ def merge_tensor(tensor_list, dump_mode):
199
+ keys = [
200
+ CompareConst.OP_NAME,
201
+ CompareConst.INPUT_STRUCT,
202
+ CompareConst.KWARGS_STRUCT,
203
+ CompareConst.OUTPUT_STRUCT,
204
+ CompareConst.PARAMS_STRUCT,
205
+ CompareConst.PARAMS_GRAD_STRUCT,
206
+ CompareConst.DEBUG_STRUCT,
207
+ Const.SUMMARY,
208
+ Const.STACK_INFO
209
+ ]
210
+ op_dict = {key: [] for key in keys}
211
+
212
+ if dump_mode == Const.ALL:
213
+ op_dict["data_name"] = []
214
+
215
+ for tensor in tensor_list:
216
+ # A dict(len=2) with 'full_op_name' and 'full_info' is added to the tensor only if self.stack_mode is True
217
+ if len(tensor) == 2:
218
+ op_dict[Const.STACK_INFO].append(tensor['full_info'])
219
+ break
220
+
221
+ op_dict[CompareConst.OP_NAME].append(tensor['full_op_name'])
222
+
223
+ _, state = get_name_and_state(tensor['full_op_name'])
224
+ struct_key = CompareConst.STATE_TO_STRUCT_MAPPING.get(state)
225
+ if not struct_key:
226
+ continue
227
+ if dump_mode == Const.MD5:
228
+ op_dict.get(struct_key).append((tensor[Const.DTYPE], tensor[Const.SHAPE], tensor[Const.MD5]))
229
+ else:
230
+ op_dict.get(struct_key).append((tensor[Const.DTYPE], tensor[Const.SHAPE]))
231
+ op_dict[Const.SUMMARY].append([tensor[Const.MAX], tensor[Const.MIN], tensor[Const.MEAN], tensor[Const.NORM]])
232
+
233
+ if dump_mode == Const.ALL:
234
+ op_dict["data_name"].append(tensor['data_name'])
235
+
236
+ if not op_dict[CompareConst.KWARGS_STRUCT]:
237
+ del op_dict[CompareConst.KWARGS_STRUCT]
238
+ return op_dict if op_dict[CompareConst.OP_NAME] else {}
239
+
240
+
241
+ def print_compare_ends_info():
242
+ total_len = len(CompareConst.COMPARE_ENDS_SUCCESSFULLY) + Const.FILL_CHAR_NUMS
243
+ logger.info('*' * total_len)
244
+ logger.info(f"*{CompareConst.COMPARE_ENDS_SUCCESSFULLY.center(total_len - 2)}*")
245
+ logger.info('*' * total_len)
246
+
247
+
248
+ def table_value_is_valid(value: str) -> bool:
249
+ if not isinstance(value, str):
250
+ return True
251
+ try:
252
+ # -1.00 or +1.00 should be considered as digit numbers
253
+ float(value)
254
+ except ValueError:
255
+ # otherwise, they will be considered as formular injections
256
+ return not bool(re.compile(FileCheckConst.CSV_BLACK_LIST).search(value))
257
+ return True
258
+
259
+
260
+ def get_name_and_state(name):
195
261
  """
196
- Function Description:
197
- 解析下面格式的数据, 是api参数的一种特殊格式
198
- {
199
- "last_hidden_state": {
200
- "type": "torch.Tensor",
201
- "dtype": "torch.bfloat16",
202
- ...
203
- },
204
- "loss": {
205
- "type": "torch.Tensor",
206
- "dtype": "torch.float32",
207
- ...
208
- }
209
- }
210
- Parameter:
211
- data_dict: 字典格式的数据
212
- full_op_name: 参数的全名字符串
213
- item_list: 参数信息集合
262
+ Get api/module name and state
263
+ example:
264
+ name = 'conv2d.forward.1.input.0'
265
+ return: ('conv2d.forward.1.', 'input')
266
+
267
+ name = 'Functional.pad.0.backward.output.0'
268
+ return: ('Functional.pad.0.backward.', 'output')
269
+
270
+ name = 'x_tensor.0.debug.{index}'
271
+ return: ('x_tensor.0.', 'debug')
272
+
273
+ state type: input, output, kwargs, parameters, parameters_grad, debug
214
274
  """
215
- for key, value in data_dict.items():
216
- if isinstance(value, dict):
217
- parsed_item = value
218
- parts = full_op_name.split(Const.SEP)
219
- parts.insert(-1, key)
220
- full_op_name_new = ".".join(parts)
221
- parsed_item['full_op_name'] = full_op_name_new
222
- item_list.append(parsed_item)
275
+ if not isinstance(name, str):
276
+ logger.error(f'Invalid name: {name}, type should be string, please check.')
277
+ raise CompareException(CompareException.INVALID_API_NAME_ERROR)
278
+
279
+ if Const.DEBUG in name.split(Const.SEP):
280
+ return name.split(Const.DEBUG)[0], Const.DEBUG
281
+ if Const.PARAMS_GRAD in name.split(Const.SEP):
282
+ return name.split(Const.PARAMS_GRAD)[0], Const.PARAMS_GRAD
283
+
284
+ split = re.split(Const.REGEX_FORWARD_BACKWARD, name)
285
+ if len(split) < 3:
286
+ logger.error(f'Invalid name string: {name}, can not be split by forward/backward, please check.')
287
+ raise CompareException(CompareException.INVALID_API_NAME_ERROR)
288
+ api = f'{split[0]}.{split[1]}.'
289
+ state_str = split[2]
290
+ match = re.match(r'^(\d+\.)?(input|output|kwargs|parameters)\..+$', state_str)
291
+ if not match:
292
+ raise CompareException(f'Invalid name string: {name}')
293
+ if match.group(1):
294
+ api = f'{api}{match.group(1)}'
295
+ state = match.group(2)
296
+ return api, state
297
+
298
+
299
+ def reorder_op_name_list(op_name_list):
300
+ if not op_name_list:
301
+ return op_name_list
302
+
303
+ parameters = []
304
+ output = []
305
+ parameters_grad = []
306
+ others = []
307
+ for x in op_name_list:
308
+ state = get_name_and_state(x)[1]
309
+ if state == Const.PARAMS:
310
+ parameters.append(x)
311
+ elif state == Const.OUTPUT:
312
+ output.append(x)
313
+ elif state == Const.PARAMS_GRAD:
314
+ parameters_grad.append(x)
315
+ else:
316
+ others.append(x)
317
+ # 合并others, parameters, 和output,确保parameters排在output前面
318
+ op_name_reorder = others + parameters + output + parameters_grad
319
+ return op_name_reorder
320
+
321
+
322
+ def reorder_op_x_list(op_name_list, summary_list, data_name_list):
323
+ """对op_name, summary, data_name重新排序,把parameters放到input后output前,data_name由于统计量比对时,为None,单独处理"""
324
+ if not op_name_list or not summary_list:
325
+ return op_name_list, summary_list, data_name_list
326
+
327
+ index_map = {name: index for index, name in enumerate(op_name_list)}
328
+
329
+ op_name_reorder = reorder_op_name_list(op_name_list)
330
+ summary_reorder = [summary_list[index_map.get(name)] for name in op_name_reorder]
331
+ if data_name_list:
332
+ data_name_reorder = [data_name_list[index_map.get(name)] for name in op_name_reorder]
333
+ else:
334
+ data_name_reorder = data_name_list
335
+
336
+ return op_name_reorder, summary_reorder, data_name_reorder
223
337
 
224
338
 
225
339
  def process_summary_data(summary_data):
@@ -285,9 +399,9 @@ def result_item_init(n_info, b_info, dump_mode):
285
399
  md5_compare_result = CompareConst.PASS if n_info.struct[2] == b_info.struct[2] else CompareConst.DIFF
286
400
  result_item.extend([n_info.struct[2], b_info.struct[2], md5_compare_result])
287
401
  elif dump_mode == Const.SUMMARY:
288
- result_item.extend([" "] * 8)
402
+ result_item.extend([" "] * 8) # 8个统计量数据情况的比对指标
289
403
  else:
290
- result_item.extend([" "] * 5)
404
+ result_item.extend([" "] * 6) # 6个真实数据情况的比对指标
291
405
  else:
292
406
  err_msg = "index out of bounds error will occur in result_item_init, please check!\n" \
293
407
  f"npu_info_struct is {n_info.struct}\n" \
@@ -321,8 +435,8 @@ def get_accuracy(result, n_dict, b_dict, dump_mode):
321
435
  has_stack = npu_stack_info and bench_stack_info
322
436
 
323
437
  if dump_mode == Const.ALL:
324
- npu_data_name = n_dict.get("data_name", None)
325
- bench_data_name = b_dict.get("data_name", None)
438
+ npu_data_name_list = n_dict.get("data_name", None)
439
+ bench_data_name_list = b_dict.get("data_name", None)
326
440
 
327
441
  for index in range(min_len):
328
442
  n_name = safe_get_value(n_dict, n_start + index, "n_dict", key="op_name")
@@ -353,7 +467,9 @@ def get_accuracy(result, n_dict, b_dict, dump_mode):
353
467
  result_item.append(err_msg)
354
468
  result_item = stack_column_process(result_item, has_stack, index, key, npu_stack_info)
355
469
  if dump_mode == Const.ALL:
356
- result_item.append(safe_get_value(npu_data_name, n_start + index, "npu_data_name"))
470
+ npu_data_name = safe_get_value(npu_data_name_list, n_start + index, "npu_data_name_list")
471
+ bench_data_name = safe_get_value(bench_data_name_list, b_start + index, "bench_data_name_list")
472
+ result_item.append([npu_data_name, bench_data_name])
357
473
 
358
474
  result.append(result_item)
359
475
 
@@ -371,7 +487,7 @@ def get_accuracy(result, n_dict, b_dict, dump_mode):
371
487
  continue
372
488
  result_item = [
373
489
  n_name, CompareConst.NAN, n_struct[0], CompareConst.NAN, n_struct[1], CompareConst.NAN,
374
- " ", " ", " ", " ", " "
490
+ " ", " ", " ", " ", " ", " "
375
491
  ]
376
492
  summary_data = n_dict.get(CompareConst.SUMMARY)[n_start + index]
377
493
  result_item.extend(summary_data)
@@ -388,7 +504,8 @@ def get_accuracy(result, n_dict, b_dict, dump_mode):
388
504
  result_item.append(err_msg)
389
505
  result_item = stack_column_process(result_item, has_stack, index, key, npu_stack_info)
390
506
  if dump_mode == Const.ALL:
391
- result_item.append(safe_get_value(npu_data_name, n_start + index, "npu_data_name"))
507
+ npu_data_name = safe_get_value(npu_data_name_list, n_start + index, "npu_data_name_list")
508
+ result_item.append([npu_data_name, "-1"])
392
509
 
393
510
  result.append(result_item)
394
511
 
@@ -404,197 +521,23 @@ def get_accuracy(result, n_dict, b_dict, dump_mode):
404
521
  CompareConst.PARAMS_GRAD_STRUCT)
405
522
 
406
523
 
407
- def append_stack_info(result_item, npu_stack_info, index):
408
- """添加堆栈信息到 result_item"""
409
- if npu_stack_info and index == 0:
410
- result_item.extend(npu_stack_info)
411
- else:
412
- result_item.append(CompareConst.NONE)
413
-
524
+ def make_result_table(result, dump_mode, stack_mode):
525
+ header = CompareConst.HEAD_OF_COMPARE_MODE[dump_mode][:]
414
526
 
415
- def get_un_match_accuracy(result, n_dict, dump_mode):
416
- npu_stack_info = n_dict.get("stack_info", None)
417
- bench_name, bench_type, bench_shape = CompareConst.N_A, CompareConst.N_A, CompareConst.N_A
418
-
419
- struct_to_index_mapping = {
420
- CompareConst.INPUT_STRUCT: 0,
421
- CompareConst.OUTPUT_STRUCT: 0,
422
- CompareConst.PARAMS_STRUCT: 0,
423
- CompareConst.PARAMS_GRAD_STRUCT: 0
424
- }
425
-
426
- op_name_list = n_dict.get(CompareConst.OP_NAME)
427
- summary_list = n_dict.get(Const.SUMMARY)
428
- data_name_list = n_dict.get('data_name')
429
- op_name_reorder, summary_reorder, _ = reorder_op_x_list(op_name_list,
430
- summary_list,
431
- data_name_list)
432
- for index, n_name in enumerate(op_name_reorder):
433
- _, state = get_name_and_state(n_name)
434
- struct_key = CompareConst.STATE_TO_STRUCT_MAPPING.get(state)
435
- if not struct_key:
436
- continue
437
- n_struct = safe_get_value(n_dict, struct_to_index_mapping.get(struct_key), "n_dict", key=struct_key)
438
- struct_to_index_mapping[struct_key] += 1
439
-
440
- try:
441
- result_item = [n_name, bench_name, n_struct[0], bench_type, n_struct[1], bench_shape]
442
- except IndexError as e:
443
- err_msg = "index out of bounds error occurs, please check!\n" \
444
- f"op_name of n_dict is {n_dict['op_name']}\n" \
445
- f"input_struct of n_dict is {n_dict[CompareConst.INPUT_STRUCT]}\n" \
446
- f"output_struct of n_dict is {n_dict[CompareConst.OUTPUT_STRUCT]}"
447
- logger.error(err_msg)
448
- raise CompareException(CompareException.INDEX_OUT_OF_BOUNDS_ERROR) from e
449
-
450
- if dump_mode == Const.MD5:
451
- result_item.extend([CompareConst.N_A] * 3)
452
- append_stack_info(result_item, npu_stack_info, index)
453
- result.append(result_item)
454
- continue
455
- if dump_mode == Const.SUMMARY:
456
- result_item.extend([CompareConst.N_A] * 8)
527
+ if stack_mode:
528
+ header.append(CompareConst.STACK)
457
529
  if dump_mode == Const.ALL:
458
- result_item.extend([CompareConst.N_A] * 5)
459
-
460
- npu_summary_data = safe_get_value(summary_reorder, index, "summary_reorder")
461
- bench_summary_data = [CompareConst.N_A] * 4
462
- result_item.extend(npu_summary_data)
463
- result_item.extend(bench_summary_data)
464
- err_msg = CompareConst.NO_BENCH
465
- accuracy_check_res = CompareConst.N_A
466
- result_item.append(accuracy_check_res)
467
- result_item.append(err_msg)
468
- append_stack_info(result_item, npu_stack_info, index)
469
- if dump_mode == Const.ALL and result_item[1] == CompareConst.N_A:
470
- result_item.extend(["-1"])
471
- result.append(result_item)
472
-
473
-
474
- def merge_tensor(tensor_list, dump_mode):
475
- op_dict = {}
476
- op_dict["op_name"] = []
477
- op_dict[CompareConst.INPUT_STRUCT] = []
478
- op_dict[CompareConst.KWARGS_STRUCT] = []
479
- op_dict[CompareConst.OUTPUT_STRUCT] = []
480
- op_dict[CompareConst.PARAMS_STRUCT] = []
481
- op_dict[CompareConst.PARAMS_GRAD_STRUCT] = []
482
- op_dict[Const.SUMMARY] = []
483
- op_dict["stack_info"] = []
484
-
485
- if dump_mode == Const.ALL:
486
- op_dict["data_name"] = []
487
-
488
- for tensor in tensor_list:
489
- # A dict(len=2) with 'full_op_name' and 'full_info' is added to the tensor only if self.stack_mode is True
490
- if len(tensor) == 2:
491
- op_dict['stack_info'].append(tensor['full_info'])
492
- break
493
-
494
- op_dict["op_name"].append(tensor['full_op_name'])
495
-
496
- _, state = get_name_and_state(tensor['full_op_name'])
497
- struct_key = CompareConst.STATE_TO_STRUCT_MAPPING.get(state)
498
- if not struct_key:
499
- continue
500
- if dump_mode == Const.MD5:
501
- op_dict.get(struct_key).append((tensor[Const.DTYPE], tensor[Const.SHAPE], tensor[Const.MD5]))
502
- else:
503
- op_dict.get(struct_key).append((tensor[Const.DTYPE], tensor[Const.SHAPE]))
504
- op_dict[Const.SUMMARY].append([tensor[Const.MAX], tensor[Const.MIN], tensor[Const.MEAN], tensor[Const.NORM]])
505
-
530
+ header.append(CompareConst.DATA_NAME)
531
+ else:
506
532
  if dump_mode == Const.ALL:
507
- op_dict["data_name"].append(tensor['data_name'])
508
-
509
- if not op_dict[CompareConst.KWARGS_STRUCT]:
510
- del op_dict[CompareConst.KWARGS_STRUCT]
511
- return op_dict if op_dict["op_name"] else {}
512
-
513
-
514
- def print_compare_ends_info():
515
- total_len = len(CompareConst.COMPARE_ENDS_SUCCESSFULLY) + Const.FILL_CHAR_NUMS
516
- logger.info('*' * total_len)
517
- logger.info(f"*{CompareConst.COMPARE_ENDS_SUCCESSFULLY.center(total_len - 2)}*")
518
- logger.info('*' * total_len)
519
-
520
-
521
- def table_value_is_valid(value: str) -> bool:
522
- if not isinstance(value, str):
523
- return True
524
- try:
525
- # -1.00 or +1.00 should be consdiered as digit numbers
526
- float(value)
527
- except ValueError:
528
- # otherwise, they will be considered as formular injections
529
- return not bool(re.compile(FileCheckConst.CSV_BLACK_LIST).search(value))
530
- return True
531
-
532
-
533
- def get_name_and_state(name):
534
- """
535
- Get api/module name and state
536
- example:
537
- name = 'conv2d.forward.1.input.0'
538
- return: ('conv2d.forward.1.', 'input')
539
-
540
- name = 'Functional.pad.0.backward.output.0'
541
- return: ('Functional.pad.0.backward.', 'output')
542
-
543
- state type: input, output, kwargs, parameters, parameters_grad
544
- """
545
- if Const.PARAMS_GRAD in name.split(Const.SEP):
546
- return name.split(Const.PARAMS_GRAD)[0], Const.PARAMS_GRAD
547
-
548
- split = re.split(Const.REGEX_FORWARD_BACKWARD, name)
549
- api = f'{split[0]}.{split[1]}.'
550
- state_str = split[2]
551
- match = re.match(r'^(\d+\.)?(input|output|kwargs|parameters)\..+$', state_str)
552
- if not match:
553
- raise CompareException(f'Invalid name string: {name}')
554
- if match.group(1):
555
- api = f'{api}{match.group(1)}'
556
- state = match.group(2)
557
- return api, state
558
-
559
-
560
- def reorder_op_name_list(op_name_list):
561
- if not op_name_list:
562
- return op_name_list
563
-
564
- parameters = []
565
- output = []
566
- parameters_grad = []
567
- others = []
568
- for x in op_name_list:
569
- state = get_name_and_state(x)[1]
570
- if state == Const.PARAMS:
571
- parameters.append(x)
572
- elif state == Const.OUTPUT:
573
- output.append(x)
574
- elif state == Const.PARAMS_GRAD:
575
- parameters_grad.append(x)
533
+ for row in result:
534
+ del row[-2] # 输出结果不要堆栈信息时,删除中间结果result中的stack info,真实数据时为倒数第2列
535
+ header.append(CompareConst.DATA_NAME)
576
536
  else:
577
- others.append(x)
578
- # 合并others, parameters, 和output,确保parameters排在output前面
579
- op_name_reorder = others + parameters + output + parameters_grad
580
- return op_name_reorder
581
-
582
-
583
- def reorder_op_x_list(op_name_list, summary_list, data_name_list):
584
- """对op_name, summary, data_name重新排序,把parameters放到input后output前,data_name由于统计量比对时,为None,单独处理"""
585
- if not op_name_list or not summary_list:
586
- return op_name_list, summary_list, data_name_list
587
-
588
- index_map = {name: index for index, name in enumerate(op_name_list)}
589
-
590
- op_name_reorder = reorder_op_name_list(op_name_list)
591
- summary_reorder = [summary_list[index_map.get(name)] for name in op_name_reorder]
592
- if data_name_list:
593
- data_name_reorder = [data_name_list[index_map.get(name)] for name in op_name_reorder]
594
- else:
595
- data_name_reorder = data_name_list
596
-
597
- return op_name_reorder, summary_reorder, data_name_reorder
537
+ for row in result:
538
+ del row[-1] # 输出结果不要堆栈信息时,删除中间结果result中的stack info,非真实数据时为倒数第1列
539
+ result_df = pd.DataFrame(result, columns=header, dtype='object')
540
+ return result_df
598
541
 
599
542
 
600
543
  def _compare_parser(parser):
@@ -617,3 +560,34 @@ def _compare_parser(parser):
617
560
  help="<optional> The data mapping file path.", required=False)
618
561
  parser.add_argument("-lm", "--layer_mapping", dest="layer_mapping", type=str, nargs='?', const=True,
619
562
  help="<optional> The layer mapping file path.", required=False)
563
+
564
+
565
+ def compare_distributed_inner(npu_dump_dir, bench_dump_dir, output_path, compare_func, **kwargs):
566
+ if kwargs.get('suffix'):
567
+ logger.error("Argument 'suffix' is not supported for compare_distributed.")
568
+ raise CompareException(CompareException.INVALID_PARAM_ERROR)
569
+ is_print_compare_log = kwargs.get('is_print_compare_log', True)
570
+ # get the ranks and match by order
571
+ npu_ranks = sorted(check_and_return_dir_contents(npu_dump_dir, 'rank'))
572
+ bench_ranks = sorted(check_and_return_dir_contents(bench_dump_dir, 'rank'))
573
+ if len(npu_ranks) != len(bench_ranks):
574
+ logger.error('The number of ranks in the two runs are different. '
575
+ 'Unable to match the ranks. Please use another folder to compare '
576
+ 'or use compare() api and manually match the ranks.')
577
+ raise CompareException(CompareException.INVALID_PATH_ERROR)
578
+ for nr, br in zip(npu_ranks, bench_ranks):
579
+ npu_data_dir = os.path.join(npu_dump_dir, nr)
580
+ bench_data_dir = os.path.join(bench_dump_dir, br)
581
+ for file_type in [Const.DUMP_JSON_FILE, Const.DEBUG_JSON_FILE]:
582
+ npu_path = extract_json(npu_data_dir, file_type)
583
+ bench_path = extract_json(bench_data_dir, file_type)
584
+ if npu_path == "" or bench_path == "":
585
+ logger.debug(f'Did not find paired {file_type} in {npu_data_dir} and {bench_data_dir},'
586
+ ' skip comparing.')
587
+ continue
588
+ dump_result_param = {
589
+ 'npu_json_path': npu_path,
590
+ 'bench_json_path': bench_path,
591
+ 'is_print_compare_log': is_print_compare_log
592
+ }
593
+ compare_func(input_param=dump_result_param, output_path=output_path, suffix=f'_{nr}', **kwargs)
@@ -13,7 +13,5 @@
13
13
  # See the License for the specific language governing permissions and
14
14
  # limitations under the License.
15
15
 
16
- class Runtime:
17
- step_count: int = 0
18
- rank_id: int = -1
19
- is_running: bool = False
16
+ import msprobe.core.config_check.checkers
17
+ from msprobe.core.config_check.config_checker import ConfigChecker
@@ -13,21 +13,13 @@
13
13
  # See the License for the specific language governing permissions and
14
14
  # limitations under the License.
15
15
 
16
- import os
16
+ __all__ = ['BaseChecker', 'apply_patches']
17
17
 
18
- from msprobe.core.common.file_utils import save_json
18
+ import msprobe.core.config_check.checkers.env_args_checker
19
+ import msprobe.core.config_check.checkers.pip_checker
20
+ import msprobe.core.config_check.checkers.dataset_checker
21
+ import msprobe.core.config_check.checkers.weights_checker
22
+ import msprobe.core.config_check.checkers.hyperparameter_checker
23
+ import msprobe.core.config_check.checkers.random_checker
19
24
 
20
-
21
- def create_kernel_config_json(dump_path, cur_rank):
22
- kernel_config_name = "kernel_config.json" if cur_rank == '' else f"kernel_config_{cur_rank}.json"
23
- kernel_config_path = os.path.join(dump_path, kernel_config_name)
24
- config_info = {
25
- "dump": {
26
- "dump_list": [],
27
- "dump_path": dump_path,
28
- "dump_mode": "all",
29
- "dump_op_switch": "on"
30
- }
31
- }
32
- save_json(kernel_config_path, config_info, indent=4)
33
- return kernel_config_path
25
+ from msprobe.core.config_check.checkers.base_checker import BaseChecker
@@ -0,0 +1,60 @@
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import os
17
+
18
+ from msprobe.core.common.framework_adapter import FmkAdp
19
+ from msprobe.core.common.const import FileCheckConst
20
+
21
+
22
+ class PackInput:
23
+
24
+ def __init__(self, output_zip_path, model, shell_path):
25
+ self.output_zip_path = output_zip_path
26
+ self.shell_path = shell_path
27
+ self.model = model[0] if isinstance(model, list) and len(model) > 0 else model
28
+ self.check_input_params()
29
+
30
+ def check_input_params(self):
31
+ if self.model and not FmkAdp.is_nn_module(self.model):
32
+ raise Exception(f"model is not torch.nn.Module/mindspore.nn.Cell or module list.")
33
+ if not isinstance(self.output_zip_path, str) or not self.output_zip_path.endswith(FileCheckConst.ZIP_SUFFIX):
34
+ raise Exception(f"output zip path must be a string and ends with '.zip'")
35
+
36
+
37
+ class BaseChecker:
38
+ input_needed = None
39
+ target_name_in_zip = None
40
+ multi_rank = False
41
+
42
+ @staticmethod
43
+ def pack(pack_input):
44
+ pass
45
+
46
+ @staticmethod
47
+ def compare(bench_dir, cmp_dir, output_path, fmk):
48
+ pass
49
+
50
+ @staticmethod
51
+ def apply_patches(fmk):
52
+ pass
53
+
54
+ @classmethod
55
+ def compare_ex(cls, bench_dir, cmp_dir, output_path, fmk):
56
+ bench_filepath = os.path.join(bench_dir, cls.target_name_in_zip)
57
+ cmp_filepath = os.path.join(cmp_dir, cls.target_name_in_zip)
58
+ if not os.path.exists(bench_filepath) or not os.path.exists(cmp_filepath):
59
+ return None, None, None
60
+ return cls.compare(bench_dir, cmp_dir, output_path, fmk)