mindstudio-probe 1.0.1__py3-none-any.whl → 1.0.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.3.dist-info}/METADATA +5 -1
- mindstudio_probe-1.0.3.dist-info/RECORD +272 -0
- msprobe/README.md +78 -23
- msprobe/__init__.py +1 -0
- msprobe/config/README.md +182 -40
- msprobe/config/config.json +22 -0
- msprobe/core/__init__.py +0 -0
- msprobe/{pytorch → core}/advisor/advisor.py +3 -3
- msprobe/{pytorch → core}/advisor/advisor_result.py +2 -2
- msprobe/core/common/const.py +82 -5
- msprobe/core/common/exceptions.py +30 -18
- msprobe/core/common/file_check.py +19 -1
- msprobe/core/common/log.py +15 -1
- msprobe/core/common/utils.py +130 -30
- msprobe/core/common_config.py +32 -19
- msprobe/core/compare/acc_compare.py +299 -0
- msprobe/core/compare/check.py +95 -0
- msprobe/core/compare/compare_cli.py +49 -0
- msprobe/core/compare/highlight.py +222 -0
- msprobe/core/compare/multiprocessing_compute.py +149 -0
- msprobe/{pytorch → core}/compare/npy_compare.py +55 -4
- msprobe/core/compare/utils.py +429 -0
- msprobe/core/data_dump/data_collector.py +39 -35
- msprobe/core/data_dump/data_processor/base.py +85 -37
- msprobe/core/data_dump/data_processor/factory.py +5 -7
- msprobe/core/data_dump/data_processor/mindspore_processor.py +198 -0
- msprobe/core/data_dump/data_processor/pytorch_processor.py +94 -51
- msprobe/core/data_dump/json_writer.py +11 -11
- msprobe/core/grad_probe/__init__.py +0 -0
- msprobe/core/grad_probe/constant.py +71 -0
- msprobe/core/grad_probe/grad_compare.py +175 -0
- msprobe/core/grad_probe/utils.py +52 -0
- msprobe/doc/grad_probe/grad_probe.md +207 -0
- msprobe/doc/grad_probe/img/image-1.png +0 -0
- msprobe/doc/grad_probe/img/image-2.png +0 -0
- msprobe/doc/grad_probe/img/image-3.png +0 -0
- msprobe/doc/grad_probe/img/image-4.png +0 -0
- msprobe/doc/grad_probe/img/image.png +0 -0
- msprobe/mindspore/api_accuracy_checker/__init__.py +0 -0
- msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +246 -0
- msprobe/mindspore/api_accuracy_checker/api_info.py +69 -0
- msprobe/mindspore/api_accuracy_checker/api_runner.py +152 -0
- msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +197 -0
- msprobe/mindspore/api_accuracy_checker/compute_element.py +224 -0
- msprobe/mindspore/api_accuracy_checker/main.py +16 -0
- msprobe/mindspore/api_accuracy_checker/type_mapping.py +114 -0
- msprobe/mindspore/api_accuracy_checker/utils.py +63 -0
- msprobe/mindspore/cell_processor.py +34 -0
- msprobe/mindspore/common/const.py +87 -0
- msprobe/mindspore/common/log.py +38 -0
- msprobe/mindspore/common/utils.py +57 -0
- msprobe/mindspore/compare/distributed_compare.py +75 -0
- msprobe/mindspore/compare/ms_compare.py +117 -0
- msprobe/mindspore/compare/ms_graph_compare.py +317 -0
- msprobe/mindspore/compare/ms_to_pt_api.yaml +399 -0
- msprobe/mindspore/debugger/debugger_config.py +38 -15
- msprobe/mindspore/debugger/precision_debugger.py +79 -4
- msprobe/mindspore/doc/compare.md +58 -0
- msprobe/mindspore/doc/dump.md +158 -6
- msprobe/mindspore/dump/dump_tool_factory.py +19 -22
- msprobe/mindspore/dump/hook_cell/api_registry.py +104 -0
- msprobe/mindspore/dump/hook_cell/hook_cell.py +53 -0
- msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +925 -0
- msprobe/mindspore/dump/hook_cell/wrap_functional.py +91 -0
- msprobe/mindspore/dump/hook_cell/wrap_tensor.py +63 -0
- msprobe/mindspore/dump/jit_dump.py +56 -0
- msprobe/mindspore/dump/kernel_kbyk_dump.py +65 -0
- msprobe/mindspore/free_benchmark/__init__.py +0 -0
- msprobe/mindspore/free_benchmark/api_pynative_self_check.py +116 -0
- msprobe/mindspore/free_benchmark/common/__init__.py +0 -0
- msprobe/mindspore/free_benchmark/common/config.py +12 -0
- msprobe/mindspore/free_benchmark/common/handler_params.py +17 -0
- msprobe/mindspore/free_benchmark/common/utils.py +71 -0
- msprobe/mindspore/free_benchmark/data/support_wrap_ops.yaml +842 -0
- msprobe/mindspore/free_benchmark/decorator/__init__.py +0 -0
- msprobe/mindspore/free_benchmark/decorator/dec_forward.py +42 -0
- msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +107 -0
- msprobe/mindspore/free_benchmark/handler/__init__.py +0 -0
- msprobe/mindspore/free_benchmark/handler/base_handler.py +90 -0
- msprobe/mindspore/free_benchmark/handler/check_handler.py +41 -0
- msprobe/mindspore/free_benchmark/handler/fix_handler.py +36 -0
- msprobe/mindspore/free_benchmark/handler/handler_factory.py +21 -0
- msprobe/mindspore/free_benchmark/perturbation/add_noise.py +67 -0
- msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +21 -0
- msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +63 -0
- msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +34 -0
- msprobe/mindspore/free_benchmark/perturbation/no_change.py +12 -0
- msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +27 -0
- msprobe/mindspore/free_benchmark/self_check_tool_factory.py +33 -0
- msprobe/mindspore/grad_probe/__init__.py +0 -0
- msprobe/mindspore/grad_probe/global_context.py +91 -0
- msprobe/mindspore/grad_probe/grad_analyzer.py +231 -0
- msprobe/mindspore/grad_probe/grad_monitor.py +27 -0
- msprobe/mindspore/grad_probe/grad_stat_csv.py +132 -0
- msprobe/mindspore/grad_probe/hook.py +92 -0
- msprobe/mindspore/grad_probe/utils.py +29 -0
- msprobe/mindspore/ms_config.py +63 -15
- msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +17 -15
- msprobe/mindspore/runtime.py +4 -0
- msprobe/mindspore/service.py +354 -0
- msprobe/mindspore/task_handler_factory.py +7 -4
- msprobe/msprobe.py +66 -26
- msprobe/pytorch/__init__.py +1 -1
- msprobe/pytorch/api_accuracy_checker/common/config.py +21 -16
- msprobe/pytorch/api_accuracy_checker/common/utils.py +1 -60
- msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +2 -5
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +46 -10
- msprobe/pytorch/api_accuracy_checker/compare/compare.py +84 -48
- msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +8 -12
- msprobe/pytorch/api_accuracy_checker/config.yaml +7 -1
- msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +15 -11
- msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +11 -15
- msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +16 -9
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +193 -105
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +68 -1
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/__init__.py +0 -0
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +202 -0
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +324 -0
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +204 -0
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +218 -0
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/ssl_config.py +10 -0
- msprobe/pytorch/bench_functions/__init__.py +15 -0
- msprobe/pytorch/bench_functions/apply_adam_w.py +28 -0
- msprobe/pytorch/bench_functions/confusion_transpose.py +19 -0
- msprobe/pytorch/bench_functions/fast_gelu.py +55 -0
- msprobe/pytorch/bench_functions/layer_norm_eval.py +6 -0
- msprobe/pytorch/bench_functions/linear.py +12 -0
- msprobe/pytorch/bench_functions/matmul_backward.py +48 -0
- msprobe/pytorch/bench_functions/npu_fusion_attention.py +421 -0
- msprobe/pytorch/bench_functions/rms_norm.py +15 -0
- msprobe/pytorch/bench_functions/rotary_mul.py +52 -0
- msprobe/pytorch/bench_functions/scaled_mask_softmax.py +26 -0
- msprobe/pytorch/bench_functions/swiglu.py +55 -0
- msprobe/pytorch/common/parse_json.py +3 -1
- msprobe/pytorch/common/utils.py +83 -7
- msprobe/pytorch/compare/distributed_compare.py +19 -64
- msprobe/pytorch/compare/match.py +3 -6
- msprobe/pytorch/compare/pt_compare.py +40 -0
- msprobe/pytorch/debugger/debugger_config.py +11 -2
- msprobe/pytorch/debugger/precision_debugger.py +34 -4
- msprobe/pytorch/doc/api_accuracy_checker.md +57 -13
- msprobe/pytorch/doc/api_accuracy_checker_online.md +187 -0
- msprobe/pytorch/doc/dump.md +73 -20
- msprobe/pytorch/doc/ptdbg_ascend_compare.md +75 -11
- msprobe/pytorch/doc/ptdbg_ascend_quickstart.md +3 -3
- msprobe/pytorch/doc/run_overflow_check.md +1 -1
- msprobe/pytorch/doc//321/206/320/247/320/260/321/206/320/260/320/227/321/206/320/255/320/226/321/205/342/225/226/320/265/321/205/320/225/342/225/226/321/205/320/254/342/225/221/321/206/320/251/320/277/321/211/320/272/320/234/321/210/320/277/320/221/321/205/320/242/320/234/321/206/320/220/320/267/321/210/320/223/342/225/234/321/205/320/257/342/225/221/321/207/342/225/221/342/224/220/321/206/320/232/320/265/321/205/320/241/320/232.md +151 -0
- msprobe/pytorch/free_benchmark/common/constant.py +3 -0
- msprobe/pytorch/free_benchmark/common/utils.py +4 -0
- msprobe/pytorch/free_benchmark/compare/grad_saver.py +22 -26
- msprobe/pytorch/free_benchmark/main.py +7 -4
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +1 -1
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +1 -1
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +1 -1
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +3 -3
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/no_change.py +1 -1
- msprobe/pytorch/free_benchmark/perturbed_layers/run_cpu.py +1 -1
- msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +43 -29
- msprobe/pytorch/free_benchmark/result_handlers/handler_factory.py +0 -1
- msprobe/pytorch/function_factory.py +75 -0
- msprobe/pytorch/functional/dump_module.py +4 -4
- msprobe/pytorch/grad_probe/__init__.py +0 -0
- msprobe/pytorch/grad_probe/grad_monitor.py +90 -0
- msprobe/pytorch/grad_probe/grad_stat_csv.py +129 -0
- msprobe/pytorch/hook_module/hook_module.py +14 -3
- msprobe/pytorch/hook_module/support_wrap_ops.yaml +2 -1
- msprobe/pytorch/hook_module/utils.py +9 -9
- msprobe/pytorch/hook_module/wrap_aten.py +20 -10
- msprobe/pytorch/hook_module/wrap_distributed.py +10 -7
- msprobe/pytorch/hook_module/wrap_functional.py +4 -7
- msprobe/pytorch/hook_module/wrap_npu_custom.py +21 -10
- msprobe/pytorch/hook_module/wrap_tensor.py +5 -6
- msprobe/pytorch/hook_module/wrap_torch.py +5 -7
- msprobe/pytorch/hook_module/wrap_vf.py +6 -8
- msprobe/pytorch/module_processer.py +53 -13
- msprobe/pytorch/online_dispatch/compare.py +4 -4
- msprobe/pytorch/online_dispatch/dispatch.py +39 -41
- msprobe/pytorch/online_dispatch/dump_compare.py +17 -47
- msprobe/pytorch/online_dispatch/single_compare.py +5 -5
- msprobe/pytorch/online_dispatch/utils.py +2 -43
- msprobe/pytorch/parse_tool/lib/compare.py +31 -19
- msprobe/pytorch/parse_tool/lib/config.py +2 -1
- msprobe/pytorch/parse_tool/lib/parse_tool.py +4 -4
- msprobe/pytorch/parse_tool/lib/utils.py +34 -80
- msprobe/pytorch/parse_tool/lib/visualization.py +4 -3
- msprobe/pytorch/pt_config.py +100 -6
- msprobe/pytorch/service.py +104 -19
- mindstudio_probe-1.0.1.dist-info/RECORD +0 -228
- msprobe/mindspore/dump/api_kbk_dump.py +0 -55
- msprobe/pytorch/compare/acc_compare.py +0 -1024
- msprobe/pytorch/compare/highlight.py +0 -100
- msprobe/test/core_ut/common/test_utils.py +0 -345
- msprobe/test/core_ut/data_dump/test_data_collector.py +0 -47
- msprobe/test/core_ut/data_dump/test_json_writer.py +0 -183
- msprobe/test/core_ut/data_dump/test_scope.py +0 -151
- msprobe/test/core_ut/test_common_config.py +0 -152
- msprobe/test/core_ut/test_file_check.py +0 -218
- msprobe/test/core_ut/test_log.py +0 -109
- msprobe/test/mindspore_ut/test_api_kbk_dump.py +0 -51
- msprobe/test/mindspore_ut/test_debugger_config.py +0 -42
- msprobe/test/mindspore_ut/test_dump_tool_factory.py +0 -51
- msprobe/test/mindspore_ut/test_kernel_graph_dump.py +0 -66
- msprobe/test/mindspore_ut/test_kernel_graph_overflow_check.py +0 -63
- msprobe/test/mindspore_ut/test_ms_config.py +0 -69
- msprobe/test/mindspore_ut/test_overflow_check_tool_factory.py +0 -51
- msprobe/test/mindspore_ut/test_precision_debugger.py +0 -56
- msprobe/test/mindspore_ut/test_task_handler_factory.py +0 -58
- msprobe/test/pytorch_ut/advisor/test_advisor.py +0 -83
- msprobe/test/pytorch_ut/api_accuracy_checker/common/test_common_utils.py +0 -108
- msprobe/test/pytorch_ut/api_accuracy_checker/common/test_config.py +0 -39
- msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_algorithm.py +0 -112
- msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_api_precision_compare.py +0 -77
- msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_compare.py +0 -125
- msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_compare_column.py +0 -10
- msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_compare_utils.py +0 -43
- msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/dump.json +0 -179
- msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/forward.json +0 -63
- msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_data_generate.py +0 -99
- msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_multi_run_ut.py +0 -115
- msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_run_ut.py +0 -72
- msprobe/test/pytorch_ut/compare/test_acc_compare.py +0 -17
- msprobe/test/pytorch_ut/free_benchmark/perturbed_layers/test_perturbed_layser.py +0 -105
- msprobe/test/pytorch_ut/free_benchmark/result_handlers/test_result_handler.py +0 -121
- msprobe/test/pytorch_ut/free_benchmark/test_main.py +0 -101
- msprobe/test/pytorch_ut/functional/test_dump_module.py +0 -15
- msprobe/test/pytorch_ut/hook_module/test_api_registry.py +0 -130
- msprobe/test/pytorch_ut/hook_module/test_hook_module.py +0 -42
- msprobe/test/pytorch_ut/hook_module/test_wrap_aten.py +0 -65
- msprobe/test/pytorch_ut/hook_module/test_wrap_distributed.py +0 -35
- msprobe/test/pytorch_ut/hook_module/test_wrap_functional.py +0 -20
- msprobe/test/pytorch_ut/hook_module/test_wrap_tensor.py +0 -35
- msprobe/test/pytorch_ut/hook_module/test_wrap_torch.py +0 -43
- msprobe/test/pytorch_ut/hook_module/test_wrap_vf.py +0 -11
- msprobe/test/pytorch_ut/test_pt_config.py +0 -69
- msprobe/test/pytorch_ut/test_service.py +0 -59
- msprobe/test/resources/advisor.txt +0 -3
- msprobe/test/resources/compare_result_20230703104808.csv +0 -9
- msprobe/test/resources/compare_result_without_accuracy.csv +0 -9
- msprobe/test/resources/config.yaml +0 -3
- msprobe/test/resources/npu_test.pkl +0 -8
- msprobe/test/run_test.sh +0 -30
- msprobe/test/run_ut.py +0 -58
- msprobe/test/test_module_processer.py +0 -64
- {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.3.dist-info}/LICENSE +0 -0
- {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.3.dist-info}/WHEEL +0 -0
- {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.3.dist-info}/entry_points.txt +0 -0
- {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.3.dist-info}/top_level.txt +0 -0
- /msprobe/{pytorch → core}/advisor/advisor_const.py +0 -0
- /msprobe/pytorch/doc/{atat → msprobe}/321/207/342/226/223/342/225/233/321/205/342/225/221/320/266/321/205/342/225/226/320/265/321/205/320/225/342/225/226/321/206/320/245/342/226/221/321/206/320/235/320/276dump/321/206/320/260/320/227/321/205/320/227/320/226/321/206/320/220/320/267/321/210/320/223/342/225/234/321/205/320/257/342/225/221/321/207/342/225/221/342/224/220/321/206/320/232/320/265/321/205/320/241/320/232.md" +0 -0
|
@@ -0,0 +1,90 @@
|
|
|
1
|
+
import os
|
|
2
|
+
from collections import defaultdict
|
|
3
|
+
|
|
4
|
+
import torch
|
|
5
|
+
if int(torch.__version__.split('.')[0]) >= 2:
|
|
6
|
+
from torch.optim.optimizer import register_optimizer_step_pre_hook
|
|
7
|
+
from msprobe.pytorch.grad_probe.grad_stat_csv import GradStatCsv
|
|
8
|
+
from msprobe.core.grad_probe.utils import check_numeral_list_ascend, data_in_list_target
|
|
9
|
+
from msprobe.core.grad_probe.constant import GradConst, level_adp
|
|
10
|
+
from msprobe.core.common.file_check import create_directory
|
|
11
|
+
from msprobe.core.common.log import logger
|
|
12
|
+
from msprobe.core.common.utils import remove_path, write_csv, save_npy
|
|
13
|
+
from msprobe.pytorch.common.utils import get_rank_id, print_rank_0, save_pt
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class GradientMonitor:
|
|
17
|
+
|
|
18
|
+
def __init__(self, common_config, task_config):
|
|
19
|
+
level = task_config.grad_level
|
|
20
|
+
if level not in level_adp:
|
|
21
|
+
raise Exception(f"level is valid, not in {level_adp.keys()}")
|
|
22
|
+
self._level_adp = level_adp[level]
|
|
23
|
+
self._param_list = task_config.param_list
|
|
24
|
+
self._target_ranks = common_config.rank
|
|
25
|
+
logger.info(f"target rank {self._target_ranks}")
|
|
26
|
+
self._target_step = common_config.step
|
|
27
|
+
logger.info(f"target step {self._target_step}")
|
|
28
|
+
self._bounds = task_config.bounds
|
|
29
|
+
check_numeral_list_ascend(self._bounds)
|
|
30
|
+
self._output_path = common_config.dump_path
|
|
31
|
+
if not os.path.exists(self._output_path):
|
|
32
|
+
create_directory(self._output_path)
|
|
33
|
+
else:
|
|
34
|
+
logger.warning(f"the file in {self._output_path} will be recoverd")
|
|
35
|
+
self._step = -1
|
|
36
|
+
self._param2name = defaultdict(str)
|
|
37
|
+
|
|
38
|
+
@property
|
|
39
|
+
def output_path(self):
|
|
40
|
+
return self._output_path
|
|
41
|
+
|
|
42
|
+
@staticmethod
|
|
43
|
+
def save_grad_direction(param_name, grad, save_path):
|
|
44
|
+
if not os.path.exists(save_path):
|
|
45
|
+
create_directory(save_path)
|
|
46
|
+
param_grad = grad.clone().detach()
|
|
47
|
+
is_positive = param_grad > 0
|
|
48
|
+
save_filepath = os.path.join(save_path, f"{param_name}.npy")
|
|
49
|
+
save_npy(is_positive.numpy(), save_filepath)
|
|
50
|
+
|
|
51
|
+
def monitor(self, model):
|
|
52
|
+
print_rank_0("> parameter names:")
|
|
53
|
+
for name, param in model.named_parameters():
|
|
54
|
+
self._param2name[param] = name
|
|
55
|
+
print_rank_0(f"\t{name}")
|
|
56
|
+
setattr(self, "_rank", get_rank_id())
|
|
57
|
+
if torch.distributed.is_initialized() and not data_in_list_target(getattr(self, "_rank"), self._target_ranks):
|
|
58
|
+
return
|
|
59
|
+
self._hook_optimizer()
|
|
60
|
+
|
|
61
|
+
def _hook_optimizer(self):
|
|
62
|
+
def optimizer_pre_step_hook(optimizer, args, kargs):
|
|
63
|
+
self._step += 1
|
|
64
|
+
if not data_in_list_target(self._step, self._target_step):
|
|
65
|
+
return
|
|
66
|
+
output_lines = []
|
|
67
|
+
for param, param_name in self._param2name.items():
|
|
68
|
+
if not data_in_list_target(param_name, self._param_list):
|
|
69
|
+
continue
|
|
70
|
+
grad = param.main_grad if hasattr(param, "main_grad") else param.grad
|
|
71
|
+
if grad is None:
|
|
72
|
+
logger.info(f"grad is None: {param_name}")
|
|
73
|
+
continue
|
|
74
|
+
grad_info = GradStatCsv.generate_csv_line(param_name, self._level_adp, grad, self._bounds)
|
|
75
|
+
output_lines.append(grad_info)
|
|
76
|
+
if self._level_adp["have_grad_direction"]:
|
|
77
|
+
GradientMonitor.save_grad_direction(param_name, grad,
|
|
78
|
+
f'{self._output_path}/rank{self._rank}/step{self._step}')
|
|
79
|
+
output_dirpath = os.path.join(self._output_path, f"rank{getattr(self, '_rank')}")
|
|
80
|
+
if not os.path.isdir(output_dirpath):
|
|
81
|
+
create_directory(output_dirpath)
|
|
82
|
+
output_path = os.path.join(output_dirpath, f"grad_summary_{self._step}.csv")
|
|
83
|
+
if os.path.exists(output_path):
|
|
84
|
+
logger.warning(f"{output_path} will be recoverd")
|
|
85
|
+
remove_path(output_path)
|
|
86
|
+
header_result = GradStatCsv.generate_csv_header(self._level_adp, self._bounds)
|
|
87
|
+
output_lines.insert(0, header_result)
|
|
88
|
+
write_csv(output_lines, output_path)
|
|
89
|
+
if int(torch.__version__.split('.')[0]) >= 2:
|
|
90
|
+
register_optimizer_step_pre_hook(optimizer_pre_step_hook)
|
|
@@ -0,0 +1,129 @@
|
|
|
1
|
+
from abc import ABC, abstractmethod
|
|
2
|
+
from collections import namedtuple
|
|
3
|
+
import hashlib
|
|
4
|
+
import torch
|
|
5
|
+
from msprobe.core.grad_probe.constant import GradConst
|
|
6
|
+
|
|
7
|
+
CSV_header_input = namedtuple("CSV_header_input", ["bounds"])
|
|
8
|
+
CSV_content_input = namedtuple("CSV_content_input", ["grad", "bounds"])
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class GradStatCsv:
|
|
12
|
+
csv = {}
|
|
13
|
+
|
|
14
|
+
@staticmethod
|
|
15
|
+
def generate_csv_header(level, bounds):
|
|
16
|
+
header = ["param_name"]
|
|
17
|
+
for key in level["header"]:
|
|
18
|
+
csv_header_input = CSV_header_input(bounds=bounds)
|
|
19
|
+
header.extend(GradStatCsv.csv[key].generate_csv_header(csv_header_input))
|
|
20
|
+
return header
|
|
21
|
+
|
|
22
|
+
@staticmethod
|
|
23
|
+
def generate_csv_line(param_name, level, grad, bounds):
|
|
24
|
+
line = [param_name]
|
|
25
|
+
for key in level["header"]:
|
|
26
|
+
csv_content_input = CSV_content_input(grad=grad, bounds=bounds)
|
|
27
|
+
line.extend(GradStatCsv.csv[key].generate_csv_content(csv_content_input))
|
|
28
|
+
return line
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def register_csv_item(key, cls=None):
|
|
32
|
+
if cls is None:
|
|
33
|
+
# 无参数时,返回装饰器函数
|
|
34
|
+
return lambda cls: register_csv_item(key, cls)
|
|
35
|
+
GradStatCsv.csv[key] = cls
|
|
36
|
+
return cls
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
class CsvItem(ABC):
|
|
40
|
+
@abstractmethod
|
|
41
|
+
def generate_csv_header(csv_header_input):
|
|
42
|
+
pass
|
|
43
|
+
|
|
44
|
+
@abstractmethod
|
|
45
|
+
def generate_csv_content(csv_content_input):
|
|
46
|
+
pass
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
@register_csv_item(GradConst.MD5)
|
|
50
|
+
class CSV_md5(CsvItem):
|
|
51
|
+
def generate_csv_header(csv_header_input):
|
|
52
|
+
return ["MD5"]
|
|
53
|
+
|
|
54
|
+
def generate_csv_content(csv_content_input):
|
|
55
|
+
grad = csv_content_input.grad
|
|
56
|
+
tensor_bytes = grad.cpu().detach().float().numpy().tobytes()
|
|
57
|
+
md5_hash = hashlib.md5(tensor_bytes)
|
|
58
|
+
return [md5_hash.hexdigest()]
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
@register_csv_item(GradConst.DISTRIBUTION)
|
|
62
|
+
class CSV_distribution(CsvItem):
|
|
63
|
+
def generate_csv_header(csv_header_input):
|
|
64
|
+
bounds = csv_header_input.bounds
|
|
65
|
+
intervals = []
|
|
66
|
+
if bounds:
|
|
67
|
+
intervals.append(f"(-inf, {bounds[0]}]")
|
|
68
|
+
for i in range(1, len(bounds)):
|
|
69
|
+
intervals.append(f"({bounds[i-1]}, {bounds[i]}]")
|
|
70
|
+
if intervals:
|
|
71
|
+
intervals.append(f"({bounds[-1]}, inf)")
|
|
72
|
+
intervals.append("=0")
|
|
73
|
+
|
|
74
|
+
return intervals
|
|
75
|
+
|
|
76
|
+
def generate_csv_content(csv_content_input):
|
|
77
|
+
grad = csv_content_input.grad
|
|
78
|
+
bounds = csv_content_input.bounds
|
|
79
|
+
grad = grad.cpu().detach()
|
|
80
|
+
if grad.dtype == torch.bfloat16:
|
|
81
|
+
grad = grad.to(torch.float32)
|
|
82
|
+
element_num = grad.numel()
|
|
83
|
+
grad_equal_0_num = (grad == 0).sum().item()
|
|
84
|
+
bound = torch.Tensor(bounds)
|
|
85
|
+
bucketsize_result = torch.bucketize(grad, bound)
|
|
86
|
+
interval_nums = [(bucketsize_result == i).sum().item() for i in range(len(bound) + 1)]
|
|
87
|
+
interval_nums.append(grad_equal_0_num)
|
|
88
|
+
return_list = [x / element_num if element_num != 0 else 0 for x in interval_nums]
|
|
89
|
+
return return_list
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
@register_csv_item(GradConst.MAX)
|
|
93
|
+
class CSV_max(CsvItem):
|
|
94
|
+
def generate_csv_header(csv_header_input):
|
|
95
|
+
return ["max"]
|
|
96
|
+
|
|
97
|
+
def generate_csv_content(csv_content_input):
|
|
98
|
+
grad = csv_content_input.grad
|
|
99
|
+
return [torch.max(grad).cpu().detach().float().numpy().tolist()]
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
@register_csv_item(GradConst.MIN)
|
|
103
|
+
class CSV_max(CsvItem):
|
|
104
|
+
def generate_csv_header(csv_header_input):
|
|
105
|
+
return ["min"]
|
|
106
|
+
|
|
107
|
+
def generate_csv_content(csv_content_input):
|
|
108
|
+
grad = csv_content_input.grad
|
|
109
|
+
return [torch.min(grad).cpu().detach().float().numpy().tolist()]
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
@register_csv_item(GradConst.NORM)
|
|
113
|
+
class CSV_max(CsvItem):
|
|
114
|
+
def generate_csv_header(csv_header_input):
|
|
115
|
+
return ["norm"]
|
|
116
|
+
|
|
117
|
+
def generate_csv_content(csv_content_input):
|
|
118
|
+
grad = csv_content_input.grad
|
|
119
|
+
return [torch.norm(grad).cpu().detach().float().numpy().tolist()]
|
|
120
|
+
|
|
121
|
+
|
|
122
|
+
@register_csv_item(GradConst.SHAPE)
|
|
123
|
+
class CSV_shape(CsvItem):
|
|
124
|
+
def generate_csv_header(csv_header_input):
|
|
125
|
+
return ["shape"]
|
|
126
|
+
|
|
127
|
+
def generate_csv_content(csv_content_input):
|
|
128
|
+
grad = csv_content_input.grad
|
|
129
|
+
return [list(grad.shape)]
|
|
@@ -17,10 +17,13 @@
|
|
|
17
17
|
|
|
18
18
|
import functools
|
|
19
19
|
import threading
|
|
20
|
+
|
|
20
21
|
import torch
|
|
21
22
|
import torch.nn as nn
|
|
22
23
|
import torch.utils.hooks as full_hooks
|
|
24
|
+
|
|
23
25
|
from msprobe.core.common.const import Const
|
|
26
|
+
torch_version_above_or_equal_2 = torch.__version__.split('+')[0] >= '2.0'
|
|
24
27
|
|
|
25
28
|
|
|
26
29
|
class HOOKModule(nn.Module):
|
|
@@ -46,9 +49,13 @@ class HOOKModule(nn.Module):
|
|
|
46
49
|
else:
|
|
47
50
|
HOOKModule.module_count[self.prefix] += 1
|
|
48
51
|
self.prefix = self.prefix + str(HOOKModule.module_count[self.prefix] - 1) + Const.SEP
|
|
49
|
-
forward_pre_hook, forward_hook, backward_hook = build_hook(self.prefix)
|
|
50
|
-
|
|
51
|
-
|
|
52
|
+
forward_pre_hook, forward_hook, backward_hook, _ = build_hook(self.prefix)
|
|
53
|
+
if torch_version_above_or_equal_2:
|
|
54
|
+
self.register_forward_pre_hook(forward_pre_hook, with_kwargs=True)
|
|
55
|
+
self.register_forward_hook(forward_hook, with_kwargs=True)
|
|
56
|
+
else:
|
|
57
|
+
self.register_forward_pre_hook(forward_pre_hook)
|
|
58
|
+
self.register_forward_hook(forward_hook)
|
|
52
59
|
self.register_backward_hook(backward_hook)
|
|
53
60
|
|
|
54
61
|
def __call__(self, *input, **kwargs):
|
|
@@ -61,6 +68,10 @@ class HOOKModule(nn.Module):
|
|
|
61
68
|
HOOKModule.inner_stop_hook[self.current_thread] = False
|
|
62
69
|
return result
|
|
63
70
|
|
|
71
|
+
@classmethod
|
|
72
|
+
def reset_module_stats(cls):
|
|
73
|
+
cls.module_count = {}
|
|
74
|
+
|
|
64
75
|
def _call_func(self, *input, **kwargs):
|
|
65
76
|
full_backward_hooks, non_full_backward_hooks = [], []
|
|
66
77
|
if len(self._backward_hooks) > 0:
|
|
@@ -16,14 +16,14 @@
|
|
|
16
16
|
"""
|
|
17
17
|
|
|
18
18
|
import os
|
|
19
|
-
import
|
|
19
|
+
from msprobe.core.common.utils import load_yaml
|
|
20
20
|
|
|
21
|
-
from msprobe.core.common.file_check import FileOpen
|
|
22
21
|
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
22
|
+
def get_ops():
|
|
23
|
+
cur_path = os.path.dirname(os.path.realpath(__file__))
|
|
24
|
+
yaml_path = os.path.join(cur_path, "support_wrap_ops.yaml")
|
|
25
|
+
ops = load_yaml(yaml_path)
|
|
26
|
+
wrap_functional = ops.get('functional')
|
|
27
|
+
wrap_tensor = ops.get('tensor')
|
|
28
|
+
wrap_torch = ops.get('torch')
|
|
29
|
+
return set(wrap_functional) | set(wrap_tensor) | set(wrap_torch)
|
|
@@ -18,18 +18,17 @@
|
|
|
18
18
|
import os
|
|
19
19
|
import torch
|
|
20
20
|
|
|
21
|
-
import yaml
|
|
22
|
-
|
|
23
21
|
from msprobe.pytorch.hook_module.hook_module import HOOKModule
|
|
24
22
|
from msprobe.pytorch.common.utils import torch_device_guard
|
|
25
23
|
from msprobe.core.common.const import Const
|
|
26
|
-
from msprobe.core.common.
|
|
27
|
-
|
|
24
|
+
from msprobe.core.common.utils import load_yaml
|
|
25
|
+
from msprobe.pytorch.function_factory import npu_custom_grad_functions
|
|
28
26
|
|
|
29
27
|
cur_path = os.path.dirname(os.path.realpath(__file__))
|
|
30
28
|
yaml_path = os.path.join(cur_path, "support_wrap_ops.yaml")
|
|
31
|
-
|
|
32
|
-
|
|
29
|
+
ops = load_yaml(yaml_path)
|
|
30
|
+
wrap_aten_ops = ops.get('aten')
|
|
31
|
+
white_aten_ops = ops.get('white_aten_ops', [])
|
|
33
32
|
|
|
34
33
|
|
|
35
34
|
aten_func = {}
|
|
@@ -38,9 +37,9 @@ for f in dir(torch.ops.aten):
|
|
|
38
37
|
|
|
39
38
|
|
|
40
39
|
def get_aten_ops():
|
|
41
|
-
global
|
|
40
|
+
global wrap_aten_ops
|
|
42
41
|
_all_aten_ops = dir(torch.ops.aten)
|
|
43
|
-
return set(
|
|
42
|
+
return set(wrap_aten_ops) & set(_all_aten_ops)
|
|
44
43
|
|
|
45
44
|
|
|
46
45
|
class HOOKAtenOP(object):
|
|
@@ -48,7 +47,7 @@ class HOOKAtenOP(object):
|
|
|
48
47
|
|
|
49
48
|
|
|
50
49
|
class AtenOPTemplate(HOOKModule):
|
|
51
|
-
def __init__(self, op, hook):
|
|
50
|
+
def __init__(self, op, hook, need_hook=True):
|
|
52
51
|
if isinstance(op, torch._ops.OpOverloadPacket):
|
|
53
52
|
op_name_ = op._qualified_op_name.split("::")[-1]
|
|
54
53
|
else:
|
|
@@ -58,10 +57,21 @@ class AtenOPTemplate(HOOKModule):
|
|
|
58
57
|
op_name_ = op_name_ + '.' + overload_name
|
|
59
58
|
self.op = op
|
|
60
59
|
self.prefix_op_name_ = "Aten" + Const.SEP + str(op_name_) + Const.SEP
|
|
61
|
-
|
|
60
|
+
self.need_hook = need_hook
|
|
61
|
+
if self.need_hook:
|
|
62
|
+
super().__init__(hook)
|
|
62
63
|
|
|
63
64
|
@torch_device_guard
|
|
64
65
|
def forward(self, *args, **kwargs):
|
|
66
|
+
if isinstance(self.op, str):
|
|
67
|
+
if self.op in npu_custom_grad_functions:
|
|
68
|
+
return npu_custom_grad_functions[self.op](*args, **kwargs)
|
|
69
|
+
if self.op in white_aten_ops:
|
|
70
|
+
return eval(f"torch.ops.aten.{self.op}")(*args, **kwargs)
|
|
71
|
+
if self.op not in aten_func:
|
|
72
|
+
raise Exception(f"Skip op[{self.op}] accuracy check, because the op is not "
|
|
73
|
+
f"in dir(torch.ops.aten) and support yaml.")
|
|
74
|
+
return aten_func[self.op](*args, **kwargs)
|
|
65
75
|
return self.op(*args, **kwargs)
|
|
66
76
|
|
|
67
77
|
|
|
@@ -18,18 +18,15 @@
|
|
|
18
18
|
import os
|
|
19
19
|
from functools import wraps
|
|
20
20
|
import torch.distributed as dist
|
|
21
|
-
import yaml
|
|
22
21
|
|
|
23
22
|
from msprobe.pytorch.hook_module.hook_module import HOOKModule
|
|
24
23
|
from msprobe.pytorch.common.utils import torch_device_guard
|
|
25
24
|
from msprobe.core.common.const import Const
|
|
26
|
-
from msprobe.core.common.
|
|
25
|
+
from msprobe.core.common.utils import load_yaml
|
|
27
26
|
|
|
28
27
|
|
|
29
28
|
cur_path = os.path.dirname(os.path.realpath(__file__))
|
|
30
29
|
yaml_path = os.path.join(cur_path, "support_wrap_ops.yaml")
|
|
31
|
-
with FileOpen(yaml_path, 'r') as f:
|
|
32
|
-
WrapDistributedOps = yaml.safe_load(f).get('distributed')
|
|
33
30
|
|
|
34
31
|
|
|
35
32
|
distributed_func = {}
|
|
@@ -38,9 +35,10 @@ for f in dir(dist):
|
|
|
38
35
|
|
|
39
36
|
|
|
40
37
|
def get_distributed_ops():
|
|
41
|
-
global WrapDistributedOps
|
|
42
38
|
_all_distributed_ops = dir(dist)
|
|
43
|
-
|
|
39
|
+
yaml_data = load_yaml(yaml_path)
|
|
40
|
+
wrap_distributed_ops = yaml_data.get('distributed')
|
|
41
|
+
return set(wrap_distributed_ops) & set(_all_distributed_ops)
|
|
44
42
|
|
|
45
43
|
|
|
46
44
|
class HOOKDistributedOP(object):
|
|
@@ -57,7 +55,12 @@ class DistributedOPTemplate(HOOKModule):
|
|
|
57
55
|
|
|
58
56
|
@torch_device_guard
|
|
59
57
|
def forward(self, *args, **kwargs):
|
|
60
|
-
|
|
58
|
+
if kwargs.get("async_op") or self.op_name_ in ["isend", "irecv"]:
|
|
59
|
+
handle = distributed_func.get(self.op_name_)(*args, **kwargs)
|
|
60
|
+
handle.wait()
|
|
61
|
+
return handle
|
|
62
|
+
else:
|
|
63
|
+
return distributed_func.get(self.op_name_)(*args, **kwargs)
|
|
61
64
|
|
|
62
65
|
|
|
63
66
|
def wrap_distributed_op(op_name, hook):
|
|
@@ -16,15 +16,13 @@
|
|
|
16
16
|
"""
|
|
17
17
|
|
|
18
18
|
import os
|
|
19
|
-
|
|
20
19
|
import torch
|
|
21
|
-
import yaml
|
|
22
20
|
|
|
23
21
|
from msprobe.pytorch.hook_module.hook_module import HOOKModule
|
|
24
22
|
from msprobe.pytorch.common.utils import torch_device_guard
|
|
25
23
|
from msprobe.core.common.const import Const
|
|
26
24
|
from msprobe.pytorch.common.log import logger
|
|
27
|
-
from msprobe.core.common.
|
|
25
|
+
from msprobe.core.common.utils import load_yaml
|
|
28
26
|
|
|
29
27
|
|
|
30
28
|
def remove_dropout():
|
|
@@ -66,14 +64,13 @@ def remove_dropout():
|
|
|
66
64
|
|
|
67
65
|
cur_path = os.path.dirname(os.path.realpath(__file__))
|
|
68
66
|
yaml_path = os.path.join(cur_path, "support_wrap_ops.yaml")
|
|
69
|
-
with FileOpen(yaml_path, 'r') as f:
|
|
70
|
-
WrapFunctionalOps = yaml.safe_load(f).get('functional')
|
|
71
67
|
|
|
72
68
|
|
|
73
69
|
def get_functional_ops():
|
|
74
|
-
|
|
70
|
+
yaml_data = load_yaml(yaml_path)
|
|
71
|
+
wrap_functional_ops = yaml_data.get('functional')
|
|
75
72
|
_all_functional_ops = dir(torch.nn.functional)
|
|
76
|
-
return set(
|
|
73
|
+
return set(wrap_functional_ops) & set(_all_functional_ops)
|
|
77
74
|
|
|
78
75
|
|
|
79
76
|
TorchFunctions = {func: getattr(torch.nn.functional, func) for func in get_functional_ops()}
|
|
@@ -17,27 +17,33 @@
|
|
|
17
17
|
|
|
18
18
|
import os
|
|
19
19
|
import torch
|
|
20
|
-
import torch_npu
|
|
21
|
-
import yaml
|
|
22
20
|
|
|
23
21
|
from msprobe.pytorch.hook_module.hook_module import HOOKModule
|
|
24
22
|
from msprobe.pytorch.common.utils import torch_device_guard, torch_without_guard_version
|
|
25
23
|
from msprobe.core.common.const import Const
|
|
26
|
-
from msprobe.core.common.
|
|
24
|
+
from msprobe.core.common.utils import load_yaml
|
|
25
|
+
from msprobe.pytorch.function_factory import npu_custom_functions
|
|
27
26
|
|
|
28
27
|
cur_path = os.path.dirname(os.path.realpath(__file__))
|
|
29
28
|
yaml_path = os.path.join(cur_path, "support_wrap_ops.yaml")
|
|
30
|
-
|
|
31
|
-
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
try:
|
|
32
|
+
import torch_npu
|
|
33
|
+
except ImportError:
|
|
34
|
+
is_gpu = True
|
|
35
|
+
else:
|
|
36
|
+
is_gpu = False
|
|
32
37
|
|
|
33
38
|
|
|
34
39
|
def get_npu_ops():
|
|
35
|
-
global WrapNpuOps
|
|
36
40
|
if torch_without_guard_version:
|
|
37
41
|
_npu_ops = dir(torch.ops.npu)
|
|
38
42
|
else:
|
|
39
43
|
_npu_ops = dir(torch_npu._C._VariableFunctionsClass)
|
|
40
|
-
|
|
44
|
+
yaml_data = load_yaml(yaml_path)
|
|
45
|
+
wrap_npu_ops = yaml_data.get('torch_npu')
|
|
46
|
+
return set(wrap_npu_ops) & set(_npu_ops)
|
|
41
47
|
|
|
42
48
|
|
|
43
49
|
class HOOKNpuOP(object):
|
|
@@ -46,13 +52,19 @@ class HOOKNpuOP(object):
|
|
|
46
52
|
|
|
47
53
|
class NpuOPTemplate(HOOKModule):
|
|
48
54
|
|
|
49
|
-
def __init__(self, op_name, hook):
|
|
55
|
+
def __init__(self, op_name, hook, need_hook=True):
|
|
50
56
|
self.op_name_ = op_name
|
|
51
57
|
self.prefix_op_name_ = "NPU" + Const.SEP + str(op_name) + Const.SEP
|
|
52
|
-
|
|
58
|
+
self.need_hook = need_hook
|
|
59
|
+
if need_hook:
|
|
60
|
+
super().__init__(hook)
|
|
53
61
|
|
|
54
62
|
@torch_device_guard
|
|
55
63
|
def forward(self, *args, **kwargs):
|
|
64
|
+
if not self.need_hook:
|
|
65
|
+
if self.op_name_ not in npu_custom_functions:
|
|
66
|
+
raise Exception(f'There is not bench function {self.op_name_}')
|
|
67
|
+
return npu_custom_functions[self.op_name_](*args, **kwargs)
|
|
56
68
|
if torch_without_guard_version:
|
|
57
69
|
return getattr(torch.ops.npu, str(self.op_name_))(*args, **kwargs)
|
|
58
70
|
else:
|
|
@@ -60,7 +72,6 @@ class NpuOPTemplate(HOOKModule):
|
|
|
60
72
|
|
|
61
73
|
|
|
62
74
|
def wrap_npu_op(op_name, hook):
|
|
63
|
-
|
|
64
75
|
def npu_op_template(*args, **kwargs):
|
|
65
76
|
return NpuOPTemplate(op_name, hook)(*args, **kwargs)
|
|
66
77
|
|
|
@@ -18,23 +18,22 @@
|
|
|
18
18
|
import os
|
|
19
19
|
|
|
20
20
|
import torch
|
|
21
|
-
import yaml
|
|
22
21
|
|
|
23
22
|
from msprobe.pytorch.hook_module.hook_module import HOOKModule
|
|
24
23
|
from msprobe.pytorch.common.utils import torch_device_guard, parameter_adapter
|
|
25
24
|
from msprobe.core.common.const import Const
|
|
26
|
-
from msprobe.core.common.
|
|
25
|
+
from msprobe.core.common.utils import load_yaml
|
|
26
|
+
|
|
27
27
|
|
|
28
28
|
cur_path = os.path.dirname(os.path.realpath(__file__))
|
|
29
29
|
yaml_path = os.path.join(cur_path, "support_wrap_ops.yaml")
|
|
30
|
-
with FileOpen(yaml_path, 'r') as f:
|
|
31
|
-
WrapTensorOps = yaml.safe_load(f).get('tensor')
|
|
32
30
|
|
|
33
31
|
|
|
34
32
|
def get_tensor_ops():
|
|
35
|
-
global WrapTensorOps
|
|
36
33
|
_tensor_ops = dir(torch.Tensor)
|
|
37
|
-
|
|
34
|
+
yaml_data = load_yaml(yaml_path)
|
|
35
|
+
wrap_tensor_ops = yaml_data.get('tensor')
|
|
36
|
+
return set(wrap_tensor_ops) & set(_tensor_ops)
|
|
38
37
|
|
|
39
38
|
|
|
40
39
|
TensorOps = {op: getattr(torch.Tensor, op) for op in get_tensor_ops()}
|
|
@@ -16,25 +16,23 @@
|
|
|
16
16
|
"""
|
|
17
17
|
|
|
18
18
|
import os
|
|
19
|
-
|
|
20
19
|
import torch
|
|
21
|
-
import yaml
|
|
22
20
|
|
|
23
21
|
from msprobe.pytorch.hook_module.hook_module import HOOKModule
|
|
24
22
|
from msprobe.pytorch.common.utils import torch_device_guard
|
|
25
23
|
from msprobe.core.common.const import Const
|
|
26
|
-
from msprobe.core.common.
|
|
24
|
+
from msprobe.core.common.utils import load_yaml
|
|
25
|
+
|
|
27
26
|
|
|
28
27
|
cur_path = os.path.dirname(os.path.realpath(__file__))
|
|
29
28
|
yaml_path = os.path.join(cur_path, "support_wrap_ops.yaml")
|
|
30
|
-
with FileOpen(yaml_path, 'r') as f:
|
|
31
|
-
WrapTorchOps = yaml.safe_load(f).get('torch')
|
|
32
29
|
|
|
33
30
|
|
|
34
31
|
def get_torch_ops():
|
|
35
|
-
global WrapTorchOps
|
|
36
32
|
_torch_ops = []
|
|
37
|
-
|
|
33
|
+
yaml_data = load_yaml(yaml_path)
|
|
34
|
+
wrap_torch_ops = yaml_data.get('torch')
|
|
35
|
+
for operation in wrap_torch_ops:
|
|
38
36
|
if '.' in operation:
|
|
39
37
|
operation_sub_module_name, operation_sub_op = operation.rsplit('.', 1)
|
|
40
38
|
operation_sub_module = getattr(torch, operation_sub_module_name)
|
|
@@ -16,24 +16,22 @@
|
|
|
16
16
|
"""
|
|
17
17
|
|
|
18
18
|
import os
|
|
19
|
-
|
|
20
19
|
import torch
|
|
21
|
-
import yaml
|
|
22
20
|
|
|
21
|
+
from msprobe.core.common.const import Const
|
|
22
|
+
from msprobe.core.common.utils import load_yaml
|
|
23
23
|
from msprobe.pytorch.hook_module.hook_module import HOOKModule
|
|
24
|
-
from msprobe.core.common.file_check import FileOpen
|
|
25
24
|
from msprobe.pytorch.common.utils import torch_device_guard
|
|
26
|
-
|
|
25
|
+
|
|
27
26
|
|
|
28
27
|
cur_path = os.path.dirname(os.path.realpath(__file__))
|
|
29
28
|
yaml_path = os.path.join(cur_path, "support_wrap_ops.yaml")
|
|
30
|
-
with FileOpen(yaml_path, 'r') as f:
|
|
31
|
-
WrapVfOps = yaml.safe_load(f).get('_VF')
|
|
32
29
|
|
|
33
30
|
|
|
34
31
|
def get_vf_ops():
|
|
35
|
-
|
|
36
|
-
|
|
32
|
+
yaml_data = load_yaml(yaml_path)
|
|
33
|
+
wrap_vf_ops = yaml_data.get('_VF')
|
|
34
|
+
return wrap_vf_ops
|
|
37
35
|
|
|
38
36
|
|
|
39
37
|
class HOOKVfOP(object):
|