mindstudio-probe 8.1.1__py3-none-any.whl → 8.1.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (95) hide show
  1. {mindstudio_probe-8.1.1.dist-info → mindstudio_probe-8.1.2.dist-info}/METADATA +1 -1
  2. {mindstudio_probe-8.1.1.dist-info → mindstudio_probe-8.1.2.dist-info}/RECORD +95 -94
  3. msprobe/core/common/const.py +3 -0
  4. msprobe/core/common/file_utils.py +45 -5
  5. msprobe/core/common/utils.py +117 -13
  6. msprobe/core/common_config.py +15 -1
  7. msprobe/core/compare/acc_compare.py +21 -9
  8. msprobe/core/compare/compare_cli.py +10 -2
  9. msprobe/core/compare/merge_result/merge_result.py +1 -1
  10. msprobe/core/compare/utils.py +8 -2
  11. msprobe/core/config_check/checkers/base_checker.py +2 -0
  12. msprobe/core/config_check/checkers/hyperparameter_checker.py +5 -4
  13. msprobe/core/config_check/ckpt_compare/ckpt_comparator.py +4 -1
  14. msprobe/core/config_check/config_check_cli.py +1 -1
  15. msprobe/core/config_check/config_checker.py +1 -2
  16. msprobe/core/data_dump/data_collector.py +4 -1
  17. msprobe/core/data_dump/data_processor/mindspore_processor.py +23 -1
  18. msprobe/core/data_dump/data_processor/pytorch_processor.py +3 -25
  19. msprobe/core/debugger/precision_debugger.py +13 -8
  20. msprobe/core/hook_manager.py +112 -82
  21. msprobe/core/monitor/utils.py +338 -0
  22. msprobe/core/service.py +2 -1
  23. msprobe/core/single_save/single_comparator.py +5 -3
  24. msprobe/docs/01.installation.md +1 -0
  25. msprobe/docs/05.data_dump_PyTorch.md +4 -4
  26. msprobe/docs/07.accuracy_checker_PyTorch.md +14 -11
  27. msprobe/docs/09.accuracy_checker_MindSpore.md +13 -11
  28. msprobe/docs/10.accuracy_compare_PyTorch.md +3 -1
  29. msprobe/docs/11.accuracy_compare_MindSpore.md +4 -2
  30. msprobe/docs/12.overflow_check_PyTorch.md +3 -2
  31. msprobe/docs/13.overflow_check_MindSpore.md +1 -1
  32. msprobe/docs/14.data_parse_PyTorch.md +35 -32
  33. msprobe/docs/21.visualization_PyTorch.md +9 -8
  34. msprobe/docs/22.visualization_MindSpore.md +1 -0
  35. msprobe/docs/23.generate_operator_PyTorch.md +1 -1
  36. msprobe/docs/24.code_mapping_Mindspore.md +6 -5
  37. msprobe/docs/31.config_check.md +15 -5
  38. msprobe/docs/33.generate_operator_MindSpore.md +2 -2
  39. msprobe/docs/34.RL_collect.md +18 -9
  40. msprobe/docs/35.nan_analyze.md +4 -3
  41. msprobe/docs/FAQ.md +3 -0
  42. msprobe/docs/img/ms_layer.png +0 -0
  43. msprobe/mindspore/api_accuracy_checker/api_runner.py +29 -1
  44. msprobe/mindspore/cell_processor.py +35 -14
  45. msprobe/mindspore/code_mapping/bind.py +23 -4
  46. msprobe/mindspore/code_mapping/graph_parser.py +6 -4
  47. msprobe/mindspore/common/utils.py +3 -0
  48. msprobe/mindspore/compare/common_dir_compare.py +32 -12
  49. msprobe/mindspore/compare/ms_graph_compare.py +7 -2
  50. msprobe/mindspore/compare/utils.py +9 -1
  51. msprobe/mindspore/debugger/debugger_config.py +13 -11
  52. msprobe/mindspore/debugger/precision_debugger.py +67 -45
  53. msprobe/mindspore/dump/dump_tool_factory.py +2 -0
  54. msprobe/mindspore/dump/hook_cell/hook_cell.py +14 -9
  55. msprobe/mindspore/dump/hook_cell/ms_hook_manager.py +12 -7
  56. msprobe/mindspore/dump/hook_cell/primitive_hooks.py +27 -13
  57. msprobe/mindspore/dump/jit_dump.py +6 -3
  58. msprobe/mindspore/dump/kernel_kbyk_dump.py +13 -6
  59. msprobe/mindspore/dym_loader/hook_dynamic_loader.cpp +6 -5
  60. msprobe/mindspore/dym_loader/hook_dynamic_loader.h +2 -2
  61. msprobe/mindspore/grad_probe/grad_analyzer.py +2 -0
  62. msprobe/mindspore/mindspore_service.py +2 -2
  63. msprobe/mindspore/monitor/common_func.py +1 -1
  64. msprobe/mindspore/monitor/module_hook.py +3 -3
  65. msprobe/mindspore/monitor/utils.py +0 -252
  66. msprobe/mindspore/ms_config.py +0 -1
  67. msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +1 -1
  68. msprobe/nan_analyze/graph.py +4 -0
  69. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +15 -6
  70. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +1 -1
  71. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +1 -1
  72. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +2 -4
  73. msprobe/pytorch/common/utils.py +0 -16
  74. msprobe/pytorch/compare/pt_compare.py +5 -0
  75. msprobe/pytorch/debugger/debugger_config.py +12 -5
  76. msprobe/pytorch/debugger/precision_debugger.py +8 -1
  77. msprobe/pytorch/dump/module_dump/hook_wrapper.py +1 -3
  78. msprobe/pytorch/dump/module_dump/module_processer.py +44 -13
  79. msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +2 -0
  80. msprobe/pytorch/hook_module/hook_module.py +9 -9
  81. msprobe/pytorch/hook_module/pt_hook_manager.py +7 -7
  82. msprobe/pytorch/monitor/csv2tb.py +3 -10
  83. msprobe/pytorch/monitor/features.py +5 -0
  84. msprobe/pytorch/monitor/module_hook.py +6 -7
  85. msprobe/pytorch/monitor/module_metric.py +0 -3
  86. msprobe/pytorch/monitor/optimizer_collect.py +1 -1
  87. msprobe/pytorch/monitor/utils.py +1 -317
  88. msprobe/pytorch/online_dispatch/dispatch.py +1 -1
  89. msprobe/pytorch/online_dispatch/dump_compare.py +7 -1
  90. msprobe/pytorch/parse_tool/lib/utils.py +2 -4
  91. msprobe/visualization/graph_service.py +1 -1
  92. {mindstudio_probe-8.1.1.dist-info → mindstudio_probe-8.1.2.dist-info}/LICENSE +0 -0
  93. {mindstudio_probe-8.1.1.dist-info → mindstudio_probe-8.1.2.dist-info}/WHEEL +0 -0
  94. {mindstudio_probe-8.1.1.dist-info → mindstudio_probe-8.1.2.dist-info}/entry_points.txt +0 -0
  95. {mindstudio_probe-8.1.1.dist-info → mindstudio_probe-8.1.2.dist-info}/top_level.txt +0 -0
@@ -20,7 +20,7 @@ import mindspore as ms
20
20
  from mindspore._c_expression import MSContext
21
21
 
22
22
  from msprobe.core.common.const import Const, MsgConst
23
- from msprobe.core.common.utils import check_token_range
23
+ from msprobe.core.common.utils import check_token_range, ThreadSafe
24
24
  from msprobe.core.common.runtime import Runtime
25
25
  from msprobe.core.debugger.precision_debugger import BasePrecisionDebugger
26
26
  from msprobe.mindspore.cell_processor import CellProcessor
@@ -57,18 +57,14 @@ ConfigParameters = namedtuple("ConfigParameters", ["config_path", "task", "dump_
57
57
 
58
58
  class PrecisionDebugger(BasePrecisionDebugger):
59
59
 
60
- def __new__(cls, config_path=None, task=None, dump_path=None,
61
- level=None, step=None, opt=None):
62
- if not cls._instance:
63
- cls._instance = super().__new__(cls)
64
- cls._instance.initialized = False
65
- cls._instance.config = None
66
- cls.service = None
67
- cls.first_start = False
68
- return cls._instance
69
-
70
- def __init__(self, config_path=None, task=None, dump_path=None,
71
- level=None, step=None):
60
+ def __init__(
61
+ self,
62
+ config_path=None,
63
+ task=None,
64
+ dump_path=None,
65
+ level=None,
66
+ step=None
67
+ ):
72
68
  if self.initialized:
73
69
  return
74
70
  set_register_backward_hook_functions()
@@ -81,13 +77,15 @@ class PrecisionDebugger(BasePrecisionDebugger):
81
77
  self.common_config.dump_path = dump_path if dump_path else self.common_config.dump_path
82
78
  self.config = DebuggerConfig(self.common_config, self.task_config)
83
79
 
84
- if self._need_msprobe_c() and _msprobe_c:
80
+ if self._is_kernel_dump() and not self.task_config.is_regex_valid:
81
+ raise ValueError('Illegal regular expressions exist in the list.')
82
+
83
+ if self._is_kernel_dump() and _msprobe_c:
85
84
  os.environ["MS_HOOK_ENABLE"] = "on"
86
85
  _msprobe_c._PrecisionDebugger(framework="MindSpore", config_path=config_path)
87
86
 
88
87
  self.config.execution_mode = self._get_execution_mode()
89
88
  if self._need_service():
90
- self.config.check_config_with_l2()
91
89
  self.service = MindsporeService(self.config)
92
90
 
93
91
  Runtime.step_count = 0
@@ -119,8 +117,6 @@ class PrecisionDebugger(BasePrecisionDebugger):
119
117
 
120
118
  @staticmethod
121
119
  def _is_graph_dump(config: DebuggerConfig):
122
- if config.level != MsConst.KERNEL:
123
- return False
124
120
  if not config.list:
125
121
  return True
126
122
  is_graph = any(item.startswith("name-regex") for item in config.list)
@@ -132,28 +128,23 @@ class PrecisionDebugger(BasePrecisionDebugger):
132
128
  instance = cls._get_instance()
133
129
  if instance is None:
134
130
  return
135
- if cls._need_msprobe_c() and _msprobe_c:
136
- _msprobe_c._PrecisionDebugger().start()
137
- check_token_range(token_range)
138
- instance.config.execution_mode = cls._get_execution_mode()
139
- if cls._need_service():
140
- if not instance.service:
141
- instance.service = MindsporeService(instance.config)
142
- instance.config.check_model(model, token_range)
143
- instance.service.start(model, token_range)
131
+ if cls._is_kernel_dump():
132
+ cls._start_kernel_dump()
144
133
  else:
145
- if not instance.first_start:
146
- get_api_register().restore_all_api()
147
- handler = TaskHandlerFactory.create(instance.config, model)
148
- handler.handle()
149
- if enable_dynamic_kbyk_dump:
150
- _set_init_iter(0)
151
- if enable_dynamic_kbyk_dump:
152
- is_valid_rank = (not instance.config.rank or Runtime.rank_id in instance.config.rank)
153
- is_valid_step = (not instance.config.step or Runtime.step_count in instance.config.step)
154
- if is_valid_rank and is_valid_step:
155
- _dump_start()
156
- Runtime.is_running = True
134
+ check_token_range(token_range)
135
+ instance.config.execution_mode = cls._get_execution_mode()
136
+ if cls._need_service():
137
+ with ThreadSafe():
138
+ if not instance.service:
139
+ instance.service = MindsporeService(instance.config)
140
+ instance.config.check_model(model, token_range)
141
+ instance.service.start(model, token_range)
142
+ else:
143
+ if not instance.first_start:
144
+ get_api_register().restore_all_api()
145
+ handler = TaskHandlerFactory.create(instance.config, model)
146
+ handler.handle()
147
+ Runtime.is_running = True
157
148
  instance.first_start = True
158
149
 
159
150
  @classmethod
@@ -165,14 +156,15 @@ class PrecisionDebugger(BasePrecisionDebugger):
165
156
  if instance.task == Const.GRAD_PROBE:
166
157
  instance.gm.stop()
167
158
  if instance.service:
168
- instance.service.stop()
159
+ with ThreadSafe():
160
+ instance.service.stop()
169
161
  else:
170
162
  Runtime.is_running = False
171
163
  if enable_dynamic_kbyk_dump:
172
164
  _dump_stop()
173
- if cls._need_msprobe_c() and _msprobe_c:
165
+ if cls._is_kernel_dump() and _msprobe_c:
174
166
  _msprobe_c._PrecisionDebugger().stop()
175
-
167
+
176
168
  @classmethod
177
169
  def step(cls):
178
170
  instance = cls._get_instance()
@@ -180,12 +172,13 @@ class PrecisionDebugger(BasePrecisionDebugger):
180
172
  return
181
173
 
182
174
  if instance.service:
183
- instance.service.step()
175
+ with ThreadSafe():
176
+ instance.service.step()
184
177
  if is_graph_mode_cell_dump_allowed(instance.config):
185
178
  GraphModeCellDump.step()
186
179
  if enable_dynamic_kbyk_dump:
187
180
  _dump_step(1)
188
- if cls._need_msprobe_c() and _msprobe_c:
181
+ if cls._is_kernel_dump() and _msprobe_c:
189
182
  _msprobe_c._PrecisionDebugger().step()
190
183
 
191
184
  HOOKCell.cell_count = defaultdict(int)
@@ -193,6 +186,7 @@ class PrecisionDebugger(BasePrecisionDebugger):
193
186
  Runtime.step_count += 1
194
187
 
195
188
  @classmethod
189
+ @ThreadSafe.synchronized
196
190
  def monitor(cls, opt):
197
191
  instance = cls._instance
198
192
  if not instance:
@@ -202,6 +196,7 @@ class PrecisionDebugger(BasePrecisionDebugger):
202
196
  instance.gm.monitor(opt)
203
197
 
204
198
  @classmethod
199
+ @ThreadSafe.synchronized
205
200
  def save(cls, variable, name, save_backward=True):
206
201
  instance = cls._instance
207
202
  if not instance:
@@ -224,14 +219,41 @@ class PrecisionDebugger(BasePrecisionDebugger):
224
219
  instance = cls._instance
225
220
  if not instance:
226
221
  raise Exception(MsgConst.NOT_CREATED_INSTANCE)
222
+ if instance.config.level_ori == Const.LEVEL_L2:
223
+ return False
227
224
  if instance.config.execution_mode != MsConst.PYNATIVE_MODE:
228
225
  return False
229
226
  else:
230
- return instance.config.task != Const.FREE_BENCHMARK and not instance._is_graph_dump(instance.config)
227
+ return instance.config.task != Const.FREE_BENCHMARK
231
228
 
232
229
  @classmethod
233
- def _need_msprobe_c(cls):
230
+ def _is_kernel_dump(cls):
234
231
  instance = cls._instance
235
232
  if not instance:
236
233
  raise Exception(MsgConst.NOT_CREATED_INSTANCE)
237
234
  return instance.config.level_ori == Const.LEVEL_L2
235
+
236
+ @classmethod
237
+ def _start_kernel_dump(cls):
238
+ instance = cls._get_instance()
239
+ is_graph_config = cls._is_graph_dump(instance.config)
240
+ instance.config.check_config_with_l2(is_graph_config)
241
+ if not is_graph_config:
242
+ if not instance.service:
243
+ instance.service = MindsporeService(instance.config)
244
+ instance.service.start()
245
+ else:
246
+ if _msprobe_c:
247
+ _msprobe_c._PrecisionDebugger().start()
248
+ if not instance.first_start:
249
+ get_api_register().restore_all_api()
250
+ handlers = TaskHandlerFactory.create(instance.config)
251
+ for handler in handlers:
252
+ handler.handle()
253
+ if enable_dynamic_kbyk_dump:
254
+ _set_init_iter(0)
255
+ if enable_dynamic_kbyk_dump:
256
+ is_valid_rank = (not instance.config.rank or Runtime.rank_id in instance.config.rank)
257
+ is_valid_step = (not instance.config.step or Runtime.step_count in instance.config.step)
258
+ if is_valid_rank and is_valid_step:
259
+ _dump_start()
@@ -51,6 +51,8 @@ class DumpToolFactory:
51
51
  else:
52
52
  if len(config.data_mode) != 1 or config.data_mode[0] not in Const.GRAPH_DATA_MODE_LIST:
53
53
  raise Exception("data_mode must be one of all, input, output.")
54
+ if config.level == Const.KERNEL:
55
+ return (KernelGraphDump(config), KernelKbykDump(config))
54
56
  tool = DumpToolFactory.tools.get(config.level)
55
57
  if not tool:
56
58
  raise Exception("Valid level is needed.")
@@ -13,15 +13,16 @@
13
13
  # See the License for the specific language governing permissions and
14
14
  # limitations under the License.
15
15
 
16
+ import threading
16
17
  from collections import defaultdict
17
18
 
18
19
  import mindspore as ms
19
20
  from mindspore import nn
20
21
 
21
22
  from msprobe.core.common.runtime import Runtime
23
+ from msprobe.core.common.utils import ThreadSafe
22
24
  from msprobe.mindspore.common.utils import is_mindtorch, register_backward_hook_functions
23
25
 
24
-
25
26
  ms_version = ms.__version__
26
27
 
27
28
 
@@ -35,16 +36,17 @@ def get_cell_count(name):
35
36
 
36
37
  def __init__(self, hook_build_func) -> None:
37
38
  super(HOOKCell, self).__init__()
38
- self.changed_status = False
39
39
  self.msprobe_input_kwargs = {}
40
- if not HOOKCell.g_stop_hook:
41
- HOOKCell.g_stop_hook = True
42
- self.changed_status = True
40
+
41
+ self.tid = threading.get_ident()
42
+ self.stop_hook = HOOKCell.inner_stop_hook.get(self.tid, False)
43
+ if not self.stop_hook:
43
44
  self.forward_data_collected = False
44
45
 
45
46
  if not Runtime.is_running:
46
47
  return
47
48
  prefix = self.prefix_api_name if hasattr(self, "prefix_api_name") else ""
49
+ ThreadSafe.acquire()
48
50
  if callable(hook_build_func):
49
51
  hook_set = hook_build_func(prefix)
50
52
  if ms_version < "2.6.0" and not is_mindtorch():
@@ -59,21 +61,24 @@ def __init__(self, hook_build_func) -> None:
59
61
 
60
62
  # 重载call,加全局标志。
61
63
  def __call__(self, *args, **kwargs):
64
+ changed = False
65
+ if not self.stop_hook:
66
+ HOOKCell.inner_stop_hook[self.tid] = True
67
+ changed = True
62
68
  try:
63
69
  self.msprobe_input_kwargs = kwargs
64
70
  out = super(HOOKCell, self).__call__(*args, **kwargs)
65
71
  except Exception as e:
66
72
  raise e
67
73
  finally:
68
- if self.changed_status:
69
- self.changed_status = False
70
- HOOKCell.g_stop_hook = False
74
+ if changed:
75
+ HOOKCell.inner_stop_hook[self.tid] = False
71
76
  return out
72
77
 
73
78
 
74
79
  hook_cell_dict = {
75
80
  "cell_count": defaultdict(int),
76
- "g_stop_hook": False,
81
+ "inner_stop_hook": defaultdict(bool),
77
82
  "add_cell_count": staticmethod(add_cell_count),
78
83
  "get_cell_count": staticmethod(get_cell_count),
79
84
  "__init__": __init__,
@@ -13,9 +13,11 @@
13
13
  # See the License for the specific language governing permissions and
14
14
  # limitations under the License.
15
15
 
16
+ import threading
17
+
16
18
  from mindspore.common.api import _no_grad
17
19
  from msprobe.core.common.const import Const
18
- from msprobe.core.common.utils import replace_last_occurrence
20
+ from msprobe.core.common.utils import replace_last_occurrence, ThreadSafe
19
21
  from msprobe.core.data_dump.data_processor.base import ModuleBackwardInputs
20
22
  from msprobe.core.hook_manager import BaseHookManager, HookSet
21
23
  from msprobe.mindspore.common.utils import has_kwargs_in_forward_hook
@@ -78,11 +80,14 @@ class MindsproeHookManager(BaseHookManager):
78
80
  def backward_pre_hook(module, grad_input):
79
81
  if self.config.level != Const.LEVEL_L2:
80
82
  return
81
- if not self._should_execute_hook(hook_type, module, False):
83
+ tid = threading.get_ident()
84
+ if not self._should_execute_hook(hook_type, module, False, tid):
82
85
  return
83
- BaseHookManager.inner_switch = True
84
- module_input = ModuleBackwardInputs(grad_input=grad_input)
85
- self.data_collector.update_api_or_module_name(name)
86
- self.data_collector.backward_input_data_collect(name, module, self._pid, module_input)
87
- BaseHookManager.inner_switch = False
86
+
87
+ with ThreadSafe():
88
+ BaseHookManager.inner_switch[tid] = True
89
+ module_input = ModuleBackwardInputs(grad_input=grad_input)
90
+ self.data_collector.update_api_or_module_name(name)
91
+ self.data_collector.backward_input_data_collect(name, module, self._pid, module_input)
92
+ BaseHookManager.inner_switch[tid] = False
88
93
  return backward_pre_hook
@@ -14,13 +14,16 @@
14
14
  # limitations under the License.
15
15
 
16
16
  import os
17
+ import threading
17
18
 
18
19
  from mindspore import ops
19
20
  from mindspore.common.tensor import Tensor
20
-
21
- from msprobe.core.common.utils import Const, DumpException
22
- from msprobe.core.data_dump.data_processor.base import (ModuleBackwardInputs, ModuleBackwardOutputs,
23
- ModuleForwardInputsOutputs)
21
+ from msprobe.core.common.utils import Const, DumpException, ThreadSafe
22
+ from msprobe.core.data_dump.data_processor.base import (
23
+ ModuleBackwardInputs,
24
+ ModuleBackwardOutputs,
25
+ ModuleForwardInputsOutputs
26
+ )
24
27
  from msprobe.core.hook_manager import BaseHookManager
25
28
  from msprobe.mindspore.common.log import logger
26
29
 
@@ -56,10 +59,13 @@ class PrimitiveHookService:
56
59
  callable: 反向 hook 函数。
57
60
  """
58
61
 
62
+ @ThreadSafe.synchronized
59
63
  def backward_hook(grad):
64
+ tid = threading.get_ident()
65
+ BaseHookManager.inner_switch[tid] = True
66
+
60
67
  captured_grads.extend(grad)
61
68
  backward_primitive_name = f"{updated_primitive_name}{Const.SEP}{Const.BACKWARD}"
62
- self.service_instance.inner_switch = True
63
69
  try:
64
70
  if hook_type == Const.INPUT:
65
71
  self.service_instance.data_collector.update_api_or_module_name(backward_primitive_name)
@@ -78,7 +84,7 @@ class PrimitiveHookService:
78
84
  logger.error(f"This is a primitive op {hook_type}_backward dump error: {exception}, "
79
85
  f"updated_primitive_name: {updated_primitive_name}")
80
86
  raise DumpException(DumpException.BACKWARD_DATA_COLLECTION_ERROR) from exception
81
- self.service_instance.inner_switch = False
87
+ BaseHookManager.inner_switch[tid] = False
82
88
 
83
89
  return backward_hook
84
90
 
@@ -137,9 +143,13 @@ class PrimitiveHookService:
137
143
  return tuple(hooked_outputs)
138
144
  return out
139
145
 
146
+ @ThreadSafe.synchronized
140
147
  def pre_forward_hook(primitive_name, primitive_instance, args, kwargs):
148
+ tid = threading.get_ident()
149
+ BaseHookManager.inner_switch[tid] = True
150
+
151
+ self.service_instance.data_collector.update_api_or_module_name(primitive_name)
141
152
  module_input_output = ModuleForwardInputsOutputs(args=args, kwargs=kwargs, output=None)
142
- self.service_instance.inner_switch = True
143
153
  try:
144
154
  self.service_instance.data_collector.forward_input_data_collect(
145
155
  primitive_name,
@@ -151,11 +161,15 @@ class PrimitiveHookService:
151
161
  logger.error(f"This is a primitive op dump error during forward input data collection: {exception}, "
152
162
  f"primitive_name: {primitive_name}")
153
163
  raise DumpException(DumpException.FORWARD_DATA_COLLECTION_ERROR) from exception
154
- self.service_instance.inner_switch = False
164
+ BaseHookManager.inner_switch[tid] = False
155
165
 
166
+ @ThreadSafe.synchronized
156
167
  def post_forward_hook(primitive_name, primitive_instance, args, kwargs, output):
168
+ tid = threading.get_ident()
169
+ BaseHookManager.inner_switch[tid] = True
170
+
171
+ self.service_instance.data_collector.update_api_or_module_name(primitive_name)
157
172
  module_input_output = ModuleForwardInputsOutputs(args=args, kwargs=kwargs, output=output)
158
- self.service_instance.inner_switch = True
159
173
  try:
160
174
  self.service_instance.data_collector.forward_output_data_collect(
161
175
  primitive_name,
@@ -167,7 +181,7 @@ class PrimitiveHookService:
167
181
  logger.error(f"This is a primitive op dump error during forward output data collection: {exception}, "
168
182
  f"primitive_name: {primitive_name}")
169
183
  raise DumpException(DumpException.FORWARD_DATA_COLLECTION_ERROR) from exception
170
- self.service_instance.inner_switch = False
184
+ BaseHookManager.inner_switch[tid] = False
171
185
 
172
186
  def wrapped_primitive_call(instance_self, *args, **kwargs):
173
187
  """
@@ -185,7 +199,8 @@ class PrimitiveHookService:
185
199
  current_count = self.primitive_counters.get(primitive_name, 0)
186
200
  updated_primitive_name = f"{Const.PRIMITIVE_PREFIX}{Const.SEP}{primitive_name}{Const.SEP}{current_count}"
187
201
 
188
- if not self.service_instance.primitive_switch or BaseHookManager.inner_switch:
202
+ tid = threading.get_ident()
203
+ if not self.service_instance.primitive_switch or BaseHookManager.inner_switch[tid]:
189
204
  return origin_func(*args, **kwargs)
190
205
 
191
206
  captured_grads_input, captured_grads_output = [], []
@@ -198,8 +213,6 @@ class PrimitiveHookService:
198
213
  raise DumpException(DumpException.INPUT_HOOK_ERROR) from exception
199
214
 
200
215
  forward_primitive_name = f"{updated_primitive_name}{Const.SEP}{Const.FORWARD}"
201
- self.service_instance.data_collector.update_api_or_module_name(forward_primitive_name)
202
-
203
216
  pre_forward_hook(forward_primitive_name, instance_self, hooked_inputs, kwargs)
204
217
  try:
205
218
  out = origin_func(*hooked_inputs, **kwargs)
@@ -220,6 +233,7 @@ class PrimitiveHookService:
220
233
 
221
234
  return wrapped_primitive_call
222
235
 
236
+ @ThreadSafe.synchronized
223
237
  def update_primitive_counters(self, primitive_name):
224
238
  if primitive_name not in self.primitive_counters:
225
239
  self.primitive_counters[primitive_name] = 0
@@ -13,13 +13,14 @@
13
13
  # See the License for the specific language governing permissions and
14
14
  # limitations under the License.
15
15
 
16
- from collections import defaultdict
17
16
  import os
18
17
  import types
18
+ from collections import defaultdict
19
19
 
20
20
  import mindspore
21
21
  from mindspore import nn
22
22
  from mindspore._c_expression import PyNativeExecutor_
23
+
23
24
  try:
24
25
  from mindspore.common.api import _MindsporeFunctionExecutor
25
26
  except ImportError:
@@ -27,19 +28,21 @@ except ImportError:
27
28
 
28
29
  from msprobe.core.common.log import logger
29
30
  from msprobe.core.common.const import Const
31
+ from msprobe.core.common.utils import ThreadSafe
30
32
  from msprobe.core.common.runtime import Runtime
31
33
  from msprobe.core.data_dump.data_processor.base import ModuleForwardInputsOutputs, ModuleBackwardInputsOutputs
32
34
  from msprobe.mindspore.common.const import Const as MsConst
33
35
  from msprobe.mindspore.dump.hook_cell.api_register import get_api_register
34
36
 
35
-
36
37
  _api_register = get_api_register()
37
38
 
38
39
 
39
40
  def dump_jit(name, in_feat, out_feat, is_forward):
40
41
  pid = os.getpid()
41
42
  name = name if name else "JitFunction"
42
- if JitDump.need_dump():
43
+ if not JitDump.need_dump():
44
+ return
45
+ with ThreadSafe():
43
46
  if is_forward:
44
47
  if name in JitDump.jit_count:
45
48
  JitDump.jit_count[name] += 1
@@ -39,12 +39,19 @@ class KernelKbykDump:
39
39
  common_set["input_output"] = 0
40
40
  common_set["kernels"] = []
41
41
  common_set["support_device"] = [0, 1, 2, 3, 4, 5, 6, 7]
42
- e2e_set = {
43
- "enable": not config.async_dump,
44
- "trans_flag": True,
45
- "stat_calc_mode": config.stat_cal_mode,
46
- "device_stat_precision_mode": config.device_stat_precision_mode,
47
- }
42
+
43
+ if config.stat_cal_mode and config.device_stat_precision_mode:
44
+ e2e_set = {
45
+ "enable": not config.async_dump,
46
+ "trans_flag": True,
47
+ "stat_calc_mode": config.stat_cal_mode,
48
+ "device_stat_precision_mode": config.device_stat_precision_mode
49
+ }
50
+ else:
51
+ e2e_set = {
52
+ "enable": not config.async_dump,
53
+ "trans_flag": True
54
+ }
48
55
 
49
56
  if config.list:
50
57
  common_set["dump_mode"] = 1
@@ -23,13 +23,14 @@
23
23
 
24
24
  namespace py = pybind11;
25
25
 
26
- HookDynamicLoader &HookDynamicLoader::GetInstance()
26
+ HookDynamicLoader &HookDynamicLoader::GetInstance()
27
27
  {
28
28
  static HookDynamicLoader instance;
29
29
  return instance;
30
30
  }
31
31
 
32
- bool HookDynamicLoader::LoadFunction(void *handle, const std::string &functionName) {
32
+ bool HookDynamicLoader::LoadFunction(void *handle, const std::string &functionName)
33
+ {
33
34
  void *func = dlsym(handle, functionName.c_str());
34
35
  if (!func) {
35
36
  MS_LOG(WARNING) << "Could not load function: " << functionName << ", error: " << dlerror();
@@ -83,7 +84,7 @@ bool HookDynamicLoader::LoadLibrary()
83
84
  return true;
84
85
  }
85
86
 
86
- bool HookDynamicLoader::UnloadLibrary()
87
+ bool HookDynamicLoader::UnloadLibrary()
87
88
  {
88
89
  std::lock_guard<std::mutex> lock(mutex_);
89
90
  if (!handle_) {
@@ -103,8 +104,8 @@ void *HookDynamicLoader::GetHooker(const std::string &funcName)
103
104
  std::lock_guard<std::mutex> lock(mutex_);
104
105
  auto iter = funcMap_.find(funcName);
105
106
  if (iter == funcMap_.end()) {
106
- MS_LOG(WARNING) << "Function not found: " << funcName;
107
- return nullptr;
107
+ MS_LOG(WARNING) << "Function not found: " << funcName;
108
+ return nullptr;
108
109
  }
109
110
  return iter->second;
110
111
  }
@@ -1,5 +1,5 @@
1
- /**
2
- * Copyright 2024 Huawei Technologies Co., Ltd
1
+ /*
2
+ * Copyright (C) 2024-2025. Huawei Technologies Co., Ltd. All rights reserved.
3
3
  *
4
4
  * Licensed under the Apache License, Version 2.0 (the "License");
5
5
  * you may not use this file except in compliance with the License.
@@ -245,6 +245,8 @@ class CSVGenerator(Process):
245
245
  return ["Max", "Min", "Norm", "Shape"]
246
246
 
247
247
  def get_dist_header(self) -> List[str]:
248
+ if not self.bounds:
249
+ return []
248
250
  intervals = []
249
251
  for i, _ in enumerate(self.bounds):
250
252
  if i == 0:
@@ -39,11 +39,11 @@ else:
39
39
  pijit_label = True
40
40
 
41
41
 
42
- class MindsporeService(BaseService):
42
+ class MindsporeService(BaseService):
43
43
  @property
44
44
  def _get_framework_type(self):
45
45
  return Const.MT_FRAMEWORK if is_mindtorch() else Const.MS_FRAMEWORK
46
-
46
+
47
47
  @staticmethod
48
48
  def _get_current_rank():
49
49
  return get_rank_if_initialized()
@@ -16,7 +16,7 @@
16
16
 
17
17
  from mindspore import nn
18
18
  from mindspore import communication
19
- from msprobe.mindspore.monitor.utils import logger
19
+ from msprobe.core.common.log import logger
20
20
  from msprobe.mindspore.common.utils import is_mindtorch
21
21
  if is_mindtorch():
22
22
  import torch
@@ -28,11 +28,12 @@ from mindspore import nn, _no_grad
28
28
  from msprobe.core.common.log import logger
29
29
  from msprobe.core.common.const import MonitorConst, Const
30
30
  from msprobe.core.common.file_utils import load_json, save_json
31
+ from msprobe.core.monitor.utils import validate_config, get_output_base_dir, get_target_output_dir
31
32
  from msprobe.core.monitor.anomaly_processor import AnomalyScanner, AnomalyDataFactory, AnomalyDataWriter
32
33
  from msprobe.mindspore.common.utils import is_mindtorch
33
34
  from msprobe.mindspore.monitor.common_func import is_valid_instance, get_parameters, get_submodules, get_rank
34
- from msprobe.mindspore.monitor.utils import get_summary_writer_tag_name, validate_config, step_accumulates_one, \
35
- is_skip_step, get_metrics, get_target_output_dir
35
+ from msprobe.mindspore.monitor.utils import get_summary_writer_tag_name, step_accumulates_one, is_skip_step, \
36
+ get_metrics
36
37
  from msprobe.mindspore.monitor.optimizer_collect import OptimizerMonFactory
37
38
  from msprobe.mindspore.monitor.data_writers import CSVWriterWithAD, BaseWriterWithAD, WriterInput
38
39
  from msprobe.mindspore.monitor.distributed.wrap_distributed import api_register, create_hooks, op_aggregate
@@ -250,7 +251,6 @@ class TrainerMon:
250
251
  self.has_collect_times = 0 # 重设采集计数器
251
252
  self.print_struct = self.config.get("print_struct", False)
252
253
  self.targets = self.config.get("targets", None)
253
- self.is_select = self.config.get("is_select", False)
254
254
  self.module_rank_list = self.config.get("module_ranks", [])
255
255
  self.format = self.config.get('format', MonitorConst.CSV) # only csv supported in mindspore
256
256
  self.eps = self.config.get('eps', 1e-8)