mindstudio-probe 1.2.2__py3-none-any.whl → 8.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/METADATA +4 -3
- {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/RECORD +243 -191
- msprobe/README.md +57 -21
- msprobe/core/__init__.py +17 -0
- msprobe/core/common/const.py +224 -82
- msprobe/core/common/decorator.py +50 -0
- msprobe/core/common/exceptions.py +5 -3
- msprobe/core/common/file_utils.py +274 -40
- msprobe/core/common/framework_adapter.py +169 -0
- msprobe/core/common/global_lock.py +86 -0
- msprobe/core/common/runtime.py +25 -0
- msprobe/core/common/utils.py +148 -72
- msprobe/core/common_config.py +7 -0
- msprobe/core/compare/acc_compare.py +640 -462
- msprobe/core/compare/check.py +36 -107
- msprobe/core/compare/compare_cli.py +4 -0
- msprobe/core/compare/config.py +72 -0
- msprobe/core/compare/highlight.py +217 -215
- msprobe/core/compare/layer_mapping/layer_mapping.py +4 -1
- msprobe/core/compare/merge_result/merge_result.py +12 -6
- msprobe/core/compare/multiprocessing_compute.py +227 -107
- msprobe/core/compare/npy_compare.py +32 -16
- msprobe/core/compare/utils.py +218 -244
- msprobe/{mindspore/runtime.py → core/config_check/__init__.py} +2 -4
- msprobe/{pytorch/dump/kernel_dump/kernel_config.py → core/config_check/checkers/__init__.py} +8 -16
- msprobe/core/config_check/checkers/base_checker.py +60 -0
- msprobe/core/config_check/checkers/dataset_checker.py +138 -0
- msprobe/core/config_check/checkers/env_args_checker.py +96 -0
- msprobe/core/config_check/checkers/hyperparameter_checker.py +170 -0
- msprobe/core/config_check/checkers/pip_checker.py +90 -0
- msprobe/core/config_check/checkers/random_checker.py +367 -0
- msprobe/core/config_check/checkers/weights_checker.py +147 -0
- msprobe/core/config_check/ckpt_compare/ckpt_comparator.py +74 -0
- msprobe/core/config_check/ckpt_compare/megatron_loader.py +302 -0
- msprobe/core/config_check/ckpt_compare/metrics.py +83 -0
- msprobe/core/config_check/ckpt_compare/name_mapping.yaml +12 -0
- msprobe/core/config_check/config_check_cli.py +51 -0
- msprobe/core/config_check/config_checker.py +100 -0
- msprobe/{pytorch/parse.py → core/config_check/resource/dependency.yaml} +7 -4
- msprobe/core/config_check/resource/env.yaml +57 -0
- msprobe/core/config_check/resource/hyperparameter.yaml +21 -0
- msprobe/core/config_check/utils/hyperparameter_parser.py +115 -0
- msprobe/core/config_check/utils/utils.py +107 -0
- msprobe/core/data_dump/api_registry.py +239 -0
- msprobe/core/data_dump/data_collector.py +36 -9
- msprobe/core/data_dump/data_processor/base.py +74 -53
- msprobe/core/data_dump/data_processor/mindspore_processor.py +119 -78
- msprobe/core/data_dump/data_processor/pytorch_processor.py +134 -96
- msprobe/core/data_dump/json_writer.py +146 -57
- msprobe/core/debugger/precision_debugger.py +143 -0
- msprobe/core/grad_probe/constant.py +2 -1
- msprobe/core/grad_probe/grad_compare.py +2 -2
- msprobe/core/grad_probe/utils.py +1 -1
- msprobe/core/hook_manager.py +242 -0
- msprobe/core/monitor/anomaly_processor.py +384 -0
- msprobe/core/overflow_check/abnormal_scene.py +2 -0
- msprobe/core/service.py +356 -0
- msprobe/core/single_save/__init__.py +0 -0
- msprobe/core/single_save/single_comparator.py +243 -0
- msprobe/core/single_save/single_saver.py +157 -0
- msprobe/docs/01.installation.md +6 -5
- msprobe/docs/02.config_introduction.md +89 -30
- msprobe/docs/03.config_examples.md +1 -0
- msprobe/docs/04.kernel_dump_PyTorch.md +1 -1
- msprobe/docs/05.data_dump_PyTorch.md +184 -50
- msprobe/docs/06.data_dump_MindSpore.md +193 -28
- msprobe/docs/07.accuracy_checker_PyTorch.md +13 -3
- msprobe/docs/08.accuracy_checker_online_PyTorch.md +72 -10
- msprobe/docs/09.accuracy_checker_MindSpore.md +19 -7
- msprobe/docs/10.accuracy_compare_PyTorch.md +266 -102
- msprobe/docs/11.accuracy_compare_MindSpore.md +117 -43
- msprobe/docs/12.overflow_check_PyTorch.md +5 -3
- msprobe/docs/13.overflow_check_MindSpore.md +6 -4
- msprobe/docs/14.data_parse_PyTorch.md +4 -10
- msprobe/docs/17.grad_probe.md +2 -1
- msprobe/docs/18.online_dispatch.md +3 -3
- msprobe/docs/19.monitor.md +211 -103
- msprobe/docs/21.visualization_PyTorch.md +100 -28
- msprobe/docs/22.visualization_MindSpore.md +103 -31
- msprobe/docs/23.generate_operator_PyTorch.md +9 -9
- msprobe/docs/25.tool_function_introduction.md +23 -22
- msprobe/docs/26.data_dump_PyTorch_baseline.md +14 -3
- msprobe/docs/27.dump_json_instruction.md +278 -8
- msprobe/docs/28.debugger_save_instruction.md +111 -20
- msprobe/docs/28.kernel_dump_MindSpore.md +1 -1
- msprobe/docs/29.data_dump_MSAdapter.md +229 -0
- msprobe/docs/30.overflow_check_MSAdapter.md +31 -0
- msprobe/docs/31.config_check.md +95 -0
- msprobe/docs/32.ckpt_compare.md +69 -0
- msprobe/docs/33.generate_operator_MindSpore.md +190 -0
- msprobe/docs/34.RL_collect.md +92 -0
- msprobe/docs/35.nan_analyze.md +72 -0
- msprobe/docs/FAQ.md +3 -11
- msprobe/docs/data_dump_MindSpore/data_dump_MindSpore_baseline.md +12 -1
- msprobe/docs/data_dump_MindSpore/dynamic_graph_quick_start_example.md +3 -1
- msprobe/docs/img/compare_result.png +0 -0
- msprobe/docs/img/merge_result.png +0 -0
- msprobe/docs/img/save_compare_result_sample.png +0 -0
- msprobe/docs/img/visualization/proxy.png +0 -0
- msprobe/docs/img/visualization/vis_browser_1.png +0 -0
- msprobe/docs/img/visualization/vis_match_info.png +0 -0
- msprobe/docs/img/visualization/vis_precision_info.png +0 -0
- msprobe/docs/img/visualization/vis_search_info.png +0 -0
- msprobe/docs/img/visualization/vis_show_info.png +0 -0
- msprobe/docs/img/visualization/vis_showcase.png +0 -0
- msprobe/docs/img/visualization/vis_unmatch_info.png +0 -0
- msprobe/mindspore/__init__.py +3 -3
- msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +151 -55
- msprobe/mindspore/api_accuracy_checker/api_runner.py +25 -11
- msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +2 -1
- msprobe/mindspore/api_accuracy_checker/bench_functions/flash_attention_score.py +580 -0
- msprobe/mindspore/api_accuracy_checker/bench_functions/fusion_operator.py +41 -0
- msprobe/mindspore/api_accuracy_checker/cmd_parser.py +4 -0
- msprobe/mindspore/api_accuracy_checker/data_manager.py +4 -3
- msprobe/mindspore/api_accuracy_checker/generate_op_script/config_op.json +9 -0
- msprobe/mindspore/api_accuracy_checker/generate_op_script/op_generator.py +451 -0
- msprobe/mindspore/api_accuracy_checker/generate_op_script/operator_replication.template +2081 -0
- msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +11 -1
- msprobe/mindspore/api_accuracy_checker/torch_mindtorch_importer.py +2 -1
- msprobe/mindspore/cell_processor.py +204 -33
- msprobe/mindspore/code_mapping/graph_parser.py +4 -21
- msprobe/mindspore/common/const.py +73 -2
- msprobe/mindspore/common/utils.py +157 -29
- msprobe/mindspore/compare/common_dir_compare.py +382 -0
- msprobe/mindspore/compare/distributed_compare.py +2 -26
- msprobe/mindspore/compare/ms_compare.py +18 -398
- msprobe/mindspore/compare/ms_graph_compare.py +20 -10
- msprobe/mindspore/compare/utils.py +37 -0
- msprobe/mindspore/debugger/debugger_config.py +59 -7
- msprobe/mindspore/debugger/precision_debugger.py +83 -90
- msprobe/mindspore/dump/cell_dump_process.py +902 -0
- msprobe/mindspore/dump/cell_dump_with_insert_gradient.py +889 -0
- msprobe/mindspore/dump/dump_tool_factory.py +18 -8
- msprobe/mindspore/dump/graph_mode_cell_dump.py +139 -0
- msprobe/mindspore/dump/graph_tensor_dump.py +123 -0
- msprobe/mindspore/dump/hook_cell/api_register.py +176 -0
- msprobe/mindspore/dump/hook_cell/hook_cell.py +22 -12
- msprobe/mindspore/dump/hook_cell/ms_hook_manager.py +88 -0
- msprobe/mindspore/dump/hook_cell/primitive_hooks.py +8 -2
- msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +42 -26
- msprobe/mindspore/dump/jit_dump.py +35 -27
- msprobe/mindspore/dump/kernel_kbyk_dump.py +6 -3
- msprobe/mindspore/dym_loader/hook_dynamic_loader.cpp +110 -0
- msprobe/mindspore/dym_loader/hook_dynamic_loader.h +15 -16
- msprobe/mindspore/free_benchmark/api_pynative_self_check.py +22 -12
- msprobe/mindspore/free_benchmark/common/utils.py +1 -1
- msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +4 -2
- msprobe/mindspore/free_benchmark/self_check_tool_factory.py +6 -3
- msprobe/mindspore/grad_probe/global_context.py +9 -2
- msprobe/mindspore/grad_probe/grad_analyzer.py +2 -1
- msprobe/mindspore/grad_probe/grad_stat_csv.py +3 -2
- msprobe/mindspore/grad_probe/hook.py +2 -4
- msprobe/mindspore/mindspore_service.py +111 -0
- msprobe/mindspore/monitor/common_func.py +52 -0
- msprobe/mindspore/monitor/data_writers.py +237 -0
- msprobe/mindspore/monitor/distributed/wrap_distributed.py +1 -1
- msprobe/mindspore/monitor/features.py +13 -1
- msprobe/mindspore/monitor/module_hook.py +568 -444
- msprobe/mindspore/monitor/optimizer_collect.py +331 -0
- msprobe/mindspore/monitor/utils.py +71 -9
- msprobe/mindspore/ms_config.py +16 -15
- msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +5 -3
- msprobe/mindspore/task_handler_factory.py +5 -2
- msprobe/msprobe.py +19 -0
- msprobe/nan_analyze/__init__.py +14 -0
- msprobe/nan_analyze/analyzer.py +255 -0
- msprobe/nan_analyze/graph.py +189 -0
- msprobe/nan_analyze/utils.py +211 -0
- msprobe/pytorch/api_accuracy_checker/common/config.py +2 -2
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +3 -6
- msprobe/pytorch/api_accuracy_checker/compare/compare.py +36 -34
- msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +15 -13
- msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +206 -4
- msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +9 -9
- msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +6 -5
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +31 -9
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +28 -20
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +3 -1
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +29 -13
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +12 -2
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +45 -31
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +154 -0
- msprobe/pytorch/attl_manager.py +65 -0
- msprobe/pytorch/bench_functions/moe_gating_top_k_softmax.py +6 -0
- msprobe/pytorch/bench_functions/npu_fusion_attention.py +27 -0
- msprobe/pytorch/common/utils.py +53 -19
- msprobe/pytorch/compare/distributed_compare.py +4 -36
- msprobe/pytorch/compare/pt_compare.py +13 -84
- msprobe/pytorch/compare/utils.py +47 -0
- msprobe/pytorch/debugger/debugger_config.py +34 -17
- msprobe/pytorch/debugger/precision_debugger.py +50 -96
- msprobe/pytorch/dump/module_dump/hook_wrapper.py +93 -0
- msprobe/pytorch/dump/module_dump/module_dump.py +15 -61
- msprobe/pytorch/dump/module_dump/module_processer.py +150 -114
- msprobe/pytorch/free_benchmark/common/utils.py +1 -1
- msprobe/pytorch/free_benchmark/compare/single_benchmark.py +1 -1
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +3 -3
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +3 -3
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +1 -1
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +1 -1
- msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +1 -1
- msprobe/pytorch/function_factory.py +1 -1
- msprobe/pytorch/grad_probe/grad_monitor.py +2 -2
- msprobe/pytorch/grad_probe/grad_stat_csv.py +3 -2
- msprobe/pytorch/hook_module/api_register.py +155 -0
- msprobe/pytorch/hook_module/hook_module.py +18 -22
- msprobe/pytorch/hook_module/jit_script_wrapper.py +33 -0
- msprobe/pytorch/hook_module/pt_hook_manager.py +68 -0
- msprobe/pytorch/hook_module/register_optimizer_hook.py +2 -1
- msprobe/pytorch/hook_module/support_wrap_ops.yaml +193 -75
- msprobe/pytorch/hook_module/utils.py +28 -2
- msprobe/pytorch/monitor/csv2tb.py +14 -4
- msprobe/pytorch/monitor/data_writers.py +259 -0
- msprobe/pytorch/monitor/distributed/wrap_distributed.py +8 -2
- msprobe/pytorch/monitor/module_hook.py +336 -241
- msprobe/pytorch/monitor/module_metric.py +17 -0
- msprobe/pytorch/monitor/optimizer_collect.py +244 -224
- msprobe/pytorch/monitor/utils.py +84 -4
- msprobe/pytorch/online_dispatch/compare.py +0 -2
- msprobe/pytorch/online_dispatch/dispatch.py +13 -2
- msprobe/pytorch/online_dispatch/dump_compare.py +8 -2
- msprobe/pytorch/online_dispatch/utils.py +3 -0
- msprobe/pytorch/parse_tool/lib/interactive_cli.py +1 -6
- msprobe/pytorch/parse_tool/lib/utils.py +5 -4
- msprobe/pytorch/pt_config.py +16 -11
- msprobe/pytorch/pytorch_service.py +70 -0
- msprobe/visualization/builder/graph_builder.py +69 -10
- msprobe/visualization/builder/msprobe_adapter.py +24 -12
- msprobe/visualization/compare/graph_comparator.py +63 -51
- msprobe/visualization/compare/mode_adapter.py +22 -20
- msprobe/visualization/graph/base_node.py +11 -4
- msprobe/visualization/graph/distributed_analyzer.py +1 -10
- msprobe/visualization/graph/graph.py +2 -13
- msprobe/visualization/graph/node_op.py +1 -2
- msprobe/visualization/graph_service.py +251 -104
- msprobe/visualization/utils.py +26 -44
- msprobe/mindspore/dump/hook_cell/api_registry.py +0 -207
- msprobe/mindspore/dump/hook_cell/wrap_api.py +0 -212
- msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +0 -140
- msprobe/mindspore/monitor/anomaly_detect.py +0 -404
- msprobe/mindspore/monitor/module_spec_verifier.py +0 -94
- msprobe/mindspore/service.py +0 -543
- msprobe/pytorch/hook_module/api_registry.py +0 -166
- msprobe/pytorch/hook_module/wrap_distributed.py +0 -79
- msprobe/pytorch/hook_module/wrap_functional.py +0 -66
- msprobe/pytorch/hook_module/wrap_npu_custom.py +0 -85
- msprobe/pytorch/hook_module/wrap_tensor.py +0 -69
- msprobe/pytorch/hook_module/wrap_torch.py +0 -84
- msprobe/pytorch/hook_module/wrap_vf.py +0 -60
- msprobe/pytorch/monitor/anomaly_analyse.py +0 -201
- msprobe/pytorch/monitor/anomaly_detect.py +0 -410
- msprobe/pytorch/monitor/module_spec_verifier.py +0 -95
- msprobe/pytorch/monitor/unittest/test_monitor.py +0 -160
- msprobe/pytorch/service.py +0 -470
- {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/LICENSE +0 -0
- {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/WHEEL +0 -0
- {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/entry_points.txt +0 -0
- {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/top_level.txt +0 -0
- /msprobe/{mindspore → core}/compare/ms_to_pt_api.yaml +0 -0
- /msprobe/{mindspore/dump → core}/kernel_dump/kernel_config.py +0 -0
- /msprobe/{pytorch/monitor/unittest → core/monitor}/__init__.py +0 -0
|
@@ -45,11 +45,11 @@ from msprobe.pytorch.api_accuracy_checker.compare.compare_column import CompareC
|
|
|
45
45
|
from msprobe.pytorch.api_accuracy_checker.common.config import CheckerConfig
|
|
46
46
|
from msprobe.pytorch.common.parse_json import parse_json_info_forward_backward
|
|
47
47
|
from msprobe.core.common.file_utils import FileChecker, change_mode, \
|
|
48
|
-
create_directory, get_json_contents, read_csv, check_file_or_directory_path
|
|
48
|
+
create_directory, get_json_contents, read_csv, check_file_or_directory_path
|
|
49
49
|
from msprobe.pytorch.common.log import logger
|
|
50
50
|
from msprobe.pytorch.pt_config import parse_json_config
|
|
51
51
|
from msprobe.core.common.const import Const, FileCheckConst, CompareConst
|
|
52
|
-
from msprobe.core.common.utils import safe_get_value, CompareException
|
|
52
|
+
from msprobe.core.common.utils import safe_get_value, CompareException, is_int, check_op_str_pattern_valid
|
|
53
53
|
from msprobe.pytorch.common.utils import seed_all
|
|
54
54
|
from msprobe.pytorch.api_accuracy_checker.tensor_transport_layer.attl import ATTL, ATTLConfig, move2device_exec
|
|
55
55
|
from msprobe.pytorch.api_accuracy_checker.tensor_transport_layer.device_dispatch import ConsumerDispatcher
|
|
@@ -65,6 +65,8 @@ DETAILS_FILE_NAME = "accuracy_checking_details_" + current_time + ".csv"
|
|
|
65
65
|
|
|
66
66
|
not_backward_list = ['repeat_interleave']
|
|
67
67
|
unsupported_backward_list = ['masked_select']
|
|
68
|
+
unsupported_api_list = ["to", "empty", "empty_like", "empty_strided", "new_empty", "new_empty_strided",
|
|
69
|
+
"empty_with_format"]
|
|
68
70
|
|
|
69
71
|
|
|
70
72
|
tqdm_params = {
|
|
@@ -83,6 +85,9 @@ tqdm_params = {
|
|
|
83
85
|
}
|
|
84
86
|
|
|
85
87
|
|
|
88
|
+
seed_all()
|
|
89
|
+
|
|
90
|
+
|
|
86
91
|
def run_ut(config):
|
|
87
92
|
logger.info("start UT test")
|
|
88
93
|
if config.online_config.is_online:
|
|
@@ -93,7 +98,7 @@ def run_ut(config):
|
|
|
93
98
|
logger.info(f"UT task details will be saved in {config.details_csv_path}")
|
|
94
99
|
|
|
95
100
|
if config.save_error_data:
|
|
96
|
-
logger.info(f"UT task
|
|
101
|
+
logger.info(f"UT task error_data will be saved in {config.error_data_path}")
|
|
97
102
|
compare = Comparator(config.result_csv_path, config.details_csv_path, config.is_continue_run_ut, config=config)
|
|
98
103
|
|
|
99
104
|
if config.online_config.is_online:
|
|
@@ -117,6 +122,7 @@ def run_ut(config):
|
|
|
117
122
|
def run_api_offline(config, compare, api_name_set):
|
|
118
123
|
err_column = CompareColumn()
|
|
119
124
|
for _, (api_full_name, api_info_dict) in enumerate(tqdm(config.forward_content.items(), **tqdm_params)):
|
|
125
|
+
check_op_str_pattern_valid(api_full_name)
|
|
120
126
|
if api_full_name in api_name_set:
|
|
121
127
|
continue
|
|
122
128
|
if is_unsupported_api(api_full_name):
|
|
@@ -218,6 +224,7 @@ def blacklist_and_whitelist_filter(api_name, black_list, white_list):
|
|
|
218
224
|
If api is both in black_list and black_list, black_list first.
|
|
219
225
|
return: False for exec api, True for not exec
|
|
220
226
|
"""
|
|
227
|
+
black_list.extend(unsupported_api_list)
|
|
221
228
|
if black_list and api_name in black_list:
|
|
222
229
|
return True
|
|
223
230
|
if white_list and api_name not in white_list:
|
|
@@ -317,7 +324,8 @@ def run_torch_api_online(api_full_name, api_data, backward_content):
|
|
|
317
324
|
if kwargs.get("device"):
|
|
318
325
|
del kwargs["device"]
|
|
319
326
|
|
|
320
|
-
|
|
327
|
+
device_exec_params = ExecParams(api_type, api_name, current_device, args, kwargs, False, None)
|
|
328
|
+
device_out = exec_api(device_exec_params)
|
|
321
329
|
device_out = move2device_exec(device_out, "cpu")
|
|
322
330
|
return UtDataInfo(None, None, out, device_out, None, in_fwd_data_list, None, rank=api_data.rank)
|
|
323
331
|
|
|
@@ -344,6 +352,9 @@ def need_to_backward(grad_index, out):
|
|
|
344
352
|
|
|
345
353
|
def run_backward(args, grad, grad_index, out):
|
|
346
354
|
if grad_index is not None:
|
|
355
|
+
if not is_int(grad_index):
|
|
356
|
+
logger.error(f"{grad_index} dtype is not int")
|
|
357
|
+
raise TypeError(f"{grad_index} dtype is not int")
|
|
347
358
|
if grad_index >= len(out):
|
|
348
359
|
logger.error(f"Run backward error when grad_index is {grad_index}")
|
|
349
360
|
raise IndexError(f"Run backward error when grad_index is {grad_index}")
|
|
@@ -430,6 +441,7 @@ def preprocess_forward_content(forward_content):
|
|
|
430
441
|
arg_cache = {}
|
|
431
442
|
|
|
432
443
|
for key, value in forward_content.items():
|
|
444
|
+
check_op_str_pattern_valid(key)
|
|
433
445
|
base_key = key.rsplit(Const.SEP, 1)[0]
|
|
434
446
|
|
|
435
447
|
if key not in arg_cache:
|
|
@@ -469,7 +481,7 @@ def _run_ut(parser=None):
|
|
|
469
481
|
_run_ut_parser(parser)
|
|
470
482
|
args = parser.parse_args(sys.argv[1:])
|
|
471
483
|
run_ut_command(args)
|
|
472
|
-
|
|
484
|
+
|
|
473
485
|
|
|
474
486
|
def checked_online_config(online_config):
|
|
475
487
|
if not online_config.is_online:
|
|
@@ -491,7 +503,10 @@ def checked_online_config(online_config):
|
|
|
491
503
|
check_file_or_directory_path(online_config.tls_path, isdir=True)
|
|
492
504
|
check_file_or_directory_path(os.path.join(online_config.tls_path, "server.key"))
|
|
493
505
|
check_file_or_directory_path(os.path.join(online_config.tls_path, "server.crt"))
|
|
494
|
-
|
|
506
|
+
check_file_or_directory_path(os.path.join(online_config.tls_path, "ca.crt"))
|
|
507
|
+
crl_path = os.path.join(online_config.tls_path, "crl.pem")
|
|
508
|
+
if os.path.exists(crl_path):
|
|
509
|
+
check_file_or_directory_path(crl_path)
|
|
495
510
|
|
|
496
511
|
# host and port
|
|
497
512
|
if not isinstance(online_config.host, str) or not re.match(Const.ipv4_pattern, online_config.host):
|
|
@@ -561,7 +576,15 @@ def run_ut_command(args):
|
|
|
561
576
|
error_data_path = checker_config.error_data_path
|
|
562
577
|
if save_error_data:
|
|
563
578
|
if args.result_csv_path:
|
|
564
|
-
|
|
579
|
+
parts_by_dot = result_csv_path.split(Const.SEP)
|
|
580
|
+
if len(parts_by_dot) < 2 or not parts_by_dot[0]:
|
|
581
|
+
raise ValueError("result_csv_path does not contain a valid file name with an extension.")
|
|
582
|
+
file_name_part = parts_by_dot[0]
|
|
583
|
+
parts_by_underscore = file_name_part.split(Const.REPLACEMENT_CHARACTER)
|
|
584
|
+
if len(parts_by_underscore) < 2:
|
|
585
|
+
raise ValueError("File name part does not contain enough '_' separated segments.")
|
|
586
|
+
time_info = parts_by_underscore[-1]
|
|
587
|
+
|
|
565
588
|
global UT_ERROR_DATA_DIR
|
|
566
589
|
UT_ERROR_DATA_DIR = 'ut_error_data' + time_info
|
|
567
590
|
error_data_path = initialize_save_error_data(error_data_path)
|
|
@@ -579,9 +602,8 @@ def run_ut_command(args):
|
|
|
579
602
|
}
|
|
580
603
|
run_ut_config = checker_config.get_run_ut_config(**config_params)
|
|
581
604
|
run_ut(run_ut_config)
|
|
605
|
+
logger.info("UT task completed.")
|
|
582
606
|
|
|
583
607
|
|
|
584
608
|
if __name__ == '__main__':
|
|
585
|
-
seed_all()
|
|
586
609
|
_run_ut()
|
|
587
|
-
logger.info("UT task completed.")
|
|
@@ -1,9 +1,7 @@
|
|
|
1
|
-
|
|
2
|
-
# -*- coding: utf-8 -*-
|
|
3
|
-
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
1
|
+
# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
|
|
4
2
|
# All rights reserved.
|
|
5
3
|
#
|
|
6
|
-
# Licensed under the Apache License, Version 2.0
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
7
5
|
# you may not use this file except in compliance with the License.
|
|
8
6
|
# You may obtain a copy of the License at
|
|
9
7
|
#
|
|
@@ -18,8 +16,8 @@
|
|
|
18
16
|
import os
|
|
19
17
|
from collections import namedtuple
|
|
20
18
|
import re
|
|
21
|
-
import torch
|
|
22
19
|
|
|
20
|
+
import torch
|
|
23
21
|
try:
|
|
24
22
|
import torch_npu
|
|
25
23
|
except ImportError:
|
|
@@ -33,11 +31,9 @@ from msprobe.core.common.const import FileCheckConst, Const, CompareConst
|
|
|
33
31
|
from msprobe.core.common.file_utils import FileChecker
|
|
34
32
|
from msprobe.core.common.log import logger
|
|
35
33
|
from msprobe.core.common.utils import CompareException
|
|
34
|
+
from msprobe.pytorch.hook_module.api_register import ApiTemplate, get_api_register
|
|
36
35
|
from msprobe.pytorch.hook_module.wrap_aten import AtenOPTemplate
|
|
37
|
-
|
|
38
|
-
from msprobe.pytorch.hook_module.wrap_npu_custom import NpuOPTemplate
|
|
39
|
-
from msprobe.pytorch.hook_module.wrap_tensor import TensorOPTemplate
|
|
40
|
-
from msprobe.pytorch.hook_module.wrap_torch import TorchOPTemplate
|
|
36
|
+
|
|
41
37
|
|
|
42
38
|
hf_32_standard_api = ["conv1d", "conv2d"]
|
|
43
39
|
not_detach_set = {'resize_', 'resize_as_', 'set_', 'transpose_', 't_', 'squeeze_', 'unsqueeze_'}
|
|
@@ -108,17 +104,28 @@ def exec_api(exec_params):
|
|
|
108
104
|
kwargs = exec_params.kwargs
|
|
109
105
|
is_autocast = exec_params.is_autocast
|
|
110
106
|
autocast_dtype = exec_params.autocast_dtype
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
if api_type
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
107
|
+
out = None
|
|
108
|
+
|
|
109
|
+
prefix_map = Const.API_DATA_PREFIX.get(Const.PT_FRAMEWORK, {})
|
|
110
|
+
if not prefix_map or api_type not in prefix_map.values() or \
|
|
111
|
+
api_type not in (
|
|
112
|
+
Const.FUNCTIONAL_API_TYPE_PREFIX,
|
|
113
|
+
Const.TENSOR_API_TYPE_PREFIX,
|
|
114
|
+
Const.TORCH_API_TYPE_PREFIX,
|
|
115
|
+
Const.ATEN_API_TYPE_PREFIX,
|
|
116
|
+
Const.NPU_API_TYPE_PREFIX
|
|
117
|
+
):
|
|
118
|
+
return out
|
|
119
|
+
|
|
120
|
+
if api_type == Const.ATEN_API_TYPE_PREFIX:
|
|
119
121
|
torch_api = AtenOPTemplate(api_name, None, False)
|
|
120
|
-
|
|
121
|
-
|
|
122
|
+
else:
|
|
123
|
+
api_register = get_api_register()
|
|
124
|
+
api_register.initialize_hook(None)
|
|
125
|
+
api_func_type = list(prefix_map.keys())[list(prefix_map.values()).index(api_type)]
|
|
126
|
+
api_func = api_register.ori_api_attr.get(Const.PT_FRAMEWORK + Const.SEP + api_func_type, {}).get(api_name)
|
|
127
|
+
|
|
128
|
+
torch_api = ApiTemplate(api_name, api_func, api_type, None, need_hook=False, device=device)
|
|
122
129
|
if is_autocast:
|
|
123
130
|
with autocast(dtype=autocast_dtype):
|
|
124
131
|
out = torch_api.forward(*args, **kwargs)
|
|
@@ -248,7 +255,8 @@ def record_skip_info(api_full_name, compare, compare_alg_results):
|
|
|
248
255
|
|
|
249
256
|
def is_unsupported_api(api_name, is_overflow_check=False):
|
|
250
257
|
split_name = api_name.split(Const.SEP)[0]
|
|
251
|
-
|
|
258
|
+
unsupport_type_list = [Const.DISTRIBUTED, Const.MINDSPEED_API_TYPE_PREFIX]
|
|
259
|
+
flag = (split_name in unsupport_type_list) or (is_overflow_check and split_name == Const.NPU)
|
|
252
260
|
if flag:
|
|
253
261
|
logger.info(f"{split_name} api is not supported for run ut. SKIP.")
|
|
254
262
|
return flag
|
|
@@ -27,6 +27,7 @@ from msprobe.pytorch.api_accuracy_checker.tensor_transport_layer.client import T
|
|
|
27
27
|
from msprobe.pytorch.api_accuracy_checker.tensor_transport_layer.server import TCPServer
|
|
28
28
|
from msprobe.core.common.file_utils import remove_path
|
|
29
29
|
from msprobe.pytorch.common.utils import logger, save_api_data, load_api_data, save_pkl, load_pkl
|
|
30
|
+
from msprobe.core.common.decorator import recursion_depth_decorator
|
|
30
31
|
|
|
31
32
|
BufferType = Union[ApiData, Dict[str, Any], str] # Union[Tensor, Tuple[Optional[Tensor]]]
|
|
32
33
|
|
|
@@ -168,11 +169,12 @@ class ATTL:
|
|
|
168
169
|
return buffer
|
|
169
170
|
|
|
170
171
|
|
|
172
|
+
@recursion_depth_decorator("move2device_exec")
|
|
171
173
|
def move2device_exec(obj, device):
|
|
172
174
|
if isinstance(obj, (tuple, list)):
|
|
173
175
|
data_list = [move2device_exec(val, device) for val in obj]
|
|
174
176
|
return data_list if isinstance(obj, list) else tuple(data_list)
|
|
175
|
-
if isinstance(obj, dict):
|
|
177
|
+
if isinstance(obj, dict):
|
|
176
178
|
return {key: move2device_exec(val, device) for key, val in obj.items()}
|
|
177
179
|
elif isinstance(obj, torch.Tensor):
|
|
178
180
|
obj = obj.detach()
|
|
@@ -12,23 +12,22 @@
|
|
|
12
12
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
13
|
# See the License for the specific language governing permissions and
|
|
14
14
|
# limitations under the License.
|
|
15
|
-
|
|
16
|
-
import
|
|
15
|
+
from functools import partial
|
|
16
|
+
import zlib
|
|
17
17
|
import io
|
|
18
18
|
import struct
|
|
19
19
|
import time
|
|
20
20
|
import os
|
|
21
|
-
import signal
|
|
22
21
|
from queue import Queue
|
|
23
22
|
from threading import Thread
|
|
24
23
|
from typing import Union
|
|
25
24
|
|
|
26
|
-
from twisted.internet import reactor, protocol, endpoints
|
|
25
|
+
from twisted.internet import reactor, protocol, endpoints, ssl
|
|
27
26
|
from twisted.protocols.basic import FileSender
|
|
28
27
|
|
|
29
28
|
from msprobe.pytorch.common.utils import logger
|
|
30
29
|
from msprobe.pytorch.api_accuracy_checker.tensor_transport_layer.utils import STRUCT_UNPACK_MODE as unpack_mode, \
|
|
31
|
-
STR_TO_BYTES_ORDER as bytes_order
|
|
30
|
+
STR_TO_BYTES_ORDER as bytes_order, cipher_list, verify_callback, load_ssl_pem
|
|
32
31
|
|
|
33
32
|
MAX_SENDING_QUEUE_SIZE = 20
|
|
34
33
|
|
|
@@ -104,11 +103,28 @@ class TCPClient:
|
|
|
104
103
|
self.factory = MessageClientFactory()
|
|
105
104
|
self.factory.protocol = cur_protocol
|
|
106
105
|
if self.tls_path:
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
106
|
+
client_key, client_crt, ca_crt, crl_pem = load_ssl_pem(
|
|
107
|
+
key_file=os.path.join(self.tls_path, "client.key"),
|
|
108
|
+
cert_file=os.path.join(self.tls_path, "client.crt"),
|
|
109
|
+
ca_file=os.path.join(self.tls_path, "ca.crt"),
|
|
110
|
+
crl_file=os.path.join(self.tls_path, "crl.pem")
|
|
111
|
+
)
|
|
112
|
+
|
|
113
|
+
ssl_options = ssl.CertificateOptions(
|
|
114
|
+
privateKey=client_key,
|
|
115
|
+
certificate=client_crt,
|
|
116
|
+
method=ssl.SSL.TLSv1_2_METHOD,
|
|
117
|
+
verify=True,
|
|
118
|
+
requireCertificate=True,
|
|
119
|
+
caCerts=[ca_crt], # 信任的CA证书列表
|
|
120
|
+
)
|
|
121
|
+
ssl_context = ssl_options.getContext()
|
|
122
|
+
ssl_context.set_cipher_list(cipher_list)
|
|
123
|
+
ssl_context.set_options(ssl.SSL.OP_NO_RENEGOTIATION)
|
|
124
|
+
ssl_context.set_verify(ssl.SSL.VERIFY_PEER | ssl.SSL.VERIFY_FAIL_IF_NO_PEER_CERT,
|
|
125
|
+
partial(verify_callback, crl=crl_pem))
|
|
126
|
+
|
|
127
|
+
endpoint = endpoints.SSL4ClientEndpoint(reactor, self.host, self.port, ssl_options)
|
|
112
128
|
else:
|
|
113
129
|
endpoint = endpoints.TCP4ClientEndpoint(reactor, self.host, self.port)
|
|
114
130
|
d = endpoint.connect(self.factory)
|
|
@@ -299,12 +315,12 @@ class ClientProtocol(protocol.Protocol):
|
|
|
299
315
|
|
|
300
316
|
def send_wrapped_data(self, data, sequence_number: int = 0, rank: int = 0, step: int = 0):
|
|
301
317
|
length = len(data)
|
|
302
|
-
|
|
318
|
+
data_crc = f"{zlib.crc32(data):08x}" if self.check_sum else ""
|
|
303
319
|
data_meaasge = length.to_bytes(8, byteorder=bytes_order) + \
|
|
304
320
|
sequence_number.to_bytes(8, byteorder=bytes_order) + \
|
|
305
321
|
rank.to_bytes(8, byteorder=bytes_order) + \
|
|
306
322
|
step.to_bytes(8, byteorder=bytes_order) + \
|
|
307
|
-
|
|
323
|
+
data_crc.encode() + \
|
|
308
324
|
data
|
|
309
325
|
logger.debug(f"send 流水号: {sequence_number}; RANK: {rank}; STEP: {step}; LENGTH: {length}")
|
|
310
326
|
|
|
@@ -346,7 +362,7 @@ class ClientProtocol(protocol.Protocol):
|
|
|
346
362
|
def connectionLost(self, reason):
|
|
347
363
|
self.signal_exit = True
|
|
348
364
|
self.factory.num_connections -= 1
|
|
349
|
-
logger.info(f"Lost connection with server, reason is : {reason}")
|
|
365
|
+
logger.info(f"Lost connection with server, reason is : {reason.value}")
|
|
350
366
|
|
|
351
367
|
|
|
352
368
|
class MessageClientFactory(protocol.ClientFactory):
|
|
@@ -29,7 +29,6 @@ from msprobe.pytorch.common.log import logger
|
|
|
29
29
|
from msprobe.pytorch.api_accuracy_checker.tensor_transport_layer.attl import move2target_device
|
|
30
30
|
from msprobe.pytorch.api_accuracy_checker.run_ut.run_ut_utils import generate_cpu_params
|
|
31
31
|
|
|
32
|
-
|
|
33
32
|
# NPU vs GPU api list
|
|
34
33
|
CompareApi = set(absolute_standard_api) | set(binary_standard_api) | set(thousandth_standard_api)
|
|
35
34
|
|
|
@@ -43,6 +42,15 @@ OnlineApiPrecisionCompareConfig = namedtuple('OnlineApiPrecisionCompareConfig',
|
|
|
43
42
|
CommonCompareConfig = namedtuple('CommonCompareConfig', ['compare', 'handle_func', 'config'])
|
|
44
43
|
|
|
45
44
|
|
|
45
|
+
def get_gpu_device():
|
|
46
|
+
try:
|
|
47
|
+
import torch_npu
|
|
48
|
+
is_gpu = False
|
|
49
|
+
except ImportError:
|
|
50
|
+
is_gpu = True
|
|
51
|
+
return is_gpu
|
|
52
|
+
|
|
53
|
+
|
|
46
54
|
def run_ut_process(xpu_id, consumer_queue, common_config, api_precision_csv_file):
|
|
47
55
|
""" When consumer_queue(shared with ConsumerDispatcher) is not empty, consume api data from consumer_queue.
|
|
48
56
|
:param xpu_id: int
|
|
@@ -51,7 +59,9 @@ def run_ut_process(xpu_id, consumer_queue, common_config, api_precision_csv_file
|
|
|
51
59
|
:param api_precision_csv_file: list, length is 2, result file name and details file name
|
|
52
60
|
:return:
|
|
53
61
|
"""
|
|
54
|
-
|
|
62
|
+
device_info = "cuda" if get_gpu_device() else "npu"
|
|
63
|
+
logger.info(f"Start run_ut_process for {device_info} device, rank: {xpu_id}.")
|
|
64
|
+
gpu_device = torch.device(f'{device_info}:{xpu_id}')
|
|
55
65
|
|
|
56
66
|
while True:
|
|
57
67
|
if consumer_queue.empty():
|
|
@@ -12,19 +12,19 @@
|
|
|
12
12
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
13
|
# See the License for the specific language governing permissions and
|
|
14
14
|
# limitations under the License.
|
|
15
|
-
|
|
16
|
-
import os
|
|
15
|
+
from functools import partial
|
|
16
|
+
import os
|
|
17
17
|
import struct
|
|
18
|
-
import
|
|
18
|
+
import zlib
|
|
19
19
|
import time
|
|
20
20
|
import io
|
|
21
21
|
from threading import Thread
|
|
22
22
|
|
|
23
|
-
from twisted.internet import reactor, protocol, endpoints
|
|
23
|
+
from twisted.internet import reactor, protocol, endpoints, ssl
|
|
24
24
|
|
|
25
25
|
from msprobe.pytorch.common.utils import logger
|
|
26
26
|
from msprobe.pytorch.api_accuracy_checker.tensor_transport_layer.utils import cipher_list, \
|
|
27
|
-
STRUCT_UNPACK_MODE as unpack_mode, STR_TO_BYTES_ORDER as bytes_order
|
|
27
|
+
STRUCT_UNPACK_MODE as unpack_mode, STR_TO_BYTES_ORDER as bytes_order, verify_callback, load_ssl_pem
|
|
28
28
|
|
|
29
29
|
|
|
30
30
|
class TCPServer:
|
|
@@ -44,15 +44,28 @@ class TCPServer:
|
|
|
44
44
|
self.factory.protocol = self.build_protocol
|
|
45
45
|
|
|
46
46
|
if self.tls_path:
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
47
|
+
server_key, server_crt, ca_crt, crl_pem = load_ssl_pem(
|
|
48
|
+
key_file=os.path.join(self.tls_path, "server.key"),
|
|
49
|
+
cert_file=os.path.join(self.tls_path, "server.crt"),
|
|
50
|
+
ca_file=os.path.join(self.tls_path, "ca.crt"),
|
|
51
|
+
crl_file=os.path.join(self.tls_path, "crl.pem")
|
|
52
|
+
)
|
|
53
|
+
|
|
54
|
+
ssl_options = ssl.CertificateOptions(
|
|
55
|
+
privateKey=server_key,
|
|
56
|
+
certificate=server_crt,
|
|
57
|
+
method=ssl.SSL.TLSv1_2_METHOD,
|
|
58
|
+
verify=True,
|
|
59
|
+
requireCertificate=True,
|
|
60
|
+
caCerts=[ca_crt], # 信任的CA证书列表
|
|
61
|
+
)
|
|
62
|
+
ssl_context = ssl_options.getContext()
|
|
63
|
+
ssl_context.set_cipher_list(cipher_list)
|
|
64
|
+
ssl_context.set_options(ssl.SSL.OP_NO_RENEGOTIATION)
|
|
65
|
+
ssl_context.set_verify(ssl.SSL.VERIFY_PEER | ssl.SSL.VERIFY_FAIL_IF_NO_PEER_CERT,
|
|
66
|
+
partial(verify_callback, crl=crl_pem))
|
|
67
|
+
|
|
68
|
+
endpoint = endpoints.SSL4ServerEndpoint(reactor, self.port, ssl_options)
|
|
56
69
|
else:
|
|
57
70
|
endpoint = endpoints.TCP4ServerEndpoint(reactor, self.port)
|
|
58
71
|
endpoint.listen(self.factory)
|
|
@@ -85,10 +98,10 @@ class ServerProtocol(protocol.Protocol):
|
|
|
85
98
|
self.consumer_queue = shared_queue
|
|
86
99
|
self.check_sum = check_sum
|
|
87
100
|
self.length_width = 8
|
|
88
|
-
self.
|
|
101
|
+
self.crc_width = 8
|
|
89
102
|
self.obj_length = None
|
|
90
103
|
self.tell = 0
|
|
91
|
-
self.
|
|
104
|
+
self.obj_crc = None
|
|
92
105
|
self.obj_body = None
|
|
93
106
|
self.sequence_number = -1
|
|
94
107
|
self.rank = -1
|
|
@@ -99,7 +112,7 @@ class ServerProtocol(protocol.Protocol):
|
|
|
99
112
|
self.buffer = io.BytesIO()
|
|
100
113
|
self.obj_length = None
|
|
101
114
|
self.tell = 0
|
|
102
|
-
self.
|
|
115
|
+
self.obj_crc = None
|
|
103
116
|
self.obj_body = None
|
|
104
117
|
self.factory.transport_dict[self.transport] = 1
|
|
105
118
|
self.factory.transport_list.append(self.transport)
|
|
@@ -132,11 +145,12 @@ class ServerProtocol(protocol.Protocol):
|
|
|
132
145
|
time.sleep(0.1)
|
|
133
146
|
|
|
134
147
|
obj_key = str(self.sequence_number) + "_" + str(self.rank) + "_" + str(self.step)
|
|
148
|
+
# get the crc value of a 16-bit string with a length of 8
|
|
149
|
+
recv_crc = f"{zlib.crc32(self.obj_body):08x}"
|
|
135
150
|
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
logger.debug(f"Error:接收数据有问题,流水号{self.sequence_number}, expected {self.obj_md5}, but get {recv_md5}")
|
|
151
|
+
if self.check_sum and recv_crc != self.obj_crc:
|
|
152
|
+
# when needs check hash value and check no pass, indicates received data error, send b"ERROR" to client.
|
|
153
|
+
logger.debug(f"Error:接收数据有问题,流水号{self.sequence_number}, expected {self.obj_crc}, but get {recv_crc}")
|
|
140
154
|
self.send_ack(self.ACK_ERROR)
|
|
141
155
|
else:
|
|
142
156
|
if self.obj_body == self.ACK_STOP:
|
|
@@ -146,7 +160,7 @@ class ServerProtocol(protocol.Protocol):
|
|
|
146
160
|
if obj_key in self.sequence_number_dict:
|
|
147
161
|
logger.debug(f"这是一次异常的重传,可以忽略。 {obj_key}, {self.sequence_number_dict}")
|
|
148
162
|
else:
|
|
149
|
-
self.sequence_number_dict[obj_key] = self.
|
|
163
|
+
self.sequence_number_dict[obj_key] = self.obj_crc
|
|
150
164
|
self.consumer_queue.put(self.obj_body, block=True)
|
|
151
165
|
|
|
152
166
|
self.reset_env()
|
|
@@ -173,7 +187,7 @@ class ServerProtocol(protocol.Protocol):
|
|
|
173
187
|
self.sequence_number = -1
|
|
174
188
|
self.rank = -1
|
|
175
189
|
self.step = -1
|
|
176
|
-
self.
|
|
190
|
+
self.obj_crc = None
|
|
177
191
|
self.obj_body = None
|
|
178
192
|
|
|
179
193
|
def dataReceived(self, data):
|
|
@@ -192,15 +206,15 @@ class ServerProtocol(protocol.Protocol):
|
|
|
192
206
|
logger.debug(
|
|
193
207
|
f"流水号: {self.sequence_number}; RANK: {self.rank}; STEP: {self.step}; Length: {self.obj_length}")
|
|
194
208
|
|
|
195
|
-
# If needs check
|
|
196
|
-
|
|
209
|
+
# If needs check hash but not parse crc yet, read 8b crc values
|
|
210
|
+
check_sum_and_crc = (self.check_sum
|
|
197
211
|
and self.obj_length is not None
|
|
198
|
-
and self.
|
|
199
|
-
and len(self.buffer.getvalue()) - self.tell >= self.
|
|
200
|
-
if
|
|
201
|
-
self.
|
|
202
|
-
self.tell += self.
|
|
203
|
-
logger.debug(f"
|
|
212
|
+
and self.obj_crc is None
|
|
213
|
+
and len(self.buffer.getvalue()) - self.tell >= self.crc_width)
|
|
214
|
+
if check_sum_and_crc:
|
|
215
|
+
self.obj_crc = self.buffer.read(self.crc_width).decode()
|
|
216
|
+
self.tell += self.crc_width
|
|
217
|
+
logger.debug(f"Hash value: {self.obj_crc}")
|
|
204
218
|
|
|
205
219
|
current_length = len(self.buffer.getvalue()) - self.tell
|
|
206
220
|
if self.obj_length is not None and 0 < self.obj_length <= current_length:
|
|
@@ -12,6 +12,16 @@
|
|
|
12
12
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
13
|
# See the License for the specific language governing permissions and
|
|
14
14
|
# limitations under the License.
|
|
15
|
+
import os
|
|
16
|
+
from datetime import datetime, timezone
|
|
17
|
+
|
|
18
|
+
from OpenSSL import crypto
|
|
19
|
+
from cryptography import x509
|
|
20
|
+
from cryptography.hazmat.backends import default_backend
|
|
21
|
+
from dateutil import parser
|
|
22
|
+
|
|
23
|
+
from msprobe.core.common.file_utils import FileOpen
|
|
24
|
+
from msprobe.core.common.log import logger
|
|
15
25
|
|
|
16
26
|
cipher_list = ":".join(
|
|
17
27
|
["TLS_DHE_RSA_WITH_AES_128_GCM_SHA256",
|
|
@@ -42,3 +52,147 @@ cipher_list = ":".join(
|
|
|
42
52
|
|
|
43
53
|
STRUCT_UNPACK_MODE = "!Q"
|
|
44
54
|
STR_TO_BYTES_ORDER = "big"
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
def is_certificate_revoked(cert, crl):
|
|
58
|
+
# 获取证书的序列号
|
|
59
|
+
cert_serial_number = cert.get_serial_number()
|
|
60
|
+
|
|
61
|
+
# 检查证书是否在CRL中
|
|
62
|
+
revoked_serials = [revoked_cert.serial_number for revoked_cert in crl]
|
|
63
|
+
if cert_serial_number in revoked_serials:
|
|
64
|
+
logger.error(f"证书已吊销:{cert_serial_number:020x}")
|
|
65
|
+
return True
|
|
66
|
+
|
|
67
|
+
return False
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
def verify_callback(conn, cert, errno, depth, preverify_ok, crl=None):
|
|
71
|
+
"""
|
|
72
|
+
验证对端证书的有效性
|
|
73
|
+
:param conn: OpenSSL.SSL.Connection, SSL 连接对象
|
|
74
|
+
:param cert: OpenSSL.crypto.X509, 当前证书
|
|
75
|
+
:param errno: int, OpenSSL错误代码, 0:无错误 | 9:证书过期 | 18: 自签名证书
|
|
76
|
+
:param depth: int, 当前证书在证书链中的深度 (0=叶子节点), 1:中间CA证书 -1:根CA证书 2+:更高级别CA证书
|
|
77
|
+
:param preverify_ok: int, 验证结果 (1=通过, 0=失败)
|
|
78
|
+
:param crl: _CRLInternal, CRL证书对象
|
|
79
|
+
:return: bool, True表示接受证书, False表示拒绝
|
|
80
|
+
"""
|
|
81
|
+
|
|
82
|
+
if not preverify_ok:
|
|
83
|
+
from OpenSSL import SSL
|
|
84
|
+
error_str = SSL._ffi.string(SSL._lib.X509_verify_cert_error_string(errno)).decode()
|
|
85
|
+
logger.error(f"证书验证失败 (depth={depth}, err={errno}): {error_str}")
|
|
86
|
+
return False
|
|
87
|
+
|
|
88
|
+
if crl and is_certificate_revoked(cert, crl):
|
|
89
|
+
return False
|
|
90
|
+
|
|
91
|
+
return preverify_ok
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
def load_ssl_pem(key_file, cert_file, ca_file, crl_file):
|
|
95
|
+
"""
|
|
96
|
+
Load SSL PEM files.
|
|
97
|
+
|
|
98
|
+
Args:
|
|
99
|
+
key_file (str): The path to the private key file.
|
|
100
|
+
cert_file (str): The path to the certificate file.
|
|
101
|
+
ca_file (str): The path to the CA certificate file.
|
|
102
|
+
crl_file (str): The path to the CRL file.
|
|
103
|
+
|
|
104
|
+
Returns:
|
|
105
|
+
tuple: (key, crt, ca_crt, crl)
|
|
106
|
+
|
|
107
|
+
Raises:
|
|
108
|
+
Exception: If the file paths are invalid or the file contents are incorrect, exceptions may be thrown.
|
|
109
|
+
"""
|
|
110
|
+
|
|
111
|
+
try:
|
|
112
|
+
# your_private_key_password
|
|
113
|
+
passphrase = ""
|
|
114
|
+
if not passphrase:
|
|
115
|
+
import pwinput
|
|
116
|
+
passphrase = pwinput.pwinput("Enter your password: ")
|
|
117
|
+
with FileOpen(key_file, "rb") as f:
|
|
118
|
+
key = crypto.load_privatekey(crypto.FILETYPE_PEM, f.read(), passphrase.encode())
|
|
119
|
+
del passphrase
|
|
120
|
+
with FileOpen(cert_file, "rb") as f:
|
|
121
|
+
crt = crypto.load_certificate(crypto.FILETYPE_PEM, f.read())
|
|
122
|
+
check_crt_valid(crt)
|
|
123
|
+
|
|
124
|
+
crt_serial_number = hex(crt.get_serial_number())[2:]
|
|
125
|
+
logger.info(f"crt_serial_number: {crt_serial_number}")
|
|
126
|
+
|
|
127
|
+
check_certificate_match(crt, key)
|
|
128
|
+
|
|
129
|
+
with FileOpen(ca_file, "rb") as f:
|
|
130
|
+
ca_crt = crypto.load_certificate(crypto.FILETYPE_PEM, f.read())
|
|
131
|
+
check_crt_valid(ca_crt)
|
|
132
|
+
|
|
133
|
+
ca_serial_number = hex(ca_crt.get_serial_number())[2:]
|
|
134
|
+
logger.info(f"ca_serial_number: {ca_serial_number}")
|
|
135
|
+
crl = None
|
|
136
|
+
if os.path.exists(crl_file):
|
|
137
|
+
with FileOpen(crl_file, "rb") as f:
|
|
138
|
+
crl = x509.load_pem_x509_crl(f.read(), default_backend())
|
|
139
|
+
check_crl_valid(crl, ca_crt)
|
|
140
|
+
for revoked_cert in crl:
|
|
141
|
+
logger.info(f"Serial Number: {revoked_cert.serial_number}, "
|
|
142
|
+
f"Revocation Date: {revoked_cert.revocation_date_utc}")
|
|
143
|
+
|
|
144
|
+
except Exception as e:
|
|
145
|
+
raise RuntimeError(f"The SSL certificate is invalid") from e
|
|
146
|
+
|
|
147
|
+
return key, crt, ca_crt, crl
|
|
148
|
+
|
|
149
|
+
|
|
150
|
+
def check_crt_valid(pem):
|
|
151
|
+
"""
|
|
152
|
+
Check the validity of the SSL certificate.
|
|
153
|
+
|
|
154
|
+
Raises:
|
|
155
|
+
RuntimeError: If the SSL certificate is invalid or expired.
|
|
156
|
+
"""
|
|
157
|
+
try:
|
|
158
|
+
pem_start = parser.parse(pem.get_notBefore().decode("UTF-8"))
|
|
159
|
+
pem_end = parser.parse(pem.get_notAfter().decode("UTF-8"))
|
|
160
|
+
logger.info(f"The SSL certificate passes the verification and the validity period "
|
|
161
|
+
f"starts from {pem_start} ends at {pem_end}.")
|
|
162
|
+
except Exception as e:
|
|
163
|
+
raise RuntimeError(f"The SSL certificate is invalid") from e
|
|
164
|
+
|
|
165
|
+
now_utc = datetime.now(tz=timezone.utc)
|
|
166
|
+
if pem.has_expired() or not (pem_start <= now_utc <= pem_end):
|
|
167
|
+
raise RuntimeError(f"The SSL certificate has expired.")
|
|
168
|
+
|
|
169
|
+
|
|
170
|
+
def check_certificate_match(certificate, private_key):
|
|
171
|
+
"""
|
|
172
|
+
Check certificate and private_key is match or not. if mismatched, an exception is thrown.
|
|
173
|
+
:param certificate:
|
|
174
|
+
:param private_key:
|
|
175
|
+
:return:
|
|
176
|
+
"""
|
|
177
|
+
test_data = os.urandom(256)
|
|
178
|
+
try:
|
|
179
|
+
signature = crypto.sign(private_key, test_data, "sha256")
|
|
180
|
+
crypto.verify(
|
|
181
|
+
certificate, # 包含公钥的证书
|
|
182
|
+
signature, # 生成的签名
|
|
183
|
+
test_data, # 原始数据
|
|
184
|
+
"sha256", # 哈希算法
|
|
185
|
+
)
|
|
186
|
+
logger.info("公钥和私钥匹配")
|
|
187
|
+
except Exception as e:
|
|
188
|
+
raise RuntimeError("公钥和私钥不匹配") from e
|
|
189
|
+
|
|
190
|
+
|
|
191
|
+
def check_crl_valid(crl, ca_crt):
|
|
192
|
+
# 验证CRL签名(确保CRL未被篡改)
|
|
193
|
+
if not crl.is_signature_valid(ca_crt.get_pubkey().to_cryptography_key()):
|
|
194
|
+
raise RuntimeError("CRL签名无效!")
|
|
195
|
+
|
|
196
|
+
# 检查CRL有效期
|
|
197
|
+
if not (crl.last_update <= datetime.utcnow() <= crl.next_update):
|
|
198
|
+
raise RuntimeError("CRL已过期或尚未生效!")
|