mindstudio-probe 1.0.4__py3-none-any.whl → 1.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.1.dist-info}/METADATA +5 -5
- mindstudio_probe-1.1.1.dist-info/RECORD +341 -0
- {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.1.dist-info}/WHEEL +1 -1
- {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.1.dist-info}/entry_points.txt +0 -1
- msprobe/README.md +84 -18
- msprobe/__init__.py +16 -1
- msprobe/config.json +1 -5
- msprobe/core/advisor/advisor.py +16 -11
- msprobe/core/advisor/advisor_const.py +6 -7
- msprobe/core/advisor/advisor_result.py +12 -12
- msprobe/core/common/const.py +164 -3
- msprobe/core/common/exceptions.py +26 -4
- msprobe/core/common/file_utils.py +196 -27
- msprobe/core/common/inplace_op_checker.py +53 -0
- msprobe/core/common/inplace_ops.yaml +251 -0
- msprobe/core/common/log.py +46 -18
- msprobe/core/common/utils.py +308 -209
- msprobe/core/common_config.py +60 -38
- msprobe/core/compare/acc_compare.py +332 -94
- msprobe/core/compare/check.py +104 -22
- msprobe/core/compare/compare_cli.py +42 -5
- msprobe/core/compare/highlight.py +162 -57
- msprobe/core/compare/layer_mapping/__init__.py +19 -0
- msprobe/core/compare/layer_mapping/data_scope_parser.py +235 -0
- msprobe/core/compare/layer_mapping/layer_mapping.py +242 -0
- msprobe/core/compare/layer_mapping/postprocess_pass.py +94 -0
- msprobe/core/compare/multiprocessing_compute.py +33 -8
- msprobe/core/compare/npy_compare.py +73 -29
- msprobe/core/compare/utils.py +306 -247
- msprobe/core/data_dump/data_collector.py +44 -43
- msprobe/core/data_dump/data_processor/base.py +88 -35
- msprobe/core/data_dump/data_processor/factory.py +20 -3
- msprobe/core/data_dump/data_processor/mindspore_processor.py +14 -8
- msprobe/core/data_dump/data_processor/pytorch_processor.py +180 -66
- msprobe/core/data_dump/json_writer.py +63 -42
- msprobe/core/data_dump/scope.py +143 -48
- msprobe/core/grad_probe/constant.py +31 -13
- msprobe/core/grad_probe/grad_compare.py +20 -4
- msprobe/core/grad_probe/utils.py +44 -3
- msprobe/core/overflow_check/abnormal_scene.py +185 -0
- msprobe/core/overflow_check/api_info.py +55 -0
- msprobe/core/overflow_check/checker.py +138 -0
- msprobe/core/overflow_check/filter.py +157 -0
- msprobe/core/overflow_check/ignore_rules.yaml +55 -0
- msprobe/core/overflow_check/level.py +22 -0
- msprobe/core/overflow_check/utils.py +28 -0
- msprobe/docs/01.installation.md +29 -9
- msprobe/docs/02.config_introduction.md +83 -84
- msprobe/docs/03.config_examples.md +3 -20
- msprobe/docs/04.kernel_dump_PyTorch.md +73 -0
- msprobe/docs/05.data_dump_PyTorch.md +143 -13
- msprobe/docs/06.data_dump_MindSpore.md +197 -88
- msprobe/docs/07.accuracy_checker_PyTorch.md +69 -46
- msprobe/docs/08.accuracy_checker_online_PyTorch.md +52 -17
- msprobe/docs/09.accuracy_checker_MindSpore.md +51 -15
- msprobe/docs/10.accuracy_compare_PyTorch.md +187 -99
- msprobe/docs/11.accuracy_compare_MindSpore.md +253 -31
- msprobe/docs/12.overflow_check_PyTorch.md +1 -1
- msprobe/docs/13.overflow_check_MindSpore.md +6 -6
- msprobe/docs/15.free_benchmarking_PyTorch.md +60 -55
- msprobe/docs/16.free_benchmarking_MindSpore.md +159 -0
- msprobe/docs/17.grad_probe.md +19 -22
- msprobe/docs/18.online_dispatch.md +89 -0
- msprobe/docs/19.monitor.md +468 -0
- msprobe/docs/20.monitor_performance_baseline.md +52 -0
- msprobe/docs/21.visualization_PyTorch.md +386 -0
- msprobe/docs/22.visualization_MindSpore.md +384 -0
- msprobe/docs/23.tool_function_introduction.md +28 -0
- msprobe/docs/{FAQ_PyTorch.md → FAQ.md} +25 -10
- msprobe/docs/data_dump_Mindspore/dynamic_graph_quick_start_example.md +211 -0
- msprobe/docs/img/compare_result.png +0 -0
- msprobe/docs/img/monitor/cpu_info.png +0 -0
- msprobe/docs/img/ms_dump.png +0 -0
- msprobe/docs/img/ms_layer.png +0 -0
- msprobe/docs/img/pt_dump.png +0 -0
- msprobe/mindspore/__init__.py +16 -0
- msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +130 -138
- msprobe/mindspore/api_accuracy_checker/api_info.py +27 -5
- msprobe/mindspore/api_accuracy_checker/api_runner.py +43 -18
- msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +21 -7
- msprobe/mindspore/api_accuracy_checker/checker_support_api.yaml +77 -0
- msprobe/mindspore/api_accuracy_checker/cmd_parser.py +63 -1
- msprobe/mindspore/api_accuracy_checker/compute_element.py +59 -24
- msprobe/mindspore/api_accuracy_checker/data_manager.py +264 -0
- msprobe/mindspore/api_accuracy_checker/main.py +27 -3
- msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +206 -0
- msprobe/mindspore/api_accuracy_checker/multi_data_manager.py +58 -0
- msprobe/mindspore/api_accuracy_checker/type_mapping.py +22 -5
- msprobe/mindspore/api_accuracy_checker/utils.py +34 -17
- msprobe/mindspore/cell_processor.py +58 -13
- msprobe/mindspore/common/const.py +35 -13
- msprobe/mindspore/common/log.py +5 -9
- msprobe/mindspore/common/utils.py +60 -5
- msprobe/mindspore/compare/distributed_compare.py +15 -28
- msprobe/mindspore/compare/ms_compare.py +319 -158
- msprobe/mindspore/compare/ms_graph_compare.py +99 -49
- msprobe/mindspore/debugger/debugger_config.py +20 -14
- msprobe/mindspore/debugger/precision_debugger.py +43 -13
- msprobe/mindspore/dump/dump_tool_factory.py +18 -1
- msprobe/mindspore/dump/hook_cell/api_registry.py +23 -3
- msprobe/mindspore/dump/hook_cell/primitive_hooks.py +203 -0
- msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +107 -10
- msprobe/mindspore/dump/hook_cell/wrap_api.py +21 -13
- msprobe/mindspore/dump/jit_dump.py +56 -20
- msprobe/mindspore/dump/kernel_graph_dump.py +19 -5
- msprobe/mindspore/dump/kernel_kbyk_dump.py +19 -6
- msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +140 -0
- msprobe/mindspore/dym_loader/hook_dynamic_loader.h +53 -0
- msprobe/mindspore/free_benchmark/api_pynative_self_check.py +162 -41
- msprobe/mindspore/free_benchmark/common/config.py +15 -0
- msprobe/mindspore/free_benchmark/common/handler_params.py +15 -1
- msprobe/mindspore/free_benchmark/common/utils.py +37 -8
- msprobe/mindspore/free_benchmark/data/support_wrap_ops.yaml +0 -204
- msprobe/mindspore/free_benchmark/handler/base_handler.py +20 -5
- msprobe/mindspore/free_benchmark/handler/check_handler.py +21 -7
- msprobe/mindspore/free_benchmark/handler/fix_handler.py +18 -3
- msprobe/mindspore/free_benchmark/handler/handler_factory.py +21 -6
- msprobe/mindspore/free_benchmark/perturbation/add_noise.py +23 -8
- msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +29 -5
- msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +25 -10
- msprobe/mindspore/free_benchmark/perturbation/exchange_value.py +45 -19
- msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +29 -8
- msprobe/mindspore/free_benchmark/perturbation/no_change.py +16 -1
- msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +22 -7
- msprobe/mindspore/free_benchmark/self_check_tool_factory.py +17 -2
- msprobe/mindspore/grad_probe/global_context.py +44 -14
- msprobe/mindspore/grad_probe/grad_analyzer.py +27 -13
- msprobe/mindspore/grad_probe/grad_monitor.py +16 -1
- msprobe/mindspore/grad_probe/grad_stat_csv.py +33 -5
- msprobe/mindspore/grad_probe/hook.py +24 -10
- msprobe/mindspore/grad_probe/utils.py +18 -5
- msprobe/mindspore/ms_config.py +22 -15
- msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +20 -6
- msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +15 -0
- msprobe/mindspore/runtime.py +15 -0
- msprobe/mindspore/service.py +75 -150
- msprobe/mindspore/task_handler_factory.py +15 -0
- msprobe/msprobe.py +24 -7
- msprobe/pytorch/__init__.py +23 -3
- msprobe/pytorch/api_accuracy_checker/common/config.py +81 -2
- msprobe/pytorch/api_accuracy_checker/common/utils.py +53 -21
- msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +19 -2
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +50 -25
- msprobe/pytorch/api_accuracy_checker/compare/compare.py +51 -21
- msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +23 -6
- msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +28 -8
- msprobe/pytorch/api_accuracy_checker/config.yaml +1 -1
- msprobe/pytorch/api_accuracy_checker/generate_op_script/config_op.json +9 -0
- msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +454 -0
- msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +365 -0
- msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +73 -33
- msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +44 -18
- msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +32 -11
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +122 -172
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +158 -4
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +30 -24
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +68 -31
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +27 -4
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/dump_dispatch.py +115 -0
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +26 -9
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/torch_ops_config.yaml +63 -0
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +44 -0
- msprobe/pytorch/bench_functions/__init__.py +18 -3
- msprobe/pytorch/bench_functions/apply_adam_w.py +15 -0
- msprobe/pytorch/bench_functions/confusion_transpose.py +20 -1
- msprobe/pytorch/bench_functions/fast_gelu.py +15 -0
- msprobe/pytorch/bench_functions/layer_norm_eval.py +15 -0
- msprobe/pytorch/bench_functions/linear.py +15 -0
- msprobe/pytorch/bench_functions/matmul_backward.py +33 -6
- msprobe/pytorch/bench_functions/npu_fusion_attention.py +280 -157
- msprobe/pytorch/bench_functions/rms_norm.py +15 -0
- msprobe/pytorch/bench_functions/rotary_mul.py +32 -9
- msprobe/pytorch/bench_functions/scaled_mask_softmax.py +15 -0
- msprobe/pytorch/bench_functions/swiglu.py +29 -6
- msprobe/pytorch/common/__init__.py +15 -0
- msprobe/pytorch/common/log.py +18 -6
- msprobe/pytorch/common/parse_json.py +31 -16
- msprobe/pytorch/common/utils.py +96 -40
- msprobe/pytorch/compare/distributed_compare.py +13 -14
- msprobe/pytorch/compare/match.py +15 -0
- msprobe/pytorch/compare/pt_compare.py +44 -10
- msprobe/pytorch/debugger/debugger_config.py +69 -52
- msprobe/pytorch/debugger/precision_debugger.py +72 -24
- msprobe/pytorch/dump/kernel_dump/kernel_config.py +33 -0
- msprobe/pytorch/free_benchmark/__init__.py +20 -5
- msprobe/pytorch/free_benchmark/common/constant.py +15 -0
- msprobe/pytorch/free_benchmark/common/counter.py +15 -0
- msprobe/pytorch/free_benchmark/common/enums.py +43 -0
- msprobe/pytorch/free_benchmark/common/params.py +23 -1
- msprobe/pytorch/free_benchmark/common/utils.py +43 -5
- msprobe/pytorch/free_benchmark/compare/grad_saver.py +47 -9
- msprobe/pytorch/free_benchmark/compare/single_benchmark.py +17 -0
- msprobe/pytorch/free_benchmark/main.py +19 -4
- msprobe/pytorch/free_benchmark/perturbed_layers/base_layer.py +15 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/layer_factory.py +19 -4
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +18 -1
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +21 -4
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +28 -2
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +19 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/no_change.py +15 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/npu_base_layser.py +15 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/run_cpu.py +15 -0
- msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +65 -16
- msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +15 -0
- msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +21 -5
- msprobe/pytorch/free_benchmark/result_handlers/handler_factory.py +15 -0
- msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +19 -4
- msprobe/pytorch/function_factory.py +17 -2
- msprobe/pytorch/functional/module_dump.py +84 -0
- msprobe/pytorch/grad_probe/grad_monitor.py +23 -6
- msprobe/pytorch/grad_probe/grad_stat_csv.py +40 -10
- msprobe/pytorch/hook_module/__init__.py +16 -1
- msprobe/pytorch/hook_module/api_registry.py +13 -8
- msprobe/pytorch/hook_module/hook_module.py +17 -19
- msprobe/pytorch/hook_module/support_wrap_ops.yaml +1 -0
- msprobe/pytorch/hook_module/utils.py +4 -6
- msprobe/pytorch/hook_module/wrap_aten.py +12 -11
- msprobe/pytorch/hook_module/wrap_distributed.py +6 -7
- msprobe/pytorch/hook_module/wrap_functional.py +21 -20
- msprobe/pytorch/hook_module/wrap_npu_custom.py +9 -17
- msprobe/pytorch/hook_module/wrap_tensor.py +4 -6
- msprobe/pytorch/hook_module/wrap_torch.py +4 -6
- msprobe/pytorch/hook_module/wrap_vf.py +4 -6
- msprobe/pytorch/module_processer.py +18 -6
- msprobe/pytorch/monitor/anomaly_analyse.py +201 -0
- msprobe/pytorch/monitor/anomaly_detect.py +340 -0
- msprobe/pytorch/monitor/distributed/distributed_ops.yaml +19 -0
- msprobe/pytorch/monitor/distributed/stack_blacklist.yaml +5 -0
- msprobe/pytorch/monitor/distributed/wrap_distributed.py +272 -0
- msprobe/pytorch/monitor/features.py +108 -0
- msprobe/pytorch/monitor/module_hook.py +870 -0
- msprobe/pytorch/monitor/module_metric.py +193 -0
- msprobe/pytorch/monitor/module_spec_verifier.py +93 -0
- msprobe/pytorch/monitor/optimizer_collect.py +295 -0
- msprobe/pytorch/monitor/unittest/__init__.py +0 -0
- msprobe/pytorch/monitor/unittest/test_monitor.py +145 -0
- msprobe/pytorch/monitor/utils.py +250 -0
- msprobe/pytorch/monitor/visualizer.py +59 -0
- msprobe/pytorch/online_dispatch/__init__.py +2 -3
- msprobe/pytorch/online_dispatch/compare.py +38 -48
- msprobe/pytorch/online_dispatch/dispatch.py +50 -25
- msprobe/pytorch/online_dispatch/dump_compare.py +21 -9
- msprobe/pytorch/online_dispatch/single_compare.py +60 -39
- msprobe/pytorch/online_dispatch/torch_ops_config.yaml +9 -1
- msprobe/pytorch/online_dispatch/utils.py +48 -23
- msprobe/pytorch/parse.py +15 -0
- msprobe/pytorch/parse_tool/cli.py +5 -6
- msprobe/pytorch/parse_tool/lib/compare.py +19 -26
- msprobe/pytorch/parse_tool/lib/config.py +1 -1
- msprobe/pytorch/parse_tool/lib/parse_tool.py +4 -2
- msprobe/pytorch/parse_tool/lib/utils.py +40 -55
- msprobe/pytorch/parse_tool/lib/visualization.py +3 -1
- msprobe/pytorch/pt_config.py +192 -40
- msprobe/pytorch/service.py +110 -35
- msprobe/visualization/__init__.py +14 -0
- msprobe/visualization/builder/__init__.py +14 -0
- msprobe/visualization/builder/graph_builder.py +165 -0
- msprobe/visualization/builder/msprobe_adapter.py +205 -0
- msprobe/visualization/compare/__init__.py +14 -0
- msprobe/visualization/compare/graph_comparator.py +130 -0
- msprobe/visualization/compare/mode_adapter.py +211 -0
- msprobe/visualization/graph/__init__.py +14 -0
- msprobe/visualization/graph/base_node.py +124 -0
- msprobe/visualization/graph/graph.py +200 -0
- msprobe/visualization/graph/node_colors.py +95 -0
- msprobe/visualization/graph/node_op.py +39 -0
- msprobe/visualization/graph_service.py +214 -0
- msprobe/visualization/utils.py +232 -0
- mindstudio_probe-1.0.4.dist-info/RECORD +0 -276
- msprobe/docs/04.acl_config_examples.md +0 -76
- msprobe/mindspore/free_benchmark/decorator/dec_forward.py +0 -43
- msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +0 -107
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/ssl_config.py +0 -10
- msprobe/pytorch/functional/dump_module.py +0 -39
- {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.1.dist-info}/LICENSE +0 -0
- {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.1.dist-info}/top_level.txt +0 -0
- /msprobe/{mindspore/free_benchmark/decorator → pytorch/monitor}/__init__.py +0 -0
- /msprobe/pytorch/{functional/data_processor.py → monitor/distributed/__init__.py} +0 -0
|
@@ -0,0 +1,115 @@
|
|
|
1
|
+
|
|
2
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
3
|
+
# All rights reserved.
|
|
4
|
+
#
|
|
5
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
6
|
+
# you may not use this file except in compliance with the License.
|
|
7
|
+
# You may obtain a copy of the License at
|
|
8
|
+
#
|
|
9
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
10
|
+
#
|
|
11
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
12
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
13
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
14
|
+
# See the License for the specific language governing permissions and
|
|
15
|
+
# limitations under the License.
|
|
16
|
+
|
|
17
|
+
import os
|
|
18
|
+
from collections import defaultdict
|
|
19
|
+
from functools import wraps
|
|
20
|
+
|
|
21
|
+
import torch
|
|
22
|
+
from torch.utils._python_dispatch import TorchDispatchMode
|
|
23
|
+
from msprobe.pytorch.api_accuracy_checker.common.utils import ApiData
|
|
24
|
+
from msprobe.pytorch.common.utils import get_tensor_rank
|
|
25
|
+
from msprobe.core.common.const import Const
|
|
26
|
+
from msprobe.pytorch.common.log import logger
|
|
27
|
+
from msprobe.core.common.file_utils import load_yaml
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def singleton(cls):
|
|
31
|
+
_instance = {}
|
|
32
|
+
|
|
33
|
+
@wraps(cls)
|
|
34
|
+
def inner():
|
|
35
|
+
if cls not in _instance:
|
|
36
|
+
_instance[cls] = cls()
|
|
37
|
+
return _instance[cls]
|
|
38
|
+
return inner
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
@singleton
|
|
42
|
+
class Counter:
|
|
43
|
+
def __init__(self) -> None:
|
|
44
|
+
self.index_dict = defaultdict(int)
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
counter = Counter()
|
|
48
|
+
yaml_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "torch_ops_config.yaml")
|
|
49
|
+
yaml_file = load_yaml(yaml_path)
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
class AccuracyCheckerDispatch(TorchDispatchMode):
|
|
53
|
+
def __init__(self, attl):
|
|
54
|
+
super(AccuracyCheckerDispatch, self).__init__()
|
|
55
|
+
self.attl = attl
|
|
56
|
+
self.counter = counter
|
|
57
|
+
self.aten_ops_blacklist = []
|
|
58
|
+
self.npu_adjust_autogard = []
|
|
59
|
+
self.aten_ops_blacklist = yaml_file.get('aten_ops_blacklist', [])
|
|
60
|
+
self.npu_adjust_autogard = yaml_file.get('npu_adjust_autogard', [])
|
|
61
|
+
|
|
62
|
+
def __torch_dispatch__(self, func, types, args=None, kwargs=None):
|
|
63
|
+
func_name_split_list = func.__name__.split(Const.SEP)
|
|
64
|
+
aten_api = func_name_split_list[0]
|
|
65
|
+
self.enable_autogard(aten_api)
|
|
66
|
+
if aten_api in self.aten_ops_blacklist:
|
|
67
|
+
npu_out = func(*args, **kwargs)
|
|
68
|
+
return npu_out
|
|
69
|
+
|
|
70
|
+
res = func(*args, **kwargs)
|
|
71
|
+
cur_rank = get_tensor_rank(args, res)
|
|
72
|
+
cur_api_number = self.counter.index_dict[aten_api]
|
|
73
|
+
api_name = f'{Const.ATEN}{Const.SEP}{aten_api}{Const.SEP}{cur_api_number}'
|
|
74
|
+
logger.info(f"tools is dumping api: {api_name}, rank: {cur_rank}")
|
|
75
|
+
api_data = ApiData(api_name, args, kwargs, res, 0, cur_rank)
|
|
76
|
+
if "device" in api_data.kwargs:
|
|
77
|
+
api_data.kwargs.pop("device")
|
|
78
|
+
if self.attl.nfs_path:
|
|
79
|
+
self.attl.upload(api_data)
|
|
80
|
+
else:
|
|
81
|
+
self.attl.send(api_data)
|
|
82
|
+
self.counter.index_dict[aten_api] += 1
|
|
83
|
+
|
|
84
|
+
return res
|
|
85
|
+
|
|
86
|
+
def enable_autogard(self, aten_api):
|
|
87
|
+
if aten_api in self.npu_adjust_autogard:
|
|
88
|
+
torch._C._dispatch_tls_set_dispatch_key_excluded(torch._C.DispatchKey.AutogradFunctionality, False)
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
def dispatch4data(func, attl, status):
|
|
92
|
+
@wraps(func)
|
|
93
|
+
def wrapper(*args, **kwargs):
|
|
94
|
+
if not status:
|
|
95
|
+
return func(*args, **kwargs)
|
|
96
|
+
with AccuracyCheckerDispatch(attl):
|
|
97
|
+
res = func(*args, **kwargs)
|
|
98
|
+
return res
|
|
99
|
+
|
|
100
|
+
return wrapper
|
|
101
|
+
|
|
102
|
+
|
|
103
|
+
def run_ut_dispatch(attl, status, is_recompute=False):
|
|
104
|
+
"""
|
|
105
|
+
This function called by online_run_ut.
|
|
106
|
+
It is used to enable or disable dispatch for torch.autograd.backward function.
|
|
107
|
+
|
|
108
|
+
Args:
|
|
109
|
+
attl (ATTL): online_run_ut class ATTL, which is used to upload or send api data to server.
|
|
110
|
+
status (bool): True means enable dispatch, False means disable dispatch.
|
|
111
|
+
is_recompute (bool): Flag of recompute, which is conflicted with aten api, then skip dispatch4data.
|
|
112
|
+
"""
|
|
113
|
+
if is_recompute:
|
|
114
|
+
return
|
|
115
|
+
torch.autograd.backward = dispatch4data(torch.autograd.backward, attl, status)
|
|
@@ -1,3 +1,18 @@
|
|
|
1
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
1
16
|
import os.path
|
|
2
17
|
import struct
|
|
3
18
|
import hashlib
|
|
@@ -8,7 +23,8 @@ from threading import Thread
|
|
|
8
23
|
from twisted.internet import reactor, protocol, endpoints
|
|
9
24
|
|
|
10
25
|
from msprobe.pytorch.common.utils import logger
|
|
11
|
-
from msprobe.pytorch.api_accuracy_checker.tensor_transport_layer.
|
|
26
|
+
from msprobe.pytorch.api_accuracy_checker.tensor_transport_layer.utils import cipher_list, \
|
|
27
|
+
STRUCT_UNPACK_MODE as unpack_mode, STR_TO_BYTES_ORDER as bytes_order
|
|
12
28
|
|
|
13
29
|
|
|
14
30
|
class TCPServer:
|
|
@@ -100,9 +116,9 @@ class ServerProtocol(protocol.Protocol):
|
|
|
100
116
|
def send_ack(self, ack_info):
|
|
101
117
|
ack_message = b"".join([
|
|
102
118
|
ack_info,
|
|
103
|
-
self.sequence_number.to_bytes(8, byteorder=
|
|
104
|
-
self.rank.to_bytes(8, byteorder=
|
|
105
|
-
self.step.to_bytes(8, byteorder=
|
|
119
|
+
self.sequence_number.to_bytes(8, byteorder=bytes_order),
|
|
120
|
+
self.rank.to_bytes(8, byteorder=bytes_order),
|
|
121
|
+
self.step.to_bytes(8, byteorder=bytes_order)
|
|
106
122
|
])
|
|
107
123
|
self.transport.write(ack_message)
|
|
108
124
|
|
|
@@ -168,10 +184,10 @@ class ServerProtocol(protocol.Protocol):
|
|
|
168
184
|
# The first data packet is packet header, it contains obj_length, sequence_number, rank, step
|
|
169
185
|
if self.obj_length is None and len(self.buffer.getvalue()) >= self.length_width * 4:
|
|
170
186
|
self.start_time = time.time()
|
|
171
|
-
self.obj_length = struct.unpack(
|
|
172
|
-
self.sequence_number = struct.unpack(
|
|
173
|
-
self.rank = struct.unpack(
|
|
174
|
-
self.step = struct.unpack(
|
|
187
|
+
self.obj_length = struct.unpack(unpack_mode, self.buffer.read(self.length_width))[0]
|
|
188
|
+
self.sequence_number = struct.unpack(unpack_mode, self.buffer.read(self.length_width))[0]
|
|
189
|
+
self.rank = struct.unpack(unpack_mode, self.buffer.read(self.length_width))[0]
|
|
190
|
+
self.step = struct.unpack(unpack_mode, self.buffer.read(self.length_width))[0]
|
|
175
191
|
self.tell += self.length_width * 4
|
|
176
192
|
logger.debug(
|
|
177
193
|
f"流水号: {self.sequence_number}; RANK: {self.rank}; STEP: {self.step}; Length: {self.obj_length}")
|
|
@@ -210,7 +226,8 @@ class MessageServerFactory(protocol.ServerFactory):
|
|
|
210
226
|
def __init__(self) -> None:
|
|
211
227
|
"""
|
|
212
228
|
transport_dict: links that have not completed data transmission.
|
|
213
|
-
transport_list: Records all TCP links. Appends TCP link to the transport list
|
|
229
|
+
transport_list: Records all TCP links. Appends TCP link to the transport list
|
|
230
|
+
when a new TCP link is established.
|
|
214
231
|
"""
|
|
215
232
|
self.transport_dict = {}
|
|
216
233
|
self.transport_list = []
|
|
@@ -0,0 +1,63 @@
|
|
|
1
|
+
aten_ops_blacklist:
|
|
2
|
+
- npu_binary_cross_entropy_with_logits_backward
|
|
3
|
+
- npu_ciou_backward
|
|
4
|
+
- _cudnn_rnn
|
|
5
|
+
- _local_scalar_dense
|
|
6
|
+
- _pin_memory
|
|
7
|
+
- _to_copy
|
|
8
|
+
- _unsafe_view
|
|
9
|
+
- clone
|
|
10
|
+
- contiguous
|
|
11
|
+
- copy_
|
|
12
|
+
- cudnn_batch_norm
|
|
13
|
+
- cudnn_batch_norm_backward
|
|
14
|
+
- detach
|
|
15
|
+
- empty
|
|
16
|
+
- index_put_
|
|
17
|
+
- lift_fresh
|
|
18
|
+
- max_pool2d_with_indices_backward # shape unmatch
|
|
19
|
+
- native_batch_norm_backward
|
|
20
|
+
- new_empty
|
|
21
|
+
- new_empty_strided
|
|
22
|
+
- new_full
|
|
23
|
+
- new_ones
|
|
24
|
+
- new_zeros
|
|
25
|
+
- ones
|
|
26
|
+
- ones_like
|
|
27
|
+
- permute
|
|
28
|
+
- rand
|
|
29
|
+
- rand_like
|
|
30
|
+
- randint
|
|
31
|
+
- randint_like
|
|
32
|
+
- randn
|
|
33
|
+
- randn_like
|
|
34
|
+
- randperm
|
|
35
|
+
- scalar_tensor
|
|
36
|
+
- select
|
|
37
|
+
- to
|
|
38
|
+
- transpose
|
|
39
|
+
- unbind
|
|
40
|
+
- view
|
|
41
|
+
- zero
|
|
42
|
+
- zero_
|
|
43
|
+
- zeros
|
|
44
|
+
- zeros_like
|
|
45
|
+
- _record_function_enter_new
|
|
46
|
+
- _record_function_exit
|
|
47
|
+
- broadcast_
|
|
48
|
+
- allreduce_
|
|
49
|
+
- npu_clear_float_status
|
|
50
|
+
- npu_format_cast
|
|
51
|
+
- npu_dtype_cast
|
|
52
|
+
- npu_dtype_cast_backward
|
|
53
|
+
- _allgather_base_
|
|
54
|
+
- _reduce_scatter_base_
|
|
55
|
+
- is_same_size
|
|
56
|
+
|
|
57
|
+
npu_adjust_autogard:
|
|
58
|
+
- adaptive_avg_pool2d
|
|
59
|
+
- batch_norm
|
|
60
|
+
- log_softmax
|
|
61
|
+
- nll_loss
|
|
62
|
+
- to
|
|
63
|
+
|
|
@@ -0,0 +1,44 @@
|
|
|
1
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
cipher_list = ":".join(
|
|
17
|
+
["TLS_DHE_RSA_WITH_AES_128_GCM_SHA256",
|
|
18
|
+
"TLS_DHE_RSA_WITH_AES_256_GCM_SHA384",
|
|
19
|
+
"TLS_DHE_DSS_WITH_AES_128_GCM_SHA256",
|
|
20
|
+
"TLS_DHE_DSS_WITH_AES_256_GCM_SHA384",
|
|
21
|
+
"TLS_DHE_PSK_WITH_AES_128_GCM_SHA256",
|
|
22
|
+
"TLS_DHE_PSK_WITH_AES_256_GCM_SHA384",
|
|
23
|
+
"TLS_DHE_PSK_WITH_CHACHA20_POLY1305_SHA256",
|
|
24
|
+
"TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256",
|
|
25
|
+
"TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384",
|
|
26
|
+
"TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256",
|
|
27
|
+
"TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384",
|
|
28
|
+
"TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256",
|
|
29
|
+
"TLS_ECDHE_PSK_WITH_CHACHA20_POLY1305_SHA256",
|
|
30
|
+
"TLS_ECDHE_PSK_WITH_AES_128_GCM_SHA256",
|
|
31
|
+
"TLS_ECDHE_PSK_WITH_AES_256_GCM_SHA384",
|
|
32
|
+
"TLS_ECDHE_PSK_WITH_AES_128_CCM_SHA256",
|
|
33
|
+
"TLS_DHE_RSA_WITH_AES_128_CCM",
|
|
34
|
+
"TLS_DHE_RSA_WITH_AES_256_CCM",
|
|
35
|
+
"TLS_DHE_RSA_WITH_CHACHA20_POLY1305_SHA256",
|
|
36
|
+
"TLS_DHE_PSK_WITH_AES_128_CCM",
|
|
37
|
+
"TLS_DHE_PSK_WITH_AES_256_CCM",
|
|
38
|
+
"TLS_ECDHE_ECDSA_WITH_AES_128_CCM",
|
|
39
|
+
"TLS_ECDHE_ECDSA_WITH_AES_256_CCM",
|
|
40
|
+
"TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256"]
|
|
41
|
+
).encode()
|
|
42
|
+
|
|
43
|
+
STRUCT_UNPACK_MODE = "!Q"
|
|
44
|
+
STR_TO_BYTES_ORDER = "big"
|
|
@@ -1,11 +1,26 @@
|
|
|
1
|
-
|
|
2
|
-
|
|
3
|
-
|
|
1
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
4
15
|
|
|
5
16
|
"""
|
|
6
17
|
gpu and cpu not implement benchmark function, supplementary benchmarking function implementation
|
|
7
18
|
"""
|
|
8
19
|
|
|
20
|
+
import os
|
|
21
|
+
from pkgutil import iter_modules
|
|
22
|
+
from importlib import import_module
|
|
23
|
+
|
|
9
24
|
package_path = os.path.dirname(os.path.realpath(__file__))
|
|
10
25
|
for _, module_name, _ in iter_modules([package_path]):
|
|
11
26
|
module = import_module(f"{__name__}.{module_name}")
|
|
@@ -1,3 +1,18 @@
|
|
|
1
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
1
16
|
import torch
|
|
2
17
|
|
|
3
18
|
|
|
@@ -1,3 +1,18 @@
|
|
|
1
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
1
16
|
def npu_confusion_transpose(data, perm, shape, transpose_first):
|
|
2
17
|
if transpose_first:
|
|
3
18
|
output = data.permute(*perm).contiguous().view(shape)
|
|
@@ -7,7 +22,11 @@ def npu_confusion_transpose(data, perm, shape, transpose_first):
|
|
|
7
22
|
|
|
8
23
|
|
|
9
24
|
def npu_confusion_transpose_backward(grad, perm, shape, transpose_first):
|
|
10
|
-
|
|
25
|
+
try:
|
|
26
|
+
shape_cal = shape if transpose_first else [shape[perm_dim] for perm_dim in perm]
|
|
27
|
+
except IndexError as e:
|
|
28
|
+
raise IndexError("npu_confusion_transpose_backward: Invalid perm index for shape") from e
|
|
29
|
+
|
|
11
30
|
perm_cal = [0] * len(perm)
|
|
12
31
|
for i, perm_dim in enumerate(perm):
|
|
13
32
|
perm_cal[perm_dim] = i
|
|
@@ -1,3 +1,18 @@
|
|
|
1
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
1
16
|
import torch
|
|
2
17
|
|
|
3
18
|
|
|
@@ -1,3 +1,18 @@
|
|
|
1
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
1
16
|
import torch
|
|
2
17
|
|
|
3
18
|
|
|
@@ -1,3 +1,18 @@
|
|
|
1
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
1
16
|
import torch
|
|
2
17
|
|
|
3
18
|
|
|
@@ -1,7 +1,25 @@
|
|
|
1
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
1
16
|
import torch
|
|
2
17
|
|
|
3
18
|
|
|
4
19
|
def matmul_backward(grad, self, other, mask):
|
|
20
|
+
if len(mask) < 2:
|
|
21
|
+
raise RuntimeError("Mask size at least 2")
|
|
22
|
+
|
|
5
23
|
grad_self, grad_other = None, None
|
|
6
24
|
dim_self = self.dim()
|
|
7
25
|
dim_other = other.dim()
|
|
@@ -9,6 +27,7 @@ def matmul_backward(grad, self, other, mask):
|
|
|
9
27
|
size_grad = list(grad.size())
|
|
10
28
|
size_self = list(self.size())
|
|
11
29
|
size_other = list(other.size())
|
|
30
|
+
|
|
12
31
|
if dim_self == 1 and dim_other == 1:
|
|
13
32
|
grad_self = other.mul(grad) if mask[0] else grad_self
|
|
14
33
|
grad_other = self.mul(grad) if mask[1] else grad_other
|
|
@@ -19,28 +38,36 @@ def matmul_backward(grad, self, other, mask):
|
|
|
19
38
|
grad_self = grad.unsqueeze(0).mm(other.transpose(-1, -2)).squeeze_(0) if mask[0] else grad_self
|
|
20
39
|
grad_other = self.unsqueeze(1).mm(grad.unsqueeze(0)) if mask[1] else grad_other
|
|
21
40
|
elif dim_self >= 3 and (dim_other == 1 or dim_other == 2):
|
|
41
|
+
if len(size_grad) < 1:
|
|
42
|
+
raise RuntimeError("size_grad's length at least 1")
|
|
22
43
|
view_size = 1 if dim_other == 1 else size_grad[-1]
|
|
23
44
|
unfolded_grad = (grad.unsqueeze(-1) if dim_other == 1 else grad).contiguous().view(-1, view_size)
|
|
24
45
|
if mask[0]:
|
|
25
46
|
grad_self = unfolded_grad.mm(other.unsqueeze(0) if dim_other == 1 else other.transpose(-1, -2)) \
|
|
26
47
|
.view(size_self)
|
|
27
48
|
if mask[1]:
|
|
49
|
+
if len(size_self) < 1:
|
|
50
|
+
raise RuntimeError("size_self's length at least 1")
|
|
28
51
|
unfolded_self = self.contiguous().view([-1, size_self[-1]])
|
|
29
52
|
grad_other = unfolded_self.transpose(-1, -2).mm(unfolded_grad).view(size_other)
|
|
30
53
|
elif (dim_self == 1 or dim_self == 2) and dim_other >= 3:
|
|
54
|
+
if len(size_grad) < 2:
|
|
55
|
+
raise RuntimeError("size_grad's length at least 2")
|
|
31
56
|
view_size = 1 if dim_self == 1 else size_grad[-2]
|
|
32
|
-
|
|
57
|
+
unfolded_grad_t = grad.view([-1, view_size]) \
|
|
33
58
|
if dim_self == 1 else grad.transpose(-1, -2).contiguous().view([-1, view_size])
|
|
34
59
|
if mask[0]:
|
|
60
|
+
if len(size_other) < 2:
|
|
61
|
+
raise RuntimeError("size_other's length at least 2")
|
|
35
62
|
# create a 2D-matrix from other
|
|
36
|
-
|
|
63
|
+
unfolded_other_t = \
|
|
37
64
|
other.transpose(-1, -2).contiguous().view([-1, size_other[-2]]).transpose(-1, -2)
|
|
38
|
-
grad_self =
|
|
65
|
+
grad_self = unfolded_other_t.mm(unfolded_grad_t).transpose(-1, -2).view(size_self)
|
|
39
66
|
if mask[1]:
|
|
40
|
-
|
|
41
|
-
|
|
67
|
+
size_other_t = size_other[:-2]
|
|
68
|
+
size_other_t.extend(size_other[::-1][:2])
|
|
42
69
|
grad_other = \
|
|
43
|
-
|
|
70
|
+
unfolded_grad_t.mm(self.unsqueeze(0) if dim_self == 1 else self).view(size_other_t).transpose(-1, -2)
|
|
44
71
|
else:
|
|
45
72
|
grad_self = torch.matmul(grad, other.transpose(-1, -2)) if mask[0] else grad_self
|
|
46
73
|
grad_other = torch.matmul(self.transpose(-1, -2), grad) if mask[1] else grad_other
|