mindstudio-probe 1.0.4__py3-none-any.whl → 1.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (194) hide show
  1. {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.0.dist-info}/METADATA +1 -1
  2. mindstudio_probe-1.1.0.dist-info/RECORD +287 -0
  3. msprobe/README.md +46 -16
  4. msprobe/__init__.py +16 -1
  5. msprobe/config.json +0 -2
  6. msprobe/core/advisor/advisor.py +8 -8
  7. msprobe/core/advisor/advisor_const.py +6 -7
  8. msprobe/core/advisor/advisor_result.py +12 -12
  9. msprobe/core/common/const.py +64 -3
  10. msprobe/core/common/exceptions.py +2 -2
  11. msprobe/core/common/file_utils.py +54 -9
  12. msprobe/core/common/inplace_op_checker.py +38 -0
  13. msprobe/core/common/inplace_ops.yaml +251 -0
  14. msprobe/core/common/log.py +21 -11
  15. msprobe/core/common/utils.py +153 -167
  16. msprobe/core/common_config.py +18 -25
  17. msprobe/core/compare/acc_compare.py +209 -36
  18. msprobe/core/compare/check.py +102 -17
  19. msprobe/core/compare/compare_cli.py +21 -1
  20. msprobe/core/compare/highlight.py +41 -5
  21. msprobe/core/compare/multiprocessing_compute.py +33 -8
  22. msprobe/core/compare/npy_compare.py +21 -6
  23. msprobe/core/compare/utils.py +82 -48
  24. msprobe/core/data_dump/data_collector.py +31 -32
  25. msprobe/core/data_dump/data_processor/base.py +45 -22
  26. msprobe/core/data_dump/data_processor/factory.py +20 -3
  27. msprobe/core/data_dump/data_processor/mindspore_processor.py +11 -5
  28. msprobe/core/data_dump/data_processor/pytorch_processor.py +24 -7
  29. msprobe/core/data_dump/json_writer.py +63 -42
  30. msprobe/core/data_dump/scope.py +32 -16
  31. msprobe/core/grad_probe/constant.py +4 -0
  32. msprobe/core/grad_probe/grad_compare.py +2 -3
  33. msprobe/core/grad_probe/utils.py +16 -3
  34. msprobe/docs/01.installation.md +19 -9
  35. msprobe/docs/02.config_introduction.md +52 -80
  36. msprobe/docs/03.config_examples.md +3 -13
  37. msprobe/docs/04.acl_config_examples.md +11 -9
  38. msprobe/docs/05.data_dump_PyTorch.md +140 -12
  39. msprobe/docs/06.data_dump_MindSpore.md +47 -5
  40. msprobe/docs/07.accuracy_checker_PyTorch.md +57 -34
  41. msprobe/docs/08.accuracy_checker_online_PyTorch.md +51 -11
  42. msprobe/docs/09.accuracy_checker_MindSpore.md +8 -8
  43. msprobe/docs/10.accuracy_compare_PyTorch.md +181 -99
  44. msprobe/docs/11.accuracy_compare_MindSpore.md +162 -31
  45. msprobe/docs/13.overflow_check_MindSpore.md +1 -1
  46. msprobe/docs/15.free_benchmarking_PyTorch.md +59 -53
  47. msprobe/docs/16.free_benchmarking_MindSpore.md +140 -0
  48. msprobe/docs/17.grad_probe.md +14 -16
  49. msprobe/docs/18.online_dispatch.md +89 -0
  50. msprobe/docs/{FAQ_PyTorch.md → FAQ.md} +22 -10
  51. msprobe/docs/img/ms_dump.png +0 -0
  52. msprobe/docs/img/ms_layer.png +0 -0
  53. msprobe/docs/img/pt_dump.png +0 -0
  54. msprobe/mindspore/__init__.py +1 -0
  55. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +35 -11
  56. msprobe/mindspore/api_accuracy_checker/api_info.py +7 -0
  57. msprobe/mindspore/cell_processor.py +27 -3
  58. msprobe/mindspore/common/const.py +2 -0
  59. msprobe/mindspore/common/utils.py +18 -2
  60. msprobe/mindspore/compare/distributed_compare.py +9 -22
  61. msprobe/mindspore/compare/layer_mapping.py +146 -0
  62. msprobe/mindspore/compare/modify_mapping.py +107 -0
  63. msprobe/mindspore/compare/ms_compare.py +173 -35
  64. msprobe/mindspore/compare/ms_graph_compare.py +27 -11
  65. msprobe/mindspore/debugger/debugger_config.py +16 -13
  66. msprobe/mindspore/debugger/precision_debugger.py +37 -13
  67. msprobe/mindspore/dump/dump_tool_factory.py +16 -1
  68. msprobe/mindspore/dump/hook_cell/api_registry.py +11 -1
  69. msprobe/mindspore/dump/hook_cell/primitive_hooks.py +206 -0
  70. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +82 -10
  71. msprobe/mindspore/dump/hook_cell/wrap_api.py +21 -13
  72. msprobe/mindspore/dump/jit_dump.py +41 -17
  73. msprobe/mindspore/dump/kernel_graph_dump.py +19 -3
  74. msprobe/mindspore/dump/kernel_kbyk_dump.py +19 -4
  75. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +19 -4
  76. msprobe/mindspore/free_benchmark/common/config.py +15 -0
  77. msprobe/mindspore/free_benchmark/common/handler_params.py +15 -0
  78. msprobe/mindspore/free_benchmark/common/utils.py +19 -5
  79. msprobe/mindspore/free_benchmark/decorator/dec_forward.py +16 -2
  80. msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +18 -3
  81. msprobe/mindspore/free_benchmark/handler/base_handler.py +18 -3
  82. msprobe/mindspore/free_benchmark/handler/check_handler.py +18 -3
  83. msprobe/mindspore/free_benchmark/handler/fix_handler.py +15 -0
  84. msprobe/mindspore/free_benchmark/handler/handler_factory.py +18 -3
  85. msprobe/mindspore/free_benchmark/perturbation/add_noise.py +22 -7
  86. msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +15 -0
  87. msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +22 -7
  88. msprobe/mindspore/free_benchmark/perturbation/exchange_value.py +44 -18
  89. msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +18 -4
  90. msprobe/mindspore/free_benchmark/perturbation/no_change.py +16 -1
  91. msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +20 -5
  92. msprobe/mindspore/free_benchmark/self_check_tool_factory.py +15 -0
  93. msprobe/mindspore/grad_probe/global_context.py +18 -8
  94. msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +20 -4
  95. msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +15 -0
  96. msprobe/mindspore/service.py +42 -123
  97. msprobe/pytorch/__init__.py +20 -1
  98. msprobe/pytorch/api_accuracy_checker/common/config.py +19 -2
  99. msprobe/pytorch/api_accuracy_checker/common/utils.py +53 -21
  100. msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +19 -2
  101. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +47 -21
  102. msprobe/pytorch/api_accuracy_checker/compare/compare.py +51 -21
  103. msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +23 -6
  104. msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +28 -8
  105. msprobe/pytorch/api_accuracy_checker/config.yaml +1 -1
  106. msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +67 -32
  107. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +26 -5
  108. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +19 -2
  109. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +51 -125
  110. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +146 -3
  111. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +21 -0
  112. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +78 -33
  113. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +27 -4
  114. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/dump_dispatch.py +110 -0
  115. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +36 -11
  116. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/torch_ops_config.yaml +63 -0
  117. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +44 -0
  118. msprobe/pytorch/bench_functions/__init__.py +18 -3
  119. msprobe/pytorch/bench_functions/apply_adam_w.py +15 -0
  120. msprobe/pytorch/bench_functions/confusion_transpose.py +15 -0
  121. msprobe/pytorch/bench_functions/fast_gelu.py +15 -0
  122. msprobe/pytorch/bench_functions/layer_norm_eval.py +15 -0
  123. msprobe/pytorch/bench_functions/linear.py +15 -0
  124. msprobe/pytorch/bench_functions/matmul_backward.py +21 -6
  125. msprobe/pytorch/bench_functions/npu_fusion_attention.py +180 -151
  126. msprobe/pytorch/bench_functions/rms_norm.py +15 -0
  127. msprobe/pytorch/bench_functions/rotary_mul.py +28 -9
  128. msprobe/pytorch/bench_functions/scaled_mask_softmax.py +15 -0
  129. msprobe/pytorch/bench_functions/swiglu.py +20 -5
  130. msprobe/pytorch/common/__init__.py +15 -0
  131. msprobe/pytorch/common/log.py +18 -6
  132. msprobe/pytorch/common/parse_json.py +26 -11
  133. msprobe/pytorch/common/utils.py +40 -35
  134. msprobe/pytorch/compare/distributed_compare.py +11 -11
  135. msprobe/pytorch/compare/match.py +15 -0
  136. msprobe/pytorch/compare/pt_compare.py +38 -6
  137. msprobe/pytorch/debugger/debugger_config.py +52 -39
  138. msprobe/pytorch/debugger/precision_debugger.py +72 -24
  139. msprobe/pytorch/free_benchmark/__init__.py +20 -5
  140. msprobe/pytorch/free_benchmark/common/enums.py +28 -0
  141. msprobe/pytorch/free_benchmark/common/params.py +15 -0
  142. msprobe/pytorch/free_benchmark/common/utils.py +17 -1
  143. msprobe/pytorch/free_benchmark/compare/grad_saver.py +28 -7
  144. msprobe/pytorch/free_benchmark/compare/single_benchmark.py +15 -0
  145. msprobe/pytorch/free_benchmark/main.py +19 -4
  146. msprobe/pytorch/free_benchmark/perturbed_layers/base_layer.py +15 -0
  147. msprobe/pytorch/free_benchmark/perturbed_layers/layer_factory.py +19 -4
  148. msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +15 -0
  149. msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +15 -0
  150. msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +26 -2
  151. msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +15 -0
  152. msprobe/pytorch/free_benchmark/perturbed_layers/npu/no_change.py +15 -0
  153. msprobe/pytorch/free_benchmark/perturbed_layers/npu/npu_base_layser.py +15 -0
  154. msprobe/pytorch/free_benchmark/perturbed_layers/run_cpu.py +15 -0
  155. msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +55 -16
  156. msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +15 -0
  157. msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +15 -0
  158. msprobe/pytorch/free_benchmark/result_handlers/handler_factory.py +15 -0
  159. msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +19 -4
  160. msprobe/pytorch/function_factory.py +17 -2
  161. msprobe/pytorch/functional/module_dump.py +84 -0
  162. msprobe/pytorch/grad_probe/grad_stat_csv.py +2 -2
  163. msprobe/pytorch/hook_module/__init__.py +16 -1
  164. msprobe/pytorch/hook_module/api_registry.py +13 -8
  165. msprobe/pytorch/hook_module/hook_module.py +17 -19
  166. msprobe/pytorch/hook_module/utils.py +4 -6
  167. msprobe/pytorch/hook_module/wrap_aten.py +12 -11
  168. msprobe/pytorch/hook_module/wrap_distributed.py +6 -7
  169. msprobe/pytorch/hook_module/wrap_functional.py +10 -11
  170. msprobe/pytorch/hook_module/wrap_npu_custom.py +9 -17
  171. msprobe/pytorch/hook_module/wrap_tensor.py +4 -6
  172. msprobe/pytorch/hook_module/wrap_torch.py +4 -6
  173. msprobe/pytorch/hook_module/wrap_vf.py +4 -6
  174. msprobe/pytorch/module_processer.py +17 -2
  175. msprobe/pytorch/online_dispatch/compare.py +11 -12
  176. msprobe/pytorch/online_dispatch/single_compare.py +7 -7
  177. msprobe/pytorch/online_dispatch/torch_ops_config.yaml +8 -0
  178. msprobe/pytorch/online_dispatch/utils.py +1 -4
  179. msprobe/pytorch/parse.py +15 -0
  180. msprobe/pytorch/parse_tool/cli.py +5 -6
  181. msprobe/pytorch/parse_tool/lib/compare.py +9 -10
  182. msprobe/pytorch/parse_tool/lib/parse_tool.py +3 -0
  183. msprobe/pytorch/parse_tool/lib/utils.py +28 -24
  184. msprobe/pytorch/parse_tool/lib/visualization.py +1 -1
  185. msprobe/pytorch/pt_config.py +167 -38
  186. msprobe/pytorch/service.py +97 -32
  187. mindstudio_probe-1.0.4.dist-info/RECORD +0 -276
  188. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/ssl_config.py +0 -10
  189. msprobe/pytorch/functional/data_processor.py +0 -0
  190. msprobe/pytorch/functional/dump_module.py +0 -39
  191. {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.0.dist-info}/LICENSE +0 -0
  192. {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.0.dist-info}/WHEEL +0 -0
  193. {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.0.dist-info}/entry_points.txt +0 -0
  194. {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.0.dist-info}/top_level.txt +0 -0
@@ -1,8 +1,35 @@
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
4
+ # All rights reserved.
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+
1
18
  import os
2
19
  import re
20
+ import torch
21
+
22
+ try:
23
+ import torch_npu
24
+ except ImportError:
25
+ current_device = "cuda"
26
+ else:
27
+ current_device = "npu"
3
28
 
4
- from msprobe.core.common.const import FileCheckConst
29
+ from msprobe.core.common.const import FileCheckConst, Const, CompareConst
5
30
  from msprobe.core.common.file_utils import FileChecker
31
+ from msprobe.core.common.log import logger
32
+ from msprobe.core.common.utils import CompareException
6
33
  from msprobe.pytorch.hook_module.wrap_aten import AtenOPTemplate
7
34
  from msprobe.pytorch.hook_module.wrap_functional import FunctionalOPTemplate
8
35
  from msprobe.pytorch.hook_module.wrap_npu_custom import NpuOPTemplate
@@ -10,11 +37,20 @@ from msprobe.pytorch.hook_module.wrap_tensor import TensorOPTemplate
10
37
  from msprobe.pytorch.hook_module.wrap_torch import TorchOPTemplate
11
38
 
12
39
  hf_32_standard_api = ["conv1d", "conv2d"]
40
+ not_detach_set = {'resize_', 'resize_as_', 'set_', 'transpose_', 't_', 'squeeze_', 'unsqueeze_'}
41
+ not_raise_dtype_set = {'type_as'}
42
+
43
+ PRECISION_MAPPING = {
44
+ torch.float16: torch.float32,
45
+ torch.bfloat16: torch.float32,
46
+ torch.float32: torch.float64
47
+ }
13
48
 
14
49
 
15
- class Backward_Message:
50
+ class BackwardMessage:
16
51
  MULTIPLE_BACKWARD_MESSAGE = "Multiple backward is not supported."
17
- UNSUPPORT_BACKWARD_MESSAGE = "function with out=... arguments don't support automatic differentiation, skip backward."
52
+ UNSUPPORT_BACKWARD_MESSAGE = "function with out=... arguments don't support automatic differentiation, " \
53
+ "skip backward."
18
54
  NO_BACKWARD_RESULT_MESSAGE = "function backward result is None, skip backward."
19
55
 
20
56
 
@@ -68,3 +104,110 @@ def exec_api(api_type, api_name, device, args, kwargs):
68
104
  torch_api = NpuOPTemplate(api_name, None, False, device)
69
105
  out = torch_api.forward(*args, **kwargs)
70
106
  return out
107
+
108
+
109
+ def deal_detach(arg, to_detach=True):
110
+ return arg.detach() if to_detach else arg
111
+
112
+
113
+ def raise_bench_data_dtype(api_name, arg, raise_dtype=None):
114
+ '''
115
+ 将标杆数据的dtype转换为raise_dtype
116
+ 输入:
117
+ api_name:api名称
118
+ arg:标杆输入
119
+ raise_dtype:需要转换的dtype
120
+ 输出:
121
+ arg: 转换dtype的标杆输入
122
+ '''
123
+ if api_name in hf_32_standard_api and arg.dtype == torch.float32:
124
+ return arg
125
+ if raise_dtype is None or arg.dtype not in PRECISION_MAPPING or raise_dtype == arg.dtype:
126
+ return arg
127
+ return arg.type(raise_dtype)
128
+
129
+
130
+ def generate_device_params(input_args, input_kwargs, need_backward, api_name):
131
+ def recursive_arg_to_device(arg_in, to_detach, depth=0):
132
+ if depth > Const.MAX_DEPTH:
133
+ logger.error("The depth of arg_in is too large, please check the arg_in.")
134
+ raise CompareException(CompareException.RECURSION_LIMIT_ERROR)
135
+ if isinstance(arg_in, (list, tuple)):
136
+ return type(arg_in)(recursive_arg_to_device(arg, to_detach, depth=depth+1) for arg in arg_in)
137
+ elif isinstance(arg_in, torch.Tensor):
138
+ if need_backward and arg_in.requires_grad:
139
+ arg_in = deal_detach(arg_in.clone(), to_detach).to(current_device).requires_grad_()
140
+ temp_arg_in = arg_in * 1
141
+ arg_in = temp_arg_in.type_as(arg_in)
142
+ arg_in.retain_grad()
143
+ return arg_in
144
+ else:
145
+ return deal_detach(arg_in.clone(), to_detach).to(current_device)
146
+ else:
147
+ return arg_in
148
+
149
+ is_detach = api_name not in not_detach_set
150
+ device_args = recursive_arg_to_device(input_args, is_detach)
151
+ device_kwargs = \
152
+ {key: recursive_arg_to_device(value, key != "out" and is_detach) for key, value in input_kwargs.items()}
153
+ return device_args, device_kwargs
154
+
155
+
156
+ def generate_cpu_params(input_args, input_kwargs, need_backward, api_name):
157
+ def recursive_arg_to_cpu(arg_in, to_detach, raise_dtype=None, depth=0):
158
+ if depth > Const.MAX_DEPTH:
159
+ logger.error("The depth of arg_in is too large, please check the arg_in.")
160
+ raise CompareException(CompareException.RECURSION_LIMIT_ERROR)
161
+ if isinstance(arg_in, (list, tuple)):
162
+ return type(arg_in)(recursive_arg_to_cpu(arg, to_detach, raise_dtype=raise_dtype, depth=depth+1)
163
+ for arg in arg_in)
164
+ elif isinstance(arg_in, torch.Tensor):
165
+ if need_backward and arg_in.requires_grad:
166
+ arg_in = deal_detach(raise_bench_data_dtype(
167
+ api_name, arg_in.clone(), raise_dtype=raise_dtype), to_detach).requires_grad_()
168
+ temp_arg_in = arg_in * 1
169
+ arg_in = temp_arg_in.type_as(arg_in)
170
+ arg_in.retain_grad()
171
+ return arg_in
172
+ else:
173
+ return deal_detach(raise_bench_data_dtype(api_name, arg_in.clone(), raise_dtype=raise_dtype), to_detach)
174
+ else:
175
+ return arg_in
176
+
177
+ def is_tensor_with_raise_precision(arg_in, check_kwargs=False):
178
+ if arg_in.dtype in PRECISION_MAPPING:
179
+ return True
180
+ if check_kwargs and arg_in.dtype in [torch.half, torch.bfloat16]:
181
+ return True
182
+ return False
183
+
184
+ def recursive_find_dtypes(arg_in, kwargs=None, check_kwargs=False, depth=0):
185
+ if depth > Const.MAX_DEPTH:
186
+ logger.error("The depth of arg_in is too large, please check the arg_in.")
187
+ raise CompareException(CompareException.RECURSION_LIMIT_ERROR)
188
+ if isinstance(arg_in, (list, tuple)):
189
+ return set().union(*tuple(recursive_find_dtypes(arg, kwargs, check_kwargs=check_kwargs, depth=depth+1) for arg in arg_in))
190
+ elif isinstance(arg_in, torch.Tensor) and is_tensor_with_raise_precision(arg_in, check_kwargs):
191
+ return set([arg_in.dtype])
192
+ elif isinstance(arg_in, dict) and check_kwargs:
193
+ return set().union(*tuple(recursive_find_dtypes(v, kwargs, check_kwargs=True, depth=depth+1) for v in arg_in.values()))
194
+ return set()
195
+
196
+ raise_dtype = None
197
+ need_raise_dtypes = recursive_find_dtypes(input_args)
198
+ need_raise_dtypes.update(recursive_find_dtypes(input_kwargs, check_kwargs=True))
199
+ if len(need_raise_dtypes) == 1:
200
+ raise_dtype = PRECISION_MAPPING.get(need_raise_dtypes.pop(), torch.float32)
201
+ elif len(need_raise_dtypes) >= 2:
202
+ raise_dtype = torch.float32
203
+
204
+ raise_dtype = None if api_name in not_raise_dtype_set else raise_dtype
205
+ is_detach = api_name not in not_detach_set
206
+ cpu_args = recursive_arg_to_cpu(input_args, is_detach, raise_dtype=raise_dtype)
207
+ cpu_kwargs = {key: recursive_arg_to_cpu(value, key != "out" and is_detach, raise_dtype=raise_dtype) for key, value in input_kwargs.items()}
208
+ return cpu_args, cpu_kwargs
209
+
210
+
211
+ def record_skip_info(api_full_name, compare, compare_alg_results):
212
+ result_info = (api_full_name, CompareConst.SKIP, CompareConst.SKIP, [compare_alg_results], None, 0)
213
+ compare.record_results(result_info)
@@ -1,3 +1,18 @@
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
1
16
  import glob
2
17
  import os.path
3
18
  import time
@@ -41,6 +56,7 @@ class ATTL:
41
56
  self.message_end = False
42
57
  self.kill_progress = False
43
58
  self.check_attl_config()
59
+ self.nfs_path = None
44
60
  if self.session_config.nfs_path:
45
61
  self.nfs_path = self.session_config.nfs_path
46
62
  elif self.session_config.is_benchmark_device:
@@ -77,6 +93,11 @@ class ATTL:
77
93
  """
78
94
  npu major in 'send' (client)
79
95
  """
96
+
97
+ # if tcp connection lost,
98
+ if self.socket_manager.signal_exit:
99
+ raise ConnectionError(f"Failed to connect to {self.session_config.connect_ip}.")
100
+
80
101
  # know receiver receive and go next
81
102
  if isinstance(buffer, ApiData):
82
103
  buffer = move2target_device(buffer, torch.device('cpu'))
@@ -1,10 +1,24 @@
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
1
16
  import hashlib
2
17
  import io
3
18
  import struct
4
19
  import time
5
20
  import os
6
21
  import signal
7
- import sys
8
22
  from queue import Queue
9
23
  from threading import Thread
10
24
  from typing import Union
@@ -13,7 +27,10 @@ from twisted.internet import reactor, protocol, endpoints
13
27
  from twisted.protocols.basic import FileSender
14
28
 
15
29
  from msprobe.pytorch.common.utils import logger
16
- from msprobe.pytorch.api_accuracy_checker.tensor_transport_layer.ssl_config import cipher_list
30
+ from msprobe.pytorch.api_accuracy_checker.tensor_transport_layer.utils import struct_unpack_mode as unpack_mode, \
31
+ str_to_bytes_order as bytes_order
32
+
33
+ MAX_SENDING_QUEUE_SIZE = 20
17
34
 
18
35
 
19
36
  class TCPDataItem:
@@ -31,7 +48,6 @@ class TCPDataItem:
31
48
 
32
49
 
33
50
  class TCPClient:
34
- MAX_SENDING_QUEUE_SIZE = 20
35
51
  ACK_SUCCESS = b"OK___"
36
52
  ACK_ERROR = b"ERROR"
37
53
  ACK_BUSY = b"BUSY_"
@@ -39,13 +55,13 @@ class TCPClient:
39
55
  ACK_STOP_CONFIRM = b"OVER_"
40
56
  ACK_KILL_PROCESS = b"KILL_"
41
57
 
42
- QUEUE_PENDING_TIME = 600 # 队列10分钟都处于阻塞状态,则终止sending进程
58
+ QUEUE_PENDING_TIME = 60
43
59
  RESEND_RETRY_TIMES = 2 # 最大重传数
44
60
  RESEND_TIMER_TIME = 5 # 接收ACK超时定时器
45
61
  RESEND_PENDING_TIME = 60 # 连续pending时间超过1分钟则放弃该数据
46
62
 
47
63
  def __init__(self, host="localhost", port=8000, check_sum=False, tls_path=None):
48
- self.send_queue = Queue(self.MAX_SENDING_QUEUE_SIZE)
64
+ self.send_queue = Queue(MAX_SENDING_QUEUE_SIZE)
49
65
  self.resend_dict = dict()
50
66
  self.host = host
51
67
  self.port = port
@@ -55,7 +71,8 @@ class TCPClient:
55
71
  self.signal_exit = False
56
72
  self.tcp_manager = ClientProtocol(ack_queue_size=100,
57
73
  chunk_size=655360,
58
- check_sum=check_sum)
74
+ check_sum=check_sum,
75
+ tls=self.tls_path)
59
76
  self.send_thread = Thread(target=self._sending_queue_data)
60
77
  self.send_thread.setDaemon(True)
61
78
  self.send_thread.start()
@@ -67,6 +84,15 @@ class TCPClient:
67
84
  def run_reactor():
68
85
  reactor.run(installSignalHandlers=False)
69
86
 
87
+ def check_tls_path(self):
88
+ client_key = os.path.join(self.tls_path, "client.key")
89
+ client_crt = os.path.join(self.tls_path, "client.crt")
90
+ if not os.path.exists(client_key):
91
+ raise Exception(f"client_key: {client_key} is not exists.")
92
+ if not os.path.exists(client_crt):
93
+ raise Exception(f"client_crt: {client_crt} is not exists.")
94
+ return client_key, client_crt
95
+
70
96
  def start(self):
71
97
  def conn_callback(cur_protocol):
72
98
  if cur_protocol.transport and cur_protocol.transport.getPeer().host == self.host:
@@ -80,8 +106,6 @@ class TCPClient:
80
106
  time.sleep(1)
81
107
  reactor.stop()
82
108
  logger.error(f"Failed to connected {self.host} {self.port}. Reason is {failure.getErrorMessage()}")
83
- os.kill(os.getpid(), signal.SIGKILL)
84
- os.kill(os.getppid(), signal.SIGKILL)
85
109
 
86
110
  def cur_protocol():
87
111
  return self.tcp_manager
@@ -89,14 +113,9 @@ class TCPClient:
89
113
  self.factory = MessageClientFactory()
90
114
  self.factory.protocol = cur_protocol
91
115
  if self.tls_path:
92
- from OpenSSL import SSL
93
116
  from twisted.internet import ssl
94
- client_key = os.path.join(self.tls_path, "client.key")
95
- client_crt = os.path.join(self.tls_path, "client.crt")
96
- client_context_factory = ssl.DefaultOpenSSLContextFactory(client_key, client_crt, SSL.TLSv1_2_METHOD)
97
- client_context_ = client_context_factory.getContext()
98
- client_context_.set_cipher_list(cipher_list)
99
- client_context_.set_options(SSL.OP_NO_RENEGOTIATION)
117
+ client_key, client_crt = self.check_tls_path()
118
+ client_context_factory = ssl.DefaultOpenSSLContextFactory(client_key, client_crt)
100
119
  endpoint = endpoints.SSL4ClientEndpoint(reactor, self.host, self.port, client_context_factory)
101
120
  else:
102
121
  endpoint = endpoints.TCP4ClientEndpoint(reactor, self.host, self.port)
@@ -109,7 +128,11 @@ class TCPClient:
109
128
 
110
129
  def send_after_queue_empty(self, data):
111
130
  while not self._ready_to_exit():
112
- self.add_to_sending_queue(data)
131
+ if not self.tls_path:
132
+ self.add_to_sending_queue(data)
133
+ else:
134
+ for _ in range(MAX_SENDING_QUEUE_SIZE):
135
+ self.add_to_sending_queue(data)
113
136
  time.sleep(2)
114
137
 
115
138
  def check_client_alive(self):
@@ -124,8 +147,6 @@ class TCPClient:
124
147
  if not self.check_client_alive():
125
148
  break
126
149
  time.sleep(1)
127
- while not self.tcp_manager.kill_process:
128
- time.sleep(1)
129
150
 
130
151
  def add_to_sending_queue(self, data: Union[bytes, TCPDataItem], rank: int = 0, step: int = 0):
131
152
  if self._ready_to_exit():
@@ -142,7 +163,8 @@ class TCPClient:
142
163
  self.send_queue.put(send_data, block=True, timeout=self.QUEUE_PENDING_TIME)
143
164
  except Exception as e:
144
165
  logger.error(f"send_queue put send_data timeout, rank: {send_data.rank}, step: {send_data.step},"
145
- f"sequence_number: {send_data.sequence_number}, {str(e)}")
166
+ f"sequence_number: {send_data.sequence_number}, send_queue size: {self.send_queue.qsize()},"
167
+ f"{str(e)}")
146
168
 
147
169
  def _send_data(self, data: TCPDataItem):
148
170
  self.tcp_manager.send_wrapped_data(data.raw_data,
@@ -159,10 +181,11 @@ class TCPClient:
159
181
  while self.send_queue.qsize() > 0:
160
182
  if self._ready_to_exit():
161
183
  break
162
- if len(self.resend_dict) < self.MAX_SENDING_QUEUE_SIZE:
184
+ if len(self.resend_dict) < MAX_SENDING_QUEUE_SIZE:
163
185
  data_obj = self.send_queue.get()
164
- self._send_data(data_obj)
165
186
  resend_key = str(data_obj.sequence_number) + "_" + str(data_obj.rank) + "_" + str(data_obj.step)
187
+ logger.debug(f"get {resend_key} from send_queue, and send to server.")
188
+ self._send_data(data_obj)
166
189
  if resend_key not in self.resend_dict.keys():
167
190
  # Send data for the first time
168
191
  self.resend_dict[resend_key] = data_obj
@@ -233,7 +256,7 @@ class TCPClient:
233
256
  class ClientProtocol(protocol.Protocol):
234
257
  TIMEOUT = 60 * 10
235
258
 
236
- def __init__(self, ack_queue_size=100, chunk_size=65536, check_sum=False):
259
+ def __init__(self, ack_queue_size=100, chunk_size=65536, check_sum=False, tls=None):
237
260
  self.buffer = io.BytesIO()
238
261
  self.is_connected = False
239
262
  self.check_sum = check_sum
@@ -244,6 +267,13 @@ class ClientProtocol(protocol.Protocol):
244
267
  self.signal_exit = False
245
268
  self.defer = None
246
269
  self.kill_process = False
270
+ self.ack = None
271
+
272
+ self.timeout_call = None
273
+
274
+ self.tls = tls
275
+ self.send_buffer = b""
276
+ self.buffer_cnt = 0
247
277
 
248
278
  def dataReceived(self, data):
249
279
  if self.timeout_call.active():
@@ -255,9 +285,11 @@ class ClientProtocol(protocol.Protocol):
255
285
  while True:
256
286
  if len(self.buffer.getvalue()) >= 29: # 5 + 8 * 3
257
287
  ack = self.buffer.read(5)
258
- seq_number = struct.unpack('!Q', self.buffer.read(8))[0]
259
- rank = struct.unpack('!Q', self.buffer.read(8))[0]
260
- step = struct.unpack('!Q', self.buffer.read(8))[0]
288
+ self.ack = ack
289
+ seq_number = struct.unpack(unpack_mode, self.buffer.read(8))[0]
290
+ rank = struct.unpack(unpack_mode, self.buffer.read(8))[0]
291
+ step = struct.unpack(unpack_mode, self.buffer.read(8))[0]
292
+ logger.debug(f"receive 流水号: {seq_number}; RANK: {rank}; STEP: {step}; ACK: {ack}")
261
293
  if ack == b"KILL_":
262
294
  self.kill_process = True
263
295
  logger.debug(f"接收到KILL信号, PID {os.getpid()}")
@@ -276,20 +308,33 @@ class ClientProtocol(protocol.Protocol):
276
308
  def send_wrapped_data(self, data, sequence_number: int = 0, rank: int = 0, step: int = 0):
277
309
  length = len(data)
278
310
  md5_hash = hashlib.md5(data).hexdigest() if self.check_sum else ""
311
+ data_meaasge = length.to_bytes(8, byteorder=bytes_order) + \
312
+ sequence_number.to_bytes(8, byteorder=bytes_order) + \
313
+ rank.to_bytes(8, byteorder=bytes_order) + \
314
+ step.to_bytes(8, byteorder=bytes_order) + \
315
+ md5_hash.encode() + \
316
+ data
317
+ logger.debug(f"send 流水号: {sequence_number}; RANK: {rank}; STEP: {step}; LENGTH: {length}")
318
+
279
319
  while True:
280
320
  if self.defer is None or self.defer.called:
281
- self.defer = self.send_large_data(
282
- length.to_bytes(8, byteorder='big') +
283
- sequence_number.to_bytes(8, byteorder='big') +
284
- rank.to_bytes(8, byteorder='big') +
285
- step.to_bytes(8, byteorder='big') +
286
- md5_hash.encode() +
287
- data)
321
+ self.defer = self.send_large_data(data_meaasge)
288
322
  break
289
323
  time.sleep(0.01)
290
324
 
291
325
  def send_large_data(self, data):
292
- d = self.file_sender.beginFileTransfer(io.BytesIO(data), self.transport)
326
+
327
+ if self.tls:
328
+ self.send_buffer += data
329
+ self.buffer_cnt += 1
330
+ if self.buffer_cnt >= MAX_SENDING_QUEUE_SIZE:
331
+ d = self.file_sender.beginFileTransfer(io.BytesIO(self.send_buffer), self.transport)
332
+ self.send_buffer = b""
333
+ self.buffer_cnt = 0
334
+ else:
335
+ d = None
336
+ else:
337
+ d = self.file_sender.beginFileTransfer(io.BytesIO(data), self.transport)
293
338
  return d
294
339
 
295
340
  def connection_timeout(self):
@@ -1,3 +1,18 @@
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
1
16
  import time
2
17
  from collections import namedtuple
3
18
 
@@ -12,6 +27,8 @@ from msprobe.pytorch.api_accuracy_checker.compare.compare_utils import DETAIL_TE
12
27
  from msprobe.pytorch.api_accuracy_checker.run_ut.run_ut_utils import UtDataInfo, exec_api
13
28
  from msprobe.pytorch.common.log import logger
14
29
  from msprobe.pytorch.api_accuracy_checker.tensor_transport_layer.attl import move2target_device
30
+ from msprobe.pytorch.api_accuracy_checker.run_ut.run_ut_utils import generate_cpu_params
31
+
15
32
 
16
33
  # NPU vs GPU api list
17
34
  CompareApi = set(absolute_standard_api) | set(binary_standard_api) | set(thousandth_standard_api)
@@ -75,7 +92,8 @@ def online_precision_compare(api_data, device, common_config, api_precision_csv_
75
92
 
76
93
  try:
77
94
  # NPU vs CPU
78
- cpu_out = exec_api(api_type, api_name, Const.CPU_LOWERCASE, npu_args, npu_kwargs)
95
+ cpu_args, cpu_kwargs = generate_cpu_params(npu_args, npu_kwargs, False, api_name)
96
+ cpu_out = exec_api(api_type, api_name, Const.CPU_LOWERCASE, cpu_args, cpu_kwargs)
79
97
  npu_data_info = UtDataInfo(None, None, npu_out, cpu_out, None, [], None, rank=api_data.rank)
80
98
  npu_detail = compare.compare_output(api_full_name, npu_data_info, True)
81
99
  npu_data = pd.DataFrame(npu_detail, columns=DETAIL_TEST_ROWS[-1])
@@ -156,7 +174,10 @@ class ConsumerDispatcher:
156
174
 
157
175
  def start(self, handle_func, config):
158
176
  self.queues = [mp.Queue(maxsize=self.capacity) for _ in range(self.num_workers)]
159
- api_precision_csv_file = [ONLINE_API_PRECISION_COMPARE_RESULT_FILE_NAME, ONLINE_API_PRECISION_COMPARE_DETAILS_FILE_NAME]
177
+ api_precision_csv_file = [
178
+ ONLINE_API_PRECISION_COMPARE_RESULT_FILE_NAME,
179
+ ONLINE_API_PRECISION_COMPARE_DETAILS_FILE_NAME
180
+ ]
160
181
  common_config = CommonCompareConfig(self.compare, handle_func, config)
161
182
  for xpu_id, q in enumerate(self.queues):
162
183
  p = mp.Process(name="run_ut_process", target=run_ut_process,
@@ -164,8 +185,10 @@ class ConsumerDispatcher:
164
185
 
165
186
  p.start()
166
187
  self.processes.append(p)
167
- logger.info(f"Api_precision_compare task result will be saved in {ONLINE_API_PRECISION_COMPARE_RESULT_FILE_NAME}")
168
- logger.info(f"Api_precision_compare task details will be saved in {ONLINE_API_PRECISION_COMPARE_DETAILS_FILE_NAME}")
188
+ logger.info(
189
+ f'Api_precision_compare task result will be saved in {ONLINE_API_PRECISION_COMPARE_RESULT_FILE_NAME}')
190
+ logger.info(
191
+ f"Api_precision_compare task details will be saved in {ONLINE_API_PRECISION_COMPARE_DETAILS_FILE_NAME}")
169
192
  logger.info("Successfully start unittest process.")
170
193
 
171
194
  def stop(self):
@@ -0,0 +1,110 @@
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import os
17
+ from functools import wraps
18
+
19
+ import torch
20
+ from torch.utils._python_dispatch import TorchDispatchMode
21
+ from msprobe.pytorch.api_accuracy_checker.common.utils import ApiData
22
+ from msprobe.pytorch.common.utils import get_tensor_rank
23
+ from msprobe.core.common.const import Const
24
+ from msprobe.pytorch.common.log import logger
25
+ from msprobe.core.common.file_utils import load_yaml
26
+
27
+
28
+ def singleton(cls):
29
+ _instance = {}
30
+
31
+ @wraps(cls)
32
+ def inner():
33
+ if cls not in _instance:
34
+ _instance[cls] = cls()
35
+ return _instance[cls]
36
+ return inner
37
+
38
+
39
+ @singleton
40
+ class Counter:
41
+ def __init__(self) -> None:
42
+ self.index_dict = {}
43
+
44
+
45
+ counter = Counter()
46
+ yaml_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "torch_ops_config.yaml")
47
+ yaml_file = load_yaml(yaml_path)
48
+
49
+
50
+ class AccuracyCheckerDispatch(TorchDispatchMode):
51
+ def __init__(self, attl):
52
+ super(AccuracyCheckerDispatch, self).__init__()
53
+ self.attl = attl
54
+ self.counter = counter
55
+ self.aten_ops_blacklist = []
56
+ self.npu_adjust_autogard = []
57
+ self.aten_ops_blacklist = yaml_file.get('aten_ops_blacklist', [])
58
+ self.npu_adjust_autogard = yaml_file.get('npu_adjust_autogard', [])
59
+
60
+ def __torch_dispatch__(self, func, types, args=None, kwargs=None):
61
+ func_name_split_list = func.__name__.split(Const.SEP)
62
+ aten_api = func_name_split_list[0]
63
+ self.enable_autogard(aten_api)
64
+ if aten_api in self.aten_ops_blacklist:
65
+ npu_out = func(*args, **kwargs)
66
+ return npu_out
67
+
68
+ res = func(*args, **kwargs)
69
+ cur_rank = get_tensor_rank(args, res)
70
+ cur_api_number = self.counter.index_dict.setdefault(aten_api, 0)
71
+ api_name = f'{Const.ATEN}{Const.SEP}{aten_api}{Const.SEP}{cur_api_number}'
72
+ logger.info(f"tools is dumping api: {api_name}")
73
+ api_data = ApiData(api_name, args, kwargs, res, 0, cur_rank)
74
+ if "device" in api_data.kwargs:
75
+ api_data.kwargs.pop("device")
76
+ if self.attl.nfs_path:
77
+ self.attl.upload(api_data)
78
+ else:
79
+ self.attl.send(api_data)
80
+ self.counter.index_dict[aten_api] += 1
81
+
82
+ return res
83
+
84
+ def enable_autogard(self, aten_api):
85
+ if aten_api in self.npu_adjust_autogard:
86
+ torch._C._dispatch_tls_set_dispatch_key_excluded(torch._C.DispatchKey.AutogradFunctionality, False)
87
+
88
+
89
+ def dispatch4data(func, attl, status):
90
+ @wraps(func)
91
+ def wrapper(*args, **kwargs):
92
+ if not status:
93
+ return func(*args, **kwargs)
94
+ with AccuracyCheckerDispatch(attl):
95
+ res = func(*args, **kwargs)
96
+ return res
97
+
98
+ return wrapper
99
+
100
+
101
+ def run_ut_dispatch(attl, status):
102
+ """
103
+ This function called by online_run_ut.
104
+ It is used to enable or disable dispatch for torch.autograd.backward function.
105
+
106
+ Args:
107
+ attl (ATTL): online_run_ut class ATTL, which is used to upload or send api data to server.
108
+ status (bool): True means enable dispatch, False means disable dispatch.
109
+ """
110
+ torch.autograd.backward = dispatch4data(torch.autograd.backward, attl, status)