mindstudio-probe 1.0.4__py3-none-any.whl → 1.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.1.dist-info}/METADATA +5 -5
- mindstudio_probe-1.1.1.dist-info/RECORD +341 -0
- {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.1.dist-info}/WHEEL +1 -1
- {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.1.dist-info}/entry_points.txt +0 -1
- msprobe/README.md +84 -18
- msprobe/__init__.py +16 -1
- msprobe/config.json +1 -5
- msprobe/core/advisor/advisor.py +16 -11
- msprobe/core/advisor/advisor_const.py +6 -7
- msprobe/core/advisor/advisor_result.py +12 -12
- msprobe/core/common/const.py +164 -3
- msprobe/core/common/exceptions.py +26 -4
- msprobe/core/common/file_utils.py +196 -27
- msprobe/core/common/inplace_op_checker.py +53 -0
- msprobe/core/common/inplace_ops.yaml +251 -0
- msprobe/core/common/log.py +46 -18
- msprobe/core/common/utils.py +308 -209
- msprobe/core/common_config.py +60 -38
- msprobe/core/compare/acc_compare.py +332 -94
- msprobe/core/compare/check.py +104 -22
- msprobe/core/compare/compare_cli.py +42 -5
- msprobe/core/compare/highlight.py +162 -57
- msprobe/core/compare/layer_mapping/__init__.py +19 -0
- msprobe/core/compare/layer_mapping/data_scope_parser.py +235 -0
- msprobe/core/compare/layer_mapping/layer_mapping.py +242 -0
- msprobe/core/compare/layer_mapping/postprocess_pass.py +94 -0
- msprobe/core/compare/multiprocessing_compute.py +33 -8
- msprobe/core/compare/npy_compare.py +73 -29
- msprobe/core/compare/utils.py +306 -247
- msprobe/core/data_dump/data_collector.py +44 -43
- msprobe/core/data_dump/data_processor/base.py +88 -35
- msprobe/core/data_dump/data_processor/factory.py +20 -3
- msprobe/core/data_dump/data_processor/mindspore_processor.py +14 -8
- msprobe/core/data_dump/data_processor/pytorch_processor.py +180 -66
- msprobe/core/data_dump/json_writer.py +63 -42
- msprobe/core/data_dump/scope.py +143 -48
- msprobe/core/grad_probe/constant.py +31 -13
- msprobe/core/grad_probe/grad_compare.py +20 -4
- msprobe/core/grad_probe/utils.py +44 -3
- msprobe/core/overflow_check/abnormal_scene.py +185 -0
- msprobe/core/overflow_check/api_info.py +55 -0
- msprobe/core/overflow_check/checker.py +138 -0
- msprobe/core/overflow_check/filter.py +157 -0
- msprobe/core/overflow_check/ignore_rules.yaml +55 -0
- msprobe/core/overflow_check/level.py +22 -0
- msprobe/core/overflow_check/utils.py +28 -0
- msprobe/docs/01.installation.md +29 -9
- msprobe/docs/02.config_introduction.md +83 -84
- msprobe/docs/03.config_examples.md +3 -20
- msprobe/docs/04.kernel_dump_PyTorch.md +73 -0
- msprobe/docs/05.data_dump_PyTorch.md +143 -13
- msprobe/docs/06.data_dump_MindSpore.md +197 -88
- msprobe/docs/07.accuracy_checker_PyTorch.md +69 -46
- msprobe/docs/08.accuracy_checker_online_PyTorch.md +52 -17
- msprobe/docs/09.accuracy_checker_MindSpore.md +51 -15
- msprobe/docs/10.accuracy_compare_PyTorch.md +187 -99
- msprobe/docs/11.accuracy_compare_MindSpore.md +253 -31
- msprobe/docs/12.overflow_check_PyTorch.md +1 -1
- msprobe/docs/13.overflow_check_MindSpore.md +6 -6
- msprobe/docs/15.free_benchmarking_PyTorch.md +60 -55
- msprobe/docs/16.free_benchmarking_MindSpore.md +159 -0
- msprobe/docs/17.grad_probe.md +19 -22
- msprobe/docs/18.online_dispatch.md +89 -0
- msprobe/docs/19.monitor.md +468 -0
- msprobe/docs/20.monitor_performance_baseline.md +52 -0
- msprobe/docs/21.visualization_PyTorch.md +386 -0
- msprobe/docs/22.visualization_MindSpore.md +384 -0
- msprobe/docs/23.tool_function_introduction.md +28 -0
- msprobe/docs/{FAQ_PyTorch.md → FAQ.md} +25 -10
- msprobe/docs/data_dump_Mindspore/dynamic_graph_quick_start_example.md +211 -0
- msprobe/docs/img/compare_result.png +0 -0
- msprobe/docs/img/monitor/cpu_info.png +0 -0
- msprobe/docs/img/ms_dump.png +0 -0
- msprobe/docs/img/ms_layer.png +0 -0
- msprobe/docs/img/pt_dump.png +0 -0
- msprobe/mindspore/__init__.py +16 -0
- msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +130 -138
- msprobe/mindspore/api_accuracy_checker/api_info.py +27 -5
- msprobe/mindspore/api_accuracy_checker/api_runner.py +43 -18
- msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +21 -7
- msprobe/mindspore/api_accuracy_checker/checker_support_api.yaml +77 -0
- msprobe/mindspore/api_accuracy_checker/cmd_parser.py +63 -1
- msprobe/mindspore/api_accuracy_checker/compute_element.py +59 -24
- msprobe/mindspore/api_accuracy_checker/data_manager.py +264 -0
- msprobe/mindspore/api_accuracy_checker/main.py +27 -3
- msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +206 -0
- msprobe/mindspore/api_accuracy_checker/multi_data_manager.py +58 -0
- msprobe/mindspore/api_accuracy_checker/type_mapping.py +22 -5
- msprobe/mindspore/api_accuracy_checker/utils.py +34 -17
- msprobe/mindspore/cell_processor.py +58 -13
- msprobe/mindspore/common/const.py +35 -13
- msprobe/mindspore/common/log.py +5 -9
- msprobe/mindspore/common/utils.py +60 -5
- msprobe/mindspore/compare/distributed_compare.py +15 -28
- msprobe/mindspore/compare/ms_compare.py +319 -158
- msprobe/mindspore/compare/ms_graph_compare.py +99 -49
- msprobe/mindspore/debugger/debugger_config.py +20 -14
- msprobe/mindspore/debugger/precision_debugger.py +43 -13
- msprobe/mindspore/dump/dump_tool_factory.py +18 -1
- msprobe/mindspore/dump/hook_cell/api_registry.py +23 -3
- msprobe/mindspore/dump/hook_cell/primitive_hooks.py +203 -0
- msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +107 -10
- msprobe/mindspore/dump/hook_cell/wrap_api.py +21 -13
- msprobe/mindspore/dump/jit_dump.py +56 -20
- msprobe/mindspore/dump/kernel_graph_dump.py +19 -5
- msprobe/mindspore/dump/kernel_kbyk_dump.py +19 -6
- msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +140 -0
- msprobe/mindspore/dym_loader/hook_dynamic_loader.h +53 -0
- msprobe/mindspore/free_benchmark/api_pynative_self_check.py +162 -41
- msprobe/mindspore/free_benchmark/common/config.py +15 -0
- msprobe/mindspore/free_benchmark/common/handler_params.py +15 -1
- msprobe/mindspore/free_benchmark/common/utils.py +37 -8
- msprobe/mindspore/free_benchmark/data/support_wrap_ops.yaml +0 -204
- msprobe/mindspore/free_benchmark/handler/base_handler.py +20 -5
- msprobe/mindspore/free_benchmark/handler/check_handler.py +21 -7
- msprobe/mindspore/free_benchmark/handler/fix_handler.py +18 -3
- msprobe/mindspore/free_benchmark/handler/handler_factory.py +21 -6
- msprobe/mindspore/free_benchmark/perturbation/add_noise.py +23 -8
- msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +29 -5
- msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +25 -10
- msprobe/mindspore/free_benchmark/perturbation/exchange_value.py +45 -19
- msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +29 -8
- msprobe/mindspore/free_benchmark/perturbation/no_change.py +16 -1
- msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +22 -7
- msprobe/mindspore/free_benchmark/self_check_tool_factory.py +17 -2
- msprobe/mindspore/grad_probe/global_context.py +44 -14
- msprobe/mindspore/grad_probe/grad_analyzer.py +27 -13
- msprobe/mindspore/grad_probe/grad_monitor.py +16 -1
- msprobe/mindspore/grad_probe/grad_stat_csv.py +33 -5
- msprobe/mindspore/grad_probe/hook.py +24 -10
- msprobe/mindspore/grad_probe/utils.py +18 -5
- msprobe/mindspore/ms_config.py +22 -15
- msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +20 -6
- msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +15 -0
- msprobe/mindspore/runtime.py +15 -0
- msprobe/mindspore/service.py +75 -150
- msprobe/mindspore/task_handler_factory.py +15 -0
- msprobe/msprobe.py +24 -7
- msprobe/pytorch/__init__.py +23 -3
- msprobe/pytorch/api_accuracy_checker/common/config.py +81 -2
- msprobe/pytorch/api_accuracy_checker/common/utils.py +53 -21
- msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +19 -2
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +50 -25
- msprobe/pytorch/api_accuracy_checker/compare/compare.py +51 -21
- msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +23 -6
- msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +28 -8
- msprobe/pytorch/api_accuracy_checker/config.yaml +1 -1
- msprobe/pytorch/api_accuracy_checker/generate_op_script/config_op.json +9 -0
- msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +454 -0
- msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +365 -0
- msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +73 -33
- msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +44 -18
- msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +32 -11
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +122 -172
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +158 -4
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +30 -24
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +68 -31
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +27 -4
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/dump_dispatch.py +115 -0
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +26 -9
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/torch_ops_config.yaml +63 -0
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +44 -0
- msprobe/pytorch/bench_functions/__init__.py +18 -3
- msprobe/pytorch/bench_functions/apply_adam_w.py +15 -0
- msprobe/pytorch/bench_functions/confusion_transpose.py +20 -1
- msprobe/pytorch/bench_functions/fast_gelu.py +15 -0
- msprobe/pytorch/bench_functions/layer_norm_eval.py +15 -0
- msprobe/pytorch/bench_functions/linear.py +15 -0
- msprobe/pytorch/bench_functions/matmul_backward.py +33 -6
- msprobe/pytorch/bench_functions/npu_fusion_attention.py +280 -157
- msprobe/pytorch/bench_functions/rms_norm.py +15 -0
- msprobe/pytorch/bench_functions/rotary_mul.py +32 -9
- msprobe/pytorch/bench_functions/scaled_mask_softmax.py +15 -0
- msprobe/pytorch/bench_functions/swiglu.py +29 -6
- msprobe/pytorch/common/__init__.py +15 -0
- msprobe/pytorch/common/log.py +18 -6
- msprobe/pytorch/common/parse_json.py +31 -16
- msprobe/pytorch/common/utils.py +96 -40
- msprobe/pytorch/compare/distributed_compare.py +13 -14
- msprobe/pytorch/compare/match.py +15 -0
- msprobe/pytorch/compare/pt_compare.py +44 -10
- msprobe/pytorch/debugger/debugger_config.py +69 -52
- msprobe/pytorch/debugger/precision_debugger.py +72 -24
- msprobe/pytorch/dump/kernel_dump/kernel_config.py +33 -0
- msprobe/pytorch/free_benchmark/__init__.py +20 -5
- msprobe/pytorch/free_benchmark/common/constant.py +15 -0
- msprobe/pytorch/free_benchmark/common/counter.py +15 -0
- msprobe/pytorch/free_benchmark/common/enums.py +43 -0
- msprobe/pytorch/free_benchmark/common/params.py +23 -1
- msprobe/pytorch/free_benchmark/common/utils.py +43 -5
- msprobe/pytorch/free_benchmark/compare/grad_saver.py +47 -9
- msprobe/pytorch/free_benchmark/compare/single_benchmark.py +17 -0
- msprobe/pytorch/free_benchmark/main.py +19 -4
- msprobe/pytorch/free_benchmark/perturbed_layers/base_layer.py +15 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/layer_factory.py +19 -4
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +18 -1
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +21 -4
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +28 -2
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +19 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/no_change.py +15 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/npu_base_layser.py +15 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/run_cpu.py +15 -0
- msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +65 -16
- msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +15 -0
- msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +21 -5
- msprobe/pytorch/free_benchmark/result_handlers/handler_factory.py +15 -0
- msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +19 -4
- msprobe/pytorch/function_factory.py +17 -2
- msprobe/pytorch/functional/module_dump.py +84 -0
- msprobe/pytorch/grad_probe/grad_monitor.py +23 -6
- msprobe/pytorch/grad_probe/grad_stat_csv.py +40 -10
- msprobe/pytorch/hook_module/__init__.py +16 -1
- msprobe/pytorch/hook_module/api_registry.py +13 -8
- msprobe/pytorch/hook_module/hook_module.py +17 -19
- msprobe/pytorch/hook_module/support_wrap_ops.yaml +1 -0
- msprobe/pytorch/hook_module/utils.py +4 -6
- msprobe/pytorch/hook_module/wrap_aten.py +12 -11
- msprobe/pytorch/hook_module/wrap_distributed.py +6 -7
- msprobe/pytorch/hook_module/wrap_functional.py +21 -20
- msprobe/pytorch/hook_module/wrap_npu_custom.py +9 -17
- msprobe/pytorch/hook_module/wrap_tensor.py +4 -6
- msprobe/pytorch/hook_module/wrap_torch.py +4 -6
- msprobe/pytorch/hook_module/wrap_vf.py +4 -6
- msprobe/pytorch/module_processer.py +18 -6
- msprobe/pytorch/monitor/anomaly_analyse.py +201 -0
- msprobe/pytorch/monitor/anomaly_detect.py +340 -0
- msprobe/pytorch/monitor/distributed/distributed_ops.yaml +19 -0
- msprobe/pytorch/monitor/distributed/stack_blacklist.yaml +5 -0
- msprobe/pytorch/monitor/distributed/wrap_distributed.py +272 -0
- msprobe/pytorch/monitor/features.py +108 -0
- msprobe/pytorch/monitor/module_hook.py +870 -0
- msprobe/pytorch/monitor/module_metric.py +193 -0
- msprobe/pytorch/monitor/module_spec_verifier.py +93 -0
- msprobe/pytorch/monitor/optimizer_collect.py +295 -0
- msprobe/pytorch/monitor/unittest/__init__.py +0 -0
- msprobe/pytorch/monitor/unittest/test_monitor.py +145 -0
- msprobe/pytorch/monitor/utils.py +250 -0
- msprobe/pytorch/monitor/visualizer.py +59 -0
- msprobe/pytorch/online_dispatch/__init__.py +2 -3
- msprobe/pytorch/online_dispatch/compare.py +38 -48
- msprobe/pytorch/online_dispatch/dispatch.py +50 -25
- msprobe/pytorch/online_dispatch/dump_compare.py +21 -9
- msprobe/pytorch/online_dispatch/single_compare.py +60 -39
- msprobe/pytorch/online_dispatch/torch_ops_config.yaml +9 -1
- msprobe/pytorch/online_dispatch/utils.py +48 -23
- msprobe/pytorch/parse.py +15 -0
- msprobe/pytorch/parse_tool/cli.py +5 -6
- msprobe/pytorch/parse_tool/lib/compare.py +19 -26
- msprobe/pytorch/parse_tool/lib/config.py +1 -1
- msprobe/pytorch/parse_tool/lib/parse_tool.py +4 -2
- msprobe/pytorch/parse_tool/lib/utils.py +40 -55
- msprobe/pytorch/parse_tool/lib/visualization.py +3 -1
- msprobe/pytorch/pt_config.py +192 -40
- msprobe/pytorch/service.py +110 -35
- msprobe/visualization/__init__.py +14 -0
- msprobe/visualization/builder/__init__.py +14 -0
- msprobe/visualization/builder/graph_builder.py +165 -0
- msprobe/visualization/builder/msprobe_adapter.py +205 -0
- msprobe/visualization/compare/__init__.py +14 -0
- msprobe/visualization/compare/graph_comparator.py +130 -0
- msprobe/visualization/compare/mode_adapter.py +211 -0
- msprobe/visualization/graph/__init__.py +14 -0
- msprobe/visualization/graph/base_node.py +124 -0
- msprobe/visualization/graph/graph.py +200 -0
- msprobe/visualization/graph/node_colors.py +95 -0
- msprobe/visualization/graph/node_op.py +39 -0
- msprobe/visualization/graph_service.py +214 -0
- msprobe/visualization/utils.py +232 -0
- mindstudio_probe-1.0.4.dist-info/RECORD +0 -276
- msprobe/docs/04.acl_config_examples.md +0 -76
- msprobe/mindspore/free_benchmark/decorator/dec_forward.py +0 -43
- msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +0 -107
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/ssl_config.py +0 -10
- msprobe/pytorch/functional/dump_module.py +0 -39
- {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.1.dist-info}/LICENSE +0 -0
- {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.1.dist-info}/top_level.txt +0 -0
- /msprobe/{mindspore/free_benchmark/decorator → pytorch/monitor}/__init__.py +0 -0
- /msprobe/pytorch/{functional/data_processor.py → monitor/distributed/__init__.py} +0 -0
|
@@ -1,8 +1,24 @@
|
|
|
1
|
-
|
|
1
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
2
16
|
from msprobe.core.common.const import Const
|
|
17
|
+
from msprobe.core.common.exceptions import ApiAccuracyCheckerException
|
|
3
18
|
from msprobe.mindspore.api_accuracy_checker.type_mapping import float_dtype_str_list
|
|
4
19
|
from msprobe.mindspore.common.log import logger
|
|
5
20
|
|
|
21
|
+
|
|
6
22
|
def check_and_get_from_json_dict(dict_instance, key, key_description, accepted_type=None, accepted_value=None):
|
|
7
23
|
'''
|
|
8
24
|
Args:
|
|
@@ -22,30 +38,30 @@ def check_and_get_from_json_dict(dict_instance, key, key_description, accepted_t
|
|
|
22
38
|
3. value is not accepted type
|
|
23
39
|
4. value is not accepted value
|
|
24
40
|
'''
|
|
25
|
-
parse_failed_exception = ApiAccuracyCheckerException(ApiAccuracyCheckerException.ParseJsonFailed)
|
|
26
41
|
if not isinstance(dict_instance, dict):
|
|
27
|
-
|
|
42
|
+
error_info = "check_and_get_from_json_dict failed: input is not a dict"
|
|
43
|
+
raise ApiAccuracyCheckerException(ApiAccuracyCheckerException.ParseJsonFailed, error_info)
|
|
28
44
|
value = dict_instance.get(key)
|
|
29
45
|
if value is None:
|
|
30
|
-
|
|
31
|
-
|
|
46
|
+
error_info = f"check_and_get_from_json_dict failed: {key_description} is missing"
|
|
47
|
+
raise ApiAccuracyCheckerException(ApiAccuracyCheckerException.ParseJsonFailed, error_info)
|
|
32
48
|
elif accepted_type is not None and not isinstance(value, accepted_type):
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
parse_failed_exception)
|
|
49
|
+
error_info = f"check_and_get_from_json_dict failed: {key_description} is not accepted type: {accepted_type}"
|
|
50
|
+
raise ApiAccuracyCheckerException(ApiAccuracyCheckerException.ParseJsonFailed, error_info)
|
|
36
51
|
elif accepted_value is not None and value not in accepted_value:
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
parse_failed_exception)
|
|
52
|
+
error_info = f"check_and_get_from_json_dict failed: {key_description} is not accepted value: {accepted_value}"
|
|
53
|
+
raise ApiAccuracyCheckerException(ApiAccuracyCheckerException.ParseJsonFailed, error_info)
|
|
40
54
|
return value
|
|
41
55
|
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
56
|
+
|
|
57
|
+
def convert_to_tuple(args):
|
|
58
|
+
if isinstance(args, (tuple, list)):
|
|
59
|
+
return tuple(args)
|
|
45
60
|
else:
|
|
46
|
-
input_list = [
|
|
61
|
+
input_list = [args]
|
|
47
62
|
return tuple(input_list)
|
|
48
63
|
|
|
64
|
+
|
|
49
65
|
def trim_output_compute_element_list(compute_element_list, forward_or_backward):
|
|
50
66
|
'''
|
|
51
67
|
Args:
|
|
@@ -55,12 +71,13 @@ def trim_output_compute_element_list(compute_element_list, forward_or_backward):
|
|
|
55
71
|
trimmed_list = []
|
|
56
72
|
for compute_element in compute_element_list:
|
|
57
73
|
if compute_element.get_parameter() is None or \
|
|
58
|
-
|
|
74
|
+
(forward_or_backward == Const.BACKWARD and compute_element.get_dtype() not in float_dtype_str_list):
|
|
59
75
|
# trim case: 1. parameter is None. 2. backward output has non float parameter
|
|
60
76
|
continue
|
|
61
77
|
trimmed_list.append(compute_element)
|
|
62
78
|
return trimmed_list
|
|
63
79
|
|
|
80
|
+
|
|
64
81
|
class GlobalContext:
|
|
65
82
|
def __init__(self):
|
|
66
83
|
self.is_constructed = True
|
|
@@ -77,4 +94,4 @@ class GlobalContext:
|
|
|
77
94
|
return self.is_constructed
|
|
78
95
|
|
|
79
96
|
|
|
80
|
-
global_context = GlobalContext()
|
|
97
|
+
global_context = GlobalContext()
|
|
@@ -1,17 +1,31 @@
|
|
|
1
|
-
|
|
1
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
from msprobe.core.data_dump.scope import ModuleRangeScope, MixRangeScope
|
|
2
17
|
from msprobe.core.common.const import Const
|
|
3
|
-
from msprobe.mindspore.common.log import logger
|
|
4
18
|
|
|
5
19
|
|
|
6
20
|
class CellProcessor:
|
|
7
21
|
cell_count = {}
|
|
22
|
+
cell_stack = []
|
|
23
|
+
api_parent_node = ""
|
|
24
|
+
module_node = {}
|
|
8
25
|
|
|
9
26
|
def __init__(self, scope):
|
|
10
|
-
if isinstance(scope, ModuleRangeScope)
|
|
11
|
-
|
|
12
|
-
else:
|
|
13
|
-
self.scope = None
|
|
14
|
-
|
|
27
|
+
self.scope = scope if isinstance(scope, (ModuleRangeScope, MixRangeScope)) else None
|
|
28
|
+
|
|
15
29
|
@staticmethod
|
|
16
30
|
def set_cell_count(cell_name):
|
|
17
31
|
if cell_name not in CellProcessor.cell_count:
|
|
@@ -19,16 +33,47 @@ class CellProcessor:
|
|
|
19
33
|
else:
|
|
20
34
|
CellProcessor.cell_count[cell_name] += 1
|
|
21
35
|
return CellProcessor.cell_count[cell_name]
|
|
22
|
-
|
|
36
|
+
|
|
37
|
+
@classmethod
|
|
38
|
+
def reset_cell_stats(cls):
|
|
39
|
+
cls.cell_count = {}
|
|
40
|
+
cls.cell_stack = []
|
|
41
|
+
cls.api_parent_node = ""
|
|
42
|
+
cls.module_node = {}
|
|
43
|
+
|
|
23
44
|
def node_hook(self, name_prefix, start_or_stop, **kwargs):
|
|
24
|
-
def begin_hook(cell,
|
|
25
|
-
|
|
26
|
-
|
|
45
|
+
def begin_hook(cell, input_data):
|
|
46
|
+
full_name = self.set_and_get_reserved_name(cell, name_prefix, is_called_by_pre_hook=True)
|
|
47
|
+
if CellProcessor.cell_stack:
|
|
48
|
+
CellProcessor.module_node[full_name] = CellProcessor.cell_stack[-1]
|
|
49
|
+
else:
|
|
50
|
+
CellProcessor.module_node[full_name] = None
|
|
51
|
+
|
|
52
|
+
CellProcessor.cell_stack.append(full_name)
|
|
53
|
+
CellProcessor.api_parent_node = full_name
|
|
54
|
+
|
|
27
55
|
if self.scope:
|
|
28
56
|
self.scope.begin_module(full_name)
|
|
29
|
-
|
|
30
|
-
def end_hook(cell,
|
|
57
|
+
|
|
58
|
+
def end_hook(cell, input_data, output_data):
|
|
59
|
+
if CellProcessor.cell_stack:
|
|
60
|
+
CellProcessor.cell_stack.pop()
|
|
61
|
+
if CellProcessor.cell_stack:
|
|
62
|
+
CellProcessor.api_parent_node = CellProcessor.cell_stack[-1]
|
|
63
|
+
else:
|
|
64
|
+
CellProcessor.api_parent_node = None
|
|
65
|
+
|
|
31
66
|
if self.scope:
|
|
32
67
|
self.scope.end_module(cell.mindstudio_reserved_name)
|
|
33
68
|
|
|
34
69
|
return begin_hook if Const.START == start_or_stop else end_hook
|
|
70
|
+
|
|
71
|
+
def set_and_get_reserved_name(self, cell, cell_name, is_called_by_pre_hook=False):
|
|
72
|
+
if not is_called_by_pre_hook and hasattr(cell, 'has_pre_hook_called') and cell.has_pre_hook_called:
|
|
73
|
+
cell.has_pre_hook_called = False
|
|
74
|
+
else:
|
|
75
|
+
if is_called_by_pre_hook:
|
|
76
|
+
cell.has_pre_hook_called = True
|
|
77
|
+
index = self.set_cell_count(cell_name)
|
|
78
|
+
cell.mindstudio_reserved_name = cell_name + Const.SEP + str(index)
|
|
79
|
+
return cell.mindstudio_reserved_name
|
|
@@ -1,3 +1,18 @@
|
|
|
1
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
1
16
|
import numpy as np
|
|
2
17
|
import mindspore as ms
|
|
3
18
|
|
|
@@ -23,31 +38,35 @@ class Const:
|
|
|
23
38
|
ASCEND_910A = "ascend910"
|
|
24
39
|
|
|
25
40
|
OPS_PREFIX = "mindspore.ops."
|
|
26
|
-
|
|
41
|
+
TENSOR_PREFIX = "mindspore.Tensor."
|
|
27
42
|
MINT_PREFIX = "mindspore.mint."
|
|
28
43
|
MINT_NN_FUNC_PREFIX = "mindspore.mint.nn.functional."
|
|
29
|
-
|
|
30
|
-
COMMUNICATION_API_LIST = [
|
|
31
|
-
"mindspore.communication.comm_func.all_gather_into_tensor",
|
|
32
|
-
"mindspore.communication.comm_func.gather_into_tensor",
|
|
33
|
-
"mindspore.communication.comm_func.all_reduce",
|
|
34
|
-
"mindspore.communication.comm_func.reduce",
|
|
35
|
-
"mindspore.communication.comm_func.reduce_scatter_tensor"
|
|
36
|
-
]
|
|
44
|
+
|
|
37
45
|
TENSOR_DATA_PREFIX = "Tensor."
|
|
38
46
|
STUB_TENSOR_DATA_PREFIX = "Tensor."
|
|
39
47
|
OPS_DATA_PREFIX = "Functional."
|
|
40
48
|
MINT_DATA_PREFIX = "Mint."
|
|
41
49
|
MINT_NN_FUNC_DATA_PREFIX = "MintFunctional."
|
|
50
|
+
DISTRIBUTED_DATA_PREFIX = "Distributed."
|
|
42
51
|
|
|
43
52
|
SUPPORTED_API_LIST_FILE = "support_wrap_ops.yaml"
|
|
44
53
|
SUPPORTED_TENSOR_LIST_KEY = "tensor"
|
|
45
54
|
SUPPORTED_OPS_LIST_KEY = "ops"
|
|
46
55
|
SUPPORTED_MINT_LIST_KEY = "mint.ops"
|
|
47
56
|
SUPPORTED__MINT_NN_FUNC_LIST_KEY = "mint.nn.functional"
|
|
57
|
+
SUPPORTED_COMM_LIST_KEY = "communication.comm_func"
|
|
48
58
|
|
|
49
59
|
DROPOUT_API_NAME_PREFIX = "dropout"
|
|
50
60
|
|
|
61
|
+
GRAPH_DATA_MODE_LIST = [CoreConst.ALL, CoreConst.INPUT, CoreConst.OUTPUT]
|
|
62
|
+
|
|
63
|
+
HOOK_MS_PREFIX_DICT = {
|
|
64
|
+
OPS_DATA_PREFIX: OPS_PREFIX,
|
|
65
|
+
TENSOR_DATA_PREFIX: TENSOR_PREFIX,
|
|
66
|
+
MINT_DATA_PREFIX: MINT_PREFIX,
|
|
67
|
+
MINT_NN_FUNC_DATA_PREFIX: MINT_NN_FUNC_PREFIX
|
|
68
|
+
}
|
|
69
|
+
|
|
51
70
|
|
|
52
71
|
class FreeBenchmarkConst:
|
|
53
72
|
ADD_NOISE = "add_noise"
|
|
@@ -63,19 +82,21 @@ class FreeBenchmarkConst:
|
|
|
63
82
|
DEFAULT_PERT_TYPE = IMPROVE_PRECISION
|
|
64
83
|
DEFAULT_HANDLER_TYPE = CHECK
|
|
65
84
|
DEVICE_LIST = [DEFAULT_DEVICE]
|
|
66
|
-
STAGE_LIST = [CoreConst.FORWARD]
|
|
85
|
+
STAGE_LIST = [CoreConst.FORWARD, CoreConst.BACKWARD]
|
|
67
86
|
DUMP_LEVEL_LIST = [DEFAULT_DUMP_LEVEL]
|
|
68
87
|
PERT_TYPE_LIST = [IMPROVE_PRECISION, ADD_NOISE, BIT_NOISE, NO_CHANGE, EXCHANGE_VALUE]
|
|
69
88
|
HANDLER_TYPE_LIST = [CHECK, FIX]
|
|
70
89
|
NO_CHANGE_ERROR_THRESHOLD = 1.0
|
|
71
90
|
SYMBOL_FLIPPING_RATIO = 8.0
|
|
72
91
|
|
|
92
|
+
SUPPORTED_CHECK_API_FILE = "support_wrap_ops.yaml"
|
|
93
|
+
CHECK_RESULT_FILE = "free_benchmark.csv"
|
|
94
|
+
|
|
73
95
|
API_PREFIX_DICT = {
|
|
74
96
|
"ops": Const.OPS_PREFIX,
|
|
75
|
-
"Tensor": Const.
|
|
97
|
+
"Tensor": Const.TENSOR_PREFIX,
|
|
76
98
|
"mint": Const.MINT_PREFIX,
|
|
77
|
-
"mint.nn.functional": Const.MINT_NN_FUNC_PREFIX
|
|
78
|
-
"communication": Const.COMM_PREFIX
|
|
99
|
+
"mint.nn.functional": Const.MINT_NN_FUNC_PREFIX
|
|
79
100
|
}
|
|
80
101
|
|
|
81
102
|
PERT_VALUE_DICT = {
|
|
@@ -86,6 +107,7 @@ class FreeBenchmarkConst:
|
|
|
86
107
|
}
|
|
87
108
|
|
|
88
109
|
ERROR_THRESHOLD = {
|
|
110
|
+
ms.bfloat16: 1.004,
|
|
89
111
|
ms.float16: 1.002,
|
|
90
112
|
ms.float32: 1.0002
|
|
91
113
|
}
|
msprobe/mindspore/common/log.py
CHANGED
|
@@ -1,4 +1,5 @@
|
|
|
1
|
-
# Copyright 2024 Huawei Technologies Co., Ltd
|
|
1
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
2
3
|
#
|
|
3
4
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
5
|
# you may not use this file except in compliance with the License.
|
|
@@ -11,15 +12,10 @@
|
|
|
11
12
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
13
|
# See the License for the specific language governing permissions and
|
|
13
14
|
# limitations under the License.
|
|
14
|
-
# ============================================================================
|
|
15
15
|
|
|
16
|
-
import os
|
|
17
|
-
import time
|
|
18
|
-
import sys
|
|
19
|
-
|
|
20
|
-
from msprobe.mindspore.common.utils import get_rank_if_initialized
|
|
21
|
-
from msprobe.core.common.log import BaseLogger
|
|
22
16
|
from msprobe.core.common.exceptions import DistributedNotInitializedError
|
|
17
|
+
from msprobe.core.common.log import BaseLogger
|
|
18
|
+
from msprobe.mindspore.common.utils import get_rank_if_initialized
|
|
23
19
|
|
|
24
20
|
|
|
25
21
|
class MindsporeLogger(BaseLogger):
|
|
@@ -35,4 +31,4 @@ class MindsporeLogger(BaseLogger):
|
|
|
35
31
|
return current_rank
|
|
36
32
|
|
|
37
33
|
|
|
38
|
-
logger = MindsporeLogger()
|
|
34
|
+
logger = MindsporeLogger()
|
|
@@ -1,4 +1,5 @@
|
|
|
1
|
-
# Copyright 2024 Huawei Technologies Co., Ltd
|
|
1
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
2
3
|
#
|
|
3
4
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
5
|
# you may not use this file except in compliance with the License.
|
|
@@ -11,13 +12,19 @@
|
|
|
11
12
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
13
|
# See the License for the specific language governing permissions and
|
|
13
14
|
# limitations under the License.
|
|
14
|
-
|
|
15
|
+
|
|
15
16
|
import os
|
|
17
|
+
import random
|
|
18
|
+
|
|
16
19
|
import mindspore as ms
|
|
17
20
|
|
|
21
|
+
from mindspore import ops
|
|
22
|
+
from mindspore.mint import nn
|
|
18
23
|
from msprobe.core.common.exceptions import DistributedNotInitializedError
|
|
19
24
|
from msprobe.core.common.file_utils import path_len_exceeds_limit, check_path_exists, save_npy
|
|
20
25
|
from msprobe.core.common.log import logger
|
|
26
|
+
from msprobe.core.common.const import Const
|
|
27
|
+
from msprobe.core.common.utils import CompareException, check_seed_all
|
|
21
28
|
|
|
22
29
|
|
|
23
30
|
def get_rank_if_initialized():
|
|
@@ -36,7 +43,7 @@ def convert_bf16_to_fp32(tensor):
|
|
|
36
43
|
def save_tensor_as_npy(tensor, file_path):
|
|
37
44
|
if not path_len_exceeds_limit(file_path):
|
|
38
45
|
tensor = convert_bf16_to_fp32(tensor)
|
|
39
|
-
saved_tensor = tensor.asnumpy()
|
|
46
|
+
saved_tensor = tensor.contiguous().asnumpy()
|
|
40
47
|
save_npy(saved_tensor, file_path)
|
|
41
48
|
else:
|
|
42
49
|
logger.warning(f'The file path {file_path} length exceeds limit.')
|
|
@@ -53,12 +60,15 @@ def list_lowest_level_directories(root_dir):
|
|
|
53
60
|
check_path_exists(root_dir)
|
|
54
61
|
lowest_level_dirs = []
|
|
55
62
|
|
|
56
|
-
def recurse_dirs(current_dir):
|
|
63
|
+
def recurse_dirs(current_dir, depth=0):
|
|
64
|
+
if depth > Const.MAX_DEPTH:
|
|
65
|
+
logger.error(f'The directory {current_dir} has more than {Const.MAX_DEPTH} levels.')
|
|
66
|
+
raise CompareException(CompareException.RECURSION_LIMIT_ERROR)
|
|
57
67
|
for entry in os.listdir(current_dir):
|
|
58
68
|
full_path = os.path.join(current_dir, entry)
|
|
59
69
|
if os.path.isdir(full_path):
|
|
60
70
|
if any(os.path.isdir(os.path.join(full_path, subentry)) for subentry in os.listdir(full_path)):
|
|
61
|
-
recurse_dirs(full_path)
|
|
71
|
+
recurse_dirs(full_path, depth=depth+1)
|
|
62
72
|
else:
|
|
63
73
|
lowest_level_dirs.append(full_path)
|
|
64
74
|
|
|
@@ -66,6 +76,16 @@ def list_lowest_level_directories(root_dir):
|
|
|
66
76
|
return lowest_level_dirs
|
|
67
77
|
|
|
68
78
|
|
|
79
|
+
def seed_all(seed=1234, mode=False, rm_dropout=True):
|
|
80
|
+
check_seed_all(seed, mode)
|
|
81
|
+
os.environ['PYTHONHASHSEED'] = str(seed)
|
|
82
|
+
ms.set_seed(seed)
|
|
83
|
+
random.seed(seed)
|
|
84
|
+
ms.set_context(deterministic="ON" if mode else "OFF")
|
|
85
|
+
os.environ['HCCL_DETERMINISTIC'] = str(mode)
|
|
86
|
+
if rm_dropout:
|
|
87
|
+
remove_dropout()
|
|
88
|
+
|
|
69
89
|
|
|
70
90
|
class MsprobeStep(ms.train.Callback):
|
|
71
91
|
|
|
@@ -79,3 +99,38 @@ class MsprobeStep(ms.train.Callback):
|
|
|
79
99
|
def on_train_step_end(self, run_context):
|
|
80
100
|
self.debugger.stop()
|
|
81
101
|
self.debugger.step()
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
class Dropout(ops.Dropout):
|
|
105
|
+
def __init__(self, keep_prob=0.5, Seed0=0, Seed1=1):
|
|
106
|
+
super().__init__(1., Seed0, Seed1)
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
class Dropout2D(ops.Dropout2D):
|
|
110
|
+
def __init__(self, keep_prob=0.5):
|
|
111
|
+
super().__init__(1.)
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
class Dropout3D(ops.Dropout3D):
|
|
115
|
+
def __init__(self, keep_prob=0.5):
|
|
116
|
+
super().__init__(1.)
|
|
117
|
+
|
|
118
|
+
|
|
119
|
+
class DropoutExt(nn.Dropout):
|
|
120
|
+
def __init__(self, p=0.5):
|
|
121
|
+
super().__init__(0)
|
|
122
|
+
|
|
123
|
+
|
|
124
|
+
def dropout_ext(input_tensor, p=0.5, training=True):
|
|
125
|
+
return input_tensor
|
|
126
|
+
|
|
127
|
+
|
|
128
|
+
def remove_dropout():
|
|
129
|
+
ops.Dropout = Dropout
|
|
130
|
+
ops.operations.Dropout = Dropout
|
|
131
|
+
ops.Dropout2D = Dropout2D
|
|
132
|
+
ops.operations.Dropout2D = Dropout2D
|
|
133
|
+
ops.Dropout3D = Dropout3D
|
|
134
|
+
ops.operations.Dropout3D = Dropout3D
|
|
135
|
+
nn.Dropout = DropoutExt
|
|
136
|
+
nn.functional.dropout = dropout_ext
|
|
@@ -1,8 +1,7 @@
|
|
|
1
|
-
|
|
2
|
-
#
|
|
3
|
-
|
|
4
|
-
#
|
|
5
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
1
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
6
5
|
# you may not use this file except in compliance with the License.
|
|
7
6
|
# You may obtain a copy of the License at
|
|
8
7
|
#
|
|
@@ -13,31 +12,29 @@
|
|
|
13
12
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
14
13
|
# See the License for the specific language governing permissions and
|
|
15
14
|
# limitations under the License.
|
|
16
|
-
|
|
15
|
+
|
|
17
16
|
import os
|
|
18
|
-
from msprobe.core.common.utils import CompareException
|
|
19
|
-
check_configuration_param, task_dumppath_get
|
|
17
|
+
from msprobe.core.common.utils import CompareException
|
|
20
18
|
from msprobe.core.common.file_utils import create_directory
|
|
21
19
|
from msprobe.core.common.exceptions import FileCheckException
|
|
22
20
|
from msprobe.mindspore.common.log import logger
|
|
23
|
-
from msprobe.mindspore.compare.ms_compare import
|
|
21
|
+
from msprobe.mindspore.compare.ms_compare import ms_compare
|
|
24
22
|
from msprobe.core.compare.utils import check_and_return_dir_contents, extract_json
|
|
25
23
|
from msprobe.mindspore.compare.ms_graph_compare import GraphMSComparator
|
|
26
24
|
|
|
25
|
+
|
|
27
26
|
def ms_compare_distributed(npu_dump_dir, bench_dump_dir, output_path, **kwargs):
|
|
28
27
|
if kwargs.get('suffix'):
|
|
29
28
|
logger.error("Argument 'suffix' is not supported for compare_distributed.")
|
|
30
29
|
raise CompareException(CompareException.INVALID_PARAM_ERROR)
|
|
31
|
-
|
|
32
|
-
auto_analyze = kwargs.get('auto_analyze', True)
|
|
33
|
-
fuzzy_match = kwargs.get('fuzzy_match', False)
|
|
30
|
+
is_print_compare_log = kwargs.get('is_print_compare_log', True)
|
|
34
31
|
# get the ranks and match by order
|
|
35
32
|
npu_ranks = sorted(check_and_return_dir_contents(npu_dump_dir, 'rank'))
|
|
36
33
|
bench_ranks = sorted(check_and_return_dir_contents(bench_dump_dir, 'rank'))
|
|
37
34
|
if len(npu_ranks) != len(bench_ranks):
|
|
38
35
|
logger.error('The number of ranks in the two runs are different. '
|
|
39
|
-
|
|
40
|
-
|
|
36
|
+
'Unable to match the ranks. Please use another folder to compare '
|
|
37
|
+
'or use compare() api and manually match the ranks.')
|
|
41
38
|
raise CompareException(CompareException.INVALID_PATH_ERROR)
|
|
42
39
|
for nr, br in zip(npu_ranks, bench_ranks):
|
|
43
40
|
npu_data_dir = os.path.join(npu_dump_dir, nr)
|
|
@@ -50,19 +47,9 @@ def ms_compare_distributed(npu_dump_dir, bench_dump_dir, output_path, **kwargs):
|
|
|
50
47
|
'npu_json_path': npu_path,
|
|
51
48
|
'bench_json_path': bench_path,
|
|
52
49
|
'stack_json_path': stack_path,
|
|
53
|
-
'is_print_compare_log':
|
|
50
|
+
'is_print_compare_log': is_print_compare_log
|
|
54
51
|
}
|
|
55
|
-
|
|
56
|
-
summary_compare, md5_compare = task_dumppath_get(dump_result_param)
|
|
57
|
-
check_configuration_param(stack_mode, auto_analyze, fuzzy_match)
|
|
58
|
-
create_directory(output_path)
|
|
59
|
-
check_compare_param(dump_result_param, output_path, summary_compare=summary_compare, md5_compare=md5_compare)
|
|
60
|
-
except (CompareException, FileCheckException) as error:
|
|
61
|
-
logger.error('Compare failed. Please check the arguments and do it again!')
|
|
62
|
-
raise CompareException(error.code) from error
|
|
63
|
-
ms_comparator = MSComparator()
|
|
64
|
-
ms_comparator.compare_core(dump_result_param, output_path, suffix=f'_{nr}-{br}', summary_compare=summary_compare,
|
|
65
|
-
md5_compare=md5_compare, **kwargs)
|
|
52
|
+
ms_compare(input_param=dump_result_param, output_path=output_path, suffix=f'_{nr}-{br}', **kwargs)
|
|
66
53
|
|
|
67
54
|
|
|
68
55
|
def ms_graph_compare(inputs, outputs):
|
|
@@ -71,5 +58,5 @@ def ms_graph_compare(inputs, outputs):
|
|
|
71
58
|
except (CompareException, FileCheckException) as error:
|
|
72
59
|
logger.error('Compare failed. Please check the arguments and do it again!')
|
|
73
60
|
return
|
|
74
|
-
|
|
75
|
-
|
|
61
|
+
ms_comparator = GraphMSComparator(inputs, outputs)
|
|
62
|
+
ms_comparator.compare_core()
|