mindstudio-probe 1.3.0__py3-none-any.whl → 8.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/METADATA +4 -2
- {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/RECORD +204 -152
- msprobe/README.md +32 -1
- msprobe/core/__init__.py +17 -0
- msprobe/core/common/const.py +120 -21
- msprobe/core/common/exceptions.py +2 -2
- msprobe/core/common/file_utils.py +279 -50
- msprobe/core/common/framework_adapter.py +169 -0
- msprobe/core/common/global_lock.py +86 -0
- msprobe/core/common/runtime.py +25 -0
- msprobe/core/common/utils.py +136 -45
- msprobe/core/common_config.py +7 -0
- msprobe/core/compare/acc_compare.py +646 -428
- msprobe/core/compare/check.py +36 -103
- msprobe/core/compare/compare_cli.py +4 -0
- msprobe/core/compare/config.py +72 -0
- msprobe/core/compare/highlight.py +215 -215
- msprobe/core/compare/layer_mapping/layer_mapping.py +2 -0
- msprobe/core/compare/merge_result/merge_result.py +4 -4
- msprobe/core/compare/multiprocessing_compute.py +223 -110
- msprobe/core/compare/npy_compare.py +2 -4
- msprobe/core/compare/utils.py +214 -244
- msprobe/core/config_check/__init__.py +17 -0
- msprobe/{pytorch/dump/kernel_dump/kernel_config.py → core/config_check/checkers/__init__.py} +8 -16
- msprobe/core/config_check/checkers/base_checker.py +60 -0
- msprobe/core/config_check/checkers/dataset_checker.py +138 -0
- msprobe/core/config_check/checkers/env_args_checker.py +96 -0
- msprobe/core/config_check/checkers/hyperparameter_checker.py +170 -0
- msprobe/core/config_check/checkers/pip_checker.py +90 -0
- msprobe/core/config_check/checkers/random_checker.py +367 -0
- msprobe/core/config_check/checkers/weights_checker.py +147 -0
- msprobe/core/config_check/ckpt_compare/ckpt_comparator.py +74 -0
- msprobe/core/config_check/ckpt_compare/megatron_loader.py +302 -0
- msprobe/core/config_check/ckpt_compare/metrics.py +83 -0
- msprobe/core/config_check/ckpt_compare/name_mapping.yaml +12 -0
- msprobe/core/config_check/config_check_cli.py +51 -0
- msprobe/core/config_check/config_checker.py +100 -0
- msprobe/{mindspore/runtime.py → core/config_check/resource/dependency.yaml} +7 -4
- msprobe/core/config_check/resource/env.yaml +57 -0
- msprobe/core/config_check/resource/hyperparameter.yaml +21 -0
- msprobe/core/config_check/utils/hyperparameter_parser.py +115 -0
- msprobe/core/config_check/utils/utils.py +107 -0
- msprobe/core/data_dump/api_registry.py +67 -4
- msprobe/core/data_dump/data_collector.py +170 -89
- msprobe/core/data_dump/data_processor/base.py +72 -51
- msprobe/core/data_dump/data_processor/mindspore_processor.py +109 -55
- msprobe/core/data_dump/data_processor/pytorch_processor.py +90 -82
- msprobe/core/data_dump/json_writer.py +143 -27
- msprobe/core/debugger/precision_debugger.py +144 -0
- msprobe/core/grad_probe/constant.py +1 -1
- msprobe/core/grad_probe/grad_compare.py +1 -1
- msprobe/core/grad_probe/utils.py +1 -1
- msprobe/core/hook_manager.py +242 -0
- msprobe/core/monitor/anomaly_processor.py +384 -0
- msprobe/core/service.py +357 -0
- msprobe/core/single_save/__init__.py +0 -0
- msprobe/core/single_save/single_comparator.py +243 -0
- msprobe/core/single_save/single_saver.py +146 -0
- msprobe/docs/01.installation.md +6 -5
- msprobe/docs/02.config_introduction.md +79 -22
- msprobe/docs/03.config_examples.md +1 -0
- msprobe/docs/04.kernel_dump_PyTorch.md +1 -1
- msprobe/docs/05.data_dump_PyTorch.md +118 -49
- msprobe/docs/06.data_dump_MindSpore.md +167 -20
- msprobe/docs/07.accuracy_checker_PyTorch.md +2 -2
- msprobe/docs/08.accuracy_checker_online_PyTorch.md +69 -9
- msprobe/docs/09.accuracy_checker_MindSpore.md +18 -6
- msprobe/docs/10.accuracy_compare_PyTorch.md +212 -74
- msprobe/docs/11.accuracy_compare_MindSpore.md +87 -37
- msprobe/docs/12.overflow_check_PyTorch.md +2 -2
- msprobe/docs/13.overflow_check_MindSpore.md +2 -2
- msprobe/docs/14.data_parse_PyTorch.md +3 -3
- msprobe/docs/17.grad_probe.md +2 -1
- msprobe/docs/18.online_dispatch.md +2 -2
- msprobe/docs/19.monitor.md +90 -44
- msprobe/docs/21.visualization_PyTorch.md +68 -15
- msprobe/docs/22.visualization_MindSpore.md +71 -18
- msprobe/docs/25.tool_function_introduction.md +23 -22
- msprobe/docs/26.data_dump_PyTorch_baseline.md +14 -3
- msprobe/docs/27.dump_json_instruction.md +1 -1
- msprobe/docs/28.debugger_save_instruction.md +111 -20
- msprobe/docs/29.data_dump_MSAdapter.md +2 -2
- msprobe/docs/30.overflow_check_MSAdapter.md +2 -2
- msprobe/docs/31.config_check.md +95 -0
- msprobe/docs/32.ckpt_compare.md +69 -0
- msprobe/docs/33.generate_operator_MindSpore.md +181 -0
- msprobe/docs/34.RL_collect.md +92 -0
- msprobe/docs/35.nan_analyze.md +72 -0
- msprobe/docs/data_dump_MindSpore/data_dump_MindSpore_baseline.md +12 -1
- msprobe/docs/data_dump_MindSpore/dynamic_graph_quick_start_example.md +3 -1
- msprobe/docs/img/compare_result.png +0 -0
- msprobe/docs/img/save_compare_result_sample.png +0 -0
- msprobe/docs/img/visualization/proxy.png +0 -0
- msprobe/mindspore/__init__.py +1 -2
- msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +150 -58
- msprobe/mindspore/api_accuracy_checker/api_runner.py +7 -3
- msprobe/mindspore/api_accuracy_checker/bench_functions/flash_attention_score.py +47 -69
- msprobe/mindspore/api_accuracy_checker/cmd_parser.py +4 -0
- msprobe/mindspore/api_accuracy_checker/compute_element.py +0 -1
- msprobe/mindspore/api_accuracy_checker/data_manager.py +2 -2
- msprobe/mindspore/api_accuracy_checker/generate_op_script/op_generator.py +460 -0
- msprobe/mindspore/api_accuracy_checker/generate_op_script/operator_replication.template +2081 -0
- msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +9 -0
- msprobe/mindspore/api_accuracy_checker/torch_mindtorch_importer.py +2 -1
- msprobe/mindspore/cell_processor.py +204 -33
- msprobe/mindspore/code_mapping/graph_parser.py +4 -21
- msprobe/mindspore/common/const.py +17 -7
- msprobe/mindspore/common/utils.py +128 -11
- msprobe/mindspore/compare/common_dir_compare.py +382 -0
- msprobe/mindspore/compare/distributed_compare.py +2 -26
- msprobe/mindspore/compare/ms_compare.py +17 -405
- msprobe/mindspore/compare/ms_graph_compare.py +14 -5
- msprobe/mindspore/compare/utils.py +37 -0
- msprobe/mindspore/debugger/debugger_config.py +53 -3
- msprobe/mindspore/debugger/precision_debugger.py +72 -91
- msprobe/mindspore/dump/cell_dump_process.py +877 -0
- msprobe/mindspore/dump/cell_dump_with_insert_gradient.py +864 -0
- msprobe/mindspore/dump/dump_tool_factory.py +13 -5
- msprobe/mindspore/dump/graph_mode_cell_dump.py +139 -0
- msprobe/mindspore/dump/graph_tensor_dump.py +123 -0
- msprobe/mindspore/dump/hook_cell/api_register.py +40 -6
- msprobe/mindspore/dump/hook_cell/hook_cell.py +18 -7
- msprobe/mindspore/dump/hook_cell/ms_hook_manager.py +88 -0
- msprobe/mindspore/dump/hook_cell/primitive_hooks.py +8 -2
- msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +18 -0
- msprobe/mindspore/dump/jit_dump.py +21 -18
- msprobe/mindspore/dump/kernel_kbyk_dump.py +6 -3
- msprobe/mindspore/dym_loader/hook_dynamic_loader.cpp +110 -0
- msprobe/mindspore/dym_loader/hook_dynamic_loader.h +15 -15
- msprobe/mindspore/free_benchmark/api_pynative_self_check.py +12 -6
- msprobe/mindspore/free_benchmark/common/utils.py +1 -1
- msprobe/mindspore/grad_probe/global_context.py +7 -2
- msprobe/mindspore/grad_probe/grad_stat_csv.py +3 -2
- msprobe/mindspore/mindspore_service.py +114 -0
- msprobe/mindspore/monitor/common_func.py +52 -0
- msprobe/mindspore/monitor/data_writers.py +237 -0
- msprobe/mindspore/monitor/features.py +20 -7
- msprobe/mindspore/monitor/module_hook.py +281 -209
- msprobe/mindspore/monitor/optimizer_collect.py +334 -0
- msprobe/mindspore/monitor/utils.py +25 -5
- msprobe/mindspore/ms_config.py +16 -15
- msprobe/mindspore/task_handler_factory.py +5 -2
- msprobe/msprobe.py +19 -0
- msprobe/nan_analyze/__init__.py +14 -0
- msprobe/nan_analyze/analyzer.py +255 -0
- msprobe/nan_analyze/graph.py +189 -0
- msprobe/nan_analyze/utils.py +211 -0
- msprobe/pytorch/api_accuracy_checker/common/config.py +2 -2
- msprobe/pytorch/api_accuracy_checker/compare/compare.py +36 -34
- msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +20 -20
- msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +4 -7
- msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +204 -2
- msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +12 -11
- msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +1 -0
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +8 -5
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +2 -3
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +29 -13
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +12 -2
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +45 -31
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +156 -0
- msprobe/pytorch/attl_manager.py +65 -0
- msprobe/pytorch/bench_functions/npu_fusion_attention.py +27 -0
- msprobe/pytorch/common/utils.py +26 -14
- msprobe/pytorch/compare/distributed_compare.py +4 -36
- msprobe/pytorch/compare/pt_compare.py +13 -84
- msprobe/pytorch/compare/utils.py +47 -0
- msprobe/pytorch/debugger/debugger_config.py +34 -17
- msprobe/pytorch/debugger/precision_debugger.py +66 -118
- msprobe/pytorch/dump/module_dump/hook_wrapper.py +93 -0
- msprobe/pytorch/dump/module_dump/module_dump.py +11 -58
- msprobe/pytorch/dump/module_dump/module_processer.py +143 -113
- msprobe/pytorch/grad_probe/grad_stat_csv.py +3 -2
- msprobe/pytorch/hook_module/api_register.py +29 -5
- msprobe/pytorch/hook_module/hook_module.py +9 -18
- msprobe/pytorch/hook_module/jit_script_wrapper.py +33 -0
- msprobe/pytorch/hook_module/pt_hook_manager.py +68 -0
- msprobe/pytorch/hook_module/support_wrap_ops.yaml +22 -1
- msprobe/pytorch/hook_module/utils.py +28 -2
- msprobe/pytorch/monitor/csv2tb.py +6 -2
- msprobe/pytorch/monitor/data_writers.py +259 -0
- msprobe/pytorch/monitor/module_hook.py +227 -158
- msprobe/pytorch/monitor/module_metric.py +14 -0
- msprobe/pytorch/monitor/optimizer_collect.py +242 -270
- msprobe/pytorch/monitor/utils.py +16 -3
- msprobe/pytorch/online_dispatch/dispatch.py +4 -2
- msprobe/pytorch/online_dispatch/dump_compare.py +5 -2
- msprobe/pytorch/parse_tool/lib/utils.py +3 -3
- msprobe/pytorch/pt_config.py +8 -7
- msprobe/pytorch/pytorch_service.py +73 -0
- msprobe/visualization/builder/graph_builder.py +33 -13
- msprobe/visualization/builder/msprobe_adapter.py +24 -11
- msprobe/visualization/compare/graph_comparator.py +53 -45
- msprobe/visualization/compare/mode_adapter.py +31 -1
- msprobe/visualization/graph/base_node.py +3 -3
- msprobe/visualization/graph/graph.py +2 -2
- msprobe/visualization/graph_service.py +250 -103
- msprobe/visualization/utils.py +27 -11
- msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +0 -106
- msprobe/mindspore/monitor/anomaly_detect.py +0 -404
- msprobe/mindspore/monitor/module_spec_verifier.py +0 -94
- msprobe/mindspore/service.py +0 -549
- msprobe/pytorch/monitor/anomaly_analyse.py +0 -201
- msprobe/pytorch/monitor/anomaly_detect.py +0 -410
- msprobe/pytorch/monitor/module_spec_verifier.py +0 -95
- msprobe/pytorch/monitor/unittest/test_monitor.py +0 -160
- msprobe/pytorch/service.py +0 -473
- {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/LICENSE +0 -0
- {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/WHEEL +0 -0
- {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/entry_points.txt +0 -0
- {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/top_level.txt +0 -0
- /msprobe/{mindspore → core}/compare/ms_to_pt_api.yaml +0 -0
- /msprobe/{mindspore/dump → core}/kernel_dump/kernel_config.py +0 -0
- /msprobe/{pytorch/monitor/unittest → core/monitor}/__init__.py +0 -0
|
@@ -40,6 +40,7 @@ from msprobe.pytorch.api_accuracy_checker.compare.compare_utils import check_dty
|
|
|
40
40
|
DETAIL_TEST_ROWS, BENCHMARK_COMPARE_SUPPORT_LIST
|
|
41
41
|
from msprobe.pytorch.api_accuracy_checker.common.utils import extract_basic_api_segments
|
|
42
42
|
from msprobe.pytorch.common.log import logger
|
|
43
|
+
from msprobe.core.common.decorator import recursion_depth_decorator
|
|
43
44
|
|
|
44
45
|
|
|
45
46
|
ResultInfo = namedtuple('ResultInfo', ['full_api_name', 'fwd_success_status', 'bwd_success_status',
|
|
@@ -178,6 +179,41 @@ class Comparator:
|
|
|
178
179
|
if not os.path.exists(detail_save_path):
|
|
179
180
|
write_csv(DETAIL_TEST_ROWS, detail_save_path)
|
|
180
181
|
|
|
182
|
+
@recursion_depth_decorator("compare_core")
|
|
183
|
+
def _compare_core(self, api_name, bench_output, device_output):
|
|
184
|
+
compare_column = CompareColumn()
|
|
185
|
+
if not isinstance(bench_output, type(device_output)):
|
|
186
|
+
status = CompareConst.ERROR
|
|
187
|
+
message = "bench and npu output type is different."
|
|
188
|
+
elif isinstance(bench_output, dict):
|
|
189
|
+
b_keys, n_keys = set(bench_output.keys()), set(device_output.keys())
|
|
190
|
+
if b_keys != n_keys:
|
|
191
|
+
status = CompareConst.ERROR
|
|
192
|
+
message = "bench and npu output dict keys are different."
|
|
193
|
+
else:
|
|
194
|
+
status, compare_column, message = self._compare_core(api_name, list(bench_output.values()),
|
|
195
|
+
list(device_output.values()))
|
|
196
|
+
elif isinstance(bench_output, torch.Tensor):
|
|
197
|
+
copy_bench_out = bench_output.detach().clone()
|
|
198
|
+
copy_device_output = device_output.detach().clone()
|
|
199
|
+
compare_column.bench_type = str(copy_bench_out.dtype)
|
|
200
|
+
compare_column.npu_type = str(copy_device_output.dtype)
|
|
201
|
+
compare_column.shape = tuple(device_output.shape)
|
|
202
|
+
status, compare_column, message = self._compare_torch_tensor(api_name, copy_bench_out, copy_device_output,
|
|
203
|
+
compare_column)
|
|
204
|
+
elif isinstance(bench_output, (bool, int, float, str)):
|
|
205
|
+
compare_column.bench_type = str(type(bench_output))
|
|
206
|
+
compare_column.npu_type = str(type(device_output))
|
|
207
|
+
status, compare_column, message = self._compare_builtin_type(bench_output, device_output, compare_column)
|
|
208
|
+
elif bench_output is None:
|
|
209
|
+
status = CompareConst.SKIP
|
|
210
|
+
message = "Bench output is None, skip this test."
|
|
211
|
+
else:
|
|
212
|
+
status = CompareConst.ERROR
|
|
213
|
+
message = "Unexpected output type in compare_core: {}".format(type(bench_output))
|
|
214
|
+
|
|
215
|
+
return status, compare_column, message
|
|
216
|
+
|
|
181
217
|
def write_summary_csv(self, test_result):
|
|
182
218
|
test_rows = []
|
|
183
219
|
try:
|
|
@@ -293,40 +329,6 @@ class Comparator:
|
|
|
293
329
|
test_final_success = CompareConst.WARNING
|
|
294
330
|
return test_final_success, detailed_result_total
|
|
295
331
|
|
|
296
|
-
def _compare_core(self, api_name, bench_output, device_output):
|
|
297
|
-
compare_column = CompareColumn()
|
|
298
|
-
if not isinstance(bench_output, type(device_output)):
|
|
299
|
-
status = CompareConst.ERROR
|
|
300
|
-
message = "bench and npu output type is different."
|
|
301
|
-
elif isinstance(bench_output, dict):
|
|
302
|
-
b_keys, n_keys = set(bench_output.keys()), set(device_output.keys())
|
|
303
|
-
if b_keys != n_keys:
|
|
304
|
-
status = CompareConst.ERROR
|
|
305
|
-
message = "bench and npu output dict keys are different."
|
|
306
|
-
else:
|
|
307
|
-
status, compare_column, message = self._compare_core(api_name, list(bench_output.values()),
|
|
308
|
-
list(device_output.values()))
|
|
309
|
-
elif isinstance(bench_output, torch.Tensor):
|
|
310
|
-
copy_bench_out = bench_output.detach().clone()
|
|
311
|
-
copy_device_output = device_output.detach().clone()
|
|
312
|
-
compare_column.bench_type = str(copy_bench_out.dtype)
|
|
313
|
-
compare_column.npu_type = str(copy_device_output.dtype)
|
|
314
|
-
compare_column.shape = tuple(device_output.shape)
|
|
315
|
-
status, compare_column, message = self._compare_torch_tensor(api_name, copy_bench_out, copy_device_output,
|
|
316
|
-
compare_column)
|
|
317
|
-
elif isinstance(bench_output, (bool, int, float, str)):
|
|
318
|
-
compare_column.bench_type = str(type(bench_output))
|
|
319
|
-
compare_column.npu_type = str(type(device_output))
|
|
320
|
-
status, compare_column, message = self._compare_builtin_type(bench_output, device_output, compare_column)
|
|
321
|
-
elif bench_output is None:
|
|
322
|
-
status = CompareConst.SKIP
|
|
323
|
-
message = "Bench output is None, skip this test."
|
|
324
|
-
else:
|
|
325
|
-
status = CompareConst.ERROR
|
|
326
|
-
message = "Unexpected output type in compare_core: {}".format(type(bench_output))
|
|
327
|
-
|
|
328
|
-
return status, compare_column, message
|
|
329
|
-
|
|
330
332
|
def _compare_torch_tensor(self, api_name, bench_output, device_output, compare_column):
|
|
331
333
|
cpu_shape = bench_output.shape
|
|
332
334
|
npu_shape = device_output.shape
|
|
@@ -73,27 +73,27 @@ DETAIL_TEST_ROWS = [
|
|
|
73
73
|
|
|
74
74
|
|
|
75
75
|
precision_configs = {
|
|
76
|
-
torch.float16
|
|
77
|
-
'small_value'
|
|
76
|
+
torch.float16: {
|
|
77
|
+
'small_value': [
|
|
78
78
|
1e-3
|
|
79
79
|
],
|
|
80
|
-
'small_value_atol'
|
|
80
|
+
'small_value_atol': [
|
|
81
81
|
1e-5
|
|
82
82
|
]
|
|
83
83
|
},
|
|
84
84
|
torch.bfloat16: {
|
|
85
|
-
'small_value'
|
|
85
|
+
'small_value': [
|
|
86
86
|
1e-3
|
|
87
87
|
],
|
|
88
|
-
'small_value_atol'
|
|
88
|
+
'small_value_atol': [
|
|
89
89
|
1e-5
|
|
90
90
|
]
|
|
91
91
|
},
|
|
92
|
-
torch.float32:{
|
|
93
|
-
'small_value'
|
|
92
|
+
torch.float32: {
|
|
93
|
+
'small_value': [
|
|
94
94
|
1e-6
|
|
95
95
|
],
|
|
96
|
-
'small_value_atol'
|
|
96
|
+
'small_value_atol': [
|
|
97
97
|
1e-9
|
|
98
98
|
]
|
|
99
99
|
}
|
|
@@ -101,33 +101,33 @@ precision_configs = {
|
|
|
101
101
|
|
|
102
102
|
|
|
103
103
|
ULP_PARAMETERS = {
|
|
104
|
-
torch.float16
|
|
105
|
-
'min_eb'
|
|
104
|
+
torch.float16: {
|
|
105
|
+
'min_eb': [
|
|
106
106
|
-14
|
|
107
107
|
],
|
|
108
|
-
'exponent_num'
|
|
108
|
+
'exponent_num': [
|
|
109
109
|
10
|
|
110
110
|
]
|
|
111
111
|
},
|
|
112
|
-
torch.bfloat16
|
|
113
|
-
'min_eb'
|
|
112
|
+
torch.bfloat16: {
|
|
113
|
+
'min_eb': [
|
|
114
114
|
-126
|
|
115
115
|
],
|
|
116
|
-
'exponent_num'
|
|
116
|
+
'exponent_num': [
|
|
117
117
|
7
|
|
118
118
|
]
|
|
119
119
|
},
|
|
120
|
-
torch.float32
|
|
121
|
-
'min_eb'
|
|
120
|
+
torch.float32: {
|
|
121
|
+
'min_eb': [
|
|
122
122
|
-126
|
|
123
123
|
],
|
|
124
|
-
'exponent_num'
|
|
124
|
+
'exponent_num': [
|
|
125
125
|
23
|
|
126
126
|
]
|
|
127
127
|
}
|
|
128
128
|
}
|
|
129
|
-
|
|
130
|
-
|
|
129
|
+
|
|
130
|
+
|
|
131
131
|
class ApiPrecisionCompareColumn:
|
|
132
132
|
API_NAME = 'API Name'
|
|
133
133
|
DEVICE_DTYPE = 'DEVICE Dtype'
|
|
@@ -202,7 +202,7 @@ class ApiPrecisionCompareColumn:
|
|
|
202
202
|
|
|
203
203
|
|
|
204
204
|
CompareMessage = {
|
|
205
|
-
"topk"
|
|
205
|
+
"topk": "在npu上,topk的入参sorted=False时不生效,会返回有序tensor,而cpu上会返回无序tensor。 如果topk精度不达标,请检查是否是该原因导致的。"
|
|
206
206
|
}
|
|
207
207
|
|
|
208
208
|
|
|
@@ -411,19 +411,16 @@ class OperatorScriptGenerator:
|
|
|
411
411
|
return kwargs_dict_generator
|
|
412
412
|
|
|
413
413
|
|
|
414
|
-
|
|
415
414
|
def _op_generator_parser(parser):
|
|
416
|
-
parser.add_argument("-i", "--config_input", dest="config_input",
|
|
417
|
-
help="<
|
|
415
|
+
parser.add_argument("-i", "--config_input", dest="config_input", type=str,
|
|
416
|
+
help="<Required> Path of config json file", required=True)
|
|
418
417
|
parser.add_argument("-o", "--api_output_path", dest="api_output_path", type=str,
|
|
419
|
-
help="<Required> Path of extract api_name.json.",
|
|
420
|
-
required=True)
|
|
418
|
+
help="<Required> Path of extract api_name.json.", required=True)
|
|
421
419
|
|
|
422
420
|
|
|
423
421
|
def parse_json_config(json_file_path):
|
|
424
422
|
if not json_file_path:
|
|
425
|
-
|
|
426
|
-
json_file_path = os.path.join(config_dir, "config.json")
|
|
423
|
+
raise Exception("config_input path can not be empty, please check.")
|
|
427
424
|
json_config = load_json(json_file_path)
|
|
428
425
|
common_config = CommonConfig(json_config)
|
|
429
426
|
return common_config
|
|
@@ -1,6 +1,6 @@
|
|
|
1
|
-
import json
|
|
2
1
|
import os
|
|
3
|
-
import
|
|
2
|
+
import re
|
|
3
|
+
import stat
|
|
4
4
|
from enum import Enum, auto
|
|
5
5
|
import torch
|
|
6
6
|
try:
|
|
@@ -25,6 +25,31 @@ RAISE_PRECISION = {{
|
|
|
25
25
|
}}
|
|
26
26
|
THOUSANDTH_THRESHOLDING = 0.001
|
|
27
27
|
BACKWARD = 'backward'
|
|
28
|
+
DIR = "dir"
|
|
29
|
+
FILE = "file"
|
|
30
|
+
READ_ABLE = "read"
|
|
31
|
+
WRITE_ABLE = "write"
|
|
32
|
+
READ_WRITE_ABLE = "read and write"
|
|
33
|
+
DIRECTORY_LENGTH = 4096
|
|
34
|
+
FILE_NAME_LENGTH = 255
|
|
35
|
+
SOFT_LINK_ERROR = "检测到软链接"
|
|
36
|
+
FILE_PERMISSION_ERROR = "文件权限错误"
|
|
37
|
+
INVALID_FILE_ERROR = "无效文件"
|
|
38
|
+
ILLEGAL_PATH_ERROR = "非法文件路径"
|
|
39
|
+
ILLEGAL_PARAM_ERROR = "非法打开方式"
|
|
40
|
+
FILE_TOO_LARGE_ERROR = "文件过大"
|
|
41
|
+
FILE_VALID_PATTERN = r"^[a-zA-Z0-9_.:/-]+$"
|
|
42
|
+
FILE_SIZE_DICT = {{
|
|
43
|
+
".pkl": 1073741824, # 1 * 1024 * 1024 * 1024
|
|
44
|
+
".npy": 10737418240, # 10 * 1024 * 1024 * 1024
|
|
45
|
+
".json": 1073741824, # 1 * 1024 * 1024 * 1024
|
|
46
|
+
".pt": 10737418240, # 10 * 1024 * 1024 * 1024
|
|
47
|
+
".csv": 1073741824, # 1 * 1024 * 1024 * 1024
|
|
48
|
+
".xlsx": 1073741824, # 1 * 1024 * 1024 * 1024
|
|
49
|
+
".yaml": 1073741824, # 1 * 1024 * 1024 * 1024
|
|
50
|
+
".ir": 1073741824 # 1 * 1024 * 1024 * 1024
|
|
51
|
+
}}
|
|
52
|
+
COMMOM_FILE_SIZE = 1048576 # 1 * 1024 * 1024
|
|
28
53
|
|
|
29
54
|
class CompareStandard(Enum):
|
|
30
55
|
BINARY_EQUALITY_STANDARD = auto()
|
|
@@ -33,8 +58,184 @@ class CompareStandard(Enum):
|
|
|
33
58
|
BENCHMARK_STANDARD = auto()
|
|
34
59
|
THOUSANDTH_STANDARD = auto()
|
|
35
60
|
|
|
61
|
+
class FileChecker:
|
|
62
|
+
"""
|
|
63
|
+
The class for check file.
|
|
64
|
+
|
|
65
|
+
Attributes:
|
|
66
|
+
file_path: The file or dictionary path to be verified.
|
|
67
|
+
path_type: file or dictionary
|
|
68
|
+
ability(str): FileCheckConst.WRITE_ABLE or FileCheckConst.READ_ABLE to set file has writability or readability
|
|
69
|
+
file_type(str): The correct file type for file
|
|
70
|
+
"""
|
|
71
|
+
|
|
72
|
+
def __init__(self, file_path, path_type, ability=None, file_type=None, is_script=True):
|
|
73
|
+
self.file_path = file_path
|
|
74
|
+
self.path_type = self._check_path_type(path_type)
|
|
75
|
+
self.ability = ability
|
|
76
|
+
self.file_type = file_type
|
|
77
|
+
self.is_script = is_script
|
|
78
|
+
|
|
79
|
+
@staticmethod
|
|
80
|
+
def _check_path_type(path_type):
|
|
81
|
+
if path_type not in [DIR, FILE]:
|
|
82
|
+
print(f'ERROR: The path_type must be {{DIR}} or {{FILE}}.')
|
|
83
|
+
raise Exception(ILLEGAL_PARAM_ERROR)
|
|
84
|
+
return path_type
|
|
85
|
+
|
|
86
|
+
def common_check(self):
|
|
87
|
+
"""
|
|
88
|
+
功能:用户校验基本文件权限:软连接、文件长度、是否存在、读写权限、文件属组、文件特殊字符
|
|
89
|
+
注意:文件后缀的合法性,非通用操作,可使用其他独立接口实现
|
|
90
|
+
"""
|
|
91
|
+
FileChecker.check_path_exists(self.file_path)
|
|
92
|
+
FileChecker.check_link(self.file_path)
|
|
93
|
+
self.file_path = os.path.realpath(self.file_path)
|
|
94
|
+
FileChecker.check_path_length(self.file_path)
|
|
95
|
+
FileChecker.check_path_type(self.file_path, self.path_type)
|
|
96
|
+
self.check_path_ability()
|
|
97
|
+
if self.is_script:
|
|
98
|
+
FileChecker.check_path_owner_consistent(self.file_path)
|
|
99
|
+
FileChecker.check_path_pattern_valid(self.file_path)
|
|
100
|
+
FileChecker.check_common_file_size(self.file_path)
|
|
101
|
+
FileChecker.check_file_suffix(self.file_path, self.file_type)
|
|
102
|
+
if self.path_type == FILE:
|
|
103
|
+
FileChecker.check_dirpath_before_read(self.file_path)
|
|
104
|
+
return self.file_path
|
|
105
|
+
|
|
106
|
+
def check_path_ability(self):
|
|
107
|
+
if self.ability == WRITE_ABLE:
|
|
108
|
+
FileChecker.check_path_writability(self.file_path)
|
|
109
|
+
if self.ability == READ_ABLE:
|
|
110
|
+
FileChecker.check_path_readability(self.file_path)
|
|
111
|
+
if self.ability == READ_WRITE_ABLE:
|
|
112
|
+
FileChecker.check_path_readability(self.file_path)
|
|
113
|
+
FileChecker.check_path_writability(self.file_path)
|
|
114
|
+
|
|
115
|
+
@staticmethod
|
|
116
|
+
def check_path_exists(path):
|
|
117
|
+
if not os.path.exists(path):
|
|
118
|
+
print(f'ERROR: The file path %s does not exist.' % path)
|
|
119
|
+
raise Exception()
|
|
120
|
+
|
|
121
|
+
@staticmethod
|
|
122
|
+
def check_link(path):
|
|
123
|
+
abs_path = os.path.abspath(path)
|
|
124
|
+
if os.path.islink(abs_path):
|
|
125
|
+
print('ERROR: The file path {{}} is a soft link.'.format(path))
|
|
126
|
+
raise Exception(SOFT_LINK_ERROR)
|
|
127
|
+
|
|
128
|
+
@staticmethod
|
|
129
|
+
def check_path_length(path, name_length=None):
|
|
130
|
+
file_max_name_length = name_length if name_length else FILE_NAME_LENGTH
|
|
131
|
+
if len(path) > DIRECTORY_LENGTH or \
|
|
132
|
+
len(os.path.basename(path)) > file_max_name_length:
|
|
133
|
+
print(f'ERROR: The file path length exceeds limit.')
|
|
134
|
+
raise Exception(ILLEGAL_PATH_ERROR)
|
|
135
|
+
|
|
136
|
+
@staticmethod
|
|
137
|
+
def check_path_type(file_path, file_type):
|
|
138
|
+
if file_type == FILE:
|
|
139
|
+
if not os.path.isfile(file_path):
|
|
140
|
+
print(f"ERROR: The {{file_path}} should be a file!")
|
|
141
|
+
raise Exception(INVALID_FILE_ERROR)
|
|
142
|
+
if file_type == DIR:
|
|
143
|
+
if not os.path.isdir(file_path):
|
|
144
|
+
print(f"ERROR: The {{file_path}} should be a dictionary!")
|
|
145
|
+
raise Exception(INVALID_FILE_ERROR)
|
|
146
|
+
|
|
147
|
+
@staticmethod
|
|
148
|
+
def check_path_owner_consistent(path):
|
|
149
|
+
file_owner = os.stat(path).st_uid
|
|
150
|
+
if file_owner != os.getuid() and os.getuid() != 0:
|
|
151
|
+
print('ERROR: The file path %s may be insecure because is does not belong to you.' % path)
|
|
152
|
+
raise Exception(FILE_PERMISSION_ERROR)
|
|
153
|
+
|
|
154
|
+
@staticmethod
|
|
155
|
+
def check_path_pattern_valid(path):
|
|
156
|
+
if not re.match(FILE_VALID_PATTERN, path):
|
|
157
|
+
print('ERROR: The file path %s contains special characters.' % (path))
|
|
158
|
+
raise Exception(ILLEGAL_PATH_ERROR)
|
|
159
|
+
|
|
160
|
+
@staticmethod
|
|
161
|
+
def check_common_file_size(file_path):
|
|
162
|
+
if os.path.isfile(file_path):
|
|
163
|
+
for suffix, max_size in FILE_SIZE_DICT.items():
|
|
164
|
+
if file_path.endswith(suffix):
|
|
165
|
+
FileChecker.check_file_size(file_path, max_size)
|
|
166
|
+
return
|
|
167
|
+
FileChecker.check_file_size(file_path, COMMOM_FILE_SIZE)
|
|
168
|
+
|
|
169
|
+
@staticmethod
|
|
170
|
+
def check_file_size(file_path, max_size):
|
|
171
|
+
try:
|
|
172
|
+
file_size = os.path.getsize(file_path)
|
|
173
|
+
except OSError as os_error:
|
|
174
|
+
print(f'ERROR: Failed to open "{{file_path}}". {{str(os_error)}}')
|
|
175
|
+
raise Exception(INVALID_FILE_ERROR) from os_error
|
|
176
|
+
if file_size >= max_size:
|
|
177
|
+
print(f'ERROR: The size ({{file_size}}) of {{file_path}} exceeds ({{max_size}}) bytes, tools not support.')
|
|
178
|
+
raise Exception(FILE_TOO_LARGE_ERROR)
|
|
179
|
+
|
|
180
|
+
@staticmethod
|
|
181
|
+
def check_file_suffix(file_path, file_suffix):
|
|
182
|
+
if file_suffix:
|
|
183
|
+
if not file_path.endswith(file_suffix):
|
|
184
|
+
print(f"The {{file_path}} should be a {{file_suffix}} file!")
|
|
185
|
+
raise Exception(INVALID_FILE_ERROR)
|
|
186
|
+
|
|
187
|
+
@staticmethod
|
|
188
|
+
def check_dirpath_before_read(path):
|
|
189
|
+
path = os.path.realpath(path)
|
|
190
|
+
dirpath = os.path.dirname(path)
|
|
191
|
+
if FileChecker.check_others_writable(dirpath):
|
|
192
|
+
print(f"WARNING: The directory is writable by others: {{dirpath}}.")
|
|
193
|
+
try:
|
|
194
|
+
FileChecker.check_path_owner_consistent(dirpath)
|
|
195
|
+
except Exception:
|
|
196
|
+
print(f"WARNING: The directory {{dirpath}} is not yours.")
|
|
197
|
+
|
|
198
|
+
@staticmethod
|
|
199
|
+
def check_others_writable(directory):
|
|
200
|
+
dir_stat = os.stat(directory)
|
|
201
|
+
is_writable = (
|
|
202
|
+
bool(dir_stat.st_mode & stat.S_IWGRP) or # 组可写
|
|
203
|
+
bool(dir_stat.st_mode & stat.S_IWOTH) # 其他用户可写
|
|
204
|
+
)
|
|
205
|
+
return is_writable
|
|
206
|
+
|
|
207
|
+
@staticmethod
|
|
208
|
+
def check_path_readability(path):
|
|
209
|
+
if not os.access(path, os.R_OK):
|
|
210
|
+
print('ERROR: The file path %s is not readable.' % path)
|
|
211
|
+
raise Exception(FILE_PERMISSION_ERROR)
|
|
212
|
+
|
|
213
|
+
@staticmethod
|
|
214
|
+
def check_path_writability(path):
|
|
215
|
+
if not os.access(path, os.W_OK):
|
|
216
|
+
print('ERROR: The file path %s is not writable.' % path)
|
|
217
|
+
raise Exception(FILE_PERMISSION_ERROR)
|
|
218
|
+
|
|
219
|
+
|
|
220
|
+
def check_file_or_directory_path(path, isdir=False):
|
|
221
|
+
"""
|
|
222
|
+
Function Description:
|
|
223
|
+
check whether the path is valid
|
|
224
|
+
Parameter:
|
|
225
|
+
path: the path to check
|
|
226
|
+
isdir: the path is dir or file
|
|
227
|
+
Exception Description:
|
|
228
|
+
when invalid data throw exception
|
|
229
|
+
"""
|
|
230
|
+
if isdir:
|
|
231
|
+
path_checker = FileChecker(path, DIR, WRITE_ABLE)
|
|
232
|
+
else:
|
|
233
|
+
path_checker = FileChecker(path, FILE, READ_ABLE)
|
|
234
|
+
path_checker.common_check()
|
|
235
|
+
|
|
36
236
|
def load_pt(pt_path, to_cpu=False):
|
|
37
237
|
pt_path = os.path.realpath(pt_path)
|
|
238
|
+
check_file_or_directory_path(pt_path)
|
|
38
239
|
try:
|
|
39
240
|
if to_cpu:
|
|
40
241
|
pt = torch.load(pt_path, map_location=torch.device("cpu"), weights_only=True)
|
|
@@ -202,6 +403,7 @@ def compare_tensor(out_device, out_bench, api_name):
|
|
|
202
403
|
else:
|
|
203
404
|
abs_err = torch.abs(out_device - out_bench)
|
|
204
405
|
abs_bench = torch.abs(out_bench)
|
|
406
|
+
eps = 2 ** -23
|
|
205
407
|
if dtype_bench == torch.float32:
|
|
206
408
|
eps = 2 ** -23
|
|
207
409
|
if dtype_bench == torch.float64:
|
|
@@ -70,7 +70,7 @@ def split_json_file(input_file, num_splits, filter_api):
|
|
|
70
70
|
split_forward_data = dict(items[start:end])
|
|
71
71
|
temp_data = {
|
|
72
72
|
**input_data,
|
|
73
|
-
"data":{
|
|
73
|
+
"data": {
|
|
74
74
|
**split_forward_data,
|
|
75
75
|
**backward_data
|
|
76
76
|
}
|
|
@@ -87,10 +87,6 @@ def signal_handler(signum, frame):
|
|
|
87
87
|
raise KeyboardInterrupt()
|
|
88
88
|
|
|
89
89
|
|
|
90
|
-
signal.signal(signal.SIGINT, signal_handler)
|
|
91
|
-
signal.signal(signal.SIGTERM, signal_handler)
|
|
92
|
-
|
|
93
|
-
|
|
94
90
|
ParallelUTConfig = namedtuple('ParallelUTConfig', ['api_files', 'out_path', 'num_splits',
|
|
95
91
|
'save_error_data_flag', 'jit_compile_flag', 'device_id',
|
|
96
92
|
'result_csv_path', 'total_items', 'config_path'])
|
|
@@ -132,6 +128,9 @@ def run_parallel_ut(config):
|
|
|
132
128
|
sys.stdout.flush()
|
|
133
129
|
except ValueError as e:
|
|
134
130
|
logger.warning(f"An error occurred while reading subprocess output: {e}")
|
|
131
|
+
finally:
|
|
132
|
+
if process.poll() is None:
|
|
133
|
+
process.stdout.close()
|
|
135
134
|
|
|
136
135
|
def update_progress_bar(progress_bar, result_csv_path):
|
|
137
136
|
while any(process.poll() is None for process in processes):
|
|
@@ -142,7 +141,7 @@ def run_parallel_ut(config):
|
|
|
142
141
|
|
|
143
142
|
for api_info in config.api_files:
|
|
144
143
|
cmd = create_cmd(api_info, next(device_id_cycle))
|
|
145
|
-
process = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.DEVNULL,
|
|
144
|
+
process = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.DEVNULL,
|
|
146
145
|
text=True, bufsize=1, shell=False)
|
|
147
146
|
processes.append(process)
|
|
148
147
|
threading.Thread(target=read_process_output, args=(process,), daemon=True).start()
|
|
@@ -188,8 +187,8 @@ def run_parallel_ut(config):
|
|
|
188
187
|
|
|
189
188
|
|
|
190
189
|
def prepare_config(args):
|
|
191
|
-
api_info_file_checker = FileChecker(file_path=args.api_info_file, path_type=FileCheckConst.FILE,
|
|
192
|
-
|
|
190
|
+
api_info_file_checker = FileChecker(file_path=args.api_info_file, path_type=FileCheckConst.FILE,
|
|
191
|
+
ability=FileCheckConst.READ_ABLE, file_type=FileCheckConst.JSON_SUFFIX)
|
|
193
192
|
api_info = api_info_file_checker.common_check()
|
|
194
193
|
out_path = args.out_path if args.out_path else Const.DEFAULT_PATH
|
|
195
194
|
create_directory(out_path)
|
|
@@ -198,11 +197,11 @@ def prepare_config(args):
|
|
|
198
197
|
split_files, total_items = split_json_file(api_info, args.num_splits, args.filter_api)
|
|
199
198
|
config_path = args.config_path if args.config_path else None
|
|
200
199
|
if config_path:
|
|
201
|
-
config_path_checker = FileChecker(config_path, FileCheckConst.FILE,
|
|
200
|
+
config_path_checker = FileChecker(config_path, FileCheckConst.FILE,
|
|
202
201
|
FileCheckConst.READ_ABLE, FileCheckConst.JSON_SUFFIX)
|
|
203
202
|
config_path = config_path_checker.common_check()
|
|
204
203
|
result_csv_path = args.result_csv_path or os.path.join(
|
|
205
|
-
|
|
204
|
+
out_path, f"accuracy_checking_result_{time.strftime('%Y%m%d%H%M%S')}.csv")
|
|
206
205
|
if not args.result_csv_path:
|
|
207
206
|
details_csv_path = os.path.join(out_path, f"accuracy_checking_details_{time.strftime('%Y%m%d%H%M%S')}.csv")
|
|
208
207
|
comparator = Comparator(result_csv_path, details_csv_path, False)
|
|
@@ -217,9 +216,11 @@ def prepare_config(args):
|
|
|
217
216
|
|
|
218
217
|
|
|
219
218
|
def main():
|
|
219
|
+
signal.signal(signal.SIGINT, signal_handler)
|
|
220
|
+
signal.signal(signal.SIGTERM, signal_handler)
|
|
220
221
|
parser = argparse.ArgumentParser(description='Run UT in parallel')
|
|
221
222
|
_run_ut_parser(parser)
|
|
222
|
-
parser.add_argument('-n', '--num_splits', type=int, choices=range(1, 65), default=8,
|
|
223
|
+
parser.add_argument('-n', '--num_splits', type=int, choices=range(1, 65), default=8,
|
|
223
224
|
help='Number of splits for parallel processing. Range: 1-64')
|
|
224
225
|
args = parser.parse_args()
|
|
225
226
|
config = prepare_config(args)
|
|
@@ -45,7 +45,7 @@ from msprobe.pytorch.api_accuracy_checker.compare.compare_column import CompareC
|
|
|
45
45
|
from msprobe.pytorch.api_accuracy_checker.common.config import CheckerConfig
|
|
46
46
|
from msprobe.pytorch.common.parse_json import parse_json_info_forward_backward
|
|
47
47
|
from msprobe.core.common.file_utils import FileChecker, change_mode, \
|
|
48
|
-
create_directory, get_json_contents, read_csv, check_file_or_directory_path
|
|
48
|
+
create_directory, get_json_contents, read_csv, check_file_or_directory_path
|
|
49
49
|
from msprobe.pytorch.common.log import logger
|
|
50
50
|
from msprobe.pytorch.pt_config import parse_json_config
|
|
51
51
|
from msprobe.core.common.const import Const, FileCheckConst, CompareConst
|
|
@@ -65,7 +65,8 @@ DETAILS_FILE_NAME = "accuracy_checking_details_" + current_time + ".csv"
|
|
|
65
65
|
|
|
66
66
|
not_backward_list = ['repeat_interleave']
|
|
67
67
|
unsupported_backward_list = ['masked_select']
|
|
68
|
-
unsupported_api_list = ["to"
|
|
68
|
+
unsupported_api_list = ["to", "empty", "empty_like", "empty_strided", "new_empty", "new_empty_strided",
|
|
69
|
+
"empty_with_format"]
|
|
69
70
|
|
|
70
71
|
|
|
71
72
|
tqdm_params = {
|
|
@@ -482,7 +483,6 @@ def _run_ut(parser=None):
|
|
|
482
483
|
run_ut_command(args)
|
|
483
484
|
|
|
484
485
|
|
|
485
|
-
|
|
486
486
|
def checked_online_config(online_config):
|
|
487
487
|
if not online_config.is_online:
|
|
488
488
|
return
|
|
@@ -503,8 +503,10 @@ def checked_online_config(online_config):
|
|
|
503
503
|
check_file_or_directory_path(online_config.tls_path, isdir=True)
|
|
504
504
|
check_file_or_directory_path(os.path.join(online_config.tls_path, "server.key"))
|
|
505
505
|
check_file_or_directory_path(os.path.join(online_config.tls_path, "server.crt"))
|
|
506
|
-
|
|
507
|
-
|
|
506
|
+
check_file_or_directory_path(os.path.join(online_config.tls_path, "ca.crt"))
|
|
507
|
+
crl_path = os.path.join(online_config.tls_path, "crl.pem")
|
|
508
|
+
if os.path.exists(crl_path):
|
|
509
|
+
check_file_or_directory_path(crl_path)
|
|
508
510
|
|
|
509
511
|
# host and port
|
|
510
512
|
if not isinstance(online_config.host, str) or not re.match(Const.ipv4_pattern, online_config.host):
|
|
@@ -582,6 +584,7 @@ def run_ut_command(args):
|
|
|
582
584
|
if len(parts_by_underscore) < 2:
|
|
583
585
|
raise ValueError("File name part does not contain enough '_' separated segments.")
|
|
584
586
|
time_info = parts_by_underscore[-1]
|
|
587
|
+
|
|
585
588
|
global UT_ERROR_DATA_DIR
|
|
586
589
|
UT_ERROR_DATA_DIR = 'ut_error_data' + time_info
|
|
587
590
|
error_data_path = initialize_save_error_data(error_data_path)
|
|
@@ -124,8 +124,6 @@ def exec_api(exec_params):
|
|
|
124
124
|
api_register.initialize_hook(None)
|
|
125
125
|
api_func_type = list(prefix_map.keys())[list(prefix_map.values()).index(api_type)]
|
|
126
126
|
api_func = api_register.ori_api_attr.get(Const.PT_FRAMEWORK + Const.SEP + api_func_type, {}).get(api_name)
|
|
127
|
-
if api_func is None:
|
|
128
|
-
return out
|
|
129
127
|
|
|
130
128
|
torch_api = ApiTemplate(api_name, api_func, api_type, None, need_hook=False, device=device)
|
|
131
129
|
if is_autocast:
|
|
@@ -257,7 +255,8 @@ def record_skip_info(api_full_name, compare, compare_alg_results):
|
|
|
257
255
|
|
|
258
256
|
def is_unsupported_api(api_name, is_overflow_check=False):
|
|
259
257
|
split_name = api_name.split(Const.SEP)[0]
|
|
260
|
-
|
|
258
|
+
unsupport_type_list = [Const.DISTRIBUTED, Const.MINDSPEED_API_TYPE_PREFIX]
|
|
259
|
+
flag = (split_name in unsupport_type_list) or (is_overflow_check and split_name == Const.NPU)
|
|
261
260
|
if flag:
|
|
262
261
|
logger.info(f"{split_name} api is not supported for run ut. SKIP.")
|
|
263
262
|
return flag
|
|
@@ -12,23 +12,22 @@
|
|
|
12
12
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
13
|
# See the License for the specific language governing permissions and
|
|
14
14
|
# limitations under the License.
|
|
15
|
-
|
|
16
|
-
import
|
|
15
|
+
from functools import partial
|
|
16
|
+
import zlib
|
|
17
17
|
import io
|
|
18
18
|
import struct
|
|
19
19
|
import time
|
|
20
20
|
import os
|
|
21
|
-
import signal
|
|
22
21
|
from queue import Queue
|
|
23
22
|
from threading import Thread
|
|
24
23
|
from typing import Union
|
|
25
24
|
|
|
26
|
-
from twisted.internet import reactor, protocol, endpoints
|
|
25
|
+
from twisted.internet import reactor, protocol, endpoints, ssl
|
|
27
26
|
from twisted.protocols.basic import FileSender
|
|
28
27
|
|
|
29
28
|
from msprobe.pytorch.common.utils import logger
|
|
30
29
|
from msprobe.pytorch.api_accuracy_checker.tensor_transport_layer.utils import STRUCT_UNPACK_MODE as unpack_mode, \
|
|
31
|
-
STR_TO_BYTES_ORDER as bytes_order
|
|
30
|
+
STR_TO_BYTES_ORDER as bytes_order, cipher_list, verify_callback, load_ssl_pem
|
|
32
31
|
|
|
33
32
|
MAX_SENDING_QUEUE_SIZE = 20
|
|
34
33
|
|
|
@@ -104,11 +103,28 @@ class TCPClient:
|
|
|
104
103
|
self.factory = MessageClientFactory()
|
|
105
104
|
self.factory.protocol = cur_protocol
|
|
106
105
|
if self.tls_path:
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
106
|
+
client_key, client_crt, ca_crt, crl_pem = load_ssl_pem(
|
|
107
|
+
key_file=os.path.join(self.tls_path, "client.key"),
|
|
108
|
+
cert_file=os.path.join(self.tls_path, "client.crt"),
|
|
109
|
+
ca_file=os.path.join(self.tls_path, "ca.crt"),
|
|
110
|
+
crl_file=os.path.join(self.tls_path, "crl.pem")
|
|
111
|
+
)
|
|
112
|
+
|
|
113
|
+
ssl_options = ssl.CertificateOptions(
|
|
114
|
+
privateKey=client_key,
|
|
115
|
+
certificate=client_crt,
|
|
116
|
+
method=ssl.SSL.TLSv1_2_METHOD,
|
|
117
|
+
verify=True,
|
|
118
|
+
requireCertificate=True,
|
|
119
|
+
caCerts=[ca_crt], # 信任的CA证书列表
|
|
120
|
+
)
|
|
121
|
+
ssl_context = ssl_options.getContext()
|
|
122
|
+
ssl_context.set_cipher_list(cipher_list)
|
|
123
|
+
ssl_context.set_options(ssl.SSL.OP_NO_RENEGOTIATION)
|
|
124
|
+
ssl_context.set_verify(ssl.SSL.VERIFY_PEER | ssl.SSL.VERIFY_FAIL_IF_NO_PEER_CERT,
|
|
125
|
+
partial(verify_callback, crl=crl_pem))
|
|
126
|
+
|
|
127
|
+
endpoint = endpoints.SSL4ClientEndpoint(reactor, self.host, self.port, ssl_options)
|
|
112
128
|
else:
|
|
113
129
|
endpoint = endpoints.TCP4ClientEndpoint(reactor, self.host, self.port)
|
|
114
130
|
d = endpoint.connect(self.factory)
|
|
@@ -299,12 +315,12 @@ class ClientProtocol(protocol.Protocol):
|
|
|
299
315
|
|
|
300
316
|
def send_wrapped_data(self, data, sequence_number: int = 0, rank: int = 0, step: int = 0):
|
|
301
317
|
length = len(data)
|
|
302
|
-
|
|
318
|
+
data_crc = f"{zlib.crc32(data):08x}" if self.check_sum else ""
|
|
303
319
|
data_meaasge = length.to_bytes(8, byteorder=bytes_order) + \
|
|
304
320
|
sequence_number.to_bytes(8, byteorder=bytes_order) + \
|
|
305
321
|
rank.to_bytes(8, byteorder=bytes_order) + \
|
|
306
322
|
step.to_bytes(8, byteorder=bytes_order) + \
|
|
307
|
-
|
|
323
|
+
data_crc.encode() + \
|
|
308
324
|
data
|
|
309
325
|
logger.debug(f"send 流水号: {sequence_number}; RANK: {rank}; STEP: {step}; LENGTH: {length}")
|
|
310
326
|
|
|
@@ -346,7 +362,7 @@ class ClientProtocol(protocol.Protocol):
|
|
|
346
362
|
def connectionLost(self, reason):
|
|
347
363
|
self.signal_exit = True
|
|
348
364
|
self.factory.num_connections -= 1
|
|
349
|
-
logger.info(f"Lost connection with server, reason is : {reason}")
|
|
365
|
+
logger.info(f"Lost connection with server, reason is : {reason.value}")
|
|
350
366
|
|
|
351
367
|
|
|
352
368
|
class MessageClientFactory(protocol.ClientFactory):
|