mindstudio-probe 1.3.0__py3-none-any.whl → 8.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (213) hide show
  1. {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/METADATA +4 -2
  2. {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/RECORD +204 -152
  3. msprobe/README.md +32 -1
  4. msprobe/core/__init__.py +17 -0
  5. msprobe/core/common/const.py +120 -21
  6. msprobe/core/common/exceptions.py +2 -2
  7. msprobe/core/common/file_utils.py +279 -50
  8. msprobe/core/common/framework_adapter.py +169 -0
  9. msprobe/core/common/global_lock.py +86 -0
  10. msprobe/core/common/runtime.py +25 -0
  11. msprobe/core/common/utils.py +136 -45
  12. msprobe/core/common_config.py +7 -0
  13. msprobe/core/compare/acc_compare.py +646 -428
  14. msprobe/core/compare/check.py +36 -103
  15. msprobe/core/compare/compare_cli.py +4 -0
  16. msprobe/core/compare/config.py +72 -0
  17. msprobe/core/compare/highlight.py +215 -215
  18. msprobe/core/compare/layer_mapping/layer_mapping.py +2 -0
  19. msprobe/core/compare/merge_result/merge_result.py +4 -4
  20. msprobe/core/compare/multiprocessing_compute.py +223 -110
  21. msprobe/core/compare/npy_compare.py +2 -4
  22. msprobe/core/compare/utils.py +214 -244
  23. msprobe/core/config_check/__init__.py +17 -0
  24. msprobe/{pytorch/dump/kernel_dump/kernel_config.py → core/config_check/checkers/__init__.py} +8 -16
  25. msprobe/core/config_check/checkers/base_checker.py +60 -0
  26. msprobe/core/config_check/checkers/dataset_checker.py +138 -0
  27. msprobe/core/config_check/checkers/env_args_checker.py +96 -0
  28. msprobe/core/config_check/checkers/hyperparameter_checker.py +170 -0
  29. msprobe/core/config_check/checkers/pip_checker.py +90 -0
  30. msprobe/core/config_check/checkers/random_checker.py +367 -0
  31. msprobe/core/config_check/checkers/weights_checker.py +147 -0
  32. msprobe/core/config_check/ckpt_compare/ckpt_comparator.py +74 -0
  33. msprobe/core/config_check/ckpt_compare/megatron_loader.py +302 -0
  34. msprobe/core/config_check/ckpt_compare/metrics.py +83 -0
  35. msprobe/core/config_check/ckpt_compare/name_mapping.yaml +12 -0
  36. msprobe/core/config_check/config_check_cli.py +51 -0
  37. msprobe/core/config_check/config_checker.py +100 -0
  38. msprobe/{mindspore/runtime.py → core/config_check/resource/dependency.yaml} +7 -4
  39. msprobe/core/config_check/resource/env.yaml +57 -0
  40. msprobe/core/config_check/resource/hyperparameter.yaml +21 -0
  41. msprobe/core/config_check/utils/hyperparameter_parser.py +115 -0
  42. msprobe/core/config_check/utils/utils.py +107 -0
  43. msprobe/core/data_dump/api_registry.py +67 -4
  44. msprobe/core/data_dump/data_collector.py +170 -89
  45. msprobe/core/data_dump/data_processor/base.py +72 -51
  46. msprobe/core/data_dump/data_processor/mindspore_processor.py +109 -55
  47. msprobe/core/data_dump/data_processor/pytorch_processor.py +90 -82
  48. msprobe/core/data_dump/json_writer.py +143 -27
  49. msprobe/core/debugger/precision_debugger.py +144 -0
  50. msprobe/core/grad_probe/constant.py +1 -1
  51. msprobe/core/grad_probe/grad_compare.py +1 -1
  52. msprobe/core/grad_probe/utils.py +1 -1
  53. msprobe/core/hook_manager.py +242 -0
  54. msprobe/core/monitor/anomaly_processor.py +384 -0
  55. msprobe/core/service.py +357 -0
  56. msprobe/core/single_save/__init__.py +0 -0
  57. msprobe/core/single_save/single_comparator.py +243 -0
  58. msprobe/core/single_save/single_saver.py +146 -0
  59. msprobe/docs/01.installation.md +6 -5
  60. msprobe/docs/02.config_introduction.md +79 -22
  61. msprobe/docs/03.config_examples.md +1 -0
  62. msprobe/docs/04.kernel_dump_PyTorch.md +1 -1
  63. msprobe/docs/05.data_dump_PyTorch.md +118 -49
  64. msprobe/docs/06.data_dump_MindSpore.md +167 -20
  65. msprobe/docs/07.accuracy_checker_PyTorch.md +2 -2
  66. msprobe/docs/08.accuracy_checker_online_PyTorch.md +69 -9
  67. msprobe/docs/09.accuracy_checker_MindSpore.md +18 -6
  68. msprobe/docs/10.accuracy_compare_PyTorch.md +212 -74
  69. msprobe/docs/11.accuracy_compare_MindSpore.md +87 -37
  70. msprobe/docs/12.overflow_check_PyTorch.md +2 -2
  71. msprobe/docs/13.overflow_check_MindSpore.md +2 -2
  72. msprobe/docs/14.data_parse_PyTorch.md +3 -3
  73. msprobe/docs/17.grad_probe.md +2 -1
  74. msprobe/docs/18.online_dispatch.md +2 -2
  75. msprobe/docs/19.monitor.md +90 -44
  76. msprobe/docs/21.visualization_PyTorch.md +68 -15
  77. msprobe/docs/22.visualization_MindSpore.md +71 -18
  78. msprobe/docs/25.tool_function_introduction.md +23 -22
  79. msprobe/docs/26.data_dump_PyTorch_baseline.md +14 -3
  80. msprobe/docs/27.dump_json_instruction.md +1 -1
  81. msprobe/docs/28.debugger_save_instruction.md +111 -20
  82. msprobe/docs/29.data_dump_MSAdapter.md +2 -2
  83. msprobe/docs/30.overflow_check_MSAdapter.md +2 -2
  84. msprobe/docs/31.config_check.md +95 -0
  85. msprobe/docs/32.ckpt_compare.md +69 -0
  86. msprobe/docs/33.generate_operator_MindSpore.md +181 -0
  87. msprobe/docs/34.RL_collect.md +92 -0
  88. msprobe/docs/35.nan_analyze.md +72 -0
  89. msprobe/docs/data_dump_MindSpore/data_dump_MindSpore_baseline.md +12 -1
  90. msprobe/docs/data_dump_MindSpore/dynamic_graph_quick_start_example.md +3 -1
  91. msprobe/docs/img/compare_result.png +0 -0
  92. msprobe/docs/img/save_compare_result_sample.png +0 -0
  93. msprobe/docs/img/visualization/proxy.png +0 -0
  94. msprobe/mindspore/__init__.py +1 -2
  95. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +150 -58
  96. msprobe/mindspore/api_accuracy_checker/api_runner.py +7 -3
  97. msprobe/mindspore/api_accuracy_checker/bench_functions/flash_attention_score.py +47 -69
  98. msprobe/mindspore/api_accuracy_checker/cmd_parser.py +4 -0
  99. msprobe/mindspore/api_accuracy_checker/compute_element.py +0 -1
  100. msprobe/mindspore/api_accuracy_checker/data_manager.py +2 -2
  101. msprobe/mindspore/api_accuracy_checker/generate_op_script/op_generator.py +460 -0
  102. msprobe/mindspore/api_accuracy_checker/generate_op_script/operator_replication.template +2081 -0
  103. msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +9 -0
  104. msprobe/mindspore/api_accuracy_checker/torch_mindtorch_importer.py +2 -1
  105. msprobe/mindspore/cell_processor.py +204 -33
  106. msprobe/mindspore/code_mapping/graph_parser.py +4 -21
  107. msprobe/mindspore/common/const.py +17 -7
  108. msprobe/mindspore/common/utils.py +128 -11
  109. msprobe/mindspore/compare/common_dir_compare.py +382 -0
  110. msprobe/mindspore/compare/distributed_compare.py +2 -26
  111. msprobe/mindspore/compare/ms_compare.py +17 -405
  112. msprobe/mindspore/compare/ms_graph_compare.py +14 -5
  113. msprobe/mindspore/compare/utils.py +37 -0
  114. msprobe/mindspore/debugger/debugger_config.py +53 -3
  115. msprobe/mindspore/debugger/precision_debugger.py +72 -91
  116. msprobe/mindspore/dump/cell_dump_process.py +877 -0
  117. msprobe/mindspore/dump/cell_dump_with_insert_gradient.py +864 -0
  118. msprobe/mindspore/dump/dump_tool_factory.py +13 -5
  119. msprobe/mindspore/dump/graph_mode_cell_dump.py +139 -0
  120. msprobe/mindspore/dump/graph_tensor_dump.py +123 -0
  121. msprobe/mindspore/dump/hook_cell/api_register.py +40 -6
  122. msprobe/mindspore/dump/hook_cell/hook_cell.py +18 -7
  123. msprobe/mindspore/dump/hook_cell/ms_hook_manager.py +88 -0
  124. msprobe/mindspore/dump/hook_cell/primitive_hooks.py +8 -2
  125. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +18 -0
  126. msprobe/mindspore/dump/jit_dump.py +21 -18
  127. msprobe/mindspore/dump/kernel_kbyk_dump.py +6 -3
  128. msprobe/mindspore/dym_loader/hook_dynamic_loader.cpp +110 -0
  129. msprobe/mindspore/dym_loader/hook_dynamic_loader.h +15 -15
  130. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +12 -6
  131. msprobe/mindspore/free_benchmark/common/utils.py +1 -1
  132. msprobe/mindspore/grad_probe/global_context.py +7 -2
  133. msprobe/mindspore/grad_probe/grad_stat_csv.py +3 -2
  134. msprobe/mindspore/mindspore_service.py +114 -0
  135. msprobe/mindspore/monitor/common_func.py +52 -0
  136. msprobe/mindspore/monitor/data_writers.py +237 -0
  137. msprobe/mindspore/monitor/features.py +20 -7
  138. msprobe/mindspore/monitor/module_hook.py +281 -209
  139. msprobe/mindspore/monitor/optimizer_collect.py +334 -0
  140. msprobe/mindspore/monitor/utils.py +25 -5
  141. msprobe/mindspore/ms_config.py +16 -15
  142. msprobe/mindspore/task_handler_factory.py +5 -2
  143. msprobe/msprobe.py +19 -0
  144. msprobe/nan_analyze/__init__.py +14 -0
  145. msprobe/nan_analyze/analyzer.py +255 -0
  146. msprobe/nan_analyze/graph.py +189 -0
  147. msprobe/nan_analyze/utils.py +211 -0
  148. msprobe/pytorch/api_accuracy_checker/common/config.py +2 -2
  149. msprobe/pytorch/api_accuracy_checker/compare/compare.py +36 -34
  150. msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +20 -20
  151. msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +4 -7
  152. msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +204 -2
  153. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +12 -11
  154. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +1 -0
  155. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +8 -5
  156. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +2 -3
  157. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +29 -13
  158. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +12 -2
  159. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +45 -31
  160. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +156 -0
  161. msprobe/pytorch/attl_manager.py +65 -0
  162. msprobe/pytorch/bench_functions/npu_fusion_attention.py +27 -0
  163. msprobe/pytorch/common/utils.py +26 -14
  164. msprobe/pytorch/compare/distributed_compare.py +4 -36
  165. msprobe/pytorch/compare/pt_compare.py +13 -84
  166. msprobe/pytorch/compare/utils.py +47 -0
  167. msprobe/pytorch/debugger/debugger_config.py +34 -17
  168. msprobe/pytorch/debugger/precision_debugger.py +66 -118
  169. msprobe/pytorch/dump/module_dump/hook_wrapper.py +93 -0
  170. msprobe/pytorch/dump/module_dump/module_dump.py +11 -58
  171. msprobe/pytorch/dump/module_dump/module_processer.py +143 -113
  172. msprobe/pytorch/grad_probe/grad_stat_csv.py +3 -2
  173. msprobe/pytorch/hook_module/api_register.py +29 -5
  174. msprobe/pytorch/hook_module/hook_module.py +9 -18
  175. msprobe/pytorch/hook_module/jit_script_wrapper.py +33 -0
  176. msprobe/pytorch/hook_module/pt_hook_manager.py +68 -0
  177. msprobe/pytorch/hook_module/support_wrap_ops.yaml +22 -1
  178. msprobe/pytorch/hook_module/utils.py +28 -2
  179. msprobe/pytorch/monitor/csv2tb.py +6 -2
  180. msprobe/pytorch/monitor/data_writers.py +259 -0
  181. msprobe/pytorch/monitor/module_hook.py +227 -158
  182. msprobe/pytorch/monitor/module_metric.py +14 -0
  183. msprobe/pytorch/monitor/optimizer_collect.py +242 -270
  184. msprobe/pytorch/monitor/utils.py +16 -3
  185. msprobe/pytorch/online_dispatch/dispatch.py +4 -2
  186. msprobe/pytorch/online_dispatch/dump_compare.py +5 -2
  187. msprobe/pytorch/parse_tool/lib/utils.py +3 -3
  188. msprobe/pytorch/pt_config.py +8 -7
  189. msprobe/pytorch/pytorch_service.py +73 -0
  190. msprobe/visualization/builder/graph_builder.py +33 -13
  191. msprobe/visualization/builder/msprobe_adapter.py +24 -11
  192. msprobe/visualization/compare/graph_comparator.py +53 -45
  193. msprobe/visualization/compare/mode_adapter.py +31 -1
  194. msprobe/visualization/graph/base_node.py +3 -3
  195. msprobe/visualization/graph/graph.py +2 -2
  196. msprobe/visualization/graph_service.py +250 -103
  197. msprobe/visualization/utils.py +27 -11
  198. msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +0 -106
  199. msprobe/mindspore/monitor/anomaly_detect.py +0 -404
  200. msprobe/mindspore/monitor/module_spec_verifier.py +0 -94
  201. msprobe/mindspore/service.py +0 -549
  202. msprobe/pytorch/monitor/anomaly_analyse.py +0 -201
  203. msprobe/pytorch/monitor/anomaly_detect.py +0 -410
  204. msprobe/pytorch/monitor/module_spec_verifier.py +0 -95
  205. msprobe/pytorch/monitor/unittest/test_monitor.py +0 -160
  206. msprobe/pytorch/service.py +0 -473
  207. {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/LICENSE +0 -0
  208. {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/WHEEL +0 -0
  209. {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/entry_points.txt +0 -0
  210. {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/top_level.txt +0 -0
  211. /msprobe/{mindspore → core}/compare/ms_to_pt_api.yaml +0 -0
  212. /msprobe/{mindspore/dump → core}/kernel_dump/kernel_config.py +0 -0
  213. /msprobe/{pytorch/monitor/unittest → core/monitor}/__init__.py +0 -0
@@ -13,418 +13,30 @@
13
13
  # See the License for the specific language governing permissions and
14
14
  # limitations under the License.
15
15
 
16
- import os
17
- import re
18
- from collections import defaultdict
19
-
20
- import numpy as np
21
- import pandas as pd
22
-
23
- from msprobe.core.common.const import CompareConst, Const
24
- from msprobe.core.common.exceptions import FileCheckException
25
- from msprobe.core.common.file_utils import create_directory, load_json, load_npy, load_yaml
26
- from msprobe.core.common.log import logger
27
- from msprobe.core.common.utils import CompareException, check_compare_param, check_configuration_param, \
28
- check_op_str_pattern_valid, get_dump_mode, set_dump_path, detect_framework_by_dump_json
29
- from msprobe.core.compare.acc_compare import Comparator, ModeConfig
30
- from msprobe.core.compare.check import dtype_mapping
16
+ from msprobe.core.compare.acc_compare import Comparator, ModeConfig, MappingConfig, setup_comparison
31
17
  from msprobe.core.compare.layer_mapping import generate_data_mapping_by_layer_mapping
32
- from msprobe.core.compare.utils import set_stack_json_path, reorder_op_x_list
33
-
34
-
35
- class MappingConfig:
36
- def __init__(self, cell_mapping=None, api_mapping=None, data_mapping=None):
37
- self.cell_mapping = cell_mapping
38
- self.api_mapping = api_mapping
39
- self.data_mapping = data_mapping
40
-
41
-
42
- class MSComparator(Comparator):
43
- """
44
- 用于mindspore动态图同框架/跨框架精度比对,支持md5/summary/all模式。
45
- cell_mapping: mindspore在cell级别(L0)dump数据和pytorch的module之间的映射关系;
46
- api_mapping: mindspore在api级别(L1)dump数据和pytorch的api之间的映射关系;
47
- data_mapping: mindspore的cell或api的入参/出参和pytorch之间的映射关系;
48
- is_cross_framework: 是否跨框架。
49
- """
50
- def __init__(self, mode_config, mapping_config=None, is_cross_framework=False):
51
- super().__init__(mode_config)
52
- self.frame_name = MSComparator.__name__
53
-
54
- self.stack_mode = mode_config.stack_mode
55
- self.auto_analyze = mode_config.auto_analyze
56
- self.fuzzy_match = mode_config.fuzzy_match
57
- self.dump_mode = mode_config.dump_mode
58
-
59
- if mapping_config:
60
- self.cell_mapping = mapping_config.cell_mapping
61
- self.api_mapping = mapping_config.api_mapping
62
- self.data_mapping = mapping_config.data_mapping
63
-
64
- if self.data_mapping:
65
- self.cross_frame = is_cross_framework
66
- else:
67
- self.cross_frame = self.cell_mapping is not None or self.api_mapping is not None
68
- self.cell_mapping_dict = self.load_mapping_file(self.cell_mapping)
69
- self.api_mapping_dict = self.load_mapping_file(self.api_mapping)
70
- if self.api_mapping is not None:
71
- self.ms_to_pt_mapping = self.load_internal_api()
72
-
73
- if isinstance(self.data_mapping, str) or self.data_mapping is None:
74
- self.data_mapping_dict = self.load_mapping_file(self.data_mapping)
75
- elif isinstance(self.data_mapping, dict):
76
- self.data_mapping_dict = self.data_mapping
77
- else:
78
- raise TypeError(f"The type of parameter `data_mapping` must be dict, str or None, but got "
79
- f"{type(self.data_mapping)}")
80
-
81
- @staticmethod
82
- def process_data_name(result):
83
- result['data_name_x'] = result.apply(lambda row: [row['data_name_x'], row['data_name_y']], axis=1)
84
- return result
85
-
86
- def calc_accuracy(self, result_df, header):
87
- condition_no_bench = result_df[CompareConst.BENCH_NAME] == CompareConst.N_A
88
- result_df[condition_no_bench] = result_df[condition_no_bench].fillna(CompareConst.N_A)
89
- result_df.loc[condition_no_bench, CompareConst.ERROR_MESSAGE] = CompareConst.NO_BENCH
90
-
91
- def calc_summary_diff(data_type: str):
92
- def type_check(val):
93
- check_series = pd.Series(False, index=val.index)
94
- val_str = val.astype(str)
95
- check_series[pd.to_numeric(val_str, errors='coerce').notna() | val_str.str.lower().eq('nan')] = True
96
- return check_series
97
-
98
- def get_number(val):
99
- return pd.to_numeric(val.astype(str), errors='coerce')
100
-
101
- ms_val = result_df['NPU ' + data_type]
102
- pt_val = result_df['Bench ' + data_type]
103
- diff_name = data_type.capitalize() + ' diff'
104
- rel_err_name = ('norm' if data_type == 'l2norm' else data_type).capitalize() + 'RelativeErr'
105
- condition_na = ~type_check(ms_val) | ~type_check(pt_val)
106
- result_df.loc[condition_na, [diff_name, rel_err_name]] = CompareConst.N_A
107
- result_df.loc[~(condition_no_bench | condition_na), diff_name] = get_number(ms_val) - get_number(pt_val)
108
- condition_nan_diff = ~condition_no_bench & ~condition_na & result_df[diff_name].isna()
109
- condition_not_nan_diff = ~condition_no_bench & ~condition_na & result_df[diff_name].notna()
110
- result_df.loc[condition_nan_diff, [diff_name, rel_err_name]] = CompareConst.NAN
111
- condition_pt_zero = pt_val == 0
112
- result_df.loc[condition_not_nan_diff & condition_pt_zero, rel_err_name] = CompareConst.NAN
113
- condition_ref_err = condition_not_nan_diff & ~condition_pt_zero
114
- result_df.loc[condition_ref_err, rel_err_name] = (result_df.loc[condition_ref_err, diff_name] /
115
- pt_val[condition_ref_err] * 100)
116
- result_df.loc[condition_ref_err, rel_err_name] = (result_df.loc[condition_ref_err, rel_err_name]
117
- .abs().astype(str) + '%')
118
- magnitude = get_number(result_df[diff_name]).abs() / (
119
- pd.Series(np.maximum(get_number(ms_val), get_number(pt_val))).abs() + CompareConst.EPSILON)
120
- return magnitude > CompareConst.MAGNITUDE
121
-
122
- if self.dump_mode == Const.MD5:
123
- condition_md5_equal = result_df[CompareConst.NPU_MD5] == result_df[CompareConst.BENCH_MD5]
124
- result_df.loc[condition_md5_equal, CompareConst.RESULT] = CompareConst.PASS
125
- result_df.loc[~condition_md5_equal & ~condition_no_bench, CompareConst.RESULT] = CompareConst.DIFF
126
- elif self.dump_mode == Const.SUMMARY:
127
- warning_list = [calc_summary_diff(data_type) for data_type in ['max', 'min', 'mean', 'l2norm']]
128
- warning_flag = pd.DataFrame(warning_list).any()
129
- result_df.loc[~condition_no_bench, [CompareConst.RESULT, CompareConst.ERROR_MESSAGE]] = ''
130
- result_df.loc[warning_flag, CompareConst.RESULT] = CompareConst.WARNING
131
- result_df.loc[warning_flag, CompareConst.ERROR_MESSAGE] = 'Need double check api accuracy.'
132
- else:
133
- fill_cols = [CompareConst.COSINE, CompareConst.EUC_DIST,
134
- CompareConst.MAX_ABS_ERR, CompareConst.MAX_RELATIVE_ERR,
135
- CompareConst.ONE_THOUSANDTH_ERR_RATIO, CompareConst.FIVE_THOUSANDTHS_ERR_RATIO,
136
- CompareConst.ERROR_MESSAGE]
137
- result_df.loc[~condition_no_bench, fill_cols] = ''
138
- result_df.loc[~condition_no_bench, CompareConst.ACCURACY] = CompareConst.ACCURACY_CHECK_YES
139
- return result_df[header]
140
-
141
- def make_result_df(self, result):
142
- header = CompareConst.HEAD_OF_COMPARE_MODE[self.dump_mode][:]
143
-
144
- if self.stack_mode:
145
- header.append(CompareConst.STACK)
146
- if self.dump_mode == Const.ALL:
147
- header.append(CompareConst.DATA_NAME)
148
- result = self.process_data_name(result)
149
-
150
- result.rename(columns={'op_name_x': CompareConst.NPU_NAME,
151
- 'op_name_y': CompareConst.BENCH_NAME,
152
- 'dtype_x': CompareConst.NPU_DTYPE,
153
- 'dtype_y': CompareConst.BENCH_DTYPE,
154
- 'shape_x': CompareConst.NPU_SHAPE,
155
- 'shape_y': CompareConst.BENCH_SHAPE,
156
- 'md5_x': CompareConst.NPU_MD5,
157
- 'md5_y': CompareConst.BENCH_MD5,
158
- 'data_name_x': CompareConst.DATA_NAME,
159
- 'stack_info_x': CompareConst.STACK}, inplace=True)
160
-
161
- npu_summary = [CompareConst.NPU_MAX, CompareConst.NPU_MIN, CompareConst.NPU_MEAN, CompareConst.NPU_NORM]
162
- bench_summary = [CompareConst.BENCH_MAX, CompareConst.BENCH_MIN, CompareConst.BENCH_MEAN,
163
- CompareConst.BENCH_NORM]
164
-
165
- def set_summary(summary):
166
- if summary == CompareConst.N_A:
167
- return [CompareConst.N_A] * 4
168
- summary_list = []
169
- for i in summary:
170
- if i is None:
171
- summary_list.append(CompareConst.N_A)
172
- elif str(i).lower() == 'nan':
173
- summary_list.append(CompareConst.NAN)
174
- else:
175
- summary_list.append(i)
176
- return summary_list
177
-
178
- result[npu_summary] = result['summary_x'].apply(set_summary).tolist()
179
- result[bench_summary] = result['summary_y'].apply(set_summary).tolist()
180
-
181
- result_df = pd.DataFrame(columns=header)
182
- for h in header:
183
- if h in result.columns:
184
- result_df[h] = result[h]
185
- return self.calc_accuracy(result_df, header)
186
-
187
- def load_internal_api(self):
188
- cur_path = os.path.dirname(os.path.realpath(__file__))
189
- yaml_path = os.path.abspath(os.path.join(cur_path, CompareConst.INTERNAL_API_MAPPING_FILE))
190
- return load_yaml(yaml_path)
191
-
192
- def load_mapping_file(self, mapping_file):
193
- if isinstance(mapping_file, str):
194
- mapping_dict = load_yaml(mapping_file)
195
- else:
196
- mapping_dict = {}
197
- return mapping_dict
198
-
199
- def process_cell_mapping(self, npu_op_name):
200
- if not npu_op_name:
201
- return CompareConst.N_A
202
- param_grad_flag = Const.PARAMS_GRAD in npu_op_name.split(Const.SEP)
203
- if not param_grad_flag and not re.search(Const.REGEX_FORWARD_BACKWARD, npu_op_name):
204
- return CompareConst.N_A
205
- npu_op_name = npu_op_name.replace("Cell", "Module", 1)
206
- if self.cell_mapping_dict:
207
- # get cell name & class name from op_name
208
- # Cell.fc1.Dense.forward.0.input.0
209
- cell_name = re.split(r'\.(?:forward|backward|parameters_grad)\.', npu_op_name.split(Const.SEP, 1)[-1])[0]
210
- if cell_name in self.cell_mapping_dict:
211
- npu_op_name = npu_op_name.replace(cell_name, self.cell_mapping_dict[cell_name], 1)
212
- return npu_op_name
213
-
214
- def read_npy_data(self, dir_path, file_name, load_pt_file=False):
215
- if not file_name:
216
- return None
217
- data_path = os.path.join(dir_path, file_name)
218
- if load_pt_file:
219
- import torch
220
- from msprobe.pytorch.common.utils import load_pt
221
- data_value = load_pt(data_path, True).detach()
222
- if data_value.dtype == torch.bfloat16:
223
- data_value = data_value.to(torch.float32)
224
- data_value = data_value.numpy()
225
- else:
226
- data_value = load_npy(data_path)
227
- return data_value
228
-
229
- def process_internal_api_mapping(self, npu_op_name):
230
- # get api name & class name from op_name
231
- # Functional.addcmul.0.forward.input.0
232
- ms_api_name = self.get_api_name(npu_op_name.split(Const.SEP))
233
- class_name = ms_api_name.split(Const.SEP)[0]
234
- if class_name == "Mint":
235
- return npu_op_name.replace("Mint", "Torch")
236
- elif class_name == "MintFunctional":
237
- return npu_op_name.replace("MintFunctional", "Functional")
238
- elif self.ms_to_pt_mapping.get(ms_api_name):
239
- return npu_op_name.replace(ms_api_name, self.ms_to_pt_mapping.get(ms_api_name))
240
- else:
241
- return npu_op_name
242
-
243
- def get_api_name(self, api_list):
244
- try:
245
- api_name = api_list[0] + Const.SEP + api_list[1]
246
- except IndexError as error:
247
- logger.error(f'Failed to retrieve API name, please check if the dump data is reasonable')
248
- raise CompareException(CompareException.INDEX_OUT_OF_BOUNDS_ERROR) from error
249
- return api_name
250
-
251
- def compare_process(self, file_lists):
252
- npu_json_path, bench_json_path, stack_json_path = file_lists
253
- npu_json_data = load_json(npu_json_path)
254
- bench_json_data = load_json(bench_json_path)
255
- stack_json_data = load_json(stack_json_path) if self.stack_mode else None
256
-
257
- npu_df = self.gen_data_df(npu_json_data, stack_json_data)
258
- bench_df = self.gen_data_df(bench_json_data, stack_json_data)
259
- if self.cell_mapping:
260
- npu_df[CompareConst.COMPARE_KEY] = npu_df[CompareConst.OP_NAME].apply(self.process_cell_mapping)
261
- elif self.api_mapping:
262
- npu_df[CompareConst.COMPARE_KEY] = npu_df[CompareConst.OP_NAME].apply(self.process_internal_api_mapping)
263
- if isinstance(self.api_mapping, str):
264
- self.modify_compare_data_with_user_mapping(npu_df, bench_df)
265
- else:
266
- npu_df[CompareConst.COMPARE_KEY] = npu_df[CompareConst.OP_NAME]
267
- npu_df[[Const.DTYPE, Const.SHAPE]] = npu_df[[Const.DTYPE, Const.SHAPE]].astype(str)
268
- bench_df[[Const.DTYPE, Const.SHAPE]] = bench_df[[Const.DTYPE, Const.SHAPE]].astype(str)
269
- npu_df[CompareConst.COMPARE_SHAPE] = npu_df[Const.SHAPE]
270
- bench_df[CompareConst.COMPARE_KEY] = bench_df[CompareConst.OP_NAME]
271
- bench_df[CompareConst.COMPARE_SHAPE] = bench_df[Const.SHAPE]
272
- match_result = pd.merge(npu_df, bench_df, on=[CompareConst.COMPARE_KEY, CompareConst.COMPARE_SHAPE],
273
- how='outer')
274
- match_result = match_result[match_result['op_name_x'].notna()].fillna(CompareConst.N_A)
275
-
276
- def gen_dtype_condition():
277
- npu_dtype = match_result['dtype_x']
278
- bench_dtype = match_result['dtype_y']
279
- if self.cross_frame:
280
- npu_dtype = npu_dtype.map(dtype_mapping).fillna(npu_dtype)
281
-
282
- equal_condition = npu_dtype == bench_dtype
283
- match_condition = (
284
- (npu_dtype.isin(CompareConst.DTYPE_MATCH_GROUPS[0]) & bench_dtype.isin(
285
- CompareConst.DTYPE_MATCH_GROUPS[0])) |
286
- (npu_dtype.isin(CompareConst.DTYPE_MATCH_GROUPS[1]) & bench_dtype.isin(
287
- CompareConst.DTYPE_MATCH_GROUPS[1]))
288
- )
289
- return equal_condition | match_condition
290
-
291
- match_result.loc[~gen_dtype_condition(), [i + '_y' for i in bench_df.columns]] = CompareConst.N_A
292
- return self.make_result_df(match_result)
293
-
294
- def modify_compare_data_with_user_mapping(self, npu_df, bench_df):
295
- def get_api_indices_dict(op_name_df):
296
- api_indices_dict = defaultdict(list)
297
- for op_index, name in enumerate(op_name_df[CompareConst.OP_NAME]):
298
- api = self.get_api_name(name.split(Const.SEP))
299
- api_indices_dict[api].append(op_index)
300
- return api_indices_dict
301
-
302
- ms_api_indices_dict = get_api_indices_dict(npu_df)
303
- pt_api_indices_dict = get_api_indices_dict(bench_df)
304
-
305
- def gen_input_compare_key(pattern, term):
306
- flag = True
307
- for i, prefix in enumerate(mapping_dict.get(f'ms_{term}')):
308
- if op_name.split(pattern)[1].startswith(str(prefix)):
309
- npu_df.loc[index, CompareConst.COMPARE_KEY] = (
310
- op_name.replace(pattern + str(prefix),
311
- pattern + str(mapping_dict.get(f'pt_{term}')[i])))
312
- flag = False
313
- return flag
314
-
315
- for mapping_dict in self.api_mapping_dict:
316
- keys_to_compare = [
317
- ('ms_args', 'pt_args'),
318
- ('ms_output', 'pt_output'),
319
- ('ms_parameters', 'pt_parameters'),
320
- ('ms_parameters_grad', 'pt_parameters_grad'),
321
- ]
322
- if not all(len(mapping_dict.get(k1, [])) == len(mapping_dict.get(k2, [])) for k1, k2 in keys_to_compare):
323
- logger.warning('The user-defined mapping table is incorrect,\
324
- make sure that the number of parameters is equal')
325
- continue
326
-
327
- ms_api, pt_api = mapping_dict.get('ms_api'), mapping_dict.get('pt_api')
328
- if ms_api not in ms_api_indices_dict or pt_api not in pt_api_indices_dict:
329
- continue
330
- for index in ms_api_indices_dict.get(ms_api):
331
- op_name = npu_df.loc[index, CompareConst.OP_NAME].replace(ms_api, pt_api, 1)
332
- if CompareConst.INPUT_PATTERN in op_name:
333
- is_abandoned = gen_input_compare_key(CompareConst.INPUT_PATTERN, 'args')
334
- elif CompareConst.KWARGS_PATTERN in op_name:
335
- is_abandoned = gen_input_compare_key(CompareConst.KWARGS_PATTERN, 'args')
336
- elif CompareConst.OUTPUT_PATTERN in op_name:
337
- is_abandoned = gen_input_compare_key(CompareConst.OUTPUT_PATTERN, 'output')
338
- elif CompareConst.PARAMS_PATTERN in op_name:
339
- is_abandoned = gen_input_compare_key(CompareConst.PARAMS_PATTERN, 'parameters')
340
- elif CompareConst.PARAMS_GRAD_PATTERN in op_name:
341
- is_abandoned = gen_input_compare_key(CompareConst.PARAMS_GRAD_PATTERN, 'parameters_grad')
342
- else:
343
- logger.error(f'Excepted op_name: {op_name}')
344
- raise CompareException(CompareException.INVALID_DATA_ERROR)
345
- if is_abandoned:
346
- npu_df.loc[index, CompareConst.COMPARE_KEY] = op_name + 'abandoned'
347
-
348
- def gen_data_df(self, data_json, stack_json_data):
349
- result = {
350
- CompareConst.OP_NAME: [],
351
- Const.DTYPE: [],
352
- Const.SHAPE: [],
353
- Const.SUMMARY: [],
354
- 'stack_info': []
355
- }
356
- if self.dump_mode == Const.ALL:
357
- result['data_name'] = []
358
- elif self.dump_mode == Const.MD5:
359
- result[Const.MD5] = []
360
- for data_name in data_json['data']:
361
- check_op_str_pattern_valid(data_name)
362
- merge_list = self.gen_merge_list(data_json, data_name, stack_json_data)
363
- if not merge_list:
364
- continue
365
-
366
- op_name_list = merge_list.get(CompareConst.OP_NAME)
367
- summary_list = merge_list.get(Const.SUMMARY)
368
- data_name_list = merge_list.get('data_name')
369
- op_name_reorder, summary_reorder, data_name_reorder = reorder_op_x_list(op_name_list,
370
- summary_list,
371
- data_name_list)
372
- for op_name in op_name_reorder:
373
- result[CompareConst.OP_NAME].append(op_name)
374
- if (CompareConst.INPUT_PATTERN in op_name) or (CompareConst.KWARGS_PATTERN in op_name):
375
- struct = merge_list[CompareConst.INPUT_STRUCT].pop(0)
376
- elif CompareConst.OUTPUT_PATTERN in op_name:
377
- struct = merge_list[CompareConst.OUTPUT_STRUCT].pop(0)
378
- elif CompareConst.PARAMS_PATTERN in op_name:
379
- struct = merge_list[CompareConst.PARAMS_STRUCT].pop(0)
380
- else:
381
- struct = merge_list[CompareConst.PARAMS_GRAD_STRUCT].pop(0)
382
- result[Const.DTYPE].append(struct[0])
383
- result[Const.SHAPE].append(struct[1])
384
- if self.dump_mode == Const.MD5:
385
- result[Const.MD5].append(struct[2])
386
- result[Const.SUMMARY].append(summary_reorder.pop(0))
387
- result['stack_info'].append(merge_list['stack_info'][0] if self.stack_mode else None)
388
- if self.dump_mode == Const.ALL:
389
- result['data_name'].append(data_name_reorder.pop(0))
390
- return pd.DataFrame(result)
18
+ from msprobe.mindspore.compare.utils import read_npy_data, check_cross_framework
391
19
 
392
20
 
393
- def check_cross_framework(bench_json_path):
394
- framework = detect_framework_by_dump_json(bench_json_path)
395
- if framework == Const.PT_FRAMEWORK:
396
- return True
21
+ def read_real_data(npu_dir, npu_data_name, bench_dir, bench_data_name, cross_frame) -> tuple:
22
+ n_value = read_npy_data(npu_dir, npu_data_name)
23
+ if cross_frame:
24
+ from msprobe.pytorch.compare.utils import read_pt_data
25
+ b_value = read_pt_data(bench_dir, bench_data_name)
397
26
  else:
398
- return False
27
+ b_value = read_npy_data(bench_dir, bench_data_name)
28
+ return n_value, b_value
399
29
 
400
30
 
401
31
  def ms_compare(input_param, output_path, **kwargs):
402
- try:
403
- auto_analyze = kwargs.get('auto_analyze', True)
404
- fuzzy_match = kwargs.get('fuzzy_match', False)
405
- cell_mapping = kwargs.get('cell_mapping', None)
406
- api_mapping = kwargs.get('api_mapping', None)
407
- data_mapping = kwargs.get('data_mapping', None)
408
- layer_mapping = kwargs.get('layer_mapping', None)
409
- suffix = kwargs.get('suffix', '')
32
+ config = setup_comparison(input_param, output_path, **kwargs)
410
33
 
411
- set_dump_path(input_param)
412
- dump_mode = get_dump_mode(input_param)
413
- if 'stack_json_path' in input_param:
414
- stack_mode = kwargs.get('stack_mode', False)
415
- else:
416
- stack_mode = set_stack_json_path(input_param) # set stack_mode and set "stack_json_path" in input_param
417
- check_configuration_param(stack_mode, auto_analyze, fuzzy_match, input_param.get('is_print_compare_log', True))
418
- create_directory(output_path)
419
- check_compare_param(input_param, output_path, dump_mode, stack_mode)
420
- except (CompareException, FileCheckException) as error:
421
- logger.error('Compare failed. Please check the arguments and do it again!')
422
- raise CompareException(error.code) from error
423
- if layer_mapping:
424
- data_mapping = generate_data_mapping_by_layer_mapping(input_param, layer_mapping, output_path)
34
+ if config.layer_mapping:
35
+ config.data_mapping = generate_data_mapping_by_layer_mapping(input_param, config.layer_mapping, output_path)
425
36
 
426
- mode_config = ModeConfig(stack_mode, auto_analyze, fuzzy_match, dump_mode)
427
- mapping_config = MappingConfig(cell_mapping, api_mapping, data_mapping)
428
37
  is_cross_framework = check_cross_framework(input_param.get('bench_json_path'))
429
- ms_comparator = MSComparator(mode_config, mapping_config, is_cross_framework)
430
- ms_comparator.compare_core(input_param, output_path, suffix=suffix)
38
+ mode_config = ModeConfig(config.stack_mode, config.auto_analyze, config.fuzzy_match,
39
+ config.dump_mode, config.compared_file_type)
40
+ mapping_config = MappingConfig(config.cell_mapping, config.api_mapping, config.data_mapping)
41
+ ms_comparator = Comparator(read_real_data, mode_config, mapping_config, is_cross_framework)
42
+ ms_comparator.compare_core(input_param, output_path, suffix=config.suffix)
@@ -85,11 +85,13 @@ def statistic_data_read(statistic_file_list, statistic_file_path):
85
85
  }
86
86
  for statistic_file in statistic_file_list:
87
87
  content = read_csv(statistic_file, as_pd=False)
88
+ if not content:
89
+ logger.error(f'Empty dump file: {statistic_file}')
90
+ raise CompareException(f'Empty dump file: {statistic_file}')
88
91
  header = content[0]
89
- for key in header_index.keys():
90
- for index, value in enumerate(header):
91
- if key == value:
92
- header_index[key] = index
92
+ for index, value in enumerate(header):
93
+ if value in header_index:
94
+ header_index[value] = index
93
95
  statistic_data_list.extend(content[1:])
94
96
 
95
97
  for key in header_index.keys():
@@ -97,7 +99,14 @@ def statistic_data_read(statistic_file_list, statistic_file_path):
97
99
  logger.warning(f"Data_path {statistic_file_path} has no key {key}.")
98
100
 
99
101
  for data in statistic_data_list:
100
- compare_key = f"{data[1]}.{data[2]}.{data[3]}.{data[5]}"
102
+ '''
103
+ 13列分别是OpType, OpName, TaskId, StreamId, TimeStamp, IO, Slot, DataSize,
104
+ DataType, Shape, MaxValue, MinValue, L2NormValue
105
+ '''
106
+ if len(data) < 13:
107
+ logger.error(f'Dump file {statistic_file_path} has been modified into incorrect format!')
108
+ raise CompareException(f'Dump file {statistic_file_path} has been modified into incorrect format!')
109
+ compare_key = f"{data[1]}.{data[2]}.{data[5]}.{data[6]}" # OpName, TaskId, IO, Slot
101
110
  op_name = f"{compare_key} {statistic_file_path}"
102
111
  timestamp = int(data[4])
103
112
  result_data = [op_name, compare_key, timestamp]
@@ -0,0 +1,37 @@
1
+ # Copyright (c) 2025-2025, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import os
17
+
18
+ from msprobe.core.common.const import Const
19
+ from msprobe.core.common.file_utils import load_npy, FileChecker, FileCheckConst
20
+ from msprobe.core.common.utils import detect_framework_by_dump_json
21
+
22
+
23
+ def read_npy_data(dir_path, file_name):
24
+ if not file_name:
25
+ return None
26
+
27
+ data_path = os.path.join(dir_path, file_name)
28
+ path_checker = FileChecker(data_path, FileCheckConst.FILE, FileCheckConst.READ_ABLE,
29
+ FileCheckConst.NUMPY_SUFFIX, False)
30
+ data_path = path_checker.common_check()
31
+ data_value = load_npy(data_path)
32
+ return data_value
33
+
34
+
35
+ def check_cross_framework(bench_json_path):
36
+ framework = detect_framework_by_dump_json(bench_json_path)
37
+ return framework == Const.PT_FRAMEWORK
@@ -15,12 +15,18 @@
15
15
 
16
16
  import os
17
17
 
18
+ from mindspore import nn
19
+
18
20
  from msprobe.core.common.const import Const
19
21
  from msprobe.core.common.exceptions import MsprobeException
20
22
  from msprobe.core.common.file_utils import create_directory
23
+ from msprobe.core.common.log import logger
21
24
  from msprobe.mindspore.common.const import Const as MsConst
22
25
  from msprobe.mindspore.common.const import FreeBenchmarkConst
23
- from msprobe.core.common.log import logger
26
+ from msprobe.mindspore.common.utils import is_mindtorch
27
+
28
+ if is_mindtorch():
29
+ import torch
24
30
 
25
31
 
26
32
  class DebuggerConfig:
@@ -41,8 +47,12 @@ class DebuggerConfig:
41
47
  self.check_mode = task_config.check_mode
42
48
  self.framework = Const.MS_FRAMEWORK
43
49
  self.summary_mode = task_config.summary_mode
50
+ self.stat_cal_mode = task_config.stat_cal_mode if hasattr(task_config, 'stat_cal_mode') else None
51
+ self.device_stat_precision_mode = task_config.device_stat_precision_mode \
52
+ if hasattr(task_config, 'device_stat_precision_mode') else None
44
53
  self.async_dump = common_config.async_dump if common_config.async_dump else False
45
54
  self.check()
55
+ self._check_statistics_config(task_config)
46
56
  create_directory(self.dump_path)
47
57
 
48
58
  if self.task == Const.FREE_BENCHMARK:
@@ -62,6 +72,31 @@ class DebuggerConfig:
62
72
  raise ValueError
63
73
  self.dump_level = FreeBenchmarkConst.DEFAULT_DUMP_LEVEL
64
74
 
75
+ @staticmethod
76
+ def check_model(models, token_range=None):
77
+ if token_range and not models:
78
+ error_info = "The 'model' parameter must be provided when token_range is not None"
79
+ raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR, error_info)
80
+
81
+ target_module_type = (torch.nn.Module, "torch.nn.Module") if is_mindtorch() else (nn.Cell, "mindspore.nn.Cell")
82
+ if models is None or isinstance(models, target_module_type[0]):
83
+ return models
84
+ error_model = None
85
+ if isinstance(models, (list, tuple)):
86
+ for model in models:
87
+ if not isinstance(model, target_module_type[0]):
88
+ error_model = model
89
+ break
90
+ else:
91
+ error_model = models
92
+
93
+ if error_model is not None:
94
+ error_info = (f"The 'model' parameter must be a {target_module_type[1]} or list[{target_module_type[1]}] "
95
+ f"type, currently there is a {type(error_model)} type.")
96
+ raise MsprobeException(
97
+ MsprobeException.INVALID_PARAM_ERROR, error_info)
98
+ return models
99
+
65
100
  def check(self):
66
101
  if not self.dump_path:
67
102
  raise Exception("Dump path is empty.")
@@ -76,8 +111,12 @@ class DebuggerConfig:
76
111
  self.check_mode = "all"
77
112
  if not isinstance(self.async_dump, bool):
78
113
  raise Exception("The parameters async_dump should be bool.")
79
- if self.async_dump and self.task == Const.TENSOR and not self.list:
80
- raise Exception("The parameters async_dump is true in tensor task, the parameters list cannot be empty.")
114
+ if self.async_dump and self.task == Const.TENSOR:
115
+ if self.level_ori == Const.LEVEL_DEBUG:
116
+ self.list = [] # async_dump + debug level case ignore list
117
+ if not self.list and self.level_ori != Const.LEVEL_DEBUG:
118
+ raise Exception("The parameters async_dump is true in tensor task,"
119
+ " the parameters list cannot be empty.")
81
120
  if self.task == Const.STRUCTURE and self.level_ori not in [Const.LEVEL_L0, Const.LEVEL_MIX]:
82
121
  logger.warning_on_rank_0(
83
122
  f"When the task is set to structure, the level should be one of {[Const.LEVEL_L0, Const.LEVEL_MIX]}. "
@@ -98,3 +137,14 @@ class DebuggerConfig:
98
137
  if not self.list or len(self.list) != 1:
99
138
  raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR,
100
139
  f"When level is set to L2, the list must be configured as a list with one api name.")
140
+
141
+ def _check_statistics_config(self, task_config):
142
+ if self.task != Const.STATISTICS:
143
+ return
144
+ self.tensor_list = []
145
+ if not hasattr(task_config, "tensor_list"):
146
+ return
147
+ if self.level_ori == Const.LEVEL_DEBUG and task_config.tensor_list:
148
+ logger.warning_on_rank_0("When level is set to debug, the tensor_list will be invalid.")
149
+ return
150
+ self.tensor_list = task_config.tensor_list