mindstudio-probe 8.3.0__py3-none-any.whl → 8.3.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (54) hide show
  1. {mindstudio_probe-8.3.0.dist-info → mindstudio_probe-8.3.2.dist-info}/METADATA +1 -1
  2. {mindstudio_probe-8.3.0.dist-info → mindstudio_probe-8.3.2.dist-info}/RECORD +44 -54
  3. msprobe/README.md +8 -5
  4. msprobe/core/common/const.py +17 -3
  5. msprobe/core/common/file_utils.py +64 -13
  6. msprobe/core/common/framework_adapter.py +10 -1
  7. msprobe/core/common/utils.py +17 -0
  8. msprobe/core/compare/utils.py +26 -6
  9. msprobe/core/config_check/ckpt_compare/ckpt_comparator.py +6 -1
  10. msprobe/core/hook_manager.py +2 -16
  11. msprobe/core/service.py +5 -16
  12. msprobe/docs/01.installation.md +2 -0
  13. msprobe/docs/02.config_introduction.md +0 -13
  14. msprobe/docs/05.data_dump_PyTorch.md +1 -1
  15. msprobe/docs/07.accuracy_checker_PyTorch.md +13 -13
  16. msprobe/docs/10.accuracy_compare_PyTorch.md +6 -6
  17. msprobe/docs/14.data_parse_PyTorch.md +2 -0
  18. msprobe/docs/19.monitor.md +4 -4
  19. msprobe/docs/21.visualization_PyTorch.md +1 -1
  20. msprobe/docs/25.tool_function_introduction.md +0 -1
  21. msprobe/docs/32.ckpt_compare.md +5 -5
  22. msprobe/mindspore/monitor/module_hook.py +17 -20
  23. msprobe/pytorch/api_accuracy_checker/common/config.py +3 -36
  24. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +0 -24
  25. msprobe/pytorch/api_accuracy_checker/compare/compare.py +2 -12
  26. msprobe/pytorch/api_accuracy_checker/config.yaml +1 -6
  27. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +34 -5
  28. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +12 -132
  29. msprobe/pytorch/common/utils.py +0 -70
  30. msprobe/pytorch/debugger/debugger_config.py +0 -10
  31. msprobe/pytorch/dump/module_dump/module_processer.py +18 -3
  32. msprobe/pytorch/hook_module/api_register.py +14 -3
  33. msprobe/pytorch/monitor/module_hook.py +16 -34
  34. msprobe/pytorch/pt_config.py +2 -51
  35. msprobe/pytorch/pytorch_service.py +10 -14
  36. msprobe/visualization/builder/graph_builder.py +2 -2
  37. msprobe/visualization/builder/graph_merger.py +13 -0
  38. msprobe/visualization/db_utils.py +42 -18
  39. msprobe/visualization/graph/graph.py +13 -9
  40. msprobe/visualization/graph_service.py +20 -10
  41. msprobe/docs/08.accuracy_checker_online_PyTorch.md +0 -295
  42. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/__init__.py +0 -0
  43. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +0 -205
  44. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +0 -378
  45. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +0 -239
  46. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/dump_dispatch.py +0 -115
  47. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +0 -250
  48. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/torch_ops_config.yaml +0 -63
  49. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +0 -198
  50. msprobe/pytorch/attl_manager.py +0 -65
  51. {mindstudio_probe-8.3.0.dist-info → mindstudio_probe-8.3.2.dist-info}/LICENSE +0 -0
  52. {mindstudio_probe-8.3.0.dist-info → mindstudio_probe-8.3.2.dist-info}/WHEEL +0 -0
  53. {mindstudio_probe-8.3.0.dist-info → mindstudio_probe-8.3.2.dist-info}/entry_points.txt +0 -0
  54. {mindstudio_probe-8.3.0.dist-info → mindstudio_probe-8.3.2.dist-info}/top_level.txt +0 -0
@@ -1,115 +0,0 @@
1
-
2
- # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
3
- # All rights reserved.
4
- #
5
- # Licensed under the Apache License, Version 2.0 (the "License");
6
- # you may not use this file except in compliance with the License.
7
- # You may obtain a copy of the License at
8
- #
9
- # http://www.apache.org/licenses/LICENSE-2.0
10
- #
11
- # Unless required by applicable law or agreed to in writing, software
12
- # distributed under the License is distributed on an "AS IS" BASIS,
13
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
- # See the License for the specific language governing permissions and
15
- # limitations under the License.
16
-
17
- import os
18
- from collections import defaultdict
19
- from functools import wraps
20
-
21
- import torch
22
- from torch.utils._python_dispatch import TorchDispatchMode
23
- from msprobe.pytorch.api_accuracy_checker.common.utils import ApiData
24
- from msprobe.pytorch.common.utils import get_tensor_rank
25
- from msprobe.core.common.const import Const
26
- from msprobe.pytorch.common.log import logger
27
- from msprobe.core.common.file_utils import load_yaml
28
-
29
-
30
- def singleton(cls):
31
- _instance = {}
32
-
33
- @wraps(cls)
34
- def inner():
35
- if cls not in _instance:
36
- _instance[cls] = cls()
37
- return _instance[cls]
38
- return inner
39
-
40
-
41
- @singleton
42
- class Counter:
43
- def __init__(self) -> None:
44
- self.index_dict = defaultdict(int)
45
-
46
-
47
- counter = Counter()
48
- yaml_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "torch_ops_config.yaml")
49
- yaml_file = load_yaml(yaml_path)
50
-
51
-
52
- class AccuracyCheckerDispatch(TorchDispatchMode):
53
- def __init__(self, attl):
54
- super(AccuracyCheckerDispatch, self).__init__()
55
- self.attl = attl
56
- self.counter = counter
57
- self.aten_ops_blacklist = []
58
- self.npu_adjust_autogard = []
59
- self.aten_ops_blacklist = yaml_file.get('aten_ops_blacklist', [])
60
- self.npu_adjust_autogard = yaml_file.get('npu_adjust_autogard', [])
61
-
62
- def __torch_dispatch__(self, func, types, args=None, kwargs=None):
63
- func_name_split_list = func.__name__.split(Const.SEP)
64
- aten_api = func_name_split_list[0]
65
- self.enable_autogard(aten_api)
66
- if aten_api in self.aten_ops_blacklist:
67
- npu_out = func(*args, **kwargs)
68
- return npu_out
69
-
70
- res = func(*args, **kwargs)
71
- cur_rank = get_tensor_rank(args, res)
72
- cur_api_number = self.counter.index_dict[aten_api]
73
- api_name = f'{Const.ATEN}{Const.SEP}{aten_api}{Const.SEP}{cur_api_number}'
74
- logger.info(f"tools is dumping api: {api_name}, rank: {cur_rank}")
75
- api_data = ApiData(api_name, args, kwargs, res, 0, cur_rank)
76
- if "device" in api_data.kwargs:
77
- api_data.kwargs.pop("device")
78
- if self.attl.nfs_path:
79
- self.attl.upload(api_data)
80
- else:
81
- self.attl.send(api_data)
82
- self.counter.index_dict[aten_api] += 1
83
-
84
- return res
85
-
86
- def enable_autogard(self, aten_api):
87
- if aten_api in self.npu_adjust_autogard:
88
- torch._C._dispatch_tls_set_dispatch_key_excluded(torch._C.DispatchKey.AutogradFunctionality, False)
89
-
90
-
91
- def dispatch4data(func, attl, status):
92
- @wraps(func)
93
- def wrapper(*args, **kwargs):
94
- if not status:
95
- return func(*args, **kwargs)
96
- with AccuracyCheckerDispatch(attl):
97
- res = func(*args, **kwargs)
98
- return res
99
-
100
- return wrapper
101
-
102
-
103
- def run_ut_dispatch(attl, status, is_recompute=False):
104
- """
105
- This function called by online_run_ut.
106
- It is used to enable or disable dispatch for torch.autograd.backward function.
107
-
108
- Args:
109
- attl (ATTL): online_run_ut class ATTL, which is used to upload or send api data to server.
110
- status (bool): True means enable dispatch, False means disable dispatch.
111
- is_recompute (bool): Flag of recompute, which is conflicted with aten api, then skip dispatch4data.
112
- """
113
- if is_recompute:
114
- return
115
- torch.autograd.backward = dispatch4data(torch.autograd.backward, attl, status)
@@ -1,250 +0,0 @@
1
- # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
- # All rights reserved.
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
- from functools import partial
16
- import os
17
- import struct
18
- import zlib
19
- import time
20
- import io
21
- from threading import Thread
22
-
23
- from twisted.internet import reactor, protocol, endpoints, ssl
24
-
25
- from msprobe.pytorch.common.utils import logger
26
- from msprobe.pytorch.api_accuracy_checker.tensor_transport_layer.utils import cipher_list, \
27
- STRUCT_UNPACK_MODE as unpack_mode, STR_TO_BYTES_ORDER as bytes_order, verify_callback, load_ssl_pem
28
-
29
-
30
- class TCPServer:
31
- def __init__(self, port, shared_queue, check_sum=False, tls_path=None) -> None:
32
- self.port = port
33
- self.shared_queue = shared_queue
34
- self.check_sum = check_sum
35
- self.tls_path = tls_path
36
- self.factory = MessageServerFactory()
37
- self.reactor_thread = None
38
-
39
- @staticmethod
40
- def run_reactor():
41
- reactor.run(installSignalHandlers=False)
42
-
43
- def start(self):
44
- self.factory.protocol = self.build_protocol
45
-
46
- if self.tls_path:
47
- server_key, server_crt, ca_crt, crl_pem = load_ssl_pem(
48
- key_file=os.path.join(self.tls_path, "server.key"),
49
- cert_file=os.path.join(self.tls_path, "server.crt"),
50
- ca_file=os.path.join(self.tls_path, "ca.crt"),
51
- crl_file=os.path.join(self.tls_path, "crl.pem")
52
- )
53
-
54
- ssl_options = ssl.CertificateOptions(
55
- privateKey=server_key,
56
- certificate=server_crt,
57
- method=ssl.SSL.TLSv1_2_METHOD,
58
- verify=True,
59
- requireCertificate=True,
60
- caCerts=[ca_crt], # 信任的CA证书列表
61
- )
62
- ssl_context = ssl_options.getContext()
63
- ssl_context.set_cipher_list(cipher_list)
64
- ssl_context.set_options(ssl.SSL.OP_NO_RENEGOTIATION)
65
- ssl_context.set_verify(ssl.SSL.VERIFY_PEER | ssl.SSL.VERIFY_FAIL_IF_NO_PEER_CERT,
66
- partial(verify_callback, crl=crl_pem))
67
-
68
- endpoint = endpoints.SSL4ServerEndpoint(reactor, self.port, ssl_options)
69
- else:
70
- endpoint = endpoints.TCP4ServerEndpoint(reactor, self.port)
71
- endpoint.listen(self.factory)
72
- self.reactor_thread = Thread(target=self.run_reactor, daemon=True)
73
- self.reactor_thread.start()
74
-
75
- def is_running(self):
76
- return not self.factory.is_all_connection_closed()
77
-
78
- def stop(self):
79
- self.factory.doStop()
80
- reactor.callFromThread(reactor.sigInt, 2)
81
- self.reactor_thread.join()
82
-
83
- def build_protocol(self):
84
- return ServerProtocol(self.shared_queue, self.check_sum)
85
-
86
-
87
- class ServerProtocol(protocol.Protocol):
88
- ACK_SUCCESS = b"OK___"
89
- ACK_ERROR = b"ERROR"
90
- ACK_BUSY = b"BUSY_"
91
- ACK_STOP = b"STOP_"
92
- ACK_STOP_CONFIRM = b"OVER_"
93
- ACK_KILL_PROCESS = b"KILL_"
94
-
95
- def __init__(self, shared_queue, check_sum=False):
96
- self.start_time = None
97
- self.buffer = io.BytesIO()
98
- self.consumer_queue = shared_queue
99
- self.check_sum = check_sum
100
- self.length_width = 8
101
- self.crc_width = 8
102
- self.obj_length = None
103
- self.tell = 0
104
- self.obj_crc = None
105
- self.obj_body = None
106
- self.sequence_number = -1
107
- self.rank = -1
108
- self.step = -1
109
- self.sequence_number_dict = dict()
110
-
111
- def connectionMade(self):
112
- self.buffer = io.BytesIO()
113
- self.obj_length = None
114
- self.tell = 0
115
- self.obj_crc = None
116
- self.obj_body = None
117
- self.factory.transport_dict[self.transport] = 1
118
- self.factory.transport_list.append(self.transport)
119
- logger.info(f"Connected to {self.transport.getPeer()} successfully.")
120
-
121
- def connectionLost(self, reason):
122
- self.factory.transport_dict.pop(self.transport, None)
123
- if len(self.factory.transport_dict) == 0:
124
- self.consumer_queue.put(self.ACK_KILL_PROCESS)
125
-
126
- logger.info(f"Lost connection with {self.transport.getPeer()}. Reason is: {reason} 与客户端 断开连接, "
127
- f"current connection number is: {len(self.factory.transport_dict)}")
128
-
129
- def send_ack(self, ack_info):
130
- ack_message = b"".join([
131
- ack_info,
132
- self.sequence_number.to_bytes(8, byteorder=bytes_order),
133
- self.rank.to_bytes(8, byteorder=bytes_order),
134
- self.step.to_bytes(8, byteorder=bytes_order)
135
- ])
136
- self.transport.write(ack_message)
137
-
138
- def post_process(self):
139
- send_busy_ack = False
140
- while self.consumer_queue.full():
141
- if not send_busy_ack:
142
- self.send_ack(self.ACK_BUSY)
143
- logger.debug("sending BUSY ACK")
144
- send_busy_ack = True
145
- time.sleep(0.1)
146
-
147
- obj_key = str(self.sequence_number) + "_" + str(self.rank) + "_" + str(self.step)
148
- # get the crc value of a 16-bit string with a length of 8
149
- recv_crc = f"{zlib.crc32(self.obj_body):08x}"
150
-
151
- if self.check_sum and recv_crc != self.obj_crc:
152
- # when needs check hash value and check no pass, indicates received data error, send b"ERROR" to client.
153
- logger.debug(f"Error:接收数据有问题,流水号{self.sequence_number}, expected {self.obj_crc}, but get {recv_crc}")
154
- self.send_ack(self.ACK_ERROR)
155
- else:
156
- if self.obj_body == self.ACK_STOP:
157
- self.handle_with_stop()
158
- else:
159
- self.send_ack(self.ACK_SUCCESS)
160
- if obj_key in self.sequence_number_dict:
161
- logger.debug(f"这是一次异常的重传,可以忽略。 {obj_key}, {self.sequence_number_dict}")
162
- else:
163
- self.sequence_number_dict[obj_key] = self.obj_crc
164
- self.consumer_queue.put(self.obj_body, block=True)
165
-
166
- self.reset_env()
167
- finish_time = time.time()
168
- logger.debug(f"finish_time: {finish_time - self.start_time}")
169
-
170
- def handle_with_stop(self):
171
- logger.debug(f"接收到停止传输信号 TCP{self.transport.getPeer()}")
172
- self.send_ack(self.ACK_STOP_CONFIRM)
173
- if len(self.factory.transport_dict) == 0:
174
- _rank, _step, _sequence_number = 0, 0, 100000000
175
- ack_kill = self.ACK_KILL_PROCESS + \
176
- _sequence_number.to_bytes(8, byteorder='big') + \
177
- _rank.to_bytes(8, byteorder='big') + \
178
- _step.to_bytes(8, byteorder='big')
179
- for trans in self.factory.transport_list:
180
- trans.write(ack_kill)
181
- logger.debug(f"发送KILL信息给{self.transport.getPeer()}")
182
- self.consumer_queue.put(self.ACK_KILL_PROCESS)
183
- time.sleep(2)
184
-
185
- def reset_env(self):
186
- self.obj_length = None
187
- self.sequence_number = -1
188
- self.rank = -1
189
- self.step = -1
190
- self.obj_crc = None
191
- self.obj_body = None
192
-
193
- def dataReceived(self, data):
194
- self.buffer.seek(0, 2)
195
- self.buffer.write(data)
196
- self.buffer.seek(self.tell)
197
-
198
- # The first data packet is packet header, it contains obj_length, sequence_number, rank, step
199
- if self.obj_length is None and len(self.buffer.getvalue()) >= self.length_width * 4:
200
- self.start_time = time.time()
201
- self.obj_length = struct.unpack(unpack_mode, self.buffer.read(self.length_width))[0]
202
- self.sequence_number = struct.unpack(unpack_mode, self.buffer.read(self.length_width))[0]
203
- self.rank = struct.unpack(unpack_mode, self.buffer.read(self.length_width))[0]
204
- self.step = struct.unpack(unpack_mode, self.buffer.read(self.length_width))[0]
205
- self.tell += self.length_width * 4
206
- logger.debug(
207
- f"流水号: {self.sequence_number}; RANK: {self.rank}; STEP: {self.step}; Length: {self.obj_length}")
208
-
209
- # If needs check hash but not parse crc yet, read 8b crc values
210
- check_sum_and_crc = (self.check_sum
211
- and self.obj_length is not None
212
- and self.obj_crc is None
213
- and len(self.buffer.getvalue()) - self.tell >= self.crc_width)
214
- if check_sum_and_crc:
215
- self.obj_crc = self.buffer.read(self.crc_width).decode()
216
- self.tell += self.crc_width
217
- logger.debug(f"Hash value: {self.obj_crc}")
218
-
219
- current_length = len(self.buffer.getvalue()) - self.tell
220
- if self.obj_length is not None and 0 < self.obj_length <= current_length:
221
- # Current api data receive finished
222
- self.obj_body = self.buffer.read(self.obj_length)
223
-
224
- self.tell += self.obj_length
225
- self.buffer = io.BytesIO(self.buffer.getvalue()[self.tell:])
226
- self.buffer.seek(0)
227
- self.tell = 0
228
- recv_data_time = time.time()
229
- logger.debug(f"self.sequence_number {self.sequence_number} "
230
- f"recv_data_time {recv_data_time - self.start_time}")
231
-
232
- if self.obj_body == self.ACK_STOP:
233
- # Indicates the current TCP link receives a STOP signal and remove from the transport_dict
234
- _transport = self.factory.transport_dict.pop(self.transport, None)
235
- logger.debug(f"接收到b'STOP_' self.sequence_number {self.sequence_number} ")
236
- self.post_process()
237
-
238
-
239
- class MessageServerFactory(protocol.ServerFactory):
240
- def __init__(self) -> None:
241
- """
242
- transport_dict: links that have not completed data transmission.
243
- transport_list: Records all TCP links. Appends TCP link to the transport list
244
- when a new TCP link is established.
245
- """
246
- self.transport_dict = {}
247
- self.transport_list = []
248
-
249
- def is_all_connection_closed(self):
250
- return len(self.transport_dict) == 0
@@ -1,63 +0,0 @@
1
- aten_ops_blacklist:
2
- - npu_binary_cross_entropy_with_logits_backward
3
- - npu_ciou_backward
4
- - _cudnn_rnn
5
- - _local_scalar_dense
6
- - _pin_memory
7
- - _to_copy
8
- - _unsafe_view
9
- - clone
10
- - contiguous
11
- - copy_
12
- - cudnn_batch_norm
13
- - cudnn_batch_norm_backward
14
- - detach
15
- - empty
16
- - index_put_
17
- - lift_fresh
18
- - max_pool2d_with_indices_backward # shape unmatch
19
- - native_batch_norm_backward
20
- - new_empty
21
- - new_empty_strided
22
- - new_full
23
- - new_ones
24
- - new_zeros
25
- - ones
26
- - ones_like
27
- - permute
28
- - rand
29
- - rand_like
30
- - randint
31
- - randint_like
32
- - randn
33
- - randn_like
34
- - randperm
35
- - scalar_tensor
36
- - select
37
- - to
38
- - transpose
39
- - unbind
40
- - view
41
- - zero
42
- - zero_
43
- - zeros
44
- - zeros_like
45
- - _record_function_enter_new
46
- - _record_function_exit
47
- - broadcast_
48
- - allreduce_
49
- - npu_clear_float_status
50
- - npu_format_cast
51
- - npu_dtype_cast
52
- - npu_dtype_cast_backward
53
- - _allgather_base_
54
- - _reduce_scatter_base_
55
- - is_same_size
56
-
57
- npu_adjust_autogard:
58
- - adaptive_avg_pool2d
59
- - batch_norm
60
- - log_softmax
61
- - nll_loss
62
- - to
63
-
@@ -1,198 +0,0 @@
1
- # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
- # All rights reserved.
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
- import gc
16
- import os
17
- from datetime import datetime, timezone
18
-
19
- from OpenSSL import crypto
20
- from cryptography import x509
21
- from cryptography.hazmat.backends import default_backend
22
- from dateutil import parser
23
-
24
- from msprobe.core.common.file_utils import FileOpen
25
- from msprobe.core.common.log import logger
26
-
27
- cipher_list = ":".join(
28
- ["TLS_DHE_RSA_WITH_AES_128_GCM_SHA256",
29
- "TLS_DHE_RSA_WITH_AES_256_GCM_SHA384",
30
- "TLS_DHE_DSS_WITH_AES_128_GCM_SHA256",
31
- "TLS_DHE_DSS_WITH_AES_256_GCM_SHA384",
32
- "TLS_DHE_PSK_WITH_AES_128_GCM_SHA256",
33
- "TLS_DHE_PSK_WITH_AES_256_GCM_SHA384",
34
- "TLS_DHE_PSK_WITH_CHACHA20_POLY1305_SHA256",
35
- "TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256",
36
- "TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384",
37
- "TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256",
38
- "TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384",
39
- "TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256",
40
- "TLS_ECDHE_PSK_WITH_CHACHA20_POLY1305_SHA256",
41
- "TLS_ECDHE_PSK_WITH_AES_128_GCM_SHA256",
42
- "TLS_ECDHE_PSK_WITH_AES_256_GCM_SHA384",
43
- "TLS_ECDHE_PSK_WITH_AES_128_CCM_SHA256",
44
- "TLS_DHE_RSA_WITH_AES_128_CCM",
45
- "TLS_DHE_RSA_WITH_AES_256_CCM",
46
- "TLS_DHE_RSA_WITH_CHACHA20_POLY1305_SHA256",
47
- "TLS_DHE_PSK_WITH_AES_128_CCM",
48
- "TLS_DHE_PSK_WITH_AES_256_CCM",
49
- "TLS_ECDHE_ECDSA_WITH_AES_128_CCM",
50
- "TLS_ECDHE_ECDSA_WITH_AES_256_CCM",
51
- "TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256"]
52
- ).encode()
53
-
54
- STRUCT_UNPACK_MODE = "!Q"
55
- STR_TO_BYTES_ORDER = "big"
56
-
57
-
58
- def is_certificate_revoked(cert, crl):
59
- # 获取证书的序列号
60
- cert_serial_number = cert.get_serial_number()
61
-
62
- # 检查证书是否在CRL中
63
- revoked_serials = [revoked_cert.serial_number for revoked_cert in crl]
64
- if cert_serial_number in revoked_serials:
65
- logger.error(f"证书已吊销:{cert_serial_number:020x}")
66
- return True
67
-
68
- return False
69
-
70
-
71
- def verify_callback(conn, cert, errno, depth, preverify_ok, crl=None):
72
- """
73
- 验证对端证书的有效性
74
- :param conn: OpenSSL.SSL.Connection, SSL 连接对象
75
- :param cert: OpenSSL.crypto.X509, 当前证书
76
- :param errno: int, OpenSSL错误代码, 0:无错误 | 9:证书过期 | 18: 自签名证书
77
- :param depth: int, 当前证书在证书链中的深度 (0=叶子节点), 1:中间CA证书 -1:根CA证书 2+:更高级别CA证书
78
- :param preverify_ok: int, 验证结果 (1=通过, 0=失败)
79
- :param crl: _CRLInternal, CRL证书对象
80
- :return: bool, True表示接受证书, False表示拒绝
81
- """
82
-
83
- if not preverify_ok:
84
- from OpenSSL import SSL
85
- error_str = SSL._ffi.string(SSL._lib.X509_verify_cert_error_string(errno)).decode()
86
- logger.error(f"证书验证失败 (depth={depth}, err={errno}): {error_str}")
87
- return False
88
-
89
- if crl and is_certificate_revoked(cert, crl):
90
- return False
91
-
92
- return preverify_ok
93
-
94
-
95
- def load_ssl_pem(key_file, cert_file, ca_file, crl_file):
96
- """
97
- Load SSL PEM files.
98
-
99
- Args:
100
- key_file (str): The path to the private key file.
101
- cert_file (str): The path to the certificate file.
102
- ca_file (str): The path to the CA certificate file.
103
- crl_file (str): The path to the CRL file.
104
-
105
- Returns:
106
- tuple: (key, crt, ca_crt, crl)
107
-
108
- Raises:
109
- Exception: If the file paths are invalid or the file contents are incorrect, exceptions may be thrown.
110
- """
111
-
112
- try:
113
- # your_private_key_password
114
- import pwinput
115
- passphrase = pwinput.pwinput("Enter your password: ")
116
- with FileOpen(key_file, "rb") as f:
117
- key = crypto.load_privatekey(crypto.FILETYPE_PEM, f.read(), passphrase.encode())
118
- del passphrase
119
- gc.collect()
120
- with FileOpen(cert_file, "rb") as f:
121
- crt = crypto.load_certificate(crypto.FILETYPE_PEM, f.read())
122
- check_crt_valid(crt)
123
-
124
- crt_serial_number = hex(crt.get_serial_number())[2:]
125
- logger.info(f"crt_serial_number: {crt_serial_number}")
126
-
127
- check_certificate_match(crt, key)
128
-
129
- with FileOpen(ca_file, "rb") as f:
130
- ca_crt = crypto.load_certificate(crypto.FILETYPE_PEM, f.read())
131
- check_crt_valid(ca_crt)
132
-
133
- ca_serial_number = hex(ca_crt.get_serial_number())[2:]
134
- logger.info(f"ca_serial_number: {ca_serial_number}")
135
- crl = None
136
- if os.path.exists(crl_file):
137
- with FileOpen(crl_file, "rb") as f:
138
- crl = x509.load_pem_x509_crl(f.read(), default_backend())
139
- check_crl_valid(crl, ca_crt)
140
- for revoked_cert in crl:
141
- logger.info(f"Serial Number: {revoked_cert.serial_number}, "
142
- f"Revocation Date: {revoked_cert.revocation_date_utc}")
143
-
144
- except Exception as e:
145
- raise RuntimeError(f"The SSL certificate is invalid") from e
146
-
147
- return key, crt, ca_crt, crl
148
-
149
-
150
- def check_crt_valid(pem):
151
- """
152
- Check the validity of the SSL certificate.
153
-
154
- Raises:
155
- RuntimeError: If the SSL certificate is invalid or expired.
156
- """
157
- try:
158
- pem_start = parser.parse(pem.get_notBefore().decode("UTF-8"))
159
- pem_end = parser.parse(pem.get_notAfter().decode("UTF-8"))
160
- logger.info(f"The SSL certificate passes the verification and the validity period "
161
- f"starts from {pem_start} ends at {pem_end}.")
162
- except Exception as e:
163
- raise RuntimeError(f"The SSL certificate is invalid") from e
164
-
165
- now_utc = datetime.now(tz=timezone.utc)
166
- if pem.has_expired() or not (pem_start <= now_utc <= pem_end):
167
- raise RuntimeError(f"The SSL certificate has expired.")
168
-
169
-
170
- def check_certificate_match(certificate, private_key):
171
- """
172
- Check certificate and private_key is match or not. if mismatched, an exception is thrown.
173
- :param certificate:
174
- :param private_key:
175
- :return:
176
- """
177
- test_data = os.urandom(256)
178
- try:
179
- signature = crypto.sign(private_key, test_data, "sha256")
180
- crypto.verify(
181
- certificate, # 包含公钥的证书
182
- signature, # 生成的签名
183
- test_data, # 原始数据
184
- "sha256", # 哈希算法
185
- )
186
- logger.info("公钥和私钥匹配")
187
- except Exception as e:
188
- raise RuntimeError("公钥和私钥不匹配") from e
189
-
190
-
191
- def check_crl_valid(crl, ca_crt):
192
- # 验证CRL签名(确保CRL未被篡改)
193
- if not crl.is_signature_valid(ca_crt.get_pubkey().to_cryptography_key()):
194
- raise RuntimeError("CRL签名无效!")
195
-
196
- # 检查CRL有效期
197
- if not (crl.last_update <= datetime.utcnow() <= crl.next_update):
198
- raise RuntimeError("CRL已过期或尚未生效!")
@@ -1,65 +0,0 @@
1
- # Copyright (c) 2025, Huawei Technologies Co., Ltd.
2
- # All rights reserved.
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
-
16
- import os
17
- from msprobe.core.common.runtime import Runtime
18
- from msprobe.core.common.utils import Const
19
- from msprobe.pytorch.api_accuracy_checker.common.utils import ApiData
20
- from msprobe.pytorch.common.log import logger
21
-
22
-
23
- class ATTLManager:
24
- def __init__(self, config):
25
- self.config = config
26
- self.attl = None
27
-
28
- def attl_init(self):
29
- if self.config.online_run_ut:
30
- from msprobe.pytorch.api_accuracy_checker.tensor_transport_layer.attl import ATTLConfig, ATTL
31
- attl_config = ATTLConfig(is_benchmark_device=False,
32
- connect_ip=self.config.host,
33
- connect_port=self.config.port,
34
- nfs_path=self.config.nfs_path,
35
- tls_path=self.config.tls_path)
36
- need_dump = len(self.config.rank) == 0 or Runtime.current_rank in self.config.rank
37
- self.attl = ATTL('npu', attl_config, need_dump=need_dump)
38
- if self.config.nfs_path:
39
- self.attl.upload("start")
40
-
41
- def attl_send(self, name, args, kwargs, output):
42
- api_data = ApiData(
43
- name[:-len(Const.FORWARD_NAME_SUFFIX)],
44
- args,
45
- kwargs,
46
- output,
47
- Runtime.current_iter,
48
- Runtime.current_rank
49
- )
50
- logger.info(f"tools is dumping api: {api_data.name}, rank: {Runtime.current_rank}")
51
- api_type, _, _ = api_data.name.split(Const.SEP)
52
- if api_type in [Const.DISTRIBUTED]:
53
- logger.info(f"api {api_data.name} is not supported, skip")
54
- return
55
- if self.config.nfs_path:
56
- self.attl.upload(api_data)
57
- else:
58
- self.attl.send(api_data)
59
-
60
- def attl_stop(self):
61
- if self.config.nfs_path:
62
- self.attl.upload("end")
63
- elif self.attl.socket_manager is not None:
64
- logger.info(f"pid: {os.getpid()} finished, start sends STOP signal.")
65
- self.attl.socket_manager.send_stop_signal()