mindstudio-probe 1.1.0__py3-none-any.whl → 1.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/METADATA +7 -6
- mindstudio_probe-1.2.1.dist-info/RECORD +396 -0
- {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/WHEEL +1 -1
- {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/entry_points.txt +0 -1
- msprobe/CMakeLists.txt +5 -0
- msprobe/README.md +51 -20
- msprobe/config.json +2 -3
- msprobe/core/advisor/advisor.py +8 -3
- msprobe/core/common/const.py +264 -15
- msprobe/core/common/exceptions.py +27 -3
- msprobe/core/common/file_utils.py +176 -26
- msprobe/core/common/inplace_op_checker.py +15 -0
- msprobe/core/common/inplace_ops.yaml +3 -0
- msprobe/core/common/log.py +27 -9
- msprobe/core/common/utils.py +204 -77
- msprobe/core/common_config.py +49 -14
- msprobe/core/compare/acc_compare.py +274 -198
- msprobe/core/compare/check.py +32 -33
- msprobe/core/compare/compare_cli.py +32 -14
- msprobe/core/compare/highlight.py +283 -127
- msprobe/core/compare/layer_mapping/__init__.py +19 -0
- msprobe/core/compare/layer_mapping/data_scope_parser.py +246 -0
- msprobe/core/compare/layer_mapping/layer_mapping.py +249 -0
- msprobe/core/compare/layer_mapping/postprocess_pass.py +95 -0
- msprobe/core/compare/merge_result/merge_result.py +380 -0
- msprobe/core/compare/merge_result/merge_result_cli.py +31 -0
- msprobe/core/compare/multiprocessing_compute.py +2 -2
- msprobe/core/compare/npy_compare.py +135 -144
- msprobe/core/compare/utils.py +419 -274
- msprobe/core/data_dump/data_collector.py +60 -28
- msprobe/core/data_dump/data_processor/base.py +84 -36
- msprobe/core/data_dump/data_processor/factory.py +5 -3
- msprobe/core/data_dump/data_processor/mindspore_processor.py +152 -18
- msprobe/core/data_dump/data_processor/pytorch_processor.py +267 -110
- msprobe/core/data_dump/json_writer.py +29 -1
- msprobe/core/data_dump/scope.py +119 -39
- msprobe/core/grad_probe/constant.py +27 -13
- msprobe/core/grad_probe/grad_compare.py +18 -1
- msprobe/core/grad_probe/utils.py +30 -2
- msprobe/core/overflow_check/abnormal_scene.py +189 -0
- msprobe/core/overflow_check/api_info.py +55 -0
- msprobe/core/overflow_check/checker.py +138 -0
- msprobe/core/overflow_check/filter.py +157 -0
- msprobe/core/overflow_check/ignore_rules.yaml +55 -0
- msprobe/core/overflow_check/level.py +22 -0
- msprobe/core/overflow_check/utils.py +28 -0
- msprobe/docs/01.installation.md +96 -7
- msprobe/docs/02.config_introduction.md +50 -23
- msprobe/docs/03.config_examples.md +2 -9
- msprobe/docs/04.kernel_dump_PyTorch.md +73 -0
- msprobe/docs/05.data_dump_PyTorch.md +93 -61
- msprobe/docs/06.data_dump_MindSpore.md +200 -95
- msprobe/docs/07.accuracy_checker_PyTorch.md +28 -28
- msprobe/docs/08.accuracy_checker_online_PyTorch.md +1 -6
- msprobe/docs/09.accuracy_checker_MindSpore.md +44 -8
- msprobe/docs/10.accuracy_compare_PyTorch.md +114 -50
- msprobe/docs/11.accuracy_compare_MindSpore.md +340 -48
- msprobe/docs/12.overflow_check_PyTorch.md +2 -2
- msprobe/docs/13.overflow_check_MindSpore.md +6 -6
- msprobe/docs/15.free_benchmarking_PyTorch.md +4 -5
- msprobe/docs/16.free_benchmarking_MindSpore.md +56 -37
- msprobe/docs/17.grad_probe.md +5 -6
- msprobe/docs/19.monitor.md +561 -0
- msprobe/docs/20.monitor_performance_baseline.md +52 -0
- msprobe/docs/21.visualization_PyTorch.md +466 -0
- msprobe/docs/22.visualization_MindSpore.md +481 -0
- msprobe/docs/23.generate_operator_PyTorch.md +107 -0
- msprobe/docs/24.code_mapping_Mindspore.md +28 -0
- msprobe/docs/25.tool_function_introduction.md +29 -0
- msprobe/docs/26.data_dump_PyTorch_baseline.md +37 -0
- msprobe/docs/27.dump_json_instruction.md +521 -0
- msprobe/docs/FAQ.md +29 -2
- msprobe/docs/accuracy_checker_MindSpore/accuracy_checker_MindSpore_baseline.md +14 -0
- msprobe/docs/data_dump_MindSpore/data_dump_MindSpore_baseline.md +22 -0
- msprobe/docs/data_dump_MindSpore/dynamic_graph_quick_start_example.md +211 -0
- msprobe/docs/img/compare_result.png +0 -0
- msprobe/docs/img/merge_result.png +0 -0
- msprobe/docs/img/monitor/cpu_info.png +0 -0
- msprobe/docs/img/visualization/fuzzy_match_ms.png +0 -0
- msprobe/docs/img/visualization/fuzzy_match_pt.png +0 -0
- msprobe/docs/img/visualization/tensorboard_1.png +0 -0
- msprobe/docs/img/visualization/tensorboard_2.png +0 -0
- msprobe/docs/img/visualization/vis_browser_1.png +0 -0
- msprobe/docs/img/visualization/vis_browser_2.png +0 -0
- msprobe/docs/img/visualization/vis_precision_info.png +0 -0
- msprobe/docs/img/visualization/vis_search_info.png +0 -0
- msprobe/docs/img/visualization/vis_show_info.png +0 -0
- msprobe/docs/img/visualization/vis_showcase.png +0 -0
- msprobe/docs/img/visualization/vis_unmatch_info.png +0 -0
- msprobe/docs/visualization/GPTModel.png +0 -0
- msprobe/docs/visualization/ParallelMLP.png +0 -0
- msprobe/docs/visualization/layer_mapping_example.md +132 -0
- msprobe/docs/visualization/mapping.png +0 -0
- msprobe/docs/visualization/mapping1.png +0 -0
- msprobe/docs/visualization/module_name.png +0 -0
- msprobe/docs/visualization/module_name1.png +0 -0
- msprobe/docs/visualization/no_mapping.png +0 -0
- msprobe/docs/visualization/no_mapping1.png +0 -0
- msprobe/docs/visualization/no_mapping_analyze.png +0 -0
- msprobe/docs/visualization/top_layer.png +0 -0
- msprobe/mindspore/__init__.py +25 -0
- msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +151 -151
- msprobe/mindspore/api_accuracy_checker/api_info.py +21 -6
- msprobe/mindspore/api_accuracy_checker/api_runner.py +43 -18
- msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +21 -7
- msprobe/mindspore/api_accuracy_checker/checker_support_api.yaml +77 -0
- msprobe/mindspore/api_accuracy_checker/cmd_parser.py +64 -1
- msprobe/mindspore/api_accuracy_checker/compute_element.py +64 -31
- msprobe/mindspore/api_accuracy_checker/data_manager.py +301 -0
- msprobe/mindspore/api_accuracy_checker/main.py +28 -3
- msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +212 -0
- msprobe/mindspore/api_accuracy_checker/multi_data_manager.py +60 -0
- msprobe/mindspore/api_accuracy_checker/type_mapping.py +22 -5
- msprobe/mindspore/api_accuracy_checker/utils.py +34 -17
- msprobe/mindspore/cell_processor.py +33 -12
- msprobe/mindspore/code_mapping/bind.py +264 -0
- msprobe/mindspore/code_mapping/cmd_parser.py +40 -0
- msprobe/mindspore/code_mapping/graph.py +49 -0
- msprobe/mindspore/code_mapping/graph_parser.py +226 -0
- msprobe/mindspore/code_mapping/main.py +24 -0
- msprobe/mindspore/code_mapping/processor.py +34 -0
- msprobe/mindspore/common/const.py +35 -13
- msprobe/mindspore/common/log.py +5 -9
- msprobe/mindspore/common/utils.py +88 -4
- msprobe/mindspore/compare/distributed_compare.py +22 -24
- msprobe/mindspore/compare/ms_compare.py +333 -268
- msprobe/mindspore/compare/ms_graph_compare.py +95 -52
- msprobe/mindspore/debugger/debugger_config.py +7 -1
- msprobe/mindspore/debugger/precision_debugger.py +87 -12
- msprobe/mindspore/dump/dump_tool_factory.py +3 -1
- msprobe/mindspore/dump/hook_cell/api_registry.py +95 -18
- msprobe/mindspore/dump/hook_cell/hook_cell.py +60 -38
- msprobe/mindspore/dump/hook_cell/primitive_hooks.py +45 -30
- msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +36 -1
- msprobe/mindspore/dump/hook_cell/wrap_api.py +92 -1
- msprobe/mindspore/dump/jit_dump.py +17 -5
- msprobe/mindspore/dump/kernel_dump/kernel_config.py +33 -0
- msprobe/mindspore/dump/kernel_graph_dump.py +9 -4
- msprobe/mindspore/dump/kernel_kbyk_dump.py +2 -4
- msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +140 -0
- msprobe/mindspore/dym_loader/hook_dynamic_loader.h +53 -0
- msprobe/mindspore/free_benchmark/api_pynative_self_check.py +156 -41
- msprobe/mindspore/free_benchmark/common/handler_params.py +1 -2
- msprobe/mindspore/free_benchmark/common/utils.py +19 -4
- msprobe/mindspore/free_benchmark/data/support_wrap_ops.yaml +0 -204
- msprobe/mindspore/free_benchmark/handler/base_handler.py +3 -3
- msprobe/mindspore/free_benchmark/handler/check_handler.py +4 -5
- msprobe/mindspore/free_benchmark/handler/fix_handler.py +4 -4
- msprobe/mindspore/free_benchmark/handler/handler_factory.py +4 -4
- msprobe/mindspore/free_benchmark/perturbation/add_noise.py +2 -2
- msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +15 -6
- msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +2 -2
- msprobe/mindspore/free_benchmark/perturbation/exchange_value.py +2 -2
- msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +13 -6
- msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +2 -2
- msprobe/mindspore/free_benchmark/self_check_tool_factory.py +2 -2
- msprobe/mindspore/grad_probe/global_context.py +28 -8
- msprobe/mindspore/grad_probe/grad_analyzer.py +50 -24
- msprobe/mindspore/grad_probe/grad_monitor.py +16 -1
- msprobe/mindspore/grad_probe/grad_stat_csv.py +33 -5
- msprobe/mindspore/grad_probe/hook.py +35 -12
- msprobe/mindspore/grad_probe/utils.py +18 -5
- msprobe/mindspore/mindtorch/__init__.py +18 -0
- msprobe/mindspore/mindtorch/mindtorch_adaptor.py +255 -0
- msprobe/mindspore/ms_config.py +27 -16
- msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +9 -4
- msprobe/mindspore/runtime.py +15 -0
- msprobe/mindspore/service.py +285 -113
- msprobe/mindspore/task_handler_factory.py +15 -0
- msprobe/msprobe.py +48 -10
- msprobe/pytorch/__init__.py +8 -6
- msprobe/pytorch/api_accuracy_checker/common/config.py +62 -0
- msprobe/pytorch/api_accuracy_checker/common/utils.py +31 -16
- msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +41 -8
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +103 -271
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_standard.yaml +4 -1
- msprobe/pytorch/api_accuracy_checker/compare/compare.py +69 -68
- msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +54 -0
- msprobe/pytorch/api_accuracy_checker/compare/compare_input.py +51 -0
- msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +2 -4
- msprobe/pytorch/api_accuracy_checker/generate_op_script/config_op.json +9 -0
- msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +478 -0
- msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +365 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/absolute_threshold.py +106 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/accumulative_error_compare.py +107 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/base_standard.py +151 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/benchmark_compare.py +226 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/binary_consistency.py +68 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/standard_config.py +218 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/standard_register.py +104 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/thousandth_standard.py +63 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/ulp_compare.py +200 -0
- msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +63 -2
- msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +21 -15
- msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +54 -22
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +140 -71
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +49 -8
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +9 -24
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +4 -12
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +5 -3
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/dump_dispatch.py +9 -4
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +3 -11
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +2 -2
- msprobe/pytorch/bench_functions/confusion_transpose.py +5 -1
- msprobe/pytorch/bench_functions/matmul_backward.py +12 -0
- msprobe/pytorch/bench_functions/npu_fusion_attention.py +142 -16
- msprobe/pytorch/bench_functions/rotary_mul.py +4 -0
- msprobe/pytorch/bench_functions/swiglu.py +10 -2
- msprobe/pytorch/common/parse_json.py +7 -6
- msprobe/pytorch/common/utils.py +101 -7
- msprobe/pytorch/compare/distributed_compare.py +17 -30
- msprobe/pytorch/compare/pt_compare.py +44 -22
- msprobe/pytorch/debugger/debugger_config.py +46 -27
- msprobe/pytorch/debugger/precision_debugger.py +42 -12
- msprobe/pytorch/dump/kernel_dump/kernel_config.py +33 -0
- msprobe/pytorch/dump/module_dump/module_dump.py +86 -0
- msprobe/pytorch/{module_processer.py → dump/module_dump/module_processer.py} +81 -10
- msprobe/pytorch/free_benchmark/common/constant.py +15 -0
- msprobe/pytorch/free_benchmark/common/counter.py +15 -0
- msprobe/pytorch/free_benchmark/common/enums.py +15 -0
- msprobe/pytorch/free_benchmark/common/params.py +10 -2
- msprobe/pytorch/free_benchmark/common/utils.py +29 -4
- msprobe/pytorch/free_benchmark/compare/grad_saver.py +20 -5
- msprobe/pytorch/free_benchmark/compare/single_benchmark.py +2 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +3 -1
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +6 -4
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +2 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +4 -0
- msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +41 -47
- msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +6 -5
- msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +0 -4
- msprobe/pytorch/grad_probe/grad_monitor.py +23 -6
- msprobe/pytorch/grad_probe/grad_stat_csv.py +40 -10
- msprobe/pytorch/hook_module/__init__.py +1 -1
- msprobe/pytorch/hook_module/hook_module.py +14 -11
- msprobe/pytorch/hook_module/register_optimizer_hook.py +59 -0
- msprobe/pytorch/hook_module/support_wrap_ops.yaml +35 -0
- msprobe/pytorch/hook_module/wrap_distributed.py +6 -8
- msprobe/pytorch/hook_module/wrap_functional.py +0 -38
- msprobe/pytorch/monitor/__init__.py +0 -0
- msprobe/pytorch/monitor/anomaly_analyse.py +201 -0
- msprobe/pytorch/monitor/anomaly_detect.py +425 -0
- msprobe/pytorch/monitor/csv2tb.py +166 -0
- msprobe/pytorch/monitor/distributed/__init__.py +0 -0
- msprobe/pytorch/monitor/distributed/distributed_ops.yaml +19 -0
- msprobe/pytorch/monitor/distributed/stack_blacklist.yaml +5 -0
- msprobe/pytorch/monitor/distributed/wrap_distributed.py +283 -0
- msprobe/pytorch/monitor/features.py +108 -0
- msprobe/pytorch/monitor/module_hook.py +1076 -0
- msprobe/pytorch/monitor/module_metric.py +172 -0
- msprobe/pytorch/monitor/module_spec_verifier.py +95 -0
- msprobe/pytorch/monitor/optimizer_collect.py +333 -0
- msprobe/pytorch/monitor/unittest/__init__.py +0 -0
- msprobe/pytorch/monitor/unittest/test_monitor.py +160 -0
- msprobe/pytorch/monitor/utils.py +321 -0
- msprobe/pytorch/monitor/visualizer.py +59 -0
- msprobe/pytorch/online_dispatch/__init__.py +2 -3
- msprobe/pytorch/online_dispatch/compare.py +29 -38
- msprobe/pytorch/online_dispatch/dispatch.py +58 -27
- msprobe/pytorch/online_dispatch/dump_compare.py +21 -9
- msprobe/pytorch/online_dispatch/single_compare.py +53 -32
- msprobe/pytorch/online_dispatch/torch_ops_config.yaml +1 -1
- msprobe/pytorch/online_dispatch/utils.py +49 -21
- msprobe/pytorch/parse_tool/lib/compare.py +21 -27
- msprobe/pytorch/parse_tool/lib/config.py +6 -8
- msprobe/pytorch/parse_tool/lib/file_desc.py +15 -1
- msprobe/pytorch/parse_tool/lib/interactive_cli.py +10 -10
- msprobe/pytorch/parse_tool/lib/parse_exception.py +7 -7
- msprobe/pytorch/parse_tool/lib/parse_tool.py +12 -12
- msprobe/pytorch/parse_tool/lib/utils.py +33 -53
- msprobe/pytorch/parse_tool/lib/visualization.py +11 -10
- msprobe/pytorch/pt_config.py +31 -8
- msprobe/pytorch/service.py +188 -108
- msprobe/visualization/__init__.py +14 -0
- msprobe/visualization/builder/__init__.py +14 -0
- msprobe/visualization/builder/graph_builder.py +222 -0
- msprobe/visualization/builder/msprobe_adapter.py +227 -0
- msprobe/visualization/compare/__init__.py +14 -0
- msprobe/visualization/compare/graph_comparator.py +180 -0
- msprobe/visualization/compare/mode_adapter.py +197 -0
- msprobe/visualization/graph/__init__.py +14 -0
- msprobe/visualization/graph/base_node.py +119 -0
- msprobe/visualization/graph/distributed_analyzer.py +318 -0
- msprobe/visualization/graph/graph.py +209 -0
- msprobe/visualization/graph/node_colors.py +95 -0
- msprobe/visualization/graph/node_op.py +39 -0
- msprobe/visualization/graph_service.py +288 -0
- msprobe/visualization/utils.py +217 -0
- mindstudio_probe-1.1.0.dist-info/RECORD +0 -287
- msprobe/docs/04.acl_config_examples.md +0 -78
- msprobe/mindspore/compare/layer_mapping.py +0 -146
- msprobe/mindspore/compare/modify_mapping.py +0 -107
- msprobe/mindspore/free_benchmark/decorator/dec_forward.py +0 -57
- msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +0 -122
- msprobe/pytorch/functional/module_dump.py +0 -84
- {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/LICENSE +0 -0
- {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/top_level.txt +0 -0
- /msprobe/mindspore/{free_benchmark/decorator → code_mapping}/__init__.py +0 -0
- /msprobe/pytorch/{functional → dump/module_dump}/__init__.py +0 -0
|
@@ -16,7 +16,6 @@
|
|
|
16
16
|
import glob
|
|
17
17
|
import os.path
|
|
18
18
|
import time
|
|
19
|
-
import re
|
|
20
19
|
from multiprocessing import Queue
|
|
21
20
|
from typing import Optional, Union, Dict, Any
|
|
22
21
|
from dataclasses import dataclass
|
|
@@ -26,9 +25,8 @@ import torch
|
|
|
26
25
|
from msprobe.pytorch.api_accuracy_checker.common.utils import ApiData
|
|
27
26
|
from msprobe.pytorch.api_accuracy_checker.tensor_transport_layer.client import TCPClient
|
|
28
27
|
from msprobe.pytorch.api_accuracy_checker.tensor_transport_layer.server import TCPServer
|
|
29
|
-
from msprobe.pytorch.common.utils import logger
|
|
30
28
|
from msprobe.core.common.file_utils import remove_path
|
|
31
|
-
from msprobe.pytorch.common.utils import save_api_data, load_api_data,
|
|
29
|
+
from msprobe.pytorch.common.utils import logger, save_api_data, load_api_data, save_pkl, load_pkl
|
|
32
30
|
|
|
33
31
|
BufferType = Union[ApiData, Dict[str, Any], str] # Union[Tensor, Tuple[Optional[Tensor]]]
|
|
34
32
|
|
|
@@ -55,7 +53,6 @@ class ATTL:
|
|
|
55
53
|
self.dequeue_list = []
|
|
56
54
|
self.message_end = False
|
|
57
55
|
self.kill_progress = False
|
|
58
|
-
self.check_attl_config()
|
|
59
56
|
self.nfs_path = None
|
|
60
57
|
if self.session_config.nfs_path:
|
|
61
58
|
self.nfs_path = self.session_config.nfs_path
|
|
@@ -73,18 +70,6 @@ class ATTL:
|
|
|
73
70
|
self.session_config.tls_path)
|
|
74
71
|
self.socket_manager.start()
|
|
75
72
|
|
|
76
|
-
def check_attl_config(self):
|
|
77
|
-
if self.session_config.nfs_path:
|
|
78
|
-
if os.path.exists(self.session_config.nfs_path):
|
|
79
|
-
return
|
|
80
|
-
else:
|
|
81
|
-
raise Exception(f"nfs path {self.session_config.nfs_path} doesn't exists.")
|
|
82
|
-
ipv4_pattern = "([1-9]?\d|1\d{2}|2[0-4]\d|25[0-5])(\.([1-9]?\d|1\d{2}|2[0-4]\d|25[0-5])){3}$"
|
|
83
|
-
if not re.match(ipv4_pattern, self.session_config.connect_ip):
|
|
84
|
-
raise Exception(f"host {self.session_config.connect_ip} is invalid.")
|
|
85
|
-
if not (0 < self.session_config.connect_port <= 65535):
|
|
86
|
-
raise Exception(f"port {self.session_config.connect_port} is invalid.")
|
|
87
|
-
|
|
88
73
|
def stop_serve(self):
|
|
89
74
|
if isinstance(self.socket_manager, TCPServer):
|
|
90
75
|
self.socket_manager.stop()
|
|
@@ -115,21 +100,21 @@ class ATTL:
|
|
|
115
100
|
self.socket_manager.add_to_sending_queue(data, rank=rank, step=step)
|
|
116
101
|
|
|
117
102
|
def recv(self, timeout_ms=0) -> Optional[BufferType]:
|
|
118
|
-
buffer =
|
|
119
|
-
while buffer
|
|
103
|
+
buffer = ''
|
|
104
|
+
while not buffer:
|
|
120
105
|
if timeout_ms > 0:
|
|
121
106
|
time.sleep(timeout_ms / 1000.0)
|
|
122
|
-
if buffer
|
|
107
|
+
if not buffer and not self.data_queue.empty():
|
|
123
108
|
buffer = self.data_queue.get()
|
|
124
109
|
break
|
|
125
|
-
if buffer
|
|
110
|
+
if not buffer and timeout_ms > 0: # timeout is the only case we give up and return None
|
|
126
111
|
break
|
|
127
112
|
if self.message_end and self.data_queue.empty():
|
|
128
113
|
buffer = b"KILL_CONFIRM"
|
|
129
114
|
self.kill_progress = True
|
|
130
115
|
break
|
|
131
116
|
time.sleep(0.1) # waiting outside the lock before next attempt
|
|
132
|
-
if buffer
|
|
117
|
+
if not buffer:
|
|
133
118
|
# this is a result of a timeout
|
|
134
119
|
self.logger.info(f"RECEIVE API DATA TIMED OUT")
|
|
135
120
|
else:
|
|
@@ -146,7 +131,7 @@ class ATTL:
|
|
|
146
131
|
except Exception as e:
|
|
147
132
|
self.logger.warning("there is something error. please check it. %s", e)
|
|
148
133
|
if isinstance(buffer, bytes):
|
|
149
|
-
return
|
|
134
|
+
return ''
|
|
150
135
|
if isinstance(buffer, str):
|
|
151
136
|
return buffer
|
|
152
137
|
|
|
@@ -160,7 +145,7 @@ class ATTL:
|
|
|
160
145
|
file_path = os.path.join(self.session_config.nfs_path, buffer + f"_{int(time.time())}")
|
|
161
146
|
|
|
162
147
|
try:
|
|
163
|
-
|
|
148
|
+
save_pkl(buffer, file_path)
|
|
164
149
|
except Exception as e:
|
|
165
150
|
self.logger.warning("there is something error in save_pt. please check it. %s", e)
|
|
166
151
|
|
|
@@ -176,7 +161,7 @@ class ATTL:
|
|
|
176
161
|
|
|
177
162
|
if cur_file is not None:
|
|
178
163
|
try:
|
|
179
|
-
buffer =
|
|
164
|
+
buffer = load_pkl(cur_file)
|
|
180
165
|
except Exception as e:
|
|
181
166
|
self.logger.warning("there is something error. please check it. %s", e)
|
|
182
167
|
remove_path(cur_file)
|
|
@@ -27,8 +27,8 @@ from twisted.internet import reactor, protocol, endpoints
|
|
|
27
27
|
from twisted.protocols.basic import FileSender
|
|
28
28
|
|
|
29
29
|
from msprobe.pytorch.common.utils import logger
|
|
30
|
-
from msprobe.pytorch.api_accuracy_checker.tensor_transport_layer.utils import
|
|
31
|
-
|
|
30
|
+
from msprobe.pytorch.api_accuracy_checker.tensor_transport_layer.utils import STRUCT_UNPACK_MODE as unpack_mode, \
|
|
31
|
+
STR_TO_BYTES_ORDER as bytes_order
|
|
32
32
|
|
|
33
33
|
MAX_SENDING_QUEUE_SIZE = 20
|
|
34
34
|
|
|
@@ -84,15 +84,6 @@ class TCPClient:
|
|
|
84
84
|
def run_reactor():
|
|
85
85
|
reactor.run(installSignalHandlers=False)
|
|
86
86
|
|
|
87
|
-
def check_tls_path(self):
|
|
88
|
-
client_key = os.path.join(self.tls_path, "client.key")
|
|
89
|
-
client_crt = os.path.join(self.tls_path, "client.crt")
|
|
90
|
-
if not os.path.exists(client_key):
|
|
91
|
-
raise Exception(f"client_key: {client_key} is not exists.")
|
|
92
|
-
if not os.path.exists(client_crt):
|
|
93
|
-
raise Exception(f"client_crt: {client_crt} is not exists.")
|
|
94
|
-
return client_key, client_crt
|
|
95
|
-
|
|
96
87
|
def start(self):
|
|
97
88
|
def conn_callback(cur_protocol):
|
|
98
89
|
if cur_protocol.transport and cur_protocol.transport.getPeer().host == self.host:
|
|
@@ -114,7 +105,8 @@ class TCPClient:
|
|
|
114
105
|
self.factory.protocol = cur_protocol
|
|
115
106
|
if self.tls_path:
|
|
116
107
|
from twisted.internet import ssl
|
|
117
|
-
client_key
|
|
108
|
+
client_key = os.path.join(self.tls_path, "client.key")
|
|
109
|
+
client_crt = os.path.join(self.tls_path, "client.crt")
|
|
118
110
|
client_context_factory = ssl.DefaultOpenSSLContextFactory(client_key, client_crt)
|
|
119
111
|
endpoint = endpoints.SSL4ClientEndpoint(reactor, self.host, self.port, client_context_factory)
|
|
120
112
|
else:
|
|
@@ -24,7 +24,7 @@ from msprobe.core.common.const import Const, CompareConst
|
|
|
24
24
|
from msprobe.pytorch.api_accuracy_checker.compare.api_precision_compare import online_api_precision_compare
|
|
25
25
|
from msprobe.pytorch.api_accuracy_checker.compare.compare_utils import DETAIL_TEST_ROWS, thousandth_standard_api, \
|
|
26
26
|
binary_standard_api, absolute_standard_api
|
|
27
|
-
from msprobe.pytorch.api_accuracy_checker.run_ut.run_ut_utils import UtDataInfo, exec_api
|
|
27
|
+
from msprobe.pytorch.api_accuracy_checker.run_ut.run_ut_utils import UtDataInfo, exec_api, ExecParams
|
|
28
28
|
from msprobe.pytorch.common.log import logger
|
|
29
29
|
from msprobe.pytorch.api_accuracy_checker.tensor_transport_layer.attl import move2target_device
|
|
30
30
|
from msprobe.pytorch.api_accuracy_checker.run_ut.run_ut_utils import generate_cpu_params
|
|
@@ -92,8 +92,10 @@ def online_precision_compare(api_data, device, common_config, api_precision_csv_
|
|
|
92
92
|
|
|
93
93
|
try:
|
|
94
94
|
# NPU vs CPU
|
|
95
|
-
|
|
96
|
-
|
|
95
|
+
cpu_params = generate_cpu_params(npu_args, npu_kwargs, False, api_name)
|
|
96
|
+
cpu_args, cpu_kwargs = cpu_params.cpu_args, cpu_params.cpu_kwargs
|
|
97
|
+
cpu_exec_params = ExecParams(api_type, api_name, Const.CPU_LOWERCASE, cpu_args, cpu_kwargs, False, None)
|
|
98
|
+
cpu_out = exec_api(cpu_exec_params)
|
|
97
99
|
npu_data_info = UtDataInfo(None, None, npu_out, cpu_out, None, [], None, rank=api_data.rank)
|
|
98
100
|
npu_detail = compare.compare_output(api_full_name, npu_data_info, True)
|
|
99
101
|
npu_data = pd.DataFrame(npu_detail, columns=DETAIL_TEST_ROWS[-1])
|
|
@@ -1,3 +1,4 @@
|
|
|
1
|
+
|
|
1
2
|
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
2
3
|
# All rights reserved.
|
|
3
4
|
#
|
|
@@ -14,6 +15,7 @@
|
|
|
14
15
|
# limitations under the License.
|
|
15
16
|
|
|
16
17
|
import os
|
|
18
|
+
from collections import defaultdict
|
|
17
19
|
from functools import wraps
|
|
18
20
|
|
|
19
21
|
import torch
|
|
@@ -39,7 +41,7 @@ def singleton(cls):
|
|
|
39
41
|
@singleton
|
|
40
42
|
class Counter:
|
|
41
43
|
def __init__(self) -> None:
|
|
42
|
-
self.index_dict =
|
|
44
|
+
self.index_dict = defaultdict(int)
|
|
43
45
|
|
|
44
46
|
|
|
45
47
|
counter = Counter()
|
|
@@ -67,9 +69,9 @@ class AccuracyCheckerDispatch(TorchDispatchMode):
|
|
|
67
69
|
|
|
68
70
|
res = func(*args, **kwargs)
|
|
69
71
|
cur_rank = get_tensor_rank(args, res)
|
|
70
|
-
cur_api_number = self.counter.index_dict
|
|
72
|
+
cur_api_number = self.counter.index_dict[aten_api]
|
|
71
73
|
api_name = f'{Const.ATEN}{Const.SEP}{aten_api}{Const.SEP}{cur_api_number}'
|
|
72
|
-
logger.info(f"tools is dumping api: {api_name}")
|
|
74
|
+
logger.info(f"tools is dumping api: {api_name}, rank: {cur_rank}")
|
|
73
75
|
api_data = ApiData(api_name, args, kwargs, res, 0, cur_rank)
|
|
74
76
|
if "device" in api_data.kwargs:
|
|
75
77
|
api_data.kwargs.pop("device")
|
|
@@ -98,7 +100,7 @@ def dispatch4data(func, attl, status):
|
|
|
98
100
|
return wrapper
|
|
99
101
|
|
|
100
102
|
|
|
101
|
-
def run_ut_dispatch(attl, status):
|
|
103
|
+
def run_ut_dispatch(attl, status, is_recompute=False):
|
|
102
104
|
"""
|
|
103
105
|
This function called by online_run_ut.
|
|
104
106
|
It is used to enable or disable dispatch for torch.autograd.backward function.
|
|
@@ -106,5 +108,8 @@ def run_ut_dispatch(attl, status):
|
|
|
106
108
|
Args:
|
|
107
109
|
attl (ATTL): online_run_ut class ATTL, which is used to upload or send api data to server.
|
|
108
110
|
status (bool): True means enable dispatch, False means disable dispatch.
|
|
111
|
+
is_recompute (bool): Flag of recompute, which is conflicted with aten api, then skip dispatch4data.
|
|
109
112
|
"""
|
|
113
|
+
if is_recompute:
|
|
114
|
+
return
|
|
110
115
|
torch.autograd.backward = dispatch4data(torch.autograd.backward, attl, status)
|
|
@@ -24,7 +24,7 @@ from twisted.internet import reactor, protocol, endpoints
|
|
|
24
24
|
|
|
25
25
|
from msprobe.pytorch.common.utils import logger
|
|
26
26
|
from msprobe.pytorch.api_accuracy_checker.tensor_transport_layer.utils import cipher_list, \
|
|
27
|
-
|
|
27
|
+
STRUCT_UNPACK_MODE as unpack_mode, STR_TO_BYTES_ORDER as bytes_order
|
|
28
28
|
|
|
29
29
|
|
|
30
30
|
class TCPServer:
|
|
@@ -40,22 +40,14 @@ class TCPServer:
|
|
|
40
40
|
def run_reactor():
|
|
41
41
|
reactor.run(installSignalHandlers=False)
|
|
42
42
|
|
|
43
|
-
def check_tls_path(self):
|
|
44
|
-
server_key = os.path.join(self.tls_path, "server.key")
|
|
45
|
-
server_crt = os.path.join(self.tls_path, "server.crt")
|
|
46
|
-
if not os.path.exists(server_key):
|
|
47
|
-
raise Exception(f"server_key: {server_key} is not exists.")
|
|
48
|
-
if not os.path.exists(server_crt):
|
|
49
|
-
raise Exception(f"server_crt: {server_crt} is not exists.")
|
|
50
|
-
return server_key, server_crt
|
|
51
|
-
|
|
52
43
|
def start(self):
|
|
53
44
|
self.factory.protocol = self.build_protocol
|
|
54
45
|
|
|
55
46
|
if self.tls_path:
|
|
56
47
|
from OpenSSL import SSL
|
|
57
48
|
from twisted.internet import ssl
|
|
58
|
-
server_key
|
|
49
|
+
server_key = os.path.join(self.tls_path, "server.key")
|
|
50
|
+
server_crt = os.path.join(self.tls_path, "server.crt")
|
|
59
51
|
server_context_factory = ssl.DefaultOpenSSLContextFactory(server_key, server_crt, SSL.TLSv1_2_METHOD)
|
|
60
52
|
server_context_ = server_context_factory.getContext()
|
|
61
53
|
server_context_.set_cipher_list(cipher_list)
|
|
@@ -22,7 +22,11 @@ def npu_confusion_transpose(data, perm, shape, transpose_first):
|
|
|
22
22
|
|
|
23
23
|
|
|
24
24
|
def npu_confusion_transpose_backward(grad, perm, shape, transpose_first):
|
|
25
|
-
|
|
25
|
+
try:
|
|
26
|
+
shape_cal = shape if transpose_first else [shape[perm_dim] for perm_dim in perm]
|
|
27
|
+
except IndexError as e:
|
|
28
|
+
raise IndexError("npu_confusion_transpose_backward: Invalid perm index for shape") from e
|
|
29
|
+
|
|
26
30
|
perm_cal = [0] * len(perm)
|
|
27
31
|
for i, perm_dim in enumerate(perm):
|
|
28
32
|
perm_cal[perm_dim] = i
|
|
@@ -17,6 +17,9 @@ import torch
|
|
|
17
17
|
|
|
18
18
|
|
|
19
19
|
def matmul_backward(grad, self, other, mask):
|
|
20
|
+
if len(mask) < 2:
|
|
21
|
+
raise RuntimeError("Mask size at least 2")
|
|
22
|
+
|
|
20
23
|
grad_self, grad_other = None, None
|
|
21
24
|
dim_self = self.dim()
|
|
22
25
|
dim_other = other.dim()
|
|
@@ -24,6 +27,7 @@ def matmul_backward(grad, self, other, mask):
|
|
|
24
27
|
size_grad = list(grad.size())
|
|
25
28
|
size_self = list(self.size())
|
|
26
29
|
size_other = list(other.size())
|
|
30
|
+
|
|
27
31
|
if dim_self == 1 and dim_other == 1:
|
|
28
32
|
grad_self = other.mul(grad) if mask[0] else grad_self
|
|
29
33
|
grad_other = self.mul(grad) if mask[1] else grad_other
|
|
@@ -34,19 +38,27 @@ def matmul_backward(grad, self, other, mask):
|
|
|
34
38
|
grad_self = grad.unsqueeze(0).mm(other.transpose(-1, -2)).squeeze_(0) if mask[0] else grad_self
|
|
35
39
|
grad_other = self.unsqueeze(1).mm(grad.unsqueeze(0)) if mask[1] else grad_other
|
|
36
40
|
elif dim_self >= 3 and (dim_other == 1 or dim_other == 2):
|
|
41
|
+
if len(size_grad) < 1:
|
|
42
|
+
raise RuntimeError("size_grad's length at least 1")
|
|
37
43
|
view_size = 1 if dim_other == 1 else size_grad[-1]
|
|
38
44
|
unfolded_grad = (grad.unsqueeze(-1) if dim_other == 1 else grad).contiguous().view(-1, view_size)
|
|
39
45
|
if mask[0]:
|
|
40
46
|
grad_self = unfolded_grad.mm(other.unsqueeze(0) if dim_other == 1 else other.transpose(-1, -2)) \
|
|
41
47
|
.view(size_self)
|
|
42
48
|
if mask[1]:
|
|
49
|
+
if len(size_self) < 1:
|
|
50
|
+
raise RuntimeError("size_self's length at least 1")
|
|
43
51
|
unfolded_self = self.contiguous().view([-1, size_self[-1]])
|
|
44
52
|
grad_other = unfolded_self.transpose(-1, -2).mm(unfolded_grad).view(size_other)
|
|
45
53
|
elif (dim_self == 1 or dim_self == 2) and dim_other >= 3:
|
|
54
|
+
if len(size_grad) < 2:
|
|
55
|
+
raise RuntimeError("size_grad's length at least 2")
|
|
46
56
|
view_size = 1 if dim_self == 1 else size_grad[-2]
|
|
47
57
|
unfolded_grad_t = grad.view([-1, view_size]) \
|
|
48
58
|
if dim_self == 1 else grad.transpose(-1, -2).contiguous().view([-1, view_size])
|
|
49
59
|
if mask[0]:
|
|
60
|
+
if len(size_other) < 2:
|
|
61
|
+
raise RuntimeError("size_other's length at least 2")
|
|
50
62
|
# create a 2D-matrix from other
|
|
51
63
|
unfolded_other_t = \
|
|
52
64
|
other.transpose(-1, -2).contiguous().view([-1, size_other[-2]]).transpose(-1, -2)
|
|
@@ -30,6 +30,7 @@
|
|
|
30
30
|
numels=0, prefix=None, sparse_mode=0, gen_mask_parallel=True, sync=False
|
|
31
31
|
"""
|
|
32
32
|
|
|
33
|
+
from collections import namedtuple
|
|
33
34
|
import torch
|
|
34
35
|
import numpy as np
|
|
35
36
|
from einops import rearrange
|
|
@@ -50,8 +51,16 @@ else:
|
|
|
50
51
|
from msprobe.pytorch.common.utils import logger
|
|
51
52
|
from msprobe.core.common.const import Const, CompareConst
|
|
52
53
|
|
|
53
|
-
|
|
54
|
-
|
|
54
|
+
GTYPE = torch.float64 # arm host必须选择float64,x86环境选择float32即可,64也行。arm计算很慢,s=8k的场景建议使用x86
|
|
55
|
+
SOFTMAX_BUILD_MODE = "QKV" # "MAX_SUM"
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
FaForwardParams = namedtuple("FaForwardParams",
|
|
59
|
+
["q", "k", "v", "drop_mask", "atten_mask", "pse", "scale", "keep_prob"])
|
|
60
|
+
FaBackwardParams = namedtuple("FaBackwardParams",
|
|
61
|
+
["dx", "q", "k", "v", "softmax_res", "drop_mask", "pse", "scale", "keep_prob"])
|
|
62
|
+
RebuildSoftmaxParams = namedtuple("RebuildSoftmaxParams",
|
|
63
|
+
["q", "k", "atten_mask", "pse", "scale", "softmax_max", "softmax_sum"])
|
|
55
64
|
|
|
56
65
|
|
|
57
66
|
def softmax_forward(x):
|
|
@@ -99,7 +108,15 @@ def calculate_qk(q, k, atten_mask, pse, scale):
|
|
|
99
108
|
return qk
|
|
100
109
|
|
|
101
110
|
|
|
102
|
-
def fusion_attention_forward(
|
|
111
|
+
def fusion_attention_forward(forward_params):
|
|
112
|
+
q = forward_params.q
|
|
113
|
+
k = forward_params.k
|
|
114
|
+
v = forward_params.v
|
|
115
|
+
drop_mask = forward_params.drop_mask
|
|
116
|
+
atten_mask = forward_params.atten_mask
|
|
117
|
+
pse = forward_params.pse
|
|
118
|
+
scale = forward_params.scale
|
|
119
|
+
keep_prob = forward_params.keep_prob
|
|
103
120
|
qk = calculate_qk(q, k, atten_mask, pse, scale)
|
|
104
121
|
softmax_res, softmax_max, softmax_sum = softmax_forward(qk)
|
|
105
122
|
if drop_mask is None or len(drop_mask.shape) == 0:
|
|
@@ -110,7 +127,16 @@ def fusion_attention_forward(q, k, v, drop_mask, atten_mask, pse, scale, keep_pr
|
|
|
110
127
|
return y, softmax_max, softmax_sum
|
|
111
128
|
|
|
112
129
|
|
|
113
|
-
def fusion_attention_backward(
|
|
130
|
+
def fusion_attention_backward(backward_params):
|
|
131
|
+
dx = backward_params.dx
|
|
132
|
+
q = backward_params.q
|
|
133
|
+
k = backward_params.k
|
|
134
|
+
v = backward_params.v
|
|
135
|
+
softmax_res = backward_params.softmax_res
|
|
136
|
+
drop_mask = backward_params.drop_mask
|
|
137
|
+
pse = backward_params.pse
|
|
138
|
+
scale = backward_params.scale
|
|
139
|
+
keep_prob = backward_params.keep_prob
|
|
114
140
|
dp = torch.matmul(dx, v.permute(0, 1, 3, 2))
|
|
115
141
|
if drop_mask is None or len(drop_mask.shape) == 0:
|
|
116
142
|
drop_res = softmax_res.permute(0, 1, 3, 2)
|
|
@@ -166,6 +192,18 @@ def parse_bsnd_args(query, key, head_num, input_layout):
|
|
|
166
192
|
|
|
167
193
|
|
|
168
194
|
def convert_from_bnsd(_input, input_layout):
|
|
195
|
+
"""
|
|
196
|
+
transform qkv from bnsd to input_layout.
|
|
197
|
+
B: batch_size
|
|
198
|
+
S: sequence_length
|
|
199
|
+
N: num_heads
|
|
200
|
+
D: head_dim
|
|
201
|
+
Args:
|
|
202
|
+
_input (torch.Tensor): tensor of shape (B,N,S,D)
|
|
203
|
+
input_layout (str): "BSH" or "SBH" or "BSND" or "BNSD" or "TND"
|
|
204
|
+
Returns:
|
|
205
|
+
tensor of shape (B,N,S,D) or (B,S,N,D) or (S,B,H) or (B,S,H)
|
|
206
|
+
"""
|
|
169
207
|
if input_layout == "BSH":
|
|
170
208
|
# (B,N,S,D)=>(B,S,N*D)
|
|
171
209
|
out = rearrange(_input, 'b n s d -> b s (n d)').contiguous()
|
|
@@ -183,7 +221,19 @@ def convert_from_bnsd(_input, input_layout):
|
|
|
183
221
|
|
|
184
222
|
|
|
185
223
|
def convert_to_bnsd(_input, n, input_layout):
|
|
186
|
-
|
|
224
|
+
"""
|
|
225
|
+
transform qkv from input_layout to bnsd.
|
|
226
|
+
B: batch_size
|
|
227
|
+
S: sequence_length
|
|
228
|
+
N: num_heads
|
|
229
|
+
D: head_dim
|
|
230
|
+
Args:
|
|
231
|
+
_input (torch.Tensor): tensor of shape (B,N,S,D) or (B,S,N,D) or (S,B,H) or (B,S,H)
|
|
232
|
+
n (int): num_heads
|
|
233
|
+
input_layout (str):"BSH" or "SBH" or "BSND" or "BNSD" or "TND"
|
|
234
|
+
Returns:
|
|
235
|
+
tensor of shape (B,N,S,D)
|
|
236
|
+
"""
|
|
187
237
|
if input_layout == "BSH":
|
|
188
238
|
# (B,S,N*D)=>(B,N,S,D)
|
|
189
239
|
out = rearrange(_input, 'b s (n d) -> b n s d', n=n)
|
|
@@ -199,7 +249,68 @@ def convert_to_bnsd(_input, n, input_layout):
|
|
|
199
249
|
out = _input
|
|
200
250
|
if out.dim() != 4:
|
|
201
251
|
raise ValueError(f"convert qkv format failed with input_layout {input_layout}.")
|
|
202
|
-
return out.to(
|
|
252
|
+
return out.to(GTYPE)
|
|
253
|
+
|
|
254
|
+
|
|
255
|
+
def convert_from_bsnd(_input, input_layout):
|
|
256
|
+
"""
|
|
257
|
+
transform qkv from bsnd to input_layout.
|
|
258
|
+
B: batch_size
|
|
259
|
+
S: sequence_length
|
|
260
|
+
N: num_heads
|
|
261
|
+
D: head_dim
|
|
262
|
+
Args:
|
|
263
|
+
_input (torch.Tensor): tensor of shape (B,S,N,D)
|
|
264
|
+
input_layout (str): "BSH" or "SBH" or "BSND" or "BNSD" or "TND"
|
|
265
|
+
Returns:
|
|
266
|
+
tensor of shape (B,N,S,D) or (B,S,N,D) or (S,B,H) or (B,S,H)
|
|
267
|
+
"""
|
|
268
|
+
if input_layout == "BSH":
|
|
269
|
+
# (B,S,N,D)=>(B,S,N*D)
|
|
270
|
+
out = rearrange(_input, 'b s n d -> b s (n d)').contiguous()
|
|
271
|
+
elif input_layout == "SBH":
|
|
272
|
+
# (B,S,N,D)=>(S,B,N*D)
|
|
273
|
+
out = rearrange(_input, 'b s n d -> s b (n d)').contiguous()
|
|
274
|
+
elif input_layout == "BNSD":
|
|
275
|
+
# (B,S,N,D)=>(B,N,S,D)
|
|
276
|
+
out = rearrange(_input, 'b s n d -> b n s d').contiguous()
|
|
277
|
+
elif input_layout == "TND":
|
|
278
|
+
raise ValueError(f"input_layout {input_layout} does not supported for now.")
|
|
279
|
+
else:
|
|
280
|
+
out = _input
|
|
281
|
+
return out
|
|
282
|
+
|
|
283
|
+
|
|
284
|
+
def convert_to_bsnd(_input, n, input_layout):
|
|
285
|
+
"""
|
|
286
|
+
transform qkv from input_layout to bsnd.
|
|
287
|
+
B: batch_size
|
|
288
|
+
S: sequence_length
|
|
289
|
+
N: num_heads
|
|
290
|
+
D: head_dim
|
|
291
|
+
Args:
|
|
292
|
+
_input (torch.Tensor): tensor of shape (B,N,S,D) or (B,S,N,D) or (S,B,H) or (B,S,H)
|
|
293
|
+
n (int): num_heads
|
|
294
|
+
input_layout (str):"BSH" or "SBH" or "BSND" or "BNSD" or "TND"
|
|
295
|
+
Returns:
|
|
296
|
+
tensor of shape (B,S,N,D)
|
|
297
|
+
"""
|
|
298
|
+
if input_layout == "BSH":
|
|
299
|
+
# (B,S,N*D)=>(B,S,N,D)
|
|
300
|
+
out = rearrange(_input, 'b s (n d) -> b s n d', n=n)
|
|
301
|
+
elif input_layout == "SBH":
|
|
302
|
+
# (S,B,N*D)=>(B,S,N,D)
|
|
303
|
+
out = rearrange(_input, 's b (n d) -> b s n d', n=n)
|
|
304
|
+
elif input_layout == "BNSD":
|
|
305
|
+
# (B,N,S,D)=>(B,S,N,D)
|
|
306
|
+
out = rearrange(_input, 'b n s d -> b s n d', n=n)
|
|
307
|
+
elif input_layout == "TND":
|
|
308
|
+
raise ValueError(f"input_layout {input_layout} does not supported for now.")
|
|
309
|
+
else:
|
|
310
|
+
out = _input
|
|
311
|
+
if out.dim() != 4:
|
|
312
|
+
raise ValueError(f"convert qkv format failed with input_layout {input_layout}.")
|
|
313
|
+
return out
|
|
203
314
|
|
|
204
315
|
|
|
205
316
|
def generate_atten_mask(*args):
|
|
@@ -279,15 +390,22 @@ def rebuid_softmax_by_qkv(q, k, atten_mask, pse, scale):
|
|
|
279
390
|
"""
|
|
280
391
|
logger.info("Using QKV to rebuild original softmax")
|
|
281
392
|
qk = calculate_qk(q, k, atten_mask, pse, scale)
|
|
282
|
-
softmax_res,
|
|
393
|
+
softmax_res, _, _ = softmax_forward(qk)
|
|
283
394
|
return softmax_res
|
|
284
395
|
|
|
285
396
|
|
|
286
|
-
def rebuild_softmax_by_max_sum(
|
|
397
|
+
def rebuild_softmax_by_max_sum(softmax_params):
|
|
287
398
|
"""
|
|
288
399
|
attention = softmax(QK^T/sqrt(d))V
|
|
289
400
|
softmax(x_i) = e^(x_i - x_max_i) / x_sum_i)
|
|
290
401
|
"""
|
|
402
|
+
q = softmax_params.q
|
|
403
|
+
k = softmax_params.k
|
|
404
|
+
atten_mask = softmax_params.atten_mask
|
|
405
|
+
pse = softmax_params.pse
|
|
406
|
+
scale = softmax_params.scale
|
|
407
|
+
softmax_max = softmax_params.softmax_max
|
|
408
|
+
softmax_sum = softmax_params.softmax_sum
|
|
291
409
|
logger.info("Using softmax_max and softmax_sum to rebuild original softmax")
|
|
292
410
|
qk = calculate_qk(q, k, atten_mask, pse, scale)
|
|
293
411
|
if softmax_max.shape[-1] == 0:
|
|
@@ -319,6 +437,10 @@ def get_input_layout(*args, **kwargs):
|
|
|
319
437
|
|
|
320
438
|
|
|
321
439
|
def npu_fusion_attention_forward_patch(*args, **kwargs):
|
|
440
|
+
|
|
441
|
+
if len(args) < 2:
|
|
442
|
+
raise RuntimeError("npu_fusion_attention_forward_patch: length of args should greater than or equal to 2.")
|
|
443
|
+
|
|
322
444
|
# query, key, value, head_num, input_layout
|
|
323
445
|
head_num = get_head_num(*args, **kwargs)
|
|
324
446
|
input_layout = get_input_layout(*args, **kwargs)
|
|
@@ -413,10 +535,8 @@ def npu_fusion_attention(*args, **kwargs):
|
|
|
413
535
|
key = convert_to_bnsd(key, n2, input_layout)
|
|
414
536
|
value = convert_to_bnsd(value, n2, input_layout)
|
|
415
537
|
k_new, v_new = generate_kv(key, value, n1, n2)
|
|
416
|
-
|
|
417
|
-
|
|
418
|
-
pse=pse, scale=scale,
|
|
419
|
-
keep_prob=keep_prob)
|
|
538
|
+
forward_params = FaForwardParams(query, k_new, v_new, None, atten_mask, pse, scale, keep_prob)
|
|
539
|
+
out_golden, softmax_max, softmax_sum = fusion_attention_forward(forward_params)
|
|
420
540
|
if out_golden.dim() == 5:
|
|
421
541
|
out_golden = out_golden.reshape(out_golden.size(0), out_golden.size(1) * out_golden.size(2), out_golden.size(3),
|
|
422
542
|
out_golden.size(4))
|
|
@@ -454,12 +574,13 @@ def npu_fusion_attention_grad(*args, **kwargs):
|
|
|
454
574
|
value = convert_to_bnsd(value, n2, input_layout)
|
|
455
575
|
k_new, v_new = generate_kv(key, value, n1, n2)
|
|
456
576
|
|
|
457
|
-
if
|
|
577
|
+
if SOFTMAX_BUILD_MODE == "QKV":
|
|
458
578
|
softmax_res = rebuid_softmax_by_qkv(query, k_new, atten_mask, pse, scale_value)
|
|
459
579
|
else:
|
|
460
|
-
|
|
461
|
-
|
|
462
|
-
|
|
580
|
+
softmax_params = RebuildSoftmaxParams(query, k_new, atten_mask, pse, scale_value, softmax_max, softmax_sum)
|
|
581
|
+
softmax_res = rebuild_softmax_by_max_sum(softmax_params)
|
|
582
|
+
backward_params = FaBackwardParams(dx, query, k_new, v_new, softmax_res, None, pse, scale_value, keep_prob)
|
|
583
|
+
dq, dk, dv = fusion_attention_backward(backward_params)
|
|
463
584
|
|
|
464
585
|
# N不等长适配by cdy
|
|
465
586
|
if not (n1 == n2):
|
|
@@ -531,8 +652,13 @@ def gpu_fusion_attention(*args, **kwargs):
|
|
|
531
652
|
else:
|
|
532
653
|
alibi_slopes = None
|
|
533
654
|
|
|
655
|
+
input_layout = get_input_layout(*args, **kwargs)
|
|
656
|
+
query = convert_to_bsnd(query, n1, input_layout)
|
|
657
|
+
key = convert_to_bsnd(key, n2, input_layout)
|
|
658
|
+
value = convert_to_bsnd(value, n2, input_layout)
|
|
534
659
|
out = flash_attn_func(
|
|
535
660
|
query, key, value, dropout_p=(1 - keep_prob), softmax_scale=scale, causal=causal_switch,
|
|
536
661
|
window_size=(window_left, window_right), alibi_slopes=alibi_slopes, deterministic=deterministic
|
|
537
662
|
)
|
|
663
|
+
out = convert_from_bsnd(out, input_layout)
|
|
538
664
|
return out, Const.NONE, Const.NONE
|
|
@@ -40,6 +40,9 @@ def npu_rotary_mul_backward(dy_tensor, x, r1, r2):
|
|
|
40
40
|
x_shape = x.shape
|
|
41
41
|
h = x.float()
|
|
42
42
|
grad = dy_tensor.float()
|
|
43
|
+
if len(r1_shape) < 4 or len(x_shape) < 4:
|
|
44
|
+
raise RuntimeError(f"Shape of r1 and x should at least be 4-dimension, "
|
|
45
|
+
f"but got r1 shape:{r1_shape}, x shape:{x_shape}")
|
|
43
46
|
condition_1 = (r1_shape[0] == 1
|
|
44
47
|
and r1_shape[1] == x_shape[1]
|
|
45
48
|
and r1_shape[2] == 1
|
|
@@ -68,4 +71,5 @@ def npu_rotary_mul_backward(dy_tensor, x, r1, r2):
|
|
|
68
71
|
for j in range(x_shape[2]):
|
|
69
72
|
r2_grad[:, 0, 0, :] += (x_new2[:, i, j, :] * grad[:, i, j, :])
|
|
70
73
|
r1_grad[:, 0, 0, :] += (h[:, i, j, :] * grad[:, i, j, :])
|
|
74
|
+
|
|
71
75
|
return x.grad.cpu(), r1_grad.cpu(), r2_grad.cpu()
|
|
@@ -19,7 +19,11 @@ import torch
|
|
|
19
19
|
def npu_swiglu(x, dim=-1):
|
|
20
20
|
tensor_dtype = x.dtype
|
|
21
21
|
|
|
22
|
-
|
|
22
|
+
try:
|
|
23
|
+
in_tensors = torch.chunk(x, 2, dim=dim)
|
|
24
|
+
except Exception as e:
|
|
25
|
+
raise RuntimeError(f"Invalid chunk x into 2 tensors with shape {x.shape} and dimension {dim}") from e
|
|
26
|
+
|
|
23
27
|
if tensor_dtype == torch.float32:
|
|
24
28
|
tensor_scalar = torch.sigmoid(torch.mul(in_tensors[0], 1.0))
|
|
25
29
|
output_data = torch.mul(torch.mul(tensor_scalar, in_tensors[0]), in_tensors[1])
|
|
@@ -34,7 +38,11 @@ def npu_swiglu(x, dim=-1):
|
|
|
34
38
|
|
|
35
39
|
def npu_swiglu_backward(grad, x, dim=-1):
|
|
36
40
|
tensor_dtype = grad.dtype
|
|
37
|
-
|
|
41
|
+
try:
|
|
42
|
+
in_tensors = torch.chunk(x, 2, dim=dim)
|
|
43
|
+
except Exception as e:
|
|
44
|
+
raise RuntimeError(f"Invalid chunk x into 2 tensors with shape {x.shape} and dimension {dim}") from e
|
|
45
|
+
|
|
38
46
|
tensor_grad_out = grad
|
|
39
47
|
|
|
40
48
|
if tensor_dtype == torch.float16:
|
|
@@ -13,20 +13,21 @@
|
|
|
13
13
|
# See the License for the specific language governing permissions and
|
|
14
14
|
# limitations under the License.
|
|
15
15
|
|
|
16
|
-
import json
|
|
17
|
-
|
|
18
16
|
from msprobe.core.common.exceptions import ParseJsonException
|
|
19
|
-
from msprobe.core.common.file_utils import
|
|
17
|
+
from msprobe.core.common.file_utils import load_json
|
|
18
|
+
from msprobe.core.common.log import logger
|
|
20
19
|
|
|
21
20
|
|
|
22
21
|
def parse_json_info_forward_backward(json_path):
|
|
23
|
-
|
|
24
|
-
dump_json = json.load(f)
|
|
22
|
+
dump_json = load_json(json_path)
|
|
25
23
|
|
|
26
24
|
real_data_path = dump_json.get("dump_data_dir")
|
|
27
25
|
dump_data = dump_json.get("data")
|
|
26
|
+
if dump_data is None:
|
|
27
|
+
raise ParseJsonException(ParseJsonException.InvalidDumpJson,
|
|
28
|
+
"something wrong with dump, no data found in dump.json")
|
|
28
29
|
if not dump_data:
|
|
29
|
-
|
|
30
|
+
logger.warning("data field is empty, no overflow data found.")
|
|
30
31
|
|
|
31
32
|
forward_data = {}
|
|
32
33
|
backward_data = {}
|