mindstudio-probe 8.3.0__py3-none-any.whl → 8.3.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (54) hide show
  1. {mindstudio_probe-8.3.0.dist-info → mindstudio_probe-8.3.2.dist-info}/METADATA +1 -1
  2. {mindstudio_probe-8.3.0.dist-info → mindstudio_probe-8.3.2.dist-info}/RECORD +44 -54
  3. msprobe/README.md +8 -5
  4. msprobe/core/common/const.py +17 -3
  5. msprobe/core/common/file_utils.py +64 -13
  6. msprobe/core/common/framework_adapter.py +10 -1
  7. msprobe/core/common/utils.py +17 -0
  8. msprobe/core/compare/utils.py +26 -6
  9. msprobe/core/config_check/ckpt_compare/ckpt_comparator.py +6 -1
  10. msprobe/core/hook_manager.py +2 -16
  11. msprobe/core/service.py +5 -16
  12. msprobe/docs/01.installation.md +2 -0
  13. msprobe/docs/02.config_introduction.md +0 -13
  14. msprobe/docs/05.data_dump_PyTorch.md +1 -1
  15. msprobe/docs/07.accuracy_checker_PyTorch.md +13 -13
  16. msprobe/docs/10.accuracy_compare_PyTorch.md +6 -6
  17. msprobe/docs/14.data_parse_PyTorch.md +2 -0
  18. msprobe/docs/19.monitor.md +4 -4
  19. msprobe/docs/21.visualization_PyTorch.md +1 -1
  20. msprobe/docs/25.tool_function_introduction.md +0 -1
  21. msprobe/docs/32.ckpt_compare.md +5 -5
  22. msprobe/mindspore/monitor/module_hook.py +17 -20
  23. msprobe/pytorch/api_accuracy_checker/common/config.py +3 -36
  24. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +0 -24
  25. msprobe/pytorch/api_accuracy_checker/compare/compare.py +2 -12
  26. msprobe/pytorch/api_accuracy_checker/config.yaml +1 -6
  27. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +34 -5
  28. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +12 -132
  29. msprobe/pytorch/common/utils.py +0 -70
  30. msprobe/pytorch/debugger/debugger_config.py +0 -10
  31. msprobe/pytorch/dump/module_dump/module_processer.py +18 -3
  32. msprobe/pytorch/hook_module/api_register.py +14 -3
  33. msprobe/pytorch/monitor/module_hook.py +16 -34
  34. msprobe/pytorch/pt_config.py +2 -51
  35. msprobe/pytorch/pytorch_service.py +10 -14
  36. msprobe/visualization/builder/graph_builder.py +2 -2
  37. msprobe/visualization/builder/graph_merger.py +13 -0
  38. msprobe/visualization/db_utils.py +42 -18
  39. msprobe/visualization/graph/graph.py +13 -9
  40. msprobe/visualization/graph_service.py +20 -10
  41. msprobe/docs/08.accuracy_checker_online_PyTorch.md +0 -295
  42. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/__init__.py +0 -0
  43. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +0 -205
  44. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +0 -378
  45. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +0 -239
  46. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/dump_dispatch.py +0 -115
  47. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +0 -250
  48. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/torch_ops_config.yaml +0 -63
  49. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +0 -198
  50. msprobe/pytorch/attl_manager.py +0 -65
  51. {mindstudio_probe-8.3.0.dist-info → mindstudio_probe-8.3.2.dist-info}/LICENSE +0 -0
  52. {mindstudio_probe-8.3.0.dist-info → mindstudio_probe-8.3.2.dist-info}/WHEEL +0 -0
  53. {mindstudio_probe-8.3.0.dist-info → mindstudio_probe-8.3.2.dist-info}/entry_points.txt +0 -0
  54. {mindstudio_probe-8.3.0.dist-info → mindstudio_probe-8.3.2.dist-info}/top_level.txt +0 -0
@@ -1,378 +0,0 @@
1
- # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
- # All rights reserved.
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
- from functools import partial
16
- import zlib
17
- import io
18
- import struct
19
- import time
20
- import os
21
- from queue import Queue
22
- from threading import Thread
23
- from typing import Union
24
-
25
- from twisted.internet import reactor, protocol, endpoints, ssl
26
- from twisted.protocols.basic import FileSender
27
-
28
- from msprobe.pytorch.common.utils import logger
29
- from msprobe.pytorch.api_accuracy_checker.tensor_transport_layer.utils import STRUCT_UNPACK_MODE as unpack_mode, \
30
- STR_TO_BYTES_ORDER as bytes_order, cipher_list, verify_callback, load_ssl_pem
31
-
32
- MAX_SENDING_QUEUE_SIZE = 20
33
-
34
-
35
- class TCPDataItem:
36
- def __init__(self, data,
37
- sequence_number: int,
38
- rank: int = 0,
39
- step: int = 0):
40
- self.raw_data = data
41
- self.sequence_number = sequence_number
42
- self.rank = rank
43
- self.step = step
44
- self.retry_times = 0
45
- self.pending_time = 0
46
- self.busy_time = 0
47
-
48
-
49
- class TCPClient:
50
- ACK_SUCCESS = b"OK___"
51
- ACK_ERROR = b"ERROR"
52
- ACK_BUSY = b"BUSY_"
53
- ACK_STOP = b"STOP_"
54
- ACK_STOP_CONFIRM = b"OVER_"
55
- ACK_KILL_PROCESS = b"KILL_"
56
-
57
- QUEUE_PENDING_TIME = 60
58
- RESEND_RETRY_TIMES = 2 # 最大重传数
59
- RESEND_TIMER_TIME = 5 # 接收ACK超时定时器
60
- RESEND_PENDING_TIME = 60 # 连续pending时间超过1分钟则放弃该数据
61
-
62
- def __init__(self, host="localhost", port=8000, check_sum=False, tls_path=None):
63
- self.send_queue = Queue(MAX_SENDING_QUEUE_SIZE)
64
- self.resend_dict = dict()
65
- self.host = host
66
- self.port = port
67
- self.tls_path = tls_path
68
- self.factory = None
69
- self.sequence_number = 0
70
- self.signal_exit = False
71
- self.tcp_manager = ClientProtocol(ack_queue_size=100,
72
- chunk_size=655360,
73
- check_sum=check_sum,
74
- tls=self.tls_path)
75
- self.send_thread = Thread(target=self._sending_queue_data)
76
- self.send_thread.setDaemon(True)
77
- self.send_thread.start()
78
- self.destroy_thread = Thread(target=self._destroy_queue_data)
79
- self.destroy_thread.setDaemon(True)
80
- self.destroy_thread.start()
81
-
82
- @staticmethod
83
- def run_reactor():
84
- reactor.run(installSignalHandlers=False)
85
-
86
- def start(self):
87
- def conn_callback(cur_protocol):
88
- if cur_protocol.transport and cur_protocol.transport.getPeer().host == self.host:
89
- logger.debug(f"Process: {os.getpid()} connects to server successfully.")
90
- else:
91
- logger.warning(f"Process: {os.getpid()} fails to connect to server. ")
92
- raise ConnectionError(f"Failed to connect to {self.host}.")
93
-
94
- def conn_err_callback(failure):
95
- self.signal_exit = True
96
- time.sleep(1)
97
- reactor.stop()
98
- logger.error(f"Failed to connected {self.host} {self.port}. Reason is {failure.getErrorMessage()}")
99
-
100
- def cur_protocol():
101
- return self.tcp_manager
102
-
103
- self.factory = MessageClientFactory()
104
- self.factory.protocol = cur_protocol
105
- if self.tls_path:
106
- client_key, client_crt, ca_crt, crl_pem = load_ssl_pem(
107
- key_file=os.path.join(self.tls_path, "client.key"),
108
- cert_file=os.path.join(self.tls_path, "client.crt"),
109
- ca_file=os.path.join(self.tls_path, "ca.crt"),
110
- crl_file=os.path.join(self.tls_path, "crl.pem")
111
- )
112
-
113
- ssl_options = ssl.CertificateOptions(
114
- privateKey=client_key,
115
- certificate=client_crt,
116
- method=ssl.SSL.TLSv1_2_METHOD,
117
- verify=True,
118
- requireCertificate=True,
119
- caCerts=[ca_crt], # 信任的CA证书列表
120
- )
121
- ssl_context = ssl_options.getContext()
122
- ssl_context.set_cipher_list(cipher_list)
123
- ssl_context.set_options(ssl.SSL.OP_NO_RENEGOTIATION)
124
- ssl_context.set_verify(ssl.SSL.VERIFY_PEER | ssl.SSL.VERIFY_FAIL_IF_NO_PEER_CERT,
125
- partial(verify_callback, crl=crl_pem))
126
-
127
- endpoint = endpoints.SSL4ClientEndpoint(reactor, self.host, self.port, ssl_options)
128
- else:
129
- endpoint = endpoints.TCP4ClientEndpoint(reactor, self.host, self.port)
130
- d = endpoint.connect(self.factory)
131
- d.addCallback(conn_callback)
132
- d.addErrback(conn_err_callback)
133
-
134
- reactor_thread = Thread(target=self.run_reactor, daemon=True)
135
- reactor_thread.start()
136
-
137
- def send_after_queue_empty(self, data):
138
- while not self._ready_to_exit():
139
- if not self.tls_path:
140
- self.add_to_sending_queue(data)
141
- else:
142
- for _ in range(MAX_SENDING_QUEUE_SIZE):
143
- self.add_to_sending_queue(data)
144
- time.sleep(2)
145
-
146
- def check_client_alive(self):
147
- return self.factory.num_connections > 0
148
-
149
- def stop(self):
150
- self.tcp_manager.connection_timeout()
151
-
152
- def send_stop_signal(self):
153
- self.send_after_queue_empty(self.ACK_STOP)
154
- while not self._ready_to_exit():
155
- if not self.check_client_alive():
156
- break
157
- time.sleep(1)
158
-
159
- def add_to_sending_queue(self, data: Union[bytes, TCPDataItem], rank: int = 0, step: int = 0):
160
- if self._ready_to_exit():
161
- return
162
-
163
- send_data = data
164
- if not isinstance(data, TCPDataItem):
165
- send_data = TCPDataItem(data=data,
166
- sequence_number=self.sequence_number,
167
- rank=rank,
168
- step=step)
169
- self.sequence_number += 1
170
- try:
171
- self.send_queue.put(send_data, block=True, timeout=self.QUEUE_PENDING_TIME)
172
- except Exception as e:
173
- logger.error(f"send_queue put send_data timeout, rank: {send_data.rank}, step: {send_data.step},"
174
- f"sequence_number: {send_data.sequence_number}, send_queue size: {self.send_queue.qsize()},"
175
- f"{str(e)}")
176
-
177
- def _send_data(self, data: TCPDataItem):
178
- self.tcp_manager.send_wrapped_data(data.raw_data,
179
- sequence_number=data.sequence_number,
180
- rank=data.rank,
181
- step=data.step
182
- )
183
-
184
- def _sending_queue_data(self):
185
- while True:
186
- if not self.tcp_manager.is_connected:
187
- continue
188
-
189
- while self.send_queue.qsize() > 0:
190
- if self._ready_to_exit():
191
- break
192
- if len(self.resend_dict) < MAX_SENDING_QUEUE_SIZE:
193
- data_obj = self.send_queue.get()
194
- resend_key = str(data_obj.sequence_number) + "_" + str(data_obj.rank) + "_" + str(data_obj.step)
195
- logger.debug(f"get {resend_key} from send_queue, and send to server.")
196
- self._send_data(data_obj)
197
- if resend_key not in self.resend_dict.keys():
198
- # Send data for the first time
199
- self.resend_dict[resend_key] = data_obj
200
- else:
201
- time.sleep(0.1)
202
-
203
- if self._ready_to_exit():
204
- logger.debug("Successfully close sending process.")
205
- break
206
- time.sleep(0.1)
207
-
208
- def _destroy_queue_data(self):
209
- while True:
210
- if self._ready_to_exit():
211
- break
212
-
213
- while len(self.resend_dict) > 0 and self.tcp_manager.ack_queue.qsize() > 0:
214
- ack_info, seq_number, rank, step = self.tcp_manager.ack_queue.get()
215
- obj_key = str(seq_number) + "_" + str(rank) + "_" + str(step)
216
- current_item = self.resend_dict.get(obj_key)
217
-
218
- if current_item is None:
219
- continue
220
-
221
- if ack_info == self.ACK_SUCCESS:
222
- self.resend_dict.pop(obj_key)
223
- elif ack_info == self.ACK_BUSY:
224
- logger.debug("RECV BUSY ACK")
225
- if current_item.busy_time > 5:
226
- self._resend_data(current_item)
227
- else:
228
- current_item.busy_time += 1
229
- elif ack_info == self.ACK_ERROR:
230
- logger.debug("RECV ERROR ACK")
231
- self._resend_data(current_item)
232
- elif ack_info == self.ACK_STOP_CONFIRM:
233
- logger.debug("RECV STOP ACK")
234
- self.factory.num_connections -= 1
235
-
236
- break
237
-
238
- time.sleep(0.1)
239
-
240
- def _resend_data(self, data: TCPDataItem):
241
- if data.retry_times < self.RESEND_RETRY_TIMES:
242
- data.retry_times += 1
243
- logger.debug(f"Resend data seq number: {data.sequence_number}")
244
- self.add_to_sending_queue(data)
245
- else:
246
- self.resend_dict.pop(data.sequence_number)
247
- logger.debug(f"SKIP send sequence number {data.sequence_number} after retry {data.retry_times} times!")
248
-
249
- def _pending_data(self, data: TCPDataItem):
250
- if data.pending_time >= self.RESEND_PENDING_TIME:
251
- self.resend_dict.pop(data.sequence_number)
252
- logger.debug(f"SKIP send sequence number {data.sequence_number} after pending {data.pending_time} times!")
253
- return
254
-
255
- # wait time is 100MB per second
256
- pending_time = max(1, len(data.raw_data) // (2 ** 20 * 50))
257
- data.pending_time += pending_time
258
- time.sleep(pending_time)
259
-
260
- def _ready_to_exit(self):
261
- return self.signal_exit or self.tcp_manager.signal_exit
262
-
263
-
264
- class ClientProtocol(protocol.Protocol):
265
- TIMEOUT = 60 * 10
266
-
267
- def __init__(self, ack_queue_size=100, chunk_size=65536, check_sum=False, tls=None):
268
- self.buffer = io.BytesIO()
269
- self.is_connected = False
270
- self.check_sum = check_sum
271
- self.tell = 0
272
- self.ack_queue = Queue(maxsize=ack_queue_size)
273
- self.file_sender = FileSender()
274
- self.file_sender.CHUNK_SIZE = chunk_size
275
- self.signal_exit = False
276
- self.defer = None
277
- self.kill_process = False
278
- self.ack = None
279
-
280
- self.timeout_call = None
281
-
282
- self.tls = tls
283
- self.send_buffer = b""
284
- self.buffer_cnt = 0
285
-
286
- def dataReceived(self, data):
287
- if self.timeout_call.active():
288
- self.timeout_call.reset(self.TIMEOUT)
289
-
290
- self.buffer.seek(0, 2)
291
- self.buffer.write(data)
292
- self.buffer.seek(self.tell)
293
- while True:
294
- if len(self.buffer.getvalue()) >= 29: # 5 + 8 * 3
295
- ack = self.buffer.read(5)
296
- self.ack = ack
297
- seq_number = struct.unpack(unpack_mode, self.buffer.read(8))[0]
298
- rank = struct.unpack(unpack_mode, self.buffer.read(8))[0]
299
- step = struct.unpack(unpack_mode, self.buffer.read(8))[0]
300
- logger.debug(f"receive 流水号: {seq_number}; RANK: {rank}; STEP: {step}; ACK: {ack}")
301
- if ack == b"KILL_":
302
- self.kill_process = True
303
- logger.debug(f"接收到KILL信号, PID {os.getpid()}")
304
- if ack == b"OVER_":
305
- self.factory.num_connections -= 1
306
- self.tell += 29
307
- if not self.ack_queue.full():
308
- self.ack_queue.put((ack, seq_number, rank, step))
309
- self.buffer = io.BytesIO(self.buffer.getvalue()[self.tell:])
310
- self.tell = 0
311
- else:
312
- time.sleep(0.1)
313
- else:
314
- break
315
-
316
- def send_wrapped_data(self, data, sequence_number: int = 0, rank: int = 0, step: int = 0):
317
- length = len(data)
318
- data_crc = f"{zlib.crc32(data):08x}" if self.check_sum else ""
319
- data_meaasge = length.to_bytes(8, byteorder=bytes_order) + \
320
- sequence_number.to_bytes(8, byteorder=bytes_order) + \
321
- rank.to_bytes(8, byteorder=bytes_order) + \
322
- step.to_bytes(8, byteorder=bytes_order) + \
323
- data_crc.encode() + \
324
- data
325
- logger.debug(f"send 流水号: {sequence_number}; RANK: {rank}; STEP: {step}; LENGTH: {length}")
326
-
327
- while True:
328
- if self.defer is None or self.defer.called:
329
- self.defer = self.send_large_data(data_meaasge)
330
- break
331
- time.sleep(0.01)
332
-
333
- def send_large_data(self, data):
334
-
335
- if self.tls:
336
- self.send_buffer += data
337
- self.buffer_cnt += 1
338
- if self.buffer_cnt >= MAX_SENDING_QUEUE_SIZE:
339
- d = self.file_sender.beginFileTransfer(io.BytesIO(self.send_buffer), self.transport)
340
- self.send_buffer = b""
341
- self.buffer_cnt = 0
342
- else:
343
- d = None
344
- else:
345
- d = self.file_sender.beginFileTransfer(io.BytesIO(data), self.transport)
346
- return d
347
-
348
- def connection_timeout(self):
349
- if self.factory.num_connections <= 0:
350
- return
351
-
352
- self.factory.num_connections -= 1
353
- logger.debug(f"超时退出{self.transport.addr}, PID {os.getpid()}")
354
- self.transport.loseConnection()
355
-
356
- def connectionMade(self):
357
- self.timeout_call = reactor.callLater(self.TIMEOUT, self.connection_timeout)
358
- self.is_connected = True
359
- self.factory.num_connections += 1
360
- logger.info("successfully connect server")
361
-
362
- def connectionLost(self, reason):
363
- self.signal_exit = True
364
- self.factory.num_connections -= 1
365
- logger.info(f"Lost connection with server, reason is : {reason.value}")
366
-
367
-
368
- class MessageClientFactory(protocol.ClientFactory):
369
- def __init__(self):
370
- self.num_connections = 0
371
-
372
- def clientConnectionFailed(self, connector, reason):
373
- logger.info(f"Fail to connection with server: {reason.getErrorMessage()}")
374
- reactor.stop()
375
-
376
- def clientConnectionLost(self, connector, reason):
377
- logger.info(f"Client lost connection with server: {reason.getErrorMessage()}")
378
- reactor.stop()
@@ -1,239 +0,0 @@
1
- # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
- # All rights reserved.
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
-
16
- import time
17
- from collections import namedtuple
18
-
19
- import pandas as pd
20
- import torch
21
- import torch.multiprocessing as mp
22
-
23
- from msprobe.core.common.const import Const, CompareConst
24
- from msprobe.pytorch.api_accuracy_checker.compare.api_precision_compare import online_api_precision_compare
25
- from msprobe.pytorch.api_accuracy_checker.compare.compare_utils import DETAIL_TEST_ROWS, thousandth_standard_api, \
26
- binary_standard_api, absolute_standard_api
27
- from msprobe.pytorch.api_accuracy_checker.run_ut.run_ut_utils import UtDataInfo, exec_api, ExecParams
28
- from msprobe.pytorch.common.log import logger
29
- from msprobe.pytorch.api_accuracy_checker.tensor_transport_layer.attl import move2target_device
30
- from msprobe.pytorch.api_accuracy_checker.run_ut.run_ut_utils import generate_cpu_params
31
-
32
- # NPU vs GPU api list
33
- CompareApi = set(absolute_standard_api) | set(binary_standard_api) | set(thousandth_standard_api)
34
-
35
- current_time = time.strftime("%Y%m%d%H%M%S")
36
- ONLINE_API_PRECISION_COMPARE_RESULT_FILE_NAME = "api_precision_compare_result_" + current_time + "_rank*.csv"
37
- ONLINE_API_PRECISION_COMPARE_DETAILS_FILE_NAME = "api_precision_compare_details_" + current_time + "_rank*.csv"
38
-
39
- OnlineApiPrecisionCompareConfig = namedtuple('OnlineApiPrecisionCompareConfig',
40
- ['npu_data', 'gpu_data', 'rank', 'result_csv_path', 'details_csv_path'])
41
- # namedtuple of [instance of Comparator, func of run_touch_api_online, config of run_ut_config]
42
- CommonCompareConfig = namedtuple('CommonCompareConfig', ['compare', 'handle_func', 'config'])
43
-
44
-
45
- def get_gpu_device():
46
- is_gpu = False
47
- try:
48
- import torch_npu
49
- except ImportError:
50
- is_gpu = True
51
- return is_gpu
52
-
53
-
54
- def run_ut_process(xpu_id, consumer_queue, common_config, api_precision_csv_file):
55
- """ When consumer_queue(shared with ConsumerDispatcher) is not empty, consume api data from consumer_queue.
56
- :param xpu_id: int
57
- :param consumer_queue: shared queues of ConsumerDispatcher
58
- :param common_config: namedtuple of CommonCompareConfig
59
- :param api_precision_csv_file: list, length is 2, result file name and details file name
60
- :return:
61
- """
62
- device_info = "cuda" if get_gpu_device() else "npu"
63
- logger.info(f"Start run_ut_process for {device_info} device, rank: {xpu_id}.")
64
- gpu_device = torch.device(f'{device_info}:{xpu_id}')
65
-
66
- while True:
67
- if consumer_queue.empty():
68
- time.sleep(0.1)
69
- continue
70
-
71
- api_data = consumer_queue.get()
72
- if api_data == "KILL_":
73
- # current consumer finish
74
- return
75
-
76
- _, api_name, _ = api_data.name.split(Const.SEP)
77
- if api_name in CompareApi:
78
- # NPU vs GPU
79
- online_compare(api_data, gpu_device, common_config)
80
- else:
81
- # NPUvsCPU vs GPUvsCPU
82
- online_precision_compare(api_data, gpu_device, common_config, api_precision_csv_file)
83
-
84
-
85
- def online_precision_compare(api_data, device, common_config, api_precision_csv_file):
86
- """online run_ut for precision_compare: NPUvsCPU vs GPUvsCPU
87
- 1. get NPUvsCPU compare result
88
- 2. get GPUvsCPU compare result
89
- 3. call online_api_precision_compare
90
- :param api_data
91
- :param device
92
- :param common_config: namedtuple of CommonCompareConfig
93
- :param api_precision_csv_file: [result_file_name, details_file_name]
94
- """
95
- compare, func, config = common_config.compare, common_config.handle_func, common_config.config
96
- api_full_name = api_data.name
97
- [api_type, api_name, _] = api_full_name.split(Const.SEP)
98
- npu_args, npu_kwargs, npu_out = api_data.args, api_data.kwargs, api_data.result
99
-
100
- if npu_kwargs.get("device"):
101
- del npu_kwargs["device"]
102
-
103
- try:
104
- # NPU vs CPU
105
- cpu_params = generate_cpu_params(npu_args, npu_kwargs, False, api_name)
106
- cpu_args, cpu_kwargs = cpu_params.cpu_args, cpu_params.cpu_kwargs
107
- cpu_exec_params = ExecParams(api_type, api_name, Const.CPU_LOWERCASE, cpu_args, cpu_kwargs, False, None)
108
- cpu_out = exec_api(cpu_exec_params)
109
- npu_data_info = UtDataInfo(None, None, npu_out, cpu_out, None, [], None, rank=api_data.rank)
110
- npu_detail = compare.compare_output(api_full_name, npu_data_info, True)
111
- npu_data = pd.DataFrame(npu_detail, columns=DETAIL_TEST_ROWS[-1])
112
-
113
- # GPU vs CPU
114
- api_data_gpu = move2target_device(api_data, device) # args, kwargs -> gpu, result -> npu
115
- data_info = func(api_full_name, api_data_gpu, config.backward_content)
116
- gpu_out = data_info.bench_output
117
- gpu_data_info = UtDataInfo(None, None, gpu_out, cpu_out, None, [], None, rank=api_data.rank)
118
- gpu_detail = compare.compare_output(api_full_name, gpu_data_info, True)
119
- gpu_data = pd.DataFrame(gpu_detail, columns=DETAIL_TEST_ROWS[-1])
120
-
121
- # NPUvsCPU vs GPUvsCPU
122
- result_file_name, details_file_name = api_precision_csv_file
123
- precision_compare_config = OnlineApiPrecisionCompareConfig(npu_data, gpu_data, api_data.rank,
124
- result_file_name, details_file_name)
125
- online_api_precision_compare(precision_compare_config)
126
-
127
- except Exception as err:
128
- if "expected scalar type Long" in str(err):
129
- logger.warning(
130
- f"API {api_name} not support int32 tensor in CPU, please add {api_name} to CONVERT_API "
131
- f"'int32_to_int64' list in accuracy_tools/msprobe/core/common/const.py file.")
132
- elif api_type in [Const.DISTRIBUTED]:
133
- logger.info(f"{api_full_name} is not supported for run ut. SKIP.")
134
- else:
135
- logger.error(f"Run {api_full_name} UT Error: {str(err)}")
136
-
137
- compare.write_summary_csv((api_full_name, CompareConst.SKIP, CompareConst.SKIP, [[str(err)]], api_data.rank))
138
-
139
- finally:
140
- torch.cuda.empty_cache()
141
-
142
-
143
- def online_compare(api_data, device, common_config):
144
- """online run_ut for compare:NPU vs GPU
145
- """
146
- compare, func, config = common_config.compare, common_config.handle_func, common_config.config
147
- api_full_name = api_data.name
148
- api_data = move2target_device(api_data, device)
149
- try:
150
- data_info = func(api_full_name, api_data, config.backward_content)
151
- is_fwd_success, is_bwd_success = compare.compare_output(api_full_name, data_info)
152
- logger.info(f"running api_full_name {api_full_name} ut, "
153
- f"is_fwd_success: {is_fwd_success}, "
154
- f"is_bwd_success: {is_bwd_success}")
155
- except Exception as err:
156
- [api_type, api_name, _] = api_full_name.split(Const.SEP)
157
- if "expected scalar type Long" in str(err):
158
- logger.warning(
159
- f"API {api_name} not support int32 tensor in CPU, please add {api_name} to CONVERT_API "
160
- f"'int32_to_int64' list in accuracy_tools/msprobe/core/common/const.py file.")
161
- elif api_type in [Const.DISTRIBUTED]:
162
- logger.info(f"{api_full_name} is not supported for run ut. SKIP.")
163
- else:
164
- logger.error(f"Run {api_full_name} UT Error: {str(err)}")
165
-
166
- compare.write_summary_csv((api_full_name, CompareConst.SKIP, CompareConst.SKIP, [[str(err)]], api_data.rank))
167
-
168
- finally:
169
- torch.cuda.empty_cache()
170
-
171
-
172
- class ConsumerDispatcher:
173
- def __init__(self, compare, capacity=10, num_workers=8, device: str = "gpu") -> None:
174
- self.num_workers = num_workers
175
- self.capacity = capacity
176
- self.compare = compare
177
- self.queues = []
178
- self.processes = []
179
- self.reverse_sort = False
180
- self.pool = None
181
- self.device = device
182
- self.data_id = 0
183
- self.lock = mp.Lock()
184
- self.result_queue = mp.Queue()
185
- mp.set_start_method("spawn", force=True)
186
-
187
- def start(self, handle_func, config):
188
- self.queues = [mp.Queue(maxsize=self.capacity) for _ in range(self.num_workers)]
189
- api_precision_csv_file = [
190
- ONLINE_API_PRECISION_COMPARE_RESULT_FILE_NAME,
191
- ONLINE_API_PRECISION_COMPARE_DETAILS_FILE_NAME
192
- ]
193
- common_config = CommonCompareConfig(self.compare, handle_func, config)
194
- for xpu_id, q in enumerate(self.queues):
195
- p = mp.Process(name="run_ut_process", target=run_ut_process,
196
- args=(xpu_id, q, common_config, api_precision_csv_file))
197
-
198
- p.start()
199
- self.processes.append(p)
200
- logger.info(
201
- f'Api_precision_compare task result will be saved in {ONLINE_API_PRECISION_COMPARE_RESULT_FILE_NAME}')
202
- logger.info(
203
- f"Api_precision_compare task details will be saved in {ONLINE_API_PRECISION_COMPARE_DETAILS_FILE_NAME}")
204
- logger.info("Successfully start unittest process.")
205
-
206
- def stop(self):
207
- for q in self.queues:
208
- while q.full():
209
- time.sleep(0.1)
210
- q.put("KILL_")
211
-
212
- for p in self.processes:
213
- p.join()
214
- logger.info("Successfully stop unittest process.")
215
- logger.info(f"Api_precision_compare task result is saved in {ONLINE_API_PRECISION_COMPARE_RESULT_FILE_NAME}")
216
- logger.info(f"Api_precision_compare task details is saved in {ONLINE_API_PRECISION_COMPARE_DETAILS_FILE_NAME}")
217
-
218
- def update_consume_queue(self, api_data):
219
- while True:
220
- index = self._choose_max_empty_site_strategy()
221
- if index != -1:
222
- q = self.queues[index]
223
- q.put(api_data)
224
- break
225
- time.sleep(0.1)
226
-
227
- def _choose_max_empty_site_strategy(self):
228
- maximum = 0
229
- index = -1
230
- # 充分利用多卡资源,防止任务过多分配给前面的卡
231
- _reverse = 1 if not self.reverse_sort else -1
232
- for i, q in enumerate(self.queues[::_reverse]):
233
- empty_site = self.capacity - q.qsize()
234
- if empty_site > maximum:
235
- maximum = empty_site
236
- index = i
237
- index = len(self.queues) - index - 1 if index != -1 and self.reverse_sort else index
238
- self.reverse_sort = not self.reverse_sort
239
- return index