mindstudio-probe 1.0.4__py3-none-any.whl → 1.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (278) hide show
  1. {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.1.dist-info}/METADATA +5 -5
  2. mindstudio_probe-1.1.1.dist-info/RECORD +341 -0
  3. {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.1.dist-info}/WHEEL +1 -1
  4. {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.1.dist-info}/entry_points.txt +0 -1
  5. msprobe/README.md +84 -18
  6. msprobe/__init__.py +16 -1
  7. msprobe/config.json +1 -5
  8. msprobe/core/advisor/advisor.py +16 -11
  9. msprobe/core/advisor/advisor_const.py +6 -7
  10. msprobe/core/advisor/advisor_result.py +12 -12
  11. msprobe/core/common/const.py +164 -3
  12. msprobe/core/common/exceptions.py +26 -4
  13. msprobe/core/common/file_utils.py +196 -27
  14. msprobe/core/common/inplace_op_checker.py +53 -0
  15. msprobe/core/common/inplace_ops.yaml +251 -0
  16. msprobe/core/common/log.py +46 -18
  17. msprobe/core/common/utils.py +308 -209
  18. msprobe/core/common_config.py +60 -38
  19. msprobe/core/compare/acc_compare.py +332 -94
  20. msprobe/core/compare/check.py +104 -22
  21. msprobe/core/compare/compare_cli.py +42 -5
  22. msprobe/core/compare/highlight.py +162 -57
  23. msprobe/core/compare/layer_mapping/__init__.py +19 -0
  24. msprobe/core/compare/layer_mapping/data_scope_parser.py +235 -0
  25. msprobe/core/compare/layer_mapping/layer_mapping.py +242 -0
  26. msprobe/core/compare/layer_mapping/postprocess_pass.py +94 -0
  27. msprobe/core/compare/multiprocessing_compute.py +33 -8
  28. msprobe/core/compare/npy_compare.py +73 -29
  29. msprobe/core/compare/utils.py +306 -247
  30. msprobe/core/data_dump/data_collector.py +44 -43
  31. msprobe/core/data_dump/data_processor/base.py +88 -35
  32. msprobe/core/data_dump/data_processor/factory.py +20 -3
  33. msprobe/core/data_dump/data_processor/mindspore_processor.py +14 -8
  34. msprobe/core/data_dump/data_processor/pytorch_processor.py +180 -66
  35. msprobe/core/data_dump/json_writer.py +63 -42
  36. msprobe/core/data_dump/scope.py +143 -48
  37. msprobe/core/grad_probe/constant.py +31 -13
  38. msprobe/core/grad_probe/grad_compare.py +20 -4
  39. msprobe/core/grad_probe/utils.py +44 -3
  40. msprobe/core/overflow_check/abnormal_scene.py +185 -0
  41. msprobe/core/overflow_check/api_info.py +55 -0
  42. msprobe/core/overflow_check/checker.py +138 -0
  43. msprobe/core/overflow_check/filter.py +157 -0
  44. msprobe/core/overflow_check/ignore_rules.yaml +55 -0
  45. msprobe/core/overflow_check/level.py +22 -0
  46. msprobe/core/overflow_check/utils.py +28 -0
  47. msprobe/docs/01.installation.md +29 -9
  48. msprobe/docs/02.config_introduction.md +83 -84
  49. msprobe/docs/03.config_examples.md +3 -20
  50. msprobe/docs/04.kernel_dump_PyTorch.md +73 -0
  51. msprobe/docs/05.data_dump_PyTorch.md +143 -13
  52. msprobe/docs/06.data_dump_MindSpore.md +197 -88
  53. msprobe/docs/07.accuracy_checker_PyTorch.md +69 -46
  54. msprobe/docs/08.accuracy_checker_online_PyTorch.md +52 -17
  55. msprobe/docs/09.accuracy_checker_MindSpore.md +51 -15
  56. msprobe/docs/10.accuracy_compare_PyTorch.md +187 -99
  57. msprobe/docs/11.accuracy_compare_MindSpore.md +253 -31
  58. msprobe/docs/12.overflow_check_PyTorch.md +1 -1
  59. msprobe/docs/13.overflow_check_MindSpore.md +6 -6
  60. msprobe/docs/15.free_benchmarking_PyTorch.md +60 -55
  61. msprobe/docs/16.free_benchmarking_MindSpore.md +159 -0
  62. msprobe/docs/17.grad_probe.md +19 -22
  63. msprobe/docs/18.online_dispatch.md +89 -0
  64. msprobe/docs/19.monitor.md +468 -0
  65. msprobe/docs/20.monitor_performance_baseline.md +52 -0
  66. msprobe/docs/21.visualization_PyTorch.md +386 -0
  67. msprobe/docs/22.visualization_MindSpore.md +384 -0
  68. msprobe/docs/23.tool_function_introduction.md +28 -0
  69. msprobe/docs/{FAQ_PyTorch.md → FAQ.md} +25 -10
  70. msprobe/docs/data_dump_Mindspore/dynamic_graph_quick_start_example.md +211 -0
  71. msprobe/docs/img/compare_result.png +0 -0
  72. msprobe/docs/img/monitor/cpu_info.png +0 -0
  73. msprobe/docs/img/ms_dump.png +0 -0
  74. msprobe/docs/img/ms_layer.png +0 -0
  75. msprobe/docs/img/pt_dump.png +0 -0
  76. msprobe/mindspore/__init__.py +16 -0
  77. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +130 -138
  78. msprobe/mindspore/api_accuracy_checker/api_info.py +27 -5
  79. msprobe/mindspore/api_accuracy_checker/api_runner.py +43 -18
  80. msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +21 -7
  81. msprobe/mindspore/api_accuracy_checker/checker_support_api.yaml +77 -0
  82. msprobe/mindspore/api_accuracy_checker/cmd_parser.py +63 -1
  83. msprobe/mindspore/api_accuracy_checker/compute_element.py +59 -24
  84. msprobe/mindspore/api_accuracy_checker/data_manager.py +264 -0
  85. msprobe/mindspore/api_accuracy_checker/main.py +27 -3
  86. msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +206 -0
  87. msprobe/mindspore/api_accuracy_checker/multi_data_manager.py +58 -0
  88. msprobe/mindspore/api_accuracy_checker/type_mapping.py +22 -5
  89. msprobe/mindspore/api_accuracy_checker/utils.py +34 -17
  90. msprobe/mindspore/cell_processor.py +58 -13
  91. msprobe/mindspore/common/const.py +35 -13
  92. msprobe/mindspore/common/log.py +5 -9
  93. msprobe/mindspore/common/utils.py +60 -5
  94. msprobe/mindspore/compare/distributed_compare.py +15 -28
  95. msprobe/mindspore/compare/ms_compare.py +319 -158
  96. msprobe/mindspore/compare/ms_graph_compare.py +99 -49
  97. msprobe/mindspore/debugger/debugger_config.py +20 -14
  98. msprobe/mindspore/debugger/precision_debugger.py +43 -13
  99. msprobe/mindspore/dump/dump_tool_factory.py +18 -1
  100. msprobe/mindspore/dump/hook_cell/api_registry.py +23 -3
  101. msprobe/mindspore/dump/hook_cell/primitive_hooks.py +203 -0
  102. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +107 -10
  103. msprobe/mindspore/dump/hook_cell/wrap_api.py +21 -13
  104. msprobe/mindspore/dump/jit_dump.py +56 -20
  105. msprobe/mindspore/dump/kernel_graph_dump.py +19 -5
  106. msprobe/mindspore/dump/kernel_kbyk_dump.py +19 -6
  107. msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +140 -0
  108. msprobe/mindspore/dym_loader/hook_dynamic_loader.h +53 -0
  109. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +162 -41
  110. msprobe/mindspore/free_benchmark/common/config.py +15 -0
  111. msprobe/mindspore/free_benchmark/common/handler_params.py +15 -1
  112. msprobe/mindspore/free_benchmark/common/utils.py +37 -8
  113. msprobe/mindspore/free_benchmark/data/support_wrap_ops.yaml +0 -204
  114. msprobe/mindspore/free_benchmark/handler/base_handler.py +20 -5
  115. msprobe/mindspore/free_benchmark/handler/check_handler.py +21 -7
  116. msprobe/mindspore/free_benchmark/handler/fix_handler.py +18 -3
  117. msprobe/mindspore/free_benchmark/handler/handler_factory.py +21 -6
  118. msprobe/mindspore/free_benchmark/perturbation/add_noise.py +23 -8
  119. msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +29 -5
  120. msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +25 -10
  121. msprobe/mindspore/free_benchmark/perturbation/exchange_value.py +45 -19
  122. msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +29 -8
  123. msprobe/mindspore/free_benchmark/perturbation/no_change.py +16 -1
  124. msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +22 -7
  125. msprobe/mindspore/free_benchmark/self_check_tool_factory.py +17 -2
  126. msprobe/mindspore/grad_probe/global_context.py +44 -14
  127. msprobe/mindspore/grad_probe/grad_analyzer.py +27 -13
  128. msprobe/mindspore/grad_probe/grad_monitor.py +16 -1
  129. msprobe/mindspore/grad_probe/grad_stat_csv.py +33 -5
  130. msprobe/mindspore/grad_probe/hook.py +24 -10
  131. msprobe/mindspore/grad_probe/utils.py +18 -5
  132. msprobe/mindspore/ms_config.py +22 -15
  133. msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +20 -6
  134. msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +15 -0
  135. msprobe/mindspore/runtime.py +15 -0
  136. msprobe/mindspore/service.py +75 -150
  137. msprobe/mindspore/task_handler_factory.py +15 -0
  138. msprobe/msprobe.py +24 -7
  139. msprobe/pytorch/__init__.py +23 -3
  140. msprobe/pytorch/api_accuracy_checker/common/config.py +81 -2
  141. msprobe/pytorch/api_accuracy_checker/common/utils.py +53 -21
  142. msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +19 -2
  143. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +50 -25
  144. msprobe/pytorch/api_accuracy_checker/compare/compare.py +51 -21
  145. msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +23 -6
  146. msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +28 -8
  147. msprobe/pytorch/api_accuracy_checker/config.yaml +1 -1
  148. msprobe/pytorch/api_accuracy_checker/generate_op_script/config_op.json +9 -0
  149. msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +454 -0
  150. msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +365 -0
  151. msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +73 -33
  152. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +44 -18
  153. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +32 -11
  154. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +122 -172
  155. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +158 -4
  156. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +30 -24
  157. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +68 -31
  158. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +27 -4
  159. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/dump_dispatch.py +115 -0
  160. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +26 -9
  161. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/torch_ops_config.yaml +63 -0
  162. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +44 -0
  163. msprobe/pytorch/bench_functions/__init__.py +18 -3
  164. msprobe/pytorch/bench_functions/apply_adam_w.py +15 -0
  165. msprobe/pytorch/bench_functions/confusion_transpose.py +20 -1
  166. msprobe/pytorch/bench_functions/fast_gelu.py +15 -0
  167. msprobe/pytorch/bench_functions/layer_norm_eval.py +15 -0
  168. msprobe/pytorch/bench_functions/linear.py +15 -0
  169. msprobe/pytorch/bench_functions/matmul_backward.py +33 -6
  170. msprobe/pytorch/bench_functions/npu_fusion_attention.py +280 -157
  171. msprobe/pytorch/bench_functions/rms_norm.py +15 -0
  172. msprobe/pytorch/bench_functions/rotary_mul.py +32 -9
  173. msprobe/pytorch/bench_functions/scaled_mask_softmax.py +15 -0
  174. msprobe/pytorch/bench_functions/swiglu.py +29 -6
  175. msprobe/pytorch/common/__init__.py +15 -0
  176. msprobe/pytorch/common/log.py +18 -6
  177. msprobe/pytorch/common/parse_json.py +31 -16
  178. msprobe/pytorch/common/utils.py +96 -40
  179. msprobe/pytorch/compare/distributed_compare.py +13 -14
  180. msprobe/pytorch/compare/match.py +15 -0
  181. msprobe/pytorch/compare/pt_compare.py +44 -10
  182. msprobe/pytorch/debugger/debugger_config.py +69 -52
  183. msprobe/pytorch/debugger/precision_debugger.py +72 -24
  184. msprobe/pytorch/dump/kernel_dump/kernel_config.py +33 -0
  185. msprobe/pytorch/free_benchmark/__init__.py +20 -5
  186. msprobe/pytorch/free_benchmark/common/constant.py +15 -0
  187. msprobe/pytorch/free_benchmark/common/counter.py +15 -0
  188. msprobe/pytorch/free_benchmark/common/enums.py +43 -0
  189. msprobe/pytorch/free_benchmark/common/params.py +23 -1
  190. msprobe/pytorch/free_benchmark/common/utils.py +43 -5
  191. msprobe/pytorch/free_benchmark/compare/grad_saver.py +47 -9
  192. msprobe/pytorch/free_benchmark/compare/single_benchmark.py +17 -0
  193. msprobe/pytorch/free_benchmark/main.py +19 -4
  194. msprobe/pytorch/free_benchmark/perturbed_layers/base_layer.py +15 -0
  195. msprobe/pytorch/free_benchmark/perturbed_layers/layer_factory.py +19 -4
  196. msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +18 -1
  197. msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +21 -4
  198. msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +28 -2
  199. msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +19 -0
  200. msprobe/pytorch/free_benchmark/perturbed_layers/npu/no_change.py +15 -0
  201. msprobe/pytorch/free_benchmark/perturbed_layers/npu/npu_base_layser.py +15 -0
  202. msprobe/pytorch/free_benchmark/perturbed_layers/run_cpu.py +15 -0
  203. msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +65 -16
  204. msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +15 -0
  205. msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +21 -5
  206. msprobe/pytorch/free_benchmark/result_handlers/handler_factory.py +15 -0
  207. msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +19 -4
  208. msprobe/pytorch/function_factory.py +17 -2
  209. msprobe/pytorch/functional/module_dump.py +84 -0
  210. msprobe/pytorch/grad_probe/grad_monitor.py +23 -6
  211. msprobe/pytorch/grad_probe/grad_stat_csv.py +40 -10
  212. msprobe/pytorch/hook_module/__init__.py +16 -1
  213. msprobe/pytorch/hook_module/api_registry.py +13 -8
  214. msprobe/pytorch/hook_module/hook_module.py +17 -19
  215. msprobe/pytorch/hook_module/support_wrap_ops.yaml +1 -0
  216. msprobe/pytorch/hook_module/utils.py +4 -6
  217. msprobe/pytorch/hook_module/wrap_aten.py +12 -11
  218. msprobe/pytorch/hook_module/wrap_distributed.py +6 -7
  219. msprobe/pytorch/hook_module/wrap_functional.py +21 -20
  220. msprobe/pytorch/hook_module/wrap_npu_custom.py +9 -17
  221. msprobe/pytorch/hook_module/wrap_tensor.py +4 -6
  222. msprobe/pytorch/hook_module/wrap_torch.py +4 -6
  223. msprobe/pytorch/hook_module/wrap_vf.py +4 -6
  224. msprobe/pytorch/module_processer.py +18 -6
  225. msprobe/pytorch/monitor/anomaly_analyse.py +201 -0
  226. msprobe/pytorch/monitor/anomaly_detect.py +340 -0
  227. msprobe/pytorch/monitor/distributed/distributed_ops.yaml +19 -0
  228. msprobe/pytorch/monitor/distributed/stack_blacklist.yaml +5 -0
  229. msprobe/pytorch/monitor/distributed/wrap_distributed.py +272 -0
  230. msprobe/pytorch/monitor/features.py +108 -0
  231. msprobe/pytorch/monitor/module_hook.py +870 -0
  232. msprobe/pytorch/monitor/module_metric.py +193 -0
  233. msprobe/pytorch/monitor/module_spec_verifier.py +93 -0
  234. msprobe/pytorch/monitor/optimizer_collect.py +295 -0
  235. msprobe/pytorch/monitor/unittest/__init__.py +0 -0
  236. msprobe/pytorch/monitor/unittest/test_monitor.py +145 -0
  237. msprobe/pytorch/monitor/utils.py +250 -0
  238. msprobe/pytorch/monitor/visualizer.py +59 -0
  239. msprobe/pytorch/online_dispatch/__init__.py +2 -3
  240. msprobe/pytorch/online_dispatch/compare.py +38 -48
  241. msprobe/pytorch/online_dispatch/dispatch.py +50 -25
  242. msprobe/pytorch/online_dispatch/dump_compare.py +21 -9
  243. msprobe/pytorch/online_dispatch/single_compare.py +60 -39
  244. msprobe/pytorch/online_dispatch/torch_ops_config.yaml +9 -1
  245. msprobe/pytorch/online_dispatch/utils.py +48 -23
  246. msprobe/pytorch/parse.py +15 -0
  247. msprobe/pytorch/parse_tool/cli.py +5 -6
  248. msprobe/pytorch/parse_tool/lib/compare.py +19 -26
  249. msprobe/pytorch/parse_tool/lib/config.py +1 -1
  250. msprobe/pytorch/parse_tool/lib/parse_tool.py +4 -2
  251. msprobe/pytorch/parse_tool/lib/utils.py +40 -55
  252. msprobe/pytorch/parse_tool/lib/visualization.py +3 -1
  253. msprobe/pytorch/pt_config.py +192 -40
  254. msprobe/pytorch/service.py +110 -35
  255. msprobe/visualization/__init__.py +14 -0
  256. msprobe/visualization/builder/__init__.py +14 -0
  257. msprobe/visualization/builder/graph_builder.py +165 -0
  258. msprobe/visualization/builder/msprobe_adapter.py +205 -0
  259. msprobe/visualization/compare/__init__.py +14 -0
  260. msprobe/visualization/compare/graph_comparator.py +130 -0
  261. msprobe/visualization/compare/mode_adapter.py +211 -0
  262. msprobe/visualization/graph/__init__.py +14 -0
  263. msprobe/visualization/graph/base_node.py +124 -0
  264. msprobe/visualization/graph/graph.py +200 -0
  265. msprobe/visualization/graph/node_colors.py +95 -0
  266. msprobe/visualization/graph/node_op.py +39 -0
  267. msprobe/visualization/graph_service.py +214 -0
  268. msprobe/visualization/utils.py +232 -0
  269. mindstudio_probe-1.0.4.dist-info/RECORD +0 -276
  270. msprobe/docs/04.acl_config_examples.md +0 -76
  271. msprobe/mindspore/free_benchmark/decorator/dec_forward.py +0 -43
  272. msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +0 -107
  273. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/ssl_config.py +0 -10
  274. msprobe/pytorch/functional/dump_module.py +0 -39
  275. {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.1.dist-info}/LICENSE +0 -0
  276. {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.1.dist-info}/top_level.txt +0 -0
  277. /msprobe/{mindspore/free_benchmark/decorator → pytorch/monitor}/__init__.py +0 -0
  278. /msprobe/pytorch/{functional/data_processor.py → monitor/distributed/__init__.py} +0 -0
@@ -1,8 +1,35 @@
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
4
+ # All rights reserved.
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+
1
18
  import os
2
19
  import re
20
+ import torch
21
+
22
+ try:
23
+ import torch_npu
24
+ except ImportError:
25
+ current_device = "cuda"
26
+ else:
27
+ current_device = "npu"
3
28
 
4
- from msprobe.core.common.const import FileCheckConst
29
+ from msprobe.core.common.const import FileCheckConst, Const, CompareConst
5
30
  from msprobe.core.common.file_utils import FileChecker
31
+ from msprobe.core.common.log import logger
32
+ from msprobe.core.common.utils import CompareException
6
33
  from msprobe.pytorch.hook_module.wrap_aten import AtenOPTemplate
7
34
  from msprobe.pytorch.hook_module.wrap_functional import FunctionalOPTemplate
8
35
  from msprobe.pytorch.hook_module.wrap_npu_custom import NpuOPTemplate
@@ -10,12 +37,21 @@ from msprobe.pytorch.hook_module.wrap_tensor import TensorOPTemplate
10
37
  from msprobe.pytorch.hook_module.wrap_torch import TorchOPTemplate
11
38
 
12
39
  hf_32_standard_api = ["conv1d", "conv2d"]
40
+ not_detach_set = {'resize_', 'resize_as_', 'set_', 'transpose_', 't_', 'squeeze_', 'unsqueeze_'}
41
+ not_raise_dtype_set = {'type_as'}
42
+
43
+ PRECISION_MAPPING = {
44
+ torch.float16: torch.float32,
45
+ torch.bfloat16: torch.float32,
46
+ torch.float32: torch.float64
47
+ }
13
48
 
14
49
 
15
- class Backward_Message:
50
+ class BackwardMessage:
16
51
  MULTIPLE_BACKWARD_MESSAGE = "Multiple backward is not supported."
17
- UNSUPPORT_BACKWARD_MESSAGE = "function with out=... arguments don't support automatic differentiation, skip backward."
18
- NO_BACKWARD_RESULT_MESSAGE = "function backward result is None, skip backward."
52
+ UNSUPPORT_BACKWARD_MESSAGE = "function with out=... arguments don't support automatic differentiation, " \
53
+ "skip backward."
54
+ NO_BACKWARD_RESULT_MESSAGE = "This API does not have backward input data, skip backward."
19
55
 
20
56
 
21
57
  class UtDataInfo:
@@ -68,3 +104,121 @@ def exec_api(api_type, api_name, device, args, kwargs):
68
104
  torch_api = NpuOPTemplate(api_name, None, False, device)
69
105
  out = torch_api.forward(*args, **kwargs)
70
106
  return out
107
+
108
+
109
+ def deal_detach(arg, to_detach=True):
110
+ return arg.detach() if to_detach else arg
111
+
112
+
113
+ def raise_bench_data_dtype(api_name, arg, raise_dtype=None):
114
+ '''
115
+ 将标杆数据的dtype转换为raise_dtype
116
+ 输入:
117
+ api_name:api名称
118
+ arg:标杆输入
119
+ raise_dtype:需要转换的dtype
120
+ 输出:
121
+ arg: 转换dtype的标杆输入
122
+ '''
123
+ if api_name in hf_32_standard_api and arg.dtype == torch.float32:
124
+ return arg
125
+ if raise_dtype is None or arg.dtype not in PRECISION_MAPPING or raise_dtype == arg.dtype:
126
+ return arg
127
+ return arg.type(raise_dtype)
128
+
129
+
130
+ def generate_device_params(input_args, input_kwargs, need_backward, api_name):
131
+ def recursive_arg_to_device(arg_in, to_detach, depth=0):
132
+ if depth > Const.MAX_DEPTH:
133
+ logger.error("The depth of arg_in is too large, please check the arg_in.")
134
+ raise CompareException(CompareException.RECURSION_LIMIT_ERROR)
135
+ if isinstance(arg_in, (list, tuple)):
136
+ return type(arg_in)(recursive_arg_to_device(arg, to_detach, depth=depth+1) for arg in arg_in)
137
+ elif isinstance(arg_in, torch.Tensor):
138
+ if need_backward and arg_in.requires_grad:
139
+ arg_in = deal_detach(arg_in.clone(), to_detach).to(current_device).requires_grad_()
140
+ temp_arg_in = arg_in * 1
141
+ arg_in = temp_arg_in.type_as(arg_in)
142
+ arg_in.retain_grad()
143
+ return arg_in
144
+ else:
145
+ return deal_detach(arg_in.clone(), to_detach).to(current_device)
146
+ else:
147
+ return arg_in
148
+
149
+ is_detach = api_name not in not_detach_set
150
+ device_args = recursive_arg_to_device(input_args, is_detach)
151
+ device_kwargs = \
152
+ {key: recursive_arg_to_device(value, key != "out" and is_detach) for key, value in input_kwargs.items()}
153
+ return device_args, device_kwargs
154
+
155
+
156
+ def generate_cpu_params(input_args, input_kwargs, need_backward, api_name):
157
+ def recursive_arg_to_cpu(arg_in, to_detach, raise_dtype=None, depth=0):
158
+ if depth > Const.MAX_DEPTH:
159
+ logger.error("The depth of arg_in is too large, please check the arg_in.")
160
+ raise CompareException(CompareException.RECURSION_LIMIT_ERROR)
161
+ if isinstance(arg_in, (list, tuple)):
162
+ return type(arg_in)(recursive_arg_to_cpu(arg, to_detach, raise_dtype=raise_dtype, depth=depth+1)
163
+ for arg in arg_in)
164
+ elif isinstance(arg_in, torch.Tensor):
165
+ if need_backward and arg_in.requires_grad:
166
+ arg_in = deal_detach(raise_bench_data_dtype(
167
+ api_name, arg_in.clone(), raise_dtype=raise_dtype), to_detach).requires_grad_()
168
+ temp_arg_in = arg_in * 1
169
+ arg_in = temp_arg_in.type_as(arg_in)
170
+ arg_in.retain_grad()
171
+ return arg_in
172
+ else:
173
+ return deal_detach(raise_bench_data_dtype(api_name, arg_in.clone(), raise_dtype=raise_dtype), to_detach)
174
+ else:
175
+ return arg_in
176
+
177
+ def is_tensor_with_raise_precision(arg_in, check_kwargs=False):
178
+ if arg_in.dtype in PRECISION_MAPPING:
179
+ return True
180
+ if check_kwargs and arg_in.dtype in [torch.half, torch.bfloat16]:
181
+ return True
182
+ return False
183
+
184
+ def recursive_find_dtypes(arg_in, kwargs=None, check_kwargs=False, depth=0):
185
+ if depth > Const.MAX_DEPTH:
186
+ logger.error("The depth of arg_in is too large, please check the arg_in.")
187
+ raise CompareException(CompareException.RECURSION_LIMIT_ERROR)
188
+ if isinstance(arg_in, (list, tuple)):
189
+ return set().union(*tuple(recursive_find_dtypes(arg, kwargs, check_kwargs=check_kwargs, depth=depth+1) for
190
+ arg in arg_in))
191
+ elif isinstance(arg_in, torch.Tensor) and is_tensor_with_raise_precision(arg_in, check_kwargs):
192
+ return set([arg_in.dtype])
193
+ elif isinstance(arg_in, dict) and check_kwargs:
194
+ return set().union(*tuple(recursive_find_dtypes(v, kwargs, check_kwargs=True, depth=depth+1) for
195
+ v in arg_in.values()))
196
+ return set()
197
+
198
+ raise_dtype = None
199
+ need_raise_dtypes = recursive_find_dtypes(input_args)
200
+ need_raise_dtypes.update(recursive_find_dtypes(input_kwargs, check_kwargs=True))
201
+ if len(need_raise_dtypes) == 1:
202
+ raise_dtype = PRECISION_MAPPING.get(need_raise_dtypes.pop(), torch.float32)
203
+ elif len(need_raise_dtypes) >= 2:
204
+ raise_dtype = torch.float32
205
+
206
+ raise_dtype = None if api_name in not_raise_dtype_set else raise_dtype
207
+ is_detach = api_name not in not_detach_set
208
+ cpu_args = recursive_arg_to_cpu(input_args, is_detach, raise_dtype=raise_dtype)
209
+ cpu_kwargs = {key: recursive_arg_to_cpu(value, key != "out" and is_detach, raise_dtype=raise_dtype) for
210
+ key, value in input_kwargs.items()}
211
+ return cpu_args, cpu_kwargs
212
+
213
+
214
+ def record_skip_info(api_full_name, compare, compare_alg_results):
215
+ result_info = (api_full_name, CompareConst.SKIP, CompareConst.SKIP, [compare_alg_results], None, 0)
216
+ compare.record_results(result_info)
217
+
218
+
219
+ def is_unsupported_api(api_name, is_overflow_check=False):
220
+ split_name = api_name.split(Const.SEP)[0]
221
+ flag = (split_name == Const.DISTRIBUTED) or (is_overflow_check and split_name == Const.NPU)
222
+ if flag:
223
+ logger.info(f"{split_name} api is not supported for run ut. SKIP.")
224
+ return flag
@@ -1,7 +1,21 @@
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
1
16
  import glob
2
17
  import os.path
3
18
  import time
4
- import re
5
19
  from multiprocessing import Queue
6
20
  from typing import Optional, Union, Dict, Any
7
21
  from dataclasses import dataclass
@@ -11,9 +25,8 @@ import torch
11
25
  from msprobe.pytorch.api_accuracy_checker.common.utils import ApiData
12
26
  from msprobe.pytorch.api_accuracy_checker.tensor_transport_layer.client import TCPClient
13
27
  from msprobe.pytorch.api_accuracy_checker.tensor_transport_layer.server import TCPServer
14
- from msprobe.pytorch.common.utils import logger
15
28
  from msprobe.core.common.file_utils import remove_path
16
- from msprobe.pytorch.common.utils import save_api_data, load_api_data, save_pt, load_pt
29
+ from msprobe.pytorch.common.utils import logger, save_api_data, load_api_data, save_pkl, load_pkl
17
30
 
18
31
  BufferType = Union[ApiData, Dict[str, Any], str] # Union[Tensor, Tuple[Optional[Tensor]]]
19
32
 
@@ -40,7 +53,7 @@ class ATTL:
40
53
  self.dequeue_list = []
41
54
  self.message_end = False
42
55
  self.kill_progress = False
43
- self.check_attl_config()
56
+ self.nfs_path = None
44
57
  if self.session_config.nfs_path:
45
58
  self.nfs_path = self.session_config.nfs_path
46
59
  elif self.session_config.is_benchmark_device:
@@ -57,18 +70,6 @@ class ATTL:
57
70
  self.session_config.tls_path)
58
71
  self.socket_manager.start()
59
72
 
60
- def check_attl_config(self):
61
- if self.session_config.nfs_path:
62
- if os.path.exists(self.session_config.nfs_path):
63
- return
64
- else:
65
- raise Exception(f"nfs path {self.session_config.nfs_path} doesn't exists.")
66
- ipv4_pattern = "([1-9]?\d|1\d{2}|2[0-4]\d|25[0-5])(\.([1-9]?\d|1\d{2}|2[0-4]\d|25[0-5])){3}$"
67
- if not re.match(ipv4_pattern, self.session_config.connect_ip):
68
- raise Exception(f"host {self.session_config.connect_ip} is invalid.")
69
- if not (0 < self.session_config.connect_port <= 65535):
70
- raise Exception(f"port {self.session_config.connect_port} is invalid.")
71
-
72
73
  def stop_serve(self):
73
74
  if isinstance(self.socket_manager, TCPServer):
74
75
  self.socket_manager.stop()
@@ -77,6 +78,11 @@ class ATTL:
77
78
  """
78
79
  npu major in 'send' (client)
79
80
  """
81
+
82
+ # if tcp connection lost,
83
+ if self.socket_manager.signal_exit:
84
+ raise ConnectionError(f"Failed to connect to {self.session_config.connect_ip}.")
85
+
80
86
  # know receiver receive and go next
81
87
  if isinstance(buffer, ApiData):
82
88
  buffer = move2target_device(buffer, torch.device('cpu'))
@@ -94,21 +100,21 @@ class ATTL:
94
100
  self.socket_manager.add_to_sending_queue(data, rank=rank, step=step)
95
101
 
96
102
  def recv(self, timeout_ms=0) -> Optional[BufferType]:
97
- buffer = None
98
- while buffer is None:
103
+ buffer = ''
104
+ while not buffer:
99
105
  if timeout_ms > 0:
100
106
  time.sleep(timeout_ms / 1000.0)
101
- if buffer is None and not self.data_queue.empty():
107
+ if not buffer and not self.data_queue.empty():
102
108
  buffer = self.data_queue.get()
103
109
  break
104
- if buffer is None and timeout_ms > 0: # timeout is the only case we give up and return None
110
+ if not buffer and timeout_ms > 0: # timeout is the only case we give up and return None
105
111
  break
106
112
  if self.message_end and self.data_queue.empty():
107
113
  buffer = b"KILL_CONFIRM"
108
114
  self.kill_progress = True
109
115
  break
110
116
  time.sleep(0.1) # waiting outside the lock before next attempt
111
- if buffer is None:
117
+ if not buffer:
112
118
  # this is a result of a timeout
113
119
  self.logger.info(f"RECEIVE API DATA TIMED OUT")
114
120
  else:
@@ -125,7 +131,7 @@ class ATTL:
125
131
  except Exception as e:
126
132
  self.logger.warning("there is something error. please check it. %s", e)
127
133
  if isinstance(buffer, bytes):
128
- return None
134
+ return ''
129
135
  if isinstance(buffer, str):
130
136
  return buffer
131
137
 
@@ -139,7 +145,7 @@ class ATTL:
139
145
  file_path = os.path.join(self.session_config.nfs_path, buffer + f"_{int(time.time())}")
140
146
 
141
147
  try:
142
- save_pt(buffer, file_path)
148
+ save_pkl(buffer, file_path)
143
149
  except Exception as e:
144
150
  self.logger.warning("there is something error in save_pt. please check it. %s", e)
145
151
 
@@ -155,7 +161,7 @@ class ATTL:
155
161
 
156
162
  if cur_file is not None:
157
163
  try:
158
- buffer = load_pt(cur_file)
164
+ buffer = load_pkl(cur_file)
159
165
  except Exception as e:
160
166
  self.logger.warning("there is something error. please check it. %s", e)
161
167
  remove_path(cur_file)
@@ -1,10 +1,24 @@
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
1
16
  import hashlib
2
17
  import io
3
18
  import struct
4
19
  import time
5
20
  import os
6
21
  import signal
7
- import sys
8
22
  from queue import Queue
9
23
  from threading import Thread
10
24
  from typing import Union
@@ -13,7 +27,10 @@ from twisted.internet import reactor, protocol, endpoints
13
27
  from twisted.protocols.basic import FileSender
14
28
 
15
29
  from msprobe.pytorch.common.utils import logger
16
- from msprobe.pytorch.api_accuracy_checker.tensor_transport_layer.ssl_config import cipher_list
30
+ from msprobe.pytorch.api_accuracy_checker.tensor_transport_layer.utils import STRUCT_UNPACK_MODE as unpack_mode, \
31
+ STR_TO_BYTES_ORDER as bytes_order
32
+
33
+ MAX_SENDING_QUEUE_SIZE = 20
17
34
 
18
35
 
19
36
  class TCPDataItem:
@@ -31,7 +48,6 @@ class TCPDataItem:
31
48
 
32
49
 
33
50
  class TCPClient:
34
- MAX_SENDING_QUEUE_SIZE = 20
35
51
  ACK_SUCCESS = b"OK___"
36
52
  ACK_ERROR = b"ERROR"
37
53
  ACK_BUSY = b"BUSY_"
@@ -39,13 +55,13 @@ class TCPClient:
39
55
  ACK_STOP_CONFIRM = b"OVER_"
40
56
  ACK_KILL_PROCESS = b"KILL_"
41
57
 
42
- QUEUE_PENDING_TIME = 600 # 队列10分钟都处于阻塞状态,则终止sending进程
58
+ QUEUE_PENDING_TIME = 60
43
59
  RESEND_RETRY_TIMES = 2 # 最大重传数
44
60
  RESEND_TIMER_TIME = 5 # 接收ACK超时定时器
45
61
  RESEND_PENDING_TIME = 60 # 连续pending时间超过1分钟则放弃该数据
46
62
 
47
63
  def __init__(self, host="localhost", port=8000, check_sum=False, tls_path=None):
48
- self.send_queue = Queue(self.MAX_SENDING_QUEUE_SIZE)
64
+ self.send_queue = Queue(MAX_SENDING_QUEUE_SIZE)
49
65
  self.resend_dict = dict()
50
66
  self.host = host
51
67
  self.port = port
@@ -55,7 +71,8 @@ class TCPClient:
55
71
  self.signal_exit = False
56
72
  self.tcp_manager = ClientProtocol(ack_queue_size=100,
57
73
  chunk_size=655360,
58
- check_sum=check_sum)
74
+ check_sum=check_sum,
75
+ tls=self.tls_path)
59
76
  self.send_thread = Thread(target=self._sending_queue_data)
60
77
  self.send_thread.setDaemon(True)
61
78
  self.send_thread.start()
@@ -80,8 +97,6 @@ class TCPClient:
80
97
  time.sleep(1)
81
98
  reactor.stop()
82
99
  logger.error(f"Failed to connected {self.host} {self.port}. Reason is {failure.getErrorMessage()}")
83
- os.kill(os.getpid(), signal.SIGKILL)
84
- os.kill(os.getppid(), signal.SIGKILL)
85
100
 
86
101
  def cur_protocol():
87
102
  return self.tcp_manager
@@ -89,14 +104,10 @@ class TCPClient:
89
104
  self.factory = MessageClientFactory()
90
105
  self.factory.protocol = cur_protocol
91
106
  if self.tls_path:
92
- from OpenSSL import SSL
93
107
  from twisted.internet import ssl
94
108
  client_key = os.path.join(self.tls_path, "client.key")
95
109
  client_crt = os.path.join(self.tls_path, "client.crt")
96
- client_context_factory = ssl.DefaultOpenSSLContextFactory(client_key, client_crt, SSL.TLSv1_2_METHOD)
97
- client_context_ = client_context_factory.getContext()
98
- client_context_.set_cipher_list(cipher_list)
99
- client_context_.set_options(SSL.OP_NO_RENEGOTIATION)
110
+ client_context_factory = ssl.DefaultOpenSSLContextFactory(client_key, client_crt)
100
111
  endpoint = endpoints.SSL4ClientEndpoint(reactor, self.host, self.port, client_context_factory)
101
112
  else:
102
113
  endpoint = endpoints.TCP4ClientEndpoint(reactor, self.host, self.port)
@@ -109,7 +120,11 @@ class TCPClient:
109
120
 
110
121
  def send_after_queue_empty(self, data):
111
122
  while not self._ready_to_exit():
112
- self.add_to_sending_queue(data)
123
+ if not self.tls_path:
124
+ self.add_to_sending_queue(data)
125
+ else:
126
+ for _ in range(MAX_SENDING_QUEUE_SIZE):
127
+ self.add_to_sending_queue(data)
113
128
  time.sleep(2)
114
129
 
115
130
  def check_client_alive(self):
@@ -124,8 +139,6 @@ class TCPClient:
124
139
  if not self.check_client_alive():
125
140
  break
126
141
  time.sleep(1)
127
- while not self.tcp_manager.kill_process:
128
- time.sleep(1)
129
142
 
130
143
  def add_to_sending_queue(self, data: Union[bytes, TCPDataItem], rank: int = 0, step: int = 0):
131
144
  if self._ready_to_exit():
@@ -142,7 +155,8 @@ class TCPClient:
142
155
  self.send_queue.put(send_data, block=True, timeout=self.QUEUE_PENDING_TIME)
143
156
  except Exception as e:
144
157
  logger.error(f"send_queue put send_data timeout, rank: {send_data.rank}, step: {send_data.step},"
145
- f"sequence_number: {send_data.sequence_number}, {str(e)}")
158
+ f"sequence_number: {send_data.sequence_number}, send_queue size: {self.send_queue.qsize()},"
159
+ f"{str(e)}")
146
160
 
147
161
  def _send_data(self, data: TCPDataItem):
148
162
  self.tcp_manager.send_wrapped_data(data.raw_data,
@@ -159,10 +173,11 @@ class TCPClient:
159
173
  while self.send_queue.qsize() > 0:
160
174
  if self._ready_to_exit():
161
175
  break
162
- if len(self.resend_dict) < self.MAX_SENDING_QUEUE_SIZE:
176
+ if len(self.resend_dict) < MAX_SENDING_QUEUE_SIZE:
163
177
  data_obj = self.send_queue.get()
164
- self._send_data(data_obj)
165
178
  resend_key = str(data_obj.sequence_number) + "_" + str(data_obj.rank) + "_" + str(data_obj.step)
179
+ logger.debug(f"get {resend_key} from send_queue, and send to server.")
180
+ self._send_data(data_obj)
166
181
  if resend_key not in self.resend_dict.keys():
167
182
  # Send data for the first time
168
183
  self.resend_dict[resend_key] = data_obj
@@ -233,7 +248,7 @@ class TCPClient:
233
248
  class ClientProtocol(protocol.Protocol):
234
249
  TIMEOUT = 60 * 10
235
250
 
236
- def __init__(self, ack_queue_size=100, chunk_size=65536, check_sum=False):
251
+ def __init__(self, ack_queue_size=100, chunk_size=65536, check_sum=False, tls=None):
237
252
  self.buffer = io.BytesIO()
238
253
  self.is_connected = False
239
254
  self.check_sum = check_sum
@@ -244,6 +259,13 @@ class ClientProtocol(protocol.Protocol):
244
259
  self.signal_exit = False
245
260
  self.defer = None
246
261
  self.kill_process = False
262
+ self.ack = None
263
+
264
+ self.timeout_call = None
265
+
266
+ self.tls = tls
267
+ self.send_buffer = b""
268
+ self.buffer_cnt = 0
247
269
 
248
270
  def dataReceived(self, data):
249
271
  if self.timeout_call.active():
@@ -255,9 +277,11 @@ class ClientProtocol(protocol.Protocol):
255
277
  while True:
256
278
  if len(self.buffer.getvalue()) >= 29: # 5 + 8 * 3
257
279
  ack = self.buffer.read(5)
258
- seq_number = struct.unpack('!Q', self.buffer.read(8))[0]
259
- rank = struct.unpack('!Q', self.buffer.read(8))[0]
260
- step = struct.unpack('!Q', self.buffer.read(8))[0]
280
+ self.ack = ack
281
+ seq_number = struct.unpack(unpack_mode, self.buffer.read(8))[0]
282
+ rank = struct.unpack(unpack_mode, self.buffer.read(8))[0]
283
+ step = struct.unpack(unpack_mode, self.buffer.read(8))[0]
284
+ logger.debug(f"receive 流水号: {seq_number}; RANK: {rank}; STEP: {step}; ACK: {ack}")
261
285
  if ack == b"KILL_":
262
286
  self.kill_process = True
263
287
  logger.debug(f"接收到KILL信号, PID {os.getpid()}")
@@ -276,20 +300,33 @@ class ClientProtocol(protocol.Protocol):
276
300
  def send_wrapped_data(self, data, sequence_number: int = 0, rank: int = 0, step: int = 0):
277
301
  length = len(data)
278
302
  md5_hash = hashlib.md5(data).hexdigest() if self.check_sum else ""
303
+ data_meaasge = length.to_bytes(8, byteorder=bytes_order) + \
304
+ sequence_number.to_bytes(8, byteorder=bytes_order) + \
305
+ rank.to_bytes(8, byteorder=bytes_order) + \
306
+ step.to_bytes(8, byteorder=bytes_order) + \
307
+ md5_hash.encode() + \
308
+ data
309
+ logger.debug(f"send 流水号: {sequence_number}; RANK: {rank}; STEP: {step}; LENGTH: {length}")
310
+
279
311
  while True:
280
312
  if self.defer is None or self.defer.called:
281
- self.defer = self.send_large_data(
282
- length.to_bytes(8, byteorder='big') +
283
- sequence_number.to_bytes(8, byteorder='big') +
284
- rank.to_bytes(8, byteorder='big') +
285
- step.to_bytes(8, byteorder='big') +
286
- md5_hash.encode() +
287
- data)
313
+ self.defer = self.send_large_data(data_meaasge)
288
314
  break
289
315
  time.sleep(0.01)
290
316
 
291
317
  def send_large_data(self, data):
292
- d = self.file_sender.beginFileTransfer(io.BytesIO(data), self.transport)
318
+
319
+ if self.tls:
320
+ self.send_buffer += data
321
+ self.buffer_cnt += 1
322
+ if self.buffer_cnt >= MAX_SENDING_QUEUE_SIZE:
323
+ d = self.file_sender.beginFileTransfer(io.BytesIO(self.send_buffer), self.transport)
324
+ self.send_buffer = b""
325
+ self.buffer_cnt = 0
326
+ else:
327
+ d = None
328
+ else:
329
+ d = self.file_sender.beginFileTransfer(io.BytesIO(data), self.transport)
293
330
  return d
294
331
 
295
332
  def connection_timeout(self):
@@ -1,3 +1,18 @@
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
1
16
  import time
2
17
  from collections import namedtuple
3
18
 
@@ -12,6 +27,8 @@ from msprobe.pytorch.api_accuracy_checker.compare.compare_utils import DETAIL_TE
12
27
  from msprobe.pytorch.api_accuracy_checker.run_ut.run_ut_utils import UtDataInfo, exec_api
13
28
  from msprobe.pytorch.common.log import logger
14
29
  from msprobe.pytorch.api_accuracy_checker.tensor_transport_layer.attl import move2target_device
30
+ from msprobe.pytorch.api_accuracy_checker.run_ut.run_ut_utils import generate_cpu_params
31
+
15
32
 
16
33
  # NPU vs GPU api list
17
34
  CompareApi = set(absolute_standard_api) | set(binary_standard_api) | set(thousandth_standard_api)
@@ -75,7 +92,8 @@ def online_precision_compare(api_data, device, common_config, api_precision_csv_
75
92
 
76
93
  try:
77
94
  # NPU vs CPU
78
- cpu_out = exec_api(api_type, api_name, Const.CPU_LOWERCASE, npu_args, npu_kwargs)
95
+ cpu_args, cpu_kwargs = generate_cpu_params(npu_args, npu_kwargs, False, api_name)
96
+ cpu_out = exec_api(api_type, api_name, Const.CPU_LOWERCASE, cpu_args, cpu_kwargs)
79
97
  npu_data_info = UtDataInfo(None, None, npu_out, cpu_out, None, [], None, rank=api_data.rank)
80
98
  npu_detail = compare.compare_output(api_full_name, npu_data_info, True)
81
99
  npu_data = pd.DataFrame(npu_detail, columns=DETAIL_TEST_ROWS[-1])
@@ -156,7 +174,10 @@ class ConsumerDispatcher:
156
174
 
157
175
  def start(self, handle_func, config):
158
176
  self.queues = [mp.Queue(maxsize=self.capacity) for _ in range(self.num_workers)]
159
- api_precision_csv_file = [ONLINE_API_PRECISION_COMPARE_RESULT_FILE_NAME, ONLINE_API_PRECISION_COMPARE_DETAILS_FILE_NAME]
177
+ api_precision_csv_file = [
178
+ ONLINE_API_PRECISION_COMPARE_RESULT_FILE_NAME,
179
+ ONLINE_API_PRECISION_COMPARE_DETAILS_FILE_NAME
180
+ ]
160
181
  common_config = CommonCompareConfig(self.compare, handle_func, config)
161
182
  for xpu_id, q in enumerate(self.queues):
162
183
  p = mp.Process(name="run_ut_process", target=run_ut_process,
@@ -164,8 +185,10 @@ class ConsumerDispatcher:
164
185
 
165
186
  p.start()
166
187
  self.processes.append(p)
167
- logger.info(f"Api_precision_compare task result will be saved in {ONLINE_API_PRECISION_COMPARE_RESULT_FILE_NAME}")
168
- logger.info(f"Api_precision_compare task details will be saved in {ONLINE_API_PRECISION_COMPARE_DETAILS_FILE_NAME}")
188
+ logger.info(
189
+ f'Api_precision_compare task result will be saved in {ONLINE_API_PRECISION_COMPARE_RESULT_FILE_NAME}')
190
+ logger.info(
191
+ f"Api_precision_compare task details will be saved in {ONLINE_API_PRECISION_COMPARE_DETAILS_FILE_NAME}")
169
192
  logger.info("Successfully start unittest process.")
170
193
 
171
194
  def stop(self):