mindstudio-probe 1.2.2__py3-none-any.whl → 8.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (261) hide show
  1. {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/METADATA +4 -3
  2. {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/RECORD +243 -191
  3. msprobe/README.md +57 -21
  4. msprobe/core/__init__.py +17 -0
  5. msprobe/core/common/const.py +224 -82
  6. msprobe/core/common/decorator.py +50 -0
  7. msprobe/core/common/exceptions.py +5 -3
  8. msprobe/core/common/file_utils.py +274 -40
  9. msprobe/core/common/framework_adapter.py +169 -0
  10. msprobe/core/common/global_lock.py +86 -0
  11. msprobe/core/common/runtime.py +25 -0
  12. msprobe/core/common/utils.py +148 -72
  13. msprobe/core/common_config.py +7 -0
  14. msprobe/core/compare/acc_compare.py +640 -462
  15. msprobe/core/compare/check.py +36 -107
  16. msprobe/core/compare/compare_cli.py +4 -0
  17. msprobe/core/compare/config.py +72 -0
  18. msprobe/core/compare/highlight.py +217 -215
  19. msprobe/core/compare/layer_mapping/layer_mapping.py +4 -1
  20. msprobe/core/compare/merge_result/merge_result.py +12 -6
  21. msprobe/core/compare/multiprocessing_compute.py +227 -107
  22. msprobe/core/compare/npy_compare.py +32 -16
  23. msprobe/core/compare/utils.py +218 -244
  24. msprobe/{mindspore/runtime.py → core/config_check/__init__.py} +2 -4
  25. msprobe/{pytorch/dump/kernel_dump/kernel_config.py → core/config_check/checkers/__init__.py} +8 -16
  26. msprobe/core/config_check/checkers/base_checker.py +60 -0
  27. msprobe/core/config_check/checkers/dataset_checker.py +138 -0
  28. msprobe/core/config_check/checkers/env_args_checker.py +96 -0
  29. msprobe/core/config_check/checkers/hyperparameter_checker.py +170 -0
  30. msprobe/core/config_check/checkers/pip_checker.py +90 -0
  31. msprobe/core/config_check/checkers/random_checker.py +367 -0
  32. msprobe/core/config_check/checkers/weights_checker.py +147 -0
  33. msprobe/core/config_check/ckpt_compare/ckpt_comparator.py +74 -0
  34. msprobe/core/config_check/ckpt_compare/megatron_loader.py +302 -0
  35. msprobe/core/config_check/ckpt_compare/metrics.py +83 -0
  36. msprobe/core/config_check/ckpt_compare/name_mapping.yaml +12 -0
  37. msprobe/core/config_check/config_check_cli.py +51 -0
  38. msprobe/core/config_check/config_checker.py +100 -0
  39. msprobe/{pytorch/parse.py → core/config_check/resource/dependency.yaml} +7 -4
  40. msprobe/core/config_check/resource/env.yaml +57 -0
  41. msprobe/core/config_check/resource/hyperparameter.yaml +21 -0
  42. msprobe/core/config_check/utils/hyperparameter_parser.py +115 -0
  43. msprobe/core/config_check/utils/utils.py +107 -0
  44. msprobe/core/data_dump/api_registry.py +239 -0
  45. msprobe/core/data_dump/data_collector.py +36 -9
  46. msprobe/core/data_dump/data_processor/base.py +74 -53
  47. msprobe/core/data_dump/data_processor/mindspore_processor.py +119 -78
  48. msprobe/core/data_dump/data_processor/pytorch_processor.py +134 -96
  49. msprobe/core/data_dump/json_writer.py +146 -57
  50. msprobe/core/debugger/precision_debugger.py +143 -0
  51. msprobe/core/grad_probe/constant.py +2 -1
  52. msprobe/core/grad_probe/grad_compare.py +2 -2
  53. msprobe/core/grad_probe/utils.py +1 -1
  54. msprobe/core/hook_manager.py +242 -0
  55. msprobe/core/monitor/anomaly_processor.py +384 -0
  56. msprobe/core/overflow_check/abnormal_scene.py +2 -0
  57. msprobe/core/service.py +356 -0
  58. msprobe/core/single_save/__init__.py +0 -0
  59. msprobe/core/single_save/single_comparator.py +243 -0
  60. msprobe/core/single_save/single_saver.py +157 -0
  61. msprobe/docs/01.installation.md +6 -5
  62. msprobe/docs/02.config_introduction.md +89 -30
  63. msprobe/docs/03.config_examples.md +1 -0
  64. msprobe/docs/04.kernel_dump_PyTorch.md +1 -1
  65. msprobe/docs/05.data_dump_PyTorch.md +184 -50
  66. msprobe/docs/06.data_dump_MindSpore.md +193 -28
  67. msprobe/docs/07.accuracy_checker_PyTorch.md +13 -3
  68. msprobe/docs/08.accuracy_checker_online_PyTorch.md +72 -10
  69. msprobe/docs/09.accuracy_checker_MindSpore.md +19 -7
  70. msprobe/docs/10.accuracy_compare_PyTorch.md +266 -102
  71. msprobe/docs/11.accuracy_compare_MindSpore.md +117 -43
  72. msprobe/docs/12.overflow_check_PyTorch.md +5 -3
  73. msprobe/docs/13.overflow_check_MindSpore.md +6 -4
  74. msprobe/docs/14.data_parse_PyTorch.md +4 -10
  75. msprobe/docs/17.grad_probe.md +2 -1
  76. msprobe/docs/18.online_dispatch.md +3 -3
  77. msprobe/docs/19.monitor.md +211 -103
  78. msprobe/docs/21.visualization_PyTorch.md +100 -28
  79. msprobe/docs/22.visualization_MindSpore.md +103 -31
  80. msprobe/docs/23.generate_operator_PyTorch.md +9 -9
  81. msprobe/docs/25.tool_function_introduction.md +23 -22
  82. msprobe/docs/26.data_dump_PyTorch_baseline.md +14 -3
  83. msprobe/docs/27.dump_json_instruction.md +278 -8
  84. msprobe/docs/28.debugger_save_instruction.md +111 -20
  85. msprobe/docs/28.kernel_dump_MindSpore.md +1 -1
  86. msprobe/docs/29.data_dump_MSAdapter.md +229 -0
  87. msprobe/docs/30.overflow_check_MSAdapter.md +31 -0
  88. msprobe/docs/31.config_check.md +95 -0
  89. msprobe/docs/32.ckpt_compare.md +69 -0
  90. msprobe/docs/33.generate_operator_MindSpore.md +190 -0
  91. msprobe/docs/34.RL_collect.md +92 -0
  92. msprobe/docs/35.nan_analyze.md +72 -0
  93. msprobe/docs/FAQ.md +3 -11
  94. msprobe/docs/data_dump_MindSpore/data_dump_MindSpore_baseline.md +12 -1
  95. msprobe/docs/data_dump_MindSpore/dynamic_graph_quick_start_example.md +3 -1
  96. msprobe/docs/img/compare_result.png +0 -0
  97. msprobe/docs/img/merge_result.png +0 -0
  98. msprobe/docs/img/save_compare_result_sample.png +0 -0
  99. msprobe/docs/img/visualization/proxy.png +0 -0
  100. msprobe/docs/img/visualization/vis_browser_1.png +0 -0
  101. msprobe/docs/img/visualization/vis_match_info.png +0 -0
  102. msprobe/docs/img/visualization/vis_precision_info.png +0 -0
  103. msprobe/docs/img/visualization/vis_search_info.png +0 -0
  104. msprobe/docs/img/visualization/vis_show_info.png +0 -0
  105. msprobe/docs/img/visualization/vis_showcase.png +0 -0
  106. msprobe/docs/img/visualization/vis_unmatch_info.png +0 -0
  107. msprobe/mindspore/__init__.py +3 -3
  108. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +151 -55
  109. msprobe/mindspore/api_accuracy_checker/api_runner.py +25 -11
  110. msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +2 -1
  111. msprobe/mindspore/api_accuracy_checker/bench_functions/flash_attention_score.py +580 -0
  112. msprobe/mindspore/api_accuracy_checker/bench_functions/fusion_operator.py +41 -0
  113. msprobe/mindspore/api_accuracy_checker/cmd_parser.py +4 -0
  114. msprobe/mindspore/api_accuracy_checker/data_manager.py +4 -3
  115. msprobe/mindspore/api_accuracy_checker/generate_op_script/config_op.json +9 -0
  116. msprobe/mindspore/api_accuracy_checker/generate_op_script/op_generator.py +451 -0
  117. msprobe/mindspore/api_accuracy_checker/generate_op_script/operator_replication.template +2081 -0
  118. msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +11 -1
  119. msprobe/mindspore/api_accuracy_checker/torch_mindtorch_importer.py +2 -1
  120. msprobe/mindspore/cell_processor.py +204 -33
  121. msprobe/mindspore/code_mapping/graph_parser.py +4 -21
  122. msprobe/mindspore/common/const.py +73 -2
  123. msprobe/mindspore/common/utils.py +157 -29
  124. msprobe/mindspore/compare/common_dir_compare.py +382 -0
  125. msprobe/mindspore/compare/distributed_compare.py +2 -26
  126. msprobe/mindspore/compare/ms_compare.py +18 -398
  127. msprobe/mindspore/compare/ms_graph_compare.py +20 -10
  128. msprobe/mindspore/compare/utils.py +37 -0
  129. msprobe/mindspore/debugger/debugger_config.py +59 -7
  130. msprobe/mindspore/debugger/precision_debugger.py +83 -90
  131. msprobe/mindspore/dump/cell_dump_process.py +902 -0
  132. msprobe/mindspore/dump/cell_dump_with_insert_gradient.py +889 -0
  133. msprobe/mindspore/dump/dump_tool_factory.py +18 -8
  134. msprobe/mindspore/dump/graph_mode_cell_dump.py +139 -0
  135. msprobe/mindspore/dump/graph_tensor_dump.py +123 -0
  136. msprobe/mindspore/dump/hook_cell/api_register.py +176 -0
  137. msprobe/mindspore/dump/hook_cell/hook_cell.py +22 -12
  138. msprobe/mindspore/dump/hook_cell/ms_hook_manager.py +88 -0
  139. msprobe/mindspore/dump/hook_cell/primitive_hooks.py +8 -2
  140. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +42 -26
  141. msprobe/mindspore/dump/jit_dump.py +35 -27
  142. msprobe/mindspore/dump/kernel_kbyk_dump.py +6 -3
  143. msprobe/mindspore/dym_loader/hook_dynamic_loader.cpp +110 -0
  144. msprobe/mindspore/dym_loader/hook_dynamic_loader.h +15 -16
  145. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +22 -12
  146. msprobe/mindspore/free_benchmark/common/utils.py +1 -1
  147. msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +4 -2
  148. msprobe/mindspore/free_benchmark/self_check_tool_factory.py +6 -3
  149. msprobe/mindspore/grad_probe/global_context.py +9 -2
  150. msprobe/mindspore/grad_probe/grad_analyzer.py +2 -1
  151. msprobe/mindspore/grad_probe/grad_stat_csv.py +3 -2
  152. msprobe/mindspore/grad_probe/hook.py +2 -4
  153. msprobe/mindspore/mindspore_service.py +111 -0
  154. msprobe/mindspore/monitor/common_func.py +52 -0
  155. msprobe/mindspore/monitor/data_writers.py +237 -0
  156. msprobe/mindspore/monitor/distributed/wrap_distributed.py +1 -1
  157. msprobe/mindspore/monitor/features.py +13 -1
  158. msprobe/mindspore/monitor/module_hook.py +568 -444
  159. msprobe/mindspore/monitor/optimizer_collect.py +331 -0
  160. msprobe/mindspore/monitor/utils.py +71 -9
  161. msprobe/mindspore/ms_config.py +16 -15
  162. msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +5 -3
  163. msprobe/mindspore/task_handler_factory.py +5 -2
  164. msprobe/msprobe.py +19 -0
  165. msprobe/nan_analyze/__init__.py +14 -0
  166. msprobe/nan_analyze/analyzer.py +255 -0
  167. msprobe/nan_analyze/graph.py +189 -0
  168. msprobe/nan_analyze/utils.py +211 -0
  169. msprobe/pytorch/api_accuracy_checker/common/config.py +2 -2
  170. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +3 -6
  171. msprobe/pytorch/api_accuracy_checker/compare/compare.py +36 -34
  172. msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +15 -13
  173. msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +206 -4
  174. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +9 -9
  175. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +6 -5
  176. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +31 -9
  177. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +28 -20
  178. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +3 -1
  179. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +29 -13
  180. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +12 -2
  181. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +45 -31
  182. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +154 -0
  183. msprobe/pytorch/attl_manager.py +65 -0
  184. msprobe/pytorch/bench_functions/moe_gating_top_k_softmax.py +6 -0
  185. msprobe/pytorch/bench_functions/npu_fusion_attention.py +27 -0
  186. msprobe/pytorch/common/utils.py +53 -19
  187. msprobe/pytorch/compare/distributed_compare.py +4 -36
  188. msprobe/pytorch/compare/pt_compare.py +13 -84
  189. msprobe/pytorch/compare/utils.py +47 -0
  190. msprobe/pytorch/debugger/debugger_config.py +34 -17
  191. msprobe/pytorch/debugger/precision_debugger.py +50 -96
  192. msprobe/pytorch/dump/module_dump/hook_wrapper.py +93 -0
  193. msprobe/pytorch/dump/module_dump/module_dump.py +15 -61
  194. msprobe/pytorch/dump/module_dump/module_processer.py +150 -114
  195. msprobe/pytorch/free_benchmark/common/utils.py +1 -1
  196. msprobe/pytorch/free_benchmark/compare/single_benchmark.py +1 -1
  197. msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +3 -3
  198. msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +3 -3
  199. msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +1 -1
  200. msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +1 -1
  201. msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +1 -1
  202. msprobe/pytorch/function_factory.py +1 -1
  203. msprobe/pytorch/grad_probe/grad_monitor.py +2 -2
  204. msprobe/pytorch/grad_probe/grad_stat_csv.py +3 -2
  205. msprobe/pytorch/hook_module/api_register.py +155 -0
  206. msprobe/pytorch/hook_module/hook_module.py +18 -22
  207. msprobe/pytorch/hook_module/jit_script_wrapper.py +33 -0
  208. msprobe/pytorch/hook_module/pt_hook_manager.py +68 -0
  209. msprobe/pytorch/hook_module/register_optimizer_hook.py +2 -1
  210. msprobe/pytorch/hook_module/support_wrap_ops.yaml +193 -75
  211. msprobe/pytorch/hook_module/utils.py +28 -2
  212. msprobe/pytorch/monitor/csv2tb.py +14 -4
  213. msprobe/pytorch/monitor/data_writers.py +259 -0
  214. msprobe/pytorch/monitor/distributed/wrap_distributed.py +8 -2
  215. msprobe/pytorch/monitor/module_hook.py +336 -241
  216. msprobe/pytorch/monitor/module_metric.py +17 -0
  217. msprobe/pytorch/monitor/optimizer_collect.py +244 -224
  218. msprobe/pytorch/monitor/utils.py +84 -4
  219. msprobe/pytorch/online_dispatch/compare.py +0 -2
  220. msprobe/pytorch/online_dispatch/dispatch.py +13 -2
  221. msprobe/pytorch/online_dispatch/dump_compare.py +8 -2
  222. msprobe/pytorch/online_dispatch/utils.py +3 -0
  223. msprobe/pytorch/parse_tool/lib/interactive_cli.py +1 -6
  224. msprobe/pytorch/parse_tool/lib/utils.py +5 -4
  225. msprobe/pytorch/pt_config.py +16 -11
  226. msprobe/pytorch/pytorch_service.py +70 -0
  227. msprobe/visualization/builder/graph_builder.py +69 -10
  228. msprobe/visualization/builder/msprobe_adapter.py +24 -12
  229. msprobe/visualization/compare/graph_comparator.py +63 -51
  230. msprobe/visualization/compare/mode_adapter.py +22 -20
  231. msprobe/visualization/graph/base_node.py +11 -4
  232. msprobe/visualization/graph/distributed_analyzer.py +1 -10
  233. msprobe/visualization/graph/graph.py +2 -13
  234. msprobe/visualization/graph/node_op.py +1 -2
  235. msprobe/visualization/graph_service.py +251 -104
  236. msprobe/visualization/utils.py +26 -44
  237. msprobe/mindspore/dump/hook_cell/api_registry.py +0 -207
  238. msprobe/mindspore/dump/hook_cell/wrap_api.py +0 -212
  239. msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +0 -140
  240. msprobe/mindspore/monitor/anomaly_detect.py +0 -404
  241. msprobe/mindspore/monitor/module_spec_verifier.py +0 -94
  242. msprobe/mindspore/service.py +0 -543
  243. msprobe/pytorch/hook_module/api_registry.py +0 -166
  244. msprobe/pytorch/hook_module/wrap_distributed.py +0 -79
  245. msprobe/pytorch/hook_module/wrap_functional.py +0 -66
  246. msprobe/pytorch/hook_module/wrap_npu_custom.py +0 -85
  247. msprobe/pytorch/hook_module/wrap_tensor.py +0 -69
  248. msprobe/pytorch/hook_module/wrap_torch.py +0 -84
  249. msprobe/pytorch/hook_module/wrap_vf.py +0 -60
  250. msprobe/pytorch/monitor/anomaly_analyse.py +0 -201
  251. msprobe/pytorch/monitor/anomaly_detect.py +0 -410
  252. msprobe/pytorch/monitor/module_spec_verifier.py +0 -95
  253. msprobe/pytorch/monitor/unittest/test_monitor.py +0 -160
  254. msprobe/pytorch/service.py +0 -470
  255. {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/LICENSE +0 -0
  256. {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/WHEEL +0 -0
  257. {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/entry_points.txt +0 -0
  258. {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/top_level.txt +0 -0
  259. /msprobe/{mindspore → core}/compare/ms_to_pt_api.yaml +0 -0
  260. /msprobe/{mindspore/dump → core}/kernel_dump/kernel_config.py +0 -0
  261. /msprobe/{pytorch/monitor/unittest → core/monitor}/__init__.py +0 -0
@@ -40,7 +40,7 @@ from msprobe.pytorch.api_accuracy_checker.run_ut.run_ut_utils import get_validat
40
40
  from msprobe.pytorch.api_accuracy_checker.common.utils import extract_detailed_api_segments, extract_basic_api_segments
41
41
  from msprobe.core.common.file_utils import FileChecker, change_mode, create_directory
42
42
  from msprobe.pytorch.common.log import logger
43
- from msprobe.core.common.utils import CompareException
43
+ from msprobe.core.common.utils import CompareException, check_op_str_pattern_valid
44
44
  from msprobe.core.common.const import Const, CompareConst, FileCheckConst
45
45
 
46
46
  CompareConfig = namedtuple('CompareConfig', ['npu_csv_path', 'gpu_csv_path', 'result_csv_path', 'details_csv_path'])
@@ -151,6 +151,7 @@ def analyse_csv(npu_data, gpu_data, config):
151
151
  message = ''
152
152
  compare_column = ApiPrecisionOutputColumn()
153
153
  full_api_name_with_direction_status = row_npu[ApiPrecisionCompareColumn.API_NAME]
154
+ check_op_str_pattern_valid(full_api_name_with_direction_status)
154
155
  row_gpu = gpu_data[gpu_data[ApiPrecisionCompareColumn.API_NAME] == full_api_name_with_direction_status]
155
156
  api_name, api_full_name, direction_status = extract_detailed_api_segments(full_api_name_with_direction_status)
156
157
  if not api_full_name:
@@ -430,6 +431,7 @@ def _api_precision_compare(parser=None):
430
431
  _api_precision_compare_parser(parser)
431
432
  args = parser.parse_args(sys.argv[1:])
432
433
  _api_precision_compare_command(args)
434
+ logger.info("Compare task completed.")
433
435
 
434
436
 
435
437
  def _api_precision_compare_command(args):
@@ -457,8 +459,3 @@ def _api_precision_compare_parser(parser):
457
459
  parser.add_argument("-o", "--out_path", dest="out_path", default="", type=str,
458
460
  help="<optional> The api precision compare task result out path.",
459
461
  required=False)
460
-
461
-
462
- if __name__ == '__main__':
463
- _api_precision_compare()
464
- logger.info("Compare task completed.")
@@ -40,6 +40,7 @@ from msprobe.pytorch.api_accuracy_checker.compare.compare_utils import check_dty
40
40
  DETAIL_TEST_ROWS, BENCHMARK_COMPARE_SUPPORT_LIST
41
41
  from msprobe.pytorch.api_accuracy_checker.common.utils import extract_basic_api_segments
42
42
  from msprobe.pytorch.common.log import logger
43
+ from msprobe.core.common.decorator import recursion_depth_decorator
43
44
 
44
45
 
45
46
  ResultInfo = namedtuple('ResultInfo', ['full_api_name', 'fwd_success_status', 'bwd_success_status',
@@ -178,6 +179,41 @@ class Comparator:
178
179
  if not os.path.exists(detail_save_path):
179
180
  write_csv(DETAIL_TEST_ROWS, detail_save_path)
180
181
 
182
+ @recursion_depth_decorator("compare_core")
183
+ def _compare_core(self, api_name, bench_output, device_output):
184
+ compare_column = CompareColumn()
185
+ if not isinstance(bench_output, type(device_output)):
186
+ status = CompareConst.ERROR
187
+ message = "bench and npu output type is different."
188
+ elif isinstance(bench_output, dict):
189
+ b_keys, n_keys = set(bench_output.keys()), set(device_output.keys())
190
+ if b_keys != n_keys:
191
+ status = CompareConst.ERROR
192
+ message = "bench and npu output dict keys are different."
193
+ else:
194
+ status, compare_column, message = self._compare_core(api_name, list(bench_output.values()),
195
+ list(device_output.values()))
196
+ elif isinstance(bench_output, torch.Tensor):
197
+ copy_bench_out = bench_output.detach().clone()
198
+ copy_device_output = device_output.detach().clone()
199
+ compare_column.bench_type = str(copy_bench_out.dtype)
200
+ compare_column.npu_type = str(copy_device_output.dtype)
201
+ compare_column.shape = tuple(device_output.shape)
202
+ status, compare_column, message = self._compare_torch_tensor(api_name, copy_bench_out, copy_device_output,
203
+ compare_column)
204
+ elif isinstance(bench_output, (bool, int, float, str)):
205
+ compare_column.bench_type = str(type(bench_output))
206
+ compare_column.npu_type = str(type(device_output))
207
+ status, compare_column, message = self._compare_builtin_type(bench_output, device_output, compare_column)
208
+ elif bench_output is None:
209
+ status = CompareConst.SKIP
210
+ message = "Bench output is None, skip this test."
211
+ else:
212
+ status = CompareConst.ERROR
213
+ message = "Unexpected output type in compare_core: {}".format(type(bench_output))
214
+
215
+ return status, compare_column, message
216
+
181
217
  def write_summary_csv(self, test_result):
182
218
  test_rows = []
183
219
  try:
@@ -293,40 +329,6 @@ class Comparator:
293
329
  test_final_success = CompareConst.WARNING
294
330
  return test_final_success, detailed_result_total
295
331
 
296
- def _compare_core(self, api_name, bench_output, device_output):
297
- compare_column = CompareColumn()
298
- if not isinstance(bench_output, type(device_output)):
299
- status = CompareConst.ERROR
300
- message = "bench and npu output type is different."
301
- elif isinstance(bench_output, dict):
302
- b_keys, n_keys = set(bench_output.keys()), set(device_output.keys())
303
- if b_keys != n_keys:
304
- status = CompareConst.ERROR
305
- message = "bench and npu output dict keys are different."
306
- else:
307
- status, compare_column, message = self._compare_core(api_name, list(bench_output.values()),
308
- list(device_output.values()))
309
- elif isinstance(bench_output, torch.Tensor):
310
- copy_bench_out = bench_output.detach().clone()
311
- copy_device_output = device_output.detach().clone()
312
- compare_column.bench_type = str(copy_bench_out.dtype)
313
- compare_column.npu_type = str(copy_device_output.dtype)
314
- compare_column.shape = tuple(device_output.shape)
315
- status, compare_column, message = self._compare_torch_tensor(api_name, copy_bench_out, copy_device_output,
316
- compare_column)
317
- elif isinstance(bench_output, (bool, int, float, str)):
318
- compare_column.bench_type = str(type(bench_output))
319
- compare_column.npu_type = str(type(device_output))
320
- status, compare_column, message = self._compare_builtin_type(bench_output, device_output, compare_column)
321
- elif bench_output is None:
322
- status = CompareConst.SKIP
323
- message = "Bench output is None, skip this test."
324
- else:
325
- status = CompareConst.ERROR
326
- message = "Unexpected output type in compare_core: {}".format(type(bench_output))
327
-
328
- return status, compare_column, message
329
-
330
332
  def _compare_torch_tensor(self, api_name, bench_output, device_output, compare_column):
331
333
  cpu_shape = bench_output.shape
332
334
  npu_shape = device_output.shape
@@ -28,10 +28,10 @@ from msprobe.pytorch.api_accuracy_checker.compare.compare_utils import binary_st
28
28
  ulp_standard_api, thousandth_standard_api
29
29
  from msprobe.core.common.file_utils import FileOpen, load_json, save_json
30
30
  from msprobe.core.common.utils import check_file_or_directory_path, check_op_str_pattern_valid, is_int
31
- from msprobe.core.common.const import Const, MonitorConst, MsgConst
31
+ from msprobe.core.common.const import Const, MonitorConst, MsgConst, FileCheckConst
32
32
  from msprobe.core.common.log import logger
33
- from msprobe.core.common.file_utils import make_dir
34
- from msprobe.core.common.utils import recursion_depth_decorator
33
+ from msprobe.core.common.file_utils import make_dir, change_mode
34
+ from msprobe.core.common.decorator import recursion_depth_decorator
35
35
 
36
36
  TENSOR_DATA_LIST = ["torch.Tensor", "torch.nn.parameter.Parameter"]
37
37
  TORCH_BOOL_TYPE = ["torch.bool"]
@@ -50,6 +50,7 @@ DATA_NAME = "data_name"
50
50
  API_MAX_LENGTH = 30
51
51
  PROPAGATION_LIST = [Const.FORWARD, Const.BACKWARD]
52
52
  DATAMODE_LIST = ["random_data", "real_data"]
53
+ ITER_MAX_TIMES = 1000
53
54
 
54
55
 
55
56
  class APIInfo:
@@ -97,6 +98,8 @@ class CommonConfig:
97
98
  iter_t = self.iter_times
98
99
  if iter_t <= 0:
99
100
  raise ValueError("iter_times should be an integer bigger than zero!")
101
+ if iter_t > ITER_MAX_TIMES:
102
+ raise ValueError("iter_times should not be greater than 1000!")
100
103
 
101
104
  json_file = self.extract_api_path
102
105
  propagation = self.propagation
@@ -117,7 +120,7 @@ class CommonConfig:
117
120
 
118
121
  # Retrieve the first API name and dictionary
119
122
  forward_item = next(iter(json_content.items()), None)
120
- if not forward_item or not isinstance(forward_item[1], dict):
123
+ if not forward_item or not isinstance(forward_item[1], dict) or not forward_item[1]:
121
124
  raise ValueError(f'Invalid forward API data in json_content!')
122
125
 
123
126
  # if propagation is backward, ensure json file contains forward and backward info
@@ -127,7 +130,7 @@ class CommonConfig:
127
130
  # if propagation is backward, ensure it has valid data
128
131
  if propagation == Const.BACKWARD:
129
132
  backward_item = list(json_content.items())[1]
130
- if not isinstance(backward_item[1], dict):
133
+ if not isinstance(backward_item[1], dict) or not backward_item[1]:
131
134
  raise ValueError(f'Invalid backward API data in json_content!')
132
135
 
133
136
  return json_content
@@ -169,7 +172,7 @@ class APIExtractor:
169
172
  value = self.load_real_data_path(value, real_data_path)
170
173
  new_data[key] = value
171
174
  if not new_data:
172
- logger.error(f"Error: The api '{self.api_name}' does not exist in the file.")
175
+ logger.warning(f"Warning: The api '{self.api_name}' does not exist in the file.")
173
176
  else:
174
177
  save_json(self.output_file, new_data, indent=4)
175
178
  logger.info(
@@ -183,6 +186,7 @@ class APIExtractor:
183
186
  self.update_data_name(v, dump_data_dir)
184
187
  return value
185
188
 
189
+ @recursion_depth_decorator("OpGenerator: APIExtractor.update_data_name")
186
190
  def update_data_name(self, data, dump_data_dir):
187
191
  if isinstance(data, list):
188
192
  for item in data:
@@ -407,19 +411,16 @@ class OperatorScriptGenerator:
407
411
  return kwargs_dict_generator
408
412
 
409
413
 
410
-
411
414
  def _op_generator_parser(parser):
412
- parser.add_argument("-i", "--config_input", dest="config_input", default='', type=str,
413
- help="<Optional> Path of config json file", required=True)
415
+ parser.add_argument("-i", "--config_input", dest="config_input", type=str,
416
+ help="<Required> Path of config json file", required=True)
414
417
  parser.add_argument("-o", "--api_output_path", dest="api_output_path", type=str,
415
- help="<Required> Path of extract api_name.json.",
416
- required=True)
418
+ help="<Required> Path of extract api_name.json.", required=True)
417
419
 
418
420
 
419
421
  def parse_json_config(json_file_path):
420
422
  if not json_file_path:
421
- config_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
422
- json_file_path = os.path.join(config_dir, "config.json")
423
+ raise Exception("config_input path can not be empty, please check.")
423
424
  json_config = load_json(json_file_path)
424
425
  common_config = CommonConfig(json_config)
425
426
  return common_config
@@ -467,6 +468,7 @@ def _run_operator_generate_commond(cmd_args):
467
468
  fout.write(code_template.format(**internal_settings))
468
469
  except OSError:
469
470
  logger.error(f"Failed to open file. Please check file {template_path} or {operator_script_path}.")
471
+ change_mode(operator_script_path, FileCheckConst.DATA_FILE_AUTHORITY)
470
472
 
471
473
  logger.info(f"Generate operator script successfully and the name is {operator_script_path}.")
472
474
 
@@ -1,6 +1,6 @@
1
- import json
2
1
  import os
3
- import math
2
+ import re
3
+ import stat
4
4
  from enum import Enum, auto
5
5
  import torch
6
6
  try:
@@ -25,6 +25,31 @@ RAISE_PRECISION = {{
25
25
  }}
26
26
  THOUSANDTH_THRESHOLDING = 0.001
27
27
  BACKWARD = 'backward'
28
+ DIR = "dir"
29
+ FILE = "file"
30
+ READ_ABLE = "read"
31
+ WRITE_ABLE = "write"
32
+ READ_WRITE_ABLE = "read and write"
33
+ DIRECTORY_LENGTH = 4096
34
+ FILE_NAME_LENGTH = 255
35
+ SOFT_LINK_ERROR = "检测到软链接"
36
+ FILE_PERMISSION_ERROR = "文件权限错误"
37
+ INVALID_FILE_ERROR = "无效文件"
38
+ ILLEGAL_PATH_ERROR = "非法文件路径"
39
+ ILLEGAL_PARAM_ERROR = "非法打开方式"
40
+ FILE_TOO_LARGE_ERROR = "文件过大"
41
+ FILE_VALID_PATTERN = r"^[a-zA-Z0-9_.:/-]+$"
42
+ FILE_SIZE_DICT = {{
43
+ ".pkl": 1073741824, # 1 * 1024 * 1024 * 1024
44
+ ".npy": 10737418240, # 10 * 1024 * 1024 * 1024
45
+ ".json": 1073741824, # 1 * 1024 * 1024 * 1024
46
+ ".pt": 10737418240, # 10 * 1024 * 1024 * 1024
47
+ ".csv": 1073741824, # 1 * 1024 * 1024 * 1024
48
+ ".xlsx": 1073741824, # 1 * 1024 * 1024 * 1024
49
+ ".yaml": 1073741824, # 1 * 1024 * 1024 * 1024
50
+ ".ir": 1073741824 # 1 * 1024 * 1024 * 1024
51
+ }}
52
+ COMMOM_FILE_SIZE = 1048576 # 1 * 1024 * 1024
28
53
 
29
54
  class CompareStandard(Enum):
30
55
  BINARY_EQUALITY_STANDARD = auto()
@@ -33,13 +58,189 @@ class CompareStandard(Enum):
33
58
  BENCHMARK_STANDARD = auto()
34
59
  THOUSANDTH_STANDARD = auto()
35
60
 
61
+ class FileChecker:
62
+ """
63
+ The class for check file.
64
+
65
+ Attributes:
66
+ file_path: The file or dictionary path to be verified.
67
+ path_type: file or dictionary
68
+ ability(str): FileCheckConst.WRITE_ABLE or FileCheckConst.READ_ABLE to set file has writability or readability
69
+ file_type(str): The correct file type for file
70
+ """
71
+
72
+ def __init__(self, file_path, path_type, ability=None, file_type=None, is_script=True):
73
+ self.file_path = file_path
74
+ self.path_type = self._check_path_type(path_type)
75
+ self.ability = ability
76
+ self.file_type = file_type
77
+ self.is_script = is_script
78
+
79
+ @staticmethod
80
+ def _check_path_type(path_type):
81
+ if path_type not in [DIR, FILE]:
82
+ print(f'ERROR: The path_type must be {{DIR}} or {{FILE}}.')
83
+ raise Exception(ILLEGAL_PARAM_ERROR)
84
+ return path_type
85
+
86
+ def common_check(self):
87
+ """
88
+ 功能:用户校验基本文件权限:软连接、文件长度、是否存在、读写权限、文件属组、文件特殊字符
89
+ 注意:文件后缀的合法性,非通用操作,可使用其他独立接口实现
90
+ """
91
+ FileChecker.check_path_exists(self.file_path)
92
+ FileChecker.check_link(self.file_path)
93
+ self.file_path = os.path.realpath(self.file_path)
94
+ FileChecker.check_path_length(self.file_path)
95
+ FileChecker.check_path_type(self.file_path, self.path_type)
96
+ self.check_path_ability()
97
+ if self.is_script:
98
+ FileChecker.check_path_owner_consistent(self.file_path)
99
+ FileChecker.check_path_pattern_valid(self.file_path)
100
+ FileChecker.check_common_file_size(self.file_path)
101
+ FileChecker.check_file_suffix(self.file_path, self.file_type)
102
+ if self.path_type == FILE:
103
+ FileChecker.check_dirpath_before_read(self.file_path)
104
+ return self.file_path
105
+
106
+ def check_path_ability(self):
107
+ if self.ability == WRITE_ABLE:
108
+ FileChecker.check_path_writability(self.file_path)
109
+ if self.ability == READ_ABLE:
110
+ FileChecker.check_path_readability(self.file_path)
111
+ if self.ability == READ_WRITE_ABLE:
112
+ FileChecker.check_path_readability(self.file_path)
113
+ FileChecker.check_path_writability(self.file_path)
114
+
115
+ @staticmethod
116
+ def check_path_exists(path):
117
+ if not os.path.exists(path):
118
+ print(f'ERROR: The file path %s does not exist.' % path)
119
+ raise Exception()
120
+
121
+ @staticmethod
122
+ def check_link(path):
123
+ abs_path = os.path.abspath(path)
124
+ if os.path.islink(abs_path):
125
+ print('ERROR: The file path {{}} is a soft link.'.format(path))
126
+ raise Exception(SOFT_LINK_ERROR)
127
+
128
+ @staticmethod
129
+ def check_path_length(path, name_length=None):
130
+ file_max_name_length = name_length if name_length else FILE_NAME_LENGTH
131
+ if len(path) > DIRECTORY_LENGTH or \
132
+ len(os.path.basename(path)) > file_max_name_length:
133
+ print(f'ERROR: The file path length exceeds limit.')
134
+ raise Exception(ILLEGAL_PATH_ERROR)
135
+
136
+ @staticmethod
137
+ def check_path_type(file_path, file_type):
138
+ if file_type == FILE:
139
+ if not os.path.isfile(file_path):
140
+ print(f"ERROR: The {{file_path}} should be a file!")
141
+ raise Exception(INVALID_FILE_ERROR)
142
+ if file_type == DIR:
143
+ if not os.path.isdir(file_path):
144
+ print(f"ERROR: The {{file_path}} should be a dictionary!")
145
+ raise Exception(INVALID_FILE_ERROR)
146
+
147
+ @staticmethod
148
+ def check_path_owner_consistent(path):
149
+ file_owner = os.stat(path).st_uid
150
+ if file_owner != os.getuid() and os.getuid() != 0:
151
+ print('ERROR: The file path %s may be insecure because is does not belong to you.' % path)
152
+ raise Exception(FILE_PERMISSION_ERROR)
153
+
154
+ @staticmethod
155
+ def check_path_pattern_valid(path):
156
+ if not re.match(FILE_VALID_PATTERN, path):
157
+ print('ERROR: The file path %s contains special characters.' % (path))
158
+ raise Exception(ILLEGAL_PATH_ERROR)
159
+
160
+ @staticmethod
161
+ def check_common_file_size(file_path):
162
+ if os.path.isfile(file_path):
163
+ for suffix, max_size in FILE_SIZE_DICT.items():
164
+ if file_path.endswith(suffix):
165
+ FileChecker.check_file_size(file_path, max_size)
166
+ return
167
+ FileChecker.check_file_size(file_path, COMMOM_FILE_SIZE)
168
+
169
+ @staticmethod
170
+ def check_file_size(file_path, max_size):
171
+ try:
172
+ file_size = os.path.getsize(file_path)
173
+ except OSError as os_error:
174
+ print(f'ERROR: Failed to open "{{file_path}}". {{str(os_error)}}')
175
+ raise Exception(INVALID_FILE_ERROR) from os_error
176
+ if file_size >= max_size:
177
+ print(f'ERROR: The size ({{file_size}}) of {{file_path}} exceeds ({{max_size}}) bytes, tools not support.')
178
+ raise Exception(FILE_TOO_LARGE_ERROR)
179
+
180
+ @staticmethod
181
+ def check_file_suffix(file_path, file_suffix):
182
+ if file_suffix:
183
+ if not file_path.endswith(file_suffix):
184
+ print(f"The {{file_path}} should be a {{file_suffix}} file!")
185
+ raise Exception(INVALID_FILE_ERROR)
186
+
187
+ @staticmethod
188
+ def check_dirpath_before_read(path):
189
+ path = os.path.realpath(path)
190
+ dirpath = os.path.dirname(path)
191
+ if FileChecker.check_others_writable(dirpath):
192
+ print(f"WARNING: The directory is writable by others: {{dirpath}}.")
193
+ try:
194
+ FileChecker.check_path_owner_consistent(dirpath)
195
+ except Exception:
196
+ print(f"WARNING: The directory {{dirpath}} is not yours.")
197
+
198
+ @staticmethod
199
+ def check_others_writable(directory):
200
+ dir_stat = os.stat(directory)
201
+ is_writable = (
202
+ bool(dir_stat.st_mode & stat.S_IWGRP) or # 组可写
203
+ bool(dir_stat.st_mode & stat.S_IWOTH) # 其他用户可写
204
+ )
205
+ return is_writable
206
+
207
+ @staticmethod
208
+ def check_path_readability(path):
209
+ if not os.access(path, os.R_OK):
210
+ print('ERROR: The file path %s is not readable.' % path)
211
+ raise Exception(FILE_PERMISSION_ERROR)
212
+
213
+ @staticmethod
214
+ def check_path_writability(path):
215
+ if not os.access(path, os.W_OK):
216
+ print('ERROR: The file path %s is not writable.' % path)
217
+ raise Exception(FILE_PERMISSION_ERROR)
218
+
219
+
220
+ def check_file_or_directory_path(path, isdir=False):
221
+ """
222
+ Function Description:
223
+ check whether the path is valid
224
+ Parameter:
225
+ path: the path to check
226
+ isdir: the path is dir or file
227
+ Exception Description:
228
+ when invalid data throw exception
229
+ """
230
+ if isdir:
231
+ path_checker = FileChecker(path, DIR, WRITE_ABLE)
232
+ else:
233
+ path_checker = FileChecker(path, FILE, READ_ABLE)
234
+ path_checker.common_check()
235
+
36
236
  def load_pt(pt_path, to_cpu=False):
37
237
  pt_path = os.path.realpath(pt_path)
238
+ check_file_or_directory_path(pt_path)
38
239
  try:
39
240
  if to_cpu:
40
- pt = torch.load(pt_path, map_location=torch.device("cpu"))
241
+ pt = torch.load(pt_path, map_location=torch.device("cpu"), weights_only=True)
41
242
  else:
42
- pt = torch.load(pt_path)
243
+ pt = torch.load(pt_path, weights_only=True)
43
244
  except Exception as e:
44
245
  raise RuntimeError(f"load pt file {{pt_path}} failed") from e
45
246
  return pt
@@ -202,6 +403,7 @@ def compare_tensor(out_device, out_bench, api_name):
202
403
  else:
203
404
  abs_err = torch.abs(out_device - out_bench)
204
405
  abs_bench = torch.abs(out_bench)
406
+ eps = 2 ** -23
205
407
  if dtype_bench == torch.float32:
206
408
  eps = 2 ** -23
207
409
  if dtype_bench == torch.float64:
@@ -50,6 +50,9 @@ def split_json_file(input_file, num_splits, filter_api):
50
50
  backward_data[f"{data_name}.backward"] = backward_data.pop(data_name)
51
51
 
52
52
  input_data = load_json(input_file)
53
+ if "dump_data_dir" not in input_data.keys():
54
+ logger.error("Invalid input file, 'dump_data_dir' field is missing")
55
+ raise CompareException("Invalid input file, 'dump_data_dir' field is missing")
53
56
  if input_data.get("data") is None:
54
57
  logger.error("Invalid input file, 'data' field is missing")
55
58
  raise CompareException("Invalid input file, 'data' field is missing")
@@ -84,10 +87,6 @@ def signal_handler(signum, frame):
84
87
  raise KeyboardInterrupt()
85
88
 
86
89
 
87
- signal.signal(signal.SIGINT, signal_handler)
88
- signal.signal(signal.SIGTERM, signal_handler)
89
-
90
-
91
90
  ParallelUTConfig = namedtuple('ParallelUTConfig', ['api_files', 'out_path', 'num_splits',
92
91
  'save_error_data_flag', 'jit_compile_flag', 'device_id',
93
92
  'result_csv_path', 'total_items', 'config_path'])
@@ -97,7 +96,7 @@ def run_parallel_ut(config):
97
96
  processes = []
98
97
  device_id_cycle = cycle(config.device_id)
99
98
  if config.save_error_data_flag:
100
- logger.info("UT task error datas will be saved")
99
+ logger.info("UT task error data will be saved")
101
100
  logger.info(f"Starting parallel UT with {config.num_splits} processes")
102
101
  progress_bar = tqdm(total=config.total_items, desc="Total items", unit="items")
103
102
 
@@ -129,6 +128,9 @@ def run_parallel_ut(config):
129
128
  sys.stdout.flush()
130
129
  except ValueError as e:
131
130
  logger.warning(f"An error occurred while reading subprocess output: {e}")
131
+ finally:
132
+ if process.poll() is None:
133
+ process.stdout.close()
132
134
 
133
135
  def update_progress_bar(progress_bar, result_csv_path):
134
136
  while any(process.poll() is None for process in processes):
@@ -214,6 +216,8 @@ def prepare_config(args):
214
216
 
215
217
 
216
218
  def main():
219
+ signal.signal(signal.SIGINT, signal_handler)
220
+ signal.signal(signal.SIGTERM, signal_handler)
217
221
  parser = argparse.ArgumentParser(description='Run UT in parallel')
218
222
  _run_ut_parser(parser)
219
223
  parser.add_argument('-n', '--num_splits', type=int, choices=range(1, 65), default=8,
@@ -221,7 +225,3 @@ def main():
221
225
  args = parser.parse_args()
222
226
  config = prepare_config(args)
223
227
  run_parallel_ut(config)
224
-
225
-
226
- if __name__ == '__main__':
227
- main()
@@ -34,8 +34,10 @@ from msprobe.pytorch.api_accuracy_checker.run_ut.run_ut_utils import exec_api, i
34
34
  from msprobe.core.common.file_utils import check_link, FileChecker
35
35
  from msprobe.pytorch.api_accuracy_checker.common.utils import extract_basic_api_segments
36
36
  from msprobe.core.common.const import FileCheckConst, Const
37
+ from msprobe.core.common.utils import check_op_str_pattern_valid
37
38
  from msprobe.pytorch.common.log import logger
38
39
  from msprobe.pytorch.common.parse_json import parse_json_info_forward_backward
40
+ from msprobe.core.common.decorator import recursion_depth_decorator
39
41
 
40
42
 
41
43
  def check_tensor_overflow(x):
@@ -63,6 +65,7 @@ def check_tensor_overflow(x):
63
65
  return False
64
66
 
65
67
 
68
+ @recursion_depth_decorator("check_data_overflow")
66
69
  def check_data_overflow(x, device):
67
70
  if isinstance(x, (tuple, list)):
68
71
  if not x:
@@ -75,6 +78,7 @@ def check_data_overflow(x, device):
75
78
  return torch_npu.npu.utils.npu_check_overflow(x)
76
79
 
77
80
 
81
+ @recursion_depth_decorator("is_bool_output")
78
82
  def is_bool_output(x):
79
83
  if isinstance(x, (tuple, list)):
80
84
  if not x:
@@ -91,6 +95,7 @@ def run_overflow_check(forward_file):
91
95
  dump_path = os.path.dirname(forward_file)
92
96
  real_data_path = os.path.join(dump_path, Const.DUMP_TENSOR_DATA)
93
97
  for api_full_name, api_info_dict in tqdm(forward_content.items()):
98
+ check_op_str_pattern_valid(api_full_name)
94
99
  if is_unsupported_api(api_full_name, is_overflow_check=True):
95
100
  continue
96
101
  try:
@@ -161,6 +166,7 @@ def _run_overflow_check(parser=None):
161
166
  _run_overflow_check_parser(parser)
162
167
  args = parser.parse_args(sys.argv[1:])
163
168
  _run_overflow_check_command(args)
169
+ logger.info("UT task completed.")
164
170
 
165
171
 
166
172
  def _run_overflow_check_command(args):
@@ -175,8 +181,3 @@ def _run_overflow_check_command(args):
175
181
  logger.error(f"Set NPU device id failed. device id is: {args.device_id}")
176
182
  raise NotImplementedError from error
177
183
  run_overflow_check(api_info)
178
-
179
-
180
- if __name__ == '__main__':
181
- _run_overflow_check()
182
- logger.info("UT task completed.")