mindstudio-probe 1.0.1__py3-none-any.whl → 1.0.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.3.dist-info}/METADATA +5 -1
- mindstudio_probe-1.0.3.dist-info/RECORD +272 -0
- msprobe/README.md +78 -23
- msprobe/__init__.py +1 -0
- msprobe/config/README.md +182 -40
- msprobe/config/config.json +22 -0
- msprobe/core/__init__.py +0 -0
- msprobe/{pytorch → core}/advisor/advisor.py +3 -3
- msprobe/{pytorch → core}/advisor/advisor_result.py +2 -2
- msprobe/core/common/const.py +82 -5
- msprobe/core/common/exceptions.py +30 -18
- msprobe/core/common/file_check.py +19 -1
- msprobe/core/common/log.py +15 -1
- msprobe/core/common/utils.py +130 -30
- msprobe/core/common_config.py +32 -19
- msprobe/core/compare/acc_compare.py +299 -0
- msprobe/core/compare/check.py +95 -0
- msprobe/core/compare/compare_cli.py +49 -0
- msprobe/core/compare/highlight.py +222 -0
- msprobe/core/compare/multiprocessing_compute.py +149 -0
- msprobe/{pytorch → core}/compare/npy_compare.py +55 -4
- msprobe/core/compare/utils.py +429 -0
- msprobe/core/data_dump/data_collector.py +39 -35
- msprobe/core/data_dump/data_processor/base.py +85 -37
- msprobe/core/data_dump/data_processor/factory.py +5 -7
- msprobe/core/data_dump/data_processor/mindspore_processor.py +198 -0
- msprobe/core/data_dump/data_processor/pytorch_processor.py +94 -51
- msprobe/core/data_dump/json_writer.py +11 -11
- msprobe/core/grad_probe/__init__.py +0 -0
- msprobe/core/grad_probe/constant.py +71 -0
- msprobe/core/grad_probe/grad_compare.py +175 -0
- msprobe/core/grad_probe/utils.py +52 -0
- msprobe/doc/grad_probe/grad_probe.md +207 -0
- msprobe/doc/grad_probe/img/image-1.png +0 -0
- msprobe/doc/grad_probe/img/image-2.png +0 -0
- msprobe/doc/grad_probe/img/image-3.png +0 -0
- msprobe/doc/grad_probe/img/image-4.png +0 -0
- msprobe/doc/grad_probe/img/image.png +0 -0
- msprobe/mindspore/api_accuracy_checker/__init__.py +0 -0
- msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +246 -0
- msprobe/mindspore/api_accuracy_checker/api_info.py +69 -0
- msprobe/mindspore/api_accuracy_checker/api_runner.py +152 -0
- msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +197 -0
- msprobe/mindspore/api_accuracy_checker/compute_element.py +224 -0
- msprobe/mindspore/api_accuracy_checker/main.py +16 -0
- msprobe/mindspore/api_accuracy_checker/type_mapping.py +114 -0
- msprobe/mindspore/api_accuracy_checker/utils.py +63 -0
- msprobe/mindspore/cell_processor.py +34 -0
- msprobe/mindspore/common/const.py +87 -0
- msprobe/mindspore/common/log.py +38 -0
- msprobe/mindspore/common/utils.py +57 -0
- msprobe/mindspore/compare/distributed_compare.py +75 -0
- msprobe/mindspore/compare/ms_compare.py +117 -0
- msprobe/mindspore/compare/ms_graph_compare.py +317 -0
- msprobe/mindspore/compare/ms_to_pt_api.yaml +399 -0
- msprobe/mindspore/debugger/debugger_config.py +38 -15
- msprobe/mindspore/debugger/precision_debugger.py +79 -4
- msprobe/mindspore/doc/compare.md +58 -0
- msprobe/mindspore/doc/dump.md +158 -6
- msprobe/mindspore/dump/dump_tool_factory.py +19 -22
- msprobe/mindspore/dump/hook_cell/api_registry.py +104 -0
- msprobe/mindspore/dump/hook_cell/hook_cell.py +53 -0
- msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +925 -0
- msprobe/mindspore/dump/hook_cell/wrap_functional.py +91 -0
- msprobe/mindspore/dump/hook_cell/wrap_tensor.py +63 -0
- msprobe/mindspore/dump/jit_dump.py +56 -0
- msprobe/mindspore/dump/kernel_kbyk_dump.py +65 -0
- msprobe/mindspore/free_benchmark/__init__.py +0 -0
- msprobe/mindspore/free_benchmark/api_pynative_self_check.py +116 -0
- msprobe/mindspore/free_benchmark/common/__init__.py +0 -0
- msprobe/mindspore/free_benchmark/common/config.py +12 -0
- msprobe/mindspore/free_benchmark/common/handler_params.py +17 -0
- msprobe/mindspore/free_benchmark/common/utils.py +71 -0
- msprobe/mindspore/free_benchmark/data/support_wrap_ops.yaml +842 -0
- msprobe/mindspore/free_benchmark/decorator/__init__.py +0 -0
- msprobe/mindspore/free_benchmark/decorator/dec_forward.py +42 -0
- msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +107 -0
- msprobe/mindspore/free_benchmark/handler/__init__.py +0 -0
- msprobe/mindspore/free_benchmark/handler/base_handler.py +90 -0
- msprobe/mindspore/free_benchmark/handler/check_handler.py +41 -0
- msprobe/mindspore/free_benchmark/handler/fix_handler.py +36 -0
- msprobe/mindspore/free_benchmark/handler/handler_factory.py +21 -0
- msprobe/mindspore/free_benchmark/perturbation/add_noise.py +67 -0
- msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +21 -0
- msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +63 -0
- msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +34 -0
- msprobe/mindspore/free_benchmark/perturbation/no_change.py +12 -0
- msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +27 -0
- msprobe/mindspore/free_benchmark/self_check_tool_factory.py +33 -0
- msprobe/mindspore/grad_probe/__init__.py +0 -0
- msprobe/mindspore/grad_probe/global_context.py +91 -0
- msprobe/mindspore/grad_probe/grad_analyzer.py +231 -0
- msprobe/mindspore/grad_probe/grad_monitor.py +27 -0
- msprobe/mindspore/grad_probe/grad_stat_csv.py +132 -0
- msprobe/mindspore/grad_probe/hook.py +92 -0
- msprobe/mindspore/grad_probe/utils.py +29 -0
- msprobe/mindspore/ms_config.py +63 -15
- msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +17 -15
- msprobe/mindspore/runtime.py +4 -0
- msprobe/mindspore/service.py +354 -0
- msprobe/mindspore/task_handler_factory.py +7 -4
- msprobe/msprobe.py +66 -26
- msprobe/pytorch/__init__.py +1 -1
- msprobe/pytorch/api_accuracy_checker/common/config.py +21 -16
- msprobe/pytorch/api_accuracy_checker/common/utils.py +1 -60
- msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +2 -5
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +46 -10
- msprobe/pytorch/api_accuracy_checker/compare/compare.py +84 -48
- msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +8 -12
- msprobe/pytorch/api_accuracy_checker/config.yaml +7 -1
- msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +15 -11
- msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +11 -15
- msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +16 -9
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +193 -105
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +68 -1
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/__init__.py +0 -0
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +202 -0
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +324 -0
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +204 -0
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +218 -0
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/ssl_config.py +10 -0
- msprobe/pytorch/bench_functions/__init__.py +15 -0
- msprobe/pytorch/bench_functions/apply_adam_w.py +28 -0
- msprobe/pytorch/bench_functions/confusion_transpose.py +19 -0
- msprobe/pytorch/bench_functions/fast_gelu.py +55 -0
- msprobe/pytorch/bench_functions/layer_norm_eval.py +6 -0
- msprobe/pytorch/bench_functions/linear.py +12 -0
- msprobe/pytorch/bench_functions/matmul_backward.py +48 -0
- msprobe/pytorch/bench_functions/npu_fusion_attention.py +421 -0
- msprobe/pytorch/bench_functions/rms_norm.py +15 -0
- msprobe/pytorch/bench_functions/rotary_mul.py +52 -0
- msprobe/pytorch/bench_functions/scaled_mask_softmax.py +26 -0
- msprobe/pytorch/bench_functions/swiglu.py +55 -0
- msprobe/pytorch/common/parse_json.py +3 -1
- msprobe/pytorch/common/utils.py +83 -7
- msprobe/pytorch/compare/distributed_compare.py +19 -64
- msprobe/pytorch/compare/match.py +3 -6
- msprobe/pytorch/compare/pt_compare.py +40 -0
- msprobe/pytorch/debugger/debugger_config.py +11 -2
- msprobe/pytorch/debugger/precision_debugger.py +34 -4
- msprobe/pytorch/doc/api_accuracy_checker.md +57 -13
- msprobe/pytorch/doc/api_accuracy_checker_online.md +187 -0
- msprobe/pytorch/doc/dump.md +73 -20
- msprobe/pytorch/doc/ptdbg_ascend_compare.md +75 -11
- msprobe/pytorch/doc/ptdbg_ascend_quickstart.md +3 -3
- msprobe/pytorch/doc/run_overflow_check.md +1 -1
- msprobe/pytorch/doc//321/206/320/247/320/260/321/206/320/260/320/227/321/206/320/255/320/226/321/205/342/225/226/320/265/321/205/320/225/342/225/226/321/205/320/254/342/225/221/321/206/320/251/320/277/321/211/320/272/320/234/321/210/320/277/320/221/321/205/320/242/320/234/321/206/320/220/320/267/321/210/320/223/342/225/234/321/205/320/257/342/225/221/321/207/342/225/221/342/224/220/321/206/320/232/320/265/321/205/320/241/320/232.md +151 -0
- msprobe/pytorch/free_benchmark/common/constant.py +3 -0
- msprobe/pytorch/free_benchmark/common/utils.py +4 -0
- msprobe/pytorch/free_benchmark/compare/grad_saver.py +22 -26
- msprobe/pytorch/free_benchmark/main.py +7 -4
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +1 -1
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +1 -1
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +1 -1
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +3 -3
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/no_change.py +1 -1
- msprobe/pytorch/free_benchmark/perturbed_layers/run_cpu.py +1 -1
- msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +43 -29
- msprobe/pytorch/free_benchmark/result_handlers/handler_factory.py +0 -1
- msprobe/pytorch/function_factory.py +75 -0
- msprobe/pytorch/functional/dump_module.py +4 -4
- msprobe/pytorch/grad_probe/__init__.py +0 -0
- msprobe/pytorch/grad_probe/grad_monitor.py +90 -0
- msprobe/pytorch/grad_probe/grad_stat_csv.py +129 -0
- msprobe/pytorch/hook_module/hook_module.py +14 -3
- msprobe/pytorch/hook_module/support_wrap_ops.yaml +2 -1
- msprobe/pytorch/hook_module/utils.py +9 -9
- msprobe/pytorch/hook_module/wrap_aten.py +20 -10
- msprobe/pytorch/hook_module/wrap_distributed.py +10 -7
- msprobe/pytorch/hook_module/wrap_functional.py +4 -7
- msprobe/pytorch/hook_module/wrap_npu_custom.py +21 -10
- msprobe/pytorch/hook_module/wrap_tensor.py +5 -6
- msprobe/pytorch/hook_module/wrap_torch.py +5 -7
- msprobe/pytorch/hook_module/wrap_vf.py +6 -8
- msprobe/pytorch/module_processer.py +53 -13
- msprobe/pytorch/online_dispatch/compare.py +4 -4
- msprobe/pytorch/online_dispatch/dispatch.py +39 -41
- msprobe/pytorch/online_dispatch/dump_compare.py +17 -47
- msprobe/pytorch/online_dispatch/single_compare.py +5 -5
- msprobe/pytorch/online_dispatch/utils.py +2 -43
- msprobe/pytorch/parse_tool/lib/compare.py +31 -19
- msprobe/pytorch/parse_tool/lib/config.py +2 -1
- msprobe/pytorch/parse_tool/lib/parse_tool.py +4 -4
- msprobe/pytorch/parse_tool/lib/utils.py +34 -80
- msprobe/pytorch/parse_tool/lib/visualization.py +4 -3
- msprobe/pytorch/pt_config.py +100 -6
- msprobe/pytorch/service.py +104 -19
- mindstudio_probe-1.0.1.dist-info/RECORD +0 -228
- msprobe/mindspore/dump/api_kbk_dump.py +0 -55
- msprobe/pytorch/compare/acc_compare.py +0 -1024
- msprobe/pytorch/compare/highlight.py +0 -100
- msprobe/test/core_ut/common/test_utils.py +0 -345
- msprobe/test/core_ut/data_dump/test_data_collector.py +0 -47
- msprobe/test/core_ut/data_dump/test_json_writer.py +0 -183
- msprobe/test/core_ut/data_dump/test_scope.py +0 -151
- msprobe/test/core_ut/test_common_config.py +0 -152
- msprobe/test/core_ut/test_file_check.py +0 -218
- msprobe/test/core_ut/test_log.py +0 -109
- msprobe/test/mindspore_ut/test_api_kbk_dump.py +0 -51
- msprobe/test/mindspore_ut/test_debugger_config.py +0 -42
- msprobe/test/mindspore_ut/test_dump_tool_factory.py +0 -51
- msprobe/test/mindspore_ut/test_kernel_graph_dump.py +0 -66
- msprobe/test/mindspore_ut/test_kernel_graph_overflow_check.py +0 -63
- msprobe/test/mindspore_ut/test_ms_config.py +0 -69
- msprobe/test/mindspore_ut/test_overflow_check_tool_factory.py +0 -51
- msprobe/test/mindspore_ut/test_precision_debugger.py +0 -56
- msprobe/test/mindspore_ut/test_task_handler_factory.py +0 -58
- msprobe/test/pytorch_ut/advisor/test_advisor.py +0 -83
- msprobe/test/pytorch_ut/api_accuracy_checker/common/test_common_utils.py +0 -108
- msprobe/test/pytorch_ut/api_accuracy_checker/common/test_config.py +0 -39
- msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_algorithm.py +0 -112
- msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_api_precision_compare.py +0 -77
- msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_compare.py +0 -125
- msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_compare_column.py +0 -10
- msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_compare_utils.py +0 -43
- msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/dump.json +0 -179
- msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/forward.json +0 -63
- msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_data_generate.py +0 -99
- msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_multi_run_ut.py +0 -115
- msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_run_ut.py +0 -72
- msprobe/test/pytorch_ut/compare/test_acc_compare.py +0 -17
- msprobe/test/pytorch_ut/free_benchmark/perturbed_layers/test_perturbed_layser.py +0 -105
- msprobe/test/pytorch_ut/free_benchmark/result_handlers/test_result_handler.py +0 -121
- msprobe/test/pytorch_ut/free_benchmark/test_main.py +0 -101
- msprobe/test/pytorch_ut/functional/test_dump_module.py +0 -15
- msprobe/test/pytorch_ut/hook_module/test_api_registry.py +0 -130
- msprobe/test/pytorch_ut/hook_module/test_hook_module.py +0 -42
- msprobe/test/pytorch_ut/hook_module/test_wrap_aten.py +0 -65
- msprobe/test/pytorch_ut/hook_module/test_wrap_distributed.py +0 -35
- msprobe/test/pytorch_ut/hook_module/test_wrap_functional.py +0 -20
- msprobe/test/pytorch_ut/hook_module/test_wrap_tensor.py +0 -35
- msprobe/test/pytorch_ut/hook_module/test_wrap_torch.py +0 -43
- msprobe/test/pytorch_ut/hook_module/test_wrap_vf.py +0 -11
- msprobe/test/pytorch_ut/test_pt_config.py +0 -69
- msprobe/test/pytorch_ut/test_service.py +0 -59
- msprobe/test/resources/advisor.txt +0 -3
- msprobe/test/resources/compare_result_20230703104808.csv +0 -9
- msprobe/test/resources/compare_result_without_accuracy.csv +0 -9
- msprobe/test/resources/config.yaml +0 -3
- msprobe/test/resources/npu_test.pkl +0 -8
- msprobe/test/run_test.sh +0 -30
- msprobe/test/run_ut.py +0 -58
- msprobe/test/test_module_processer.py +0 -64
- {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.3.dist-info}/LICENSE +0 -0
- {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.3.dist-info}/WHEEL +0 -0
- {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.3.dist-info}/entry_points.txt +0 -0
- {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.3.dist-info}/top_level.txt +0 -0
- /msprobe/{pytorch → core}/advisor/advisor_const.py +0 -0
- /msprobe/pytorch/doc/{atat → msprobe}/321/207/342/226/223/342/225/233/321/205/342/225/221/320/266/321/205/342/225/226/320/265/321/205/320/225/342/225/226/321/206/320/245/342/226/221/321/206/320/235/320/276dump/321/206/320/260/320/227/321/205/320/227/320/226/321/206/320/220/320/267/321/210/320/223/342/225/234/321/205/320/257/342/225/221/321/207/342/225/221/342/224/220/321/206/320/232/320/265/321/205/320/241/320/232.md" +0 -0
|
@@ -22,6 +22,8 @@ from collections import namedtuple
|
|
|
22
22
|
from msprobe.pytorch.parse_tool.lib.utils import Util
|
|
23
23
|
from msprobe.pytorch.parse_tool.lib.config import Const
|
|
24
24
|
from msprobe.pytorch.parse_tool.lib.parse_exception import ParseException
|
|
25
|
+
from msprobe.core.common.utils import create_directory, write_csv, save_npy_to_txt
|
|
26
|
+
from msprobe.core.common.file_check import FileChecker
|
|
25
27
|
|
|
26
28
|
|
|
27
29
|
class Compare:
|
|
@@ -36,7 +38,7 @@ class Compare:
|
|
|
36
38
|
self.log.info("Compare finished!!")
|
|
37
39
|
|
|
38
40
|
def compare_vector(self, my_dump_path, golden_dump_path, result_dir, msaccucmp_path):
|
|
39
|
-
|
|
41
|
+
create_directory(result_dir)
|
|
40
42
|
self.util.check_path_valid(result_dir)
|
|
41
43
|
call_msaccucmp = self.util.check_msaccucmp(msaccucmp_path)
|
|
42
44
|
cmd = '%s %s compare -m %s -g %s -out %s' % (
|
|
@@ -65,7 +67,7 @@ class Compare:
|
|
|
65
67
|
self.util.print_panel("\n".join(summary_txt))
|
|
66
68
|
|
|
67
69
|
def convert(self, dump_file, data_format, output, msaccucmp_path):
|
|
68
|
-
|
|
70
|
+
create_directory(output)
|
|
69
71
|
self.util.check_path_valid(output)
|
|
70
72
|
call_msaccucmp = self.util.check_msaccucmp(msaccucmp_path)
|
|
71
73
|
if data_format:
|
|
@@ -83,21 +85,22 @@ class Compare:
|
|
|
83
85
|
(left, right, save_txt, rl, al, diff_count) = args
|
|
84
86
|
if left is None or right is None:
|
|
85
87
|
raise ParseException("invalid input or output")
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
88
|
+
if self.util.check_path_valid(left) and self.util.check_path_valid(right):
|
|
89
|
+
try:
|
|
90
|
+
left_data = np.load(left)
|
|
91
|
+
right_data = np.load(right)
|
|
92
|
+
except UnicodeError as e:
|
|
93
|
+
self.log.error("%s %s" % ("UnicodeError", str(e)))
|
|
94
|
+
self.log.warning("Please check the npy file")
|
|
95
|
+
raise ParseException(ParseException.PARSE_UNICODE_ERROR) from e
|
|
96
|
+
except IOError:
|
|
97
|
+
self.log.error("Failed to load npy %s or %s." % (left, right))
|
|
98
|
+
raise ParseException(ParseException.PARSE_LOAD_NPY_ERROR) from e
|
|
96
99
|
|
|
97
100
|
# save to txt
|
|
98
101
|
if save_txt:
|
|
99
|
-
|
|
100
|
-
|
|
102
|
+
save_npy_to_txt(left_data, left + ".txt")
|
|
103
|
+
save_npy_to_txt(right_data, right + ".txt")
|
|
101
104
|
# compare data
|
|
102
105
|
(total_cnt, all_close, cos_sim, err_percent) = self.do_compare_data(left_data, right_data, rl, al, diff_count)
|
|
103
106
|
content = ['Left:', ' ├─ NpyFile: %s' % left]
|
|
@@ -157,8 +160,10 @@ class Compare:
|
|
|
157
160
|
return res
|
|
158
161
|
|
|
159
162
|
def compare_npy(self, file, bench_file, output_path):
|
|
160
|
-
|
|
161
|
-
|
|
163
|
+
if self.util.check_path_valid(file):
|
|
164
|
+
data = np.load(file)
|
|
165
|
+
if self.util.check_path_valid(bench_file):
|
|
166
|
+
bench_data = np.load(bench_file)
|
|
162
167
|
shape, dtype = data.shape, data.dtype
|
|
163
168
|
bench_shape, bench_dtype = bench_data.shape, bench_data.dtype
|
|
164
169
|
filename = os.path.basename(file)
|
|
@@ -181,7 +186,7 @@ class Compare:
|
|
|
181
186
|
rel_diff_max = np.max(rel_error)
|
|
182
187
|
compare_result = [[filename, bench_filename, data_mean, bench_data_mean, md5_consistency, abs_diff_max,
|
|
183
188
|
rel_diff_max]]
|
|
184
|
-
|
|
189
|
+
write_csv(compare_result, output_path)
|
|
185
190
|
|
|
186
191
|
def compare_all_file_in_directory(self, my_dump_dir, golden_dump_dir, output_path):
|
|
187
192
|
if not (self.util.is_subdir_count_equal(my_dump_dir, golden_dump_dir)
|
|
@@ -228,7 +233,7 @@ class Compare:
|
|
|
228
233
|
"Max Abs Error",
|
|
229
234
|
"Max Relative Error"
|
|
230
235
|
]]
|
|
231
|
-
|
|
236
|
+
write_csv(title_rows, output_path)
|
|
232
237
|
|
|
233
238
|
my_ordered_subdirs = self.util.get_sorted_subdirectories_names(my_dump_dir)
|
|
234
239
|
golden_ordered_subdirs = self.util.get_sorted_subdirectories_names(golden_dump_dir)
|
|
@@ -246,7 +251,9 @@ class Compare:
|
|
|
246
251
|
|
|
247
252
|
def convert_api_dir_to_npy(self, dump_dir, param, output_dir, msaccucmp_path):
|
|
248
253
|
dump_dir = self.util.path_strip(dump_dir)
|
|
249
|
-
for root, _, files in os.walk(dump_dir):
|
|
254
|
+
for root, _, files in os.walk(dump_dir, topdown=True):
|
|
255
|
+
path_checker = FileChecker(root)
|
|
256
|
+
path_checker.common_check()
|
|
250
257
|
for file in files:
|
|
251
258
|
file_path = os.path.join(root, file)
|
|
252
259
|
file_name = os.path.basename(file_path)
|
|
@@ -257,3 +264,8 @@ class Compare:
|
|
|
257
264
|
timestamp = parts[-1]
|
|
258
265
|
output_path = os.path.join(output_dir, op_name, timestamp)
|
|
259
266
|
self.convert_dump_to_npy(file_path, param, output_path, msaccucmp_path)
|
|
267
|
+
path_depth = root.count(os.sep)
|
|
268
|
+
if path_depth <= Const.MAX_TRAVERSAL_DEPTH:
|
|
269
|
+
yield root, _, files
|
|
270
|
+
else:
|
|
271
|
+
_[:] = []
|
|
@@ -33,11 +33,12 @@ class Const:
|
|
|
33
33
|
OFFLINE_DUMP_CONVERT_PATTERN = \
|
|
34
34
|
r"^([A-Za-z0-9_-]+)\.([A-Za-z0-9_-]+)\.([0-9]+)(\.[0-9]+)?\.([0-9]{1,255})" \
|
|
35
35
|
r"\.([a-z]+)\.([0-9]{1,255})(\.[x0-9]+)?\.npy$"
|
|
36
|
-
NUMPY_PATTERN = r"
|
|
36
|
+
NUMPY_PATTERN = r"^[\w\-_-]\.npy$"
|
|
37
37
|
NPY_SUFFIX = ".npy"
|
|
38
38
|
PKL_SUFFIX = ".pkl"
|
|
39
39
|
DIRECTORY_LENGTH = 4096
|
|
40
40
|
FILE_NAME_LENGTH = 255
|
|
41
|
+
MAX_TRAVERSAL_DEPTH = 5
|
|
41
42
|
FILE_PATTERN = r'^[a-zA-Z0-9_./-]+$'
|
|
42
43
|
ONE_GB = 1 * 1024 * 1024 * 1024
|
|
43
44
|
TEN_GB = 10 * 1024 * 1024 * 1024
|
|
@@ -23,7 +23,7 @@ from msprobe.pytorch.parse_tool.lib.utils import Util
|
|
|
23
23
|
from msprobe.pytorch.parse_tool.lib.compare import Compare
|
|
24
24
|
from msprobe.pytorch.parse_tool.lib.visualization import Visualization
|
|
25
25
|
from msprobe.pytorch.parse_tool.lib.parse_exception import catch_exception, ParseException
|
|
26
|
-
|
|
26
|
+
from msprobe.core.common.utils import create_directory
|
|
27
27
|
|
|
28
28
|
class ParseTool:
|
|
29
29
|
def __init__(self):
|
|
@@ -33,7 +33,7 @@ class ParseTool:
|
|
|
33
33
|
|
|
34
34
|
@catch_exception
|
|
35
35
|
def prepare(self):
|
|
36
|
-
|
|
36
|
+
create_directory(Const.DATA_ROOT_DIR)
|
|
37
37
|
|
|
38
38
|
@catch_exception
|
|
39
39
|
def do_vector_compare(self, args):
|
|
@@ -112,8 +112,8 @@ class ParseTool:
|
|
|
112
112
|
args = parser.parse_args(argv)
|
|
113
113
|
self.util.check_path_valid(args.my_dump_path)
|
|
114
114
|
self.util.check_path_valid(args.golden_dump_path)
|
|
115
|
-
self.util.
|
|
116
|
-
self.util.
|
|
115
|
+
self.util.check_file_path_format(args.my_dump_path, Const.NPY_SUFFIX)
|
|
116
|
+
self.util.check_file_path_format(args.golden_dump_path, Const.NPY_SUFFIX)
|
|
117
117
|
compare_data_args = namedtuple('compare_data_args', ['my_dump_path', 'golden_dump_path', 'save', 'rtol', 'atol', 'count'])
|
|
118
118
|
compare_data_args.__new__.__defaults__ = (False, 0.001, 0.001, 20)
|
|
119
119
|
res = compare_data_args(args.my_dump_path, args.golden_dump_path, args.save, args.rtol, args.atol, args.count)
|
|
@@ -31,7 +31,7 @@ from msprobe.pytorch.parse_tool.lib.parse_exception import ParseException
|
|
|
31
31
|
from msprobe.core.common.file_check import change_mode, check_other_user_writable,\
|
|
32
32
|
check_path_executable, check_path_owner_consistent
|
|
33
33
|
from msprobe.core.common.const import FileCheckConst
|
|
34
|
-
from msprobe.core.common.file_check import FileOpen
|
|
34
|
+
from msprobe.core.common.file_check import FileOpen, FileChecker
|
|
35
35
|
from msprobe.core.common.utils import check_file_or_directory_path
|
|
36
36
|
from msprobe.pytorch.common.log import logger
|
|
37
37
|
|
|
@@ -57,12 +57,7 @@ except ImportError as err:
|
|
|
57
57
|
class Util:
|
|
58
58
|
def __init__(self):
|
|
59
59
|
self.ms_accu_cmp = None
|
|
60
|
-
|
|
61
|
-
level=Const.LOG_LEVEL,
|
|
62
|
-
format="%(asctime)s (%(process)d) -[%(levelname)s]%(message)s",
|
|
63
|
-
datefmt="%Y-%m-%d %H:%M:%S"
|
|
64
|
-
)
|
|
65
|
-
self.log = logging.getLogger()
|
|
60
|
+
self.log = logger
|
|
66
61
|
self.python = sys.executable
|
|
67
62
|
|
|
68
63
|
@staticmethod
|
|
@@ -82,6 +77,8 @@ class Util:
|
|
|
82
77
|
@staticmethod
|
|
83
78
|
def get_subdir_count(self, directory):
|
|
84
79
|
subdir_count = 0
|
|
80
|
+
path_checker = FileChecker(directory)
|
|
81
|
+
path_checker.common_check()
|
|
85
82
|
for _, dirs, _ in os.walk(directory):
|
|
86
83
|
subdir_count += len(dirs)
|
|
87
84
|
break
|
|
@@ -90,8 +87,15 @@ class Util:
|
|
|
90
87
|
@staticmethod
|
|
91
88
|
def get_subfiles_count(self, directory):
|
|
92
89
|
file_count = 0
|
|
93
|
-
for
|
|
90
|
+
for root, _, files in os.walk(directory, topdown=True):
|
|
91
|
+
path_checker = FileChecker(root)
|
|
92
|
+
path_checker.common_check()
|
|
94
93
|
file_count += len(files)
|
|
94
|
+
path_depth = root.count(os.sep)
|
|
95
|
+
if path_depth <= Const.MAX_TRAVERSAL_DEPTH:
|
|
96
|
+
yield root, _, files
|
|
97
|
+
else:
|
|
98
|
+
_[:] = []
|
|
95
99
|
return file_count
|
|
96
100
|
|
|
97
101
|
@staticmethod
|
|
@@ -128,17 +132,6 @@ class Util:
|
|
|
128
132
|
md5_hash = hashlib.md5(np_bytes)
|
|
129
133
|
return md5_hash.hexdigest()
|
|
130
134
|
|
|
131
|
-
@staticmethod
|
|
132
|
-
def write_csv(self, data, filepath):
|
|
133
|
-
need_change_mode = False
|
|
134
|
-
if not os.path.exists(filepath):
|
|
135
|
-
need_change_mode = True
|
|
136
|
-
with FileOpen(filepath, 'a') as f:
|
|
137
|
-
writer = csv.writer(f)
|
|
138
|
-
writer.writerows(data)
|
|
139
|
-
if need_change_mode:
|
|
140
|
-
change_mode(filepath, FileCheckConst.DATA_FILE_AUTHORITY)
|
|
141
|
-
|
|
142
135
|
@staticmethod
|
|
143
136
|
def deal_with_dir_or_file_inconsistency(self, output_path):
|
|
144
137
|
if os.path.exists(output_path):
|
|
@@ -160,10 +153,17 @@ class Util:
|
|
|
160
153
|
|
|
161
154
|
@staticmethod
|
|
162
155
|
def dir_contains_only(self, path, endfix):
|
|
163
|
-
for
|
|
156
|
+
for root, _, files in os.walk(path, topdown=True):
|
|
157
|
+
path_checker = FileChecker(root)
|
|
158
|
+
path_checker.common_check()
|
|
164
159
|
for file in files:
|
|
165
160
|
if not file.endswith(endfix):
|
|
166
161
|
return False
|
|
162
|
+
path_depth = root.count(os.sep)
|
|
163
|
+
if path_depth <= Const.MAX_TRAVERSAL_DEPTH:
|
|
164
|
+
yield root, _, files
|
|
165
|
+
else:
|
|
166
|
+
_[:] = []
|
|
167
167
|
return True
|
|
168
168
|
|
|
169
169
|
@staticmethod
|
|
@@ -188,7 +188,7 @@ class Util:
|
|
|
188
188
|
if not cmd:
|
|
189
189
|
self.log.error("Commond is None")
|
|
190
190
|
return -1
|
|
191
|
-
self.log.
|
|
191
|
+
self.log.info("[RUN CMD]: %s", cmd)
|
|
192
192
|
cmd = cmd.split(" ")
|
|
193
193
|
complete_process = subprocess.run(cmd, shell=False)
|
|
194
194
|
return complete_process.returncode
|
|
@@ -208,7 +208,7 @@ class Util:
|
|
|
208
208
|
"Check msaccucmp failed in dir %s. This is not a correct msaccucmp file" % target_file)
|
|
209
209
|
raise ParseException(ParseException.PARSE_MSACCUCMP_ERROR)
|
|
210
210
|
result = subprocess.run(
|
|
211
|
-
[self.python, target_file, "--help"], stdout=subprocess.PIPE)
|
|
211
|
+
[self.python, target_file, "--help"], stdout=subprocess.PIPE, shell=False)
|
|
212
212
|
if result.returncode == 0:
|
|
213
213
|
self.log.info("Check [%s] success.", target_file)
|
|
214
214
|
else:
|
|
@@ -217,37 +217,12 @@ class Util:
|
|
|
217
217
|
raise ParseException(ParseException.PARSE_MSACCUCMP_ERROR)
|
|
218
218
|
return target_file
|
|
219
219
|
|
|
220
|
-
def create_dir(self, path):
|
|
221
|
-
path = self.path_strip(path)
|
|
222
|
-
if os.path.exists(path):
|
|
223
|
-
return
|
|
224
|
-
self.check_path_name(path)
|
|
225
|
-
try:
|
|
226
|
-
os.makedirs(path, mode=FileCheckConst.DATA_DIR_AUTHORITY)
|
|
227
|
-
except OSError as e:
|
|
228
|
-
self.log.error("Failed to create %s.", path)
|
|
229
|
-
raise ParseException(ParseException.PARSE_INVALID_PATH_ERROR) from e
|
|
230
|
-
|
|
231
220
|
def gen_npy_info_txt(self, source_data):
|
|
232
221
|
(shape, dtype, max_data, min_data, mean) = \
|
|
233
222
|
self.npy_info(source_data)
|
|
234
223
|
return \
|
|
235
224
|
'[Shape: %s] [Dtype: %s] [Max: %s] [Min: %s] [Mean: %s]' % (shape, dtype, max_data, min_data, mean)
|
|
236
225
|
|
|
237
|
-
def save_npy_to_txt(self, data, dst_file='', align=0):
|
|
238
|
-
if os.path.exists(dst_file):
|
|
239
|
-
self.log.info("Dst file %s exists, will not save new one.", dst_file)
|
|
240
|
-
return
|
|
241
|
-
shape = data.shape
|
|
242
|
-
data = data.flatten()
|
|
243
|
-
if align == 0:
|
|
244
|
-
align = 1 if len(shape) == 0 else shape[-1]
|
|
245
|
-
elif data.size % align != 0:
|
|
246
|
-
pad_array = np.zeros((align - data.size % align,))
|
|
247
|
-
data = np.append(data, pad_array)
|
|
248
|
-
np.savetxt(dst_file, data.reshape((-1, align)), delimiter=' ', fmt='%g')
|
|
249
|
-
change_mode(dst_file, FileCheckConst.DATA_FILE_AUTHORITY)
|
|
250
|
-
|
|
251
226
|
def list_convert_files(self, path, external_pattern=""):
|
|
252
227
|
return self.list_file_with_pattern(
|
|
253
228
|
path, Const.OFFLINE_DUMP_CONVERT_PATTERN, external_pattern, self._gen_npu_dump_convert_file_info
|
|
@@ -274,27 +249,8 @@ class Util:
|
|
|
274
249
|
|
|
275
250
|
def check_path_valid(self, path):
|
|
276
251
|
path = self.path_strip(path)
|
|
277
|
-
|
|
278
|
-
|
|
279
|
-
raise ParseException(ParseException.PARSE_INVALID_PATH_ERROR)
|
|
280
|
-
if os.path.islink(path):
|
|
281
|
-
self.log.error('The file path {} is a soft link.'.format(path))
|
|
282
|
-
raise ParseException(ParseException.PARSE_INVALID_PATH_ERROR)
|
|
283
|
-
if len(os.path.realpath(path)) > Const.DIRECTORY_LENGTH or len(os.path.basename(path)) > \
|
|
284
|
-
Const.FILE_NAME_LENGTH:
|
|
285
|
-
self.log.error('The file path length exceeds limit.')
|
|
286
|
-
raise ParseException(ParseException.PARSE_INVALID_PATH_ERROR)
|
|
287
|
-
if not re.match(Const.FILE_PATTERN, os.path.realpath(path)):
|
|
288
|
-
self.log.error('The file path {} contains special characters.'.format(path))
|
|
289
|
-
raise ParseException(ParseException.PARSE_INVALID_PATH_ERROR)
|
|
290
|
-
if os.path.isfile(path):
|
|
291
|
-
file_size = os.path.getsize(path)
|
|
292
|
-
if path.endswith(Const.PKL_SUFFIX) and file_size > Const.ONE_GB:
|
|
293
|
-
self.log.error('The file {} size is greater than 1GB.'.format(path))
|
|
294
|
-
raise ParseException(ParseException.PARSE_INVALID_PATH_ERROR)
|
|
295
|
-
if path.endswith(Const.NPY_SUFFIX) and file_size > Const.TEN_GB:
|
|
296
|
-
self.log.error('The file {} size is greater than 10GB.'.format(path))
|
|
297
|
-
raise ParseException(ParseException.PARSE_INVALID_PATH_ERROR)
|
|
252
|
+
path_checker = FileChecker(path)
|
|
253
|
+
path_checker.common_check()
|
|
298
254
|
return True
|
|
299
255
|
|
|
300
256
|
def check_files_in_path(self, path):
|
|
@@ -322,17 +278,24 @@ class Util:
|
|
|
322
278
|
self.check_path_valid(path)
|
|
323
279
|
file_list = {}
|
|
324
280
|
re_pattern = re.compile(pattern)
|
|
325
|
-
for dir_path, _, file_names in os.walk(path,
|
|
281
|
+
for dir_path, _, file_names in os.walk(path, topdown=True):
|
|
282
|
+
path_checker = FileChecker(dir)
|
|
283
|
+
path_checker.common_check()
|
|
326
284
|
for name in file_names:
|
|
327
285
|
match = re_pattern.match(name)
|
|
328
286
|
if not match:
|
|
329
287
|
continue
|
|
330
|
-
if extern_pattern != '' and not re.match(extern_pattern, name):
|
|
288
|
+
if extern_pattern != '' and re_pattern.match(extern_pattern) and not re.match(extern_pattern, name):
|
|
331
289
|
continue
|
|
332
290
|
file_list[name] = gen_info_func(name, match, dir_path)
|
|
291
|
+
path_depth = dir_path.count(os.sep)
|
|
292
|
+
if path_depth <= Const.MAX_TRAVERSAL_DEPTH:
|
|
293
|
+
yield dir_path, _, file_names
|
|
294
|
+
else:
|
|
295
|
+
_[:] = []
|
|
333
296
|
return file_list
|
|
334
297
|
|
|
335
|
-
def
|
|
298
|
+
def check_file_path_format(self, path, suffix):
|
|
336
299
|
if os.path.isfile(path):
|
|
337
300
|
if not path.endswith(suffix):
|
|
338
301
|
self.log.error("%s is not a %s file." % (path, suffix))
|
|
@@ -344,15 +307,6 @@ class Util:
|
|
|
344
307
|
self.log.error("The file path %s is invalid" % path)
|
|
345
308
|
raise ParseException(ParseException.PARSE_INVALID_PATH_ERROR)
|
|
346
309
|
|
|
347
|
-
def check_path_name(self, path):
|
|
348
|
-
if len(os.path.realpath(path)) > Const.DIRECTORY_LENGTH or len(os.path.basename(path)) > \
|
|
349
|
-
Const.FILE_NAME_LENGTH:
|
|
350
|
-
self.log.error('The file path length exceeds limit.')
|
|
351
|
-
raise ParseException(ParseException.PARSE_INVALID_PATH_ERROR)
|
|
352
|
-
if not re.match(Const.FILE_PATTERN, os.path.realpath(path)):
|
|
353
|
-
self.log.error('The file path {} contains special characters.'.format(path))
|
|
354
|
-
raise ParseException(ParseException.PARSE_INVALID_PATH_ERROR)
|
|
355
|
-
|
|
356
310
|
def check_str_param(self, param):
|
|
357
311
|
if len(param) > Const.FILE_NAME_LENGTH:
|
|
358
312
|
self.log.error('The parameter length exceeds limit')
|
|
@@ -21,6 +21,7 @@ from msprobe.pytorch.parse_tool.lib.config import Const
|
|
|
21
21
|
from msprobe.pytorch.parse_tool.lib.utils import Util
|
|
22
22
|
from msprobe.pytorch.parse_tool.lib.parse_exception import ParseException
|
|
23
23
|
from msprobe.core.common.file_check import FileOpen
|
|
24
|
+
from msprobe.core.common.utils import save_npy_to_txt
|
|
24
25
|
|
|
25
26
|
|
|
26
27
|
class Visualization:
|
|
@@ -43,18 +44,18 @@ class Visualization:
|
|
|
43
44
|
summary = ['[yellow]%s[/yellow]' % self.util.gen_npy_info_txt(np_data), 'Path: %s' % target_file,
|
|
44
45
|
"TextFile: %s.txt" % target_file]
|
|
45
46
|
self.util.print_panel(self.util.create_columns([table, "\n".join(summary)]), target_file)
|
|
46
|
-
|
|
47
|
+
save_npy_to_txt(np_data, target_file + ".txt")
|
|
47
48
|
|
|
48
49
|
def print_npy_data(self, file_name):
|
|
49
50
|
file_name = self.util.path_strip(file_name)
|
|
50
51
|
self.util.check_path_valid(file_name)
|
|
51
|
-
self.util.
|
|
52
|
+
self.util.check_file_path_format(file_name, Const.NPY_SUFFIX)
|
|
52
53
|
return self.print_npy_summary(file_name)
|
|
53
54
|
|
|
54
55
|
def parse_pkl(self, path, api_name):
|
|
55
56
|
path = self.util.path_strip(path)
|
|
56
57
|
self.util.check_path_valid(path)
|
|
57
|
-
self.util.
|
|
58
|
+
self.util.check_file_path_format(path, Const.PKL_SUFFIX)
|
|
58
59
|
self.util.check_str_param(api_name)
|
|
59
60
|
with FileOpen(path, "r") as pkl_handle:
|
|
60
61
|
title_printed = False
|
msprobe/pytorch/pt_config.py
CHANGED
|
@@ -4,18 +4,36 @@ import os
|
|
|
4
4
|
from msprobe.core.common_config import CommonConfig, BaseConfig
|
|
5
5
|
from msprobe.core.common.file_check import FileOpen
|
|
6
6
|
from msprobe.core.common.const import Const
|
|
7
|
+
from msprobe.pytorch.hook_module.utils import get_ops
|
|
8
|
+
from msprobe.core.grad_probe.constant import level_adp
|
|
9
|
+
from msprobe.core.grad_probe.utils import check_numeral_list_ascend
|
|
7
10
|
|
|
8
11
|
|
|
9
12
|
class TensorConfig(BaseConfig):
|
|
10
13
|
def __init__(self, json_config):
|
|
11
14
|
super().__init__(json_config)
|
|
15
|
+
self.online_run_ut = json_config.get("online_run_ut", False)
|
|
16
|
+
self.nfs_path = json_config.get("nfs_path", "")
|
|
17
|
+
self.host = json_config.get("host", "")
|
|
18
|
+
self.port = json_config.get("port", -1)
|
|
19
|
+
self.tls_path = json_config.get("tls_path", "")
|
|
12
20
|
self.check_config()
|
|
13
21
|
self._check_file_format()
|
|
22
|
+
self._check_tls_path_config()
|
|
14
23
|
|
|
15
24
|
def _check_file_format(self):
|
|
16
25
|
if self.file_format is not None and self.file_format not in ["npy", "bin"]:
|
|
17
26
|
raise Exception("file_format is invalid")
|
|
18
27
|
|
|
28
|
+
def _check_tls_path_config(self):
|
|
29
|
+
if self.tls_path:
|
|
30
|
+
if not os.path.exists(self.tls_path):
|
|
31
|
+
raise Exception("tls_path: %s does not exist" % self.tls_path)
|
|
32
|
+
if not os.path.exists(os.path.join(self.tls_path, "client.key")):
|
|
33
|
+
raise Exception("tls_path does not contain client.key")
|
|
34
|
+
if not os.path.exists(os.path.join(self.tls_path, "client.crt")):
|
|
35
|
+
raise Exception("tls_path does not contain client.crt")
|
|
36
|
+
|
|
19
37
|
|
|
20
38
|
class StatisticsConfig(BaseConfig):
|
|
21
39
|
def __init__(self, json_config):
|
|
@@ -31,12 +49,12 @@ class StatisticsConfig(BaseConfig):
|
|
|
31
49
|
class OverflowCheckConfig(BaseConfig):
|
|
32
50
|
def __init__(self, json_config):
|
|
33
51
|
super().__init__(json_config)
|
|
34
|
-
self.
|
|
52
|
+
self.overflow_nums = json_config.get("overflow_nums")
|
|
35
53
|
self.check_mode = json_config.get("check_mode")
|
|
36
54
|
self.check_overflow_config()
|
|
37
55
|
|
|
38
56
|
def check_overflow_config(self):
|
|
39
|
-
if self.
|
|
57
|
+
if self.overflow_nums is not None and not isinstance(self.overflow_nums, int):
|
|
40
58
|
raise Exception("overflow_num is invalid")
|
|
41
59
|
if self.check_mode is not None and self.check_mode not in ["all", "aicore", "atomic"]:
|
|
42
60
|
raise Exception("check_mode is invalid")
|
|
@@ -61,20 +79,96 @@ class FreeBenchmarkCheckConfig(BaseConfig):
|
|
|
61
79
|
if self.preheat_step and self.preheat_step == 0:
|
|
62
80
|
raise Exception("preheat_step cannot be 0")
|
|
63
81
|
|
|
82
|
+
|
|
83
|
+
class RunUTConfig(BaseConfig):
|
|
84
|
+
WrapApi = get_ops()
|
|
85
|
+
|
|
86
|
+
def __init__(self, json_config):
|
|
87
|
+
super().__init__(json_config)
|
|
88
|
+
self.white_list = json_config.get("white_list", Const.DEFAULT_LIST)
|
|
89
|
+
self.black_list = json_config.get("black_list", Const.DEFAULT_LIST)
|
|
90
|
+
self.error_data_path = json_config.get("error_data_path", Const.DEFAULT_PATH)
|
|
91
|
+
self.is_online = json_config.get("is_online", False)
|
|
92
|
+
self.nfs_path = json_config.get("nfs_path", "")
|
|
93
|
+
self.host = json_config.get("host", "")
|
|
94
|
+
self.port = json_config.get("port", -1)
|
|
95
|
+
self.rank_list = json_config.get("rank_list", Const.DEFAULT_LIST)
|
|
96
|
+
self.tls_path = json_config.get("tls_path", "")
|
|
97
|
+
self.check_run_ut_config()
|
|
98
|
+
|
|
99
|
+
@classmethod
|
|
100
|
+
def check_filter_list_config(cls, key, filter_list):
|
|
101
|
+
if not isinstance(filter_list, list):
|
|
102
|
+
raise Exception("%s must be a list type" % key)
|
|
103
|
+
if not all(isinstance(item, str) for item in filter_list):
|
|
104
|
+
raise Exception("All elements in %s must be string type" % key)
|
|
105
|
+
invalid_api = [item for item in filter_list if item not in cls.WrapApi]
|
|
106
|
+
if invalid_api:
|
|
107
|
+
raise Exception("Invalid api in %s: %s" % (key, invalid_api))
|
|
108
|
+
|
|
109
|
+
@classmethod
|
|
110
|
+
def check_error_data_path_config(cls, error_data_path):
|
|
111
|
+
if not os.path.exists(error_data_path):
|
|
112
|
+
raise Exception("error_data_path: %s does not exist" % error_data_path)
|
|
113
|
+
|
|
114
|
+
@classmethod
|
|
115
|
+
def check_nfs_path_config(cls, nfs_path):
|
|
116
|
+
if nfs_path and not os.path.exists(nfs_path):
|
|
117
|
+
raise Exception("nfs_path: %s does not exist" % nfs_path)
|
|
118
|
+
|
|
119
|
+
@classmethod
|
|
120
|
+
def check_tls_path_config(cls, tls_path):
|
|
121
|
+
if tls_path:
|
|
122
|
+
if not os.path.exists(tls_path):
|
|
123
|
+
raise Exception("tls_path: %s does not exist" % tls_path)
|
|
124
|
+
if not os.path.exists(os.path.join(tls_path, "server.key")):
|
|
125
|
+
raise Exception("tls_path does not contain server.key")
|
|
126
|
+
if not os.path.exists(os.path.join(tls_path, "server.crt")):
|
|
127
|
+
raise Exception("tls_path does not contain server.crt")
|
|
128
|
+
|
|
129
|
+
def check_run_ut_config(self):
|
|
130
|
+
RunUTConfig.check_filter_list_config(Const.WHITE_LIST, self.white_list)
|
|
131
|
+
RunUTConfig.check_filter_list_config(Const.BLACK_LIST, self.black_list)
|
|
132
|
+
RunUTConfig.check_error_data_path_config(self.error_data_path)
|
|
133
|
+
RunUTConfig.check_nfs_path_config(self.nfs_path)
|
|
134
|
+
RunUTConfig.check_tls_path_config(self.tls_path)
|
|
135
|
+
|
|
136
|
+
|
|
137
|
+
class GradToolConfig(BaseConfig):
|
|
138
|
+
def __init__(self, json_config):
|
|
139
|
+
super().__init__(json_config)
|
|
140
|
+
self.grad_level = json_config.get("grad_level", "L1")
|
|
141
|
+
self.param_list = json_config.get("param_list", [])
|
|
142
|
+
self.bounds = json_config.get("bounds", [])
|
|
143
|
+
|
|
144
|
+
def _check_config(self):
|
|
145
|
+
if self.grad_level not in level_adp.keys():
|
|
146
|
+
raise Exception(f"grad_level must be one of {level_adp.keys()}")
|
|
147
|
+
if not isinstance(self.param_list, list):
|
|
148
|
+
raise Exception(f"param_list must be a list")
|
|
149
|
+
check_numeral_list_ascend(self.bounds)
|
|
150
|
+
|
|
151
|
+
|
|
64
152
|
def parse_task_config(task, json_config):
|
|
65
153
|
default_dic = {}
|
|
66
154
|
if task == Const.TENSOR:
|
|
67
|
-
config_dic = json_config.get(Const.TENSOR
|
|
155
|
+
config_dic = json_config.get(Const.TENSOR, default_dic)
|
|
68
156
|
return TensorConfig(config_dic)
|
|
69
157
|
elif task == Const.STATISTICS:
|
|
70
|
-
config_dic = json_config.get(Const.STATISTICS
|
|
158
|
+
config_dic = json_config.get(Const.STATISTICS, default_dic)
|
|
71
159
|
return StatisticsConfig(config_dic)
|
|
72
160
|
elif task == Const.OVERFLOW_CHECK:
|
|
73
|
-
config_dic = json_config.get(Const.OVERFLOW_CHECK
|
|
161
|
+
config_dic = json_config.get(Const.OVERFLOW_CHECK, default_dic)
|
|
74
162
|
return OverflowCheckConfig(config_dic)
|
|
75
163
|
elif task == Const.FREE_BENCHMARK:
|
|
76
|
-
config_dic = json_config.get(Const.FREE_BENCHMARK
|
|
164
|
+
config_dic = json_config.get(Const.FREE_BENCHMARK, default_dic)
|
|
77
165
|
return FreeBenchmarkCheckConfig(config_dic)
|
|
166
|
+
elif task == Const.RUN_UT:
|
|
167
|
+
config_dic = json_config.get(Const.RUN_UT, default_dic)
|
|
168
|
+
return RunUTConfig(config_dic)
|
|
169
|
+
elif task == Const.GRAD_PROBE:
|
|
170
|
+
config_dic = json_config.get(Const.GRAD_PROBE, default_dic)
|
|
171
|
+
return GradToolConfig(config_dic)
|
|
78
172
|
else:
|
|
79
173
|
return StatisticsConfig(default_dic)
|
|
80
174
|
|