mindstudio-probe 1.0.4__py3-none-any.whl → 1.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (194) hide show
  1. {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.0.dist-info}/METADATA +1 -1
  2. mindstudio_probe-1.1.0.dist-info/RECORD +287 -0
  3. msprobe/README.md +46 -16
  4. msprobe/__init__.py +16 -1
  5. msprobe/config.json +0 -2
  6. msprobe/core/advisor/advisor.py +8 -8
  7. msprobe/core/advisor/advisor_const.py +6 -7
  8. msprobe/core/advisor/advisor_result.py +12 -12
  9. msprobe/core/common/const.py +64 -3
  10. msprobe/core/common/exceptions.py +2 -2
  11. msprobe/core/common/file_utils.py +54 -9
  12. msprobe/core/common/inplace_op_checker.py +38 -0
  13. msprobe/core/common/inplace_ops.yaml +251 -0
  14. msprobe/core/common/log.py +21 -11
  15. msprobe/core/common/utils.py +153 -167
  16. msprobe/core/common_config.py +18 -25
  17. msprobe/core/compare/acc_compare.py +209 -36
  18. msprobe/core/compare/check.py +102 -17
  19. msprobe/core/compare/compare_cli.py +21 -1
  20. msprobe/core/compare/highlight.py +41 -5
  21. msprobe/core/compare/multiprocessing_compute.py +33 -8
  22. msprobe/core/compare/npy_compare.py +21 -6
  23. msprobe/core/compare/utils.py +82 -48
  24. msprobe/core/data_dump/data_collector.py +31 -32
  25. msprobe/core/data_dump/data_processor/base.py +45 -22
  26. msprobe/core/data_dump/data_processor/factory.py +20 -3
  27. msprobe/core/data_dump/data_processor/mindspore_processor.py +11 -5
  28. msprobe/core/data_dump/data_processor/pytorch_processor.py +24 -7
  29. msprobe/core/data_dump/json_writer.py +63 -42
  30. msprobe/core/data_dump/scope.py +32 -16
  31. msprobe/core/grad_probe/constant.py +4 -0
  32. msprobe/core/grad_probe/grad_compare.py +2 -3
  33. msprobe/core/grad_probe/utils.py +16 -3
  34. msprobe/docs/01.installation.md +19 -9
  35. msprobe/docs/02.config_introduction.md +52 -80
  36. msprobe/docs/03.config_examples.md +3 -13
  37. msprobe/docs/04.acl_config_examples.md +11 -9
  38. msprobe/docs/05.data_dump_PyTorch.md +140 -12
  39. msprobe/docs/06.data_dump_MindSpore.md +47 -5
  40. msprobe/docs/07.accuracy_checker_PyTorch.md +57 -34
  41. msprobe/docs/08.accuracy_checker_online_PyTorch.md +51 -11
  42. msprobe/docs/09.accuracy_checker_MindSpore.md +8 -8
  43. msprobe/docs/10.accuracy_compare_PyTorch.md +181 -99
  44. msprobe/docs/11.accuracy_compare_MindSpore.md +162 -31
  45. msprobe/docs/13.overflow_check_MindSpore.md +1 -1
  46. msprobe/docs/15.free_benchmarking_PyTorch.md +59 -53
  47. msprobe/docs/16.free_benchmarking_MindSpore.md +140 -0
  48. msprobe/docs/17.grad_probe.md +14 -16
  49. msprobe/docs/18.online_dispatch.md +89 -0
  50. msprobe/docs/{FAQ_PyTorch.md → FAQ.md} +22 -10
  51. msprobe/docs/img/ms_dump.png +0 -0
  52. msprobe/docs/img/ms_layer.png +0 -0
  53. msprobe/docs/img/pt_dump.png +0 -0
  54. msprobe/mindspore/__init__.py +1 -0
  55. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +35 -11
  56. msprobe/mindspore/api_accuracy_checker/api_info.py +7 -0
  57. msprobe/mindspore/cell_processor.py +27 -3
  58. msprobe/mindspore/common/const.py +2 -0
  59. msprobe/mindspore/common/utils.py +18 -2
  60. msprobe/mindspore/compare/distributed_compare.py +9 -22
  61. msprobe/mindspore/compare/layer_mapping.py +146 -0
  62. msprobe/mindspore/compare/modify_mapping.py +107 -0
  63. msprobe/mindspore/compare/ms_compare.py +173 -35
  64. msprobe/mindspore/compare/ms_graph_compare.py +27 -11
  65. msprobe/mindspore/debugger/debugger_config.py +16 -13
  66. msprobe/mindspore/debugger/precision_debugger.py +37 -13
  67. msprobe/mindspore/dump/dump_tool_factory.py +16 -1
  68. msprobe/mindspore/dump/hook_cell/api_registry.py +11 -1
  69. msprobe/mindspore/dump/hook_cell/primitive_hooks.py +206 -0
  70. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +82 -10
  71. msprobe/mindspore/dump/hook_cell/wrap_api.py +21 -13
  72. msprobe/mindspore/dump/jit_dump.py +41 -17
  73. msprobe/mindspore/dump/kernel_graph_dump.py +19 -3
  74. msprobe/mindspore/dump/kernel_kbyk_dump.py +19 -4
  75. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +19 -4
  76. msprobe/mindspore/free_benchmark/common/config.py +15 -0
  77. msprobe/mindspore/free_benchmark/common/handler_params.py +15 -0
  78. msprobe/mindspore/free_benchmark/common/utils.py +19 -5
  79. msprobe/mindspore/free_benchmark/decorator/dec_forward.py +16 -2
  80. msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +18 -3
  81. msprobe/mindspore/free_benchmark/handler/base_handler.py +18 -3
  82. msprobe/mindspore/free_benchmark/handler/check_handler.py +18 -3
  83. msprobe/mindspore/free_benchmark/handler/fix_handler.py +15 -0
  84. msprobe/mindspore/free_benchmark/handler/handler_factory.py +18 -3
  85. msprobe/mindspore/free_benchmark/perturbation/add_noise.py +22 -7
  86. msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +15 -0
  87. msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +22 -7
  88. msprobe/mindspore/free_benchmark/perturbation/exchange_value.py +44 -18
  89. msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +18 -4
  90. msprobe/mindspore/free_benchmark/perturbation/no_change.py +16 -1
  91. msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +20 -5
  92. msprobe/mindspore/free_benchmark/self_check_tool_factory.py +15 -0
  93. msprobe/mindspore/grad_probe/global_context.py +18 -8
  94. msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +20 -4
  95. msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +15 -0
  96. msprobe/mindspore/service.py +42 -123
  97. msprobe/pytorch/__init__.py +20 -1
  98. msprobe/pytorch/api_accuracy_checker/common/config.py +19 -2
  99. msprobe/pytorch/api_accuracy_checker/common/utils.py +53 -21
  100. msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +19 -2
  101. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +47 -21
  102. msprobe/pytorch/api_accuracy_checker/compare/compare.py +51 -21
  103. msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +23 -6
  104. msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +28 -8
  105. msprobe/pytorch/api_accuracy_checker/config.yaml +1 -1
  106. msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +67 -32
  107. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +26 -5
  108. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +19 -2
  109. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +51 -125
  110. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +146 -3
  111. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +21 -0
  112. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +78 -33
  113. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +27 -4
  114. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/dump_dispatch.py +110 -0
  115. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +36 -11
  116. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/torch_ops_config.yaml +63 -0
  117. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +44 -0
  118. msprobe/pytorch/bench_functions/__init__.py +18 -3
  119. msprobe/pytorch/bench_functions/apply_adam_w.py +15 -0
  120. msprobe/pytorch/bench_functions/confusion_transpose.py +15 -0
  121. msprobe/pytorch/bench_functions/fast_gelu.py +15 -0
  122. msprobe/pytorch/bench_functions/layer_norm_eval.py +15 -0
  123. msprobe/pytorch/bench_functions/linear.py +15 -0
  124. msprobe/pytorch/bench_functions/matmul_backward.py +21 -6
  125. msprobe/pytorch/bench_functions/npu_fusion_attention.py +180 -151
  126. msprobe/pytorch/bench_functions/rms_norm.py +15 -0
  127. msprobe/pytorch/bench_functions/rotary_mul.py +28 -9
  128. msprobe/pytorch/bench_functions/scaled_mask_softmax.py +15 -0
  129. msprobe/pytorch/bench_functions/swiglu.py +20 -5
  130. msprobe/pytorch/common/__init__.py +15 -0
  131. msprobe/pytorch/common/log.py +18 -6
  132. msprobe/pytorch/common/parse_json.py +26 -11
  133. msprobe/pytorch/common/utils.py +40 -35
  134. msprobe/pytorch/compare/distributed_compare.py +11 -11
  135. msprobe/pytorch/compare/match.py +15 -0
  136. msprobe/pytorch/compare/pt_compare.py +38 -6
  137. msprobe/pytorch/debugger/debugger_config.py +52 -39
  138. msprobe/pytorch/debugger/precision_debugger.py +72 -24
  139. msprobe/pytorch/free_benchmark/__init__.py +20 -5
  140. msprobe/pytorch/free_benchmark/common/enums.py +28 -0
  141. msprobe/pytorch/free_benchmark/common/params.py +15 -0
  142. msprobe/pytorch/free_benchmark/common/utils.py +17 -1
  143. msprobe/pytorch/free_benchmark/compare/grad_saver.py +28 -7
  144. msprobe/pytorch/free_benchmark/compare/single_benchmark.py +15 -0
  145. msprobe/pytorch/free_benchmark/main.py +19 -4
  146. msprobe/pytorch/free_benchmark/perturbed_layers/base_layer.py +15 -0
  147. msprobe/pytorch/free_benchmark/perturbed_layers/layer_factory.py +19 -4
  148. msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +15 -0
  149. msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +15 -0
  150. msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +26 -2
  151. msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +15 -0
  152. msprobe/pytorch/free_benchmark/perturbed_layers/npu/no_change.py +15 -0
  153. msprobe/pytorch/free_benchmark/perturbed_layers/npu/npu_base_layser.py +15 -0
  154. msprobe/pytorch/free_benchmark/perturbed_layers/run_cpu.py +15 -0
  155. msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +55 -16
  156. msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +15 -0
  157. msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +15 -0
  158. msprobe/pytorch/free_benchmark/result_handlers/handler_factory.py +15 -0
  159. msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +19 -4
  160. msprobe/pytorch/function_factory.py +17 -2
  161. msprobe/pytorch/functional/module_dump.py +84 -0
  162. msprobe/pytorch/grad_probe/grad_stat_csv.py +2 -2
  163. msprobe/pytorch/hook_module/__init__.py +16 -1
  164. msprobe/pytorch/hook_module/api_registry.py +13 -8
  165. msprobe/pytorch/hook_module/hook_module.py +17 -19
  166. msprobe/pytorch/hook_module/utils.py +4 -6
  167. msprobe/pytorch/hook_module/wrap_aten.py +12 -11
  168. msprobe/pytorch/hook_module/wrap_distributed.py +6 -7
  169. msprobe/pytorch/hook_module/wrap_functional.py +10 -11
  170. msprobe/pytorch/hook_module/wrap_npu_custom.py +9 -17
  171. msprobe/pytorch/hook_module/wrap_tensor.py +4 -6
  172. msprobe/pytorch/hook_module/wrap_torch.py +4 -6
  173. msprobe/pytorch/hook_module/wrap_vf.py +4 -6
  174. msprobe/pytorch/module_processer.py +17 -2
  175. msprobe/pytorch/online_dispatch/compare.py +11 -12
  176. msprobe/pytorch/online_dispatch/single_compare.py +7 -7
  177. msprobe/pytorch/online_dispatch/torch_ops_config.yaml +8 -0
  178. msprobe/pytorch/online_dispatch/utils.py +1 -4
  179. msprobe/pytorch/parse.py +15 -0
  180. msprobe/pytorch/parse_tool/cli.py +5 -6
  181. msprobe/pytorch/parse_tool/lib/compare.py +9 -10
  182. msprobe/pytorch/parse_tool/lib/parse_tool.py +3 -0
  183. msprobe/pytorch/parse_tool/lib/utils.py +28 -24
  184. msprobe/pytorch/parse_tool/lib/visualization.py +1 -1
  185. msprobe/pytorch/pt_config.py +167 -38
  186. msprobe/pytorch/service.py +97 -32
  187. mindstudio_probe-1.0.4.dist-info/RECORD +0 -276
  188. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/ssl_config.py +0 -10
  189. msprobe/pytorch/functional/data_processor.py +0 -0
  190. msprobe/pytorch/functional/dump_module.py +0 -39
  191. {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.0.dist-info}/LICENSE +0 -0
  192. {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.0.dist-info}/WHEEL +0 -0
  193. {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.0.dist-info}/entry_points.txt +0 -0
  194. {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.0.dist-info}/top_level.txt +0 -0
@@ -1,3 +1,18 @@
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
1
16
  from msprobe.core.common.const import Const
2
17
 
3
18
 
@@ -34,14 +49,14 @@ class DataProcessorFactory:
34
49
  @classmethod
35
50
  def register_processors(cls, framework):
36
51
  if framework == Const.PT_FRAMEWORK:
37
- from .pytorch_processor import (
52
+ from msprobe.core.data_dump.data_processor.pytorch_processor import (
38
53
  StatisticsDataProcessor as PytorchStatisticsDataProcessor,
39
54
  TensorDataProcessor as PytorchTensorDataProcessor,
40
55
  OverflowCheckDataProcessor as PytorchOverflowCheckDataProcessor,
41
56
  FreeBenchmarkDataProcessor as PytorchFreeBenchmarkDataProcessor,
42
57
  KernelDumpDataProcessor as PytorchKernelDumpDataProcessor
43
58
  )
44
- from ....pytorch.module_processer import ModuleProcesser
59
+ from msprobe.pytorch.module_processer import ModuleProcesser
45
60
  cls.register_processor(Const.PT_FRAMEWORK, Const.STATISTICS, PytorchStatisticsDataProcessor)
46
61
  cls.register_processor(Const.PT_FRAMEWORK, Const.TENSOR, PytorchTensorDataProcessor)
47
62
  cls.register_processor(Const.PT_FRAMEWORK, Const.OVERFLOW_CHECK, PytorchOverflowCheckDataProcessor)
@@ -49,11 +64,13 @@ class DataProcessorFactory:
49
64
  cls.register_processor(Const.PT_FRAMEWORK, Const.KERNEL_DUMP, PytorchKernelDumpDataProcessor)
50
65
  cls.register_module_processor(Const.PT_FRAMEWORK, ModuleProcesser)
51
66
  elif framework == Const.MS_FRAMEWORK:
52
- from .mindspore_processor import (
67
+ from msprobe.core.data_dump.data_processor.mindspore_processor import (
53
68
  StatisticsDataProcessor as MindsporeStatisticsDataProcessor,
54
69
  TensorDataProcessor as MindsporeTensorDataProcessor,
55
70
  OverflowCheckDataProcessor as MindsporeOverflowCheckDataProcessor
56
71
  )
72
+ from msprobe.mindspore.cell_processor import CellProcessor
57
73
  cls.register_processor(Const.MS_FRAMEWORK, Const.STATISTICS, MindsporeStatisticsDataProcessor)
58
74
  cls.register_processor(Const.MS_FRAMEWORK, Const.TENSOR, MindsporeTensorDataProcessor)
59
75
  cls.register_processor(Const.MS_FRAMEWORK, Const.OVERFLOW_CHECK, MindsporeOverflowCheckDataProcessor)
76
+ cls.register_module_processor(Const.MS_FRAMEWORK, CellProcessor)
@@ -17,6 +17,7 @@ import zlib
17
17
 
18
18
  import mindspore as ms
19
19
  from mindspore import mint, ops
20
+ from mindspore._c_expression.typing import Number
20
21
  import numpy as np
21
22
 
22
23
  from msprobe.core.common.const import Const
@@ -29,7 +30,7 @@ from msprobe.mindspore.dump.hook_cell.api_registry import api_register
29
30
 
30
31
 
31
32
  class MindsporeDataProcessor(BaseDataProcessor):
32
- mindspore_special_type = tuple([ms.Tensor])
33
+ mindspore_special_type = tuple([ms.Tensor, Number])
33
34
 
34
35
  def __init__(self, config, data_writer):
35
36
  super().__init__(config, data_writer)
@@ -69,13 +70,16 @@ class MindsporeDataProcessor(BaseDataProcessor):
69
70
  tensor_stat.mean = np.mean(data_abs).item()
70
71
  tensor_stat.norm = np.linalg.norm(data_abs).item()
71
72
  else:
72
- if data.dtype == ms.bfloat16 or not ops.is_floating_point(data):
73
+ if not ops.is_floating_point(data):
73
74
  data = data.to(ms.float32)
74
75
  api_register.norm_inner_op_set_ori_func()
75
76
  get_max_value = api_register.mint_ops_ori_attr.get("max", mint.max)
76
77
  get_min_value = api_register.mint_ops_ori_attr.get("min", mint.min)
77
78
  get_mean_value = api_register.mint_ops_ori_attr.get("mean", mint.mean)
78
- get_norm_value = api_register.functional_ori_attr.get("norm", ops.norm)
79
+ if hasattr(mint, "norm"):
80
+ get_norm_value = api_register.mint_ops_ori_attr.get("norm", mint.norm)
81
+ else:
82
+ get_norm_value = api_register.functional_ori_attr.get("norm", ops.norm)
79
83
  tensor_stat.max = get_max_value(data).item()
80
84
  tensor_stat.min = get_min_value(data).item()
81
85
  tensor_stat.mean = get_mean_value(data).item()
@@ -90,9 +94,10 @@ class MindsporeDataProcessor(BaseDataProcessor):
90
94
  converted_numpy, numpy_type = self._convert_numpy_to_builtin(element)
91
95
  if converted_numpy is not element:
92
96
  return self._analyze_numpy(converted_numpy, numpy_type)
97
+ if isinstance(element, Number):
98
+ return self.analyze_dtype_in_kwargs(element)
93
99
  if isinstance(element, ms.Tensor):
94
100
  return self._analyze_tensor(element, Const.SEP.join(suffix_stack))
95
-
96
101
  if isinstance(element, (bool, int, float, str, slice, type(Ellipsis))):
97
102
  return self._analyze_builtin(element)
98
103
  return {}
@@ -163,7 +168,8 @@ class OverflowCheckDataProcessor(MindsporeDataProcessor):
163
168
  save_tensor_as_npy(tensor, file_path)
164
169
  self.real_overflow_nums += 1
165
170
  if self.overflow_nums != -1 and self.real_overflow_nums >= self.overflow_nums:
166
- logger.info(f"[{Const.TOOL_NAME}] 超过预设溢出次数 当前溢出次数: {self.real_overflow_nums}")
171
+ logger.info(f"[{Const.TOOL_NAME}] Reached the preset overflow times, "
172
+ f"current overflow times: {self.real_overflow_nums}.")
167
173
  self.cached_tensors_and_file_paths = {}
168
174
 
169
175
  def _analyze_maybe_overflow_tensor(self, tensor_json):
@@ -1,20 +1,35 @@
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
1
16
  import zlib
2
17
  from dataclasses import asdict
3
18
  from typing import List
4
19
 
5
20
  import numpy as np
6
21
  import torch
7
- from msprobe.core.common.file_utils import path_len_exceeds_limit, change_mode
22
+ from msprobe.core.common.const import Const
23
+ from msprobe.core.common.file_utils import path_len_exceeds_limit
8
24
  from msprobe.core.common.log import logger
9
- from msprobe.core.common.const import Const, OverflowConst, FileCheckConst
10
25
  from msprobe.core.data_dump.data_processor.base import BaseDataProcessor, ModuleBackwardInputsOutputs, \
11
26
  ModuleForwardInputsOutputs, TensorStatInfo
12
- from msprobe.pytorch.free_benchmark import FreeBenchmarkCheck, UnequalRow
13
27
  from msprobe.pytorch.common.utils import save_pt, load_pt
28
+ from msprobe.pytorch.free_benchmark import FreeBenchmarkCheck, UnequalRow
14
29
 
30
+ is_gpu = False
15
31
  try:
16
32
  import torch_npu
17
- is_gpu = False
18
33
  except ImportError:
19
34
  is_gpu = True
20
35
 
@@ -153,7 +168,7 @@ class StatisticsDataProcessor(PytorchDataProcessor):
153
168
  class TensorDataProcessor(PytorchDataProcessor):
154
169
  def _analyze_tensor(self, tensor, suffix):
155
170
  dump_data_name, file_path = self.get_save_file_path(suffix)
156
- saved_tensor = tensor.contiguous().detach()
171
+ saved_tensor = tensor.clone().contiguous().detach()
157
172
  save_pt(saved_tensor, file_path)
158
173
  single_arg = super()._analyze_tensor(tensor, suffix)
159
174
  single_arg.update({"data_name": dump_data_name})
@@ -178,7 +193,6 @@ class OverflowCheckDataProcessor(PytorchDataProcessor):
178
193
  if self.overflow_nums == -1:
179
194
  return False
180
195
  if self.real_overflow_nums >= self.overflow_nums:
181
- logger.info(f"[msprobe] 超过预设溢出次数 当前溢出次数: {self.real_overflow_nums}")
182
196
  return True
183
197
  return False
184
198
 
@@ -219,6 +233,9 @@ class OverflowCheckDataProcessor(PytorchDataProcessor):
219
233
  for file_path, tensor in self.cached_tensors_and_file_paths.items():
220
234
  save_pt(tensor, file_path)
221
235
  self.real_overflow_nums += 1
236
+ if self.overflow_nums != -1 and self.real_overflow_nums >= self.overflow_nums:
237
+ logger.info(f"[{Const.TOOL_NAME}] Reached the preset overflow times, "
238
+ f"current overflow times: {self.real_overflow_nums}.")
222
239
  self.cached_tensors_and_file_paths = {}
223
240
 
224
241
  def _is_support_inf_nan(self):
@@ -243,7 +260,7 @@ class OverflowCheckDataProcessor(PytorchDataProcessor):
243
260
  if tensor_json['Max'] is None or tensor_json['Min'] is None:
244
261
  return
245
262
  self.has_overflow = np.isinf(tensor_json['Max']) or np.isnan(tensor_json['Max']) or \
246
- np.isinf(tensor_json['Min']) or np.isnan(tensor_json['Min'])
263
+ np.isinf(tensor_json['Min']) or np.isnan(tensor_json['Min'])
247
264
 
248
265
  def _analyze_tensor(self, tensor, suffix):
249
266
  dump_data_name, file_path = self.get_save_file_path(suffix)
@@ -1,24 +1,36 @@
1
- import os
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
2
16
  import csv
17
+ import os
3
18
 
4
- from msprobe.core.common.file_utils import change_mode, FileOpen
5
- from msprobe.core.common.log import logger
6
19
  from msprobe.core.common.const import Const, FileCheckConst
7
- from msprobe.core.common.file_utils import remove_path, load_json, save_json
20
+ from msprobe.core.common.file_utils import change_mode, FileOpen, save_json
21
+ from msprobe.core.common.log import logger
8
22
 
9
23
 
10
24
  class DataWriter:
11
25
 
12
- def __init__(self, init_json=None) -> None:
13
- self.dump_count = 0
14
- self.init_json = init_json
15
- self.dump_file_path = None # os.path.join(dump_dir, DataWriter.dump_json_name)
16
- self.stack_file_path = None # os.path.join(dump_dir, DataWriter.stack_json_name)
17
- self.construct_file_path = None # os.path.join(dump_dir, DataWriter.construct_json_name)
26
+ def __init__(self) -> None:
27
+ self.dump_file_path = None
28
+ self.stack_file_path = None
29
+ self.construct_file_path = None
18
30
  self.free_benchmark_file_path = None
19
31
  self.dump_tensor_data_dir = None
20
- self.buffer_size = 1000
21
- self.cache_data = {Const.DATA: {}}
32
+ self.flush_size = 1000
33
+ self.cache_data = {}
22
34
  self.cache_stack = {}
23
35
  self.cache_construct = {}
24
36
 
@@ -37,18 +49,22 @@ class DataWriter:
37
49
  if is_new_file:
38
50
  change_mode(file_path, FileCheckConst.DATA_FILE_AUTHORITY)
39
51
 
40
- def initialize_json_file(self, **kwargs):
41
- kwargs.update({"dump_data_dir": self.dump_tensor_data_dir, Const.DATA: {}})
42
- save_json(self.dump_file_path, kwargs)
43
-
44
- empty_dict = {}
45
- remove_path(self.stack_file_path)
46
- save_json(self.stack_file_path, empty_dict)
47
-
48
- remove_path(self.construct_file_path)
49
- save_json(self.construct_file_path, empty_dict)
52
+ def reset_cache(self):
53
+ self.cache_data = {}
54
+ self.cache_stack = {}
55
+ self.cache_construct = {}
50
56
 
51
- def update_dump_paths(self, dump_file_path, stack_file_path, construct_file_path, dump_data_dir,
57
+ def initialize_json_file(self, **kwargs):
58
+ if not self.cache_data:
59
+ kwargs.update({"dump_data_dir": self.dump_tensor_data_dir, Const.DATA: {}})
60
+ self.cache_data = kwargs
61
+ save_json(self.dump_file_path, self.cache_data, indent=1)
62
+ if not self.cache_stack:
63
+ save_json(self.stack_file_path, self.cache_stack, indent=1)
64
+ if not self.cache_construct:
65
+ save_json(self.construct_file_path, self.cache_construct, indent=1)
66
+
67
+ def update_dump_paths(self, dump_file_path, stack_file_path, construct_file_path, dump_data_dir,
52
68
  free_benchmark_file_path):
53
69
  self.dump_file_path = dump_file_path
54
70
  self.stack_file_path = stack_file_path
@@ -56,16 +72,25 @@ class DataWriter:
56
72
  self.dump_tensor_data_dir = dump_data_dir
57
73
  self.free_benchmark_file_path = free_benchmark_file_path
58
74
 
75
+ def flush_data_periodically(self):
76
+ dump_data = self.cache_data.get(Const.DATA)
77
+ if dump_data and isinstance(dump_data, dict) and len(dump_data) % self.flush_size == 0:
78
+ self.write_json()
79
+
59
80
  def update_data(self, new_data):
60
- key = next(iter(new_data.keys())) # assert len(new_data.keys()) == 1
61
- if key in self.cache_data[Const.DATA]:
62
- self.cache_data[Const.DATA][key].update(new_data[key])
63
- else:
64
- self.cache_data[Const.DATA].update(new_data)
81
+ if not isinstance(new_data, dict) or len(new_data.keys()) != 1:
82
+ logger.warning(f"The data info({new_data}) should be a dict with only one outer key.")
83
+ return
84
+ dump_data = self.cache_data.get(Const.DATA)
85
+ if not isinstance(dump_data, dict):
86
+ logger.warning(f"The dump data({dump_data}) should be a dict.")
87
+ return
65
88
 
66
- def flush_data_when_buffer_is_full(self):
67
- if len(self.cache_data[Const.DATA]) >= self.buffer_size:
68
- self.write_data_json(self.dump_file_path)
89
+ key = next(iter(new_data.keys()))
90
+ if key in dump_data:
91
+ dump_data.get(key).update(new_data.get(key))
92
+ else:
93
+ dump_data.update(new_data)
69
94
 
70
95
  def update_stack(self, new_data):
71
96
  self.cache_stack.update(new_data)
@@ -75,14 +100,7 @@ class DataWriter:
75
100
 
76
101
  def write_data_json(self, file_path):
77
102
  logger.info(f"dump.json is at {os.path.dirname(os.path.dirname(file_path))}. ")
78
- if os.path.exists(file_path) and os.path.getsize(file_path) > 0:
79
- data_to_write = load_json(file_path)
80
- else:
81
- self.init_json['data_path'] = self.dump_tensor_data_dir
82
- data_to_write = self.init_json
83
- data_to_write[Const.DATA].update(self.cache_data[Const.DATA])
84
- save_json(file_path, data_to_write, indent=1)
85
- self.cache_data[Const.DATA].clear()
103
+ save_json(file_path, self.cache_data, indent=1)
86
104
 
87
105
  def write_stack_info_json(self, file_path):
88
106
  save_json(file_path, self.cache_stack, indent=1)
@@ -91,6 +109,9 @@ class DataWriter:
91
109
  save_json(file_path, self.cache_construct, indent=1)
92
110
 
93
111
  def write_json(self):
94
- self.write_data_json(self.dump_file_path)
95
- self.write_stack_info_json(self.stack_file_path)
96
- self.write_construct_info_json(self.construct_file_path)
112
+ if self.cache_data:
113
+ self.write_data_json(self.dump_file_path)
114
+ if self.cache_stack:
115
+ self.write_stack_info_json(self.stack_file_path)
116
+ if self.cache_construct:
117
+ self.write_construct_info_json(self.construct_file_path)
@@ -1,6 +1,22 @@
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
1
16
  from abc import ABC, abstractmethod
2
- from msprobe.core.common.exceptions import ScopeException
17
+
3
18
  from msprobe.core.common.const import Const
19
+ from msprobe.core.common.exceptions import ScopeException
4
20
 
5
21
 
6
22
  def build_scope(scope_class, scope=None, api_list=None):
@@ -33,6 +49,7 @@ def build_range_scope_according_to_scope_name(scope, api_list):
33
49
  class BaseScope(ABC):
34
50
  Module_Type_Module = "Module"
35
51
  Module_Type_API = "api"
52
+ module_type = ["Module", "Cell"]
36
53
 
37
54
  def __init__(self, scope, api_list):
38
55
  scope, api_list = self.rectify_args(scope, api_list)
@@ -81,9 +98,9 @@ class ListScope(BaseScope):
81
98
  f"scope和api_list不可以同时配置,实际配置为scope={scope}, api_list={api_list}.")
82
99
  return super(ListScope, ListScope).rectify_args(scope, api_list)
83
100
 
84
- def check(self, module_name):
85
- if not self.scope or module_name in self.scope:
86
- return self.check_api_list(module_name)
101
+ def check(self, name):
102
+ if not self.scope or name in self.scope:
103
+ return self.check_api_list(name)
87
104
  return False
88
105
 
89
106
 
@@ -94,7 +111,6 @@ class RangeScope(BaseScope, ABC):
94
111
  self.in_scope = False
95
112
  self.is_valid = self.check_scope_is_valid()
96
113
 
97
-
98
114
  @staticmethod
99
115
  def rectify_args(scope, api_list):
100
116
  scope, api_list = super(RangeScope, RangeScope).rectify_args(scope, api_list)
@@ -104,7 +120,6 @@ class RangeScope(BaseScope, ABC):
104
120
  elif len(scope) > 2:
105
121
  raise ScopeException(ScopeException.InvalidScope,
106
122
  f"scope参数指定区间断点,须传入长度为1或2的列表,实际长度为{len(scope)}.")
107
-
108
123
  return scope, api_list
109
124
 
110
125
  @abstractmethod
@@ -123,23 +138,23 @@ class APIRangeScope(RangeScope):
123
138
  if not self.scope:
124
139
  return True
125
140
  scope_start_type = self.scope[0].split(Const.SEP)[0]
126
- if scope_start_type == BaseScope.Module_Type_Module:
141
+ if scope_start_type in BaseScope.module_type:
127
142
  return False
128
143
  scope_stop_type = self.scope[1].split(Const.SEP)[0]
129
- if scope_stop_type == BaseScope.Module_Type_Module:
144
+ if scope_stop_type in BaseScope.module_type:
130
145
  return False
131
146
  return True
132
147
 
133
- def check(self, api_name):
134
- if self.scope and api_name == self.scope[0]:
148
+ def check(self, name):
149
+ if self.scope and name == self.scope[0]:
135
150
  self.in_scope = True
136
151
 
137
152
  if not self.scope or self.in_scope:
138
- result = self.check_api_list(api_name)
153
+ result = self.check_api_list(name)
139
154
  else:
140
155
  result = False
141
156
 
142
- if self.scope and api_name == self.scope[1]:
157
+ if self.scope and name == self.scope[1]:
143
158
  self.in_scope = False
144
159
  return result
145
160
 
@@ -150,13 +165,14 @@ class ModuleRangeScope(RangeScope):
150
165
  需要用pre_hook和full_backward_hook来精确控制module的开始和结束,
151
166
  在这些hook触发时调用begin_module和end_module做区间控制
152
167
  """
168
+
153
169
  def check_scope_is_valid(self):
154
170
  if not self.scope:
155
171
  return True
156
172
  scope_start_type = self.scope[0].split(Const.SEP)[0]
157
173
  scope_stop_type = self.scope[1].split(Const.SEP)[0]
158
- if scope_start_type == BaseScope.Module_Type_Module and \
159
- scope_stop_type == BaseScope.Module_Type_Module:
174
+ if scope_start_type in BaseScope.module_type and \
175
+ scope_stop_type in BaseScope.module_type:
160
176
  return True
161
177
  return False
162
178
 
@@ -172,7 +188,7 @@ class ModuleRangeScope(RangeScope):
172
188
  if module_name == self.scope[1]:
173
189
  self.in_scope = False
174
190
 
175
- def check(self, module_name):
191
+ def check(self, name):
176
192
  if not self.scope or self.in_scope:
177
- return self.check_api_list(module_name)
193
+ return self.check_api_list(name)
178
194
  return False
@@ -33,6 +33,10 @@ class GradConst:
33
33
  # direction suffix
34
34
  DIR_SUFFIX = "dir.npy"
35
35
 
36
+ # bounds safety
37
+ BOUNDS_MINIMUM = -2**63
38
+ BOUNDS_MAXIMUM = 2**63 - 1
39
+
36
40
  # file safty
37
41
  DATA_DIR_AUTHORITY = 0o750
38
42
  DATA_FILE_AUTHORITY = 0o640
@@ -2,12 +2,11 @@ import os
2
2
  from typing import List
3
3
 
4
4
  from tqdm import tqdm
5
- import pandas as pd
6
5
  import matplotlib.pyplot as plt
7
6
 
8
7
  from msprobe.core.common.file_utils import create_directory, check_path_before_create, check_file_or_directory_path
9
8
  from msprobe.core.common.log import logger
10
- from msprobe.core.common.file_utils import remove_path, load_npy, write_csv
9
+ from msprobe.core.common.file_utils import remove_path, load_npy, write_csv, read_csv
11
10
  from msprobe.core.grad_probe.constant import GradConst
12
11
  from msprobe.core.grad_probe.utils import plt_savefig
13
12
 
@@ -21,7 +20,7 @@ class GradComparator:
21
20
  continue
22
21
  if not os.path.exists(os.path.join(path2, summary_file)):
23
22
  continue
24
- summary_csv = pd.read_csv(os.path.join(path1, summary_file))
23
+ summary_csv = read_csv(os.path.join(path1, summary_file))
25
24
  return summary_csv["param_name"]
26
25
  raise RuntimeError("no matched grad_summary.csv for comparison, please dump data in same configuration")
27
26
 
@@ -20,12 +20,25 @@ def check_numeral_list_ascend(lst):
20
20
  def check_param(param_name):
21
21
  if not re.match(GradConst.PARAM_VALID_PATTERN, param_name):
22
22
  raise RuntimeError("The parameter name contains special characters.")
23
-
23
+
24
24
 
25
25
  def check_str(string, variable_name):
26
26
  if not isinstance(string, str):
27
27
  raise ValueError(f'The variable: "{variable_name}" is not a string.')
28
-
28
+
29
+ def check_bounds_element(bound):
30
+ return GradConst.BOUNDS_MINIMUM <= bound and bound <= GradConst.BOUNDS_MAXIMUM
31
+
32
+ def check_bounds(bounds):
33
+ prev = GradConst.BOUNDS_MINIMUM - 1
34
+ for element in bounds:
35
+ if not isinstance(element, (int, float)):
36
+ raise Exception("bounds element is not int or float")
37
+ if not check_bounds_element(element):
38
+ raise Exception("bounds element is out of int64 range")
39
+ if prev >= element:
40
+ raise Exception("bounds list is not ascending")
41
+ prev = element
29
42
 
30
43
  class ListCache(list):
31
44
  threshold = 1000
@@ -50,7 +63,7 @@ class ListCache(list):
50
63
  list.append(self, data)
51
64
  if len(self) >= ListCache.threshold:
52
65
  self.flush()
53
-
66
+
54
67
  def set_output_file(self, output_file):
55
68
  self._output_file = output_file
56
69
 
@@ -3,19 +3,20 @@
3
3
  推荐使用 [miniconda](https://docs.anaconda.com/miniconda/) 管理环境依赖。
4
4
 
5
5
  ```bash
6
- conda create -n msprobe python=3.8
6
+ conda create -n msprobe python
7
7
  conda activate msprobe
8
8
  ```
9
9
 
10
- ## 1. 从 PyPI 安装
10
+ ## 1 从 PyPI 安装
11
11
  ```shell
12
- pip install mindstudio-probe[==版本号]
12
+ pip install mindstudio-probe
13
13
  ```
14
14
 
15
- ## 2. 下载 whl 包安装
15
+ ## 2 下载 whl 包安装
16
16
 
17
17
  |版本|发布日期|支持 PyTorch 版本|支持 MindSpore 版本|下载链接|校验码|
18
18
  |:--:|:--:|:--:|:--:|:--:|:--:|
19
+ |1.0.4|2024.09.09|1.11/2.0/2.1/2.2|2.4.0|[mindstudio_probe-1.0.4-py3-none-any.whl](https://ptdbg.obs.myhuaweicloud.com/msprobe/1.0/mindstudio_probe-1.0.4-py3-none-any.whl)|4e1909566a71a855b356597750c20ee43d964a22b2c2b02ac08312a5def75fd6|
19
20
  | 1.0.3 | 2024.08.23 | 1.11/2.0/2.1/2.2 | 2.4.0 | [mindstudio_probe-1.0.3-py3-none-any.whl](https://ptdbg.obs.myhuaweicloud.com/msprobe/1.0/mindstudio_probe-1.0.3-py3-none-any.whl) | 7060cc141a5b98ef770cd9220995d299393f32a61938261e632c7e8b5160bef2 |
20
21
  | 1.0.2 | 2024.08.09 | 1.11/2.0/2.1/2.2 | 2.4.0 | [mindstudio_probe-1.0.2-py3-none-any.whl](https://ptdbg.obs.myhuaweicloud.com/msprobe/1.0/mindstudio_probe-1.0.2-py3-none-any.whl) | e4a980e5d98c426ce5ce9842520d9bc031d3b3de621c74b3d59414cc6e238e0e |
21
22
  | 1.0.1 | 2024.07.25 | 2.0/2.1/2.2 | 2.4.0 | [mindstudio_probe-1.0.1-py3-none-any.whl](https://ptdbg.obs.myhuaweicloud.com/msprobe/1.0/mindstudio_probe-1.0.1-py3-none-any.whl) | b699e224e4d4e3bcf9412c54fa858a1ee370f0d7a2bc69cb3f1273ac14a6dc82 |
@@ -31,7 +32,7 @@ pip install ./mindstudio_probe-{version}-py3-none-any.whl # 安装whl包
31
32
 
32
33
  若覆盖安装,请在命令行末尾添加 `--force-reinstall` 参数。
33
34
 
34
- ## 3. 从源码安装
35
+ ## 3 从源码安装
35
36
 
36
37
  ```shell
37
38
  git clone https://gitee.com/ascend/mstt.git
@@ -40,9 +41,18 @@ cd mstt/debug/accuracy_tools
40
41
  pip install setuptools wheel
41
42
 
42
43
  python setup.py bdist_wheel
43
- pip install ./dist/mindstudio_probe*.whl
44
+ cd ./dist
45
+ pip install ./mindstudio_probe*.whl
44
46
  ```
45
47
 
48
+ # 历史版本特性
49
+
50
+ <table>
51
+ <tr><th>版本</th><th>特性</th></tr>
52
+ <tr><td rowspan="2">1.0.3</td><td>【精度预检】</br>1. 落盘数据小;</br>2. 支持随机生成模式和真实数据模式;</br>3. 单 API 测试,排除整网中的累计误差问题。</td></tr>
53
+ <tr><td>【梯度检测】</br>1. 使用便捷,无需在训练流程里插入代码。</br>2. 可以精准定位问题出现的 step。</td></tr>
54
+ </table>
55
+
46
56
  # 查看 msprobe 工具信息
47
57
 
48
58
  ```bash
@@ -59,7 +69,7 @@ Home-page: https://gitee.com/ascend/mstt/tree/master/debug/accuracy_tools/msprob
59
69
  Author: Ascend Team
60
70
  Author-email: pmail_mindstudio@huawei.com
61
71
  License: Apache License 2.0
62
- Location: /home/xxx/miniconda3/envs/xxx/lib/python3.8/site-packages/mindstudio_probe-1.0.0-py3.8.egg
72
+ Location: /home/xxx/miniconda3/envs/xxx/lib/python3.x/site-packages/mindstudio_probe-1.0.x-py3.x.egg
63
73
  Requires: einops, matplotlib, numpy, openpyxl, pandas, pyOpenSSL, pyyaml, rich, tqdm, twisted, wheel
64
74
  Required-by:
65
75
  ```
@@ -68,11 +78,11 @@ Required-by:
68
78
 
69
79
  ## 1 安装 CANN 包
70
80
 
71
- 1.1 根据 CPU 架构和 NPU 型号选择 toolkit 和 kernal 包,可以参考 [CANN 软件安装指南](https://gitee.com/link?target=https%3A%2F%2Fwww.hiascend.com%2Fdocument%2Fdetail%2Fzh%2Fcanncommercial%2F700%2Fenvdeployment%2Finstg%2Finstg_0001.html)和[昇腾社区](https://www.hiascend.cn/developer/download/community/result?module=cann)。
81
+ 1.1 根据 CPU 架构和 NPU 型号选择 toolkit 和 kernel,可以参考 [CANN 软件安装指南](https://gitee.com/link?target=https%3A%2F%2Fwww.hiascend.com%2Fdocument%2Fdetail%2Fzh%2Fcanncommercial%2F700%2Fenvdeployment%2Finstg%2Finstg_0001.html)和[昇腾社区](https://www.hiascend.cn/developer/download/community/result?module=cann)。
72
82
 
73
83
  1.2 运行示例
74
84
  ```bash
75
- Ascend-cann-toolkit_x.x.x_linux-aarch64.run --full --install-path={cann_path}
85
+ Ascend-cann-toolkit_x.x.x_linux-xxxx.run --full --install-path={cann_path}
76
86
  Ascend-cann-kernels_x.x.x_linux.run --install --install-path={cann_path}
77
87
  ```
78
88