mindstudio-probe 1.0.4__py3-none-any.whl → 1.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (194) hide show
  1. {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.0.dist-info}/METADATA +1 -1
  2. mindstudio_probe-1.1.0.dist-info/RECORD +287 -0
  3. msprobe/README.md +46 -16
  4. msprobe/__init__.py +16 -1
  5. msprobe/config.json +0 -2
  6. msprobe/core/advisor/advisor.py +8 -8
  7. msprobe/core/advisor/advisor_const.py +6 -7
  8. msprobe/core/advisor/advisor_result.py +12 -12
  9. msprobe/core/common/const.py +64 -3
  10. msprobe/core/common/exceptions.py +2 -2
  11. msprobe/core/common/file_utils.py +54 -9
  12. msprobe/core/common/inplace_op_checker.py +38 -0
  13. msprobe/core/common/inplace_ops.yaml +251 -0
  14. msprobe/core/common/log.py +21 -11
  15. msprobe/core/common/utils.py +153 -167
  16. msprobe/core/common_config.py +18 -25
  17. msprobe/core/compare/acc_compare.py +209 -36
  18. msprobe/core/compare/check.py +102 -17
  19. msprobe/core/compare/compare_cli.py +21 -1
  20. msprobe/core/compare/highlight.py +41 -5
  21. msprobe/core/compare/multiprocessing_compute.py +33 -8
  22. msprobe/core/compare/npy_compare.py +21 -6
  23. msprobe/core/compare/utils.py +82 -48
  24. msprobe/core/data_dump/data_collector.py +31 -32
  25. msprobe/core/data_dump/data_processor/base.py +45 -22
  26. msprobe/core/data_dump/data_processor/factory.py +20 -3
  27. msprobe/core/data_dump/data_processor/mindspore_processor.py +11 -5
  28. msprobe/core/data_dump/data_processor/pytorch_processor.py +24 -7
  29. msprobe/core/data_dump/json_writer.py +63 -42
  30. msprobe/core/data_dump/scope.py +32 -16
  31. msprobe/core/grad_probe/constant.py +4 -0
  32. msprobe/core/grad_probe/grad_compare.py +2 -3
  33. msprobe/core/grad_probe/utils.py +16 -3
  34. msprobe/docs/01.installation.md +19 -9
  35. msprobe/docs/02.config_introduction.md +52 -80
  36. msprobe/docs/03.config_examples.md +3 -13
  37. msprobe/docs/04.acl_config_examples.md +11 -9
  38. msprobe/docs/05.data_dump_PyTorch.md +140 -12
  39. msprobe/docs/06.data_dump_MindSpore.md +47 -5
  40. msprobe/docs/07.accuracy_checker_PyTorch.md +57 -34
  41. msprobe/docs/08.accuracy_checker_online_PyTorch.md +51 -11
  42. msprobe/docs/09.accuracy_checker_MindSpore.md +8 -8
  43. msprobe/docs/10.accuracy_compare_PyTorch.md +181 -99
  44. msprobe/docs/11.accuracy_compare_MindSpore.md +162 -31
  45. msprobe/docs/13.overflow_check_MindSpore.md +1 -1
  46. msprobe/docs/15.free_benchmarking_PyTorch.md +59 -53
  47. msprobe/docs/16.free_benchmarking_MindSpore.md +140 -0
  48. msprobe/docs/17.grad_probe.md +14 -16
  49. msprobe/docs/18.online_dispatch.md +89 -0
  50. msprobe/docs/{FAQ_PyTorch.md → FAQ.md} +22 -10
  51. msprobe/docs/img/ms_dump.png +0 -0
  52. msprobe/docs/img/ms_layer.png +0 -0
  53. msprobe/docs/img/pt_dump.png +0 -0
  54. msprobe/mindspore/__init__.py +1 -0
  55. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +35 -11
  56. msprobe/mindspore/api_accuracy_checker/api_info.py +7 -0
  57. msprobe/mindspore/cell_processor.py +27 -3
  58. msprobe/mindspore/common/const.py +2 -0
  59. msprobe/mindspore/common/utils.py +18 -2
  60. msprobe/mindspore/compare/distributed_compare.py +9 -22
  61. msprobe/mindspore/compare/layer_mapping.py +146 -0
  62. msprobe/mindspore/compare/modify_mapping.py +107 -0
  63. msprobe/mindspore/compare/ms_compare.py +173 -35
  64. msprobe/mindspore/compare/ms_graph_compare.py +27 -11
  65. msprobe/mindspore/debugger/debugger_config.py +16 -13
  66. msprobe/mindspore/debugger/precision_debugger.py +37 -13
  67. msprobe/mindspore/dump/dump_tool_factory.py +16 -1
  68. msprobe/mindspore/dump/hook_cell/api_registry.py +11 -1
  69. msprobe/mindspore/dump/hook_cell/primitive_hooks.py +206 -0
  70. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +82 -10
  71. msprobe/mindspore/dump/hook_cell/wrap_api.py +21 -13
  72. msprobe/mindspore/dump/jit_dump.py +41 -17
  73. msprobe/mindspore/dump/kernel_graph_dump.py +19 -3
  74. msprobe/mindspore/dump/kernel_kbyk_dump.py +19 -4
  75. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +19 -4
  76. msprobe/mindspore/free_benchmark/common/config.py +15 -0
  77. msprobe/mindspore/free_benchmark/common/handler_params.py +15 -0
  78. msprobe/mindspore/free_benchmark/common/utils.py +19 -5
  79. msprobe/mindspore/free_benchmark/decorator/dec_forward.py +16 -2
  80. msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +18 -3
  81. msprobe/mindspore/free_benchmark/handler/base_handler.py +18 -3
  82. msprobe/mindspore/free_benchmark/handler/check_handler.py +18 -3
  83. msprobe/mindspore/free_benchmark/handler/fix_handler.py +15 -0
  84. msprobe/mindspore/free_benchmark/handler/handler_factory.py +18 -3
  85. msprobe/mindspore/free_benchmark/perturbation/add_noise.py +22 -7
  86. msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +15 -0
  87. msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +22 -7
  88. msprobe/mindspore/free_benchmark/perturbation/exchange_value.py +44 -18
  89. msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +18 -4
  90. msprobe/mindspore/free_benchmark/perturbation/no_change.py +16 -1
  91. msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +20 -5
  92. msprobe/mindspore/free_benchmark/self_check_tool_factory.py +15 -0
  93. msprobe/mindspore/grad_probe/global_context.py +18 -8
  94. msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +20 -4
  95. msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +15 -0
  96. msprobe/mindspore/service.py +42 -123
  97. msprobe/pytorch/__init__.py +20 -1
  98. msprobe/pytorch/api_accuracy_checker/common/config.py +19 -2
  99. msprobe/pytorch/api_accuracy_checker/common/utils.py +53 -21
  100. msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +19 -2
  101. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +47 -21
  102. msprobe/pytorch/api_accuracy_checker/compare/compare.py +51 -21
  103. msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +23 -6
  104. msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +28 -8
  105. msprobe/pytorch/api_accuracy_checker/config.yaml +1 -1
  106. msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +67 -32
  107. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +26 -5
  108. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +19 -2
  109. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +51 -125
  110. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +146 -3
  111. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +21 -0
  112. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +78 -33
  113. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +27 -4
  114. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/dump_dispatch.py +110 -0
  115. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +36 -11
  116. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/torch_ops_config.yaml +63 -0
  117. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +44 -0
  118. msprobe/pytorch/bench_functions/__init__.py +18 -3
  119. msprobe/pytorch/bench_functions/apply_adam_w.py +15 -0
  120. msprobe/pytorch/bench_functions/confusion_transpose.py +15 -0
  121. msprobe/pytorch/bench_functions/fast_gelu.py +15 -0
  122. msprobe/pytorch/bench_functions/layer_norm_eval.py +15 -0
  123. msprobe/pytorch/bench_functions/linear.py +15 -0
  124. msprobe/pytorch/bench_functions/matmul_backward.py +21 -6
  125. msprobe/pytorch/bench_functions/npu_fusion_attention.py +180 -151
  126. msprobe/pytorch/bench_functions/rms_norm.py +15 -0
  127. msprobe/pytorch/bench_functions/rotary_mul.py +28 -9
  128. msprobe/pytorch/bench_functions/scaled_mask_softmax.py +15 -0
  129. msprobe/pytorch/bench_functions/swiglu.py +20 -5
  130. msprobe/pytorch/common/__init__.py +15 -0
  131. msprobe/pytorch/common/log.py +18 -6
  132. msprobe/pytorch/common/parse_json.py +26 -11
  133. msprobe/pytorch/common/utils.py +40 -35
  134. msprobe/pytorch/compare/distributed_compare.py +11 -11
  135. msprobe/pytorch/compare/match.py +15 -0
  136. msprobe/pytorch/compare/pt_compare.py +38 -6
  137. msprobe/pytorch/debugger/debugger_config.py +52 -39
  138. msprobe/pytorch/debugger/precision_debugger.py +72 -24
  139. msprobe/pytorch/free_benchmark/__init__.py +20 -5
  140. msprobe/pytorch/free_benchmark/common/enums.py +28 -0
  141. msprobe/pytorch/free_benchmark/common/params.py +15 -0
  142. msprobe/pytorch/free_benchmark/common/utils.py +17 -1
  143. msprobe/pytorch/free_benchmark/compare/grad_saver.py +28 -7
  144. msprobe/pytorch/free_benchmark/compare/single_benchmark.py +15 -0
  145. msprobe/pytorch/free_benchmark/main.py +19 -4
  146. msprobe/pytorch/free_benchmark/perturbed_layers/base_layer.py +15 -0
  147. msprobe/pytorch/free_benchmark/perturbed_layers/layer_factory.py +19 -4
  148. msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +15 -0
  149. msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +15 -0
  150. msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +26 -2
  151. msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +15 -0
  152. msprobe/pytorch/free_benchmark/perturbed_layers/npu/no_change.py +15 -0
  153. msprobe/pytorch/free_benchmark/perturbed_layers/npu/npu_base_layser.py +15 -0
  154. msprobe/pytorch/free_benchmark/perturbed_layers/run_cpu.py +15 -0
  155. msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +55 -16
  156. msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +15 -0
  157. msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +15 -0
  158. msprobe/pytorch/free_benchmark/result_handlers/handler_factory.py +15 -0
  159. msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +19 -4
  160. msprobe/pytorch/function_factory.py +17 -2
  161. msprobe/pytorch/functional/module_dump.py +84 -0
  162. msprobe/pytorch/grad_probe/grad_stat_csv.py +2 -2
  163. msprobe/pytorch/hook_module/__init__.py +16 -1
  164. msprobe/pytorch/hook_module/api_registry.py +13 -8
  165. msprobe/pytorch/hook_module/hook_module.py +17 -19
  166. msprobe/pytorch/hook_module/utils.py +4 -6
  167. msprobe/pytorch/hook_module/wrap_aten.py +12 -11
  168. msprobe/pytorch/hook_module/wrap_distributed.py +6 -7
  169. msprobe/pytorch/hook_module/wrap_functional.py +10 -11
  170. msprobe/pytorch/hook_module/wrap_npu_custom.py +9 -17
  171. msprobe/pytorch/hook_module/wrap_tensor.py +4 -6
  172. msprobe/pytorch/hook_module/wrap_torch.py +4 -6
  173. msprobe/pytorch/hook_module/wrap_vf.py +4 -6
  174. msprobe/pytorch/module_processer.py +17 -2
  175. msprobe/pytorch/online_dispatch/compare.py +11 -12
  176. msprobe/pytorch/online_dispatch/single_compare.py +7 -7
  177. msprobe/pytorch/online_dispatch/torch_ops_config.yaml +8 -0
  178. msprobe/pytorch/online_dispatch/utils.py +1 -4
  179. msprobe/pytorch/parse.py +15 -0
  180. msprobe/pytorch/parse_tool/cli.py +5 -6
  181. msprobe/pytorch/parse_tool/lib/compare.py +9 -10
  182. msprobe/pytorch/parse_tool/lib/parse_tool.py +3 -0
  183. msprobe/pytorch/parse_tool/lib/utils.py +28 -24
  184. msprobe/pytorch/parse_tool/lib/visualization.py +1 -1
  185. msprobe/pytorch/pt_config.py +167 -38
  186. msprobe/pytorch/service.py +97 -32
  187. mindstudio_probe-1.0.4.dist-info/RECORD +0 -276
  188. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/ssl_config.py +0 -10
  189. msprobe/pytorch/functional/data_processor.py +0 -0
  190. msprobe/pytorch/functional/dump_module.py +0 -39
  191. {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.0.dist-info}/LICENSE +0 -0
  192. {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.0.dist-info}/WHEEL +0 -0
  193. {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.0.dist-info}/entry_points.txt +0 -0
  194. {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.0.dist-info}/top_level.txt +0 -0
@@ -1,12 +1,33 @@
1
- import json
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
2
16
  import os
3
17
 
4
- from msprobe.core.common_config import CommonConfig, BaseConfig
5
- from msprobe.core.common.file_utils import FileOpen
6
18
  from msprobe.core.common.const import Const
7
- from msprobe.pytorch.hook_module.utils import get_ops
19
+ from msprobe.core.common.exceptions import MsprobeException
20
+ from msprobe.core.common.file_utils import FileOpen, load_json
21
+ from msprobe.core.common.log import logger
22
+ from msprobe.core.common_config import BaseConfig, CommonConfig
8
23
  from msprobe.core.grad_probe.constant import level_adp
9
- from msprobe.core.grad_probe.utils import check_numeral_list_ascend
24
+ from msprobe.core.grad_probe.utils import check_bounds
25
+ from msprobe.pytorch.free_benchmark.common.enums import (
26
+ DeviceType,
27
+ HandlerType,
28
+ PytorchFreeBenchmarkConst,
29
+ )
30
+ from msprobe.pytorch.hook_module.utils import get_ops
10
31
 
11
32
 
12
33
  class TensorConfig(BaseConfig):
@@ -16,7 +37,7 @@ class TensorConfig(BaseConfig):
16
37
  self.nfs_path = json_config.get("nfs_path", "")
17
38
  self.host = json_config.get("host", "")
18
39
  self.port = json_config.get("port", -1)
19
- self.tls_path = json_config.get("tls_path", "")
40
+ self.tls_path = json_config.get("tls_path", "./")
20
41
  self.check_config()
21
42
  self._check_file_format()
22
43
  self._check_tls_path_config()
@@ -26,13 +47,8 @@ class TensorConfig(BaseConfig):
26
47
  raise Exception("file_format is invalid")
27
48
 
28
49
  def _check_tls_path_config(self):
29
- if self.tls_path:
30
- if not os.path.exists(self.tls_path):
31
- raise Exception("tls_path: %s does not exist" % self.tls_path)
32
- if not os.path.exists(os.path.join(self.tls_path, "client.key")):
33
- raise Exception("tls_path does not contain client.key")
34
- if not os.path.exists(os.path.join(self.tls_path, "client.crt")):
35
- raise Exception("tls_path does not contain client.crt")
50
+ if self.tls_path and not os.path.exists(self.tls_path):
51
+ raise Exception("tls_path: %s does not exist" % self.tls_path)
36
52
 
37
53
 
38
54
  class StatisticsConfig(BaseConfig):
@@ -61,23 +77,142 @@ class OverflowCheckConfig(BaseConfig):
61
77
 
62
78
 
63
79
  class FreeBenchmarkCheckConfig(BaseConfig):
80
+
64
81
  def __init__(self, json_config):
65
82
  super().__init__(json_config)
66
- self.fuzz_device = json_config.get("fuzz_device")
67
- self.pert_mode = json_config.get("pert_mode")
68
- self.handler_type = json_config.get("handler_type")
69
- self.fuzz_level = json_config.get("fuzz_level")
70
- self.fuzz_stage = json_config.get("fuzz_stage")
71
- self.if_preheat = json_config.get("if_preheat")
72
- self.preheat_step = json_config.get("preheat_step")
73
- self.max_sample = json_config.get("max_sample")
83
+ self.fuzz_device = json_config.get("fuzz_device", PytorchFreeBenchmarkConst.DEFAULT_DEVICE)
84
+ self.pert_mode = json_config.get("pert_mode", PytorchFreeBenchmarkConst.DEFAULT_MODE)
85
+ self.handler_type = json_config.get("handler_type", PytorchFreeBenchmarkConst.DEFAULT_HANDLER)
86
+ self.fuzz_level = json_config.get("fuzz_level", PytorchFreeBenchmarkConst.DEFAULT_FUZZ_LEVEL)
87
+ self.fuzz_stage = json_config.get("fuzz_stage", PytorchFreeBenchmarkConst.DEFAULT_FUZZ_STAGE)
88
+ self.if_preheat = json_config.get("if_preheat", False)
89
+ self.preheat_step = json_config.get("preheat_step", PytorchFreeBenchmarkConst.DEFAULT_PREHEAT_STEP)
90
+ self.max_sample = json_config.get("max_sample", PytorchFreeBenchmarkConst.DEFAULT_PREHEAT_STEP)
74
91
  self.check_freebenchmark_config()
75
92
 
76
93
  def check_freebenchmark_config(self):
77
- if self.if_preheat and self.handler_type == "fix":
78
- raise Exception("Preheating is not supported in fix handler type")
79
- if self.preheat_step and self.preheat_step == 0:
80
- raise Exception("preheat_step cannot be 0")
94
+ self._check_pert_mode()
95
+ self._check_fuzz_device()
96
+ self._check_handler_type()
97
+ self._check_fuzz_stage()
98
+ self._check_fuzz_level()
99
+ self._check_if_preheat()
100
+ if self.handler_type == HandlerType.FIX:
101
+ self._check_fix_config()
102
+ if self.if_preheat:
103
+ self._check_preheat_config()
104
+
105
+ def _check_pert_mode(self):
106
+ if self.pert_mode not in PytorchFreeBenchmarkConst.PERTURBATION_MODE_LIST:
107
+ msg = (
108
+ f"pert_mode is invalid, it should be one of"
109
+ f" {PytorchFreeBenchmarkConst.PERTURBATION_MODE_LIST}"
110
+ )
111
+ logger.error_log_with_exp(
112
+ msg, MsprobeException(MsprobeException.INVALID_PARAM_ERROR, msg)
113
+ )
114
+
115
+ def _check_fuzz_device(self):
116
+ if self.fuzz_device not in PytorchFreeBenchmarkConst.DEVICE_LIST:
117
+ msg = (
118
+ f"fuzz_device is invalid, it should be one of"
119
+ f" {PytorchFreeBenchmarkConst.DEVICE_LIST}"
120
+ )
121
+ logger.error_log_with_exp(
122
+ msg, MsprobeException(MsprobeException.INVALID_PARAM_ERROR, msg)
123
+ )
124
+ if (self.fuzz_device == DeviceType.CPU) ^ (
125
+ self.pert_mode in PytorchFreeBenchmarkConst.CPU_MODE_LIST
126
+ ):
127
+ msg = (
128
+ f"You neet to and can only set fuzz_device as {DeviceType.CPU} "
129
+ f"when pert_mode in {PytorchFreeBenchmarkConst.CPU_MODE_LIST}"
130
+ )
131
+ logger.error_log_with_exp(
132
+ msg, MsprobeException(MsprobeException.INVALID_PARAM_ERROR, msg)
133
+ )
134
+
135
+ def _check_handler_type(self):
136
+ if self.handler_type not in PytorchFreeBenchmarkConst.HANDLER_LIST:
137
+ msg = (
138
+ f"handler_type is invalid, it should be one of"
139
+ f" {PytorchFreeBenchmarkConst.HANDLER_LIST}"
140
+ )
141
+ logger.error_log_with_exp(
142
+ msg, MsprobeException(MsprobeException.INVALID_PARAM_ERROR, msg)
143
+ )
144
+
145
+ def _check_fuzz_stage(self):
146
+ if self.fuzz_stage not in PytorchFreeBenchmarkConst.FUZZ_STAGE_LIST:
147
+ msg = (
148
+ f"fuzz_stage is invalid, it should be one of"
149
+ f" {PytorchFreeBenchmarkConst.FUZZ_STAGE_LIST}"
150
+ )
151
+ logger.error_log_with_exp(
152
+ msg, MsprobeException(MsprobeException.INVALID_PARAM_ERROR, msg)
153
+ )
154
+
155
+ def _check_fuzz_level(self):
156
+ if self.fuzz_level not in PytorchFreeBenchmarkConst.FUZZ_LEVEL_LIST:
157
+ msg = (
158
+ f"fuzz_level is invalid, it should be one of"
159
+ f" {PytorchFreeBenchmarkConst.FUZZ_LEVEL_LIST}"
160
+ )
161
+ logger.error_log_with_exp(
162
+ msg, MsprobeException(MsprobeException.INVALID_PARAM_ERROR, msg)
163
+ )
164
+
165
+ def _check_if_preheat(self):
166
+ if not isinstance(self.if_preheat, bool):
167
+ msg = "if_preheat is invalid, it should be a boolean"
168
+ logger.error_log_with_exp(
169
+ msg, MsprobeException(MsprobeException.INVALID_PARAM_ERROR, msg)
170
+ )
171
+
172
+ def _check_preheat_config(self):
173
+ if not isinstance(self.preheat_step, int):
174
+ msg = "preheat_step is invalid, it should be an integer"
175
+ logger.error_log_with_exp(
176
+ msg, MsprobeException(MsprobeException.INVALID_PARAM_ERROR, msg)
177
+ )
178
+ if self.preheat_step <= 0:
179
+ msg = "preheat_step must be greater than 0"
180
+ logger.error_log_with_exp(
181
+ msg, MsprobeException(MsprobeException.INVALID_PARAM_ERROR, msg)
182
+ )
183
+ if not isinstance(self.max_sample, int):
184
+ msg = "max_sample is invalid, it should be an integer"
185
+ logger.error_log_with_exp(
186
+ msg, MsprobeException(MsprobeException.INVALID_PARAM_ERROR, msg)
187
+ )
188
+ if self.max_sample <= 0:
189
+ msg = "max_sample must be greater than 0"
190
+ logger.error_log_with_exp(
191
+ msg, MsprobeException(MsprobeException.INVALID_PARAM_ERROR, msg)
192
+ )
193
+
194
+ def _check_fix_config(self):
195
+ if self.if_preheat:
196
+ msg = f"Preheating is not supported for {HandlerType.FIX} handler type"
197
+ logger.error_log_with_exp(
198
+ msg, MsprobeException(MsprobeException.INVALID_PARAM_ERROR, msg)
199
+ )
200
+ if self.fuzz_stage not in PytorchFreeBenchmarkConst.FIX_STAGE_LIST:
201
+ msg = (
202
+ f"The fuzz_stage when opening {HandlerType.FIX} handler must be one of "
203
+ f"{PytorchFreeBenchmarkConst.FIX_STAGE_LIST}"
204
+ )
205
+ logger.error_log_with_exp(
206
+ msg, MsprobeException(MsprobeException.INVALID_PARAM_ERROR, msg)
207
+ )
208
+ if self.pert_mode not in PytorchFreeBenchmarkConst.FIX_MODE_LIST:
209
+ msg = (
210
+ f"The pert_mode when opening {HandlerType.FIX} handler must be one of "
211
+ f"{PytorchFreeBenchmarkConst.FIX_MODE_LIST}"
212
+ )
213
+ logger.error_log_with_exp(
214
+ msg, MsprobeException(MsprobeException.INVALID_PARAM_ERROR, msg)
215
+ )
81
216
 
82
217
 
83
218
  class RunUTConfig(BaseConfig):
@@ -93,7 +228,7 @@ class RunUTConfig(BaseConfig):
93
228
  self.host = json_config.get("host", "")
94
229
  self.port = json_config.get("port", -1)
95
230
  self.rank_list = json_config.get("rank_list", Const.DEFAULT_LIST)
96
- self.tls_path = json_config.get("tls_path", "")
231
+ self.tls_path = json_config.get("tls_path", "./")
97
232
  self.check_run_ut_config()
98
233
 
99
234
  @classmethod
@@ -118,13 +253,8 @@ class RunUTConfig(BaseConfig):
118
253
 
119
254
  @classmethod
120
255
  def check_tls_path_config(cls, tls_path):
121
- if tls_path:
122
- if not os.path.exists(tls_path):
123
- raise Exception("tls_path: %s does not exist" % tls_path)
124
- if not os.path.exists(os.path.join(tls_path, "server.key")):
125
- raise Exception("tls_path does not contain server.key")
126
- if not os.path.exists(os.path.join(tls_path, "server.crt")):
127
- raise Exception("tls_path does not contain server.crt")
256
+ if tls_path and not os.path.exists(tls_path):
257
+ raise Exception("tls_path: %s does not exist" % tls_path)
128
258
 
129
259
  def check_run_ut_config(self):
130
260
  RunUTConfig.check_filter_list_config(Const.WHITE_LIST, self.white_list)
@@ -141,13 +271,13 @@ class GradToolConfig(BaseConfig):
141
271
  self.param_list = json_config.get("param_list", [])
142
272
  self.bounds = json_config.get("bounds", [-1, 0, 1])
143
273
  self._check_config()
144
-
274
+
145
275
  def _check_config(self):
146
276
  if self.grad_level not in level_adp.keys():
147
277
  raise Exception(f"grad_level must be one of {level_adp.keys()}")
148
278
  if not isinstance(self.param_list, list):
149
279
  raise Exception(f"param_list must be a list")
150
- check_numeral_list_ascend(self.bounds)
280
+ check_bounds(self.bounds)
151
281
 
152
282
 
153
283
  def parse_task_config(task, json_config):
@@ -178,10 +308,9 @@ def parse_json_config(json_file_path, task):
178
308
  if not json_file_path:
179
309
  config_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
180
310
  json_file_path = os.path.join(config_dir, "config.json")
181
- with FileOpen(json_file_path, 'r') as file:
182
- json_config = json.load(file)
311
+ json_config = load_json(json_file_path)
183
312
  common_config = CommonConfig(json_config)
184
- if task and task in Const.TASK_LIST:
313
+ if task:
185
314
  task_config = parse_task_config(task, json_config)
186
315
  else:
187
316
  task_config = parse_task_config(common_config.task, json_config)
@@ -1,3 +1,18 @@
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
1
16
  import functools
2
17
  import os
3
18
 
@@ -6,6 +21,7 @@ import torch
6
21
  from msprobe.core.common.const import Const
7
22
  from msprobe.core.common.exceptions import DistributedNotInitializedError, MsprobeException
8
23
  from msprobe.core.common.file_utils import create_directory
24
+ from msprobe.core.common.utils import print_tools_ends_info
9
25
  from msprobe.core.data_dump.data_collector import build_data_collector
10
26
  from msprobe.core.data_dump.data_processor.base import ModuleForwardInputsOutputs, ModuleBackwardInputsOutputs
11
27
  from msprobe.core.data_dump.scope import BaseScope
@@ -16,7 +32,10 @@ from msprobe.pytorch.hook_module.api_registry import api_register
16
32
  from msprobe.pytorch.hook_module.hook_module import HOOKModule
17
33
  from msprobe.pytorch.module_processer import ModuleProcesser
18
34
  from msprobe.pytorch.api_accuracy_checker.common.utils import ApiData
35
+
19
36
  torch_version_above_or_equal_2 = torch.__version__.split('+')[0] >= '2.0'
37
+ if torch_version_above_or_equal_2:
38
+ from msprobe.pytorch.api_accuracy_checker.tensor_transport_layer.dump_dispatch import run_ut_dispatch
20
39
 
21
40
  HookFn = namedtuple('hookFn', ['pre_hook', 'forward_hook', 'backward_hook', 'forward_hook_torch_version_below_2'])
22
41
 
@@ -32,6 +51,7 @@ class Service:
32
51
  self.first_start = True
33
52
  self.current_rank = None
34
53
  self.dump_iter_dir = None
54
+ self.should_stop_service = False
35
55
  self.attl = None
36
56
 
37
57
  @staticmethod
@@ -39,14 +59,29 @@ class Service:
39
59
  logger.info_on_rank_0("Data needed ends here.")
40
60
  api_register.api_originality()
41
61
 
62
+ @staticmethod
63
+ def is_registered_backward_hook(module):
64
+ if hasattr(module, '_backward_hooks') and \
65
+ len(module._backward_hooks) > 0 and \
66
+ module._is_full_backward_hook is False:
67
+ return True
68
+ return False
69
+
70
+ def check_register_full_backward_hook(self, module):
71
+ if self.is_registered_backward_hook(module):
72
+ module._backward_hooks.clear()
73
+ module._is_full_backward_hook = None
74
+ logger.warning("Found deprecated backward hooks. Removing them and switching to full backward hooks.")
75
+
42
76
  def build_hook(self, module_type, name):
43
77
  def pre_hook(api_or_module_name, module, args, kwargs):
78
+ if not self.should_execute_hook():
79
+ return args, kwargs
80
+
44
81
  if module_type == BaseScope.Module_Type_Module:
45
82
  api_or_module_name = module.mindstudio_reserved_name
46
83
  self.data_collector.update_api_or_module_name(api_or_module_name)
47
84
 
48
- if not self.switch:
49
- return args, kwargs
50
85
  if self.config.online_run_ut:
51
86
  return None, None
52
87
  if self.data_collector:
@@ -55,13 +90,13 @@ class Service:
55
90
  return args, kwargs
56
91
 
57
92
  def forward_hook(api_or_module_name, module, args, kwargs, output):
93
+ if not self.should_execute_hook():
94
+ return None
95
+
58
96
  if module_type == BaseScope.Module_Type_Module:
59
97
  api_or_module_name = module.mindstudio_reserved_name
60
98
  self.data_collector.update_api_or_module_name(api_or_module_name)
61
99
 
62
- if not self.switch:
63
- return None
64
-
65
100
  if self.config.online_run_ut:
66
101
  if self.data_collector.scope and not self.data_collector.scope.check(api_or_module_name):
67
102
  return None
@@ -80,18 +115,14 @@ class Service:
80
115
  return forward_hook(api_or_module_name, module, args, {}, output)
81
116
 
82
117
  def backward_hook(api_or_module_name, module, grad_input, grad_output):
118
+ if not self.should_execute_hook():
119
+ return
120
+
83
121
  if module_type == BaseScope.Module_Type_Module:
84
122
  api_or_module_name = module.mindstudio_reserved_name
85
123
  self.data_collector.update_api_or_module_name(api_or_module_name)
86
124
 
87
- if not self.switch:
88
- return
89
-
90
125
  if self.config.online_run_ut:
91
- if self.data_collector.scope and not self.data_collector.scope.check(api_or_module_name):
92
- return
93
- api_data = ApiData(name[:-1], grad_input, {}, grad_output, self.current_iter, self.current_rank)
94
- self.attl_send(api_data)
95
126
  return
96
127
 
97
128
  if self.data_collector:
@@ -105,26 +136,15 @@ class Service:
105
136
  pre_forward_hook_fn = functools.partial(pre_hook, forward_name_template)
106
137
  forward_hook_fn = functools.partial(forward_hook, forward_name_template)
107
138
  backward_hook_fn = functools.partial(backward_hook, backward_name_template)
108
- forward_hook_torch_version_below_2_fn = functools.partial(forward_hook_torch_version_below_2, forward_name_template)
139
+ forward_hook_torch_version_below_2_fn = functools.partial(forward_hook_torch_version_below_2,
140
+ forward_name_template)
109
141
  return HookFn(pre_forward_hook_fn, forward_hook_fn, backward_hook_fn, forward_hook_torch_version_below_2_fn)
110
142
 
111
- def step(self):
112
- self.current_iter += 1
113
- self.data_collector.update_iter(self.current_iter)
114
-
115
- ModuleProcesser.reset_module_stats()
116
- HOOKModule.reset_module_stats()
117
-
118
143
  def start(self, model, api_origin=False):
119
- self.model = model
120
- if self.config.step and self.current_iter > max(self.config.step):
121
- if self.config.online_run_ut:
122
- # send stop signal if online_run_ut
123
- self.attl_stop()
124
- self.stop()
125
- raise Exception("msprobe: exit after iteration {}".format(max(self.config.step)))
126
- if self.config.step and self.current_iter not in self.config.step:
144
+ if self.need_stop_service():
127
145
  return
146
+
147
+ self.model = model
128
148
  if self.first_start:
129
149
  try:
130
150
  self.current_rank = get_rank_if_initialized()
@@ -138,6 +158,8 @@ class Service:
138
158
  self.first_start = False
139
159
  if api_origin:
140
160
  api_register.api_modularity()
161
+ if self.config.online_run_ut and torch_version_above_or_equal_2:
162
+ run_ut_dispatch(self.attl, True)
141
163
  self.switch = True
142
164
  logger.info_on_rank_0(f"Dump switch is turned on at step {self.current_iter}. ")
143
165
  if self.config.level != "L2" and not self.config.online_run_ut:
@@ -145,6 +167,8 @@ class Service:
145
167
  logger.info_on_rank_0(f"Dump data will be saved in {self.dump_iter_dir}.")
146
168
 
147
169
  def stop(self):
170
+ if self.should_stop_service:
171
+ return
148
172
  if self.config.level == "L2":
149
173
  return
150
174
  if self.config.step and self.current_iter not in self.config.step:
@@ -152,10 +176,47 @@ class Service:
152
176
  if self.config.rank and self.current_rank not in self.config.rank:
153
177
  return
154
178
  self.switch = False
155
- if self.config.online_run_ut:
179
+ if self.config.online_run_ut and torch_version_above_or_equal_2:
180
+ run_ut_dispatch(self.attl, False)
156
181
  return
157
182
  self.data_collector.write_json()
158
183
 
184
+ def step(self):
185
+ if self.should_stop_service:
186
+ return
187
+ self.current_iter += 1
188
+ self.data_collector.update_iter(self.current_iter)
189
+
190
+ ModuleProcesser.reset_module_stats()
191
+ HOOKModule.reset_module_stats()
192
+ self.data_collector.data_writer.reset_cache()
193
+
194
+ def need_stop_service(self):
195
+ if self.should_stop_service:
196
+ return True
197
+ end_service = self.config.step and self.current_iter > max(self.config.step) or \
198
+ self.data_collector and self.data_collector.data_processor.is_terminated
199
+ if end_service:
200
+ if self.config.online_run_ut:
201
+ # send stop signal if online_run_ut
202
+ self.attl_stop()
203
+ if self.config.level in [Const.LEVEL_L1, Const.LEVEL_L2, Const.LEVEL_MIX]:
204
+ api_register.api_originality()
205
+ self.switch = False
206
+ self.should_stop_service = True
207
+ print_tools_ends_info()
208
+ return True
209
+ if self.config.step and self.current_iter not in self.config.step:
210
+ return True
211
+ return False
212
+
213
+ def should_execute_hook(self):
214
+ if not self.switch:
215
+ return False
216
+ if self.data_collector and self.data_collector.data_processor.is_terminated:
217
+ return False
218
+ return True
219
+
159
220
  def create_dirs(self):
160
221
  create_directory(self.config.dump_path)
161
222
  self.dump_iter_dir = os.path.join(self.config.dump_path, f"step{self.current_iter}")
@@ -187,14 +248,16 @@ class Service:
187
248
  prefix = BaseScope.Module_Type_Module + Const.SEP + name + Const.SEP + \
188
249
  module.__class__.__name__ + Const.SEP
189
250
 
190
- pre_forward_hook, forward_hook, backward_hook, forward_hook_torch_version_below_2 \
191
- = self.build_hook(BaseScope.Module_Type_Module, prefix)
251
+ pre_forward_hook, forward_hook, backward_hook, forward_hook_torch_version_below_2 = self.build_hook(
252
+ BaseScope.Module_Type_Module, prefix)
192
253
  if torch_version_above_or_equal_2:
193
254
  module.register_forward_hook(forward_hook, with_kwargs=True)
194
255
  else:
256
+ self.check_register_full_backward_hook(module)
195
257
  module.register_full_backward_hook(
196
258
  self.module_processor.node_hook(prefix + Const.BACKWARD, Const.STOP))
197
259
  module.register_forward_hook(forward_hook_torch_version_below_2)
260
+ self.check_register_full_backward_hook(module)
198
261
  module.register_full_backward_hook(backward_hook)
199
262
 
200
263
  module.register_forward_pre_hook(
@@ -204,11 +267,13 @@ class Service:
204
267
  if torch_version_above_or_equal_2:
205
268
  module.register_full_backward_pre_hook(
206
269
  self.module_processor.node_hook(prefix + Const.BACKWARD, Const.START))
270
+ self.check_register_full_backward_hook(module)
207
271
  module.register_full_backward_hook(
208
272
  self.module_processor.node_hook(prefix + Const.BACKWARD, Const.STOP))
209
273
 
210
274
  if self.config.level in ["mix", "L1", "L2"]:
211
- api_register.initialize_hook(functools.partial(self.build_hook, BaseScope.Module_Type_API))
275
+ api_register.initialize_hook(functools.partial(self.build_hook, BaseScope.Module_Type_API),
276
+ self.config.online_run_ut)
212
277
  api_register.api_modularity()
213
278
 
214
279
  if Const.STATISTICS == self.config.task or Const.TENSOR == self.config.task: