mindstudio-probe 1.1.1__py3-none-any.whl → 1.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (197) hide show
  1. {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.1.dist-info}/METADATA +3 -2
  2. {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.1.dist-info}/RECORD +196 -141
  3. msprobe/CMakeLists.txt +5 -0
  4. msprobe/README.md +14 -19
  5. msprobe/config.json +1 -0
  6. msprobe/core/common/const.py +155 -6
  7. msprobe/core/common/exceptions.py +3 -1
  8. msprobe/core/common/file_utils.py +33 -7
  9. msprobe/core/common/inplace_ops.yaml +3 -0
  10. msprobe/core/common/utils.py +28 -14
  11. msprobe/core/common_config.py +6 -0
  12. msprobe/core/compare/acc_compare.py +139 -128
  13. msprobe/core/compare/check.py +31 -29
  14. msprobe/core/compare/compare_cli.py +17 -16
  15. msprobe/core/compare/highlight.py +186 -99
  16. msprobe/core/compare/layer_mapping/data_scope_parser.py +18 -7
  17. msprobe/core/compare/layer_mapping/layer_mapping.py +21 -14
  18. msprobe/core/compare/layer_mapping/postprocess_pass.py +4 -3
  19. msprobe/core/compare/merge_result/merge_result.py +380 -0
  20. msprobe/core/compare/merge_result/merge_result_cli.py +31 -0
  21. msprobe/core/compare/multiprocessing_compute.py +2 -2
  22. msprobe/core/compare/npy_compare.py +109 -147
  23. msprobe/core/compare/utils.py +189 -69
  24. msprobe/core/data_dump/data_collector.py +51 -21
  25. msprobe/core/data_dump/data_processor/base.py +38 -20
  26. msprobe/core/data_dump/data_processor/factory.py +5 -3
  27. msprobe/core/data_dump/data_processor/mindspore_processor.py +154 -20
  28. msprobe/core/data_dump/data_processor/pytorch_processor.py +118 -58
  29. msprobe/core/data_dump/json_writer.py +29 -1
  30. msprobe/core/data_dump/scope.py +19 -18
  31. msprobe/core/overflow_check/abnormal_scene.py +9 -5
  32. msprobe/core/overflow_check/checker.py +1 -1
  33. msprobe/core/overflow_check/utils.py +1 -1
  34. msprobe/docs/01.installation.md +96 -17
  35. msprobe/docs/02.config_introduction.md +5 -5
  36. msprobe/docs/05.data_dump_PyTorch.md +91 -61
  37. msprobe/docs/06.data_dump_MindSpore.md +57 -19
  38. msprobe/docs/07.accuracy_checker_PyTorch.md +18 -18
  39. msprobe/docs/09.accuracy_checker_MindSpore.md +4 -4
  40. msprobe/docs/10.accuracy_compare_PyTorch.md +99 -41
  41. msprobe/docs/11.accuracy_compare_MindSpore.md +249 -48
  42. msprobe/docs/12.overflow_check_PyTorch.md +1 -1
  43. msprobe/docs/19.monitor.md +120 -27
  44. msprobe/docs/21.visualization_PyTorch.md +115 -35
  45. msprobe/docs/22.visualization_MindSpore.md +138 -41
  46. msprobe/docs/23.generate_operator_PyTorch.md +107 -0
  47. msprobe/docs/24.code_mapping_Mindspore.md +28 -0
  48. msprobe/docs/{23.tool_function_introduction.md → 25.tool_function_introduction.md} +1 -0
  49. msprobe/docs/26.data_dump_PyTorch_baseline.md +37 -0
  50. msprobe/docs/27.dump_json_instruction.md +521 -0
  51. msprobe/docs/FAQ.md +26 -2
  52. msprobe/docs/accuracy_checker_MindSpore/accuracy_checker_MindSpore_baseline.md +14 -0
  53. msprobe/docs/data_dump_MindSpore/data_dump_MindSpore_baseline.md +22 -0
  54. msprobe/docs/img/merge_result.png +0 -0
  55. msprobe/docs/img/visualization/fuzzy_match_ms.png +0 -0
  56. msprobe/docs/img/visualization/fuzzy_match_pt.png +0 -0
  57. msprobe/docs/img/visualization/tensorboard_1.png +0 -0
  58. msprobe/docs/img/visualization/tensorboard_2.png +0 -0
  59. msprobe/docs/img/visualization/vis_browser_1.png +0 -0
  60. msprobe/docs/img/visualization/vis_browser_2.png +0 -0
  61. msprobe/docs/img/visualization/vis_precision_info.png +0 -0
  62. msprobe/docs/img/visualization/vis_search_info.png +0 -0
  63. msprobe/docs/img/visualization/vis_show_info.png +0 -0
  64. msprobe/docs/img/visualization/vis_showcase.png +0 -0
  65. msprobe/docs/img/visualization/vis_unmatch_info.png +0 -0
  66. msprobe/docs/visualization/GPTModel.png +0 -0
  67. msprobe/docs/visualization/ParallelMLP.png +0 -0
  68. msprobe/docs/visualization/layer_mapping_example.md +132 -0
  69. msprobe/docs/visualization/mapping.png +0 -0
  70. msprobe/docs/visualization/mapping1.png +0 -0
  71. msprobe/docs/visualization/module_name.png +0 -0
  72. msprobe/docs/visualization/module_name1.png +0 -0
  73. msprobe/docs/visualization/no_mapping.png +0 -0
  74. msprobe/docs/visualization/no_mapping1.png +0 -0
  75. msprobe/docs/visualization/no_mapping_analyze.png +0 -0
  76. msprobe/docs/visualization/top_layer.png +0 -0
  77. msprobe/mindspore/__init__.py +10 -0
  78. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +57 -25
  79. msprobe/mindspore/api_accuracy_checker/cmd_parser.py +2 -1
  80. msprobe/mindspore/api_accuracy_checker/compute_element.py +5 -7
  81. msprobe/mindspore/api_accuracy_checker/data_manager.py +37 -0
  82. msprobe/mindspore/api_accuracy_checker/main.py +1 -0
  83. msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +12 -6
  84. msprobe/mindspore/api_accuracy_checker/multi_data_manager.py +3 -1
  85. msprobe/mindspore/code_mapping/bind.py +264 -0
  86. msprobe/mindspore/code_mapping/cmd_parser.py +40 -0
  87. msprobe/mindspore/code_mapping/graph.py +49 -0
  88. msprobe/mindspore/code_mapping/graph_parser.py +226 -0
  89. msprobe/mindspore/code_mapping/main.py +24 -0
  90. msprobe/mindspore/code_mapping/processor.py +34 -0
  91. msprobe/mindspore/common/const.py +3 -1
  92. msprobe/mindspore/common/utils.py +50 -5
  93. msprobe/mindspore/compare/distributed_compare.py +0 -2
  94. msprobe/mindspore/compare/ms_compare.py +105 -63
  95. msprobe/mindspore/compare/ms_graph_compare.py +14 -5
  96. msprobe/mindspore/debugger/debugger_config.py +3 -0
  97. msprobe/mindspore/debugger/precision_debugger.py +81 -12
  98. msprobe/mindspore/dump/hook_cell/api_registry.py +83 -16
  99. msprobe/mindspore/dump/hook_cell/hook_cell.py +60 -38
  100. msprobe/mindspore/dump/hook_cell/primitive_hooks.py +33 -15
  101. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +11 -1
  102. msprobe/mindspore/dump/hook_cell/wrap_api.py +92 -1
  103. msprobe/mindspore/dump/kernel_dump/kernel_config.py +33 -0
  104. msprobe/mindspore/dump/kernel_graph_dump.py +7 -0
  105. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +13 -4
  106. msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +2 -2
  107. msprobe/mindspore/grad_probe/grad_analyzer.py +24 -12
  108. msprobe/mindspore/grad_probe/hook.py +13 -4
  109. msprobe/mindspore/mindtorch/__init__.py +18 -0
  110. msprobe/mindspore/mindtorch/mindtorch_adaptor.py +255 -0
  111. msprobe/mindspore/ms_config.py +5 -1
  112. msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +7 -0
  113. msprobe/mindspore/service.py +267 -101
  114. msprobe/msprobe.py +24 -3
  115. msprobe/pytorch/__init__.py +7 -6
  116. msprobe/pytorch/api_accuracy_checker/common/utils.py +31 -16
  117. msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +41 -8
  118. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +100 -267
  119. msprobe/pytorch/api_accuracy_checker/compare/api_precision_standard.yaml +4 -1
  120. msprobe/pytorch/api_accuracy_checker/compare/compare.py +69 -68
  121. msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +54 -0
  122. msprobe/pytorch/api_accuracy_checker/compare/compare_input.py +51 -0
  123. msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +2 -4
  124. msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +54 -30
  125. msprobe/pytorch/api_accuracy_checker/precision_standard/absolute_threshold.py +106 -0
  126. msprobe/pytorch/api_accuracy_checker/precision_standard/accumulative_error_compare.py +107 -0
  127. msprobe/pytorch/api_accuracy_checker/precision_standard/base_standard.py +151 -0
  128. msprobe/pytorch/api_accuracy_checker/precision_standard/benchmark_compare.py +226 -0
  129. msprobe/pytorch/api_accuracy_checker/precision_standard/binary_consistency.py +68 -0
  130. msprobe/pytorch/api_accuracy_checker/precision_standard/standard_config.py +218 -0
  131. msprobe/pytorch/api_accuracy_checker/precision_standard/standard_register.py +104 -0
  132. msprobe/pytorch/api_accuracy_checker/precision_standard/thousandth_standard.py +63 -0
  133. msprobe/pytorch/api_accuracy_checker/precision_standard/ulp_compare.py +200 -0
  134. msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +57 -1
  135. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +2 -1
  136. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +42 -14
  137. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +64 -19
  138. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +34 -4
  139. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +5 -3
  140. msprobe/pytorch/bench_functions/npu_fusion_attention.py +42 -10
  141. msprobe/pytorch/common/parse_json.py +2 -1
  142. msprobe/pytorch/common/utils.py +45 -2
  143. msprobe/pytorch/compare/distributed_compare.py +17 -29
  144. msprobe/pytorch/compare/pt_compare.py +40 -20
  145. msprobe/pytorch/debugger/debugger_config.py +27 -12
  146. msprobe/pytorch/debugger/precision_debugger.py +42 -12
  147. msprobe/pytorch/dump/module_dump/__init__.py +0 -0
  148. msprobe/pytorch/dump/module_dump/module_dump.py +86 -0
  149. msprobe/pytorch/{module_processer.py → dump/module_dump/module_processer.py} +80 -6
  150. msprobe/pytorch/free_benchmark/common/params.py +2 -1
  151. msprobe/pytorch/free_benchmark/common/utils.py +3 -0
  152. msprobe/pytorch/free_benchmark/compare/grad_saver.py +0 -2
  153. msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +31 -47
  154. msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +0 -4
  155. msprobe/pytorch/hook_module/__init__.py +1 -1
  156. msprobe/pytorch/hook_module/hook_module.py +14 -11
  157. msprobe/pytorch/hook_module/register_optimizer_hook.py +59 -0
  158. msprobe/pytorch/hook_module/support_wrap_ops.yaml +34 -0
  159. msprobe/pytorch/hook_module/wrap_distributed.py +6 -8
  160. msprobe/pytorch/hook_module/wrap_functional.py +0 -40
  161. msprobe/pytorch/monitor/anomaly_analyse.py +1 -1
  162. msprobe/pytorch/monitor/anomaly_detect.py +107 -22
  163. msprobe/pytorch/monitor/csv2tb.py +166 -0
  164. msprobe/pytorch/monitor/distributed/wrap_distributed.py +25 -14
  165. msprobe/pytorch/monitor/features.py +3 -3
  166. msprobe/pytorch/monitor/module_hook.py +483 -277
  167. msprobe/pytorch/monitor/module_metric.py +27 -48
  168. msprobe/pytorch/monitor/module_spec_verifier.py +3 -1
  169. msprobe/pytorch/monitor/optimizer_collect.py +52 -14
  170. msprobe/pytorch/monitor/unittest/test_monitor.py +24 -9
  171. msprobe/pytorch/monitor/utils.py +77 -6
  172. msprobe/pytorch/online_dispatch/dispatch.py +8 -2
  173. msprobe/pytorch/parse_tool/lib/compare.py +10 -10
  174. msprobe/pytorch/parse_tool/lib/config.py +5 -7
  175. msprobe/pytorch/parse_tool/lib/file_desc.py +15 -1
  176. msprobe/pytorch/parse_tool/lib/interactive_cli.py +10 -10
  177. msprobe/pytorch/parse_tool/lib/parse_exception.py +7 -7
  178. msprobe/pytorch/parse_tool/lib/parse_tool.py +11 -10
  179. msprobe/pytorch/parse_tool/lib/utils.py +18 -19
  180. msprobe/pytorch/parse_tool/lib/visualization.py +9 -10
  181. msprobe/pytorch/service.py +176 -106
  182. msprobe/visualization/builder/graph_builder.py +62 -5
  183. msprobe/visualization/builder/msprobe_adapter.py +24 -2
  184. msprobe/visualization/compare/graph_comparator.py +64 -14
  185. msprobe/visualization/compare/mode_adapter.py +1 -15
  186. msprobe/visualization/graph/base_node.py +12 -17
  187. msprobe/visualization/graph/distributed_analyzer.py +318 -0
  188. msprobe/visualization/graph/graph.py +9 -0
  189. msprobe/visualization/graph_service.py +97 -23
  190. msprobe/visualization/utils.py +14 -29
  191. msprobe/pytorch/functional/module_dump.py +0 -84
  192. {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.1.dist-info}/LICENSE +0 -0
  193. {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.1.dist-info}/WHEEL +0 -0
  194. {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.1.dist-info}/entry_points.txt +0 -0
  195. {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.1.dist-info}/top_level.txt +0 -0
  196. /msprobe/docs/{data_dump_Mindspore → data_dump_MindSpore}/dynamic_graph_quick_start_example.md +0 -0
  197. /msprobe/{pytorch/functional → mindspore/code_mapping}/__init__.py +0 -0
@@ -1,4 +1,4 @@
1
- # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
1
+ # Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
2
2
  # All rights reserved.
3
3
  #
4
4
  # Licensed under the Apache License, Version 2.0 (the "License");
@@ -23,10 +23,16 @@ from mindspore.mint.nn import functional
23
23
  from msprobe.core.common.const import Const
24
24
  from msprobe.core.common.file_utils import load_yaml
25
25
  from msprobe.mindspore.common.const import Const as MsConst
26
+ from msprobe.mindspore.common.utils import is_mindtorch
26
27
  from msprobe.mindspore.dump.hook_cell.hook_cell import HOOKCell
27
28
 
29
+ if is_mindtorch():
30
+ import torch
31
+ import torch_npu
32
+
28
33
  cur_path = os.path.dirname(os.path.realpath(__file__))
29
34
  yaml_path = os.path.join(cur_path, MsConst.SUPPORTED_API_LIST_FILE)
35
+ torch_yaml_path = os.path.join(cur_path, "../../../pytorch/hook_module", MsConst.SUPPORTED_API_LIST_FILE)
30
36
 
31
37
 
32
38
  class HOOKTensor(object):
@@ -53,6 +59,26 @@ class HOOKDistributedOP(object):
53
59
  pass
54
60
 
55
61
 
62
+ class HOOKTorchOP(object):
63
+ pass
64
+
65
+
66
+ class HOOKTorchTensor(object):
67
+ pass
68
+
69
+
70
+ class HOOKTorchFunctionalOP(object):
71
+ pass
72
+
73
+
74
+ class HOOKTorchDistributedOP(object):
75
+ pass
76
+
77
+
78
+ class HOOKTorchNpuOP(object):
79
+ pass
80
+
81
+
56
82
  class ApiTemplate(HOOKCell):
57
83
  def __init__(self, api_name, api_dict, prefix, hook):
58
84
  self.api_name = api_name
@@ -60,7 +86,30 @@ class ApiTemplate(HOOKCell):
60
86
  self.prefix_api_name = prefix + str(api_name.split(Const.SEP)[-1]) + Const.SEP
61
87
  super().__init__(hook)
62
88
 
89
+ @staticmethod
90
+ def async_to_sync(output):
91
+ # Fake handle, used to return after the CommHandle executes the wait method
92
+ fake_handle = type("FakeHandle", (), {"wait": lambda self: None})()
93
+ if isinstance(output, tuple) and len(output) == 2 and hasattr(output[1], "wait"):
94
+ output[1].wait()
95
+ output = (output[0], fake_handle)
96
+ elif hasattr(output, "wait"):
97
+ output.wait()
98
+ output = fake_handle
99
+ return output
100
+
63
101
  def construct(self, *args, **kwargs):
102
+ if self.api_name.startswith(MsConst.DROPOUT_API_NAME_PREFIX):
103
+ return args[0] if args else kwargs.get(Const.INPUT)
104
+
105
+ output = self.api_func(*args, **kwargs)
106
+
107
+ if self.prefix_api_name.startswith(MsConst.DISTRIBUTED_DATA_PREFIX):
108
+ if kwargs.get("async_op") or self.api_name in ["isend", "irecv"]:
109
+ output = self.async_to_sync(output)
110
+ return output
111
+
112
+ def forward(self, *args, **kwargs):
64
113
  if self.api_name.startswith(MsConst.DROPOUT_API_NAME_PREFIX):
65
114
  return args[0] if args else kwargs.get(Const.INPUT)
66
115
  return self.api_func(*args, **kwargs)
@@ -77,6 +126,15 @@ class WrapApiName:
77
126
  self.distributed_api_names = distributed_api_names
78
127
 
79
128
 
129
+ class WrapTorchApiName:
130
+ def __init__(self, torch_api_names, tensor_api_names, functional_api_names, distributed_api_names, npu_api_names):
131
+ self.torch_api_names = torch_api_names
132
+ self.tensor_api_names = tensor_api_names
133
+ self.functional_api_names = functional_api_names
134
+ self.distributed_api_names = distributed_api_names
135
+ self.npu_api_names = npu_api_names
136
+
137
+
80
138
  def get_wrap_api_list():
81
139
  api_list = load_yaml(yaml_path)
82
140
  tensor_api = api_list.get(MsConst.SUPPORTED_TENSOR_LIST_KEY)
@@ -93,6 +151,21 @@ def get_wrap_api_list():
93
151
  return wrap_api_name
94
152
 
95
153
 
154
+ def get_wrap_torch_api_list():
155
+ api_list = load_yaml(torch_yaml_path)
156
+ torch_api = api_list.get("torch")
157
+ tensor_api = api_list.get("tensor")
158
+ functional_api = api_list.get("functional")
159
+ distributed_api = api_list.get("distributed")
160
+ npu_api = api_list.get("torch_npu")
161
+ wrap_api_name = WrapTorchApiName(set(torch_api) & set(dir(torch)),
162
+ set(tensor_api) & set(dir(torch.Tensor)),
163
+ set(functional_api) & set(dir(torch.nn.functional)),
164
+ set(distributed_api) & set(dir(torch.distributed)),
165
+ set(npu_api) & set(dir(torch_npu)))
166
+ return wrap_api_name
167
+
168
+
96
169
  def wrap_api_func(api_name, api_dict, prefix, hook):
97
170
  def api_function(*args, **kwargs):
98
171
  return ApiTemplate(api_name, api_dict, prefix, hook)(*args, **kwargs)
@@ -106,6 +179,24 @@ def wrap_api_func_and_bind(api_list, api_dict, prefix, hook, hook_class):
106
179
 
107
180
 
108
181
  def setup_hooks(hook):
182
+ if is_mindtorch():
183
+ torch_wrap_api_name = get_wrap_torch_api_list()
184
+ wrap_api_func_and_bind(torch_wrap_api_name.torch_api_names,
185
+ {f: getattr(torch, f) for f in dir(torch)},
186
+ MsConst.TORCH_DATA_PREFIX, hook, HOOKTorchOP)
187
+ wrap_api_func_and_bind(torch_wrap_api_name.tensor_api_names,
188
+ {f: getattr(torch.Tensor, f) for f in dir(torch.Tensor)},
189
+ MsConst.TENSOR_DATA_PREFIX, hook, HOOKTorchTensor)
190
+ wrap_api_func_and_bind(torch_wrap_api_name.functional_api_names,
191
+ {f: getattr(torch.nn.functional, f) for f in dir(torch.nn.functional)},
192
+ MsConst.OPS_DATA_PREFIX, hook, HOOKTorchFunctionalOP)
193
+ wrap_api_func_and_bind(torch_wrap_api_name.distributed_api_names,
194
+ {f: getattr(torch.distributed, f) for f in dir(torch.distributed)},
195
+ MsConst.DISTRIBUTED_DATA_PREFIX, hook, HOOKTorchDistributedOP)
196
+ wrap_api_func_and_bind(torch_wrap_api_name.npu_api_names, {f: getattr(torch_npu, f) for f in dir(torch_npu)},
197
+ MsConst.TORCH_NPU_DATA_PREFIX, hook, HOOKTorchNpuOP)
198
+ return
199
+
109
200
  wrap_api_name = get_wrap_api_list()
110
201
  wrap_api_func_and_bind(wrap_api_name.tensor_api_names, {f: getattr(Tensor, f) for f in dir(Tensor)},
111
202
  MsConst.TENSOR_DATA_PREFIX, hook, HOOKTensor)
@@ -0,0 +1,33 @@
1
+ # Copyright (c) 2025, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import os
17
+
18
+ from msprobe.core.common.file_utils import save_json
19
+
20
+
21
+ def create_kernel_config_json(dump_path, cur_rank):
22
+ kernel_config_name = "kernel_config.json" if cur_rank == '' else f"kernel_config_{cur_rank}.json"
23
+ kernel_config_path = os.path.join(dump_path, kernel_config_name)
24
+ config_info = {
25
+ "dump": {
26
+ "dump_list": [],
27
+ "dump_path": dump_path,
28
+ "dump_mode": "all",
29
+ "dump_op_switch": "on"
30
+ }
31
+ }
32
+ save_json(kernel_config_path, config_info, indent=4)
33
+ return kernel_config_path
@@ -56,6 +56,13 @@ class KernelGraphDump:
56
56
  self.dump_json["common_dump_settings"]["input_output"] = 2
57
57
 
58
58
  def handle(self):
59
+ try:
60
+ from msprobe.lib import _msprobe_c
61
+ return
62
+ except ImportError:
63
+ # 如果没有_msprobe_ce_c走MindSpore老流程
64
+ logger.info("Module _msprobe_c has not been installed, use interface in mindspore instead.")
65
+
59
66
  if os.getenv("GRAPH_OP_RUN") == "1":
60
67
  raise Exception("Must run in graph mode, not kbk mode")
61
68
  json_path = self.dump_json["common_dump_settings"]["path"]
@@ -19,7 +19,6 @@ import os
19
19
  import traceback
20
20
 
21
21
  import mindspore as ms
22
-
23
22
  from msprobe.core.common.const import Const
24
23
  from msprobe.core.common.exceptions import DistributedNotInitializedError
25
24
  from msprobe.core.common.file_utils import check_path_length, load_yaml
@@ -29,6 +28,7 @@ from msprobe.mindspore.common.log import logger
29
28
  from msprobe.mindspore.common.utils import get_rank_if_initialized
30
29
  from msprobe.mindspore.debugger.debugger_config import DebuggerConfig
31
30
  from msprobe.mindspore.dump.hook_cell.api_registry import api_register
31
+ from msprobe.mindspore.dump.hook_cell.hook_cell import HOOKCell
32
32
  from msprobe.mindspore.free_benchmark.common.config import Config
33
33
  from msprobe.mindspore.free_benchmark.common.handler_params import HandlerParams
34
34
  from msprobe.mindspore.free_benchmark.common.utils import Tools
@@ -63,7 +63,10 @@ class ApiPyNativeSelfCheck:
63
63
  api_register.initialize_hook(self.build_hook)
64
64
  api_register.api_set_hook_func()
65
65
 
66
- def build_hook(self, api_name_with_id):
66
+ def build_hook(self, api_name):
67
+ def pre_hook(cell, input_data):
68
+ return None
69
+
67
70
  def forward_hook(api_name_with_id, cell, input_data, output_data):
68
71
  ret = None
69
72
 
@@ -85,7 +88,10 @@ class ApiPyNativeSelfCheck:
85
88
  def backward_hook(cell, grad_input, grad_output):
86
89
  pass
87
90
 
91
+ HOOKCell.get_cell_count(api_name)
92
+ api_name_with_id = api_name + str(HOOKCell.get_cell_count(api_name)) + Const.SEP
88
93
  forward_hook = functools.partial(forward_hook, api_name_with_id)
94
+ HOOKCell.add_cell_count(api_name)
89
95
 
90
96
  def wrap_forward_hook(cell, input_data, output_data):
91
97
  return forward_hook(cell, input_data, output_data)
@@ -93,7 +99,10 @@ class ApiPyNativeSelfCheck:
93
99
  def wrap_backward_hook(cell, grad_input, grad_output):
94
100
  return backward_hook(cell, grad_input, grad_output)
95
101
 
96
- return wrap_forward_hook, wrap_backward_hook
102
+ def pre_backward_hook(cell, grad_input):
103
+ return None
104
+
105
+ return pre_hook, wrap_forward_hook, wrap_backward_hook, pre_backward_hook
97
106
 
98
107
  def store_original_func(self):
99
108
  for api_name in self.api_list:
@@ -138,7 +147,7 @@ def get_module(api_name):
138
147
  module_obj = importlib.import_module(func_name_list[0])
139
148
  for i, module_name in enumerate(func_name_list[1:-1]):
140
149
  if not hasattr(module_obj, module_name):
141
- importlib.import_module(f"{Const.SEP.join(func_name_list[:i+2])}")
150
+ importlib.import_module(f"{Const.SEP.join(func_name_list[:i + 2])}")
142
151
  module_obj = getattr(module_obj, module_name)
143
152
  orig_func = getattr(module_obj, func_name)
144
153
 
@@ -35,12 +35,12 @@ class BitNoisePerturbation(BasePerturbation):
35
35
  noise_type = list(FreeBenchmarkConst.MS_NUMPY_DTYPE_DICT.keys())[
36
36
  list(FreeBenchmarkConst.MS_NUMPY_DTYPE_DICT.values()).index(bit_len_type)]
37
37
  noise = ops.full(inputs.shape, 1, dtype=noise_type)
38
- input_np = inputs.contiguous().asnumpy()
38
+ input_np = inputs.asnumpy()
39
39
  input_np_int = input_np.view(bit_len_type)
40
40
  result = Tensor(input_np_int)
41
41
  result = ops.where(ops.abs(inputs) > sub_normal,
42
42
  ops.bitwise_xor(result, noise), result)
43
- result_np = result.contiguous().asnumpy()
43
+ result_np = result.asnumpy()
44
44
  result_np_float = result_np.view(FreeBenchmarkConst.MS_NUMPY_DTYPE_DICT.get(inputs.dtype))
45
45
  self.is_fuzzed = True
46
46
  return Tensor(result_np_float)
@@ -16,6 +16,7 @@
16
16
  import multiprocessing
17
17
  import os
18
18
  import time
19
+ from dataclasses import dataclass
19
20
  from multiprocessing import Process
20
21
  from typing import List
21
22
 
@@ -23,6 +24,7 @@ import mindspore as ms
23
24
  import numpy as np
24
25
  from mindspore.common.parameter import Parameter
25
26
  from mindspore.communication import get_rank
27
+
26
28
  from msprobe.core.common.file_utils import (create_directory, check_file_or_directory_path,
27
29
  write_csv, remove_path, move_file, load_npy)
28
30
  from msprobe.core.grad_probe.constant import GradConst
@@ -31,6 +33,16 @@ from msprobe.mindspore.common.log import logger
31
33
  from msprobe.mindspore.grad_probe.global_context import grad_context, GlobalContext
32
34
 
33
35
 
36
+ @dataclass
37
+ class GradDumpConfig:
38
+ dump_dir: str
39
+ g_name: str
40
+ dump_step: Parameter
41
+ grad: ms.Tensor
42
+ level: str
43
+ bounds: List
44
+
45
+
34
46
  def get_rank_id():
35
47
  try:
36
48
  rank_id = get_rank()
@@ -40,35 +52,35 @@ def get_rank_id():
40
52
 
41
53
 
42
54
  @ms.jit
43
- def grad_dump(dump_dir: str, g_name: str, dump_step: Parameter, grad: ms.Tensor, level: str, bounds: List):
55
+ def grad_dump(config: GradDumpConfig):
44
56
  """
45
57
  Dump gradient statistic data.
46
58
  level0: [step, max, min, norm, shape_dim, shape]
47
59
  level1: [step, max, min, norm, shape_dim, shape] + grad_bool_data
48
60
  level2: [step, max, min, norm, shape_dim, shape, dist_dim, dist] + grad_bool_data
49
61
  """
50
- dump_path = os.path.join(dump_dir, g_name)
62
+ dump_path = os.path.join(config.dump_dir, config.g_name)
51
63
  dump_dir_path = dump_path + "_dir"
52
64
  save_op = ms.ops.TensorDump()
53
65
 
54
- grad_flat = grad.reshape(-1)
66
+ grad_flat = config.grad.reshape(-1)
55
67
  max_val = grad_flat.max(axis=0).float()
56
68
  min_val = grad_flat.min(axis=0).float()
57
69
  norm_val = grad_flat.norm(ord=2).float()
58
- shape = grad.shape
59
- extrem_list = [dump_step[0].float(), max_val, min_val, norm_val]
70
+ shape = config.grad.shape
71
+ extrem_list = [config.dump_step[0].float(), max_val, min_val, norm_val]
60
72
  extrem_stat = ms.ops.stack(extrem_list)
61
73
  shape_list = [len(shape)] + list(shape)
62
74
  shape_stat = ms.Tensor(shape_list).float()
63
75
  level0_stat = ms.ops.concat((extrem_stat, shape_stat), axis=0)
64
76
  level_stat = level0_stat
65
77
 
66
- if level == GradConst.LEVEL2:
67
- zero_grad = (grad == 0).sum()
68
- dist_dim = ms.Tensor([len(bounds) + 2]).float()
69
- bucket_result = ms.ops.bucketize(grad.float(), bounds)
78
+ if config.level == GradConst.LEVEL2:
79
+ zero_grad = (config.grad == 0).sum()
80
+ dist_dim = ms.Tensor([len(config.bounds) + 2]).float()
81
+ bucket_result = ms.ops.bucketize(config.grad.float(), config.bounds)
70
82
  bucket_result = bucket_result.astype(ms.int8)
71
- dist_stat = [(bucket_result == i).sum() for i in range(len(bounds) + 1)]
83
+ dist_stat = [(bucket_result == i).sum() for i in range(len(config.bounds) + 1)]
72
84
  dist_stat.append(zero_grad)
73
85
  dist_stat.append(ms.Tensor(1, dtype=ms.int64)) # make sure dist_stat is not empty
74
86
  dist_stat = ms.ops.stack(dist_stat, axis=0).float()
@@ -76,8 +88,8 @@ def grad_dump(dump_dir: str, g_name: str, dump_step: Parameter, grad: ms.Tensor,
76
88
  level_stat = level2_stat
77
89
 
78
90
  save_op(dump_path, level_stat)
79
- if level == GradConst.LEVEL1 or level == GradConst.LEVEL2:
80
- grad_direction = grad > 0
91
+ if config.level == GradConst.LEVEL1 or config.level == GradConst.LEVEL2:
92
+ grad_direction = config.grad > 0
81
93
  save_op(dump_dir_path, grad_direction)
82
94
 
83
95
 
@@ -26,7 +26,7 @@ from msprobe.core.grad_probe.constant import GradConst
26
26
  from msprobe.mindspore.common.log import logger
27
27
  from msprobe.mindspore.grad_probe.global_context import grad_context
28
28
  from msprobe.mindspore.grad_probe.grad_analyzer import csv_generator
29
- from msprobe.mindspore.grad_probe.grad_analyzer import grad_dump, get_rank_id
29
+ from msprobe.mindspore.grad_probe.grad_analyzer import grad_dump, get_rank_id, GradDumpConfig
30
30
  from msprobe.mindspore.grad_probe.grad_stat_csv import GradStatCsv, CsvInput
31
31
  from msprobe.mindspore.grad_probe.utils import save_grad_direction, get_adapted_level
32
32
 
@@ -38,7 +38,14 @@ class HookInput:
38
38
 
39
39
  def __init__(self, opt) -> None:
40
40
  self.func = opt.construct
41
- self.g_names = [param.name for param in opt._parameters]
41
+ if hasattr(opt, "_parameters"):
42
+ parameter_list = opt._parameters
43
+ elif hasattr(opt, "parameters"):
44
+ parameter_list = opt.parameters
45
+ else:
46
+ logger.error_log_with_exp("Given optimizer has no attributes: '_parameters' or 'parameters'. \
47
+ Please check the type of the given optimizer.", ValueError)
48
+ self.g_names = [param.name for param in parameter_list]
42
49
  self.param_list = grad_context.get_context(GradConst.PARAM_LIST)
43
50
  self.rank_id = get_rank_id()
44
51
  output_path = grad_context.get_context(GradConst.OUTPUT_PATH)
@@ -59,8 +66,10 @@ def hook_graph_mode_optimizer(opt, hook_input):
59
66
  for index, grad_value in enumerate(gradients):
60
67
  if hook_input.param_list and hook_input.g_names[index] not in hook_input.param_list:
61
68
  continue
62
- grad_dump(hook_input.dump_dir, hook_input.g_names[index], self.dump_step,
63
- grad_value, hook_input.level, hook_input.bounds)
69
+ conf = GradDumpConfig(dump_dir=hook_input.dump_dir, g_name=hook_input.g_names[index],
70
+ dump_step=self.dump_step, grad=grad_value, level=hook_input.level,
71
+ bounds=hook_input.bounds)
72
+ grad_dump(conf)
64
73
  ms.ops.TensorDump()(hook_input.step_finish_flag, self.dump_step)
65
74
  self.assignadd(self.dump_step, self.global_step_increase_tensor)
66
75
  out = hook_input.func(gradients)
@@ -0,0 +1,18 @@
1
+ # Copyright (c) 2025-2025, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from .mindtorch_adaptor import (_call_impl,
17
+ register_full_backward_pre_hook,
18
+ register_full_backward_hook)
@@ -0,0 +1,255 @@
1
+ # From PyTorch:
2
+
3
+ # Copyright (c) 2025 Huawei Technologies Co., Ltd
4
+ # Copyright (c) 2016- Facebook, Inc (Adam Paszke)
5
+ # Copyright (c) 2014- Facebook, Inc (Soumith Chintala)
6
+ # Copyright (c) 2011-2014 Idiap Research Institute (Ronan Collobert)
7
+ # Copyright (c) 2012-2014 Deepmind Technologies (Koray Kavukcuoglu)
8
+ # Copyright (c) 2011-2012 NEC Laboratories America (Koray Kavukcuoglu)
9
+ # Copyright (c) 2011-2013 NYU (Clement Farabet)
10
+ # Copyright (c) 2006-2010 NEC Laboratories America (Ronan Collobert, Leon Bottou, Iain Melvin, Jason Weston)
11
+ # Copyright (c) 2006 Idiap Research Institute (Samy Bengio)
12
+ # Copyright (c) 2001-2004 Idiap Research Institute (Ronan Collobert, Samy Bengio, Johnny Mariethoz)
13
+
14
+ # From Caffe2:
15
+
16
+ # Copyright (c) 2016-present, Facebook Inc. All rights reserved.
17
+
18
+ # All contributions by Facebook:
19
+ # Copyright (c) 2016 Facebook Inc.
20
+
21
+ # All contributions by Google:
22
+ # Copyright (c) 2015 Google Inc.
23
+ # All rights reserved.
24
+
25
+ # All contributions by Yangqing Jia:
26
+ # Copyright (c) 2015 Yangqing Jia
27
+ # All rights reserved.
28
+
29
+ # All contributions by Kakao Brain:
30
+ # Copyright 2019-2020 Kakao Brain
31
+
32
+ # All contributions by Cruise LLC:
33
+ # Copyright (c) 2022 Cruise LLC.
34
+ # All rights reserved.
35
+
36
+ # All contributions by Tri Dao:
37
+ # Copyright (c) 2024 Tri Dao.
38
+ # All rights reserved.
39
+
40
+ # All contributions by Arm:
41
+ # Copyright (c) 2021, 2023-2024 Arm Limited and/or its affiliates
42
+
43
+ # All contributions from Caffe:
44
+ # Copyright(c) 2013, 2014, 2015, the respective contributors
45
+ # All rights reserved.
46
+
47
+ # All other contributions:
48
+ # Copyright(c) 2015, 2016 the respective contributors
49
+ # All rights reserved.
50
+
51
+ # Caffe2 uses a copyright model similar to Caffe: each contributor holds
52
+ # copyright over their contributions to Caffe2. The project versioning records
53
+ # all such contribution and copyright details. If a contributor wants to further
54
+ # mark their specific copyright on a particular contribution, they should
55
+ # indicate their copyright solely in the commit message of the change when it is
56
+ # committed.
57
+
58
+ # All rights reserved.
59
+
60
+ # Redistribution and use in source and binary forms, with or without
61
+ # modification, are permitted provided that the following conditions are met:
62
+
63
+ # 1. Redistributions of source code must retain the above copyright
64
+ # notice, this list of conditions and the following disclaimer.
65
+
66
+ # 2. Redistributions in binary form must reproduce the above copyright
67
+ # notice, this list of conditions and the following disclaimer in the
68
+ # documentation and/or other materials provided with the distribution.
69
+
70
+ # 3. Neither the names of Facebook, Deepmind Technologies, NYU, NEC Laboratories
71
+ # America, IDIAP Research Institute and Huawei nor the names of its contributors
72
+ # may be used to endorse or promote products derived from this software without
73
+ # specific prior written permission.
74
+
75
+ # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
76
+ # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
77
+ # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
78
+ # ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
79
+ # LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
80
+ # CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
81
+ # SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
82
+ # INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
83
+ # CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
84
+ # ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
85
+ # POSSIBILITY OF SUCH DAMAGE.
86
+
87
+ import warnings
88
+
89
+ import mindspore as ms
90
+ from mindspore.ops.operations import _inner_ops as inner
91
+ from torch.nn.modules.module import (_global_backward_pre_hooks, _global_backward_hooks,
92
+ _global_is_full_backward_hook, _global_forward_pre_hooks,
93
+ _global_forward_hooks, _global_forward_hooks_always_called)
94
+ from torch.utils.hooks import RemovableHandle
95
+
96
+
97
+ def _call_impl(self, *args, **kwargs):
98
+ forward_call = self.forward
99
+ if self.__ms_class__:
100
+ return forward_call(*args, **kwargs)
101
+
102
+ # If we don't have any hooks, we want to skip the rest of the logic in
103
+ # this function, and just call forward.
104
+ if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
105
+ or _global_backward_pre_hooks or _global_backward_hooks
106
+ or _global_forward_hooks or _global_forward_pre_hooks):
107
+ return forward_call(*args, **kwargs)
108
+
109
+ try:
110
+ result = None
111
+ called_always_called_hooks = set()
112
+
113
+ if self._backward_pre_hooks or _global_backward_pre_hooks:
114
+ _get_backward_pre_hooks(self)
115
+
116
+ if self._backward_hooks or _global_backward_hooks:
117
+ _get_backward_hooks(self)
118
+
119
+ if _global_forward_pre_hooks or self._forward_pre_hooks:
120
+ for hook_id, hook in (
121
+ *_global_forward_pre_hooks.items(),
122
+ *self._forward_pre_hooks.items(),
123
+ ):
124
+ if hook_id in self._forward_pre_hooks_with_kwargs:
125
+ args_kwargs_result = hook(self, args, kwargs) # type: ignore[misc]
126
+ if args_kwargs_result is not None:
127
+ if isinstance(args_kwargs_result, tuple) and len(args_kwargs_result) == 2:
128
+ args, kwargs = args_kwargs_result
129
+ else:
130
+ raise RuntimeError(
131
+ "forward pre-hook must return None or a tuple "
132
+ f"of (new_args, new_kwargs), but got {args_kwargs_result}."
133
+ )
134
+ else:
135
+ args_result = hook(self, args)
136
+ if args_result is not None:
137
+ if not isinstance(args_result, tuple):
138
+ args_result = (args_result,)
139
+ args = args_result
140
+
141
+ bw_hook = None
142
+ if self._backward_hooks:
143
+ bw_hook = inner.CellBackwardHook(self.__class__.__name__ + "(" + str(id(self)) + ")",
144
+ self, self._backward_hooks)
145
+ bw_hook.register_backward_hook()
146
+ args = apply_backward_hook_on_tensors(bw_hook, args)
147
+
148
+ result = forward_call(*args, **kwargs)
149
+ if _global_forward_hooks or self._forward_hooks:
150
+ for hook_id, hook in (
151
+ *_global_forward_hooks.items(),
152
+ *self._forward_hooks.items(),
153
+ ):
154
+ # mark that always called hook is run
155
+ if hook_id in self._forward_hooks_always_called or hook_id in _global_forward_hooks_always_called:
156
+ called_always_called_hooks.add(hook_id)
157
+
158
+ if hook_id in self._forward_hooks_with_kwargs:
159
+ hook_result = hook(self, args, kwargs, result)
160
+ else:
161
+ hook_result = hook(self, args, result)
162
+
163
+ if hook_result is not None:
164
+ result = hook_result
165
+
166
+ if bw_hook:
167
+ if not isinstance(result, (ms.Tensor, tuple)):
168
+ warnings.warn("For backward hooks to be called,"
169
+ " module output should be a Tensor or a tuple of Tensors"
170
+ f" but received {type(result)}")
171
+ result = apply_backward_hook_on_tensors(bw_hook, result)
172
+
173
+ if self._backward_pre_hooks:
174
+ bw_pre_hook = inner.CellBackwardHook(self.__class__.__name__ + "(" + str(id(self)) + ")",
175
+ self, self._backward_pre_hooks)
176
+ bw_pre_hook.register_backward_pre_hook()
177
+ result = apply_backward_hook_on_tensors(bw_pre_hook, result)
178
+
179
+ return result
180
+ except Exception:
181
+ # run always called hooks if they have not already been run
182
+ # For now only forward hooks have the always_call option but perhaps
183
+ # this functionality should be added to full backward hooks as well.
184
+ for hook_id, hook in _global_forward_hooks.items():
185
+ # type: ignore[possibly-undefined]
186
+ if hook_id in _global_forward_hooks_always_called and hook_id not in called_always_called_hooks:
187
+ try:
188
+ hook_result = hook(self, args, result) # type: ignore[possibly-undefined]
189
+ if hook_result is not None:
190
+ result = hook_result
191
+ except Exception as e:
192
+ warnings.warn("global module forward hook with ``always_call=True`` raised an exception "
193
+ f"that was silenced as another error was raised in forward: {str(e)}")
194
+ continue
195
+
196
+ for hook_id, hook in self._forward_hooks.items():
197
+ # type: ignore[possibly-undefined]
198
+ if hook_id in self._forward_hooks_always_called and hook_id not in called_always_called_hooks:
199
+ try:
200
+ if hook_id in self._forward_hooks_with_kwargs:
201
+ hook_result = hook(self, args, kwargs, result) # type: ignore[possibly-undefined]
202
+ else:
203
+ hook_result = hook(self, args, result) # type: ignore[possibly-undefined]
204
+ if hook_result is not None:
205
+ result = hook_result
206
+ except Exception as e:
207
+ warnings.warn("module forward hook with ``always_call=True`` raised an exception "
208
+ f"that was silenced as another error was raised in forward: {str(e)}")
209
+ continue
210
+ # raise exception raised in try block
211
+ raise
212
+
213
+
214
+ def register_full_backward_pre_hook(self, hook, prepend: bool = False) -> RemovableHandle:
215
+ handle = RemovableHandle(self._backward_pre_hooks)
216
+ self._backward_pre_hooks[handle.id] = hook
217
+ if prepend:
218
+ self._backward_pre_hooks.move_to_end(handle.id, last=False) # type: ignore[attr-defined]
219
+ return handle
220
+
221
+
222
+ def register_full_backward_hook(self, hook, prepend: bool = False) -> RemovableHandle:
223
+ if self._is_full_backward_hook is False:
224
+ raise RuntimeError(
225
+ "Cannot use both regular backward hooks and full backward hooks on a "
226
+ "single Module. Please use only one of them."
227
+ )
228
+
229
+ self._is_full_backward_hook = True
230
+
231
+ handle = RemovableHandle(self._backward_hooks)
232
+ self._backward_hooks[handle.id] = hook
233
+ if prepend:
234
+ self._backward_hooks.move_to_end(handle.id, last=False) # type: ignore[attr-defined]
235
+ return handle
236
+
237
+
238
+ def _get_backward_pre_hooks(self):
239
+ self._backward_pre_hooks.update(_global_backward_pre_hooks)
240
+
241
+
242
+ def _get_backward_hooks(self):
243
+ if (_global_is_full_backward_hook is True):
244
+ self._backward_hooks.update(_global_backward_hooks)
245
+
246
+
247
+ def apply_backward_hook_on_tensors(cell_backward_hook, args):
248
+ is_tuple = True
249
+ if not isinstance(args, tuple):
250
+ args = (args,)
251
+ is_tuple = False
252
+ hooked_args = cell_backward_hook(*args)
253
+ if is_tuple and len(args) == 1:
254
+ hooked_args = (hooked_args, )
255
+ return hooked_args
@@ -45,7 +45,11 @@ class StatisticsConfig(BaseConfig):
45
45
  self._check_config()
46
46
 
47
47
  def _check_config(self):
48
- if self.summary_mode and self.summary_mode not in ["statistics", "md5"]:
48
+ single_opt = ["statistics", "md5"]
49
+ muti_opt = ["md5", "max", "min", "mean", "l2norm"]
50
+ if isinstance(self.summary_mode, str) and self.summary_mode not in single_opt:
51
+ raise Exception("summary_mode is invalid")
52
+ if isinstance(self.summary_mode, list) and not all(opt in muti_opt for opt in self.summary_mode):
49
53
  raise Exception("summary_mode is invalid")
50
54
 
51
55