mindstudio-probe 1.2.2__py3-none-any.whl → 8.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (261) hide show
  1. {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/METADATA +4 -3
  2. {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/RECORD +243 -191
  3. msprobe/README.md +57 -21
  4. msprobe/core/__init__.py +17 -0
  5. msprobe/core/common/const.py +224 -82
  6. msprobe/core/common/decorator.py +50 -0
  7. msprobe/core/common/exceptions.py +5 -3
  8. msprobe/core/common/file_utils.py +274 -40
  9. msprobe/core/common/framework_adapter.py +169 -0
  10. msprobe/core/common/global_lock.py +86 -0
  11. msprobe/core/common/runtime.py +25 -0
  12. msprobe/core/common/utils.py +148 -72
  13. msprobe/core/common_config.py +7 -0
  14. msprobe/core/compare/acc_compare.py +640 -462
  15. msprobe/core/compare/check.py +36 -107
  16. msprobe/core/compare/compare_cli.py +4 -0
  17. msprobe/core/compare/config.py +72 -0
  18. msprobe/core/compare/highlight.py +217 -215
  19. msprobe/core/compare/layer_mapping/layer_mapping.py +4 -1
  20. msprobe/core/compare/merge_result/merge_result.py +12 -6
  21. msprobe/core/compare/multiprocessing_compute.py +227 -107
  22. msprobe/core/compare/npy_compare.py +32 -16
  23. msprobe/core/compare/utils.py +218 -244
  24. msprobe/{mindspore/runtime.py → core/config_check/__init__.py} +2 -4
  25. msprobe/{pytorch/dump/kernel_dump/kernel_config.py → core/config_check/checkers/__init__.py} +8 -16
  26. msprobe/core/config_check/checkers/base_checker.py +60 -0
  27. msprobe/core/config_check/checkers/dataset_checker.py +138 -0
  28. msprobe/core/config_check/checkers/env_args_checker.py +96 -0
  29. msprobe/core/config_check/checkers/hyperparameter_checker.py +170 -0
  30. msprobe/core/config_check/checkers/pip_checker.py +90 -0
  31. msprobe/core/config_check/checkers/random_checker.py +367 -0
  32. msprobe/core/config_check/checkers/weights_checker.py +147 -0
  33. msprobe/core/config_check/ckpt_compare/ckpt_comparator.py +74 -0
  34. msprobe/core/config_check/ckpt_compare/megatron_loader.py +302 -0
  35. msprobe/core/config_check/ckpt_compare/metrics.py +83 -0
  36. msprobe/core/config_check/ckpt_compare/name_mapping.yaml +12 -0
  37. msprobe/core/config_check/config_check_cli.py +51 -0
  38. msprobe/core/config_check/config_checker.py +100 -0
  39. msprobe/{pytorch/parse.py → core/config_check/resource/dependency.yaml} +7 -4
  40. msprobe/core/config_check/resource/env.yaml +57 -0
  41. msprobe/core/config_check/resource/hyperparameter.yaml +21 -0
  42. msprobe/core/config_check/utils/hyperparameter_parser.py +115 -0
  43. msprobe/core/config_check/utils/utils.py +107 -0
  44. msprobe/core/data_dump/api_registry.py +239 -0
  45. msprobe/core/data_dump/data_collector.py +36 -9
  46. msprobe/core/data_dump/data_processor/base.py +74 -53
  47. msprobe/core/data_dump/data_processor/mindspore_processor.py +119 -78
  48. msprobe/core/data_dump/data_processor/pytorch_processor.py +134 -96
  49. msprobe/core/data_dump/json_writer.py +146 -57
  50. msprobe/core/debugger/precision_debugger.py +143 -0
  51. msprobe/core/grad_probe/constant.py +2 -1
  52. msprobe/core/grad_probe/grad_compare.py +2 -2
  53. msprobe/core/grad_probe/utils.py +1 -1
  54. msprobe/core/hook_manager.py +242 -0
  55. msprobe/core/monitor/anomaly_processor.py +384 -0
  56. msprobe/core/overflow_check/abnormal_scene.py +2 -0
  57. msprobe/core/service.py +356 -0
  58. msprobe/core/single_save/__init__.py +0 -0
  59. msprobe/core/single_save/single_comparator.py +243 -0
  60. msprobe/core/single_save/single_saver.py +157 -0
  61. msprobe/docs/01.installation.md +6 -5
  62. msprobe/docs/02.config_introduction.md +89 -30
  63. msprobe/docs/03.config_examples.md +1 -0
  64. msprobe/docs/04.kernel_dump_PyTorch.md +1 -1
  65. msprobe/docs/05.data_dump_PyTorch.md +184 -50
  66. msprobe/docs/06.data_dump_MindSpore.md +193 -28
  67. msprobe/docs/07.accuracy_checker_PyTorch.md +13 -3
  68. msprobe/docs/08.accuracy_checker_online_PyTorch.md +72 -10
  69. msprobe/docs/09.accuracy_checker_MindSpore.md +19 -7
  70. msprobe/docs/10.accuracy_compare_PyTorch.md +266 -102
  71. msprobe/docs/11.accuracy_compare_MindSpore.md +117 -43
  72. msprobe/docs/12.overflow_check_PyTorch.md +5 -3
  73. msprobe/docs/13.overflow_check_MindSpore.md +6 -4
  74. msprobe/docs/14.data_parse_PyTorch.md +4 -10
  75. msprobe/docs/17.grad_probe.md +2 -1
  76. msprobe/docs/18.online_dispatch.md +3 -3
  77. msprobe/docs/19.monitor.md +211 -103
  78. msprobe/docs/21.visualization_PyTorch.md +100 -28
  79. msprobe/docs/22.visualization_MindSpore.md +103 -31
  80. msprobe/docs/23.generate_operator_PyTorch.md +9 -9
  81. msprobe/docs/25.tool_function_introduction.md +23 -22
  82. msprobe/docs/26.data_dump_PyTorch_baseline.md +14 -3
  83. msprobe/docs/27.dump_json_instruction.md +278 -8
  84. msprobe/docs/28.debugger_save_instruction.md +111 -20
  85. msprobe/docs/28.kernel_dump_MindSpore.md +1 -1
  86. msprobe/docs/29.data_dump_MSAdapter.md +229 -0
  87. msprobe/docs/30.overflow_check_MSAdapter.md +31 -0
  88. msprobe/docs/31.config_check.md +95 -0
  89. msprobe/docs/32.ckpt_compare.md +69 -0
  90. msprobe/docs/33.generate_operator_MindSpore.md +190 -0
  91. msprobe/docs/34.RL_collect.md +92 -0
  92. msprobe/docs/35.nan_analyze.md +72 -0
  93. msprobe/docs/FAQ.md +3 -11
  94. msprobe/docs/data_dump_MindSpore/data_dump_MindSpore_baseline.md +12 -1
  95. msprobe/docs/data_dump_MindSpore/dynamic_graph_quick_start_example.md +3 -1
  96. msprobe/docs/img/compare_result.png +0 -0
  97. msprobe/docs/img/merge_result.png +0 -0
  98. msprobe/docs/img/save_compare_result_sample.png +0 -0
  99. msprobe/docs/img/visualization/proxy.png +0 -0
  100. msprobe/docs/img/visualization/vis_browser_1.png +0 -0
  101. msprobe/docs/img/visualization/vis_match_info.png +0 -0
  102. msprobe/docs/img/visualization/vis_precision_info.png +0 -0
  103. msprobe/docs/img/visualization/vis_search_info.png +0 -0
  104. msprobe/docs/img/visualization/vis_show_info.png +0 -0
  105. msprobe/docs/img/visualization/vis_showcase.png +0 -0
  106. msprobe/docs/img/visualization/vis_unmatch_info.png +0 -0
  107. msprobe/mindspore/__init__.py +3 -3
  108. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +151 -55
  109. msprobe/mindspore/api_accuracy_checker/api_runner.py +25 -11
  110. msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +2 -1
  111. msprobe/mindspore/api_accuracy_checker/bench_functions/flash_attention_score.py +580 -0
  112. msprobe/mindspore/api_accuracy_checker/bench_functions/fusion_operator.py +41 -0
  113. msprobe/mindspore/api_accuracy_checker/cmd_parser.py +4 -0
  114. msprobe/mindspore/api_accuracy_checker/data_manager.py +4 -3
  115. msprobe/mindspore/api_accuracy_checker/generate_op_script/config_op.json +9 -0
  116. msprobe/mindspore/api_accuracy_checker/generate_op_script/op_generator.py +451 -0
  117. msprobe/mindspore/api_accuracy_checker/generate_op_script/operator_replication.template +2081 -0
  118. msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +11 -1
  119. msprobe/mindspore/api_accuracy_checker/torch_mindtorch_importer.py +2 -1
  120. msprobe/mindspore/cell_processor.py +204 -33
  121. msprobe/mindspore/code_mapping/graph_parser.py +4 -21
  122. msprobe/mindspore/common/const.py +73 -2
  123. msprobe/mindspore/common/utils.py +157 -29
  124. msprobe/mindspore/compare/common_dir_compare.py +382 -0
  125. msprobe/mindspore/compare/distributed_compare.py +2 -26
  126. msprobe/mindspore/compare/ms_compare.py +18 -398
  127. msprobe/mindspore/compare/ms_graph_compare.py +20 -10
  128. msprobe/mindspore/compare/utils.py +37 -0
  129. msprobe/mindspore/debugger/debugger_config.py +59 -7
  130. msprobe/mindspore/debugger/precision_debugger.py +83 -90
  131. msprobe/mindspore/dump/cell_dump_process.py +902 -0
  132. msprobe/mindspore/dump/cell_dump_with_insert_gradient.py +889 -0
  133. msprobe/mindspore/dump/dump_tool_factory.py +18 -8
  134. msprobe/mindspore/dump/graph_mode_cell_dump.py +139 -0
  135. msprobe/mindspore/dump/graph_tensor_dump.py +123 -0
  136. msprobe/mindspore/dump/hook_cell/api_register.py +176 -0
  137. msprobe/mindspore/dump/hook_cell/hook_cell.py +22 -12
  138. msprobe/mindspore/dump/hook_cell/ms_hook_manager.py +88 -0
  139. msprobe/mindspore/dump/hook_cell/primitive_hooks.py +8 -2
  140. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +42 -26
  141. msprobe/mindspore/dump/jit_dump.py +35 -27
  142. msprobe/mindspore/dump/kernel_kbyk_dump.py +6 -3
  143. msprobe/mindspore/dym_loader/hook_dynamic_loader.cpp +110 -0
  144. msprobe/mindspore/dym_loader/hook_dynamic_loader.h +15 -16
  145. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +22 -12
  146. msprobe/mindspore/free_benchmark/common/utils.py +1 -1
  147. msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +4 -2
  148. msprobe/mindspore/free_benchmark/self_check_tool_factory.py +6 -3
  149. msprobe/mindspore/grad_probe/global_context.py +9 -2
  150. msprobe/mindspore/grad_probe/grad_analyzer.py +2 -1
  151. msprobe/mindspore/grad_probe/grad_stat_csv.py +3 -2
  152. msprobe/mindspore/grad_probe/hook.py +2 -4
  153. msprobe/mindspore/mindspore_service.py +111 -0
  154. msprobe/mindspore/monitor/common_func.py +52 -0
  155. msprobe/mindspore/monitor/data_writers.py +237 -0
  156. msprobe/mindspore/monitor/distributed/wrap_distributed.py +1 -1
  157. msprobe/mindspore/monitor/features.py +13 -1
  158. msprobe/mindspore/monitor/module_hook.py +568 -444
  159. msprobe/mindspore/monitor/optimizer_collect.py +331 -0
  160. msprobe/mindspore/monitor/utils.py +71 -9
  161. msprobe/mindspore/ms_config.py +16 -15
  162. msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +5 -3
  163. msprobe/mindspore/task_handler_factory.py +5 -2
  164. msprobe/msprobe.py +19 -0
  165. msprobe/nan_analyze/__init__.py +14 -0
  166. msprobe/nan_analyze/analyzer.py +255 -0
  167. msprobe/nan_analyze/graph.py +189 -0
  168. msprobe/nan_analyze/utils.py +211 -0
  169. msprobe/pytorch/api_accuracy_checker/common/config.py +2 -2
  170. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +3 -6
  171. msprobe/pytorch/api_accuracy_checker/compare/compare.py +36 -34
  172. msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +15 -13
  173. msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +206 -4
  174. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +9 -9
  175. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +6 -5
  176. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +31 -9
  177. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +28 -20
  178. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +3 -1
  179. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +29 -13
  180. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +12 -2
  181. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +45 -31
  182. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +154 -0
  183. msprobe/pytorch/attl_manager.py +65 -0
  184. msprobe/pytorch/bench_functions/moe_gating_top_k_softmax.py +6 -0
  185. msprobe/pytorch/bench_functions/npu_fusion_attention.py +27 -0
  186. msprobe/pytorch/common/utils.py +53 -19
  187. msprobe/pytorch/compare/distributed_compare.py +4 -36
  188. msprobe/pytorch/compare/pt_compare.py +13 -84
  189. msprobe/pytorch/compare/utils.py +47 -0
  190. msprobe/pytorch/debugger/debugger_config.py +34 -17
  191. msprobe/pytorch/debugger/precision_debugger.py +50 -96
  192. msprobe/pytorch/dump/module_dump/hook_wrapper.py +93 -0
  193. msprobe/pytorch/dump/module_dump/module_dump.py +15 -61
  194. msprobe/pytorch/dump/module_dump/module_processer.py +150 -114
  195. msprobe/pytorch/free_benchmark/common/utils.py +1 -1
  196. msprobe/pytorch/free_benchmark/compare/single_benchmark.py +1 -1
  197. msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +3 -3
  198. msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +3 -3
  199. msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +1 -1
  200. msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +1 -1
  201. msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +1 -1
  202. msprobe/pytorch/function_factory.py +1 -1
  203. msprobe/pytorch/grad_probe/grad_monitor.py +2 -2
  204. msprobe/pytorch/grad_probe/grad_stat_csv.py +3 -2
  205. msprobe/pytorch/hook_module/api_register.py +155 -0
  206. msprobe/pytorch/hook_module/hook_module.py +18 -22
  207. msprobe/pytorch/hook_module/jit_script_wrapper.py +33 -0
  208. msprobe/pytorch/hook_module/pt_hook_manager.py +68 -0
  209. msprobe/pytorch/hook_module/register_optimizer_hook.py +2 -1
  210. msprobe/pytorch/hook_module/support_wrap_ops.yaml +193 -75
  211. msprobe/pytorch/hook_module/utils.py +28 -2
  212. msprobe/pytorch/monitor/csv2tb.py +14 -4
  213. msprobe/pytorch/monitor/data_writers.py +259 -0
  214. msprobe/pytorch/monitor/distributed/wrap_distributed.py +8 -2
  215. msprobe/pytorch/monitor/module_hook.py +336 -241
  216. msprobe/pytorch/monitor/module_metric.py +17 -0
  217. msprobe/pytorch/monitor/optimizer_collect.py +244 -224
  218. msprobe/pytorch/monitor/utils.py +84 -4
  219. msprobe/pytorch/online_dispatch/compare.py +0 -2
  220. msprobe/pytorch/online_dispatch/dispatch.py +13 -2
  221. msprobe/pytorch/online_dispatch/dump_compare.py +8 -2
  222. msprobe/pytorch/online_dispatch/utils.py +3 -0
  223. msprobe/pytorch/parse_tool/lib/interactive_cli.py +1 -6
  224. msprobe/pytorch/parse_tool/lib/utils.py +5 -4
  225. msprobe/pytorch/pt_config.py +16 -11
  226. msprobe/pytorch/pytorch_service.py +70 -0
  227. msprobe/visualization/builder/graph_builder.py +69 -10
  228. msprobe/visualization/builder/msprobe_adapter.py +24 -12
  229. msprobe/visualization/compare/graph_comparator.py +63 -51
  230. msprobe/visualization/compare/mode_adapter.py +22 -20
  231. msprobe/visualization/graph/base_node.py +11 -4
  232. msprobe/visualization/graph/distributed_analyzer.py +1 -10
  233. msprobe/visualization/graph/graph.py +2 -13
  234. msprobe/visualization/graph/node_op.py +1 -2
  235. msprobe/visualization/graph_service.py +251 -104
  236. msprobe/visualization/utils.py +26 -44
  237. msprobe/mindspore/dump/hook_cell/api_registry.py +0 -207
  238. msprobe/mindspore/dump/hook_cell/wrap_api.py +0 -212
  239. msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +0 -140
  240. msprobe/mindspore/monitor/anomaly_detect.py +0 -404
  241. msprobe/mindspore/monitor/module_spec_verifier.py +0 -94
  242. msprobe/mindspore/service.py +0 -543
  243. msprobe/pytorch/hook_module/api_registry.py +0 -166
  244. msprobe/pytorch/hook_module/wrap_distributed.py +0 -79
  245. msprobe/pytorch/hook_module/wrap_functional.py +0 -66
  246. msprobe/pytorch/hook_module/wrap_npu_custom.py +0 -85
  247. msprobe/pytorch/hook_module/wrap_tensor.py +0 -69
  248. msprobe/pytorch/hook_module/wrap_torch.py +0 -84
  249. msprobe/pytorch/hook_module/wrap_vf.py +0 -60
  250. msprobe/pytorch/monitor/anomaly_analyse.py +0 -201
  251. msprobe/pytorch/monitor/anomaly_detect.py +0 -410
  252. msprobe/pytorch/monitor/module_spec_verifier.py +0 -95
  253. msprobe/pytorch/monitor/unittest/test_monitor.py +0 -160
  254. msprobe/pytorch/service.py +0 -470
  255. {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/LICENSE +0 -0
  256. {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/WHEEL +0 -0
  257. {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/entry_points.txt +0 -0
  258. {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/top_level.txt +0 -0
  259. /msprobe/{mindspore → core}/compare/ms_to_pt_api.yaml +0 -0
  260. /msprobe/{mindspore/dump → core}/kernel_dump/kernel_config.py +0 -0
  261. /msprobe/{pytorch/monitor/unittest → core/monitor}/__init__.py +0 -0
@@ -45,11 +45,11 @@ from msprobe.pytorch.api_accuracy_checker.compare.compare_column import CompareC
45
45
  from msprobe.pytorch.api_accuracy_checker.common.config import CheckerConfig
46
46
  from msprobe.pytorch.common.parse_json import parse_json_info_forward_backward
47
47
  from msprobe.core.common.file_utils import FileChecker, change_mode, \
48
- create_directory, get_json_contents, read_csv, check_file_or_directory_path, check_crt_valid
48
+ create_directory, get_json_contents, read_csv, check_file_or_directory_path
49
49
  from msprobe.pytorch.common.log import logger
50
50
  from msprobe.pytorch.pt_config import parse_json_config
51
51
  from msprobe.core.common.const import Const, FileCheckConst, CompareConst
52
- from msprobe.core.common.utils import safe_get_value, CompareException
52
+ from msprobe.core.common.utils import safe_get_value, CompareException, is_int, check_op_str_pattern_valid
53
53
  from msprobe.pytorch.common.utils import seed_all
54
54
  from msprobe.pytorch.api_accuracy_checker.tensor_transport_layer.attl import ATTL, ATTLConfig, move2device_exec
55
55
  from msprobe.pytorch.api_accuracy_checker.tensor_transport_layer.device_dispatch import ConsumerDispatcher
@@ -65,6 +65,8 @@ DETAILS_FILE_NAME = "accuracy_checking_details_" + current_time + ".csv"
65
65
 
66
66
  not_backward_list = ['repeat_interleave']
67
67
  unsupported_backward_list = ['masked_select']
68
+ unsupported_api_list = ["to", "empty", "empty_like", "empty_strided", "new_empty", "new_empty_strided",
69
+ "empty_with_format"]
68
70
 
69
71
 
70
72
  tqdm_params = {
@@ -83,6 +85,9 @@ tqdm_params = {
83
85
  }
84
86
 
85
87
 
88
+ seed_all()
89
+
90
+
86
91
  def run_ut(config):
87
92
  logger.info("start UT test")
88
93
  if config.online_config.is_online:
@@ -93,7 +98,7 @@ def run_ut(config):
93
98
  logger.info(f"UT task details will be saved in {config.details_csv_path}")
94
99
 
95
100
  if config.save_error_data:
96
- logger.info(f"UT task error_datas will be saved in {config.error_data_path}")
101
+ logger.info(f"UT task error_data will be saved in {config.error_data_path}")
97
102
  compare = Comparator(config.result_csv_path, config.details_csv_path, config.is_continue_run_ut, config=config)
98
103
 
99
104
  if config.online_config.is_online:
@@ -117,6 +122,7 @@ def run_ut(config):
117
122
  def run_api_offline(config, compare, api_name_set):
118
123
  err_column = CompareColumn()
119
124
  for _, (api_full_name, api_info_dict) in enumerate(tqdm(config.forward_content.items(), **tqdm_params)):
125
+ check_op_str_pattern_valid(api_full_name)
120
126
  if api_full_name in api_name_set:
121
127
  continue
122
128
  if is_unsupported_api(api_full_name):
@@ -218,6 +224,7 @@ def blacklist_and_whitelist_filter(api_name, black_list, white_list):
218
224
  If api is both in black_list and black_list, black_list first.
219
225
  return: False for exec api, True for not exec
220
226
  """
227
+ black_list.extend(unsupported_api_list)
221
228
  if black_list and api_name in black_list:
222
229
  return True
223
230
  if white_list and api_name not in white_list:
@@ -317,7 +324,8 @@ def run_torch_api_online(api_full_name, api_data, backward_content):
317
324
  if kwargs.get("device"):
318
325
  del kwargs["device"]
319
326
 
320
- device_out = exec_api(api_type, api_name, Const.CUDA_LOWERCASE, args, kwargs)
327
+ device_exec_params = ExecParams(api_type, api_name, current_device, args, kwargs, False, None)
328
+ device_out = exec_api(device_exec_params)
321
329
  device_out = move2device_exec(device_out, "cpu")
322
330
  return UtDataInfo(None, None, out, device_out, None, in_fwd_data_list, None, rank=api_data.rank)
323
331
 
@@ -344,6 +352,9 @@ def need_to_backward(grad_index, out):
344
352
 
345
353
  def run_backward(args, grad, grad_index, out):
346
354
  if grad_index is not None:
355
+ if not is_int(grad_index):
356
+ logger.error(f"{grad_index} dtype is not int")
357
+ raise TypeError(f"{grad_index} dtype is not int")
347
358
  if grad_index >= len(out):
348
359
  logger.error(f"Run backward error when grad_index is {grad_index}")
349
360
  raise IndexError(f"Run backward error when grad_index is {grad_index}")
@@ -430,6 +441,7 @@ def preprocess_forward_content(forward_content):
430
441
  arg_cache = {}
431
442
 
432
443
  for key, value in forward_content.items():
444
+ check_op_str_pattern_valid(key)
433
445
  base_key = key.rsplit(Const.SEP, 1)[0]
434
446
 
435
447
  if key not in arg_cache:
@@ -469,7 +481,7 @@ def _run_ut(parser=None):
469
481
  _run_ut_parser(parser)
470
482
  args = parser.parse_args(sys.argv[1:])
471
483
  run_ut_command(args)
472
-
484
+
473
485
 
474
486
  def checked_online_config(online_config):
475
487
  if not online_config.is_online:
@@ -491,7 +503,10 @@ def checked_online_config(online_config):
491
503
  check_file_or_directory_path(online_config.tls_path, isdir=True)
492
504
  check_file_or_directory_path(os.path.join(online_config.tls_path, "server.key"))
493
505
  check_file_or_directory_path(os.path.join(online_config.tls_path, "server.crt"))
494
- check_crt_valid(os.path.join(online_config.tls_path, "server.crt"))
506
+ check_file_or_directory_path(os.path.join(online_config.tls_path, "ca.crt"))
507
+ crl_path = os.path.join(online_config.tls_path, "crl.pem")
508
+ if os.path.exists(crl_path):
509
+ check_file_or_directory_path(crl_path)
495
510
 
496
511
  # host and port
497
512
  if not isinstance(online_config.host, str) or not re.match(Const.ipv4_pattern, online_config.host):
@@ -561,7 +576,15 @@ def run_ut_command(args):
561
576
  error_data_path = checker_config.error_data_path
562
577
  if save_error_data:
563
578
  if args.result_csv_path:
564
- time_info = result_csv_path.split('.')[0].split('_')[-1]
579
+ parts_by_dot = result_csv_path.split(Const.SEP)
580
+ if len(parts_by_dot) < 2 or not parts_by_dot[0]:
581
+ raise ValueError("result_csv_path does not contain a valid file name with an extension.")
582
+ file_name_part = parts_by_dot[0]
583
+ parts_by_underscore = file_name_part.split(Const.REPLACEMENT_CHARACTER)
584
+ if len(parts_by_underscore) < 2:
585
+ raise ValueError("File name part does not contain enough '_' separated segments.")
586
+ time_info = parts_by_underscore[-1]
587
+
565
588
  global UT_ERROR_DATA_DIR
566
589
  UT_ERROR_DATA_DIR = 'ut_error_data' + time_info
567
590
  error_data_path = initialize_save_error_data(error_data_path)
@@ -579,9 +602,8 @@ def run_ut_command(args):
579
602
  }
580
603
  run_ut_config = checker_config.get_run_ut_config(**config_params)
581
604
  run_ut(run_ut_config)
605
+ logger.info("UT task completed.")
582
606
 
583
607
 
584
608
  if __name__ == '__main__':
585
- seed_all()
586
609
  _run_ut()
587
- logger.info("UT task completed.")
@@ -1,9 +1,7 @@
1
- #!/usr/bin/env python3
2
- # -*- coding: utf-8 -*-
3
- # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
1
+ # Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
4
2
  # All rights reserved.
5
3
  #
6
- # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
5
  # you may not use this file except in compliance with the License.
8
6
  # You may obtain a copy of the License at
9
7
  #
@@ -18,8 +16,8 @@
18
16
  import os
19
17
  from collections import namedtuple
20
18
  import re
21
- import torch
22
19
 
20
+ import torch
23
21
  try:
24
22
  import torch_npu
25
23
  except ImportError:
@@ -33,11 +31,9 @@ from msprobe.core.common.const import FileCheckConst, Const, CompareConst
33
31
  from msprobe.core.common.file_utils import FileChecker
34
32
  from msprobe.core.common.log import logger
35
33
  from msprobe.core.common.utils import CompareException
34
+ from msprobe.pytorch.hook_module.api_register import ApiTemplate, get_api_register
36
35
  from msprobe.pytorch.hook_module.wrap_aten import AtenOPTemplate
37
- from msprobe.pytorch.hook_module.wrap_functional import FunctionalOPTemplate
38
- from msprobe.pytorch.hook_module.wrap_npu_custom import NpuOPTemplate
39
- from msprobe.pytorch.hook_module.wrap_tensor import TensorOPTemplate
40
- from msprobe.pytorch.hook_module.wrap_torch import TorchOPTemplate
36
+
41
37
 
42
38
  hf_32_standard_api = ["conv1d", "conv2d"]
43
39
  not_detach_set = {'resize_', 'resize_as_', 'set_', 'transpose_', 't_', 'squeeze_', 'unsqueeze_'}
@@ -108,17 +104,28 @@ def exec_api(exec_params):
108
104
  kwargs = exec_params.kwargs
109
105
  is_autocast = exec_params.is_autocast
110
106
  autocast_dtype = exec_params.autocast_dtype
111
-
112
- if api_type == "Functional":
113
- torch_api = FunctionalOPTemplate(api_name, str, False)
114
- if api_type == "Tensor":
115
- torch_api = TensorOPTemplate(api_name, str, False)
116
- if api_type == "Torch":
117
- torch_api = TorchOPTemplate(api_name, str, False)
118
- if api_type == "Aten":
107
+ out = None
108
+
109
+ prefix_map = Const.API_DATA_PREFIX.get(Const.PT_FRAMEWORK, {})
110
+ if not prefix_map or api_type not in prefix_map.values() or \
111
+ api_type not in (
112
+ Const.FUNCTIONAL_API_TYPE_PREFIX,
113
+ Const.TENSOR_API_TYPE_PREFIX,
114
+ Const.TORCH_API_TYPE_PREFIX,
115
+ Const.ATEN_API_TYPE_PREFIX,
116
+ Const.NPU_API_TYPE_PREFIX
117
+ ):
118
+ return out
119
+
120
+ if api_type == Const.ATEN_API_TYPE_PREFIX:
119
121
  torch_api = AtenOPTemplate(api_name, None, False)
120
- if api_type == "NPU":
121
- torch_api = NpuOPTemplate(api_name, None, False, device)
122
+ else:
123
+ api_register = get_api_register()
124
+ api_register.initialize_hook(None)
125
+ api_func_type = list(prefix_map.keys())[list(prefix_map.values()).index(api_type)]
126
+ api_func = api_register.ori_api_attr.get(Const.PT_FRAMEWORK + Const.SEP + api_func_type, {}).get(api_name)
127
+
128
+ torch_api = ApiTemplate(api_name, api_func, api_type, None, need_hook=False, device=device)
122
129
  if is_autocast:
123
130
  with autocast(dtype=autocast_dtype):
124
131
  out = torch_api.forward(*args, **kwargs)
@@ -248,7 +255,8 @@ def record_skip_info(api_full_name, compare, compare_alg_results):
248
255
 
249
256
  def is_unsupported_api(api_name, is_overflow_check=False):
250
257
  split_name = api_name.split(Const.SEP)[0]
251
- flag = (split_name == Const.DISTRIBUTED) or (is_overflow_check and split_name == Const.NPU)
258
+ unsupport_type_list = [Const.DISTRIBUTED, Const.MINDSPEED_API_TYPE_PREFIX]
259
+ flag = (split_name in unsupport_type_list) or (is_overflow_check and split_name == Const.NPU)
252
260
  if flag:
253
261
  logger.info(f"{split_name} api is not supported for run ut. SKIP.")
254
262
  return flag
@@ -27,6 +27,7 @@ from msprobe.pytorch.api_accuracy_checker.tensor_transport_layer.client import T
27
27
  from msprobe.pytorch.api_accuracy_checker.tensor_transport_layer.server import TCPServer
28
28
  from msprobe.core.common.file_utils import remove_path
29
29
  from msprobe.pytorch.common.utils import logger, save_api_data, load_api_data, save_pkl, load_pkl
30
+ from msprobe.core.common.decorator import recursion_depth_decorator
30
31
 
31
32
  BufferType = Union[ApiData, Dict[str, Any], str] # Union[Tensor, Tuple[Optional[Tensor]]]
32
33
 
@@ -168,11 +169,12 @@ class ATTL:
168
169
  return buffer
169
170
 
170
171
 
172
+ @recursion_depth_decorator("move2device_exec")
171
173
  def move2device_exec(obj, device):
172
174
  if isinstance(obj, (tuple, list)):
173
175
  data_list = [move2device_exec(val, device) for val in obj]
174
176
  return data_list if isinstance(obj, list) else tuple(data_list)
175
- if isinstance(obj, dict):
177
+ if isinstance(obj, dict):
176
178
  return {key: move2device_exec(val, device) for key, val in obj.items()}
177
179
  elif isinstance(obj, torch.Tensor):
178
180
  obj = obj.detach()
@@ -12,23 +12,22 @@
12
12
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
13
  # See the License for the specific language governing permissions and
14
14
  # limitations under the License.
15
-
16
- import hashlib
15
+ from functools import partial
16
+ import zlib
17
17
  import io
18
18
  import struct
19
19
  import time
20
20
  import os
21
- import signal
22
21
  from queue import Queue
23
22
  from threading import Thread
24
23
  from typing import Union
25
24
 
26
- from twisted.internet import reactor, protocol, endpoints
25
+ from twisted.internet import reactor, protocol, endpoints, ssl
27
26
  from twisted.protocols.basic import FileSender
28
27
 
29
28
  from msprobe.pytorch.common.utils import logger
30
29
  from msprobe.pytorch.api_accuracy_checker.tensor_transport_layer.utils import STRUCT_UNPACK_MODE as unpack_mode, \
31
- STR_TO_BYTES_ORDER as bytes_order
30
+ STR_TO_BYTES_ORDER as bytes_order, cipher_list, verify_callback, load_ssl_pem
32
31
 
33
32
  MAX_SENDING_QUEUE_SIZE = 20
34
33
 
@@ -104,11 +103,28 @@ class TCPClient:
104
103
  self.factory = MessageClientFactory()
105
104
  self.factory.protocol = cur_protocol
106
105
  if self.tls_path:
107
- from twisted.internet import ssl
108
- client_key = os.path.join(self.tls_path, "client.key")
109
- client_crt = os.path.join(self.tls_path, "client.crt")
110
- client_context_factory = ssl.DefaultOpenSSLContextFactory(client_key, client_crt)
111
- endpoint = endpoints.SSL4ClientEndpoint(reactor, self.host, self.port, client_context_factory)
106
+ client_key, client_crt, ca_crt, crl_pem = load_ssl_pem(
107
+ key_file=os.path.join(self.tls_path, "client.key"),
108
+ cert_file=os.path.join(self.tls_path, "client.crt"),
109
+ ca_file=os.path.join(self.tls_path, "ca.crt"),
110
+ crl_file=os.path.join(self.tls_path, "crl.pem")
111
+ )
112
+
113
+ ssl_options = ssl.CertificateOptions(
114
+ privateKey=client_key,
115
+ certificate=client_crt,
116
+ method=ssl.SSL.TLSv1_2_METHOD,
117
+ verify=True,
118
+ requireCertificate=True,
119
+ caCerts=[ca_crt], # 信任的CA证书列表
120
+ )
121
+ ssl_context = ssl_options.getContext()
122
+ ssl_context.set_cipher_list(cipher_list)
123
+ ssl_context.set_options(ssl.SSL.OP_NO_RENEGOTIATION)
124
+ ssl_context.set_verify(ssl.SSL.VERIFY_PEER | ssl.SSL.VERIFY_FAIL_IF_NO_PEER_CERT,
125
+ partial(verify_callback, crl=crl_pem))
126
+
127
+ endpoint = endpoints.SSL4ClientEndpoint(reactor, self.host, self.port, ssl_options)
112
128
  else:
113
129
  endpoint = endpoints.TCP4ClientEndpoint(reactor, self.host, self.port)
114
130
  d = endpoint.connect(self.factory)
@@ -299,12 +315,12 @@ class ClientProtocol(protocol.Protocol):
299
315
 
300
316
  def send_wrapped_data(self, data, sequence_number: int = 0, rank: int = 0, step: int = 0):
301
317
  length = len(data)
302
- md5_hash = hashlib.md5(data).hexdigest() if self.check_sum else ""
318
+ data_crc = f"{zlib.crc32(data):08x}" if self.check_sum else ""
303
319
  data_meaasge = length.to_bytes(8, byteorder=bytes_order) + \
304
320
  sequence_number.to_bytes(8, byteorder=bytes_order) + \
305
321
  rank.to_bytes(8, byteorder=bytes_order) + \
306
322
  step.to_bytes(8, byteorder=bytes_order) + \
307
- md5_hash.encode() + \
323
+ data_crc.encode() + \
308
324
  data
309
325
  logger.debug(f"send 流水号: {sequence_number}; RANK: {rank}; STEP: {step}; LENGTH: {length}")
310
326
 
@@ -346,7 +362,7 @@ class ClientProtocol(protocol.Protocol):
346
362
  def connectionLost(self, reason):
347
363
  self.signal_exit = True
348
364
  self.factory.num_connections -= 1
349
- logger.info(f"Lost connection with server, reason is : {reason}")
365
+ logger.info(f"Lost connection with server, reason is : {reason.value}")
350
366
 
351
367
 
352
368
  class MessageClientFactory(protocol.ClientFactory):
@@ -29,7 +29,6 @@ from msprobe.pytorch.common.log import logger
29
29
  from msprobe.pytorch.api_accuracy_checker.tensor_transport_layer.attl import move2target_device
30
30
  from msprobe.pytorch.api_accuracy_checker.run_ut.run_ut_utils import generate_cpu_params
31
31
 
32
-
33
32
  # NPU vs GPU api list
34
33
  CompareApi = set(absolute_standard_api) | set(binary_standard_api) | set(thousandth_standard_api)
35
34
 
@@ -43,6 +42,15 @@ OnlineApiPrecisionCompareConfig = namedtuple('OnlineApiPrecisionCompareConfig',
43
42
  CommonCompareConfig = namedtuple('CommonCompareConfig', ['compare', 'handle_func', 'config'])
44
43
 
45
44
 
45
+ def get_gpu_device():
46
+ try:
47
+ import torch_npu
48
+ is_gpu = False
49
+ except ImportError:
50
+ is_gpu = True
51
+ return is_gpu
52
+
53
+
46
54
  def run_ut_process(xpu_id, consumer_queue, common_config, api_precision_csv_file):
47
55
  """ When consumer_queue(shared with ConsumerDispatcher) is not empty, consume api data from consumer_queue.
48
56
  :param xpu_id: int
@@ -51,7 +59,9 @@ def run_ut_process(xpu_id, consumer_queue, common_config, api_precision_csv_file
51
59
  :param api_precision_csv_file: list, length is 2, result file name and details file name
52
60
  :return:
53
61
  """
54
- gpu_device = torch.device(f'cuda:{xpu_id}')
62
+ device_info = "cuda" if get_gpu_device() else "npu"
63
+ logger.info(f"Start run_ut_process for {device_info} device, rank: {xpu_id}.")
64
+ gpu_device = torch.device(f'{device_info}:{xpu_id}')
55
65
 
56
66
  while True:
57
67
  if consumer_queue.empty():
@@ -12,19 +12,19 @@
12
12
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
13
  # See the License for the specific language governing permissions and
14
14
  # limitations under the License.
15
-
16
- import os.path
15
+ from functools import partial
16
+ import os
17
17
  import struct
18
- import hashlib
18
+ import zlib
19
19
  import time
20
20
  import io
21
21
  from threading import Thread
22
22
 
23
- from twisted.internet import reactor, protocol, endpoints
23
+ from twisted.internet import reactor, protocol, endpoints, ssl
24
24
 
25
25
  from msprobe.pytorch.common.utils import logger
26
26
  from msprobe.pytorch.api_accuracy_checker.tensor_transport_layer.utils import cipher_list, \
27
- STRUCT_UNPACK_MODE as unpack_mode, STR_TO_BYTES_ORDER as bytes_order
27
+ STRUCT_UNPACK_MODE as unpack_mode, STR_TO_BYTES_ORDER as bytes_order, verify_callback, load_ssl_pem
28
28
 
29
29
 
30
30
  class TCPServer:
@@ -44,15 +44,28 @@ class TCPServer:
44
44
  self.factory.protocol = self.build_protocol
45
45
 
46
46
  if self.tls_path:
47
- from OpenSSL import SSL
48
- from twisted.internet import ssl
49
- server_key = os.path.join(self.tls_path, "server.key")
50
- server_crt = os.path.join(self.tls_path, "server.crt")
51
- server_context_factory = ssl.DefaultOpenSSLContextFactory(server_key, server_crt, SSL.TLSv1_2_METHOD)
52
- server_context_ = server_context_factory.getContext()
53
- server_context_.set_cipher_list(cipher_list)
54
- server_context_.set_options(SSL.OP_NO_RENEGOTIATION)
55
- endpoint = endpoints.SSL4ServerEndpoint(reactor, self.port, server_context_factory)
47
+ server_key, server_crt, ca_crt, crl_pem = load_ssl_pem(
48
+ key_file=os.path.join(self.tls_path, "server.key"),
49
+ cert_file=os.path.join(self.tls_path, "server.crt"),
50
+ ca_file=os.path.join(self.tls_path, "ca.crt"),
51
+ crl_file=os.path.join(self.tls_path, "crl.pem")
52
+ )
53
+
54
+ ssl_options = ssl.CertificateOptions(
55
+ privateKey=server_key,
56
+ certificate=server_crt,
57
+ method=ssl.SSL.TLSv1_2_METHOD,
58
+ verify=True,
59
+ requireCertificate=True,
60
+ caCerts=[ca_crt], # 信任的CA证书列表
61
+ )
62
+ ssl_context = ssl_options.getContext()
63
+ ssl_context.set_cipher_list(cipher_list)
64
+ ssl_context.set_options(ssl.SSL.OP_NO_RENEGOTIATION)
65
+ ssl_context.set_verify(ssl.SSL.VERIFY_PEER | ssl.SSL.VERIFY_FAIL_IF_NO_PEER_CERT,
66
+ partial(verify_callback, crl=crl_pem))
67
+
68
+ endpoint = endpoints.SSL4ServerEndpoint(reactor, self.port, ssl_options)
56
69
  else:
57
70
  endpoint = endpoints.TCP4ServerEndpoint(reactor, self.port)
58
71
  endpoint.listen(self.factory)
@@ -85,10 +98,10 @@ class ServerProtocol(protocol.Protocol):
85
98
  self.consumer_queue = shared_queue
86
99
  self.check_sum = check_sum
87
100
  self.length_width = 8
88
- self.md5_width = 32
101
+ self.crc_width = 8
89
102
  self.obj_length = None
90
103
  self.tell = 0
91
- self.obj_md5 = None
104
+ self.obj_crc = None
92
105
  self.obj_body = None
93
106
  self.sequence_number = -1
94
107
  self.rank = -1
@@ -99,7 +112,7 @@ class ServerProtocol(protocol.Protocol):
99
112
  self.buffer = io.BytesIO()
100
113
  self.obj_length = None
101
114
  self.tell = 0
102
- self.obj_md5 = None
115
+ self.obj_crc = None
103
116
  self.obj_body = None
104
117
  self.factory.transport_dict[self.transport] = 1
105
118
  self.factory.transport_list.append(self.transport)
@@ -132,11 +145,12 @@ class ServerProtocol(protocol.Protocol):
132
145
  time.sleep(0.1)
133
146
 
134
147
  obj_key = str(self.sequence_number) + "_" + str(self.rank) + "_" + str(self.step)
148
+ # get the crc value of a 16-bit string with a length of 8
149
+ recv_crc = f"{zlib.crc32(self.obj_body):08x}"
135
150
 
136
- recv_md5 = hashlib.md5(self.obj_body).hexdigest()
137
- if self.check_sum and recv_md5 != self.obj_md5:
138
- # when needs check md5 and check no pass, indicates received data error, send b"ERROR" to client.
139
- logger.debug(f"Error:接收数据有问题,流水号{self.sequence_number}, expected {self.obj_md5}, but get {recv_md5}")
151
+ if self.check_sum and recv_crc != self.obj_crc:
152
+ # when needs check hash value and check no pass, indicates received data error, send b"ERROR" to client.
153
+ logger.debug(f"Error:接收数据有问题,流水号{self.sequence_number}, expected {self.obj_crc}, but get {recv_crc}")
140
154
  self.send_ack(self.ACK_ERROR)
141
155
  else:
142
156
  if self.obj_body == self.ACK_STOP:
@@ -146,7 +160,7 @@ class ServerProtocol(protocol.Protocol):
146
160
  if obj_key in self.sequence_number_dict:
147
161
  logger.debug(f"这是一次异常的重传,可以忽略。 {obj_key}, {self.sequence_number_dict}")
148
162
  else:
149
- self.sequence_number_dict[obj_key] = self.obj_md5
163
+ self.sequence_number_dict[obj_key] = self.obj_crc
150
164
  self.consumer_queue.put(self.obj_body, block=True)
151
165
 
152
166
  self.reset_env()
@@ -173,7 +187,7 @@ class ServerProtocol(protocol.Protocol):
173
187
  self.sequence_number = -1
174
188
  self.rank = -1
175
189
  self.step = -1
176
- self.obj_md5 = None
190
+ self.obj_crc = None
177
191
  self.obj_body = None
178
192
 
179
193
  def dataReceived(self, data):
@@ -192,15 +206,15 @@ class ServerProtocol(protocol.Protocol):
192
206
  logger.debug(
193
207
  f"流水号: {self.sequence_number}; RANK: {self.rank}; STEP: {self.step}; Length: {self.obj_length}")
194
208
 
195
- # If needs check md5 but not parse md5 yet, read 32b md5 values
196
- check_sum_and_md5 = (self.check_sum
209
+ # If needs check hash but not parse crc yet, read 8b crc values
210
+ check_sum_and_crc = (self.check_sum
197
211
  and self.obj_length is not None
198
- and self.obj_md5 is None
199
- and len(self.buffer.getvalue()) - self.tell >= self.md5_width)
200
- if check_sum_and_md5:
201
- self.obj_md5 = self.buffer.read(self.md5_width).decode()
202
- self.tell += self.md5_width
203
- logger.debug(f"MD5: {self.obj_md5}")
212
+ and self.obj_crc is None
213
+ and len(self.buffer.getvalue()) - self.tell >= self.crc_width)
214
+ if check_sum_and_crc:
215
+ self.obj_crc = self.buffer.read(self.crc_width).decode()
216
+ self.tell += self.crc_width
217
+ logger.debug(f"Hash value: {self.obj_crc}")
204
218
 
205
219
  current_length = len(self.buffer.getvalue()) - self.tell
206
220
  if self.obj_length is not None and 0 < self.obj_length <= current_length:
@@ -12,6 +12,16 @@
12
12
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
13
  # See the License for the specific language governing permissions and
14
14
  # limitations under the License.
15
+ import os
16
+ from datetime import datetime, timezone
17
+
18
+ from OpenSSL import crypto
19
+ from cryptography import x509
20
+ from cryptography.hazmat.backends import default_backend
21
+ from dateutil import parser
22
+
23
+ from msprobe.core.common.file_utils import FileOpen
24
+ from msprobe.core.common.log import logger
15
25
 
16
26
  cipher_list = ":".join(
17
27
  ["TLS_DHE_RSA_WITH_AES_128_GCM_SHA256",
@@ -42,3 +52,147 @@ cipher_list = ":".join(
42
52
 
43
53
  STRUCT_UNPACK_MODE = "!Q"
44
54
  STR_TO_BYTES_ORDER = "big"
55
+
56
+
57
+ def is_certificate_revoked(cert, crl):
58
+ # 获取证书的序列号
59
+ cert_serial_number = cert.get_serial_number()
60
+
61
+ # 检查证书是否在CRL中
62
+ revoked_serials = [revoked_cert.serial_number for revoked_cert in crl]
63
+ if cert_serial_number in revoked_serials:
64
+ logger.error(f"证书已吊销:{cert_serial_number:020x}")
65
+ return True
66
+
67
+ return False
68
+
69
+
70
+ def verify_callback(conn, cert, errno, depth, preverify_ok, crl=None):
71
+ """
72
+ 验证对端证书的有效性
73
+ :param conn: OpenSSL.SSL.Connection, SSL 连接对象
74
+ :param cert: OpenSSL.crypto.X509, 当前证书
75
+ :param errno: int, OpenSSL错误代码, 0:无错误 | 9:证书过期 | 18: 自签名证书
76
+ :param depth: int, 当前证书在证书链中的深度 (0=叶子节点), 1:中间CA证书 -1:根CA证书 2+:更高级别CA证书
77
+ :param preverify_ok: int, 验证结果 (1=通过, 0=失败)
78
+ :param crl: _CRLInternal, CRL证书对象
79
+ :return: bool, True表示接受证书, False表示拒绝
80
+ """
81
+
82
+ if not preverify_ok:
83
+ from OpenSSL import SSL
84
+ error_str = SSL._ffi.string(SSL._lib.X509_verify_cert_error_string(errno)).decode()
85
+ logger.error(f"证书验证失败 (depth={depth}, err={errno}): {error_str}")
86
+ return False
87
+
88
+ if crl and is_certificate_revoked(cert, crl):
89
+ return False
90
+
91
+ return preverify_ok
92
+
93
+
94
+ def load_ssl_pem(key_file, cert_file, ca_file, crl_file):
95
+ """
96
+ Load SSL PEM files.
97
+
98
+ Args:
99
+ key_file (str): The path to the private key file.
100
+ cert_file (str): The path to the certificate file.
101
+ ca_file (str): The path to the CA certificate file.
102
+ crl_file (str): The path to the CRL file.
103
+
104
+ Returns:
105
+ tuple: (key, crt, ca_crt, crl)
106
+
107
+ Raises:
108
+ Exception: If the file paths are invalid or the file contents are incorrect, exceptions may be thrown.
109
+ """
110
+
111
+ try:
112
+ # your_private_key_password
113
+ passphrase = ""
114
+ if not passphrase:
115
+ import pwinput
116
+ passphrase = pwinput.pwinput("Enter your password: ")
117
+ with FileOpen(key_file, "rb") as f:
118
+ key = crypto.load_privatekey(crypto.FILETYPE_PEM, f.read(), passphrase.encode())
119
+ del passphrase
120
+ with FileOpen(cert_file, "rb") as f:
121
+ crt = crypto.load_certificate(crypto.FILETYPE_PEM, f.read())
122
+ check_crt_valid(crt)
123
+
124
+ crt_serial_number = hex(crt.get_serial_number())[2:]
125
+ logger.info(f"crt_serial_number: {crt_serial_number}")
126
+
127
+ check_certificate_match(crt, key)
128
+
129
+ with FileOpen(ca_file, "rb") as f:
130
+ ca_crt = crypto.load_certificate(crypto.FILETYPE_PEM, f.read())
131
+ check_crt_valid(ca_crt)
132
+
133
+ ca_serial_number = hex(ca_crt.get_serial_number())[2:]
134
+ logger.info(f"ca_serial_number: {ca_serial_number}")
135
+ crl = None
136
+ if os.path.exists(crl_file):
137
+ with FileOpen(crl_file, "rb") as f:
138
+ crl = x509.load_pem_x509_crl(f.read(), default_backend())
139
+ check_crl_valid(crl, ca_crt)
140
+ for revoked_cert in crl:
141
+ logger.info(f"Serial Number: {revoked_cert.serial_number}, "
142
+ f"Revocation Date: {revoked_cert.revocation_date_utc}")
143
+
144
+ except Exception as e:
145
+ raise RuntimeError(f"The SSL certificate is invalid") from e
146
+
147
+ return key, crt, ca_crt, crl
148
+
149
+
150
+ def check_crt_valid(pem):
151
+ """
152
+ Check the validity of the SSL certificate.
153
+
154
+ Raises:
155
+ RuntimeError: If the SSL certificate is invalid or expired.
156
+ """
157
+ try:
158
+ pem_start = parser.parse(pem.get_notBefore().decode("UTF-8"))
159
+ pem_end = parser.parse(pem.get_notAfter().decode("UTF-8"))
160
+ logger.info(f"The SSL certificate passes the verification and the validity period "
161
+ f"starts from {pem_start} ends at {pem_end}.")
162
+ except Exception as e:
163
+ raise RuntimeError(f"The SSL certificate is invalid") from e
164
+
165
+ now_utc = datetime.now(tz=timezone.utc)
166
+ if pem.has_expired() or not (pem_start <= now_utc <= pem_end):
167
+ raise RuntimeError(f"The SSL certificate has expired.")
168
+
169
+
170
+ def check_certificate_match(certificate, private_key):
171
+ """
172
+ Check certificate and private_key is match or not. if mismatched, an exception is thrown.
173
+ :param certificate:
174
+ :param private_key:
175
+ :return:
176
+ """
177
+ test_data = os.urandom(256)
178
+ try:
179
+ signature = crypto.sign(private_key, test_data, "sha256")
180
+ crypto.verify(
181
+ certificate, # 包含公钥的证书
182
+ signature, # 生成的签名
183
+ test_data, # 原始数据
184
+ "sha256", # 哈希算法
185
+ )
186
+ logger.info("公钥和私钥匹配")
187
+ except Exception as e:
188
+ raise RuntimeError("公钥和私钥不匹配") from e
189
+
190
+
191
+ def check_crl_valid(crl, ca_crt):
192
+ # 验证CRL签名(确保CRL未被篡改)
193
+ if not crl.is_signature_valid(ca_crt.get_pubkey().to_cryptography_key()):
194
+ raise RuntimeError("CRL签名无效!")
195
+
196
+ # 检查CRL有效期
197
+ if not (crl.last_update <= datetime.utcnow() <= crl.next_update):
198
+ raise RuntimeError("CRL已过期或尚未生效!")