mindstudio-probe 1.2.2__py3-none-any.whl → 1.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (153) hide show
  1. {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-1.3.0.dist-info}/METADATA +3 -3
  2. {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-1.3.0.dist-info}/RECORD +143 -144
  3. msprobe/README.md +25 -20
  4. msprobe/core/common/const.py +110 -66
  5. msprobe/core/common/decorator.py +50 -0
  6. msprobe/core/common/exceptions.py +3 -1
  7. msprobe/core/common/file_utils.py +25 -2
  8. msprobe/core/common/utils.py +30 -34
  9. msprobe/core/compare/acc_compare.py +43 -74
  10. msprobe/core/compare/check.py +2 -6
  11. msprobe/core/compare/highlight.py +2 -0
  12. msprobe/core/compare/layer_mapping/layer_mapping.py +2 -1
  13. msprobe/core/compare/merge_result/merge_result.py +8 -2
  14. msprobe/core/compare/multiprocessing_compute.py +19 -12
  15. msprobe/core/compare/npy_compare.py +30 -12
  16. msprobe/core/compare/utils.py +20 -10
  17. msprobe/core/data_dump/api_registry.py +176 -0
  18. msprobe/core/data_dump/data_processor/base.py +2 -2
  19. msprobe/core/data_dump/data_processor/mindspore_processor.py +19 -32
  20. msprobe/core/data_dump/data_processor/pytorch_processor.py +45 -15
  21. msprobe/core/data_dump/json_writer.py +38 -35
  22. msprobe/core/grad_probe/constant.py +1 -0
  23. msprobe/core/grad_probe/grad_compare.py +1 -1
  24. msprobe/core/overflow_check/abnormal_scene.py +2 -0
  25. msprobe/docs/01.installation.md +2 -1
  26. msprobe/docs/02.config_introduction.md +17 -15
  27. msprobe/docs/05.data_dump_PyTorch.md +70 -2
  28. msprobe/docs/06.data_dump_MindSpore.md +33 -12
  29. msprobe/docs/07.accuracy_checker_PyTorch.md +11 -1
  30. msprobe/docs/08.accuracy_checker_online_PyTorch.md +3 -1
  31. msprobe/docs/09.accuracy_checker_MindSpore.md +1 -1
  32. msprobe/docs/10.accuracy_compare_PyTorch.md +59 -33
  33. msprobe/docs/11.accuracy_compare_MindSpore.md +40 -16
  34. msprobe/docs/12.overflow_check_PyTorch.md +3 -1
  35. msprobe/docs/13.overflow_check_MindSpore.md +4 -2
  36. msprobe/docs/14.data_parse_PyTorch.md +1 -7
  37. msprobe/docs/18.online_dispatch.md +1 -1
  38. msprobe/docs/19.monitor.md +124 -62
  39. msprobe/docs/21.visualization_PyTorch.md +32 -13
  40. msprobe/docs/22.visualization_MindSpore.md +32 -13
  41. msprobe/docs/23.generate_operator_PyTorch.md +9 -9
  42. msprobe/docs/27.dump_json_instruction.md +278 -8
  43. msprobe/docs/28.kernel_dump_MindSpore.md +1 -1
  44. msprobe/docs/29.data_dump_MSAdapter.md +229 -0
  45. msprobe/docs/30.overflow_check_MSAdapter.md +31 -0
  46. msprobe/docs/FAQ.md +3 -11
  47. msprobe/docs/img/compare_result.png +0 -0
  48. msprobe/docs/img/merge_result.png +0 -0
  49. msprobe/docs/img/visualization/vis_browser_1.png +0 -0
  50. msprobe/docs/img/visualization/vis_match_info.png +0 -0
  51. msprobe/docs/img/visualization/vis_precision_info.png +0 -0
  52. msprobe/docs/img/visualization/vis_search_info.png +0 -0
  53. msprobe/docs/img/visualization/vis_show_info.png +0 -0
  54. msprobe/docs/img/visualization/vis_showcase.png +0 -0
  55. msprobe/docs/img/visualization/vis_unmatch_info.png +0 -0
  56. msprobe/mindspore/__init__.py +4 -3
  57. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +6 -1
  58. msprobe/mindspore/api_accuracy_checker/api_runner.py +19 -9
  59. msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +2 -1
  60. msprobe/mindspore/api_accuracy_checker/bench_functions/flash_attention_score.py +602 -0
  61. msprobe/mindspore/api_accuracy_checker/bench_functions/fusion_operator.py +41 -0
  62. msprobe/mindspore/api_accuracy_checker/data_manager.py +2 -1
  63. msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +2 -1
  64. msprobe/mindspore/api_accuracy_checker/torch_mindtorch_importer.py +2 -1
  65. msprobe/mindspore/common/const.py +61 -0
  66. msprobe/mindspore/common/utils.py +31 -19
  67. msprobe/mindspore/compare/ms_compare.py +27 -19
  68. msprobe/mindspore/compare/ms_graph_compare.py +6 -5
  69. msprobe/mindspore/debugger/debugger_config.py +6 -4
  70. msprobe/mindspore/debugger/precision_debugger.py +22 -10
  71. msprobe/mindspore/dump/dump_tool_factory.py +5 -3
  72. msprobe/mindspore/dump/hook_cell/api_register.py +142 -0
  73. msprobe/mindspore/dump/hook_cell/hook_cell.py +9 -10
  74. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +24 -26
  75. msprobe/mindspore/dump/jit_dump.py +14 -9
  76. msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +22 -56
  77. msprobe/mindspore/dym_loader/hook_dynamic_loader.h +0 -1
  78. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +10 -6
  79. msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +4 -2
  80. msprobe/mindspore/free_benchmark/self_check_tool_factory.py +6 -3
  81. msprobe/mindspore/grad_probe/global_context.py +2 -0
  82. msprobe/mindspore/grad_probe/grad_analyzer.py +2 -1
  83. msprobe/mindspore/grad_probe/hook.py +2 -4
  84. msprobe/mindspore/monitor/distributed/wrap_distributed.py +1 -1
  85. msprobe/mindspore/monitor/module_hook.py +354 -302
  86. msprobe/mindspore/monitor/utils.py +46 -4
  87. msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +5 -3
  88. msprobe/mindspore/service.py +23 -17
  89. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +3 -6
  90. msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +11 -6
  91. msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +2 -2
  92. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +4 -5
  93. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +5 -5
  94. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +25 -6
  95. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +28 -19
  96. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +3 -1
  97. msprobe/pytorch/bench_functions/moe_gating_top_k_softmax.py +6 -0
  98. msprobe/pytorch/common/utils.py +29 -7
  99. msprobe/pytorch/debugger/precision_debugger.py +10 -1
  100. msprobe/pytorch/dump/module_dump/module_dump.py +4 -3
  101. msprobe/pytorch/dump/module_dump/module_processer.py +12 -6
  102. msprobe/pytorch/free_benchmark/common/utils.py +1 -1
  103. msprobe/pytorch/free_benchmark/compare/single_benchmark.py +1 -1
  104. msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +3 -3
  105. msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +3 -3
  106. msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +1 -1
  107. msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +1 -1
  108. msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +1 -1
  109. msprobe/pytorch/function_factory.py +1 -1
  110. msprobe/pytorch/grad_probe/grad_monitor.py +2 -2
  111. msprobe/pytorch/hook_module/api_register.py +131 -0
  112. msprobe/pytorch/hook_module/hook_module.py +19 -14
  113. msprobe/pytorch/hook_module/register_optimizer_hook.py +2 -1
  114. msprobe/pytorch/hook_module/support_wrap_ops.yaml +172 -75
  115. msprobe/pytorch/monitor/csv2tb.py +8 -2
  116. msprobe/pytorch/monitor/distributed/wrap_distributed.py +8 -2
  117. msprobe/pytorch/monitor/module_hook.py +131 -105
  118. msprobe/pytorch/monitor/module_metric.py +3 -0
  119. msprobe/pytorch/monitor/optimizer_collect.py +55 -4
  120. msprobe/pytorch/monitor/unittest/test_monitor.py +1 -1
  121. msprobe/pytorch/monitor/utils.py +68 -1
  122. msprobe/pytorch/online_dispatch/compare.py +0 -2
  123. msprobe/pytorch/online_dispatch/dispatch.py +9 -0
  124. msprobe/pytorch/online_dispatch/dump_compare.py +3 -0
  125. msprobe/pytorch/online_dispatch/utils.py +3 -0
  126. msprobe/pytorch/parse_tool/lib/interactive_cli.py +1 -6
  127. msprobe/pytorch/parse_tool/lib/utils.py +2 -1
  128. msprobe/pytorch/pt_config.py +11 -7
  129. msprobe/pytorch/service.py +11 -8
  130. msprobe/visualization/builder/graph_builder.py +44 -5
  131. msprobe/visualization/builder/msprobe_adapter.py +0 -1
  132. msprobe/visualization/compare/graph_comparator.py +42 -38
  133. msprobe/visualization/compare/mode_adapter.py +0 -19
  134. msprobe/visualization/graph/base_node.py +8 -1
  135. msprobe/visualization/graph/distributed_analyzer.py +1 -10
  136. msprobe/visualization/graph/graph.py +0 -11
  137. msprobe/visualization/graph/node_op.py +1 -2
  138. msprobe/visualization/graph_service.py +1 -1
  139. msprobe/visualization/utils.py +2 -33
  140. msprobe/mindspore/dump/hook_cell/api_registry.py +0 -207
  141. msprobe/mindspore/dump/hook_cell/wrap_api.py +0 -212
  142. msprobe/pytorch/hook_module/api_registry.py +0 -166
  143. msprobe/pytorch/hook_module/wrap_distributed.py +0 -79
  144. msprobe/pytorch/hook_module/wrap_functional.py +0 -66
  145. msprobe/pytorch/hook_module/wrap_npu_custom.py +0 -85
  146. msprobe/pytorch/hook_module/wrap_tensor.py +0 -69
  147. msprobe/pytorch/hook_module/wrap_torch.py +0 -84
  148. msprobe/pytorch/hook_module/wrap_vf.py +0 -60
  149. msprobe/pytorch/parse.py +0 -19
  150. {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-1.3.0.dist-info}/LICENSE +0 -0
  151. {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-1.3.0.dist-info}/WHEEL +0 -0
  152. {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-1.3.0.dist-info}/entry_points.txt +0 -0
  153. {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-1.3.0.dist-info}/top_level.txt +0 -0
@@ -12,13 +12,16 @@
12
12
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
13
  # See the License for the specific language governing permissions and
14
14
  # limitations under the License.
15
-
15
+ import os
16
+ import re
17
+ from datetime import datetime
16
18
  from mindspore import dtype as mstype, Tensor
17
19
 
18
20
  from msprobe.mindspore.monitor.features import FUNC_MAP
19
21
  from msprobe.core.common.const import MonitorConst
20
22
  from msprobe.core.common.utils import is_int
21
23
  from msprobe.core.common.log import logger
24
+ from msprobe.core.common.file_utils import check_file_or_directory_path
22
25
 
23
26
 
24
27
  def get_single_metrics(op_list, tag, tensor, output=None):
@@ -95,8 +98,8 @@ def validate_ranks(ranks):
95
98
  if not isinstance(ranks, list):
96
99
  raise TypeError("module_ranks should be a list")
97
100
  for rank in ranks:
98
- if not isinstance(rank, str):
99
- raise TypeError(f"element in module_ranks should be a str, get {type(rank)}")
101
+ if not isinstance(rank, int):
102
+ raise TypeError(f"element in module_ranks should be a int, get {type(rank)}")
100
103
 
101
104
 
102
105
  def validate_targets(targets):
@@ -209,6 +212,11 @@ def validate_collect_times(collect_times):
209
212
  raise ValueError("collect_times must greater than 1")
210
213
 
211
214
 
215
+ def validate_dynamic_on(dynamic_on):
216
+ if not isinstance(dynamic_on, bool):
217
+ raise TypeError('dynamic_on should be a bool')
218
+
219
+
212
220
  def validate_config(config):
213
221
  config['ops'] = validate_ops(config.get('ops', []))
214
222
 
@@ -255,9 +263,12 @@ def validate_config(config):
255
263
  step_interval = config.get('step_interval', 1)
256
264
  validate_step_interval(step_interval)
257
265
 
258
- collect_times = config.get('collect_times', 1e8)
266
+ collect_times = config.get('collect_times', int(1e8))
259
267
  validate_collect_times(collect_times)
260
268
 
269
+ dynamic_on = config.get('dynamic_on', False)
270
+ validate_dynamic_on(dynamic_on)
271
+
261
272
  if not targets:
262
273
  if xy_distribution:
263
274
  config["all_xy"] = True
@@ -265,3 +276,34 @@ def validate_config(config):
265
276
  config["is_select"] = False
266
277
  else:
267
278
  config["is_select"] = True
279
+
280
+
281
+ def time_str2time_digit(time_str):
282
+ time_format = '%b%d_%H-%M-%S'
283
+ try:
284
+ time_digit = datetime.strptime(time_str, time_format)
285
+ except Exception as e:
286
+ raise RuntimeError(f"illegal timestamp: {time_str}, timestamp should be prefix \
287
+ of existing output dirpath, like 'Dec03_21-34-40'.") from e
288
+ return time_digit
289
+
290
+
291
+ def get_target_output_dir(monitor_path, time_start, time_end):
292
+ check_file_or_directory_path(monitor_path, isdir=True)
293
+ time_start = time_str2time_digit(time_start) if time_start is not None else time_start
294
+ time_end = time_str2time_digit(time_end) if time_end is not None else time_end
295
+ if time_start and time_end and time_start > time_end:
296
+ raise ValueError(f"time_start({time_start}) greater than time_end({time_end})")
297
+ result = {}
298
+ for dirname in os.listdir(monitor_path):
299
+ match = re.match(MonitorConst.OUTPUT_DIR_PATTERN, dirname)
300
+ if not match:
301
+ continue
302
+ time_tag = match.group(1)
303
+ rank = match.group(2)
304
+ target_time = time_str2time_digit(time_tag)
305
+ start_ok = time_start is None or target_time >= time_start
306
+ end_ok = time_end is None or target_time <= time_end
307
+ if start_ok and end_ok:
308
+ result[rank] = os.path.join(monitor_path, dirname)
309
+ return result
@@ -1,4 +1,4 @@
1
- # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
1
+ # Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
2
2
  # All rights reserved.
3
3
  #
4
4
  # Licensed under the Apache License, Version 2.0 (the "License");
@@ -13,6 +13,7 @@
13
13
  # See the License for the specific language governing permissions and
14
14
  # limitations under the License.
15
15
 
16
+ from msprobe.core.common.log import logger
16
17
  from msprobe.mindspore.common.const import Const
17
18
  from msprobe.mindspore.debugger.debugger_config import DebuggerConfig
18
19
  from msprobe.mindspore.overflow_check.kernel_graph_overflow_check import KernelGraphOverflowCheck
@@ -44,6 +45,7 @@ class OverflowCheckToolFactory:
44
45
  raise Exception("Valid level is needed.")
45
46
  tool = tool.get(config.execution_mode)
46
47
  if not tool:
47
- raise Exception(f"Overflow check is not supported in {config.execution_mode} mode "
48
- f"when level is {config.level}.")
48
+ logger.error(f"Overflow check is not supported in {config.execution_mode} mode "
49
+ f"when level is {config.level}.")
50
+ raise ValueError
49
51
  return tool(config)
@@ -41,7 +41,7 @@ from msprobe.mindspore.cell_processor import CellProcessor
41
41
  from msprobe.mindspore.common.log import logger
42
42
  from msprobe.mindspore.common.utils import (get_rank_if_initialized, clean_input_kwargs,
43
43
  is_mindtorch, register_backward_hook_functions)
44
- from msprobe.mindspore.dump.hook_cell.api_registry import api_register
44
+ from msprobe.mindspore.dump.hook_cell.api_register import get_api_register
45
45
  from msprobe.mindspore.dump.hook_cell.primitive_hooks import PrimitiveHookService
46
46
  from msprobe.mindspore.dump.jit_dump import JitDump
47
47
  from msprobe.mindspore.dump.hook_cell.hook_cell import HOOKCell
@@ -63,6 +63,8 @@ class Service:
63
63
  self.inner_switch = False
64
64
  self.primitive_switch = False
65
65
  self.current_iter = 0
66
+ self.loop = 0
67
+ self.init_step = 0
66
68
  self.first_start = True
67
69
  self.current_rank = None
68
70
  self.dump_iter_dir = None
@@ -71,6 +73,7 @@ class Service:
71
73
  self.params_grad_info = {}
72
74
  self.hook_handle_dict = {}
73
75
  # 提前注册,确保注册尽可能多的API hook
76
+ self.api_register = get_api_register()
74
77
  self.register_api_hook()
75
78
  self.init_for_debug_level()
76
79
 
@@ -276,11 +279,24 @@ class Service:
276
279
  if self.config.task == Const.TENSOR:
277
280
  self.data_collector.data_processor.dump_async_data()
278
281
  self.data_collector.write_json()
279
- self.current_iter += 1
280
- self.data_collector.update_iter(self.current_iter)
282
+ self.loop += 1
281
283
  self.reset_status()
282
284
 
283
285
  def start(self, model=None):
286
+ if self.current_iter == 0:
287
+ if self.config.level in [Const.LEVEL_MIX, Const.LEVEL_L1]:
288
+ JitDump.set_config(self.config)
289
+ JitDump.set_data_collector(self.data_collector)
290
+ if hasattr(ms.common.api, "_MindsporeFunctionExecutor"):
291
+ ms.common.api._MindsporeFunctionExecutor = JitDump
292
+ else:
293
+ ms.common.api._JitExecutor = JitDump
294
+ ms.common.api._PyNativeExecutor.grad = JitDump.grad
295
+ if pijit_label:
296
+ PIJitCaptureContext.__enter__ = self.empty
297
+ PIJitCaptureContext.__exit__ = self.empty
298
+ self.current_iter = self.loop + self.init_step
299
+ self.data_collector.update_iter(self.current_iter)
284
300
  if self.config.level == Const.LEVEL_DEBUG:
285
301
  return
286
302
  self.start_call = True
@@ -293,6 +309,7 @@ class Service:
293
309
  print_tools_ends_info()
294
310
  return
295
311
  if self.config.step and self.current_iter not in self.config.step:
312
+ JitDump.jit_dump_switch = False
296
313
  return
297
314
  self.model = self.check_model_valid(model)
298
315
 
@@ -308,20 +325,9 @@ class Service:
308
325
  return
309
326
  self.register_primitive_hook()
310
327
  self.register_cell_hook()
311
- if self.config.level in [Const.LEVEL_MIX, Const.LEVEL_L1]:
312
- JitDump.set_config(self.config)
313
- JitDump.set_data_collector(self.data_collector)
314
- if hasattr(ms.common.api, "_MindsporeFunctionExecutor"):
315
- ms.common.api._MindsporeFunctionExecutor = JitDump
316
- else:
317
- ms.common.api._JitExecutor = JitDump
318
- ms.common.api._PyNativeExecutor.grad = JitDump.grad
319
- if pijit_label:
320
- PIJitCaptureContext.__enter__ = self.empty
321
- PIJitCaptureContext.__exit__ = self.empty
322
328
  self.first_start = False
323
329
 
324
- api_register.api_set_hook_func()
330
+ self.api_register.register_all_api()
325
331
  self.switch = True
326
332
  self.primitive_switch = True
327
333
  logger.info(f"Dump switch is turned on at step {self.current_iter}. ")
@@ -410,8 +416,8 @@ class Service:
410
416
  def register_api_hook(self):
411
417
  if self.config.level in [Const.LEVEL_MIX, Const.LEVEL_L1, Const.LEVEL_L2]:
412
418
  logger.info(f"The api {self.config.task} hook function is successfully mounted to the model.")
413
- api_register.initialize_hook(functools.partial(self.build_hook, BaseScope.Module_Type_API))
414
- api_register.api_set_hook_func()
419
+ self.api_register.initialize_hook(functools.partial(self.build_hook, BaseScope.Module_Type_API))
420
+ self.api_register.register_all_api()
415
421
 
416
422
  def get_cells_and_names(self):
417
423
  cells_and_names_with_index = {}
@@ -40,7 +40,7 @@ from msprobe.pytorch.api_accuracy_checker.run_ut.run_ut_utils import get_validat
40
40
  from msprobe.pytorch.api_accuracy_checker.common.utils import extract_detailed_api_segments, extract_basic_api_segments
41
41
  from msprobe.core.common.file_utils import FileChecker, change_mode, create_directory
42
42
  from msprobe.pytorch.common.log import logger
43
- from msprobe.core.common.utils import CompareException
43
+ from msprobe.core.common.utils import CompareException, check_op_str_pattern_valid
44
44
  from msprobe.core.common.const import Const, CompareConst, FileCheckConst
45
45
 
46
46
  CompareConfig = namedtuple('CompareConfig', ['npu_csv_path', 'gpu_csv_path', 'result_csv_path', 'details_csv_path'])
@@ -151,6 +151,7 @@ def analyse_csv(npu_data, gpu_data, config):
151
151
  message = ''
152
152
  compare_column = ApiPrecisionOutputColumn()
153
153
  full_api_name_with_direction_status = row_npu[ApiPrecisionCompareColumn.API_NAME]
154
+ check_op_str_pattern_valid(full_api_name_with_direction_status)
154
155
  row_gpu = gpu_data[gpu_data[ApiPrecisionCompareColumn.API_NAME] == full_api_name_with_direction_status]
155
156
  api_name, api_full_name, direction_status = extract_detailed_api_segments(full_api_name_with_direction_status)
156
157
  if not api_full_name:
@@ -430,6 +431,7 @@ def _api_precision_compare(parser=None):
430
431
  _api_precision_compare_parser(parser)
431
432
  args = parser.parse_args(sys.argv[1:])
432
433
  _api_precision_compare_command(args)
434
+ logger.info("Compare task completed.")
433
435
 
434
436
 
435
437
  def _api_precision_compare_command(args):
@@ -457,8 +459,3 @@ def _api_precision_compare_parser(parser):
457
459
  parser.add_argument("-o", "--out_path", dest="out_path", default="", type=str,
458
460
  help="<optional> The api precision compare task result out path.",
459
461
  required=False)
460
-
461
-
462
- if __name__ == '__main__':
463
- _api_precision_compare()
464
- logger.info("Compare task completed.")
@@ -28,10 +28,10 @@ from msprobe.pytorch.api_accuracy_checker.compare.compare_utils import binary_st
28
28
  ulp_standard_api, thousandth_standard_api
29
29
  from msprobe.core.common.file_utils import FileOpen, load_json, save_json
30
30
  from msprobe.core.common.utils import check_file_or_directory_path, check_op_str_pattern_valid, is_int
31
- from msprobe.core.common.const import Const, MonitorConst, MsgConst
31
+ from msprobe.core.common.const import Const, MonitorConst, MsgConst, FileCheckConst
32
32
  from msprobe.core.common.log import logger
33
- from msprobe.core.common.file_utils import make_dir
34
- from msprobe.core.common.utils import recursion_depth_decorator
33
+ from msprobe.core.common.file_utils import make_dir, change_mode
34
+ from msprobe.core.common.decorator import recursion_depth_decorator
35
35
 
36
36
  TENSOR_DATA_LIST = ["torch.Tensor", "torch.nn.parameter.Parameter"]
37
37
  TORCH_BOOL_TYPE = ["torch.bool"]
@@ -50,6 +50,7 @@ DATA_NAME = "data_name"
50
50
  API_MAX_LENGTH = 30
51
51
  PROPAGATION_LIST = [Const.FORWARD, Const.BACKWARD]
52
52
  DATAMODE_LIST = ["random_data", "real_data"]
53
+ ITER_MAX_TIMES = 1000
53
54
 
54
55
 
55
56
  class APIInfo:
@@ -97,6 +98,8 @@ class CommonConfig:
97
98
  iter_t = self.iter_times
98
99
  if iter_t <= 0:
99
100
  raise ValueError("iter_times should be an integer bigger than zero!")
101
+ if iter_t > ITER_MAX_TIMES:
102
+ raise ValueError("iter_times should not be greater than 1000!")
100
103
 
101
104
  json_file = self.extract_api_path
102
105
  propagation = self.propagation
@@ -117,7 +120,7 @@ class CommonConfig:
117
120
 
118
121
  # Retrieve the first API name and dictionary
119
122
  forward_item = next(iter(json_content.items()), None)
120
- if not forward_item or not isinstance(forward_item[1], dict):
123
+ if not forward_item or not isinstance(forward_item[1], dict) or not forward_item[1]:
121
124
  raise ValueError(f'Invalid forward API data in json_content!')
122
125
 
123
126
  # if propagation is backward, ensure json file contains forward and backward info
@@ -127,7 +130,7 @@ class CommonConfig:
127
130
  # if propagation is backward, ensure it has valid data
128
131
  if propagation == Const.BACKWARD:
129
132
  backward_item = list(json_content.items())[1]
130
- if not isinstance(backward_item[1], dict):
133
+ if not isinstance(backward_item[1], dict) or not backward_item[1]:
131
134
  raise ValueError(f'Invalid backward API data in json_content!')
132
135
 
133
136
  return json_content
@@ -169,7 +172,7 @@ class APIExtractor:
169
172
  value = self.load_real_data_path(value, real_data_path)
170
173
  new_data[key] = value
171
174
  if not new_data:
172
- logger.error(f"Error: The api '{self.api_name}' does not exist in the file.")
175
+ logger.warning(f"Warning: The api '{self.api_name}' does not exist in the file.")
173
176
  else:
174
177
  save_json(self.output_file, new_data, indent=4)
175
178
  logger.info(
@@ -183,6 +186,7 @@ class APIExtractor:
183
186
  self.update_data_name(v, dump_data_dir)
184
187
  return value
185
188
 
189
+ @recursion_depth_decorator("OpGenerator: APIExtractor.update_data_name")
186
190
  def update_data_name(self, data, dump_data_dir):
187
191
  if isinstance(data, list):
188
192
  for item in data:
@@ -467,6 +471,7 @@ def _run_operator_generate_commond(cmd_args):
467
471
  fout.write(code_template.format(**internal_settings))
468
472
  except OSError:
469
473
  logger.error(f"Failed to open file. Please check file {template_path} or {operator_script_path}.")
474
+ change_mode(operator_script_path, FileCheckConst.DATA_FILE_AUTHORITY)
470
475
 
471
476
  logger.info(f"Generate operator script successfully and the name is {operator_script_path}.")
472
477
 
@@ -37,9 +37,9 @@ def load_pt(pt_path, to_cpu=False):
37
37
  pt_path = os.path.realpath(pt_path)
38
38
  try:
39
39
  if to_cpu:
40
- pt = torch.load(pt_path, map_location=torch.device("cpu"))
40
+ pt = torch.load(pt_path, map_location=torch.device("cpu"), weights_only=True)
41
41
  else:
42
- pt = torch.load(pt_path)
42
+ pt = torch.load(pt_path, weights_only=True)
43
43
  except Exception as e:
44
44
  raise RuntimeError(f"load pt file {{pt_path}} failed") from e
45
45
  return pt
@@ -50,6 +50,9 @@ def split_json_file(input_file, num_splits, filter_api):
50
50
  backward_data[f"{data_name}.backward"] = backward_data.pop(data_name)
51
51
 
52
52
  input_data = load_json(input_file)
53
+ if "dump_data_dir" not in input_data.keys():
54
+ logger.error("Invalid input file, 'dump_data_dir' field is missing")
55
+ raise CompareException("Invalid input file, 'dump_data_dir' field is missing")
53
56
  if input_data.get("data") is None:
54
57
  logger.error("Invalid input file, 'data' field is missing")
55
58
  raise CompareException("Invalid input file, 'data' field is missing")
@@ -97,7 +100,7 @@ def run_parallel_ut(config):
97
100
  processes = []
98
101
  device_id_cycle = cycle(config.device_id)
99
102
  if config.save_error_data_flag:
100
- logger.info("UT task error datas will be saved")
103
+ logger.info("UT task error data will be saved")
101
104
  logger.info(f"Starting parallel UT with {config.num_splits} processes")
102
105
  progress_bar = tqdm(total=config.total_items, desc="Total items", unit="items")
103
106
 
@@ -221,7 +224,3 @@ def main():
221
224
  args = parser.parse_args()
222
225
  config = prepare_config(args)
223
226
  run_parallel_ut(config)
224
-
225
-
226
- if __name__ == '__main__':
227
- main()
@@ -34,8 +34,10 @@ from msprobe.pytorch.api_accuracy_checker.run_ut.run_ut_utils import exec_api, i
34
34
  from msprobe.core.common.file_utils import check_link, FileChecker
35
35
  from msprobe.pytorch.api_accuracy_checker.common.utils import extract_basic_api_segments
36
36
  from msprobe.core.common.const import FileCheckConst, Const
37
+ from msprobe.core.common.utils import check_op_str_pattern_valid
37
38
  from msprobe.pytorch.common.log import logger
38
39
  from msprobe.pytorch.common.parse_json import parse_json_info_forward_backward
40
+ from msprobe.core.common.decorator import recursion_depth_decorator
39
41
 
40
42
 
41
43
  def check_tensor_overflow(x):
@@ -75,6 +77,7 @@ def check_data_overflow(x, device):
75
77
  return torch_npu.npu.utils.npu_check_overflow(x)
76
78
 
77
79
 
80
+ @recursion_depth_decorator("is_bool_output")
78
81
  def is_bool_output(x):
79
82
  if isinstance(x, (tuple, list)):
80
83
  if not x:
@@ -91,6 +94,7 @@ def run_overflow_check(forward_file):
91
94
  dump_path = os.path.dirname(forward_file)
92
95
  real_data_path = os.path.join(dump_path, Const.DUMP_TENSOR_DATA)
93
96
  for api_full_name, api_info_dict in tqdm(forward_content.items()):
97
+ check_op_str_pattern_valid(api_full_name)
94
98
  if is_unsupported_api(api_full_name, is_overflow_check=True):
95
99
  continue
96
100
  try:
@@ -161,6 +165,7 @@ def _run_overflow_check(parser=None):
161
165
  _run_overflow_check_parser(parser)
162
166
  args = parser.parse_args(sys.argv[1:])
163
167
  _run_overflow_check_command(args)
168
+ logger.info("UT task completed.")
164
169
 
165
170
 
166
171
  def _run_overflow_check_command(args):
@@ -175,8 +180,3 @@ def _run_overflow_check_command(args):
175
180
  logger.error(f"Set NPU device id failed. device id is: {args.device_id}")
176
181
  raise NotImplementedError from error
177
182
  run_overflow_check(api_info)
178
-
179
-
180
- if __name__ == '__main__':
181
- _run_overflow_check()
182
- logger.info("UT task completed.")
@@ -49,7 +49,7 @@ from msprobe.core.common.file_utils import FileChecker, change_mode, \
49
49
  from msprobe.pytorch.common.log import logger
50
50
  from msprobe.pytorch.pt_config import parse_json_config
51
51
  from msprobe.core.common.const import Const, FileCheckConst, CompareConst
52
- from msprobe.core.common.utils import safe_get_value, CompareException
52
+ from msprobe.core.common.utils import safe_get_value, CompareException, is_int, check_op_str_pattern_valid
53
53
  from msprobe.pytorch.common.utils import seed_all
54
54
  from msprobe.pytorch.api_accuracy_checker.tensor_transport_layer.attl import ATTL, ATTLConfig, move2device_exec
55
55
  from msprobe.pytorch.api_accuracy_checker.tensor_transport_layer.device_dispatch import ConsumerDispatcher
@@ -65,6 +65,7 @@ DETAILS_FILE_NAME = "accuracy_checking_details_" + current_time + ".csv"
65
65
 
66
66
  not_backward_list = ['repeat_interleave']
67
67
  unsupported_backward_list = ['masked_select']
68
+ unsupported_api_list = ["to"]
68
69
 
69
70
 
70
71
  tqdm_params = {
@@ -83,6 +84,9 @@ tqdm_params = {
83
84
  }
84
85
 
85
86
 
87
+ seed_all()
88
+
89
+
86
90
  def run_ut(config):
87
91
  logger.info("start UT test")
88
92
  if config.online_config.is_online:
@@ -93,7 +97,7 @@ def run_ut(config):
93
97
  logger.info(f"UT task details will be saved in {config.details_csv_path}")
94
98
 
95
99
  if config.save_error_data:
96
- logger.info(f"UT task error_datas will be saved in {config.error_data_path}")
100
+ logger.info(f"UT task error_data will be saved in {config.error_data_path}")
97
101
  compare = Comparator(config.result_csv_path, config.details_csv_path, config.is_continue_run_ut, config=config)
98
102
 
99
103
  if config.online_config.is_online:
@@ -117,6 +121,7 @@ def run_ut(config):
117
121
  def run_api_offline(config, compare, api_name_set):
118
122
  err_column = CompareColumn()
119
123
  for _, (api_full_name, api_info_dict) in enumerate(tqdm(config.forward_content.items(), **tqdm_params)):
124
+ check_op_str_pattern_valid(api_full_name)
120
125
  if api_full_name in api_name_set:
121
126
  continue
122
127
  if is_unsupported_api(api_full_name):
@@ -218,6 +223,7 @@ def blacklist_and_whitelist_filter(api_name, black_list, white_list):
218
223
  If api is both in black_list and black_list, black_list first.
219
224
  return: False for exec api, True for not exec
220
225
  """
226
+ black_list.extend(unsupported_api_list)
221
227
  if black_list and api_name in black_list:
222
228
  return True
223
229
  if white_list and api_name not in white_list:
@@ -317,7 +323,8 @@ def run_torch_api_online(api_full_name, api_data, backward_content):
317
323
  if kwargs.get("device"):
318
324
  del kwargs["device"]
319
325
 
320
- device_out = exec_api(api_type, api_name, Const.CUDA_LOWERCASE, args, kwargs)
326
+ device_exec_params = ExecParams(api_type, api_name, current_device, args, kwargs, False, None)
327
+ device_out = exec_api(device_exec_params)
321
328
  device_out = move2device_exec(device_out, "cpu")
322
329
  return UtDataInfo(None, None, out, device_out, None, in_fwd_data_list, None, rank=api_data.rank)
323
330
 
@@ -344,6 +351,9 @@ def need_to_backward(grad_index, out):
344
351
 
345
352
  def run_backward(args, grad, grad_index, out):
346
353
  if grad_index is not None:
354
+ if not is_int(grad_index):
355
+ logger.error(f"{grad_index} dtype is not int")
356
+ raise TypeError(f"{grad_index} dtype is not int")
347
357
  if grad_index >= len(out):
348
358
  logger.error(f"Run backward error when grad_index is {grad_index}")
349
359
  raise IndexError(f"Run backward error when grad_index is {grad_index}")
@@ -430,6 +440,7 @@ def preprocess_forward_content(forward_content):
430
440
  arg_cache = {}
431
441
 
432
442
  for key, value in forward_content.items():
443
+ check_op_str_pattern_valid(key)
433
444
  base_key = key.rsplit(Const.SEP, 1)[0]
434
445
 
435
446
  if key not in arg_cache:
@@ -469,6 +480,7 @@ def _run_ut(parser=None):
469
480
  _run_ut_parser(parser)
470
481
  args = parser.parse_args(sys.argv[1:])
471
482
  run_ut_command(args)
483
+
472
484
 
473
485
 
474
486
  def checked_online_config(online_config):
@@ -492,6 +504,7 @@ def checked_online_config(online_config):
492
504
  check_file_or_directory_path(os.path.join(online_config.tls_path, "server.key"))
493
505
  check_file_or_directory_path(os.path.join(online_config.tls_path, "server.crt"))
494
506
  check_crt_valid(os.path.join(online_config.tls_path, "server.crt"))
507
+ check_crt_valid(os.path.join(online_config.tls_path, "server.key"), True)
495
508
 
496
509
  # host and port
497
510
  if not isinstance(online_config.host, str) or not re.match(Const.ipv4_pattern, online_config.host):
@@ -561,7 +574,14 @@ def run_ut_command(args):
561
574
  error_data_path = checker_config.error_data_path
562
575
  if save_error_data:
563
576
  if args.result_csv_path:
564
- time_info = result_csv_path.split('.')[0].split('_')[-1]
577
+ parts_by_dot = result_csv_path.split(Const.SEP)
578
+ if len(parts_by_dot) < 2 or not parts_by_dot[0]:
579
+ raise ValueError("result_csv_path does not contain a valid file name with an extension.")
580
+ file_name_part = parts_by_dot[0]
581
+ parts_by_underscore = file_name_part.split(Const.REPLACEMENT_CHARACTER)
582
+ if len(parts_by_underscore) < 2:
583
+ raise ValueError("File name part does not contain enough '_' separated segments.")
584
+ time_info = parts_by_underscore[-1]
565
585
  global UT_ERROR_DATA_DIR
566
586
  UT_ERROR_DATA_DIR = 'ut_error_data' + time_info
567
587
  error_data_path = initialize_save_error_data(error_data_path)
@@ -579,9 +599,8 @@ def run_ut_command(args):
579
599
  }
580
600
  run_ut_config = checker_config.get_run_ut_config(**config_params)
581
601
  run_ut(run_ut_config)
602
+ logger.info("UT task completed.")
582
603
 
583
604
 
584
605
  if __name__ == '__main__':
585
- seed_all()
586
606
  _run_ut()
587
- logger.info("UT task completed.")
@@ -1,9 +1,7 @@
1
- #!/usr/bin/env python3
2
- # -*- coding: utf-8 -*-
3
- # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
1
+ # Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
4
2
  # All rights reserved.
5
3
  #
6
- # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
5
  # you may not use this file except in compliance with the License.
8
6
  # You may obtain a copy of the License at
9
7
  #
@@ -18,8 +16,8 @@
18
16
  import os
19
17
  from collections import namedtuple
20
18
  import re
21
- import torch
22
19
 
20
+ import torch
23
21
  try:
24
22
  import torch_npu
25
23
  except ImportError:
@@ -33,11 +31,9 @@ from msprobe.core.common.const import FileCheckConst, Const, CompareConst
33
31
  from msprobe.core.common.file_utils import FileChecker
34
32
  from msprobe.core.common.log import logger
35
33
  from msprobe.core.common.utils import CompareException
34
+ from msprobe.pytorch.hook_module.api_register import ApiTemplate, get_api_register
36
35
  from msprobe.pytorch.hook_module.wrap_aten import AtenOPTemplate
37
- from msprobe.pytorch.hook_module.wrap_functional import FunctionalOPTemplate
38
- from msprobe.pytorch.hook_module.wrap_npu_custom import NpuOPTemplate
39
- from msprobe.pytorch.hook_module.wrap_tensor import TensorOPTemplate
40
- from msprobe.pytorch.hook_module.wrap_torch import TorchOPTemplate
36
+
41
37
 
42
38
  hf_32_standard_api = ["conv1d", "conv2d"]
43
39
  not_detach_set = {'resize_', 'resize_as_', 'set_', 'transpose_', 't_', 'squeeze_', 'unsqueeze_'}
@@ -108,17 +104,30 @@ def exec_api(exec_params):
108
104
  kwargs = exec_params.kwargs
109
105
  is_autocast = exec_params.is_autocast
110
106
  autocast_dtype = exec_params.autocast_dtype
111
-
112
- if api_type == "Functional":
113
- torch_api = FunctionalOPTemplate(api_name, str, False)
114
- if api_type == "Tensor":
115
- torch_api = TensorOPTemplate(api_name, str, False)
116
- if api_type == "Torch":
117
- torch_api = TorchOPTemplate(api_name, str, False)
118
- if api_type == "Aten":
107
+ out = None
108
+
109
+ prefix_map = Const.API_DATA_PREFIX.get(Const.PT_FRAMEWORK, {})
110
+ if not prefix_map or api_type not in prefix_map.values() or \
111
+ api_type not in (
112
+ Const.FUNCTIONAL_API_TYPE_PREFIX,
113
+ Const.TENSOR_API_TYPE_PREFIX,
114
+ Const.TORCH_API_TYPE_PREFIX,
115
+ Const.ATEN_API_TYPE_PREFIX,
116
+ Const.NPU_API_TYPE_PREFIX
117
+ ):
118
+ return out
119
+
120
+ if api_type == Const.ATEN_API_TYPE_PREFIX:
119
121
  torch_api = AtenOPTemplate(api_name, None, False)
120
- if api_type == "NPU":
121
- torch_api = NpuOPTemplate(api_name, None, False, device)
122
+ else:
123
+ api_register = get_api_register()
124
+ api_register.initialize_hook(None)
125
+ api_func_type = list(prefix_map.keys())[list(prefix_map.values()).index(api_type)]
126
+ api_func = api_register.ori_api_attr.get(Const.PT_FRAMEWORK + Const.SEP + api_func_type, {}).get(api_name)
127
+ if api_func is None:
128
+ return out
129
+
130
+ torch_api = ApiTemplate(api_name, api_func, api_type, None, need_hook=False, device=device)
122
131
  if is_autocast:
123
132
  with autocast(dtype=autocast_dtype):
124
133
  out = torch_api.forward(*args, **kwargs)
@@ -27,6 +27,7 @@ from msprobe.pytorch.api_accuracy_checker.tensor_transport_layer.client import T
27
27
  from msprobe.pytorch.api_accuracy_checker.tensor_transport_layer.server import TCPServer
28
28
  from msprobe.core.common.file_utils import remove_path
29
29
  from msprobe.pytorch.common.utils import logger, save_api_data, load_api_data, save_pkl, load_pkl
30
+ from msprobe.core.common.decorator import recursion_depth_decorator
30
31
 
31
32
  BufferType = Union[ApiData, Dict[str, Any], str] # Union[Tensor, Tuple[Optional[Tensor]]]
32
33
 
@@ -168,11 +169,12 @@ class ATTL:
168
169
  return buffer
169
170
 
170
171
 
172
+ @recursion_depth_decorator("move2device_exec")
171
173
  def move2device_exec(obj, device):
172
174
  if isinstance(obj, (tuple, list)):
173
175
  data_list = [move2device_exec(val, device) for val in obj]
174
176
  return data_list if isinstance(obj, list) else tuple(data_list)
175
- if isinstance(obj, dict):
177
+ if isinstance(obj, dict):
176
178
  return {key: move2device_exec(val, device) for key, val in obj.items()}
177
179
  elif isinstance(obj, torch.Tensor):
178
180
  obj = obj.detach()
@@ -29,6 +29,8 @@ def softmax_func(x, axis=None):
29
29
 
30
30
  def npu_moe_gating_top_k_softmax(x, finished_optional, k):
31
31
  input_dtype = x.dtype
32
+ if x.dim() < 1:
33
+ raise ValueError("Input x must have at least 1 dimensions.")
32
34
  num_expert = x.shape[-1]
33
35
  softmax = softmax_func(x, -1)
34
36
  softmax = softmax.to(input_dtype)
@@ -36,9 +38,13 @@ def npu_moe_gating_top_k_softmax(x, finished_optional, k):
36
38
  expert_idx = expert_idx[:, :k]
37
39
  y = torch.gather(softmax, index=expert_idx, dim=-1)
38
40
  if finished_optional is not None:
41
+ if finished_optional.dim() < 1:
42
+ raise ValueError("Finished_optional must have at least 1 dimensions.")
39
43
  finished_optional = finished_optional.view(finished_optional.shape[0], 1)
40
44
  finished_optional = finished_optional.expand(-1, k)
41
45
  expert_idx = torch.where(finished_optional, num_expert, expert_idx)
46
+ if y.dim() < 2:
47
+ raise ValueError("Variable y must have at least 2 dimensions.")
42
48
  row_idx = torch.arange(y.shape[0] * y.shape[1]).reshape(y.shape[1], y.shape[0]).t()
43
49
 
44
50
  return y, expert_idx, row_idx