mindstudio-probe 1.0.3__py3-none-any.whl → 1.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.1.0.dist-info}/LICENSE +201 -201
- {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.1.0.dist-info}/METADATA +36 -34
- mindstudio_probe-1.1.0.dist-info/RECORD +287 -0
- {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.1.0.dist-info}/WHEEL +1 -1
- {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.1.0.dist-info}/entry_points.txt +1 -0
- msprobe/README.md +131 -237
- msprobe/__init__.py +16 -1
- msprobe/{config/config.json → config.json} +47 -49
- msprobe/core/advisor/advisor.py +124 -124
- msprobe/core/advisor/advisor_const.py +58 -59
- msprobe/core/advisor/advisor_result.py +58 -58
- msprobe/core/common/const.py +402 -318
- msprobe/core/common/exceptions.py +99 -99
- msprobe/core/common/{file_check.py → file_utils.py} +523 -283
- msprobe/core/common/inplace_op_checker.py +38 -0
- msprobe/core/common/inplace_ops.yaml +251 -0
- msprobe/core/common/log.py +86 -69
- msprobe/core/common/utils.py +371 -616
- msprobe/core/common_config.py +78 -71
- msprobe/core/compare/acc_compare.py +472 -298
- msprobe/core/compare/check.py +180 -95
- msprobe/core/compare/compare_cli.py +69 -49
- msprobe/core/compare/highlight.py +259 -222
- msprobe/core/compare/multiprocessing_compute.py +174 -149
- msprobe/core/compare/npy_compare.py +310 -295
- msprobe/core/compare/utils.py +464 -429
- msprobe/core/data_dump/data_collector.py +153 -144
- msprobe/core/data_dump/data_processor/base.py +337 -293
- msprobe/core/data_dump/data_processor/factory.py +76 -59
- msprobe/core/data_dump/data_processor/mindspore_processor.py +192 -198
- msprobe/core/data_dump/data_processor/pytorch_processor.py +383 -389
- msprobe/core/data_dump/json_writer.py +117 -116
- msprobe/core/data_dump/scope.py +194 -178
- msprobe/core/grad_probe/constant.py +74 -70
- msprobe/core/grad_probe/grad_compare.py +170 -175
- msprobe/core/grad_probe/utils.py +77 -52
- msprobe/docs/01.installation.md +99 -0
- msprobe/docs/02.config_introduction.md +137 -0
- msprobe/docs/03.config_examples.md +237 -0
- msprobe/docs/04.acl_config_examples.md +78 -0
- msprobe/docs/05.data_dump_PyTorch.md +326 -0
- msprobe/docs/06.data_dump_MindSpore.md +285 -0
- msprobe/docs/07.accuracy_checker_PyTorch.md +297 -0
- msprobe/docs/08.accuracy_checker_online_PyTorch.md +238 -0
- msprobe/docs/09.accuracy_checker_MindSpore.md +68 -0
- msprobe/docs/10.accuracy_compare_PyTorch.md +327 -0
- msprobe/docs/11.accuracy_compare_MindSpore.md +333 -0
- msprobe/docs/12.overflow_check_PyTorch.md +79 -0
- msprobe/docs/13.overflow_check_MindSpore.md +31 -0
- msprobe/{pytorch/doc/parse_tool.md → docs/14.data_parse_PyTorch.md} +283 -286
- msprobe/docs/15.free_benchmarking_PyTorch.md +170 -0
- msprobe/docs/16.free_benchmarking_MindSpore.md +140 -0
- msprobe/{doc/grad_probe/grad_probe.md → docs/17.grad_probe.md} +205 -207
- msprobe/{pytorch/doc//321/205/320/254/320/270/321/207/342/225/221/342/224/220/321/207/342/226/223/342/225/233/321/205/342/225/221/320/266/321/206/320/277/320/244/321/205/320/277/342/225/243.md → docs/18.online_dispatch.md} +89 -90
- msprobe/docs/FAQ.md +189 -0
- msprobe/docs/S02.report_free_benchmarking_validation_performance_baseline.md +146 -0
- msprobe/docs/img/free_benchmark_framework.png +0 -0
- msprobe/docs/img/ms_dump.png +0 -0
- msprobe/docs/img/ms_layer.png +0 -0
- msprobe/docs/img/pt_dump.png +0 -0
- msprobe/mindspore/__init__.py +2 -1
- msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +278 -245
- msprobe/mindspore/api_accuracy_checker/api_info.py +76 -69
- msprobe/mindspore/api_accuracy_checker/api_runner.py +155 -151
- msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +196 -196
- msprobe/mindspore/api_accuracy_checker/cmd_parser.py +6 -0
- msprobe/mindspore/api_accuracy_checker/compute_element.py +238 -223
- msprobe/mindspore/api_accuracy_checker/main.py +8 -15
- msprobe/mindspore/api_accuracy_checker/type_mapping.py +113 -113
- msprobe/mindspore/api_accuracy_checker/utils.py +79 -62
- msprobe/mindspore/cell_processor.py +58 -34
- msprobe/mindspore/common/const.py +108 -87
- msprobe/mindspore/common/log.py +37 -37
- msprobe/mindspore/common/utils.py +97 -57
- msprobe/mindspore/compare/distributed_compare.py +62 -75
- msprobe/mindspore/compare/layer_mapping.py +146 -0
- msprobe/mindspore/compare/modify_mapping.py +107 -0
- msprobe/mindspore/compare/ms_compare.py +357 -117
- msprobe/mindspore/compare/ms_graph_compare.py +364 -317
- msprobe/mindspore/compare/ms_to_pt_api.yaml +399 -399
- msprobe/mindspore/debugger/debugger_config.py +69 -74
- msprobe/mindspore/debugger/precision_debugger.py +150 -107
- msprobe/mindspore/dump/dump_tool_factory.py +50 -35
- msprobe/mindspore/dump/hook_cell/api_registry.py +128 -104
- msprobe/mindspore/dump/hook_cell/hook_cell.py +55 -53
- msprobe/mindspore/dump/hook_cell/primitive_hooks.py +206 -0
- msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +994 -925
- msprobe/mindspore/dump/hook_cell/wrap_api.py +121 -0
- msprobe/mindspore/dump/jit_dump.py +96 -56
- msprobe/mindspore/dump/kernel_graph_dump.py +75 -60
- msprobe/mindspore/dump/kernel_kbyk_dump.py +79 -65
- msprobe/mindspore/free_benchmark/api_pynative_self_check.py +131 -116
- msprobe/mindspore/free_benchmark/common/config.py +27 -12
- msprobe/mindspore/free_benchmark/common/handler_params.py +32 -17
- msprobe/mindspore/free_benchmark/common/utils.py +85 -71
- msprobe/mindspore/free_benchmark/data/support_wrap_ops.yaml +842 -842
- msprobe/mindspore/free_benchmark/decorator/dec_forward.py +57 -42
- msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +122 -107
- msprobe/mindspore/free_benchmark/handler/base_handler.py +105 -90
- msprobe/mindspore/free_benchmark/handler/check_handler.py +56 -41
- msprobe/mindspore/free_benchmark/handler/fix_handler.py +51 -36
- msprobe/mindspore/free_benchmark/handler/handler_factory.py +36 -21
- msprobe/mindspore/free_benchmark/perturbation/add_noise.py +82 -67
- msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +36 -21
- msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +78 -63
- msprobe/mindspore/free_benchmark/perturbation/exchange_value.py +77 -0
- msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +49 -34
- msprobe/mindspore/free_benchmark/perturbation/no_change.py +27 -12
- msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +44 -27
- msprobe/mindspore/free_benchmark/self_check_tool_factory.py +48 -33
- msprobe/mindspore/grad_probe/global_context.py +100 -91
- msprobe/mindspore/grad_probe/grad_analyzer.py +231 -231
- msprobe/mindspore/grad_probe/grad_monitor.py +27 -27
- msprobe/mindspore/grad_probe/grad_stat_csv.py +131 -131
- msprobe/mindspore/grad_probe/hook.py +94 -92
- msprobe/mindspore/grad_probe/utils.py +29 -28
- msprobe/mindspore/ms_config.py +128 -126
- msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +60 -45
- msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +49 -34
- msprobe/mindspore/runtime.py +4 -4
- msprobe/mindspore/service.py +297 -354
- msprobe/mindspore/task_handler_factory.py +24 -24
- msprobe/msprobe.py +105 -107
- msprobe/pytorch/__init__.py +23 -4
- msprobe/pytorch/api_accuracy_checker/common/config.py +70 -55
- msprobe/pytorch/api_accuracy_checker/common/utils.py +246 -165
- msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +230 -213
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +632 -581
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_standard.yaml +132 -132
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_threshold.yaml +390 -390
- msprobe/pytorch/api_accuracy_checker/compare/compare.py +416 -381
- msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +90 -73
- msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +265 -244
- msprobe/pytorch/api_accuracy_checker/config.yaml +10 -10
- msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +370 -332
- msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +221 -199
- msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +150 -134
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +518 -581
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +213 -74
- msprobe/pytorch/api_accuracy_checker/run_ut/torch_ut_setting.json +7 -4
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +218 -202
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +370 -324
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +227 -204
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/dump_dispatch.py +110 -0
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +244 -218
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/torch_ops_config.yaml +63 -0
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +44 -0
- msprobe/pytorch/bench_functions/__init__.py +30 -15
- msprobe/pytorch/bench_functions/apply_adam_w.py +43 -28
- msprobe/pytorch/bench_functions/confusion_transpose.py +34 -19
- msprobe/pytorch/bench_functions/fast_gelu.py +70 -55
- msprobe/pytorch/bench_functions/layer_norm_eval.py +21 -6
- msprobe/pytorch/bench_functions/linear.py +27 -12
- msprobe/pytorch/bench_functions/matmul_backward.py +63 -48
- msprobe/pytorch/bench_functions/npu_fusion_attention.py +538 -421
- msprobe/pytorch/bench_functions/rms_norm.py +30 -15
- msprobe/pytorch/bench_functions/rotary_mul.py +71 -52
- msprobe/pytorch/bench_functions/scaled_mask_softmax.py +41 -26
- msprobe/pytorch/bench_functions/swiglu.py +70 -55
- msprobe/pytorch/common/__init__.py +17 -2
- msprobe/pytorch/common/compare_script.template +14 -14
- msprobe/pytorch/common/log.py +33 -32
- msprobe/pytorch/common/parse_json.py +54 -39
- msprobe/pytorch/common/utils.py +310 -300
- msprobe/pytorch/compare/distributed_compare.py +66 -66
- msprobe/pytorch/compare/mapping.yaml +607 -607
- msprobe/pytorch/compare/match.py +49 -33
- msprobe/pytorch/compare/pt_compare.py +82 -40
- msprobe/pytorch/debugger/debugger_config.py +108 -95
- msprobe/pytorch/debugger/precision_debugger.py +173 -125
- msprobe/pytorch/free_benchmark/__init__.py +23 -8
- msprobe/pytorch/free_benchmark/common/constant.py +70 -70
- msprobe/pytorch/free_benchmark/common/counter.py +71 -71
- msprobe/pytorch/free_benchmark/common/enums.py +65 -37
- msprobe/pytorch/free_benchmark/common/params.py +144 -129
- msprobe/pytorch/free_benchmark/common/utils.py +118 -102
- msprobe/pytorch/free_benchmark/compare/grad_saver.py +200 -179
- msprobe/pytorch/free_benchmark/compare/single_benchmark.py +119 -104
- msprobe/pytorch/free_benchmark/main.py +120 -105
- msprobe/pytorch/free_benchmark/perturbed_layers/base_layer.py +28 -13
- msprobe/pytorch/free_benchmark/perturbed_layers/layer_factory.py +56 -41
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +105 -90
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +119 -104
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +87 -63
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +83 -68
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/no_change.py +43 -28
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/npu_base_layser.py +60 -45
- msprobe/pytorch/free_benchmark/perturbed_layers/run_cpu.py +34 -19
- msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +256 -217
- msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +54 -39
- msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +38 -23
- msprobe/pytorch/free_benchmark/result_handlers/handler_factory.py +45 -30
- msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +185 -170
- msprobe/pytorch/function_factory.py +91 -75
- msprobe/pytorch/functional/module_dump.py +84 -0
- msprobe/pytorch/grad_probe/grad_monitor.py +91 -90
- msprobe/pytorch/grad_probe/grad_stat_csv.py +128 -128
- msprobe/pytorch/hook_module/__init__.py +16 -1
- msprobe/pytorch/hook_module/api_registry.py +166 -161
- msprobe/pytorch/hook_module/hook_module.py +118 -120
- msprobe/pytorch/hook_module/support_wrap_ops.yaml +1879 -1877
- msprobe/pytorch/hook_module/utils.py +28 -29
- msprobe/pytorch/hook_module/wrap_aten.py +111 -110
- msprobe/pytorch/hook_module/wrap_distributed.py +77 -78
- msprobe/pytorch/hook_module/wrap_functional.py +104 -105
- msprobe/pytorch/hook_module/wrap_npu_custom.py +85 -84
- msprobe/pytorch/hook_module/wrap_tensor.py +69 -71
- msprobe/pytorch/hook_module/wrap_torch.py +84 -86
- msprobe/pytorch/hook_module/wrap_vf.py +60 -62
- msprobe/pytorch/module_processer.py +153 -138
- msprobe/pytorch/online_dispatch/__init__.py +20 -20
- msprobe/pytorch/online_dispatch/compare.py +235 -236
- msprobe/pytorch/online_dispatch/dispatch.py +271 -271
- msprobe/pytorch/online_dispatch/dump_compare.py +155 -156
- msprobe/pytorch/online_dispatch/single_compare.py +391 -391
- msprobe/pytorch/online_dispatch/torch_ops_config.yaml +57 -49
- msprobe/pytorch/online_dispatch/utils.py +127 -146
- msprobe/pytorch/parse.py +19 -4
- msprobe/pytorch/parse_tool/cli.py +31 -32
- msprobe/pytorch/parse_tool/lib/compare.py +259 -271
- msprobe/pytorch/parse_tool/lib/config.py +52 -52
- msprobe/pytorch/parse_tool/lib/file_desc.py +31 -31
- msprobe/pytorch/parse_tool/lib/interactive_cli.py +102 -102
- msprobe/pytorch/parse_tool/lib/parse_exception.py +54 -54
- msprobe/pytorch/parse_tool/lib/parse_tool.py +161 -158
- msprobe/pytorch/parse_tool/lib/utils.py +320 -321
- msprobe/pytorch/parse_tool/lib/visualization.py +85 -91
- msprobe/pytorch/pt_config.py +317 -187
- msprobe/pytorch/service.py +311 -252
- mindstudio_probe-1.0.3.dist-info/RECORD +0 -272
- msprobe/config/README.md +0 -539
- msprobe/mindspore/doc/compare.md +0 -58
- msprobe/mindspore/doc/dump.md +0 -217
- msprobe/mindspore/dump/hook_cell/wrap_functional.py +0 -91
- msprobe/mindspore/dump/hook_cell/wrap_tensor.py +0 -63
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/ssl_config.py +0 -10
- msprobe/pytorch/doc/FAQ.md +0 -193
- msprobe/pytorch/doc/api_accuracy_checker.md +0 -313
- msprobe/pytorch/doc/api_accuracy_checker_online.md +0 -187
- msprobe/pytorch/doc/dump.md +0 -260
- msprobe/pytorch/doc/msprobe/321/207/342/226/223/342/225/233/321/205/342/225/221/320/266/321/205/342/225/226/320/265/321/205/320/225/342/225/226/321/206/320/245/342/226/221/321/206/320/235/320/276dump/321/206/320/260/320/227/321/205/320/227/320/226/321/206/320/220/320/267/321/210/320/223/342/225/234/321/205/320/257/342/225/221/321/207/342/225/221/342/224/220/321/206/320/232/320/265/321/205/320/241/320/232.md +0 -182
- msprobe/pytorch/doc/ptdbg_ascend_compare.md +0 -240
- msprobe/pytorch/doc/ptdbg_ascend_overview.md +0 -68
- msprobe/pytorch/doc/ptdbg_ascend_quickstart.md +0 -381
- msprobe/pytorch/doc/run_overflow_check.md +0 -25
- msprobe/pytorch/doc//321/206/320/247/320/260/321/206/320/260/320/227/321/206/320/255/320/226/321/205/342/225/226/320/265/321/205/320/225/342/225/226/321/205/320/254/342/225/221/321/206/320/251/320/277/321/211/320/272/320/234/321/210/320/277/320/221/321/205/320/242/320/234/321/206/320/220/320/267/321/210/320/223/342/225/234/321/205/320/257/342/225/221/321/207/342/225/221/342/224/220/321/206/320/232/320/265/321/205/320/241/320/232.md +0 -151
- msprobe/pytorch/functional/data_processor.py +0 -0
- msprobe/pytorch/functional/dump_module.py +0 -39
- {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.1.0.dist-info}/top_level.txt +0 -0
- /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_1.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_2.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_3.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_4.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/GPT-3_1.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/GPT-3_2.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/GPT-3_3.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/GPT-3_4.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/GPT-3_5.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/GPT-3_6.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/GPT-3_7.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/GPT-3_8.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/YOLOV5S_1.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/YOLOV5S_2.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/accuracy_checking_details.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/accuracy_checking_result.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/api_precision_compare_details.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/api_precision_compare_result.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/auto_analyze_log.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/compare_result_pkl.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/compare_result_pkl_md5.png.png +0 -0
- /msprobe/{pytorch/doc → docs}/img/cpu_info.png +0 -0
- /msprobe/{config → docs}/img/free_benchmark.png +0 -0
- /msprobe/{doc/grad_probe/img/image-1.png → docs/img/grad_probe_image-1.png} +0 -0
- /msprobe/{doc/grad_probe/img/image-2.png → docs/img/grad_probe_image-2.png} +0 -0
- /msprobe/{doc/grad_probe/img/image-3.png → docs/img/grad_probe_image-3.png} +0 -0
- /msprobe/{doc/grad_probe/img/image-4.png → docs/img/grad_probe_image-4.png} +0 -0
- /msprobe/{doc/grad_probe/img/image.png → docs/img/grad_probe_image.png} +0 -0
- /msprobe/{pytorch/doc → docs}/img/module_compare.png +0 -0
msprobe/mindspore/service.py
CHANGED
|
@@ -1,354 +1,297 @@
|
|
|
1
|
-
# Copyright 2024 Huawei Technologies Co., Ltd
|
|
2
|
-
#
|
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
-
# you may not use this file except in compliance with the License.
|
|
5
|
-
# You may obtain a copy of the License at
|
|
6
|
-
#
|
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
-
#
|
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
-
# See the License for the specific language governing permissions and
|
|
13
|
-
# limitations under the License.
|
|
14
|
-
# ============================================================================
|
|
15
|
-
|
|
16
|
-
import os
|
|
17
|
-
import copy
|
|
18
|
-
|
|
19
|
-
import
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
from mindspore
|
|
24
|
-
from mindspore import
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
from msprobe.core.data_dump.
|
|
34
|
-
from msprobe.
|
|
35
|
-
from msprobe.
|
|
36
|
-
from msprobe.
|
|
37
|
-
from msprobe.
|
|
38
|
-
from msprobe.core.common.
|
|
39
|
-
from msprobe.
|
|
40
|
-
from msprobe.mindspore.dump.hook_cell.
|
|
41
|
-
from msprobe.core.data_dump.data_processor.base import ModuleBackwardInputsOutputs, ModuleForwardInputsOutputs, \
|
|
42
|
-
ModuleBackwardInputs, ModuleBackwardOutputs
|
|
43
|
-
from msprobe.core.common.exceptions import MsprobeException
|
|
44
|
-
from msprobe.mindspore.dump.hook_cell.hook_cell import HOOKCell
|
|
45
|
-
from msprobe.mindspore.cell_processor import CellProcessor
|
|
46
|
-
from msprobe.mindspore.dump.jit_dump import JitDump
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
class Service:
|
|
50
|
-
def __init__(self, config):
|
|
51
|
-
self.model = None
|
|
52
|
-
self.config = copy.deepcopy(config)
|
|
53
|
-
self.config.level = self.config.level_ori
|
|
54
|
-
self.data_collector = build_data_collector(self.config)
|
|
55
|
-
self.cell_processor = CellProcessor(self.data_collector.scope)
|
|
56
|
-
self.
|
|
57
|
-
self.
|
|
58
|
-
self.
|
|
59
|
-
self.
|
|
60
|
-
self.
|
|
61
|
-
self.
|
|
62
|
-
self.
|
|
63
|
-
self.
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
self.
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
if
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
if not self.
|
|
103
|
-
return
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
|
|
213
|
-
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
|
|
223
|
-
|
|
224
|
-
|
|
225
|
-
|
|
226
|
-
|
|
227
|
-
return
|
|
228
|
-
|
|
229
|
-
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
|
|
233
|
-
|
|
234
|
-
|
|
235
|
-
|
|
236
|
-
|
|
237
|
-
|
|
238
|
-
|
|
239
|
-
|
|
240
|
-
|
|
241
|
-
|
|
242
|
-
|
|
243
|
-
|
|
244
|
-
|
|
245
|
-
|
|
246
|
-
|
|
247
|
-
|
|
248
|
-
def
|
|
249
|
-
self.
|
|
250
|
-
self.
|
|
251
|
-
|
|
252
|
-
|
|
253
|
-
|
|
254
|
-
|
|
255
|
-
|
|
256
|
-
|
|
257
|
-
|
|
258
|
-
|
|
259
|
-
|
|
260
|
-
|
|
261
|
-
|
|
262
|
-
|
|
263
|
-
|
|
264
|
-
|
|
265
|
-
|
|
266
|
-
|
|
267
|
-
|
|
268
|
-
|
|
269
|
-
|
|
270
|
-
|
|
271
|
-
|
|
272
|
-
self.
|
|
273
|
-
|
|
274
|
-
|
|
275
|
-
|
|
276
|
-
|
|
277
|
-
|
|
278
|
-
|
|
279
|
-
|
|
280
|
-
|
|
281
|
-
|
|
282
|
-
|
|
283
|
-
|
|
284
|
-
|
|
285
|
-
|
|
286
|
-
|
|
287
|
-
|
|
288
|
-
|
|
289
|
-
|
|
290
|
-
|
|
291
|
-
|
|
292
|
-
|
|
293
|
-
|
|
294
|
-
|
|
295
|
-
|
|
296
|
-
|
|
297
|
-
|
|
298
|
-
self.start_call = False
|
|
299
|
-
self.data_collector.write_json()
|
|
300
|
-
|
|
301
|
-
def create_dirs(self):
|
|
302
|
-
check_path_before_create(self.config.dump_path)
|
|
303
|
-
if not os.path.exists(self.config.dump_path):
|
|
304
|
-
Path(self.config.dump_path).mkdir(mode=0o750, exist_ok=True)
|
|
305
|
-
file_check = FileChecker(self.config.dump_path, FileCheckConst.DIR)
|
|
306
|
-
file_check.common_check()
|
|
307
|
-
self.dump_iter_dir = os.path.join(self.config.dump_path, f"step{self.current_iter}")
|
|
308
|
-
cur_rank = self.current_rank if self.current_rank is not None else ''
|
|
309
|
-
dump_dir = os.path.join(self.dump_iter_dir, f"rank{cur_rank}")
|
|
310
|
-
if not os.path.exists(dump_dir):
|
|
311
|
-
Path(dump_dir).mkdir(mode=0o750, parents=True, exist_ok=True)
|
|
312
|
-
if self.config.task in self.data_collector.tasks_need_tensor_data:
|
|
313
|
-
dump_data_dir = os.path.join(dump_dir, "dump_tensor_data")
|
|
314
|
-
Path(dump_data_dir).mkdir(mode=0o750, exist_ok=True)
|
|
315
|
-
else:
|
|
316
|
-
dump_data_dir = None
|
|
317
|
-
|
|
318
|
-
dump_file_path = os.path.join(dump_dir, "dump.json")
|
|
319
|
-
stack_file_path = os.path.join(dump_dir, "stack.json")
|
|
320
|
-
construct_file_path = os.path.join(dump_dir, "construct.json")
|
|
321
|
-
self.data_collector.update_dump_paths(
|
|
322
|
-
dump_file_path, stack_file_path, construct_file_path, dump_data_dir, None)
|
|
323
|
-
|
|
324
|
-
def empty(self, *args, **kwargs):
|
|
325
|
-
pass
|
|
326
|
-
|
|
327
|
-
def register_hook_new(self):
|
|
328
|
-
logger.info("The {} hook function is successfully mounted to the model.".format(self.config.task))
|
|
329
|
-
if self.config.level == "L1":
|
|
330
|
-
api_register.initialize_hook(functools.partial(self.build_hook, BaseScope.Module_Type_API))
|
|
331
|
-
api_register.api_set_hook_func()
|
|
332
|
-
if self.model:
|
|
333
|
-
self.register_hooks()
|
|
334
|
-
|
|
335
|
-
if self.config.level == "L0":
|
|
336
|
-
if not self.model:
|
|
337
|
-
raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR, "The current level is L0, the model cannot be None")
|
|
338
|
-
for name, cell in self.model.cells_and_names():
|
|
339
|
-
if cell == self.model:
|
|
340
|
-
continue
|
|
341
|
-
prefix = 'Cell' + Const.SEP + name + Const.SEP + \
|
|
342
|
-
cell.__class__.__name__ + Const.SEP
|
|
343
|
-
forward_hook, backward_hook = self.build_hook(BaseScope.Module_Type_Module, prefix)
|
|
344
|
-
cell.register_forward_hook(forward_hook)
|
|
345
|
-
cell.register_backward_hook(backward_hook)
|
|
346
|
-
|
|
347
|
-
cell.register_forward_pre_hook(
|
|
348
|
-
self.cell_processor.node_hook(prefix + Const.FORWARD, Const.START))
|
|
349
|
-
cell.register_forward_hook(
|
|
350
|
-
self.cell_processor.node_hook(prefix + Const.FORWARD, Const.STOP))
|
|
351
|
-
cell.register_backward_pre_hook(
|
|
352
|
-
self.cell_processor.node_hook(prefix + Const.BACKWARD, Const.START))
|
|
353
|
-
cell.register_backward_hook(
|
|
354
|
-
self.cell_processor.node_hook(prefix + Const.BACKWARD, Const.STOP))
|
|
1
|
+
# Copyright 2024 Huawei Technologies Co., Ltd
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ============================================================================
|
|
15
|
+
|
|
16
|
+
import os
|
|
17
|
+
import copy
|
|
18
|
+
import functools
|
|
19
|
+
from collections import defaultdict
|
|
20
|
+
|
|
21
|
+
import mindspore as ms
|
|
22
|
+
from mindspore.common.tensor import Tensor
|
|
23
|
+
from mindspore import ops
|
|
24
|
+
from mindspore import nn
|
|
25
|
+
try:
|
|
26
|
+
from mindspore.common._pijit_context import PIJitCaptureContext
|
|
27
|
+
pijit_label = True
|
|
28
|
+
except ImportError:
|
|
29
|
+
pijit_label = False
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
from msprobe.core.data_dump.data_collector import build_data_collector
|
|
33
|
+
from msprobe.core.data_dump.scope import BaseScope
|
|
34
|
+
from msprobe.mindspore.common.utils import get_rank_if_initialized
|
|
35
|
+
from msprobe.core.common.file_utils import create_directory
|
|
36
|
+
from msprobe.mindspore.common.log import logger
|
|
37
|
+
from msprobe.core.common.utils import Const, print_tools_ends_info
|
|
38
|
+
from msprobe.core.common.exceptions import DistributedNotInitializedError
|
|
39
|
+
from msprobe.mindspore.dump.hook_cell.api_registry import api_register
|
|
40
|
+
from msprobe.mindspore.dump.hook_cell.primitive_hooks import PrimitiveHookService
|
|
41
|
+
from msprobe.core.data_dump.data_processor.base import ModuleBackwardInputsOutputs, ModuleForwardInputsOutputs, \
|
|
42
|
+
ModuleBackwardInputs, ModuleBackwardOutputs
|
|
43
|
+
from msprobe.core.common.exceptions import MsprobeException
|
|
44
|
+
from msprobe.mindspore.dump.hook_cell.hook_cell import HOOKCell
|
|
45
|
+
from msprobe.mindspore.cell_processor import CellProcessor
|
|
46
|
+
from msprobe.mindspore.dump.jit_dump import JitDump
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
class Service:
|
|
50
|
+
def __init__(self, config):
|
|
51
|
+
self.model = None
|
|
52
|
+
self.config = copy.deepcopy(config)
|
|
53
|
+
self.config.level = self.config.level_ori
|
|
54
|
+
self.data_collector = build_data_collector(self.config)
|
|
55
|
+
self.cell_processor = CellProcessor(self.data_collector.scope)
|
|
56
|
+
self.primitive_hook_service = PrimitiveHookService(self)
|
|
57
|
+
self.switch = False
|
|
58
|
+
self.primitive_switch = False
|
|
59
|
+
self.current_iter = 0
|
|
60
|
+
self.first_start = True
|
|
61
|
+
self.current_rank = None
|
|
62
|
+
self.dump_iter_dir = None
|
|
63
|
+
self.start_call = False
|
|
64
|
+
self.check_level_valid()
|
|
65
|
+
self.should_stop_service = False
|
|
66
|
+
|
|
67
|
+
@staticmethod
|
|
68
|
+
def check_model_valid(model):
|
|
69
|
+
if not model or isinstance(model, nn.Cell):
|
|
70
|
+
return model
|
|
71
|
+
raise MsprobeException(
|
|
72
|
+
MsprobeException.INVALID_PARAM_ERROR, "model 参数必须是 mindspore.nn.Cell 类型。"
|
|
73
|
+
)
|
|
74
|
+
|
|
75
|
+
def check_level_valid(self):
|
|
76
|
+
if self.config.level == Const.LEVEL_L2:
|
|
77
|
+
raise MsprobeException(
|
|
78
|
+
MsprobeException.INVALID_PARAM_ERROR, "L2 level dump function is currently not supported."
|
|
79
|
+
)
|
|
80
|
+
|
|
81
|
+
def build_hook(self, target_type, name):
|
|
82
|
+
def forward_hook(api_or_cell_name, cell, input, output):
|
|
83
|
+
if not self.should_excute_hook():
|
|
84
|
+
return None
|
|
85
|
+
|
|
86
|
+
if target_type == BaseScope.Module_Type_Module:
|
|
87
|
+
api_or_cell_name = cell.mindstudio_reserved_name
|
|
88
|
+
module_input_output = ModuleForwardInputsOutputs(args=input, kwargs={}, output=output)
|
|
89
|
+
else:
|
|
90
|
+
module_input_output = ModuleForwardInputsOutputs(args=input, kwargs=cell.input_kwargs,
|
|
91
|
+
output=output)
|
|
92
|
+
|
|
93
|
+
self.data_collector.update_api_or_module_name(api_or_cell_name)
|
|
94
|
+
self.data_collector.forward_data_collect(api_or_cell_name, cell, pid, module_input_output)
|
|
95
|
+
if self.data_collector.if_return_forward_new_output():
|
|
96
|
+
return self.data_collector.get_forward_new_output()
|
|
97
|
+
if target_type == BaseScope.Module_Type_API:
|
|
98
|
+
del cell.input_kwargs
|
|
99
|
+
return output
|
|
100
|
+
|
|
101
|
+
def backward_hook(api_or_cell_name, cell, grad_input, grad_output):
|
|
102
|
+
if not self.should_excute_hook():
|
|
103
|
+
return
|
|
104
|
+
|
|
105
|
+
if target_type == BaseScope.Module_Type_Module:
|
|
106
|
+
api_or_cell_name = cell.mindstudio_reserved_name
|
|
107
|
+
self.data_collector.update_api_or_module_name(api_or_cell_name)
|
|
108
|
+
if self.data_collector:
|
|
109
|
+
# 框架最新接口变更,grad_input和grad_output的含义发生了变化,与torch含义保持一致,因此此处调换顺序传入
|
|
110
|
+
module_input_output = ModuleBackwardInputsOutputs(grad_input=grad_output, grad_output=grad_input)
|
|
111
|
+
self.data_collector.backward_data_collect(api_or_cell_name, cell, pid, module_input_output)
|
|
112
|
+
|
|
113
|
+
pid = os.getpid()
|
|
114
|
+
forward_name_template = name + Const.FORWARD
|
|
115
|
+
backward_name_template = name + Const.BACKWARD
|
|
116
|
+
forward_hook = functools.partial(forward_hook, forward_name_template)
|
|
117
|
+
backward_hook = functools.partial(backward_hook, backward_name_template)
|
|
118
|
+
|
|
119
|
+
def wrap_forward_hook(cell, input, output):
|
|
120
|
+
return forward_hook(cell, input, output)
|
|
121
|
+
|
|
122
|
+
def wrap_backward_hook(cell, grad_input, grad_output):
|
|
123
|
+
return backward_hook(cell, grad_input, grad_output)
|
|
124
|
+
|
|
125
|
+
return wrap_forward_hook, wrap_backward_hook
|
|
126
|
+
|
|
127
|
+
|
|
128
|
+
def update_primitive_counters(self, primitive_name):
|
|
129
|
+
if primitive_name not in self.primitive_counters:
|
|
130
|
+
self.primitive_counters[primitive_name] = 0
|
|
131
|
+
else:
|
|
132
|
+
self.primitive_counters[primitive_name] += 1
|
|
133
|
+
|
|
134
|
+
def register_primitive_hooks(self):
|
|
135
|
+
primitive_set = set()
|
|
136
|
+
for _, cell in self.model.cells_and_names():
|
|
137
|
+
for pname, primitive in cell._primitives.items():
|
|
138
|
+
primitive_set.add((pname, primitive))
|
|
139
|
+
|
|
140
|
+
for pname, primitive in primitive_set:
|
|
141
|
+
NewPrimitive = type('NewPrimitive', (primitive.__class__,),
|
|
142
|
+
{'__call__': self.primitive_hook_service.wrap_primitive(primitive.__call__, pname)})
|
|
143
|
+
primitive.__class__ = NewPrimitive
|
|
144
|
+
|
|
145
|
+
def step(self):
|
|
146
|
+
self.current_iter += 1
|
|
147
|
+
self.data_collector.update_iter(self.current_iter)
|
|
148
|
+
HOOKCell.cell_count = defaultdict(int)
|
|
149
|
+
CellProcessor.reset_cell_stats()
|
|
150
|
+
self.primitive_hook_service.primitive_counters.clear()
|
|
151
|
+
self.data_collector.data_writer.reset_cache()
|
|
152
|
+
JitDump.jit_count = defaultdict(int)
|
|
153
|
+
|
|
154
|
+
def start(self, model=None):
|
|
155
|
+
self.start_call = True
|
|
156
|
+
if self.should_stop_service:
|
|
157
|
+
return
|
|
158
|
+
if self.need_end_service():
|
|
159
|
+
api_register.api_set_ori_func()
|
|
160
|
+
self.should_stop_service = True
|
|
161
|
+
self.switch = False
|
|
162
|
+
self.primitive_switch = False
|
|
163
|
+
print_tools_ends_info()
|
|
164
|
+
return
|
|
165
|
+
if self.config.step and self.current_iter not in self.config.step:
|
|
166
|
+
return
|
|
167
|
+
self.model = self.check_model_valid(model)
|
|
168
|
+
|
|
169
|
+
logger.info(f"{Const.TOOL_NAME}: debugger.start() is set successfully")
|
|
170
|
+
|
|
171
|
+
if self.first_start:
|
|
172
|
+
try:
|
|
173
|
+
self.current_rank = get_rank_if_initialized()
|
|
174
|
+
except DistributedNotInitializedError:
|
|
175
|
+
self.current_rank = None
|
|
176
|
+
|
|
177
|
+
if self.config.rank and self.current_rank not in self.config.rank:
|
|
178
|
+
return
|
|
179
|
+
self.register_hook_new()
|
|
180
|
+
if self.config.level in [Const.LEVEL_MIX, Const.LEVEL_L1]:
|
|
181
|
+
JitDump.set_config(self.config)
|
|
182
|
+
JitDump.set_data_collector(self.data_collector)
|
|
183
|
+
ms.common.api._MindsporeFunctionExecutor = JitDump
|
|
184
|
+
ms.common.api._PyNativeExecutor.grad = JitDump.grad
|
|
185
|
+
if pijit_label:
|
|
186
|
+
PIJitCaptureContext.__enter__ = self.empty
|
|
187
|
+
PIJitCaptureContext.__exit__ = self.empty
|
|
188
|
+
self.first_start = False
|
|
189
|
+
|
|
190
|
+
api_register.api_set_hook_func()
|
|
191
|
+
self.switch = True
|
|
192
|
+
self.primitive_switch = True
|
|
193
|
+
logger.info(f"Dump switch is turned on at step {self.current_iter}. ")
|
|
194
|
+
self.create_dirs()
|
|
195
|
+
logger.info(f"Dump data will be saved in {self.dump_iter_dir}.")
|
|
196
|
+
JitDump.jit_dump_switch = True
|
|
197
|
+
|
|
198
|
+
def forward_backward_dump_end(self):
|
|
199
|
+
if self.should_stop_service:
|
|
200
|
+
return
|
|
201
|
+
logger.info(f"{Const.TOOL_NAME}: debugger.forward_backward_dump_end() is set successfully. ")
|
|
202
|
+
if not self.start_call:
|
|
203
|
+
logger.error(f"{Const.TOOL_NAME}: debugger.start() is not set in the current scope.")
|
|
204
|
+
raise Exception("debugger.start() is not set in the current scope.")
|
|
205
|
+
if not self.switch:
|
|
206
|
+
logger.error(f"{Const.TOOL_NAME}: debugger.forward_backward_dump_end() should be called between "
|
|
207
|
+
"debugger.start() and debugger.stop() ")
|
|
208
|
+
raise Exception("debugger.stop() is already called. ")
|
|
209
|
+
if self.config.step and self.current_iter not in self.config.step:
|
|
210
|
+
return
|
|
211
|
+
if self.config.rank and self.current_rank not in self.config.rank:
|
|
212
|
+
return
|
|
213
|
+
self.primitive_switch = False
|
|
214
|
+
api_register.api_set_ori_func()
|
|
215
|
+
|
|
216
|
+
def stop(self):
|
|
217
|
+
if self.should_stop_service:
|
|
218
|
+
return
|
|
219
|
+
logger.info(f"{Const.TOOL_NAME}: debugger.stop() is set successfully. "
|
|
220
|
+
"Please set debugger.start() to turn on the dump switch again. ")
|
|
221
|
+
if not self.start_call:
|
|
222
|
+
logger.error(f"{Const.TOOL_NAME}: debugger.start() is not set in the current scope.")
|
|
223
|
+
raise Exception("debugger.start() is not set in the current scope.")
|
|
224
|
+
if self.config.step and self.current_iter not in self.config.step:
|
|
225
|
+
return
|
|
226
|
+
if self.config.rank and self.current_rank not in self.config.rank:
|
|
227
|
+
return
|
|
228
|
+
self.switch = False
|
|
229
|
+
self.primitive_switch = False
|
|
230
|
+
self.start_call = False
|
|
231
|
+
self.data_collector.write_json()
|
|
232
|
+
JitDump.jit_dump_switch = False
|
|
233
|
+
|
|
234
|
+
def need_end_service(self):
|
|
235
|
+
if self.config.step and self.current_iter > max(self.config.step):
|
|
236
|
+
return True
|
|
237
|
+
if self.data_collector and self.data_collector.data_processor.is_terminated:
|
|
238
|
+
return True
|
|
239
|
+
return False
|
|
240
|
+
|
|
241
|
+
def should_excute_hook(self):
|
|
242
|
+
if not self.switch:
|
|
243
|
+
return False
|
|
244
|
+
if not self.data_collector or self.data_collector.data_processor.is_terminated:
|
|
245
|
+
return False
|
|
246
|
+
return True
|
|
247
|
+
|
|
248
|
+
def create_dirs(self):
|
|
249
|
+
create_directory(self.config.dump_path)
|
|
250
|
+
self.dump_iter_dir = os.path.join(self.config.dump_path, f"step{self.current_iter}")
|
|
251
|
+
cur_rank = self.current_rank if self.current_rank is not None else ''
|
|
252
|
+
dump_dir = os.path.join(self.dump_iter_dir, f"rank{cur_rank}")
|
|
253
|
+
create_directory(dump_dir)
|
|
254
|
+
if self.config.task in self.data_collector.tasks_need_tensor_data:
|
|
255
|
+
dump_data_dir = os.path.join(dump_dir, "dump_tensor_data")
|
|
256
|
+
create_directory(dump_data_dir)
|
|
257
|
+
else:
|
|
258
|
+
dump_data_dir = None
|
|
259
|
+
|
|
260
|
+
dump_file_path = os.path.join(dump_dir, "dump.json")
|
|
261
|
+
stack_file_path = os.path.join(dump_dir, "stack.json")
|
|
262
|
+
construct_file_path = os.path.join(dump_dir, "construct.json")
|
|
263
|
+
self.data_collector.update_dump_paths(
|
|
264
|
+
dump_file_path, stack_file_path, construct_file_path, dump_data_dir, None)
|
|
265
|
+
|
|
266
|
+
def empty(self, *args, **kwargs):
|
|
267
|
+
pass
|
|
268
|
+
|
|
269
|
+
def register_hook_new(self):
|
|
270
|
+
logger.info("The {} hook function is successfully mounted to the model.".format(self.config.task))
|
|
271
|
+
if self.config.level in [Const.LEVEL_MIX, Const.LEVEL_L1]:
|
|
272
|
+
api_register.initialize_hook(functools.partial(self.build_hook, BaseScope.Module_Type_API))
|
|
273
|
+
api_register.api_set_hook_func()
|
|
274
|
+
if self.model and self.config.task in Const.DUMP_DATA_COLLECTION_LIST:
|
|
275
|
+
self.register_primitive_hooks()
|
|
276
|
+
|
|
277
|
+
if self.config.level in [Const.LEVEL_MIX, Const.LEVEL_L0]:
|
|
278
|
+
if not self.model:
|
|
279
|
+
raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR,
|
|
280
|
+
f"The current level is {self.config.level}, the model cannot be None")
|
|
281
|
+
for name, cell in self.model.cells_and_names():
|
|
282
|
+
if cell == self.model:
|
|
283
|
+
continue
|
|
284
|
+
prefix = 'Cell' + Const.SEP + name + Const.SEP + \
|
|
285
|
+
cell.__class__.__name__ + Const.SEP
|
|
286
|
+
forward_hook, backward_hook = self.build_hook(BaseScope.Module_Type_Module, prefix)
|
|
287
|
+
cell.register_forward_hook(forward_hook)
|
|
288
|
+
cell.register_backward_hook(backward_hook)
|
|
289
|
+
|
|
290
|
+
cell.register_forward_pre_hook(
|
|
291
|
+
self.cell_processor.node_hook(prefix + Const.FORWARD, Const.START))
|
|
292
|
+
cell.register_forward_hook(
|
|
293
|
+
self.cell_processor.node_hook(prefix + Const.FORWARD, Const.STOP))
|
|
294
|
+
cell.register_backward_pre_hook(
|
|
295
|
+
self.cell_processor.node_hook(prefix + Const.BACKWARD, Const.START))
|
|
296
|
+
cell.register_backward_hook(
|
|
297
|
+
self.cell_processor.node_hook(prefix + Const.BACKWARD, Const.STOP))
|