mindstudio-probe 1.0.4__py3-none-any.whl → 1.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.1.dist-info}/METADATA +5 -5
- mindstudio_probe-1.1.1.dist-info/RECORD +341 -0
- {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.1.dist-info}/WHEEL +1 -1
- {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.1.dist-info}/entry_points.txt +0 -1
- msprobe/README.md +84 -18
- msprobe/__init__.py +16 -1
- msprobe/config.json +1 -5
- msprobe/core/advisor/advisor.py +16 -11
- msprobe/core/advisor/advisor_const.py +6 -7
- msprobe/core/advisor/advisor_result.py +12 -12
- msprobe/core/common/const.py +164 -3
- msprobe/core/common/exceptions.py +26 -4
- msprobe/core/common/file_utils.py +196 -27
- msprobe/core/common/inplace_op_checker.py +53 -0
- msprobe/core/common/inplace_ops.yaml +251 -0
- msprobe/core/common/log.py +46 -18
- msprobe/core/common/utils.py +308 -209
- msprobe/core/common_config.py +60 -38
- msprobe/core/compare/acc_compare.py +332 -94
- msprobe/core/compare/check.py +104 -22
- msprobe/core/compare/compare_cli.py +42 -5
- msprobe/core/compare/highlight.py +162 -57
- msprobe/core/compare/layer_mapping/__init__.py +19 -0
- msprobe/core/compare/layer_mapping/data_scope_parser.py +235 -0
- msprobe/core/compare/layer_mapping/layer_mapping.py +242 -0
- msprobe/core/compare/layer_mapping/postprocess_pass.py +94 -0
- msprobe/core/compare/multiprocessing_compute.py +33 -8
- msprobe/core/compare/npy_compare.py +73 -29
- msprobe/core/compare/utils.py +306 -247
- msprobe/core/data_dump/data_collector.py +44 -43
- msprobe/core/data_dump/data_processor/base.py +88 -35
- msprobe/core/data_dump/data_processor/factory.py +20 -3
- msprobe/core/data_dump/data_processor/mindspore_processor.py +14 -8
- msprobe/core/data_dump/data_processor/pytorch_processor.py +180 -66
- msprobe/core/data_dump/json_writer.py +63 -42
- msprobe/core/data_dump/scope.py +143 -48
- msprobe/core/grad_probe/constant.py +31 -13
- msprobe/core/grad_probe/grad_compare.py +20 -4
- msprobe/core/grad_probe/utils.py +44 -3
- msprobe/core/overflow_check/abnormal_scene.py +185 -0
- msprobe/core/overflow_check/api_info.py +55 -0
- msprobe/core/overflow_check/checker.py +138 -0
- msprobe/core/overflow_check/filter.py +157 -0
- msprobe/core/overflow_check/ignore_rules.yaml +55 -0
- msprobe/core/overflow_check/level.py +22 -0
- msprobe/core/overflow_check/utils.py +28 -0
- msprobe/docs/01.installation.md +29 -9
- msprobe/docs/02.config_introduction.md +83 -84
- msprobe/docs/03.config_examples.md +3 -20
- msprobe/docs/04.kernel_dump_PyTorch.md +73 -0
- msprobe/docs/05.data_dump_PyTorch.md +143 -13
- msprobe/docs/06.data_dump_MindSpore.md +197 -88
- msprobe/docs/07.accuracy_checker_PyTorch.md +69 -46
- msprobe/docs/08.accuracy_checker_online_PyTorch.md +52 -17
- msprobe/docs/09.accuracy_checker_MindSpore.md +51 -15
- msprobe/docs/10.accuracy_compare_PyTorch.md +187 -99
- msprobe/docs/11.accuracy_compare_MindSpore.md +253 -31
- msprobe/docs/12.overflow_check_PyTorch.md +1 -1
- msprobe/docs/13.overflow_check_MindSpore.md +6 -6
- msprobe/docs/15.free_benchmarking_PyTorch.md +60 -55
- msprobe/docs/16.free_benchmarking_MindSpore.md +159 -0
- msprobe/docs/17.grad_probe.md +19 -22
- msprobe/docs/18.online_dispatch.md +89 -0
- msprobe/docs/19.monitor.md +468 -0
- msprobe/docs/20.monitor_performance_baseline.md +52 -0
- msprobe/docs/21.visualization_PyTorch.md +386 -0
- msprobe/docs/22.visualization_MindSpore.md +384 -0
- msprobe/docs/23.tool_function_introduction.md +28 -0
- msprobe/docs/{FAQ_PyTorch.md → FAQ.md} +25 -10
- msprobe/docs/data_dump_Mindspore/dynamic_graph_quick_start_example.md +211 -0
- msprobe/docs/img/compare_result.png +0 -0
- msprobe/docs/img/monitor/cpu_info.png +0 -0
- msprobe/docs/img/ms_dump.png +0 -0
- msprobe/docs/img/ms_layer.png +0 -0
- msprobe/docs/img/pt_dump.png +0 -0
- msprobe/mindspore/__init__.py +16 -0
- msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +130 -138
- msprobe/mindspore/api_accuracy_checker/api_info.py +27 -5
- msprobe/mindspore/api_accuracy_checker/api_runner.py +43 -18
- msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +21 -7
- msprobe/mindspore/api_accuracy_checker/checker_support_api.yaml +77 -0
- msprobe/mindspore/api_accuracy_checker/cmd_parser.py +63 -1
- msprobe/mindspore/api_accuracy_checker/compute_element.py +59 -24
- msprobe/mindspore/api_accuracy_checker/data_manager.py +264 -0
- msprobe/mindspore/api_accuracy_checker/main.py +27 -3
- msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +206 -0
- msprobe/mindspore/api_accuracy_checker/multi_data_manager.py +58 -0
- msprobe/mindspore/api_accuracy_checker/type_mapping.py +22 -5
- msprobe/mindspore/api_accuracy_checker/utils.py +34 -17
- msprobe/mindspore/cell_processor.py +58 -13
- msprobe/mindspore/common/const.py +35 -13
- msprobe/mindspore/common/log.py +5 -9
- msprobe/mindspore/common/utils.py +60 -5
- msprobe/mindspore/compare/distributed_compare.py +15 -28
- msprobe/mindspore/compare/ms_compare.py +319 -158
- msprobe/mindspore/compare/ms_graph_compare.py +99 -49
- msprobe/mindspore/debugger/debugger_config.py +20 -14
- msprobe/mindspore/debugger/precision_debugger.py +43 -13
- msprobe/mindspore/dump/dump_tool_factory.py +18 -1
- msprobe/mindspore/dump/hook_cell/api_registry.py +23 -3
- msprobe/mindspore/dump/hook_cell/primitive_hooks.py +203 -0
- msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +107 -10
- msprobe/mindspore/dump/hook_cell/wrap_api.py +21 -13
- msprobe/mindspore/dump/jit_dump.py +56 -20
- msprobe/mindspore/dump/kernel_graph_dump.py +19 -5
- msprobe/mindspore/dump/kernel_kbyk_dump.py +19 -6
- msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +140 -0
- msprobe/mindspore/dym_loader/hook_dynamic_loader.h +53 -0
- msprobe/mindspore/free_benchmark/api_pynative_self_check.py +162 -41
- msprobe/mindspore/free_benchmark/common/config.py +15 -0
- msprobe/mindspore/free_benchmark/common/handler_params.py +15 -1
- msprobe/mindspore/free_benchmark/common/utils.py +37 -8
- msprobe/mindspore/free_benchmark/data/support_wrap_ops.yaml +0 -204
- msprobe/mindspore/free_benchmark/handler/base_handler.py +20 -5
- msprobe/mindspore/free_benchmark/handler/check_handler.py +21 -7
- msprobe/mindspore/free_benchmark/handler/fix_handler.py +18 -3
- msprobe/mindspore/free_benchmark/handler/handler_factory.py +21 -6
- msprobe/mindspore/free_benchmark/perturbation/add_noise.py +23 -8
- msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +29 -5
- msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +25 -10
- msprobe/mindspore/free_benchmark/perturbation/exchange_value.py +45 -19
- msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +29 -8
- msprobe/mindspore/free_benchmark/perturbation/no_change.py +16 -1
- msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +22 -7
- msprobe/mindspore/free_benchmark/self_check_tool_factory.py +17 -2
- msprobe/mindspore/grad_probe/global_context.py +44 -14
- msprobe/mindspore/grad_probe/grad_analyzer.py +27 -13
- msprobe/mindspore/grad_probe/grad_monitor.py +16 -1
- msprobe/mindspore/grad_probe/grad_stat_csv.py +33 -5
- msprobe/mindspore/grad_probe/hook.py +24 -10
- msprobe/mindspore/grad_probe/utils.py +18 -5
- msprobe/mindspore/ms_config.py +22 -15
- msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +20 -6
- msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +15 -0
- msprobe/mindspore/runtime.py +15 -0
- msprobe/mindspore/service.py +75 -150
- msprobe/mindspore/task_handler_factory.py +15 -0
- msprobe/msprobe.py +24 -7
- msprobe/pytorch/__init__.py +23 -3
- msprobe/pytorch/api_accuracy_checker/common/config.py +81 -2
- msprobe/pytorch/api_accuracy_checker/common/utils.py +53 -21
- msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +19 -2
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +50 -25
- msprobe/pytorch/api_accuracy_checker/compare/compare.py +51 -21
- msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +23 -6
- msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +28 -8
- msprobe/pytorch/api_accuracy_checker/config.yaml +1 -1
- msprobe/pytorch/api_accuracy_checker/generate_op_script/config_op.json +9 -0
- msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +454 -0
- msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +365 -0
- msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +73 -33
- msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +44 -18
- msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +32 -11
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +122 -172
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +158 -4
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +30 -24
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +68 -31
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +27 -4
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/dump_dispatch.py +115 -0
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +26 -9
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/torch_ops_config.yaml +63 -0
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +44 -0
- msprobe/pytorch/bench_functions/__init__.py +18 -3
- msprobe/pytorch/bench_functions/apply_adam_w.py +15 -0
- msprobe/pytorch/bench_functions/confusion_transpose.py +20 -1
- msprobe/pytorch/bench_functions/fast_gelu.py +15 -0
- msprobe/pytorch/bench_functions/layer_norm_eval.py +15 -0
- msprobe/pytorch/bench_functions/linear.py +15 -0
- msprobe/pytorch/bench_functions/matmul_backward.py +33 -6
- msprobe/pytorch/bench_functions/npu_fusion_attention.py +280 -157
- msprobe/pytorch/bench_functions/rms_norm.py +15 -0
- msprobe/pytorch/bench_functions/rotary_mul.py +32 -9
- msprobe/pytorch/bench_functions/scaled_mask_softmax.py +15 -0
- msprobe/pytorch/bench_functions/swiglu.py +29 -6
- msprobe/pytorch/common/__init__.py +15 -0
- msprobe/pytorch/common/log.py +18 -6
- msprobe/pytorch/common/parse_json.py +31 -16
- msprobe/pytorch/common/utils.py +96 -40
- msprobe/pytorch/compare/distributed_compare.py +13 -14
- msprobe/pytorch/compare/match.py +15 -0
- msprobe/pytorch/compare/pt_compare.py +44 -10
- msprobe/pytorch/debugger/debugger_config.py +69 -52
- msprobe/pytorch/debugger/precision_debugger.py +72 -24
- msprobe/pytorch/dump/kernel_dump/kernel_config.py +33 -0
- msprobe/pytorch/free_benchmark/__init__.py +20 -5
- msprobe/pytorch/free_benchmark/common/constant.py +15 -0
- msprobe/pytorch/free_benchmark/common/counter.py +15 -0
- msprobe/pytorch/free_benchmark/common/enums.py +43 -0
- msprobe/pytorch/free_benchmark/common/params.py +23 -1
- msprobe/pytorch/free_benchmark/common/utils.py +43 -5
- msprobe/pytorch/free_benchmark/compare/grad_saver.py +47 -9
- msprobe/pytorch/free_benchmark/compare/single_benchmark.py +17 -0
- msprobe/pytorch/free_benchmark/main.py +19 -4
- msprobe/pytorch/free_benchmark/perturbed_layers/base_layer.py +15 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/layer_factory.py +19 -4
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +18 -1
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +21 -4
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +28 -2
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +19 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/no_change.py +15 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/npu_base_layser.py +15 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/run_cpu.py +15 -0
- msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +65 -16
- msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +15 -0
- msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +21 -5
- msprobe/pytorch/free_benchmark/result_handlers/handler_factory.py +15 -0
- msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +19 -4
- msprobe/pytorch/function_factory.py +17 -2
- msprobe/pytorch/functional/module_dump.py +84 -0
- msprobe/pytorch/grad_probe/grad_monitor.py +23 -6
- msprobe/pytorch/grad_probe/grad_stat_csv.py +40 -10
- msprobe/pytorch/hook_module/__init__.py +16 -1
- msprobe/pytorch/hook_module/api_registry.py +13 -8
- msprobe/pytorch/hook_module/hook_module.py +17 -19
- msprobe/pytorch/hook_module/support_wrap_ops.yaml +1 -0
- msprobe/pytorch/hook_module/utils.py +4 -6
- msprobe/pytorch/hook_module/wrap_aten.py +12 -11
- msprobe/pytorch/hook_module/wrap_distributed.py +6 -7
- msprobe/pytorch/hook_module/wrap_functional.py +21 -20
- msprobe/pytorch/hook_module/wrap_npu_custom.py +9 -17
- msprobe/pytorch/hook_module/wrap_tensor.py +4 -6
- msprobe/pytorch/hook_module/wrap_torch.py +4 -6
- msprobe/pytorch/hook_module/wrap_vf.py +4 -6
- msprobe/pytorch/module_processer.py +18 -6
- msprobe/pytorch/monitor/anomaly_analyse.py +201 -0
- msprobe/pytorch/monitor/anomaly_detect.py +340 -0
- msprobe/pytorch/monitor/distributed/distributed_ops.yaml +19 -0
- msprobe/pytorch/monitor/distributed/stack_blacklist.yaml +5 -0
- msprobe/pytorch/monitor/distributed/wrap_distributed.py +272 -0
- msprobe/pytorch/monitor/features.py +108 -0
- msprobe/pytorch/monitor/module_hook.py +870 -0
- msprobe/pytorch/monitor/module_metric.py +193 -0
- msprobe/pytorch/monitor/module_spec_verifier.py +93 -0
- msprobe/pytorch/monitor/optimizer_collect.py +295 -0
- msprobe/pytorch/monitor/unittest/__init__.py +0 -0
- msprobe/pytorch/monitor/unittest/test_monitor.py +145 -0
- msprobe/pytorch/monitor/utils.py +250 -0
- msprobe/pytorch/monitor/visualizer.py +59 -0
- msprobe/pytorch/online_dispatch/__init__.py +2 -3
- msprobe/pytorch/online_dispatch/compare.py +38 -48
- msprobe/pytorch/online_dispatch/dispatch.py +50 -25
- msprobe/pytorch/online_dispatch/dump_compare.py +21 -9
- msprobe/pytorch/online_dispatch/single_compare.py +60 -39
- msprobe/pytorch/online_dispatch/torch_ops_config.yaml +9 -1
- msprobe/pytorch/online_dispatch/utils.py +48 -23
- msprobe/pytorch/parse.py +15 -0
- msprobe/pytorch/parse_tool/cli.py +5 -6
- msprobe/pytorch/parse_tool/lib/compare.py +19 -26
- msprobe/pytorch/parse_tool/lib/config.py +1 -1
- msprobe/pytorch/parse_tool/lib/parse_tool.py +4 -2
- msprobe/pytorch/parse_tool/lib/utils.py +40 -55
- msprobe/pytorch/parse_tool/lib/visualization.py +3 -1
- msprobe/pytorch/pt_config.py +192 -40
- msprobe/pytorch/service.py +110 -35
- msprobe/visualization/__init__.py +14 -0
- msprobe/visualization/builder/__init__.py +14 -0
- msprobe/visualization/builder/graph_builder.py +165 -0
- msprobe/visualization/builder/msprobe_adapter.py +205 -0
- msprobe/visualization/compare/__init__.py +14 -0
- msprobe/visualization/compare/graph_comparator.py +130 -0
- msprobe/visualization/compare/mode_adapter.py +211 -0
- msprobe/visualization/graph/__init__.py +14 -0
- msprobe/visualization/graph/base_node.py +124 -0
- msprobe/visualization/graph/graph.py +200 -0
- msprobe/visualization/graph/node_colors.py +95 -0
- msprobe/visualization/graph/node_op.py +39 -0
- msprobe/visualization/graph_service.py +214 -0
- msprobe/visualization/utils.py +232 -0
- mindstudio_probe-1.0.4.dist-info/RECORD +0 -276
- msprobe/docs/04.acl_config_examples.md +0 -76
- msprobe/mindspore/free_benchmark/decorator/dec_forward.py +0 -43
- msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +0 -107
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/ssl_config.py +0 -10
- msprobe/pytorch/functional/dump_module.py +0 -39
- {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.1.dist-info}/LICENSE +0 -0
- {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.1.dist-info}/top_level.txt +0 -0
- /msprobe/{mindspore/free_benchmark/decorator → pytorch/monitor}/__init__.py +0 -0
- /msprobe/pytorch/{functional/data_processor.py → monitor/distributed/__init__.py} +0 -0
|
@@ -0,0 +1,203 @@
|
|
|
1
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
import os
|
|
17
|
+
|
|
18
|
+
from mindspore import ops
|
|
19
|
+
from mindspore.common.tensor import Tensor
|
|
20
|
+
|
|
21
|
+
from msprobe.core.common.utils import Const, DumpException
|
|
22
|
+
from msprobe.core.data_dump.data_processor.base import (ModuleBackwardInputs, ModuleBackwardOutputs,
|
|
23
|
+
ModuleForwardInputsOutputs)
|
|
24
|
+
from msprobe.mindspore.common.log import logger
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class PrimitiveHookService:
|
|
28
|
+
def __init__(self, service_instance):
|
|
29
|
+
self.primitive_counters = {}
|
|
30
|
+
self.service_instance = service_instance
|
|
31
|
+
|
|
32
|
+
def wrap_primitive(self, origin_func, primitive_name):
|
|
33
|
+
"""
|
|
34
|
+
包装原始的 primitive 函数,添加输入和输出的 hook 以捕获前向和反向数据。
|
|
35
|
+
|
|
36
|
+
Args:
|
|
37
|
+
origin_func (callable): 原始 的 primitive 函数。
|
|
38
|
+
primitive_name (str): 原始的 primitive 名称。
|
|
39
|
+
|
|
40
|
+
Returns:
|
|
41
|
+
callable: 包装后的 primitive 函数。
|
|
42
|
+
"""
|
|
43
|
+
|
|
44
|
+
def create_backward_hook(captured_grads, num_tensors, updated_primitive_name, hook_type):
|
|
45
|
+
"""
|
|
46
|
+
创建反向 hook 函数,用于捕获梯度。
|
|
47
|
+
|
|
48
|
+
Args:
|
|
49
|
+
captured_grads (list): 用于保存捕获的梯度。
|
|
50
|
+
num_tensors (int): 张量数量。
|
|
51
|
+
updated_primitive_name (str): 更新后的 primitive 名称。
|
|
52
|
+
hook_type (str): hook 类型 (输入/输出)。
|
|
53
|
+
|
|
54
|
+
Returns:
|
|
55
|
+
callable: 反向 hook 函数。
|
|
56
|
+
"""
|
|
57
|
+
|
|
58
|
+
def backward_hook(grad):
|
|
59
|
+
captured_grads.extend(grad)
|
|
60
|
+
backward_primitive_name = f"{updated_primitive_name}{Const.SEP}{Const.BACKWARD}"
|
|
61
|
+
|
|
62
|
+
try:
|
|
63
|
+
if hook_type == Const.INPUT:
|
|
64
|
+
self.service_instance.data_collector.update_api_or_module_name(backward_primitive_name)
|
|
65
|
+
new_module_input_output = ModuleBackwardOutputs(grad_output=tuple(captured_grads))
|
|
66
|
+
self.service_instance.data_collector.backward_output_data_collect(
|
|
67
|
+
backward_primitive_name, self, os.getpid(), new_module_input_output
|
|
68
|
+
)
|
|
69
|
+
elif hook_type == Const.OUTPUT:
|
|
70
|
+
self.service_instance.data_collector.update_api_or_module_name(backward_primitive_name)
|
|
71
|
+
new_module_input_output = ModuleBackwardInputs(grad_input=tuple(captured_grads))
|
|
72
|
+
self.service_instance.data_collector.backward_input_data_collect(
|
|
73
|
+
backward_primitive_name, self, os.getpid(), new_module_input_output
|
|
74
|
+
)
|
|
75
|
+
|
|
76
|
+
except Exception as exception:
|
|
77
|
+
logger.error(f"This is a primitive op {hook_type}_backward dump error: {exception}, "
|
|
78
|
+
f"updated_primitive_name: {updated_primitive_name}")
|
|
79
|
+
raise DumpException(DumpException.BACKWARD_DATA_COLLECTION_ERROR) from exception
|
|
80
|
+
|
|
81
|
+
return backward_hook
|
|
82
|
+
|
|
83
|
+
def hook_primitive_inputs(args, captured_grads_input, updated_primitive_name):
|
|
84
|
+
"""
|
|
85
|
+
针对前向输入添加 hook。
|
|
86
|
+
|
|
87
|
+
Args:
|
|
88
|
+
args (tuple): primitive 输入参数。
|
|
89
|
+
captured_grads_input (list): 捕获的输入梯度。
|
|
90
|
+
updated_primitive_name (str): 更新后的 primitive 名称。
|
|
91
|
+
|
|
92
|
+
Returns:
|
|
93
|
+
list: 添加了 hook 的输入。
|
|
94
|
+
"""
|
|
95
|
+
hooked_inputs = []
|
|
96
|
+
num_tensors = sum(isinstance(arg, Tensor) for arg in args)
|
|
97
|
+
input_backward_hook = create_backward_hook(captured_grads_input, num_tensors, updated_primitive_name,
|
|
98
|
+
Const.INPUT)
|
|
99
|
+
for arg in args:
|
|
100
|
+
if isinstance(arg, Tensor):
|
|
101
|
+
arg_hooked = ops.HookBackward(input_backward_hook)(arg)
|
|
102
|
+
hooked_inputs.append(arg_hooked)
|
|
103
|
+
else:
|
|
104
|
+
hooked_inputs.append(arg)
|
|
105
|
+
return tuple(hooked_inputs)
|
|
106
|
+
|
|
107
|
+
def hook_primitive_outputs(out, captured_grads_output, updated_primitive_name):
|
|
108
|
+
"""
|
|
109
|
+
针对前向输出添加 hook。
|
|
110
|
+
|
|
111
|
+
Args:
|
|
112
|
+
out (Tensor/tuple): primitive 输出。
|
|
113
|
+
captured_grads_output (list): 捕获的输出梯度。
|
|
114
|
+
updated_primitive_name (str): 更新后的 primitive 名称。
|
|
115
|
+
|
|
116
|
+
Returns:
|
|
117
|
+
Tensor/tuple: 添加了 hook 的输出。
|
|
118
|
+
"""
|
|
119
|
+
if isinstance(out, tuple):
|
|
120
|
+
num_output_tensors = sum(isinstance(tensor, Tensor) for tensor in out)
|
|
121
|
+
else:
|
|
122
|
+
num_output_tensors = 1
|
|
123
|
+
output_backward_hook = create_backward_hook(captured_grads_output, num_output_tensors,
|
|
124
|
+
updated_primitive_name, Const.OUTPUT)
|
|
125
|
+
|
|
126
|
+
if isinstance(out, Tensor):
|
|
127
|
+
return ops.HookBackward(output_backward_hook)(out)
|
|
128
|
+
elif isinstance(out, tuple):
|
|
129
|
+
hooked_outputs = []
|
|
130
|
+
for tensor in out:
|
|
131
|
+
if isinstance(tensor, Tensor):
|
|
132
|
+
hooked_outputs.append(ops.HookBackward(output_backward_hook)(tensor))
|
|
133
|
+
else:
|
|
134
|
+
hooked_outputs.append(tensor)
|
|
135
|
+
return tuple(hooked_outputs)
|
|
136
|
+
return out
|
|
137
|
+
|
|
138
|
+
def wrapped_primitive_call(instance_self, *args, **kwargs):
|
|
139
|
+
"""
|
|
140
|
+
包装后的 primitive 调用函数,添加输入和输出的 hook。
|
|
141
|
+
|
|
142
|
+
Args:
|
|
143
|
+
instance_self (object): primitive 的实例。
|
|
144
|
+
*args: primitive 输入参数。
|
|
145
|
+
**kwargs: primitive 关键字参数。
|
|
146
|
+
|
|
147
|
+
Returns:
|
|
148
|
+
Tensor/tuple: primitive 的返回值。
|
|
149
|
+
"""
|
|
150
|
+
self.update_primitive_counters(primitive_name)
|
|
151
|
+
current_count = self.primitive_counters.get(primitive_name, 0)
|
|
152
|
+
updated_primitive_name = f"{Const.PRIMITIVE_PREFIX}{Const.SEP}{primitive_name}{Const.SEP}{current_count}"
|
|
153
|
+
|
|
154
|
+
if not self.service_instance.primitive_switch:
|
|
155
|
+
return origin_func(*args, **kwargs)
|
|
156
|
+
|
|
157
|
+
captured_grads_input, captured_grads_output = [], []
|
|
158
|
+
|
|
159
|
+
try:
|
|
160
|
+
hooked_inputs = hook_primitive_inputs(args, captured_grads_input, updated_primitive_name)
|
|
161
|
+
except Exception as exception:
|
|
162
|
+
logger.error(f"This is a primitive op dump error during input hooking: {exception}, "
|
|
163
|
+
f"primitive_name: {primitive_name}")
|
|
164
|
+
raise DumpException(DumpException.INPUT_HOOK_ERROR) from exception
|
|
165
|
+
|
|
166
|
+
try:
|
|
167
|
+
out = origin_func(*hooked_inputs, **kwargs)
|
|
168
|
+
except Exception as exception:
|
|
169
|
+
logger.error(f"This is a primitive op dump error during function call: {exception}, "
|
|
170
|
+
f"primitive_name: {primitive_name}")
|
|
171
|
+
raise DumpException(DumpException.FUNCTION_CALL_ERROR) from exception
|
|
172
|
+
|
|
173
|
+
forward_primitive_name = f"{updated_primitive_name}{Const.SEP}{Const.FORWARD}"
|
|
174
|
+
self.service_instance.data_collector.update_api_or_module_name(forward_primitive_name)
|
|
175
|
+
if self.service_instance.data_collector:
|
|
176
|
+
module_input_output = ModuleForwardInputsOutputs(args=hooked_inputs, kwargs=kwargs, output=out)
|
|
177
|
+
try:
|
|
178
|
+
self.service_instance.data_collector.forward_data_collect(forward_primitive_name, instance_self,
|
|
179
|
+
os.getpid(), module_input_output)
|
|
180
|
+
except Exception as exception:
|
|
181
|
+
logger.error(f"This is a primitive op dump error during forward data collection: {exception}, "
|
|
182
|
+
f"primitive_name: {primitive_name}")
|
|
183
|
+
raise DumpException(DumpException.FORWARD_DATA_COLLECTION_ERROR) from exception
|
|
184
|
+
|
|
185
|
+
if self.service_instance.data_collector.if_return_forward_new_output():
|
|
186
|
+
out = self.service_instance.data_collector.get_forward_new_output()
|
|
187
|
+
|
|
188
|
+
try:
|
|
189
|
+
out = hook_primitive_outputs(out, captured_grads_output, updated_primitive_name)
|
|
190
|
+
except Exception as exception:
|
|
191
|
+
logger.error(f"This is a primitive op dump error during output hooking: {exception}, "
|
|
192
|
+
f"primitive_name: {primitive_name}")
|
|
193
|
+
raise DumpException(DumpException.OUTPUT_HOOK_ERROR) from exception
|
|
194
|
+
|
|
195
|
+
return out
|
|
196
|
+
|
|
197
|
+
return wrapped_primitive_call
|
|
198
|
+
|
|
199
|
+
def update_primitive_counters(self, primitive_name):
|
|
200
|
+
if primitive_name not in self.primitive_counters:
|
|
201
|
+
self.primitive_counters[primitive_name] = 0
|
|
202
|
+
else:
|
|
203
|
+
self.primitive_counters[primitive_name] += 1
|
|
@@ -185,6 +185,7 @@ ops:
|
|
|
185
185
|
- float_power
|
|
186
186
|
- fmod
|
|
187
187
|
- frac
|
|
188
|
+
- flash_attention_score
|
|
188
189
|
- gcd
|
|
189
190
|
- hypot
|
|
190
191
|
- igamma
|
|
@@ -489,6 +490,31 @@ ops:
|
|
|
489
490
|
- scatter_update
|
|
490
491
|
- derivative
|
|
491
492
|
- jet
|
|
493
|
+
- row_stack
|
|
494
|
+
- gather
|
|
495
|
+
- arange
|
|
496
|
+
- cond
|
|
497
|
+
- slice_scatter
|
|
498
|
+
- clip_by_norm
|
|
499
|
+
- eps
|
|
500
|
+
- layer_norm
|
|
501
|
+
- cast
|
|
502
|
+
- numel
|
|
503
|
+
- permute
|
|
504
|
+
- select_scatter
|
|
505
|
+
- group_norm
|
|
506
|
+
- eq
|
|
507
|
+
- embedding
|
|
508
|
+
- ones_like
|
|
509
|
+
- zeros
|
|
510
|
+
- nanmean
|
|
511
|
+
- shape
|
|
512
|
+
- zeros_like
|
|
513
|
+
- ones
|
|
514
|
+
- diagonal_scatter
|
|
515
|
+
- vander
|
|
516
|
+
- is_nonzero
|
|
517
|
+
- rotary_position_embedding
|
|
492
518
|
|
|
493
519
|
tensor:
|
|
494
520
|
- __abs__
|
|
@@ -876,16 +902,60 @@ mint.ops:
|
|
|
876
902
|
- zeros
|
|
877
903
|
- zeros_ex
|
|
878
904
|
- zeros_like
|
|
879
|
-
|
|
880
|
-
|
|
881
|
-
-
|
|
882
|
-
-
|
|
883
|
-
-
|
|
884
|
-
-
|
|
885
|
-
-
|
|
886
|
-
-
|
|
887
|
-
-
|
|
888
|
-
-
|
|
905
|
+
- inverse
|
|
906
|
+
- select
|
|
907
|
+
- item
|
|
908
|
+
- unsqueeze
|
|
909
|
+
- median
|
|
910
|
+
- floor
|
|
911
|
+
- histc
|
|
912
|
+
- special
|
|
913
|
+
- arctan2
|
|
914
|
+
- sign
|
|
915
|
+
- concat
|
|
916
|
+
- atanh
|
|
917
|
+
- greater_equal
|
|
918
|
+
- eye
|
|
919
|
+
- fix
|
|
920
|
+
- argmin
|
|
921
|
+
- asinh
|
|
922
|
+
- atan
|
|
923
|
+
- nan_to_num
|
|
924
|
+
- tan
|
|
925
|
+
- round
|
|
926
|
+
- cosh
|
|
927
|
+
- norm
|
|
928
|
+
- roll
|
|
929
|
+
- log1p
|
|
930
|
+
- reshape
|
|
931
|
+
- arccos
|
|
932
|
+
- outer
|
|
933
|
+
- arcsin
|
|
934
|
+
- rand_like
|
|
935
|
+
- acosh
|
|
936
|
+
- multinomial
|
|
937
|
+
- logical_xor
|
|
938
|
+
- acos
|
|
939
|
+
- linalg
|
|
940
|
+
- sinc
|
|
941
|
+
- arcsinh
|
|
942
|
+
- asin
|
|
943
|
+
- narrow
|
|
944
|
+
- arctanh
|
|
945
|
+
- trace
|
|
946
|
+
- erfc
|
|
947
|
+
- bernoulli
|
|
948
|
+
- expm1
|
|
949
|
+
- logaddexp
|
|
950
|
+
- sinh
|
|
951
|
+
- arccosh
|
|
952
|
+
- atan2
|
|
953
|
+
- rand
|
|
954
|
+
- arange
|
|
955
|
+
- trunc
|
|
956
|
+
- arctan
|
|
957
|
+
- swapaxes
|
|
958
|
+
- transpose
|
|
889
959
|
|
|
890
960
|
mint.nn.functional:
|
|
891
961
|
- absolute_import
|
|
@@ -920,3 +990,30 @@ mint.nn.functional:
|
|
|
920
990
|
- softplus
|
|
921
991
|
- tanh
|
|
922
992
|
- unfold
|
|
993
|
+
- mse_loss
|
|
994
|
+
- adaptive_avg_pool1d
|
|
995
|
+
- binary_cross_entropy
|
|
996
|
+
- adaptive_avg_pool2d
|
|
997
|
+
- hardsigmoid
|
|
998
|
+
- selu
|
|
999
|
+
- softshrink
|
|
1000
|
+
- prelu
|
|
1001
|
+
- logsigmoid
|
|
1002
|
+
- hardswish
|
|
1003
|
+
- mish
|
|
1004
|
+
- log_softmax
|
|
1005
|
+
- hardshrink
|
|
1006
|
+
- l1_loss
|
|
1007
|
+
- elu
|
|
1008
|
+
|
|
1009
|
+
communication.comm_func:
|
|
1010
|
+
- all_reduce
|
|
1011
|
+
- all_gather_into_tensor
|
|
1012
|
+
- reduce
|
|
1013
|
+
- reduce_scatter_tensor
|
|
1014
|
+
- all_to_all_single_with_output_shape
|
|
1015
|
+
- all_to_all_with_output_shape
|
|
1016
|
+
- batch_isend_irecv
|
|
1017
|
+
- broadcast
|
|
1018
|
+
- gather_into_tensor
|
|
1019
|
+
- scatter_tensor
|
|
@@ -1,8 +1,7 @@
|
|
|
1
|
-
|
|
2
|
-
#
|
|
3
|
-
|
|
4
|
-
#
|
|
5
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
1
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
6
5
|
# you may not use this file except in compliance with the License.
|
|
7
6
|
# You may obtain a copy of the License at
|
|
8
7
|
#
|
|
@@ -13,19 +12,18 @@
|
|
|
13
12
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
14
13
|
# See the License for the specific language governing permissions and
|
|
15
14
|
# limitations under the License.
|
|
16
|
-
"""
|
|
17
15
|
|
|
18
16
|
import os
|
|
19
17
|
|
|
20
|
-
from mindspore import Tensor,
|
|
21
|
-
from mindspore.mint.nn import functional
|
|
18
|
+
from mindspore import Tensor, mint, ops
|
|
22
19
|
from mindspore.common._stub_tensor import StubTensor
|
|
20
|
+
from mindspore.communication import comm_func
|
|
21
|
+
from mindspore.mint.nn import functional
|
|
23
22
|
|
|
24
|
-
from msprobe.mindspore.dump.hook_cell.hook_cell import HOOKCell
|
|
25
23
|
from msprobe.core.common.const import Const
|
|
26
|
-
from msprobe.mindspore.common.const import Const as MsConst
|
|
27
24
|
from msprobe.core.common.file_utils import load_yaml
|
|
28
|
-
|
|
25
|
+
from msprobe.mindspore.common.const import Const as MsConst
|
|
26
|
+
from msprobe.mindspore.dump.hook_cell.hook_cell import HOOKCell
|
|
29
27
|
|
|
30
28
|
cur_path = os.path.dirname(os.path.realpath(__file__))
|
|
31
29
|
yaml_path = os.path.join(cur_path, MsConst.SUPPORTED_API_LIST_FILE)
|
|
@@ -51,6 +49,10 @@ class HOOKMintNNFunctionalOP(object):
|
|
|
51
49
|
pass
|
|
52
50
|
|
|
53
51
|
|
|
52
|
+
class HOOKDistributedOP(object):
|
|
53
|
+
pass
|
|
54
|
+
|
|
55
|
+
|
|
54
56
|
class ApiTemplate(HOOKCell):
|
|
55
57
|
def __init__(self, api_name, api_dict, prefix, hook):
|
|
56
58
|
self.api_name = api_name
|
|
@@ -65,12 +67,14 @@ class ApiTemplate(HOOKCell):
|
|
|
65
67
|
|
|
66
68
|
|
|
67
69
|
class WrapApiName:
|
|
68
|
-
def __init__(self, tensor_api_names, stub_tensor_api_names, ops_api_names, mint_api_names, mint_nn_func_api_names
|
|
70
|
+
def __init__(self, tensor_api_names, stub_tensor_api_names, ops_api_names, mint_api_names, mint_nn_func_api_names,
|
|
71
|
+
distributed_api_names):
|
|
69
72
|
self.tensor_api_names = tensor_api_names
|
|
70
73
|
self.stub_tensor_api_names = stub_tensor_api_names
|
|
71
74
|
self.ops_api_names = ops_api_names
|
|
72
75
|
self.mint_api_names = mint_api_names
|
|
73
76
|
self.mint_nn_func_api_names = mint_nn_func_api_names
|
|
77
|
+
self.distributed_api_names = distributed_api_names
|
|
74
78
|
|
|
75
79
|
|
|
76
80
|
def get_wrap_api_list():
|
|
@@ -79,11 +83,13 @@ def get_wrap_api_list():
|
|
|
79
83
|
ops_api = api_list.get(MsConst.SUPPORTED_OPS_LIST_KEY)
|
|
80
84
|
mint_api = api_list.get(MsConst.SUPPORTED_MINT_LIST_KEY)
|
|
81
85
|
mint_nn_func_api = api_list.get(MsConst.SUPPORTED__MINT_NN_FUNC_LIST_KEY)
|
|
86
|
+
distributed_api = api_list.get(MsConst.SUPPORTED_COMM_LIST_KEY)
|
|
82
87
|
wrap_api_name = WrapApiName(set(tensor_api) & set(dir(Tensor)),
|
|
83
88
|
set(tensor_api) & set(dir(StubTensor)),
|
|
84
89
|
set(ops_api) & set(dir(ops)),
|
|
85
90
|
set(mint_api) & set(dir(mint)),
|
|
86
|
-
set(mint_nn_func_api) & set(dir(functional))
|
|
91
|
+
set(mint_nn_func_api) & set(dir(functional)),
|
|
92
|
+
set(distributed_api) & set(dir(comm_func)))
|
|
87
93
|
return wrap_api_name
|
|
88
94
|
|
|
89
95
|
|
|
@@ -111,3 +117,5 @@ def setup_hooks(hook):
|
|
|
111
117
|
MsConst.MINT_DATA_PREFIX, hook, HOOKMintOP)
|
|
112
118
|
wrap_api_func_and_bind(wrap_api_name.mint_nn_func_api_names, {f: getattr(functional, f) for f in dir(functional)},
|
|
113
119
|
MsConst.MINT_NN_FUNC_DATA_PREFIX, hook, HOOKMintNNFunctionalOP)
|
|
120
|
+
wrap_api_func_and_bind(wrap_api_name.distributed_api_names, {f: getattr(comm_func, f) for f in dir(comm_func)},
|
|
121
|
+
MsConst.DISTRIBUTED_DATA_PREFIX, hook, HOOKDistributedOP)
|
|
@@ -1,12 +1,30 @@
|
|
|
1
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
1
16
|
import os
|
|
17
|
+
from collections import defaultdict
|
|
2
18
|
|
|
3
19
|
from mindspore import Tensor
|
|
4
|
-
from mindspore.common.api import _MindsporeFunctionExecutor
|
|
5
20
|
from mindspore._c_expression import PyNativeExecutor_
|
|
21
|
+
from mindspore.common.api import _MindsporeFunctionExecutor
|
|
6
22
|
|
|
7
|
-
from msprobe.
|
|
8
|
-
from msprobe.core.data_dump.data_processor.base import ModuleForwardInputsOutputs
|
|
23
|
+
from msprobe.core.common.log import logger
|
|
24
|
+
from msprobe.core.data_dump.data_processor.base import ModuleForwardInputsOutputs, ModuleBackwardInputsOutputs
|
|
9
25
|
from msprobe.core.common.const import Const
|
|
26
|
+
from msprobe.core.data_dump.data_processor.base import ModuleForwardInputsOutputs
|
|
27
|
+
from msprobe.mindspore.dump.hook_cell.api_registry import api_register
|
|
10
28
|
|
|
11
29
|
|
|
12
30
|
def dump_jit(name, in_feat, out_feat, is_forward):
|
|
@@ -15,35 +33,53 @@ def dump_jit(name, in_feat, out_feat, is_forward):
|
|
|
15
33
|
index = ori_args.find("<")
|
|
16
34
|
if index != 0 and index != -1:
|
|
17
35
|
result = ori_args[0:index]
|
|
36
|
+
elif name is not None and "<" not in str(name):
|
|
37
|
+
result = str(name)
|
|
18
38
|
else:
|
|
19
39
|
result = "JitFunction"
|
|
20
|
-
if is_forward:
|
|
21
|
-
name_template = "Jit." + result + ".forward"
|
|
22
|
-
else:
|
|
23
|
-
name_template = "Jit." + result + ".backward"
|
|
24
40
|
if JitDump.need_dump():
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
41
|
+
if is_forward:
|
|
42
|
+
JitDump.jit_count[result] += 1
|
|
43
|
+
name_template = Const.JIT + Const.SEP + result + Const.SEP + str(JitDump.jit_count[result]) + Const.SEP + \
|
|
44
|
+
Const.FORWARD
|
|
45
|
+
JitDump.data_collector.update_api_or_module_name(name_template)
|
|
46
|
+
module_input_output = ModuleForwardInputsOutputs(args=in_feat, kwargs={}, output=out_feat)
|
|
47
|
+
JitDump.data_collector.forward_data_collect(name_template, None, pid, module_input_output)
|
|
48
|
+
else:
|
|
49
|
+
name_template = Const.JIT + Const.SEP + result + Const.SEP + str(JitDump.jit_count[result]) + Const.SEP + \
|
|
50
|
+
Const.BACKWARD
|
|
51
|
+
JitDump.data_collector.update_api_or_module_name(name_template)
|
|
52
|
+
module_input_output = ModuleBackwardInputsOutputs(grad_input=in_feat, grad_output=out_feat)
|
|
53
|
+
JitDump.data_collector.backward_data_collect(name_template, None, pid, module_input_output)
|
|
28
54
|
|
|
29
55
|
|
|
30
56
|
class JitDump(_MindsporeFunctionExecutor):
|
|
31
57
|
dump_config = None
|
|
32
58
|
jit_enable = False
|
|
59
|
+
jit_dump_switch = True
|
|
60
|
+
jit_count = defaultdict(int)
|
|
33
61
|
|
|
34
62
|
def __init__(self, *args, **kwargs):
|
|
35
63
|
super().__init__(*args, **kwargs)
|
|
64
|
+
self.name = None
|
|
65
|
+
if len(args) > 0:
|
|
66
|
+
self.name = args[0].__name__
|
|
36
67
|
self._executor = PyNativeExecutor_.get_instance()
|
|
37
68
|
|
|
38
69
|
def __call__(self, *args, **kwargs):
|
|
39
|
-
|
|
70
|
+
if JitDump.jit_dump_switch:
|
|
71
|
+
api_register.api_set_ori_func()
|
|
40
72
|
out = super().__call__(*args, **kwargs)
|
|
41
|
-
if
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
73
|
+
if JitDump.jit_dump_switch and len(args) > 0:
|
|
74
|
+
if self.name and self.name != "construct":
|
|
75
|
+
dump_jit(self.name, args, out, True)
|
|
76
|
+
else:
|
|
77
|
+
dump_jit(args[0], args, out, True)
|
|
78
|
+
JitDump.jit_enable = True
|
|
79
|
+
elif len(args) == 0:
|
|
80
|
+
logger.warning(f"The jit function {self.name} has no input arguments, nothing will be dumped.")
|
|
81
|
+
if JitDump.jit_dump_switch:
|
|
82
|
+
api_register.api_set_hook_func()
|
|
47
83
|
return out
|
|
48
84
|
|
|
49
85
|
@classmethod
|
|
@@ -62,11 +98,11 @@ class JitDump(_MindsporeFunctionExecutor):
|
|
|
62
98
|
return False
|
|
63
99
|
return True
|
|
64
100
|
|
|
65
|
-
def grad(self, obj, grad, weights, grad_position, *args,
|
|
66
|
-
if JitDump.jit_enable:
|
|
101
|
+
def grad(self, obj, grad, weights, grad_position, *args, **kwargs):
|
|
102
|
+
if JitDump.jit_dump_switch and JitDump.jit_enable:
|
|
67
103
|
api_register.api_set_ori_func()
|
|
68
104
|
output = self._executor.grad(grad, obj, weights, grad_position, *args, *(kwargs.values()))
|
|
69
|
-
if JitDump.jit_enable:
|
|
105
|
+
if JitDump.jit_dump_switch and JitDump.jit_enable:
|
|
70
106
|
dump_jit(obj, args, None, False)
|
|
71
107
|
api_register.api_set_hook_func()
|
|
72
108
|
return output
|
|
@@ -1,8 +1,23 @@
|
|
|
1
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
1
16
|
import os
|
|
2
|
-
|
|
3
|
-
from msprobe.
|
|
17
|
+
|
|
18
|
+
from msprobe.core.common.file_utils import create_directory, save_json
|
|
4
19
|
from msprobe.mindspore.common.log import logger
|
|
5
|
-
from msprobe.
|
|
20
|
+
from msprobe.mindspore.debugger.debugger_config import DebuggerConfig
|
|
6
21
|
|
|
7
22
|
|
|
8
23
|
class KernelGraphDump:
|
|
@@ -46,8 +61,7 @@ class KernelGraphDump:
|
|
|
46
61
|
json_path = self.dump_json["common_dump_settings"]["path"]
|
|
47
62
|
create_directory(json_path)
|
|
48
63
|
json_path = os.path.join(json_path, "kernel_graph_dump.json")
|
|
49
|
-
|
|
50
|
-
json.dump(self.dump_json, f)
|
|
64
|
+
save_json(json_path, self.dump_json, indent=4)
|
|
51
65
|
logger.info(json_path + " has been created.")
|
|
52
66
|
os.environ["MINDSPORE_DUMP_CONFIG"] = json_path
|
|
53
67
|
if self.dump_json["common_dump_settings"]["dump_mode"] == 0:
|
|
@@ -1,10 +1,24 @@
|
|
|
1
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
1
16
|
import os
|
|
2
|
-
import json
|
|
3
17
|
|
|
4
|
-
from msprobe.mindspore.debugger.debugger_config import DebuggerConfig
|
|
5
|
-
from msprobe.mindspore.common.log import logger
|
|
6
|
-
from msprobe.core.common.file_utils import FileOpen, create_directory
|
|
7
18
|
from msprobe.core.common.const import Const
|
|
19
|
+
from msprobe.core.common.file_utils import create_directory, save_json
|
|
20
|
+
from msprobe.mindspore.common.log import logger
|
|
21
|
+
from msprobe.mindspore.debugger.debugger_config import DebuggerConfig
|
|
8
22
|
|
|
9
23
|
|
|
10
24
|
class KernelKbykDump:
|
|
@@ -55,8 +69,7 @@ class KernelKbykDump:
|
|
|
55
69
|
json_path = self.dump_json[KernelKbykDump.COMMON_SETTINGS]["path"]
|
|
56
70
|
create_directory(json_path)
|
|
57
71
|
json_path = os.path.join(json_path, "kernel_kbyk_dump.json")
|
|
58
|
-
|
|
59
|
-
json.dump(self.dump_json, f)
|
|
72
|
+
save_json(json_path, self.dump_json, indent=4)
|
|
60
73
|
logger.info(json_path + " has been created.")
|
|
61
74
|
|
|
62
75
|
os.environ["MINDSPORE_DUMP_CONFIG"] = json_path
|