mindstudio-probe 1.1.0__py3-none-any.whl → 1.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (220) hide show
  1. {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.1.1.dist-info}/METADATA +5 -5
  2. mindstudio_probe-1.1.1.dist-info/RECORD +341 -0
  3. {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.1.1.dist-info}/WHEEL +1 -1
  4. {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.1.1.dist-info}/entry_points.txt +0 -1
  5. msprobe/README.md +39 -3
  6. msprobe/config.json +1 -3
  7. msprobe/core/advisor/advisor.py +8 -3
  8. msprobe/core/common/const.py +113 -13
  9. msprobe/core/common/exceptions.py +25 -3
  10. msprobe/core/common/file_utils.py +150 -26
  11. msprobe/core/common/inplace_op_checker.py +15 -0
  12. msprobe/core/common/log.py +27 -9
  13. msprobe/core/common/utils.py +182 -69
  14. msprobe/core/common_config.py +44 -15
  15. msprobe/core/compare/acc_compare.py +207 -142
  16. msprobe/core/compare/check.py +2 -5
  17. msprobe/core/compare/compare_cli.py +21 -4
  18. msprobe/core/compare/highlight.py +124 -55
  19. msprobe/core/compare/layer_mapping/__init__.py +19 -0
  20. msprobe/core/compare/layer_mapping/data_scope_parser.py +235 -0
  21. msprobe/core/compare/layer_mapping/layer_mapping.py +242 -0
  22. msprobe/core/compare/layer_mapping/postprocess_pass.py +94 -0
  23. msprobe/core/compare/npy_compare.py +52 -23
  24. msprobe/core/compare/utils.py +272 -247
  25. msprobe/core/data_dump/data_collector.py +13 -11
  26. msprobe/core/data_dump/data_processor/base.py +46 -16
  27. msprobe/core/data_dump/data_processor/mindspore_processor.py +4 -4
  28. msprobe/core/data_dump/data_processor/pytorch_processor.py +156 -59
  29. msprobe/core/data_dump/scope.py +113 -34
  30. msprobe/core/grad_probe/constant.py +27 -13
  31. msprobe/core/grad_probe/grad_compare.py +18 -1
  32. msprobe/core/grad_probe/utils.py +30 -2
  33. msprobe/core/overflow_check/abnormal_scene.py +185 -0
  34. msprobe/core/overflow_check/api_info.py +55 -0
  35. msprobe/core/overflow_check/checker.py +138 -0
  36. msprobe/core/overflow_check/filter.py +157 -0
  37. msprobe/core/overflow_check/ignore_rules.yaml +55 -0
  38. msprobe/core/overflow_check/level.py +22 -0
  39. msprobe/core/overflow_check/utils.py +28 -0
  40. msprobe/docs/01.installation.md +10 -0
  41. msprobe/docs/02.config_introduction.md +49 -22
  42. msprobe/docs/03.config_examples.md +2 -9
  43. msprobe/docs/04.kernel_dump_PyTorch.md +73 -0
  44. msprobe/docs/05.data_dump_PyTorch.md +3 -1
  45. msprobe/docs/06.data_dump_MindSpore.md +157 -90
  46. msprobe/docs/07.accuracy_checker_PyTorch.md +12 -12
  47. msprobe/docs/08.accuracy_checker_online_PyTorch.md +1 -6
  48. msprobe/docs/09.accuracy_checker_MindSpore.md +44 -8
  49. msprobe/docs/10.accuracy_compare_PyTorch.md +19 -13
  50. msprobe/docs/11.accuracy_compare_MindSpore.md +104 -13
  51. msprobe/docs/12.overflow_check_PyTorch.md +1 -1
  52. msprobe/docs/13.overflow_check_MindSpore.md +6 -6
  53. msprobe/docs/15.free_benchmarking_PyTorch.md +4 -5
  54. msprobe/docs/16.free_benchmarking_MindSpore.md +56 -37
  55. msprobe/docs/17.grad_probe.md +5 -6
  56. msprobe/docs/19.monitor.md +468 -0
  57. msprobe/docs/20.monitor_performance_baseline.md +52 -0
  58. msprobe/docs/21.visualization_PyTorch.md +386 -0
  59. msprobe/docs/22.visualization_MindSpore.md +384 -0
  60. msprobe/docs/23.tool_function_introduction.md +28 -0
  61. msprobe/docs/FAQ.md +3 -0
  62. msprobe/docs/data_dump_Mindspore/dynamic_graph_quick_start_example.md +211 -0
  63. msprobe/docs/img/compare_result.png +0 -0
  64. msprobe/docs/img/monitor/cpu_info.png +0 -0
  65. msprobe/mindspore/__init__.py +15 -0
  66. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +113 -145
  67. msprobe/mindspore/api_accuracy_checker/api_info.py +21 -6
  68. msprobe/mindspore/api_accuracy_checker/api_runner.py +43 -18
  69. msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +21 -7
  70. msprobe/mindspore/api_accuracy_checker/checker_support_api.yaml +77 -0
  71. msprobe/mindspore/api_accuracy_checker/cmd_parser.py +63 -1
  72. msprobe/mindspore/api_accuracy_checker/compute_element.py +59 -24
  73. msprobe/mindspore/api_accuracy_checker/data_manager.py +264 -0
  74. msprobe/mindspore/api_accuracy_checker/main.py +27 -3
  75. msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +206 -0
  76. msprobe/mindspore/api_accuracy_checker/multi_data_manager.py +58 -0
  77. msprobe/mindspore/api_accuracy_checker/type_mapping.py +22 -5
  78. msprobe/mindspore/api_accuracy_checker/utils.py +34 -17
  79. msprobe/mindspore/cell_processor.py +33 -12
  80. msprobe/mindspore/common/const.py +33 -13
  81. msprobe/mindspore/common/log.py +5 -9
  82. msprobe/mindspore/common/utils.py +43 -4
  83. msprobe/mindspore/compare/distributed_compare.py +22 -22
  84. msprobe/mindspore/compare/ms_compare.py +271 -248
  85. msprobe/mindspore/compare/ms_graph_compare.py +81 -47
  86. msprobe/mindspore/debugger/debugger_config.py +4 -1
  87. msprobe/mindspore/debugger/precision_debugger.py +7 -1
  88. msprobe/mindspore/dump/dump_tool_factory.py +3 -1
  89. msprobe/mindspore/dump/hook_cell/api_registry.py +12 -2
  90. msprobe/mindspore/dump/hook_cell/primitive_hooks.py +13 -16
  91. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +25 -0
  92. msprobe/mindspore/dump/jit_dump.py +17 -5
  93. msprobe/mindspore/dump/kernel_graph_dump.py +2 -4
  94. msprobe/mindspore/dump/kernel_kbyk_dump.py +2 -4
  95. msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +140 -0
  96. msprobe/mindspore/dym_loader/hook_dynamic_loader.h +53 -0
  97. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +145 -39
  98. msprobe/mindspore/free_benchmark/common/handler_params.py +1 -2
  99. msprobe/mindspore/free_benchmark/common/utils.py +19 -4
  100. msprobe/mindspore/free_benchmark/data/support_wrap_ops.yaml +0 -204
  101. msprobe/mindspore/free_benchmark/handler/base_handler.py +3 -3
  102. msprobe/mindspore/free_benchmark/handler/check_handler.py +4 -5
  103. msprobe/mindspore/free_benchmark/handler/fix_handler.py +4 -4
  104. msprobe/mindspore/free_benchmark/handler/handler_factory.py +4 -4
  105. msprobe/mindspore/free_benchmark/perturbation/add_noise.py +2 -2
  106. msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +15 -6
  107. msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +4 -4
  108. msprobe/mindspore/free_benchmark/perturbation/exchange_value.py +2 -2
  109. msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +13 -6
  110. msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +2 -2
  111. msprobe/mindspore/free_benchmark/self_check_tool_factory.py +2 -2
  112. msprobe/mindspore/grad_probe/global_context.py +28 -8
  113. msprobe/mindspore/grad_probe/grad_analyzer.py +27 -13
  114. msprobe/mindspore/grad_probe/grad_monitor.py +16 -1
  115. msprobe/mindspore/grad_probe/grad_stat_csv.py +33 -5
  116. msprobe/mindspore/grad_probe/hook.py +24 -10
  117. msprobe/mindspore/grad_probe/utils.py +18 -5
  118. msprobe/mindspore/ms_config.py +22 -15
  119. msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +2 -4
  120. msprobe/mindspore/runtime.py +15 -0
  121. msprobe/mindspore/service.py +36 -30
  122. msprobe/mindspore/task_handler_factory.py +15 -0
  123. msprobe/msprobe.py +24 -7
  124. msprobe/pytorch/__init__.py +3 -2
  125. msprobe/pytorch/api_accuracy_checker/common/config.py +62 -0
  126. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +3 -4
  127. msprobe/pytorch/api_accuracy_checker/generate_op_script/config_op.json +9 -0
  128. msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +454 -0
  129. msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +365 -0
  130. msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +6 -1
  131. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +19 -14
  132. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +13 -9
  133. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +77 -53
  134. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +15 -4
  135. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +9 -24
  136. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +4 -12
  137. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/dump_dispatch.py +9 -4
  138. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +3 -11
  139. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +2 -2
  140. msprobe/pytorch/bench_functions/confusion_transpose.py +5 -1
  141. msprobe/pytorch/bench_functions/matmul_backward.py +12 -0
  142. msprobe/pytorch/bench_functions/npu_fusion_attention.py +100 -6
  143. msprobe/pytorch/bench_functions/rotary_mul.py +4 -0
  144. msprobe/pytorch/bench_functions/swiglu.py +10 -2
  145. msprobe/pytorch/common/parse_json.py +6 -6
  146. msprobe/pytorch/common/utils.py +56 -5
  147. msprobe/pytorch/compare/distributed_compare.py +8 -9
  148. msprobe/pytorch/compare/pt_compare.py +8 -6
  149. msprobe/pytorch/debugger/debugger_config.py +19 -15
  150. msprobe/pytorch/dump/kernel_dump/kernel_config.py +33 -0
  151. msprobe/pytorch/free_benchmark/common/constant.py +15 -0
  152. msprobe/pytorch/free_benchmark/common/counter.py +15 -0
  153. msprobe/pytorch/free_benchmark/common/enums.py +15 -0
  154. msprobe/pytorch/free_benchmark/common/params.py +8 -1
  155. msprobe/pytorch/free_benchmark/common/utils.py +26 -4
  156. msprobe/pytorch/free_benchmark/compare/grad_saver.py +20 -3
  157. msprobe/pytorch/free_benchmark/compare/single_benchmark.py +2 -0
  158. msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +3 -1
  159. msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +6 -4
  160. msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +2 -0
  161. msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +4 -0
  162. msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +10 -0
  163. msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +6 -5
  164. msprobe/pytorch/grad_probe/grad_monitor.py +23 -6
  165. msprobe/pytorch/grad_probe/grad_stat_csv.py +40 -10
  166. msprobe/pytorch/hook_module/support_wrap_ops.yaml +1 -0
  167. msprobe/pytorch/hook_module/wrap_functional.py +14 -12
  168. msprobe/pytorch/module_processer.py +2 -5
  169. msprobe/pytorch/monitor/anomaly_analyse.py +201 -0
  170. msprobe/pytorch/monitor/anomaly_detect.py +340 -0
  171. msprobe/pytorch/monitor/distributed/__init__.py +0 -0
  172. msprobe/pytorch/monitor/distributed/distributed_ops.yaml +19 -0
  173. msprobe/pytorch/monitor/distributed/stack_blacklist.yaml +5 -0
  174. msprobe/pytorch/monitor/distributed/wrap_distributed.py +272 -0
  175. msprobe/pytorch/monitor/features.py +108 -0
  176. msprobe/pytorch/monitor/module_hook.py +870 -0
  177. msprobe/pytorch/monitor/module_metric.py +193 -0
  178. msprobe/pytorch/monitor/module_spec_verifier.py +93 -0
  179. msprobe/pytorch/monitor/optimizer_collect.py +295 -0
  180. msprobe/pytorch/monitor/unittest/__init__.py +0 -0
  181. msprobe/pytorch/monitor/unittest/test_monitor.py +145 -0
  182. msprobe/pytorch/monitor/utils.py +250 -0
  183. msprobe/pytorch/monitor/visualizer.py +59 -0
  184. msprobe/pytorch/online_dispatch/__init__.py +2 -3
  185. msprobe/pytorch/online_dispatch/compare.py +29 -38
  186. msprobe/pytorch/online_dispatch/dispatch.py +50 -25
  187. msprobe/pytorch/online_dispatch/dump_compare.py +21 -9
  188. msprobe/pytorch/online_dispatch/single_compare.py +53 -32
  189. msprobe/pytorch/online_dispatch/torch_ops_config.yaml +1 -1
  190. msprobe/pytorch/online_dispatch/utils.py +49 -21
  191. msprobe/pytorch/parse_tool/lib/compare.py +12 -18
  192. msprobe/pytorch/parse_tool/lib/config.py +1 -1
  193. msprobe/pytorch/parse_tool/lib/parse_tool.py +1 -2
  194. msprobe/pytorch/parse_tool/lib/utils.py +16 -35
  195. msprobe/pytorch/parse_tool/lib/visualization.py +2 -0
  196. msprobe/pytorch/pt_config.py +31 -8
  197. msprobe/pytorch/service.py +15 -5
  198. msprobe/visualization/__init__.py +14 -0
  199. msprobe/visualization/builder/__init__.py +14 -0
  200. msprobe/visualization/builder/graph_builder.py +165 -0
  201. msprobe/visualization/builder/msprobe_adapter.py +205 -0
  202. msprobe/visualization/compare/__init__.py +14 -0
  203. msprobe/visualization/compare/graph_comparator.py +130 -0
  204. msprobe/visualization/compare/mode_adapter.py +211 -0
  205. msprobe/visualization/graph/__init__.py +14 -0
  206. msprobe/visualization/graph/base_node.py +124 -0
  207. msprobe/visualization/graph/graph.py +200 -0
  208. msprobe/visualization/graph/node_colors.py +95 -0
  209. msprobe/visualization/graph/node_op.py +39 -0
  210. msprobe/visualization/graph_service.py +214 -0
  211. msprobe/visualization/utils.py +232 -0
  212. mindstudio_probe-1.1.0.dist-info/RECORD +0 -287
  213. msprobe/docs/04.acl_config_examples.md +0 -78
  214. msprobe/mindspore/compare/layer_mapping.py +0 -146
  215. msprobe/mindspore/compare/modify_mapping.py +0 -107
  216. msprobe/mindspore/free_benchmark/decorator/dec_forward.py +0 -57
  217. msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +0 -122
  218. {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.1.1.dist-info}/LICENSE +0 -0
  219. {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.1.1.dist-info}/top_level.txt +0 -0
  220. /msprobe/{mindspore/free_benchmark/decorator → pytorch/monitor}/__init__.py +0 -0
@@ -1,4 +1,5 @@
1
- # Copyright 2024 Huawei Technologies Co., Ltd
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
2
3
  #
3
4
  # Licensed under the Apache License, Version 2.0 (the "License");
4
5
  # you may not use this file except in compliance with the License.
@@ -11,38 +12,33 @@
11
12
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
13
  # See the License for the specific language governing permissions and
13
14
  # limitations under the License.
14
- # ============================================================================
15
15
 
16
- import os
17
16
  import copy
18
17
  import functools
18
+ import os
19
19
  from collections import defaultdict
20
20
 
21
21
  import mindspore as ms
22
- from mindspore.common.tensor import Tensor
23
- from mindspore import ops
24
22
  from mindspore import nn
25
23
  try:
26
24
  from mindspore.common._pijit_context import PIJitCaptureContext
27
- pijit_label = True
28
25
  except ImportError:
29
26
  pijit_label = False
27
+ else:
28
+ pijit_label = True
30
29
 
31
30
 
31
+ from msprobe.core.common.exceptions import DistributedNotInitializedError, MsprobeException
32
+ from msprobe.core.common.file_utils import create_directory
33
+ from msprobe.core.common.utils import Const, print_tools_ends_info
32
34
  from msprobe.core.data_dump.data_collector import build_data_collector
35
+ from msprobe.core.data_dump.data_processor.base import ModuleBackwardInputsOutputs, ModuleForwardInputsOutputs
33
36
  from msprobe.core.data_dump.scope import BaseScope
34
- from msprobe.mindspore.common.utils import get_rank_if_initialized
35
- from msprobe.core.common.file_utils import create_directory
37
+ from msprobe.mindspore.cell_processor import CellProcessor
36
38
  from msprobe.mindspore.common.log import logger
37
- from msprobe.core.common.utils import Const, print_tools_ends_info
38
- from msprobe.core.common.exceptions import DistributedNotInitializedError
39
+ from msprobe.mindspore.common.utils import get_rank_if_initialized
39
40
  from msprobe.mindspore.dump.hook_cell.api_registry import api_register
40
41
  from msprobe.mindspore.dump.hook_cell.primitive_hooks import PrimitiveHookService
41
- from msprobe.core.data_dump.data_processor.base import ModuleBackwardInputsOutputs, ModuleForwardInputsOutputs, \
42
- ModuleBackwardInputs, ModuleBackwardOutputs
43
- from msprobe.core.common.exceptions import MsprobeException
44
- from msprobe.mindspore.dump.hook_cell.hook_cell import HOOKCell
45
- from msprobe.mindspore.cell_processor import CellProcessor
46
42
  from msprobe.mindspore.dump.jit_dump import JitDump
47
43
 
48
44
 
@@ -79,22 +75,24 @@ class Service:
79
75
  )
80
76
 
81
77
  def build_hook(self, target_type, name):
82
- def forward_hook(api_or_cell_name, cell, input, output):
78
+ def forward_hook(api_or_cell_name, cell, input_data, output):
83
79
  if not self.should_excute_hook():
80
+ if hasattr(cell, 'input_kwargs'):
81
+ del cell.input_kwargs
84
82
  return None
85
83
 
86
84
  if target_type == BaseScope.Module_Type_Module:
87
- api_or_cell_name = cell.mindstudio_reserved_name
88
- module_input_output = ModuleForwardInputsOutputs(args=input, kwargs={}, output=output)
85
+ api_or_cell_name = self.cell_processor.set_and_get_reserved_name(cell, api_or_cell_name)
86
+ module_input_output = ModuleForwardInputsOutputs(args=input_data, kwargs={}, output=output)
89
87
  else:
90
- module_input_output = ModuleForwardInputsOutputs(args=input, kwargs=cell.input_kwargs,
88
+ module_input_output = ModuleForwardInputsOutputs(args=input_data, kwargs=cell.input_kwargs,
91
89
  output=output)
92
90
 
93
91
  self.data_collector.update_api_or_module_name(api_or_cell_name)
94
92
  self.data_collector.forward_data_collect(api_or_cell_name, cell, pid, module_input_output)
95
93
  if self.data_collector.if_return_forward_new_output():
96
94
  return self.data_collector.get_forward_new_output()
97
- if target_type == BaseScope.Module_Type_API:
95
+ if hasattr(cell, 'input_kwargs'):
98
96
  del cell.input_kwargs
99
97
  return output
100
98
 
@@ -102,12 +100,19 @@ class Service:
102
100
  if not self.should_excute_hook():
103
101
  return
104
102
 
103
+ need_exchange = True
105
104
  if target_type == BaseScope.Module_Type_Module:
106
- api_or_cell_name = cell.mindstudio_reserved_name
105
+ if not hasattr(cell, 'has_pre_hook_called') or not cell.has_pre_hook_called:
106
+ need_exchange = False
107
+ api_or_cell_name = self.cell_processor.set_and_get_reserved_name(cell, api_or_cell_name)
108
+
107
109
  self.data_collector.update_api_or_module_name(api_or_cell_name)
108
110
  if self.data_collector:
109
111
  # 框架最新接口变更,grad_input和grad_output的含义发生了变化,与torch含义保持一致,因此此处调换顺序传入
110
- module_input_output = ModuleBackwardInputsOutputs(grad_input=grad_output, grad_output=grad_input)
112
+ if need_exchange:
113
+ module_input_output = ModuleBackwardInputsOutputs(grad_input=grad_output, grad_output=grad_input)
114
+ else:
115
+ module_input_output = ModuleBackwardInputsOutputs(grad_input=grad_input, grad_output=grad_output)
111
116
  self.data_collector.backward_data_collect(api_or_cell_name, cell, pid, module_input_output)
112
117
 
113
118
  pid = os.getpid()
@@ -116,15 +121,14 @@ class Service:
116
121
  forward_hook = functools.partial(forward_hook, forward_name_template)
117
122
  backward_hook = functools.partial(backward_hook, backward_name_template)
118
123
 
119
- def wrap_forward_hook(cell, input, output):
120
- return forward_hook(cell, input, output)
124
+ def wrap_forward_hook(cell, input_data, output_data):
125
+ return forward_hook(cell, input_data, output_data)
121
126
 
122
127
  def wrap_backward_hook(cell, grad_input, grad_output):
123
128
  return backward_hook(cell, grad_input, grad_output)
124
129
 
125
130
  return wrap_forward_hook, wrap_backward_hook
126
131
 
127
-
128
132
  def update_primitive_counters(self, primitive_name):
129
133
  if primitive_name not in self.primitive_counters:
130
134
  self.primitive_counters[primitive_name] = 0
@@ -138,15 +142,16 @@ class Service:
138
142
  primitive_set.add((pname, primitive))
139
143
 
140
144
  for pname, primitive in primitive_set:
141
- NewPrimitive = type('NewPrimitive', (primitive.__class__,),
142
- {'__call__': self.primitive_hook_service.wrap_primitive(primitive.__call__, pname)})
143
- primitive.__class__ = NewPrimitive
145
+ primitive_class_name = primitive.__class__.__name__
146
+ primitive_combined_name = pname + Const.SEP + primitive_class_name
147
+ new_primitive = type('NewPrimitive', (primitive.__class__,),
148
+ {'__call__': self.primitive_hook_service.wrap_primitive(primitive.__call__,
149
+ primitive_combined_name)})
150
+ primitive.__class__ = new_primitive
144
151
 
145
152
  def step(self):
146
153
  self.current_iter += 1
147
154
  self.data_collector.update_iter(self.current_iter)
148
- HOOKCell.cell_count = defaultdict(int)
149
- CellProcessor.reset_cell_stats()
150
155
  self.primitive_hook_service.primitive_counters.clear()
151
156
  self.data_collector.data_writer.reset_cache()
152
157
  JitDump.jit_count = defaultdict(int)
@@ -212,6 +217,7 @@ class Service:
212
217
  return
213
218
  self.primitive_switch = False
214
219
  api_register.api_set_ori_func()
220
+ JitDump.jit_dump_switch = False
215
221
 
216
222
  def stop(self):
217
223
  if self.should_stop_service:
@@ -1,3 +1,18 @@
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
1
16
  from msprobe.core.common.const import Const
2
17
  from msprobe.mindspore.debugger.debugger_config import DebuggerConfig
3
18
  from msprobe.mindspore.dump.dump_tool_factory import DumpToolFactory
msprobe/msprobe.py CHANGED
@@ -45,10 +45,15 @@ def main():
45
45
  multi_run_ut_cmd_parser = subparsers.add_parser('multi_run_ut')
46
46
  api_precision_compare_cmd_parser = subparsers.add_parser('api_precision_compare')
47
47
  run_overflow_check_cmd_parser = subparsers.add_parser('run_overflow_check')
48
+ graph_service_cmd_parser = subparsers.add_parser('graph')
48
49
  _compare_parser(compare_cmd_parser)
49
- is_torch_available=is_module_available("torch")
50
+ is_torch_available = is_module_available("torch")
50
51
  is_mindspore_available = is_module_available("mindspore")
51
- if is_torch_available:
52
+ if len(sys.argv) < 4:
53
+ parser.print_help()
54
+ sys.exit(0)
55
+ framework_args = parser.parse_args(sys.argv[1:3])
56
+ if framework_args.framework == Const.PT_FRAMEWORK:
52
57
  from msprobe.pytorch.api_accuracy_checker.run_ut.run_ut import _run_ut_parser, run_ut_command
53
58
  from msprobe.pytorch.parse_tool.cli import parse as cli_parse
54
59
  from msprobe.pytorch.api_accuracy_checker.run_ut.multi_run_ut import prepare_config, run_parallel_ut
@@ -56,20 +61,24 @@ def main():
56
61
  _api_precision_compare_command
57
62
  from msprobe.pytorch.api_accuracy_checker.run_ut.run_overflow_check import _run_overflow_check_parser, \
58
63
  _run_overflow_check_command
64
+ from msprobe.visualization.graph_service import _pt_graph_service_parser, _pt_graph_service_command
59
65
 
60
66
  _run_ut_parser(run_ut_cmd_parser)
61
67
  _run_ut_parser(multi_run_ut_cmd_parser)
62
68
  multi_run_ut_cmd_parser.add_argument('-n', '--num_splits', type=int, choices=range(1, 65), default=8,
63
- help='Number of splits for parallel processing. Range: 1-64')
69
+ help='Number of splits for parallel processing. Range: 1-64')
64
70
  _api_precision_compare_parser(api_precision_compare_cmd_parser)
65
71
  _run_overflow_check_parser(run_overflow_check_cmd_parser)
66
- elif is_mindspore_available:
72
+ _pt_graph_service_parser(graph_service_cmd_parser)
73
+ elif framework_args.framework == Const.MS_FRAMEWORK:
67
74
  from msprobe.mindspore.api_accuracy_checker.cmd_parser import add_api_accuracy_checker_argument
75
+ from msprobe.visualization.graph_service import _ms_graph_service_parser, _ms_graph_service_command
68
76
  add_api_accuracy_checker_argument(run_ut_cmd_parser)
77
+ from msprobe.mindspore.api_accuracy_checker.cmd_parser import multi_add_api_accuracy_checker_argument
78
+ multi_add_api_accuracy_checker_argument(multi_run_ut_cmd_parser)
79
+
80
+ _ms_graph_service_parser(graph_service_cmd_parser)
69
81
 
70
- if len(sys.argv) == 1:
71
- parser.print_help()
72
- sys.exit(0)
73
82
  args = parser.parse_args(sys.argv[1:])
74
83
  if sys.argv[2] == Const.PT_FRAMEWORK:
75
84
  if not is_torch_available:
@@ -86,6 +95,8 @@ def main():
86
95
  _api_precision_compare_command(args)
87
96
  elif sys.argv[3] == "run_overflow_check":
88
97
  _run_overflow_check_command(args)
98
+ elif sys.argv[3] == "graph":
99
+ _pt_graph_service_command(args)
89
100
  elif sys.argv[3] == "compare":
90
101
  if args.cell_mapping is not None or args.api_mapping is not None:
91
102
  logger.error("Argument -cm or -am is not supported in PyTorch framework")
@@ -100,6 +111,12 @@ def main():
100
111
  elif sys.argv[3] == "run_ut":
101
112
  from msprobe.mindspore.api_accuracy_checker.main import api_checker_main
102
113
  api_checker_main(args)
114
+ elif sys.argv[3] == "multi_run_ut":
115
+ from msprobe.mindspore.api_accuracy_checker.main import mul_api_checker_main
116
+ mul_api_checker_main(args)
117
+ elif sys.argv[3] == "graph":
118
+ _ms_graph_service_command(args)
119
+
103
120
 
104
121
  if __name__ == "__main__":
105
122
  main()
@@ -16,8 +16,9 @@
16
16
  # limitations under the License.
17
17
 
18
18
 
19
- from .debugger.precision_debugger import PrecisionDebugger
20
- from .common.utils import seed_all
19
+ from msprobe.pytorch.monitor.module_hook import TrainerMon
21
20
  from .compare.distributed_compare import compare_distributed
22
21
  from .compare.pt_compare import compare
22
+ from .common.utils import seed_all
23
+ from .debugger.precision_debugger import PrecisionDebugger
23
24
  from .functional.module_dump import module_dump, module_dump_end
@@ -16,10 +16,18 @@
16
16
  # limitations under the License.
17
17
 
18
18
  import os
19
+ from collections import namedtuple
19
20
  from msprobe.core.common.file_utils import load_yaml, check_file_or_directory_path
21
+ from msprobe.core.common.utils import is_int
20
22
  from msprobe.pytorch.pt_config import RunUTConfig
21
23
 
22
24
 
25
+ RunUtConfig = namedtuple('RunUtConfig', ['forward_content', 'backward_content', 'result_csv_path', 'details_csv_path',
26
+ 'save_error_data', 'is_continue_run_ut', 'real_data_path', 'white_list',
27
+ 'black_list', 'error_data_path', 'online_config'])
28
+ OnlineConfig = namedtuple('OnlineConfig', ['is_online', 'nfs_path', 'host', 'port', 'rank_list', 'tls_path'])
29
+
30
+
23
31
  class Config:
24
32
  def __init__(self, yaml_file):
25
33
  check_file_or_directory_path(yaml_file, False)
@@ -50,6 +58,8 @@ class Config:
50
58
  raise ValueError(f"{key} must be one of {validators.keys()}")
51
59
  if not isinstance(value, validators.get(key)):
52
60
  raise ValueError(f"{key} must be {validators[key].__name__} type")
61
+ if key == 'precision' and not is_int(value):
62
+ raise ValueError("precision must be an integer")
53
63
  if key == 'precision' and (value < 0 or value > 20):
54
64
  raise ValueError("precision must be greater than or equal to 0 and less than 21")
55
65
  if key == 'white_list':
@@ -68,3 +78,55 @@ class Config:
68
78
  cur_path = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
69
79
  yaml_path = os.path.join(cur_path, "config.yaml")
70
80
  msCheckerConfig = Config(yaml_path)
81
+
82
+
83
+ class CheckerConfig:
84
+ def __init__(self, task_config=None):
85
+ self.white_list = msCheckerConfig.white_list
86
+ self.black_list = msCheckerConfig.black_list
87
+ self.error_data_path = msCheckerConfig.error_data_path
88
+ self.is_online = msCheckerConfig.is_online
89
+ self.nfs_path = msCheckerConfig.nfs_path
90
+ self.host = msCheckerConfig.host
91
+ self.port = msCheckerConfig.port
92
+ self.rank_list = msCheckerConfig.rank_list
93
+ self.tls_path = msCheckerConfig.tls_path
94
+
95
+ if task_config:
96
+ self.load_config(task_config)
97
+
98
+ def load_config(self, task_config):
99
+ self.white_list = task_config.white_list
100
+ self.black_list = task_config.black_list
101
+ self.error_data_path = task_config.error_data_path
102
+ self.is_online = task_config.is_online
103
+ self.nfs_path = task_config.nfs_path
104
+ self.host = task_config.host
105
+ self.port = task_config.port
106
+ self.rank_list = task_config.rank_list
107
+ self.tls_path = task_config.tls_path
108
+
109
+ def get_online_config(self):
110
+ return OnlineConfig(
111
+ is_online=self.is_online,
112
+ nfs_path=self.nfs_path,
113
+ host=self.host,
114
+ port=self.port,
115
+ rank_list=self.rank_list,
116
+ tls_path=self.tls_path
117
+ )
118
+
119
+ def get_run_ut_config(self, **config_params):
120
+ return RunUtConfig(
121
+ forward_content=config_params.get('forward_content'),
122
+ backward_content=config_params.get('backward_content'),
123
+ result_csv_path=config_params.get('result_csv_path'),
124
+ details_csv_path=config_params.get('details_csv_path'),
125
+ save_error_data=config_params.get('save_error_data'),
126
+ is_continue_run_ut=config_params.get('is_continue_run_ut'),
127
+ real_data_path=config_params.get('real_data_path'),
128
+ white_list=self.white_list,
129
+ black_list=self.black_list,
130
+ error_data_path=config_params.get('error_data_path'),
131
+ online_config=self.get_online_config()
132
+ )
@@ -34,7 +34,7 @@ from msprobe.pytorch.api_accuracy_checker.compare.compare_utils import API_PRECI
34
34
  from msprobe.pytorch.api_accuracy_checker.compare.compare_column import ApiPrecisionOutputColumn
35
35
  from msprobe.pytorch.api_accuracy_checker.run_ut.run_ut_utils import get_validated_result_csv_path
36
36
  from msprobe.pytorch.api_accuracy_checker.common.utils import extract_detailed_api_segments
37
- from msprobe.core.common.file_utils import FileChecker, change_mode, check_path_before_create, create_directory
37
+ from msprobe.core.common.file_utils import FileChecker, change_mode, create_directory
38
38
  from msprobe.pytorch.common.log import logger
39
39
  from msprobe.core.common.utils import CompareException
40
40
  from msprobe.core.common.const import Const, CompareConst, FileCheckConst
@@ -602,8 +602,7 @@ def _api_precision_compare(parser=None):
602
602
  def _api_precision_compare_command(args):
603
603
  npu_csv_path = get_validated_result_csv_path(args.npu_csv_path, 'detail')
604
604
  gpu_csv_path = get_validated_result_csv_path(args.gpu_csv_path, 'detail')
605
- out_path = os.path.realpath(args.out_path) if args.out_path else "./"
606
- check_path_before_create(out_path)
605
+ out_path = args.out_path if args.out_path else Const.DEFAULT_PATH
607
606
  create_directory(out_path)
608
607
  out_path_checker = FileChecker(out_path, FileCheckConst.DIR, ability=FileCheckConst.WRITE_ABLE)
609
608
  out_path = out_path_checker.common_check()
@@ -621,7 +620,7 @@ def _api_precision_compare_parser(parser):
621
620
  parser.add_argument("-gpu", "--gpu_csv_path", dest="gpu_csv_path", default="", type=str,
622
621
  help="<Required> Accuracy_checking_details.csv generated on the GPU by using the "
623
622
  "api_accuracy_checker tool.",
624
- required=False)
623
+ required=True)
625
624
  parser.add_argument("-o", "--out_path", dest="out_path", default="", type=str,
626
625
  help="<optional> The api precision compare task result out path.",
627
626
  required=False)
@@ -0,0 +1,9 @@
1
+ {
2
+ "dump_json_path": "./dump.json",
3
+ "api_name": "",
4
+ "extract_api_path": "",
5
+ "propagation": "forward",
6
+ "data_mode": "random_data",
7
+ "random_seed": 1234,
8
+ "iter_times": 1
9
+ }