mindstudio-probe 1.0.4__py3-none-any.whl → 1.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.1.dist-info}/METADATA +5 -5
- mindstudio_probe-1.1.1.dist-info/RECORD +341 -0
- {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.1.dist-info}/WHEEL +1 -1
- {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.1.dist-info}/entry_points.txt +0 -1
- msprobe/README.md +84 -18
- msprobe/__init__.py +16 -1
- msprobe/config.json +1 -5
- msprobe/core/advisor/advisor.py +16 -11
- msprobe/core/advisor/advisor_const.py +6 -7
- msprobe/core/advisor/advisor_result.py +12 -12
- msprobe/core/common/const.py +164 -3
- msprobe/core/common/exceptions.py +26 -4
- msprobe/core/common/file_utils.py +196 -27
- msprobe/core/common/inplace_op_checker.py +53 -0
- msprobe/core/common/inplace_ops.yaml +251 -0
- msprobe/core/common/log.py +46 -18
- msprobe/core/common/utils.py +308 -209
- msprobe/core/common_config.py +60 -38
- msprobe/core/compare/acc_compare.py +332 -94
- msprobe/core/compare/check.py +104 -22
- msprobe/core/compare/compare_cli.py +42 -5
- msprobe/core/compare/highlight.py +162 -57
- msprobe/core/compare/layer_mapping/__init__.py +19 -0
- msprobe/core/compare/layer_mapping/data_scope_parser.py +235 -0
- msprobe/core/compare/layer_mapping/layer_mapping.py +242 -0
- msprobe/core/compare/layer_mapping/postprocess_pass.py +94 -0
- msprobe/core/compare/multiprocessing_compute.py +33 -8
- msprobe/core/compare/npy_compare.py +73 -29
- msprobe/core/compare/utils.py +306 -247
- msprobe/core/data_dump/data_collector.py +44 -43
- msprobe/core/data_dump/data_processor/base.py +88 -35
- msprobe/core/data_dump/data_processor/factory.py +20 -3
- msprobe/core/data_dump/data_processor/mindspore_processor.py +14 -8
- msprobe/core/data_dump/data_processor/pytorch_processor.py +180 -66
- msprobe/core/data_dump/json_writer.py +63 -42
- msprobe/core/data_dump/scope.py +143 -48
- msprobe/core/grad_probe/constant.py +31 -13
- msprobe/core/grad_probe/grad_compare.py +20 -4
- msprobe/core/grad_probe/utils.py +44 -3
- msprobe/core/overflow_check/abnormal_scene.py +185 -0
- msprobe/core/overflow_check/api_info.py +55 -0
- msprobe/core/overflow_check/checker.py +138 -0
- msprobe/core/overflow_check/filter.py +157 -0
- msprobe/core/overflow_check/ignore_rules.yaml +55 -0
- msprobe/core/overflow_check/level.py +22 -0
- msprobe/core/overflow_check/utils.py +28 -0
- msprobe/docs/01.installation.md +29 -9
- msprobe/docs/02.config_introduction.md +83 -84
- msprobe/docs/03.config_examples.md +3 -20
- msprobe/docs/04.kernel_dump_PyTorch.md +73 -0
- msprobe/docs/05.data_dump_PyTorch.md +143 -13
- msprobe/docs/06.data_dump_MindSpore.md +197 -88
- msprobe/docs/07.accuracy_checker_PyTorch.md +69 -46
- msprobe/docs/08.accuracy_checker_online_PyTorch.md +52 -17
- msprobe/docs/09.accuracy_checker_MindSpore.md +51 -15
- msprobe/docs/10.accuracy_compare_PyTorch.md +187 -99
- msprobe/docs/11.accuracy_compare_MindSpore.md +253 -31
- msprobe/docs/12.overflow_check_PyTorch.md +1 -1
- msprobe/docs/13.overflow_check_MindSpore.md +6 -6
- msprobe/docs/15.free_benchmarking_PyTorch.md +60 -55
- msprobe/docs/16.free_benchmarking_MindSpore.md +159 -0
- msprobe/docs/17.grad_probe.md +19 -22
- msprobe/docs/18.online_dispatch.md +89 -0
- msprobe/docs/19.monitor.md +468 -0
- msprobe/docs/20.monitor_performance_baseline.md +52 -0
- msprobe/docs/21.visualization_PyTorch.md +386 -0
- msprobe/docs/22.visualization_MindSpore.md +384 -0
- msprobe/docs/23.tool_function_introduction.md +28 -0
- msprobe/docs/{FAQ_PyTorch.md → FAQ.md} +25 -10
- msprobe/docs/data_dump_Mindspore/dynamic_graph_quick_start_example.md +211 -0
- msprobe/docs/img/compare_result.png +0 -0
- msprobe/docs/img/monitor/cpu_info.png +0 -0
- msprobe/docs/img/ms_dump.png +0 -0
- msprobe/docs/img/ms_layer.png +0 -0
- msprobe/docs/img/pt_dump.png +0 -0
- msprobe/mindspore/__init__.py +16 -0
- msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +130 -138
- msprobe/mindspore/api_accuracy_checker/api_info.py +27 -5
- msprobe/mindspore/api_accuracy_checker/api_runner.py +43 -18
- msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +21 -7
- msprobe/mindspore/api_accuracy_checker/checker_support_api.yaml +77 -0
- msprobe/mindspore/api_accuracy_checker/cmd_parser.py +63 -1
- msprobe/mindspore/api_accuracy_checker/compute_element.py +59 -24
- msprobe/mindspore/api_accuracy_checker/data_manager.py +264 -0
- msprobe/mindspore/api_accuracy_checker/main.py +27 -3
- msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +206 -0
- msprobe/mindspore/api_accuracy_checker/multi_data_manager.py +58 -0
- msprobe/mindspore/api_accuracy_checker/type_mapping.py +22 -5
- msprobe/mindspore/api_accuracy_checker/utils.py +34 -17
- msprobe/mindspore/cell_processor.py +58 -13
- msprobe/mindspore/common/const.py +35 -13
- msprobe/mindspore/common/log.py +5 -9
- msprobe/mindspore/common/utils.py +60 -5
- msprobe/mindspore/compare/distributed_compare.py +15 -28
- msprobe/mindspore/compare/ms_compare.py +319 -158
- msprobe/mindspore/compare/ms_graph_compare.py +99 -49
- msprobe/mindspore/debugger/debugger_config.py +20 -14
- msprobe/mindspore/debugger/precision_debugger.py +43 -13
- msprobe/mindspore/dump/dump_tool_factory.py +18 -1
- msprobe/mindspore/dump/hook_cell/api_registry.py +23 -3
- msprobe/mindspore/dump/hook_cell/primitive_hooks.py +203 -0
- msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +107 -10
- msprobe/mindspore/dump/hook_cell/wrap_api.py +21 -13
- msprobe/mindspore/dump/jit_dump.py +56 -20
- msprobe/mindspore/dump/kernel_graph_dump.py +19 -5
- msprobe/mindspore/dump/kernel_kbyk_dump.py +19 -6
- msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +140 -0
- msprobe/mindspore/dym_loader/hook_dynamic_loader.h +53 -0
- msprobe/mindspore/free_benchmark/api_pynative_self_check.py +162 -41
- msprobe/mindspore/free_benchmark/common/config.py +15 -0
- msprobe/mindspore/free_benchmark/common/handler_params.py +15 -1
- msprobe/mindspore/free_benchmark/common/utils.py +37 -8
- msprobe/mindspore/free_benchmark/data/support_wrap_ops.yaml +0 -204
- msprobe/mindspore/free_benchmark/handler/base_handler.py +20 -5
- msprobe/mindspore/free_benchmark/handler/check_handler.py +21 -7
- msprobe/mindspore/free_benchmark/handler/fix_handler.py +18 -3
- msprobe/mindspore/free_benchmark/handler/handler_factory.py +21 -6
- msprobe/mindspore/free_benchmark/perturbation/add_noise.py +23 -8
- msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +29 -5
- msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +25 -10
- msprobe/mindspore/free_benchmark/perturbation/exchange_value.py +45 -19
- msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +29 -8
- msprobe/mindspore/free_benchmark/perturbation/no_change.py +16 -1
- msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +22 -7
- msprobe/mindspore/free_benchmark/self_check_tool_factory.py +17 -2
- msprobe/mindspore/grad_probe/global_context.py +44 -14
- msprobe/mindspore/grad_probe/grad_analyzer.py +27 -13
- msprobe/mindspore/grad_probe/grad_monitor.py +16 -1
- msprobe/mindspore/grad_probe/grad_stat_csv.py +33 -5
- msprobe/mindspore/grad_probe/hook.py +24 -10
- msprobe/mindspore/grad_probe/utils.py +18 -5
- msprobe/mindspore/ms_config.py +22 -15
- msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +20 -6
- msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +15 -0
- msprobe/mindspore/runtime.py +15 -0
- msprobe/mindspore/service.py +75 -150
- msprobe/mindspore/task_handler_factory.py +15 -0
- msprobe/msprobe.py +24 -7
- msprobe/pytorch/__init__.py +23 -3
- msprobe/pytorch/api_accuracy_checker/common/config.py +81 -2
- msprobe/pytorch/api_accuracy_checker/common/utils.py +53 -21
- msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +19 -2
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +50 -25
- msprobe/pytorch/api_accuracy_checker/compare/compare.py +51 -21
- msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +23 -6
- msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +28 -8
- msprobe/pytorch/api_accuracy_checker/config.yaml +1 -1
- msprobe/pytorch/api_accuracy_checker/generate_op_script/config_op.json +9 -0
- msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +454 -0
- msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +365 -0
- msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +73 -33
- msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +44 -18
- msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +32 -11
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +122 -172
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +158 -4
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +30 -24
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +68 -31
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +27 -4
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/dump_dispatch.py +115 -0
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +26 -9
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/torch_ops_config.yaml +63 -0
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +44 -0
- msprobe/pytorch/bench_functions/__init__.py +18 -3
- msprobe/pytorch/bench_functions/apply_adam_w.py +15 -0
- msprobe/pytorch/bench_functions/confusion_transpose.py +20 -1
- msprobe/pytorch/bench_functions/fast_gelu.py +15 -0
- msprobe/pytorch/bench_functions/layer_norm_eval.py +15 -0
- msprobe/pytorch/bench_functions/linear.py +15 -0
- msprobe/pytorch/bench_functions/matmul_backward.py +33 -6
- msprobe/pytorch/bench_functions/npu_fusion_attention.py +280 -157
- msprobe/pytorch/bench_functions/rms_norm.py +15 -0
- msprobe/pytorch/bench_functions/rotary_mul.py +32 -9
- msprobe/pytorch/bench_functions/scaled_mask_softmax.py +15 -0
- msprobe/pytorch/bench_functions/swiglu.py +29 -6
- msprobe/pytorch/common/__init__.py +15 -0
- msprobe/pytorch/common/log.py +18 -6
- msprobe/pytorch/common/parse_json.py +31 -16
- msprobe/pytorch/common/utils.py +96 -40
- msprobe/pytorch/compare/distributed_compare.py +13 -14
- msprobe/pytorch/compare/match.py +15 -0
- msprobe/pytorch/compare/pt_compare.py +44 -10
- msprobe/pytorch/debugger/debugger_config.py +69 -52
- msprobe/pytorch/debugger/precision_debugger.py +72 -24
- msprobe/pytorch/dump/kernel_dump/kernel_config.py +33 -0
- msprobe/pytorch/free_benchmark/__init__.py +20 -5
- msprobe/pytorch/free_benchmark/common/constant.py +15 -0
- msprobe/pytorch/free_benchmark/common/counter.py +15 -0
- msprobe/pytorch/free_benchmark/common/enums.py +43 -0
- msprobe/pytorch/free_benchmark/common/params.py +23 -1
- msprobe/pytorch/free_benchmark/common/utils.py +43 -5
- msprobe/pytorch/free_benchmark/compare/grad_saver.py +47 -9
- msprobe/pytorch/free_benchmark/compare/single_benchmark.py +17 -0
- msprobe/pytorch/free_benchmark/main.py +19 -4
- msprobe/pytorch/free_benchmark/perturbed_layers/base_layer.py +15 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/layer_factory.py +19 -4
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +18 -1
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +21 -4
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +28 -2
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +19 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/no_change.py +15 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/npu_base_layser.py +15 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/run_cpu.py +15 -0
- msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +65 -16
- msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +15 -0
- msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +21 -5
- msprobe/pytorch/free_benchmark/result_handlers/handler_factory.py +15 -0
- msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +19 -4
- msprobe/pytorch/function_factory.py +17 -2
- msprobe/pytorch/functional/module_dump.py +84 -0
- msprobe/pytorch/grad_probe/grad_monitor.py +23 -6
- msprobe/pytorch/grad_probe/grad_stat_csv.py +40 -10
- msprobe/pytorch/hook_module/__init__.py +16 -1
- msprobe/pytorch/hook_module/api_registry.py +13 -8
- msprobe/pytorch/hook_module/hook_module.py +17 -19
- msprobe/pytorch/hook_module/support_wrap_ops.yaml +1 -0
- msprobe/pytorch/hook_module/utils.py +4 -6
- msprobe/pytorch/hook_module/wrap_aten.py +12 -11
- msprobe/pytorch/hook_module/wrap_distributed.py +6 -7
- msprobe/pytorch/hook_module/wrap_functional.py +21 -20
- msprobe/pytorch/hook_module/wrap_npu_custom.py +9 -17
- msprobe/pytorch/hook_module/wrap_tensor.py +4 -6
- msprobe/pytorch/hook_module/wrap_torch.py +4 -6
- msprobe/pytorch/hook_module/wrap_vf.py +4 -6
- msprobe/pytorch/module_processer.py +18 -6
- msprobe/pytorch/monitor/anomaly_analyse.py +201 -0
- msprobe/pytorch/monitor/anomaly_detect.py +340 -0
- msprobe/pytorch/monitor/distributed/distributed_ops.yaml +19 -0
- msprobe/pytorch/monitor/distributed/stack_blacklist.yaml +5 -0
- msprobe/pytorch/monitor/distributed/wrap_distributed.py +272 -0
- msprobe/pytorch/monitor/features.py +108 -0
- msprobe/pytorch/monitor/module_hook.py +870 -0
- msprobe/pytorch/monitor/module_metric.py +193 -0
- msprobe/pytorch/monitor/module_spec_verifier.py +93 -0
- msprobe/pytorch/monitor/optimizer_collect.py +295 -0
- msprobe/pytorch/monitor/unittest/__init__.py +0 -0
- msprobe/pytorch/monitor/unittest/test_monitor.py +145 -0
- msprobe/pytorch/monitor/utils.py +250 -0
- msprobe/pytorch/monitor/visualizer.py +59 -0
- msprobe/pytorch/online_dispatch/__init__.py +2 -3
- msprobe/pytorch/online_dispatch/compare.py +38 -48
- msprobe/pytorch/online_dispatch/dispatch.py +50 -25
- msprobe/pytorch/online_dispatch/dump_compare.py +21 -9
- msprobe/pytorch/online_dispatch/single_compare.py +60 -39
- msprobe/pytorch/online_dispatch/torch_ops_config.yaml +9 -1
- msprobe/pytorch/online_dispatch/utils.py +48 -23
- msprobe/pytorch/parse.py +15 -0
- msprobe/pytorch/parse_tool/cli.py +5 -6
- msprobe/pytorch/parse_tool/lib/compare.py +19 -26
- msprobe/pytorch/parse_tool/lib/config.py +1 -1
- msprobe/pytorch/parse_tool/lib/parse_tool.py +4 -2
- msprobe/pytorch/parse_tool/lib/utils.py +40 -55
- msprobe/pytorch/parse_tool/lib/visualization.py +3 -1
- msprobe/pytorch/pt_config.py +192 -40
- msprobe/pytorch/service.py +110 -35
- msprobe/visualization/__init__.py +14 -0
- msprobe/visualization/builder/__init__.py +14 -0
- msprobe/visualization/builder/graph_builder.py +165 -0
- msprobe/visualization/builder/msprobe_adapter.py +205 -0
- msprobe/visualization/compare/__init__.py +14 -0
- msprobe/visualization/compare/graph_comparator.py +130 -0
- msprobe/visualization/compare/mode_adapter.py +211 -0
- msprobe/visualization/graph/__init__.py +14 -0
- msprobe/visualization/graph/base_node.py +124 -0
- msprobe/visualization/graph/graph.py +200 -0
- msprobe/visualization/graph/node_colors.py +95 -0
- msprobe/visualization/graph/node_op.py +39 -0
- msprobe/visualization/graph_service.py +214 -0
- msprobe/visualization/utils.py +232 -0
- mindstudio_probe-1.0.4.dist-info/RECORD +0 -276
- msprobe/docs/04.acl_config_examples.md +0 -76
- msprobe/mindspore/free_benchmark/decorator/dec_forward.py +0 -43
- msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +0 -107
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/ssl_config.py +0 -10
- msprobe/pytorch/functional/dump_module.py +0 -39
- {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.1.dist-info}/LICENSE +0 -0
- {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.1.dist-info}/top_level.txt +0 -0
- /msprobe/{mindspore/free_benchmark/decorator → pytorch/monitor}/__init__.py +0 -0
- /msprobe/pytorch/{functional/data_processor.py → monitor/distributed/__init__.py} +0 -0
|
@@ -1,11 +1,26 @@
|
|
|
1
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
1
16
|
from typing import Any
|
|
2
17
|
|
|
3
18
|
from mindspore import Tensor, ops
|
|
4
19
|
|
|
20
|
+
from msprobe.mindspore.common.const import FreeBenchmarkConst
|
|
5
21
|
from msprobe.mindspore.common.log import logger
|
|
6
|
-
from msprobe.mindspore.free_benchmark.perturbation.base_perturbation import BasePerturbation
|
|
7
22
|
from msprobe.mindspore.free_benchmark.common.handler_params import HandlerParams
|
|
8
|
-
from msprobe.mindspore.
|
|
23
|
+
from msprobe.mindspore.free_benchmark.perturbation.base_perturbation import BasePerturbation
|
|
9
24
|
|
|
10
25
|
|
|
11
26
|
class AddNoisePerturbation(BasePerturbation):
|
|
@@ -17,7 +32,7 @@ class AddNoisePerturbation(BasePerturbation):
|
|
|
17
32
|
"""
|
|
18
33
|
params.fuzzed_value = self.add_noise(params.args[params.index])
|
|
19
34
|
if not self.is_fuzzed:
|
|
20
|
-
logger.warning(f"{self.
|
|
35
|
+
logger.warning(f"{self.api_name_with_id} can not add noise.")
|
|
21
36
|
return False
|
|
22
37
|
return self.get_fuzzed_result(params)
|
|
23
38
|
|
|
@@ -43,25 +58,25 @@ class AddNoisePerturbation(BasePerturbation):
|
|
|
43
58
|
|
|
44
59
|
return inputs
|
|
45
60
|
|
|
46
|
-
def _get_noise(self,
|
|
61
|
+
def _get_noise(self, tensor):
|
|
47
62
|
"""
|
|
48
63
|
得到要添加的噪声值
|
|
49
64
|
|
|
50
65
|
"""
|
|
51
66
|
if self.is_fuzzed:
|
|
52
67
|
return False
|
|
53
|
-
if not ops.is_floating_point(
|
|
68
|
+
if not ops.is_floating_point(tensor) or ops.numel(tensor) == 0:
|
|
54
69
|
return False
|
|
55
70
|
|
|
56
|
-
pert_value = FreeBenchmarkConst.PERT_VALUE_DICT.get(
|
|
71
|
+
pert_value = FreeBenchmarkConst.PERT_VALUE_DICT.get(tensor.dtype)
|
|
57
72
|
if not pert_value:
|
|
58
73
|
return False
|
|
59
74
|
else:
|
|
60
75
|
self.perturbation_value = pert_value
|
|
61
76
|
|
|
62
|
-
max_val = ops.max(ops.abs(
|
|
77
|
+
max_val = ops.max(ops.abs(tensor))[0].item()
|
|
63
78
|
if max_val < pert_value:
|
|
64
79
|
return False
|
|
65
80
|
|
|
66
|
-
noise = ops.full(
|
|
81
|
+
noise = ops.full(tensor.shape, self.perturbation_value, dtype=tensor.dtype)
|
|
67
82
|
return noise
|
|
@@ -1,20 +1,44 @@
|
|
|
1
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
1
16
|
from typing import Any
|
|
2
17
|
|
|
18
|
+
from msprobe.core.common.const import Const
|
|
19
|
+
from msprobe.mindspore.free_benchmark.common.config import Config
|
|
3
20
|
from msprobe.mindspore.free_benchmark.common.handler_params import HandlerParams
|
|
21
|
+
from msprobe.mindspore.free_benchmark.common.utils import Tools
|
|
4
22
|
|
|
5
23
|
|
|
6
24
|
class BasePerturbation:
|
|
7
25
|
|
|
8
|
-
def __init__(self,
|
|
9
|
-
self.
|
|
26
|
+
def __init__(self, api_name_with_id: str):
|
|
27
|
+
self.api_name_with_id = api_name_with_id
|
|
10
28
|
self.is_fuzzed = False
|
|
11
29
|
self.perturbation_value = None
|
|
12
30
|
|
|
13
31
|
@staticmethod
|
|
14
32
|
def get_fuzzed_result(params: HandlerParams):
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
33
|
+
if Config.stage == Const.BACKWARD:
|
|
34
|
+
fuzzed_result = Tools.get_grad(params.original_func, *params.args[:params.index],
|
|
35
|
+
params.fuzzed_value, *params.args[params.index + 1:], **params.kwargs)
|
|
36
|
+
|
|
37
|
+
if fuzzed_result is None:
|
|
38
|
+
return False
|
|
39
|
+
else:
|
|
40
|
+
fuzzed_result = params.original_func(*params.args[:params.index], params.fuzzed_value,
|
|
41
|
+
*params.args[params.index + 1:], **params.kwargs)
|
|
18
42
|
return fuzzed_result
|
|
19
43
|
|
|
20
44
|
def handler(self, params: HandlerParams) -> Any:
|
|
@@ -1,10 +1,25 @@
|
|
|
1
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
1
16
|
from typing import Any
|
|
2
17
|
|
|
3
18
|
import numpy as np
|
|
4
19
|
from mindspore import Tensor, ops
|
|
5
20
|
|
|
6
|
-
from msprobe.mindspore.common.log import logger
|
|
7
21
|
from msprobe.mindspore.common.const import FreeBenchmarkConst
|
|
22
|
+
from msprobe.mindspore.common.log import logger
|
|
8
23
|
from msprobe.mindspore.free_benchmark.common.handler_params import HandlerParams
|
|
9
24
|
from msprobe.mindspore.free_benchmark.perturbation.base_perturbation import BasePerturbation
|
|
10
25
|
|
|
@@ -20,12 +35,12 @@ class BitNoisePerturbation(BasePerturbation):
|
|
|
20
35
|
noise_type = list(FreeBenchmarkConst.MS_NUMPY_DTYPE_DICT.keys())[
|
|
21
36
|
list(FreeBenchmarkConst.MS_NUMPY_DTYPE_DICT.values()).index(bit_len_type)]
|
|
22
37
|
noise = ops.full(inputs.shape, 1, dtype=noise_type)
|
|
23
|
-
input_np = inputs.asnumpy()
|
|
38
|
+
input_np = inputs.contiguous().asnumpy()
|
|
24
39
|
input_np_int = input_np.view(bit_len_type)
|
|
25
40
|
result = Tensor(input_np_int)
|
|
26
41
|
result = ops.where(ops.abs(inputs) > sub_normal,
|
|
27
42
|
ops.bitwise_xor(result, noise), result)
|
|
28
|
-
result_np = result.asnumpy()
|
|
43
|
+
result_np = result.contiguous().asnumpy()
|
|
29
44
|
result_np_float = result_np.view(FreeBenchmarkConst.MS_NUMPY_DTYPE_DICT.get(inputs.dtype))
|
|
30
45
|
self.is_fuzzed = True
|
|
31
46
|
return Tensor(result_np_float)
|
|
@@ -40,24 +55,24 @@ class BitNoisePerturbation(BasePerturbation):
|
|
|
40
55
|
args = params.args
|
|
41
56
|
params.fuzzed_value = self.add_bit_noise(params.args[params.index])
|
|
42
57
|
if not self.is_fuzzed:
|
|
43
|
-
logger.warning(f"{self.
|
|
58
|
+
logger.warning(f"{self.api_name_with_id} can not add bit noise.")
|
|
44
59
|
return False
|
|
45
60
|
params.args = args
|
|
46
61
|
return self.get_fuzzed_result(params)
|
|
47
62
|
|
|
48
|
-
def _get_bit_len_type(self,
|
|
63
|
+
def _get_bit_len_type(self, tensor):
|
|
49
64
|
if self.is_fuzzed:
|
|
50
65
|
return False
|
|
51
|
-
if not isinstance(
|
|
52
|
-
|
|
66
|
+
if not isinstance(tensor, Tensor) or not ops.is_floating_point(tensor) or \
|
|
67
|
+
tensor.numel() == 0:
|
|
53
68
|
return False
|
|
54
|
-
bit_len_type = FreeBenchmarkConst.PERT_BIT_DICT.get(
|
|
69
|
+
bit_len_type = FreeBenchmarkConst.PERT_BIT_DICT.get(tensor.dtype)
|
|
55
70
|
if not bit_len_type:
|
|
56
71
|
return False
|
|
57
|
-
pert_value = FreeBenchmarkConst.PERT_VALUE_DICT.get(
|
|
72
|
+
pert_value = FreeBenchmarkConst.PERT_VALUE_DICT.get(tensor.dtype)
|
|
58
73
|
if not pert_value:
|
|
59
74
|
return False
|
|
60
|
-
max_val = ops.max(ops.abs(
|
|
75
|
+
max_val = ops.max(ops.abs(tensor))[0].item()
|
|
61
76
|
if max_val < pert_value:
|
|
62
77
|
return False
|
|
63
78
|
return bit_len_type
|
|
@@ -1,14 +1,39 @@
|
|
|
1
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
1
16
|
from typing import Any
|
|
2
17
|
|
|
3
|
-
from mindspore import Tensor
|
|
18
|
+
from mindspore import Tensor, ops
|
|
4
19
|
|
|
5
20
|
from msprobe.mindspore.common.log import logger
|
|
6
|
-
from msprobe.mindspore.free_benchmark.perturbation.base_perturbation import BasePerturbation
|
|
7
21
|
from msprobe.mindspore.free_benchmark.common.handler_params import HandlerParams
|
|
22
|
+
from msprobe.mindspore.free_benchmark.perturbation.base_perturbation import BasePerturbation
|
|
8
23
|
|
|
9
24
|
|
|
10
25
|
class ExchangeValuePerturbation(BasePerturbation):
|
|
11
26
|
|
|
27
|
+
@staticmethod
|
|
28
|
+
def _check_tensor_shape(inputs):
|
|
29
|
+
dims = len(inputs.shape)
|
|
30
|
+
if dims == 1 and inputs.shape[0] > 1:
|
|
31
|
+
return True
|
|
32
|
+
if dims > 1 and inputs.shape[1] > 0:
|
|
33
|
+
if inputs.shape[0] > 1 or inputs.shape[1] > 1:
|
|
34
|
+
return True
|
|
35
|
+
return False
|
|
36
|
+
|
|
12
37
|
def handle(self, params: HandlerParams) -> Any:
|
|
13
38
|
"""
|
|
14
39
|
返回首尾交换后的api输出
|
|
@@ -16,7 +41,7 @@ class ExchangeValuePerturbation(BasePerturbation):
|
|
|
16
41
|
"""
|
|
17
42
|
params.fuzzed_value = self.exchange_value(params.args[params.index])
|
|
18
43
|
if not self.is_fuzzed:
|
|
19
|
-
logger.warning(f"{self.
|
|
44
|
+
logger.warning(f"{self.api_name_with_id} can not exchange value.")
|
|
20
45
|
return False
|
|
21
46
|
return self.get_fuzzed_result(params)
|
|
22
47
|
|
|
@@ -25,22 +50,23 @@ class ExchangeValuePerturbation(BasePerturbation):
|
|
|
25
50
|
返回首尾交换后的api输入
|
|
26
51
|
|
|
27
52
|
"""
|
|
28
|
-
if isinstance(inputs, Tensor):
|
|
29
|
-
if
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
53
|
+
if isinstance(inputs, Tensor) and ops.is_floating_point(inputs):
|
|
54
|
+
if self.is_fuzzed or not self._check_tensor_shape(inputs):
|
|
55
|
+
return inputs
|
|
56
|
+
result = inputs.copy()
|
|
57
|
+
if len(inputs.shape) == 1:
|
|
58
|
+
first_element = inputs[0]
|
|
59
|
+
last_element = inputs[-1]
|
|
60
|
+
result[0] = last_element
|
|
61
|
+
result[-1] = first_element
|
|
62
|
+
else:
|
|
63
|
+
first_element = inputs[0][0]
|
|
64
|
+
last_element = inputs[-1][-1]
|
|
65
|
+
result[0][0] = last_element
|
|
66
|
+
result[-1][-1] = first_element
|
|
67
|
+
|
|
68
|
+
self.is_fuzzed = True
|
|
69
|
+
return result
|
|
44
70
|
|
|
45
71
|
if isinstance(inputs, dict):
|
|
46
72
|
return {k: self.exchange_value(v) for k, v in inputs.items()}
|
|
@@ -1,13 +1,29 @@
|
|
|
1
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
1
16
|
from typing import Any
|
|
2
17
|
|
|
3
18
|
import mindspore as ms
|
|
4
19
|
from mindspore import Tensor, ops
|
|
5
20
|
|
|
6
|
-
from msprobe.
|
|
7
|
-
from msprobe.mindspore.free_benchmark.common.handler_params import HandlerParams
|
|
8
|
-
from msprobe.mindspore.common.const import FreeBenchmarkConst
|
|
21
|
+
from msprobe.core.common.const import Const
|
|
9
22
|
from msprobe.mindspore.common.log import logger
|
|
10
|
-
from msprobe.mindspore.common.
|
|
23
|
+
from msprobe.mindspore.free_benchmark.common.config import Config
|
|
24
|
+
from msprobe.mindspore.free_benchmark.common.handler_params import HandlerParams
|
|
25
|
+
from msprobe.mindspore.free_benchmark.common.utils import Tools
|
|
26
|
+
from msprobe.mindspore.free_benchmark.perturbation.base_perturbation import BasePerturbation
|
|
11
27
|
|
|
12
28
|
|
|
13
29
|
class ImprovePrecisionPerturbation(BasePerturbation):
|
|
@@ -26,10 +42,15 @@ class ImprovePrecisionPerturbation(BasePerturbation):
|
|
|
26
42
|
def handle(self, params: HandlerParams) -> Any:
|
|
27
43
|
args = self.improve_tensor_precision(params.args)
|
|
28
44
|
kwargs = self.improve_tensor_precision(params.kwargs)
|
|
29
|
-
fuzzed_value = args
|
|
30
|
-
if self.api_name in Const.COMMUNICATION_API_LIST:
|
|
31
|
-
params.fuzzed_value = fuzzed_value
|
|
32
45
|
if not self.is_fuzzed:
|
|
33
|
-
logger.warning(f"{self.
|
|
46
|
+
logger.warning(f"{self.api_name_with_id} can not improve precision.")
|
|
34
47
|
return False
|
|
48
|
+
|
|
49
|
+
if Config.stage == Const.BACKWARD:
|
|
50
|
+
fuzzed_result = Tools.get_grad(params.original_func, *args, **kwargs)
|
|
51
|
+
if fuzzed_result is not None:
|
|
52
|
+
return fuzzed_result
|
|
53
|
+
else:
|
|
54
|
+
return False
|
|
55
|
+
|
|
35
56
|
return params.original_func(*args, **kwargs)
|
|
@@ -1,7 +1,22 @@
|
|
|
1
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
1
16
|
from typing import Any
|
|
2
17
|
|
|
3
|
-
from msprobe.mindspore.free_benchmark.perturbation.base_perturbation import BasePerturbation
|
|
4
18
|
from msprobe.mindspore.free_benchmark.common.handler_params import HandlerParams
|
|
19
|
+
from msprobe.mindspore.free_benchmark.perturbation.base_perturbation import BasePerturbation
|
|
5
20
|
|
|
6
21
|
|
|
7
22
|
class NoChangePerturbation(BasePerturbation):
|
|
@@ -1,10 +1,25 @@
|
|
|
1
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
1
16
|
from msprobe.mindspore.common.const import FreeBenchmarkConst
|
|
2
17
|
from msprobe.mindspore.free_benchmark.common.config import Config
|
|
3
|
-
from .add_noise import AddNoisePerturbation
|
|
4
|
-
from .bit_noise import BitNoisePerturbation
|
|
5
|
-
from .
|
|
6
|
-
from .improve_precision import ImprovePrecisionPerturbation
|
|
7
|
-
from .
|
|
18
|
+
from msprobe.mindspore.free_benchmark.perturbation.add_noise import AddNoisePerturbation
|
|
19
|
+
from msprobe.mindspore.free_benchmark.perturbation.bit_noise import BitNoisePerturbation
|
|
20
|
+
from msprobe.mindspore.free_benchmark.perturbation.exchange_value import ExchangeValuePerturbation
|
|
21
|
+
from msprobe.mindspore.free_benchmark.perturbation.improve_precision import ImprovePrecisionPerturbation
|
|
22
|
+
from msprobe.mindspore.free_benchmark.perturbation.no_change import NoChangePerturbation
|
|
8
23
|
|
|
9
24
|
|
|
10
25
|
class PerturbationFactory:
|
|
@@ -21,9 +36,9 @@ class PerturbationFactory:
|
|
|
21
36
|
}
|
|
22
37
|
|
|
23
38
|
@staticmethod
|
|
24
|
-
def create(
|
|
39
|
+
def create(api_name_with_id: str):
|
|
25
40
|
perturbation = PerturbationFactory.perturbations.get(Config.pert_type)
|
|
26
41
|
if perturbation:
|
|
27
|
-
return perturbation(
|
|
42
|
+
return perturbation(api_name_with_id)
|
|
28
43
|
else:
|
|
29
44
|
raise Exception(f'{Config.pert_type} is a invalid perturbation type')
|
|
@@ -1,6 +1,21 @@
|
|
|
1
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
1
16
|
from msprobe.mindspore.common.const import Const
|
|
2
17
|
from msprobe.mindspore.debugger.debugger_config import DebuggerConfig
|
|
3
|
-
from msprobe.mindspore.free_benchmark.api_pynative_self_check import
|
|
18
|
+
from msprobe.mindspore.free_benchmark.api_pynative_self_check import ApiPyNativeSelfCheck
|
|
4
19
|
|
|
5
20
|
|
|
6
21
|
class SelfCheckToolFactory:
|
|
@@ -13,7 +28,7 @@ class SelfCheckToolFactory:
|
|
|
13
28
|
Const.API: {
|
|
14
29
|
Const.GRAPH_KBYK_MODE: None,
|
|
15
30
|
Const.GRAPH_GE_MODE: None,
|
|
16
|
-
Const.PYNATIVE_MODE:
|
|
31
|
+
Const.PYNATIVE_MODE: ApiPyNativeSelfCheck
|
|
17
32
|
},
|
|
18
33
|
Const.KERNEL: {
|
|
19
34
|
Const.GRAPH_KBYK_MODE: None,
|
|
@@ -1,15 +1,30 @@
|
|
|
1
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
1
16
|
import os
|
|
2
17
|
import threading
|
|
3
|
-
from typing import Dict, Union
|
|
18
|
+
from typing import Dict, Union, Tuple
|
|
4
19
|
|
|
5
|
-
from msprobe.core.
|
|
20
|
+
from msprobe.core.common.utils import is_int
|
|
21
|
+
from msprobe.core.common.file_utils import create_directory, check_path_before_create
|
|
6
22
|
from msprobe.core.grad_probe.constant import GradConst
|
|
23
|
+
from msprobe.core.grad_probe.utils import check_str, check_bounds_element, check_param_element
|
|
7
24
|
from msprobe.mindspore.common.log import logger
|
|
8
|
-
from msprobe.core.common.file_utils import create_directory, check_path_before_create
|
|
9
25
|
|
|
10
26
|
|
|
11
27
|
class GlobalContext:
|
|
12
|
-
|
|
13
28
|
_instance = None
|
|
14
29
|
_instance_lock = threading.Lock()
|
|
15
30
|
_setting = {
|
|
@@ -18,7 +33,7 @@ class GlobalContext:
|
|
|
18
33
|
GradConst.STEP: None,
|
|
19
34
|
GradConst.RANK: None,
|
|
20
35
|
GradConst.CURRENT_STEP: 0,
|
|
21
|
-
GradConst.BOUNDS: [-
|
|
36
|
+
GradConst.BOUNDS: [-1, 0, 1],
|
|
22
37
|
GradConst.OUTPUT_PATH: None
|
|
23
38
|
}
|
|
24
39
|
|
|
@@ -31,19 +46,19 @@ class GlobalContext:
|
|
|
31
46
|
|
|
32
47
|
def init_context(self, config_dict: Dict):
|
|
33
48
|
level = config_dict.get(GradConst.LEVEL)
|
|
34
|
-
check_str(level, variable_name
|
|
49
|
+
check_str(level, variable_name="level in yaml")
|
|
35
50
|
if level in GradConst.SUPPORTED_LEVEL:
|
|
36
51
|
self._setting[GradConst.LEVEL] = config_dict.get(GradConst.LEVEL)
|
|
37
52
|
else:
|
|
38
53
|
raise ValueError("Invalid level set in config yaml file, level option: L0, L1, L2")
|
|
39
54
|
|
|
40
|
-
self._set_input_list(config_dict, GradConst.PARAM_LIST, str)
|
|
41
|
-
self._set_input_list(config_dict, GradConst.BOUNDS, float)
|
|
42
|
-
self._set_input_list(config_dict, GradConst.STEP, int)
|
|
43
|
-
self._set_input_list(config_dict, GradConst.RANK, int)
|
|
55
|
+
self._set_input_list(config_dict, GradConst.PARAM_LIST, (str,), element_check=check_param_element)
|
|
56
|
+
self._set_input_list(config_dict, GradConst.BOUNDS, (float, int), element_check=check_bounds_element)
|
|
57
|
+
self._set_input_list(config_dict, GradConst.STEP, (int,))
|
|
58
|
+
self._set_input_list(config_dict, GradConst.RANK, (int,))
|
|
44
59
|
|
|
45
60
|
output_path = config_dict.get(GradConst.OUTPUT_PATH)
|
|
46
|
-
check_str(output_path, variable_name
|
|
61
|
+
check_str(output_path, variable_name="output_path in yaml")
|
|
47
62
|
try:
|
|
48
63
|
check_path_before_create(output_path)
|
|
49
64
|
except RuntimeError as err:
|
|
@@ -70,21 +85,36 @@ class GlobalContext:
|
|
|
70
85
|
dump_rank_list = self.get_context(GradConst.RANK)
|
|
71
86
|
return (not dump_rank_list) or (rank in dump_rank_list)
|
|
72
87
|
|
|
73
|
-
def
|
|
74
|
-
|
|
88
|
+
def _get_type_str(self, dtype: Union[int, str, float, Tuple[int, str, float]]):
|
|
89
|
+
if isinstance(dtype, tuple):
|
|
90
|
+
return "/".join([self._get_type_str(element) for element in dtype])
|
|
75
91
|
if dtype == int:
|
|
76
92
|
type_str = "integer"
|
|
77
93
|
elif dtype == float:
|
|
78
94
|
type_str = "float"
|
|
79
95
|
else:
|
|
80
96
|
type_str = "string"
|
|
97
|
+
return type_str
|
|
98
|
+
|
|
99
|
+
def _set_input_list(self, config_dict: Dict, name: str,
|
|
100
|
+
dtype: Union[int, str, float, Tuple[int, str, float]], element_check=None):
|
|
101
|
+
value = config_dict.get(name)
|
|
102
|
+
type_str = self._get_type_str(dtype)
|
|
81
103
|
if value and isinstance(value, list):
|
|
82
104
|
for val in value:
|
|
83
105
|
if not isinstance(val, dtype):
|
|
84
|
-
logger.warning(f"Invalid {name} which must be None or list of {type_str}")
|
|
106
|
+
logger.warning(f"Invalid {name} which must be None or list of {type_str}, use default value.")
|
|
107
|
+
return
|
|
108
|
+
elif isinstance(val, int) and not is_int(val):
|
|
109
|
+
logger.warning(f"Invalid {name} which must be None or list of int, use default value.")
|
|
110
|
+
return
|
|
111
|
+
if element_check and not element_check(val):
|
|
112
|
+
logger.warning(f"Given {name} violates some rules, use default value.")
|
|
85
113
|
return
|
|
114
|
+
|
|
86
115
|
self._setting[name] = value
|
|
87
116
|
else:
|
|
88
117
|
logger.warning(f"{name} is None or not a list with valid items, use default value.")
|
|
89
118
|
|
|
119
|
+
|
|
90
120
|
grad_context = GlobalContext()
|
|
@@ -1,20 +1,33 @@
|
|
|
1
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
import multiprocessing
|
|
1
17
|
import os
|
|
2
18
|
import time
|
|
3
|
-
from typing import List, Tuple
|
|
4
|
-
import multiprocessing
|
|
5
19
|
from multiprocessing import Process
|
|
20
|
+
from typing import List
|
|
6
21
|
|
|
7
|
-
import numpy as np
|
|
8
22
|
import mindspore as ms
|
|
9
|
-
|
|
10
|
-
from mindspore.ops import operations as P
|
|
23
|
+
import numpy as np
|
|
11
24
|
from mindspore.common.parameter import Parameter
|
|
12
|
-
|
|
13
|
-
from msprobe.core.grad_probe.utils import ListCache
|
|
14
|
-
from msprobe.core.grad_probe.constant import GradConst
|
|
15
|
-
from msprobe.mindspore.common.log import logger
|
|
25
|
+
from mindspore.communication import get_rank
|
|
16
26
|
from msprobe.core.common.file_utils import (create_directory, check_file_or_directory_path,
|
|
17
27
|
write_csv, remove_path, move_file, load_npy)
|
|
28
|
+
from msprobe.core.grad_probe.constant import GradConst
|
|
29
|
+
from msprobe.core.grad_probe.utils import ListCache
|
|
30
|
+
from msprobe.mindspore.common.log import logger
|
|
18
31
|
from msprobe.mindspore.grad_probe.global_context import grad_context, GlobalContext
|
|
19
32
|
|
|
20
33
|
|
|
@@ -28,12 +41,12 @@ def get_rank_id():
|
|
|
28
41
|
|
|
29
42
|
@ms.jit
|
|
30
43
|
def grad_dump(dump_dir: str, g_name: str, dump_step: Parameter, grad: ms.Tensor, level: str, bounds: List):
|
|
31
|
-
|
|
44
|
+
"""
|
|
32
45
|
Dump gradient statistic data.
|
|
33
46
|
level0: [step, max, min, norm, shape_dim, shape]
|
|
34
47
|
level1: [step, max, min, norm, shape_dim, shape] + grad_bool_data
|
|
35
48
|
level2: [step, max, min, norm, shape_dim, shape, dist_dim, dist] + grad_bool_data
|
|
36
|
-
|
|
49
|
+
"""
|
|
37
50
|
dump_path = os.path.join(dump_dir, g_name)
|
|
38
51
|
dump_dir_path = dump_path + "_dir"
|
|
39
52
|
save_op = ms.ops.TensorDump()
|
|
@@ -182,7 +195,7 @@ class CSVGenerator(Process):
|
|
|
182
195
|
shape_dim = int(stat_data[GradConst.SHAPE_DIM_IDX])
|
|
183
196
|
file_name = os.path.basename(file_path)
|
|
184
197
|
prefix_idx = len(file_name.split("_")[0])
|
|
185
|
-
param_name = file_name[(prefix_idx + 1)
|
|
198
|
+
param_name = file_name[(prefix_idx + 1): -(len(GradConst.NPY_SUFFIX) + 1)]
|
|
186
199
|
if not param_name:
|
|
187
200
|
raise RuntimeError("Invalid gradient statistic file name.")
|
|
188
201
|
csv_line = [param_name]
|
|
@@ -224,8 +237,9 @@ class CSVGenerator(Process):
|
|
|
224
237
|
if i == 0:
|
|
225
238
|
intervals.append(f"(-inf, {self.bounds[i]}]")
|
|
226
239
|
else:
|
|
227
|
-
intervals.append(f"({self.bounds[i-1]}, {self.bounds[i]}]")
|
|
240
|
+
intervals.append(f"({self.bounds[i - 1]}, {self.bounds[i]}]")
|
|
228
241
|
intervals.extend([f"({self.bounds[-1]}, inf)", "=0"])
|
|
229
242
|
return intervals
|
|
230
243
|
|
|
244
|
+
|
|
231
245
|
csv_generator = CSVGenerator()
|
|
@@ -1,7 +1,22 @@
|
|
|
1
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
from msprobe.core.grad_probe.constant import GradConst
|
|
1
17
|
from msprobe.mindspore.grad_probe.global_context import grad_context
|
|
2
18
|
from msprobe.mindspore.grad_probe.grad_analyzer import csv_generator
|
|
3
19
|
from msprobe.mindspore.grad_probe.hook import hook_optimizer
|
|
4
|
-
from msprobe.core.grad_probe.constant import GradConst
|
|
5
20
|
|
|
6
21
|
|
|
7
22
|
class GradientMonitor:
|