mindstudio-probe 1.0.4__py3-none-any.whl → 1.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (194) hide show
  1. {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.0.dist-info}/METADATA +1 -1
  2. mindstudio_probe-1.1.0.dist-info/RECORD +287 -0
  3. msprobe/README.md +46 -16
  4. msprobe/__init__.py +16 -1
  5. msprobe/config.json +0 -2
  6. msprobe/core/advisor/advisor.py +8 -8
  7. msprobe/core/advisor/advisor_const.py +6 -7
  8. msprobe/core/advisor/advisor_result.py +12 -12
  9. msprobe/core/common/const.py +64 -3
  10. msprobe/core/common/exceptions.py +2 -2
  11. msprobe/core/common/file_utils.py +54 -9
  12. msprobe/core/common/inplace_op_checker.py +38 -0
  13. msprobe/core/common/inplace_ops.yaml +251 -0
  14. msprobe/core/common/log.py +21 -11
  15. msprobe/core/common/utils.py +153 -167
  16. msprobe/core/common_config.py +18 -25
  17. msprobe/core/compare/acc_compare.py +209 -36
  18. msprobe/core/compare/check.py +102 -17
  19. msprobe/core/compare/compare_cli.py +21 -1
  20. msprobe/core/compare/highlight.py +41 -5
  21. msprobe/core/compare/multiprocessing_compute.py +33 -8
  22. msprobe/core/compare/npy_compare.py +21 -6
  23. msprobe/core/compare/utils.py +82 -48
  24. msprobe/core/data_dump/data_collector.py +31 -32
  25. msprobe/core/data_dump/data_processor/base.py +45 -22
  26. msprobe/core/data_dump/data_processor/factory.py +20 -3
  27. msprobe/core/data_dump/data_processor/mindspore_processor.py +11 -5
  28. msprobe/core/data_dump/data_processor/pytorch_processor.py +24 -7
  29. msprobe/core/data_dump/json_writer.py +63 -42
  30. msprobe/core/data_dump/scope.py +32 -16
  31. msprobe/core/grad_probe/constant.py +4 -0
  32. msprobe/core/grad_probe/grad_compare.py +2 -3
  33. msprobe/core/grad_probe/utils.py +16 -3
  34. msprobe/docs/01.installation.md +19 -9
  35. msprobe/docs/02.config_introduction.md +52 -80
  36. msprobe/docs/03.config_examples.md +3 -13
  37. msprobe/docs/04.acl_config_examples.md +11 -9
  38. msprobe/docs/05.data_dump_PyTorch.md +140 -12
  39. msprobe/docs/06.data_dump_MindSpore.md +47 -5
  40. msprobe/docs/07.accuracy_checker_PyTorch.md +57 -34
  41. msprobe/docs/08.accuracy_checker_online_PyTorch.md +51 -11
  42. msprobe/docs/09.accuracy_checker_MindSpore.md +8 -8
  43. msprobe/docs/10.accuracy_compare_PyTorch.md +181 -99
  44. msprobe/docs/11.accuracy_compare_MindSpore.md +162 -31
  45. msprobe/docs/13.overflow_check_MindSpore.md +1 -1
  46. msprobe/docs/15.free_benchmarking_PyTorch.md +59 -53
  47. msprobe/docs/16.free_benchmarking_MindSpore.md +140 -0
  48. msprobe/docs/17.grad_probe.md +14 -16
  49. msprobe/docs/18.online_dispatch.md +89 -0
  50. msprobe/docs/{FAQ_PyTorch.md → FAQ.md} +22 -10
  51. msprobe/docs/img/ms_dump.png +0 -0
  52. msprobe/docs/img/ms_layer.png +0 -0
  53. msprobe/docs/img/pt_dump.png +0 -0
  54. msprobe/mindspore/__init__.py +1 -0
  55. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +35 -11
  56. msprobe/mindspore/api_accuracy_checker/api_info.py +7 -0
  57. msprobe/mindspore/cell_processor.py +27 -3
  58. msprobe/mindspore/common/const.py +2 -0
  59. msprobe/mindspore/common/utils.py +18 -2
  60. msprobe/mindspore/compare/distributed_compare.py +9 -22
  61. msprobe/mindspore/compare/layer_mapping.py +146 -0
  62. msprobe/mindspore/compare/modify_mapping.py +107 -0
  63. msprobe/mindspore/compare/ms_compare.py +173 -35
  64. msprobe/mindspore/compare/ms_graph_compare.py +27 -11
  65. msprobe/mindspore/debugger/debugger_config.py +16 -13
  66. msprobe/mindspore/debugger/precision_debugger.py +37 -13
  67. msprobe/mindspore/dump/dump_tool_factory.py +16 -1
  68. msprobe/mindspore/dump/hook_cell/api_registry.py +11 -1
  69. msprobe/mindspore/dump/hook_cell/primitive_hooks.py +206 -0
  70. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +82 -10
  71. msprobe/mindspore/dump/hook_cell/wrap_api.py +21 -13
  72. msprobe/mindspore/dump/jit_dump.py +41 -17
  73. msprobe/mindspore/dump/kernel_graph_dump.py +19 -3
  74. msprobe/mindspore/dump/kernel_kbyk_dump.py +19 -4
  75. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +19 -4
  76. msprobe/mindspore/free_benchmark/common/config.py +15 -0
  77. msprobe/mindspore/free_benchmark/common/handler_params.py +15 -0
  78. msprobe/mindspore/free_benchmark/common/utils.py +19 -5
  79. msprobe/mindspore/free_benchmark/decorator/dec_forward.py +16 -2
  80. msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +18 -3
  81. msprobe/mindspore/free_benchmark/handler/base_handler.py +18 -3
  82. msprobe/mindspore/free_benchmark/handler/check_handler.py +18 -3
  83. msprobe/mindspore/free_benchmark/handler/fix_handler.py +15 -0
  84. msprobe/mindspore/free_benchmark/handler/handler_factory.py +18 -3
  85. msprobe/mindspore/free_benchmark/perturbation/add_noise.py +22 -7
  86. msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +15 -0
  87. msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +22 -7
  88. msprobe/mindspore/free_benchmark/perturbation/exchange_value.py +44 -18
  89. msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +18 -4
  90. msprobe/mindspore/free_benchmark/perturbation/no_change.py +16 -1
  91. msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +20 -5
  92. msprobe/mindspore/free_benchmark/self_check_tool_factory.py +15 -0
  93. msprobe/mindspore/grad_probe/global_context.py +18 -8
  94. msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +20 -4
  95. msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +15 -0
  96. msprobe/mindspore/service.py +42 -123
  97. msprobe/pytorch/__init__.py +20 -1
  98. msprobe/pytorch/api_accuracy_checker/common/config.py +19 -2
  99. msprobe/pytorch/api_accuracy_checker/common/utils.py +53 -21
  100. msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +19 -2
  101. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +47 -21
  102. msprobe/pytorch/api_accuracy_checker/compare/compare.py +51 -21
  103. msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +23 -6
  104. msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +28 -8
  105. msprobe/pytorch/api_accuracy_checker/config.yaml +1 -1
  106. msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +67 -32
  107. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +26 -5
  108. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +19 -2
  109. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +51 -125
  110. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +146 -3
  111. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +21 -0
  112. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +78 -33
  113. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +27 -4
  114. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/dump_dispatch.py +110 -0
  115. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +36 -11
  116. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/torch_ops_config.yaml +63 -0
  117. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +44 -0
  118. msprobe/pytorch/bench_functions/__init__.py +18 -3
  119. msprobe/pytorch/bench_functions/apply_adam_w.py +15 -0
  120. msprobe/pytorch/bench_functions/confusion_transpose.py +15 -0
  121. msprobe/pytorch/bench_functions/fast_gelu.py +15 -0
  122. msprobe/pytorch/bench_functions/layer_norm_eval.py +15 -0
  123. msprobe/pytorch/bench_functions/linear.py +15 -0
  124. msprobe/pytorch/bench_functions/matmul_backward.py +21 -6
  125. msprobe/pytorch/bench_functions/npu_fusion_attention.py +180 -151
  126. msprobe/pytorch/bench_functions/rms_norm.py +15 -0
  127. msprobe/pytorch/bench_functions/rotary_mul.py +28 -9
  128. msprobe/pytorch/bench_functions/scaled_mask_softmax.py +15 -0
  129. msprobe/pytorch/bench_functions/swiglu.py +20 -5
  130. msprobe/pytorch/common/__init__.py +15 -0
  131. msprobe/pytorch/common/log.py +18 -6
  132. msprobe/pytorch/common/parse_json.py +26 -11
  133. msprobe/pytorch/common/utils.py +40 -35
  134. msprobe/pytorch/compare/distributed_compare.py +11 -11
  135. msprobe/pytorch/compare/match.py +15 -0
  136. msprobe/pytorch/compare/pt_compare.py +38 -6
  137. msprobe/pytorch/debugger/debugger_config.py +52 -39
  138. msprobe/pytorch/debugger/precision_debugger.py +72 -24
  139. msprobe/pytorch/free_benchmark/__init__.py +20 -5
  140. msprobe/pytorch/free_benchmark/common/enums.py +28 -0
  141. msprobe/pytorch/free_benchmark/common/params.py +15 -0
  142. msprobe/pytorch/free_benchmark/common/utils.py +17 -1
  143. msprobe/pytorch/free_benchmark/compare/grad_saver.py +28 -7
  144. msprobe/pytorch/free_benchmark/compare/single_benchmark.py +15 -0
  145. msprobe/pytorch/free_benchmark/main.py +19 -4
  146. msprobe/pytorch/free_benchmark/perturbed_layers/base_layer.py +15 -0
  147. msprobe/pytorch/free_benchmark/perturbed_layers/layer_factory.py +19 -4
  148. msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +15 -0
  149. msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +15 -0
  150. msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +26 -2
  151. msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +15 -0
  152. msprobe/pytorch/free_benchmark/perturbed_layers/npu/no_change.py +15 -0
  153. msprobe/pytorch/free_benchmark/perturbed_layers/npu/npu_base_layser.py +15 -0
  154. msprobe/pytorch/free_benchmark/perturbed_layers/run_cpu.py +15 -0
  155. msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +55 -16
  156. msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +15 -0
  157. msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +15 -0
  158. msprobe/pytorch/free_benchmark/result_handlers/handler_factory.py +15 -0
  159. msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +19 -4
  160. msprobe/pytorch/function_factory.py +17 -2
  161. msprobe/pytorch/functional/module_dump.py +84 -0
  162. msprobe/pytorch/grad_probe/grad_stat_csv.py +2 -2
  163. msprobe/pytorch/hook_module/__init__.py +16 -1
  164. msprobe/pytorch/hook_module/api_registry.py +13 -8
  165. msprobe/pytorch/hook_module/hook_module.py +17 -19
  166. msprobe/pytorch/hook_module/utils.py +4 -6
  167. msprobe/pytorch/hook_module/wrap_aten.py +12 -11
  168. msprobe/pytorch/hook_module/wrap_distributed.py +6 -7
  169. msprobe/pytorch/hook_module/wrap_functional.py +10 -11
  170. msprobe/pytorch/hook_module/wrap_npu_custom.py +9 -17
  171. msprobe/pytorch/hook_module/wrap_tensor.py +4 -6
  172. msprobe/pytorch/hook_module/wrap_torch.py +4 -6
  173. msprobe/pytorch/hook_module/wrap_vf.py +4 -6
  174. msprobe/pytorch/module_processer.py +17 -2
  175. msprobe/pytorch/online_dispatch/compare.py +11 -12
  176. msprobe/pytorch/online_dispatch/single_compare.py +7 -7
  177. msprobe/pytorch/online_dispatch/torch_ops_config.yaml +8 -0
  178. msprobe/pytorch/online_dispatch/utils.py +1 -4
  179. msprobe/pytorch/parse.py +15 -0
  180. msprobe/pytorch/parse_tool/cli.py +5 -6
  181. msprobe/pytorch/parse_tool/lib/compare.py +9 -10
  182. msprobe/pytorch/parse_tool/lib/parse_tool.py +3 -0
  183. msprobe/pytorch/parse_tool/lib/utils.py +28 -24
  184. msprobe/pytorch/parse_tool/lib/visualization.py +1 -1
  185. msprobe/pytorch/pt_config.py +167 -38
  186. msprobe/pytorch/service.py +97 -32
  187. mindstudio_probe-1.0.4.dist-info/RECORD +0 -276
  188. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/ssl_config.py +0 -10
  189. msprobe/pytorch/functional/data_processor.py +0 -0
  190. msprobe/pytorch/functional/dump_module.py +0 -39
  191. {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.0.dist-info}/LICENSE +0 -0
  192. {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.0.dist-info}/WHEEL +0 -0
  193. {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.0.dist-info}/entry_points.txt +0 -0
  194. {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.0.dist-info}/top_level.txt +0 -0
@@ -1,3 +1,18 @@
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
1
16
  import os.path
2
17
  import struct
3
18
  import hashlib
@@ -8,7 +23,8 @@ from threading import Thread
8
23
  from twisted.internet import reactor, protocol, endpoints
9
24
 
10
25
  from msprobe.pytorch.common.utils import logger
11
- from msprobe.pytorch.api_accuracy_checker.tensor_transport_layer.ssl_config import cipher_list
26
+ from msprobe.pytorch.api_accuracy_checker.tensor_transport_layer.utils import cipher_list, \
27
+ struct_unpack_mode as unpack_mode, str_to_bytes_order as bytes_order
12
28
 
13
29
 
14
30
  class TCPServer:
@@ -24,14 +40,22 @@ class TCPServer:
24
40
  def run_reactor():
25
41
  reactor.run(installSignalHandlers=False)
26
42
 
43
+ def check_tls_path(self):
44
+ server_key = os.path.join(self.tls_path, "server.key")
45
+ server_crt = os.path.join(self.tls_path, "server.crt")
46
+ if not os.path.exists(server_key):
47
+ raise Exception(f"server_key: {server_key} is not exists.")
48
+ if not os.path.exists(server_crt):
49
+ raise Exception(f"server_crt: {server_crt} is not exists.")
50
+ return server_key, server_crt
51
+
27
52
  def start(self):
28
53
  self.factory.protocol = self.build_protocol
29
54
 
30
55
  if self.tls_path:
31
56
  from OpenSSL import SSL
32
57
  from twisted.internet import ssl
33
- server_key = os.path.join(self.tls_path, "server.key")
34
- server_crt = os.path.join(self.tls_path, "server.crt")
58
+ server_key, server_crt = self.check_tls_path()
35
59
  server_context_factory = ssl.DefaultOpenSSLContextFactory(server_key, server_crt, SSL.TLSv1_2_METHOD)
36
60
  server_context_ = server_context_factory.getContext()
37
61
  server_context_.set_cipher_list(cipher_list)
@@ -100,9 +124,9 @@ class ServerProtocol(protocol.Protocol):
100
124
  def send_ack(self, ack_info):
101
125
  ack_message = b"".join([
102
126
  ack_info,
103
- self.sequence_number.to_bytes(8, byteorder='big'),
104
- self.rank.to_bytes(8, byteorder='big'),
105
- self.step.to_bytes(8, byteorder='big')
127
+ self.sequence_number.to_bytes(8, byteorder=bytes_order),
128
+ self.rank.to_bytes(8, byteorder=bytes_order),
129
+ self.step.to_bytes(8, byteorder=bytes_order)
106
130
  ])
107
131
  self.transport.write(ack_message)
108
132
 
@@ -168,10 +192,10 @@ class ServerProtocol(protocol.Protocol):
168
192
  # The first data packet is packet header, it contains obj_length, sequence_number, rank, step
169
193
  if self.obj_length is None and len(self.buffer.getvalue()) >= self.length_width * 4:
170
194
  self.start_time = time.time()
171
- self.obj_length = struct.unpack('!Q', self.buffer.read(self.length_width))[0]
172
- self.sequence_number = struct.unpack('!Q', self.buffer.read(self.length_width))[0]
173
- self.rank = struct.unpack('!Q', self.buffer.read(self.length_width))[0]
174
- self.step = struct.unpack('!Q', self.buffer.read(self.length_width))[0]
195
+ self.obj_length = struct.unpack(unpack_mode, self.buffer.read(self.length_width))[0]
196
+ self.sequence_number = struct.unpack(unpack_mode, self.buffer.read(self.length_width))[0]
197
+ self.rank = struct.unpack(unpack_mode, self.buffer.read(self.length_width))[0]
198
+ self.step = struct.unpack(unpack_mode, self.buffer.read(self.length_width))[0]
175
199
  self.tell += self.length_width * 4
176
200
  logger.debug(
177
201
  f"流水号: {self.sequence_number}; RANK: {self.rank}; STEP: {self.step}; Length: {self.obj_length}")
@@ -210,7 +234,8 @@ class MessageServerFactory(protocol.ServerFactory):
210
234
  def __init__(self) -> None:
211
235
  """
212
236
  transport_dict: links that have not completed data transmission.
213
- transport_list: Records all TCP links. Appends TCP link to the transport list when a new TCP link is established.
237
+ transport_list: Records all TCP links. Appends TCP link to the transport list
238
+ when a new TCP link is established.
214
239
  """
215
240
  self.transport_dict = {}
216
241
  self.transport_list = []
@@ -0,0 +1,63 @@
1
+ aten_ops_blacklist:
2
+ - npu_binary_cross_entropy_with_logits_backward
3
+ - npu_ciou_backward
4
+ - _cudnn_rnn
5
+ - _local_scalar_dense
6
+ - _pin_memory
7
+ - _to_copy
8
+ - _unsafe_view
9
+ - clone
10
+ - contiguous
11
+ - copy_
12
+ - cudnn_batch_norm
13
+ - cudnn_batch_norm_backward
14
+ - detach
15
+ - empty
16
+ - index_put_
17
+ - lift_fresh
18
+ - max_pool2d_with_indices_backward # shape unmatch
19
+ - native_batch_norm_backward
20
+ - new_empty
21
+ - new_empty_strided
22
+ - new_full
23
+ - new_ones
24
+ - new_zeros
25
+ - ones
26
+ - ones_like
27
+ - permute
28
+ - rand
29
+ - rand_like
30
+ - randint
31
+ - randint_like
32
+ - randn
33
+ - randn_like
34
+ - randperm
35
+ - scalar_tensor
36
+ - select
37
+ - to
38
+ - transpose
39
+ - unbind
40
+ - view
41
+ - zero
42
+ - zero_
43
+ - zeros
44
+ - zeros_like
45
+ - _record_function_enter_new
46
+ - _record_function_exit
47
+ - broadcast_
48
+ - allreduce_
49
+ - npu_clear_float_status
50
+ - npu_format_cast
51
+ - npu_dtype_cast
52
+ - npu_dtype_cast_backward
53
+ - _allgather_base_
54
+ - _reduce_scatter_base_
55
+ - is_same_size
56
+
57
+ npu_adjust_autogard:
58
+ - adaptive_avg_pool2d
59
+ - batch_norm
60
+ - log_softmax
61
+ - nll_loss
62
+ - to
63
+
@@ -0,0 +1,44 @@
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ cipher_list = ":".join(
17
+ ["TLS_DHE_RSA_WITH_AES_128_GCM_SHA256",
18
+ "TLS_DHE_RSA_WITH_AES_256_GCM_SHA384",
19
+ "TLS_DHE_DSS_WITH_AES_128_GCM_SHA256",
20
+ "TLS_DHE_DSS_WITH_AES_256_GCM_SHA384",
21
+ "TLS_DHE_PSK_WITH_AES_128_GCM_SHA256",
22
+ "TLS_DHE_PSK_WITH_AES_256_GCM_SHA384",
23
+ "TLS_DHE_PSK_WITH_CHACHA20_POLY1305_SHA256",
24
+ "TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256",
25
+ "TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384",
26
+ "TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256",
27
+ "TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384",
28
+ "TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256",
29
+ "TLS_ECDHE_PSK_WITH_CHACHA20_POLY1305_SHA256",
30
+ "TLS_ECDHE_PSK_WITH_AES_128_GCM_SHA256",
31
+ "TLS_ECDHE_PSK_WITH_AES_256_GCM_SHA384",
32
+ "TLS_ECDHE_PSK_WITH_AES_128_CCM_SHA256",
33
+ "TLS_DHE_RSA_WITH_AES_128_CCM",
34
+ "TLS_DHE_RSA_WITH_AES_256_CCM",
35
+ "TLS_DHE_RSA_WITH_CHACHA20_POLY1305_SHA256",
36
+ "TLS_DHE_PSK_WITH_AES_128_CCM",
37
+ "TLS_DHE_PSK_WITH_AES_256_CCM",
38
+ "TLS_ECDHE_ECDSA_WITH_AES_128_CCM",
39
+ "TLS_ECDHE_ECDSA_WITH_AES_256_CCM",
40
+ "TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256"]
41
+ ).encode()
42
+
43
+ struct_unpack_mode = "!Q"
44
+ str_to_bytes_order = "big"
@@ -1,11 +1,26 @@
1
- import os
2
- from pkgutil import iter_modules
3
- from importlib import import_module
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
4
15
 
5
16
  """
6
17
  gpu and cpu not implement benchmark function, supplementary benchmarking function implementation
7
18
  """
8
19
 
20
+ import os
21
+ from pkgutil import iter_modules
22
+ from importlib import import_module
23
+
9
24
  package_path = os.path.dirname(os.path.realpath(__file__))
10
25
  for _, module_name, _ in iter_modules([package_path]):
11
26
  module = import_module(f"{__name__}.{module_name}")
@@ -1,3 +1,18 @@
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
1
16
  import torch
2
17
 
3
18
 
@@ -1,3 +1,18 @@
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
1
16
  def npu_confusion_transpose(data, perm, shape, transpose_first):
2
17
  if transpose_first:
3
18
  output = data.permute(*perm).contiguous().view(shape)
@@ -1,3 +1,18 @@
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
1
16
  import torch
2
17
 
3
18
 
@@ -1,3 +1,18 @@
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
1
16
  import torch
2
17
 
3
18
 
@@ -1,3 +1,18 @@
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
1
16
  import torch
2
17
 
3
18
 
@@ -1,3 +1,18 @@
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
1
16
  import torch
2
17
 
3
18
 
@@ -29,18 +44,18 @@ def matmul_backward(grad, self, other, mask):
29
44
  grad_other = unfolded_self.transpose(-1, -2).mm(unfolded_grad).view(size_other)
30
45
  elif (dim_self == 1 or dim_self == 2) and dim_other >= 3:
31
46
  view_size = 1 if dim_self == 1 else size_grad[-2]
32
- unfolded_grad_T = grad.view([-1, view_size]) \
47
+ unfolded_grad_t = grad.view([-1, view_size]) \
33
48
  if dim_self == 1 else grad.transpose(-1, -2).contiguous().view([-1, view_size])
34
49
  if mask[0]:
35
50
  # create a 2D-matrix from other
36
- unfolded_other_T = \
51
+ unfolded_other_t = \
37
52
  other.transpose(-1, -2).contiguous().view([-1, size_other[-2]]).transpose(-1, -2)
38
- grad_self = unfolded_other_T.mm(unfolded_grad_T).transpose(-1, -2).view(size_self)
53
+ grad_self = unfolded_other_t.mm(unfolded_grad_t).transpose(-1, -2).view(size_self)
39
54
  if mask[1]:
40
- size_other_T = size_other[:-2]
41
- size_other_T.extend(size_other[::-1][:2])
55
+ size_other_t = size_other[:-2]
56
+ size_other_t.extend(size_other[::-1][:2])
42
57
  grad_other = \
43
- unfolded_grad_T.mm(self.unsqueeze(0) if dim_self == 1 else self).view(size_other_T).transpose(-1, -2)
58
+ unfolded_grad_t.mm(self.unsqueeze(0) if dim_self == 1 else self).view(size_other_t).transpose(-1, -2)
44
59
  else:
45
60
  grad_self = torch.matmul(grad, other.transpose(-1, -2)) if mask[0] else grad_self
46
61
  grad_other = torch.matmul(self.transpose(-1, -2), grad) if mask[1] else grad_other