mindstudio-probe 1.3.0__py3-none-any.whl → 8.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (213) hide show
  1. {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/METADATA +4 -2
  2. {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/RECORD +204 -152
  3. msprobe/README.md +32 -1
  4. msprobe/core/__init__.py +17 -0
  5. msprobe/core/common/const.py +120 -21
  6. msprobe/core/common/exceptions.py +2 -2
  7. msprobe/core/common/file_utils.py +279 -50
  8. msprobe/core/common/framework_adapter.py +169 -0
  9. msprobe/core/common/global_lock.py +86 -0
  10. msprobe/core/common/runtime.py +25 -0
  11. msprobe/core/common/utils.py +136 -45
  12. msprobe/core/common_config.py +7 -0
  13. msprobe/core/compare/acc_compare.py +646 -428
  14. msprobe/core/compare/check.py +36 -103
  15. msprobe/core/compare/compare_cli.py +4 -0
  16. msprobe/core/compare/config.py +72 -0
  17. msprobe/core/compare/highlight.py +215 -215
  18. msprobe/core/compare/layer_mapping/layer_mapping.py +2 -0
  19. msprobe/core/compare/merge_result/merge_result.py +4 -4
  20. msprobe/core/compare/multiprocessing_compute.py +223 -110
  21. msprobe/core/compare/npy_compare.py +2 -4
  22. msprobe/core/compare/utils.py +214 -244
  23. msprobe/core/config_check/__init__.py +17 -0
  24. msprobe/{pytorch/dump/kernel_dump/kernel_config.py → core/config_check/checkers/__init__.py} +8 -16
  25. msprobe/core/config_check/checkers/base_checker.py +60 -0
  26. msprobe/core/config_check/checkers/dataset_checker.py +138 -0
  27. msprobe/core/config_check/checkers/env_args_checker.py +96 -0
  28. msprobe/core/config_check/checkers/hyperparameter_checker.py +170 -0
  29. msprobe/core/config_check/checkers/pip_checker.py +90 -0
  30. msprobe/core/config_check/checkers/random_checker.py +367 -0
  31. msprobe/core/config_check/checkers/weights_checker.py +147 -0
  32. msprobe/core/config_check/ckpt_compare/ckpt_comparator.py +74 -0
  33. msprobe/core/config_check/ckpt_compare/megatron_loader.py +302 -0
  34. msprobe/core/config_check/ckpt_compare/metrics.py +83 -0
  35. msprobe/core/config_check/ckpt_compare/name_mapping.yaml +12 -0
  36. msprobe/core/config_check/config_check_cli.py +51 -0
  37. msprobe/core/config_check/config_checker.py +100 -0
  38. msprobe/{mindspore/runtime.py → core/config_check/resource/dependency.yaml} +7 -4
  39. msprobe/core/config_check/resource/env.yaml +57 -0
  40. msprobe/core/config_check/resource/hyperparameter.yaml +21 -0
  41. msprobe/core/config_check/utils/hyperparameter_parser.py +115 -0
  42. msprobe/core/config_check/utils/utils.py +107 -0
  43. msprobe/core/data_dump/api_registry.py +67 -4
  44. msprobe/core/data_dump/data_collector.py +170 -89
  45. msprobe/core/data_dump/data_processor/base.py +72 -51
  46. msprobe/core/data_dump/data_processor/mindspore_processor.py +109 -55
  47. msprobe/core/data_dump/data_processor/pytorch_processor.py +90 -82
  48. msprobe/core/data_dump/json_writer.py +143 -27
  49. msprobe/core/debugger/precision_debugger.py +144 -0
  50. msprobe/core/grad_probe/constant.py +1 -1
  51. msprobe/core/grad_probe/grad_compare.py +1 -1
  52. msprobe/core/grad_probe/utils.py +1 -1
  53. msprobe/core/hook_manager.py +242 -0
  54. msprobe/core/monitor/anomaly_processor.py +384 -0
  55. msprobe/core/service.py +357 -0
  56. msprobe/core/single_save/__init__.py +0 -0
  57. msprobe/core/single_save/single_comparator.py +243 -0
  58. msprobe/core/single_save/single_saver.py +146 -0
  59. msprobe/docs/01.installation.md +6 -5
  60. msprobe/docs/02.config_introduction.md +79 -22
  61. msprobe/docs/03.config_examples.md +1 -0
  62. msprobe/docs/04.kernel_dump_PyTorch.md +1 -1
  63. msprobe/docs/05.data_dump_PyTorch.md +118 -49
  64. msprobe/docs/06.data_dump_MindSpore.md +167 -20
  65. msprobe/docs/07.accuracy_checker_PyTorch.md +2 -2
  66. msprobe/docs/08.accuracy_checker_online_PyTorch.md +69 -9
  67. msprobe/docs/09.accuracy_checker_MindSpore.md +18 -6
  68. msprobe/docs/10.accuracy_compare_PyTorch.md +212 -74
  69. msprobe/docs/11.accuracy_compare_MindSpore.md +87 -37
  70. msprobe/docs/12.overflow_check_PyTorch.md +2 -2
  71. msprobe/docs/13.overflow_check_MindSpore.md +2 -2
  72. msprobe/docs/14.data_parse_PyTorch.md +3 -3
  73. msprobe/docs/17.grad_probe.md +2 -1
  74. msprobe/docs/18.online_dispatch.md +2 -2
  75. msprobe/docs/19.monitor.md +90 -44
  76. msprobe/docs/21.visualization_PyTorch.md +68 -15
  77. msprobe/docs/22.visualization_MindSpore.md +71 -18
  78. msprobe/docs/25.tool_function_introduction.md +23 -22
  79. msprobe/docs/26.data_dump_PyTorch_baseline.md +14 -3
  80. msprobe/docs/27.dump_json_instruction.md +1 -1
  81. msprobe/docs/28.debugger_save_instruction.md +111 -20
  82. msprobe/docs/29.data_dump_MSAdapter.md +2 -2
  83. msprobe/docs/30.overflow_check_MSAdapter.md +2 -2
  84. msprobe/docs/31.config_check.md +95 -0
  85. msprobe/docs/32.ckpt_compare.md +69 -0
  86. msprobe/docs/33.generate_operator_MindSpore.md +181 -0
  87. msprobe/docs/34.RL_collect.md +92 -0
  88. msprobe/docs/35.nan_analyze.md +72 -0
  89. msprobe/docs/data_dump_MindSpore/data_dump_MindSpore_baseline.md +12 -1
  90. msprobe/docs/data_dump_MindSpore/dynamic_graph_quick_start_example.md +3 -1
  91. msprobe/docs/img/compare_result.png +0 -0
  92. msprobe/docs/img/save_compare_result_sample.png +0 -0
  93. msprobe/docs/img/visualization/proxy.png +0 -0
  94. msprobe/mindspore/__init__.py +1 -2
  95. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +150 -58
  96. msprobe/mindspore/api_accuracy_checker/api_runner.py +7 -3
  97. msprobe/mindspore/api_accuracy_checker/bench_functions/flash_attention_score.py +47 -69
  98. msprobe/mindspore/api_accuracy_checker/cmd_parser.py +4 -0
  99. msprobe/mindspore/api_accuracy_checker/compute_element.py +0 -1
  100. msprobe/mindspore/api_accuracy_checker/data_manager.py +2 -2
  101. msprobe/mindspore/api_accuracy_checker/generate_op_script/op_generator.py +460 -0
  102. msprobe/mindspore/api_accuracy_checker/generate_op_script/operator_replication.template +2081 -0
  103. msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +9 -0
  104. msprobe/mindspore/api_accuracy_checker/torch_mindtorch_importer.py +2 -1
  105. msprobe/mindspore/cell_processor.py +204 -33
  106. msprobe/mindspore/code_mapping/graph_parser.py +4 -21
  107. msprobe/mindspore/common/const.py +17 -7
  108. msprobe/mindspore/common/utils.py +128 -11
  109. msprobe/mindspore/compare/common_dir_compare.py +382 -0
  110. msprobe/mindspore/compare/distributed_compare.py +2 -26
  111. msprobe/mindspore/compare/ms_compare.py +17 -405
  112. msprobe/mindspore/compare/ms_graph_compare.py +14 -5
  113. msprobe/mindspore/compare/utils.py +37 -0
  114. msprobe/mindspore/debugger/debugger_config.py +53 -3
  115. msprobe/mindspore/debugger/precision_debugger.py +72 -91
  116. msprobe/mindspore/dump/cell_dump_process.py +877 -0
  117. msprobe/mindspore/dump/cell_dump_with_insert_gradient.py +864 -0
  118. msprobe/mindspore/dump/dump_tool_factory.py +13 -5
  119. msprobe/mindspore/dump/graph_mode_cell_dump.py +139 -0
  120. msprobe/mindspore/dump/graph_tensor_dump.py +123 -0
  121. msprobe/mindspore/dump/hook_cell/api_register.py +40 -6
  122. msprobe/mindspore/dump/hook_cell/hook_cell.py +18 -7
  123. msprobe/mindspore/dump/hook_cell/ms_hook_manager.py +88 -0
  124. msprobe/mindspore/dump/hook_cell/primitive_hooks.py +8 -2
  125. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +18 -0
  126. msprobe/mindspore/dump/jit_dump.py +21 -18
  127. msprobe/mindspore/dump/kernel_kbyk_dump.py +6 -3
  128. msprobe/mindspore/dym_loader/hook_dynamic_loader.cpp +110 -0
  129. msprobe/mindspore/dym_loader/hook_dynamic_loader.h +15 -15
  130. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +12 -6
  131. msprobe/mindspore/free_benchmark/common/utils.py +1 -1
  132. msprobe/mindspore/grad_probe/global_context.py +7 -2
  133. msprobe/mindspore/grad_probe/grad_stat_csv.py +3 -2
  134. msprobe/mindspore/mindspore_service.py +114 -0
  135. msprobe/mindspore/monitor/common_func.py +52 -0
  136. msprobe/mindspore/monitor/data_writers.py +237 -0
  137. msprobe/mindspore/monitor/features.py +20 -7
  138. msprobe/mindspore/monitor/module_hook.py +281 -209
  139. msprobe/mindspore/monitor/optimizer_collect.py +334 -0
  140. msprobe/mindspore/monitor/utils.py +25 -5
  141. msprobe/mindspore/ms_config.py +16 -15
  142. msprobe/mindspore/task_handler_factory.py +5 -2
  143. msprobe/msprobe.py +19 -0
  144. msprobe/nan_analyze/__init__.py +14 -0
  145. msprobe/nan_analyze/analyzer.py +255 -0
  146. msprobe/nan_analyze/graph.py +189 -0
  147. msprobe/nan_analyze/utils.py +211 -0
  148. msprobe/pytorch/api_accuracy_checker/common/config.py +2 -2
  149. msprobe/pytorch/api_accuracy_checker/compare/compare.py +36 -34
  150. msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +20 -20
  151. msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +4 -7
  152. msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +204 -2
  153. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +12 -11
  154. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +1 -0
  155. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +8 -5
  156. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +2 -3
  157. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +29 -13
  158. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +12 -2
  159. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +45 -31
  160. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +156 -0
  161. msprobe/pytorch/attl_manager.py +65 -0
  162. msprobe/pytorch/bench_functions/npu_fusion_attention.py +27 -0
  163. msprobe/pytorch/common/utils.py +26 -14
  164. msprobe/pytorch/compare/distributed_compare.py +4 -36
  165. msprobe/pytorch/compare/pt_compare.py +13 -84
  166. msprobe/pytorch/compare/utils.py +47 -0
  167. msprobe/pytorch/debugger/debugger_config.py +34 -17
  168. msprobe/pytorch/debugger/precision_debugger.py +66 -118
  169. msprobe/pytorch/dump/module_dump/hook_wrapper.py +93 -0
  170. msprobe/pytorch/dump/module_dump/module_dump.py +11 -58
  171. msprobe/pytorch/dump/module_dump/module_processer.py +143 -113
  172. msprobe/pytorch/grad_probe/grad_stat_csv.py +3 -2
  173. msprobe/pytorch/hook_module/api_register.py +29 -5
  174. msprobe/pytorch/hook_module/hook_module.py +9 -18
  175. msprobe/pytorch/hook_module/jit_script_wrapper.py +33 -0
  176. msprobe/pytorch/hook_module/pt_hook_manager.py +68 -0
  177. msprobe/pytorch/hook_module/support_wrap_ops.yaml +22 -1
  178. msprobe/pytorch/hook_module/utils.py +28 -2
  179. msprobe/pytorch/monitor/csv2tb.py +6 -2
  180. msprobe/pytorch/monitor/data_writers.py +259 -0
  181. msprobe/pytorch/monitor/module_hook.py +227 -158
  182. msprobe/pytorch/monitor/module_metric.py +14 -0
  183. msprobe/pytorch/monitor/optimizer_collect.py +242 -270
  184. msprobe/pytorch/monitor/utils.py +16 -3
  185. msprobe/pytorch/online_dispatch/dispatch.py +4 -2
  186. msprobe/pytorch/online_dispatch/dump_compare.py +5 -2
  187. msprobe/pytorch/parse_tool/lib/utils.py +3 -3
  188. msprobe/pytorch/pt_config.py +8 -7
  189. msprobe/pytorch/pytorch_service.py +73 -0
  190. msprobe/visualization/builder/graph_builder.py +33 -13
  191. msprobe/visualization/builder/msprobe_adapter.py +24 -11
  192. msprobe/visualization/compare/graph_comparator.py +53 -45
  193. msprobe/visualization/compare/mode_adapter.py +31 -1
  194. msprobe/visualization/graph/base_node.py +3 -3
  195. msprobe/visualization/graph/graph.py +2 -2
  196. msprobe/visualization/graph_service.py +250 -103
  197. msprobe/visualization/utils.py +27 -11
  198. msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +0 -106
  199. msprobe/mindspore/monitor/anomaly_detect.py +0 -404
  200. msprobe/mindspore/monitor/module_spec_verifier.py +0 -94
  201. msprobe/mindspore/service.py +0 -549
  202. msprobe/pytorch/monitor/anomaly_analyse.py +0 -201
  203. msprobe/pytorch/monitor/anomaly_detect.py +0 -410
  204. msprobe/pytorch/monitor/module_spec_verifier.py +0 -95
  205. msprobe/pytorch/monitor/unittest/test_monitor.py +0 -160
  206. msprobe/pytorch/service.py +0 -473
  207. {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/LICENSE +0 -0
  208. {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/WHEEL +0 -0
  209. {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/entry_points.txt +0 -0
  210. {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/top_level.txt +0 -0
  211. /msprobe/{mindspore → core}/compare/ms_to_pt_api.yaml +0 -0
  212. /msprobe/{mindspore/dump → core}/kernel_dump/kernel_config.py +0 -0
  213. /msprobe/{pytorch/monitor/unittest → core/monitor}/__init__.py +0 -0
@@ -29,7 +29,6 @@ from msprobe.pytorch.common.log import logger
29
29
  from msprobe.pytorch.api_accuracy_checker.tensor_transport_layer.attl import move2target_device
30
30
  from msprobe.pytorch.api_accuracy_checker.run_ut.run_ut_utils import generate_cpu_params
31
31
 
32
-
33
32
  # NPU vs GPU api list
34
33
  CompareApi = set(absolute_standard_api) | set(binary_standard_api) | set(thousandth_standard_api)
35
34
 
@@ -43,6 +42,15 @@ OnlineApiPrecisionCompareConfig = namedtuple('OnlineApiPrecisionCompareConfig',
43
42
  CommonCompareConfig = namedtuple('CommonCompareConfig', ['compare', 'handle_func', 'config'])
44
43
 
45
44
 
45
+ def get_gpu_device():
46
+ try:
47
+ import torch_npu
48
+ is_gpu = False
49
+ except ImportError:
50
+ is_gpu = True
51
+ return is_gpu
52
+
53
+
46
54
  def run_ut_process(xpu_id, consumer_queue, common_config, api_precision_csv_file):
47
55
  """ When consumer_queue(shared with ConsumerDispatcher) is not empty, consume api data from consumer_queue.
48
56
  :param xpu_id: int
@@ -51,7 +59,9 @@ def run_ut_process(xpu_id, consumer_queue, common_config, api_precision_csv_file
51
59
  :param api_precision_csv_file: list, length is 2, result file name and details file name
52
60
  :return:
53
61
  """
54
- gpu_device = torch.device(f'cuda:{xpu_id}')
62
+ device_info = "cuda" if get_gpu_device() else "npu"
63
+ logger.info(f"Start run_ut_process for {device_info} device, rank: {xpu_id}.")
64
+ gpu_device = torch.device(f'{device_info}:{xpu_id}')
55
65
 
56
66
  while True:
57
67
  if consumer_queue.empty():
@@ -12,19 +12,19 @@
12
12
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
13
  # See the License for the specific language governing permissions and
14
14
  # limitations under the License.
15
-
16
- import os.path
15
+ from functools import partial
16
+ import os
17
17
  import struct
18
- import hashlib
18
+ import zlib
19
19
  import time
20
20
  import io
21
21
  from threading import Thread
22
22
 
23
- from twisted.internet import reactor, protocol, endpoints
23
+ from twisted.internet import reactor, protocol, endpoints, ssl
24
24
 
25
25
  from msprobe.pytorch.common.utils import logger
26
26
  from msprobe.pytorch.api_accuracy_checker.tensor_transport_layer.utils import cipher_list, \
27
- STRUCT_UNPACK_MODE as unpack_mode, STR_TO_BYTES_ORDER as bytes_order
27
+ STRUCT_UNPACK_MODE as unpack_mode, STR_TO_BYTES_ORDER as bytes_order, verify_callback, load_ssl_pem
28
28
 
29
29
 
30
30
  class TCPServer:
@@ -44,15 +44,28 @@ class TCPServer:
44
44
  self.factory.protocol = self.build_protocol
45
45
 
46
46
  if self.tls_path:
47
- from OpenSSL import SSL
48
- from twisted.internet import ssl
49
- server_key = os.path.join(self.tls_path, "server.key")
50
- server_crt = os.path.join(self.tls_path, "server.crt")
51
- server_context_factory = ssl.DefaultOpenSSLContextFactory(server_key, server_crt, SSL.TLSv1_2_METHOD)
52
- server_context_ = server_context_factory.getContext()
53
- server_context_.set_cipher_list(cipher_list)
54
- server_context_.set_options(SSL.OP_NO_RENEGOTIATION)
55
- endpoint = endpoints.SSL4ServerEndpoint(reactor, self.port, server_context_factory)
47
+ server_key, server_crt, ca_crt, crl_pem = load_ssl_pem(
48
+ key_file=os.path.join(self.tls_path, "server.key"),
49
+ cert_file=os.path.join(self.tls_path, "server.crt"),
50
+ ca_file=os.path.join(self.tls_path, "ca.crt"),
51
+ crl_file=os.path.join(self.tls_path, "crl.pem")
52
+ )
53
+
54
+ ssl_options = ssl.CertificateOptions(
55
+ privateKey=server_key,
56
+ certificate=server_crt,
57
+ method=ssl.SSL.TLSv1_2_METHOD,
58
+ verify=True,
59
+ requireCertificate=True,
60
+ caCerts=[ca_crt], # 信任的CA证书列表
61
+ )
62
+ ssl_context = ssl_options.getContext()
63
+ ssl_context.set_cipher_list(cipher_list)
64
+ ssl_context.set_options(ssl.SSL.OP_NO_RENEGOTIATION)
65
+ ssl_context.set_verify(ssl.SSL.VERIFY_PEER | ssl.SSL.VERIFY_FAIL_IF_NO_PEER_CERT,
66
+ partial(verify_callback, crl=crl_pem))
67
+
68
+ endpoint = endpoints.SSL4ServerEndpoint(reactor, self.port, ssl_options)
56
69
  else:
57
70
  endpoint = endpoints.TCP4ServerEndpoint(reactor, self.port)
58
71
  endpoint.listen(self.factory)
@@ -85,10 +98,10 @@ class ServerProtocol(protocol.Protocol):
85
98
  self.consumer_queue = shared_queue
86
99
  self.check_sum = check_sum
87
100
  self.length_width = 8
88
- self.md5_width = 32
101
+ self.crc_width = 8
89
102
  self.obj_length = None
90
103
  self.tell = 0
91
- self.obj_md5 = None
104
+ self.obj_crc = None
92
105
  self.obj_body = None
93
106
  self.sequence_number = -1
94
107
  self.rank = -1
@@ -99,7 +112,7 @@ class ServerProtocol(protocol.Protocol):
99
112
  self.buffer = io.BytesIO()
100
113
  self.obj_length = None
101
114
  self.tell = 0
102
- self.obj_md5 = None
115
+ self.obj_crc = None
103
116
  self.obj_body = None
104
117
  self.factory.transport_dict[self.transport] = 1
105
118
  self.factory.transport_list.append(self.transport)
@@ -132,11 +145,12 @@ class ServerProtocol(protocol.Protocol):
132
145
  time.sleep(0.1)
133
146
 
134
147
  obj_key = str(self.sequence_number) + "_" + str(self.rank) + "_" + str(self.step)
148
+ # get the crc value of a 16-bit string with a length of 8
149
+ recv_crc = f"{zlib.crc32(self.obj_body):08x}"
135
150
 
136
- recv_md5 = hashlib.md5(self.obj_body).hexdigest()
137
- if self.check_sum and recv_md5 != self.obj_md5:
138
- # when needs check md5 and check no pass, indicates received data error, send b"ERROR" to client.
139
- logger.debug(f"Error:接收数据有问题,流水号{self.sequence_number}, expected {self.obj_md5}, but get {recv_md5}")
151
+ if self.check_sum and recv_crc != self.obj_crc:
152
+ # when needs check hash value and check no pass, indicates received data error, send b"ERROR" to client.
153
+ logger.debug(f"Error:接收数据有问题,流水号{self.sequence_number}, expected {self.obj_crc}, but get {recv_crc}")
140
154
  self.send_ack(self.ACK_ERROR)
141
155
  else:
142
156
  if self.obj_body == self.ACK_STOP:
@@ -146,7 +160,7 @@ class ServerProtocol(protocol.Protocol):
146
160
  if obj_key in self.sequence_number_dict:
147
161
  logger.debug(f"这是一次异常的重传,可以忽略。 {obj_key}, {self.sequence_number_dict}")
148
162
  else:
149
- self.sequence_number_dict[obj_key] = self.obj_md5
163
+ self.sequence_number_dict[obj_key] = self.obj_crc
150
164
  self.consumer_queue.put(self.obj_body, block=True)
151
165
 
152
166
  self.reset_env()
@@ -173,7 +187,7 @@ class ServerProtocol(protocol.Protocol):
173
187
  self.sequence_number = -1
174
188
  self.rank = -1
175
189
  self.step = -1
176
- self.obj_md5 = None
190
+ self.obj_crc = None
177
191
  self.obj_body = None
178
192
 
179
193
  def dataReceived(self, data):
@@ -192,15 +206,15 @@ class ServerProtocol(protocol.Protocol):
192
206
  logger.debug(
193
207
  f"流水号: {self.sequence_number}; RANK: {self.rank}; STEP: {self.step}; Length: {self.obj_length}")
194
208
 
195
- # If needs check md5 but not parse md5 yet, read 32b md5 values
196
- check_sum_and_md5 = (self.check_sum
209
+ # If needs check hash but not parse crc yet, read 8b crc values
210
+ check_sum_and_crc = (self.check_sum
197
211
  and self.obj_length is not None
198
- and self.obj_md5 is None
199
- and len(self.buffer.getvalue()) - self.tell >= self.md5_width)
200
- if check_sum_and_md5:
201
- self.obj_md5 = self.buffer.read(self.md5_width).decode()
202
- self.tell += self.md5_width
203
- logger.debug(f"MD5: {self.obj_md5}")
212
+ and self.obj_crc is None
213
+ and len(self.buffer.getvalue()) - self.tell >= self.crc_width)
214
+ if check_sum_and_crc:
215
+ self.obj_crc = self.buffer.read(self.crc_width).decode()
216
+ self.tell += self.crc_width
217
+ logger.debug(f"Hash value: {self.obj_crc}")
204
218
 
205
219
  current_length = len(self.buffer.getvalue()) - self.tell
206
220
  if self.obj_length is not None and 0 < self.obj_length <= current_length:
@@ -12,6 +12,17 @@
12
12
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
13
  # See the License for the specific language governing permissions and
14
14
  # limitations under the License.
15
+ import gc
16
+ import os
17
+ from datetime import datetime, timezone
18
+
19
+ from OpenSSL import crypto
20
+ from cryptography import x509
21
+ from cryptography.hazmat.backends import default_backend
22
+ from dateutil import parser
23
+
24
+ from msprobe.core.common.file_utils import FileOpen
25
+ from msprobe.core.common.log import logger
15
26
 
16
27
  cipher_list = ":".join(
17
28
  ["TLS_DHE_RSA_WITH_AES_128_GCM_SHA256",
@@ -42,3 +53,148 @@ cipher_list = ":".join(
42
53
 
43
54
  STRUCT_UNPACK_MODE = "!Q"
44
55
  STR_TO_BYTES_ORDER = "big"
56
+
57
+
58
+ def is_certificate_revoked(cert, crl):
59
+ # 获取证书的序列号
60
+ cert_serial_number = cert.get_serial_number()
61
+
62
+ # 检查证书是否在CRL中
63
+ revoked_serials = [revoked_cert.serial_number for revoked_cert in crl]
64
+ if cert_serial_number in revoked_serials:
65
+ logger.error(f"证书已吊销:{cert_serial_number:020x}")
66
+ return True
67
+
68
+ return False
69
+
70
+
71
+ def verify_callback(conn, cert, errno, depth, preverify_ok, crl=None):
72
+ """
73
+ 验证对端证书的有效性
74
+ :param conn: OpenSSL.SSL.Connection, SSL 连接对象
75
+ :param cert: OpenSSL.crypto.X509, 当前证书
76
+ :param errno: int, OpenSSL错误代码, 0:无错误 | 9:证书过期 | 18: 自签名证书
77
+ :param depth: int, 当前证书在证书链中的深度 (0=叶子节点), 1:中间CA证书 -1:根CA证书 2+:更高级别CA证书
78
+ :param preverify_ok: int, 验证结果 (1=通过, 0=失败)
79
+ :param crl: _CRLInternal, CRL证书对象
80
+ :return: bool, True表示接受证书, False表示拒绝
81
+ """
82
+
83
+ if not preverify_ok:
84
+ from OpenSSL import SSL
85
+ error_str = SSL._ffi.string(SSL._lib.X509_verify_cert_error_string(errno)).decode()
86
+ logger.error(f"证书验证失败 (depth={depth}, err={errno}): {error_str}")
87
+ return False
88
+
89
+ if crl and is_certificate_revoked(cert, crl):
90
+ return False
91
+
92
+ return preverify_ok
93
+
94
+
95
+ def load_ssl_pem(key_file, cert_file, ca_file, crl_file):
96
+ """
97
+ Load SSL PEM files.
98
+
99
+ Args:
100
+ key_file (str): The path to the private key file.
101
+ cert_file (str): The path to the certificate file.
102
+ ca_file (str): The path to the CA certificate file.
103
+ crl_file (str): The path to the CRL file.
104
+
105
+ Returns:
106
+ tuple: (key, crt, ca_crt, crl)
107
+
108
+ Raises:
109
+ Exception: If the file paths are invalid or the file contents are incorrect, exceptions may be thrown.
110
+ """
111
+
112
+ try:
113
+ # your_private_key_password
114
+ passphrase = ""
115
+ if not passphrase:
116
+ import pwinput
117
+ passphrase = pwinput.pwinput("Enter your password: ")
118
+ with FileOpen(key_file, "rb") as f:
119
+ key = crypto.load_privatekey(crypto.FILETYPE_PEM, f.read(), passphrase.encode())
120
+ del passphrase
121
+ gc.collect()
122
+ with FileOpen(cert_file, "rb") as f:
123
+ crt = crypto.load_certificate(crypto.FILETYPE_PEM, f.read())
124
+ check_crt_valid(crt)
125
+
126
+ crt_serial_number = hex(crt.get_serial_number())[2:]
127
+ logger.info(f"crt_serial_number: {crt_serial_number}")
128
+
129
+ check_certificate_match(crt, key)
130
+
131
+ with FileOpen(ca_file, "rb") as f:
132
+ ca_crt = crypto.load_certificate(crypto.FILETYPE_PEM, f.read())
133
+ check_crt_valid(ca_crt)
134
+
135
+ ca_serial_number = hex(ca_crt.get_serial_number())[2:]
136
+ logger.info(f"ca_serial_number: {ca_serial_number}")
137
+ crl = None
138
+ if os.path.exists(crl_file):
139
+ with FileOpen(crl_file, "rb") as f:
140
+ crl = x509.load_pem_x509_crl(f.read(), default_backend())
141
+ check_crl_valid(crl, ca_crt)
142
+ for revoked_cert in crl:
143
+ logger.info(f"Serial Number: {revoked_cert.serial_number}, "
144
+ f"Revocation Date: {revoked_cert.revocation_date_utc}")
145
+
146
+ except Exception as e:
147
+ raise RuntimeError(f"The SSL certificate is invalid") from e
148
+
149
+ return key, crt, ca_crt, crl
150
+
151
+
152
+ def check_crt_valid(pem):
153
+ """
154
+ Check the validity of the SSL certificate.
155
+
156
+ Raises:
157
+ RuntimeError: If the SSL certificate is invalid or expired.
158
+ """
159
+ try:
160
+ pem_start = parser.parse(pem.get_notBefore().decode("UTF-8"))
161
+ pem_end = parser.parse(pem.get_notAfter().decode("UTF-8"))
162
+ logger.info(f"The SSL certificate passes the verification and the validity period "
163
+ f"starts from {pem_start} ends at {pem_end}.")
164
+ except Exception as e:
165
+ raise RuntimeError(f"The SSL certificate is invalid") from e
166
+
167
+ now_utc = datetime.now(tz=timezone.utc)
168
+ if pem.has_expired() or not (pem_start <= now_utc <= pem_end):
169
+ raise RuntimeError(f"The SSL certificate has expired.")
170
+
171
+
172
+ def check_certificate_match(certificate, private_key):
173
+ """
174
+ Check certificate and private_key is match or not. if mismatched, an exception is thrown.
175
+ :param certificate:
176
+ :param private_key:
177
+ :return:
178
+ """
179
+ test_data = os.urandom(256)
180
+ try:
181
+ signature = crypto.sign(private_key, test_data, "sha256")
182
+ crypto.verify(
183
+ certificate, # 包含公钥的证书
184
+ signature, # 生成的签名
185
+ test_data, # 原始数据
186
+ "sha256", # 哈希算法
187
+ )
188
+ logger.info("公钥和私钥匹配")
189
+ except Exception as e:
190
+ raise RuntimeError("公钥和私钥不匹配") from e
191
+
192
+
193
+ def check_crl_valid(crl, ca_crt):
194
+ # 验证CRL签名(确保CRL未被篡改)
195
+ if not crl.is_signature_valid(ca_crt.get_pubkey().to_cryptography_key()):
196
+ raise RuntimeError("CRL签名无效!")
197
+
198
+ # 检查CRL有效期
199
+ if not (crl.last_update <= datetime.utcnow() <= crl.next_update):
200
+ raise RuntimeError("CRL已过期或尚未生效!")
@@ -0,0 +1,65 @@
1
+ # Copyright (c) 2025, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import os
17
+ from msprobe.core.common.runtime import Runtime
18
+ from msprobe.core.common.utils import Const
19
+ from msprobe.pytorch.api_accuracy_checker.common.utils import ApiData
20
+ from msprobe.pytorch.common.log import logger
21
+
22
+
23
+ class ATTLManager:
24
+ def __init__(self, config):
25
+ self.config = config
26
+ self.attl = None
27
+
28
+ def attl_init(self):
29
+ if self.config.online_run_ut:
30
+ from msprobe.pytorch.api_accuracy_checker.tensor_transport_layer.attl import ATTLConfig, ATTL
31
+ attl_config = ATTLConfig(is_benchmark_device=False,
32
+ connect_ip=self.config.host,
33
+ connect_port=self.config.port,
34
+ nfs_path=self.config.nfs_path,
35
+ tls_path=self.config.tls_path)
36
+ need_dump = len(self.config.rank) == 0 or Runtime.current_rank in self.config.rank
37
+ self.attl = ATTL('npu', attl_config, need_dump=need_dump)
38
+ if self.config.nfs_path:
39
+ self.attl.upload("start")
40
+
41
+ def attl_send(self, name, args, kwargs, output):
42
+ api_data = ApiData(
43
+ name[:-len(Const.FORWARD_NAME_SUFFIX)],
44
+ args,
45
+ kwargs,
46
+ output,
47
+ Runtime.current_iter,
48
+ Runtime.current_rank
49
+ )
50
+ logger.info(f"tools is dumping api: {api_data.name}, rank: {Runtime.current_rank}")
51
+ api_type, _, _ = api_data.name.split(Const.SEP)
52
+ if api_type in [Const.DISTRIBUTED]:
53
+ logger.info(f"api {api_data.name} is not supported, skip")
54
+ return
55
+ if self.config.nfs_path:
56
+ self.attl.upload(api_data)
57
+ else:
58
+ self.attl.send(api_data)
59
+
60
+ def attl_stop(self):
61
+ if self.config.nfs_path:
62
+ self.attl.upload("end")
63
+ elif self.attl.socket_manager is not None:
64
+ logger.info(f"pid: {os.getpid()} finished, start sends STOP signal.")
65
+ self.attl.socket_manager.send_stop_signal()
@@ -117,6 +117,12 @@ def fusion_attention_forward(forward_params):
117
117
  pse = forward_params.pse
118
118
  scale = forward_params.scale
119
119
  keep_prob = forward_params.keep_prob
120
+
121
+ # 除零风险拦截:keep_prob 为 0 时会导致除零错误
122
+ if keep_prob == 0:
123
+ raise ValueError("fusion_attention_forward: keep_prob cannot be zero to avoid division by zero.")
124
+
125
+
120
126
  qk = calculate_qk(q, k, atten_mask, pse, scale)
121
127
  softmax_res, softmax_max, softmax_sum = softmax_forward(qk)
122
128
  if drop_mask is None or len(drop_mask.shape) == 0:
@@ -137,6 +143,11 @@ def fusion_attention_backward(backward_params):
137
143
  pse = backward_params.pse
138
144
  scale = backward_params.scale
139
145
  keep_prob = backward_params.keep_prob
146
+
147
+ # 除零风险拦截:keep_prob 为 0 时会导致除零错误
148
+ if keep_prob == 0:
149
+ raise ValueError("fusion_attention_backward: keep_prob cannot be zero to avoid division by zero.")
150
+
140
151
  dp = torch.matmul(dx, v.permute(0, 1, 3, 2))
141
152
  if drop_mask is None or len(drop_mask.shape) == 0:
142
153
  drop_res = softmax_res.permute(0, 1, 3, 2)
@@ -164,23 +175,35 @@ def parse_bsnd_args(query, key, head_num, input_layout):
164
175
  if input_layout == "BSH":
165
176
  b, s1, h1 = query.shape
166
177
  _, s2, h2 = key.shape
178
+ if n1 == 0:
179
+ raise ValueError("parse_bsnd_args: head_num (n1) cannot be zero to avoid division by zero.")
167
180
  d = h1 // n1
181
+ if d == 0:
182
+ raise ValueError("parse_bsnd_args: computed head dimension (d) is zero, division by zero risk.")
168
183
  n2 = h2 // d
169
184
  elif input_layout == "SBH":
170
185
  s1, b, h1 = query.shape
171
186
  s2, _, h2 = key.shape
187
+ if n1 == 0:
188
+ raise ValueError("parse_bsnd_args: head_num (n1) cannot be zero to avoid division by zero.")
172
189
  d = h1 // n1
190
+ if d == 0:
191
+ raise ValueError("parse_bsnd_args: computed head dimension (d) is zero, division by zero risk.")
173
192
  n2 = h2 // d
174
193
  elif input_layout == "BSND":
175
194
  b, s1, n1, d = query.shape
176
195
  _, s2, n2, _ = key.shape
177
196
  h1 = n1 * d
178
197
  h2 = n2 * d
198
+ if d == 0:
199
+ raise ValueError("parse_bsnd_args: head dimension (d) is zero, division by zero risk.")
179
200
  elif input_layout == "BNSD":
180
201
  b, n1, s1, d = query.shape
181
202
  _, n2, s2, _ = key.shape
182
203
  h1 = n1 * d
183
204
  h2 = n2 * d
205
+ if d == 0:
206
+ raise ValueError("parse_bsnd_args: head dimension (d) is zero, division by zero risk.")
184
207
  except Exception as e:
185
208
  raise ValueError(f"query.shape: {query.shape}, key.shape: {key.shape}, parse_bsnd_args error: {e}") from e
186
209
 
@@ -446,6 +469,8 @@ def npu_fusion_attention_forward_patch(*args, **kwargs):
446
469
  input_layout = get_input_layout(*args, **kwargs)
447
470
 
448
471
  b, s1, s2, n1, n2, d, h1, h2, dtype = parse_bsnd_args(args[0], args[1], head_num, input_layout)
472
+ if d == 0:
473
+ raise ValueError("npu_fusion_attention_forward_patch: head dimension (d) is zero, division by zero risk.")
449
474
  if n1 == n2 and s1 == s2:
450
475
  logger.debug(f"running case : BNSD = {b}_{n1}_{s1}_{d}, sparse = {kwargs.get('sparse_mode', 0)}")
451
476
  else:
@@ -478,6 +503,8 @@ def npu_fusion_attention_backward_patch(*args, **kwargs):
478
503
  raise ValueError(f"Unsupported npu_fusion_attention_grad args {args}.")
479
504
 
480
505
  b, s1, s2, n1, n2, d, h1, h2, dtype = parse_bsnd_args(args[0], args[1], args[4], args[5])
506
+ if d == 0:
507
+ raise ValueError("npu_fusion_attention_backward_patch: head dimension (d) is zero, division by zero risk.")
481
508
  if n1 == n2 and s1 == s2:
482
509
  logger.info(f"running case : bnsd = {b}_{n1}_{s1}_{d}, sparse = {kwargs.get('sparse_mode', 0)}")
483
510
  else:
@@ -24,6 +24,7 @@ from functools import wraps
24
24
  import numpy as np
25
25
  import torch
26
26
  import torch.distributed as dist
27
+
27
28
  from msprobe.core.common.exceptions import DistributedNotInitializedError
28
29
  from msprobe.core.common.file_utils import (FileCheckConst, change_mode,
29
30
  check_file_or_directory_path, check_path_before_create, FileOpen)
@@ -38,7 +39,9 @@ except ImportError:
38
39
  else:
39
40
  is_gpu = False
40
41
 
42
+
41
43
  torch_without_guard_version = torch.__version__ >= '2.1'
44
+ torch_version_above_or_equal_2 = torch.__version__.split('+')[0] >= '2.0'
42
45
 
43
46
  if not is_gpu and not torch_without_guard_version:
44
47
  from torch_npu.utils.device_guard import torch_device_guard as torch_npu_device_guard
@@ -313,14 +316,14 @@ def print_rank_0(message):
313
316
  logger.info(message)
314
317
 
315
318
 
316
- def load_pt(pt_path, to_cpu=False):
319
+ def load_pt(pt_path, to_cpu=False, weights_only=True):
317
320
  pt_path = os.path.realpath(pt_path)
318
321
  check_file_or_directory_path(pt_path)
319
322
  try:
320
323
  if to_cpu:
321
- pt = torch.load(pt_path, map_location=torch.device("cpu"), weights_only=True)
324
+ pt = torch.load(pt_path, map_location=torch.device("cpu"), weights_only=weights_only)
322
325
  else:
323
- pt = torch.load(pt_path, weights_only=True)
326
+ pt = torch.load(pt_path, weights_only=weights_only)
324
327
  except Exception as e:
325
328
  raise RuntimeError(f"load pt file {pt_path} failed") from e
326
329
  return pt
@@ -395,7 +398,7 @@ def save_api_data(api_data):
395
398
  io_buff = io.BytesIO()
396
399
  torch.save(api_data, io_buff)
397
400
  except Exception as e:
398
- raise RuntimeError(f"save api_data to io_buff failed") from e
401
+ raise RuntimeError("save api_data to io_buff failed") from e
399
402
  return io_buff
400
403
 
401
404
 
@@ -403,9 +406,9 @@ def load_api_data(api_data_bytes):
403
406
  """Load data from bytes stream"""
404
407
  try:
405
408
  buffer = io.BytesIO(api_data_bytes)
406
- buffer = torch.load(buffer, map_location="cpu")
409
+ buffer = torch.load(buffer, map_location="cpu", weights_only=False)
407
410
  except Exception as e:
408
- raise RuntimeError(f"load api_data from bytes failed") from e
411
+ raise RuntimeError("load api_data from bytes failed") from e
409
412
  return buffer
410
413
 
411
414
 
@@ -457,7 +460,7 @@ def is_recomputation():
457
460
 
458
461
  def check_save_param(variable, name, save_backward):
459
462
  # try catch this api to skip invalid call
460
- valid_data_types = tuple([torch.Tensor, int, float, str])
463
+ valid_data_types = (torch.Tensor, int, float, str)
461
464
  if not is_save_variable_valid(variable, valid_data_types):
462
465
  valid_data_types_with_nested_types = valid_data_types + (dict, tuple, list)
463
466
  logger.warning("PrecisionDebugger.save variable type not valid, "
@@ -476,13 +479,8 @@ def check_save_param(variable, name, save_backward):
476
479
  raise ValueError
477
480
 
478
481
 
479
- def replace_last_occurrence(text, old, new):
480
- if text is None:
481
- return text
482
- index = text.rfind(old)
483
- if index != -1:
484
- return text[:index] + text[index:].replace(old, new, 1)
485
- return text
482
+ def is_torch_nn_module(variable):
483
+ return isinstance(variable, torch.nn.Module) and not isinstance(variable, torch.jit.ScriptModule)
486
484
 
487
485
 
488
486
  def is_hifloat8_tensor(tensor):
@@ -495,3 +493,17 @@ def is_float8_tensor(tensor):
495
493
  if str(tensor.dtype) in [Const.FLOAT8_E5M2_TYPE, Const.FLOAT8_E4M3FN_TYPE]:
496
494
  return True
497
495
  return is_hifloat8_tensor(tensor)
496
+
497
+
498
+ def register_forward_pre_hook(module, forward_pre_hook):
499
+ if torch_version_above_or_equal_2:
500
+ module.register_forward_pre_hook(forward_pre_hook, with_kwargs=True)
501
+ else:
502
+ module.register_forward_pre_hook(forward_pre_hook)
503
+
504
+
505
+ def register_forward_hook(module, forward_hook):
506
+ if torch_version_above_or_equal_2:
507
+ module.register_forward_hook(forward_hook, with_kwargs=True)
508
+ else:
509
+ module.register_forward_hook(forward_hook)
@@ -1,4 +1,4 @@
1
- # Copyright (c) 2019-2024, Huawei Technologies Co., Ltd.
1
+ # Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
2
2
  # All rights reserved.
3
3
  #
4
4
  # Licensed under the Apache License, Version 2.0 (the "License");
@@ -13,41 +13,9 @@
13
13
  # See the License for the specific language governing permissions and
14
14
  # limitations under the License.
15
15
 
16
- import os
17
-
18
- from msprobe.core.common.exceptions import FileCheckException
19
- from msprobe.core.common.file_utils import create_directory
20
- from msprobe.core.common.utils import CompareException, check_compare_param, check_configuration_param, get_dump_mode, \
21
- set_dump_path
22
- from msprobe.core.compare.acc_compare import ModeConfig
23
- from msprobe.core.compare.utils import check_and_return_dir_contents, extract_json, set_stack_json_path
24
- from msprobe.pytorch.common.log import logger
25
- from msprobe.pytorch.compare.pt_compare import PTComparator, compare
16
+ from msprobe.core.compare.utils import compare_distributed_inner
17
+ from msprobe.pytorch.compare.pt_compare import compare
26
18
 
27
19
 
28
20
  def compare_distributed(npu_dump_dir, bench_dump_dir, output_path, **kwargs):
29
- if kwargs.get("suffix"):
30
- logger.error("Argument 'suffix' is not supported for compare_distributed.")
31
- raise CompareException(CompareException.INVALID_PARAM_ERROR)
32
- is_print_compare_log = kwargs.get("is_print_compare_log", True)
33
- # get the ranks and match by order
34
- npu_ranks = sorted(check_and_return_dir_contents(npu_dump_dir, 'rank'))
35
- bench_ranks = sorted(check_and_return_dir_contents(bench_dump_dir, 'rank'))
36
- if len(npu_ranks) != len(bench_ranks):
37
- logger.error(
38
- "The number of ranks in the two runs are different. "
39
- "Unable to match the ranks. "
40
- "Please use another folder to compare or use compare() api and manually match the ranks.")
41
- raise CompareException(CompareException.INVALID_PATH_ERROR)
42
- for nr, br in zip(npu_ranks, bench_ranks):
43
- npu_data_dir = os.path.join(npu_dump_dir, nr)
44
- bench_data_dir = os.path.join(bench_dump_dir, br)
45
- npu_path = extract_json(npu_data_dir, stack_json=False)
46
- bench_path = extract_json(bench_data_dir, stack_json=False)
47
-
48
- dump_result_param = {
49
- "npu_json_path": npu_path,
50
- "bench_json_path": bench_path,
51
- "is_print_compare_log": is_print_compare_log
52
- }
53
- compare(input_param=dump_result_param, output_path=output_path, suffix=f'_{nr}-{br}', **kwargs)
21
+ compare_distributed_inner(npu_dump_dir, bench_dump_dir, output_path, compare, **kwargs)