mindstudio-probe 1.0.4__py3-none-any.whl → 1.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (278) hide show
  1. {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.1.dist-info}/METADATA +5 -5
  2. mindstudio_probe-1.1.1.dist-info/RECORD +341 -0
  3. {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.1.dist-info}/WHEEL +1 -1
  4. {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.1.dist-info}/entry_points.txt +0 -1
  5. msprobe/README.md +84 -18
  6. msprobe/__init__.py +16 -1
  7. msprobe/config.json +1 -5
  8. msprobe/core/advisor/advisor.py +16 -11
  9. msprobe/core/advisor/advisor_const.py +6 -7
  10. msprobe/core/advisor/advisor_result.py +12 -12
  11. msprobe/core/common/const.py +164 -3
  12. msprobe/core/common/exceptions.py +26 -4
  13. msprobe/core/common/file_utils.py +196 -27
  14. msprobe/core/common/inplace_op_checker.py +53 -0
  15. msprobe/core/common/inplace_ops.yaml +251 -0
  16. msprobe/core/common/log.py +46 -18
  17. msprobe/core/common/utils.py +308 -209
  18. msprobe/core/common_config.py +60 -38
  19. msprobe/core/compare/acc_compare.py +332 -94
  20. msprobe/core/compare/check.py +104 -22
  21. msprobe/core/compare/compare_cli.py +42 -5
  22. msprobe/core/compare/highlight.py +162 -57
  23. msprobe/core/compare/layer_mapping/__init__.py +19 -0
  24. msprobe/core/compare/layer_mapping/data_scope_parser.py +235 -0
  25. msprobe/core/compare/layer_mapping/layer_mapping.py +242 -0
  26. msprobe/core/compare/layer_mapping/postprocess_pass.py +94 -0
  27. msprobe/core/compare/multiprocessing_compute.py +33 -8
  28. msprobe/core/compare/npy_compare.py +73 -29
  29. msprobe/core/compare/utils.py +306 -247
  30. msprobe/core/data_dump/data_collector.py +44 -43
  31. msprobe/core/data_dump/data_processor/base.py +88 -35
  32. msprobe/core/data_dump/data_processor/factory.py +20 -3
  33. msprobe/core/data_dump/data_processor/mindspore_processor.py +14 -8
  34. msprobe/core/data_dump/data_processor/pytorch_processor.py +180 -66
  35. msprobe/core/data_dump/json_writer.py +63 -42
  36. msprobe/core/data_dump/scope.py +143 -48
  37. msprobe/core/grad_probe/constant.py +31 -13
  38. msprobe/core/grad_probe/grad_compare.py +20 -4
  39. msprobe/core/grad_probe/utils.py +44 -3
  40. msprobe/core/overflow_check/abnormal_scene.py +185 -0
  41. msprobe/core/overflow_check/api_info.py +55 -0
  42. msprobe/core/overflow_check/checker.py +138 -0
  43. msprobe/core/overflow_check/filter.py +157 -0
  44. msprobe/core/overflow_check/ignore_rules.yaml +55 -0
  45. msprobe/core/overflow_check/level.py +22 -0
  46. msprobe/core/overflow_check/utils.py +28 -0
  47. msprobe/docs/01.installation.md +29 -9
  48. msprobe/docs/02.config_introduction.md +83 -84
  49. msprobe/docs/03.config_examples.md +3 -20
  50. msprobe/docs/04.kernel_dump_PyTorch.md +73 -0
  51. msprobe/docs/05.data_dump_PyTorch.md +143 -13
  52. msprobe/docs/06.data_dump_MindSpore.md +197 -88
  53. msprobe/docs/07.accuracy_checker_PyTorch.md +69 -46
  54. msprobe/docs/08.accuracy_checker_online_PyTorch.md +52 -17
  55. msprobe/docs/09.accuracy_checker_MindSpore.md +51 -15
  56. msprobe/docs/10.accuracy_compare_PyTorch.md +187 -99
  57. msprobe/docs/11.accuracy_compare_MindSpore.md +253 -31
  58. msprobe/docs/12.overflow_check_PyTorch.md +1 -1
  59. msprobe/docs/13.overflow_check_MindSpore.md +6 -6
  60. msprobe/docs/15.free_benchmarking_PyTorch.md +60 -55
  61. msprobe/docs/16.free_benchmarking_MindSpore.md +159 -0
  62. msprobe/docs/17.grad_probe.md +19 -22
  63. msprobe/docs/18.online_dispatch.md +89 -0
  64. msprobe/docs/19.monitor.md +468 -0
  65. msprobe/docs/20.monitor_performance_baseline.md +52 -0
  66. msprobe/docs/21.visualization_PyTorch.md +386 -0
  67. msprobe/docs/22.visualization_MindSpore.md +384 -0
  68. msprobe/docs/23.tool_function_introduction.md +28 -0
  69. msprobe/docs/{FAQ_PyTorch.md → FAQ.md} +25 -10
  70. msprobe/docs/data_dump_Mindspore/dynamic_graph_quick_start_example.md +211 -0
  71. msprobe/docs/img/compare_result.png +0 -0
  72. msprobe/docs/img/monitor/cpu_info.png +0 -0
  73. msprobe/docs/img/ms_dump.png +0 -0
  74. msprobe/docs/img/ms_layer.png +0 -0
  75. msprobe/docs/img/pt_dump.png +0 -0
  76. msprobe/mindspore/__init__.py +16 -0
  77. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +130 -138
  78. msprobe/mindspore/api_accuracy_checker/api_info.py +27 -5
  79. msprobe/mindspore/api_accuracy_checker/api_runner.py +43 -18
  80. msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +21 -7
  81. msprobe/mindspore/api_accuracy_checker/checker_support_api.yaml +77 -0
  82. msprobe/mindspore/api_accuracy_checker/cmd_parser.py +63 -1
  83. msprobe/mindspore/api_accuracy_checker/compute_element.py +59 -24
  84. msprobe/mindspore/api_accuracy_checker/data_manager.py +264 -0
  85. msprobe/mindspore/api_accuracy_checker/main.py +27 -3
  86. msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +206 -0
  87. msprobe/mindspore/api_accuracy_checker/multi_data_manager.py +58 -0
  88. msprobe/mindspore/api_accuracy_checker/type_mapping.py +22 -5
  89. msprobe/mindspore/api_accuracy_checker/utils.py +34 -17
  90. msprobe/mindspore/cell_processor.py +58 -13
  91. msprobe/mindspore/common/const.py +35 -13
  92. msprobe/mindspore/common/log.py +5 -9
  93. msprobe/mindspore/common/utils.py +60 -5
  94. msprobe/mindspore/compare/distributed_compare.py +15 -28
  95. msprobe/mindspore/compare/ms_compare.py +319 -158
  96. msprobe/mindspore/compare/ms_graph_compare.py +99 -49
  97. msprobe/mindspore/debugger/debugger_config.py +20 -14
  98. msprobe/mindspore/debugger/precision_debugger.py +43 -13
  99. msprobe/mindspore/dump/dump_tool_factory.py +18 -1
  100. msprobe/mindspore/dump/hook_cell/api_registry.py +23 -3
  101. msprobe/mindspore/dump/hook_cell/primitive_hooks.py +203 -0
  102. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +107 -10
  103. msprobe/mindspore/dump/hook_cell/wrap_api.py +21 -13
  104. msprobe/mindspore/dump/jit_dump.py +56 -20
  105. msprobe/mindspore/dump/kernel_graph_dump.py +19 -5
  106. msprobe/mindspore/dump/kernel_kbyk_dump.py +19 -6
  107. msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +140 -0
  108. msprobe/mindspore/dym_loader/hook_dynamic_loader.h +53 -0
  109. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +162 -41
  110. msprobe/mindspore/free_benchmark/common/config.py +15 -0
  111. msprobe/mindspore/free_benchmark/common/handler_params.py +15 -1
  112. msprobe/mindspore/free_benchmark/common/utils.py +37 -8
  113. msprobe/mindspore/free_benchmark/data/support_wrap_ops.yaml +0 -204
  114. msprobe/mindspore/free_benchmark/handler/base_handler.py +20 -5
  115. msprobe/mindspore/free_benchmark/handler/check_handler.py +21 -7
  116. msprobe/mindspore/free_benchmark/handler/fix_handler.py +18 -3
  117. msprobe/mindspore/free_benchmark/handler/handler_factory.py +21 -6
  118. msprobe/mindspore/free_benchmark/perturbation/add_noise.py +23 -8
  119. msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +29 -5
  120. msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +25 -10
  121. msprobe/mindspore/free_benchmark/perturbation/exchange_value.py +45 -19
  122. msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +29 -8
  123. msprobe/mindspore/free_benchmark/perturbation/no_change.py +16 -1
  124. msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +22 -7
  125. msprobe/mindspore/free_benchmark/self_check_tool_factory.py +17 -2
  126. msprobe/mindspore/grad_probe/global_context.py +44 -14
  127. msprobe/mindspore/grad_probe/grad_analyzer.py +27 -13
  128. msprobe/mindspore/grad_probe/grad_monitor.py +16 -1
  129. msprobe/mindspore/grad_probe/grad_stat_csv.py +33 -5
  130. msprobe/mindspore/grad_probe/hook.py +24 -10
  131. msprobe/mindspore/grad_probe/utils.py +18 -5
  132. msprobe/mindspore/ms_config.py +22 -15
  133. msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +20 -6
  134. msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +15 -0
  135. msprobe/mindspore/runtime.py +15 -0
  136. msprobe/mindspore/service.py +75 -150
  137. msprobe/mindspore/task_handler_factory.py +15 -0
  138. msprobe/msprobe.py +24 -7
  139. msprobe/pytorch/__init__.py +23 -3
  140. msprobe/pytorch/api_accuracy_checker/common/config.py +81 -2
  141. msprobe/pytorch/api_accuracy_checker/common/utils.py +53 -21
  142. msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +19 -2
  143. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +50 -25
  144. msprobe/pytorch/api_accuracy_checker/compare/compare.py +51 -21
  145. msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +23 -6
  146. msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +28 -8
  147. msprobe/pytorch/api_accuracy_checker/config.yaml +1 -1
  148. msprobe/pytorch/api_accuracy_checker/generate_op_script/config_op.json +9 -0
  149. msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +454 -0
  150. msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +365 -0
  151. msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +73 -33
  152. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +44 -18
  153. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +32 -11
  154. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +122 -172
  155. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +158 -4
  156. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +30 -24
  157. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +68 -31
  158. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +27 -4
  159. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/dump_dispatch.py +115 -0
  160. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +26 -9
  161. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/torch_ops_config.yaml +63 -0
  162. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +44 -0
  163. msprobe/pytorch/bench_functions/__init__.py +18 -3
  164. msprobe/pytorch/bench_functions/apply_adam_w.py +15 -0
  165. msprobe/pytorch/bench_functions/confusion_transpose.py +20 -1
  166. msprobe/pytorch/bench_functions/fast_gelu.py +15 -0
  167. msprobe/pytorch/bench_functions/layer_norm_eval.py +15 -0
  168. msprobe/pytorch/bench_functions/linear.py +15 -0
  169. msprobe/pytorch/bench_functions/matmul_backward.py +33 -6
  170. msprobe/pytorch/bench_functions/npu_fusion_attention.py +280 -157
  171. msprobe/pytorch/bench_functions/rms_norm.py +15 -0
  172. msprobe/pytorch/bench_functions/rotary_mul.py +32 -9
  173. msprobe/pytorch/bench_functions/scaled_mask_softmax.py +15 -0
  174. msprobe/pytorch/bench_functions/swiglu.py +29 -6
  175. msprobe/pytorch/common/__init__.py +15 -0
  176. msprobe/pytorch/common/log.py +18 -6
  177. msprobe/pytorch/common/parse_json.py +31 -16
  178. msprobe/pytorch/common/utils.py +96 -40
  179. msprobe/pytorch/compare/distributed_compare.py +13 -14
  180. msprobe/pytorch/compare/match.py +15 -0
  181. msprobe/pytorch/compare/pt_compare.py +44 -10
  182. msprobe/pytorch/debugger/debugger_config.py +69 -52
  183. msprobe/pytorch/debugger/precision_debugger.py +72 -24
  184. msprobe/pytorch/dump/kernel_dump/kernel_config.py +33 -0
  185. msprobe/pytorch/free_benchmark/__init__.py +20 -5
  186. msprobe/pytorch/free_benchmark/common/constant.py +15 -0
  187. msprobe/pytorch/free_benchmark/common/counter.py +15 -0
  188. msprobe/pytorch/free_benchmark/common/enums.py +43 -0
  189. msprobe/pytorch/free_benchmark/common/params.py +23 -1
  190. msprobe/pytorch/free_benchmark/common/utils.py +43 -5
  191. msprobe/pytorch/free_benchmark/compare/grad_saver.py +47 -9
  192. msprobe/pytorch/free_benchmark/compare/single_benchmark.py +17 -0
  193. msprobe/pytorch/free_benchmark/main.py +19 -4
  194. msprobe/pytorch/free_benchmark/perturbed_layers/base_layer.py +15 -0
  195. msprobe/pytorch/free_benchmark/perturbed_layers/layer_factory.py +19 -4
  196. msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +18 -1
  197. msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +21 -4
  198. msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +28 -2
  199. msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +19 -0
  200. msprobe/pytorch/free_benchmark/perturbed_layers/npu/no_change.py +15 -0
  201. msprobe/pytorch/free_benchmark/perturbed_layers/npu/npu_base_layser.py +15 -0
  202. msprobe/pytorch/free_benchmark/perturbed_layers/run_cpu.py +15 -0
  203. msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +65 -16
  204. msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +15 -0
  205. msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +21 -5
  206. msprobe/pytorch/free_benchmark/result_handlers/handler_factory.py +15 -0
  207. msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +19 -4
  208. msprobe/pytorch/function_factory.py +17 -2
  209. msprobe/pytorch/functional/module_dump.py +84 -0
  210. msprobe/pytorch/grad_probe/grad_monitor.py +23 -6
  211. msprobe/pytorch/grad_probe/grad_stat_csv.py +40 -10
  212. msprobe/pytorch/hook_module/__init__.py +16 -1
  213. msprobe/pytorch/hook_module/api_registry.py +13 -8
  214. msprobe/pytorch/hook_module/hook_module.py +17 -19
  215. msprobe/pytorch/hook_module/support_wrap_ops.yaml +1 -0
  216. msprobe/pytorch/hook_module/utils.py +4 -6
  217. msprobe/pytorch/hook_module/wrap_aten.py +12 -11
  218. msprobe/pytorch/hook_module/wrap_distributed.py +6 -7
  219. msprobe/pytorch/hook_module/wrap_functional.py +21 -20
  220. msprobe/pytorch/hook_module/wrap_npu_custom.py +9 -17
  221. msprobe/pytorch/hook_module/wrap_tensor.py +4 -6
  222. msprobe/pytorch/hook_module/wrap_torch.py +4 -6
  223. msprobe/pytorch/hook_module/wrap_vf.py +4 -6
  224. msprobe/pytorch/module_processer.py +18 -6
  225. msprobe/pytorch/monitor/anomaly_analyse.py +201 -0
  226. msprobe/pytorch/monitor/anomaly_detect.py +340 -0
  227. msprobe/pytorch/monitor/distributed/distributed_ops.yaml +19 -0
  228. msprobe/pytorch/monitor/distributed/stack_blacklist.yaml +5 -0
  229. msprobe/pytorch/monitor/distributed/wrap_distributed.py +272 -0
  230. msprobe/pytorch/monitor/features.py +108 -0
  231. msprobe/pytorch/monitor/module_hook.py +870 -0
  232. msprobe/pytorch/monitor/module_metric.py +193 -0
  233. msprobe/pytorch/monitor/module_spec_verifier.py +93 -0
  234. msprobe/pytorch/monitor/optimizer_collect.py +295 -0
  235. msprobe/pytorch/monitor/unittest/__init__.py +0 -0
  236. msprobe/pytorch/monitor/unittest/test_monitor.py +145 -0
  237. msprobe/pytorch/monitor/utils.py +250 -0
  238. msprobe/pytorch/monitor/visualizer.py +59 -0
  239. msprobe/pytorch/online_dispatch/__init__.py +2 -3
  240. msprobe/pytorch/online_dispatch/compare.py +38 -48
  241. msprobe/pytorch/online_dispatch/dispatch.py +50 -25
  242. msprobe/pytorch/online_dispatch/dump_compare.py +21 -9
  243. msprobe/pytorch/online_dispatch/single_compare.py +60 -39
  244. msprobe/pytorch/online_dispatch/torch_ops_config.yaml +9 -1
  245. msprobe/pytorch/online_dispatch/utils.py +48 -23
  246. msprobe/pytorch/parse.py +15 -0
  247. msprobe/pytorch/parse_tool/cli.py +5 -6
  248. msprobe/pytorch/parse_tool/lib/compare.py +19 -26
  249. msprobe/pytorch/parse_tool/lib/config.py +1 -1
  250. msprobe/pytorch/parse_tool/lib/parse_tool.py +4 -2
  251. msprobe/pytorch/parse_tool/lib/utils.py +40 -55
  252. msprobe/pytorch/parse_tool/lib/visualization.py +3 -1
  253. msprobe/pytorch/pt_config.py +192 -40
  254. msprobe/pytorch/service.py +110 -35
  255. msprobe/visualization/__init__.py +14 -0
  256. msprobe/visualization/builder/__init__.py +14 -0
  257. msprobe/visualization/builder/graph_builder.py +165 -0
  258. msprobe/visualization/builder/msprobe_adapter.py +205 -0
  259. msprobe/visualization/compare/__init__.py +14 -0
  260. msprobe/visualization/compare/graph_comparator.py +130 -0
  261. msprobe/visualization/compare/mode_adapter.py +211 -0
  262. msprobe/visualization/graph/__init__.py +14 -0
  263. msprobe/visualization/graph/base_node.py +124 -0
  264. msprobe/visualization/graph/graph.py +200 -0
  265. msprobe/visualization/graph/node_colors.py +95 -0
  266. msprobe/visualization/graph/node_op.py +39 -0
  267. msprobe/visualization/graph_service.py +214 -0
  268. msprobe/visualization/utils.py +232 -0
  269. mindstudio_probe-1.0.4.dist-info/RECORD +0 -276
  270. msprobe/docs/04.acl_config_examples.md +0 -76
  271. msprobe/mindspore/free_benchmark/decorator/dec_forward.py +0 -43
  272. msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +0 -107
  273. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/ssl_config.py +0 -10
  274. msprobe/pytorch/functional/dump_module.py +0 -39
  275. {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.1.dist-info}/LICENSE +0 -0
  276. {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.1.dist-info}/top_level.txt +0 -0
  277. /msprobe/{mindspore/free_benchmark/decorator → pytorch/monitor}/__init__.py +0 -0
  278. /msprobe/pytorch/{functional/data_processor.py → monitor/distributed/__init__.py} +0 -0
@@ -1,11 +1,26 @@
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
1
16
  from typing import Any
2
17
 
3
18
  from mindspore import Tensor, ops
4
19
 
20
+ from msprobe.mindspore.common.const import FreeBenchmarkConst
5
21
  from msprobe.mindspore.common.log import logger
6
- from msprobe.mindspore.free_benchmark.perturbation.base_perturbation import BasePerturbation
7
22
  from msprobe.mindspore.free_benchmark.common.handler_params import HandlerParams
8
- from msprobe.mindspore.common.const import FreeBenchmarkConst
23
+ from msprobe.mindspore.free_benchmark.perturbation.base_perturbation import BasePerturbation
9
24
 
10
25
 
11
26
  class AddNoisePerturbation(BasePerturbation):
@@ -17,7 +32,7 @@ class AddNoisePerturbation(BasePerturbation):
17
32
  """
18
33
  params.fuzzed_value = self.add_noise(params.args[params.index])
19
34
  if not self.is_fuzzed:
20
- logger.warning(f"{self.api_name} can not add noise.")
35
+ logger.warning(f"{self.api_name_with_id} can not add noise.")
21
36
  return False
22
37
  return self.get_fuzzed_result(params)
23
38
 
@@ -43,25 +58,25 @@ class AddNoisePerturbation(BasePerturbation):
43
58
 
44
59
  return inputs
45
60
 
46
- def _get_noise(self, input):
61
+ def _get_noise(self, tensor):
47
62
  """
48
63
  得到要添加的噪声值
49
64
 
50
65
  """
51
66
  if self.is_fuzzed:
52
67
  return False
53
- if not ops.is_floating_point(input) or ops.numel(input) == 0:
68
+ if not ops.is_floating_point(tensor) or ops.numel(tensor) == 0:
54
69
  return False
55
70
 
56
- pert_value = FreeBenchmarkConst.PERT_VALUE_DICT.get(input.dtype)
71
+ pert_value = FreeBenchmarkConst.PERT_VALUE_DICT.get(tensor.dtype)
57
72
  if not pert_value:
58
73
  return False
59
74
  else:
60
75
  self.perturbation_value = pert_value
61
76
 
62
- max_val = ops.max(ops.abs(input))[0].item()
77
+ max_val = ops.max(ops.abs(tensor))[0].item()
63
78
  if max_val < pert_value:
64
79
  return False
65
80
 
66
- noise = ops.full(input.shape, self.perturbation_value, dtype=input.dtype)
81
+ noise = ops.full(tensor.shape, self.perturbation_value, dtype=tensor.dtype)
67
82
  return noise
@@ -1,20 +1,44 @@
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
1
16
  from typing import Any
2
17
 
18
+ from msprobe.core.common.const import Const
19
+ from msprobe.mindspore.free_benchmark.common.config import Config
3
20
  from msprobe.mindspore.free_benchmark.common.handler_params import HandlerParams
21
+ from msprobe.mindspore.free_benchmark.common.utils import Tools
4
22
 
5
23
 
6
24
  class BasePerturbation:
7
25
 
8
- def __init__(self, api_name: str):
9
- self.api_name = api_name
26
+ def __init__(self, api_name_with_id: str):
27
+ self.api_name_with_id = api_name_with_id
10
28
  self.is_fuzzed = False
11
29
  self.perturbation_value = None
12
30
 
13
31
  @staticmethod
14
32
  def get_fuzzed_result(params: HandlerParams):
15
- args_front = params.args[:params.index]
16
- args_rear = params.args[params.index + 1:]
17
- fuzzed_result = params.original_func(*args_front, params.fuzzed_value, *args_rear, **params.kwargs)
33
+ if Config.stage == Const.BACKWARD:
34
+ fuzzed_result = Tools.get_grad(params.original_func, *params.args[:params.index],
35
+ params.fuzzed_value, *params.args[params.index + 1:], **params.kwargs)
36
+
37
+ if fuzzed_result is None:
38
+ return False
39
+ else:
40
+ fuzzed_result = params.original_func(*params.args[:params.index], params.fuzzed_value,
41
+ *params.args[params.index + 1:], **params.kwargs)
18
42
  return fuzzed_result
19
43
 
20
44
  def handler(self, params: HandlerParams) -> Any:
@@ -1,10 +1,25 @@
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
1
16
  from typing import Any
2
17
 
3
18
  import numpy as np
4
19
  from mindspore import Tensor, ops
5
20
 
6
- from msprobe.mindspore.common.log import logger
7
21
  from msprobe.mindspore.common.const import FreeBenchmarkConst
22
+ from msprobe.mindspore.common.log import logger
8
23
  from msprobe.mindspore.free_benchmark.common.handler_params import HandlerParams
9
24
  from msprobe.mindspore.free_benchmark.perturbation.base_perturbation import BasePerturbation
10
25
 
@@ -20,12 +35,12 @@ class BitNoisePerturbation(BasePerturbation):
20
35
  noise_type = list(FreeBenchmarkConst.MS_NUMPY_DTYPE_DICT.keys())[
21
36
  list(FreeBenchmarkConst.MS_NUMPY_DTYPE_DICT.values()).index(bit_len_type)]
22
37
  noise = ops.full(inputs.shape, 1, dtype=noise_type)
23
- input_np = inputs.asnumpy()
38
+ input_np = inputs.contiguous().asnumpy()
24
39
  input_np_int = input_np.view(bit_len_type)
25
40
  result = Tensor(input_np_int)
26
41
  result = ops.where(ops.abs(inputs) > sub_normal,
27
42
  ops.bitwise_xor(result, noise), result)
28
- result_np = result.asnumpy()
43
+ result_np = result.contiguous().asnumpy()
29
44
  result_np_float = result_np.view(FreeBenchmarkConst.MS_NUMPY_DTYPE_DICT.get(inputs.dtype))
30
45
  self.is_fuzzed = True
31
46
  return Tensor(result_np_float)
@@ -40,24 +55,24 @@ class BitNoisePerturbation(BasePerturbation):
40
55
  args = params.args
41
56
  params.fuzzed_value = self.add_bit_noise(params.args[params.index])
42
57
  if not self.is_fuzzed:
43
- logger.warning(f"{self.api_name} can not add bit noise.")
58
+ logger.warning(f"{self.api_name_with_id} can not add bit noise.")
44
59
  return False
45
60
  params.args = args
46
61
  return self.get_fuzzed_result(params)
47
62
 
48
- def _get_bit_len_type(self, input):
63
+ def _get_bit_len_type(self, tensor):
49
64
  if self.is_fuzzed:
50
65
  return False
51
- if not isinstance(input, Tensor) or not ops.is_floating_point(input) or \
52
- input.numel() == 0:
66
+ if not isinstance(tensor, Tensor) or not ops.is_floating_point(tensor) or \
67
+ tensor.numel() == 0:
53
68
  return False
54
- bit_len_type = FreeBenchmarkConst.PERT_BIT_DICT.get(input.dtype)
69
+ bit_len_type = FreeBenchmarkConst.PERT_BIT_DICT.get(tensor.dtype)
55
70
  if not bit_len_type:
56
71
  return False
57
- pert_value = FreeBenchmarkConst.PERT_VALUE_DICT.get(input.dtype)
72
+ pert_value = FreeBenchmarkConst.PERT_VALUE_DICT.get(tensor.dtype)
58
73
  if not pert_value:
59
74
  return False
60
- max_val = ops.max(ops.abs(input))[0].item()
75
+ max_val = ops.max(ops.abs(tensor))[0].item()
61
76
  if max_val < pert_value:
62
77
  return False
63
78
  return bit_len_type
@@ -1,14 +1,39 @@
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
1
16
  from typing import Any
2
17
 
3
- from mindspore import Tensor
18
+ from mindspore import Tensor, ops
4
19
 
5
20
  from msprobe.mindspore.common.log import logger
6
- from msprobe.mindspore.free_benchmark.perturbation.base_perturbation import BasePerturbation
7
21
  from msprobe.mindspore.free_benchmark.common.handler_params import HandlerParams
22
+ from msprobe.mindspore.free_benchmark.perturbation.base_perturbation import BasePerturbation
8
23
 
9
24
 
10
25
  class ExchangeValuePerturbation(BasePerturbation):
11
26
 
27
+ @staticmethod
28
+ def _check_tensor_shape(inputs):
29
+ dims = len(inputs.shape)
30
+ if dims == 1 and inputs.shape[0] > 1:
31
+ return True
32
+ if dims > 1 and inputs.shape[1] > 0:
33
+ if inputs.shape[0] > 1 or inputs.shape[1] > 1:
34
+ return True
35
+ return False
36
+
12
37
  def handle(self, params: HandlerParams) -> Any:
13
38
  """
14
39
  返回首尾交换后的api输出
@@ -16,7 +41,7 @@ class ExchangeValuePerturbation(BasePerturbation):
16
41
  """
17
42
  params.fuzzed_value = self.exchange_value(params.args[params.index])
18
43
  if not self.is_fuzzed:
19
- logger.warning(f"{self.api_name} can not exchange value.")
44
+ logger.warning(f"{self.api_name_with_id} can not exchange value.")
20
45
  return False
21
46
  return self.get_fuzzed_result(params)
22
47
 
@@ -25,22 +50,23 @@ class ExchangeValuePerturbation(BasePerturbation):
25
50
  返回首尾交换后的api输入
26
51
 
27
52
  """
28
- if isinstance(inputs, Tensor):
29
- if not self.is_fuzzed and len(inputs.shape) > 0 and inputs.shape[0] > 1:
30
- result = inputs.copy()
31
- if len(inputs.shape) == 1:
32
- first_element = inputs[0]
33
- last_element = inputs[-1]
34
- result[0] = last_element
35
- result[-1] = first_element
36
- else:
37
- first_element = inputs[0][0]
38
- last_element = inputs[-1][-1]
39
- result[0][0] = last_element
40
- result[-1][-1] = first_element
41
-
42
- self.is_fuzzed = True
43
- return result
53
+ if isinstance(inputs, Tensor) and ops.is_floating_point(inputs):
54
+ if self.is_fuzzed or not self._check_tensor_shape(inputs):
55
+ return inputs
56
+ result = inputs.copy()
57
+ if len(inputs.shape) == 1:
58
+ first_element = inputs[0]
59
+ last_element = inputs[-1]
60
+ result[0] = last_element
61
+ result[-1] = first_element
62
+ else:
63
+ first_element = inputs[0][0]
64
+ last_element = inputs[-1][-1]
65
+ result[0][0] = last_element
66
+ result[-1][-1] = first_element
67
+
68
+ self.is_fuzzed = True
69
+ return result
44
70
 
45
71
  if isinstance(inputs, dict):
46
72
  return {k: self.exchange_value(v) for k, v in inputs.items()}
@@ -1,13 +1,29 @@
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
1
16
  from typing import Any
2
17
 
3
18
  import mindspore as ms
4
19
  from mindspore import Tensor, ops
5
20
 
6
- from msprobe.mindspore.free_benchmark.perturbation.base_perturbation import BasePerturbation
7
- from msprobe.mindspore.free_benchmark.common.handler_params import HandlerParams
8
- from msprobe.mindspore.common.const import FreeBenchmarkConst
21
+ from msprobe.core.common.const import Const
9
22
  from msprobe.mindspore.common.log import logger
10
- from msprobe.mindspore.common.const import Const
23
+ from msprobe.mindspore.free_benchmark.common.config import Config
24
+ from msprobe.mindspore.free_benchmark.common.handler_params import HandlerParams
25
+ from msprobe.mindspore.free_benchmark.common.utils import Tools
26
+ from msprobe.mindspore.free_benchmark.perturbation.base_perturbation import BasePerturbation
11
27
 
12
28
 
13
29
  class ImprovePrecisionPerturbation(BasePerturbation):
@@ -26,10 +42,15 @@ class ImprovePrecisionPerturbation(BasePerturbation):
26
42
  def handle(self, params: HandlerParams) -> Any:
27
43
  args = self.improve_tensor_precision(params.args)
28
44
  kwargs = self.improve_tensor_precision(params.kwargs)
29
- fuzzed_value = args
30
- if self.api_name in Const.COMMUNICATION_API_LIST:
31
- params.fuzzed_value = fuzzed_value
32
45
  if not self.is_fuzzed:
33
- logger.warning(f"{self.api_name} can not improve precision.")
46
+ logger.warning(f"{self.api_name_with_id} can not improve precision.")
34
47
  return False
48
+
49
+ if Config.stage == Const.BACKWARD:
50
+ fuzzed_result = Tools.get_grad(params.original_func, *args, **kwargs)
51
+ if fuzzed_result is not None:
52
+ return fuzzed_result
53
+ else:
54
+ return False
55
+
35
56
  return params.original_func(*args, **kwargs)
@@ -1,7 +1,22 @@
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
1
16
  from typing import Any
2
17
 
3
- from msprobe.mindspore.free_benchmark.perturbation.base_perturbation import BasePerturbation
4
18
  from msprobe.mindspore.free_benchmark.common.handler_params import HandlerParams
19
+ from msprobe.mindspore.free_benchmark.perturbation.base_perturbation import BasePerturbation
5
20
 
6
21
 
7
22
  class NoChangePerturbation(BasePerturbation):
@@ -1,10 +1,25 @@
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
1
16
  from msprobe.mindspore.common.const import FreeBenchmarkConst
2
17
  from msprobe.mindspore.free_benchmark.common.config import Config
3
- from .add_noise import AddNoisePerturbation
4
- from .bit_noise import BitNoisePerturbation
5
- from .no_change import NoChangePerturbation
6
- from .improve_precision import ImprovePrecisionPerturbation
7
- from .exchange_value import ExchangeValuePerturbation
18
+ from msprobe.mindspore.free_benchmark.perturbation.add_noise import AddNoisePerturbation
19
+ from msprobe.mindspore.free_benchmark.perturbation.bit_noise import BitNoisePerturbation
20
+ from msprobe.mindspore.free_benchmark.perturbation.exchange_value import ExchangeValuePerturbation
21
+ from msprobe.mindspore.free_benchmark.perturbation.improve_precision import ImprovePrecisionPerturbation
22
+ from msprobe.mindspore.free_benchmark.perturbation.no_change import NoChangePerturbation
8
23
 
9
24
 
10
25
  class PerturbationFactory:
@@ -21,9 +36,9 @@ class PerturbationFactory:
21
36
  }
22
37
 
23
38
  @staticmethod
24
- def create(api_name: str):
39
+ def create(api_name_with_id: str):
25
40
  perturbation = PerturbationFactory.perturbations.get(Config.pert_type)
26
41
  if perturbation:
27
- return perturbation(api_name)
42
+ return perturbation(api_name_with_id)
28
43
  else:
29
44
  raise Exception(f'{Config.pert_type} is a invalid perturbation type')
@@ -1,6 +1,21 @@
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
1
16
  from msprobe.mindspore.common.const import Const
2
17
  from msprobe.mindspore.debugger.debugger_config import DebuggerConfig
3
- from msprobe.mindspore.free_benchmark.api_pynative_self_check import ApiPyNativeSelFCheck
18
+ from msprobe.mindspore.free_benchmark.api_pynative_self_check import ApiPyNativeSelfCheck
4
19
 
5
20
 
6
21
  class SelfCheckToolFactory:
@@ -13,7 +28,7 @@ class SelfCheckToolFactory:
13
28
  Const.API: {
14
29
  Const.GRAPH_KBYK_MODE: None,
15
30
  Const.GRAPH_GE_MODE: None,
16
- Const.PYNATIVE_MODE: ApiPyNativeSelFCheck
31
+ Const.PYNATIVE_MODE: ApiPyNativeSelfCheck
17
32
  },
18
33
  Const.KERNEL: {
19
34
  Const.GRAPH_KBYK_MODE: None,
@@ -1,15 +1,30 @@
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
1
16
  import os
2
17
  import threading
3
- from typing import Dict, Union
18
+ from typing import Dict, Union, Tuple
4
19
 
5
- from msprobe.core.grad_probe.utils import check_str
20
+ from msprobe.core.common.utils import is_int
21
+ from msprobe.core.common.file_utils import create_directory, check_path_before_create
6
22
  from msprobe.core.grad_probe.constant import GradConst
23
+ from msprobe.core.grad_probe.utils import check_str, check_bounds_element, check_param_element
7
24
  from msprobe.mindspore.common.log import logger
8
- from msprobe.core.common.file_utils import create_directory, check_path_before_create
9
25
 
10
26
 
11
27
  class GlobalContext:
12
-
13
28
  _instance = None
14
29
  _instance_lock = threading.Lock()
15
30
  _setting = {
@@ -18,7 +33,7 @@ class GlobalContext:
18
33
  GradConst.STEP: None,
19
34
  GradConst.RANK: None,
20
35
  GradConst.CURRENT_STEP: 0,
21
- GradConst.BOUNDS: [-10, -1, -0.1, -0.01, -0.001, 0, 0.001, 0.01, 0.1, 1, 10],
36
+ GradConst.BOUNDS: [-1, 0, 1],
22
37
  GradConst.OUTPUT_PATH: None
23
38
  }
24
39
 
@@ -31,19 +46,19 @@ class GlobalContext:
31
46
 
32
47
  def init_context(self, config_dict: Dict):
33
48
  level = config_dict.get(GradConst.LEVEL)
34
- check_str(level, variable_name = "level in yaml")
49
+ check_str(level, variable_name="level in yaml")
35
50
  if level in GradConst.SUPPORTED_LEVEL:
36
51
  self._setting[GradConst.LEVEL] = config_dict.get(GradConst.LEVEL)
37
52
  else:
38
53
  raise ValueError("Invalid level set in config yaml file, level option: L0, L1, L2")
39
54
 
40
- self._set_input_list(config_dict, GradConst.PARAM_LIST, str)
41
- self._set_input_list(config_dict, GradConst.BOUNDS, float)
42
- self._set_input_list(config_dict, GradConst.STEP, int)
43
- self._set_input_list(config_dict, GradConst.RANK, int)
55
+ self._set_input_list(config_dict, GradConst.PARAM_LIST, (str,), element_check=check_param_element)
56
+ self._set_input_list(config_dict, GradConst.BOUNDS, (float, int), element_check=check_bounds_element)
57
+ self._set_input_list(config_dict, GradConst.STEP, (int,))
58
+ self._set_input_list(config_dict, GradConst.RANK, (int,))
44
59
 
45
60
  output_path = config_dict.get(GradConst.OUTPUT_PATH)
46
- check_str(output_path, variable_name = "output_path in yaml")
61
+ check_str(output_path, variable_name="output_path in yaml")
47
62
  try:
48
63
  check_path_before_create(output_path)
49
64
  except RuntimeError as err:
@@ -70,21 +85,36 @@ class GlobalContext:
70
85
  dump_rank_list = self.get_context(GradConst.RANK)
71
86
  return (not dump_rank_list) or (rank in dump_rank_list)
72
87
 
73
- def _set_input_list(self, config_dict: Dict, name: str, dtype: Union[int, str, float]):
74
- value = config_dict.get(name)
88
+ def _get_type_str(self, dtype: Union[int, str, float, Tuple[int, str, float]]):
89
+ if isinstance(dtype, tuple):
90
+ return "/".join([self._get_type_str(element) for element in dtype])
75
91
  if dtype == int:
76
92
  type_str = "integer"
77
93
  elif dtype == float:
78
94
  type_str = "float"
79
95
  else:
80
96
  type_str = "string"
97
+ return type_str
98
+
99
+ def _set_input_list(self, config_dict: Dict, name: str,
100
+ dtype: Union[int, str, float, Tuple[int, str, float]], element_check=None):
101
+ value = config_dict.get(name)
102
+ type_str = self._get_type_str(dtype)
81
103
  if value and isinstance(value, list):
82
104
  for val in value:
83
105
  if not isinstance(val, dtype):
84
- logger.warning(f"Invalid {name} which must be None or list of {type_str}")
106
+ logger.warning(f"Invalid {name} which must be None or list of {type_str}, use default value.")
107
+ return
108
+ elif isinstance(val, int) and not is_int(val):
109
+ logger.warning(f"Invalid {name} which must be None or list of int, use default value.")
110
+ return
111
+ if element_check and not element_check(val):
112
+ logger.warning(f"Given {name} violates some rules, use default value.")
85
113
  return
114
+
86
115
  self._setting[name] = value
87
116
  else:
88
117
  logger.warning(f"{name} is None or not a list with valid items, use default value.")
89
118
 
119
+
90
120
  grad_context = GlobalContext()
@@ -1,20 +1,33 @@
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import multiprocessing
1
17
  import os
2
18
  import time
3
- from typing import List, Tuple
4
- import multiprocessing
5
19
  from multiprocessing import Process
20
+ from typing import List
6
21
 
7
- import numpy as np
8
22
  import mindspore as ms
9
- from mindspore.communication import get_rank
10
- from mindspore.ops import operations as P
23
+ import numpy as np
11
24
  from mindspore.common.parameter import Parameter
12
-
13
- from msprobe.core.grad_probe.utils import ListCache
14
- from msprobe.core.grad_probe.constant import GradConst
15
- from msprobe.mindspore.common.log import logger
25
+ from mindspore.communication import get_rank
16
26
  from msprobe.core.common.file_utils import (create_directory, check_file_or_directory_path,
17
27
  write_csv, remove_path, move_file, load_npy)
28
+ from msprobe.core.grad_probe.constant import GradConst
29
+ from msprobe.core.grad_probe.utils import ListCache
30
+ from msprobe.mindspore.common.log import logger
18
31
  from msprobe.mindspore.grad_probe.global_context import grad_context, GlobalContext
19
32
 
20
33
 
@@ -28,12 +41,12 @@ def get_rank_id():
28
41
 
29
42
  @ms.jit
30
43
  def grad_dump(dump_dir: str, g_name: str, dump_step: Parameter, grad: ms.Tensor, level: str, bounds: List):
31
- '''
44
+ """
32
45
  Dump gradient statistic data.
33
46
  level0: [step, max, min, norm, shape_dim, shape]
34
47
  level1: [step, max, min, norm, shape_dim, shape] + grad_bool_data
35
48
  level2: [step, max, min, norm, shape_dim, shape, dist_dim, dist] + grad_bool_data
36
- '''
49
+ """
37
50
  dump_path = os.path.join(dump_dir, g_name)
38
51
  dump_dir_path = dump_path + "_dir"
39
52
  save_op = ms.ops.TensorDump()
@@ -182,7 +195,7 @@ class CSVGenerator(Process):
182
195
  shape_dim = int(stat_data[GradConst.SHAPE_DIM_IDX])
183
196
  file_name = os.path.basename(file_path)
184
197
  prefix_idx = len(file_name.split("_")[0])
185
- param_name = file_name[(prefix_idx + 1) : -(len(GradConst.NPY_SUFFIX) + 1)]
198
+ param_name = file_name[(prefix_idx + 1): -(len(GradConst.NPY_SUFFIX) + 1)]
186
199
  if not param_name:
187
200
  raise RuntimeError("Invalid gradient statistic file name.")
188
201
  csv_line = [param_name]
@@ -224,8 +237,9 @@ class CSVGenerator(Process):
224
237
  if i == 0:
225
238
  intervals.append(f"(-inf, {self.bounds[i]}]")
226
239
  else:
227
- intervals.append(f"({self.bounds[i-1]}, {self.bounds[i]}]")
240
+ intervals.append(f"({self.bounds[i - 1]}, {self.bounds[i]}]")
228
241
  intervals.extend([f"({self.bounds[-1]}, inf)", "=0"])
229
242
  return intervals
230
243
 
244
+
231
245
  csv_generator = CSVGenerator()
@@ -1,7 +1,22 @@
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from msprobe.core.grad_probe.constant import GradConst
1
17
  from msprobe.mindspore.grad_probe.global_context import grad_context
2
18
  from msprobe.mindspore.grad_probe.grad_analyzer import csv_generator
3
19
  from msprobe.mindspore.grad_probe.hook import hook_optimizer
4
- from msprobe.core.grad_probe.constant import GradConst
5
20
 
6
21
 
7
22
  class GradientMonitor: