mindstudio-probe 1.0.4__py3-none-any.whl → 1.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (194) hide show
  1. {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.0.dist-info}/METADATA +1 -1
  2. mindstudio_probe-1.1.0.dist-info/RECORD +287 -0
  3. msprobe/README.md +46 -16
  4. msprobe/__init__.py +16 -1
  5. msprobe/config.json +0 -2
  6. msprobe/core/advisor/advisor.py +8 -8
  7. msprobe/core/advisor/advisor_const.py +6 -7
  8. msprobe/core/advisor/advisor_result.py +12 -12
  9. msprobe/core/common/const.py +64 -3
  10. msprobe/core/common/exceptions.py +2 -2
  11. msprobe/core/common/file_utils.py +54 -9
  12. msprobe/core/common/inplace_op_checker.py +38 -0
  13. msprobe/core/common/inplace_ops.yaml +251 -0
  14. msprobe/core/common/log.py +21 -11
  15. msprobe/core/common/utils.py +153 -167
  16. msprobe/core/common_config.py +18 -25
  17. msprobe/core/compare/acc_compare.py +209 -36
  18. msprobe/core/compare/check.py +102 -17
  19. msprobe/core/compare/compare_cli.py +21 -1
  20. msprobe/core/compare/highlight.py +41 -5
  21. msprobe/core/compare/multiprocessing_compute.py +33 -8
  22. msprobe/core/compare/npy_compare.py +21 -6
  23. msprobe/core/compare/utils.py +82 -48
  24. msprobe/core/data_dump/data_collector.py +31 -32
  25. msprobe/core/data_dump/data_processor/base.py +45 -22
  26. msprobe/core/data_dump/data_processor/factory.py +20 -3
  27. msprobe/core/data_dump/data_processor/mindspore_processor.py +11 -5
  28. msprobe/core/data_dump/data_processor/pytorch_processor.py +24 -7
  29. msprobe/core/data_dump/json_writer.py +63 -42
  30. msprobe/core/data_dump/scope.py +32 -16
  31. msprobe/core/grad_probe/constant.py +4 -0
  32. msprobe/core/grad_probe/grad_compare.py +2 -3
  33. msprobe/core/grad_probe/utils.py +16 -3
  34. msprobe/docs/01.installation.md +19 -9
  35. msprobe/docs/02.config_introduction.md +52 -80
  36. msprobe/docs/03.config_examples.md +3 -13
  37. msprobe/docs/04.acl_config_examples.md +11 -9
  38. msprobe/docs/05.data_dump_PyTorch.md +140 -12
  39. msprobe/docs/06.data_dump_MindSpore.md +47 -5
  40. msprobe/docs/07.accuracy_checker_PyTorch.md +57 -34
  41. msprobe/docs/08.accuracy_checker_online_PyTorch.md +51 -11
  42. msprobe/docs/09.accuracy_checker_MindSpore.md +8 -8
  43. msprobe/docs/10.accuracy_compare_PyTorch.md +181 -99
  44. msprobe/docs/11.accuracy_compare_MindSpore.md +162 -31
  45. msprobe/docs/13.overflow_check_MindSpore.md +1 -1
  46. msprobe/docs/15.free_benchmarking_PyTorch.md +59 -53
  47. msprobe/docs/16.free_benchmarking_MindSpore.md +140 -0
  48. msprobe/docs/17.grad_probe.md +14 -16
  49. msprobe/docs/18.online_dispatch.md +89 -0
  50. msprobe/docs/{FAQ_PyTorch.md → FAQ.md} +22 -10
  51. msprobe/docs/img/ms_dump.png +0 -0
  52. msprobe/docs/img/ms_layer.png +0 -0
  53. msprobe/docs/img/pt_dump.png +0 -0
  54. msprobe/mindspore/__init__.py +1 -0
  55. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +35 -11
  56. msprobe/mindspore/api_accuracy_checker/api_info.py +7 -0
  57. msprobe/mindspore/cell_processor.py +27 -3
  58. msprobe/mindspore/common/const.py +2 -0
  59. msprobe/mindspore/common/utils.py +18 -2
  60. msprobe/mindspore/compare/distributed_compare.py +9 -22
  61. msprobe/mindspore/compare/layer_mapping.py +146 -0
  62. msprobe/mindspore/compare/modify_mapping.py +107 -0
  63. msprobe/mindspore/compare/ms_compare.py +173 -35
  64. msprobe/mindspore/compare/ms_graph_compare.py +27 -11
  65. msprobe/mindspore/debugger/debugger_config.py +16 -13
  66. msprobe/mindspore/debugger/precision_debugger.py +37 -13
  67. msprobe/mindspore/dump/dump_tool_factory.py +16 -1
  68. msprobe/mindspore/dump/hook_cell/api_registry.py +11 -1
  69. msprobe/mindspore/dump/hook_cell/primitive_hooks.py +206 -0
  70. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +82 -10
  71. msprobe/mindspore/dump/hook_cell/wrap_api.py +21 -13
  72. msprobe/mindspore/dump/jit_dump.py +41 -17
  73. msprobe/mindspore/dump/kernel_graph_dump.py +19 -3
  74. msprobe/mindspore/dump/kernel_kbyk_dump.py +19 -4
  75. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +19 -4
  76. msprobe/mindspore/free_benchmark/common/config.py +15 -0
  77. msprobe/mindspore/free_benchmark/common/handler_params.py +15 -0
  78. msprobe/mindspore/free_benchmark/common/utils.py +19 -5
  79. msprobe/mindspore/free_benchmark/decorator/dec_forward.py +16 -2
  80. msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +18 -3
  81. msprobe/mindspore/free_benchmark/handler/base_handler.py +18 -3
  82. msprobe/mindspore/free_benchmark/handler/check_handler.py +18 -3
  83. msprobe/mindspore/free_benchmark/handler/fix_handler.py +15 -0
  84. msprobe/mindspore/free_benchmark/handler/handler_factory.py +18 -3
  85. msprobe/mindspore/free_benchmark/perturbation/add_noise.py +22 -7
  86. msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +15 -0
  87. msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +22 -7
  88. msprobe/mindspore/free_benchmark/perturbation/exchange_value.py +44 -18
  89. msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +18 -4
  90. msprobe/mindspore/free_benchmark/perturbation/no_change.py +16 -1
  91. msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +20 -5
  92. msprobe/mindspore/free_benchmark/self_check_tool_factory.py +15 -0
  93. msprobe/mindspore/grad_probe/global_context.py +18 -8
  94. msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +20 -4
  95. msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +15 -0
  96. msprobe/mindspore/service.py +42 -123
  97. msprobe/pytorch/__init__.py +20 -1
  98. msprobe/pytorch/api_accuracy_checker/common/config.py +19 -2
  99. msprobe/pytorch/api_accuracy_checker/common/utils.py +53 -21
  100. msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +19 -2
  101. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +47 -21
  102. msprobe/pytorch/api_accuracy_checker/compare/compare.py +51 -21
  103. msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +23 -6
  104. msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +28 -8
  105. msprobe/pytorch/api_accuracy_checker/config.yaml +1 -1
  106. msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +67 -32
  107. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +26 -5
  108. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +19 -2
  109. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +51 -125
  110. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +146 -3
  111. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +21 -0
  112. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +78 -33
  113. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +27 -4
  114. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/dump_dispatch.py +110 -0
  115. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +36 -11
  116. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/torch_ops_config.yaml +63 -0
  117. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +44 -0
  118. msprobe/pytorch/bench_functions/__init__.py +18 -3
  119. msprobe/pytorch/bench_functions/apply_adam_w.py +15 -0
  120. msprobe/pytorch/bench_functions/confusion_transpose.py +15 -0
  121. msprobe/pytorch/bench_functions/fast_gelu.py +15 -0
  122. msprobe/pytorch/bench_functions/layer_norm_eval.py +15 -0
  123. msprobe/pytorch/bench_functions/linear.py +15 -0
  124. msprobe/pytorch/bench_functions/matmul_backward.py +21 -6
  125. msprobe/pytorch/bench_functions/npu_fusion_attention.py +180 -151
  126. msprobe/pytorch/bench_functions/rms_norm.py +15 -0
  127. msprobe/pytorch/bench_functions/rotary_mul.py +28 -9
  128. msprobe/pytorch/bench_functions/scaled_mask_softmax.py +15 -0
  129. msprobe/pytorch/bench_functions/swiglu.py +20 -5
  130. msprobe/pytorch/common/__init__.py +15 -0
  131. msprobe/pytorch/common/log.py +18 -6
  132. msprobe/pytorch/common/parse_json.py +26 -11
  133. msprobe/pytorch/common/utils.py +40 -35
  134. msprobe/pytorch/compare/distributed_compare.py +11 -11
  135. msprobe/pytorch/compare/match.py +15 -0
  136. msprobe/pytorch/compare/pt_compare.py +38 -6
  137. msprobe/pytorch/debugger/debugger_config.py +52 -39
  138. msprobe/pytorch/debugger/precision_debugger.py +72 -24
  139. msprobe/pytorch/free_benchmark/__init__.py +20 -5
  140. msprobe/pytorch/free_benchmark/common/enums.py +28 -0
  141. msprobe/pytorch/free_benchmark/common/params.py +15 -0
  142. msprobe/pytorch/free_benchmark/common/utils.py +17 -1
  143. msprobe/pytorch/free_benchmark/compare/grad_saver.py +28 -7
  144. msprobe/pytorch/free_benchmark/compare/single_benchmark.py +15 -0
  145. msprobe/pytorch/free_benchmark/main.py +19 -4
  146. msprobe/pytorch/free_benchmark/perturbed_layers/base_layer.py +15 -0
  147. msprobe/pytorch/free_benchmark/perturbed_layers/layer_factory.py +19 -4
  148. msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +15 -0
  149. msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +15 -0
  150. msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +26 -2
  151. msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +15 -0
  152. msprobe/pytorch/free_benchmark/perturbed_layers/npu/no_change.py +15 -0
  153. msprobe/pytorch/free_benchmark/perturbed_layers/npu/npu_base_layser.py +15 -0
  154. msprobe/pytorch/free_benchmark/perturbed_layers/run_cpu.py +15 -0
  155. msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +55 -16
  156. msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +15 -0
  157. msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +15 -0
  158. msprobe/pytorch/free_benchmark/result_handlers/handler_factory.py +15 -0
  159. msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +19 -4
  160. msprobe/pytorch/function_factory.py +17 -2
  161. msprobe/pytorch/functional/module_dump.py +84 -0
  162. msprobe/pytorch/grad_probe/grad_stat_csv.py +2 -2
  163. msprobe/pytorch/hook_module/__init__.py +16 -1
  164. msprobe/pytorch/hook_module/api_registry.py +13 -8
  165. msprobe/pytorch/hook_module/hook_module.py +17 -19
  166. msprobe/pytorch/hook_module/utils.py +4 -6
  167. msprobe/pytorch/hook_module/wrap_aten.py +12 -11
  168. msprobe/pytorch/hook_module/wrap_distributed.py +6 -7
  169. msprobe/pytorch/hook_module/wrap_functional.py +10 -11
  170. msprobe/pytorch/hook_module/wrap_npu_custom.py +9 -17
  171. msprobe/pytorch/hook_module/wrap_tensor.py +4 -6
  172. msprobe/pytorch/hook_module/wrap_torch.py +4 -6
  173. msprobe/pytorch/hook_module/wrap_vf.py +4 -6
  174. msprobe/pytorch/module_processer.py +17 -2
  175. msprobe/pytorch/online_dispatch/compare.py +11 -12
  176. msprobe/pytorch/online_dispatch/single_compare.py +7 -7
  177. msprobe/pytorch/online_dispatch/torch_ops_config.yaml +8 -0
  178. msprobe/pytorch/online_dispatch/utils.py +1 -4
  179. msprobe/pytorch/parse.py +15 -0
  180. msprobe/pytorch/parse_tool/cli.py +5 -6
  181. msprobe/pytorch/parse_tool/lib/compare.py +9 -10
  182. msprobe/pytorch/parse_tool/lib/parse_tool.py +3 -0
  183. msprobe/pytorch/parse_tool/lib/utils.py +28 -24
  184. msprobe/pytorch/parse_tool/lib/visualization.py +1 -1
  185. msprobe/pytorch/pt_config.py +167 -38
  186. msprobe/pytorch/service.py +97 -32
  187. mindstudio_probe-1.0.4.dist-info/RECORD +0 -276
  188. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/ssl_config.py +0 -10
  189. msprobe/pytorch/functional/data_processor.py +0 -0
  190. msprobe/pytorch/functional/dump_module.py +0 -39
  191. {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.0.dist-info}/LICENSE +0 -0
  192. {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.0.dist-info}/WHEEL +0 -0
  193. {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.0.dist-info}/entry_points.txt +0 -0
  194. {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.0.dist-info}/top_level.txt +0 -0
@@ -1,13 +1,27 @@
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
1
16
  from typing import Any
2
17
 
3
18
  import mindspore as ms
4
19
  from mindspore import Tensor, ops
5
20
 
6
- from msprobe.mindspore.free_benchmark.perturbation.base_perturbation import BasePerturbation
7
- from msprobe.mindspore.free_benchmark.common.handler_params import HandlerParams
8
- from msprobe.mindspore.common.const import FreeBenchmarkConst
9
- from msprobe.mindspore.common.log import logger
10
21
  from msprobe.mindspore.common.const import Const
22
+ from msprobe.mindspore.common.log import logger
23
+ from msprobe.mindspore.free_benchmark.common.handler_params import HandlerParams
24
+ from msprobe.mindspore.free_benchmark.perturbation.base_perturbation import BasePerturbation
11
25
 
12
26
 
13
27
  class ImprovePrecisionPerturbation(BasePerturbation):
@@ -1,7 +1,22 @@
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
1
16
  from typing import Any
2
17
 
3
- from msprobe.mindspore.free_benchmark.perturbation.base_perturbation import BasePerturbation
4
18
  from msprobe.mindspore.free_benchmark.common.handler_params import HandlerParams
19
+ from msprobe.mindspore.free_benchmark.perturbation.base_perturbation import BasePerturbation
5
20
 
6
21
 
7
22
  class NoChangePerturbation(BasePerturbation):
@@ -1,10 +1,25 @@
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
1
16
  from msprobe.mindspore.common.const import FreeBenchmarkConst
2
17
  from msprobe.mindspore.free_benchmark.common.config import Config
3
- from .add_noise import AddNoisePerturbation
4
- from .bit_noise import BitNoisePerturbation
5
- from .no_change import NoChangePerturbation
6
- from .improve_precision import ImprovePrecisionPerturbation
7
- from .exchange_value import ExchangeValuePerturbation
18
+ from msprobe.mindspore.free_benchmark.perturbation.add_noise import AddNoisePerturbation
19
+ from msprobe.mindspore.free_benchmark.perturbation.bit_noise import BitNoisePerturbation
20
+ from msprobe.mindspore.free_benchmark.perturbation.exchange_value import ExchangeValuePerturbation
21
+ from msprobe.mindspore.free_benchmark.perturbation.improve_precision import ImprovePrecisionPerturbation
22
+ from msprobe.mindspore.free_benchmark.perturbation.no_change import NoChangePerturbation
8
23
 
9
24
 
10
25
  class PerturbationFactory:
@@ -1,3 +1,18 @@
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
1
16
  from msprobe.mindspore.common.const import Const
2
17
  from msprobe.mindspore.debugger.debugger_config import DebuggerConfig
3
18
  from msprobe.mindspore.free_benchmark.api_pynative_self_check import ApiPyNativeSelFCheck
@@ -1,8 +1,8 @@
1
1
  import os
2
2
  import threading
3
- from typing import Dict, Union
3
+ from typing import Dict, Union, Tuple
4
4
 
5
- from msprobe.core.grad_probe.utils import check_str
5
+ from msprobe.core.grad_probe.utils import check_str, check_bounds_element
6
6
  from msprobe.core.grad_probe.constant import GradConst
7
7
  from msprobe.mindspore.common.log import logger
8
8
  from msprobe.core.common.file_utils import create_directory, check_path_before_create
@@ -18,7 +18,7 @@ class GlobalContext:
18
18
  GradConst.STEP: None,
19
19
  GradConst.RANK: None,
20
20
  GradConst.CURRENT_STEP: 0,
21
- GradConst.BOUNDS: [-10, -1, -0.1, -0.01, -0.001, 0, 0.001, 0.01, 0.1, 1, 10],
21
+ GradConst.BOUNDS: [-1, 0, 1],
22
22
  GradConst.OUTPUT_PATH: None
23
23
  }
24
24
 
@@ -31,19 +31,19 @@ class GlobalContext:
31
31
 
32
32
  def init_context(self, config_dict: Dict):
33
33
  level = config_dict.get(GradConst.LEVEL)
34
- check_str(level, variable_name = "level in yaml")
34
+ check_str(level, variable_name="level in yaml")
35
35
  if level in GradConst.SUPPORTED_LEVEL:
36
36
  self._setting[GradConst.LEVEL] = config_dict.get(GradConst.LEVEL)
37
37
  else:
38
38
  raise ValueError("Invalid level set in config yaml file, level option: L0, L1, L2")
39
39
 
40
40
  self._set_input_list(config_dict, GradConst.PARAM_LIST, str)
41
- self._set_input_list(config_dict, GradConst.BOUNDS, float)
41
+ self._set_input_list(config_dict, GradConst.BOUNDS, (float, int), element_check=check_bounds_element)
42
42
  self._set_input_list(config_dict, GradConst.STEP, int)
43
43
  self._set_input_list(config_dict, GradConst.RANK, int)
44
44
 
45
45
  output_path = config_dict.get(GradConst.OUTPUT_PATH)
46
- check_str(output_path, variable_name = "output_path in yaml")
46
+ check_str(output_path, variable_name="output_path in yaml")
47
47
  try:
48
48
  check_path_before_create(output_path)
49
49
  except RuntimeError as err:
@@ -70,19 +70,29 @@ class GlobalContext:
70
70
  dump_rank_list = self.get_context(GradConst.RANK)
71
71
  return (not dump_rank_list) or (rank in dump_rank_list)
72
72
 
73
- def _set_input_list(self, config_dict: Dict, name: str, dtype: Union[int, str, float]):
74
- value = config_dict.get(name)
73
+ def _get_type_str(self, dtype: Union[int, str, float, Tuple[int, str, float]]):
74
+ if isinstance(dtype, tuple):
75
+ return "/".join([self._get_type_str(element) for element in dtype])
75
76
  if dtype == int:
76
77
  type_str = "integer"
77
78
  elif dtype == float:
78
79
  type_str = "float"
79
80
  else:
80
81
  type_str = "string"
82
+ return type_str
83
+
84
+ def _set_input_list(self, config_dict: Dict, name: str,
85
+ dtype: Union[int, str, float, Tuple[int, str, float]], element_check=None):
86
+ value = config_dict.get(name)
87
+ type_str = self._get_type_str(dtype)
81
88
  if value and isinstance(value, list):
82
89
  for val in value:
83
90
  if not isinstance(val, dtype):
84
91
  logger.warning(f"Invalid {name} which must be None or list of {type_str}")
85
92
  return
93
+ if element_check and not element_check(val):
94
+ logger.warning(f"Given {name} violates some rules.")
95
+ return
86
96
  self._setting[name] = value
87
97
  else:
88
98
  logger.warning(f"{name} is None or not a list with valid items, use default value.")
@@ -1,8 +1,24 @@
1
- import os
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
2
16
  import json
3
- from msprobe.mindspore.debugger.debugger_config import DebuggerConfig
4
- from msprobe.mindspore.common.log import logger
17
+ import os
18
+
5
19
  from msprobe.core.common.file_utils import FileOpen, create_directory
20
+ from msprobe.mindspore.common.log import logger
21
+ from msprobe.mindspore.debugger.debugger_config import DebuggerConfig
6
22
 
7
23
 
8
24
  class KernelGraphOverflowCheck:
@@ -16,7 +32,7 @@ class KernelGraphOverflowCheck:
16
32
  self.dump_json["common_dump_settings"]["saved_data"] = "full"
17
33
  self.dump_json["common_dump_settings"]["input_output"] = 0
18
34
  self.dump_json["common_dump_settings"]["kernels"] = []
19
- self.dump_json["common_dump_settings"]["support_device"] = [0,1,2,3,4,5,6,7]
35
+ self.dump_json["common_dump_settings"]["support_device"] = [0, 1, 2, 3, 4, 5, 6, 7]
20
36
  self.dump_json["common_dump_settings"]["op_debug_mode"] = 3
21
37
  self.dump_json["common_dump_settings"]["file_format"] = "npy"
22
38
 
@@ -1,3 +1,18 @@
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
1
16
  from msprobe.mindspore.common.const import Const
2
17
  from msprobe.mindspore.debugger.debugger_config import DebuggerConfig
3
18
  from msprobe.mindspore.overflow_check.kernel_graph_overflow_check import KernelGraphOverflowCheck
@@ -34,9 +34,10 @@ from msprobe.core.data_dump.scope import BaseScope
34
34
  from msprobe.mindspore.common.utils import get_rank_if_initialized
35
35
  from msprobe.core.common.file_utils import create_directory
36
36
  from msprobe.mindspore.common.log import logger
37
- from msprobe.core.common.utils import Const
37
+ from msprobe.core.common.utils import Const, print_tools_ends_info
38
38
  from msprobe.core.common.exceptions import DistributedNotInitializedError
39
39
  from msprobe.mindspore.dump.hook_cell.api_registry import api_register
40
+ from msprobe.mindspore.dump.hook_cell.primitive_hooks import PrimitiveHookService
40
41
  from msprobe.core.data_dump.data_processor.base import ModuleBackwardInputsOutputs, ModuleForwardInputsOutputs, \
41
42
  ModuleBackwardInputs, ModuleBackwardOutputs
42
43
  from msprobe.core.common.exceptions import MsprobeException
@@ -52,11 +53,12 @@ class Service:
52
53
  self.config.level = self.config.level_ori
53
54
  self.data_collector = build_data_collector(self.config)
54
55
  self.cell_processor = CellProcessor(self.data_collector.scope)
56
+ self.primitive_hook_service = PrimitiveHookService(self)
55
57
  self.switch = False
58
+ self.primitive_switch = False
56
59
  self.current_iter = 0
57
60
  self.first_start = True
58
61
  self.current_rank = None
59
- self.primitive_counters = {}
60
62
  self.dump_iter_dir = None
61
63
  self.start_call = False
62
64
  self.check_level_valid()
@@ -71,7 +73,7 @@ class Service:
71
73
  )
72
74
 
73
75
  def check_level_valid(self):
74
- if self.config.level == "L2":
76
+ if self.config.level == Const.LEVEL_L2:
75
77
  raise MsprobeException(
76
78
  MsprobeException.INVALID_PARAM_ERROR, "L2 level dump function is currently not supported."
77
79
  )
@@ -122,113 +124,6 @@ class Service:
122
124
 
123
125
  return wrap_forward_hook, wrap_backward_hook
124
126
 
125
- def wrap_primitive(self, origin_func, primitive_name):
126
- service_instance = self
127
-
128
- def create_backward_hook(captured_grads, num_tensors, updated_primitive_name, hook_type):
129
- def backward_hook(grad):
130
- captured_grads.append(grad)
131
- backward_primitive_name = f"{updated_primitive_name}.{Const.BACKWARD}"
132
- try:
133
- if len(captured_grads) == num_tensors and hook_type == Const.INPUT:
134
- service_instance.data_collector.update_api_or_module_name(backward_primitive_name)
135
- new_module_input_output = ModuleBackwardOutputs(grad_output=tuple(captured_grads))
136
- service_instance.data_collector.backward_output_data_collect(
137
- backward_primitive_name, service_instance, os.getpid(), new_module_input_output
138
- )
139
- captured_grads.clear()
140
- elif len(captured_grads) == num_tensors and hook_type == Const.OUTPUT:
141
- service_instance.data_collector.update_api_or_module_name(backward_primitive_name)
142
- new_module_input_output = ModuleBackwardInputs(grad_input=tuple(captured_grads))
143
- service_instance.data_collector.backward_input_data_collect(
144
- backward_primitive_name, service_instance, os.getpid(), new_module_input_output
145
- )
146
- captured_grads.clear()
147
-
148
- except Exception as exception:
149
- raise Exception(f"This is a primitive op {hook_type}_backward dump error: {exception},"
150
- f" updated_primitive_name: {updated_primitive_name}") from exception
151
-
152
- return backward_hook
153
-
154
- def hook_primitive_inputs(args, captured_grads_input, updated_primitive_name):
155
- hooked_inputs = []
156
- num_tensors = sum(isinstance(arg, Tensor) for arg in args)
157
- input_backward_hook = create_backward_hook(captured_grads_input, num_tensors, updated_primitive_name,
158
- Const.INPUT)
159
- for _, arg in enumerate(args):
160
- if isinstance(arg, Tensor):
161
- arg_hooked = ops.HookBackward(input_backward_hook)(arg)
162
- hooked_inputs.append(arg_hooked)
163
- else:
164
- hooked_inputs.append(arg)
165
- return hooked_inputs
166
-
167
- def hook_primitive_outputs(out, captured_grads_output, updated_primitive_name):
168
- if isinstance(out, tuple):
169
- num_output_tensors = sum(isinstance(tensor, Tensor) for tensor in out)
170
- else:
171
- num_output_tensors = 1
172
- output_backward_hook = create_backward_hook(captured_grads_output, num_output_tensors,
173
- updated_primitive_name, Const.OUTPUT)
174
-
175
- if isinstance(out, Tensor):
176
- return ops.HookBackward(output_backward_hook)(out)
177
- elif isinstance(out, tuple):
178
- hooked_outputs = []
179
- for tensor in out:
180
- if isinstance(tensor, Tensor):
181
- hooked_outputs.append(ops.HookBackward(output_backward_hook)(tensor))
182
- else:
183
- hooked_outputs.append(tensor)
184
- return tuple(hooked_outputs)
185
- return out
186
-
187
- def wrapped_primitive_call(instance_self, *args, **kwargs):
188
- service_instance.update_primitive_counters(primitive_name)
189
- current_count = service_instance.primitive_counters.get(primitive_name, 0)
190
- updated_primitive_name = f"{Const.PRIMITIVE_PREFIX}.{primitive_name}.{current_count}"
191
-
192
- if not service_instance.switch:
193
- return origin_func(*args, **kwargs)
194
-
195
- captured_grads_input, captured_grads_output = [], []
196
-
197
- try:
198
- hooked_inputs = hook_primitive_inputs(args, captured_grads_input, updated_primitive_name)
199
- except Exception as exception:
200
- raise Exception("This is a primitive op dump error during input hooking: {},"
201
- " primitive_name: {}".format(exception, primitive_name)) from exception
202
-
203
- try:
204
- out = origin_func(*hooked_inputs, **kwargs)
205
- except Exception as exception:
206
- raise Exception("This is a primitive op dump error during function call: {},"
207
- " primitive_name: {}".format(exception, primitive_name)) from exception
208
-
209
- forward_primitive_name = f"{updated_primitive_name}.{Const.FORWARD}"
210
- service_instance.data_collector.update_api_or_module_name(forward_primitive_name)
211
- if service_instance.data_collector:
212
- module_input_output = ModuleForwardInputsOutputs(args=hooked_inputs, kwargs=kwargs, output=out)
213
- try:
214
- service_instance.data_collector.forward_data_collect(forward_primitive_name, instance_self,
215
- os.getpid(), module_input_output)
216
- except Exception as exception:
217
- raise Exception("This is a primitive op dump error during forward data collection: {},"
218
- " primitive_name: {}".format(exception, primitive_name)) from exception
219
-
220
- if service_instance.data_collector.if_return_forward_new_output():
221
- out = service_instance.data_collector.get_forward_new_output()
222
-
223
- try:
224
- out = hook_primitive_outputs(out, captured_grads_output, updated_primitive_name)
225
- except Exception as exception:
226
- raise Exception("This is a primitive op dump error during output hooking: {},"
227
- " primitive_name: {}".format(exception, primitive_name)) from exception
228
-
229
- return out
230
-
231
- return wrapped_primitive_call
232
127
 
233
128
  def update_primitive_counters(self, primitive_name):
234
129
  if primitive_name not in self.primitive_counters:
@@ -236,7 +131,7 @@ class Service:
236
131
  else:
237
132
  self.primitive_counters[primitive_name] += 1
238
133
 
239
- def register_hooks(self):
134
+ def register_primitive_hooks(self):
240
135
  primitive_set = set()
241
136
  for _, cell in self.model.cells_and_names():
242
137
  for pname, primitive in cell._primitives.items():
@@ -244,15 +139,17 @@ class Service:
244
139
 
245
140
  for pname, primitive in primitive_set:
246
141
  NewPrimitive = type('NewPrimitive', (primitive.__class__,),
247
- {'__call__': self.wrap_primitive(primitive.__call__, pname)})
142
+ {'__call__': self.primitive_hook_service.wrap_primitive(primitive.__call__, pname)})
248
143
  primitive.__class__ = NewPrimitive
249
144
 
250
145
  def step(self):
251
146
  self.current_iter += 1
252
147
  self.data_collector.update_iter(self.current_iter)
253
148
  HOOKCell.cell_count = defaultdict(int)
254
- CellProcessor.cell_count = {}
255
- self.primitive_counters.clear()
149
+ CellProcessor.reset_cell_stats()
150
+ self.primitive_hook_service.primitive_counters.clear()
151
+ self.data_collector.data_writer.reset_cache()
152
+ JitDump.jit_count = defaultdict(int)
256
153
 
257
154
  def start(self, model=None):
258
155
  self.start_call = True
@@ -262,9 +159,8 @@ class Service:
262
159
  api_register.api_set_ori_func()
263
160
  self.should_stop_service = True
264
161
  self.switch = False
265
- logger.info("************************************************")
266
- logger.info(f"* {Const.TOOL_NAME} ends successfully. *")
267
- logger.info("************************************************")
162
+ self.primitive_switch = False
163
+ print_tools_ends_info()
268
164
  return
269
165
  if self.config.step and self.current_iter not in self.config.step:
270
166
  return
@@ -281,7 +177,7 @@ class Service:
281
177
  if self.config.rank and self.current_rank not in self.config.rank:
282
178
  return
283
179
  self.register_hook_new()
284
- if self.config.level == "L1":
180
+ if self.config.level in [Const.LEVEL_MIX, Const.LEVEL_L1]:
285
181
  JitDump.set_config(self.config)
286
182
  JitDump.set_data_collector(self.data_collector)
287
183
  ms.common.api._MindsporeFunctionExecutor = JitDump
@@ -291,10 +187,31 @@ class Service:
291
187
  PIJitCaptureContext.__exit__ = self.empty
292
188
  self.first_start = False
293
189
 
190
+ api_register.api_set_hook_func()
294
191
  self.switch = True
192
+ self.primitive_switch = True
295
193
  logger.info(f"Dump switch is turned on at step {self.current_iter}. ")
296
194
  self.create_dirs()
297
195
  logger.info(f"Dump data will be saved in {self.dump_iter_dir}.")
196
+ JitDump.jit_dump_switch = True
197
+
198
+ def forward_backward_dump_end(self):
199
+ if self.should_stop_service:
200
+ return
201
+ logger.info(f"{Const.TOOL_NAME}: debugger.forward_backward_dump_end() is set successfully. ")
202
+ if not self.start_call:
203
+ logger.error(f"{Const.TOOL_NAME}: debugger.start() is not set in the current scope.")
204
+ raise Exception("debugger.start() is not set in the current scope.")
205
+ if not self.switch:
206
+ logger.error(f"{Const.TOOL_NAME}: debugger.forward_backward_dump_end() should be called between "
207
+ "debugger.start() and debugger.stop() ")
208
+ raise Exception("debugger.stop() is already called. ")
209
+ if self.config.step and self.current_iter not in self.config.step:
210
+ return
211
+ if self.config.rank and self.current_rank not in self.config.rank:
212
+ return
213
+ self.primitive_switch = False
214
+ api_register.api_set_ori_func()
298
215
 
299
216
  def stop(self):
300
217
  if self.should_stop_service:
@@ -309,8 +226,10 @@ class Service:
309
226
  if self.config.rank and self.current_rank not in self.config.rank:
310
227
  return
311
228
  self.switch = False
229
+ self.primitive_switch = False
312
230
  self.start_call = False
313
231
  self.data_collector.write_json()
232
+ JitDump.jit_dump_switch = False
314
233
 
315
234
  def need_end_service(self):
316
235
  if self.config.step and self.current_iter > max(self.config.step):
@@ -349,16 +268,16 @@ class Service:
349
268
 
350
269
  def register_hook_new(self):
351
270
  logger.info("The {} hook function is successfully mounted to the model.".format(self.config.task))
352
- if self.config.level == "L1":
271
+ if self.config.level in [Const.LEVEL_MIX, Const.LEVEL_L1]:
353
272
  api_register.initialize_hook(functools.partial(self.build_hook, BaseScope.Module_Type_API))
354
273
  api_register.api_set_hook_func()
355
- if self.model:
356
- self.register_hooks()
274
+ if self.model and self.config.task in Const.DUMP_DATA_COLLECTION_LIST:
275
+ self.register_primitive_hooks()
357
276
 
358
- if self.config.level == "L0":
277
+ if self.config.level in [Const.LEVEL_MIX, Const.LEVEL_L0]:
359
278
  if not self.model:
360
279
  raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR,
361
- "The current level is L0, the model cannot be None")
280
+ f"The current level is {self.config.level}, the model cannot be None")
362
281
  for name, cell in self.model.cells_and_names():
363
282
  if cell == self.model:
364
283
  continue
@@ -1,4 +1,23 @@
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
4
+ # All rights reserved.
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+
18
+
1
19
  from .debugger.precision_debugger import PrecisionDebugger
2
20
  from .common.utils import seed_all
3
21
  from .compare.distributed_compare import compare_distributed
4
- from .compare.pt_compare import compare
22
+ from .compare.pt_compare import compare
23
+ from .functional.module_dump import module_dump, module_dump_end
@@ -1,3 +1,20 @@
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
4
+ # All rights reserved.
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+
1
18
  import os
2
19
  from msprobe.core.common.file_utils import load_yaml, check_file_or_directory_path
3
20
  from msprobe.pytorch.pt_config import RunUTConfig
@@ -33,8 +50,8 @@ class Config:
33
50
  raise ValueError(f"{key} must be one of {validators.keys()}")
34
51
  if not isinstance(value, validators.get(key)):
35
52
  raise ValueError(f"{key} must be {validators[key].__name__} type")
36
- if key == 'precision' and value < 0:
37
- raise ValueError("precision must be greater than 0")
53
+ if key == 'precision' and (value < 0 or value > 20):
54
+ raise ValueError("precision must be greater than or equal to 0 and less than 21")
38
55
  if key == 'white_list':
39
56
  RunUTConfig.check_filter_list_config(key, value)
40
57
  if key == 'black_list':