mindstudio-probe 1.0.3__py3-none-any.whl → 1.0.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.0.4.dist-info}/LICENSE +201 -201
- {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.0.4.dist-info}/METADATA +36 -34
- mindstudio_probe-1.0.4.dist-info/RECORD +276 -0
- {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.0.4.dist-info}/WHEEL +1 -1
- {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.0.4.dist-info}/entry_points.txt +1 -0
- msprobe/README.md +101 -237
- msprobe/{config/config.json → config.json} +49 -49
- msprobe/core/advisor/advisor.py +124 -124
- msprobe/core/advisor/advisor_const.py +59 -59
- msprobe/core/advisor/advisor_result.py +58 -58
- msprobe/core/common/const.py +341 -318
- msprobe/core/common/exceptions.py +99 -99
- msprobe/core/common/{file_check.py → file_utils.py} +478 -283
- msprobe/core/common/log.py +76 -69
- msprobe/core/common/utils.py +385 -616
- msprobe/core/common_config.py +85 -71
- msprobe/core/compare/acc_compare.py +299 -298
- msprobe/core/compare/check.py +95 -95
- msprobe/core/compare/compare_cli.py +49 -49
- msprobe/core/compare/highlight.py +223 -222
- msprobe/core/compare/multiprocessing_compute.py +149 -149
- msprobe/core/compare/npy_compare.py +295 -295
- msprobe/core/compare/utils.py +430 -429
- msprobe/core/data_dump/data_collector.py +154 -144
- msprobe/core/data_dump/data_processor/base.py +314 -293
- msprobe/core/data_dump/data_processor/factory.py +59 -59
- msprobe/core/data_dump/data_processor/mindspore_processor.py +186 -198
- msprobe/core/data_dump/data_processor/pytorch_processor.py +366 -389
- msprobe/core/data_dump/json_writer.py +96 -116
- msprobe/core/data_dump/scope.py +178 -178
- msprobe/core/grad_probe/constant.py +70 -70
- msprobe/core/grad_probe/grad_compare.py +171 -175
- msprobe/core/grad_probe/utils.py +64 -52
- msprobe/docs/01.installation.md +89 -0
- msprobe/docs/02.config_introduction.md +165 -0
- msprobe/docs/03.config_examples.md +247 -0
- msprobe/docs/04.acl_config_examples.md +76 -0
- msprobe/docs/05.data_dump_PyTorch.md +198 -0
- msprobe/docs/06.data_dump_MindSpore.md +243 -0
- msprobe/docs/07.accuracy_checker_PyTorch.md +274 -0
- msprobe/docs/08.accuracy_checker_online_PyTorch.md +198 -0
- msprobe/docs/09.accuracy_checker_MindSpore.md +68 -0
- msprobe/docs/10.accuracy_compare_PyTorch.md +245 -0
- msprobe/docs/11.accuracy_compare_MindSpore.md +202 -0
- msprobe/docs/12.overflow_check_PyTorch.md +79 -0
- msprobe/docs/13.overflow_check_MindSpore.md +31 -0
- msprobe/{pytorch/doc/parse_tool.md → docs/14.data_parse_PyTorch.md} +283 -286
- msprobe/docs/15.free_benchmarking_PyTorch.md +164 -0
- msprobe/{doc/grad_probe/grad_probe.md → docs/17.grad_probe.md} +207 -207
- msprobe/docs/FAQ_PyTorch.md +177 -0
- msprobe/docs/S02.report_free_benchmarking_validation_performance_baseline.md +146 -0
- msprobe/docs/img/free_benchmark_framework.png +0 -0
- msprobe/mindspore/__init__.py +1 -1
- msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +254 -245
- msprobe/mindspore/api_accuracy_checker/api_info.py +69 -69
- msprobe/mindspore/api_accuracy_checker/api_runner.py +155 -151
- msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +196 -196
- msprobe/mindspore/api_accuracy_checker/cmd_parser.py +6 -0
- msprobe/mindspore/api_accuracy_checker/compute_element.py +238 -223
- msprobe/mindspore/api_accuracy_checker/main.py +8 -15
- msprobe/mindspore/api_accuracy_checker/type_mapping.py +113 -113
- msprobe/mindspore/api_accuracy_checker/utils.py +79 -62
- msprobe/mindspore/cell_processor.py +34 -34
- msprobe/mindspore/common/const.py +106 -87
- msprobe/mindspore/common/log.py +37 -37
- msprobe/mindspore/common/utils.py +81 -57
- msprobe/mindspore/compare/distributed_compare.py +75 -75
- msprobe/mindspore/compare/ms_compare.py +219 -117
- msprobe/mindspore/compare/ms_graph_compare.py +348 -317
- msprobe/mindspore/compare/ms_to_pt_api.yaml +399 -399
- msprobe/mindspore/debugger/debugger_config.py +66 -74
- msprobe/mindspore/debugger/precision_debugger.py +126 -107
- msprobe/mindspore/dump/dump_tool_factory.py +35 -35
- msprobe/mindspore/dump/hook_cell/api_registry.py +118 -104
- msprobe/mindspore/dump/hook_cell/hook_cell.py +55 -53
- msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +922 -925
- msprobe/mindspore/dump/hook_cell/wrap_api.py +113 -0
- msprobe/mindspore/dump/jit_dump.py +72 -56
- msprobe/mindspore/dump/kernel_graph_dump.py +59 -60
- msprobe/mindspore/dump/kernel_kbyk_dump.py +64 -65
- msprobe/mindspore/free_benchmark/api_pynative_self_check.py +116 -116
- msprobe/mindspore/free_benchmark/common/config.py +12 -12
- msprobe/mindspore/free_benchmark/common/handler_params.py +17 -17
- msprobe/mindspore/free_benchmark/common/utils.py +71 -71
- msprobe/mindspore/free_benchmark/data/support_wrap_ops.yaml +842 -842
- msprobe/mindspore/free_benchmark/decorator/dec_forward.py +43 -42
- msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +107 -107
- msprobe/mindspore/free_benchmark/handler/base_handler.py +90 -90
- msprobe/mindspore/free_benchmark/handler/check_handler.py +41 -41
- msprobe/mindspore/free_benchmark/handler/fix_handler.py +36 -36
- msprobe/mindspore/free_benchmark/handler/handler_factory.py +21 -21
- msprobe/mindspore/free_benchmark/perturbation/add_noise.py +67 -67
- msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +21 -21
- msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +63 -63
- msprobe/mindspore/free_benchmark/perturbation/exchange_value.py +51 -0
- msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +35 -34
- msprobe/mindspore/free_benchmark/perturbation/no_change.py +12 -12
- msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +29 -27
- msprobe/mindspore/free_benchmark/self_check_tool_factory.py +33 -33
- msprobe/mindspore/grad_probe/global_context.py +90 -91
- msprobe/mindspore/grad_probe/grad_analyzer.py +231 -231
- msprobe/mindspore/grad_probe/grad_monitor.py +27 -27
- msprobe/mindspore/grad_probe/grad_stat_csv.py +131 -131
- msprobe/mindspore/grad_probe/hook.py +94 -92
- msprobe/mindspore/grad_probe/utils.py +29 -28
- msprobe/mindspore/ms_config.py +128 -126
- msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +44 -45
- msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +34 -34
- msprobe/mindspore/runtime.py +4 -4
- msprobe/mindspore/service.py +378 -354
- msprobe/mindspore/task_handler_factory.py +24 -24
- msprobe/msprobe.py +105 -107
- msprobe/pytorch/__init__.py +3 -3
- msprobe/pytorch/api_accuracy_checker/common/config.py +53 -55
- msprobe/pytorch/api_accuracy_checker/common/utils.py +214 -165
- msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +213 -213
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +606 -581
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_standard.yaml +132 -132
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_threshold.yaml +390 -390
- msprobe/pytorch/api_accuracy_checker/compare/compare.py +386 -381
- msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +73 -73
- msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +245 -244
- msprobe/pytorch/api_accuracy_checker/config.yaml +10 -10
- msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +335 -332
- msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +200 -199
- msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +133 -134
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +592 -581
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +70 -74
- msprobe/pytorch/api_accuracy_checker/run_ut/torch_ut_setting.json +7 -4
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +197 -202
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +325 -324
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +204 -204
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +219 -218
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/ssl_config.py +10 -10
- msprobe/pytorch/bench_functions/__init__.py +15 -15
- msprobe/pytorch/bench_functions/apply_adam_w.py +28 -28
- msprobe/pytorch/bench_functions/confusion_transpose.py +19 -19
- msprobe/pytorch/bench_functions/fast_gelu.py +55 -55
- msprobe/pytorch/bench_functions/layer_norm_eval.py +6 -6
- msprobe/pytorch/bench_functions/linear.py +12 -12
- msprobe/pytorch/bench_functions/matmul_backward.py +48 -48
- msprobe/pytorch/bench_functions/npu_fusion_attention.py +509 -421
- msprobe/pytorch/bench_functions/rms_norm.py +15 -15
- msprobe/pytorch/bench_functions/rotary_mul.py +52 -52
- msprobe/pytorch/bench_functions/scaled_mask_softmax.py +26 -26
- msprobe/pytorch/bench_functions/swiglu.py +55 -55
- msprobe/pytorch/common/__init__.py +2 -2
- msprobe/pytorch/common/compare_script.template +14 -14
- msprobe/pytorch/common/log.py +20 -31
- msprobe/pytorch/common/parse_json.py +39 -39
- msprobe/pytorch/common/utils.py +305 -300
- msprobe/pytorch/compare/distributed_compare.py +66 -66
- msprobe/pytorch/compare/mapping.yaml +607 -607
- msprobe/pytorch/compare/match.py +34 -33
- msprobe/pytorch/compare/pt_compare.py +50 -40
- msprobe/pytorch/debugger/debugger_config.py +95 -95
- msprobe/pytorch/debugger/precision_debugger.py +125 -125
- msprobe/pytorch/free_benchmark/__init__.py +8 -8
- msprobe/pytorch/free_benchmark/common/constant.py +70 -70
- msprobe/pytorch/free_benchmark/common/counter.py +71 -71
- msprobe/pytorch/free_benchmark/common/enums.py +37 -37
- msprobe/pytorch/free_benchmark/common/params.py +129 -129
- msprobe/pytorch/free_benchmark/common/utils.py +102 -102
- msprobe/pytorch/free_benchmark/compare/grad_saver.py +179 -179
- msprobe/pytorch/free_benchmark/compare/single_benchmark.py +104 -104
- msprobe/pytorch/free_benchmark/main.py +105 -105
- msprobe/pytorch/free_benchmark/perturbed_layers/base_layer.py +13 -13
- msprobe/pytorch/free_benchmark/perturbed_layers/layer_factory.py +41 -41
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +90 -90
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +104 -104
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +63 -63
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +68 -68
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/no_change.py +28 -28
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/npu_base_layser.py +45 -45
- msprobe/pytorch/free_benchmark/perturbed_layers/run_cpu.py +19 -19
- msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +217 -217
- msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +39 -39
- msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +23 -23
- msprobe/pytorch/free_benchmark/result_handlers/handler_factory.py +30 -30
- msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +170 -170
- msprobe/pytorch/function_factory.py +76 -75
- msprobe/pytorch/functional/dump_module.py +39 -39
- msprobe/pytorch/grad_probe/grad_monitor.py +91 -90
- msprobe/pytorch/grad_probe/grad_stat_csv.py +128 -128
- msprobe/pytorch/hook_module/api_registry.py +161 -161
- msprobe/pytorch/hook_module/hook_module.py +120 -120
- msprobe/pytorch/hook_module/support_wrap_ops.yaml +1879 -1877
- msprobe/pytorch/hook_module/utils.py +30 -29
- msprobe/pytorch/hook_module/wrap_aten.py +110 -110
- msprobe/pytorch/hook_module/wrap_distributed.py +78 -78
- msprobe/pytorch/hook_module/wrap_functional.py +105 -105
- msprobe/pytorch/hook_module/wrap_npu_custom.py +93 -84
- msprobe/pytorch/hook_module/wrap_tensor.py +71 -71
- msprobe/pytorch/hook_module/wrap_torch.py +86 -86
- msprobe/pytorch/hook_module/wrap_vf.py +62 -62
- msprobe/pytorch/module_processer.py +138 -138
- msprobe/pytorch/online_dispatch/__init__.py +20 -20
- msprobe/pytorch/online_dispatch/compare.py +236 -236
- msprobe/pytorch/online_dispatch/dispatch.py +271 -271
- msprobe/pytorch/online_dispatch/dump_compare.py +155 -156
- msprobe/pytorch/online_dispatch/single_compare.py +391 -391
- msprobe/pytorch/online_dispatch/torch_ops_config.yaml +49 -49
- msprobe/pytorch/online_dispatch/utils.py +130 -146
- msprobe/pytorch/parse.py +4 -4
- msprobe/pytorch/parse_tool/cli.py +32 -32
- msprobe/pytorch/parse_tool/lib/compare.py +260 -271
- msprobe/pytorch/parse_tool/lib/config.py +52 -52
- msprobe/pytorch/parse_tool/lib/file_desc.py +31 -31
- msprobe/pytorch/parse_tool/lib/interactive_cli.py +102 -102
- msprobe/pytorch/parse_tool/lib/parse_exception.py +54 -54
- msprobe/pytorch/parse_tool/lib/parse_tool.py +158 -158
- msprobe/pytorch/parse_tool/lib/utils.py +316 -321
- msprobe/pytorch/parse_tool/lib/visualization.py +85 -91
- msprobe/pytorch/pt_config.py +188 -187
- msprobe/pytorch/service.py +246 -252
- mindstudio_probe-1.0.3.dist-info/RECORD +0 -272
- msprobe/config/README.md +0 -539
- msprobe/mindspore/doc/compare.md +0 -58
- msprobe/mindspore/doc/dump.md +0 -217
- msprobe/mindspore/dump/hook_cell/wrap_functional.py +0 -91
- msprobe/mindspore/dump/hook_cell/wrap_tensor.py +0 -63
- msprobe/pytorch/doc/FAQ.md +0 -193
- msprobe/pytorch/doc/api_accuracy_checker.md +0 -313
- msprobe/pytorch/doc/api_accuracy_checker_online.md +0 -187
- msprobe/pytorch/doc/dump.md +0 -260
- msprobe/pytorch/doc/msprobe/321/207/342/226/223/342/225/233/321/205/342/225/221/320/266/321/205/342/225/226/320/265/321/205/320/225/342/225/226/321/206/320/245/342/226/221/321/206/320/235/320/276dump/321/206/320/260/320/227/321/205/320/227/320/226/321/206/320/220/320/267/321/210/320/223/342/225/234/321/205/320/257/342/225/221/321/207/342/225/221/342/224/220/321/206/320/232/320/265/321/205/320/241/320/232.md +0 -182
- msprobe/pytorch/doc/ptdbg_ascend_compare.md +0 -240
- msprobe/pytorch/doc/ptdbg_ascend_overview.md +0 -68
- msprobe/pytorch/doc/ptdbg_ascend_quickstart.md +0 -381
- msprobe/pytorch/doc/run_overflow_check.md +0 -25
- msprobe/pytorch/doc//321/205/320/254/320/270/321/207/342/225/221/342/224/220/321/207/342/226/223/342/225/233/321/205/342/225/221/320/266/321/206/320/277/320/244/321/205/320/277/342/225/243.md +0 -90
- msprobe/pytorch/doc//321/206/320/247/320/260/321/206/320/260/320/227/321/206/320/255/320/226/321/205/342/225/226/320/265/321/205/320/225/342/225/226/321/205/320/254/342/225/221/321/206/320/251/320/277/321/211/320/272/320/234/321/210/320/277/320/221/321/205/320/242/320/234/321/206/320/220/320/267/321/210/320/223/342/225/234/321/205/320/257/342/225/221/321/207/342/225/221/342/224/220/321/206/320/232/320/265/321/205/320/241/320/232.md +0 -151
- {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.0.4.dist-info}/top_level.txt +0 -0
- /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_1.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_2.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_3.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_4.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/GPT-3_1.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/GPT-3_2.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/GPT-3_3.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/GPT-3_4.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/GPT-3_5.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/GPT-3_6.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/GPT-3_7.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/GPT-3_8.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/YOLOV5S_1.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/YOLOV5S_2.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/accuracy_checking_details.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/accuracy_checking_result.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/api_precision_compare_details.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/api_precision_compare_result.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/auto_analyze_log.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/compare_result_pkl.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/compare_result_pkl_md5.png.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/cpu_info.png +0 -0
- /msprobe/{config → docs}/img/free_benchmark.png +0 -0
- /msprobe/{doc/grad_probe/img/image-1.png → docs/img/grad_probe_image-1.png} +0 -0
- /msprobe/{doc/grad_probe/img/image-2.png → docs/img/grad_probe_image-2.png} +0 -0
- /msprobe/{doc/grad_probe/img/image-3.png → docs/img/grad_probe_image-3.png} +0 -0
- /msprobe/{doc/grad_probe/img/image-4.png → docs/img/grad_probe_image-4.png} +0 -0
- /msprobe/{doc/grad_probe/img/image.png → docs/img/grad_probe_image.png} +0 -0
- /msprobe/{pytorch/doc → docs}/img/module_compare.png +0 -0
|
@@ -1,27 +1,27 @@
|
|
|
1
|
-
from msprobe.mindspore.grad_probe.global_context import grad_context
|
|
2
|
-
from msprobe.mindspore.grad_probe.grad_analyzer import csv_generator
|
|
3
|
-
from msprobe.mindspore.grad_probe.hook import hook_optimizer
|
|
4
|
-
from msprobe.core.grad_probe.constant import GradConst
|
|
5
|
-
|
|
6
|
-
|
|
7
|
-
class GradientMonitor:
|
|
8
|
-
|
|
9
|
-
def __init__(self, common_dict, task_config):
|
|
10
|
-
config = {}
|
|
11
|
-
config[GradConst.OUTPUT_PATH] = common_dict.dump_path
|
|
12
|
-
config[GradConst.STEP] = common_dict.step
|
|
13
|
-
config[GradConst.RANK] = common_dict.rank
|
|
14
|
-
config[GradConst.PARAM_LIST] = task_config.param_list
|
|
15
|
-
config[GradConst.LEVEL] = task_config.grad_level
|
|
16
|
-
config[GradConst.BOUNDS] = task_config.bounds
|
|
17
|
-
self.config = config
|
|
18
|
-
grad_context.init_context(self.config)
|
|
19
|
-
|
|
20
|
-
@staticmethod
|
|
21
|
-
def monitor(opt):
|
|
22
|
-
csv_generator.init(grad_context)
|
|
23
|
-
hook_optimizer(opt)
|
|
24
|
-
|
|
25
|
-
@staticmethod
|
|
26
|
-
def stop():
|
|
27
|
-
csv_generator.stop()
|
|
1
|
+
from msprobe.mindspore.grad_probe.global_context import grad_context
|
|
2
|
+
from msprobe.mindspore.grad_probe.grad_analyzer import csv_generator
|
|
3
|
+
from msprobe.mindspore.grad_probe.hook import hook_optimizer
|
|
4
|
+
from msprobe.core.grad_probe.constant import GradConst
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class GradientMonitor:
|
|
8
|
+
|
|
9
|
+
def __init__(self, common_dict, task_config):
|
|
10
|
+
config = {}
|
|
11
|
+
config[GradConst.OUTPUT_PATH] = common_dict.dump_path
|
|
12
|
+
config[GradConst.STEP] = common_dict.step
|
|
13
|
+
config[GradConst.RANK] = common_dict.rank
|
|
14
|
+
config[GradConst.PARAM_LIST] = task_config.param_list
|
|
15
|
+
config[GradConst.LEVEL] = task_config.grad_level
|
|
16
|
+
config[GradConst.BOUNDS] = task_config.bounds
|
|
17
|
+
self.config = config
|
|
18
|
+
grad_context.init_context(self.config)
|
|
19
|
+
|
|
20
|
+
@staticmethod
|
|
21
|
+
def monitor(opt):
|
|
22
|
+
csv_generator.init(grad_context)
|
|
23
|
+
hook_optimizer(opt)
|
|
24
|
+
|
|
25
|
+
@staticmethod
|
|
26
|
+
def stop():
|
|
27
|
+
csv_generator.stop()
|
|
@@ -1,132 +1,132 @@
|
|
|
1
|
-
from abc import ABC, abstractmethod
|
|
2
|
-
import hashlib
|
|
3
|
-
|
|
4
|
-
import mindspore
|
|
5
|
-
from mindspore import ops, Tensor
|
|
6
|
-
from msprobe.core.grad_probe.constant import GradConst
|
|
7
|
-
|
|
8
|
-
|
|
9
|
-
class CsvInput:
|
|
10
|
-
def __init__(self, param_name, grad, bounds):
|
|
11
|
-
self.param_name = param_name
|
|
12
|
-
self.grad = grad
|
|
13
|
-
self.bounds = bounds
|
|
14
|
-
|
|
15
|
-
class GradStatCsv:
|
|
16
|
-
csv = {}
|
|
17
|
-
|
|
18
|
-
@staticmethod
|
|
19
|
-
def get_csv_header(level, csv_input):
|
|
20
|
-
header = ["param_name"]
|
|
21
|
-
for key in level["header"]:
|
|
22
|
-
header.extend(GradStatCsv.csv[key].generate_csv_header(csv_input))
|
|
23
|
-
return header
|
|
24
|
-
|
|
25
|
-
@staticmethod
|
|
26
|
-
def get_csv_line(level, csv_input):
|
|
27
|
-
line = [csv_input.param_name]
|
|
28
|
-
for key in level["header"]:
|
|
29
|
-
line.extend(GradStatCsv.csv[key].generate_csv_content(csv_input))
|
|
30
|
-
return line
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
def register_csv_item(key, cls=None):
|
|
34
|
-
if cls is None:
|
|
35
|
-
# 无参数时,返回装饰器函数
|
|
36
|
-
return lambda cls: register_csv_item(key, cls)
|
|
37
|
-
GradStatCsv.csv[key] = cls
|
|
38
|
-
return cls
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
class CsvItem(ABC):
|
|
42
|
-
@staticmethod
|
|
43
|
-
@abstractmethod
|
|
44
|
-
def generate_csv_header(csv_input):
|
|
45
|
-
pass
|
|
46
|
-
|
|
47
|
-
@staticmethod
|
|
48
|
-
@abstractmethod
|
|
49
|
-
def generate_csv_content(csv_input):
|
|
50
|
-
pass
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
@register_csv_item(GradConst.MD5)
|
|
54
|
-
class CsvMd5(CsvItem):
|
|
55
|
-
def generate_csv_header(csv_input):
|
|
56
|
-
return ["MD5"]
|
|
57
|
-
|
|
58
|
-
def generate_csv_content(csv_input):
|
|
59
|
-
grad = csv_input.grad
|
|
60
|
-
tensor_bytes = grad.float().numpy().tobytes()
|
|
61
|
-
md5_hash = hashlib.md5(tensor_bytes)
|
|
62
|
-
return [md5_hash.hexdigest()]
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
@register_csv_item(GradConst.DISTRIBUTION)
|
|
66
|
-
class CsvDistribution(CsvItem):
|
|
67
|
-
def generate_csv_header(csv_input):
|
|
68
|
-
bounds = csv_input.bounds
|
|
69
|
-
intervals = []
|
|
70
|
-
if bounds:
|
|
71
|
-
intervals.append(f"(-inf, {bounds[0]}]")
|
|
72
|
-
for i in range(1, len(bounds)):
|
|
73
|
-
intervals.append(f"({bounds[i-1]}, {bounds[i]}]")
|
|
74
|
-
if intervals:
|
|
75
|
-
intervals.append(f"({bounds[-1]}, inf)")
|
|
76
|
-
intervals.append("=0")
|
|
77
|
-
|
|
78
|
-
return intervals
|
|
79
|
-
|
|
80
|
-
def generate_csv_content(csv_input):
|
|
81
|
-
grad = csv_input.grad
|
|
82
|
-
bounds = csv_input.bounds
|
|
83
|
-
if grad.dtype == mindspore.bfloat16:
|
|
84
|
-
grad = grad.to(mindspore.float32)
|
|
85
|
-
element_num = grad.numel()
|
|
86
|
-
grad_equal_0_num = (grad == 0).sum().item()
|
|
87
|
-
bucketsize_result = ops.bucketize(grad.float(), bounds)
|
|
88
|
-
bucketsize_result = bucketsize_result.astype(mindspore.int8)
|
|
89
|
-
interval_nums = [(bucketsize_result == i).sum().item() for i in range(len(bounds) + 1)]
|
|
90
|
-
interval_nums.append(grad_equal_0_num)
|
|
91
|
-
return_list = [x / element_num if element_num != 0 else 0 for x in interval_nums]
|
|
92
|
-
return return_list
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
@register_csv_item(GradConst.MAX)
|
|
96
|
-
class CsvMax(CsvItem):
|
|
97
|
-
def generate_csv_header(csv_input):
|
|
98
|
-
return ["max"]
|
|
99
|
-
|
|
100
|
-
def generate_csv_content(csv_input):
|
|
101
|
-
grad = csv_input.grad
|
|
102
|
-
return [ops.amax(grad).float().numpy().tolist()]
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
@register_csv_item(GradConst.MIN)
|
|
106
|
-
class CsvMin(CsvItem):
|
|
107
|
-
def generate_csv_header(csv_input):
|
|
108
|
-
return ["min"]
|
|
109
|
-
|
|
110
|
-
def generate_csv_content(csv_input):
|
|
111
|
-
grad = csv_input.grad
|
|
112
|
-
return [ops.amin(grad).float().numpy().tolist()]
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
@register_csv_item(GradConst.NORM)
|
|
116
|
-
class CsvNorm(CsvItem):
|
|
117
|
-
def generate_csv_header(csv_input):
|
|
118
|
-
return ["norm"]
|
|
119
|
-
|
|
120
|
-
def generate_csv_content(csv_input):
|
|
121
|
-
grad = csv_input.grad
|
|
122
|
-
return [ops.norm(grad).float().numpy().tolist()]
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
@register_csv_item(GradConst.SHAPE)
|
|
126
|
-
class CsvShape(CsvItem):
|
|
127
|
-
def generate_csv_header(csv_input):
|
|
128
|
-
return ["shape"]
|
|
129
|
-
|
|
130
|
-
def generate_csv_content(csv_input):
|
|
131
|
-
grad = csv_input.grad
|
|
1
|
+
from abc import ABC, abstractmethod
|
|
2
|
+
import hashlib
|
|
3
|
+
|
|
4
|
+
import mindspore
|
|
5
|
+
from mindspore import ops, Tensor
|
|
6
|
+
from msprobe.core.grad_probe.constant import GradConst
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class CsvInput:
|
|
10
|
+
def __init__(self, param_name, grad, bounds):
|
|
11
|
+
self.param_name = param_name
|
|
12
|
+
self.grad = grad
|
|
13
|
+
self.bounds = bounds
|
|
14
|
+
|
|
15
|
+
class GradStatCsv:
|
|
16
|
+
csv = {}
|
|
17
|
+
|
|
18
|
+
@staticmethod
|
|
19
|
+
def get_csv_header(level, csv_input):
|
|
20
|
+
header = ["param_name"]
|
|
21
|
+
for key in level["header"]:
|
|
22
|
+
header.extend(GradStatCsv.csv[key].generate_csv_header(csv_input))
|
|
23
|
+
return header
|
|
24
|
+
|
|
25
|
+
@staticmethod
|
|
26
|
+
def get_csv_line(level, csv_input):
|
|
27
|
+
line = [csv_input.param_name]
|
|
28
|
+
for key in level["header"]:
|
|
29
|
+
line.extend(GradStatCsv.csv[key].generate_csv_content(csv_input))
|
|
30
|
+
return line
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def register_csv_item(key, cls=None):
|
|
34
|
+
if cls is None:
|
|
35
|
+
# 无参数时,返回装饰器函数
|
|
36
|
+
return lambda cls: register_csv_item(key, cls)
|
|
37
|
+
GradStatCsv.csv[key] = cls
|
|
38
|
+
return cls
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
class CsvItem(ABC):
|
|
42
|
+
@staticmethod
|
|
43
|
+
@abstractmethod
|
|
44
|
+
def generate_csv_header(csv_input):
|
|
45
|
+
pass
|
|
46
|
+
|
|
47
|
+
@staticmethod
|
|
48
|
+
@abstractmethod
|
|
49
|
+
def generate_csv_content(csv_input):
|
|
50
|
+
pass
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
@register_csv_item(GradConst.MD5)
|
|
54
|
+
class CsvMd5(CsvItem):
|
|
55
|
+
def generate_csv_header(csv_input):
|
|
56
|
+
return ["MD5"]
|
|
57
|
+
|
|
58
|
+
def generate_csv_content(csv_input):
|
|
59
|
+
grad = csv_input.grad
|
|
60
|
+
tensor_bytes = grad.float().numpy().tobytes()
|
|
61
|
+
md5_hash = hashlib.md5(tensor_bytes)
|
|
62
|
+
return [md5_hash.hexdigest()]
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
@register_csv_item(GradConst.DISTRIBUTION)
|
|
66
|
+
class CsvDistribution(CsvItem):
|
|
67
|
+
def generate_csv_header(csv_input):
|
|
68
|
+
bounds = csv_input.bounds
|
|
69
|
+
intervals = []
|
|
70
|
+
if bounds:
|
|
71
|
+
intervals.append(f"(-inf, {bounds[0]}]")
|
|
72
|
+
for i in range(1, len(bounds)):
|
|
73
|
+
intervals.append(f"({bounds[i-1]}, {bounds[i]}]")
|
|
74
|
+
if intervals:
|
|
75
|
+
intervals.append(f"({bounds[-1]}, inf)")
|
|
76
|
+
intervals.append("=0")
|
|
77
|
+
|
|
78
|
+
return intervals
|
|
79
|
+
|
|
80
|
+
def generate_csv_content(csv_input):
|
|
81
|
+
grad = csv_input.grad
|
|
82
|
+
bounds = csv_input.bounds
|
|
83
|
+
if grad.dtype == mindspore.bfloat16:
|
|
84
|
+
grad = grad.to(mindspore.float32)
|
|
85
|
+
element_num = grad.numel()
|
|
86
|
+
grad_equal_0_num = (grad == 0).sum().item()
|
|
87
|
+
bucketsize_result = ops.bucketize(grad.float(), bounds)
|
|
88
|
+
bucketsize_result = bucketsize_result.astype(mindspore.int8)
|
|
89
|
+
interval_nums = [(bucketsize_result == i).sum().item() for i in range(len(bounds) + 1)]
|
|
90
|
+
interval_nums.append(grad_equal_0_num)
|
|
91
|
+
return_list = [x / element_num if element_num != 0 else 0 for x in interval_nums]
|
|
92
|
+
return return_list
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
@register_csv_item(GradConst.MAX)
|
|
96
|
+
class CsvMax(CsvItem):
|
|
97
|
+
def generate_csv_header(csv_input):
|
|
98
|
+
return ["max"]
|
|
99
|
+
|
|
100
|
+
def generate_csv_content(csv_input):
|
|
101
|
+
grad = csv_input.grad
|
|
102
|
+
return [ops.amax(grad).float().numpy().tolist()]
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
@register_csv_item(GradConst.MIN)
|
|
106
|
+
class CsvMin(CsvItem):
|
|
107
|
+
def generate_csv_header(csv_input):
|
|
108
|
+
return ["min"]
|
|
109
|
+
|
|
110
|
+
def generate_csv_content(csv_input):
|
|
111
|
+
grad = csv_input.grad
|
|
112
|
+
return [ops.amin(grad).float().numpy().tolist()]
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
@register_csv_item(GradConst.NORM)
|
|
116
|
+
class CsvNorm(CsvItem):
|
|
117
|
+
def generate_csv_header(csv_input):
|
|
118
|
+
return ["norm"]
|
|
119
|
+
|
|
120
|
+
def generate_csv_content(csv_input):
|
|
121
|
+
grad = csv_input.grad
|
|
122
|
+
return [ops.norm(grad).float().numpy().tolist()]
|
|
123
|
+
|
|
124
|
+
|
|
125
|
+
@register_csv_item(GradConst.SHAPE)
|
|
126
|
+
class CsvShape(CsvItem):
|
|
127
|
+
def generate_csv_header(csv_input):
|
|
128
|
+
return ["shape"]
|
|
129
|
+
|
|
130
|
+
def generate_csv_content(csv_input):
|
|
131
|
+
grad = csv_input.grad
|
|
132
132
|
return [list(grad.shape)]
|
|
@@ -1,92 +1,94 @@
|
|
|
1
|
-
|
|
2
|
-
import os
|
|
3
|
-
|
|
4
|
-
import mindspore
|
|
5
|
-
import mindspore as ms
|
|
6
|
-
from mindspore.common.api import jit
|
|
7
|
-
from mindspore.nn.optim.optimizer import Optimizer
|
|
8
|
-
from mindspore.common.parameter import Parameter
|
|
9
|
-
from mindspore.common.initializer import initializer
|
|
10
|
-
|
|
11
|
-
from msprobe.core.grad_probe.constant import GradConst
|
|
12
|
-
from msprobe.
|
|
13
|
-
|
|
14
|
-
from msprobe.core.common.
|
|
15
|
-
from msprobe.mindspore.grad_probe.global_context import grad_context
|
|
16
|
-
from msprobe.mindspore.grad_probe.grad_analyzer import grad_dump, get_rank_id
|
|
17
|
-
from msprobe.mindspore.grad_probe.grad_analyzer import csv_generator
|
|
18
|
-
from msprobe.mindspore.grad_probe.grad_stat_csv import GradStatCsv, CsvInput
|
|
19
|
-
from msprobe.mindspore.grad_probe.utils import save_grad_direction, get_adapted_level
|
|
20
|
-
|
|
21
|
-
class HookInput:
|
|
22
|
-
|
|
23
|
-
'''
|
|
24
|
-
HookInput is a class wrapping all the variables used for hooking optimizer
|
|
25
|
-
'''
|
|
26
|
-
|
|
27
|
-
def __init__(self, opt) -> None:
|
|
28
|
-
self.func = opt.construct
|
|
29
|
-
self.g_names = [param.name for param in opt._parameters]
|
|
30
|
-
self.param_list = grad_context.get_context(GradConst.PARAM_LIST)
|
|
31
|
-
self.rank_id = get_rank_id()
|
|
32
|
-
output_path = grad_context.get_context(GradConst.OUTPUT_PATH)
|
|
33
|
-
self.dump_dir = os.path.join(output_path, f"rank{self.rank_id}", "Dump")
|
|
34
|
-
self.save_dir = os.path.join(output_path, f"rank{self.rank_id}")
|
|
35
|
-
self.step_finish_flag = os.path.join(self.dump_dir, GradConst.STEP_FINISH)
|
|
36
|
-
if os.path.exists(self.save_dir):
|
|
37
|
-
logger.warning(f"Delete existing path {self.save_dir}.")
|
|
38
|
-
remove_path(self.save_dir)
|
|
39
|
-
self.level = grad_context.get_context(GradConst.LEVEL)
|
|
40
|
-
self.bounds = grad_context.get_context(GradConst.BOUNDS)
|
|
41
|
-
self.mode = mindspore.get_context("mode")
|
|
42
|
-
|
|
43
|
-
def hook_graph_mode_optimizer(opt, hook_input):
|
|
44
|
-
@jit
|
|
45
|
-
def new_construct(self, gradients):
|
|
46
|
-
for index, grad_value in enumerate(gradients):
|
|
47
|
-
if hook_input.param_list and hook_input.g_names[index] not in hook_input.param_list:
|
|
48
|
-
continue
|
|
49
|
-
grad_dump(hook_input.dump_dir, hook_input.g_names[index], self.dump_step,
|
|
50
|
-
grad_value, hook_input.level, hook_input.bounds)
|
|
51
|
-
ms.ops.TensorDump()(hook_input.step_finish_flag, self.dump_step)
|
|
52
|
-
self.assignadd(self.dump_step, self.global_step_increase_tensor)
|
|
53
|
-
out = hook_input.func(gradients)
|
|
54
|
-
return out
|
|
55
|
-
|
|
56
|
-
opt.dump_step = Parameter(initializer(0, [1], ms.int32), name="dump_step")
|
|
57
|
-
opt.construct = new_construct.__get__(opt, type(opt))
|
|
58
|
-
csv_generator.start()
|
|
59
|
-
|
|
60
|
-
def hook_pynative_optimizer(opt, hook_input):
|
|
61
|
-
level_adapted = get_adapted_level(hook_input.level)
|
|
62
|
-
|
|
63
|
-
def hook_fn(cell, input):
|
|
64
|
-
gradients, = input
|
|
65
|
-
cur_step = grad_context.get_context(GradConst.CURRENT_STEP)
|
|
66
|
-
if grad_context.step_need_dump(cur_step) and grad_context.rank_need_dump(hook_input.rank_id):
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
1
|
+
|
|
2
|
+
import os
|
|
3
|
+
|
|
4
|
+
import mindspore
|
|
5
|
+
import mindspore as ms
|
|
6
|
+
from mindspore.common.api import jit
|
|
7
|
+
from mindspore.nn.optim.optimizer import Optimizer
|
|
8
|
+
from mindspore.common.parameter import Parameter
|
|
9
|
+
from mindspore.common.initializer import initializer
|
|
10
|
+
|
|
11
|
+
from msprobe.core.grad_probe.constant import GradConst
|
|
12
|
+
from msprobe.mindspore.common.log import logger
|
|
13
|
+
|
|
14
|
+
from msprobe.core.common.file_utils import remove_path, write_csv, create_directory
|
|
15
|
+
from msprobe.mindspore.grad_probe.global_context import grad_context
|
|
16
|
+
from msprobe.mindspore.grad_probe.grad_analyzer import grad_dump, get_rank_id
|
|
17
|
+
from msprobe.mindspore.grad_probe.grad_analyzer import csv_generator
|
|
18
|
+
from msprobe.mindspore.grad_probe.grad_stat_csv import GradStatCsv, CsvInput
|
|
19
|
+
from msprobe.mindspore.grad_probe.utils import save_grad_direction, get_adapted_level
|
|
20
|
+
|
|
21
|
+
class HookInput:
|
|
22
|
+
|
|
23
|
+
'''
|
|
24
|
+
HookInput is a class wrapping all the variables used for hooking optimizer
|
|
25
|
+
'''
|
|
26
|
+
|
|
27
|
+
def __init__(self, opt) -> None:
|
|
28
|
+
self.func = opt.construct
|
|
29
|
+
self.g_names = [param.name for param in opt._parameters]
|
|
30
|
+
self.param_list = grad_context.get_context(GradConst.PARAM_LIST)
|
|
31
|
+
self.rank_id = get_rank_id()
|
|
32
|
+
output_path = grad_context.get_context(GradConst.OUTPUT_PATH)
|
|
33
|
+
self.dump_dir = os.path.join(output_path, f"rank{self.rank_id}", "Dump")
|
|
34
|
+
self.save_dir = os.path.join(output_path, f"rank{self.rank_id}")
|
|
35
|
+
self.step_finish_flag = os.path.join(self.dump_dir, GradConst.STEP_FINISH)
|
|
36
|
+
if os.path.exists(self.save_dir):
|
|
37
|
+
logger.warning(f"Delete existing path {self.save_dir}.")
|
|
38
|
+
remove_path(self.save_dir)
|
|
39
|
+
self.level = grad_context.get_context(GradConst.LEVEL)
|
|
40
|
+
self.bounds = grad_context.get_context(GradConst.BOUNDS)
|
|
41
|
+
self.mode = mindspore.get_context("mode")
|
|
42
|
+
|
|
43
|
+
def hook_graph_mode_optimizer(opt, hook_input):
|
|
44
|
+
@jit
|
|
45
|
+
def new_construct(self, gradients):
|
|
46
|
+
for index, grad_value in enumerate(gradients):
|
|
47
|
+
if hook_input.param_list and hook_input.g_names[index] not in hook_input.param_list:
|
|
48
|
+
continue
|
|
49
|
+
grad_dump(hook_input.dump_dir, hook_input.g_names[index], self.dump_step,
|
|
50
|
+
grad_value, hook_input.level, hook_input.bounds)
|
|
51
|
+
ms.ops.TensorDump()(hook_input.step_finish_flag, self.dump_step)
|
|
52
|
+
self.assignadd(self.dump_step, self.global_step_increase_tensor)
|
|
53
|
+
out = hook_input.func(gradients)
|
|
54
|
+
return out
|
|
55
|
+
|
|
56
|
+
opt.dump_step = Parameter(initializer(0, [1], ms.int32), name="dump_step")
|
|
57
|
+
opt.construct = new_construct.__get__(opt, type(opt))
|
|
58
|
+
csv_generator.start()
|
|
59
|
+
|
|
60
|
+
def hook_pynative_optimizer(opt, hook_input):
|
|
61
|
+
level_adapted = get_adapted_level(hook_input.level)
|
|
62
|
+
|
|
63
|
+
def hook_fn(cell, input):
|
|
64
|
+
gradients, = input
|
|
65
|
+
cur_step = grad_context.get_context(GradConst.CURRENT_STEP)
|
|
66
|
+
if grad_context.step_need_dump(cur_step) and grad_context.rank_need_dump(hook_input.rank_id):
|
|
67
|
+
create_directory(hook_input.save_dir)
|
|
68
|
+
output_lines = []
|
|
69
|
+
for index, grad_value in enumerate(gradients):
|
|
70
|
+
param_name = hook_input.g_names[index]
|
|
71
|
+
if hook_input.param_list and param_name not in hook_input.param_list:
|
|
72
|
+
continue
|
|
73
|
+
csv_input = CsvInput(param_name, grad_value, hook_input.bounds)
|
|
74
|
+
grad_info = GradStatCsv.get_csv_line(level_adapted, csv_input)
|
|
75
|
+
output_lines.append(grad_info)
|
|
76
|
+
if level_adapted["have_grad_direction"]:
|
|
77
|
+
save_grad_direction(param_name, grad_value, os.path.join(hook_input.save_dir, f'step{cur_step}'))
|
|
78
|
+
output_csv_path = os.path.join(hook_input.save_dir, f"grad_summary_{cur_step}.csv")
|
|
79
|
+
dummy_csv_input = CsvInput(None, None, hook_input.bounds)
|
|
80
|
+
output_lines.insert(0, GradStatCsv.get_csv_header(level_adapted, dummy_csv_input))
|
|
81
|
+
write_csv(output_lines, output_csv_path)
|
|
82
|
+
logger.info(f"write grad data to {output_csv_path}")
|
|
83
|
+
grad_context.update_step()
|
|
84
|
+
|
|
85
|
+
opt.register_forward_pre_hook(hook_fn)
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
def hook_optimizer(opt: Optimizer):
|
|
89
|
+
hook_input = HookInput(opt)
|
|
90
|
+
|
|
91
|
+
if hook_input.mode == mindspore.GRAPH_MODE:
|
|
92
|
+
hook_graph_mode_optimizer(opt, hook_input)
|
|
93
|
+
else:
|
|
94
|
+
hook_pynative_optimizer(opt, hook_input)
|
|
@@ -1,29 +1,30 @@
|
|
|
1
|
-
import os
|
|
2
|
-
|
|
3
|
-
import
|
|
4
|
-
import
|
|
5
|
-
from msprobe.core.grad_probe.
|
|
6
|
-
from msprobe.core.
|
|
7
|
-
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
1
|
+
import os
|
|
2
|
+
|
|
3
|
+
import mindspore
|
|
4
|
+
from msprobe.core.grad_probe.constant import level_adp
|
|
5
|
+
from msprobe.core.grad_probe.utils import check_param
|
|
6
|
+
from msprobe.core.common.file_utils import (create_directory,
|
|
7
|
+
check_path_before_create,
|
|
8
|
+
check_file_or_directory_path,
|
|
9
|
+
save_npy)
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def save_grad_direction(param_name, grad, save_path):
|
|
13
|
+
if not os.path.exists(save_path):
|
|
14
|
+
create_directory(save_path)
|
|
15
|
+
check_file_or_directory_path(save_path, isdir=True)
|
|
16
|
+
check_param(param_name)
|
|
17
|
+
save_filepath = os.path.join(save_path, f"{param_name}.npy")
|
|
18
|
+
check_path_before_create(save_filepath)
|
|
19
|
+
|
|
20
|
+
if grad.dtype == mindspore.bfloat16:
|
|
21
|
+
grad = grad.to(mindspore.float32)
|
|
22
|
+
grad_direction_tensor = grad > 0
|
|
23
|
+
grad_direction_ndarray = grad_direction_tensor.numpy()
|
|
24
|
+
|
|
25
|
+
save_npy(grad_direction_ndarray, save_filepath)
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def get_adapted_level(level: str):
|
|
29
|
+
level_adapted = level_adp.get(level)
|
|
29
30
|
return level_adapted
|