mindstudio-probe 1.3.0__py3-none-any.whl → 8.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (213) hide show
  1. {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/METADATA +4 -2
  2. {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/RECORD +204 -152
  3. msprobe/README.md +32 -1
  4. msprobe/core/__init__.py +17 -0
  5. msprobe/core/common/const.py +120 -21
  6. msprobe/core/common/exceptions.py +2 -2
  7. msprobe/core/common/file_utils.py +279 -50
  8. msprobe/core/common/framework_adapter.py +169 -0
  9. msprobe/core/common/global_lock.py +86 -0
  10. msprobe/core/common/runtime.py +25 -0
  11. msprobe/core/common/utils.py +136 -45
  12. msprobe/core/common_config.py +7 -0
  13. msprobe/core/compare/acc_compare.py +646 -428
  14. msprobe/core/compare/check.py +36 -103
  15. msprobe/core/compare/compare_cli.py +4 -0
  16. msprobe/core/compare/config.py +72 -0
  17. msprobe/core/compare/highlight.py +215 -215
  18. msprobe/core/compare/layer_mapping/layer_mapping.py +2 -0
  19. msprobe/core/compare/merge_result/merge_result.py +4 -4
  20. msprobe/core/compare/multiprocessing_compute.py +223 -110
  21. msprobe/core/compare/npy_compare.py +2 -4
  22. msprobe/core/compare/utils.py +214 -244
  23. msprobe/core/config_check/__init__.py +17 -0
  24. msprobe/{pytorch/dump/kernel_dump/kernel_config.py → core/config_check/checkers/__init__.py} +8 -16
  25. msprobe/core/config_check/checkers/base_checker.py +60 -0
  26. msprobe/core/config_check/checkers/dataset_checker.py +138 -0
  27. msprobe/core/config_check/checkers/env_args_checker.py +96 -0
  28. msprobe/core/config_check/checkers/hyperparameter_checker.py +170 -0
  29. msprobe/core/config_check/checkers/pip_checker.py +90 -0
  30. msprobe/core/config_check/checkers/random_checker.py +367 -0
  31. msprobe/core/config_check/checkers/weights_checker.py +147 -0
  32. msprobe/core/config_check/ckpt_compare/ckpt_comparator.py +74 -0
  33. msprobe/core/config_check/ckpt_compare/megatron_loader.py +302 -0
  34. msprobe/core/config_check/ckpt_compare/metrics.py +83 -0
  35. msprobe/core/config_check/ckpt_compare/name_mapping.yaml +12 -0
  36. msprobe/core/config_check/config_check_cli.py +51 -0
  37. msprobe/core/config_check/config_checker.py +100 -0
  38. msprobe/{mindspore/runtime.py → core/config_check/resource/dependency.yaml} +7 -4
  39. msprobe/core/config_check/resource/env.yaml +57 -0
  40. msprobe/core/config_check/resource/hyperparameter.yaml +21 -0
  41. msprobe/core/config_check/utils/hyperparameter_parser.py +115 -0
  42. msprobe/core/config_check/utils/utils.py +107 -0
  43. msprobe/core/data_dump/api_registry.py +67 -4
  44. msprobe/core/data_dump/data_collector.py +170 -89
  45. msprobe/core/data_dump/data_processor/base.py +72 -51
  46. msprobe/core/data_dump/data_processor/mindspore_processor.py +109 -55
  47. msprobe/core/data_dump/data_processor/pytorch_processor.py +90 -82
  48. msprobe/core/data_dump/json_writer.py +143 -27
  49. msprobe/core/debugger/precision_debugger.py +144 -0
  50. msprobe/core/grad_probe/constant.py +1 -1
  51. msprobe/core/grad_probe/grad_compare.py +1 -1
  52. msprobe/core/grad_probe/utils.py +1 -1
  53. msprobe/core/hook_manager.py +242 -0
  54. msprobe/core/monitor/anomaly_processor.py +384 -0
  55. msprobe/core/service.py +357 -0
  56. msprobe/core/single_save/__init__.py +0 -0
  57. msprobe/core/single_save/single_comparator.py +243 -0
  58. msprobe/core/single_save/single_saver.py +146 -0
  59. msprobe/docs/01.installation.md +6 -5
  60. msprobe/docs/02.config_introduction.md +79 -22
  61. msprobe/docs/03.config_examples.md +1 -0
  62. msprobe/docs/04.kernel_dump_PyTorch.md +1 -1
  63. msprobe/docs/05.data_dump_PyTorch.md +118 -49
  64. msprobe/docs/06.data_dump_MindSpore.md +167 -20
  65. msprobe/docs/07.accuracy_checker_PyTorch.md +2 -2
  66. msprobe/docs/08.accuracy_checker_online_PyTorch.md +69 -9
  67. msprobe/docs/09.accuracy_checker_MindSpore.md +18 -6
  68. msprobe/docs/10.accuracy_compare_PyTorch.md +212 -74
  69. msprobe/docs/11.accuracy_compare_MindSpore.md +87 -37
  70. msprobe/docs/12.overflow_check_PyTorch.md +2 -2
  71. msprobe/docs/13.overflow_check_MindSpore.md +2 -2
  72. msprobe/docs/14.data_parse_PyTorch.md +3 -3
  73. msprobe/docs/17.grad_probe.md +2 -1
  74. msprobe/docs/18.online_dispatch.md +2 -2
  75. msprobe/docs/19.monitor.md +90 -44
  76. msprobe/docs/21.visualization_PyTorch.md +68 -15
  77. msprobe/docs/22.visualization_MindSpore.md +71 -18
  78. msprobe/docs/25.tool_function_introduction.md +23 -22
  79. msprobe/docs/26.data_dump_PyTorch_baseline.md +14 -3
  80. msprobe/docs/27.dump_json_instruction.md +1 -1
  81. msprobe/docs/28.debugger_save_instruction.md +111 -20
  82. msprobe/docs/29.data_dump_MSAdapter.md +2 -2
  83. msprobe/docs/30.overflow_check_MSAdapter.md +2 -2
  84. msprobe/docs/31.config_check.md +95 -0
  85. msprobe/docs/32.ckpt_compare.md +69 -0
  86. msprobe/docs/33.generate_operator_MindSpore.md +181 -0
  87. msprobe/docs/34.RL_collect.md +92 -0
  88. msprobe/docs/35.nan_analyze.md +72 -0
  89. msprobe/docs/data_dump_MindSpore/data_dump_MindSpore_baseline.md +12 -1
  90. msprobe/docs/data_dump_MindSpore/dynamic_graph_quick_start_example.md +3 -1
  91. msprobe/docs/img/compare_result.png +0 -0
  92. msprobe/docs/img/save_compare_result_sample.png +0 -0
  93. msprobe/docs/img/visualization/proxy.png +0 -0
  94. msprobe/mindspore/__init__.py +1 -2
  95. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +150 -58
  96. msprobe/mindspore/api_accuracy_checker/api_runner.py +7 -3
  97. msprobe/mindspore/api_accuracy_checker/bench_functions/flash_attention_score.py +47 -69
  98. msprobe/mindspore/api_accuracy_checker/cmd_parser.py +4 -0
  99. msprobe/mindspore/api_accuracy_checker/compute_element.py +0 -1
  100. msprobe/mindspore/api_accuracy_checker/data_manager.py +2 -2
  101. msprobe/mindspore/api_accuracy_checker/generate_op_script/op_generator.py +460 -0
  102. msprobe/mindspore/api_accuracy_checker/generate_op_script/operator_replication.template +2081 -0
  103. msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +9 -0
  104. msprobe/mindspore/api_accuracy_checker/torch_mindtorch_importer.py +2 -1
  105. msprobe/mindspore/cell_processor.py +204 -33
  106. msprobe/mindspore/code_mapping/graph_parser.py +4 -21
  107. msprobe/mindspore/common/const.py +17 -7
  108. msprobe/mindspore/common/utils.py +128 -11
  109. msprobe/mindspore/compare/common_dir_compare.py +382 -0
  110. msprobe/mindspore/compare/distributed_compare.py +2 -26
  111. msprobe/mindspore/compare/ms_compare.py +17 -405
  112. msprobe/mindspore/compare/ms_graph_compare.py +14 -5
  113. msprobe/mindspore/compare/utils.py +37 -0
  114. msprobe/mindspore/debugger/debugger_config.py +53 -3
  115. msprobe/mindspore/debugger/precision_debugger.py +72 -91
  116. msprobe/mindspore/dump/cell_dump_process.py +877 -0
  117. msprobe/mindspore/dump/cell_dump_with_insert_gradient.py +864 -0
  118. msprobe/mindspore/dump/dump_tool_factory.py +13 -5
  119. msprobe/mindspore/dump/graph_mode_cell_dump.py +139 -0
  120. msprobe/mindspore/dump/graph_tensor_dump.py +123 -0
  121. msprobe/mindspore/dump/hook_cell/api_register.py +40 -6
  122. msprobe/mindspore/dump/hook_cell/hook_cell.py +18 -7
  123. msprobe/mindspore/dump/hook_cell/ms_hook_manager.py +88 -0
  124. msprobe/mindspore/dump/hook_cell/primitive_hooks.py +8 -2
  125. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +18 -0
  126. msprobe/mindspore/dump/jit_dump.py +21 -18
  127. msprobe/mindspore/dump/kernel_kbyk_dump.py +6 -3
  128. msprobe/mindspore/dym_loader/hook_dynamic_loader.cpp +110 -0
  129. msprobe/mindspore/dym_loader/hook_dynamic_loader.h +15 -15
  130. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +12 -6
  131. msprobe/mindspore/free_benchmark/common/utils.py +1 -1
  132. msprobe/mindspore/grad_probe/global_context.py +7 -2
  133. msprobe/mindspore/grad_probe/grad_stat_csv.py +3 -2
  134. msprobe/mindspore/mindspore_service.py +114 -0
  135. msprobe/mindspore/monitor/common_func.py +52 -0
  136. msprobe/mindspore/monitor/data_writers.py +237 -0
  137. msprobe/mindspore/monitor/features.py +20 -7
  138. msprobe/mindspore/monitor/module_hook.py +281 -209
  139. msprobe/mindspore/monitor/optimizer_collect.py +334 -0
  140. msprobe/mindspore/monitor/utils.py +25 -5
  141. msprobe/mindspore/ms_config.py +16 -15
  142. msprobe/mindspore/task_handler_factory.py +5 -2
  143. msprobe/msprobe.py +19 -0
  144. msprobe/nan_analyze/__init__.py +14 -0
  145. msprobe/nan_analyze/analyzer.py +255 -0
  146. msprobe/nan_analyze/graph.py +189 -0
  147. msprobe/nan_analyze/utils.py +211 -0
  148. msprobe/pytorch/api_accuracy_checker/common/config.py +2 -2
  149. msprobe/pytorch/api_accuracy_checker/compare/compare.py +36 -34
  150. msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +20 -20
  151. msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +4 -7
  152. msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +204 -2
  153. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +12 -11
  154. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +1 -0
  155. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +8 -5
  156. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +2 -3
  157. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +29 -13
  158. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +12 -2
  159. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +45 -31
  160. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +156 -0
  161. msprobe/pytorch/attl_manager.py +65 -0
  162. msprobe/pytorch/bench_functions/npu_fusion_attention.py +27 -0
  163. msprobe/pytorch/common/utils.py +26 -14
  164. msprobe/pytorch/compare/distributed_compare.py +4 -36
  165. msprobe/pytorch/compare/pt_compare.py +13 -84
  166. msprobe/pytorch/compare/utils.py +47 -0
  167. msprobe/pytorch/debugger/debugger_config.py +34 -17
  168. msprobe/pytorch/debugger/precision_debugger.py +66 -118
  169. msprobe/pytorch/dump/module_dump/hook_wrapper.py +93 -0
  170. msprobe/pytorch/dump/module_dump/module_dump.py +11 -58
  171. msprobe/pytorch/dump/module_dump/module_processer.py +143 -113
  172. msprobe/pytorch/grad_probe/grad_stat_csv.py +3 -2
  173. msprobe/pytorch/hook_module/api_register.py +29 -5
  174. msprobe/pytorch/hook_module/hook_module.py +9 -18
  175. msprobe/pytorch/hook_module/jit_script_wrapper.py +33 -0
  176. msprobe/pytorch/hook_module/pt_hook_manager.py +68 -0
  177. msprobe/pytorch/hook_module/support_wrap_ops.yaml +22 -1
  178. msprobe/pytorch/hook_module/utils.py +28 -2
  179. msprobe/pytorch/monitor/csv2tb.py +6 -2
  180. msprobe/pytorch/monitor/data_writers.py +259 -0
  181. msprobe/pytorch/monitor/module_hook.py +227 -158
  182. msprobe/pytorch/monitor/module_metric.py +14 -0
  183. msprobe/pytorch/monitor/optimizer_collect.py +242 -270
  184. msprobe/pytorch/monitor/utils.py +16 -3
  185. msprobe/pytorch/online_dispatch/dispatch.py +4 -2
  186. msprobe/pytorch/online_dispatch/dump_compare.py +5 -2
  187. msprobe/pytorch/parse_tool/lib/utils.py +3 -3
  188. msprobe/pytorch/pt_config.py +8 -7
  189. msprobe/pytorch/pytorch_service.py +73 -0
  190. msprobe/visualization/builder/graph_builder.py +33 -13
  191. msprobe/visualization/builder/msprobe_adapter.py +24 -11
  192. msprobe/visualization/compare/graph_comparator.py +53 -45
  193. msprobe/visualization/compare/mode_adapter.py +31 -1
  194. msprobe/visualization/graph/base_node.py +3 -3
  195. msprobe/visualization/graph/graph.py +2 -2
  196. msprobe/visualization/graph_service.py +250 -103
  197. msprobe/visualization/utils.py +27 -11
  198. msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +0 -106
  199. msprobe/mindspore/monitor/anomaly_detect.py +0 -404
  200. msprobe/mindspore/monitor/module_spec_verifier.py +0 -94
  201. msprobe/mindspore/service.py +0 -549
  202. msprobe/pytorch/monitor/anomaly_analyse.py +0 -201
  203. msprobe/pytorch/monitor/anomaly_detect.py +0 -410
  204. msprobe/pytorch/monitor/module_spec_verifier.py +0 -95
  205. msprobe/pytorch/monitor/unittest/test_monitor.py +0 -160
  206. msprobe/pytorch/service.py +0 -473
  207. {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/LICENSE +0 -0
  208. {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/WHEEL +0 -0
  209. {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/entry_points.txt +0 -0
  210. {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/top_level.txt +0 -0
  211. /msprobe/{mindspore → core}/compare/ms_to_pt_api.yaml +0 -0
  212. /msprobe/{mindspore/dump → core}/kernel_dump/kernel_config.py +0 -0
  213. /msprobe/{pytorch/monitor/unittest → core/monitor}/__init__.py +0 -0
@@ -0,0 +1,864 @@
1
+ # Copyright (c) 2025-2025, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import atexit
17
+ from multiprocessing import Pool
18
+ import os
19
+ import re
20
+ import time
21
+
22
+ import numpy as np
23
+ import pandas as pd
24
+ import mindspore as ms
25
+ from mindspore import nn, ops
26
+
27
+ from msprobe.core.common.const import Const as CoreConst
28
+ from msprobe.core.common.const import FileCheckConst
29
+ from msprobe.core.common.file_utils import (
30
+ load_npy, save_json, remove_path, load_yaml,
31
+ create_directory, read_csv, write_df_to_csv, write_csv, move_file, move_directory)
32
+ from msprobe.mindspore.common.log import logger
33
+ from msprobe.mindspore.dump.cell_dump_process import CellDumpConfig
34
+
35
+
36
+ CONSTRUCT_FILE_NAME = "construct.json"
37
+ DEFAULT_RANK_DIR = "rank0"
38
+ KEY_LAYERS = "layers"
39
+ construct = {}
40
+ cell_list = []
41
+ KEY_SIDE_EFFECT = "side_effect_io"
42
+ KEY_TOPLAYER = "TopLayer"
43
+ KEY_FORWARD = CoreConst.FORWARD
44
+ KEY_BACKWARD = CoreConst.BACKWARD
45
+ KEY_INPUT = CoreConst.INPUT
46
+ KEY_OUTPUT = CoreConst.OUTPUT
47
+ KEY_DUMP_TENSOR_DATA = "dump_tensor_data_"
48
+ KEY_STATISTIC_CSV = "statistic.csv"
49
+ KEY_TD_FLAG = "td_flag"
50
+ td = ops.TensorDump()
51
+ if (ms.__version__ >= "2.5.0"):
52
+ td_in = ops.TensorDump("in")
53
+ else:
54
+ td_in = ops.TensorDump()
55
+ graph_step_flag = True
56
+ try:
57
+ from mindspore._c_expression import _set_init_iter
58
+ except ImportError:
59
+ graph_step_flag = False
60
+ td.add_prim_attr(KEY_SIDE_EFFECT, False)
61
+ td_in.add_prim_attr(KEY_SIDE_EFFECT, False)
62
+ td.add_prim_attr(KEY_TD_FLAG, True)
63
+ td_in.add_prim_attr(KEY_TD_FLAG, True)
64
+ dump_task = CoreConst.STATISTICS
65
+ np_ms_dtype_dict = {
66
+ "bool": ms.bool_,
67
+ "int8": ms.int8,
68
+ "byte": ms.byte,
69
+ "int16": ms.int16,
70
+ "short": ms.short,
71
+ "int32": ms.int32,
72
+ "intc": ms.intc,
73
+ "int64": ms.int64,
74
+ "intp": ms.intp,
75
+ "uint8": ms.uint8,
76
+ "ubyte": ms.ubyte,
77
+ "uint16": ms.uint16,
78
+ "ushort": ms.ushort,
79
+ "uint32": ms.uint32,
80
+ "uintc": ms.uintc,
81
+ "uint64": ms.uint64,
82
+ "uintp": ms.uintp,
83
+ "float16": ms.float16,
84
+ "half": ms.half,
85
+ "float32": ms.float32,
86
+ "single": ms.single,
87
+ "float64": ms.float64,
88
+ "double": ms.double,
89
+ "bfloat16": ms.bfloat16,
90
+ "complex64": ms.complex64,
91
+ "complex128": ms.complex128
92
+ }
93
+
94
+
95
+ def gen_file_path(dump_path, cell_prefix, suffix, io_type, index):
96
+ data_path = os.path.join(dump_path, '{step}', '{rank}', CoreConst.DUMP_TENSOR_DATA)
97
+ file_name = ""
98
+ if dump_task == CoreConst.TENSOR:
99
+ file_name = cell_prefix + CoreConst.SEP + suffix + CoreConst.SEP + io_type + CoreConst.SEP + str(index)
100
+ if dump_task == CoreConst.STATISTICS:
101
+ file_name = cell_prefix + CoreConst.HYPHEN + suffix + CoreConst.HYPHEN + io_type + CoreConst.HYPHEN + str(index)
102
+ return os.path.join(data_path, file_name)
103
+
104
+
105
+ def partial_func(func, dump_path, cell_prefix, index, io_type):
106
+ def newfunc(*args, **kwargs):
107
+ return func(dump_path, cell_prefix, index, io_type, *args, **kwargs)
108
+ return newfunc
109
+
110
+
111
+ def clip_gradient(dump_path, cell_prefix, index, io_type, dx):
112
+ if io_type == KEY_OUTPUT:
113
+ temp = td(gen_file_path(dump_path, cell_prefix, KEY_BACKWARD, io_type, index), dx)
114
+ dx = ops.depend(dx, temp)
115
+ elif io_type == KEY_INPUT:
116
+ temp = td_in(gen_file_path(dump_path, cell_prefix, KEY_BACKWARD, io_type, index), dx)
117
+ dx = ops.depend(dx, temp)
118
+ return dx
119
+
120
+
121
+ def need_tensordump_in(cell_obj, attr):
122
+ return hasattr(cell_obj, attr) and getattr(cell_obj, attr) == "in"
123
+
124
+
125
+ def cell_construct_wrapper(func, self):
126
+ def new_construct(self, *args, **kwargs):
127
+ new_args = []
128
+ out_list = []
129
+
130
+ index = 0
131
+ item = None
132
+ backward_or_all = self.data_mode in ["backward", "all"]
133
+ forward_or_all = self.data_mode in ["forward", "all"]
134
+ # The inputs of the cell.
135
+ for index, item in enumerate(args):
136
+ if backward_or_all and ops.is_tensor(item):
137
+ item = self.output_clips[index](item)
138
+ if forward_or_all and ops.is_tensor(item):
139
+ if need_tensordump_in(self, 'input_dump_mode'):
140
+ temp = td_in(
141
+ gen_file_path(self.dump_path, self.cell_prefix, KEY_FORWARD, KEY_INPUT, index),
142
+ item
143
+ )
144
+ else:
145
+ temp = td(
146
+ gen_file_path(self.dump_path, self.cell_prefix, KEY_FORWARD, KEY_INPUT, index),
147
+ item
148
+ )
149
+ item = ops.depend(item, temp)
150
+ new_args.append(item)
151
+
152
+ out = func(*new_args, **kwargs)
153
+
154
+ # The outputs of the cell.
155
+ if isinstance(out, tuple):
156
+ for index, item in enumerate(out):
157
+ if backward_or_all and ops.is_tensor(item):
158
+ item = self.input_clips[index](item)
159
+ if forward_or_all and ops.is_tensor(item):
160
+ if need_tensordump_in(self, 'output_dump_mode'):
161
+ temp = td_in(
162
+ gen_file_path(self.dump_path, self.cell_prefix, KEY_FORWARD, KEY_OUTPUT, index),
163
+ item
164
+ )
165
+ else:
166
+ temp = td(
167
+ gen_file_path(self.dump_path, self.cell_prefix, KEY_FORWARD, KEY_OUTPUT, index),
168
+ item
169
+ )
170
+ item = ops.depend(item, temp)
171
+ out_list.append(item)
172
+ elif forward_or_all and not ops.is_tensor(item):
173
+ out_list.append(item)
174
+ out_list = tuple(out_list)
175
+ return out_list
176
+ else:
177
+ if backward_or_all:
178
+ out = self.input_clips[0](out)
179
+ if forward_or_all and ops.is_tensor(out):
180
+ if need_tensordump_in(self, 'output_dump_mode'):
181
+ temp = td_in(
182
+ gen_file_path(self.dump_path, self.cell_prefix, KEY_FORWARD, KEY_OUTPUT, 0),
183
+ out
184
+ )
185
+ else:
186
+ temp = td(
187
+ gen_file_path(self.dump_path, self.cell_prefix, KEY_FORWARD, KEY_OUTPUT, 0),
188
+ out
189
+ )
190
+ out = ops.depend(out, temp)
191
+ return out
192
+
193
+ return new_construct.__get__(self, type(self))
194
+
195
+
196
+ # 获取目录下所有文件名并根据TensorDump落盘自增id从小到大排序
197
+ def sort_filenames(path):
198
+ filenames = os.listdir(path)
199
+ id_pattern = re.compile(rf'{CoreConst.REPLACEMENT_CHARACTER}(\d+){CoreConst.NUMPY_SUFFIX}$')
200
+ filenames.sort(key=lambda x: int(id_pattern.findall(x)[0]))
201
+ return filenames
202
+
203
+
204
+ def rename_filename(path="", data_df=None):
205
+ if dump_task == CoreConst.TENSOR:
206
+ filenames = sort_filenames(path)
207
+ if dump_task == CoreConst.STATISTICS:
208
+ filenames = data_df[CoreConst.OP_NAME].tolist()
209
+
210
+ filename_dict = {}
211
+ for index, filename in enumerate(filenames):
212
+ if dump_task == CoreConst.TENSOR:
213
+ name_field = filename.rsplit(CoreConst.REPLACEMENT_CHARACTER, 1)[0]
214
+ if dump_task == CoreConst.STATISTICS:
215
+ name_field = filename
216
+
217
+ if name_field in filename_dict:
218
+ filename_dict[name_field] += 1
219
+ else:
220
+ filename_dict[name_field] = 0
221
+
222
+ cell_index = filename_dict[name_field]
223
+
224
+ # 修改文件名,增加重复调用Cell的序号
225
+ if CoreConst.FORWARD_PATTERN in filename:
226
+ # Format: Cell.{cell_name}.{class_name}.{forward/backward}.{number}.{input/output}.{index}_{dtype}_{id}.npy
227
+ new_file_name = filename.replace(CoreConst.FORWARD_PATTERN,
228
+ CoreConst.FORWARD_PATTERN + str(cell_index) + CoreConst.SEP)
229
+ if CoreConst.BACKWARD_PATTERN in filename:
230
+ new_file_name = filename.replace(CoreConst.BACKWARD_PATTERN,
231
+ CoreConst.BACKWARD_PATTERN + str(cell_index) + CoreConst.SEP)
232
+ if dump_task == CoreConst.TENSOR:
233
+ move_file(os.path.join(path, filename), os.path.join(path, new_file_name))
234
+ if dump_task == CoreConst.STATISTICS:
235
+ data_df.loc[index, CoreConst.OP_NAME] = new_file_name
236
+ logger.info("==========The rename_filename phase is Finished!==========")
237
+
238
+
239
+ # Extract the field between the first "." and the third to last ".", i.e. {cell_name}
240
+ def get_cell_name(string):
241
+ parts = string.split(CoreConst.SEP)
242
+ if len(parts) < 4:
243
+ return None
244
+ start_index = 1
245
+ end_index = len(parts) - 3
246
+ return CoreConst.SEP.join(parts[start_index:end_index])
247
+
248
+
249
+ # Extract the field between the last "." and the second to last ".", i.e. {data_made}
250
+ def get_data_mode(string):
251
+ last_dot_index = string.rfind(CoreConst.SEP)
252
+ second_last_dot_index = string.rfind(CoreConst.SEP, 0, last_dot_index)
253
+ data_mode = string[second_last_dot_index + 1:last_dot_index]
254
+ return data_mode
255
+
256
+
257
+ # 判断二者之间是否存在父子关系
258
+ def check_relation(cell_name, parent_cell_name):
259
+ layers_pattern = rf"{CoreConst.SEP}{KEY_LAYERS}{CoreConst.SEP}\d+$"
260
+ last_dot_index = cell_name.rfind(CoreConst.SEP)
261
+ if last_dot_index == -1:
262
+ return False
263
+ # 如果cell_name最后一个'.'之前的字段等于parent_cell_name,则判定存在父子关系
264
+ sub_cell_name = cell_name[:last_dot_index]
265
+ if sub_cell_name == parent_cell_name:
266
+ return True
267
+ elif re.search(layers_pattern, cell_name):
268
+ # 如果cell_name以".layer.{layer_id}"结尾,且去掉该字段后等于parent_cell_name,则判定存在父子关系
269
+ sub_cell_name = re.sub(layers_pattern, '', cell_name)
270
+ if sub_cell_name == parent_cell_name:
271
+ return True
272
+ return False
273
+
274
+
275
+ def get_construct(cell_list_input):
276
+ for cell in cell_list_input:
277
+ cell_name = get_cell_name(cell)
278
+ cell_data_mode = get_data_mode(cell)
279
+ found_flag = False
280
+ for parent_cell in cell_list_input:
281
+ parent_cell_name = get_cell_name(parent_cell)
282
+ parent_data_mode = get_data_mode(parent_cell)
283
+ has_relation = check_relation(cell_name, parent_cell_name)
284
+ if has_relation and parent_data_mode == cell_data_mode:
285
+ construct.update({cell: parent_cell})
286
+ found_flag = True
287
+ break
288
+ if not found_flag:
289
+ construct.update({cell: None})
290
+
291
+
292
+ def generate_construct(path):
293
+ global construct
294
+ if dump_task == CoreConst.TENSOR:
295
+ # filename格式:Cell.clip_grad_norm.ClipGradNorm.forward.0.output.1_int32_0.npy
296
+ filenames = sort_filenames(path)
297
+ point_position = 3
298
+ if dump_task == CoreConst.STATISTICS:
299
+ df = read_csv(path)
300
+ # filename格式:Cell.clip_grad_norm.ClipGradNorm.forward.0.output.1
301
+ filenames = df[CoreConst.OP_NAME].tolist()
302
+ point_position = 2
303
+
304
+ # 提取文件名中Cell.{cell_name}.{class_name}.{data_mode}.{重复调用此cell的序号}字段,并存入cell_list
305
+ for filename in filenames:
306
+ mid_field = filename.rsplit(CoreConst.SEP, point_position)[0]
307
+ if KEY_INPUT in filename:
308
+ if mid_field in cell_list:
309
+ cell_list.remove(mid_field)
310
+ cell_list.append(mid_field)
311
+ else:
312
+ if mid_field not in cell_list:
313
+ index = filenames.index(filename)
314
+ output_field = mid_field + KEY_OUTPUT
315
+ find_flag = False
316
+ for filename_other in cell_list[index + 1:]:
317
+ if output_field in filename_other:
318
+ find_flag = True
319
+ if find_flag is False:
320
+ cell_list.append(mid_field)
321
+
322
+ get_construct(cell_list)
323
+
324
+ # 生成JSON文件
325
+ rank_dir = os.path.dirname(path)
326
+ json_path = os.path.join(rank_dir, CONSTRUCT_FILE_NAME)
327
+ save_json(json_path, construct, indent=1)
328
+
329
+ # 清空'construct'继续处理下一个路径下的数据
330
+ construct = {}
331
+ logger.info(f"Construct data saved to {json_path}")
332
+
333
+
334
+ def process_file(file_path):
335
+ try:
336
+ # 读取.npy文件内容
337
+ npy_content = load_npy(file_path)
338
+ logger.debug(f"Loaded {file_path}: shape is {npy_content.shape}, dtype is {npy_content.dtype}")
339
+
340
+ # 文件名举例:Cell.network._backbone.loss.CrossEntropyLoss.forward.0.input.0_float32_165.npy
341
+ parts = os.path.basename(file_path).split(CoreConst.SEP)
342
+ data_dtype = ""
343
+ # 获取0_float32_165或者0_in_float32_165中的float32
344
+ data_dtype_list = parts[-2].split('_')
345
+ if len(data_dtype_list) > 1:
346
+ data_dtype = data_dtype_list[-2]
347
+ # op_name是Cell.network._backbone.loss.CrossEntropyLoss.forward.0
348
+ op_name = CoreConst.SEP.join(parts[:-3])
349
+ ms_dtype = np_ms_dtype_dict.get(data_dtype)
350
+ if ms_dtype is None:
351
+ logger.warning(f"Get dtype None from file {file_path}")
352
+
353
+ # 修改落盘文件名字,去掉TensorDump自带的数据类型和自增id字段
354
+ data_file_name = os.path.basename(file_path)
355
+ data_file_dir = os.path.dirname(file_path)
356
+ parts = data_file_name.split(CoreConst.SEP)
357
+ if len(parts) >= 2:
358
+ param_index = parts[-2].split(CoreConst.REPLACEMENT_CHARACTER)[0]
359
+ pre_parts = CoreConst.SEP.join(parts[:-2])
360
+ new_file_name = pre_parts + CoreConst.SEP + param_index + CoreConst.NUMPY_SUFFIX
361
+ move_file(os.path.join(data_file_dir, data_file_name), os.path.join(data_file_dir, new_file_name))
362
+ logger.debug(f"{data_file_name} is renamed to {new_file_name}")
363
+ else:
364
+ logger.warning(f"Failed to rename {data_file_name}.")
365
+ new_file_name = data_file_name
366
+
367
+ tensor_json = {
368
+ CoreConst.TYPE: 'mindspore.Tensor',
369
+ CoreConst.DTYPE: str(ms_dtype),
370
+ CoreConst.SHAPE: list(npy_content.shape),
371
+ CoreConst.MAX: npy_content.max().item(),
372
+ CoreConst.MIN: npy_content.min().item(),
373
+ CoreConst.MEAN: npy_content.mean().item(),
374
+ CoreConst.NORM: np.linalg.norm(npy_content).item(),
375
+ CoreConst.DATA_NAME: new_file_name
376
+ }
377
+
378
+ # 根据文件名的最后一个部分(输入或输出)确定是添加到input_args还是output
379
+ if parts[-3] == KEY_INPUT:
380
+ return op_name, CoreConst.INPUT_ARGS, tensor_json
381
+ elif parts[-3] == KEY_OUTPUT:
382
+ return op_name, KEY_OUTPUT, tensor_json
383
+ else:
384
+ return None, None, None
385
+
386
+ except Exception as e:
387
+ logger.error(f"Error reading {file_path}: {e}")
388
+ return None, None, None
389
+
390
+
391
+ def custom_sort(item, key_to_index):
392
+ key = item[0]
393
+ return key_to_index.get(key, float('inf'))
394
+
395
+
396
+ def convert_special_values(value):
397
+ if isinstance(value, str):
398
+ if value.lower() == "true":
399
+ return True
400
+ elif value.lower() == "false":
401
+ return False
402
+ try:
403
+ return float(value)
404
+ except ValueError:
405
+ return value
406
+ elif pd.isna(value):
407
+ return None
408
+ return value
409
+
410
+
411
+ def process_csv(path):
412
+ data_info = []
413
+ df = read_csv(path)
414
+ df = df.sort_values(by='Op Name', ascending=True)
415
+ columns = df.columns
416
+ colume_to_json_key = {
417
+ 'Max Value': CoreConst.MAX,
418
+ 'Min Value': CoreConst.MIN,
419
+ 'Avg Value': CoreConst.MEAN,
420
+ 'L2Norm Value': CoreConst.NORM
421
+ }
422
+ for _, row in df.iterrows():
423
+ # op_name_value格式:Cell.network._backbone.loss.CrossEntropyLoss.forward.0.input.0
424
+ op_name_value = row['Op Name']
425
+ op_name = op_name_value.rsplit(CoreConst.SEP, 2)[0]
426
+
427
+ # 获取input/output字段
428
+ io_key = op_name_value.split(CoreConst.SEP)[-2]
429
+
430
+ # shape读取出来为字符串类型转为list。"(1,4096)"->[1,4096]
431
+ shape_num = re.findall(r'\d+', row['Shape'])
432
+ shape = [int(num) for num in shape_num]
433
+
434
+ tensor_json = {
435
+ CoreConst.TYPE: 'mindspore.Tensor',
436
+ CoreConst.DTYPE: str(np_ms_dtype_dict.get(row['Data Type'])),
437
+ CoreConst.SHAPE: shape
438
+ }
439
+ for col_name, json_key in colume_to_json_key.items():
440
+ if col_name in columns:
441
+ value = convert_special_values(row[col_name])
442
+ tensor_json[json_key] = value
443
+
444
+ if io_key == KEY_INPUT:
445
+ data_info.append([op_name, CoreConst.INPUT_ARGS, tensor_json])
446
+ elif io_key == KEY_OUTPUT:
447
+ data_info.append([op_name, KEY_OUTPUT, tensor_json])
448
+ else:
449
+ data_info.append([None, None, None])
450
+ return data_info
451
+
452
+
453
+ def generate_dump_info(path):
454
+ if not os.path.exists(path):
455
+ logger.error("The provided path does not exist.")
456
+ return
457
+
458
+ if dump_task == CoreConst.TENSOR:
459
+ dump_data = {"task": "tensor", "level": "L0", "dump_data_dir": path, "data": {}}
460
+ with Pool(processes=10) as pool:
461
+ file_paths = []
462
+ for file in os.listdir(path):
463
+ if file.endswith(FileCheckConst.NUMPY_SUFFIX):
464
+ file_paths.append((os.path.join(path, file),))
465
+ file_paths.sort()
466
+ results = pool.starmap(process_file, file_paths)
467
+ if dump_task == CoreConst.STATISTICS:
468
+ dump_data = {"task": "statistics", "level": "L0", "framework": "mindspore", "dump_data_dir": None, "data": {}}
469
+ results = process_csv(path)
470
+
471
+ # 收集结果
472
+ for op_name, key, tensor_json in results:
473
+ if op_name:
474
+ if op_name not in dump_data.get(CoreConst.DATA, {}):
475
+ dump_data.get(CoreConst.DATA, {})[op_name] = {CoreConst.INPUT_ARGS: [],
476
+ CoreConst.INPUT_KWARGS: {},
477
+ KEY_OUTPUT: []}
478
+ if key not in dump_data.get(CoreConst.DATA, {}).get(op_name, {}):
479
+ dump_data.get(CoreConst.DATA, {}).get(op_name, {})[key] = []
480
+ dump_data.get(CoreConst.DATA, {}).get(op_name, {}).get(key, []).append(tensor_json)
481
+
482
+ # 根据cell_list排序
483
+ data_dict = dump_data.get(CoreConst.DATA, {})
484
+ key_to_index = {key: index for index, key in enumerate(cell_list)}
485
+ sorted_data_dict = dict(sorted(data_dict.items(), key=lambda item: custom_sort(item, key_to_index)))
486
+ dump_data[CoreConst.DATA] = sorted_data_dict
487
+
488
+ # 将数据写入dump.json
489
+ json_path = os.path.join(os.path.dirname(path), 'dump.json')
490
+ save_json(json_path, dump_data, indent=1)
491
+
492
+ logger.info(f"Dump data saved to {json_path}")
493
+
494
+
495
+ def generate_stack_info(path):
496
+ if not os.path.exists(path):
497
+ logger.error("The provided path does not exist.")
498
+ return
499
+
500
+ stack_data = {}
501
+ for cell_name in cell_list:
502
+ stack_data.update({cell_name: []})
503
+
504
+ # 将数据写入stack.json
505
+ json_path = os.path.join(os.path.dirname(path), 'stack.json')
506
+ save_json(json_path, stack_data, indent=1)
507
+
508
+ # 删除csv文件
509
+ if dump_task == CoreConst.STATISTICS:
510
+ remove_path(path)
511
+
512
+ logger.info(f"Stack data saved to {json_path}")
513
+
514
+
515
+ def is_download_finished(directory, interval=3):
516
+ """
517
+ 判断指定目录在一段时间后是否有数据被下载完成
518
+ :param directory: 指定目录的路径
519
+ :param interval: 检查的时间间隔(秒),默认为 3 秒
520
+ :return: 如有数据被下载完成返回 True,否则返回 False
521
+ """
522
+ # 检查目录是否存在
523
+ if not os.path.exists(directory):
524
+ logger.warning(f"The specified directory {directory} does not exist.")
525
+ return False, False
526
+ initial_modification_time = os.path.getmtime(directory)
527
+ time.sleep(interval)
528
+ current_modification_time = os.path.getmtime(directory)
529
+ # 比较初始和当前修改时间
530
+ if current_modification_time > initial_modification_time:
531
+ return False, True
532
+ else:
533
+ return True, False
534
+
535
+
536
+ def process(dump_path):
537
+ rank_id = os.environ.get('RANK_ID')
538
+ rank_dir = DEFAULT_RANK_DIR
539
+ if rank_id is not None:
540
+ rank_dir = CoreConst.RANK + str(rank_id)
541
+
542
+ step_dir_list = os.listdir(dump_path)
543
+ for step_dir in step_dir_list:
544
+ step_path = os.path.join(dump_path, step_dir)
545
+ rank_path = os.path.join(step_path, rank_dir)
546
+ npy_path = os.path.join(rank_path, CoreConst.DUMP_TENSOR_DATA)
547
+ check_times = 0
548
+ while True:
549
+ is_finished, is_downloading = is_download_finished(npy_path)
550
+ if not is_finished:
551
+ if not is_downloading:
552
+ logger.warning(f'{npy_path} does not exist.')
553
+ break
554
+ check_times += 1
555
+ if check_times < 1000:
556
+ logger.info("There is data being downloaded in the specified directory, continue checking...")
557
+ else:
558
+ logger.warning('Download timeout, stop checking.')
559
+ break
560
+ else:
561
+ logger.info("There is no data being downloaded in the specified directory, stop checking.")
562
+ break
563
+ logger.info("==========Start processing data that has already been stored on the disk!==========")
564
+ rename_filename(path=npy_path)
565
+ generate_construct(npy_path)
566
+ generate_dump_info(npy_path)
567
+ generate_stack_info(npy_path)
568
+ # 单卡场景,rank目录名称为rank
569
+ if rank_id is None:
570
+ new_rank_path = os.path.join(step_path, CoreConst.RANK)
571
+ try:
572
+ move_directory(rank_path, new_rank_path)
573
+ logger.debug(f"Directory was successfully renamed to: {new_rank_path}")
574
+ except Exception as e:
575
+ logger.warning(f"Failed to renamed to {new_rank_path}: {e}")
576
+ logger.info("==========JSON file generation completed!==========")
577
+
578
+
579
+ # 删除csv文件中每行数据最后面的逗号
580
+ def remove_trailing_commas(filename):
581
+ csv_data = read_csv(filename, as_pd=False)
582
+ for i in range(1, len(csv_data)):
583
+ if csv_data[i] and csv_data[i][-1] == "":
584
+ csv_data[i].pop()
585
+ write_csv(csv_data, filename, mode="w")
586
+
587
+
588
+ # 将相同step的csv文件合并,并加工后存入相应step目录下
589
+ def merge_file(dump_path, rank_dir, file_dict):
590
+ rank_dir = rank_dir.replace(CoreConst.REPLACEMENT_CHARACTER, '')
591
+ for step_dir, file_list in file_dict.items():
592
+ step_dir = CoreConst.STEP + step_dir
593
+ rank_path = os.path.join(dump_path, step_dir, rank_dir)
594
+ create_directory(rank_path)
595
+ output_file = os.path.join(rank_path, KEY_STATISTIC_CSV)
596
+
597
+ all_dfs = []
598
+ try:
599
+ for file_path in file_list:
600
+ remove_trailing_commas(file_path)
601
+ df = read_csv(file_path)
602
+ all_dfs.append(df)
603
+
604
+ # 合并所有 DataFrame
605
+ merged_df = pd.concat(all_dfs, ignore_index=True)
606
+
607
+ # 按 Timestamp 字段升序排序
608
+ merged_df = merged_df.sort_values(by='Timestamp', ascending=True)
609
+ # 删除Slot字段为0的数据
610
+ merged_df = merged_df[merged_df['Slot'] != 0]
611
+ # 重置索引,从0开始排序
612
+ merged_df.reset_index(drop=True, inplace=True)
613
+
614
+ # 获取op_name并加工为Cell.network._backbone.LlamaForCausalLM.forward.input.0格式
615
+ merged_df[CoreConst.OP_NAME] = merged_df[CoreConst.OP_NAME].str.split(KEY_DUMP_TENSOR_DATA, expand=True)[1]
616
+ merged_df[CoreConst.OP_NAME] = (
617
+ merged_df[CoreConst.OP_NAME].str.split(CoreConst.PIPE_SEPARATOR, expand=True)[0]
618
+ )
619
+ merged_df[CoreConst.OP_NAME] = (
620
+ merged_df[CoreConst.OP_NAME].str.replace(CoreConst.HYPHEN, CoreConst.SEP, regex=False)
621
+ )
622
+ # 重命名op_name,改为Cell.{cell_name}.{class_name}.{forward/backward}.{number}.{input/output}.{index}格式
623
+ rename_filename(data_df=merged_df)
624
+
625
+ # 将合并并排序后的 DataFrame 保存到相应step目录下
626
+ write_df_to_csv(merged_df, output_file)
627
+ except FileNotFoundError:
628
+ logger.error("One or more files not found.")
629
+ except KeyError:
630
+ logger.error("The value of the ‘Op Name’ field does not contain KEY_DUMP_TENSOR_DATA,"
631
+ " and the index is out of bounds.")
632
+ except Exception as e:
633
+ logger.error(f"An error occurred:{e}")
634
+
635
+
636
+ def process_statistics(dump_path):
637
+ rank_id = os.environ.get('RANK_ID')
638
+ rank_dir_kbk = "rank_0"
639
+ if rank_id is not None:
640
+ rank_dir_kbk = CoreConst.RANK + CoreConst.REPLACEMENT_CHARACTER + str(rank_id)
641
+ rank_path_kbk = os.path.join(dump_path, rank_dir_kbk)
642
+
643
+ # 按相同step数将csv文件名分组存入file_dict
644
+ file_dict = {}
645
+ depth_limit = 4
646
+ base_depth = rank_path_kbk.count(os.sep)
647
+ for root, _, files in os.walk(rank_path_kbk):
648
+ current_depth = root.count(os.sep) - base_depth
649
+ if current_depth > depth_limit:
650
+ continue
651
+ for file in files:
652
+ if file == KEY_STATISTIC_CSV:
653
+ file_path = os.path.join(root, file)
654
+ step_dir = os.path.basename(os.path.dirname(file_path))
655
+ if step_dir in file_dict:
656
+ file_dict[step_dir].append(file_path)
657
+ else:
658
+ file_dict[step_dir] = [file_path]
659
+
660
+ # 将相同step的csv文件合并,并加工后存入相应step目录下
661
+ merge_file(dump_path, rank_dir_kbk, file_dict)
662
+
663
+ rank_dir = rank_dir_kbk.replace(CoreConst.REPLACEMENT_CHARACTER, '')
664
+ dir_list = os.listdir(dump_path)
665
+ step_dir_list = [d for d in dir_list if d.startswith(CoreConst.STEP)]
666
+ for step_dir in step_dir_list:
667
+ step_path = os.path.join(dump_path, step_dir)
668
+ rank_path = os.path.join(step_path, rank_dir)
669
+ csv_path = os.path.join(rank_path, KEY_STATISTIC_CSV)
670
+ logger.info("==========Start processing data csv!==========")
671
+ generate_construct(csv_path)
672
+ generate_dump_info(csv_path)
673
+ generate_stack_info(csv_path)
674
+ remove_path(rank_path_kbk)
675
+ # 单卡场景,rank目录名称为rank
676
+ if rank_id is None:
677
+ new_rank_path = os.path.join(step_path, CoreConst.RANK)
678
+ try:
679
+ move_directory(rank_path, new_rank_path)
680
+ logger.info(f"Directory was successfully renamed to: {new_rank_path}")
681
+ except Exception as e:
682
+ logger.warning(f"Failed to renamed to {new_rank_path}: {e}")
683
+ logger.info("==========JSON file generation completed!==========")
684
+
685
+
686
+ def get_yaml_keys(yaml_data):
687
+ keys = []
688
+ for key, _ in yaml_data.items():
689
+ keys.append(key)
690
+ return keys
691
+
692
+
693
+ def get_tensordump_mode(input_str):
694
+ left_index = input_str.find('(')
695
+ right_index = input_str.find(')')
696
+
697
+ # 提取括号内的字符串
698
+ if left_index != -1 and right_index != -1:
699
+ inner_str = input_str[left_index + 1:right_index]
700
+ # 分割字符串得到元素列表
701
+ elements = inner_str.split(',')
702
+ if len(elements) >= 2:
703
+ # 去除元素前后的空格
704
+ first_element = elements[0].strip()
705
+ second_element = elements[1].strip()
706
+ return first_element, second_element
707
+ return None, None
708
+
709
+
710
+ def set_tensordump_mode(cell, input_str):
711
+ first_str, second_str = get_tensordump_mode(input_str)
712
+ if first_str and second_str:
713
+ cell.input_dump_mode = first_str
714
+ cell.output_dump_mode = second_str
715
+
716
+
717
+ def create_kbyk_json(dump_path, summary_mode, step):
718
+ if step:
719
+ step_str = ""
720
+ for s in step:
721
+ step_str += (str(s) + '|')
722
+ iteration = step_str[:-1]
723
+ else:
724
+ iteration = "all"
725
+
726
+ if summary_mode == "statistics":
727
+ statistic_category = ["max", "min", "avg", "l2norm"]
728
+ elif "mean" in summary_mode:
729
+ mean_index = summary_mode.index("mean")
730
+ summary_mode[mean_index] = "avg"
731
+ statistic_category = summary_mode
732
+ else:
733
+ statistic_category = summary_mode
734
+
735
+ config_json = {
736
+ "common_dump_settings": {
737
+ "op_debug_mode": 0,
738
+ "dump_mode": 1,
739
+ "path": dump_path,
740
+ "net_name": "Net",
741
+ "iteration": iteration,
742
+ "saved_data": "statistic",
743
+ "input_output": 0,
744
+ "kernels": ["TensorDump"],
745
+ "support_device": [0, 1, 2, 3, 4, 5, 6, 7],
746
+ "statistic_category": statistic_category
747
+ },
748
+ "e2e_dump_settings": {
749
+ "enable": False,
750
+ "trans_flag": True,
751
+ "stat_calc_mode": "device"
752
+ }
753
+ }
754
+
755
+ create_directory(dump_path)
756
+ rank_id = os.environ.get('RANK_ID')
757
+ if rank_id is None:
758
+ rank_id = 0
759
+ config_json_path = os.path.join(dump_path, str(rank_id) + "kernel_kbyk_dump.json")
760
+ save_json(config_json_path, config_json, indent=4)
761
+ logger.info(config_json_path + " has been created.")
762
+ return config_json_path
763
+
764
+
765
+ def start(config: CellDumpConfig):
766
+ global dump_task
767
+ dump_task = config.task
768
+ net = config.net
769
+ dump_path = config.dump_path
770
+ data_mode = config.data_mode
771
+ summary_mode = config.summary_mode
772
+ step = config.step
773
+ if dump_task == CoreConst.STATISTICS:
774
+ # 使能KBK dump
775
+ config_json_path = create_kbyk_json(dump_path, summary_mode, step)
776
+ os.environ["MINDSPORE_DUMP_CONFIG"] = config_json_path
777
+
778
+ # 执行过程中跳过TensorDump算子
779
+ os.environ["MS_KERNEL_LAUNCH_SKIP"] = "TensorDump"
780
+
781
+ # 初始化静态图KBK dump的step数,从0开始
782
+ if not graph_step_flag:
783
+ raise Exception(
784
+ "Importing _set_init_iter failed, "
785
+ "please use the latest version package of MindSpore."
786
+ )
787
+ _set_init_iter(0)
788
+ remove_path(config_json_path)
789
+
790
+ if net is None:
791
+ return
792
+
793
+ if isinstance(net, nn.Cell):
794
+ net = (('', net),)
795
+
796
+ td_config_path = ""
797
+ try:
798
+ import mindformers
799
+ mindformers_file = mindformers.__file__
800
+ mindformers_dir = os.path.dirname(mindformers_file)
801
+ td_config_path = os.path.join(mindformers_dir, "configuration", "layer_mapping.yaml")
802
+ if not os.path.exists(td_config_path):
803
+ td_config_path = ""
804
+ logger.warning("The configuration file in mindformers was not loaded, the default mode will be used.")
805
+ except ImportError:
806
+ logger.warning("The mindFormers failed to load, the default mode will be used.")
807
+
808
+ if td_config_path == "":
809
+ yaml_data = {}
810
+ else:
811
+ yaml_data = load_yaml(td_config_path)
812
+ first_layer_key = get_yaml_keys(yaml_data)
813
+
814
+ black_list = ["grad_reducer", ""]
815
+
816
+ for name_and_model in net:
817
+ for name, cell in name_and_model[1].cells_and_names(name_prefix=name_and_model[0]):
818
+ class_name = cell.__class__.__name__
819
+ # 跳过黑名单cell
820
+ if name in black_list:
821
+ logger.info(f"Cell {name}.{class_name} is skipped!")
822
+ continue
823
+ # 跳过框架内部的cell
824
+ if class_name.startswith(CoreConst.REPLACEMENT_CHARACTER):
825
+ logger.info(f"Cell {name}.{class_name} is skipped!")
826
+ continue
827
+ else:
828
+ # Format: Cell.{cell_name}.{class_name}
829
+ cell.cell_prefix = CoreConst.SEP.join([CoreConst.CELL, name, cell.__class__.__name__])
830
+ if dump_task == CoreConst.STATISTICS:
831
+ cell.cell_prefix = cell.cell_prefix.replace(CoreConst.SEP, CoreConst.HYPHEN)
832
+
833
+ # 根据yaml配置文件设置cell的TensorDump模式
834
+ if class_name in first_layer_key:
835
+ layer_data = yaml_data.get(class_name)
836
+ if layer_data:
837
+ for child_name, child_cell in cell.cells_and_names():
838
+ if child_name in layer_data:
839
+ set_tensordump_mode(child_cell, layer_data[child_name])
840
+ top_layer_data = yaml_data.get(KEY_TOPLAYER)
841
+ if top_layer_data and name in top_layer_data:
842
+ set_tensordump_mode(cell, top_layer_data[name])
843
+
844
+ # 替换construct函数
845
+ cell.construct = cell_construct_wrapper(cell.construct, cell)
846
+ logger.info(f"Cell {name}: construct function is wrapped!")
847
+ cell.dump_path = dump_path
848
+ cell.data_mode = data_mode
849
+ cell.input_clips = []
850
+ cell.output_clips = []
851
+ # It is assumed that each cell has a maximum of 50 outputs and 50 inputs.
852
+ for i in range(50):
853
+ cell.input_clips.append(
854
+ ops.InsertGradientOf(partial_func(clip_gradient, cell.dump_path, cell.cell_prefix, i, KEY_INPUT))
855
+ )
856
+ cell.output_clips.append(
857
+ ops.InsertGradientOf(partial_func(clip_gradient, cell.dump_path, cell.cell_prefix, i, KEY_OUTPUT))
858
+ )
859
+
860
+ logger.info("==========The cell_dump_process_start phase is Finished!==========")
861
+ if dump_task == CoreConst.TENSOR:
862
+ atexit.register(process, dump_path=dump_path)
863
+ if dump_task == CoreConst.STATISTICS:
864
+ atexit.register(process_statistics, dump_path=dump_path)