mindstudio-probe 1.2.2__py3-none-any.whl → 1.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-1.3.0.dist-info}/METADATA +3 -3
- {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-1.3.0.dist-info}/RECORD +143 -144
- msprobe/README.md +25 -20
- msprobe/core/common/const.py +110 -66
- msprobe/core/common/decorator.py +50 -0
- msprobe/core/common/exceptions.py +3 -1
- msprobe/core/common/file_utils.py +25 -2
- msprobe/core/common/utils.py +30 -34
- msprobe/core/compare/acc_compare.py +43 -74
- msprobe/core/compare/check.py +2 -6
- msprobe/core/compare/highlight.py +2 -0
- msprobe/core/compare/layer_mapping/layer_mapping.py +2 -1
- msprobe/core/compare/merge_result/merge_result.py +8 -2
- msprobe/core/compare/multiprocessing_compute.py +19 -12
- msprobe/core/compare/npy_compare.py +30 -12
- msprobe/core/compare/utils.py +20 -10
- msprobe/core/data_dump/api_registry.py +176 -0
- msprobe/core/data_dump/data_processor/base.py +2 -2
- msprobe/core/data_dump/data_processor/mindspore_processor.py +19 -32
- msprobe/core/data_dump/data_processor/pytorch_processor.py +45 -15
- msprobe/core/data_dump/json_writer.py +38 -35
- msprobe/core/grad_probe/constant.py +1 -0
- msprobe/core/grad_probe/grad_compare.py +1 -1
- msprobe/core/overflow_check/abnormal_scene.py +2 -0
- msprobe/docs/01.installation.md +2 -1
- msprobe/docs/02.config_introduction.md +17 -15
- msprobe/docs/05.data_dump_PyTorch.md +70 -2
- msprobe/docs/06.data_dump_MindSpore.md +33 -12
- msprobe/docs/07.accuracy_checker_PyTorch.md +11 -1
- msprobe/docs/08.accuracy_checker_online_PyTorch.md +3 -1
- msprobe/docs/09.accuracy_checker_MindSpore.md +1 -1
- msprobe/docs/10.accuracy_compare_PyTorch.md +59 -33
- msprobe/docs/11.accuracy_compare_MindSpore.md +40 -16
- msprobe/docs/12.overflow_check_PyTorch.md +3 -1
- msprobe/docs/13.overflow_check_MindSpore.md +4 -2
- msprobe/docs/14.data_parse_PyTorch.md +1 -7
- msprobe/docs/18.online_dispatch.md +1 -1
- msprobe/docs/19.monitor.md +124 -62
- msprobe/docs/21.visualization_PyTorch.md +32 -13
- msprobe/docs/22.visualization_MindSpore.md +32 -13
- msprobe/docs/23.generate_operator_PyTorch.md +9 -9
- msprobe/docs/27.dump_json_instruction.md +278 -8
- msprobe/docs/28.kernel_dump_MindSpore.md +1 -1
- msprobe/docs/29.data_dump_MSAdapter.md +229 -0
- msprobe/docs/30.overflow_check_MSAdapter.md +31 -0
- msprobe/docs/FAQ.md +3 -11
- msprobe/docs/img/compare_result.png +0 -0
- msprobe/docs/img/merge_result.png +0 -0
- msprobe/docs/img/visualization/vis_browser_1.png +0 -0
- msprobe/docs/img/visualization/vis_match_info.png +0 -0
- msprobe/docs/img/visualization/vis_precision_info.png +0 -0
- msprobe/docs/img/visualization/vis_search_info.png +0 -0
- msprobe/docs/img/visualization/vis_show_info.png +0 -0
- msprobe/docs/img/visualization/vis_showcase.png +0 -0
- msprobe/docs/img/visualization/vis_unmatch_info.png +0 -0
- msprobe/mindspore/__init__.py +4 -3
- msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +6 -1
- msprobe/mindspore/api_accuracy_checker/api_runner.py +19 -9
- msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +2 -1
- msprobe/mindspore/api_accuracy_checker/bench_functions/flash_attention_score.py +602 -0
- msprobe/mindspore/api_accuracy_checker/bench_functions/fusion_operator.py +41 -0
- msprobe/mindspore/api_accuracy_checker/data_manager.py +2 -1
- msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +2 -1
- msprobe/mindspore/api_accuracy_checker/torch_mindtorch_importer.py +2 -1
- msprobe/mindspore/common/const.py +61 -0
- msprobe/mindspore/common/utils.py +31 -19
- msprobe/mindspore/compare/ms_compare.py +27 -19
- msprobe/mindspore/compare/ms_graph_compare.py +6 -5
- msprobe/mindspore/debugger/debugger_config.py +6 -4
- msprobe/mindspore/debugger/precision_debugger.py +22 -10
- msprobe/mindspore/dump/dump_tool_factory.py +5 -3
- msprobe/mindspore/dump/hook_cell/api_register.py +142 -0
- msprobe/mindspore/dump/hook_cell/hook_cell.py +9 -10
- msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +24 -26
- msprobe/mindspore/dump/jit_dump.py +14 -9
- msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +22 -56
- msprobe/mindspore/dym_loader/hook_dynamic_loader.h +0 -1
- msprobe/mindspore/free_benchmark/api_pynative_self_check.py +10 -6
- msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +4 -2
- msprobe/mindspore/free_benchmark/self_check_tool_factory.py +6 -3
- msprobe/mindspore/grad_probe/global_context.py +2 -0
- msprobe/mindspore/grad_probe/grad_analyzer.py +2 -1
- msprobe/mindspore/grad_probe/hook.py +2 -4
- msprobe/mindspore/monitor/distributed/wrap_distributed.py +1 -1
- msprobe/mindspore/monitor/module_hook.py +354 -302
- msprobe/mindspore/monitor/utils.py +46 -4
- msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +5 -3
- msprobe/mindspore/service.py +23 -17
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +3 -6
- msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +11 -6
- msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +2 -2
- msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +4 -5
- msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +5 -5
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +25 -6
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +28 -19
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +3 -1
- msprobe/pytorch/bench_functions/moe_gating_top_k_softmax.py +6 -0
- msprobe/pytorch/common/utils.py +29 -7
- msprobe/pytorch/debugger/precision_debugger.py +10 -1
- msprobe/pytorch/dump/module_dump/module_dump.py +4 -3
- msprobe/pytorch/dump/module_dump/module_processer.py +12 -6
- msprobe/pytorch/free_benchmark/common/utils.py +1 -1
- msprobe/pytorch/free_benchmark/compare/single_benchmark.py +1 -1
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +3 -3
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +3 -3
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +1 -1
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +1 -1
- msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +1 -1
- msprobe/pytorch/function_factory.py +1 -1
- msprobe/pytorch/grad_probe/grad_monitor.py +2 -2
- msprobe/pytorch/hook_module/api_register.py +131 -0
- msprobe/pytorch/hook_module/hook_module.py +19 -14
- msprobe/pytorch/hook_module/register_optimizer_hook.py +2 -1
- msprobe/pytorch/hook_module/support_wrap_ops.yaml +172 -75
- msprobe/pytorch/monitor/csv2tb.py +8 -2
- msprobe/pytorch/monitor/distributed/wrap_distributed.py +8 -2
- msprobe/pytorch/monitor/module_hook.py +131 -105
- msprobe/pytorch/monitor/module_metric.py +3 -0
- msprobe/pytorch/monitor/optimizer_collect.py +55 -4
- msprobe/pytorch/monitor/unittest/test_monitor.py +1 -1
- msprobe/pytorch/monitor/utils.py +68 -1
- msprobe/pytorch/online_dispatch/compare.py +0 -2
- msprobe/pytorch/online_dispatch/dispatch.py +9 -0
- msprobe/pytorch/online_dispatch/dump_compare.py +3 -0
- msprobe/pytorch/online_dispatch/utils.py +3 -0
- msprobe/pytorch/parse_tool/lib/interactive_cli.py +1 -6
- msprobe/pytorch/parse_tool/lib/utils.py +2 -1
- msprobe/pytorch/pt_config.py +11 -7
- msprobe/pytorch/service.py +11 -8
- msprobe/visualization/builder/graph_builder.py +44 -5
- msprobe/visualization/builder/msprobe_adapter.py +0 -1
- msprobe/visualization/compare/graph_comparator.py +42 -38
- msprobe/visualization/compare/mode_adapter.py +0 -19
- msprobe/visualization/graph/base_node.py +8 -1
- msprobe/visualization/graph/distributed_analyzer.py +1 -10
- msprobe/visualization/graph/graph.py +0 -11
- msprobe/visualization/graph/node_op.py +1 -2
- msprobe/visualization/graph_service.py +1 -1
- msprobe/visualization/utils.py +2 -33
- msprobe/mindspore/dump/hook_cell/api_registry.py +0 -207
- msprobe/mindspore/dump/hook_cell/wrap_api.py +0 -212
- msprobe/pytorch/hook_module/api_registry.py +0 -166
- msprobe/pytorch/hook_module/wrap_distributed.py +0 -79
- msprobe/pytorch/hook_module/wrap_functional.py +0 -66
- msprobe/pytorch/hook_module/wrap_npu_custom.py +0 -85
- msprobe/pytorch/hook_module/wrap_tensor.py +0 -69
- msprobe/pytorch/hook_module/wrap_torch.py +0 -84
- msprobe/pytorch/hook_module/wrap_vf.py +0 -60
- msprobe/pytorch/parse.py +0 -19
- {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-1.3.0.dist-info}/LICENSE +0 -0
- {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-1.3.0.dist-info}/WHEEL +0 -0
- {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-1.3.0.dist-info}/entry_points.txt +0 -0
- {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-1.3.0.dist-info}/top_level.txt +0 -0
msprobe/pytorch/common/utils.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright (c) 2024-
|
|
1
|
+
# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
|
|
2
2
|
# All rights reserved.
|
|
3
3
|
#
|
|
4
4
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
@@ -28,7 +28,7 @@ from msprobe.core.common.exceptions import DistributedNotInitializedError
|
|
|
28
28
|
from msprobe.core.common.file_utils import (FileCheckConst, change_mode,
|
|
29
29
|
check_file_or_directory_path, check_path_before_create, FileOpen)
|
|
30
30
|
from msprobe.core.common.log import logger
|
|
31
|
-
from msprobe.core.common.utils import check_seed_all
|
|
31
|
+
from msprobe.core.common.utils import check_seed_all, is_save_variable_valid
|
|
32
32
|
from packaging import version
|
|
33
33
|
|
|
34
34
|
try:
|
|
@@ -57,7 +57,7 @@ def parameter_adapter(func):
|
|
|
57
57
|
|
|
58
58
|
@wraps(func)
|
|
59
59
|
def inner(self, *args, **kwargs):
|
|
60
|
-
if self.
|
|
60
|
+
if self.api_name == "__getitem__" and len(args) > 1 and isinstance(args[1], torch.Tensor):
|
|
61
61
|
input_tensor = args[0]
|
|
62
62
|
indices = args[1]
|
|
63
63
|
if indices.dtype == torch.uint8:
|
|
@@ -77,7 +77,7 @@ def parameter_adapter(func):
|
|
|
77
77
|
else:
|
|
78
78
|
res = [input_tensor[tensor_index] for tensor_index in indices]
|
|
79
79
|
return getattr(torch._C._VariableFunctionsClass, "stack")(res, 0)
|
|
80
|
-
if self.
|
|
80
|
+
if self.api_name == "__eq__" and len(args) > 1 and args[1] is None:
|
|
81
81
|
return False
|
|
82
82
|
return func(self, *args, **kwargs)
|
|
83
83
|
|
|
@@ -261,6 +261,10 @@ class Const:
|
|
|
261
261
|
NPU = 'NPU'
|
|
262
262
|
DISTRIBUTED = 'Distributed'
|
|
263
263
|
|
|
264
|
+
HIFLOAT8_TYPE = "torch_npu.HiFloat8Tensor"
|
|
265
|
+
FLOAT8_E5M2_TYPE = "torch.float8_e5m2"
|
|
266
|
+
FLOAT8_E4M3FN_TYPE = "torch.float8_e4m3fn"
|
|
267
|
+
|
|
264
268
|
RAISE_PRECISION = {
|
|
265
269
|
torch.float16: torch.float32,
|
|
266
270
|
torch.bfloat16: torch.float32,
|
|
@@ -419,7 +423,11 @@ def is_recomputation():
|
|
|
419
423
|
bool: True if in the re-computation phase, False otherwise.
|
|
420
424
|
"""
|
|
421
425
|
backward_function_indices = []
|
|
422
|
-
|
|
426
|
+
try:
|
|
427
|
+
call_stack = inspect.stack()
|
|
428
|
+
except Exception as e:
|
|
429
|
+
logger.warning(f"Failed to capture stack trace, recomputation validation may be incorrect, error info: {e}.")
|
|
430
|
+
return False
|
|
423
431
|
|
|
424
432
|
# Identify the function 'backward' is being executed within the 'torch/_tensor.py' file.
|
|
425
433
|
for frame_info in call_stack:
|
|
@@ -449,9 +457,11 @@ def is_recomputation():
|
|
|
449
457
|
|
|
450
458
|
def check_save_param(variable, name, save_backward):
|
|
451
459
|
# try catch this api to skip invalid call
|
|
452
|
-
|
|
460
|
+
valid_data_types = tuple([torch.Tensor, int, float, str])
|
|
461
|
+
if not is_save_variable_valid(variable, valid_data_types):
|
|
462
|
+
valid_data_types_with_nested_types = valid_data_types + (dict, tuple, list)
|
|
453
463
|
logger.warning("PrecisionDebugger.save variable type not valid, "
|
|
454
|
-
"should be one of
|
|
464
|
+
f"should be one of {valid_data_types_with_nested_types}"
|
|
455
465
|
"Skip current save process.")
|
|
456
466
|
raise ValueError
|
|
457
467
|
if not isinstance(name, str):
|
|
@@ -473,3 +483,15 @@ def replace_last_occurrence(text, old, new):
|
|
|
473
483
|
if index != -1:
|
|
474
484
|
return text[:index] + text[index:].replace(old, new, 1)
|
|
475
485
|
return text
|
|
486
|
+
|
|
487
|
+
|
|
488
|
+
def is_hifloat8_tensor(tensor):
|
|
489
|
+
if not is_gpu and hasattr(torch_npu, "HiFloat8Tensor") and isinstance(tensor, torch_npu.HiFloat8Tensor):
|
|
490
|
+
return True
|
|
491
|
+
return False
|
|
492
|
+
|
|
493
|
+
|
|
494
|
+
def is_float8_tensor(tensor):
|
|
495
|
+
if str(tensor.dtype) in [Const.FLOAT8_E5M2_TYPE, Const.FLOAT8_E4M3FN_TYPE]:
|
|
496
|
+
return True
|
|
497
|
+
return is_hifloat8_tensor(tensor)
|
|
@@ -19,7 +19,7 @@ import torch
|
|
|
19
19
|
from msprobe.core.common.const import Const, FileCheckConst, MsgConst
|
|
20
20
|
from msprobe.core.common.exceptions import MsprobeException
|
|
21
21
|
from msprobe.core.common.file_utils import FileChecker
|
|
22
|
-
from msprobe.core.common.utils import get_real_step_or_rank
|
|
22
|
+
from msprobe.core.common.utils import get_real_step_or_rank, check_init_step
|
|
23
23
|
from msprobe.pytorch.common.log import logger
|
|
24
24
|
from msprobe.pytorch.common.utils import check_save_param
|
|
25
25
|
from msprobe.pytorch.debugger.debugger_config import DebuggerConfig
|
|
@@ -172,6 +172,15 @@ class PrecisionDebugger:
|
|
|
172
172
|
return
|
|
173
173
|
instance.service.save(variable, name, save_backward)
|
|
174
174
|
|
|
175
|
+
@classmethod
|
|
176
|
+
def set_init_step(cls, step):
|
|
177
|
+
instance = cls._instance
|
|
178
|
+
if not instance:
|
|
179
|
+
raise Exception(MsgConst.NOT_CREATED_INSTANCE)
|
|
180
|
+
check_init_step(step)
|
|
181
|
+
instance.service.init_step = step
|
|
182
|
+
instance.service.loop = 0
|
|
183
|
+
|
|
175
184
|
|
|
176
185
|
def module_dump(module, dump_name):
|
|
177
186
|
if not isinstance(module, torch.nn.Module):
|
|
@@ -17,7 +17,7 @@ import torch
|
|
|
17
17
|
from msprobe.core.common.const import Const
|
|
18
18
|
from msprobe.core.data_dump.scope import BaseScope
|
|
19
19
|
from msprobe.pytorch.common.log import logger
|
|
20
|
-
from msprobe.pytorch.hook_module.
|
|
20
|
+
from msprobe.pytorch.hook_module.api_register import get_api_register
|
|
21
21
|
|
|
22
22
|
torch_version_above_or_equal_2 = torch.__version__.split('+')[0] >= '2.0'
|
|
23
23
|
|
|
@@ -26,13 +26,14 @@ class ModuleDumper:
|
|
|
26
26
|
def __init__(self, service):
|
|
27
27
|
self.service = service
|
|
28
28
|
self.hook_handle_list = []
|
|
29
|
+
self.api_register = get_api_register()
|
|
29
30
|
|
|
30
31
|
def start_module_dump(self, module, dump_name):
|
|
31
|
-
api_register.
|
|
32
|
+
self.api_register.restore_all_api()
|
|
32
33
|
self.register_hook(module, dump_name)
|
|
33
34
|
|
|
34
35
|
def stop_module_dump(self):
|
|
35
|
-
api_register.
|
|
36
|
+
self.api_register.register_all_api()
|
|
36
37
|
for hook_handle in self.hook_handle_list:
|
|
37
38
|
if isinstance(hook_handle, torch.utils.hooks.RemovableHandle):
|
|
38
39
|
hook_handle.remove()
|
|
@@ -16,15 +16,17 @@
|
|
|
16
16
|
from functools import wraps
|
|
17
17
|
|
|
18
18
|
import torch
|
|
19
|
+
from torch.utils.hooks import BackwardHook
|
|
20
|
+
|
|
19
21
|
from msprobe.core.common.const import Const
|
|
22
|
+
from msprobe.core.common.decorator import recursion_depth_decorator
|
|
20
23
|
from msprobe.core.data_dump.scope import BaseScope, ModuleRangeScope, MixRangeScope
|
|
21
24
|
from msprobe.pytorch.common.log import logger
|
|
22
|
-
from msprobe.pytorch.common.utils import replace_last_occurrence
|
|
23
|
-
from torch.utils.checkpoint import checkpoint as origin_checkpoint
|
|
24
|
-
from torch.utils.checkpoint import set_checkpoint_early_stop
|
|
25
|
-
from torch.utils.hooks import BackwardHook
|
|
25
|
+
from msprobe.pytorch.common.utils import replace_last_occurrence, is_float8_tensor
|
|
26
26
|
|
|
27
27
|
torch_version_above_or_equal_2 = torch.__version__.split('+')[0] >= '2.0'
|
|
28
|
+
if torch_version_above_or_equal_2:
|
|
29
|
+
from torch.utils.checkpoint import checkpoint as origin_checkpoint, set_checkpoint_early_stop
|
|
28
30
|
|
|
29
31
|
|
|
30
32
|
def checkpoint_without_early_stop(*args, **kwargs):
|
|
@@ -33,7 +35,8 @@ def checkpoint_without_early_stop(*args, **kwargs):
|
|
|
33
35
|
|
|
34
36
|
|
|
35
37
|
def replace_checkpoint():
|
|
36
|
-
|
|
38
|
+
if torch_version_above_or_equal_2:
|
|
39
|
+
torch.utils.checkpoint.checkpoint = checkpoint_without_early_stop
|
|
37
40
|
|
|
38
41
|
|
|
39
42
|
class ModuleProcesser:
|
|
@@ -58,8 +61,9 @@ class ModuleProcesser:
|
|
|
58
61
|
return clone_return_value_func
|
|
59
62
|
|
|
60
63
|
@staticmethod
|
|
64
|
+
@recursion_depth_decorator("ModuleDump: ModuleProcesser.clone_if_tensor", max_depth=Const.DUMP_MAX_DEPTH)
|
|
61
65
|
def clone_if_tensor(result):
|
|
62
|
-
if isinstance(result, torch.Tensor):
|
|
66
|
+
if isinstance(result, torch.Tensor) and not is_float8_tensor(result):
|
|
63
67
|
return result.clone()
|
|
64
68
|
elif type(result) is tuple:
|
|
65
69
|
return tuple(ModuleProcesser.clone_if_tensor(x) for x in result)
|
|
@@ -109,6 +113,8 @@ class ModuleProcesser:
|
|
|
109
113
|
for name, module in modules_and_names:
|
|
110
114
|
if module == model:
|
|
111
115
|
continue
|
|
116
|
+
if module.__class__.__name__ == "FullyShardedDataParallel":
|
|
117
|
+
continue
|
|
112
118
|
module_index = (index + Const.SEP) if index != "-1" else ""
|
|
113
119
|
prefix_name = (BaseScope.Module_Type_Module + Const.SEP + module_index +
|
|
114
120
|
name + Const.SEP + module.__class__.__name__ + Const.SEP)
|
|
@@ -16,7 +16,7 @@
|
|
|
16
16
|
|
|
17
17
|
import torch
|
|
18
18
|
from msprobe.core.common.exceptions import FreeBenchmarkException
|
|
19
|
-
from msprobe.core.common.
|
|
19
|
+
from msprobe.core.common.decorator import recursion_depth_decorator
|
|
20
20
|
from msprobe.pytorch.free_benchmark.common.enums import DeviceType
|
|
21
21
|
|
|
22
22
|
|
|
@@ -16,7 +16,7 @@
|
|
|
16
16
|
import math
|
|
17
17
|
|
|
18
18
|
import torch
|
|
19
|
-
from msprobe.core.common.
|
|
19
|
+
from msprobe.core.common.decorator import recursion_depth_decorator
|
|
20
20
|
from msprobe.pytorch.free_benchmark import logger
|
|
21
21
|
from msprobe.pytorch.free_benchmark.common.constant import ThresholdConfig
|
|
22
22
|
from msprobe.pytorch.free_benchmark.common.utils import TorchC
|
|
@@ -14,7 +14,7 @@
|
|
|
14
14
|
# limitations under the License.
|
|
15
15
|
|
|
16
16
|
import torch
|
|
17
|
-
from msprobe.core.common.
|
|
17
|
+
from msprobe.core.common.decorator import recursion_depth_decorator
|
|
18
18
|
from msprobe.pytorch.free_benchmark import logger
|
|
19
19
|
from msprobe.pytorch.free_benchmark.common.constant import ThresholdConfig
|
|
20
20
|
from msprobe.pytorch.free_benchmark.common.enums import PerturbationMode
|
|
@@ -95,13 +95,13 @@ class AddNoiseLayer(NpuBaseLayer):
|
|
|
95
95
|
except Exception:
|
|
96
96
|
logger.warning_on_rank_0(
|
|
97
97
|
f"[msprobe] Free Benchmark: For {self.api_name}, "
|
|
98
|
-
f"when
|
|
98
|
+
f"when calculating the maximum value, the tensor is changed to float32."
|
|
99
99
|
)
|
|
100
100
|
max_val = TorchC.max(TorchC.abs(tensor_obj.to(torch.float32))).item()
|
|
101
101
|
if max_val < abs_tol:
|
|
102
102
|
logger.warning_on_rank_0(
|
|
103
103
|
f"[msprobe] Free Benchmark: For {self.api_name}, "
|
|
104
|
-
f"
|
|
104
|
+
f"maximum value is less than the minimum threshold. Cancel adding noise."
|
|
105
105
|
)
|
|
106
106
|
return False
|
|
107
107
|
return True
|
|
@@ -14,7 +14,7 @@
|
|
|
14
14
|
# limitations under the License.
|
|
15
15
|
|
|
16
16
|
import torch
|
|
17
|
-
from msprobe.core.common.
|
|
17
|
+
from msprobe.core.common.decorator import recursion_depth_decorator
|
|
18
18
|
from msprobe.pytorch.free_benchmark import logger
|
|
19
19
|
from msprobe.pytorch.free_benchmark.common.constant import ThresholdConfig
|
|
20
20
|
from msprobe.pytorch.free_benchmark.common.enums import PerturbationMode
|
|
@@ -100,13 +100,13 @@ class BitNoiseLayer(NpuBaseLayer):
|
|
|
100
100
|
except Exception:
|
|
101
101
|
logger.warning_on_rank_0(
|
|
102
102
|
f"[msprobe] Free Benchmark: For {self.api_name}, "
|
|
103
|
-
f"when calculate
|
|
103
|
+
f"when calculate the maximum value, the tensor is changed to float32."
|
|
104
104
|
)
|
|
105
105
|
max_val = TorchC.max(TorchC.abs(tensor_obj.to(torch.float32))).item()
|
|
106
106
|
if max_val < abs_tol:
|
|
107
107
|
logger.warning_on_rank_0(
|
|
108
108
|
f"[msprobe] Free Benchmark: For {self.api_name}, "
|
|
109
|
-
f"
|
|
109
|
+
f"maximum value is less than the minimum threshold. Cancel adding noise."
|
|
110
110
|
)
|
|
111
111
|
return False
|
|
112
112
|
return True
|
|
@@ -14,7 +14,7 @@
|
|
|
14
14
|
# limitations under the License.
|
|
15
15
|
|
|
16
16
|
import torch
|
|
17
|
-
from msprobe.core.common.
|
|
17
|
+
from msprobe.core.common.decorator import recursion_depth_decorator
|
|
18
18
|
from msprobe.pytorch.free_benchmark import logger
|
|
19
19
|
from msprobe.pytorch.free_benchmark.common.enums import PerturbationMode
|
|
20
20
|
from msprobe.pytorch.free_benchmark.common.params import DataParams
|
|
@@ -15,7 +15,7 @@
|
|
|
15
15
|
|
|
16
16
|
import torch
|
|
17
17
|
from msprobe.core.common.const import Const
|
|
18
|
-
from msprobe.core.common.
|
|
18
|
+
from msprobe.core.common.decorator import recursion_depth_decorator
|
|
19
19
|
from msprobe.pytorch.free_benchmark import logger
|
|
20
20
|
from msprobe.pytorch.free_benchmark.common.constant import CommonField
|
|
21
21
|
from msprobe.pytorch.free_benchmark.common.enums import PerturbationMode
|
|
@@ -49,6 +49,6 @@ class CheckerHandler(FuzzHandler):
|
|
|
49
49
|
except Exception as e:
|
|
50
50
|
logger.warning_on_rank_0(
|
|
51
51
|
f"[msprobe] Free Benchmark: For {self.params.api_name}, "
|
|
52
|
-
f"when
|
|
52
|
+
f"when comparing the results, an exception is raised: {e}"
|
|
53
53
|
)
|
|
54
54
|
return data_params.original_result
|
|
@@ -70,7 +70,7 @@ class Register(dict):
|
|
|
70
70
|
|
|
71
71
|
def add_register_item(key, value):
|
|
72
72
|
if key in self._dict:
|
|
73
|
-
logger.warning(f"{value.__name__} has been registered before, so we will
|
|
73
|
+
logger.warning(f"{value.__name__} has been registered before, so we will override it.")
|
|
74
74
|
self[key] = value
|
|
75
75
|
return value
|
|
76
76
|
|
|
@@ -46,7 +46,7 @@ class GradientMonitor:
|
|
|
46
46
|
if not os.path.exists(self._output_path):
|
|
47
47
|
create_directory(self._output_path)
|
|
48
48
|
else:
|
|
49
|
-
logger.warning(f"the file in {self._output_path} will be
|
|
49
|
+
logger.warning(f"the file in {self._output_path} will be deleted")
|
|
50
50
|
self._step = -1
|
|
51
51
|
self._param2name = defaultdict(str)
|
|
52
52
|
|
|
@@ -97,7 +97,7 @@ class GradientMonitor:
|
|
|
97
97
|
create_directory(output_dirpath)
|
|
98
98
|
output_path = os.path.join(output_dirpath, f"grad_summary_{self._step}.csv")
|
|
99
99
|
if os.path.exists(output_path):
|
|
100
|
-
logger.warning(f"{output_path} will be
|
|
100
|
+
logger.warning(f"{output_path} will be deleted")
|
|
101
101
|
remove_path(output_path)
|
|
102
102
|
header_result = GradStatCsv.generate_csv_header(self._level_adp, self._bounds)
|
|
103
103
|
output_lines.insert(0, header_result)
|
|
@@ -0,0 +1,131 @@
|
|
|
1
|
+
# Copyright (c) 2025-2025, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
import functools
|
|
17
|
+
import os
|
|
18
|
+
|
|
19
|
+
import torch
|
|
20
|
+
import torch.distributed as dist
|
|
21
|
+
|
|
22
|
+
from msprobe.core.common.const import Const
|
|
23
|
+
from msprobe.core.data_dump.api_registry import ApiRegistry
|
|
24
|
+
from msprobe.pytorch.common.utils import (
|
|
25
|
+
torch_without_guard_version, is_gpu, torch_device_guard, parameter_adapter
|
|
26
|
+
)
|
|
27
|
+
from msprobe.pytorch.function_factory import npu_custom_functions
|
|
28
|
+
from msprobe.pytorch.hook_module.hook_module import HOOKModule
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
torch_version_above_2 = torch.__version__.split('+')[0] > '2.0'
|
|
32
|
+
|
|
33
|
+
_api_types = {
|
|
34
|
+
Const.PT_FRAMEWORK: {
|
|
35
|
+
Const.PT_API_TYPE_FUNCTIONAL: (torch.nn.functional, (torch.nn.functional,)),
|
|
36
|
+
Const.PT_API_TYPE_TENSOR: (torch.Tensor, (torch.Tensor,)),
|
|
37
|
+
Const.PT_API_TYPE_TORCH: (torch, (torch,)),
|
|
38
|
+
Const.PT_API_TYPE_VF: (torch._C._VariableFunctionsClass, (torch._VF,)),
|
|
39
|
+
Const.PT_API_TYPE_DIST: (dist, (dist, dist.distributed_c10d))
|
|
40
|
+
}
|
|
41
|
+
}
|
|
42
|
+
if not is_gpu:
|
|
43
|
+
import torch_npu
|
|
44
|
+
if torch_without_guard_version:
|
|
45
|
+
_api_types.get(Const.PT_FRAMEWORK).update(
|
|
46
|
+
{
|
|
47
|
+
Const.PT_API_TYPE_NPU: (torch.ops.npu, (torch_npu, torch.ops.npu))
|
|
48
|
+
}
|
|
49
|
+
)
|
|
50
|
+
else:
|
|
51
|
+
_api_types.get(Const.PT_FRAMEWORK).update(
|
|
52
|
+
{Const.PT_API_TYPE_NPU: (torch_npu._C._VariableFunctionsClass, (torch_npu,))}
|
|
53
|
+
)
|
|
54
|
+
_api_types.get(Const.PT_FRAMEWORK).update(
|
|
55
|
+
{
|
|
56
|
+
Const.PT_API_TYPE_NPU_DIST: (torch_npu.distributed, (torch_npu.distributed,
|
|
57
|
+
torch_npu.distributed.distributed_c10d))
|
|
58
|
+
}
|
|
59
|
+
)
|
|
60
|
+
|
|
61
|
+
_inner_used_api = {}
|
|
62
|
+
_supported_api_list_path = (os.path.join(os.path.dirname(os.path.realpath(__file__)), Const.SUPPORT_API_FILE_NAME),)
|
|
63
|
+
_cuda_func_mapping = {"npu_fusion_attention": "gpu_fusion_attention"}
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
@parameter_adapter
|
|
67
|
+
def tensor_module_forward(module, *args, **kwargs):
|
|
68
|
+
return module.api_func(*args, **kwargs)
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
def dist_module_forward(module, *args, **kwargs):
|
|
72
|
+
handle = module.api_func(*args, **kwargs)
|
|
73
|
+
if kwargs.get("async_op") or module.api_name in ["isend", "irecv"]:
|
|
74
|
+
if handle and hasattr(handle, 'wait'):
|
|
75
|
+
handle.wait()
|
|
76
|
+
if module.api_name == "batch_isend_irecv":
|
|
77
|
+
if isinstance(handle, list):
|
|
78
|
+
for req in handle:
|
|
79
|
+
req.wait()
|
|
80
|
+
return handle
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
def npu_module_forward(module, *args, **kwargs):
|
|
84
|
+
if not module.need_hook:
|
|
85
|
+
if module.api_name not in npu_custom_functions:
|
|
86
|
+
raise Exception(f'There is not bench function {module.api_name}')
|
|
87
|
+
if module.device == Const.CUDA_LOWERCASE:
|
|
88
|
+
module.api_name = _cuda_func_mapping.get(module.api_name, module.api_name)
|
|
89
|
+
if module.device in [Const.CUDA_LOWERCASE, Const.CPU_LOWERCASE]:
|
|
90
|
+
return npu_custom_functions[module.api_name](*args, **kwargs)
|
|
91
|
+
return module.api_func(*args, **kwargs)
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
forward_methods = {
|
|
95
|
+
"Tensor": tensor_module_forward,
|
|
96
|
+
"Distributed": dist_module_forward,
|
|
97
|
+
"NPU": npu_module_forward
|
|
98
|
+
}
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
class ApiTemplate(HOOKModule):
|
|
102
|
+
def __init__(self, api_name, api_func, prefix, hook_build_func, need_hook=True, device=Const.CPU_LOWERCASE):
|
|
103
|
+
self.api_name = api_name
|
|
104
|
+
self.api_func = api_func
|
|
105
|
+
self.prefix = prefix
|
|
106
|
+
self.prefix_api_name = prefix + Const.SEP + str(api_name.split(Const.SEP)[-1]) + Const.SEP
|
|
107
|
+
self.need_hook = need_hook
|
|
108
|
+
self.device = device
|
|
109
|
+
if self.need_hook:
|
|
110
|
+
super().__init__(hook_build_func)
|
|
111
|
+
if prefix == Const.DIST_API_TYPE_PREFIX:
|
|
112
|
+
self.op_is_distributed = True
|
|
113
|
+
|
|
114
|
+
@torch_device_guard
|
|
115
|
+
def forward(self, *args, **kwargs):
|
|
116
|
+
exec_func = forward_methods.get(self.prefix)
|
|
117
|
+
exec_func = functools.partial(exec_func, self) if exec_func else self.api_func
|
|
118
|
+
return exec_func(*args, **kwargs)
|
|
119
|
+
|
|
120
|
+
|
|
121
|
+
api_register = None
|
|
122
|
+
|
|
123
|
+
|
|
124
|
+
def get_api_register(return_new=False):
|
|
125
|
+
if return_new:
|
|
126
|
+
return ApiRegistry(_api_types, _inner_used_api, _supported_api_list_path, ApiTemplate)
|
|
127
|
+
|
|
128
|
+
global api_register
|
|
129
|
+
if api_register is None:
|
|
130
|
+
api_register = ApiRegistry(_api_types, _inner_used_api, _supported_api_list_path, ApiTemplate)
|
|
131
|
+
return api_register
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright (c) 2024-
|
|
1
|
+
# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
|
|
2
2
|
# All rights reserved.
|
|
3
3
|
#
|
|
4
4
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
@@ -21,6 +21,8 @@ import torch
|
|
|
21
21
|
import torch.nn as nn
|
|
22
22
|
import torch.utils.hooks as full_hooks
|
|
23
23
|
|
|
24
|
+
from msprobe.pytorch.common.utils import is_float8_tensor
|
|
25
|
+
|
|
24
26
|
torch_version_above_or_equal_2 = torch.__version__.split('+')[0] >= '2.0'
|
|
25
27
|
|
|
26
28
|
|
|
@@ -28,28 +30,27 @@ class HOOKModule(nn.Module):
|
|
|
28
30
|
module_count = defaultdict(int)
|
|
29
31
|
inner_stop_hook = {}
|
|
30
32
|
|
|
31
|
-
def __init__(self,
|
|
33
|
+
def __init__(self, hook_build_func) -> None:
|
|
32
34
|
super(HOOKModule, self).__init__()
|
|
33
35
|
self.has_overflow = False
|
|
34
|
-
self.prefix = ""
|
|
35
36
|
self.current_thread = threading.current_thread().ident
|
|
36
37
|
if self.current_thread not in HOOKModule.inner_stop_hook:
|
|
37
38
|
HOOKModule.inner_stop_hook[self.current_thread] = False
|
|
38
39
|
self.stop_hook = HOOKModule.inner_stop_hook.get(self.current_thread, False)
|
|
39
40
|
|
|
40
41
|
if not self.stop_hook:
|
|
41
|
-
if hasattr(self, "prefix_op_name_"):
|
|
42
|
-
self.prefix = self.prefix_op_name_
|
|
43
|
-
|
|
44
42
|
self.forward_data_collected = False
|
|
45
|
-
|
|
46
|
-
if
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
43
|
+
|
|
44
|
+
prefix = self.prefix_api_name if hasattr(self, "prefix_api_name") else ""
|
|
45
|
+
if callable(hook_build_func):
|
|
46
|
+
forward_pre_hook, forward_hook, backward_hook, _ = hook_build_func(prefix)
|
|
47
|
+
if torch_version_above_or_equal_2:
|
|
48
|
+
self.register_forward_pre_hook(forward_pre_hook, with_kwargs=True)
|
|
49
|
+
self.register_forward_hook(forward_hook, with_kwargs=True)
|
|
50
|
+
else:
|
|
51
|
+
self.register_forward_pre_hook(forward_pre_hook)
|
|
52
|
+
self.register_forward_hook(forward_hook)
|
|
53
|
+
self.register_backward_hook(backward_hook)
|
|
53
54
|
|
|
54
55
|
def __call__(self, *args, **kwargs):
|
|
55
56
|
changed = False
|
|
@@ -111,6 +112,10 @@ class HOOKModule(nn.Module):
|
|
|
111
112
|
return result
|
|
112
113
|
else:
|
|
113
114
|
return result
|
|
115
|
+
|
|
116
|
+
if is_float8_tensor(var) or not (var.requires_grad and torch.is_grad_enabled()):
|
|
117
|
+
return result
|
|
118
|
+
|
|
114
119
|
grad_fn = var.grad_fn
|
|
115
120
|
if grad_fn is not None:
|
|
116
121
|
for hook in non_full_backward_hooks:
|
|
@@ -32,8 +32,9 @@ def register_optimizer_hook(data_collector):
|
|
|
32
32
|
def patch_clip_grad(func):
|
|
33
33
|
def wrapper(*args, **kwargs):
|
|
34
34
|
data_collector.optimizer_status = Const.CLIP_GRAD
|
|
35
|
-
func(*args, **kwargs)
|
|
35
|
+
result = func(*args, **kwargs)
|
|
36
36
|
data_collector.optimizer_status = Const.END_PREFIX + Const.CLIP_GRAD
|
|
37
|
+
return result
|
|
37
38
|
|
|
38
39
|
return wrapper
|
|
39
40
|
|