mindstudio-probe 1.1.1__py3-none-any.whl → 1.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (197) hide show
  1. {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.1.dist-info}/METADATA +3 -2
  2. {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.1.dist-info}/RECORD +196 -141
  3. msprobe/CMakeLists.txt +5 -0
  4. msprobe/README.md +14 -19
  5. msprobe/config.json +1 -0
  6. msprobe/core/common/const.py +155 -6
  7. msprobe/core/common/exceptions.py +3 -1
  8. msprobe/core/common/file_utils.py +33 -7
  9. msprobe/core/common/inplace_ops.yaml +3 -0
  10. msprobe/core/common/utils.py +28 -14
  11. msprobe/core/common_config.py +6 -0
  12. msprobe/core/compare/acc_compare.py +139 -128
  13. msprobe/core/compare/check.py +31 -29
  14. msprobe/core/compare/compare_cli.py +17 -16
  15. msprobe/core/compare/highlight.py +186 -99
  16. msprobe/core/compare/layer_mapping/data_scope_parser.py +18 -7
  17. msprobe/core/compare/layer_mapping/layer_mapping.py +21 -14
  18. msprobe/core/compare/layer_mapping/postprocess_pass.py +4 -3
  19. msprobe/core/compare/merge_result/merge_result.py +380 -0
  20. msprobe/core/compare/merge_result/merge_result_cli.py +31 -0
  21. msprobe/core/compare/multiprocessing_compute.py +2 -2
  22. msprobe/core/compare/npy_compare.py +109 -147
  23. msprobe/core/compare/utils.py +189 -69
  24. msprobe/core/data_dump/data_collector.py +51 -21
  25. msprobe/core/data_dump/data_processor/base.py +38 -20
  26. msprobe/core/data_dump/data_processor/factory.py +5 -3
  27. msprobe/core/data_dump/data_processor/mindspore_processor.py +154 -20
  28. msprobe/core/data_dump/data_processor/pytorch_processor.py +118 -58
  29. msprobe/core/data_dump/json_writer.py +29 -1
  30. msprobe/core/data_dump/scope.py +19 -18
  31. msprobe/core/overflow_check/abnormal_scene.py +9 -5
  32. msprobe/core/overflow_check/checker.py +1 -1
  33. msprobe/core/overflow_check/utils.py +1 -1
  34. msprobe/docs/01.installation.md +96 -17
  35. msprobe/docs/02.config_introduction.md +5 -5
  36. msprobe/docs/05.data_dump_PyTorch.md +91 -61
  37. msprobe/docs/06.data_dump_MindSpore.md +57 -19
  38. msprobe/docs/07.accuracy_checker_PyTorch.md +18 -18
  39. msprobe/docs/09.accuracy_checker_MindSpore.md +4 -4
  40. msprobe/docs/10.accuracy_compare_PyTorch.md +99 -41
  41. msprobe/docs/11.accuracy_compare_MindSpore.md +249 -48
  42. msprobe/docs/12.overflow_check_PyTorch.md +1 -1
  43. msprobe/docs/19.monitor.md +120 -27
  44. msprobe/docs/21.visualization_PyTorch.md +115 -35
  45. msprobe/docs/22.visualization_MindSpore.md +138 -41
  46. msprobe/docs/23.generate_operator_PyTorch.md +107 -0
  47. msprobe/docs/24.code_mapping_Mindspore.md +28 -0
  48. msprobe/docs/{23.tool_function_introduction.md → 25.tool_function_introduction.md} +1 -0
  49. msprobe/docs/26.data_dump_PyTorch_baseline.md +37 -0
  50. msprobe/docs/27.dump_json_instruction.md +521 -0
  51. msprobe/docs/FAQ.md +26 -2
  52. msprobe/docs/accuracy_checker_MindSpore/accuracy_checker_MindSpore_baseline.md +14 -0
  53. msprobe/docs/data_dump_MindSpore/data_dump_MindSpore_baseline.md +22 -0
  54. msprobe/docs/img/merge_result.png +0 -0
  55. msprobe/docs/img/visualization/fuzzy_match_ms.png +0 -0
  56. msprobe/docs/img/visualization/fuzzy_match_pt.png +0 -0
  57. msprobe/docs/img/visualization/tensorboard_1.png +0 -0
  58. msprobe/docs/img/visualization/tensorboard_2.png +0 -0
  59. msprobe/docs/img/visualization/vis_browser_1.png +0 -0
  60. msprobe/docs/img/visualization/vis_browser_2.png +0 -0
  61. msprobe/docs/img/visualization/vis_precision_info.png +0 -0
  62. msprobe/docs/img/visualization/vis_search_info.png +0 -0
  63. msprobe/docs/img/visualization/vis_show_info.png +0 -0
  64. msprobe/docs/img/visualization/vis_showcase.png +0 -0
  65. msprobe/docs/img/visualization/vis_unmatch_info.png +0 -0
  66. msprobe/docs/visualization/GPTModel.png +0 -0
  67. msprobe/docs/visualization/ParallelMLP.png +0 -0
  68. msprobe/docs/visualization/layer_mapping_example.md +132 -0
  69. msprobe/docs/visualization/mapping.png +0 -0
  70. msprobe/docs/visualization/mapping1.png +0 -0
  71. msprobe/docs/visualization/module_name.png +0 -0
  72. msprobe/docs/visualization/module_name1.png +0 -0
  73. msprobe/docs/visualization/no_mapping.png +0 -0
  74. msprobe/docs/visualization/no_mapping1.png +0 -0
  75. msprobe/docs/visualization/no_mapping_analyze.png +0 -0
  76. msprobe/docs/visualization/top_layer.png +0 -0
  77. msprobe/mindspore/__init__.py +10 -0
  78. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +57 -25
  79. msprobe/mindspore/api_accuracy_checker/cmd_parser.py +2 -1
  80. msprobe/mindspore/api_accuracy_checker/compute_element.py +5 -7
  81. msprobe/mindspore/api_accuracy_checker/data_manager.py +37 -0
  82. msprobe/mindspore/api_accuracy_checker/main.py +1 -0
  83. msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +12 -6
  84. msprobe/mindspore/api_accuracy_checker/multi_data_manager.py +3 -1
  85. msprobe/mindspore/code_mapping/bind.py +264 -0
  86. msprobe/mindspore/code_mapping/cmd_parser.py +40 -0
  87. msprobe/mindspore/code_mapping/graph.py +49 -0
  88. msprobe/mindspore/code_mapping/graph_parser.py +226 -0
  89. msprobe/mindspore/code_mapping/main.py +24 -0
  90. msprobe/mindspore/code_mapping/processor.py +34 -0
  91. msprobe/mindspore/common/const.py +3 -1
  92. msprobe/mindspore/common/utils.py +50 -5
  93. msprobe/mindspore/compare/distributed_compare.py +0 -2
  94. msprobe/mindspore/compare/ms_compare.py +105 -63
  95. msprobe/mindspore/compare/ms_graph_compare.py +14 -5
  96. msprobe/mindspore/debugger/debugger_config.py +3 -0
  97. msprobe/mindspore/debugger/precision_debugger.py +81 -12
  98. msprobe/mindspore/dump/hook_cell/api_registry.py +83 -16
  99. msprobe/mindspore/dump/hook_cell/hook_cell.py +60 -38
  100. msprobe/mindspore/dump/hook_cell/primitive_hooks.py +33 -15
  101. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +11 -1
  102. msprobe/mindspore/dump/hook_cell/wrap_api.py +92 -1
  103. msprobe/mindspore/dump/kernel_dump/kernel_config.py +33 -0
  104. msprobe/mindspore/dump/kernel_graph_dump.py +7 -0
  105. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +13 -4
  106. msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +2 -2
  107. msprobe/mindspore/grad_probe/grad_analyzer.py +24 -12
  108. msprobe/mindspore/grad_probe/hook.py +13 -4
  109. msprobe/mindspore/mindtorch/__init__.py +18 -0
  110. msprobe/mindspore/mindtorch/mindtorch_adaptor.py +255 -0
  111. msprobe/mindspore/ms_config.py +5 -1
  112. msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +7 -0
  113. msprobe/mindspore/service.py +267 -101
  114. msprobe/msprobe.py +24 -3
  115. msprobe/pytorch/__init__.py +7 -6
  116. msprobe/pytorch/api_accuracy_checker/common/utils.py +31 -16
  117. msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +41 -8
  118. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +100 -267
  119. msprobe/pytorch/api_accuracy_checker/compare/api_precision_standard.yaml +4 -1
  120. msprobe/pytorch/api_accuracy_checker/compare/compare.py +69 -68
  121. msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +54 -0
  122. msprobe/pytorch/api_accuracy_checker/compare/compare_input.py +51 -0
  123. msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +2 -4
  124. msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +54 -30
  125. msprobe/pytorch/api_accuracy_checker/precision_standard/absolute_threshold.py +106 -0
  126. msprobe/pytorch/api_accuracy_checker/precision_standard/accumulative_error_compare.py +107 -0
  127. msprobe/pytorch/api_accuracy_checker/precision_standard/base_standard.py +151 -0
  128. msprobe/pytorch/api_accuracy_checker/precision_standard/benchmark_compare.py +226 -0
  129. msprobe/pytorch/api_accuracy_checker/precision_standard/binary_consistency.py +68 -0
  130. msprobe/pytorch/api_accuracy_checker/precision_standard/standard_config.py +218 -0
  131. msprobe/pytorch/api_accuracy_checker/precision_standard/standard_register.py +104 -0
  132. msprobe/pytorch/api_accuracy_checker/precision_standard/thousandth_standard.py +63 -0
  133. msprobe/pytorch/api_accuracy_checker/precision_standard/ulp_compare.py +200 -0
  134. msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +57 -1
  135. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +2 -1
  136. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +42 -14
  137. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +64 -19
  138. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +34 -4
  139. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +5 -3
  140. msprobe/pytorch/bench_functions/npu_fusion_attention.py +42 -10
  141. msprobe/pytorch/common/parse_json.py +2 -1
  142. msprobe/pytorch/common/utils.py +45 -2
  143. msprobe/pytorch/compare/distributed_compare.py +17 -29
  144. msprobe/pytorch/compare/pt_compare.py +40 -20
  145. msprobe/pytorch/debugger/debugger_config.py +27 -12
  146. msprobe/pytorch/debugger/precision_debugger.py +42 -12
  147. msprobe/pytorch/dump/module_dump/__init__.py +0 -0
  148. msprobe/pytorch/dump/module_dump/module_dump.py +86 -0
  149. msprobe/pytorch/{module_processer.py → dump/module_dump/module_processer.py} +80 -6
  150. msprobe/pytorch/free_benchmark/common/params.py +2 -1
  151. msprobe/pytorch/free_benchmark/common/utils.py +3 -0
  152. msprobe/pytorch/free_benchmark/compare/grad_saver.py +0 -2
  153. msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +31 -47
  154. msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +0 -4
  155. msprobe/pytorch/hook_module/__init__.py +1 -1
  156. msprobe/pytorch/hook_module/hook_module.py +14 -11
  157. msprobe/pytorch/hook_module/register_optimizer_hook.py +59 -0
  158. msprobe/pytorch/hook_module/support_wrap_ops.yaml +34 -0
  159. msprobe/pytorch/hook_module/wrap_distributed.py +6 -8
  160. msprobe/pytorch/hook_module/wrap_functional.py +0 -40
  161. msprobe/pytorch/monitor/anomaly_analyse.py +1 -1
  162. msprobe/pytorch/monitor/anomaly_detect.py +107 -22
  163. msprobe/pytorch/monitor/csv2tb.py +166 -0
  164. msprobe/pytorch/monitor/distributed/wrap_distributed.py +25 -14
  165. msprobe/pytorch/monitor/features.py +3 -3
  166. msprobe/pytorch/monitor/module_hook.py +483 -277
  167. msprobe/pytorch/monitor/module_metric.py +27 -48
  168. msprobe/pytorch/monitor/module_spec_verifier.py +3 -1
  169. msprobe/pytorch/monitor/optimizer_collect.py +52 -14
  170. msprobe/pytorch/monitor/unittest/test_monitor.py +24 -9
  171. msprobe/pytorch/monitor/utils.py +77 -6
  172. msprobe/pytorch/online_dispatch/dispatch.py +8 -2
  173. msprobe/pytorch/parse_tool/lib/compare.py +10 -10
  174. msprobe/pytorch/parse_tool/lib/config.py +5 -7
  175. msprobe/pytorch/parse_tool/lib/file_desc.py +15 -1
  176. msprobe/pytorch/parse_tool/lib/interactive_cli.py +10 -10
  177. msprobe/pytorch/parse_tool/lib/parse_exception.py +7 -7
  178. msprobe/pytorch/parse_tool/lib/parse_tool.py +11 -10
  179. msprobe/pytorch/parse_tool/lib/utils.py +18 -19
  180. msprobe/pytorch/parse_tool/lib/visualization.py +9 -10
  181. msprobe/pytorch/service.py +176 -106
  182. msprobe/visualization/builder/graph_builder.py +62 -5
  183. msprobe/visualization/builder/msprobe_adapter.py +24 -2
  184. msprobe/visualization/compare/graph_comparator.py +64 -14
  185. msprobe/visualization/compare/mode_adapter.py +1 -15
  186. msprobe/visualization/graph/base_node.py +12 -17
  187. msprobe/visualization/graph/distributed_analyzer.py +318 -0
  188. msprobe/visualization/graph/graph.py +9 -0
  189. msprobe/visualization/graph_service.py +97 -23
  190. msprobe/visualization/utils.py +14 -29
  191. msprobe/pytorch/functional/module_dump.py +0 -84
  192. {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.1.dist-info}/LICENSE +0 -0
  193. {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.1.dist-info}/WHEEL +0 -0
  194. {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.1.dist-info}/entry_points.txt +0 -0
  195. {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.1.dist-info}/top_level.txt +0 -0
  196. /msprobe/docs/{data_dump_Mindspore → data_dump_MindSpore}/dynamic_graph_quick_start_example.md +0 -0
  197. /msprobe/{pytorch/functional → mindspore/code_mapping}/__init__.py +0 -0
@@ -1,4 +1,4 @@
1
- # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
1
+ # Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
2
2
  # All rights reserved.
3
3
  #
4
4
  # Licensed under the Apache License, Version 2.0 (the "License");
@@ -21,7 +21,7 @@ from dataclasses import dataclass
21
21
 
22
22
  import numpy as np
23
23
 
24
- from msprobe.core.common.const import Const, CompareConst
24
+ from msprobe.core.common.const import Const, CompareConst, FileCheckConst
25
25
  from msprobe.core.common.utils import CompareException, check_regex_prefix_format_valid, logger, safe_get_value
26
26
  from msprobe.core.common.file_utils import check_file_or_directory_path
27
27
 
@@ -37,13 +37,20 @@ def extract_json(dirname, stack_json=False):
37
37
  # Provide robustness on invalid directory inputs
38
38
  if not json_path:
39
39
  if stack_json:
40
- logger.error(f'stack.json is not found in dump dir {dirname}.')
40
+ logger.warning(f'stack.json is not found in dump dir {dirname}.')
41
41
  else:
42
42
  logger.error(f'dump.json is not found in dump dir {dirname}.')
43
- raise CompareException(CompareException.NO_DUMP_FILE_ERROR)
43
+ raise CompareException(CompareException.NO_DUMP_FILE_ERROR)
44
44
  return json_path
45
45
 
46
46
 
47
+ def set_stack_json_path(input_param):
48
+ npu_data_dir = os.path.dirname(input_param.get("npu_json_path"))
49
+ stack_path = extract_json(npu_data_dir, stack_json=True)
50
+ input_param["stack_json_path"] = stack_path if stack_path else None
51
+ return bool(stack_path)
52
+
53
+
47
54
  def check_and_return_dir_contents(dump_dir, prefix):
48
55
  """
49
56
  check the given dump dir and validate files in dump dir by using the given prefix patterns to build a
@@ -75,6 +82,10 @@ def check_and_return_dir_contents(dump_dir, prefix):
75
82
 
76
83
 
77
84
  def rename_api(npu_name, process):
85
+ """
86
+ 原api: {api_type}.{api_name}.{API调用次数}.{前向反向}.{input/output}.{参数序号}
87
+ rename后: {api_type}.{api_name}.{input/output}.{参数序号}
88
+ """
78
89
  npu_split = npu_name.split(process)
79
90
  try:
80
91
  torch_func_index, in_out = npu_split[0], npu_split[1]
@@ -87,17 +98,13 @@ def rename_api(npu_name, process):
87
98
 
88
99
 
89
100
  def read_op(op_data, op_name):
90
- io_name_mapping = {
91
- Const.INPUT_ARGS: '.input',
92
- Const.INPUT_KWARGS: '.input',
93
- Const.INPUT: '.input',
94
- Const.OUTPUT: '.output'
95
- }
96
-
97
- op_parsed_list = []
98
- for name in io_name_mapping:
99
- if name in op_data:
100
- op_parsed_list.extend(op_item_parse(op_data[name], op_name + io_name_mapping[name]))
101
+ if Const.PARAMS_GRAD in op_name.split(Const.SEP):
102
+ op_parsed_list = op_item_parse(op_data, op_name)
103
+ else:
104
+ op_parsed_list = []
105
+ for name in CompareConst.IO_NAME_MAPPING:
106
+ if name in op_data:
107
+ op_parsed_list.extend(op_item_parse(op_data[name], op_name + CompareConst.IO_NAME_MAPPING[name]))
101
108
  return op_parsed_list
102
109
 
103
110
 
@@ -124,11 +131,14 @@ def op_item_parse(op_data, op_name: str, depth: int = 0) -> list:
124
131
  return [default_item]
125
132
  elif not op_data:
126
133
  return []
127
-
134
+
128
135
  item_list = []
129
136
  if isinstance(op_data, list):
130
137
  for i, data in enumerate(op_data):
131
- item_list.extend(op_item_parse(data, op_name + Const.SEP + str(i), depth + 1))
138
+ if Const.PARAMS_GRAD not in op_name.split(Const.SEP):
139
+ item_list.extend(op_item_parse(data, op_name + Const.SEP + str(i), depth + 1))
140
+ else:
141
+ item_list.extend(op_item_parse(data, op_name, depth + 1))
132
142
  elif isinstance(op_data, dict):
133
143
  if is_leaf_data(op_data):
134
144
  return [gen_op_item(op_data, op_name)]
@@ -144,14 +154,15 @@ def is_leaf_data(op_data):
144
154
  def gen_op_item(op_data, op_name):
145
155
  op_item = {}
146
156
  op_item.update(op_data)
147
- op_item['full_op_name'] = op_name
148
- op_item['data_name'] = op_data.get('data_name', '-1')
157
+ data_name = op_data.get('data_name') if op_data.get('data_name') else '-1' # 如果是""也返回-1
158
+ op_item['data_name'] = data_name
159
+ op_item['full_op_name'] = data_name.rsplit(Const.SEP, 1)[0] if data_name != '-1' else op_name
149
160
 
150
161
  params = ['Max', 'Min', 'Mean', 'Norm']
151
162
  for i in params:
152
163
  if i not in op_item:
153
164
  op_item[i] = None
154
-
165
+
155
166
  if not op_item.get('dtype'):
156
167
  if op_item.get('type') == 'torch.Size':
157
168
  op_item['dtype'] = op_data.get('type')
@@ -166,7 +177,7 @@ def gen_op_item(op_data, op_name):
166
177
  op_item[i] = op_data.get('value')
167
178
  if not op_item.get('md5'):
168
179
  op_item['md5'] = f"{zlib.crc32(str(op_data.get('value', '')).encode()):08x}"
169
-
180
+
170
181
  return op_item
171
182
 
172
183
 
@@ -276,6 +287,22 @@ def result_item_init(n_info, b_info, dump_mode):
276
287
  return result_item
277
288
 
278
289
 
290
+ def count_struct(op_dict):
291
+ parts = [
292
+ CompareConst.OP_NAME,
293
+ CompareConst.INPUT_STRUCT,
294
+ CompareConst.OUTPUT_STRUCT,
295
+ CompareConst.PARAMS_STRUCT,
296
+ CompareConst.PARAMS_GRAD_STRUCT
297
+ ]
298
+ lengths = [len(op_dict.get(part, [])) for part in parts]
299
+ num = lengths[0]
300
+ if num != sum(lengths[1:]):
301
+ logger.error(f"Length of names and structs of op_dict not match. Please check! op_dict: {op_dict}")
302
+ raise CompareException(CompareException.NAMES_STRUCTS_MATCH_ERROR)
303
+ return tuple(lengths)
304
+
305
+
279
306
  def get_accuracy(result, n_dict, b_dict, dump_mode):
280
307
  def get_accuracy_core(n_start, n_len, b_start, b_len, key):
281
308
  min_len = min(n_len, b_len)
@@ -355,31 +382,50 @@ def get_accuracy(result, n_dict, b_dict, dump_mode):
355
382
 
356
383
  result.append(result_item)
357
384
 
358
- n_num = len(n_dict['op_name'])
359
- b_num = len(b_dict['op_name'])
360
- n_num_input = len([name for name in n_dict['op_name']
361
- if Const.INPUT in name.split(Const.SEP) or Const.KWARGS in name.split(Const.SEP)])
362
- b_num_input = len([name for name in b_dict['op_name']
363
- if Const.INPUT in name.split(Const.SEP) or Const.KWARGS in name.split(Const.SEP)])
364
- n_num_output = n_num - n_num_input
365
- b_num_output = b_num - b_num_input
366
- get_accuracy_core(0, n_num_input, 0, b_num_input, 'input_struct')
367
- get_accuracy_core(n_num_input, n_num_output, b_num_input, b_num_output, 'output_struct')
385
+ n_num, n_num_input, n_num_output, n_num_params, n_num_params_grad = count_struct(n_dict)
386
+ b_num, b_num_input, b_num_output, b_num_params, b_num_params_grad = count_struct(b_dict)
387
+
388
+ get_accuracy_core(0, n_num_input, 0, b_num_input, CompareConst.INPUT_STRUCT)
389
+ get_accuracy_core(n_num_input + n_num_output, n_num_params, b_num_input + b_num_output, b_num_params,
390
+ CompareConst.PARAMS_STRUCT)
391
+ get_accuracy_core(n_num_input, n_num_output, b_num_input, b_num_output, CompareConst.OUTPUT_STRUCT)
392
+ get_accuracy_core(n_num_input + n_num_output + n_num_params, n_num_params_grad,
393
+ b_num_input + b_num_output + b_num_params, b_num_params_grad,
394
+ CompareConst.PARAMS_GRAD_STRUCT)
395
+
396
+
397
+ def append_stack_info(result_item, npu_stack_info, index):
398
+ """添加堆栈信息到 result_item"""
399
+ if npu_stack_info and index == 0:
400
+ result_item.extend(npu_stack_info)
401
+ else:
402
+ result_item.append(CompareConst.NONE)
368
403
 
369
404
 
370
405
  def get_un_match_accuracy(result, n_dict, dump_mode):
371
- index_out = 0
372
406
  npu_stack_info = n_dict.get("stack_info", None)
373
407
  bench_name, bench_type, bench_shape = CompareConst.N_A, CompareConst.N_A, CompareConst.N_A
374
- err_msg = CompareConst.NO_BENCH
375
- accuracy_check_res = CompareConst.N_A
376
- for index, n_name in enumerate(n_dict["op_name"]):
377
- name_ele_list = n_name.split(Const.SEP)
378
- if Const.INPUT in name_ele_list or Const.KWARGS in name_ele_list:
379
- n_struct = safe_get_value(n_dict, index, "n_dict", key=CompareConst.INPUT_STRUCT)
380
- if Const.OUTPUT in name_ele_list:
381
- n_struct = safe_get_value(n_dict, index_out, "n_dict", key=CompareConst.OUTPUT_STRUCT)
382
- index_out += 1
408
+
409
+ struct_to_index_mapping = {
410
+ CompareConst.INPUT_STRUCT: 0,
411
+ CompareConst.OUTPUT_STRUCT: 0,
412
+ CompareConst.PARAMS_STRUCT: 0,
413
+ CompareConst.PARAMS_GRAD_STRUCT: 0
414
+ }
415
+
416
+ op_name_list = n_dict.get(CompareConst.OP_NAME)
417
+ summary_list = n_dict.get(Const.SUMMARY)
418
+ data_name_list = n_dict.get('data_name')
419
+ op_name_reorder, summary_reorder, _ = reorder_op_x_list(op_name_list,
420
+ summary_list,
421
+ data_name_list)
422
+ for index, n_name in enumerate(op_name_reorder):
423
+ _, state = get_name_and_state(n_name)
424
+ struct_key = CompareConst.STATE_TO_STRUCT_MAPPING.get(state)
425
+ if not struct_key:
426
+ continue
427
+ n_struct = safe_get_value(n_dict, struct_to_index_mapping.get(struct_key), "n_dict", key=struct_key)
428
+ struct_to_index_mapping[struct_key] += 1
383
429
 
384
430
  try:
385
431
  result_item = [n_name, bench_name, n_struct[0], bench_type, n_struct[1], bench_shape]
@@ -390,28 +436,26 @@ def get_un_match_accuracy(result, n_dict, dump_mode):
390
436
  f"output_struct of n_dict is {n_dict[CompareConst.OUTPUT_STRUCT]}"
391
437
  logger.error(err_msg)
392
438
  raise CompareException(CompareException.INDEX_OUT_OF_BOUNDS_ERROR) from e
439
+
393
440
  if dump_mode == Const.MD5:
394
441
  result_item.extend([CompareConst.N_A] * 3)
395
- if npu_stack_info and index == 0:
396
- result_item.extend(npu_stack_info)
397
- else:
398
- result_item.append(CompareConst.NONE)
442
+ append_stack_info(result_item, npu_stack_info, index)
399
443
  result.append(result_item)
400
444
  continue
401
445
  if dump_mode == Const.SUMMARY:
402
446
  result_item.extend([CompareConst.N_A] * 8)
403
- else:
447
+ if dump_mode == Const.ALL:
404
448
  result_item.extend([CompareConst.N_A] * 5)
405
- npu_summary_data = safe_get_value(n_dict, index, "n_dict", key=CompareConst.SUMMARY)
406
- result_item.extend(npu_summary_data)
449
+
450
+ npu_summary_data = safe_get_value(summary_reorder, index, "summary_reorder")
407
451
  bench_summary_data = [CompareConst.N_A] * 4
452
+ result_item.extend(npu_summary_data)
408
453
  result_item.extend(bench_summary_data)
454
+ err_msg = CompareConst.NO_BENCH
455
+ accuracy_check_res = CompareConst.N_A
409
456
  result_item.append(accuracy_check_res)
410
457
  result_item.append(err_msg)
411
- if npu_stack_info and index == 0:
412
- result_item.extend(npu_stack_info)
413
- else:
414
- result_item.append(CompareConst.NONE)
458
+ append_stack_info(result_item, npu_stack_info, index)
415
459
  if dump_mode == Const.ALL and result_item[1] == CompareConst.N_A:
416
460
  result_item.extend(["-1"])
417
461
  result.append(result_item)
@@ -423,6 +467,8 @@ def merge_tensor(tensor_list, dump_mode):
423
467
  op_dict[CompareConst.INPUT_STRUCT] = []
424
468
  op_dict[CompareConst.KWARGS_STRUCT] = []
425
469
  op_dict[CompareConst.OUTPUT_STRUCT] = []
470
+ op_dict[CompareConst.PARAMS_STRUCT] = []
471
+ op_dict[CompareConst.PARAMS_GRAD_STRUCT] = []
426
472
  op_dict[Const.SUMMARY] = []
427
473
  op_dict["stack_info"] = []
428
474
 
@@ -430,30 +476,25 @@ def merge_tensor(tensor_list, dump_mode):
430
476
  op_dict["data_name"] = []
431
477
 
432
478
  for tensor in tensor_list:
479
+ # A dict(len=2) with 'full_op_name' and 'full_info' is added to the tensor only if self.stack_mode is True
433
480
  if len(tensor) == 2:
434
481
  op_dict['stack_info'].append(tensor['full_info'])
435
482
  break
483
+
436
484
  op_dict["op_name"].append(tensor['full_op_name'])
437
- name_ele_list = tensor['full_op_name'].split(Const.SEP)
438
- name_to_struct_mapping = {
439
- Const.INPUT: CompareConst.INPUT_STRUCT,
440
- Const.KWARGS: CompareConst.KWARGS_STRUCT,
441
- Const.OUTPUT: CompareConst.OUTPUT_STRUCT
442
- }
443
- for name_key, struct_key in name_to_struct_mapping.items():
444
- if name_key in name_ele_list:
445
- if dump_mode == Const.MD5:
446
- op_dict.get(struct_key).append((tensor[Const.DTYPE], tensor[Const.SHAPE], tensor[Const.MD5]))
447
- else:
448
- op_dict.get(struct_key).append((tensor[Const.DTYPE], tensor[Const.SHAPE]))
449
- break
485
+
486
+ _, state = get_name_and_state(tensor['full_op_name'])
487
+ struct_key = CompareConst.STATE_TO_STRUCT_MAPPING.get(state)
488
+ if not struct_key:
489
+ continue
490
+ if dump_mode == Const.MD5:
491
+ op_dict.get(struct_key).append((tensor[Const.DTYPE], tensor[Const.SHAPE], tensor[Const.MD5]))
492
+ else:
493
+ op_dict.get(struct_key).append((tensor[Const.DTYPE], tensor[Const.SHAPE]))
450
494
  op_dict[Const.SUMMARY].append([tensor[Const.MAX], tensor[Const.MIN], tensor[Const.MEAN], tensor[Const.NORM]])
451
495
 
452
496
  if dump_mode == Const.ALL:
453
497
  op_dict["data_name"].append(tensor['data_name'])
454
- data_name = safe_get_value(op_dict, -1, "op_dict", key="data_name").rsplit(Const.SEP, 1)[0]
455
- if data_name != "-1":
456
- op_dict["op_name"][-1] = data_name
457
498
 
458
499
  if not op_dict[CompareConst.KWARGS_STRUCT]:
459
500
  del op_dict[CompareConst.KWARGS_STRUCT]
@@ -467,11 +508,90 @@ def print_compare_ends_info():
467
508
  logger.info('*' * total_len)
468
509
 
469
510
 
511
+ def table_value_is_valid(value: str) -> bool:
512
+ if not isinstance(value, str):
513
+ return True
514
+ try:
515
+ # -1.00 or +1.00 should be consdiered as digit numbers
516
+ float(value)
517
+ except ValueError:
518
+ # otherwise, they will be considered as formular injections
519
+ return not bool(re.compile(FileCheckConst.CSV_BLACK_LIST).search(value))
520
+ return True
521
+
522
+
523
+ def get_name_and_state(name):
524
+ """
525
+ Get api/module name and state
526
+ example:
527
+ name = 'conv2d.forward.1.input.0'
528
+ return: ('conv2d.forward.1.', 'input')
529
+
530
+ name = 'Functional.pad.0.backward.output.0'
531
+ return: ('Functional.pad.0.backward.', 'output')
532
+
533
+ state type: input, output, kwargs, parameters, parameters_grad
534
+ """
535
+ if Const.PARAMS_GRAD in name.split(Const.SEP):
536
+ return name.split(Const.PARAMS_GRAD)[0], Const.PARAMS_GRAD
537
+
538
+ split = re.split(Const.REGEX_FORWARD_BACKWARD, name)
539
+ api = f'{split[0]}.{split[1]}.'
540
+ state_str = split[2]
541
+ match = re.match(r'^(\d+\.)?(input|output|kwargs|parameters)\..+$', state_str)
542
+ if not match:
543
+ raise CompareException(f'Invalid name string: {name}')
544
+ if match.group(1):
545
+ api = f'{api}{match.group(1)}'
546
+ state = match.group(2)
547
+ return api, state
548
+
549
+
550
+ def reorder_op_name_list(op_name_list):
551
+ if not op_name_list:
552
+ return op_name_list
553
+
554
+ parameters = []
555
+ output = []
556
+ parameters_grad = []
557
+ others = []
558
+ for x in op_name_list:
559
+ state = get_name_and_state(x)[1]
560
+ if state == Const.PARAMS:
561
+ parameters.append(x)
562
+ elif state == Const.OUTPUT:
563
+ output.append(x)
564
+ elif state == Const.PARAMS_GRAD:
565
+ parameters_grad.append(x)
566
+ else:
567
+ others.append(x)
568
+ # 合并others, parameters, 和output,确保parameters排在output前面
569
+ op_name_reorder = others + parameters + output + parameters_grad
570
+ return op_name_reorder
571
+
572
+
573
+ def reorder_op_x_list(op_name_list, summary_list, data_name_list):
574
+ """对op_name, summary, data_name重新排序,把parameters放到input后output前,data_name由于统计量比对时,为None,单独处理"""
575
+ if not op_name_list or not summary_list:
576
+ return op_name_list, summary_list, data_name_list
577
+
578
+ index_map = {name: index for index, name in enumerate(op_name_list)}
579
+
580
+ op_name_reorder = reorder_op_name_list(op_name_list)
581
+ summary_reorder = [summary_list[index_map.get(name)] for name in op_name_reorder]
582
+ if data_name_list:
583
+ data_name_reorder = [data_name_list[index_map.get(name)] for name in op_name_reorder]
584
+ else:
585
+ data_name_reorder = data_name_list
586
+
587
+ return op_name_reorder, summary_reorder, data_name_reorder
588
+
589
+
470
590
  def _compare_parser(parser):
471
591
  parser.add_argument("-i", "--input_path", dest="input_path", type=str,
472
592
  help="<Required> The compare input path, a dict json.", required=True)
473
593
  parser.add_argument("-o", "--output_path", dest="output_path", type=str,
474
- help="<Required> The compare task result out path. Default path: ./output",
594
+ help="<Required> The compare task result out path. Default path: ./output",
475
595
  required=False, default="./output", nargs="?", const="./output")
476
596
  parser.add_argument("-s", "--stack_mode", dest="stack_mode", action="store_true",
477
597
  help="<optional> Whether to save stack info.", required=False)
@@ -38,6 +38,8 @@ class DataCollector:
38
38
  self.module_processor = DataProcessorFactory.get_module_processor(self.config.framework)
39
39
  self.module_count = {}
40
40
  self.scope = ScopeFactory(self.config).build_scope()
41
+ self.backward_module_names = {}
42
+ self.optimizer_status = ""
41
43
  atexit.register(self.write_json)
42
44
 
43
45
  @property
@@ -52,10 +54,6 @@ class DataCollector:
52
54
  def check_scope_and_pid(scope, name, pid):
53
55
  return (not scope or scope.check(name)) and pid == os.getpid()
54
56
 
55
- @staticmethod
56
- def is_inplace(module):
57
- return getattr(module, "op_is_inplace", False)
58
-
59
57
  def if_return_forward_new_output(self):
60
58
  return self.data_processor.if_return_forward_new_output()
61
59
 
@@ -79,32 +77,38 @@ class DataCollector:
79
77
  logger.debug(msg)
80
78
  self.data_writer.update_data(data_info)
81
79
 
82
- def pre_forward_data_collect(self, name, module, pid, module_input_output):
83
- if self.config.level == Const.LEVEL_L2 and self.check_scope_and_pid(self.scope, name, pid):
84
- self.data_processor.analyze_pre_forward(name, module, module_input_output)
80
+ def forward_input_data_collect(self, name, module, pid, module_input_output):
81
+ if self.config.task == Const.FREE_BENCHMARK:
82
+ backward_name = name.replace(Const.FORWARD, Const.BACKWARD)
83
+ if self.check_scope_and_pid(self.scope, backward_name, pid):
84
+ self.data_processor.analyze_forward_input(backward_name, module, module_input_output)
85
+ return
86
+
87
+ if not self.check_scope_and_pid(self.scope, name, pid):
85
88
  return
86
89
 
87
- backward_name = name.replace(Const.FORWARD, Const.BACKWARD)
88
- if self.check_scope_and_pid(self.scope, backward_name, pid):
89
- self.data_processor.analyze_pre_forward(backward_name, module, module_input_output)
90
- if not self.is_inplace(module) or not self.check_scope_and_pid(self.scope, name, pid):
90
+ data_info = self.data_processor.analyze_forward_input(name, module, module_input_output)
91
+ if self.config.level == Const.LEVEL_L2:
91
92
  return
92
- logger.info(f"API {name} is inplace.")
93
- data_info = self.data_processor.analyze_pre_forward_inplace(name, module_input_output)
94
93
  self.handle_data(name, data_info, flush=self.data_processor.is_terminated)
95
94
 
96
- def forward_data_collect(self, name, module, pid, module_input_output):
95
+ def forward_output_data_collect(self, name, module, pid, module_input_output):
97
96
  self.update_construct(name)
98
97
  if not self.check_scope_and_pid(self.scope, name, pid):
99
98
  return
99
+
100
+ data_info = self.data_processor.analyze_forward_output(name, module, module_input_output)
100
101
  if self.config.level == Const.LEVEL_L2:
101
- self.data_processor.analyze_forward(name, module, module_input_output)
102
102
  return
103
+ self.data_writer.update_stack(self.data_processor.analyze_api_call_stack(name))
104
+ self.handle_data(name, data_info, flush=self.data_processor.is_terminated)
103
105
 
104
- if not self.is_inplace(module):
105
- data_info = self.data_processor.analyze_forward(name, module, module_input_output)
106
- else:
107
- data_info = self.data_processor.analyze_forward_inplace(name, module_input_output)
106
+ def forward_data_collect(self, name, module, pid, module_input_output):
107
+ self.update_construct(name)
108
+ if not self.check_scope_and_pid(self.scope, name, pid):
109
+ return
110
+
111
+ data_info = self.data_processor.analyze_forward(name, module, module_input_output)
108
112
  self.data_writer.update_stack(self.data_processor.analyze_api_call_stack(name))
109
113
  self.handle_data(name, data_info, flush=self.data_processor.is_terminated)
110
114
 
@@ -116,6 +120,11 @@ class DataCollector:
116
120
  data_info = self.data_processor.analyze_backward(name, module, module_input_output)
117
121
  if self.config.level == Const.LEVEL_L2:
118
122
  return
123
+ # 获取执行反向的模块名称
124
+ if data_info and name.split(Const.SEP)[0] in Const.MODULE_PREFIX:
125
+ module_name = name.rsplit(Const.SEP, 2)[0]
126
+ # 将模块名称加入到反向模块名称集合中,用于梯度收集时判断是否需要收集梯度
127
+ self.backward_module_names[module_name] = True
119
128
  self.handle_data(name, data_info, flush=self.data_processor.is_terminated)
120
129
 
121
130
  def backward_input_data_collect(self, name, module, pid, module_input_output):
@@ -136,12 +145,17 @@ class DataCollector:
136
145
 
137
146
  def update_construct(self, name):
138
147
  if self.config.level not in DataCollector.level_without_construct:
139
- self.data_writer.update_construct({name: self.module_processor.api_parent_node})
148
+ if self.optimizer_status in [Const.OPTIMIZER, Const.CLIP_GRAD]:
149
+ self.data_writer.update_construct({name: self.optimizer_status})
150
+ else:
151
+ self.data_writer.update_construct({name: self.module_processor.api_parent_node})
140
152
  self.data_writer.update_construct(self.module_processor.module_node)
141
153
 
142
154
  def handle_data(self, name, data_info, flush=False):
143
155
  if data_info:
144
156
  self.update_data(name, data_info)
157
+ if self.config.async_dump:
158
+ return
145
159
  if not flush:
146
160
  self.data_writer.flush_data_periodically()
147
161
  else:
@@ -149,7 +163,23 @@ class DataCollector:
149
163
 
150
164
  def update_dump_paths(self, *args):
151
165
  self.data_writer.update_dump_paths(*args)
152
- self.data_writer.initialize_json_file(task=self.config.task, level=self.config.level)
166
+
167
+ def initialize_json_file(self, framework=Const.UNKNOWN_FRAMEWORK):
168
+ self.data_writer.initialize_json_file(task=self.config.task, level=self.config.level, framework=framework)
153
169
 
154
170
  def update_iter(self, current_iter):
155
171
  self.data_processor.update_iter(current_iter)
172
+
173
+ def params_data_collect(self, name, param_name, pid, data):
174
+ grad_name = name + Const.SEP + Const.PARAMS_GRAD
175
+ # 校验scope和pid,以及当前name是否有过反向计算
176
+ if not self.check_scope_and_pid(self.scope, name, pid) and not self.backward_module_names.get(name):
177
+ # 如果没有反向计算,则需要清除之前占位写入的grad数据
178
+ if self.data_writer.cache_data.get("data"):
179
+ self.data_writer.cache_data.get("data").pop(grad_name, None)
180
+ return
181
+ data_info = self.data_processor.analyze_params(grad_name, param_name, data)
182
+ self.handle_data(grad_name, data_info, flush=self.data_processor.is_terminated)
183
+
184
+ def fill_stack_tensor_data(self):
185
+ self.data_writer.fill_stack_tensor_data()
@@ -39,9 +39,8 @@ class ModuleForwardInputsOutputs:
39
39
  def output_tuple(self):
40
40
  return convert_tuple(self.output)
41
41
 
42
- def concat_args_and_kwargs(self):
43
- args = self.args + tuple(self.kwargs.values())
44
- return args
42
+ def update_output_with_args_and_kwargs(self):
43
+ self.output = self.args + tuple(self.kwargs.values())
45
44
 
46
45
 
47
46
  @dataclass
@@ -77,11 +76,12 @@ class ModuleBackwardOutputs:
77
76
 
78
77
 
79
78
  class TensorStatInfo:
80
- def __init__(self, max_val=None, min_val=None, mean_val=None, norm_val=None):
79
+ def __init__(self, max_val=None, min_val=None, mean_val=None, norm_val=None, stack_tensor_stat=None):
81
80
  self.max = max_val
82
81
  self.min = min_val
83
82
  self.mean = mean_val
84
83
  self.norm = norm_val
84
+ self.stack_tensor_stat = stack_tensor_stat
85
85
 
86
86
 
87
87
  class BaseDataProcessor:
@@ -102,6 +102,7 @@ class BaseDataProcessor:
102
102
  self.current_iter = 0
103
103
  self._return_forward_new_output = False
104
104
  self._forward_new_output = None
105
+ self.save_name = None
105
106
  if hasattr(config, "data_mode"):
106
107
  self.allowed_data_mode = self._get_allowed_data_mode(config.data_mode)
107
108
 
@@ -223,7 +224,7 @@ class BaseDataProcessor:
223
224
  elif isinstance(args, dict):
224
225
  return cls.apply_transform_dict(args, transform, depth)
225
226
  elif args is not None:
226
- logger.warning(f"Data type {type(args)} is not supported.")
227
+ logger.debug(f"Data type {type(args)} is not supported.")
227
228
  return None
228
229
  else:
229
230
  return None
@@ -273,13 +274,10 @@ class BaseDataProcessor:
273
274
  """
274
275
  return forward_backward in self.allowed_data_mode and input_output in self.allowed_data_mode
275
276
 
276
- def analyze_pre_forward(self, name, module, module_input_output: ModuleForwardInputsOutputs):
277
- pass
278
-
279
277
  def analyze_element(self, element):
280
278
  return self.recursive_apply_transform(element, self.analyze_single_element)
281
279
 
282
- def analyze_forward(self, name, module, module_input_output: ModuleForwardInputsOutputs):
280
+ def analyze_forward_input(self, name, module, module_input_output: ModuleForwardInputsOutputs):
283
281
  api_info_struct = {}
284
282
  # check whether data_mode contains forward or input
285
283
  if self.is_dump_for_data_mode(Const.FORWARD, Const.INPUT):
@@ -291,16 +289,22 @@ class BaseDataProcessor:
291
289
  kwargs_info_list = self.analyze_element(module_input_output.kwargs)
292
290
  api_info_struct[name][Const.INPUT_KWARGS] = kwargs_info_list
293
291
 
294
- # check whether data_mode contains forward or output
292
+ return api_info_struct
293
+
294
+ def analyze_forward_output(self, name, module, module_input_output: ModuleForwardInputsOutputs):
295
+ api_info_struct = {}
296
+ # check whether data_mode contains forward or input
295
297
  if self.is_dump_for_data_mode(Const.FORWARD, Const.OUTPUT):
296
- api_info_struct[name] = api_info_struct.get(name, {})
298
+ api_info_struct[name] = {}
297
299
  self.api_data_category = Const.OUTPUT
298
300
  output_info_list = self.analyze_element(module_input_output.output_tuple)
299
301
  api_info_struct[name][Const.OUTPUT] = output_info_list
302
+
300
303
  return api_info_struct
301
304
 
302
- def analyze_pre_forward_inplace(self, name, module_input_output: ModuleForwardInputsOutputs):
305
+ def analyze_forward(self, name, module, module_input_output: ModuleForwardInputsOutputs):
303
306
  api_info_struct = {}
307
+ # check whether data_mode contains forward or input
304
308
  if self.is_dump_for_data_mode(Const.FORWARD, Const.INPUT):
305
309
  api_info_struct[name] = {}
306
310
  self.api_data_category = Const.INPUT
@@ -309,16 +313,18 @@ class BaseDataProcessor:
309
313
  self.api_data_category = Const.KWARGS
310
314
  kwargs_info_list = self.analyze_element(module_input_output.kwargs)
311
315
  api_info_struct[name][Const.INPUT_KWARGS] = kwargs_info_list
312
- return api_info_struct
313
316
 
314
- def analyze_forward_inplace(self, name, module_input_output: ModuleForwardInputsOutputs):
315
- concat_args = module_input_output.concat_args_and_kwargs()
316
- api_info_struct = {}
317
+ # check whether data_mode contains forward or output
317
318
  if self.is_dump_for_data_mode(Const.FORWARD, Const.OUTPUT):
318
- api_info_struct[name] = {}
319
+ api_info_struct[name] = api_info_struct.get(name, {})
319
320
  self.api_data_category = Const.OUTPUT
320
- output_info_list = self.analyze_element(concat_args)
321
+ output_info_list = self.analyze_element(module_input_output.output_tuple)
321
322
  api_info_struct[name][Const.OUTPUT] = output_info_list
323
+
324
+ if name in api_info_struct and hasattr(module_input_output, Const.PARAMS):
325
+ self.api_data_category = Const.PARAMS
326
+ api_info_struct[name][Const.PARAMS] = self.analyze_element(getattr(module_input_output, Const.PARAMS))
327
+
322
328
  return api_info_struct
323
329
 
324
330
  def analyze_backward(self, name, module, module_input_output: ModuleBackwardInputsOutputs):
@@ -359,9 +365,21 @@ class BaseDataProcessor:
359
365
  api_info_struct[name][Const.OUTPUT] = output_info_list
360
366
  return api_info_struct
361
367
 
368
+ def analyze_params(self, name, param_name, grad):
369
+ api_info_struct = {}
370
+ self.save_name = name + Const.SEP + param_name
371
+ data_info = self.analyze_element(grad)
372
+ grad_info_dict = {param_name: [data_info]}
373
+ api_info_struct[name] = grad_info_dict
374
+ return api_info_struct
375
+
362
376
  def get_save_file_path(self, suffix):
363
377
  file_format = Const.PT_SUFFIX if self.config.framework == Const.PT_FRAMEWORK else Const.NUMPY_SUFFIX
364
- dump_data_name = (self.current_api_or_module_name + Const.SEP + self.api_data_category + Const.SEP +
365
- suffix + file_format)
378
+ if self.save_name is not None:
379
+ dump_data_name = (self.save_name + file_format)
380
+ self.save_name = None
381
+ else:
382
+ dump_data_name = (self.current_api_or_module_name + Const.SEP + self.api_data_category + Const.SEP +
383
+ suffix + file_format)
366
384
  file_path = os.path.join(self.data_writer.dump_tensor_data_dir, dump_data_name)
367
385
  return dump_data_name, file_path