mindstudio-probe 1.0.4__py3-none-any.whl → 1.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.1.dist-info}/METADATA +5 -5
- mindstudio_probe-1.1.1.dist-info/RECORD +341 -0
- {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.1.dist-info}/WHEEL +1 -1
- {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.1.dist-info}/entry_points.txt +0 -1
- msprobe/README.md +84 -18
- msprobe/__init__.py +16 -1
- msprobe/config.json +1 -5
- msprobe/core/advisor/advisor.py +16 -11
- msprobe/core/advisor/advisor_const.py +6 -7
- msprobe/core/advisor/advisor_result.py +12 -12
- msprobe/core/common/const.py +164 -3
- msprobe/core/common/exceptions.py +26 -4
- msprobe/core/common/file_utils.py +196 -27
- msprobe/core/common/inplace_op_checker.py +53 -0
- msprobe/core/common/inplace_ops.yaml +251 -0
- msprobe/core/common/log.py +46 -18
- msprobe/core/common/utils.py +308 -209
- msprobe/core/common_config.py +60 -38
- msprobe/core/compare/acc_compare.py +332 -94
- msprobe/core/compare/check.py +104 -22
- msprobe/core/compare/compare_cli.py +42 -5
- msprobe/core/compare/highlight.py +162 -57
- msprobe/core/compare/layer_mapping/__init__.py +19 -0
- msprobe/core/compare/layer_mapping/data_scope_parser.py +235 -0
- msprobe/core/compare/layer_mapping/layer_mapping.py +242 -0
- msprobe/core/compare/layer_mapping/postprocess_pass.py +94 -0
- msprobe/core/compare/multiprocessing_compute.py +33 -8
- msprobe/core/compare/npy_compare.py +73 -29
- msprobe/core/compare/utils.py +306 -247
- msprobe/core/data_dump/data_collector.py +44 -43
- msprobe/core/data_dump/data_processor/base.py +88 -35
- msprobe/core/data_dump/data_processor/factory.py +20 -3
- msprobe/core/data_dump/data_processor/mindspore_processor.py +14 -8
- msprobe/core/data_dump/data_processor/pytorch_processor.py +180 -66
- msprobe/core/data_dump/json_writer.py +63 -42
- msprobe/core/data_dump/scope.py +143 -48
- msprobe/core/grad_probe/constant.py +31 -13
- msprobe/core/grad_probe/grad_compare.py +20 -4
- msprobe/core/grad_probe/utils.py +44 -3
- msprobe/core/overflow_check/abnormal_scene.py +185 -0
- msprobe/core/overflow_check/api_info.py +55 -0
- msprobe/core/overflow_check/checker.py +138 -0
- msprobe/core/overflow_check/filter.py +157 -0
- msprobe/core/overflow_check/ignore_rules.yaml +55 -0
- msprobe/core/overflow_check/level.py +22 -0
- msprobe/core/overflow_check/utils.py +28 -0
- msprobe/docs/01.installation.md +29 -9
- msprobe/docs/02.config_introduction.md +83 -84
- msprobe/docs/03.config_examples.md +3 -20
- msprobe/docs/04.kernel_dump_PyTorch.md +73 -0
- msprobe/docs/05.data_dump_PyTorch.md +143 -13
- msprobe/docs/06.data_dump_MindSpore.md +197 -88
- msprobe/docs/07.accuracy_checker_PyTorch.md +69 -46
- msprobe/docs/08.accuracy_checker_online_PyTorch.md +52 -17
- msprobe/docs/09.accuracy_checker_MindSpore.md +51 -15
- msprobe/docs/10.accuracy_compare_PyTorch.md +187 -99
- msprobe/docs/11.accuracy_compare_MindSpore.md +253 -31
- msprobe/docs/12.overflow_check_PyTorch.md +1 -1
- msprobe/docs/13.overflow_check_MindSpore.md +6 -6
- msprobe/docs/15.free_benchmarking_PyTorch.md +60 -55
- msprobe/docs/16.free_benchmarking_MindSpore.md +159 -0
- msprobe/docs/17.grad_probe.md +19 -22
- msprobe/docs/18.online_dispatch.md +89 -0
- msprobe/docs/19.monitor.md +468 -0
- msprobe/docs/20.monitor_performance_baseline.md +52 -0
- msprobe/docs/21.visualization_PyTorch.md +386 -0
- msprobe/docs/22.visualization_MindSpore.md +384 -0
- msprobe/docs/23.tool_function_introduction.md +28 -0
- msprobe/docs/{FAQ_PyTorch.md → FAQ.md} +25 -10
- msprobe/docs/data_dump_Mindspore/dynamic_graph_quick_start_example.md +211 -0
- msprobe/docs/img/compare_result.png +0 -0
- msprobe/docs/img/monitor/cpu_info.png +0 -0
- msprobe/docs/img/ms_dump.png +0 -0
- msprobe/docs/img/ms_layer.png +0 -0
- msprobe/docs/img/pt_dump.png +0 -0
- msprobe/mindspore/__init__.py +16 -0
- msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +130 -138
- msprobe/mindspore/api_accuracy_checker/api_info.py +27 -5
- msprobe/mindspore/api_accuracy_checker/api_runner.py +43 -18
- msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +21 -7
- msprobe/mindspore/api_accuracy_checker/checker_support_api.yaml +77 -0
- msprobe/mindspore/api_accuracy_checker/cmd_parser.py +63 -1
- msprobe/mindspore/api_accuracy_checker/compute_element.py +59 -24
- msprobe/mindspore/api_accuracy_checker/data_manager.py +264 -0
- msprobe/mindspore/api_accuracy_checker/main.py +27 -3
- msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +206 -0
- msprobe/mindspore/api_accuracy_checker/multi_data_manager.py +58 -0
- msprobe/mindspore/api_accuracy_checker/type_mapping.py +22 -5
- msprobe/mindspore/api_accuracy_checker/utils.py +34 -17
- msprobe/mindspore/cell_processor.py +58 -13
- msprobe/mindspore/common/const.py +35 -13
- msprobe/mindspore/common/log.py +5 -9
- msprobe/mindspore/common/utils.py +60 -5
- msprobe/mindspore/compare/distributed_compare.py +15 -28
- msprobe/mindspore/compare/ms_compare.py +319 -158
- msprobe/mindspore/compare/ms_graph_compare.py +99 -49
- msprobe/mindspore/debugger/debugger_config.py +20 -14
- msprobe/mindspore/debugger/precision_debugger.py +43 -13
- msprobe/mindspore/dump/dump_tool_factory.py +18 -1
- msprobe/mindspore/dump/hook_cell/api_registry.py +23 -3
- msprobe/mindspore/dump/hook_cell/primitive_hooks.py +203 -0
- msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +107 -10
- msprobe/mindspore/dump/hook_cell/wrap_api.py +21 -13
- msprobe/mindspore/dump/jit_dump.py +56 -20
- msprobe/mindspore/dump/kernel_graph_dump.py +19 -5
- msprobe/mindspore/dump/kernel_kbyk_dump.py +19 -6
- msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +140 -0
- msprobe/mindspore/dym_loader/hook_dynamic_loader.h +53 -0
- msprobe/mindspore/free_benchmark/api_pynative_self_check.py +162 -41
- msprobe/mindspore/free_benchmark/common/config.py +15 -0
- msprobe/mindspore/free_benchmark/common/handler_params.py +15 -1
- msprobe/mindspore/free_benchmark/common/utils.py +37 -8
- msprobe/mindspore/free_benchmark/data/support_wrap_ops.yaml +0 -204
- msprobe/mindspore/free_benchmark/handler/base_handler.py +20 -5
- msprobe/mindspore/free_benchmark/handler/check_handler.py +21 -7
- msprobe/mindspore/free_benchmark/handler/fix_handler.py +18 -3
- msprobe/mindspore/free_benchmark/handler/handler_factory.py +21 -6
- msprobe/mindspore/free_benchmark/perturbation/add_noise.py +23 -8
- msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +29 -5
- msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +25 -10
- msprobe/mindspore/free_benchmark/perturbation/exchange_value.py +45 -19
- msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +29 -8
- msprobe/mindspore/free_benchmark/perturbation/no_change.py +16 -1
- msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +22 -7
- msprobe/mindspore/free_benchmark/self_check_tool_factory.py +17 -2
- msprobe/mindspore/grad_probe/global_context.py +44 -14
- msprobe/mindspore/grad_probe/grad_analyzer.py +27 -13
- msprobe/mindspore/grad_probe/grad_monitor.py +16 -1
- msprobe/mindspore/grad_probe/grad_stat_csv.py +33 -5
- msprobe/mindspore/grad_probe/hook.py +24 -10
- msprobe/mindspore/grad_probe/utils.py +18 -5
- msprobe/mindspore/ms_config.py +22 -15
- msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +20 -6
- msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +15 -0
- msprobe/mindspore/runtime.py +15 -0
- msprobe/mindspore/service.py +75 -150
- msprobe/mindspore/task_handler_factory.py +15 -0
- msprobe/msprobe.py +24 -7
- msprobe/pytorch/__init__.py +23 -3
- msprobe/pytorch/api_accuracy_checker/common/config.py +81 -2
- msprobe/pytorch/api_accuracy_checker/common/utils.py +53 -21
- msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +19 -2
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +50 -25
- msprobe/pytorch/api_accuracy_checker/compare/compare.py +51 -21
- msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +23 -6
- msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +28 -8
- msprobe/pytorch/api_accuracy_checker/config.yaml +1 -1
- msprobe/pytorch/api_accuracy_checker/generate_op_script/config_op.json +9 -0
- msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +454 -0
- msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +365 -0
- msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +73 -33
- msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +44 -18
- msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +32 -11
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +122 -172
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +158 -4
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +30 -24
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +68 -31
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +27 -4
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/dump_dispatch.py +115 -0
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +26 -9
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/torch_ops_config.yaml +63 -0
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +44 -0
- msprobe/pytorch/bench_functions/__init__.py +18 -3
- msprobe/pytorch/bench_functions/apply_adam_w.py +15 -0
- msprobe/pytorch/bench_functions/confusion_transpose.py +20 -1
- msprobe/pytorch/bench_functions/fast_gelu.py +15 -0
- msprobe/pytorch/bench_functions/layer_norm_eval.py +15 -0
- msprobe/pytorch/bench_functions/linear.py +15 -0
- msprobe/pytorch/bench_functions/matmul_backward.py +33 -6
- msprobe/pytorch/bench_functions/npu_fusion_attention.py +280 -157
- msprobe/pytorch/bench_functions/rms_norm.py +15 -0
- msprobe/pytorch/bench_functions/rotary_mul.py +32 -9
- msprobe/pytorch/bench_functions/scaled_mask_softmax.py +15 -0
- msprobe/pytorch/bench_functions/swiglu.py +29 -6
- msprobe/pytorch/common/__init__.py +15 -0
- msprobe/pytorch/common/log.py +18 -6
- msprobe/pytorch/common/parse_json.py +31 -16
- msprobe/pytorch/common/utils.py +96 -40
- msprobe/pytorch/compare/distributed_compare.py +13 -14
- msprobe/pytorch/compare/match.py +15 -0
- msprobe/pytorch/compare/pt_compare.py +44 -10
- msprobe/pytorch/debugger/debugger_config.py +69 -52
- msprobe/pytorch/debugger/precision_debugger.py +72 -24
- msprobe/pytorch/dump/kernel_dump/kernel_config.py +33 -0
- msprobe/pytorch/free_benchmark/__init__.py +20 -5
- msprobe/pytorch/free_benchmark/common/constant.py +15 -0
- msprobe/pytorch/free_benchmark/common/counter.py +15 -0
- msprobe/pytorch/free_benchmark/common/enums.py +43 -0
- msprobe/pytorch/free_benchmark/common/params.py +23 -1
- msprobe/pytorch/free_benchmark/common/utils.py +43 -5
- msprobe/pytorch/free_benchmark/compare/grad_saver.py +47 -9
- msprobe/pytorch/free_benchmark/compare/single_benchmark.py +17 -0
- msprobe/pytorch/free_benchmark/main.py +19 -4
- msprobe/pytorch/free_benchmark/perturbed_layers/base_layer.py +15 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/layer_factory.py +19 -4
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +18 -1
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +21 -4
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +28 -2
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +19 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/no_change.py +15 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/npu_base_layser.py +15 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/run_cpu.py +15 -0
- msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +65 -16
- msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +15 -0
- msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +21 -5
- msprobe/pytorch/free_benchmark/result_handlers/handler_factory.py +15 -0
- msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +19 -4
- msprobe/pytorch/function_factory.py +17 -2
- msprobe/pytorch/functional/module_dump.py +84 -0
- msprobe/pytorch/grad_probe/grad_monitor.py +23 -6
- msprobe/pytorch/grad_probe/grad_stat_csv.py +40 -10
- msprobe/pytorch/hook_module/__init__.py +16 -1
- msprobe/pytorch/hook_module/api_registry.py +13 -8
- msprobe/pytorch/hook_module/hook_module.py +17 -19
- msprobe/pytorch/hook_module/support_wrap_ops.yaml +1 -0
- msprobe/pytorch/hook_module/utils.py +4 -6
- msprobe/pytorch/hook_module/wrap_aten.py +12 -11
- msprobe/pytorch/hook_module/wrap_distributed.py +6 -7
- msprobe/pytorch/hook_module/wrap_functional.py +21 -20
- msprobe/pytorch/hook_module/wrap_npu_custom.py +9 -17
- msprobe/pytorch/hook_module/wrap_tensor.py +4 -6
- msprobe/pytorch/hook_module/wrap_torch.py +4 -6
- msprobe/pytorch/hook_module/wrap_vf.py +4 -6
- msprobe/pytorch/module_processer.py +18 -6
- msprobe/pytorch/monitor/anomaly_analyse.py +201 -0
- msprobe/pytorch/monitor/anomaly_detect.py +340 -0
- msprobe/pytorch/monitor/distributed/distributed_ops.yaml +19 -0
- msprobe/pytorch/monitor/distributed/stack_blacklist.yaml +5 -0
- msprobe/pytorch/monitor/distributed/wrap_distributed.py +272 -0
- msprobe/pytorch/monitor/features.py +108 -0
- msprobe/pytorch/monitor/module_hook.py +870 -0
- msprobe/pytorch/monitor/module_metric.py +193 -0
- msprobe/pytorch/monitor/module_spec_verifier.py +93 -0
- msprobe/pytorch/monitor/optimizer_collect.py +295 -0
- msprobe/pytorch/monitor/unittest/__init__.py +0 -0
- msprobe/pytorch/monitor/unittest/test_monitor.py +145 -0
- msprobe/pytorch/monitor/utils.py +250 -0
- msprobe/pytorch/monitor/visualizer.py +59 -0
- msprobe/pytorch/online_dispatch/__init__.py +2 -3
- msprobe/pytorch/online_dispatch/compare.py +38 -48
- msprobe/pytorch/online_dispatch/dispatch.py +50 -25
- msprobe/pytorch/online_dispatch/dump_compare.py +21 -9
- msprobe/pytorch/online_dispatch/single_compare.py +60 -39
- msprobe/pytorch/online_dispatch/torch_ops_config.yaml +9 -1
- msprobe/pytorch/online_dispatch/utils.py +48 -23
- msprobe/pytorch/parse.py +15 -0
- msprobe/pytorch/parse_tool/cli.py +5 -6
- msprobe/pytorch/parse_tool/lib/compare.py +19 -26
- msprobe/pytorch/parse_tool/lib/config.py +1 -1
- msprobe/pytorch/parse_tool/lib/parse_tool.py +4 -2
- msprobe/pytorch/parse_tool/lib/utils.py +40 -55
- msprobe/pytorch/parse_tool/lib/visualization.py +3 -1
- msprobe/pytorch/pt_config.py +192 -40
- msprobe/pytorch/service.py +110 -35
- msprobe/visualization/__init__.py +14 -0
- msprobe/visualization/builder/__init__.py +14 -0
- msprobe/visualization/builder/graph_builder.py +165 -0
- msprobe/visualization/builder/msprobe_adapter.py +205 -0
- msprobe/visualization/compare/__init__.py +14 -0
- msprobe/visualization/compare/graph_comparator.py +130 -0
- msprobe/visualization/compare/mode_adapter.py +211 -0
- msprobe/visualization/graph/__init__.py +14 -0
- msprobe/visualization/graph/base_node.py +124 -0
- msprobe/visualization/graph/graph.py +200 -0
- msprobe/visualization/graph/node_colors.py +95 -0
- msprobe/visualization/graph/node_op.py +39 -0
- msprobe/visualization/graph_service.py +214 -0
- msprobe/visualization/utils.py +232 -0
- mindstudio_probe-1.0.4.dist-info/RECORD +0 -276
- msprobe/docs/04.acl_config_examples.md +0 -76
- msprobe/mindspore/free_benchmark/decorator/dec_forward.py +0 -43
- msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +0 -107
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/ssl_config.py +0 -10
- msprobe/pytorch/functional/dump_module.py +0 -39
- {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.1.dist-info}/LICENSE +0 -0
- {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.1.dist-info}/top_level.txt +0 -0
- /msprobe/{mindspore/free_benchmark/decorator → pytorch/monitor}/__init__.py +0 -0
- /msprobe/pytorch/{functional/data_processor.py → monitor/distributed/__init__.py} +0 -0
|
@@ -1,8 +1,7 @@
|
|
|
1
|
-
|
|
2
|
-
#
|
|
3
|
-
|
|
4
|
-
#
|
|
5
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
1
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
6
5
|
# you may not use this file except in compliance with the License.
|
|
7
6
|
# You may obtain a copy of the License at
|
|
8
7
|
#
|
|
@@ -13,7 +12,7 @@
|
|
|
13
12
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
14
13
|
# See the License for the specific language governing permissions and
|
|
15
14
|
# limitations under the License.
|
|
16
|
-
|
|
15
|
+
|
|
17
16
|
from msprobe.pytorch.parse_tool.lib.interactive_cli import InteractiveCli
|
|
18
17
|
from msprobe.pytorch.common.log import logger
|
|
19
18
|
|
|
@@ -22,7 +22,7 @@ from collections import namedtuple
|
|
|
22
22
|
from msprobe.pytorch.parse_tool.lib.utils import Util
|
|
23
23
|
from msprobe.pytorch.parse_tool.lib.config import Const
|
|
24
24
|
from msprobe.pytorch.parse_tool.lib.parse_exception import ParseException
|
|
25
|
-
from msprobe.core.common.file_utils import
|
|
25
|
+
from msprobe.core.common.file_utils import create_directory, load_npy, save_npy_to_txt, write_csv, os_walk_for_files
|
|
26
26
|
|
|
27
27
|
|
|
28
28
|
class Compare:
|
|
@@ -49,10 +49,10 @@ class Compare:
|
|
|
49
49
|
dump_file = self.util.path_strip(dump_file)
|
|
50
50
|
file_name = ""
|
|
51
51
|
if os.path.isfile(dump_file):
|
|
52
|
-
self.log.info("Covert file is: %s"
|
|
52
|
+
self.log.info("Covert file is: %s" % dump_file)
|
|
53
53
|
file_name = os.path.basename(dump_file)
|
|
54
54
|
elif os.path.isdir(dump_file):
|
|
55
|
-
self.log.info("Convert all files in path: %s"
|
|
55
|
+
self.log.info("Convert all files in path: %s" % dump_file)
|
|
56
56
|
file_name = ""
|
|
57
57
|
output = output if output else Const.DUMP_CONVERT_DIR
|
|
58
58
|
convert = self.convert(dump_file, data_format, output, msaccucmp_path)
|
|
@@ -62,7 +62,7 @@ class Compare:
|
|
|
62
62
|
summary_txt = ["SrcFile: %s" % dump_file]
|
|
63
63
|
for convert_file in convert_files.values():
|
|
64
64
|
summary_txt.append(" - %s" % convert_file.file_name)
|
|
65
|
-
self.log.info("Transfer result is saved in : %s"
|
|
65
|
+
self.log.info("Transfer result is saved in : %s" % os.path.realpath(output))
|
|
66
66
|
self.util.print_panel("\n".join(summary_txt))
|
|
67
67
|
|
|
68
68
|
def convert(self, dump_file, data_format, output, msaccucmp_path):
|
|
@@ -114,11 +114,11 @@ class Compare:
|
|
|
114
114
|
shape_left = data_left.shape
|
|
115
115
|
shape_right = data_right.shape
|
|
116
116
|
if shape_left != shape_right:
|
|
117
|
-
self.log.warning("Data shape not equal: %s vs %s"
|
|
117
|
+
self.log.warning("Data shape not equal: %s vs %s" % (data_left.shape, data_right.shape))
|
|
118
118
|
data_left = data_left.reshape(-1)
|
|
119
119
|
data_right = data_right.reshape(-1)
|
|
120
120
|
if data_left.shape[0] != data_right.shape[0]:
|
|
121
|
-
self.log.warning("Data size not equal: %s vs %s"
|
|
121
|
+
self.log.warning("Data size not equal: %s vs %s" % (data_left.shape, data_right.shape))
|
|
122
122
|
if data_left.shape[0] < data_right.shape[0]:
|
|
123
123
|
data_left = np.pad(data_left, (0, data_right.shape[0] - data_left.shape[0]), 'constant')
|
|
124
124
|
else:
|
|
@@ -160,7 +160,7 @@ class Compare:
|
|
|
160
160
|
if shape != bench_shape or dtype != bench_dtype:
|
|
161
161
|
self.log.error(
|
|
162
162
|
"Shape or dtype between two npy files is inconsistent. Please check the two files."
|
|
163
|
-
"File 1: %s, file 2: %s"
|
|
163
|
+
"File 1: %s, file 2: %s" % (file, bench_file))
|
|
164
164
|
self.util.deal_with_dir_or_file_inconsistency(output_path)
|
|
165
165
|
return
|
|
166
166
|
md5_consistency = False
|
|
@@ -236,25 +236,18 @@ class Compare:
|
|
|
236
236
|
golden_subdir_path = os.path.join(golden_dump_dir, golden_subdir_name)
|
|
237
237
|
self.compare_timestamp_directory(my_subdir_path, golden_subdir_path, output_path)
|
|
238
238
|
self.util.change_filemode_safe(output_path)
|
|
239
|
-
self.log.info("Compare result is saved in : %s"
|
|
239
|
+
self.log.info("Compare result is saved in : %s" % (output_path))
|
|
240
240
|
|
|
241
241
|
def convert_api_dir_to_npy(self, dump_dir, param, output_dir, msaccucmp_path):
|
|
242
242
|
dump_dir = self.util.path_strip(dump_dir)
|
|
243
|
-
|
|
244
|
-
|
|
245
|
-
|
|
246
|
-
|
|
247
|
-
|
|
248
|
-
|
|
249
|
-
|
|
250
|
-
|
|
251
|
-
|
|
252
|
-
|
|
253
|
-
|
|
254
|
-
output_path = os.path.join(output_dir, op_name, timestamp)
|
|
255
|
-
self.convert_dump_to_npy(file_path, param, output_path, msaccucmp_path)
|
|
256
|
-
path_depth = root.count(os.sep)
|
|
257
|
-
if path_depth <= Const.MAX_TRAVERSAL_DEPTH:
|
|
258
|
-
yield root, _, files
|
|
259
|
-
else:
|
|
260
|
-
_[:] = []
|
|
243
|
+
files = os_walk_for_files(dump_dir, Const.MAX_TRAVERSAL_DEPTH)
|
|
244
|
+
filepaths = [os.path.join(file['root'], file['file']) for file in files]
|
|
245
|
+
for path in filepaths:
|
|
246
|
+
filename = os.path.basename(path)
|
|
247
|
+
parts = filename.split(".")
|
|
248
|
+
if len(parts) < 5:
|
|
249
|
+
continue
|
|
250
|
+
op_name = parts[1]
|
|
251
|
+
timestamp = parts[-1]
|
|
252
|
+
output_path = os.path.join(output_dir, op_name, timestamp)
|
|
253
|
+
self.convert_dump_to_npy(path, param, output_path, msaccucmp_path)
|
|
@@ -33,7 +33,7 @@ class Const:
|
|
|
33
33
|
OFFLINE_DUMP_CONVERT_PATTERN = \
|
|
34
34
|
r"^([A-Za-z0-9_-]+)\.([A-Za-z0-9_-]+)\.([0-9]+)(\.[0-9]+)?\.([0-9]{1,255})" \
|
|
35
35
|
r"\.([a-z]+)\.([0-9]{1,255})(\.[x0-9]+)?\.npy$"
|
|
36
|
-
NUMPY_PATTERN = r"^[\w\-_
|
|
36
|
+
NUMPY_PATTERN = r"^[\w\-_.]+\.npy$"
|
|
37
37
|
NPY_SUFFIX = ".npy"
|
|
38
38
|
PKL_SUFFIX = ".pkl"
|
|
39
39
|
DIRECTORY_LENGTH = 4096
|
|
@@ -110,6 +110,9 @@ class ParseTool:
|
|
|
110
110
|
parser.add_argument('-al', '--atol', dest='atol', default=0.001, type=float, help='set rtol')
|
|
111
111
|
parser.add_argument('-rl', '--rtol', dest='rtol', default=0.001, type=float, help='set atol')
|
|
112
112
|
args = parser.parse_args(argv)
|
|
113
|
+
self.util.check_positive(args.count)
|
|
114
|
+
self.util.check_positive(args.rtol)
|
|
115
|
+
self.util.check_positive(args.atol)
|
|
113
116
|
self.util.check_path_valid(args.my_dump_path)
|
|
114
117
|
self.util.check_path_valid(args.golden_dump_path)
|
|
115
118
|
self.util.check_file_path_format(args.my_dump_path, Const.NPY_SUFFIX)
|
|
@@ -129,8 +132,7 @@ class ParseTool:
|
|
|
129
132
|
" '-m' and '-g'.")
|
|
130
133
|
raise ParseException("My directory path and golden directory path is same.")
|
|
131
134
|
output_path = self.util.path_strip(args.output_path) if args.output_path else Const.BATCH_COMPARE_DIR
|
|
132
|
-
|
|
133
|
-
os.makedirs(output_path, mode=0o750)
|
|
135
|
+
create_directory(output_path)
|
|
134
136
|
self.compare.compare_converted_dir(my_dump_dir, golden_dump_dir, output_path)
|
|
135
137
|
|
|
136
138
|
@catch_exception
|
|
@@ -28,7 +28,7 @@ from msprobe.pytorch.parse_tool.lib.parse_exception import ParseException
|
|
|
28
28
|
from msprobe.core.common.file_utils import change_mode, check_other_user_writable,\
|
|
29
29
|
check_path_executable, check_path_owner_consistent
|
|
30
30
|
from msprobe.core.common.const import FileCheckConst
|
|
31
|
-
from msprobe.core.common.file_utils import
|
|
31
|
+
from msprobe.core.common.file_utils import check_file_or_directory_path, remove_path, check_file_type, os_walk_for_files
|
|
32
32
|
from msprobe.pytorch.common.log import logger
|
|
33
33
|
|
|
34
34
|
|
|
@@ -71,31 +71,21 @@ class Util:
|
|
|
71
71
|
check_path_executable(path)
|
|
72
72
|
|
|
73
73
|
@staticmethod
|
|
74
|
-
def get_subdir_count(
|
|
74
|
+
def get_subdir_count(directory):
|
|
75
75
|
subdir_count = 0
|
|
76
|
-
|
|
77
|
-
path_checker.common_check()
|
|
76
|
+
check_file_or_directory_path(directory, isdir=True)
|
|
78
77
|
for _, dirs, _ in os.walk(directory):
|
|
79
78
|
subdir_count += len(dirs)
|
|
80
79
|
break
|
|
81
80
|
return subdir_count
|
|
82
81
|
|
|
83
82
|
@staticmethod
|
|
84
|
-
def get_subfiles_count(
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
path_checker = FileChecker(root)
|
|
88
|
-
path_checker.common_check()
|
|
89
|
-
file_count += len(files)
|
|
90
|
-
path_depth = root.count(os.sep)
|
|
91
|
-
if path_depth <= Const.MAX_TRAVERSAL_DEPTH:
|
|
92
|
-
yield root, _, files
|
|
93
|
-
else:
|
|
94
|
-
_[:] = []
|
|
95
|
-
return file_count
|
|
83
|
+
def get_subfiles_count(directory):
|
|
84
|
+
files = os_walk_for_files(directory, Const.MAX_TRAVERSAL_DEPTH)
|
|
85
|
+
return len(files)
|
|
96
86
|
|
|
97
87
|
@staticmethod
|
|
98
|
-
def get_sorted_subdirectories_names(
|
|
88
|
+
def get_sorted_subdirectories_names(directory):
|
|
99
89
|
subdirectories = []
|
|
100
90
|
for item in os.listdir(directory):
|
|
101
91
|
item_path = os.path.join(directory, item)
|
|
@@ -104,7 +94,7 @@ class Util:
|
|
|
104
94
|
return sorted(subdirectories)
|
|
105
95
|
|
|
106
96
|
@staticmethod
|
|
107
|
-
def get_sorted_files_names(
|
|
97
|
+
def get_sorted_files_names(directory):
|
|
108
98
|
files = []
|
|
109
99
|
for item in os.listdir(directory):
|
|
110
100
|
item_path = os.path.join(directory, item)
|
|
@@ -113,7 +103,7 @@ class Util:
|
|
|
113
103
|
return sorted(files)
|
|
114
104
|
|
|
115
105
|
@staticmethod
|
|
116
|
-
def check_npy_files_valid_in_dir(
|
|
106
|
+
def check_npy_files_valid_in_dir(dir_path):
|
|
117
107
|
for file_name in os.listdir(dir_path):
|
|
118
108
|
file_path = os.path.join(dir_path, file_name)
|
|
119
109
|
check_file_or_directory_path(file_path)
|
|
@@ -123,18 +113,18 @@ class Util:
|
|
|
123
113
|
return True
|
|
124
114
|
|
|
125
115
|
@staticmethod
|
|
126
|
-
def get_md5_for_numpy(
|
|
116
|
+
def get_md5_for_numpy(obj):
|
|
127
117
|
np_bytes = obj.tobytes()
|
|
128
118
|
md5_hash = hashlib.md5(np_bytes)
|
|
129
119
|
return md5_hash.hexdigest()
|
|
130
120
|
|
|
131
121
|
@staticmethod
|
|
132
|
-
def deal_with_dir_or_file_inconsistency(
|
|
122
|
+
def deal_with_dir_or_file_inconsistency(output_path):
|
|
133
123
|
remove_path(output_path)
|
|
134
124
|
raise ParseException("Inconsistent directory structure or file.")
|
|
135
125
|
|
|
136
126
|
@staticmethod
|
|
137
|
-
def deal_with_value_if_has_zero(
|
|
127
|
+
def deal_with_value_if_has_zero(data):
|
|
138
128
|
if data.dtype in Const.FLOAT_TYPE:
|
|
139
129
|
zero_mask = (data == 0)
|
|
140
130
|
# 给0的地方加上eps防止除0
|
|
@@ -147,26 +137,19 @@ class Util:
|
|
|
147
137
|
return data
|
|
148
138
|
|
|
149
139
|
@staticmethod
|
|
150
|
-
def dir_contains_only(
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
if not file.endswith(endfix):
|
|
156
|
-
return False
|
|
157
|
-
path_depth = root.count(os.sep)
|
|
158
|
-
if path_depth <= Const.MAX_TRAVERSAL_DEPTH:
|
|
159
|
-
yield root, _, files
|
|
160
|
-
else:
|
|
161
|
-
_[:] = []
|
|
140
|
+
def dir_contains_only(path, endfix):
|
|
141
|
+
files = os_walk_for_files(path, Const.MAX_TRAVERSAL_DEPTH)
|
|
142
|
+
for file in files:
|
|
143
|
+
if not file['file'].endswith(endfix):
|
|
144
|
+
return False
|
|
162
145
|
return True
|
|
163
146
|
|
|
164
147
|
@staticmethod
|
|
165
|
-
def localtime_str(
|
|
148
|
+
def localtime_str():
|
|
166
149
|
return time.strftime("%Y%m%d%H%M%S", time.localtime())
|
|
167
150
|
|
|
168
151
|
@staticmethod
|
|
169
|
-
def change_filemode_safe(
|
|
152
|
+
def change_filemode_safe(path):
|
|
170
153
|
change_mode(path, FileCheckConst.DATA_FILE_AUTHORITY)
|
|
171
154
|
|
|
172
155
|
@staticmethod
|
|
@@ -183,7 +166,7 @@ class Util:
|
|
|
183
166
|
if not cmd:
|
|
184
167
|
self.log.error("Commond is None")
|
|
185
168
|
return -1
|
|
186
|
-
self.log.info("[RUN CMD]: %s"
|
|
169
|
+
self.log.info("[RUN CMD]: %s" % cmd)
|
|
187
170
|
cmd = cmd.split(" ")
|
|
188
171
|
complete_process = subprocess.run(cmd, shell=False)
|
|
189
172
|
return complete_process.returncode
|
|
@@ -205,7 +188,7 @@ class Util:
|
|
|
205
188
|
result = subprocess.run(
|
|
206
189
|
[self.python, target_file, "--help"], stdout=subprocess.PIPE, shell=False)
|
|
207
190
|
if result.returncode == 0:
|
|
208
|
-
self.log.info("Check [%s] success."
|
|
191
|
+
self.log.info("Check [%s] success." % (target_file))
|
|
209
192
|
else:
|
|
210
193
|
self.log.error("Check msaccucmp failed in dir %s" % target_file)
|
|
211
194
|
self.log.error("Please specify a valid msaccucmp.py path or install the cann package")
|
|
@@ -244,8 +227,11 @@ class Util:
|
|
|
244
227
|
|
|
245
228
|
def check_path_valid(self, path):
|
|
246
229
|
path = self.path_strip(path)
|
|
247
|
-
|
|
248
|
-
|
|
230
|
+
if not path or not os.path.exists(path):
|
|
231
|
+
self.log.error("The path %s does not exist." % path)
|
|
232
|
+
raise ParseException(ParseException.PARSE_INVALID_PATH_ERROR)
|
|
233
|
+
isdir = check_file_type(path) == FileCheckConst.DIR
|
|
234
|
+
check_file_or_directory_path(path, isdir=isdir)
|
|
249
235
|
return True
|
|
250
236
|
|
|
251
237
|
def check_files_in_path(self, path):
|
|
@@ -273,21 +259,15 @@ class Util:
|
|
|
273
259
|
self.check_path_valid(path)
|
|
274
260
|
file_list = {}
|
|
275
261
|
re_pattern = re.compile(pattern)
|
|
276
|
-
|
|
277
|
-
|
|
278
|
-
|
|
279
|
-
|
|
280
|
-
|
|
281
|
-
|
|
282
|
-
|
|
283
|
-
|
|
284
|
-
|
|
285
|
-
file_list[name] = gen_info_func(name, match, dir_path)
|
|
286
|
-
path_depth = dir_path.count(os.sep)
|
|
287
|
-
if path_depth <= Const.MAX_TRAVERSAL_DEPTH:
|
|
288
|
-
yield dir_path, _, file_names
|
|
289
|
-
else:
|
|
290
|
-
_[:] = []
|
|
262
|
+
files = os_walk_for_files(path, Const.MAX_TRAVERSAL_DEPTH)
|
|
263
|
+
for file in files:
|
|
264
|
+
name = file["file"]
|
|
265
|
+
match = re_pattern.match(name)
|
|
266
|
+
if not match:
|
|
267
|
+
continue
|
|
268
|
+
if extern_pattern != '' and re_pattern.match(extern_pattern) and not re.match(extern_pattern, name):
|
|
269
|
+
continue
|
|
270
|
+
file_list[name] = gen_info_func(name, match, file["root"])
|
|
291
271
|
return file_list
|
|
292
272
|
|
|
293
273
|
def check_file_path_format(self, path, suffix):
|
|
@@ -314,3 +294,8 @@ class Util:
|
|
|
314
294
|
dir1_count = self.get_subdir_count(dir1)
|
|
315
295
|
dir2_count = self.get_subdir_count(dir2)
|
|
316
296
|
return dir1_count == dir2_count
|
|
297
|
+
|
|
298
|
+
def check_positive(self, value):
|
|
299
|
+
if value <= 0.0:
|
|
300
|
+
self.log.error("Invalid value. It must be greater than 0.")
|
|
301
|
+
raise ParseException(ParseException.PARSE_INVALID_DATA_ERROR)
|
|
@@ -28,7 +28,7 @@ class Visualization:
|
|
|
28
28
|
self.util = Util()
|
|
29
29
|
|
|
30
30
|
def print_npy_summary(self, target_file):
|
|
31
|
-
np_data = load_npy(target_file
|
|
31
|
+
np_data = load_npy(target_file)
|
|
32
32
|
table = self.util.create_table('', ['Index', 'Data'])
|
|
33
33
|
flatten_data = np_data.flatten()
|
|
34
34
|
tablesize = 8
|
|
@@ -65,6 +65,8 @@ class Visualization:
|
|
|
65
65
|
self.util.log.error("%s %s in line %s" % ("JSONDecodeError", str(e), pkl_line))
|
|
66
66
|
self.util.log.warning("Please check the pkl file")
|
|
67
67
|
raise ParseException(ParseException.PARSE_JSONDECODE_ERROR) from e
|
|
68
|
+
if not isinstance(msg, list) or len(msg) == 0:
|
|
69
|
+
break
|
|
68
70
|
info_prefix = msg[0]
|
|
69
71
|
if not info_prefix.startswith(api_name):
|
|
70
72
|
continue
|
msprobe/pytorch/pt_config.py
CHANGED
|
@@ -1,12 +1,35 @@
|
|
|
1
|
-
|
|
1
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
2
16
|
import os
|
|
17
|
+
import re
|
|
3
18
|
|
|
4
|
-
from msprobe.core.common_config import CommonConfig, BaseConfig
|
|
5
|
-
from msprobe.core.common.file_utils import FileOpen
|
|
6
19
|
from msprobe.core.common.const import Const
|
|
7
|
-
from msprobe.
|
|
20
|
+
from msprobe.core.common.exceptions import MsprobeException
|
|
21
|
+
from msprobe.core.common.file_utils import FileOpen, load_json, check_file_or_directory_path, check_crt_valid
|
|
22
|
+
from msprobe.core.common.log import logger
|
|
23
|
+
from msprobe.core.common.utils import is_int
|
|
24
|
+
from msprobe.core.common_config import BaseConfig, CommonConfig
|
|
8
25
|
from msprobe.core.grad_probe.constant import level_adp
|
|
9
|
-
from msprobe.core.grad_probe.utils import
|
|
26
|
+
from msprobe.core.grad_probe.utils import check_bounds
|
|
27
|
+
from msprobe.pytorch.free_benchmark.common.enums import (
|
|
28
|
+
DeviceType,
|
|
29
|
+
HandlerType,
|
|
30
|
+
PytorchFreeBenchmarkConst,
|
|
31
|
+
)
|
|
32
|
+
from msprobe.pytorch.hook_module.utils import get_ops
|
|
10
33
|
|
|
11
34
|
|
|
12
35
|
class TensorConfig(BaseConfig):
|
|
@@ -16,23 +39,39 @@ class TensorConfig(BaseConfig):
|
|
|
16
39
|
self.nfs_path = json_config.get("nfs_path", "")
|
|
17
40
|
self.host = json_config.get("host", "")
|
|
18
41
|
self.port = json_config.get("port", -1)
|
|
19
|
-
self.tls_path = json_config.get("tls_path", "")
|
|
42
|
+
self.tls_path = json_config.get("tls_path", "./")
|
|
43
|
+
self.online_run_ut_recompute = json_config.get("online_run_ut_recompute", False)
|
|
20
44
|
self.check_config()
|
|
21
45
|
self._check_file_format()
|
|
22
|
-
self.
|
|
46
|
+
if self.online_run_ut:
|
|
47
|
+
self._check_online_run_ut()
|
|
23
48
|
|
|
24
49
|
def _check_file_format(self):
|
|
25
50
|
if self.file_format is not None and self.file_format not in ["npy", "bin"]:
|
|
26
51
|
raise Exception("file_format is invalid")
|
|
27
52
|
|
|
28
|
-
def
|
|
53
|
+
def _check_online_run_ut(self):
|
|
54
|
+
if not isinstance(self.online_run_ut, bool):
|
|
55
|
+
raise Exception(f"online_run_ut: {self.online_run_ut} is invalid.")
|
|
56
|
+
|
|
57
|
+
if not isinstance(self.online_run_ut_recompute, bool):
|
|
58
|
+
raise Exception(f"online_run_ut_recompute: {self.online_run_ut_recompute} is invalid.")
|
|
59
|
+
|
|
60
|
+
if self.nfs_path:
|
|
61
|
+
check_file_or_directory_path(self.nfs_path, isdir=True)
|
|
62
|
+
return
|
|
63
|
+
|
|
29
64
|
if self.tls_path:
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
65
|
+
check_file_or_directory_path(self.tls_path, isdir=True)
|
|
66
|
+
check_file_or_directory_path(os.path.join(self.tls_path, "client.key"))
|
|
67
|
+
check_file_or_directory_path(os.path.join(self.tls_path, "client.crt"))
|
|
68
|
+
check_crt_valid(os.path.join(self.tls_path, "client.crt"))
|
|
69
|
+
|
|
70
|
+
if not isinstance(self.host, str) or not re.match(Const.ipv4_pattern, self.host):
|
|
71
|
+
raise Exception(f"host: {self.host} is invalid.")
|
|
72
|
+
|
|
73
|
+
if not isinstance(self.port, int) or not (0 < self.port <= 65535):
|
|
74
|
+
raise Exception(f"port: {self.port} is invalid, port range 0-65535.")
|
|
36
75
|
|
|
37
76
|
|
|
38
77
|
class StatisticsConfig(BaseConfig):
|
|
@@ -54,30 +93,149 @@ class OverflowCheckConfig(BaseConfig):
|
|
|
54
93
|
self.check_overflow_config()
|
|
55
94
|
|
|
56
95
|
def check_overflow_config(self):
|
|
57
|
-
if self.overflow_nums is not None and not
|
|
96
|
+
if self.overflow_nums is not None and not is_int(self.overflow_nums):
|
|
58
97
|
raise Exception("overflow_num is invalid")
|
|
59
98
|
if self.check_mode is not None and self.check_mode not in ["all", "aicore", "atomic"]:
|
|
60
99
|
raise Exception("check_mode is invalid")
|
|
61
100
|
|
|
62
101
|
|
|
63
102
|
class FreeBenchmarkCheckConfig(BaseConfig):
|
|
103
|
+
|
|
64
104
|
def __init__(self, json_config):
|
|
65
105
|
super().__init__(json_config)
|
|
66
|
-
self.fuzz_device = json_config.get("fuzz_device")
|
|
67
|
-
self.pert_mode = json_config.get("pert_mode")
|
|
68
|
-
self.handler_type = json_config.get("handler_type")
|
|
69
|
-
self.fuzz_level = json_config.get("fuzz_level")
|
|
70
|
-
self.fuzz_stage = json_config.get("fuzz_stage")
|
|
71
|
-
self.if_preheat = json_config.get("if_preheat")
|
|
72
|
-
self.preheat_step = json_config.get("preheat_step")
|
|
73
|
-
self.max_sample = json_config.get("max_sample")
|
|
106
|
+
self.fuzz_device = json_config.get("fuzz_device", PytorchFreeBenchmarkConst.DEFAULT_DEVICE)
|
|
107
|
+
self.pert_mode = json_config.get("pert_mode", PytorchFreeBenchmarkConst.DEFAULT_MODE)
|
|
108
|
+
self.handler_type = json_config.get("handler_type", PytorchFreeBenchmarkConst.DEFAULT_HANDLER)
|
|
109
|
+
self.fuzz_level = json_config.get("fuzz_level", PytorchFreeBenchmarkConst.DEFAULT_FUZZ_LEVEL)
|
|
110
|
+
self.fuzz_stage = json_config.get("fuzz_stage", PytorchFreeBenchmarkConst.DEFAULT_FUZZ_STAGE)
|
|
111
|
+
self.if_preheat = json_config.get("if_preheat", False)
|
|
112
|
+
self.preheat_step = json_config.get("preheat_step", PytorchFreeBenchmarkConst.DEFAULT_PREHEAT_STEP)
|
|
113
|
+
self.max_sample = json_config.get("max_sample", PytorchFreeBenchmarkConst.DEFAULT_PREHEAT_STEP)
|
|
74
114
|
self.check_freebenchmark_config()
|
|
75
115
|
|
|
76
116
|
def check_freebenchmark_config(self):
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
117
|
+
self._check_pert_mode()
|
|
118
|
+
self._check_fuzz_device()
|
|
119
|
+
self._check_handler_type()
|
|
120
|
+
self._check_fuzz_stage()
|
|
121
|
+
self._check_fuzz_level()
|
|
122
|
+
self._check_if_preheat()
|
|
123
|
+
if self.handler_type == HandlerType.FIX:
|
|
124
|
+
self._check_fix_config()
|
|
125
|
+
if self.if_preheat:
|
|
126
|
+
self._check_preheat_config()
|
|
127
|
+
|
|
128
|
+
def _check_pert_mode(self):
|
|
129
|
+
if self.pert_mode not in PytorchFreeBenchmarkConst.PERTURBATION_MODE_LIST:
|
|
130
|
+
msg = (
|
|
131
|
+
f"pert_mode is invalid, it should be one of"
|
|
132
|
+
f" {PytorchFreeBenchmarkConst.PERTURBATION_MODE_LIST}"
|
|
133
|
+
)
|
|
134
|
+
logger.error_log_with_exp(
|
|
135
|
+
msg, MsprobeException(MsprobeException.INVALID_PARAM_ERROR, msg)
|
|
136
|
+
)
|
|
137
|
+
|
|
138
|
+
def _check_fuzz_device(self):
|
|
139
|
+
if self.fuzz_device not in PytorchFreeBenchmarkConst.DEVICE_LIST:
|
|
140
|
+
msg = (
|
|
141
|
+
f"fuzz_device is invalid, it should be one of"
|
|
142
|
+
f" {PytorchFreeBenchmarkConst.DEVICE_LIST}"
|
|
143
|
+
)
|
|
144
|
+
logger.error_log_with_exp(
|
|
145
|
+
msg, MsprobeException(MsprobeException.INVALID_PARAM_ERROR, msg)
|
|
146
|
+
)
|
|
147
|
+
if (self.fuzz_device == DeviceType.CPU) ^ (
|
|
148
|
+
self.pert_mode in PytorchFreeBenchmarkConst.CPU_MODE_LIST
|
|
149
|
+
):
|
|
150
|
+
msg = (
|
|
151
|
+
f"You neet to and can only set fuzz_device as {DeviceType.CPU} "
|
|
152
|
+
f"when pert_mode in {PytorchFreeBenchmarkConst.CPU_MODE_LIST}"
|
|
153
|
+
)
|
|
154
|
+
logger.error_log_with_exp(
|
|
155
|
+
msg, MsprobeException(MsprobeException.INVALID_PARAM_ERROR, msg)
|
|
156
|
+
)
|
|
157
|
+
|
|
158
|
+
def _check_handler_type(self):
|
|
159
|
+
if self.handler_type not in PytorchFreeBenchmarkConst.HANDLER_LIST:
|
|
160
|
+
msg = (
|
|
161
|
+
f"handler_type is invalid, it should be one of"
|
|
162
|
+
f" {PytorchFreeBenchmarkConst.HANDLER_LIST}"
|
|
163
|
+
)
|
|
164
|
+
logger.error_log_with_exp(
|
|
165
|
+
msg, MsprobeException(MsprobeException.INVALID_PARAM_ERROR, msg)
|
|
166
|
+
)
|
|
167
|
+
|
|
168
|
+
def _check_fuzz_stage(self):
|
|
169
|
+
if self.fuzz_stage not in PytorchFreeBenchmarkConst.FUZZ_STAGE_LIST:
|
|
170
|
+
msg = (
|
|
171
|
+
f"fuzz_stage is invalid, it should be one of"
|
|
172
|
+
f" {PytorchFreeBenchmarkConst.FUZZ_STAGE_LIST}"
|
|
173
|
+
)
|
|
174
|
+
logger.error_log_with_exp(
|
|
175
|
+
msg, MsprobeException(MsprobeException.INVALID_PARAM_ERROR, msg)
|
|
176
|
+
)
|
|
177
|
+
|
|
178
|
+
def _check_fuzz_level(self):
|
|
179
|
+
if self.fuzz_level not in PytorchFreeBenchmarkConst.FUZZ_LEVEL_LIST:
|
|
180
|
+
msg = (
|
|
181
|
+
f"fuzz_level is invalid, it should be one of"
|
|
182
|
+
f" {PytorchFreeBenchmarkConst.FUZZ_LEVEL_LIST}"
|
|
183
|
+
)
|
|
184
|
+
logger.error_log_with_exp(
|
|
185
|
+
msg, MsprobeException(MsprobeException.INVALID_PARAM_ERROR, msg)
|
|
186
|
+
)
|
|
187
|
+
|
|
188
|
+
def _check_if_preheat(self):
|
|
189
|
+
if not isinstance(self.if_preheat, bool):
|
|
190
|
+
msg = "if_preheat is invalid, it should be a boolean"
|
|
191
|
+
logger.error_log_with_exp(
|
|
192
|
+
msg, MsprobeException(MsprobeException.INVALID_PARAM_ERROR, msg)
|
|
193
|
+
)
|
|
194
|
+
|
|
195
|
+
def _check_preheat_config(self):
|
|
196
|
+
if not is_int(self.preheat_step):
|
|
197
|
+
msg = "preheat_step is invalid, it should be an integer"
|
|
198
|
+
logger.error_log_with_exp(
|
|
199
|
+
msg, MsprobeException(MsprobeException.INVALID_PARAM_ERROR, msg)
|
|
200
|
+
)
|
|
201
|
+
if self.preheat_step <= 0:
|
|
202
|
+
msg = "preheat_step must be greater than 0"
|
|
203
|
+
logger.error_log_with_exp(
|
|
204
|
+
msg, MsprobeException(MsprobeException.INVALID_PARAM_ERROR, msg)
|
|
205
|
+
)
|
|
206
|
+
if not is_int(self.max_sample):
|
|
207
|
+
msg = "max_sample is invalid, it should be an integer"
|
|
208
|
+
logger.error_log_with_exp(
|
|
209
|
+
msg, MsprobeException(MsprobeException.INVALID_PARAM_ERROR, msg)
|
|
210
|
+
)
|
|
211
|
+
if self.max_sample <= 0:
|
|
212
|
+
msg = "max_sample must be greater than 0"
|
|
213
|
+
logger.error_log_with_exp(
|
|
214
|
+
msg, MsprobeException(MsprobeException.INVALID_PARAM_ERROR, msg)
|
|
215
|
+
)
|
|
216
|
+
|
|
217
|
+
def _check_fix_config(self):
|
|
218
|
+
if self.if_preheat:
|
|
219
|
+
msg = f"Preheating is not supported for {HandlerType.FIX} handler type"
|
|
220
|
+
logger.error_log_with_exp(
|
|
221
|
+
msg, MsprobeException(MsprobeException.INVALID_PARAM_ERROR, msg)
|
|
222
|
+
)
|
|
223
|
+
if self.fuzz_stage not in PytorchFreeBenchmarkConst.FIX_STAGE_LIST:
|
|
224
|
+
msg = (
|
|
225
|
+
f"The fuzz_stage when opening {HandlerType.FIX} handler must be one of "
|
|
226
|
+
f"{PytorchFreeBenchmarkConst.FIX_STAGE_LIST}"
|
|
227
|
+
)
|
|
228
|
+
logger.error_log_with_exp(
|
|
229
|
+
msg, MsprobeException(MsprobeException.INVALID_PARAM_ERROR, msg)
|
|
230
|
+
)
|
|
231
|
+
if self.pert_mode not in PytorchFreeBenchmarkConst.FIX_MODE_LIST:
|
|
232
|
+
msg = (
|
|
233
|
+
f"The pert_mode when opening {HandlerType.FIX} handler must be one of "
|
|
234
|
+
f"{PytorchFreeBenchmarkConst.FIX_MODE_LIST}"
|
|
235
|
+
)
|
|
236
|
+
logger.error_log_with_exp(
|
|
237
|
+
msg, MsprobeException(MsprobeException.INVALID_PARAM_ERROR, msg)
|
|
238
|
+
)
|
|
81
239
|
|
|
82
240
|
|
|
83
241
|
class RunUTConfig(BaseConfig):
|
|
@@ -93,7 +251,7 @@ class RunUTConfig(BaseConfig):
|
|
|
93
251
|
self.host = json_config.get("host", "")
|
|
94
252
|
self.port = json_config.get("port", -1)
|
|
95
253
|
self.rank_list = json_config.get("rank_list", Const.DEFAULT_LIST)
|
|
96
|
-
self.tls_path = json_config.get("tls_path", "")
|
|
254
|
+
self.tls_path = json_config.get("tls_path", "./")
|
|
97
255
|
self.check_run_ut_config()
|
|
98
256
|
|
|
99
257
|
@classmethod
|
|
@@ -118,13 +276,8 @@ class RunUTConfig(BaseConfig):
|
|
|
118
276
|
|
|
119
277
|
@classmethod
|
|
120
278
|
def check_tls_path_config(cls, tls_path):
|
|
121
|
-
if tls_path:
|
|
122
|
-
|
|
123
|
-
raise Exception("tls_path: %s does not exist" % tls_path)
|
|
124
|
-
if not os.path.exists(os.path.join(tls_path, "server.key")):
|
|
125
|
-
raise Exception("tls_path does not contain server.key")
|
|
126
|
-
if not os.path.exists(os.path.join(tls_path, "server.crt")):
|
|
127
|
-
raise Exception("tls_path does not contain server.crt")
|
|
279
|
+
if tls_path and not os.path.exists(tls_path):
|
|
280
|
+
raise Exception("tls_path: %s does not exist" % tls_path)
|
|
128
281
|
|
|
129
282
|
def check_run_ut_config(self):
|
|
130
283
|
RunUTConfig.check_filter_list_config(Const.WHITE_LIST, self.white_list)
|
|
@@ -141,13 +294,13 @@ class GradToolConfig(BaseConfig):
|
|
|
141
294
|
self.param_list = json_config.get("param_list", [])
|
|
142
295
|
self.bounds = json_config.get("bounds", [-1, 0, 1])
|
|
143
296
|
self._check_config()
|
|
144
|
-
|
|
297
|
+
|
|
145
298
|
def _check_config(self):
|
|
146
299
|
if self.grad_level not in level_adp.keys():
|
|
147
300
|
raise Exception(f"grad_level must be one of {level_adp.keys()}")
|
|
148
301
|
if not isinstance(self.param_list, list):
|
|
149
302
|
raise Exception(f"param_list must be a list")
|
|
150
|
-
|
|
303
|
+
check_bounds(self.bounds)
|
|
151
304
|
|
|
152
305
|
|
|
153
306
|
def parse_task_config(task, json_config):
|
|
@@ -178,10 +331,9 @@ def parse_json_config(json_file_path, task):
|
|
|
178
331
|
if not json_file_path:
|
|
179
332
|
config_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
|
|
180
333
|
json_file_path = os.path.join(config_dir, "config.json")
|
|
181
|
-
|
|
182
|
-
json_config = json.load(file)
|
|
334
|
+
json_config = load_json(json_file_path)
|
|
183
335
|
common_config = CommonConfig(json_config)
|
|
184
|
-
if task
|
|
336
|
+
if task:
|
|
185
337
|
task_config = parse_task_config(task, json_config)
|
|
186
338
|
else:
|
|
187
339
|
task_config = parse_task_config(common_config.task, json_config)
|