mindstudio-probe 1.0.1__py3-none-any.whl → 1.0.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.3.dist-info}/METADATA +5 -1
- mindstudio_probe-1.0.3.dist-info/RECORD +272 -0
- msprobe/README.md +78 -23
- msprobe/__init__.py +1 -0
- msprobe/config/README.md +182 -40
- msprobe/config/config.json +22 -0
- msprobe/core/__init__.py +0 -0
- msprobe/{pytorch → core}/advisor/advisor.py +3 -3
- msprobe/{pytorch → core}/advisor/advisor_result.py +2 -2
- msprobe/core/common/const.py +82 -5
- msprobe/core/common/exceptions.py +30 -18
- msprobe/core/common/file_check.py +19 -1
- msprobe/core/common/log.py +15 -1
- msprobe/core/common/utils.py +130 -30
- msprobe/core/common_config.py +32 -19
- msprobe/core/compare/acc_compare.py +299 -0
- msprobe/core/compare/check.py +95 -0
- msprobe/core/compare/compare_cli.py +49 -0
- msprobe/core/compare/highlight.py +222 -0
- msprobe/core/compare/multiprocessing_compute.py +149 -0
- msprobe/{pytorch → core}/compare/npy_compare.py +55 -4
- msprobe/core/compare/utils.py +429 -0
- msprobe/core/data_dump/data_collector.py +39 -35
- msprobe/core/data_dump/data_processor/base.py +85 -37
- msprobe/core/data_dump/data_processor/factory.py +5 -7
- msprobe/core/data_dump/data_processor/mindspore_processor.py +198 -0
- msprobe/core/data_dump/data_processor/pytorch_processor.py +94 -51
- msprobe/core/data_dump/json_writer.py +11 -11
- msprobe/core/grad_probe/__init__.py +0 -0
- msprobe/core/grad_probe/constant.py +71 -0
- msprobe/core/grad_probe/grad_compare.py +175 -0
- msprobe/core/grad_probe/utils.py +52 -0
- msprobe/doc/grad_probe/grad_probe.md +207 -0
- msprobe/doc/grad_probe/img/image-1.png +0 -0
- msprobe/doc/grad_probe/img/image-2.png +0 -0
- msprobe/doc/grad_probe/img/image-3.png +0 -0
- msprobe/doc/grad_probe/img/image-4.png +0 -0
- msprobe/doc/grad_probe/img/image.png +0 -0
- msprobe/mindspore/api_accuracy_checker/__init__.py +0 -0
- msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +246 -0
- msprobe/mindspore/api_accuracy_checker/api_info.py +69 -0
- msprobe/mindspore/api_accuracy_checker/api_runner.py +152 -0
- msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +197 -0
- msprobe/mindspore/api_accuracy_checker/compute_element.py +224 -0
- msprobe/mindspore/api_accuracy_checker/main.py +16 -0
- msprobe/mindspore/api_accuracy_checker/type_mapping.py +114 -0
- msprobe/mindspore/api_accuracy_checker/utils.py +63 -0
- msprobe/mindspore/cell_processor.py +34 -0
- msprobe/mindspore/common/const.py +87 -0
- msprobe/mindspore/common/log.py +38 -0
- msprobe/mindspore/common/utils.py +57 -0
- msprobe/mindspore/compare/distributed_compare.py +75 -0
- msprobe/mindspore/compare/ms_compare.py +117 -0
- msprobe/mindspore/compare/ms_graph_compare.py +317 -0
- msprobe/mindspore/compare/ms_to_pt_api.yaml +399 -0
- msprobe/mindspore/debugger/debugger_config.py +38 -15
- msprobe/mindspore/debugger/precision_debugger.py +79 -4
- msprobe/mindspore/doc/compare.md +58 -0
- msprobe/mindspore/doc/dump.md +158 -6
- msprobe/mindspore/dump/dump_tool_factory.py +19 -22
- msprobe/mindspore/dump/hook_cell/api_registry.py +104 -0
- msprobe/mindspore/dump/hook_cell/hook_cell.py +53 -0
- msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +925 -0
- msprobe/mindspore/dump/hook_cell/wrap_functional.py +91 -0
- msprobe/mindspore/dump/hook_cell/wrap_tensor.py +63 -0
- msprobe/mindspore/dump/jit_dump.py +56 -0
- msprobe/mindspore/dump/kernel_kbyk_dump.py +65 -0
- msprobe/mindspore/free_benchmark/__init__.py +0 -0
- msprobe/mindspore/free_benchmark/api_pynative_self_check.py +116 -0
- msprobe/mindspore/free_benchmark/common/__init__.py +0 -0
- msprobe/mindspore/free_benchmark/common/config.py +12 -0
- msprobe/mindspore/free_benchmark/common/handler_params.py +17 -0
- msprobe/mindspore/free_benchmark/common/utils.py +71 -0
- msprobe/mindspore/free_benchmark/data/support_wrap_ops.yaml +842 -0
- msprobe/mindspore/free_benchmark/decorator/__init__.py +0 -0
- msprobe/mindspore/free_benchmark/decorator/dec_forward.py +42 -0
- msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +107 -0
- msprobe/mindspore/free_benchmark/handler/__init__.py +0 -0
- msprobe/mindspore/free_benchmark/handler/base_handler.py +90 -0
- msprobe/mindspore/free_benchmark/handler/check_handler.py +41 -0
- msprobe/mindspore/free_benchmark/handler/fix_handler.py +36 -0
- msprobe/mindspore/free_benchmark/handler/handler_factory.py +21 -0
- msprobe/mindspore/free_benchmark/perturbation/add_noise.py +67 -0
- msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +21 -0
- msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +63 -0
- msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +34 -0
- msprobe/mindspore/free_benchmark/perturbation/no_change.py +12 -0
- msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +27 -0
- msprobe/mindspore/free_benchmark/self_check_tool_factory.py +33 -0
- msprobe/mindspore/grad_probe/__init__.py +0 -0
- msprobe/mindspore/grad_probe/global_context.py +91 -0
- msprobe/mindspore/grad_probe/grad_analyzer.py +231 -0
- msprobe/mindspore/grad_probe/grad_monitor.py +27 -0
- msprobe/mindspore/grad_probe/grad_stat_csv.py +132 -0
- msprobe/mindspore/grad_probe/hook.py +92 -0
- msprobe/mindspore/grad_probe/utils.py +29 -0
- msprobe/mindspore/ms_config.py +63 -15
- msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +17 -15
- msprobe/mindspore/runtime.py +4 -0
- msprobe/mindspore/service.py +354 -0
- msprobe/mindspore/task_handler_factory.py +7 -4
- msprobe/msprobe.py +66 -26
- msprobe/pytorch/__init__.py +1 -1
- msprobe/pytorch/api_accuracy_checker/common/config.py +21 -16
- msprobe/pytorch/api_accuracy_checker/common/utils.py +1 -60
- msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +2 -5
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +46 -10
- msprobe/pytorch/api_accuracy_checker/compare/compare.py +84 -48
- msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +8 -12
- msprobe/pytorch/api_accuracy_checker/config.yaml +7 -1
- msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +15 -11
- msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +11 -15
- msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +16 -9
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +193 -105
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +68 -1
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/__init__.py +0 -0
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +202 -0
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +324 -0
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +204 -0
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +218 -0
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/ssl_config.py +10 -0
- msprobe/pytorch/bench_functions/__init__.py +15 -0
- msprobe/pytorch/bench_functions/apply_adam_w.py +28 -0
- msprobe/pytorch/bench_functions/confusion_transpose.py +19 -0
- msprobe/pytorch/bench_functions/fast_gelu.py +55 -0
- msprobe/pytorch/bench_functions/layer_norm_eval.py +6 -0
- msprobe/pytorch/bench_functions/linear.py +12 -0
- msprobe/pytorch/bench_functions/matmul_backward.py +48 -0
- msprobe/pytorch/bench_functions/npu_fusion_attention.py +421 -0
- msprobe/pytorch/bench_functions/rms_norm.py +15 -0
- msprobe/pytorch/bench_functions/rotary_mul.py +52 -0
- msprobe/pytorch/bench_functions/scaled_mask_softmax.py +26 -0
- msprobe/pytorch/bench_functions/swiglu.py +55 -0
- msprobe/pytorch/common/parse_json.py +3 -1
- msprobe/pytorch/common/utils.py +83 -7
- msprobe/pytorch/compare/distributed_compare.py +19 -64
- msprobe/pytorch/compare/match.py +3 -6
- msprobe/pytorch/compare/pt_compare.py +40 -0
- msprobe/pytorch/debugger/debugger_config.py +11 -2
- msprobe/pytorch/debugger/precision_debugger.py +34 -4
- msprobe/pytorch/doc/api_accuracy_checker.md +57 -13
- msprobe/pytorch/doc/api_accuracy_checker_online.md +187 -0
- msprobe/pytorch/doc/dump.md +73 -20
- msprobe/pytorch/doc/ptdbg_ascend_compare.md +75 -11
- msprobe/pytorch/doc/ptdbg_ascend_quickstart.md +3 -3
- msprobe/pytorch/doc/run_overflow_check.md +1 -1
- msprobe/pytorch/doc//321/206/320/247/320/260/321/206/320/260/320/227/321/206/320/255/320/226/321/205/342/225/226/320/265/321/205/320/225/342/225/226/321/205/320/254/342/225/221/321/206/320/251/320/277/321/211/320/272/320/234/321/210/320/277/320/221/321/205/320/242/320/234/321/206/320/220/320/267/321/210/320/223/342/225/234/321/205/320/257/342/225/221/321/207/342/225/221/342/224/220/321/206/320/232/320/265/321/205/320/241/320/232.md +151 -0
- msprobe/pytorch/free_benchmark/common/constant.py +3 -0
- msprobe/pytorch/free_benchmark/common/utils.py +4 -0
- msprobe/pytorch/free_benchmark/compare/grad_saver.py +22 -26
- msprobe/pytorch/free_benchmark/main.py +7 -4
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +1 -1
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +1 -1
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +1 -1
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +3 -3
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/no_change.py +1 -1
- msprobe/pytorch/free_benchmark/perturbed_layers/run_cpu.py +1 -1
- msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +43 -29
- msprobe/pytorch/free_benchmark/result_handlers/handler_factory.py +0 -1
- msprobe/pytorch/function_factory.py +75 -0
- msprobe/pytorch/functional/dump_module.py +4 -4
- msprobe/pytorch/grad_probe/__init__.py +0 -0
- msprobe/pytorch/grad_probe/grad_monitor.py +90 -0
- msprobe/pytorch/grad_probe/grad_stat_csv.py +129 -0
- msprobe/pytorch/hook_module/hook_module.py +14 -3
- msprobe/pytorch/hook_module/support_wrap_ops.yaml +2 -1
- msprobe/pytorch/hook_module/utils.py +9 -9
- msprobe/pytorch/hook_module/wrap_aten.py +20 -10
- msprobe/pytorch/hook_module/wrap_distributed.py +10 -7
- msprobe/pytorch/hook_module/wrap_functional.py +4 -7
- msprobe/pytorch/hook_module/wrap_npu_custom.py +21 -10
- msprobe/pytorch/hook_module/wrap_tensor.py +5 -6
- msprobe/pytorch/hook_module/wrap_torch.py +5 -7
- msprobe/pytorch/hook_module/wrap_vf.py +6 -8
- msprobe/pytorch/module_processer.py +53 -13
- msprobe/pytorch/online_dispatch/compare.py +4 -4
- msprobe/pytorch/online_dispatch/dispatch.py +39 -41
- msprobe/pytorch/online_dispatch/dump_compare.py +17 -47
- msprobe/pytorch/online_dispatch/single_compare.py +5 -5
- msprobe/pytorch/online_dispatch/utils.py +2 -43
- msprobe/pytorch/parse_tool/lib/compare.py +31 -19
- msprobe/pytorch/parse_tool/lib/config.py +2 -1
- msprobe/pytorch/parse_tool/lib/parse_tool.py +4 -4
- msprobe/pytorch/parse_tool/lib/utils.py +34 -80
- msprobe/pytorch/parse_tool/lib/visualization.py +4 -3
- msprobe/pytorch/pt_config.py +100 -6
- msprobe/pytorch/service.py +104 -19
- mindstudio_probe-1.0.1.dist-info/RECORD +0 -228
- msprobe/mindspore/dump/api_kbk_dump.py +0 -55
- msprobe/pytorch/compare/acc_compare.py +0 -1024
- msprobe/pytorch/compare/highlight.py +0 -100
- msprobe/test/core_ut/common/test_utils.py +0 -345
- msprobe/test/core_ut/data_dump/test_data_collector.py +0 -47
- msprobe/test/core_ut/data_dump/test_json_writer.py +0 -183
- msprobe/test/core_ut/data_dump/test_scope.py +0 -151
- msprobe/test/core_ut/test_common_config.py +0 -152
- msprobe/test/core_ut/test_file_check.py +0 -218
- msprobe/test/core_ut/test_log.py +0 -109
- msprobe/test/mindspore_ut/test_api_kbk_dump.py +0 -51
- msprobe/test/mindspore_ut/test_debugger_config.py +0 -42
- msprobe/test/mindspore_ut/test_dump_tool_factory.py +0 -51
- msprobe/test/mindspore_ut/test_kernel_graph_dump.py +0 -66
- msprobe/test/mindspore_ut/test_kernel_graph_overflow_check.py +0 -63
- msprobe/test/mindspore_ut/test_ms_config.py +0 -69
- msprobe/test/mindspore_ut/test_overflow_check_tool_factory.py +0 -51
- msprobe/test/mindspore_ut/test_precision_debugger.py +0 -56
- msprobe/test/mindspore_ut/test_task_handler_factory.py +0 -58
- msprobe/test/pytorch_ut/advisor/test_advisor.py +0 -83
- msprobe/test/pytorch_ut/api_accuracy_checker/common/test_common_utils.py +0 -108
- msprobe/test/pytorch_ut/api_accuracy_checker/common/test_config.py +0 -39
- msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_algorithm.py +0 -112
- msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_api_precision_compare.py +0 -77
- msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_compare.py +0 -125
- msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_compare_column.py +0 -10
- msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_compare_utils.py +0 -43
- msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/dump.json +0 -179
- msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/forward.json +0 -63
- msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_data_generate.py +0 -99
- msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_multi_run_ut.py +0 -115
- msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_run_ut.py +0 -72
- msprobe/test/pytorch_ut/compare/test_acc_compare.py +0 -17
- msprobe/test/pytorch_ut/free_benchmark/perturbed_layers/test_perturbed_layser.py +0 -105
- msprobe/test/pytorch_ut/free_benchmark/result_handlers/test_result_handler.py +0 -121
- msprobe/test/pytorch_ut/free_benchmark/test_main.py +0 -101
- msprobe/test/pytorch_ut/functional/test_dump_module.py +0 -15
- msprobe/test/pytorch_ut/hook_module/test_api_registry.py +0 -130
- msprobe/test/pytorch_ut/hook_module/test_hook_module.py +0 -42
- msprobe/test/pytorch_ut/hook_module/test_wrap_aten.py +0 -65
- msprobe/test/pytorch_ut/hook_module/test_wrap_distributed.py +0 -35
- msprobe/test/pytorch_ut/hook_module/test_wrap_functional.py +0 -20
- msprobe/test/pytorch_ut/hook_module/test_wrap_tensor.py +0 -35
- msprobe/test/pytorch_ut/hook_module/test_wrap_torch.py +0 -43
- msprobe/test/pytorch_ut/hook_module/test_wrap_vf.py +0 -11
- msprobe/test/pytorch_ut/test_pt_config.py +0 -69
- msprobe/test/pytorch_ut/test_service.py +0 -59
- msprobe/test/resources/advisor.txt +0 -3
- msprobe/test/resources/compare_result_20230703104808.csv +0 -9
- msprobe/test/resources/compare_result_without_accuracy.csv +0 -9
- msprobe/test/resources/config.yaml +0 -3
- msprobe/test/resources/npu_test.pkl +0 -8
- msprobe/test/run_test.sh +0 -30
- msprobe/test/run_ut.py +0 -58
- msprobe/test/test_module_processer.py +0 -64
- {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.3.dist-info}/LICENSE +0 -0
- {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.3.dist-info}/WHEEL +0 -0
- {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.3.dist-info}/entry_points.txt +0 -0
- {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.3.dist-info}/top_level.txt +0 -0
- /msprobe/{pytorch → core}/advisor/advisor_const.py +0 -0
- /msprobe/pytorch/doc/{atat → msprobe}/321/207/342/226/223/342/225/233/321/205/342/225/221/320/266/321/205/342/225/226/320/265/321/205/320/225/342/225/226/321/206/320/245/342/226/221/321/206/320/235/320/276dump/321/206/320/260/320/227/321/205/320/227/320/226/321/206/320/220/320/267/321/210/320/223/342/225/234/321/205/320/257/342/225/221/321/207/342/225/221/342/224/220/321/206/320/232/320/265/321/205/320/241/320/232.md" +0 -0
|
@@ -0,0 +1,91 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import threading
|
|
3
|
+
from typing import Dict, Union
|
|
4
|
+
|
|
5
|
+
from msprobe.core.grad_probe.utils import check_str
|
|
6
|
+
from msprobe.core.grad_probe.constant import GradConst
|
|
7
|
+
from msprobe.core.common.log import logger
|
|
8
|
+
from msprobe.core.common.file_check import create_directory
|
|
9
|
+
from msprobe.core.common.utils import check_path_before_create
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class GlobalContext:
|
|
13
|
+
|
|
14
|
+
_instance = None
|
|
15
|
+
_instance_lock = threading.Lock()
|
|
16
|
+
_setting = {
|
|
17
|
+
GradConst.LEVEL: None,
|
|
18
|
+
GradConst.PARAM_LIST: None,
|
|
19
|
+
GradConst.STEP: None,
|
|
20
|
+
GradConst.RANK: None,
|
|
21
|
+
GradConst.CURRENT_STEP: 0,
|
|
22
|
+
GradConst.BOUNDS: [-10, -1, -0.1, -0.01, -0.001, 0, 0.001, 0.01, 0.1, 1, 10],
|
|
23
|
+
GradConst.OUTPUT_PATH: None
|
|
24
|
+
}
|
|
25
|
+
|
|
26
|
+
def __new__(cls, *args, **kwargs):
|
|
27
|
+
if cls._instance is None:
|
|
28
|
+
cls._instance_lock.acquire()
|
|
29
|
+
cls._instance = object.__new__(cls)
|
|
30
|
+
cls._instance_lock.release()
|
|
31
|
+
return cls._instance
|
|
32
|
+
|
|
33
|
+
def init_context(self, config_dict: Dict):
|
|
34
|
+
level = config_dict.get(GradConst.LEVEL)
|
|
35
|
+
check_str(level, variable_name = "level in yaml")
|
|
36
|
+
if level in GradConst.SUPPORTED_LEVEL:
|
|
37
|
+
self._setting[GradConst.LEVEL] = config_dict.get(GradConst.LEVEL)
|
|
38
|
+
else:
|
|
39
|
+
raise ValueError("Invalid level set in config yaml file, level option: L0, L1, L2")
|
|
40
|
+
|
|
41
|
+
self._set_input_list(config_dict, GradConst.PARAM_LIST, str)
|
|
42
|
+
self._set_input_list(config_dict, GradConst.BOUNDS, float)
|
|
43
|
+
self._set_input_list(config_dict, GradConst.STEP, int)
|
|
44
|
+
self._set_input_list(config_dict, GradConst.RANK, int)
|
|
45
|
+
|
|
46
|
+
output_path = config_dict.get(GradConst.OUTPUT_PATH)
|
|
47
|
+
check_str(output_path, variable_name = "output_path in yaml")
|
|
48
|
+
try:
|
|
49
|
+
check_path_before_create(output_path)
|
|
50
|
+
except RuntimeError as err:
|
|
51
|
+
raise ValueError(f"Invalid output_path: {output_path}. The error message is {err}.") from err
|
|
52
|
+
self._setting[GradConst.OUTPUT_PATH] = output_path
|
|
53
|
+
if not os.path.isdir(self._setting.get(GradConst.OUTPUT_PATH)):
|
|
54
|
+
create_directory(self._setting.get(GradConst.OUTPUT_PATH))
|
|
55
|
+
else:
|
|
56
|
+
logger.warning("The output_path exists, the data will be covered.")
|
|
57
|
+
|
|
58
|
+
def get_context(self, key: str):
|
|
59
|
+
if key not in self._setting:
|
|
60
|
+
logger.warning(f"Unrecognized {key}.")
|
|
61
|
+
return self._setting.get(key)
|
|
62
|
+
|
|
63
|
+
def update_step(self):
|
|
64
|
+
self._setting[GradConst.CURRENT_STEP] += 1
|
|
65
|
+
|
|
66
|
+
def step_need_dump(self, step):
|
|
67
|
+
dump_step_list = self.get_context(GradConst.STEP)
|
|
68
|
+
return (not dump_step_list) or (step in dump_step_list)
|
|
69
|
+
|
|
70
|
+
def rank_need_dump(self, rank):
|
|
71
|
+
dump_rank_list = self.get_context(GradConst.RANK)
|
|
72
|
+
return (not dump_rank_list) or (rank in dump_rank_list)
|
|
73
|
+
|
|
74
|
+
def _set_input_list(self, config_dict: Dict, name: str, dtype: Union[int, str, float]):
|
|
75
|
+
value = config_dict.get(name)
|
|
76
|
+
if dtype == int:
|
|
77
|
+
type_str = "integer"
|
|
78
|
+
elif dtype == float:
|
|
79
|
+
type_str = "float"
|
|
80
|
+
else:
|
|
81
|
+
type_str = "string"
|
|
82
|
+
if value and isinstance(value, list):
|
|
83
|
+
for val in value:
|
|
84
|
+
if not isinstance(val, dtype):
|
|
85
|
+
logger.warning(f"Invalid {name} which must be None or list of {type_str}")
|
|
86
|
+
return
|
|
87
|
+
self._setting[name] = value
|
|
88
|
+
else:
|
|
89
|
+
logger.warning(f"{name} is None or not a list with valid items, use default value.")
|
|
90
|
+
|
|
91
|
+
grad_context = GlobalContext()
|
|
@@ -0,0 +1,231 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import time
|
|
3
|
+
from typing import List, Tuple
|
|
4
|
+
import multiprocessing
|
|
5
|
+
from multiprocessing import Process
|
|
6
|
+
|
|
7
|
+
import numpy as np
|
|
8
|
+
import mindspore as ms
|
|
9
|
+
from mindspore.communication import get_rank
|
|
10
|
+
from mindspore.ops import operations as P
|
|
11
|
+
from mindspore.common.parameter import Parameter
|
|
12
|
+
|
|
13
|
+
from msprobe.core.grad_probe.utils import ListCache
|
|
14
|
+
from msprobe.core.grad_probe.constant import GradConst
|
|
15
|
+
from msprobe.core.common.log import logger
|
|
16
|
+
from msprobe.core.common.file_check import create_directory
|
|
17
|
+
from msprobe.core.common.utils import check_file_or_directory_path, write_csv, remove_path, move_file
|
|
18
|
+
from msprobe.mindspore.grad_probe.global_context import grad_context, GlobalContext
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def get_rank_id():
|
|
22
|
+
try:
|
|
23
|
+
rank_id = get_rank()
|
|
24
|
+
except Exception as err:
|
|
25
|
+
rank_id = 0
|
|
26
|
+
return rank_id
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
@ms.jit
|
|
30
|
+
def grad_dump(dump_dir: str, g_name: str, dump_step: Parameter, grad: ms.Tensor, level: str, bounds: List):
|
|
31
|
+
'''
|
|
32
|
+
Dump gradient statistic data.
|
|
33
|
+
level0: [step, max, min, norm, shape_dim, shape]
|
|
34
|
+
level1: [step, max, min, norm, shape_dim, shape] + grad_bool_data
|
|
35
|
+
level2: [step, max, min, norm, shape_dim, shape, dist_dim, dist] + grad_bool_data
|
|
36
|
+
'''
|
|
37
|
+
dump_path = os.path.join(dump_dir, g_name)
|
|
38
|
+
dump_dir_path = dump_path + "_dir"
|
|
39
|
+
save_op = ms.ops.TensorDump()
|
|
40
|
+
|
|
41
|
+
grad_flat = grad.reshape(-1)
|
|
42
|
+
max_val = grad_flat.max(axis=0).float()
|
|
43
|
+
min_val = grad_flat.min(axis=0).float()
|
|
44
|
+
norm_val = grad_flat.norm(ord=2).float()
|
|
45
|
+
shape = grad.shape
|
|
46
|
+
extrem_list = [dump_step[0].float(), max_val, min_val, norm_val]
|
|
47
|
+
extrem_stat = ms.ops.stack(extrem_list)
|
|
48
|
+
shape_list = [len(shape)] + list(shape)
|
|
49
|
+
shape_stat = ms.Tensor(shape_list).float()
|
|
50
|
+
level0_stat = ms.ops.concat((extrem_stat, shape_stat), axis=0)
|
|
51
|
+
level_stat = level0_stat
|
|
52
|
+
|
|
53
|
+
if level == GradConst.LEVEL2:
|
|
54
|
+
zero_grad = (grad == 0).sum()
|
|
55
|
+
dist_dim = ms.Tensor([len(bounds) + 2]).float()
|
|
56
|
+
bucket_result = ms.ops.bucketize(grad.float(), bounds)
|
|
57
|
+
bucket_result = bucket_result.astype(ms.int8)
|
|
58
|
+
dist_stat = [(bucket_result == i).sum() for i in range(len(bounds) + 1)]
|
|
59
|
+
dist_stat.append(zero_grad)
|
|
60
|
+
dist_stat.append(ms.Tensor(1, dtype=ms.int64)) # make sure dist_stat is not empty
|
|
61
|
+
dist_stat = ms.ops.stack(dist_stat, axis=0).float()
|
|
62
|
+
level2_stat = ms.ops.concat((level0_stat, dist_dim, dist_stat), axis=0)
|
|
63
|
+
level_stat = level2_stat
|
|
64
|
+
|
|
65
|
+
save_op(dump_path, level_stat)
|
|
66
|
+
if level == GradConst.LEVEL1 or level == GradConst.LEVEL2:
|
|
67
|
+
grad_direction = grad > 0
|
|
68
|
+
save_op(dump_dir_path, grad_direction)
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
class CSVGenerator(Process):
|
|
72
|
+
|
|
73
|
+
def __init__(self) -> None:
|
|
74
|
+
super().__init__()
|
|
75
|
+
self.dump_dir = None
|
|
76
|
+
self.save_dir = None
|
|
77
|
+
self.level = GradConst.LEVEL0
|
|
78
|
+
self.cache_list = ListCache()
|
|
79
|
+
self.current_step = None
|
|
80
|
+
self.stop_event = None
|
|
81
|
+
self.last_finish = False
|
|
82
|
+
self.bounds = [-0.1, 0.0, 0.1],
|
|
83
|
+
|
|
84
|
+
def init(self, context: GlobalContext):
|
|
85
|
+
rank_id = get_rank_id()
|
|
86
|
+
output_path = context.get_context(GradConst.OUTPUT_PATH)
|
|
87
|
+
self.level = context.get_context(GradConst.LEVEL)
|
|
88
|
+
self.bounds = context.get_context(GradConst.BOUNDS)
|
|
89
|
+
self.dump_dir = f"{output_path}/rank{rank_id}/Dump/"
|
|
90
|
+
self.save_dir = f"{output_path}/rank{rank_id}/"
|
|
91
|
+
self.current_step = None
|
|
92
|
+
self.stop_event = multiprocessing.Event()
|
|
93
|
+
self.last_finish = False
|
|
94
|
+
|
|
95
|
+
def run(self):
|
|
96
|
+
while True:
|
|
97
|
+
if not os.path.exists(self.dump_dir):
|
|
98
|
+
time.sleep(0.1)
|
|
99
|
+
if self.stop_event.is_set():
|
|
100
|
+
break
|
|
101
|
+
continue
|
|
102
|
+
npy_files = os.listdir(self.dump_dir)
|
|
103
|
+
npy_files.sort(key=lambda x: int(x.split("_")[0]))
|
|
104
|
+
self.traverse_files(npy_files)
|
|
105
|
+
empty = len(os.listdir(self.dump_dir)) == 0
|
|
106
|
+
if self.stop_event.is_set() and empty and self.last_finish:
|
|
107
|
+
break
|
|
108
|
+
if os.path.exists(self.dump_dir):
|
|
109
|
+
remove_path(self.dump_dir)
|
|
110
|
+
|
|
111
|
+
def stop(self):
|
|
112
|
+
self.stop_event.set()
|
|
113
|
+
|
|
114
|
+
def traverse_files(self, npy_files: List):
|
|
115
|
+
for npy_file in npy_files:
|
|
116
|
+
file_path = os.path.join(self.dump_dir, npy_file)
|
|
117
|
+
while not os.path.exists(file_path):
|
|
118
|
+
time.sleep(0.01)
|
|
119
|
+
check_file_or_directory_path(file_path)
|
|
120
|
+
if GradConst.STEP_FINISH in npy_file:
|
|
121
|
+
self.cache_list.flush()
|
|
122
|
+
remove_path(file_path)
|
|
123
|
+
self.last_finish = True
|
|
124
|
+
elif file_path.split("_")[-1] == GradConst.DIR_SUFFIX:
|
|
125
|
+
prefix_idx = len(npy_file.split("_")[0])
|
|
126
|
+
new_name = npy_file[prefix_idx + 1:].replace("_" + GradConst.DIR_SUFFIX, "." + GradConst.NPY_SUFFIX)
|
|
127
|
+
if not new_name:
|
|
128
|
+
raise RuntimeError("Invalid dump data name.")
|
|
129
|
+
if self.current_step is None:
|
|
130
|
+
raise RuntimeError("Current record step is None.")
|
|
131
|
+
step_dir = os.path.join(self.save_dir, f"step{self.current_step}")
|
|
132
|
+
if not os.path.exists(step_dir):
|
|
133
|
+
create_directory(step_dir)
|
|
134
|
+
dst_file = os.path.join(step_dir, new_name)
|
|
135
|
+
move_file(file_path, dst_file)
|
|
136
|
+
self.last_finish = False
|
|
137
|
+
elif file_path.split(".")[-1] == GradConst.NPY_SUFFIX:
|
|
138
|
+
stat_data = self.load_npy_data(file_path)
|
|
139
|
+
if stat_data is None:
|
|
140
|
+
continue
|
|
141
|
+
if not self.check_valid(stat_data):
|
|
142
|
+
os.remove(file_path)
|
|
143
|
+
continue
|
|
144
|
+
step = int(stat_data[GradConst.STEP_IDX])
|
|
145
|
+
update_step = self.current_step is None or step != self.current_step
|
|
146
|
+
self.current_step = step
|
|
147
|
+
if update_step:
|
|
148
|
+
self.create_csv_file()
|
|
149
|
+
self.gen_csv_line(file_path, stat_data)
|
|
150
|
+
os.remove(file_path)
|
|
151
|
+
self.last_finish = False
|
|
152
|
+
|
|
153
|
+
def check_valid(self, stat_data):
|
|
154
|
+
level = grad_context.get_context(GradConst.LEVEL)
|
|
155
|
+
try:
|
|
156
|
+
shape_dim = int(stat_data[GradConst.SHAPE_DIM_IDX])
|
|
157
|
+
if level == GradConst.LEVEL2:
|
|
158
|
+
dist_dim = int(stat_data[shape_dim + GradConst.SHAPE_DIM_IDX + 1])
|
|
159
|
+
length = shape_dim + dist_dim + 7
|
|
160
|
+
else:
|
|
161
|
+
length = shape_dim + 5
|
|
162
|
+
except IndexError as err:
|
|
163
|
+
return False
|
|
164
|
+
if length != len(stat_data):
|
|
165
|
+
return False
|
|
166
|
+
return True
|
|
167
|
+
|
|
168
|
+
def load_npy_data(self, file_path: str):
|
|
169
|
+
stat_data = None
|
|
170
|
+
max_try = 10
|
|
171
|
+
while max_try:
|
|
172
|
+
try:
|
|
173
|
+
stat_data = np.load(file_path)
|
|
174
|
+
return stat_data
|
|
175
|
+
except Exception as err:
|
|
176
|
+
logger.warning(f"load numpy file failed, retry...")
|
|
177
|
+
max_try -= 1
|
|
178
|
+
time.sleep(0.1)
|
|
179
|
+
return stat_data
|
|
180
|
+
|
|
181
|
+
def gen_csv_line(self, file_path: str, stat_data) -> None:
|
|
182
|
+
shape_dim = int(stat_data[GradConst.SHAPE_DIM_IDX])
|
|
183
|
+
file_name = os.path.basename(file_path)
|
|
184
|
+
prefix_idx = len(file_name.split("_")[0])
|
|
185
|
+
param_name = file_name[(prefix_idx + 1) : -(len(GradConst.NPY_SUFFIX) + 1)]
|
|
186
|
+
if not param_name:
|
|
187
|
+
raise RuntimeError("Invalid gradient statistic file name.")
|
|
188
|
+
csv_line = [param_name]
|
|
189
|
+
if self.level == GradConst.LEVEL2:
|
|
190
|
+
csv_line.extend(self.get_dist_data(shape_dim, stat_data))
|
|
191
|
+
csv_line.extend(self.get_extrem_data(shape_dim, stat_data))
|
|
192
|
+
self.cache_list.append(csv_line)
|
|
193
|
+
|
|
194
|
+
def get_dist_data(self, shape_dim: int, stat_data: np.ndarray):
|
|
195
|
+
dist_data = stat_data[(shape_dim + GradConst.SHAPE_DIM_IDX + 2):-1]
|
|
196
|
+
element_num = dist_data.sum() - dist_data[-1]
|
|
197
|
+
if element_num != 0:
|
|
198
|
+
dist_data = dist_data / element_num
|
|
199
|
+
return list(dist_data)
|
|
200
|
+
|
|
201
|
+
def get_extrem_data(self, shape_dim: int, stat_data: np.ndarray):
|
|
202
|
+
extrem_data = list(stat_data[(GradConst.STEP_IDX + 1):(GradConst.STEP_IDX + 4)])
|
|
203
|
+
shape_data = stat_data[(GradConst.SHAPE_DIM_IDX + 1):(GradConst.SHAPE_DIM_IDX + shape_dim + 1)]
|
|
204
|
+
shape_data = list(shape_data.astype(int))
|
|
205
|
+
extrem_data.append(shape_data)
|
|
206
|
+
return extrem_data
|
|
207
|
+
|
|
208
|
+
def create_csv_file(self):
|
|
209
|
+
headers = ["Param_name"]
|
|
210
|
+
if self.level == GradConst.LEVEL2:
|
|
211
|
+
headers.extend(self.get_dist_header())
|
|
212
|
+
headers.extend(self.get_extrem_headers())
|
|
213
|
+
output_path = f"{self.save_dir}/grad_summary_{self.current_step}.csv"
|
|
214
|
+
write_csv([headers], output_path)
|
|
215
|
+
self.cache_list.set_output_file(output_path)
|
|
216
|
+
self.cache_list.clear()
|
|
217
|
+
|
|
218
|
+
def get_extrem_headers(self) -> List[str]:
|
|
219
|
+
return ["Max", "Min", "Norm", "Shape"]
|
|
220
|
+
|
|
221
|
+
def get_dist_header(self) -> List[str]:
|
|
222
|
+
intervals = []
|
|
223
|
+
for i, _ in enumerate(self.bounds):
|
|
224
|
+
if i == 0:
|
|
225
|
+
intervals.append(f"(-inf, {self.bounds[i]}]")
|
|
226
|
+
else:
|
|
227
|
+
intervals.append(f"({self.bounds[i-1]}, {self.bounds[i]}]")
|
|
228
|
+
intervals.extend([f"({self.bounds[-1]}, inf)", "=0"])
|
|
229
|
+
return intervals
|
|
230
|
+
|
|
231
|
+
csv_generator = CSVGenerator()
|
|
@@ -0,0 +1,27 @@
|
|
|
1
|
+
from msprobe.mindspore.grad_probe.global_context import grad_context
|
|
2
|
+
from msprobe.mindspore.grad_probe.grad_analyzer import csv_generator
|
|
3
|
+
from msprobe.mindspore.grad_probe.hook import hook_optimizer
|
|
4
|
+
from msprobe.core.grad_probe.constant import GradConst
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class GradientMonitor:
|
|
8
|
+
|
|
9
|
+
def __init__(self, common_dict, task_config):
|
|
10
|
+
config = {}
|
|
11
|
+
config[GradConst.OUTPUT_PATH] = common_dict.dump_path
|
|
12
|
+
config[GradConst.STEP] = common_dict.step
|
|
13
|
+
config[GradConst.RANK] = common_dict.rank
|
|
14
|
+
config[GradConst.PARAM_LIST] = task_config.param_list
|
|
15
|
+
config[GradConst.LEVEL] = task_config.grad_level
|
|
16
|
+
config[GradConst.BOUNDS] = task_config.bounds
|
|
17
|
+
self.config = config
|
|
18
|
+
grad_context.init_context(self.config)
|
|
19
|
+
|
|
20
|
+
@staticmethod
|
|
21
|
+
def monitor(opt):
|
|
22
|
+
csv_generator.init(grad_context)
|
|
23
|
+
hook_optimizer(opt)
|
|
24
|
+
|
|
25
|
+
@staticmethod
|
|
26
|
+
def stop():
|
|
27
|
+
csv_generator.stop()
|
|
@@ -0,0 +1,132 @@
|
|
|
1
|
+
from abc import ABC, abstractmethod
|
|
2
|
+
import hashlib
|
|
3
|
+
|
|
4
|
+
import mindspore
|
|
5
|
+
from mindspore import ops, Tensor
|
|
6
|
+
from msprobe.core.grad_probe.constant import GradConst
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class CsvInput:
|
|
10
|
+
def __init__(self, param_name, grad, bounds):
|
|
11
|
+
self.param_name = param_name
|
|
12
|
+
self.grad = grad
|
|
13
|
+
self.bounds = bounds
|
|
14
|
+
|
|
15
|
+
class GradStatCsv:
|
|
16
|
+
csv = {}
|
|
17
|
+
|
|
18
|
+
@staticmethod
|
|
19
|
+
def get_csv_header(level, csv_input):
|
|
20
|
+
header = ["param_name"]
|
|
21
|
+
for key in level["header"]:
|
|
22
|
+
header.extend(GradStatCsv.csv[key].generate_csv_header(csv_input))
|
|
23
|
+
return header
|
|
24
|
+
|
|
25
|
+
@staticmethod
|
|
26
|
+
def get_csv_line(level, csv_input):
|
|
27
|
+
line = [csv_input.param_name]
|
|
28
|
+
for key in level["header"]:
|
|
29
|
+
line.extend(GradStatCsv.csv[key].generate_csv_content(csv_input))
|
|
30
|
+
return line
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def register_csv_item(key, cls=None):
|
|
34
|
+
if cls is None:
|
|
35
|
+
# 无参数时,返回装饰器函数
|
|
36
|
+
return lambda cls: register_csv_item(key, cls)
|
|
37
|
+
GradStatCsv.csv[key] = cls
|
|
38
|
+
return cls
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
class CsvItem(ABC):
|
|
42
|
+
@staticmethod
|
|
43
|
+
@abstractmethod
|
|
44
|
+
def generate_csv_header(csv_input):
|
|
45
|
+
pass
|
|
46
|
+
|
|
47
|
+
@staticmethod
|
|
48
|
+
@abstractmethod
|
|
49
|
+
def generate_csv_content(csv_input):
|
|
50
|
+
pass
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
@register_csv_item(GradConst.MD5)
|
|
54
|
+
class CsvMd5(CsvItem):
|
|
55
|
+
def generate_csv_header(csv_input):
|
|
56
|
+
return ["MD5"]
|
|
57
|
+
|
|
58
|
+
def generate_csv_content(csv_input):
|
|
59
|
+
grad = csv_input.grad
|
|
60
|
+
tensor_bytes = grad.float().numpy().tobytes()
|
|
61
|
+
md5_hash = hashlib.md5(tensor_bytes)
|
|
62
|
+
return [md5_hash.hexdigest()]
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
@register_csv_item(GradConst.DISTRIBUTION)
|
|
66
|
+
class CsvDistribution(CsvItem):
|
|
67
|
+
def generate_csv_header(csv_input):
|
|
68
|
+
bounds = csv_input.bounds
|
|
69
|
+
intervals = []
|
|
70
|
+
if bounds:
|
|
71
|
+
intervals.append(f"(-inf, {bounds[0]}]")
|
|
72
|
+
for i in range(1, len(bounds)):
|
|
73
|
+
intervals.append(f"({bounds[i-1]}, {bounds[i]}]")
|
|
74
|
+
if intervals:
|
|
75
|
+
intervals.append(f"({bounds[-1]}, inf)")
|
|
76
|
+
intervals.append("=0")
|
|
77
|
+
|
|
78
|
+
return intervals
|
|
79
|
+
|
|
80
|
+
def generate_csv_content(csv_input):
|
|
81
|
+
grad = csv_input.grad
|
|
82
|
+
bounds = csv_input.bounds
|
|
83
|
+
if grad.dtype == mindspore.bfloat16:
|
|
84
|
+
grad = grad.to(mindspore.float32)
|
|
85
|
+
element_num = grad.numel()
|
|
86
|
+
grad_equal_0_num = (grad == 0).sum().item()
|
|
87
|
+
bucketsize_result = ops.bucketize(grad.float(), bounds)
|
|
88
|
+
bucketsize_result = bucketsize_result.astype(mindspore.int8)
|
|
89
|
+
interval_nums = [(bucketsize_result == i).sum().item() for i in range(len(bounds) + 1)]
|
|
90
|
+
interval_nums.append(grad_equal_0_num)
|
|
91
|
+
return_list = [x / element_num if element_num != 0 else 0 for x in interval_nums]
|
|
92
|
+
return return_list
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
@register_csv_item(GradConst.MAX)
|
|
96
|
+
class CsvMax(CsvItem):
|
|
97
|
+
def generate_csv_header(csv_input):
|
|
98
|
+
return ["max"]
|
|
99
|
+
|
|
100
|
+
def generate_csv_content(csv_input):
|
|
101
|
+
grad = csv_input.grad
|
|
102
|
+
return [ops.amax(grad).float().numpy().tolist()]
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
@register_csv_item(GradConst.MIN)
|
|
106
|
+
class CsvMin(CsvItem):
|
|
107
|
+
def generate_csv_header(csv_input):
|
|
108
|
+
return ["min"]
|
|
109
|
+
|
|
110
|
+
def generate_csv_content(csv_input):
|
|
111
|
+
grad = csv_input.grad
|
|
112
|
+
return [ops.amin(grad).float().numpy().tolist()]
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
@register_csv_item(GradConst.NORM)
|
|
116
|
+
class CsvNorm(CsvItem):
|
|
117
|
+
def generate_csv_header(csv_input):
|
|
118
|
+
return ["norm"]
|
|
119
|
+
|
|
120
|
+
def generate_csv_content(csv_input):
|
|
121
|
+
grad = csv_input.grad
|
|
122
|
+
return [ops.norm(grad).float().numpy().tolist()]
|
|
123
|
+
|
|
124
|
+
|
|
125
|
+
@register_csv_item(GradConst.SHAPE)
|
|
126
|
+
class CsvShape(CsvItem):
|
|
127
|
+
def generate_csv_header(csv_input):
|
|
128
|
+
return ["shape"]
|
|
129
|
+
|
|
130
|
+
def generate_csv_content(csv_input):
|
|
131
|
+
grad = csv_input.grad
|
|
132
|
+
return [list(grad.shape)]
|
|
@@ -0,0 +1,92 @@
|
|
|
1
|
+
|
|
2
|
+
import os
|
|
3
|
+
|
|
4
|
+
import mindspore
|
|
5
|
+
import mindspore as ms
|
|
6
|
+
from mindspore.common.api import jit
|
|
7
|
+
from mindspore.nn.optim.optimizer import Optimizer
|
|
8
|
+
from mindspore.common.parameter import Parameter
|
|
9
|
+
from mindspore.common.initializer import initializer
|
|
10
|
+
|
|
11
|
+
from msprobe.core.grad_probe.constant import GradConst
|
|
12
|
+
from msprobe.core.common.log import logger
|
|
13
|
+
|
|
14
|
+
from msprobe.core.common.utils import write_csv, remove_path
|
|
15
|
+
from msprobe.mindspore.grad_probe.global_context import grad_context
|
|
16
|
+
from msprobe.mindspore.grad_probe.grad_analyzer import grad_dump, get_rank_id
|
|
17
|
+
from msprobe.mindspore.grad_probe.grad_analyzer import csv_generator
|
|
18
|
+
from msprobe.mindspore.grad_probe.grad_stat_csv import GradStatCsv, CsvInput
|
|
19
|
+
from msprobe.mindspore.grad_probe.utils import save_grad_direction, get_adapted_level
|
|
20
|
+
|
|
21
|
+
class HookInput:
|
|
22
|
+
|
|
23
|
+
'''
|
|
24
|
+
HookInput is a class wrapping all the variables used for hooking optimizer
|
|
25
|
+
'''
|
|
26
|
+
|
|
27
|
+
def __init__(self, opt) -> None:
|
|
28
|
+
self.func = opt.construct
|
|
29
|
+
self.g_names = [param.name for param in opt._parameters]
|
|
30
|
+
self.param_list = grad_context.get_context(GradConst.PARAM_LIST)
|
|
31
|
+
self.rank_id = get_rank_id()
|
|
32
|
+
output_path = grad_context.get_context(GradConst.OUTPUT_PATH)
|
|
33
|
+
self.dump_dir = os.path.join(output_path, f"rank{self.rank_id}", "Dump")
|
|
34
|
+
self.save_dir = os.path.join(output_path, f"rank{self.rank_id}")
|
|
35
|
+
self.step_finish_flag = os.path.join(self.dump_dir, GradConst.STEP_FINISH)
|
|
36
|
+
if os.path.exists(self.save_dir):
|
|
37
|
+
logger.warning(f"Delete existing path {self.save_dir}.")
|
|
38
|
+
remove_path(self.save_dir)
|
|
39
|
+
self.level = grad_context.get_context(GradConst.LEVEL)
|
|
40
|
+
self.bounds = grad_context.get_context(GradConst.BOUNDS)
|
|
41
|
+
self.mode = mindspore.get_context("mode")
|
|
42
|
+
|
|
43
|
+
def hook_graph_mode_optimizer(opt, hook_input):
|
|
44
|
+
@jit
|
|
45
|
+
def new_construct(self, gradients):
|
|
46
|
+
for index, grad_value in enumerate(gradients):
|
|
47
|
+
if hook_input.param_list and hook_input.g_names[index] not in hook_input.param_list:
|
|
48
|
+
continue
|
|
49
|
+
grad_dump(hook_input.dump_dir, hook_input.g_names[index], self.dump_step,
|
|
50
|
+
grad_value, hook_input.level, hook_input.bounds)
|
|
51
|
+
ms.ops.TensorDump()(hook_input.step_finish_flag, self.dump_step)
|
|
52
|
+
self.assignadd(self.dump_step, self.global_step_increase_tensor)
|
|
53
|
+
out = hook_input.func(gradients)
|
|
54
|
+
return out
|
|
55
|
+
|
|
56
|
+
opt.dump_step = Parameter(initializer(0, [1], ms.int32), name="dump_step")
|
|
57
|
+
opt.construct = new_construct.__get__(opt, type(opt))
|
|
58
|
+
csv_generator.start()
|
|
59
|
+
|
|
60
|
+
def hook_pynative_optimizer(opt, hook_input):
|
|
61
|
+
level_adapted = get_adapted_level(hook_input.level)
|
|
62
|
+
|
|
63
|
+
def hook_fn(cell, input):
|
|
64
|
+
gradients, = input
|
|
65
|
+
cur_step = grad_context.get_context(GradConst.CURRENT_STEP)
|
|
66
|
+
if grad_context.step_need_dump(cur_step) and grad_context.rank_need_dump(hook_input.rank_id):
|
|
67
|
+
output_lines = []
|
|
68
|
+
for index, grad_value in enumerate(gradients):
|
|
69
|
+
param_name = hook_input.g_names[index]
|
|
70
|
+
if hook_input.param_list and param_name not in hook_input.param_list:
|
|
71
|
+
continue
|
|
72
|
+
csv_input = CsvInput(param_name, grad_value, hook_input.bounds)
|
|
73
|
+
grad_info = GradStatCsv.get_csv_line(level_adapted, csv_input)
|
|
74
|
+
output_lines.append(grad_info)
|
|
75
|
+
if level_adapted["have_grad_direction"]:
|
|
76
|
+
save_grad_direction(param_name, grad_value, os.path.join(hook_input.save_dir, f'step{cur_step}'))
|
|
77
|
+
output_csv_path = os.path.join(hook_input.save_dir, f"grad_summary_{cur_step}.csv")
|
|
78
|
+
dummy_csv_input = CsvInput(None, None, hook_input.bounds)
|
|
79
|
+
output_lines.insert(0, GradStatCsv.get_csv_header(level_adapted, dummy_csv_input))
|
|
80
|
+
write_csv(output_lines, output_csv_path)
|
|
81
|
+
grad_context.update_step()
|
|
82
|
+
|
|
83
|
+
opt.register_forward_pre_hook(hook_fn)
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
def hook_optimizer(opt: Optimizer):
|
|
87
|
+
hook_input = HookInput(opt)
|
|
88
|
+
|
|
89
|
+
if hook_input.mode == mindspore.GRAPH_MODE:
|
|
90
|
+
hook_graph_mode_optimizer(opt, hook_input)
|
|
91
|
+
else:
|
|
92
|
+
hook_pynative_optimizer(opt, hook_input)
|
|
@@ -0,0 +1,29 @@
|
|
|
1
|
+
import os
|
|
2
|
+
|
|
3
|
+
import numpy as np
|
|
4
|
+
import mindspore
|
|
5
|
+
from msprobe.core.grad_probe.constant import GradConst, level_adp
|
|
6
|
+
from msprobe.core.grad_probe.utils import check_param
|
|
7
|
+
from msprobe.core.common.file_check import create_directory
|
|
8
|
+
from msprobe.core.common.utils import check_path_before_create, change_mode, check_file_or_directory_path, save_npy
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def save_grad_direction(param_name, grad, save_path):
|
|
12
|
+
if not os.path.exists(save_path):
|
|
13
|
+
create_directory(save_path)
|
|
14
|
+
check_file_or_directory_path(save_path, isdir=True)
|
|
15
|
+
check_param(param_name)
|
|
16
|
+
save_filepath = os.path.join(save_path, f"{param_name}.npy")
|
|
17
|
+
check_path_before_create(save_filepath)
|
|
18
|
+
|
|
19
|
+
if grad.dtype == mindspore.bfloat16:
|
|
20
|
+
grad = grad.to(mindspore.float32)
|
|
21
|
+
grad_direction_tensor = grad > 0
|
|
22
|
+
grad_direction_ndarray = grad_direction_tensor.numpy()
|
|
23
|
+
|
|
24
|
+
save_npy(grad_direction_ndarray, save_filepath)
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def get_adapted_level(level: str):
|
|
28
|
+
level_adapted = level_adp.get(level)
|
|
29
|
+
return level_adapted
|