mindstudio-probe 1.2.1__py3-none-any.whl → 1.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (177) hide show
  1. {mindstudio_probe-1.2.1.dist-info → mindstudio_probe-1.3.0.dist-info}/METADATA +3 -3
  2. {mindstudio_probe-1.2.1.dist-info → mindstudio_probe-1.3.0.dist-info}/RECORD +168 -150
  3. msprobe/README.md +27 -22
  4. msprobe/core/common/const.py +129 -60
  5. msprobe/core/common/decorator.py +50 -0
  6. msprobe/core/common/exceptions.py +3 -1
  7. msprobe/core/common/file_utils.py +25 -2
  8. msprobe/core/common/inplace_ops.yaml +1 -0
  9. msprobe/core/common/utils.py +43 -33
  10. msprobe/core/compare/acc_compare.py +43 -74
  11. msprobe/core/compare/check.py +2 -6
  12. msprobe/core/compare/highlight.py +2 -0
  13. msprobe/core/compare/layer_mapping/data_scope_parser.py +1 -1
  14. msprobe/core/compare/layer_mapping/layer_mapping.py +2 -1
  15. msprobe/core/compare/merge_result/merge_result.py +16 -9
  16. msprobe/core/compare/merge_result/utils.py +81 -0
  17. msprobe/core/compare/multiprocessing_compute.py +19 -12
  18. msprobe/core/compare/npy_compare.py +30 -12
  19. msprobe/core/compare/utils.py +30 -10
  20. msprobe/core/data_dump/api_registry.py +176 -0
  21. msprobe/core/data_dump/data_collector.py +58 -13
  22. msprobe/core/data_dump/data_processor/base.py +94 -10
  23. msprobe/core/data_dump/data_processor/factory.py +3 -0
  24. msprobe/core/data_dump/data_processor/mindspore_processor.py +33 -33
  25. msprobe/core/data_dump/data_processor/pytorch_processor.py +99 -18
  26. msprobe/core/data_dump/json_writer.py +61 -40
  27. msprobe/core/grad_probe/constant.py +1 -0
  28. msprobe/core/grad_probe/grad_compare.py +1 -1
  29. msprobe/core/overflow_check/abnormal_scene.py +2 -0
  30. msprobe/docs/01.installation.md +27 -1
  31. msprobe/docs/02.config_introduction.md +27 -23
  32. msprobe/docs/03.config_examples.md +24 -0
  33. msprobe/docs/05.data_dump_PyTorch.md +103 -16
  34. msprobe/docs/06.data_dump_MindSpore.md +76 -32
  35. msprobe/docs/07.accuracy_checker_PyTorch.md +11 -1
  36. msprobe/docs/08.accuracy_checker_online_PyTorch.md +3 -1
  37. msprobe/docs/09.accuracy_checker_MindSpore.md +5 -3
  38. msprobe/docs/10.accuracy_compare_PyTorch.md +59 -33
  39. msprobe/docs/11.accuracy_compare_MindSpore.md +40 -16
  40. msprobe/docs/12.overflow_check_PyTorch.md +3 -1
  41. msprobe/docs/13.overflow_check_MindSpore.md +4 -2
  42. msprobe/docs/14.data_parse_PyTorch.md +1 -7
  43. msprobe/docs/18.online_dispatch.md +1 -1
  44. msprobe/docs/19.monitor.md +332 -273
  45. msprobe/docs/21.visualization_PyTorch.md +42 -13
  46. msprobe/docs/22.visualization_MindSpore.md +43 -13
  47. msprobe/docs/23.generate_operator_PyTorch.md +9 -9
  48. msprobe/docs/27.dump_json_instruction.md +301 -27
  49. msprobe/docs/28.debugger_save_instruction.md +94 -0
  50. msprobe/docs/28.kernel_dump_MindSpore.md +69 -0
  51. msprobe/docs/29.data_dump_MSAdapter.md +229 -0
  52. msprobe/docs/30.overflow_check_MSAdapter.md +31 -0
  53. msprobe/docs/FAQ.md +3 -11
  54. msprobe/docs/img/compare_result.png +0 -0
  55. msprobe/docs/img/merge_result.png +0 -0
  56. msprobe/docs/img/monitor/step_count_per_record.png +0 -0
  57. msprobe/docs/img/visualization/vis_browser_1.png +0 -0
  58. msprobe/docs/img/visualization/vis_match_info.png +0 -0
  59. msprobe/docs/img/visualization/vis_precision_info.png +0 -0
  60. msprobe/docs/img/visualization/vis_search_info.png +0 -0
  61. msprobe/docs/img/visualization/vis_show_info.png +0 -0
  62. msprobe/docs/img/visualization/vis_showcase.png +0 -0
  63. msprobe/docs/img/visualization/vis_unmatch_info.png +0 -0
  64. msprobe/mindspore/__init__.py +4 -2
  65. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +32 -7
  66. msprobe/mindspore/api_accuracy_checker/api_runner.py +70 -22
  67. msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +2 -1
  68. msprobe/mindspore/api_accuracy_checker/bench_functions/flash_attention_score.py +602 -0
  69. msprobe/mindspore/api_accuracy_checker/bench_functions/fusion_operator.py +41 -0
  70. msprobe/mindspore/api_accuracy_checker/compute_element.py +47 -1
  71. msprobe/mindspore/api_accuracy_checker/data_manager.py +2 -1
  72. msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +2 -1
  73. msprobe/mindspore/api_accuracy_checker/torch_mindtorch_importer.py +130 -0
  74. msprobe/mindspore/api_accuracy_checker/type_mapping.py +24 -1
  75. msprobe/mindspore/api_accuracy_checker/utils.py +6 -1
  76. msprobe/mindspore/common/const.py +61 -0
  77. msprobe/mindspore/common/utils.py +48 -18
  78. msprobe/mindspore/compare/ms_compare.py +27 -19
  79. msprobe/mindspore/compare/ms_graph_compare.py +6 -5
  80. msprobe/mindspore/debugger/debugger_config.py +31 -6
  81. msprobe/mindspore/debugger/precision_debugger.py +45 -14
  82. msprobe/mindspore/dump/dump_tool_factory.py +5 -3
  83. msprobe/mindspore/dump/hook_cell/api_register.py +142 -0
  84. msprobe/mindspore/dump/hook_cell/hook_cell.py +9 -10
  85. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +24 -26
  86. msprobe/mindspore/dump/jit_dump.py +21 -15
  87. msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +22 -56
  88. msprobe/mindspore/dym_loader/hook_dynamic_loader.h +0 -1
  89. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +10 -6
  90. msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +4 -2
  91. msprobe/mindspore/free_benchmark/self_check_tool_factory.py +6 -3
  92. msprobe/mindspore/grad_probe/global_context.py +2 -0
  93. msprobe/mindspore/grad_probe/grad_analyzer.py +2 -1
  94. msprobe/mindspore/grad_probe/hook.py +2 -4
  95. msprobe/mindspore/monitor/anomaly_detect.py +404 -0
  96. msprobe/mindspore/monitor/distributed/__init__.py +0 -0
  97. msprobe/mindspore/monitor/distributed/distributed_ops.yaml +15 -0
  98. msprobe/mindspore/monitor/distributed/stack_blacklist.yaml +5 -0
  99. msprobe/mindspore/monitor/distributed/wrap_distributed.py +300 -0
  100. msprobe/mindspore/monitor/features.py +63 -0
  101. msprobe/mindspore/monitor/module_hook.py +873 -0
  102. msprobe/mindspore/monitor/module_spec_verifier.py +94 -0
  103. msprobe/mindspore/monitor/utils.py +309 -0
  104. msprobe/mindspore/ms_config.py +8 -2
  105. msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +5 -3
  106. msprobe/mindspore/service.py +114 -34
  107. msprobe/pytorch/__init__.py +0 -1
  108. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +3 -6
  109. msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +12 -7
  110. msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +2 -2
  111. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +4 -5
  112. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +5 -5
  113. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +25 -6
  114. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +28 -19
  115. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +3 -1
  116. msprobe/pytorch/bench_functions/apply_adam.py +215 -0
  117. msprobe/pytorch/bench_functions/group_norm_silu.py +27 -0
  118. msprobe/pytorch/{parse.py → bench_functions/mish.py} +6 -4
  119. msprobe/pytorch/bench_functions/moe_gating_top_k_softmax.py +50 -0
  120. msprobe/pytorch/bench_functions/sort_v2.py +21 -0
  121. msprobe/pytorch/common/utils.py +97 -4
  122. msprobe/pytorch/debugger/debugger_config.py +19 -9
  123. msprobe/pytorch/debugger/precision_debugger.py +24 -1
  124. msprobe/pytorch/dump/module_dump/module_dump.py +4 -3
  125. msprobe/pytorch/dump/module_dump/module_processer.py +21 -35
  126. msprobe/pytorch/free_benchmark/common/utils.py +1 -1
  127. msprobe/pytorch/free_benchmark/compare/single_benchmark.py +1 -1
  128. msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +3 -3
  129. msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +3 -3
  130. msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +1 -1
  131. msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +1 -1
  132. msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +1 -1
  133. msprobe/pytorch/function_factory.py +8 -2
  134. msprobe/pytorch/grad_probe/grad_monitor.py +2 -2
  135. msprobe/pytorch/hook_module/api_register.py +131 -0
  136. msprobe/pytorch/hook_module/hook_module.py +19 -14
  137. msprobe/pytorch/hook_module/register_optimizer_hook.py +2 -1
  138. msprobe/pytorch/hook_module/support_wrap_ops.yaml +173 -75
  139. msprobe/pytorch/monitor/anomaly_detect.py +14 -29
  140. msprobe/pytorch/monitor/csv2tb.py +18 -14
  141. msprobe/pytorch/monitor/distributed/wrap_distributed.py +8 -2
  142. msprobe/pytorch/monitor/module_hook.py +238 -193
  143. msprobe/pytorch/monitor/module_metric.py +9 -6
  144. msprobe/pytorch/monitor/optimizer_collect.py +100 -67
  145. msprobe/pytorch/monitor/unittest/test_monitor.py +1 -1
  146. msprobe/pytorch/monitor/utils.py +76 -44
  147. msprobe/pytorch/online_dispatch/compare.py +0 -2
  148. msprobe/pytorch/online_dispatch/dispatch.py +9 -0
  149. msprobe/pytorch/online_dispatch/dump_compare.py +3 -0
  150. msprobe/pytorch/online_dispatch/utils.py +3 -0
  151. msprobe/pytorch/parse_tool/lib/interactive_cli.py +1 -6
  152. msprobe/pytorch/parse_tool/lib/utils.py +2 -1
  153. msprobe/pytorch/pt_config.py +30 -29
  154. msprobe/pytorch/service.py +114 -32
  155. msprobe/visualization/builder/graph_builder.py +75 -10
  156. msprobe/visualization/builder/msprobe_adapter.py +7 -6
  157. msprobe/visualization/compare/graph_comparator.py +42 -38
  158. msprobe/visualization/compare/mode_adapter.py +0 -19
  159. msprobe/visualization/graph/base_node.py +11 -3
  160. msprobe/visualization/graph/distributed_analyzer.py +71 -3
  161. msprobe/visualization/graph/graph.py +0 -11
  162. msprobe/visualization/graph/node_op.py +4 -3
  163. msprobe/visualization/graph_service.py +4 -5
  164. msprobe/visualization/utils.py +12 -35
  165. msprobe/mindspore/dump/hook_cell/api_registry.py +0 -205
  166. msprobe/mindspore/dump/hook_cell/wrap_api.py +0 -212
  167. msprobe/pytorch/hook_module/api_registry.py +0 -166
  168. msprobe/pytorch/hook_module/wrap_distributed.py +0 -75
  169. msprobe/pytorch/hook_module/wrap_functional.py +0 -66
  170. msprobe/pytorch/hook_module/wrap_npu_custom.py +0 -85
  171. msprobe/pytorch/hook_module/wrap_tensor.py +0 -69
  172. msprobe/pytorch/hook_module/wrap_torch.py +0 -84
  173. msprobe/pytorch/hook_module/wrap_vf.py +0 -60
  174. {mindstudio_probe-1.2.1.dist-info → mindstudio_probe-1.3.0.dist-info}/LICENSE +0 -0
  175. {mindstudio_probe-1.2.1.dist-info → mindstudio_probe-1.3.0.dist-info}/WHEEL +0 -0
  176. {mindstudio_probe-1.2.1.dist-info → mindstudio_probe-1.3.0.dist-info}/entry_points.txt +0 -0
  177. {mindstudio_probe-1.2.1.dist-info → mindstudio_probe-1.3.0.dist-info}/top_level.txt +0 -0
@@ -22,6 +22,7 @@ import mindspore as ms
22
22
  from mindspore import nn
23
23
  from mindspore.common.api import _no_grad
24
24
  from mindspore.ops.primitive import Primitive
25
+
25
26
  try:
26
27
  from mindspore.common._pijit_context import PIJitCaptureContext
27
28
  except ImportError:
@@ -31,7 +32,7 @@ else:
31
32
 
32
33
  from msprobe.core.common.exceptions import DistributedNotInitializedError, MsprobeException
33
34
  from msprobe.core.common.file_utils import create_directory
34
- from msprobe.core.common.utils import Const, print_tools_ends_info
35
+ from msprobe.core.common.utils import Const, print_tools_ends_info, DumpPathAggregation
35
36
  from msprobe.core.data_dump.data_collector import build_data_collector
36
37
  from msprobe.core.data_dump.data_processor.base import (ModuleBackwardInputsOutputs, ModuleForwardInputsOutputs,
37
38
  ModuleBackwardInputs)
@@ -40,7 +41,7 @@ from msprobe.mindspore.cell_processor import CellProcessor
40
41
  from msprobe.mindspore.common.log import logger
41
42
  from msprobe.mindspore.common.utils import (get_rank_if_initialized, clean_input_kwargs,
42
43
  is_mindtorch, register_backward_hook_functions)
43
- from msprobe.mindspore.dump.hook_cell.api_registry import api_register
44
+ from msprobe.mindspore.dump.hook_cell.api_register import get_api_register
44
45
  from msprobe.mindspore.dump.hook_cell.primitive_hooks import PrimitiveHookService
45
46
  from msprobe.mindspore.dump.jit_dump import JitDump
46
47
  from msprobe.mindspore.dump.hook_cell.hook_cell import HOOKCell
@@ -62,14 +63,19 @@ class Service:
62
63
  self.inner_switch = False
63
64
  self.primitive_switch = False
64
65
  self.current_iter = 0
66
+ self.loop = 0
67
+ self.init_step = 0
65
68
  self.first_start = True
66
69
  self.current_rank = None
67
70
  self.dump_iter_dir = None
68
71
  self.start_call = False
69
72
  self.should_stop_service = False
70
73
  self.params_grad_info = {}
74
+ self.hook_handle_dict = {}
71
75
  # 提前注册,确保注册尽可能多的API hook
76
+ self.api_register = get_api_register()
72
77
  self.register_api_hook()
78
+ self.init_for_debug_level()
73
79
 
74
80
  @staticmethod
75
81
  def check_model_valid(models):
@@ -138,7 +144,12 @@ class Service:
138
144
  if not (Const.FORWARD in self.config.data_mode and Const.BACKWARD not in self.config.data_mode):
139
145
  for param_name, param in params_dict.items():
140
146
  if param.requires_grad:
141
- param.register_hook(grad_hook(cell, ori_name, param_name))
147
+ name = ori_name + Const.SEP + param_name
148
+ old_handle = self.hook_handle_dict.get(name)
149
+ if old_handle and hasattr(old_handle, "remove"):
150
+ old_handle.remove()
151
+ handle = param.register_hook(grad_hook(cell, ori_name, param_name))
152
+ self.hook_handle_dict[name] = handle
142
153
 
143
154
  def init_params_grad_info(cell, params_dict):
144
155
  '''
@@ -168,11 +179,15 @@ class Service:
168
179
  module_input_output = self.prepare_module_input_output(target_type, cell, input_data, output)
169
180
  if target_type == BaseScope.Module_Type_Module:
170
181
  api_or_cell_name = self.cell_processor.set_and_get_reserved_name(cell, api_or_cell_name)
171
- params_dict = {key.split(Const.SEP)[-1]: value for key, value in cell.parameters_dict(
172
- recurse=False).items()}
173
- setattr(module_input_output, Const.PARAMS, params_dict)
182
+ params_dict = {}
183
+ if self.config.task != Const.STRUCTURE:
184
+ params_dict = {
185
+ key.split(Const.SEP)[-1]: value
186
+ for key, value in cell.parameters_dict(recurse=False).items()
187
+ }
188
+ setattr(module_input_output, Const.PARAMS, params_dict)
174
189
  # 判断是否需要注册参数hook
175
- if not hasattr(cell, 'params_grad_name') and params_dict:
190
+ if params_dict:
176
191
  ori_name = api_or_cell_name.rsplit(Const.SEP, 2)[0]
177
192
  grad_name = ori_name + Const.SEP + Const.PARAMS_GRAD
178
193
  # 首次执行前向hook时,添加params_grad_name属性,并注册参数hook
@@ -257,15 +272,33 @@ class Service:
257
272
  self.primitive_counters[primitive_name] += 1
258
273
 
259
274
  def step(self):
275
+ if self.config.level == Const.LEVEL_DEBUG:
276
+ return
260
277
  if self.config.async_dump:
261
278
  self.data_collector.fill_stack_tensor_data()
262
- self.data_collector.data_processor.dump_async_data()
279
+ if self.config.task == Const.TENSOR:
280
+ self.data_collector.data_processor.dump_async_data()
263
281
  self.data_collector.write_json()
264
- self.current_iter += 1
265
- self.data_collector.update_iter(self.current_iter)
282
+ self.loop += 1
266
283
  self.reset_status()
267
284
 
268
285
  def start(self, model=None):
286
+ if self.current_iter == 0:
287
+ if self.config.level in [Const.LEVEL_MIX, Const.LEVEL_L1]:
288
+ JitDump.set_config(self.config)
289
+ JitDump.set_data_collector(self.data_collector)
290
+ if hasattr(ms.common.api, "_MindsporeFunctionExecutor"):
291
+ ms.common.api._MindsporeFunctionExecutor = JitDump
292
+ else:
293
+ ms.common.api._JitExecutor = JitDump
294
+ ms.common.api._PyNativeExecutor.grad = JitDump.grad
295
+ if pijit_label:
296
+ PIJitCaptureContext.__enter__ = self.empty
297
+ PIJitCaptureContext.__exit__ = self.empty
298
+ self.current_iter = self.loop + self.init_step
299
+ self.data_collector.update_iter(self.current_iter)
300
+ if self.config.level == Const.LEVEL_DEBUG:
301
+ return
269
302
  self.start_call = True
270
303
  if self.should_stop_service:
271
304
  return
@@ -276,6 +309,7 @@ class Service:
276
309
  print_tools_ends_info()
277
310
  return
278
311
  if self.config.step and self.current_iter not in self.config.step:
312
+ JitDump.jit_dump_switch = False
279
313
  return
280
314
  self.model = self.check_model_valid(model)
281
315
 
@@ -291,17 +325,9 @@ class Service:
291
325
  return
292
326
  self.register_primitive_hook()
293
327
  self.register_cell_hook()
294
- if self.config.level in [Const.LEVEL_MIX, Const.LEVEL_L1]:
295
- JitDump.set_config(self.config)
296
- JitDump.set_data_collector(self.data_collector)
297
- ms.common.api._MindsporeFunctionExecutor = JitDump
298
- ms.common.api._PyNativeExecutor.grad = JitDump.grad
299
- if pijit_label:
300
- PIJitCaptureContext.__enter__ = self.empty
301
- PIJitCaptureContext.__exit__ = self.empty
302
328
  self.first_start = False
303
329
 
304
- api_register.api_set_hook_func()
330
+ self.api_register.register_all_api()
305
331
  self.switch = True
306
332
  self.primitive_switch = True
307
333
  logger.info(f"Dump switch is turned on at step {self.current_iter}. ")
@@ -310,6 +336,8 @@ class Service:
310
336
  JitDump.jit_dump_switch = True
311
337
 
312
338
  def stop(self):
339
+ if self.config.level == Const.LEVEL_DEBUG:
340
+ return
313
341
  if self.should_stop_service:
314
342
  return
315
343
  logger.info(f"{Const.TOOL_NAME}: debugger.stop() is set successfully. "
@@ -326,7 +354,8 @@ class Service:
326
354
  self.start_call = False
327
355
  if self.config.async_dump:
328
356
  self.data_collector.fill_stack_tensor_data()
329
- self.data_collector.data_processor.dump_async_data()
357
+ if self.config.task == Const.TENSOR:
358
+ self.data_collector.data_processor.dump_async_data()
330
359
  self.data_collector.write_json()
331
360
  JitDump.jit_dump_switch = False
332
361
 
@@ -370,12 +399,13 @@ class Service:
370
399
  else:
371
400
  dump_data_dir = None
372
401
 
373
- dump_file_path = os.path.join(dump_dir, "dump.json")
374
- stack_file_path = os.path.join(dump_dir, "stack.json")
375
- construct_file_path = os.path.join(dump_dir, "construct.json")
376
- self.data_collector.update_dump_paths(
377
- dump_file_path, stack_file_path, construct_file_path, dump_data_dir, None
378
- )
402
+ dump_path_aggregation = DumpPathAggregation()
403
+ dump_path_aggregation.dump_file_path = os.path.join(dump_dir, "dump.json")
404
+ dump_path_aggregation.stack_file_path = os.path.join(dump_dir, "stack.json")
405
+ dump_path_aggregation.construct_file_path = os.path.join(dump_dir, "construct.json")
406
+ dump_path_aggregation.dump_tensor_data_dir = dump_data_dir
407
+ self.data_collector.update_dump_paths(dump_path_aggregation)
408
+
379
409
  self.data_collector.initialize_json_file(
380
410
  framework=Const.MT_FRAMEWORK if is_mindtorch() else Const.MS_FRAMEWORK
381
411
  )
@@ -386,21 +416,21 @@ class Service:
386
416
  def register_api_hook(self):
387
417
  if self.config.level in [Const.LEVEL_MIX, Const.LEVEL_L1, Const.LEVEL_L2]:
388
418
  logger.info(f"The api {self.config.task} hook function is successfully mounted to the model.")
389
- api_register.initialize_hook(functools.partial(self.build_hook, BaseScope.Module_Type_API))
390
- api_register.api_set_hook_func()
419
+ self.api_register.initialize_hook(functools.partial(self.build_hook, BaseScope.Module_Type_API))
420
+ self.api_register.register_all_api()
391
421
 
392
422
  def get_cells_and_names(self):
393
423
  cells_and_names_with_index = {}
394
424
 
395
425
  def get_cell_or_module(model):
396
426
  return model.named_modules() if is_mindtorch() else model.cells_and_names()
397
-
427
+
398
428
  if isinstance(self.model, (list, tuple)):
399
429
  for index, model in enumerate(self.model):
400
430
  cells_and_names_with_index[str(index)] = get_cell_or_module(model)
401
431
  else:
402
432
  cells_and_names_with_index["-1"] = get_cell_or_module(self.model)
403
- return cells_and_names_with_index
433
+ return cells_and_names_with_index
404
434
 
405
435
  def register_primitive_hook(self):
406
436
  if self.config.level not in [Const.LEVEL_MIX, Const.LEVEL_L1]:
@@ -430,7 +460,7 @@ class Service:
430
460
  if not self.model:
431
461
  raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR,
432
462
  f"The current level is {self.config.level}, the model cannot be None")
433
- model_type = Const.MODULE if is_mindtorch() else Const.CELL
463
+ model_type = Const.MODULE if is_mindtorch() else Const.CELL
434
464
  cells_and_names_with_index = self.get_cells_and_names()
435
465
 
436
466
  for index, cells_and_names in cells_and_names_with_index.items():
@@ -439,7 +469,7 @@ class Service:
439
469
  if cell == model:
440
470
  continue
441
471
  cell_index = (index + Const.SEP) if index != "-1" else ""
442
- prefix = (model_type + Const.SEP + cell_index + name +
472
+ prefix = (model_type + Const.SEP + cell_index + name +
443
473
  Const.SEP + cell.__class__.__name__ + Const.SEP)
444
474
  _, forward_hook, backward_hook, _ = self.build_hook(BaseScope.Module_Type_Module, prefix)
445
475
  cell.register_forward_hook(forward_hook)
@@ -456,10 +486,9 @@ class Service:
456
486
 
457
487
  def reset_status(self):
458
488
  self.primitive_hook_service.primitive_counters.clear()
459
- self.data_collector.data_writer.reset_cache()
489
+ self.data_collector.reset_status()
460
490
  JitDump.jit_count = defaultdict(int)
461
491
  self.params_grad_info.clear()
462
-
463
492
  if self.config.level == Const.LEVEL_L2:
464
493
  self.data_collector.data_processor.reset_status()
465
494
  return
@@ -467,3 +496,54 @@ class Service:
467
496
  return
468
497
  if self.config.rank and self.current_rank not in self.config.rank:
469
498
  return
499
+
500
+ def init_for_debug_level(self):
501
+ if not (self.config.level == Const.LEVEL_DEBUG and self.config.task in [Const.TENSOR, Const.STATISTICS]):
502
+ return
503
+ try:
504
+ self.current_rank = get_rank_if_initialized()
505
+ except DistributedNotInitializedError:
506
+ self.current_rank = None
507
+ # dir: dump_path -- rank{} -- debug.json
508
+ self.dump_iter_dir = self.config.dump_path
509
+ cur_rank = self.current_rank if self.current_rank is not None else ''
510
+ dump_dir = os.path.join(self.dump_iter_dir, f"rank{cur_rank}")
511
+ create_directory(dump_dir)
512
+ if self.config.task in self.data_collector.tasks_need_tensor_data:
513
+ dump_data_dir = os.path.join(dump_dir, "dump_tensor_data")
514
+ create_directory(dump_data_dir)
515
+ else:
516
+ dump_data_dir = None
517
+
518
+ dump_path_aggregation = DumpPathAggregation()
519
+ dump_path_aggregation.dump_tensor_data_dir = dump_data_dir
520
+ dump_path_aggregation.debug_file_path = os.path.join(dump_dir, "debug.json")
521
+ self.data_collector.update_dump_paths(dump_path_aggregation)
522
+ self.data_collector.initialize_json_file(
523
+ framework=Const.MT_FRAMEWORK if is_mindtorch() else Const.MS_FRAMEWORK
524
+ )
525
+ self.debug_variable_counter = defaultdict(int)
526
+
527
+ def save(self, variable, name, save_backward):
528
+ '''
529
+ Args:
530
+ variable: Union[List[variable], dict{str: variable}, mindspore.tensor, str, float, int]
531
+ name: str
532
+ save_backward: boolean
533
+ Return:
534
+ void
535
+ '''
536
+ if self.config.level != Const.LEVEL_DEBUG:
537
+ return
538
+ count = self.debug_variable_counter[name]
539
+ self.debug_variable_counter[name] += 1
540
+
541
+ name_with_count = f"{name}.{count}"
542
+ grad_name_with_count = f"{name}_grad.{count}"
543
+
544
+ # forward save
545
+ self.data_collector.debug_data_collect_forward(variable, name_with_count)
546
+
547
+ # backward save
548
+ if save_backward:
549
+ self.data_collector.debug_data_collect_backward(variable, grad_name_with_count)
@@ -13,7 +13,6 @@
13
13
  # See the License for the specific language governing permissions and
14
14
  # limitations under the License.
15
15
 
16
-
17
16
  import torch
18
17
  from .compare.distributed_compare import compare_distributed
19
18
  from .compare.pt_compare import compare
@@ -40,7 +40,7 @@ from msprobe.pytorch.api_accuracy_checker.run_ut.run_ut_utils import get_validat
40
40
  from msprobe.pytorch.api_accuracy_checker.common.utils import extract_detailed_api_segments, extract_basic_api_segments
41
41
  from msprobe.core.common.file_utils import FileChecker, change_mode, create_directory
42
42
  from msprobe.pytorch.common.log import logger
43
- from msprobe.core.common.utils import CompareException
43
+ from msprobe.core.common.utils import CompareException, check_op_str_pattern_valid
44
44
  from msprobe.core.common.const import Const, CompareConst, FileCheckConst
45
45
 
46
46
  CompareConfig = namedtuple('CompareConfig', ['npu_csv_path', 'gpu_csv_path', 'result_csv_path', 'details_csv_path'])
@@ -151,6 +151,7 @@ def analyse_csv(npu_data, gpu_data, config):
151
151
  message = ''
152
152
  compare_column = ApiPrecisionOutputColumn()
153
153
  full_api_name_with_direction_status = row_npu[ApiPrecisionCompareColumn.API_NAME]
154
+ check_op_str_pattern_valid(full_api_name_with_direction_status)
154
155
  row_gpu = gpu_data[gpu_data[ApiPrecisionCompareColumn.API_NAME] == full_api_name_with_direction_status]
155
156
  api_name, api_full_name, direction_status = extract_detailed_api_segments(full_api_name_with_direction_status)
156
157
  if not api_full_name:
@@ -430,6 +431,7 @@ def _api_precision_compare(parser=None):
430
431
  _api_precision_compare_parser(parser)
431
432
  args = parser.parse_args(sys.argv[1:])
432
433
  _api_precision_compare_command(args)
434
+ logger.info("Compare task completed.")
433
435
 
434
436
 
435
437
  def _api_precision_compare_command(args):
@@ -457,8 +459,3 @@ def _api_precision_compare_parser(parser):
457
459
  parser.add_argument("-o", "--out_path", dest="out_path", default="", type=str,
458
460
  help="<optional> The api precision compare task result out path.",
459
461
  required=False)
460
-
461
-
462
- if __name__ == '__main__':
463
- _api_precision_compare()
464
- logger.info("Compare task completed.")
@@ -28,10 +28,10 @@ from msprobe.pytorch.api_accuracy_checker.compare.compare_utils import binary_st
28
28
  ulp_standard_api, thousandth_standard_api
29
29
  from msprobe.core.common.file_utils import FileOpen, load_json, save_json
30
30
  from msprobe.core.common.utils import check_file_or_directory_path, check_op_str_pattern_valid, is_int
31
- from msprobe.core.common.const import Const, MonitorConst, MsgConst
31
+ from msprobe.core.common.const import Const, MonitorConst, MsgConst, FileCheckConst
32
32
  from msprobe.core.common.log import logger
33
- from msprobe.core.common.file_utils import make_dir
34
- from msprobe.core.common.utils import recursion_depth_decorator
33
+ from msprobe.core.common.file_utils import make_dir, change_mode
34
+ from msprobe.core.common.decorator import recursion_depth_decorator
35
35
 
36
36
  TENSOR_DATA_LIST = ["torch.Tensor", "torch.nn.parameter.Parameter"]
37
37
  TORCH_BOOL_TYPE = ["torch.bool"]
@@ -50,6 +50,7 @@ DATA_NAME = "data_name"
50
50
  API_MAX_LENGTH = 30
51
51
  PROPAGATION_LIST = [Const.FORWARD, Const.BACKWARD]
52
52
  DATAMODE_LIST = ["random_data", "real_data"]
53
+ ITER_MAX_TIMES = 1000
53
54
 
54
55
 
55
56
  class APIInfo:
@@ -97,6 +98,8 @@ class CommonConfig:
97
98
  iter_t = self.iter_times
98
99
  if iter_t <= 0:
99
100
  raise ValueError("iter_times should be an integer bigger than zero!")
101
+ if iter_t > ITER_MAX_TIMES:
102
+ raise ValueError("iter_times should not be greater than 1000!")
100
103
 
101
104
  json_file = self.extract_api_path
102
105
  propagation = self.propagation
@@ -117,7 +120,7 @@ class CommonConfig:
117
120
 
118
121
  # Retrieve the first API name and dictionary
119
122
  forward_item = next(iter(json_content.items()), None)
120
- if not forward_item or not isinstance(forward_item[1], dict):
123
+ if not forward_item or not isinstance(forward_item[1], dict) or not forward_item[1]:
121
124
  raise ValueError(f'Invalid forward API data in json_content!')
122
125
 
123
126
  # if propagation is backward, ensure json file contains forward and backward info
@@ -127,7 +130,7 @@ class CommonConfig:
127
130
  # if propagation is backward, ensure it has valid data
128
131
  if propagation == Const.BACKWARD:
129
132
  backward_item = list(json_content.items())[1]
130
- if not isinstance(backward_item[1], dict):
133
+ if not isinstance(backward_item[1], dict) or not backward_item[1]:
131
134
  raise ValueError(f'Invalid backward API data in json_content!')
132
135
 
133
136
  return json_content
@@ -169,7 +172,7 @@ class APIExtractor:
169
172
  value = self.load_real_data_path(value, real_data_path)
170
173
  new_data[key] = value
171
174
  if not new_data:
172
- logger.error(f"Error: The api '{self.api_name}' does not exist in the file.")
175
+ logger.warning(f"Warning: The api '{self.api_name}' does not exist in the file.")
173
176
  else:
174
177
  save_json(self.output_file, new_data, indent=4)
175
178
  logger.info(
@@ -183,6 +186,7 @@ class APIExtractor:
183
186
  self.update_data_name(v, dump_data_dir)
184
187
  return value
185
188
 
189
+ @recursion_depth_decorator("OpGenerator: APIExtractor.update_data_name")
186
190
  def update_data_name(self, data, dump_data_dir):
187
191
  if isinstance(data, list):
188
192
  for item in data:
@@ -399,7 +403,7 @@ class OperatorScriptGenerator:
399
403
  def generate_kwargs_dict(self, kwargs_info, flag_device):
400
404
  kwargs_dict_generator = ""
401
405
  for key, value in kwargs_info.items():
402
- kwargs_dict_generator += '"' + key + '"' + MonitorConst.VPP_SEP
406
+ kwargs_dict_generator += '"' + key + '"' + MonitorConst.NAME_SEP
403
407
  if flag_device:
404
408
  kwargs_dict_generator += self.recursive_kwargs_dict(value, flag_device=True) + Const.COMMA
405
409
  else:
@@ -467,6 +471,7 @@ def _run_operator_generate_commond(cmd_args):
467
471
  fout.write(code_template.format(**internal_settings))
468
472
  except OSError:
469
473
  logger.error(f"Failed to open file. Please check file {template_path} or {operator_script_path}.")
474
+ change_mode(operator_script_path, FileCheckConst.DATA_FILE_AUTHORITY)
470
475
 
471
476
  logger.info(f"Generate operator script successfully and the name is {operator_script_path}.")
472
477
 
@@ -37,9 +37,9 @@ def load_pt(pt_path, to_cpu=False):
37
37
  pt_path = os.path.realpath(pt_path)
38
38
  try:
39
39
  if to_cpu:
40
- pt = torch.load(pt_path, map_location=torch.device("cpu"))
40
+ pt = torch.load(pt_path, map_location=torch.device("cpu"), weights_only=True)
41
41
  else:
42
- pt = torch.load(pt_path)
42
+ pt = torch.load(pt_path, weights_only=True)
43
43
  except Exception as e:
44
44
  raise RuntimeError(f"load pt file {{pt_path}} failed") from e
45
45
  return pt
@@ -50,6 +50,9 @@ def split_json_file(input_file, num_splits, filter_api):
50
50
  backward_data[f"{data_name}.backward"] = backward_data.pop(data_name)
51
51
 
52
52
  input_data = load_json(input_file)
53
+ if "dump_data_dir" not in input_data.keys():
54
+ logger.error("Invalid input file, 'dump_data_dir' field is missing")
55
+ raise CompareException("Invalid input file, 'dump_data_dir' field is missing")
53
56
  if input_data.get("data") is None:
54
57
  logger.error("Invalid input file, 'data' field is missing")
55
58
  raise CompareException("Invalid input file, 'data' field is missing")
@@ -97,7 +100,7 @@ def run_parallel_ut(config):
97
100
  processes = []
98
101
  device_id_cycle = cycle(config.device_id)
99
102
  if config.save_error_data_flag:
100
- logger.info("UT task error datas will be saved")
103
+ logger.info("UT task error data will be saved")
101
104
  logger.info(f"Starting parallel UT with {config.num_splits} processes")
102
105
  progress_bar = tqdm(total=config.total_items, desc="Total items", unit="items")
103
106
 
@@ -221,7 +224,3 @@ def main():
221
224
  args = parser.parse_args()
222
225
  config = prepare_config(args)
223
226
  run_parallel_ut(config)
224
-
225
-
226
- if __name__ == '__main__':
227
- main()
@@ -34,8 +34,10 @@ from msprobe.pytorch.api_accuracy_checker.run_ut.run_ut_utils import exec_api, i
34
34
  from msprobe.core.common.file_utils import check_link, FileChecker
35
35
  from msprobe.pytorch.api_accuracy_checker.common.utils import extract_basic_api_segments
36
36
  from msprobe.core.common.const import FileCheckConst, Const
37
+ from msprobe.core.common.utils import check_op_str_pattern_valid
37
38
  from msprobe.pytorch.common.log import logger
38
39
  from msprobe.pytorch.common.parse_json import parse_json_info_forward_backward
40
+ from msprobe.core.common.decorator import recursion_depth_decorator
39
41
 
40
42
 
41
43
  def check_tensor_overflow(x):
@@ -75,6 +77,7 @@ def check_data_overflow(x, device):
75
77
  return torch_npu.npu.utils.npu_check_overflow(x)
76
78
 
77
79
 
80
+ @recursion_depth_decorator("is_bool_output")
78
81
  def is_bool_output(x):
79
82
  if isinstance(x, (tuple, list)):
80
83
  if not x:
@@ -91,6 +94,7 @@ def run_overflow_check(forward_file):
91
94
  dump_path = os.path.dirname(forward_file)
92
95
  real_data_path = os.path.join(dump_path, Const.DUMP_TENSOR_DATA)
93
96
  for api_full_name, api_info_dict in tqdm(forward_content.items()):
97
+ check_op_str_pattern_valid(api_full_name)
94
98
  if is_unsupported_api(api_full_name, is_overflow_check=True):
95
99
  continue
96
100
  try:
@@ -161,6 +165,7 @@ def _run_overflow_check(parser=None):
161
165
  _run_overflow_check_parser(parser)
162
166
  args = parser.parse_args(sys.argv[1:])
163
167
  _run_overflow_check_command(args)
168
+ logger.info("UT task completed.")
164
169
 
165
170
 
166
171
  def _run_overflow_check_command(args):
@@ -175,8 +180,3 @@ def _run_overflow_check_command(args):
175
180
  logger.error(f"Set NPU device id failed. device id is: {args.device_id}")
176
181
  raise NotImplementedError from error
177
182
  run_overflow_check(api_info)
178
-
179
-
180
- if __name__ == '__main__':
181
- _run_overflow_check()
182
- logger.info("UT task completed.")
@@ -49,7 +49,7 @@ from msprobe.core.common.file_utils import FileChecker, change_mode, \
49
49
  from msprobe.pytorch.common.log import logger
50
50
  from msprobe.pytorch.pt_config import parse_json_config
51
51
  from msprobe.core.common.const import Const, FileCheckConst, CompareConst
52
- from msprobe.core.common.utils import safe_get_value, CompareException
52
+ from msprobe.core.common.utils import safe_get_value, CompareException, is_int, check_op_str_pattern_valid
53
53
  from msprobe.pytorch.common.utils import seed_all
54
54
  from msprobe.pytorch.api_accuracy_checker.tensor_transport_layer.attl import ATTL, ATTLConfig, move2device_exec
55
55
  from msprobe.pytorch.api_accuracy_checker.tensor_transport_layer.device_dispatch import ConsumerDispatcher
@@ -65,6 +65,7 @@ DETAILS_FILE_NAME = "accuracy_checking_details_" + current_time + ".csv"
65
65
 
66
66
  not_backward_list = ['repeat_interleave']
67
67
  unsupported_backward_list = ['masked_select']
68
+ unsupported_api_list = ["to"]
68
69
 
69
70
 
70
71
  tqdm_params = {
@@ -83,6 +84,9 @@ tqdm_params = {
83
84
  }
84
85
 
85
86
 
87
+ seed_all()
88
+
89
+
86
90
  def run_ut(config):
87
91
  logger.info("start UT test")
88
92
  if config.online_config.is_online:
@@ -93,7 +97,7 @@ def run_ut(config):
93
97
  logger.info(f"UT task details will be saved in {config.details_csv_path}")
94
98
 
95
99
  if config.save_error_data:
96
- logger.info(f"UT task error_datas will be saved in {config.error_data_path}")
100
+ logger.info(f"UT task error_data will be saved in {config.error_data_path}")
97
101
  compare = Comparator(config.result_csv_path, config.details_csv_path, config.is_continue_run_ut, config=config)
98
102
 
99
103
  if config.online_config.is_online:
@@ -117,6 +121,7 @@ def run_ut(config):
117
121
  def run_api_offline(config, compare, api_name_set):
118
122
  err_column = CompareColumn()
119
123
  for _, (api_full_name, api_info_dict) in enumerate(tqdm(config.forward_content.items(), **tqdm_params)):
124
+ check_op_str_pattern_valid(api_full_name)
120
125
  if api_full_name in api_name_set:
121
126
  continue
122
127
  if is_unsupported_api(api_full_name):
@@ -218,6 +223,7 @@ def blacklist_and_whitelist_filter(api_name, black_list, white_list):
218
223
  If api is both in black_list and black_list, black_list first.
219
224
  return: False for exec api, True for not exec
220
225
  """
226
+ black_list.extend(unsupported_api_list)
221
227
  if black_list and api_name in black_list:
222
228
  return True
223
229
  if white_list and api_name not in white_list:
@@ -317,7 +323,8 @@ def run_torch_api_online(api_full_name, api_data, backward_content):
317
323
  if kwargs.get("device"):
318
324
  del kwargs["device"]
319
325
 
320
- device_out = exec_api(api_type, api_name, Const.CUDA_LOWERCASE, args, kwargs)
326
+ device_exec_params = ExecParams(api_type, api_name, current_device, args, kwargs, False, None)
327
+ device_out = exec_api(device_exec_params)
321
328
  device_out = move2device_exec(device_out, "cpu")
322
329
  return UtDataInfo(None, None, out, device_out, None, in_fwd_data_list, None, rank=api_data.rank)
323
330
 
@@ -344,6 +351,9 @@ def need_to_backward(grad_index, out):
344
351
 
345
352
  def run_backward(args, grad, grad_index, out):
346
353
  if grad_index is not None:
354
+ if not is_int(grad_index):
355
+ logger.error(f"{grad_index} dtype is not int")
356
+ raise TypeError(f"{grad_index} dtype is not int")
347
357
  if grad_index >= len(out):
348
358
  logger.error(f"Run backward error when grad_index is {grad_index}")
349
359
  raise IndexError(f"Run backward error when grad_index is {grad_index}")
@@ -430,6 +440,7 @@ def preprocess_forward_content(forward_content):
430
440
  arg_cache = {}
431
441
 
432
442
  for key, value in forward_content.items():
443
+ check_op_str_pattern_valid(key)
433
444
  base_key = key.rsplit(Const.SEP, 1)[0]
434
445
 
435
446
  if key not in arg_cache:
@@ -469,6 +480,7 @@ def _run_ut(parser=None):
469
480
  _run_ut_parser(parser)
470
481
  args = parser.parse_args(sys.argv[1:])
471
482
  run_ut_command(args)
483
+
472
484
 
473
485
 
474
486
  def checked_online_config(online_config):
@@ -492,6 +504,7 @@ def checked_online_config(online_config):
492
504
  check_file_or_directory_path(os.path.join(online_config.tls_path, "server.key"))
493
505
  check_file_or_directory_path(os.path.join(online_config.tls_path, "server.crt"))
494
506
  check_crt_valid(os.path.join(online_config.tls_path, "server.crt"))
507
+ check_crt_valid(os.path.join(online_config.tls_path, "server.key"), True)
495
508
 
496
509
  # host and port
497
510
  if not isinstance(online_config.host, str) or not re.match(Const.ipv4_pattern, online_config.host):
@@ -561,7 +574,14 @@ def run_ut_command(args):
561
574
  error_data_path = checker_config.error_data_path
562
575
  if save_error_data:
563
576
  if args.result_csv_path:
564
- time_info = result_csv_path.split('.')[0].split('_')[-1]
577
+ parts_by_dot = result_csv_path.split(Const.SEP)
578
+ if len(parts_by_dot) < 2 or not parts_by_dot[0]:
579
+ raise ValueError("result_csv_path does not contain a valid file name with an extension.")
580
+ file_name_part = parts_by_dot[0]
581
+ parts_by_underscore = file_name_part.split(Const.REPLACEMENT_CHARACTER)
582
+ if len(parts_by_underscore) < 2:
583
+ raise ValueError("File name part does not contain enough '_' separated segments.")
584
+ time_info = parts_by_underscore[-1]
565
585
  global UT_ERROR_DATA_DIR
566
586
  UT_ERROR_DATA_DIR = 'ut_error_data' + time_info
567
587
  error_data_path = initialize_save_error_data(error_data_path)
@@ -579,9 +599,8 @@ def run_ut_command(args):
579
599
  }
580
600
  run_ut_config = checker_config.get_run_ut_config(**config_params)
581
601
  run_ut(run_ut_config)
602
+ logger.info("UT task completed.")
582
603
 
583
604
 
584
605
  if __name__ == '__main__':
585
- seed_all()
586
606
  _run_ut()
587
- logger.info("UT task completed.")