mindstudio-probe 8.3.0__py3-none-any.whl → 8.3.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mindstudio_probe-8.3.0.dist-info → mindstudio_probe-8.3.1.dist-info}/METADATA +1 -1
- {mindstudio_probe-8.3.0.dist-info → mindstudio_probe-8.3.1.dist-info}/RECORD +37 -47
- msprobe/README.md +8 -5
- msprobe/core/common/const.py +17 -3
- msprobe/core/common/file_utils.py +64 -13
- msprobe/core/common/framework_adapter.py +10 -1
- msprobe/core/common/utils.py +17 -0
- msprobe/core/config_check/ckpt_compare/ckpt_comparator.py +6 -1
- msprobe/core/hook_manager.py +2 -16
- msprobe/core/service.py +5 -16
- msprobe/docs/01.installation.md +2 -0
- msprobe/docs/02.config_introduction.md +0 -13
- msprobe/docs/14.data_parse_PyTorch.md +2 -0
- msprobe/docs/21.visualization_PyTorch.md +1 -1
- msprobe/docs/25.tool_function_introduction.md +0 -1
- msprobe/docs/32.ckpt_compare.md +5 -5
- msprobe/mindspore/monitor/module_hook.py +17 -20
- msprobe/pytorch/api_accuracy_checker/common/config.py +3 -36
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +0 -24
- msprobe/pytorch/api_accuracy_checker/compare/compare.py +2 -12
- msprobe/pytorch/api_accuracy_checker/config.yaml +1 -6
- msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +34 -5
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +12 -132
- msprobe/pytorch/common/utils.py +0 -70
- msprobe/pytorch/debugger/debugger_config.py +0 -10
- msprobe/pytorch/dump/module_dump/module_processer.py +18 -3
- msprobe/pytorch/hook_module/api_register.py +5 -1
- msprobe/pytorch/monitor/module_hook.py +16 -34
- msprobe/pytorch/pt_config.py +2 -51
- msprobe/pytorch/pytorch_service.py +2 -11
- msprobe/visualization/builder/graph_builder.py +2 -2
- msprobe/visualization/builder/graph_merger.py +13 -0
- msprobe/visualization/graph/graph.py +13 -9
- msprobe/docs/08.accuracy_checker_online_PyTorch.md +0 -295
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/__init__.py +0 -0
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +0 -205
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +0 -378
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +0 -239
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/dump_dispatch.py +0 -115
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +0 -250
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/torch_ops_config.yaml +0 -63
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +0 -198
- msprobe/pytorch/attl_manager.py +0 -65
- {mindstudio_probe-8.3.0.dist-info → mindstudio_probe-8.3.1.dist-info}/LICENSE +0 -0
- {mindstudio_probe-8.3.0.dist-info → mindstudio_probe-8.3.1.dist-info}/WHEEL +0 -0
- {mindstudio_probe-8.3.0.dist-info → mindstudio_probe-8.3.1.dist-info}/entry_points.txt +0 -0
- {mindstudio_probe-8.3.0.dist-info → mindstudio_probe-8.3.1.dist-info}/top_level.txt +0 -0
|
@@ -1,115 +0,0 @@
|
|
|
1
|
-
|
|
2
|
-
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
3
|
-
# All rights reserved.
|
|
4
|
-
#
|
|
5
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
6
|
-
# you may not use this file except in compliance with the License.
|
|
7
|
-
# You may obtain a copy of the License at
|
|
8
|
-
#
|
|
9
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
|
10
|
-
#
|
|
11
|
-
# Unless required by applicable law or agreed to in writing, software
|
|
12
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
13
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
14
|
-
# See the License for the specific language governing permissions and
|
|
15
|
-
# limitations under the License.
|
|
16
|
-
|
|
17
|
-
import os
|
|
18
|
-
from collections import defaultdict
|
|
19
|
-
from functools import wraps
|
|
20
|
-
|
|
21
|
-
import torch
|
|
22
|
-
from torch.utils._python_dispatch import TorchDispatchMode
|
|
23
|
-
from msprobe.pytorch.api_accuracy_checker.common.utils import ApiData
|
|
24
|
-
from msprobe.pytorch.common.utils import get_tensor_rank
|
|
25
|
-
from msprobe.core.common.const import Const
|
|
26
|
-
from msprobe.pytorch.common.log import logger
|
|
27
|
-
from msprobe.core.common.file_utils import load_yaml
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
def singleton(cls):
|
|
31
|
-
_instance = {}
|
|
32
|
-
|
|
33
|
-
@wraps(cls)
|
|
34
|
-
def inner():
|
|
35
|
-
if cls not in _instance:
|
|
36
|
-
_instance[cls] = cls()
|
|
37
|
-
return _instance[cls]
|
|
38
|
-
return inner
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
@singleton
|
|
42
|
-
class Counter:
|
|
43
|
-
def __init__(self) -> None:
|
|
44
|
-
self.index_dict = defaultdict(int)
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
counter = Counter()
|
|
48
|
-
yaml_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "torch_ops_config.yaml")
|
|
49
|
-
yaml_file = load_yaml(yaml_path)
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
class AccuracyCheckerDispatch(TorchDispatchMode):
|
|
53
|
-
def __init__(self, attl):
|
|
54
|
-
super(AccuracyCheckerDispatch, self).__init__()
|
|
55
|
-
self.attl = attl
|
|
56
|
-
self.counter = counter
|
|
57
|
-
self.aten_ops_blacklist = []
|
|
58
|
-
self.npu_adjust_autogard = []
|
|
59
|
-
self.aten_ops_blacklist = yaml_file.get('aten_ops_blacklist', [])
|
|
60
|
-
self.npu_adjust_autogard = yaml_file.get('npu_adjust_autogard', [])
|
|
61
|
-
|
|
62
|
-
def __torch_dispatch__(self, func, types, args=None, kwargs=None):
|
|
63
|
-
func_name_split_list = func.__name__.split(Const.SEP)
|
|
64
|
-
aten_api = func_name_split_list[0]
|
|
65
|
-
self.enable_autogard(aten_api)
|
|
66
|
-
if aten_api in self.aten_ops_blacklist:
|
|
67
|
-
npu_out = func(*args, **kwargs)
|
|
68
|
-
return npu_out
|
|
69
|
-
|
|
70
|
-
res = func(*args, **kwargs)
|
|
71
|
-
cur_rank = get_tensor_rank(args, res)
|
|
72
|
-
cur_api_number = self.counter.index_dict[aten_api]
|
|
73
|
-
api_name = f'{Const.ATEN}{Const.SEP}{aten_api}{Const.SEP}{cur_api_number}'
|
|
74
|
-
logger.info(f"tools is dumping api: {api_name}, rank: {cur_rank}")
|
|
75
|
-
api_data = ApiData(api_name, args, kwargs, res, 0, cur_rank)
|
|
76
|
-
if "device" in api_data.kwargs:
|
|
77
|
-
api_data.kwargs.pop("device")
|
|
78
|
-
if self.attl.nfs_path:
|
|
79
|
-
self.attl.upload(api_data)
|
|
80
|
-
else:
|
|
81
|
-
self.attl.send(api_data)
|
|
82
|
-
self.counter.index_dict[aten_api] += 1
|
|
83
|
-
|
|
84
|
-
return res
|
|
85
|
-
|
|
86
|
-
def enable_autogard(self, aten_api):
|
|
87
|
-
if aten_api in self.npu_adjust_autogard:
|
|
88
|
-
torch._C._dispatch_tls_set_dispatch_key_excluded(torch._C.DispatchKey.AutogradFunctionality, False)
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
def dispatch4data(func, attl, status):
|
|
92
|
-
@wraps(func)
|
|
93
|
-
def wrapper(*args, **kwargs):
|
|
94
|
-
if not status:
|
|
95
|
-
return func(*args, **kwargs)
|
|
96
|
-
with AccuracyCheckerDispatch(attl):
|
|
97
|
-
res = func(*args, **kwargs)
|
|
98
|
-
return res
|
|
99
|
-
|
|
100
|
-
return wrapper
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
def run_ut_dispatch(attl, status, is_recompute=False):
|
|
104
|
-
"""
|
|
105
|
-
This function called by online_run_ut.
|
|
106
|
-
It is used to enable or disable dispatch for torch.autograd.backward function.
|
|
107
|
-
|
|
108
|
-
Args:
|
|
109
|
-
attl (ATTL): online_run_ut class ATTL, which is used to upload or send api data to server.
|
|
110
|
-
status (bool): True means enable dispatch, False means disable dispatch.
|
|
111
|
-
is_recompute (bool): Flag of recompute, which is conflicted with aten api, then skip dispatch4data.
|
|
112
|
-
"""
|
|
113
|
-
if is_recompute:
|
|
114
|
-
return
|
|
115
|
-
torch.autograd.backward = dispatch4data(torch.autograd.backward, attl, status)
|
|
@@ -1,250 +0,0 @@
|
|
|
1
|
-
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
2
|
-
# All rights reserved.
|
|
3
|
-
#
|
|
4
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
-
# you may not use this file except in compliance with the License.
|
|
6
|
-
# You may obtain a copy of the License at
|
|
7
|
-
#
|
|
8
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
-
#
|
|
10
|
-
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
-
# See the License for the specific language governing permissions and
|
|
14
|
-
# limitations under the License.
|
|
15
|
-
from functools import partial
|
|
16
|
-
import os
|
|
17
|
-
import struct
|
|
18
|
-
import zlib
|
|
19
|
-
import time
|
|
20
|
-
import io
|
|
21
|
-
from threading import Thread
|
|
22
|
-
|
|
23
|
-
from twisted.internet import reactor, protocol, endpoints, ssl
|
|
24
|
-
|
|
25
|
-
from msprobe.pytorch.common.utils import logger
|
|
26
|
-
from msprobe.pytorch.api_accuracy_checker.tensor_transport_layer.utils import cipher_list, \
|
|
27
|
-
STRUCT_UNPACK_MODE as unpack_mode, STR_TO_BYTES_ORDER as bytes_order, verify_callback, load_ssl_pem
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
class TCPServer:
|
|
31
|
-
def __init__(self, port, shared_queue, check_sum=False, tls_path=None) -> None:
|
|
32
|
-
self.port = port
|
|
33
|
-
self.shared_queue = shared_queue
|
|
34
|
-
self.check_sum = check_sum
|
|
35
|
-
self.tls_path = tls_path
|
|
36
|
-
self.factory = MessageServerFactory()
|
|
37
|
-
self.reactor_thread = None
|
|
38
|
-
|
|
39
|
-
@staticmethod
|
|
40
|
-
def run_reactor():
|
|
41
|
-
reactor.run(installSignalHandlers=False)
|
|
42
|
-
|
|
43
|
-
def start(self):
|
|
44
|
-
self.factory.protocol = self.build_protocol
|
|
45
|
-
|
|
46
|
-
if self.tls_path:
|
|
47
|
-
server_key, server_crt, ca_crt, crl_pem = load_ssl_pem(
|
|
48
|
-
key_file=os.path.join(self.tls_path, "server.key"),
|
|
49
|
-
cert_file=os.path.join(self.tls_path, "server.crt"),
|
|
50
|
-
ca_file=os.path.join(self.tls_path, "ca.crt"),
|
|
51
|
-
crl_file=os.path.join(self.tls_path, "crl.pem")
|
|
52
|
-
)
|
|
53
|
-
|
|
54
|
-
ssl_options = ssl.CertificateOptions(
|
|
55
|
-
privateKey=server_key,
|
|
56
|
-
certificate=server_crt,
|
|
57
|
-
method=ssl.SSL.TLSv1_2_METHOD,
|
|
58
|
-
verify=True,
|
|
59
|
-
requireCertificate=True,
|
|
60
|
-
caCerts=[ca_crt], # 信任的CA证书列表
|
|
61
|
-
)
|
|
62
|
-
ssl_context = ssl_options.getContext()
|
|
63
|
-
ssl_context.set_cipher_list(cipher_list)
|
|
64
|
-
ssl_context.set_options(ssl.SSL.OP_NO_RENEGOTIATION)
|
|
65
|
-
ssl_context.set_verify(ssl.SSL.VERIFY_PEER | ssl.SSL.VERIFY_FAIL_IF_NO_PEER_CERT,
|
|
66
|
-
partial(verify_callback, crl=crl_pem))
|
|
67
|
-
|
|
68
|
-
endpoint = endpoints.SSL4ServerEndpoint(reactor, self.port, ssl_options)
|
|
69
|
-
else:
|
|
70
|
-
endpoint = endpoints.TCP4ServerEndpoint(reactor, self.port)
|
|
71
|
-
endpoint.listen(self.factory)
|
|
72
|
-
self.reactor_thread = Thread(target=self.run_reactor, daemon=True)
|
|
73
|
-
self.reactor_thread.start()
|
|
74
|
-
|
|
75
|
-
def is_running(self):
|
|
76
|
-
return not self.factory.is_all_connection_closed()
|
|
77
|
-
|
|
78
|
-
def stop(self):
|
|
79
|
-
self.factory.doStop()
|
|
80
|
-
reactor.callFromThread(reactor.sigInt, 2)
|
|
81
|
-
self.reactor_thread.join()
|
|
82
|
-
|
|
83
|
-
def build_protocol(self):
|
|
84
|
-
return ServerProtocol(self.shared_queue, self.check_sum)
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
class ServerProtocol(protocol.Protocol):
|
|
88
|
-
ACK_SUCCESS = b"OK___"
|
|
89
|
-
ACK_ERROR = b"ERROR"
|
|
90
|
-
ACK_BUSY = b"BUSY_"
|
|
91
|
-
ACK_STOP = b"STOP_"
|
|
92
|
-
ACK_STOP_CONFIRM = b"OVER_"
|
|
93
|
-
ACK_KILL_PROCESS = b"KILL_"
|
|
94
|
-
|
|
95
|
-
def __init__(self, shared_queue, check_sum=False):
|
|
96
|
-
self.start_time = None
|
|
97
|
-
self.buffer = io.BytesIO()
|
|
98
|
-
self.consumer_queue = shared_queue
|
|
99
|
-
self.check_sum = check_sum
|
|
100
|
-
self.length_width = 8
|
|
101
|
-
self.crc_width = 8
|
|
102
|
-
self.obj_length = None
|
|
103
|
-
self.tell = 0
|
|
104
|
-
self.obj_crc = None
|
|
105
|
-
self.obj_body = None
|
|
106
|
-
self.sequence_number = -1
|
|
107
|
-
self.rank = -1
|
|
108
|
-
self.step = -1
|
|
109
|
-
self.sequence_number_dict = dict()
|
|
110
|
-
|
|
111
|
-
def connectionMade(self):
|
|
112
|
-
self.buffer = io.BytesIO()
|
|
113
|
-
self.obj_length = None
|
|
114
|
-
self.tell = 0
|
|
115
|
-
self.obj_crc = None
|
|
116
|
-
self.obj_body = None
|
|
117
|
-
self.factory.transport_dict[self.transport] = 1
|
|
118
|
-
self.factory.transport_list.append(self.transport)
|
|
119
|
-
logger.info(f"Connected to {self.transport.getPeer()} successfully.")
|
|
120
|
-
|
|
121
|
-
def connectionLost(self, reason):
|
|
122
|
-
self.factory.transport_dict.pop(self.transport, None)
|
|
123
|
-
if len(self.factory.transport_dict) == 0:
|
|
124
|
-
self.consumer_queue.put(self.ACK_KILL_PROCESS)
|
|
125
|
-
|
|
126
|
-
logger.info(f"Lost connection with {self.transport.getPeer()}. Reason is: {reason} 与客户端 断开连接, "
|
|
127
|
-
f"current connection number is: {len(self.factory.transport_dict)}")
|
|
128
|
-
|
|
129
|
-
def send_ack(self, ack_info):
|
|
130
|
-
ack_message = b"".join([
|
|
131
|
-
ack_info,
|
|
132
|
-
self.sequence_number.to_bytes(8, byteorder=bytes_order),
|
|
133
|
-
self.rank.to_bytes(8, byteorder=bytes_order),
|
|
134
|
-
self.step.to_bytes(8, byteorder=bytes_order)
|
|
135
|
-
])
|
|
136
|
-
self.transport.write(ack_message)
|
|
137
|
-
|
|
138
|
-
def post_process(self):
|
|
139
|
-
send_busy_ack = False
|
|
140
|
-
while self.consumer_queue.full():
|
|
141
|
-
if not send_busy_ack:
|
|
142
|
-
self.send_ack(self.ACK_BUSY)
|
|
143
|
-
logger.debug("sending BUSY ACK")
|
|
144
|
-
send_busy_ack = True
|
|
145
|
-
time.sleep(0.1)
|
|
146
|
-
|
|
147
|
-
obj_key = str(self.sequence_number) + "_" + str(self.rank) + "_" + str(self.step)
|
|
148
|
-
# get the crc value of a 16-bit string with a length of 8
|
|
149
|
-
recv_crc = f"{zlib.crc32(self.obj_body):08x}"
|
|
150
|
-
|
|
151
|
-
if self.check_sum and recv_crc != self.obj_crc:
|
|
152
|
-
# when needs check hash value and check no pass, indicates received data error, send b"ERROR" to client.
|
|
153
|
-
logger.debug(f"Error:接收数据有问题,流水号{self.sequence_number}, expected {self.obj_crc}, but get {recv_crc}")
|
|
154
|
-
self.send_ack(self.ACK_ERROR)
|
|
155
|
-
else:
|
|
156
|
-
if self.obj_body == self.ACK_STOP:
|
|
157
|
-
self.handle_with_stop()
|
|
158
|
-
else:
|
|
159
|
-
self.send_ack(self.ACK_SUCCESS)
|
|
160
|
-
if obj_key in self.sequence_number_dict:
|
|
161
|
-
logger.debug(f"这是一次异常的重传,可以忽略。 {obj_key}, {self.sequence_number_dict}")
|
|
162
|
-
else:
|
|
163
|
-
self.sequence_number_dict[obj_key] = self.obj_crc
|
|
164
|
-
self.consumer_queue.put(self.obj_body, block=True)
|
|
165
|
-
|
|
166
|
-
self.reset_env()
|
|
167
|
-
finish_time = time.time()
|
|
168
|
-
logger.debug(f"finish_time: {finish_time - self.start_time}")
|
|
169
|
-
|
|
170
|
-
def handle_with_stop(self):
|
|
171
|
-
logger.debug(f"接收到停止传输信号 TCP{self.transport.getPeer()}")
|
|
172
|
-
self.send_ack(self.ACK_STOP_CONFIRM)
|
|
173
|
-
if len(self.factory.transport_dict) == 0:
|
|
174
|
-
_rank, _step, _sequence_number = 0, 0, 100000000
|
|
175
|
-
ack_kill = self.ACK_KILL_PROCESS + \
|
|
176
|
-
_sequence_number.to_bytes(8, byteorder='big') + \
|
|
177
|
-
_rank.to_bytes(8, byteorder='big') + \
|
|
178
|
-
_step.to_bytes(8, byteorder='big')
|
|
179
|
-
for trans in self.factory.transport_list:
|
|
180
|
-
trans.write(ack_kill)
|
|
181
|
-
logger.debug(f"发送KILL信息给{self.transport.getPeer()}")
|
|
182
|
-
self.consumer_queue.put(self.ACK_KILL_PROCESS)
|
|
183
|
-
time.sleep(2)
|
|
184
|
-
|
|
185
|
-
def reset_env(self):
|
|
186
|
-
self.obj_length = None
|
|
187
|
-
self.sequence_number = -1
|
|
188
|
-
self.rank = -1
|
|
189
|
-
self.step = -1
|
|
190
|
-
self.obj_crc = None
|
|
191
|
-
self.obj_body = None
|
|
192
|
-
|
|
193
|
-
def dataReceived(self, data):
|
|
194
|
-
self.buffer.seek(0, 2)
|
|
195
|
-
self.buffer.write(data)
|
|
196
|
-
self.buffer.seek(self.tell)
|
|
197
|
-
|
|
198
|
-
# The first data packet is packet header, it contains obj_length, sequence_number, rank, step
|
|
199
|
-
if self.obj_length is None and len(self.buffer.getvalue()) >= self.length_width * 4:
|
|
200
|
-
self.start_time = time.time()
|
|
201
|
-
self.obj_length = struct.unpack(unpack_mode, self.buffer.read(self.length_width))[0]
|
|
202
|
-
self.sequence_number = struct.unpack(unpack_mode, self.buffer.read(self.length_width))[0]
|
|
203
|
-
self.rank = struct.unpack(unpack_mode, self.buffer.read(self.length_width))[0]
|
|
204
|
-
self.step = struct.unpack(unpack_mode, self.buffer.read(self.length_width))[0]
|
|
205
|
-
self.tell += self.length_width * 4
|
|
206
|
-
logger.debug(
|
|
207
|
-
f"流水号: {self.sequence_number}; RANK: {self.rank}; STEP: {self.step}; Length: {self.obj_length}")
|
|
208
|
-
|
|
209
|
-
# If needs check hash but not parse crc yet, read 8b crc values
|
|
210
|
-
check_sum_and_crc = (self.check_sum
|
|
211
|
-
and self.obj_length is not None
|
|
212
|
-
and self.obj_crc is None
|
|
213
|
-
and len(self.buffer.getvalue()) - self.tell >= self.crc_width)
|
|
214
|
-
if check_sum_and_crc:
|
|
215
|
-
self.obj_crc = self.buffer.read(self.crc_width).decode()
|
|
216
|
-
self.tell += self.crc_width
|
|
217
|
-
logger.debug(f"Hash value: {self.obj_crc}")
|
|
218
|
-
|
|
219
|
-
current_length = len(self.buffer.getvalue()) - self.tell
|
|
220
|
-
if self.obj_length is not None and 0 < self.obj_length <= current_length:
|
|
221
|
-
# Current api data receive finished
|
|
222
|
-
self.obj_body = self.buffer.read(self.obj_length)
|
|
223
|
-
|
|
224
|
-
self.tell += self.obj_length
|
|
225
|
-
self.buffer = io.BytesIO(self.buffer.getvalue()[self.tell:])
|
|
226
|
-
self.buffer.seek(0)
|
|
227
|
-
self.tell = 0
|
|
228
|
-
recv_data_time = time.time()
|
|
229
|
-
logger.debug(f"self.sequence_number {self.sequence_number} "
|
|
230
|
-
f"recv_data_time {recv_data_time - self.start_time}")
|
|
231
|
-
|
|
232
|
-
if self.obj_body == self.ACK_STOP:
|
|
233
|
-
# Indicates the current TCP link receives a STOP signal and remove from the transport_dict
|
|
234
|
-
_transport = self.factory.transport_dict.pop(self.transport, None)
|
|
235
|
-
logger.debug(f"接收到b'STOP_' self.sequence_number {self.sequence_number} ")
|
|
236
|
-
self.post_process()
|
|
237
|
-
|
|
238
|
-
|
|
239
|
-
class MessageServerFactory(protocol.ServerFactory):
|
|
240
|
-
def __init__(self) -> None:
|
|
241
|
-
"""
|
|
242
|
-
transport_dict: links that have not completed data transmission.
|
|
243
|
-
transport_list: Records all TCP links. Appends TCP link to the transport list
|
|
244
|
-
when a new TCP link is established.
|
|
245
|
-
"""
|
|
246
|
-
self.transport_dict = {}
|
|
247
|
-
self.transport_list = []
|
|
248
|
-
|
|
249
|
-
def is_all_connection_closed(self):
|
|
250
|
-
return len(self.transport_dict) == 0
|
|
@@ -1,63 +0,0 @@
|
|
|
1
|
-
aten_ops_blacklist:
|
|
2
|
-
- npu_binary_cross_entropy_with_logits_backward
|
|
3
|
-
- npu_ciou_backward
|
|
4
|
-
- _cudnn_rnn
|
|
5
|
-
- _local_scalar_dense
|
|
6
|
-
- _pin_memory
|
|
7
|
-
- _to_copy
|
|
8
|
-
- _unsafe_view
|
|
9
|
-
- clone
|
|
10
|
-
- contiguous
|
|
11
|
-
- copy_
|
|
12
|
-
- cudnn_batch_norm
|
|
13
|
-
- cudnn_batch_norm_backward
|
|
14
|
-
- detach
|
|
15
|
-
- empty
|
|
16
|
-
- index_put_
|
|
17
|
-
- lift_fresh
|
|
18
|
-
- max_pool2d_with_indices_backward # shape unmatch
|
|
19
|
-
- native_batch_norm_backward
|
|
20
|
-
- new_empty
|
|
21
|
-
- new_empty_strided
|
|
22
|
-
- new_full
|
|
23
|
-
- new_ones
|
|
24
|
-
- new_zeros
|
|
25
|
-
- ones
|
|
26
|
-
- ones_like
|
|
27
|
-
- permute
|
|
28
|
-
- rand
|
|
29
|
-
- rand_like
|
|
30
|
-
- randint
|
|
31
|
-
- randint_like
|
|
32
|
-
- randn
|
|
33
|
-
- randn_like
|
|
34
|
-
- randperm
|
|
35
|
-
- scalar_tensor
|
|
36
|
-
- select
|
|
37
|
-
- to
|
|
38
|
-
- transpose
|
|
39
|
-
- unbind
|
|
40
|
-
- view
|
|
41
|
-
- zero
|
|
42
|
-
- zero_
|
|
43
|
-
- zeros
|
|
44
|
-
- zeros_like
|
|
45
|
-
- _record_function_enter_new
|
|
46
|
-
- _record_function_exit
|
|
47
|
-
- broadcast_
|
|
48
|
-
- allreduce_
|
|
49
|
-
- npu_clear_float_status
|
|
50
|
-
- npu_format_cast
|
|
51
|
-
- npu_dtype_cast
|
|
52
|
-
- npu_dtype_cast_backward
|
|
53
|
-
- _allgather_base_
|
|
54
|
-
- _reduce_scatter_base_
|
|
55
|
-
- is_same_size
|
|
56
|
-
|
|
57
|
-
npu_adjust_autogard:
|
|
58
|
-
- adaptive_avg_pool2d
|
|
59
|
-
- batch_norm
|
|
60
|
-
- log_softmax
|
|
61
|
-
- nll_loss
|
|
62
|
-
- to
|
|
63
|
-
|
|
@@ -1,198 +0,0 @@
|
|
|
1
|
-
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
2
|
-
# All rights reserved.
|
|
3
|
-
#
|
|
4
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
-
# you may not use this file except in compliance with the License.
|
|
6
|
-
# You may obtain a copy of the License at
|
|
7
|
-
#
|
|
8
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
-
#
|
|
10
|
-
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
-
# See the License for the specific language governing permissions and
|
|
14
|
-
# limitations under the License.
|
|
15
|
-
import gc
|
|
16
|
-
import os
|
|
17
|
-
from datetime import datetime, timezone
|
|
18
|
-
|
|
19
|
-
from OpenSSL import crypto
|
|
20
|
-
from cryptography import x509
|
|
21
|
-
from cryptography.hazmat.backends import default_backend
|
|
22
|
-
from dateutil import parser
|
|
23
|
-
|
|
24
|
-
from msprobe.core.common.file_utils import FileOpen
|
|
25
|
-
from msprobe.core.common.log import logger
|
|
26
|
-
|
|
27
|
-
cipher_list = ":".join(
|
|
28
|
-
["TLS_DHE_RSA_WITH_AES_128_GCM_SHA256",
|
|
29
|
-
"TLS_DHE_RSA_WITH_AES_256_GCM_SHA384",
|
|
30
|
-
"TLS_DHE_DSS_WITH_AES_128_GCM_SHA256",
|
|
31
|
-
"TLS_DHE_DSS_WITH_AES_256_GCM_SHA384",
|
|
32
|
-
"TLS_DHE_PSK_WITH_AES_128_GCM_SHA256",
|
|
33
|
-
"TLS_DHE_PSK_WITH_AES_256_GCM_SHA384",
|
|
34
|
-
"TLS_DHE_PSK_WITH_CHACHA20_POLY1305_SHA256",
|
|
35
|
-
"TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256",
|
|
36
|
-
"TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384",
|
|
37
|
-
"TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256",
|
|
38
|
-
"TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384",
|
|
39
|
-
"TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256",
|
|
40
|
-
"TLS_ECDHE_PSK_WITH_CHACHA20_POLY1305_SHA256",
|
|
41
|
-
"TLS_ECDHE_PSK_WITH_AES_128_GCM_SHA256",
|
|
42
|
-
"TLS_ECDHE_PSK_WITH_AES_256_GCM_SHA384",
|
|
43
|
-
"TLS_ECDHE_PSK_WITH_AES_128_CCM_SHA256",
|
|
44
|
-
"TLS_DHE_RSA_WITH_AES_128_CCM",
|
|
45
|
-
"TLS_DHE_RSA_WITH_AES_256_CCM",
|
|
46
|
-
"TLS_DHE_RSA_WITH_CHACHA20_POLY1305_SHA256",
|
|
47
|
-
"TLS_DHE_PSK_WITH_AES_128_CCM",
|
|
48
|
-
"TLS_DHE_PSK_WITH_AES_256_CCM",
|
|
49
|
-
"TLS_ECDHE_ECDSA_WITH_AES_128_CCM",
|
|
50
|
-
"TLS_ECDHE_ECDSA_WITH_AES_256_CCM",
|
|
51
|
-
"TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256"]
|
|
52
|
-
).encode()
|
|
53
|
-
|
|
54
|
-
STRUCT_UNPACK_MODE = "!Q"
|
|
55
|
-
STR_TO_BYTES_ORDER = "big"
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
def is_certificate_revoked(cert, crl):
|
|
59
|
-
# 获取证书的序列号
|
|
60
|
-
cert_serial_number = cert.get_serial_number()
|
|
61
|
-
|
|
62
|
-
# 检查证书是否在CRL中
|
|
63
|
-
revoked_serials = [revoked_cert.serial_number for revoked_cert in crl]
|
|
64
|
-
if cert_serial_number in revoked_serials:
|
|
65
|
-
logger.error(f"证书已吊销:{cert_serial_number:020x}")
|
|
66
|
-
return True
|
|
67
|
-
|
|
68
|
-
return False
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
def verify_callback(conn, cert, errno, depth, preverify_ok, crl=None):
|
|
72
|
-
"""
|
|
73
|
-
验证对端证书的有效性
|
|
74
|
-
:param conn: OpenSSL.SSL.Connection, SSL 连接对象
|
|
75
|
-
:param cert: OpenSSL.crypto.X509, 当前证书
|
|
76
|
-
:param errno: int, OpenSSL错误代码, 0:无错误 | 9:证书过期 | 18: 自签名证书
|
|
77
|
-
:param depth: int, 当前证书在证书链中的深度 (0=叶子节点), 1:中间CA证书 -1:根CA证书 2+:更高级别CA证书
|
|
78
|
-
:param preverify_ok: int, 验证结果 (1=通过, 0=失败)
|
|
79
|
-
:param crl: _CRLInternal, CRL证书对象
|
|
80
|
-
:return: bool, True表示接受证书, False表示拒绝
|
|
81
|
-
"""
|
|
82
|
-
|
|
83
|
-
if not preverify_ok:
|
|
84
|
-
from OpenSSL import SSL
|
|
85
|
-
error_str = SSL._ffi.string(SSL._lib.X509_verify_cert_error_string(errno)).decode()
|
|
86
|
-
logger.error(f"证书验证失败 (depth={depth}, err={errno}): {error_str}")
|
|
87
|
-
return False
|
|
88
|
-
|
|
89
|
-
if crl and is_certificate_revoked(cert, crl):
|
|
90
|
-
return False
|
|
91
|
-
|
|
92
|
-
return preverify_ok
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
def load_ssl_pem(key_file, cert_file, ca_file, crl_file):
|
|
96
|
-
"""
|
|
97
|
-
Load SSL PEM files.
|
|
98
|
-
|
|
99
|
-
Args:
|
|
100
|
-
key_file (str): The path to the private key file.
|
|
101
|
-
cert_file (str): The path to the certificate file.
|
|
102
|
-
ca_file (str): The path to the CA certificate file.
|
|
103
|
-
crl_file (str): The path to the CRL file.
|
|
104
|
-
|
|
105
|
-
Returns:
|
|
106
|
-
tuple: (key, crt, ca_crt, crl)
|
|
107
|
-
|
|
108
|
-
Raises:
|
|
109
|
-
Exception: If the file paths are invalid or the file contents are incorrect, exceptions may be thrown.
|
|
110
|
-
"""
|
|
111
|
-
|
|
112
|
-
try:
|
|
113
|
-
# your_private_key_password
|
|
114
|
-
import pwinput
|
|
115
|
-
passphrase = pwinput.pwinput("Enter your password: ")
|
|
116
|
-
with FileOpen(key_file, "rb") as f:
|
|
117
|
-
key = crypto.load_privatekey(crypto.FILETYPE_PEM, f.read(), passphrase.encode())
|
|
118
|
-
del passphrase
|
|
119
|
-
gc.collect()
|
|
120
|
-
with FileOpen(cert_file, "rb") as f:
|
|
121
|
-
crt = crypto.load_certificate(crypto.FILETYPE_PEM, f.read())
|
|
122
|
-
check_crt_valid(crt)
|
|
123
|
-
|
|
124
|
-
crt_serial_number = hex(crt.get_serial_number())[2:]
|
|
125
|
-
logger.info(f"crt_serial_number: {crt_serial_number}")
|
|
126
|
-
|
|
127
|
-
check_certificate_match(crt, key)
|
|
128
|
-
|
|
129
|
-
with FileOpen(ca_file, "rb") as f:
|
|
130
|
-
ca_crt = crypto.load_certificate(crypto.FILETYPE_PEM, f.read())
|
|
131
|
-
check_crt_valid(ca_crt)
|
|
132
|
-
|
|
133
|
-
ca_serial_number = hex(ca_crt.get_serial_number())[2:]
|
|
134
|
-
logger.info(f"ca_serial_number: {ca_serial_number}")
|
|
135
|
-
crl = None
|
|
136
|
-
if os.path.exists(crl_file):
|
|
137
|
-
with FileOpen(crl_file, "rb") as f:
|
|
138
|
-
crl = x509.load_pem_x509_crl(f.read(), default_backend())
|
|
139
|
-
check_crl_valid(crl, ca_crt)
|
|
140
|
-
for revoked_cert in crl:
|
|
141
|
-
logger.info(f"Serial Number: {revoked_cert.serial_number}, "
|
|
142
|
-
f"Revocation Date: {revoked_cert.revocation_date_utc}")
|
|
143
|
-
|
|
144
|
-
except Exception as e:
|
|
145
|
-
raise RuntimeError(f"The SSL certificate is invalid") from e
|
|
146
|
-
|
|
147
|
-
return key, crt, ca_crt, crl
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
def check_crt_valid(pem):
|
|
151
|
-
"""
|
|
152
|
-
Check the validity of the SSL certificate.
|
|
153
|
-
|
|
154
|
-
Raises:
|
|
155
|
-
RuntimeError: If the SSL certificate is invalid or expired.
|
|
156
|
-
"""
|
|
157
|
-
try:
|
|
158
|
-
pem_start = parser.parse(pem.get_notBefore().decode("UTF-8"))
|
|
159
|
-
pem_end = parser.parse(pem.get_notAfter().decode("UTF-8"))
|
|
160
|
-
logger.info(f"The SSL certificate passes the verification and the validity period "
|
|
161
|
-
f"starts from {pem_start} ends at {pem_end}.")
|
|
162
|
-
except Exception as e:
|
|
163
|
-
raise RuntimeError(f"The SSL certificate is invalid") from e
|
|
164
|
-
|
|
165
|
-
now_utc = datetime.now(tz=timezone.utc)
|
|
166
|
-
if pem.has_expired() or not (pem_start <= now_utc <= pem_end):
|
|
167
|
-
raise RuntimeError(f"The SSL certificate has expired.")
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
def check_certificate_match(certificate, private_key):
|
|
171
|
-
"""
|
|
172
|
-
Check certificate and private_key is match or not. if mismatched, an exception is thrown.
|
|
173
|
-
:param certificate:
|
|
174
|
-
:param private_key:
|
|
175
|
-
:return:
|
|
176
|
-
"""
|
|
177
|
-
test_data = os.urandom(256)
|
|
178
|
-
try:
|
|
179
|
-
signature = crypto.sign(private_key, test_data, "sha256")
|
|
180
|
-
crypto.verify(
|
|
181
|
-
certificate, # 包含公钥的证书
|
|
182
|
-
signature, # 生成的签名
|
|
183
|
-
test_data, # 原始数据
|
|
184
|
-
"sha256", # 哈希算法
|
|
185
|
-
)
|
|
186
|
-
logger.info("公钥和私钥匹配")
|
|
187
|
-
except Exception as e:
|
|
188
|
-
raise RuntimeError("公钥和私钥不匹配") from e
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
def check_crl_valid(crl, ca_crt):
|
|
192
|
-
# 验证CRL签名(确保CRL未被篡改)
|
|
193
|
-
if not crl.is_signature_valid(ca_crt.get_pubkey().to_cryptography_key()):
|
|
194
|
-
raise RuntimeError("CRL签名无效!")
|
|
195
|
-
|
|
196
|
-
# 检查CRL有效期
|
|
197
|
-
if not (crl.last_update <= datetime.utcnow() <= crl.next_update):
|
|
198
|
-
raise RuntimeError("CRL已过期或尚未生效!")
|
msprobe/pytorch/attl_manager.py
DELETED
|
@@ -1,65 +0,0 @@
|
|
|
1
|
-
# Copyright (c) 2025, Huawei Technologies Co., Ltd.
|
|
2
|
-
# All rights reserved.
|
|
3
|
-
#
|
|
4
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
-
# you may not use this file except in compliance with the License.
|
|
6
|
-
# You may obtain a copy of the License at
|
|
7
|
-
#
|
|
8
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
-
#
|
|
10
|
-
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
-
# See the License for the specific language governing permissions and
|
|
14
|
-
# limitations under the License.
|
|
15
|
-
|
|
16
|
-
import os
|
|
17
|
-
from msprobe.core.common.runtime import Runtime
|
|
18
|
-
from msprobe.core.common.utils import Const
|
|
19
|
-
from msprobe.pytorch.api_accuracy_checker.common.utils import ApiData
|
|
20
|
-
from msprobe.pytorch.common.log import logger
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
class ATTLManager:
|
|
24
|
-
def __init__(self, config):
|
|
25
|
-
self.config = config
|
|
26
|
-
self.attl = None
|
|
27
|
-
|
|
28
|
-
def attl_init(self):
|
|
29
|
-
if self.config.online_run_ut:
|
|
30
|
-
from msprobe.pytorch.api_accuracy_checker.tensor_transport_layer.attl import ATTLConfig, ATTL
|
|
31
|
-
attl_config = ATTLConfig(is_benchmark_device=False,
|
|
32
|
-
connect_ip=self.config.host,
|
|
33
|
-
connect_port=self.config.port,
|
|
34
|
-
nfs_path=self.config.nfs_path,
|
|
35
|
-
tls_path=self.config.tls_path)
|
|
36
|
-
need_dump = len(self.config.rank) == 0 or Runtime.current_rank in self.config.rank
|
|
37
|
-
self.attl = ATTL('npu', attl_config, need_dump=need_dump)
|
|
38
|
-
if self.config.nfs_path:
|
|
39
|
-
self.attl.upload("start")
|
|
40
|
-
|
|
41
|
-
def attl_send(self, name, args, kwargs, output):
|
|
42
|
-
api_data = ApiData(
|
|
43
|
-
name[:-len(Const.FORWARD_NAME_SUFFIX)],
|
|
44
|
-
args,
|
|
45
|
-
kwargs,
|
|
46
|
-
output,
|
|
47
|
-
Runtime.current_iter,
|
|
48
|
-
Runtime.current_rank
|
|
49
|
-
)
|
|
50
|
-
logger.info(f"tools is dumping api: {api_data.name}, rank: {Runtime.current_rank}")
|
|
51
|
-
api_type, _, _ = api_data.name.split(Const.SEP)
|
|
52
|
-
if api_type in [Const.DISTRIBUTED]:
|
|
53
|
-
logger.info(f"api {api_data.name} is not supported, skip")
|
|
54
|
-
return
|
|
55
|
-
if self.config.nfs_path:
|
|
56
|
-
self.attl.upload(api_data)
|
|
57
|
-
else:
|
|
58
|
-
self.attl.send(api_data)
|
|
59
|
-
|
|
60
|
-
def attl_stop(self):
|
|
61
|
-
if self.config.nfs_path:
|
|
62
|
-
self.attl.upload("end")
|
|
63
|
-
elif self.attl.socket_manager is not None:
|
|
64
|
-
logger.info(f"pid: {os.getpid()} finished, start sends STOP signal.")
|
|
65
|
-
self.attl.socket_manager.send_stop_signal()
|
|
File without changes
|
|
File without changes
|
|
File without changes
|