mindstudio-probe 1.0.4__py3-none-any.whl → 1.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (194) hide show
  1. {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.0.dist-info}/METADATA +1 -1
  2. mindstudio_probe-1.1.0.dist-info/RECORD +287 -0
  3. msprobe/README.md +46 -16
  4. msprobe/__init__.py +16 -1
  5. msprobe/config.json +0 -2
  6. msprobe/core/advisor/advisor.py +8 -8
  7. msprobe/core/advisor/advisor_const.py +6 -7
  8. msprobe/core/advisor/advisor_result.py +12 -12
  9. msprobe/core/common/const.py +64 -3
  10. msprobe/core/common/exceptions.py +2 -2
  11. msprobe/core/common/file_utils.py +54 -9
  12. msprobe/core/common/inplace_op_checker.py +38 -0
  13. msprobe/core/common/inplace_ops.yaml +251 -0
  14. msprobe/core/common/log.py +21 -11
  15. msprobe/core/common/utils.py +153 -167
  16. msprobe/core/common_config.py +18 -25
  17. msprobe/core/compare/acc_compare.py +209 -36
  18. msprobe/core/compare/check.py +102 -17
  19. msprobe/core/compare/compare_cli.py +21 -1
  20. msprobe/core/compare/highlight.py +41 -5
  21. msprobe/core/compare/multiprocessing_compute.py +33 -8
  22. msprobe/core/compare/npy_compare.py +21 -6
  23. msprobe/core/compare/utils.py +82 -48
  24. msprobe/core/data_dump/data_collector.py +31 -32
  25. msprobe/core/data_dump/data_processor/base.py +45 -22
  26. msprobe/core/data_dump/data_processor/factory.py +20 -3
  27. msprobe/core/data_dump/data_processor/mindspore_processor.py +11 -5
  28. msprobe/core/data_dump/data_processor/pytorch_processor.py +24 -7
  29. msprobe/core/data_dump/json_writer.py +63 -42
  30. msprobe/core/data_dump/scope.py +32 -16
  31. msprobe/core/grad_probe/constant.py +4 -0
  32. msprobe/core/grad_probe/grad_compare.py +2 -3
  33. msprobe/core/grad_probe/utils.py +16 -3
  34. msprobe/docs/01.installation.md +19 -9
  35. msprobe/docs/02.config_introduction.md +52 -80
  36. msprobe/docs/03.config_examples.md +3 -13
  37. msprobe/docs/04.acl_config_examples.md +11 -9
  38. msprobe/docs/05.data_dump_PyTorch.md +140 -12
  39. msprobe/docs/06.data_dump_MindSpore.md +47 -5
  40. msprobe/docs/07.accuracy_checker_PyTorch.md +57 -34
  41. msprobe/docs/08.accuracy_checker_online_PyTorch.md +51 -11
  42. msprobe/docs/09.accuracy_checker_MindSpore.md +8 -8
  43. msprobe/docs/10.accuracy_compare_PyTorch.md +181 -99
  44. msprobe/docs/11.accuracy_compare_MindSpore.md +162 -31
  45. msprobe/docs/13.overflow_check_MindSpore.md +1 -1
  46. msprobe/docs/15.free_benchmarking_PyTorch.md +59 -53
  47. msprobe/docs/16.free_benchmarking_MindSpore.md +140 -0
  48. msprobe/docs/17.grad_probe.md +14 -16
  49. msprobe/docs/18.online_dispatch.md +89 -0
  50. msprobe/docs/{FAQ_PyTorch.md → FAQ.md} +22 -10
  51. msprobe/docs/img/ms_dump.png +0 -0
  52. msprobe/docs/img/ms_layer.png +0 -0
  53. msprobe/docs/img/pt_dump.png +0 -0
  54. msprobe/mindspore/__init__.py +1 -0
  55. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +35 -11
  56. msprobe/mindspore/api_accuracy_checker/api_info.py +7 -0
  57. msprobe/mindspore/cell_processor.py +27 -3
  58. msprobe/mindspore/common/const.py +2 -0
  59. msprobe/mindspore/common/utils.py +18 -2
  60. msprobe/mindspore/compare/distributed_compare.py +9 -22
  61. msprobe/mindspore/compare/layer_mapping.py +146 -0
  62. msprobe/mindspore/compare/modify_mapping.py +107 -0
  63. msprobe/mindspore/compare/ms_compare.py +173 -35
  64. msprobe/mindspore/compare/ms_graph_compare.py +27 -11
  65. msprobe/mindspore/debugger/debugger_config.py +16 -13
  66. msprobe/mindspore/debugger/precision_debugger.py +37 -13
  67. msprobe/mindspore/dump/dump_tool_factory.py +16 -1
  68. msprobe/mindspore/dump/hook_cell/api_registry.py +11 -1
  69. msprobe/mindspore/dump/hook_cell/primitive_hooks.py +206 -0
  70. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +82 -10
  71. msprobe/mindspore/dump/hook_cell/wrap_api.py +21 -13
  72. msprobe/mindspore/dump/jit_dump.py +41 -17
  73. msprobe/mindspore/dump/kernel_graph_dump.py +19 -3
  74. msprobe/mindspore/dump/kernel_kbyk_dump.py +19 -4
  75. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +19 -4
  76. msprobe/mindspore/free_benchmark/common/config.py +15 -0
  77. msprobe/mindspore/free_benchmark/common/handler_params.py +15 -0
  78. msprobe/mindspore/free_benchmark/common/utils.py +19 -5
  79. msprobe/mindspore/free_benchmark/decorator/dec_forward.py +16 -2
  80. msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +18 -3
  81. msprobe/mindspore/free_benchmark/handler/base_handler.py +18 -3
  82. msprobe/mindspore/free_benchmark/handler/check_handler.py +18 -3
  83. msprobe/mindspore/free_benchmark/handler/fix_handler.py +15 -0
  84. msprobe/mindspore/free_benchmark/handler/handler_factory.py +18 -3
  85. msprobe/mindspore/free_benchmark/perturbation/add_noise.py +22 -7
  86. msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +15 -0
  87. msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +22 -7
  88. msprobe/mindspore/free_benchmark/perturbation/exchange_value.py +44 -18
  89. msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +18 -4
  90. msprobe/mindspore/free_benchmark/perturbation/no_change.py +16 -1
  91. msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +20 -5
  92. msprobe/mindspore/free_benchmark/self_check_tool_factory.py +15 -0
  93. msprobe/mindspore/grad_probe/global_context.py +18 -8
  94. msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +20 -4
  95. msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +15 -0
  96. msprobe/mindspore/service.py +42 -123
  97. msprobe/pytorch/__init__.py +20 -1
  98. msprobe/pytorch/api_accuracy_checker/common/config.py +19 -2
  99. msprobe/pytorch/api_accuracy_checker/common/utils.py +53 -21
  100. msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +19 -2
  101. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +47 -21
  102. msprobe/pytorch/api_accuracy_checker/compare/compare.py +51 -21
  103. msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +23 -6
  104. msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +28 -8
  105. msprobe/pytorch/api_accuracy_checker/config.yaml +1 -1
  106. msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +67 -32
  107. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +26 -5
  108. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +19 -2
  109. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +51 -125
  110. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +146 -3
  111. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +21 -0
  112. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +78 -33
  113. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +27 -4
  114. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/dump_dispatch.py +110 -0
  115. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +36 -11
  116. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/torch_ops_config.yaml +63 -0
  117. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +44 -0
  118. msprobe/pytorch/bench_functions/__init__.py +18 -3
  119. msprobe/pytorch/bench_functions/apply_adam_w.py +15 -0
  120. msprobe/pytorch/bench_functions/confusion_transpose.py +15 -0
  121. msprobe/pytorch/bench_functions/fast_gelu.py +15 -0
  122. msprobe/pytorch/bench_functions/layer_norm_eval.py +15 -0
  123. msprobe/pytorch/bench_functions/linear.py +15 -0
  124. msprobe/pytorch/bench_functions/matmul_backward.py +21 -6
  125. msprobe/pytorch/bench_functions/npu_fusion_attention.py +180 -151
  126. msprobe/pytorch/bench_functions/rms_norm.py +15 -0
  127. msprobe/pytorch/bench_functions/rotary_mul.py +28 -9
  128. msprobe/pytorch/bench_functions/scaled_mask_softmax.py +15 -0
  129. msprobe/pytorch/bench_functions/swiglu.py +20 -5
  130. msprobe/pytorch/common/__init__.py +15 -0
  131. msprobe/pytorch/common/log.py +18 -6
  132. msprobe/pytorch/common/parse_json.py +26 -11
  133. msprobe/pytorch/common/utils.py +40 -35
  134. msprobe/pytorch/compare/distributed_compare.py +11 -11
  135. msprobe/pytorch/compare/match.py +15 -0
  136. msprobe/pytorch/compare/pt_compare.py +38 -6
  137. msprobe/pytorch/debugger/debugger_config.py +52 -39
  138. msprobe/pytorch/debugger/precision_debugger.py +72 -24
  139. msprobe/pytorch/free_benchmark/__init__.py +20 -5
  140. msprobe/pytorch/free_benchmark/common/enums.py +28 -0
  141. msprobe/pytorch/free_benchmark/common/params.py +15 -0
  142. msprobe/pytorch/free_benchmark/common/utils.py +17 -1
  143. msprobe/pytorch/free_benchmark/compare/grad_saver.py +28 -7
  144. msprobe/pytorch/free_benchmark/compare/single_benchmark.py +15 -0
  145. msprobe/pytorch/free_benchmark/main.py +19 -4
  146. msprobe/pytorch/free_benchmark/perturbed_layers/base_layer.py +15 -0
  147. msprobe/pytorch/free_benchmark/perturbed_layers/layer_factory.py +19 -4
  148. msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +15 -0
  149. msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +15 -0
  150. msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +26 -2
  151. msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +15 -0
  152. msprobe/pytorch/free_benchmark/perturbed_layers/npu/no_change.py +15 -0
  153. msprobe/pytorch/free_benchmark/perturbed_layers/npu/npu_base_layser.py +15 -0
  154. msprobe/pytorch/free_benchmark/perturbed_layers/run_cpu.py +15 -0
  155. msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +55 -16
  156. msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +15 -0
  157. msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +15 -0
  158. msprobe/pytorch/free_benchmark/result_handlers/handler_factory.py +15 -0
  159. msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +19 -4
  160. msprobe/pytorch/function_factory.py +17 -2
  161. msprobe/pytorch/functional/module_dump.py +84 -0
  162. msprobe/pytorch/grad_probe/grad_stat_csv.py +2 -2
  163. msprobe/pytorch/hook_module/__init__.py +16 -1
  164. msprobe/pytorch/hook_module/api_registry.py +13 -8
  165. msprobe/pytorch/hook_module/hook_module.py +17 -19
  166. msprobe/pytorch/hook_module/utils.py +4 -6
  167. msprobe/pytorch/hook_module/wrap_aten.py +12 -11
  168. msprobe/pytorch/hook_module/wrap_distributed.py +6 -7
  169. msprobe/pytorch/hook_module/wrap_functional.py +10 -11
  170. msprobe/pytorch/hook_module/wrap_npu_custom.py +9 -17
  171. msprobe/pytorch/hook_module/wrap_tensor.py +4 -6
  172. msprobe/pytorch/hook_module/wrap_torch.py +4 -6
  173. msprobe/pytorch/hook_module/wrap_vf.py +4 -6
  174. msprobe/pytorch/module_processer.py +17 -2
  175. msprobe/pytorch/online_dispatch/compare.py +11 -12
  176. msprobe/pytorch/online_dispatch/single_compare.py +7 -7
  177. msprobe/pytorch/online_dispatch/torch_ops_config.yaml +8 -0
  178. msprobe/pytorch/online_dispatch/utils.py +1 -4
  179. msprobe/pytorch/parse.py +15 -0
  180. msprobe/pytorch/parse_tool/cli.py +5 -6
  181. msprobe/pytorch/parse_tool/lib/compare.py +9 -10
  182. msprobe/pytorch/parse_tool/lib/parse_tool.py +3 -0
  183. msprobe/pytorch/parse_tool/lib/utils.py +28 -24
  184. msprobe/pytorch/parse_tool/lib/visualization.py +1 -1
  185. msprobe/pytorch/pt_config.py +167 -38
  186. msprobe/pytorch/service.py +97 -32
  187. mindstudio_probe-1.0.4.dist-info/RECORD +0 -276
  188. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/ssl_config.py +0 -10
  189. msprobe/pytorch/functional/data_processor.py +0 -0
  190. msprobe/pytorch/functional/dump_module.py +0 -39
  191. {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.0.dist-info}/LICENSE +0 -0
  192. {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.0.dist-info}/WHEEL +0 -0
  193. {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.0.dist-info}/entry_points.txt +0 -0
  194. {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.0.dist-info}/top_level.txt +0 -0
@@ -1,8 +1,7 @@
1
- #!/usr/bin/env python3
2
- # -*- coding: utf-8 -*-
3
- """
4
- # Copyright (C) 2019-2020. Huawei Technologies Co., Ltd. All rights reserved.
5
- # Licensed under the Apache License, Version 2.0 (the "License");
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
5
  # you may not use this file except in compliance with the License.
7
6
  # You may obtain a copy of the License at
8
7
  #
@@ -13,7 +12,6 @@
13
12
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
13
  # See the License for the specific language governing permissions and
15
14
  # limitations under the License.
16
- """
17
15
 
18
16
  import os
19
17
  import torch
@@ -35,25 +33,26 @@ def remove_dropout():
35
33
  def function_dropout(input: torch.Tensor, p: float = 0.5, training: bool = True,
36
34
  inplace: bool = False) -> torch.Tensor:
37
35
  if has_torch_function_unary(input):
38
- return handle_torch_function(function_dropout, (input,), input, p=0., training=training, inplace=inplace)
36
+ return handle_torch_function(
37
+ function_dropout, (input,), input, p=0., training=training, inplace=inplace)
39
38
  if p < 0.0 or p > 1.0:
40
39
  raise ValueError("dropout probability has to be between 0 and 1, " "but got {}".format(p))
41
40
  return _VF.dropout_(input, 0., training) if inplace else _VF.dropout(input, 0., training)
42
41
 
43
-
44
42
  def function_dropout2d(input: torch.Tensor, p: float = 0.5, training: bool = True,
45
43
  inplace: bool = False) -> torch.Tensor:
46
44
  if has_torch_function_unary(input):
47
- return handle_torch_function(function_dropout2d, (input,), input, p=0., training=training, inplace=inplace)
45
+ return handle_torch_function(
46
+ function_dropout2d, (input,), input, p=0., training=training, inplace=inplace)
48
47
  if p < 0.0 or p > 1.0:
49
48
  raise ValueError("dropout probability has to be between 0 and 1, " "but got {}".format(p))
50
49
  return _VF.feature_dropout_(input, 0., training) if inplace else _VF.feature_dropout(input, 0., training)
51
50
 
52
-
53
51
  def function_dropout3d(input: torch.Tensor, p: float = 0.5, training: bool = True,
54
52
  inplace: bool = False) -> torch.Tensor:
55
53
  if has_torch_function_unary(input):
56
- return handle_torch_function(function_dropout3d, (input,), input, p=0., training=training, inplace=inplace)
54
+ return handle_torch_function(
55
+ function_dropout3d, (input,), input, p=0., training=training, inplace=inplace)
57
56
  if p < 0.0 or p > 1.0:
58
57
  raise ValueError("dropout probability has to be between 0 and 1, " "but got {}".format(p))
59
58
  return _VF.feature_dropout_(input, 0., training) if inplace else _VF.feature_dropout(input, 0., training)
@@ -1,8 +1,7 @@
1
- #!/usr/bin/env python3
2
- # -*- coding: utf-8 -*-
3
- """
4
- # Copyright (C) 2019-2020. Huawei Technologies Co., Ltd. All rights reserved.
5
- # Licensed under the Apache License, Version 2.0 (the "License");
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
5
  # you may not use this file except in compliance with the License.
7
6
  # You may obtain a copy of the License at
8
7
  #
@@ -13,7 +12,6 @@
13
12
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
13
  # See the License for the specific language governing permissions and
15
14
  # limitations under the License.
16
- """
17
15
 
18
16
  import os
19
17
  import torch
@@ -21,24 +19,19 @@ import torch
21
19
  from msprobe.pytorch.hook_module.hook_module import HOOKModule
22
20
  from msprobe.pytorch.common.utils import torch_device_guard, torch_without_guard_version
23
21
  from msprobe.core.common.const import Const
22
+ from msprobe.core.common.log import logger
24
23
  from msprobe.core.common.file_utils import load_yaml
25
24
  from msprobe.pytorch.function_factory import npu_custom_functions
26
25
 
27
- cur_path = os.path.dirname(os.path.realpath(__file__))
28
- yaml_path = os.path.join(cur_path, "support_wrap_ops.yaml")
29
-
30
-
31
26
  try:
32
27
  import torch_npu
33
28
  except ImportError:
34
- is_gpu = True
35
- else:
36
- is_gpu = False
29
+ logger.info("Failing to import torch_npu.")
37
30
 
38
31
 
39
- cuda_func_mapping = {
40
- "npu_fusion_attention" : "gpu_fusion_attention"
41
- }
32
+ cur_path = os.path.dirname(os.path.realpath(__file__))
33
+ yaml_path = os.path.join(cur_path, "support_wrap_ops.yaml")
34
+ cuda_func_mapping = {"npu_fusion_attention" : "gpu_fusion_attention"}
42
35
 
43
36
 
44
37
  def get_npu_ops():
@@ -83,7 +76,6 @@ class NpuOPTemplate(HOOKModule):
83
76
  def wrap_npu_op(op_name, hook):
84
77
  def npu_op_template(*args, **kwargs):
85
78
  return NpuOPTemplate(op_name, hook)(*args, **kwargs)
86
-
87
79
  return npu_op_template
88
80
 
89
81
 
@@ -1,8 +1,7 @@
1
- #!/usr/bin/env python3
2
- # -*- coding: utf-8 -*-
3
- """
4
- # Copyright (C) 2019-2020. Huawei Technologies Co., Ltd. All rights reserved.
5
- # Licensed under the Apache License, Version 2.0 (the "License");
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
5
  # you may not use this file except in compliance with the License.
7
6
  # You may obtain a copy of the License at
8
7
  #
@@ -13,7 +12,6 @@
13
12
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
13
  # See the License for the specific language governing permissions and
15
14
  # limitations under the License.
16
- """
17
15
 
18
16
  import os
19
17
 
@@ -1,8 +1,7 @@
1
- #!/usr/bin/env python3
2
- # -*- coding: utf-8 -*-
3
- """
4
- # Copyright (C) 2019-2020. Huawei Technologies Co., Ltd. All rights reserved.
5
- # Licensed under the Apache License, Version 2.0 (the "License");
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
5
  # you may not use this file except in compliance with the License.
7
6
  # You may obtain a copy of the License at
8
7
  #
@@ -13,7 +12,6 @@
13
12
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
13
  # See the License for the specific language governing permissions and
15
14
  # limitations under the License.
16
- """
17
15
 
18
16
  import os
19
17
  import torch
@@ -1,8 +1,7 @@
1
- #!/usr/bin/env python3
2
- # -*- coding: utf-8 -*-
3
- """
4
- # Copyright (C) 2019-2020. Huawei Technologies Co., Ltd. All rights reserved.
5
- # Licensed under the Apache License, Version 2.0 (the "License");
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
5
  # you may not use this file except in compliance with the License.
7
6
  # You may obtain a copy of the License at
8
7
  #
@@ -13,7 +12,6 @@
13
12
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
13
  # See the License for the specific language governing permissions and
15
14
  # limitations under the License.
16
- """
17
15
 
18
16
  import os
19
17
  import torch
@@ -1,10 +1,25 @@
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
1
16
  from functools import wraps
2
17
 
3
18
  import torch
4
- from torch.utils.hooks import BackwardHook
5
-
6
19
  from msprobe.core.common.const import Const
7
20
  from msprobe.core.data_dump.scope import ModuleRangeScope
21
+ from torch.utils.hooks import BackwardHook
22
+
8
23
  torch_version_above_or_equal_2 = torch.__version__.split('+')[0] >= '2.0'
9
24
 
10
25
 
@@ -7,10 +7,10 @@ from collections import namedtuple
7
7
  from rich.table import Table
8
8
  from rich.console import Console
9
9
  from msprobe.core.common.const import CompareConst, FileCheckConst
10
- from msprobe.core.common.file_utils import FileOpen, change_mode
10
+ from msprobe.core.common.file_utils import FileOpen, change_mode, read_csv
11
11
  from msprobe.pytorch.online_dispatch.single_compare import single_benchmark_compare_wrap
12
12
  from msprobe.pytorch.common.log import logger
13
- from msprobe.core.common.utils import CompareException
13
+ from msprobe.core.common.utils import CompareException, check_op_str_pattern_valid
14
14
 
15
15
  ELEMENT_NUM_THRESHOLD = 100
16
16
  ZERO_NUM_THRESHOLD = 0.1
@@ -107,19 +107,17 @@ class Saver:
107
107
 
108
108
  def get_statistics_from_result_csv(self):
109
109
  checklist = [CompareConst.TRUE, CompareConst.FALSE, CompareConst.NA, CompareConst.SKIP]
110
- with FileOpen(self.save_path, 'r') as file:
111
- reader = csv.reader(file)
112
- result_csv_rows = [row for row in reader]
110
+ data = read_csv(self.save_path)
113
111
  result_csv_name = os.path.basename(self.save_path)
114
- for item in result_csv_rows[1:]:
115
- if not isinstance(item, list) or len(item) < 3:
112
+ for _, row in data.iterrows():
113
+ if len(row) < 3:
116
114
  raise ValueError("The number of columns in %s is incorrect" % result_csv_name)
117
- if not all(item[i] and item[i].upper() in checklist for i in (1, 2)):
115
+ if not all(row[i] and row[i].upper() in checklist for i in (1, 2)):
118
116
  raise ValueError(
119
117
  "The value in the 2nd or 3rd column of %s is wrong, it must be TRUE, FALSE, SKIP or N/A"
120
118
  % result_csv_name)
121
- column1 = item[1].upper()
122
- column2 = item[2].upper()
119
+ column1 = row[1].upper()
120
+ column2 = row[2].upper()
123
121
  if column1 == CompareConst.SKIP:
124
122
  continue
125
123
  self.test_result_cnt["total_num"] += 1
@@ -139,12 +137,13 @@ class Saver:
139
137
  if self.stack_info:
140
138
  test_rows[0].append(self.COLUMN_STACK_INFO)
141
139
 
142
- name = test_result.api_name
140
+ check_op_str_pattern_valid(test_result.api_name)
143
141
  df_row = [test_result.api_name, test_result.is_fwd_success, test_result.is_bwd_success]
144
142
  if test_result.is_fwd_success == "SKIP" or test_result.is_bwd_success == "SKIP":
145
143
  df_row.append(test_result.fwd_compare_alg_results)
146
144
  if self.stack_info:
147
- stack_info = "\n".join(self.stack_info[name])
145
+ check_op_str_pattern_valid(self.stack_info[test_result.api_name])
146
+ stack_info = "\n".join(self.stack_info[test_result.api_name])
148
147
  df_row.append(stack_info)
149
148
  test_rows.append(df_row)
150
149
  write_csv(test_rows, self.save_path)
@@ -329,20 +329,20 @@ def single_benchmark_compare(npu_out: torch.Tensor, bench_out: torch.Tensor, hig
329
329
  return result, details
330
330
 
331
331
 
332
- def calc_status_details_list_tuple(npu_out, bench_out, high_precision, summary):
332
+ def calc_status_details_list_tuple(npu_out, bench_out, summary):
333
333
  status, details = [], []
334
334
  if len(bench_out) != len(npu_out):
335
335
  summary.result = False
336
336
  summary.failed_info = "bench and npu output structure is different."
337
337
  return False, summary.to_column_value()
338
338
  for b_out_i, n_out_i in zip(bench_out, npu_out):
339
- status_i, details_i = single_benchmark_compare_wrap(n_out_i, b_out_i, high_precision)
339
+ status_i, details_i = single_benchmark_compare_wrap(n_out_i, b_out_i)
340
340
  status.append(status_i)
341
341
  details.append(details_i)
342
342
  return status, details
343
343
 
344
344
 
345
- def calc_status_details_dict(npu_out, bench_out, high_precision, summary):
345
+ def calc_status_details_dict(npu_out, bench_out, summary):
346
346
  b_keys, n_keys = set(bench_out.keys()), set(npu_out.keys())
347
347
  if b_keys != n_keys:
348
348
  summary.result = False
@@ -353,7 +353,7 @@ def calc_status_details_dict(npu_out, bench_out, high_precision, summary):
353
353
  return status, details
354
354
 
355
355
 
356
- def calc_status_details_tensor(npu_out, bench_out, high_precision, summary):
356
+ def calc_status_details_tensor(npu_out, bench_out, summary):
357
357
  return single_benchmark_compare(npu_out, bench_out)
358
358
 
359
359
 
@@ -365,13 +365,13 @@ def calc_status_details_builtin(npu_out, bench_out, summary):
365
365
  return status, summary.to_column_value()
366
366
 
367
367
 
368
- def calc_status_details_none(npu_out, bench_out, high_precision, summary):
368
+ def calc_status_details_none(npu_out, bench_out, summary):
369
369
  summary.result = True
370
370
  summary.failed_info = "Output is None."
371
371
  return True, summary.to_column_value()
372
372
 
373
373
 
374
- def single_benchmark_compare_wrap(npu_output: torch.Tensor, bench_output: torch.Tensor, high_precision=True):
374
+ def single_benchmark_compare_wrap(npu_output: torch.Tensor, bench_output: torch.Tensor):
375
375
  type_method_dict = {
376
376
  (list, tuple): calc_status_details_list_tuple,
377
377
  dict: calc_status_details_dict,
@@ -384,7 +384,7 @@ def single_benchmark_compare_wrap(npu_output: torch.Tensor, bench_output: torch.
384
384
  bench_summary = SingleBenchSummary(result)
385
385
  for type1, func in type_method_dict.items():
386
386
  if isinstance(bench_output, type1):
387
- return func(npu_output, bench_output, high_precision, bench_summary)
387
+ return func(npu_output, bench_output, bench_summary)
388
388
 
389
389
  bench_summary.result = True
390
390
  bench_summary.failed_info = "Unexpected output type: {}".format(type(bench_output))
@@ -1,4 +1,12 @@
1
1
  aten_ops_blacklist:
2
+ - max_pool2d_with_indices
3
+ - detach
4
+ - allreduce_
5
+ - max
6
+ - npu_rotary_mul
7
+ - split_with_sizes
8
+ - npu_dtype_cast
9
+ - add_
2
10
  - _cudnn_rnn
3
11
  - _local_scalar_dense
4
12
  - _pin_memory
@@ -58,10 +58,7 @@ INT_TYPE = [np.int32, np.int64]
58
58
  def get_callstack():
59
59
  callstack = []
60
60
  for (_, path, line, func, code, _) in inspect.stack()[2:]:
61
- if code:
62
- stack_line = [path, str(line), func, code[0].strip() if code else code]
63
- else:
64
- stack_line = [path, str(line), func, code]
61
+ stack_line = [path, str(line), func, code[0].strip() if code else code]
65
62
  callstack.append(stack_line)
66
63
  return callstack
67
64
 
msprobe/pytorch/parse.py CHANGED
@@ -1,3 +1,18 @@
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
1
16
  from msprobe.pytorch.parse_tool import cli
2
17
 
3
18
  if __name__ == '__main__':
@@ -1,8 +1,7 @@
1
- #!/usr/bin/env python3
2
- # -*- coding: utf-8 -*-
3
- """
4
- # Copyright (C) 2022-2024. Huawei Technologies Co., Ltd. All rights reserved.
5
- # Licensed under the Apache License, Version 2.0 (the "License");
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
5
  # you may not use this file except in compliance with the License.
7
6
  # You may obtain a copy of the License at
8
7
  #
@@ -13,7 +12,7 @@
13
12
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
13
  # See the License for the specific language governing permissions and
15
14
  # limitations under the License.
16
- """
15
+
17
16
  from msprobe.pytorch.parse_tool.lib.interactive_cli import InteractiveCli
18
17
  from msprobe.pytorch.common.log import logger
19
18
 
@@ -22,7 +22,7 @@ from collections import namedtuple
22
22
  from msprobe.pytorch.parse_tool.lib.utils import Util
23
23
  from msprobe.pytorch.parse_tool.lib.config import Const
24
24
  from msprobe.pytorch.parse_tool.lib.parse_exception import ParseException
25
- from msprobe.core.common.file_utils import FileChecker, create_directory, load_npy, save_npy_to_txt, write_csv
25
+ from msprobe.core.common.file_utils import create_directory, load_npy, save_npy_to_txt, write_csv
26
26
 
27
27
 
28
28
  class Compare:
@@ -49,10 +49,10 @@ class Compare:
49
49
  dump_file = self.util.path_strip(dump_file)
50
50
  file_name = ""
51
51
  if os.path.isfile(dump_file):
52
- self.log.info("Covert file is: %s", dump_file)
52
+ self.log.info("Covert file is: %s" % dump_file)
53
53
  file_name = os.path.basename(dump_file)
54
54
  elif os.path.isdir(dump_file):
55
- self.log.info("Convert all files in path: %s", dump_file)
55
+ self.log.info("Convert all files in path: %s" % dump_file)
56
56
  file_name = ""
57
57
  output = output if output else Const.DUMP_CONVERT_DIR
58
58
  convert = self.convert(dump_file, data_format, output, msaccucmp_path)
@@ -62,7 +62,7 @@ class Compare:
62
62
  summary_txt = ["SrcFile: %s" % dump_file]
63
63
  for convert_file in convert_files.values():
64
64
  summary_txt.append(" - %s" % convert_file.file_name)
65
- self.log.info("Transfer result is saved in : %s", os.path.realpath(output))
65
+ self.log.info("Transfer result is saved in : %s" % os.path.realpath(output))
66
66
  self.util.print_panel("\n".join(summary_txt))
67
67
 
68
68
  def convert(self, dump_file, data_format, output, msaccucmp_path):
@@ -114,11 +114,11 @@ class Compare:
114
114
  shape_left = data_left.shape
115
115
  shape_right = data_right.shape
116
116
  if shape_left != shape_right:
117
- self.log.warning("Data shape not equal: %s vs %s", data_left.shape, data_right.shape)
117
+ self.log.warning("Data shape not equal: %s vs %s" % (data_left.shape, data_right.shape))
118
118
  data_left = data_left.reshape(-1)
119
119
  data_right = data_right.reshape(-1)
120
120
  if data_left.shape[0] != data_right.shape[0]:
121
- self.log.warning("Data size not equal: %s vs %s", data_left.shape, data_right.shape)
121
+ self.log.warning("Data size not equal: %s vs %s" % (data_left.shape, data_right.shape))
122
122
  if data_left.shape[0] < data_right.shape[0]:
123
123
  data_left = np.pad(data_left, (0, data_right.shape[0] - data_left.shape[0]), 'constant')
124
124
  else:
@@ -160,7 +160,7 @@ class Compare:
160
160
  if shape != bench_shape or dtype != bench_dtype:
161
161
  self.log.error(
162
162
  "Shape or dtype between two npy files is inconsistent. Please check the two files."
163
- "File 1: %s, file 2: %s", file, bench_file)
163
+ "File 1: %s, file 2: %s" % (file, bench_file))
164
164
  self.util.deal_with_dir_or_file_inconsistency(output_path)
165
165
  return
166
166
  md5_consistency = False
@@ -236,13 +236,12 @@ class Compare:
236
236
  golden_subdir_path = os.path.join(golden_dump_dir, golden_subdir_name)
237
237
  self.compare_timestamp_directory(my_subdir_path, golden_subdir_path, output_path)
238
238
  self.util.change_filemode_safe(output_path)
239
- self.log.info("Compare result is saved in : %s", output_path)
239
+ self.log.info("Compare result is saved in : %s" % (output_path))
240
240
 
241
241
  def convert_api_dir_to_npy(self, dump_dir, param, output_dir, msaccucmp_path):
242
242
  dump_dir = self.util.path_strip(dump_dir)
243
243
  for root, _, files in os.walk(dump_dir, topdown=True):
244
- path_checker = FileChecker(root)
245
- path_checker.common_check()
244
+ self.util.check_path_valid(root)
246
245
  for file in files:
247
246
  file_path = os.path.join(root, file)
248
247
  file_name = os.path.basename(file_path)
@@ -110,6 +110,9 @@ class ParseTool:
110
110
  parser.add_argument('-al', '--atol', dest='atol', default=0.001, type=float, help='set rtol')
111
111
  parser.add_argument('-rl', '--rtol', dest='rtol', default=0.001, type=float, help='set atol')
112
112
  args = parser.parse_args(argv)
113
+ self.util.check_positive(args.count)
114
+ self.util.check_positive(args.rtol)
115
+ self.util.check_positive(args.atol)
113
116
  self.util.check_path_valid(args.my_dump_path)
114
117
  self.util.check_path_valid(args.golden_dump_path)
115
118
  self.util.check_file_path_format(args.my_dump_path, Const.NPY_SUFFIX)
@@ -28,7 +28,7 @@ from msprobe.pytorch.parse_tool.lib.parse_exception import ParseException
28
28
  from msprobe.core.common.file_utils import change_mode, check_other_user_writable,\
29
29
  check_path_executable, check_path_owner_consistent
30
30
  from msprobe.core.common.const import FileCheckConst
31
- from msprobe.core.common.file_utils import FileChecker, check_file_or_directory_path, remove_path
31
+ from msprobe.core.common.file_utils import check_file_or_directory_path, remove_path, check_file_type
32
32
  from msprobe.pytorch.common.log import logger
33
33
 
34
34
 
@@ -71,21 +71,19 @@ class Util:
71
71
  check_path_executable(path)
72
72
 
73
73
  @staticmethod
74
- def get_subdir_count(self, directory):
74
+ def get_subdir_count(directory):
75
75
  subdir_count = 0
76
- path_checker = FileChecker(directory)
77
- path_checker.common_check()
76
+ check_file_or_directory_path(directory, isdir=True)
78
77
  for _, dirs, _ in os.walk(directory):
79
78
  subdir_count += len(dirs)
80
79
  break
81
80
  return subdir_count
82
81
 
83
82
  @staticmethod
84
- def get_subfiles_count(self, directory):
83
+ def get_subfiles_count(directory):
85
84
  file_count = 0
86
85
  for root, _, files in os.walk(directory, topdown=True):
87
- path_checker = FileChecker(root)
88
- path_checker.common_check()
86
+ check_file_or_directory_path(root, isdir=True)
89
87
  file_count += len(files)
90
88
  path_depth = root.count(os.sep)
91
89
  if path_depth <= Const.MAX_TRAVERSAL_DEPTH:
@@ -95,7 +93,7 @@ class Util:
95
93
  return file_count
96
94
 
97
95
  @staticmethod
98
- def get_sorted_subdirectories_names(self, directory):
96
+ def get_sorted_subdirectories_names(directory):
99
97
  subdirectories = []
100
98
  for item in os.listdir(directory):
101
99
  item_path = os.path.join(directory, item)
@@ -104,7 +102,7 @@ class Util:
104
102
  return sorted(subdirectories)
105
103
 
106
104
  @staticmethod
107
- def get_sorted_files_names(self, directory):
105
+ def get_sorted_files_names(directory):
108
106
  files = []
109
107
  for item in os.listdir(directory):
110
108
  item_path = os.path.join(directory, item)
@@ -113,7 +111,7 @@ class Util:
113
111
  return sorted(files)
114
112
 
115
113
  @staticmethod
116
- def check_npy_files_valid_in_dir(self, dir_path):
114
+ def check_npy_files_valid_in_dir(dir_path):
117
115
  for file_name in os.listdir(dir_path):
118
116
  file_path = os.path.join(dir_path, file_name)
119
117
  check_file_or_directory_path(file_path)
@@ -123,18 +121,18 @@ class Util:
123
121
  return True
124
122
 
125
123
  @staticmethod
126
- def get_md5_for_numpy(self, obj):
124
+ def get_md5_for_numpy(obj):
127
125
  np_bytes = obj.tobytes()
128
126
  md5_hash = hashlib.md5(np_bytes)
129
127
  return md5_hash.hexdigest()
130
128
 
131
129
  @staticmethod
132
- def deal_with_dir_or_file_inconsistency(self, output_path):
130
+ def deal_with_dir_or_file_inconsistency(output_path):
133
131
  remove_path(output_path)
134
132
  raise ParseException("Inconsistent directory structure or file.")
135
133
 
136
134
  @staticmethod
137
- def deal_with_value_if_has_zero(self, data):
135
+ def deal_with_value_if_has_zero(data):
138
136
  if data.dtype in Const.FLOAT_TYPE:
139
137
  zero_mask = (data == 0)
140
138
  # 给0的地方加上eps防止除0
@@ -147,10 +145,9 @@ class Util:
147
145
  return data
148
146
 
149
147
  @staticmethod
150
- def dir_contains_only(self, path, endfix):
148
+ def dir_contains_only(path, endfix):
151
149
  for root, _, files in os.walk(path, topdown=True):
152
- path_checker = FileChecker(root)
153
- path_checker.common_check()
150
+ check_file_or_directory_path(root, isdir=True)
154
151
  for file in files:
155
152
  if not file.endswith(endfix):
156
153
  return False
@@ -162,11 +159,11 @@ class Util:
162
159
  return True
163
160
 
164
161
  @staticmethod
165
- def localtime_str(self):
162
+ def localtime_str():
166
163
  return time.strftime("%Y%m%d%H%M%S", time.localtime())
167
164
 
168
165
  @staticmethod
169
- def change_filemode_safe(self, path):
166
+ def change_filemode_safe(path):
170
167
  change_mode(path, FileCheckConst.DATA_FILE_AUTHORITY)
171
168
 
172
169
  @staticmethod
@@ -183,7 +180,7 @@ class Util:
183
180
  if not cmd:
184
181
  self.log.error("Commond is None")
185
182
  return -1
186
- self.log.info("[RUN CMD]: %s", cmd)
183
+ self.log.info("[RUN CMD]: %s" % cmd)
187
184
  cmd = cmd.split(" ")
188
185
  complete_process = subprocess.run(cmd, shell=False)
189
186
  return complete_process.returncode
@@ -205,7 +202,7 @@ class Util:
205
202
  result = subprocess.run(
206
203
  [self.python, target_file, "--help"], stdout=subprocess.PIPE, shell=False)
207
204
  if result.returncode == 0:
208
- self.log.info("Check [%s] success.", target_file)
205
+ self.log.info("Check [%s] success." % (target_file))
209
206
  else:
210
207
  self.log.error("Check msaccucmp failed in dir %s" % target_file)
211
208
  self.log.error("Please specify a valid msaccucmp.py path or install the cann package")
@@ -244,8 +241,11 @@ class Util:
244
241
 
245
242
  def check_path_valid(self, path):
246
243
  path = self.path_strip(path)
247
- path_checker = FileChecker(path)
248
- path_checker.common_check()
244
+ if not path or not os.path.exists(path):
245
+ self.log.error("The path %s does not exist." % path)
246
+ raise ParseException(ParseException.PARSE_INVALID_PATH_ERROR)
247
+ isdir = check_file_type(path) == FileCheckConst.DIR
248
+ check_file_or_directory_path(path, isdir=isdir)
249
249
  return True
250
250
 
251
251
  def check_files_in_path(self, path):
@@ -274,8 +274,7 @@ class Util:
274
274
  file_list = {}
275
275
  re_pattern = re.compile(pattern)
276
276
  for dir_path, _, file_names in os.walk(path, topdown=True):
277
- path_checker = FileChecker(dir)
278
- path_checker.common_check()
277
+ check_file_or_directory_path(dir_path, isdir=True)
279
278
  for name in file_names:
280
279
  match = re_pattern.match(name)
281
280
  if not match:
@@ -314,3 +313,8 @@ class Util:
314
313
  dir1_count = self.get_subdir_count(dir1)
315
314
  dir2_count = self.get_subdir_count(dir2)
316
315
  return dir1_count == dir2_count
316
+
317
+ def check_positive(self, value):
318
+ if value <= 0.0:
319
+ self.log.error("Invalid value. It must be greater than 0.")
320
+ raise ParseException(ParseException.PARSE_INVALID_DATA_ERROR)
@@ -28,7 +28,7 @@ class Visualization:
28
28
  self.util = Util()
29
29
 
30
30
  def print_npy_summary(self, target_file):
31
- np_data = load_npy(target_file, enable_pickle=True)
31
+ np_data = load_npy(target_file)
32
32
  table = self.util.create_table('', ['Index', 'Data'])
33
33
  flatten_data = np_data.flatten()
34
34
  tablesize = 8