mindstudio-probe 8.1.2__py3-none-any.whl → 8.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mindstudio_probe-8.1.2.dist-info → mindstudio_probe-8.2.1.dist-info}/METADATA +2 -2
- {mindstudio_probe-8.1.2.dist-info → mindstudio_probe-8.2.1.dist-info}/RECORD +172 -147
- msprobe/README.md +6 -6
- msprobe/core/common/const.py +98 -41
- msprobe/core/common/db_manager.py +256 -0
- msprobe/core/common/file_utils.py +28 -5
- msprobe/core/common/log.py +7 -0
- msprobe/core/common/megatron_utils.py +59 -0
- msprobe/core/common/parallel_state.py +193 -0
- msprobe/core/common/utils.py +20 -13
- msprobe/core/common_config.py +5 -0
- msprobe/core/compare/acc_compare.py +140 -93
- msprobe/core/compare/check.py +13 -0
- msprobe/core/compare/compare_cli.py +64 -6
- msprobe/core/compare/config.py +10 -8
- msprobe/core/compare/diff_analyze/diff_analyze_threshold.yaml +14 -0
- msprobe/core/compare/diff_analyze/first_diff_analyze.py +135 -0
- msprobe/core/compare/diff_analyze/ignore_op_list.yaml +3 -0
- msprobe/core/compare/find_first/__init__.py +0 -0
- msprobe/core/compare/find_first/analyzer.py +282 -0
- msprobe/core/compare/find_first/data_processor.py +35 -0
- msprobe/core/compare/find_first/graph.py +188 -0
- msprobe/core/compare/find_first/utils.py +189 -0
- msprobe/core/compare/highlight.py +74 -101
- msprobe/core/compare/layer_mapping/layer_mapping.py +14 -9
- msprobe/core/compare/merge_result/merge_result.py +2 -2
- msprobe/core/compare/multiprocessing_compute.py +45 -28
- msprobe/core/compare/npy_compare.py +7 -10
- msprobe/core/compare/utils.py +338 -130
- msprobe/core/config_check/checkers/dataset_checker.py +2 -1
- msprobe/core/config_check/checkers/env_args_checker.py +5 -5
- msprobe/core/config_check/checkers/hyperparameter_checker.py +30 -10
- msprobe/core/config_check/checkers/pip_checker.py +4 -3
- msprobe/core/config_check/checkers/random_checker.py +3 -3
- msprobe/core/config_check/checkers/weights_checker.py +2 -1
- msprobe/core/config_check/ckpt_compare/megatron_loader.py +2 -0
- msprobe/core/config_check/resource/hyperparameter.yaml +11 -1
- msprobe/core/config_check/utils/hyperparameter_parser.py +7 -3
- msprobe/core/config_check/utils/utils.py +10 -0
- msprobe/core/data_dump/api_registry.py +49 -30
- msprobe/core/data_dump/data_collector.py +71 -29
- msprobe/core/data_dump/data_processor/base.py +2 -0
- msprobe/core/data_dump/data_processor/mindspore_processor.py +47 -53
- msprobe/core/data_dump/data_processor/pytorch_processor.py +227 -93
- msprobe/core/data_dump/json_writer.py +81 -7
- msprobe/core/data_dump/scope.py +4 -6
- msprobe/core/hook_manager.py +129 -70
- msprobe/core/monitor/csv2db.py +361 -0
- msprobe/core/monitor/db_utils.py +278 -0
- msprobe/core/monitor/utils.py +35 -1
- msprobe/core/service.py +31 -39
- msprobe/core/single_save/single_comparator.py +16 -3
- msprobe/docs/01.installation.md +51 -19
- msprobe/docs/02.config_introduction.md +16 -20
- msprobe/docs/03.config_examples.md +26 -0
- msprobe/docs/04.kernel_dump_PyTorch.md +1 -1
- msprobe/docs/05.data_dump_PyTorch.md +6 -2
- msprobe/docs/06.data_dump_MindSpore.md +44 -7
- msprobe/docs/07.accuracy_checker_PyTorch.md +1 -1
- msprobe/docs/10.accuracy_compare_PyTorch.md +124 -44
- msprobe/docs/11.accuracy_compare_MindSpore.md +75 -7
- msprobe/docs/14.data_parse_PyTorch.md +1 -1
- msprobe/docs/19.monitor.md +94 -7
- msprobe/docs/21.visualization_PyTorch.md +71 -101
- msprobe/docs/22.visualization_MindSpore.md +69 -119
- msprobe/docs/23.generate_operator_PyTorch.md +1 -1
- msprobe/docs/25.tool_function_introduction.md +0 -1
- msprobe/docs/26.data_dump_PyTorch_baseline.md +7 -7
- msprobe/docs/28.debugger_save_instruction.md +184 -81
- msprobe/docs/29.data_dump_MSAdapter.md +6 -0
- msprobe/docs/31.config_check.md +4 -2
- msprobe/docs/36.calculation_result_change.md +75 -0
- msprobe/docs/FAQ.md +22 -1
- msprobe/docs/data_dump_MindSpore/dynamic_graph_quick_start_example.md +6 -2
- msprobe/docs/img/compare_result.png +0 -0
- msprobe/docs/img/visualization/vis_browser_1.png +0 -0
- msprobe/docs/img/visualization/vis_match_info.png +0 -0
- msprobe/docs/img/visualization/vis_precision_info.png +0 -0
- msprobe/docs/img/visualization/vis_search_info.png +0 -0
- msprobe/docs/img/visualization/vis_show_info.png +0 -0
- msprobe/docs/img/visualization/vis_showcase.png +0 -0
- msprobe/docs/img/visualization/vis_unmatch_info.png +0 -0
- msprobe/docs/visualization/mindspeed_llamafactoary_img/1.png +0 -0
- msprobe/docs/visualization/mindspeed_llamafactoary_img/2.png +0 -0
- msprobe/docs/visualization/mindspeed_llamafactoary_img/3.png +0 -0
- msprobe/docs/visualization/mindspeed_llamafactoary_img/4.png +0 -0
- msprobe/docs/visualization/mindspeed_llamafactoary_img/5.png +0 -0
- msprobe/docs/visualization/mindspeed_llamafactoary_img/6.png +0 -0
- msprobe/docs/visualization/mindspeed_llamafactoary_img/7.png +0 -0
- msprobe/docs/visualization/mindspeed_llamafactoary_img/llamafactory-qwen25vl.txt +59 -0
- msprobe/docs/visualization/mindspeed_llamafactoary_img/llamafactory1.png +0 -0
- msprobe/docs/visualization/mindspeed_llamafactoary_img/llamafactory2.png +0 -0
- msprobe/docs/visualization/mindspeed_llamafactoary_img/mindspeed-mm-qwen25vl.txt +80 -0
- msprobe/docs/visualization/mindspeed_llamafactoary_img/mindspeed1.png +0 -0
- msprobe/docs/visualization/mindspeed_llamafactoary_img/mindspeed2.png +0 -0
- msprobe/docs/visualization/mindspeed_llamafactory_mapping.md +330 -0
- msprobe/mindspore/__init__.py +1 -1
- msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +1 -1
- msprobe/mindspore/api_accuracy_checker/api_runner.py +9 -6
- msprobe/mindspore/api_accuracy_checker/compute_element.py +18 -12
- msprobe/mindspore/cell_processor.py +64 -25
- msprobe/mindspore/common/utils.py +51 -7
- msprobe/mindspore/compare/common_dir_compare.py +45 -37
- msprobe/mindspore/compare/ms_compare.py +10 -2
- msprobe/mindspore/compare/ms_graph_compare.py +47 -52
- msprobe/mindspore/debugger/debugger_config.py +18 -7
- msprobe/mindspore/debugger/precision_debugger.py +16 -12
- msprobe/mindspore/dump/cell_dump_process.py +130 -68
- msprobe/mindspore/dump/cell_dump_with_insert_gradient.py +10 -2
- msprobe/mindspore/dump/graph_mode_cell_dump.py +35 -9
- msprobe/mindspore/dump/graph_tensor_dump.py +11 -0
- msprobe/mindspore/dump/hook_cell/api_register.py +19 -20
- msprobe/mindspore/dump/hook_cell/hook_cell.py +12 -34
- msprobe/mindspore/dump/hook_cell/ms_hook_manager.py +142 -21
- msprobe/mindspore/dump/kernel_kbyk_dump.py +24 -0
- msprobe/mindspore/exception_dump/__init__.py +0 -0
- msprobe/mindspore/exception_dump/exception_dump_tool_factory.py +51 -0
- msprobe/mindspore/exception_dump/kernel_graph_exception_dump.py +57 -0
- msprobe/mindspore/free_benchmark/api_pynative_self_check.py +5 -4
- msprobe/mindspore/mindspore_service.py +2 -2
- msprobe/mindspore/mindtorch/mindtorch_adaptor.py +12 -7
- msprobe/mindspore/monitor/features.py +82 -0
- msprobe/mindspore/monitor/module_hook.py +168 -10
- msprobe/mindspore/monitor/utils.py +27 -1
- msprobe/mindspore/ms_config.py +12 -4
- msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +1 -1
- msprobe/mindspore/task_handler_factory.py +3 -1
- msprobe/nan_analyze/graph.py +1 -1
- msprobe/pytorch/api_accuracy_checker/common/config.py +3 -36
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +0 -24
- msprobe/pytorch/api_accuracy_checker/compare/compare.py +2 -12
- msprobe/pytorch/api_accuracy_checker/config.yaml +1 -6
- msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +2 -2
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +12 -132
- msprobe/pytorch/common/utils.py +1 -21
- msprobe/pytorch/compare/pt_compare.py +10 -2
- msprobe/pytorch/{hook_module/jit_script_wrapper.py → compare/pt_diff_analyze.py} +3 -15
- msprobe/pytorch/compare/utils.py +2 -1
- msprobe/pytorch/debugger/debugger_config.py +18 -23
- msprobe/pytorch/dump/module_dump/hook_wrapper.py +10 -7
- msprobe/pytorch/dump/module_dump/module_processer.py +41 -19
- msprobe/pytorch/free_benchmark/main.py +7 -4
- msprobe/pytorch/hook_module/api_register.py +62 -24
- msprobe/pytorch/hook_module/hook_module.py +9 -29
- msprobe/pytorch/hook_module/pt_hook_manager.py +84 -15
- msprobe/pytorch/hook_module/script_wrapper.py +140 -0
- msprobe/pytorch/hook_module/support_wrap_ops.yaml +6 -0
- msprobe/pytorch/monitor/csv2tb.py +1 -1
- msprobe/pytorch/monitor/features.py +94 -0
- msprobe/pytorch/monitor/module_hook.py +221 -81
- msprobe/pytorch/monitor/module_metric.py +27 -1
- msprobe/pytorch/monitor/optimizer_collect.py +109 -4
- msprobe/pytorch/online_dispatch/dispatch.py +42 -24
- msprobe/pytorch/online_dispatch/dump_compare.py +1 -1
- msprobe/pytorch/parse_tool/lib/visualization.py +0 -1
- msprobe/pytorch/pt_config.py +2 -51
- msprobe/pytorch/pytorch_service.py +7 -14
- msprobe/visualization/builder/graph_builder.py +192 -63
- msprobe/visualization/builder/graph_merger.py +986 -0
- msprobe/visualization/builder/msprobe_adapter.py +17 -15
- msprobe/visualization/compare/graph_comparator.py +26 -16
- msprobe/visualization/db_utils.py +252 -0
- msprobe/visualization/graph/base_node.py +2 -22
- msprobe/visualization/graph/distributed_analyzer.py +12 -12
- msprobe/visualization/graph/graph.py +44 -16
- msprobe/visualization/graph_service.py +143 -59
- msprobe/visualization/utils.py +103 -4
- msprobe/docs/08.accuracy_checker_online_PyTorch.md +0 -295
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +0 -205
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +0 -378
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +0 -239
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/dump_dispatch.py +0 -115
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +0 -250
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/torch_ops_config.yaml +0 -63
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +0 -198
- msprobe/pytorch/attl_manager.py +0 -65
- {mindstudio_probe-8.1.2.dist-info → mindstudio_probe-8.2.1.dist-info}/LICENSE +0 -0
- {mindstudio_probe-8.1.2.dist-info → mindstudio_probe-8.2.1.dist-info}/WHEEL +0 -0
- {mindstudio_probe-8.1.2.dist-info → mindstudio_probe-8.2.1.dist-info}/entry_points.txt +0 -0
- {mindstudio_probe-8.1.2.dist-info → mindstudio_probe-8.2.1.dist-info}/top_level.txt +0 -0
- /msprobe/{pytorch/api_accuracy_checker/tensor_transport_layer → core/compare/diff_analyze}/__init__.py +0 -0
|
@@ -25,13 +25,15 @@ from msprobe.core.common.exceptions import MsprobeException
|
|
|
25
25
|
from msprobe.core.common.runtime import Runtime
|
|
26
26
|
from msprobe.core.common.utils import ModuleQueue, ThreadSafe
|
|
27
27
|
from msprobe.core.data_dump.scope import ModuleRangeScope, MixRangeScope, BaseScope
|
|
28
|
+
from msprobe.core.common.megatron_utils import wrap_megatron_step, get_micro_step, is_megatron
|
|
28
29
|
from msprobe.mindspore.common.const import Const as MsConst
|
|
29
30
|
from msprobe.mindspore.common.log import logger
|
|
30
31
|
from msprobe.mindspore.common.utils import (
|
|
31
32
|
is_mindtorch,
|
|
32
33
|
get_cells_and_names_with_index,
|
|
33
34
|
has_kwargs_in_forward_hook,
|
|
34
|
-
is_graph_mode_cell_dump_allowed
|
|
35
|
+
is_graph_mode_cell_dump_allowed,
|
|
36
|
+
is_backward_hook_output_a_view
|
|
35
37
|
)
|
|
36
38
|
from msprobe.mindspore.debugger.debugger_config import DebuggerConfig
|
|
37
39
|
from msprobe.mindspore.dump.graph_mode_cell_dump import GraphModeCellDump
|
|
@@ -46,6 +48,28 @@ def get_cell_construct(construct):
|
|
|
46
48
|
return _construct
|
|
47
49
|
|
|
48
50
|
|
|
51
|
+
def patch_schedules_step():
|
|
52
|
+
try:
|
|
53
|
+
from mindspeed.mindspore.core.pipeline_parallel import schedules
|
|
54
|
+
schedules.forward_step = wrap_megatron_step(schedules.forward_step)
|
|
55
|
+
schedules.backward_step = wrap_megatron_step(schedules.backward_step, is_forward=False)
|
|
56
|
+
logger.info_on_rank_0("Patch mindspeed.mindspore method success.")
|
|
57
|
+
except ImportError:
|
|
58
|
+
logger.info_on_rank_0("No mindspeed.mindspore find.")
|
|
59
|
+
except Exception as e:
|
|
60
|
+
logger.info_on_rank_0(f"Patch mindspeed.mindspore method failed, detail:{str(e)}")
|
|
61
|
+
|
|
62
|
+
try:
|
|
63
|
+
from megatron.core.pipeline_parallel import schedules
|
|
64
|
+
schedules.forward_step = wrap_megatron_step(schedules.forward_step)
|
|
65
|
+
schedules.backward_step = wrap_megatron_step(schedules.backward_step, is_forward=False)
|
|
66
|
+
logger.info_on_rank_0("Patch megatron method success.")
|
|
67
|
+
except ImportError:
|
|
68
|
+
logger.info_on_rank_0("No megatron find.")
|
|
69
|
+
except Exception as e:
|
|
70
|
+
logger.info_on_rank_0(f"Patch megatron method failed, detail:{str(e)}")
|
|
71
|
+
|
|
72
|
+
|
|
49
73
|
class CellProcessor:
|
|
50
74
|
cell_queue = ModuleQueue()
|
|
51
75
|
cell_count = {}
|
|
@@ -83,6 +107,8 @@ class CellProcessor:
|
|
|
83
107
|
raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR,
|
|
84
108
|
'The model cannot be None, when level is "L0" or "mix"')
|
|
85
109
|
|
|
110
|
+
patch_schedules_step()
|
|
111
|
+
|
|
86
112
|
is_registered = False
|
|
87
113
|
model_type = Const.MODULE if is_mindtorch() else Const.CELL
|
|
88
114
|
cells_with_index_in_pynative_mode, cells_with_index_in_graph_mode = get_cells_and_names_with_index(models)
|
|
@@ -116,19 +142,23 @@ class CellProcessor:
|
|
|
116
142
|
cells_and_names_in_graph_mode = []
|
|
117
143
|
for index, cells_and_names in cells_with_index_in_graph_mode.items():
|
|
118
144
|
model = models if index == "-1" else models[int(index)]
|
|
119
|
-
for name, cell in cells_and_names:
|
|
145
|
+
for name, cell, parent_cell in cells_and_names:
|
|
120
146
|
if cell == model:
|
|
121
147
|
continue
|
|
122
148
|
cell_index = (index + Const.SEP) if index != "-1" else ""
|
|
123
|
-
cells_and_names_in_graph_mode.append((f'{cell_index}{name}', cell))
|
|
149
|
+
cells_and_names_in_graph_mode.append((f'{cell_index}{name}', cell, parent_cell))
|
|
124
150
|
|
|
125
151
|
if cells_and_names_in_graph_mode:
|
|
126
152
|
Runtime.run_mode = MsConst.PYNATIVE_GRAPH_MODE
|
|
127
153
|
GraphModeCellDump(config, cells_and_names_in_graph_mode, strict=False).handle()
|
|
128
154
|
|
|
155
|
+
|
|
129
156
|
def build_cell_hook(self, cell_name, build_data_hook):
|
|
130
157
|
@ThreadSafe.synchronized
|
|
131
158
|
def forward_pre_hook(cell, args):
|
|
159
|
+
if not Runtime.is_running:
|
|
160
|
+
return args
|
|
161
|
+
|
|
132
162
|
index = CellProcessor.set_and_get_calls_number(cell_name)
|
|
133
163
|
full_forward_name = f'{cell_name}{Const.FORWARD}{Const.SEP}{index}'
|
|
134
164
|
full_backward_name = f'{cell_name}{Const.BACKWARD}{Const.SEP}{index}'
|
|
@@ -174,7 +204,7 @@ class CellProcessor:
|
|
|
174
204
|
bw_hook.register_backward_hook()
|
|
175
205
|
CellProcessor.cell_bw_hook_kernels[full_forward_name] = bw_hook
|
|
176
206
|
|
|
177
|
-
args = bw_hook(*args)
|
|
207
|
+
args = bw_hook(args) if is_backward_hook_output_a_view() else bw_hook(*args)
|
|
178
208
|
|
|
179
209
|
return args
|
|
180
210
|
|
|
@@ -199,12 +229,15 @@ class CellProcessor:
|
|
|
199
229
|
logger.warning("For backward hooks to be called,"
|
|
200
230
|
" cell output should be a Tensor or a tuple of Tensors"
|
|
201
231
|
f" but received {type(outputs)}")
|
|
202
|
-
if
|
|
203
|
-
new_outputs = bw_hook(*outputs)
|
|
204
|
-
else:
|
|
232
|
+
if is_backward_hook_output_a_view():
|
|
205
233
|
new_outputs = bw_hook(outputs)
|
|
206
|
-
|
|
207
|
-
|
|
234
|
+
else:
|
|
235
|
+
if isinstance(outputs, tuple):
|
|
236
|
+
new_outputs = bw_hook(*outputs)
|
|
237
|
+
else:
|
|
238
|
+
new_outputs = bw_hook(outputs)
|
|
239
|
+
if isinstance(outputs, tuple) and len(outputs) == 1:
|
|
240
|
+
new_outputs = (new_outputs,)
|
|
208
241
|
outputs = new_outputs
|
|
209
242
|
|
|
210
243
|
def get_backward_pre_hook(full_backward_name, backward_data_hook):
|
|
@@ -227,18 +260,21 @@ class CellProcessor:
|
|
|
227
260
|
self.cell_backward_pre_hook[-1])
|
|
228
261
|
bw_pre_hook.register_backward_pre_hook()
|
|
229
262
|
|
|
230
|
-
if
|
|
231
|
-
result = bw_pre_hook(*outputs)
|
|
232
|
-
else:
|
|
263
|
+
if is_backward_hook_output_a_view():
|
|
233
264
|
result = bw_pre_hook(outputs)
|
|
234
|
-
|
|
235
|
-
if
|
|
236
|
-
result = (
|
|
237
|
-
|
|
238
|
-
|
|
239
|
-
|
|
240
|
-
|
|
241
|
-
|
|
265
|
+
else:
|
|
266
|
+
if isinstance(outputs, tuple):
|
|
267
|
+
result = bw_pre_hook(*outputs)
|
|
268
|
+
else:
|
|
269
|
+
result = bw_pre_hook(outputs)
|
|
270
|
+
if isinstance(outputs, tuple):
|
|
271
|
+
if len(outputs) == 1:
|
|
272
|
+
result = (result,)
|
|
273
|
+
if len(result) != len(outputs):
|
|
274
|
+
raise TypeError(
|
|
275
|
+
f"The backward pre hook return value size is {len(result)} "
|
|
276
|
+
f"not equal to output size {len(outputs)}"
|
|
277
|
+
)
|
|
242
278
|
return result
|
|
243
279
|
|
|
244
280
|
return forward_pre_hook
|
|
@@ -249,23 +285,26 @@ class CellProcessor:
|
|
|
249
285
|
CellProcessor.cell_stack[tid] = []
|
|
250
286
|
|
|
251
287
|
if self.cell_stack[tid]:
|
|
252
|
-
CellProcessor.module_node[full_name] = self.cell_stack[tid][-1]
|
|
288
|
+
CellProcessor.module_node[full_name] = self.cell_stack[tid][-1] if not is_megatron() \
|
|
289
|
+
else [self.cell_stack[tid][-1], get_micro_step()]
|
|
253
290
|
else:
|
|
254
291
|
parent_name = CellProcessor.cell_queue.find_last(full_name)
|
|
255
|
-
CellProcessor.module_node[full_name] = parent_name
|
|
292
|
+
CellProcessor.module_node[full_name] = parent_name if not is_megatron() else [parent_name, get_micro_step()]
|
|
256
293
|
|
|
257
294
|
CellProcessor.cell_queue.add_name(full_name)
|
|
258
295
|
CellProcessor.cell_stack[tid].append(full_name)
|
|
259
|
-
CellProcessor.api_parent_node[tid] = full_name
|
|
296
|
+
CellProcessor.api_parent_node[tid] = full_name if not is_megatron() else [full_name, get_micro_step()]
|
|
260
297
|
if self.scope:
|
|
261
298
|
self.scope.begin_module(full_name)
|
|
262
299
|
|
|
263
300
|
def set_construct_info_in_hook(self, full_name):
|
|
264
301
|
tid = threading.get_ident()
|
|
265
|
-
CellProcessor.
|
|
302
|
+
CellProcessor.cell_queue.remove_name(full_name)
|
|
303
|
+
CellProcessor.api_parent_node[tid] = None if not is_megatron() else [None, get_micro_step()]
|
|
266
304
|
if self.cell_stack.get(tid):
|
|
267
305
|
CellProcessor.cell_stack[tid].pop()
|
|
268
306
|
if self.cell_stack.get(tid):
|
|
269
|
-
CellProcessor.api_parent_node[tid] = CellProcessor.cell_stack[tid][-1]
|
|
307
|
+
CellProcessor.api_parent_node[tid] = CellProcessor.cell_stack[tid][-1] if not is_megatron() \
|
|
308
|
+
else [CellProcessor.cell_stack[tid][-1], get_micro_step()]
|
|
270
309
|
if self.scope:
|
|
271
310
|
self.scope.end_module(full_name)
|
|
@@ -16,6 +16,7 @@
|
|
|
16
16
|
import inspect
|
|
17
17
|
import os
|
|
18
18
|
import random
|
|
19
|
+
import sys
|
|
19
20
|
import types
|
|
20
21
|
|
|
21
22
|
import mindspore as ms
|
|
@@ -41,6 +42,7 @@ else:
|
|
|
41
42
|
mindtorch_check_result = None
|
|
42
43
|
register_backward_hook_functions = {}
|
|
43
44
|
kwargs_exist_in_forward_hook = None
|
|
45
|
+
is_output_of_backward_hook_a_view = None
|
|
44
46
|
|
|
45
47
|
|
|
46
48
|
class MsprobeStep(ms.train.Callback):
|
|
@@ -129,7 +131,7 @@ def list_lowest_level_directories(root_dir):
|
|
|
129
131
|
return lowest_level_dirs
|
|
130
132
|
|
|
131
133
|
|
|
132
|
-
def seed_all(seed=1234, mode=False, rm_dropout=
|
|
134
|
+
def seed_all(seed=1234, mode=False, rm_dropout=False):
|
|
133
135
|
check_seed_all(seed, mode, rm_dropout)
|
|
134
136
|
os.environ['PYTHONHASHSEED'] = str(seed)
|
|
135
137
|
ms.set_seed(seed)
|
|
@@ -179,6 +181,8 @@ def is_mindtorch():
|
|
|
179
181
|
global mindtorch_check_result
|
|
180
182
|
if mindtorch_check_result is None:
|
|
181
183
|
mindtorch_check_result = False
|
|
184
|
+
if 'torch' not in sys.modules:
|
|
185
|
+
return mindtorch_check_result
|
|
182
186
|
try:
|
|
183
187
|
import torch
|
|
184
188
|
except ImportError:
|
|
@@ -254,14 +258,14 @@ def is_decorated_by_jit(func):
|
|
|
254
258
|
|
|
255
259
|
|
|
256
260
|
@recursion_depth_decorator('msprobe.mindspore.common.utils.get_cells_and_names')
|
|
257
|
-
def get_cells_and_names(model, cells_set=None, name_prefix=''):
|
|
261
|
+
def get_cells_and_names(model, cells_set=None, name_prefix='', parent_cell=None):
|
|
258
262
|
cells_set = cells_set if cells_set else set()
|
|
259
263
|
if model in cells_set:
|
|
260
264
|
return
|
|
261
265
|
|
|
262
266
|
cells_set.add(model)
|
|
263
267
|
jit_decorated = is_decorated_by_jit(model.construct)
|
|
264
|
-
yield name_prefix, model, jit_decorated
|
|
268
|
+
yield name_prefix, model, jit_decorated, parent_cell
|
|
265
269
|
if jit_decorated:
|
|
266
270
|
return
|
|
267
271
|
|
|
@@ -271,9 +275,9 @@ def get_cells_and_names(model, cells_set=None, name_prefix=''):
|
|
|
271
275
|
cells_name_prefix = f'{name_prefix}{Const.SEP}{name}' if name_prefix else name
|
|
272
276
|
jit_decorated = is_decorated_by_jit(model.construct)
|
|
273
277
|
if jit_decorated:
|
|
274
|
-
yield cells_name_prefix, cell, jit_decorated
|
|
278
|
+
yield cells_name_prefix, cell, jit_decorated, model
|
|
275
279
|
else:
|
|
276
|
-
for ele in get_cells_and_names(cell, cells_set, cells_name_prefix):
|
|
280
|
+
for ele in get_cells_and_names(cell, cells_set, cells_name_prefix, model):
|
|
277
281
|
yield ele
|
|
278
282
|
|
|
279
283
|
|
|
@@ -284,9 +288,9 @@ def get_cells_and_names_with_index(models):
|
|
|
284
288
|
def distinguish_cells(cells):
|
|
285
289
|
cells_in_pynative_mode = []
|
|
286
290
|
cells_in_graph_mode = []
|
|
287
|
-
for name, cell, jit_decorated in cells:
|
|
291
|
+
for name, cell, jit_decorated, parent_cell in cells:
|
|
288
292
|
if jit_decorated:
|
|
289
|
-
cells_in_graph_mode.append((name, cell))
|
|
293
|
+
cells_in_graph_mode.append((name, cell, parent_cell))
|
|
290
294
|
else:
|
|
291
295
|
cells_in_pynative_mode.append((name, cell))
|
|
292
296
|
return cells_in_pynative_mode, cells_in_graph_mode
|
|
@@ -329,3 +333,43 @@ def has_kwargs_in_forward_hook():
|
|
|
329
333
|
return kwargs_exist_in_forward_hook
|
|
330
334
|
|
|
331
335
|
return kwargs_exist_in_forward_hook
|
|
336
|
+
|
|
337
|
+
|
|
338
|
+
def is_backward_hook_output_a_view():
|
|
339
|
+
global is_output_of_backward_hook_a_view
|
|
340
|
+
|
|
341
|
+
if is_output_of_backward_hook_a_view is None:
|
|
342
|
+
is_output_of_backward_hook_a_view = False
|
|
343
|
+
if getattr(ms, '__version__', '2.4.0') < '2.7.0':
|
|
344
|
+
return is_output_of_backward_hook_a_view
|
|
345
|
+
try:
|
|
346
|
+
from mindspore.ops.operations import _inner_ops as inner
|
|
347
|
+
call_func = getattr(inner.CellBackwardHook, '__call__')
|
|
348
|
+
func_params = inspect.signature(call_func).parameters
|
|
349
|
+
except Exception:
|
|
350
|
+
return is_output_of_backward_hook_a_view
|
|
351
|
+
if 'args' in func_params and func_params['args'].kind == inspect.Parameter.POSITIONAL_OR_KEYWORD:
|
|
352
|
+
is_output_of_backward_hook_a_view = True
|
|
353
|
+
|
|
354
|
+
return is_output_of_backward_hook_a_view
|
|
355
|
+
|
|
356
|
+
|
|
357
|
+
def wrap_backward_hook_call_func(call_func):
|
|
358
|
+
if not is_backward_hook_output_a_view():
|
|
359
|
+
return call_func
|
|
360
|
+
|
|
361
|
+
from mindspore.common.api import _pynative_executor as executor
|
|
362
|
+
from mindspore._c_expression import CreationType
|
|
363
|
+
|
|
364
|
+
def new_call(self, args):
|
|
365
|
+
outputs = call_func(self, args)
|
|
366
|
+
if isinstance(outputs, ms.Tensor):
|
|
367
|
+
executor.set_creation_type(outputs, CreationType.DEFAULT)
|
|
368
|
+
elif isinstance(outputs, tuple):
|
|
369
|
+
for item in outputs:
|
|
370
|
+
if isinstance(item, ms.Tensor):
|
|
371
|
+
executor.set_creation_type(item, CreationType.DEFAULT)
|
|
372
|
+
return outputs
|
|
373
|
+
new_call.__name__ = '__call__'
|
|
374
|
+
|
|
375
|
+
return new_call
|
|
@@ -154,21 +154,34 @@ def find_npy_files(directory):
|
|
|
154
154
|
dirs.clear()
|
|
155
155
|
for file in files:
|
|
156
156
|
if file.endswith(".npy"):
|
|
157
|
-
#
|
|
158
|
-
|
|
159
|
-
if len(
|
|
157
|
+
# 正确移除文件扩展名
|
|
158
|
+
base_name = os.path.splitext(file)
|
|
159
|
+
if not base_name or len(base_name) < 1:
|
|
160
|
+
logger.warning("Invalid file encountered.")
|
|
160
161
|
continue
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
162
|
+
file_name = base_name[0]
|
|
163
|
+
|
|
164
|
+
logger.info(f"Generating file info for file: {file}")
|
|
165
|
+
|
|
166
|
+
# 使用一致的分割逻辑
|
|
167
|
+
file_ele = file_name.split('_')
|
|
168
|
+
|
|
169
|
+
if len(file_ele) < 2:
|
|
170
|
+
continue
|
|
171
|
+
|
|
172
|
+
key = '_'.join(file_ele[:-2])
|
|
173
|
+
if key:
|
|
174
|
+
# 文件的完整路径
|
|
175
|
+
value = os.path.join(root, file)
|
|
176
|
+
# 添加到字典中
|
|
177
|
+
if key not in npy_files_dict:
|
|
178
|
+
npy_files_dict[key] = []
|
|
179
|
+
npy_files_dict[key].append(value)
|
|
168
180
|
return npy_files_dict
|
|
169
181
|
|
|
170
182
|
|
|
171
183
|
def generate_map_dict(npu_file_dict, bench_file_dict, name_map_dict=None):
|
|
184
|
+
result_dict = {}
|
|
172
185
|
for k, npu_file_list in npu_file_dict.items():
|
|
173
186
|
bench_file_list = bench_file_dict.get(k)
|
|
174
187
|
if not bench_file_list and k in name_map_dict:
|
|
@@ -176,7 +189,6 @@ def generate_map_dict(npu_file_dict, bench_file_dict, name_map_dict=None):
|
|
|
176
189
|
bench_length = len(bench_file_list)
|
|
177
190
|
if not (bench_file_list and bench_length):
|
|
178
191
|
continue
|
|
179
|
-
result_dict = {}
|
|
180
192
|
for i, npu_file in enumerate(npu_file_list):
|
|
181
193
|
if i >= bench_length:
|
|
182
194
|
break
|
|
@@ -200,14 +212,14 @@ def do_multi_process(func, map_dict):
|
|
|
200
212
|
df_chunks = [result_df]
|
|
201
213
|
process_num = 1
|
|
202
214
|
logger.info(f"Using {process_num} processes with chunk size {df_chunk_size}")
|
|
203
|
-
|
|
215
|
+
|
|
204
216
|
# 分割字典
|
|
205
217
|
map_chunks = split_dict(map_dict, df_chunk_size)
|
|
206
|
-
|
|
218
|
+
|
|
207
219
|
# 创建结果列表和进程池
|
|
208
220
|
results = []
|
|
209
221
|
pool = multiprocessing.Pool(process_num)
|
|
210
|
-
|
|
222
|
+
|
|
211
223
|
progress_bar = tqdm(total=len(result_df), desc="API/Module Item Compare Process", unit="row", ncols=100)
|
|
212
224
|
|
|
213
225
|
def update_progress(size, progress_lock, extra_param=None):
|
|
@@ -216,34 +228,30 @@ def do_multi_process(func, map_dict):
|
|
|
216
228
|
|
|
217
229
|
def err_call(args):
|
|
218
230
|
logger.error('multiprocess compare failed! Reason: {}'.format(args))
|
|
219
|
-
|
|
220
|
-
pool.close()
|
|
221
|
-
except OSError as e:
|
|
222
|
-
logger.error(f'pool terminate failed: {str(e)}')
|
|
231
|
+
|
|
223
232
|
results = []
|
|
233
|
+
|
|
234
|
+
# 提交任务到进程池
|
|
235
|
+
for process_idx, (df_chunk, map_chunk) in enumerate(zip(df_chunks, map_chunks)):
|
|
236
|
+
start_idx = df_chunk_size * process_idx
|
|
237
|
+
result = pool.apply_async(
|
|
238
|
+
func,
|
|
239
|
+
args=(df_chunk, start_idx, map_chunk, lock),
|
|
240
|
+
error_callback=err_call,
|
|
241
|
+
callback=partial(update_progress, len(map_chunk), lock)
|
|
242
|
+
)
|
|
243
|
+
results.append(result)
|
|
244
|
+
pool.close()
|
|
245
|
+
|
|
224
246
|
try:
|
|
225
|
-
|
|
226
|
-
for process_idx, (df_chunk, map_chunk) in enumerate(zip(df_chunks, map_chunks)):
|
|
227
|
-
start_idx = df_chunk_size * process_idx
|
|
228
|
-
result = pool.apply_async(
|
|
229
|
-
func,
|
|
230
|
-
args=(df_chunk, start_idx, map_chunk, lock),
|
|
231
|
-
error_callback=err_call,
|
|
232
|
-
callback=partial(update_progress, len(map_chunk), lock)
|
|
233
|
-
)
|
|
234
|
-
results.append(result)
|
|
235
|
-
|
|
236
|
-
final_results = [r.get() for r in results]
|
|
237
|
-
# 等待所有任务完成
|
|
238
|
-
pool.close()
|
|
239
|
-
pool.join()
|
|
240
|
-
return pd.concat(final_results, ignore_index=True)
|
|
247
|
+
final_results = [r.get(timeout=3600) for r in results]
|
|
241
248
|
except Exception as e:
|
|
242
|
-
logger.error(f"
|
|
249
|
+
logger.error(f"Task failed with exception: {e}")
|
|
243
250
|
pool.terminate()
|
|
244
251
|
return pd.DataFrame({})
|
|
245
|
-
|
|
246
|
-
|
|
252
|
+
# 等待所有任务完成
|
|
253
|
+
pool.join()
|
|
254
|
+
return pd.concat(final_results, ignore_index=True)
|
|
247
255
|
|
|
248
256
|
|
|
249
257
|
def initialize_result_df(total_size):
|
|
@@ -35,8 +35,16 @@ def ms_compare(input_param, output_path, **kwargs):
|
|
|
35
35
|
config.data_mapping = generate_data_mapping_by_layer_mapping(input_param, config.layer_mapping, output_path)
|
|
36
36
|
|
|
37
37
|
is_cross_framework = check_cross_framework(input_param.get('bench_json_path'))
|
|
38
|
-
|
|
39
|
-
|
|
38
|
+
|
|
39
|
+
config_dict = {
|
|
40
|
+
'stack_mode': config.stack_mode,
|
|
41
|
+
'auto_analyze': config.auto_analyze,
|
|
42
|
+
'fuzzy_match': config.fuzzy_match,
|
|
43
|
+
'highlight': config.highlight,
|
|
44
|
+
'dump_mode': config.dump_mode,
|
|
45
|
+
'compared_file_type': config.compared_file_type
|
|
46
|
+
}
|
|
47
|
+
mode_config = ModeConfig(**config_dict)
|
|
40
48
|
mapping_config = MappingConfig(config.cell_mapping, config.api_mapping, config.data_mapping)
|
|
41
49
|
ms_comparator = Comparator(read_real_data, mode_config, mapping_config, is_cross_framework)
|
|
42
50
|
ms_comparator.compare_core(input_param, output_path, suffix=config.suffix)
|
|
@@ -34,10 +34,11 @@ class RowData:
|
|
|
34
34
|
self.basic_data = copy.deepcopy(CompareConst.MS_GRAPH_BASE)
|
|
35
35
|
self.npy_data = copy.deepcopy(CompareConst.MS_GRAPH_NPY)
|
|
36
36
|
self.statistic_data = copy.deepcopy(CompareConst.MS_GRAPH_STATISTIC)
|
|
37
|
+
self.csv = copy.deepcopy(CompareConst.MS_GRAPH_CSV)
|
|
37
38
|
if mode == GraphMode.NPY_MODE:
|
|
38
39
|
self.data = {**self.basic_data, **self.npy_data}
|
|
39
40
|
else:
|
|
40
|
-
self.data = {**self.basic_data, **self.statistic_data}
|
|
41
|
+
self.data = {**self.basic_data, **self.statistic_data, **self.csv}
|
|
41
42
|
|
|
42
43
|
def __call__(self):
|
|
43
44
|
return self.data
|
|
@@ -80,8 +81,8 @@ def statistic_data_read(statistic_file_list, statistic_file_path):
|
|
|
80
81
|
data_list = []
|
|
81
82
|
statistic_data_list = []
|
|
82
83
|
header_index = {
|
|
83
|
-
'Data Type': None, 'Shape': None,
|
|
84
|
-
'Min Value': None, 'Avg Value': None, 'L2Norm Value': None
|
|
84
|
+
'Data Type': None, 'Shape': None,
|
|
85
|
+
'Max Value': None, 'Min Value': None, 'Avg Value': None, 'L2Norm Value': None
|
|
85
86
|
}
|
|
86
87
|
for statistic_file in statistic_file_list:
|
|
87
88
|
content = read_csv(statistic_file, as_pd=False)
|
|
@@ -107,7 +108,7 @@ def statistic_data_read(statistic_file_list, statistic_file_path):
|
|
|
107
108
|
logger.error(f'Dump file {statistic_file_path} has been modified into incorrect format!')
|
|
108
109
|
raise CompareException(f'Dump file {statistic_file_path} has been modified into incorrect format!')
|
|
109
110
|
compare_key = f"{data[1]}.{data[2]}.{data[5]}.{data[6]}" # OpName, TaskId, IO, Slot
|
|
110
|
-
op_name = f"{compare_key}
|
|
111
|
+
op_name = f"{compare_key}"
|
|
111
112
|
timestamp = int(data[4])
|
|
112
113
|
result_data = [op_name, compare_key, timestamp]
|
|
113
114
|
for key in header_index.keys():
|
|
@@ -115,6 +116,8 @@ def statistic_data_read(statistic_file_list, statistic_file_path):
|
|
|
115
116
|
result_data.append(np.nan)
|
|
116
117
|
else:
|
|
117
118
|
result_data.append(data[header_index[key]])
|
|
119
|
+
csv_file = f"{statistic_file_path}"
|
|
120
|
+
result_data.append(csv_file)
|
|
118
121
|
data_list.append(result_data)
|
|
119
122
|
return data_list
|
|
120
123
|
|
|
@@ -230,6 +233,17 @@ class GraphMSComparator:
|
|
|
230
233
|
result[f'{prefix} min'] = np.float32(rows[f'{prefix} min'])
|
|
231
234
|
result[f'{prefix} mean'] = np.float32(rows[f'{prefix} mean'])
|
|
232
235
|
result[f'{prefix} l2norm'] = np.float32(rows[f'{prefix} l2norm'])
|
|
236
|
+
result[f'{prefix} CSV File'] = rows[f'{prefix} CSV File']
|
|
237
|
+
|
|
238
|
+
def calculate_relative_error(numerator, denominator):
|
|
239
|
+
"""Calculates relative error, handling division by zero and NaN."""
|
|
240
|
+
if denominator != 0:
|
|
241
|
+
result = numerator / denominator
|
|
242
|
+
if not np.isnan(result):
|
|
243
|
+
return str(abs(result * 100)) + "%"
|
|
244
|
+
else:
|
|
245
|
+
return CompareConst.NAN
|
|
246
|
+
return CompareConst.N_A
|
|
233
247
|
|
|
234
248
|
# 使用示例
|
|
235
249
|
update_result_dict(result_dict, row, 'NPU')
|
|
@@ -237,34 +251,26 @@ class GraphMSComparator:
|
|
|
237
251
|
error_flag, error_message = statistics_data_check(result_dict)
|
|
238
252
|
result_dict[CompareConst.ERROR_MESSAGE] += error_message
|
|
239
253
|
if not error_flag:
|
|
240
|
-
|
|
241
|
-
|
|
242
|
-
|
|
243
|
-
|
|
244
|
-
|
|
245
|
-
|
|
246
|
-
|
|
247
|
-
|
|
248
|
-
|
|
249
|
-
CompareConst.
|
|
250
|
-
|
|
251
|
-
|
|
252
|
-
|
|
253
|
-
|
|
254
|
-
|
|
255
|
-
|
|
256
|
-
|
|
257
|
-
|
|
258
|
-
|
|
259
|
-
|
|
260
|
-
if not np.isnan(result_dict[CompareConst.MEAN_RELATIVE_ERR]):
|
|
261
|
-
result_dict[CompareConst.MEAN_RELATIVE_ERR] = str(
|
|
262
|
-
result_dict[CompareConst.MEAN_RELATIVE_ERR] * 100) + "%"
|
|
263
|
-
result_dict[CompareConst.NORM_RELATIVE_ERR] = result_dict[CompareConst.NORM_DIFF] / result_dict[
|
|
264
|
-
CompareConst.BENCH_NORM] if result_dict[CompareConst.BENCH_NORM] > 0 else 0
|
|
265
|
-
if not np.isnan(result_dict[CompareConst.NORM_RELATIVE_ERR]):
|
|
266
|
-
result_dict[CompareConst.NORM_RELATIVE_ERR] = str(
|
|
267
|
-
result_dict[CompareConst.NORM_RELATIVE_ERR] * 100) + "%"
|
|
254
|
+
metrics = [
|
|
255
|
+
(CompareConst.MAX_DIFF, CompareConst.NPU_MAX, CompareConst.BENCH_MAX),
|
|
256
|
+
(CompareConst.MIN_DIFF, CompareConst.NPU_MIN, CompareConst.BENCH_MIN),
|
|
257
|
+
(CompareConst.MEAN_DIFF, CompareConst.NPU_MEAN, CompareConst.BENCH_MEAN),
|
|
258
|
+
(CompareConst.NORM_DIFF, CompareConst.NPU_NORM, CompareConst.BENCH_NORM),
|
|
259
|
+
]
|
|
260
|
+
relative_error_metrics = [
|
|
261
|
+
(CompareConst.MAX_RELATIVE_ERR, CompareConst.MAX_DIFF, CompareConst.BENCH_MAX),
|
|
262
|
+
(CompareConst.MIN_RELATIVE_ERR, CompareConst.MIN_DIFF, CompareConst.BENCH_MIN),
|
|
263
|
+
(CompareConst.MEAN_RELATIVE_ERR, CompareConst.MEAN_DIFF, CompareConst.BENCH_MEAN),
|
|
264
|
+
(CompareConst.NORM_RELATIVE_ERR, CompareConst.NORM_DIFF, CompareConst.BENCH_NORM),
|
|
265
|
+
]
|
|
266
|
+
|
|
267
|
+
for diff_metric, npu_metric, bench_metric in metrics:
|
|
268
|
+
result_dict[diff_metric] = result_dict[npu_metric] - result_dict[bench_metric]
|
|
269
|
+
|
|
270
|
+
for rel_metric, diff_metric, bench_metric in relative_error_metrics:
|
|
271
|
+
result_dict[rel_metric] = calculate_relative_error(result_dict[diff_metric],
|
|
272
|
+
result_dict[bench_metric])
|
|
273
|
+
|
|
268
274
|
magnitude_diff = result_dict[CompareConst.MAX_DIFF] / (
|
|
269
275
|
max(result_dict[CompareConst.NPU_MAX], result_dict[CompareConst.BENCH_MAX]) + 1e-10)
|
|
270
276
|
if np.isnan(result_dict[CompareConst.NPU_MAX]) and np.isnan(result_dict[CompareConst.BENCH_MAX]):
|
|
@@ -296,20 +302,8 @@ class GraphMSComparator:
|
|
|
296
302
|
compare_result_df = self.do_multi_process(compare_result_df, mode)
|
|
297
303
|
compare_result_name = add_time_with_xlsx(f"compare_result_{str(rank_id)}_{str(step_id)}")
|
|
298
304
|
compare_result_path = os.path.join(os.path.realpath(self.output_path), f"{compare_result_name}")
|
|
299
|
-
self.to_excel(compare_result_df, compare_result_path)
|
|
300
|
-
logger.info(f"Compare rank: {rank_id} step: {step_id} finish. Compare result: {compare_result_path}.")
|
|
301
|
-
|
|
302
|
-
def to_excel(self, compare_result_df: pd.DataFrame, compare_result_path: str, slice_num=0, need_slice=False) -> int:
|
|
303
|
-
size = len(compare_result_df)
|
|
304
|
-
# sheet size cannot be larger than 1048576
|
|
305
|
-
if size < CompareConst.MAX_EXCEL_LENGTH:
|
|
306
|
-
compare_result_path = compare_result_path.replace('.xlsx', f'_slice_{slice_num}.xlsx') if \
|
|
307
|
-
need_slice else compare_result_path
|
|
308
305
|
save_excel(compare_result_path, compare_result_df)
|
|
309
|
-
|
|
310
|
-
else:
|
|
311
|
-
slice_num = self.to_excel(compare_result_df.iloc[0: size // 2], compare_result_path, slice_num, True)
|
|
312
|
-
return self.to_excel(compare_result_df.iloc[size // 2:], compare_result_path, slice_num, True)
|
|
306
|
+
logger.info(f"Compare rank: {rank_id} step: {step_id} finish. Compare result: {compare_result_path}.")
|
|
313
307
|
|
|
314
308
|
def compare_process(self, rank_id, step_id):
|
|
315
309
|
# generate data_path
|
|
@@ -331,7 +325,7 @@ class GraphMSComparator:
|
|
|
331
325
|
bench_data_list.extend(data_list)
|
|
332
326
|
|
|
333
327
|
if npu_mode == GraphMode.ERROR_MODE or bench_mode == GraphMode.ERROR_MODE:
|
|
334
|
-
logger.warning(f"
|
|
328
|
+
logger.warning(f"Data path: npu_data_path or bench_data_path does not exist.")
|
|
335
329
|
return [], ''
|
|
336
330
|
if npu_mode != bench_mode:
|
|
337
331
|
logger.error(f"NPU mode {npu_mode} not equal to MATCH mode {bench_mode}.")
|
|
@@ -344,14 +338,15 @@ class GraphMSComparator:
|
|
|
344
338
|
npu_data_df = pd.DataFrame(npu_data_list,
|
|
345
339
|
columns=[CompareConst.NPU_NAME, 'Compare Key', 'TimeStamp',
|
|
346
340
|
CompareConst.NPU_DTYPE, CompareConst.NPU_SHAPE,
|
|
347
|
-
CompareConst.NPU_MAX, CompareConst.NPU_MIN,
|
|
348
|
-
CompareConst.NPU_NORM
|
|
341
|
+
CompareConst.NPU_MAX, CompareConst.NPU_MIN,
|
|
342
|
+
CompareConst.NPU_MEAN, CompareConst.NPU_NORM,
|
|
343
|
+
CompareConst.NPU_CSV_FILE])
|
|
349
344
|
bench_data_df = pd.DataFrame(bench_data_list,
|
|
350
345
|
columns=[CompareConst.BENCH_NAME, 'Compare Key', 'TimeStamp',
|
|
351
|
-
CompareConst.BENCH_DTYPE,
|
|
352
|
-
CompareConst.
|
|
353
|
-
CompareConst.
|
|
354
|
-
CompareConst.
|
|
346
|
+
CompareConst.BENCH_DTYPE, CompareConst.BENCH_SHAPE,
|
|
347
|
+
CompareConst.BENCH_MAX, CompareConst.BENCH_MIN,
|
|
348
|
+
CompareConst.BENCH_MEAN, CompareConst.BENCH_NORM,
|
|
349
|
+
CompareConst.BENCH_CSV_FILE])
|
|
355
350
|
|
|
356
351
|
npu_float_type = [CompareConst.NPU_MAX, CompareConst.NPU_MIN, CompareConst.NPU_MEAN, CompareConst.NPU_NORM]
|
|
357
352
|
npu_float_data_df = npu_data_df[npu_float_type].astype(str)
|
|
@@ -49,8 +49,9 @@ class DebuggerConfig:
|
|
|
49
49
|
self.summary_mode = task_config.summary_mode
|
|
50
50
|
self.stat_cal_mode = task_config.stat_cal_mode if hasattr(task_config, 'stat_cal_mode') else None
|
|
51
51
|
self.device_stat_precision_mode = task_config.device_stat_precision_mode \
|
|
52
|
-
|
|
52
|
+
if hasattr(task_config, 'device_stat_precision_mode') else None
|
|
53
53
|
self.async_dump = common_config.async_dump if common_config.async_dump else False
|
|
54
|
+
self.precision = common_config.precision if common_config.precision else Const.DUMP_PRECISION_LOW
|
|
54
55
|
self.check()
|
|
55
56
|
self._check_statistics_config(task_config)
|
|
56
57
|
create_directory(self.dump_path)
|
|
@@ -115,18 +116,28 @@ class DebuggerConfig:
|
|
|
115
116
|
self.check_mode = "all"
|
|
116
117
|
if not isinstance(self.async_dump, bool):
|
|
117
118
|
raise Exception("The parameters async_dump should be bool.")
|
|
118
|
-
if self.async_dump and self.task == Const.TENSOR:
|
|
119
|
-
if self.level_ori == Const.LEVEL_DEBUG:
|
|
120
|
-
self.list = [] # async_dump + debug level case ignore list
|
|
121
|
-
if not self.list and self.level_ori != Const.LEVEL_DEBUG:
|
|
122
|
-
raise Exception("The parameters async_dump is true in tensor task,"
|
|
123
|
-
" the parameters list cannot be empty.")
|
|
124
119
|
if self.task == Const.STRUCTURE and self.level_ori not in [Const.LEVEL_L0, Const.LEVEL_MIX]:
|
|
125
120
|
logger.warning_on_rank_0(
|
|
126
121
|
f"When the task is set to structure, the level should be one of {[Const.LEVEL_L0, Const.LEVEL_MIX]}. "
|
|
127
122
|
f"If not, the default level is {Const.LEVEL_MIX}."
|
|
128
123
|
)
|
|
129
124
|
self.level_ori = Const.LEVEL_MIX
|
|
125
|
+
if self.async_dump:
|
|
126
|
+
if self.task == Const.TENSOR:
|
|
127
|
+
if self.level_ori == Const.LEVEL_DEBUG:
|
|
128
|
+
self.list = [] # async_dump + debug level case ignore list
|
|
129
|
+
if not self.list and self.level_ori != Const.LEVEL_DEBUG:
|
|
130
|
+
raise MsprobeException(
|
|
131
|
+
MsprobeException.INVALID_PARAM_ERROR,
|
|
132
|
+
"The parameters async_dump is true in tensor task, the parameters list cannot be empty."
|
|
133
|
+
)
|
|
134
|
+
is_unsupported_mode = self.summary_mode == Const.MD5 or \
|
|
135
|
+
isinstance(self.summary_mode, list) and Const.MD5 in self.summary_mode
|
|
136
|
+
if is_unsupported_mode:
|
|
137
|
+
raise MsprobeException(
|
|
138
|
+
MsprobeException.INVALID_PARAM_ERROR,
|
|
139
|
+
f"The parameters async_dump is true, the parameters summary_mode cannot be/contain md5."
|
|
140
|
+
)
|
|
130
141
|
return True
|
|
131
142
|
|
|
132
143
|
def check_config_with_l2(self, is_graph_config):
|