mindstudio-probe 1.0.4__py3-none-any.whl → 1.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (194) hide show
  1. {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.0.dist-info}/METADATA +1 -1
  2. mindstudio_probe-1.1.0.dist-info/RECORD +287 -0
  3. msprobe/README.md +46 -16
  4. msprobe/__init__.py +16 -1
  5. msprobe/config.json +0 -2
  6. msprobe/core/advisor/advisor.py +8 -8
  7. msprobe/core/advisor/advisor_const.py +6 -7
  8. msprobe/core/advisor/advisor_result.py +12 -12
  9. msprobe/core/common/const.py +64 -3
  10. msprobe/core/common/exceptions.py +2 -2
  11. msprobe/core/common/file_utils.py +54 -9
  12. msprobe/core/common/inplace_op_checker.py +38 -0
  13. msprobe/core/common/inplace_ops.yaml +251 -0
  14. msprobe/core/common/log.py +21 -11
  15. msprobe/core/common/utils.py +153 -167
  16. msprobe/core/common_config.py +18 -25
  17. msprobe/core/compare/acc_compare.py +209 -36
  18. msprobe/core/compare/check.py +102 -17
  19. msprobe/core/compare/compare_cli.py +21 -1
  20. msprobe/core/compare/highlight.py +41 -5
  21. msprobe/core/compare/multiprocessing_compute.py +33 -8
  22. msprobe/core/compare/npy_compare.py +21 -6
  23. msprobe/core/compare/utils.py +82 -48
  24. msprobe/core/data_dump/data_collector.py +31 -32
  25. msprobe/core/data_dump/data_processor/base.py +45 -22
  26. msprobe/core/data_dump/data_processor/factory.py +20 -3
  27. msprobe/core/data_dump/data_processor/mindspore_processor.py +11 -5
  28. msprobe/core/data_dump/data_processor/pytorch_processor.py +24 -7
  29. msprobe/core/data_dump/json_writer.py +63 -42
  30. msprobe/core/data_dump/scope.py +32 -16
  31. msprobe/core/grad_probe/constant.py +4 -0
  32. msprobe/core/grad_probe/grad_compare.py +2 -3
  33. msprobe/core/grad_probe/utils.py +16 -3
  34. msprobe/docs/01.installation.md +19 -9
  35. msprobe/docs/02.config_introduction.md +52 -80
  36. msprobe/docs/03.config_examples.md +3 -13
  37. msprobe/docs/04.acl_config_examples.md +11 -9
  38. msprobe/docs/05.data_dump_PyTorch.md +140 -12
  39. msprobe/docs/06.data_dump_MindSpore.md +47 -5
  40. msprobe/docs/07.accuracy_checker_PyTorch.md +57 -34
  41. msprobe/docs/08.accuracy_checker_online_PyTorch.md +51 -11
  42. msprobe/docs/09.accuracy_checker_MindSpore.md +8 -8
  43. msprobe/docs/10.accuracy_compare_PyTorch.md +181 -99
  44. msprobe/docs/11.accuracy_compare_MindSpore.md +162 -31
  45. msprobe/docs/13.overflow_check_MindSpore.md +1 -1
  46. msprobe/docs/15.free_benchmarking_PyTorch.md +59 -53
  47. msprobe/docs/16.free_benchmarking_MindSpore.md +140 -0
  48. msprobe/docs/17.grad_probe.md +14 -16
  49. msprobe/docs/18.online_dispatch.md +89 -0
  50. msprobe/docs/{FAQ_PyTorch.md → FAQ.md} +22 -10
  51. msprobe/docs/img/ms_dump.png +0 -0
  52. msprobe/docs/img/ms_layer.png +0 -0
  53. msprobe/docs/img/pt_dump.png +0 -0
  54. msprobe/mindspore/__init__.py +1 -0
  55. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +35 -11
  56. msprobe/mindspore/api_accuracy_checker/api_info.py +7 -0
  57. msprobe/mindspore/cell_processor.py +27 -3
  58. msprobe/mindspore/common/const.py +2 -0
  59. msprobe/mindspore/common/utils.py +18 -2
  60. msprobe/mindspore/compare/distributed_compare.py +9 -22
  61. msprobe/mindspore/compare/layer_mapping.py +146 -0
  62. msprobe/mindspore/compare/modify_mapping.py +107 -0
  63. msprobe/mindspore/compare/ms_compare.py +173 -35
  64. msprobe/mindspore/compare/ms_graph_compare.py +27 -11
  65. msprobe/mindspore/debugger/debugger_config.py +16 -13
  66. msprobe/mindspore/debugger/precision_debugger.py +37 -13
  67. msprobe/mindspore/dump/dump_tool_factory.py +16 -1
  68. msprobe/mindspore/dump/hook_cell/api_registry.py +11 -1
  69. msprobe/mindspore/dump/hook_cell/primitive_hooks.py +206 -0
  70. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +82 -10
  71. msprobe/mindspore/dump/hook_cell/wrap_api.py +21 -13
  72. msprobe/mindspore/dump/jit_dump.py +41 -17
  73. msprobe/mindspore/dump/kernel_graph_dump.py +19 -3
  74. msprobe/mindspore/dump/kernel_kbyk_dump.py +19 -4
  75. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +19 -4
  76. msprobe/mindspore/free_benchmark/common/config.py +15 -0
  77. msprobe/mindspore/free_benchmark/common/handler_params.py +15 -0
  78. msprobe/mindspore/free_benchmark/common/utils.py +19 -5
  79. msprobe/mindspore/free_benchmark/decorator/dec_forward.py +16 -2
  80. msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +18 -3
  81. msprobe/mindspore/free_benchmark/handler/base_handler.py +18 -3
  82. msprobe/mindspore/free_benchmark/handler/check_handler.py +18 -3
  83. msprobe/mindspore/free_benchmark/handler/fix_handler.py +15 -0
  84. msprobe/mindspore/free_benchmark/handler/handler_factory.py +18 -3
  85. msprobe/mindspore/free_benchmark/perturbation/add_noise.py +22 -7
  86. msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +15 -0
  87. msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +22 -7
  88. msprobe/mindspore/free_benchmark/perturbation/exchange_value.py +44 -18
  89. msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +18 -4
  90. msprobe/mindspore/free_benchmark/perturbation/no_change.py +16 -1
  91. msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +20 -5
  92. msprobe/mindspore/free_benchmark/self_check_tool_factory.py +15 -0
  93. msprobe/mindspore/grad_probe/global_context.py +18 -8
  94. msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +20 -4
  95. msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +15 -0
  96. msprobe/mindspore/service.py +42 -123
  97. msprobe/pytorch/__init__.py +20 -1
  98. msprobe/pytorch/api_accuracy_checker/common/config.py +19 -2
  99. msprobe/pytorch/api_accuracy_checker/common/utils.py +53 -21
  100. msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +19 -2
  101. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +47 -21
  102. msprobe/pytorch/api_accuracy_checker/compare/compare.py +51 -21
  103. msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +23 -6
  104. msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +28 -8
  105. msprobe/pytorch/api_accuracy_checker/config.yaml +1 -1
  106. msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +67 -32
  107. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +26 -5
  108. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +19 -2
  109. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +51 -125
  110. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +146 -3
  111. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +21 -0
  112. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +78 -33
  113. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +27 -4
  114. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/dump_dispatch.py +110 -0
  115. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +36 -11
  116. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/torch_ops_config.yaml +63 -0
  117. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +44 -0
  118. msprobe/pytorch/bench_functions/__init__.py +18 -3
  119. msprobe/pytorch/bench_functions/apply_adam_w.py +15 -0
  120. msprobe/pytorch/bench_functions/confusion_transpose.py +15 -0
  121. msprobe/pytorch/bench_functions/fast_gelu.py +15 -0
  122. msprobe/pytorch/bench_functions/layer_norm_eval.py +15 -0
  123. msprobe/pytorch/bench_functions/linear.py +15 -0
  124. msprobe/pytorch/bench_functions/matmul_backward.py +21 -6
  125. msprobe/pytorch/bench_functions/npu_fusion_attention.py +180 -151
  126. msprobe/pytorch/bench_functions/rms_norm.py +15 -0
  127. msprobe/pytorch/bench_functions/rotary_mul.py +28 -9
  128. msprobe/pytorch/bench_functions/scaled_mask_softmax.py +15 -0
  129. msprobe/pytorch/bench_functions/swiglu.py +20 -5
  130. msprobe/pytorch/common/__init__.py +15 -0
  131. msprobe/pytorch/common/log.py +18 -6
  132. msprobe/pytorch/common/parse_json.py +26 -11
  133. msprobe/pytorch/common/utils.py +40 -35
  134. msprobe/pytorch/compare/distributed_compare.py +11 -11
  135. msprobe/pytorch/compare/match.py +15 -0
  136. msprobe/pytorch/compare/pt_compare.py +38 -6
  137. msprobe/pytorch/debugger/debugger_config.py +52 -39
  138. msprobe/pytorch/debugger/precision_debugger.py +72 -24
  139. msprobe/pytorch/free_benchmark/__init__.py +20 -5
  140. msprobe/pytorch/free_benchmark/common/enums.py +28 -0
  141. msprobe/pytorch/free_benchmark/common/params.py +15 -0
  142. msprobe/pytorch/free_benchmark/common/utils.py +17 -1
  143. msprobe/pytorch/free_benchmark/compare/grad_saver.py +28 -7
  144. msprobe/pytorch/free_benchmark/compare/single_benchmark.py +15 -0
  145. msprobe/pytorch/free_benchmark/main.py +19 -4
  146. msprobe/pytorch/free_benchmark/perturbed_layers/base_layer.py +15 -0
  147. msprobe/pytorch/free_benchmark/perturbed_layers/layer_factory.py +19 -4
  148. msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +15 -0
  149. msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +15 -0
  150. msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +26 -2
  151. msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +15 -0
  152. msprobe/pytorch/free_benchmark/perturbed_layers/npu/no_change.py +15 -0
  153. msprobe/pytorch/free_benchmark/perturbed_layers/npu/npu_base_layser.py +15 -0
  154. msprobe/pytorch/free_benchmark/perturbed_layers/run_cpu.py +15 -0
  155. msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +55 -16
  156. msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +15 -0
  157. msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +15 -0
  158. msprobe/pytorch/free_benchmark/result_handlers/handler_factory.py +15 -0
  159. msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +19 -4
  160. msprobe/pytorch/function_factory.py +17 -2
  161. msprobe/pytorch/functional/module_dump.py +84 -0
  162. msprobe/pytorch/grad_probe/grad_stat_csv.py +2 -2
  163. msprobe/pytorch/hook_module/__init__.py +16 -1
  164. msprobe/pytorch/hook_module/api_registry.py +13 -8
  165. msprobe/pytorch/hook_module/hook_module.py +17 -19
  166. msprobe/pytorch/hook_module/utils.py +4 -6
  167. msprobe/pytorch/hook_module/wrap_aten.py +12 -11
  168. msprobe/pytorch/hook_module/wrap_distributed.py +6 -7
  169. msprobe/pytorch/hook_module/wrap_functional.py +10 -11
  170. msprobe/pytorch/hook_module/wrap_npu_custom.py +9 -17
  171. msprobe/pytorch/hook_module/wrap_tensor.py +4 -6
  172. msprobe/pytorch/hook_module/wrap_torch.py +4 -6
  173. msprobe/pytorch/hook_module/wrap_vf.py +4 -6
  174. msprobe/pytorch/module_processer.py +17 -2
  175. msprobe/pytorch/online_dispatch/compare.py +11 -12
  176. msprobe/pytorch/online_dispatch/single_compare.py +7 -7
  177. msprobe/pytorch/online_dispatch/torch_ops_config.yaml +8 -0
  178. msprobe/pytorch/online_dispatch/utils.py +1 -4
  179. msprobe/pytorch/parse.py +15 -0
  180. msprobe/pytorch/parse_tool/cli.py +5 -6
  181. msprobe/pytorch/parse_tool/lib/compare.py +9 -10
  182. msprobe/pytorch/parse_tool/lib/parse_tool.py +3 -0
  183. msprobe/pytorch/parse_tool/lib/utils.py +28 -24
  184. msprobe/pytorch/parse_tool/lib/visualization.py +1 -1
  185. msprobe/pytorch/pt_config.py +167 -38
  186. msprobe/pytorch/service.py +97 -32
  187. mindstudio_probe-1.0.4.dist-info/RECORD +0 -276
  188. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/ssl_config.py +0 -10
  189. msprobe/pytorch/functional/data_processor.py +0 -0
  190. msprobe/pytorch/functional/dump_module.py +0 -39
  191. {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.0.dist-info}/LICENSE +0 -0
  192. {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.0.dist-info}/WHEEL +0 -0
  193. {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.0.dist-info}/entry_points.txt +0 -0
  194. {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.0.dist-info}/top_level.txt +0 -0
@@ -1,9 +1,22 @@
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
1
15
 
2
16
  import multiprocessing
3
17
  from dataclasses import dataclass
4
- from functools import partial
5
- import numpy as np
6
18
  import pandas as pd
19
+ from tqdm import tqdm
7
20
  from msprobe.core.common.log import logger
8
21
  from msprobe.core.common.utils import CompareException
9
22
  from msprobe.core.common.const import CompareConst
@@ -29,11 +42,19 @@ def _handle_multi_process(func, input_parma, result_df, lock):
29
42
  except OSError as e:
30
43
  logger.error("pool terminate failed")
31
44
 
45
+ progress_bar = tqdm(total=len(result_df), desc="API/Module Item Compare Process", unit="row", ncols=100)
46
+
47
+ def update_progress(size, progress_lock):
48
+ with progress_lock:
49
+ progress_bar.update(size)
50
+
32
51
  for process_idx, df_chunk in enumerate(df_chunks):
33
52
  idx = df_chunk_size * process_idx
53
+ chunk_size = len(df_chunk)
34
54
  result = pool.apply_async(func,
35
55
  args=(idx, op_name_mapping_dict, df_chunk, lock, input_parma),
36
- error_callback=err_call)
56
+ error_callback=err_call,
57
+ callback=update_progress(chunk_size, lock))
37
58
  results.append(result)
38
59
  final_results = [r.get() for r in results]
39
60
  pool.close()
@@ -42,7 +63,7 @@ def _handle_multi_process(func, input_parma, result_df, lock):
42
63
 
43
64
 
44
65
  def _ms_graph_handle_multi_process(func, result_df, mode):
45
- process_num = int((multiprocessing.cpu_count() + 1) // 2)
66
+ process_num = int((multiprocessing.cpu_count() + 1) // 4)
46
67
  df_chunk_size = len(result_df) // process_num
47
68
  if df_chunk_size > 0:
48
69
  df_chunks = [result_df.iloc[i:i + df_chunk_size] for i in range(0, len(result_df), df_chunk_size)]
@@ -84,7 +105,8 @@ def read_dump_data(result_df):
84
105
  except IndexError as e:
85
106
  logger.error('result dataframe elements can not be access.')
86
107
  raise CompareException(CompareException.INDEX_OUT_OF_BOUNDS_ERROR) from e
87
-
108
+
109
+
88
110
  @dataclass
89
111
  class ComparisonResult:
90
112
  cos_result: list
@@ -116,9 +138,12 @@ def _save_cmp_result(offset, result: ComparisonResult, result_df, lock):
116
138
  result_df.loc[process_index, CompareConst.MAX_ABS_ERR] = result.max_err_result[i]
117
139
  result_df.loc[process_index, CompareConst.MAX_RELATIVE_ERR] = result.max_relative_err_result[i]
118
140
  result_df.loc[process_index, CompareConst.ERROR_MESSAGE] = result.err_msgs[i]
119
- result_df.loc[process_index, CompareConst.ACCURACY] = check_accuracy(result.cos_result[i], result.max_err_result[i])
120
- result_df.loc[process_index, CompareConst.ONE_THOUSANDTH_ERR_RATIO] = result.one_thousand_err_ratio_result[i]
121
- result_df.loc[process_index, CompareConst.FIVE_THOUSANDTHS_ERR_RATIO] = result.five_thousand_err_ratio_result[i]
141
+ result_df.loc[process_index, CompareConst.ACCURACY] = (
142
+ check_accuracy(result.cos_result[i], result.max_err_result[i]))
143
+ result_df.loc[process_index, CompareConst.ONE_THOUSANDTH_ERR_RATIO] = (
144
+ result.one_thousand_err_ratio_result)[i]
145
+ result_df.loc[process_index, CompareConst.FIVE_THOUSANDTHS_ERR_RATIO] = (
146
+ result.five_thousand_err_ratio_result)[i]
122
147
  return result_df
123
148
  except ValueError as e:
124
149
  logger.error('result dataframe is not found.')
@@ -1,3 +1,18 @@
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
1
16
  import abc
2
17
  import numpy as np
3
18
  from msprobe.core.common.utils import format_value
@@ -78,10 +93,8 @@ def get_error_message(n_value, b_value, npu_op_name, error_flag, error_file=None
78
93
 
79
94
  def npy_data_check(n_value, b_value):
80
95
  error_message = ""
81
- if n_value is None or b_value is None:
82
- error_message += "Dump file not found.\n"
83
- if n_value == "" or b_value == "":
84
- error_message += "Dump file not found.\n"
96
+ if not isinstance(n_value, np.ndarray) or not isinstance(b_value, np.ndarray):
97
+ error_message += "Dump file is not ndarray.\n"
85
98
 
86
99
  # 检查 n_value 和 b_value 是否为空
87
100
  if not error_message and (n_value.size == 0 or b_value.size == 0):
@@ -97,7 +110,8 @@ def npy_data_check(n_value, b_value):
97
110
 
98
111
  if not error_message:
99
112
  n_value, b_value = handle_inf_nan(n_value, b_value) # 判断是否有 nan/inf 数据
100
- if CompareConst.NAN in (n_value, b_value):
113
+ # handle_inf_nan 会返回'Nan'或ndarray类型,使用类型判断是否存在无法处理的nan/inf数据
114
+ if not isinstance(n_value, np.ndarray) or not isinstance(b_value, np.ndarray):
101
115
  error_message += "The position of inf or nan in NPU and bench Tensor do not match.\n"
102
116
  if error_message == "":
103
117
  error_flag = False
@@ -273,7 +287,8 @@ class GetFiveThousandErrRatio(TensorComparisonBasic):
273
287
  relative_err = get_relative_err(n_value, b_value)
274
288
  if not np.size(relative_err):
275
289
  return CompareConst.NAN, ""
276
- return format_value(np.sum(relative_err < CompareConst.FIVE_THOUSAND_RATIO_THRESHOLD) / np.size(relative_err)), ""
290
+ return format_value(
291
+ np.sum(relative_err < CompareConst.FIVE_THOUSAND_RATIO_THRESHOLD) / np.size(relative_err)), ""
277
292
 
278
293
 
279
294
  class CompareOps:
@@ -1,3 +1,17 @@
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
1
15
 
2
16
  import os
3
17
  import re
@@ -59,14 +73,18 @@ def check_and_return_dir_contents(dump_dir, prefix):
59
73
 
60
74
  def rename_api(npu_name, process):
61
75
  npu_split = npu_name.split(process)
62
- torch_func_index, in_out = npu_split[0], npu_split[1]
76
+ try:
77
+ torch_func_index, in_out = npu_split[0], npu_split[1]
78
+ except IndexError as error:
79
+ logger.error(f'{npu_name} can not be split with {process}, please check!')
80
+ raise CompareException(CompareException.INDEX_OUT_OF_BOUNDS_ERROR) from error
63
81
  torch_func_split = torch_func_index.rsplit(Const.SEP, 2)
64
82
  torch_func = str(torch_func_split[0]) + str(in_out)
65
83
  return torch_func
66
84
 
67
85
 
68
86
  def read_op(op_data, op_name):
69
- op_parsed_list = Const.DEFAULT_LIST
87
+ op_parsed_list = []
70
88
  if Const.FORWARD in op_name:
71
89
  if Const.INPUT_ARGS in op_data:
72
90
  input_item = op_data[Const.INPUT_ARGS]
@@ -103,16 +121,23 @@ def read_op(op_data, op_name):
103
121
  return op_parsed_list
104
122
 
105
123
 
106
- def op_item_parse(item, op_name, index, item_list=None, top_bool=True):
124
+ def op_item_parse(item, op_name, index, item_list=None, top_bool=True, depth=0):
125
+ if depth > Const.MAX_DEPTH:
126
+ logger.error(f"parse of api/module of {op_name} exceeds the recursion limit.")
127
+ raise CompareException(CompareException.RECURSION_LIMIT_ERROR)
107
128
  if item_list is None:
108
129
  item_list = []
109
130
  if item is None or (isinstance(item, dict) and not item):
110
131
  if not top_bool:
111
- tmp = {'full_op_name': op_name + '.' + str(index), 'Max': None, 'Min': None, 'Mean': None, 'Norm': None,
112
- 'dtype': None, 'shape': None, 'md5': None, 'data_name': '-1'}
132
+ tmp = {
133
+ 'full_op_name': op_name + '.' + str(index), 'Max': None, 'Min': None, 'Mean': None, 'Norm': None,
134
+ 'dtype': None, 'shape': None, 'md5': None, 'data_name': '-1'
135
+ }
113
136
  else:
114
- tmp = {'full_op_name': op_name + '.0', 'Max': None, 'Min': None, 'Mean': None, 'Norm': None, 'dtype': None,
115
- 'shape': None, 'md5': None, 'data_name': '-1'}
137
+ tmp = {
138
+ 'full_op_name': op_name + '.0', 'Max': None, 'Min': None, 'Mean': None, 'Norm': None, 'dtype': None,
139
+ 'shape': None, 'md5': None, 'data_name': '-1'
140
+ }
116
141
  item_list.append(tmp)
117
142
  return item_list
118
143
  if index is None:
@@ -125,7 +150,7 @@ def op_item_parse(item, op_name, index, item_list=None, top_bool=True):
125
150
  if isinstance(item, dict):
126
151
  if 'type' not in item:
127
152
  for kwarg in item:
128
- kwarg_parsed_list = op_item_parse(item[kwarg], op_name + Const.SEP + kwarg, None)
153
+ kwarg_parsed_list = op_item_parse(item[kwarg], op_name + Const.SEP + kwarg, None, depth=depth+1)
129
154
  item_list += kwarg_parsed_list
130
155
  kwarg_parsed_list.clear()
131
156
  elif 'dtype' in item:
@@ -171,7 +196,7 @@ def op_item_parse(item, op_name, index, item_list=None, top_bool=True):
171
196
  resolve_api_special_parameters(item, full_op_name, item_list)
172
197
  else:
173
198
  for j, item_spec in enumerate(item):
174
- op_item_parse(item_spec, full_op_name, j, item_list=item_list, top_bool=False)
199
+ op_item_parse(item_spec, full_op_name, j, item_list=item_list, top_bool=False, depth=depth+1)
175
200
  return item_list
176
201
 
177
202
 
@@ -226,9 +251,10 @@ def get_accuracy(result, n_dict, b_dict, summary_compare=False, md5_compare=Fals
226
251
  b_struct = b_dict[key][index]
227
252
  err_msg = ""
228
253
  if md5_compare:
229
- result_item = [n_name, b_name, n_struct[0], b_struct[0], n_struct[1], b_struct[1],
230
- n_struct[2], b_struct[2],
231
- CompareConst.PASS if n_struct[2] == b_struct[2] else CompareConst.DIFF]
254
+ result_item = [
255
+ n_name, b_name, n_struct[0], b_struct[0], n_struct[1], b_struct[1], n_struct[2], b_struct[2],
256
+ CompareConst.PASS if n_struct[2] == b_struct[2] else CompareConst.DIFF
257
+ ]
232
258
  if has_stack and index == 0 and key == "input_struct":
233
259
  result_item.extend(npu_stack_info)
234
260
  else:
@@ -237,15 +263,19 @@ def get_accuracy(result, n_dict, b_dict, summary_compare=False, md5_compare=Fals
237
263
  continue
238
264
 
239
265
  if summary_compare:
240
- result_item = [n_name, b_name, n_struct[0], b_struct[0], n_struct[1], b_struct[1],
241
- " ", " ", " ", " ", " ", " ", " ", " "]
266
+ result_item = [
267
+ n_name, b_name, n_struct[0], b_struct[0], n_struct[1], b_struct[1],
268
+ " ", " ", " ", " ", " ", " ", " ", " "
269
+ ]
242
270
  else:
243
- result_item = [n_name, b_name, n_struct[0], b_struct[0], n_struct[1], b_struct[1],
244
- " ", " ", " ", " ", " "]
271
+ result_item = [
272
+ n_name, b_name, n_struct[0], b_struct[0], n_struct[1], b_struct[1],
273
+ " ", " ", " ", " ", " "
274
+ ]
245
275
 
246
- npu_summary_data = n_dict.get("summary")[n_start + index]
276
+ npu_summary_data = n_dict.get(CompareConst.SUMMARY)[n_start + index]
247
277
  result_item.extend(npu_summary_data)
248
- bench_summary_data = b_dict.get("summary")[b_start + index]
278
+ bench_summary_data = b_dict.get(CompareConst.SUMMARY)[b_start + index]
249
279
  result_item.extend(bench_summary_data)
250
280
 
251
281
  if summary_compare:
@@ -257,7 +287,7 @@ def get_accuracy(result, n_dict, b_dict, summary_compare=False, md5_compare=Fals
257
287
  if bench_val != 0:
258
288
  relative = str(abs((diff / bench_val) * 100)) + '%'
259
289
  else:
260
- relative = "N/A"
290
+ relative = CompareConst.N_A
261
291
  result_item[start_idx + i] = diff
262
292
  result_item[start_idx + i + 4] = relative
263
293
  magnitude_diff = abs(diff) / (max(abs(npu_val), abs(bench_val)) + 1e-10)
@@ -287,15 +317,19 @@ def get_accuracy(result, n_dict, b_dict, summary_compare=False, md5_compare=Fals
287
317
  n_name = n_dict['op_name'][n_start + index]
288
318
  n_struct = n_dict[key][index]
289
319
  if md5_compare:
290
- result_item = [n_name, CompareConst.NAN, n_struct[0], CompareConst.NAN,
291
- n_struct[1], CompareConst.NAN, n_struct[2], CompareConst.NAN, CompareConst.NAN]
320
+ result_item = [
321
+ n_name, CompareConst.NAN, n_struct[0], CompareConst.NAN, n_struct[1], CompareConst.NAN,
322
+ n_struct[2], CompareConst.NAN, CompareConst.NAN
323
+ ]
292
324
  result.append(result_item)
293
325
  continue
294
- result_item = [n_name, CompareConst.NAN, n_struct[0], CompareConst.NAN,
295
- n_struct[1], CompareConst.NAN, " ", " ", " ", " ", " "]
296
- summary_data = n_dict.get("summary")[n_start + index]
326
+ result_item = [
327
+ n_name, CompareConst.NAN, n_struct[0], CompareConst.NAN, n_struct[1], CompareConst.NAN,
328
+ " ", " ", " ", " ", " "
329
+ ]
330
+ summary_data = n_dict.get(CompareConst.SUMMARY)[n_start + index]
297
331
  result_item.extend(summary_data)
298
- summary_data = [CompareConst.NAN for _ in range(len(n_dict.get("summary")[0]))]
332
+ summary_data = [CompareConst.NAN for _ in range(len(n_dict.get(CompareConst.SUMMARY)[0]))]
299
333
  result_item.extend(summary_data)
300
334
 
301
335
  err_msg = ""
@@ -313,15 +347,12 @@ def get_accuracy(result, n_dict, b_dict, summary_compare=False, md5_compare=Fals
313
347
 
314
348
  n_num = len(n_dict['op_name'])
315
349
  b_num = len(b_dict['op_name'])
316
- n_num_input = len([name for name in n_dict['op_name'] if Const.INPUT in name])
317
- b_num_input = len([name for name in b_dict['op_name'] if Const.INPUT in name])
318
- n_num_kwarg = len([name for name in n_dict['op_name'] if 'kwarg' in name])
319
- b_num_kwarg = len([name for name in b_dict['op_name'] if 'kwarg' in name])
320
- n_num_output = n_num - n_num_input - n_num_kwarg
321
- b_num_output = b_num - b_num_input - b_num_kwarg
350
+ n_num_input = len([name for name in n_dict['op_name'] if Const.INPUT in name.split(Const.SEP) or Const.KWARGS in name.split(Const.SEP)])
351
+ b_num_input = len([name for name in b_dict['op_name'] if Const.INPUT in name.split(Const.SEP) or Const.KWARGS in name.split(Const.SEP)])
352
+ n_num_output = n_num - n_num_input
353
+ b_num_output = b_num - b_num_input
322
354
  get_accuracy_core(0, n_num_input, 0, b_num_input, 'input_struct')
323
- get_accuracy_core(n_num_input, n_num_kwarg, b_num_input, b_num_kwarg, "kwargs_struct")
324
- get_accuracy_core(n_num_input + n_num_kwarg, n_num_output, b_num_input + b_num_kwarg, b_num_output, 'output_struct')
355
+ get_accuracy_core(n_num_input, n_num_output, b_num_input, b_num_output, 'output_struct')
325
356
 
326
357
 
327
358
  def get_un_match_accuracy(result, n_dict, md5_compare, summary_compare):
@@ -331,7 +362,8 @@ def get_un_match_accuracy(result, n_dict, md5_compare, summary_compare):
331
362
  err_msg = CompareConst.NO_BENCH
332
363
  accuracy_check_res = CompareConst.N_A
333
364
  for index, n_name in enumerate(n_dict["op_name"]):
334
- if n_name.find("input") != -1:
365
+ name_ele_list = n_name.split(Const.SEP)
366
+ if "input" in name_ele_list:
335
367
  n_struct = n_dict["input_struct"][index]
336
368
  else:
337
369
  n_struct = n_dict["output_struct"][index_out]
@@ -383,25 +415,28 @@ def merge_tensor(tensor_list, summary_compare, md5_compare):
383
415
  op_dict['stack_info'].append(tensor['full_info'])
384
416
  break
385
417
  op_dict["op_name"].append(tensor['full_op_name'])
418
+ name_ele_list = tensor['full_op_name'].split(Const.SEP)
386
419
  if not md5_compare:
387
- if tensor['full_op_name'].find("input") != -1:
420
+ if "input" in name_ele_list:
388
421
  op_dict["input_struct"].append((tensor['dtype'], tensor['shape']))
389
- elif tensor['full_op_name'].find("kwarg") != -1:
422
+ elif "kwarg" in name_ele_list:
390
423
  op_dict["kwargs_struct"].append((tensor['dtype'], tensor['shape']))
391
- elif tensor['full_op_name'].find("output") != -1:
424
+ elif "output" in name_ele_list:
392
425
  op_dict["output_struct"].append((tensor['dtype'], tensor['shape']))
393
426
  else:
394
- if tensor['full_op_name'].find("input") != -1:
427
+ if "input" in name_ele_list:
395
428
  op_dict["input_struct"].append((tensor['dtype'], tensor['shape'], tensor['md5']))
396
- elif tensor['full_op_name'].find("kwarg") != -1:
429
+ if "kwarg" in name_ele_list:
397
430
  op_dict["kwargs_struct"].append((tensor['dtype'], tensor['shape'], tensor['md5']))
398
- elif tensor['full_op_name'].find("output") != -1:
431
+ elif "output" in name_ele_list:
399
432
  op_dict["output_struct"].append((tensor['dtype'], tensor['shape'], tensor['md5']))
400
-
401
433
  op_dict["summary"].append([tensor['Max'], tensor['Min'], tensor['Mean'], tensor['Norm']])
402
434
 
403
435
  if all_mode_bool:
404
436
  op_dict["data_name"].append(tensor['data_name'])
437
+ data_name = op_dict["data_name"][-1].rsplit(Const.SEP, 1)[0]
438
+ if data_name != "-1":
439
+ op_dict["op_name"][-1] = data_name
405
440
 
406
441
  if not op_dict["kwargs_struct"]:
407
442
  del op_dict["kwargs_struct"]
@@ -410,7 +445,7 @@ def merge_tensor(tensor_list, summary_compare, md5_compare):
410
445
 
411
446
  def _compare_parser(parser):
412
447
  parser.add_argument("-i", "--input_path", dest="input_path", type=str,
413
- help="<Required> The compare input path, a dict json.", required=True)
448
+ help="<Required> The compare input path, a dict json.", required=True)
414
449
  parser.add_argument("-o", "--output_path", dest="output_path", type=str,
415
450
  help="<Required> The compare task result out path.", required=True)
416
451
  parser.add_argument("-s", "--stack_mode", dest="stack_mode", action="store_true",
@@ -422,9 +457,8 @@ def _compare_parser(parser):
422
457
  parser.add_argument("-cm", "--cell_mapping", dest="cell_mapping", type=str, nargs='?', const=True,
423
458
  help="<optional> The cell mapping file path.", required=False)
424
459
  parser.add_argument("-am", "--api_mapping", dest="api_mapping", type=str, nargs='?', const=True,
425
- help="<optional> The api mapping file path.", required=False)
426
-
427
-
428
-
429
-
430
-
460
+ help="<optional> The api mapping file path.", required=False)
461
+ parser.add_argument("-dm", "--data_mapping", dest="data_mapping", type=str,
462
+ help="<optional> The data mapping file path.", required=False)
463
+ parser.add_argument("-lm", "--layer_mapping", dest="layer_mapping", type=str,
464
+ help="<optional> The layer mapping file path.", required=False)
@@ -1,9 +1,24 @@
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
1
16
  import os
2
17
 
3
18
  from msprobe.core.data_dump.scope import build_scope, ListScope
4
19
  from msprobe.core.data_dump.json_writer import DataWriter
5
20
  from msprobe.core.common.log import logger
6
- from msprobe.core.common.const import Const, MsgConst
21
+ from msprobe.core.common.const import Const
7
22
  from msprobe.core.data_dump.data_processor.factory import DataProcessorFactory
8
23
 
9
24
 
@@ -14,14 +29,13 @@ def build_data_collector(config):
14
29
  class DataCollector:
15
30
  multi_output_apis = ["_sort_", "npu_flash_attention"]
16
31
  tasks_need_tensor_data = [Const.OVERFLOW_CHECK, Const.TENSOR, Const.FREE_BENCHMARK]
17
- level_without_construct = ["L1", "L2"]
32
+ level_without_construct = [Const.LEVEL_L1, Const.LEVEL_L2]
18
33
 
19
34
  def __init__(self, config):
20
35
  self.config = config
21
36
  self.data_writer = DataWriter()
22
37
  self.data_processor = DataProcessorFactory.create_processor(self.config, self.data_writer)
23
- self.module_processor = DataProcessorFactory.get_module_processor(self.config.framework) \
24
- if self.config.framework == Const.PT_FRAMEWORK else None
38
+ self.module_processor = DataProcessorFactory.get_module_processor(self.config.framework)
25
39
  self.module_count = {}
26
40
  if self.config.task == Const.FREE_BENCHMARK:
27
41
  self.scope = build_scope(ListScope, self.config.scope, self.config.list)
@@ -59,16 +73,16 @@ class DataCollector:
59
73
  def write_json(self):
60
74
  self.data_writer.write_json()
61
75
 
62
- def update_data(self, data_info, msg=''):
76
+ def update_data(self, name, data_info):
77
+ msg = f"msprobe is collecting data on {name}."
63
78
  if self.config.task == Const.OVERFLOW_CHECK:
64
79
  if self.data_processor.has_overflow:
80
+ msg += " Overflow detected."
81
+ logger.warning(msg)
65
82
  self.data_writer.update_data(data_info)
66
- msg += "Overflow detected."
67
- else:
68
- msg += "No Overflow, OK."
69
- else:
70
- self.data_writer.update_data(data_info)
71
- return msg
83
+ return
84
+ logger.debug(msg)
85
+ self.data_writer.update_data(data_info)
72
86
 
73
87
  def pre_forward_data_collect(self, name, module, pid, module_input_output):
74
88
  backward_name = name.replace(Const.FORWARD, Const.BACKWARD)
@@ -78,7 +92,7 @@ class DataCollector:
78
92
  return
79
93
  logger.info(f"API {name} is inplace.")
80
94
  data_info = self.data_processor.analyze_pre_forward_inplace(name, module_input_output)
81
- self.handle_data(name, data_info)
95
+ self.handle_data(name, data_info, flush=self.data_processor.is_terminated)
82
96
 
83
97
  def forward_data_collect(self, name, module, pid, module_input_output):
84
98
  self.update_construct(name)
@@ -92,13 +106,7 @@ class DataCollector:
92
106
  if self.config.level == "L2":
93
107
  return
94
108
  self.data_writer.update_stack(self.data_processor.analyze_api_call_stack(name))
95
- if self.config.framework == Const.MS_FRAMEWORK:
96
- self.handle_data(name, data_info, flush=self.data_processor.is_terminated)
97
- else:
98
- if self.data_processor.is_terminated:
99
- self.handle_data(name, data_info, flush=True)
100
- raise Exception(f"[{Const.TOOL_NAME}] exit")
101
- self.handle_data(name, data_info)
109
+ self.handle_data(name, data_info, flush=self.data_processor.is_terminated)
102
110
 
103
111
  def backward_data_collect(self, name, module, pid, module_input_output):
104
112
  self.update_construct(name)
@@ -106,13 +114,7 @@ class DataCollector:
106
114
  return
107
115
 
108
116
  data_info = self.data_processor.analyze_backward(name, module, module_input_output)
109
- if self.config.framework == Const.MS_FRAMEWORK:
110
- self.handle_data(name, data_info, flush=self.data_processor.is_terminated)
111
- else:
112
- if self.data_processor.is_terminated:
113
- self.handle_data(name, data_info, flush=True)
114
- raise Exception(f"[{Const.TOOL_NAME}] exit")
115
- self.handle_data(name, data_info)
117
+ self.handle_data(name, data_info, flush=self.data_processor.is_terminated)
116
118
 
117
119
  def backward_input_data_collect(self, name, module, pid, module_input_output):
118
120
  self.update_construct(name)
@@ -131,18 +133,15 @@ class DataCollector:
131
133
  self.handle_data(name, data_info)
132
134
 
133
135
  def update_construct(self, name):
134
- if self.config.framework == Const.PT_FRAMEWORK and \
135
- self.config.level not in DataCollector.level_without_construct:
136
+ if self.config.level not in DataCollector.level_without_construct:
136
137
  self.data_writer.update_construct({name: self.module_processor.api_parent_node})
137
138
  self.data_writer.update_construct(self.module_processor.module_node)
138
139
 
139
140
  def handle_data(self, name, data_info, flush=False):
140
141
  if data_info:
141
- msg = f"msprobe is collecting data on {name}. "
142
- msg = self.update_data(data_info, msg)
143
- logger.info(MsgConst.CLEAR_SYMBOL + msg, end='\r')
142
+ self.update_data(name, data_info)
144
143
  if not flush:
145
- self.data_writer.flush_data_when_buffer_is_full()
144
+ self.data_writer.flush_data_periodically()
146
145
  else:
147
146
  self.write_json()
148
147
 
@@ -1,11 +1,27 @@
1
- import os
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
2
16
  import inspect
17
+ import os
3
18
  from dataclasses import dataclass
4
19
  from typing import Tuple, Dict, Optional, Any
20
+
5
21
  import numpy as np
6
- from msprobe.core.common.log import logger
7
- from msprobe.core.common.utils import convert_tuple
8
22
  from msprobe.core.common.const import Const
23
+ from msprobe.core.common.log import logger
24
+ from msprobe.core.common.utils import convert_tuple, CompareException
9
25
 
10
26
 
11
27
  @dataclass
@@ -69,8 +85,11 @@ class TensorStatInfo:
69
85
 
70
86
  class BaseDataProcessor:
71
87
  _recursive_key_stack = []
72
- special_type = (np.integer, np.floating, np.bool_, np.complexfloating, np.str_, np.byte, np.unicode_,
73
- bool, int, float, str, slice, type(Ellipsis))
88
+ special_type = (
89
+ np.integer, np.floating, np.bool_, np.complexfloating, np.str_, np.byte, np.unicode_,
90
+ bool, int, float, str, slice,
91
+ type(Ellipsis)
92
+ )
74
93
 
75
94
  def __init__(self, config, data_writer):
76
95
  self.data_writer = data_writer
@@ -86,26 +105,27 @@ class BaseDataProcessor:
86
105
  @property
87
106
  def data_path(self):
88
107
  return self.data_writer.dump_tensor_data_dir
89
-
108
+
90
109
  @property
91
110
  def is_terminated(self):
92
111
  return False
93
112
 
94
113
  @staticmethod
95
114
  def analyze_api_call_stack(name):
115
+ try:
116
+ api_stack = inspect.stack()[5:]
117
+ except Exception as e:
118
+ logger.warning(f"The call stack of <{name}> failed to retrieve, {e}.")
119
+ api_stack = None
96
120
  stack_str = []
97
- for (_, path, line, func, code, _) in inspect.stack()[5:]:
98
- if not code:
99
- continue
100
- stack_line = " ".join([
101
- "File", ", ".join([
102
- path,
103
- " ".join(["line", str(line)]),
104
- " ".join(["in", func]),
105
- " ".join(["\n", code[0].strip()])
106
- ])
107
- ])
108
- stack_str.append(stack_line)
121
+ if api_stack:
122
+ for (_, path, line, func, code, _) in api_stack:
123
+ if not code:
124
+ continue
125
+ stack_line = f"File {path}, line {str(line)}, in {func}, \n {code[0].strip()}"
126
+ stack_str.append(stack_line)
127
+ else:
128
+ stack_str.append(Const.WITHOUT_CALL_STACK)
109
129
  stack_info_struct = {name: stack_str}
110
130
  return stack_info_struct
111
131
 
@@ -167,7 +187,10 @@ class BaseDataProcessor:
167
187
  return cls.special_type
168
188
 
169
189
  @classmethod
170
- def recursive_apply_transform(cls, args, transform):
190
+ def recursive_apply_transform(cls, args, transform, depth=0):
191
+ if depth > Const.MAX_DEPTH:
192
+ logger.error(f"The maximum depth of recursive transform, {Const.MAX_DEPTH} is reached.")
193
+ raise CompareException(CompareException.RECURSION_LIMIT_ERROR)
171
194
  if isinstance(args, cls.get_special_types()):
172
195
  arg_transform = transform(args, cls._recursive_key_stack)
173
196
  return arg_transform
@@ -175,14 +198,14 @@ class BaseDataProcessor:
175
198
  result_list = []
176
199
  for i, arg in enumerate(args):
177
200
  cls._recursive_key_stack.append(str(i))
178
- result_list.append(cls.recursive_apply_transform(arg, transform))
201
+ result_list.append(cls.recursive_apply_transform(arg, transform, depth=depth + 1))
179
202
  cls._recursive_key_stack.pop()
180
203
  return type(args)(result_list)
181
204
  elif isinstance(args, dict):
182
205
  result_dict = {}
183
206
  for k, arg in args.items():
184
207
  cls._recursive_key_stack.append(str(k))
185
- result_dict[k] = cls.recursive_apply_transform(arg, transform)
208
+ result_dict[k] = cls.recursive_apply_transform(arg, transform, depth=depth + 1)
186
209
  cls._recursive_key_stack.pop()
187
210
  return result_dict
188
211
  elif args is not None:
@@ -222,7 +245,7 @@ class BaseDataProcessor:
222
245
 
223
246
  def analyze_pre_forward(self, name, module, module_input_output: ModuleForwardInputsOutputs):
224
247
  pass
225
-
248
+
226
249
  def analyze_element(self, element):
227
250
  return self.recursive_apply_transform(element, self.analyze_single_element)
228
251