mindstudio-probe 1.3.0__py3-none-any.whl → 8.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/METADATA +4 -2
- {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/RECORD +204 -152
- msprobe/README.md +32 -1
- msprobe/core/__init__.py +17 -0
- msprobe/core/common/const.py +120 -21
- msprobe/core/common/exceptions.py +2 -2
- msprobe/core/common/file_utils.py +279 -50
- msprobe/core/common/framework_adapter.py +169 -0
- msprobe/core/common/global_lock.py +86 -0
- msprobe/core/common/runtime.py +25 -0
- msprobe/core/common/utils.py +136 -45
- msprobe/core/common_config.py +7 -0
- msprobe/core/compare/acc_compare.py +646 -428
- msprobe/core/compare/check.py +36 -103
- msprobe/core/compare/compare_cli.py +4 -0
- msprobe/core/compare/config.py +72 -0
- msprobe/core/compare/highlight.py +215 -215
- msprobe/core/compare/layer_mapping/layer_mapping.py +2 -0
- msprobe/core/compare/merge_result/merge_result.py +4 -4
- msprobe/core/compare/multiprocessing_compute.py +223 -110
- msprobe/core/compare/npy_compare.py +2 -4
- msprobe/core/compare/utils.py +214 -244
- msprobe/core/config_check/__init__.py +17 -0
- msprobe/{pytorch/dump/kernel_dump/kernel_config.py → core/config_check/checkers/__init__.py} +8 -16
- msprobe/core/config_check/checkers/base_checker.py +60 -0
- msprobe/core/config_check/checkers/dataset_checker.py +138 -0
- msprobe/core/config_check/checkers/env_args_checker.py +96 -0
- msprobe/core/config_check/checkers/hyperparameter_checker.py +170 -0
- msprobe/core/config_check/checkers/pip_checker.py +90 -0
- msprobe/core/config_check/checkers/random_checker.py +367 -0
- msprobe/core/config_check/checkers/weights_checker.py +147 -0
- msprobe/core/config_check/ckpt_compare/ckpt_comparator.py +74 -0
- msprobe/core/config_check/ckpt_compare/megatron_loader.py +302 -0
- msprobe/core/config_check/ckpt_compare/metrics.py +83 -0
- msprobe/core/config_check/ckpt_compare/name_mapping.yaml +12 -0
- msprobe/core/config_check/config_check_cli.py +51 -0
- msprobe/core/config_check/config_checker.py +100 -0
- msprobe/{mindspore/runtime.py → core/config_check/resource/dependency.yaml} +7 -4
- msprobe/core/config_check/resource/env.yaml +57 -0
- msprobe/core/config_check/resource/hyperparameter.yaml +21 -0
- msprobe/core/config_check/utils/hyperparameter_parser.py +115 -0
- msprobe/core/config_check/utils/utils.py +107 -0
- msprobe/core/data_dump/api_registry.py +67 -4
- msprobe/core/data_dump/data_collector.py +170 -89
- msprobe/core/data_dump/data_processor/base.py +72 -51
- msprobe/core/data_dump/data_processor/mindspore_processor.py +109 -55
- msprobe/core/data_dump/data_processor/pytorch_processor.py +90 -82
- msprobe/core/data_dump/json_writer.py +143 -27
- msprobe/core/debugger/precision_debugger.py +144 -0
- msprobe/core/grad_probe/constant.py +1 -1
- msprobe/core/grad_probe/grad_compare.py +1 -1
- msprobe/core/grad_probe/utils.py +1 -1
- msprobe/core/hook_manager.py +242 -0
- msprobe/core/monitor/anomaly_processor.py +384 -0
- msprobe/core/service.py +357 -0
- msprobe/core/single_save/__init__.py +0 -0
- msprobe/core/single_save/single_comparator.py +243 -0
- msprobe/core/single_save/single_saver.py +146 -0
- msprobe/docs/01.installation.md +6 -5
- msprobe/docs/02.config_introduction.md +79 -22
- msprobe/docs/03.config_examples.md +1 -0
- msprobe/docs/04.kernel_dump_PyTorch.md +1 -1
- msprobe/docs/05.data_dump_PyTorch.md +118 -49
- msprobe/docs/06.data_dump_MindSpore.md +167 -20
- msprobe/docs/07.accuracy_checker_PyTorch.md +2 -2
- msprobe/docs/08.accuracy_checker_online_PyTorch.md +69 -9
- msprobe/docs/09.accuracy_checker_MindSpore.md +18 -6
- msprobe/docs/10.accuracy_compare_PyTorch.md +212 -74
- msprobe/docs/11.accuracy_compare_MindSpore.md +87 -37
- msprobe/docs/12.overflow_check_PyTorch.md +2 -2
- msprobe/docs/13.overflow_check_MindSpore.md +2 -2
- msprobe/docs/14.data_parse_PyTorch.md +3 -3
- msprobe/docs/17.grad_probe.md +2 -1
- msprobe/docs/18.online_dispatch.md +2 -2
- msprobe/docs/19.monitor.md +90 -44
- msprobe/docs/21.visualization_PyTorch.md +68 -15
- msprobe/docs/22.visualization_MindSpore.md +71 -18
- msprobe/docs/25.tool_function_introduction.md +23 -22
- msprobe/docs/26.data_dump_PyTorch_baseline.md +14 -3
- msprobe/docs/27.dump_json_instruction.md +1 -1
- msprobe/docs/28.debugger_save_instruction.md +111 -20
- msprobe/docs/29.data_dump_MSAdapter.md +2 -2
- msprobe/docs/30.overflow_check_MSAdapter.md +2 -2
- msprobe/docs/31.config_check.md +95 -0
- msprobe/docs/32.ckpt_compare.md +69 -0
- msprobe/docs/33.generate_operator_MindSpore.md +181 -0
- msprobe/docs/34.RL_collect.md +92 -0
- msprobe/docs/35.nan_analyze.md +72 -0
- msprobe/docs/data_dump_MindSpore/data_dump_MindSpore_baseline.md +12 -1
- msprobe/docs/data_dump_MindSpore/dynamic_graph_quick_start_example.md +3 -1
- msprobe/docs/img/compare_result.png +0 -0
- msprobe/docs/img/save_compare_result_sample.png +0 -0
- msprobe/docs/img/visualization/proxy.png +0 -0
- msprobe/mindspore/__init__.py +1 -2
- msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +150 -58
- msprobe/mindspore/api_accuracy_checker/api_runner.py +7 -3
- msprobe/mindspore/api_accuracy_checker/bench_functions/flash_attention_score.py +47 -69
- msprobe/mindspore/api_accuracy_checker/cmd_parser.py +4 -0
- msprobe/mindspore/api_accuracy_checker/compute_element.py +0 -1
- msprobe/mindspore/api_accuracy_checker/data_manager.py +2 -2
- msprobe/mindspore/api_accuracy_checker/generate_op_script/op_generator.py +460 -0
- msprobe/mindspore/api_accuracy_checker/generate_op_script/operator_replication.template +2081 -0
- msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +9 -0
- msprobe/mindspore/api_accuracy_checker/torch_mindtorch_importer.py +2 -1
- msprobe/mindspore/cell_processor.py +204 -33
- msprobe/mindspore/code_mapping/graph_parser.py +4 -21
- msprobe/mindspore/common/const.py +17 -7
- msprobe/mindspore/common/utils.py +128 -11
- msprobe/mindspore/compare/common_dir_compare.py +382 -0
- msprobe/mindspore/compare/distributed_compare.py +2 -26
- msprobe/mindspore/compare/ms_compare.py +17 -405
- msprobe/mindspore/compare/ms_graph_compare.py +14 -5
- msprobe/mindspore/compare/utils.py +37 -0
- msprobe/mindspore/debugger/debugger_config.py +53 -3
- msprobe/mindspore/debugger/precision_debugger.py +72 -91
- msprobe/mindspore/dump/cell_dump_process.py +877 -0
- msprobe/mindspore/dump/cell_dump_with_insert_gradient.py +864 -0
- msprobe/mindspore/dump/dump_tool_factory.py +13 -5
- msprobe/mindspore/dump/graph_mode_cell_dump.py +139 -0
- msprobe/mindspore/dump/graph_tensor_dump.py +123 -0
- msprobe/mindspore/dump/hook_cell/api_register.py +40 -6
- msprobe/mindspore/dump/hook_cell/hook_cell.py +18 -7
- msprobe/mindspore/dump/hook_cell/ms_hook_manager.py +88 -0
- msprobe/mindspore/dump/hook_cell/primitive_hooks.py +8 -2
- msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +18 -0
- msprobe/mindspore/dump/jit_dump.py +21 -18
- msprobe/mindspore/dump/kernel_kbyk_dump.py +6 -3
- msprobe/mindspore/dym_loader/hook_dynamic_loader.cpp +110 -0
- msprobe/mindspore/dym_loader/hook_dynamic_loader.h +15 -15
- msprobe/mindspore/free_benchmark/api_pynative_self_check.py +12 -6
- msprobe/mindspore/free_benchmark/common/utils.py +1 -1
- msprobe/mindspore/grad_probe/global_context.py +7 -2
- msprobe/mindspore/grad_probe/grad_stat_csv.py +3 -2
- msprobe/mindspore/mindspore_service.py +114 -0
- msprobe/mindspore/monitor/common_func.py +52 -0
- msprobe/mindspore/monitor/data_writers.py +237 -0
- msprobe/mindspore/monitor/features.py +20 -7
- msprobe/mindspore/monitor/module_hook.py +281 -209
- msprobe/mindspore/monitor/optimizer_collect.py +334 -0
- msprobe/mindspore/monitor/utils.py +25 -5
- msprobe/mindspore/ms_config.py +16 -15
- msprobe/mindspore/task_handler_factory.py +5 -2
- msprobe/msprobe.py +19 -0
- msprobe/nan_analyze/__init__.py +14 -0
- msprobe/nan_analyze/analyzer.py +255 -0
- msprobe/nan_analyze/graph.py +189 -0
- msprobe/nan_analyze/utils.py +211 -0
- msprobe/pytorch/api_accuracy_checker/common/config.py +2 -2
- msprobe/pytorch/api_accuracy_checker/compare/compare.py +36 -34
- msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +20 -20
- msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +4 -7
- msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +204 -2
- msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +12 -11
- msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +1 -0
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +8 -5
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +2 -3
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +29 -13
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +12 -2
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +45 -31
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +156 -0
- msprobe/pytorch/attl_manager.py +65 -0
- msprobe/pytorch/bench_functions/npu_fusion_attention.py +27 -0
- msprobe/pytorch/common/utils.py +26 -14
- msprobe/pytorch/compare/distributed_compare.py +4 -36
- msprobe/pytorch/compare/pt_compare.py +13 -84
- msprobe/pytorch/compare/utils.py +47 -0
- msprobe/pytorch/debugger/debugger_config.py +34 -17
- msprobe/pytorch/debugger/precision_debugger.py +66 -118
- msprobe/pytorch/dump/module_dump/hook_wrapper.py +93 -0
- msprobe/pytorch/dump/module_dump/module_dump.py +11 -58
- msprobe/pytorch/dump/module_dump/module_processer.py +143 -113
- msprobe/pytorch/grad_probe/grad_stat_csv.py +3 -2
- msprobe/pytorch/hook_module/api_register.py +29 -5
- msprobe/pytorch/hook_module/hook_module.py +9 -18
- msprobe/pytorch/hook_module/jit_script_wrapper.py +33 -0
- msprobe/pytorch/hook_module/pt_hook_manager.py +68 -0
- msprobe/pytorch/hook_module/support_wrap_ops.yaml +22 -1
- msprobe/pytorch/hook_module/utils.py +28 -2
- msprobe/pytorch/monitor/csv2tb.py +6 -2
- msprobe/pytorch/monitor/data_writers.py +259 -0
- msprobe/pytorch/monitor/module_hook.py +227 -158
- msprobe/pytorch/monitor/module_metric.py +14 -0
- msprobe/pytorch/monitor/optimizer_collect.py +242 -270
- msprobe/pytorch/monitor/utils.py +16 -3
- msprobe/pytorch/online_dispatch/dispatch.py +4 -2
- msprobe/pytorch/online_dispatch/dump_compare.py +5 -2
- msprobe/pytorch/parse_tool/lib/utils.py +3 -3
- msprobe/pytorch/pt_config.py +8 -7
- msprobe/pytorch/pytorch_service.py +73 -0
- msprobe/visualization/builder/graph_builder.py +33 -13
- msprobe/visualization/builder/msprobe_adapter.py +24 -11
- msprobe/visualization/compare/graph_comparator.py +53 -45
- msprobe/visualization/compare/mode_adapter.py +31 -1
- msprobe/visualization/graph/base_node.py +3 -3
- msprobe/visualization/graph/graph.py +2 -2
- msprobe/visualization/graph_service.py +250 -103
- msprobe/visualization/utils.py +27 -11
- msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +0 -106
- msprobe/mindspore/monitor/anomaly_detect.py +0 -404
- msprobe/mindspore/monitor/module_spec_verifier.py +0 -94
- msprobe/mindspore/service.py +0 -549
- msprobe/pytorch/monitor/anomaly_analyse.py +0 -201
- msprobe/pytorch/monitor/anomaly_detect.py +0 -410
- msprobe/pytorch/monitor/module_spec_verifier.py +0 -95
- msprobe/pytorch/monitor/unittest/test_monitor.py +0 -160
- msprobe/pytorch/service.py +0 -473
- {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/LICENSE +0 -0
- {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/WHEEL +0 -0
- {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/entry_points.txt +0 -0
- {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/top_level.txt +0 -0
- /msprobe/{mindspore → core}/compare/ms_to_pt_api.yaml +0 -0
- /msprobe/{mindspore/dump → core}/kernel_dump/kernel_config.py +0 -0
- /msprobe/{pytorch/monitor/unittest → core/monitor}/__init__.py +0 -0
|
@@ -29,7 +29,6 @@ from msprobe.pytorch.common.log import logger
|
|
|
29
29
|
from msprobe.pytorch.api_accuracy_checker.tensor_transport_layer.attl import move2target_device
|
|
30
30
|
from msprobe.pytorch.api_accuracy_checker.run_ut.run_ut_utils import generate_cpu_params
|
|
31
31
|
|
|
32
|
-
|
|
33
32
|
# NPU vs GPU api list
|
|
34
33
|
CompareApi = set(absolute_standard_api) | set(binary_standard_api) | set(thousandth_standard_api)
|
|
35
34
|
|
|
@@ -43,6 +42,15 @@ OnlineApiPrecisionCompareConfig = namedtuple('OnlineApiPrecisionCompareConfig',
|
|
|
43
42
|
CommonCompareConfig = namedtuple('CommonCompareConfig', ['compare', 'handle_func', 'config'])
|
|
44
43
|
|
|
45
44
|
|
|
45
|
+
def get_gpu_device():
|
|
46
|
+
try:
|
|
47
|
+
import torch_npu
|
|
48
|
+
is_gpu = False
|
|
49
|
+
except ImportError:
|
|
50
|
+
is_gpu = True
|
|
51
|
+
return is_gpu
|
|
52
|
+
|
|
53
|
+
|
|
46
54
|
def run_ut_process(xpu_id, consumer_queue, common_config, api_precision_csv_file):
|
|
47
55
|
""" When consumer_queue(shared with ConsumerDispatcher) is not empty, consume api data from consumer_queue.
|
|
48
56
|
:param xpu_id: int
|
|
@@ -51,7 +59,9 @@ def run_ut_process(xpu_id, consumer_queue, common_config, api_precision_csv_file
|
|
|
51
59
|
:param api_precision_csv_file: list, length is 2, result file name and details file name
|
|
52
60
|
:return:
|
|
53
61
|
"""
|
|
54
|
-
|
|
62
|
+
device_info = "cuda" if get_gpu_device() else "npu"
|
|
63
|
+
logger.info(f"Start run_ut_process for {device_info} device, rank: {xpu_id}.")
|
|
64
|
+
gpu_device = torch.device(f'{device_info}:{xpu_id}')
|
|
55
65
|
|
|
56
66
|
while True:
|
|
57
67
|
if consumer_queue.empty():
|
|
@@ -12,19 +12,19 @@
|
|
|
12
12
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
13
|
# See the License for the specific language governing permissions and
|
|
14
14
|
# limitations under the License.
|
|
15
|
-
|
|
16
|
-
import os
|
|
15
|
+
from functools import partial
|
|
16
|
+
import os
|
|
17
17
|
import struct
|
|
18
|
-
import
|
|
18
|
+
import zlib
|
|
19
19
|
import time
|
|
20
20
|
import io
|
|
21
21
|
from threading import Thread
|
|
22
22
|
|
|
23
|
-
from twisted.internet import reactor, protocol, endpoints
|
|
23
|
+
from twisted.internet import reactor, protocol, endpoints, ssl
|
|
24
24
|
|
|
25
25
|
from msprobe.pytorch.common.utils import logger
|
|
26
26
|
from msprobe.pytorch.api_accuracy_checker.tensor_transport_layer.utils import cipher_list, \
|
|
27
|
-
STRUCT_UNPACK_MODE as unpack_mode, STR_TO_BYTES_ORDER as bytes_order
|
|
27
|
+
STRUCT_UNPACK_MODE as unpack_mode, STR_TO_BYTES_ORDER as bytes_order, verify_callback, load_ssl_pem
|
|
28
28
|
|
|
29
29
|
|
|
30
30
|
class TCPServer:
|
|
@@ -44,15 +44,28 @@ class TCPServer:
|
|
|
44
44
|
self.factory.protocol = self.build_protocol
|
|
45
45
|
|
|
46
46
|
if self.tls_path:
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
47
|
+
server_key, server_crt, ca_crt, crl_pem = load_ssl_pem(
|
|
48
|
+
key_file=os.path.join(self.tls_path, "server.key"),
|
|
49
|
+
cert_file=os.path.join(self.tls_path, "server.crt"),
|
|
50
|
+
ca_file=os.path.join(self.tls_path, "ca.crt"),
|
|
51
|
+
crl_file=os.path.join(self.tls_path, "crl.pem")
|
|
52
|
+
)
|
|
53
|
+
|
|
54
|
+
ssl_options = ssl.CertificateOptions(
|
|
55
|
+
privateKey=server_key,
|
|
56
|
+
certificate=server_crt,
|
|
57
|
+
method=ssl.SSL.TLSv1_2_METHOD,
|
|
58
|
+
verify=True,
|
|
59
|
+
requireCertificate=True,
|
|
60
|
+
caCerts=[ca_crt], # 信任的CA证书列表
|
|
61
|
+
)
|
|
62
|
+
ssl_context = ssl_options.getContext()
|
|
63
|
+
ssl_context.set_cipher_list(cipher_list)
|
|
64
|
+
ssl_context.set_options(ssl.SSL.OP_NO_RENEGOTIATION)
|
|
65
|
+
ssl_context.set_verify(ssl.SSL.VERIFY_PEER | ssl.SSL.VERIFY_FAIL_IF_NO_PEER_CERT,
|
|
66
|
+
partial(verify_callback, crl=crl_pem))
|
|
67
|
+
|
|
68
|
+
endpoint = endpoints.SSL4ServerEndpoint(reactor, self.port, ssl_options)
|
|
56
69
|
else:
|
|
57
70
|
endpoint = endpoints.TCP4ServerEndpoint(reactor, self.port)
|
|
58
71
|
endpoint.listen(self.factory)
|
|
@@ -85,10 +98,10 @@ class ServerProtocol(protocol.Protocol):
|
|
|
85
98
|
self.consumer_queue = shared_queue
|
|
86
99
|
self.check_sum = check_sum
|
|
87
100
|
self.length_width = 8
|
|
88
|
-
self.
|
|
101
|
+
self.crc_width = 8
|
|
89
102
|
self.obj_length = None
|
|
90
103
|
self.tell = 0
|
|
91
|
-
self.
|
|
104
|
+
self.obj_crc = None
|
|
92
105
|
self.obj_body = None
|
|
93
106
|
self.sequence_number = -1
|
|
94
107
|
self.rank = -1
|
|
@@ -99,7 +112,7 @@ class ServerProtocol(protocol.Protocol):
|
|
|
99
112
|
self.buffer = io.BytesIO()
|
|
100
113
|
self.obj_length = None
|
|
101
114
|
self.tell = 0
|
|
102
|
-
self.
|
|
115
|
+
self.obj_crc = None
|
|
103
116
|
self.obj_body = None
|
|
104
117
|
self.factory.transport_dict[self.transport] = 1
|
|
105
118
|
self.factory.transport_list.append(self.transport)
|
|
@@ -132,11 +145,12 @@ class ServerProtocol(protocol.Protocol):
|
|
|
132
145
|
time.sleep(0.1)
|
|
133
146
|
|
|
134
147
|
obj_key = str(self.sequence_number) + "_" + str(self.rank) + "_" + str(self.step)
|
|
148
|
+
# get the crc value of a 16-bit string with a length of 8
|
|
149
|
+
recv_crc = f"{zlib.crc32(self.obj_body):08x}"
|
|
135
150
|
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
logger.debug(f"Error:接收数据有问题,流水号{self.sequence_number}, expected {self.obj_md5}, but get {recv_md5}")
|
|
151
|
+
if self.check_sum and recv_crc != self.obj_crc:
|
|
152
|
+
# when needs check hash value and check no pass, indicates received data error, send b"ERROR" to client.
|
|
153
|
+
logger.debug(f"Error:接收数据有问题,流水号{self.sequence_number}, expected {self.obj_crc}, but get {recv_crc}")
|
|
140
154
|
self.send_ack(self.ACK_ERROR)
|
|
141
155
|
else:
|
|
142
156
|
if self.obj_body == self.ACK_STOP:
|
|
@@ -146,7 +160,7 @@ class ServerProtocol(protocol.Protocol):
|
|
|
146
160
|
if obj_key in self.sequence_number_dict:
|
|
147
161
|
logger.debug(f"这是一次异常的重传,可以忽略。 {obj_key}, {self.sequence_number_dict}")
|
|
148
162
|
else:
|
|
149
|
-
self.sequence_number_dict[obj_key] = self.
|
|
163
|
+
self.sequence_number_dict[obj_key] = self.obj_crc
|
|
150
164
|
self.consumer_queue.put(self.obj_body, block=True)
|
|
151
165
|
|
|
152
166
|
self.reset_env()
|
|
@@ -173,7 +187,7 @@ class ServerProtocol(protocol.Protocol):
|
|
|
173
187
|
self.sequence_number = -1
|
|
174
188
|
self.rank = -1
|
|
175
189
|
self.step = -1
|
|
176
|
-
self.
|
|
190
|
+
self.obj_crc = None
|
|
177
191
|
self.obj_body = None
|
|
178
192
|
|
|
179
193
|
def dataReceived(self, data):
|
|
@@ -192,15 +206,15 @@ class ServerProtocol(protocol.Protocol):
|
|
|
192
206
|
logger.debug(
|
|
193
207
|
f"流水号: {self.sequence_number}; RANK: {self.rank}; STEP: {self.step}; Length: {self.obj_length}")
|
|
194
208
|
|
|
195
|
-
# If needs check
|
|
196
|
-
|
|
209
|
+
# If needs check hash but not parse crc yet, read 8b crc values
|
|
210
|
+
check_sum_and_crc = (self.check_sum
|
|
197
211
|
and self.obj_length is not None
|
|
198
|
-
and self.
|
|
199
|
-
and len(self.buffer.getvalue()) - self.tell >= self.
|
|
200
|
-
if
|
|
201
|
-
self.
|
|
202
|
-
self.tell += self.
|
|
203
|
-
logger.debug(f"
|
|
212
|
+
and self.obj_crc is None
|
|
213
|
+
and len(self.buffer.getvalue()) - self.tell >= self.crc_width)
|
|
214
|
+
if check_sum_and_crc:
|
|
215
|
+
self.obj_crc = self.buffer.read(self.crc_width).decode()
|
|
216
|
+
self.tell += self.crc_width
|
|
217
|
+
logger.debug(f"Hash value: {self.obj_crc}")
|
|
204
218
|
|
|
205
219
|
current_length = len(self.buffer.getvalue()) - self.tell
|
|
206
220
|
if self.obj_length is not None and 0 < self.obj_length <= current_length:
|
|
@@ -12,6 +12,17 @@
|
|
|
12
12
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
13
|
# See the License for the specific language governing permissions and
|
|
14
14
|
# limitations under the License.
|
|
15
|
+
import gc
|
|
16
|
+
import os
|
|
17
|
+
from datetime import datetime, timezone
|
|
18
|
+
|
|
19
|
+
from OpenSSL import crypto
|
|
20
|
+
from cryptography import x509
|
|
21
|
+
from cryptography.hazmat.backends import default_backend
|
|
22
|
+
from dateutil import parser
|
|
23
|
+
|
|
24
|
+
from msprobe.core.common.file_utils import FileOpen
|
|
25
|
+
from msprobe.core.common.log import logger
|
|
15
26
|
|
|
16
27
|
cipher_list = ":".join(
|
|
17
28
|
["TLS_DHE_RSA_WITH_AES_128_GCM_SHA256",
|
|
@@ -42,3 +53,148 @@ cipher_list = ":".join(
|
|
|
42
53
|
|
|
43
54
|
STRUCT_UNPACK_MODE = "!Q"
|
|
44
55
|
STR_TO_BYTES_ORDER = "big"
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
def is_certificate_revoked(cert, crl):
|
|
59
|
+
# 获取证书的序列号
|
|
60
|
+
cert_serial_number = cert.get_serial_number()
|
|
61
|
+
|
|
62
|
+
# 检查证书是否在CRL中
|
|
63
|
+
revoked_serials = [revoked_cert.serial_number for revoked_cert in crl]
|
|
64
|
+
if cert_serial_number in revoked_serials:
|
|
65
|
+
logger.error(f"证书已吊销:{cert_serial_number:020x}")
|
|
66
|
+
return True
|
|
67
|
+
|
|
68
|
+
return False
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
def verify_callback(conn, cert, errno, depth, preverify_ok, crl=None):
|
|
72
|
+
"""
|
|
73
|
+
验证对端证书的有效性
|
|
74
|
+
:param conn: OpenSSL.SSL.Connection, SSL 连接对象
|
|
75
|
+
:param cert: OpenSSL.crypto.X509, 当前证书
|
|
76
|
+
:param errno: int, OpenSSL错误代码, 0:无错误 | 9:证书过期 | 18: 自签名证书
|
|
77
|
+
:param depth: int, 当前证书在证书链中的深度 (0=叶子节点), 1:中间CA证书 -1:根CA证书 2+:更高级别CA证书
|
|
78
|
+
:param preverify_ok: int, 验证结果 (1=通过, 0=失败)
|
|
79
|
+
:param crl: _CRLInternal, CRL证书对象
|
|
80
|
+
:return: bool, True表示接受证书, False表示拒绝
|
|
81
|
+
"""
|
|
82
|
+
|
|
83
|
+
if not preverify_ok:
|
|
84
|
+
from OpenSSL import SSL
|
|
85
|
+
error_str = SSL._ffi.string(SSL._lib.X509_verify_cert_error_string(errno)).decode()
|
|
86
|
+
logger.error(f"证书验证失败 (depth={depth}, err={errno}): {error_str}")
|
|
87
|
+
return False
|
|
88
|
+
|
|
89
|
+
if crl and is_certificate_revoked(cert, crl):
|
|
90
|
+
return False
|
|
91
|
+
|
|
92
|
+
return preverify_ok
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
def load_ssl_pem(key_file, cert_file, ca_file, crl_file):
|
|
96
|
+
"""
|
|
97
|
+
Load SSL PEM files.
|
|
98
|
+
|
|
99
|
+
Args:
|
|
100
|
+
key_file (str): The path to the private key file.
|
|
101
|
+
cert_file (str): The path to the certificate file.
|
|
102
|
+
ca_file (str): The path to the CA certificate file.
|
|
103
|
+
crl_file (str): The path to the CRL file.
|
|
104
|
+
|
|
105
|
+
Returns:
|
|
106
|
+
tuple: (key, crt, ca_crt, crl)
|
|
107
|
+
|
|
108
|
+
Raises:
|
|
109
|
+
Exception: If the file paths are invalid or the file contents are incorrect, exceptions may be thrown.
|
|
110
|
+
"""
|
|
111
|
+
|
|
112
|
+
try:
|
|
113
|
+
# your_private_key_password
|
|
114
|
+
passphrase = ""
|
|
115
|
+
if not passphrase:
|
|
116
|
+
import pwinput
|
|
117
|
+
passphrase = pwinput.pwinput("Enter your password: ")
|
|
118
|
+
with FileOpen(key_file, "rb") as f:
|
|
119
|
+
key = crypto.load_privatekey(crypto.FILETYPE_PEM, f.read(), passphrase.encode())
|
|
120
|
+
del passphrase
|
|
121
|
+
gc.collect()
|
|
122
|
+
with FileOpen(cert_file, "rb") as f:
|
|
123
|
+
crt = crypto.load_certificate(crypto.FILETYPE_PEM, f.read())
|
|
124
|
+
check_crt_valid(crt)
|
|
125
|
+
|
|
126
|
+
crt_serial_number = hex(crt.get_serial_number())[2:]
|
|
127
|
+
logger.info(f"crt_serial_number: {crt_serial_number}")
|
|
128
|
+
|
|
129
|
+
check_certificate_match(crt, key)
|
|
130
|
+
|
|
131
|
+
with FileOpen(ca_file, "rb") as f:
|
|
132
|
+
ca_crt = crypto.load_certificate(crypto.FILETYPE_PEM, f.read())
|
|
133
|
+
check_crt_valid(ca_crt)
|
|
134
|
+
|
|
135
|
+
ca_serial_number = hex(ca_crt.get_serial_number())[2:]
|
|
136
|
+
logger.info(f"ca_serial_number: {ca_serial_number}")
|
|
137
|
+
crl = None
|
|
138
|
+
if os.path.exists(crl_file):
|
|
139
|
+
with FileOpen(crl_file, "rb") as f:
|
|
140
|
+
crl = x509.load_pem_x509_crl(f.read(), default_backend())
|
|
141
|
+
check_crl_valid(crl, ca_crt)
|
|
142
|
+
for revoked_cert in crl:
|
|
143
|
+
logger.info(f"Serial Number: {revoked_cert.serial_number}, "
|
|
144
|
+
f"Revocation Date: {revoked_cert.revocation_date_utc}")
|
|
145
|
+
|
|
146
|
+
except Exception as e:
|
|
147
|
+
raise RuntimeError(f"The SSL certificate is invalid") from e
|
|
148
|
+
|
|
149
|
+
return key, crt, ca_crt, crl
|
|
150
|
+
|
|
151
|
+
|
|
152
|
+
def check_crt_valid(pem):
|
|
153
|
+
"""
|
|
154
|
+
Check the validity of the SSL certificate.
|
|
155
|
+
|
|
156
|
+
Raises:
|
|
157
|
+
RuntimeError: If the SSL certificate is invalid or expired.
|
|
158
|
+
"""
|
|
159
|
+
try:
|
|
160
|
+
pem_start = parser.parse(pem.get_notBefore().decode("UTF-8"))
|
|
161
|
+
pem_end = parser.parse(pem.get_notAfter().decode("UTF-8"))
|
|
162
|
+
logger.info(f"The SSL certificate passes the verification and the validity period "
|
|
163
|
+
f"starts from {pem_start} ends at {pem_end}.")
|
|
164
|
+
except Exception as e:
|
|
165
|
+
raise RuntimeError(f"The SSL certificate is invalid") from e
|
|
166
|
+
|
|
167
|
+
now_utc = datetime.now(tz=timezone.utc)
|
|
168
|
+
if pem.has_expired() or not (pem_start <= now_utc <= pem_end):
|
|
169
|
+
raise RuntimeError(f"The SSL certificate has expired.")
|
|
170
|
+
|
|
171
|
+
|
|
172
|
+
def check_certificate_match(certificate, private_key):
|
|
173
|
+
"""
|
|
174
|
+
Check certificate and private_key is match or not. if mismatched, an exception is thrown.
|
|
175
|
+
:param certificate:
|
|
176
|
+
:param private_key:
|
|
177
|
+
:return:
|
|
178
|
+
"""
|
|
179
|
+
test_data = os.urandom(256)
|
|
180
|
+
try:
|
|
181
|
+
signature = crypto.sign(private_key, test_data, "sha256")
|
|
182
|
+
crypto.verify(
|
|
183
|
+
certificate, # 包含公钥的证书
|
|
184
|
+
signature, # 生成的签名
|
|
185
|
+
test_data, # 原始数据
|
|
186
|
+
"sha256", # 哈希算法
|
|
187
|
+
)
|
|
188
|
+
logger.info("公钥和私钥匹配")
|
|
189
|
+
except Exception as e:
|
|
190
|
+
raise RuntimeError("公钥和私钥不匹配") from e
|
|
191
|
+
|
|
192
|
+
|
|
193
|
+
def check_crl_valid(crl, ca_crt):
|
|
194
|
+
# 验证CRL签名(确保CRL未被篡改)
|
|
195
|
+
if not crl.is_signature_valid(ca_crt.get_pubkey().to_cryptography_key()):
|
|
196
|
+
raise RuntimeError("CRL签名无效!")
|
|
197
|
+
|
|
198
|
+
# 检查CRL有效期
|
|
199
|
+
if not (crl.last_update <= datetime.utcnow() <= crl.next_update):
|
|
200
|
+
raise RuntimeError("CRL已过期或尚未生效!")
|
|
@@ -0,0 +1,65 @@
|
|
|
1
|
+
# Copyright (c) 2025, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
import os
|
|
17
|
+
from msprobe.core.common.runtime import Runtime
|
|
18
|
+
from msprobe.core.common.utils import Const
|
|
19
|
+
from msprobe.pytorch.api_accuracy_checker.common.utils import ApiData
|
|
20
|
+
from msprobe.pytorch.common.log import logger
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class ATTLManager:
|
|
24
|
+
def __init__(self, config):
|
|
25
|
+
self.config = config
|
|
26
|
+
self.attl = None
|
|
27
|
+
|
|
28
|
+
def attl_init(self):
|
|
29
|
+
if self.config.online_run_ut:
|
|
30
|
+
from msprobe.pytorch.api_accuracy_checker.tensor_transport_layer.attl import ATTLConfig, ATTL
|
|
31
|
+
attl_config = ATTLConfig(is_benchmark_device=False,
|
|
32
|
+
connect_ip=self.config.host,
|
|
33
|
+
connect_port=self.config.port,
|
|
34
|
+
nfs_path=self.config.nfs_path,
|
|
35
|
+
tls_path=self.config.tls_path)
|
|
36
|
+
need_dump = len(self.config.rank) == 0 or Runtime.current_rank in self.config.rank
|
|
37
|
+
self.attl = ATTL('npu', attl_config, need_dump=need_dump)
|
|
38
|
+
if self.config.nfs_path:
|
|
39
|
+
self.attl.upload("start")
|
|
40
|
+
|
|
41
|
+
def attl_send(self, name, args, kwargs, output):
|
|
42
|
+
api_data = ApiData(
|
|
43
|
+
name[:-len(Const.FORWARD_NAME_SUFFIX)],
|
|
44
|
+
args,
|
|
45
|
+
kwargs,
|
|
46
|
+
output,
|
|
47
|
+
Runtime.current_iter,
|
|
48
|
+
Runtime.current_rank
|
|
49
|
+
)
|
|
50
|
+
logger.info(f"tools is dumping api: {api_data.name}, rank: {Runtime.current_rank}")
|
|
51
|
+
api_type, _, _ = api_data.name.split(Const.SEP)
|
|
52
|
+
if api_type in [Const.DISTRIBUTED]:
|
|
53
|
+
logger.info(f"api {api_data.name} is not supported, skip")
|
|
54
|
+
return
|
|
55
|
+
if self.config.nfs_path:
|
|
56
|
+
self.attl.upload(api_data)
|
|
57
|
+
else:
|
|
58
|
+
self.attl.send(api_data)
|
|
59
|
+
|
|
60
|
+
def attl_stop(self):
|
|
61
|
+
if self.config.nfs_path:
|
|
62
|
+
self.attl.upload("end")
|
|
63
|
+
elif self.attl.socket_manager is not None:
|
|
64
|
+
logger.info(f"pid: {os.getpid()} finished, start sends STOP signal.")
|
|
65
|
+
self.attl.socket_manager.send_stop_signal()
|
|
@@ -117,6 +117,12 @@ def fusion_attention_forward(forward_params):
|
|
|
117
117
|
pse = forward_params.pse
|
|
118
118
|
scale = forward_params.scale
|
|
119
119
|
keep_prob = forward_params.keep_prob
|
|
120
|
+
|
|
121
|
+
# 除零风险拦截:keep_prob 为 0 时会导致除零错误
|
|
122
|
+
if keep_prob == 0:
|
|
123
|
+
raise ValueError("fusion_attention_forward: keep_prob cannot be zero to avoid division by zero.")
|
|
124
|
+
|
|
125
|
+
|
|
120
126
|
qk = calculate_qk(q, k, atten_mask, pse, scale)
|
|
121
127
|
softmax_res, softmax_max, softmax_sum = softmax_forward(qk)
|
|
122
128
|
if drop_mask is None or len(drop_mask.shape) == 0:
|
|
@@ -137,6 +143,11 @@ def fusion_attention_backward(backward_params):
|
|
|
137
143
|
pse = backward_params.pse
|
|
138
144
|
scale = backward_params.scale
|
|
139
145
|
keep_prob = backward_params.keep_prob
|
|
146
|
+
|
|
147
|
+
# 除零风险拦截:keep_prob 为 0 时会导致除零错误
|
|
148
|
+
if keep_prob == 0:
|
|
149
|
+
raise ValueError("fusion_attention_backward: keep_prob cannot be zero to avoid division by zero.")
|
|
150
|
+
|
|
140
151
|
dp = torch.matmul(dx, v.permute(0, 1, 3, 2))
|
|
141
152
|
if drop_mask is None or len(drop_mask.shape) == 0:
|
|
142
153
|
drop_res = softmax_res.permute(0, 1, 3, 2)
|
|
@@ -164,23 +175,35 @@ def parse_bsnd_args(query, key, head_num, input_layout):
|
|
|
164
175
|
if input_layout == "BSH":
|
|
165
176
|
b, s1, h1 = query.shape
|
|
166
177
|
_, s2, h2 = key.shape
|
|
178
|
+
if n1 == 0:
|
|
179
|
+
raise ValueError("parse_bsnd_args: head_num (n1) cannot be zero to avoid division by zero.")
|
|
167
180
|
d = h1 // n1
|
|
181
|
+
if d == 0:
|
|
182
|
+
raise ValueError("parse_bsnd_args: computed head dimension (d) is zero, division by zero risk.")
|
|
168
183
|
n2 = h2 // d
|
|
169
184
|
elif input_layout == "SBH":
|
|
170
185
|
s1, b, h1 = query.shape
|
|
171
186
|
s2, _, h2 = key.shape
|
|
187
|
+
if n1 == 0:
|
|
188
|
+
raise ValueError("parse_bsnd_args: head_num (n1) cannot be zero to avoid division by zero.")
|
|
172
189
|
d = h1 // n1
|
|
190
|
+
if d == 0:
|
|
191
|
+
raise ValueError("parse_bsnd_args: computed head dimension (d) is zero, division by zero risk.")
|
|
173
192
|
n2 = h2 // d
|
|
174
193
|
elif input_layout == "BSND":
|
|
175
194
|
b, s1, n1, d = query.shape
|
|
176
195
|
_, s2, n2, _ = key.shape
|
|
177
196
|
h1 = n1 * d
|
|
178
197
|
h2 = n2 * d
|
|
198
|
+
if d == 0:
|
|
199
|
+
raise ValueError("parse_bsnd_args: head dimension (d) is zero, division by zero risk.")
|
|
179
200
|
elif input_layout == "BNSD":
|
|
180
201
|
b, n1, s1, d = query.shape
|
|
181
202
|
_, n2, s2, _ = key.shape
|
|
182
203
|
h1 = n1 * d
|
|
183
204
|
h2 = n2 * d
|
|
205
|
+
if d == 0:
|
|
206
|
+
raise ValueError("parse_bsnd_args: head dimension (d) is zero, division by zero risk.")
|
|
184
207
|
except Exception as e:
|
|
185
208
|
raise ValueError(f"query.shape: {query.shape}, key.shape: {key.shape}, parse_bsnd_args error: {e}") from e
|
|
186
209
|
|
|
@@ -446,6 +469,8 @@ def npu_fusion_attention_forward_patch(*args, **kwargs):
|
|
|
446
469
|
input_layout = get_input_layout(*args, **kwargs)
|
|
447
470
|
|
|
448
471
|
b, s1, s2, n1, n2, d, h1, h2, dtype = parse_bsnd_args(args[0], args[1], head_num, input_layout)
|
|
472
|
+
if d == 0:
|
|
473
|
+
raise ValueError("npu_fusion_attention_forward_patch: head dimension (d) is zero, division by zero risk.")
|
|
449
474
|
if n1 == n2 and s1 == s2:
|
|
450
475
|
logger.debug(f"running case : BNSD = {b}_{n1}_{s1}_{d}, sparse = {kwargs.get('sparse_mode', 0)}")
|
|
451
476
|
else:
|
|
@@ -478,6 +503,8 @@ def npu_fusion_attention_backward_patch(*args, **kwargs):
|
|
|
478
503
|
raise ValueError(f"Unsupported npu_fusion_attention_grad args {args}.")
|
|
479
504
|
|
|
480
505
|
b, s1, s2, n1, n2, d, h1, h2, dtype = parse_bsnd_args(args[0], args[1], args[4], args[5])
|
|
506
|
+
if d == 0:
|
|
507
|
+
raise ValueError("npu_fusion_attention_backward_patch: head dimension (d) is zero, division by zero risk.")
|
|
481
508
|
if n1 == n2 and s1 == s2:
|
|
482
509
|
logger.info(f"running case : bnsd = {b}_{n1}_{s1}_{d}, sparse = {kwargs.get('sparse_mode', 0)}")
|
|
483
510
|
else:
|
msprobe/pytorch/common/utils.py
CHANGED
|
@@ -24,6 +24,7 @@ from functools import wraps
|
|
|
24
24
|
import numpy as np
|
|
25
25
|
import torch
|
|
26
26
|
import torch.distributed as dist
|
|
27
|
+
|
|
27
28
|
from msprobe.core.common.exceptions import DistributedNotInitializedError
|
|
28
29
|
from msprobe.core.common.file_utils import (FileCheckConst, change_mode,
|
|
29
30
|
check_file_or_directory_path, check_path_before_create, FileOpen)
|
|
@@ -38,7 +39,9 @@ except ImportError:
|
|
|
38
39
|
else:
|
|
39
40
|
is_gpu = False
|
|
40
41
|
|
|
42
|
+
|
|
41
43
|
torch_without_guard_version = torch.__version__ >= '2.1'
|
|
44
|
+
torch_version_above_or_equal_2 = torch.__version__.split('+')[0] >= '2.0'
|
|
42
45
|
|
|
43
46
|
if not is_gpu and not torch_without_guard_version:
|
|
44
47
|
from torch_npu.utils.device_guard import torch_device_guard as torch_npu_device_guard
|
|
@@ -313,14 +316,14 @@ def print_rank_0(message):
|
|
|
313
316
|
logger.info(message)
|
|
314
317
|
|
|
315
318
|
|
|
316
|
-
def load_pt(pt_path, to_cpu=False):
|
|
319
|
+
def load_pt(pt_path, to_cpu=False, weights_only=True):
|
|
317
320
|
pt_path = os.path.realpath(pt_path)
|
|
318
321
|
check_file_or_directory_path(pt_path)
|
|
319
322
|
try:
|
|
320
323
|
if to_cpu:
|
|
321
|
-
pt = torch.load(pt_path, map_location=torch.device("cpu"), weights_only=
|
|
324
|
+
pt = torch.load(pt_path, map_location=torch.device("cpu"), weights_only=weights_only)
|
|
322
325
|
else:
|
|
323
|
-
pt = torch.load(pt_path, weights_only=
|
|
326
|
+
pt = torch.load(pt_path, weights_only=weights_only)
|
|
324
327
|
except Exception as e:
|
|
325
328
|
raise RuntimeError(f"load pt file {pt_path} failed") from e
|
|
326
329
|
return pt
|
|
@@ -395,7 +398,7 @@ def save_api_data(api_data):
|
|
|
395
398
|
io_buff = io.BytesIO()
|
|
396
399
|
torch.save(api_data, io_buff)
|
|
397
400
|
except Exception as e:
|
|
398
|
-
raise RuntimeError(
|
|
401
|
+
raise RuntimeError("save api_data to io_buff failed") from e
|
|
399
402
|
return io_buff
|
|
400
403
|
|
|
401
404
|
|
|
@@ -403,9 +406,9 @@ def load_api_data(api_data_bytes):
|
|
|
403
406
|
"""Load data from bytes stream"""
|
|
404
407
|
try:
|
|
405
408
|
buffer = io.BytesIO(api_data_bytes)
|
|
406
|
-
buffer = torch.load(buffer, map_location="cpu")
|
|
409
|
+
buffer = torch.load(buffer, map_location="cpu", weights_only=False)
|
|
407
410
|
except Exception as e:
|
|
408
|
-
raise RuntimeError(
|
|
411
|
+
raise RuntimeError("load api_data from bytes failed") from e
|
|
409
412
|
return buffer
|
|
410
413
|
|
|
411
414
|
|
|
@@ -457,7 +460,7 @@ def is_recomputation():
|
|
|
457
460
|
|
|
458
461
|
def check_save_param(variable, name, save_backward):
|
|
459
462
|
# try catch this api to skip invalid call
|
|
460
|
-
valid_data_types =
|
|
463
|
+
valid_data_types = (torch.Tensor, int, float, str)
|
|
461
464
|
if not is_save_variable_valid(variable, valid_data_types):
|
|
462
465
|
valid_data_types_with_nested_types = valid_data_types + (dict, tuple, list)
|
|
463
466
|
logger.warning("PrecisionDebugger.save variable type not valid, "
|
|
@@ -476,13 +479,8 @@ def check_save_param(variable, name, save_backward):
|
|
|
476
479
|
raise ValueError
|
|
477
480
|
|
|
478
481
|
|
|
479
|
-
def
|
|
480
|
-
|
|
481
|
-
return text
|
|
482
|
-
index = text.rfind(old)
|
|
483
|
-
if index != -1:
|
|
484
|
-
return text[:index] + text[index:].replace(old, new, 1)
|
|
485
|
-
return text
|
|
482
|
+
def is_torch_nn_module(variable):
|
|
483
|
+
return isinstance(variable, torch.nn.Module) and not isinstance(variable, torch.jit.ScriptModule)
|
|
486
484
|
|
|
487
485
|
|
|
488
486
|
def is_hifloat8_tensor(tensor):
|
|
@@ -495,3 +493,17 @@ def is_float8_tensor(tensor):
|
|
|
495
493
|
if str(tensor.dtype) in [Const.FLOAT8_E5M2_TYPE, Const.FLOAT8_E4M3FN_TYPE]:
|
|
496
494
|
return True
|
|
497
495
|
return is_hifloat8_tensor(tensor)
|
|
496
|
+
|
|
497
|
+
|
|
498
|
+
def register_forward_pre_hook(module, forward_pre_hook):
|
|
499
|
+
if torch_version_above_or_equal_2:
|
|
500
|
+
module.register_forward_pre_hook(forward_pre_hook, with_kwargs=True)
|
|
501
|
+
else:
|
|
502
|
+
module.register_forward_pre_hook(forward_pre_hook)
|
|
503
|
+
|
|
504
|
+
|
|
505
|
+
def register_forward_hook(module, forward_hook):
|
|
506
|
+
if torch_version_above_or_equal_2:
|
|
507
|
+
module.register_forward_hook(forward_hook, with_kwargs=True)
|
|
508
|
+
else:
|
|
509
|
+
module.register_forward_hook(forward_hook)
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright (c)
|
|
1
|
+
# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
|
|
2
2
|
# All rights reserved.
|
|
3
3
|
#
|
|
4
4
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
@@ -13,41 +13,9 @@
|
|
|
13
13
|
# See the License for the specific language governing permissions and
|
|
14
14
|
# limitations under the License.
|
|
15
15
|
|
|
16
|
-
import
|
|
17
|
-
|
|
18
|
-
from msprobe.core.common.exceptions import FileCheckException
|
|
19
|
-
from msprobe.core.common.file_utils import create_directory
|
|
20
|
-
from msprobe.core.common.utils import CompareException, check_compare_param, check_configuration_param, get_dump_mode, \
|
|
21
|
-
set_dump_path
|
|
22
|
-
from msprobe.core.compare.acc_compare import ModeConfig
|
|
23
|
-
from msprobe.core.compare.utils import check_and_return_dir_contents, extract_json, set_stack_json_path
|
|
24
|
-
from msprobe.pytorch.common.log import logger
|
|
25
|
-
from msprobe.pytorch.compare.pt_compare import PTComparator, compare
|
|
16
|
+
from msprobe.core.compare.utils import compare_distributed_inner
|
|
17
|
+
from msprobe.pytorch.compare.pt_compare import compare
|
|
26
18
|
|
|
27
19
|
|
|
28
20
|
def compare_distributed(npu_dump_dir, bench_dump_dir, output_path, **kwargs):
|
|
29
|
-
|
|
30
|
-
logger.error("Argument 'suffix' is not supported for compare_distributed.")
|
|
31
|
-
raise CompareException(CompareException.INVALID_PARAM_ERROR)
|
|
32
|
-
is_print_compare_log = kwargs.get("is_print_compare_log", True)
|
|
33
|
-
# get the ranks and match by order
|
|
34
|
-
npu_ranks = sorted(check_and_return_dir_contents(npu_dump_dir, 'rank'))
|
|
35
|
-
bench_ranks = sorted(check_and_return_dir_contents(bench_dump_dir, 'rank'))
|
|
36
|
-
if len(npu_ranks) != len(bench_ranks):
|
|
37
|
-
logger.error(
|
|
38
|
-
"The number of ranks in the two runs are different. "
|
|
39
|
-
"Unable to match the ranks. "
|
|
40
|
-
"Please use another folder to compare or use compare() api and manually match the ranks.")
|
|
41
|
-
raise CompareException(CompareException.INVALID_PATH_ERROR)
|
|
42
|
-
for nr, br in zip(npu_ranks, bench_ranks):
|
|
43
|
-
npu_data_dir = os.path.join(npu_dump_dir, nr)
|
|
44
|
-
bench_data_dir = os.path.join(bench_dump_dir, br)
|
|
45
|
-
npu_path = extract_json(npu_data_dir, stack_json=False)
|
|
46
|
-
bench_path = extract_json(bench_data_dir, stack_json=False)
|
|
47
|
-
|
|
48
|
-
dump_result_param = {
|
|
49
|
-
"npu_json_path": npu_path,
|
|
50
|
-
"bench_json_path": bench_path,
|
|
51
|
-
"is_print_compare_log": is_print_compare_log
|
|
52
|
-
}
|
|
53
|
-
compare(input_param=dump_result_param, output_path=output_path, suffix=f'_{nr}-{br}', **kwargs)
|
|
21
|
+
compare_distributed_inner(npu_dump_dir, bench_dump_dir, output_path, compare, **kwargs)
|