mindstudio-probe 1.1.0__py3-none-any.whl → 1.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.1.1.dist-info}/METADATA +5 -5
- mindstudio_probe-1.1.1.dist-info/RECORD +341 -0
- {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.1.1.dist-info}/WHEEL +1 -1
- {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.1.1.dist-info}/entry_points.txt +0 -1
- msprobe/README.md +39 -3
- msprobe/config.json +1 -3
- msprobe/core/advisor/advisor.py +8 -3
- msprobe/core/common/const.py +113 -13
- msprobe/core/common/exceptions.py +25 -3
- msprobe/core/common/file_utils.py +150 -26
- msprobe/core/common/inplace_op_checker.py +15 -0
- msprobe/core/common/log.py +27 -9
- msprobe/core/common/utils.py +182 -69
- msprobe/core/common_config.py +44 -15
- msprobe/core/compare/acc_compare.py +207 -142
- msprobe/core/compare/check.py +2 -5
- msprobe/core/compare/compare_cli.py +21 -4
- msprobe/core/compare/highlight.py +124 -55
- msprobe/core/compare/layer_mapping/__init__.py +19 -0
- msprobe/core/compare/layer_mapping/data_scope_parser.py +235 -0
- msprobe/core/compare/layer_mapping/layer_mapping.py +242 -0
- msprobe/core/compare/layer_mapping/postprocess_pass.py +94 -0
- msprobe/core/compare/npy_compare.py +52 -23
- msprobe/core/compare/utils.py +272 -247
- msprobe/core/data_dump/data_collector.py +13 -11
- msprobe/core/data_dump/data_processor/base.py +46 -16
- msprobe/core/data_dump/data_processor/mindspore_processor.py +4 -4
- msprobe/core/data_dump/data_processor/pytorch_processor.py +156 -59
- msprobe/core/data_dump/scope.py +113 -34
- msprobe/core/grad_probe/constant.py +27 -13
- msprobe/core/grad_probe/grad_compare.py +18 -1
- msprobe/core/grad_probe/utils.py +30 -2
- msprobe/core/overflow_check/abnormal_scene.py +185 -0
- msprobe/core/overflow_check/api_info.py +55 -0
- msprobe/core/overflow_check/checker.py +138 -0
- msprobe/core/overflow_check/filter.py +157 -0
- msprobe/core/overflow_check/ignore_rules.yaml +55 -0
- msprobe/core/overflow_check/level.py +22 -0
- msprobe/core/overflow_check/utils.py +28 -0
- msprobe/docs/01.installation.md +10 -0
- msprobe/docs/02.config_introduction.md +49 -22
- msprobe/docs/03.config_examples.md +2 -9
- msprobe/docs/04.kernel_dump_PyTorch.md +73 -0
- msprobe/docs/05.data_dump_PyTorch.md +3 -1
- msprobe/docs/06.data_dump_MindSpore.md +157 -90
- msprobe/docs/07.accuracy_checker_PyTorch.md +12 -12
- msprobe/docs/08.accuracy_checker_online_PyTorch.md +1 -6
- msprobe/docs/09.accuracy_checker_MindSpore.md +44 -8
- msprobe/docs/10.accuracy_compare_PyTorch.md +19 -13
- msprobe/docs/11.accuracy_compare_MindSpore.md +104 -13
- msprobe/docs/12.overflow_check_PyTorch.md +1 -1
- msprobe/docs/13.overflow_check_MindSpore.md +6 -6
- msprobe/docs/15.free_benchmarking_PyTorch.md +4 -5
- msprobe/docs/16.free_benchmarking_MindSpore.md +56 -37
- msprobe/docs/17.grad_probe.md +5 -6
- msprobe/docs/19.monitor.md +468 -0
- msprobe/docs/20.monitor_performance_baseline.md +52 -0
- msprobe/docs/21.visualization_PyTorch.md +386 -0
- msprobe/docs/22.visualization_MindSpore.md +384 -0
- msprobe/docs/23.tool_function_introduction.md +28 -0
- msprobe/docs/FAQ.md +3 -0
- msprobe/docs/data_dump_Mindspore/dynamic_graph_quick_start_example.md +211 -0
- msprobe/docs/img/compare_result.png +0 -0
- msprobe/docs/img/monitor/cpu_info.png +0 -0
- msprobe/mindspore/__init__.py +15 -0
- msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +113 -145
- msprobe/mindspore/api_accuracy_checker/api_info.py +21 -6
- msprobe/mindspore/api_accuracy_checker/api_runner.py +43 -18
- msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +21 -7
- msprobe/mindspore/api_accuracy_checker/checker_support_api.yaml +77 -0
- msprobe/mindspore/api_accuracy_checker/cmd_parser.py +63 -1
- msprobe/mindspore/api_accuracy_checker/compute_element.py +59 -24
- msprobe/mindspore/api_accuracy_checker/data_manager.py +264 -0
- msprobe/mindspore/api_accuracy_checker/main.py +27 -3
- msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +206 -0
- msprobe/mindspore/api_accuracy_checker/multi_data_manager.py +58 -0
- msprobe/mindspore/api_accuracy_checker/type_mapping.py +22 -5
- msprobe/mindspore/api_accuracy_checker/utils.py +34 -17
- msprobe/mindspore/cell_processor.py +33 -12
- msprobe/mindspore/common/const.py +33 -13
- msprobe/mindspore/common/log.py +5 -9
- msprobe/mindspore/common/utils.py +43 -4
- msprobe/mindspore/compare/distributed_compare.py +22 -22
- msprobe/mindspore/compare/ms_compare.py +271 -248
- msprobe/mindspore/compare/ms_graph_compare.py +81 -47
- msprobe/mindspore/debugger/debugger_config.py +4 -1
- msprobe/mindspore/debugger/precision_debugger.py +7 -1
- msprobe/mindspore/dump/dump_tool_factory.py +3 -1
- msprobe/mindspore/dump/hook_cell/api_registry.py +12 -2
- msprobe/mindspore/dump/hook_cell/primitive_hooks.py +13 -16
- msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +25 -0
- msprobe/mindspore/dump/jit_dump.py +17 -5
- msprobe/mindspore/dump/kernel_graph_dump.py +2 -4
- msprobe/mindspore/dump/kernel_kbyk_dump.py +2 -4
- msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +140 -0
- msprobe/mindspore/dym_loader/hook_dynamic_loader.h +53 -0
- msprobe/mindspore/free_benchmark/api_pynative_self_check.py +145 -39
- msprobe/mindspore/free_benchmark/common/handler_params.py +1 -2
- msprobe/mindspore/free_benchmark/common/utils.py +19 -4
- msprobe/mindspore/free_benchmark/data/support_wrap_ops.yaml +0 -204
- msprobe/mindspore/free_benchmark/handler/base_handler.py +3 -3
- msprobe/mindspore/free_benchmark/handler/check_handler.py +4 -5
- msprobe/mindspore/free_benchmark/handler/fix_handler.py +4 -4
- msprobe/mindspore/free_benchmark/handler/handler_factory.py +4 -4
- msprobe/mindspore/free_benchmark/perturbation/add_noise.py +2 -2
- msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +15 -6
- msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +4 -4
- msprobe/mindspore/free_benchmark/perturbation/exchange_value.py +2 -2
- msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +13 -6
- msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +2 -2
- msprobe/mindspore/free_benchmark/self_check_tool_factory.py +2 -2
- msprobe/mindspore/grad_probe/global_context.py +28 -8
- msprobe/mindspore/grad_probe/grad_analyzer.py +27 -13
- msprobe/mindspore/grad_probe/grad_monitor.py +16 -1
- msprobe/mindspore/grad_probe/grad_stat_csv.py +33 -5
- msprobe/mindspore/grad_probe/hook.py +24 -10
- msprobe/mindspore/grad_probe/utils.py +18 -5
- msprobe/mindspore/ms_config.py +22 -15
- msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +2 -4
- msprobe/mindspore/runtime.py +15 -0
- msprobe/mindspore/service.py +36 -30
- msprobe/mindspore/task_handler_factory.py +15 -0
- msprobe/msprobe.py +24 -7
- msprobe/pytorch/__init__.py +3 -2
- msprobe/pytorch/api_accuracy_checker/common/config.py +62 -0
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +3 -4
- msprobe/pytorch/api_accuracy_checker/generate_op_script/config_op.json +9 -0
- msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +454 -0
- msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +365 -0
- msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +6 -1
- msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +19 -14
- msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +13 -9
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +77 -53
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +15 -4
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +9 -24
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +4 -12
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/dump_dispatch.py +9 -4
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +3 -11
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +2 -2
- msprobe/pytorch/bench_functions/confusion_transpose.py +5 -1
- msprobe/pytorch/bench_functions/matmul_backward.py +12 -0
- msprobe/pytorch/bench_functions/npu_fusion_attention.py +100 -6
- msprobe/pytorch/bench_functions/rotary_mul.py +4 -0
- msprobe/pytorch/bench_functions/swiglu.py +10 -2
- msprobe/pytorch/common/parse_json.py +6 -6
- msprobe/pytorch/common/utils.py +56 -5
- msprobe/pytorch/compare/distributed_compare.py +8 -9
- msprobe/pytorch/compare/pt_compare.py +8 -6
- msprobe/pytorch/debugger/debugger_config.py +19 -15
- msprobe/pytorch/dump/kernel_dump/kernel_config.py +33 -0
- msprobe/pytorch/free_benchmark/common/constant.py +15 -0
- msprobe/pytorch/free_benchmark/common/counter.py +15 -0
- msprobe/pytorch/free_benchmark/common/enums.py +15 -0
- msprobe/pytorch/free_benchmark/common/params.py +8 -1
- msprobe/pytorch/free_benchmark/common/utils.py +26 -4
- msprobe/pytorch/free_benchmark/compare/grad_saver.py +20 -3
- msprobe/pytorch/free_benchmark/compare/single_benchmark.py +2 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +3 -1
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +6 -4
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +2 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +4 -0
- msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +10 -0
- msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +6 -5
- msprobe/pytorch/grad_probe/grad_monitor.py +23 -6
- msprobe/pytorch/grad_probe/grad_stat_csv.py +40 -10
- msprobe/pytorch/hook_module/support_wrap_ops.yaml +1 -0
- msprobe/pytorch/hook_module/wrap_functional.py +14 -12
- msprobe/pytorch/module_processer.py +2 -5
- msprobe/pytorch/monitor/anomaly_analyse.py +201 -0
- msprobe/pytorch/monitor/anomaly_detect.py +340 -0
- msprobe/pytorch/monitor/distributed/__init__.py +0 -0
- msprobe/pytorch/monitor/distributed/distributed_ops.yaml +19 -0
- msprobe/pytorch/monitor/distributed/stack_blacklist.yaml +5 -0
- msprobe/pytorch/monitor/distributed/wrap_distributed.py +272 -0
- msprobe/pytorch/monitor/features.py +108 -0
- msprobe/pytorch/monitor/module_hook.py +870 -0
- msprobe/pytorch/monitor/module_metric.py +193 -0
- msprobe/pytorch/monitor/module_spec_verifier.py +93 -0
- msprobe/pytorch/monitor/optimizer_collect.py +295 -0
- msprobe/pytorch/monitor/unittest/__init__.py +0 -0
- msprobe/pytorch/monitor/unittest/test_monitor.py +145 -0
- msprobe/pytorch/monitor/utils.py +250 -0
- msprobe/pytorch/monitor/visualizer.py +59 -0
- msprobe/pytorch/online_dispatch/__init__.py +2 -3
- msprobe/pytorch/online_dispatch/compare.py +29 -38
- msprobe/pytorch/online_dispatch/dispatch.py +50 -25
- msprobe/pytorch/online_dispatch/dump_compare.py +21 -9
- msprobe/pytorch/online_dispatch/single_compare.py +53 -32
- msprobe/pytorch/online_dispatch/torch_ops_config.yaml +1 -1
- msprobe/pytorch/online_dispatch/utils.py +49 -21
- msprobe/pytorch/parse_tool/lib/compare.py +12 -18
- msprobe/pytorch/parse_tool/lib/config.py +1 -1
- msprobe/pytorch/parse_tool/lib/parse_tool.py +1 -2
- msprobe/pytorch/parse_tool/lib/utils.py +16 -35
- msprobe/pytorch/parse_tool/lib/visualization.py +2 -0
- msprobe/pytorch/pt_config.py +31 -8
- msprobe/pytorch/service.py +15 -5
- msprobe/visualization/__init__.py +14 -0
- msprobe/visualization/builder/__init__.py +14 -0
- msprobe/visualization/builder/graph_builder.py +165 -0
- msprobe/visualization/builder/msprobe_adapter.py +205 -0
- msprobe/visualization/compare/__init__.py +14 -0
- msprobe/visualization/compare/graph_comparator.py +130 -0
- msprobe/visualization/compare/mode_adapter.py +211 -0
- msprobe/visualization/graph/__init__.py +14 -0
- msprobe/visualization/graph/base_node.py +124 -0
- msprobe/visualization/graph/graph.py +200 -0
- msprobe/visualization/graph/node_colors.py +95 -0
- msprobe/visualization/graph/node_op.py +39 -0
- msprobe/visualization/graph_service.py +214 -0
- msprobe/visualization/utils.py +232 -0
- mindstudio_probe-1.1.0.dist-info/RECORD +0 -287
- msprobe/docs/04.acl_config_examples.md +0 -78
- msprobe/mindspore/compare/layer_mapping.py +0 -146
- msprobe/mindspore/compare/modify_mapping.py +0 -107
- msprobe/mindspore/free_benchmark/decorator/dec_forward.py +0 -57
- msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +0 -122
- {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.1.1.dist-info}/LICENSE +0 -0
- {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.1.1.dist-info}/top_level.txt +0 -0
- /msprobe/{mindspore/free_benchmark/decorator → pytorch/monitor}/__init__.py +0 -0
|
@@ -17,7 +17,7 @@
|
|
|
17
17
|
|
|
18
18
|
import argparse
|
|
19
19
|
import os
|
|
20
|
-
import
|
|
20
|
+
import re
|
|
21
21
|
import sys
|
|
22
22
|
import time
|
|
23
23
|
import gc
|
|
@@ -35,19 +35,20 @@ import torch
|
|
|
35
35
|
from tqdm import tqdm
|
|
36
36
|
|
|
37
37
|
from msprobe.pytorch.api_accuracy_checker.run_ut.run_ut_utils import BackwardMessage, UtDataInfo, \
|
|
38
|
-
get_validated_result_csv_path, get_validated_details_csv_path, exec_api, record_skip_info
|
|
38
|
+
get_validated_result_csv_path, get_validated_details_csv_path, exec_api, record_skip_info, is_unsupported_api
|
|
39
39
|
from msprobe.pytorch.api_accuracy_checker.run_ut.data_generate import gen_api_params, gen_args
|
|
40
40
|
from msprobe.pytorch.api_accuracy_checker.common.utils import api_info_preprocess, \
|
|
41
41
|
initialize_save_path, UtDataProcessor, extract_basic_api_segments, ApiData
|
|
42
42
|
from msprobe.pytorch.api_accuracy_checker.compare.compare import Comparator
|
|
43
43
|
from msprobe.pytorch.api_accuracy_checker.compare.compare_column import CompareColumn
|
|
44
|
-
from msprobe.pytorch.api_accuracy_checker.common.config import
|
|
44
|
+
from msprobe.pytorch.api_accuracy_checker.common.config import CheckerConfig
|
|
45
45
|
from msprobe.pytorch.common.parse_json import parse_json_info_forward_backward
|
|
46
|
-
from msprobe.core.common.file_utils import FileChecker, change_mode,
|
|
47
|
-
create_directory, get_json_contents, read_csv
|
|
46
|
+
from msprobe.core.common.file_utils import FileChecker, change_mode, \
|
|
47
|
+
create_directory, get_json_contents, read_csv, check_file_or_directory_path, check_crt_valid
|
|
48
48
|
from msprobe.pytorch.common.log import logger
|
|
49
49
|
from msprobe.pytorch.pt_config import parse_json_config
|
|
50
50
|
from msprobe.core.common.const import Const, FileCheckConst, CompareConst
|
|
51
|
+
from msprobe.core.common.utils import safe_get_value
|
|
51
52
|
from msprobe.pytorch.api_accuracy_checker.tensor_transport_layer.attl import ATTL, ATTLConfig, move2device_exec
|
|
52
53
|
from msprobe.pytorch.api_accuracy_checker.tensor_transport_layer.device_dispatch import ConsumerDispatcher
|
|
53
54
|
from msprobe.pytorch.api_accuracy_checker.run_ut.run_ut_utils import generate_cpu_params, generate_device_params
|
|
@@ -57,11 +58,7 @@ current_time = time.strftime("%Y%m%d%H%M%S")
|
|
|
57
58
|
UT_ERROR_DATA_DIR = 'ut_error_data' + current_time
|
|
58
59
|
RESULT_FILE_NAME = "accuracy_checking_result_" + current_time + ".csv"
|
|
59
60
|
DETAILS_FILE_NAME = "accuracy_checking_details_" + current_time + ".csv"
|
|
60
|
-
RunUTConfig = namedtuple('RunUTConfig', ['forward_content', 'backward_content', 'result_csv_path', 'details_csv_path',
|
|
61
|
-
'save_error_data', 'is_continue_run_ut', 'real_data_path', 'white_list',
|
|
62
|
-
'black_list', 'error_data_path', 'online_config'])
|
|
63
61
|
|
|
64
|
-
OnlineConfig = namedtuple('OnlineConfig', ['is_online', 'nfs_path', 'host', 'port', 'rank_list', 'tls_path'])
|
|
65
62
|
|
|
66
63
|
not_backward_list = ['repeat_interleave']
|
|
67
64
|
|
|
@@ -99,7 +96,11 @@ def run_ut(config):
|
|
|
99
96
|
run_api_online(config, compare)
|
|
100
97
|
else:
|
|
101
98
|
csv_df = read_csv(config.result_csv_path)
|
|
102
|
-
|
|
99
|
+
try:
|
|
100
|
+
api_name_set = {row[0] for row in csv_df.itertuples(index=False, name=None)}
|
|
101
|
+
except IndexError:
|
|
102
|
+
logger.error(f"Read {config.result_csv_path} error, api_name_set is empty.")
|
|
103
|
+
api_name_set = set()
|
|
103
104
|
run_api_offline(config, compare, api_name_set)
|
|
104
105
|
for result_csv_path, details_csv_path in zip(compare.save_path_list, compare.detail_save_path_list):
|
|
105
106
|
change_mode(result_csv_path, FileCheckConst.DATA_FILE_AUTHORITY)
|
|
@@ -140,7 +141,7 @@ def run_api_offline(config, compare, api_name_set):
|
|
|
140
141
|
except Exception as err:
|
|
141
142
|
if "expected scalar type Long" in str(err):
|
|
142
143
|
logger.warning(f"API {api_name} not support int32 tensor in CPU, please add {api_name} to CONVERT_API "
|
|
143
|
-
|
|
144
|
+
"'int32_to_int64' list in accuracy_tools/msprobe/core/common/const.py file.")
|
|
144
145
|
else:
|
|
145
146
|
logger.error(f"Run {api_full_name} UT Error: %s" % str(err))
|
|
146
147
|
compare_alg_results = err_column.to_column_value(CompareConst.SKIP, str(err))
|
|
@@ -220,14 +221,6 @@ def blacklist_and_whitelist_filter(api_name, black_list, white_list):
|
|
|
220
221
|
return False
|
|
221
222
|
|
|
222
223
|
|
|
223
|
-
def is_unsupported_api(api_name):
|
|
224
|
-
split_name = api_name.split(Const.SEP)[0]
|
|
225
|
-
flag = split_name == Const.DISTRIBUTED
|
|
226
|
-
if flag:
|
|
227
|
-
logger.info(f"{split_name} api is not supported for run ut. SKIP.")
|
|
228
|
-
return flag
|
|
229
|
-
|
|
230
|
-
|
|
231
224
|
def do_save_error_data(api_full_name, data_info, error_data_path, is_fwd_success, is_bwd_success):
|
|
232
225
|
if not is_fwd_success or not is_bwd_success:
|
|
233
226
|
processor = UtDataProcessor(error_data_path)
|
|
@@ -253,7 +246,7 @@ def run_torch_api(api_full_name, real_data_path, backward_content, api_info_dict
|
|
|
253
246
|
backward_message += BackwardMessage.UNSUPPORT_BACKWARD_MESSAGE
|
|
254
247
|
if api_name in not_backward_list:
|
|
255
248
|
need_grad = False
|
|
256
|
-
logger.
|
|
249
|
+
logger.info("%s %s" % (api_full_name, BackwardMessage.NO_BACKWARD_RESULT_MESSAGE))
|
|
257
250
|
backward_message += BackwardMessage.NO_BACKWARD_RESULT_MESSAGE
|
|
258
251
|
need_backward = need_backward and need_grad
|
|
259
252
|
if kwargs.get("device"):
|
|
@@ -278,7 +271,8 @@ def run_torch_api(api_full_name, real_data_path, backward_content, api_info_dict
|
|
|
278
271
|
func_options = {
|
|
279
272
|
'real_data_path': real_data_path
|
|
280
273
|
}
|
|
281
|
-
grad = gen_args(backward_args, api_name, func_options)
|
|
274
|
+
grad = gen_args(backward_args, api_name, func_options)
|
|
275
|
+
grad = safe_get_value(grad, 0, "grad")
|
|
282
276
|
bench_grad, _ = generate_cpu_params(grad, {}, False, api_name)
|
|
283
277
|
bench_grad_out = run_backward(cpu_args, bench_grad, grad_index, out)
|
|
284
278
|
device_grad = grad.clone().detach().to(current_device)
|
|
@@ -286,8 +280,8 @@ def run_torch_api(api_full_name, real_data_path, backward_content, api_info_dict
|
|
|
286
280
|
else:
|
|
287
281
|
backward_message += BackwardMessage.MULTIPLE_BACKWARD_MESSAGE
|
|
288
282
|
if api_name == "npu_fusion_attention":
|
|
289
|
-
out = out
|
|
290
|
-
device_out = device_out
|
|
283
|
+
out = safe_get_value(out, 0, "out")
|
|
284
|
+
device_out = safe_get_value(device_out, 0, "device_out")
|
|
291
285
|
|
|
292
286
|
return UtDataInfo(bench_grad_out, device_grad_out, device_out, out, bench_grad, in_fwd_data_list, backward_message)
|
|
293
287
|
|
|
@@ -323,6 +317,9 @@ def need_to_backward(grad_index, out):
|
|
|
323
317
|
|
|
324
318
|
def run_backward(args, grad, grad_index, out):
|
|
325
319
|
if grad_index is not None:
|
|
320
|
+
if grad_index >= len(out):
|
|
321
|
+
logger.error(f"Run backward error when grad_index is {grad_index}")
|
|
322
|
+
raise IndexError(f"Run backward error when grad_index is {grad_index}")
|
|
326
323
|
out[grad_index].backward(grad)
|
|
327
324
|
else:
|
|
328
325
|
out.backward(grad)
|
|
@@ -336,7 +333,6 @@ def run_backward(args, grad, grad_index, out):
|
|
|
336
333
|
|
|
337
334
|
|
|
338
335
|
def initialize_save_error_data(error_data_path):
|
|
339
|
-
check_path_before_create(error_data_path)
|
|
340
336
|
create_directory(error_data_path)
|
|
341
337
|
error_data_path_checker = FileChecker(error_data_path, FileCheckConst.DIR,
|
|
342
338
|
ability=FileCheckConst.WRITE_ABLE)
|
|
@@ -438,7 +434,49 @@ def _run_ut(parser=None):
|
|
|
438
434
|
run_ut_command(args)
|
|
439
435
|
|
|
440
436
|
|
|
437
|
+
def checked_online_config(online_config):
|
|
438
|
+
if not online_config.is_online:
|
|
439
|
+
return
|
|
440
|
+
if not isinstance(online_config.is_online, bool):
|
|
441
|
+
raise ValueError("is_online must be bool type")
|
|
442
|
+
# rank_list
|
|
443
|
+
if not isinstance(online_config.rank_list, list):
|
|
444
|
+
raise ValueError("rank_list must be a list")
|
|
445
|
+
if online_config.rank_list and not all(isinstance(rank, int) for rank in online_config.rank_list):
|
|
446
|
+
raise ValueError("All elements in rank_list must be integers")
|
|
447
|
+
|
|
448
|
+
# nfs_path
|
|
449
|
+
if online_config.nfs_path:
|
|
450
|
+
check_file_or_directory_path(online_config.nfs_path, isdir=True)
|
|
451
|
+
return
|
|
452
|
+
# tls_path
|
|
453
|
+
if online_config.tls_path:
|
|
454
|
+
check_file_or_directory_path(online_config.tls_path, isdir=True)
|
|
455
|
+
check_file_or_directory_path(os.path.join(online_config.tls_path, "server.key"))
|
|
456
|
+
check_file_or_directory_path(os.path.join(online_config.tls_path, "server.crt"))
|
|
457
|
+
check_crt_valid(os.path.join(online_config.tls_path, "server.crt"))
|
|
458
|
+
|
|
459
|
+
# host and port
|
|
460
|
+
if not isinstance(online_config.host, str) or not re.match(Const.ipv4_pattern, online_config.host):
|
|
461
|
+
raise Exception(f"host: {online_config.host} is invalid.")
|
|
462
|
+
if not isinstance(online_config.port, int) or not (0 < online_config.port <= 65535):
|
|
463
|
+
raise Exception(f"port: {online_config.port} is invalid, port range 0-65535.")
|
|
464
|
+
|
|
465
|
+
|
|
441
466
|
def run_ut_command(args):
|
|
467
|
+
if args.config_path:
|
|
468
|
+
config_path_checker = FileChecker(args.config_path, FileCheckConst.FILE,
|
|
469
|
+
FileCheckConst.READ_ABLE, FileCheckConst.JSON_SUFFIX)
|
|
470
|
+
checked_config_path = config_path_checker.common_check()
|
|
471
|
+
_, task_config = parse_json_config(checked_config_path, Const.RUN_UT)
|
|
472
|
+
checker_config = CheckerConfig(task_config)
|
|
473
|
+
else:
|
|
474
|
+
checker_config = CheckerConfig()
|
|
475
|
+
|
|
476
|
+
if not checker_config.is_online and not args.api_info_file:
|
|
477
|
+
logger.error("Please provide api_info_file for offline run ut.")
|
|
478
|
+
raise Exception("Please provide api_info_file for offline run ut.")
|
|
479
|
+
|
|
442
480
|
if not is_gpu:
|
|
443
481
|
torch.npu.set_compile_mode(jit_compile=args.jit_compile)
|
|
444
482
|
used_device = current_device + ":" + str(args.device_id[0])
|
|
@@ -464,8 +502,7 @@ def run_ut_command(args):
|
|
|
464
502
|
forward_content = preprocess_forward_content(forward_content)
|
|
465
503
|
logger.info("Finish filtering the api in the api_info_file.")
|
|
466
504
|
|
|
467
|
-
out_path =
|
|
468
|
-
check_path_before_create(out_path)
|
|
505
|
+
out_path = args.out_path if args.out_path else Const.DEFAULT_PATH
|
|
469
506
|
create_directory(out_path)
|
|
470
507
|
out_path_checker = FileChecker(out_path, FileCheckConst.DIR, ability=FileCheckConst.WRITE_ABLE)
|
|
471
508
|
out_path = out_path_checker.common_check()
|
|
@@ -476,40 +513,27 @@ def run_ut_command(args):
|
|
|
476
513
|
if args.result_csv_path:
|
|
477
514
|
result_csv_path = get_validated_result_csv_path(args.result_csv_path, 'result')
|
|
478
515
|
details_csv_path = get_validated_details_csv_path(result_csv_path)
|
|
479
|
-
white_list = msCheckerConfig.white_list
|
|
480
|
-
black_list = msCheckerConfig.black_list
|
|
481
|
-
error_data_path = msCheckerConfig.error_data_path
|
|
482
|
-
is_online = msCheckerConfig.is_online
|
|
483
|
-
nfs_path = msCheckerConfig.nfs_path
|
|
484
|
-
host = msCheckerConfig.host
|
|
485
|
-
port = msCheckerConfig.port
|
|
486
|
-
rank_list = msCheckerConfig.rank_list
|
|
487
|
-
tls_path = msCheckerConfig.tls_path
|
|
488
|
-
if args.config_path:
|
|
489
|
-
config_path_checker = FileChecker(args.config_path, FileCheckConst.FILE,
|
|
490
|
-
FileCheckConst.READ_ABLE, FileCheckConst.JSON_SUFFIX)
|
|
491
|
-
checked_config_path = config_path_checker.common_check()
|
|
492
|
-
_, task_config = parse_json_config(checked_config_path, Const.RUN_UT)
|
|
493
|
-
white_list = task_config.white_list
|
|
494
|
-
black_list = task_config.black_list
|
|
495
|
-
error_data_path = task_config.error_data_path
|
|
496
|
-
is_online = task_config.is_online
|
|
497
|
-
nfs_path = task_config.nfs_path
|
|
498
|
-
host = task_config.host
|
|
499
|
-
port = task_config.port
|
|
500
|
-
rank_list = task_config.rank_list
|
|
501
|
-
tls_path = task_config.tls_path
|
|
502
516
|
|
|
517
|
+
error_data_path = checker_config.error_data_path
|
|
503
518
|
if save_error_data:
|
|
504
519
|
if args.result_csv_path:
|
|
505
520
|
time_info = result_csv_path.split('.')[0].split('_')[-1]
|
|
506
521
|
global UT_ERROR_DATA_DIR
|
|
507
522
|
UT_ERROR_DATA_DIR = 'ut_error_data' + time_info
|
|
508
523
|
error_data_path = initialize_save_error_data(error_data_path)
|
|
509
|
-
online_config =
|
|
510
|
-
|
|
511
|
-
|
|
512
|
-
|
|
524
|
+
online_config = checker_config.get_online_config()
|
|
525
|
+
checked_online_config(online_config)
|
|
526
|
+
config_params = {
|
|
527
|
+
'forward_content': forward_content,
|
|
528
|
+
'backward_content': backward_content,
|
|
529
|
+
'result_csv_path': result_csv_path,
|
|
530
|
+
'details_csv_path': details_csv_path,
|
|
531
|
+
'save_error_data': save_error_data,
|
|
532
|
+
'is_continue_run_ut': args.result_csv_path,
|
|
533
|
+
'real_data_path': real_data_path,
|
|
534
|
+
'error_data_path': error_data_path
|
|
535
|
+
}
|
|
536
|
+
run_ut_config = checker_config.get_run_ut_config(**config_params)
|
|
513
537
|
run_ut(run_ut_config)
|
|
514
538
|
|
|
515
539
|
|
|
@@ -51,7 +51,7 @@ class BackwardMessage:
|
|
|
51
51
|
MULTIPLE_BACKWARD_MESSAGE = "Multiple backward is not supported."
|
|
52
52
|
UNSUPPORT_BACKWARD_MESSAGE = "function with out=... arguments don't support automatic differentiation, " \
|
|
53
53
|
"skip backward."
|
|
54
|
-
NO_BACKWARD_RESULT_MESSAGE = "
|
|
54
|
+
NO_BACKWARD_RESULT_MESSAGE = "This API does not have backward input data, skip backward."
|
|
55
55
|
|
|
56
56
|
|
|
57
57
|
class UtDataInfo:
|
|
@@ -186,11 +186,13 @@ def generate_cpu_params(input_args, input_kwargs, need_backward, api_name):
|
|
|
186
186
|
logger.error("The depth of arg_in is too large, please check the arg_in.")
|
|
187
187
|
raise CompareException(CompareException.RECURSION_LIMIT_ERROR)
|
|
188
188
|
if isinstance(arg_in, (list, tuple)):
|
|
189
|
-
return set().union(*tuple(recursive_find_dtypes(arg, kwargs, check_kwargs=check_kwargs, depth=depth+1) for
|
|
189
|
+
return set().union(*tuple(recursive_find_dtypes(arg, kwargs, check_kwargs=check_kwargs, depth=depth+1) for
|
|
190
|
+
arg in arg_in))
|
|
190
191
|
elif isinstance(arg_in, torch.Tensor) and is_tensor_with_raise_precision(arg_in, check_kwargs):
|
|
191
192
|
return set([arg_in.dtype])
|
|
192
193
|
elif isinstance(arg_in, dict) and check_kwargs:
|
|
193
|
-
return set().union(*tuple(recursive_find_dtypes(v, kwargs, check_kwargs=True, depth=depth+1) for
|
|
194
|
+
return set().union(*tuple(recursive_find_dtypes(v, kwargs, check_kwargs=True, depth=depth+1) for
|
|
195
|
+
v in arg_in.values()))
|
|
194
196
|
return set()
|
|
195
197
|
|
|
196
198
|
raise_dtype = None
|
|
@@ -204,10 +206,19 @@ def generate_cpu_params(input_args, input_kwargs, need_backward, api_name):
|
|
|
204
206
|
raise_dtype = None if api_name in not_raise_dtype_set else raise_dtype
|
|
205
207
|
is_detach = api_name not in not_detach_set
|
|
206
208
|
cpu_args = recursive_arg_to_cpu(input_args, is_detach, raise_dtype=raise_dtype)
|
|
207
|
-
cpu_kwargs = {key: recursive_arg_to_cpu(value, key != "out" and is_detach, raise_dtype=raise_dtype) for
|
|
209
|
+
cpu_kwargs = {key: recursive_arg_to_cpu(value, key != "out" and is_detach, raise_dtype=raise_dtype) for
|
|
210
|
+
key, value in input_kwargs.items()}
|
|
208
211
|
return cpu_args, cpu_kwargs
|
|
209
212
|
|
|
210
213
|
|
|
211
214
|
def record_skip_info(api_full_name, compare, compare_alg_results):
|
|
212
215
|
result_info = (api_full_name, CompareConst.SKIP, CompareConst.SKIP, [compare_alg_results], None, 0)
|
|
213
216
|
compare.record_results(result_info)
|
|
217
|
+
|
|
218
|
+
|
|
219
|
+
def is_unsupported_api(api_name, is_overflow_check=False):
|
|
220
|
+
split_name = api_name.split(Const.SEP)[0]
|
|
221
|
+
flag = (split_name == Const.DISTRIBUTED) or (is_overflow_check and split_name == Const.NPU)
|
|
222
|
+
if flag:
|
|
223
|
+
logger.info(f"{split_name} api is not supported for run ut. SKIP.")
|
|
224
|
+
return flag
|
|
@@ -16,7 +16,6 @@
|
|
|
16
16
|
import glob
|
|
17
17
|
import os.path
|
|
18
18
|
import time
|
|
19
|
-
import re
|
|
20
19
|
from multiprocessing import Queue
|
|
21
20
|
from typing import Optional, Union, Dict, Any
|
|
22
21
|
from dataclasses import dataclass
|
|
@@ -26,9 +25,8 @@ import torch
|
|
|
26
25
|
from msprobe.pytorch.api_accuracy_checker.common.utils import ApiData
|
|
27
26
|
from msprobe.pytorch.api_accuracy_checker.tensor_transport_layer.client import TCPClient
|
|
28
27
|
from msprobe.pytorch.api_accuracy_checker.tensor_transport_layer.server import TCPServer
|
|
29
|
-
from msprobe.pytorch.common.utils import logger
|
|
30
28
|
from msprobe.core.common.file_utils import remove_path
|
|
31
|
-
from msprobe.pytorch.common.utils import save_api_data, load_api_data,
|
|
29
|
+
from msprobe.pytorch.common.utils import logger, save_api_data, load_api_data, save_pkl, load_pkl
|
|
32
30
|
|
|
33
31
|
BufferType = Union[ApiData, Dict[str, Any], str] # Union[Tensor, Tuple[Optional[Tensor]]]
|
|
34
32
|
|
|
@@ -55,7 +53,6 @@ class ATTL:
|
|
|
55
53
|
self.dequeue_list = []
|
|
56
54
|
self.message_end = False
|
|
57
55
|
self.kill_progress = False
|
|
58
|
-
self.check_attl_config()
|
|
59
56
|
self.nfs_path = None
|
|
60
57
|
if self.session_config.nfs_path:
|
|
61
58
|
self.nfs_path = self.session_config.nfs_path
|
|
@@ -73,18 +70,6 @@ class ATTL:
|
|
|
73
70
|
self.session_config.tls_path)
|
|
74
71
|
self.socket_manager.start()
|
|
75
72
|
|
|
76
|
-
def check_attl_config(self):
|
|
77
|
-
if self.session_config.nfs_path:
|
|
78
|
-
if os.path.exists(self.session_config.nfs_path):
|
|
79
|
-
return
|
|
80
|
-
else:
|
|
81
|
-
raise Exception(f"nfs path {self.session_config.nfs_path} doesn't exists.")
|
|
82
|
-
ipv4_pattern = "([1-9]?\d|1\d{2}|2[0-4]\d|25[0-5])(\.([1-9]?\d|1\d{2}|2[0-4]\d|25[0-5])){3}$"
|
|
83
|
-
if not re.match(ipv4_pattern, self.session_config.connect_ip):
|
|
84
|
-
raise Exception(f"host {self.session_config.connect_ip} is invalid.")
|
|
85
|
-
if not (0 < self.session_config.connect_port <= 65535):
|
|
86
|
-
raise Exception(f"port {self.session_config.connect_port} is invalid.")
|
|
87
|
-
|
|
88
73
|
def stop_serve(self):
|
|
89
74
|
if isinstance(self.socket_manager, TCPServer):
|
|
90
75
|
self.socket_manager.stop()
|
|
@@ -115,21 +100,21 @@ class ATTL:
|
|
|
115
100
|
self.socket_manager.add_to_sending_queue(data, rank=rank, step=step)
|
|
116
101
|
|
|
117
102
|
def recv(self, timeout_ms=0) -> Optional[BufferType]:
|
|
118
|
-
buffer =
|
|
119
|
-
while buffer
|
|
103
|
+
buffer = ''
|
|
104
|
+
while not buffer:
|
|
120
105
|
if timeout_ms > 0:
|
|
121
106
|
time.sleep(timeout_ms / 1000.0)
|
|
122
|
-
if buffer
|
|
107
|
+
if not buffer and not self.data_queue.empty():
|
|
123
108
|
buffer = self.data_queue.get()
|
|
124
109
|
break
|
|
125
|
-
if buffer
|
|
110
|
+
if not buffer and timeout_ms > 0: # timeout is the only case we give up and return None
|
|
126
111
|
break
|
|
127
112
|
if self.message_end and self.data_queue.empty():
|
|
128
113
|
buffer = b"KILL_CONFIRM"
|
|
129
114
|
self.kill_progress = True
|
|
130
115
|
break
|
|
131
116
|
time.sleep(0.1) # waiting outside the lock before next attempt
|
|
132
|
-
if buffer
|
|
117
|
+
if not buffer:
|
|
133
118
|
# this is a result of a timeout
|
|
134
119
|
self.logger.info(f"RECEIVE API DATA TIMED OUT")
|
|
135
120
|
else:
|
|
@@ -146,7 +131,7 @@ class ATTL:
|
|
|
146
131
|
except Exception as e:
|
|
147
132
|
self.logger.warning("there is something error. please check it. %s", e)
|
|
148
133
|
if isinstance(buffer, bytes):
|
|
149
|
-
return
|
|
134
|
+
return ''
|
|
150
135
|
if isinstance(buffer, str):
|
|
151
136
|
return buffer
|
|
152
137
|
|
|
@@ -160,7 +145,7 @@ class ATTL:
|
|
|
160
145
|
file_path = os.path.join(self.session_config.nfs_path, buffer + f"_{int(time.time())}")
|
|
161
146
|
|
|
162
147
|
try:
|
|
163
|
-
|
|
148
|
+
save_pkl(buffer, file_path)
|
|
164
149
|
except Exception as e:
|
|
165
150
|
self.logger.warning("there is something error in save_pt. please check it. %s", e)
|
|
166
151
|
|
|
@@ -176,7 +161,7 @@ class ATTL:
|
|
|
176
161
|
|
|
177
162
|
if cur_file is not None:
|
|
178
163
|
try:
|
|
179
|
-
buffer =
|
|
164
|
+
buffer = load_pkl(cur_file)
|
|
180
165
|
except Exception as e:
|
|
181
166
|
self.logger.warning("there is something error. please check it. %s", e)
|
|
182
167
|
remove_path(cur_file)
|
|
@@ -27,8 +27,8 @@ from twisted.internet import reactor, protocol, endpoints
|
|
|
27
27
|
from twisted.protocols.basic import FileSender
|
|
28
28
|
|
|
29
29
|
from msprobe.pytorch.common.utils import logger
|
|
30
|
-
from msprobe.pytorch.api_accuracy_checker.tensor_transport_layer.utils import
|
|
31
|
-
|
|
30
|
+
from msprobe.pytorch.api_accuracy_checker.tensor_transport_layer.utils import STRUCT_UNPACK_MODE as unpack_mode, \
|
|
31
|
+
STR_TO_BYTES_ORDER as bytes_order
|
|
32
32
|
|
|
33
33
|
MAX_SENDING_QUEUE_SIZE = 20
|
|
34
34
|
|
|
@@ -84,15 +84,6 @@ class TCPClient:
|
|
|
84
84
|
def run_reactor():
|
|
85
85
|
reactor.run(installSignalHandlers=False)
|
|
86
86
|
|
|
87
|
-
def check_tls_path(self):
|
|
88
|
-
client_key = os.path.join(self.tls_path, "client.key")
|
|
89
|
-
client_crt = os.path.join(self.tls_path, "client.crt")
|
|
90
|
-
if not os.path.exists(client_key):
|
|
91
|
-
raise Exception(f"client_key: {client_key} is not exists.")
|
|
92
|
-
if not os.path.exists(client_crt):
|
|
93
|
-
raise Exception(f"client_crt: {client_crt} is not exists.")
|
|
94
|
-
return client_key, client_crt
|
|
95
|
-
|
|
96
87
|
def start(self):
|
|
97
88
|
def conn_callback(cur_protocol):
|
|
98
89
|
if cur_protocol.transport and cur_protocol.transport.getPeer().host == self.host:
|
|
@@ -114,7 +105,8 @@ class TCPClient:
|
|
|
114
105
|
self.factory.protocol = cur_protocol
|
|
115
106
|
if self.tls_path:
|
|
116
107
|
from twisted.internet import ssl
|
|
117
|
-
client_key
|
|
108
|
+
client_key = os.path.join(self.tls_path, "client.key")
|
|
109
|
+
client_crt = os.path.join(self.tls_path, "client.crt")
|
|
118
110
|
client_context_factory = ssl.DefaultOpenSSLContextFactory(client_key, client_crt)
|
|
119
111
|
endpoint = endpoints.SSL4ClientEndpoint(reactor, self.host, self.port, client_context_factory)
|
|
120
112
|
else:
|
|
@@ -1,3 +1,4 @@
|
|
|
1
|
+
|
|
1
2
|
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
2
3
|
# All rights reserved.
|
|
3
4
|
#
|
|
@@ -14,6 +15,7 @@
|
|
|
14
15
|
# limitations under the License.
|
|
15
16
|
|
|
16
17
|
import os
|
|
18
|
+
from collections import defaultdict
|
|
17
19
|
from functools import wraps
|
|
18
20
|
|
|
19
21
|
import torch
|
|
@@ -39,7 +41,7 @@ def singleton(cls):
|
|
|
39
41
|
@singleton
|
|
40
42
|
class Counter:
|
|
41
43
|
def __init__(self) -> None:
|
|
42
|
-
self.index_dict =
|
|
44
|
+
self.index_dict = defaultdict(int)
|
|
43
45
|
|
|
44
46
|
|
|
45
47
|
counter = Counter()
|
|
@@ -67,9 +69,9 @@ class AccuracyCheckerDispatch(TorchDispatchMode):
|
|
|
67
69
|
|
|
68
70
|
res = func(*args, **kwargs)
|
|
69
71
|
cur_rank = get_tensor_rank(args, res)
|
|
70
|
-
cur_api_number = self.counter.index_dict
|
|
72
|
+
cur_api_number = self.counter.index_dict[aten_api]
|
|
71
73
|
api_name = f'{Const.ATEN}{Const.SEP}{aten_api}{Const.SEP}{cur_api_number}'
|
|
72
|
-
logger.info(f"tools is dumping api: {api_name}")
|
|
74
|
+
logger.info(f"tools is dumping api: {api_name}, rank: {cur_rank}")
|
|
73
75
|
api_data = ApiData(api_name, args, kwargs, res, 0, cur_rank)
|
|
74
76
|
if "device" in api_data.kwargs:
|
|
75
77
|
api_data.kwargs.pop("device")
|
|
@@ -98,7 +100,7 @@ def dispatch4data(func, attl, status):
|
|
|
98
100
|
return wrapper
|
|
99
101
|
|
|
100
102
|
|
|
101
|
-
def run_ut_dispatch(attl, status):
|
|
103
|
+
def run_ut_dispatch(attl, status, is_recompute=False):
|
|
102
104
|
"""
|
|
103
105
|
This function called by online_run_ut.
|
|
104
106
|
It is used to enable or disable dispatch for torch.autograd.backward function.
|
|
@@ -106,5 +108,8 @@ def run_ut_dispatch(attl, status):
|
|
|
106
108
|
Args:
|
|
107
109
|
attl (ATTL): online_run_ut class ATTL, which is used to upload or send api data to server.
|
|
108
110
|
status (bool): True means enable dispatch, False means disable dispatch.
|
|
111
|
+
is_recompute (bool): Flag of recompute, which is conflicted with aten api, then skip dispatch4data.
|
|
109
112
|
"""
|
|
113
|
+
if is_recompute:
|
|
114
|
+
return
|
|
110
115
|
torch.autograd.backward = dispatch4data(torch.autograd.backward, attl, status)
|
|
@@ -24,7 +24,7 @@ from twisted.internet import reactor, protocol, endpoints
|
|
|
24
24
|
|
|
25
25
|
from msprobe.pytorch.common.utils import logger
|
|
26
26
|
from msprobe.pytorch.api_accuracy_checker.tensor_transport_layer.utils import cipher_list, \
|
|
27
|
-
|
|
27
|
+
STRUCT_UNPACK_MODE as unpack_mode, STR_TO_BYTES_ORDER as bytes_order
|
|
28
28
|
|
|
29
29
|
|
|
30
30
|
class TCPServer:
|
|
@@ -40,22 +40,14 @@ class TCPServer:
|
|
|
40
40
|
def run_reactor():
|
|
41
41
|
reactor.run(installSignalHandlers=False)
|
|
42
42
|
|
|
43
|
-
def check_tls_path(self):
|
|
44
|
-
server_key = os.path.join(self.tls_path, "server.key")
|
|
45
|
-
server_crt = os.path.join(self.tls_path, "server.crt")
|
|
46
|
-
if not os.path.exists(server_key):
|
|
47
|
-
raise Exception(f"server_key: {server_key} is not exists.")
|
|
48
|
-
if not os.path.exists(server_crt):
|
|
49
|
-
raise Exception(f"server_crt: {server_crt} is not exists.")
|
|
50
|
-
return server_key, server_crt
|
|
51
|
-
|
|
52
43
|
def start(self):
|
|
53
44
|
self.factory.protocol = self.build_protocol
|
|
54
45
|
|
|
55
46
|
if self.tls_path:
|
|
56
47
|
from OpenSSL import SSL
|
|
57
48
|
from twisted.internet import ssl
|
|
58
|
-
server_key
|
|
49
|
+
server_key = os.path.join(self.tls_path, "server.key")
|
|
50
|
+
server_crt = os.path.join(self.tls_path, "server.crt")
|
|
59
51
|
server_context_factory = ssl.DefaultOpenSSLContextFactory(server_key, server_crt, SSL.TLSv1_2_METHOD)
|
|
60
52
|
server_context_ = server_context_factory.getContext()
|
|
61
53
|
server_context_.set_cipher_list(cipher_list)
|
|
@@ -22,7 +22,11 @@ def npu_confusion_transpose(data, perm, shape, transpose_first):
|
|
|
22
22
|
|
|
23
23
|
|
|
24
24
|
def npu_confusion_transpose_backward(grad, perm, shape, transpose_first):
|
|
25
|
-
|
|
25
|
+
try:
|
|
26
|
+
shape_cal = shape if transpose_first else [shape[perm_dim] for perm_dim in perm]
|
|
27
|
+
except IndexError as e:
|
|
28
|
+
raise IndexError("npu_confusion_transpose_backward: Invalid perm index for shape") from e
|
|
29
|
+
|
|
26
30
|
perm_cal = [0] * len(perm)
|
|
27
31
|
for i, perm_dim in enumerate(perm):
|
|
28
32
|
perm_cal[perm_dim] = i
|
|
@@ -17,6 +17,9 @@ import torch
|
|
|
17
17
|
|
|
18
18
|
|
|
19
19
|
def matmul_backward(grad, self, other, mask):
|
|
20
|
+
if len(mask) < 2:
|
|
21
|
+
raise RuntimeError("Mask size at least 2")
|
|
22
|
+
|
|
20
23
|
grad_self, grad_other = None, None
|
|
21
24
|
dim_self = self.dim()
|
|
22
25
|
dim_other = other.dim()
|
|
@@ -24,6 +27,7 @@ def matmul_backward(grad, self, other, mask):
|
|
|
24
27
|
size_grad = list(grad.size())
|
|
25
28
|
size_self = list(self.size())
|
|
26
29
|
size_other = list(other.size())
|
|
30
|
+
|
|
27
31
|
if dim_self == 1 and dim_other == 1:
|
|
28
32
|
grad_self = other.mul(grad) if mask[0] else grad_self
|
|
29
33
|
grad_other = self.mul(grad) if mask[1] else grad_other
|
|
@@ -34,19 +38,27 @@ def matmul_backward(grad, self, other, mask):
|
|
|
34
38
|
grad_self = grad.unsqueeze(0).mm(other.transpose(-1, -2)).squeeze_(0) if mask[0] else grad_self
|
|
35
39
|
grad_other = self.unsqueeze(1).mm(grad.unsqueeze(0)) if mask[1] else grad_other
|
|
36
40
|
elif dim_self >= 3 and (dim_other == 1 or dim_other == 2):
|
|
41
|
+
if len(size_grad) < 1:
|
|
42
|
+
raise RuntimeError("size_grad's length at least 1")
|
|
37
43
|
view_size = 1 if dim_other == 1 else size_grad[-1]
|
|
38
44
|
unfolded_grad = (grad.unsqueeze(-1) if dim_other == 1 else grad).contiguous().view(-1, view_size)
|
|
39
45
|
if mask[0]:
|
|
40
46
|
grad_self = unfolded_grad.mm(other.unsqueeze(0) if dim_other == 1 else other.transpose(-1, -2)) \
|
|
41
47
|
.view(size_self)
|
|
42
48
|
if mask[1]:
|
|
49
|
+
if len(size_self) < 1:
|
|
50
|
+
raise RuntimeError("size_self's length at least 1")
|
|
43
51
|
unfolded_self = self.contiguous().view([-1, size_self[-1]])
|
|
44
52
|
grad_other = unfolded_self.transpose(-1, -2).mm(unfolded_grad).view(size_other)
|
|
45
53
|
elif (dim_self == 1 or dim_self == 2) and dim_other >= 3:
|
|
54
|
+
if len(size_grad) < 2:
|
|
55
|
+
raise RuntimeError("size_grad's length at least 2")
|
|
46
56
|
view_size = 1 if dim_self == 1 else size_grad[-2]
|
|
47
57
|
unfolded_grad_t = grad.view([-1, view_size]) \
|
|
48
58
|
if dim_self == 1 else grad.transpose(-1, -2).contiguous().view([-1, view_size])
|
|
49
59
|
if mask[0]:
|
|
60
|
+
if len(size_other) < 2:
|
|
61
|
+
raise RuntimeError("size_other's length at least 2")
|
|
50
62
|
# create a 2D-matrix from other
|
|
51
63
|
unfolded_other_t = \
|
|
52
64
|
other.transpose(-1, -2).contiguous().view([-1, size_other[-2]]).transpose(-1, -2)
|