mindstudio-probe 1.0.4__py3-none-any.whl → 1.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.1.dist-info}/METADATA +5 -5
- mindstudio_probe-1.1.1.dist-info/RECORD +341 -0
- {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.1.dist-info}/WHEEL +1 -1
- {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.1.dist-info}/entry_points.txt +0 -1
- msprobe/README.md +84 -18
- msprobe/__init__.py +16 -1
- msprobe/config.json +1 -5
- msprobe/core/advisor/advisor.py +16 -11
- msprobe/core/advisor/advisor_const.py +6 -7
- msprobe/core/advisor/advisor_result.py +12 -12
- msprobe/core/common/const.py +164 -3
- msprobe/core/common/exceptions.py +26 -4
- msprobe/core/common/file_utils.py +196 -27
- msprobe/core/common/inplace_op_checker.py +53 -0
- msprobe/core/common/inplace_ops.yaml +251 -0
- msprobe/core/common/log.py +46 -18
- msprobe/core/common/utils.py +308 -209
- msprobe/core/common_config.py +60 -38
- msprobe/core/compare/acc_compare.py +332 -94
- msprobe/core/compare/check.py +104 -22
- msprobe/core/compare/compare_cli.py +42 -5
- msprobe/core/compare/highlight.py +162 -57
- msprobe/core/compare/layer_mapping/__init__.py +19 -0
- msprobe/core/compare/layer_mapping/data_scope_parser.py +235 -0
- msprobe/core/compare/layer_mapping/layer_mapping.py +242 -0
- msprobe/core/compare/layer_mapping/postprocess_pass.py +94 -0
- msprobe/core/compare/multiprocessing_compute.py +33 -8
- msprobe/core/compare/npy_compare.py +73 -29
- msprobe/core/compare/utils.py +306 -247
- msprobe/core/data_dump/data_collector.py +44 -43
- msprobe/core/data_dump/data_processor/base.py +88 -35
- msprobe/core/data_dump/data_processor/factory.py +20 -3
- msprobe/core/data_dump/data_processor/mindspore_processor.py +14 -8
- msprobe/core/data_dump/data_processor/pytorch_processor.py +180 -66
- msprobe/core/data_dump/json_writer.py +63 -42
- msprobe/core/data_dump/scope.py +143 -48
- msprobe/core/grad_probe/constant.py +31 -13
- msprobe/core/grad_probe/grad_compare.py +20 -4
- msprobe/core/grad_probe/utils.py +44 -3
- msprobe/core/overflow_check/abnormal_scene.py +185 -0
- msprobe/core/overflow_check/api_info.py +55 -0
- msprobe/core/overflow_check/checker.py +138 -0
- msprobe/core/overflow_check/filter.py +157 -0
- msprobe/core/overflow_check/ignore_rules.yaml +55 -0
- msprobe/core/overflow_check/level.py +22 -0
- msprobe/core/overflow_check/utils.py +28 -0
- msprobe/docs/01.installation.md +29 -9
- msprobe/docs/02.config_introduction.md +83 -84
- msprobe/docs/03.config_examples.md +3 -20
- msprobe/docs/04.kernel_dump_PyTorch.md +73 -0
- msprobe/docs/05.data_dump_PyTorch.md +143 -13
- msprobe/docs/06.data_dump_MindSpore.md +197 -88
- msprobe/docs/07.accuracy_checker_PyTorch.md +69 -46
- msprobe/docs/08.accuracy_checker_online_PyTorch.md +52 -17
- msprobe/docs/09.accuracy_checker_MindSpore.md +51 -15
- msprobe/docs/10.accuracy_compare_PyTorch.md +187 -99
- msprobe/docs/11.accuracy_compare_MindSpore.md +253 -31
- msprobe/docs/12.overflow_check_PyTorch.md +1 -1
- msprobe/docs/13.overflow_check_MindSpore.md +6 -6
- msprobe/docs/15.free_benchmarking_PyTorch.md +60 -55
- msprobe/docs/16.free_benchmarking_MindSpore.md +159 -0
- msprobe/docs/17.grad_probe.md +19 -22
- msprobe/docs/18.online_dispatch.md +89 -0
- msprobe/docs/19.monitor.md +468 -0
- msprobe/docs/20.monitor_performance_baseline.md +52 -0
- msprobe/docs/21.visualization_PyTorch.md +386 -0
- msprobe/docs/22.visualization_MindSpore.md +384 -0
- msprobe/docs/23.tool_function_introduction.md +28 -0
- msprobe/docs/{FAQ_PyTorch.md → FAQ.md} +25 -10
- msprobe/docs/data_dump_Mindspore/dynamic_graph_quick_start_example.md +211 -0
- msprobe/docs/img/compare_result.png +0 -0
- msprobe/docs/img/monitor/cpu_info.png +0 -0
- msprobe/docs/img/ms_dump.png +0 -0
- msprobe/docs/img/ms_layer.png +0 -0
- msprobe/docs/img/pt_dump.png +0 -0
- msprobe/mindspore/__init__.py +16 -0
- msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +130 -138
- msprobe/mindspore/api_accuracy_checker/api_info.py +27 -5
- msprobe/mindspore/api_accuracy_checker/api_runner.py +43 -18
- msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +21 -7
- msprobe/mindspore/api_accuracy_checker/checker_support_api.yaml +77 -0
- msprobe/mindspore/api_accuracy_checker/cmd_parser.py +63 -1
- msprobe/mindspore/api_accuracy_checker/compute_element.py +59 -24
- msprobe/mindspore/api_accuracy_checker/data_manager.py +264 -0
- msprobe/mindspore/api_accuracy_checker/main.py +27 -3
- msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +206 -0
- msprobe/mindspore/api_accuracy_checker/multi_data_manager.py +58 -0
- msprobe/mindspore/api_accuracy_checker/type_mapping.py +22 -5
- msprobe/mindspore/api_accuracy_checker/utils.py +34 -17
- msprobe/mindspore/cell_processor.py +58 -13
- msprobe/mindspore/common/const.py +35 -13
- msprobe/mindspore/common/log.py +5 -9
- msprobe/mindspore/common/utils.py +60 -5
- msprobe/mindspore/compare/distributed_compare.py +15 -28
- msprobe/mindspore/compare/ms_compare.py +319 -158
- msprobe/mindspore/compare/ms_graph_compare.py +99 -49
- msprobe/mindspore/debugger/debugger_config.py +20 -14
- msprobe/mindspore/debugger/precision_debugger.py +43 -13
- msprobe/mindspore/dump/dump_tool_factory.py +18 -1
- msprobe/mindspore/dump/hook_cell/api_registry.py +23 -3
- msprobe/mindspore/dump/hook_cell/primitive_hooks.py +203 -0
- msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +107 -10
- msprobe/mindspore/dump/hook_cell/wrap_api.py +21 -13
- msprobe/mindspore/dump/jit_dump.py +56 -20
- msprobe/mindspore/dump/kernel_graph_dump.py +19 -5
- msprobe/mindspore/dump/kernel_kbyk_dump.py +19 -6
- msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +140 -0
- msprobe/mindspore/dym_loader/hook_dynamic_loader.h +53 -0
- msprobe/mindspore/free_benchmark/api_pynative_self_check.py +162 -41
- msprobe/mindspore/free_benchmark/common/config.py +15 -0
- msprobe/mindspore/free_benchmark/common/handler_params.py +15 -1
- msprobe/mindspore/free_benchmark/common/utils.py +37 -8
- msprobe/mindspore/free_benchmark/data/support_wrap_ops.yaml +0 -204
- msprobe/mindspore/free_benchmark/handler/base_handler.py +20 -5
- msprobe/mindspore/free_benchmark/handler/check_handler.py +21 -7
- msprobe/mindspore/free_benchmark/handler/fix_handler.py +18 -3
- msprobe/mindspore/free_benchmark/handler/handler_factory.py +21 -6
- msprobe/mindspore/free_benchmark/perturbation/add_noise.py +23 -8
- msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +29 -5
- msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +25 -10
- msprobe/mindspore/free_benchmark/perturbation/exchange_value.py +45 -19
- msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +29 -8
- msprobe/mindspore/free_benchmark/perturbation/no_change.py +16 -1
- msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +22 -7
- msprobe/mindspore/free_benchmark/self_check_tool_factory.py +17 -2
- msprobe/mindspore/grad_probe/global_context.py +44 -14
- msprobe/mindspore/grad_probe/grad_analyzer.py +27 -13
- msprobe/mindspore/grad_probe/grad_monitor.py +16 -1
- msprobe/mindspore/grad_probe/grad_stat_csv.py +33 -5
- msprobe/mindspore/grad_probe/hook.py +24 -10
- msprobe/mindspore/grad_probe/utils.py +18 -5
- msprobe/mindspore/ms_config.py +22 -15
- msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +20 -6
- msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +15 -0
- msprobe/mindspore/runtime.py +15 -0
- msprobe/mindspore/service.py +75 -150
- msprobe/mindspore/task_handler_factory.py +15 -0
- msprobe/msprobe.py +24 -7
- msprobe/pytorch/__init__.py +23 -3
- msprobe/pytorch/api_accuracy_checker/common/config.py +81 -2
- msprobe/pytorch/api_accuracy_checker/common/utils.py +53 -21
- msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +19 -2
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +50 -25
- msprobe/pytorch/api_accuracy_checker/compare/compare.py +51 -21
- msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +23 -6
- msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +28 -8
- msprobe/pytorch/api_accuracy_checker/config.yaml +1 -1
- msprobe/pytorch/api_accuracy_checker/generate_op_script/config_op.json +9 -0
- msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +454 -0
- msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +365 -0
- msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +73 -33
- msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +44 -18
- msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +32 -11
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +122 -172
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +158 -4
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +30 -24
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +68 -31
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +27 -4
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/dump_dispatch.py +115 -0
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +26 -9
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/torch_ops_config.yaml +63 -0
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +44 -0
- msprobe/pytorch/bench_functions/__init__.py +18 -3
- msprobe/pytorch/bench_functions/apply_adam_w.py +15 -0
- msprobe/pytorch/bench_functions/confusion_transpose.py +20 -1
- msprobe/pytorch/bench_functions/fast_gelu.py +15 -0
- msprobe/pytorch/bench_functions/layer_norm_eval.py +15 -0
- msprobe/pytorch/bench_functions/linear.py +15 -0
- msprobe/pytorch/bench_functions/matmul_backward.py +33 -6
- msprobe/pytorch/bench_functions/npu_fusion_attention.py +280 -157
- msprobe/pytorch/bench_functions/rms_norm.py +15 -0
- msprobe/pytorch/bench_functions/rotary_mul.py +32 -9
- msprobe/pytorch/bench_functions/scaled_mask_softmax.py +15 -0
- msprobe/pytorch/bench_functions/swiglu.py +29 -6
- msprobe/pytorch/common/__init__.py +15 -0
- msprobe/pytorch/common/log.py +18 -6
- msprobe/pytorch/common/parse_json.py +31 -16
- msprobe/pytorch/common/utils.py +96 -40
- msprobe/pytorch/compare/distributed_compare.py +13 -14
- msprobe/pytorch/compare/match.py +15 -0
- msprobe/pytorch/compare/pt_compare.py +44 -10
- msprobe/pytorch/debugger/debugger_config.py +69 -52
- msprobe/pytorch/debugger/precision_debugger.py +72 -24
- msprobe/pytorch/dump/kernel_dump/kernel_config.py +33 -0
- msprobe/pytorch/free_benchmark/__init__.py +20 -5
- msprobe/pytorch/free_benchmark/common/constant.py +15 -0
- msprobe/pytorch/free_benchmark/common/counter.py +15 -0
- msprobe/pytorch/free_benchmark/common/enums.py +43 -0
- msprobe/pytorch/free_benchmark/common/params.py +23 -1
- msprobe/pytorch/free_benchmark/common/utils.py +43 -5
- msprobe/pytorch/free_benchmark/compare/grad_saver.py +47 -9
- msprobe/pytorch/free_benchmark/compare/single_benchmark.py +17 -0
- msprobe/pytorch/free_benchmark/main.py +19 -4
- msprobe/pytorch/free_benchmark/perturbed_layers/base_layer.py +15 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/layer_factory.py +19 -4
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +18 -1
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +21 -4
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +28 -2
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +19 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/no_change.py +15 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/npu_base_layser.py +15 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/run_cpu.py +15 -0
- msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +65 -16
- msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +15 -0
- msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +21 -5
- msprobe/pytorch/free_benchmark/result_handlers/handler_factory.py +15 -0
- msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +19 -4
- msprobe/pytorch/function_factory.py +17 -2
- msprobe/pytorch/functional/module_dump.py +84 -0
- msprobe/pytorch/grad_probe/grad_monitor.py +23 -6
- msprobe/pytorch/grad_probe/grad_stat_csv.py +40 -10
- msprobe/pytorch/hook_module/__init__.py +16 -1
- msprobe/pytorch/hook_module/api_registry.py +13 -8
- msprobe/pytorch/hook_module/hook_module.py +17 -19
- msprobe/pytorch/hook_module/support_wrap_ops.yaml +1 -0
- msprobe/pytorch/hook_module/utils.py +4 -6
- msprobe/pytorch/hook_module/wrap_aten.py +12 -11
- msprobe/pytorch/hook_module/wrap_distributed.py +6 -7
- msprobe/pytorch/hook_module/wrap_functional.py +21 -20
- msprobe/pytorch/hook_module/wrap_npu_custom.py +9 -17
- msprobe/pytorch/hook_module/wrap_tensor.py +4 -6
- msprobe/pytorch/hook_module/wrap_torch.py +4 -6
- msprobe/pytorch/hook_module/wrap_vf.py +4 -6
- msprobe/pytorch/module_processer.py +18 -6
- msprobe/pytorch/monitor/anomaly_analyse.py +201 -0
- msprobe/pytorch/monitor/anomaly_detect.py +340 -0
- msprobe/pytorch/monitor/distributed/distributed_ops.yaml +19 -0
- msprobe/pytorch/monitor/distributed/stack_blacklist.yaml +5 -0
- msprobe/pytorch/monitor/distributed/wrap_distributed.py +272 -0
- msprobe/pytorch/monitor/features.py +108 -0
- msprobe/pytorch/monitor/module_hook.py +870 -0
- msprobe/pytorch/monitor/module_metric.py +193 -0
- msprobe/pytorch/monitor/module_spec_verifier.py +93 -0
- msprobe/pytorch/monitor/optimizer_collect.py +295 -0
- msprobe/pytorch/monitor/unittest/__init__.py +0 -0
- msprobe/pytorch/monitor/unittest/test_monitor.py +145 -0
- msprobe/pytorch/monitor/utils.py +250 -0
- msprobe/pytorch/monitor/visualizer.py +59 -0
- msprobe/pytorch/online_dispatch/__init__.py +2 -3
- msprobe/pytorch/online_dispatch/compare.py +38 -48
- msprobe/pytorch/online_dispatch/dispatch.py +50 -25
- msprobe/pytorch/online_dispatch/dump_compare.py +21 -9
- msprobe/pytorch/online_dispatch/single_compare.py +60 -39
- msprobe/pytorch/online_dispatch/torch_ops_config.yaml +9 -1
- msprobe/pytorch/online_dispatch/utils.py +48 -23
- msprobe/pytorch/parse.py +15 -0
- msprobe/pytorch/parse_tool/cli.py +5 -6
- msprobe/pytorch/parse_tool/lib/compare.py +19 -26
- msprobe/pytorch/parse_tool/lib/config.py +1 -1
- msprobe/pytorch/parse_tool/lib/parse_tool.py +4 -2
- msprobe/pytorch/parse_tool/lib/utils.py +40 -55
- msprobe/pytorch/parse_tool/lib/visualization.py +3 -1
- msprobe/pytorch/pt_config.py +192 -40
- msprobe/pytorch/service.py +110 -35
- msprobe/visualization/__init__.py +14 -0
- msprobe/visualization/builder/__init__.py +14 -0
- msprobe/visualization/builder/graph_builder.py +165 -0
- msprobe/visualization/builder/msprobe_adapter.py +205 -0
- msprobe/visualization/compare/__init__.py +14 -0
- msprobe/visualization/compare/graph_comparator.py +130 -0
- msprobe/visualization/compare/mode_adapter.py +211 -0
- msprobe/visualization/graph/__init__.py +14 -0
- msprobe/visualization/graph/base_node.py +124 -0
- msprobe/visualization/graph/graph.py +200 -0
- msprobe/visualization/graph/node_colors.py +95 -0
- msprobe/visualization/graph/node_op.py +39 -0
- msprobe/visualization/graph_service.py +214 -0
- msprobe/visualization/utils.py +232 -0
- mindstudio_probe-1.0.4.dist-info/RECORD +0 -276
- msprobe/docs/04.acl_config_examples.md +0 -76
- msprobe/mindspore/free_benchmark/decorator/dec_forward.py +0 -43
- msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +0 -107
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/ssl_config.py +0 -10
- msprobe/pytorch/functional/dump_module.py +0 -39
- {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.1.dist-info}/LICENSE +0 -0
- {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.1.dist-info}/top_level.txt +0 -0
- /msprobe/{mindspore/free_benchmark/decorator → pytorch/monitor}/__init__.py +0 -0
- /msprobe/pytorch/{functional/data_processor.py → monitor/distributed/__init__.py} +0 -0
|
@@ -1,3 +1,18 @@
|
|
|
1
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
1
16
|
import torch
|
|
2
17
|
|
|
3
18
|
|
|
@@ -25,15 +40,22 @@ def npu_rotary_mul_backward(dy_tensor, x, r1, r2):
|
|
|
25
40
|
x_shape = x.shape
|
|
26
41
|
h = x.float()
|
|
27
42
|
grad = dy_tensor.float()
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
43
|
+
if len(r1_shape) < 4 or len(x_shape) < 4:
|
|
44
|
+
raise RuntimeError(f"Shape of r1 and x should at least be 4-dimension, "
|
|
45
|
+
f"but got r1 shape:{r1_shape}, x shape:{x_shape}")
|
|
46
|
+
condition_1 = (r1_shape[0] == 1
|
|
47
|
+
and r1_shape[1] == x_shape[1]
|
|
48
|
+
and r1_shape[2] == 1
|
|
49
|
+
and r1_shape[3] == x_shape[3])
|
|
50
|
+
condition_2 = (r1_shape[0] == 1
|
|
51
|
+
and r1_shape[1] == 1
|
|
52
|
+
and r1_shape[2] == x_shape[2]
|
|
53
|
+
and r1_shape[3] == x_shape[3])
|
|
54
|
+
condition_3 = (r1_shape[0] == x_shape[0]
|
|
55
|
+
and r1_shape[1] == 1
|
|
56
|
+
and r1_shape[2] == 1
|
|
57
|
+
and r1_shape[3] == x_shape[3])
|
|
58
|
+
|
|
37
59
|
if condition_1:
|
|
38
60
|
for i in range(x_shape[0]):
|
|
39
61
|
for j in range(x_shape[2]):
|
|
@@ -49,4 +71,5 @@ def npu_rotary_mul_backward(dy_tensor, x, r1, r2):
|
|
|
49
71
|
for j in range(x_shape[2]):
|
|
50
72
|
r2_grad[:, 0, 0, :] += (x_new2[:, i, j, :] * grad[:, i, j, :])
|
|
51
73
|
r1_grad[:, 0, 0, :] += (h[:, i, j, :] * grad[:, i, j, :])
|
|
74
|
+
|
|
52
75
|
return x.grad.cpu(), r1_grad.cpu(), r2_grad.cpu()
|
|
@@ -1,3 +1,18 @@
|
|
|
1
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
1
16
|
import torch
|
|
2
17
|
|
|
3
18
|
|
|
@@ -1,16 +1,35 @@
|
|
|
1
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
1
16
|
import torch
|
|
2
17
|
|
|
3
18
|
|
|
4
19
|
def npu_swiglu(x, dim=-1):
|
|
5
20
|
tensor_dtype = x.dtype
|
|
6
21
|
|
|
7
|
-
|
|
22
|
+
try:
|
|
23
|
+
in_tensors = torch.chunk(x, 2, dim=dim)
|
|
24
|
+
except Exception as e:
|
|
25
|
+
raise RuntimeError(f"Invalid chunk x into 2 tensors with shape {x.shape} and dimension {dim}") from e
|
|
26
|
+
|
|
8
27
|
if tensor_dtype == torch.float32:
|
|
9
|
-
tensor_scalar = torch.sigmoid(torch.mul(
|
|
10
|
-
output_data = torch.mul(torch.mul(tensor_scalar,
|
|
28
|
+
tensor_scalar = torch.sigmoid(torch.mul(in_tensors[0], 1.0))
|
|
29
|
+
output_data = torch.mul(torch.mul(tensor_scalar, in_tensors[0]), in_tensors[1])
|
|
11
30
|
else:
|
|
12
|
-
tensor_self_float =
|
|
13
|
-
tensor_other_float =
|
|
31
|
+
tensor_self_float = in_tensors[0].type(torch.float)
|
|
32
|
+
tensor_other_float = in_tensors[1].type(torch.float)
|
|
14
33
|
tensor_out_float = torch.nn.functional.silu(tensor_self_float).type(tensor_dtype).type(
|
|
15
34
|
torch.float32) * tensor_other_float
|
|
16
35
|
output_data = tensor_out_float.type(tensor_dtype)
|
|
@@ -19,7 +38,11 @@ def npu_swiglu(x, dim=-1):
|
|
|
19
38
|
|
|
20
39
|
def npu_swiglu_backward(grad, x, dim=-1):
|
|
21
40
|
tensor_dtype = grad.dtype
|
|
22
|
-
|
|
41
|
+
try:
|
|
42
|
+
in_tensors = torch.chunk(x, 2, dim=dim)
|
|
43
|
+
except Exception as e:
|
|
44
|
+
raise RuntimeError(f"Invalid chunk x into 2 tensors with shape {x.shape} and dimension {dim}") from e
|
|
45
|
+
|
|
23
46
|
tensor_grad_out = grad
|
|
24
47
|
|
|
25
48
|
if tensor_dtype == torch.float16:
|
|
@@ -1,2 +1,17 @@
|
|
|
1
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
1
16
|
from .parse_json import parse_json_info_forward_backward
|
|
2
17
|
from .utils import seed_all
|
msprobe/pytorch/common/log.py
CHANGED
|
@@ -1,9 +1,21 @@
|
|
|
1
|
-
|
|
2
|
-
|
|
3
|
-
|
|
4
|
-
|
|
5
|
-
|
|
1
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
6
16
|
from msprobe.core.common.exceptions import DistributedNotInitializedError
|
|
17
|
+
from msprobe.core.common.log import BaseLogger
|
|
18
|
+
from msprobe.pytorch.common.utils import get_rank_if_initialized
|
|
7
19
|
|
|
8
20
|
|
|
9
21
|
class PyTorchLogger(BaseLogger):
|
|
@@ -18,4 +30,4 @@ class PyTorchLogger(BaseLogger):
|
|
|
18
30
|
return current_rank
|
|
19
31
|
|
|
20
32
|
|
|
21
|
-
logger = PyTorchLogger()
|
|
33
|
+
logger = PyTorchLogger()
|
|
@@ -1,25 +1,32 @@
|
|
|
1
|
-
|
|
1
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
2
15
|
|
|
3
16
|
from msprobe.core.common.exceptions import ParseJsonException
|
|
4
|
-
from msprobe.core.common.file_utils import
|
|
17
|
+
from msprobe.core.common.file_utils import load_json
|
|
18
|
+
from msprobe.core.common.log import logger
|
|
5
19
|
|
|
6
20
|
|
|
7
21
|
def parse_json_info_forward_backward(json_path):
|
|
8
|
-
|
|
9
|
-
name_struct = data_name.split('.')
|
|
10
|
-
if not name_struct[-1] == pattern:
|
|
11
|
-
raise ParseJsonException(ParseJsonException.UnexpectedNameStruct,
|
|
12
|
-
f"{data_name} in file {json_path}")
|
|
13
|
-
api_name = '.'.join(name_struct[:-1])
|
|
14
|
-
return api_name
|
|
15
|
-
|
|
16
|
-
with FileOpen(json_path, 'r') as f:
|
|
17
|
-
dump_json = json.load(f)
|
|
22
|
+
dump_json = load_json(json_path)
|
|
18
23
|
|
|
19
24
|
real_data_path = dump_json.get("dump_data_dir")
|
|
20
25
|
dump_data = dump_json.get("data")
|
|
26
|
+
if dump_data is None:
|
|
27
|
+
raise ParseJsonException(ParseJsonException.InvalidDumpJson, "something wrong with dump, no data found in dump.json")
|
|
21
28
|
if not dump_data:
|
|
22
|
-
|
|
29
|
+
logger.warning("data field is empty, no overflow data found.")
|
|
23
30
|
|
|
24
31
|
forward_data = {}
|
|
25
32
|
backward_data = {}
|
|
@@ -27,13 +34,21 @@ def parse_json_info_forward_backward(json_path):
|
|
|
27
34
|
if "Module" in data_name:
|
|
28
35
|
continue
|
|
29
36
|
if "forward" in data_name:
|
|
30
|
-
api_name = parse_data_name_with_pattern(data_name, "forward")
|
|
37
|
+
api_name = parse_data_name_with_pattern(data_name, "forward", json_path)
|
|
31
38
|
forward_data.update({api_name: data_item})
|
|
32
39
|
elif "backward" in data_name:
|
|
33
|
-
api_name = parse_data_name_with_pattern(data_name, "backward")
|
|
40
|
+
api_name = parse_data_name_with_pattern(data_name, "backward", json_path)
|
|
34
41
|
backward_data.update({api_name: data_item})
|
|
35
42
|
else:
|
|
36
43
|
raise ParseJsonException(ParseJsonException.UnexpectedNameStruct,
|
|
37
|
-
|
|
44
|
+
f"{data_name} in file {json_path}.")
|
|
38
45
|
|
|
39
46
|
return forward_data, backward_data, real_data_path
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
def parse_data_name_with_pattern(data_name, pattern, json_path):
|
|
50
|
+
name_struct = data_name.split('.')
|
|
51
|
+
if not name_struct[-1] == pattern:
|
|
52
|
+
raise ParseJsonException(ParseJsonException.UnexpectedNameStruct, f"{data_name} in file {json_path}")
|
|
53
|
+
api_name = '.'.join(name_struct[:-1])
|
|
54
|
+
return api_name
|
msprobe/pytorch/common/utils.py
CHANGED
|
@@ -1,8 +1,7 @@
|
|
|
1
|
-
|
|
2
|
-
#
|
|
3
|
-
|
|
4
|
-
#
|
|
5
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
1
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
6
5
|
# you may not use this file except in compliance with the License.
|
|
7
6
|
# You may obtain a copy of the License at
|
|
8
7
|
#
|
|
@@ -13,20 +12,23 @@
|
|
|
13
12
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
14
13
|
# See the License for the specific language governing permissions and
|
|
15
14
|
# limitations under the License.
|
|
16
|
-
|
|
15
|
+
|
|
17
16
|
import io
|
|
18
17
|
import os
|
|
18
|
+
import pickle
|
|
19
19
|
import random
|
|
20
20
|
import stat
|
|
21
|
+
from functools import wraps
|
|
22
|
+
|
|
23
|
+
import numpy as np
|
|
21
24
|
import torch
|
|
22
25
|
import torch.distributed as dist
|
|
23
|
-
import numpy as np
|
|
24
|
-
from functools import wraps
|
|
25
26
|
from msprobe.core.common.exceptions import DistributedNotInitializedError
|
|
26
|
-
from msprobe.core.common.log import logger
|
|
27
27
|
from msprobe.core.common.file_utils import (FileCheckConst, change_mode,
|
|
28
|
-
check_file_or_directory_path, check_path_before_create)
|
|
29
|
-
|
|
28
|
+
check_file_or_directory_path, check_path_before_create, FileOpen)
|
|
29
|
+
from msprobe.core.common.log import logger
|
|
30
|
+
from msprobe.core.common.utils import check_seed_all
|
|
31
|
+
from packaging import version
|
|
30
32
|
|
|
31
33
|
try:
|
|
32
34
|
import torch_npu
|
|
@@ -35,10 +37,8 @@ except ImportError:
|
|
|
35
37
|
else:
|
|
36
38
|
is_gpu = False
|
|
37
39
|
|
|
38
|
-
|
|
39
40
|
torch_without_guard_version = torch.__version__ >= '2.1'
|
|
40
41
|
|
|
41
|
-
|
|
42
42
|
if not is_gpu and not torch_without_guard_version:
|
|
43
43
|
from torch_npu.utils.device_guard import torch_device_guard as torch_npu_device_guard
|
|
44
44
|
|
|
@@ -46,7 +46,6 @@ npu_distributed_api = ['isend', 'irecv']
|
|
|
46
46
|
|
|
47
47
|
|
|
48
48
|
def parameter_adapter(func):
|
|
49
|
-
|
|
50
49
|
def handle_masked_select(input_tensor, indices):
|
|
51
50
|
masked_select_func = getattr(torch._C._VariableFunctionsClass, "masked_select")
|
|
52
51
|
if input_tensor.dtype == torch.bfloat16:
|
|
@@ -77,20 +76,22 @@ def parameter_adapter(func):
|
|
|
77
76
|
else:
|
|
78
77
|
res = [input_tensor[tensor_index] for tensor_index in indices]
|
|
79
78
|
return getattr(torch._C._VariableFunctionsClass, "stack")(res, 0)
|
|
80
|
-
if self.op_name_ == "__eq__" and args[1] is None:
|
|
79
|
+
if self.op_name_ == "__eq__" and len(args) > 1 and args[1] is None:
|
|
81
80
|
return False
|
|
82
81
|
return func(self, *args, **kwargs)
|
|
82
|
+
|
|
83
83
|
return inner
|
|
84
84
|
|
|
85
85
|
|
|
86
86
|
def torch_device_guard(func):
|
|
87
87
|
if is_gpu or torch_without_guard_version:
|
|
88
88
|
return func
|
|
89
|
-
# Parse args/kwargs matched torch.device objects
|
|
90
89
|
|
|
90
|
+
# Parse args/kwargs matched torch.device objects
|
|
91
91
|
@torch_npu_device_guard
|
|
92
92
|
def wrapper(*args, **kwargs):
|
|
93
93
|
return func(*args, **kwargs)
|
|
94
|
+
|
|
94
95
|
return wrapper
|
|
95
96
|
|
|
96
97
|
|
|
@@ -105,20 +106,28 @@ def get_rank_if_initialized():
|
|
|
105
106
|
|
|
106
107
|
|
|
107
108
|
def seed_all(seed=1234, mode=False):
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
torch.cuda
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
torch.
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
109
|
+
check_seed_all(seed, mode)
|
|
110
|
+
try:
|
|
111
|
+
random.seed(seed)
|
|
112
|
+
os.environ['PYTHONHASHSEED'] = str(seed)
|
|
113
|
+
np.random.seed(seed)
|
|
114
|
+
torch.manual_seed(seed)
|
|
115
|
+
cuda_version = torch.version.cuda
|
|
116
|
+
if cuda_version is not None and version.parse(cuda_version) >= version.parse("10.2"):
|
|
117
|
+
os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8'
|
|
118
|
+
os.environ['HCCL_DETERMINISTIC'] = str(mode)
|
|
119
|
+
torch.use_deterministic_algorithms(mode)
|
|
120
|
+
if is_gpu:
|
|
121
|
+
torch.cuda.manual_seed_all(seed)
|
|
122
|
+
torch.cuda.manual_seed(seed)
|
|
123
|
+
torch.backends.cudnn.deterministic = True
|
|
124
|
+
torch.backends.cudnn.enable = False
|
|
125
|
+
torch.backends.cudnn.benchmark = False
|
|
126
|
+
else:
|
|
127
|
+
torch_npu.npu.manual_seed_all(seed)
|
|
128
|
+
torch_npu.npu.manual_seed(seed)
|
|
129
|
+
except Exception as e:
|
|
130
|
+
logger.error(f"There is an unexpected error while determinating randomness. {e}")
|
|
122
131
|
|
|
123
132
|
|
|
124
133
|
class Const:
|
|
@@ -191,10 +200,7 @@ class Const:
|
|
|
191
200
|
ENV_ENABLE = "1"
|
|
192
201
|
ENV_DISABLE = "0"
|
|
193
202
|
|
|
194
|
-
MAX_SEED_VALUE = 2**32 - 1
|
|
195
|
-
|
|
196
|
-
INPLACE_LIST = ["broadcast", "all_reduce", "reduce", "all_gather", "gather", "scatter", "reduce_scatter",
|
|
197
|
-
"_reduce_scatter_base", "_all_gather_base", "all_to_all_single"]
|
|
203
|
+
MAX_SEED_VALUE = 2 ** 32 - 1
|
|
198
204
|
|
|
199
205
|
TASK_LIST = ["tensor", "statistics", "overflow_check", "free_benchmark"]
|
|
200
206
|
LEVEL_LIST = ["L0", "L1", "L2", "mix"]
|
|
@@ -257,34 +263,84 @@ def print_rank_0(message):
|
|
|
257
263
|
logger.info(message)
|
|
258
264
|
else:
|
|
259
265
|
logger.info(message)
|
|
260
|
-
|
|
266
|
+
|
|
261
267
|
|
|
262
268
|
def load_pt(pt_path, to_cpu=False):
|
|
263
269
|
pt_path = os.path.realpath(pt_path)
|
|
264
270
|
check_file_or_directory_path(pt_path)
|
|
265
271
|
try:
|
|
266
272
|
if to_cpu:
|
|
267
|
-
pt = torch.load(pt_path, map_location=torch.device("cpu"))
|
|
273
|
+
pt = torch.load(pt_path, map_location=torch.device("cpu"), weights_only=True)
|
|
268
274
|
else:
|
|
269
|
-
pt = torch.load(pt_path)
|
|
275
|
+
pt = torch.load(pt_path, weights_only=True)
|
|
270
276
|
except Exception as e:
|
|
271
277
|
raise RuntimeError(f"load pt file {pt_path} failed") from e
|
|
272
278
|
return pt
|
|
273
279
|
|
|
274
280
|
|
|
275
281
|
def save_pt(tensor, filepath):
|
|
276
|
-
filepath = os.path.realpath(filepath)
|
|
277
282
|
check_path_before_create(filepath)
|
|
283
|
+
filepath = os.path.realpath(filepath)
|
|
278
284
|
try:
|
|
279
285
|
torch.save(tensor, filepath)
|
|
280
286
|
except Exception as e:
|
|
281
287
|
logger.error("Save pt file failed, please check according possible error causes: "
|
|
282
|
-
|
|
283
|
-
|
|
288
|
+
"1. out of disk space or disk error, "
|
|
289
|
+
"2. no permission to write files, etc.")
|
|
284
290
|
raise RuntimeError(f"save pt file {filepath} failed") from e
|
|
285
291
|
change_mode(filepath, FileCheckConst.DATA_FILE_AUTHORITY)
|
|
286
292
|
|
|
287
293
|
|
|
294
|
+
class TypeCheckingUnpickler(pickle.Unpickler):
|
|
295
|
+
"""
|
|
296
|
+
This class is a subclass of pickle.Unpickler, which is used to unpickle pickled objects.
|
|
297
|
+
It overrides the find_class method to add type checking functionality.
|
|
298
|
+
"""
|
|
299
|
+
allowed_types = [
|
|
300
|
+
"str",
|
|
301
|
+
"ApiData",
|
|
302
|
+
"OrderedDict",
|
|
303
|
+
"_rebuild_tensor_v2", # from torch.utils
|
|
304
|
+
"_load_from_bytes" # from torch.storage
|
|
305
|
+
]
|
|
306
|
+
|
|
307
|
+
def find_class(self, module, name):
|
|
308
|
+
"""
|
|
309
|
+
Method to find the class of the object to be unpickled.
|
|
310
|
+
Throws pickle.UnpicklingError If the object type is not in the allowed types list.
|
|
311
|
+
"""
|
|
312
|
+
if name in self.allowed_types:
|
|
313
|
+
return super().find_class(module, name)
|
|
314
|
+
raise pickle.UnpicklingError("Unsupported object type: {}.{}".format(module, name))
|
|
315
|
+
|
|
316
|
+
|
|
317
|
+
def save_pkl(tensor, filepath):
|
|
318
|
+
"""Save ApiData or str objection by pickle"""
|
|
319
|
+
check_path_before_create(filepath)
|
|
320
|
+
filepath = os.path.realpath(filepath)
|
|
321
|
+
try:
|
|
322
|
+
with FileOpen(filepath, 'wb') as f:
|
|
323
|
+
pickle.dump(tensor, f)
|
|
324
|
+
except Exception as e:
|
|
325
|
+
logger.error("Save pt file failed, please check according possible error causes: "
|
|
326
|
+
"1. out of disk space or disk error, "
|
|
327
|
+
"2. no permission to write files, etc.")
|
|
328
|
+
raise RuntimeError(f"save pt file {filepath} failed") from e
|
|
329
|
+
change_mode(filepath, FileCheckConst.DATA_FILE_AUTHORITY)
|
|
330
|
+
|
|
331
|
+
|
|
332
|
+
def load_pkl(pt_path):
|
|
333
|
+
"""Load ApiData or str objection by pickle for accuracy_checker_online"""
|
|
334
|
+
check_file_or_directory_path(pt_path)
|
|
335
|
+
pt_path = os.path.realpath(pt_path)
|
|
336
|
+
try:
|
|
337
|
+
with FileOpen(pt_path, 'rb') as f:
|
|
338
|
+
pt = TypeCheckingUnpickler(f).load()
|
|
339
|
+
except Exception as e:
|
|
340
|
+
raise RuntimeError(f"load pt file {pt_path} failed: {e}") from e
|
|
341
|
+
return pt
|
|
342
|
+
|
|
343
|
+
|
|
288
344
|
def save_api_data(api_data):
|
|
289
345
|
"""Save data to io stream"""
|
|
290
346
|
try:
|
|
@@ -1,8 +1,7 @@
|
|
|
1
|
-
|
|
2
|
-
#
|
|
3
|
-
|
|
4
|
-
#
|
|
5
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
1
|
+
# Copyright (c) 2019-2024, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
6
5
|
# you may not use this file except in compliance with the License.
|
|
7
6
|
# You may obtain a copy of the License at
|
|
8
7
|
#
|
|
@@ -13,14 +12,13 @@
|
|
|
13
12
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
14
13
|
# See the License for the specific language governing permissions and
|
|
15
14
|
# limitations under the License.
|
|
16
|
-
|
|
15
|
+
|
|
17
16
|
import os
|
|
18
17
|
from msprobe.core.common.utils import CompareException, check_compare_param, \
|
|
19
|
-
check_configuration_param,
|
|
18
|
+
check_configuration_param, set_dump_path, get_dump_mode
|
|
20
19
|
from msprobe.core.common.file_utils import create_directory
|
|
21
20
|
from msprobe.core.common.exceptions import FileCheckException
|
|
22
21
|
from msprobe.pytorch.common.log import logger
|
|
23
|
-
from msprobe.core.common.const import Const
|
|
24
22
|
from msprobe.pytorch.compare.pt_compare import PTComparator
|
|
25
23
|
from msprobe.core.compare.utils import check_and_return_dir_contents, extract_json
|
|
26
24
|
|
|
@@ -32,6 +30,7 @@ def compare_distributed(npu_dump_dir, bench_dump_dir, output_path, **kwargs):
|
|
|
32
30
|
stack_mode = kwargs.get('stack_mode', False)
|
|
33
31
|
auto_analyze = kwargs.get('auto_analyze', True)
|
|
34
32
|
fuzzy_match = kwargs.get('fuzzy_match', False)
|
|
33
|
+
is_print_compare_log = kwargs.get('is_print_compare_log', True)
|
|
35
34
|
# get the ranks and match by order
|
|
36
35
|
npu_ranks = sorted(check_and_return_dir_contents(npu_dump_dir, 'rank'))
|
|
37
36
|
bench_ranks = sorted(check_and_return_dir_contents(bench_dump_dir, 'rank'))
|
|
@@ -51,16 +50,16 @@ def compare_distributed(npu_dump_dir, bench_dump_dir, output_path, **kwargs):
|
|
|
51
50
|
'npu_json_path': npu_path,
|
|
52
51
|
'bench_json_path': bench_path,
|
|
53
52
|
'stack_json_path': stack_path,
|
|
54
|
-
'is_print_compare_log':
|
|
53
|
+
'is_print_compare_log': is_print_compare_log
|
|
55
54
|
}
|
|
56
55
|
try:
|
|
57
|
-
|
|
58
|
-
|
|
56
|
+
set_dump_path(dump_result_param)
|
|
57
|
+
dump_mode = get_dump_mode(dump_result_param)
|
|
58
|
+
check_configuration_param(stack_mode, auto_analyze, fuzzy_match, is_print_compare_log)
|
|
59
59
|
create_directory(output_path)
|
|
60
|
-
check_compare_param(dump_result_param, output_path,
|
|
60
|
+
check_compare_param(dump_result_param, output_path, dump_mode)
|
|
61
61
|
except (CompareException, FileCheckException) as error:
|
|
62
62
|
logger.error('Compare failed. Please check the arguments and do it again!')
|
|
63
63
|
raise CompareException(error.code) from error
|
|
64
64
|
pt_comparator = PTComparator()
|
|
65
|
-
pt_comparator.compare_core(dump_result_param, output_path, suffix=f'_{nr}-{br}',
|
|
66
|
-
md5_compare=md5_compare, **kwargs)
|
|
65
|
+
pt_comparator.compare_core(dump_result_param, output_path, suffix=f'_{nr}-{br}', dump_mode=dump_mode, **kwargs)
|
msprobe/pytorch/compare/match.py
CHANGED
|
@@ -1,3 +1,18 @@
|
|
|
1
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
1
16
|
import os
|
|
2
17
|
from msprobe.core.common.utils import CompareException
|
|
3
18
|
from msprobe.core.common.file_utils import load_yaml
|
|
@@ -1,19 +1,52 @@
|
|
|
1
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
1
16
|
import os.path
|
|
2
17
|
import torch
|
|
3
18
|
from msprobe.core.common.const import FileCheckConst
|
|
4
19
|
from msprobe.pytorch.common.log import logger
|
|
5
20
|
from msprobe.core.common.exceptions import FileCheckException
|
|
6
21
|
from msprobe.core.compare.acc_compare import Comparator
|
|
7
|
-
from msprobe.core.common.utils import check_configuration_param,
|
|
8
|
-
|
|
22
|
+
from msprobe.core.common.utils import check_configuration_param, check_compare_param, \
|
|
23
|
+
CompareException, set_dump_path, get_dump_mode
|
|
24
|
+
from msprobe.core.common.file_utils import FileChecker, create_directory, load_yaml
|
|
9
25
|
from msprobe.pytorch.common.utils import load_pt
|
|
10
26
|
|
|
11
27
|
|
|
12
28
|
class PTComparator (Comparator):
|
|
13
|
-
def __init__(self):
|
|
29
|
+
def __init__(self, data_mapping=None):
|
|
14
30
|
self.frame_name = PTComparator.__name__
|
|
31
|
+
self.data_mapping = data_mapping
|
|
32
|
+
if isinstance(self.data_mapping, str) or self.data_mapping is None:
|
|
33
|
+
self.data_mapping_dict = self.load_mapping_file(self.data_mapping)
|
|
34
|
+
elif isinstance(self.data_mapping, dict):
|
|
35
|
+
self.data_mapping_dict = self.data_mapping
|
|
36
|
+
else:
|
|
37
|
+
raise TypeError(f"The type of parameter `data_mapping` must be dict, str or None, but got "
|
|
38
|
+
f"{type(self.data_mapping)}")
|
|
39
|
+
|
|
40
|
+
def load_mapping_file(self, mapping_file):
|
|
41
|
+
if isinstance(mapping_file, str):
|
|
42
|
+
mapping_dict = load_yaml(mapping_file)
|
|
43
|
+
else:
|
|
44
|
+
mapping_dict = {}
|
|
45
|
+
return mapping_dict
|
|
15
46
|
|
|
16
47
|
def read_npy_data(self, dir_path, file_name):
|
|
48
|
+
if not file_name:
|
|
49
|
+
return None
|
|
17
50
|
data_path = os.path.join(dir_path, file_name)
|
|
18
51
|
path_checker = FileChecker(data_path, FileCheckConst.FILE, FileCheckConst.READ_ABLE,
|
|
19
52
|
FileCheckConst.PT_SUFFIX, False)
|
|
@@ -35,16 +68,17 @@ class PTComparator (Comparator):
|
|
|
35
68
|
return data_value
|
|
36
69
|
|
|
37
70
|
|
|
38
|
-
def compare(input_param, output_path, stack_mode=False, auto_analyze=True, fuzzy_match=False):
|
|
71
|
+
def compare(input_param, output_path, stack_mode=False, auto_analyze=True, fuzzy_match=False, **kwargs):
|
|
39
72
|
try:
|
|
40
|
-
|
|
41
|
-
|
|
73
|
+
set_dump_path(input_param)
|
|
74
|
+
dump_mode = get_dump_mode(input_param)
|
|
75
|
+
check_configuration_param(stack_mode, auto_analyze, fuzzy_match, input_param.get('is_print_compare_log', True))
|
|
42
76
|
create_directory(output_path)
|
|
43
|
-
check_compare_param(input_param, output_path,
|
|
77
|
+
check_compare_param(input_param, output_path, dump_mode)
|
|
78
|
+
data_mapping = kwargs.get('data_mapping', None)
|
|
44
79
|
except (CompareException, FileCheckException) as error:
|
|
45
80
|
logger.error('Compare failed. Please check the arguments and do it again!')
|
|
46
81
|
raise CompareException(error.code) from error
|
|
47
|
-
pt_comparator = PTComparator()
|
|
82
|
+
pt_comparator = PTComparator(data_mapping)
|
|
48
83
|
pt_comparator.compare_core(input_param, output_path, stack_mode=stack_mode,
|
|
49
|
-
auto_analyze=auto_analyze, fuzzy_match=fuzzy_match,
|
|
50
|
-
md5_compare=md5_compare)
|
|
84
|
+
auto_analyze=auto_analyze, fuzzy_match=fuzzy_match, dump_mode=dump_mode)
|