mindstudio-probe 1.2.1__py3-none-any.whl → 1.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mindstudio_probe-1.2.1.dist-info → mindstudio_probe-1.3.0.dist-info}/METADATA +3 -3
- {mindstudio_probe-1.2.1.dist-info → mindstudio_probe-1.3.0.dist-info}/RECORD +168 -150
- msprobe/README.md +27 -22
- msprobe/core/common/const.py +129 -60
- msprobe/core/common/decorator.py +50 -0
- msprobe/core/common/exceptions.py +3 -1
- msprobe/core/common/file_utils.py +25 -2
- msprobe/core/common/inplace_ops.yaml +1 -0
- msprobe/core/common/utils.py +43 -33
- msprobe/core/compare/acc_compare.py +43 -74
- msprobe/core/compare/check.py +2 -6
- msprobe/core/compare/highlight.py +2 -0
- msprobe/core/compare/layer_mapping/data_scope_parser.py +1 -1
- msprobe/core/compare/layer_mapping/layer_mapping.py +2 -1
- msprobe/core/compare/merge_result/merge_result.py +16 -9
- msprobe/core/compare/merge_result/utils.py +81 -0
- msprobe/core/compare/multiprocessing_compute.py +19 -12
- msprobe/core/compare/npy_compare.py +30 -12
- msprobe/core/compare/utils.py +30 -10
- msprobe/core/data_dump/api_registry.py +176 -0
- msprobe/core/data_dump/data_collector.py +58 -13
- msprobe/core/data_dump/data_processor/base.py +94 -10
- msprobe/core/data_dump/data_processor/factory.py +3 -0
- msprobe/core/data_dump/data_processor/mindspore_processor.py +33 -33
- msprobe/core/data_dump/data_processor/pytorch_processor.py +99 -18
- msprobe/core/data_dump/json_writer.py +61 -40
- msprobe/core/grad_probe/constant.py +1 -0
- msprobe/core/grad_probe/grad_compare.py +1 -1
- msprobe/core/overflow_check/abnormal_scene.py +2 -0
- msprobe/docs/01.installation.md +27 -1
- msprobe/docs/02.config_introduction.md +27 -23
- msprobe/docs/03.config_examples.md +24 -0
- msprobe/docs/05.data_dump_PyTorch.md +103 -16
- msprobe/docs/06.data_dump_MindSpore.md +76 -32
- msprobe/docs/07.accuracy_checker_PyTorch.md +11 -1
- msprobe/docs/08.accuracy_checker_online_PyTorch.md +3 -1
- msprobe/docs/09.accuracy_checker_MindSpore.md +5 -3
- msprobe/docs/10.accuracy_compare_PyTorch.md +59 -33
- msprobe/docs/11.accuracy_compare_MindSpore.md +40 -16
- msprobe/docs/12.overflow_check_PyTorch.md +3 -1
- msprobe/docs/13.overflow_check_MindSpore.md +4 -2
- msprobe/docs/14.data_parse_PyTorch.md +1 -7
- msprobe/docs/18.online_dispatch.md +1 -1
- msprobe/docs/19.monitor.md +332 -273
- msprobe/docs/21.visualization_PyTorch.md +42 -13
- msprobe/docs/22.visualization_MindSpore.md +43 -13
- msprobe/docs/23.generate_operator_PyTorch.md +9 -9
- msprobe/docs/27.dump_json_instruction.md +301 -27
- msprobe/docs/28.debugger_save_instruction.md +94 -0
- msprobe/docs/28.kernel_dump_MindSpore.md +69 -0
- msprobe/docs/29.data_dump_MSAdapter.md +229 -0
- msprobe/docs/30.overflow_check_MSAdapter.md +31 -0
- msprobe/docs/FAQ.md +3 -11
- msprobe/docs/img/compare_result.png +0 -0
- msprobe/docs/img/merge_result.png +0 -0
- msprobe/docs/img/monitor/step_count_per_record.png +0 -0
- msprobe/docs/img/visualization/vis_browser_1.png +0 -0
- msprobe/docs/img/visualization/vis_match_info.png +0 -0
- msprobe/docs/img/visualization/vis_precision_info.png +0 -0
- msprobe/docs/img/visualization/vis_search_info.png +0 -0
- msprobe/docs/img/visualization/vis_show_info.png +0 -0
- msprobe/docs/img/visualization/vis_showcase.png +0 -0
- msprobe/docs/img/visualization/vis_unmatch_info.png +0 -0
- msprobe/mindspore/__init__.py +4 -2
- msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +32 -7
- msprobe/mindspore/api_accuracy_checker/api_runner.py +70 -22
- msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +2 -1
- msprobe/mindspore/api_accuracy_checker/bench_functions/flash_attention_score.py +602 -0
- msprobe/mindspore/api_accuracy_checker/bench_functions/fusion_operator.py +41 -0
- msprobe/mindspore/api_accuracy_checker/compute_element.py +47 -1
- msprobe/mindspore/api_accuracy_checker/data_manager.py +2 -1
- msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +2 -1
- msprobe/mindspore/api_accuracy_checker/torch_mindtorch_importer.py +130 -0
- msprobe/mindspore/api_accuracy_checker/type_mapping.py +24 -1
- msprobe/mindspore/api_accuracy_checker/utils.py +6 -1
- msprobe/mindspore/common/const.py +61 -0
- msprobe/mindspore/common/utils.py +48 -18
- msprobe/mindspore/compare/ms_compare.py +27 -19
- msprobe/mindspore/compare/ms_graph_compare.py +6 -5
- msprobe/mindspore/debugger/debugger_config.py +31 -6
- msprobe/mindspore/debugger/precision_debugger.py +45 -14
- msprobe/mindspore/dump/dump_tool_factory.py +5 -3
- msprobe/mindspore/dump/hook_cell/api_register.py +142 -0
- msprobe/mindspore/dump/hook_cell/hook_cell.py +9 -10
- msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +24 -26
- msprobe/mindspore/dump/jit_dump.py +21 -15
- msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +22 -56
- msprobe/mindspore/dym_loader/hook_dynamic_loader.h +0 -1
- msprobe/mindspore/free_benchmark/api_pynative_self_check.py +10 -6
- msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +4 -2
- msprobe/mindspore/free_benchmark/self_check_tool_factory.py +6 -3
- msprobe/mindspore/grad_probe/global_context.py +2 -0
- msprobe/mindspore/grad_probe/grad_analyzer.py +2 -1
- msprobe/mindspore/grad_probe/hook.py +2 -4
- msprobe/mindspore/monitor/anomaly_detect.py +404 -0
- msprobe/mindspore/monitor/distributed/__init__.py +0 -0
- msprobe/mindspore/monitor/distributed/distributed_ops.yaml +15 -0
- msprobe/mindspore/monitor/distributed/stack_blacklist.yaml +5 -0
- msprobe/mindspore/monitor/distributed/wrap_distributed.py +300 -0
- msprobe/mindspore/monitor/features.py +63 -0
- msprobe/mindspore/monitor/module_hook.py +873 -0
- msprobe/mindspore/monitor/module_spec_verifier.py +94 -0
- msprobe/mindspore/monitor/utils.py +309 -0
- msprobe/mindspore/ms_config.py +8 -2
- msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +5 -3
- msprobe/mindspore/service.py +114 -34
- msprobe/pytorch/__init__.py +0 -1
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +3 -6
- msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +12 -7
- msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +2 -2
- msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +4 -5
- msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +5 -5
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +25 -6
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +28 -19
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +3 -1
- msprobe/pytorch/bench_functions/apply_adam.py +215 -0
- msprobe/pytorch/bench_functions/group_norm_silu.py +27 -0
- msprobe/pytorch/{parse.py → bench_functions/mish.py} +6 -4
- msprobe/pytorch/bench_functions/moe_gating_top_k_softmax.py +50 -0
- msprobe/pytorch/bench_functions/sort_v2.py +21 -0
- msprobe/pytorch/common/utils.py +97 -4
- msprobe/pytorch/debugger/debugger_config.py +19 -9
- msprobe/pytorch/debugger/precision_debugger.py +24 -1
- msprobe/pytorch/dump/module_dump/module_dump.py +4 -3
- msprobe/pytorch/dump/module_dump/module_processer.py +21 -35
- msprobe/pytorch/free_benchmark/common/utils.py +1 -1
- msprobe/pytorch/free_benchmark/compare/single_benchmark.py +1 -1
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +3 -3
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +3 -3
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +1 -1
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +1 -1
- msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +1 -1
- msprobe/pytorch/function_factory.py +8 -2
- msprobe/pytorch/grad_probe/grad_monitor.py +2 -2
- msprobe/pytorch/hook_module/api_register.py +131 -0
- msprobe/pytorch/hook_module/hook_module.py +19 -14
- msprobe/pytorch/hook_module/register_optimizer_hook.py +2 -1
- msprobe/pytorch/hook_module/support_wrap_ops.yaml +173 -75
- msprobe/pytorch/monitor/anomaly_detect.py +14 -29
- msprobe/pytorch/monitor/csv2tb.py +18 -14
- msprobe/pytorch/monitor/distributed/wrap_distributed.py +8 -2
- msprobe/pytorch/monitor/module_hook.py +238 -193
- msprobe/pytorch/monitor/module_metric.py +9 -6
- msprobe/pytorch/monitor/optimizer_collect.py +100 -67
- msprobe/pytorch/monitor/unittest/test_monitor.py +1 -1
- msprobe/pytorch/monitor/utils.py +76 -44
- msprobe/pytorch/online_dispatch/compare.py +0 -2
- msprobe/pytorch/online_dispatch/dispatch.py +9 -0
- msprobe/pytorch/online_dispatch/dump_compare.py +3 -0
- msprobe/pytorch/online_dispatch/utils.py +3 -0
- msprobe/pytorch/parse_tool/lib/interactive_cli.py +1 -6
- msprobe/pytorch/parse_tool/lib/utils.py +2 -1
- msprobe/pytorch/pt_config.py +30 -29
- msprobe/pytorch/service.py +114 -32
- msprobe/visualization/builder/graph_builder.py +75 -10
- msprobe/visualization/builder/msprobe_adapter.py +7 -6
- msprobe/visualization/compare/graph_comparator.py +42 -38
- msprobe/visualization/compare/mode_adapter.py +0 -19
- msprobe/visualization/graph/base_node.py +11 -3
- msprobe/visualization/graph/distributed_analyzer.py +71 -3
- msprobe/visualization/graph/graph.py +0 -11
- msprobe/visualization/graph/node_op.py +4 -3
- msprobe/visualization/graph_service.py +4 -5
- msprobe/visualization/utils.py +12 -35
- msprobe/mindspore/dump/hook_cell/api_registry.py +0 -205
- msprobe/mindspore/dump/hook_cell/wrap_api.py +0 -212
- msprobe/pytorch/hook_module/api_registry.py +0 -166
- msprobe/pytorch/hook_module/wrap_distributed.py +0 -75
- msprobe/pytorch/hook_module/wrap_functional.py +0 -66
- msprobe/pytorch/hook_module/wrap_npu_custom.py +0 -85
- msprobe/pytorch/hook_module/wrap_tensor.py +0 -69
- msprobe/pytorch/hook_module/wrap_torch.py +0 -84
- msprobe/pytorch/hook_module/wrap_vf.py +0 -60
- {mindstudio_probe-1.2.1.dist-info → mindstudio_probe-1.3.0.dist-info}/LICENSE +0 -0
- {mindstudio_probe-1.2.1.dist-info → mindstudio_probe-1.3.0.dist-info}/WHEEL +0 -0
- {mindstudio_probe-1.2.1.dist-info → mindstudio_probe-1.3.0.dist-info}/entry_points.txt +0 -0
- {mindstudio_probe-1.2.1.dist-info → mindstudio_probe-1.3.0.dist-info}/top_level.txt +0 -0
msprobe/core/compare/utils.py
CHANGED
|
@@ -170,6 +170,16 @@ def gen_op_item(op_data, op_name):
|
|
|
170
170
|
elif op_item.get('type') == 'slice':
|
|
171
171
|
op_item['dtype'] = op_data.get('type')
|
|
172
172
|
op_item['shape'] = str(np.shape(np.array(op_data.get('value'))))
|
|
173
|
+
elif op_item.get('type') == 'ellipsis':
|
|
174
|
+
op_item['dtype'] = op_data.get('type')
|
|
175
|
+
op_item['shape'] = '[]'
|
|
176
|
+
for i in params:
|
|
177
|
+
op_item[i] = op_data.get('value')
|
|
178
|
+
elif op_item.get('type') == 'torch.ProcessGroup':
|
|
179
|
+
op_item['dtype'] = op_data.get('type')
|
|
180
|
+
op_item['shape'] = '[]'
|
|
181
|
+
for i in params:
|
|
182
|
+
op_item[i] = str(op_data.get('group_ranks'))
|
|
173
183
|
else:
|
|
174
184
|
op_item['dtype'] = str(type(op_data.get('value')))
|
|
175
185
|
op_item['shape'] = '[]'
|
|
@@ -275,9 +285,9 @@ def result_item_init(n_info, b_info, dump_mode):
|
|
|
275
285
|
md5_compare_result = CompareConst.PASS if n_info.struct[2] == b_info.struct[2] else CompareConst.DIFF
|
|
276
286
|
result_item.extend([n_info.struct[2], b_info.struct[2], md5_compare_result])
|
|
277
287
|
elif dump_mode == Const.SUMMARY:
|
|
278
|
-
result_item.extend([" "] * 8)
|
|
288
|
+
result_item.extend([" "] * 8) # 8个统计量数据情况的比对指标
|
|
279
289
|
else:
|
|
280
|
-
result_item.extend([" "] *
|
|
290
|
+
result_item.extend([" "] * 6) # 6个真实数据情况的比对指标
|
|
281
291
|
else:
|
|
282
292
|
err_msg = "index out of bounds error will occur in result_item_init, please check!\n" \
|
|
283
293
|
f"npu_info_struct is {n_info.struct}\n" \
|
|
@@ -311,8 +321,8 @@ def get_accuracy(result, n_dict, b_dict, dump_mode):
|
|
|
311
321
|
has_stack = npu_stack_info and bench_stack_info
|
|
312
322
|
|
|
313
323
|
if dump_mode == Const.ALL:
|
|
314
|
-
|
|
315
|
-
|
|
324
|
+
npu_data_name_list = n_dict.get("data_name", None)
|
|
325
|
+
bench_data_name_list = b_dict.get("data_name", None)
|
|
316
326
|
|
|
317
327
|
for index in range(min_len):
|
|
318
328
|
n_name = safe_get_value(n_dict, n_start + index, "n_dict", key="op_name")
|
|
@@ -343,7 +353,9 @@ def get_accuracy(result, n_dict, b_dict, dump_mode):
|
|
|
343
353
|
result_item.append(err_msg)
|
|
344
354
|
result_item = stack_column_process(result_item, has_stack, index, key, npu_stack_info)
|
|
345
355
|
if dump_mode == Const.ALL:
|
|
346
|
-
|
|
356
|
+
npu_data_name = safe_get_value(npu_data_name_list, n_start + index, "npu_data_name_list")
|
|
357
|
+
bench_data_name = safe_get_value(bench_data_name_list, b_start + index, "bench_data_name_list")
|
|
358
|
+
result_item.append([npu_data_name, bench_data_name])
|
|
347
359
|
|
|
348
360
|
result.append(result_item)
|
|
349
361
|
|
|
@@ -361,7 +373,7 @@ def get_accuracy(result, n_dict, b_dict, dump_mode):
|
|
|
361
373
|
continue
|
|
362
374
|
result_item = [
|
|
363
375
|
n_name, CompareConst.NAN, n_struct[0], CompareConst.NAN, n_struct[1], CompareConst.NAN,
|
|
364
|
-
" ", " ", " ", " ", " "
|
|
376
|
+
" ", " ", " ", " ", " ", " "
|
|
365
377
|
]
|
|
366
378
|
summary_data = n_dict.get(CompareConst.SUMMARY)[n_start + index]
|
|
367
379
|
result_item.extend(summary_data)
|
|
@@ -378,7 +390,8 @@ def get_accuracy(result, n_dict, b_dict, dump_mode):
|
|
|
378
390
|
result_item.append(err_msg)
|
|
379
391
|
result_item = stack_column_process(result_item, has_stack, index, key, npu_stack_info)
|
|
380
392
|
if dump_mode == Const.ALL:
|
|
381
|
-
|
|
393
|
+
npu_data_name = safe_get_value(npu_data_name_list, n_start + index, "npu_data_name_list")
|
|
394
|
+
result_item.append([npu_data_name, "-1"])
|
|
382
395
|
|
|
383
396
|
result.append(result_item)
|
|
384
397
|
|
|
@@ -443,9 +456,9 @@ def get_un_match_accuracy(result, n_dict, dump_mode):
|
|
|
443
456
|
result.append(result_item)
|
|
444
457
|
continue
|
|
445
458
|
if dump_mode == Const.SUMMARY:
|
|
446
|
-
result_item.extend([CompareConst.N_A] * 8)
|
|
459
|
+
result_item.extend([CompareConst.N_A] * 8) # 8个统计量数据情况的比对指标
|
|
447
460
|
if dump_mode == Const.ALL:
|
|
448
|
-
result_item.extend([CompareConst.N_A] *
|
|
461
|
+
result_item.extend([CompareConst.N_A] * 6) # 6个真实数据情况的比对指标
|
|
449
462
|
|
|
450
463
|
npu_summary_data = safe_get_value(summary_reorder, index, "summary_reorder")
|
|
451
464
|
bench_summary_data = [CompareConst.N_A] * 4
|
|
@@ -457,7 +470,7 @@ def get_un_match_accuracy(result, n_dict, dump_mode):
|
|
|
457
470
|
result_item.append(err_msg)
|
|
458
471
|
append_stack_info(result_item, npu_stack_info, index)
|
|
459
472
|
if dump_mode == Const.ALL and result_item[1] == CompareConst.N_A:
|
|
460
|
-
result_item.extend(["-1"])
|
|
473
|
+
result_item.extend([["-1", "-1"]])
|
|
461
474
|
result.append(result_item)
|
|
462
475
|
|
|
463
476
|
|
|
@@ -532,10 +545,17 @@ def get_name_and_state(name):
|
|
|
532
545
|
|
|
533
546
|
state type: input, output, kwargs, parameters, parameters_grad
|
|
534
547
|
"""
|
|
548
|
+
if not isinstance(name, str):
|
|
549
|
+
logger.error(f'Invalid name: {name}, type should be string, please check.')
|
|
550
|
+
raise CompareException(CompareException.INVALID_API_NAME_ERROR)
|
|
551
|
+
|
|
535
552
|
if Const.PARAMS_GRAD in name.split(Const.SEP):
|
|
536
553
|
return name.split(Const.PARAMS_GRAD)[0], Const.PARAMS_GRAD
|
|
537
554
|
|
|
538
555
|
split = re.split(Const.REGEX_FORWARD_BACKWARD, name)
|
|
556
|
+
if len(split) < 3:
|
|
557
|
+
logger.error(f'Invalid name string: {name}, can not be split by forward/backward, please check.')
|
|
558
|
+
raise CompareException(CompareException.INVALID_API_NAME_ERROR)
|
|
539
559
|
api = f'{split[0]}.{split[1]}.'
|
|
540
560
|
state_str = split[2]
|
|
541
561
|
match = re.match(r'^(\d+\.)?(input|output|kwargs|parameters)\..+$', state_str)
|
|
@@ -0,0 +1,176 @@
|
|
|
1
|
+
# Copyright (c) 2025-2025, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
from typing import Dict, Any, Optional, Callable, Union, List, Tuple
|
|
17
|
+
|
|
18
|
+
from msprobe.core.common.const import Const
|
|
19
|
+
from msprobe.core.common.file_utils import load_yaml
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def _get_attr(module, attr_name):
|
|
23
|
+
if Const.SEP in attr_name:
|
|
24
|
+
sub_module_name, sub_attr = attr_name.rsplit(Const.SEP, 1)
|
|
25
|
+
sub_module = getattr(module, sub_module_name, None)
|
|
26
|
+
attr = getattr(sub_module, sub_attr, None)
|
|
27
|
+
else:
|
|
28
|
+
attr = getattr(module, attr_name, None)
|
|
29
|
+
return attr
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class ApiWrapper:
|
|
33
|
+
def __init__(
|
|
34
|
+
self, api_types: Dict[str, Dict[str, Any]],
|
|
35
|
+
api_list_paths: Union[str, List[str], Tuple[str]]
|
|
36
|
+
):
|
|
37
|
+
self.api_types = api_types
|
|
38
|
+
if not isinstance(api_list_paths, (list, tuple)):
|
|
39
|
+
api_list_paths = [api_list_paths] * len(self.api_types)
|
|
40
|
+
elif len(api_list_paths) != len(self.api_types):
|
|
41
|
+
raise RuntimeError("The number of api_list_paths must be equal to the number of frameworks in 'api_types', "
|
|
42
|
+
"when api_list_paths is a list or tuple.")
|
|
43
|
+
self.api_list_paths = api_list_paths
|
|
44
|
+
self.api_names = self._get_api_names()
|
|
45
|
+
self.wrapped_api_functions = dict()
|
|
46
|
+
|
|
47
|
+
def wrap_api(
|
|
48
|
+
self, api_templates, hook_build_func: Optional[Callable]
|
|
49
|
+
):
|
|
50
|
+
api_types_num = sum([len(v) for v in self.api_types.values()])
|
|
51
|
+
if not isinstance(api_templates, (list, tuple)):
|
|
52
|
+
api_templates = [api_templates] * api_types_num
|
|
53
|
+
elif len(api_templates) != api_types_num:
|
|
54
|
+
raise RuntimeError("The number of api_templates must be equal to the number of api_types, "
|
|
55
|
+
"when api_templates is a list or tuple.")
|
|
56
|
+
|
|
57
|
+
self.wrapped_api_functions.clear()
|
|
58
|
+
index = 0
|
|
59
|
+
for framework, api_types in self.api_types.items():
|
|
60
|
+
wrapped_functions_in_framework = dict()
|
|
61
|
+
for api_type, api_modules in api_types.items():
|
|
62
|
+
wrapped_functions = dict()
|
|
63
|
+
name_prefix = Const.API_DATA_PREFIX.get(framework, {}).get(api_type, "API")
|
|
64
|
+
api_template = api_templates[index]
|
|
65
|
+
index += 1
|
|
66
|
+
for api_name in self.api_names.get(framework, {}).get(api_type, []):
|
|
67
|
+
ori_api = _get_attr(api_modules[0], api_name)
|
|
68
|
+
if callable(ori_api):
|
|
69
|
+
def wrap_api_func(api_name, api_func, prefix, hook_build_func, api_template):
|
|
70
|
+
def api_function(*args, **kwargs):
|
|
71
|
+
return api_template(api_name, api_func, prefix, hook_build_func)(*args, **kwargs)
|
|
72
|
+
api_function.__name__ = api_name
|
|
73
|
+
return api_function
|
|
74
|
+
wrapped_functions[api_name] = wrap_api_func(api_name, ori_api, name_prefix,
|
|
75
|
+
hook_build_func, api_template)
|
|
76
|
+
wrapped_functions_in_framework[api_type] = wrapped_functions
|
|
77
|
+
self.wrapped_api_functions[framework] = wrapped_functions_in_framework
|
|
78
|
+
return self.wrapped_api_functions
|
|
79
|
+
|
|
80
|
+
def _get_api_names(self):
|
|
81
|
+
api_names = dict()
|
|
82
|
+
|
|
83
|
+
for index, framework in enumerate(self.api_types.keys()):
|
|
84
|
+
api_list = load_yaml(self.api_list_paths[index])
|
|
85
|
+
valid_names = dict()
|
|
86
|
+
for api_type, api_modules in self.api_types.get(framework, {}).items():
|
|
87
|
+
api_from_file = api_list.get(Const.SUPPORT_API_DICT_KEY_MAP.get(framework, {}).get(api_type), [])
|
|
88
|
+
names = set()
|
|
89
|
+
for api_name in api_from_file:
|
|
90
|
+
target_attr = api_name
|
|
91
|
+
target_module = api_modules[0]
|
|
92
|
+
if Const.SEP in api_name:
|
|
93
|
+
sub_module_name, target_attr = api_name.rsplit(Const.SEP, 1)
|
|
94
|
+
target_module = getattr(api_modules[0], sub_module_name, None)
|
|
95
|
+
if target_module and target_attr in dir(target_module):
|
|
96
|
+
names.add(api_name)
|
|
97
|
+
valid_names[api_type] = names
|
|
98
|
+
api_names[framework] = valid_names
|
|
99
|
+
|
|
100
|
+
return api_names
|
|
101
|
+
|
|
102
|
+
|
|
103
|
+
class ApiRegistry:
|
|
104
|
+
"""
|
|
105
|
+
Base class for api registry.
|
|
106
|
+
"""
|
|
107
|
+
|
|
108
|
+
def __init__(self, api_types, inner_used_api, supported_api_list_path, api_templates):
|
|
109
|
+
self.ori_api_attr = dict()
|
|
110
|
+
self.wrapped_api_attr = dict()
|
|
111
|
+
self.inner_used_ori_attr = dict()
|
|
112
|
+
self.inner_used_wrapped_attr = dict()
|
|
113
|
+
self.api_types = api_types
|
|
114
|
+
self.inner_used_api = inner_used_api
|
|
115
|
+
self.supported_api_list_path = supported_api_list_path
|
|
116
|
+
self.api_templates = api_templates
|
|
117
|
+
|
|
118
|
+
@staticmethod
|
|
119
|
+
def store_ori_attr(ori_api_group, api_list, api_ori_attr):
|
|
120
|
+
for api in api_list:
|
|
121
|
+
api_ori_attr[api] = _get_attr(ori_api_group, api)
|
|
122
|
+
|
|
123
|
+
@staticmethod
|
|
124
|
+
def set_api_attr(api_group, attr_dict):
|
|
125
|
+
for api, api_attr in attr_dict.items():
|
|
126
|
+
if Const.SEP in api:
|
|
127
|
+
sub_module_name, sub_op = api.rsplit(Const.SEP, 1)
|
|
128
|
+
sub_module = getattr(api_group, sub_module_name, None)
|
|
129
|
+
if sub_module is not None:
|
|
130
|
+
setattr(sub_module, sub_op, api_attr)
|
|
131
|
+
else:
|
|
132
|
+
setattr(api_group, api, api_attr)
|
|
133
|
+
|
|
134
|
+
def register_all_api(self):
|
|
135
|
+
for framework, api_types in self.api_types.items():
|
|
136
|
+
for api_type, api_modules in api_types.items():
|
|
137
|
+
api_type_with_framework = framework + Const.SEP + api_type
|
|
138
|
+
for module in api_modules[1]:
|
|
139
|
+
self.set_api_attr(module, self.wrapped_api_attr.get(api_type_with_framework, {}))
|
|
140
|
+
|
|
141
|
+
def register_inner_used_api(self):
|
|
142
|
+
for api_type in self.inner_used_api.keys():
|
|
143
|
+
self.set_api_attr(self.inner_used_api.get(api_type)[0], self.inner_used_wrapped_attr.get(api_type, {}))
|
|
144
|
+
|
|
145
|
+
def restore_all_api(self):
|
|
146
|
+
for framework, api_types in self.api_types.items():
|
|
147
|
+
for api_type, api_modules in api_types.items():
|
|
148
|
+
api_type_with_framework = framework + Const.SEP + api_type
|
|
149
|
+
for module in api_modules[1]:
|
|
150
|
+
self.set_api_attr(module, self.ori_api_attr.get(api_type_with_framework, {}))
|
|
151
|
+
|
|
152
|
+
def restore_inner_used_api(self):
|
|
153
|
+
for api_type in self.inner_used_api.keys():
|
|
154
|
+
self.set_api_attr(self.inner_used_api.get(api_type)[0], self.inner_used_ori_attr.get(api_type, {}))
|
|
155
|
+
|
|
156
|
+
def initialize_hook(self, hook_build_func):
|
|
157
|
+
api_wrapper = ApiWrapper(self.api_types, self.supported_api_list_path)
|
|
158
|
+
wrapped_api_functions = api_wrapper.wrap_api(self.api_templates, hook_build_func)
|
|
159
|
+
|
|
160
|
+
for framework, api_types in self.api_types.items():
|
|
161
|
+
for api_type, api_modules in api_types.items():
|
|
162
|
+
ori_attr = dict()
|
|
163
|
+
self.store_ori_attr(api_modules[0], api_wrapper.api_names.get(framework).get(api_type), ori_attr)
|
|
164
|
+
api_type_with_framework = framework + Const.SEP + api_type
|
|
165
|
+
self.ori_api_attr[api_type_with_framework] = ori_attr
|
|
166
|
+
self.wrapped_api_attr[api_type_with_framework] = wrapped_api_functions.get(framework).get(api_type)
|
|
167
|
+
|
|
168
|
+
for inner_used_api_type, inner_used_api_list in self.inner_used_api.items():
|
|
169
|
+
ori_attr = dict()
|
|
170
|
+
wrapped_attr = dict()
|
|
171
|
+
for api_name in inner_used_api_list[1:]:
|
|
172
|
+
if self.ori_api_attr.get(inner_used_api_type, {}).get(api_name):
|
|
173
|
+
ori_attr[api_name] = self.ori_api_attr.get(inner_used_api_type).get(api_name)
|
|
174
|
+
wrapped_attr[api_name] = self.wrapped_api_attr.get(inner_used_api_type).get(api_name)
|
|
175
|
+
self.inner_used_ori_attr[inner_used_api_type] = ori_attr
|
|
176
|
+
self.inner_used_wrapped_attr[inner_used_api_type] = wrapped_attr
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright (c) 2024-
|
|
1
|
+
# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
|
|
2
2
|
# All rights reserved.
|
|
3
3
|
#
|
|
4
4
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
@@ -40,6 +40,7 @@ class DataCollector:
|
|
|
40
40
|
self.scope = ScopeFactory(self.config).build_scope()
|
|
41
41
|
self.backward_module_names = {}
|
|
42
42
|
self.optimizer_status = ""
|
|
43
|
+
self.optimizer_status_first_start = {Const.OPTIMIZER: True, Const.CLIP_GRAD: True}
|
|
43
44
|
atexit.register(self.write_json)
|
|
44
45
|
|
|
45
46
|
@property
|
|
@@ -54,6 +55,17 @@ class DataCollector:
|
|
|
54
55
|
def check_scope_and_pid(scope, name, pid):
|
|
55
56
|
return (not scope or scope.check(name)) and pid == os.getpid()
|
|
56
57
|
|
|
58
|
+
@staticmethod
|
|
59
|
+
def set_is_recomputable(data_info, is_recompute):
|
|
60
|
+
if data_info and len(data_info) == 1 and is_recompute is not None: # 正常情况下data_info的长度应改为1
|
|
61
|
+
data_info[list(data_info.keys())[0]]["is_recompute"] = is_recompute
|
|
62
|
+
|
|
63
|
+
def reset_status(self):
|
|
64
|
+
self.optimizer_status = ""
|
|
65
|
+
self.optimizer_status_first_start = {Const.OPTIMIZER: True, Const.CLIP_GRAD: True}
|
|
66
|
+
self.data_writer.reset_cache()
|
|
67
|
+
self.backward_module_names.clear()
|
|
68
|
+
|
|
57
69
|
def if_return_forward_new_output(self):
|
|
58
70
|
return self.data_processor.if_return_forward_new_output()
|
|
59
71
|
|
|
@@ -77,7 +89,7 @@ class DataCollector:
|
|
|
77
89
|
logger.debug(msg)
|
|
78
90
|
self.data_writer.update_data(data_info)
|
|
79
91
|
|
|
80
|
-
def forward_input_data_collect(self, name, module, pid, module_input_output):
|
|
92
|
+
def forward_input_data_collect(self, name, module, pid, module_input_output, is_recompute=None):
|
|
81
93
|
if self.config.task == Const.FREE_BENCHMARK:
|
|
82
94
|
backward_name = name.replace(Const.FORWARD, Const.BACKWARD)
|
|
83
95
|
if self.check_scope_and_pid(self.scope, backward_name, pid):
|
|
@@ -87,37 +99,48 @@ class DataCollector:
|
|
|
87
99
|
if not self.check_scope_and_pid(self.scope, name, pid):
|
|
88
100
|
return
|
|
89
101
|
|
|
90
|
-
data_info =
|
|
102
|
+
data_info = {}
|
|
103
|
+
if self.config.task != Const.STRUCTURE:
|
|
104
|
+
data_info = self.data_processor.analyze_forward_input(name, module, module_input_output)
|
|
105
|
+
self.set_is_recomputable(data_info, is_recompute)
|
|
91
106
|
if self.config.level == Const.LEVEL_L2:
|
|
92
107
|
return
|
|
93
108
|
self.handle_data(name, data_info, flush=self.data_processor.is_terminated)
|
|
94
109
|
|
|
95
|
-
def forward_output_data_collect(self, name, module, pid, module_input_output):
|
|
110
|
+
def forward_output_data_collect(self, name, module, pid, module_input_output, is_recompute=None):
|
|
96
111
|
self.update_construct(name)
|
|
97
112
|
if not self.check_scope_and_pid(self.scope, name, pid):
|
|
98
113
|
return
|
|
99
114
|
|
|
100
|
-
data_info =
|
|
115
|
+
data_info = {}
|
|
116
|
+
if self.config.task != Const.STRUCTURE:
|
|
117
|
+
data_info = self.data_processor.analyze_forward_output(name, module, module_input_output)
|
|
118
|
+
self.set_is_recomputable(data_info, is_recompute)
|
|
101
119
|
if self.config.level == Const.LEVEL_L2:
|
|
102
120
|
return
|
|
103
121
|
self.data_writer.update_stack(self.data_processor.analyze_api_call_stack(name))
|
|
104
122
|
self.handle_data(name, data_info, flush=self.data_processor.is_terminated)
|
|
105
123
|
|
|
106
|
-
def forward_data_collect(self, name, module, pid, module_input_output):
|
|
124
|
+
def forward_data_collect(self, name, module, pid, module_input_output, is_recompute=None):
|
|
107
125
|
self.update_construct(name)
|
|
108
126
|
if not self.check_scope_and_pid(self.scope, name, pid):
|
|
109
127
|
return
|
|
110
128
|
|
|
111
|
-
data_info =
|
|
129
|
+
data_info = {}
|
|
130
|
+
if self.config.task != Const.STRUCTURE:
|
|
131
|
+
data_info = self.data_processor.analyze_forward(name, module, module_input_output)
|
|
132
|
+
self.set_is_recomputable(data_info, is_recompute)
|
|
112
133
|
self.data_writer.update_stack(self.data_processor.analyze_api_call_stack(name))
|
|
113
134
|
self.handle_data(name, data_info, flush=self.data_processor.is_terminated)
|
|
114
135
|
|
|
115
|
-
def backward_data_collect(self, name, module, pid, module_input_output):
|
|
136
|
+
def backward_data_collect(self, name, module, pid, module_input_output, is_recompute=None):
|
|
116
137
|
self.update_construct(name)
|
|
117
138
|
if not self.check_scope_and_pid(self.scope, name, pid):
|
|
118
139
|
return
|
|
119
140
|
|
|
120
|
-
data_info =
|
|
141
|
+
data_info = {}
|
|
142
|
+
if self.config.task != Const.STRUCTURE:
|
|
143
|
+
data_info = self.data_processor.analyze_backward(name, module, module_input_output)
|
|
121
144
|
if self.config.level == Const.LEVEL_L2:
|
|
122
145
|
return
|
|
123
146
|
# 获取执行反向的模块名称
|
|
@@ -127,25 +150,34 @@ class DataCollector:
|
|
|
127
150
|
self.backward_module_names[module_name] = True
|
|
128
151
|
self.handle_data(name, data_info, flush=self.data_processor.is_terminated)
|
|
129
152
|
|
|
130
|
-
def backward_input_data_collect(self, name, module, pid, module_input_output):
|
|
153
|
+
def backward_input_data_collect(self, name, module, pid, module_input_output, is_recompute=None):
|
|
131
154
|
self.update_construct(name)
|
|
132
155
|
if not self.check_scope_and_pid(self.scope, name, pid):
|
|
133
156
|
return
|
|
134
157
|
|
|
135
|
-
data_info =
|
|
158
|
+
data_info = {}
|
|
159
|
+
if self.config.task != Const.STRUCTURE:
|
|
160
|
+
data_info = self.data_processor.analyze_backward_input(name, module, module_input_output)
|
|
161
|
+
self.set_is_recomputable(data_info, is_recompute)
|
|
136
162
|
self.handle_data(name, data_info)
|
|
137
163
|
|
|
138
|
-
def backward_output_data_collect(self, name, module, pid, module_input_output):
|
|
164
|
+
def backward_output_data_collect(self, name, module, pid, module_input_output, is_recompute=None):
|
|
139
165
|
self.update_construct(name)
|
|
140
166
|
if not self.check_scope_and_pid(self.scope, name, pid):
|
|
141
167
|
return
|
|
142
168
|
|
|
143
|
-
data_info =
|
|
169
|
+
data_info = {}
|
|
170
|
+
if self.config.task != Const.STRUCTURE:
|
|
171
|
+
data_info = self.data_processor.analyze_backward_output(name, module, module_input_output)
|
|
172
|
+
self.set_is_recomputable(data_info, is_recompute)
|
|
144
173
|
self.handle_data(name, data_info)
|
|
145
174
|
|
|
146
175
|
def update_construct(self, name):
|
|
147
176
|
if self.config.level not in DataCollector.level_without_construct:
|
|
148
177
|
if self.optimizer_status in [Const.OPTIMIZER, Const.CLIP_GRAD]:
|
|
178
|
+
if self.optimizer_status_first_start[self.optimizer_status]:
|
|
179
|
+
self.data_writer.update_construct({self.optimizer_status: None})
|
|
180
|
+
self.optimizer_status_first_start[self.optimizer_status] = False
|
|
149
181
|
self.data_writer.update_construct({name: self.optimizer_status})
|
|
150
182
|
else:
|
|
151
183
|
self.data_writer.update_construct({name: self.module_processor.api_parent_node})
|
|
@@ -183,3 +215,16 @@ class DataCollector:
|
|
|
183
215
|
|
|
184
216
|
def fill_stack_tensor_data(self):
|
|
185
217
|
self.data_writer.fill_stack_tensor_data()
|
|
218
|
+
|
|
219
|
+
def debug_data_collect_forward(self, variable, name_with_count):
|
|
220
|
+
|
|
221
|
+
data_info = self.data_processor.analyze_debug_forward(variable, name_with_count)
|
|
222
|
+
self.data_writer.update_debug({name_with_count: data_info})
|
|
223
|
+
|
|
224
|
+
def debug_data_collect_backward(self, variable, grad_name_with_count):
|
|
225
|
+
# prepare all None nested data structure
|
|
226
|
+
all_none_data_info = self.data_processor.analyze_element_to_all_none(variable)
|
|
227
|
+
self.data_writer.update_debug({grad_name_with_count: all_none_data_info})
|
|
228
|
+
|
|
229
|
+
# register tensor backward hook
|
|
230
|
+
self.data_processor.analyze_debug_backward(variable, grad_name_with_count, self.data_writer.cache_debug['data'])
|
|
@@ -17,6 +17,9 @@ import inspect
|
|
|
17
17
|
import os
|
|
18
18
|
from dataclasses import dataclass, is_dataclass
|
|
19
19
|
from typing import Tuple, Dict, Optional, Any
|
|
20
|
+
from functools import partial
|
|
21
|
+
import copy
|
|
22
|
+
from typing import Union
|
|
20
23
|
|
|
21
24
|
import numpy as np
|
|
22
25
|
|
|
@@ -87,7 +90,7 @@ class TensorStatInfo:
|
|
|
87
90
|
class BaseDataProcessor:
|
|
88
91
|
_recursive_key_stack = []
|
|
89
92
|
special_type = (
|
|
90
|
-
np.integer, np.floating, np.bool_, np.complexfloating, np.str_, np.byte, np.unicode_,
|
|
93
|
+
np.integer, np.floating, np.bool_, np.complexfloating, np.str_, np.byte, np.unicode_, np.ndarray,
|
|
91
94
|
bool, int, float, str, slice,
|
|
92
95
|
type(Ellipsis)
|
|
93
96
|
)
|
|
@@ -143,6 +146,37 @@ class BaseDataProcessor:
|
|
|
143
146
|
else:
|
|
144
147
|
return data
|
|
145
148
|
|
|
149
|
+
@staticmethod
|
|
150
|
+
def set_value_into_nested_structure(data_structure, indexes, value):
|
|
151
|
+
'''
|
|
152
|
+
Args:
|
|
153
|
+
data_structure: nested data structure
|
|
154
|
+
indexes: List
|
|
155
|
+
value: value to be set
|
|
156
|
+
'''
|
|
157
|
+
if not indexes:
|
|
158
|
+
raise ValueError("set_value_into_nested_structure failed: "
|
|
159
|
+
"indexes need to be non empty when set value to nested data structure")
|
|
160
|
+
current_level = data_structure
|
|
161
|
+
for i, index in enumerate(indexes):
|
|
162
|
+
valid_for_list = isinstance(current_level, list) and isinstance(index, int) and len(current_level) > index
|
|
163
|
+
valid_for_dict = isinstance(current_level, dict) and index in current_level
|
|
164
|
+
is_last = i == len(indexes) - 1
|
|
165
|
+
if valid_for_dict or valid_for_list:
|
|
166
|
+
if is_last:
|
|
167
|
+
try:
|
|
168
|
+
current_level[index] = value
|
|
169
|
+
except Exception as e:
|
|
170
|
+
raise IndexError("set_value_into_nested_structure failed: passed indexes wrong") from e
|
|
171
|
+
else:
|
|
172
|
+
try:
|
|
173
|
+
current_level = current_level[index]
|
|
174
|
+
except Exception as e:
|
|
175
|
+
raise IndexError("set_value_into_nested_structure failed: passed indexes wrong") from e
|
|
176
|
+
else:
|
|
177
|
+
raise ValueError("set_value_into_nested_structure failed: "
|
|
178
|
+
"invalid data_structure type or invalid index")
|
|
179
|
+
|
|
146
180
|
@staticmethod
|
|
147
181
|
def _convert_numpy_to_builtin(arg):
|
|
148
182
|
type_mapping = {
|
|
@@ -183,8 +217,22 @@ class BaseDataProcessor:
|
|
|
183
217
|
return single_arg
|
|
184
218
|
|
|
185
219
|
@staticmethod
|
|
186
|
-
def _analyze_numpy(
|
|
187
|
-
|
|
220
|
+
def _analyze_numpy(ndarray, numpy_type):
|
|
221
|
+
ndarray_json = {}
|
|
222
|
+
ndarray_json.update({'type': 'numpy.ndarray'})
|
|
223
|
+
ndarray_json.update({'dtype': str(ndarray.dtype)})
|
|
224
|
+
ndarray_json.update({'shape': ndarray.shape})
|
|
225
|
+
if ndarray.size > 0:
|
|
226
|
+
ndarray_json.update({"Max": np.max(ndarray).item()})
|
|
227
|
+
ndarray_json.update({"Min": np.min(ndarray).item()})
|
|
228
|
+
ndarray_json.update({"Mean": np.mean(ndarray).item()})
|
|
229
|
+
ndarray_json.update({"Norm": np.linalg.norm(ndarray).item()})
|
|
230
|
+
else:
|
|
231
|
+
ndarray_json.update({"Max": None})
|
|
232
|
+
ndarray_json.update({"Min": None})
|
|
233
|
+
ndarray_json.update({"Mean": None})
|
|
234
|
+
ndarray_json.update({"Norm": None})
|
|
235
|
+
return ndarray_json
|
|
188
236
|
|
|
189
237
|
@staticmethod
|
|
190
238
|
def _get_allowed_data_mode(data_mode):
|
|
@@ -203,9 +251,9 @@ class BaseDataProcessor:
|
|
|
203
251
|
return cls.special_type
|
|
204
252
|
|
|
205
253
|
@classmethod
|
|
206
|
-
def recursive_apply_transform(cls, args, transform, depth=0):
|
|
207
|
-
if depth > Const.
|
|
208
|
-
logger.error(f"The maximum depth of recursive transform, {Const.
|
|
254
|
+
def recursive_apply_transform(cls, args, transform, depth=0) -> Union[dict, list, None]:
|
|
255
|
+
if depth > Const.DUMP_MAX_DEPTH:
|
|
256
|
+
logger.error(f"The maximum depth of recursive transform, {Const.DUMP_MAX_DEPTH} is reached.")
|
|
209
257
|
raise CompareException(CompareException.RECURSION_LIMIT_ERROR)
|
|
210
258
|
if isinstance(args, cls.get_special_types()):
|
|
211
259
|
arg_transform = transform(args, cls._recursive_key_stack)
|
|
@@ -220,7 +268,7 @@ class BaseDataProcessor:
|
|
|
220
268
|
return cls.apply_transform_dict(args_dict, transform, depth)
|
|
221
269
|
elif isinstance(args, (list, tuple)):
|
|
222
270
|
result_list = cls.apply_transform_list(args, transform, depth)
|
|
223
|
-
return
|
|
271
|
+
return result_list
|
|
224
272
|
elif isinstance(args, dict):
|
|
225
273
|
return cls.apply_transform_dict(args, transform, depth)
|
|
226
274
|
elif args is not None:
|
|
@@ -228,12 +276,12 @@ class BaseDataProcessor:
|
|
|
228
276
|
return None
|
|
229
277
|
else:
|
|
230
278
|
return None
|
|
231
|
-
|
|
279
|
+
|
|
232
280
|
@classmethod
|
|
233
281
|
def apply_transform_dict(cls, args, transform, depth):
|
|
234
282
|
result_dict = {}
|
|
235
283
|
for k, arg in args.items():
|
|
236
|
-
cls._recursive_key_stack.append(
|
|
284
|
+
cls._recursive_key_stack.append(k)
|
|
237
285
|
result_dict[k] = cls.recursive_apply_transform(arg, transform, depth=depth + 1)
|
|
238
286
|
cls._recursive_key_stack.pop()
|
|
239
287
|
return result_dict
|
|
@@ -242,11 +290,21 @@ class BaseDataProcessor:
|
|
|
242
290
|
def apply_transform_list(cls, args, transform, depth):
|
|
243
291
|
result_list = []
|
|
244
292
|
for i, arg in enumerate(args):
|
|
245
|
-
cls._recursive_key_stack.append(
|
|
293
|
+
cls._recursive_key_stack.append(i)
|
|
246
294
|
result_list.append(cls.recursive_apply_transform(arg, transform, depth=depth + 1))
|
|
247
295
|
cls._recursive_key_stack.pop()
|
|
248
296
|
return result_list
|
|
249
297
|
|
|
298
|
+
@classmethod
|
|
299
|
+
def register_hook_single_element(cls, element, suffix_stack, hook_fn):
|
|
300
|
+
if cls.is_hookable_element(element):
|
|
301
|
+
indexes = copy.deepcopy(suffix_stack)
|
|
302
|
+
wrap_hook_fn = partial(hook_fn, indexes=indexes)
|
|
303
|
+
|
|
304
|
+
def real_hook_fn(grad):
|
|
305
|
+
return wrap_hook_fn(grad)
|
|
306
|
+
element.register_hook(real_hook_fn)
|
|
307
|
+
|
|
250
308
|
def if_return_forward_new_output(self):
|
|
251
309
|
return self._return_forward_new_output
|
|
252
310
|
|
|
@@ -383,3 +441,29 @@ class BaseDataProcessor:
|
|
|
383
441
|
suffix + file_format)
|
|
384
442
|
file_path = os.path.join(self.data_writer.dump_tensor_data_dir, dump_data_name)
|
|
385
443
|
return dump_data_name, file_path
|
|
444
|
+
|
|
445
|
+
def analyze_element_to_all_none(self, element):
|
|
446
|
+
return self.recursive_apply_transform(element, lambda element, stack: None)
|
|
447
|
+
|
|
448
|
+
def analyze_debug_forward(self, variable, name_with_count):
|
|
449
|
+
self.current_api_or_module_name = name_with_count
|
|
450
|
+
self.api_data_category = Const.TENSOR
|
|
451
|
+
# these two attributes are used to construct tensor file name {name_with_count}.tensor.{indexes}.npy/pt
|
|
452
|
+
data_info = self.analyze_element(variable)
|
|
453
|
+
return data_info
|
|
454
|
+
|
|
455
|
+
def analyze_debug_backward(self, variable, grad_name_with_count, nested_data_structure):
|
|
456
|
+
def hook_fn(grad, indexes):
|
|
457
|
+
suffix = Const.SEP.join([str(index) for index in indexes])
|
|
458
|
+
self.save_name = grad_name_with_count + Const.SEP + Const.TENSOR + Const.SEP + suffix
|
|
459
|
+
grad_data_info = self.analyze_element(grad)
|
|
460
|
+
self.save_name = None
|
|
461
|
+
full_index = [grad_name_with_count] + indexes
|
|
462
|
+
try:
|
|
463
|
+
self.set_value_into_nested_structure(nested_data_structure, full_index, grad_data_info)
|
|
464
|
+
except (ValueError, IndexError) as e:
|
|
465
|
+
logger.warning(f"error occured while recording statistics of {grad_name_with_count} variable, "
|
|
466
|
+
f"skip current recording, detailed infomation: {e}")
|
|
467
|
+
return grad
|
|
468
|
+
wrap_register_hook_single_element = partial(self.register_hook_single_element, hook_fn=hook_fn)
|
|
469
|
+
self.recursive_apply_transform(variable, wrap_register_hook_single_element)
|
|
@@ -14,6 +14,7 @@
|
|
|
14
14
|
# limitations under the License.
|
|
15
15
|
|
|
16
16
|
from msprobe.core.common.const import Const
|
|
17
|
+
from msprobe.core.data_dump.data_processor.base import BaseDataProcessor
|
|
17
18
|
|
|
18
19
|
|
|
19
20
|
class DataProcessorFactory:
|
|
@@ -62,6 +63,7 @@ class DataProcessorFactory:
|
|
|
62
63
|
cls.register_processor(Const.PT_FRAMEWORK, Const.OVERFLOW_CHECK, PytorchOverflowCheckDataProcessor)
|
|
63
64
|
cls.register_processor(Const.PT_FRAMEWORK, Const.FREE_BENCHMARK, PytorchFreeBenchmarkDataProcessor)
|
|
64
65
|
cls.register_processor(Const.PT_FRAMEWORK, Const.KERNEL_DUMP, PytorchKernelDumpDataProcessor)
|
|
66
|
+
cls.register_processor(Const.PT_FRAMEWORK, Const.STRUCTURE, BaseDataProcessor)
|
|
65
67
|
cls.register_module_processor(Const.PT_FRAMEWORK, ModuleProcesser)
|
|
66
68
|
elif framework == Const.MS_FRAMEWORK:
|
|
67
69
|
from msprobe.core.data_dump.data_processor.mindspore_processor import (
|
|
@@ -75,4 +77,5 @@ class DataProcessorFactory:
|
|
|
75
77
|
cls.register_processor(Const.MS_FRAMEWORK, Const.TENSOR, MindsporeTensorDataProcessor)
|
|
76
78
|
cls.register_processor(Const.MS_FRAMEWORK, Const.OVERFLOW_CHECK, MindsporeOverflowCheckDataProcessor)
|
|
77
79
|
cls.register_processor(Const.MS_FRAMEWORK, Const.KERNEL_DUMP, MindsporeKernelDumpDataProcessor)
|
|
80
|
+
cls.register_processor(Const.MS_FRAMEWORK, Const.STRUCTURE, BaseDataProcessor)
|
|
78
81
|
cls.register_module_processor(Const.MS_FRAMEWORK, CellProcessor)
|